[
  {
    "path": ".commitlintrc.js",
    "content": "module.exports = {\n    extends: [\"@commitlint/config-conventional\"],\n    rules: {\n        // Configuration Format: [level, applicability, value]\n        // level: Error level, usually expressed as a number:\n        //     0 - disable rule\n        //     1 - Warning (does not prevent commits)\n        //     2 - Error (will block the commit)\n        // applicability: the conditions under which the rule applies, commonly used values:\n        //     “always” - always apply the rule\n        //     “never” - never apply the rule\n        // value: the specific value of the rule, e.g. a maximum length of 100.\n        // Refs: https://commitlint.js.org/reference/rules-configuration.html\n      \"header-max-length\": [2, \"always\", 100],\n      \"type-enum\": [\n        2,\n        \"always\",\n        [\"build\", \"chore\", \"ci\", \"docs\", \"feat\", \"fix\", \"perf\", \"refactor\", \"revert\", \"style\", \"test\", \"Release-As\"]\n      ]\n    }\n  };\n"
  },
  {
    "path": ".deepsource.toml",
    "content": "version = 1\n\ntest_patterns = [\"tests/test_*.py\"]\n\nexclude_patterns = [\"examples/**\"]\n\n[[analyzers]]\nname = \"python\"\nenabled = true\n\n  [analyzers.meta]\n  runtime_version = \"3.x.x\"\n"
  },
  {
    "path": ".dockerignore",
    "content": "__pycache__\n*.pyc\n*.pyo\n*.pyd\n.Python\n.env\n.git\n\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/bug-report.md",
    "content": "---\nname: \"\\U0001F41B Bug Report\"\nabout: Submit a bug report to help us improve Qlib\nlabels: bug\n\n---\n\n## 🐛 Bug Description\n\n<!-- A clear and concise description of what the bug is. -->\n\n## To Reproduce\n\nSteps to reproduce the behavior:\n\n1.\n1.\n1.\n\n\n## Expected Behavior\n\n<!-- A clear and concise description of what you expected to happen. -->\n\n## Screenshot\n\n<!-- A screenshot of the error message or anything shouldn't appear-->\n\n## Environment\n\n**Note**: User could run `cd scripts && python collect_info.py all` under project directory to get system information\nand paste them here directly.\n\n - Qlib version:\n - Python version:\n - OS (`Windows`, `Linux`, `MacOS`):\n - Commit number (optional, please provide it if you are using the dev version):\n\n## Additional Notes\n\n<!-- Add any other information about the problem here. -->\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/documentation.md",
    "content": "---\nname: \"\\U0001F4D6 Documentation\"\nabout: Report an issue related to documentation\n\n---\n\n## 📖 Documentation\n\n<!-- Please specify whether it's tutorial part or API reference part, and describe it.-->\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/feature-request.md",
    "content": "---\nname: \"\\U0001F31FFeature Request\"\nabout: Request for a new Qlib feature\nlabels: enhancement\n\n---\n\n## 🌟 Feature Description\n<!-- A clear and concise description of the feature proposal -->\n\n## Motivation\n\n1. Application scenario\n2. Related works (Papers, Github repos etc.):\n3. Any other relevant and important information:\n\n<!-- Please describe why the feature is important. -->\n\n## Alternatives\n\n<!-- A short description of any alternative solutions or features you've considered. -->\n\n## Additional Notes\n\n<!-- Add any other context or screenshots about the feature request here. -->"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/question.md",
    "content": "---\nname: \"❓Questions & Help\"\nabout: Have some questions? We can offer help.\nlabels: question\n\n---\n\n## ❓ Questions and Help\n\nWe sincerely suggest you to carefully read the [documentation](http://qlib.readthedocs.io/) of our library as well as the official [paper](https://arxiv.org/abs/2009.11189). After that, if you still feel puzzled, please describe the question clearly under this issue."
  },
  {
    "path": ".github/PULL_REQUEST_TEMPLATE.md",
    "content": "<!--- Thank you for submitting a Pull Request! In order to make our work smoother. -->\n<!--- please make sure your Pull Request meets the following requirements: -->\n<!---   1. Provide a general summary of your changes in the Title above; -->\n<!---   2. Add appropriate prefixes to titles, such as `build:`, `chore:`, `ci:`, `docs:`, `feat:`, `fix:`, `perf:`, `refactor:`, `revert:`, `style:`, `test:`(Ref: https://www.conventionalcommits.org/). -->\n<!--- Category: -->\n<!--- Patch Updates: `fix:` -->\n<!---   Example: fix(auth): correct login validation issue -->\n<!--- minor update (introduces new functionality): `feat` -->\n<!---   Example: feature(parser): add ability to parse arrays -->\n<!--- major update(destructive update): Include BREAKING CHANGE in the commit message footer, or add `! ` in the commit footer to indicate that there is a destructive update. -->\n<!---   Example: feat(auth)! : remove support for old authentication method -->\n<!--- Other updates: `build:`, `chore:`, `ci:`, `docs:`, `perf:`, `refactor:`, `revert:`, `style:`, `test:`. -->\n\n<!--- Provide a general summary of your changes in the Title above -->\n\n## Description\n<!--- Describe your changes in detail -->\n\n## Motivation and Context\n<!--- Are there any related issues? If so, please put the link here. -->\n<!--- Why is this change required? What problem does it solve? -->\n\n## How Has This Been Tested?\n<!---  Put an `x` in all the boxes that apply: --->\n- [ ] Pass the test by running: `pytest qlib/tests/test_all_pipeline.py` under upper directory of `qlib`.\n- [ ] If you are adding a new feature, test on your own test scripts.\n\n<!--- **ATTENTION**: If you are adding a new feature, please make sure your codes are **correctly tested**. If our test scripts do not cover your cases, please provide your own test scripts under the `tests` folder and test them. More information about test scripts can be found [here](https://docs.python.org/3/library/unittest.html#basic-example), or you could refer to those we provide under the `tests` folder. -->\n\n## Screenshots of Test Results (if appropriate):\n1. Pipeline test:\n2. Your own tests:\n\n## Types of changes\n<!--- What types of changes does your code introduce? Put an `x` in all the boxes that apply: -->\n- [ ] Fix bugs\n- [ ] Add new feature\n- [ ] Update documentation\n"
  },
  {
    "path": ".github/brew_install.sh",
    "content": "#!/bin/bash\nset -u\n\n# First check if the OS is Linux.\nif [[ \"$(uname)\" = \"Linux\" ]]; then\n  HOMEBREW_ON_LINUX=1\nfi\n\n# On macOS, this script installs to /usr/local only.\n# On Linux, it installs to /home/linuxbrew/.linuxbrew if you have sudo access\n# and ~/.linuxbrew otherwise.\n# To install elsewhere (which is unsupported)\n# you can untar https://github.com/Homebrew/brew/tarball/master\n# anywhere you like.\nif [[ -z \"${HOMEBREW_ON_LINUX-}\" ]]; then\n  HOMEBREW_PREFIX=\"/usr/local\"\n  HOMEBREW_REPOSITORY=\"/usr/local/Homebrew\"\n  HOMEBREW_CACHE=\"${HOME}/Library/Caches/Homebrew\"\n\n  STAT=\"stat -f\"\n  CHOWN=\"/usr/sbin/chown\"\n  CHGRP=\"/usr/bin/chgrp\"\n  GROUP=\"admin\"\n  TOUCH=\"/usr/bin/touch\"\nelse\n  HOMEBREW_PREFIX_DEFAULT=\"/home/linuxbrew/.linuxbrew\"\n  HOMEBREW_CACHE=\"${HOME}/.cache/Homebrew\"\n\n  STAT=\"stat --printf\"\n  CHOWN=\"/bin/chown\"\n  CHGRP=\"/bin/chgrp\"\n  GROUP=\"$(id -gn)\"\n  TOUCH=\"/bin/touch\"\nfi\nBREW_REPO=\"https://github.com/Homebrew/brew\"\n\n# TODO: bump version when new macOS is released\nMACOS_LATEST_SUPPORTED=\"10.15\"\n# TODO: bump version when new macOS is released\nMACOS_OLDEST_SUPPORTED=\"10.13\"\n\n# For Homebrew on Linux\nREQUIRED_RUBY_VERSION=2.6  # https://github.com/Homebrew/brew/pull/6556\nREQUIRED_GLIBC_VERSION=2.13  # https://docs.brew.sh/Homebrew-on-Linux#requirements\n\n# no analytics during installation\nexport HOMEBREW_NO_ANALYTICS_THIS_RUN=1\nexport HOMEBREW_NO_ANALYTICS_MESSAGE_OUTPUT=1\n\n# string formatters\nif [[ -t 1 ]]; then\n  tty_escape() { printf \"\\033[%sm\" \"$1\"; }\nelse\n  tty_escape() { :; }\nfi\ntty_mkbold() { tty_escape \"1;$1\"; }\ntty_underline=\"$(tty_escape \"4;39\")\"\ntty_blue=\"$(tty_mkbold 34)\"\ntty_red=\"$(tty_mkbold 31)\"\ntty_bold=\"$(tty_mkbold 39)\"\ntty_reset=\"$(tty_escape 0)\"\n\nhave_sudo_access() {\n  local -a args\n  if [[ -n \"${SUDO_ASKPASS-}\" ]]; then\n    args=(\"-A\")\n  fi\n\n  if [[ -z \"${HAVE_SUDO_ACCESS-}\" ]]; then\n    if [[ -n \"${args[*]-}\" ]]; then\n      /usr/bin/sudo \"${args[@]}\" -l mkdir &>/dev/null\n    else\n      /usr/bin/sudo -l mkdir &>/dev/null\n    fi\n    HAVE_SUDO_ACCESS=\"$?\"\n  fi\n\n  if [[ -z \"${HOMEBREW_ON_LINUX-}\" ]] && [[ \"$HAVE_SUDO_ACCESS\" -ne 0 ]]; then\n    abort \"Need sudo access on macOS (e.g. the user $USER to be an Administrator)!\"\n  fi\n\n  return \"$HAVE_SUDO_ACCESS\"\n}\n\nshell_join() {\n  local arg\n  printf \"%s\" \"$1\"\n  shift\n  for arg in \"$@\"; do\n    printf \" \"\n    printf \"%s\" \"${arg// /\\ }\"\n  done\n}\n\nchomp() {\n  printf \"%s\" \"${1/\"$'\\n'\"/}\"\n}\n\nohai() {\n  printf \"${tty_blue}==>${tty_bold} %s${tty_reset}\\n\" \"$(shell_join \"$@\")\"\n}\n\nwarn() {\n  printf \"${tty_red}Warning${tty_reset}: %s\\n\" \"$(chomp \"$1\")\"\n}\n\nabort() {\n  printf \"%s\\n\" \"$1\"\n  exit 1\n}\n\nexecute() {\n  if ! \"$@\"; then\n    abort \"$(printf \"Failed during: %s\" \"$(shell_join \"$@\")\")\"\n  fi\n}\n\nexecute_sudo() {\n  local -a args=(\"$@\")\n  if [[ -n \"${SUDO_ASKPASS-}\" ]]; then\n    args=(\"-A\" \"${args[@]}\")\n  fi\n  if have_sudo_access; then\n    ohai \"/usr/bin/sudo\" \"${args[@]}\"\n    execute \"/usr/bin/sudo\" \"${args[@]}\"\n  else\n    ohai \"${args[@]}\"\n    execute \"${args[@]}\"\n  fi\n}\n\ngetc() {\n  local save_state\n  save_state=$(/bin/stty -g)\n  /bin/stty raw -echo\n  IFS= read -r -n 1 -d '' \"$@\"\n  /bin/stty \"$save_state\"\n}\n\nwait_for_user() {\n  local c\n  echo\n  echo \"Press RETURN to continue or any other key to abort\"\n  getc c\n  # we test for \\r and \\n because some stuff does \\r instead\n  if ! [[ \"$c\" == $'\\r' || \"$c\" == $'\\n' ]]; then\n    exit 1\n  fi\n}\n\nmajor_minor() {\n  echo \"${1%%.*}.$(x=\"${1#*.}\"; echo \"${x%%.*}\")\"\n}\n\nif [[ -z \"${HOMEBREW_ON_LINUX-}\" ]]; then\n  macos_version=\"$(major_minor \"$(/usr/bin/sw_vers -productVersion)\")\"\nfi\n\nversion_gt() {\n  [[ \"${1%.*}\" -gt \"${2%.*}\" ]] || [[ \"${1%.*}\" -eq \"${2%.*}\" && \"${1#*.}\" -gt \"${2#*.}\" ]]\n}\nversion_ge() {\n  [[ \"${1%.*}\" -gt \"${2%.*}\" ]] || [[ \"${1%.*}\" -eq \"${2%.*}\" && \"${1#*.}\" -ge \"${2#*.}\" ]]\n}\nversion_lt() {\n  [[ \"${1%.*}\" -lt \"${2%.*}\" ]] || [[ \"${1%.*}\" -eq \"${2%.*}\" && \"${1#*.}\" -lt \"${2#*.}\" ]]\n}\n\nshould_install_command_line_tools() {\n  if [[ -n \"${HOMEBREW_ON_LINUX-}\" ]]; then\n    return 1\n  fi\n\n  if version_gt \"$macos_version\" \"10.13\"; then\n    ! [[ -e \"/Library/Developer/CommandLineTools/usr/bin/git\" ]]\n  else\n    ! [[ -e \"/Library/Developer/CommandLineTools/usr/bin/git\" ]] ||\n      ! [[ -e \"/usr/include/iconv.h\" ]]\n  fi\n}\n\nget_permission() {\n  $STAT \"%A\" \"$1\"\n}\n\nuser_only_chmod() {\n  [[ -d \"$1\" ]] && [[ \"$(get_permission \"$1\")\" != \"755\" ]]\n}\n\nexists_but_not_writable() {\n  [[ -e \"$1\" ]] && ! [[ -r \"$1\" && -w \"$1\" && -x \"$1\" ]]\n}\n\nget_owner() {\n  $STAT \"%u\" \"$1\"\n}\n\nfile_not_owned() {\n  [[ \"$(get_owner \"$1\")\" != \"$(id -u)\" ]]\n}\n\nget_group() {\n  $STAT \"%g\" \"$1\"\n}\n\nfile_not_grpowned() {\n  [[ \" $(id -G \"$USER\") \" != *\" $(get_group \"$1\") \"*  ]]\n}\n\n# Please sync with 'test_ruby()' in 'Library/Homebrew/utils/ruby.sh' from Homebrew/brew repository.\ntest_ruby () {\n  if [[ ! -x $1 ]]\n  then\n    return 1\n  fi\n\n  \"$1\" --enable-frozen-string-literal --disable=gems,did_you_mean,rubyopt -rrubygems -e \\\n    \"abort if Gem::Version.new(RUBY_VERSION.to_s.dup).to_s.split('.').first(2) != \\\n              Gem::Version.new('$REQUIRED_RUBY_VERSION').to_s.split('.').first(2)\" 2>/dev/null\n}\n\nno_usable_ruby() {\n  local ruby_exec\n  IFS=$'\\n' # Do word splitting on new lines only\n  for ruby_exec in $(which -a ruby); do\n    if test_ruby \"$ruby_exec\"; then\n      return 1\n    fi\n  done\n  IFS=$' \\t\\n' # Restore IFS to its default value\n  return 0\n}\n\noutdated_glibc() {\n  local glibc_version\n  glibc_version=$(ldd --version | head -n1 | grep -o '[0-9.]*$' | grep -o '^[0-9]\\+\\.[0-9]\\+')\n  version_lt \"$glibc_version\" \"$REQUIRED_GLIBC_VERSION\"\n}\n\nif [[ -n \"${HOMEBREW_ON_LINUX-}\" ]] && no_usable_ruby && outdated_glibc\nthen\n    abort \"$(cat <<-EOFABORT\n\tHomebrew requires Ruby $REQUIRED_RUBY_VERSION which was not found on your system.\n\tHomebrew portable Ruby requires Glibc version $REQUIRED_GLIBC_VERSION or newer,\n\tand your Glibc version is too old.\n\tSee ${tty_underline}https://docs.brew.sh/Homebrew-on-Linux#requirements${tty_reset}\n\tInstall Ruby $REQUIRED_RUBY_VERSION and add its location to your PATH.\n\tEOFABORT\n    )\"\nfi\n\n# USER isn't always set so provide a fall back for the installer and subprocesses.\nif [[ -z \"${USER-}\" ]]; then\n  USER=\"$(chomp \"$(id -un)\")\"\n  export USER\nfi\n\n# Invalidate sudo timestamp before exiting (if it wasn't active before).\nif ! /usr/bin/sudo -n -v 2>/dev/null; then\n  trap '/usr/bin/sudo -k' EXIT\nfi\n\n# Things can fail later if `pwd` doesn't exist.\n# Also sudo prints a warning message for no good reason\ncd \"/usr\" || exit 1\n\n####################################################################### script\nif ! command -v git >/dev/null; then\n    abort \"$(cat <<EOABORT\nYou must install Git before installing Homebrew. See:\n  ${tty_underline}https://docs.brew.sh/Installation${tty_reset}\nEOABORT\n)\"\nfi\n\nif ! command -v curl >/dev/null; then\n    abort \"$(cat <<EOABORT\nYou must install cURL before installing Homebrew. See:\n  ${tty_underline}https://docs.brew.sh/Installation${tty_reset}\nEOABORT\n)\"\nfi\n\nif [[ -z \"${HOMEBREW_ON_LINUX-}\" ]]; then\n have_sudo_access\nelse\n  if [[ -n \"${CI-}\" ]] || [[ -w \"$HOMEBREW_PREFIX_DEFAULT\" ]] || [[ -w \"/home/linuxbrew\" ]] || [[ -w \"/home\" ]]; then\n    HOMEBREW_PREFIX=\"$HOMEBREW_PREFIX_DEFAULT\"\n  else\n    trap exit SIGINT\n    if [[ $(/usr/bin/sudo -n -l mkdir 2>&1) != *\"mkdir\"* ]]; then\n      ohai \"Select the Homebrew installation directory\"\n      echo \"- ${tty_bold}Enter your password${tty_reset} to install to ${tty_underline}${HOMEBREW_PREFIX_DEFAULT}${tty_reset} (${tty_bold}recommended${tty_reset})\"\n      echo \"- ${tty_bold}Press Control-D${tty_reset} to install to ${tty_underline}$HOME/.linuxbrew${tty_reset}\"\n      echo \"- ${tty_bold}Press Control-C${tty_reset} to cancel installation\"\n    fi\n    if have_sudo_access; then\n      HOMEBREW_PREFIX=\"$HOMEBREW_PREFIX_DEFAULT\"\n    else\n      HOMEBREW_PREFIX=\"$HOME/.linuxbrew\"\n    fi\n    trap - SIGINT\n  fi\n  HOMEBREW_REPOSITORY=\"${HOMEBREW_PREFIX}/Homebrew\"\nfi\n\nif [[ \"$UID\" == \"0\" ]]; then\n  abort \"Don't run this as root!\"\nelif [[ -d \"$HOMEBREW_PREFIX\" && ! -x \"$HOMEBREW_PREFIX\" ]]; then\n  abort \"$(cat <<EOABORT\nThe Homebrew prefix, ${HOMEBREW_PREFIX}, exists but is not searchable. If this is\nnot intentional, please restore the default permissions and try running the\ninstaller again:\n    sudo chmod 775 ${HOMEBREW_PREFIX}\nEOABORT\n)\"\nfi\n\nif [[ -z \"${HOMEBREW_ON_LINUX-}\" ]]; then\n  if version_lt \"$macos_version\" \"10.7\"; then\n    abort \"$(cat <<EOABORT\nYour Mac OS X version is too old. See:\n  ${tty_underline}https://github.com/mistydemeo/tigerbrew${tty_reset}\nEOABORT\n)\"\n  elif version_lt \"$macos_version\" \"10.10\"; then\n    abort \"Your OS X version is too old\"\n  elif version_gt \"$macos_version\" \"$MACOS_LATEST_SUPPORTED\" || \\\n    version_lt \"$macos_version\" \"$MACOS_OLDEST_SUPPORTED\"; then\n    who=\"We\"\n    what=\"\"\n    if version_gt \"$macos_version\" \"$MACOS_LATEST_SUPPORTED\"; then\n      what=\"pre-release version\"\n    else\n      who+=\" (and Apple)\"\n      what=\"old version\"\n    fi\n    ohai \"You are using macOS ${macos_version}.\"\n    ohai \"${who} do not provide support for this ${what}.\"\n\n    echo \"$(cat <<EOS\nThis installation may not succeed.\nAfter installation, you will encounter build failures with some formulae.\nPlease create pull requests instead of asking for help on Homebrew\\'s GitHub,\nDiscourse, Twitter or IRC. You are responsible for resolving any issues you\nexperience while you are running this ${what}.\nEOS\n)\n\"\n  fi\nfi\n\nohai \"This script will install:\"\necho \"${HOMEBREW_PREFIX}/bin/brew\"\necho \"${HOMEBREW_PREFIX}/share/doc/homebrew\"\necho \"${HOMEBREW_PREFIX}/share/man/man1/brew.1\"\necho \"${HOMEBREW_PREFIX}/share/zsh/site-functions/_brew\"\necho \"${HOMEBREW_PREFIX}/etc/bash_completion.d/brew\"\necho \"${HOMEBREW_REPOSITORY}\"\n\n# Keep relatively in sync with\n# https://github.com/Homebrew/brew/blob/master/Library/Homebrew/keg.rb\ndirectories=(bin etc include lib sbin share opt var\n             Frameworks\n             etc/bash_completion.d lib/pkgconfig\n             share/aclocal share/doc share/info share/locale share/man\n             share/man/man1 share/man/man2 share/man/man3 share/man/man4\n             share/man/man5 share/man/man6 share/man/man7 share/man/man8\n             var/log var/homebrew var/homebrew/linked\n             bin/brew)\ngroup_chmods=()\nfor dir in \"${directories[@]}\"; do\n  if exists_but_not_writable \"${HOMEBREW_PREFIX}/${dir}\"; then\n    group_chmods+=(\"${HOMEBREW_PREFIX}/${dir}\")\n  fi\ndone\n\n# zsh refuses to read from these directories if group writable\ndirectories=(share/zsh share/zsh/site-functions)\nzsh_dirs=()\nfor dir in \"${directories[@]}\"; do\n  zsh_dirs+=(\"${HOMEBREW_PREFIX}/${dir}\")\ndone\n\ndirectories=(bin etc include lib sbin share var opt\n             share/zsh share/zsh/site-functions\n             var/homebrew var/homebrew/linked\n             Cellar Caskroom Homebrew Frameworks)\nmkdirs=()\nfor dir in \"${directories[@]}\"; do\n  if ! [[ -d \"${HOMEBREW_PREFIX}/${dir}\" ]]; then\n    mkdirs+=(\"${HOMEBREW_PREFIX}/${dir}\")\n  fi\ndone\n\nuser_chmods=()\nif [[ \"${#zsh_dirs[@]}\" -gt 0 ]]; then\n  for dir in \"${zsh_dirs[@]}\"; do\n    if user_only_chmod \"${dir}\"; then\n      user_chmods+=(\"${dir}\")\n    fi\n  done\nfi\n\nchmods=()\nif [[ \"${#group_chmods[@]}\" -gt 0 ]]; then\n  chmods+=(\"${group_chmods[@]}\")\nfi\nif [[ \"${#user_chmods[@]}\" -gt 0 ]]; then\n  chmods+=(\"${user_chmods[@]}\")\nfi\n\nchowns=()\nchgrps=()\nif [[ \"${#chmods[@]}\" -gt 0 ]]; then\n  for dir in \"${chmods[@]}\"; do\n    if file_not_owned \"${dir}\"; then\n      chowns+=(\"${dir}\")\n    fi\n    if file_not_grpowned \"${dir}\"; then\n      chgrps+=(\"${dir}\")\n    fi\n  done\nfi\n\nif [[ \"${#group_chmods[@]}\" -gt 0 ]]; then\n  ohai \"The following existing directories will be made group writable:\"\n  printf \"%s\\n\" \"${group_chmods[@]}\"\nfi\nif [[ \"${#user_chmods[@]}\" -gt 0 ]]; then\n  ohai \"The following existing directories will be made writable by user only:\"\n  printf \"%s\\n\" \"${user_chmods[@]}\"\nfi\nif [[ \"${#chowns[@]}\" -gt 0 ]]; then\n  ohai \"The following existing directories will have their owner set to ${tty_underline}${USER}${tty_reset}:\"\n  printf \"%s\\n\" \"${chowns[@]}\"\nfi\nif [[ \"${#chgrps[@]}\" -gt 0 ]]; then\n  ohai \"The following existing directories will have their group set to ${tty_underline}${GROUP}${tty_reset}:\"\n  printf \"%s\\n\" \"${chgrps[@]}\"\nfi\nif [[ \"${#mkdirs[@]}\" -gt 0 ]]; then\n  ohai \"The following new directories will be created:\"\n  printf \"%s\\n\" \"${mkdirs[@]}\"\nfi\n\nif should_install_command_line_tools; then\n  ohai \"The Xcode Command Line Tools will be installed.\"\nfi\n\nif [[ -t 0 && -z \"${CI-}\" ]]; then\n  wait_for_user\nfi\n\nif [[ -d \"${HOMEBREW_PREFIX}\" ]]; then\n  if [[ \"${#chmods[@]}\" -gt 0 ]]; then\n    execute_sudo \"/bin/chmod\" \"u+rwx\" \"${chmods[@]}\"\n  fi\n  if [[ \"${#group_chmods[@]}\" -gt 0 ]]; then\n    execute_sudo \"/bin/chmod\" \"g+rwx\" \"${group_chmods[@]}\"\n  fi\n  if [[ \"${#user_chmods[@]}\" -gt 0 ]]; then\n    execute_sudo \"/bin/chmod\" \"755\" \"${user_chmods[@]}\"\n  fi\n  if [[ \"${#chowns[@]}\" -gt 0 ]]; then\n    execute_sudo \"$CHOWN\" \"$USER\" \"${chowns[@]}\"\n  fi\n  if [[ \"${#chgrps[@]}\" -gt 0 ]]; then\n    execute_sudo \"$CHGRP\" \"$GROUP\" \"${chgrps[@]}\"\n  fi\nelse\n  execute_sudo \"/bin/mkdir\" \"-p\" \"${HOMEBREW_PREFIX}\"\n  if [[ -z \"${HOMEBREW_ON_LINUX-}\" ]]; then\n    execute_sudo \"$CHOWN\" \"root:wheel\" \"${HOMEBREW_PREFIX}\"\n  else\n    execute_sudo \"$CHOWN\" \"$USER:$GROUP\" \"${HOMEBREW_PREFIX}\"\n  fi\nfi\n\nif [[ \"${#mkdirs[@]}\" -gt 0 ]]; then\n  execute_sudo \"/bin/mkdir\" \"-p\" \"${mkdirs[@]}\"\n  execute_sudo \"/bin/chmod\" \"g+rwx\" \"${mkdirs[@]}\"\n  execute_sudo \"$CHOWN\" \"$USER\" \"${mkdirs[@]}\"\n  execute_sudo \"$CHGRP\" \"$GROUP\" \"${mkdirs[@]}\"\nfi\n\nif ! [[ -d \"${HOMEBREW_CACHE}\" ]]; then\n  if [[ -z \"${HOMEBREW_ON_LINUX-}\" ]]; then\n    execute_sudo \"/bin/mkdir\" \"-p\" \"${HOMEBREW_CACHE}\"\n  else\n    execute \"/bin/mkdir\" \"-p\" \"${HOMEBREW_CACHE}\"\n  fi\nfi\nif exists_but_not_writable \"${HOMEBREW_CACHE}\"; then\n  execute_sudo \"/bin/chmod\" \"g+rwx\" \"${HOMEBREW_CACHE}\"\nfi\nif file_not_owned \"${HOMEBREW_CACHE}\"; then\n  execute_sudo \"$CHOWN\" \"$USER\" \"${HOMEBREW_CACHE}\"\nfi\nif file_not_grpowned \"${HOMEBREW_CACHE}\"; then\n  execute_sudo \"$CHGRP\" \"$GROUP\" \"${HOMEBREW_CACHE}\"\nfi\nif [[ -d \"${HOMEBREW_CACHE}\" ]]; then\n  execute \"$TOUCH\" \"${HOMEBREW_CACHE}/.cleaned\"\nfi\n\nif should_install_command_line_tools && version_ge \"$macos_version\" \"10.13\"; then\n  ohai \"Searching online for the Command Line Tools\"\n  # This temporary file prompts the 'softwareupdate' utility to list the Command Line Tools\n  clt_placeholder=\"/tmp/.com.apple.dt.CommandLineTools.installondemand.in-progress\"\n  execute_sudo \"$TOUCH\" \"$clt_placeholder\"\n\n  clt_label_command=\"/usr/sbin/softwareupdate -l |\n                      grep -B 1 -E 'Command Line Tools' |\n                      awk -F'*' '/^ *\\\\*/ {print \\$2}' |\n                      sed -e 's/^ *Label: //' -e 's/^ *//' |\n                      sort -V |\n                      tail -n1\"\n  clt_label=\"$(chomp \"$(/bin/bash -c \"$clt_label_command\")\")\"\n\n  if [[ -n \"$clt_label\" ]]; then\n    ohai \"Installing $clt_label\"\n    execute_sudo \"/usr/sbin/softwareupdate\" \"-i\" \"$clt_label\"\n    execute_sudo \"/bin/rm\" \"-f\" \"$clt_placeholder\"\n    execute_sudo \"/usr/bin/xcode-select\" \"--switch\" \"/Library/Developer/CommandLineTools\"\n  fi\nfi\n\n# Headless install may have failed, so fallback to original 'xcode-select' method\nif should_install_command_line_tools && test -t 0; then\n  ohai \"Installing the Command Line Tools (expect a GUI popup):\"\n  execute_sudo \"/usr/bin/xcode-select\" \"--install\"\n  echo \"Press any key when the installation has completed.\"\n  getc\n  execute_sudo \"/usr/bin/xcode-select\" \"--switch\" \"/Library/Developer/CommandLineTools\"\nfi\n\nif [[ -z \"${HOMEBREW_ON_LINUX-}\" ]] && ! output=\"$(/usr/bin/xcrun clang 2>&1)\" && [[ \"$output\" == *\"license\"* ]]; then\n  abort \"$(cat <<EOABORT\nYou have not agreed to the Xcode license.\nBefore running the installer again please agree to the license by opening\nXcode.app or running:\n    sudo xcodebuild -license\nEOABORT\n)\"\nfi\n\nohai \"Downloading and installing Homebrew...\"\n(\n  cd \"${HOMEBREW_REPOSITORY}\" >/dev/null || return\n\n  # we do it in four steps to avoid merge errors when reinstalling\n  execute \"git\" \"init\" \"-q\"\n\n  # \"git remote add\" will fail if the remote is defined in the global config\n  execute \"git\" \"config\" \"remote.origin.url\" \"${BREW_REPO}\"\n  execute \"git\" \"config\" \"remote.origin.fetch\" \"+refs/heads/*:refs/remotes/origin/*\"\n\n  # ensure we don't munge line endings on checkout\n  execute \"git\" \"config\" \"core.autocrlf\" \"false\"\n\n  execute \"git\" \"fetch\" \"origin\" \"--force\"\n  execute \"git\" \"fetch\" \"origin\" \"--tags\" \"--force\"\n\n  execute \"git\" \"reset\" \"--hard\" \"origin/master\"\n\n  execute \"ln\" \"-sf\" \"${HOMEBREW_REPOSITORY}/bin/brew\" \"${HOMEBREW_PREFIX}/bin/brew\"\n\n) || exit 1\n\nif [[ \":${PATH}:\" != *\":${HOMEBREW_PREFIX}/bin:\"* ]]; then\n  warn \"${HOMEBREW_PREFIX}/bin is not in your PATH.\"\nfi\n\nohai \"Installation successful!\"\necho\n\n# Use the shell's audible bell.\nif [[ -t 1 ]]; then\n  printf \"\\a\"\nfi\n\n# Use an extra newline and bold to avoid this being missed.\nohai \"Homebrew has enabled anonymous aggregate formulae and cask analytics.\"\necho \"$(cat <<EOS\n${tty_bold}Read the analytics documentation (and how to opt-out) here:\n  ${tty_underline}https://docs.brew.sh/Analytics${tty_reset}\nNo analytics data has been sent yet (or will be during this \\`install\\` run).\nEOS\n)\n\"\n\nohai \"Homebrew is run entirely by unpaid volunteers. Please consider donating:\"\necho \"$(cat <<EOS\n  ${tty_underline}https://github.com/Homebrew/brew#donations${tty_reset}\nEOS\n)\n\"\n\n(\n  cd \"${HOMEBREW_REPOSITORY}\" >/dev/null || return\n  execute \"git\" \"config\" \"--replace-all\" \"homebrew.analyticsmessage\" \"true\"\n  execute \"git\" \"config\" \"--replace-all\" \"homebrew.caskanalyticsmessage\" \"true\"\n) || exit 1\n\nohai \"Next steps:\"\necho \"- Run \\`brew help\\` to get started\"\necho \"- Further documentation: \"\necho \"    ${tty_underline}https://docs.brew.sh${tty_reset}\"\n\nif [[ -n \"${HOMEBREW_ON_LINUX-}\" ]]; then\n  case \"$SHELL\" in\n    */bash*)\n      if [[ -r \"$HOME/.bash_profile\" ]]; then\n        shell_profile=\"$HOME/.bash_profile\"\n      else\n        shell_profile=\"$HOME/.profile\"\n      fi\n      ;;\n    */zsh*)\n      shell_profile=\"$HOME/.zprofile\"\n      ;;\n    *)\n      shell_profile=\"$HOME/.profile\"\n      ;;\n  esac\n\n  echo \"- Install the Homebrew dependencies if you have sudo access:\"\n\n  if [[ $(command -v apt-get) ]]; then\n    echo \"    sudo apt-get install build-essential\"\n  elif [[ $(command -v yum) ]]; then\n    echo \"    sudo yum groupinstall 'Development Tools'\"\n  elif [[ $(command -v pacman) ]]; then\n    echo \"    sudo pacman -S base-devel\"\n  elif [[ $(command -v apk) ]]; then\n    echo \"    sudo apk add build-base\"\n  fi\n\n  cat <<EOS\n    See ${tty_underline}https://docs.brew.sh/linux${tty_reset} for more information\n- Add Homebrew to your ${tty_bold}PATH${tty_reset} in ${tty_underline}${shell_profile}${tty_reset}:\n    echo 'eval \\$(${HOMEBREW_PREFIX}/bin/brew shellenv)' >> ${shell_profile}\n    eval \\$(${HOMEBREW_PREFIX}/bin/brew shellenv)\n- We recommend that you install GCC:\n    brew install gcc\n\nEOS\nfi"
  },
  {
    "path": ".github/release-drafter.yml",
    "content": "name-template: 'v$RESOLVED_VERSION 🌈'\ntag-template: 'v$RESOLVED_VERSION'\ncategories:\n  - title: '🌟 Features'\n    labels:\n      - 'feature'\n      - 'enhancement'\n  - title: '🐛 Bug Fixes'\n    labels:\n      - 'fix'\n      - 'bugfix'\n      - 'bug'\n  - title: '📚 Documentation'\n    label: \n      - 'doc'\n      - 'documentation'\n  - title: '🧹 Maintenance'\n    label: \n      - 'maintenance'\nchange-template: '- $TITLE @$AUTHOR (#$NUMBER)'\nchange-title-escapes: '\\<*_&' # You can add # and @ to disable mentions, and add ` to disable code blocks.\nversion-resolver:\n  major:\n    labels:\n      - 'major'\n  minor:\n    labels:\n      - 'minor'\n  patch:\n    labels:\n      - 'patch'\n  default: patch\ntemplate: |\n  ## Changes\n\n  $CHANGES\n"
  },
  {
    "path": ".github/workflows/lint_title.yml",
    "content": "name: Lint pull request title\n\non:\n  pull_request:\n    types:\n      - opened\n      - synchronize\n      - reopened\n      - edited\n\nconcurrency:\n  cancel-in-progress: true\n  group: ${{ github.workflow }}-${{ github.ref }}\n\njobs:\n  lint-title:\n    runs-on: ubuntu-latest\n    steps:\n      # This step is necessary because the lint title uses the .commitlintrc.js file in the project root directory.\n      - name: Checkout Repository\n        uses: actions/checkout@v4\n\n      - name: Setup Node.js\n        uses: actions/setup-node@v4\n        with:\n          node-version: '16'\n\n      - name: Install commitlint\n        run: npm install --save-dev @commitlint/{config-conventional,cli}\n\n      - name: Validate PR Title with commitlint\n        env:\n          BODY: ${{ github.event.pull_request.title }}\n        run: |\n          echo \"$BODY\" | npx commitlint --config .commitlintrc.js\n"
  },
  {
    "path": ".github/workflows/release.yml",
    "content": "name: Release\n\non:\n  push:\n    branches:\n      - main\n\npermissions:\n  contents: read\n\njobs:\n  release:\n    runs-on: ubuntu-latest\n    outputs:\n      release_created: ${{ steps.release_please.outputs.release_created }}\n\n    steps:\n      - name: Release please\n        id: release_please\n        uses: googleapis/release-please-action@v4\n        with:\n          token: ${{ secrets.PAT }}\n          release-type: simple\n\n  deploy_with_manylinux:\n    needs: release\n    permissions:\n      contents: write\n      pull-requests: read\n    runs-on: ubuntu-latest\n\n    steps:\n      - uses: actions/checkout@v4\n        if: needs.release.outputs.release_created == 'true'\n        with:\n          fetch-depth: 0\n\n      - name: Set up Python ${{ matrix.python-version }}\n        if: needs.release.outputs.release_created == 'true'\n        uses: actions/setup-python@v4\n        with:\n          python-version: ${{ matrix.python-version }}\n\n      - name: Build wheel on Linux\n        if: needs.release.outputs.release_created == 'true'\n        uses: RalfG/python-wheels-manylinux-build@v0.7.1-manylinux2014_x86_64\n        with:\n          python-versions: 'cp38-cp38 cp39-cp39 cp310-cp310 cp311-cp311 cp312-cp312'\n          build-requirements: 'numpy cython'\n\n      - name: Install dependencies\n        if: needs.release.outputs.release_created == 'true'\n        run: |\n          python -m pip install twine\n\n      - name: Upload to PyPi\n        if: needs.release.outputs.release_created == 'true'\n        env:\n          TWINE_USERNAME: __token__\n          TWINE_PASSWORD: ${{ secrets.TESTPYPI }}\n        run: |\n          twine check dist/pyqlib-*-manylinux*.whl\n          twine upload --repository-url https://test.pypi.org/legacy/ dist/pyqlib-*-manylinux*.whl --verbose\n\n  deploy_with_bdist_wheel:\n    needs: release\n    runs-on: ${{ matrix.os }}\n    strategy:\n      matrix:\n        # After testing, the whl files of pyqlib built by macos-14 and macos-15 in python environments of 3.8, 3.9, 3.10, 3.11, 3.12,\n        # the filenames are exactly duplicated, which will result in the duplicated whl files not being able to be uploaded to pypi,\n        # so we chose to just keep the latest macos-latest. macos-latest currently points to macos-15.\n        # Also, macos-13 will stop being supported on 2025-11-14.\n        # Refs: https://github.blog/changelog/2025-07-11-upcoming-changes-to-macos-hosted-runners-macos-latest-migration-and-xcode-support-policy-updates/\n        os: [windows-latest, macos-latest]\n        python-version: [\"3.8\", \"3.9\", \"3.10\", \"3.11\", \"3.12\"]\n\n    steps:\n      - uses: actions/checkout@v4\n        if: needs.release.outputs.release_created == 'true'\n        with:\n          fetch-depth: 0\n\n      - name: Set up Python ${{ matrix.python-version }}\n        if: needs.release.outputs.release_created == 'true'\n        uses: actions/setup-python@v4\n        with:\n          python-version: ${{ matrix.python-version }}\n\n      - name: Install dependencies\n        if: needs.release.outputs.release_created == 'true'\n        run: |\n          make dev\n\n      - name: Build wheel on ${{ matrix.os }}\n        if: needs.release.outputs.release_created == 'true'\n        run: |\n          make build\n\n      - name: Upload to PyPi\n        if: needs.release.outputs.release_created == 'true'\n        env:\n          TWINE_USERNAME: __token__\n          TWINE_PASSWORD: ${{ secrets.TESTPYPI }}\n        run: |\n          twine check dist/*.whl\n          twine upload --repository-url https://test.pypi.org/legacy/ dist/*.whl --verbose\n"
  },
  {
    "path": ".github/workflows/stale.yml",
    "content": "name: Mark stale issues and pull requests\n\non:\n  schedule:\n  - cron: \"0 0/3 * * *\"\n\njobs:\n  stale:\n\n    runs-on: ubuntu-latest\n\n    steps:\n    - uses: actions/stale@v3\n      with:\n        repo-token: ${{ secrets.GITHUB_TOKEN }}\n        stale-issue-message: 'This issue is stale because it has been open for three months with no activity. Remove the stale label or comment on the issue otherwise this will be closed in 5 days'\n        stale-pr-message: 'This PR is stale because it has been open for a year with no activity. Remove the stale label or comment on the PR otherwise this will be closed in 5 days'\n        stale-issue-label: 'stale'\n        stale-pr-label: 'stale'\n        days-before-stale: 90\n        days-before-pr-stale: 365\n        days-before-close: 5\n        operations-per-run: 100\n        exempt-issue-labels: 'bug,enhancement'\n        remove-stale-when-updated: true\n"
  },
  {
    "path": ".github/workflows/test_qlib_from_pip.yml",
    "content": "name: Test qlib from pip\n\non:\n  push:\n    branches: [ main ]\n  pull_request:\n    branches: [ main ]\n\njobs:\n  build:\n    timeout-minutes: 120\n\n    runs-on: ${{ matrix.os }}\n    strategy:\n      matrix:\n        os: [windows-latest, ubuntu-24.04, ubuntu-22.04, macos-14, macos-15]\n        # In github action, using python 3.7, pip install will not match the latest version of the package.\n        # Also, python 3.7 is no longer supported from macos-14, and will be phased out from macos-13 in the near future.\n        # All things considered, we have removed python 3.7.\n        python-version: [\"3.8\", \"3.9\", \"3.10\", \"3.11\", \"3.12\"]\n\n    steps:\n    - name: Test qlib from pip\n      uses: actions/checkout@v4\n      with:\n        fetch-depth: 0\n\n    - name: Set up Python ${{ matrix.python-version }}\n      uses: actions/setup-python@v4\n      with:\n        python-version: ${{ matrix.python-version }}\n\n    - name: Update pip to the latest version\n      run: |\n        python -m pip install --upgrade pip\n      \n    - name: Qlib installation test\n      run: |\n        python -m pip install pyqlib\n\n    - name: Install Lightgbm for MacOS\n      if: ${{ matrix.os == 'macos-14' || matrix.os == 'macos-15' }}\n      run: |\n        brew update\n        brew install libomp || brew reinstall libomp\n        python -m pip install --no-binary=:all: lightgbm\n\n    - name: Downloads dependencies data\n      run: |\n        cd ..\n        python -m qlib.cli.data qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn\n        cd qlib\n\n    - name: Test workflow by config\n      run: |\n        qrun examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml\n"
  },
  {
    "path": ".github/workflows/test_qlib_from_source.yml",
    "content": "name: Test qlib from source\n\non:\n  push:\n    branches: [ main ]\n  pull_request:\n    branches: [ main ]\n\njobs:\n  build:\n    timeout-minutes: 180\n    # we may retry for 3 times for `Unit tests with Pytest`\n\n    runs-on: ${{ matrix.os }}\n    strategy:\n      matrix:\n        os: [windows-latest, ubuntu-24.04, ubuntu-22.04, macos-14, macos-15]\n        # In github action, using python 3.7, pip install will not match the latest version of the package.\n        # Also, python 3.7 is no longer supported from macos-14, and will be phased out from macos-13 in the near future.\n        # All things considered, we have removed python 3.7.\n        python-version: [\"3.8\", \"3.9\", \"3.10\", \"3.11\", \"3.12\"]\n\n    steps:\n    - name: Test qlib from source\n      uses: actions/checkout@v4\n      with:\n        fetch-depth: 0\n\n    - name: Set up Python ${{ matrix.python-version }}\n      uses: actions/setup-python@v4\n      with:\n        python-version: ${{ matrix.python-version }}\n\n    - name: Update pip to the latest version\n      run: |\n        python -m pip install --upgrade pip\n\n    - name: Installing pytorch for macos\n      if: ${{ matrix.os == 'macos-14' || matrix.os == 'macos-15' }}\n      run: |\n        python -m pip install torch torchvision torchaudio\n\n    - name: Installing pytorch for ubuntu\n      if: ${{ matrix.os == 'ubuntu-24.04' || matrix.os == 'ubuntu-22.04' }}\n      run: |\n        python -m pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu\n\n    - name: Installing pytorch for windows\n      if: ${{ matrix.os == 'windows-latest' }}\n      run: |\n        python -m pip install torch torchvision torchaudio\n\n    - name: Set up Python tools\n      run: |\n        make dev\n\n    - name: Lint with Black\n      run: |\n        make black\n\n    - name: Make html with sphinx\n      # Since read the docs builds on ubuntu 22.04, we only need to test that the build passes on ubuntu 22.04.\n      if: ${{ matrix.os == 'ubuntu-22.04' }}\n      run: |\n        make docs-gen\n\n    - name: Check Qlib with pylint\n      run: |\n        make pylint\n\n    - name: Check Qlib with flake8\n      run: |\n        make flake8\n\n    - name: Check Qlib with mypy\n      run: |\n        make mypy\n    \n    # Due to issues that cannot be automatically fixed when running `nbqa black . -l 120 --check --diff` on Jupyter notebooks,\n    # we reverted to a version of `black` earlier than 26.1.0 before performing the checks.\n    - name: Check Qlib ipynb with nbqa\n      run: |\n        python -m pip install \"black<26.1\"\n        make nbqa\n\n    - name: Test data downloads\n      run: |\n        python scripts/get_data.py qlib_data --name qlib_data_simple --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn\n        python scripts/get_data.py download_data --file_name rl_data.zip --target_dir tests/.data/rl\n\n    - name: Install Lightgbm for MacOS\n      if: ${{ matrix.os == 'macos-14' || matrix.os == 'macos-15' }}\n      run: |\n        brew update\n        brew install libomp || brew reinstall libomp\n        python -m pip install --no-binary=:all: lightgbm\n\n    - name: Check Qlib ipynb with nbconvert\n      run: |\n        make nbconvert\n\n    - name: Test workflow by config (install from source)\n      run: |\n        python -m pip install numba\n        python qlib/cli/run.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml\n\n    - name: Unit tests with Pytest (MacOS)\n      if: ${{ matrix.os == 'macos-14' || matrix.os == 'macos-15' }}\n      uses: nick-fields/retry@v2\n      with:\n        timeout_minutes: 60\n        max_attempts: 3\n        command: |\n          # Limit the number of threads in various libraries to prevent Segmentation faults caused by OpenMP multithreading conflicts under macOS.\n          export OMP_NUM_THREADS=1  # Limit the number of OpenMP threads\n          export MKL_NUM_THREADS=1  # Limit the number of Intel MKL threads\n          export NUMEXPR_NUM_THREADS=1  # Limit the number of NumExpr threads\n          export OPENBLAS_NUM_THREADS=1  # Limit the number of OpenBLAS threads\n          export VECLIB_MAXIMUM_THREADS=1  # Limit the number of macOS Accelerate/vecLib threads\n          cd tests\n          python -m pytest . -m \"not slow\" --durations=0\n\n    - name: Unit tests with Pytest (Ubuntu and Windows)\n      if: ${{ matrix.os != 'macos-13' && matrix.os != 'macos-14' && matrix.os != 'macos-15' }}\n      uses: nick-fields/retry@v2\n      with:\n        timeout_minutes: 60\n        max_attempts: 3\n        command: |\n          cd tests\n          python -m pytest . -m \"not slow\" --durations=0\n"
  },
  {
    "path": ".github/workflows/test_qlib_from_source_slow.yml",
    "content": "name: Test qlib from source slow\n\non:\n  push:\n    branches: [ main ]\n  pull_request:\n    branches: [ main ]\n\njobs:\n  build:\n    timeout-minutes: 720\n    # we may retry for 3 times for `Unit tests with Pytest`\n\n    runs-on: ${{ matrix.os }}\n    strategy:\n      matrix:\n        os: [windows-latest, ubuntu-24.04, ubuntu-22.04, macos-14, macos-15]\n        # In github action, using python 3.7, pip install will not match the latest version of the package.\n        # Also, python 3.7 is no longer supported from macos-14, and will be phased out from macos-13 in the near future.\n        # All things considered, we have removed python 3.7.\n        python-version: [\"3.8\", \"3.9\", \"3.10\", \"3.11\", \"3.12\"]\n\n    steps:\n    - name: Test qlib from source slow\n      uses: actions/checkout@v4\n      with:\n        fetch-depth: 0\n\n    - name: Set up Python ${{ matrix.python-version }}\n      uses: actions/setup-python@v4\n      with:\n        python-version: ${{ matrix.python-version }}\n\n    - name: Set up Python tools\n      run: |\n        make dev\n\n    - name: Downloads dependencies data\n      run: |\n        python scripts/get_data.py qlib_data --name qlib_data_simple --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn\n\n    - name: Install Lightgbm for MacOS\n      if: ${{ matrix.os == 'macos-14' || matrix.os == 'macos-15' }}\n      run: |\n        brew update\n        brew install libomp || brew reinstall libomp\n        python -m pip install --no-binary=:all: lightgbm\n\n    - name: Unit tests with Pytest\n      uses: nick-fields/retry@v2\n      with:\n        timeout_minutes: 240\n        max_attempts: 3\n        command: |\n          cd tests\n          python -m pytest . -m \"slow\" --durations=0\n"
  },
  {
    "path": ".gitignore",
    "content": "# https://github.com/github/gitignore/blob/master/Python.gitignore\n__pycache__/\n\n*.pyc\n*.pyd\n*.so\n*.ipynb\n.ipynb_checkpoints\n_build\nbuild/\ndist/\n\n*.pkl\n*.hd5\n*.csv\n\n.env\n.vim\n.nvimrc\n.vscode\n\nqlib/VERSION.txt\nqlib/data/_libs/expanding.cpp\nqlib/data/_libs/rolling.cpp\nqlib/_version.py\nexamples/estimator/estimator_example/\nexamples/rl/data/\nexamples/rl/checkpoints/\nexamples/rl/outputs/\nexamples/rl_order_execution/data/\nexamples/rl_order_execution/outputs/\n\n*.egg-info/\n\n# test related\ntest-output.xml\n.output\n.data\n\n# special software\nmlruns/\n\ntags\n\n.pytest_cache/\n.mypy_cache/\n.vscode/\n\n*.swp\n\n./pretrain\n.idea/\n.aider*\n"
  },
  {
    "path": ".mypy.ini",
    "content": "[mypy]\nexclude = (?x)(\n    ^qlib/backtest/high_performance_ds\\.py$\n    | ^qlib/contrib\n    | ^qlib/data\n    | ^qlib/model\n    | ^qlib/strategy\n    | ^qlib/tests\n    | ^qlib/utils\n    | ^qlib/workflow\n    | ^qlib/config\\.py$\n    | ^qlib/log\\.py$\n    | ^qlib/__init__\\.py$\n  )\nignore_missing_imports = true\ndisallow_incomplete_defs = true\nfollow_imports = skip\n"
  },
  {
    "path": ".pre-commit-config.yaml",
    "content": "repos:\n-   repo: https://github.com/psf/black\n    rev: 23.7.0\n    hooks:\n    -   id: black\n        args: [\"qlib\", \"-l 120\"]\n\n-   repo: https://github.com/PyCQA/flake8\n    rev: 4.0.1\n    hooks:\n        - id: flake8\n          args: [\"--ignore=E501,F541,E266,E402,W503,E731,E203\"]\n"
  },
  {
    "path": ".pylintrc",
    "content": "[TYPECHECK]\n# https://stackoverflow.com/a/53572939 \n# List of members which are set dynamically and missed by Pylint inference\n# system, and so shouldn't trigger E1101 when accessed.\ngenerated-members=numpy.*, torch.*\n"
  },
  {
    "path": ".readthedocs.yaml",
    "content": "# .readthedocs.yml\n# Read the Docs configuration file\n# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details\n\n# Required\nversion: 2\n\n# Set the version of Python and other tools you might need\nbuild:\n  os: ubuntu-22.04\n  tools:\n    python: \"3.8\"\n\n# Build documentation in the docs/ directory with Sphinx\nsphinx:\n  configuration: docs/conf.py\n\n# Build all formats\nformats: all\n\n# Optionally set the version of Python and requirements required to build your docs\npython:\n  install:\n    - requirements: docs/requirements.txt\n    - method: pip\n      path: .\n"
  },
  {
    "path": "CHANGELOG.md",
    "content": ""
  },
  {
    "path": "CHANGES.rst",
    "content": "Changelog\n=========\nHere you can see the full list of changes between each QLib release.\n\nVersion 0.1.0\n-------------\nThis is the initial release of QLib library.\n\nVersion 0.1.1\n-------------\nPerformance optimize. Add more features and operators.\n\nVersion 0.1.2\n-------------\n- Support operator syntax. Now ``High() - Low()`` is equivalent to ``Sub(High(), Low())``.\n- Add more technical indicators.\n\nVersion 0.1.3\n-------------\nBug fix and add instruments filtering mechanism.\n\nVersion 0.2.0\n-------------\n- Redesign ``LocalProvider`` database format for performance improvement.\n- Support load features as string fields.\n- Add scripts for database construction.\n- More operators and technical indicators.\n\nVersion 0.2.1\n-------------\n- Support registering user-defined ``Provider``.\n- Support use operators in string format, e.g. ``['Ref($close, 1)']`` is valid field format.\n- Support dynamic fields in ``$some_field`` format. And existing fields like ``Close()`` may be deprecated in the future.\n\nVersion 0.2.2\n-------------\n- Add ``disk_cache`` for reusing features (enabled by default).\n- Add ``qlib.contrib`` for experimental model construction and evaluation.\n\n\nVersion 0.2.3\n-------------\n- Add ``backtest`` module\n- Decoupling the Strategy, Account, Position, Exchange from the backtest module\n\nVersion 0.2.4\n-------------\n- Add ``profit attribution`` module\n- Add ``rick_control`` and ``cost_control`` strategies\n\nVersion 0.3.0\n-------------\n- Add ``estimator`` module\n\nVersion 0.3.1\n-------------\n- Add ``filter`` module\n\nVersion 0.3.2\n-------------\n- Add real price trading, if the ``factor`` field in the data set is incomplete, use ``adj_price`` trading\n- Refactor ``handler`` ``launcher`` ``trainer`` code\n- Support ``backtest`` configuration parameters in the configuration file\n- Fix bug in position ``amount`` is 0\n- Fix bug of ``filter`` module\n\nVersion 0.3.3\n-------------\n- Fix bug of ``filter`` module\n\nVersion 0.3.4\n-------------\n- Support for ``finetune model``\n- Refactor ``fetcher`` code\n\nVersion 0.3.5\n-------------\n- Support multi-label training, you can provide multiple label in ``handler``. (But LightGBM doesn't support due to the algorithm itself)\n- Refactor ``handler`` code, dataset.py is no longer used, and you can deploy your own labels and features in ``feature_label_config``\n- Handler only offer DataFrame. Also, ``trainer`` and model.py only receive DataFrame\n- Change ``split_rolling_data``, we roll the data on market calendar now, not on normal date\n- Move some date config from ``handler`` to ``trainer``\n\nVersion 0.4.0\n-------------\n- Add `data` package that holds all data-related codes\n- Reform the data provider structure\n- Create a server for data centralized management `qlib-server <https://amc-msra.visualstudio.com/trading-algo/_git/qlib-server>`_\n- Add a `ClientProvider` to work with server\n- Add a pluggable cache mechanism\n- Add a recursive backtracking algorithm to inspect the furthest reference date for an expression\n\n.. note::\n    The ``D.instruments`` function does not support ``start_time``, ``end_time``, and ``as_list`` parameters, if you want to get the results of previous versions of ``D.instruments``, you can do this:\n\n\n    >>> from qlib.data import D\n    >>> instruments = D.instruments(market='csi500')\n    >>> D.list_instruments(instruments=instruments, start_time='2015-01-01', end_time='2016-02-15', as_list=True)\n\n\nVersion 0.4.1\n-------------\n- Add support Windows\n- Fix ``instruments`` type bug\n- Fix ``features`` is empty bug(It will cause failure in updating)\n- Fix ``cache`` lock and update bug\n- Fix use the same cache for the same field (the original space will add a new cache)\n- Change \"logger handler\" from config\n- Change model load support 0.4.0 later\n- The default value of the ``method`` parameter of ``risk_analysis`` function is changed from **ci** to **si**\n\n\nVersion 0.4.2\n-------------\n- Refactor DataHandler\n- Add ``Alpha360`` DataHandler\n\n\nVersion 0.4.3\n-------------\n- Implementing Online Inference and Trading Framework\n- Refactoring The interfaces of backtest and strategy module.\n\n\nVersion 0.4.4\n-------------\n- Optimize cache generation performance\n- Add report module\n- Fix bug when using ``ServerDatasetCache`` offline.\n- In the previous version of ``long_short_backtest``, there is a case of ``np.nan`` in long_short. The current version ``0.4.4`` has been fixed, so ``long_short_backtest`` will be different from the previous version.\n- In the ``0.4.2`` version of ``risk_analysis`` function, ``N`` is ``250``, and ``N`` is ``252`` from ``0.4.3``, so ``0.4.2`` is ``0.002122`` smaller than the ``0.4.3`` the backtest result is slightly different between ``0.4.2`` and ``0.4.3``.\n- refactor the argument of backtest function.\n    - **NOTE**:\n      - The default arguments of topk margin strategy is changed. Please pass the arguments explicitly if you want to get the same backtest result as previous version.\n      - The TopkWeightStrategy is changed slightly. It will try to sell the stocks more than ``topk``.  (The backtest result of TopkAmountStrategy remains the same)\n- The margin ratio mechanism is supported in the Topk Margin strategies.\n\n\nVersion 0.4.5\n-------------\n- Add multi-kernel implementation for both client and server.\n    - Support a new way to load data from client which skips dataset cache.\n    - Change the default dataset method from single kernel implementation to multi kernel implementation.\n- Accelerate the high frequency data reading by optimizing the relative modules.\n- Support a new method to write config file by using dict.\n\nVersion 0.4.6\n-------------\n- Some bugs are fixed\n    - The default config in `Version 0.4.5` is not friendly to daily frequency data.\n    - Backtest error in TopkWeightStrategy when `WithInteract=True`.\n\n\nVersion 0.5.0\n-------------\n- First opensource version\n    - Refine the docs, code\n    - Add baselines\n    - public data crawler\n\n\nVersion 0.8.0\n-------------\n- The backtest is greatly refactored.\n    - Nested decision execution framework is supported\n    - There are lots of changes for daily trading, it is hard to list all of them. But a few important changes could be noticed\n        - The trading limitation is more accurate;\n            - In `previous version <https://github.com/microsoft/qlib/blob/v0.7.2/qlib/contrib/backtest/exchange.py#L160>`__, longing and shorting actions share the same action.\n            - In `current version <https://github.com/microsoft/qlib/blob/7c31012b507a3823117bddcc693fc64899460b2a/qlib/backtest/exchange.py#L304>`__, the trading limitation is different between logging and shorting action.\n        - The constant is different when calculating annualized metrics.\n            - `Current version <https://github.com/microsoft/qlib/blob/7c31012b507a3823117bddcc693fc64899460b2a/qlib/contrib/evaluate.py#L42>`_ uses more accurate constant than `previous version <https://github.com/microsoft/qlib/blob/v0.7.2/qlib/contrib/evaluate.py#L22>`__\n        - `A new version <https://github.com/microsoft/qlib/blob/7c31012b507a3823117bddcc693fc64899460b2a/qlib/tests/data.py#L17>`__ of data is released. Due to the unstability of Yahoo data source, the data may be different after downloading data again.\n        - Users could check out the backtesting results between  `Current version <https://github.com/microsoft/qlib/tree/7c31012b507a3823117bddcc693fc64899460b2a/examples/benchmarks>`__ and `previous version <https://github.com/microsoft/qlib/tree/v0.7.2/examples/benchmarks>`__\n\n\nOther Versions\n--------------\nPlease refer to `Github release Notes <https://github.com/microsoft/qlib/releases>`_\n"
  },
  {
    "path": "CODE_OF_CONDUCT.md",
    "content": "# Microsoft Open Source Code of Conduct\n\nThis project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).\n\nResources:\n\n- [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)\n- [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)\n- Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns\n"
  },
  {
    "path": "Dockerfile",
    "content": "FROM continuumio/miniconda3:latest\n\nWORKDIR /qlib\n\nCOPY . .\n\nRUN apt-get update && \\\n    apt-get install -y build-essential\n\nRUN conda create --name qlib_env python=3.8 -y\nRUN echo \"conda activate qlib_env\" >> ~/.bashrc\nENV PATH /opt/conda/envs/qlib_env/bin:$PATH\n\nRUN python -m pip install --upgrade pip\n\nRUN python -m pip install numpy==1.23.5\nRUN python -m pip install pandas==1.5.3\nRUN python -m pip install importlib-metadata==5.2.0\nRUN python -m pip install \"cloudpickle<3\"\nRUN python -m pip install scikit-learn==1.3.2\n\nRUN python -m pip install cython packaging tables matplotlib statsmodels\nRUN python -m pip install pybind11 cvxpy\n\nARG IS_STABLE=\"yes\"\n\nRUN if [ \"$IS_STABLE\" = \"yes\" ]; then \\\n        python -m pip install pyqlib; \\\n    else \\\n        python setup.py install; \\\n    fi\n"
  },
  {
    "path": "LICENSE",
    "content": "    MIT License\n\n    Copyright (c) Microsoft Corporation.\n\n    Permission is hereby granted, free of charge, to any person obtaining a copy\n    of this software and associated documentation files (the \"Software\"), to deal\n    in the Software without restriction, including without limitation the rights\n    to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n    copies of the Software, and to permit persons to whom the Software is\n    furnished to do so, subject to the following conditions:\n\n    The above copyright notice and this permission notice shall be included in all\n    copies or substantial portions of the Software.\n\n    THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n    IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n    FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n    AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n    LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n    OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n    SOFTWARE\n"
  },
  {
    "path": "MANIFEST.in",
    "content": "exclude tests/*\ninclude qlib/*\ninclude qlib/*/*\ninclude qlib/*/*/*\ninclude qlib/*/*/*/*\ninclude qlib/*/*/*/*/*\n"
  },
  {
    "path": "Makefile",
    "content": ".PHONY: clean deepclean prerequisite dependencies lightgbm rl develop lint docs package test analysis all install dev black pylint flake8 mypy nbqa nbconvert lint build upload docs-gen\n#You can modify it according to your terminal\nSHELL := /bin/bash\n\n########################################################################################\n# Variables\n########################################################################################\n\n# Documentation target directory, will be adapted to specific folder for readthedocs.\nPUBLIC_DIR := $(shell [ \"$$READTHEDOCS\" = \"True\" ] && echo \"$$READTHEDOCS_OUTPUT/html\" || echo \"public\")\n\nSO_DIR := qlib/data/_libs\nSO_FILES := $(wildcard $(SO_DIR)/*.so)\n\nifeq ($(OS),Windows_NT)\n    IS_WINDOWS = true\nelse\n    IS_WINDOWS = false\nendif\n\n########################################################################################\n# Development Environment Management\n########################################################################################\n# Remove common intermediate files.\nclean:\n\t-rm -rf \\\n\t\t$(PUBLIC_DIR) \\\n\t\tqlib/data/_libs/*.cpp \\\n\t\tqlib/data/_libs/*.so \\\n\t\tmlruns \\\n\t\tpublic \\\n\t\tbuild \\\n\t\t.coverage \\\n\t\t.mypy_cache \\\n\t\t.pytest_cache \\\n\t\t.ruff_cache \\\n\t\tPipfile* \\\n\t\tcoverage.xml \\\n\t\tdist \\\n\t\trelease-notes.md\n\n\tfind . -name '*.egg-info' -print0 | xargs -0 rm -rf\n\tfind . -name '*.pyc' -print0 | xargs -0 rm -f\n\tfind . -name '*.swp' -print0 | xargs -0 rm -f\n\tfind . -name '.DS_Store' -print0 | xargs -0 rm -f\n\tfind . -name '__pycache__' -print0 | xargs -0 rm -rf\n\n# Remove pre-commit hook, virtual environment alongside itermediate files.\ndeepclean: clean\n\tif command -v pre-commit > /dev/null 2>&1; then pre-commit uninstall --hook-type pre-push; fi\n\tif command -v pipenv >/dev/null 2>&1 && pipenv --venv >/dev/null 2>&1; then pipenv --rm; fi\n\n# Prerequisite section\n# What this code does is compile two Cython modules, rolling and expanding, using setuptools and Cython,\n# and builds them as binary expansion modules that can be imported directly into Python.\n# Since pyproject.toml can't do that, we compile it here.\n\n# pywinpty as a dependency of jupyter on windows, if you use pip install pywinpty installation,\n# will first download the tar.gz file, and then locally compiled and installed,\n# this will lead to some unnecessary trouble, so we choose to install the compiled whl file, to avoid trouble.\nprerequisite:\n\t@if [ -n \"$(SO_FILES)\" ]; then \\\n\t\techo \"Shared library files exist, skipping build.\"; \\\n\telse \\\n\t\techo \"No shared library files found, building...\"; \\\n\t\tpip install --upgrade setuptools wheel; \\\n\t\tpython -m pip install cython numpy; \\\n\t\tpython -c \"from setuptools import setup, Extension; from Cython.Build import cythonize; import numpy; extensions = [Extension('qlib.data._libs.rolling', ['qlib/data/_libs/rolling.pyx'], language='c++', include_dirs=[numpy.get_include()]), Extension('qlib.data._libs.expanding', ['qlib/data/_libs/expanding.pyx'], language='c++', include_dirs=[numpy.get_include()])]; setup(ext_modules=cythonize(extensions, language_level='3'), script_args=['build_ext', '--inplace'])\"; \\\n\tfi\n\n\t@if [ \"$(IS_WINDOWS)\" = \"true\" ]; then \\\n\t\tpython -m pip install pywinpty --only-binary=:all:; \\\n\tfi\n\n# Install the package in editable mode.\ndependencies:\n\tpython -m pip install --no-cache-dir -e .\n\nlightgbm:\n\tpython -m pip install --no-cache-dir lightgbm --prefer-binary\n\nrl:\n\tpython -m pip install --no-cache-dir -e .[rl]\n\ndevelop:\n\tpython -m pip install --no-cache-dir -e .[dev]\n\nlint:\n\tpython -m pip install --no-cache-dir -e .[lint]\n\ndocs:\n\tpython -m pip install --no-cache-dir -e .[docs]\n\npackage:\n\tpython -m pip install --no-cache-dir -e .[package]\n\ntest:\n\tpython -m pip install --no-cache-dir -e .[test]\n\nanalysis:\n\tpython -m pip install --no-cache-dir -e .[analysis]\n\nclient:\n\tpython -m pip install --no-cache-dir -e .[client]\n\nall:\n\tpython -m pip install --no-cache-dir -e .[pywinpty,dev,lint,docs,package,test,analysis,rl]\n\ninstall: prerequisite dependencies\n\ndev: prerequisite all\n\n########################################################################################\n# Lint and pre-commit\n########################################################################################\n\n# Check lint with black.\nblack:\n\tblack . -l 120 --check --diff --exclude qlib/_version.py\n\n# Check code folder with pylint.\n# TODO: These problems we will solve in the future. Important among them are: W0221, W0223, W0237, E1102\n# \tC0103: invalid-name\n# \tC0209: consider-using-f-string\n# \tR0402: consider-using-from-import\n# \tR1705: no-else-return\n# \tR1710: inconsistent-return-statements\n# \tR1725: super-with-arguments\n# \tR1735: use-dict-literal\n# \tW0102: dangerous-default-value\n# \tW0212: protected-access\n# \tW0221: arguments-differ\n# \tW0223: abstract-method\n# \tW0231: super-init-not-called\n# \tW0237: arguments-renamed\n# \tW0612: unused-variable\n# \tW0621: redefined-outer-name\n# \tW0622: redefined-builtin\n# \tFIXME: specify exception type\n# \tW0703: broad-except\n# \tW1309: f-string-without-interpolation\n# \tE1102: not-callable\n# \tE1136: unsubscriptable-object\n# \tW4904: deprecated-class\n# \tR0917: too-many-positional-arguments\n# \tE1123: unexpected-keyword-arg\n# References for disable error: https://pylint.pycqa.org/en/latest/user_guide/messages/messages_overview.html\n# We use sys.setrecursionlimit(2000) to make the recursion depth larger to ensure that pylint works properly (the default recursion depth is 1000).\n# References for parameters: https://github.com/PyCQA/pylint/issues/4577#issuecomment-1000245962\npylint:\n\tpylint --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R0917,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,W4904,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1730,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0612,W0621,W0622,W0703,W1309,E1102,E1136 --const-rgx='[a-z_][a-z0-9_]{2,30}' qlib --init-hook=\"import astroid; astroid.context.InferenceContext.max_inferred = 500; import sys; sys.setrecursionlimit(2000)\"\n\tpylint --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R0917,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,E1123,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0246,W0612,W0621,W0622,W0703,W1309,E1102,E1136 --const-rgx='[a-z_][a-z0-9_]{2,30}' scripts --init-hook=\"import astroid; astroid.context.InferenceContext.max_inferred = 500; import sys; sys.setrecursionlimit(2000)\"\n\n# Check code with flake8.\n# The following flake8 error codes were ignored:\n# E501 line too long\n# \tDescription: We have used black to limit the length of each line to 120.\n# F541 f-string is missing placeholders\n# \tDescription: The same thing is done when using pylint for detection.\n# E266 too many leading '#' for block comment\n# \tDescription: To make the code more readable, a lot of \"#\" is used.\n#         This error code appears centrally in:\n# \t\t\tqlib/backtest/executor.py\n# \t\t\tqlib/data/ops.py\n# \t\t\tqlib/utils/__init__.py\n# E402 module level import not at top of file\n# \tDescription: There are times when module level import is not available at the top of the file.\n# W503 line break before binary operator\n# \tDescription: Since black formats the length of each line of code, it has to perform a line break when a line of arithmetic is too long.\n# E731 do not assign a lambda expression, use a def\n# \tDescription: Restricts the use of lambda expressions, but at some point lambda expressions are required.\n# E203 whitespace before ':'\n# \tDescription: If there is whitespace before \":\", it cannot pass the black check.\nflake8:\n\tflake8 --ignore=E501,F541,E266,E402,W503,E731,E203 --per-file-ignores=\"__init__.py:F401,F403\" qlib\n\n# Check code with mypy.\n# https://github.com/python/mypy/issues/10600\nmypy:\n\tmypy qlib --install-types --non-interactive\n\tmypy qlib --verbose\n\n# Check ipynb with nbqa.\nnbqa:\n\tnbqa black . -l 120 --check --diff\n\tnbqa pylint . --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0612,W0621,W0622,W0703,W1309,E1102,E1136,W0719,W0104,W0404,C0412,W0611,C0410 --const-rgx='[a-z_][a-z0-9_]{2,30}'\n\n# Check ipynb with nbconvert.(Run after data downloads)\n# TODO: Add more ipynb files in future\nnbconvert:\n\tjupyter nbconvert --to notebook --execute examples/workflow_by_code.ipynb\n\nlint: black pylint flake8 mypy nbqa\n\n########################################################################################\n# Package\n########################################################################################\n\n# Build the package.\nbuild:\n\tpython -m build --wheel\n\n# Upload the package.\nupload:\n\tpython -m twine upload dist/*\n\n########################################################################################\n# Documentation\n########################################################################################\n\ndocs-gen:\n\tpython -m sphinx.cmd.build -W docs $(PUBLIC_DIR)\n"
  },
  {
    "path": "README.md",
    "content": "[![Python Versions](https://img.shields.io/pypi/pyversions/pyqlib.svg?logo=python&logoColor=white)](https://pypi.org/project/pyqlib/#files)\n[![Platform](https://img.shields.io/badge/platform-linux%20%7C%20windows%20%7C%20macos-lightgrey)](https://pypi.org/project/pyqlib/#files)\n[![PypI Versions](https://img.shields.io/pypi/v/pyqlib)](https://pypi.org/project/pyqlib/#history)\n[![Upload Python Package](https://github.com/microsoft/qlib/workflows/Upload%20Python%20Package/badge.svg)](https://pypi.org/project/pyqlib/)\n[![Github Actions Test Status](https://github.com/microsoft/qlib/workflows/Test/badge.svg?branch=main)](https://github.com/microsoft/qlib/actions)\n[![Documentation Status](https://readthedocs.org/projects/qlib/badge/?version=latest)](https://qlib.readthedocs.io/en/latest/?badge=latest)\n[![License](https://img.shields.io/pypi/l/pyqlib)](LICENSE)\n[![Join the chat at https://gitter.im/Microsoft/qlib](https://badges.gitter.im/Microsoft/qlib.svg)](https://gitter.im/Microsoft/qlib?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)\n\n## :newspaper: **What's NEW!** &nbsp;   :sparkling_heart: \n\nRecent released features\n\n### Introducing <a href=\"https://github.com/microsoft/RD-Agent\"><img src=\"docs/_static/img/rdagent_logo.png\" alt=\"RD_Agent\" style=\"height: 2em\"></a>: LLM-Based Autonomous Evolving Agents for Industrial Data-Driven R&D\n\nWe are excited to announce the release of **RD-Agent**📢, a powerful tool that supports automated factor mining and model optimization in quant investment R&D.\n\nRD-Agent is now available on [GitHub](https://github.com/microsoft/RD-Agent), and we welcome your star🌟!\n\nTo learn more, please visit our [♾️Demo page](https://rdagent.azurewebsites.net/). Here, you will find demo videos in both English and Chinese to help you better understand the scenario and usage of RD-Agent.\n\nWe have prepared several demo videos for you:\n| Scenario | Demo video (English) | Demo video (中文) |\n| --                      | ------    | ------    |\n| Quant Factor Mining | [Link](https://rdagent.azurewebsites.net/factor_loop?lang=en) | [Link](https://rdagent.azurewebsites.net/factor_loop?lang=zh) |\n| Quant Factor Mining from reports | [Link](https://rdagent.azurewebsites.net/report_factor?lang=en) | [Link](https://rdagent.azurewebsites.net/report_factor?lang=zh) |\n| Quant Model Optimization | [Link](https://rdagent.azurewebsites.net/model_loop?lang=en) | [Link](https://rdagent.azurewebsites.net/model_loop?lang=zh) |\n\n- 📃**Paper**: [R&D-Agent-Quant: A Multi-Agent Framework for Data-Centric Factors and Model Joint Optimization](https://arxiv.org/abs/2505.15155)\n- 👾**Code**: https://github.com/microsoft/RD-Agent/\n```BibTeX\n@misc{li2025rdagentquant,\n    title={R\\&D-Agent-Quant: A Multi-Agent Framework for Data-Centric Factors and Model Joint Optimization},\n    author={Yuante Li and Xu Yang and Xiao Yang and Minrui Xu and Xisen Wang and Weiqing Liu and Jiang Bian},\n    year={2025},\n    eprint={2505.15155},\n    archivePrefix={arXiv},\n    primaryClass={cs.AI}\n}\n```\n![image](https://github.com/user-attachments/assets/3198bc10-47ba-4ee0-8a8e-46d5ce44f45d)\n\n***\n\n| Feature | Status |\n| --                      | ------    |\n| [R&D-Agent-Quant](https://arxiv.org/abs/2505.15155) Published | Apply R&D-Agent to Qlib for quant trading | \n| BPQP for End-to-end learning | 📈Coming soon!([Under review](https://github.com/microsoft/qlib/pull/1863)) |\n| 🔥LLM-driven Auto Quant Factory🔥 | 🚀 Released in [♾️RD-Agent](https://github.com/microsoft/RD-Agent) on Aug 8, 2024 |\n| KRNN and Sandwich models | :chart_with_upwards_trend: [Released](https://github.com/microsoft/qlib/pull/1414/) on May 26, 2023 |\n| Release Qlib v0.9.0 | :octocat: [Released](https://github.com/microsoft/qlib/releases/tag/v0.9.0) on Dec 9, 2022 |\n| RL Learning Framework | :hammer: :chart_with_upwards_trend: Released on Nov 10, 2022. [#1332](https://github.com/microsoft/qlib/pull/1332), [#1322](https://github.com/microsoft/qlib/pull/1322), [#1316](https://github.com/microsoft/qlib/pull/1316),[#1299](https://github.com/microsoft/qlib/pull/1299),[#1263](https://github.com/microsoft/qlib/pull/1263), [#1244](https://github.com/microsoft/qlib/pull/1244), [#1169](https://github.com/microsoft/qlib/pull/1169), [#1125](https://github.com/microsoft/qlib/pull/1125), [#1076](https://github.com/microsoft/qlib/pull/1076)|\n| HIST and IGMTF models | :chart_with_upwards_trend: [Released](https://github.com/microsoft/qlib/pull/1040) on Apr 10, 2022 |\n| Qlib [notebook tutorial](https://github.com/microsoft/qlib/tree/main/examples/tutorial) | 📖 [Released](https://github.com/microsoft/qlib/pull/1037) on Apr 7, 2022 | \n| Ibovespa index data | :rice: [Released](https://github.com/microsoft/qlib/pull/990) on Apr 6, 2022 |\n| Point-in-Time database | :hammer: [Released](https://github.com/microsoft/qlib/pull/343) on Mar 10, 2022 |\n| Arctic Provider Backend & Orderbook data example | :hammer: [Released](https://github.com/microsoft/qlib/pull/744) on Jan 17, 2022 |\n| Meta-Learning-based framework & DDG-DA  | :chart_with_upwards_trend:  :hammer: [Released](https://github.com/microsoft/qlib/pull/743) on Jan 10, 2022 | \n| Planning-based portfolio optimization | :hammer: [Released](https://github.com/microsoft/qlib/pull/754) on Dec 28, 2021 | \n| Release Qlib v0.8.0 | :octocat: [Released](https://github.com/microsoft/qlib/releases/tag/v0.8.0) on Dec 8, 2021 |\n| ADD model | :chart_with_upwards_trend: [Released](https://github.com/microsoft/qlib/pull/704) on Nov 22, 2021 |\n| ADARNN  model | :chart_with_upwards_trend: [Released](https://github.com/microsoft/qlib/pull/689) on Nov 14, 2021 |\n| TCN  model | :chart_with_upwards_trend: [Released](https://github.com/microsoft/qlib/pull/668) on Nov 4, 2021 |\n| Nested Decision Framework | :hammer: [Released](https://github.com/microsoft/qlib/pull/438) on Oct 1, 2021. [Example](https://github.com/microsoft/qlib/blob/main/examples/nested_decision_execution/workflow.py) and [Doc](https://qlib.readthedocs.io/en/latest/component/highfreq.html) |\n| Temporal Routing Adaptor (TRA) | :chart_with_upwards_trend: [Released](https://github.com/microsoft/qlib/pull/531) on July 30, 2021 |\n| Transformer & Localformer | :chart_with_upwards_trend: [Released](https://github.com/microsoft/qlib/pull/508) on July 22, 2021 |\n| Release Qlib v0.7.0 | :octocat: [Released](https://github.com/microsoft/qlib/releases/tag/v0.7.0) on July 12, 2021 |\n| TCTS Model | :chart_with_upwards_trend: [Released](https://github.com/microsoft/qlib/pull/491) on July 1, 2021 |\n| Online serving and automatic model rolling | :hammer:  [Released](https://github.com/microsoft/qlib/pull/290) on May 17, 2021 | \n| DoubleEnsemble Model | :chart_with_upwards_trend: [Released](https://github.com/microsoft/qlib/pull/286) on Mar 2, 2021 | \n| High-frequency data processing example | :hammer: [Released](https://github.com/microsoft/qlib/pull/257) on Feb 5, 2021  |\n| High-frequency trading example | :chart_with_upwards_trend: [Part of code released](https://github.com/microsoft/qlib/pull/227) on Jan 28, 2021  | \n| High-frequency data(1min) | :rice: [Released](https://github.com/microsoft/qlib/pull/221) on Jan 27, 2021 |\n| Tabnet Model | :chart_with_upwards_trend: [Released](https://github.com/microsoft/qlib/pull/205) on Jan 22, 2021 |\n\nFeatures released before 2021 are not listed here.\n\n<p align=\"center\">\n  <img src=\"docs/_static/img/logo/1.png\" />\n</p>\n\nQlib is an open-source, AI-oriented quantitative investment platform that aims to realize the potential, empower research, and create value using AI technologies in quantitative investment, from exploring ideas to implementing productions. Qlib supports diverse machine learning modeling paradigms, including supervised learning, market dynamics modeling, and reinforcement learning.\n\nAn increasing number of SOTA Quant research works/papers in diverse paradigms are being released in Qlib to collaboratively solve key challenges in quantitative investment. For example, 1) using supervised learning to mine the market's complex non-linear patterns from rich and heterogeneous financial data, 2) modeling the dynamic nature of the financial market using adaptive concept drift technology, and 3) using reinforcement learning to model continuous investment decisions and assist investors in optimizing their trading strategies.\n\nIt contains the full ML pipeline of data processing, model training, back-testing; and covers the entire chain of quantitative investment: alpha seeking, risk modeling, portfolio optimization, and order execution. \nFor more details, please refer to our paper [\"Qlib: An AI-oriented Quantitative Investment Platform\"](https://arxiv.org/abs/2009.11189).\n\n\n<table>\n  <tbody>\n    <tr>\n      <th>Frameworks, Tutorial, Data & DevOps</th>\n      <th>Main Challenges & Solutions in Quant Research</th>\n    </tr>\n    <tr>\n      <td>\n        <li><a href=\"#plans\"><strong>Plans</strong></a></li>\n        <li><a href=\"#framework-of-qlib\">Framework of Qlib</a></li>\n        <li><a href=\"#quick-start\">Quick Start</a></li>\n          <ul dir=\"auto\">\n            <li type=\"circle\"><a href=\"#installation\">Installation</a> </li>\n            <li type=\"circle\"><a href=\"#data-preparation\">Data Preparation</a></li>\n            <li type=\"circle\"><a href=\"#auto-quant-research-workflow\">Auto Quant Research Workflow</a></li>\n            <li type=\"circle\"><a href=\"#building-customized-quant-research-workflow-by-code\">Building Customized Quant Research Workflow by Code</a></li></ul>\n        <li><a href=\"#quant-dataset-zoo\"><strong>Quant Dataset Zoo</strong></a></li>\n        <li><a href=\"#learning-framework\">Learning Framework</a></li>\n        <li><a href=\"#more-about-qlib\">More About Qlib</a></li>\n        <li><a href=\"#offline-mode-and-online-mode\">Offline Mode and Online Mode</a>\n        <ul>\n          <li type=\"circle\"><a href=\"#performance-of-qlib-data-server\">Performance of Qlib Data Server</a></li></ul>\n        <li><a href=\"#related-reports\">Related Reports</a></li>\n        <li><a href=\"#contact-us\">Contact Us</a></li>\n        <li><a href=\"#contributing\">Contributing</a></li>\n      </td>\n      <td valign=\"baseline\">\n        <li><a href=\"#main-challenges--solutions-in-quant-research\">Main Challenges &amp; Solutions in Quant Research</a>\n          <ul>\n            <li type=\"circle\"><a href=\"#forecasting-finding-valuable-signalspatterns\">Forecasting: Finding Valuable Signals/Patterns</a>\n              <ul>\n                <li type=\"disc\"><a href=\"#quant-model-paper-zoo\"><strong>Quant Model (Paper) Zoo</strong></a>\n                  <ul>\n                    <li type=\"circle\"><a href=\"#run-a-single-model\">Run a Single Model</a></li>\n                    <li type=\"circle\"><a href=\"#run-multiple-models\">Run Multiple Models</a></li>\n                  </ul>\n                </li>\n              </ul>\n            </li>\n          <li type=\"circle\"><a href=\"#adapting-to-market-dynamics\">Adapting to Market Dynamics</a></li>\n          <li type=\"circle\"><a href=\"#reinforcement-learning-modeling-continuous-decisions\">Reinforcement Learning: modeling continuous decisions</a></li>\n          </ul>\n        </li>\n      </td>\n    </tr>\n  </tbody>\n</table>\n\n# Plans\nNew features under development(order by estimated release time).\nYour feedbacks about the features are very important.\n<!-- | Feature                        | Status      | -->\n<!-- | --                      | ------    | -->\n\n# Framework of Qlib\n\n<div style=\"align: center\">\n<img src=\"docs/_static/img/framework-abstract.jpg\" />\n</div>\n\nThe high-level framework of Qlib can be found above(users can find the [detailed framework](https://qlib.readthedocs.io/en/latest/introduction/introduction.html#framework) of Qlib's design when getting into nitty gritty).\nThe components are designed as loose-coupled modules, and each component could be used stand-alone.\n\nQlib provides a strong infrastructure to support Quant research. [Data](https://qlib.readthedocs.io/en/latest/component/data.html) is always an important part.\nA strong learning framework is designed to support diverse learning paradigms (e.g. [reinforcement learning](https://qlib.readthedocs.io/en/latest/component/rl.html), [supervised learning](https://qlib.readthedocs.io/en/latest/component/workflow.html#model-section)) and patterns at different levels(e.g. [market dynamic modeling](https://qlib.readthedocs.io/en/latest/component/meta.html)).\nBy modeling the market, [trading strategies](https://qlib.readthedocs.io/en/latest/component/strategy.html) will generate trade decisions that will be executed. Multiple trading strategies and executors in different levels or granularities can be [nested to be optimized and run together](https://qlib.readthedocs.io/en/latest/component/highfreq.html).\nAt last, a comprehensive [analysis](https://qlib.readthedocs.io/en/latest/component/report.html) will be provided and the model can be [served online](https://qlib.readthedocs.io/en/latest/component/online.html) in a low cost.\n\n\n# Quick Start\n\nThis quick start guide tries to demonstrate\n1. It's very easy to build a complete Quant research workflow and try your ideas with _Qlib_.\n2. Though with *public data* and *simple models*, machine learning technologies **work very well** in practical Quant investment.\n\nHere is a quick **[demo](https://terminalizer.com/view/3f24561a4470)** shows how to install ``Qlib``, and run LightGBM with ``qrun``. **But**, please make sure you have already prepared the data following the [instruction](#data-preparation).\n\n\n## Installation\n\nThis table demonstrates the supported Python version of `Qlib`:\n|               | install with pip      | install from source  |        plot        |\n| ------------- |:---------------------:|:--------------------:|:------------------:|\n| Python 3.8    | :heavy_check_mark:    | :heavy_check_mark:   | :heavy_check_mark: |\n| Python 3.9    | :heavy_check_mark:    | :heavy_check_mark:   | :heavy_check_mark: |\n| Python 3.10   | :heavy_check_mark:    | :heavy_check_mark:   | :heavy_check_mark: |\n| Python 3.11   | :heavy_check_mark:    | :heavy_check_mark:   | :heavy_check_mark: |\n| Python 3.12   | :heavy_check_mark:    | :heavy_check_mark:   | :heavy_check_mark: |\n\n**Note**: \n1. **Conda** is suggested for managing your Python environment. In some cases, using Python outside of a `conda` environment may result in missing header files, causing the installation failure of certain packages.\n2. Please pay attention that installing cython in Python 3.6 will raise some error when installing ``Qlib`` from source. If users use Python 3.6 on their machines, it is recommended to *upgrade* Python to version 3.8 or higher, or use `conda`'s Python to install ``Qlib`` from source.\n\n### Install with pip\nUsers can easily install ``Qlib`` by pip according to the following command.\n\n```bash\n  pip install pyqlib\n```\n\n**Note**: pip will install the latest stable qlib. However, the main branch of qlib is in active development. If you want to test the latest scripts or functions in the main branch. Please install qlib with the methods below.\n\n### Install from source\nAlso, users can install the latest dev version ``Qlib`` by the source code according to the following steps:\n\n* Before installing ``Qlib`` from source, users need to install some dependencies:\n\n  ```bash\n  pip install numpy\n  pip install --upgrade cython\n  ```\n\n* Clone the repository and install ``Qlib`` as follows.\n    ```bash\n    git clone https://github.com/microsoft/qlib.git && cd qlib\n    pip install .  # `pip install -e .[dev]` is recommended for development. check details in docs/developer/code_standard_and_dev_guide.rst\n    ```\n\n**Tips**: If you fail to install `Qlib` or run the examples in your environment,  comparing your steps and the [CI workflow](.github/workflows/test_qlib_from_source.yml) may help you find the problem.\n\n**Tips for Mac**: If you are using Mac with M1, you might encounter issues in building the wheel for LightGBM, which is due to missing dependencies from OpenMP. To solve the problem, install openmp first with ``brew install libomp`` and then run ``pip install .`` to build it successfully. \n\n## Data Preparation\n❗ Due to more restrict data security policy. The official dataset is disabled temporarily. You can try [this data source](https://github.com/chenditc/investment_data/releases) contributed by the community.\nHere is an example to download the latest data.\n```bash\nwget https://github.com/chenditc/investment_data/releases/latest/download/qlib_bin.tar.gz\nmkdir -p ~/.qlib/qlib_data/cn_data\ntar -zxvf qlib_bin.tar.gz -C ~/.qlib/qlib_data/cn_data --strip-components=1\nrm -f qlib_bin.tar.gz\n```\n\nThe official dataset below will resume in short future.\n\n\n----\n\nLoad and prepare data by running the following code:\n\n### Get with module\n  ```bash\n  # get 1d data\n  python -m qlib.cli.data qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn\n\n  # get 1min data\n  python -m qlib.cli.data qlib_data --target_dir ~/.qlib/qlib_data/cn_data_1min --region cn --interval 1min\n\n  ```\n\n### Get from source\n\n  ```bash\n  # get 1d data\n  python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn\n\n  # get 1min data\n  python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data_1min --region cn --interval 1min\n\n  ```\n\nThis dataset is created by public data collected by [crawler scripts](scripts/data_collector/), which have been released in\nthe same repository.\nUsers could create the same dataset with it. [Description of dataset](https://github.com/microsoft/qlib/tree/main/scripts/data_collector#description-of-dataset)\n\n*Please pay **ATTENTION** that the data is collected from [Yahoo Finance](https://finance.yahoo.com/lookup), and the data might not be perfect.\nWe recommend users to prepare their own data if they have a high-quality dataset. For more information, users can refer to the [related document](https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format)*.\n\n### Automatic update of daily frequency data (from yahoo finance)\n  > This step is *Optional* if users only want to try their models and strategies on history data.\n  > \n  > It is recommended that users update the data manually once (--trading_date 2021-05-25) and then set it to update automatically.\n  >\n  > **NOTE**: Users can't incrementally  update data based on the offline data provided by Qlib(some fields are removed to reduce the data size). Users should use [yahoo collector](https://github.com/microsoft/qlib/tree/main/scripts/data_collector/yahoo#automatic-update-of-daily-frequency-datafrom-yahoo-finance) to download Yahoo data from scratch and then incrementally update it.\n  > \n  > For more information, please refer to: [yahoo collector](https://github.com/microsoft/qlib/tree/main/scripts/data_collector/yahoo#automatic-update-of-daily-frequency-datafrom-yahoo-finance)\n\n  * Automatic update of data to the \"qlib\" directory each trading day(Linux)\n      * use *crontab*: `crontab -e`\n      * set up timed tasks:\n\n        ```\n        * * * * 1-5 python <script path> update_data_to_bin --qlib_data_1d_dir <user data dir>\n        ```\n        * **script path**: *scripts/data_collector/yahoo/collector.py*\n\n  * Manual update of data\n      ```\n      python scripts/data_collector/yahoo/collector.py update_data_to_bin --qlib_data_1d_dir <user data dir> --trading_date <start date> --end_date <end date>\n      ```\n      * *trading_date*: start of trading day\n      * *end_date*: end of trading day(not included)\n\n### Checking the health of the data\n  * We provide a script to check the health of the data, you can run the following commands to check whether the data is healthy or not.\n    ```\n    python scripts/check_data_health.py check_data --qlib_dir ~/.qlib/qlib_data/cn_data\n    ```\n  * Of course, you can also add some parameters to adjust the test results, such as this.\n    ```\n    python scripts/check_data_health.py check_data --qlib_dir ~/.qlib/qlib_data/cn_data --missing_data_num 30055 --large_step_threshold_volume 94485 --large_step_threshold_price 20\n    ```\n  * If you want more information about `check_data_health`, please refer to the [documentation](https://qlib.readthedocs.io/en/latest/component/data.html#checking-the-health-of-the-data).\n\n<!-- \n- Run the initialization code and get stock data:\n\n  ```python\n  import qlib\n  from qlib.data import D\n  from qlib.constant import REG_CN\n\n  # Initialization\n  mount_path = \"~/.qlib/qlib_data/cn_data\"  # target_dir\n  qlib.init(mount_path=mount_path, region=REG_CN)\n\n  # Get stock data by Qlib\n  # Load trading calendar with the given time range and frequency\n  print(D.calendar(start_time='2010-01-01', end_time='2017-12-31', freq='day')[:2])\n\n  # Parse a given market name into a stockpool config\n  instruments = D.instruments('csi500')\n  print(D.list_instruments(instruments=instruments, start_time='2010-01-01', end_time='2017-12-31', as_list=True)[:6])\n\n  # Load features of certain instruments in given time range\n  instruments = ['SH600000']\n  fields = ['$close', '$volume', 'Ref($close, 1)', 'Mean($close, 3)', '$high-$low']\n  print(D.features(instruments, fields, start_time='2010-01-01', end_time='2017-12-31', freq='day').head())\n  ```\n -->\n\n## Docker images\n1. Pulling a docker image from a docker hub repository\n    ```bash\n    docker pull pyqlib/qlib_image_stable:stable\n    ```\n2. Start a new Docker container\n    ```bash\n    docker run -it --name <container name> -v <Mounted local directory>:/app pyqlib/qlib_image_stable:stable\n    ```\n3. At this point you are in the docker environment and can run the qlib scripts. An example:\n    ```bash\n    >>> python scripts/get_data.py qlib_data --name qlib_data_simple --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn\n    >>> python qlib/cli/run.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml\n    ```\n4. Exit the container\n    ```bash\n    >>> exit\n    ```\n5. Restart the container\n    ```bash\n    docker start -i -a <container name>\n    ```\n6. Stop the container\n    ```bash\n    docker stop <container name>\n    ```\n7. Delete the container\n    ```bash\n    docker rm <container name>\n    ```\n8. If you want to know more information, please refer to the [documentation](https://qlib.readthedocs.io/en/latest/developer/how_to_build_image.html).\n\n## Auto Quant Research Workflow\nQlib provides a tool named `qrun` to run the whole workflow automatically (including building dataset, training models, backtest and evaluation). You can start an auto quant research workflow and have a graphical reports analysis according to the following steps: \n\n1. Quant Research Workflow: Run  `qrun` with lightgbm workflow config ([workflow_config_lightgbm_Alpha158.yaml](examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml) as following.\n    ```bash\n      cd examples  # Avoid running program under the directory contains `qlib`\n      qrun benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml\n    ```\n    If users want to use `qrun` under debug mode, please use the following command:\n    ```bash\n    python -m pdb qlib/cli/run.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml\n    ```\n    The result of `qrun` is as follows, please refer to [docs](https://qlib.readthedocs.io/en/latest/component/strategy.html#result) for more explanations about the result. \n\n    ```bash\n\n    'The following are analysis results of the excess return without cost.'\n                           risk\n    mean               0.000708\n    std                0.005626\n    annualized_return  0.178316\n    information_ratio  1.996555\n    max_drawdown      -0.081806\n    'The following are analysis results of the excess return with cost.'\n                           risk\n    mean               0.000512\n    std                0.005626\n    annualized_return  0.128982\n    information_ratio  1.444287\n    max_drawdown      -0.091078\n    ```\n    Here are detailed documents for `qrun` and [workflow](https://qlib.readthedocs.io/en/latest/component/workflow.html).\n\n2. Graphical Reports Analysis: First, run `python -m pip install .[analysis]` to install the required dependencies. Then run `examples/workflow_by_code.ipynb` with `jupyter notebook` to get graphical reports. \n    - Forecasting signal (model prediction) analysis\n      - Cumulative Return of groups\n      ![Cumulative Return](https://github.com/microsoft/qlib/blob/main/docs/_static/img/analysis/analysis_model_cumulative_return.png)\n      - Return distribution\n      ![long_short](https://github.com/microsoft/qlib/blob/main/docs/_static/img/analysis/analysis_model_long_short.png)\n      - Information Coefficient (IC)\n      ![Information Coefficient](https://github.com/microsoft/qlib/blob/main/docs/_static/img/analysis/analysis_model_IC.png)\n      ![Monthly IC](https://github.com/microsoft/qlib/blob/main/docs/_static/img/analysis/analysis_model_monthly_IC.png)\n      ![IC](https://github.com/microsoft/qlib/blob/main/docs/_static/img/analysis/analysis_model_NDQ.png)\n      - Auto Correlation of forecasting signal (model prediction)\n      ![Auto Correlation](https://github.com/microsoft/qlib/blob/main/docs/_static/img/analysis/analysis_model_auto_correlation.png)\n\n    - Portfolio analysis\n      - Backtest return\n      ![Report](https://github.com/microsoft/qlib/blob/main/docs/_static/img/analysis/report.png)\n      <!-- \n      - Score IC\n      ![Score IC](docs/_static/img/score_ic.png)\n      - Cumulative Return\n      ![Cumulative Return](docs/_static/img/cumulative_return.png)\n      - Risk Analysis\n      ![Risk Analysis](docs/_static/img/risk_analysis.png)\n      - Rank Label\n      ![Rank Label](docs/_static/img/rank_label.png)\n      -->\n   - [Explanation](https://qlib.readthedocs.io/en/latest/component/report.html) of above results\n\n## Building Customized Quant Research Workflow by Code\nThe automatic workflow may not suit the research workflow of all Quant researchers. To support a flexible Quant research workflow, Qlib also provides a modularized interface to allow researchers to build their own workflow by code. [Here](examples/workflow_by_code.ipynb) is a demo for customized Quant research workflow by code.\n\n# Main Challenges & Solutions in Quant Research\nQuant investment is a very unique scenario with lots of key challenges to be solved.\nCurrently, Qlib provides some solutions for several of them.\n\n## Forecasting: Finding Valuable Signals/Patterns\nAccurate forecasting of the stock price trend is a very important part to construct profitable portfolios.\nHowever, huge amount of data with various formats in the financial market which make it challenging to build forecasting models.\n\nAn increasing number of SOTA Quant research works/papers, which focus on building forecasting models to mine valuable signals/patterns in complex financial data, are released in `Qlib`\n\n\n### [Quant Model (Paper) Zoo](examples/benchmarks)\n\nHere is a list of models built on `Qlib`.\n- [GBDT based on XGBoost (Tianqi Chen, et al. KDD 2016)](examples/benchmarks/XGBoost/)\n- [GBDT based on LightGBM (Guolin Ke, et al. NIPS 2017)](examples/benchmarks/LightGBM/)\n- [GBDT based on Catboost (Liudmila Prokhorenkova, et al. NIPS 2018)](examples/benchmarks/CatBoost/)\n- [MLP based on pytorch](examples/benchmarks/MLP/)\n- [LSTM based on pytorch (Sepp Hochreiter, et al. Neural computation 1997)](examples/benchmarks/LSTM/)\n- [GRU based on pytorch (Kyunghyun Cho, et al. 2014)](examples/benchmarks/GRU/)\n- [ALSTM based on pytorch (Yao Qin, et al. IJCAI 2017)](examples/benchmarks/ALSTM)\n- [GATs based on pytorch (Petar Velickovic, et al. 2017)](examples/benchmarks/GATs/)\n- [SFM based on pytorch (Liheng Zhang, et al. KDD 2017)](examples/benchmarks/SFM/)\n- [TFT based on tensorflow (Bryan Lim, et al. International Journal of Forecasting 2019)](examples/benchmarks/TFT/)\n- [TabNet based on pytorch (Sercan O. Arik, et al. AAAI 2019)](examples/benchmarks/TabNet/)\n- [DoubleEnsemble based on LightGBM (Chuheng Zhang, et al. ICDM 2020)](examples/benchmarks/DoubleEnsemble/)\n- [TCTS based on pytorch (Xueqing Wu, et al. ICML 2021)](examples/benchmarks/TCTS/)\n- [Transformer based on pytorch (Ashish Vaswani, et al. NeurIPS 2017)](examples/benchmarks/Transformer/)\n- [Localformer based on pytorch (Juyong Jiang, et al.)](examples/benchmarks/Localformer/)\n- [TRA based on pytorch (Hengxu, Dong, et al. KDD 2021)](examples/benchmarks/TRA/)\n- [TCN based on pytorch (Shaojie Bai, et al. 2018)](examples/benchmarks/TCN/)\n- [ADARNN based on pytorch (YunTao Du, et al. 2021)](examples/benchmarks/ADARNN/)\n- [ADD based on pytorch (Hongshun Tang, et al.2020)](examples/benchmarks/ADD/)\n- [IGMTF based on pytorch (Wentao Xu, et al.2021)](examples/benchmarks/IGMTF/)\n- [HIST based on pytorch (Wentao Xu, et al.2021)](examples/benchmarks/HIST/)\n- [KRNN based on pytorch](examples/benchmarks/KRNN/)\n- [Sandwich based on pytorch](examples/benchmarks/Sandwich/)\n\nYour PR of new Quant models is highly welcomed.\n\nThe performance of each model on the `Alpha158` and `Alpha360` datasets can be found [here](examples/benchmarks/README.md).\n\n### Run a single model\nAll the models listed above are runnable with ``Qlib``. Users can find the config files we provide and some details about the model through the [benchmarks](examples/benchmarks) folder. More information can be retrieved at the model files listed above.\n\n`Qlib` provides three different ways to run a single model, users can pick the one that fits their cases best:\n- Users can use the tool `qrun` mentioned above to run a model's workflow based from a config file.\n- Users can create a `workflow_by_code` python script based on the [one](examples/workflow_by_code.py) listed in the `examples` folder.\n\n- Users can use the script [`run_all_model.py`](examples/run_all_model.py) listed in the `examples` folder to run a model. Here is an example of the specific shell command to be used: `python run_all_model.py run --models=lightgbm`, where the `--models` arguments can take any number of models listed above(the available models can be found  in [benchmarks](examples/benchmarks/)). For more use cases, please refer to the file's [docstrings](examples/run_all_model.py).\n    - **NOTE**: Each baseline has different environment dependencies, please make sure that your python version aligns with the requirements(e.g. TFT only supports Python 3.6~3.7 due to the limitation of `tensorflow==1.15.0`)\n\n### Run multiple models\n`Qlib` also provides a script [`run_all_model.py`](examples/run_all_model.py) which can run multiple models for several iterations. (**Note**: the script only support *Linux* for now. Other OS will be supported in the future. Besides, it doesn't support parallel running the same model for multiple times as well, and this will be fixed in the future development too.)\n\nThe script will create a unique virtual environment for each model, and delete the environments after training. Thus, only experiment results such as `IC` and `backtest` results will be generated and stored.\n\nHere is an example of running all the models for 10 iterations:\n```python\npython run_all_model.py run 10\n```\n\nIt also provides the API to run specific models at once. For more use cases, please refer to the file's [docstrings](examples/run_all_model.py). \n\n### Break change\nIn `pandas`, `group_key` is one of the parameters of the `groupby` method. From version 1.5 to 2.0 of `pandas`, the default value of `group_key` has been changed from `no default` to `True`, which will cause qlib to report an error during operation. So we set `group_key=False`, but it doesn't guarantee that some programmes will run correctly, including:\n* qlib\\examples\\rl_order_execution\\scripts\\gen_training_orders.py\n* qlib\\examples\\benchmarks\\TRA\\src\\dataset.MTSDatasetH.py\n* qlib\\examples\\benchmarks\\TFT\\tft.py\n\n\n\n## [Adapting to Market Dynamics](examples/benchmarks_dynamic)\n\nDue to the non-stationary nature of the environment of the financial market, the data distribution may change in different periods, which makes the performance of models build on training data decays in the future test data.\nSo adapting the forecasting models/strategies to market dynamics is very important to the model/strategies' performance.\n\nHere is a list of solutions built on `Qlib`.\n- [Rolling Retraining](examples/benchmarks_dynamic/baseline/)\n- [DDG-DA on pytorch (Wendi, et al. AAAI 2022)](examples/benchmarks_dynamic/DDG-DA/)\n\n##  Reinforcement Learning: modeling continuous decisions\nQlib now supports reinforcement learning, a feature designed to model continuous investment decisions. This functionality assists investors in optimizing their trading strategies by learning from interactions with the environment to maximize some notion of cumulative reward.\n\nHere is a list of solutions built on `Qlib` categorized by scenarios.\n\n### [RL for order execution](examples/rl_order_execution)\n[Here](https://qlib.readthedocs.io/en/latest/component/rl/overall.html#order-execution) is the introduction of this scenario.  All the methods below are compared [here](examples/rl_order_execution).\n- [TWAP](examples/rl_order_execution/exp_configs/backtest_twap.yml)\n- [PPO: \"An End-to-End Optimal Trade Execution Framework based on Proximal Policy Optimization\", IJCAL 2020](examples/rl_order_execution/exp_configs/backtest_ppo.yml)\n- [OPDS: \"Universal Trading for Order Execution with Oracle Policy Distillation\", AAAI 2021](examples/rl_order_execution/exp_configs/backtest_opds.yml)\n\n# Quant Dataset Zoo\nDataset plays a very important role in Quant. Here is a list of the datasets built on `Qlib`:\n\n| Dataset                                    | US Market | China Market |\n| --                                         | --        | --           |\n| [Alpha360](./qlib/contrib/data/handler.py) |  √        |  √           |\n| [Alpha158](./qlib/contrib/data/handler.py) |  √        |  √           |\n\n[Here](https://qlib.readthedocs.io/en/latest/advanced/alpha.html) is a tutorial to build dataset with `Qlib`.\nYour PR to build new Quant dataset is highly welcomed.\n\n\n# Learning Framework\nQlib is high customizable and a lot of its components are learnable.\nThe learnable components are instances of `Forecast Model` and `Trading Agent`. They are learned based on the `Learning Framework` layer and then applied to multiple scenarios in `Workflow` layer.\nThe learning framework leverages the `Workflow` layer as well(e.g. sharing `Information Extractor`, creating environments based on `Execution Env`).\n\nBased on learning paradigms, they can be categorized into reinforcement learning and supervised learning.\n- For supervised learning, the detailed docs can be found [here](https://qlib.readthedocs.io/en/latest/component/model.html).\n- For reinforcement learning, the detailed docs can be found [here](https://qlib.readthedocs.io/en/latest/component/rl.html). Qlib's RL learning framework leverages `Execution Env` in `Workflow` layer to create environments.  It's worth noting that `NestedExecutor` is supported as well. This empowers users to optimize different level of strategies/models/agents together (e.g. optimizing an order execution strategy for a specific portfolio management strategy).\n\n\n# More About Qlib\nIf you want to have a quick glance at the most frequently used components of qlib, you can try notebooks [here](examples/tutorial/).\n\nThe detailed documents are organized in [docs](docs/).\n[Sphinx](http://www.sphinx-doc.org) and the readthedocs theme is required to build the documentation in html formats. \n```bash\ncd docs/\nconda install sphinx sphinx_rtd_theme -y\n# Otherwise, you can install them with pip\n# pip install sphinx sphinx_rtd_theme\nmake html\n```\nYou can also view the [latest document](http://qlib.readthedocs.io/) online directly.\n\nQlib is in active and continuing development. Our plan is in the roadmap, which is managed as a [github project](https://github.com/microsoft/qlib/projects/1).\n\n\n\n# Offline Mode and Online Mode\nThe data server of Qlib can either deployed as `Offline` mode or `Online` mode. The default mode is offline mode.\n\nUnder `Offline` mode, the data will be deployed locally. \n\nUnder `Online` mode, the data will be deployed as a shared data service. The data and their cache will be shared by all the clients. The data retrieval performance is expected to be improved due to a higher rate of cache hits. It will consume less disk space, too. The documents of the online mode can be found in [Qlib-Server](https://qlib-server.readthedocs.io/). The online mode can be deployed automatically with [Azure CLI based scripts](https://qlib-server.readthedocs.io/en/latest/build.html#one-click-deployment-in-azure). The source code of online data server can be found in [Qlib-Server repository](https://github.com/microsoft/qlib-server).\n\n## Performance of Qlib Data Server\nThe performance of data processing is important to data-driven methods like AI technologies. As an AI-oriented platform, Qlib provides a solution for data storage and data processing. To demonstrate the performance of Qlib data server, we\ncompare it with several other data storage solutions. \n\nWe evaluate the performance of several storage solutions by finishing the same task,\nwhich creates a dataset (14 features/factors) from the basic OHLCV daily data of a stock market (800 stocks each day from 2007 to 2020). The task involves data queries and processing.\n\n|                         | HDF5      | MySQL     | MongoDB   | InfluxDB  | Qlib -E -D  | Qlib +E -D   | Qlib +E +D  |\n| --                      | ------    | ------    | --------  | --------- | ----------- | ------------ | ----------- |\n| Total (1CPU) (seconds)  | 184.4±3.7 | 365.3±7.5 | 253.6±6.7 | 368.2±3.6 | 147.0±8.8   | 47.6±1.0     | **7.4±0.3** |\n| Total (64CPU) (seconds) |           |           |           |           | 8.8±0.6     | **4.2±0.2**  |             |\n* `+(-)E` indicates with (out) `ExpressionCache`\n* `+(-)D` indicates with (out) `DatasetCache`\n\nMost general-purpose databases take too much time to load data. After looking into the underlying implementation, we find that data go through too many layers of interfaces and unnecessary format transformations in general-purpose database solutions.\nSuch overheads greatly slow down the data loading process.\nQlib data are stored in a compact format, which is efficient to be combined into arrays for scientific computation.\n\n# Related Reports\n- [Guide To Qlib: Microsoft’s AI Investment Platform](https://analyticsindiamag.com/qlib/)\n- [微软也搞AI量化平台？还是开源的！](https://mp.weixin.qq.com/s/47bP5YwxfTp2uTHjUBzJQQ)\n- [微矿Qlib：业内首个AI量化投资开源平台](https://mp.weixin.qq.com/s/vsJv7lsgjEi-ALYUz4CvtQ)\n\n# Contact Us\n- If you have any issues, please create issue [here](https://github.com/microsoft/qlib/issues/new/choose) or send messages in [gitter](https://gitter.im/Microsoft/qlib).\n- If you want to make contributions to `Qlib`, please [create pull requests](https://github.com/microsoft/qlib/compare). \n- For other reasons, you are welcome to contact us by email([qlib@microsoft.com](mailto:qlib@microsoft.com)).\n  - We are recruiting new members(both FTEs and interns), your resumes are welcome!\n\nJoin IM discussion groups:\n|[Gitter](https://gitter.im/Microsoft/qlib)|\n|----|\n|![image](https://github.com/microsoft/qlib/blob/main/docs/_static/img/qrcode/gitter_qr.png)|\n\n# Contributing\nWe appreciate all contributions and thank all the contributors!\n<a href=\"https://github.com/microsoft/qlib/graphs/contributors\"><img src=\"https://contrib.rocks/image?repo=microsoft/qlib\" /></a>\n\nBefore we released Qlib as an open-source project on Github in Sep 2020, Qlib is an internal project in our group. Unfortunately, the internal commit history is not kept. A lot of members in our group have also contributed a lot to Qlib, which includes Ruihua Wang, Yinda Zhang, Haisu Yu, Shuyu Wang, Bochen Pang, and [Dong Zhou](https://github.com/evanzd/evanzd). Especially thanks to [Dong Zhou](https://github.com/evanzd/evanzd) due to his initial version of Qlib.\n\n## Guidance\n\nThis project welcomes contributions and suggestions.  \n**Here are some \n[code standards and development guidance](docs/developer/code_standard_and_dev_guide.rst) for submiting a pull request.**\n\nMaking contributions is not a hard thing. Solving an issue(maybe just answering a question raised in [issues list](https://github.com/microsoft/qlib/issues) or [gitter](https://gitter.im/Microsoft/qlib)), fixing/issuing a bug, improving the documents and even fixing a typo are important contributions to Qlib.\n\nFor example, if you want to contribute to Qlib's document/code, you can follow the steps in the figure below.\n<p align=\"center\">\n  <img src=\"https://github.com/demon143/qlib/blob/main/docs/_static/img/change%20doc.gif\" />\n</p>\n\nIf you don't know how to start to contribute, you can refer to the following examples.\n| Type | Examples |\n| -- | -- |\n| Solving issues | [Answer a question](https://github.com/microsoft/qlib/issues/749);  [issuing](https://github.com/microsoft/qlib/issues/765) or [fixing](https://github.com/microsoft/qlib/pull/792) a bug |\n| Docs | [Improve docs quality](https://github.com/microsoft/qlib/pull/797/files) ;  [Fix a typo](https://github.com/microsoft/qlib/pull/774) | \n| Feature |  Implement a [requested feature](https://github.com/microsoft/qlib/projects) like [this](https://github.com/microsoft/qlib/pull/754); [Refactor interfaces](https://github.com/microsoft/qlib/pull/539/files) |\n| Dataset | [Add a dataset](https://github.com/microsoft/qlib/pull/733) | \n| Models |  [Implement a new model](https://github.com/microsoft/qlib/pull/689), [some instructions to contribute models](https://github.com/microsoft/qlib/tree/main/examples/benchmarks#contributing) |\n\n[Good first issues](https://github.com/microsoft/qlib/labels/good%20first%20issue) are labelled to indicate that they are easy to start your contributions.\n\nYou can find some impefect implementation in Qlib by  `rg 'TODO|FIXME' qlib`\n \nIf you would like to become one of Qlib's maintainers to contribute more (e.g. help merge PR, triage issues), please contact us by email([qlib@microsoft.com](mailto:qlib@microsoft.com)).  We are glad to help to upgrade your permission.\n\n## License\nMost contributions require you to agree to a\nContributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us\nthe right to use your contribution. For details, visit https://cla.opensource.microsoft.com.\n\nWhen you submit a pull request, a CLA bot will automatically determine whether you need to provide\na CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions\nprovided by the bot. You will only need to do this once across all repos using our CLA.\n\nThis project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).\nFor more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or\ncontact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.\n"
  },
  {
    "path": "SECURITY.md",
    "content": "<!-- BEGIN MICROSOFT SECURITY.MD V0.0.5 BLOCK -->\n\n## Security\n\nMicrosoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/).\n\nIf you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)), please report it to us as described below.\n\n## Reporting Security Issues\n\n**Please do not report security vulnerabilities through public GitHub issues.**\n\nInstead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report).\n\nIf you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com).  If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc).\n\nYou should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). \n\nPlease include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:\n\n  * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)\n  * Full paths of source file(s) related to the manifestation of the issue\n  * The location of the affected source code (tag/branch/commit or direct URL)\n  * Any special configuration required to reproduce the issue\n  * Step-by-step instructions to reproduce the issue\n  * Proof-of-concept or exploit code (if possible)\n  * Impact of the issue, including how an attacker might exploit the issue\n\nThis information will help us triage your report more quickly.\n\nIf you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs.\n\n## Preferred Languages\n\nWe prefer all communications to be in English.\n\n## Policy\n\nMicrosoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd).\n\n<!-- END MICROSOFT SECURITY.MD BLOCK -->"
  },
  {
    "path": "build_docker_image.sh",
    "content": "#!/bin/bash\n\ndocker_user=\"your_dockerhub_username\"\n\nread -p \"Do you want to build the nightly version of the qlib image? (default is stable) (yes/no): \" answer;\nanswer=$(echo \"$answer\" | tr '[:upper:]' '[:lower:]')\n\nif [ \"$answer\" = \"yes\" ]; then\n    # Build the nightly version of the qlib image\n    docker build --build-arg IS_STABLE=no -t qlib_image -f ./Dockerfile .\n    image_tag=\"nightly\"\nelse\n    # Build the stable version of the qlib image\n    docker build -t qlib_image -f ./Dockerfile .\n    image_tag=\"stable\"\nfi\n\nread -p \"Is it uploaded to docker hub? (default is no) (yes/no): \" answer;\nanswer=$(echo \"$answer\" | tr '[:upper:]' '[:lower:]')\n\nif [ \"$answer\" = \"yes\" ]; then\n    # Log in to Docker Hub\n    # If you are a new docker hub user, please verify your email address before proceeding with this step.\n    docker login\n    # Tag the Docker image\n    docker tag qlib_image \"$docker_user/qlib_image:$image_tag\"\n    # Push the Docker image to Docker Hub\n    docker push \"$docker_user/qlib_image:$image_tag\"\nelse\n    echo \"Not uploaded to docker hub.\"\nfi\n"
  },
  {
    "path": "docs/FAQ/FAQ.rst",
    "content": "\nQlib FAQ\n############\n\nQlib Frequently Asked Questions\n===============================\n.. contents::\n    :depth: 1\n    :local:\n    :backlinks: none\n\n------\n\n\n1. RuntimeError: An attempt has been made to start a new process before the current process has finished its bootstrapping phase...\n-----------------------------------------------------------------------------------------------------------------------------------\n\n.. code-block:: console\n\n    RuntimeError:\n            An attempt has been made to start a new process before the\n            current process has finished its bootstrapping phase.\n\n            This probably means that you are not using fork to start your\n            child processes and you have forgotten to use the proper idiom\n            in the main module:\n\n                if __name__ == '__main__':\n                    freeze_support()\n                    ...\n\n            The \"freeze_support()\" line can be omitted if the program\n            is not going to be frozen to produce an executable.\n\nThis is caused by the limitation of multiprocessing under windows OS. Please refer to `here <https://stackoverflow.com/a/24374798>`_ for more info.\n\n**Solution**: To select a start method you use the ``D.features`` in the if __name__ == '__main__' clause of the main module. For example:\n\n.. code-block:: python\n\n    import qlib\n    from qlib.data import D\n\n\n    if __name__ == \"__main__\":\n        qlib.init()\n        instruments = [\"SH600000\"]\n        fields = [\"$close\", \"$change\"]\n        df = D.features(instruments, fields, start_time='2010-01-01', end_time='2012-12-31')\n        print(df.head())\n\n\n\n2. qlib.data.cache.QlibCacheException: It sees the key(...) of the redis lock has existed in your redis db now.\n---------------------------------------------------------------------------------------------------------------\n\nIt sees the key of the redis lock has existed in your redis db now. You can use the following command to clear your redis keys and rerun your commands\n\n.. code-block:: console\n\n    $ redis-cli\n    > select 1\n    > flushdb\n\nIf the issue is not resolved, use ``keys *`` to find if multiple keys exist. If so, try using ``flushall`` to clear all the keys.\n\n.. note::\n\n    ``qlib.config.redis_task_db`` defaults is ``1``, users can use ``qlib.init(redis_task_db=<other_db>)`` settings.\n\n\nAlso, feel free to post a new issue in our GitHub repository. We always check each issue carefully and try our best to solve them.\n\n3. ModuleNotFoundError: No module named 'qlib.data._libs.rolling'\n-----------------------------------------------------------------\n\n.. code-block:: python\n\n    #### Do not import qlib package in the repository directory in case of importing qlib from . without compiling #####\n    Traceback (most recent call last):\n    File \"<stdin>\", line 1, in <module>\n    File \"qlib/qlib/__init__.py\", line 19, in init\n        from .data.cache import H\n    File \"qlib/qlib/data/__init__.py\", line 8, in <module>\n        from .data import (\n    File \"qlib/qlib/data/data.py\", line 20, in <module>\n        from .cache import H\n    File \"qlib/qlib/data/cache.py\", line 36, in <module>\n        from .ops import Operators\n    File \"qlib/qlib/data/ops.py\", line 19, in <module>\n        from ._libs.rolling import rolling_slope, rolling_rsquare, rolling_resi\n    ModuleNotFoundError: No module named 'qlib.data._libs.rolling'\n\n- If the error occurs when importing ``qlib`` package with ``PyCharm`` IDE, users can execute the following command in the project root folder to compile Cython files and generate executable files:\n\n    .. code-block:: bash\n\n        python setup.py build_ext --inplace\n\n- If the error occurs when importing ``qlib`` package with command ``python`` , users need to change the running directory to ensure that the script does not run in the project directory.\n\n\n4. BadNamespaceError: / is not a connected namespace\n----------------------------------------------------\n\n.. code-block:: python\n\n      File \"qlib_online.py\", line 35, in <module>\n        cal = D.calendar()\n      File \"e:\\code\\python\\microsoft\\qlib_latest\\qlib\\qlib\\data\\data.py\", line 973, in calendar\n        return Cal.calendar(start_time, end_time, freq, future=future)\n      File \"e:\\code\\python\\microsoft\\qlib_latest\\qlib\\qlib\\data\\data.py\", line 798, in calendar\n        self.conn.send_request(\n      File \"e:\\code\\python\\microsoft\\qlib_latest\\qlib\\qlib\\data\\client.py\", line 101, in send_request\n        self.sio.emit(request_type + \"_request\", request_content)\n      File \"G:\\apps\\miniconda\\envs\\qlib\\lib\\site-packages\\python_socketio-5.3.0-py3.8.egg\\socketio\\client.py\", line 369, in emit\n        raise exceptions.BadNamespaceError(\n      BadNamespaceError: / is not a connected namespace.\n\n- The version of ``python-socketio`` in qlib needs to be the same as the version of ``python-socketio`` in qlib-server:\n\n    .. code-block:: bash\n\n        pip install -U python-socketio==<qlib-server python-socketio version>\n\n\n5. TypeError: send() got an unexpected keyword argument 'binary'\n----------------------------------------------------------------\n\n.. code-block:: python\n\n      File \"qlib_online.py\", line 35, in <module>\n        cal = D.calendar()\n      File \"e:\\code\\python\\microsoft\\qlib_latest\\qlib\\qlib\\data\\data.py\", line 973, in calendar\n        return Cal.calendar(start_time, end_time, freq, future=future)\n      File \"e:\\code\\python\\microsoft\\qlib_latest\\qlib\\qlib\\data\\data.py\", line 798, in calendar\n        self.conn.send_request(\n      File \"e:\\code\\python\\microsoft\\qlib_latest\\qlib\\qlib\\data\\client.py\", line 101, in send_request\n        self.sio.emit(request_type + \"_request\", request_content)\n      File \"G:\\apps\\miniconda\\envs\\qlib\\lib\\site-packages\\socketio\\client.py\", line 263, in emit\n        self._send_packet(packet.Packet(packet.EVENT, namespace=namespace,\n      File \"G:\\apps\\miniconda\\envs\\qlib\\lib\\site-packages\\socketio\\client.py\", line 339, in _send_packet\n        self.eio.send(ep, binary=binary)\n      TypeError: send() got an unexpected keyword argument 'binary'\n\n\n- The ``python-engineio`` version needs to be compatible with the ``python-socketio`` version, reference: https://github.com/miguelgrinberg/python-socketio#version-compatibility\n\n    .. code-block:: bash\n\n        pip install -U python-engineio==<compatible python-socketio version>\n        # or\n        pip install -U python-socketio==3.1.2 python-engineio==3.13.2\n"
  },
  {
    "path": "docs/Makefile",
    "content": "# Minimal makefile for Sphinx documentation\n#\n\n# You can set these variables from the command line.\nSPHINXOPTS    =\nSPHINXBUILD   = python3 -msphinx\nSPHINXPROJ    = Quantlab\nSOURCEDIR     = .\nBUILDDIR      = _build\n\n# Put it first so that \"make\" without argument is like \"make help\".\nhelp:\n\t@$(SPHINXBUILD) -M help \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)\n\n.PHONY: help Makefile\n\n# Catch-all target: route all unknown targets to Sphinx using the new\n# \"make mode\" option.  $(O) is meant as a shortcut for $(SPHINXOPTS).\n%: Makefile\n\tpip install -r requirements.txt\n\t@$(SPHINXBUILD) -M $@ \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)\n"
  },
  {
    "path": "docs/_static/demo.sh",
    "content": "#!/bin/sh\ngit clone https://github.com/microsoft/qlib.git\ncd qlib\nls\npip install pyqlib\n# or\n# pip install numpy\n# pip install --upgrade cython\n# python setup.py install\ncd examples\nls\nqrun benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml"
  },
  {
    "path": "docs/advanced/PIT.rst",
    "content": ".. _pit:\n\n============================\n(P)oint-(I)n-(T)ime Database\n============================\n.. currentmodule:: qlib\n\n\nIntroduction\n------------\nPoint-in-time data is a very important consideration when performing any sort of historical market analysis.\n\nFor example, let’s say we are backtesting a trading strategy and we are using the past five years of historical data as our input.\nOur model is assumed to trade once a day, at the market close, and we’ll say we are calculating the trading signal for 1 January 2020 in our backtest. At that point, we should only have data for 1 January 2020, 31 December 2019, 30 December 2019 etc.\n\nIn financial data (especially financial reports), the same piece of data may be amended for multiple times overtime.  If we only use the latest version for historical backtesting, data leakage will happen.\nPoint-in-time database is designed for solving this problem to make sure user get the right version of data at any historical timestamp. It will keep the performance of online trading and historical backtesting the same.\n\n\n\nData Preparation\n----------------\n\nQlib provides a crawler to help users to download financial data and then a converter to dump the data in Qlib format.\nPlease follow `scripts/data_collector/pit/README.md <https://github.com/microsoft/qlib/tree/main/scripts/data_collector/pit/>`_ to download and convert data.\nBesides, you can find some additional usage examples there.\n\n\nFile-based design for PIT data\n------------------------------\n\nQlib provides a file-based storage for PIT data.\n\nFor each feature, it contains 4 columns, i.e. date, period, value, _next.\nEach row corresponds to a statement.\n\nThe meaning of each feature with filename like `XXX_a.data`:\n\n- `date`: the statement's date of publication.\n- `period`: the period of the statement. (e.g. it will be quarterly frequency in most of the markets)\n    - If it is an annual period, it will be an integer corresponding to the year\n    - If it is an quarterly  periods, it will be an integer like `<year><index of quarter>`.  The last two decimal digits represents the index of quarter. Others represent the year.\n- `value`: the described value\n- `_next`: the byte index of the next occurance of the field.\n\nBesides the feature data, an index `XXX_a.index` is included to speed up the querying performance\n\nThe statements are soted by the `date` in ascending order from the beginning of the file.\n\n.. code-block:: python\n\n    # the data format from XXXX.data\n    array([(20070428, 200701, 0.090219  , 4294967295),\n           (20070817, 200702, 0.13933   , 4294967295),\n           (20071023, 200703, 0.24586301, 4294967295),\n           (20080301, 200704, 0.3479    ,         80),\n           (20080313, 200704, 0.395989  , 4294967295),\n           (20080422, 200801, 0.100724  , 4294967295),\n           (20080828, 200802, 0.24996801, 4294967295),\n           (20081027, 200803, 0.33412001, 4294967295),\n           (20090325, 200804, 0.39011699, 4294967295),\n           (20090421, 200901, 0.102675  , 4294967295),\n           (20090807, 200902, 0.230712  , 4294967295),\n           (20091024, 200903, 0.30072999, 4294967295),\n           (20100402, 200904, 0.33546099, 4294967295),\n           (20100426, 201001, 0.083825  , 4294967295),\n           (20100812, 201002, 0.200545  , 4294967295),\n           (20101029, 201003, 0.260986  , 4294967295),\n           (20110321, 201004, 0.30739301, 4294967295),\n           (20110423, 201101, 0.097411  , 4294967295),\n           (20110831, 201102, 0.24825101, 4294967295),\n           (20111018, 201103, 0.318919  , 4294967295),\n           (20120323, 201104, 0.4039    ,        420),\n           (20120411, 201104, 0.403925  , 4294967295),\n           (20120426, 201201, 0.112148  , 4294967295),\n           (20120810, 201202, 0.26484701, 4294967295),\n           (20121026, 201203, 0.370487  , 4294967295),\n           (20130329, 201204, 0.45004699, 4294967295),\n           (20130418, 201301, 0.099958  , 4294967295),\n           (20130831, 201302, 0.21044201, 4294967295),\n           (20131016, 201303, 0.30454299, 4294967295),\n           (20140325, 201304, 0.394328  , 4294967295),\n           (20140425, 201401, 0.083217  , 4294967295),\n           (20140829, 201402, 0.16450299, 4294967295),\n           (20141030, 201403, 0.23408499, 4294967295),\n           (20150421, 201404, 0.319612  , 4294967295),\n           (20150421, 201501, 0.078494  , 4294967295),\n           (20150828, 201502, 0.137504  , 4294967295),\n           (20151023, 201503, 0.201709  , 4294967295),\n           (20160324, 201504, 0.26420501, 4294967295),\n           (20160421, 201601, 0.073664  , 4294967295),\n           (20160827, 201602, 0.136576  , 4294967295),\n           (20161029, 201603, 0.188062  , 4294967295),\n           (20170415, 201604, 0.244385  , 4294967295),\n           (20170425, 201701, 0.080614  , 4294967295),\n           (20170728, 201702, 0.15151   , 4294967295),\n           (20171026, 201703, 0.25416601, 4294967295),\n           (20180328, 201704, 0.32954201, 4294967295),\n           (20180428, 201801, 0.088887  , 4294967295),\n           (20180802, 201802, 0.170563  , 4294967295),\n           (20181029, 201803, 0.25522   , 4294967295),\n           (20190329, 201804, 0.34464401, 4294967295),\n           (20190425, 201901, 0.094737  , 4294967295),\n           (20190713, 201902, 0.        ,       1040),\n           (20190718, 201902, 0.175322  , 4294967295),\n           (20191016, 201903, 0.25581899, 4294967295)],\n          dtype=[('date', '<u4'), ('period', '<u4'), ('value', '<f8'), ('_next', '<u4')])\n    # - each row contains 20 byte\n\n\n    # The data format from XXXX.index.  It consists of two parts\n    # 1) the start index of the data. So the first part of the info will be like\n    2007\n    # 2) the remain index data will be like information below\n    #    - The data indicate the **byte index** of first data update of a period.\n    #    - e.g. Because the info at both byte 80 and 100 corresponds to 200704. The byte index of first occurance (i.e. 100) is recorded in the data.\n    array([         0,         20,         40,         60,        100,\n                  120,        140,        160,        180,        200,\n                  220,        240,        260,        280,        300,\n                  320,        340,        360,        380,        400,\n                  440,        460,        480,        500,        520,\n                  540,        560,        580,        600,        620,\n                  640,        660,        680,        700,        720,\n                  740,        760,        780,        800,        820,\n                  840,        860,        880,        900,        920,\n                  940,        960,        980,       1000,       1020,\n                 1060, 4294967295], dtype=uint32)\n\n\n\n\nKnown limitations:\n\n- Currently, the PIT database is designed for quarterly or annually factors, which can handle fundamental data of financial reports in most markets.\n- Qlib leverage the file name to identify the type of the data. File with name like `XXX_q.data` corresponds to quarterly data. File with name like `XXX_a.data` corresponds to annual data.\n- The caclulation of PIT is not performed in the optimal way. There is great potential to boost the performance of PIT data calcuation.\n"
  },
  {
    "path": "docs/advanced/alpha.rst",
    "content": ".. _alpha:\n\n=========================\nBuilding Formulaic Alphas\n=========================\n.. currentmodule:: qlib\n\nIntroduction\n============\n\nIn quantitative trading practice, designing novel factors that can explain and predict future asset returns are of vital importance to the profitability of a strategy. Such factors are usually called alpha factors, or alphas in short.\n\n\nA formulaic alpha, as the name suggests, is a kind of alpha that can be presented as a formula or a mathematical expression.\n\n\nBuilding Formulaic Alphas in ``Qlib``\n=====================================\n\nIn ``Qlib``, users can easily build formulaic alphas.\n\nExample\n-------\n\n`MACD`, short for moving average convergence/divergence, is a formulaic alpha used in technical analysis of stock prices. It is designed to reveal changes in the strength, direction, momentum, and duration of a trend in a stock's price.\n\n`MACD` can be presented as the following formula:\n\n.. math::\n\n    MACD = 2\\times (DIF-DEA)\n\n.. note::\n\n    `DIF` means Differential value, which is 12-period EMA minus 26-period EMA.\n\n    .. math::\n\n        DIF = \\frac{EMA(CLOSE, 12) - EMA(CLOSE, 26)}{CLOSE}\n\n    `DEA` means a 9-period EMA of the DIF.\n\n    .. math::\n\n        DEA = EMA(DIF, 9)\n\nUsers can use ``Data Handler`` to build formulaic alphas `MACD` in qlib:\n\n.. note:: Users need to initialize ``Qlib`` with `qlib.init` first.  Please refer to `initialization <../start/initialization.html>`_.\n\n.. code-block:: python\n\n    >> from qlib.data.dataset.loader import QlibDataLoader\n    >> MACD_EXP = '2 * ((EMA($close, 12) - EMA($close, 26))/$close - EMA((EMA($close, 12) - EMA($close, 26))/$close, 9))'\n    >> fields = [MACD_EXP] # MACD\n    >> names = ['MACD']\n    >> labels = ['Ref($close, -2)/Ref($close, -1) - 1'] # label\n    >> label_names = ['LABEL']\n    >> data_loader_config = {\n    ..     \"feature\": (fields, names),\n    ..     \"label\": (labels, label_names)\n    .. }\n    >> data_loader = QlibDataLoader(config=data_loader_config)\n    >> df = data_loader.load(instruments='csi300', start_time='2010-01-01', end_time='2017-12-31')\n    >> print(df)\n                            feature     label\n                               MACD     LABEL\n    datetime   instrument\n    2010-01-04 SH600000    0.008781 -0.019672\n               SH600004    0.006699 -0.014721\n               SH600006    0.005714  0.002911\n               SH600008    0.000798  0.009818\n               SH600009    0.017015 -0.017758\n    ...                         ...       ...\n    2017-12-29 SZ300124    0.015071 -0.005074\n               SZ300136   -0.015466  0.056352\n               SZ300144    0.013082  0.011853\n               SZ300251   -0.001026  0.021739\n               SZ300315   -0.007559  0.012455\n\nReference\n=========\n\nTo learn more about ``Data Loader``, please refer to `Data Loader <../component/data.html#data-loader>`_\n\nTo learn more about ``Data API``, please refer to `Data API <../component/data.html>`_\n"
  },
  {
    "path": "docs/advanced/serial.rst",
    "content": ".. _serial:\n\n=============\nSerialization\n=============\n.. currentmodule:: qlib\n\nIntroduction\n============\n``Qlib`` supports dumping the state of ``DataHandler``, ``DataSet``, ``Processor`` and ``Model``, etc. into a disk and reloading them.\n\nSerializable Class\n==================\n\n``Qlib`` provides a base class ``qlib.utils.serial.Serializable``, whose state can be dumped into or loaded from disk in `pickle` format.\nWhen users dump the state of a ``Serializable`` instance, the attributes of the instance whose name **does not** start with `_` will be saved on the disk.\nHowever, users can use ``config`` method or override ``default_dump_all`` attribute to prevent this feature.\n\nUsers can also override ``pickle_backend`` attribute to choose a pickle backend. The supported value is \"pickle\" (default and common) and \"dill\" (dump more things such as function, more information in `here <https://pypi.org/project/dill/>`_).\n\nExample\n=======\n``Qlib``'s serializable class includes  ``DataHandler``, ``DataSet``, ``Processor`` and ``Model``, etc., which are subclass of  ``qlib.utils.serial.Serializable``.\nSpecifically, ``qlib.data.dataset.DatasetH`` is one of them. Users can serialize ``DatasetH`` as follows.\n\n.. code-block:: Python\n\n    ##=============dump dataset=============\n    dataset.to_pickle(path=\"dataset.pkl\") # dataset is an instance of qlib.data.dataset.DatasetH\n\n    ##=============reload dataset=============\n    with open(\"dataset.pkl\", \"rb\") as file_dataset:\n        dataset = pickle.load(file_dataset)\n\n.. note::\n    Only state of ``DatasetH`` should be saved on the disk, such as some `mean` and `variance` used for data normalization, etc.\n\n    After reloading the ``DatasetH``, users need to reinitialize it. It means that users can reset some states of ``DatasetH`` or ``QlibDataHandler`` such as `instruments`, `start_time`, `end_time` and `segments`, etc.,  and generate new data according to the states (data is not state and should not be saved on the disk).\n\nA more detailed example is in this `link <https://github.com/microsoft/qlib/tree/main/examples/highfreq>`_.\n\n\nAPI\n===\nPlease refer to `Serializable API <../reference/api.html#module-qlib.utils.serial.Serializable>`_.\n"
  },
  {
    "path": "docs/advanced/server.rst",
    "content": ".. _server:\n\n=============================\n``Online`` & ``Offline`` mode\n=============================\n.. currentmodule:: qlib\n\n\nIntroduction\n============\n\n``Qlib`` supports ``Online`` mode and ``Offline`` mode. Only the ``Offline`` mode is introduced in this document.\n\nThe ``Online`` mode is designed to solve the following problems:\n\n- Manage the data in a centralized way. Users don't have to manage data of different versions.\n- Reduce the amount of cache to be generated.\n- Make the data can be accessed in a remote way.\n\nQlib-Server\n===========\n\n``Qlib-Server`` is the assorted server system for ``Qlib``, which utilizes ``Qlib`` for basic calculations and provides extensive server system and cache mechanism. With QLibServer, the data provided for ``Qlib`` can be managed in a centralized manner. With ``Qlib-Server``, users can use ``Qlib`` in ``Online`` mode.\n\n\n\nReference\n=========\nIf users are interested in ``Qlib-Server`` and ``Online`` mode, please refer to `Qlib-Server Project <https://github.com/microsoft/qlib-server>`_ and `Qlib-Server Document <https://qlib-server.readthedocs.io/en/latest/>`_.\n"
  },
  {
    "path": "docs/advanced/task_management.rst",
    "content": ".. _task_management:\n\n===============\nTask Management\n===============\n.. currentmodule:: qlib\n\n\nIntroduction\n============\n\nThe `Workflow <../component/introduction.html>`_ part introduces how to run research workflow in a loosely-coupled way. But it can only execute one ``task`` when you use ``qrun``.\nTo automatically generate and execute different tasks, ``Task Management`` provides a whole process including `Task Generating`_, `Task Storing`_, `Task Training`_ and `Task Collecting`_. \nWith this module, users can run their ``task`` automatically at different periods, in different losses, or even by different models.The processes of task generation, model training and combine and collect data are shown in the following figure.\n\n.. image:: ../_static/img/Task-Gen-Recorder-Collector.svg\n    :align: center\n\nThis whole process can be used in `Online Serving <../component/online.html>`_.\n\nAn example of the entire process is shown `here <https://github.com/microsoft/qlib/tree/main/examples/model_rolling/task_manager_rolling.py>`__.\n\nTask Generating\n===============\nA ``task`` consists of `Model`, `Dataset`, `Record`, or anything added by users. \nThe specific task template can be viewed in \n`Task Section <../component/workflow.html#task-section>`_.\nEven though the task template is fixed, users can customize their ``TaskGen`` to generate different ``task`` by task template.\n\nHere is the base class of ``TaskGen``:\n\n.. autoclass:: qlib.workflow.task.gen.TaskGen\n    :members:\n    :noindex:\n\n``Qlib`` provides a class `RollingGen <https://github.com/microsoft/qlib/tree/main/qlib/workflow/task/gen.py>`_ to generate a list of ``task`` of the dataset in different date segments.\nThis class allows users to verify the effect of data from different periods on the model in one experiment. More information is `here <../reference/api.html#TaskGen>`__.\n\nTask Storing\n============\nTo achieve higher efficiency and the possibility of cluster operation, ``Task Manager`` will store all tasks in `MongoDB <https://www.mongodb.com/>`_.\n``TaskManager`` can fetch undone tasks automatically and manage the lifecycle of a set of tasks with error handling.\nUsers **MUST** finish the configuration of `MongoDB <https://www.mongodb.com/>`_ when using this module.\n\nUsers need to provide the MongoDB URL and database name for using ``TaskManager`` in `initialization <../start/initialization.html#Parameters>`_ or make a statement like this.\n\n    .. code-block:: python\n\n        from qlib.config import C\n        C[\"mongo\"] = {\n            \"task_url\" : \"mongodb://localhost:27017/\", # your MongoDB url\n            \"task_db_name\" : \"rolling_db\" # database name\n        }\n\n.. autoclass:: qlib.workflow.task.manage.TaskManager\n    :members:\n    :noindex:\n\nMore information of ``Task Manager`` can be found in `here <../reference/api.html#TaskManager>`__.\n\nTask Training\n=============\nAfter generating and storing those ``task``, it's time to run the ``task`` which is in the *WAITING* status.\n``Qlib`` provides a method called ``run_task`` to run those ``task`` in task pool, however, users can also customize how tasks are executed.\nAn easy way to get the ``task_func`` is using ``qlib.model.trainer.task_train`` directly.\nIt will run the whole workflow defined by ``task``, which includes *Model*, *Dataset*, *Record*.\n\n.. autofunction:: qlib.workflow.task.manage.run_task\n    :noindex:\n\nMeanwhile, ``Qlib`` provides a module called ``Trainer``. \n\n.. autoclass:: qlib.model.trainer.Trainer\n    :members:\n    :noindex:\n\n``Trainer`` will train a list of tasks and return a list of model recorders.\n``Qlib`` offer two kinds of Trainer, TrainerR is the simplest way and TrainerRM is based on TaskManager to help manager tasks lifecycle automatically. \nIf you do not want to use ``Task Manager`` to manage tasks, then use TrainerR to train a list of tasks generated by ``TaskGen`` is enough.\n`Here <../reference/api.html#Trainer>`_ are the details about different ``Trainer``.\n\nTask Collecting\n===============\nBefore collecting model training results, you need to use the ``qlib.init`` to specify the path of mlruns.\n\nTo collect the results of ``task`` after training, ``Qlib`` provides `Collector <../reference/api.html#Collector>`_, `Group <../reference/api.html#Group>`_ and `Ensemble <../reference/api.html#Ensemble>`_ to collect the results in a readable, expandable and loosely-coupled way.\n\n`Collector <../reference/api.html#Collector>`_ can collect objects from everywhere and process them such as merging, grouping, averaging and so on. It has 2 step action including ``collect`` (collect anything in a dict) and ``process_collect`` (process collected dict).\n\n`Group <../reference/api.html#Group>`_ also has 2 steps including ``group`` (can group a set of object based on `group_func` and change them to a dict) and ``reduce`` (can make a dict become an ensemble based on some rule).\nFor example: {(A,B,C1): object, (A,B,C2): object} ---``group``---> {(A,B): {C1: object, C2: object}} ---``reduce``---> {(A,B): object}\n\n`Ensemble <../reference/api.html#Ensemble>`_ can merge the objects in an ensemble. \nFor example: {C1: object, C2: object} ---``Ensemble``---> object.\nYou can set the ensembles you want in the ``Collector``'s process_list.\nCommon ensembles include ``AverageEnsemble`` and ``RollingEnsemble``. Average ensemble is used to ensemble the results of different models in the same time period. Rollingensemble is used to ensemble the results of different models in the same time period\n\nSo the hierarchy is ``Collector``'s second step corresponds to ``Group``. And ``Group``'s second step correspond to ``Ensemble``.\n\nFor more information, please see `Collector <../reference/api.html#Collector>`_, `Group <../reference/api.html#Group>`_ and `Ensemble <../reference/api.html#Ensemble>`_, or the `example <https://github.com/microsoft/qlib/tree/main/examples/model_rolling/task_manager_rolling.py>`_.\n"
  },
  {
    "path": "docs/changelog/changelog.rst",
    "content": ".. include:: ../../CHANGES.rst\n"
  },
  {
    "path": "docs/component/data.rst",
    "content": ".. _data:\n\n==================================\nData Layer: Data Framework & Usage\n==================================\n\nIntroduction\n============\n\n``Data Layer`` provides user-friendly APIs to manage and retrieve data. It provides high-performance data infrastructure.\n\nIt is designed for quantitative investment. For example, users could build formulaic alphas with ``Data Layer`` easily. Please refer to `Building Formulaic Alphas <../advanced/alpha.html>`_ for more details.\n\nThe introduction of ``Data Layer`` includes the following parts.\n\n- Data Preparation\n- Data API\n- Data Loader\n- Data Handler\n- Dataset\n- Cache\n- Data and Cache File Structure\n\nHere is a typical example of Qlib data workflow\n\n- Users download data and converting data into Qlib format(with filename suffix `.bin`).  In this step, typically only some basic data are stored on disk(such as OHLCV).\n- Creating some basic features based on Qlib's expression Engine(e.g. \"Ref($close, 60) / $close\", the return of last 60 trading days). Supported operators in the expression engine can be found `here <https://github.com/microsoft/qlib/blob/main/qlib/data/ops.py>`__. This step is typically implemented in Qlib's `Data Loader <https://qlib.readthedocs.io/en/latest/component/data.html#data-loader>`_ which is a component of `Data Handler <https://qlib.readthedocs.io/en/latest/component/data.html#data-handler>`_ .\n- If users require more complicated data processing (e.g. data normalization),  `Data Handler <https://qlib.readthedocs.io/en/latest/component/data.html#data-handler>`_ support user-customized processors to process data(some predefined processors can be found `here <https://github.com/microsoft/qlib/blob/main/qlib/data/dataset/processor.py>`__).  The processors are different from operators in expression engine. It is designed for some complicated data processing methods which is hard to supported in operators in expression engine.\n- At last, `Dataset <https://qlib.readthedocs.io/en/latest/component/data.html#dataset>`_ is responsible to prepare model-specific dataset from the processed data of Data Handler\n\nData Preparation\n================\n\nQlib Format Data\n----------------\n\nWe've specially designed a data structure to manage financial data, please refer to the `File storage design section in Qlib paper <https://arxiv.org/abs/2009.11189>`_ for detailed information.\nSuch data will be stored with filename suffix `.bin` (We'll call them `.bin` file, `.bin` format, or qlib format). `.bin` file is designed for scientific computing on finance data.\n\n``Qlib`` provides two different off-the-shelf datasets, which can be accessed through this `link <https://github.com/microsoft/qlib/blob/main/qlib/contrib/data/handler.py>`__:\n\n========================  =================  ================\nDataset                   US Market          China Market\n========================  =================  ================\nAlpha360                  √                  √\n\nAlpha158                  √                  √\n========================  =================  ================\n\nAlso, ``Qlib`` provides a high-frequency dataset. Users can run a high-frequency dataset example through this `link <https://github.com/microsoft/qlib/tree/main/examples/highfreq>`__.\n\nQlib Format Dataset\n-------------------\n``Qlib`` has provided an off-the-shelf dataset in `.bin` format, users could use the script ``scripts/get_data.py`` to download the China-Stock dataset as follows. User can also use numpy to load `.bin` file to validate data.\nThe price volume data look different from the actual dealing price because of they are **adjusted** (`adjusted price <https://www.investopedia.com/terms/a/adjusted_closing_price.asp>`_).  And then you may find that the adjusted price may be different from different data sources. This is because different data sources may vary in the way of adjusting prices. Qlib normalize the price on first trading day of each stock to 1 when adjusting them.\nUsers can leverage `$factor` to get the original trading price (e.g. `$close / $factor` to get the original close price).\n\nHere are some discussions about the price adjusting of Qlib. \n\n- https://github.com/microsoft/qlib/issues/991#issuecomment-1075252402\n\n\n.. code-block:: bash\n\n    # download 1d\n    python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn\n\n    # download 1min\n    python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/qlib_cn_1min --region cn --interval 1min\n\nIn addition to China-Stock data, ``Qlib`` also includes a US-Stock dataset, which can be downloaded with the following command:\n\n.. code-block:: bash\n\n    python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/us_data --region us\n\nAfter running the above command, users can find china-stock and us-stock data in ``Qlib`` format in the ``~/.qlib/qlib_data/cn_data`` directory and ``~/.qlib/qlib_data/us_data`` directory respectively.\n\n``Qlib`` also provides the scripts in ``scripts/data_collector`` to help users crawl the latest data on the Internet and convert it to qlib format.\n\nWhen ``Qlib`` is initialized with this dataset, users could build and evaluate their own models with it.  Please refer to `Initialization <../start/initialization.html>`_ for more details.\n\nAutomatic update of daily frequency data\n----------------------------------------\n\n  **It is recommended that users update the data manually once (\\-\\-trading_date 2021-05-25) and then set it to update automatically.**\n\n  For more information refer to: `yahoo collector <https://github.com/microsoft/qlib/tree/main/scripts/data_collector/yahoo#Automatic-update-of-daily-frequency-data>`_\n\n  - Automatic update of data to the \"qlib\" directory each trading day(Linux)\n      - use *crontab*: `crontab -e`\n      - set up timed tasks:\n\n        .. code-block:: bash\n\n            * * * * 1-5 python <script path> update_data_to_bin --qlib_data_1d_dir <user data dir>\n\n        - **script path**: *scripts/data_collector/yahoo/collector.py*\n\n  - Manual update of data\n\n      .. code-block:: bash\n\n        python scripts/data_collector/yahoo/collector.py update_data_to_bin --qlib_data_1d_dir <user data dir> --trading_date <start date> --end_date <end date>\n\n      - *trading_date*: start of trading day\n      - *end_date*: end of trading day(not included)\n\n\n\nConverting CSV and Parquet Format into Qlib Format\n--------------------------------------------------\n\n``Qlib`` has provided the script ``scripts/dump_bin.py`` to convert **any** data in CSV or Parquet format into `.bin` files (``Qlib`` format) as long as they are in the correct format.\n\nBesides downloading the prepared demo data, users could download demo data directly from the Collector as follows for reference to the CSV format.\nHere are some example:\n\nfor daily data:\n  .. code-block:: bash\n\n    python scripts/get_data.py download_data --file_name csv_data_cn.zip --target_dir ~/.qlib/csv_data/cn_data\n\nfor 1min data:\n  .. code-block:: bash\n\n    python scripts/data_collector/yahoo/collector.py download_data --source_dir ~/.qlib/stock_data/source/cn_1min --region CN --start 2021-05-20 --end 2021-05-23 --delay 0.1 --interval 1min --limit_nums 10\n\nUsers can also provide their own data in CSV or Parquet format. However, the data **must satisfies** following criterions:\n\n- CSV or Parquet file is named after a specific stock *or* the CSV or Parquet file includes a column of the stock name\n\n    - Name the CSV or Parquet file after a stock: `SH600000.csv`, `AAPL.csv` or `SH600000.parquet`, `AAPL.parquet` (not case sensitive).\n\n    - CSV or Parquet file includes a column of the stock name. User **must** specify the column name when dumping the data. Here is an example:\n\n        .. code-block:: bash\n\n            python scripts/dump_bin.py dump_all ... --symbol_field_name symbol --file_suffix <.csv or .parquet>\n\n        where the data are in the following format:\n\n            +-----------+-------+\n            | symbol    | close |\n            +===========+=======+\n            | SH600000  | 120   |\n            +-----------+-------+\n\n- CSV or Parquet file **must** include a column for the date, and when dumping the data, user must specify the date column name. Here is an example:\n\n    .. code-block:: bash\n\n        python scripts/dump_bin.py dump_all ... --date_field_name date --file_suffix <.csv or .parquet>\n\n    where the data are in the following format:\n\n        +---------+------------+-------+------+----------+\n        | symbol  | date       | close | open | volume   |\n        +=========+============+=======+======+==========+\n        | SH600000| 2020-11-01 | 120   | 121  | 12300000 |\n        +---------+------------+-------+------+----------+\n        | SH600000| 2020-11-02 | 123   | 120  | 12300000 |\n        +---------+------------+-------+------+----------+\n\n\nSupposed that users prepare their CSV or Parquet format data in the directory ``~/.qlib/my_data``, they can run the following command to start the conversion.\n\n.. code-block:: bash\n\n    python scripts/dump_bin.py dump_all --data_path  ~/.qlib/my_data --qlib_dir ~/.qlib/qlib_data/ --include_fields open,close,high,low,volume,factor --file_suffix <.csv or .parquet>\n\nFor other supported parameters when dumping the data into `.bin` file, users can refer to the information by running the following commands:\n\n.. code-block:: bash\n\n    python scripts/dump_bin.py dump_all --help\n\nAfter conversion, users can find their Qlib format data in the directory `~/.qlib/qlib_data/`.\n\n.. note::\n\n    The arguments of `--include_fields` should correspond with the column names of CSV or Parquet files. The columns names of dataset provided by ``Qlib`` should include open, close, high, low, volume and factor at least.\n\n    - `open`\n        The adjusted opening price\n    - `close`\n        The adjusted closing price\n    - `high`\n        The adjusted highest price\n    - `low`\n        The adjusted lowest price\n    - `volume`\n        The adjusted trading volume\n    - `factor`\n        The Restoration factor. Normally, ``factor = adjusted_price / original_price``, `adjusted price` reference: `split adjusted <https://www.investopedia.com/terms/s/splitadjusted.asp>`_\n\n    In the convention of `Qlib` data processing, `open, close, high, low, volume, money and factor` will be set to NaN if the stock is suspended.\n    If you want to use your own alpha-factor which can't be calculate by OCHLV, like PE, EPS and so on, you could add it to the CSV or Parquet files with OHCLV together and then dump it to the Qlib format data.\n\nChecking the health of the data\n-------------------------------\n\n``Qlib`` provides a script to check the health of the data.\n\n- The main points to check are as follows\n\n    - Check if any data is missing in the DataFrame.\n\n    - Check if there are any large step changes above the threshold in the OHLCV columns.\n\n    - Check if any of the required columns (OLHCV) are missing in the DataFrame.\n\n    - Check if the 'factor' column is missing in the DataFrame.\n\n- You can run the following commands to check whether the data is healthy or not.\n\n    for daily data:\n        .. code-block:: bash\n\n            python scripts/check_data_health.py check_data --qlib_dir ~/.qlib/qlib_data/cn_data\n\n    for 1min data:\n        .. code-block:: bash\n\n            python scripts/check_data_health.py check_data --qlib_dir ~/.qlib/qlib_data/cn_data_1min --freq 1min\n\n- Of course, you can also add some parameters to adjust the test results.\n\n    - The available parameters are these.\n\n        - freq: Frequency of data.\n\n        - large_step_threshold_price: Maximum permitted price change\n\n        - large_step_threshold_volume: Maximum permitted volume change.\n\n        - missing_data_num: Maximum value for which data is allowed to be null.\n\n- You can run the following commands to check whether the data is healthy or not.\n\n    for daily data:\n        .. code-block:: bash\n\n            python scripts/check_data_health.py check_data --qlib_dir ~/.qlib/qlib_data/cn_data --missing_data_num 30055 --large_step_threshold_volume 94485 --large_step_threshold_price 20\n\n    for 1min data:\n        .. code-block:: bash\n\n            python scripts/check_data_health.py check_data --qlib_dir ~/.qlib/qlib_data/cn_data --freq 1min --missing_data_num 35806 --large_step_threshold_volume 3205452000000 --large_step_threshold_price 0.91\n\nStock Pool (Market)\n-------------------\n\n``Qlib`` defines `stock pool <https://github.com/microsoft/qlib/blob/main/examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml#L4>`_ as stock list and their date ranges. Predefined stock pools (e.g. csi300) may be imported as follows.\n\n.. code-block:: bash\n\n    python collector.py --index_name CSI300 --qlib_dir <user qlib data dir> --method parse_instruments\n\n\nMultiple Stock Modes\n--------------------\n\n``Qlib`` now provides two different stock modes for users: China-Stock Mode & US-Stock Mode. Here are some different settings of these two modes:\n\n==============  =================  ================\nRegion          Trade Unit         Limit Threshold\n==============  =================  ================\nChina           100                0.099\n\nUS              1                  None\n==============  =================  ================\n\nThe `trade unit` defines the unit number of stocks can be used in a trade, and the `limit threshold` defines the bound set to the percentage of ups and downs of a stock.\n\n- If users use ``Qlib`` in china-stock mode, china-stock data is required. Users can use ``Qlib`` in china-stock mode according to the following steps:\n    - Download china-stock in qlib format, please refer to section `Qlib Format Dataset <#qlib-format-dataset>`_.\n    - Initialize ``Qlib`` in china-stock mode\n        Supposed that users download their Qlib format data in the directory ``~/.qlib/qlib_data/cn_data``. Users only need to initialize ``Qlib`` as follows.\n\n        .. code-block:: python\n\n            from qlib.constant import REG_CN\n            qlib.init(provider_uri='~/.qlib/qlib_data/cn_data', region=REG_CN)\n\n\n- If users use ``Qlib`` in US-stock mode, US-stock data is required. ``Qlib`` also provides a script to download US-stock data. Users can use ``Qlib`` in US-stock mode according to the following steps:\n    - Download us-stock in qlib format, please refer to section `Qlib Format Dataset <#qlib-format-dataset>`_.\n    - Initialize ``Qlib`` in US-stock mode\n        Supposed that users prepare their Qlib format data in the directory ``~/.qlib/qlib_data/us_data``. Users only need to initialize ``Qlib`` as follows.\n\n        .. code-block:: python\n\n            from qlib.config import REG_US\n            qlib.init(provider_uri='~/.qlib/qlib_data/us_data', region=REG_US)\n\n\n.. note::\n\n    PRs for new data source are highly welcome! Users could commit the code to crawl data as a PR like `the examples here  <https://github.com/microsoft/qlib/tree/main/scripts>`_. And then we will use the code to create data cache on our server which other users could use directly.\n\n\nData API\n========\n\nData Retrieval\n--------------\nUsers can use APIs in ``qlib.data`` to retrieve data, please refer to `Data Retrieval <../start/getdata.html>`_.\n\nFeature\n-------\n\n``Qlib`` provides `Feature` and `ExpressionOps` to fetch the features according to users' needs.\n\n- `Feature`\n    Load data from the data provider. User can get the features like `$high`, `$low`, `$open`, `$close`, .etc, which should correspond with the arguments of `--include_fields`, please refer to section `Converting CSV Format into Qlib Format <#converting-csv-format-into-qlib-format>`_.\n\n- `ExpressionOps`\n    `ExpressionOps` will use operator for feature construction.\n    To know more about  ``Operator``, please refer to `Operator API <../reference/api.html#module-qlib.data.ops>`_.\n    Also, ``Qlib`` supports users to define their own custom ``Operator``, an example has been given in ``tests/test_register_ops.py``.\n\nTo know more about  ``Feature``, please refer to `Feature API <../reference/api.html#module-qlib.data.base>`_.\n\nFilter\n------\n``Qlib`` provides `NameDFilter` and `ExpressionDFilter` to filter the instruments according to users' needs.\n\n- `NameDFilter`\n    Name dynamic instrument filter. Filter the instruments based on a regulated name format. A name rule regular expression is required.\n\n- `ExpressionDFilter`\n    Expression dynamic instrument filter. Filter the instruments based on a certain expression. An expression rule indicating a certain feature field is required.\n\n    - `basic features filter`: rule_expression = '$close/$open>5'\n    - `cross-sectional features filter` \\: rule_expression = '$rank($close)<10'\n    - `time-sequence features filter`: rule_expression = '$Ref($close, 3)>100'\n\nHere is a simple example showing how to use filter in a basic ``Qlib`` workflow configuration file:\n\n.. code-block:: yaml\n\n    filter: &filter\n        filter_type: ExpressionDFilter\n        rule_expression: \"Ref($close, -2) / Ref($close, -1) > 1\"\n        filter_start_time: 2010-01-01\n        filter_end_time: 2010-01-07\n        keep: False\n\n    data_handler_config: &data_handler_config\n        start_time: 2010-01-01\n        end_time: 2021-01-22\n        fit_start_time: 2010-01-01\n        fit_end_time: 2015-12-31\n        instruments: *market\n        filter_pipe: [*filter]\n\nTo know more about ``Filter``, please refer to `Filter API <../reference/api.html#module-qlib.data.filter>`_.\n\nReference\n---------\n\nTo know more about ``Data API``, please refer to `Data API <../reference/api.html#data>`_.\n\n\nData Loader\n===========\n\n``Data Loader`` in ``Qlib`` is designed to load raw data from the original data source. It will be loaded and used in the ``Data Handler`` module.\n\nQlibDataLoader\n--------------\n\nThe ``QlibDataLoader`` class in ``Qlib`` is such an interface that allows users to load raw data from the ``Qlib`` data source.\n\nStaticDataLoader\n----------------\n\nThe ``StaticDataLoader`` class in ``Qlib`` is such an interface that allows users to load raw data from file or as provided.\n\n\nInterface\n---------\n\nHere are some interfaces of the ``QlibDataLoader`` class:\n\n.. autoclass:: qlib.data.dataset.loader.DataLoader\n    :members:\n    :noindex:\n\nAPI\n---\n\nTo know more about ``Data Loader``, please refer to `Data Loader API <../reference/api.html#module-qlib.data.dataset.loader>`_.\n\n\nData Handler\n============\n\nThe ``Data Handler`` module in ``Qlib`` is designed to handler those common data processing methods which will be used by most of the models.\n\nUsers can use ``Data Handler`` in an automatic workflow by ``qrun``, refer to `Workflow: Workflow Management <workflow.html>`_ for more details.\n\nDataHandlerLP\n-------------\n\nIn addition to use ``Data Handler`` in an automatic workflow with ``qrun``, ``Data Handler`` can be used as an independent module, by which users can easily preprocess data (standardization, remove NaN, etc.) and build datasets.\n\nIn order to achieve so, ``Qlib`` provides a base class `qlib.data.dataset.DataHandlerLP <../reference/api.html#qlib.data.dataset.handler.DataHandlerLP>`_. The core idea of this class is that: we will have some learnable ``Processors`` which can learn the parameters of data processing(e.g., parameters for zscore normalization). When new data comes in, these `trained` ``Processors`` can then process the new data and thus processing real-time data in an efficient way becomes possible. More information about ``Processors`` will be listed in the next subsection.\n\n\nInterface\n---------\n\nHere are some important interfaces that ``DataHandlerLP`` provides:\n\n.. autoclass:: qlib.data.dataset.handler.DataHandlerLP\n    :members: __init__, fetch, get_cols\n    :noindex:\n\n\nIf users want to load features and labels by config, users can define a new handler and call the static method `parse_config_to_fields` of ``qlib.contrib.data.handler.Alpha158``.\n\nAlso, users can pass ``qlib.contrib.data.processor.ConfigSectionProcessor`` that provides some preprocess methods for features defined by config into the new handler.\n\n\nProcessor\n---------\n\nThe ``Processor`` module in ``Qlib`` is designed to be learnable and it is responsible for handling data processing such as `normalization` and `drop none/nan features/labels`.\n\n``Qlib`` provides the following ``Processors``:\n\n- ``DropnaProcessor``: `processor` that drops N/A features.\n- ``DropnaLabel``: `processor` that drops N/A labels.\n- ``TanhProcess``: `processor` that uses `tanh` to process noise data.\n- ``ProcessInf``: `processor` that handles infinity values, it will be replaces by the mean of the column.\n- ``Fillna``: `processor` that handles N/A values, which will fill the N/A value by 0 or other given number.\n- ``MinMaxNorm``: `processor` that applies min-max normalization.\n- ``ZscoreNorm``: `processor` that applies z-score normalization.\n- ``RobustZScoreNorm``: `processor` that applies robust z-score normalization.\n- ``CSZScoreNorm``: `processor` that applies cross sectional z-score normalization.\n- ``CSRankNorm``: `processor` that applies cross sectional rank normalization.\n- ``CSZFillna``: `processor` that fills N/A values in a cross sectional way by the mean of the column.\n\nUsers can also create their own `processor` by inheriting the base class of ``Processor``. Please refer to the implementation of all the processors for more information (`Processor Link <https://github.com/microsoft/qlib/blob/main/qlib/data/dataset/processor.py>`_).\n\nTo know more about ``Processor``, please refer to `Processor API <../reference/api.html#module-qlib.data.dataset.processor>`_.\n\nExample\n-------\n\n``Data Handler`` can be run with ``qrun`` by modifying the configuration file, and can also be used as a single module.\n\nKnow more about how to run ``Data Handler`` with ``qrun``, please refer to `Workflow: Workflow Management <workflow.html>`_\n\nQlib provides implemented data handler `Alpha158`. The following example shows how to run `Alpha158` as a single module.\n\n.. note:: Users need to initialize ``Qlib`` with `qlib.init` first, please refer to `initialization <../start/initialization.html>`_.\n\n.. code-block:: Python\n\n    import qlib\n    from qlib.contrib.data.handler import Alpha158\n\n    data_handler_config = {\n        \"start_time\": \"2008-01-01\",\n        \"end_time\": \"2020-08-01\",\n        \"fit_start_time\": \"2008-01-01\",\n        \"fit_end_time\": \"2014-12-31\",\n        \"instruments\": \"csi300\",\n    }\n\n    if __name__ == \"__main__\":\n        qlib.init()\n        h = Alpha158(**data_handler_config)\n\n        # get all the columns of the data\n        print(h.get_cols())\n\n        # fetch all the labels\n        print(h.fetch(col_set=\"label\"))\n\n        # fetch all the features\n        print(h.fetch(col_set=\"feature\"))\n\n\n.. note:: In the ``Alpha158``, ``Qlib`` uses the label `Ref($close, -2)/Ref($close, -1) - 1` that means the change from T+1 to T+2, rather than `Ref($close, -1)/$close - 1`, of which the reason is that when getting the T day close price of a china stock, the stock can be bought on T+1 day and sold on T+2 day.\n\nAPI\n---\n\nTo know more about ``Data Handler``, please refer to `Data Handler API <../reference/api.html#module-qlib.data.dataset.handler>`_.\n\n\nDataset\n=======\n\nThe ``Dataset`` module in ``Qlib`` aims to prepare data for model training and inferencing.\n\nThe motivation of this module is that we want to maximize the flexibility of different models to handle data that are suitable for themselves. This module gives the model the flexibility to process their data in an unique way. For instance, models such as ``GBDT`` may work well on data that contains `nan` or `None` value, while neural networks such as ``MLP`` will break down on such data.\n\nIf user's model need process its data in a different way, user could implement his own ``Dataset`` class. If the model's\ndata processing is not special, ``DatasetH`` can be used directly.\n\nThe ``DatasetH`` class is the `dataset` with `Data Handler`. Here is the most important interface of the class:\n\n.. autoclass:: qlib.data.dataset.__init__.DatasetH\n    :members:\n    :noindex:\n\nAPI\n---\n\nTo know more about ``Dataset``, please refer to `Dataset API <../reference/api.html#dataset>`_.\n\n\nCache\n=====\n\n``Cache`` is an optional module that helps accelerate providing data by saving some frequently-used data as cache file. ``Qlib`` provides a `Memcache` class to cache the most-frequently-used data in memory, an inheritable `ExpressionCache` class, and an inheritable `DatasetCache` class.\n\nGlobal Memory Cache\n-------------------\n\n`Memcache` is a global memory cache mechanism that composes of three `MemCacheUnit` instances to cache **Calendar**, **Instruments**, and **Features**. The `MemCache` is defined globally in `cache.py` as `H`. Users can use `H['c'], H['i'], H['f']` to get/set `memcache`.\n\n.. autoclass:: qlib.data.cache.MemCacheUnit\n    :members:\n    :noindex:\n\n.. autoclass:: qlib.data.cache.MemCache\n    :members:\n    :noindex:\n\n\nExpressionCache\n---------------\n\n`ExpressionCache` is a cache mechanism that saves expressions such as **Mean($close, 5)**. Users can inherit this base class to define their own cache mechanism that saves expressions according to the following steps.\n\n- Override `self._uri` method to define how the cache file path is generated\n- Override `self._expression` method to define what data will be cached and how to cache it.\n\nThe following shows the details about the interfaces:\n\n.. autoclass:: qlib.data.cache.ExpressionCache\n    :members:\n    :noindex:\n\n``Qlib`` has currently provided implemented disk cache `DiskExpressionCache` which inherits from `ExpressionCache` . The expressions data will be stored in the disk.\n\nDatasetCache\n------------\n\n`DatasetCache` is a cache mechanism that saves datasets. A certain dataset is regulated by a stock pool configuration (or a series of instruments, though not recommended), a list of expressions or static feature fields, the start time, and end time for the collected features and the frequency. Users can inherit this base class to define their own cache mechanism that saves datasets according to the following steps.\n\n- Override `self._uri` method to define how their cache file path is generated\n- Override `self._expression` method to define what data will be cached and how to cache it.\n\nThe following shows the details about the interfaces:\n\n.. autoclass:: qlib.data.cache.DatasetCache\n    :members:\n    :noindex:\n\n``Qlib`` has currently provided implemented disk cache `DiskDatasetCache` which inherits from `DatasetCache` . The datasets' data will be stored in the disk.\n\n\n\nData and Cache File Structure\n=============================\n\nWe've specially designed a file structure to manage data and cache, please refer to the `File storage design section in Qlib paper <https://arxiv.org/abs/2009.11189>`_ for detailed information. The file structure of data and cache is listed as follows.\n\n.. code-block::\n\n    - data/\n        [raw data] updated by data providers\n        - calendars/\n            - day.txt\n        - instruments/\n            - all.txt\n            - csi500.txt\n            - ...\n        - features/\n            - sh600000/\n                - open.day.bin\n                - close.day.bin\n                - ...\n            - ...\n        [cached data] updated when raw data is updated\n        - calculated features/\n            - sh600000/\n                - [hash(instrtument, field_expression, freq)]\n                    - all-time expression -cache data file\n                    - .meta : an assorted meta file recording the instrument name, field name, freq, and visit times\n            - ...\n        - cache/\n            - [hash(stockpool_config, field_expression_list, freq)]\n                - all-time Dataset-cache data file\n                - .meta : an assorted meta file recording the stockpool config, field names and visit times\n                - .index : an assorted index file recording the line index of all calendars\n            - ...\n"
  },
  {
    "path": "docs/component/highfreq.rst",
    "content": ".. _highfreq:\n\n========================================================================\nDesign of Nested Decision Execution Framework for High-Frequency Trading\n========================================================================\n.. currentmodule:: qlib\n\nIntroduction\n============\n\nDaily trading (e.g. portfolio management) and intraday trading (e.g. orders execution) are two hot topics in Quant investment and are usually studied separately.\n\nTo get the join trading performance of daily and intraday trading, they must interact with each other and run backtest jointly.\nIn order to support the joint backtest strategies at multiple levels, a corresponding framework is required. None of the publicly available high-frequency trading frameworks considers multi-level joint trading, which makes the backtesting aforementioned inaccurate.\n\nBesides backtesting, the optimization of strategies from different levels is not standalone and can be affected by each other.\nFor example, the best portfolio management strategy may change with the performance of order executions(e.g. a portfolio with higher turnover may become a better choice when we improve the order execution strategies).\nTo achieve overall good performance, it is necessary to consider the interaction of strategies at a different levels.\n\nTherefore, building a new framework for trading on multiple levels becomes necessary to solve the various problems mentioned above, for which we designed a nested decision execution framework that considers the interaction of strategies.\n\n.. image:: ../_static/img/framework.svg\n\nThe design of the framework is shown in the yellow part in the middle of the figure above. Each level consists of ``Trading Agent`` and ``Execution Env``. ``Trading Agent`` has its own data processing module (``Information Extractor``), forecasting module (``Forecast Model``) and decision generator (``Decision Generator``). The trading algorithm generates the decisions by the ``Decision Generator`` based on the forecast signals output by the ``Forecast Module``, and the decisions generated by the trading algorithm are passed to the ``Execution Env``, which returns the execution results.\n\nThe frequency of the trading algorithm, decision content and execution environment can be customized by users (e.g. intraday trading, daily-frequency trading, weekly-frequency trading), and the execution environment can be nested with finer-grained trading algorithm and execution environment inside (i.e. sub-workflow in the figure, e.g. daily-frequency orders can be turned into finer-grained decisions by splitting orders within the day). The flexibility of the nested decision execution framework makes it easy for users to explore the effects of combining different levels of trading strategies and break down the optimization barriers between different levels of the trading algorithm.\n\nThe optimization for the nested decision execution framework can be implemented with the support of `QlibRL <./rl/overall.html>`_. To know more about how to use the QlibRL, go to API Reference: `RL API <../reference/api.html#rl>`_. \n\nExample\n=======\n\nAn example of a nested decision execution framework for high-frequency can be found `here <https://github.com/microsoft/qlib/blob/main/examples/nested_decision_execution/workflow.py>`_.\n\n\nBesides, the above examples, here are some other related works about high-frequency trading in Qlib.\n\n- `Prediction with high-frequency data <https://github.com/microsoft/qlib/tree/main/examples/highfreq#benchmarks-performance-predicting-the-price-trend-in-high-frequency-data>`_\n- `Examples <https://github.com/microsoft/qlib/blob/main/examples/orderbook_data/>`_ to extract features from high-frequency data without fixed frequency.\n- `A paper <https://github.com/microsoft/qlib/tree/high-freq-execution#high-frequency-execution>`_ for high-frequency trading.\n"
  },
  {
    "path": "docs/component/meta.rst",
    "content": ".. _meta:\n\n======================================================\nMeta Controller: Meta-Task & Meta-Dataset & Meta-Model\n======================================================\n.. currentmodule:: qlib\n\n\nIntroduction\n============\n``Meta Controller`` provides guidance to ``Forecast Model``, which aims to learn regular patterns among a series of forecasting tasks and use learned patterns to guide forthcoming forecasting tasks. Users can implement their own meta-model instance based on ``Meta Controller`` module.\n\nMeta Task\n=========\n\nA `Meta Task` instance is the basic element in the meta-learning framework. It saves the data that can be used for the `Meta Model`. Multiple `Meta Task` instances may share the same `Data Handler`, controlled by `Meta Dataset`. Users should use `prepare_task_data()` to obtain the data that can be directly fed into the `Meta Model`.\n\n.. autoclass:: qlib.model.meta.task.MetaTask\n    :members:\n\nMeta Dataset\n============\n\n`Meta Dataset` controls the meta-information generating process. It is on the duty of providing data for training the `Meta Model`. Users should use `prepare_tasks` to retrieve a list of `Meta Task` instances.\n\n.. autoclass:: qlib.model.meta.dataset.MetaTaskDataset\n    :members:\n\nMeta Model\n==========\n\nGeneral Meta Model\n------------------\n`Meta Model` instance is the part that controls the workflow. The usage of the `Meta Model` includes:\n1. Users train their `Meta Model` with the `fit` function.\n2. The `Meta Model` instance guides the workflow by giving useful information via the `inference` function.\n\n.. autoclass:: qlib.model.meta.model.MetaModel\n    :members:\n\nMeta Task Model\n---------------\nThis type of meta-model may interact with task definitions directly. Then, the `Meta Task Model` is the class for them to inherit from. They guide the base tasks by modifying the base task definitions. The function `prepare_tasks` can be used to obtain the modified base task definitions.\n\n.. autoclass:: qlib.model.meta.model.MetaTaskModel\n    :members:\n\nMeta Guide Model\n----------------\nThis type of meta-model participates in the training process of the base forecasting model. The meta-model may guide the base forecasting models during their training to improve their performances.\n\n.. autoclass:: qlib.model.meta.model.MetaGuideModel\n    :members:\n\n\nExample\n=======\n``Qlib`` provides an implementation of ``Meta Model`` module, ``DDG-DA``,\nwhich adapts to the market dynamics.\n\n``DDG-DA`` includes four steps:\n\n1. Calculate meta-information and encapsulate it into ``Meta Task`` instances. All the meta-tasks form a ``Meta Dataset`` instance.\n2. Train ``DDG-DA`` based on the training data of the meta-dataset.\n3. Do the inference of the ``DDG-DA`` to get guide information.\n4. Apply guide information to the forecasting models to improve their performances.\n\nThe `above example <https://github.com/microsoft/qlib/tree/main/examples/benchmarks_dynamic/DDG-DA>`_ can be found in ``examples/benchmarks_dynamic/DDG-DA/workflow.py``.\n"
  },
  {
    "path": "docs/component/model.rst",
    "content": ".. _model:\n\n===========================================\nForecast Model: Model Training & Prediction\n===========================================\n\nIntroduction\n============\n\n``Forecast Model`` is designed to make the `prediction score` about stocks. Users can use the ``Forecast Model`` in an automatic workflow by ``qrun``, please refer to `Workflow: Workflow Management <workflow.html>`_.\n\nBecause the components in ``Qlib`` are designed in a loosely-coupled way, ``Forecast Model`` can be used as an independent module also.\n\nBase Class & Interface\n======================\n\n``Qlib`` provides a base class `qlib.model.base.Model <../reference/api.html#module-qlib.model.base>`_ from which all models should inherit.\n\nThe base class provides the following interfaces:\n\n.. autoclass:: qlib.model.base.Model\n    :members:\n    :noindex:\n\n``Qlib`` also provides a base class `qlib.model.base.ModelFT <../reference/api.html#qlib.model.base.ModelFT>`_, which includes the method for finetuning the model.\n\nFor other interfaces such as `finetune`, please refer to `Model API <../reference/api.html#module-qlib.model.base>`_.\n\nExample\n=======\n\n``Qlib``'s `Model Zoo` includes models such as ``LightGBM``, ``MLP``, ``LSTM``, etc.. These models are treated as the baselines of ``Forecast Model``. The following steps show how to run`` LightGBM`` as an independent module.\n\n- Initialize ``Qlib`` with `qlib.init` first, please refer to `Initialization <../start/initialization.html>`_.\n- Run the following code to get the `prediction score` `pred_score`\n    .. code-block:: Python\n\n        from qlib.contrib.model.gbdt import LGBModel\n        from qlib.contrib.data.handler import Alpha158\n        from qlib.utils import init_instance_by_config, flatten_dict\n        from qlib.workflow import R\n        from qlib.workflow.record_temp import SignalRecord, PortAnaRecord\n\n        market = \"csi300\"\n        benchmark = \"SH000300\"\n\n        data_handler_config = {\n            \"start_time\": \"2008-01-01\",\n            \"end_time\": \"2020-08-01\",\n            \"fit_start_time\": \"2008-01-01\",\n            \"fit_end_time\": \"2014-12-31\",\n            \"instruments\": market,\n        }\n\n        task = {\n            \"model\": {\n                \"class\": \"LGBModel\",\n                \"module_path\": \"qlib.contrib.model.gbdt\",\n                \"kwargs\": {\n                    \"loss\": \"mse\",\n                    \"colsample_bytree\": 0.8879,\n                    \"learning_rate\": 0.0421,\n                    \"subsample\": 0.8789,\n                    \"lambda_l1\": 205.6999,\n                    \"lambda_l2\": 580.9768,\n                    \"max_depth\": 8,\n                    \"num_leaves\": 210,\n                    \"num_threads\": 20,\n                },\n            },\n            \"dataset\": {\n                \"class\": \"DatasetH\",\n                \"module_path\": \"qlib.data.dataset\",\n                \"kwargs\": {\n                    \"handler\": {\n                        \"class\": \"Alpha158\",\n                        \"module_path\": \"qlib.contrib.data.handler\",\n                        \"kwargs\": data_handler_config,\n                    },\n                    \"segments\": {\n                        \"train\": (\"2008-01-01\", \"2014-12-31\"),\n                        \"valid\": (\"2015-01-01\", \"2016-12-31\"),\n                        \"test\": (\"2017-01-01\", \"2020-08-01\"),\n                    },\n                },\n            },\n        }\n\n        # model initialization\n        model = init_instance_by_config(task[\"model\"])\n        dataset = init_instance_by_config(task[\"dataset\"])\n\n        # start exp\n        with R.start(experiment_name=\"workflow\"):\n            # train\n            R.log_params(**flatten_dict(task))\n            model.fit(dataset)\n\n            # prediction\n            recorder = R.get_recorder()\n            sr = SignalRecord(model, dataset, recorder)\n            sr.generate()\n\n    .. note::\n\n        `Alpha158` is the data handler provided by ``Qlib``, please refer to `Data Handler <data.html#data-handler>`_.\n        `SignalRecord` is the `Record Template` in ``Qlib``, please refer to `Workflow <recorder.html#record-template>`_.\n\nAlso, the above example has been given in ``examples/train_backtest_analyze.ipynb``.\nTechnically, the meaning of the model prediction depends on the label setting designed by user.\nBy default, the meaning of the score is normally the rating of the instruments by the forecasting model. The higher the score, the more profit the instruments.\n\n\nCustom Model\n============\n\nQlib supports custom models. If users are interested in customizing their own models and integrating the models into ``Qlib``, please refer to `Custom Model Integration <../start/integration.html>`_.\n\n\nAPI\n===\nPlease refer to `Model API <../reference/api.html#module-qlib.model.base>`_.\n"
  },
  {
    "path": "docs/component/online.rst",
    "content": ".. _online_serving:\n\n==============\nOnline Serving\n==============\n.. currentmodule:: qlib\n\n\nIntroduction\n============\n\n.. image:: ../_static/img/online_serving.png\n    :align: center\n\n\nIn addition to backtesting, one way to test a model is effective is to make predictions in real market conditions or even do real trading based on those predictions.\n``Online Serving`` is a set of modules for online models using the latest data,\nwhich including `Online Manager <#Online Manager>`_, `Online Strategy <#Online Strategy>`_, `Online Tool <#Online Tool>`_, `Updater <#Updater>`_.\n\n`Here <https://github.com/microsoft/qlib/tree/main/examples/online_srv>`_ are several examples for reference, which demonstrate different features of ``Online Serving``.\nIf you have many models or `task` needs to be managed, please consider `Task Management <../advanced/task_management.html>`_.\nThe `examples <https://github.com/microsoft/qlib/tree/main/examples/online_srv>`_ are based on some components in `Task Management <../advanced/task_management.html>`_ such as ``TrainerRM`` or ``Collector``.\n\n**NOTE**: User should keep his data source updated to support online serving. For example, Qlib provides `a batch of scripts <https://github.com/microsoft/qlib/blob/main/scripts/data_collector/yahoo/README.md#automatic-update-of-daily-frequency-datafrom-yahoo-finance>`_ to help users update Yahoo daily data.\n\nKnown limitations currently\n- Currently, the daily updating prediction for the next trading day is supported. But generating orders for the next trading day is not supported due to the `limitations of public data <https://github.com/microsoft/qlib/issues/215#issuecomment-766293563>_`\n\n\nOnline Manager\n==============\n\n.. automodule:: qlib.workflow.online.manager\n    :members:\n    :noindex:\n\nOnline Strategy\n===============\n\n.. automodule:: qlib.workflow.online.strategy\n    :members:\n    :noindex:\n\nOnline Tool\n===========\n\n.. automodule:: qlib.workflow.online.utils\n    :members:\n    :noindex:\n\nUpdater\n=======\n\n.. automodule:: qlib.workflow.online.update\n    :members:\n    :noindex:\n"
  },
  {
    "path": "docs/component/recorder.rst",
    "content": ".. _recorder:\n\n====================================\nQlib Recorder: Experiment Management\n====================================\n.. currentmodule:: qlib\n\nIntroduction\n============\n``Qlib`` contains an experiment management system named ``QlibRecorder``, which is designed to help users handle experiment and analyse results in an efficient way.\n\nThere are three components of the system:\n\n- `ExperimentManager`\n    a class that manages experiments.\n\n- `Experiment`\n    a class of experiment, and each instance of it is responsible for a single experiment.\n\n- `Recorder`\n    a class of recorder, and each instance of it is responsible for a single run.\n\nHere is a general view of the structure of the system:\n\n.. code-block::\n\n    ExperimentManager\n        - Experiment 1\n            - Recorder 1\n            - Recorder 2\n            - ...\n        - Experiment 2\n            - Recorder 1\n            - Recorder 2\n            - ...\n        - ...\n\nThis experiment management system defines a set of interface and provided a concrete implementation ``MLflowExpManager``, which is based on the machine learning platform: ``MLFlow`` (`link <https://mlflow.org/>`_).\n\nIf users set the implementation of ``ExpManager`` to be ``MLflowExpManager``, they can use the command `mlflow ui` to visualize and check the experiment results. For more information, please refer to the related documents `here <https://www.mlflow.org/docs/latest/cli.html#mlflow-ui>`_.\n\nQlib Recorder\n=============\n``QlibRecorder`` provides a high level API for users to use the experiment management system. The interfaces are wrapped in the variable ``R`` in ``Qlib``, and users can directly use ``R`` to interact with the system. The following command shows how to import ``R`` in Python:\n\n.. code-block:: Python\n\n        from qlib.workflow import R\n\n``QlibRecorder`` includes several common API for managing `experiments` and `recorders` within a workflow. For more available APIs, please refer to the following section about `Experiment Manager`, `Experiment` and `Recorder`.\n\nHere are the available interfaces of ``QlibRecorder``:\n\n.. autoclass:: qlib.workflow.__init__.QlibRecorder\n    :members:\n\nExperiment Manager\n==================\n\nThe ``ExpManager`` module in ``Qlib`` is responsible for managing different experiments. Most of the APIs of ``ExpManager`` are similar to ``QlibRecorder``, and the most important API will be the ``get_exp`` method. User can directly refer to the documents above for some detailed information about how to use the ``get_exp`` method.\n\n.. autoclass:: qlib.workflow.expm.ExpManager\n    :members: get_exp, list_experiments\n    :noindex:\n\nFor other interfaces such as `create_exp`, `delete_exp`, please refer to `Experiment Manager API <../reference/api.html#experiment-manager>`_.\n\nExperiment\n==========\n\nThe ``Experiment`` class is solely responsible for a single experiment, and it will handle any operations that are related to an experiment. Basic methods such as `start`, `end` an experiment are included. Besides, methods related to `recorders` are also available: such methods include `get_recorder` and `list_recorders`.\n\n.. autoclass:: qlib.workflow.exp.Experiment\n    :members: get_recorder, list_recorders\n    :noindex:\n\nFor other interfaces such as `search_records`, `delete_recorder`, please refer to `Experiment API <../reference/api.html#experiment>`_.\n\n``Qlib`` also provides a default ``Experiment``, which will be created and used under certain situations when users use the APIs such as `log_metrics` or `get_exp`. If the default ``Experiment`` is used, there will be related logged information when running ``Qlib``. Users are able to change the name of the default ``Experiment`` in the config file of ``Qlib`` or during ``Qlib``'s `initialization <../start/initialization.html#parameters>`_, which is set to be '`Experiment`'.\n\nRecorder\n========\n\nThe ``Recorder`` class is responsible for a single recorder. It will handle some detailed operations such as ``log_metrics``, ``log_params`` of a single run. It is designed to help user to easily track results and things being generated during a run.\n\nHere are some important APIs that are not included in the ``QlibRecorder``:\n\n.. autoclass:: qlib.workflow.recorder.Recorder\n    :members: list_artifacts, list_metrics, list_params, list_tags\n    :noindex:\n\nFor other interfaces such as `save_objects`, `load_object`, please refer to `Recorder API <../reference/api.html#recorder>`_.\n\nRecord Template\n===============\n\nThe ``RecordTemp`` class is a class that enables generate experiment results such as IC and backtest in a certain format. We have provided three different `Record Template` class:\n\n- ``SignalRecord``: This class generates the `prediction` results of the model.\n- ``SigAnaRecord``: This class generates the `IC`, `ICIR`, `Rank IC` and `Rank ICIR` of the model.\n\nHere is a simple example of what is done in ``SigAnaRecord``, which users can refer to if they want to calculate IC, Rank IC, Long-Short Return with their own prediction and label.\n\n.. code-block:: Python\n\n    from qlib.contrib.eva.alpha import calc_ic, calc_long_short_return\n\n    ic, ric = calc_ic(pred.iloc[:, 0], label.iloc[:, 0])\n    long_short_r, long_avg_r = calc_long_short_return(pred.iloc[:, 0], label.iloc[:, 0])\n\n- ``PortAnaRecord``: This class generates the results of `backtest`. The detailed information about `backtest` as well as the available `strategy`, users can refer to `Strategy <../component/strategy.html>`_ and `Backtest <../component/backtest.html>`_.\n\nHere is a simple example of what is done in ``PortAnaRecord``, which users can refer to if they want to do backtest based on their own prediction and label.\n\n.. code-block:: Python\n\n    from qlib.contrib.strategy.strategy import TopkDropoutStrategy\n    from qlib.contrib.evaluate import (\n        backtest as normal_backtest,\n        risk_analysis,\n    )\n\n    # backtest\n    STRATEGY_CONFIG = {\n        \"topk\": 50,\n        \"n_drop\": 5,\n    }\n    BACKTEST_CONFIG = {\n        \"limit_threshold\": 0.095,\n        \"account\": 100000000,\n        \"benchmark\": BENCHMARK,\n        \"deal_price\": \"close\",\n        \"open_cost\": 0.0005,\n        \"close_cost\": 0.0015,\n        \"min_cost\": 5,\n    }\n\n    strategy = TopkDropoutStrategy(**STRATEGY_CONFIG)\n    report_normal, positions_normal = normal_backtest(pred_score, strategy=strategy, **BACKTEST_CONFIG)\n\n    # analysis\n    analysis = dict()\n    analysis[\"excess_return_without_cost\"] = risk_analysis(report_normal[\"return\"] - report_normal[\"bench\"])\n    analysis[\"excess_return_with_cost\"] = risk_analysis(report_normal[\"return\"] - report_normal[\"bench\"] - report_normal[\"cost\"])\n    analysis_df = pd.concat(analysis)  # type: pd.DataFrame\n    print(analysis_df)\n\nFor more information about the APIs, please refer to `Record Template API <../reference/api.html#module-qlib.workflow.record_temp>`_.\n\n\n\nKnown Limitations\n=================\n- The Python objects are saved based on pickle, which may results in issues when the environment dumping objects and loading objects are different.\n"
  },
  {
    "path": "docs/component/report.rst",
    "content": ".. _report:\n\n=======================================\nAnalysis: Evaluation & Results Analysis\n=======================================\n\nIntroduction\n============\n\n``Analysis`` is designed to show the graphical reports of ``Intraday Trading`` , which helps users to evaluate and analyse investment portfolios visually. The following are some graphics to view:\n\n- analysis_position\n    - report_graph\n    - score_ic_graph\n    - cumulative_return_graph\n    - risk_analysis_graph\n    - rank_label_graph\n\n- analysis_model\n    - model_performance_graph\n\n\nAll of the accumulated profit metrics(e.g. return, max drawdown) in Qlib are calculated by summation.\nThis avoids the metrics or the plots being skewed exponentially over time.\n\nGraphical Reports\n=================\n\nUsers can run the following code to get all supported reports.\n\n.. code-block:: python\n\n    >> import qlib.contrib.report as qcr\n    >> print(qcr.GRAPH_NAME_LIST)\n    ['analysis_position.report_graph', 'analysis_position.score_ic_graph', 'analysis_position.cumulative_return_graph', 'analysis_position.risk_analysis_graph', 'analysis_position.rank_label_graph', 'analysis_model.model_performance_graph']\n\n.. note::\n\n    For more details, please refer to the function document: similar to ``help(qcr.analysis_position.report_graph)``\n\n\n\nUsage & Example\n===============\n\nUsage of `analysis_position.report`\n-----------------------------------\n\nAPI\n~~~\n\n.. automodule:: qlib.contrib.report.analysis_position.report\n    :members:\n    :noindex:\n\nGraphical Result\n~~~~~~~~~~~~~~~~\n\n.. note::\n\n    - Axis X: Trading day\n    - Axis Y:\n        - `cum bench`\n            Cumulative returns series of benchmark\n        - `cum return wo cost`\n            Cumulative returns series of portfolio without cost\n        - `cum return w cost`\n            Cumulative returns series of portfolio with cost\n        - `return wo mdd`\n            Maximum drawdown series of cumulative return without cost\n        - `return w cost mdd`:\n            Maximum drawdown series of cumulative return with cost\n        - `cum ex return wo cost`\n            The `CAR` (cumulative abnormal return) series of the portfolio compared to the benchmark without cost.\n        - `cum ex return w cost`\n            The `CAR` (cumulative abnormal return) series of the portfolio compared to the benchmark with cost.\n        - `turnover`\n            Turnover rate series\n        - `cum ex return wo cost mdd`\n            Drawdown series of `CAR` (cumulative abnormal return) without cost\n        - `cum ex return w cost mdd`\n            Drawdown series of `CAR` (cumulative abnormal return) with cost\n    - The shaded part above: Maximum drawdown corresponding to `cum return wo cost`\n    - The shaded part below: Maximum drawdown corresponding to `cum ex return wo cost`\n\n.. image:: ../_static/img/analysis/report.png\n\n\nUsage of `analysis_position.score_ic`\n-------------------------------------\n\nAPI\n~~~\n\n.. automodule:: qlib.contrib.report.analysis_position.score_ic\n    :members:\n    :noindex:\n\n\nGraphical Result\n~~~~~~~~~~~~~~~~\n\n.. note::\n\n    - Axis X: Trading day\n    - Axis Y:\n        - `ic`\n            The `Pearson correlation coefficient` series between `label` and `prediction score`.\n            In the above example, the `label` is formulated as `Ref($close, -2)/Ref($close, -1)-1`. Please refer to `Data Feature <data.html#feature>`_ for more details.\n\n        - `rank_ic`\n            The `Spearman's rank correlation coefficient` series between `label` and `prediction score`.\n\n.. image:: ../_static/img/analysis/score_ic.png\n\n\n.. Usage of `analysis_position.cumulative_return`\n.. ----------------------------------------------\n..\n.. API\n.. ~~~~~~~~~~~~~~~~\n..\n.. .. automodule:: qlib.contrib.report.analysis_position.cumulative_return\n..     :members:\n..\n.. Graphical Result\n.. ~~~~~~~~~~~~~~~~~\n..\n.. .. note::\n..\n..     - Axis X: Trading day\n..     - Axis Y:\n..         - Above axis Y: `(((Ref($close, -1)/$close - 1) * weight).sum() / weight.sum()).cumsum()`\n..         - Below axis Y: Daily weight sum\n..     - In the **sell** graph, `y < 0` stands for profit; in other cases, `y > 0` stands for profit.\n..     - In the **buy_minus_sell** graph, the **y** value of the **weight** graph at the bottom is `buy_weight + sell_weight`.\n..     - In each graph, the **red line** in the histogram on the right represents the average.\n..\n.. .. image:: ../_static/img/analysis/cumulative_return_buy.png\n..\n.. .. image:: ../_static/img/analysis/cumulative_return_sell.png\n..\n.. .. image:: ../_static/img/analysis/cumulative_return_buy_minus_sell.png\n..\n.. .. image:: ../_static/img/analysis/cumulative_return_hold.png\n\n\nUsage of `analysis_position.risk_analysis`\n------------------------------------------\n\nAPI\n~~~\n\n.. automodule:: qlib.contrib.report.analysis_position.risk_analysis\n    :members:\n    :noindex:\n\n\nGraphical Result\n~~~~~~~~~~~~~~~~\n\n.. note::\n\n    - general graphics\n        - `std`\n            - `excess_return_without_cost`\n                The `Standard Deviation` of `CAR` (cumulative abnormal return) without cost.\n            - `excess_return_with_cost`\n                The `Standard Deviation` of `CAR` (cumulative abnormal return) with cost.\n        - `annualized_return`\n            - `excess_return_without_cost`\n                The `Annualized Rate` of `CAR` (cumulative abnormal return) without cost.\n            - `excess_return_with_cost`\n                The `Annualized Rate` of `CAR` (cumulative abnormal return) with cost.\n        -  `information_ratio`\n            - `excess_return_without_cost`\n                The `Information Ratio` without cost.\n            - `excess_return_with_cost`\n                The `Information Ratio` with cost.\n\n            To know more about `Information Ratio`, please refer to `Information Ratio – IR <https://www.investopedia.com/terms/i/informationratio.asp>`_.\n        -  `max_drawdown`\n            - `excess_return_without_cost`\n                The `Maximum Drawdown` of `CAR` (cumulative abnormal return) without cost.\n            - `excess_return_with_cost`\n                The `Maximum Drawdown` of `CAR` (cumulative abnormal return) with cost.\n\n\n.. image:: ../_static/img/analysis/risk_analysis_bar.png\n    :align: center\n\n.. note::\n\n    - annualized_return/max_drawdown/information_ratio/std graphics\n        - Axis X: Trading days grouped by month\n        - Axis Y:\n            - annualized_return graphics\n                - `excess_return_without_cost_annualized_return`\n                    The `Annualized Rate` series of monthly `CAR` (cumulative abnormal return) without cost.\n                - `excess_return_with_cost_annualized_return`\n                    The `Annualized Rate` series of monthly `CAR` (cumulative abnormal return) with cost.\n            - max_drawdown graphics\n                - `excess_return_without_cost_max_drawdown`\n                    The `Maximum Drawdown` series of monthly `CAR` (cumulative abnormal return) without cost.\n                - `excess_return_with_cost_max_drawdown`\n                    The `Maximum Drawdown` series of monthly `CAR` (cumulative abnormal return) with cost.\n            - information_ratio graphics\n                - `excess_return_without_cost_information_ratio`\n                    The `Information Ratio` series of monthly `CAR` (cumulative abnormal return) without cost.\n                - `excess_return_with_cost_information_ratio`\n                    The `Information Ratio` series of monthly `CAR` (cumulative abnormal return) with cost.\n            - std graphics\n                - `excess_return_without_cost_max_drawdown`\n                    The `Standard Deviation` series of monthly `CAR` (cumulative abnormal return) without cost.\n                - `excess_return_with_cost_max_drawdown`\n                    The `Standard Deviation` series of monthly `CAR` (cumulative abnormal return) with cost.\n\n\n.. image:: ../_static/img/analysis/risk_analysis_annualized_return.png\n    :align: center\n\n.. image:: ../_static/img/analysis/risk_analysis_max_drawdown.png\n    :align: center\n\n.. image:: ../_static/img/analysis/risk_analysis_information_ratio.png\n    :align: center\n\n.. image:: ../_static/img/analysis/risk_analysis_std.png\n    :align: center\n\n..\n.. Usage of `analysis_position.rank_label`\n.. ---------------------------------------\n..\n.. API\n.. ~~~\n..\n.. .. automodule:: qlib.contrib.report.analysis_position.rank_label\n..     :members:\n..\n..\n.. Graphical Result\n.. ~~~~~~~~~~~~~~~~\n..\n.. .. note::\n..\n..     - hold/sell/buy graphics:\n..         - Axis X: Trading day\n..         - Axis Y:\n..             Average `ranking ratio`of `label` for stocks that is held/sold/bought on the trading day.\n..\n..             In the above example, the `label` is formulated as `Ref($close, -1)/$close - 1`. The `ranking ratio` can be formulated as follows.\n..             .. math::\n..\n..                 ranking\\ ratio = \\frac{Ascending\\ Ranking\\ of\\ label}{Number\\ of\\ Stocks\\ in\\ the\\ Portfolio}\n..\n.. .. image:: ../_static/img/analysis/rank_label_hold.png\n..     :align: center\n..\n.. .. image:: ../_static/img/analysis/rank_label_buy.png\n..     :align: center\n..\n.. .. image:: ../_static/img/analysis/rank_label_sell.png\n..     :align: center\n..\n..\n\nUsage of `analysis_model.analysis_model_performance`\n----------------------------------------------------\n\nAPI\n~~~\n\n.. automodule:: qlib.contrib.report.analysis_model.analysis_model_performance\n    :members:\n    :noindex:\n\n\nGraphical Results\n~~~~~~~~~~~~~~~~~\n\n.. note::\n\n    - cumulative return graphics\n        - `Group1`:\n            The `Cumulative Return` series of stocks group with (`ranking ratio` of label <= 20%)\n        - `Group2`:\n            The `Cumulative Return` series of stocks group with (20% < `ranking ratio` of label <= 40%)\n        - `Group3`:\n            The `Cumulative Return` series of stocks group with (40% < `ranking ratio` of label <= 60%)\n        - `Group4`:\n            The `Cumulative Return` series of stocks group with (60% < `ranking ratio` of label <= 80%)\n        - `Group5`:\n            The `Cumulative Return` series of stocks group with (80% < `ranking ratio` of label)\n        - `long-short`:\n            The Difference series between `Cumulative Return` of `Group1` and of `Group5`\n        - `long-average`\n            The Difference series between `Cumulative Return` of `Group1` and average `Cumulative Return` for all stocks.\n\n        The `ranking ratio` can be formulated as follows.\n            .. math::\n\n                ranking\\ ratio = \\frac{Ascending\\ Ranking\\ of\\ label}{Number\\ of\\ Stocks\\ in\\ the\\ Portfolio}\n\n.. image:: ../_static/img/analysis/analysis_model_cumulative_return.png\n    :align: center\n\n.. note::\n    - long-short/long-average\n        The distribution of long-short/long-average returns on each trading day\n\n\n.. image:: ../_static/img/analysis/analysis_model_long_short.png\n    :align: center\n\n.. TODO: ask xiao yang for detial\n\n.. note::\n    - Information Coefficient\n        - The `Pearson correlation coefficient` series between `labels` and `prediction scores` of stocks in portfolio.\n        - The graphics reports can be used to evaluate the `prediction scores`.\n\n.. image:: ../_static/img/analysis/analysis_model_IC.png\n    :align: center\n\n.. note::\n    - Monthly IC\n        Monthly average of the `Information Coefficient`\n\n.. image:: ../_static/img/analysis/analysis_model_monthly_IC.png\n    :align: center\n\n.. note::\n    - IC\n        The distribution of the `Information Coefficient` on each trading day.\n    - IC Normal Dist. Q-Q\n        The `Quantile-Quantile Plot` is used for the normal distribution of `Information Coefficient` on each trading day.\n\n.. image:: ../_static/img/analysis/analysis_model_NDQ.png\n    :align: center\n\n.. note::\n    - Auto Correlation\n        - The `Pearson correlation coefficient` series between the latest `prediction scores` and the `prediction scores` `lag` days ago of stocks in portfolio on each trading day.\n        - The graphics reports can be used to estimate the turnover rate.\n\n\n.. image:: ../_static/img/analysis/analysis_model_auto_correlation.png\n    :align: center\n"
  },
  {
    "path": "docs/component/rl/framework.rst",
    "content": "The Framework of QlibRL\n=======================\n\nQlibRL contains a full set of components that cover the entire lifecycle of an RL pipeline, including building the simulator of the market, shaping states & actions, training policies (strategies), and backtesting strategies in the simulated environment.\n\nQlibRL is basically implemented with the support of Tianshou and Gym frameworks. The high-level structure of QlibRL is demonstrated below:\n\n.. image:: ../../_static/img/QlibRL_framework.png\n   :width: 600\n   :align: center\n\nHere, we briefly introduce each component in the figure.\n\nEnvWrapper\n------------\nEnvWrapper is the complete capsulation of the simulated environment. It receives actions from outside (policy/strategy/agent), simulates the changes in the market, and then replies rewards and updated states, thus forming an interaction loop.\n\nIn QlibRL, EnvWrapper is a subclass of gym.Env, so it implements all necessary interfaces of gym.Env. Any classes or pipelines that accept gym.Env should also accept EnvWrapper. Developers do not need to implement their own EnvWrapper to build their own environment. Instead, they only need to implement 4 components of the EnvWrapper:\n\n- `Simulator`\n    The simulator is the core component responsible for the environment simulation. Developers could implement all the logic that is directly related to the environment simulation in the Simulator in any way they like. In QlibRL, there are already two implementations of Simulator for single asset trading: 1) ``SingleAssetOrderExecution``, which is built based on Qlib's backtest toolkits and hence considers a lot of practical trading details but is slow. 2) ``SimpleSingleAssetOrderExecution``, which is built based on a simplified trading simulator, which ignores a lot of details (e.g. trading limitations, rounding) but is quite fast.\n- `State interpreter` \n    The state interpreter is responsible for \"interpret\" states in the original format (format provided by the simulator) into states in a format that the policy could understand. For example, transform unstructured raw features into numerical tensors.\n- `Action interpreter` \n    The action interpreter is similar to the state interpreter. But instead of states, it interprets actions generated by the policy, from the format provided by the policy to the format that is acceptable to the simulator.\n- `Reward function` \n    The reward function returns a numerical reward to the policy after each time the policy takes an action. \n\nEnvWrapper will organically organize these components. Such decomposition allows for better flexibility in development. For example, if the developers want to train multiple types of policies in the same environment, they only need to design one simulator and design different state interpreters/action interpreters/reward functions for different types of policies.\n\nQlibRL has well-defined base classes for all these 4 components. All the developers need to do is define their own components by inheriting the base classes and then implementing all interfaces required by the base classes. The API for the above base components can be found `here <../../reference/api.html#module-qlib.rl>`__.\n\nPolicy\n------------\nQlibRL directly uses Tianshou's policy. Developers could use policies provided by Tianshou off the shelf, or implement their own policies by inheriting Tianshou's policies.\n\nTraining Vessel & Trainer\n-------------------------\nAs stated by their names, training vessels and trainers are helper classes used in training. A training vessel is a ship that contains a simulator/interpreters/reward function/policy, and it controls algorithm-related parts of training. Correspondingly, the trainer is responsible for controlling the runtime parts of training.\n\nAs you may have noticed, a training vessel itself holds all the required components to build an EnvWrapper rather than holding an instance of EnvWrapper directly. This allows the training vessel to create duplicates of EnvWrapper dynamically when necessary (for example, under parallel training).\n\nWith a training vessel, the trainer could finally launch the training pipeline by simple, Scikit-learn-like interfaces (i.e., ``trainer.fit()``).\n\nThe API for Trainer and TrainingVessel and can be found `here <../../reference/api.html#module-qlib.rl.trainer>`__.\n\nThe RL module is designed in a loosely-coupled way. Currently, RL examples are integrated with concrete business logic.\nBut the core part of RL is much simpler than what you see.\nTo demonstrate the simple core of RL, `a dedicated notebook <https://github.com/microsoft/qlib/tree/main/examples/rl/simple_example.ipynb>`__ for RL without business loss is created.\n"
  },
  {
    "path": "docs/component/rl/guidance.rst",
    "content": "\n========\nGuidance\n========\n.. currentmodule:: qlib\n\nQlibRL can help users quickly get started and conveniently implement quantitative strategies based on reinforcement learning(RL) algorithms. For different user groups, we recommend the following guidance to use QlibRL.\n\nBeginners to Reinforcement Learning Algorithms\n==============================================\nWhether you are a quantitative researcher who wants to understand what RL can do in trading or a learner who wants to get started with RL algorithms in trading scenarios, if you have limited knowledge of RL and want to shield various detailed settings to quickly get started with RL algorithms, we recommend the following sequence to learn qlibrl:\n - Learn the fundamentals of RL in `part1 <https://qlib.readthedocs.io/en/latest/component/rl/overall.html#reinforcement-learning>`_.\n - Understand the trading scenarios where RL methods can be applied in `part2 <https://qlib.readthedocs.io/en/latest/component/rl/overall.html#potential-application-scenarios-in-quantitative-trading>`_.\n - Run the examples in `part3 <https://qlib.readthedocs.io/en/latest/component/rl/quickstart.html>`_ to solve trading problems using RL.\n - If you want to further explore QlibRL and make some customizations, you need to first understand the framework of QlibRL in `part4 <https://qlib.readthedocs.io/en/latest/component/rl/framework.html>`_ and rewrite specific components according to your needs.\n\nReinforcement Learning Algorithm Researcher\n==============================================\nIf you are already familiar with existing RL algorithms and dedicated to researching RL algorithms but lack domain knowledge in the financial field, and you want to validate the effectiveness of your algorithms in financial trading scenarios, we recommend the following steps to get started with QlibRL:\n - Understand the trading scenarios where RL methods can be applied in `part2 <https://qlib.readthedocs.io/en/latest/component/rl/overall.html#potential-application-scenarios-in-quantitative-trading>`_.\n - Choose an RL application scenario (currently, QlibRL has implemented two scenario examples: order execution and algorithmic trading). Run the example in `part3 <https://qlib.readthedocs.io/en/latest/component/rl/quickstart.html>`_ to get it working.\n - Modify the `policy <https://github.com/microsoft/qlib/blob/main/qlib/rl/order_execution/policy.py>`_ part to incorporate your own RL algorithm.\n\nQuantitative Researcher\n=======================\nIf you have a certain level of financial domain knowledge and coding skills, and you want to explore the application of RL algorithms in the investment field, we recommend the following steps to explore QlibRL:\n - Learn the fundamentals of RL in `part1 <https://qlib.readthedocs.io/en/latest/component/rl/overall.html#reinforcement-learning>`_.\n - Understand the trading scenarios where RL methods can be applied in `part2 <https://qlib.readthedocs.io/en/latest/component/rl/overall.html#potential-application-scenarios-in-quantitative-trading>`_.\n - Run the examples in `part3 <https://qlib.readthedocs.io/en/latest/component/rl/quickstart.html>`_ to solve trading problems using RL.\n - Understand the framework of QlibRL in `part4 <https://qlib.readthedocs.io/en/latest/component/rl/framework.html>`_.\n - Choose a suitable RL algorithm based on the characteristics of the problem you want to solve (currently, QlibRL supports PPO and DQN algorithms based on tianshou).\n - Design the MDP (Markov Decision Process) process based on market trading rules and the problem you want to solve. Refer to the example in order execution and make corresponding modifications to the following modules: `State <https://github.com/microsoft/qlib/blob/main/qlib/rl/order_execution/state.py#L70>`_, `Metrics <https://github.com/microsoft/qlib/blob/main/qlib/rl/order_execution/state.py#L18>`_, `ActionInterpreter <https://github.com/microsoft/qlib/blob/main/qlib/rl/order_execution/interpreter.py#L199>`_, `StateInterpreter <https://github.com/microsoft/qlib/blob/main/qlib/rl/order_execution/interpreter.py#L68>`_, `Reward <https://github.com/microsoft/qlib/blob/main/qlib/rl/order_execution/reward.py>`_, `Observation <https://github.com/microsoft/qlib/blob/main/qlib/rl/order_execution/interpreter.py#L44>`_, `Simulator <https://github.com/microsoft/qlib/blob/main/qlib/rl/order_execution/simulator_simple.py>`_."
  },
  {
    "path": "docs/component/rl/overall.rst",
    "content": "=====================================================\nReinforcement Learning in Quantitative Trading\n=====================================================\n\nReinforcement Learning\n======================\nDifferent from supervised learning tasks such as classification tasks and regression tasks. Another important paradigm in machine learning is Reinforcement Learning(RL), \nwhich attempts to optimize an accumulative numerical reward signal by directly interacting with the environment under a few assumptions such as Markov Decision Process(MDP).\n\nAs demonstrated in the following figure, an RL system consists of four elements, 1)the agent 2) the environment the agent interacts with 3) the policy that the agent follows to take actions on the environment and 4)the reward signal from the environment to the agent. \nIn general, the agent can perceive and interpret its environment, take actions and learn through reward, to seek long-term and maximum overall reward to achieve an optimal solution.\n\n.. image:: ../../_static/img/RL_framework.png\n   :width: 300\n   :align: center \n\nRL attempts to learn to produce actions by trial and error. \nBy sampling actions and then observing which one leads to our desired outcome, a policy is obtained to generate optimal actions. \nIn contrast to supervised learning, RL learns this not from a label but from a time-delayed label called a reward. \nThis scalar value lets us know whether the current outcome is good or bad. \nIn a word, the target of RL is to take actions to maximize reward.\n\nThe Qlib Reinforcement Learning toolkit (QlibRL) is an RL platform for quantitative investment, which provides support to implement the RL algorithms in Qlib.\n\n\nPotential Application Scenarios in Quantitative Trading\n=======================================================\nRL methods have demonstrated remarkable achievements in various applications, including game playing, resource allocation, recommendation systems, marketing, and advertising.\nIn the context of investment, which involves continuous decision-making, let's consider the example of the stock market. Investors strive to optimize their investment returns by effectively managing their positions and stock holdings through various buying and selling behaviors.\nFurthermore, investors carefully evaluate market conditions and stock-specific information before making each buying or selling decision. From an investor's perspective, this process can be viewed as a continuous decision-making process driven by interactions with the market. RL algorithms offer a promising approach to tackle such challenges.\nHere are several scenarios where RL holds potential for application in quantitative investment.\n\nOrder Execution\n---------------\nThe order execution task is to execute orders efficiently while considering multiple factors, including optimal prices, minimizing trading costs, reducing market impact, maximizing order fullfill rates, and achieving execution within a specified time frame. RL can be applied to such tasks by incorporating these objectives into the reward function and action selection process. Specifically, the RL agent interacts with the market environment, observes the state from market information, and makes decisions on next step execution. The RL algorithm learns an optimal execution strategy through trial and error, aiming to maximize the expected cumulative reward, which incorporates the desired objectives.\n\n - General Setting\n    - Environment: The environment represents the financial market where order execution takes place. It encompasses variables such as the order book dynamics, liquidity, price movements, and market conditions.\n\n    - State: The state refers to the information available to the RL agent at a given time step. It typically includes features such as the current order book state (bid-ask spread, order depth), historical price data, historical trading volume, market volatility, and any other relevant information that can aid in decision-making.\n\n    - Action: The action is the decision made by the RL agent based on the observed state. In order execution, actions can include selecting the order size, price, and timing of execution.\n\n    - Reward: The reward is a scalar signal that indicates the performance of the RL agent's action in the environment. The reward function is designed to encourage actions that lead to efficient and cost-effective order execution. It typically considers multiple objectives, such as maximizing price advantages, minimizing trading costs (including transaction fees and slippage), reducing market impact (the effect of the order on the market price) and maximizing order fullfill rates. \n\n - Scenarios\n    - Single-asset order execution: Single-asset order execution focuses on the task of executing a single order for a specific asset, such as a stock or a cryptocurrency. The primary objective is to execute the order efficiently while considering factors such as maximizing price advantages, minimizing trading costs, reducing market impact, and achieving a high fullfill rate. The RL agent interacts with the market environment and makes decisions on order size, price, and timing of execution for that particular asset. The goal is to learn an optimal execution strategy for the single asset, maximizing the expected cumulative reward while considering the specific dynamics and characteristics of that asset.\n\n    - Multi-asset order execution: Multi-asset order execution expands the order execution task to involve multiple assets or securities. It typically involves executing a portfolio of orders across different assets simultaneously or sequentially. Unlike single-asset order execution, the focus is not only on the execution of individual orders but also on managing the interactions and dependencies between different assets within the portfolio. The RL agent needs to make decisions on the order sizes, prices, and timings for each asset in the portfolio, considering their interdependencies, cash constraints, market conditions, and transaction costs. The goal is to learn an optimal execution strategy that balances the execution efficiency for each asset while considering the overall performance and objectives of the portfolio as a whole.\n   \nThe choice of settings and RL algorithm depends on the specific requirements of the task, available data, and desired performance objectives. \n\nPortfolio Construction\n----------------------\nPortfolio construction is a process of selecting and allocating assets in an investment portfolio. RL provides a framework to optimize portfolio management decisions by learning from interactions with the market environment and maximizing long-term returns while considering risk management.\n - General Setting\n    - State: The state represents the current information about the market and the portfolio. It typically includes historical prices and volumes, technical indicators, and other relevant data.\n\n    - Action: The action corresponds to the decision of allocating capital to different assets in the portfolio. It determines the weights or proportions of investments in each asset.\n\n    - Reward: The reward is a metric that evaluates the performance of the portfolio. It can be defined in various ways, such as total return, risk-adjusted return, or other objectives like maximizing Sharpe ratio or minimizing drawdown.\n\n - Scenarios\n    - Stock market: RL can be used to construct portfolios of stocks, where the agent learns to allocate capital among different stocks.\n\n    - Cryptocurrency market: RL can be applied to construct portfolios of cryptocurrencies, where the agent learns to make allocation decisions.\n\n    - Foreign exchange (Forex) market: RL can be used to construct portfolios of currency pairs, where the agent learns to allocate capital across different currencies based on exchange rate data, economic indicators, and other factors.\n\nSimilarly, the choice of basic setting and algorithm depends on the specific requirements of the problem and the characteristics of the market."
  },
  {
    "path": "docs/component/rl/quickstart.rst",
    "content": "\nQuick Start\n============\n.. currentmodule:: qlib\n\nQlibRL provides an example of an implementation of a single asset order execution task and the following is an example of the config file to train with QlibRL.\n\n.. code-block:: yaml\n\n    simulator:\n        # Each step contains 30mins\n        time_per_step: 30\n        # Upper bound of volume, should be null or a float between 0 and 1, if it is a float, represent upper bound is calculated by the percentage of the market volume\n        vol_limit: null\n    env:\n        # Concurrent environment workers.\n        concurrency: 1\n        # dummy or subproc or shmem. Corresponding to `parallelism in tianshou <https://tianshou.readthedocs.io/en/master/api/tianshou.env.html#vectorenv>`_.\n        parallel_mode: dummy\n    action_interpreter:\n        class: CategoricalActionInterpreter\n        kwargs:\n            # Candidate actions, it can be a list with length L: [a_1, a_2,..., a_L] or an integer n, in which case the list of length n+1 is auto-generated, i.e., [0, 1/n, 2/n,..., n/n].\n            values: 14\n            # Total number of steps (an upper-bound estimation)\n            max_step: 8\n        module_path: qlib.rl.order_execution.interpreter\n    state_interpreter:\n        class: FullHistoryStateInterpreter\n        kwargs:\n            # Number of dimensions in data.\n            data_dim: 6\n            # Equal to the total number of records. For example, in SAOE per minute, data_ticks is the length of the day in minutes.\n            data_ticks: 240\n            # The total number of steps (an upper-bound estimation). For example, 390min / 30min-per-step = 13 steps.\n            max_step: 8\n            # Provider of the processed data.\n            processed_data_provider:\n                class: PickleProcessedDataProvider\n                module_path: qlib.rl.data.pickle_styled\n                kwargs:\n                    data_dir: ./data/pickle_dataframe/feature\n        module_path: qlib.rl.order_execution.interpreter\n    reward:\n        class: PAPenaltyReward\n        kwargs:\n            # The penalty for a large volume in a short time.\n            penalty: 100.0\n        module_path: qlib.rl.order_execution.reward\n    data:\n        source:\n            order_dir: ./data/training_order_split\n            data_dir: ./data/pickle_dataframe/backtest\n            # number of time indexes\n            total_time: 240\n            # start time index\n            default_start_time: 0\n            # end time index\n            default_end_time: 240\n            proc_data_dim: 6\n        num_workers: 0\n        queue_size: 20\n    network:\n        class: Recurrent\n        module_path: qlib.rl.order_execution.network\n    policy:\n        class: PPO\n        kwargs:\n            lr: 0.0001\n        module_path: qlib.rl.order_execution.policy\n    runtime:\n        seed: 42\n        use_cuda: false\n    trainer:\n        max_epoch: 2\n        # Number of episodes collected in each training iteration\n        repeat_per_collect: 5\n        earlystop_patience: 2\n        # Episodes per collect at training.\n        episode_per_collect: 20\n        batch_size: 16\n        # Perform validation every n iterations\n        val_every_n_epoch: 1\n        checkpoint_path: ./checkpoints\n        checkpoint_every_n_iters: 1\n\n\nAnd the config file for backtesting:\n\n.. code-block:: yaml\n\n    order_file: ./data/backtest_orders.csv\n    start_time: \"9:45\"\n    end_time: \"14:44\"\n    qlib:\n        provider_uri_1min: ./data/bin\n        feature_root_dir: ./data/pickle\n        # feature generated by today's information\n        feature_columns_today: [\n            \"$open\", \"$high\", \"$low\", \"$close\", \"$vwap\", \"$volume\",\n        ]\n        # feature generated by yesterday's information\n        feature_columns_yesterday: [\n            \"$open_v1\", \"$high_v1\", \"$low_v1\", \"$close_v1\", \"$vwap_v1\", \"$volume_v1\",\n        ]\n    exchange:\n        # the expression for buying and selling stock limitation\n        limit_threshold: ['$close == 0', '$close == 0']\n        # deal price for buying and selling\n        deal_price: [\"If($close == 0, $vwap, $close)\", \"If($close == 0, $vwap, $close)\"]\n    volume_threshold:\n        # volume limits are both buying and selling, \"cum\" means that this is a cumulative value over time\n        all: [\"cum\", \"0.2 * DayCumsum($volume, '9:45', '14:44')\"]\n        # the volume limits of buying\n        buy: [\"current\", \"$close\"]\n        # the volume limits of selling, \"current\" means that this is a real-time value and will not accumulate over time\n        sell: [\"current\", \"$close\"]\n    strategies: \n        30min: \n            class: TWAPStrategy\n            module_path: qlib.contrib.strategy.rule_strategy\n            kwargs: {}\n        1day: \n            class: SAOEIntStrategy\n            module_path: qlib.rl.order_execution.strategy\n            kwargs:\n            state_interpreter:\n                class: FullHistoryStateInterpreter\n                module_path: qlib.rl.order_execution.interpreter\n                kwargs:\n                max_step: 8\n                data_ticks: 240\n                data_dim: 6\n                processed_data_provider:\n                    class: PickleProcessedDataProvider\n                    module_path: qlib.rl.data.pickle_styled\n                    kwargs:\n                    data_dir: ./data/pickle_dataframe/feature\n            action_interpreter: \n                class: CategoricalActionInterpreter\n                module_path: qlib.rl.order_execution.interpreter\n                kwargs: \n                values: 14\n                max_step: 8\n            network: \n                class: Recurrent\n                module_path: qlib.rl.order_execution.network\n                kwargs: {}\n            policy: \n                class: PPO\n                module_path: qlib.rl.order_execution.policy\n                kwargs: \n                    lr: 1.0e-4\n                    # Local path to the latest model. The model is generated during training, so please run training first if you want to run backtest with a trained policy. You could also remove this parameter file to run backtest with a randomly initialized policy.\n                    weight_file: ./checkpoints/latest.pth\n    # Concurrent environment workers.\n    concurrency: 5\n\nWith the above config files, you can start training the agent by the following command:\n\n.. code-block:: console\n\n    $ python -m qlib.rl.contrib.train_onpolicy.py --config_path train_config.yml\n\nAfter the training, you can backtest with the following command:\n\n.. code-block:: console\n\n    $ python -m qlib.rl.contrib.backtest.py --config_path backtest_config.yml\n\nIn that case, :class:`~qlib.rl.order_execution.simulator_qlib.SingleAssetOrderExecution` and :class:`~qlib.rl.order_execution.simulator_simple.SingleAssetOrderExecutionSimple` as examples for simulator, :class:`qlib.rl.order_execution.interpreter.FullHistoryStateInterpreter` and :class:`qlib.rl.order_execution.interpreter.CategoricalActionInterpreter` as examples for interpreter, :class:`qlib.rl.order_execution.policy.PPO` as an example for policy, and :class:`qlib.rl.order_execution.reward.PAPenaltyReward` as an example for reward.\nFor the single asset order execution task, if developers have already defined their simulator/interpreters/reward function/policy, they could launch the training and backtest pipeline by simply modifying the corresponding settings in the config files.\nThe details about the example can be found `here <https://github.com/microsoft/qlib/blob/main/examples/rl/README.md>`_. \n\nIn the future, we will provide more examples for different scenarios such as RL-based portfolio construction.\n"
  },
  {
    "path": "docs/component/rl/toctree.rst",
    "content": ".. _rl:\n\n========================================================================\nReinforcement Learning in Quantitative Trading\n========================================================================\n\n.. toctree::\n    Guidance <guidance>\n    Overall <overall>\n    Quick Start <quickstart>\n    Framework <framework>\n"
  },
  {
    "path": "docs/component/strategy.rst",
    "content": ".. _strategy:\n\n========================================\nPortfolio Strategy: Portfolio Management\n========================================\n.. currentmodule:: qlib\n\nIntroduction\n============\n\n``Portfolio Strategy`` is designed to adopt different portfolio strategies, which means that users can adopt different algorithms to generate investment portfolios based on the prediction scores of the ``Forecast Model``. Users can use the ``Portfolio Strategy`` in an automatic workflow by ``Workflow`` module, please refer to `Workflow: Workflow Management <workflow.html>`_.\n\nBecause the components in ``Qlib`` are designed in a loosely-coupled way, ``Portfolio Strategy`` can be used as an independent module also.\n\n``Qlib`` provides several implemented portfolio strategies. Also, ``Qlib`` supports custom strategy, users can customize strategies according to their own requirements.\n\nAfter users specifying the models(forecasting signals) and strategies, running backtest will help users to check the performance of a custom model(forecasting signals)/strategy.\n\nBase Class & Interface\n======================\n\nBaseStrategy\n------------\n\nQlib provides a base class ``qlib.strategy.base.BaseStrategy``. All strategy classes need to inherit the base class and implement its interface.\n\n- `generate_trade_decision`\n    generate_trade_decision is a key interface that generates trade decisions in each trading bar.\n    The frequency to call this method depends on the executor frequency(\"time_per_step\"=\"day\" by default). But the trading frequency can be decided by users' implementation.\n    For example, if the user wants to trading in weekly while the `time_per_step` is \"day\" in executor, user can return non-empty TradeDecision weekly(otherwise return empty like `this <https://github.com/microsoft/qlib/blob/main/qlib/contrib/strategy/signal_strategy.py#L132>`_ ).\n\nUsers can inherit `BaseStrategy` to customize their strategy class.\n\nWeightStrategyBase\n------------------\n\nQlib also provides a class ``qlib.contrib.strategy.WeightStrategyBase`` that is a subclass of `BaseStrategy`.\n\n`WeightStrategyBase` only focuses on the target positions, and automatically generates an order list based on positions. It provides the `generate_target_weight_position` interface.\n\n- `generate_target_weight_position`\n    - According to the current position and trading date to generate the target position. The cash is not considered in\n      the output weight distribution.\n    - Return the target position.\n\n    .. note::\n        Here the `target position` means the target percentage of total assets.\n\n`WeightStrategyBase` implements the interface `generate_order_list`, whose processions is as follows.\n\n- Call `generate_target_weight_position` method to generate the target position.\n- Generate the target amount of stocks from the target position.\n- Generate the order list from the target amount\n\nUsers can inherit `WeightStrategyBase` and implement the interface `generate_target_weight_position` to customize their strategy class, which only focuses on the target positions.\n\nImplemented Strategy\n====================\n\nQlib provides a implemented strategy classes named `TopkDropoutStrategy`.\n\nTopkDropoutStrategy\n-------------------\n`TopkDropoutStrategy` is a subclass of `BaseStrategy` and implement the interface `generate_order_list` whose process is as follows.\n\n- Adopt the ``Topk-Drop`` algorithm to calculate the target amount of each stock\n\n    .. note::\n        There are two parameters for the ``Topk-Drop`` algorithm:\n\n        - `Topk`: The number of stocks held\n        - `Drop`: The number of stocks sold on each trading day\n\n        In general, the number of stocks currently held is `Topk`, with the exception of being zero at the beginning period of trading.\n        For each trading day, let $d$ be the number of the instruments currently held and with a rank $\\gt K$ when ranked by the prediction scores from high to low.\n        Then `d` number of stocks currently held with the worst `prediction score` will be sold, and the same number of unheld stocks with the best `prediction score` will be bought.\n\n        In general, $d=$`Drop`, especially when the pool of the candidate instruments is large, $K$ is large, and `Drop` is small.\n\n        In most cases, ``TopkDrop`` algorithm sells and buys `Drop` stocks every trading day, which yields a turnover rate of 2$\\times$`Drop`/$K$.\n\n        The following images illustrate a typical scenario.\n\n        .. image:: ../_static/img/topk_drop.png\n            :alt: Topk-Drop\n\n\n\n- Generate the order list from the target amount\n\nEnhancedIndexingStrategy\n------------------------\n`EnhancedIndexingStrategy` Enhanced indexing combines the arts of active management and passive management,\nwith the aim of outperforming a benchmark index (e.g., S&P 500) in terms of portfolio return while controlling\nthe risk exposure (a.k.a. tracking error).\n\nFor more information, please refer to `qlib.contrib.strategy.signal_strategy.EnhancedIndexingStrategy`\nand `qlib.contrib.strategy.optimizer.enhanced_indexing.EnhancedIndexingOptimizer`.\n\n\nUsage & Example\n===============\n\nFirst, user can create a model to get trading signals(the variable name is ``pred_score`` in following cases).\n\nPrediction Score\n----------------\n\nThe `prediction score` is a pandas DataFrame. Its index is <datetime(pd.Timestamp), instrument(str)> and it must\ncontains a `score` column.\n\nA prediction sample is shown as follows.\n\n.. code-block:: python\n\n      datetime instrument     score\n    2019-01-04   SH600000 -0.505488\n    2019-01-04   SZ002531 -0.320391\n    2019-01-04   SZ000999  0.583808\n    2019-01-04   SZ300569  0.819628\n    2019-01-04   SZ001696 -0.137140\n                 ...            ...\n    2019-04-30   SZ000996 -1.027618\n    2019-04-30   SH603127  0.225677\n    2019-04-30   SH603126  0.462443\n    2019-04-30   SH603133 -0.302460\n    2019-04-30   SZ300760 -0.126383\n\n``Forecast Model`` module can make predictions, please refer to `Forecast Model: Model Training & Prediction <model.html>`_.\n\nNormally, the prediction score is the output of the models. But some models are learned from a label with a different scale. So the scale of the prediction score may be different from your expectation(e.g. the return of instruments).\n\nQlib didn't add a step to scale the prediction score to a unified scale due to the following reasons.\n- Because not every trading strategy cares about the scale(e.g. TopkDropoutStrategy only cares about the order).  So the strategy is responsible for rescaling the prediction score(e.g. some portfolio-optimization-based strategies may require a meaningful scale).\n- The model has the flexibility to define the target, loss, and data processing. So we don't think there is a silver bullet to rescale it back directly barely based on the model's outputs. If you want to scale it back to some meaningful values(e.g. stock returns.), an intuitive solution is to create a regression model for the model's recent outputs and your recent target values.\n\nRunning backtest\n----------------\n\n- In most cases, users could backtest their portfolio management strategy  with ``backtest_daily``.\n\n    .. code-block:: python\n\n        from pprint import pprint\n\n        import qlib\n        import pandas as pd\n        from qlib.utils.time import Freq\n        from qlib.utils import flatten_dict\n        from qlib.contrib.evaluate import backtest_daily\n        from qlib.contrib.evaluate import risk_analysis\n        from qlib.contrib.strategy import TopkDropoutStrategy\n\n        # init qlib\n        qlib.init(provider_uri=<qlib data dir>)\n\n        CSI300_BENCH = \"SH000300\"\n        STRATEGY_CONFIG = {\n            \"topk\": 50,\n            \"n_drop\": 5,\n            # pred_score, pd.Series\n            \"signal\": pred_score,\n        }\n\n\n        strategy_obj = TopkDropoutStrategy(**STRATEGY_CONFIG)\n        report_normal, positions_normal = backtest_daily(\n            start_time=\"2017-01-01\", end_time=\"2020-08-01\", strategy=strategy_obj\n        )\n        analysis = dict()\n        # default frequency will be daily (i.e. \"day\")\n        analysis[\"excess_return_without_cost\"] = risk_analysis(report_normal[\"return\"] - report_normal[\"bench\"])\n        analysis[\"excess_return_with_cost\"] = risk_analysis(report_normal[\"return\"] - report_normal[\"bench\"] - report_normal[\"cost\"])\n\n        analysis_df = pd.concat(analysis)  # type: pd.DataFrame\n        pprint(analysis_df)\n\n\n\n- If users would like to control their strategies in a more detailed(e.g. users have a more advanced version of executor), user could follow this example.\n\n    .. code-block:: python\n\n        from pprint import pprint\n\n        import qlib\n        import pandas as pd\n        from qlib.utils.time import Freq\n        from qlib.utils import flatten_dict\n        from qlib.backtest import backtest, executor\n        from qlib.contrib.evaluate import risk_analysis\n        from qlib.contrib.strategy import TopkDropoutStrategy\n\n        # init qlib\n        qlib.init(provider_uri=<qlib data dir>)\n\n        CSI300_BENCH = \"SH000300\"\n        # Benchmark is for calculating the excess return of your strategy.\n        # Its data format will be like **ONE normal instrument**.\n        # For example, you can query its data with the code below\n        # `D.features([\"SH000300\"], [\"$close\"], start_time='2010-01-01', end_time='2017-12-31', freq='day')`\n        # It is different from the argument `market`, which indicates a universe of stocks (e.g. **A SET** of stocks like csi300)\n        # For example, you can query all data from a stock market with the code below.\n        # ` D.features(D.instruments(market='csi300'), [\"$close\"], start_time='2010-01-01', end_time='2017-12-31', freq='day')`\n\n        FREQ = \"day\"\n        STRATEGY_CONFIG = {\n            \"topk\": 50,\n            \"n_drop\": 5,\n            # pred_score, pd.Series\n            \"signal\": pred_score,\n        }\n\n        EXECUTOR_CONFIG = {\n            \"time_per_step\": \"day\",\n            \"generate_portfolio_metrics\": True,\n        }\n\n        backtest_config = {\n            \"start_time\": \"2017-01-01\",\n            \"end_time\": \"2020-08-01\",\n            \"account\": 100000000,\n            \"benchmark\": CSI300_BENCH,\n            \"exchange_kwargs\": {\n                \"freq\": FREQ,\n                \"limit_threshold\": 0.095,\n                \"deal_price\": \"close\",\n                \"open_cost\": 0.0005,\n                \"close_cost\": 0.0015,\n                \"min_cost\": 5,\n            },\n        }\n\n        # strategy object\n        strategy_obj = TopkDropoutStrategy(**STRATEGY_CONFIG)\n        # executor object\n        executor_obj = executor.SimulatorExecutor(**EXECUTOR_CONFIG)\n        # backtest\n        portfolio_metric_dict, indicator_dict = backtest(executor=executor_obj, strategy=strategy_obj, **backtest_config)\n        analysis_freq = \"{0}{1}\".format(*Freq.parse(FREQ))\n        # backtest info\n        report_normal, positions_normal = portfolio_metric_dict.get(analysis_freq)\n\n        # analysis\n        analysis = dict()\n        analysis[\"excess_return_without_cost\"] = risk_analysis(\n            report_normal[\"return\"] - report_normal[\"bench\"], freq=analysis_freq\n        )\n        analysis[\"excess_return_with_cost\"] = risk_analysis(\n            report_normal[\"return\"] - report_normal[\"bench\"] - report_normal[\"cost\"], freq=analysis_freq\n        )\n\n        analysis_df = pd.concat(analysis)  # type: pd.DataFrame\n        # log metrics\n        analysis_dict = flatten_dict(analysis_df[\"risk\"].unstack().T.to_dict())\n        # print out results\n        pprint(f\"The following are analysis results of benchmark return({analysis_freq}).\")\n        pprint(risk_analysis(report_normal[\"bench\"], freq=analysis_freq))\n        pprint(f\"The following are analysis results of the excess return without cost({analysis_freq}).\")\n        pprint(analysis[\"excess_return_without_cost\"])\n        pprint(f\"The following are analysis results of the excess return with cost({analysis_freq}).\")\n        pprint(analysis[\"excess_return_with_cost\"])\n\n\nResult\n------\n\nThe backtest results are in the following form:\n\n.. code-block:: python\n\n                                                      risk\n    excess_return_without_cost mean               0.000605\n                               std                0.005481\n                               annualized_return  0.152373\n                               information_ratio  1.751319\n                               max_drawdown      -0.059055\n    excess_return_with_cost    mean               0.000410\n                               std                0.005478\n                               annualized_return  0.103265\n                               information_ratio  1.187411\n                               max_drawdown      -0.075024\n\n\n- `excess_return_without_cost`\n    - `mean`\n        Mean value of the `CAR` (cumulative abnormal return) without cost\n    - `std`\n        The `Standard Deviation` of `CAR` (cumulative abnormal return) without cost.\n    - `annualized_return`\n        The `Annualized Rate` of `CAR` (cumulative abnormal return) without cost.\n    - `information_ratio`\n        The `Information Ratio` without cost. please refer to `Information Ratio – IR <https://www.investopedia.com/terms/i/informationratio.asp>`_.\n    - `max_drawdown`\n        The `Maximum Drawdown` of `CAR` (cumulative abnormal return) without cost, please refer to `Maximum Drawdown (MDD) <https://www.investopedia.com/terms/m/maximum-drawdown-mdd.asp>`_.\n\n- `excess_return_with_cost`\n    - `mean`\n        Mean value of the `CAR` (cumulative abnormal return) series with cost\n    - `std`\n        The `Standard Deviation` of `CAR` (cumulative abnormal return) series with cost.\n    - `annualized_return`\n        The `Annualized Rate` of `CAR` (cumulative abnormal return) with cost.\n    - `information_ratio`\n        The `Information Ratio` with cost. please refer to `Information Ratio – IR <https://www.investopedia.com/terms/i/informationratio.asp>`_.\n    - `max_drawdown`\n        The `Maximum Drawdown` of `CAR` (cumulative abnormal return) with cost, please refer to `Maximum Drawdown (MDD) <https://www.investopedia.com/terms/m/maximum-drawdown-mdd.asp>`_.\n\n\nReference\n=========\nTo know more about the `prediction score` `pred_score` output by ``Forecast Model``, please refer to `Forecast Model: Model Training & Prediction <model.html>`_.\n"
  },
  {
    "path": "docs/component/workflow.rst",
    "content": ".. _workflow:\n\n=============================\nWorkflow: Workflow Management\n=============================\n.. currentmodule:: qlib\n\nIntroduction\n============\n\nThe components in `Qlib Framework <../introduction/introduction.html#framework>`_ are designed in a loosely-coupled way. Users could build their own Quant research workflow with these components like `Example <https://github.com/microsoft/qlib/blob/main/examples/workflow_by_code.py>`_.\n\n\nBesides, ``Qlib`` provides more user-friendly interfaces named ``qrun`` to automatically run the whole workflow defined by configuration. Running the whole workflow is called an `execution`.\nWith ``qrun``, user can easily start an `execution`, which includes the following steps:\n\n- Data\n    - Loading\n    - Processing\n    - Slicing\n- Model\n    - Training and inference\n    - Saving & loading\n- Evaluation\n    - Forecast signal analysis\n    - Backtest\n\nFor each `execution`, ``Qlib`` has a complete system to tracking all the information as well as artifacts generated during training, inference and evaluation phase. For more information about how ``Qlib`` handles this, please refer to the related document: `Recorder: Experiment Management <../component/recorder.html>`_.\n\nComplete Example\n================\n\nBefore getting into details, here is a complete example of ``qrun``, which defines the workflow in typical Quant research.\nBelow is a typical config file of ``qrun``.\n\n.. code-block:: YAML\n\n    qlib_init:\n        provider_uri: \"~/.qlib/qlib_data/cn_data\"\n        region: cn\n    market: &market csi300\n    benchmark: &benchmark SH000300\n    data_handler_config: &data_handler_config\n        start_time: 2008-01-01\n        end_time: 2020-08-01\n        fit_start_time: 2008-01-01\n        fit_end_time: 2014-12-31\n        instruments: *market\n    port_analysis_config: &port_analysis_config\n        strategy:\n            class: TopkDropoutStrategy\n            module_path: qlib.contrib.strategy.strategy\n            kwargs:\n                topk: 50\n                n_drop: 5\n                signal: <PRED>\n        backtest:\n            start_time: 2017-01-01\n            end_time: 2020-08-01\n            account: 100000000\n            benchmark: *benchmark\n            exchange_kwargs:\n                limit_threshold: 0.095\n                deal_price: close\n                open_cost: 0.0005\n                close_cost: 0.0015\n                min_cost: 5\n    task:\n        model:\n            class: LGBModel\n            module_path: qlib.contrib.model.gbdt\n            kwargs:\n                loss: mse\n                colsample_bytree: 0.8879\n                learning_rate: 0.0421\n                subsample: 0.8789\n                lambda_l1: 205.6999\n                lambda_l2: 580.9768\n                max_depth: 8\n                num_leaves: 210\n                num_threads: 20\n        dataset:\n            class: DatasetH\n            module_path: qlib.data.dataset\n            kwargs:\n                handler:\n                    class: Alpha158\n                    module_path: qlib.contrib.data.handler\n                    kwargs: *data_handler_config\n                segments:\n                    train: [2008-01-01, 2014-12-31]\n                    valid: [2015-01-01, 2016-12-31]\n                    test: [2017-01-01, 2020-08-01]\n        record:\n            - class: SignalRecord\n              module_path: qlib.workflow.record_temp\n              kwargs: {}\n            - class: PortAnaRecord\n              module_path: qlib.workflow.record_temp\n              kwargs:\n                  config: *port_analysis_config\n\nAfter saving the config into `configuration.yaml`, users could start the workflow and test their ideas with a single command below.\n\n.. code-block:: bash\n\n    qrun configuration.yaml\n\nIf users want to use ``qrun`` under debug mode, please use the following command:\n\n.. code-block:: bash\n\n    python -m pdb qlib/cli/run.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml\n\n.. note::\n\n    `qrun` will be placed in your $PATH directory when installing ``Qlib``.\n\n.. note::\n\n    The symbol `&` in `yaml` file stands for an anchor of a field, which is useful when another fields include this parameter as part of the value. Taking the configuration file above as an example, users can directly change the value of `market` and `benchmark` without traversing the entire configuration file.\n\n\nConfiguration File\n==================\n\nLet's get into details of ``qrun`` in this section.\nBefore using ``qrun``, users need to prepare a configuration file. The following content shows how to prepare each part of the configuration file.\n\nThe design logic of the configuration file is very simple. It predefines fixed workflows and provide this yaml interface to users to define how to initialize each component.\nIt follow the design of `init_instance_by_config <https://github.com/microsoft/qlib/blob/2aee9e0145decc3e71def70909639b5e5a6f4b58/qlib/utils/__init__.py#L264>`_ .  It defines the initialization of each component of Qlib, which typically include the class and the initialization arguments.\n\nFor example, the following yaml and code are equivalent.\n\n.. code-block:: YAML\n\n    model:\n        class: LGBModel\n        module_path: qlib.contrib.model.gbdt\n        kwargs:\n            loss: mse\n            colsample_bytree: 0.8879\n            learning_rate: 0.0421\n            subsample: 0.8789\n            lambda_l1: 205.6999\n            lambda_l2: 580.9768\n            max_depth: 8\n            num_leaves: 210\n            num_threads: 20\n\n\n.. code-block:: python\n\n        from qlib.contrib.model.gbdt import LGBModel\n        kwargs = {\n            \"loss\": \"mse\" ,\n            \"colsample_bytree\": 0.8879,\n            \"learning_rate\": 0.0421,\n            \"subsample\": 0.8789,\n            \"lambda_l1\": 205.6999,\n            \"lambda_l2\": 580.9768,\n            \"max_depth\": 8,\n            \"num_leaves\": 210,\n            \"num_threads\": 20,\n        }\n        LGBModel(kwargs)\n\n\nQlib Init Section\n-----------------\n\nAt first, the configuration file needs to contain several basic parameters which will be used for qlib initialization.\n\n.. code-block:: YAML\n\n    provider_uri: \"~/.qlib/qlib_data/cn_data\"\n    region: cn\n\nThe meaning of each field is as follows:\n\n- `provider_uri`\n    Type: str. The URI of the Qlib data. For example, it could be the location where the data loaded by ``get_data.py`` are stored.\n\n- `region`\n    - If `region` == \"us\", ``Qlib`` will be initialized in US-stock mode.\n    - If `region` == \"cn\", ``Qlib`` will be initialized in China-stock mode.\n\n    .. note::\n\n        The value of `region` should be aligned with the data stored in `provider_uri`.\n\n\nTask Section\n------------\n\nThe `task` field in the configuration corresponds to a `task`, which contains the parameters of three different subsections: `Model`, `Dataset` and `Record`.\n\nModel Section\n~~~~~~~~~~~~~\n\nIn the `task` field, the `model` section describes the parameters of the model to be used for training and inference. For more information about the base ``Model`` class, please refer to `Qlib Model <../component/model.html>`_.\n\n.. code-block:: YAML\n\n    model:\n        class: LGBModel\n        module_path: qlib.contrib.model.gbdt\n        kwargs:\n            loss: mse\n            colsample_bytree: 0.8879\n            learning_rate: 0.0421\n            subsample: 0.8789\n            lambda_l1: 205.6999\n            lambda_l2: 580.9768\n            max_depth: 8\n            num_leaves: 210\n            num_threads: 20\n\nThe meaning of each field is as follows:\n\n- `class`\n    Type: str. The name for the model class.\n\n- `module_path`\n    Type: str. The path for the model in qlib.\n\n- `kwargs`\n    The keywords arguments for the model. Please refer to the specific model implementation for more information: `models <https://github.com/microsoft/qlib/blob/main/qlib/contrib/model>`_.\n\n.. note::\n\n    ``Qlib`` provides a util named: ``init_instance_by_config`` to initialize any class inside ``Qlib`` with the configuration includes the fields: `class`, `module_path` and `kwargs`.\n\nDataset Section\n~~~~~~~~~~~~~~~\n\nThe `dataset` field describes the parameters for the ``Dataset`` module in ``Qlib`` as well those for the module ``DataHandler``. For more information about the ``Dataset`` module, please refer to `Qlib Data <../component/data.html#dataset>`_.\n\nThe keywords arguments configuration of the ``DataHandler`` is as follows:\n\n.. code-block:: YAML\n\n    data_handler_config: &data_handler_config\n        start_time: 2008-01-01\n        end_time: 2020-08-01\n        fit_start_time: 2008-01-01\n        fit_end_time: 2014-12-31\n        instruments: *market\n\nUsers can refer to the document of `DataHandler <../component/data.html#datahandler>`_ for more information about the meaning of each field in the configuration.\n\nHere is the configuration for the ``Dataset`` module which will take care of data preprocessing and slicing during the training and testing phase.\n\n.. code-block:: YAML\n\n    dataset:\n        class: DatasetH\n        module_path: qlib.data.dataset\n        kwargs:\n            handler:\n                class: Alpha158\n                module_path: qlib.contrib.data.handler\n                kwargs: *data_handler_config\n            segments:\n                train: [2008-01-01, 2014-12-31]\n                valid: [2015-01-01, 2016-12-31]\n                test: [2017-01-01, 2020-08-01]\n\nRecord Section\n~~~~~~~~~~~~~~\n\nThe `record` field is about the parameters the ``Record`` module in ``Qlib``. ``Record`` is responsible for tracking training process and results such as `information Coefficient (IC)` and `backtest` in a standard format.\n\nThe following script is the configuration of `backtest` and the `strategy` used in `backtest`:\n\n.. code-block:: YAML\n\n    port_analysis_config: &port_analysis_config\n        strategy:\n            class: TopkDropoutStrategy\n            module_path: qlib.contrib.strategy.strategy\n            kwargs:\n                topk: 50\n                n_drop: 5\n                signal: <PRED>\n        backtest:\n            limit_threshold: 0.095\n            account: 100000000\n            benchmark: *benchmark\n            deal_price: close\n            open_cost: 0.0005\n            close_cost: 0.0015\n            min_cost: 5\n\nFor more information about the meaning of each field in configuration of `strategy` and `backtest`, users can look up the documents: `Strategy <../component/strategy.html>`_ and `Backtest <../component/backtest.html>`_.\n\nHere is the configuration details of different `Record Template` such as ``SignalRecord`` and ``PortAnaRecord``:\n\n.. code-block:: YAML\n\n    record:\n        - class: SignalRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: {}\n        - class: PortAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs:\n            config: *port_analysis_config\n\nFor more information about the ``Record`` module in ``Qlib``, user can refer to the related document: `Record <../component/recorder.html#record-template>`_.\n"
  },
  {
    "path": "docs/conf.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\n\n# QLib documentation build configuration file, created by\n# sphinx-quickstart on Wed Sep 27 15:16:05 2017.\n#\n# This file is execfile()d with the current directory set to its\n# containing dir.\n#\n# Note that not all possible configuration values are present in this\n# autogenerated file.\n#\n# All configuration values have a default; values that are commented out\n# serve to show the default.\n\n# If extensions (or modules to document with autodoc) are in another directory,\n# add these directories to sys.path here. If the directory is relative to the\n# documentation root, use os.path.abspath to make it absolute, like shown here.\n#\nimport os\nimport sys\n\nfrom importlib.metadata import version as ver\n\n# -- General configuration ------------------------------------------------\n\n# If your documentation needs a minimal Sphinx version, state it here.\n#\n# needs_sphinx = '1.0'\n\n# Add any Sphinx extension module names here, as strings. They can be\n# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom\n# ones.\nextensions = [\n    \"sphinx.ext.autodoc\",\n    \"sphinx.ext.todo\",\n    \"sphinx.ext.mathjax\",\n    \"sphinx.ext.napoleon\",\n]\n\n# Add any paths that contain templates here, relative to this directory.\ntemplates_path = [\"_templates\"]\n\n# The suffix(es) of source filenames.\n# You can specify multiple suffix as a list of string:\n#\n# source_suffix = ['.rst', '.md']\nsource_suffix = \".rst\"\n\n# The master toctree document.\nmaster_doc = \"index\"\n\n\n# General information about the project.\nproject = \"QLib\"\ncopyright = \"Microsoft\"\nauthor = \"Microsoft\"\n\n# The version info for the project you're documenting, acts as replacement for\n# |version| and |release|, also used in various other places throughout the\n# built documents.\n#\n# The short X.Y version.\nversion = ver(\"pyqlib\")\n# The full version, including alpha/beta/rc tags.\nrelease = version\n\n# The language for content autogenerated by Sphinx. Refer to documentation\n# for a list of supported languages.\n#\n# This is also used if you do content translation via gettext catalogs.\n# Usually you set \"language\" from the command line for these cases.\nlanguage = \"en_US\"\n\n# List of patterns, relative to source directory, that match files and\n# directories to ignore when looking for source files.\n# This patterns also effect to html_static_path and html_extra_path\nexclude_patterns = [\"_build\", \"Thumbs.db\", \".DS_Store\", \"hidden\"]\n\n# The name of the Pygments (syntax highlighting) style to use.\npygments_style = \"sphinx\"\n\n# If true, `todo` and `todoList` produce output, else they produce nothing.\ntodo_include_todos = False\n\n# If true, '()' will be appended to :func: etc. cross-reference text.\nadd_function_parentheses = False\n\n# If true, the current module name will be prepended to all description\n# unit titles (such as .. function::).\nadd_module_names = True\n\n# If true, `todo` and `todoList` produce output, else they produce nothing.\ntodo_include_todos = True\n\n\n# -- Options for HTML output ----------------------------------------------\n\n# The theme to use for HTML and HTML Help pages.  See the documentation for\n# a list of builtin themes.\n#\nhtml_theme = \"sphinx_rtd_theme\"\n\nhtml_logo = \"_static/img/logo/1.png\"\n\n\n# Theme options are theme-specific and customize the look and feel of a theme\n# further.  For a list of options available for each theme, see the\n# documentation.\n# html_context = {\n#     \"display_github\": False,\n#     \"last_updated\": True,\n#     \"commit\": True,\n#     \"github_user\": \"Microsoft\",\n#     \"github_repo\": \"QLib\",\n#     'github_version': 'master',\n#     'conf_py_path': '/docs/',\n\n# }\n#\nhtml_theme_options = {\n    \"logo_only\": True,\n    \"collapse_navigation\": False,\n    \"navigation_depth\": 4,\n}\n\n# Add any paths that contain custom static files (such as style sheets) here,\n# relative to this directory. They are copied after the builtin static files,\n# so a file named \"default.css\" will overwrite the builtin \"default.css\".\n# html_static_path = ['_static']\n\n# Custom sidebar templates, must be a dictionary that maps document names\n# to template names.\n#\n# This is required for the alabaster theme\n# refs: http://alabaster.readthedocs.io/en/latest/installation.html#sidebars\nhtml_sidebars = {\n    \"**\": [\n        \"about.html\",\n        \"navigation.html\",\n        \"relations.html\",  # needs 'show_related': True theme option to display\n        \"searchbox.html\",\n    ]\n}\n\n\n# -- Options for HTMLHelp output ------------------------------------------\n\n# Output file base name for HTML help builder.\nhtmlhelp_basename = \"qlibdoc\"\n\n\n# -- Options for LaTeX output ---------------------------------------------\n\nlatex_elements = {\n    # The paper size ('letterpaper' or 'a4paper').\n    #\n    # 'papersize': 'letterpaper',\n    # The font size ('10pt', '11pt' or '12pt').\n    #\n    # 'pointsize': '10pt',\n    # Additional stuff for the LaTeX preamble.\n    #\n    # 'preamble': '',\n    # Latex figure (float) alignment\n    #\n    # 'figure_align': 'htbp',\n}\n\n# Grouping the document tree into LaTeX files. List of tuples\n# (source start file, target name, title,\n#  author, documentclass [howto, manual, or own class]).\nlatex_documents = [\n    (master_doc, \"qlib.tex\", \"QLib Documentation\", \"Microsoft\", \"manual\"),\n]\n\n\n# -- Options for manual page output ---------------------------------------\n\n# One entry per manual page. List of tuples\n# (source start file, name, description, authors, manual section).\nman_pages = [(master_doc, \"qlib\", \"QLib Documentation\", [author], 1)]\n\n\n# -- Options for Texinfo output -------------------------------------------\n\n# Grouping the document tree into Texinfo files. List of tuples\n# (source start file, target name, title, author,\n#  dir menu entry, description, category)\ntexinfo_documents = [\n    (\n        master_doc,\n        \"QLib\",\n        \"QLib Documentation\",\n        author,\n        \"QLib\",\n        \"One line description of project.\",\n        \"Miscellaneous\",\n    ),\n]\n\n\n# -- Options for Epub output ----------------------------------------------\n\n# Bibliographic Dublin Core info.\nepub_title = project\nepub_author = author\nepub_publisher = author\nepub_copyright = copyright\n\n# The unique identifier of the text. This can be a ISBN number\n# or the project homepage.\n#\n# epub_identifier = ''\n\n# A unique identification for the text.\n#\n# epub_uid = ''\n\n# A list of files that should not be packed into the epub file.\nepub_exclude_files = [\"search.html\"]\n\n\nautodoc_member_order = \"bysource\"\nautodoc_default_flags = [\"members\"]\nautodoc_default_options = {\n    \"members\": True,\n    \"member-order\": \"bysource\",\n    \"special-members\": \"__init__\",\n}\n"
  },
  {
    "path": "docs/developer/code_standard_and_dev_guide.rst",
    "content": ".. _code_standard:\n\n=============\nCode Standard\n=============\n\nDocstring\n=========\nPlease use the `Numpydoc Style <https://stackoverflow.com/a/24385103>`_.\n\nContinuous Integration\n======================\nContinuous Integration (CI) tools help you stick to the quality standards by running tests every time you push a new commit and reporting the results to a pull request.\n\nWhen you submit a PR request, you can check whether your code passes the CI tests in the \"check\" section at the bottom of the web page.\n\n1. Qlib will check the code format with black. The PR will raise error if your code does not align to the standard of Qlib(e.g. a common error is the mixed use of space and tab).\n\n   You can fix the bug by inputting the following code in the command line.\n\n.. code-block:: bash\n\n    pip install black\n    python -m black . -l 120\n\n\n2. Qlib will check your code style pylint. The checking command is implemented in [github action workflow](https://github.com/microsoft/qlib/blob/0e8b94a552f1c457cfa6cd2c1bb3b87ebb3fb279/.github/workflows/test.yml#L66).\n   Sometime pylint's restrictions are not that reasonable. You can ignore specific errors like this\n\n.. code-block:: python\n\n    return -ICLoss()(pred, target, index)  # pylint: disable=E1130\n\n\n3. Qlib will check your code style flake8. The checking command is implemented in [github action workflow](https://github.com/microsoft/qlib/blob/0e8b94a552f1c457cfa6cd2c1bb3b87ebb3fb279/.github/workflows/test.yml#L73).\n\n   You can fix the bug by inputing the following code in the command line.\n\n.. code-block:: bash\n\n    flake8 --ignore E501,F541,E402,F401,W503,E741,E266,E203,E302,E731,E262,F523,F821,F811,F841,E713,E265,W291,E712,E722,W293 qlib\n\n\n4. Qlib has integrated pre-commit, which will make it easier for developers to format their code.\n\n   Just run the following two commands, and the code will be automatically formatted using black and flake8 when the git commit command is executed.\n\n.. code-block:: bash\n\n    pip install -e .[dev]\n    pre-commit install\n\n\n=================================\nDevelopment Guidance\n=================================\n\nAs a developer, you often want make changes to `Qlib` and hope it would reflect directly in your environment without reinstalling it. You can install `Qlib` in editable mode with following command.\nThe `[dev]` option will help you to install some related packages when developing `Qlib` (e.g. pytest, sphinx)\n\n.. code-block:: bash\n\n    pip install -e \".[dev]\""
  },
  {
    "path": "docs/developer/how_to_build_image.rst",
    "content": ".. _docker_image:\n\n==================\nBuild Docker Image\n==================\n\nDockerfile\n==========\n\nThere is a **Dockerfile** file in the root directory of the project from which you can build the docker image. There are two build methods in Dockerfile to choose from.\nWhen executing the build command, use the ``--build-arg`` parameter to control the image version. The ``--build-arg`` parameter defaults to ``yes``, which builds the ``stable`` version of the qlib image.\n\n1.For the ``stable`` version, use ``pip install pyqlib`` to build the qlib image.\n\n.. code-block:: bash\n\n    docker build --build-arg IS_STABLE=yes -t <image name> -f ./Dockerfile .\n\n.. code-block:: bash\n\n    docker build -t <image name> -f ./Dockerfile .\n\n2. For the ``nightly`` version, use current source code to build the qlib image.\n\n.. code-block:: bash\n\n    docker build --build-arg IS_STABLE=no -t <image name> -f ./Dockerfile .\n\nAuto build of qlib images\n=========================\n\n1. There is a **build_docker_image.sh** file in the root directory of your project, which can be used to automatically build docker images and upload them to your docker hub repository(Optional, configuration required).\n\n.. code-block:: bash\n\n    sh build_docker_image.sh\n    >>> Do you want to build the nightly version of the qlib image? (default is stable) (yes/no):\n    >>> Is it uploaded to docker hub? (default is no) (yes/no):\n\n2. If you want to upload the built image to your docker hub repository, you need to edit your **build_docker_image.sh** file first, fill in ``docker_user`` in the file, and then execute this file.\n\nHow to use qlib images\n======================\n1. Start a new Docker container\n\n.. code-block:: bash\n\n    docker run -it --name <container name> -v <Mounted local directory>:/app <image name>\n\n2. At this point you are in the docker environment and can run the qlib scripts. An example:\n\n.. code-block:: bash\n\n    >>> python scripts/get_data.py qlib_data --name qlib_data_simple --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn\n    >>> python qlib/cli/run.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml\n\n3. Exit the container\n\n.. code-block:: bash\n\n    >>> exit\n\n4. Restart the container\n\n.. code-block:: bash\n\n    docker start -i -a <container name>\n\n5. Stop the container\n\n.. code-block:: bash\n\n    docker stop -i -a <container name>\n\n6. Delete the container\n\n.. code-block:: bash\n\n    docker rm <container name>\n\n7. For more information on using docker see the `docker documentation <https://docs.docker.com/reference/cli/docker/>`_.\n"
  },
  {
    "path": "docs/hidden/client.rst",
    "content": ".. _client:\n\nQlib Client-Server Framework\n============================\n\n.. currentmodule:: qlib\n\nIntroduction\n------------\nClient-Server is designed to solve following  problems\n\n- Manage the data in a centralized way. Users don't have to manage data of different versions.\n- Reduce the amount of cache to be generated.\n- Make the data can be accessed in a remote way.\n\nTherefore, we designed the client-server framework to solve these problems.\nWe will maintain a server and provide the data.\n\nYou have to initialize you qlib with specific config for using the client-server framework.\nHere is a typical initialization process.\n\nqlib ``init`` commonly used parameters; ``nfs-common`` must be installed on the server where the client is located, execute: ``sudo apt install nfs-common``:\n    - ``provider_uri``: nfs-server path; the format is ``host: data_dir``, for example: ``172.23.233.89:/data2/gaochao/sync_qlib/qlib``. If using offline, it can be a local data directory\n    - ``mount_path``: local data directory, ``provider_uri`` will be mounted to this directory\n    - ``auto_mount``: whether to automatically mount ``provider_uri`` to ``mount_path`` during qlib ``init``; You can also mount it manually: sudo mount.nfs ``provider_uri`` ``mount_path``. If on PAI, it is recommended to set ``auto_mount=True``\n    - ``flask_server``: data service host; if you are on the intranet, you can use the default host: 172.23.233.89\n    - ``flask_port``: data service port\n\n\nIf running on 10.150.144.153 or 10.150.144.154 server, it's recommended to use the following code to ``init`` qlib:\n\n.. code-block:: python\n\n   >>> import qlib\n   >>> qlib.init(auto_mount=False, mount_path='/data/csdesign/qlib')\n   >>> from qlib.data import D\n   >>> D.features(['SH600000'], ['$close'], start_time='20080101', end_time='20090101').head()\n    [39336:MainThread](2019-05-28 21:35:42,800) INFO - Initialization - [__init__.py:16] - default_conf: client.\n    [39336:MainThread](2019-05-28 21:35:42,801) INFO - Initialization - [__init__.py:54] - qlib successfully initialized based on client settings.\n    [39336:MainThread](2019-05-28 21:35:42,801) INFO - Initialization - [__init__.py:56] - provider_uri=172.23.233.89:/data2/gaochao/sync_qlib/qlib\n    [39336:Thread-68](2019-05-28 21:35:42,809) INFO - Client - [client.py:28] - Connect to server ws://172.23.233.89:9710\n    [39336:Thread-72](2019-05-28 21:35:43,489) INFO - Client - [client.py:31] - Disconnect from server!\n    Opening /data/csdesign/qlib/cache/d239a3b191daa9a5b1b19a59beb47b33 in read-only mode\n    Out[5]:\n                               $close\n    instrument datetime\n    SH600000   2008-01-02  119.079704\n               2008-01-03  113.120125\n               2008-01-04  117.878860\n               2008-01-07  124.505539\n               2008-01-08  125.395004\n\n\nIf running on PAI, it's recommended to use the following code to ``init`` qlib:\n\n.. code-block:: python\n\n   >>> import qlib\n   >>> qlib.init(auto_mount=True, mount_path='/data/csdesign/qlib', provider_uri='172.23.233.89:/data2/gaochao/sync_qlib/qlib')\n   >>> from qlib.data import D\n   >>> D.features(['SH600000'], ['$close'], start_time='20080101', end_time='20090101').head()\n    [39336:MainThread](2019-05-28 21:35:42,800) INFO - Initialization - [__init__.py:16] - default_conf: client.\n    [39336:MainThread](2019-05-28 21:35:42,801) INFO - Initialization - [__init__.py:54] - qlib successfully initialized based on client settings.\n    [39336:MainThread](2019-05-28 21:35:42,801) INFO - Initialization - [__init__.py:56] - provider_uri=172.23.233.89:/data2/gaochao/sync_qlib/qlib\n    [39336:Thread-68](2019-05-28 21:35:42,809) INFO - Client - [client.py:28] - Connect to server ws://172.23.233.89:9710\n    [39336:Thread-72](2019-05-28 21:35:43,489) INFO - Client - [client.py:31] - Disconnect from server!\n    Opening /data/csdesign/qlib/cache/d239a3b191daa9a5b1b19a59beb47b33 in read-only mode\n    Out[5]:\n                               $close\n    instrument datetime\n    SH600000   2008-01-02  119.079704\n               2008-01-03  113.120125\n               2008-01-04  117.878860\n               2008-01-07  124.505539\n               2008-01-08  125.395004\n\n\nIf running on Windows, open **NFS** features and write correct **mount_path**, it's recommended to use the following code to ``init`` qlib:\n\n1.windows System open NFS Features\n    * Open ``Programs and Features``.\n    * Click ``Turn Windows features on or off``.\n    * Scroll down and check the option ``Services for NFS``, then click OK\n\n    Reference address: https://graspingtech.com/mount-nfs-share-windows-10/\n2.config correct mount_path\n    * In windows, mount path must be not exist path and root path,\n        * correct format path eg: `H`, `i`...\n        * error format path eg: `C`, `C:/user/name`, `qlib_data`...\n\n.. code-block:: python\n\n   >>> import qlib\n   >>> qlib.init(auto_mount=True, mount_path='H', provider_uri='172.23.233.89:/data2/gaochao/sync_qlib/qlib')\n   >>> from qlib.data import D\n   >>> D.features(['SH600000'], ['$close'], start_time='20080101', end_time='20090101').head()\n    [39336:MainThread](2019-05-28 21:35:42,800) INFO - Initialization - [__init__.py:16] - default_conf: client.\n    [39336:MainThread](2019-05-28 21:35:42,801) INFO - Initialization - [__init__.py:54] - qlib successfully initialized based on client settings.\n    [39336:MainThread](2019-05-28 21:35:42,801) INFO - Initialization - [__init__.py:56] - provider_uri=172.23.233.89:/data2/gaochao/sync_qlib/qlib\n    [39336:Thread-68](2019-05-28 21:35:42,809) INFO - Client - [client.py:28] - Connect to server ws://172.23.233.89:9710\n    [39336:Thread-72](2019-05-28 21:35:43,489) INFO - Client - [client.py:31] - Disconnect from server!\n    Opening /data/csdesign/qlib/cache/d239a3b191daa9a5b1b19a59beb47b33 in read-only mode\n    Out[5]:\n                               $close\n    instrument datetime\n    SH600000   2008-01-02  119.079704\n               2008-01-03  113.120125\n               2008-01-04  117.878860\n               2008-01-07  124.505539\n               2008-01-08  125.395004\n\n\n\n\n\nThe client will mount the data in `provider_uri` on `mount_path`. Then the server and client will communicate with flask and transporting data with this NFS.\n\n\nIf you have a local qlib data files and want to use the qlib data offline instead of online with client server framework.\nIt is also possible with  specific config.\nYou can created such a config. `client_config_local.yml`\n\n.. code-block:: YAML\n\n    provider_uri: /data/csdesign/qlib\n    calendar_provider: 'LocalCalendarProvider'\n    instrument_provider: 'LocalInstrumentProvider'\n    feature_provider: 'LocalFeatureProvider'\n    expression_provider: 'LocalExpressionProvider'\n    dataset_provider: 'LocalDatasetProvider'\n    provider: 'LocalProvider'\n    dataset_cache: 'SimpleDatasetCache'\n    local_cache_path: '~/.cache/qlib/'\n\n`provider_uri` is the directory of your local data.\n\n.. code-block:: python\n\n   >>> import qlib\n   >>> qlib.init_from_yaml_conf('client_config_local.yml')\n   >>> from qlib.data import D\n   >>> D.features(['SH600001'], ['$close'], start_time='20180101', end_time='20190101').head()\n    21232:MainThread](2019-05-29 10:16:05,066) INFO - Initialization - [__init__.py:16] - default_conf: client.\n    [21232:MainThread](2019-05-29 10:16:05,066) INFO - Initialization - [__init__.py:54] - qlib successfully initialized based on client settings.\n    [21232:MainThread](2019-05-29 10:16:05,067) INFO - Initialization - [__init__.py:56] - provider_uri=/data/csdesign/qlib\n    Out[9]:\n                              $close\n    instrument datetime\n    SH600001   2008-01-02  21.082111\n               2008-01-03  23.195362\n               2008-01-04  23.874615\n               2008-01-07  24.880930\n               2008-01-08  24.277143\n\nLimitations\n-----------\n1. The following API under the client-server module may not be as fast as the older off-line  API.\n    - Cal.calendar\n    - Inst.list_instruments\n2. The rolling operation expression with parameter `0` can not be updated rightly under mechanism of the client-server framework.\n\nAPI\n***\n\nThe client is based on `python-socketio <https://python-socketio.readthedocs.io>`_ which is a framework that supports WebSocket client for Python language. The client can only propose requests and receive results, which do not include any calculating procedure.\n\nClass\n-----\n\n.. automodule:: qlib.data.client\n"
  },
  {
    "path": "docs/hidden/online.rst",
    "content": ".. _online:\n\nOnline\n======\n.. currentmodule:: qlib\n\nIntroduction\n------------\n\nWelcome to use Online, this module simulates what will be like if we do the real trading use our model and strategy.\n\nJust like Estimator and other modules in Qlib, you need to determine parameters through the configuration file,\nand in this module, you need to add an account in a folder to do the simulation. Then in each coming day,\nthis module will use the newest information to do the trade for your account,\nthe performance can be viewed at any time using the API we defined.\n\nEach account will experience the following processes, the ‘pred_date’ represents the date you predict the target\npositions after trading, also, the ‘trade_date’ is the date you do the trading.\n\n- Generate the order list (pre_date)\n- Execute the order list (trade_date)\n- Update account (trade_date)\n\nIn the meantime, you can just create an account and use this module to test its performance in a period.\n\n- Simulate (start_date, end_date)\n\nThis module need to save your account in a folder, the model and strategy will be saved as pickle files,\nand the position and report will be saved as excel.\nThe file structure can be viewed at fileStruct_.\n\n\nExample\n-------\n\nLet's take an example,\n\n.. note:: Make sure you have the latest version of `qlib` installed.\n\nIf you want to use the models and data provided by `qlib`, you only need to do as follows.\n\nFirstly, write a simple configuration file as following,\n\n.. code-block:: YAML\n\n    strategy:\n        class: TopkAmountStrategy\n        module_path: qlib.contrib.strategy\n        args:\n            market: csi500\n            trade_freq: 5\n\n    model:\n        class: ScoreFileModel\n        module_path: qlib.contrib.online.online_model\n        args:\n            loss: mse\n            model_path: ./model.bin\n\n    init_cash: 1000000000\n\nWe then can use this command to create a folder and do trading from 2017-01-01 to 2018-08-01.\n\n.. code-block:: bash\n\n    online simulate -id v-test -config ./config/config.yaml -exchange_config ./config/exchange.yaml -start 2017-01-01 -end 2018-08-01 -path ./user_data/\n\nThe start date (2017-01-01) is the add date of the user, which also is the first predict date,\nand the end date (2018-08-01) is the last trade date. You can use \"`online generate -date 2018-08-02...`\"\ncommand to continue generate the order_list at next trading date.\n\nIf Your account was saved in \"./user_data/\", you can see the performance of your account compared to a benchmark by\n\n.. code-block:: bash\n\n    >> online show -id v-test -path ./user_data/ -bench SH000905\n\n    ...\n    Result of porfolio:\n                                                      risk\n    excess_return_without_cost mean               0.000605\n                               std                0.005481\n                               annualized_return  0.152373\n                               information_ratio  1.751319\n                               max_drawdown      -0.059055\n    excess_return_with_cost    mean               0.000410\n                               std                0.005478\n                               annualized_return  0.103265\n                               information_ratio  1.187411\n                               max_drawdown      -0.075024\n\n\nHere 'SH000905' represents csi500 and 'SH000300' represents csi300\n\nManage your account\n-------------------\n\nAny account processed by `online` should be saved in a folder. you can use commands\ndefined to manage your accounts.\n\n- add an new account\n    This will add an new account with user_id='v-test', add_date='2019-10-15' in ./user_data.\n\n    .. code-block:: bash\n\n        >> online add_user -id {user_id} -config {config_file} -path {folder_path} -date {add_date}\n        >> online add_user -id v-test -config config.yaml -path ./user_data/ -date 2019-10-15\n\n- remove an account\n    .. code-block:: bash\n\n        >> online remove_user -id {user_id} -path {folder_path}\n        >> online remove_user -id v-test -path ./user_data/\n\n- show the performance\n    Here benchmark indicates the baseline is to be compared with yours.\n\n    .. code-block:: bash\n\n        >> online show -id {user_id} -path {folder_path} -bench {benchmark}\n        >> online show -id v-test -path ./user_data/ -bench SH000905\n\nThe default value of all the parameter 'date' below is trade date\n(will be today if today is trading date and information has been updated in `qlib`).\n\nThe 'generate' and 'update' will check whether input date is valid, the following 3 processes should\nbe called at each trading date.\n\n- generate the order list\n    generate the order list at trade date, and save them in {folder_path}/{user_id}/temp/ as a json file.\n\n    .. code-block:: bash\n\n        >> online generate -date {date} -path {folder_path}\n        >> online generate -date 2019-10-16 -path ./user_data/\n\n- execute the order list\n    execute the order list and generate the transactions result in {folder_path}/{user_id}/temp/ at trade date\n\n    .. code-block:: bash\n\n        >> online execute -date {date} -exchange_config {exchange_config_path} -path {folder_path}\n        >> online execute -date 2019-10-16 -exchange_config ./config/exchange.yaml -path ./user_data/\n\n    A simple exchange config file can be as\n\n    .. code-block:: yaml\n\n        open_cost: 0.003\n        close_cost: 0.003\n        limit_threshold: 0.095\n        deal_price: vwap\n\n\n- update accounts\n    update accounts in \"{folder_path}/\" at trade date\n\n    .. code-block:: bash\n\n        >> online update -date {date} -path {folder_path}\n        >> online update -date 2019-10-16 -path ./user_data/\n\nAPI\n---\n\nAll those operations are based on defined in `qlib.contrib.online.operator`\n\n.. automodule:: qlib.contrib.online.operator\n\n.. _fileStruct:\n\nFile structure\n--------------\n\n'user_data' indicates the root of folder.\nName that bold indicates it’s a folder, otherwise it’s a document.\n\n.. code-block:: yaml\n\n    {user_folder}\n    │   users.csv: (Init date for each users)\n    │\n    └───{user_id1}: (users' sub-folder to save their data)\n    │   │   position.xlsx\n    │   │   report.csv\n    │   │   model_{user_id1}.pickle\n    │   │   strategy_{user_id1}.pickle\n    │   │\n    │   └───score\n    │   │   └───{YYYY}\n    │   │       └───{MM}\n    │   │           │   score_{YYYY-MM-DD}.csv\n    │   │\n    │   └───trade\n    │       └───{YYYY}\n    │           └───{MM}\n    │               │   orderlist_{YYYY-MM-DD}.json\n    │               │   transaction_{YYYY-MM-DD}.csv\n    │\n    └───{user_id2}\n    │   │   position.xlsx\n    │   │   report.csv\n    │   │   model_{user_id2}.pickle\n    │   │   strategy_{user_id2}.pickle\n    │   │\n    │   └───score\n    │   └───trade\n    ....\n\n\nConfiguration file\n------------------\n\nThe configure file used in `online` should contain the model and strategy information.\n\nAbout the model\n~~~~~~~~~~~~~~~\n\nFirst, your configuration file needs to have a field about the model,\nthis field and its contents determine the model we used when generating score at predict date.\n\nFollowings are two examples for ScoreFileModel and a model that read a score file and return score at trade date.\n\n.. code-block:: YAML\n\n     model:\n        class: ScoreFileModel\n        module_path: qlib.contrib.online.OnlineModel\n        args:\n            loss: mse\n\n.. code-block:: YAML\n\n     model:\n        class: ScoreFileModel\n        module_path: qlib.contrib.online.OnlineModel\n        args:\n            score_path: <your score path>\n\nIf your model doesn't belong to above models, you need to coding your model manually.\nYour model should be a subclass of models defined in 'qlib.contfib.model'. And it must\ncontains 2 methods used in `online` module.\n\n\nAbout the strategy\n~~~~~~~~~~~~~~~~~~\n\nYour need define the strategy used to generate the order list at predict date.\n\nFollowings are two examples for a TopkAmountStrategy\n\n.. code-block:: YAML\n\n    strategy:\n        class: TopkDropoutStrategy\n        module_path: qlib.contrib.strategy.strategy\n        args:\n            topk: 100\n            n_drop: 10\n\nGenerated files\n---------------\n\nThe 'online_generate' command will create the order list at {folder_path}/{user_id}/temp/,\nthe name of that is orderlist_{YYYY-MM-DD}.json, YYYY-MM-DD is the date that those orders to be executed.\n\nThe format of json file is like\n\n.. code-block:: python\n\n    {\n        'sell': {\n                {'$stock_id1': '$amount1'},\n                {'$stock_id2': '$amount2'}, ...\n                },\n        'buy': {\n                {'$stock_id1': '$amount1'},\n                {'$stock_id2': '$amount2'}, ...\n                }\n    }\n\nThen after executing the order list (either by 'online_execute' or other executors), a transaction file\nwill be created also at {folder_path}/{user_id}/temp/.\n"
  },
  {
    "path": "docs/hidden/tuner.rst",
    "content": ".. _tuner:\r\n\r\nTuner\r\n=====\r\n.. currentmodule:: qlib\r\n\r\nIntroduction\r\n------------\r\n\r\nWelcome to use Tuner, this document is based on that you can use Estimator proficiently and correctly.\r\n\r\nYou can find the optimal hyper-parameters and combinations of models, trainers, strategies and data labels.\r\n\r\nThe usage of program `tuner` is similar with `estimator`, you need provide the URL of the configuration file.\r\nThe `tuner` will do the following things:\r\n\r\n- Construct tuner pipeline\r\n- Search and save best hyper-parameters of one tuner\r\n- Search next tuner in pipeline\r\n- Save the global best hyper-parameters and combination\r\n\r\nEach tuner is consisted with a kind of combination of modules, and its goal is searching the optimal hyper-parameters of this combination.\r\nThe pipeline is consisted with different tuners, it is aim at finding the optimal combination of modules.\r\n\r\nThe result will be printed on screen and saved in file, you can check the result in your experiment saving files.\r\n\r\nExample\r\n~~~~~~~\r\n\r\nLet's see an example,\r\n\r\nFirst make sure you have the latest version of `qlib` installed.\r\n\r\nThen, you need to provide a configuration to setup the experiment.\r\nWe write a simple configuration example as following,\r\n\r\n.. code-block:: YAML\r\n\r\n    experiment:\r\n        name: tuner_experiment\r\n        tuner_class: QLibTuner\r\n    qlib_client:\r\n        auto_mount: False\r\n        logging_level: INFO\r\n    optimization_criteria:\r\n        report_type: model\r\n        report_factor: model_score\r\n        optim_type: max\r\n    tuner_pipeline:\r\n      -\r\n        model:\r\n            class: SomeModel\r\n            space: SomeModelSpace\r\n        trainer:\r\n            class: RollingTrainer\r\n        strategy:\r\n            class: TopkAmountStrategy\r\n            space: TopkAmountStrategySpace\r\n        max_evals: 2\r\n\r\n    time_period:\r\n        rolling_period: 360\r\n        train_start_date: 2005-01-01\r\n        train_end_date: 2014-12-31\r\n        validate_start_date: 2015-01-01\r\n        validate_end_date: 2016-06-30\r\n        test_start_date: 2016-07-01\r\n        test_end_date: 2018-04-30\r\n    data:\r\n        class: ALPHA360\r\n        provider_uri: /data/qlib\r\n        args:\r\n            start_date: 2005-01-01\r\n            end_date: 2018-04-30\r\n            dropna_label: True\r\n            dropna_feature: True\r\n        filter:\r\n            market: csi500\r\n            filter_pipeline:\r\n              -\r\n                class: NameDFilter\r\n                module_path: qlib.data.filter\r\n                args:\r\n                  name_rule_re: S(?!Z3)\r\n                  fstart_time: 2018-01-01\r\n                  fend_time: 2018-12-11\r\n              -\r\n                class: ExpressionDFilter\r\n                module_path: qlib.data.filter\r\n                args:\r\n                  rule_expression: $open/$factor<=45\r\n                  fstart_time: 2018-01-01\r\n                  fend_time: 2018-12-11\r\n    backtest:\r\n        normal_backtest_args:\r\n            limit_threshold: 0.095\r\n            account: 500000\r\n            benchmark: SH000905\r\n            deal_price: vwap\r\n        long_short_backtest_args:\r\n            topk: 50\r\n\r\nNext, we run the following command, and you can see:\r\n\r\n.. code-block:: bash\r\n\r\n    ~/v-yindzh/Qlib/cfg$ tuner -c tuner_config.yaml\r\n\r\n    Searching params: {'model_space': {'colsample_bytree': 0.8870905643607678, 'lambda_l1': 472.3188735122233, 'lambda_l2': 92.75390994877243, 'learning_rate': 0.09741751430635413, 'loss': 'mse', 'max_depth': 8, 'num_leaves': 160, 'num_threads': 20, 'subsample': 0.7536051584789751}, 'strategy_space': {'buffer_margin': 250, 'topk': 40}}\r\n    ...\r\n    (Estimator experiment screen log)\r\n    ...\r\n    Searching params: {'model_space': {'colsample_bytree': 0.6667379039007301, 'lambda_l1': 382.10698024977904, 'lambda_l2': 117.02506488151757, 'learning_rate': 0.18514539615228137, 'loss': 'mse', 'max_depth': 6, 'num_leaves': 200, 'num_threads': 12, 'subsample': 0.9449255686969292}, 'strategy_space': {'buffer_margin': 200, 'topk': 30}}\r\n    ...\r\n    (Estimator experiment screen log)\r\n    ...\r\n    Local best params: {'model_space': {'colsample_bytree': 0.6667379039007301, 'lambda_l1': 382.10698024977904, 'lambda_l2': 117.02506488151757, 'learning_rate': 0.18514539615228137, 'loss': 'mse', 'max_depth': 6, 'num_leaves': 200, 'num_threads': 12, 'subsample': 0.9449255686969292}, 'strategy_space': {'buffer_margin': 200, 'topk': 30}}\r\n    Time cost: 489.87220 | Finished searching best parameters in Tuner 0.\r\n    Time cost: 0.00069 | Finished saving local best tuner parameters to: tuner_experiment/estimator_experiment/estimator_experiment_0/local_best_params.json .\r\n    Searching params: {'data_label_space': {'labels': ('Ref($vwap, -2)/Ref($vwap, -1) - 2',)}, 'model_space': {'input_dim': 158, 'lr': 0.001, 'lr_decay': 0.9100529502185579, 'lr_decay_steps': 162.48901403763966, 'optimizer': 'gd', 'output_dim': 1}, 'strategy_space': {'buffer_margin': 300, 'topk': 35}}\r\n    ...\r\n    (Estimator experiment screen log)\r\n    ...\r\n    Searching params: {'data_label_space': {'labels': ('Ref($vwap, -2)/Ref($vwap, -1) - 1',)}, 'model_space': {'input_dim': 158, 'lr': 0.1, 'lr_decay': 0.9882802970847494, 'lr_decay_steps': 164.76742865207729, 'optimizer': 'adam', 'output_dim': 1}, 'strategy_space': {'buffer_margin': 250, 'topk': 35}}\r\n    ...\r\n    (Estimator experiment screen log)\r\n    ...\r\n    Local best params: {'data_label_space': {'labels': ('Ref($vwap, -2)/Ref($vwap, -1) - 1',)}, 'model_space': {'input_dim': 158, 'lr': 0.1, 'lr_decay': 0.9882802970847494, 'lr_decay_steps': 164.76742865207729, 'optimizer': 'adam', 'output_dim': 1}, 'strategy_space': {'buffer_margin': 250, 'topk': 35}}\r\n    Time cost: 550.74039 | Finished searching best parameters in Tuner 1.\r\n    Time cost: 0.00023 | Finished saving local best tuner parameters to: tuner_experiment/estimator_experiment/estimator_experiment_1/local_best_params.json .\r\n    Time cost: 1784.14691 | Finished tuner pipeline.\r\n    Time cost: 0.00014 | Finished save global best tuner parameters.\r\n    Best Tuner id: 0.\r\n    You can check the best parameters at tuner_experiment/global_best_params.json.\r\n\r\n\r\nFinally, you can check the results of your experiment in the given path.\r\n\r\nConfiguration file\r\n------------------\r\n\r\nBefore using `tuner`, you need to prepare a configuration file. Next we will show you how to prepare each part of the configuration file.\r\n\r\nAbout the experiment\r\n~~~~~~~~~~~~~~~~~~~~\r\n\r\nFirst, your configuration file needs to have a field about the experiment, whose key is `experiment`, this field and its contents determine the saving path and tuner class.\r\n\r\nUsually it should contain the following content:\r\n\r\n.. code-block:: YAML\r\n\r\n    experiment:\r\n        name: tuner_experiment\r\n        tuner_class: QLibTuner\r\n\r\nAlso, there are some optional fields. The meaning of each field is as follows:\r\n\r\n- `name`\r\n    The experiment name, str type, the program will use this experiment name to construct a directory to save the process of the whole experiment and the results. The default value is `tuner_experiment`.\r\n\r\n- `dir`\r\n    The saving path, str type, the program will construct the experiment directory in this path. The default value is the path where configuration locate.\r\n\r\n- `tuner_class`\r\n    The class of tuner, str type, must be an already implemented model, such as `QLibTuner` in `qlib`, or a custom tuner, but it must be a subclass of `qlib.contrib.tuner.Tuner`, the default value is `QLibTuner`.\r\n\r\n- `tuner_module_path`\r\n    The module path, str type, absolute url is also supported, indicates the path of the implementation of tuner. The default value is `qlib.contrib.tuner.tuner`\r\n\r\nAbout the optimization criteria\r\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\r\n\r\nYou need to designate a factor to optimize, for tuner need a factor to decide which case is better than other cases.\r\nUsually, we use the result of `estimator`, such as backtest results and the score of model.\r\n\r\nThis part needs contain these fields:\r\n\r\n.. code-block:: YAML\r\n\r\n    optimization_criteria:\r\n        report_type: model\r\n        report_factor: model_pearsonr\r\n        optim_type: max\r\n\r\n- `report_type`\r\n    The type of the report, str type, determines which kind of report you want to use. If you want to use the backtest result type, you can choose `pred_long`, `pred_long_short`, `pred_short`, `excess_return_without_cost` and `excess_return_with_cost`. If you want to use the model result type, you can only choose `model`.\r\n\r\n- `report_factor`\r\n    The factor you want to use in the report, str type, determines which factor you want to optimize. If your `report_type` is backtest result type, you can choose `annualized_return`, `information_ratio`, `max_drawdown`, `mean` and `std`. If your `report_type` is model result type, you can choose `model_score` and `model_pearsonr`.\r\n\r\n- `optim_type`\r\n    The optimization type, str type, determines what kind of optimization you want to do. you can minimize the factor or maximize the factor, so you can choose `max`, `min` or `correlation` at this field.\r\n    Note: `correlation` means the factor's best value is 1, such as `model_pearsonr` (a corraltion coefficient).\r\n\r\nIf you want to process the factor or you want fetch other kinds of factor, you can override the `objective` method in your own tuner.\r\n\r\nAbout the tuner pipeline\r\n~~~~~~~~~~~~~~~~~~~~~~~~\r\n\r\nThe tuner pipeline contains different tuners, and the `tuner` program will process each tuner in pipeline. Each tuner will get an optimal hyper-parameters of its specific combination of modules. The pipeline will contrast the results of each tuner, and get the best combination and its optimal hyper-parameters. So, you need to configurate the pipeline and each tuner, here is an example:\r\n\r\n.. code-block:: YAML\r\n\r\n    tuner_pipeline:\r\n      -\r\n        model:\r\n            class: SomeModel\r\n            space: SomeModelSpace\r\n        trainer:\r\n            class: RollingTrainer\r\n        strategy:\r\n            class: TopkAmountStrategy\r\n            space: TopkAmountStrategySpace\r\n        max_evals: 2\r\n\r\nEach part represents a tuner, and its modules which are to be tuned. Space in each part is the hyper-parameters' space of a certain module, you need to create your searching space and modify it in `/qlib/contrib/tuner/space.py`. We use `hyperopt` package to help us to construct the space, you can see the detail of how to use it in https://github.com/hyperopt/hyperopt/wiki/FMin .\r\n\r\n- model\r\n    You need to provide the `class` and the `space` of the model. If the model is user's own implementation, you need to provide the `module_path`.\r\n\r\n- trainer\r\n    You need to provide the `class` of the trainer. If the trainer is user's own implementation, you need to provide the `module_path`.\r\n\r\n- strategy\r\n    You need to provide the `class` and the `space` of the strategy. If the strategy is user's own implementation, you need to provide the `module_path`.\r\n\r\n- data_label\r\n    The label of the data, you can search which kinds of labels will lead to a better result. This part is optional, and you only need to provide `space`.\r\n\r\n- max_evals\r\n    Allow up to this many function evaluations in this tuner. The default value is 10.\r\n\r\nIf you don't want to search some modules, you can fix their spaces in `space.py`. We will not give the default module.\r\n\r\nAbout the time period\r\n~~~~~~~~~~~~~~~~~~~~~\r\n\r\nYou need to use the same dataset to evaluate your different `estimator` experiments in `tuner` experiment. Two experiments using different dataset are uncomparable. You can specify `time_period` through the configuration file:\r\n\r\n.. code-block:: YAML\r\n\r\n    time_period:\r\n        rolling_period: 360\r\n        train_start_date: 2005-01-01\r\n        train_end_date: 2014-12-31\r\n        validate_start_date: 2015-01-01\r\n        validate_end_date: 2016-06-30\r\n        test_start_date: 2016-07-01\r\n        test_end_date: 2018-04-30\r\n\r\n- `rolling_period`\r\n    The rolling period, integer type, indicates how many time steps need rolling when rolling the data. The default value is `60`. If you use `RollingTrainer`, this config will be used, or it will be ignored.\r\n\r\n- `train_start_date`\r\n    Training start time, str type.\r\n\r\n- `train_end_date`\r\n    Training end time, str type.\r\n\r\n- `validate_start_date`\r\n    Validation start time, str type.\r\n\r\n- `validate_end_date`\r\n    Validation end time, str type.\r\n\r\n- `test_start_date`\r\n    Test start time, str type.\r\n\r\n- `test_end_date`\r\n    Test end time, str type. If `test_end_date` is `-1` or greater than the last date of the data, the last date of the data will be used as `test_end_date`.\r\n\r\nAbout the data and backtest\r\n~~~~~~~~~~~~~~~~~~~~~~~~~~~\r\n\r\n`data` and `backtest` are all same in the whole `tuner` experiment. Different `estimator` experiments must use the same data and backtest method. So, these two parts of config are same with that in `estimator` configuration. You can see the precise definition of these parts in `estimator` introduction. We only provide an example here.\r\n\r\n.. code-block:: YAML\r\n\r\n    data:\r\n        class: ALPHA360\r\n        provider_uri: /data/qlib\r\n        args:\r\n            start_date: 2005-01-01\r\n            end_date: 2018-04-30\r\n            dropna_label: True\r\n            dropna_feature: True\r\n            feature_label_config: /home/v-yindzh/v-yindzh/QLib/cfg/feature_config.yaml\r\n        filter:\r\n            market: csi500\r\n            filter_pipeline:\r\n              -\r\n                class: NameDFilter\r\n                module_path: qlib.filter\r\n                args:\r\n                  name_rule_re: S(?!Z3)\r\n                  fstart_time: 2018-01-01\r\n                  fend_time: 2018-12-11\r\n              -\r\n                class: ExpressionDFilter\r\n                module_path: qlib.filter\r\n                args:\r\n                  rule_expression: $open/$factor<=45\r\n                  fstart_time: 2018-01-01\r\n                  fend_time: 2018-12-11\r\n    backtest:\r\n        normal_backtest_args:\r\n            limit_threshold: 0.095\r\n            account: 500000\r\n            benchmark: SH000905\r\n            deal_price: vwap\r\n        long_short_backtest_args:\r\n            topk: 50\r\n\r\nExperiment Result\r\n-----------------\r\n\r\nAll the results are stored in experiment file directly, you can check them directly in the corresponding files.\r\nWhat we save are as following:\r\n\r\n- Global optimal parameters\r\n- Local optimal parameters of each tuner\r\n- Config file of this `tuner` experiment\r\n- Every `estimator` experiments result in the process\r\n"
  },
  {
    "path": "docs/index.rst",
    "content": "======================\n``Qlib`` Documentation\n======================\n\n``Qlib`` is an AI-oriented quantitative investment platform, which aims to realize the potential, empower the research, and create the value of AI technologies in quantitative investment.\n\n.. _user_guide:\n\nDocument Structure\n====================\n\n.. toctree::\n   :hidden:\n\n   Home <self>\n\n.. toctree::\n   :maxdepth: 3\n   :caption: GETTING STARTED:\n\n   Introduction <introduction/introduction.rst>\n   Quick Start <introduction/quick.rst>\n\n.. toctree::\n   :maxdepth: 3\n   :caption: FIRST STEPS:\n\n   Installation <start/installation.rst>\n   Initialization <start/initialization.rst>\n   Data Retrieval <start/getdata.rst>\n   Custom Model Integration <start/integration.rst>\n\n\n.. toctree::\n   :maxdepth: 3\n   :caption: MAIN COMPONENTS:\n\n   Workflow: Workflow Management <component/workflow.rst>\n   Data Layer: Data Framework & Usage <component/data.rst>\n   Forecast Model: Model Training & Prediction <component/model.rst>\n   Portfolio Management and Backtest <component/strategy.rst>\n   Nested Decision Execution: High-Frequency Trading <component/highfreq.rst>\n   Meta Controller: Meta-Task & Meta-Dataset & Meta-Model <component/meta.rst>\n   Qlib Recorder: Experiment Management <component/recorder.rst>\n   Analysis: Evaluation & Results Analysis <component/report.rst>\n   Online Serving: Online Management & Strategy & Tool <component/online.rst>\n   Reinforcement Learning <component/rl/toctree>\n\n.. toctree::\n   :maxdepth: 3\n   :caption: OTHER COMPONENTS/FEATURES/TOPICS:\n\n   Building Formulaic Alphas <advanced/alpha.rst>\n   Online & Offline mode <advanced/server.rst>\n   Serialization <advanced/serial.rst>\n   Task Management <advanced/task_management.rst>\n   Point-In-Time database <advanced/PIT.rst>\n\n.. toctree::\n   :maxdepth: 3\n   :caption: FOR DEVELOPERS:\n\n   Code Standard & Development Guidance <developer/code_standard_and_dev_guide.rst>\n   How to build image <developer/how_to_build_image.rst>\n\n.. toctree::\n   :maxdepth: 3\n   :caption: REFERENCE:\n\n   API <reference/api.rst>\n\n.. toctree::\n   :maxdepth: 3\n\n   FAQ <FAQ/FAQ.rst>\n\n.. toctree::\n   :maxdepth: 3\n   :caption: Change Log:\n\n   Change Log <changelog/changelog.rst>\n"
  },
  {
    "path": "docs/introduction/introduction.rst",
    "content": "===============================\n``Qlib``: Quantitative Platform\n===============================\n\nIntroduction\n============\n\n.. image:: ../_static/img/logo/white_bg_rec+word.png\n    :align: center\n\n``Qlib`` is an AI-oriented quantitative investment platform, which aims to realize the potential, empower the research, and create the value of AI technologies in quantitative investment.\n\nWith ``Qlib``, users can easily try their ideas to create better Quant investment strategies.\n\nFramework\n=========\n\n\n.. image:: ../_static/img/framework.svg\n    :align: center\n\n\nAt the module level, Qlib is a platform that consists of above components. The components are designed as loose-coupled modules and each component could be used stand-alone.\n\nThis framework may be intimidating for new users to Qlib. It tries to accurately include a lot of details of Qlib's design.\nFor users new to Qlib, you can skip it first and read it later.\n\n\n\n===========================  ==============================================================================\nName                         Description\n===========================  ==============================================================================\n`Infrastructure` layer       `Infrastructure` layer provides underlying support for Quant research.\n                             `DataServer` provides high-performance infrastructure for users to manage\n                             and retrieve raw data. `Trainer` provides flexible interface to control\n                             the training process of models which enable algorithms controlling the\n                             training process.\n\n`Learning Framework` layer   The `Forecast Model` and `Trading Agent` are trainable. They are trained\n                             based on the `Learning Framework` layer and then applied to multiple scenarios\n                             in `Workflow` layer. The supported learning paradigms can be categorized into\n                             reinforcement learning and supervised learning.  The learning framework\n                             leverages the `Workflow` layer as well(e.g. sharing `Information Extractor`,\n                             creating environments based on `Execution Env`).\n\n`Workflow` layer             `Workflow` layer covers the whole workflow of quantitative investment.\n                             Both supervised-learning-based strategies and RL-based Strategies\n                             are supported.\n                             `Information Extractor` extracts data for models. `Forecast Model` focuses\n                             on producing all kinds of forecast signals (e.g. *alpha*, risk) for other\n                             modules.  With these signals `Decision Generator` will generate the target\n                             trading decisions(i.e. portfolio, orders)\n                             If RL-based Strategies are adopted, the `Policy` is learned in a end-to-end way,\n                             the trading decisions are generated directly.\n                             Decisions will be executed by `Execution Env`\n                             (i.e. the trading market).  There may be multiple levels of `Strategy`\n                             and `Executor` (e.g. an *order executor trading strategy and intraday order executor*\n                             could behave like an interday trading loop and be nested in\n                             *daily portfolio management trading strategy and interday trading executor*\n                             trading loop)\n\n`Interface` layer            `Interface` layer tries to present a user-friendly interface for the underlying\n                             system. `Analyser` module will provide users detailed analysis reports of\n                             forecasting signals, portfolios and execution results\n===========================  ==============================================================================\n\n- The modules with hand-drawn style are under development and will be released in the future.\n- The modules with dashed borders are highly user-customizable and extendible.\n\n(p.s. framework image is created with https://draw.io/)\n"
  },
  {
    "path": "docs/introduction/quick.rst",
    "content": "\n===========\nQuick Start\n===========\n\nIntroduction\n============\n\nThis ``Quick Start`` guide tries to demonstrate\n\n- It's very easy to build a complete Quant research workflow and try users' ideas with ``Qlib``.\n- Though with public data and simple models, machine learning technologies work very well in practical Quant investment.\n\n\n\nInstallation\n============\n\nUsers can easily install ``Qlib`` according to the following steps:\n\n- Before installing ``Qlib`` from source, users need to install some dependencies:\n\n    .. code-block::\n\n        pip install numpy\n        pip install --upgrade  cython\n\n- Clone the repository and install ``Qlib``\n\n    .. code-block::\n\n        git clone https://github.com/microsoft/qlib.git && cd qlib\n        python setup.py install\n\nTo known more about `installation`, please refer to `Qlib Installation <../start/installation.html>`_.\n\nPrepare Data\n============\n\nLoad and prepare data by running the following code:\n\n.. code-block::\n\n    python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn\n\nThis dataset is created by public data collected by crawler scripts in ``scripts/data_collector/``, which have been released in the same repository. Users could create the same dataset with it.\n\nTo known more about `prepare data`, please refer to `Data Preparation <../component/data.html#data-preparation>`_.\n\nAuto Quant Research Workflow\n============================\n\n``Qlib`` provides a tool named ``qrun`` to run the whole workflow automatically (including building dataset, training models, backtest and evaluation). Users can start an auto quant research workflow and have a graphical reports analysis according to the following steps:\n\n- Quant Research Workflow:\n    - Run  ``qrun`` with a config file of the LightGBM model `workflow_config_lightgbm.yaml` as following.\n\n        .. code-block::\n\n            cd examples  # Avoid running program under the directory contains `qlib`\n            qrun benchmarks/LightGBM/workflow_config_lightgbm.yaml\n\n\n    - Workflow result\n        The result of ``qrun`` is as follows, which is also the typical result of ``Forecast model(alpha)``. Please refer to  `Intraday Trading <../component/backtest.html>`_. for more details about the result.\n\n        .. code-block:: python\n\n                                                              risk\n            excess_return_without_cost mean               0.000605\n                                       std                0.005481\n                                       annualized_return  0.152373\n                                       information_ratio  1.751319\n                                       max_drawdown      -0.059055\n            excess_return_with_cost    mean               0.000410\n                                       std                0.005478\n                                       annualized_return  0.103265\n                                       information_ratio  1.187411\n                                       max_drawdown      -0.075024\n\n\n    To know more about `workflow` and `qrun`, please refer to `Workflow: Workflow Management <../component/workflow.html>`_.\n\n- Graphical Reports Analysis:\n    - Run ``examples/workflow_by_code.ipynb`` with jupyter notebook\n        Users can have portfolio analysis or prediction score (model prediction) analysis by run ``examples/workflow_by_code.ipynb``.\n    - Graphical Reports\n        Users can get graphical reports about the analysis, please refer to `Analysis: Evaluation & Results Analysis <../component/report.html>`_ for more details.\n\n\n\nCustom Model Integration\n========================\n\n``Qlib`` provides a batch of models (such as ``lightGBM`` and ``MLP`` models) as examples of ``Forecast Model``. In addition to the default model, users can integrate their own custom models into ``Qlib``. If users are interested in the custom model, please refer to `Custom Model Integration <../start/integration.html>`_.\n"
  },
  {
    "path": "docs/make.bat",
    "content": "@ECHO OFF\r\n\r\npushd %~dp0\r\n\r\nREM Command file for Sphinx documentation\r\n\r\nif \"%SPHINXBUILD%\" == \"\" (\r\n\tset SPHINXBUILD=sphinx-build\r\n)\r\nset SOURCEDIR=.\r\nset BUILDDIR=_build\r\n\r\n%SPHINXBUILD% >NUL 2>NUL\r\nif errorlevel 9009 (\r\n\techo.\r\n\techo.The 'sphinx-build' command was not found. Make sure you have Sphinx\r\n\techo.installed, then set the SPHINXBUILD environment variable to point\r\n\techo.to the full path of the 'sphinx-build' executable. Alternatively you\r\n\techo.may add the Sphinx directory to PATH.\r\n\techo.\r\n\techo.If you don't have Sphinx installed, grab it from\r\n\techo.https://www.sphinx-doc.org/\r\n\texit /b 1\r\n)\r\n\r\nif \"%1\" == \"\" goto help\r\n\r\n%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%\r\ngoto end\r\n\r\n:help\r\n%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%\r\n\r\n:end\r\npopd\r\n"
  },
  {
    "path": "docs/reference/api.rst",
    "content": ".. _api:\n\n=============\nAPI Reference\n=============\n\n\n\nHere you can find all ``Qlib`` interfaces.\n\n\nData\n====\n\nProvider\n--------\n\n.. automodule:: qlib.data.data\n    :members:\n\nFilter\n------\n\n.. automodule:: qlib.data.filter\n    :members:\n\nClass\n-----\n.. automodule:: qlib.data.base\n    :members:\n\nOperator\n--------\n.. automodule:: qlib.data.ops\n    :members:\n\nCache\n-----\n.. autoclass:: qlib.data.cache.MemCacheUnit\n    :members:\n\n.. autoclass:: qlib.data.cache.MemCache\n    :members:\n\n.. autoclass:: qlib.data.cache.ExpressionCache\n    :members:\n\n.. autoclass:: qlib.data.cache.DatasetCache\n    :members:\n\n.. autoclass:: qlib.data.cache.DiskExpressionCache\n    :members:\n\n.. autoclass:: qlib.data.cache.DiskDatasetCache\n    :members:\n\n\nStorage\n-------\n.. autoclass:: qlib.data.storage.storage.BaseStorage\n    :members:\n\n.. autoclass:: qlib.data.storage.storage.CalendarStorage\n    :members:\n\n.. autoclass:: qlib.data.storage.storage.InstrumentStorage\n    :members:\n\n.. autoclass:: qlib.data.storage.storage.FeatureStorage\n    :members:\n\n.. autoclass:: qlib.data.storage.file_storage.FileStorageMixin\n    :members:\n\n.. autoclass:: qlib.data.storage.file_storage.FileCalendarStorage\n    :members:\n\n.. autoclass:: qlib.data.storage.file_storage.FileInstrumentStorage\n    :members:\n\n.. autoclass:: qlib.data.storage.file_storage.FileFeatureStorage\n    :members:\n\n\nDataset\n-------\n\nDataset Class\n~~~~~~~~~~~~~\n.. automodule:: qlib.data.dataset.__init__\n    :members:\n\nData Loader\n~~~~~~~~~~~\n.. automodule:: qlib.data.dataset.loader\n    :members:\n\nData Handler\n~~~~~~~~~~~~\n.. automodule:: qlib.data.dataset.handler\n    :members:\n\nProcessor\n~~~~~~~~~\n.. automodule:: qlib.data.dataset.processor\n    :members:\n\n\nContrib\n=======\n\nModel\n-----\n.. automodule:: qlib.model.base\n    :members:\n\nStrategy\n--------\n\n.. automodule:: qlib.contrib.strategy\n    :members:\n\nEvaluate\n--------\n\n.. automodule:: qlib.contrib.evaluate\n    :members:\n\n\nReport\n------\n\n.. automodule:: qlib.contrib.report.analysis_position.report\n    :members:\n\n\n\n.. automodule:: qlib.contrib.report.analysis_position.score_ic\n    :members:\n\n\n\n.. automodule:: qlib.contrib.report.analysis_position.cumulative_return\n    :members:\n\n\n\n.. automodule:: qlib.contrib.report.analysis_position.risk_analysis\n    :members:\n\n\n\n.. automodule:: qlib.contrib.report.analysis_position.rank_label\n    :members:\n\n\n\n.. automodule:: qlib.contrib.report.analysis_model.analysis_model_performance\n    :members:\n\n\nWorkflow\n========\n\n\nExperiment Manager\n------------------\n.. autoclass:: qlib.workflow.expm.ExpManager\n    :members:\n\nExperiment\n----------\n.. autoclass:: qlib.workflow.exp.Experiment\n    :members:\n\nRecorder\n--------\n.. autoclass:: qlib.workflow.recorder.Recorder\n    :members:\n\nRecord Template\n---------------\n.. automodule:: qlib.workflow.record_temp\n    :members:\n\nTask Management\n===============\n\n\nTaskGen\n-------\n.. automodule:: qlib.workflow.task.gen\n    :members:\n\nTaskManager\n-----------\n.. automodule:: qlib.workflow.task.manage\n    :members:\n\nTrainer\n-------\n.. automodule:: qlib.model.trainer\n    :members:\n\nCollector\n---------\n.. automodule:: qlib.workflow.task.collect\n    :members:\n\nGroup\n-----\n.. automodule:: qlib.model.ens.group\n    :members:\n\nEnsemble\n--------\n.. automodule:: qlib.model.ens.ensemble\n    :members:\n\nUtils\n-----\n.. automodule:: qlib.workflow.task.utils\n    :members:\n\n\nOnline Serving\n==============\n\n\nOnline Manager\n--------------\n.. automodule:: qlib.workflow.online.manager\n    :members:\n\nOnline Strategy\n---------------\n.. automodule:: qlib.workflow.online.strategy\n    :members:\n\nOnline Tool\n-----------\n.. automodule:: qlib.workflow.online.utils\n    :members:\n\n\nRecordUpdater\n-------------\n.. automodule:: qlib.workflow.online.update\n    :members:\n\n\nUtils\n=====\n\nSerializable\n------------\n\n.. automodule:: qlib.utils.serial\n    :members:\n\nRL\n==============\n\nBase Component\n--------------\n.. automodule:: qlib.rl\n    :members:\n    :imported-members:\n\nStrategy\n--------\n.. automodule:: qlib.rl.strategy\n    :members:\n    :imported-members:\n\nTrainer\n-------\n.. automodule:: qlib.rl.trainer\n    :members:\n    :imported-members:\n\nOrder Execution\n---------------\n.. automodule:: qlib.rl.order_execution\n    :members:\n    :imported-members:\n\nUtils\n---------------\n.. automodule:: qlib.rl.utils\n    :members:\n    :imported-members:"
  },
  {
    "path": "docs/requirements.txt",
    "content": "Cython\ncmake\nnumpy\nscipy\nscikit-learn\npandas\ntianshou\nsphinx_rtd_theme\n"
  },
  {
    "path": "docs/start/getdata.rst",
    "content": ".. _getdata:\n\n==============\nData Retrieval\n==============\n\n.. currentmodule:: qlib\n\nIntroduction\n============\n\nUsers can get stock data with ``Qlib``. The following examples demonstrate the basic user interface.\n\nExamples\n========\n\n\n``QLib`` Initialization:\n\n.. note:: In order to get the data, users need to initialize ``Qlib`` with `qlib.init` first. Please refer to `initialization <initialization.html>`_.\n\nIf users followed steps in `initialization <initialization.html>`_ and downloaded the data, they should use the following code to initialize qlib\n\n.. code-block:: python\n\n    >> import qlib\n    >> qlib.init(provider_uri='~/.qlib/qlib_data/cn_data')\n\n\nLoad trading calendar with given time range and frequency:\n\n.. code-block:: python\n\n   >> from qlib.data import D\n   >> D.calendar(start_time='2010-01-01', end_time='2017-12-31', freq='day')[:2]\n   [Timestamp('2010-01-04 00:00:00'), Timestamp('2010-01-05 00:00:00')]\n\nParse a given market name into a stock pool config:\n\n.. code-block:: python\n\n   >> from qlib.data import D\n   >> D.instruments(market='all')\n   {'market': 'all', 'filter_pipe': []}\n\nLoad instruments of certain stock pool in the given time range:\n\n.. code-block:: python\n\n   >> from qlib.data import D\n   >> instruments = D.instruments(market='csi300')\n   >> D.list_instruments(instruments=instruments, start_time='2010-01-01', end_time='2017-12-31', as_list=True)[:6]\n   ['SH600036', 'SH600110', 'SH600087', 'SH600900', 'SH600089', 'SZ000912']\n\nLoad dynamic instruments from a base market according to a name filter\n\n.. code-block:: python\n\n   >> from qlib.data import D\n   >> from qlib.data.filter import NameDFilter\n   >> nameDFilter = NameDFilter(name_rule_re='SH[0-9]{4}55')\n   >> instruments = D.instruments(market='csi300', filter_pipe=[nameDFilter])\n   >> D.list_instruments(instruments=instruments, start_time='2015-01-01', end_time='2016-02-15', as_list=True)\n   ['SH600655', 'SH601555']\n\nLoad dynamic instruments from a base market according to an expression filter\n\n.. code-block:: python\n\n   >> from qlib.data import D\n   >> from qlib.data.filter import ExpressionDFilter\n   >> expressionDFilter = ExpressionDFilter(rule_expression='$close>2000')\n   >> instruments = D.instruments(market='csi300', filter_pipe=[expressionDFilter])\n   >> D.list_instruments(instruments=instruments, start_time='2015-01-01', end_time='2016-02-15', as_list=True)\n   ['SZ000651', 'SZ000002', 'SH600655', 'SH600570']\n\nFor more details about filter, please refer `Filter API <../component/data.html>`_.\n\nLoad features of certain instruments in a given time range:\n\n.. code-block:: python\n\n   >> from qlib.data import D\n   >> instruments = ['SH600000']\n   >> fields = ['$close', '$volume', 'Ref($close, 1)', 'Mean($close, 3)', '$high-$low']\n   >> D.features(instruments, fields, start_time='2010-01-01', end_time='2017-12-31', freq='day').head().to_string()\n   '                           $close     $volume  Ref($close, 1)  Mean($close, 3)  $high-$low\n   ... instrument  datetime\n   ... SH600000    2010-01-04  86.778313  16162960.0       88.825928        88.061483    2.907631\n   ...             2010-01-05  87.433578  28117442.0       86.778313        87.679273    3.235252\n   ...             2010-01-06  85.713585  23632884.0       87.433578        86.641825    1.720009\n   ...             2010-01-07  83.788803  20813402.0       85.713585        85.645322    3.030487\n   ...             2010-01-08  84.730675  16044853.0       83.788803        84.744354    2.047623'\n\nLoad features of certain stock pool in a given time range:\n\n.. note:: With cache enabled, the qlib data server will cache data all the time for the requested stock pool and fields, it may take longer to process the request for the first time than that without cache. But after the first time, requests with the same stock pool and fields will hit the cache and be processed faster even the requested time period changes.\n\n.. code-block:: python\n\n   >> from qlib.data import D\n   >> from qlib.data.filter import NameDFilter, ExpressionDFilter\n   >> nameDFilter = NameDFilter(name_rule_re='SH[0-9]{4}55')\n   >> expressionDFilter = ExpressionDFilter(rule_expression='$close>Ref($close,1)')\n   >> instruments = D.instruments(market='csi300', filter_pipe=[nameDFilter, expressionDFilter])\n   >> fields = ['$close', '$volume', 'Ref($close, 1)', 'Mean($close, 3)', '$high-$low']\n   >> D.features(instruments, fields, start_time='2010-01-01', end_time='2017-12-31', freq='day').head().to_string()\n   '                              $close        $volume  Ref($close, 1)  Mean($close, 3)  $high-$low\n   ... instrument  datetime\n   ... SH600655    2010-01-04  2699.567383  158193.328125     2619.070312      2626.097738  124.580566\n   ...             2010-01-08  2612.359619   77501.406250     2584.567627      2623.220133   83.373047\n   ...             2010-01-11  2712.982422  160852.390625     2612.359619      2636.636556  146.621582\n   ...             2010-01-12  2788.688232  164587.937500     2712.982422      2704.676758  128.413818\n   ...             2010-01-13  2790.604004  145460.453125     2788.688232      2764.091553  128.413818'\n\n\nFor more details about features, please refer `Feature API <../component/data.html>`_.\n\n.. note:: When calling `D.features()` at the client, use parameter `disk_cache=0` to skip dataset cache, use `disk_cache=1` to generate and use dataset cache. In addition, when calling at the server, users can use `disk_cache=2` to update the dataset cache.\n\n\nWhen you are building complicated expressions, implementing all the expressions in a single string may not be easy.\nFor example, it looks quite long and complicated:\n\n.. code-block:: python\n\n   >> from qlib.data import D\n   >> data = D.features([\"sh600519\"], [\"(($high / $close) + ($open / $close)) * (($high / $close) + ($open / $close)) / (($high / $close) + ($open / $close))\"], start_time=\"20200101\")\n\n\nBut using string is not the only way to implement the expression. You can also implement expression by code.\nHere is an example which does the same thing as above examples.\n\n\n.. code-block:: python\n\n   >> from qlib.data.ops import *\n   >> f1 = Feature(\"high\") / Feature(\"close\")\n   >> f2 = Feature(\"open\") / Feature(\"close\")\n   >> f3 = f1 + f2\n   >> f4 = f3 * f3 / f3\n\n   >> data = D.features([\"sh600519\"], [f4], start_time=\"20200101\")\n   >> data.head()\n\n\nAPI\n===\nTo know more about how to use the Data, go to API Reference: `Data API <../reference/api.html#data>`_\n"
  },
  {
    "path": "docs/start/initialization.rst",
    "content": ".. _initialization:\n\n===================\nQlib Initialization\n===================\n\n.. currentmodule:: qlib\n\n\nInitialization\n==============\n\nPlease follow the steps below to initialize ``Qlib``.\n\nDownload and prepare the Data: execute the following command to download stock data. Please pay `attention` that the data is collected from `Yahoo Finance <https://finance.yahoo.com/lookup>`_ and the data might not be perfect. We recommend users to prepare their own data if they have high-quality datasets. Please refer to `Data <../component/data.html#converting-csv-format-into-qlib-format>`_ for more information about customized dataset.\n\n    .. code-block:: bash\n\n        python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn\n\nPlease refer to `Data Preparation <../component/data.html#data-preparation>`_ for more information about `get_data.py`,\n\n\nInitialize Qlib before calling other APIs: run following code in python.\n\n    .. code-block:: Python\n\n        import qlib\n        # region in [REG_CN, REG_US]\n        from qlib.constant import REG_CN\n        provider_uri = \"~/.qlib/qlib_data/cn_data\"  # target_dir\n        qlib.init(provider_uri=provider_uri, region=REG_CN)\n\n.. note::\n   Do not import qlib package in the repository directory  of ``Qlib``, otherwise, errors may occur.\n\nParameters\n-------------------\n\nBesides `provider_uri` and `region`, `qlib.init` has other parameters.\nThe following are several important parameters of `qlib.init` (`Qlib` has a lot of config. Only part of parameters are limited here. More detailed setting can be found `here <https://github.com/microsoft/qlib/blob/main/qlib/config.py>`_):\n\n- `provider_uri`\n    Type: str. The URI of the Qlib data. For example, it could be the location where the data loaded by ``get_data.py`` are stored.\n- `region`\n    Type: str, optional parameter(default: `qlib.constant.REG_CN`).\n        Currently: ``qlib.constant.REG_US`` ('us') and ``qlib.constant.REG_CN`` ('cn') is supported. Different value of  `region` will result in different stock market mode.\n        - ``qlib.constant.REG_US``: US stock market.\n        - ``qlib.constant.REG_CN``: China stock market.\n\n        Different modes will result in different trading limitations and costs.\n        The region is just `shortcuts for defining a batch of configurations <https://github.com/microsoft/qlib/blob/528f74af099bf6156e9480bcd2bb28e453231212/qlib/config.py#L249>`_, which include minimal trading order unit (``trade_unit``),  trading limitation (``limit_threshold``) , etc.  It is not a necessary part and users can set the key configurations manually if the existing region setting can't meet their requirements.\n- `redis_host`\n    Type: str, optional parameter(default: \"127.0.0.1\"), host of `redis`\n        The lock and cache mechanism relies on redis.\n- `redis_port`\n    Type: int, optional parameter(default: 6379), port of `redis`\n\n    .. note::\n\n        The value of `region` should be aligned with the data stored in `provider_uri`. Currently, ``scripts/get_data.py`` only provides China stock market data. If users want to use the US stock market data, they should prepare their own US-stock data in `provider_uri` and switch to US-stock mode.\n\n    .. note::\n\n        If Qlib fails to connect redis via `redis_host` and `redis_port`, cache mechanism will not be used! Please refer to `Cache <../component/data.html#cache>`_ for details.\n- `exp_manager`\n    Type: dict, optional parameter, the setting of `experiment manager` to be used in qlib. Users can specify an experiment manager class, as well as the tracking URI for all the experiments. However, please be aware that we only support input of a dictionary in the following style for `exp_manager`. For more information about `exp_manager`, users can refer to `Recorder: Experiment Management <../component/recorder.html>`_.\n\n    .. code-block:: Python\n\n        # For example, if you want to set your tracking_uri to a <specific folder>, you can initialize qlib below\n        qlib.init(provider_uri=provider_uri, region=REG_CN, exp_manager= {\n            \"class\": \"MLflowExpManager\",\n            \"module_path\": \"qlib.workflow.expm\",\n            \"kwargs\": {\n                \"uri\": \"python_execution_path/mlruns\",\n                \"default_exp_name\": \"Experiment\",\n            }\n        })\n- `mongo`\n    Type: dict, optional parameter, the setting of `MongoDB <https://www.mongodb.com/>`_ which will be used in some features such as `Task Management <../advanced/task_management.html>`_, with high performance and clustered processing.\n    Users need to follow the steps in  `installation <https://www.mongodb.com/try/download/community>`_  to install MongoDB firstly and then access it via a URI.\n    Users can access mongodb with credential by setting \"task_url\"  to a string like `\"mongodb://%s:%s@%s\" % (user, pwd, host + \":\" + port)`.\n\n    .. code-block:: Python\n\n        # For example, you can initialize qlib below\n        qlib.init(provider_uri=provider_uri, region=REG_CN, mongo={\n            \"task_url\": \"mongodb://localhost:27017/\",  # your mongo url\n            \"task_db_name\": \"rolling_db\", # the database name of Task Management\n        })\n\n- `logging_level`\n    The logging level for the system.\n\n- `kernels`\n    The number of processes used when calculating features in Qlib's expression engine. It is very helpful to set it to 1 when you are debuggin an expression calculating exception\n"
  },
  {
    "path": "docs/start/installation.rst",
    "content": ".. _installation:\n\n============\nInstallation\n============\n\n.. currentmodule:: qlib\n\n\n``Qlib`` Installation\n=====================\n.. note::\n\n   `Qlib` supports both `Windows` and `Linux`. It's recommended to use `Qlib` in `Linux`. ``Qlib`` supports Python3, which is up to Python3.8.\n\nUsers can easily install ``Qlib`` by pip according to the following command:\n\n.. code-block:: bash\n\n   pip install pyqlib\n\n\nAlso, Users can install ``Qlib`` by the source code according to the following steps:\n\n- Enter the root directory of ``Qlib``, in which the file ``setup.py`` exists.\n- Then, please execute the following command to install the environment dependencies and install ``Qlib``:\n\n   .. code-block:: bash\n\n      $ pip install numpy\n      $ pip install --upgrade cython\n      $ git clone https://github.com/microsoft/qlib.git && cd qlib\n      $ python setup.py install\n\n.. note::\n   It's recommended to use anaconda/miniconda to setup the environment. ``Qlib`` needs lightgbm and pytorch packages, use pip to install them.\n\n\n\nUse the following code to make sure the installation successful:\n\n.. code-block:: python\n\n   >>> import qlib\n   >>> qlib.__version__\n   <LATEST VERSION>\n"
  },
  {
    "path": "docs/start/integration.rst",
    "content": "========================\nCustom Model Integration\n========================\n\nIntroduction\n============\n\n``Qlib``'s `Model Zoo` includes models such as ``LightGBM``, ``MLP``, ``LSTM``, etc.. These models are examples of ``Forecast Model``. In addition to the default models ``Qlib`` provide, users can integrate their own custom models into ``Qlib``.\n\nUsers can integrate their own custom models according to the following steps.\n\n- Define a custom model class, which should be a subclass of the `qlib.model.base.Model <../reference/api.html#module-qlib.model.base>`_.\n- Write a configuration file that describes the path and parameters of the custom model.\n- Test the custom model.\n\nCustom Model Class\n==================\nThe Custom models need to inherit `qlib.model.base.Model <../reference/api.html#module-qlib.model.base>`_ and override the methods in it.\n\n- Override the `__init__` method\n    - ``Qlib`` passes the initialized parameters to the \\_\\_init\\_\\_ method.\n    - The hyperparameters of model in the configuration must be consistent with those defined in the `__init__` method.\n    - Code Example: In the following example, the hyperparameters of model in the configuration file should contain parameters such as `loss:mse`.\n\n        .. code-block:: Python\n\n            def __init__(self, loss='mse', **kwargs):\n                if loss not in {'mse', 'binary'}:\n                    raise NotImplementedError\n                self._scorer = mean_squared_error if loss == 'mse' else roc_auc_score\n                self._params.update(objective=loss, **kwargs)\n                self._model = None\n\n- Override the `fit` method\n    - ``Qlib`` calls the fit method to train the model.\n    - The parameters must include training feature `dataset`, which is designed in the interface.\n    - The parameters could include some `optional` parameters with default values, such as `num_boost_round = 1000` for `GBDT`.\n    - Code Example: In the following example, `num_boost_round = 1000` is an optional parameter.\n\n        .. code-block:: Python\n\n            def fit(self, dataset: DatasetH, num_boost_round = 1000, **kwargs):\n\n                # prepare dataset for lgb training and evaluation\n                df_train, df_valid = dataset.prepare(\n                    [\"train\", \"valid\"], col_set=[\"feature\", \"label\"], data_key=DataHandlerLP.DK_L\n                )\n                x_train, y_train = df_train[\"feature\"], df_train[\"label\"]\n                x_valid, y_valid = df_valid[\"feature\"], df_valid[\"label\"]\n\n                # Lightgbm need 1D array as its label\n                if y_train.values.ndim == 2 and y_train.values.shape[1] == 1:\n                    y_train, y_valid = np.squeeze(y_train.values), np.squeeze(y_valid.values)\n                else:\n                    raise ValueError(\"LightGBM doesn't support multi-label training\")\n\n                dtrain = lgb.Dataset(x_train.values, label=y_train)\n                dvalid = lgb.Dataset(x_valid.values, label=y_valid)\n\n                # fit the model\n                self.model = lgb.train(\n                    self.params,\n                    dtrain,\n                    num_boost_round=num_boost_round,\n                    valid_sets=[dtrain, dvalid],\n                    valid_names=[\"train\", \"valid\"],\n                    early_stopping_rounds=early_stopping_rounds,\n                    verbose_eval=verbose_eval,\n                    evals_result=evals_result,\n                    **kwargs\n                )\n\n- Override the `predict` method\n    - The parameters must include the parameter `dataset`, which will be used to get the test dataset.\n    - Return the `prediction score`.\n    - Please refer to `Model API <../reference/api.html#module-qlib.model.base>`_ for the parameter types of the fit method.\n    - Code Example: In the following example, users need to use `LightGBM` to predict the label(such as `preds`) of test data `x_test` and return it.\n\n        .. code-block:: Python\n\n            def predict(self, dataset: DatasetH, **kwargs)-> pandas.Series:\n                if self.model is None:\n                    raise ValueError(\"model is not fitted yet!\")\n                x_test = dataset.prepare(\"test\", col_set=\"feature\", data_key=DataHandlerLP.DK_I)\n                return pd.Series(self.model.predict(x_test.values), index=x_test.index)\n\n- Override the `finetune` method (Optional)\n    - This method is optional to the users. When users want to use this method on their own models, they should inherit the ``ModelFT`` base class, which includes the interface of `finetune`.\n    - The parameters must include the parameter `dataset`.\n    - Code Example: In the following example, users will use `LightGBM` as the model and finetune it.\n\n        .. code-block:: Python\n\n            def finetune(self, dataset: DatasetH, num_boost_round=10, verbose_eval=20):\n                # Based on existing model and finetune by train more rounds\n                dtrain, _ = self._prepare_data(dataset)\n                self.model = lgb.train(\n                    self.params,\n                    dtrain,\n                    num_boost_round=num_boost_round,\n                    init_model=self.model,\n                    valid_sets=[dtrain],\n                    valid_names=[\"train\"],\n                    verbose_eval=verbose_eval,\n                )\n\nConfiguration File\n==================\n\nThe configuration file is described in detail in the `Workflow <../component/workflow.html#complete-example>`_ document. In order to integrate the custom model into ``Qlib``, users need to modify the \"model\" field in the configuration file. The configuration describes which models to use and how we can initialize it.\n\n- Example: The following example describes the `model` field of configuration file about the custom lightgbm model mentioned above, where `module_path` is the module path, `class` is the class name, and `args` is the hyperparameter passed into the __init__ method. All parameters in the field is passed to `self._params` by `\\*\\*kwargs` in `__init__` except `loss = mse`.\n\n    .. code-block:: YAML\n\n        model:\n            class: LGBModel\n            module_path: qlib.contrib.model.gbdt\n            args:\n                loss: mse\n                colsample_bytree: 0.8879\n                learning_rate: 0.0421\n                subsample: 0.8789\n                lambda_l1: 205.6999\n                lambda_l2: 580.9768\n                max_depth: 8\n                num_leaves: 210\n                num_threads: 20\n\nUsers could find configuration file of the baselines of the ``Model`` in ``examples/benchmarks``. All the configurations of different models are listed under the corresponding model folder.\n\nModel Testing\n=============\nAssuming that the configuration file is ``examples/benchmarks/LightGBM/workflow_config_lightgbm.yaml``, users can run the following command to test the custom model:\n\n.. code-block:: bash\n\n    cd examples  # Avoid running program under the directory contains `qlib`\n    qrun benchmarks/LightGBM/workflow_config_lightgbm.yaml\n\n.. note:: ``qrun`` is a built-in command of ``Qlib``.\n\nAlso, ``Model`` can also be tested as a single module. An example has been given in ``examples/workflow_by_code.ipynb``.\n\n\nReference\n=========\n\nTo know more about ``Forecast Model``, please refer to `Forecast Model: Model Training & Prediction <../component/model.html>`_ and `Model API <../reference/api.html#module-qlib.model.base>`_.\n"
  },
  {
    "path": "examples/README.md",
    "content": "# Requirements\n\nHere is the minimal hardware requirements to run the `workflow_by_code` example.\n- Memory: 16G\n- Free Disk: 5G\n\n\n# NOTE\nThe results will slightly vary on different OSs(the variance of annualized return will be less than 2%).\nThe evaluation results in the `README.md` page are from Linux OS.\n"
  },
  {
    "path": "examples/benchmarks/ADARNN/README.md",
    "content": "# AdaRNN\n* Code: [https://github.com/jindongwang/transferlearning/tree/master/code/deep/adarnn](https://github.com/jindongwang/transferlearning/tree/master/code/deep/adarnn)\n* Paper: [AdaRNN: Adaptive Learning and Forecasting for Time Series](https://arxiv.org/pdf/2108.04443.pdf).\n\n"
  },
  {
    "path": "examples/benchmarks/ADARNN/requirements.txt",
    "content": "pandas==1.1.2\nnumpy==1.21.0\nscikit_learn==0.23.2\ntorch==1.7.0\n"
  },
  {
    "path": "examples/benchmarks/ADARNN/workflow_config_adarnn_Alpha360.yaml",
    "content": "qlib_init:\n    provider_uri: \"~/.qlib/qlib_data/cn_data\"\n    region: cn\nmarket: &market csi300\nbenchmark: &benchmark SH000300\ndata_handler_config: &data_handler_config\n    start_time: 2008-01-01\n    end_time: 2020-08-01\n    fit_start_time: 2008-01-01\n    fit_end_time: 2014-12-31\n    instruments: *market\n    infer_processors:\n        - class: RobustZScoreNorm\n          kwargs:\n              fields_group: feature\n              clip_outlier: true\n        - class: Fillna\n          kwargs:\n              fields_group: feature\n    learn_processors:\n        - class: DropnaLabel\n        - class: CSRankNorm\n          kwargs:\n              fields_group: label\n    label: [\"Ref($close, -2) / Ref($close, -1) - 1\"]\nport_analysis_config: &port_analysis_config\n    strategy:\n        class: TopkDropoutStrategy\n        module_path: qlib.contrib.strategy\n        kwargs:\n            signal: <PRED>\n            topk: 50\n            n_drop: 5\n    backtest:\n        start_time: 2017-01-01\n        end_time: 2020-08-01\n        account: 100000000\n        benchmark: *benchmark\n        exchange_kwargs:\n            limit_threshold: 0.095\n            deal_price: close\n            open_cost: 0.0005\n            close_cost: 0.0015\n            min_cost: 5\ntask:\n    model:\n        class: ADARNN\n        module_path: qlib.contrib.model.pytorch_adarnn\n        kwargs:\n            d_feat: 6\n            hidden_size: 64\n            num_layers: 2\n            dropout: 0.0\n            n_epochs: 200\n            lr: 1e-3\n            early_stop: 20\n            batch_size: 800\n            metric: loss\n            loss: mse\n            GPU: 0\n    dataset:\n        class: DatasetH\n        module_path: qlib.data.dataset\n        kwargs:\n            handler:\n                class: Alpha360\n                module_path: qlib.contrib.data.handler\n                kwargs: *data_handler_config\n            segments:\n                train: [2008-01-01, 2014-12-31]\n                valid: [2015-01-01, 2016-12-31]\n                test: [2017-01-01, 2020-08-01]\n    record: \n        - class: SignalRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            model: <MODEL>\n            dataset: <DATASET>\n        - class: SigAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            ana_long_short: False\n            ann_scaler: 252\n        - class: PortAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            config: *port_analysis_config\n"
  },
  {
    "path": "examples/benchmarks/ADD/README.md",
    "content": "# ADD\n* Paper: [ADD: Augmented Disentanglement Distillation Framework for Improving Stock Trend Forecasting](https://arxiv.org/abs/2012.06289).\n\n"
  },
  {
    "path": "examples/benchmarks/ADD/requirements.txt",
    "content": "numpy==1.21.0\npandas==1.1.2\nscikit_learn==0.23.2\ntorch==1.7.0\n"
  },
  {
    "path": "examples/benchmarks/ADD/workflow_config_add_Alpha360.yaml",
    "content": "qlib_init:\n    provider_uri: \"~/.qlib/qlib_data/cn_data\"\n    region: cn\nmarket: &market csi300\nbenchmark: &benchmark SH000300\ndata_handler_config: &data_handler_config\n    start_time: 2008-01-01\n    end_time: 2020-08-01\n    fit_start_time: 2008-01-01\n    fit_end_time: 2014-12-31\n    instruments: *market\n    infer_processors:\n        - class: RobustZScoreNorm\n          kwargs:\n              fields_group: feature\n              clip_outlier: true\n        - class: Fillna\n          kwargs:\n              fields_group: feature\n    learn_processors:\n        - class: DropnaLabel\n        - class: CSRankNorm\n          kwargs:\n              fields_group: label\n    label: [\"Ref($close, -2) / Ref($close, -1) - 1\"]\nport_analysis_config: &port_analysis_config\n    strategy:\n        class: TopkDropoutStrategy\n        module_path: qlib.contrib.strategy\n        kwargs:\n            signal: <PRED>\n            topk: 50\n            n_drop: 5\n    backtest:\n        start_time: 2017-01-01\n        end_time: 2020-08-01\n        account: 100000000\n        benchmark: *benchmark\n        exchange_kwargs:\n            limit_threshold: 0.095\n            deal_price: close\n            open_cost: 0.0005\n            close_cost: 0.0015\n            min_cost: 5\ntask:\n    model:\n        class: ADD\n        module_path: qlib.contrib.model.pytorch_add\n        kwargs:\n            d_feat: 6\n            hidden_size: 64\n            num_layers: 2\n            dropout: 0.1\n            dec_dropout: 0.0\n            n_epochs: 200\n            lr: 1e-3\n            early_stop: 20\n            batch_size: 5000\n            metric: ic\n            base_model: GRU\n            gamma: 0.1\n            gamma_clip: 0.2\n            optimizer: adam\n            mu: 0.2\n            GPU: 0\n    dataset:\n        class: DatasetH\n        module_path: qlib.data.dataset\n        kwargs:\n            handler:\n                class: Alpha360\n                module_path: qlib.contrib.data.handler\n                kwargs: *data_handler_config\n            segments:\n                train: [2008-01-01, 2014-12-31]\n                valid: [2015-01-01, 2016-12-31]\n                test: [2017-01-01, 2020-08-01]\n    record:\n        - class: SignalRecord\n          module_path: qlib.workflow.record_temp\n          kwargs:\n            model: <MODEL>\n            dataset: <DATASET>\n        - class: SigAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs:\n            ana_long_short: False\n            ann_scaler: 252\n        - class: PortAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs:\n            config: *port_analysis_config\n"
  },
  {
    "path": "examples/benchmarks/ALSTM/README.md",
    "content": "# ALSTM\r\n\r\n- ALSTM contains a temporal attentive aggregation layer based on normal LSTM.\r\n\r\n- Paper: A dual-stage attention-based recurrent neural network for time series prediction.\r\n\r\n  [https://www.ijcai.org/Proceedings/2017/0366.pdf](https://www.ijcai.org/Proceedings/2017/0366.pdf)\r\n\r\n- NOTE: Current version of implementation is just a simplified version of ALSTM. It is an LSTM with attention.\r\n"
  },
  {
    "path": "examples/benchmarks/ALSTM/requirements.txt",
    "content": "numpy==1.21.0\npandas==1.1.2\nscikit_learn==0.23.2\ntorch==1.7.0\n"
  },
  {
    "path": "examples/benchmarks/ALSTM/workflow_config_alstm_Alpha158.yaml",
    "content": "qlib_init:\n    provider_uri: \"~/.qlib/qlib_data/cn_data\"\n    region: cn\nmarket: &market csi300\nbenchmark: &benchmark SH000300\ndata_handler_config: &data_handler_config\n    start_time: 2008-01-01\n    end_time: 2020-08-01\n    fit_start_time: 2008-01-01\n    fit_end_time: 2014-12-31\n    instruments: *market\n    infer_processors:\n        - class: FilterCol\n          kwargs:\n              fields_group: feature\n              col_list: [\"RESI5\", \"WVMA5\", \"RSQR5\", \"KLEN\", \"RSQR10\", \"CORR5\", \"CORD5\", \"CORR10\", \n                            \"ROC60\", \"RESI10\", \"VSTD5\", \"RSQR60\", \"CORR60\", \"WVMA60\", \"STD5\", \n                            \"RSQR20\", \"CORD60\", \"CORD10\", \"CORR20\", \"KLOW\"\n                        ]\n        - class: RobustZScoreNorm\n          kwargs:\n              fields_group: feature\n              clip_outlier: true\n        - class: Fillna\n          kwargs:\n              fields_group: feature\n    learn_processors:\n        - class: DropnaLabel\n        - class: CSRankNorm\n          kwargs:\n              fields_group: label\n    label: [\"Ref($close, -2) / Ref($close, -1) - 1\"] \n\nport_analysis_config: &port_analysis_config\n    strategy:\n        class: TopkDropoutStrategy\n        module_path: qlib.contrib.strategy\n        kwargs:\n            signal: <PRED>\n            topk: 50\n            n_drop: 5\n    backtest:\n        start_time: 2017-01-01\n        end_time: 2020-08-01\n        account: 100000000\n        benchmark: *benchmark\n        exchange_kwargs:\n            limit_threshold: 0.095\n            deal_price: close\n            open_cost: 0.0005\n            close_cost: 0.0015\n            min_cost: 5\ntask:\n    model:\n        class: ALSTM\n        module_path: qlib.contrib.model.pytorch_alstm_ts\n        kwargs:\n            d_feat: 20\n            hidden_size: 64\n            num_layers: 2\n            dropout: 0.0\n            n_epochs: 200\n            lr: 1e-3\n            early_stop: 10\n            batch_size: 800\n            metric: loss\n            loss: mse\n            n_jobs: 20\n            GPU: 0\n            rnn_type: GRU\n    dataset:\n        class: TSDatasetH\n        module_path: qlib.data.dataset\n        kwargs:\n            handler:\n                class: Alpha158\n                module_path: qlib.contrib.data.handler\n                kwargs: *data_handler_config\n            segments:\n                train: [2008-01-01, 2014-12-31]\n                valid: [2015-01-01, 2016-12-31]\n                test: [2017-01-01, 2020-08-01]\n            step_len: 20\n    record: \n        - class: SignalRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            model: <MODEL>\n            dataset: <DATASET>\n        - class: SigAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            ana_long_short: False\n            ann_scaler: 252\n        - class: PortAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            config: *port_analysis_config\n"
  },
  {
    "path": "examples/benchmarks/ALSTM/workflow_config_alstm_Alpha360.yaml",
    "content": "qlib_init:\n    provider_uri: \"~/.qlib/qlib_data/cn_data\"\n    region: cn\nmarket: &market csi300\nbenchmark: &benchmark SH000300\ndata_handler_config: &data_handler_config\n    start_time: 2008-01-01\n    end_time: 2020-08-01\n    fit_start_time: 2008-01-01\n    fit_end_time: 2014-12-31\n    instruments: *market\n    infer_processors:\n        - class: RobustZScoreNorm\n          kwargs:\n              fields_group: feature\n              clip_outlier: true\n        - class: Fillna\n          kwargs:\n              fields_group: feature\n    learn_processors:\n        - class: DropnaLabel\n        - class: CSRankNorm\n          kwargs:\n              fields_group: label\n    label: [\"Ref($close, -2) / Ref($close, -1) - 1\"]\nport_analysis_config: &port_analysis_config\n    strategy:\n        class: TopkDropoutStrategy\n        module_path: qlib.contrib.strategy\n        kwargs:\n            signal: <PRED>\n            topk: 50\n            n_drop: 5\n    backtest:\n        start_time: 2017-01-01\n        end_time: 2020-08-01\n        account: 100000000\n        benchmark: *benchmark\n        exchange_kwargs:\n            limit_threshold: 0.095\n            deal_price: close\n            open_cost: 0.0005\n            close_cost: 0.0015\n            min_cost: 5\ntask:\n    model:\n        class: ALSTM\n        module_path: qlib.contrib.model.pytorch_alstm\n        kwargs:\n            d_feat: 6\n            hidden_size: 64\n            num_layers: 2\n            dropout: 0.0\n            n_epochs: 200\n            lr: 1e-3\n            early_stop: 20\n            batch_size: 800\n            metric: loss\n            loss: mse\n            GPU: 0\n            rnn_type: GRU\n    dataset:\n        class: DatasetH\n        module_path: qlib.data.dataset\n        kwargs:\n            handler:\n                class: Alpha360\n                module_path: qlib.contrib.data.handler\n                kwargs: *data_handler_config\n            segments:\n                train: [2008-01-01, 2014-12-31]\n                valid: [2015-01-01, 2016-12-31]\n                test: [2017-01-01, 2020-08-01]\n    record: \n        - class: SignalRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            model: <MODEL>\n            dataset: <DATASET>\n        - class: SigAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            ana_long_short: False\n            ann_scaler: 252\n        - class: PortAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            config: *port_analysis_config\n"
  },
  {
    "path": "examples/benchmarks/CatBoost/README.md",
    "content": "# CatBoost\n* Code: [https://github.com/catboost/catboost](https://github.com/catboost/catboost)\n* Paper: CatBoost: unbiased boosting with categorical features. [https://proceedings.neurips.cc/paper/2018/file/14491b756b3a51daac41c24863285549-Paper.pdf](https://proceedings.neurips.cc/paper/2018/file/14491b756b3a51daac41c24863285549-Paper.pdf)."
  },
  {
    "path": "examples/benchmarks/CatBoost/requirements.txt",
    "content": "pandas==1.1.2\nnumpy==1.21.0\ncatboost==0.24.3\n"
  },
  {
    "path": "examples/benchmarks/CatBoost/workflow_config_catboost_Alpha158.yaml",
    "content": "qlib_init:\n    provider_uri: \"~/.qlib/qlib_data/cn_data\"\n    region: cn\nmarket: &market csi300\nbenchmark: &benchmark SH000300\ndata_handler_config: &data_handler_config\n    start_time: 2008-01-01\n    end_time: 2020-08-01\n    fit_start_time: 2008-01-01\n    fit_end_time: 2014-12-31\n    instruments: *market\nport_analysis_config: &port_analysis_config\n    strategy:\n        class: TopkDropoutStrategy\n        module_path: qlib.contrib.strategy\n        kwargs:\n            signal: <PRED>\n            topk: 50\n            n_drop: 5\n    backtest:\n        start_time: 2017-01-01\n        end_time: 2020-08-01\n        account: 100000000\n        benchmark: *benchmark\n        exchange_kwargs:\n            limit_threshold: 0.095\n            deal_price: close\n            open_cost: 0.0005\n            close_cost: 0.0015\n            min_cost: 5\ntask:\n    model:\n        class: CatBoostModel\n        module_path: qlib.contrib.model.catboost_model\n        kwargs:\n            loss: RMSE\n            learning_rate: 0.0421\n            subsample: 0.8789\n            max_depth: 6\n            num_leaves: 100\n            thread_count: 20\n            grow_policy: Lossguide\n            bootstrap_type: Poisson\n    dataset:\n        class: DatasetH\n        module_path: qlib.data.dataset\n        kwargs:\n            handler:\n                class: Alpha158\n                module_path: qlib.contrib.data.handler\n                kwargs: *data_handler_config\n            segments:\n                train: [2008-01-01, 2014-12-31]\n                valid: [2015-01-01, 2016-12-31]\n                test: [2017-01-01, 2020-08-01]\n    record: \n        - class: SignalRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            model: <MODEL>\n            dataset: <DATASET>\n        - class: SigAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            ana_long_short: False\n            ann_scaler: 252\n        - class: PortAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            config: *port_analysis_config\n"
  },
  {
    "path": "examples/benchmarks/CatBoost/workflow_config_catboost_Alpha158_csi500.yaml",
    "content": "qlib_init:\n    provider_uri: \"~/.qlib/qlib_data/cn_data\"\n    region: cn\nmarket: &market csi500\nbenchmark: &benchmark SH000905\ndata_handler_config: &data_handler_config\n    start_time: 2008-01-01\n    end_time: 2020-08-01\n    fit_start_time: 2008-01-01\n    fit_end_time: 2014-12-31\n    instruments: *market\nport_analysis_config: &port_analysis_config\n    strategy:\n        class: TopkDropoutStrategy\n        module_path: qlib.contrib.strategy\n        kwargs:\n            signal: <PRED>\n            topk: 50\n            n_drop: 5\n    backtest:\n        start_time: 2017-01-01\n        end_time: 2020-08-01\n        account: 100000000\n        benchmark: *benchmark\n        exchange_kwargs:\n            limit_threshold: 0.095\n            deal_price: close\n            open_cost: 0.0005\n            close_cost: 0.0015\n            min_cost: 5\ntask:\n    model:\n        class: CatBoostModel\n        module_path: qlib.contrib.model.catboost_model\n        kwargs:\n            loss: RMSE\n            learning_rate: 0.0421\n            subsample: 0.8789\n            max_depth: 6\n            num_leaves: 100\n            thread_count: 20\n            grow_policy: Lossguide\n            bootstrap_type: Poisson\n    dataset:\n        class: DatasetH\n        module_path: qlib.data.dataset\n        kwargs:\n            handler:\n                class: Alpha158\n                module_path: qlib.contrib.data.handler\n                kwargs: *data_handler_config\n            segments:\n                train: [2008-01-01, 2014-12-31]\n                valid: [2015-01-01, 2016-12-31]\n                test: [2017-01-01, 2020-08-01]\n    record: \n        - class: SignalRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            model: <MODEL>\n            dataset: <DATASET>\n        - class: SigAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            ana_long_short: False\n            ann_scaler: 252\n        - class: PortAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            config: *port_analysis_config\n"
  },
  {
    "path": "examples/benchmarks/CatBoost/workflow_config_catboost_Alpha360.yaml",
    "content": "qlib_init:\n    provider_uri: \"~/.qlib/qlib_data/cn_data\"\n    region: cn\nmarket: &market csi300\nbenchmark: &benchmark SH000300\ndata_handler_config: &data_handler_config\n    start_time: 2008-01-01\n    end_time: 2020-08-01\n    fit_start_time: 2008-01-01\n    fit_end_time: 2014-12-31\n    instruments: *market\n    infer_processors: []\n    learn_processors:\n        - class: DropnaLabel\n        - class: CSRankNorm\n          kwargs:\n              fields_group: label\n    label: [\"Ref($close, -2) / Ref($close, -1) - 1\"]\nport_analysis_config: &port_analysis_config\n    strategy:\n        class: TopkDropoutStrategy\n        module_path: qlib.contrib.strategy\n        kwargs:\n            signal: <PRED>\n            topk: 50\n            n_drop: 5\n    backtest:\n        start_time: 2017-01-01\n        end_time: 2020-08-01\n        account: 100000000\n        benchmark: *benchmark\n        exchange_kwargs:\n            limit_threshold: 0.095\n            deal_price: close\n            open_cost: 0.0005\n            close_cost: 0.0015\n            min_cost: 5\ntask:\n    model:\n        class: CatBoostModel\n        module_path: qlib.contrib.model.catboost_model\n        kwargs:\n            loss: RMSE\n            learning_rate: 0.0421\n            subsample: 0.8789\n            max_depth: 6\n            num_leaves: 100\n            thread_count: 20\n            grow_policy: Lossguide\n            bootstrap_type: Poisson\n    dataset:\n        class: DatasetH\n        module_path: qlib.data.dataset\n        kwargs:\n            handler:\n                class: Alpha360\n                module_path: qlib.contrib.data.handler\n                kwargs: *data_handler_config\n            segments:\n                train: [2008-01-01, 2014-12-31]\n                valid: [2015-01-01, 2016-12-31]\n                test: [2017-01-01, 2020-08-01]\n    record: \n        - class: SignalRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            model: <MODEL>\n            dataset: <DATASET>\n        - class: SigAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            ana_long_short: False\n            ann_scaler: 252\n        - class: PortAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            config: *port_analysis_config\n"
  },
  {
    "path": "examples/benchmarks/CatBoost/workflow_config_catboost_Alpha360_csi500.yaml",
    "content": "qlib_init:\n    provider_uri: \"~/.qlib/qlib_data/cn_data\"\n    region: cn\nmarket: &market csi500\nbenchmark: &benchmark SH000905\ndata_handler_config: &data_handler_config\n    start_time: 2008-01-01\n    end_time: 2020-08-01\n    fit_start_time: 2008-01-01\n    fit_end_time: 2014-12-31\n    instruments: *market\n    infer_processors: []\n    learn_processors:\n        - class: DropnaLabel\n        - class: CSRankNorm\n          kwargs:\n              fields_group: label\n    label: [\"Ref($close, -2) / Ref($close, -1) - 1\"]\nport_analysis_config: &port_analysis_config\n    strategy:\n        class: TopkDropoutStrategy\n        module_path: qlib.contrib.strategy\n        kwargs:\n            signal: <PRED>\n            topk: 50\n            n_drop: 5\n    backtest:\n        start_time: 2017-01-01\n        end_time: 2020-08-01\n        account: 100000000\n        benchmark: *benchmark\n        exchange_kwargs:\n            limit_threshold: 0.095\n            deal_price: close\n            open_cost: 0.0005\n            close_cost: 0.0015\n            min_cost: 5\ntask:\n    model:\n        class: CatBoostModel\n        module_path: qlib.contrib.model.catboost_model\n        kwargs:\n            loss: RMSE\n            learning_rate: 0.0421\n            subsample: 0.8789\n            max_depth: 6\n            num_leaves: 100\n            thread_count: 20\n            grow_policy: Lossguide\n            bootstrap_type: Poisson\n    dataset:\n        class: DatasetH\n        module_path: qlib.data.dataset\n        kwargs:\n            handler:\n                class: Alpha360\n                module_path: qlib.contrib.data.handler\n                kwargs: *data_handler_config\n            segments:\n                train: [2008-01-01, 2014-12-31]\n                valid: [2015-01-01, 2016-12-31]\n                test: [2017-01-01, 2020-08-01]\n    record: \n        - class: SignalRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            model: <MODEL>\n            dataset: <DATASET>\n        - class: SigAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            ana_long_short: False\n            ann_scaler: 252\n        - class: PortAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            config: *port_analysis_config\n"
  },
  {
    "path": "examples/benchmarks/DoubleEnsemble/README.md",
    "content": "# DoubleEnsemble\n* DoubleEnsemble is an ensemble framework leveraging learning trajectory based sample reweighting and shuffling based feature selection, to solve both the low signal-to-noise ratio and increasing number of features problems. They identify the key samples based on the training dynamics on each sample and elicit key features based on the ablation impact of each feature via shuffling. The model is applicable to a wide range of base models, capable of extracting complex patterns, while mitigating the overfitting and instability issues for financial market prediction.\n* This code used in Qlib is implemented by ourselves.\n* Paper: DoubleEnsemble: A New Ensemble Method Based on Sample Reweighting and Feature Selection for Financial Data Analysis [https://arxiv.org/pdf/2010.01265.pdf](https://arxiv.org/pdf/2010.01265.pdf)."
  },
  {
    "path": "examples/benchmarks/DoubleEnsemble/requirements.txt",
    "content": "pandas==1.1.2\r\nnumpy==1.21.0\r\nlightgbm==3.1.0"
  },
  {
    "path": "examples/benchmarks/DoubleEnsemble/workflow_config_doubleensemble_Alpha158.yaml",
    "content": "qlib_init:\r\n    provider_uri: \"~/.qlib/qlib_data/cn_data\"\r\n    region: cn\r\nmarket: &market csi300\r\nbenchmark: &benchmark SH000300\r\ndata_handler_config: &data_handler_config\r\n    start_time: 2008-01-01\r\n    end_time: 2020-08-01\r\n    fit_start_time: 2008-01-01\r\n    fit_end_time: 2014-12-31\r\n    instruments: *market\r\nport_analysis_config: &port_analysis_config\r\n    strategy:\r\n        class: TopkDropoutStrategy\r\n        module_path: qlib.contrib.strategy\r\n        kwargs:\r\n            signal: <PRED>\r\n            topk: 50\r\n            n_drop: 5\r\n    backtest:\r\n        start_time: 2017-01-01\r\n        end_time: 2020-08-01\r\n        account: 100000000\r\n        benchmark: *benchmark\r\n        exchange_kwargs:\r\n            limit_threshold: 0.095\r\n            deal_price: close\r\n            open_cost: 0.0005\r\n            close_cost: 0.0015\r\n            min_cost: 5\r\ntask:\r\n    model:\r\n        class: DEnsembleModel\r\n        module_path: qlib.contrib.model.double_ensemble\r\n        kwargs:\r\n            base_model: \"gbm\"\r\n            loss: mse\r\n            num_models: 3\r\n            enable_sr: True\r\n            enable_fs: True\r\n            alpha1: 1\r\n            alpha2: 1\r\n            bins_sr: 10\r\n            bins_fs: 5\r\n            decay: 0.5\r\n            sample_ratios:\r\n                - 0.8\r\n                - 0.7\r\n                - 0.6\r\n                - 0.5\r\n                - 0.4\r\n            sub_weights:\r\n                - 1\r\n                - 1\r\n                - 1\r\n            epochs: 28\r\n            colsample_bytree: 0.8879\r\n            learning_rate: 0.2\r\n            subsample: 0.8789\r\n            lambda_l1: 205.6999\r\n            lambda_l2: 580.9768\r\n            max_depth: 8\r\n            num_leaves: 210\r\n            num_threads: 20\r\n            verbosity: -1\r\n    dataset:\r\n        class: DatasetH\r\n        module_path: qlib.data.dataset\r\n        kwargs:\r\n            handler:\r\n                class: Alpha158\r\n                module_path: qlib.contrib.data.handler\r\n                kwargs: *data_handler_config\r\n            segments:\r\n                train: [2008-01-01, 2014-12-31]\r\n                valid: [2015-01-01, 2016-12-31]\r\n                test: [2017-01-01, 2020-08-01]\r\n    record: \r\n        - class: SignalRecord\r\n          module_path: qlib.workflow.record_temp\r\n          kwargs: \r\n            model: <MODEL>\r\n            dataset: <DATASET>\r\n        - class: SigAnaRecord\r\n          module_path: qlib.workflow.record_temp\r\n          kwargs: \r\n            ana_long_short: False\r\n            ann_scaler: 252\r\n        - class: PortAnaRecord\r\n          module_path: qlib.workflow.record_temp\r\n          kwargs: \r\n            config: *port_analysis_config\r\n"
  },
  {
    "path": "examples/benchmarks/DoubleEnsemble/workflow_config_doubleensemble_Alpha158_csi500.yaml",
    "content": "qlib_init:\n    provider_uri: \"~/.qlib/qlib_data/cn_data\"\n    region: cn\nmarket: &market csi500\nbenchmark: &benchmark SH000905\ndata_handler_config: &data_handler_config\n    start_time: 2008-01-01\n    end_time: 2020-08-01\n    fit_start_time: 2008-01-01\n    fit_end_time: 2014-12-31\n    instruments: *market\nport_analysis_config: &port_analysis_config\n    strategy:\n        class: TopkDropoutStrategy\n        module_path: qlib.contrib.strategy\n        kwargs:\n            signal: <PRED>\n            topk: 50\n            n_drop: 5\n    backtest:\n        start_time: 2017-01-01\n        end_time: 2020-08-01\n        account: 100000000\n        benchmark: *benchmark\n        exchange_kwargs:\n            limit_threshold: 0.095\n            deal_price: close\n            open_cost: 0.0005\n            close_cost: 0.0015\n            min_cost: 5\ntask:\n    model:\n        class: DEnsembleModel\n        module_path: qlib.contrib.model.double_ensemble\n        kwargs:\n            base_model: \"gbm\"\n            loss: mse\n            num_models: 6\n            enable_sr: True\n            enable_fs: True\n            alpha1: 1\n            alpha2: 1\n            bins_sr: 10\n            bins_fs: 5\n            decay: 0.5\n            sample_ratios:\n                - 0.8\n                - 0.7\n                - 0.6\n                - 0.5\n                - 0.4\n            sub_weights:\n                - 1\n                - 0.2\n                - 0.2\n                - 0.2\n                - 0.2\n                - 0.2\n            epochs: 28\n            colsample_bytree: 0.8879\n            learning_rate: 0.2\n            subsample: 0.8789\n            lambda_l1: 205.6999\n            lambda_l2: 580.9768\n            max_depth: 8\n            num_leaves: 210\n            num_threads: 20\n            verbosity: -1\n    dataset:\n        class: DatasetH\n        module_path: qlib.data.dataset\n        kwargs:\n            handler:\n                class: Alpha158\n                module_path: qlib.contrib.data.handler\n                kwargs: *data_handler_config\n            segments:\n                train: [2008-01-01, 2014-12-31]\n                valid: [2015-01-01, 2016-12-31]\n                test: [2017-01-01, 2020-08-01]\n    record: \n        - class: SignalRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            model: <MODEL>\n            dataset: <DATASET>\n        - class: SigAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            ana_long_short: False\n            ann_scaler: 252\n        - class: PortAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            config: *port_analysis_config\n"
  },
  {
    "path": "examples/benchmarks/DoubleEnsemble/workflow_config_doubleensemble_Alpha360.yaml",
    "content": "qlib_init:\r\n    provider_uri: \"~/.qlib/qlib_data/cn_data\"\r\n    region: cn\r\nmarket: &market csi300\r\nbenchmark: &benchmark SH000300\r\ndata_handler_config: &data_handler_config\r\n    start_time: 2008-01-01\r\n    end_time: 2020-08-01\r\n    fit_start_time: 2008-01-01\r\n    fit_end_time: 2014-12-31\r\n    instruments: *market\r\n    infer_processors: []\r\n    learn_processors:\r\n        - class: DropnaLabel\r\n        - class: CSRankNorm\r\n          kwargs:\r\n              fields_group: label\r\n    label: [\"Ref($close, -2) / Ref($close, -1) - 1\"]\r\nport_analysis_config: &port_analysis_config\r\n    strategy:\r\n        class: TopkDropoutStrategy\r\n        module_path: qlib.contrib.strategy\r\n        kwargs:\r\n            signal: <PRED>\r\n            topk: 50\r\n            n_drop: 5\r\n    backtest:\r\n        start_time: 2017-01-01\r\n        end_time: 2020-08-01\r\n        account: 100000000\r\n        benchmark: *benchmark\r\n        exchange_kwargs:\r\n            limit_threshold: 0.095\r\n            deal_price: close\r\n            open_cost: 0.0005\r\n            close_cost: 0.0015\r\n            min_cost: 5\r\ntask:\r\n    model:\r\n        class: DEnsembleModel\r\n        module_path: qlib.contrib.model.double_ensemble\r\n        kwargs:\r\n            base_model: \"gbm\"\r\n            loss: mse\r\n            num_models: 3\r\n            enable_sr: True\r\n            enable_fs: True\r\n            alpha1: 1\r\n            alpha2: 1\r\n            bins_sr: 10\r\n            bins_fs: 5\r\n            decay: 0.5\r\n            sample_ratios:\r\n                - 0.8\r\n                - 0.7\r\n                - 0.6\r\n                - 0.5\r\n                - 0.4\r\n            sub_weights:\r\n                - 1\r\n                - 1\r\n                - 1\r\n            epochs: 136\r\n            colsample_bytree: 0.8879\r\n            learning_rate: 0.0421\r\n            subsample: 0.8789\r\n            lambda_l1: 205.6999\r\n            lambda_l2: 580.9768\r\n            max_depth: 8\r\n            num_leaves: 210\r\n            num_threads: 20\r\n            verbosity: -1\r\n    dataset:\r\n        class: DatasetH\r\n        module_path: qlib.data.dataset\r\n        kwargs:\r\n            handler:\r\n                class: Alpha360\r\n                module_path: qlib.contrib.data.handler\r\n                kwargs: *data_handler_config\r\n            segments:\r\n                train: [2008-01-01, 2014-12-31]\r\n                valid: [2015-01-01, 2016-12-31]\r\n                test: [2017-01-01, 2020-08-01]\r\n    record: \r\n        - class: SignalRecord\r\n          module_path: qlib.workflow.record_temp\r\n          kwargs: \r\n            model: <MODEL>\r\n            dataset: <DATASET>\r\n        - class: SigAnaRecord\r\n          module_path: qlib.workflow.record_temp\r\n          kwargs:\r\n            ana_long_short: False\r\n            ann_scaler: 252\r\n        - class: PortAnaRecord\r\n          module_path: qlib.workflow.record_temp\r\n          kwargs: \r\n            config: *port_analysis_config\r\n"
  },
  {
    "path": "examples/benchmarks/DoubleEnsemble/workflow_config_doubleensemble_Alpha360_csi500.yaml",
    "content": "qlib_init:\n    provider_uri: \"~/.qlib/qlib_data/cn_data\"\n    region: cn\nmarket: &market csi500\nbenchmark: &benchmark SH000905\ndata_handler_config: &data_handler_config\n    start_time: 2008-01-01\n    end_time: 2020-08-01\n    fit_start_time: 2008-01-01\n    fit_end_time: 2014-12-31\n    instruments: *market\n    infer_processors: []\n    learn_processors:\n        - class: DropnaLabel\n        - class: CSRankNorm\n          kwargs:\n              fields_group: label\n    label: [\"Ref($close, -2) / Ref($close, -1) - 1\"]\nport_analysis_config: &port_analysis_config\n    strategy:\n        class: TopkDropoutStrategy\n        module_path: qlib.contrib.strategy\n        kwargs:\n            signal: <PRED>\n            topk: 50\n            n_drop: 5\n    backtest:\n        start_time: 2017-01-01\n        end_time: 2020-08-01\n        account: 100000000\n        benchmark: *benchmark\n        exchange_kwargs:\n            limit_threshold: 0.095\n            deal_price: close\n            open_cost: 0.0005\n            close_cost: 0.0015\n            min_cost: 5\ntask:\n    model:\n        class: DEnsembleModel\n        module_path: qlib.contrib.model.double_ensemble\n        kwargs:\n            base_model: \"gbm\"\n            loss: mse\n            num_models: 6\n            enable_sr: True\n            enable_fs: True\n            alpha1: 1\n            alpha2: 1\n            bins_sr: 10\n            bins_fs: 5\n            decay: 0.5\n            sample_ratios:\n                - 0.8\n                - 0.7\n                - 0.6\n                - 0.5\n                - 0.4\n            sub_weights:\n                - 1\n                - 0.2\n                - 0.2\n                - 0.2\n                - 0.2\n                - 0.2\n            epochs: 136\n            colsample_bytree: 0.8879\n            learning_rate: 0.0421\n            subsample: 0.8789\n            lambda_l1: 205.6999\n            lambda_l2: 580.9768\n            max_depth: 8\n            num_leaves: 210\n            num_threads: 20\n            verbosity: -1\n    dataset:\n        class: DatasetH\n        module_path: qlib.data.dataset\n        kwargs:\n            handler:\n                class: Alpha360\n                module_path: qlib.contrib.data.handler\n                kwargs: *data_handler_config\n            segments:\n                train: [2008-01-01, 2014-12-31]\n                valid: [2015-01-01, 2016-12-31]\n                test: [2017-01-01, 2020-08-01]\n    record: \n        - class: SignalRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            model: <MODEL>\n            dataset: <DATASET>\n        - class: SigAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs:\n            ana_long_short: False\n            ann_scaler: 252\n        - class: PortAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            config: *port_analysis_config\n"
  },
  {
    "path": "examples/benchmarks/DoubleEnsemble/workflow_config_doubleensemble_early_stop_Alpha158.yaml",
    "content": "qlib_init:\r\n    provider_uri: \"~/.qlib/qlib_data/cn_data\"\r\n    region: cn\r\nmarket: &market csi300\r\nbenchmark: &benchmark SH000300\r\ndata_handler_config: &data_handler_config\r\n    start_time: 2008-01-01\r\n    end_time: 2020-08-01\r\n    fit_start_time: 2008-01-01\r\n    fit_end_time: 2014-12-31\r\n    instruments: *market\r\nport_analysis_config: &port_analysis_config\r\n    strategy:\r\n        class: TopkDropoutStrategy\r\n        module_path: qlib.contrib.strategy\r\n        kwargs:\r\n            signal: <PRED>\r\n            topk: 50\r\n            n_drop: 5\r\n    backtest:\r\n        start_time: 2017-01-01\r\n        end_time: 2020-08-01\r\n        account: 100000000\r\n        benchmark: *benchmark\r\n        exchange_kwargs:\r\n            limit_threshold: 0.095\r\n            deal_price: close\r\n            open_cost: 0.0005\r\n            close_cost: 0.0015\r\n            min_cost: 5\r\ntask:\r\n    model:\r\n        class: DEnsembleModel\r\n        module_path: qlib.contrib.model.double_ensemble\r\n        kwargs:\r\n            base_model: \"gbm\"\r\n            loss: mse\r\n            num_models: 3\r\n            enable_sr: True\r\n            enable_fs: True\r\n            alpha1: 1\r\n            alpha2: 1\r\n            bins_sr: 10\r\n            bins_fs: 5\r\n            decay: 0.5\r\n            sample_ratios:\r\n                - 0.8\r\n                - 0.7\r\n                - 0.6\r\n                - 0.5\r\n                - 0.4\r\n            sub_weights:\r\n                - 1\r\n                - 1\r\n                - 1\r\n            epochs: 1000\r\n            early_stopping_rounds: 50\r\n            colsample_bytree: 0.8879\r\n            learning_rate: 0.2\r\n            subsample: 0.8789\r\n            lambda_l1: 205.6999\r\n            lambda_l2: 580.9768\r\n            max_depth: 8\r\n            num_leaves: 210\r\n            num_threads: 20\r\n            verbosity: -1\r\n    dataset:\r\n        class: DatasetH\r\n        module_path: qlib.data.dataset\r\n        kwargs:\r\n            handler:\r\n                class: Alpha158\r\n                module_path: qlib.contrib.data.handler\r\n                kwargs: *data_handler_config\r\n            segments:\r\n                train: [2008-01-01, 2014-12-31]\r\n                valid: [2015-01-01, 2016-12-31]\r\n                test: [2017-01-01, 2020-08-01]\r\n    record: \r\n        - class: SignalRecord\r\n          module_path: qlib.workflow.record_temp\r\n          kwargs: \r\n            model: <MODEL>\r\n            dataset: <DATASET>\r\n        - class: SigAnaRecord\r\n          module_path: qlib.workflow.record_temp\r\n          kwargs: \r\n            ana_long_short: False\r\n            ann_scaler: 252\r\n        - class: PortAnaRecord\r\n          module_path: qlib.workflow.record_temp\r\n          kwargs: \r\n            config: *port_analysis_config\r\n"
  },
  {
    "path": "examples/benchmarks/GATs/README.md",
    "content": "# GATs\n* Graph Attention Networks(GATs) leverage masked self-attentional layers on graph-structured data. The nodes in stacked layers have different weights and they are able to attend over their\nneighborhoods’ features, without requiring any kind of costly matrix operation (such as inversion) or depending on knowing the graph structure upfront.\n* This code used in Qlib is implemented with PyTorch by ourselves.\n* Paper: Graph Attention Networks https://arxiv.org/pdf/1710.10903.pdf"
  },
  {
    "path": "examples/benchmarks/GATs/requirements.txt",
    "content": "pandas==1.1.2\nnumpy==1.21.0\nscikit_learn==0.23.2\ntorch==1.7.0\n"
  },
  {
    "path": "examples/benchmarks/GATs/workflow_config_gats_Alpha158.yaml",
    "content": "qlib_init:\n    provider_uri: \"~/.qlib/qlib_data/cn_data\"\n    region: cn\nmarket: &market csi300\nbenchmark: &benchmark SH000300\ndata_handler_config: &data_handler_config\n    start_time: 2008-01-01\n    end_time: 2020-08-01\n    fit_start_time: 2008-01-01\n    fit_end_time: 2014-12-31\n    instruments: *market\n    infer_processors:\n        - class: FilterCol\n          kwargs:\n              fields_group: feature\n              col_list: [\"RESI5\", \"WVMA5\", \"RSQR5\", \"KLEN\", \"RSQR10\", \"CORR5\", \"CORD5\", \"CORR10\", \n                            \"ROC60\", \"RESI10\", \"VSTD5\", \"RSQR60\", \"CORR60\", \"WVMA60\", \"STD5\", \n                            \"RSQR20\", \"CORD60\", \"CORD10\", \"CORR20\", \"KLOW\"\n                        ]\n        - class: RobustZScoreNorm\n          kwargs:\n              fields_group: feature\n              clip_outlier: true\n        - class: Fillna\n          kwargs:\n              fields_group: feature\n    learn_processors:\n        - class: DropnaLabel\n        - class: CSRankNorm\n          kwargs:\n              fields_group: label\n    label: [\"Ref($close, -2) / Ref($close, -1) - 1\"] \nport_analysis_config: &port_analysis_config\n    strategy:\n        class: TopkDropoutStrategy\n        module_path: qlib.contrib.strategy\n        kwargs:\n            signal: <PRED>\n            topk: 50\n            n_drop: 5\n    backtest:\n        start_time: 2017-01-01\n        end_time: 2020-08-01\n        account: 100000000\n        benchmark: *benchmark\n        exchange_kwargs:\n            limit_threshold: 0.095\n            deal_price: close\n            open_cost: 0.0005\n            close_cost: 0.0015\n            min_cost: 5\ntask:\n    model:\n        class: GATs\n        module_path: qlib.contrib.model.pytorch_gats_ts\n        kwargs:\n            d_feat: 20\n            hidden_size: 64\n            num_layers: 2\n            dropout: 0.7\n            n_epochs: 200\n            lr: 1e-4\n            early_stop: 10\n            metric: loss\n            loss: mse\n            base_model: LSTM\n            model_path: \"benchmarks/LSTM/csi300_lstm_ts.pkl\"\n            GPU: 0\n    dataset:\n        class: TSDatasetH\n        module_path: qlib.data.dataset\n        kwargs:\n            handler:\n                class: Alpha158\n                module_path: qlib.contrib.data.handler\n                kwargs: *data_handler_config\n            segments:\n                train: [2008-01-01, 2014-12-31]\n                valid: [2015-01-01, 2016-12-31]\n                test: [2017-01-01, 2020-08-01]\n            step_len: 20\n    record: \n        - class: SignalRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            model: <MODEL>\n            dataset: <DATASET>\n        - class: SigAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            ana_long_short: False\n            ann_scaler: 252\n        - class: PortAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            config: *port_analysis_config\n"
  },
  {
    "path": "examples/benchmarks/GATs/workflow_config_gats_Alpha360.yaml",
    "content": "qlib_init:\n    provider_uri: \"~/.qlib/qlib_data/cn_data\"\n    region: cn\nmarket: &market csi300\nbenchmark: &benchmark SH000300\ndata_handler_config: &data_handler_config\n    start_time: 2008-01-01\n    end_time: 2020-08-01\n    fit_start_time: 2008-01-01\n    fit_end_time: 2014-12-31\n    instruments: *market\n    infer_processors:\n        - class: RobustZScoreNorm\n          kwargs:\n              fields_group: feature\n              clip_outlier: true\n        - class: Fillna\n          kwargs:\n              fields_group: feature\n    learn_processors:\n        - class: DropnaLabel\n        - class: CSRankNorm\n          kwargs:\n              fields_group: label\n    label: [\"Ref($close, -2) / Ref($close, -1) - 1\"]\nport_analysis_config: &port_analysis_config\n    strategy:\n        class: TopkDropoutStrategy\n        module_path: qlib.contrib.strategy\n        kwargs:\n            signal: <PRED>\n            topk: 50\n            n_drop: 5\n    backtest:\n        start_time: 2017-01-01\n        end_time: 2020-08-01\n        account: 100000000\n        benchmark: *benchmark\n        exchange_kwargs:\n            limit_threshold: 0.095\n            deal_price: close\n            open_cost: 0.0005\n            close_cost: 0.0015\n            min_cost: 5\ntask:\n    model:\n        class: GATs\n        module_path: qlib.contrib.model.pytorch_gats\n        kwargs:\n            d_feat: 6\n            hidden_size: 64\n            num_layers: 2\n            dropout: 0.7\n            n_epochs: 200\n            lr: 1e-4\n            early_stop: 20\n            metric: loss\n            loss: mse\n            base_model: LSTM\n            model_path: \"benchmarks/LSTM/model_lstm_csi300.pkl\"\n            GPU: 0\n    dataset:\n        class: DatasetH\n        module_path: qlib.data.dataset\n        kwargs:\n            handler:\n                class: Alpha360\n                module_path: qlib.contrib.data.handler\n                kwargs: *data_handler_config\n            segments:\n                train: [2008-01-01, 2014-12-31]\n                valid: [2015-01-01, 2016-12-31]\n                test: [2017-01-01, 2020-08-01]\n    record: \n        - class: SignalRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            model: <MODEL>\n            dataset: <DATASET>\n        - class: SigAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            ana_long_short: False\n            ann_scaler: 252\n        - class: PortAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            config: *port_analysis_config\n"
  },
  {
    "path": "examples/benchmarks/GRU/README.md",
    "content": "# Gated Recurrent Unit (GRU)\n* Paper: [Learning Phrase Representations using RNN Encoder–Decoder for Statistical Machine Translation](https://aclanthology.org/D14-1179.pdf).\n"
  },
  {
    "path": "examples/benchmarks/GRU/requirements.txt",
    "content": "numpy==1.21.0\npandas==1.1.2\nscikit_learn==0.23.2\ntorch==1.7.0\n"
  },
  {
    "path": "examples/benchmarks/GRU/workflow_config_gru_Alpha158.yaml",
    "content": "qlib_init:\n    provider_uri: \"~/.qlib/qlib_data/cn_data\"\n    region: cn\nmarket: &market csi300\nbenchmark: &benchmark SH000300\ndata_handler_config: &data_handler_config\n    start_time: 2008-01-01\n    end_time: 2020-08-01\n    fit_start_time: 2008-01-01\n    fit_end_time: 2014-12-31\n    instruments: *market\n    infer_processors:\n        - class: FilterCol\n          kwargs:\n              fields_group: feature\n              col_list: [\"RESI5\", \"WVMA5\", \"RSQR5\", \"KLEN\", \"RSQR10\", \"CORR5\", \"CORD5\", \"CORR10\", \n                            \"ROC60\", \"RESI10\", \"VSTD5\", \"RSQR60\", \"CORR60\", \"WVMA60\", \"STD5\", \n                            \"RSQR20\", \"CORD60\", \"CORD10\", \"CORR20\", \"KLOW\"\n                        ]\n        - class: RobustZScoreNorm\n          kwargs:\n              fields_group: feature\n              clip_outlier: true\n        - class: Fillna\n          kwargs:\n              fields_group: feature\n    learn_processors:\n        - class: DropnaLabel\n        - class: CSRankNorm\n          kwargs:\n              fields_group: label\n    label: [\"Ref($close, -2) / Ref($close, -1) - 1\"] \n\nport_analysis_config: &port_analysis_config\n    strategy:\n        class: TopkDropoutStrategy\n        module_path: qlib.contrib.strategy\n        kwargs:\n            signal: <PRED>\n            topk: 50\n            n_drop: 5\n    backtest:\n        start_time: 2017-01-01\n        end_time: 2020-08-01\n        account: 100000000\n        benchmark: *benchmark\n        exchange_kwargs:\n            limit_threshold: 0.095\n            deal_price: close\n            open_cost: 0.0005\n            close_cost: 0.0015\n            min_cost: 5\ntask:\n    model:\n        class: GRU\n        module_path: qlib.contrib.model.pytorch_gru_ts\n        kwargs:\n            d_feat: 20\n            hidden_size: 64\n            num_layers: 2\n            dropout: 0.0\n            n_epochs: 200\n            lr: 2e-4\n            early_stop: 10\n            batch_size: 800\n            metric: loss\n            loss: mse\n            n_jobs: 20\n            GPU: 0\n    dataset:\n        class: TSDatasetH\n        module_path: qlib.data.dataset\n        kwargs:\n            handler:\n                class: Alpha158\n                module_path: qlib.contrib.data.handler\n                kwargs: *data_handler_config\n            segments:\n                train: [2008-01-01, 2014-12-31]\n                valid: [2015-01-01, 2016-12-31]\n                test: [2017-01-01, 2020-08-01]\n            step_len: 20\n    record: \n        - class: SignalRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            model: <MODEL>\n            dataset: <DATASET>\n        - class: SigAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            ana_long_short: False\n            ann_scaler: 252\n        - class: PortAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            config: *port_analysis_config\n"
  },
  {
    "path": "examples/benchmarks/GRU/workflow_config_gru_Alpha360.yaml",
    "content": "qlib_init:\n    provider_uri: \"~/.qlib/qlib_data/cn_data\"\n    region: cn\nmarket: &market csi300\nbenchmark: &benchmark SH000300\ndata_handler_config: &data_handler_config\n    start_time: 2008-01-01\n    end_time: 2020-08-01\n    fit_start_time: 2008-01-01\n    fit_end_time: 2014-12-31\n    instruments: *market\n    infer_processors:\n        - class: RobustZScoreNorm\n          kwargs:\n              fields_group: feature\n              clip_outlier: true\n        - class: Fillna\n          kwargs:\n              fields_group: feature\n    learn_processors:\n        - class: DropnaLabel\n        - class: CSRankNorm\n          kwargs:\n              fields_group: label\n    label: [\"Ref($close, -2) / Ref($close, -1) - 1\"]\nport_analysis_config: &port_analysis_config\n    strategy:\n        class: TopkDropoutStrategy\n        module_path: qlib.contrib.strategy\n        kwargs:\n            signal: <PRED>\n            topk: 50\n            n_drop: 5\n    backtest:\n        start_time: 2017-01-01\n        end_time: 2020-08-01\n        account: 100000000\n        benchmark: *benchmark\n        exchange_kwargs:\n            limit_threshold: 0.095\n            deal_price: close\n            open_cost: 0.0005\n            close_cost: 0.0015\n            min_cost: 5\ntask:\n    model:\n        class: GRU\n        module_path: qlib.contrib.model.pytorch_gru\n        kwargs:\n            d_feat: 6\n            hidden_size: 64\n            num_layers: 2\n            dropout: 0.0\n            n_epochs: 200\n            lr: 1e-3\n            early_stop: 20\n            batch_size: 800\n            metric: loss\n            loss: mse\n            GPU: 0\n    dataset:\n        class: DatasetH\n        module_path: qlib.data.dataset\n        kwargs:\n            handler:\n                class: Alpha360\n                module_path: qlib.contrib.data.handler\n                kwargs: *data_handler_config\n            segments:\n                train: [2008-01-01, 2014-12-31]\n                valid: [2015-01-01, 2016-12-31]\n                test: [2017-01-01, 2020-08-01]\n    record: \n        - class: SignalRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            model: <MODEL>\n            dataset: <DATASET>\n        - class: SigAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            ana_long_short: False\n            ann_scaler: 252\n        - class: PortAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            config: *port_analysis_config\n"
  },
  {
    "path": "examples/benchmarks/GeneralPtNN/README.md",
    "content": "\n\n# Introduction\n\nWhat is GeneralPtNN\n- Fix previous design that fail to support both Time-series and tabular data\n- Now you can just replace the Pytorch model structure to run a NN model.\n\nWe provide an example to demonstrate the effectiveness of the current design.\n- `workflow_config_gru.yaml` align with previous results [GRU(Kyunghyun Cho, et al.)](../README.md#Alpha158-dataset)\n  - `workflow_config_gru2mlp.yaml` to demonstrate we can convert config from time-series to tabular data with minimal changes\n    - You only have to change the net & dataset class to make the conversion.\n- `workflow_config_mlp.yaml` achieved similar functionality with [MLP](../README.md#Alpha158-dataset)\n\n# TODO\n\n- We will align existing models to current design.\n\n- The result of `workflow_config_mlp.yaml` is different with the result of [MLP](../README.md#Alpha158-dataset) since GeneralPtNN has a different stopping method compared to previous implementations. Specificly, GeneralPtNN controls training according to epoches, whereas previous methods controlled by max_steps. \n"
  },
  {
    "path": "examples/benchmarks/GeneralPtNN/workflow_config_gru.yaml",
    "content": "qlib_init:\n    provider_uri: \"~/.qlib/qlib_data/cn_data\"\n    region: cn\nmarket: &market csi300\nbenchmark: &benchmark SH000300\ndata_handler_config: &data_handler_config\n    start_time: 2008-01-01\n    end_time: 2020-08-01\n    fit_start_time: 2008-01-01\n    fit_end_time: 2014-12-31\n    instruments: *market\n    infer_processors:\n        - class: FilterCol\n          kwargs:\n              fields_group: feature\n              col_list: [\"RESI5\", \"WVMA5\", \"RSQR5\", \"KLEN\", \"RSQR10\", \"CORR5\", \"CORD5\", \"CORR10\", \n                            \"ROC60\", \"RESI10\", \"VSTD5\", \"RSQR60\", \"CORR60\", \"WVMA60\", \"STD5\", \n                            \"RSQR20\", \"CORD60\", \"CORD10\", \"CORR20\", \"KLOW\"\n                        ]\n        - class: RobustZScoreNorm\n          kwargs:\n              fields_group: feature\n              clip_outlier: true\n        - class: Fillna\n          kwargs:\n              fields_group: feature\n    learn_processors:\n        - class: DropnaLabel\n        - class: CSRankNorm\n          kwargs:\n              fields_group: label\n    label: [\"Ref($close, -2) / Ref($close, -1) - 1\"] \n\nport_analysis_config: &port_analysis_config\n    strategy:\n        class: TopkDropoutStrategy\n        module_path: qlib.contrib.strategy\n        kwargs:\n            signal: <PRED>\n            topk: 50\n            n_drop: 5\n    backtest:\n        start_time: 2017-01-01\n        end_time: 2020-08-01\n        account: 100000000\n        benchmark: *benchmark\n        exchange_kwargs:\n            limit_threshold: 0.095\n            deal_price: close\n            open_cost: 0.0005\n            close_cost: 0.0015\n            min_cost: 5\ntask:\n    model:\n        class: GeneralPTNN\n        module_path: qlib.contrib.model.pytorch_general_nn\n        kwargs:\n            n_epochs: 200\n            lr: 2e-4\n            early_stop: 10\n            batch_size: 800\n            metric: loss\n            loss: mse\n            n_jobs: 20\n            GPU: 0\n            pt_model_uri: \"qlib.contrib.model.pytorch_gru_ts.GRUModel\"\n            pt_model_kwargs: {\n                \"d_feat\": 20,\n                \"hidden_size\": 64,\n                \"num_layers\": 2,\n                \"dropout\": 0.,\n            }\n    dataset:\n        class: TSDatasetH\n        module_path: qlib.data.dataset\n        kwargs:\n            handler:\n                class: Alpha158\n                module_path: qlib.contrib.data.handler\n                kwargs: *data_handler_config\n            segments:\n                train: [2008-01-01, 2014-12-31]\n                valid: [2015-01-01, 2016-12-31]\n                test: [2017-01-01, 2020-08-01]\n            step_len: 20\n    record: \n        - class: SignalRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            model: <MODEL>\n            dataset: <DATASET>\n        - class: SigAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            ana_long_short: False\n            ann_scaler: 252\n        - class: PortAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            config: *port_analysis_config\n"
  },
  {
    "path": "examples/benchmarks/GeneralPtNN/workflow_config_gru2mlp.yaml",
    "content": "qlib_init:\n    provider_uri: \"~/.qlib/qlib_data/cn_data\"\n    region: cn\nmarket: &market csi300\nbenchmark: &benchmark SH000300\ndata_handler_config: &data_handler_config\n    start_time: 2008-01-01\n    end_time: 2020-08-01\n    fit_start_time: 2008-01-01\n    fit_end_time: 2014-12-31\n    instruments: *market\n    infer_processors:\n        - class: FilterCol\n          kwargs:\n              fields_group: feature\n              col_list: [\"RESI5\", \"WVMA5\", \"RSQR5\", \"KLEN\", \"RSQR10\", \"CORR5\", \"CORD5\", \"CORR10\", \n                            \"ROC60\", \"RESI10\", \"VSTD5\", \"RSQR60\", \"CORR60\", \"WVMA60\", \"STD5\", \n                            \"RSQR20\", \"CORD60\", \"CORD10\", \"CORR20\", \"KLOW\"\n                        ]\n        - class: RobustZScoreNorm\n          kwargs:\n              fields_group: feature\n              clip_outlier: true\n        - class: Fillna\n          kwargs:\n              fields_group: feature\n    learn_processors:\n        - class: DropnaLabel\n        - class: CSRankNorm\n          kwargs:\n              fields_group: label\n    label: [\"Ref($close, -2) / Ref($close, -1) - 1\"] \n\nport_analysis_config: &port_analysis_config\n    strategy:\n        class: TopkDropoutStrategy\n        module_path: qlib.contrib.strategy\n        kwargs:\n            signal: <PRED>\n            topk: 50\n            n_drop: 5\n    backtest:\n        start_time: 2017-01-01\n        end_time: 2020-08-01\n        account: 100000000\n        benchmark: *benchmark\n        exchange_kwargs:\n            limit_threshold: 0.095\n            deal_price: close\n            open_cost: 0.0005\n            close_cost: 0.0015\n            min_cost: 5\ntask:\n    model:\n        class: GeneralPTNN\n        module_path: qlib.contrib.model.pytorch_general_nn\n        kwargs:\n            lr: 1e-3\n            n_epochs: 1\n            batch_size: 800\n            loss: mse\n            optimizer: adam\n            pt_model_uri: \"qlib.contrib.model.pytorch_nn.Net\"\n            pt_model_kwargs: \n                input_dim: 20\n                layers: [20,]\n    dataset:\n        class: DatasetH\n        module_path: qlib.data.dataset\n        kwargs:\n            handler:\n                class: Alpha158\n                module_path: qlib.contrib.data.handler\n                kwargs: *data_handler_config\n            segments:\n                train: [2008-01-01, 2014-12-31]\n                valid: [2015-01-01, 2016-12-31]\n                test: [2017-01-01, 2020-08-01]\n    record: \n        - class: SignalRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            model: <MODEL>\n            dataset: <DATASET>\n        - class: SigAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            ana_long_short: False\n            ann_scaler: 252\n        - class: PortAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            config: *port_analysis_config\n"
  },
  {
    "path": "examples/benchmarks/GeneralPtNN/workflow_config_mlp.yaml",
    "content": "qlib_init:\n    provider_uri: \"~/.qlib/qlib_data/cn_data\"\n    region: cn\nmarket: &market csi300\nbenchmark: &benchmark SH000300\ndata_handler_config: &data_handler_config\n    start_time: 2008-01-01\n    end_time: 2020-08-01\n    fit_start_time: 2008-01-01\n    fit_end_time: 2014-12-31\n    instruments: *market\n    infer_processors: [\n        {\n            \"class\" : \"DropCol\", \n            \"kwargs\":{\"col_list\": [\"VWAP0\"]}\n        },\n        {\n             \"class\" : \"CSZFillna\", \n             \"kwargs\":{\"fields_group\": \"feature\"}\n        }\n    ]\n    learn_processors: [\n        {\n            \"class\" : \"DropCol\", \n            \"kwargs\":{\"col_list\": [\"VWAP0\"]}\n        },\n        {\n            \"class\" : \"DropnaProcessor\", \n            \"kwargs\":{\"fields_group\": \"feature\"}\n        },\n        \"DropnaLabel\",\n        {\n            \"class\": \"CSZScoreNorm\", \n            \"kwargs\": {\"fields_group\": \"label\"}\n        }\n    ]\n    process_type: \"independent\"\n\nport_analysis_config: &port_analysis_config\n    strategy:\n        class: TopkDropoutStrategy\n        module_path: qlib.contrib.strategy\n        kwargs:\n            signal: <PRED>\n            topk: 50\n            n_drop: 5\n    backtest:\n        start_time: 2017-01-01\n        end_time: 2020-08-01\n        account: 100000000\n        benchmark: *benchmark\n        exchange_kwargs:\n            limit_threshold: 0.095\n            deal_price: close\n            open_cost: 0.0005\n            close_cost: 0.0015\n            min_cost: 5\ntask:\n    model:\n        class: GeneralPTNN\n        module_path: qlib.contrib.model.pytorch_general_nn\n        kwargs:\n            # FIXME: wrong parameters.\n            lr: 2e-3\n            batch_size: 8192\n            loss: mse\n            weight_decay: 0.0002\n            optimizer: adam\n            pt_model_uri: \"qlib.contrib.model.pytorch_nn.Net\"\n            pt_model_kwargs: \n                input_dim: 157\n    dataset:\n        class: DatasetH\n        module_path: qlib.data.dataset\n        kwargs:\n            handler:\n                class: Alpha158\n                module_path: qlib.contrib.data.handler\n                kwargs: *data_handler_config\n            segments:\n                train: [2008-01-01, 2014-12-31]\n                valid: [2015-01-01, 2016-12-31]\n                test: [2017-01-01, 2020-08-01]\n    record: \n        - class: SignalRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            model: <MODEL>\n            dataset: <DATASET>\n        - class: SigAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            ana_long_short: False\n            ann_scaler: 252\n        - class: PortAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            config: *port_analysis_config\n"
  },
  {
    "path": "examples/benchmarks/HIST/README.md",
    "content": "# HIST\n* Code: [https://github.com/Wentao-Xu/HIST](https://github.com/Wentao-Xu/HIST)\n* Paper: [HIST: A Graph-based Framework for Stock Trend Forecasting via Mining Concept-Oriented Shared InformationAdaRNN: Adaptive Learning and Forecasting for Time Series](https://arxiv.org/abs/2110.13716)."
  },
  {
    "path": "examples/benchmarks/HIST/requirements.txt",
    "content": "pandas==1.1.2\nnumpy==1.21.0\nscikit_learn==0.23.2\ntorch==1.7.0"
  },
  {
    "path": "examples/benchmarks/HIST/workflow_config_hist_Alpha360.yaml",
    "content": "qlib_init:\n    provider_uri: \"~/.qlib/qlib_data/cn_data\"\n    region: cn\nmarket: &market csi300\nbenchmark: &benchmark SH000300\ndata_handler_config: &data_handler_config\n    start_time: 2008-01-01\n    end_time: 2020-08-01\n    fit_start_time: 2008-01-01\n    fit_end_time: 2014-12-31\n    instruments: *market\n    infer_processors:\n        - class: RobustZScoreNorm\n          kwargs:\n              fields_group: feature\n              clip_outlier: true\n        - class: Fillna\n          kwargs:\n              fields_group: feature\n    learn_processors:\n        - class: DropnaLabel\n        - class: CSRankNorm\n          kwargs:\n              fields_group: label\n    label: [\"Ref($close, -2) / Ref($close, -1) - 1\"]\nport_analysis_config: &port_analysis_config\n    strategy:\n        class: TopkDropoutStrategy\n        module_path: qlib.contrib.strategy\n        kwargs:\n            signal: <PRED>\n            topk: 50\n            n_drop: 5\n    backtest:\n        start_time: 2017-01-01\n        end_time: 2020-08-01\n        account: 100000000\n        benchmark: *benchmark\n        exchange_kwargs:\n            limit_threshold: 0.095\n            deal_price: close\n            open_cost: 0.0005\n            close_cost: 0.0015\n            min_cost: 5\ntask:\n    model:\n        class: HIST\n        module_path: qlib.contrib.model.pytorch_hist\n        kwargs:\n            d_feat: 6\n            hidden_size: 64\n            num_layers: 2\n            dropout: 0\n            n_epochs: 200\n            lr: 1e-4\n            early_stop: 20\n            metric: ic\n            loss: mse\n            base_model: LSTM\n            model_path: \"benchmarks/LSTM/model_lstm_csi300.pkl\"\n            stock2concept: \"benchmarks/HIST/qlib_csi300_stock2concept.npy\"\n            stock_index: \"benchmarks/HIST/qlib_csi300_stock_index.npy\"\n            GPU: 0\n    dataset:\n        class: DatasetH\n        module_path: qlib.data.dataset\n        kwargs:\n            handler:\n                class: Alpha360\n                module_path: qlib.contrib.data.handler\n                kwargs: *data_handler_config\n            segments:\n                train: [2008-01-01, 2014-12-31]\n                valid: [2015-01-01, 2016-12-31]\n                test: [2017-01-01, 2020-08-01]\n    record: \n        - class: SignalRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            model: <MODEL>\n            dataset: <DATASET>\n        - class: SigAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            ana_long_short: False\n            ann_scaler: 252\n        - class: PortAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            config: *port_analysis_config\n"
  },
  {
    "path": "examples/benchmarks/IGMTF/README.md",
    "content": "# IGMTF\n* Code: [https://github.com/Wentao-Xu/IGMTF](https://github.com/Wentao-Xu/IGMTF)\n* Paper: [IGMTF: An Instance-wise Graph-based Framework for\nMultivariate Time Series Forecasting](https://arxiv.org/abs/2109.06489)."
  },
  {
    "path": "examples/benchmarks/IGMTF/requirements.txt",
    "content": "pandas==1.1.2\nnumpy==1.21.0\nscikit_learn==0.23.2\ntorch==1.7.0\n"
  },
  {
    "path": "examples/benchmarks/IGMTF/workflow_config_igmtf_Alpha360.yaml",
    "content": "qlib_init:\n    provider_uri: \"~/.qlib/qlib_data/cn_data\"\n    region: cn\nmarket: &market csi300\nbenchmark: &benchmark SH000300\ndata_handler_config: &data_handler_config\n    start_time: 2008-01-01\n    end_time: 2020-08-01\n    fit_start_time: 2008-01-01\n    fit_end_time: 2014-12-31\n    instruments: *market\n    infer_processors:\n        - class: RobustZScoreNorm\n          kwargs:\n              fields_group: feature\n              clip_outlier: true\n        - class: Fillna\n          kwargs:\n              fields_group: feature\n    learn_processors:\n        - class: DropnaLabel\n        - class: CSRankNorm\n          kwargs:\n              fields_group: label\n    label: [\"Ref($close, -2) / Ref($close, -1) - 1\"]\nport_analysis_config: &port_analysis_config\n    strategy:\n        class: TopkDropoutStrategy\n        module_path: qlib.contrib.strategy\n        kwargs:\n            signal: <PRED>\n            topk: 50\n            n_drop: 5\n    backtest:\n        start_time: 2017-01-01\n        end_time: 2020-08-01\n        account: 100000000\n        benchmark: *benchmark\n        exchange_kwargs:\n            limit_threshold: 0.095\n            deal_price: close\n            open_cost: 0.0005\n            close_cost: 0.0015\n            min_cost: 5\ntask:\n    model:\n        class: IGMTF\n        module_path: qlib.contrib.model.pytorch_igmtf\n        kwargs:\n            d_feat: 6\n            hidden_size: 64\n            num_layers: 2\n            dropout: 0\n            n_epochs: 200\n            lr: 1e-4\n            early_stop: 20\n            metric: ic\n            loss: mse\n            base_model: LSTM\n            model_path: \"benchmarks/LSTM/model_lstm_csi300.pkl\"\n            GPU: 0\n    dataset:\n        class: DatasetH\n        module_path: qlib.data.dataset\n        kwargs:\n            handler:\n                class: Alpha360\n                module_path: qlib.contrib.data.handler\n                kwargs: *data_handler_config\n            segments:\n                train: [2008-01-01, 2014-12-31]\n                valid: [2015-01-01, 2016-12-31]\n                test: [2017-01-01, 2020-08-01]\n    record: \n        - class: SignalRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            model: <MODEL>\n            dataset: <DATASET>\n        - class: SigAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            ana_long_short: False\n            ann_scaler: 252\n        - class: PortAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            config: *port_analysis_config\n"
  },
  {
    "path": "examples/benchmarks/KRNN/README.md",
    "content": "# KRNN\n* Code: [https://github.com/microsoft/FOST/blob/main/fostool/model/krnn.py](https://github.com/microsoft/FOST/blob/main/fostool/model/krnn.py)\n\n\n# Introductions about the settings/configs.\n* Torch_geometric is used in the original model in FOST, but we didn't use it.\n* make use your CUDA version matches the torch version to allow the usage of GPU, we use CUDA==10.2 and torch.__version__==1.12.1\n\n"
  },
  {
    "path": "examples/benchmarks/KRNN/requirements.txt",
    "content": "numpy==1.23.4\npandas==1.5.2\n"
  },
  {
    "path": "examples/benchmarks/KRNN/workflow_config_krnn_Alpha360.yaml",
    "content": "qlib_init:\n    provider_uri: \"~/.qlib/qlib_data/cn_data\"\n    region: cn\nmarket: &market csi300\nbenchmark: &benchmark SH000300\ndata_handler_config: &data_handler_config\n    start_time: 2008-01-01\n    end_time: 2020-08-01\n    fit_start_time: 2008-01-01\n    fit_end_time: 2014-12-31\n    instruments: *market\n    infer_processors:\n        - class: RobustZScoreNorm\n          kwargs:\n              fields_group: feature\n              clip_outlier: true\n        - class: Fillna\n          kwargs:\n              fields_group: feature\n    learn_processors:\n        - class: DropnaLabel\n        - class: CSRankNorm\n          kwargs:\n              fields_group: label\n    label: [\"Ref($close, -2) / Ref($close, -1) - 1\"]\nport_analysis_config: &port_analysis_config\n    strategy:\n        class: TopkDropoutStrategy\n        module_path: qlib.contrib.strategy\n        kwargs:\n            signal: <PRED>\n            topk: 50\n            n_drop: 5\n    backtest:\n        start_time: 2017-01-01\n        end_time: 2020-08-01\n        account: 100000000\n        benchmark: *benchmark\n        exchange_kwargs:\n            limit_threshold: 0.095\n            deal_price: close\n            open_cost: 0.0005\n            close_cost: 0.0015\n            min_cost: 5\ntask:\n    model:\n        class: KRNN\n        module_path: qlib.contrib.model.pytorch_krnn\n        kwargs:\n            fea_dim: 6\n            cnn_dim: 8\n            cnn_kernel_size: 3\n            rnn_dim: 8\n            rnn_dups: 2\n            rnn_layers: 2\n            n_epochs: 200\n            lr: 0.001\n            early_stop: 20\n            batch_size: 2000\n            metric: loss\n            GPU: 0\n    dataset:\n        class: DatasetH\n        module_path: qlib.data.dataset\n        kwargs:\n            handler:\n                class: Alpha360\n                module_path: qlib.contrib.data.handler\n                kwargs: *data_handler_config\n            segments:\n                train: [2008-01-01, 2014-12-31]\n                valid: [2015-01-01, 2016-12-31]\n                test: [2017-01-01, 2020-08-01]\n    record: \n        - class: SignalRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            model: <MODEL>\n            dataset: <DATASET>\n        - class: SigAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            ana_long_short: False\n            ann_scaler: 252\n        - class: PortAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            config: *port_analysis_config\n\n"
  },
  {
    "path": "examples/benchmarks/LSTM/README.md",
    "content": "# Long Short-Term Memory (LSTM)\n* Paper: [Long Short-Term Memory](https://direct.mit.edu/neco/article-abstract/9/8/1735/6109/Long-Short-Term-Memory?redirectedFrom=fulltext).\n"
  },
  {
    "path": "examples/benchmarks/LSTM/requirements.txt",
    "content": "numpy==1.21.0\npandas==1.1.2\nscikit_learn==0.23.2\ntorch==1.7.0\n"
  },
  {
    "path": "examples/benchmarks/LSTM/workflow_config_lstm_Alpha158.yaml",
    "content": "qlib_init:\n    provider_uri: \"~/.qlib/qlib_data/cn_data\"\n    region: cn\nmarket: &market csi300\nbenchmark: &benchmark SH000300\ndata_handler_config: &data_handler_config\n    start_time: 2008-01-01\n    end_time: 2020-08-01\n    fit_start_time: 2008-01-01\n    fit_end_time: 2014-12-31\n    instruments: *market\n    infer_processors:\n        - class: FilterCol\n          kwargs:\n              fields_group: feature\n              col_list: [\"RESI5\", \"WVMA5\", \"RSQR5\", \"KLEN\", \"RSQR10\", \"CORR5\", \"CORD5\", \"CORR10\", \n                            \"ROC60\", \"RESI10\", \"VSTD5\", \"RSQR60\", \"CORR60\", \"WVMA60\", \"STD5\", \n                            \"RSQR20\", \"CORD60\", \"CORD10\", \"CORR20\", \"KLOW\"\n                        ]\n        - class: RobustZScoreNorm\n          kwargs:\n              fields_group: feature\n              clip_outlier: true\n        - class: Fillna\n          kwargs:\n              fields_group: feature\n    learn_processors:\n        - class: DropnaLabel\n        - class: CSRankNorm\n          kwargs:\n              fields_group: label\n    label: [\"Ref($close, -2) / Ref($close, -1) - 1\"]\n\nport_analysis_config: &port_analysis_config\n    strategy:\n        class: TopkDropoutStrategy\n        module_path: qlib.contrib.strategy\n        kwargs:\n            signal: <PRED>\n            topk: 50\n            n_drop: 5\n    backtest:\n        start_time: 2017-01-01\n        end_time: 2020-08-01\n        account: 100000000\n        benchmark: *benchmark\n        exchange_kwargs:\n            limit_threshold: 0.095\n            deal_price: close\n            open_cost: 0.0005\n            close_cost: 0.0015\n            min_cost: 5\ntask:\n    model:\n        class: LSTM\n        module_path: qlib.contrib.model.pytorch_lstm_ts\n        kwargs:\n            d_feat: 20\n            hidden_size: 64\n            num_layers: 2\n            dropout: 0.0\n            n_epochs: 200\n            lr: 1e-3\n            early_stop: 10\n            batch_size: 800\n            metric: loss\n            loss: mse\n            n_jobs: 20\n            GPU: 0\n    dataset:\n        class: TSDatasetH\n        module_path: qlib.data.dataset\n        kwargs:\n            handler:\n                class: Alpha158\n                module_path: qlib.contrib.data.handler\n                kwargs: *data_handler_config\n            segments:\n                train: [2008-01-01, 2014-12-31]\n                valid: [2015-01-01, 2016-12-31]\n                test: [2017-01-01, 2020-08-01]\n            step_len: 20\n    record: \n        - class: SignalRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            model: <MODEL>\n            dataset: <DATASET>\n        - class: SigAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            ana_long_short: False\n            ann_scaler: 252\n        - class: PortAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            config: *port_analysis_config\n"
  },
  {
    "path": "examples/benchmarks/LSTM/workflow_config_lstm_Alpha360.yaml",
    "content": "qlib_init:\n    provider_uri: \"~/.qlib/qlib_data/cn_data\"\n    region: cn\nmarket: &market csi300\nbenchmark: &benchmark SH000300\ndata_handler_config: &data_handler_config\n    start_time: 2008-01-01\n    end_time: 2020-08-01\n    fit_start_time: 2008-01-01\n    fit_end_time: 2014-12-31\n    instruments: *market\n    infer_processors:\n        - class: RobustZScoreNorm\n          kwargs:\n              fields_group: feature\n              clip_outlier: true\n        - class: Fillna\n          kwargs:\n              fields_group: feature\n    learn_processors:\n        - class: DropnaLabel\n        - class: CSRankNorm\n          kwargs:\n              fields_group: label\n    label: [\"Ref($close, -2) / Ref($close, -1) - 1\"]\nport_analysis_config: &port_analysis_config\n    strategy:\n        class: TopkDropoutStrategy\n        module_path: qlib.contrib.strategy\n        kwargs:\n            signal: <PRED>\n            topk: 50\n            n_drop: 5\n    backtest:\n        start_time: 2017-01-01\n        end_time: 2020-08-01\n        account: 100000000\n        benchmark: *benchmark\n        exchange_kwargs:\n            limit_threshold: 0.095\n            deal_price: close\n            open_cost: 0.0005\n            close_cost: 0.0015\n            min_cost: 5\ntask:\n    model:\n        class: LSTM\n        module_path: qlib.contrib.model.pytorch_lstm\n        kwargs:\n            d_feat: 6\n            hidden_size: 64\n            num_layers: 2\n            dropout: 0.0\n            n_epochs: 200\n            lr: 1e-3\n            early_stop: 20\n            batch_size: 800\n            metric: loss\n            loss: mse\n            GPU: 0\n    dataset:\n        class: DatasetH\n        module_path: qlib.data.dataset\n        kwargs:\n            handler:\n                class: Alpha360\n                module_path: qlib.contrib.data.handler\n                kwargs: *data_handler_config\n            segments:\n                train: [2008-01-01, 2014-12-31]\n                valid: [2015-01-01, 2016-12-31]\n                test: [2017-01-01, 2020-08-01]\n    record: \n        - class: SignalRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            model: <MODEL>\n            dataset: <DATASET>\n        - class: SigAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            ana_long_short: False\n            ann_scaler: 252\n        - class: PortAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            config: *port_analysis_config\n"
  },
  {
    "path": "examples/benchmarks/LightGBM/README.md",
    "content": "# LightGBM\n* Code: [https://github.com/microsoft/LightGBM](https://github.com/microsoft/LightGBM)\n* Paper: LightGBM: A Highly Efficient Gradient Boosting\nDecision Tree. [https://proceedings.neurips.cc/paper/2017/file/6449f44a102fde848669bdd9eb6b76fa-Paper.pdf](https://proceedings.neurips.cc/paper/2017/file/6449f44a102fde848669bdd9eb6b76fa-Paper.pdf).\n\n\n# Introductions about the settings/configs.\n\n`workflow_config_lightgbm_multi_freq.yaml`\n- It uses data sources of different frequencies (i.e. multiple frequencies) for daily prediction.\n"
  },
  {
    "path": "examples/benchmarks/LightGBM/features_resample_N.py",
    "content": "#  Copyright (c) Microsoft Corporation.\n#  Licensed under the MIT License.\n\nimport pandas as pd\n\nfrom qlib.data.inst_processor import InstProcessor\nfrom qlib.utils.resam import resam_calendar\n\n\nclass ResampleNProcessor(InstProcessor):\n    def __init__(self, target_frq: str, **kwargs):\n        self.target_frq = target_frq\n\n    def __call__(self, df: pd.DataFrame, *args, **kwargs):\n        df.index = pd.to_datetime(df.index)\n        res_index = resam_calendar(df.index, \"1min\", self.target_frq)\n        df = df.resample(self.target_frq).last().reindex(res_index)\n        return df\n"
  },
  {
    "path": "examples/benchmarks/LightGBM/features_sample.py",
    "content": "import datetime\nimport pandas as pd\n\nfrom qlib.data.inst_processor import InstProcessor\n\n\nclass Resample1minProcessor(InstProcessor):\n    \"\"\"This processor tries to resample the data. It will reasmple the data from 1min freq to day freq by selecting a specific miniute\"\"\"\n\n    def __init__(self, hour: int, minute: int, **kwargs):\n        self.hour = hour\n        self.minute = minute\n\n    def __call__(self, df: pd.DataFrame, *args, **kwargs):\n        df.index = pd.to_datetime(df.index)\n        df = df.loc[df.index.time == datetime.time(self.hour, self.minute)]\n        df.index = df.index.normalize()\n        return df\n"
  },
  {
    "path": "examples/benchmarks/LightGBM/multi_freq_handler.py",
    "content": "#  Copyright (c) Microsoft Corporation.\n#  Licensed under the MIT License.\n\nimport pandas as pd\n\nfrom qlib.data.dataset.loader import QlibDataLoader\nfrom qlib.contrib.data.handler import DataHandlerLP, _DEFAULT_LEARN_PROCESSORS, check_transform_proc\n\n\nclass Avg15minLoader(QlibDataLoader):\n    def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame:\n        df = super(Avg15minLoader, self).load(instruments, start_time, end_time)\n        if self.is_group:\n            # feature_day(day freq) and feature_15min(1min freq, Average every 15 minutes) renamed feature\n            df.columns = df.columns.map(lambda x: (\"feature\", x[1]) if x[0].startswith(\"feature\") else x)\n        return df\n\n\nclass Avg15minHandler(DataHandlerLP):\n    def __init__(\n        self,\n        instruments=\"csi500\",\n        start_time=None,\n        end_time=None,\n        freq=\"day\",\n        infer_processors=[],\n        learn_processors=_DEFAULT_LEARN_PROCESSORS,\n        fit_start_time=None,\n        fit_end_time=None,\n        process_type=DataHandlerLP.PTYPE_A,\n        filter_pipe=None,\n        inst_processors=None,\n        **kwargs,\n    ):\n        infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)\n        learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)\n        data_loader = Avg15minLoader(\n            config=self.loader_config(), filter_pipe=filter_pipe, freq=freq, inst_processors=inst_processors\n        )\n        super().__init__(\n            instruments=instruments,\n            start_time=start_time,\n            end_time=end_time,\n            data_loader=data_loader,\n            infer_processors=infer_processors,\n            learn_processors=learn_processors,\n            process_type=process_type,\n        )\n\n    def loader_config(self):\n        # Results for dataset: df: pd.DataFrame\n        #   len(df.columns) == 6 + 6 * 16, len(df.index.get_level_values(level=\"datetime\").unique()) == T\n        #   df.columns: close0, close1, ..., close16, open0, ..., open16, ..., vwap16\n        #       freq == day:\n        #           close0, open0, low0, high0, volume0, vwap0\n        #       freq == 1min:\n        #           close1, ..., close16, ..., vwap1, ..., vwap16\n        #   df.index.name == [\"datetime\", \"instrument\"]: pd.MultiIndex\n        # Example:\n        #                          feature                        ...                  label\n        #                           close0      open0       low0  ... vwap1 vwap16    LABEL0\n        # datetime   instrument                                   ...\n        # 2020-10-09 SH600000    11.794546  11.819587  11.769505  ...   NaN    NaN -0.005214\n        # 2020-10-15 SH600000    12.044961  11.944795  11.932274  ...   NaN    NaN -0.007202\n        # ...                          ...        ...        ...  ...   ...    ...       ...\n        # 2021-05-28 SZ300676     6.369684   6.495406   6.306568  ...   NaN    NaN -0.001321\n        # 2021-05-31 SZ300676     6.601626   6.465643   6.465130  ...   NaN    NaN -0.023428\n\n        # features day: len(columns) == 6, freq = day\n        # $close is the closing price of the current trading day:\n        #   if the user needs to get the `close` before the last T days, use Ref($close, T-1), for example:\n        #                                    $close  Ref($close, 1)  Ref($close, 2)  Ref($close, 3)  Ref($close, 4)\n        #         instrument datetime\n        #         SH600519   2021-06-01  244.271530\n        #                    2021-06-02  242.205917      244.271530\n        #                    2021-06-03  242.229889      242.205917      244.271530\n        #                    2021-06-04  245.421524      242.229889      242.205917      244.271530\n        #                    2021-06-07  247.547089      245.421524      242.229889      242.205917      244.271530\n\n        # WARNING: Ref($close, N), if N == 0, Ref($close, N) ==> $close\n\n        fields = [\"$close\", \"$open\", \"$low\", \"$high\", \"$volume\", \"$vwap\"]\n        # names: close0, open0, ..., vwap0\n        names = list(map(lambda x: x.strip(\"$\") + \"0\", fields))\n\n        config = {\"feature_day\": (fields, names)}\n\n        # features 15min: len(columns) == 6 * 16, freq = 1min\n        #   $close is the closing price of the current trading day:\n        #       if the user gets 'close' for the i-th 15min of the last T days, use `Ref(Mean($close, 15), (T-1) * 240 + i * 15)`, for example:\n        #                                    Ref(Mean($close, 15), 225)  Ref(Mean($close, 15), 465)  Ref(Mean($close, 15), 705)\n        #             instrument datetime\n        #             SH600519   2021-05-31                  241.769897                  243.077942                  244.712997\n        #                        2021-06-01                  244.271530                  241.769897                  243.077942\n        #                        2021-06-02                  242.205917                  244.271530                  241.769897\n\n        # WARNING: Ref(Mean($close, 15), N), if N == 0, Ref(Mean($close, 15), N) ==> Mean($close, 15)\n\n        # Results of the current script:\n        #   time:   09:00 --> 09:14,            ..., 14:45 --> 14:59\n        #   fields: Ref(Mean($close, 15), 225), ..., Mean($close, 15)\n        #   name:   close1,                     ..., close16\n        #\n\n        # Expression description: take close as an example\n        #   Mean($close, 15) ==> df[\"$close\"].rolling(15, min_periods=1).mean()\n        #   Ref(Mean($close, 15), 15) ==> df[\"$close\"].rolling(15, min_periods=1).mean().shift(15)\n\n        #   NOTE: The last data of each trading day, which is the average of the i-th 15 minutes\n\n        # Average:\n        #   Average of the i-th 15-minute period of each trading day: 1 <= i <= 250 // 16\n        #       Avg(15minutes): Ref(Mean($close, 15), 240 - i * 15)\n        #\n        #   Average of the first 15 minutes of each trading day; i = 1\n        #       Avg(09:00 --> 09:14), df.index.loc[\"09:14\"]: Ref(Mean($close, 15), 240- 1 * 15) ==> Ref(Mean($close, 15), 225)\n        #   Average of the last 15 minutes of each trading day; i = 16\n        #       Avg(14:45 --> 14:59), df.index.loc[\"14:59\"]: Ref(Mean($close, 15), 240 - 16 * 15) ==> Ref(Mean($close, 15), 0) ==> Mean($close, 15)\n\n        # 15min resample to day\n        #   df.resample(\"1d\").last()\n        tmp_fields = []\n        tmp_names = []\n        for i, _f in enumerate(fields):\n            _fields = [f\"Ref(Mean({_f}, 15), {j * 15})\" for j in range(1, 240 // 15)]\n            _names = [f\"{names[i][:-1]}{int(names[i][-1])+j}\" for j in range(240 // 15 - 1, 0, -1)]\n            _fields.append(f\"Mean({_f}, 15)\")\n            _names.append(f\"{names[i][:-1]}{int(names[i][-1])+240 // 15}\")\n            tmp_fields += _fields\n            tmp_names += _names\n        config[\"feature_15min\"] = (tmp_fields, tmp_names)\n        # label\n        config[\"label\"] = ([\"Ref($close, -2)/Ref($close, -1) - 1\"], [\"LABEL0\"])\n        return config\n"
  },
  {
    "path": "examples/benchmarks/LightGBM/requirements.txt",
    "content": "pandas==1.1.2\nnumpy==1.21.0\nlightgbm\n"
  },
  {
    "path": "examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml",
    "content": "qlib_init:\n    provider_uri: \"~/.qlib/qlib_data/cn_data\"\n    region: cn\nmarket: &market csi300\nbenchmark: &benchmark SH000300\ndata_handler_config: &data_handler_config\n    start_time: 2008-01-01\n    end_time: 2020-08-01\n    fit_start_time: 2008-01-01\n    fit_end_time: 2014-12-31\n    instruments: *market\nport_analysis_config: &port_analysis_config\n    strategy:\n        class: TopkDropoutStrategy\n        module_path: qlib.contrib.strategy\n        kwargs:\n            signal: <PRED>\n            topk: 50\n            n_drop: 5\n    backtest:\n        start_time: 2017-01-01\n        end_time: 2020-08-01\n        account: 100000000\n        benchmark: *benchmark\n        exchange_kwargs:\n            limit_threshold: 0.095\n            deal_price: close\n            open_cost: 0.0005\n            close_cost: 0.0015\n            min_cost: 5\ntask:\n    model:\n        class: LGBModel\n        module_path: qlib.contrib.model.gbdt\n        kwargs:\n            loss: mse\n            colsample_bytree: 0.8879\n            learning_rate: 0.2\n            subsample: 0.8789\n            lambda_l1: 205.6999\n            lambda_l2: 580.9768\n            max_depth: 8\n            num_leaves: 210\n            num_threads: 20\n    dataset:\n        class: DatasetH\n        module_path: qlib.data.dataset\n        kwargs:\n            handler:\n                class: Alpha158\n                module_path: qlib.contrib.data.handler\n                kwargs: *data_handler_config\n            segments:\n                train: [2008-01-01, 2014-12-31]\n                valid: [2015-01-01, 2016-12-31]\n                test: [2017-01-01, 2020-08-01]\n    record: \n        - class: SignalRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            model: <MODEL>\n            dataset: <DATASET>\n        - class: SigAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            ana_long_short: False\n            ann_scaler: 252\n        - class: PortAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            config: *port_analysis_config\n"
  },
  {
    "path": "examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158_csi500.yaml",
    "content": "qlib_init:\n    provider_uri: \"~/.qlib/qlib_data/cn_data\"\n    region: cn\nmarket: &market csi500\nbenchmark: &benchmark SH000905\ndata_handler_config: &data_handler_config\n    start_time: 2008-01-01\n    end_time: 2020-08-01\n    fit_start_time: 2008-01-01\n    fit_end_time: 2014-12-31\n    instruments: *market\nport_analysis_config: &port_analysis_config\n    strategy:\n        class: TopkDropoutStrategy\n        module_path: qlib.contrib.strategy\n        kwargs:\n            signal: <PRED>\n            topk: 50\n            n_drop: 5\n    backtest:\n        start_time: 2017-01-01\n        end_time: 2020-08-01\n        account: 100000000\n        benchmark: *benchmark\n        exchange_kwargs:\n            limit_threshold: 0.095\n            deal_price: close\n            open_cost: 0.0005\n            close_cost: 0.0015\n            min_cost: 5\ntask:\n    model:\n        class: LGBModel\n        module_path: qlib.contrib.model.gbdt\n        kwargs:\n            loss: mse\n            colsample_bytree: 0.9\n            learning_rate: 0.1\n            subsample: 0.9\n            lambda_l1: 205.6999\n            lambda_l2: 580.9768\n            max_depth: 8\n            num_leaves: 250\n            num_threads: 20\n    dataset:\n        class: DatasetH\n        module_path: qlib.data.dataset\n        kwargs:\n            handler:\n                class: Alpha158\n                module_path: qlib.contrib.data.handler\n                kwargs: *data_handler_config\n            segments:\n                train: [2008-01-01, 2014-12-31]\n                valid: [2015-01-01, 2016-12-31]\n                test: [2017-01-01, 2020-08-01]\n    record: \n        - class: SignalRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            model: <MODEL>\n            dataset: <DATASET>\n        - class: SigAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            ana_long_short: False\n            ann_scaler: 252\n        - class: PortAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            config: *port_analysis_config\n"
  },
  {
    "path": "examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158_multi_freq.yaml",
    "content": "qlib_init:\n    provider_uri:\n        day: \"~/.qlib/qlib_data/cn_data\"\n        1min: \"~/.qlib/qlib_data/cn_data_1min\"\n    region: cn\n    dataset_cache: null\n    maxtasksperchild: 1\nmarket: &market csi300\nbenchmark: &benchmark SH000300\ndata_handler_config: &data_handler_config\n    start_time: 2008-01-01\n    # 1min closing time is 15:00:00\n    end_time: \"2020-08-01 15:00:00\"\n    fit_start_time: 2008-01-01\n    fit_end_time: 2014-12-31\n    instruments: *market\n    freq:\n        label: day\n        feature: 1min\n    # with label as reference\n    inst_processors:\n        feature:\n            - class: Resample1minProcessor\n              module_path: features_sample.py\n              kwargs:\n                  hour: 14\n                  minute: 56\n\nport_analysis_config: &port_analysis_config\n    strategy:\n        class: TopkDropoutStrategy\n        module_path: qlib.contrib.strategy.strategy\n        kwargs:\n            topk: 50\n            n_drop: 5\n            signal: <PRED>\n    backtest:\n        verbose: False\n        limit_threshold: 0.095\n        account: 100000000\n        benchmark: *benchmark\n        deal_price: close\n        open_cost: 0.0005\n        close_cost: 0.0015\n        min_cost: 5\ntask:\n    model:\n        class: LGBModel\n        module_path: qlib.contrib.model.gbdt\n        kwargs:\n            loss: mse\n            colsample_bytree: 0.8879\n            learning_rate: 0.2\n            subsample: 0.8789\n            lambda_l1: 205.6999\n            lambda_l2: 580.9768\n            max_depth: 8\n            num_leaves: 210\n            num_threads: 20\n    dataset:\n        class: DatasetH\n        module_path: qlib.data.dataset\n        kwargs:\n            handler:\n                class: Alpha158\n                module_path: qlib.contrib.data.handler\n                kwargs: *data_handler_config\n            segments:\n                train: [2008-01-01, 2014-12-31]\n                valid: [2015-01-01, 2016-12-31]\n                test: [2017-01-01, 2020-08-01]\n    record: \n        - class: SignalRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: {}\n        - class: SigAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            ana_long_short: False\n            ann_scaler: 252\n        - class: PortAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            config: *port_analysis_config\n"
  },
  {
    "path": "examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha360.yaml",
    "content": "qlib_init:\n    provider_uri: \"~/.qlib/qlib_data/cn_data\"\n    region: cn\nmarket: &market csi300\nbenchmark: &benchmark SH000300\ndata_handler_config: &data_handler_config\n    start_time: 2008-01-01\n    end_time: 2020-08-01\n    fit_start_time: 2008-01-01\n    fit_end_time: 2014-12-31\n    instruments: *market\n    infer_processors: []\n    learn_processors:\n        - class: DropnaLabel\n        - class: CSRankNorm\n          kwargs:\n              fields_group: label\n    label: [\"Ref($close, -2) / Ref($close, -1) - 1\"]\nport_analysis_config: &port_analysis_config\n    strategy:\n        class: TopkDropoutStrategy\n        module_path: qlib.contrib.strategy\n        kwargs:\n            signal: <PRED>\n            topk: 50\n            n_drop: 5\n    backtest:\n        start_time: 2017-01-01\n        end_time: 2020-08-01\n        account: 100000000\n        benchmark: *benchmark\n        exchange_kwargs:\n            limit_threshold: 0.095\n            deal_price: close\n            open_cost: 0.0005\n            close_cost: 0.0015\n            min_cost: 5\ntask:\n    model:\n        class: LGBModel\n        module_path: qlib.contrib.model.gbdt\n        kwargs:\n            loss: mse\n            colsample_bytree: 0.8879\n            learning_rate: 0.0421\n            subsample: 0.8789\n            lambda_l1: 205.6999\n            lambda_l2: 580.9768\n            max_depth: 8\n            num_leaves: 210\n            num_threads: 20\n    dataset:\n        class: DatasetH\n        module_path: qlib.data.dataset\n        kwargs:\n            handler:\n                class: Alpha360\n                module_path: qlib.contrib.data.handler\n                kwargs: *data_handler_config\n            segments:\n                train: [2008-01-01, 2014-12-31]\n                valid: [2015-01-01, 2016-12-31]\n                test: [2017-01-01, 2020-08-01]\n    record: \n        - class: SignalRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            model: <MODEL>\n            dataset: <DATASET>\n        - class: SigAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            ana_long_short: False\n            ann_scaler: 252\n        - class: PortAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            config: *port_analysis_config\n"
  },
  {
    "path": "examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha360_csi500.yaml",
    "content": "qlib_init:\n    provider_uri: \"~/.qlib/qlib_data/cn_data\"\n    region: cn\nmarket: &market csi500\nbenchmark: &benchmark SH000905\ndata_handler_config: &data_handler_config\n    start_time: 2008-01-01\n    end_time: 2020-08-01\n    fit_start_time: 2008-01-01\n    fit_end_time: 2014-12-31\n    instruments: *market\n    infer_processors: []\n    learn_processors:\n        - class: DropnaLabel\n        - class: CSRankNorm\n          kwargs:\n              fields_group: label\n    label: [\"Ref($close, -2) / Ref($close, -1) - 1\"]\nport_analysis_config: &port_analysis_config\n    strategy:\n        class: TopkDropoutStrategy\n        module_path: qlib.contrib.strategy\n        kwargs:\n            signal: <PRED>\n            topk: 50\n            n_drop: 5\n    backtest:\n        start_time: 2017-01-01\n        end_time: 2020-08-01\n        account: 100000000\n        benchmark: *benchmark\n        exchange_kwargs:\n            limit_threshold: 0.095\n            deal_price: close\n            open_cost: 0.0005\n            close_cost: 0.0015\n            min_cost: 5\ntask:\n    model:\n        class: LGBModel\n        module_path: qlib.contrib.model.gbdt\n        kwargs:\n            loss: mse\n            colsample_bytree: 0.8879\n            learning_rate: 0.0421\n            subsample: 0.8789\n            lambda_l1: 205.6999\n            lambda_l2: 580.9768\n            max_depth: 8\n            num_leaves: 210\n            num_threads: 20\n    dataset:\n        class: DatasetH\n        module_path: qlib.data.dataset\n        kwargs:\n            handler:\n                class: Alpha360\n                module_path: qlib.contrib.data.handler\n                kwargs: *data_handler_config\n            segments:\n                train: [2008-01-01, 2014-12-31]\n                valid: [2015-01-01, 2016-12-31]\n                test: [2017-01-01, 2020-08-01]\n    record: \n        - class: SignalRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            model: <MODEL>\n            dataset: <DATASET>\n        - class: SigAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            ana_long_short: False\n            ann_scaler: 252\n        - class: PortAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            config: *port_analysis_config\n"
  },
  {
    "path": "examples/benchmarks/LightGBM/workflow_config_lightgbm_configurable_dataset.yaml",
    "content": "qlib_init:\n    provider_uri: \"~/.qlib/qlib_data/cn_data\"\n    region: cn\nmarket: &market csi300\nbenchmark: &benchmark SH000300\ndata_handler_config: &data_handler_config\n    start_time: 2008-01-01\n    end_time: 2020-08-01\n    instruments: *market\n    data_loader:\n        class: QlibDataLoader\n        kwargs:\n            config:\n                feature:\n                    - [\"Resi($close, 15)/$close\", \"Std(Abs($close/Ref($close, 1)-1)*$volume, 5)/(Mean(Abs($close/Ref($close, 1)-1)*$volume, 5)+1e-12)\", \"Rsquare($close, 5)\", \"($high-$low)/$open\", \"Rsquare($close, 10)\", \"Corr($close, Log($volume+1), 5)\", \"Corr($close/Ref($close,1), Log($volume/Ref($volume, 1)+1), 5)\", \"Corr($close, Log($volume+1), 10)\", \"Rsquare($close, 20)\", \"Corr($close/Ref($close,1), Log($volume/Ref($volume, 1)+1), 60)\", \"Corr($close/Ref($close,1), Log($volume/Ref($volume, 1)+1), 10)\", \"Corr($close, Log($volume+1), 20)\", \"(Less($open, $close)-$low)/$open\"]\n                    - [\"RESI5\", \"WVMA5\", \"RSQR5\", \"KLEN\", \"RSQR10\", \"CORR5\", \"CORD5\", \"CORR10\", \"RSQR20\", \"CORD60\", \"CORD10\", \"CORR20\", \"KLOW\"]\n                label:\n                    - [\"Ref($close, -2)/Ref($close, -1) - 1\"]\n                    - [\"LABEL0\"]\n            freq: day\n\n    learn_processors:\n        - class: DropnaLabel\n        - class: CSZScoreNorm\n          kwargs:\n            fields_group: label\nport_analysis_config: &port_analysis_config\n    strategy:\n        class: TopkDropoutStrategy\n        module_path: qlib.contrib.strategy\n        kwargs:\n            signal: <PRED>\n            topk: 50\n            n_drop: 5\n    backtest:\n        start_time: 2017-01-01\n        end_time: 2020-08-01\n        account: 100000000\n        benchmark: *benchmark\n        exchange_kwargs:\n            limit_threshold: 0.095\n            deal_price: close\n            open_cost: 0.0005\n            close_cost: 0.0015\n            min_cost: 5\ntask:\n    model:\n        class: LGBModel\n        module_path: qlib.contrib.model.gbdt\n        kwargs:\n            loss: mse\n            colsample_bytree: 0.8879\n            learning_rate: 0.2\n            subsample: 0.8789\n            lambda_l1: 205.6999\n            lambda_l2: 580.9768\n            max_depth: 8\n            num_leaves: 210\n            num_threads: 20\n    dataset:\n        class: DatasetH\n        module_path: qlib.data.dataset\n        kwargs:\n            handler:\n                class: DataHandlerLP\n                module_path: qlib.data.dataset.handler\n                kwargs: *data_handler_config\n            segments:\n                train: [2008-01-01, 2014-12-31]\n                valid: [2015-01-01, 2016-12-31]\n                test: [2017-01-01, 2020-08-01]\n    record: \n        - class: SignalRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            model: <MODEL>\n            dataset: <DATASET>\n        - class: SigAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            ana_long_short: False\n            ann_scaler: 252\n        - class: PortAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            config: *port_analysis_config\n"
  },
  {
    "path": "examples/benchmarks/LightGBM/workflow_config_lightgbm_multi_freq.yaml",
    "content": "qlib_init:\n    provider_uri:\n        day: \"~/.qlib/qlib_data/cn_data\"\n        1min: \"~/.qlib/qlib_data/cn_data_1min\"\n    region: cn\n    dataset_cache: null\n    maxtasksperchild: null\nmarket: &market csi300\nbenchmark: &benchmark SH000300\ndata_handler_config: &data_handler_config\n    start_time: 2008-01-01\n    # 1min closing time is 15:00:00\n    end_time: \"2020-08-01 15:00:00\"\n    fit_start_time: 2008-01-01\n    fit_end_time: 2014-12-31\n    instruments: *market\n    freq:\n        label: day\n        feature_15min: 1min\n        feature_day: day\n    # with label as reference\n    inst_processors:\n        feature_15min:\n            - class: ResampleNProcessor\n              module_path: features_resample_N.py\n              kwargs:\n                  target_frq: 1d\n\nport_analysis_config: &port_analysis_config\n    strategy:\n        class: TopkDropoutStrategy\n        module_path: qlib.contrib.strategy\n        kwargs:\n            signal: <PRED>\n            topk: 50\n            n_drop: 5\n    backtest:\n        start_time: 2017-01-01\n        end_time: 2020-08-01\n        account: 100000000\n        benchmark: *benchmark\n        exchange_kwargs:\n            limit_threshold: 0.095\n            deal_price: close\n            open_cost: 0.0005\n            close_cost: 0.0015\n            min_cost: 5\ntask:\n    model:\n        class: LGBModel\n        module_path: qlib.contrib.model.gbdt\n        kwargs:\n            loss: mse\n            colsample_bytree: 0.8879\n            learning_rate: 0.2\n            subsample: 0.8789\n            lambda_l1: 205.6999\n            lambda_l2: 580.9768\n            max_depth: 8\n            num_leaves: 210\n            num_threads: 20\n    dataset:\n        class: DatasetH\n        module_path: qlib.data.dataset\n        kwargs:\n            handler:\n                class: Avg15minHandler\n                module_path: multi_freq_handler.py\n                kwargs: *data_handler_config\n            segments:\n                train: [2008-01-01, 2014-12-31]\n                valid: [2015-01-01, 2016-12-31]\n                test: [2017-01-01, 2020-08-01]\n    record:\n        - class: SignalRecord\n          module_path: qlib.workflow.record_temp\n          kwargs:\n            model: <MODEL>\n            dataset: <DATASET>\n        - class: SigAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs:\n            ana_long_short: False\n            ann_scaler: 252\n        - class: PortAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs:\n            config: *port_analysis_config\n"
  },
  {
    "path": "examples/benchmarks/Linear/requirements.txt",
    "content": "numpy>=1.17.4\npandas>=1.0.1\nscikit-learn>=0.23.1\n"
  },
  {
    "path": "examples/benchmarks/Linear/workflow_config_linear_Alpha158.yaml",
    "content": "qlib_init:\n    provider_uri: \"~/.qlib/qlib_data/cn_data\"\n    region: cn\nmarket: &market csi300\nbenchmark: &benchmark SH000300\ndata_handler_config: &data_handler_config\n    start_time: 2008-01-01\n    end_time: 2020-08-01\n    fit_start_time: 2008-01-01\n    fit_end_time: 2014-12-31\n    instruments: *market\n    infer_processors:\n        - class: RobustZScoreNorm\n          kwargs:\n              fields_group: feature\n              clip_outlier: true\n        - class: Fillna\n          kwargs:\n              fields_group: feature\n    learn_processors:\n        - class: DropnaLabel\n        - class: CSRankNorm\n          kwargs:\n              fields_group: label\nport_analysis_config: &port_analysis_config\n    strategy:\n        class: TopkDropoutStrategy\n        module_path: qlib.contrib.strategy\n        kwargs:\n            signal: <PRED>\n            topk: 50\n            n_drop: 5\n    backtest:\n        start_time: 2017-01-01\n        end_time: 2020-08-01\n        account: 100000000\n        benchmark: *benchmark\n        exchange_kwargs:\n            limit_threshold: 0.095\n            deal_price: close\n            open_cost: 0.0005\n            close_cost: 0.0015\n            min_cost: 5\ntask:\n    model:\n        class: LinearModel\n        module_path: qlib.contrib.model.linear\n        kwargs:\n            estimator: ols\n    dataset:\n        class: DatasetH\n        module_path: qlib.data.dataset\n        kwargs:\n            handler:\n                class: Alpha158\n                module_path: qlib.contrib.data.handler\n                kwargs: *data_handler_config\n            segments:\n                train: [2008-01-01, 2014-12-31]\n                valid: [2015-01-01, 2016-12-31]\n                test: [2017-01-01, 2020-08-01]\n    record: \n        - class: SignalRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            model: <MODEL>\n            dataset: <DATASET>\n        - class: SigAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            ana_long_short: True\n            ann_scaler: 252\n        - class: PortAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            config: *port_analysis_config\n"
  },
  {
    "path": "examples/benchmarks/Linear/workflow_config_linear_Alpha158_csi500.yaml",
    "content": "qlib_init:\n    provider_uri: \"~/.qlib/qlib_data/cn_data\"\n    region: cn\nmarket: &market csi500\nbenchmark: &benchmark SH000905\ndata_handler_config: &data_handler_config\n    start_time: 2008-01-01\n    end_time: 2020-08-01\n    fit_start_time: 2008-01-01\n    fit_end_time: 2014-12-31\n    instruments: *market\n    infer_processors:\n        - class: RobustZScoreNorm\n          kwargs:\n              fields_group: feature\n              clip_outlier: true\n        - class: Fillna\n          kwargs:\n              fields_group: feature\n    learn_processors:\n        - class: DropnaLabel\n        - class: CSRankNorm\n          kwargs:\n              fields_group: label\nport_analysis_config: &port_analysis_config\n    strategy:\n        class: TopkDropoutStrategy\n        module_path: qlib.contrib.strategy\n        kwargs:\n            signal: <PRED>\n            topk: 50\n            n_drop: 5\n    backtest:\n        start_time: 2017-01-01\n        end_time: 2020-08-01\n        account: 100000000\n        benchmark: *benchmark\n        exchange_kwargs:\n            limit_threshold: 0.095\n            deal_price: close\n            open_cost: 0.0005\n            close_cost: 0.0015\n            min_cost: 5\ntask:\n    model:\n        class: LinearModel\n        module_path: qlib.contrib.model.linear\n        kwargs:\n            estimator: ols\n    dataset:\n        class: DatasetH\n        module_path: qlib.data.dataset\n        kwargs:\n            handler:\n                class: Alpha158\n                module_path: qlib.contrib.data.handler\n                kwargs: *data_handler_config\n            segments:\n                train: [2008-01-01, 2014-12-31]\n                valid: [2015-01-01, 2016-12-31]\n                test: [2017-01-01, 2020-08-01]\n    record: \n        - class: SignalRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            model: <MODEL>\n            dataset: <DATASET>\n        - class: SigAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            ana_long_short: True\n            ann_scaler: 252\n        - class: PortAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            config: *port_analysis_config\n"
  },
  {
    "path": "examples/benchmarks/Linear/workflow_config_linear_Alpha158_multi_pass_bt.yaml",
    "content": "qlib_init:\n    provider_uri: \"~/.qlib/qlib_data/cn_data\"\n    region: cn\nmarket: &market csi300\nbenchmark: &benchmark SH000300\ndata_handler_config: &data_handler_config\n    start_time: 2008-01-01\n    end_time: 2020-08-01\n    fit_start_time: 2008-01-01\n    fit_end_time: 2014-12-31\n    instruments: *market\n    infer_processors:\n        - class: RobustZScoreNorm\n          kwargs:\n              fields_group: feature\n              clip_outlier: true\n        - class: Fillna\n          kwargs:\n              fields_group: feature\n    learn_processors:\n        - class: DropnaLabel\n        - class: CSRankNorm\n          kwargs:\n              fields_group: label\nport_analysis_config: &port_analysis_config\n    strategy:\n        class: TopkDropoutStrategy\n        module_path: qlib.contrib.strategy\n        kwargs:\n            signal:\n                - <MODEL> \n                - <DATASET>\n            topk: 50\n            n_drop: 5\n    backtest:\n        start_time: 2017-01-01\n        end_time: 2020-08-01\n        account: 100000000\n        benchmark: *benchmark\n        exchange_kwargs:\n            limit_threshold: 0.095\n            deal_price: close\n            open_cost: 0.0005\n            close_cost: 0.0015\n            min_cost: 5\ntask:\n    model:\n        class: LinearModel\n        module_path: qlib.contrib.model.linear\n        kwargs:\n            estimator: ols\n    dataset:\n        class: DatasetH\n        module_path: qlib.data.dataset\n        kwargs:\n            handler:\n                class: Alpha158\n                module_path: qlib.contrib.data.handler\n                kwargs: *data_handler_config\n            segments:\n                train: [2008-01-01, 2014-12-31]\n                valid: [2015-01-01, 2016-12-31]\n                test: [2017-01-01, 2020-08-01]\n    record: \n        - class: SignalRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            model: <MODEL>\n            dataset: <DATASET>\n        - class: SigAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            ana_long_short: True\n            ann_scaler: 252\n        - class: MultiPassPortAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            config: *port_analysis_config\n"
  },
  {
    "path": "examples/benchmarks/Localformer/README.md",
    "content": "# Localformer\n"
  },
  {
    "path": "examples/benchmarks/Localformer/requirements.txt",
    "content": "numpy==1.21.0\r\npandas==1.1.2\r\ntorch==1.2.0"
  },
  {
    "path": "examples/benchmarks/Localformer/workflow_config_localformer_Alpha158.yaml",
    "content": "qlib_init:\n    provider_uri: \"~/.qlib/qlib_data/cn_data\"\n    region: cn\nmarket: &market csi300\nbenchmark: &benchmark SH000300\ndata_handler_config: &data_handler_config\n    start_time: 2008-01-01\n    end_time: 2020-08-01\n    fit_start_time: 2008-01-01\n    fit_end_time: 2014-12-31\n    instruments: *market\n    infer_processors:\n        - class: FilterCol\n          kwargs:\n              fields_group: feature\n              col_list: [\"RESI5\", \"WVMA5\", \"RSQR5\", \"KLEN\", \"RSQR10\", \"CORR5\", \"CORD5\", \"CORR10\", \n                            \"ROC60\", \"RESI10\", \"VSTD5\", \"RSQR60\", \"CORR60\", \"WVMA60\", \"STD5\", \n                            \"RSQR20\", \"CORD60\", \"CORD10\", \"CORR20\", \"KLOW\"\n                        ]\n        - class: RobustZScoreNorm\n          kwargs:\n              fields_group: feature\n              clip_outlier: true\n        - class: Fillna\n          kwargs:\n              fields_group: feature\n    learn_processors:\n        - class: DropnaLabel\n        - class: CSRankNorm\n          kwargs:\n              fields_group: label\n    label: [\"Ref($close, -2) / Ref($close, -1) - 1\"] \n\nport_analysis_config: &port_analysis_config\n    strategy:\n        class: TopkDropoutStrategy\n        module_path: qlib.contrib.strategy\n        kwargs:\n            signal: <PRED>\n            topk: 50\n            n_drop: 5\n    backtest:\n        start_time: 2017-01-01\n        end_time: 2020-08-01\n        account: 100000000\n        benchmark: *benchmark\n        exchange_kwargs:\n            limit_threshold: 0.095\n            deal_price: close\n            open_cost: 0.0005\n            close_cost: 0.0015\n            min_cost: 5\ntask:\n    model:\n        class: LocalformerModel\n        module_path: qlib.contrib.model.pytorch_localformer_ts\n        kwargs:\n            seed: 0\n            n_jobs: 20\n    dataset:\n        class: TSDatasetH\n        module_path: qlib.data.dataset\n        kwargs:\n            handler:\n                class: Alpha158\n                module_path: qlib.contrib.data.handler\n                kwargs: *data_handler_config\n            segments:\n                train: [2008-01-01, 2014-12-31]\n                valid: [2015-01-01, 2016-12-31]\n                test: [2017-01-01, 2020-08-01]\n            step_len: 20\n    record: \n        - class: SignalRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n              model: <MODEL>\n              dataset: <DATASET>\n        - class: SigAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n              ana_long_short: False\n              ann_scaler: 252\n        - class: PortAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n              config: *port_analysis_config\n"
  },
  {
    "path": "examples/benchmarks/Localformer/workflow_config_localformer_Alpha360.yaml",
    "content": "qlib_init:\n    provider_uri: \"~/.qlib/qlib_data/cn_data\"\n    region: cn\nmarket: &market csi300\nbenchmark: &benchmark SH000300\ndata_handler_config: &data_handler_config\n    start_time: 2008-01-01\n    end_time: 2020-08-01\n    fit_start_time: 2008-01-01\n    fit_end_time: 2014-12-31\n    instruments: *market\n    infer_processors:\n        - class: RobustZScoreNorm\n          kwargs:\n              fields_group: feature\n              clip_outlier: true\n        - class: Fillna\n          kwargs:\n              fields_group: feature\n    learn_processors:\n        - class: DropnaLabel\n        - class: CSRankNorm\n          kwargs:\n              fields_group: label\n    label: [\"Ref($close, -2) / Ref($close, -1) - 1\"]\nport_analysis_config: &port_analysis_config\n    strategy:\n        class: TopkDropoutStrategy\n        module_path: qlib.contrib.strategy\n        kwargs:\n            signal: <PRED>\n            topk: 50\n            n_drop: 5\n    backtest:\n        start_time: 2017-01-01\n        end_time: 2020-08-01\n        account: 100000000\n        benchmark: *benchmark\n        exchange_kwargs:\n            limit_threshold: 0.095\n            deal_price: close\n            open_cost: 0.0005\n            close_cost: 0.0015\n            min_cost: 5\ntask:\n    model:\n        class: LocalformerModel\n        module_path: qlib.contrib.model.pytorch_localformer\n        kwargs:\n            d_feat: 6\n            seed: 0\n    dataset:\n        class: DatasetH\n        module_path: qlib.data.dataset\n        kwargs:\n            handler:\n                class: Alpha360\n                module_path: qlib.contrib.data.handler\n                kwargs: *data_handler_config\n            segments:\n                train: [2008-01-01, 2014-12-31]\n                valid: [2015-01-01, 2016-12-31]\n                test: [2017-01-01, 2020-08-01]\n    record: \n    - class: SignalRecord\n      module_path: qlib.workflow.record_temp\n      kwargs: \n        model: <MODEL>\n        dataset: <DATASET>\n    - class: SigAnaRecord\n      module_path: qlib.workflow.record_temp\n      kwargs: \n        ana_long_short: False\n        ann_scaler: 252\n    - class: PortAnaRecord\n      module_path: qlib.workflow.record_temp\n      kwargs: \n        config: *port_analysis_config\n"
  },
  {
    "path": "examples/benchmarks/MLP/README.md",
    "content": "# Multi-Layer Perceptron (MLP)\n"
  },
  {
    "path": "examples/benchmarks/MLP/requirements.txt",
    "content": "pandas==1.1.2\nnumpy==1.21.0\nscikit_learn==0.23.2\ntorch==1.7.0\n"
  },
  {
    "path": "examples/benchmarks/MLP/workflow_config_mlp_Alpha158.yaml",
    "content": "qlib_init:\n    provider_uri: \"~/.qlib/qlib_data/cn_data\"\n    region: cn\nmarket: &market csi300\nbenchmark: &benchmark SH000300\ndata_handler_config: &data_handler_config\n    start_time: 2008-01-01\n    end_time: 2020-08-01\n    fit_start_time: 2008-01-01\n    fit_end_time: 2014-12-31\n    instruments: *market\n    infer_processors: [\n        {\n            \"class\" : \"DropCol\", \n            \"kwargs\":{\"col_list\": [\"VWAP0\"]}\n        },\n        {\n             \"class\" : \"CSZFillna\", \n             \"kwargs\":{\"fields_group\": \"feature\"}\n        }\n    ]\n    learn_processors: [\n        {\n            \"class\" : \"DropCol\", \n            \"kwargs\":{\"col_list\": [\"VWAP0\"]}\n        },\n        {\n            \"class\" : \"DropnaProcessor\", \n            \"kwargs\":{\"fields_group\": \"feature\"}\n        },\n        \"DropnaLabel\",\n        {\n            \"class\": \"CSZScoreNorm\", \n            \"kwargs\": {\"fields_group\": \"label\"}\n        }\n    ]\n    process_type: \"independent\"\n\nport_analysis_config: &port_analysis_config\n    strategy:\n        class: TopkDropoutStrategy\n        module_path: qlib.contrib.strategy\n        kwargs:\n            signal: <PRED>\n            topk: 50\n            n_drop: 5\n    backtest:\n        start_time: 2017-01-01\n        end_time: 2020-08-01\n        account: 100000000\n        benchmark: *benchmark\n        exchange_kwargs:\n            limit_threshold: 0.095\n            deal_price: close\n            open_cost: 0.0005\n            close_cost: 0.0015\n            min_cost: 5\ntask:\n    model:\n        class: DNNModelPytorch\n        module_path: qlib.contrib.model.pytorch_nn\n        kwargs:\n            loss: mse\n            lr: 0.002\n            optimizer: adam\n            max_steps: 8000\n            batch_size: 8192\n            GPU: 0\n            weight_decay: 0.0002\n            pt_model_kwargs:\n              input_dim: 157\n    dataset:\n        class: DatasetH\n        module_path: qlib.data.dataset\n        kwargs:\n            handler:\n                class: Alpha158\n                module_path: qlib.contrib.data.handler\n                kwargs: *data_handler_config\n            segments:\n                train: [2008-01-01, 2014-12-31]\n                valid: [2015-01-01, 2016-12-31]\n                test: [2017-01-01, 2020-08-01]\n    record: \n        - class: SignalRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            model: <MODEL>\n            dataset: <DATASET>\n        - class: SigAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            ana_long_short: False\n            ann_scaler: 252\n        - class: PortAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            config: *port_analysis_config\n"
  },
  {
    "path": "examples/benchmarks/MLP/workflow_config_mlp_Alpha158_csi500.yaml",
    "content": "qlib_init:\n    provider_uri: \"~/.qlib/qlib_data/cn_data\"\n    region: cn\nmarket: &market csi500\nbenchmark: &benchmark SH000905\ndata_handler_config: &data_handler_config\n    start_time: 2008-01-01\n    end_time: 2020-08-01\n    fit_start_time: 2008-01-01\n    fit_end_time: 2014-12-31\n    instruments: *market\n    infer_processors: [\n        {\n            \"class\" : \"DropCol\", \n            \"kwargs\":{\"col_list\": [\"VWAP0\"]}\n        },\n        {\n             \"class\" : \"CSZFillna\", \n             \"kwargs\":{\"fields_group\": \"feature\"}\n        }\n    ]\n    learn_processors: [\n        {\n            \"class\" : \"DropCol\", \n            \"kwargs\":{\"col_list\": [\"VWAP0\"]}\n        },\n        {\n            \"class\" : \"DropnaProcessor\", \n            \"kwargs\":{\"fields_group\": \"feature\"}\n        },\n        \"DropnaLabel\",\n        {\n            \"class\": \"CSZScoreNorm\", \n            \"kwargs\": {\"fields_group\": \"label\"}\n        }\n    ]\n    process_type: \"independent\"\n\nport_analysis_config: &port_analysis_config\n    strategy:\n        class: TopkDropoutStrategy\n        module_path: qlib.contrib.strategy\n        kwargs:\n            signal: <PRED>\n            topk: 50\n            n_drop: 5\n    backtest:\n        start_time: 2017-01-01\n        end_time: 2020-08-01\n        account: 100000000\n        benchmark: *benchmark\n        exchange_kwargs:\n            limit_threshold: 0.095\n            deal_price: close\n            open_cost: 0.0005\n            close_cost: 0.0015\n            min_cost: 5\ntask:\n    model:\n        class: DNNModelPytorch\n        module_path: qlib.contrib.model.pytorch_nn\n        kwargs:\n            loss: mse\n            lr: 0.002\n            optimizer: adam\n            max_steps: 8000\n            batch_size: 8192\n            GPU: 0\n            weight_decay: 0.0002\n            pt_model_kwargs:\n              input_dim: 157\n    dataset:\n        class: DatasetH\n        module_path: qlib.data.dataset\n        kwargs:\n            handler:\n                class: Alpha158\n                module_path: qlib.contrib.data.handler\n                kwargs: *data_handler_config\n            segments:\n                train: [2008-01-01, 2014-12-31]\n                valid: [2015-01-01, 2016-12-31]\n                test: [2017-01-01, 2020-08-01]\n    record: \n        - class: SignalRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            model: <MODEL>\n            dataset: <DATASET>\n        - class: SigAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            ana_long_short: False\n            ann_scaler: 252\n        - class: PortAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            config: *port_analysis_config\n"
  },
  {
    "path": "examples/benchmarks/MLP/workflow_config_mlp_Alpha360.yaml",
    "content": "qlib_init:\n    provider_uri: \"~/.qlib/qlib_data/cn_data\"\n    region: cn\nmarket: &market csi300\nbenchmark: &benchmark SH000300\ndata_handler_config: &data_handler_config\n    start_time: 2008-01-01\n    end_time: 2020-08-01\n    fit_start_time: 2008-01-01\n    fit_end_time: 2014-12-31\n    instruments: *market\n    infer_processors:\n        - class: RobustZScoreNorm\n          kwargs:\n              fields_group: feature\n              clip_outlier: true\n        - class: Fillna\n          kwargs:\n              fields_group: feature\n    learn_processors:\n        - class: DropnaLabel\n        - class: CSRankNorm\n          kwargs:\n              fields_group: label\n    label: [\"Ref($close, -2) / Ref($close, -1) - 1\"]\n\nport_analysis_config: &port_analysis_config\n    strategy:\n        class: TopkDropoutStrategy\n        module_path: qlib.contrib.strategy\n        kwargs:\n            signal: <PRED>\n            topk: 50\n            n_drop: 5\n    backtest:\n        start_time: 2017-01-01\n        end_time: 2020-08-01\n        account: 100000000\n        benchmark: *benchmark\n        exchange_kwargs:\n            limit_threshold: 0.095\n            deal_price: close\n            open_cost: 0.0005\n            close_cost: 0.0015\n            min_cost: 5\ntask:\n    model:\n        class: DNNModelPytorch\n        module_path: qlib.contrib.model.pytorch_nn\n        kwargs:\n            loss: mse\n            lr: 0.002\n            optimizer: adam\n            max_steps: 8000\n            batch_size: 4096\n            GPU: 0\n            pt_model_kwargs:\n              input_dim: 360\n    dataset:\n        class: DatasetH\n        module_path: qlib.data.dataset\n        kwargs:\n            handler:\n                class: Alpha360\n                module_path: qlib.contrib.data.handler\n                kwargs: *data_handler_config\n            segments:\n                train: [2008-01-01, 2014-12-31]\n                valid: [2015-01-01, 2016-12-31]\n                test: [2017-01-01, 2020-08-01]\n    record: \n        - class: SignalRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            model: <MODEL>\n            dataset: <DATASET>\n        - class: SigAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            ana_long_short: False\n            ann_scaler: 252\n        - class: PortAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            config: *port_analysis_config\n"
  },
  {
    "path": "examples/benchmarks/MLP/workflow_config_mlp_Alpha360_csi500.yaml",
    "content": "qlib_init:\n    provider_uri: \"~/.qlib/qlib_data/cn_data\"\n    region: cn\nmarket: &market csi500\nbenchmark: &benchmark SH000905\ndata_handler_config: &data_handler_config\n    start_time: 2008-01-01\n    end_time: 2020-08-01\n    fit_start_time: 2008-01-01\n    fit_end_time: 2014-12-31\n    instruments: *market\n    infer_processors:\n        - class: RobustZScoreNorm\n          kwargs:\n              fields_group: feature\n              clip_outlier: true\n        - class: Fillna\n          kwargs:\n              fields_group: feature\n    learn_processors:\n        - class: DropnaLabel\n        - class: CSRankNorm\n          kwargs:\n              fields_group: label\n    label: [\"Ref($close, -2) / Ref($close, -1) - 1\"]\n\nport_analysis_config: &port_analysis_config\n    strategy:\n        class: TopkDropoutStrategy\n        module_path: qlib.contrib.strategy\n        kwargs:\n            signal: <PRED>\n            topk: 50\n            n_drop: 5\n    backtest:\n        start_time: 2017-01-01\n        end_time: 2020-08-01\n        account: 100000000\n        benchmark: *benchmark\n        exchange_kwargs:\n            limit_threshold: 0.095\n            deal_price: close\n            open_cost: 0.0005\n            close_cost: 0.0015\n            min_cost: 5\ntask:\n    model:\n        class: DNNModelPytorch\n        module_path: qlib.contrib.model.pytorch_nn\n        kwargs:\n            loss: mse\n            lr: 0.002\n            optimizer: adam\n            max_steps: 8000\n            batch_size: 4096\n            GPU: 0\n            pt_model_kwargs:\n              input_dim: 360\n    dataset:\n        class: DatasetH\n        module_path: qlib.data.dataset\n        kwargs:\n            handler:\n                class: Alpha360\n                module_path: qlib.contrib.data.handler\n                kwargs: *data_handler_config\n            segments:\n                train: [2008-01-01, 2014-12-31]\n                valid: [2015-01-01, 2016-12-31]\n                test: [2017-01-01, 2020-08-01]\n    record: \n        - class: SignalRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            model: <MODEL>\n            dataset: <DATASET>\n        - class: SigAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            ana_long_short: False\n            ann_scaler: 252\n        - class: PortAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            config: *port_analysis_config\n"
  },
  {
    "path": "examples/benchmarks/README.md",
    "content": "# Benchmarks Performance\nThis page lists a batch of methods designed for alpha seeking. Each method tries to give scores/predictions for all stocks each day(e.g. forecasting the future excess return of stocks). The scores/predictions of the models will be used as the mined alpha. Investing in stocks with higher scores is expected to yield more profit.  \n\nThe alpha is evaluated in two ways.\n1. The correlation between the alpha and future return.\n1. Constructing portfolio based on the alpha and evaluating the final total return.\n   - The explanation of metrics can be found [here](https://qlib.readthedocs.io/en/latest/component/report.html#id4)\n\nHere are the results of each benchmark model running on Qlib's `Alpha360` and `Alpha158` dataset with China's A shared-stock & CSI300 data respectively. The values of each metric are the mean and std calculated based on 20 runs with different random seeds.\n\nThe numbers shown below demonstrate the performance of the entire `workflow` of each model. We will update the `workflow` as well as models in the near future for better results.\n<!-- \n> If you need to reproduce the results below, please use the **v1** dataset: `python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn --version v1`\n>\n> In the new version of qlib, the default dataset is **v2**. Since the data is collected from the YahooFinance API (which is not very stable), the results of *v2* and *v1* may differ -->\n\n> NOTE:\n> The backtest start from 0.8.0 is quite different from previous version. Please check out the changelog for the difference.\n\n> NOTE:\n> We have very limited resources to implement and finetune the models. We tried our best effort to fairly compare these models.  But some models may have greater potential than what it looks like in the table below.  Your contribution is highly welcomed to explore their potential.\n\n## Results on CSI300\n\n### Alpha158 dataset\n\n| Model Name                               | Dataset                             | IC          | ICIR        | Rank IC     | Rank ICIR   | Annualized Return | Information Ratio | Max Drawdown |\n|------------------------------------------|-------------------------------------|-------------|-------------|-------------|-------------|-------------------|-------------------|--------------|\n| TCN(Shaojie Bai, et al.)                 | Alpha158                            | 0.0279±0.00 | 0.2181±0.01 | 0.0421±0.00 | 0.3429±0.01 | 0.0262±0.02       | 0.4133±0.25       | -0.1090±0.03 |\n| TabNet(Sercan O. Arik, et al.)           | Alpha158                            | 0.0204±0.01 | 0.1554±0.07 | 0.0333±0.00 | 0.2552±0.05 | 0.0227±0.04       | 0.3676±0.54       | -0.1089±0.08 |\n| Transformer(Ashish Vaswani, et al.)      | Alpha158                            | 0.0264±0.00 | 0.2053±0.02 | 0.0407±0.00 | 0.3273±0.02 | 0.0273±0.02       | 0.3970±0.26       | -0.1101±0.02 |\n| GRU(Kyunghyun Cho, et al.)               | Alpha158(with selected 20 features) | 0.0315±0.00 | 0.2450±0.04 | 0.0428±0.00 | 0.3440±0.03 | 0.0344±0.02       | 0.5160±0.25       | -0.1017±0.02 |\n| LSTM(Sepp Hochreiter, et al.)            | Alpha158(with selected 20 features) | 0.0318±0.00 | 0.2367±0.04 | 0.0435±0.00 | 0.3389±0.03 | 0.0381±0.03       | 0.5561±0.46       | -0.1207±0.04 |\n| Localformer(Juyong Jiang, et al.)        | Alpha158                            | 0.0356±0.00 | 0.2756±0.03 | 0.0468±0.00 | 0.3784±0.03 | 0.0438±0.02       | 0.6600±0.33       | -0.0952±0.02 |\n| SFM(Liheng Zhang, et al.)                | Alpha158                            | 0.0379±0.00 | 0.2959±0.04 | 0.0464±0.00 | 0.3825±0.04 | 0.0465±0.02       | 0.5672±0.29       | -0.1282±0.03 |\n| ALSTM (Yao Qin, et al.)                  | Alpha158(with selected 20 features) | 0.0362±0.01 | 0.2789±0.06 | 0.0463±0.01 | 0.3661±0.05 | 0.0470±0.03       | 0.6992±0.47       | -0.1072±0.03 |\n| GATs (Petar Velickovic, et al.)          | Alpha158(with selected 20 features) | 0.0349±0.00 | 0.2511±0.01 | 0.0462±0.00 | 0.3564±0.01 | 0.0497±0.01       | 0.7338±0.19       | -0.0777±0.02 |\n| TRA(Hengxu Lin, et al.)                  | Alpha158(with selected 20 features) | 0.0404±0.00 | 0.3197±0.05 | 0.0490±0.00 | 0.4047±0.04 | 0.0649±0.02       | 1.0091±0.30       | -0.0860±0.02 |\n| Linear                                   | Alpha158                            | 0.0397±0.00 | 0.3000±0.00 | 0.0472±0.00 | 0.3531±0.00 | 0.0692±0.00       | 0.9209±0.00       | -0.1509±0.00 |\n| TRA(Hengxu Lin, et al.)                  | Alpha158                            | 0.0440±0.00 | 0.3535±0.05 | 0.0540±0.00 | 0.4451±0.03 | 0.0718±0.02       | 1.0835±0.35       | -0.0760±0.02 |\n| CatBoost(Liudmila Prokhorenkova, et al.) | Alpha158                            | 0.0481±0.00 | 0.3366±0.00 | 0.0454±0.00 | 0.3311±0.00 | 0.0765±0.00       | 0.8032±0.01       | -0.1092±0.00 |\n| XGBoost(Tianqi Chen, et al.)             | Alpha158                            | 0.0498±0.00 | 0.3779±0.00 | 0.0505±0.00 | 0.4131±0.00 | 0.0780±0.00       | 0.9070±0.00       | -0.1168±0.00 |\n| TFT (Bryan Lim, et al.)                  | Alpha158(with selected 20 features) | 0.0358±0.00 | 0.2160±0.03 | 0.0116±0.01 | 0.0720±0.03 | 0.0847±0.02       | 0.8131±0.19       | -0.1824±0.03 |\n| MLP                                      | Alpha158                            | 0.0376±0.00 | 0.2846±0.02 | 0.0429±0.00 | 0.3220±0.01 | 0.0895±0.02       | 1.1408±0.23       | -0.1103±0.02 |\n| LightGBM(Guolin Ke, et al.)              | Alpha158                            | 0.0448±0.00 | 0.3660±0.00 | 0.0469±0.00 | 0.3877±0.00 | 0.0901±0.00       | 1.0164±0.00       | -0.1038±0.00 |\n| DoubleEnsemble(Chuheng Zhang, et al.)    | Alpha158                            | 0.0521±0.00 | 0.4223±0.01 | 0.0502±0.00 | 0.4117±0.01 | 0.1158±0.01       | 1.3432±0.11       | -0.0920±0.01 |\n\n### Alpha360 dataset\n\n| Model Name                                | Dataset  | IC          | ICIR        | Rank IC     | Rank ICIR   | Annualized Return | Information Ratio | Max Drawdown |\n|-------------------------------------------|----------|-------------|-------------|-------------|-------------|-------------------|-------------------|--------------|\n| Transformer(Ashish Vaswani, et al.)       | Alpha360 | 0.0114±0.00 | 0.0716±0.03 | 0.0327±0.00 | 0.2248±0.02 | -0.0270±0.03      | -0.3378±0.37      | -0.1653±0.05 |\n| TabNet(Sercan O. Arik, et al.)            | Alpha360 | 0.0099±0.00 | 0.0593±0.00 | 0.0290±0.00 | 0.1887±0.00 | -0.0369±0.00      | -0.3892±0.00      | -0.2145±0.00 |\n| MLP                                       | Alpha360 | 0.0273±0.00 | 0.1870±0.02 | 0.0396±0.00 | 0.2910±0.02 | 0.0029±0.02       | 0.0274±0.23       | -0.1385±0.03 |\n| Localformer(Juyong Jiang, et al.)         | Alpha360 | 0.0404±0.00 | 0.2932±0.04 | 0.0542±0.00 | 0.4110±0.03 | 0.0246±0.02       | 0.3211±0.21       | -0.1095±0.02 |\n| CatBoost((Liudmila Prokhorenkova, et al.) | Alpha360 | 0.0378±0.00 | 0.2714±0.00 | 0.0467±0.00 | 0.3659±0.00 | 0.0292±0.00       | 0.3781±0.00       | -0.0862±0.00 |\n| XGBoost(Tianqi Chen, et al.)              | Alpha360 | 0.0394±0.00 | 0.2909±0.00 | 0.0448±0.00 | 0.3679±0.00 | 0.0344±0.00       | 0.4527±0.02       | -0.1004±0.00 |\n| DoubleEnsemble(Chuheng Zhang, et al.)     | Alpha360 | 0.0390±0.00 | 0.2946±0.01 | 0.0486±0.00 | 0.3836±0.01 | 0.0462±0.01       | 0.6151±0.18       | -0.0915±0.01 |\n| LightGBM(Guolin Ke, et al.)               | Alpha360 | 0.0400±0.00 | 0.3037±0.00 | 0.0499±0.00 | 0.4042±0.00 | 0.0558±0.00       | 0.7632±0.00       | -0.0659±0.00 |\n| TCN(Shaojie Bai, et al.)                  | Alpha360 | 0.0441±0.00 | 0.3301±0.02 | 0.0519±0.00 | 0.4130±0.01 | 0.0604±0.02       | 0.8295±0.34       | -0.1018±0.03 |\n| ALSTM (Yao Qin, et al.)                   | Alpha360 | 0.0497±0.00 | 0.3829±0.04 | 0.0599±0.00 | 0.4736±0.03 | 0.0626±0.02       | 0.8651±0.31       | -0.0994±0.03 |\n| LSTM(Sepp Hochreiter, et al.)             | Alpha360 | 0.0448±0.00 | 0.3474±0.04 | 0.0549±0.00 | 0.4366±0.03 | 0.0647±0.03       | 0.8963±0.39       | -0.0875±0.02 |\n| ADD                                       | Alpha360 | 0.0430±0.00 | 0.3188±0.04 | 0.0559±0.00 | 0.4301±0.03 | 0.0667±0.02       | 0.8992±0.34       | -0.0855±0.02 |\n| GRU(Kyunghyun Cho, et al.)                | Alpha360 | 0.0493±0.00 | 0.3772±0.04 | 0.0584±0.00 | 0.4638±0.03 | 0.0720±0.02       | 0.9730±0.33       | -0.0821±0.02 |\n| AdaRNN(Yuntao Du, et al.)                 | Alpha360 | 0.0464±0.01 | 0.3619±0.08 | 0.0539±0.01 | 0.4287±0.06 | 0.0753±0.03       | 1.0200±0.40       | -0.0936±0.03 |\n| GATs (Petar Velickovic, et al.)           | Alpha360 | 0.0476±0.00 | 0.3508±0.02 | 0.0598±0.00 | 0.4604±0.01 | 0.0824±0.02       | 1.1079±0.26       | -0.0894±0.03 |\n| TCTS(Xueqing Wu, et al.)                  | Alpha360 | 0.0508±0.00 | 0.3931±0.04 | 0.0599±0.00 | 0.4756±0.03 | 0.0893±0.03       | 1.2256±0.36       | -0.0857±0.02 |\n| TRA(Hengxu Lin, et al.)                   | Alpha360 | 0.0485±0.00 | 0.3787±0.03 | 0.0587±0.00 | 0.4756±0.03 | 0.0920±0.03       | 1.2789±0.42       | -0.0834±0.02 |\n| IGMTF(Wentao Xu, et al.)                  | Alpha360 | 0.0480±0.00 | 0.3589±0.02 | 0.0606±0.00 | 0.4773±0.01 | 0.0946±0.02       | 1.3509±0.25       | -0.0716±0.02 |\n| HIST(Wentao Xu, et al.)                   | Alpha360 | 0.0522±0.00 | 0.3530±0.01 | 0.0667±0.00 | 0.4576±0.01 | 0.0987±0.02       | 1.3726±0.27       | -0.0681±0.01 |\n| KRNN                                      | Alpha360 | 0.0173±0.01 | 0.1210±0.06 | 0.0270±0.01 | 0.2018±0.04 | -0.0465±0.05      | -0.5415±0.62      | -0.2919±0.13 |\n| Sandwich                                  | Alpha360 | 0.0258±0.00 | 0.1924±0.04 | 0.0337±0.00 | 0.2624±0.03 | 0.0005±0.03       | 0.0001±0.33       | -0.1752±0.05 |\n\n\n- The selected 20 features are based on the feature importance of a lightgbm-based model.\n- The base model of DoubleEnsemble is LGBM.\n- The base model of TCTS is GRU.\n- About the datasets\n  - Alpha158 is a tabular dataset. There are less spatial relationships between different features. Each feature are carefully designed by human (a.k.a feature engineering)\n  - Alpha360 contains raw price and volue data without much feature engineering. There are strong strong spatial relationships between the features in the time dimension.\n- The metrics can be categorized into two\n   - Signal-based evaluation:  IC, ICIR, Rank IC, Rank ICIR\n      - ![equation](https://latex.codecogs.com/gif.latex?%5Ctext%7Bcorr%7D%28%5Ctextbf%7Bx%7D%2C%5Ctextbf%7By%7D%29%3D%5Cfrac%7B%5Csum_i%20%28x_i-%5Cbar%7Bx%7D%29%28y_i-%5Cbar%7By%7D%29%7D%7B%5Csqrt%7B%5Csum_i%28x_i-%5Cbar%7Bx%7D%29%5E2%5Csum_i%28y_i-%5Cbar%7By%7D%29%5E2%7D%7D)\n      - ![equation](https://latex.codecogs.com/gif.latex?%5Ctext%7BIC%7D%5E%7B%28t%29%7D%20%3D%20%5Ctext%7Bcorr%7D%28%5Chat%7B%5Ctextbf%7By%7D%7D%5E%7B%28t%29%7D%2C%20%5Ctextbf%7Bret%7D%5E%7B%28t%29%7D%29)\n      - ![equation](https://latex.codecogs.com/gif.latex?%5Ctext%7BICIR%7D%20%3D%20%5Cfrac%20%7B%5Ctext%7Bmean%7D%28%5Ctextbf%7BIC%7D%29%7D%20%7B%5Ctext%7Bstd%7D%28%5Ctextbf%7BIC%7D%29%7D)\n      - ![equation](https://latex.codecogs.com/gif.latex?%5Ctext%7BRank%20IC%7D%5E%7B%28t%29%7D%20%3D%20%5Ctext%7Bcorr%7D%28%5Ctext%7Brank%7D%28%5Chat%7B%5Ctextbf%7By%7D%7D%5E%7B%28t%29%7D%29%2C%20%5Ctext%7Brank%7D%28%5Ctextbf%7Bret%7D%5E%7B%28t%29%7D%29%29)\n      - ![equation](https://latex.codecogs.com/gif.latex?%5Ctext%7BRank%20ICIR%7D%20%3D%20%5Cfrac%20%7B%5Ctext%7Bmean%7D%28%5Ctextbf%7BRank%20IC%7D%29%7D%20%7B%5Ctext%7Bstd%7D%28%5Ctextbf%7BRankIC%7D%29%7D)\n   - Portfolio-based metrics:  Annualized Return, Information Ratio, Max Drawdown\n\n## Results on CSI500\nThe results on CSI500 is not complete. PR's for models on csi500 are welcome!\n\nTransfer previous models in CSI300 to CSI500 is quite easy.  You can try models with just a few commands below.\n```\ncd examples/benchmarks/LightGBM\npip install -r requirements.txt\n\n# create new config and set the benchmark to csi500\ncp workflow_config_lightgbm_Alpha158.yaml workflow_config_lightgbm_Alpha158_csi500.yaml\nsed -i \"s/csi300/csi500/g\"  workflow_config_lightgbm_Alpha158_csi500.yaml\nsed -i \"s/SH000300/SH000905/g\"  workflow_config_lightgbm_Alpha158_csi500.yaml\n\n# you can either run the model once\nqrun workflow_config_lightgbm_Alpha158_csi500.yaml\n\n# or run it for multiple times automatically and get the summarized results.\ncd  ../../\npython run_all_model.py run 3 lightgbm Alpha158 csi500  # for models with randomness.  please run it for 20 times.\n```\n\n### Alpha158 dataset\n| Model Name | Dataset  | IC          | ICIR        | Rank IC     | Rank ICIR   | Annualized Return | Information Ratio | Max Drawdown |\n|------------|----------|-------------|-------------|-------------|-------------|-------------------|-------------------|--------------|\n| Linear     | Alpha158 | 0.0332±0.00 | 0.3044±0.00 | 0.0462±0.00 | 0.4326±0.00 | 0.0382±0.00       | 0.1723±0.00       | -0.4876±0.00 |\n| MLP        | Alpha158 | 0.0229±0.01 | 0.2181±0.05 | 0.0360±0.00 | 0.3409±0.02 | 0.0043±0.02       | 0.0602±0.27       | -0.2184±0.04 |\n| LightGBM   | Alpha158 | 0.0399±0.00 | 0.4065±0.00 | 0.0482±0.00 | 0.5101±0.00 | 0.1284±0.00       | 1.5650±0.00       | -0.0635±0.00 |\n| CatBoost   | Alpha158 | 0.0345±0.00 | 0.2855±0.00 | 0.0417±0.00 | 0.3740±0.00 | 0.0496±0.00       | 0.5977±0.00       | -0.1496±0.00 |\n| DoubleEnsemble  | Alpha158 | 0.0380±0.00 | 0.3659±0.00 | 0.0442±0.00 | 0.4324±0.00 | 0.0382±0.00       | 0.1723±0.00       | -0.4876±0.00 |\n\n### Alpha360 dataset\n| Model Name | Dataset  | IC          | ICIR        | Rank IC     | Rank ICIR   | Annualized Return | Information Ratio | Max Drawdown |\n|------------|----------|-------------|-------------|-------------|-------------|-------------------|-------------------|--------------|\n| MLP        | Alpha360 | 0.0258±0.00 | 0.2021±0.02 | 0.0426±0.00 | 0.3840±0.02 | 0.0022±0.02       | 0.0301±0.26       | -0.2064±0.02 |\n| LightGBM   | Alpha360 | 0.0400±0.00 | 0.3605±0.00 | 0.0536±0.00 | 0.5431±0.00 | 0.0505±0.00       | 0.7658±0.02       | -0.1880±0.00 |\n| CatBoost   | Alpha360 | 0.0382±0.00 | 0.3229±0.00 | 0.0489±0.00 | 0.4649±0.00 | 0.0297±0.00       | 0.4227±0.02       | -0.1499±0.01 |\n| DoubleEnsemble  | Alpha360 | 0.0361±0.00 | 0.3092±0.00 | 0.0499±0.00 | 0.4793±0.00 | 0.0382±0.00       | 0.1723±0.02       | -0.4876±0.00 |\n\n# Contributing\n\nYour contributions to new models are highly welcome!\n\nIf you want to contribute your new models, you can follow the steps below.\n1. Create a folder for your model\n2. The folder contains following items(you can refer to [this example](https://github.com/microsoft/qlib/tree/main/examples/benchmarks/TCTS)).\n    - `requirements.txt`: required dependencies.\n    - `README.md`: a brief introduction to your models\n    - `workflow_config_<model name>_<dataset>.yaml`: a configuration which can read by `qrun`. You are encouraged to run your model in all datasets.\n3. You can integrate your model as a module [in this folder](https://github.com/microsoft/qlib/tree/main/qlib/contrib/model).\n4. Please update your results in the above **Benchmark Tables**, e.g. [Alpha360](#alpha158-dataset), [Alpha158](#alpha158-dataset)(the values of each metric are the mean and std calculated based on **20 Runs** with different random seeds. You can accomplish the above operations through the automated [script](https://github.com/microsoft/qlib/blob/main/examples/run_all_model.py) provided by Qlib, and get the final result in the .md file. if you don't have enough computational resource, you can ask for help in the PR).\n5. Update the info in the index page in the [news list](https://github.com/microsoft/qlib#newspaper-whats-new----sparkling_heart) and [model list](https://github.com/microsoft/qlib#quant-model-paper-zoo).\n\nFinally, you can send PR for review. ([here is an example](https://github.com/microsoft/qlib/pull/1040))\n\n\n# FAQ\n\nQ: What's the difference between models with name `*.py` and `*_ts.py`?\n\nA: Models with name `*_ts.py` are designed for `TSDatasetH` (`TSDatasetH` will create time-series automatically from tabular data).  Models with name `*.py` are designed for `DatasetH` (`DatasetH` is usually used in tabular data.  But users still can apply time-series models on tabular datasets if the columns has time-series relationships). \n"
  },
  {
    "path": "examples/benchmarks/SFM/README.md",
    "content": "# State-Frequency-Memory\n- State Frequency Memory (SFM) is a novel recurrent network that uses Discrete Fourier Transform to decompose the hidden states of memory cells and capture the multi-frequency trading patterns from past market data to make stock price predictions. \n- Paper: Stock Price Prediction via Discovering Multi-Frequency Trading Patterns. [http://www.eecs.ucf.edu/~gqi/publications/kdd2017_stock.pdf.](http://www.eecs.ucf.edu/~gqi/publications/kdd2017_stock.pdf)"
  },
  {
    "path": "examples/benchmarks/SFM/requirements.txt",
    "content": "pandas==1.1.2\nnumpy==1.21.0\nscikit_learn==0.23.2\ntorch==1.7.0\n"
  },
  {
    "path": "examples/benchmarks/SFM/workflow_config_sfm_Alpha360.yaml",
    "content": "qlib_init:\n    provider_uri: \"~/.qlib/qlib_data/cn_data\"\n    region: cn\nmarket: &market csi300\nbenchmark: &benchmark SH000300\ndata_handler_config: &data_handler_config\n    start_time: 2008-01-01\n    end_time: 2020-08-01\n    fit_start_time: 2008-01-01\n    fit_end_time: 2014-12-31\n    instruments: *market\n    infer_processors:\n        - class: RobustZScoreNorm\n          kwargs:\n              fields_group: feature\n              clip_outlier: true\n        - class: Fillna\n          kwargs:\n              fields_group: feature\n    learn_processors:\n        - class: DropnaLabel\n        - class: CSRankNorm\n          kwargs:\n              fields_group: label\n    label: [\"Ref($close, -2) / Ref($close, -1) - 1\"]\nport_analysis_config: &port_analysis_config\n    strategy:\n        class: TopkDropoutStrategy\n        module_path: qlib.contrib.strategy\n        kwargs:\n            signal: <PRED>\n            topk: 50\n            n_drop: 5\n    backtest:\n        start_time: 2017-01-01\n        end_time: 2020-08-01\n        account: 100000000\n        benchmark: *benchmark\n        exchange_kwargs:\n            limit_threshold: 0.095\n            deal_price: close\n            open_cost: 0.0005\n            close_cost: 0.0015\n            min_cost: 5\ntask:\n    model:\n        class: SFM\n        module_path: qlib.contrib.model.pytorch_sfm\n        kwargs:\n            d_feat: 6\n            hidden_size: 64\n            output_dim: 32\n            freq_dim: 25\n            dropout_W: 0.5\n            dropout_U: 0.5\n            n_epochs: 20\n            lr: 1e-3\n            batch_size: 1600\n            early_stop: 20\n            eval_steps: 5\n            loss: mse\n            optimizer: adam\n            GPU: 0\n    dataset:\n        class: DatasetH\n        module_path: qlib.data.dataset\n        kwargs:\n            handler:\n                class: Alpha360\n                module_path: qlib.contrib.data.handler\n                kwargs: *data_handler_config\n            segments:\n                train: [2008-01-01, 2014-12-31]\n                valid: [2015-01-01, 2016-12-31]\n                test: [2017-01-01, 2020-08-01]\n    record: \n        - class: SignalRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            model: <MODEL>\n            dataset: <DATASET>\n        - class: SigAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            ana_long_short: False\n            ann_scaler: 252\n        - class: PortAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            config: *port_analysis_config\n"
  },
  {
    "path": "examples/benchmarks/Sandwich/README.md",
    "content": "# Sandwich\n* Code: [https://github.com/microsoft/FOST/blob/main/fostool/model/sandwich.py](https://github.com/microsoft/FOST/blob/main/fostool/model/sandwich.py)\n\n\n# Introductions about the settings/configs.\n* Torch_geometric is used in the original model in FOST, but we didn't use it.\nmake use your CUDA version matches the torch version to allow the usage of GPU, we use CUDA==10.2 and torch.version==1.12.1\n\n"
  },
  {
    "path": "examples/benchmarks/Sandwich/requirements.txt",
    "content": "numpy==1.23.4\npandas==1.5.2\n"
  },
  {
    "path": "examples/benchmarks/Sandwich/workflow_config_sandwich_Alpha360.yaml",
    "content": "qlib_init:\n    provider_uri: \"~/.qlib/qlib_data/cn_data\"\n    region: cn\nmarket: &market csi300\nbenchmark: &benchmark SH000300\ndata_handler_config: &data_handler_config\n    start_time: 2008-01-01\n    end_time: 2020-08-01\n    fit_start_time: 2008-01-01\n    fit_end_time: 2014-12-31\n    instruments: *market\n    infer_processors:\n        - class: RobustZScoreNorm\n          kwargs:\n              fields_group: feature\n              clip_outlier: true\n        - class: Fillna\n          kwargs:\n              fields_group: feature\n    learn_processors:\n        - class: DropnaLabel\n        - class: CSRankNorm\n          kwargs:\n              fields_group: label\n    label: [\"Ref($close, -2) / Ref($close, -1) - 1\"]\nport_analysis_config: &port_analysis_config\n    strategy:\n        class: TopkDropoutStrategy\n        module_path: qlib.contrib.strategy\n        kwargs:\n            signal: <PRED>\n            topk: 50\n            n_drop: 5\n    backtest:\n        start_time: 2017-01-01\n        end_time: 2020-08-01\n        account: 100000000\n        benchmark: *benchmark\n        exchange_kwargs:\n            limit_threshold: 0.095\n            deal_price: close\n            open_cost: 0.0005\n            close_cost: 0.0015\n            min_cost: 5\ntask:\n    model:\n        class: Sandwich\n        module_path: qlib.contrib.model.pytorch_sandwich\n        kwargs:\n            fea_dim: 6\n            cnn_dim_1: 16\n            cnn_dim_2: 16\n            cnn_kernel_size: 3\n            rnn_dim_1: 8\n            rnn_dim_2: 8\n            rnn_dups: 2\n            rnn_layers: 2\n            n_epochs: 200\n            lr: 0.001\n            early_stop: 20\n            batch_size: 2000\n            metric: loss\n            GPU: 0\n    dataset:\n        class: DatasetH\n        module_path: qlib.data.dataset\n        kwargs:\n            handler:\n                class: Alpha360\n                module_path: qlib.contrib.data.handler\n                kwargs: *data_handler_config\n            segments:\n                train: [2008-01-01, 2014-12-31]\n                valid: [2015-01-01, 2016-12-31]\n                test: [2017-01-01, 2020-08-01]\n    record: \n        - class: SignalRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            model: <MODEL>\n            dataset: <DATASET>\n        - class: SigAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            ana_long_short: False\n            ann_scaler: 252\n        - class: PortAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            config: *port_analysis_config\n\n"
  },
  {
    "path": "examples/benchmarks/TCN/README.md",
    "content": "# TCN\n* Code: [https://github.com/locuslab/TCN](https://github.com/locuslab/TCN)\n* Paper: [An Empirical Evaluation of Generic Convolutional and Recurrent Networks for Sequence Modeling](https://arxiv.org/abs/1803.01271).\n\n"
  },
  {
    "path": "examples/benchmarks/TCN/requirements.txt",
    "content": "numpy==1.21.0\npandas==1.1.2\nscikit_learn==0.23.2\ntorch==1.7.0\n"
  },
  {
    "path": "examples/benchmarks/TCN/workflow_config_tcn_Alpha158.yaml",
    "content": "qlib_init:\n    provider_uri: \"~/.qlib/qlib_data/cn_data\"\n    region: cn\nmarket: &market csi300\nbenchmark: &benchmark SH000300\ndata_handler_config: &data_handler_config\n    start_time: 2008-01-01\n    end_time: 2020-08-01\n    fit_start_time: 2008-01-01\n    fit_end_time: 2014-12-31\n    instruments: *market\n    infer_processors:\n        - class: FilterCol\n          kwargs:\n              fields_group: feature\n              col_list: [\"RESI5\", \"WVMA5\", \"RSQR5\", \"KLEN\", \"RSQR10\", \"CORR5\", \"CORD5\", \"CORR10\", \n                            \"ROC60\", \"RESI10\", \"VSTD5\", \"RSQR60\", \"CORR60\", \"WVMA60\", \"STD5\", \n                            \"RSQR20\", \"CORD60\", \"CORD10\", \"CORR20\", \"KLOW\"\n                        ]\n        - class: RobustZScoreNorm\n          kwargs:\n              fields_group: feature\n              clip_outlier: true\n        - class: Fillna\n          kwargs:\n              fields_group: feature\n    learn_processors:\n        - class: DropnaLabel\n        - class: CSRankNorm\n          kwargs:\n              fields_group: label\n    label: [\"Ref($close, -2) / Ref($close, -1) - 1\"]\n\nport_analysis_config: &port_analysis_config\n    strategy:\n        class: TopkDropoutStrategy\n        module_path: qlib.contrib.strategy\n        kwargs:\n            signal: <PRED>\n            topk: 50\n            n_drop: 5\n    backtest:\n        start_time: 2017-01-01\n        end_time: 2020-08-01\n        account: 100000000\n        benchmark: *benchmark\n        exchange_kwargs:\n            limit_threshold: 0.095\n            deal_price: close\n            open_cost: 0.0005\n            close_cost: 0.0015\n            min_cost: 5\ntask:\n    model:\n        class: TCN\n        module_path: qlib.contrib.model.pytorch_tcn_ts\n        kwargs:\n            d_feat: 20\n            num_layers: 5\n            n_chans: 32\n            kernel_size: 7\n            dropout: 0.5\n            n_epochs: 200\n            lr: 1e-4\n            early_stop: 20\n            batch_size: 2000\n            metric: loss\n            loss: mse\n            optimizer: adam\n            n_jobs: 20\n            GPU: 0\n    dataset:\n        class: TSDatasetH\n        module_path: qlib.data.dataset\n        kwargs:\n            handler:\n                class: Alpha158\n                module_path: qlib.contrib.data.handler\n                kwargs: *data_handler_config\n            segments:\n                train: [2008-01-01, 2014-12-31]\n                valid: [2015-01-01, 2016-12-31]\n                test: [2017-01-01, 2020-08-01]\n            step_len: 20\n    record: \n        - class: SignalRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            model: <MODEL>\n            dataset: <DATASET>\n        - class: SigAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            ana_long_short: False\n            ann_scaler: 252\n        - class: PortAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            config: *port_analysis_config\n"
  },
  {
    "path": "examples/benchmarks/TCN/workflow_config_tcn_Alpha360.yaml",
    "content": "qlib_init:\n    provider_uri: \"~/.qlib/qlib_data/cn_data\"\n    region: cn\nmarket: &market csi300\nbenchmark: &benchmark SH000300\ndata_handler_config: &data_handler_config\n    start_time: 2008-01-01\n    end_time: 2020-08-01\n    fit_start_time: 2008-01-01\n    fit_end_time: 2014-12-31\n    instruments: *market\n    infer_processors:\n        - class: RobustZScoreNorm\n          kwargs:\n              fields_group: feature\n              clip_outlier: true\n        - class: Fillna\n          kwargs:\n              fields_group: feature\n    learn_processors:\n        - class: DropnaLabel\n        - class: CSRankNorm\n          kwargs:\n              fields_group: label\n    label: [\"Ref($close, -2) / Ref($close, -1) - 1\"]\nport_analysis_config: &port_analysis_config\n    strategy:\n        class: TopkDropoutStrategy\n        module_path: qlib.contrib.strategy\n        kwargs:\n            signal: <PRED>\n            topk: 50\n            n_drop: 5\n    backtest:\n        start_time: 2017-01-01\n        end_time: 2020-08-01\n        account: 100000000\n        benchmark: *benchmark\n        exchange_kwargs:\n            limit_threshold: 0.095\n            deal_price: close\n            open_cost: 0.0005\n            close_cost: 0.0015\n            min_cost: 5\ntask:\n    model:\n        class: TCN\n        module_path: qlib.contrib.model.pytorch_tcn\n        kwargs:\n            d_feat: 6\n            num_layers: 5\n            n_chans: 128\n            kernel_size: 3\n            dropout: 0.5\n            n_epochs: 200\n            lr: 1e-3\n            early_stop: 20\n            batch_size: 2000\n            metric: loss\n            loss: mse\n            optimizer: adam\n            GPU: 0\n    dataset:\n        class: DatasetH\n        module_path: qlib.data.dataset\n        kwargs:\n            handler:\n                class: Alpha360\n                module_path: qlib.contrib.data.handler\n                kwargs: *data_handler_config\n            segments:\n                train: [2008-01-01, 2014-12-31]\n                valid: [2015-01-01, 2016-12-31]\n                test: [2017-01-01, 2020-08-01]\n    record: \n        - class: SignalRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            model: <MODEL>\n            dataset: <DATASET>\n        - class: SigAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            ana_long_short: False\n            ann_scaler: 252\n        - class: PortAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            config: *port_analysis_config\n"
  },
  {
    "path": "examples/benchmarks/TCTS/README.md",
    "content": "# Temporally Correlated Task Scheduling for Sequence Learning\n### Background\nSequence learning has attracted much research attention from the machine learning community in recent years. In many applications, a sequence learning task is usually associated with multiple temporally correlated auxiliary tasks, which are different in terms of how much input information to use or which future step to predict. In stock trend forecasting, as demonstrated in Figure1, one can predict the price of a stock in different future days (e.g., tomorrow, the day after tomorrow). In this paper, we propose a framework to make use of those temporally correlated tasks to help each other. \n\n### Method\nGiven that there are usually multiple temporally correlated tasks, the key challenge lies in which tasks to use and when to use them in the training process. This work introduces a learnable task scheduler for sequence learning, which adaptively selects temporally correlated tasks during the training process. The scheduler accesses the model status and the current training data (e.g., in the current minibatch) and selects the best auxiliary task to help the training of the main task. The scheduler and the model for the main task are jointly trained through bi-level optimization: the scheduler is trained to maximize the validation performance of the model, and the model is trained to minimize the training loss guided by the scheduler. The process is demonstrated in Figure2.\n\n<p align=\"center\"> \n<img src=\"workflow.png\"/>\n</p>\n\nAt step <img src=\"https://latex.codecogs.com/png.latex?s\" title=\"s\" />, with training data <img src=\"https://latex.codecogs.com/png.latex?x_s,y_s\" title=\"x_s,y_s\" />, the scheduler <img src=\"https://latex.codecogs.com/png.latex?\\varphi\" title=\"\\varphi\" /> chooses a suitable task <img src=\"https://latex.codecogs.com/png.latex?T_{i_s}\" title=\"T_{i_s}\" /> (green solid lines) to update the model <img src=\"https://latex.codecogs.com/png.latex?f\" title=\"f\" /> (blue solid lines). After <img src=\"https://latex.codecogs.com/png.latex?S\" title=\"S\" /> steps, we evaluate the model <img src=\"https://latex.codecogs.com/png.latex?f\" title=\"f\" /> on the validation set and update the scheduler <img src=\"https://latex.codecogs.com/png.latex?\\varphi\" title=\"\\varphi\" /> (green dashed lines).\n\n### Experiments\nDue to different data versions and different Qlib versions, the original data and data preprocessing methods of the experimental settings in the paper are different from those experimental settings in the existing Qlib version. Therefore, we provide two versions of the code according to the two kinds of settings, 1) the [code](https://github.com/lwwang1995/tcts) that can be used to reproduce the experimental results and 2) the [code](https://github.com/microsoft/qlib/blob/main/qlib/contrib/model/pytorch_tcts.py) in the current Qlib baseline.\n\n#### Setting1\n* Dataset: We use the historical transaction data for 300 stocks on [CSI300](http://www.csindex.com.cn/en/indices/index-detail/000300) from 01/01/2008 to 08/01/2020. We split the data into training (01/01/2008-12/31/2013), validation (01/01/2014-12/31/2015), and test sets (01/01/2016-08/01/2020) based on the transaction time. \n\n* The main tasks <img src=\"https://latex.codecogs.com/png.latex?T_k\" title=\"T_k\" /> refers to forecasting return of stock <img src=\"https://latex.codecogs.com/png.latex?i\" title=\"i\" /> as following,\n<div align=center>\n<img src=\"https://latex.codecogs.com/png.image?\\dpi{110}&space;r_{i}^{t,k}&space;=&space;\\frac{price_i^{t&plus;k}}{price_i^{t&plus;k-1}}-1\" title=\"r_{i}^{t,k} = \\frac{price_i^{t+k}}{price_i^{t+k-1}}-1\" />\n</div>\n\n* Temporally correlated task sets <img src=\"https://latex.codecogs.com/png.latex?\\mathcal{T}_k&space;=&space;\\{T_1,&space;T_2,&space;...&space;,&space;T_k\\}\" title=\"\\mathcal{T}_k = \\{T_1, T_2, ... , T_k\\}\" />, in this paper, <img src=\"https://latex.codecogs.com/png.latex?\\mathcal{T}_3\" title=\"\\mathcal{T}_3\" />, <img src=\"https://latex.codecogs.com/png.latex?\\mathcal{T}_5\" title=\"\\mathcal{T}_5\" /> and <img src=\"https://latex.codecogs.com/png.latex?\\mathcal{T}_{10}\" title=\"\\mathcal{T}_{10}\" /> are used in <img src=\"https://latex.codecogs.com/png.latex?T_1\" title=\"T_1\" />, <img src=\"https://latex.codecogs.com/png.latex?T_2\" title=\"T_2\" />, and <img src=\"https://latex.codecogs.com/png.latex?T_3\" title=\"T_3\" />.\n\n#### Setting2\n* Dataset: We use the historical transaction data for 300 stocks on [CSI300](http://www.csindex.com.cn/en/indices/index-detail/000300) from 01/01/2008 to 08/01/2020. We split the data into training (01/01/2008-12/31/2014), validation (01/01/2015-12/31/2016), and test sets (01/01/2017-08/01/2020) based on the transaction time. \n\n* The main tasks <img src=\"https://latex.codecogs.com/png.latex?T_k\" title=\"T_k\" /> refers to forecasting return of stock <img src=\"https://latex.codecogs.com/png.latex?i\" title=\"i\" /> as following,\n<div align=center>\n<img src=\"https://latex.codecogs.com/png.image?\\dpi{110}&space;r_{i}^{t,k}&space;=&space;\\frac{price_i^{t&plus;1&plus;k}}{price_i^{t&plus;1}}-1\" title=\"r_{i}^{t,k} = \\frac{price_i^{t+1+k}}{price_i^{t+1}}-1\" />\n</div>\n\n* In Qlib baseline, <img src=\"https://latex.codecogs.com/png.latex?\\mathcal{T}_3\" title=\"\\mathcal{T}_3\" />, is used in  <img src=\"https://latex.codecogs.com/png.latex?T_1\" title=\"T_1\" />.\n\n### Experimental Result\nYou can find the experimental result of setting1 in the [paper](http://proceedings.mlr.press/v139/wu21e/wu21e.pdf) and the experimental result of setting2 in this [page](https://github.com/microsoft/qlib/tree/main/examples/benchmarks)."
  },
  {
    "path": "examples/benchmarks/TCTS/requirements.txt",
    "content": "pandas==1.1.2\nnumpy==1.21.0\nscikit_learn==0.23.2\ntorch==1.7.0"
  },
  {
    "path": "examples/benchmarks/TCTS/workflow_config_tcts_Alpha360.yaml",
    "content": "qlib_init:\n    provider_uri: \"~/.qlib/qlib_data/cn_data\"\n    region: cn\nmarket: &market csi300\nbenchmark: &benchmark SH000300\ndata_handler_config: &data_handler_config\n    start_time: 2008-01-01\n    end_time: 2020-08-01\n    fit_start_time: 2008-01-01\n    fit_end_time: 2014-12-31\n    instruments: *market\n    infer_processors:\n        - class: RobustZScoreNorm\n          kwargs:\n              fields_group: feature\n              clip_outlier: true\n        - class: Fillna\n          kwargs:\n              fields_group: feature\n    learn_processors:\n        - class: DropnaLabel\n        - class: CSRankNorm\n          kwargs:\n              fields_group: label\n    label: [\"Ref($close, -2) / Ref($close, -1) - 1\", \n            \"Ref($close, -3) / Ref($close, -1) - 1\", \n            \"Ref($close, -4) / Ref($close, -1) - 1\"]\nport_analysis_config: &port_analysis_config\n    strategy:\n        class: TopkDropoutStrategy\n        module_path: qlib.contrib.strategy\n        kwargs:\n            signal: <PRED>\n            topk: 50\n            n_drop: 5\n    backtest:\n        start_time: 2017-01-01\n        end_time: 2020-08-01\n        account: 100000000\n        benchmark: *benchmark\n        exchange_kwargs:\n            limit_threshold: 0.095\n            deal_price: close\n            open_cost: 0.0005\n            close_cost: 0.0015\n            min_cost: 5\ntask:\n    model:\n        class: TCTS\n        module_path: qlib.contrib.model.pytorch_tcts\n        kwargs:\n            d_feat: 6\n            hidden_size: 64\n            num_layers: 2\n            dropout: 0.3\n            n_epochs: 200\n            early_stop: 20\n            batch_size: 800\n            metric: loss\n            loss: mse\n            GPU: 0\n            fore_optimizer: adam\n            weight_optimizer: adam\n            output_dim: 3\n            fore_lr: 2e-3\n            weight_lr: 2e-3\n            steps: 3\n            target_label: 0\n            lowest_valid_performance: 0.993\n    dataset:\n        class: DatasetH\n        module_path: qlib.data.dataset\n        kwargs:\n            handler:\n                class: Alpha360\n                module_path: qlib.contrib.data.handler\n                kwargs: *data_handler_config\n            segments:\n                train: [2008-01-01, 2014-12-31]\n                valid: [2015-01-01, 2016-12-31]\n                test: [2017-01-01, 2020-08-01]\n    record: \n        - class: SignalRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            model: <MODEL>\n            dataset: <DATASET>\n        - class: SigAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            ana_long_short: False\n            ann_scaler: 252\n        - class: PortAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            config: *port_analysis_config\n"
  },
  {
    "path": "examples/benchmarks/TFT/README.md",
    "content": "# Temporal Fusion Transformers Benchmark\r\n## Source\r\n**Reference**: Lim, Bryan, et al. \"Temporal fusion transformers for interpretable multi-horizon time series forecasting.\" arXiv preprint arXiv:1912.09363 (2019).\r\n\r\n**GitHub**: https://github.com/google-research/google-research/tree/master/tft\r\n\r\n## Run the Workflow\r\nUsers can follow the ``workflow_by_code_tft.py`` to run the benchmark. \r\n\r\n### Notes\r\n1. Please be **aware** that this script can only support `Python 3.6 - 3.7`.\r\n2. If the CUDA version on your machine is not 10.0, please remember to run the following commands `conda install anaconda cudatoolkit=10.0` and `conda install cudnn` on your machine.\r\n3. The model must run in GPU, or an error will be raised.\r\n4. New datasets should be registered in ``data_formatters``, for detail please visit the source.\r\n"
  },
  {
    "path": "examples/benchmarks/TFT/data_formatters/__init__.py",
    "content": "# coding=utf-8\r\n# Copyright 2020 The Google Research Authors.\r\n#\r\n# Licensed under the Apache License, Version 2.0 (the \"License\");\r\n# you may not use this file except in compliance with the License.\r\n# You may obtain a copy of the License at\r\n#\r\n#     http://www.apache.org/licenses/LICENSE-2.0\r\n#\r\n# Unless required by applicable law or agreed to in writing, software\r\n# distributed under the License is distributed on an \"AS IS\" BASIS,\r\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\r\n# See the License for the specific language governing permissions and\r\n# limitations under the License.\r\n"
  },
  {
    "path": "examples/benchmarks/TFT/data_formatters/base.py",
    "content": "# coding=utf-8\r\n# Copyright 2020 The Google Research Authors.\r\n#\r\n# Licensed under the Apache License, Version 2.0 (the \"License\");\r\n# you may not use this file except in compliance with the License.\r\n# You may obtain a copy of the License at\r\n#\r\n#     http://www.apache.org/licenses/LICENSE-2.0\r\n#\r\n# Unless required by applicable law or agreed to in writing, software\r\n# distributed under the License is distributed on an \"AS IS\" BASIS,\r\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\r\n# See the License for the specific language governing permissions and\r\n# limitations under the License.\r\n\r\n# Lint as: python3\r\n\"\"\"Default data formatting functions for experiments.\r\n\r\nFor new datasets, inherit form GenericDataFormatter and implement\r\nall abstract functions.\r\n\r\nThese dataset-specific methods:\r\n1) Define the column and input types for tabular dataframes used by model\r\n2) Perform the necessary input feature engineering & normalisation steps\r\n3) Reverts the normalisation for predictions\r\n4) Are responsible for train, validation and test splits\r\n\r\n\r\n\"\"\"\r\n\r\nimport abc\r\nimport enum\r\n\r\n\r\n# Type definitions\r\nclass DataTypes(enum.IntEnum):\r\n    \"\"\"Defines numerical types of each column.\"\"\"\r\n\r\n    REAL_VALUED = 0\r\n    CATEGORICAL = 1\r\n    DATE = 2\r\n\r\n\r\nclass InputTypes(enum.IntEnum):\r\n    \"\"\"Defines input types of each column.\"\"\"\r\n\r\n    TARGET = 0\r\n    OBSERVED_INPUT = 1\r\n    KNOWN_INPUT = 2\r\n    STATIC_INPUT = 3\r\n    ID = 4  # Single column used as an entity identifier\r\n    TIME = 5  # Single column exclusively used as a time index\r\n\r\n\r\nclass GenericDataFormatter(abc.ABC):\r\n    \"\"\"Abstract base class for all data formatters.\r\n\r\n    User can implement the abstract methods below to perform dataset-specific\r\n    manipulations.\r\n\r\n    \"\"\"\r\n\r\n    @abc.abstractmethod\r\n    def set_scalers(self, df):\r\n        \"\"\"Calibrates scalers using the data supplied.\"\"\"\r\n        raise NotImplementedError()\r\n\r\n    @abc.abstractmethod\r\n    def transform_inputs(self, df):\r\n        \"\"\"Performs feature transformation.\"\"\"\r\n        raise NotImplementedError()\r\n\r\n    @abc.abstractmethod\r\n    def format_predictions(self, df):\r\n        \"\"\"Reverts any normalisation to give predictions in original scale.\"\"\"\r\n        raise NotImplementedError()\r\n\r\n    @abc.abstractmethod\r\n    def split_data(self, df):\r\n        \"\"\"Performs the default train, validation and test splits.\"\"\"\r\n        raise NotImplementedError()\r\n\r\n    @property\r\n    @abc.abstractmethod\r\n    def _column_definition(self):\r\n        \"\"\"Defines order, input type and data type of each column.\"\"\"\r\n        raise NotImplementedError()\r\n\r\n    @abc.abstractmethod\r\n    def get_fixed_params(self):\r\n        \"\"\"Defines the fixed parameters used by the model for training.\r\n\r\n        Requires the following keys:\r\n          'total_time_steps': Defines the total number of time steps used by TFT\r\n          'num_encoder_steps': Determines length of LSTM encoder (i.e. history)\r\n          'num_epochs': Maximum number of epochs for training\r\n          'early_stopping_patience': Early stopping param for keras\r\n          'multiprocessing_workers': # of cpus for data processing\r\n\r\n\r\n        Returns:\r\n          A dictionary of fixed parameters, e.g.:\r\n\r\n          fixed_params = {\r\n              'total_time_steps': 252 + 5,\r\n              'num_encoder_steps': 252,\r\n              'num_epochs': 100,\r\n              'early_stopping_patience': 5,\r\n              'multiprocessing_workers': 5,\r\n          }\r\n        \"\"\"\r\n        raise NotImplementedError\r\n\r\n    # Shared functions across data-formatters\r\n    @property\r\n    def num_classes_per_cat_input(self):\r\n        \"\"\"Returns number of categories per relevant input.\r\n\r\n        This is seqeuently required for keras embedding layers.\r\n        \"\"\"\r\n        return self._num_classes_per_cat_input\r\n\r\n    def get_num_samples_for_calibration(self):\r\n        \"\"\"Gets the default number of training and validation samples.\r\n\r\n        Use to sub-sample the data for network calibration and a value of -1 uses\r\n        all available samples.\r\n\r\n        Returns:\r\n          Tuple of (training samples, validation samples)\r\n        \"\"\"\r\n        return -1, -1\r\n\r\n    def get_column_definition(self):\r\n        \"\"\"Returns formatted column definition in order expected by the TFT.\"\"\"\r\n\r\n        column_definition = self._column_definition\r\n\r\n        # Sanity checks first.\r\n        # Ensure only one ID and time column exist\r\n        def _check_single_column(input_type):\r\n            length = len([tup for tup in column_definition if tup[2] == input_type])\r\n\r\n            if length != 1:\r\n                raise ValueError(\"Illegal number of inputs ({}) of type {}\".format(length, input_type))\r\n\r\n        _check_single_column(InputTypes.ID)\r\n        _check_single_column(InputTypes.TIME)\r\n\r\n        identifier = [tup for tup in column_definition if tup[2] == InputTypes.ID]\r\n        time = [tup for tup in column_definition if tup[2] == InputTypes.TIME]\r\n        real_inputs = [\r\n            tup\r\n            for tup in column_definition\r\n            if tup[1] == DataTypes.REAL_VALUED and tup[2] not in {InputTypes.ID, InputTypes.TIME}\r\n        ]\r\n        categorical_inputs = [\r\n            tup\r\n            for tup in column_definition\r\n            if tup[1] == DataTypes.CATEGORICAL and tup[2] not in {InputTypes.ID, InputTypes.TIME}\r\n        ]\r\n\r\n        return identifier + time + real_inputs + categorical_inputs\r\n\r\n    def _get_input_columns(self):\r\n        \"\"\"Returns names of all input columns.\"\"\"\r\n        return [tup[0] for tup in self.get_column_definition() if tup[2] not in {InputTypes.ID, InputTypes.TIME}]\r\n\r\n    def _get_tft_input_indices(self):\r\n        \"\"\"Returns the relevant indexes and input sizes required by TFT.\"\"\"\r\n\r\n        # Functions\r\n        def _extract_tuples_from_data_type(data_type, defn):\r\n            return [tup for tup in defn if tup[1] == data_type and tup[2] not in {InputTypes.ID, InputTypes.TIME}]\r\n\r\n        def _get_locations(input_types, defn):\r\n            return [i for i, tup in enumerate(defn) if tup[2] in input_types]\r\n\r\n        # Start extraction\r\n        column_definition = [\r\n            tup for tup in self.get_column_definition() if tup[2] not in {InputTypes.ID, InputTypes.TIME}\r\n        ]\r\n\r\n        categorical_inputs = _extract_tuples_from_data_type(DataTypes.CATEGORICAL, column_definition)\r\n        real_inputs = _extract_tuples_from_data_type(DataTypes.REAL_VALUED, column_definition)\r\n\r\n        locations = {\r\n            \"input_size\": len(self._get_input_columns()),\r\n            \"output_size\": len(_get_locations({InputTypes.TARGET}, column_definition)),\r\n            \"category_counts\": self.num_classes_per_cat_input,\r\n            \"input_obs_loc\": _get_locations({InputTypes.TARGET}, column_definition),\r\n            \"static_input_loc\": _get_locations({InputTypes.STATIC_INPUT}, column_definition),\r\n            \"known_regular_inputs\": _get_locations({InputTypes.STATIC_INPUT, InputTypes.KNOWN_INPUT}, real_inputs),\r\n            \"known_categorical_inputs\": _get_locations(\r\n                {InputTypes.STATIC_INPUT, InputTypes.KNOWN_INPUT}, categorical_inputs\r\n            ),\r\n        }\r\n\r\n        return locations\r\n\r\n    def get_experiment_params(self):\r\n        \"\"\"Returns fixed model parameters for experiments.\"\"\"\r\n\r\n        required_keys = [\r\n            \"total_time_steps\",\r\n            \"num_encoder_steps\",\r\n            \"num_epochs\",\r\n            \"early_stopping_patience\",\r\n            \"multiprocessing_workers\",\r\n        ]\r\n\r\n        fixed_params = self.get_fixed_params()\r\n\r\n        for k in required_keys:\r\n            if k not in fixed_params:\r\n                raise ValueError(\"Field {}\".format(k) + \" missing from fixed parameter definitions!\")\r\n\r\n        fixed_params[\"column_definition\"] = self.get_column_definition()\r\n\r\n        fixed_params.update(self._get_tft_input_indices())\r\n\r\n        return fixed_params\r\n"
  },
  {
    "path": "examples/benchmarks/TFT/data_formatters/qlib_Alpha158.py",
    "content": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Lint as: python3\n\"\"\"Custom formatting functions for Alpha158 dataset.\n\nDefines dataset specific column definitions and data transformations.\n\"\"\"\n\nimport data_formatters.base\nimport libs.utils as utils\nimport sklearn.preprocessing\n\nGenericDataFormatter = data_formatters.base.GenericDataFormatter\nDataTypes = data_formatters.base.DataTypes\nInputTypes = data_formatters.base.InputTypes\n\n\nclass Alpha158Formatter(GenericDataFormatter):\n    \"\"\"Defines and formats data for the Alpha158 dataset.\n\n    Attributes:\n      column_definition: Defines input and data type of column used in the\n        experiment.\n      identifiers: Entity identifiers used in experiments.\n    \"\"\"\n\n    _column_definition = [\n        (\"instrument\", DataTypes.CATEGORICAL, InputTypes.ID),\n        (\"LABEL0\", DataTypes.REAL_VALUED, InputTypes.TARGET),\n        (\"date\", DataTypes.DATE, InputTypes.TIME),\n        (\"month\", DataTypes.CATEGORICAL, InputTypes.KNOWN_INPUT),\n        (\"day_of_week\", DataTypes.CATEGORICAL, InputTypes.KNOWN_INPUT),\n        # Selected features\n        (\"RESI5\", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),\n        (\"WVMA5\", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),\n        (\"RSQR5\", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),\n        (\"KLEN\", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),\n        (\"RSQR10\", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),\n        (\"CORR5\", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),\n        (\"CORD5\", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),\n        (\"CORR10\", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),\n        (\"ROC60\", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),\n        (\"RESI10\", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),\n        (\"VSTD5\", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),\n        (\"RSQR60\", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),\n        (\"CORR60\", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),\n        (\"WVMA60\", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),\n        (\"STD5\", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),\n        (\"RSQR20\", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),\n        (\"CORD60\", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),\n        (\"CORD10\", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),\n        (\"CORR20\", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),\n        (\"KLOW\", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),\n        (\"const\", DataTypes.CATEGORICAL, InputTypes.STATIC_INPUT),\n    ]\n\n    def __init__(self):\n        \"\"\"Initialises formatter.\"\"\"\n\n        self.identifiers = None\n        self._real_scalers = None\n        self._cat_scalers = None\n        self._target_scaler = None\n        self._num_classes_per_cat_input = None\n\n    def split_data(self, df, valid_boundary=2016, test_boundary=2018):\n        \"\"\"Splits data frame into training-validation-test data frames.\n\n        This also calibrates scaling object, and transforms data for each split.\n\n        Args:\n          df: Source data frame to split.\n          valid_boundary: Starting year for validation data\n          test_boundary: Starting year for test data\n\n        Returns:\n          Tuple of transformed (train, valid, test) data.\n        \"\"\"\n\n        print(\"Formatting train-valid-test splits.\")\n\n        index = df[\"year\"]\n        train = df.loc[index < valid_boundary]\n        valid = df.loc[(index >= valid_boundary) & (index < test_boundary)]\n        test = df.loc[index >= test_boundary]\n\n        self.set_scalers(train)\n\n        return (self.transform_inputs(data) for data in [train, valid, test])\n\n    def set_scalers(self, df):\n        \"\"\"Calibrates scalers using the data supplied.\n\n        Args:\n          df: Data to use to calibrate scalers.\n        \"\"\"\n        print(\"Setting scalers with training data...\")\n\n        column_definitions = self.get_column_definition()\n        id_column = utils.get_single_col_by_input_type(InputTypes.ID, column_definitions)\n        target_column = utils.get_single_col_by_input_type(InputTypes.TARGET, column_definitions)\n\n        # Extract identifiers in case required\n        self.identifiers = list(df[id_column].unique())\n\n        # Format real scalers\n        real_inputs = utils.extract_cols_from_data_type(\n            DataTypes.REAL_VALUED, column_definitions, {InputTypes.ID, InputTypes.TIME}\n        )\n\n        data = df[real_inputs].values\n        self._real_scalers = sklearn.preprocessing.StandardScaler().fit(data)\n        self._target_scaler = sklearn.preprocessing.StandardScaler().fit(\n            df[[target_column]].values\n        )  # used for predictions\n\n        # Format categorical scalers\n        categorical_inputs = utils.extract_cols_from_data_type(\n            DataTypes.CATEGORICAL, column_definitions, {InputTypes.ID, InputTypes.TIME}\n        )\n\n        categorical_scalers = {}\n        num_classes = []\n        for col in categorical_inputs:\n            # Set all to str so that we don't have mixed integer/string columns\n            srs = df[col].apply(str)\n            categorical_scalers[col] = sklearn.preprocessing.LabelEncoder().fit(srs.values)\n            num_classes.append(srs.nunique())\n\n        # Set categorical scaler outputs\n        self._cat_scalers = categorical_scalers\n        self._num_classes_per_cat_input = num_classes\n\n    def transform_inputs(self, df):\n        \"\"\"Performs feature transformations.\n\n        This includes both feature engineering, preprocessing and normalisation.\n\n        Args:\n          df: Data frame to transform.\n\n        Returns:\n          Transformed data frame.\n\n        \"\"\"\n        output = df.copy()\n\n        if self._real_scalers is None and self._cat_scalers is None:\n            raise ValueError(\"Scalers have not been set!\")\n\n        column_definitions = self.get_column_definition()\n\n        real_inputs = utils.extract_cols_from_data_type(\n            DataTypes.REAL_VALUED, column_definitions, {InputTypes.ID, InputTypes.TIME}\n        )\n        categorical_inputs = utils.extract_cols_from_data_type(\n            DataTypes.CATEGORICAL, column_definitions, {InputTypes.ID, InputTypes.TIME}\n        )\n\n        # Format real inputs\n        output[real_inputs] = self._real_scalers.transform(df[real_inputs].values)\n\n        # Format categorical inputs\n        for col in categorical_inputs:\n            string_df = df[col].apply(str)\n            output[col] = self._cat_scalers[col].transform(string_df)\n\n        return output\n\n    def format_predictions(self, predictions):\n        \"\"\"Reverts any normalisation to give predictions in original scale.\n\n        Args:\n          predictions: Dataframe of model predictions.\n\n        Returns:\n          Data frame of unnormalised predictions.\n        \"\"\"\n        output = predictions.copy()\n\n        column_names = predictions.columns\n\n        for col in column_names:\n            if col not in {\"forecast_time\", \"identifier\"}:\n                # Using [col] is for aligning with the format when fitting\n                output[col] = self._target_scaler.inverse_transform(predictions[[col]])\n\n        return output\n\n    # Default params\n    def get_fixed_params(self):\n        \"\"\"Returns fixed model parameters for experiments.\"\"\"\n\n        fixed_params = {\n            \"total_time_steps\": 6 + 6,\n            \"num_encoder_steps\": 6,\n            \"num_epochs\": 100,\n            \"early_stopping_patience\": 10,\n            \"multiprocessing_workers\": 5,\n        }\n\n        return fixed_params\n\n    def get_default_model_params(self):\n        \"\"\"Returns default optimised model parameters.\"\"\"\n\n        model_params = {\n            \"dropout_rate\": 0.4,\n            \"hidden_layer_size\": 160,\n            \"learning_rate\": 0.0001,\n            \"minibatch_size\": 128,\n            \"max_gradient_norm\": 0.0135,\n            \"num_heads\": 1,\n            \"stack_size\": 1,\n        }\n\n        return model_params\n"
  },
  {
    "path": "examples/benchmarks/TFT/expt_settings/__init__.py",
    "content": "# coding=utf-8\r\n# Copyright 2020 The Google Research Authors.\r\n#\r\n# Licensed under the Apache License, Version 2.0 (the \"License\");\r\n# you may not use this file except in compliance with the License.\r\n# You may obtain a copy of the License at\r\n#\r\n#     http://www.apache.org/licenses/LICENSE-2.0\r\n#\r\n# Unless required by applicable law or agreed to in writing, software\r\n# distributed under the License is distributed on an \"AS IS\" BASIS,\r\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\r\n# See the License for the specific language governing permissions and\r\n# limitations under the License.\r\n"
  },
  {
    "path": "examples/benchmarks/TFT/expt_settings/configs.py",
    "content": "# coding=utf-8\r\n# Copyright 2020 The Google Research Authors.\r\n#\r\n# Licensed under the Apache License, Version 2.0 (the \"License\");\r\n# you may not use this file except in compliance with the License.\r\n# You may obtain a copy of the License at\r\n#\r\n#     http://www.apache.org/licenses/LICENSE-2.0\r\n#\r\n# Unless required by applicable law or agreed to in writing, software\r\n# distributed under the License is distributed on an \"AS IS\" BASIS,\r\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\r\n# See the License for the specific language governing permissions and\r\n# limitations under the License.\r\n\r\n# Lint as: python3\r\n\"\"\"Default configs for TFT experiments.\r\n\r\nContains the default output paths for data, serialised models and predictions\r\nfor the main experiments used in the publication.\r\n\"\"\"\r\n\r\nimport os\r\n\r\nimport data_formatters.qlib_Alpha158\r\n\r\n\r\nclass ExperimentConfig:\r\n    \"\"\"Defines experiment configs and paths to outputs.\r\n\r\n    Attributes:\r\n      root_folder: Root folder to contain all experimental outputs.\r\n      experiment: Name of experiment to run.\r\n      data_folder: Folder to store data for experiment.\r\n      model_folder: Folder to store serialised models.\r\n      results_folder: Folder to store results.\r\n      data_csv_path: Path to primary data csv file used in experiment.\r\n      hyperparam_iterations: Default number of random search iterations for\r\n        experiment.\r\n    \"\"\"\r\n\r\n    default_experiments = [\"Alpha158\"]\r\n\r\n    def __init__(self, experiment=\"volatility\", root_folder=None):\r\n        \"\"\"Creates configs based on default experiment chosen.\r\n\r\n        Args:\r\n          experiment: Name of experiment.\r\n          root_folder: Root folder to save all outputs of training.\r\n        \"\"\"\r\n\r\n        if experiment not in self.default_experiments:\r\n            raise ValueError(\"Unrecognised experiment={}\".format(experiment))\r\n\r\n        # Defines all relevant paths\r\n        if root_folder is None:\r\n            root_folder = os.path.join(os.path.dirname(os.path.realpath(__file__)), \"..\", \"outputs\")\r\n            print(\"Using root folder {}\".format(root_folder))\r\n\r\n        self.root_folder = root_folder\r\n        self.experiment = experiment\r\n        self.data_folder = os.path.join(root_folder, \"data\", experiment)\r\n        self.model_folder = os.path.join(root_folder, \"saved_models\", experiment)\r\n        self.results_folder = os.path.join(root_folder, \"results\", experiment)\r\n\r\n        # Creates folders if they don't exist\r\n        for relevant_directory in [self.root_folder, self.data_folder, self.model_folder, self.results_folder]:\r\n            if not os.path.exists(relevant_directory):\r\n                os.makedirs(relevant_directory)\r\n\r\n    @property\r\n    def data_csv_path(self):\r\n        csv_map = {\r\n            \"Alpha158\": \"Alpha158.csv\",\r\n        }\r\n\r\n        return os.path.join(self.data_folder, csv_map[self.experiment])\r\n\r\n    @property\r\n    def hyperparam_iterations(self):\r\n        return 240 if self.experiment == \"volatility\" else 60\r\n\r\n    def make_data_formatter(self):\r\n        \"\"\"Gets a data formatter object for experiment.\r\n\r\n        Returns:\r\n          Default DataFormatter per experiment.\r\n        \"\"\"\r\n\r\n        data_formatter_class = {\r\n            \"Alpha158\": data_formatters.qlib_Alpha158.Alpha158Formatter,\r\n        }\r\n\r\n        return data_formatter_class[self.experiment]()\r\n"
  },
  {
    "path": "examples/benchmarks/TFT/libs/__init__.py",
    "content": "# coding=utf-8\r\n# Copyright 2020 The Google Research Authors.\r\n#\r\n# Licensed under the Apache License, Version 2.0 (the \"License\");\r\n# you may not use this file except in compliance with the License.\r\n# You may obtain a copy of the License at\r\n#\r\n#     http://www.apache.org/licenses/LICENSE-2.0\r\n#\r\n# Unless required by applicable law or agreed to in writing, software\r\n# distributed under the License is distributed on an \"AS IS\" BASIS,\r\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\r\n# See the License for the specific language governing permissions and\r\n# limitations under the License.\r\n"
  },
  {
    "path": "examples/benchmarks/TFT/libs/hyperparam_opt.py",
    "content": "# coding=utf-8\r\n# Copyright 2020 The Google Research Authors.\r\n#\r\n# Licensed under the Apache License, Version 2.0 (the \"License\");\r\n# you may not use this file except in compliance with the License.\r\n# You may obtain a copy of the License at\r\n#\r\n#     http://www.apache.org/licenses/LICENSE-2.0\r\n#\r\n# Unless required by applicable law or agreed to in writing, software\r\n# distributed under the License is distributed on an \"AS IS\" BASIS,\r\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\r\n# See the License for the specific language governing permissions and\r\n# limitations under the License.\r\n\r\n# Lint as: python3\r\n\"\"\"Classes used for hyperparameter optimisation.\r\n\r\nTwo main classes exist:\r\n1) HyperparamOptManager used for optimisation on a single machine/GPU.\r\n2) DistributedHyperparamOptManager for multiple GPUs on different machines.\r\n\"\"\"\r\n\r\nfrom __future__ import absolute_import\r\nfrom __future__ import division\r\nfrom __future__ import print_function\r\n\r\nimport collections\r\nimport os\r\nimport shutil\r\nimport libs.utils as utils\r\nimport numpy as np\r\nimport pandas as pd\r\n\r\nDeque = collections.deque\r\n\r\n\r\nclass HyperparamOptManager:\r\n    \"\"\"Manages hyperparameter optimisation using random search for a single GPU.\r\n\r\n    Attributes:\r\n      param_ranges: Discrete hyperparameter range for random search.\r\n      results: Dataframe of validation results.\r\n      fixed_params: Fixed model parameters per experiment.\r\n      saved_params: Dataframe of parameters trained.\r\n      best_score: Minimum validation loss observed thus far.\r\n      optimal_name: Key to best configuration.\r\n      hyperparam_folder: Where to save optimisation outputs.\r\n    \"\"\"\r\n\r\n    def __init__(self, param_ranges, fixed_params, model_folder, override_w_fixed_params=True):\r\n        \"\"\"Instantiates model.\r\n\r\n        Args:\r\n          param_ranges: Discrete hyperparameter range for random search.\r\n          fixed_params: Fixed model parameters per experiment.\r\n          model_folder: Folder to store optimisation artifacts.\r\n          override_w_fixed_params: Whether to override serialsed fixed model\r\n            parameters with new supplied values.\r\n        \"\"\"\r\n\r\n        self.param_ranges = param_ranges\r\n\r\n        self._max_tries = 1000\r\n        self.results = pd.DataFrame()\r\n        self.fixed_params = fixed_params\r\n        self.saved_params = pd.DataFrame()\r\n\r\n        self.best_score = np.Inf\r\n        self.optimal_name = \"\"\r\n\r\n        # Setup\r\n        # Create folder for saving if its not there\r\n        self.hyperparam_folder = model_folder\r\n        utils.create_folder_if_not_exist(self.hyperparam_folder)\r\n\r\n        self._override_w_fixed_params = override_w_fixed_params\r\n\r\n    def load_results(self):\r\n        \"\"\"Loads results from previous hyperparameter optimisation.\r\n\r\n        Returns:\r\n          A boolean indicating if previous results can be loaded.\r\n        \"\"\"\r\n        print(\"Loading results from\", self.hyperparam_folder)\r\n\r\n        results_file = os.path.join(self.hyperparam_folder, \"results.csv\")\r\n        params_file = os.path.join(self.hyperparam_folder, \"params.csv\")\r\n\r\n        if os.path.exists(results_file) and os.path.exists(params_file):\r\n            self.results = pd.read_csv(results_file, index_col=0)\r\n            self.saved_params = pd.read_csv(params_file, index_col=0)\r\n\r\n            if not self.results.empty:\r\n                self.results.at[\"loss\"] = self.results.loc[\"loss\"].apply(float)\r\n                self.best_score = self.results.loc[\"loss\"].min()\r\n\r\n                is_optimal = self.results.loc[\"loss\"] == self.best_score\r\n                self.optimal_name = self.results.T[is_optimal].index[0]\r\n\r\n                return True\r\n\r\n        return False\r\n\r\n    def _get_params_from_name(self, name):\r\n        \"\"\"Returns previously saved parameters given a key.\"\"\"\r\n        params = self.saved_params\r\n\r\n        selected_params = dict(params[name])\r\n\r\n        if self._override_w_fixed_params:\r\n            for k in self.fixed_params:\r\n                selected_params[k] = self.fixed_params[k]\r\n\r\n        return selected_params\r\n\r\n    def get_best_params(self):\r\n        \"\"\"Returns the optimal hyperparameters thus far.\"\"\"\r\n\r\n        optimal_name = self.optimal_name\r\n\r\n        return self._get_params_from_name(optimal_name)\r\n\r\n    def clear(self):\r\n        \"\"\"Clears all previous results and saved parameters.\"\"\"\r\n        shutil.rmtree(self.hyperparam_folder)\r\n        os.makedirs(self.hyperparam_folder)\r\n        self.results = pd.DataFrame()\r\n        self.saved_params = pd.DataFrame()\r\n\r\n    def _check_params(self, params):\r\n        \"\"\"Checks that parameter map is properly defined.\"\"\"\r\n\r\n        valid_fields = list(self.param_ranges.keys()) + list(self.fixed_params.keys())\r\n        invalid_fields = [k for k in params if k not in valid_fields]\r\n        missing_fields = [k for k in valid_fields if k not in params]\r\n\r\n        if invalid_fields:\r\n            raise ValueError(\"Invalid Fields Found {} - Valid ones are {}\".format(invalid_fields, valid_fields))\r\n        if missing_fields:\r\n            raise ValueError(\"Missing Fields Found {} - Valid ones are {}\".format(missing_fields, valid_fields))\r\n\r\n    def _get_name(self, params):\r\n        \"\"\"Returns a unique key for the supplied set of params.\"\"\"\r\n\r\n        self._check_params(params)\r\n\r\n        fields = list(params.keys())\r\n        fields.sort()\r\n\r\n        return \"_\".join([str(params[k]) for k in fields])\r\n\r\n    def get_next_parameters(self, ranges_to_skip=None):\r\n        \"\"\"Returns the next set of parameters to optimise.\r\n\r\n        Args:\r\n          ranges_to_skip: Explicitly defines a set of keys to skip.\r\n        \"\"\"\r\n        if ranges_to_skip is None:\r\n            ranges_to_skip = set(self.results.index)\r\n\r\n        if not isinstance(self.param_ranges, dict):\r\n            raise ValueError(\"Only works for random search!\")\r\n\r\n        param_range_keys = list(self.param_ranges.keys())\r\n        param_range_keys.sort()\r\n\r\n        def _get_next():\r\n            \"\"\"Returns next hyperparameter set per try.\"\"\"\r\n\r\n            parameters = {k: np.random.choice(self.param_ranges[k]) for k in param_range_keys}\r\n\r\n            # Adds fixed params\r\n            for k in self.fixed_params:\r\n                parameters[k] = self.fixed_params[k]\r\n\r\n            return parameters\r\n\r\n        for _ in range(self._max_tries):\r\n            parameters = _get_next()\r\n            name = self._get_name(parameters)\r\n\r\n            if name not in ranges_to_skip:\r\n                return parameters\r\n\r\n        raise ValueError(\"Exceeded max number of hyperparameter searches!!\")\r\n\r\n    def update_score(self, parameters, loss, model, info=\"\"):\r\n        \"\"\"Updates the results from last optimisation run.\r\n\r\n        Args:\r\n          parameters: Hyperparameters used in optimisation.\r\n          loss: Validation loss obtained.\r\n          model: Model to serialised if required.\r\n          info: Any ancillary information to tag on to results.\r\n\r\n        Returns:\r\n          Boolean flag indicating if the model is the best seen so far.\r\n        \"\"\"\r\n\r\n        if np.isnan(loss):\r\n            loss = np.Inf\r\n\r\n        if not os.path.isdir(self.hyperparam_folder):\r\n            os.makedirs(self.hyperparam_folder)\r\n\r\n        name = self._get_name(parameters)\r\n\r\n        is_optimal = self.results.empty or loss < self.best_score\r\n\r\n        # save the first model\r\n        if is_optimal:\r\n            # Try saving first, before updating info\r\n            if model is not None:\r\n                print(\"Optimal model found, updating\")\r\n                model.save(self.hyperparam_folder)\r\n            self.best_score = loss\r\n            self.optimal_name = name\r\n\r\n        self.results[name] = pd.Series({\"loss\": loss, \"info\": info})\r\n        self.saved_params[name] = pd.Series(parameters)\r\n\r\n        self.results.to_csv(os.path.join(self.hyperparam_folder, \"results.csv\"))\r\n        self.saved_params.to_csv(os.path.join(self.hyperparam_folder, \"params.csv\"))\r\n\r\n        return is_optimal\r\n\r\n\r\nclass DistributedHyperparamOptManager(HyperparamOptManager):\r\n    \"\"\"Manages distributed hyperparameter optimisation across many gpus.\"\"\"\r\n\r\n    def __init__(\r\n        self,\r\n        param_ranges,\r\n        fixed_params,\r\n        root_model_folder,\r\n        worker_number,\r\n        search_iterations=1000,\r\n        num_iterations_per_worker=5,\r\n        clear_serialised_params=False,\r\n    ):\r\n        \"\"\"Instantiates optimisation manager.\r\n\r\n        This hyperparameter optimisation pre-generates #search_iterations\r\n        hyperparameter combinations and serialises them\r\n        at the start. At runtime, each worker goes through their own set of\r\n        parameter ranges. The pregeneration\r\n        allows for multiple workers to run in parallel on different machines without\r\n        resulting in parameter overlaps.\r\n\r\n        Args:\r\n          param_ranges: Discrete hyperparameter range for random search.\r\n          fixed_params: Fixed model parameters per experiment.\r\n          root_model_folder: Folder to store optimisation artifacts.\r\n          worker_number: Worker index defining which set of hyperparameters to\r\n            test.\r\n          search_iterations: Maximum number of random search iterations.\r\n          num_iterations_per_worker: How many iterations are handled per worker.\r\n          clear_serialised_params: Whether to regenerate hyperparameter\r\n            combinations.\r\n        \"\"\"\r\n\r\n        max_workers = int(np.ceil(search_iterations / num_iterations_per_worker))\r\n\r\n        # Sanity checks\r\n        if worker_number > max_workers:\r\n            raise ValueError(\r\n                \"Worker number ({}) cannot be larger than the total number of workers!\".format(max_workers)\r\n            )\r\n        if worker_number > search_iterations:\r\n            raise ValueError(\r\n                \"Worker number ({}) cannot be larger than the max search iterations ({})!\".format(\r\n                    worker_number, search_iterations\r\n                )\r\n            )\r\n\r\n        print(\"*** Creating hyperparameter manager for worker {} ***\".format(worker_number))\r\n\r\n        hyperparam_folder = os.path.join(root_model_folder, str(worker_number))\r\n        super().__init__(param_ranges, fixed_params, hyperparam_folder, override_w_fixed_params=True)\r\n\r\n        serialised_ranges_folder = os.path.join(root_model_folder, \"hyperparams\")\r\n        if clear_serialised_params:\r\n            print(\"Regenerating hyperparameter list\")\r\n            if os.path.exists(serialised_ranges_folder):\r\n                shutil.rmtree(serialised_ranges_folder)\r\n\r\n        utils.create_folder_if_not_exist(serialised_ranges_folder)\r\n\r\n        self.serialised_ranges_path = os.path.join(serialised_ranges_folder, \"ranges_{}.csv\".format(search_iterations))\r\n        self.hyperparam_folder = hyperparam_folder  # override\r\n        self.worker_num = worker_number\r\n        self.total_search_iterations = search_iterations\r\n        self.num_iterations_per_worker = num_iterations_per_worker\r\n        self.global_hyperparam_df = self.load_serialised_hyperparam_df()\r\n        self.worker_search_queue = self._get_worker_search_queue()\r\n\r\n    @property\r\n    def optimisation_completed(self):\r\n        return False if self.worker_search_queue else True\r\n\r\n    def get_next_parameters(self):\r\n        \"\"\"Returns next dictionary of hyperparameters to optimise.\"\"\"\r\n        param_name = self.worker_search_queue.pop()\r\n\r\n        params = self.global_hyperparam_df.loc[param_name, :].to_dict()\r\n\r\n        # Always override!\r\n        for k in self.fixed_params:\r\n            print(\"Overriding saved {}: {}\".format(k, self.fixed_params[k]))\r\n\r\n            params[k] = self.fixed_params[k]\r\n\r\n        return params\r\n\r\n    def load_serialised_hyperparam_df(self):\r\n        \"\"\"Loads serialsed hyperparameter ranges from file.\r\n\r\n        Returns:\r\n          DataFrame containing hyperparameter combinations.\r\n        \"\"\"\r\n        print(\r\n            \"Loading params for {} search iterations form {}\".format(\r\n                self.total_search_iterations, self.serialised_ranges_path\r\n            )\r\n        )\r\n\r\n        if os.path.exists(self.serialised_ranges_folder):\r\n            df = pd.read_csv(self.serialised_ranges_path, index_col=0)\r\n        else:\r\n            print(\"Unable to load - regenerating search ranges instead\")\r\n            df = self.update_serialised_hyperparam_df()\r\n\r\n        return df\r\n\r\n    def update_serialised_hyperparam_df(self):\r\n        \"\"\"Regenerates hyperparameter combinations and saves to file.\r\n\r\n        Returns:\r\n          DataFrame containing hyperparameter combinations.\r\n        \"\"\"\r\n        search_df = self._generate_full_hyperparam_df()\r\n\r\n        print(\r\n            \"Serialising params for {} search iterations to {}\".format(\r\n                self.total_search_iterations, self.serialised_ranges_path\r\n            )\r\n        )\r\n\r\n        search_df.to_csv(self.serialised_ranges_path)\r\n\r\n        return search_df\r\n\r\n    def _generate_full_hyperparam_df(self):\r\n        \"\"\"Generates actual hyperparameter combinations.\r\n\r\n        Returns:\r\n          DataFrame containing hyperparameter combinations.\r\n        \"\"\"\r\n\r\n        np.random.seed(131)  # for reproducibility of hyperparam list\r\n\r\n        name_list = []\r\n        param_list = []\r\n        for _ in range(self.total_search_iterations):\r\n            params = super().get_next_parameters(name_list)\r\n\r\n            name = self._get_name(params)\r\n\r\n            name_list.append(name)\r\n            param_list.append(params)\r\n\r\n        full_search_df = pd.DataFrame(param_list, index=name_list)\r\n\r\n        return full_search_df\r\n\r\n    def clear(self):  # reset when cleared\r\n        \"\"\"Clears results for hyperparameter manager and resets.\"\"\"\r\n        super().clear()\r\n        self.worker_search_queue = self._get_worker_search_queue()\r\n\r\n    def load_results(self):\r\n        \"\"\"Load results from file and queue parameter combinations to try.\r\n\r\n        Returns:\r\n          Boolean indicating if results were successfully loaded.\r\n        \"\"\"\r\n        success = super().load_results()\r\n\r\n        if success:\r\n            self.worker_search_queue = self._get_worker_search_queue()\r\n\r\n        return success\r\n\r\n    def _get_worker_search_queue(self):\r\n        \"\"\"Generates the queue of param combinations for current worker.\r\n\r\n        Returns:\r\n          Queue of hyperparameter combinations outstanding.\r\n        \"\"\"\r\n        global_df = self.assign_worker_numbers(self.global_hyperparam_df)\r\n        worker_df = global_df[global_df[\"worker\"] == self.worker_num]\r\n\r\n        left_overs = [s for s in worker_df.index if s not in self.results.columns]\r\n\r\n        return Deque(left_overs)\r\n\r\n    def assign_worker_numbers(self, df):\r\n        \"\"\"Updates parameter combinations with the index of the worker used.\r\n\r\n        Args:\r\n          df: DataFrame of parameter combinations.\r\n\r\n        Returns:\r\n          Updated DataFrame with worker number.\r\n        \"\"\"\r\n        output = df.copy()\r\n\r\n        n = self.total_search_iterations\r\n        batch_size = self.num_iterations_per_worker\r\n\r\n        max_worker_num = int(np.ceil(n / batch_size))\r\n\r\n        worker_idx = np.concatenate([np.tile(i + 1, self.num_iterations_per_worker) for i in range(max_worker_num)])\r\n\r\n        output[\"worker\"] = worker_idx[: len(output)]\r\n\r\n        return output\r\n"
  },
  {
    "path": "examples/benchmarks/TFT/libs/tft_model.py",
    "content": "# coding=utf-8\r\n# Copyright 2020 The Google Research Authors.\r\n#\r\n# Licensed under the Apache License, Version 2.0 (the \"License\");\r\n# you may not use this file except in compliance with the License.\r\n# You may obtain a copy of the License at\r\n#\r\n#     http://www.apache.org/licenses/LICENSE-2.0\r\n#\r\n# Unless required by applicable law or agreed to in writing, software\r\n# distributed under the License is distributed on an \"AS IS\" BASIS,\r\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\r\n# See the License for the specific language governing permissions and\r\n# limitations under the License.\r\n\r\n# Lint as: python3\r\n\"\"\"Temporal Fusion Transformer Model.\r\n\r\nContains the full TFT architecture and associated components. Defines functions\r\nfor training, evaluation and prediction using simple Pandas Dataframe inputs.\r\n\"\"\"\r\n\r\nfrom __future__ import absolute_import\r\nfrom __future__ import division\r\nfrom __future__ import print_function\r\n\r\nimport gc\r\nimport json\r\nimport os\r\nimport shutil\r\n\r\nimport data_formatters.base\r\nimport libs.utils as utils\r\nimport numpy as np\r\nimport pandas as pd\r\nimport tensorflow as tf\r\n\r\n# Layer definitions.\r\nconcat = tf.keras.backend.concatenate\r\nstack = tf.keras.backend.stack\r\nK = tf.keras.backend\r\nAdd = tf.keras.layers.Add\r\nLayerNorm = tf.keras.layers.LayerNormalization\r\nDense = tf.keras.layers.Dense\r\nMultiply = tf.keras.layers.Multiply\r\nDropout = tf.keras.layers.Dropout\r\nActivation = tf.keras.layers.Activation\r\nLambda = tf.keras.layers.Lambda\r\n\r\n# Default input types.\r\nInputTypes = data_formatters.base.InputTypes\r\n\r\n\r\n# Layer utility functions.\r\ndef linear_layer(size, activation=None, use_time_distributed=False, use_bias=True):\r\n    \"\"\"Returns simple Keras linear layer.\r\n\r\n    Args:\r\n      size: Output size\r\n      activation: Activation function to apply if required\r\n      use_time_distributed: Whether to apply layer across time\r\n      use_bias: Whether bias should be included in layer\r\n    \"\"\"\r\n    linear = tf.keras.layers.Dense(size, activation=activation, use_bias=use_bias)\r\n    if use_time_distributed:\r\n        linear = tf.keras.layers.TimeDistributed(linear)\r\n    return linear\r\n\r\n\r\ndef apply_mlp(\r\n    inputs, hidden_size, output_size, output_activation=None, hidden_activation=\"tanh\", use_time_distributed=False\r\n):\r\n    \"\"\"Applies simple feed-forward network to an input.\r\n\r\n    Args:\r\n      inputs: MLP inputs\r\n      hidden_size: Hidden state size\r\n      output_size: Output size of MLP\r\n      output_activation: Activation function to apply on output\r\n      hidden_activation: Activation function to apply on input\r\n      use_time_distributed: Whether to apply across time\r\n\r\n    Returns:\r\n      Tensor for MLP outputs.\r\n    \"\"\"\r\n    if use_time_distributed:\r\n        hidden = tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(hidden_size, activation=hidden_activation))(\r\n            inputs\r\n        )\r\n        return tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(output_size, activation=output_activation))(hidden)\r\n    else:\r\n        hidden = tf.keras.layers.Dense(hidden_size, activation=hidden_activation)(inputs)\r\n        return tf.keras.layers.Dense(output_size, activation=output_activation)(hidden)\r\n\r\n\r\ndef apply_gating_layer(x, hidden_layer_size, dropout_rate=None, use_time_distributed=True, activation=None):\r\n    \"\"\"Applies a Gated Linear Unit (GLU) to an input.\r\n\r\n    Args:\r\n      x: Input to gating layer\r\n      hidden_layer_size: Dimension of GLU\r\n      dropout_rate: Dropout rate to apply if any\r\n      use_time_distributed: Whether to apply across time\r\n      activation: Activation function to apply to the linear feature transform if\r\n        necessary\r\n\r\n    Returns:\r\n      Tuple of tensors for: (GLU output, gate)\r\n    \"\"\"\r\n\r\n    if dropout_rate is not None:\r\n        x = tf.keras.layers.Dropout(dropout_rate)(x)\r\n\r\n    if use_time_distributed:\r\n        activation_layer = tf.keras.layers.TimeDistributed(\r\n            tf.keras.layers.Dense(hidden_layer_size, activation=activation)\r\n        )(x)\r\n        gated_layer = tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(hidden_layer_size, activation=\"sigmoid\"))(x)\r\n    else:\r\n        activation_layer = tf.keras.layers.Dense(hidden_layer_size, activation=activation)(x)\r\n        gated_layer = tf.keras.layers.Dense(hidden_layer_size, activation=\"sigmoid\")(x)\r\n\r\n    return tf.keras.layers.Multiply()([activation_layer, gated_layer]), gated_layer\r\n\r\n\r\ndef add_and_norm(x_list):\r\n    \"\"\"Applies skip connection followed by layer normalisation.\r\n\r\n    Args:\r\n      x_list: List of inputs to sum for skip connection\r\n\r\n    Returns:\r\n      Tensor output from layer.\r\n    \"\"\"\r\n    tmp = Add()(x_list)\r\n    tmp = LayerNorm()(tmp)\r\n    return tmp\r\n\r\n\r\ndef gated_residual_network(\r\n    x,\r\n    hidden_layer_size,\r\n    output_size=None,\r\n    dropout_rate=None,\r\n    use_time_distributed=True,\r\n    additional_context=None,\r\n    return_gate=False,\r\n):\r\n    \"\"\"Applies the gated residual network (GRN) as defined in paper.\r\n\r\n    Args:\r\n      x: Network inputs\r\n      hidden_layer_size: Internal state size\r\n      output_size: Size of output layer\r\n      dropout_rate: Dropout rate if dropout is applied\r\n      use_time_distributed: Whether to apply network across time dimension\r\n      additional_context: Additional context vector to use if relevant\r\n      return_gate: Whether to return GLU gate for diagnostic purposes\r\n\r\n    Returns:\r\n      Tuple of tensors for: (GRN output, GLU gate)\r\n    \"\"\"\r\n\r\n    # Setup skip connection\r\n    if output_size is None:\r\n        output_size = hidden_layer_size\r\n        skip = x\r\n    else:\r\n        linear = Dense(output_size)\r\n        if use_time_distributed:\r\n            linear = tf.keras.layers.TimeDistributed(linear)\r\n        skip = linear(x)\r\n\r\n    # Apply feedforward network\r\n    hidden = linear_layer(hidden_layer_size, activation=None, use_time_distributed=use_time_distributed)(x)\r\n    if additional_context is not None:\r\n        hidden = hidden + linear_layer(\r\n            hidden_layer_size, activation=None, use_time_distributed=use_time_distributed, use_bias=False\r\n        )(additional_context)\r\n    hidden = tf.keras.layers.Activation(\"elu\")(hidden)\r\n    hidden = linear_layer(hidden_layer_size, activation=None, use_time_distributed=use_time_distributed)(hidden)\r\n\r\n    gating_layer, gate = apply_gating_layer(\r\n        hidden, output_size, dropout_rate=dropout_rate, use_time_distributed=use_time_distributed, activation=None\r\n    )\r\n\r\n    if return_gate:\r\n        return add_and_norm([skip, gating_layer]), gate\r\n    else:\r\n        return add_and_norm([skip, gating_layer])\r\n\r\n\r\n# Attention Components.\r\ndef get_decoder_mask(self_attn_inputs):\r\n    \"\"\"Returns causal mask to apply for self-attention layer.\r\n\r\n    Args:\r\n      self_attn_inputs: Inputs to self attention layer to determine mask shape\r\n    \"\"\"\r\n    len_s = tf.shape(self_attn_inputs)[1]\r\n    bs = tf.shape(self_attn_inputs)[:1]\r\n    mask = K.cumsum(tf.eye(len_s, batch_shape=bs), 1)\r\n    return mask\r\n\r\n\r\nclass ScaledDotProductAttention:\r\n    \"\"\"Defines scaled dot product attention layer.\r\n\r\n    Attributes:\r\n      dropout: Dropout rate to use\r\n      activation: Normalisation function for scaled dot product attention (e.g.\r\n        softmax by default)\r\n    \"\"\"\r\n\r\n    def __init__(self, attn_dropout=0.0):\r\n        self.dropout = Dropout(attn_dropout)\r\n        self.activation = Activation(\"softmax\")\r\n\r\n    def __call__(self, q, k, v, mask):\r\n        \"\"\"Applies scaled dot product attention.\r\n\r\n        Args:\r\n          q: Queries\r\n          k: Keys\r\n          v: Values\r\n          mask: Masking if required -- sets softmax to very large value\r\n\r\n        Returns:\r\n          Tuple of (layer outputs, attention weights)\r\n        \"\"\"\r\n        temper = tf.sqrt(tf.cast(tf.shape(k)[-1], dtype=\"float32\"))\r\n        attn = Lambda(lambda x: K.batch_dot(x[0], x[1], axes=[2, 2]) / temper)([q, k])  # shape=(batch, q, k)\r\n        if mask is not None:\r\n            mmask = Lambda(lambda x: (-1e9) * (1.0 - K.cast(x, \"float32\")))(mask)  # setting to infinity\r\n            attn = Add()([attn, mmask])\r\n        attn = self.activation(attn)\r\n        attn = self.dropout(attn)\r\n        output = Lambda(lambda x: K.batch_dot(x[0], x[1]))([attn, v])\r\n        return output, attn\r\n\r\n\r\nclass InterpretableMultiHeadAttention:\r\n    \"\"\"Defines interpretable multi-head attention layer.\r\n\r\n    Attributes:\r\n      n_head: Number of heads\r\n      d_k: Key/query dimensionality per head\r\n      d_v: Value dimensionality\r\n      dropout: Dropout rate to apply\r\n      qs_layers: List of queries across heads\r\n      ks_layers: List of keys across heads\r\n      vs_layers: List of values across heads\r\n      attention: Scaled dot product attention layer\r\n      w_o: Output weight matrix to project internal state to the original TFT\r\n        state size\r\n    \"\"\"\r\n\r\n    def __init__(self, n_head, d_model, dropout):\r\n        \"\"\"Initialises layer.\r\n\r\n        Args:\r\n          n_head: Number of heads\r\n          d_model: TFT state dimensionality\r\n          dropout: Dropout discard rate\r\n        \"\"\"\r\n\r\n        self.n_head = n_head\r\n        self.d_k = self.d_v = d_k = d_v = d_model // n_head\r\n        self.dropout = dropout\r\n\r\n        self.qs_layers = []\r\n        self.ks_layers = []\r\n        self.vs_layers = []\r\n\r\n        # Use same value layer to facilitate interp\r\n        vs_layer = Dense(d_v, use_bias=False)\r\n\r\n        for _ in range(n_head):\r\n            self.qs_layers.append(Dense(d_k, use_bias=False))\r\n            self.ks_layers.append(Dense(d_k, use_bias=False))\r\n            self.vs_layers.append(vs_layer)  # use same vs_layer\r\n\r\n        self.attention = ScaledDotProductAttention()\r\n        self.w_o = Dense(d_model, use_bias=False)\r\n\r\n    def __call__(self, q, k, v, mask=None):\r\n        \"\"\"Applies interpretable multihead attention.\r\n\r\n        Using T to denote the number of time steps fed into the transformer.\r\n\r\n        Args:\r\n          q: Query tensor of shape=(?, T, d_model)\r\n          k: Key of shape=(?, T, d_model)\r\n          v: Values of shape=(?, T, d_model)\r\n          mask: Masking if required with shape=(?, T, T)\r\n\r\n        Returns:\r\n          Tuple of (layer outputs, attention weights)\r\n        \"\"\"\r\n        n_head = self.n_head\r\n\r\n        heads = []\r\n        attns = []\r\n        for i in range(n_head):\r\n            qs = self.qs_layers[i](q)\r\n            ks = self.ks_layers[i](k)\r\n            vs = self.vs_layers[i](v)\r\n            head, attn = self.attention(qs, ks, vs, mask)\r\n\r\n            head_dropout = Dropout(self.dropout)(head)\r\n            heads.append(head_dropout)\r\n            attns.append(attn)\r\n        head = K.stack(heads) if n_head > 1 else heads[0]\r\n        attn = K.stack(attns)\r\n\r\n        outputs = K.mean(head, axis=0) if n_head > 1 else head\r\n        outputs = self.w_o(outputs)\r\n        outputs = Dropout(self.dropout)(outputs)  # output dropout\r\n\r\n        return outputs, attn\r\n\r\n\r\nclass TFTDataCache:\r\n    \"\"\"Caches data for the TFT.\"\"\"\r\n\r\n    _data_cache = {}\r\n\r\n    @classmethod\r\n    def update(cls, data, key):\r\n        \"\"\"Updates cached data.\r\n\r\n        Args:\r\n          data: Source to update\r\n          key: Key to dictionary location\r\n        \"\"\"\r\n        cls._data_cache[key] = data\r\n\r\n    @classmethod\r\n    def get(cls, key):\r\n        \"\"\"Returns data stored at key location.\"\"\"\r\n        return cls._data_cache[key].copy()\r\n\r\n    @classmethod\r\n    def contains(cls, key):\r\n        \"\"\"Returns boolean indicating whether key is present in cache.\"\"\"\r\n\r\n        return key in cls._data_cache\r\n\r\n\r\n# TFT model definitions.\r\nclass TemporalFusionTransformer:\r\n    \"\"\"Defines Temporal Fusion Transformer.\r\n\r\n    Attributes:\r\n      name: Name of model\r\n      time_steps: Total number of input time steps per forecast date (i.e. Width\r\n        of Temporal fusion decoder N)\r\n      input_size: Total number of inputs\r\n      output_size: Total number of outputs\r\n      category_counts: Number of categories per categorical variable\r\n      n_multiprocessing_workers: Number of workers to use for parallel\r\n        computations\r\n      column_definition: List of tuples of (string, DataType, InputType) that\r\n        define each column\r\n      quantiles: Quantiles to forecast for TFT\r\n      use_cudnn: Whether to use Keras CuDNNLSTM or standard LSTM layers\r\n      hidden_layer_size: Internal state size of TFT\r\n      dropout_rate: Dropout discard rate\r\n      max_gradient_norm: Maximum norm for gradient clipping\r\n      learning_rate: Initial learning rate of ADAM optimizer\r\n      minibatch_size: Size of minibatches for training\r\n      num_epochs: Maximum number of epochs for training\r\n      early_stopping_patience: Maximum number of iterations of non-improvement\r\n        before early stopping kicks in\r\n      num_encoder_steps: Size of LSTM encoder -- i.e. number of past time steps\r\n        before forecast date to use\r\n      num_stacks: Number of self-attention layers to apply (default is 1 for basic\r\n        TFT)\r\n      num_heads: Number of heads for interpretable mulit-head attention\r\n      model: Keras model for TFT\r\n    \"\"\"\r\n\r\n    def __init__(self, raw_params, use_cudnn=False):\r\n        \"\"\"Builds TFT from parameters.\r\n\r\n        Args:\r\n          raw_params: Parameters to define TFT\r\n          use_cudnn: Whether to use CUDNN GPU optimised LSTM\r\n        \"\"\"\r\n\r\n        self.name = self.__class__.__name__\r\n\r\n        params = dict(raw_params)  # copy locally\r\n\r\n        # Data parameters\r\n        self.time_steps = int(params[\"total_time_steps\"])\r\n        self.input_size = int(params[\"input_size\"])\r\n        self.output_size = int(params[\"output_size\"])\r\n        self.category_counts = json.loads(str(params[\"category_counts\"]))\r\n        self.n_multiprocessing_workers = int(params[\"multiprocessing_workers\"])\r\n\r\n        # Relevant indices for TFT\r\n        self._input_obs_loc = json.loads(str(params[\"input_obs_loc\"]))\r\n        self._static_input_loc = json.loads(str(params[\"static_input_loc\"]))\r\n        self._known_regular_input_idx = json.loads(str(params[\"known_regular_inputs\"]))\r\n        self._known_categorical_input_idx = json.loads(str(params[\"known_categorical_inputs\"]))\r\n\r\n        self.column_definition = params[\"column_definition\"]\r\n\r\n        # Network params\r\n        self.quantiles = [0.1, 0.5, 0.9]\r\n        self.use_cudnn = use_cudnn  # Whether to use GPU optimised LSTM\r\n        self.hidden_layer_size = int(params[\"hidden_layer_size\"])\r\n        self.dropout_rate = float(params[\"dropout_rate\"])\r\n        self.max_gradient_norm = float(params[\"max_gradient_norm\"])\r\n        self.learning_rate = float(params[\"learning_rate\"])\r\n        self.minibatch_size = int(params[\"minibatch_size\"])\r\n        self.num_epochs = int(params[\"num_epochs\"])\r\n        self.early_stopping_patience = int(params[\"early_stopping_patience\"])\r\n\r\n        self.num_encoder_steps = int(params[\"num_encoder_steps\"])\r\n        self.num_stacks = int(params[\"stack_size\"])\r\n        self.num_heads = int(params[\"num_heads\"])\r\n\r\n        # Serialisation options\r\n        self._temp_folder = os.path.join(params[\"model_folder\"], \"tmp\")\r\n        self.reset_temp_folder()\r\n\r\n        # Extra components to store Tensorflow nodes for attention computations\r\n        self._input_placeholder = None\r\n        self._attention_components = None\r\n        self._prediction_parts = None\r\n\r\n        print(\"*** {} params ***\".format(self.name))\r\n        for k in params:\r\n            print(\"# {} = {}\".format(k, params[k]))\r\n\r\n        # Build model\r\n        self.model = self.build_model()\r\n\r\n    def get_tft_embeddings(self, all_inputs):\r\n        \"\"\"Transforms raw inputs to embeddings.\r\n\r\n        Applies linear transformation onto continuous variables and uses embeddings\r\n        for categorical variables.\r\n\r\n        Args:\r\n          all_inputs: Inputs to transform\r\n\r\n        Returns:\r\n          Tensors for transformed inputs.\r\n        \"\"\"\r\n\r\n        time_steps = self.time_steps\r\n\r\n        # Sanity checks\r\n        for i in self._known_regular_input_idx:\r\n            if i in self._input_obs_loc:\r\n                raise ValueError(\"Observation cannot be known a priori!\")\r\n        for i in self._input_obs_loc:\r\n            if i in self._static_input_loc:\r\n                raise ValueError(\"Observation cannot be static!\")\r\n\r\n        if all_inputs.get_shape().as_list()[-1] != self.input_size:\r\n            raise ValueError(\r\n                \"Illegal number of inputs! Inputs observed={}, expected={}\".format(\r\n                    all_inputs.get_shape().as_list()[-1], self.input_size\r\n                )\r\n            )\r\n\r\n        num_categorical_variables = len(self.category_counts)\r\n        num_regular_variables = self.input_size - num_categorical_variables\r\n\r\n        embedding_sizes = [self.hidden_layer_size for i, size in enumerate(self.category_counts)]\r\n\r\n        embeddings = []\r\n        for i in range(num_categorical_variables):\r\n            embedding = tf.keras.Sequential(\r\n                [\r\n                    tf.keras.layers.InputLayer([time_steps]),\r\n                    tf.keras.layers.Embedding(\r\n                        self.category_counts[i], embedding_sizes[i], input_length=time_steps, dtype=tf.float32\r\n                    ),\r\n                ]\r\n            )\r\n            embeddings.append(embedding)\r\n\r\n        regular_inputs, categorical_inputs = (\r\n            all_inputs[:, :, :num_regular_variables],\r\n            all_inputs[:, :, num_regular_variables:],\r\n        )\r\n\r\n        embedded_inputs = [embeddings[i](categorical_inputs[Ellipsis, i]) for i in range(num_categorical_variables)]\r\n\r\n        # Static inputs\r\n        if self._static_input_loc:\r\n            static_inputs = [\r\n                tf.keras.layers.Dense(self.hidden_layer_size)(regular_inputs[:, 0, i : i + 1])\r\n                for i in range(num_regular_variables)\r\n                if i in self._static_input_loc\r\n            ] + [\r\n                embedded_inputs[i][:, 0, :]\r\n                for i in range(num_categorical_variables)\r\n                if i + num_regular_variables in self._static_input_loc\r\n            ]\r\n            static_inputs = tf.keras.backend.stack(static_inputs, axis=1)\r\n\r\n        else:\r\n            static_inputs = None\r\n\r\n        def convert_real_to_embedding(x):\r\n            \"\"\"Applies linear transformation for time-varying inputs.\"\"\"\r\n            return tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(self.hidden_layer_size))(x)\r\n\r\n        # Targets\r\n        obs_inputs = tf.keras.backend.stack(\r\n            [convert_real_to_embedding(regular_inputs[Ellipsis, i : i + 1]) for i in self._input_obs_loc], axis=-1\r\n        )\r\n\r\n        # Observed (a prioir unknown) inputs\r\n        wired_embeddings = []\r\n        for i in range(num_categorical_variables):\r\n            if i not in self._known_categorical_input_idx and i + num_regular_variables not in self._input_obs_loc:\r\n                e = embeddings[i](categorical_inputs[:, :, i])\r\n                wired_embeddings.append(e)\r\n\r\n        unknown_inputs = []\r\n        for i in range(regular_inputs.shape[-1]):\r\n            if i not in self._known_regular_input_idx and i not in self._input_obs_loc:\r\n                e = convert_real_to_embedding(regular_inputs[Ellipsis, i : i + 1])\r\n                unknown_inputs.append(e)\r\n\r\n        if unknown_inputs + wired_embeddings:\r\n            unknown_inputs = tf.keras.backend.stack(unknown_inputs + wired_embeddings, axis=-1)\r\n        else:\r\n            unknown_inputs = None\r\n\r\n        # A priori known inputs\r\n        known_regular_inputs = [\r\n            convert_real_to_embedding(regular_inputs[Ellipsis, i : i + 1])\r\n            for i in self._known_regular_input_idx\r\n            if i not in self._static_input_loc\r\n        ]\r\n        known_categorical_inputs = [\r\n            embedded_inputs[i]\r\n            for i in self._known_categorical_input_idx\r\n            if i + num_regular_variables not in self._static_input_loc\r\n        ]\r\n\r\n        known_combined_layer = tf.keras.backend.stack(known_regular_inputs + known_categorical_inputs, axis=-1)\r\n\r\n        return unknown_inputs, known_combined_layer, obs_inputs, static_inputs\r\n\r\n    def _get_single_col_by_type(self, input_type):\r\n        \"\"\"Returns name of single column for input type.\"\"\"\r\n\r\n        return utils.get_single_col_by_input_type(input_type, self.column_definition)\r\n\r\n    def training_data_cached(self):\r\n        \"\"\"Returns boolean indicating if training data has been cached.\"\"\"\r\n\r\n        return TFTDataCache.contains(\"train\") and TFTDataCache.contains(\"valid\")\r\n\r\n    def cache_batched_data(self, data, cache_key, num_samples=-1):\r\n        \"\"\"Batches and caches data once for using during training.\r\n\r\n        Args:\r\n          data: Data to batch and cache\r\n          cache_key: Key used for cache\r\n          num_samples: Maximum number of samples to extract (-1 to use all data)\r\n        \"\"\"\r\n\r\n        if num_samples > 0:\r\n            TFTDataCache.update(self._batch_sampled_data(data, max_samples=num_samples), cache_key)\r\n        else:\r\n            TFTDataCache.update(self._batch_data(data), cache_key)\r\n\r\n        print('Cached data \"{}\" updated'.format(cache_key))\r\n\r\n    def _batch_sampled_data(self, data, max_samples):\r\n        \"\"\"Samples segments into a compatible format.\r\n\r\n        Args:\r\n          data: Sources data to sample and batch\r\n          max_samples: Maximum number of samples in batch\r\n\r\n        Returns:\r\n          Dictionary of batched data with the maximum samples specified.\r\n        \"\"\"\r\n\r\n        if max_samples < 1:\r\n            raise ValueError(\"Illegal number of samples specified! samples={}\".format(max_samples))\r\n\r\n        id_col = self._get_single_col_by_type(InputTypes.ID)\r\n        time_col = self._get_single_col_by_type(InputTypes.TIME)\r\n\r\n        data.sort_values(by=[id_col, time_col], inplace=True)\r\n\r\n        print(\"Getting valid sampling locations.\")\r\n        valid_sampling_locations = []\r\n        split_data_map = {}\r\n        for identifier, df in data.groupby(id_col, group_key=False):\r\n            print(\"Getting locations for {}\".format(identifier))\r\n            num_entries = len(df)\r\n            if num_entries >= self.time_steps:\r\n                valid_sampling_locations += [\r\n                    (identifier, self.time_steps + i) for i in range(num_entries - self.time_steps + 1)\r\n                ]\r\n            split_data_map[identifier] = df\r\n\r\n        inputs = np.zeros((max_samples, self.time_steps, self.input_size))\r\n        outputs = np.zeros((max_samples, self.time_steps, self.output_size))\r\n        time = np.empty((max_samples, self.time_steps, 1), dtype=object)\r\n        identifiers = np.empty((max_samples, self.time_steps, 1), dtype=object)\r\n\r\n        if max_samples > 0 and len(valid_sampling_locations) > max_samples:\r\n            print(\"Extracting {} samples...\".format(max_samples))\r\n            ranges = [\r\n                valid_sampling_locations[i]\r\n                for i in np.random.choice(len(valid_sampling_locations), max_samples, replace=False)\r\n            ]\r\n        else:\r\n            print(\"Max samples={} exceeds # available segments={}\".format(max_samples, len(valid_sampling_locations)))\r\n            ranges = valid_sampling_locations\r\n\r\n        id_col = self._get_single_col_by_type(InputTypes.ID)\r\n        time_col = self._get_single_col_by_type(InputTypes.TIME)\r\n        target_col = self._get_single_col_by_type(InputTypes.TARGET)\r\n        input_cols = [tup[0] for tup in self.column_definition if tup[2] not in {InputTypes.ID, InputTypes.TIME}]\r\n\r\n        for i, tup in enumerate(ranges):\r\n            if (i + 1 % 1000) == 0:\r\n                print(i + 1, \"of\", max_samples, \"samples done...\")\r\n            identifier, start_idx = tup\r\n            sliced = split_data_map[identifier].iloc[start_idx - self.time_steps : start_idx]\r\n            inputs[i, :, :] = sliced[input_cols]\r\n            outputs[i, :, :] = sliced[[target_col]]\r\n            time[i, :, 0] = sliced[time_col]\r\n            identifiers[i, :, 0] = sliced[id_col]\r\n\r\n        sampled_data = {\r\n            \"inputs\": inputs,\r\n            \"outputs\": outputs[:, self.num_encoder_steps :, :],\r\n            \"active_entries\": np.ones_like(outputs[:, self.num_encoder_steps :, :]),\r\n            \"time\": time,\r\n            \"identifier\": identifiers,\r\n        }\r\n\r\n        return sampled_data\r\n\r\n    def _batch_data(self, data):\r\n        \"\"\"Batches data for training.\r\n\r\n        Converts raw dataframe from a 2-D tabular format to a batched 3-D array\r\n        to feed into Keras model.\r\n\r\n        Args:\r\n          data: DataFrame to batch\r\n\r\n        Returns:\r\n          Batched Numpy array with shape=(?, self.time_steps, self.input_size)\r\n        \"\"\"\r\n\r\n        # Functions.\r\n        def _batch_single_entity(input_data):\r\n            time_steps = len(input_data)\r\n            lags = self.time_steps\r\n            x = input_data.values\r\n            if time_steps >= lags:\r\n                return np.stack([x[i : time_steps - (lags - 1) + i, :] for i in range(lags)], axis=1)\r\n\r\n            else:\r\n                return None\r\n\r\n        id_col = self._get_single_col_by_type(InputTypes.ID)\r\n        time_col = self._get_single_col_by_type(InputTypes.TIME)\r\n        target_col = self._get_single_col_by_type(InputTypes.TARGET)\r\n        input_cols = [tup[0] for tup in self.column_definition if tup[2] not in {InputTypes.ID, InputTypes.TIME}]\r\n\r\n        data_map = {}\r\n        for _, sliced in data.groupby(id_col, group_keys=False):\r\n            col_mappings = {\"identifier\": [id_col], \"time\": [time_col], \"outputs\": [target_col], \"inputs\": input_cols}\r\n\r\n            for k in col_mappings:\r\n                cols = col_mappings[k]\r\n                arr = _batch_single_entity(sliced[cols].copy())\r\n\r\n                if k not in data_map:\r\n                    data_map[k] = [arr]\r\n                else:\r\n                    data_map[k].append(arr)\r\n\r\n        # Combine all data\r\n        for k in data_map:\r\n            # Wendi: Avoid returning None when the length is not enough\r\n            data_map[k] = np.concatenate([i for i in data_map[k] if i is not None], axis=0)\r\n\r\n        # Shorten target so we only get decoder steps\r\n        data_map[\"outputs\"] = data_map[\"outputs\"][:, self.num_encoder_steps :, :]\r\n\r\n        active_entries = np.ones_like(data_map[\"outputs\"])\r\n        if \"active_entries\" not in data_map:\r\n            data_map[\"active_entries\"] = active_entries\r\n        else:\r\n            data_map[\"active_entries\"].append(active_entries)\r\n\r\n        return data_map\r\n\r\n    def _get_active_locations(self, x):\r\n        \"\"\"Formats sample weights for Keras training.\"\"\"\r\n        return (np.sum(x, axis=-1) > 0.0) * 1.0\r\n\r\n    def _build_base_graph(self):\r\n        \"\"\"Returns graph defining layers of the TFT.\"\"\"\r\n\r\n        # Size definitions.\r\n        time_steps = self.time_steps\r\n        combined_input_size = self.input_size\r\n        encoder_steps = self.num_encoder_steps\r\n\r\n        # Inputs.\r\n        all_inputs = tf.keras.layers.Input(\r\n            shape=(\r\n                time_steps,\r\n                combined_input_size,\r\n            )\r\n        )\r\n\r\n        unknown_inputs, known_combined_layer, obs_inputs, static_inputs = self.get_tft_embeddings(all_inputs)\r\n\r\n        # Isolate known and observed historical inputs.\r\n        if unknown_inputs is not None:\r\n            historical_inputs = concat(\r\n                [\r\n                    unknown_inputs[:, :encoder_steps, :],\r\n                    known_combined_layer[:, :encoder_steps, :],\r\n                    obs_inputs[:, :encoder_steps, :],\r\n                ],\r\n                axis=-1,\r\n            )\r\n        else:\r\n            historical_inputs = concat(\r\n                [known_combined_layer[:, :encoder_steps, :], obs_inputs[:, :encoder_steps, :]], axis=-1\r\n            )\r\n\r\n        # Isolate only known future inputs.\r\n        future_inputs = known_combined_layer[:, encoder_steps:, :]\r\n\r\n        def static_combine_and_mask(embedding):\r\n            \"\"\"Applies variable selection network to static inputs.\r\n\r\n            Args:\r\n              embedding: Transformed static inputs\r\n\r\n            Returns:\r\n              Tensor output for variable selection network\r\n            \"\"\"\r\n\r\n            # Add temporal features\r\n            _, num_static, _ = embedding.get_shape().as_list()\r\n\r\n            flatten = tf.keras.layers.Flatten()(embedding)\r\n\r\n            # Nonlinear transformation with gated residual network.\r\n            mlp_outputs = gated_residual_network(\r\n                flatten,\r\n                self.hidden_layer_size,\r\n                output_size=num_static,\r\n                dropout_rate=self.dropout_rate,\r\n                use_time_distributed=False,\r\n                additional_context=None,\r\n            )\r\n\r\n            sparse_weights = tf.keras.layers.Activation(\"softmax\")(mlp_outputs)\r\n            sparse_weights = K.expand_dims(sparse_weights, axis=-1)\r\n\r\n            trans_emb_list = []\r\n            for i in range(num_static):\r\n                e = gated_residual_network(\r\n                    embedding[:, i : i + 1, :],\r\n                    self.hidden_layer_size,\r\n                    dropout_rate=self.dropout_rate,\r\n                    use_time_distributed=False,\r\n                )\r\n                trans_emb_list.append(e)\r\n\r\n            transformed_embedding = concat(trans_emb_list, axis=1)\r\n\r\n            combined = tf.keras.layers.Multiply()([sparse_weights, transformed_embedding])\r\n\r\n            static_vec = K.sum(combined, axis=1)\r\n\r\n            return static_vec, sparse_weights\r\n\r\n        static_encoder, static_weights = static_combine_and_mask(static_inputs)\r\n\r\n        static_context_variable_selection = gated_residual_network(\r\n            static_encoder, self.hidden_layer_size, dropout_rate=self.dropout_rate, use_time_distributed=False\r\n        )\r\n        static_context_enrichment = gated_residual_network(\r\n            static_encoder, self.hidden_layer_size, dropout_rate=self.dropout_rate, use_time_distributed=False\r\n        )\r\n        static_context_state_h = gated_residual_network(\r\n            static_encoder, self.hidden_layer_size, dropout_rate=self.dropout_rate, use_time_distributed=False\r\n        )\r\n        static_context_state_c = gated_residual_network(\r\n            static_encoder, self.hidden_layer_size, dropout_rate=self.dropout_rate, use_time_distributed=False\r\n        )\r\n\r\n        def lstm_combine_and_mask(embedding):\r\n            \"\"\"Apply temporal variable selection networks.\r\n\r\n            Args:\r\n              embedding: Transformed inputs.\r\n\r\n            Returns:\r\n              Processed tensor outputs.\r\n            \"\"\"\r\n\r\n            # Add temporal features\r\n            _, time_steps, embedding_dim, num_inputs = embedding.get_shape().as_list()\r\n\r\n            flatten = K.reshape(embedding, [-1, time_steps, embedding_dim * num_inputs])\r\n\r\n            expanded_static_context = K.expand_dims(static_context_variable_selection, axis=1)\r\n\r\n            # Variable selection weights\r\n            mlp_outputs, static_gate = gated_residual_network(\r\n                flatten,\r\n                self.hidden_layer_size,\r\n                output_size=num_inputs,\r\n                dropout_rate=self.dropout_rate,\r\n                use_time_distributed=True,\r\n                additional_context=expanded_static_context,\r\n                return_gate=True,\r\n            )\r\n\r\n            sparse_weights = tf.keras.layers.Activation(\"softmax\")(mlp_outputs)\r\n            sparse_weights = tf.expand_dims(sparse_weights, axis=2)\r\n\r\n            # Non-linear Processing & weight application\r\n            trans_emb_list = []\r\n            for i in range(num_inputs):\r\n                grn_output = gated_residual_network(\r\n                    embedding[Ellipsis, i],\r\n                    self.hidden_layer_size,\r\n                    dropout_rate=self.dropout_rate,\r\n                    use_time_distributed=True,\r\n                )\r\n                trans_emb_list.append(grn_output)\r\n\r\n            transformed_embedding = stack(trans_emb_list, axis=-1)\r\n\r\n            combined = tf.keras.layers.Multiply()([sparse_weights, transformed_embedding])\r\n            temporal_ctx = K.sum(combined, axis=-1)\r\n\r\n            return temporal_ctx, sparse_weights, static_gate\r\n\r\n        historical_features, historical_flags, _ = lstm_combine_and_mask(historical_inputs)\r\n        future_features, future_flags, _ = lstm_combine_and_mask(future_inputs)\r\n\r\n        # LSTM layer\r\n        def get_lstm(return_state):\r\n            \"\"\"Returns LSTM cell initialized with default parameters.\"\"\"\r\n            if self.use_cudnn:\r\n                lstm = tf.keras.layers.CuDNNLSTM(\r\n                    self.hidden_layer_size,\r\n                    return_sequences=True,\r\n                    return_state=return_state,\r\n                    stateful=False,\r\n                )\r\n            else:\r\n                lstm = tf.keras.layers.LSTM(\r\n                    self.hidden_layer_size,\r\n                    return_sequences=True,\r\n                    return_state=return_state,\r\n                    stateful=False,\r\n                    # Additional params to ensure LSTM matches CuDNN, See TF 2.0 :\r\n                    # (https://www.tensorflow.org/api_docs/python/tf/keras/layers/LSTM)\r\n                    activation=\"tanh\",\r\n                    recurrent_activation=\"sigmoid\",\r\n                    recurrent_dropout=0,\r\n                    unroll=False,\r\n                    use_bias=True,\r\n                )\r\n            return lstm\r\n\r\n        history_lstm, state_h, state_c = get_lstm(return_state=True)(\r\n            historical_features, initial_state=[static_context_state_h, static_context_state_c]\r\n        )\r\n\r\n        future_lstm = get_lstm(return_state=False)(future_features, initial_state=[state_h, state_c])\r\n\r\n        lstm_layer = concat([history_lstm, future_lstm], axis=1)\r\n\r\n        # Apply gated skip connection\r\n        input_embeddings = concat([historical_features, future_features], axis=1)\r\n\r\n        lstm_layer, _ = apply_gating_layer(lstm_layer, self.hidden_layer_size, self.dropout_rate, activation=None)\r\n        temporal_feature_layer = add_and_norm([lstm_layer, input_embeddings])\r\n\r\n        # Static enrichment layers\r\n        expanded_static_context = K.expand_dims(static_context_enrichment, axis=1)\r\n        enriched, _ = gated_residual_network(\r\n            temporal_feature_layer,\r\n            self.hidden_layer_size,\r\n            dropout_rate=self.dropout_rate,\r\n            use_time_distributed=True,\r\n            additional_context=expanded_static_context,\r\n            return_gate=True,\r\n        )\r\n\r\n        # Decoder self attention\r\n        self_attn_layer = InterpretableMultiHeadAttention(\r\n            self.num_heads, self.hidden_layer_size, dropout=self.dropout_rate\r\n        )\r\n\r\n        mask = get_decoder_mask(enriched)\r\n        x, self_att = self_attn_layer(enriched, enriched, enriched, mask=mask)\r\n\r\n        x, _ = apply_gating_layer(x, self.hidden_layer_size, dropout_rate=self.dropout_rate, activation=None)\r\n        x = add_and_norm([x, enriched])\r\n\r\n        # Nonlinear processing on outputs\r\n        decoder = gated_residual_network(\r\n            x, self.hidden_layer_size, dropout_rate=self.dropout_rate, use_time_distributed=True\r\n        )\r\n\r\n        # Final skip connection\r\n        decoder, _ = apply_gating_layer(decoder, self.hidden_layer_size, activation=None)\r\n        transformer_layer = add_and_norm([decoder, temporal_feature_layer])\r\n\r\n        # Attention components for explainability\r\n        attention_components = {\r\n            # Temporal attention weights\r\n            \"decoder_self_attn\": self_att,\r\n            # Static variable selection weights\r\n            \"static_flags\": static_weights[Ellipsis, 0],\r\n            # Variable selection weights of past inputs\r\n            \"historical_flags\": historical_flags[Ellipsis, 0, :],\r\n            # Variable selection weights of future inputs\r\n            \"future_flags\": future_flags[Ellipsis, 0, :],\r\n        }\r\n\r\n        return transformer_layer, all_inputs, attention_components\r\n\r\n    def build_model(self):\r\n        \"\"\"Build model and defines training losses.\r\n\r\n        Returns:\r\n          Fully defined Keras model.\r\n        \"\"\"\r\n\r\n        with tf.variable_scope(self.name):\r\n            transformer_layer, all_inputs, attention_components = self._build_base_graph()\r\n\r\n            outputs = tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(self.output_size * len(self.quantiles)))(\r\n                transformer_layer[Ellipsis, self.num_encoder_steps :, :]\r\n            )\r\n\r\n            self._attention_components = attention_components\r\n\r\n            adam = tf.keras.optimizers.Adam(lr=self.learning_rate, clipnorm=self.max_gradient_norm)\r\n\r\n            model = tf.keras.Model(inputs=all_inputs, outputs=outputs)\r\n\r\n            print(model.summary())\r\n\r\n            valid_quantiles = self.quantiles\r\n            output_size = self.output_size\r\n\r\n            class QuantileLossCalculator:\r\n                \"\"\"Computes the combined quantile loss for prespecified quantiles.\r\n\r\n                Attributes:\r\n                  quantiles: Quantiles to compute losses\r\n                \"\"\"\r\n\r\n                def __init__(self, quantiles):\r\n                    \"\"\"Initializes computer with quantiles for loss calculations.\r\n\r\n                    Args:\r\n                      quantiles: Quantiles to use for computations.\r\n                    \"\"\"\r\n                    self.quantiles = quantiles\r\n\r\n                def quantile_loss(self, a, b):\r\n                    \"\"\"Returns quantile loss for specified quantiles.\r\n\r\n                    Args:\r\n                      a: Targets\r\n                      b: Predictions\r\n                    \"\"\"\r\n                    quantiles_used = set(self.quantiles)\r\n\r\n                    loss = 0.0\r\n                    for i, quantile in enumerate(valid_quantiles):\r\n                        if quantile in quantiles_used:\r\n                            loss += utils.tensorflow_quantile_loss(\r\n                                a[Ellipsis, output_size * i : output_size * (i + 1)],\r\n                                b[Ellipsis, output_size * i : output_size * (i + 1)],\r\n                                quantile,\r\n                            )\r\n                    return loss\r\n\r\n            quantile_loss = QuantileLossCalculator(valid_quantiles).quantile_loss\r\n\r\n            model.compile(loss=quantile_loss, optimizer=adam, sample_weight_mode=\"temporal\")\r\n\r\n            self._input_placeholder = all_inputs\r\n\r\n        return model\r\n\r\n    def fit(self, train_df=None, valid_df=None):\r\n        \"\"\"Fits deep neural network for given training and validation data.\r\n\r\n        Args:\r\n          train_df: DataFrame for training data\r\n          valid_df: DataFrame for validation data\r\n        \"\"\"\r\n\r\n        print(\"*** Fitting {} ***\".format(self.name))\r\n\r\n        # Add relevant callbacks\r\n        callbacks = [\r\n            tf.keras.callbacks.EarlyStopping(monitor=\"val_loss\", patience=self.early_stopping_patience, min_delta=1e-4),\r\n            tf.keras.callbacks.ModelCheckpoint(\r\n                filepath=self.get_keras_saved_path(self._temp_folder),\r\n                monitor=\"val_loss\",\r\n                save_best_only=True,\r\n                save_weights_only=True,\r\n            ),\r\n            tf.keras.callbacks.TerminateOnNaN(),\r\n        ]\r\n\r\n        print(\"Getting batched_data\")\r\n        if train_df is None:\r\n            print(\"Using cached training data\")\r\n            train_data = TFTDataCache.get(\"train\")\r\n        else:\r\n            train_data = self._batch_data(train_df)\r\n\r\n        if valid_df is None:\r\n            print(\"Using cached validation data\")\r\n            valid_data = TFTDataCache.get(\"valid\")\r\n        else:\r\n            valid_data = self._batch_data(valid_df)\r\n\r\n        print(\"Using keras standard fit\")\r\n\r\n        def _unpack(data):\r\n            return data[\"inputs\"], data[\"outputs\"], self._get_active_locations(data[\"active_entries\"])\r\n\r\n        # Unpack without sample weights\r\n        data, labels, active_flags = _unpack(train_data)\r\n        val_data, val_labels, val_flags = _unpack(valid_data)\r\n\r\n        all_callbacks = callbacks\r\n\r\n        self.model.fit(\r\n            x=data,\r\n            y=np.concatenate([labels, labels, labels], axis=-1),\r\n            sample_weight=active_flags,\r\n            epochs=self.num_epochs,\r\n            batch_size=self.minibatch_size,\r\n            validation_data=(val_data, np.concatenate([val_labels, val_labels, val_labels], axis=-1), val_flags),\r\n            callbacks=all_callbacks,\r\n            shuffle=True,\r\n            use_multiprocessing=True,\r\n            workers=self.n_multiprocessing_workers,\r\n        )\r\n\r\n        # Load best checkpoint again\r\n        tmp_checkpont = self.get_keras_saved_path(self._temp_folder)\r\n        if os.path.exists(tmp_checkpont):\r\n            self.load(self._temp_folder, use_keras_loadings=True)\r\n\r\n        else:\r\n            print(\"Cannot load from {}, skipping ...\".format(self._temp_folder))\r\n\r\n    def evaluate(self, data=None, eval_metric=\"loss\"):\r\n        \"\"\"Applies evaluation metric to the training data.\r\n\r\n        Args:\r\n          data: Dataframe for evaluation\r\n          eval_metric: Evaluation metic to return, based on model definition.\r\n\r\n        Returns:\r\n          Computed evaluation loss.\r\n        \"\"\"\r\n\r\n        if data is None:\r\n            print(\"Using cached validation data\")\r\n            raw_data = TFTDataCache.get(\"valid\")\r\n        else:\r\n            raw_data = self._batch_data(data)\r\n\r\n        inputs = raw_data[\"inputs\"]\r\n        outputs = raw_data[\"outputs\"]\r\n        active_entries = self._get_active_locations(raw_data[\"active_entries\"])\r\n\r\n        metric_values = self.model.evaluate(\r\n            x=inputs,\r\n            y=np.concatenate([outputs, outputs, outputs], axis=-1),\r\n            sample_weight=active_entries,\r\n            workers=16,\r\n            use_multiprocessing=True,\r\n        )\r\n\r\n        metrics = pd.Series(metric_values, self.model.metrics_names)\r\n\r\n        return metrics[eval_metric]\r\n\r\n    def predict(self, df, return_targets=False):\r\n        \"\"\"Computes predictions for a given input dataset.\r\n\r\n        Args:\r\n          df: Input dataframe\r\n          return_targets: Whether to also return outputs aligned with predictions to\r\n            facilitate evaluation\r\n\r\n        Returns:\r\n          Input dataframe or tuple of (input dataframe, aligned output dataframe).\r\n        \"\"\"\r\n\r\n        data = self._batch_data(df)\r\n\r\n        inputs = data[\"inputs\"]\r\n        time = data[\"time\"]\r\n        identifier = data[\"identifier\"]\r\n        outputs = data[\"outputs\"]\r\n\r\n        combined = self.model.predict(inputs, workers=16, use_multiprocessing=True, batch_size=self.minibatch_size)\r\n\r\n        # Format output_csv\r\n        if self.output_size != 1:\r\n            raise NotImplementedError(\"Current version only supports 1D targets!\")\r\n\r\n        def format_outputs(prediction):\r\n            \"\"\"Returns formatted dataframes for prediction.\"\"\"\r\n\r\n            flat_prediction = pd.DataFrame(\r\n                prediction[:, :, 0], columns=[\"t+{}\".format(i) for i in range(self.time_steps - self.num_encoder_steps)]\r\n            )\r\n            cols = list(flat_prediction.columns)\r\n            flat_prediction[\"forecast_time\"] = time[:, self.num_encoder_steps - 1, 0]\r\n            flat_prediction[\"identifier\"] = identifier[:, 0, 0]\r\n\r\n            # Arrange in order\r\n            return flat_prediction[[\"forecast_time\", \"identifier\"] + cols]\r\n\r\n        # Extract predictions for each quantile into different entries\r\n        process_map = {\r\n            \"p{}\".format(int(q * 100)): combined[Ellipsis, i * self.output_size : (i + 1) * self.output_size]\r\n            for i, q in enumerate(self.quantiles)\r\n        }\r\n\r\n        if return_targets:\r\n            # Add targets if relevant\r\n            process_map[\"targets\"] = outputs\r\n\r\n        return {k: format_outputs(process_map[k]) for k in process_map}\r\n\r\n    def get_attention(self, df):\r\n        \"\"\"Computes TFT attention weights for a given dataset.\r\n\r\n        Args:\r\n          df: Input dataframe\r\n\r\n        Returns:\r\n            Dictionary of numpy arrays for temporal attention weights and variable\r\n              selection weights, along with their identifiers and time indices\r\n        \"\"\"\r\n\r\n        data = self._batch_data(df)\r\n        inputs = data[\"inputs\"]\r\n        identifiers = data[\"identifier\"]\r\n        time = data[\"time\"]\r\n\r\n        def get_batch_attention_weights(input_batch):\r\n            \"\"\"Returns weights for a given minibatch of data.\"\"\"\r\n            input_placeholder = self._input_placeholder\r\n            attention_weights = {}\r\n            for k in self._attention_components:\r\n                attention_weight = tf.keras.backend.get_session().run(\r\n                    self._attention_components[k], {input_placeholder: input_batch.astype(np.float32)}\r\n                )\r\n                attention_weights[k] = attention_weight\r\n            return attention_weights\r\n\r\n        # Compute number of batches\r\n        batch_size = self.minibatch_size\r\n        n = inputs.shape[0]\r\n        num_batches = n // batch_size\r\n        if n - (num_batches * batch_size) > 0:\r\n            num_batches += 1\r\n\r\n        # Split up inputs into batches\r\n        batched_inputs = [inputs[i * batch_size : (i + 1) * batch_size, Ellipsis] for i in range(num_batches)]\r\n\r\n        # Get attention weights, while avoiding large memory increases\r\n        attention_by_batch = [get_batch_attention_weights(batch) for batch in batched_inputs]\r\n        attention_weights = {}\r\n        for k in self._attention_components:\r\n            attention_weights[k] = []\r\n            for batch_weights in attention_by_batch:\r\n                attention_weights[k].append(batch_weights[k])\r\n\r\n            if len(attention_weights[k][0].shape) == 4:\r\n                tmp = np.concatenate(attention_weights[k], axis=1)\r\n            else:\r\n                tmp = np.concatenate(attention_weights[k], axis=0)\r\n\r\n            del attention_weights[k]\r\n            gc.collect()\r\n            attention_weights[k] = tmp\r\n\r\n        attention_weights[\"identifiers\"] = identifiers[:, 0, 0]\r\n        attention_weights[\"time\"] = time[:, :, 0]\r\n\r\n        return attention_weights\r\n\r\n    # Serialisation.\r\n    def reset_temp_folder(self):\r\n        \"\"\"Deletes and recreates folder with temporary Keras training outputs.\"\"\"\r\n        print(\"Resetting temp folder...\")\r\n        utils.create_folder_if_not_exist(self._temp_folder)\r\n        shutil.rmtree(self._temp_folder)\r\n        os.makedirs(self._temp_folder)\r\n\r\n    def get_keras_saved_path(self, model_folder):\r\n        \"\"\"Returns path to keras checkpoint.\"\"\"\r\n        return os.path.join(model_folder, \"{}.check\".format(self.name))\r\n\r\n    def save(self, model_folder):\r\n        \"\"\"Saves optimal TFT weights.\r\n\r\n        Args:\r\n          model_folder: Location to serialze model.\r\n        \"\"\"\r\n        # Allows for direct serialisation of tensorflow variables to avoid spurious\r\n        # issue with Keras that leads to different performance evaluation results\r\n        # when model is reloaded (https://github.com/keras-team/keras/issues/4875).\r\n\r\n        utils.save(tf.keras.backend.get_session(), model_folder, cp_name=self.name, scope=self.name)\r\n\r\n    def load(self, model_folder, use_keras_loadings=False):\r\n        \"\"\"Loads TFT weights.\r\n\r\n        Args:\r\n          model_folder: Folder containing serialized models.\r\n          use_keras_loadings: Whether to load from Keras checkpoint.\r\n\r\n        Returns:\r\n\r\n        \"\"\"\r\n        if use_keras_loadings:\r\n            # Loads temporary Keras model saved during training.\r\n            serialisation_path = self.get_keras_saved_path(model_folder)\r\n            print(\"Loading model from {}\".format(serialisation_path))\r\n            self.model.load_weights(serialisation_path)\r\n        else:\r\n            # Loads tensorflow graph for optimal models.\r\n            utils.load(tf.keras.backend.get_session(), model_folder, cp_name=self.name, scope=self.name)\r\n\r\n    @classmethod\r\n    def get_hyperparm_choices(cls):\r\n        \"\"\"Returns hyperparameter ranges for random search.\"\"\"\r\n        return {\r\n            \"dropout_rate\": [0.1, 0.2, 0.3, 0.4, 0.5, 0.7, 0.9],\r\n            \"hidden_layer_size\": [10, 20, 40, 80, 160, 240, 320],\r\n            \"minibatch_size\": [64, 128, 256],\r\n            \"learning_rate\": [1e-4, 1e-3, 1e-2],\r\n            \"max_gradient_norm\": [0.01, 1.0, 100.0],\r\n            \"num_heads\": [1, 4],\r\n            \"stack_size\": [1],\r\n        }\r\n"
  },
  {
    "path": "examples/benchmarks/TFT/libs/utils.py",
    "content": "# coding=utf-8\r\n# Copyright 2020 The Google Research Authors.\r\n#\r\n# Licensed under the Apache License, Version 2.0 (the \"License\");\r\n# you may not use this file except in compliance with the License.\r\n# You may obtain a copy of the License at\r\n#\r\n#     http://www.apache.org/licenses/LICENSE-2.0\r\n#\r\n# Unless required by applicable law or agreed to in writing, software\r\n# distributed under the License is distributed on an \"AS IS\" BASIS,\r\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\r\n# See the License for the specific language governing permissions and\r\n# limitations under the License.\r\n\r\n# Lint as: python3\r\n\"\"\"Generic helper functions used across codebase.\"\"\"\r\n\r\nimport os\r\nimport pathlib\r\n\r\nimport numpy as np\r\nimport tensorflow as tf\r\nfrom tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file\r\n\r\n\r\n# Generic.\r\ndef get_single_col_by_input_type(input_type, column_definition):\r\n    \"\"\"Returns name of single column.\r\n\r\n    Args:\r\n      input_type: Input type of column to extract\r\n      column_definition: Column definition list for experiment\r\n    \"\"\"\r\n\r\n    l = [tup[0] for tup in column_definition if tup[2] == input_type]\r\n\r\n    if len(l) != 1:\r\n        raise ValueError(\"Invalid number of columns for {}\".format(input_type))\r\n\r\n    return l[0]\r\n\r\n\r\ndef extract_cols_from_data_type(data_type, column_definition, excluded_input_types):\r\n    \"\"\"Extracts the names of columns that correspond to a define data_type.\r\n\r\n    Args:\r\n      data_type: DataType of columns to extract.\r\n      column_definition: Column definition to use.\r\n      excluded_input_types: Set of input types to exclude\r\n\r\n    Returns:\r\n      List of names for columns with data type specified.\r\n    \"\"\"\r\n    return [tup[0] for tup in column_definition if tup[1] == data_type and tup[2] not in excluded_input_types]\r\n\r\n\r\n# Loss functions.\r\ndef tensorflow_quantile_loss(y, y_pred, quantile):\r\n    \"\"\"Computes quantile loss for tensorflow.\r\n\r\n    Standard quantile loss as defined in the \"Training Procedure\" section of\r\n    the main TFT paper\r\n\r\n    Args:\r\n      y: Targets\r\n      y_pred: Predictions\r\n      quantile: Quantile to use for loss calculations (between 0 & 1)\r\n\r\n    Returns:\r\n      Tensor for quantile loss.\r\n    \"\"\"\r\n\r\n    # Checks quantile\r\n    if quantile < 0 or quantile > 1:\r\n        raise ValueError(\"Illegal quantile value={}! Values should be between 0 and 1.\".format(quantile))\r\n\r\n    prediction_underflow = y - y_pred\r\n    q_loss = quantile * tf.maximum(prediction_underflow, 0.0) + (1.0 - quantile) * tf.maximum(\r\n        -prediction_underflow, 0.0\r\n    )\r\n\r\n    return tf.reduce_sum(q_loss, axis=-1)\r\n\r\n\r\ndef numpy_normalised_quantile_loss(y, y_pred, quantile):\r\n    \"\"\"Computes normalised quantile loss for numpy arrays.\r\n\r\n    Uses the q-Risk metric as defined in the \"Training Procedure\" section of the\r\n    main TFT paper.\r\n\r\n    Args:\r\n      y: Targets\r\n      y_pred: Predictions\r\n      quantile: Quantile to use for loss calculations (between 0 & 1)\r\n\r\n    Returns:\r\n      Float for normalised quantile loss.\r\n    \"\"\"\r\n    prediction_underflow = y - y_pred\r\n    weighted_errors = quantile * np.maximum(prediction_underflow, 0.0) + (1.0 - quantile) * np.maximum(\r\n        -prediction_underflow, 0.0\r\n    )\r\n\r\n    quantile_loss = weighted_errors.mean()\r\n    normaliser = y.abs().mean()\r\n\r\n    return 2 * quantile_loss / normaliser\r\n\r\n\r\n# OS related functions.\r\ndef create_folder_if_not_exist(directory):\r\n    \"\"\"Creates folder if it doesn't exist.\r\n\r\n    Args:\r\n      directory: Folder path to create.\r\n    \"\"\"\r\n    # Also creates directories recursively\r\n    pathlib.Path(directory).mkdir(parents=True, exist_ok=True)\r\n\r\n\r\n# Tensorflow related functions.\r\ndef get_default_tensorflow_config(tf_device=\"gpu\", gpu_id=0):\r\n    \"\"\"Creates tensorflow config for graphs to run on CPU or GPU.\r\n\r\n    Specifies whether to run graph on gpu or cpu and which GPU ID to use for multi\r\n    GPU machines.\r\n\r\n    Args:\r\n      tf_device: 'cpu' or 'gpu'\r\n      gpu_id: GPU ID to use if relevant\r\n\r\n    Returns:\r\n      Tensorflow config.\r\n    \"\"\"\r\n\r\n    if tf_device == \"cpu\":\r\n        os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"-1\"  # for training on cpu\r\n        tf_config = tf.ConfigProto(log_device_placement=False, device_count={\"GPU\": 0})\r\n\r\n    else:\r\n        os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\"\r\n        os.environ[\"CUDA_VISIBLE_DEVICES\"] = str(gpu_id)\r\n\r\n        print(\"Selecting GPU ID={}\".format(gpu_id))\r\n\r\n        tf_config = tf.ConfigProto(log_device_placement=False)\r\n        tf_config.gpu_options.allow_growth = True\r\n\r\n    return tf_config\r\n\r\n\r\ndef save(tf_session, model_folder, cp_name, scope=None):\r\n    \"\"\"Saves Tensorflow graph to checkpoint.\r\n\r\n    Saves all trainiable variables under a given variable scope to checkpoint.\r\n\r\n    Args:\r\n      tf_session: Session containing graph\r\n      model_folder: Folder to save models\r\n      cp_name: Name of Tensorflow checkpoint\r\n      scope: Variable scope containing variables to save\r\n    \"\"\"\r\n    # Save model\r\n    if scope is None:\r\n        saver = tf.train.Saver()\r\n    else:\r\n        var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=scope)\r\n        saver = tf.train.Saver(var_list=var_list, max_to_keep=100000)\r\n\r\n    save_path = saver.save(tf_session, os.path.join(model_folder, \"{0}.ckpt\".format(cp_name)))\r\n    print(\"Model saved to: {0}\".format(save_path))\r\n\r\n\r\ndef load(tf_session, model_folder, cp_name, scope=None, verbose=False):\r\n    \"\"\"Loads Tensorflow graph from checkpoint.\r\n\r\n    Args:\r\n      tf_session: Session to load graph into\r\n      model_folder: Folder containing serialised model\r\n      cp_name: Name of Tensorflow checkpoint\r\n      scope: Variable scope to use.\r\n      verbose: Whether to print additional debugging information.\r\n    \"\"\"\r\n    # Load model proper\r\n    load_path = os.path.join(model_folder, \"{0}.ckpt\".format(cp_name))\r\n\r\n    print(\"Loading model from {0}\".format(load_path))\r\n\r\n    print_weights_in_checkpoint(model_folder, cp_name)\r\n\r\n    initial_vars = set([v.name for v in tf.get_default_graph().as_graph_def().node])\r\n\r\n    # Saver\r\n    if scope is None:\r\n        saver = tf.train.Saver()\r\n    else:\r\n        var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=scope)\r\n        saver = tf.train.Saver(var_list=var_list, max_to_keep=100000)\r\n    # Load\r\n    saver.restore(tf_session, load_path)\r\n    all_vars = set([v.name for v in tf.get_default_graph().as_graph_def().node])\r\n\r\n    if verbose:\r\n        print(\"Restored {0}\".format(\",\".join(initial_vars.difference(all_vars))))\r\n        print(\"Existing {0}\".format(\",\".join(all_vars.difference(initial_vars))))\r\n        print(\"All {0}\".format(\",\".join(all_vars)))\r\n\r\n    print(\"Done.\")\r\n\r\n\r\ndef print_weights_in_checkpoint(model_folder, cp_name):\r\n    \"\"\"Prints all weights in Tensorflow checkpoint.\r\n\r\n    Args:\r\n      model_folder: Folder containing checkpoint\r\n      cp_name: Name of checkpoint\r\n\r\n    Returns:\r\n\r\n    \"\"\"\r\n    load_path = os.path.join(model_folder, \"{0}.ckpt\".format(cp_name))\r\n\r\n    print_tensors_in_checkpoint_file(file_name=load_path, tensor_name=\"\", all_tensors=True, all_tensor_names=True)\r\n"
  },
  {
    "path": "examples/benchmarks/TFT/requirements.txt",
    "content": "tensorflow-gpu==1.15.0\r\npandas==1.1.0\r\n"
  },
  {
    "path": "examples/benchmarks/TFT/tft.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom pathlib import Path\nfrom typing import Union\nimport numpy as np\nimport pandas as pd\nimport tensorflow.compat.v1 as tf\nimport data_formatters.base\nimport expt_settings.configs\nimport libs.hyperparam_opt\nimport libs.tft_model\nimport libs.utils as utils\nimport os\nimport datetime as dte\n\n\nfrom qlib.model.base import ModelFT\nfrom qlib.data.dataset import DatasetH\nfrom qlib.data.dataset.handler import DataHandlerLP\n\n# To register new datasets, please add them here.\nALLOW_DATASET = [\"Alpha158\", \"Alpha360\"]\n# To register new datasets, please add their configurations here.\nDATASET_SETTING = {\n    \"Alpha158\": {\n        \"feature_col\": [\n            \"RESI5\",\n            \"WVMA5\",\n            \"RSQR5\",\n            \"KLEN\",\n            \"RSQR10\",\n            \"CORR5\",\n            \"CORD5\",\n            \"CORR10\",\n            \"ROC60\",\n            \"RESI10\",\n            \"VSTD5\",\n            \"RSQR60\",\n            \"CORR60\",\n            \"WVMA60\",\n            \"STD5\",\n            \"RSQR20\",\n            \"CORD60\",\n            \"CORD10\",\n            \"CORR20\",\n            \"KLOW\",\n        ],\n        \"label_col\": \"LABEL0\",\n    },\n    \"Alpha360\": {\n        \"feature_col\": [\n            \"HIGH0\",\n            \"LOW0\",\n            \"OPEN0\",\n            \"CLOSE1\",\n            \"HIGH1\",\n            \"VOLUME1\",\n            \"LOW1\",\n            \"VOLUME3\",\n            \"OPEN1\",\n            \"VOLUME4\",\n            \"CLOSE2\",\n            \"CLOSE4\",\n            \"VOLUME5\",\n            \"LOW2\",\n            \"CLOSE3\",\n            \"VOLUME2\",\n            \"HIGH2\",\n            \"LOW4\",\n            \"VOLUME8\",\n            \"VOLUME11\",\n        ],\n        \"label_col\": \"LABEL0\",\n    },\n}\n\n\ndef get_shifted_label(data_df, shifts=5, col_shift=\"LABEL0\"):\n    return data_df[[col_shift]].groupby(\"instrument\", group_keys=False).apply(lambda df: df.shift(shifts))\n\n\ndef fill_test_na(test_df):\n    test_df_res = test_df.copy()\n    feature_cols = ~test_df_res.columns.str.contains(\"label\", case=False)\n    test_feature_fna = (\n        test_df_res.loc[:, feature_cols].groupby(\"datetime\", group_keys=False).apply(lambda df: df.fillna(df.mean()))\n    )\n    test_df_res.loc[:, feature_cols] = test_feature_fna\n    return test_df_res\n\n\ndef process_qlib_data(df, dataset, fillna=False):\n    \"\"\"Prepare data to fit the TFT model.\n\n    Args:\n      df: Original DataFrame.\n      fillna: Whether to fill the data with the mean values.\n\n    Returns:\n      Transformed DataFrame.\n\n    \"\"\"\n    # Several features selected manually\n    feature_col = DATASET_SETTING[dataset][\"feature_col\"]\n    label_col = [DATASET_SETTING[dataset][\"label_col\"]]\n    temp_df = df.loc[:, feature_col + label_col]\n    if fillna:\n        temp_df = fill_test_na(temp_df)\n    temp_df = temp_df.swaplevel()\n    temp_df = temp_df.sort_index()\n    temp_df = temp_df.reset_index(level=0)\n    dates = pd.to_datetime(temp_df.index)\n    temp_df[\"date\"] = dates\n    temp_df[\"day_of_week\"] = dates.dayofweek\n    temp_df[\"month\"] = dates.month\n    temp_df[\"year\"] = dates.year\n    temp_df[\"const\"] = 1.0\n    return temp_df\n\n\ndef process_predicted(df, col_name):\n    \"\"\"Transform the TFT predicted data into Qlib format.\n\n    Args:\n      df: Original DataFrame.\n      fillna: New column name.\n\n    Returns:\n      Transformed DataFrame.\n\n    \"\"\"\n    df_res = df.copy()\n    df_res = df_res.rename(columns={\"forecast_time\": \"datetime\", \"identifier\": \"instrument\", \"t+4\": col_name})\n    df_res = df_res.set_index([\"datetime\", \"instrument\"]).sort_index()\n    df_res = df_res[[col_name]]\n    return df_res\n\n\ndef format_score(forecast_df, col_name=\"pred\", label_shift=5):\n    pred = process_predicted(forecast_df, col_name=col_name)\n    pred = get_shifted_label(pred, shifts=-label_shift, col_shift=col_name)\n    pred = pred.dropna()[col_name]\n    return pred\n\n\ndef transform_df(df, col_name=\"LABEL0\"):\n    df_res = df[\"feature\"]\n    df_res[col_name] = df[\"label\"]\n    return df_res\n\n\nclass TFTModel(ModelFT):\n    \"\"\"TFT Model\"\"\"\n\n    def __init__(self, **kwargs):\n        self.model = None\n        self.params = {\"DATASET\": \"Alpha158\", \"label_shift\": 5}\n        self.params.update(kwargs)\n\n    def _prepare_data(self, dataset: DatasetH):\n        df_train, df_valid = dataset.prepare(\n            [\"train\", \"valid\"], col_set=[\"feature\", \"label\"], data_key=DataHandlerLP.DK_L\n        )\n        return transform_df(df_train), transform_df(df_valid)\n\n    def fit(self, dataset: DatasetH, MODEL_FOLDER=\"qlib_tft_model\", USE_GPU_ID=0, **kwargs):\n        DATASET = self.params[\"DATASET\"]\n        LABEL_SHIFT = self.params[\"label_shift\"]\n        LABEL_COL = DATASET_SETTING[DATASET][\"label_col\"]\n\n        if DATASET not in ALLOW_DATASET:\n            raise AssertionError(\"The dataset is not supported, please make a new formatter to fit this dataset\")\n\n        dtrain, dvalid = self._prepare_data(dataset)\n        dtrain.loc[:, LABEL_COL] = get_shifted_label(dtrain, shifts=LABEL_SHIFT, col_shift=LABEL_COL)\n        dvalid.loc[:, LABEL_COL] = get_shifted_label(dvalid, shifts=LABEL_SHIFT, col_shift=LABEL_COL)\n\n        train = process_qlib_data(dtrain, DATASET, fillna=True).dropna()\n        valid = process_qlib_data(dvalid, DATASET, fillna=True).dropna()\n\n        ExperimentConfig = expt_settings.configs.ExperimentConfig\n        config = ExperimentConfig(DATASET)\n        self.data_formatter = config.make_data_formatter()\n        self.model_folder = MODEL_FOLDER\n        self.gpu_id = USE_GPU_ID\n        self.label_shift = LABEL_SHIFT\n        self.expt_name = DATASET\n        self.label_col = LABEL_COL\n\n        use_gpu = (True, self.gpu_id)\n        # ===========================Training Process===========================\n        ModelClass = libs.tft_model.TemporalFusionTransformer\n        if not isinstance(self.data_formatter, data_formatters.base.GenericDataFormatter):\n            raise ValueError(\n                \"Data formatters should inherit from\"\n                + \"AbstractDataFormatter! Type={}\".format(type(self.data_formatter))\n            )\n\n        default_keras_session = tf.keras.backend.get_session()\n\n        if use_gpu[0]:\n            self.tf_config = utils.get_default_tensorflow_config(tf_device=\"gpu\", gpu_id=use_gpu[1])\n        else:\n            self.tf_config = utils.get_default_tensorflow_config(tf_device=\"cpu\")\n\n        self.data_formatter.set_scalers(train)\n\n        # Sets up default params\n        fixed_params = self.data_formatter.get_experiment_params()\n        params = self.data_formatter.get_default_model_params()\n\n        params = {**params, **fixed_params}\n\n        if not os.path.exists(self.model_folder):\n            os.makedirs(self.model_folder)\n        params[\"model_folder\"] = self.model_folder\n\n        print(\"*** Begin training ***\")\n        best_loss = np.Inf\n\n        tf.reset_default_graph()\n\n        self.tf_graph = tf.Graph()\n        with self.tf_graph.as_default():\n            self.sess = tf.Session(config=self.tf_config)\n            tf.keras.backend.set_session(self.sess)\n            self.model = ModelClass(params, use_cudnn=use_gpu[0])\n            self.sess.run(tf.global_variables_initializer())\n            self.model.fit(train_df=train, valid_df=valid)\n            print(\"*** Finished training ***\")\n            saved_model_dir = self.model_folder + \"/\" + \"saved_model\"\n            if not os.path.exists(saved_model_dir):\n                os.makedirs(saved_model_dir)\n            self.model.save(saved_model_dir)\n\n            def extract_numerical_data(data):\n                \"\"\"Strips out forecast time and identifier columns.\"\"\"\n                return data[[col for col in data.columns if col not in {\"forecast_time\", \"identifier\"}]]\n\n            # p50_loss = utils.numpy_normalised_quantile_loss(\n            #    extract_numerical_data(targets), extract_numerical_data(p50_forecast),\n            #    0.5)\n            # p90_loss = utils.numpy_normalised_quantile_loss(\n            #    extract_numerical_data(targets), extract_numerical_data(p90_forecast),\n            #    0.9)\n            tf.keras.backend.set_session(default_keras_session)\n        print(\"Training completed at {}.\".format(dte.datetime.now()))\n        # ===========================Training Process===========================\n\n    def predict(self, dataset):\n        if self.model is None:\n            raise ValueError(\"model is not fitted yet!\")\n        d_test = dataset.prepare(\"test\", col_set=[\"feature\", \"label\"])\n        d_test = transform_df(d_test)\n        d_test.loc[:, self.label_col] = get_shifted_label(d_test, shifts=self.label_shift, col_shift=self.label_col)\n        test = process_qlib_data(d_test, self.expt_name, fillna=True).dropna()\n\n        use_gpu = (True, self.gpu_id)\n        # ===========================Predicting Process===========================\n        default_keras_session = tf.keras.backend.get_session()\n\n        # Sets up default params\n        fixed_params = self.data_formatter.get_experiment_params()\n        params = self.data_formatter.get_default_model_params()\n        params = {**params, **fixed_params}\n\n        print(\"*** Begin predicting ***\")\n        tf.reset_default_graph()\n\n        with self.tf_graph.as_default():\n            tf.keras.backend.set_session(self.sess)\n            output_map = self.model.predict(test, return_targets=True)\n            targets = self.data_formatter.format_predictions(output_map[\"targets\"])\n            p50_forecast = self.data_formatter.format_predictions(output_map[\"p50\"])\n            p90_forecast = self.data_formatter.format_predictions(output_map[\"p90\"])\n            tf.keras.backend.set_session(default_keras_session)\n\n        predict50 = format_score(p50_forecast, \"pred\", 1)\n        predict90 = format_score(p90_forecast, \"pred\", 1)\n        predict = (predict50 + predict90) / 2  # self.label_shift\n        # ===========================Predicting Process===========================\n        return predict\n\n    def finetune(self, dataset: DatasetH):\n        \"\"\"\n        finetune model\n        Parameters\n        ----------\n        dataset : DatasetH\n            dataset for finetuning\n        \"\"\"\n        pass\n\n    def to_pickle(self, path: Union[Path, str]):\n        \"\"\"\n        Tensorflow model can't be dumped directly.\n        So the data should be save separately\n\n        **TODO**: Please implement the function to load the files\n\n        Parameters\n        ----------\n        path : Union[Path, str]\n            the target path to be dumped\n        \"\"\"\n        # FIXME: implementing saving tensorflow models\n        # save tensorflow model\n        # path = Path(path)\n        # path.mkdir(parents=True)\n        # self.model.save(path)\n\n        # save qlib model wrapper\n        drop_attrs = [\"model\", \"tf_graph\", \"sess\", \"data_formatter\"]\n        orig_attr = {}\n        for attr in drop_attrs:\n            orig_attr[attr] = getattr(self, attr)\n            setattr(self, attr, None)\n        super(TFTModel, self).to_pickle(path)\n        for attr in drop_attrs:\n            setattr(self, attr, orig_attr[attr])\n"
  },
  {
    "path": "examples/benchmarks/TFT/workflow_config_tft_Alpha158.yaml",
    "content": "sys:\n    rel_path: .\nqlib_init:\n    provider_uri: \"~/.qlib/qlib_data/cn_data\"\n    region: cn\nmarket: &market csi300\nbenchmark: &benchmark SH000300\ndata_handler_config: &data_handler_config\n    start_time: 2008-01-01\n    end_time: 2020-08-01\n    fit_start_time: 2008-01-01\n    fit_end_time: 2014-12-31\n    instruments: *market\nport_analysis_config: &port_analysis_config\n    strategy:\n        class: TopkDropoutStrategy\n        module_path: qlib.contrib.strategy\n        kwargs:\n            signal: <PRED>\n            topk: 50\n            n_drop: 5\n    backtest:\n        start_time: 2017-01-01\n        end_time: 2020-08-01\n        account: 100000000\n        benchmark: *benchmark\n        exchange_kwargs:\n            limit_threshold: 0.095\n            deal_price: close\n            open_cost: 0.0005\n            close_cost: 0.0015\n            min_cost: 5\ntask:\n    model:\n        class: TFTModel\n        module_path: tft\n    dataset:\n        class: DatasetH\n        module_path: qlib.data.dataset\n        kwargs:\n            handler:\n                class: Alpha158\n                module_path: qlib.contrib.data.handler\n                kwargs: *data_handler_config\n            segments:\n                train: [2008-01-01, 2014-12-31]\n                valid: [2015-01-01, 2016-12-31]\n                test: [2017-01-01, 2020-08-01]\n    record: \n        - class: SignalRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            model: <MODEL>\n            dataset: <DATASET>\n        - class: SigAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            ana_long_short: False\n            ann_scaler: 252\n        - class: PortAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            config: *port_analysis_config\n"
  },
  {
    "path": "examples/benchmarks/TRA/README.md",
    "content": "# Learning Multiple Stock Trading Patterns with Temporal Routing Adaptor and Optimal Transport\n\nTemporal Routing Adaptor (TRA) is designed to capture multiple trading patterns in the stock market data. Please refer to [our paper](http://arxiv.org/abs/2106.12950) for more details.\n\nIf you find our work useful in your research, please cite:\n```\n@inproceedings{HengxuKDD2021,\n author = {Hengxu Lin and Dong Zhou and Weiqing Liu and Jiang Bian},\n title = {Learning Multiple Stock Trading Patterns with Temporal Routing Adaptor and Optimal Transport},\n booktitle = {Proceedings of the 27th ACM SIGKDD Conference on Knowledge Discovery \\& Data Mining},\n series = {KDD '21},\n year = {2021},\n publisher = {ACM},\n}\n\n@article{yang2020qlib,\n  title={Qlib: An AI-oriented Quantitative Investment Platform},\n  author={Yang, Xiao and Liu, Weiqing and Zhou, Dong and Bian, Jiang and Liu, Tie-Yan},\n  journal={arXiv preprint arXiv:2009.11189},\n  year={2020}\n}\n```\n\n## Usage (Recommended)\n\n**Update**: `TRA` has been moved to `qlib.contrib.model.pytorch_tra` to support other `Qlib` components like  `qlib.workflow` and `Alpha158/Alpha360` dataset.\n\nPlease follow the official [doc](https://qlib.readthedocs.io/en/latest/component/workflow.html) to use `TRA` with `workflow`. Here we also provide several example config files:\n\n- `workflow_config_tra_Alpha360.yaml`: running `TRA` with `Alpha360` dataset\n- `workflow_config_tra_Alpha158.yaml`: running `TRA` with `Alpha158` dataset (with feature subsampling)\n- `workflow_config_tra_Alpha158_full.yaml`: running `TRA` with `Alpha158` dataset (without feature subsampling)\n\nThe performances of `TRA` are reported in [Benchmarks](https://github.com/microsoft/qlib/tree/main/examples/benchmarks).\n\n## Usage (Not Maintained)\n\nThis section is used to reproduce the results in the paper.\n\n### Running\n\nWe attach our running scripts for the paper in `run.sh`.\n\nAnd here are two ways to run the model:\n\n* Running from scripts with default parameters\n\n  You can directly run from Qlib command `qrun`:\n  ```\n  qrun configs/config_alstm.yaml\n  ```\n\n* Running from code with self-defined parameters\n\n  Setting different parameters is also allowed. See codes in `example.py`:\n  ```\n  python example.py --config_file configs/config_alstm.yaml\n  ```\n\nHere we trained TRA on a pretrained backbone model. Therefore we run `*_init.yaml` before TRA's scripts.\n\n### Results\n\nAfter running the scripts, you can find result files in path `./output`:\n\n* `info.json` - config settings and result metrics.\n* `log.csv` - running logs.\n* `model.bin` - the model parameter dictionary.\n* `pred.pkl` - the prediction scores and output for inference.\n\nEvaluation metrics reported in the paper:\nThis result is generated by qlib==0.7.1.\n\n| Methods | MSE| MAE| IC | ICIR | AR | AV | SR | MDD |\n|-------|-------|------|-----|-----|-----|-----|-----|-----|\n|Linear|0.163|0.327|0.020|0.132|-3.2%|16.8%|-0.191|32.1%|\n|LightGBM|0.160(0.000)|0.323(0.000)|0.041|0.292|7.8%|15.5%|0.503|25.7%|\n|MLP|0.160(0.002)|0.323(0.003)|0.037|0.273|3.7%|15.3%|0.264|26.2%|\n|SFM|0.159(0.001)\t|0.321(0.001)\t|0.047\t|0.381\t|7.1%\t|14.3%\t|0.497\t|22.9%|\n|ALSTM|0.158(0.001)\t|0.320(0.001)\t|0.053\t|0.419\t|12.3%\t|13.7%\t|0.897\t|20.2%|\n|Trans.|0.158(0.001)\t|0.322(0.001)\t|0.051\t|0.400\t|14.5%\t|14.2%\t|1.028\t|22.5%|\n|ALSTM+TS|0.160(0.002)\t|0.321(0.002)\t|0.039\t|0.291\t|6.7%\t|14.6%\t|0.480|22.3%|\n|Trans.+TS|0.160(0.004)\t|0.324(0.005)\t|0.037\t|0.278\t|10.4%\t|14.7%\t|0.722\t|23.7%|\n|ALSTM+TRA(Ours)|0.157(0.000)\t|0.318(0.000)\t|0.059\t|0.460\t|12.4%\t|14.0%\t|0.885\t|20.4%|\n|Trans.+TRA(Ours)|0.157(0.000)\t|0.320(0.000)\t|0.056\t|0.442\t|16.1%\t|14.2%\t|1.133\t|23.1%|\n\nA more detailed demo for our experiment results in the paper can be found in `Report.ipynb`.\n\n## Common Issues\n\nFor help or issues using TRA, please submit a GitHub issue.\n\nSometimes we might encounter situation where the loss is `NaN`, please check the `epsilon` parameter in the sinkhorn algorithm, adjusting the `epsilon` according to input's scale is important.\n"
  },
  {
    "path": "examples/benchmarks/TRA/configs/config_alstm.yaml",
    "content": "qlib_init:\n  provider_uri: \"~/.qlib/qlib_data/cn_data\"\n  region: cn\n\ndata_loader_config: &data_loader_config\n  class: StaticDataLoader\n  module_path: qlib.data.dataset.loader\n  kwargs:\n    config:\n      feature: data/feature.pkl\n      label: data/label.pkl\n\nmodel_config: &model_config\n  input_size: 16\n  hidden_size: 256\n  num_layers: 2\n  num_heads: 2\n  use_attn: True\n  dropout: 0.1\n\nnum_states: &num_states 1\n\ntra_config: &tra_config\n  num_states: *num_states\n  hidden_size: 16\n  tau: 1.0\n  src_info: LR_TPE\n\ntask:\n  model:\n    class: TRAModel\n    module_path: src/model.py\n    kwargs:\n      lr: 0.0002\n      n_epochs: 500\n      max_steps_per_epoch: 100\n      early_stop: 20\n      seed: 1000\n      logdir: output/test/alstm\n      model_type: LSTM\n      model_config: *model_config\n      tra_config: *tra_config\n      lamb: 1.0\n      rho: 0.99\n      freeze_model: False\n      model_init_state: \n  dataset:\n    class: MTSDatasetH\n    module_path: src/dataset.py\n    kwargs:\n      handler:\n        class: DataHandler\n        module_path: qlib.data.dataset.handler\n        kwargs:\n          data_loader: *data_loader_config\n      segments:\n        train: [2007-10-30, 2016-05-27]\n        valid: [2016-09-26, 2018-05-29]\n        test: [2018-09-21, 2020-06-30]\n      seq_len: 60\n      horizon: 21\n      num_states: *num_states\n      batch_size: 1024"
  },
  {
    "path": "examples/benchmarks/TRA/configs/config_alstm_tra.yaml",
    "content": "qlib_init:\n  provider_uri: \"~/.qlib/qlib_data/cn_data\"\n  region: cn\n\ndata_loader_config: &data_loader_config\n  class: StaticDataLoader\n  module_path: qlib.data.dataset.loader\n  kwargs:\n    config:\n      feature: data/feature.pkl\n      label: data/label.pkl\n\nmodel_config: &model_config\n  input_size: 16\n  hidden_size: 256\n  num_layers: 2\n  num_heads: 2\n  use_attn: True\n  dropout: 0.1\n\nnum_states: &num_states 10\n\ntra_config: &tra_config\n  num_states: *num_states\n  hidden_size: 16\n  tau: 1.0\n  src_info: LR_TPE\n\ntask:\n  model:\n    class: TRAModel\n    module_path: src/model.py\n    kwargs:\n      lr: 0.0001\n      n_epochs: 500\n      max_steps_per_epoch: 100\n      early_stop: 20\n      seed: 1000\n      logdir: output/test/alstm_tra\n      model_type: LSTM\n      model_config: *model_config\n      tra_config: *tra_config\n      lamb: 2.0\n      rho: 0.99\n      freeze_model: True\n      model_init_state: output/test/alstm_tra_init/model.bin\n  dataset:\n    class: MTSDatasetH\n    module_path: src/dataset.py\n    kwargs:\n      handler:\n        class: DataHandler\n        module_path: qlib.data.dataset.handler\n        kwargs:\n          data_loader: *data_loader_config\n      segments:\n        train: [2007-10-30, 2016-05-27]\n        valid: [2016-09-26, 2018-05-29]\n        test: [2018-09-21, 2020-06-30]\n      seq_len: 60\n      horizon: 21\n      num_states: *num_states\n      batch_size: 1024"
  },
  {
    "path": "examples/benchmarks/TRA/configs/config_alstm_tra_init.yaml",
    "content": "qlib_init:\n  provider_uri: \"~/.qlib/qlib_data/cn_data\"\n  region: cn\n\ndata_loader_config: &data_loader_config\n  class: StaticDataLoader\n  module_path: qlib.data.dataset.loader\n  kwargs:\n    config:\n      feature: data/feature.pkl\n      label: data/label.pkl\n\nmodel_config: &model_config\n  input_size: 16\n  hidden_size: 256\n  num_layers: 2\n  num_heads: 2\n  use_attn: True\n  dropout: 0.1\n\nnum_states: &num_states 3\n\ntra_config: &tra_config\n  num_states: *num_states\n  hidden_size: 16\n  tau: 1.0\n  src_info: LR_TPE\n\ntask:\n  model:\n    class: TRAModel\n    module_path: src/model.py\n    kwargs:\n      lr: 0.0002\n      n_epochs: 500\n      max_steps_per_epoch: 100\n      early_stop: 20\n      seed: 1000\n      logdir: output/test/alstm_tra_init\n      model_type: LSTM\n      model_config: *model_config\n      tra_config: *tra_config\n      lamb: 1.0\n      rho: 0.99\n      freeze_model: False\n      model_init_state: \n  dataset:\n    class: MTSDatasetH\n    module_path: src/dataset.py\n    kwargs:\n      handler:\n        class: DataHandler\n        module_path: qlib.data.dataset.handler\n        kwargs:\n          data_loader: *data_loader_config\n      segments:\n        train: [2007-10-30, 2016-05-27]\n        valid: [2016-09-26, 2018-05-29]\n        test: [2018-09-21, 2020-06-30]\n      seq_len: 60\n      horizon: 21\n      num_states: *num_states\n      batch_size: 512"
  },
  {
    "path": "examples/benchmarks/TRA/configs/config_transformer.yaml",
    "content": "qlib_init:\n  provider_uri: \"~/.qlib/qlib_data/cn_data\"\n  region: cn\n\ndata_loader_config: &data_loader_config\n  class: StaticDataLoader\n  module_path: qlib.data.dataset.loader\n  kwargs:\n    config:\n      feature: data/feature.pkl\n      label: data/label.pkl\n\nmodel_config: &model_config\n  input_size: 16\n  hidden_size: 64\n  num_layers: 2\n  num_heads: 4\n  use_attn: False\n  dropout: 0.1\n\nnum_states: &num_states 1\n\ntra_config: &tra_config\n  num_states: *num_states\n  hidden_size: 16\n  tau: 1.0\n  src_info: LR_TPE\n\ntask:\n  model:\n    class: TRAModel\n    module_path: src/model.py\n    kwargs:\n      lr: 0.0002\n      n_epochs: 500\n      max_steps_per_epoch: 100\n      early_stop: 20\n      seed: 1000\n      logdir: output/test/transformer\n      model_type: Transformer\n      model_config: *model_config\n      tra_config: *tra_config\n      lamb: 1.0\n      rho: 0.99\n      freeze_model: False\n      model_init_state: \n  dataset:\n    class: MTSDatasetH\n    module_path: src/dataset.py\n    kwargs:\n      handler:\n        class: DataHandler\n        module_path: qlib.data.dataset.handler\n        kwargs:\n          data_loader: *data_loader_config\n      segments:\n        train: [2007-10-30, 2016-05-27]\n        valid: [2016-09-26, 2018-05-29]\n        test: [2018-09-21, 2020-06-30]\n      seq_len: 60\n      horizon: 21\n      num_states: *num_states\n      batch_size: 1024"
  },
  {
    "path": "examples/benchmarks/TRA/configs/config_transformer_tra.yaml",
    "content": "qlib_init:\n  provider_uri: \"~/.qlib/qlib_data/cn_data\"\n  region: cn\n\ndata_loader_config: &data_loader_config\n  class: StaticDataLoader\n  module_path: qlib.data.dataset.loader\n  kwargs:\n    config:\n      feature: data/feature.pkl\n      label: data/label.pkl\n\nmodel_config: &model_config\n  input_size: 16\n  hidden_size: 64\n  num_layers: 2\n  num_heads: 4\n  use_attn: False\n  dropout: 0.1\n\nnum_states: &num_states 3\n\ntra_config: &tra_config\n  num_states: *num_states\n  hidden_size: 16\n  tau: 1.0\n  src_info: LR_TPE\n\ntask:\n  model:\n    class: TRAModel\n    module_path: src/model.py\n    kwargs:\n      lr: 0.0005\n      n_epochs: 500\n      max_steps_per_epoch: 100\n      early_stop: 20\n      seed: 1000\n      logdir: output/test/transformer_tra\n      model_type: Transformer\n      model_config: *model_config\n      tra_config: *tra_config\n      lamb: 1.0\n      rho: 0.99\n      freeze_model: True\n      model_init_state: output/test/transformer_tra_init/model.bin\n  dataset:\n    class: MTSDatasetH\n    module_path: src/dataset.py\n    kwargs:\n      handler:\n        class: DataHandler\n        module_path: qlib.data.dataset.handler\n        kwargs:\n          data_loader: *data_loader_config\n      segments:\n        train: [2007-10-30, 2016-05-27]\n        valid: [2016-09-26, 2018-05-29]\n        test: [2018-09-21, 2020-06-30]\n      seq_len: 60\n      horizon: 21\n      num_states: *num_states\n      batch_size: 512"
  },
  {
    "path": "examples/benchmarks/TRA/configs/config_transformer_tra_init.yaml",
    "content": "qlib_init:\n  provider_uri: \"~/.qlib/qlib_data/cn_data\"\n  region: cn\n\ndata_loader_config: &data_loader_config\n  class: StaticDataLoader\n  module_path: qlib.data.dataset.loader\n  kwargs:\n    config:\n      feature: data/feature.pkl\n      label: data/label.pkl\n\nmodel_config: &model_config\n  input_size: 16\n  hidden_size: 64\n  num_layers: 2\n  num_heads: 4\n  use_attn: False\n  dropout: 0.1\n\nnum_states: &num_states 3\n\ntra_config: &tra_config\n  num_states: *num_states\n  hidden_size: 16\n  tau: 1.0\n  src_info: LR_TPE\n\ntask:\n  model:\n    class: TRAModel\n    module_path: src/model.py\n    kwargs:\n      lr: 0.0002\n      n_epochs: 500\n      max_steps_per_epoch: 100\n      early_stop: 20\n      seed: 1000\n      logdir: output/test/transformer_tra_init\n      model_type: Transformer\n      model_config: *model_config\n      tra_config: *tra_config\n      lamb: 1.0\n      rho: 0.99\n      freeze_model: False\n      model_init_state: \n  dataset:\n    class: MTSDatasetH\n    module_path: src/dataset.py\n    kwargs:\n      handler:\n        class: DataHandler\n        module_path: qlib.data.dataset.handler\n        kwargs:\n          data_loader: *data_loader_config\n      segments:\n        train: [2007-10-30, 2016-05-27]\n        valid: [2016-09-26, 2018-05-29]\n        test: [2018-09-21, 2020-06-30]\n      seq_len: 60\n      horizon: 21\n      num_states: *num_states\n      batch_size: 512"
  },
  {
    "path": "examples/benchmarks/TRA/data/README.md",
    "content": "Data Link: https://drive.google.com/drive/folders/1fMqZYSeLyrHiWmVzygeI4sw3vp5Gt8cY?usp=sharing\n"
  },
  {
    "path": "examples/benchmarks/TRA/example.py",
    "content": "import argparse\n\nimport qlib\nfrom ruamel.yaml import YAML\nfrom qlib.utils import init_instance_by_config\n\n\ndef main(seed, config_file=\"configs/config_alstm.yaml\"):\n    # set random seed\n    with open(config_file) as f:\n        yaml = YAML(typ=\"safe\", pure=True)\n        config = yaml.load(f)\n\n    # seed_suffix = \"/seed1000\" if \"init\" in config_file else f\"/seed{seed}\"\n    seed_suffix = \"\"\n    config[\"task\"][\"model\"][\"kwargs\"].update(\n        {\"seed\": seed, \"logdir\": config[\"task\"][\"model\"][\"kwargs\"][\"logdir\"] + seed_suffix}\n    )\n\n    # initialize workflow\n    qlib.init(\n        provider_uri=config[\"qlib_init\"][\"provider_uri\"],\n        region=config[\"qlib_init\"][\"region\"],\n    )\n    dataset = init_instance_by_config(config[\"task\"][\"dataset\"])\n    model = init_instance_by_config(config[\"task\"][\"model\"])\n\n    # train model\n    model.fit(dataset)\n\n\nif __name__ == \"__main__\":\n    # set params from cmd\n    parser = argparse.ArgumentParser(allow_abbrev=False)\n    parser.add_argument(\"--seed\", type=int, default=1000, help=\"random seed\")\n    parser.add_argument(\"--config_file\", type=str, default=\"configs/config_alstm.yaml\", help=\"config file\")\n    args = parser.parse_args()\n    main(**vars(args))\n"
  },
  {
    "path": "examples/benchmarks/TRA/requirements.txt",
    "content": "pandas==1.1.2\nnumpy==1.21.0\nscikit_learn==0.23.2\ntorch==1.7.0\nseaborn\n"
  },
  {
    "path": "examples/benchmarks/TRA/run.sh",
    "content": "#!/bin/bash\n\n# we used random seed(1 1000 2000 3000 4000 5000) in our experiments \n\n# Directly run from Qlib command `qrun`\nqrun configs/config_alstm.yaml\n\nqrun configs/config_transformer.yaml\n\nqrun configs/config_transformer_tra_init.yaml\nqrun configs/config_transformer_tra.yaml\n\nqrun configs/config_alstm_tra_init.yaml\nqrun configs/config_alstm_tra.yaml\n\n\n# Or setting different parameters with example.py\npython example.py --config_file configs/config_alstm.yaml\n\npython example.py --config_file configs/config_transformer.yaml\n\npython example.py --config_file configs/config_transformer_tra_init.yaml\npython example.py --config_file configs/config_transformer_tra.yaml\n\npython example.py --config_file configs/config_alstm_tra_init.yaml\npython example.py --config_file configs/config_alstm_tra.yaml\n\n\n\n"
  },
  {
    "path": "examples/benchmarks/TRA/src/dataset.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport copy\nimport torch\nimport numpy as np\nimport pandas as pd\n\nfrom qlib.data.dataset import DatasetH\n\ndevice = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\n\ndef _to_tensor(x):\n    if not isinstance(x, torch.Tensor):\n        return torch.tensor(x, dtype=torch.float, device=device)\n    return x\n\n\ndef _create_ts_slices(index, seq_len):\n    \"\"\"\n    create time series slices from pandas index\n\n    Args:\n        index (pd.MultiIndex): pandas multiindex with <instrument, datetime> order\n        seq_len (int): sequence length\n    \"\"\"\n    assert index.is_lexsorted(), \"index should be sorted\"\n\n    # number of dates for each code\n    sample_count_by_codes = pd.Series(0, index=index).groupby(level=0, group_keys=False).size().values\n\n    # start_index for each code\n    start_index_of_codes = np.roll(np.cumsum(sample_count_by_codes), 1)\n    start_index_of_codes[0] = 0\n\n    # all the [start, stop) indices of features\n    # features btw [start, stop) are used to predict the `stop - 1` label\n    slices = []\n    for cur_loc, cur_cnt in zip(start_index_of_codes, sample_count_by_codes):\n        for stop in range(1, cur_cnt + 1):\n            end = cur_loc + stop\n            start = max(end - seq_len, 0)\n            slices.append(slice(start, end))\n    slices = np.array(slices)\n\n    return slices\n\n\ndef _get_date_parse_fn(target):\n    \"\"\"get date parse function\n\n    This method is used to parse date arguments as target type.\n\n    Example:\n        get_date_parse_fn('20120101')('2017-01-01') => '20170101'\n        get_date_parse_fn(20120101)('2017-01-01') => 20170101\n    \"\"\"\n    if isinstance(target, pd.Timestamp):\n        _fn = lambda x: pd.Timestamp(x)  # Timestamp('2020-01-01')\n    elif isinstance(target, str) and len(target) == 8:\n        _fn = lambda x: str(x).replace(\"-\", \"\")[:8]  # '20200201'\n    elif isinstance(target, int):\n        _fn = lambda x: int(str(x).replace(\"-\", \"\")[:8])  # 20200201\n    else:\n        _fn = lambda x: x\n    return _fn\n\n\nclass MTSDatasetH(DatasetH):\n    \"\"\"Memory Augmented Time Series Dataset\n\n    Args:\n        handler (DataHandler): data handler\n        segments (dict): data split segments\n        seq_len (int): time series sequence length\n        horizon (int): label horizon (to mask historical loss for TRA)\n        num_states (int): how many memory states to be added (for TRA)\n        batch_size (int): batch size (<0 means daily batch)\n        shuffle (bool): whether shuffle data\n        pin_memory (bool): whether pin data to gpu memory\n        drop_last (bool): whether drop last batch < batch_size\n    \"\"\"\n\n    def __init__(\n        self,\n        handler,\n        segments,\n        seq_len=60,\n        horizon=0,\n        num_states=1,\n        batch_size=-1,\n        shuffle=True,\n        pin_memory=False,\n        drop_last=False,\n        **kwargs,\n    ):\n        assert horizon > 0, \"please specify `horizon` to avoid data leakage\"\n\n        self.seq_len = seq_len\n        self.horizon = horizon\n        self.num_states = num_states\n        self.batch_size = batch_size\n        self.shuffle = shuffle\n        self.drop_last = drop_last\n        self.pin_memory = pin_memory\n        self.params = (batch_size, drop_last, shuffle)  # for train/eval switch\n\n        super().__init__(handler, segments, **kwargs)\n\n    def setup_data(self, handler_kwargs: dict = None, **kwargs):\n        super().setup_data()\n\n        # change index to <code, date>\n        # NOTE: we will use inplace sort to reduce memory use\n        df = self.handler._data\n        df.index = df.index.swaplevel()\n        df.sort_index(inplace=True)\n\n        self._data = df[\"feature\"].values.astype(\"float32\")\n        self._label = df[\"label\"].squeeze().astype(\"float32\")\n        self._index = df.index\n\n        # add memory to feature\n        self._data = np.c_[self._data, np.zeros((len(self._data), self.num_states), dtype=np.float32)]\n\n        # padding tensor\n        self.zeros = np.zeros((self.seq_len, self._data.shape[1]), dtype=np.float32)\n\n        # pin memory\n        if self.pin_memory:\n            self._data = _to_tensor(self._data)\n            self._label = _to_tensor(self._label)\n            self.zeros = _to_tensor(self.zeros)\n\n        # create batch slices\n        self.batch_slices = _create_ts_slices(self._index, self.seq_len)\n\n        # create daily slices\n        index = [slc.stop - 1 for slc in self.batch_slices]\n        act_index = self.restore_index(index)\n        daily_slices = {date: [] for date in sorted(act_index.unique(level=1))}\n        for i, (code, date) in enumerate(act_index):\n            daily_slices[date].append(self.batch_slices[i])\n        self.daily_slices = list(daily_slices.values())\n\n    def _prepare_seg(self, slc, **kwargs):\n        fn = _get_date_parse_fn(self._index[0][1])\n\n        if isinstance(slc, slice):\n            start, stop = slc.start, slc.stop\n        elif isinstance(slc, (list, tuple)):\n            start, stop = slc\n        else:\n            raise NotImplementedError(f\"This type of input is not supported\")\n        start_date = fn(start)\n        end_date = fn(stop)\n        obj = copy.copy(self)  # shallow copy\n        # NOTE: Seriable will disable copy `self._data` so we manually assign them here\n        obj._data = self._data\n        obj._label = self._label\n        obj._index = self._index\n        new_batch_slices = []\n        for batch_slc in self.batch_slices:\n            date = self._index[batch_slc.stop - 1][1]\n            if start_date <= date <= end_date:\n                new_batch_slices.append(batch_slc)\n        obj.batch_slices = np.array(new_batch_slices)\n        new_daily_slices = []\n        for daily_slc in self.daily_slices:\n            date = self._index[daily_slc[0].stop - 1][1]\n            if start_date <= date <= end_date:\n                new_daily_slices.append(daily_slc)\n        obj.daily_slices = new_daily_slices\n        return obj\n\n    def restore_index(self, index):\n        if isinstance(index, torch.Tensor):\n            index = index.cpu().numpy()\n        return self._index[index]\n\n    def assign_data(self, index, vals):\n        if isinstance(self._data, torch.Tensor):\n            vals = _to_tensor(vals)\n        elif isinstance(vals, torch.Tensor):\n            vals = vals.detach().cpu().numpy()\n            index = index.detach().cpu().numpy()\n        self._data[index, -self.num_states :] = vals\n\n    def clear_memory(self):\n        self._data[:, -self.num_states :] = 0\n\n    # TODO: better train/eval mode design\n    def train(self):\n        \"\"\"enable traning mode\"\"\"\n        self.batch_size, self.drop_last, self.shuffle = self.params\n\n    def eval(self):\n        \"\"\"enable evaluation mode\"\"\"\n        self.batch_size = -1\n        self.drop_last = False\n        self.shuffle = False\n\n    def _get_slices(self):\n        if self.batch_size < 0:\n            slices = self.daily_slices.copy()\n            batch_size = -1 * self.batch_size\n        else:\n            slices = self.batch_slices.copy()\n            batch_size = self.batch_size\n        return slices, batch_size\n\n    def __len__(self):\n        slices, batch_size = self._get_slices()\n        if self.drop_last:\n            return len(slices) // batch_size\n        return (len(slices) + batch_size - 1) // batch_size\n\n    def __iter__(self):\n        slices, batch_size = self._get_slices()\n        if self.shuffle:\n            np.random.shuffle(slices)\n\n        for i in range(len(slices))[::batch_size]:\n            if self.drop_last and i + batch_size > len(slices):\n                break\n            # get slices for this batch\n            slices_subset = slices[i : i + batch_size]\n            if self.batch_size < 0:\n                slices_subset = np.concatenate(slices_subset)\n            # collect data\n            data = []\n            label = []\n            index = []\n            for slc in slices_subset:\n                _data = self._data[slc].clone() if self.pin_memory else self._data[slc].copy()\n                if len(_data) != self.seq_len:\n                    if self.pin_memory:\n                        _data = torch.cat([self.zeros[: self.seq_len - len(_data)], _data], axis=0)\n                    else:\n                        _data = np.concatenate([self.zeros[: self.seq_len - len(_data)], _data], axis=0)\n                if self.num_states > 0:\n                    _data[-self.horizon :, -self.num_states :] = 0\n                data.append(_data)\n                label.append(self._label[slc.stop - 1])\n                index.append(slc.stop - 1)\n            # concate\n            index = torch.tensor(index, device=device)\n            if isinstance(data[0], torch.Tensor):\n                data = torch.stack(data)\n                label = torch.stack(label)\n            else:\n                data = _to_tensor(np.stack(data))\n                label = _to_tensor(np.stack(label))\n            # yield -> generator\n            yield {\"data\": data, \"label\": label, \"index\": index}\n"
  },
  {
    "path": "examples/benchmarks/TRA/src/model.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport os\nimport copy\nimport math\nimport json\nimport collections\nimport numpy as np\nimport pandas as pd\n\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nimport torch.nn.functional as F\n\nfrom tqdm import tqdm\n\nfrom qlib.utils import get_or_create_path\nfrom qlib.log import get_module_logger\nfrom qlib.model.base import Model\n\ndevice = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\n\nclass TRAModel(Model):\n    def __init__(\n        self,\n        model_config,\n        tra_config,\n        model_type=\"LSTM\",\n        lr=1e-3,\n        n_epochs=500,\n        early_stop=50,\n        smooth_steps=5,\n        max_steps_per_epoch=None,\n        freeze_model=False,\n        model_init_state=None,\n        lamb=0.0,\n        rho=0.99,\n        seed=None,\n        logdir=None,\n        eval_train=True,\n        eval_test=False,\n        avg_params=True,\n        **kwargs,\n    ):\n        np.random.seed(seed)\n        torch.manual_seed(seed)\n\n        self.logger = get_module_logger(\"TRA\")\n        self.logger.info(\"TRA Model...\")\n\n        self.model = eval(model_type)(**model_config).to(device)\n        if model_init_state:\n            self.model.load_state_dict(torch.load(model_init_state, map_location=\"cpu\")[\"model\"])\n        if freeze_model:\n            for param in self.model.parameters():\n                param.requires_grad_(False)\n        else:\n            self.logger.info(\"# model params: %d\" % sum([p.numel() for p in self.model.parameters()]))\n\n        self.tra = TRA(self.model.output_size, **tra_config).to(device)\n        self.logger.info(\"# tra params: %d\" % sum([p.numel() for p in self.tra.parameters()]))\n\n        self.optimizer = optim.Adam(list(self.model.parameters()) + list(self.tra.parameters()), lr=lr)\n\n        self.model_config = model_config\n        self.tra_config = tra_config\n        self.lr = lr\n        self.n_epochs = n_epochs\n        self.early_stop = early_stop\n        self.smooth_steps = smooth_steps\n        self.max_steps_per_epoch = max_steps_per_epoch\n        self.lamb = lamb\n        self.rho = rho\n        self.seed = seed\n        self.logdir = logdir\n        self.eval_train = eval_train\n        self.eval_test = eval_test\n        self.avg_params = avg_params\n\n        if self.tra.num_states > 1 and not self.eval_train:\n            self.logger.warn(\"`eval_train` will be ignored when using TRA\")\n\n        if self.logdir is not None:\n            if os.path.exists(self.logdir):\n                self.logger.warn(f\"logdir {self.logdir} is not empty\")\n            os.makedirs(self.logdir, exist_ok=True)\n\n        self.fitted = False\n        self.global_step = -1\n\n    def train_epoch(self, data_set):\n        self.model.train()\n        self.tra.train()\n\n        data_set.train()\n\n        max_steps = self.n_epochs\n        if self.max_steps_per_epoch is not None:\n            max_steps = min(self.max_steps_per_epoch, self.n_epochs)\n\n        count = 0\n        total_loss = 0\n        total_count = 0\n        for batch in tqdm(data_set, total=max_steps):\n            count += 1\n            if count > max_steps:\n                break\n\n            self.global_step += 1\n\n            data, label, index = batch[\"data\"], batch[\"label\"], batch[\"index\"]\n\n            feature = data[:, :, : -self.tra.num_states]\n            hist_loss = data[:, : -data_set.horizon, -self.tra.num_states :]\n\n            hidden = self.model(feature)\n            pred, all_preds, prob = self.tra(hidden, hist_loss)\n\n            loss = (pred - label).pow(2).mean()\n\n            L = (all_preds.detach() - label[:, None]).pow(2)\n            L -= L.min(dim=-1, keepdim=True).values  # normalize & ensure positive input\n\n            data_set.assign_data(index, L)  # save loss to memory\n\n            if prob is not None:\n                P = sinkhorn(-L, epsilon=0.01)  # sample assignment matrix\n                lamb = self.lamb * (self.rho**self.global_step)\n                reg = prob.log().mul(P).sum(dim=-1).mean()\n                loss = loss - lamb * reg\n\n            loss.backward()\n            self.optimizer.step()\n            self.optimizer.zero_grad()\n\n            total_loss += loss.item()\n            total_count += len(pred)\n\n        total_loss /= total_count\n\n        return total_loss\n\n    def test_epoch(self, data_set, return_pred=False):\n        self.model.eval()\n        self.tra.eval()\n        data_set.eval()\n\n        preds = []\n        metrics = []\n        for batch in tqdm(data_set):\n            data, label, index = batch[\"data\"], batch[\"label\"], batch[\"index\"]\n\n            feature = data[:, :, : -self.tra.num_states]\n            hist_loss = data[:, : -data_set.horizon, -self.tra.num_states :]\n\n            with torch.no_grad():\n                hidden = self.model(feature)\n                pred, all_preds, prob = self.tra(hidden, hist_loss)\n\n            L = (all_preds - label[:, None]).pow(2)\n\n            L -= L.min(dim=-1, keepdim=True).values  # normalize & ensure positive input\n\n            data_set.assign_data(index, L)  # save loss to memory\n\n            X = np.c_[\n                pred.cpu().numpy(),\n                label.cpu().numpy(),\n            ]\n            columns = [\"score\", \"label\"]\n            if prob is not None:\n                X = np.c_[X, all_preds.cpu().numpy(), prob.cpu().numpy()]\n                columns += [\"score_%d\" % d for d in range(all_preds.shape[1])] + [\n                    \"prob_%d\" % d for d in range(all_preds.shape[1])\n                ]\n\n            pred = pd.DataFrame(X, index=index.cpu().numpy(), columns=columns)\n\n            metrics.append(evaluate(pred))\n\n            if return_pred:\n                preds.append(pred)\n\n        metrics = pd.DataFrame(metrics)\n        metrics = {\n            \"MSE\": metrics.MSE.mean(),\n            \"MAE\": metrics.MAE.mean(),\n            \"IC\": metrics.IC.mean(),\n            \"ICIR\": metrics.IC.mean() / metrics.IC.std(),\n        }\n\n        if return_pred:\n            preds = pd.concat(preds, axis=0)\n            preds.index = data_set.restore_index(preds.index)\n            preds.index = preds.index.swaplevel()\n            preds.sort_index(inplace=True)\n\n        return metrics, preds\n\n    def fit(self, dataset, evals_result=dict()):\n        train_set, valid_set, test_set = dataset.prepare([\"train\", \"valid\", \"test\"])\n\n        best_score = -1\n        best_epoch = 0\n        stop_rounds = 0\n        best_params = {\n            \"model\": copy.deepcopy(self.model.state_dict()),\n            \"tra\": copy.deepcopy(self.tra.state_dict()),\n        }\n        params_list = {\n            \"model\": collections.deque(maxlen=self.smooth_steps),\n            \"tra\": collections.deque(maxlen=self.smooth_steps),\n        }\n        evals_result[\"train\"] = []\n        evals_result[\"valid\"] = []\n        evals_result[\"test\"] = []\n\n        # train\n        self.fitted = True\n        self.global_step = -1\n\n        if self.tra.num_states > 1:\n            self.logger.info(\"init memory...\")\n            self.test_epoch(train_set)\n\n        for epoch in range(self.n_epochs):\n            self.logger.info(\"Epoch %d:\", epoch)\n\n            self.logger.info(\"training...\")\n            self.train_epoch(train_set)\n\n            self.logger.info(\"evaluating...\")\n            # average params for inference\n            params_list[\"model\"].append(copy.deepcopy(self.model.state_dict()))\n            params_list[\"tra\"].append(copy.deepcopy(self.tra.state_dict()))\n            self.model.load_state_dict(average_params(params_list[\"model\"]))\n            self.tra.load_state_dict(average_params(params_list[\"tra\"]))\n\n            # NOTE: during evaluating, the whole memory will be refreshed\n            if self.tra.num_states > 1 or self.eval_train:\n                train_set.clear_memory()  # NOTE: clear the shared memory\n                train_metrics = self.test_epoch(train_set)[0]\n                evals_result[\"train\"].append(train_metrics)\n                self.logger.info(\"\\ttrain metrics: %s\" % train_metrics)\n\n            valid_metrics = self.test_epoch(valid_set)[0]\n            evals_result[\"valid\"].append(valid_metrics)\n            self.logger.info(\"\\tvalid metrics: %s\" % valid_metrics)\n\n            if self.eval_test:\n                test_metrics = self.test_epoch(test_set)[0]\n                evals_result[\"test\"].append(test_metrics)\n                self.logger.info(\"\\ttest metrics: %s\" % test_metrics)\n\n            if valid_metrics[\"IC\"] > best_score:\n                best_score = valid_metrics[\"IC\"]\n                stop_rounds = 0\n                best_epoch = epoch\n                best_params = {\n                    \"model\": copy.deepcopy(self.model.state_dict()),\n                    \"tra\": copy.deepcopy(self.tra.state_dict()),\n                }\n            else:\n                stop_rounds += 1\n                if stop_rounds >= self.early_stop:\n                    self.logger.info(\"early stop @ %s\" % epoch)\n                    break\n\n            # restore parameters\n            self.model.load_state_dict(params_list[\"model\"][-1])\n            self.tra.load_state_dict(params_list[\"tra\"][-1])\n\n        self.logger.info(\"best score: %.6lf @ %d\" % (best_score, best_epoch))\n        self.model.load_state_dict(best_params[\"model\"])\n        self.tra.load_state_dict(best_params[\"tra\"])\n\n        metrics, preds = self.test_epoch(test_set, return_pred=True)\n        self.logger.info(\"test metrics: %s\" % metrics)\n\n        if self.logdir:\n            self.logger.info(\"save model & pred to local directory\")\n\n            pd.concat({name: pd.DataFrame(evals_result[name]) for name in evals_result}, axis=1).to_csv(\n                self.logdir + \"/logs.csv\", index=False\n            )\n\n            torch.save(best_params, self.logdir + \"/model.bin\")\n\n            preds.to_pickle(self.logdir + \"/pred.pkl\")\n\n            info = {\n                \"config\": {\n                    \"model_config\": self.model_config,\n                    \"tra_config\": self.tra_config,\n                    \"lr\": self.lr,\n                    \"n_epochs\": self.n_epochs,\n                    \"early_stop\": self.early_stop,\n                    \"smooth_steps\": self.smooth_steps,\n                    \"max_steps_per_epoch\": self.max_steps_per_epoch,\n                    \"lamb\": self.lamb,\n                    \"rho\": self.rho,\n                    \"seed\": self.seed,\n                    \"logdir\": self.logdir,\n                },\n                \"best_eval_metric\": -best_score,  # NOTE: minux -1 for minimize\n                \"metric\": metrics,\n            }\n            with open(self.logdir + \"/info.json\", \"w\") as f:\n                json.dump(info, f)\n\n    def predict(self, dataset, segment=\"test\"):\n        if not self.fitted:\n            raise ValueError(\"model is not fitted yet!\")\n\n        test_set = dataset.prepare(segment)\n\n        metrics, preds = self.test_epoch(test_set, return_pred=True)\n        self.logger.info(\"test metrics: %s\" % metrics)\n\n        return preds\n\n\nclass LSTM(nn.Module):\n    \"\"\"LSTM Model\n\n    Args:\n        input_size (int): input size (# features)\n        hidden_size (int): hidden size\n        num_layers (int): number of hidden layers\n        use_attn (bool): whether use attention layer.\n            we use concat attention as https://github.com/fulifeng/Adv-ALSTM/\n        dropout (float): dropout rate\n        input_drop (float): input dropout for data augmentation\n        noise_level (float): add gaussian noise to input for data augmentation\n    \"\"\"\n\n    def __init__(\n        self,\n        input_size=16,\n        hidden_size=64,\n        num_layers=2,\n        use_attn=True,\n        dropout=0.0,\n        input_drop=0.0,\n        noise_level=0.0,\n        *args,\n        **kwargs,\n    ):\n        super().__init__()\n\n        self.input_size = input_size\n        self.hidden_size = hidden_size\n        self.num_layers = num_layers\n        self.use_attn = use_attn\n        self.noise_level = noise_level\n\n        self.input_drop = nn.Dropout(input_drop)\n\n        self.rnn = nn.LSTM(\n            input_size=input_size,\n            hidden_size=hidden_size,\n            num_layers=num_layers,\n            batch_first=True,\n            dropout=dropout,\n        )\n\n        if self.use_attn:\n            self.W = nn.Linear(hidden_size, hidden_size)\n            self.u = nn.Linear(hidden_size, 1, bias=False)\n            self.output_size = hidden_size * 2\n        else:\n            self.output_size = hidden_size\n\n    def forward(self, x):\n        x = self.input_drop(x)\n\n        if self.training and self.noise_level > 0:\n            noise = torch.randn_like(x).to(x)\n            x = x + noise * self.noise_level\n\n        rnn_out, _ = self.rnn(x)\n        last_out = rnn_out[:, -1]\n\n        if self.use_attn:\n            laten = self.W(rnn_out).tanh()\n            scores = self.u(laten).softmax(dim=1)\n            att_out = (rnn_out * scores).sum(dim=1).squeeze()\n            last_out = torch.cat([last_out, att_out], dim=1)\n\n        return last_out\n\n\nclass PositionalEncoding(nn.Module):\n    # reference: https://pytorch.org/tutorials/beginner/transformer_tutorial.html\n    def __init__(self, d_model, dropout=0.1, max_len=5000):\n        super(PositionalEncoding, self).__init__()\n        self.dropout = nn.Dropout(p=dropout)\n\n        pe = torch.zeros(max_len, d_model)\n        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)\n        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))\n        pe[:, 0::2] = torch.sin(position * div_term)\n        pe[:, 1::2] = torch.cos(position * div_term)\n        pe = pe.unsqueeze(0).transpose(0, 1)\n        self.register_buffer(\"pe\", pe)\n\n    def forward(self, x):\n        x = x + self.pe[: x.size(0), :]\n        return self.dropout(x)\n\n\nclass Transformer(nn.Module):\n    \"\"\"Transformer Model\n\n    Args:\n        input_size (int): input size (# features)\n        hidden_size (int): hidden size\n        num_layers (int): number of transformer layers\n        num_heads (int): number of heads in transformer\n        dropout (float): dropout rate\n        input_drop (float): input dropout for data augmentation\n        noise_level (float): add gaussian noise to input for data augmentation\n    \"\"\"\n\n    def __init__(\n        self,\n        input_size=16,\n        hidden_size=64,\n        num_layers=2,\n        num_heads=2,\n        dropout=0.0,\n        input_drop=0.0,\n        noise_level=0.0,\n        **kwargs,\n    ):\n        super().__init__()\n\n        self.input_size = input_size\n        self.hidden_size = hidden_size\n        self.num_layers = num_layers\n        self.num_heads = num_heads\n        self.noise_level = noise_level\n\n        self.input_drop = nn.Dropout(input_drop)\n\n        self.input_proj = nn.Linear(input_size, hidden_size)\n\n        self.pe = PositionalEncoding(input_size, dropout)\n        layer = nn.TransformerEncoderLayer(\n            nhead=num_heads, dropout=dropout, d_model=hidden_size, dim_feedforward=hidden_size * 4\n        )\n        self.encoder = nn.TransformerEncoder(layer, num_layers=num_layers)\n\n        self.output_size = hidden_size\n\n    def forward(self, x):\n        x = self.input_drop(x)\n\n        if self.training and self.noise_level > 0:\n            noise = torch.randn_like(x).to(x)\n            x = x + noise * self.noise_level\n\n        x = x.permute(1, 0, 2).contiguous()  # the first dim need to be sequence\n        x = self.pe(x)\n\n        x = self.input_proj(x)\n        out = self.encoder(x)\n\n        return out[-1]\n\n\nclass TRA(nn.Module):\n    \"\"\"Temporal Routing Adaptor (TRA)\n\n    TRA takes historical prediction errors & latent representation as inputs,\n    then routes the input sample to a specific predictor for training & inference.\n\n    Args:\n        input_size (int): input size (RNN/Transformer's hidden size)\n        num_states (int): number of latent states (i.e., trading patterns)\n            If `num_states=1`, then TRA falls back to traditional methods\n        hidden_size (int): hidden size of the router\n        tau (float): gumbel softmax temperature\n    \"\"\"\n\n    def __init__(self, input_size, num_states=1, hidden_size=8, tau=1.0, src_info=\"LR_TPE\"):\n        super().__init__()\n\n        self.num_states = num_states\n        self.tau = tau\n        self.src_info = src_info\n\n        if num_states > 1:\n            self.router = nn.LSTM(\n                input_size=num_states,\n                hidden_size=hidden_size,\n                num_layers=1,\n                batch_first=True,\n            )\n            self.fc = nn.Linear(hidden_size + input_size, num_states)\n\n        self.predictors = nn.Linear(input_size, num_states)\n\n    def forward(self, hidden, hist_loss):\n        preds = self.predictors(hidden)\n\n        if self.num_states == 1:\n            return preds.squeeze(-1), preds, None\n\n        # information type\n        router_out, _ = self.router(hist_loss)\n        if \"LR\" in self.src_info:\n            latent_representation = hidden\n        else:\n            latent_representation = torch.randn(hidden.shape).to(hidden)\n        if \"TPE\" in self.src_info:\n            temporal_pred_error = router_out[:, -1]\n        else:\n            temporal_pred_error = torch.randn(router_out[:, -1].shape).to(hidden)\n\n        out = self.fc(torch.cat([temporal_pred_error, latent_representation], dim=-1))\n        prob = F.gumbel_softmax(out, dim=-1, tau=self.tau, hard=False)\n\n        if self.training:\n            final_pred = (preds * prob).sum(dim=-1)\n        else:\n            final_pred = preds[range(len(preds)), prob.argmax(dim=-1)]\n\n        return final_pred, preds, prob\n\n\ndef evaluate(pred):\n    pred = pred.rank(pct=True)  # transform into percentiles\n    score = pred.score\n    label = pred.label\n    diff = score - label\n    MSE = (diff**2).mean()\n    MAE = (diff.abs()).mean()\n    IC = score.corr(label)\n    return {\"MSE\": MSE, \"MAE\": MAE, \"IC\": IC}\n\n\ndef average_params(params_list):\n    assert isinstance(params_list, (tuple, list, collections.deque))\n    n = len(params_list)\n    if n == 1:\n        return params_list[0]\n    new_params = collections.OrderedDict()\n    keys = None\n    for i, params in enumerate(params_list):\n        if keys is None:\n            keys = params.keys()\n        for k, v in params.items():\n            if k not in keys:\n                raise ValueError(\"the %d-th model has different params\" % i)\n            if k not in new_params:\n                new_params[k] = v / n\n            else:\n                new_params[k] += v / n\n    return new_params\n\n\ndef shoot_infs(inp_tensor):\n    \"\"\"Replaces inf by maximum of tensor\"\"\"\n    mask_inf = torch.isinf(inp_tensor)\n    ind_inf = torch.nonzero(mask_inf, as_tuple=False)\n    if len(ind_inf) > 0:\n        for ind in ind_inf:\n            if len(ind) == 2:\n                inp_tensor[ind[0], ind[1]] = 0\n            elif len(ind) == 1:\n                inp_tensor[ind[0]] = 0\n        m = torch.max(inp_tensor)\n        for ind in ind_inf:\n            if len(ind) == 2:\n                inp_tensor[ind[0], ind[1]] = m\n            elif len(ind) == 1:\n                inp_tensor[ind[0]] = m\n    return inp_tensor\n\n\ndef sinkhorn(Q, n_iters=3, epsilon=0.01):\n    # epsilon should be adjusted according to logits value's scale\n    with torch.no_grad():\n        Q = shoot_infs(Q)\n        Q = torch.exp(Q / epsilon)\n        for i in range(n_iters):\n            Q /= Q.sum(dim=0, keepdim=True)\n            Q /= Q.sum(dim=1, keepdim=True)\n    return Q\n"
  },
  {
    "path": "examples/benchmarks/TRA/workflow_config_tra_Alpha158.yaml",
    "content": "qlib_init:\n  provider_uri: \"~/.qlib/qlib_data/cn_data\"\n  region: cn\n\nmarket: &market csi300\nbenchmark: &benchmark SH000300\n\ndata_handler_config: &data_handler_config\n  start_time: 2008-01-01\n  end_time: 2020-08-01\n  fit_start_time: 2008-01-01\n  fit_end_time: 2014-12-31\n  instruments: *market\n  infer_processors:\n    - class: FilterCol\n      kwargs:\n        fields_group: feature\n        col_list: [\"RESI5\", \"WVMA5\", \"RSQR5\", \"KLEN\", \"RSQR10\", \"CORR5\", \"CORD5\", \"CORR10\",\n                   \"ROC60\", \"RESI10\", \"VSTD5\", \"RSQR60\", \"CORR60\", \"WVMA60\", \"STD5\",\n                   \"RSQR20\", \"CORD60\", \"CORD10\", \"CORR20\", \"KLOW\"]\n    - class: RobustZScoreNorm\n      kwargs:\n        fields_group: feature\n        clip_outlier: true\n    - class: Fillna\n      kwargs:\n        fields_group: feature\n  learn_processors:\n    - class: CSRankNorm\n      kwargs:\n        fields_group: label\n  label: [\"Ref($close, -2) / Ref($close, -1) - 1\"]\n\nnum_states: &num_states 3\n\nmemory_mode: &memory_mode sample\n\ntra_config: &tra_config\n  num_states: *num_states\n  rnn_arch: LSTM\n  hidden_size: 32\n  num_layers: 1\n  dropout: 0.0\n  tau: 1.0\n  src_info: LR_TPE\n\nmodel_config: &model_config\n  input_size: 20\n  hidden_size: 64\n  num_layers: 2\n  rnn_arch: LSTM\n  use_attn: True\n  dropout: 0.0\n\nport_analysis_config: &port_analysis_config\n    strategy:\n        class: TopkDropoutStrategy\n        module_path: qlib.contrib.strategy\n        kwargs:\n            signal: <PRED>\n            topk: 50\n            n_drop: 5\n    backtest:\n        start_time: 2017-01-01\n        end_time: 2020-08-01\n        account: 100000000\n        benchmark: *benchmark\n        exchange_kwargs:\n            limit_threshold: 0.095\n            deal_price: close\n            open_cost: 0.0005\n            close_cost: 0.0015\n            min_cost: 5\n\ntask:\n  model:\n    class: TRAModel\n    module_path: qlib.contrib.model.pytorch_tra\n    kwargs:\n      tra_config: *tra_config\n      model_config: *model_config\n      model_type: RNN\n      lr: 1e-3\n      n_epochs: 100\n      max_steps_per_epoch:\n      early_stop: 20\n      logdir: output/Alpha158\n      seed: 0\n      lamb: 1.0\n      rho: 0.99\n      alpha: 0.5\n      transport_method: router\n      memory_mode: *memory_mode\n      eval_train: False\n      eval_test: True\n      pretrain: True\n      init_state:\n      freeze_model: False\n      freeze_predictors: False\n  dataset:\n    class: MTSDatasetH\n    module_path: qlib.contrib.data.dataset\n    kwargs:\n      handler:\n        class: Alpha158\n        module_path: qlib.contrib.data.handler\n        kwargs: *data_handler_config\n      segments:\n        train: [2008-01-01, 2014-12-31]\n        valid: [2015-01-01, 2016-12-31]\n        test: [2017-01-01, 2020-08-01]\n      seq_len: 60\n      input_size:\n      num_states: *num_states\n      batch_size: 1024\n      n_samples:\n      memory_mode: *memory_mode\n      drop_last: True\n  record:\n    - class: SignalRecord\n      module_path: qlib.workflow.record_temp\n      kwargs: \n        model: <MODEL>\n        dataset: <DATASET>\n    - class: SigAnaRecord\n      module_path: qlib.workflow.record_temp\n      kwargs: \n        ana_long_short: False\n        ann_scaler: 252\n    - class: PortAnaRecord\n      module_path: qlib.workflow.record_temp\n      kwargs: \n        config: *port_analysis_config\n"
  },
  {
    "path": "examples/benchmarks/TRA/workflow_config_tra_Alpha158_full.yaml",
    "content": "qlib_init:\n  provider_uri: \"~/.qlib/qlib_data/cn_data\"\n  region: cn\n\nmarket: &market csi300\nbenchmark: &benchmark SH000300\n\ndata_handler_config: &data_handler_config\n  start_time: 2008-01-01\n  end_time: 2020-08-01\n  fit_start_time: 2008-01-01\n  fit_end_time: 2014-12-31\n  instruments: *market\n  infer_processors:\n    - class: RobustZScoreNorm\n      kwargs:\n        fields_group: feature\n        clip_outlier: true\n    - class: Fillna\n      kwargs:\n        fields_group: feature\n  learn_processors:\n    - class: CSRankNorm\n      kwargs:\n        fields_group: label\n  label: [\"Ref($close, -2) / Ref($close, -1) - 1\"]\n\nnum_states: &num_states 3\n\nmemory_mode: &memory_mode sample\n\ntra_config: &tra_config\n  num_states: *num_states\n  rnn_arch: LSTM\n  hidden_size: 32\n  num_layers: 1\n  dropout: 0.0\n  tau: 1.0\n  src_info: LR_TPE\n\nmodel_config: &model_config\n  input_size: 158\n  hidden_size: 256\n  num_layers: 2\n  rnn_arch: LSTM\n  use_attn: True\n  dropout: 0.2\n\nport_analysis_config: &port_analysis_config\n    strategy:\n        class: TopkDropoutStrategy\n        module_path: qlib.contrib.strategy\n        kwargs:\n            signal: <PRED>\n            topk: 50\n            n_drop: 5\n    backtest:\n        start_time: 2017-01-01\n        end_time: 2020-08-01\n        account: 100000000\n        benchmark: *benchmark\n        exchange_kwargs:\n            limit_threshold: 0.095\n            deal_price: close\n            open_cost: 0.0005\n            close_cost: 0.0015\n            min_cost: 5\n\ntask:\n  model:\n    class: TRAModel\n    module_path: qlib.contrib.model.pytorch_tra\n    kwargs:\n      tra_config: *tra_config\n      model_config: *model_config\n      model_type: RNN\n      lr: 1e-3\n      n_epochs: 100\n      max_steps_per_epoch:\n      early_stop: 20\n      logdir: output/Alpha158_full\n      seed: 0\n      lamb: 1.0\n      rho: 0.99\n      alpha: 0.5\n      transport_method: router\n      memory_mode: *memory_mode\n      eval_train: False\n      eval_test: True\n      pretrain: True\n      init_state:\n      freeze_model: False\n      freeze_predictors: False\n  dataset:\n    class: MTSDatasetH\n    module_path: qlib.contrib.data.dataset\n    kwargs:\n      handler:\n        class: Alpha158\n        module_path: qlib.contrib.data.handler\n        kwargs: *data_handler_config\n      segments:\n        train: [2008-01-01, 2014-12-31]\n        valid: [2015-01-01, 2016-12-31]\n        test: [2017-01-01, 2020-08-01]\n      seq_len: 60\n      input_size:\n      num_states: *num_states\n      batch_size: 1024\n      n_samples:\n      memory_mode: *memory_mode\n      drop_last: True\n  record:\n    - class: SignalRecord\n      module_path: qlib.workflow.record_temp\n      kwargs: \n        model: <MODEL>\n        dataset: <DATASET>\n    - class: SigAnaRecord\n      module_path: qlib.workflow.record_temp\n      kwargs: \n        ana_long_short: False\n        ann_scaler: 252\n    - class: PortAnaRecord\n      module_path: qlib.workflow.record_temp\n      kwargs:\n        config: *port_analysis_config\n"
  },
  {
    "path": "examples/benchmarks/TRA/workflow_config_tra_Alpha360.yaml",
    "content": "qlib_init:\n  provider_uri: \"~/.qlib/qlib_data/cn_data\"\n  region: cn\n\nmarket: &market csi300\nbenchmark: &benchmark SH000300\n\ndata_handler_config: &data_handler_config\n  start_time: 2008-01-01\n  end_time: 2020-08-01\n  fit_start_time: 2008-01-01\n  fit_end_time: 2014-12-31\n  instruments: *market\n  infer_processors:\n    - class: RobustZScoreNorm\n      kwargs:\n        fields_group: feature\n        clip_outlier: true\n    - class: Fillna\n      kwargs:\n        fields_group: feature\n  learn_processors:\n    - class: CSRankNorm\n      kwargs:\n        fields_group: label\n  label: [\"Ref($close, -2) / Ref($close, -1) - 1\"]\n\nnum_states: &num_states 3\n\nmemory_mode: &memory_mode sample\n\ntra_config: &tra_config\n  num_states: *num_states\n  rnn_arch: LSTM\n  hidden_size: 32\n  num_layers: 1\n  dropout: 0.0\n  tau: 1.0\n  src_info: LR_TPE\n\nmodel_config: &model_config\n  input_size: 6\n  hidden_size: 64\n  num_layers: 2\n  rnn_arch: LSTM\n  use_attn: True\n  dropout: 0.0\n\nport_analysis_config: &port_analysis_config\n    strategy:\n        class: TopkDropoutStrategy\n        module_path: qlib.contrib.strategy\n        kwargs:\n            signal: <PRED>\n            topk: 50\n            n_drop: 5\n    backtest:\n        start_time: 2017-01-01\n        end_time: 2020-08-01\n        account: 100000000\n        benchmark: *benchmark\n        exchange_kwargs:\n            limit_threshold: 0.095\n            deal_price: close\n            open_cost: 0.0005\n            close_cost: 0.0015\n            min_cost: 5\n\ntask:\n  model:\n    class: TRAModel\n    module_path: qlib.contrib.model.pytorch_tra\n    kwargs:\n      tra_config: *tra_config\n      model_config: *model_config\n      model_type: RNN\n      lr: 1e-3\n      n_epochs: 100\n      max_steps_per_epoch:\n      early_stop: 20\n      logdir: output/Alpha360\n      seed: 0\n      lamb: 1.0\n      rho: 0.99\n      alpha: 0.5\n      transport_method: router\n      memory_mode: *memory_mode\n      eval_train: False\n      eval_test: True\n      pretrain: True\n      init_state:\n      freeze_model: False\n      freeze_predictors: False\n  dataset:\n    class: MTSDatasetH\n    module_path: qlib.contrib.data.dataset\n    kwargs:\n      handler:\n        class: Alpha360\n        module_path: qlib.contrib.data.handler\n        kwargs: *data_handler_config\n      segments:\n        train: [2008-01-01, 2014-12-31]\n        valid: [2015-01-01, 2016-12-31]\n        test: [2017-01-01, 2020-08-01]\n      seq_len: 60\n      input_size: 6\n      num_states: *num_states\n      batch_size: 1024\n      n_samples:\n      memory_mode: *memory_mode\n      drop_last: True\n  record:\n    - class: SignalRecord\n      module_path: qlib.workflow.record_temp\n      kwargs: \n        model: <MODEL>\n        dataset: <DATASET>\n    - class: SigAnaRecord\n      module_path: qlib.workflow.record_temp\n      kwargs: \n        ana_long_short: False\n        ann_scaler: 252\n    - class: PortAnaRecord\n      module_path: qlib.workflow.record_temp\n      kwargs:\n        config: *port_analysis_config\n"
  },
  {
    "path": "examples/benchmarks/TabNet/README.md",
    "content": "# TabNet\n* Code: [https://github.com/dreamquark-ai/tabnet](https://github.com/dreamquark-ai/tabnet)\n* Paper: [TabNet: Attentive Interpretable Tabular Learning](https://arxiv.org/pdf/1908.07442.pdf).\n"
  },
  {
    "path": "examples/benchmarks/TabNet/requirements.txt",
    "content": "pandas==1.1.2\nnumpy==1.21.0\nscikit_learn==0.23.2\ntorch==1.7.0"
  },
  {
    "path": "examples/benchmarks/TabNet/workflow_config_TabNet_Alpha158.yaml",
    "content": "qlib_init:\n    provider_uri: \"~/.qlib/qlib_data/cn_data\"\n    region: cn\nmarket: &market csi300\nbenchmark: &benchmark SH000300\ndata_handler_config: &data_handler_config\n    start_time: 2008-01-01\n    end_time: 2020-08-01\n    fit_start_time: 2008-01-01\n    fit_end_time: 2014-12-31\n    instruments: *market\n    infer_processors:\n        - class: RobustZScoreNorm\n          kwargs:\n              fields_group: feature\n              clip_outlier: true\n        - class: Fillna\n          kwargs:\n              fields_group: feature\n    learn_processors:\n        - class: DropnaLabel\n        - class: CSRankNorm\n          kwargs:\n              fields_group: label\n    label: [\"Ref($close, -2) / Ref($close, -1) - 1\"]\nport_analysis_config: &port_analysis_config\n    strategy:\n        class: TopkDropoutStrategy\n        module_path: qlib.contrib.strategy\n        kwargs:\n            signal: <PRED>\n            topk: 50\n            n_drop: 5\n    backtest:\n        start_time: 2017-01-01\n        end_time: 2020-08-01\n        account: 100000000\n        benchmark: *benchmark\n        exchange_kwargs:\n            limit_threshold: 0.095\n            deal_price: close\n            open_cost: 0.0005\n            close_cost: 0.0015\n            min_cost: 5\ntask:\n    model:\n        class: TabnetModel\n        module_path: qlib.contrib.model.pytorch_tabnet\n        kwargs:\n            d_feat: 158\n            pretrain: True\n            seed: 993\n    dataset:\n        class: DatasetH\n        module_path: qlib.data.dataset\n        kwargs:\n            handler:\n                class: Alpha158\n                module_path: qlib.contrib.data.handler\n                kwargs: *data_handler_config\n            segments:\n                pretrain: [2008-01-01, 2014-12-31]\n                pretrain_validation: [2015-01-01, 2016-12-31]\n                train: [2008-01-01, 2014-12-31]\n                valid: [2015-01-01, 2016-12-31]\n                test: [2017-01-01, 2020-08-01]\n    record: \n        - class: SignalRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            model: <MODEL>\n            dataset: <DATASET>\n        - class: SigAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            ana_long_short: False\n            ann_scaler: 252\n        - class: PortAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            config: *port_analysis_config\n"
  },
  {
    "path": "examples/benchmarks/TabNet/workflow_config_TabNet_Alpha360.yaml",
    "content": "qlib_init:\n    provider_uri: \"~/.qlib/qlib_data/cn_data\"\n    region: cn\nmarket: &market csi300\nbenchmark: &benchmark SH000300\ndata_handler_config: &data_handler_config\n    start_time: 2008-01-01\n    end_time: 2020-08-01\n    fit_start_time: 2008-01-01\n    fit_end_time: 2014-12-31\n    instruments: *market\n    infer_processors:\n        - class: RobustZScoreNorm\n          kwargs:\n              fields_group: feature\n              clip_outlier: true\n        - class: Fillna\n          kwargs:\n              fields_group: feature\n    learn_processors:\n        - class: DropnaLabel\n        - class: CSRankNorm\n          kwargs:\n              fields_group: label\n    label: [\"Ref($close, -2) / Ref($close, -1) - 1\"]\nport_analysis_config: &port_analysis_config\n    strategy:\n        class: TopkDropoutStrategy\n        module_path: qlib.contrib.strategy\n        kwargs:\n            signal: <PRED>\n            topk: 50\n            n_drop: 5\n    backtest:\n        start_time: 2017-01-01\n        end_time: 2020-08-01\n        account: 100000000\n        benchmark: *benchmark\n        exchange_kwargs:\n            limit_threshold: 0.095\n            deal_price: close\n            open_cost: 0.0005\n            close_cost: 0.0015\n            min_cost: 5\ntask:\n    model:\n        class: TabnetModel\n        module_path: qlib.contrib.model.pytorch_tabnet\n        kwargs:\n            d_feat: 360\n            pretrain: True\n            seed: 993\n    dataset:\n        class: DatasetH\n        module_path: qlib.data.dataset\n        kwargs:\n            handler:\n                class: Alpha360\n                module_path: qlib.contrib.data.handler\n                kwargs: *data_handler_config\n            segments:\n                pretrain: [2008-01-01, 2014-12-31]\n                pretrain_validation: [2015-01-01, 2016-12-31]\n                train: [2008-01-01, 2014-12-31]\n                valid: [2015-01-01, 2016-12-31]\n                test: [2017-01-01, 2020-08-01]\n    record: \n        - class: SignalRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            model: <MODEL>\n            dataset: <DATASET>\n        - class: SigAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            ana_long_short: False\n            ann_scaler: 252\n        - class: PortAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            config: *port_analysis_config\n"
  },
  {
    "path": "examples/benchmarks/Transformer/README.md",
    "content": "# Transformer\n* Code: [https://github.com/tensorflow/tensor2tensor](https://github.com/tensorflow/tensor2tensor)\n* Paper: [Attention is All you Need](https://proceedings.neurips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf).\n"
  },
  {
    "path": "examples/benchmarks/Transformer/requirements.txt",
    "content": "numpy==1.21.0\r\npandas==1.1.2\r\ntorch==1.2.0"
  },
  {
    "path": "examples/benchmarks/Transformer/workflow_config_transformer_Alpha158.yaml",
    "content": "qlib_init:\n    provider_uri: \"~/.qlib/qlib_data/cn_data\"\n    region: cn\nmarket: &market csi300\nbenchmark: &benchmark SH000300\ndata_handler_config: &data_handler_config\n    start_time: 2008-01-01\n    end_time: 2020-08-01\n    fit_start_time: 2008-01-01\n    fit_end_time: 2014-12-31\n    instruments: *market\n    infer_processors:\n        - class: FilterCol\n          kwargs:\n              fields_group: feature\n              col_list: [\"RESI5\", \"WVMA5\", \"RSQR5\", \"KLEN\", \"RSQR10\", \"CORR5\", \"CORD5\", \"CORR10\", \n                            \"ROC60\", \"RESI10\", \"VSTD5\", \"RSQR60\", \"CORR60\", \"WVMA60\", \"STD5\", \n                            \"RSQR20\", \"CORD60\", \"CORD10\", \"CORR20\", \"KLOW\"\n                        ]\n        - class: RobustZScoreNorm\n          kwargs:\n              fields_group: feature\n              clip_outlier: true\n        - class: Fillna\n          kwargs:\n              fields_group: feature\n    learn_processors:\n        - class: DropnaLabel\n        - class: CSRankNorm\n          kwargs:\n              fields_group: label\n    label: [\"Ref($close, -2) / Ref($close, -1) - 1\"] \n\nport_analysis_config: &port_analysis_config\n    strategy:\n        class: TopkDropoutStrategy\n        module_path: qlib.contrib.strategy\n        kwargs:\n            signal: <PRED>\n            topk: 50\n            n_drop: 5\n    backtest:\n        start_time: 2017-01-01\n        end_time: 2020-08-01\n        account: 100000000\n        benchmark: *benchmark\n        exchange_kwargs:\n            limit_threshold: 0.095\n            deal_price: close\n            open_cost: 0.0005\n            close_cost: 0.0015\n            min_cost: 5\ntask:\n    model:\n        class: TransformerModel\n        module_path: qlib.contrib.model.pytorch_transformer_ts\n        kwargs:\n            seed: 0\n            n_jobs: 20\n    dataset:\n        class: TSDatasetH\n        module_path: qlib.data.dataset\n        kwargs:\n            handler:\n                class: Alpha158\n                module_path: qlib.contrib.data.handler\n                kwargs: *data_handler_config\n            segments:\n                train: [2008-01-01, 2014-12-31]\n                valid: [2015-01-01, 2016-12-31]\n                test: [2017-01-01, 2020-08-01]\n            step_len: 20\n    record: \n        - class: SignalRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            model: <MODEL>\n            dataset: <DATASET>\n        - class: SigAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            ana_long_short: False\n            ann_scaler: 252\n        - class: PortAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            config: *port_analysis_config\n"
  },
  {
    "path": "examples/benchmarks/Transformer/workflow_config_transformer_Alpha360.yaml",
    "content": "qlib_init:\n    provider_uri: \"~/.qlib/qlib_data/cn_data\"\n    region: cn\nmarket: &market csi300\nbenchmark: &benchmark SH000300\ndata_handler_config: &data_handler_config\n    start_time: 2008-01-01\n    end_time: 2020-08-01\n    fit_start_time: 2008-01-01\n    fit_end_time: 2014-12-31\n    instruments: *market\n    infer_processors:\n        - class: RobustZScoreNorm\n          kwargs:\n              fields_group: feature\n              clip_outlier: true\n        - class: Fillna\n          kwargs:\n              fields_group: feature\n    learn_processors:\n        - class: DropnaLabel\n        - class: CSRankNorm\n          kwargs:\n              fields_group: label\n    label: [\"Ref($close, -2) / Ref($close, -1) - 1\"]\nport_analysis_config: &port_analysis_config\n    strategy:\n        class: TopkDropoutStrategy\n        module_path: qlib.contrib.strategy\n        kwargs:\n            signal: <PRED>\n            topk: 50\n            n_drop: 5\n    backtest:\n        start_time: 2017-01-01\n        end_time: 2020-08-01\n        account: 100000000\n        benchmark: *benchmark\n        exchange_kwargs:\n            limit_threshold: 0.095\n            deal_price: close\n            open_cost: 0.0005\n            close_cost: 0.0015\n            min_cost: 5\ntask:\n    model:\n        class: TransformerModel\n        module_path: qlib.contrib.model.pytorch_transformer\n        kwargs:\n            d_feat: 6\n            seed: 0\n    dataset:\n        class: DatasetH\n        module_path: qlib.data.dataset\n        kwargs:\n            handler:\n                class: Alpha360\n                module_path: qlib.contrib.data.handler\n                kwargs: *data_handler_config\n            segments:\n                train: [2008-01-01, 2014-12-31]\n                valid: [2015-01-01, 2016-12-31]\n                test: [2017-01-01, 2020-08-01]\n    record: \n        - class: SignalRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            model: <MODEL>\n            dataset: <DATASET>\n        - class: SigAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            ana_long_short: False\n            ann_scaler: 252\n        - class: PortAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            config: *port_analysis_config\n"
  },
  {
    "path": "examples/benchmarks/XGBoost/README.md",
    "content": "# XGBoost\n* Code: [https://github.com/dmlc/xgboost](https://github.com/dmlc/xgboost)\n* Paper: XGBoost: A Scalable Tree Boosting System. [https://dl.acm.org/doi/pdf/10.1145/2939672.2939785](https://dl.acm.org/doi/pdf/10.1145/2939672.2939785)."
  },
  {
    "path": "examples/benchmarks/XGBoost/requirements.txt",
    "content": "numpy==1.21.0\npandas==1.1.2\nxgboost==1.2.1"
  },
  {
    "path": "examples/benchmarks/XGBoost/workflow_config_xgboost_Alpha158.yaml",
    "content": "qlib_init:\n    provider_uri: \"~/.qlib/qlib_data/cn_data\"\n    region: cn\nmarket: &market csi300\nbenchmark: &benchmark SH000300\ndata_handler_config: &data_handler_config\n    start_time: 2008-01-01\n    end_time: 2020-08-01\n    fit_start_time: 2008-01-01\n    fit_end_time: 2014-12-31\n    instruments: *market\nport_analysis_config: &port_analysis_config\n    strategy:\n        class: TopkDropoutStrategy\n        module_path: qlib.contrib.strategy\n        kwargs:\n            signal: <PRED>\n            topk: 50\n            n_drop: 5\n    backtest:\n        start_time: 2017-01-01\n        end_time: 2020-08-01\n        account: 100000000\n        benchmark: *benchmark\n        exchange_kwargs:\n            limit_threshold: 0.095\n            deal_price: close\n            open_cost: 0.0005\n            close_cost: 0.0015\n            min_cost: 5\ntask:\n    model:\n        class: XGBModel\n        module_path: qlib.contrib.model.xgboost\n        kwargs:\n            eval_metric: rmse\n            colsample_bytree: 0.8879\n            eta: 0.0421\n            max_depth: 8\n            n_estimators: 647\n            subsample: 0.8789\n            nthread: 20\n    dataset:\n        class: DatasetH\n        module_path: qlib.data.dataset\n        kwargs:\n            handler:\n                class: Alpha158\n                module_path: qlib.contrib.data.handler\n                kwargs: *data_handler_config\n            segments:\n                train: [2008-01-01, 2014-12-31]\n                valid: [2015-01-01, 2016-12-31]\n                test: [2017-01-01, 2020-08-01]\n    record: \n        - class: SignalRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            model: <MODEL>\n            dataset: <DATASET>\n        - class: SigAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            ana_long_short: False\n            ann_scaler: 252\n        - class: PortAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            config: *port_analysis_config\n"
  },
  {
    "path": "examples/benchmarks/XGBoost/workflow_config_xgboost_Alpha360.yaml",
    "content": "qlib_init:\n    provider_uri: \"~/.qlib/qlib_data/cn_data\"\n    region: cn\nmarket: &market csi300\nbenchmark: &benchmark SH000300\ndata_handler_config: &data_handler_config\n    start_time: 2008-01-01\n    end_time: 2020-08-01\n    fit_start_time: 2008-01-01\n    fit_end_time: 2014-12-31\n    instruments: *market\n    infer_processors: []\n    learn_processors:\n        - class: DropnaLabel\n        - class: CSRankNorm\n          kwargs:\n              fields_group: label\n    label: [\"Ref($close, -2) / Ref($close, -1) - 1\"]\nport_analysis_config: &port_analysis_config\n    strategy:\n        class: TopkDropoutStrategy\n        module_path: qlib.contrib.strategy\n        kwargs:\n            signal: <PRED>\n            topk: 50\n            n_drop: 5\n    backtest:\n        start_time: 2017-01-01\n        end_time: 2020-08-01\n        account: 100000000\n        benchmark: *benchmark\n        exchange_kwargs:\n            limit_threshold: 0.095\n            deal_price: close\n            open_cost: 0.0005\n            close_cost: 0.0015\n            min_cost: 5\ntask:\n    model:\n        class: XGBModel\n        module_path: qlib.contrib.model.xgboost\n        kwargs:\n            eval_metric: rmse\n            colsample_bytree: 0.8879\n            eta: 0.0421\n            max_depth: 8\n            n_estimators: 647\n            subsample: 0.8789\n            nthread: 20\n    dataset:\n        class: DatasetH\n        module_path: qlib.data.dataset\n        kwargs:\n            handler:\n                class: Alpha360\n                module_path: qlib.contrib.data.handler\n                kwargs: *data_handler_config\n            segments:\n                train: [2008-01-01, 2014-12-31]\n                valid: [2015-01-01, 2016-12-31]\n                test: [2017-01-01, 2020-08-01]\n    record: \n        - class: SignalRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            model: <MODEL>\n            dataset: <DATASET>\n        - class: SigAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            ana_long_short: False\n            ann_scaler: 252\n        - class: PortAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            config: *port_analysis_config\n"
  },
  {
    "path": "examples/benchmarks_dynamic/DDG-DA/Makefile",
    "content": ".PHONY: clean\n\nclean:\n\t-rm -r *.pkl mlruns || true\n"
  },
  {
    "path": "examples/benchmarks_dynamic/DDG-DA/README.md",
    "content": "# Introduction\nThis is the implementation of `DDG-DA` based on `Meta Controller` component provided by `Qlib`.\n\nPlease refer to the paper for more details: *DDG-DA: Data Distribution Generation for Predictable Concept Drift Adaptation* [[arXiv](https://arxiv.org/abs/2201.04038)]\n\n\n# Background\nIn many real-world scenarios, we often deal with streaming data that is sequentially collected over time. Due to the non-stationary nature of the environment, the streaming data distribution may change in unpredictable ways, which is known as concept drift. To handle concept drift, previous methods first detect when/where the concept drift happens and then adapt models to fit the distribution of the latest data. However, there are still many cases that some underlying factors of environment evolution are predictable, making it possible to model the future concept drift trend of the streaming data, while such cases are not fully explored in previous work.\n\nTherefore, we propose a novel method `DDG-DA`, that can effectively forecast the evolution of data distribution and improve the performance of models. Specifically, we first train a predictor to estimate the future data distribution, then leverage it to generate training samples, and finally train models on the generated data.\n\n# Dataset\nThe data in the paper are private. So we conduct experiments on Qlib's public dataset.\nThough the dataset is different, the conclusion remains the same. By applying `DDG-DA`, users can see rising trends at the test phase both in the proxy models' ICs and the performances of the forecasting models.\n\n# Run the Code\nUsers can try `DDG-DA` by running the following command:\n```bash\n    python workflow.py run\n```\n\nThe default forecasting models are `Linear`. Users can choose other forecasting models by changing the `forecast_model` parameter when `DDG-DA` initializes. For example, users can try `LightGBM` forecasting models by running the following command:\n```bash\n    python workflow.py --conf_path=../workflow_config_lightgbm_Alpha158.yaml run\n```\n\n# Results\nThe results of related methods in Qlib's public dataset can be found [here](../)\n\n# Requirements\nHere are the minimal hardware requirements to run the ``workflow.py`` of DDG-DA.\n* Memory: 45G\n* Disk: 4G\n\nPytorch with CPU & RAM will be enough for this example.\n"
  },
  {
    "path": "examples/benchmarks_dynamic/DDG-DA/requirements.txt",
    "content": "torch==1.10.0 \n"
  },
  {
    "path": "examples/benchmarks_dynamic/DDG-DA/vis_data.py",
    "content": "import numpy as np\nimport pandas as pd\nimport matplotlib.pyplot as plt\nimport seaborn as sns\n\nfrom qlib.utils.pickle_utils import restricted_pickle_load\n\nsns.set(color_codes=True)\nplt.rcParams[\"font.sans-serif\"] = \"SimHei\"\nplt.rcParams[\"axes.unicode_minus\"] = False\nfrom tqdm.auto import tqdm\n\n# tqdm.pandas()  # for progress_apply\n# %matplotlib inline\n# %load_ext autoreload\n\n\n# # Meta Input\n\n# +\nwith open(\"./internal_data_s20.pkl\", \"rb\") as f:\n    data = restricted_pickle_load(f)\n\ndata.data_ic_df.columns.names = [\"start_date\", \"end_date\"]\n\ndata_sim = data.data_ic_df.droplevel(axis=1, level=\"end_date\")\n\ndata_sim.index.name = \"test datetime\"\n# -\n\nplt.figure(figsize=(40, 20))\nsns.heatmap(data_sim)\n\nplt.figure(figsize=(40, 20))\nsns.heatmap(data_sim.rolling(20).mean())\n\n# # Meta Model\n\nfrom qlib import auto_init\n\nauto_init()\nfrom qlib.workflow import R\n\nexp = R.get_exp(experiment_name=\"DDG-DA\")\nmeta_rec = exp.list_recorders(rtype=\"list\", max_results=1)[0]\nmeta_m = meta_rec.load_object(\"model\")\n\npd.DataFrame(meta_m.tn.twm.linear.weight.detach().numpy()).T[0].plot()\n\npd.DataFrame(meta_m.tn.twm.linear.weight.detach().numpy()).T[0].rolling(5).mean().plot()\n\n# # Meta Output\n\n# +\nwith open(\"./tasks_s20.pkl\", \"rb\") as f:\n    tasks = restricted_pickle_load(f)\n\ntask_df = {}\nfor t in tasks:\n    test_seg = t[\"dataset\"][\"kwargs\"][\"segments\"][\"test\"]\n    if None not in test_seg:\n        # The last rolling is skipped.\n        task_df[test_seg] = t[\"reweighter\"].time_weight\ntask_df = pd.concat(task_df)\n\ntask_df.index.names = [\"OS_start\", \"OS_end\", \"IS_start\", \"IS_end\"]\ntask_df = task_df.droplevel([\"OS_end\", \"IS_end\"])\ntask_df = task_df.unstack(\"OS_start\")\n# -\n\nplt.figure(figsize=(40, 20))\nsns.heatmap(task_df.T)\n\nplt.figure(figsize=(40, 20))\nsns.heatmap(task_df.rolling(10).mean().T)\n\n# # Sub Models\n#\n# NOTE:\n# - this section assumes that the model is Linear model!!\n# - Other models does not support this analysis\n\nexp = R.get_exp(experiment_name=\"rolling_ds\")\n\n\ndef show_linear_weight(exp):\n    coef_df = {}\n    for r in exp.list_recorders(\"list\"):\n        t = r.load_object(\"task\")\n        if None in t[\"dataset\"][\"kwargs\"][\"segments\"][\"test\"]:\n            continue\n        m = r.load_object(\"params.pkl\")\n        coef_df[t[\"dataset\"][\"kwargs\"][\"segments\"][\"test\"]] = pd.Series(m.coef_)\n\n    coef_df = pd.concat(coef_df)\n\n    coef_df.index.names = [\"test_start\", \"test_end\", \"coef_idx\"]\n\n    coef_df = coef_df.droplevel(\"test_end\").unstack(\"coef_idx\").T\n\n    plt.figure(figsize=(40, 20))\n    sns.heatmap(coef_df)\n    plt.show()\n\n\nshow_linear_weight(R.get_exp(experiment_name=\"rolling_ds\"))\n\nshow_linear_weight(R.get_exp(experiment_name=\"rolling_models\"))\n"
  },
  {
    "path": "examples/benchmarks_dynamic/DDG-DA/workflow.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\nimport os\nfrom pathlib import Path\nfrom typing import Union\n\nimport fire\n\nfrom qlib import auto_init\nfrom qlib.contrib.rolling.ddgda import DDGDA\nfrom qlib.tests.data import GetData\n\nDIRNAME = Path(__file__).absolute().resolve().parent\nBENCH_DIR = DIRNAME.parent / \"baseline\"\n\n\nclass DDGDABench(DDGDA):\n    # The config in the README.md\n    CONF_LIST = [\n        BENCH_DIR / \"workflow_config_linear_Alpha158.yaml\",\n        BENCH_DIR / \"workflow_config_lightgbm_Alpha158.yaml\",\n    ]\n\n    DEFAULT_CONF = CONF_LIST[0]  # Linear by default due to efficiency\n\n    def __init__(self, conf_path: Union[str, Path] = DEFAULT_CONF, horizon=20, **kwargs) -> None:\n        # This code is for being compatible with the previous old code\n        conf_path = Path(conf_path)\n        super().__init__(conf_path=conf_path, horizon=horizon, working_dir=DIRNAME, **kwargs)\n\n        for f in self.CONF_LIST:\n            if conf_path.samefile(f):\n                break\n        else:\n            self.logger.warning(\"Model type is not in the benchmark!\")\n\n\nif __name__ == \"__main__\":\n    kwargs = {}\n    if os.environ.get(\"PROVIDER_URI\", \"\") == \"\":\n        GetData().qlib_data(exists_skip=True)\n    else:\n        kwargs[\"provider_uri\"] = os.environ[\"PROVIDER_URI\"]\n    auto_init(**kwargs)\n    fire.Fire(DDGDABench)\n"
  },
  {
    "path": "examples/benchmarks_dynamic/README.md",
    "content": "# Introduction\nDue to the non-stationary nature of the environment of the financial market, the data distribution may change in different periods, which makes the performance of models build on training data decays in the future test data.\nSo adapting the forecasting models/strategies to market dynamics is very important to the model/strategies' performance.\n\nThe table below shows the performances of different solutions on different forecasting models.\n\n## Alpha158 Dataset\nHere is the [crowd sourced version of qlib data](data_collector/crowd_source/README.md): https://github.com/chenditc/investment_data/releases\n```bash\nwget https://github.com/chenditc/investment_data/releases/latest/download/qlib_bin.tar.gz\nmkdir -p ~/.qlib/qlib_data/cn_data\ntar -zxvf qlib_bin.tar.gz -C ~/.qlib/qlib_data/cn_data --strip-components=2\nrm -f qlib_bin.tar.gz\n```\n\n| Model Name       | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |\n|------------------|---------|------|------|---------|-----------|-------------------|-------------------|--------------|\n| RR[Linear]       |Alpha158 |0.0945|0.5989|0.1069   |0.6495     |0.0857             |1.3682             |-0.0986       |\n| DDG-DA[Linear]   |Alpha158 |0.0983|0.6157|0.1108   |0.6646     |0.0764             |1.1904             |-0.0769       |\n| RR[LightGBM]     |Alpha158 |0.0816|0.5887|0.0912   |0.6263     |0.0771             |1.3196             |-0.0909       |\n| DDG-DA[LightGBM] |Alpha158 |0.0878|0.6185|0.0975   |0.6524     |0.1261             |2.0096             |-0.0744       |\n\n- The label horizon of the `Alpha158` dataset is set to 20.\n- The rolling time intervals are set to 20 trading days.\n- The test rolling periods are from January 2017 to August 2020.\n- The results are based on the crowd-sourced version. The Yahoo version of qlib data does not contain `VWAP`, so all related factors are missing and filled with 0, which leads to a rank-deficient matrix (a matrix does not have full rank) and makes lower-level optimization of DDG-DA can not be solved.\n"
  },
  {
    "path": "examples/benchmarks_dynamic/baseline/README.md",
    "content": "# Introduction\n\nThis is the framework of periodically Rolling Retrain (RR) forecasting models. RR adapts to market dynamics by utilizing the up-to-date data periodically.\n\n## Run the Code\nUsers can try RR by running the following command:\n```bash\n    python rolling_benchmark.py run\n```\n\nThe default forecasting models are `Linear`. Users can choose other forecasting models by changing the `model_type` parameter.\nFor example, users can try `LightGBM` forecasting models by running the following command:\n```bash\n    python rolling_benchmark.py --conf_path=workflow_config_lightgbm_Alpha158.yaml run\n\n```\n"
  },
  {
    "path": "examples/benchmarks_dynamic/baseline/rolling_benchmark.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\nimport os\nfrom pathlib import Path\nfrom typing import Union\n\nimport fire\n\nfrom qlib import auto_init\nfrom qlib.contrib.rolling.base import Rolling\nfrom qlib.tests.data import GetData\n\nDIRNAME = Path(__file__).absolute().resolve().parent\n\n\nclass RollingBenchmark(Rolling):\n    # The config in the README.md\n    CONF_LIST = [DIRNAME / \"workflow_config_linear_Alpha158.yaml\", DIRNAME / \"workflow_config_lightgbm_Alpha158.yaml\"]\n\n    DEFAULT_CONF = CONF_LIST[0]\n\n    def __init__(self, conf_path: Union[str, Path] = DEFAULT_CONF, horizon=20, **kwargs) -> None:\n        # This code is for being compatible with the previous old code\n        conf_path = Path(conf_path)\n        super().__init__(conf_path=conf_path, horizon=horizon, **kwargs)\n\n        for f in self.CONF_LIST:\n            if conf_path.samefile(f):\n                break\n        else:\n            self.logger.warning(\"Model type is not in the benchmark!\")\n\n\nif __name__ == \"__main__\":\n    kwargs = {}\n    if os.environ.get(\"PROVIDER_URI\", \"\") == \"\":\n        GetData().qlib_data(exists_skip=True)\n    else:\n        kwargs[\"provider_uri\"] = os.environ[\"PROVIDER_URI\"]\n    auto_init(**kwargs)\n    fire.Fire(RollingBenchmark)\n"
  },
  {
    "path": "examples/benchmarks_dynamic/baseline/workflow_config_lightgbm_Alpha158.yaml",
    "content": "qlib_init:\n    provider_uri: \"~/.qlib/qlib_data/cn_data\"\n    region: cn\nmarket: &market csi300\nbenchmark: &benchmark SH000300\ndata_handler_config: &data_handler_config\n    start_time: 2008-01-01\n    end_time: 2020-08-01\n    fit_start_time: 2008-01-01\n    fit_end_time: 2014-12-31\n    instruments: *market\nport_analysis_config: &port_analysis_config\n    strategy:\n        class: TopkDropoutStrategy\n        module_path: qlib.contrib.strategy\n        kwargs:\n            signal: <PRED>\n            topk: 50\n            n_drop: 5\n    backtest:\n        start_time: 2017-01-01\n        end_time: 2020-08-01\n        account: 100000000\n        benchmark: *benchmark\n        exchange_kwargs:\n            limit_threshold: 0.095\n            deal_price: close\n            open_cost: 0.0005\n            close_cost: 0.0015\n            min_cost: 5\ntask:\n    model:\n        class: LGBModel\n        module_path: qlib.contrib.model.gbdt\n        kwargs:\n            loss: mse\n            colsample_bytree: 0.8879\n            learning_rate: 0.2\n            subsample: 0.8789\n            lambda_l1: 205.6999\n            lambda_l2: 580.9768\n            max_depth: 8\n            num_leaves: 210\n            num_threads: 20\n    dataset:\n        class: DatasetH\n        module_path: qlib.data.dataset\n        kwargs:\n            handler:\n                class: Alpha158\n                module_path: qlib.contrib.data.handler\n                kwargs: *data_handler_config\n            segments:\n                train: [2008-01-01, 2014-12-31]\n                valid: [2015-01-01, 2016-12-31]\n                test: [2017-01-01, 2020-08-01]\n    record: \n        - class: SignalRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            model: <MODEL>\n            dataset: <DATASET>\n        - class: SigAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            ana_long_short: False\n            ann_scaler: 252\n        - class: PortAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            config: *port_analysis_config\n"
  },
  {
    "path": "examples/benchmarks_dynamic/baseline/workflow_config_linear_Alpha158.yaml",
    "content": "qlib_init:\n    provider_uri: \"~/.qlib/qlib_data/cn_data\"\n    region: cn\nmarket: &market csi300\nbenchmark: &benchmark SH000300\ndata_handler_config: &data_handler_config\n    start_time: 2008-01-01\n    end_time: 2020-08-01\n    fit_start_time: 2008-01-01\n    fit_end_time: 2014-12-31\n    instruments: *market\n    infer_processors:\n        - class: RobustZScoreNorm\n          kwargs:\n              fields_group: feature\n              clip_outlier: true\n        - class: Fillna\n          kwargs:\n              fields_group: feature\n    learn_processors:\n        - class: DropnaLabel\n        - class: CSRankNorm\n          kwargs:\n              fields_group: label\nport_analysis_config: &port_analysis_config\n    strategy:\n        class: TopkDropoutStrategy\n        module_path: qlib.contrib.strategy\n        kwargs:\n            signal: <PRED>\n            topk: 50\n            n_drop: 5\n    backtest:\n        start_time: 2017-01-01\n        end_time: 2020-08-01\n        account: 100000000\n        benchmark: *benchmark\n        exchange_kwargs:\n            limit_threshold: 0.095\n            deal_price: close\n            open_cost: 0.0005\n            close_cost: 0.0015\n            min_cost: 5\ntask:\n    model:\n        class: LinearModel\n        module_path: qlib.contrib.model.linear\n        kwargs:\n            estimator: ridge\n            alpha: 0.05\n    dataset:\n        class: DatasetH\n        module_path: qlib.data.dataset\n        kwargs:\n            handler:\n                class: Alpha158\n                module_path: qlib.contrib.data.handler\n                kwargs: *data_handler_config\n            segments:\n                train: [2008-01-01, 2014-12-31]\n                valid: [2015-01-01, 2016-12-31]\n                test: [2017-01-01, 2020-08-01]\n    record: \n        - class: SignalRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            model: <MODEL>\n            dataset: <DATASET>\n        - class: SigAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            ana_long_short: True\n            ann_scaler: 252\n        - class: PortAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs: \n            config: *port_analysis_config\n"
  },
  {
    "path": "examples/data_demo/README.md",
    "content": "# Introduction\nThe examples in this folder try to demonstrate some common usage of data-related modules of Qlib\n"
  },
  {
    "path": "examples/data_demo/data_cache_demo.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\"\"\"\nThe motivation of this demo\n- To show the data modules of Qlib is Serializable, users can dump processed data to disk to avoid duplicated data preprocessing\n\"\"\"\n\nfrom copy import deepcopy\nfrom pathlib import Path\nimport pickle\nfrom pprint import pprint\nfrom ruamel.yaml import YAML\nimport subprocess\nfrom qlib.log import TimeInspector\n\nfrom qlib import init\nfrom qlib.data.dataset.handler import DataHandlerLP\nfrom qlib.utils import init_instance_by_config\n\n# For general purpose, we use relative path\nDIRNAME = Path(__file__).absolute().resolve().parent\n\nif __name__ == \"__main__\":\n    init()\n\n    config_path = DIRNAME.parent / \"benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml\"\n\n    # 1) show original time\n    with TimeInspector.logt(\"The original time without handler cache:\"):\n        subprocess.run(f\"qrun {config_path}\", shell=True)\n\n    # 2) dump handler\n    yaml = YAML(typ=\"safe\", pure=True)\n    task_config = yaml.load(config_path.open())\n    hd_conf = task_config[\"task\"][\"dataset\"][\"kwargs\"][\"handler\"]\n    pprint(hd_conf)\n    hd: DataHandlerLP = init_instance_by_config(hd_conf)\n    hd_path = DIRNAME / \"handler.pkl\"\n    hd.to_pickle(hd_path, dump_all=True)\n\n    # 3) create new task with handler cache\n    new_task_config = deepcopy(task_config)\n    new_task_config[\"task\"][\"dataset\"][\"kwargs\"][\"handler\"] = f\"file://{hd_path}\"\n    new_task_config[\"sys\"] = {\"path\": [str(config_path.parent.resolve())]}\n    new_task_path = DIRNAME / \"new_task.yaml\"\n    print(\"The location of the new task\", new_task_path)\n\n    # save new task\n    with new_task_path.open(\"w\") as f:\n        yaml.safe_dump(new_task_config, f, indent=4, sort_keys=False)\n\n    # 4) train model with new task\n    with TimeInspector.logt(\"The time for task with handler cache:\"):\n        subprocess.run(f\"qrun {new_task_path}\", shell=True)\n"
  },
  {
    "path": "examples/data_demo/data_mem_resuse_demo.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\"\"\"\nThe motivation of this demo\n- To show the data modules of Qlib is Serializable, users can dump processed data to disk to avoid duplicated data preprocessing\n\"\"\"\n\nfrom copy import deepcopy\nfrom pathlib import Path\nimport pickle\nfrom pprint import pprint\nfrom ruamel.yaml import YAML\nimport subprocess\n\nfrom qlib import init\nfrom qlib.data.dataset.handler import DataHandlerLP\nfrom qlib.log import TimeInspector\nfrom qlib.model.trainer import task_train\nfrom qlib.utils import init_instance_by_config\n\n# For general purpose, we use relative path\nDIRNAME = Path(__file__).absolute().resolve().parent\n\nif __name__ == \"__main__\":\n    init()\n\n    repeat = 2\n    exp_name = \"data_mem_reuse_demo\"\n\n    config_path = DIRNAME.parent / \"benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml\"\n    yaml = YAML(typ=\"safe\", pure=True)\n    task_config = yaml.load(config_path.open())\n\n    # 1) without using processed data in memory\n    with TimeInspector.logt(\"The original time without reusing processed data in memory:\"):\n        for i in range(repeat):\n            task_train(task_config[\"task\"], experiment_name=exp_name)\n\n    # 2) prepare processed data in memory.\n    hd_conf = task_config[\"task\"][\"dataset\"][\"kwargs\"][\"handler\"]\n    pprint(hd_conf)\n    hd: DataHandlerLP = init_instance_by_config(hd_conf)\n\n    # 3) with reusing processed data in memory\n    new_task = deepcopy(task_config[\"task\"])\n    new_task[\"dataset\"][\"kwargs\"][\"handler\"] = hd\n    print(new_task)\n\n    with TimeInspector.logt(\"The time with reusing processed data in memory:\"):\n        # this will save the time to reload and process data from disk(in `DataHandlerLP`)\n        # It still takes a lot of time in the backtest phase\n        for i in range(repeat):\n            task_train(new_task, experiment_name=exp_name)\n\n    # 4) User can change other parts exclude processed data in memory(handler)\n    new_task = deepcopy(task_config[\"task\"])\n    new_task[\"dataset\"][\"kwargs\"][\"segments\"][\"train\"] = (\"20100101\", \"20131231\")\n    with TimeInspector.logt(\"The time with reusing processed data in memory:\"):\n        task_train(new_task, experiment_name=exp_name)\n"
  },
  {
    "path": "examples/highfreq/README.md",
    "content": "# Introduction\nThis folder contains 2 examples\n- A high-frequency dataset example\n- An example of predicting the price trend in high-frequency data\n\n## High-Frequency Dataset\n\nThis dataset is an example for RL high frequency trading.\n\n### Get High-Frequency Data\n\nGet high-frequency data by running the following command:\n```bash\n    python workflow.py get_data\n```\n\n### Dump & Reload & Reinitialize the Dataset\n\n\nThe High-Frequency Dataset is implemented as `qlib.data.dataset.DatasetH` in the `workflow.py`. `DatatsetH` is the subclass of [`qlib.utils.serial.Serializable`](https://qlib.readthedocs.io/en/latest/advanced/serial.html), whose state can be dumped in or loaded from disk in `pickle` format.\n\n### About Reinitialization\n\nAfter reloading `Dataset` from disk, `Qlib` also support reinitializing the dataset. It means that users can reset some states of `Dataset` or `DataHandler` such as `instruments`, `start_time`, `end_time` and `segments`, etc.,  and generate new data according to the states.\n\nThe example is given in `workflow.py`, users can run the code as follows.\n\n### Run the Code\n\nRun the example by running the following command:\n```bash\n    python workflow.py dump_and_load_dataset\n```\n\n## Benchmarks Performance (predicting the price trend in high-frequency data)\n\nHere are the results of models for predicting the price trend in high-frequency data. We will keep updating benchmark models in future.\n\n| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Long precision| Short Precision | Long-Short Average Return | Long-Short Average Sharpe |\n|---|---|---|---|---|---|---|---|---|---|\n| LightGBM | Alpha158 | 0.0349±0.00 | 0.3805±0.00| 0.0435±0.00 | 0.4724±0.00 | 0.5111±0.00 | 0.5428±0.00 | 0.000074±0.00 | 0.2677±0.00 |\n"
  },
  {
    "path": "examples/highfreq/highfreq_handler.py",
    "content": "from qlib.data.dataset.handler import DataHandler, DataHandlerLP\nfrom qlib.contrib.data.handler import check_transform_proc\n\n\nclass HighFreqHandler(DataHandlerLP):\n    def __init__(\n        self,\n        instruments=\"csi300\",\n        start_time=None,\n        end_time=None,\n        infer_processors=[],\n        learn_processors=[],\n        fit_start_time=None,\n        fit_end_time=None,\n        drop_raw=True,\n    ):\n        infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)\n        learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)\n\n        data_loader = {\n            \"class\": \"QlibDataLoader\",\n            \"kwargs\": {\n                \"config\": self.get_feature_config(),\n                \"swap_level\": False,\n                \"freq\": \"1min\",\n            },\n        }\n        super().__init__(\n            instruments=instruments,\n            start_time=start_time,\n            end_time=end_time,\n            data_loader=data_loader,\n            infer_processors=infer_processors,\n            learn_processors=learn_processors,\n            drop_raw=drop_raw,\n        )\n\n    def get_feature_config(self):\n        fields = []\n        names = []\n\n        template_if = \"If(IsNull({1}), {0}, {1})\"\n        template_paused = \"Select(Or(IsNull($paused), Eq($paused, 0.0)), {0})\"\n        template_fillnan = \"BFillNan(FFillNan({0}))\"\n        # Because there is no vwap field in the yahoo data, a method similar to Simpson integration is used to approximate vwap\n        simpson_vwap = \"($open + 2*$high + 2*$low + $close)/6\"\n\n        def get_normalized_price_feature(price_field, shift=0):\n            \"\"\"Get normalized price feature ops\"\"\"\n            if shift == 0:\n                template_norm = \"Cut({0}/Ref(DayLast({1}), 240), 240, None)\"\n            else:\n                template_norm = \"Cut(Ref({0}, \" + str(shift) + \")/Ref(DayLast({1}), 240), 240, None)\"\n\n            feature_ops = template_norm.format(\n                template_if.format(\n                    template_fillnan.format(template_paused.format(\"$close\")),\n                    template_paused.format(price_field),\n                ),\n                template_fillnan.format(template_paused.format(\"$close\")),\n            )\n            return feature_ops\n\n        fields += [get_normalized_price_feature(\"$open\", 0)]\n        fields += [get_normalized_price_feature(\"$high\", 0)]\n        fields += [get_normalized_price_feature(\"$low\", 0)]\n        fields += [get_normalized_price_feature(\"$close\", 0)]\n        fields += [get_normalized_price_feature(simpson_vwap, 0)]\n        names += [\"$open\", \"$high\", \"$low\", \"$close\", \"$vwap\"]\n\n        fields += [get_normalized_price_feature(\"$open\", 240)]\n        fields += [get_normalized_price_feature(\"$high\", 240)]\n        fields += [get_normalized_price_feature(\"$low\", 240)]\n        fields += [get_normalized_price_feature(\"$close\", 240)]\n        fields += [get_normalized_price_feature(simpson_vwap, 240)]\n        names += [\"$open_1\", \"$high_1\", \"$low_1\", \"$close_1\", \"$vwap_1\"]\n\n        fields += [\n            \"Cut({0}/Ref(DayLast(Mean({0}, 7200)), 240), 240, None)\".format(\n                \"If(IsNull({0}), 0, If(Or(Gt({1}, Mul(1.001, {3})), Lt({1}, Mul(0.999, {2}))), 0, {0}))\".format(\n                    template_paused.format(\"$volume\"),\n                    template_paused.format(simpson_vwap),\n                    template_paused.format(\"$low\"),\n                    template_paused.format(\"$high\"),\n                )\n            )\n        ]\n        names += [\"$volume\"]\n        fields += [\n            \"Cut(Ref({0}, 240)/Ref(DayLast(Mean({0}, 7200)), 240), 240, None)\".format(\n                \"If(IsNull({0}), 0, If(Or(Gt({1}, Mul(1.001, {3})), Lt({1}, Mul(0.999, {2}))), 0, {0}))\".format(\n                    template_paused.format(\"$volume\"),\n                    template_paused.format(simpson_vwap),\n                    template_paused.format(\"$low\"),\n                    template_paused.format(\"$high\"),\n                )\n            )\n        ]\n        names += [\"$volume_1\"]\n\n        return fields, names\n\n\nclass HighFreqBacktestHandler(DataHandler):\n    def __init__(\n        self,\n        instruments=\"csi300\",\n        start_time=None,\n        end_time=None,\n    ):\n        data_loader = {\n            \"class\": \"QlibDataLoader\",\n            \"kwargs\": {\n                \"config\": self.get_feature_config(),\n                \"swap_level\": False,\n                \"freq\": \"1min\",\n            },\n        }\n        super().__init__(\n            instruments=instruments,\n            start_time=start_time,\n            end_time=end_time,\n            data_loader=data_loader,\n        )\n\n    def get_feature_config(self):\n        fields = []\n        names = []\n\n        template_if = \"If(IsNull({1}), {0}, {1})\"\n        template_paused = \"Select(Or(IsNull($paused), Eq($paused, 0.0)), {0})\"\n        template_fillnan = \"BFillNan(FFillNan({0}))\"\n        # Because there is no vwap field in the yahoo data, a method similar to Simpson integration is used to approximate vwap\n        simpson_vwap = \"($open + 2*$high + 2*$low + $close)/6\"\n        fields += [\n            \"Cut({0}, 240, None)\".format(template_fillnan.format(template_paused.format(\"$close\"))),\n        ]\n        names += [\"$close0\"]\n        fields += [\n            \"Cut({0}, 240, None)\".format(\n                template_if.format(\n                    template_fillnan.format(template_paused.format(\"$close\")),\n                    template_paused.format(simpson_vwap),\n                )\n            )\n        ]\n        names += [\"$vwap0\"]\n        fields += [\n            \"Cut(If(IsNull({0}), 0, If(Or(Gt({1}, Mul(1.001, {3})), Lt({1}, Mul(0.999, {2}))), 0, {0})), 240, None)\".format(\n                template_paused.format(\"$volume\"),\n                template_paused.format(simpson_vwap),\n                template_paused.format(\"$low\"),\n                template_paused.format(\"$high\"),\n            )\n        ]\n        names += [\"$volume0\"]\n\n        return fields, names\n"
  },
  {
    "path": "examples/highfreq/highfreq_ops.py",
    "content": "import numpy as np\nimport pandas as pd\nimport importlib\nfrom qlib.data.ops import ElemOperator, PairOperator\nfrom qlib.config import C\nfrom qlib.data.cache import H\nfrom qlib.data.data import Cal\nfrom qlib.contrib.ops.high_freq import get_calendar_day\n\n\nclass DayLast(ElemOperator):\n    \"\"\"DayLast Operator\n\n    Parameters\n    ----------\n    feature : Expression\n        feature instance\n\n    Returns\n    ----------\n    feature:\n        a series of that each value equals the last value of its day\n    \"\"\"\n\n    def _load_internal(self, instrument, start_index, end_index, freq):\n        _calendar = get_calendar_day(freq=freq)\n        series = self.feature.load(instrument, start_index, end_index, freq)\n        return series.groupby(_calendar[series.index], group_keys=False).transform(\"last\")\n\n\nclass FFillNan(ElemOperator):\n    \"\"\"FFillNan Operator\n\n    Parameters\n    ----------\n    feature : Expression\n        feature instance\n\n    Returns\n    ----------\n    feature:\n        a forward fill nan feature\n    \"\"\"\n\n    def _load_internal(self, instrument, start_index, end_index, freq):\n        series = self.feature.load(instrument, start_index, end_index, freq)\n        return series.ffill()\n\n\nclass BFillNan(ElemOperator):\n    \"\"\"BFillNan Operator\n\n    Parameters\n    ----------\n    feature : Expression\n        feature instance\n\n    Returns\n    ----------\n    feature:\n        a backfoward fill nan feature\n    \"\"\"\n\n    def _load_internal(self, instrument, start_index, end_index, freq):\n        series = self.feature.load(instrument, start_index, end_index, freq)\n        return series.bfill()\n\n\nclass Date(ElemOperator):\n    \"\"\"Date Operator\n\n    Parameters\n    ----------\n    feature : Expression\n        feature instance\n\n    Returns\n    ----------\n    feature:\n        a series of that each value is the date corresponding to feature.index\n    \"\"\"\n\n    def _load_internal(self, instrument, start_index, end_index, freq):\n        _calendar = get_calendar_day(freq=freq)\n        series = self.feature.load(instrument, start_index, end_index, freq)\n        return pd.Series(_calendar[series.index], index=series.index)\n\n\nclass Select(PairOperator):\n    \"\"\"Select Operator\n\n    Parameters\n    ----------\n    feature_left : Expression\n        feature instance, select condition\n    feature_right : Expression\n        feature instance, select value\n\n    Returns\n    ----------\n    feature:\n        value(feature_right) that meets the condition(feature_left)\n\n    \"\"\"\n\n    def _load_internal(self, instrument, start_index, end_index, freq):\n        series_condition = self.feature_left.load(instrument, start_index, end_index, freq)\n        series_feature = self.feature_right.load(instrument, start_index, end_index, freq)\n        return series_feature.loc[series_condition]\n\n\nclass IsNull(ElemOperator):\n    \"\"\"IsNull Operator\n\n    Parameters\n    ----------\n    feature : Expression\n        feature instance\n\n    Returns\n    ----------\n    feature:\n        A series indicating whether the feature is nan\n    \"\"\"\n\n    def _load_internal(self, instrument, start_index, end_index, freq):\n        series = self.feature.load(instrument, start_index, end_index, freq)\n        return series.isnull()\n\n\nclass Cut(ElemOperator):\n    \"\"\"Cut Operator\n\n    Parameters\n    ----------\n    feature : Expression\n        feature instance\n    l : int\n        l > 0, delete the first l elements of feature (default is None, which means 0)\n    r : int\n        r < 0, delete the last -r elements of feature (default is None, which means 0)\n    Returns\n    ----------\n    feature:\n        A series with the first l and last -r elements deleted from the feature.\n        Note: It is deleted from the raw data, not the sliced data\n    \"\"\"\n\n    def __init__(self, feature, l=None, r=None):\n        self.l = l\n        self.r = r\n        if (self.l is not None and self.l <= 0) or (self.r is not None and self.r >= 0):\n            raise ValueError(\"Cut operator l should > 0 and r should < 0\")\n\n        super(Cut, self).__init__(feature)\n\n    def _load_internal(self, instrument, start_index, end_index, freq):\n        series = self.feature.load(instrument, start_index, end_index, freq)\n        return series.iloc[self.l : self.r]\n\n    def get_extended_window_size(self):\n        ll = 0 if self.l is None else self.l\n        rr = 0 if self.r is None else abs(self.r)\n        lft_etd, rght_etd = self.feature.get_extended_window_size()\n        lft_etd = lft_etd + ll\n        rght_etd = rght_etd + rr\n        return lft_etd, rght_etd\n"
  },
  {
    "path": "examples/highfreq/highfreq_processor.py",
    "content": "import numpy as np\nimport pandas as pd\nfrom qlib.constant import EPS\nfrom qlib.data.dataset.processor import Processor\nfrom qlib.data.dataset.utils import fetch_df_by_index\n\n\nclass HighFreqNorm(Processor):\n    def __init__(self, fit_start_time, fit_end_time):\n        self.fit_start_time = fit_start_time\n        self.fit_end_time = fit_end_time\n\n    def fit(self, df_features):\n        fetch_df = fetch_df_by_index(df_features, slice(self.fit_start_time, self.fit_end_time), level=\"datetime\")\n        del df_features\n        df_values = fetch_df.values\n        names = {\n            \"price\": slice(0, 10),\n            \"volume\": slice(10, 12),\n        }\n        self.feature_med = {}\n        self.feature_std = {}\n        self.feature_vmax = {}\n        self.feature_vmin = {}\n        for name, name_val in names.items():\n            part_values = df_values[:, name_val].astype(np.float32)\n            if name == \"volume\":\n                part_values = np.log1p(part_values)\n            self.feature_med[name] = np.nanmedian(part_values)\n            part_values = part_values - self.feature_med[name]\n            self.feature_std[name] = np.nanmedian(np.absolute(part_values)) * 1.4826 + EPS\n            part_values = part_values / self.feature_std[name]\n            self.feature_vmax[name] = np.nanmax(part_values)\n            self.feature_vmin[name] = np.nanmin(part_values)\n\n    def __call__(self, df_features):\n        df_features[\"date\"] = pd.to_datetime(\n            df_features.index.get_level_values(level=\"datetime\").to_series().dt.date.values\n        )\n        df_features.set_index(\"date\", append=True, drop=True, inplace=True)\n        df_values = df_features.values\n        names = {\n            \"price\": slice(0, 10),\n            \"volume\": slice(10, 12),\n        }\n\n        for name, name_val in names.items():\n            if name == \"volume\":\n                df_values[:, name_val] = np.log1p(df_values[:, name_val])\n            df_values[:, name_val] -= self.feature_med[name]\n            df_values[:, name_val] /= self.feature_std[name]\n            slice0 = df_values[:, name_val] > 3.0\n            slice1 = df_values[:, name_val] > 3.5\n            slice2 = df_values[:, name_val] < -3.0\n            slice3 = df_values[:, name_val] < -3.5\n\n            df_values[:, name_val][slice0] = (\n                3.0 + (df_values[:, name_val][slice0] - 3.0) / (self.feature_vmax[name] - 3) * 0.5\n            )\n            df_values[:, name_val][slice1] = 3.5\n            df_values[:, name_val][slice2] = (\n                -3.0 - (df_values[:, name_val][slice2] + 3.0) / (self.feature_vmin[name] + 3) * 0.5\n            )\n            df_values[:, name_val][slice3] = -3.5\n        idx = df_features.index.droplevel(\"datetime\").drop_duplicates()\n        idx.set_names([\"instrument\", \"datetime\"], inplace=True)\n\n        # Reshape is specifically for adapting to RL high-freq executor\n        feat = df_values[:, [0, 1, 2, 3, 4, 10]].reshape(-1, 6 * 240)\n        feat_1 = df_values[:, [5, 6, 7, 8, 9, 11]].reshape(-1, 6 * 240)\n        df_new_features = pd.DataFrame(\n            data=np.concatenate((feat, feat_1), axis=1),\n            index=idx,\n            columns=[\"FEATURE_%d\" % i for i in range(12 * 240)],\n        ).sort_index()\n        return df_new_features\n"
  },
  {
    "path": "examples/highfreq/workflow.py",
    "content": "#  Copyright (c) Microsoft Corporation.\n#  Licensed under the MIT License.\n\nimport fire\n\nimport qlib\nfrom qlib.constant import REG_CN\nfrom qlib.config import HIGH_FREQ_CONFIG\n\nfrom qlib.utils import init_instance_by_config\nfrom qlib.utils.pickle_utils import restricted_pickle_load\nfrom qlib.data.dataset.handler import DataHandlerLP\nfrom qlib.data.ops import Operators\nfrom qlib.data.data import Cal\nfrom qlib.tests.data import GetData\n\nfrom highfreq_ops import get_calendar_day, DayLast, FFillNan, BFillNan, Date, Select, IsNull, Cut\n\n\nclass HighfreqWorkflow:\n    SPEC_CONF = {\"custom_ops\": [DayLast, FFillNan, BFillNan, Date, Select, IsNull, Cut], \"expression_cache\": None}\n\n    MARKET = \"all\"\n\n    start_time = \"2020-09-15 00:00:00\"\n    end_time = \"2021-01-18 16:00:00\"\n    train_end_time = \"2020-11-30 16:00:00\"\n    test_start_time = \"2020-12-01 00:00:00\"\n\n    DATA_HANDLER_CONFIG0 = {\n        \"start_time\": start_time,\n        \"end_time\": end_time,\n        \"fit_start_time\": start_time,\n        \"fit_end_time\": train_end_time,\n        \"instruments\": MARKET,\n        \"infer_processors\": [{\"class\": \"HighFreqNorm\", \"module_path\": \"highfreq_processor\"}],\n    }\n    DATA_HANDLER_CONFIG1 = {\n        \"start_time\": start_time,\n        \"end_time\": end_time,\n        \"instruments\": MARKET,\n    }\n\n    task = {\n        \"dataset\": {\n            \"class\": \"DatasetH\",\n            \"module_path\": \"qlib.data.dataset\",\n            \"kwargs\": {\n                \"handler\": {\n                    \"class\": \"HighFreqHandler\",\n                    \"module_path\": \"highfreq_handler\",\n                    \"kwargs\": DATA_HANDLER_CONFIG0,\n                },\n                \"segments\": {\n                    \"train\": (start_time, train_end_time),\n                    \"test\": (\n                        test_start_time,\n                        end_time,\n                    ),\n                },\n            },\n        },\n        \"dataset_backtest\": {\n            \"class\": \"DatasetH\",\n            \"module_path\": \"qlib.data.dataset\",\n            \"kwargs\": {\n                \"handler\": {\n                    \"class\": \"HighFreqBacktestHandler\",\n                    \"module_path\": \"highfreq_handler\",\n                    \"kwargs\": DATA_HANDLER_CONFIG1,\n                },\n                \"segments\": {\n                    \"train\": (start_time, train_end_time),\n                    \"test\": (\n                        test_start_time,\n                        end_time,\n                    ),\n                },\n            },\n        },\n    }\n\n    def _init_qlib(self):\n        \"\"\"initialize qlib\"\"\"\n        # use cn_data_1min data\n        QLIB_INIT_CONFIG = {**HIGH_FREQ_CONFIG, **self.SPEC_CONF}\n        provider_uri = QLIB_INIT_CONFIG.get(\"provider_uri\")\n        GetData().qlib_data(target_dir=provider_uri, interval=\"1min\", region=REG_CN, exists_skip=True)\n        qlib.init(**QLIB_INIT_CONFIG)\n\n    def _prepare_calender_cache(self):\n        \"\"\"preload the calendar for cache\"\"\"\n\n        # This code used the copy-on-write feature of Linux to avoid calculating the calendar multiple times in the subprocess\n        # This code may accelerate, but may be not useful on Windows and Mac Os\n        Cal.calendar(freq=\"1min\")\n        get_calendar_day(freq=\"1min\")\n\n    def get_data(self):\n        \"\"\"use dataset to get highreq data\"\"\"\n        self._init_qlib()\n        self._prepare_calender_cache()\n\n        dataset = init_instance_by_config(self.task[\"dataset\"])\n        xtrain, xtest = dataset.prepare([\"train\", \"test\"])\n        print(xtrain, xtest)\n\n        dataset_backtest = init_instance_by_config(self.task[\"dataset_backtest\"])\n        backtest_train, backtest_test = dataset_backtest.prepare([\"train\", \"test\"])\n        print(backtest_train, backtest_test)\n\n        return\n\n    def dump_and_load_dataset(self):\n        \"\"\"dump and load dataset state on disk\"\"\"\n        self._init_qlib()\n        self._prepare_calender_cache()\n        dataset = init_instance_by_config(self.task[\"dataset\"])\n        dataset_backtest = init_instance_by_config(self.task[\"dataset_backtest\"])\n\n        ##=============dump dataset=============\n        dataset.to_pickle(path=\"dataset.pkl\")\n        dataset_backtest.to_pickle(path=\"dataset_backtest.pkl\")\n\n        del dataset, dataset_backtest\n        ##=============reload dataset=============\n        with open(\"dataset.pkl\", \"rb\") as file_dataset:\n            dataset = restricted_pickle_load(file_dataset)\n\n        with open(\"dataset_backtest.pkl\", \"rb\") as file_dataset_backtest:\n            dataset_backtest = restricted_pickle_load(file_dataset_backtest)\n\n        self._prepare_calender_cache()\n        ##=============reinit dataset=============\n        dataset.config(\n            handler_kwargs={\n                \"start_time\": \"2021-01-19 00:00:00\",\n                \"end_time\": \"2021-01-25 16:00:00\",\n            },\n            segments={\n                \"test\": (\n                    \"2021-01-19 00:00:00\",\n                    \"2021-01-25 16:00:00\",\n                ),\n            },\n        )\n        dataset.setup_data(\n            handler_kwargs={\n                \"init_type\": DataHandlerLP.IT_LS,\n            },\n        )\n        dataset_backtest.config(\n            handler_kwargs={\n                \"start_time\": \"2021-01-19 00:00:00\",\n                \"end_time\": \"2021-01-25 16:00:00\",\n            },\n            segments={\n                \"test\": (\n                    \"2021-01-19 00:00:00\",\n                    \"2021-01-25 16:00:00\",\n                ),\n            },\n        )\n        dataset_backtest.setup_data(handler_kwargs={})\n\n        ##=============get data=============\n        xtest = dataset.prepare(\"test\")\n        backtest_test = dataset_backtest.prepare(\"test\")\n\n        print(xtest, backtest_test)\n        return\n\n\nif __name__ == \"__main__\":\n    fire.Fire(HighfreqWorkflow)\n"
  },
  {
    "path": "examples/highfreq/workflow_config_High_Freq_Tree_Alpha158.yaml",
    "content": "qlib_init:\n    provider_uri: \"~/.qlib/qlib_data/cn_data_1min\"\n    region: cn\nmarket: &market 'csi300'\nstart_time: &start_time \"2020-09-15 00:00:00\"\nend_time: &end_time \"2021-01-18 16:00:00\"\ntrain_end_time: &train_end_time \"2020-11-15 16:00:00\"\nvalid_start_time: &valid_start_time \"2020-11-16 00:00:00\"\nvalid_end_time: &valid_end_time \"2020-11-30 16:00:00\"\ntest_start_time: &test_start_time \"2020-12-01 00:00:00\"\ndata_handler_config: &data_handler_config\n    start_time: *start_time\n    end_time: *end_time\n    fit_start_time: *start_time\n    fit_end_time: *train_end_time\n    instruments: *market\n    freq: '1min'\n    infer_processors:\n        - class: 'RobustZScoreNorm'\n          kwargs:\n              fields_group: 'feature'\n              clip_outlier: false\n        - class: \"Fillna\"\n          kwargs:\n              fields_group: 'feature'\n    learn_processors:\n        - class: 'DropnaLabel'\n        - class: 'CSRankNorm'\n          kwargs:\n              fields_group: 'label'\n    label: [\"Ref($close, -2) / Ref($close, -1) - 1\"]\n    \ntask:\n    model:\n        class: \"HFLGBModel\"\n        module_path: \"qlib.contrib.model.highfreq_gdbt_model\"\n        kwargs:\n            objective: 'binary'\n            metric: ['binary_logloss','auc']\n            verbosity: -1\n            learning_rate: 0.01\n            max_depth: 8\n            num_leaves: 150\n            lambda_l1: 1.5\n            lambda_l2: 1\n            num_threads: 20\n    dataset:\n        class: \"DatasetH\"\n        module_path: \"qlib.data.dataset\"\n        kwargs:\n            handler:\n                class: \"Alpha158\"\n                module_path: \"qlib.contrib.data.handler\"\n                kwargs: *data_handler_config\n            segments:\n                train: [*start_time, *train_end_time]\n                valid: [*train_end_time, *valid_end_time]\n                test: [*test_start_time, *end_time]\n    record: \n        - class: \"SignalRecord\"\n          module_path: \"qlib.workflow.record_temp\"\n          kwargs: {}\n        - class: \"HFSignalRecord\"\n          module_path: \"qlib.workflow.record_temp\"\n          kwargs: {}"
  },
  {
    "path": "examples/hyperparameter/LightGBM/Readme.md",
    "content": "# LightGBM hyperparameter\n\n## Alpha158\nFirst terminal\n```\noptuna create-study --study LGBM_158 --storage sqlite:///db.sqlite3\noptuna-dashboard --port 5000 --host 0.0.0.0 sqlite:///db.sqlite3\n```\nSecond terminal\n```\npython hyperparameter_158.py\n```\n\n## Alpha360\nFirst terminal\n```\noptuna create-study --study LGBM_360 --storage sqlite:///db.sqlite3\noptuna-dashboard --port 5000 --host 0.0.0.0 sqlite:///db.sqlite3\n```\nSecond terminal\n```\npython hyperparameter_360.py\n```\n"
  },
  {
    "path": "examples/hyperparameter/LightGBM/hyperparameter_158.py",
    "content": "import qlib\nimport optuna\nfrom qlib.constant import REG_CN\nfrom qlib.utils import init_instance_by_config\nfrom qlib.tests.config import CSI300_DATASET_CONFIG\nfrom qlib.tests.data import GetData\n\n\ndef objective(trial):\n    task = {\n        \"model\": {\n            \"class\": \"LGBModel\",\n            \"module_path\": \"qlib.contrib.model.gbdt\",\n            \"kwargs\": {\n                \"loss\": \"mse\",\n                \"colsample_bytree\": trial.suggest_uniform(\"colsample_bytree\", 0.5, 1),\n                \"learning_rate\": trial.suggest_uniform(\"learning_rate\", 0, 1),\n                \"subsample\": trial.suggest_uniform(\"subsample\", 0, 1),\n                \"lambda_l1\": trial.suggest_loguniform(\"lambda_l1\", 1e-8, 1e4),\n                \"lambda_l2\": trial.suggest_loguniform(\"lambda_l2\", 1e-8, 1e4),\n                \"max_depth\": 10,\n                \"num_leaves\": trial.suggest_int(\"num_leaves\", 1, 1024),\n                \"feature_fraction\": trial.suggest_uniform(\"feature_fraction\", 0.4, 1.0),\n                \"bagging_fraction\": trial.suggest_uniform(\"bagging_fraction\", 0.4, 1.0),\n                \"bagging_freq\": trial.suggest_int(\"bagging_freq\", 1, 7),\n                \"min_data_in_leaf\": trial.suggest_int(\"min_data_in_leaf\", 1, 50),\n                \"min_child_samples\": trial.suggest_int(\"min_child_samples\", 5, 100),\n            },\n        },\n    }\n    evals_result = dict()\n    model = init_instance_by_config(task[\"model\"])\n    model.fit(dataset, evals_result=evals_result)\n    return min(evals_result[\"valid\"])\n\n\nif __name__ == \"__main__\":\n    provider_uri = \"~/.qlib/qlib_data/cn_data\"\n    GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True)\n    qlib.init(provider_uri=provider_uri, region=\"cn\")\n\n    dataset = init_instance_by_config(CSI300_DATASET_CONFIG)\n\n    study = optuna.Study(study_name=\"LGBM_158\", storage=\"sqlite:///db.sqlite3\")\n    study.optimize(objective, n_jobs=6)\n"
  },
  {
    "path": "examples/hyperparameter/LightGBM/hyperparameter_360.py",
    "content": "import qlib\nimport optuna\nfrom qlib.constant import REG_CN\nfrom qlib.utils import init_instance_by_config\nfrom qlib.tests.data import GetData\nfrom qlib.tests.config import get_dataset_config, CSI300_MARKET, DATASET_ALPHA360_CLASS\n\nDATASET_CONFIG = get_dataset_config(market=CSI300_MARKET, dataset_class=DATASET_ALPHA360_CLASS)\n\n\ndef objective(trial):\n    task = {\n        \"model\": {\n            \"class\": \"LGBModel\",\n            \"module_path\": \"qlib.contrib.model.gbdt\",\n            \"kwargs\": {\n                \"loss\": \"mse\",\n                \"colsample_bytree\": trial.suggest_uniform(\"colsample_bytree\", 0.5, 1),\n                \"learning_rate\": trial.suggest_uniform(\"learning_rate\", 0, 1),\n                \"subsample\": trial.suggest_uniform(\"subsample\", 0, 1),\n                \"lambda_l1\": trial.suggest_loguniform(\"lambda_l1\", 1e-8, 1e4),\n                \"lambda_l2\": trial.suggest_loguniform(\"lambda_l2\", 1e-8, 1e4),\n                \"max_depth\": 10,\n                \"num_leaves\": trial.suggest_int(\"num_leaves\", 1, 1024),\n                \"feature_fraction\": trial.suggest_uniform(\"feature_fraction\", 0.4, 1.0),\n                \"bagging_fraction\": trial.suggest_uniform(\"bagging_fraction\", 0.4, 1.0),\n                \"bagging_freq\": trial.suggest_int(\"bagging_freq\", 1, 7),\n                \"min_data_in_leaf\": trial.suggest_int(\"min_data_in_leaf\", 1, 50),\n                \"min_child_samples\": trial.suggest_int(\"min_child_samples\", 5, 100),\n            },\n        },\n    }\n\n    evals_result = dict()\n    model = init_instance_by_config(task[\"model\"])\n    model.fit(dataset, evals_result=evals_result)\n    return min(evals_result[\"valid\"])\n\n\nif __name__ == \"__main__\":\n    provider_uri = \"~/.qlib/qlib_data/cn_data\"\n    GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True)\n    qlib.init(provider_uri=provider_uri, region=REG_CN)\n\n    dataset = init_instance_by_config(DATASET_CONFIG)\n\n    study = optuna.Study(study_name=\"LGBM_360\", storage=\"sqlite:///db.sqlite3\")\n    study.optimize(objective, n_jobs=6)\n"
  },
  {
    "path": "examples/hyperparameter/LightGBM/requirements.txt",
    "content": "pandas==1.1.2\nnumpy==1.21.0\nlightgbm==3.1.0\noptuna==2.7.0\noptuna-dashboard==0.4.1\n"
  },
  {
    "path": "examples/model_interpreter/feature.py",
    "content": "#  Copyright (c) Microsoft Corporation.\n#  Licensed under the MIT License.\n\n\nimport qlib\nfrom qlib.constant import REG_CN\n\nfrom qlib.utils import init_instance_by_config\nfrom qlib.tests.data import GetData\nfrom qlib.tests.config import CSI300_GBDT_TASK\n\nif __name__ == \"__main__\":\n    # use default data\n    provider_uri = \"~/.qlib/qlib_data/cn_data\"  # target_dir\n    GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True)\n\n    qlib.init(provider_uri=provider_uri, region=REG_CN)\n\n    ###################################\n    # train model\n    ###################################\n    # model initialization\n    model = init_instance_by_config(CSI300_GBDT_TASK[\"model\"])\n    dataset = init_instance_by_config(CSI300_GBDT_TASK[\"dataset\"])\n    model.fit(dataset)\n\n    # get model feature importance\n    feature_importance = model.get_feature_importance()\n    print(\"feature importance:\")\n    print(feature_importance)\n"
  },
  {
    "path": "examples/model_rolling/requirements.txt",
    "content": "xgboost\n"
  },
  {
    "path": "examples/model_rolling/task_manager_rolling.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\n\"\"\"\nThis example shows how a TrainerRM works based on TaskManager with rolling tasks.\nAfter training, how to collect the rolling results will be shown in task_collecting.\nBased on the ability of TaskManager, `worker` method offer a simple way for multiprocessing.\n\"\"\"\n\nfrom pprint import pprint\n\nimport fire\nimport qlib\nfrom qlib.constant import REG_CN\nfrom qlib.workflow import R\nfrom qlib.workflow.task.gen import RollingGen, task_generator\nfrom qlib.workflow.task.manage import TaskManager, run_task\nfrom qlib.workflow.task.collect import RecorderCollector\nfrom qlib.model.ens.group import RollingGroup\nfrom qlib.model.trainer import TrainerR, TrainerRM, task_train\nfrom qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG, CSI100_RECORD_XGBOOST_TASK_CONFIG\n\n\nclass RollingTaskExample:\n    def __init__(\n        self,\n        provider_uri=\"~/.qlib/qlib_data/cn_data\",\n        region=REG_CN,\n        task_url=\"mongodb://10.0.0.4:27017/\",\n        task_db_name=\"rolling_db\",\n        experiment_name=\"rolling_exp\",\n        task_pool=None,  # if user want to  \"rolling_task\"\n        task_config=None,\n        rolling_step=550,\n        rolling_type=RollingGen.ROLL_SD,\n    ):\n        # TaskManager config\n        if task_config is None:\n            task_config = [CSI100_RECORD_XGBOOST_TASK_CONFIG, CSI100_RECORD_LGB_TASK_CONFIG]\n        mongo_conf = {\n            \"task_url\": task_url,\n            \"task_db_name\": task_db_name,\n        }\n        qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf)\n        self.experiment_name = experiment_name\n        if task_pool is None:\n            self.trainer = TrainerR(experiment_name=self.experiment_name)\n        else:\n            self.task_pool = task_pool\n            self.trainer = TrainerRM(self.experiment_name, self.task_pool)\n        self.task_config = task_config\n        self.rolling_gen = RollingGen(step=rolling_step, rtype=rolling_type)\n\n    # Reset all things to the first status, be careful to save important data\n    def reset(self):\n        print(\"========== reset ==========\")\n        if isinstance(self.trainer, TrainerRM):\n            TaskManager(task_pool=self.task_pool).remove()\n        exp = R.get_exp(experiment_name=self.experiment_name)\n        for rid in exp.list_recorders():\n            exp.delete_recorder(rid)\n\n    def task_generating(self):\n        print(\"========== task_generating ==========\")\n        tasks = task_generator(\n            tasks=self.task_config,\n            generators=self.rolling_gen,  # generate different date segments\n        )\n        pprint(tasks)\n        return tasks\n\n    def task_training(self, tasks):\n        print(\"========== task_training ==========\")\n        self.trainer.train(tasks)\n\n    def worker(self):\n        # NOTE: this is only used for TrainerRM\n        # train tasks by other progress or machines for multiprocessing. It is same as TrainerRM.worker.\n        print(\"========== worker ==========\")\n        run_task(task_train, self.task_pool, experiment_name=self.experiment_name)\n\n    def task_collecting(self):\n        print(\"========== task_collecting ==========\")\n\n        def rec_key(recorder):\n            task_config = recorder.load_object(\"task\")\n            model_key = task_config[\"model\"][\"class\"]\n            rolling_key = task_config[\"dataset\"][\"kwargs\"][\"segments\"][\"test\"]\n            return model_key, rolling_key\n\n        def my_filter(recorder):\n            # only choose the results of \"LGBModel\"\n            model_key, rolling_key = rec_key(recorder)\n            if model_key == \"LGBModel\":\n                return True\n            return False\n\n        collector = RecorderCollector(\n            experiment=self.experiment_name,\n            process_list=RollingGroup(),\n            rec_key_func=rec_key,\n            rec_filter_func=my_filter,\n        )\n        print(collector())\n\n    def main(self):\n        self.reset()\n        tasks = self.task_generating()\n        self.task_training(tasks)\n        self.task_collecting()\n\n\nif __name__ == \"__main__\":\n    ## to see the whole process with your own parameters, use the command below\n    # python task_manager_rolling.py main --experiment_name=\"your_exp_name\"\n    fire.Fire(RollingTaskExample)\n"
  },
  {
    "path": "examples/nested_decision_execution/README.md",
    "content": "# Nested Decision Execution\n\nThis workflow is an example for nested decision execution in backtesting. Qlib supports nested decision execution in backtesting. It means that users can use different strategies to make trade decision in different frequencies.\n\n## Weekly Portfolio Generation and Daily Order Execution\n\nThis workflow provides an example that uses a DropoutTopkStrategy (a strategy based on the daily frequency Lightgbm model) in weekly frequency for portfolio generation and uses SBBStrategyEMA (a rule-based strategy that uses EMA for decision-making) to execute orders in daily frequency. \n\n### Usage\n\nStart backtesting by running the following command:\n```bash\n    python workflow.py backtest\n```\n\nStart collecting data by running the following command:\n```bash\n    python workflow.py collect_data\n```\n\n## Daily Portfolio Generation and Minutely Order Execution\n\nThis workflow also provides a high-frequency example that uses a DropoutTopkStrategy for portfolio generation in daily frequency and uses SBBStrategyEMA to execute orders in minutely frequency. \n\n### Usage\n\nStart backtesting by running the following command:\n```bash\n    python workflow.py backtest_highfreq\n```"
  },
  {
    "path": "examples/nested_decision_execution/workflow.py",
    "content": "#  Copyright (c) Microsoft Corporation.\n#  Licensed under the MIT License.\n\"\"\"\nThe expect result of `backtest` is following in current version\n\n'The following are analysis results of benchmark return(1day).'\n                       risk\nmean               0.000651\nstd                0.012472\nannualized_return  0.154967\ninformation_ratio  0.805422\nmax_drawdown      -0.160445\n'The following are analysis results of the excess return without cost(1day).'\n                       risk\nmean               0.001258\nstd                0.007575\nannualized_return  0.299303\ninformation_ratio  2.561219\nmax_drawdown      -0.068386\n'The following are analysis results of the excess return with cost(1day).'\n                       risk\nmean               0.001110\nstd                0.007575\nannualized_return  0.264280\ninformation_ratio  2.261392\nmax_drawdown      -0.071842\n[1706497:MainThread](2021-12-07 14:08:30,263) INFO - qlib.workflow - [record_temp.py:441] - Portfolio analysis record 'port_analysis_30minute.\npkl' has been saved as the artifact of the Experiment 2\n'The following are analysis results of benchmark return(30minute).'\n                       risk\nmean               0.000078\nstd                0.003646\nannualized_return  0.148787\ninformation_ratio  0.935252\nmax_drawdown      -0.142830\n('The following are analysis results of the excess return without '\n 'cost(30minute).')\n                       risk\nmean               0.000174\nstd                0.003343\nannualized_return  0.331867\ninformation_ratio  2.275019\nmax_drawdown      -0.074752\n'The following are analysis results of the excess return with cost(30minute).'\n                       risk\nmean               0.000155\nstd                0.003343\nannualized_return  0.294536\ninformation_ratio  2.018860\nmax_drawdown      -0.075579\n[1706497:MainThread](2021-12-07 14:08:30,277) INFO - qlib.workflow - [record_temp.py:441] - Portfolio analysis record 'port_analysis_5minute.p\nkl' has been saved as the artifact of the Experiment 2\n'The following are analysis results of benchmark return(5minute).'\n                       risk\nmean               0.000015\nstd                0.001460\nannualized_return  0.172170\ninformation_ratio  1.103439\nmax_drawdown      -0.144807\n'The following are analysis results of the excess return without cost(5minute).'\n                       risk\nmean               0.000028\nstd                0.001412\nannualized_return  0.319771\ninformation_ratio  2.119563\nmax_drawdown      -0.077426\n'The following are analysis results of the excess return with cost(5minute).'\n                       risk\nmean               0.000025\nstd                0.001412\nannualized_return  0.281536\ninformation_ratio  1.866091\nmax_drawdown      -0.078194\n[1706497:MainThread](2021-12-07 14:08:30,287) INFO - qlib.workflow - [record_temp.py:466] - Indicator analysis record 'indicator_analysis_1day\n.pkl' has been saved as the artifact of the Experiment 2\n'The following are analysis results of indicators(1day).'\n        value\nffr  0.945821\npa   0.000324\npos  0.542882\n[1706497:MainThread](2021-12-07 14:08:30,293) INFO - qlib.workflow - [record_temp.py:466] - Indicator analysis record 'indicator_analysis_30mi\nnute.pkl' has been saved as the artifact of the Experiment 2\n'The following are analysis results of indicators(30minute).'\n        value\nffr  0.982910\npa   0.000037\npos  0.500806\n[1706497:MainThread](2021-12-07 14:08:30,302) INFO - qlib.workflow - [record_temp.py:466] - Indicator analysis record 'indicator_analysis_5min\nute.pkl' has been saved as the artifact of the Experiment 2\n'The following are analysis results of indicators(5minute).'\n        value\nffr  0.991017\npa   0.000000\npos  0.000000\n[1706497:MainThread](2021-12-07 14:08:30,627) INFO - qlib.timer - [log.py:113] - Time cost: 0.014s | waiting `async_log` Done\n\"\"\"\n\nfrom copy import deepcopy\nimport qlib\nimport fire\nimport pandas as pd\nfrom qlib.constant import REG_CN\nfrom qlib.config import HIGH_FREQ_CONFIG\nfrom qlib.data import D\nfrom qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict\nfrom qlib.workflow import R\nfrom qlib.workflow.record_temp import SignalRecord, PortAnaRecord\nfrom qlib.tests.data import GetData\nfrom qlib.backtest import collect_data\n\n\nclass NestedDecisionExecutionWorkflow:\n    market = \"csi300\"\n    benchmark = \"SH000300\"\n    data_handler_config = {\n        \"start_time\": \"2008-01-01\",\n        \"end_time\": \"2021-05-31\",\n        \"fit_start_time\": \"2008-01-01\",\n        \"fit_end_time\": \"2014-12-31\",\n        \"instruments\": market,\n    }\n\n    task = {\n        \"model\": {\n            \"class\": \"LGBModel\",\n            \"module_path\": \"qlib.contrib.model.gbdt\",\n            \"kwargs\": {\n                \"loss\": \"mse\",\n                \"colsample_bytree\": 0.8879,\n                \"learning_rate\": 0.0421,\n                \"subsample\": 0.8789,\n                \"lambda_l1\": 205.6999,\n                \"lambda_l2\": 580.9768,\n                \"max_depth\": 8,\n                \"num_leaves\": 210,\n                \"num_threads\": 20,\n            },\n        },\n        \"dataset\": {\n            \"class\": \"DatasetH\",\n            \"module_path\": \"qlib.data.dataset\",\n            \"kwargs\": {\n                \"handler\": {\n                    \"class\": \"Alpha158\",\n                    \"module_path\": \"qlib.contrib.data.handler\",\n                    \"kwargs\": data_handler_config,\n                },\n                \"segments\": {\n                    \"train\": (\"2007-01-01\", \"2014-12-31\"),\n                    \"valid\": (\"2015-01-01\", \"2016-12-31\"),\n                    \"test\": (\"2020-01-01\", \"2021-05-31\"),\n                },\n            },\n        },\n    }\n\n    exp_name = \"nested\"\n\n    port_analysis_config = {\n        \"executor\": {\n            \"class\": \"NestedExecutor\",\n            \"module_path\": \"qlib.backtest.executor\",\n            \"kwargs\": {\n                \"time_per_step\": \"day\",\n                \"inner_executor\": {\n                    \"class\": \"NestedExecutor\",\n                    \"module_path\": \"qlib.backtest.executor\",\n                    \"kwargs\": {\n                        \"time_per_step\": \"30min\",\n                        \"inner_executor\": {\n                            \"class\": \"SimulatorExecutor\",\n                            \"module_path\": \"qlib.backtest.executor\",\n                            \"kwargs\": {\n                                \"time_per_step\": \"5min\",\n                                \"generate_portfolio_metrics\": True,\n                                \"verbose\": True,\n                                \"indicator_config\": {\n                                    \"show_indicator\": True,\n                                },\n                            },\n                        },\n                        \"inner_strategy\": {\n                            \"class\": \"TWAPStrategy\",\n                            \"module_path\": \"qlib.contrib.strategy.rule_strategy\",\n                        },\n                        \"generate_portfolio_metrics\": True,\n                        \"indicator_config\": {\n                            \"show_indicator\": True,\n                        },\n                    },\n                },\n                \"inner_strategy\": {\n                    \"class\": \"SBBStrategyEMA\",\n                    \"module_path\": \"qlib.contrib.strategy.rule_strategy\",\n                    \"kwargs\": {\n                        \"instruments\": market,\n                        \"freq\": \"1min\",\n                    },\n                },\n                \"track_data\": True,\n                \"generate_portfolio_metrics\": True,\n                \"indicator_config\": {\n                    \"show_indicator\": True,\n                },\n            },\n        },\n        \"backtest\": {\n            \"start_time\": \"2020-09-20\",\n            \"end_time\": \"2021-05-20\",\n            \"account\": 100000000,\n            \"exchange_kwargs\": {\n                \"freq\": \"1min\",\n                \"limit_threshold\": 0.095,\n                \"deal_price\": \"close\",\n                \"open_cost\": 0.0005,\n                \"close_cost\": 0.0015,\n                \"min_cost\": 5,\n            },\n        },\n    }\n\n    def _init_qlib(self):\n        \"\"\"initialize qlib\"\"\"\n        provider_uri_day = \"~/.qlib/qlib_data/cn_data\"  # target_dir\n        GetData().qlib_data(target_dir=provider_uri_day, region=REG_CN, version=\"v2\", exists_skip=True)\n        provider_uri_1min = HIGH_FREQ_CONFIG.get(\"provider_uri\")\n        GetData().qlib_data(\n            target_dir=provider_uri_1min, interval=\"1min\", region=REG_CN, version=\"v2\", exists_skip=True\n        )\n        provider_uri_map = {\"1min\": provider_uri_1min, \"day\": provider_uri_day}\n        qlib.init(provider_uri=provider_uri_map, dataset_cache=None, expression_cache=None)\n\n    def _train_model(self, model, dataset):\n        with R.start(experiment_name=self.exp_name):\n            R.log_params(**flatten_dict(self.task))\n            model.fit(dataset)\n            R.save_objects(**{\"params.pkl\": model})\n\n            # prediction\n            recorder = R.get_recorder()\n            sr = SignalRecord(model, dataset, recorder)\n            sr.generate()\n\n    def backtest(self):\n        self._init_qlib()\n        model = init_instance_by_config(self.task[\"model\"])\n        dataset = init_instance_by_config(self.task[\"dataset\"])\n        self._train_model(model, dataset)\n        strategy_config = {\n            \"class\": \"TopkDropoutStrategy\",\n            \"module_path\": \"qlib.contrib.strategy.signal_strategy\",\n            \"kwargs\": {\n                \"signal\": (model, dataset),\n                \"topk\": 50,\n                \"n_drop\": 5,\n            },\n        }\n        self.port_analysis_config[\"strategy\"] = strategy_config\n        self.port_analysis_config[\"backtest\"][\"benchmark\"] = self.benchmark\n\n        with R.start(experiment_name=self.exp_name, resume=True):\n            recorder = R.get_recorder()\n            par = PortAnaRecord(\n                recorder,\n                self.port_analysis_config,\n                indicator_analysis_method=\"value_weighted\",\n            )\n            par.generate()\n\n        # user could use following methods to analysis the position\n        # report_normal_df = recorder.load_object(\"portfolio_analysis/report_normal_1day.pkl\")\n        # from qlib.contrib.report import analysis_position\n        # analysis_position.report_graph(report_normal_df)\n\n    def collect_data(self):\n        self._init_qlib()\n        model = init_instance_by_config(self.task[\"model\"])\n        dataset = init_instance_by_config(self.task[\"dataset\"])\n        self._train_model(model, dataset)\n        executor_config = self.port_analysis_config[\"executor\"]\n        backtest_config = self.port_analysis_config[\"backtest\"]\n        backtest_config[\"benchmark\"] = self.benchmark\n        strategy_config = {\n            \"class\": \"TopkDropoutStrategy\",\n            \"module_path\": \"qlib.contrib.strategy.signal_strategy\",\n            \"kwargs\": {\n                \"signal\": (model, dataset),\n                \"topk\": 50,\n                \"n_drop\": 5,\n            },\n        }\n        data_generator = collect_data(executor=executor_config, strategy=strategy_config, **backtest_config)\n        for trade_decision in data_generator:\n            print(trade_decision)\n\n    # the code below are for checking, users don't have to care about it\n    # The tests can be categorized into 2 types\n    # 1) comparing same backtest\n    # - Basic test idea: the shared accumulated value are equal in multiple levels\n    #   - Aligning the profit calculation between multiple levels and single levels.\n    # 2) comparing different backtest\n    # - Basic test idea:\n    #   - the daily backtest will be similar as multi-level(the data quality makes this gap smaller)\n\n    def check_diff_freq(self):\n        self._init_qlib()\n        exp = R.get_exp(experiment_name=\"backtest\")\n        rec = next(iter(exp.list_recorders().values()))  # assuming this will get the latest recorder\n        for check_key in \"account\", \"total_turnover\", \"total_cost\":\n            check_key = \"total_cost\"\n\n            acc_dict = {}\n            for freq in [\"30minute\", \"5minute\", \"1day\"]:\n                acc_dict[freq] = rec.load_object(f\"portfolio_analysis/report_normal_{freq}.pkl\")[check_key]\n            acc_df = pd.DataFrame(acc_dict)\n            acc_resam = acc_df.resample(\"1d\").last().dropna()\n            assert (acc_resam[\"30minute\"] == acc_resam[\"1day\"]).all()\n\n    def backtest_only_daily(self):\n        \"\"\"\n        This backtest is used for comparing the nested execution and single layer execution\n        Due to the low quality daily-level and miniute-level data, they are hardly comparable.\n        So it is used for detecting serious bugs which make the results different greatly.\n\n        .. code-block:: shell\n\n            [1724971:MainThread](2021-12-07 16:24:31,156) INFO - qlib.workflow - [record_temp.py:441] - Portfolio analysis record 'port_analysis_1day.pkl'\n            has been saved as the artifact of the Experiment 2\n            'The following are analysis results of benchmark return(1day).'\n                                   risk\n            mean               0.000651\n            std                0.012472\n            annualized_return  0.154967\n            information_ratio  0.805422\n            max_drawdown      -0.160445\n            'The following are analysis results of the excess return without cost(1day).'\n                                   risk\n            mean               0.001375\n            std                0.006103\n            annualized_return  0.327204\n            information_ratio  3.475016\n            max_drawdown      -0.024927\n            'The following are analysis results of the excess return with cost(1day).'\n                                   risk\n            mean               0.001184\n            std                0.006091\n            annualized_return  0.281801\n            information_ratio  2.998749\n            max_drawdown      -0.029568\n            [1724971:MainThread](2021-12-07 16:24:31,170) INFO - qlib.workflow - [record_temp.py:466] - Indicator analysis record 'indicator_analysis_1day.\n            pkl' has been saved as the artifact of the Experiment 2\n            'The following are analysis results of indicators(1day).'\n                 value\n            ffr    1.0\n            pa     0.0\n            pos    0.0\n            [1724971:MainThread](2021-12-07 16:24:31,188) INFO - qlib.timer - [log.py:113] - Time cost: 0.007s | waiting `async_log` Done\n\n        \"\"\"\n        self._init_qlib()\n        model = init_instance_by_config(self.task[\"model\"])\n        dataset = init_instance_by_config(self.task[\"dataset\"])\n        self._train_model(model, dataset)\n        strategy_config = {\n            \"class\": \"TopkDropoutStrategy\",\n            \"module_path\": \"qlib.contrib.strategy.signal_strategy\",\n            \"kwargs\": {\n                \"signal\": (model, dataset),\n                \"topk\": 50,\n                \"n_drop\": 5,\n            },\n        }\n        pa_conf = deepcopy(self.port_analysis_config)\n        pa_conf[\"strategy\"] = strategy_config\n        pa_conf[\"executor\"] = {\n            \"class\": \"SimulatorExecutor\",\n            \"module_path\": \"qlib.backtest.executor\",\n            \"kwargs\": {\n                \"time_per_step\": \"day\",\n                \"generate_portfolio_metrics\": True,\n                \"verbose\": True,\n            },\n        }\n        pa_conf[\"backtest\"][\"benchmark\"] = self.benchmark\n\n        with R.start(experiment_name=self.exp_name, resume=True):\n            recorder = R.get_recorder()\n            par = PortAnaRecord(recorder, pa_conf)\n            par.generate()\n\n\nif __name__ == \"__main__\":\n    fire.Fire(NestedDecisionExecutionWorkflow)\n"
  },
  {
    "path": "examples/online_srv/online_management_simulate.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\n\"\"\"\nThis example is about how can simulate the OnlineManager based on rolling tasks.\n\"\"\"\n\nfrom pprint import pprint\nimport fire\nimport qlib\nfrom qlib.model.trainer import DelayTrainerR, DelayTrainerRM, TrainerR, TrainerRM\nfrom qlib.workflow import R\nfrom qlib.workflow.online.manager import OnlineManager\nfrom qlib.workflow.online.strategy import RollingStrategy\nfrom qlib.workflow.task.gen import RollingGen\nfrom qlib.workflow.task.manage import TaskManager\nfrom qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG_ONLINE, CSI100_RECORD_XGBOOST_TASK_CONFIG_ONLINE\nimport pandas as pd\nfrom qlib.contrib.evaluate import backtest_daily\nfrom qlib.contrib.evaluate import risk_analysis\nfrom qlib.contrib.strategy import TopkDropoutStrategy\n\n\nclass OnlineSimulationExample:\n    def __init__(\n        self,\n        provider_uri=\"~/.qlib/qlib_data/cn_data\",\n        region=\"cn\",\n        exp_name=\"rolling_exp\",\n        task_url=\"mongodb://10.0.0.4:27017/\",  # not necessary when using TrainerR or DelayTrainerR\n        task_db_name=\"rolling_db\",  # not necessary when using TrainerR or DelayTrainerR\n        task_pool=\"rolling_task\",\n        rolling_step=80,\n        start_time=\"2018-09-10\",\n        end_time=\"2018-10-31\",\n        tasks=None,\n        trainer=\"TrainerR\",\n    ):\n        \"\"\"\n        Init OnlineManagerExample.\n\n        Args:\n            provider_uri (str, optional): the provider uri. Defaults to \"~/.qlib/qlib_data/cn_data\".\n            region (str, optional): the stock region. Defaults to \"cn\".\n            exp_name (str, optional): the experiment name. Defaults to \"rolling_exp\".\n            task_url (str, optional): your MongoDB url. Defaults to \"mongodb://10.0.0.4:27017/\".\n            task_db_name (str, optional): database name. Defaults to \"rolling_db\".\n            task_pool (str, optional): the task pool name (a task pool is a collection in MongoDB). Defaults to \"rolling_task\".\n            rolling_step (int, optional): the step for rolling. Defaults to 80.\n            start_time (str, optional): the start time of simulating. Defaults to \"2018-09-10\".\n            end_time (str, optional): the end time of simulating. Defaults to \"2018-10-31\".\n            tasks (dict or list[dict]): a set of the task config waiting for rolling and training\n        \"\"\"\n        if tasks is None:\n            tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG_ONLINE, CSI100_RECORD_LGB_TASK_CONFIG_ONLINE]\n        self.exp_name = exp_name\n        self.task_pool = task_pool\n        self.start_time = start_time\n        self.end_time = end_time\n        mongo_conf = {\n            \"task_url\": task_url,\n            \"task_db_name\": task_db_name,\n        }\n        qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf)\n        self.rolling_gen = RollingGen(\n            step=rolling_step, rtype=RollingGen.ROLL_SD, ds_extra_mod_func=None\n        )  # The rolling tasks generator, ds_extra_mod_func is None because we just need to simulate to 2018-10-31 and needn't change the handler end time.\n        if trainer == \"TrainerRM\":\n            self.trainer = TrainerRM(self.exp_name, self.task_pool)\n        elif trainer == \"TrainerR\":\n            self.trainer = TrainerR(self.exp_name)\n        else:\n            # TODO: support all the trainers: TrainerR, TrainerRM, DelayTrainerR\n            raise NotImplementedError(f\"This type of input is not supported\")\n        self.rolling_online_manager = OnlineManager(\n            RollingStrategy(exp_name, task_template=tasks, rolling_gen=self.rolling_gen),\n            trainer=self.trainer,\n            begin_time=self.start_time,\n        )\n        self.tasks = tasks\n\n    # Reset all things to the first status, be careful to save important data\n    def reset(self):\n        if isinstance(self.trainer, TrainerRM):\n            TaskManager(self.task_pool).remove()\n        exp = R.get_exp(experiment_name=self.exp_name)\n        for rid in exp.list_recorders():\n            exp.delete_recorder(rid)\n\n    # Run this to run all workflow automatically\n    def main(self):\n        print(\"========== reset ==========\")\n        self.reset()\n        print(\"========== simulate ==========\")\n        self.rolling_online_manager.simulate(end_time=self.end_time)\n        print(\"========== collect results ==========\")\n        print(self.rolling_online_manager.get_collector()())\n        print(\"========== signals ==========\")\n        signals = self.rolling_online_manager.get_signals()\n        print(signals)\n        # Backtesting\n        # - the code is based on this example https://qlib.readthedocs.io/en/latest/component/strategy.html\n        CSI300_BENCH = \"SH000903\"\n        STRATEGY_CONFIG = {\n            \"topk\": 30,\n            \"n_drop\": 3,\n            \"signal\": signals.to_frame(\"score\"),\n        }\n        strategy_obj = TopkDropoutStrategy(**STRATEGY_CONFIG)\n        report_normal, positions_normal = backtest_daily(\n            start_time=signals.index.get_level_values(\"datetime\").min(),\n            end_time=signals.index.get_level_values(\"datetime\").max(),\n            strategy=strategy_obj,\n        )\n        analysis = dict()\n        analysis[\"excess_return_without_cost\"] = risk_analysis(report_normal[\"return\"] - report_normal[\"bench\"])\n        analysis[\"excess_return_with_cost\"] = risk_analysis(\n            report_normal[\"return\"] - report_normal[\"bench\"] - report_normal[\"cost\"]\n        )\n\n        analysis_df = pd.concat(analysis)  # type: pd.DataFrame\n        pprint(analysis_df)\n\n    def worker(self):\n        # train tasks by other progress or machines for multiprocessing\n        # FIXME: only can call after finishing simulation when using DelayTrainerRM, or there will be some exception.\n        print(\"========== worker ==========\")\n        if isinstance(self.trainer, TrainerRM):\n            self.trainer.worker()\n        else:\n            print(f\"{type(self.trainer)} is not supported for worker.\")\n\n\nif __name__ == \"__main__\":\n    ## to run all workflow automatically with your own parameters, use the command below\n    # python online_management_simulate.py main --experiment_name=\"your_exp_name\" --rolling_step=60\n    fire.Fire(OnlineSimulationExample)\n"
  },
  {
    "path": "examples/online_srv/rolling_online_management.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\n\"\"\"\nThis example shows how OnlineManager works with rolling tasks.\nThere are four parts including first train, routine 1, add strategy and routine 2.\nFirstly, the OnlineManager will finish the first training and set trained models to `online` models.\nNext, the OnlineManager will finish a routine process, including update online prediction -> prepare tasks -> prepare new models -> prepare signals\nThen, we will add some new strategies to the OnlineManager. This will finish first training of new strategies.\nFinally, the OnlineManager will finish second routine and update all strategies.\n\"\"\"\n\nimport os\nimport fire\nimport qlib\nfrom qlib.model.trainer import DelayTrainerR, DelayTrainerRM, TrainerR, TrainerRM, end_task_train, task_train\nfrom qlib.workflow import R\nfrom qlib.workflow.online.strategy import RollingStrategy\nfrom qlib.workflow.task.gen import RollingGen\nfrom qlib.workflow.online.manager import OnlineManager\nfrom qlib.tests.config import CSI100_RECORD_XGBOOST_TASK_CONFIG_ROLLING, CSI100_RECORD_LGB_TASK_CONFIG_ROLLING\nfrom qlib.workflow.task.manage import TaskManager\n\n\nclass RollingOnlineExample:\n    def __init__(\n        self,\n        provider_uri=\"~/.qlib/qlib_data/cn_data\",\n        region=\"cn\",\n        trainer=DelayTrainerRM(),  # you can choose from TrainerR, TrainerRM, DelayTrainerR, DelayTrainerRM\n        task_url=\"mongodb://10.0.0.4:27017/\",  # not necessary when using TrainerR or DelayTrainerR\n        task_db_name=\"rolling_db\",  # not necessary when using TrainerR or DelayTrainerR\n        rolling_step=550,\n        tasks=None,\n        add_tasks=None,\n    ):\n        if add_tasks is None:\n            add_tasks = [CSI100_RECORD_LGB_TASK_CONFIG_ROLLING]\n        if tasks is None:\n            tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG_ROLLING]\n        mongo_conf = {\n            \"task_url\": task_url,  # your MongoDB url\n            \"task_db_name\": task_db_name,  # database name\n        }\n        qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf)\n        self.tasks = tasks\n        self.add_tasks = add_tasks\n        self.rolling_step = rolling_step\n        strategies = []\n        for task in tasks:\n            name_id = task[\"model\"][\"class\"]  # NOTE: Assumption: The model class can specify only one strategy\n            strategies.append(\n                RollingStrategy(\n                    name_id,\n                    task,\n                    RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD),\n                )\n            )\n        self.trainer = trainer\n        self.rolling_online_manager = OnlineManager(strategies, trainer=self.trainer)\n\n    _ROLLING_MANAGER_PATH = (\n        \".RollingOnlineExample\"  # the OnlineManager will dump to this file, for it can be loaded when calling routine.\n    )\n\n    def worker(self):\n        # train tasks by other progress or machines for multiprocessing\n        print(\"========== worker ==========\")\n        if isinstance(self.trainer, TrainerRM):\n            for task in self.tasks + self.add_tasks:\n                name_id = task[\"model\"][\"class\"]\n                self.trainer.worker(experiment_name=name_id)\n        else:\n            print(f\"{type(self.trainer)} is not supported for worker.\")\n\n    # Reset all things to the first status, be careful to save important data\n    def reset(self):\n        for task in self.tasks + self.add_tasks:\n            name_id = task[\"model\"][\"class\"]\n            TaskManager(task_pool=name_id).remove()\n            exp = R.get_exp(experiment_name=name_id)\n            for rid in exp.list_recorders():\n                exp.delete_recorder(rid)\n\n        if os.path.exists(self._ROLLING_MANAGER_PATH):\n            os.remove(self._ROLLING_MANAGER_PATH)\n\n    def first_run(self):\n        print(\"========== reset ==========\")\n        self.reset()\n        print(\"========== first_run ==========\")\n        self.rolling_online_manager.first_train()\n        print(\"========== collect results ==========\")\n        print(self.rolling_online_manager.get_collector()())\n        print(\"========== dump ==========\")\n        self.rolling_online_manager.to_pickle(self._ROLLING_MANAGER_PATH)\n\n    def routine(self):\n        print(\"========== load ==========\")\n        self.rolling_online_manager = OnlineManager.load(self._ROLLING_MANAGER_PATH)\n        print(\"========== routine ==========\")\n        self.rolling_online_manager.routine()\n        print(\"========== collect results ==========\")\n        print(self.rolling_online_manager.get_collector()())\n        print(\"========== signals ==========\")\n        print(self.rolling_online_manager.get_signals())\n        print(\"========== dump ==========\")\n        self.rolling_online_manager.to_pickle(self._ROLLING_MANAGER_PATH)\n\n    def add_strategy(self):\n        print(\"========== load ==========\")\n        self.rolling_online_manager = OnlineManager.load(self._ROLLING_MANAGER_PATH)\n        print(\"========== add strategy ==========\")\n        strategies = []\n        for task in self.add_tasks:\n            name_id = task[\"model\"][\"class\"]  # NOTE: Assumption: The model class can specify only one strategy\n            strategies.append(\n                RollingStrategy(\n                    name_id,\n                    task,\n                    RollingGen(step=self.rolling_step, rtype=RollingGen.ROLL_SD),\n                )\n            )\n        self.rolling_online_manager.add_strategy(strategies=strategies)\n        print(\"========== dump ==========\")\n        self.rolling_online_manager.to_pickle(self._ROLLING_MANAGER_PATH)\n\n    def main(self):\n        self.first_run()\n        self.routine()\n        self.add_strategy()\n        self.routine()\n\n\nif __name__ == \"__main__\":\n    ####### to train the first version's models, use the command below\n    # python rolling_online_management.py first_run\n\n    ####### to update the models and predictions after the trading time, use the command below\n    # python rolling_online_management.py routine\n\n    ####### to define your own parameters, use `--`\n    # python rolling_online_management.py first_run --exp_name='your_exp_name' --rolling_step=40\n    fire.Fire(RollingOnlineExample)\n"
  },
  {
    "path": "examples/online_srv/update_online_pred.py",
    "content": "# Copyright (c) Microsoft Corporation.\r\n# Licensed under the MIT License.\r\n\r\n\"\"\"\r\nThis example shows how OnlineTool works when we need update prediction.\r\nThere are two parts including first_train and update_online_pred.\r\nFirstly, we will finish the training and set the trained models to the `online` models.\r\nNext, we will finish updating online predictions.\r\n\"\"\"\r\n\r\nimport copy\r\nimport fire\r\nimport qlib\r\nfrom qlib.constant import REG_CN\r\nfrom qlib.model.trainer import task_train\r\nfrom qlib.workflow.online.utils import OnlineToolR\r\nfrom qlib.tests.config import CSI300_GBDT_TASK\r\n\r\ntask = copy.deepcopy(CSI300_GBDT_TASK)\r\n\r\ntask[\"record\"] = {\r\n    \"class\": \"SignalRecord\",\r\n    \"module_path\": \"qlib.workflow.record_temp\",\r\n}\r\n\r\n\r\nclass UpdatePredExample:\r\n    def __init__(\r\n        self, provider_uri=\"~/.qlib/qlib_data/cn_data\", region=REG_CN, experiment_name=\"online_srv\", task_config=task\r\n    ):\r\n        qlib.init(provider_uri=provider_uri, region=region)\r\n        self.experiment_name = experiment_name\r\n        self.online_tool = OnlineToolR(self.experiment_name)\r\n        self.task_config = task_config\r\n\r\n    def first_train(self):\r\n        rec = task_train(self.task_config, experiment_name=self.experiment_name)\r\n        self.online_tool.reset_online_tag(rec)  # set to online model\r\n\r\n    def update_online_pred(self):\r\n        self.online_tool.update_online_pred()\r\n\r\n    def main(self):\r\n        self.first_train()\r\n        self.update_online_pred()\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    ## to train a model and set it to online model, use the command below\r\n    # python update_online_pred.py first_train\r\n    ## to update online predictions once a day, use the command below\r\n    # python update_online_pred.py update_online_pred\r\n    ## to see the whole process with your own parameters, use the command below\r\n    # python update_online_pred.py main --experiment_name=\"your_exp_name\"\r\n    fire.Fire(UpdatePredExample)\r\n"
  },
  {
    "path": "examples/orderbook_data/README.md",
    "content": "# Introduction\n\nThis example tries to demonstrate how Qlib supports data without fixed shared frequency.\n\nFor example,\n- Daily prices volume data are fixed-frequency data. The data comes in a fixed frequency (i.e. daily)\n- Orders are not fixed data and they may come at any time point\n\nTo support such non-fixed-frequency, Qlib implements an Arctic-based backend.\nHere is an example to import and query data based on this backend.\n\n# Installation\n\nPlease refer to [the installation docs](https://docs.mongodb.com/manual/installation/) of mongodb.\nCurrent version of script with default value tries to connect localhost **via default port without authentication**.\n\nRun following command to install necessary libraries\n```\npip install pytest coverage gdown\npip install arctic  # NOTE: pip may fail to resolve the right package dependency !!! Please make sure the dependency are satisfied.\n```\n\n# Importing example data\n\n\n1. (Optional) Please follow the first part of [this section](https://github.com/microsoft/qlib#data-preparation) to **get 1min data** of Qlib.\n2. Please follow following steps to download example data\n```bash\ncd examples/orderbook_data/\ngdown https://drive.google.com/uc?id=15FuUqWn2rkCi8uhJYGEQWKakcEqLJNDG  # Proxies may be necessary here.\npython ../../scripts/get_data.py _unzip --file_path highfreq_orderbook_example_data.zip --target_dir .\n```\n\n3. Please import the example data to your mongo db\n```bash\npython create_dataset.py initialize_library  # Initialization Libraries\npython create_dataset.py import_data  # Initialization Libraries\n```\n\n# Query Examples\n\nAfter importing these data, you run `example.py` to create some high-frequency features.\n```bash\npytest -s --disable-warnings example.py   # If you want run all examples\npytest -s --disable-warnings example.py::TestClass::test_exp_10  # If you want to run specific example\n```\n\n\n# Known limitations\nExpression computing between different frequencies are not supported yet\n"
  },
  {
    "path": "examples/orderbook_data/create_dataset.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\"\"\"\nNOTE:\n- This scripts is a demo to import example data import Qlib\n- !!!!!!!!!!!!!!!TODO!!!!!!!!!!!!!!!!!!!:\n    - Its structure is not well designed and very ugly, your contribution is welcome to make importing dataset easier\n\"\"\"\n\nfrom datetime import date, datetime as dt\nimport os\nfrom pathlib import Path\nimport random\nimport shutil\nimport time\nimport traceback\n\nfrom arctic import Arctic, chunkstore\nimport arctic\nfrom arctic import Arctic, CHUNK_STORE\nfrom arctic.chunkstore.chunkstore import CHUNK_SIZE\nimport fire\nfrom joblib import Parallel, delayed, parallel\nimport numpy as np\nimport pandas as pd\nfrom pandas import DataFrame\nfrom pandas.core.indexes.datetimes import date_range\nfrom pymongo.mongo_client import MongoClient\n\nDIRNAME = Path(__file__).absolute().resolve().parent\n\n# CONFIG\nN_JOBS = -1  # leaving one kernel free\nLOG_FILE_PATH = DIRNAME / \"log_file\"\nDATA_PATH = DIRNAME / \"raw_data\"\nDATABASE_PATH = DIRNAME / \"orig_data\"\nDATA_INFO_PATH = DIRNAME / \"data_info\"\nDATA_FINISH_INFO_PATH = DIRNAME / \"./data_finish_info\"\nDOC_TYPE = [\"Tick\", \"Order\", \"OrderQueue\", \"Transaction\", \"Day\", \"Minute\"]\nMAX_SIZE = 3000 * 1024 * 1024 * 1024\nALL_STOCK_PATH = DATABASE_PATH / \"all.txt\"\nARCTIC_SRV = \"127.0.0.1\"\n\n\ndef get_library_name(doc_type):\n    if str.lower(doc_type) == str.lower(\"Tick\"):\n        return \"ticks\"\n    else:\n        return str.lower(doc_type)\n\n\ndef is_stock(exchange_place, code):\n    if exchange_place == \"SH\" and code[0] != \"6\":\n        return False\n    if exchange_place == \"SZ\" and code[0] != \"0\" and code[:2] != \"30\":\n        return False\n    return True\n\n\ndef add_one_stock_daily_data(filepath, type, exchange_place, arc, date):\n    \"\"\"\n    exchange_place: \"SZ\" OR \"SH\"\n    type: \"tick\", \"orderbook\", ...\n    filepath: the path of csv\n    arc: arclink created by a process\n    \"\"\"\n    code = os.path.split(filepath)[-1].split(\".csv\")[0]\n    if exchange_place == \"SH\" and code[0] != \"6\":\n        return\n    if exchange_place == \"SZ\" and code[0] != \"0\" and code[:2] != \"30\":\n        return\n\n    df = pd.read_csv(filepath, encoding=\"gbk\", dtype={\"code\": str})\n    code = os.path.split(filepath)[-1].split(\".csv\")[0]\n\n    def format_time(day, hms):\n        day = str(day)\n        hms = str(hms)\n        if hms[0] == \"1\":  # >=10,\n            return (\n                \"-\".join([day[0:4], day[4:6], day[6:8]]) + \" \" + \":\".join([hms[:2], hms[2:4], hms[4:6] + \".\" + hms[6:]])\n            )\n        else:\n            return (\n                \"-\".join([day[0:4], day[4:6], day[6:8]]) + \" \" + \":\".join([hms[:1], hms[1:3], hms[3:5] + \".\" + hms[5:]])\n            )\n\n    ## Discard the entire row if wrong data timestamp encoutered.\n    timestamp = list(zip(list(df[\"date\"]), list(df[\"time\"])))\n    error_index_list = []\n    for index, t in enumerate(timestamp):\n        try:\n            pd.Timestamp(format_time(t[0], t[1]))\n        except Exception:\n            error_index_list.append(index)  ## The row number of the error line\n\n    # to-do: writting to logs\n\n    if len(error_index_list) > 0:\n        print(\"error: {}, {}\".format(filepath, len(error_index_list)))\n\n    df = df.drop(error_index_list)\n    timestamp = list(zip(list(df[\"date\"]), list(df[\"time\"])))  ## The cleaned timestamp\n    # generate timestamp\n    pd_timestamp = pd.DatetimeIndex(\n        [pd.Timestamp(format_time(timestamp[i][0], timestamp[i][1])) for i in range(len(df[\"date\"]))]\n    )\n    df = df.drop(columns=[\"date\", \"time\", \"name\", \"code\", \"wind_code\"])\n    # df = pd.DataFrame(data=df.to_dict(\"list\"), index=pd_timestamp)\n    df[\"date\"] = pd.to_datetime(pd_timestamp)\n    df.set_index(\"date\", inplace=True)\n\n    if str.lower(type) == \"orderqueue\":\n        ## extract ab1~ab50\n        df[\"ab\"] = [\n            \",\".join([str(int(row[\"ab\" + str(i + 1)])) for i in range(0, row[\"ab_items\"])])\n            for timestamp, row in df.iterrows()\n        ]\n        df = df.drop(columns=[\"ab\" + str(i) for i in range(1, 51)])\n\n    type = get_library_name(type)\n    # arc.initialize_library(type, lib_type=CHUNK_STORE)\n    lib = arc[type]\n\n    symbol = \"\".join([exchange_place, code])\n    if symbol in lib.list_symbols():\n        print(\"update {0}, date={1}\".format(symbol, date))\n        if df.empty == True:\n            return error_index_list\n        lib.update(symbol, df, chunk_size=\"D\")\n    else:\n        print(\"write {0}, date={1}\".format(symbol, date))\n        lib.write(symbol, df, chunk_size=\"D\")\n    return error_index_list\n\n\ndef add_one_stock_daily_data_wrapper(filepath, type, exchange_place, index, date):\n    pid = os.getpid()\n    code = os.path.split(filepath)[-1].split(\".csv\")[0]\n    arc = Arctic(ARCTIC_SRV)\n    try:\n        if index % 100 == 0:\n            print(\"index = {}, filepath = {}\".format(index, filepath))\n        error_index_list = add_one_stock_daily_data(filepath, type, exchange_place, arc, date)\n        if error_index_list is not None and len(error_index_list) > 0:\n            f = open(os.path.join(LOG_FILE_PATH, \"temp_timestamp_error_{0}_{1}_{2}.txt\".format(pid, date, type)), \"a+\")\n            f.write(\"{}, {}, {}\\n\".format(filepath, error_index_list, exchange_place + \"_\" + code))\n            f.close()\n\n    except Exception as e:\n        info = traceback.format_exc()\n        print(\"error:\" + str(e))\n        f = open(os.path.join(LOG_FILE_PATH, \"temp_fail_{0}_{1}_{2}.txt\".format(pid, date, type)), \"a+\")\n        f.write(\"fail:\" + str(filepath) + \"\\n\" + str(e) + \"\\n\" + str(info) + \"\\n\")\n        f.close()\n\n    finally:\n        arc.reset()\n\n\ndef add_data(tick_date, doc_type, stock_name_dict):\n    pid = os.getpid()\n\n    if doc_type not in DOC_TYPE:\n        print(\"doc_type not in {}\".format(DOC_TYPE))\n        return\n    try:\n        begin_time = time.time()\n        os.system(f\"cp {DATABASE_PATH}/{tick_date + '_{}.tar.gz'.format(doc_type)} {DATA_PATH}/\")\n\n        os.system(\n            f\"tar -xvzf {DATA_PATH}/{tick_date + '_{}.tar.gz'.format(doc_type)} -C {DATA_PATH}/ {tick_date + '_' + doc_type}/SH\"\n        )\n        os.system(\n            f\"tar -xvzf {DATA_PATH}/{tick_date + '_{}.tar.gz'.format(doc_type)} -C {DATA_PATH}/ {tick_date + '_' + doc_type}/SZ\"\n        )\n        os.system(f\"chmod 777 {DATA_PATH}\")\n        os.system(f\"chmod 777 {DATA_PATH}/{tick_date + '_' + doc_type}\")\n        os.system(f\"chmod 777 {DATA_PATH}/{tick_date + '_' + doc_type}/SH\")\n        os.system(f\"chmod 777 {DATA_PATH}/{tick_date + '_' + doc_type}/SZ\")\n        os.system(f\"chmod 777 {DATA_PATH}/{tick_date + '_' + doc_type}/SH/{tick_date}\")\n        os.system(f\"chmod 777 {DATA_PATH}/{tick_date + '_' + doc_type}/SZ/{tick_date}\")\n\n        print(\"tick_date={}\".format(tick_date))\n\n        temp_data_path_sh = os.path.join(DATA_PATH, tick_date + \"_\" + doc_type, \"SH\", tick_date)\n        temp_data_path_sz = os.path.join(DATA_PATH, tick_date + \"_\" + doc_type, \"SZ\", tick_date)\n        is_files_exist = {\"sh\": os.path.exists(temp_data_path_sh), \"sz\": os.path.exists(temp_data_path_sz)}\n\n        sz_files = (\n            (\n                set([i.split(\".csv\")[0] for i in os.listdir(temp_data_path_sz) if i[:2] == \"30\" or i[0] == \"0\"])\n                & set(stock_name_dict[\"SZ\"])\n            )\n            if is_files_exist[\"sz\"]\n            else set()\n        )\n        sz_file_nums = len(sz_files) if is_files_exist[\"sz\"] else 0\n        sh_files = (\n            (\n                set([i.split(\".csv\")[0] for i in os.listdir(temp_data_path_sh) if i[0] == \"6\"])\n                & set(stock_name_dict[\"SH\"])\n            )\n            if is_files_exist[\"sh\"]\n            else set()\n        )\n        sh_file_nums = len(sh_files) if is_files_exist[\"sh\"] else 0\n        print(\"sz_file_nums:{}, sh_file_nums:{}\".format(sz_file_nums, sh_file_nums))\n\n        f = (DATA_INFO_PATH / \"data_info_log_{}_{}\".format(doc_type, tick_date)).open(\"w+\")\n        f.write(\"sz:{}, sh:{}, date:{}:\".format(sz_file_nums, sh_file_nums, tick_date) + \"\\n\")\n        f.close()\n\n        if sh_file_nums > 0:\n            # write is not thread-safe, update may be thread-safe\n            Parallel(n_jobs=N_JOBS)(\n                delayed(add_one_stock_daily_data_wrapper)(\n                    os.path.join(temp_data_path_sh, name + \".csv\"), doc_type, \"SH\", index, tick_date\n                )\n                for index, name in enumerate(list(sh_files))\n            )\n        if sz_file_nums > 0:\n            # write is not thread-safe, update may be thread-safe\n            Parallel(n_jobs=N_JOBS)(\n                delayed(add_one_stock_daily_data_wrapper)(\n                    os.path.join(temp_data_path_sz, name + \".csv\"), doc_type, \"SZ\", index, tick_date\n                )\n                for index, name in enumerate(list(sz_files))\n            )\n\n        os.system(f\"rm -f {DATA_PATH}/{tick_date + '_{}.tar.gz'.format(doc_type)}\")\n        os.system(f\"rm -rf {DATA_PATH}/{tick_date + '_' + doc_type}\")\n        total_time = time.time() - begin_time\n        f = (DATA_FINISH_INFO_PATH / \"data_info_finish_log_{}_{}\".format(doc_type, tick_date)).open(\"w+\")\n        f.write(\"finish: date:{}, consume_time:{}, end_time: {}\".format(tick_date, total_time, time.time()) + \"\\n\")\n        f.close()\n\n    except Exception as e:\n        info = traceback.format_exc()\n        print(\"date error:\" + str(e))\n        f = open(os.path.join(LOG_FILE_PATH, \"temp_fail_{0}_{1}_{2}.txt\".format(pid, tick_date, doc_type)), \"a+\")\n        f.write(\"fail:\" + str(tick_date) + \"\\n\" + str(e) + \"\\n\" + str(info) + \"\\n\")\n        f.close()\n\n\nclass DSCreator:\n    \"\"\"Dataset creator\"\"\"\n\n    def clear(self):\n        client = MongoClient(ARCTIC_SRV)\n        client.drop_database(\"arctic\")\n\n    def initialize_library(self):\n        arc = Arctic(ARCTIC_SRV)\n        for doc_type in DOC_TYPE:\n            arc.initialize_library(get_library_name(doc_type), lib_type=CHUNK_STORE)\n\n    def _get_empty_folder(self, fp: Path):\n        fp = Path(fp)\n        if fp.exists():\n            shutil.rmtree(fp)\n        fp.mkdir(parents=True, exist_ok=True)\n\n    def import_data(self, doc_type_l=[\"Tick\", \"Transaction\", \"Order\"]):\n        # clear all the old files\n        for fp in LOG_FILE_PATH, DATA_INFO_PATH, DATA_FINISH_INFO_PATH, DATA_PATH:\n            self._get_empty_folder(fp)\n\n        arc = Arctic(ARCTIC_SRV)\n        for doc_type in DOC_TYPE:\n            # arc.initialize_library(get_library_name(doc_type), lib_type=CHUNK_STORE)\n            arc.set_quota(get_library_name(doc_type), MAX_SIZE)\n        arc.reset()\n\n        # doc_type = 'Day'\n        for doc_type in doc_type_l:\n            date_list = list(set([int(path.split(\"_\")[0]) for path in os.listdir(DATABASE_PATH) if doc_type in path]))\n            date_list.sort()\n            date_list = [str(date) for date in date_list]\n\n            f = open(ALL_STOCK_PATH, \"r\")\n            stock_name_list = [lines.split(\"\\t\")[0] for lines in f.readlines()]\n            f.close()\n            stock_name_dict = {\n                \"SH\": [stock_name[2:] for stock_name in stock_name_list if \"SH\" in stock_name],\n                \"SZ\": [stock_name[2:] for stock_name in stock_name_list if \"SZ\" in stock_name],\n            }\n\n            lib_name = get_library_name(doc_type)\n            a = Arctic(ARCTIC_SRV)\n            # a.initialize_library(lib_name, lib_type=CHUNK_STORE)\n\n            stock_name_exist = a[lib_name].list_symbols()\n            lib = a[lib_name]\n            initialize_count = 0\n            for stock_name in stock_name_list:\n                if stock_name not in stock_name_exist:\n                    initialize_count += 1\n                    # A placeholder for stocks\n                    pdf = pd.DataFrame(index=[pd.Timestamp(\"1900-01-01\")])\n                    pdf.index.name = \"date\"  # an col named date is necessary\n                    lib.write(stock_name, pdf)\n            print(\"initialize count: {}\".format(initialize_count))\n            print(\"tasks: {}\".format(date_list))\n            a.reset()\n\n            # date_list = [files.split(\"_\")[0] for files in os.listdir(\"./raw_data_price\") if \"tar\" in files]\n            # print(len(date_list))\n            date_list = [\"20201231\"]  # for test\n            Parallel(n_jobs=min(2, len(date_list)))(\n                delayed(add_data)(date, doc_type, stock_name_dict) for date in date_list\n            )\n\n\nif __name__ == \"__main__\":\n    fire.Fire(DSCreator)\n"
  },
  {
    "path": "examples/orderbook_data/example.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom arctic.arctic import Arctic\nimport qlib\nfrom qlib.data import D\nimport unittest\n\n\nclass TestClass(unittest.TestCase):\n    \"\"\"\n    Useful commands\n    - run all tests: pytest examples/orderbook_data/example.py\n    - run a single test:  pytest -s --pdb --disable-warnings examples/orderbook_data/example.py::TestClass::test_basic01\n    \"\"\"\n\n    def setUp(self):\n        \"\"\"\n        Configure for arctic\n        \"\"\"\n        provider_uri = \"~/.qlib/qlib_data/yahoo_cn_1min\"\n        qlib.init(\n            provider_uri=provider_uri,\n            mem_cache_size_limit=1024**3 * 2,\n            mem_cache_type=\"sizeof\",\n            kernels=1,\n            expression_provider={\"class\": \"LocalExpressionProvider\", \"kwargs\": {\"time2idx\": False}},\n            feature_provider={\n                \"class\": \"ArcticFeatureProvider\",\n                \"module_path\": \"qlib.contrib.data.data\",\n                \"kwargs\": {\"uri\": \"127.0.0.1\"},\n            },\n            dataset_provider={\n                \"class\": \"LocalDatasetProvider\",\n                \"kwargs\": {\n                    \"align_time\": False,  # Order book is not fixed, so it can't be align to a shared fixed frequency calendar\n                },\n            },\n        )\n        # self.stocks_list = [\"SH600519\"]\n        self.stocks_list = [\"SZ000725\"]\n\n    def test_basic(self):\n        # NOTE: this data contains a lot of zeros in $askX and $bidX\n        df = D.features(\n            self.stocks_list,\n            fields=[\"$ask1\", \"$ask2\", \"$bid1\", \"$bid2\"],\n            freq=\"ticks\",\n            start_time=\"20201230\",\n            end_time=\"20210101\",\n        )\n        print(df)\n\n    def test_basic_without_time(self):\n        df = D.features(self.stocks_list, fields=[\"$ask1\"], freq=\"ticks\")\n        print(df)\n\n    def test_basic01(self):\n        df = D.features(\n            self.stocks_list,\n            fields=[\"TResample($ask1, '1min', 'last')\"],\n            freq=\"ticks\",\n            start_time=\"20201230\",\n            end_time=\"20210101\",\n        )\n        print(df)\n\n    def test_basic02(self):\n        df = D.features(\n            self.stocks_list,\n            fields=[\"$function_code\"],\n            freq=\"transaction\",\n            start_time=\"20201230\",\n            end_time=\"20210101\",\n        )\n        print(df)\n\n    def test_basic03(self):\n        df = D.features(\n            self.stocks_list,\n            fields=[\"$function_code\"],\n            freq=\"order\",\n            start_time=\"20201230\",\n            end_time=\"20210101\",\n        )\n        print(df)\n\n    # Here are some popular expressions for high-frequency\n    # 1) some shared expression\n    expr_sum_buy_ask_1 = \"(TResample($ask1, '1min', 'last') + TResample($bid1, '1min', 'last'))\"\n    total_volume = (\n        \"TResample(\"\n        + \"+\".join([f\"${name}{i}\" for i in range(1, 11) for name in [\"asize\", \"bsize\"]])\n        + \", '1min', 'sum')\"\n    )\n\n    @staticmethod\n    def total_func(name, method):\n        return \"TResample(\" + \"+\".join([f\"${name}{i}\" for i in range(1, 11)]) + \",'1min', '{}')\".format(method)\n\n    def test_exp_01(self):\n        exprs = []\n        names = []\n        for name in [\"asize\", \"bsize\"]:\n            for i in range(1, 11):\n                exprs.append(f\"TResample(${name}{i}, '1min', 'mean') / ({self.total_volume})\")\n                names.append(f\"v_{name}_{i}\")\n        df = D.features(self.stocks_list, fields=exprs, freq=\"ticks\")\n        df.columns = names\n        print(df)\n\n    # 2) some often used papers;\n    def test_exp_02(self):\n        spread_func = (\n            lambda index: f\"2 * TResample($ask{index} - $bid{index}, '1min', 'last') / {self.expr_sum_buy_ask_1}\"\n        )\n        mid_func = (\n            lambda index: f\"2 * TResample(($ask{index} + $bid{index})/2, '1min', 'last') / {self.expr_sum_buy_ask_1}\"\n        )\n\n        exprs = []\n        names = []\n        for i in range(1, 11):\n            exprs.extend([spread_func(i), mid_func(i)])\n            names.extend([f\"p_spread_{i}\", f\"p_mid_{i}\"])\n        df = D.features(self.stocks_list, fields=exprs, freq=\"ticks\")\n        df.columns = names\n        print(df)\n\n    def test_exp_03(self):\n        expr3_func1 = (\n            lambda name, index_left, index_right: f\"2 * TResample(Abs(${name}{index_left} - ${name}{index_right}), '1min', 'last') / {self.expr_sum_buy_ask_1}\"\n        )\n        for name in [\"ask\", \"bid\"]:\n            for i in range(1, 10):\n                exprs = [expr3_func1(name, i + 1, i)]\n                names = [f\"p_diff_{name}_{i}_{i+1}\"]\n        exprs.extend([expr3_func1(\"ask\", 10, 1), expr3_func1(\"bid\", 1, 10)])\n        names.extend([\"p_diff_ask_10_1\", \"p_diff_bid_1_10\"])\n        df = D.features(self.stocks_list, fields=exprs, freq=\"ticks\")\n        df.columns = names\n        print(df)\n\n    def test_exp_04(self):\n        exprs = []\n        names = []\n        for name in [\"asize\", \"bsize\"]:\n            exprs.append(f\"(({ self.total_func(name, 'mean')}) / 10) / {self.total_volume}\")\n            names.append(f\"v_avg_{name}\")\n\n        df = D.features(self.stocks_list, fields=exprs, freq=\"ticks\")\n        df.columns = names\n        print(df)\n\n    def test_exp_05(self):\n        exprs = [\n            f\"2 * Sub({ self.total_func('ask', 'last')}, {self.total_func('bid', 'last')})/{self.expr_sum_buy_ask_1}\",\n            f\"Sub({ self.total_func('asize', 'mean')}, {self.total_func('bsize', 'mean')})/{self.total_volume}\",\n        ]\n        names = [\"p_accspread\", \"v_accspread\"]\n\n        df = D.features(self.stocks_list, fields=exprs, freq=\"ticks\")\n        df.columns = names\n        print(df)\n\n    #  (p|v)_diff_(ask|bid|asize|bsize)_(time_interval)\n    def test_exp_06(self):\n        t = 3\n        expr6_price_func = (\n            lambda name, index, method: f'2 * (TResample(${name}{index}, \"{t}s\", \"{method}\") - Ref(TResample(${name}{index}, \"{t}s\", \"{method}\"), 1)) / {t}'\n        )\n        exprs = []\n        names = []\n        for i in range(1, 11):\n            for name in [\"bid\", \"ask\"]:\n                exprs.append(\n                    f\"TResample({expr6_price_func(name, i, 'last')}, '1min', 'mean') / {self.expr_sum_buy_ask_1}\"\n                )\n                names.append(f\"p_diff_{name}{i}_{t}s\")\n\n        for i in range(1, 11):\n            for name in [\"asize\", \"bsize\"]:\n                exprs.append(f\"TResample({expr6_price_func(name, i, 'mean')}, '1min', 'mean') / {self.total_volume}\")\n                names.append(f\"v_diff_{name}{i}_{t}s\")\n\n        df = D.features(self.stocks_list, fields=exprs, freq=\"ticks\")\n        df.columns = names\n        print(df)\n\n    # TODOs:\n    # Following expressions may be implemented in the future\n    # expr7_2 = lambda funccode, bsflag, time_interval: \\\n    #     \"TResample(TRolling(TEq(@transaction.function_code,  {}) & TEq(@transaction.bs_flag ,{}), '{}s', 'sum') / \\\n    #     TRolling(@transaction.function_code, '{}s', 'count') , '1min', 'mean')\".format(ord(funccode), bsflag,time_interval,time_interval)\n    # create_dataset(7, \"SH600000\", [expr7_2(\"C\")] + [expr7(funccode, ordercode) for funccode in ['B','S'] for ordercode in ['0','1']])\n    # create_dataset(7,  [\"SH600000\"], [expr7_2(\"C\", 48)] )\n\n    @staticmethod\n    def expr7_init(funccode, ordercode, time_interval):\n        # NOTE: based on on order frequency (i.e. freq=\"order\")\n        return f\"Rolling(Eq($function_code,  {ord(funccode)}) & Eq($order_kind ,{ord(ordercode)}), '{time_interval}s', 'sum') / Rolling($function_code, '{time_interval}s', 'count')\"\n\n    # (la|lb|ma|mb|ca|cb)_intensity_(time_interval)\n    def test_exp_07_1(self):\n        # NOTE: based on transaction frequency (i.e. freq=\"transaction\")\n        expr7_3 = (\n            lambda funccode, code, time_interval: f\"TResample(Rolling(Eq($function_code,  {ord(funccode)}) & {code}($ask_order, $bid_order) , '{time_interval}s', 'sum')   / Rolling($function_code, '{time_interval}s', 'count') , '1min', 'mean')\"\n        )\n\n        exprs = [expr7_3(\"C\", \"Gt\", \"3\"), expr7_3(\"C\", \"Lt\", \"3\")]\n        names = [\"ca_intensity_3s\", \"cb_intensity_3s\"]\n\n        df = D.features(self.stocks_list, fields=exprs, freq=\"transaction\")\n        df.columns = names\n        print(df)\n\n    trans_dict = {\"B\": \"a\", \"S\": \"b\", \"0\": \"l\", \"1\": \"m\"}\n\n    def test_exp_07_2(self):\n        # NOTE: based on on order frequency\n        expr7 = (\n            lambda funccode, ordercode, time_interval: f\"TResample({self.expr7_init(funccode, ordercode, time_interval)}, '1min', 'mean')\"\n        )\n\n        exprs = []\n        names = []\n        for funccode in [\"B\", \"S\"]:\n            for ordercode in [\"0\", \"1\"]:\n                exprs.append(expr7(funccode, ordercode, \"3\"))\n                names.append(self.trans_dict[ordercode] + self.trans_dict[funccode] + \"_intensity_3s\")\n        df = D.features(self.stocks_list, fields=exprs, freq=\"transaction\")\n        df.columns = names\n        print(df)\n\n    @staticmethod\n    def expr7_3_init(funccode, code, time_interval):\n        # NOTE: It depends on transaction frequency\n        return f\"Rolling(Eq($function_code,  {ord(funccode)}) & {code}($ask_order, $bid_order) , '{time_interval}s', 'sum') / Rolling($function_code, '{time_interval}s', 'count')\"\n\n    # (la|lb|ma|mb|ca|cb)_relative_intensity_(time_interval_small)_(time_interval_big)\n    def test_exp_08_1(self):\n        expr8_1 = (\n            lambda funccode, ordercode, time_interval_short, time_interval_long: f\"TResample(Gt({self.expr7_init(funccode, ordercode, time_interval_short)},{self.expr7_init(funccode, ordercode, time_interval_long)}), '1min', 'mean')\"\n        )\n\n        exprs = []\n        names = []\n        for funccode in [\"B\", \"S\"]:\n            for ordercode in [\"0\", \"1\"]:\n                exprs.append(expr8_1(funccode, ordercode, \"10\", \"900\"))\n                names.append(self.trans_dict[ordercode] + self.trans_dict[funccode] + \"_relative_intensity_10s_900s\")\n\n        df = D.features(self.stocks_list, fields=exprs, freq=\"order\")\n        df.columns = names\n        print(df)\n\n    def test_exp_08_2(self):\n        # NOTE: It depends on transaction frequency\n        expr8_2 = (\n            lambda funccode, ordercode, time_interval_short, time_interval_long: f\"TResample(Gt({self.expr7_3_init(funccode, ordercode, time_interval_short)},{self.expr7_3_init(funccode, ordercode, time_interval_long)}), '1min', 'mean')\"\n        )\n\n        exprs = [expr8_2(\"C\", \"Gt\", \"10\", \"900\"), expr8_2(\"C\", \"Lt\", \"10\", \"900\")]\n        names = [\"ca_relative_intensity_10s_900s\", \"cb_relative_intensity_10s_900s\"]\n\n        df = D.features(self.stocks_list, fields=exprs, freq=\"transaction\")\n        df.columns = names\n        print(df)\n\n    ## v9(la|lb|ma|mb|ca|cb)_diff_intensity_(time_interval1)_(time_interval2)\n    # 1) calculating the original data\n    # 2) Resample data to 3s and calculate the changing rate\n    # 3) Resample data to 1min\n\n    def test_exp_09_trans(self):\n        exprs = [\n            f'TResample(Div(Sub(TResample({self.expr7_3_init(\"C\", \"Gt\", \"3\")}, \"3s\", \"last\"), Ref(TResample({self.expr7_3_init(\"C\", \"Gt\", \"3\")}, \"3s\",\"last\"), 1)), 3), \"1min\", \"mean\")',\n            f'TResample(Div(Sub(TResample({self.expr7_3_init(\"C\", \"Lt\", \"3\")}, \"3s\", \"last\"), Ref(TResample({self.expr7_3_init(\"C\", \"Lt\", \"3\")}, \"3s\",\"last\"), 1)), 3), \"1min\", \"mean\")',\n        ]\n        names = [\"ca_diff_intensity_3s_3s\", \"cb_diff_intensity_3s_3s\"]\n        df = D.features(self.stocks_list, fields=exprs, freq=\"transaction\")\n        df.columns = names\n        print(df)\n\n    def test_exp_09_order(self):\n        exprs = []\n        names = []\n        for funccode in [\"B\", \"S\"]:\n            for ordercode in [\"0\", \"1\"]:\n                exprs.append(\n                    f'TResample(Div(Sub(TResample({self.expr7_init(funccode, ordercode, \"3\")}, \"3s\", \"last\"), Ref(TResample({self.expr7_init(funccode, ordercode, \"3\")},\"3s\", \"last\"), 1)), 3) ,\"1min\", \"mean\")'\n                )\n                names.append(self.trans_dict[ordercode] + self.trans_dict[funccode] + \"_diff_intensity_3s_3s\")\n        df = D.features(self.stocks_list, fields=exprs, freq=\"order\")\n        df.columns = names\n        print(df)\n\n    def test_exp_10(self):\n        exprs = []\n        names = []\n        for i in [5, 10, 30, 60]:\n            exprs.append(\n                f'TResample(Ref(TResample($ask1 + $bid1, \"1s\", \"ffill\"), {-i}) / TResample($ask1 + $bid1, \"1s\", \"ffill\") - 1, \"1min\", \"mean\" )'\n            )\n            names.append(f\"lag_{i}_change_rate\" for i in [5, 10, 30, 60])\n        df = D.features(self.stocks_list, fields=exprs, freq=\"ticks\")\n        df.columns = names\n        print(df)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "examples/portfolio/README.md",
    "content": "# Portfolio Optimization Strategy\n\n## Introduction\n\nIn `qlib/examples/benchmarks` we have various **alpha** models that predict\nthe stock returns. We also use a simple rule based `TopkDropoutStrategy` to\nevaluate the investing performance of these models. However, such a strategy\nis too simple to control the portfolio risk like correlation and volatility.\n\nTo this end, an optimization based strategy should be used to for the\ntrade-off between return and risk. In this doc, we will show how to use\n`EnhancedIndexingStrategy` to maximize portfolio return while minimizing\ntracking error relative to a benchmark.\n\n\n## Preparation\n\nWe use China stock market data for our example.\n\n1. Prepare CSI300 weight:\n\n   ```bash\n   wget https://github.com/SunsetWolf/qlib_dataset/releases/download/v0/csi300_weight.zip\n   unzip -d ~/.qlib/qlib_data/cn_data csi300_weight.zip\n   rm -f csi300_weight.zip\n   ```\n   NOTE:  We don't find any public free resource to get the weight in the benchmark. To run the example, we manually create this weight data.\n\n2. Prepare risk model data:\n\n   ```bash\n   python prepare_riskdata.py\n   ```\n\nHere we use a **Statistical Risk Model** implemented in `qlib.model.riskmodel`.\nHowever users are strongly recommended to use other risk models for better quality:\n* **Fundamental Risk Model** like MSCI BARRA\n* [Deep Risk Model](https://arxiv.org/abs/2107.05201)\n\n\n## End-to-End Workflow\n\nYou can finish workflow with `EnhancedIndexingStrategy` by running\n`qrun config_enhanced_indexing.yaml`.\n\nIn this config, we mainly changed the strategy section compared to\n`qlib/examples/benchmarks/workflow_config_lightgbm_Alpha158.yaml`.\n"
  },
  {
    "path": "examples/portfolio/config_enhanced_indexing.yaml",
    "content": "qlib_init:\n    provider_uri: \"~/.qlib/qlib_data/cn_data\"\n    region: cn\nmarket: &market csi300\nbenchmark: &benchmark SH000300\ndata_handler_config: &data_handler_config\n    start_time: 2008-01-01\n    end_time: 2020-08-01\n    fit_start_time: 2008-01-01\n    fit_end_time: 2014-12-31\n    instruments: *market\nport_analysis_config: &port_analysis_config\n    strategy:\n        class: EnhancedIndexingStrategy\n        module_path: qlib.contrib.strategy\n        kwargs:\n            model: <MODEL>\n            dataset: <DATASET>\n            riskmodel_root: ./riskdata\n    backtest:\n        start_time: 2017-01-01\n        end_time: 2020-08-01\n        account: 100000000\n        benchmark: *benchmark\n        exchange_kwargs:\n            limit_threshold: 0.095\n            deal_price: close\n            open_cost: 0.0005\n            close_cost: 0.0015\n            min_cost: 5\ntask:\n    model:\n        class: LGBModel\n        module_path: qlib.contrib.model.gbdt\n        kwargs:\n            loss: mse\n            colsample_bytree: 0.8879\n            learning_rate: 0.2\n            subsample: 0.8789\n            lambda_l1: 205.6999\n            lambda_l2: 580.9768\n            max_depth: 8\n            num_leaves: 210\n            num_threads: 20\n    dataset:\n        class: DatasetH\n        module_path: qlib.data.dataset\n        kwargs:\n            handler:\n                class: Alpha158\n                module_path: qlib.contrib.data.handler\n                kwargs: *data_handler_config\n            segments:\n                train: [2008-01-01, 2014-12-31]\n                valid: [2015-01-01, 2016-12-31]\n                test: [2017-01-01, 2020-08-01]\n    record:\n        - class: SignalRecord\n          module_path: qlib.workflow.record_temp\n          kwargs:\n            model: <MODEL>\n            dataset: <DATASET>\n        - class: SigAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs:\n            ana_long_short: False\n            ann_scaler: 252\n        - class: PortAnaRecord\n          module_path: qlib.workflow.record_temp\n          kwargs:\n            config: *port_analysis_config\n"
  },
  {
    "path": "examples/portfolio/prepare_riskdata.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\nimport os\nimport numpy as np\nimport pandas as pd\n\nfrom qlib.data import D\nfrom qlib.model.riskmodel import StructuredCovEstimator\n\n\ndef prepare_data(riskdata_root=\"./riskdata\", T=240, start_time=\"2016-01-01\"):\n    universe = D.features(D.instruments(\"csi300\"), [\"$close\"], start_time=start_time).swaplevel().sort_index()\n\n    price_all = (\n        D.features(D.instruments(\"all\"), [\"$close\"], start_time=start_time).squeeze().unstack(level=\"instrument\")\n    )\n\n    # StructuredCovEstimator is a statistical risk model\n    riskmodel = StructuredCovEstimator()\n\n    for i in range(T - 1, len(price_all)):\n        date = price_all.index[i]\n        ref_date = price_all.index[i - T + 1]\n\n        print(date)\n\n        codes = universe.loc[date].index\n        price = price_all.loc[ref_date:date, codes]\n\n        # calculate return and remove extreme return\n        ret = price.pct_change()\n        ret.clip(ret.quantile(0.025), ret.quantile(0.975), axis=1, inplace=True)\n\n        # run risk model\n        F, cov_b, var_u = riskmodel.predict(ret, is_price=False, return_decomposed_components=True)\n\n        # save risk data\n        root = riskdata_root + \"/\" + date.strftime(\"%Y%m%d\")\n        os.makedirs(root, exist_ok=True)\n\n        pd.DataFrame(F, index=codes).to_pickle(root + \"/factor_exp.pkl\")\n        pd.DataFrame(cov_b).to_pickle(root + \"/factor_cov.pkl\")\n        # for specific_risk we follow the convention to save volatility\n        pd.Series(np.sqrt(var_u), index=codes).to_pickle(root + \"/specific_risk.pkl\")\n\n\nif __name__ == \"__main__\":\n    import qlib\n\n    qlib.init(provider_uri=\"~/.qlib/qlib_data/cn_data\")\n\n    prepare_data()\n"
  },
  {
    "path": "examples/rl_order_execution/README.md",
    "content": "# RL Example for Order Execution\n\nThis folder comprises an example of Reinforcement Learning (RL) workflows for order execution scenario, including both training workflows and backtest workflows.\n\n## Data Processing\n\n### Get Data\n\n```\npython -m qlib.cli.data qlib_data --target_dir ./data/bin --region hs300 --interval 5min\n```\n\n### Generate Pickle-Style Data\n\nTo run codes in this example, we need data in pickle format. To achieve this, run following commands (might need a few minutes to finish):\n\n[//]: # (TODO: Instead of dumping dataframe with different format &#40;like `_gen_dataset` and `_gen_day_dataset` in `qlib/contrib/data/highfreq_provider.py`&#41;, we encourage to implement different subclass of `Dataset` and `DataHandler`. This will keep the workflow cleaner and interfaces more consistent, and move all the complexity to the subclass.)\n\n```\npython scripts/gen_pickle_data.py -c scripts/pickle_data_config.yml\npython scripts/gen_training_orders.py\npython scripts/merge_orders.py\n```\n\nWhen finished, the structure under `data/` should be:\n\n```\ndata\n├── bin\n├── orders\n└── pickle\n```\n\n## Training\n\nEach training task is specified by a config file. The config file for task `TASKNAME` is `exp_configs/train_TASKNAME.yml`. This example provides two training tasks:\n\n- **PPO**: Method proposed by IJCAL 2020 paper \"[An End-to-End Optimal Trade Execution Framework based on Proximal Policy Optimization](https://www.ijcai.org/proceedings/2020/0627.pdf)\".\n- **OPDS**: Method proposed by AAAI 2021 paper \"[Universal Trading for Order Execution with Oracle Policy Distillation](https://arxiv.org/abs/2103.10860)\".\n\nThe main differece between these two methods is their reward functions. Please see their config files for details.\n\nTake OPDS as an example, to run the training workflow, run:\n\n```\npython -m qlib.rl.contrib.train_onpolicy --config_path exp_configs/train_opds.yml --run_backtest\n```\n\nMetrics, logs, and checkpoints will be stored under `outputs/opds` (configured by `exp_configs/train_opds.yml`). \n\n## Backtest\n\nOnce the training workflow has completed, the trained model can be used for the backtesting workflow. Still taking OPDS as an example, once training is finished, the latest checkpoint of the model can be found at `outputs/opds/checkpoints/latest.pth`. To run backtest workflow:\n\n1. Uncomment the `weight_file` parameter in `exp_configs/train_opds.yml` (it is commented by default). While it is possible to run the backtesting workflow without setting a checkpoint, this will lead to randomly initialized model results, thus making them meaningless.\n2. Run `python -m qlib.rl.contrib.backtest --config_path exp_configs/backtest_opds.yml`.\n\nThe backtest result is stored in `outputs/checkpoints/backtest_result.csv`.\n\nIn addition to OPDS and PPO, we also provide TWAP ([Time-weighted average price](https://en.wikipedia.org/wiki/Time-weighted_average_price)) as a weak baseline. The config file for TWAP is `exp_configs/backtest_twap.yml`.\n\n### Gap between backtest and training pipeline's testing\n\nIt is worthy to notice that the results of the backtesting process may differ from the results of the testing process used during training.\nThis is because different simulators are used to simulate market conditions during training and backtesting.\nIn training pipeline, the simplified simulator called `SingleAssetOrderExecutionSimple` is used for efficiency reasons. \n`SingleAssetOrderExecutionSimple` makes no restriction to trading amounts. \nNo matter what the amount of the order is, it can be completely executed.\nHowever, during backtesting, a more realistic simulator called `SingleAssetOrderExecution` is used. \nIt takes into account practical constraints in more real-world scenarios (for example, the trading volume must be a multiple of the smallest trading unit).\nAs a result, the amount of an order that is actually executed during backtesting may differ from the amount expected to be executed.\n\nIf you would like to obtain results that are exactly the same as those obtained during testing in the training pipeline, you could run training pipeline with only backtest phrase.\nIn order to do this:\n- Modify the training config. Add the path of the checkpoint you want to use (see following for an example).\n- Run `python -m qlib.rl.contrib.train_onpolicy --config_path PATH/TO/CONFIG --run_backtest --no_training`\n\n```yaml\n...\npolicy:\n  class: PPO  # PPO, DQN\n  kwargs:\n    lr: 0.0001\n    weight_file: PATH/TO/CHECKPOINT\n  module_path: qlib.rl.order_execution.policy\n...\n```\n\n## Benchmarks (TBD)\n\nTo accurately evaluate the performance of models using Reinforcement Learning algorithms, it's best to run experiments multiple times and compute the average performance across all trials. However, given the time-consuming nature of model training, this is not always feasible. An alternative approach is to run each training task only once, selecting the 10 checkpoints with the highest validation performance to simulate multiple trials. In this example, we use \"Price Advantage (PA)\" as the metric for selecting these checkpoints. The average performance of these 10 checkpoints on the testing set is as follows:\n\n| **Model**                   | **PA mean with std.** |\n|-----------------------------|-----------------------|\n| OPDS (with PPO policy)      |  0.4785 ± 0.7815      |\n| OPDS (with DQN policy)      | -0.0114 ± 0.5780      |\n| PPO                         | -1.0935 ± 0.0922      |\n| TWAP                        |   ≈ 0.0 ± 0.0         |\n\nThe table above also includes TWAP as a rule-based baseline. The ideal PA of TWAP should be 0.0, however, in this example, the order execution is divided into two steps: first, the order is split equally among each half hour, and then each five minutes within each half hour. Since trading is forbidden during the last five minutes of the day, this approach may slightly differ from traditional TWAP over the course of a full day (as there are 5 minutes missing in the last \"half hour\"). Therefore, the PA of TWAP can be considered as a number that is close to 0.0. To verify this, you may run a TWAP backtest and check the results.\n"
  },
  {
    "path": "examples/rl_order_execution/exp_configs/backtest_opds.yml",
    "content": "order_file: ./data/orders/test_orders.pkl\nstart_time: \"9:30\"\nend_time: \"14:54\"\ndata_granularity: \"5min\"\nqlib:\n  provider_uri_5min: ./data/bin/\nexchange:\n  limit_threshold: null\n  deal_price: [\"$close\", \"$close\"]\n  volume_threshold: null\nstrategies:\n  1day:\n    class: SAOEIntStrategy\n    kwargs:\n      data_granularity: 5\n      action_interpreter:\n        class: CategoricalActionInterpreter\n        kwargs:\n          max_step: 8\n          values: 4\n        module_path: qlib.rl.order_execution.interpreter\n      network:\n        class: Recurrent\n        kwargs: {}\n        module_path: qlib.rl.order_execution.network\n      policy:\n        class: PPO  # PPO, DQN\n        kwargs:\n          lr: 0.0001\n          # Restore `weight_file` once the training workflow finishes. You can change the checkpoint file you want to use.\n          # weight_file: outputs/opds/checkpoints/latest.pth\n        module_path: qlib.rl.order_execution.policy\n      state_interpreter:\n        class: FullHistoryStateInterpreter\n        kwargs:\n          data_dim: 5\n          data_ticks: 48\n          max_step: 8\n          processed_data_provider:\n            class: HandlerProcessedDataProvider\n            kwargs:\n              data_dir: ./data/pickle/\n              feature_columns_today: [\"$high\", \"$low\", \"$open\", \"$close\", \"$volume\"]\n              feature_columns_yesterday: [\"$high_1\", \"$low_1\", \"$open_1\", \"$close_1\", \"$volume_1\"]\n            module_path: qlib.rl.data.native\n        module_path: qlib.rl.order_execution.interpreter\n    module_path: qlib.rl.order_execution.strategy\n  30min:\n    class: TWAPStrategy\n    kwargs: {}\n    module_path: qlib.contrib.strategy.rule_strategy\nconcurrency: 16\noutput_dir: outputs/opds/\n"
  },
  {
    "path": "examples/rl_order_execution/exp_configs/backtest_ppo.yml",
    "content": "order_file: ./data/orders/test_orders.pkl\nstart_time: \"9:30\"\nend_time: \"14:54\"\ndata_granularity: \"5min\"\nqlib:\n  provider_uri_5min: ./data/bin/\nexchange:\n  limit_threshold: null\n  deal_price: [\"$close\", \"$close\"]\n  volume_threshold: null\nstrategies:\n  1day:\n    class: SAOEIntStrategy\n    kwargs:\n      data_granularity: 5\n      action_interpreter:\n        class: CategoricalActionInterpreter\n        kwargs:\n          max_step: 8\n          values: 4\n        module_path: qlib.rl.order_execution.interpreter\n      network:\n        class: Recurrent\n        kwargs: {}\n        module_path: qlib.rl.order_execution.network\n      policy:\n        class: PPO  # PPO, DQN\n        kwargs:\n          lr: 0.0001\n          # Restore `weight_file` once the training workflow finishes. You can change the checkpoint file you want to use.\n          # weight_file: outputs/ppo/checkpoints/latest.pth\n        module_path: qlib.rl.order_execution.policy\n      state_interpreter:\n        class: FullHistoryStateInterpreter\n        kwargs:\n          data_dim: 5\n          data_ticks: 48\n          max_step: 8\n          processed_data_provider:\n            class: HandlerProcessedDataProvider\n            kwargs:\n              data_dir: ./data/pickle/\n              feature_columns_today: [\"$high\", \"$low\", \"$open\", \"$close\", \"$volume\"]\n              feature_columns_yesterday: [\"$high_1\", \"$low_1\", \"$open_1\", \"$close_1\", \"$volume_1\"]\n            module_path: qlib.rl.data.native\n        module_path: qlib.rl.order_execution.interpreter\n    module_path: qlib.rl.order_execution.strategy\n  30min:\n    class: TWAPStrategy\n    kwargs: {}\n    module_path: qlib.contrib.strategy.rule_strategy\nconcurrency: 16\noutput_dir: outputs/ppo/\n"
  },
  {
    "path": "examples/rl_order_execution/exp_configs/backtest_twap.yml",
    "content": "order_file: ./data/orders/test_orders.pkl\nstart_time: \"9:30\"\nend_time: \"14:54\"\ndata_granularity: \"5min\"\nqlib:\n  provider_uri_5min: ./data/bin/\nexchange:\n  limit_threshold: null\n  deal_price: [\"$close\", \"$close\"]\n  volume_threshold: null\nstrategies:\n  1day:\n    class: TWAPStrategy\n    kwargs: {}\n    module_path: qlib.contrib.strategy.rule_strategy\n  30min:\n    class: TWAPStrategy\n    kwargs: {}\n    module_path: qlib.contrib.strategy.rule_strategy\nconcurrency: 16\noutput_dir: outputs/twap/\n"
  },
  {
    "path": "examples/rl_order_execution/exp_configs/train_opds.yml",
    "content": "simulator:\n  data_granularity: 5\n  time_per_step: 30\n  vol_limit: null\nenv:\n  concurrency: 32\n  parallel_mode: dummy\naction_interpreter:\n  class: CategoricalActionInterpreter\n  kwargs:\n    values: 4\n    max_step: 8\n  module_path: qlib.rl.order_execution.interpreter\nstate_interpreter:\n  class: FullHistoryStateInterpreter\n  kwargs:\n    data_dim: 5\n    data_ticks: 48  # 48 = 240 min / 5 min\n    max_step: 8\n    processed_data_provider:\n      class: HandlerProcessedDataProvider\n      kwargs:\n        data_dir: ./data/pickle/\n        feature_columns_today: [\"$high\", \"$low\", \"$open\", \"$close\", \"$volume\"]\n        feature_columns_yesterday: [\"$high_1\", \"$low_1\", \"$open_1\", \"$close_1\", \"$volume_1\"]\n        backtest: false\n      module_path: qlib.rl.data.native\n  module_path: qlib.rl.order_execution.interpreter\nreward:\n  class: PAPenaltyReward\n  kwargs:\n    penalty: 4.0\n    scale: 0.01\n  module_path: qlib.rl.order_execution.reward\ndata:\n  source:\n    order_dir: ./data/orders\n    feature_root_dir: ./data/pickle/\n    feature_columns_today: [\"$close0\", \"$volume0\"]\n    feature_columns_yesterday: []\n    total_time: 240\n    default_start_time_index: 0\n    default_end_time_index: 235\n    proc_data_dim: 5\n  num_workers: 0\n  queue_size: 20\nnetwork:\n  class: Recurrent\n  module_path: qlib.rl.order_execution.network\npolicy:\n  class: PPO  # PPO, DQN\n  kwargs:\n    lr: 0.0001\n  module_path: qlib.rl.order_execution.policy\nruntime:\n  seed: 42\n  use_cuda: false\ntrainer:\n  max_epoch: 500\n  repeat_per_collect: 25\n  earlystop_patience: 50\n  episode_per_collect: 10000\n  batch_size: 1024\n  val_every_n_epoch: 4\n  checkpoint_path: ./outputs/opds\n  checkpoint_every_n_iters: 1\n"
  },
  {
    "path": "examples/rl_order_execution/exp_configs/train_ppo.yml",
    "content": "simulator:\n  data_granularity: 5\n  time_per_step: 30\n  vol_limit: null\nenv:\n  concurrency: 32\n  parallel_mode: dummy\naction_interpreter:\n  class: CategoricalActionInterpreter\n  kwargs:\n    values: 4\n    max_step: 8\n  module_path: qlib.rl.order_execution.interpreter\nstate_interpreter:\n  class: FullHistoryStateInterpreter\n  kwargs:\n    data_dim: 5\n    data_ticks: 48  # 48 = 240 min / 5 min\n    max_step: 8\n    processed_data_provider:\n      class: HandlerProcessedDataProvider\n      kwargs:\n        data_dir: ./data/pickle/\n        feature_columns_today: [\"$high\", \"$low\", \"$open\", \"$close\", \"$volume\"]\n        feature_columns_yesterday: [\"$high_1\", \"$low_1\", \"$open_1\", \"$close_1\", \"$volume_1\"]\n        backtest: false\n      module_path: qlib.rl.data.native\n  module_path: qlib.rl.order_execution.interpreter\nreward:\n  class: PPOReward\n  kwargs:\n    max_step: 8\n    start_time_index: 0\n    end_time_index: 46  # 46 = (240 - 5) min / 5 min - 1\n  module_path: qlib.rl.order_execution.reward\ndata:\n  source:\n    order_dir: ./data/orders\n    feature_root_dir: ./data/pickle/\n    feature_columns_today: [\"$close0\", \"$volume0\"]\n    feature_columns_yesterday: []\n    total_time: 240\n    default_start_time_index: 0\n    default_end_time_index: 235\n    proc_data_dim: 5\n  num_workers: 0\n  queue_size: 20\nnetwork:\n  class: Recurrent\n  module_path: qlib.rl.order_execution.network\npolicy:\n  class: PPO  # PPO, DQN\n  kwargs:\n    lr: 0.0001\n  module_path: qlib.rl.order_execution.policy\nruntime:\n  seed: 42\n  use_cuda: false\ntrainer:\n  max_epoch: 500\n  repeat_per_collect: 25\n  earlystop_patience: 50\n  episode_per_collect: 10000\n  batch_size: 1024\n  val_every_n_epoch: 4\n  checkpoint_path: ./outputs/ppo\n  checkpoint_every_n_iters: 1\n"
  },
  {
    "path": "examples/rl_order_execution/scripts/gen_pickle_data.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport yaml\nimport argparse\nimport os\nimport shutil\nfrom copy import deepcopy\n\nfrom qlib.contrib.data.highfreq_provider import HighFreqProvider\n\nloader = yaml.FullLoader\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"-c\", \"--config\", type=str, default=\"config.yml\")\n    parser.add_argument(\"-d\", \"--dest\", type=str, default=\".\")\n    parser.add_argument(\"-s\", \"--split\", type=str, choices=[\"none\", \"date\", \"stock\", \"both\"], default=\"stock\")\n    args = parser.parse_args()\n\n    conf = yaml.load(open(args.config), Loader=loader)\n\n    for k, v in conf.items():\n        if isinstance(v, dict) and \"path\" in v:\n            v[\"path\"] = os.path.join(args.dest, v[\"path\"])\n    provider = HighFreqProvider(**conf)\n\n    # Gen dataframe\n    if \"feature_conf\" in conf:\n        feature = provider._gen_dataframe(deepcopy(provider.feature_conf))\n    if \"backtest_conf\" in conf:\n        backtest = provider._gen_dataframe(deepcopy(provider.backtest_conf))\n\n    provider.feature_conf[\"path\"] = os.path.splitext(provider.feature_conf[\"path\"])[0] + \"/\"\n    provider.backtest_conf[\"path\"] = os.path.splitext(provider.backtest_conf[\"path\"])[0] + \"/\"\n    # Split by date\n    if args.split == \"date\" or args.split == \"both\":\n        provider._gen_day_dataset(deepcopy(provider.feature_conf), \"feature\")\n        provider._gen_day_dataset(deepcopy(provider.backtest_conf), \"backtest\")\n\n    # Split by stock\n    if args.split == \"stock\" or args.split == \"both\":\n        provider._gen_stock_dataset(deepcopy(provider.feature_conf), \"feature\")\n        provider._gen_stock_dataset(deepcopy(provider.backtest_conf), \"backtest\")\n\n    shutil.rmtree(\"stat/\", ignore_errors=True)\n"
  },
  {
    "path": "examples/rl_order_execution/scripts/gen_training_orders.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport os\nimport numpy as np\nimport pandas as pd\n\nfrom pathlib import Path\n\nDATA_PATH = Path(os.path.join(\"data\", \"pickle\", \"backtest\"))\nOUTPUT_PATH = Path(os.path.join(\"data\", \"orders\"))\n\n\ndef generate_order(stock: str, start_idx: int, end_idx: int) -> bool:\n    dataset = pd.read_pickle(DATA_PATH / f\"{stock}.pkl\")\n    df = dataset.handler.fetch(level=None).reset_index()\n    if len(df) == 0 or df.isnull().values.any() or min(df[\"$volume0\"]) < 1e-5:\n        return False\n\n    df[\"date\"] = df[\"datetime\"].dt.date.astype(\"datetime64\")\n    df = df.set_index([\"instrument\", \"datetime\", \"date\"])\n    df = df.groupby(\"date\", group_keys=False).take(range(start_idx, end_idx)).droplevel(level=0)\n\n    order_all = pd.DataFrame(df.groupby(level=(2, 0), group_keys=False).mean().dropna())\n    order_all[\"amount\"] = np.random.lognormal(-3.28, 1.14) * order_all[\"$volume0\"]\n    order_all = order_all[order_all[\"amount\"] > 0.0]\n    order_all[\"order_type\"] = 0\n    order_all = order_all.drop(columns=[\"$volume0\"])\n\n    order_train = order_all[order_all.index.get_level_values(0) <= pd.Timestamp(\"2021-06-30\")]\n    order_test = order_all[order_all.index.get_level_values(0) > pd.Timestamp(\"2021-06-30\")]\n    order_valid = order_test[order_test.index.get_level_values(0) <= pd.Timestamp(\"2021-09-30\")]\n    order_test = order_test[order_test.index.get_level_values(0) > pd.Timestamp(\"2021-09-30\")]\n\n    for order, tag in zip((order_train, order_valid, order_test, order_all), (\"train\", \"valid\", \"test\", \"all\")):\n        path = OUTPUT_PATH / tag\n        os.makedirs(path, exist_ok=True)\n        if len(order) > 0:\n            order.to_pickle(path / f\"{stock}.pkl.target\")\n    return True\n\n\nnp.random.seed(1234)\nfile_list = sorted(os.listdir(DATA_PATH))\nstocks = [f.replace(\".pkl\", \"\") for f in file_list]\nnp.random.shuffle(stocks)\n\ncnt = 0\nfor stock in stocks:\n    if generate_order(stock, 0, 240 // 5 - 1):\n        cnt += 1\n        if cnt == 100:\n            break\n"
  },
  {
    "path": "examples/rl_order_execution/scripts/merge_orders.py",
    "content": "import os\nimport pandas as pd\nfrom tqdm import tqdm\n\nfrom qlib.utils.pickle_utils import restricted_pickle_load\n\nfor tag in [\"test\", \"valid\"]:\n    files = os.listdir(os.path.join(\"data/orders/\", tag))\n    dfs = []\n    for f in tqdm(files):\n        with open(os.path.join(\"data/orders/\", tag, f), \"rb\") as fr:\n            df = restricted_pickle_load(fr)\n        df = df.drop([\"$close0\"], axis=1)\n        dfs.append(df)\n\n    total_df = pd.concat(dfs)\n    pickle.dump(total_df, open(os.path.join(\"data\", \"orders\", f\"{tag}_orders.pkl\"), \"wb\"))\n"
  },
  {
    "path": "examples/rl_order_execution/scripts/pickle_data_config.yml",
    "content": "# start & end time for training/validation/test datasets\nstart_time: !!str &start 2020-01-01\nend_time: !!str &end 2021-12-31\ntrain_end_time: !!str &tend 2021-06-30\nvalid_start_time: !!str &vstart 2021-07-01\nvalid_end_time: !!str &vend 2021-09-30\ntest_start_time: !!str &tstart 2021-10-01\n# the instrument set\ninstruments: &ins csi300s19_22\n# qlib related configuration\nqlib_conf:\n    provider_uri: \n        5min: ./data/bin # path to generated qlib bin\n    redis_port: 233\nfeature_conf:\n    path: ./data/pickle/feature.pkl # output path of feature\n    class: DatasetH\n    module_path: qlib.data.dataset\n    kwargs:\n        handler:\n            class: HighFreqGeneralHandler\n            module_path: qlib.contrib.data.highfreq_handler\n            kwargs:\n                start_time: *start\n                end_time: *end\n                fit_start_time: *start\n                fit_end_time: *tend\n                instruments: *ins\n                day_length: 240 # how many minutes in one trading day\n                freq: 5min\n                columns: [\"$open\", \"$high\", \"$low\", \"$close\"]\n                infer_processors:\n                - class: HighFreqNorm\n                  module_path: qlib.contrib.data.highfreq_processor\n                  kwargs:\n                    feature_save_dir: ./stat/  #  output path of statistics of features (for feature normalization)\n                    norm_groups: \n                        price: 8\n                        volume: 2\n                inst_processors:\n                - class: TimeRangeFlt\n                  module_path: qlib.data.dataset.processor\n                  kwargs:\n                    start_time: \"2020-01-01\"\n                    end_time: \"2021-12-31\"\n                    freq: 5min\n        segments:\n            train: !!python/tuple [*start, *tend]\n            valid: !!python/tuple [*vstart, *vend]\n            test: !!python/tuple [*tstart, *end]\nbacktest_conf:\n    path: ./data/pickle/backtest.pkl # output path of backtest\n    class: DatasetH\n    module_path: qlib.data.dataset\n    kwargs:\n        handler:\n            class: HighFreqGeneralBacktestHandler\n            module_path: qlib.contrib.data.highfreq_handler\n            kwargs:\n                start_time: *start\n                end_time: *end\n                instruments: *ins\n                day_length: 240\n                freq: 5min\n                columns: [\"$close\", \"$volume\"]\n                inst_processors:\n                - class: TimeRangeFlt\n                  module_path: qlib.data.dataset.processor\n                  kwargs:\n                    start_time: \"2020-01-01\"\n                    end_time: \"2021-12-31\"\n                    freq: 5min\n        segments:\n            train: !!python/tuple [*start, *tend]\n            valid: !!python/tuple [*vstart, *vend]\n            test: !!python/tuple [*tstart, *end]\nfreq: 5min\n"
  },
  {
    "path": "examples/rolling_process_data/README.md",
    "content": "# Rolling Process Data\n\nThis workflow is an example for `Rolling Process Data`.\n\n## Background\n\nWhen rolling train the models, data also needs to be generated in the different rolling windows. When the rolling window moves, the training data will change, and the processor's learnable state (such as standard deviation, mean, etc.) will also change. \n\nIn order to avoid regenerating data, this example uses the `DataHandler-based DataLoader` to load the raw features that are not related to the rolling window, and then used Processors to generate processed-features related to the rolling window.\n\n\n## Run the Code\n\nRun the example by running the following command:\n```bash\n    python workflow.py rolling_process\n```"
  },
  {
    "path": "examples/rolling_process_data/rolling_handler.py",
    "content": "from qlib.data.dataset.handler import DataHandlerLP\nfrom qlib.data.dataset.loader import DataLoaderDH\nfrom qlib.contrib.data.handler import check_transform_proc\n\n\nclass RollingDataHandler(DataHandlerLP):\n    def __init__(\n        self,\n        start_time=None,\n        end_time=None,\n        infer_processors=[],\n        learn_processors=[],\n        fit_start_time=None,\n        fit_end_time=None,\n        data_loader_kwargs={},\n    ):\n        infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)\n        learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)\n\n        data_loader = {\n            \"class\": \"DataLoaderDH\",\n            \"kwargs\": {**data_loader_kwargs},\n        }\n\n        super().__init__(\n            instruments=None,\n            start_time=start_time,\n            end_time=end_time,\n            data_loader=data_loader,\n            infer_processors=infer_processors,\n            learn_processors=learn_processors,\n        )\n"
  },
  {
    "path": "examples/rolling_process_data/workflow.py",
    "content": "#  Copyright (c) Microsoft Corporation.\n#  Licensed under the MIT License.\n\nimport qlib\nimport fire\n\nfrom datetime import datetime\nfrom qlib.constant import REG_CN\nfrom qlib.data.dataset.handler import DataHandlerLP\nfrom qlib.utils import init_instance_by_config\nfrom qlib.utils.pickle_utils import restricted_pickle_load\nfrom qlib.tests.data import GetData\n\n\nclass RollingDataWorkflow:\n    MARKET = \"csi300\"\n    start_time = \"2010-01-01\"\n    end_time = \"2019-12-31\"\n    rolling_cnt = 5\n\n    def _init_qlib(self):\n        \"\"\"initialize qlib\"\"\"\n        provider_uri = \"~/.qlib/qlib_data/cn_data\"  # target_dir\n        GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True)\n        qlib.init(provider_uri=provider_uri, region=REG_CN)\n\n    def _dump_pre_handler(self, path):\n        handler_config = {\n            \"class\": \"Alpha158\",\n            \"module_path\": \"qlib.contrib.data.handler\",\n            \"kwargs\": {\n                \"start_time\": self.start_time,\n                \"end_time\": self.end_time,\n                \"instruments\": self.MARKET,\n                \"infer_processors\": [],\n                \"learn_processors\": [],\n            },\n        }\n        pre_handler = init_instance_by_config(handler_config)\n        pre_handler.config(dump_all=True)\n        pre_handler.to_pickle(path)\n\n    def _load_pre_handler(self, path):\n        with open(path, \"rb\") as file_dataset:\n            pre_handler = restricted_pickle_load(file_dataset)\n        return pre_handler\n\n    def rolling_process(self):\n        self._init_qlib()\n        self._dump_pre_handler(\"pre_handler.pkl\")\n        pre_handler = self._load_pre_handler(\"pre_handler.pkl\")\n\n        train_start_time = (2010, 1, 1)\n        train_end_time = (2012, 12, 31)\n        valid_start_time = (2013, 1, 1)\n        valid_end_time = (2013, 12, 31)\n        test_start_time = (2014, 1, 1)\n        test_end_time = (2014, 12, 31)\n\n        dataset_config = {\n            \"class\": \"DatasetH\",\n            \"module_path\": \"qlib.data.dataset\",\n            \"kwargs\": {\n                \"handler\": {\n                    \"class\": \"RollingDataHandler\",\n                    \"module_path\": \"rolling_handler\",\n                    \"kwargs\": {\n                        \"start_time\": datetime(*train_start_time),\n                        \"end_time\": datetime(*test_end_time),\n                        \"fit_start_time\": datetime(*train_start_time),\n                        \"fit_end_time\": datetime(*train_end_time),\n                        \"infer_processors\": [\n                            {\"class\": \"RobustZScoreNorm\", \"kwargs\": {\"fields_group\": \"feature\"}},\n                        ],\n                        \"learn_processors\": [\n                            {\"class\": \"DropnaLabel\"},\n                            {\"class\": \"CSZScoreNorm\", \"kwargs\": {\"fields_group\": \"label\"}},\n                        ],\n                        \"data_loader_kwargs\": {\n                            \"handler_config\": pre_handler,\n                        },\n                    },\n                },\n                \"segments\": {\n                    \"train\": (datetime(*train_start_time), datetime(*train_end_time)),\n                    \"valid\": (datetime(*valid_start_time), datetime(*valid_end_time)),\n                    \"test\": (datetime(*test_start_time), datetime(*test_end_time)),\n                },\n            },\n        }\n\n        dataset = init_instance_by_config(dataset_config)\n\n        for rolling_offset in range(self.rolling_cnt):\n            print(f\"===========rolling{rolling_offset} start===========\")\n            if rolling_offset:\n                dataset.config(\n                    handler_kwargs={\n                        \"start_time\": datetime(train_start_time[0] + rolling_offset, *train_start_time[1:]),\n                        \"end_time\": datetime(test_end_time[0] + rolling_offset, *test_end_time[1:]),\n                        \"processor_kwargs\": {\n                            \"fit_start_time\": datetime(train_start_time[0] + rolling_offset, *train_start_time[1:]),\n                            \"fit_end_time\": datetime(train_end_time[0] + rolling_offset, *train_end_time[1:]),\n                        },\n                    },\n                    segments={\n                        \"train\": (\n                            datetime(train_start_time[0] + rolling_offset, *train_start_time[1:]),\n                            datetime(train_end_time[0] + rolling_offset, *train_end_time[1:]),\n                        ),\n                        \"valid\": (\n                            datetime(valid_start_time[0] + rolling_offset, *valid_start_time[1:]),\n                            datetime(valid_end_time[0] + rolling_offset, *valid_end_time[1:]),\n                        ),\n                        \"test\": (\n                            datetime(test_start_time[0] + rolling_offset, *test_start_time[1:]),\n                            datetime(test_end_time[0] + rolling_offset, *test_end_time[1:]),\n                        ),\n                    },\n                )\n                dataset.setup_data(\n                    handler_kwargs={\n                        \"init_type\": DataHandlerLP.IT_FIT_SEQ,\n                    }\n                )\n\n            dtrain, dvalid, dtest = dataset.prepare([\"train\", \"valid\", \"test\"])\n            print(dtrain, dvalid, dtest)\n            ## print or dump data\n            print(f\"===========rolling{rolling_offset} end===========\")\n\n\nif __name__ == \"__main__\":\n    fire.Fire(RollingDataWorkflow)\n"
  },
  {
    "path": "examples/run_all_model.py",
    "content": "#  Copyright (c) Microsoft Corporation.\n#  Licensed under the MIT License.\n\nimport os\nimport sys\nimport fire\nimport time\nimport glob\nimport shutil\nimport signal\nimport inspect\nimport tempfile\nimport functools\nimport statistics\nimport subprocess\nfrom datetime import datetime\nfrom ruamel.yaml import YAML\nfrom pathlib import Path\nfrom operator import xor\nfrom pprint import pprint\n\nimport qlib\nfrom qlib.workflow import R\nfrom qlib.tests.data import GetData\n\n\n# decorator to check the arguments\ndef only_allow_defined_args(function_to_decorate):\n    @functools.wraps(function_to_decorate)\n    def _return_wrapped(*args, **kwargs):\n        \"\"\"Internal wrapper function.\"\"\"\n        argspec = inspect.getfullargspec(function_to_decorate)\n        valid_names = set(argspec.args + argspec.kwonlyargs)\n        if \"self\" in valid_names:\n            valid_names.remove(\"self\")\n        for arg_name in kwargs:\n            if arg_name not in valid_names:\n                raise ValueError(\"Unknown argument seen '%s', expected: [%s]\" % (arg_name, \", \".join(valid_names)))\n        return function_to_decorate(*args, **kwargs)\n\n    return _return_wrapped\n\n\n# function to handle ctrl z and ctrl c\ndef handler(signum, frame):\n    os.system(\"kill -9 %d\" % os.getpid())\n\n\nsignal.signal(signal.SIGINT, handler)\n\n\n# function to calculate the mean and std of a list in the results dictionary\ndef cal_mean_std(results) -> dict:\n    mean_std = dict()\n    for fn in results:\n        mean_std[fn] = dict()\n        for metric in results[fn]:\n            mean = statistics.mean(results[fn][metric]) if len(results[fn][metric]) > 1 else results[fn][metric][0]\n            std = statistics.stdev(results[fn][metric]) if len(results[fn][metric]) > 1 else 0\n            mean_std[fn][metric] = [mean, std]\n    return mean_std\n\n\n# function to create the environment ofr an anaconda environment\ndef create_env():\n    # create env\n    temp_dir = tempfile.mkdtemp()\n    env_path = Path(temp_dir).absolute()\n    sys.stderr.write(f\"Creating Virtual Environment with path: {env_path}...\\n\")\n    execute(f\"conda create --prefix {env_path} python=3.7 -y\")\n    python_path = env_path / \"bin\" / \"python\"  # TODO: FIX ME!\n    sys.stderr.write(\"\\n\")\n    # get anaconda activate path\n    conda_activate = Path(os.environ[\"CONDA_PREFIX\"]) / \"bin\" / \"activate\"  # TODO: FIX ME!\n    return temp_dir, env_path, python_path, conda_activate\n\n\n# function to execute the cmd\ndef execute(cmd, wait_when_err=False, raise_err=True):\n    print(\"Running CMD:\", cmd)\n    with subprocess.Popen(cmd, stdout=subprocess.PIPE, bufsize=1, universal_newlines=True, shell=True) as p:\n        for line in p.stdout:\n            sys.stdout.write(line.split(\"\\b\")[0])\n            if \"\\b\" in line:\n                sys.stdout.flush()\n                time.sleep(0.1)\n                sys.stdout.write(\"\\b\" * 10 + \"\\b\".join(line.split(\"\\b\")[1:-1]))\n\n    if p.returncode != 0:\n        if wait_when_err:\n            input(\"Press Enter to Continue\")\n        if raise_err:\n            raise RuntimeError(f\"Error when executing command: {cmd}\")\n        return p.stderr\n    else:\n        return None\n\n\n# function to get all the folders benchmark folder\ndef get_all_folders(models, exclude) -> dict:\n    folders = dict()\n    if isinstance(models, str):\n        model_list = models.split(\",\")\n        models = [m.lower().strip(\"[ ]\") for m in model_list]\n    elif isinstance(models, list):\n        models = [m.lower() for m in models]\n    elif models is None:\n        models = [f.name.lower() for f in os.scandir(\"benchmarks\")]\n    else:\n        raise ValueError(\"Input models type is not supported. Please provide str or list without space.\")\n    for f in os.scandir(\"benchmarks\"):\n        add = xor(bool(f.name.lower() in models), bool(exclude))\n        if add:\n            path = Path(\"benchmarks\") / f.name\n            folders[f.name] = str(path.resolve())\n    return folders\n\n\n# function to get all the files under the model folder\ndef get_all_files(folder_path, dataset, universe=\"\") -> (str, str):\n    if universe != \"\":\n        universe = f\"_{universe}\"\n    yaml_path = str(Path(f\"{folder_path}\") / f\"*{dataset}{universe}.yaml\")\n    req_path = str(Path(f\"{folder_path}\") / f\"*.txt\")\n    yaml_file = glob.glob(yaml_path)\n    req_file = glob.glob(req_path)\n    if len(yaml_file) == 0:\n        return None, None\n    else:\n        return yaml_file[0], req_file[0]\n\n\n# function to retrieve all the results\ndef get_all_results(folders) -> dict:\n    results = dict()\n    for fn in folders:\n        try:\n            exp = R.get_exp(experiment_name=fn, create=False)\n        except ValueError:\n            # No experiment results\n            continue\n        recorders = exp.list_recorders()\n        result = dict()\n        result[\"annualized_return_with_cost\"] = list()\n        result[\"information_ratio_with_cost\"] = list()\n        result[\"max_drawdown_with_cost\"] = list()\n        result[\"ic\"] = list()\n        result[\"icir\"] = list()\n        result[\"rank_ic\"] = list()\n        result[\"rank_icir\"] = list()\n        for recorder_id in recorders:\n            if recorders[recorder_id].status == \"FINISHED\":\n                recorder = R.get_recorder(recorder_id=recorder_id, experiment_name=fn)\n                metrics = recorder.list_metrics()\n                if \"1day.excess_return_with_cost.annualized_return\" not in metrics:\n                    print(f\"{recorder_id} is skipped due to incomplete result\")\n                    continue\n                result[\"annualized_return_with_cost\"].append(metrics[\"1day.excess_return_with_cost.annualized_return\"])\n                result[\"information_ratio_with_cost\"].append(metrics[\"1day.excess_return_with_cost.information_ratio\"])\n                result[\"max_drawdown_with_cost\"].append(metrics[\"1day.excess_return_with_cost.max_drawdown\"])\n                result[\"ic\"].append(metrics[\"IC\"])\n                result[\"icir\"].append(metrics[\"ICIR\"])\n                result[\"rank_ic\"].append(metrics[\"Rank IC\"])\n                result[\"rank_icir\"].append(metrics[\"Rank ICIR\"])\n        results[fn] = result\n    return results\n\n\n# function to generate and save markdown table\ndef gen_and_save_md_table(metrics, dataset):\n    table = \"| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |\\n\"\n    table += \"|---|---|---|---|---|---|---|---|---|\\n\"\n    for fn in metrics:\n        ic = metrics[fn][\"ic\"]\n        icir = metrics[fn][\"icir\"]\n        ric = metrics[fn][\"rank_ic\"]\n        ricir = metrics[fn][\"rank_icir\"]\n        ar = metrics[fn][\"annualized_return_with_cost\"]\n        ir = metrics[fn][\"information_ratio_with_cost\"]\n        md = metrics[fn][\"max_drawdown_with_cost\"]\n        table += f\"| {fn} | {dataset} | {ic[0]:5.4f}±{ic[1]:2.2f} | {icir[0]:5.4f}±{icir[1]:2.2f}| {ric[0]:5.4f}±{ric[1]:2.2f} | {ricir[0]:5.4f}±{ricir[1]:2.2f} | {ar[0]:5.4f}±{ar[1]:2.2f} | {ir[0]:5.4f}±{ir[1]:2.2f}| {md[0]:5.4f}±{md[1]:2.2f} |\\n\"\n    pprint(table)\n    with open(\"table.md\", \"w\") as f:\n        f.write(table)\n    return table\n\n\n# read yaml, remove seed kwargs of model, and then save file in the temp_dir\ndef gen_yaml_file_without_seed_kwargs(yaml_path, temp_dir):\n    with open(yaml_path, \"r\") as fp:\n        yaml = YAML(typ=\"safe\", pure=True)\n        config = yaml.load(fp)\n    try:\n        del config[\"task\"][\"model\"][\"kwargs\"][\"seed\"]\n    except KeyError:\n        # If the key does not exists, use original yaml\n        # NOTE: it is very important if the model most run in original path(when sys.rel_path is used)\n        return yaml_path\n    else:\n        # otherwise, generating a new yaml without random seed\n        file_name = yaml_path.split(\"/\")[-1]\n        temp_path = os.path.join(temp_dir, file_name)\n        with open(temp_path, \"w\") as fp:\n            yaml.dump(config, fp)\n        return temp_path\n\n\nclass ModelRunner:\n    def _init_qlib(self, exp_folder_name):\n        # init qlib\n        GetData().qlib_data(exists_skip=True)\n        qlib.init(\n            exp_manager={\n                \"class\": \"MLflowExpManager\",\n                \"module_path\": \"qlib.workflow.expm\",\n                \"kwargs\": {\n                    \"uri\": \"file:\" + str(Path(os.getcwd()).resolve() / exp_folder_name),\n                    \"default_exp_name\": \"Experiment\",\n                },\n            }\n        )\n\n    # function to run the all the models\n    @only_allow_defined_args\n    def run(\n        self,\n        times=1,\n        models=None,\n        dataset=\"Alpha360\",\n        universe=\"\",\n        exclude=False,\n        qlib_uri: str = \"git+https://github.com/microsoft/qlib#egg=pyqlib\",\n        exp_folder_name: str = \"run_all_model_records\",\n        wait_before_rm_env: bool = False,\n        wait_when_err: bool = False,\n    ):\n        \"\"\"\n        Please be aware that this function can only work under Linux. MacOS and Windows will be supported in the future.\n        Any PR to enhance this method is highly welcomed. Besides, this script doesn't support parallel running the same model\n        for multiple times, and this will be fixed in the future development.\n\n        Parameters:\n        -----------\n        times : int\n            determines how many times the model should be running.\n        models : str or list\n            determines the specific model or list of models to run or exclude.\n        exclude : boolean\n            determines whether the model being used is excluded or included.\n        dataset : str\n            determines the dataset to be used for each model.\n        universe  : str\n            the stock universe of the dataset.\n            default \"\" indicates that\n        qlib_uri : str\n            the uri to install qlib with pip\n            it could be URI on the remote or local path (NOTE: the local path must be an absolute path)\n        exp_folder_name: str\n            the name of the experiment folder\n        wait_before_rm_env : bool\n            wait before remove environment.\n        wait_when_err : bool\n            wait when errors raised when executing commands\n\n        Usage:\n        -------\n        Here are some use cases of the function in the bash:\n\n        The run_all_models  will decide which config to run based no `models` `dataset`  `universe`\n        Example 1):\n\n            models=\"lightgbm\", dataset=\"Alpha158\", universe=\"\" will result in running the following config\n            examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml\n\n            models=\"lightgbm\", dataset=\"Alpha158\", universe=\"csi500\" will result in running the following config\n            examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158_csi500.yaml\n\n        .. code-block:: bash\n\n            # Case 1 - run all models multiple times\n            python run_all_model.py run 3\n\n            # Case 2 - run specific models multiple times\n            python run_all_model.py run 3 mlp\n\n            # Case 3 - run specific models multiple times with specific dataset\n            python run_all_model.py run 3 mlp Alpha158\n\n            # Case 4 - run other models except those are given as arguments for multiple times\n            python run_all_model.py run 3 [mlp,tft,lstm] --exclude=True\n\n            # Case 5 - run specific models for one time\n            python run_all_model.py run --models=[mlp,lightgbm]\n\n            # Case 6 - run other models except those are given as arguments for one time\n            python run_all_model.py run --models=[mlp,tft,sfm] --exclude=True\n\n            # Case 7 - run lightgbm model on csi500.\n            python run_all_model.py run 3 lightgbm Alpha158 csi500\n\n        \"\"\"\n        self._init_qlib(exp_folder_name)\n\n        # get all folders\n        folders = get_all_folders(models, exclude)\n        # init error messages:\n        errors = dict()\n        # run all the model for iterations\n        for fn in folders:\n            # get all files\n            sys.stderr.write(\"Retrieving files...\\n\")\n            yaml_path, req_path = get_all_files(folders[fn], dataset, universe=universe)\n            if yaml_path is None:\n                sys.stderr.write(f\"There is no {dataset}.yaml file in {folders[fn]}\")\n                continue\n            sys.stderr.write(\"\\n\")\n            # create env by anaconda\n            temp_dir, env_path, python_path, conda_activate = create_env()\n\n            # install requirements.txt\n            sys.stderr.write(\"Installing requirements.txt...\\n\")\n            with open(req_path) as f:\n                content = f.read()\n            if \"torch\" in content:\n                # automatically install pytorch according to nvidia's version\n                execute(\n                    f\"{python_path} -m pip install light-the-torch\", wait_when_err=wait_when_err\n                )  # for automatically installing torch according to the nvidia driver\n                execute(\n                    f\"{env_path / 'bin' / 'ltt'} install --install-cmd '{python_path} -m pip install {{packages}}' -- -r {req_path}\",\n                    wait_when_err=wait_when_err,\n                )\n            else:\n                execute(f\"{python_path} -m pip install -r {req_path}\", wait_when_err=wait_when_err)\n            sys.stderr.write(\"\\n\")\n\n            # read yaml, remove seed kwargs of model, and then save file in the temp_dir\n            yaml_path = gen_yaml_file_without_seed_kwargs(yaml_path, temp_dir)\n            # setup gpu for tft\n            if fn == \"TFT\":\n                execute(\n                    f\"conda install -y --prefix {env_path} anaconda cudatoolkit=10.0 && conda install -y --prefix {env_path} cudnn\",\n                    wait_when_err=wait_when_err,\n                )\n                sys.stderr.write(\"\\n\")\n            # install qlib\n            sys.stderr.write(\"Installing qlib...\\n\")\n            execute(f\"{python_path} -m pip install --upgrade pip\", wait_when_err=wait_when_err)  # TODO: FIX ME!\n            execute(f\"{python_path} -m pip install --upgrade cython\", wait_when_err=wait_when_err)  # TODO: FIX ME!\n            if fn == \"TFT\":\n                execute(\n                    f\"cd {env_path} && {python_path} -m pip install --upgrade --force-reinstall --ignore-installed PyYAML -e {qlib_uri}\",\n                    wait_when_err=wait_when_err,\n                )  # TODO: FIX ME!\n            else:\n                execute(\n                    f\"cd {env_path} && {python_path} -m pip install --upgrade --force-reinstall -e {qlib_uri}\",\n                    wait_when_err=wait_when_err,\n                )  # TODO: FIX ME!\n            sys.stderr.write(\"\\n\")\n            # run workflow_by_config for multiple times\n            for i in range(times):\n                sys.stderr.write(f\"Running the model: {fn} for iteration {i+1}...\\n\")\n                errs = execute(\n                    f\"{python_path} {env_path / 'bin' / 'qrun'} {yaml_path} {fn} {exp_folder_name}\",\n                    wait_when_err=wait_when_err,\n                )\n                if errs is not None:\n                    _errs = errors.get(fn, {})\n                    _errs.update({i: errs})\n                    errors[fn] = _errs\n                sys.stderr.write(\"\\n\")\n            # remove env\n            sys.stderr.write(f\"Deleting the environment: {env_path}...\\n\")\n            if wait_before_rm_env:\n                input(\"Press Enter to Continue\")\n            shutil.rmtree(env_path)\n        # print errors\n        sys.stderr.write(f\"Here are some of the errors of the models...\\n\")\n        pprint(errors)\n        self._collect_results(exp_folder_name, dataset)\n\n    def _collect_results(self, exp_folder_name, dataset):\n        folders = get_all_folders(exp_folder_name, dataset)\n        # getting all results\n        sys.stderr.write(f\"Retrieving results...\\n\")\n        results = get_all_results(folders)\n        if len(results) > 0:\n            # calculating the mean and std\n            sys.stderr.write(f\"Calculating the mean and std of results...\\n\")\n            results = cal_mean_std(results)\n            # generating md table\n            sys.stderr.write(f\"Generating markdown table...\\n\")\n            gen_and_save_md_table(results, dataset)\n            sys.stderr.write(\"\\n\")\n        sys.stderr.write(\"\\n\")\n        # move results folder\n        shutil.move(exp_folder_name, exp_folder_name + f\"_{dataset}_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}\")\n        shutil.move(\"table.md\", f\"table_{dataset}_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}.md\")\n\n\nif __name__ == \"__main__\":\n    fire.Fire(ModelRunner)  # run all the model\n"
  },
  {
    "path": "examples/workflow_by_code.py",
    "content": "#  Copyright (c) Microsoft Corporation.\n#  Licensed under the MIT License.\n\"\"\"\nQlib provides two kinds of interfaces.\n(1) Users could define the Quant research workflow by a simple configuration.\n(2) Qlib is designed in a modularized way and supports creating research workflow by code just like building blocks.\n\nThe interface of (1) is `qrun XXX.yaml`.  The interface of (2) is script like this, which nearly does the same thing as `qrun XXX.yaml`\n\"\"\"\n\nimport qlib\nfrom qlib.constant import REG_CN\nfrom qlib.utils import init_instance_by_config, flatten_dict\nfrom qlib.workflow import R\nfrom qlib.workflow.record_temp import SignalRecord, PortAnaRecord, SigAnaRecord\nfrom qlib.tests.data import GetData\nfrom qlib.tests.config import CSI300_BENCH, CSI300_GBDT_TASK\n\nif __name__ == \"__main__\":\n    # use default data\n    provider_uri = \"~/.qlib/qlib_data/cn_data\"  # target_dir\n    GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True)\n    qlib.init(provider_uri=provider_uri, region=REG_CN)\n\n    model = init_instance_by_config(CSI300_GBDT_TASK[\"model\"])\n    dataset = init_instance_by_config(CSI300_GBDT_TASK[\"dataset\"])\n\n    port_analysis_config = {\n        \"executor\": {\n            \"class\": \"SimulatorExecutor\",\n            \"module_path\": \"qlib.backtest.executor\",\n            \"kwargs\": {\n                \"time_per_step\": \"day\",\n                \"generate_portfolio_metrics\": True,\n            },\n        },\n        \"strategy\": {\n            \"class\": \"TopkDropoutStrategy\",\n            \"module_path\": \"qlib.contrib.strategy.signal_strategy\",\n            \"kwargs\": {\n                \"signal\": (model, dataset),\n                \"topk\": 50,\n                \"n_drop\": 5,\n            },\n        },\n        \"backtest\": {\n            \"start_time\": \"2017-01-01\",\n            \"end_time\": \"2020-08-01\",\n            \"account\": 100000000,\n            \"benchmark\": CSI300_BENCH,\n            \"exchange_kwargs\": {\n                \"freq\": \"day\",\n                \"limit_threshold\": 0.095,\n                \"deal_price\": \"close\",\n                \"open_cost\": 0.0005,\n                \"close_cost\": 0.0015,\n                \"min_cost\": 5,\n            },\n        },\n    }\n\n    # NOTE: This line is optional\n    # It demonstrates that the dataset can be used standalone.\n    example_df = dataset.prepare(\"train\")\n    print(example_df.head())\n\n    # start exp\n    with R.start(experiment_name=\"workflow\"):\n        R.log_params(**flatten_dict(CSI300_GBDT_TASK))\n        model.fit(dataset)\n        R.save_objects(**{\"params.pkl\": model})\n\n        # prediction\n        recorder = R.get_recorder()\n        sr = SignalRecord(model, dataset, recorder)\n        sr.generate()\n\n        # Signal Analysis\n        sar = SigAnaRecord(recorder)\n        sar.generate()\n\n        # backtest. If users want to use backtest based on their own prediction,\n        # please refer to https://qlib.readthedocs.io/en/latest/component/recorder.html#record-template.\n        par = PortAnaRecord(recorder, port_analysis_config, \"day\")\n        par.generate()\n"
  },
  {
    "path": "pyproject.toml",
    "content": "[build-system]\nrequires = [\"setuptools\", \"setuptools-scm\", \"cython\", \"numpy>=1.24.0\"]\nbuild-backend = \"setuptools.build_meta\"\n\n[project]\nclassifiers = [\n  \"Operating System :: POSIX :: Linux\",\n  \"Operating System :: Microsoft :: Windows\",\n  \"Operating System :: MacOS\",\n  \"License :: OSI Approved :: MIT License\",\n  \"Development Status :: 3 - Alpha\",\n  \"Programming Language :: Python\",\n  \"Programming Language :: Python :: 3\",\n  \"Programming Language :: Python :: 3.8\",\n  \"Programming Language :: Python :: 3.9\",\n  \"Programming Language :: Python :: 3.10\",\n  \"Programming Language :: Python :: 3.11\",\n  \"Programming Language :: Python :: 3.12\",\n]\nname = \"pyqlib\"\ndynamic = [\"version\"]\ndescription = \"A Quantitative-research Platform\"\nrequires-python = \">=3.8.0\"\nreadme = {file = \"README.md\", content-type = \"text/markdown\"}\nlicense = { text = \"MIT\" }\n\ndependencies = [\n  \"pyyaml\",\n  \"numpy\",\n  # Since version 1.1.0, pandas supports the ffill and bfill methods.\n  # Since version 2.1.0, pandas has deprecated the method parameter of the fillna method. \n  # qlib has updated the fillna method in PR 1987 and limited the minimum version of pandas.\n  \"pandas>=1.1\",\n  # I encoutered an Error that the set_uri does not work when downloading artifacts in mlflow 3.1.1;\n  # But earlier versions of mlflow does not have this problem.\n  # But when I switch to 2.*.* version, another error occurs, which is even more strange...\n  \"mlflow\",\n  \"filelock>=3.16.0\",\n  \"redis\",\n  \"dill\",\n  \"fire\",\n  \"ruamel.yaml>=0.17.38\",\n  \"python-redis-lock\",\n  \"tqdm\",\n  \"pymongo\",\n  \"loguru\",\n  \"lightgbm\",\n  \"gym\",\n  \"cvxpy\",\n  \"joblib\",\n  \"matplotlib\",\n  \"jupyter\",\n  \"nbconvert\",\n  \"pyarrow\",\n  \"pydantic-settings\",\n  \"setuptools-scm\",\n]\n\n[project.optional-dependencies]\ndev = [\n  \"pytest\",\n  \"statsmodels\",\n]\n# On macos-13 system, when using python version greater than or equal to 3.10,\n# pytorch can't fully support Numpy version above 2.0, so, when you want to install torch,\n# it will limit the version of Numpy less than 2.0.\nrl = [\n  \"tianshou<=0.4.10\",\n  \"torch\",\n  \"numpy<2.0.0\",\n]\n\nlint = [\n  \"black\",\n  \"pylint\",\n  \"mypy<1.5.0\",\n  \"flake8\",\n  \"nbqa\",\n]\n# snowballstemmer, a dependency of sphinx, was released on 2025-05-08 with version 3.0.0,\n# which causes errors in the build process. So we've limited the version for now.\ndocs = [\n  # After upgrading scipy to version 1.16.0,\n  # we encountered ImportError: cannot import name '_lazywhere', in the build documentation,\n  # so we restricted the version of scipy to: 1.15.3\n  \"scipy<=1.15.3\",\n  \"sphinx\",\n  \"sphinx_rtd_theme\",\n  \"readthedocs_sphinx_ext\",\n  \"snowballstemmer<3.0\",\n]\npackage = [\n  \"twine\",\n  \"build\",\n]\n# test_pit dependency packages\ntest = [\n  \"yahooquery\",\n  \"baostock\",\n]\nanalysis = [\n  \"plotly\",\n  \"statsmodels\",\n]\nclient = [\n  \"python-socketio<6\",\n  \"tables\",\n]\n\n# In the process of releasing a new version, when checking the manylinux package with twine, an error is reported:\n# InvalidDistribution: Invalid distribution metadata: unrecognized or malformed field 'license-file'\n# To solve this problem, we added license-files here. Refs: https://github.com/pypa/twine/issues/1216\n[tool.setuptools]\npackages = [\n  \"qlib\",\n]\nlicense-files = []\n\n[project.scripts]\nqrun = \"qlib.cli.run:run\"\n\n[tool.setuptools_scm]\nlocal_scheme = \"no-local-version\"\nversion_scheme = \"guess-next-dev\"\nwrite_to = \"qlib/_version.py\"\n"
  },
  {
    "path": "qlib/__init__.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\nfrom pathlib import Path\n\nfrom setuptools_scm import get_version\n\ntry:\n    from ._version import version as __version__\nexcept ImportError:\n    __version__ = get_version(root=\"..\", relative_to=__file__)\n__version__bak = __version__  # This version is backup for QlibConfig.reset_qlib_version\nimport logging\nimport os\nimport platform\nimport re\nimport subprocess\nfrom typing import Union\n\nfrom ruamel.yaml import YAML\n\nfrom .log import get_module_logger\n\n\n# init qlib\ndef init(default_conf=\"client\", **kwargs):\n    \"\"\"\n\n    Parameters\n    ----------\n    default_conf: str\n        the default value is client. Accepted values: client/server.\n    **kwargs :\n        clear_mem_cache: str\n            the default value is True;\n            Will the memory cache be clear.\n            It is often used to improve performance when init will be called for multiple times\n        skip_if_reg: bool: str\n            the default value is True;\n            When using the recorder, skip_if_reg can set to True to avoid loss of recorder.\n\n    \"\"\"\n    from .config import C  # pylint: disable=C0415\n    from .data.cache import H  # pylint: disable=C0415\n\n    logger = get_module_logger(\"Initialization\")\n\n    skip_if_reg = kwargs.pop(\"skip_if_reg\", False)\n    if skip_if_reg and C.registered:\n        # if we reinitialize Qlib during running an experiment `R.start`.\n        # it will result in loss of the recorder\n        logger.warning(\"Skip initialization because `skip_if_reg is True`\")\n        return\n\n    clear_mem_cache = kwargs.pop(\"clear_mem_cache\", True)\n    if clear_mem_cache:\n        H.clear()\n    C.set(default_conf, **kwargs)\n    get_module_logger.setLevel(C.logging_level)\n\n    # mount nfs\n    for _freq, provider_uri in C.provider_uri.items():\n        mount_path = C[\"mount_path\"][_freq]\n        # check path if server/local\n        uri_type = C.dpm.get_uri_type(provider_uri)\n        if uri_type == C.LOCAL_URI:\n            if not Path(provider_uri).exists():\n                if C[\"auto_mount\"]:\n                    logger.error(\n                        f\"Invalid provider uri: {provider_uri}, please check if a valid provider uri has been set. This path does not exist.\"\n                    )\n                else:\n                    logger.warning(f\"auto_path is False, please make sure {mount_path} is mounted\")\n        elif uri_type == C.NFS_URI:\n            _mount_nfs_uri(provider_uri, C.dpm.get_data_uri(_freq), C[\"auto_mount\"])\n        else:\n            raise NotImplementedError(f\"This type of URI is not supported\")\n\n    C.register()\n\n    if \"flask_server\" in C:\n        logger.info(f\"flask_server={C['flask_server']}, flask_port={C['flask_port']}\")\n    logger.info(\"qlib successfully initialized based on %s settings.\" % default_conf)\n    data_path = {_freq: C.dpm.get_data_uri(_freq) for _freq in C.dpm.provider_uri.keys()}\n    logger.info(f\"data_path={data_path}\")\n\n\ndef _mount_nfs_uri(provider_uri, mount_path, auto_mount: bool = False):\n    LOG = get_module_logger(\"mount nfs\", level=logging.INFO)\n    if mount_path is None:\n        raise ValueError(f\"Invalid mount path: {mount_path}!\")\n    if not re.match(r\"^[a-zA-Z0-9.:/\\-_]+$\", provider_uri):\n        raise ValueError(f\"Invalid provider_uri format: {provider_uri}\")\n    # FIXME: the C[\"provider_uri\"] is modified in this function\n    # If it is not modified, we can pass only  provider_uri or mount_path instead of C\n    mount_command = [\"sudo\", \"mount.nfs\", provider_uri, mount_path]\n    # If the provider uri looks like this 172.23.233.89//data/csdesign'\n    # It will be a nfs path. The client provider will be used\n    if not auto_mount:  # pylint: disable=R1702\n        if not Path(mount_path).exists():\n            raise FileNotFoundError(\n                f\"Invalid mount path: {mount_path}! Please mount manually: {' '.join(mount_command)} or Set init parameter `auto_mount=True`\"\n            )\n    else:\n        # Judging system type\n        sys_type = platform.system()\n        if \"windows\" in sys_type.lower():\n            # system: window\n            try:\n                subprocess.run(\n                    [\"mount\", \"-o\", \"anon\", provider_uri, mount_path],\n                    capture_output=True,\n                    text=True,\n                    check=True,\n                )\n                LOG.info(\"Mount finished.\")\n            except subprocess.CalledProcessError as e:\n                error_output = (e.stdout or \"\") + (e.stderr or \"\")\n                if e.returncode == 85:\n                    LOG.warning(f\"{provider_uri} already mounted at {mount_path}\")\n                elif e.returncode == 53:\n                    raise OSError(\"Network path not found\") from e\n                elif \"error\" in error_output.lower() or \"错误\" in error_output:\n                    raise OSError(\"Invalid mount path\") from e\n                else:\n                    raise OSError(f\"Unknown mount error: {error_output.strip()}\") from e\n        else:\n            # system: linux/Unix/Mac\n            # check mount\n            _remote_uri = provider_uri[:-1] if provider_uri.endswith(\"/\") else provider_uri\n            # `mount a /b/c` is different from `mount a /b/c/`. So we convert it into string to make sure handling it accurately\n            mount_path = str(mount_path)\n            _mount_path = mount_path[:-1] if mount_path.endswith(\"/\") else mount_path\n            _check_level_num = 2\n            _is_mount = False\n            while _check_level_num:\n                with subprocess.Popen(\n                    [\"mount\"],\n                    text=True,\n                    stdout=subprocess.PIPE,\n                    stderr=subprocess.STDOUT,\n                ) as shell_r:\n                    _command_log = shell_r.stdout.readlines()\n                    _command_log = [line for line in _command_log if _remote_uri in line]\n                if len(_command_log) > 0:\n                    for _c in _command_log:\n                        if isinstance(_c, str):\n                            _temp_mount = _c.split(\" \")[2]\n                        else:\n                            _temp_mount = _c.decode(\"utf-8\").split(\" \")[2]\n                        _temp_mount = _temp_mount[:-1] if _temp_mount.endswith(\"/\") else _temp_mount\n                        if _temp_mount == _mount_path:\n                            _is_mount = True\n                            break\n                if _is_mount:\n                    break\n                _remote_uri = \"/\".join(_remote_uri.split(\"/\")[:-1])\n                _mount_path = \"/\".join(_mount_path.split(\"/\")[:-1])\n                _check_level_num -= 1\n\n            if not _is_mount:\n                try:\n                    Path(mount_path).mkdir(parents=True, exist_ok=True)\n                except Exception as e:\n                    raise OSError(\n                        f\"Failed to create directory {mount_path}, please create {mount_path} manually!\"\n                    ) from e\n\n                # check nfs-common\n                command_res = os.popen(\"dpkg -l | grep nfs-common\")\n                command_res = command_res.readlines()\n                if not command_res:\n                    raise OSError(\"nfs-common is not found, please install it by execute: sudo apt install nfs-common\")\n                # manually mount\n                try:\n                    subprocess.run(mount_command, check=True, capture_output=True, text=True)\n                    LOG.info(\"Mount finished.\")\n                except subprocess.CalledProcessError as e:\n                    if e.returncode == 256:\n                        raise OSError(\"Mount failed: requires sudo or permission denied\") from e\n                    elif e.returncode == 32512:\n                        raise OSError(f\"mount {provider_uri} on {mount_path} error! Command error\") from e\n                    else:\n                        raise OSError(f\"Mount failed: {e.stderr}\") from e\n            else:\n                LOG.warning(f\"{_remote_uri} on {_mount_path} is already mounted\")\n\n\ndef init_from_yaml_conf(conf_path, **kwargs):\n    \"\"\"init_from_yaml_conf\n\n    :param conf_path: A path to the qlib config in yml format\n    \"\"\"\n\n    if conf_path is None:\n        config = {}\n    else:\n        with open(conf_path) as f:\n            yaml = YAML(typ=\"safe\", pure=True)\n            config = yaml.load(f)\n    config.update(kwargs)\n    default_conf = config.pop(\"default_conf\", \"client\")\n    init(default_conf, **config)\n\n\ndef get_project_path(config_name=\"config.yaml\", cur_path: Union[Path, str, None] = None) -> Path:\n    \"\"\"\n    If users are building a project follow the following pattern.\n    - Qlib is a sub folder in project path\n    - There is a file named `config.yaml` in qlib.\n\n    For example:\n        If your project file system structure follows such a pattern\n\n            <project_path>/\n              - config.yaml\n              - ...some folders...\n                - qlib/\n\n        This folder will return <project_path>\n\n        NOTE: link is not supported here.\n\n\n    This method is often used when\n    - user want to use a relative config path instead of hard-coding qlib config path in code\n\n    Raises\n    ------\n    FileNotFoundError:\n        If project path is not found\n    \"\"\"\n    if cur_path is None:\n        cur_path = Path(__file__).absolute().resolve()\n    cur_path = Path(cur_path)\n    while True:\n        if (cur_path / config_name).exists():\n            return cur_path\n        if cur_path == cur_path.parent:\n            raise FileNotFoundError(\"We can't find the project path\")\n        cur_path = cur_path.parent\n\n\ndef auto_init(**kwargs):\n    \"\"\"\n    This function will init qlib automatically with following priority\n    - Find the project configuration and init qlib\n        - The parsing process will be affected by the `conf_type` of the configuration file\n    - Init qlib with default config\n    - Skip initialization if already initialized\n\n    :**kwargs: it may contain following parameters\n                cur_path: the start path to find the project path\n\n    Here are two examples of the configuration\n\n    Example 1)\n    If you want to create a new project-specific config based on a shared configure, you can use  `conf_type: ref`\n\n    .. code-block:: yaml\n\n        conf_type: ref\n        qlib_cfg: '<shared_yaml_config_path>'    # this could be null reference no config from other files\n        # following configs in `qlib_cfg_update` is project=specific\n        qlib_cfg_update:\n            exp_manager:\n                class: \"MLflowExpManager\"\n                module_path: \"qlib.workflow.expm\"\n                kwargs:\n                    uri: \"file://<your mlflow experiment path>\"\n                    default_exp_name: \"Experiment\"\n\n    Example 2)\n    If you want to create simple a standalone config, you can use following config(a.k.a. `conf_type: origin`)\n\n    .. code-block:: python\n\n        exp_manager:\n            class: \"MLflowExpManager\"\n            module_path: \"qlib.workflow.expm\"\n            kwargs:\n                uri: \"file://<your mlflow experiment path>\"\n                default_exp_name: \"Experiment\"\n\n    \"\"\"\n    kwargs[\"skip_if_reg\"] = kwargs.get(\"skip_if_reg\", True)\n\n    try:\n        pp = get_project_path(cur_path=kwargs.pop(\"cur_path\", None))\n    except FileNotFoundError:\n        init(**kwargs)\n    else:\n        logger = get_module_logger(\"Initialization\")\n        conf_pp = pp / \"config.yaml\"\n        with conf_pp.open() as f:\n            yaml = YAML(typ=\"safe\", pure=True)\n            conf = yaml.load(f)\n\n        conf_type = conf.get(\"conf_type\", \"origin\")\n        if conf_type == \"origin\":\n            # The type of config is just like original qlib config\n            init_from_yaml_conf(conf_pp, **kwargs)\n        elif conf_type == \"ref\":\n            # This config type will be more convenient in following scenario\n            # - There is a shared configure file, and you don't want to edit it inplace.\n            # - The shared configure may be updated later, and you don't want to copy it.\n            # - You have some customized config.\n            qlib_conf_path = conf.get(\"qlib_cfg\", None)\n\n            # merge the arguments\n            qlib_conf_update = conf.get(\"qlib_cfg_update\", {})\n            for k, v in kwargs.items():\n                if k in qlib_conf_update:\n                    logger.warning(f\"`qlib_conf_update` from conf_pp is override by `kwargs` on key '{k}'\")\n            qlib_conf_update.update(kwargs)\n\n            init_from_yaml_conf(qlib_conf_path, **qlib_conf_update)\n        logger.info(f\"Auto load project config: {conf_pp}\")\n"
  },
  {
    "path": "qlib/backtest/__init__.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom __future__ import annotations\n\nimport copy\nfrom pathlib import Path\nfrom typing import TYPE_CHECKING, Any, Generator, List, Optional, Tuple, Union\n\nimport pandas as pd\n\nfrom .account import Account\n\nif TYPE_CHECKING:\n    from ..strategy.base import BaseStrategy\n    from .executor import BaseExecutor\n    from .decision import BaseTradeDecision\n\nfrom ..config import C\nfrom ..log import get_module_logger\nfrom ..utils import init_instance_by_config\nfrom .backtest import INDICATOR_METRIC, PORT_METRIC, backtest_loop, collect_data_loop\nfrom .decision import Order\nfrom .exchange import Exchange\nfrom .utils import CommonInfrastructure\n\n# make import more user-friendly by adding `from qlib.backtest import STH`\n\n\nlogger = get_module_logger(\"backtest caller\")\n\n\ndef get_exchange(\n    exchange: Union[str, dict, object, Path] = None,\n    freq: str = \"day\",\n    start_time: Union[pd.Timestamp, str] = None,\n    end_time: Union[pd.Timestamp, str] = None,\n    codes: Union[list, str] = \"all\",\n    subscribe_fields: list = [],\n    open_cost: float = 0.0015,\n    close_cost: float = 0.0025,\n    min_cost: float = 5.0,\n    limit_threshold: Union[Tuple[str, str], float, None] | None = None,\n    deal_price: Union[str, Tuple[str, str], List[str]] | None = None,\n    **kwargs: Any,\n) -> Exchange:\n    \"\"\"get_exchange\n\n    Parameters\n    ----------\n\n    # exchange related arguments\n    exchange: Exchange\n        It could be None or any types that are acceptable by `init_instance_by_config`.\n    freq: str\n        frequency of data.\n    start_time: Union[pd.Timestamp, str]\n        closed start time for backtest.\n    end_time: Union[pd.Timestamp, str]\n        closed end time for backtest.\n    codes: Union[list, str]\n        list stock_id list or a string of instruments (i.e. all, csi500, sse50)\n    subscribe_fields: list\n        subscribe fields.\n    open_cost : float\n        open transaction cost. It is a ratio. The cost is proportional to your order's deal amount.\n    close_cost : float\n        close transaction cost. It is a ratio. The cost is proportional to your order's deal amount.\n    min_cost : float\n        min transaction cost.  It is an absolute amount of cost instead of a ratio of your order's deal amount.\n        e.g. You must pay at least 5 yuan of commission regardless of your order's deal amount.\n    deal_price: Union[str, Tuple[str, str], List[str]]\n                The `deal_price` supports following two types of input\n                - <deal_price> : str\n                - (<buy_price>, <sell_price>): Tuple[str, str] or List[str]\n\n                <deal_price>, <buy_price> or <sell_price> := <price>\n                <price> := str\n                - for example '$close', '$open', '$vwap' (\"close\" is OK. `Exchange` will help to prepend\n                  \"$\" to the expression)\n    limit_threshold : float\n        limit move 0.1 (10%) for example, long and short with same limit.\n\n    Returns\n    -------\n    :class: Exchange\n    an initialized Exchange object\n    \"\"\"\n\n    if limit_threshold is None:\n        limit_threshold = C.limit_threshold\n    if exchange is None:\n        logger.info(\"Create new exchange\")\n\n        exchange = Exchange(\n            freq=freq,\n            start_time=start_time,\n            end_time=end_time,\n            codes=codes,\n            deal_price=deal_price,\n            subscribe_fields=subscribe_fields,\n            limit_threshold=limit_threshold,\n            open_cost=open_cost,\n            close_cost=close_cost,\n            min_cost=min_cost,\n            **kwargs,\n        )\n        return exchange\n    else:\n        return init_instance_by_config(exchange, accept_types=Exchange)\n\n\ndef create_account_instance(\n    start_time: Union[pd.Timestamp, str],\n    end_time: Union[pd.Timestamp, str],\n    benchmark: Optional[str],\n    account: Union[float, int, dict],\n    pos_type: str = \"Position\",\n) -> Account:\n    \"\"\"\n    # TODO: is very strange pass benchmark_config in the account (maybe for report)\n    # There should be a post-step to process the report.\n\n    Parameters\n    ----------\n    start_time\n        start time of the benchmark\n    end_time\n        end time of the benchmark\n    benchmark : str\n        the benchmark for reporting\n    account :   Union[\n                    float,\n                    {\n                        \"cash\": float,\n                        \"stock1\": Union[\n                                        int,    # it is equal to {\"amount\": int}\n                                        {\"amount\": int, \"price\"(optional): float},\n                                  ]\n                    },\n                ]\n        information for describing how to creating the account\n        For `float`:\n            Using Account with only initial cash\n        For `dict`:\n            key \"cash\" means initial cash.\n            key \"stock1\" means the information of first stock with amount and price(optional).\n            ...\n    pos_type: str\n        Postion type.\n    \"\"\"\n    if isinstance(account, (int, float)):\n        init_cash = account\n        position_dict = {}\n    elif isinstance(account, dict):\n        init_cash = account.pop(\"cash\")\n        position_dict = account\n    else:\n        raise ValueError(\"account must be in (int, float, dict)\")\n\n    return Account(\n        init_cash=init_cash,\n        position_dict=position_dict,\n        pos_type=pos_type,\n        benchmark_config=(\n            {}\n            if benchmark is None\n            else {\n                \"benchmark\": benchmark,\n                \"start_time\": start_time,\n                \"end_time\": end_time,\n            }\n        ),\n    )\n\n\ndef get_strategy_executor(\n    start_time: Union[pd.Timestamp, str],\n    end_time: Union[pd.Timestamp, str],\n    strategy: Union[str, dict, object, Path],\n    executor: Union[str, dict, object, Path],\n    benchmark: Optional[str] = \"SH000300\",\n    account: Union[float, int, dict] = 1e9,\n    exchange_kwargs: dict = {},\n    pos_type: str = \"Position\",\n) -> Tuple[BaseStrategy, BaseExecutor]:\n    # NOTE:\n    # - for avoiding recursive import\n    # - typing annotations is not reliable\n    from ..strategy.base import BaseStrategy  # pylint: disable=C0415\n    from .executor import BaseExecutor  # pylint: disable=C0415\n\n    trade_account = create_account_instance(\n        start_time=start_time,\n        end_time=end_time,\n        benchmark=benchmark,\n        account=account,\n        pos_type=pos_type,\n    )\n\n    exchange_kwargs = copy.copy(exchange_kwargs)\n    if \"start_time\" not in exchange_kwargs:\n        exchange_kwargs[\"start_time\"] = start_time\n    if \"end_time\" not in exchange_kwargs:\n        exchange_kwargs[\"end_time\"] = end_time\n    trade_exchange = get_exchange(**exchange_kwargs)\n\n    common_infra = CommonInfrastructure(trade_account=trade_account, trade_exchange=trade_exchange)\n    trade_strategy = init_instance_by_config(strategy, accept_types=BaseStrategy)\n    trade_strategy.reset_common_infra(common_infra)\n    trade_executor = init_instance_by_config(executor, accept_types=BaseExecutor)\n    trade_executor.reset_common_infra(common_infra)\n\n    return trade_strategy, trade_executor\n\n\ndef backtest(\n    start_time: Union[pd.Timestamp, str],\n    end_time: Union[pd.Timestamp, str],\n    strategy: Union[str, dict, object, Path],\n    executor: Union[str, dict, object, Path],\n    benchmark: str = \"SH000300\",\n    account: Union[float, int, dict] = 1e9,\n    exchange_kwargs: dict = {},\n    pos_type: str = \"Position\",\n) -> Tuple[PORT_METRIC, INDICATOR_METRIC]:\n    \"\"\"initialize the strategy and executor, then backtest function for the interaction of the outermost strategy and\n    executor in the nested decision execution\n\n    Parameters\n    ----------\n    start_time : Union[pd.Timestamp, str]\n        closed start time for backtest\n        **NOTE**: This will be applied to the outmost executor's calendar.\n    end_time : Union[pd.Timestamp, str]\n        closed end time for backtest\n        **NOTE**: This will be applied to the outmost executor's calendar.\n        E.g. Executor[day](Executor[1min]),   setting `end_time == 20XX0301` will include all the minutes on 20XX0301\n    strategy : Union[str, dict, object, Path]\n        for initializing outermost portfolio strategy. Please refer to the docs of init_instance_by_config for more\n        information.\n    executor : Union[str, dict, object, Path]\n        for initializing the outermost executor.\n    benchmark: str\n        the benchmark for reporting.\n    account : Union[float, int, Position]\n        information for describing how to create the account\n        For `float` or `int`:\n            Using Account with only initial cash\n        For `Position`:\n            Using Account with a Position\n    exchange_kwargs : dict\n        the kwargs for initializing Exchange\n    pos_type : str\n        the type of Position.\n\n    Returns\n    -------\n    portfolio_dict: PORT_METRIC\n        it records the trading portfolio_metrics information\n    indicator_dict: INDICATOR_METRIC\n        it computes the trading indicator\n        It is organized in a dict format\n\n    \"\"\"\n    trade_strategy, trade_executor = get_strategy_executor(\n        start_time,\n        end_time,\n        strategy,\n        executor,\n        benchmark,\n        account,\n        exchange_kwargs,\n        pos_type=pos_type,\n    )\n    return backtest_loop(start_time, end_time, trade_strategy, trade_executor)\n\n\ndef collect_data(\n    start_time: Union[pd.Timestamp, str],\n    end_time: Union[pd.Timestamp, str],\n    strategy: Union[str, dict, object, Path],\n    executor: Union[str, dict, object, Path],\n    benchmark: str = \"SH000300\",\n    account: Union[float, int, dict] = 1e9,\n    exchange_kwargs: dict = {},\n    pos_type: str = \"Position\",\n    return_value: dict | None = None,\n) -> Generator[object, None, None]:\n    \"\"\"initialize the strategy and executor, then collect the trade decision data for rl training\n\n    please refer to the docs of the backtest for the explanation of the parameters\n\n    Yields\n    -------\n    object\n        trade decision\n    \"\"\"\n    trade_strategy, trade_executor = get_strategy_executor(\n        start_time,\n        end_time,\n        strategy,\n        executor,\n        benchmark,\n        account,\n        exchange_kwargs,\n        pos_type=pos_type,\n    )\n    yield from collect_data_loop(start_time, end_time, trade_strategy, trade_executor, return_value=return_value)\n\n\ndef format_decisions(\n    decisions: List[BaseTradeDecision],\n) -> Optional[Tuple[str, List[Tuple[BaseTradeDecision, Union[Tuple, None]]]]]:\n    \"\"\"\n    format the decisions collected by `qlib.backtest.collect_data`\n    The decisions will be organized into a tree-like structure.\n\n    Parameters\n    ----------\n    decisions : List[BaseTradeDecision]\n        decisions collected by `qlib.backtest.collect_data`\n\n    Returns\n    -------\n    Tuple[str, List[Tuple[BaseTradeDecision, Union[Tuple, None]]]]:\n\n        reformat the list of decisions into a more user-friendly format\n        <decisions> :=  Tuple[<freq>, List[Tuple[<decision>, <sub decisions>]]]\n        - <sub decisions> := `<decisions> in lower level` | None\n        - <freq> := \"day\" | \"30min\" | \"1min\" | ...\n        - <decision> := <instance of BaseTradeDecision>\n    \"\"\"\n    if len(decisions) == 0:\n        return None\n\n    cur_freq = decisions[0].strategy.trade_calendar.get_freq()\n\n    res: Tuple[str, list] = (cur_freq, [])\n    last_dec_idx = 0\n    for i, dec in enumerate(decisions[1:], 1):\n        if dec.strategy.trade_calendar.get_freq() == cur_freq:\n            res[1].append((decisions[last_dec_idx], format_decisions(decisions[last_dec_idx + 1 : i])))\n            last_dec_idx = i\n    res[1].append((decisions[last_dec_idx], format_decisions(decisions[last_dec_idx + 1 :])))\n    return res\n\n\n__all__ = [\"Order\", \"backtest\", \"get_strategy_executor\"]\n"
  },
  {
    "path": "qlib/backtest/account.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\nfrom __future__ import annotations\n\nimport copy\nfrom typing import Dict, List, Optional, Tuple, cast\n\nimport pandas as pd\n\nfrom qlib.utils import init_instance_by_config\n\nfrom .decision import BaseTradeDecision, Order\nfrom .exchange import Exchange\nfrom .high_performance_ds import BaseOrderIndicator\nfrom .position import BasePosition\nfrom .report import Indicator, PortfolioMetrics\n\n\"\"\"\nrtn & earning in the Account\n    rtn:\n        from order's view\n        1.change if any order is executed, sell order or buy order\n        2.change at the end of today,   (today_close - stock_price) * amount\n    earning\n        from value of current position\n        earning will be updated at the end of trade date\n        earning = today_value - pre_value\n    **is consider cost**\n        while earning is the difference of two position value, so it considers cost, it is the true return rate\n        in the specific accomplishment for rtn, it does not consider cost, in other words, rtn - cost = earning\n\n\"\"\"\n\n\nclass AccumulatedInfo:\n    \"\"\"\n    accumulated trading info, including accumulated return/cost/turnover\n    AccumulatedInfo should be shared across different levels\n    \"\"\"\n\n    def __init__(self) -> None:\n        self.reset()\n\n    def reset(self) -> None:\n        self.rtn: float = 0.0  # accumulated return, do not consider cost\n        self.cost: float = 0.0  # accumulated cost\n        self.to: float = 0.0  # accumulated turnover\n\n    def add_return_value(self, value: float) -> None:\n        self.rtn += value\n\n    def add_cost(self, value: float) -> None:\n        self.cost += value\n\n    def add_turnover(self, value: float) -> None:\n        self.to += value\n\n    @property\n    def get_return(self) -> float:\n        return self.rtn\n\n    @property\n    def get_cost(self) -> float:\n        return self.cost\n\n    @property\n    def get_turnover(self) -> float:\n        return self.to\n\n\nclass Account:\n    \"\"\"\n    The correctness of the metrics of Account in nested execution depends on the shallow copy of `trade_account` in\n    qlib/backtest/executor.py:NestedExecutor\n    Different level of executor has different Account object when calculating metrics. But the position object is\n    shared cross all the Account object.\n    \"\"\"\n\n    def __init__(\n        self,\n        init_cash: float = 1e9,\n        position_dict: dict = {},\n        freq: str = \"day\",\n        benchmark_config: dict = {},\n        pos_type: str = \"Position\",\n        port_metr_enabled: bool = True,\n    ) -> None:\n        \"\"\"the trade account of backtest.\n\n        Parameters\n        ----------\n        init_cash : float, optional\n            initial cash, by default 1e9\n        position_dict : Dict[\n                            stock_id,\n                            Union[\n                                int,  # it is equal to {\"amount\": int}\n                                {\"amount\": int, \"price\"(optional): float},\n                            ]\n                        ]\n            initial stocks with parameters amount and price,\n            if there is no price key in the dict of stocks, it will be filled by _fill_stock_value.\n            by default {}.\n        \"\"\"\n\n        self._pos_type = pos_type\n        self._port_metr_enabled = port_metr_enabled\n        self.benchmark_config: dict = {}  # avoid no attribute error\n        self.init_vars(init_cash, position_dict, freq, benchmark_config)\n\n    def init_vars(self, init_cash: float, position_dict: dict, freq: str, benchmark_config: dict) -> None:\n        # 1) the following variables are shared by multiple layers\n        # - you will see a shallow copy instead of deepcopy in the NestedExecutor;\n        self.init_cash = init_cash\n        self.current_position: BasePosition = init_instance_by_config(\n            {\n                \"class\": self._pos_type,\n                \"kwargs\": {\n                    \"cash\": init_cash,\n                    \"position_dict\": position_dict,\n                },\n                \"module_path\": \"qlib.backtest.position\",\n            },\n        )\n        self.accum_info = AccumulatedInfo()\n\n        # 2) following variables are not shared between layers\n        self.portfolio_metrics: Optional[PortfolioMetrics] = None\n        self.hist_positions: Dict[pd.Timestamp, BasePosition] = {}\n        self.reset(freq=freq, benchmark_config=benchmark_config)\n\n    def is_port_metr_enabled(self) -> bool:\n        \"\"\"\n        Is portfolio-based metrics enabled.\n        \"\"\"\n        return self._port_metr_enabled and not self.current_position.skip_update()\n\n    def reset_report(self, freq: str, benchmark_config: dict) -> None:\n        # portfolio related metrics\n        if self.is_port_metr_enabled():\n            # NOTE:\n            # `accum_info` and `current_position` are shared here\n            self.portfolio_metrics = PortfolioMetrics(freq, benchmark_config)\n            self.hist_positions = {}\n\n            # fill stock value\n            # The frequency of account may not align with the trading frequency.\n            # This may result in obscure bugs when data quality is low.\n            if isinstance(self.benchmark_config, dict) and \"start_time\" in self.benchmark_config:\n                self.current_position.fill_stock_value(self.benchmark_config[\"start_time\"], self.freq)\n\n        # trading related metrics(e.g. high-frequency trading)\n        self.indicator = Indicator()\n\n    def reset(\n        self, freq: str | None = None, benchmark_config: dict | None = None, port_metr_enabled: bool | None = None\n    ) -> None:\n        \"\"\"reset freq and report of account\n\n        Parameters\n        ----------\n        freq : str, optional\n            frequency of account & report, by default None\n        benchmark_config : {}, optional\n            benchmark config of report, by default None\n        port_metr_enabled: bool\n        \"\"\"\n        if freq is not None:\n            self.freq = freq\n        if benchmark_config is not None:\n            self.benchmark_config = benchmark_config\n        if port_metr_enabled is not None:\n            self._port_metr_enabled = port_metr_enabled\n\n        self.reset_report(self.freq, self.benchmark_config)\n\n    def get_hist_positions(self) -> Dict[pd.Timestamp, BasePosition]:\n        return self.hist_positions\n\n    def get_cash(self) -> float:\n        return self.current_position.get_cash()\n\n    def _update_state_from_order(self, order: Order, trade_val: float, cost: float, trade_price: float) -> None:\n        if self.is_port_metr_enabled():\n            # update turnover\n            self.accum_info.add_turnover(trade_val)\n            # update cost\n            self.accum_info.add_cost(cost)\n\n            # update return from order\n            trade_amount = trade_val / trade_price\n            if order.direction == Order.SELL:  # 0 for sell\n                # when sell stock, get profit from price change\n                profit = trade_val - self.current_position.get_stock_price(order.stock_id) * trade_amount\n                self.accum_info.add_return_value(profit)  # note here do not consider cost\n\n            elif order.direction == Order.BUY:  # 1 for buy\n                # when buy stock, we get return for the rtn computing method\n                # profit in buy order is to make rtn is consistent with earning at the end of bar\n                profit = self.current_position.get_stock_price(order.stock_id) * trade_amount - trade_val\n                self.accum_info.add_return_value(profit)  # note here do not consider cost\n\n    def update_order(self, order: Order, trade_val: float, cost: float, trade_price: float) -> None:\n        if self.current_position.skip_update():\n            # TODO: supporting polymorphism for account\n            # updating order for infinite position is meaningless\n            return\n\n        # if stock is sold out, no stock price information in Position, then we should update account first,\n        # then update current position\n        # if stock is bought, there is no stock in current position, update current, then update account\n        # The cost will be subtracted from the cash at last. So the trading logic can ignore the cost calculation\n        if order.direction == Order.SELL:\n            # sell stock\n            self._update_state_from_order(order, trade_val, cost, trade_price)\n            # update current position\n            # for may sell all of stock_id\n            self.current_position.update_order(order, trade_val, cost, trade_price)\n        else:\n            # buy stock\n            # deal order, then update state\n            self.current_position.update_order(order, trade_val, cost, trade_price)\n            self._update_state_from_order(order, trade_val, cost, trade_price)\n\n    def update_current_position(\n        self,\n        trade_start_time: pd.Timestamp,\n        trade_end_time: pd.Timestamp,\n        trade_exchange: Exchange,\n    ) -> None:\n        \"\"\"\n        Update current to make rtn consistent with earning at the end of bar, and update holding bar count of stock\n        \"\"\"\n        # update price for stock in the position and the profit from changed_price\n        # NOTE: updating position does not only serve portfolio metrics, it also serve the strategy\n        assert self.current_position is not None\n\n        if not self.current_position.skip_update():\n            stock_list = self.current_position.get_stock_list()\n            for code in stock_list:\n                # if suspended, no new price to be updated, profit is 0\n                if trade_exchange.check_stock_suspended(code, trade_start_time, trade_end_time):\n                    continue\n                bar_close = cast(float, trade_exchange.get_close(code, trade_start_time, trade_end_time))\n                self.current_position.update_stock_price(stock_id=code, price=bar_close)\n            # update holding day count\n            # NOTE: updating bar_count does not only serve portfolio metrics, it also serve the strategy\n            self.current_position.add_count_all(bar=self.freq)\n\n    def update_portfolio_metrics(self, trade_start_time: pd.Timestamp, trade_end_time: pd.Timestamp) -> None:\n        \"\"\"update portfolio_metrics\"\"\"\n        # calculate earning\n        # account_value - last_account_value\n        # for the first trade date, account_value - init_cash\n        # self.portfolio_metrics.is_empty() to judge is_first_trade_date\n        # get last_account_value, last_total_cost, last_total_turnover\n        assert self.portfolio_metrics is not None\n\n        if self.portfolio_metrics.is_empty():\n            last_account_value = self.init_cash\n            last_total_cost = 0\n            last_total_turnover = 0\n        else:\n            last_account_value = self.portfolio_metrics.get_latest_account_value()\n            last_total_cost = self.portfolio_metrics.get_latest_total_cost()\n            last_total_turnover = self.portfolio_metrics.get_latest_total_turnover()\n\n        # get now_account_value, now_stock_value, now_earning, now_cost, now_turnover\n        now_account_value = self.current_position.calculate_value()\n        now_stock_value = self.current_position.calculate_stock_value()\n        now_earning = now_account_value - last_account_value\n        now_cost = self.accum_info.get_cost - last_total_cost\n        now_turnover = self.accum_info.get_turnover - last_total_turnover\n\n        # update portfolio_metrics for today\n        # judge whether the trading is begin.\n        # and don't add init account state into portfolio_metrics, due to we don't have excess return in those days.\n        self.portfolio_metrics.update_portfolio_metrics_record(\n            trade_start_time=trade_start_time,\n            trade_end_time=trade_end_time,\n            account_value=now_account_value,\n            cash=self.current_position.position[\"cash\"],\n            return_rate=(now_earning + now_cost) / last_account_value,\n            # here use earning to calculate return, position's view, earning consider cost, true return\n            # in order to make same definition with original backtest in evaluate.py\n            total_turnover=self.accum_info.get_turnover,\n            turnover_rate=now_turnover / last_account_value,\n            total_cost=self.accum_info.get_cost,\n            cost_rate=now_cost / last_account_value,\n            stock_value=now_stock_value,\n        )\n\n    def update_hist_positions(self, trade_start_time: pd.Timestamp) -> None:\n        \"\"\"update history position\"\"\"\n        now_account_value = self.current_position.calculate_value()\n        # set now_account_value to position\n        self.current_position.position[\"now_account_value\"] = now_account_value\n        self.current_position.update_weight_all()\n        # update hist_positions\n        # note use deepcopy\n        self.hist_positions[trade_start_time] = copy.deepcopy(self.current_position)\n\n    def update_indicator(\n        self,\n        trade_start_time: pd.Timestamp,\n        trade_exchange: Exchange,\n        atomic: bool,\n        outer_trade_decision: BaseTradeDecision,\n        trade_info: list = [],\n        inner_order_indicators: List[BaseOrderIndicator] = [],\n        decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]] = [],\n        indicator_config: dict = {},\n    ) -> None:\n        \"\"\"update trade indicators and order indicators in each bar end\"\"\"\n        # TODO: will skip empty decisions make it faster?  `outer_trade_decision.empty():`\n\n        # indicator is trading (e.g. high-frequency order execution) related analysis\n        self.indicator.reset()\n\n        # aggregate the information for each order\n        if atomic:\n            self.indicator.update_order_indicators(trade_info)\n        else:\n            self.indicator.agg_order_indicators(\n                inner_order_indicators,\n                decision_list=decision_list,\n                outer_trade_decision=outer_trade_decision,\n                trade_exchange=trade_exchange,\n                indicator_config=indicator_config,\n            )\n\n        # aggregate all the order metrics a single step\n        self.indicator.cal_trade_indicators(trade_start_time, self.freq, indicator_config)\n\n        # record the metrics\n        self.indicator.record(trade_start_time)\n\n    def update_bar_end(\n        self,\n        trade_start_time: pd.Timestamp,\n        trade_end_time: pd.Timestamp,\n        trade_exchange: Exchange,\n        atomic: bool,\n        outer_trade_decision: BaseTradeDecision,\n        trade_info: list = [],\n        inner_order_indicators: List[BaseOrderIndicator] = [],\n        decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]] = [],\n        indicator_config: dict = {},\n    ) -> None:\n        \"\"\"update account at each trading bar step\n\n        Parameters\n        ----------\n        trade_start_time : pd.Timestamp\n            closed start time of step\n        trade_end_time : pd.Timestamp\n            closed end time of step\n        trade_exchange : Exchange\n            trading exchange, used to update current\n        atomic : bool\n            whether the trading executor is atomic, which means there is no higher-frequency trading executor inside it\n            - if atomic is True, calculate the indicators with trade_info\n            - else, aggregate indicators with inner indicators\n        outer_trade_decision: BaseTradeDecision\n            external trade decision\n        trade_info : List[(Order, float, float, float)], optional\n            trading information, by default None\n            - necessary if atomic is True\n            - list of tuple(order, trade_val, trade_cost, trade_price)\n        inner_order_indicators : Indicator, optional\n            indicators of inner executor, by default None\n            - necessary if atomic is False\n            - used to aggregate outer indicators\n        decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]] = None,\n            The decision list of the inner level: List[Tuple[<decision>, <start_time>, <end_time>]]\n            The inner level\n        indicator_config : dict, optional\n            config of calculating indicators, by default {}\n        \"\"\"\n        if atomic is True and trade_info is None:\n            raise ValueError(\"trade_info is necessary in atomic executor\")\n        elif atomic is False and inner_order_indicators is None:\n            raise ValueError(\"inner_order_indicators is necessary in un-atomic executor\")\n\n        # update current position and hold bar count in each bar end\n        self.update_current_position(trade_start_time, trade_end_time, trade_exchange)\n\n        if self.is_port_metr_enabled():\n            # portfolio_metrics is portfolio related analysis\n            self.update_portfolio_metrics(trade_start_time, trade_end_time)\n            self.update_hist_positions(trade_start_time)\n\n        # update indicator in each bar end\n        self.update_indicator(\n            trade_start_time=trade_start_time,\n            trade_exchange=trade_exchange,\n            atomic=atomic,\n            outer_trade_decision=outer_trade_decision,\n            trade_info=trade_info,\n            inner_order_indicators=inner_order_indicators,\n            decision_list=decision_list,\n            indicator_config=indicator_config,\n        )\n\n    def get_portfolio_metrics(self) -> Tuple[pd.DataFrame, dict]:\n        \"\"\"get the history portfolio_metrics and positions instance\"\"\"\n        if self.is_port_metr_enabled():\n            assert self.portfolio_metrics is not None\n            _portfolio_metrics = self.portfolio_metrics.generate_portfolio_metrics_dataframe()\n            _positions = self.get_hist_positions()\n            return _portfolio_metrics, _positions\n        else:\n            raise ValueError(\"generate_portfolio_metrics should be True if you want to generate portfolio_metrics\")\n\n    def get_trade_indicator(self) -> Indicator:\n        \"\"\"get the trade indicator instance, which has pa/pos/ffr info.\"\"\"\n        return self.indicator\n"
  },
  {
    "path": "qlib/backtest/backtest.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom __future__ import annotations\n\nfrom typing import Dict, TYPE_CHECKING, Generator, Optional, Tuple, Union, cast\n\nimport pandas as pd\n\nfrom qlib.backtest.decision import BaseTradeDecision\nfrom qlib.backtest.report import Indicator\n\nif TYPE_CHECKING:\n    from qlib.strategy.base import BaseStrategy\n    from qlib.backtest.executor import BaseExecutor\n\nfrom tqdm.auto import tqdm\n\nfrom ..utils.time import Freq\n\nPORT_METRIC = Dict[str, Tuple[pd.DataFrame, dict]]\nINDICATOR_METRIC = Dict[str, Tuple[pd.DataFrame, Indicator]]\n\n\ndef backtest_loop(\n    start_time: Union[pd.Timestamp, str],\n    end_time: Union[pd.Timestamp, str],\n    trade_strategy: BaseStrategy,\n    trade_executor: BaseExecutor,\n) -> Tuple[PORT_METRIC, INDICATOR_METRIC]:\n    \"\"\"backtest function for the interaction of the outermost strategy and executor in the nested decision execution\n\n    please refer to the docs of `collect_data_loop`\n\n    Returns\n    -------\n    portfolio_dict: PORT_METRIC\n        it records the trading portfolio_metrics information\n    indicator_dict: INDICATOR_METRIC\n        it computes the trading indicator\n    \"\"\"\n    return_value: dict = {}\n    for _decision in collect_data_loop(start_time, end_time, trade_strategy, trade_executor, return_value):\n        pass\n\n    portfolio_dict = cast(PORT_METRIC, return_value.get(\"portfolio_dict\"))\n    indicator_dict = cast(INDICATOR_METRIC, return_value.get(\"indicator_dict\"))\n\n    return portfolio_dict, indicator_dict\n\n\ndef collect_data_loop(\n    start_time: Union[pd.Timestamp, str],\n    end_time: Union[pd.Timestamp, str],\n    trade_strategy: BaseStrategy,\n    trade_executor: BaseExecutor,\n    return_value: dict | None = None,\n) -> Generator[BaseTradeDecision, Optional[BaseTradeDecision], None]:\n    \"\"\"Generator for collecting the trade decision data for rl training\n\n    Parameters\n    ----------\n    start_time : Union[pd.Timestamp, str]\n        closed start time for backtest\n        **NOTE**: This will be applied to the outmost executor's calendar.\n    end_time : Union[pd.Timestamp, str]\n        closed end time for backtest\n        **NOTE**: This will be applied to the outmost executor's calendar.\n        E.g. Executor[day](Executor[1min]), setting `end_time == 20XX0301` will include all the minutes on 20XX0301\n    trade_strategy : BaseStrategy\n        the outermost portfolio strategy\n    trade_executor : BaseExecutor\n        the outermost executor\n    return_value : dict\n        used for backtest_loop\n\n    Yields\n    -------\n    object\n        trade decision\n    \"\"\"\n    trade_executor.reset(start_time=start_time, end_time=end_time)\n    trade_strategy.reset(level_infra=trade_executor.get_level_infra())\n\n    with tqdm(total=trade_executor.trade_calendar.get_trade_len(), desc=\"backtest loop\") as bar:\n        _execute_result = None\n        while not trade_executor.finished():\n            _trade_decision: BaseTradeDecision = trade_strategy.generate_trade_decision(_execute_result)\n            _execute_result = yield from trade_executor.collect_data(_trade_decision, level=0)\n            trade_strategy.post_exe_step(_execute_result)\n            bar.update(1)\n        trade_strategy.post_upper_level_exe_step()\n\n    if return_value is not None:\n        all_executors = trade_executor.get_all_executors()\n\n        portfolio_dict: PORT_METRIC = {}\n        indicator_dict: INDICATOR_METRIC = {}\n\n        for executor in all_executors:\n            key = \"{}{}\".format(*Freq.parse(executor.time_per_step))\n            if executor.trade_account.is_port_metr_enabled():\n                portfolio_dict[key] = executor.trade_account.get_portfolio_metrics()\n\n            indicator_df = executor.trade_account.get_trade_indicator().generate_trade_indicators_dataframe()\n            indicator_obj = executor.trade_account.get_trade_indicator()\n            indicator_dict[key] = (indicator_df, indicator_obj)\n\n        return_value.update({\"portfolio_dict\": portfolio_dict, \"indicator_dict\": indicator_dict})\n"
  },
  {
    "path": "qlib/backtest/decision.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom __future__ import annotations\n\nfrom abc import abstractmethod\nfrom datetime import time\nfrom enum import IntEnum\n\n# try to fix circular imports when enabling type hints\nfrom typing import TYPE_CHECKING, Any, ClassVar, Generic, List, Optional, Tuple, TypeVar, Union, cast\n\nfrom qlib.backtest.utils import TradeCalendarManager\nfrom qlib.data.data import Cal\nfrom qlib.log import get_module_logger\nfrom qlib.utils.time import concat_date_time, epsilon_change\n\nif TYPE_CHECKING:\n    from qlib.strategy.base import BaseStrategy\n    from qlib.backtest.exchange import Exchange\n\nfrom dataclasses import dataclass\n\nimport numpy as np\nimport pandas as pd\n\nDecisionType = TypeVar(\"DecisionType\")\n\n\nclass OrderDir(IntEnum):\n    # Order direction\n    SELL = 0\n    BUY = 1\n\n\n@dataclass\nclass Order:\n    \"\"\"\n    stock_id : str\n    amount : float\n    start_time : pd.Timestamp\n        closed start time for order trading\n    end_time : pd.Timestamp\n        closed end time for order trading\n    direction : int\n        Order.SELL for sell; Order.BUY for buy\n    factor : float\n            presents the weight factor assigned in Exchange()\n    \"\"\"\n\n    # 1) time invariant values\n    # - they are set by users and is time-invariant.\n    stock_id: str\n    amount: float  # `amount` is a non-negative and adjusted value\n    direction: OrderDir\n\n    # 2) time variant values:\n    # - Users may want to set these values when using lower level APIs\n    # - If users don't, TradeDecisionWO will help users to set them\n    # The interval of the order which belongs to (NOTE: this is not the expected order dealing range time)\n    start_time: pd.Timestamp\n    end_time: pd.Timestamp\n\n    # 3) results\n    # - users should not care about these values\n    # - they are set by the backtest system after finishing the results.\n    # What the value should be about in all kinds of cases\n    # - not tradable: the deal_amount == 0 , factor is None\n    #    - the stock is suspended and the entire order fails. No cost for this order\n    # - dealt or partially dealt: deal_amount >= 0 and factor is not None\n    deal_amount: float = 0.0  # `deal_amount` is a non-negative value\n    factor: Optional[float] = None\n\n    # TODO:\n    # a status field to indicate the dealing result of the order\n\n    # FIXME:\n    # for compatible now.\n    # Please remove them in the future\n    SELL: ClassVar[OrderDir] = OrderDir.SELL\n    BUY: ClassVar[OrderDir] = OrderDir.BUY\n\n    def __post_init__(self) -> None:\n        if self.direction not in {Order.SELL, Order.BUY}:\n            raise NotImplementedError(\"direction not supported, `Order.SELL` for sell, `Order.BUY` for buy\")\n        self.deal_amount = 0.0\n        self.factor = None\n\n    @property\n    def amount_delta(self) -> float:\n        \"\"\"\n        return the delta of amount.\n        - Positive value indicates buying `amount` of share\n        - Negative value indicates selling `amount` of share\n        \"\"\"\n        return self.amount * self.sign\n\n    @property\n    def deal_amount_delta(self) -> float:\n        \"\"\"\n        return the delta of deal_amount.\n        - Positive value indicates buying `deal_amount` of share\n        - Negative value indicates selling `deal_amount` of share\n        \"\"\"\n        return self.deal_amount * self.sign\n\n    @property\n    def sign(self) -> int:\n        \"\"\"\n        return the sign of trading\n        - `+1` indicates buying\n        - `-1` value indicates selling\n        \"\"\"\n        return self.direction * 2 - 1\n\n    @staticmethod\n    def parse_dir(direction: Union[str, int, np.integer, OrderDir, np.ndarray]) -> Union[OrderDir, np.ndarray]:\n        if isinstance(direction, OrderDir):\n            return direction\n        elif isinstance(direction, (int, float, np.integer, np.floating)):\n            return Order.BUY if direction > 0 else Order.SELL\n        elif isinstance(direction, str):\n            dl = direction.lower().strip()\n            if dl == \"sell\":\n                return OrderDir.SELL\n            elif dl == \"buy\":\n                return OrderDir.BUY\n            else:\n                raise NotImplementedError(f\"This type of input is not supported\")\n        elif isinstance(direction, np.ndarray):\n            direction_array = direction.copy()\n            direction_array[direction_array > 0] = Order.BUY\n            direction_array[direction_array <= 0] = Order.SELL\n            return direction_array\n        else:\n            raise NotImplementedError(f\"This type of input is not supported\")\n\n    @property\n    def key_by_day(self) -> tuple:\n        \"\"\"A hashable & unique key to identify this order, under the granularity in day.\"\"\"\n        return self.stock_id, self.date, self.direction\n\n    @property\n    def key(self) -> tuple:\n        \"\"\"A hashable & unique key to identify this order.\"\"\"\n        return self.stock_id, self.start_time, self.end_time, self.direction\n\n    @property\n    def date(self) -> pd.Timestamp:\n        \"\"\"Date of the order.\"\"\"\n        return pd.Timestamp(self.start_time.replace(hour=0, minute=0, second=0))\n\n\nclass OrderHelper:\n    \"\"\"\n    Motivation\n    - Make generating order easier\n        - User may have no knowledge about the adjust-factor information about the system.\n        - It involves too much interaction with the exchange when generating orders.\n    \"\"\"\n\n    def __init__(self, exchange: Exchange) -> None:\n        self.exchange = exchange\n\n    @staticmethod\n    def create(\n        code: str,\n        amount: float,\n        direction: OrderDir,\n        start_time: Union[str, pd.Timestamp] = None,\n        end_time: Union[str, pd.Timestamp] = None,\n    ) -> Order:\n        \"\"\"\n        help to create a order\n\n        # TODO: create order for unadjusted amount order\n\n        Parameters\n        ----------\n        code : str\n            the id of the instrument\n        amount : float\n            **adjusted trading amount**\n        direction : OrderDir\n            trading  direction\n        start_time : Union[str, pd.Timestamp] (optional)\n            The interval of the order which belongs to\n        end_time : Union[str, pd.Timestamp] (optional)\n            The interval of the order which belongs to\n\n        Returns\n        -------\n        Order:\n            The created order\n        \"\"\"\n        # NOTE: factor is a value belongs to the results section. User don't have to care about it when creating orders\n        return Order(\n            stock_id=code,\n            amount=amount,\n            start_time=None if start_time is None else pd.Timestamp(start_time),\n            end_time=None if end_time is None else pd.Timestamp(end_time),\n            direction=direction,\n        )\n\n\nclass TradeRange:\n    @abstractmethod\n    def __call__(self, trade_calendar: TradeCalendarManager) -> Tuple[int, int]:\n        \"\"\"\n        This method will be call with following way\n\n        The outer strategy give a decision with with `TradeRange`\n        The decision will be checked by the inner decision.\n        inner decision will pass its trade_calendar as parameter when getting the trading range\n        - The framework's step is integer-index based.\n\n        Parameters\n        ----------\n        trade_calendar : TradeCalendarManager\n            the trade_calendar is from inner strategy\n\n        Returns\n        -------\n        Tuple[int, int]:\n            the start index and end index which are tradable\n\n        Raises\n        ------\n        NotImplementedError:\n            Exceptions are raised when no range limitation\n        \"\"\"\n        raise NotImplementedError(f\"Please implement the `__call__` method\")\n\n    @abstractmethod\n    def clip_time_range(self, start_time: pd.Timestamp, end_time: pd.Timestamp) -> Tuple[pd.Timestamp, pd.Timestamp]:\n        \"\"\"\n        Parameters\n        ----------\n        start_time : pd.Timestamp\n        end_time : pd.Timestamp\n            Both sides (start_time, end_time) are closed\n\n        Returns\n        -------\n        Tuple[pd.Timestamp, pd.Timestamp]:\n            The tradable time range.\n            - It is intersection of [start_time, end_time] and the rule of TradeRange itself\n        \"\"\"\n        raise NotImplementedError(f\"Please implement the `clip_time_range` method\")\n\n\nclass IdxTradeRange(TradeRange):\n    def __init__(self, start_idx: int, end_idx: int) -> None:\n        self._start_idx = start_idx\n        self._end_idx = end_idx\n\n    def __call__(self, trade_calendar: TradeCalendarManager | None = None) -> Tuple[int, int]:\n        return self._start_idx, self._end_idx\n\n    def clip_time_range(self, start_time: pd.Timestamp, end_time: pd.Timestamp) -> Tuple[pd.Timestamp, pd.Timestamp]:\n        raise NotImplementedError\n\n\nclass TradeRangeByTime(TradeRange):\n    \"\"\"This is a helper function for make decisions\"\"\"\n\n    def __init__(self, start_time: str | time, end_time: str | time) -> None:\n        \"\"\"\n        This is a callable class.\n\n        **NOTE**:\n        - It is designed for minute-bar for intra-day trading!!!!!\n        - Both start_time and end_time are **closed** in the range\n\n        Parameters\n        ----------\n        start_time : str | time\n            e.g. \"9:30\"\n        end_time : str | time\n            e.g. \"14:30\"\n        \"\"\"\n        self.start_time = pd.Timestamp(start_time).time() if isinstance(start_time, str) else start_time\n        self.end_time = pd.Timestamp(end_time).time() if isinstance(end_time, str) else end_time\n        assert self.start_time < self.end_time\n\n    def __call__(self, trade_calendar: TradeCalendarManager) -> Tuple[int, int]:\n        if trade_calendar is None:\n            raise NotImplementedError(\"trade_calendar is necessary for getting TradeRangeByTime.\")\n\n        start_date = trade_calendar.start_time.date()\n        val_start, val_end = concat_date_time(start_date, self.start_time), concat_date_time(start_date, self.end_time)\n        return trade_calendar.get_range_idx(val_start, val_end)\n\n    def clip_time_range(self, start_time: pd.Timestamp, end_time: pd.Timestamp) -> Tuple[pd.Timestamp, pd.Timestamp]:\n        start_date = start_time.date()\n        val_start, val_end = concat_date_time(start_date, self.start_time), concat_date_time(start_date, self.end_time)\n        # NOTE: `end_date` should not be used. Because the `end_date` is for slicing. It may be in the next day\n        # Assumption: start_time and end_time is for intra-day trading. So it is OK for only using start_date\n        return max(val_start, start_time), min(val_end, end_time)\n\n\nclass BaseTradeDecision(Generic[DecisionType]):\n    \"\"\"\n    Trade decisions are made by strategy and executed by executor\n\n    Motivation:\n        Here are several typical scenarios for `BaseTradeDecision`\n\n        Case 1:\n        1. Outer strategy makes a decision. The decision is not available at the start of current interval\n        2. After a period of time, the decision are updated and become available\n        3. The inner strategy try to get the decision and start to execute the decision according to `get_range_limit`\n        Case 2:\n        1. The outer strategy's decision is available at the start of the interval\n        2. Same as `case 1.3`\n    \"\"\"\n\n    def __init__(self, strategy: BaseStrategy, trade_range: Union[Tuple[int, int], TradeRange, None] = None) -> None:\n        \"\"\"\n        Parameters\n        ----------\n        strategy : BaseStrategy\n            The strategy who make the decision\n        trade_range: Union[Tuple[int, int], Callable] (optional)\n            The index range for underlying strategy.\n\n            Here are two examples of trade_range for each type\n\n            1) Tuple[int, int]\n            start_index and end_index of the underlying strategy(both sides are closed)\n\n            2) TradeRange\n\n        \"\"\"\n        self.strategy = strategy\n        self.start_time, self.end_time = strategy.trade_calendar.get_step_time()\n        # upper strategy has no knowledge about the sub executor before `_init_sub_trading`\n        self.total_step: Optional[int] = None\n        if isinstance(trade_range, tuple):\n            # for Tuple[int, int]\n            trade_range = IdxTradeRange(*trade_range)\n        self.trade_range: Optional[TradeRange] = trade_range\n\n    def get_decision(self) -> List[DecisionType]:\n        \"\"\"\n        get the **concrete decision**  (e.g. execution orders)\n        This will be called by the inner strategy\n\n        Returns\n        -------\n        List[DecisionType:\n            The decision result. Typically it is some orders\n            Example:\n                []:\n                    Decision not available\n                [concrete_decision]:\n                    available\n        \"\"\"\n        raise NotImplementedError(f\"This type of input is not supported\")\n\n    def update(self, trade_calendar: TradeCalendarManager) -> Optional[BaseTradeDecision]:\n        \"\"\"\n        Be called at the **start** of each step.\n\n        This function is design for following purpose\n        1) Leave a hook for the strategy who make `self` decision to update the decision itself\n        2) Update some information from the inner executor calendar\n\n        Parameters\n        ----------\n        trade_calendar : TradeCalendarManager\n            The calendar of the **inner strategy**!!!!!\n\n        Returns\n        -------\n        BaseTradeDecision:\n            New update, use new decision. If no updates, return None (use previous decision (or unavailable))\n        \"\"\"\n        # purpose 1)\n        self.total_step = trade_calendar.get_trade_len()\n\n        # purpose 2)\n        return self.strategy.update_trade_decision(self, trade_calendar)\n\n    def _get_range_limit(self, **kwargs: Any) -> Tuple[int, int]:\n        if self.trade_range is not None:\n            return self.trade_range(trade_calendar=cast(TradeCalendarManager, kwargs.get(\"inner_calendar\")))\n        else:\n            raise NotImplementedError(\"The decision didn't provide an index range\")\n\n    def get_range_limit(self, **kwargs: Any) -> Tuple[int, int]:\n        \"\"\"\n        return the expected step range for limiting the decision execution time\n        Both left and right are **closed**\n\n        if no available trade_range, `default_value` will be returned\n\n        It is only used in `NestedExecutor`\n        - The outmost strategy will not follow any range limit (but it may give range_limit)\n        - The inner most strategy's range_limit will be useless due to atomic executors don't have such\n          features.\n\n        **NOTE**:\n        1) This function must be called after `self.update` in following cases(ensured by NestedExecutor):\n        - user relies on the auto-clip feature of `self.update`\n\n        2) This function will be called after _init_sub_trading in NestedExecutor.\n\n        Parameters\n        ----------\n        **kwargs:\n            {\n                \"default_value\": <default_value>, # using dict is for distinguish no value provided or None provided\n                \"inner_calendar\": <trade calendar of inner strategy>\n                # because the range limit  will control the step range of inner strategy, inner calendar will be a\n                # important parameter when trade_range is callable\n            }\n\n        Returns\n        -------\n        Tuple[int, int]:\n\n        Raises\n        ------\n        NotImplementedError:\n            If the following criteria meet\n            1) the decision can't provide a unified start and end\n            2) default_value is not provided\n        \"\"\"\n        try:\n            _start_idx, _end_idx = self._get_range_limit(**kwargs)\n        except NotImplementedError as e:\n            if \"default_value\" in kwargs:\n                return kwargs[\"default_value\"]\n            else:\n                # Default to get full index\n                raise NotImplementedError(f\"The decision didn't provide an index range\") from e\n\n        # clip index\n        if getattr(self, \"total_step\", None) is not None:\n            # if `self.update` is called.\n            # Then the _start_idx, _end_idx should be clipped\n            assert self.total_step is not None\n            if _start_idx < 0 or _end_idx >= self.total_step:\n                logger = get_module_logger(\"decision\")\n                logger.warning(\n                    f\"[{_start_idx},{_end_idx}] go beyond the total_step({self.total_step}), it will be clipped.\",\n                )\n                _start_idx, _end_idx = max(0, _start_idx), min(self.total_step - 1, _end_idx)\n        return _start_idx, _end_idx\n\n    def get_data_cal_range_limit(self, rtype: str = \"full\", raise_error: bool = False) -> Tuple[int, int]:\n        \"\"\"\n        get the range limit based on data calendar\n\n        NOTE: it is **total** range limit instead of a single step\n\n        The following assumptions are made\n        1) The frequency of the exchange in common_infra is the same as the data calendar\n        2) Users want the index mod by **day** (i.e. 240 min)\n\n        Parameters\n        ----------\n        rtype: str\n            - \"full\": return the full limitation of the decision in the day\n            - \"step\": return the limitation of current step\n\n        raise_error: bool\n            True: raise error if no trade_range is set\n            False: return full trade calendar.\n\n            It is useful in following cases\n            - users want to follow the order specific trading time range when decision level trade range is not\n              available. Raising NotImplementedError to indicates that range limit is not available\n\n        Returns\n        -------\n        Tuple[int, int]:\n            the range limit in data calendar\n\n        Raises\n        ------\n        NotImplementedError:\n            If the following criteria meet\n            1) the decision can't provide a unified start and end\n            2) raise_error is True\n        \"\"\"\n        # potential performance issue\n        day_start = pd.Timestamp(self.start_time.date())\n        day_end = epsilon_change(day_start + pd.Timedelta(days=1))\n        freq = self.strategy.trade_exchange.freq\n        _, _, day_start_idx, day_end_idx = Cal.locate_index(day_start, day_end, freq=freq)\n        if self.trade_range is None:\n            if raise_error:\n                raise NotImplementedError(f\"There is no trade_range in this case\")\n            else:\n                return 0, day_end_idx - day_start_idx\n        else:\n            if rtype == \"full\":\n                val_start, val_end = self.trade_range.clip_time_range(day_start, day_end)\n            elif rtype == \"step\":\n                val_start, val_end = self.trade_range.clip_time_range(self.start_time, self.end_time)\n            else:\n                raise ValueError(f\"This type of input {rtype} is not supported\")\n            _, _, start_idx, end_index = Cal.locate_index(val_start, val_end, freq=freq)\n            return start_idx - day_start_idx, end_index - day_start_idx\n\n    def empty(self) -> bool:\n        for obj in self.get_decision():\n            if isinstance(obj, Order):\n                # Zero amount order will be treated as empty\n                if obj.amount > 1e-6:\n                    return False\n            else:\n                return True\n        return True\n\n    def mod_inner_decision(self, inner_trade_decision: BaseTradeDecision) -> None:\n        \"\"\"\n        This method will be called on the inner_trade_decision after it is generated.\n        `inner_trade_decision` will be changed **inplace**.\n\n        Motivation of the `mod_inner_decision`\n        - Leave a hook for outer decision to affect the decision generated by the inner strategy\n            - e.g. the outmost strategy generate a time range for trading. But the upper layer can only affect the\n              nearest layer in the original design.  With `mod_inner_decision`, the decision can passed through multiple\n              layers\n\n        Parameters\n        ----------\n        inner_trade_decision : BaseTradeDecision\n        \"\"\"\n        # base class provide a default behaviour to modify inner_trade_decision\n        # trade_range should be propagated when inner trade_range is not set\n        if inner_trade_decision.trade_range is None:\n            inner_trade_decision.trade_range = self.trade_range\n\n\nclass EmptyTradeDecision(BaseTradeDecision[object]):\n    def get_decision(self) -> List[object]:\n        return []\n\n    def empty(self) -> bool:\n        return True\n\n\nclass TradeDecisionWO(BaseTradeDecision[Order]):\n    \"\"\"\n    Trade Decision (W)ith (O)rder.\n    Besides, the time_range is also included.\n    \"\"\"\n\n    def __init__(\n        self,\n        order_list: List[Order],\n        strategy: BaseStrategy,\n        trade_range: Union[Tuple[int, int], TradeRange, None] = None,\n    ) -> None:\n        super().__init__(strategy, trade_range=trade_range)\n        self.order_list = cast(List[Order], order_list)\n        start, end = strategy.trade_calendar.get_step_time()\n        for o in order_list:\n            assert isinstance(o, Order)\n            if o.start_time is None:\n                o.start_time = start\n            if o.end_time is None:\n                o.end_time = end\n\n    def get_decision(self) -> List[Order]:\n        return self.order_list\n\n    def __repr__(self) -> str:\n        return (\n            f\"class: {self.__class__.__name__}; \"\n            f\"strategy: {self.strategy}; \"\n            f\"trade_range: {self.trade_range}; \"\n            f\"order_list[{len(self.order_list)}]\"\n        )\n\n\nclass TradeDecisionWithDetails(TradeDecisionWO):\n    \"\"\"\n    Decision with detail information.\n    Detail information is used to generate execution reports.\n    \"\"\"\n\n    def __init__(\n        self,\n        order_list: List[Order],\n        strategy: BaseStrategy,\n        trade_range: Optional[Tuple[int, int]] = None,\n        details: Optional[Any] = None,\n    ) -> None:\n        super().__init__(order_list, strategy, trade_range)\n\n        self.details = details\n"
  },
  {
    "path": "qlib/backtest/exchange.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\nfrom __future__ import annotations\n\nfrom collections import defaultdict\nfrom typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union, cast\n\nfrom ..utils.index_data import IndexData\n\nif TYPE_CHECKING:\n    from .account import Account\n\nimport random\n\nimport numpy as np\nimport pandas as pd\n\nfrom qlib.backtest.position import BasePosition\n\nfrom ..config import C\nfrom ..constant import REG_CN, REG_TW\nfrom ..data.data import D\nfrom ..log import get_module_logger\nfrom .decision import Order, OrderDir, OrderHelper\nfrom .high_performance_ds import BaseQuote, NumpyQuote\n\n\nclass Exchange:\n    # `quote_df` is a pd.DataFrame class that contains basic information for backtesting\n    # After some processing, the data will later be maintained by `quote_cls` object for faster data retrieving.\n    # Some conventions for `quote_df`\n    # - $close is for calculating the total value at end of each day.\n    #   - if $close is None, the stock on that day is regarded as suspended.\n    # - $factor is for rounding to the trading unit;\n    #   - if any $factor is missing when $close exists, trading unit rounding will be disabled\n    quote_df: pd.DataFrame\n\n    def __init__(\n        self,\n        freq: str = \"day\",\n        start_time: Union[pd.Timestamp, str] = None,\n        end_time: Union[pd.Timestamp, str] = None,\n        codes: Union[list, str] = \"all\",\n        deal_price: Union[str, Tuple[str, str], List[str], None] = None,\n        subscribe_fields: list = [],\n        limit_threshold: Union[Tuple[str, str], float, None] = None,\n        volume_threshold: Union[tuple, dict, None] = None,\n        open_cost: float = 0.0015,\n        close_cost: float = 0.0025,\n        min_cost: float = 5.0,\n        impact_cost: float = 0.0,\n        extra_quote: pd.DataFrame = None,\n        quote_cls: Type[BaseQuote] = NumpyQuote,\n        **kwargs: Any,\n    ) -> None:\n        \"\"\"__init__\n        :param freq:             frequency of data\n        :param start_time:       closed start time for backtest\n        :param end_time:         closed end time for backtest\n        :param codes:            list stock_id list or a string of instruments(i.e. all, csi500, sse50)\n        :param deal_price:      Union[str, Tuple[str, str], List[str]]\n                                The `deal_price` supports following two types of input\n                                - <deal_price> : str\n                                - (<buy_price>, <sell_price>): Tuple[str] or List[str]\n                                <deal_price>, <buy_price> or <sell_price> := <price>\n                                <price> := str\n                                - for example '$close', '$open', '$vwap' (\"close\" is OK. `Exchange` will help to prepend\n                                  \"$\" to the expression)\n        :param subscribe_fields: list, subscribe fields. This expressions will be added to the query and `self.quote`.\n                                 It is useful when users want more fields to be queried\n        :param limit_threshold: Union[Tuple[str, str], float, None]\n                                1) `None`: no limitation\n                                2) float, 0.1 for example, default None\n                                3) Tuple[str, str]: (<the expression for buying stock limitation>,\n                                                     <the expression for sell stock limitation>)\n                                                    `False` value indicates the stock is tradable\n                                                    `True` value indicates the stock is limited and not tradable\n        :param volume_threshold: Union[\n                                    Dict[\n                                        \"all\": (\"cum\" or \"current\", limit_str),\n                                        \"buy\": (\"cum\" or \"current\", limit_str),\n                                        \"sell\":(\"cum\" or \"current\", limit_str),\n                                    ],\n                                    (\"cum\" or \"current\", limit_str),\n                                 ]\n                                1) (\"cum\" or \"current\", limit_str) denotes a single volume limit.\n                                    - limit_str is qlib data expression which is allowed to define your own Operator.\n                                    Please refer to qlib/contrib/ops/high_freq.py, here are any custom operator for\n                                    high frequency, such as DayCumsum. !!!NOTE: if you want you use the custom\n                                    operator, you need to register it in qlib_init.\n                                    - \"cum\" means that this is a cumulative value over time, such as cumulative market\n                                    volume. So when it is used as a volume limit, it is necessary to subtract the dealt\n                                    amount.\n                                    - \"current\" means that this is a real-time value and will not accumulate over time,\n                                    so it can be directly used as a capacity limit.\n                                    e.g. (\"cum\", \"0.2 * DayCumsum($volume, '9:45', '14:45')\"), (\"current\", \"$bidV1\")\n                                2) \"all\" means the volume limits are both buying and selling.\n                                \"buy\" means the volume limits of buying. \"sell\" means the volume limits of selling.\n                                Different volume limits will be aggregated with min(). If volume_threshold is only\n                                (\"cum\" or \"current\", limit_str) instead of a dict, the volume limits are for\n                                both by default. In other words, it is same as {\"all\": (\"cum\" or \"current\", limit_str)}.\n                                3) e.g. \"volume_threshold\": {\n                                            \"all\": (\"cum\", \"0.2 * DayCumsum($volume, '9:45', '14:45')\"),\n                                            \"buy\": (\"current\", \"$askV1\"),\n                                            \"sell\": (\"current\", \"$bidV1\"),\n                                        }\n        :param open_cost:        cost rate for open, default 0.0015\n        :param close_cost:       cost rate for close, default 0.0025\n        :param trade_unit:       trade unit, 100 for China A market.\n                                 None for disable trade unit.\n                                 **NOTE**: `trade_unit` is included in the `kwargs`. It is necessary because we must\n                                 distinguish `not set` and `disable trade_unit`\n        :param min_cost:         min cost, default 5\n        :param impact_cost:     market impact cost rate (a.k.a. slippage). A recommended value is 0.1.\n        :param extra_quote:     pandas, dataframe consists of\n                                    columns: like ['$vwap', '$close', '$volume', '$factor', 'limit_sell', 'limit_buy'].\n                                            The limit indicates that the etf is tradable on a specific day.\n                                            Necessary fields:\n                                                $close is for calculating the total value at end of each day.\n                                            Optional fields:\n                                                $volume is only necessary when we limit the trade amount or calculate\n                                                PA(vwap) indicator\n                                                $vwap is only necessary when we use the $vwap price as the deal price\n                                                $factor is for rounding to the trading unit\n                                                limit_sell will be set to False by default (False indicates we can sell\n                                                this target on this day).\n                                                limit_buy will be set to False by default (False indicates we can buy\n                                                this target on this day).\n                                    index: MultipleIndex(instrument, pd.Datetime)\n        \"\"\"\n        self.freq = freq\n        self.start_time = start_time\n        self.end_time = end_time\n\n        self.trade_unit = kwargs.pop(\"trade_unit\", C.trade_unit)\n        if len(kwargs) > 0:\n            raise ValueError(f\"Get Unexpected arguments {kwargs}\")\n\n        if limit_threshold is None:\n            limit_threshold = C.limit_threshold\n        if deal_price is None:\n            deal_price = C.deal_price\n\n        # we have some verbose information here. So logging is enabled\n        self.logger = get_module_logger(\"online operator\")\n\n        # TODO: the quote, trade_dates, codes are not necessary.\n        # It is just for performance consideration.\n        self.limit_type = self._get_limit_type(limit_threshold)\n        if limit_threshold is None:\n            if C.region in [REG_CN, REG_TW]:\n                self.logger.warning(f\"limit_threshold not set. The stocks hit the limit may be bought/sold\")\n        elif self.limit_type == self.LT_FLT and abs(cast(float, limit_threshold)) > 0.1:\n            if C.region in [REG_CN, REG_TW]:\n                self.logger.warning(f\"limit_threshold may not be set to a reasonable value\")\n\n        if isinstance(deal_price, str):\n            if deal_price[0] != \"$\":\n                deal_price = \"$\" + deal_price\n            self.buy_price = self.sell_price = deal_price\n        elif isinstance(deal_price, (tuple, list)):\n            self.buy_price, self.sell_price = cast(Tuple[str, str], deal_price)\n        else:\n            raise NotImplementedError(f\"This type of input is not supported\")\n\n        if isinstance(codes, str):\n            codes = D.instruments(codes)\n        self.codes = codes\n        # Necessary fields\n        # $close is for calculating the total value at end of each day.\n        # - if $close is None, the stock on that day is regarded as suspended.\n        # $factor is for rounding to the trading unit\n        # $change is for calculating the limit of the stock\n\n        # 　get volume limit from kwargs\n        self.buy_vol_limit, self.sell_vol_limit, vol_lt_fields = self._get_vol_limit(volume_threshold)\n\n        necessary_fields = {self.buy_price, self.sell_price, \"$close\", \"$change\", \"$factor\", \"$volume\"}\n        if self.limit_type == self.LT_TP_EXP:\n            assert isinstance(limit_threshold, tuple)\n            for exp in limit_threshold:\n                necessary_fields.add(exp)\n        all_fields = list(necessary_fields | set(vol_lt_fields) | set(subscribe_fields))\n\n        self.all_fields = all_fields\n\n        self.open_cost = open_cost\n        self.close_cost = close_cost\n        self.min_cost = min_cost\n        self.impact_cost = impact_cost\n\n        self.limit_threshold: Union[Tuple[str, str], float, None] = limit_threshold\n        self.volume_threshold = volume_threshold\n        self.extra_quote = extra_quote\n        self.get_quote_from_qlib()\n\n        # init quote by quote_df\n        self.quote_cls = quote_cls\n        self.quote: BaseQuote = self.quote_cls(self.quote_df, freq)\n\n    def get_quote_from_qlib(self) -> None:\n        # get stock data from qlib\n        if len(self.codes) == 0:\n            self.codes = D.instruments()\n        self.quote_df = D.features(\n            self.codes,\n            self.all_fields,\n            self.start_time,\n            self.end_time,\n            freq=self.freq,\n            disk_cache=True,\n        )\n        self.quote_df.columns = self.all_fields\n\n        # check buy_price data and sell_price data\n        for attr in (\"buy_price\", \"sell_price\"):\n            pstr = getattr(self, attr)  # price string\n            if self.quote_df[pstr].isna().any():\n                self.logger.warning(\"{} field data contains nan.\".format(pstr))\n\n        # update trade_w_adj_price\n        if (self.quote_df[\"$factor\"].isna() & ~self.quote_df[\"$close\"].isna()).any():\n            # The 'factor.day.bin' file not exists, and `factor` field contains `nan`\n            # Use adjusted price\n            self.trade_w_adj_price = True\n            self.logger.warning(\"factor.day.bin file not exists or factor contains `nan`. Order using adjusted_price.\")\n            if self.trade_unit is not None:\n                self.logger.warning(f\"trade unit {self.trade_unit} is not supported in adjusted_price mode.\")\n        else:\n            # The `factor.day.bin` file exists and all data `close` and `factor` are not `nan`\n            # Use normal price\n            self.trade_w_adj_price = False\n        # update limit\n        self._update_limit(self.limit_threshold)\n\n        # concat extra_quote\n        if self.extra_quote is not None:\n            # process extra_quote\n            if \"$close\" not in self.extra_quote:\n                raise ValueError(\"$close is necessray in extra_quote\")\n            for attr in \"buy_price\", \"sell_price\":\n                pstr = getattr(self, attr)  # price string\n                if pstr not in self.extra_quote.columns:\n                    self.extra_quote[pstr] = self.extra_quote[\"$close\"]\n                    self.logger.warning(f\"No {pstr} set for extra_quote. Use $close as {pstr}.\")\n            if \"$factor\" not in self.extra_quote.columns:\n                self.extra_quote[\"$factor\"] = 1.0\n                self.logger.warning(\"No $factor set for extra_quote. Use 1.0 as $factor.\")\n            if \"limit_sell\" not in self.extra_quote.columns:\n                self.extra_quote[\"limit_sell\"] = False\n                self.logger.warning(\"No limit_sell set for extra_quote. All stock will be able to be sold.\")\n            if \"limit_buy\" not in self.extra_quote.columns:\n                self.extra_quote[\"limit_buy\"] = False\n                self.logger.warning(\"No limit_buy set for extra_quote. All stock will be able to be bought.\")\n            assert set(self.extra_quote.columns) == set(self.quote_df.columns) - {\"$change\"}\n            self.quote_df = pd.concat([self.quote_df, self.extra_quote], sort=False, axis=0)\n\n    LT_TP_EXP = \"(exp)\"  # Tuple[str, str]:  the limitation is calculated by a Qlib expression.\n    LT_FLT = \"float\"  # float:  the trading limitation is based on `abs($change) < limit_threshold`\n    LT_NONE = \"none\"  # none:  there is no trading limitation\n\n    def _get_limit_type(self, limit_threshold: Union[tuple, float, None]) -> str:\n        \"\"\"get limit type\"\"\"\n        if isinstance(limit_threshold, tuple):\n            return self.LT_TP_EXP\n        elif isinstance(limit_threshold, float):\n            return self.LT_FLT\n        elif limit_threshold is None:\n            return self.LT_NONE\n        else:\n            raise NotImplementedError(f\"This type of `limit_threshold` is not supported\")\n\n    def _update_limit(self, limit_threshold: Union[Tuple, float, None]) -> None:\n        # $close may contain NaN, the nan indicates that the stock is not tradable at that timestamp\n        suspended = self.quote_df[\"$close\"].isna()\n        # check limit_threshold\n        limit_type = self._get_limit_type(limit_threshold)\n        if limit_type == self.LT_NONE:\n            self.quote_df[\"limit_buy\"] = suspended\n            self.quote_df[\"limit_sell\"] = suspended\n        elif limit_type == self.LT_TP_EXP:\n            # set limit\n            limit_threshold = cast(tuple, limit_threshold)\n            # astype bool is necessary, because quote_df is an expression and could be float\n            self.quote_df[\"limit_buy\"] = self.quote_df[limit_threshold[0]].astype(\"bool\") | suspended\n            self.quote_df[\"limit_sell\"] = self.quote_df[limit_threshold[1]].astype(\"bool\") | suspended\n        elif limit_type == self.LT_FLT:\n            limit_threshold = cast(float, limit_threshold)\n            self.quote_df[\"limit_buy\"] = self.quote_df[\"$change\"].ge(limit_threshold) | suspended\n            self.quote_df[\"limit_sell\"] = (\n                self.quote_df[\"$change\"].le(-limit_threshold) | suspended\n            )  # pylint: disable=E1130\n\n    @staticmethod\n    def _get_vol_limit(volume_threshold: Union[tuple, dict, None]) -> Tuple[Optional[list], Optional[list], set]:\n        \"\"\"\n        preprocess the volume limit.\n        get the fields need to get from qlib.\n        get the volume limit list of buying and selling which is composed of all limits.\n        Parameters\n        ----------\n        volume_threshold :\n            please refer to the doc of exchange.\n        Returns\n        -------\n        fields: set\n            the fields need to get from qlib.\n        buy_vol_limit: List[Tuple[str]]\n            all volume limits of buying.\n        sell_vol_limit: List[Tuple[str]]\n            all volume limits of selling.\n        Raises\n        ------\n        ValueError\n            the format of volume_threshold is not supported.\n        \"\"\"\n        if volume_threshold is None:\n            return None, None, set()\n\n        fields = set()\n        buy_vol_limit = []\n        sell_vol_limit = []\n        if isinstance(volume_threshold, tuple):\n            volume_threshold = {\"all\": volume_threshold}\n\n        assert isinstance(volume_threshold, dict)\n        for key, vol_limit in volume_threshold.items():\n            assert isinstance(vol_limit, tuple)\n            fields.add(vol_limit[1])\n\n            if key in (\"buy\", \"all\"):\n                buy_vol_limit.append(vol_limit)\n            if key in (\"sell\", \"all\"):\n                sell_vol_limit.append(vol_limit)\n\n        return buy_vol_limit, sell_vol_limit, fields\n\n    def check_stock_limit(\n        self,\n        stock_id: str,\n        start_time: pd.Timestamp,\n        end_time: pd.Timestamp,\n        direction: int | None = None,\n    ) -> bool:\n        \"\"\"\n        Parameters\n        ----------\n        stock_id : str\n        start_time: pd.Timestamp\n        end_time: pd.Timestamp\n        direction : int, optional\n            trade direction, by default None\n            - if direction is None, check if tradable for buying and selling.\n            - if direction == Order.BUY, check the if tradable for buying\n            - if direction == Order.SELL, check the sell limit for selling.\n\n        Returns\n        -------\n        True: the trading of the stock is limited (maybe hit the highest/lowest price), hence the stock is not tradable\n        False: the trading of the stock is not limited, hence the stock may be tradable\n        \"\"\"\n        # NOTE:\n        # **all** is used when checking limitation.\n        # For example, the stock trading is limited in a day if every minute is limited in a day if every minute is limited.\n        if direction is None:\n            # The trading limitation is related to the trading direction\n            # if the direction is not provided, then any limitation from buy or sell will result in trading limitation\n            buy_limit = self.quote.get_data(stock_id, start_time, end_time, field=\"limit_buy\", method=\"all\")\n            sell_limit = self.quote.get_data(stock_id, start_time, end_time, field=\"limit_sell\", method=\"all\")\n            return bool(buy_limit or sell_limit)\n        elif direction == Order.BUY:\n            return cast(bool, self.quote.get_data(stock_id, start_time, end_time, field=\"limit_buy\", method=\"all\"))\n        elif direction == Order.SELL:\n            return cast(bool, self.quote.get_data(stock_id, start_time, end_time, field=\"limit_sell\", method=\"all\"))\n        else:\n            raise ValueError(f\"direction {direction} is not supported!\")\n\n    def check_stock_suspended(\n        self,\n        stock_id: str,\n        start_time: pd.Timestamp,\n        end_time: pd.Timestamp,\n    ) -> bool:\n        \"\"\"if stock is suspended(hence not tradable), True will be returned\"\"\"\n        # is suspended\n        if stock_id in self.quote.get_all_stock():\n            # suspended stocks are represented by None $close stock\n            # The $close may contain NaN,\n            close = self.quote.get_data(stock_id, start_time, end_time, \"$close\")\n            if close is None:\n                # if no close record exists\n                return True\n            elif isinstance(close, IndexData):\n                # **any** non-NaN $close represents trading opportunity may exist\n                #  if all returned is nan, then the stock is suspended\n                return cast(bool, cast(IndexData, close).isna().all())\n            else:\n                # it is single value, make sure is not None\n                return np.isnan(close)\n        else:\n            # if the stock is not in the stock list, then it is not tradable and regarded as suspended\n            return True\n\n    def is_stock_tradable(\n        self,\n        stock_id: str,\n        start_time: pd.Timestamp,\n        end_time: pd.Timestamp,\n        direction: int | None = None,\n    ) -> bool:\n        # check if stock can be traded\n        return not (\n            self.check_stock_suspended(stock_id, start_time, end_time)\n            or self.check_stock_limit(stock_id, start_time, end_time, direction)\n        )\n\n    def check_order(self, order: Order) -> bool:\n        # check limit and suspended\n        return self.is_stock_tradable(order.stock_id, order.start_time, order.end_time, order.direction)\n\n    def deal_order(\n        self,\n        order: Order,\n        trade_account: Account | None = None,\n        position: BasePosition | None = None,\n        dealt_order_amount: Dict[str, float] = defaultdict(float),\n    ) -> Tuple[float, float, float]:\n        \"\"\"\n        Deal order when the actual transaction\n        the results section in `Order` will be changed.\n        :param order:  Deal the order.\n        :param trade_account: Trade account to be updated after dealing the order.\n        :param position: position to be updated after dealing the order.\n        :param dealt_order_amount: the dealt order amount dict with the format of {stock_id: float}\n        :return: trade_val, trade_cost, trade_price\n        \"\"\"\n        # check order first.\n        if not self.check_order(order):\n            order.deal_amount = 0.0\n            # using np.nan instead of None to make it more convenient to show the value in format string\n            self.logger.debug(f\"Order failed due to trading limitation: {order}\")\n            return 0.0, 0.0, np.nan\n\n        if trade_account is not None and position is not None:\n            raise ValueError(\"trade_account and position can only choose one\")\n\n        # NOTE: order will be changed in this function\n        trade_price, trade_val, trade_cost = self._calc_trade_info_by_order(\n            order,\n            trade_account.current_position if trade_account else position,\n            dealt_order_amount,\n        )\n        if trade_val > 1e-5:\n            # If the order can only be deal 0 value. Nothing to be updated\n            # Otherwise, it will result in\n            # 1) some stock with 0 value in the position\n            # 2) `trade_unit` of trade_cost will be lost in user account\n            if trade_account:\n                trade_account.update_order(order=order, trade_val=trade_val, cost=trade_cost, trade_price=trade_price)\n            elif position:\n                position.update_order(order=order, trade_val=trade_val, cost=trade_cost, trade_price=trade_price)\n\n        return trade_val, trade_cost, trade_price\n\n    def get_quote_info(\n        self,\n        stock_id: str,\n        start_time: pd.Timestamp,\n        end_time: pd.Timestamp,\n        field: str,\n        method: str = \"ts_data_last\",\n    ) -> Union[None, int, float, bool, IndexData]:\n        return self.quote.get_data(stock_id, start_time, end_time, field=field, method=method)\n\n    def get_close(\n        self,\n        stock_id: str,\n        start_time: pd.Timestamp,\n        end_time: pd.Timestamp,\n        method: str = \"ts_data_last\",\n    ) -> Union[None, int, float, bool, IndexData]:\n        return self.quote.get_data(stock_id, start_time, end_time, field=\"$close\", method=method)\n\n    def get_volume(\n        self,\n        stock_id: str,\n        start_time: pd.Timestamp,\n        end_time: pd.Timestamp,\n        method: Optional[str] = \"sum\",\n    ) -> Union[None, int, float, bool, IndexData]:\n        \"\"\"get the total deal volume of stock with `stock_id` between the time interval [start_time, end_time)\"\"\"\n        return self.quote.get_data(stock_id, start_time, end_time, field=\"$volume\", method=method)\n\n    def get_deal_price(\n        self,\n        stock_id: str,\n        start_time: pd.Timestamp,\n        end_time: pd.Timestamp,\n        direction: OrderDir,\n        method: Optional[str] = \"ts_data_last\",\n    ) -> Union[None, int, float, bool, IndexData]:\n        if direction == OrderDir.SELL:\n            pstr = self.sell_price\n        elif direction == OrderDir.BUY:\n            pstr = self.buy_price\n        else:\n            raise NotImplementedError(f\"This type of input is not supported\")\n\n        deal_price = self.quote.get_data(stock_id, start_time, end_time, field=pstr, method=method)\n        if method is not None and (deal_price is None or np.isnan(deal_price) or deal_price <= 1e-08):\n            self.logger.warning(f\"(stock_id:{stock_id}, trade_time:{(start_time, end_time)}, {pstr}): {deal_price}!!!\")\n            self.logger.warning(f\"setting deal_price to close price\")\n            deal_price = self.get_close(stock_id, start_time, end_time, method)\n        return deal_price\n\n    def get_factor(\n        self,\n        stock_id: str,\n        start_time: pd.Timestamp,\n        end_time: pd.Timestamp,\n    ) -> Optional[float]:\n        \"\"\"\n        Returns\n        -------\n        Optional[float]:\n            `None`: if the stock is suspended `None` may be returned\n            `float`: return factor if the factor exists\n        \"\"\"\n        assert start_time is not None and end_time is not None, \"the time range must be given\"\n        if stock_id not in self.quote.get_all_stock():\n            return None\n        return self.quote.get_data(stock_id, start_time, end_time, field=\"$factor\", method=\"ts_data_last\")\n\n    def generate_amount_position_from_weight_position(\n        self,\n        weight_position: dict,\n        cash: float,\n        start_time: pd.Timestamp,\n        end_time: pd.Timestamp,\n        direction: OrderDir = OrderDir.BUY,\n    ) -> dict:\n        \"\"\"\n        Generates the target position according to the weight and the cash.\n        NOTE: All the cash will be assigned to the tradable stock.\n        Parameter:\n        weight_position : dict {stock_id : weight}; allocate cash by weight_position\n            among then, weight must be in this range: 0 < weight < 1\n        cash : cash\n        start_time : the start time point of the step\n        end_time : the end time point of the step\n        direction : the direction of the deal price for estimating the amount\n                    # NOTE: this function is used for calculating target position. So the default direction is buy\n        \"\"\"\n\n        # calculate the total weight of tradable value\n        tradable_weight = 0.0\n        for stock_id, wp in weight_position.items():\n            if self.is_stock_tradable(stock_id=stock_id, start_time=start_time, end_time=end_time):\n                # weight_position must be greater than 0 and less than 1\n                if wp < 0 or wp > 1:\n                    raise ValueError(\n                        \"weight_position is {}, \" \"weight_position is not in the range of (0, 1).\".format(wp),\n                    )\n                tradable_weight += wp\n\n        if tradable_weight - 1.0 >= 1e-5:\n            raise ValueError(\"tradable_weight is {}, can not greater than 1.\".format(tradable_weight))\n\n        amount_dict = {}\n        for stock_id in weight_position:\n            if weight_position[stock_id] > 0.0 and self.is_stock_tradable(\n                stock_id=stock_id,\n                start_time=start_time,\n                end_time=end_time,\n            ):\n                amount_dict[stock_id] = (\n                    cash\n                    * weight_position[stock_id]\n                    / tradable_weight\n                    // self.get_deal_price(\n                        stock_id=stock_id,\n                        start_time=start_time,\n                        end_time=end_time,\n                        direction=direction,\n                    )\n                )\n        return amount_dict\n\n    def get_real_deal_amount(self, current_amount: float, target_amount: float, factor: float | None = None) -> float:\n        \"\"\"\n        Calculate the real adjust deal amount when considering the trading unit\n        :param current_amount:\n        :param target_amount:\n        :param factor:\n        :return  real_deal_amount;  Positive deal_amount indicates buying more stock.\n        \"\"\"\n        if current_amount == target_amount:\n            return 0\n        elif current_amount < target_amount:\n            deal_amount = target_amount - current_amount\n            deal_amount = self.round_amount_by_trade_unit(deal_amount, factor)\n            return deal_amount\n        else:\n            if target_amount == 0:\n                return -current_amount\n            else:\n                deal_amount = current_amount - target_amount\n                deal_amount = self.round_amount_by_trade_unit(deal_amount, factor)\n                return -deal_amount\n\n    def generate_order_for_target_amount_position(\n        self,\n        target_position: dict,\n        current_position: dict,\n        start_time: pd.Timestamp,\n        end_time: pd.Timestamp,\n    ) -> List[Order]:\n        \"\"\"\n        Note: some future information is used in this function\n        Parameter:\n        target_position : dict { stock_id : amount }\n        current_position : dict { stock_id : amount}\n        trade_unit : trade_unit\n        down sample : for amount 321 and trade_unit 100, deal_amount is 300\n        deal order on trade_date\n        \"\"\"\n        # split buy and sell for further use\n        buy_order_list = []\n        sell_order_list = []\n        # three parts: kept stock_id, dropped stock_id, new stock_id\n        # handle kept stock_id\n\n        # because the order of the set is not fixed, the trading order of the stock is different, so that the backtest\n        # results of the same parameter are different;\n        # so here we sort stock_id, and then randomly shuffle the order of stock_id\n        # because the same random seed is used, the final stock_id order is fixed\n        sorted_ids = sorted(set(list(current_position.keys()) + list(target_position.keys())))\n        random.seed(0)\n        random.shuffle(sorted_ids)\n        for stock_id in sorted_ids:\n            # Do not generate order for the non-tradable stocks\n            if not self.is_stock_tradable(stock_id=stock_id, start_time=start_time, end_time=end_time):\n                continue\n\n            target_amount = target_position.get(stock_id, 0)\n            current_amount = current_position.get(stock_id, 0)\n            factor = self.get_factor(stock_id, start_time=start_time, end_time=end_time)\n\n            deal_amount = self.get_real_deal_amount(current_amount, target_amount, factor)\n            if deal_amount == 0:\n                continue\n            if deal_amount > 0:\n                # buy stock\n                buy_order_list.append(\n                    Order(\n                        stock_id=stock_id,\n                        amount=deal_amount,\n                        direction=Order.BUY,\n                        start_time=start_time,\n                        end_time=end_time,\n                        factor=factor,\n                    ),\n                )\n            else:\n                # sell stock\n                sell_order_list.append(\n                    Order(\n                        stock_id=stock_id,\n                        amount=abs(deal_amount),\n                        direction=Order.SELL,\n                        start_time=start_time,\n                        end_time=end_time,\n                        factor=factor,\n                    ),\n                )\n        # return order_list : buy + sell\n        return sell_order_list + buy_order_list\n\n    def calculate_amount_position_value(\n        self,\n        amount_dict: dict,\n        start_time: pd.Timestamp,\n        end_time: pd.Timestamp,\n        only_tradable: bool = False,\n        direction: OrderDir = OrderDir.SELL,\n    ) -> float:\n        \"\"\"Parameter\n        position : Position()\n        amount_dict : {stock_id : amount}\n        direction : the direction of the deal price for estimating the amount\n                    # NOTE:\n                    This function is used for calculating current position value.\n                    So the default direction is sell.\n        \"\"\"\n        value = 0\n        for stock_id in amount_dict:\n            if not only_tradable or (\n                not self.check_stock_suspended(stock_id=stock_id, start_time=start_time, end_time=end_time)\n                and not self.check_stock_limit(stock_id=stock_id, start_time=start_time, end_time=end_time)\n            ):\n                value += (\n                    self.get_deal_price(\n                        stock_id=stock_id,\n                        start_time=start_time,\n                        end_time=end_time,\n                        direction=direction,\n                    )\n                    * amount_dict[stock_id]\n                )\n        return value\n\n    def _get_factor_or_raise_error(\n        self,\n        factor: float | None = None,\n        stock_id: str | None = None,\n        start_time: pd.Timestamp = None,\n        end_time: pd.Timestamp = None,\n    ) -> float:\n        \"\"\"Please refer to the docs of get_amount_of_trade_unit\"\"\"\n        if factor is None:\n            if stock_id is not None and start_time is not None and end_time is not None:\n                factor = self.get_factor(stock_id=stock_id, start_time=start_time, end_time=end_time)\n            else:\n                raise ValueError(f\"`factor` and (`stock_id`, `start_time`, `end_time`) can't both be None\")\n        assert factor is not None\n        return factor\n\n    def get_amount_of_trade_unit(\n        self,\n        factor: float | None = None,\n        stock_id: str | None = None,\n        start_time: pd.Timestamp = None,\n        end_time: pd.Timestamp = None,\n    ) -> Optional[float]:\n        \"\"\"\n        get the trade unit of amount based on **factor**\n        the factor can be given directly or calculated in given time range and stock id.\n        `factor` has higher priority than `stock_id`, `start_time` and `end_time`\n        Parameters\n        ----------\n        factor : float\n            the adjusted factor\n        stock_id : str\n            the id of the stock\n        start_time :\n            the start time of trading range\n        end_time :\n            the end time of trading range\n        \"\"\"\n        if not self.trade_w_adj_price and self.trade_unit is not None:\n            factor = self._get_factor_or_raise_error(\n                factor=factor,\n                stock_id=stock_id,\n                start_time=start_time,\n                end_time=end_time,\n            )\n            return self.trade_unit / factor\n        else:\n            return None\n\n    def round_amount_by_trade_unit(\n        self,\n        deal_amount: float,\n        factor: float | None = None,\n        stock_id: str | None = None,\n        start_time: pd.Timestamp = None,\n        end_time: pd.Timestamp = None,\n    ) -> float:\n        \"\"\"Parameter\n        Please refer to the docs of get_amount_of_trade_unit\n        deal_amount : float, adjusted amount\n        factor : float, adjusted factor\n        return : float, real amount\n        \"\"\"\n        if not self.trade_w_adj_price and self.trade_unit is not None:\n            # the minimal amount is 1. Add 0.1 for solving precision problem.\n            factor = self._get_factor_or_raise_error(\n                factor=factor,\n                stock_id=stock_id,\n                start_time=start_time,\n                end_time=end_time,\n            )\n            return (deal_amount * factor + 0.1) // self.trade_unit * self.trade_unit / factor\n        return deal_amount\n\n    def _clip_amount_by_volume(self, order: Order, dealt_order_amount: dict) -> Optional[float]:\n        \"\"\"parse the capacity limit string and return the actual amount of orders that can be executed.\n        NOTE:\n            this function will change the order.deal_amount **inplace**\n            - This will make the order info more accurate\n        Parameters\n        ----------\n        order : Order\n            the order to be executed.\n        dealt_order_amount : dict\n            :param dealt_order_amount: the dealt order amount dict with the format of {stock_id: float}\n        \"\"\"\n        vol_limit = self.buy_vol_limit if order.direction == Order.BUY else self.sell_vol_limit\n\n        if vol_limit is None:\n            return order.deal_amount\n\n        vol_limit_num: List[float] = []\n        for limit in vol_limit:\n            assert isinstance(limit, tuple)\n            if limit[0] == \"current\":\n                limit_value = self.quote.get_data(\n                    order.stock_id,\n                    order.start_time,\n                    order.end_time,\n                    field=limit[1],\n                    method=\"sum\",\n                )\n                vol_limit_num.append(cast(float, limit_value))\n            elif limit[0] == \"cum\":\n                limit_value = self.quote.get_data(\n                    order.stock_id,\n                    order.start_time,\n                    order.end_time,\n                    field=limit[1],\n                    method=\"ts_data_last\",\n                )\n                vol_limit_num.append(limit_value - dealt_order_amount[order.stock_id])\n            else:\n                raise ValueError(f\"{limit[0]} is not supported\")\n        vol_limit_min = min(vol_limit_num)\n        orig_deal_amount = order.deal_amount\n        order.deal_amount = max(min(vol_limit_min, orig_deal_amount), 0)\n        if vol_limit_min < orig_deal_amount:\n            self.logger.debug(f\"Order clipped due to volume limitation: {order}, {list(zip(vol_limit_num, vol_limit))}\")\n\n        return None\n\n    def _get_buy_amount_by_cash_limit(self, trade_price: float, cash: float, cost_ratio: float) -> float:\n        \"\"\"return the real order amount after cash limit for buying.\n        Parameters\n        ----------\n        trade_price : float\n        cash : float\n        cost_ratio : float\n\n        Return\n        ----------\n        float\n            the real order amount after cash limit for buying.\n        \"\"\"\n        max_trade_amount = 0.0\n        if cash >= self.min_cost:\n            # critical_price means the stock transaction price when the service fee is equal to min_cost.\n            critical_price = self.min_cost / cost_ratio + self.min_cost\n            if cash >= critical_price:\n                # the service fee is equal to cost_ratio * trade_amount\n                max_trade_amount = cash / (1 + cost_ratio) / trade_price\n            else:\n                # the service fee is equal to min_cost\n                max_trade_amount = (cash - self.min_cost) / trade_price\n        return max_trade_amount\n\n    def _calc_trade_info_by_order(\n        self,\n        order: Order,\n        position: Optional[BasePosition],\n        dealt_order_amount: dict,\n    ) -> Tuple[float, float, float]:\n        \"\"\"\n        Calculation of trade info\n        **NOTE**: Order will be changed in this function\n        :param order:\n        :param position: Position\n        :param dealt_order_amount: the dealt order amount dict with the format of {stock_id: float}\n        :return: trade_price, trade_val, trade_cost\n        \"\"\"\n        trade_price = cast(\n            float,\n            self.get_deal_price(order.stock_id, order.start_time, order.end_time, direction=order.direction),\n        )\n        total_trade_val = cast(float, self.get_volume(order.stock_id, order.start_time, order.end_time)) * trade_price\n        order.factor = self.get_factor(order.stock_id, order.start_time, order.end_time)\n        order.deal_amount = order.amount  # set to full amount and clip it step by step\n        # Clipping amount first\n        # - It simulates that the order is rejected directly by the exchange due to large order\n        # Another choice is placing it after rounding the order\n        # - It simulates that the large order is submitted, but partial is dealt regardless of rounding by trading unit.\n        self._clip_amount_by_volume(order, dealt_order_amount)\n\n        # TODO: the adjusted cost ratio can be overestimated as deal_amount will be clipped in the next steps\n        trade_val = order.deal_amount * trade_price\n        if not total_trade_val or np.isnan(total_trade_val):\n            # TODO: assert trade_val == 0, f\"trade_val != 0, total_trade_val: {total_trade_val}; order info: {order}\"\n            adj_cost_ratio = self.impact_cost\n        else:\n            adj_cost_ratio = self.impact_cost * (trade_val / total_trade_val) ** 2\n\n        if order.direction == Order.SELL:\n            cost_ratio = self.close_cost + adj_cost_ratio\n            # sell\n            # if we don't know current position, we choose to sell all\n            # Otherwise, we clip the amount based on current position\n            if position is not None:\n                # TODO: make the trading shortable\n                current_amount = (\n                    position.get_stock_amount(order.stock_id) if position.check_stock(order.stock_id) else 0\n                )\n                if not np.isclose(order.deal_amount, current_amount):\n                    # when not selling last stock. rounding is necessary\n                    order.deal_amount = self.round_amount_by_trade_unit(\n                        min(current_amount, order.deal_amount),\n                        order.factor,\n                    )\n\n                # in case of negative value of cash\n                if position.get_cash() + order.deal_amount * trade_price < max(\n                    order.deal_amount * trade_price * cost_ratio,\n                    self.min_cost,\n                ):\n                    order.deal_amount = 0\n                    self.logger.debug(f\"Order clipped due to cash limitation: {order}\")\n\n        elif order.direction == Order.BUY:\n            cost_ratio = self.open_cost + adj_cost_ratio\n            # buy\n            if position is not None:\n                cash = position.get_cash()\n                trade_val = order.deal_amount * trade_price\n                if cash < max(trade_val * cost_ratio, self.min_cost):\n                    # cash cannot cover cost\n                    order.deal_amount = 0\n                    self.logger.debug(f\"Order clipped due to cost higher than cash: {order}\")\n                elif cash < trade_val + max(trade_val * cost_ratio, self.min_cost):\n                    # The money is not enough\n                    max_buy_amount = self._get_buy_amount_by_cash_limit(trade_price, cash, cost_ratio)\n                    order.deal_amount = self.round_amount_by_trade_unit(\n                        min(max_buy_amount, order.deal_amount),\n                        order.factor,\n                    )\n                    self.logger.debug(f\"Order clipped due to cash limitation: {order}\")\n                else:\n                    # The money is enough\n                    order.deal_amount = self.round_amount_by_trade_unit(order.deal_amount, order.factor)\n            else:\n                # Unknown amount of money. Just round the amount\n                order.deal_amount = self.round_amount_by_trade_unit(order.deal_amount, order.factor)\n\n        else:\n            raise NotImplementedError(\"order direction {} error\".format(order.direction))\n\n        trade_val = order.deal_amount * trade_price\n        trade_cost = max(trade_val * cost_ratio, self.min_cost)\n        if trade_val <= 1e-5:\n            # if dealing is not successful, the trade_cost should be zero.\n            trade_cost = 0\n        return trade_price, trade_val, trade_cost\n\n    def get_order_helper(self) -> OrderHelper:\n        if not hasattr(self, \"_order_helper\"):\n            # cache to avoid recreate the same instance\n            self._order_helper = OrderHelper(self)\n        return self._order_helper\n"
  },
  {
    "path": "qlib/backtest/executor.py",
    "content": "from __future__ import annotations\n\nimport copy\nfrom abc import abstractmethod\nfrom collections import defaultdict\nfrom types import GeneratorType\nfrom typing import Any, Dict, Generator, List, Tuple, Union, cast\n\nimport pandas as pd\n\nfrom qlib.backtest.account import Account\nfrom qlib.backtest.position import BasePosition\nfrom qlib.log import get_module_logger\n\nfrom ..strategy.base import BaseStrategy\nfrom ..utils import init_instance_by_config\nfrom .decision import BaseTradeDecision, Order\nfrom .exchange import Exchange\nfrom .utils import CommonInfrastructure, LevelInfrastructure, TradeCalendarManager, get_start_end_idx\n\n\nclass BaseExecutor:\n    \"\"\"Base executor for trading\"\"\"\n\n    def __init__(\n        self,\n        time_per_step: str,\n        start_time: Union[str, pd.Timestamp] = None,\n        end_time: Union[str, pd.Timestamp] = None,\n        indicator_config: dict = {},\n        generate_portfolio_metrics: bool = False,\n        verbose: bool = False,\n        track_data: bool = False,\n        trade_exchange: Exchange | None = None,\n        common_infra: CommonInfrastructure | None = None,\n        settle_type: str = BasePosition.ST_NO,\n        **kwargs: Any,\n    ) -> None:\n        \"\"\"\n        Parameters\n        ----------\n        time_per_step : str\n            trade time per trading step, used for generate the trade calendar\n        show_indicator: bool, optional\n            whether to show indicators, :\n            - 'pa', the price advantage\n            - 'pos', the positive rate\n            - 'ffr', the fulfill rate\n        indicator_config: dict, optional\n            config for calculating trade indicator, including the following fields:\n            - 'show_indicator': whether to show indicators, optional, default by False. The indicators includes\n                - 'pa', the price advantage\n                - 'pos', the positive rate\n                - 'ffr', the fulfill rate\n            - 'pa_config': config for calculating price advantage(pa), optional\n                - 'base_price': the based price than which the trading price is advanced, Optional, default by 'twap'\n                    - If 'base_price' is 'twap', the based price is the time weighted average price\n                    - If 'base_price' is 'vwap', the based price is the volume weighted average price\n                - 'weight_method': weighted method when calculating total trading pa by different orders' pa in each\n                    step, optional, default by 'mean'\n                    - If 'weight_method' is 'mean', calculating mean value of different orders' pa\n                    - If 'weight_method' is 'amount_weighted', calculating amount weighted average value of different\n                        orders' pa\n                    - If 'weight_method' is 'value_weighted', calculating value weighted average value of different\n                        orders' pa\n            - 'ffr_config': config for calculating fulfill rate(ffr), optional\n                - 'weight_method': weighted method when calculating total trading ffr by different orders' ffr in each\n                    step, optional, default by 'mean'\n                    - If 'weight_method' is 'mean', calculating mean value of different orders' ffr\n                    - If 'weight_method' is 'amount_weighted', calculating amount weighted average value of different\n                        orders' ffr\n                    - If 'weight_method' is 'value_weighted', calculating value weighted average value of different\n                        orders' ffr\n            Example:\n                {\n                    'show_indicator': True,\n                    'pa_config': {\n                        \"agg\": \"twap\",  # \"vwap\"\n                        \"price\": \"$close\", # default to use deal price of the exchange\n                    },\n                    'ffr_config':{\n                        'weight_method': 'value_weighted',\n                    }\n                }\n        generate_portfolio_metrics : bool, optional\n            whether to generate portfolio_metrics, by default False\n        verbose : bool, optional\n            whether to print trading info, by default False\n        track_data : bool, optional\n            whether to generate trade_decision, will be used when training rl agent\n            - If `self.track_data` is true, when making data for training, the input `trade_decision` of `execute` will\n                be generated by `collect_data`\n            - Else,  `trade_decision` will not be generated\n\n        trade_exchange : Exchange\n            exchange that provides market info, used to generate portfolio_metrics\n            - If generate_portfolio_metrics is None, trade_exchange will be ignored\n            - Else If `trade_exchange` is None, self.trade_exchange will be set with common_infra\n\n        common_infra : CommonInfrastructure, optional:\n            common infrastructure for backtesting, may including:\n            - trade_account : Account, optional\n                trade account for trading\n            - trade_exchange : Exchange, optional\n                exchange that provides market info\n\n        settle_type : str\n            Please refer to the docs of BasePosition.settle_start\n        \"\"\"\n        self.time_per_step = time_per_step\n        self.indicator_config = indicator_config\n        self.generate_portfolio_metrics = generate_portfolio_metrics\n        self.verbose = verbose\n        self.track_data = track_data\n        self._trade_exchange = trade_exchange\n        self.level_infra = LevelInfrastructure()\n        self.level_infra.reset_infra(common_infra=common_infra, executor=self)\n        self._settle_type = settle_type\n        self.reset(start_time=start_time, end_time=end_time, common_infra=common_infra)\n        if common_infra is None:\n            get_module_logger(\"BaseExecutor\").warning(f\"`common_infra` is not set for {self}\")\n\n        # record deal order amount in one day\n        self.dealt_order_amount: Dict[str, float] = defaultdict(float)\n        self.deal_day = None\n\n    def reset_common_infra(self, common_infra: CommonInfrastructure, copy_trade_account: bool = False) -> None:\n        \"\"\"\n        reset infrastructure for trading\n            - reset trade_account\n        \"\"\"\n        if not hasattr(self, \"common_infra\"):\n            self.common_infra = common_infra\n        else:\n            self.common_infra.update(common_infra)\n\n        self.level_infra.reset_infra(common_infra=self.common_infra)\n\n        if common_infra.has(\"trade_account\"):\n            # NOTE: there is a trick in the code.\n            # shallow copy is used instead of deepcopy.\n            # 1. So positions are shared\n            # 2. Others are not shared, so each level has it own metrics (portfolio and trading metrics)\n            self.trade_account: Account = (\n                copy.copy(common_infra.get(\"trade_account\"))\n                if copy_trade_account\n                else common_infra.get(\"trade_account\")\n            )\n            self.trade_account.reset(freq=self.time_per_step, port_metr_enabled=self.generate_portfolio_metrics)\n\n    @property\n    def trade_exchange(self) -> Exchange:\n        \"\"\"get trade exchange in a prioritized order\"\"\"\n        return getattr(self, \"_trade_exchange\", None) or self.common_infra.get(\"trade_exchange\")\n\n    @property\n    def trade_calendar(self) -> TradeCalendarManager:\n        \"\"\"\n        Though trade calendar can be accessed from multiple sources, but managing in a centralized way will make the\n        code easier\n        \"\"\"\n        return self.level_infra.get(\"trade_calendar\")\n\n    def reset(self, common_infra: CommonInfrastructure | None = None, **kwargs: Any) -> None:\n        \"\"\"\n        - reset `start_time` and `end_time`, used in trade calendar\n        - reset `common_infra`, used to reset `trade_account`, `trade_exchange`, .etc\n        \"\"\"\n\n        if \"start_time\" in kwargs or \"end_time\" in kwargs:\n            start_time = kwargs.get(\"start_time\")\n            end_time = kwargs.get(\"end_time\")\n            self.level_infra.reset_cal(freq=self.time_per_step, start_time=start_time, end_time=end_time)\n        if common_infra is not None:\n            self.reset_common_infra(common_infra)\n\n    def get_level_infra(self) -> LevelInfrastructure:\n        return self.level_infra\n\n    def finished(self) -> bool:\n        return self.trade_calendar.finished()\n\n    def execute(self, trade_decision: BaseTradeDecision, level: int = 0) -> List[object]:\n        \"\"\"execute the trade decision and return the executed result\n\n        NOTE: this function is never used directly in the framework. Should we delete it?\n\n        Parameters\n        ----------\n        trade_decision : BaseTradeDecision\n\n        level : int\n            the level of current executor\n\n        Returns\n        ----------\n        execute_result : List[object]\n            the executed result for trade decision\n        \"\"\"\n        return_value: dict = {}\n        for _decision in self.collect_data(trade_decision, return_value=return_value, level=level):\n            pass\n        return cast(list, return_value.get(\"execute_result\"))\n\n    @abstractmethod\n    def _collect_data(\n        self,\n        trade_decision: BaseTradeDecision,\n        level: int = 0,\n    ) -> Union[Generator[Any, Any, Tuple[List[object], dict]], Tuple[List[object], dict]]:\n        \"\"\"\n        Please refer to the doc of collect_data\n        The only difference between `_collect_data` and `collect_data` is that some common steps are moved into\n        collect_data\n\n        Parameters\n        ----------\n        Please refer to the doc of collect_data\n\n\n        Returns\n        -------\n        Tuple[List[object], dict]:\n            (<the executed result for trade decision>, <the extra kwargs for `self.trade_account.update_bar_end`>)\n        \"\"\"\n\n    def collect_data(\n        self,\n        trade_decision: BaseTradeDecision,\n        return_value: dict | None = None,\n        level: int = 0,\n    ) -> Generator[Any, Any, List[object]]:\n        \"\"\"Generator for collecting the trade decision data for rl training\n\n        his function will make a step forward\n\n        Parameters\n        ----------\n        trade_decision : BaseTradeDecision\n\n        level : int\n            the level of current executor. 0 indicates the top level\n\n        return_value : dict\n            the mem address to return the value\n            e.g.  {\"return_value\": <the executed result>}\n\n        Returns\n        ----------\n        execute_result : List[object]\n            the executed result for trade decision.\n            ** NOTE!!!! **:\n            1) This is necessary,  The return value of generator will be used in NestedExecutor\n            2) Please note the executed results are not merged.\n\n        Yields\n        -------\n        object\n            trade decision\n        \"\"\"\n\n        if self.track_data:\n            yield trade_decision\n\n        atomic = not issubclass(self.__class__, NestedExecutor)  # issubclass(A, A) is True\n\n        if atomic and trade_decision.get_range_limit(default_value=None) is not None:\n            raise ValueError(\"atomic executor doesn't support specify `range_limit`\")\n\n        if self._settle_type != BasePosition.ST_NO:\n            self.trade_account.current_position.settle_start(self._settle_type)\n\n        obj = self._collect_data(trade_decision=trade_decision, level=level)\n\n        if isinstance(obj, GeneratorType):\n            yield_res = yield from obj\n            assert isinstance(yield_res, tuple) and len(yield_res) == 2\n            res, kwargs = yield_res\n        else:\n            # Some concrete executor don't have inner decisions\n            res, kwargs = obj\n\n        trade_start_time, trade_end_time = self.trade_calendar.get_step_time()\n        # Account will not be changed in this function\n        self.trade_account.update_bar_end(\n            trade_start_time,\n            trade_end_time,\n            self.trade_exchange,\n            atomic=atomic,\n            outer_trade_decision=trade_decision,\n            indicator_config=self.indicator_config,\n            **kwargs,\n        )\n\n        self.trade_calendar.step()\n\n        if self._settle_type != BasePosition.ST_NO:\n            self.trade_account.current_position.settle_commit()\n\n        if return_value is not None:\n            return_value.update({\"execute_result\": res})\n\n        return res\n\n    def get_all_executors(self) -> List[BaseExecutor]:\n        \"\"\"get all executors\"\"\"\n        return [self]\n\n\nclass NestedExecutor(BaseExecutor):\n    \"\"\"\n    Nested Executor with inner strategy and executor\n    - At each time `execute` is called, it will call the inner strategy and executor to execute the `trade_decision`\n        in a higher frequency env.\n    \"\"\"\n\n    def __init__(\n        self,\n        time_per_step: str,\n        inner_executor: Union[BaseExecutor, dict],\n        inner_strategy: Union[BaseStrategy, dict],\n        start_time: Union[str, pd.Timestamp] = None,\n        end_time: Union[str, pd.Timestamp] = None,\n        indicator_config: dict = {},\n        generate_portfolio_metrics: bool = False,\n        verbose: bool = False,\n        track_data: bool = False,\n        skip_empty_decision: bool = True,\n        align_range_limit: bool = True,\n        common_infra: CommonInfrastructure | None = None,\n        **kwargs: Any,\n    ) -> None:\n        \"\"\"\n        Parameters\n        ----------\n        inner_executor : BaseExecutor\n            trading env in each trading bar.\n        inner_strategy : BaseStrategy\n            trading strategy in each trading bar\n        skip_empty_decision: bool\n            Will the executor skip call inner loop when the decision is empty.\n            It should be False in following cases\n            - The decisions may be updated by steps\n            - The inner executor may not follow the decisions from the outer strategy\n        align_range_limit: bool\n            force to align the trade_range decision\n            It is only for nested executor, because range_limit is given by outer strategy\n        \"\"\"\n        self.inner_executor: BaseExecutor = init_instance_by_config(\n            inner_executor,\n            common_infra=common_infra,\n            accept_types=BaseExecutor,\n        )\n        self.inner_strategy: BaseStrategy = init_instance_by_config(\n            inner_strategy,\n            common_infra=common_infra,\n            accept_types=BaseStrategy,\n        )\n\n        self._skip_empty_decision = skip_empty_decision\n        self._align_range_limit = align_range_limit\n\n        super(NestedExecutor, self).__init__(\n            time_per_step=time_per_step,\n            start_time=start_time,\n            end_time=end_time,\n            indicator_config=indicator_config,\n            generate_portfolio_metrics=generate_portfolio_metrics,\n            verbose=verbose,\n            track_data=track_data,\n            common_infra=common_infra,\n            **kwargs,\n        )\n\n    def reset_common_infra(self, common_infra: CommonInfrastructure, copy_trade_account: bool = False) -> None:\n        \"\"\"\n        reset infrastructure for trading\n            - reset inner_strategy and inner_executor common infra\n        \"\"\"\n        # NOTE: please refer to the docs of BaseExecutor.reset_common_infra for the meaning of `copy_trade_account`\n\n        # The first level follow the `copy_trade_account` from the upper level\n        super(NestedExecutor, self).reset_common_infra(common_infra, copy_trade_account=copy_trade_account)\n\n        # The lower level have to copy the trade_account\n        self.inner_executor.reset_common_infra(common_infra, copy_trade_account=True)\n        self.inner_strategy.reset_common_infra(common_infra)\n\n    def _init_sub_trading(self, trade_decision: BaseTradeDecision) -> None:\n        trade_start_time, trade_end_time = self.trade_calendar.get_step_time()\n        self.inner_executor.reset(start_time=trade_start_time, end_time=trade_end_time)\n        sub_level_infra = self.inner_executor.get_level_infra()\n        self.level_infra.set_sub_level_infra(sub_level_infra)\n        self.inner_strategy.reset(level_infra=sub_level_infra, outer_trade_decision=trade_decision)\n\n    def _update_trade_decision(self, trade_decision: BaseTradeDecision) -> BaseTradeDecision:\n        # outer strategy have chance to update decision each iterator\n        updated_trade_decision = trade_decision.update(self.inner_executor.trade_calendar)\n        if updated_trade_decision is not None:  # TODO: always is None for now?\n            trade_decision = updated_trade_decision\n            # NEW UPDATE\n            # create a hook for inner strategy to update outer decision\n            trade_decision = self.inner_strategy.alter_outer_trade_decision(trade_decision)\n        return trade_decision\n\n    def _collect_data(\n        self,\n        trade_decision: BaseTradeDecision,\n        level: int = 0,\n    ) -> Generator[Any, Any, Tuple[List[object], dict]]:\n        execute_result = []\n        inner_order_indicators = []\n        decision_list = []\n        # NOTE:\n        # - this is necessary to calculating the steps in sub level\n        # - more detailed information will be set into trade decision\n        self._init_sub_trading(trade_decision)\n\n        _inner_execute_result = None\n        while not self.inner_executor.finished():\n            trade_decision = self._update_trade_decision(trade_decision)\n\n            if trade_decision.empty() and self._skip_empty_decision:\n                # give one chance for outer strategy to update the strategy\n                # - For updating some information in the sub executor (the strategy have no knowledge of the inner\n                #   executor when generating the decision)\n                break\n\n            sub_cal: TradeCalendarManager = self.inner_executor.trade_calendar\n\n            # NOTE: make sure get_start_end_idx is after `self._update_trade_decision`\n            start_idx, end_idx = get_start_end_idx(sub_cal, trade_decision)\n            if not self._align_range_limit or start_idx <= sub_cal.get_trade_step() <= end_idx:\n                # if force align the range limit, skip the steps outside the decision range limit\n\n                res = self.inner_strategy.generate_trade_decision(_inner_execute_result)\n\n                # NOTE: !!!!!\n                # the two lines below is for a special case in RL\n                # To solve the conflicts below\n                # - Normally, user will create a strategy and embed it into Qlib's executor and simulator interaction\n                #   loop For a _nested qlib example_, (Qlib Strategy) <=> (Qlib Executor[(inner Qlib Strategy) <=>\n                #   (inner Qlib Executor)])\n                # - However, RL-based framework has it's own script to run the loop\n                #   For an _RL learning example_, (RL Policy) <=> (RL Env[(inner Qlib Executor)])\n                # To make it possible to run  _nested qlib example_ and _RL learning example_ together, the solution\n                # below is proposed\n                # - The entry script follow the example of  _RL learning example_ to be compatible with all kinds of\n                #   RL Framework\n                # - Each step of (RL Env) will make (inner Qlib Executor) one step forward\n                #     - (inner Qlib Strategy) is a proxy strategy, it will give the program control right to (RL Env)\n                #       by `yield from` and wait for the action from the policy\n                # So the two lines below is the implementation of yielding control rights\n                if isinstance(res, GeneratorType):\n                    res = yield from res\n\n                _inner_trade_decision: BaseTradeDecision = res\n\n                trade_decision.mod_inner_decision(_inner_trade_decision)  # propagate part of decision information\n\n                # NOTE sub_cal.get_step_time() must be called before collect_data in case of step shifting\n                decision_list.append((_inner_trade_decision, *sub_cal.get_step_time()))\n\n                # NOTE: Trade Calendar will step forward in the follow line\n                _inner_execute_result = yield from self.inner_executor.collect_data(\n                    trade_decision=_inner_trade_decision,\n                    level=level + 1,\n                )\n                assert isinstance(_inner_execute_result, list)\n                self.post_inner_exe_step(_inner_execute_result)\n                execute_result.extend(_inner_execute_result)\n\n                inner_order_indicators.append(\n                    self.inner_executor.trade_account.get_trade_indicator().get_order_indicator(raw=True),\n                )\n            else:\n                # do nothing and just step forward\n                sub_cal.step()\n\n        # Let inner strategy know that the outer level execution is done.\n        self.inner_strategy.post_upper_level_exe_step()\n\n        return execute_result, {\"inner_order_indicators\": inner_order_indicators, \"decision_list\": decision_list}\n\n    def post_inner_exe_step(self, inner_exe_res: List[object]) -> None:\n        \"\"\"\n        A hook for doing sth after each step of inner strategy\n\n        Parameters\n        ----------\n        inner_exe_res :\n            the execution result of inner task\n        \"\"\"\n        self.inner_strategy.post_exe_step(inner_exe_res)\n\n    def get_all_executors(self) -> List[BaseExecutor]:\n        \"\"\"get all executors, including self and inner_executor.get_all_executors()\"\"\"\n        return [self, *self.inner_executor.get_all_executors()]\n\n\ndef _retrieve_orders_from_decision(trade_decision: BaseTradeDecision) -> List[Order]:\n    \"\"\"\n    IDE-friendly helper function.\n    \"\"\"\n    decisions = trade_decision.get_decision()\n    orders: List[Order] = []\n    for decision in decisions:\n        assert isinstance(decision, Order)\n        orders.append(decision)\n    return orders\n\n\nclass SimulatorExecutor(BaseExecutor):\n    \"\"\"Executor that simulate the true market\"\"\"\n\n    # TODO: TT_SERIAL & TT_PARAL will be replaced by feature fix_pos now.\n    # Please remove them in the future.\n\n    # available trade_types\n    TT_SERIAL = \"serial\"\n    # The orders will be executed serially in a sequence\n    # In each trading step, it is possible that users sell instruments first and use the money to buy new instruments\n    TT_PARAL = \"parallel\"\n    # The orders will be executed in parallel\n    # In each trading step, if users try to sell instruments first and buy new instruments with money, failure will\n    # occur\n\n    def __init__(\n        self,\n        time_per_step: str,\n        start_time: Union[str, pd.Timestamp] = None,\n        end_time: Union[str, pd.Timestamp] = None,\n        indicator_config: dict = {},\n        generate_portfolio_metrics: bool = False,\n        verbose: bool = False,\n        track_data: bool = False,\n        common_infra: CommonInfrastructure | None = None,\n        trade_type: str = TT_SERIAL,\n        **kwargs: Any,\n    ) -> None:\n        \"\"\"\n        Parameters\n        ----------\n        trade_type: str\n            please refer to the doc of `TT_SERIAL` & `TT_PARAL`\n        \"\"\"\n        super(SimulatorExecutor, self).__init__(\n            time_per_step=time_per_step,\n            start_time=start_time,\n            end_time=end_time,\n            indicator_config=indicator_config,\n            generate_portfolio_metrics=generate_portfolio_metrics,\n            verbose=verbose,\n            track_data=track_data,\n            common_infra=common_infra,\n            **kwargs,\n        )\n\n        self.trade_type = trade_type\n\n    def _get_order_iterator(self, trade_decision: BaseTradeDecision) -> List[Order]:\n        \"\"\"\n\n        Parameters\n        ----------\n        trade_decision : BaseTradeDecision\n            the trade decision given by the strategy\n\n        Returns\n        -------\n        List[Order]:\n            get a list orders according to `self.trade_type`\n        \"\"\"\n        orders = _retrieve_orders_from_decision(trade_decision)\n\n        if self.trade_type == self.TT_SERIAL:\n            # Orders will be traded in a parallel way\n            order_it = orders\n        elif self.trade_type == self.TT_PARAL:\n            # NOTE: !!!!!!!\n            # Assumption: there will not be orders in different trading direction in a single step of a strategy !!!!\n            # The parallel trading failure will be caused only by the conflicts of money\n            # Therefore, make the buying go first will make sure the conflicts happen.\n            # It equals to parallel trading after sorting the order by direction\n            order_it = sorted(orders, key=lambda order: -order.direction)\n        else:\n            raise NotImplementedError(f\"This type of input is not supported\")\n        return order_it\n\n    def _collect_data(self, trade_decision: BaseTradeDecision, level: int = 0) -> Tuple[List[object], dict]:\n        trade_start_time, _ = self.trade_calendar.get_step_time()\n        execute_result: list = []\n\n        for order in self._get_order_iterator(trade_decision):\n            # Each time we move into a new date, clear `self.dealt_order_amount` since it only maintains intraday\n            # information.\n            now_deal_day = self.trade_calendar.get_step_time()[0].floor(freq=\"D\")\n            if self.deal_day is None or now_deal_day > self.deal_day:\n                self.dealt_order_amount = defaultdict(float)\n                self.deal_day = now_deal_day\n\n            # execute the order.\n            # NOTE: The trade_account will be changed in this function\n            trade_val, trade_cost, trade_price = self.trade_exchange.deal_order(\n                order,\n                trade_account=self.trade_account,\n                dealt_order_amount=self.dealt_order_amount,\n            )\n            execute_result.append((order, trade_val, trade_cost, trade_price))\n\n            self.dealt_order_amount[order.stock_id] += order.deal_amount\n\n            if self.verbose:\n                print(\n                    \"[I {:%Y-%m-%d %H:%M:%S}]: {} {}, price {:.2f}, amount {}, deal_amount {}, factor {}, \"\n                    \"value {:.2f}, cash {:.2f}.\".format(\n                        trade_start_time,\n                        \"sell\" if order.direction == Order.SELL else \"buy\",\n                        order.stock_id,\n                        trade_price,\n                        order.amount,\n                        order.deal_amount,\n                        order.factor,\n                        trade_val,\n                        self.trade_account.get_cash(),\n                    ),\n                )\n        return execute_result, {\"trade_info\": execute_result}\n"
  },
  {
    "path": "qlib/backtest/high_performance_ds.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom __future__ import annotations\n\nimport inspect\nimport logging\nfrom collections import OrderedDict\nfrom functools import lru_cache\nfrom typing import Any, Callable, Dict, Iterable, List, Optional, Text, Union, cast\n\nimport numpy as np\nimport pandas as pd\n\nimport qlib.utils.index_data as idd\n\nfrom ..log import get_module_logger\nfrom ..utils.index_data import IndexData, SingleData\nfrom ..utils.resam import resam_ts_data, ts_data_last\nfrom ..utils.time import Freq, is_single_value\n\n\nclass BaseQuote:\n    def __init__(self, quote_df: pd.DataFrame, freq: str) -> None:\n        self.logger = get_module_logger(\"online operator\", level=logging.INFO)\n\n    def get_all_stock(self) -> Iterable:\n        \"\"\"return all stock codes\n\n        Return\n        ------\n        Iterable\n            all stock codes\n        \"\"\"\n\n        raise NotImplementedError(f\"Please implement the `get_all_stock` method\")\n\n    def get_data(\n        self,\n        stock_id: str,\n        start_time: Union[pd.Timestamp, str],\n        end_time: Union[pd.Timestamp, str],\n        field: Union[str],\n        method: Optional[str] = None,\n    ) -> Union[None, int, float, bool, IndexData]:\n        \"\"\"get the specific field of stock data during start time and end_time,\n           and apply method to the data.\n\n           Example:\n            .. code-block::\n                                        $close      $volume\n                instrument  datetime\n                SH600000    2010-01-04  86.778313   16162960.0\n                            2010-01-05  87.433578   28117442.0\n                            2010-01-06  85.713585   23632884.0\n                            2010-01-07  83.788803   20813402.0\n                            2010-01-08  84.730675   16044853.0\n\n                SH600655    2010-01-04  2699.567383  158193.328125\n                            2010-01-08  2612.359619   77501.406250\n                            2010-01-11  2712.982422  160852.390625\n                            2010-01-12  2788.688232  164587.937500\n                            2010-01-13  2790.604004  145460.453125\n\n                this function is used for three case:\n\n                1. method is not None. It returns int/float/bool/None.\n                    - It will return None in one case, the method return None\n\n                    print(get_data(stock_id=\"SH600000\", start_time=\"2010-01-04\", end_time=\"2010-01-06\", field=\"$close\", method=\"last\"))\n\n                    85.713585\n\n                2. method is None. It returns IndexData.\n                    print(get_data(stock_id=\"SH600000\", start_time=\"2010-01-04\", end_time=\"2010-01-06\", field=\"$close\", method=None))\n\n                    IndexData([86.778313, 87.433578, 85.713585], [2010-01-04, 2010-01-05, 2010-01-06])\n\n        Parameters\n        ----------\n        stock_id: str\n        start_time : Union[pd.Timestamp, str]\n            closed start time for backtest\n        end_time : Union[pd.Timestamp, str]\n            closed end time for backtest\n        field : str\n            the columns of data to fetch\n        method : Union[str, None]\n            the method apply to data.\n            e.g [None, \"last\", \"all\", \"sum\", \"mean\", \"ts_data_last\"]\n\n        Return\n        ----------\n        Union[None, int, float, bool, IndexData]\n            it will return None in following cases\n            - There is no stock data which meet the query criterion from data source.\n            - The `method` returns None\n        \"\"\"\n\n        raise NotImplementedError(f\"Please implement the `get_data` method\")\n\n\nclass PandasQuote(BaseQuote):\n    def __init__(self, quote_df: pd.DataFrame, freq: str) -> None:\n        super().__init__(quote_df=quote_df, freq=freq)\n        quote_dict = {}\n        for stock_id, stock_val in quote_df.groupby(level=\"instrument\", group_keys=False):\n            quote_dict[stock_id] = stock_val.droplevel(level=\"instrument\")\n        self.data = quote_dict\n\n    def get_all_stock(self):\n        return self.data.keys()\n\n    def get_data(self, stock_id, start_time, end_time, field, method=None):\n        if method == \"ts_data_last\":\n            method = ts_data_last\n        stock_data = resam_ts_data(self.data[stock_id][field], start_time, end_time, method=method)\n        if stock_data is None:\n            return None\n        elif isinstance(stock_data, (bool, np.bool_, int, float, np.number)):\n            return stock_data\n        elif isinstance(stock_data, pd.Series):\n            return idd.SingleData(stock_data)\n        else:\n            raise ValueError(f\"stock data from resam_ts_data must be a number, pd.Series or pd.DataFrame\")\n\n\nclass NumpyQuote(BaseQuote):\n    def __init__(self, quote_df: pd.DataFrame, freq: str, region: str = \"cn\") -> None:\n        \"\"\"NumpyQuote\n\n        Parameters\n        ----------\n        quote_df : pd.DataFrame\n            the init dataframe from qlib.\n        self.data : Dict(stock_id, IndexData.DataFrame)\n        \"\"\"\n        super().__init__(quote_df=quote_df, freq=freq)\n        quote_dict = {}\n        for stock_id, stock_val in quote_df.groupby(level=\"instrument\", group_keys=False):\n            quote_dict[stock_id] = idd.MultiData(stock_val.droplevel(level=\"instrument\"))\n            quote_dict[stock_id].sort_index()  # To support more flexible slicing, we must sort data first\n        self.data = quote_dict\n\n        n, unit = Freq.parse(freq)\n        if unit in Freq.SUPPORT_CAL_LIST:\n            self.freq = Freq.get_timedelta(1, unit)\n        else:\n            raise ValueError(f\"{freq} is not supported in NumpyQuote\")\n        self.region = region\n\n    def get_all_stock(self):\n        return self.data.keys()\n\n    @lru_cache(maxsize=512)\n    def get_data(self, stock_id, start_time, end_time, field, method=None):\n        # check stock id\n        if stock_id not in self.get_all_stock():\n            return None\n\n        # single data\n        # If it don't consider the classification of single data, it will consume a lot of time.\n        if is_single_value(start_time, end_time, self.freq, self.region):\n            # this is a very special case.\n            # skip aggregating function to speed-up the query calculation\n\n            # FIXME:\n            # it will go to the else logic when it comes to the\n            # 1) the day before holiday when daily trading\n            # 2) the last minute of the day when intraday trading\n            try:\n                return self.data[stock_id].loc[start_time, field]\n            except KeyError:\n                return None\n        else:\n            data = self.data[stock_id].loc[start_time:end_time, field]\n            if data.empty:\n                return None\n            if method is not None:\n                data = self._agg_data(data, method)\n            return data\n\n    @staticmethod\n    def _agg_data(data: IndexData, method: str) -> Union[IndexData, np.ndarray, None]:\n        \"\"\"Agg data by specific method.\"\"\"\n        # FIXME: why not call the method of data directly?\n        if method == \"sum\":\n            return np.nansum(data)\n        elif method == \"mean\":\n            return np.nanmean(data)\n        elif method == \"last\":\n            # FIXME: I've never seen that this method was called.\n            # Please merge it with \"ts_data_last\"\n            return data[-1]\n        elif method == \"all\":\n            return data.all()\n        elif method == \"ts_data_last\":\n            valid_data = data.loc[~data.isna().data.astype(bool)]\n            if len(valid_data) == 0:\n                return None\n            else:\n                return valid_data.iloc[-1]\n        else:\n            raise ValueError(f\"{method} is not supported\")\n\n\nclass BaseSingleMetric:\n    \"\"\"\n    The data structure of the single metric.\n    The following methods are used for computing metrics in one indicator.\n    \"\"\"\n\n    def __init__(self, metric: Union[dict, pd.Series]):\n        \"\"\"Single data structure for each metric.\n\n        Parameters\n        ----------\n        metric : Union[dict, pd.Series]\n            keys/index is stock_id, value is the metric value.\n            for example:\n                SH600068    NaN\n                SH600079    1.0\n                SH600266    NaN\n                           ...\n                SZ300692    NaN\n                SZ300719    NaN,\n        \"\"\"\n        raise NotImplementedError(f\"Please implement the `__init__` method\")\n\n    def __add__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric:\n        raise NotImplementedError(f\"Please implement the `__add__` method\")\n\n    def __radd__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric:\n        return self + other\n\n    def __sub__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric:\n        raise NotImplementedError(f\"Please implement the `__sub__` method\")\n\n    def __rsub__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric:\n        raise NotImplementedError(f\"Please implement the `__rsub__` method\")\n\n    def __mul__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric:\n        raise NotImplementedError(f\"Please implement the `__mul__` method\")\n\n    def __truediv__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric:\n        raise NotImplementedError(f\"Please implement the `__truediv__` method\")\n\n    def __eq__(self, other: object) -> BaseSingleMetric:\n        raise NotImplementedError(f\"Please implement the `__eq__` method\")\n\n    def __gt__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric:\n        raise NotImplementedError(f\"Please implement the `__gt__` method\")\n\n    def __lt__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric:\n        raise NotImplementedError(f\"Please implement the `__lt__` method\")\n\n    def __len__(self) -> int:\n        raise NotImplementedError(f\"Please implement the `__len__` method\")\n\n    def sum(self) -> float:\n        raise NotImplementedError(f\"Please implement the `sum` method\")\n\n    def mean(self) -> float:\n        raise NotImplementedError(f\"Please implement the `mean` method\")\n\n    def count(self) -> int:\n        \"\"\"Return the count of the single metric, NaN is not included.\"\"\"\n\n        raise NotImplementedError(f\"Please implement the `count` method\")\n\n    def abs(self) -> BaseSingleMetric:\n        raise NotImplementedError(f\"Please implement the `abs` method\")\n\n    @property\n    def empty(self) -> bool:\n        \"\"\"If metric is empty, return True.\"\"\"\n\n        raise NotImplementedError(f\"Please implement the `empty` method\")\n\n    def add(self, other: BaseSingleMetric, fill_value: float = None) -> BaseSingleMetric:\n        \"\"\"Replace np.nan with fill_value in two metrics and add them.\"\"\"\n\n        raise NotImplementedError(f\"Please implement the `add` method\")\n\n    def replace(self, replace_dict: dict) -> BaseSingleMetric:\n        \"\"\"Replace the value of metric according to replace_dict.\"\"\"\n\n        raise NotImplementedError(f\"Please implement the `replace` method\")\n\n    def apply(self, func: Callable) -> BaseSingleMetric:\n        \"\"\"Replace the value of metric with func (metric).\n        Currently, the func is only qlib/backtest/order/Order.parse_dir.\n        \"\"\"\n\n        raise NotImplementedError(f\"Please implement the 'apply' method\")\n\n\nclass BaseOrderIndicator:\n    \"\"\"\n    The data structure of order indicator.\n    !!!NOTE: There are two ways to organize the data structure. Please choose a better way.\n        1. One way is using BaseSingleMetric to represent each metric. For example, the data\n        structure of PandasOrderIndicator is Dict[str, PandasSingleMetric]. It uses\n        PandasSingleMetric based on pd.Series to represent each metric.\n        2. The another way doesn't use BaseSingleMetric to represent each metric. The data\n        structure of PandasOrderIndicator is a whole matrix. It means you are not necessary\n        to inherit the BaseSingleMetric.\n    \"\"\"\n\n    def __init__(self):\n        self.data = {}  # will be created in the subclass\n        self.logger = get_module_logger(\"online operator\")\n\n    def assign(self, col: str, metric: Union[dict, pd.Series]) -> None:\n        \"\"\"assign one metric.\n\n        Parameters\n        ----------\n        col : str\n            the metric name of one metric.\n        metric : Union[dict, pd.Series]\n            one metric with stock_id index, such as deal_amount, ffr, etc.\n            for example:\n                SH600068    NaN\n                SH600079    1.0\n                SH600266    NaN\n                           ...\n                SZ300692    NaN\n                SZ300719    NaN,\n        \"\"\"\n\n        raise NotImplementedError(f\"Please implement the 'assign' method\")\n\n    def transfer(self, func: Callable, new_col: str = None) -> Optional[BaseSingleMetric]:\n        \"\"\"compute new metric with existing metrics.\n\n        Parameters\n        ----------\n        func : Callable\n            the func of computing new metric.\n            the kwargs of func will be replaced with metric data by name in this function.\n            e.g.\n                def func(pa):\n                    return (pa > 0).sum() / pa.count()\n        new_col : str, optional\n            New metric will be assigned in the data if new_col is not None, by default None.\n\n        Return\n        ----------\n        BaseSingleMetric\n            new metric.\n        \"\"\"\n        func_sig = inspect.signature(func).parameters.keys()\n        func_kwargs = {sig: self.data[sig] for sig in func_sig}\n        tmp_metric = func(**func_kwargs)\n        if new_col is not None:\n            self.data[new_col] = tmp_metric\n            return None\n        else:\n            return tmp_metric\n\n    def get_metric_series(self, metric: str) -> pd.Series:\n        \"\"\"return the single metric with pd.Series format.\n\n        Parameters\n        ----------\n        metric : str\n            the metric name.\n\n        Return\n        ----------\n        pd.Series\n            the single metric.\n            If there is no metric name in the data, return pd.Series().\n        \"\"\"\n\n        raise NotImplementedError(f\"Please implement the 'get_metric_series' method\")\n\n    def get_index_data(self, metric: str) -> SingleData:\n        \"\"\"get one metric with the format of SingleData\n\n        Parameters\n        ----------\n        metric : str\n            the metric name.\n\n        Return\n        ------\n        IndexData.Series\n            one metric with the format of SingleData\n        \"\"\"\n\n        raise NotImplementedError(f\"Please implement the 'get_index_data' method\")\n\n    @staticmethod\n    def sum_all_indicators(\n        order_indicator: BaseOrderIndicator,\n        indicators: List[BaseOrderIndicator],\n        metrics: Union[str, List[str]],\n        fill_value: float = 0,\n    ) -> None:\n        \"\"\"sum indicators with the same metrics.\n        and assign to the order_indicator(BaseOrderIndicator).\n        NOTE: indicators could be a empty list when orders in lower level all fail.\n\n        Parameters\n        ----------\n        order_indicator : BaseOrderIndicator\n            the order indicator to assign.\n        indicators : List[BaseOrderIndicator]\n            the list of all inner indicators.\n        metrics : Union[str, List[str]]\n            all metrics needs to be sumed.\n        fill_value : float, optional\n            fill np.nan with value. By default None.\n        \"\"\"\n\n        raise NotImplementedError(f\"Please implement the 'sum_all_indicators' method\")\n\n    def to_series(self) -> Dict[Text, pd.Series]:\n        \"\"\"return the metrics as pandas series\n\n        for example: { \"ffr\":\n                SH600068    NaN\n                SH600079    1.0\n                SH600266    NaN\n                           ...\n                SZ300692    NaN\n                SZ300719    NaN,\n                ...\n         }\n        \"\"\"\n        raise NotImplementedError(f\"Please implement the `to_series` method\")\n\n\nclass SingleMetric(BaseSingleMetric):\n    def __init__(self, metric):\n        self.metric = metric\n\n    def __add__(self, other):\n        if isinstance(other, (int, float)):\n            return self.__class__(self.metric + other)\n        elif isinstance(other, self.__class__):\n            return self.__class__(self.metric + other.metric)\n        else:\n            return NotImplemented\n\n    def __sub__(self, other):\n        if isinstance(other, (int, float)):\n            return self.__class__(self.metric - other)\n        elif isinstance(other, self.__class__):\n            return self.__class__(self.metric - other.metric)\n        else:\n            return NotImplemented\n\n    def __rsub__(self, other):\n        if isinstance(other, (int, float)):\n            return self.__class__(other - self.metric)\n        elif isinstance(other, self.__class__):\n            return self.__class__(other.metric - self.metric)\n        else:\n            return NotImplemented\n\n    def __mul__(self, other):\n        if isinstance(other, (int, float)):\n            return self.__class__(self.metric * other)\n        elif isinstance(other, self.__class__):\n            return self.__class__(self.metric * other.metric)\n        else:\n            return NotImplemented\n\n    def __truediv__(self, other):\n        if isinstance(other, (int, float)):\n            return self.__class__(self.metric / other)\n        elif isinstance(other, self.__class__):\n            return self.__class__(self.metric / other.metric)\n        else:\n            return NotImplemented\n\n    def __eq__(self, other):\n        if isinstance(other, (int, float)):\n            return self.__class__(self.metric == other)\n        elif isinstance(other, self.__class__):\n            return self.__class__(self.metric == other.metric)\n        else:\n            return NotImplemented\n\n    def __gt__(self, other):\n        if isinstance(other, (int, float)):\n            return self.__class__(self.metric > other)\n        elif isinstance(other, self.__class__):\n            return self.__class__(self.metric > other.metric)\n        else:\n            return NotImplemented\n\n    def __lt__(self, other):\n        if isinstance(other, (int, float)):\n            return self.__class__(self.metric < other)\n        elif isinstance(other, self.__class__):\n            return self.__class__(self.metric < other.metric)\n        else:\n            return NotImplemented\n\n    def __len__(self):\n        return len(self.metric)\n\n\nclass PandasSingleMetric(SingleMetric):\n    \"\"\"Each SingleMetric is based on pd.Series.\"\"\"\n\n    def __init__(self, metric: Union[dict, pd.Series] = {}):\n        if isinstance(metric, dict):\n            self.metric = pd.Series(metric)\n        elif isinstance(metric, pd.Series):\n            self.metric = metric\n        else:\n            raise ValueError(f\"metric must be dict or pd.Series\")\n\n    def sum(self):\n        return self.metric.sum()\n\n    def mean(self):\n        return self.metric.mean()\n\n    def count(self):\n        return self.metric.count()\n\n    def abs(self):\n        return self.__class__(self.metric.abs())\n\n    @property\n    def empty(self):\n        return self.metric.empty\n\n    @property\n    def index(self):\n        return list(self.metric.index)\n\n    def add(self, other: BaseSingleMetric, fill_value: float = None) -> PandasSingleMetric:\n        other = cast(PandasSingleMetric, other)\n        return self.__class__(self.metric.add(other.metric, fill_value=fill_value))\n\n    def replace(self, replace_dict: dict) -> PandasSingleMetric:\n        return self.__class__(self.metric.replace(replace_dict))\n\n    def apply(self, func: Callable) -> PandasSingleMetric:\n        return self.__class__(self.metric.apply(func))\n\n    def reindex(self, index: Any, fill_value: float) -> PandasSingleMetric:\n        return self.__class__(self.metric.reindex(index, fill_value=fill_value))\n\n    def __repr__(self):\n        return repr(self.metric)\n\n\nclass PandasOrderIndicator(BaseOrderIndicator):\n    \"\"\"\n    The data structure is OrderedDict(str: PandasSingleMetric).\n    Each PandasSingleMetric based on pd.Series is one metric.\n    Str is the name of metric.\n    \"\"\"\n\n    def __init__(self) -> None:\n        super(PandasOrderIndicator, self).__init__()\n        self.data: Dict[str, PandasSingleMetric] = OrderedDict()\n\n    def assign(self, col: str, metric: Union[dict, pd.Series]) -> None:\n        self.data[col] = PandasSingleMetric(metric)\n\n    def get_index_data(self, metric: str) -> SingleData:\n        if metric in self.data:\n            return idd.SingleData(self.data[metric].metric)\n        else:\n            return idd.SingleData()\n\n    def get_metric_series(self, metric: str) -> Union[pd.Series]:\n        if metric in self.data:\n            return self.data[metric].metric\n        else:\n            return pd.Series()\n\n    def to_series(self):\n        return {k: v.metric for k, v in self.data.items()}\n\n    @staticmethod\n    def sum_all_indicators(\n        order_indicator: BaseOrderIndicator,\n        indicators: List[BaseOrderIndicator],\n        metrics: Union[str, List[str]],\n        fill_value: float = 0,\n    ) -> None:\n        if isinstance(metrics, str):\n            metrics = [metrics]\n        for metric in metrics:\n            tmp_metric = PandasSingleMetric({})\n            for indicator in indicators:\n                tmp_metric = tmp_metric.add(indicator.data[metric], fill_value)\n            order_indicator.assign(metric, tmp_metric.metric)\n\n    def __repr__(self):\n        return repr(self.data)\n\n\nclass NumpyOrderIndicator(BaseOrderIndicator):\n    \"\"\"\n    The data structure is OrderedDict(str: SingleData).\n    Each idd.SingleData is one metric.\n    Str is the name of metric.\n    \"\"\"\n\n    def __init__(self) -> None:\n        super(NumpyOrderIndicator, self).__init__()\n        self.data: Dict[str, SingleData] = OrderedDict()\n\n    def assign(self, col: str, metric: dict) -> None:\n        self.data[col] = idd.SingleData(metric)\n\n    def get_index_data(self, metric: str) -> SingleData:\n        if metric in self.data:\n            return self.data[metric]\n        else:\n            return idd.SingleData()\n\n    def get_metric_series(self, metric: str) -> Union[pd.Series]:\n        return self.data[metric].to_series()\n\n    def to_series(self) -> Dict[str, pd.Series]:\n        tmp_metric_dict = {}\n        for metric in self.data:\n            tmp_metric_dict[metric] = self.get_metric_series(metric)\n        return tmp_metric_dict\n\n    @staticmethod\n    def sum_all_indicators(\n        order_indicator: BaseOrderIndicator,\n        indicators: List[BaseOrderIndicator],\n        metrics: Union[str, List[str]],\n        fill_value: float = 0,\n    ) -> None:\n        # get all index(stock_id)\n        stock_set: set = set()\n        for indicator in indicators:\n            # set(np.ndarray.tolist()) is faster than set(np.ndarray)\n            stock_set = stock_set | set(indicator.data[metrics[0]].index.tolist())\n        stocks = sorted(list(stock_set))\n\n        # add metric by index\n        if isinstance(metrics, str):\n            metrics = [metrics]\n        for metric in metrics:\n            order_indicator.data[metric] = idd.sum_by_index(\n                [indicator.data[metric] for indicator in indicators],\n                stocks,\n                fill_value,\n            )\n\n    def __repr__(self):\n        return repr(self.data)\n"
  },
  {
    "path": "qlib/backtest/position.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom __future__ import annotations\n\nfrom datetime import timedelta\nfrom typing import Any, Dict, List, Union\n\nimport numpy as np\nimport pandas as pd\n\nfrom ..data.data import D\nfrom .decision import Order\n\n\nclass BasePosition:\n    \"\"\"\n    The Position wants to maintain the position like a dictionary\n    Please refer to the `Position` class for the position\n    \"\"\"\n\n    def __init__(self, *args: Any, cash: float = 0.0, **kwargs: Any) -> None:\n        self._settle_type = self.ST_NO\n        self.position: dict = {}\n\n    def fill_stock_value(self, start_time: Union[str, pd.Timestamp], freq: str, last_days: int = 30) -> None:\n        pass\n\n    def skip_update(self) -> bool:\n        \"\"\"\n        Should we skip updating operation for this position\n        For example, updating is meaningless for InfPosition\n\n        Returns\n        -------\n        bool:\n            should we skip the updating operator\n        \"\"\"\n        return False\n\n    def check_stock(self, stock_id: str) -> bool:\n        \"\"\"\n        check if is the stock in the position\n\n        Parameters\n        ----------\n        stock_id : str\n            the id of the stock\n\n        Returns\n        -------\n        bool:\n            if is the stock in the position\n        \"\"\"\n        raise NotImplementedError(f\"Please implement the `check_stock` method\")\n\n    def update_order(self, order: Order, trade_val: float, cost: float, trade_price: float) -> None:\n        \"\"\"\n        Parameters\n        ----------\n        order : Order\n            the order to update the position\n        trade_val : float\n            the trade value(money) of dealing results\n        cost : float\n            the trade cost of the dealing results\n        trade_price : float\n            the trade price of the dealing results\n        \"\"\"\n        raise NotImplementedError(f\"Please implement the `update_order` method\")\n\n    def update_stock_price(self, stock_id: str, price: float) -> None:\n        \"\"\"\n        Updating the latest price of the order\n        The useful when clearing balance at each bar end\n\n        Parameters\n        ----------\n        stock_id :\n            the id of the stock\n        price : float\n            the price to be updated\n        \"\"\"\n        raise NotImplementedError(f\"Please implement the `update stock price` method\")\n\n    def calculate_stock_value(self) -> float:\n        \"\"\"\n        calculate the value of the all assets except cash in the position\n\n        Returns\n        -------\n        float:\n            the value(money) of all the stock\n        \"\"\"\n        raise NotImplementedError(f\"Please implement the `calculate_stock_value` method\")\n\n    def calculate_value(self) -> float:\n        raise NotImplementedError(f\"Please implement the `calculate_value` method\")\n\n    def get_stock_list(self) -> List[str]:\n        \"\"\"\n        Get the list of stocks in the position.\n        \"\"\"\n        raise NotImplementedError(f\"Please implement the `get_stock_list` method\")\n\n    def get_stock_price(self, code: str) -> float:\n        \"\"\"\n        get the latest price of the stock\n\n        Parameters\n        ----------\n        code :\n            the code of the stock\n        \"\"\"\n        raise NotImplementedError(f\"Please implement the `get_stock_price` method\")\n\n    def get_stock_amount(self, code: str) -> float:\n        \"\"\"\n        get the amount of the stock\n\n        Parameters\n        ----------\n        code :\n            the code of the stock\n\n        Returns\n        -------\n        float:\n            the amount of the stock\n        \"\"\"\n        raise NotImplementedError(f\"Please implement the `get_stock_amount` method\")\n\n    def get_cash(self, include_settle: bool = False) -> float:\n        \"\"\"\n        Parameters\n        ----------\n        include_settle:\n            will the unsettled(delayed) cash included\n            Default: not include those unavailable cash\n\n        Returns\n        -------\n        float:\n            the available(tradable) cash in position\n        \"\"\"\n        raise NotImplementedError(f\"Please implement the `get_cash` method\")\n\n    def get_stock_amount_dict(self) -> dict:\n        \"\"\"\n        generate stock amount dict {stock_id : amount of stock}\n\n        Returns\n        -------\n        Dict:\n            {stock_id : amount of stock}\n        \"\"\"\n        raise NotImplementedError(f\"Please implement the `get_stock_amount_dict` method\")\n\n    def get_stock_weight_dict(self, only_stock: bool = False) -> dict:\n        \"\"\"\n        generate stock weight dict {stock_id : value weight of stock in the position}\n        it is meaningful in the beginning or the end of each trade step\n        - During execution of each trading step, the weight may be not consistent with the portfolio value\n\n        Parameters\n        ----------\n        only_stock : bool\n            If only_stock=True, the weight of each stock in total stock will be returned\n            If only_stock=False, the weight of each stock in total assets(stock + cash) will be returned\n\n        Returns\n        -------\n        Dict:\n            {stock_id : value weight of stock in the position}\n        \"\"\"\n        raise NotImplementedError(f\"Please implement the `get_stock_weight_dict` method\")\n\n    def add_count_all(self, bar: str) -> None:\n        \"\"\"\n        Will be called at the end of each bar on each level\n\n        Parameters\n        ----------\n        bar :\n            The level to be updated\n        \"\"\"\n        raise NotImplementedError(f\"Please implement the `add_count_all` method\")\n\n    def update_weight_all(self) -> None:\n        \"\"\"\n        Updating the position weight;\n\n        # TODO: this function is a little weird. The weight data in the position is in a wrong state after dealing order\n        # and before updating weight.\n        \"\"\"\n        raise NotImplementedError(f\"Please implement the `add_count_all` method\")\n\n    ST_CASH = \"cash\"\n    ST_NO = \"None\"  # String is more typehint friendly than None\n\n    def settle_start(self, settle_type: str) -> None:\n        \"\"\"\n        settlement start\n        It will act like start and commit a transaction\n\n        Parameters\n        ----------\n        settle_type : str\n            Should we make delay the settlement in each execution (each execution will make the executor a step forward)\n            - \"cash\": make the cash settlement delayed.\n                - The cash you get can't be used in current step (e.g. you can't sell a stock to get cash to buy another\n                        stock)\n            - None: not settlement mechanism\n            - TODO: other assets will be supported in the future.\n        \"\"\"\n        raise NotImplementedError(f\"Please implement the `settle_conf` method\")\n\n    def settle_commit(self) -> None:\n        \"\"\"\n        settlement commit\n        \"\"\"\n        raise NotImplementedError(f\"Please implement the `settle_commit` method\")\n\n    def __str__(self) -> str:\n        return self.__dict__.__str__()\n\n    def __repr__(self) -> str:\n        return self.__dict__.__repr__()\n\n\nclass Position(BasePosition):\n    \"\"\"Position\n\n    current state of position\n    a typical example is :{\n      <instrument_id>: {\n        'count': <how many days the security has been hold>,\n        'amount': <the amount of the security>,\n        'price': <the close price of security in the last trading day>,\n        'weight': <the security weight of total position value>,\n      },\n    }\n    \"\"\"\n\n    def __init__(self, cash: float = 0, position_dict: Dict[str, Union[Dict[str, float], float]] = {}) -> None:\n        \"\"\"Init position by cash and position_dict.\n\n        Parameters\n        ----------\n        cash : float, optional\n            initial cash in account, by default 0\n        position_dict : Dict[\n                            stock_id,\n                            Union[\n                                int,  # it is equal to {\"amount\": int}\n                                {\"amount\": int, \"price\"(optional): float},\n                            ]\n                        ]\n            initial stocks with parameters amount and price,\n            if there is no price key in the dict of stocks, it will be filled by _fill_stock_value.\n            by default {}.\n        \"\"\"\n        super().__init__()\n\n        # NOTE: The position dict must be copied!!!\n        # Otherwise the initial value\n        self.init_cash = cash\n        self.position = position_dict.copy()\n        for stock, value in self.position.items():\n            if isinstance(value, int):\n                self.position[stock] = {\"amount\": value}\n        self.position[\"cash\"] = cash\n\n        # If the stock price information is missing, the account value will not be calculated temporarily\n        try:\n            self.position[\"now_account_value\"] = self.calculate_value()\n        except KeyError:\n            pass\n\n    def fill_stock_value(self, start_time: Union[str, pd.Timestamp], freq: str, last_days: int = 30) -> None:\n        \"\"\"fill the stock value by the close price of latest last_days from qlib.\n\n        Parameters\n        ----------\n        start_time :\n            the start time of backtest.\n        freq : str\n            Frequency\n        last_days : int, optional\n            the days to get the latest close price, by default 30.\n        \"\"\"\n        stock_list = []\n        for stock, value in self.position.items():\n            if not isinstance(value, dict):\n                continue\n            if value.get(\"price\", None) is None:\n                stock_list.append(stock)\n\n        if len(stock_list) == 0:\n            return\n\n        start_time = pd.Timestamp(start_time)\n        # note that start time is 2020-01-01 00:00:00 if raw start time is \"2020-01-01\"\n        price_end_time = start_time\n        price_start_time = start_time - timedelta(days=last_days)\n        price_df = D.features(\n            stock_list,\n            [\"$close\"],\n            price_start_time,\n            price_end_time,\n            freq=freq,\n            disk_cache=True,\n        ).dropna()\n        price_dict = price_df.groupby([\"instrument\"], group_keys=False).tail(1)[\"$close\"].to_dict()\n\n        if len(price_dict) < len(stock_list):\n            lack_stock = set(stock_list) - set(price_dict)\n            raise ValueError(f\"{lack_stock} doesn't have close price in qlib in the latest {last_days} days\")\n\n        for stock in stock_list:\n            self.position[stock][\"price\"] = price_dict[stock]\n        self.position[\"now_account_value\"] = self.calculate_value()\n\n    def _init_stock(self, stock_id: str, amount: float, price: float | None = None) -> None:\n        \"\"\"\n        initialization the stock in current position\n\n        Parameters\n        ----------\n        stock_id :\n            the id of the stock\n        amount : float\n            the amount of the stock\n        price :\n             the price when buying the init stock\n        \"\"\"\n        self.position[stock_id] = {}\n        self.position[stock_id][\"amount\"] = amount\n        self.position[stock_id][\"price\"] = price\n        self.position[stock_id][\"weight\"] = 0  # update the weight in the end of the trade date\n\n    def _buy_stock(self, stock_id: str, trade_val: float, cost: float, trade_price: float) -> None:\n        trade_amount = trade_val / trade_price\n        if stock_id not in self.position:\n            self._init_stock(stock_id=stock_id, amount=trade_amount, price=trade_price)\n        else:\n            # exist, add amount\n            self.position[stock_id][\"amount\"] += trade_amount\n\n        self.position[\"cash\"] -= trade_val + cost\n\n    def _sell_stock(self, stock_id: str, trade_val: float, cost: float, trade_price: float) -> None:\n        trade_amount = trade_val / trade_price\n        if stock_id not in self.position:\n            raise KeyError(\"{} not in current position\".format(stock_id))\n        else:\n            if np.isclose(self.position[stock_id][\"amount\"], trade_amount):\n                # Selling all the stocks\n                # we use np.isclose instead of abs(<the final amount>) <= 1e-5  because `np.isclose` consider both\n                # relative amount and absolute amount\n                # Using abs(<the final amount>) <= 1e-5 will result in error when the amount is large\n                self._del_stock(stock_id)\n            else:\n                # decrease the amount of stock\n                self.position[stock_id][\"amount\"] -= trade_amount\n                # check if to delete\n                if self.position[stock_id][\"amount\"] < -1e-5:\n                    raise ValueError(\n                        \"only have {} {}, require {}\".format(\n                            self.position[stock_id][\"amount\"] + trade_amount,\n                            stock_id,\n                            trade_amount,\n                        ),\n                    )\n\n        new_cash = trade_val - cost\n        if self._settle_type == self.ST_CASH:\n            self.position[\"cash_delay\"] += new_cash\n        elif self._settle_type == self.ST_NO:\n            self.position[\"cash\"] += new_cash\n        else:\n            raise NotImplementedError(f\"This type of input is not supported\")\n\n    def _del_stock(self, stock_id: str) -> None:\n        del self.position[stock_id]\n\n    def check_stock(self, stock_id: str) -> bool:\n        return stock_id in self.position\n\n    def update_order(self, order: Order, trade_val: float, cost: float, trade_price: float) -> None:\n        # handle order, order is a order class, defined in exchange.py\n        if order.direction == Order.BUY:\n            # BUY\n            self._buy_stock(order.stock_id, trade_val, cost, trade_price)\n        elif order.direction == Order.SELL:\n            # SELL\n            self._sell_stock(order.stock_id, trade_val, cost, trade_price)\n        else:\n            raise NotImplementedError(\"do not support order direction {}\".format(order.direction))\n\n    def update_stock_price(self, stock_id: str, price: float) -> None:\n        self.position[stock_id][\"price\"] = price\n\n    def update_stock_count(self, stock_id: str, bar: str, count: float) -> None:  # TODO: check type of `bar`\n        self.position[stock_id][f\"count_{bar}\"] = count\n\n    def update_stock_weight(self, stock_id: str, weight: float) -> None:\n        self.position[stock_id][\"weight\"] = weight\n\n    def calculate_stock_value(self) -> float:\n        stock_list = self.get_stock_list()\n        value = 0\n        for stock_id in stock_list:\n            value += self.position[stock_id][\"amount\"] * self.position[stock_id][\"price\"]\n        return value\n\n    def calculate_value(self) -> float:\n        value = self.calculate_stock_value()\n        value += self.position[\"cash\"] + self.position.get(\"cash_delay\", 0.0)\n        return value\n\n    def get_stock_list(self) -> List[str]:\n        stock_list = list(set(self.position.keys()) - {\"cash\", \"now_account_value\", \"cash_delay\"})\n        return stock_list\n\n    def get_stock_price(self, code: str) -> float:\n        return self.position[code][\"price\"]\n\n    def get_stock_amount(self, code: str) -> float:\n        return self.position[code][\"amount\"] if code in self.position else 0\n\n    def get_stock_count(self, code: str, bar: str) -> float:\n        \"\"\"the days the account has been hold, it may be used in some special strategies\"\"\"\n        if f\"count_{bar}\" in self.position[code]:\n            return self.position[code][f\"count_{bar}\"]\n        else:\n            return 0\n\n    def get_stock_weight(self, code: str) -> float:\n        return self.position[code][\"weight\"]\n\n    def get_cash(self, include_settle: bool = False) -> float:\n        cash = self.position[\"cash\"]\n        if include_settle:\n            cash += self.position.get(\"cash_delay\", 0.0)\n        return cash\n\n    def get_stock_amount_dict(self) -> dict:\n        \"\"\"generate stock amount dict {stock_id : amount of stock}\"\"\"\n        d = {}\n        stock_list = self.get_stock_list()\n        for stock_code in stock_list:\n            d[stock_code] = self.get_stock_amount(code=stock_code)\n        return d\n\n    def get_stock_weight_dict(self, only_stock: bool = False) -> dict:\n        \"\"\"get_stock_weight_dict\n        generate stock weight dict {stock_id : value weight of stock in the position}\n        it is meaningful in the beginning or the end of each trade date\n\n        :param only_stock: If only_stock=True, the weight of each stock in total stock will be returned\n                           If only_stock=False, the weight of each stock in total assets(stock + cash) will be returned\n        \"\"\"\n        if only_stock:\n            position_value = self.calculate_stock_value()\n        else:\n            position_value = self.calculate_value()\n        d = {}\n        stock_list = self.get_stock_list()\n        for stock_code in stock_list:\n            d[stock_code] = self.position[stock_code][\"amount\"] * self.position[stock_code][\"price\"] / position_value\n        return d\n\n    def add_count_all(self, bar: str) -> None:\n        stock_list = self.get_stock_list()\n        for code in stock_list:\n            if f\"count_{bar}\" in self.position[code]:\n                self.position[code][f\"count_{bar}\"] += 1\n            else:\n                self.position[code][f\"count_{bar}\"] = 1\n\n    def update_weight_all(self) -> None:\n        weight_dict = self.get_stock_weight_dict()\n        for stock_code, weight in weight_dict.items():\n            self.update_stock_weight(stock_code, weight)\n\n    def settle_start(self, settle_type: str) -> None:\n        assert self._settle_type == self.ST_NO, \"Currently, settlement can't be nested!!!!!\"\n        self._settle_type = settle_type\n        if settle_type == self.ST_CASH:\n            self.position[\"cash_delay\"] = 0.0\n\n    def settle_commit(self) -> None:\n        if self._settle_type != self.ST_NO:\n            if self._settle_type == self.ST_CASH:\n                self.position[\"cash\"] += self.position[\"cash_delay\"]\n                del self.position[\"cash_delay\"]\n            else:\n                raise NotImplementedError(f\"This type of input is not supported\")\n            self._settle_type = self.ST_NO\n\n\nclass InfPosition(BasePosition):\n    \"\"\"\n    Position with infinite cash and amount.\n\n    This is useful for generating random orders.\n    \"\"\"\n\n    def skip_update(self) -> bool:\n        \"\"\"Updating state is meaningless for InfPosition\"\"\"\n        return True\n\n    def check_stock(self, stock_id: str) -> bool:\n        # InfPosition always have any stocks\n        return True\n\n    def update_order(self, order: Order, trade_val: float, cost: float, trade_price: float) -> None:\n        pass\n\n    def update_stock_price(self, stock_id: str, price: float) -> None:\n        pass\n\n    def calculate_stock_value(self) -> float:\n        \"\"\"\n        Returns\n        -------\n        float:\n            infinity stock value\n        \"\"\"\n        return np.inf\n\n    def calculate_value(self) -> float:\n        raise NotImplementedError(f\"InfPosition doesn't support calculating value\")\n\n    def get_stock_list(self) -> List[str]:\n        raise NotImplementedError(f\"InfPosition doesn't support stock list position\")\n\n    def get_stock_price(self, code: str) -> float:\n        \"\"\"the price of the inf position is meaningless\"\"\"\n        return np.nan\n\n    def get_stock_amount(self, code: str) -> float:\n        return np.inf\n\n    def get_cash(self, include_settle: bool = False) -> float:\n        return np.inf\n\n    def get_stock_amount_dict(self) -> dict:\n        raise NotImplementedError(f\"InfPosition doesn't support get_stock_amount_dict\")\n\n    def get_stock_weight_dict(self, only_stock: bool = False) -> dict:\n        raise NotImplementedError(f\"InfPosition doesn't support get_stock_weight_dict\")\n\n    def add_count_all(self, bar: str) -> None:\n        raise NotImplementedError(f\"InfPosition doesn't support add_count_all\")\n\n    def update_weight_all(self) -> None:\n        raise NotImplementedError(f\"InfPosition doesn't support update_weight_all\")\n\n    def settle_start(self, settle_type: str) -> None:\n        pass\n\n    def settle_commit(self) -> None:\n        pass\n"
  },
  {
    "path": "qlib/backtest/profit_attribution.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\"\"\"\nThis module is not well maintained.\n\"\"\"\n\nimport datetime\nfrom pathlib import Path\n\nimport numpy as np\nimport pandas as pd\n\nfrom ..config import C\nfrom ..data import D\nfrom .position import Position\n\n\ndef get_benchmark_weight(\n    bench,\n    start_date=None,\n    end_date=None,\n    path=None,\n    freq=\"day\",\n):\n    \"\"\"get_benchmark_weight\n\n    get the stock weight distribution of the benchmark\n\n    :param bench:\n    :param start_date:\n    :param end_date:\n    :param path:\n    :param freq:\n\n    :return: The weight distribution of the the benchmark described by a pandas dataframe\n             Every row corresponds to a trading day.\n             Every column corresponds to a stock.\n             Every cell represents the strategy.\n\n    \"\"\"\n    if not path:\n        path = Path(C.dpm.get_data_uri(freq)).expanduser() / \"raw\" / \"AIndexMembers\" / \"weights.csv\"\n    # TODO: the storage of weights should be implemented in a more elegent way\n    # TODO: The benchmark is not consistent with the filename in instruments.\n    bench_weight_df = pd.read_csv(path, usecols=[\"code\", \"date\", \"index\", \"weight\"])\n    bench_weight_df = bench_weight_df[bench_weight_df[\"index\"] == bench]\n    bench_weight_df[\"date\"] = pd.to_datetime(bench_weight_df[\"date\"])\n    if start_date is not None:\n        bench_weight_df = bench_weight_df[bench_weight_df.date >= start_date]\n    if end_date is not None:\n        bench_weight_df = bench_weight_df[bench_weight_df.date <= end_date]\n    bench_stock_weight = bench_weight_df.pivot_table(index=\"date\", columns=\"code\", values=\"weight\") / 100.0\n    return bench_stock_weight\n\n\ndef get_stock_weight_df(positions):\n    \"\"\"get_stock_weight_df\n    :param positions: Given a positions from backtest result.\n    :return:          A weight distribution for the position\n    \"\"\"\n    stock_weight = []\n    index = []\n    for date in sorted(positions.keys()):\n        pos = positions[date]\n        if isinstance(pos, dict):\n            pos = Position(position_dict=pos)\n        index.append(date)\n        stock_weight.append(pos.get_stock_weight_dict(only_stock=True))\n    return pd.DataFrame(stock_weight, index=index)\n\n\ndef decompose_portofolio_weight(stock_weight_df, stock_group_df):\n    \"\"\"decompose_portofolio_weight\n\n    '''\n    :param stock_weight_df: a pandas dataframe to describe the portofolio by weight.\n                    every row corresponds to a  day\n                    every column corresponds to a stock.\n                    Here is an example below.\n                    code        SH600004  SH600006  SH600017  SH600022  SH600026  SH600037  \\\n                    date\n                    2016-01-05  0.001543  0.001570  0.002732  0.001320  0.003000       NaN\n                    2016-01-06  0.001538  0.001569  0.002770  0.001417  0.002945       NaN\n                    ....\n    :param stock_group_df: a pandas dataframe to describe  the stock group.\n                    every row corresponds to a  day\n                    every column corresponds to a stock.\n                    the value in the cell repreponds the group id.\n                    Here is a example by for stock_group_df for industry. The value is the industry code\n                    instrument  SH600000  SH600004  SH600005  SH600006  SH600007  SH600008  \\\n                    datetime\n                    2016-01-05  801780.0  801170.0  801040.0  801880.0  801180.0  801160.0\n                    2016-01-06  801780.0  801170.0  801040.0  801880.0  801180.0  801160.0\n                    ...\n    :return:        Two dict will be returned.  The group_weight and the stock_weight_in_group.\n                    The key is the group. The value is a Series or Dataframe to describe the weight of group or weight of stock\n    \"\"\"\n    all_group = np.unique(stock_group_df.values.flatten())\n    all_group = all_group[~np.isnan(all_group)]\n\n    group_weight = {}\n    stock_weight_in_group = {}\n    for group_key in all_group:\n        group_mask = stock_group_df == group_key\n        group_weight[group_key] = stock_weight_df[group_mask].sum(axis=1)\n        stock_weight_in_group[group_key] = stock_weight_df[group_mask].divide(group_weight[group_key], axis=0)\n    return group_weight, stock_weight_in_group\n\n\ndef decompose_portofolio(stock_weight_df, stock_group_df, stock_ret_df):\n    \"\"\"\n    :param stock_weight_df: a pandas dataframe to describe the portofolio by weight.\n                    every row corresponds to a  day\n                    every column corresponds to a stock.\n                    Here is an example below.\n                    code        SH600004  SH600006  SH600017  SH600022  SH600026  SH600037  \\\n                    date\n                    2016-01-05  0.001543  0.001570  0.002732  0.001320  0.003000       NaN\n                    2016-01-06  0.001538  0.001569  0.002770  0.001417  0.002945       NaN\n                    2016-01-07  0.001555  0.001546  0.002772  0.001393  0.002904       NaN\n                    2016-01-08  0.001564  0.001527  0.002791  0.001506  0.002948       NaN\n                    2016-01-11  0.001597  0.001476  0.002738  0.001493  0.003043       NaN\n                    ....\n\n    :param stock_group_df: a pandas dataframe to describe  the stock group.\n                    every row corresponds to a  day\n                    every column corresponds to a stock.\n                    the value in the cell repreponds the group id.\n                    Here is a example by for stock_group_df for industry. The value is the industry code\n                    instrument  SH600000  SH600004  SH600005  SH600006  SH600007  SH600008  \\\n                    datetime\n                    2016-01-05  801780.0  801170.0  801040.0  801880.0  801180.0  801160.0\n                    2016-01-06  801780.0  801170.0  801040.0  801880.0  801180.0  801160.0\n                    2016-01-07  801780.0  801170.0  801040.0  801880.0  801180.0  801160.0\n                    2016-01-08  801780.0  801170.0  801040.0  801880.0  801180.0  801160.0\n                    2016-01-11  801780.0  801170.0  801040.0  801880.0  801180.0  801160.0\n                    ...\n\n    :param stock_ret_df:   a pandas dataframe to describe the stock return.\n                    every row corresponds to a day\n                    every column corresponds to a stock.\n                    the value in the cell repreponds the return of the group.\n                    Here is a example by for stock_ret_df.\n                    instrument  SH600000  SH600004  SH600005  SH600006  SH600007  SH600008  \\\n                    datetime\n                    2016-01-05  0.007795  0.022070  0.099099  0.024707  0.009473  0.016216\n                    2016-01-06 -0.032597 -0.075205 -0.098361 -0.098985 -0.099707 -0.098936\n                    2016-01-07 -0.001142  0.022544  0.100000  0.004225  0.000651  0.047226\n                    2016-01-08 -0.025157 -0.047244 -0.038567 -0.098177 -0.099609 -0.074408\n                    2016-01-11  0.023460  0.004959 -0.034384  0.018663  0.014461  0.010962\n                    ...\n\n    :return: It will decompose the portofolio to the group weight and group return.\n    \"\"\"\n    all_group = np.unique(stock_group_df.values.flatten())\n    all_group = all_group[~np.isnan(all_group)]\n\n    group_weight, stock_weight_in_group = decompose_portofolio_weight(stock_weight_df, stock_group_df)\n\n    group_ret = {}\n    for group_key, val in stock_weight_in_group.items():\n        stock_weight_in_group_start_date = min(val.index)\n        stock_weight_in_group_end_date = max(val.index)\n\n        temp_stock_ret_df = stock_ret_df[\n            (stock_ret_df.index >= stock_weight_in_group_start_date)\n            & (stock_ret_df.index <= stock_weight_in_group_end_date)\n        ]\n\n        group_ret[group_key] = (temp_stock_ret_df * val).sum(axis=1)\n        # If no weight is assigned, then the return of group will be np.nan\n        group_ret[group_key][group_weight[group_key] == 0.0] = np.nan\n\n    group_weight_df = pd.DataFrame(group_weight)\n    group_ret_df = pd.DataFrame(group_ret)\n    return group_weight_df, group_ret_df\n\n\ndef get_daily_bin_group(bench_values, stock_values, group_n):\n    \"\"\"get_daily_bin_group\n    Group the values of the stocks of benchmark into several bins in a day.\n    Put the stocks into these bins.\n\n    :param bench_values: A series contains the value of stocks in benchmark.\n                         The index is the stock code.\n    :param stock_values: A series contains the value of stocks of your portofolio\n                         The index is the stock code.\n    :param group_n:      Bins will be produced\n\n    :return:             A series with the same size and index as the stock_value.\n                         The value in the series is the group id of the bins.\n                         The No.1 bin contains the biggest values.\n    \"\"\"\n    stock_group = stock_values.copy()\n\n    # get the bin split points based on the daily proportion of benchmark\n    split_points = np.percentile(bench_values[~bench_values.isna()], np.linspace(0, 100, group_n + 1))\n    # Modify the biggest uppper bound and smallest lowerbound\n    split_points[0], split_points[-1] = -np.inf, np.inf\n    for i, (lb, up) in enumerate(zip(split_points, split_points[1:])):\n        stock_group.loc[stock_values[(stock_values >= lb) & (stock_values < up)].index] = group_n - i\n    return stock_group\n\n\ndef get_stock_group(stock_group_field_df, bench_stock_weight_df, group_method, group_n=None):\n    if group_method == \"category\":\n        # use the value of the benchmark as the category\n        return stock_group_field_df\n    elif group_method == \"bins\":\n        assert group_n is not None\n        # place the values into `group_n` fields.\n        # Each bin corresponds to a category.\n        new_stock_group_df = stock_group_field_df.copy().loc[\n            bench_stock_weight_df.index.min() : bench_stock_weight_df.index.max()\n        ]\n        for idx, row in (~bench_stock_weight_df.isna()).iterrows():\n            bench_values = stock_group_field_df.loc[idx, row[row].index]\n            new_stock_group_df.loc[idx] = get_daily_bin_group(\n                bench_values,\n                stock_group_field_df.loc[idx],\n                group_n=group_n,\n            )\n        return new_stock_group_df\n\n\ndef brinson_pa(\n    positions,\n    bench=\"SH000905\",\n    group_field=\"industry\",\n    group_method=\"category\",\n    group_n=None,\n    deal_price=\"vwap\",\n    freq=\"day\",\n):\n    \"\"\"brinson profit attribution\n\n    :param positions: The position produced by the backtest class\n    :param bench: The benchmark for comparing. TODO: if no benchmark is set, the equal-weighted is used.\n    :param group_field: The field used to set the group for assets allocation.\n                        `industry` and `market_value` is often used.\n    :param group_method: 'category' or 'bins'. The method used to set the group for asstes allocation\n                         `bin` will split the value into `group_n` bins and each bins represents a group\n    :param group_n: . Only used when group_method == 'bins'.\n\n    :return:\n        A dataframe with three columns: RAA(excess Return of Assets Allocation),  RSS(excess Return of Stock Selectino),  RTotal(Total excess Return)\n                                        Every row corresponds to a trading day, the value corresponds to the next return for this trading day\n        The middle info of brinson profit attribution\n    \"\"\"\n    # group_method will decide how to group the group_field.\n    dates = sorted(positions.keys())\n\n    start_date, end_date = min(dates), max(dates)\n\n    bench_stock_weight = get_benchmark_weight(bench, start_date, end_date, freq)\n\n    # The attributes for allocation will not\n    if not group_field.startswith(\"$\"):\n        group_field = \"$\" + group_field\n    if not deal_price.startswith(\"$\"):\n        deal_price = \"$\" + deal_price\n\n    # FIXME: In current version.  Some attributes(such as market_value) of some\n    # suspend stock is NAN. So we have to get more date to forward fill the NAN\n    shift_start_date = start_date - datetime.timedelta(days=250)\n    instruments = D.list_instruments(\n        D.instruments(market=\"all\"),\n        start_time=shift_start_date,\n        end_time=end_date,\n        as_list=True,\n        freq=freq,\n    )\n    stock_df = D.features(\n        instruments,\n        [group_field, deal_price],\n        start_time=shift_start_date,\n        end_time=end_date,\n        freq=freq,\n    )\n    stock_df.columns = [group_field, \"deal_price\"]\n\n    stock_group_field = stock_df[group_field].unstack().T\n    # FIXME: some attributes of some suspend stock is NAN.\n    stock_group_field = stock_group_field.ffill()\n    stock_group_field = stock_group_field.loc[start_date:end_date]\n\n    stock_group = get_stock_group(stock_group_field, bench_stock_weight, group_method, group_n)\n\n    deal_price_df = stock_df[\"deal_price\"].unstack().T\n    deal_price_df = deal_price_df.ffill()\n\n    # NOTE:\n    # The return will be slightly different from the of the return in the report.\n    # Here the position are adjusted at the end of the trading day with close\n    stock_ret = (deal_price_df - deal_price_df.shift(1)) / deal_price_df.shift(1)\n    stock_ret = stock_ret.shift(-1).loc[start_date:end_date]\n\n    port_stock_weight_df = get_stock_weight_df(positions)\n\n    # decomposing the portofolio\n    port_group_weight_df, port_group_ret_df = decompose_portofolio(port_stock_weight_df, stock_group, stock_ret)\n    bench_group_weight_df, bench_group_ret_df = decompose_portofolio(bench_stock_weight, stock_group, stock_ret)\n\n    # if the group return of the portofolio is NaN, replace it with the market\n    # value\n    mod_port_group_ret_df = port_group_ret_df.copy()\n    mod_port_group_ret_df[mod_port_group_ret_df.isna()] = bench_group_ret_df\n\n    Q1 = (bench_group_weight_df * bench_group_ret_df).sum(axis=1)\n    Q2 = (port_group_weight_df * bench_group_ret_df).sum(axis=1)\n    Q3 = (bench_group_weight_df * mod_port_group_ret_df).sum(axis=1)\n    Q4 = (port_group_weight_df * mod_port_group_ret_df).sum(axis=1)\n\n    return (\n        pd.DataFrame(\n            {\n                \"RAA\": Q2 - Q1,  # The excess profit from the assets allocation\n                \"RSS\": Q3 - Q1,  # The excess profit from the stocks selection\n                # The excess profit from the interaction of assets allocation and stocks selection\n                \"RIN\": Q4 - Q3 - Q2 + Q1,\n                \"RTotal\": Q4 - Q1,  # The totoal excess profit\n            },\n        ),\n        {\n            \"port_group_ret\": port_group_ret_df,\n            \"port_group_weight\": port_group_weight_df,\n            \"bench_group_ret\": bench_group_ret_df,\n            \"bench_group_weight\": bench_group_weight_df,\n            \"stock_group\": stock_group,\n            \"bench_stock_weight\": bench_stock_weight,\n            \"port_stock_weight\": port_stock_weight_df,\n            \"stock_ret\": stock_ret,\n        },\n    )\n"
  },
  {
    "path": "qlib/backtest/report.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom __future__ import annotations\n\nimport pathlib\nfrom collections import OrderedDict\nfrom typing import Any, Dict, List, Optional, Text, Tuple, Type, Union, cast\n\nimport numpy as np\nimport pandas as pd\n\nimport qlib.utils.index_data as idd\nfrom qlib.backtest.decision import BaseTradeDecision, Order, OrderDir\nfrom qlib.backtest.exchange import Exchange\n\nfrom ..tests.config import CSI300_BENCH\nfrom ..utils.resam import get_higher_eq_freq_feature, resam_ts_data\nfrom .high_performance_ds import BaseOrderIndicator, BaseSingleMetric, NumpyOrderIndicator\n\n\nclass PortfolioMetrics:\n    \"\"\"\n    Motivation:\n        PortfolioMetrics is for supporting portfolio related metrics.\n\n    Implementation:\n\n        daily portfolio metrics of the account\n        contain those followings: return, cost, turnover, account, cash, bench, value\n        For each step(bar/day/minute), each column represents\n        - return: the return of the portfolio generated by strategy **without transaction fee**.\n        - cost: the transaction fee and slippage.\n        - account: the total value of assets(cash and securities are both included) in user account based on the close price of each step.\n        - cash: the amount of cash in user's account.\n        - bench: the return of the benchmark\n        - value: the total value of securities/stocks/instruments (cash is excluded).\n\n        update report\n    \"\"\"\n\n    def __init__(self, freq: str = \"day\", benchmark_config: dict = {}) -> None:\n        \"\"\"\n        Parameters\n        ----------\n        freq : str\n            frequency of trading bar, used for updating hold count of trading bar\n        benchmark_config : dict\n            config of benchmark, may including the following arguments:\n            - benchmark : Union[str, list, pd.Series]\n                - If `benchmark` is pd.Series, `index` is trading date; the value T is the change from T-1 to T.\n                    example:\n                        print(\n                            D.features(D.instruments('csi500'),\n                            ['$close/Ref($close, 1)-1'])['$close/Ref($close, 1)-1'].head()\n                        )\n                            2017-01-04    0.011693\n                            2017-01-05    0.000721\n                            2017-01-06   -0.004322\n                            2017-01-09    0.006874\n                            2017-01-10   -0.003350\n                - If `benchmark` is list, will use the daily average change of the stock pool in the list as the\n                    'bench'.\n                - If `benchmark` is str, will use the daily change as the 'bench'.\n                benchmark code, default is SH000300 CSI300\n            - start_time : Union[str, pd.Timestamp], optional\n                - If `benchmark` is pd.Series, it will be ignored\n                - Else, it represent start time of benchmark, by default None\n            - end_time : Union[str, pd.Timestamp], optional\n                - If `benchmark` is pd.Series, it will be ignored\n                - Else, it represent end time of benchmark, by default None\n\n        \"\"\"\n\n        self.init_vars()\n        self.init_bench(freq=freq, benchmark_config=benchmark_config)\n\n    def init_vars(self) -> None:\n        self.accounts: dict = OrderedDict()  # account position value for each trade time\n        self.returns: dict = OrderedDict()  # daily return rate for each trade time\n        self.total_turnovers: dict = OrderedDict()  # total turnover for each trade time\n        self.turnovers: dict = OrderedDict()  # turnover for each trade time\n        self.total_costs: dict = OrderedDict()  # total trade cost for each trade time\n        self.costs: dict = OrderedDict()  # trade cost rate for each trade time\n        self.values: dict = OrderedDict()  # value for each trade time\n        self.cashes: dict = OrderedDict()\n        self.benches: dict = OrderedDict()\n        self.latest_pm_time: Optional[pd.TimeStamp] = None\n\n    def init_bench(self, freq: str | None = None, benchmark_config: dict | None = None) -> None:\n        if freq is not None:\n            self.freq = freq\n        self.benchmark_config = benchmark_config\n        self.bench = self._cal_benchmark(self.benchmark_config, self.freq)\n\n    @staticmethod\n    def _cal_benchmark(benchmark_config: Optional[dict], freq: str) -> Optional[pd.Series]:\n        if benchmark_config is None:\n            return None\n        benchmark = benchmark_config.get(\"benchmark\", CSI300_BENCH)\n        if benchmark is None:\n            return None\n\n        if isinstance(benchmark, pd.Series):\n            return benchmark\n        else:\n            start_time = benchmark_config.get(\"start_time\", None)\n            end_time = benchmark_config.get(\"end_time\", None)\n\n            if freq is None:\n                raise ValueError(\"benchmark freq can't be None!\")\n            _codes = benchmark if isinstance(benchmark, (list, dict)) else [benchmark]\n            fields = [\"$close/Ref($close,1)-1\"]\n            _temp_result, _ = get_higher_eq_freq_feature(_codes, fields, start_time, end_time, freq=freq)\n            if len(_temp_result) == 0:\n                raise ValueError(f\"The benchmark {_codes} does not exist. Please provide the right benchmark\")\n            return (\n                _temp_result.groupby(level=\"datetime\", group_keys=False)[_temp_result.columns.tolist()[0]]\n                .mean()\n                .fillna(0)\n            )\n\n    def _sample_benchmark(\n        self,\n        bench: pd.Series,\n        trade_start_time: Union[str, pd.Timestamp],\n        trade_end_time: Union[str, pd.Timestamp],\n    ) -> Optional[float]:\n        if self.bench is None:\n            return None\n\n        def cal_change(x):\n            return (x + 1).prod()\n\n        _ret = resam_ts_data(bench, trade_start_time, trade_end_time, method=cal_change)\n        return 0.0 if _ret is None else _ret - 1\n\n    def is_empty(self) -> bool:\n        return len(self.accounts) == 0\n\n    def get_latest_date(self) -> pd.Timestamp:\n        return self.latest_pm_time\n\n    def get_latest_account_value(self) -> float:\n        return self.accounts[self.latest_pm_time]\n\n    def get_latest_total_cost(self) -> Any:\n        return self.total_costs[self.latest_pm_time]\n\n    def get_latest_total_turnover(self) -> Any:\n        return self.total_turnovers[self.latest_pm_time]\n\n    def update_portfolio_metrics_record(\n        self,\n        trade_start_time: Union[str, pd.Timestamp] = None,\n        trade_end_time: Union[str, pd.Timestamp] = None,\n        account_value: float | None = None,\n        cash: float | None = None,\n        return_rate: float | None = None,\n        total_turnover: float | None = None,\n        turnover_rate: float | None = None,\n        total_cost: float | None = None,\n        cost_rate: float | None = None,\n        stock_value: float | None = None,\n        bench_value: float | None = None,\n    ) -> None:\n        # check data\n        if None in [\n            trade_start_time,\n            account_value,\n            cash,\n            return_rate,\n            total_turnover,\n            turnover_rate,\n            total_cost,\n            cost_rate,\n            stock_value,\n        ]:\n            raise ValueError(\n                \"None in [trade_start_time, account_value, cash, return_rate, total_turnover, turnover_rate, \"\n                \"total_cost, cost_rate, stock_value]\",\n            )\n\n        if trade_end_time is None and bench_value is None:\n            raise ValueError(\"Both trade_end_time and bench_value is None, benchmark is not usable.\")\n        elif bench_value is None:\n            bench_value = self._sample_benchmark(self.bench, trade_start_time, trade_end_time)\n\n        # update pm data\n        self.accounts[trade_start_time] = account_value\n        self.returns[trade_start_time] = return_rate\n        self.total_turnovers[trade_start_time] = total_turnover\n        self.turnovers[trade_start_time] = turnover_rate\n        self.total_costs[trade_start_time] = total_cost\n        self.costs[trade_start_time] = cost_rate\n        self.values[trade_start_time] = stock_value\n        self.cashes[trade_start_time] = cash\n        self.benches[trade_start_time] = bench_value\n        # update pm\n        self.latest_pm_time = trade_start_time\n        # finish pm update in each step\n\n    def generate_portfolio_metrics_dataframe(self) -> pd.DataFrame:\n        pm = pd.DataFrame()\n        pm[\"account\"] = pd.Series(self.accounts)\n        pm[\"return\"] = pd.Series(self.returns)\n        pm[\"total_turnover\"] = pd.Series(self.total_turnovers)\n        pm[\"turnover\"] = pd.Series(self.turnovers)\n        pm[\"total_cost\"] = pd.Series(self.total_costs)\n        pm[\"cost\"] = pd.Series(self.costs)\n        pm[\"value\"] = pd.Series(self.values)\n        pm[\"cash\"] = pd.Series(self.cashes)\n        pm[\"bench\"] = pd.Series(self.benches)\n        pm.index.name = \"datetime\"\n        return pm\n\n    def save_portfolio_metrics(self, path: str) -> None:\n        r = self.generate_portfolio_metrics_dataframe()\n        r.to_csv(path)\n\n    def load_portfolio_metrics(self, path: str) -> None:\n        \"\"\"load pm from a file\n        should have format like\n        columns = ['account', 'return', 'total_turnover', 'turnover', 'cost', 'total_cost', 'value', 'cash', 'bench']\n            :param\n                path: str/ pathlib.Path()\n        \"\"\"\n        with pathlib.Path(path).open(\"rb\") as f:\n            r = pd.read_csv(f, index_col=0)\n        r.index = pd.DatetimeIndex(r.index)\n\n        index = r.index\n        self.init_vars()\n        for trade_start_time in index:\n            self.update_portfolio_metrics_record(\n                trade_start_time=trade_start_time,\n                account_value=r.loc[trade_start_time][\"account\"],\n                cash=r.loc[trade_start_time][\"cash\"],\n                return_rate=r.loc[trade_start_time][\"return\"],\n                total_turnover=r.loc[trade_start_time][\"total_turnover\"],\n                turnover_rate=r.loc[trade_start_time][\"turnover\"],\n                total_cost=r.loc[trade_start_time][\"total_cost\"],\n                cost_rate=r.loc[trade_start_time][\"cost\"],\n                stock_value=r.loc[trade_start_time][\"value\"],\n                bench_value=r.loc[trade_start_time][\"bench\"],\n            )\n\n\nclass Indicator:\n    \"\"\"\n    `Indicator` is implemented in a aggregate way.\n    All the metrics are calculated aggregately.\n    All the metrics are calculated for a separated stock and in a specific step on a specific level.\n\n    | indicator    | desc.                                                        |\n    |--------------+--------------------------------------------------------------|\n    | amount       | the *target* amount given by the outer strategy              |\n    | deal_amount  | the real deal amount                                         |\n    | inner_amount | the total *target* amount of inner strategy                  |\n    | trade_price  | the average deal price                                       |\n    | trade_value  | the total trade value                                        |\n    | trade_cost   | the total trade cost  (base price need drection)             |\n    | trade_dir    | the trading direction                                        |\n    | ffr          | full fill rate                                               |\n    | pa           | price advantage                                              |\n    | pos          | win rate                                                     |\n    | base_price   | the price of baseline                                        |\n    | base_volume  | the volume of baseline (for weighted aggregating base_price) |\n\n    **NOTE**:\n    The `base_price` and `base_volume` can't be NaN when there are not trading on that step. Otherwise\n    aggregating get wrong results.\n\n    So `base_price` will not be calculated in a aggregate way!!\n\n    \"\"\"\n\n    def __init__(self, order_indicator_cls: Type[BaseOrderIndicator] = NumpyOrderIndicator) -> None:\n        self.order_indicator_cls = order_indicator_cls\n\n        # order indicator is metrics for a single order for a specific step\n        self.order_indicator_his: dict = OrderedDict()\n        self.order_indicator: BaseOrderIndicator = self.order_indicator_cls()\n\n        # trade indicator is metrics for all orders for a specific step\n        self.trade_indicator_his: dict = OrderedDict()\n        self.trade_indicator: Dict[str, Optional[BaseSingleMetric]] = OrderedDict()\n\n        self._trade_calendar = None\n\n    # def reset(self, trade_calendar: TradeCalendarManager):\n    def reset(self) -> None:\n        self.order_indicator = self.order_indicator_cls()\n        self.trade_indicator = OrderedDict()\n        # self._trade_calendar = trade_calendar\n\n    def record(self, trade_start_time: Union[str, pd.Timestamp]) -> None:\n        self.order_indicator_his[trade_start_time] = self.get_order_indicator()\n        self.trade_indicator_his[trade_start_time] = self.get_trade_indicator()\n\n    def _update_order_trade_info(self, trade_info: List[Tuple[Order, float, float, float]]) -> None:\n        amount = dict()\n        deal_amount = dict()\n        trade_price = dict()\n        trade_value = dict()\n        trade_cost = dict()\n        trade_dir = dict()\n        pa = dict()\n\n        for order, _trade_val, _trade_cost, _trade_price in trade_info:\n            amount[order.stock_id] = order.amount_delta\n            deal_amount[order.stock_id] = order.deal_amount_delta\n            trade_price[order.stock_id] = _trade_price\n            trade_value[order.stock_id] = _trade_val * order.sign\n            trade_cost[order.stock_id] = _trade_cost\n            trade_dir[order.stock_id] = order.direction\n            # The PA in the innermost layer is meanless\n            pa[order.stock_id] = 0\n\n        self.order_indicator.assign(\"amount\", amount)\n        self.order_indicator.assign(\"inner_amount\", amount)\n        self.order_indicator.assign(\"deal_amount\", deal_amount)\n        # NOTE: trade_price and baseline price will be same on the lowest-level\n        self.order_indicator.assign(\"trade_price\", trade_price)\n        self.order_indicator.assign(\"trade_value\", trade_value)\n        self.order_indicator.assign(\"trade_cost\", trade_cost)\n        self.order_indicator.assign(\"trade_dir\", trade_dir)\n        self.order_indicator.assign(\"pa\", pa)\n\n    def _update_order_fulfill_rate(self) -> None:\n        def func(deal_amount, amount):\n            # deal_amount is np.nan or None when there is no inner decision. So full fill rate is 0.\n            tmp_deal_amount = deal_amount.reindex(amount.index, 0)\n            tmp_deal_amount = tmp_deal_amount.replace({np.nan: 0})\n            return tmp_deal_amount / amount\n\n        self.order_indicator.transfer(func, \"ffr\")\n\n    def update_order_indicators(self, trade_info: List[Tuple[Order, float, float, float]]) -> None:\n        self._update_order_trade_info(trade_info=trade_info)\n        self._update_order_fulfill_rate()\n\n    def _agg_order_trade_info(self, inner_order_indicators: List[BaseOrderIndicator]) -> None:\n        # calculate total trade amount with each inner order indicator.\n        def trade_amount_func(deal_amount, trade_price):\n            return deal_amount * trade_price\n\n        for indicator in inner_order_indicators:\n            indicator.transfer(trade_amount_func, \"trade_price\")\n\n        # sum inner order indicators with same metric.\n        all_metric = [\"inner_amount\", \"deal_amount\", \"trade_price\", \"trade_value\", \"trade_cost\", \"trade_dir\"]\n        self.order_indicator_cls.sum_all_indicators(\n            self.order_indicator,\n            inner_order_indicators,\n            all_metric,\n            fill_value=0,\n        )\n\n        def func(trade_price, deal_amount):\n            # trade_price is np.nan instead of inf when deal_amount is zero.\n            tmp_deal_amount = deal_amount.replace({0: np.nan})\n            return trade_price / tmp_deal_amount\n\n        self.order_indicator.transfer(func, \"trade_price\")\n\n        def func_apply(trade_dir):\n            return trade_dir.apply(Order.parse_dir)\n\n        self.order_indicator.transfer(func_apply, \"trade_dir\")\n\n    def _update_trade_amount(self, outer_trade_decision: BaseTradeDecision) -> None:\n        # NOTE: these indicator is designed for order execution, so the\n        decision: List[Order] = cast(List[Order], outer_trade_decision.get_decision())\n        if len(decision) == 0:\n            self.order_indicator.assign(\"amount\", {})\n        else:\n            self.order_indicator.assign(\"amount\", {order.stock_id: order.amount_delta for order in decision})\n\n    def _get_base_vol_pri(\n        self,\n        inst: str,\n        trade_start_time: pd.Timestamp,\n        trade_end_time: pd.Timestamp,\n        direction: OrderDir,\n        decision: BaseTradeDecision,\n        trade_exchange: Exchange,\n        pa_config: dict = {},\n    ) -> Tuple[Optional[float], Optional[float]]:\n        \"\"\"\n        Get the base volume and price information\n        All the base price values are rooted from this function\n        \"\"\"\n\n        agg = pa_config.get(\"agg\", \"twap\").lower()\n        price = pa_config.get(\"price\", \"deal_price\").lower()\n\n        if decision.trade_range is not None:\n            trade_start_time, trade_end_time = decision.trade_range.clip_time_range(\n                start_time=trade_start_time,\n                end_time=trade_end_time,\n            )\n\n        if price == \"deal_price\":\n            price_s = trade_exchange.get_deal_price(\n                inst,\n                trade_start_time,\n                trade_end_time,\n                direction=direction,\n                method=None,\n            )\n        else:\n            raise NotImplementedError(f\"This type of input is not supported\")\n\n        # if there is no stock data during the time period\n        if price_s is None:\n            return None, None\n\n        if isinstance(price_s, (int, float, np.number)):\n            price_s = idd.SingleData(price_s, [trade_start_time])\n        elif isinstance(price_s, idd.SingleData):\n            pass\n        else:\n            raise NotImplementedError(f\"This type of input is not supported\")\n\n        # NOTE: there are some zeros in the trading price. These cases are known meaningless\n        # for aligning the previous logic, remove it.\n        # remove zero and negative values.\n        assert isinstance(price_s, idd.SingleData)\n        price_s = price_s.loc[(price_s > 1e-08).data.astype(bool)]\n        # NOTE ~(price_s < 1e-08) is different from price_s >= 1e-8\n        #   ~(np.nan < 1e-8) -> ~(False)  -> True\n\n        # if price_s is empty\n        if price_s.empty:\n            return None, None\n\n        assert isinstance(price_s, idd.SingleData)\n        if agg == \"vwap\":\n            volume_s = trade_exchange.get_volume(inst, trade_start_time, trade_end_time, method=None)\n            if isinstance(volume_s, (int, float, np.number)):\n                volume_s = idd.SingleData(volume_s, [trade_start_time])\n            assert isinstance(volume_s, idd.SingleData)\n            volume_s = volume_s.reindex(price_s.index)\n        elif agg == \"twap\":\n            volume_s = idd.SingleData(1, price_s.index)\n        else:\n            raise NotImplementedError(f\"This type of input is not supported\")\n\n        assert isinstance(volume_s, idd.SingleData)\n        base_volume = volume_s.sum()\n        base_price = (price_s * volume_s).sum() / base_volume\n        return base_price, base_volume\n\n    def _agg_base_price(\n        self,\n        inner_order_indicators: List[BaseOrderIndicator],\n        decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]],\n        trade_exchange: Exchange,\n        pa_config: dict = {},\n    ) -> None:\n        \"\"\"\n        # NOTE:!!!!\n        # Strong assumption!!!!!!\n        # the correctness of the base_price relies on that the **same** exchange is used\n\n        Parameters\n        ----------\n        inner_order_indicators : List[BaseOrderIndicator]\n            the indicators of account of inner executor\n        decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]],\n            a list of decisions according to inner_order_indicators\n        trade_exchange : Exchange\n            for retrieving trading price\n        pa_config : dict\n            For example\n            {\n                \"agg\": \"twap\",  # \"vwap\"\n                \"price\": \"$close\",  # TODO: this is not supported now!!!!!\n                                    # default to use deal price of the exchange\n            }\n        \"\"\"\n\n        # TODO: I think there are potentials to be optimized\n        trade_dir = self.order_indicator.get_index_data(\"trade_dir\")\n        if len(trade_dir) > 0:\n            bp_all, bv_all = [], []\n            # <step, inst, (base_volume | base_price)>\n            for oi, (dec, start, end) in zip(inner_order_indicators, decision_list):\n                bp_s = oi.get_index_data(\"base_price\").reindex(trade_dir.index)\n                bv_s = oi.get_index_data(\"base_volume\").reindex(trade_dir.index)\n\n                bp_new, bv_new = {}, {}\n                for pr, v, (inst, direction) in zip(bp_s.data, bv_s.data, zip(trade_dir.index, trade_dir.data)):\n                    if np.isnan(pr):\n                        bp_tmp, bv_tmp = self._get_base_vol_pri(\n                            inst,\n                            start,\n                            end,\n                            decision=dec,\n                            direction=direction,\n                            trade_exchange=trade_exchange,\n                            pa_config=pa_config,\n                        )\n                        if (bp_tmp is not None) and (bv_tmp is not None):\n                            bp_new[inst], bv_new[inst] = bp_tmp, bv_tmp\n                    else:\n                        bp_new[inst], bv_new[inst] = pr, v\n\n                bp_new = idd.SingleData(bp_new)\n                bv_new = idd.SingleData(bv_new)\n                bp_all.append(bp_new)\n                bv_all.append(bv_new)\n            bp_all_multi_data = idd.concat(bp_all, axis=1)\n            bv_all_multi_data = idd.concat(bv_all, axis=1)\n\n            base_volume = bv_all_multi_data.sum(axis=1)\n            self.order_indicator.assign(\"base_volume\", base_volume.to_dict())\n            self.order_indicator.assign(\n                \"base_price\",\n                ((bp_all_multi_data * bv_all_multi_data).sum(axis=1) / base_volume).to_dict(),\n            )\n\n    def _agg_order_price_advantage(self) -> None:\n        def if_empty_func(trade_price):\n            return trade_price.empty\n\n        if_empty = self.order_indicator.transfer(if_empty_func)\n        if not if_empty:\n\n            def func(trade_dir, trade_price, base_price):\n                sign = 1 - trade_dir * 2\n                return sign * (trade_price / base_price - 1)\n\n            self.order_indicator.transfer(func, \"pa\")\n        else:\n            self.order_indicator.assign(\"pa\", {})\n\n    def agg_order_indicators(\n        self,\n        inner_order_indicators: List[BaseOrderIndicator],\n        decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]],\n        outer_trade_decision: BaseTradeDecision,\n        trade_exchange: Exchange,\n        indicator_config: dict = {},\n    ) -> None:\n        self._agg_order_trade_info(inner_order_indicators)\n        self._update_trade_amount(outer_trade_decision)\n        self._update_order_fulfill_rate()\n        pa_config = indicator_config.get(\"pa_config\", {})\n        self._agg_base_price(inner_order_indicators, decision_list, trade_exchange, pa_config=pa_config)  # TODO\n        self._agg_order_price_advantage()\n\n    def _cal_trade_fulfill_rate(self, method: str = \"mean\") -> Optional[BaseSingleMetric]:\n        if method == \"mean\":\n            return self.order_indicator.transfer(\n                lambda ffr: ffr.mean(),\n            )\n        elif method == \"amount_weighted\":\n            return self.order_indicator.transfer(\n                lambda ffr, deal_amount: (ffr * deal_amount.abs()).sum() / (deal_amount.abs().sum()),\n            )\n        elif method == \"value_weighted\":\n            return self.order_indicator.transfer(\n                lambda ffr, trade_value: (ffr * trade_value.abs()).sum() / (trade_value.abs().sum()),\n            )\n        else:\n            raise ValueError(f\"method {method} is not supported!\")\n\n    def _cal_trade_price_advantage(self, method: str = \"mean\") -> Optional[BaseSingleMetric]:\n        if method == \"mean\":\n            return self.order_indicator.transfer(lambda pa: pa.mean())\n        elif method == \"amount_weighted\":\n            return self.order_indicator.transfer(\n                lambda pa, deal_amount: (pa * deal_amount.abs()).sum() / (deal_amount.abs().sum()),\n            )\n        elif method == \"value_weighted\":\n            return self.order_indicator.transfer(\n                lambda pa, trade_value: (pa * trade_value.abs()).sum() / (trade_value.abs().sum()),\n            )\n        else:\n            raise ValueError(f\"method {method} is not supported!\")\n\n    def _cal_trade_positive_rate(self) -> Optional[BaseSingleMetric]:\n        def func(pa):\n            return (pa > 0).sum() / pa.count()\n\n        return self.order_indicator.transfer(func)\n\n    def _cal_deal_amount(self) -> Optional[BaseSingleMetric]:\n        def func(deal_amount):\n            return deal_amount.abs().sum()\n\n        return self.order_indicator.transfer(func)\n\n    def _cal_trade_value(self) -> Optional[BaseSingleMetric]:\n        def func(trade_value):\n            return trade_value.abs().sum()\n\n        return self.order_indicator.transfer(func)\n\n    def _cal_trade_order_count(self) -> Optional[BaseSingleMetric]:\n        def func(amount):\n            return amount.count()\n\n        return self.order_indicator.transfer(func)\n\n    def cal_trade_indicators(\n        self,\n        trade_start_time: Union[str, pd.Timestamp],\n        freq: str,\n        indicator_config: dict = {},\n    ) -> None:\n        show_indicator = indicator_config.get(\"show_indicator\", False)\n        ffr_config = indicator_config.get(\"ffr_config\", {})\n        pa_config = indicator_config.get(\"pa_config\", {})\n        fulfill_rate = self._cal_trade_fulfill_rate(method=ffr_config.get(\"weight_method\", \"mean\"))\n        price_advantage = self._cal_trade_price_advantage(method=pa_config.get(\"weight_method\", \"mean\"))\n        positive_rate = self._cal_trade_positive_rate()\n        deal_amount = self._cal_deal_amount()\n        trade_value = self._cal_trade_value()\n        order_count = self._cal_trade_order_count()\n        self.trade_indicator[\"ffr\"] = fulfill_rate\n        self.trade_indicator[\"pa\"] = price_advantage\n        self.trade_indicator[\"pos\"] = positive_rate\n        self.trade_indicator[\"deal_amount\"] = deal_amount\n        self.trade_indicator[\"value\"] = trade_value\n        self.trade_indicator[\"count\"] = order_count\n        if show_indicator:\n            print(\n                \"[Indicator({}) {}]: FFR: {}, PA: {}, POS: {}\".format(\n                    freq,\n                    (\n                        trade_start_time\n                        if isinstance(trade_start_time, str)\n                        else trade_start_time.strftime(\"%Y-%m-%d %H:%M:%S\")\n                    ),\n                    fulfill_rate,\n                    price_advantage,\n                    positive_rate,\n                ),\n            )\n\n    def get_order_indicator(self, raw: bool = True) -> Union[BaseOrderIndicator, Dict[Text, pd.Series]]:\n        return self.order_indicator if raw else self.order_indicator.to_series()\n\n    def get_trade_indicator(self) -> Dict[str, Optional[BaseSingleMetric]]:\n        return self.trade_indicator\n\n    def generate_trade_indicators_dataframe(self) -> pd.DataFrame:\n        return pd.DataFrame.from_dict(self.trade_indicator_his, orient=\"index\")\n"
  },
  {
    "path": "qlib/backtest/signal.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\nimport abc\nfrom typing import Dict, List, Text, Tuple, Union\n\nimport pandas as pd\n\nfrom qlib.utils import init_instance_by_config\n\nfrom ..data.dataset import Dataset\nfrom ..data.dataset.utils import convert_index_format\nfrom ..model.base import BaseModel\nfrom ..utils.resam import resam_ts_data\n\n\nclass Signal(metaclass=abc.ABCMeta):\n    \"\"\"\n    Some trading strategy make decisions based on other prediction signals\n    The signals may comes from different sources(e.g. prepared data, online prediction from model and dataset)\n\n    This interface is tries to provide unified interface for those different sources\n    \"\"\"\n\n    @abc.abstractmethod\n    def get_signal(self, start_time: pd.Timestamp, end_time: pd.Timestamp) -> Union[pd.Series, pd.DataFrame, None]:\n        \"\"\"\n        get the signal at the end of the decision step(from `start_time` to `end_time`)\n\n        Returns\n        -------\n        Union[pd.Series, pd.DataFrame, None]:\n            returns None if no signal in the specific day\n        \"\"\"\n\n\nclass SignalWCache(Signal):\n    \"\"\"\n    Signal With pandas with based Cache\n    SignalWCache will store the prepared signal as a attribute and give the according signal based on input query\n    \"\"\"\n\n    def __init__(self, signal: Union[pd.Series, pd.DataFrame]) -> None:\n        \"\"\"\n\n        Parameters\n        ----------\n        signal : Union[pd.Series, pd.DataFrame]\n            The expected format of the signal is like the data below (the order of index is not important and can be\n            automatically adjusted)\n\n                instrument datetime\n                SH600000   2008-01-02  0.079704\n                           2008-01-03  0.120125\n                           2008-01-04  0.878860\n                           2008-01-07  0.505539\n                           2008-01-08  0.395004\n        \"\"\"\n        self.signal_cache = convert_index_format(signal, level=\"datetime\")\n\n    def get_signal(self, start_time: pd.Timestamp, end_time: pd.Timestamp) -> Union[pd.Series, pd.DataFrame]:\n        # the frequency of the signal may not align with the decision frequency of strategy\n        # so resampling from the data is necessary\n        # the latest signal leverage more recent data and therefore is used in trading.\n        signal = resam_ts_data(self.signal_cache, start_time=start_time, end_time=end_time, method=\"last\")\n        return signal\n\n\nclass ModelSignal(SignalWCache):\n    def __init__(self, model: BaseModel, dataset: Dataset) -> None:\n        self.model = model\n        self.dataset = dataset\n        pred_scores = self.model.predict(dataset)\n        if isinstance(pred_scores, pd.DataFrame):\n            pred_scores = pred_scores.iloc[:, 0]\n        super().__init__(pred_scores)\n\n    def _update_model(self) -> None:\n        \"\"\"\n        When using online data, update model in each bar as the following steps:\n            - update dataset with online data, the dataset should support online update\n            - make the latest prediction scores of the new bar\n            - update the pred score into the latest prediction\n        \"\"\"\n        # TODO: this method is not included in the framework and could be refactor later\n        raise NotImplementedError(\"_update_model is not implemented!\")\n\n\ndef create_signal_from(\n    obj: Union[Signal, Tuple[BaseModel, Dataset], List, Dict, Text, pd.Series, pd.DataFrame],\n) -> Signal:\n    \"\"\"\n    create signal from diverse information\n    This method will choose the right method to create a signal based on `obj`\n    Please refer to the code below.\n    \"\"\"\n    if isinstance(obj, Signal):\n        return obj\n    elif isinstance(obj, (tuple, list)):\n        return ModelSignal(*obj)\n    elif isinstance(obj, (dict, str)):\n        return init_instance_by_config(obj)\n    elif isinstance(obj, (pd.DataFrame, pd.Series)):\n        return SignalWCache(signal=obj)\n    else:\n        raise NotImplementedError(f\"This type of signal is not supported\")\n"
  },
  {
    "path": "qlib/backtest/utils.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom __future__ import annotations\n\nfrom abc import abstractmethod\nfrom typing import Any, Set, Tuple, TYPE_CHECKING, Union\n\nimport numpy as np\n\nfrom qlib.utils.time import epsilon_change\n\nif TYPE_CHECKING:\n    from qlib.backtest.decision import BaseTradeDecision\n\nimport warnings\n\nimport pandas as pd\n\nfrom ..data.data import Cal\n\n\nclass TradeCalendarManager:\n    \"\"\"\n    Manager for trading calendar\n        - BaseStrategy and BaseExecutor will use it\n    \"\"\"\n\n    def __init__(\n        self,\n        freq: str,\n        start_time: Union[str, pd.Timestamp] = None,\n        end_time: Union[str, pd.Timestamp] = None,\n        level_infra: LevelInfrastructure | None = None,\n    ) -> None:\n        \"\"\"\n        Parameters\n        ----------\n        freq : str\n            frequency of trading calendar, also trade time per trading step\n        start_time : Union[str, pd.Timestamp], optional\n            closed start of the trading calendar, by default None\n            If `start_time` is None, it must be reset before trading.\n        end_time : Union[str, pd.Timestamp], optional\n            closed end of the trade time range, by default None\n            If `end_time` is None, it must be reset before trading.\n        \"\"\"\n        self.level_infra = level_infra\n        self.reset(freq=freq, start_time=start_time, end_time=end_time)\n\n    def reset(\n        self,\n        freq: str,\n        start_time: Union[str, pd.Timestamp] = None,\n        end_time: Union[str, pd.Timestamp] = None,\n    ) -> None:\n        \"\"\"\n        Please refer to the docs of `__init__`\n\n        Reset the trade calendar\n        - self.trade_len : The total count for trading step\n        - self.trade_step : The number of trading step finished, self.trade_step can be\n            [0, 1, 2, ..., self.trade_len - 1]\n        \"\"\"\n        self.freq = freq\n        self.start_time = pd.Timestamp(start_time) if start_time else None\n        self.end_time = pd.Timestamp(end_time) if end_time else None\n\n        _calendar = Cal.calendar(freq=freq, future=True)\n        assert isinstance(_calendar, np.ndarray)\n        self._calendar = _calendar\n        _, _, _start_index, _end_index = Cal.locate_index(start_time, end_time, freq=freq, future=True)\n        self.start_index = _start_index\n        self.end_index = _end_index\n        self.trade_len = _end_index - _start_index + 1\n        self.trade_step = 0\n\n    def finished(self) -> bool:\n        \"\"\"\n        Check if the trading finished\n        - Should check before calling strategy.generate_decisions and executor.execute\n        - If self.trade_step >= self.self.trade_len, it means the trading is finished\n        - If self.trade_step < self.self.trade_len, it means the number of trading step finished is self.trade_step\n        \"\"\"\n        return self.trade_step >= self.trade_len\n\n    def step(self) -> None:\n        if self.finished():\n            raise RuntimeError(f\"The calendar is finished, please reset it if you want to call it!\")\n        self.trade_step += 1\n\n    def get_freq(self) -> str:\n        return self.freq\n\n    def get_trade_len(self) -> int:\n        \"\"\"get the total step length\"\"\"\n        return self.trade_len\n\n    def get_trade_step(self) -> int:\n        return self.trade_step\n\n    def get_step_time(self, trade_step: int | None = None, shift: int = 0) -> Tuple[pd.Timestamp, pd.Timestamp]:\n        \"\"\"\n        Get the left and right endpoints of the trade_step'th trading interval\n\n        About the endpoints:\n            - Qlib uses the closed interval in time-series data selection, which has the same performance as\n            pandas.Series.loc\n            # - The returned right endpoints should minus 1 seconds because of the closed interval representation in\n            #   Qlib.\n            # Note: Qlib supports up to minutely decision execution, so 1 seconds is less than any trading time\n            #   interval.\n\n        Parameters\n        ----------\n        trade_step : int, optional\n            the number of trading step finished, by default None to indicate current step\n        shift : int, optional\n            shift bars , by default 0\n\n        Returns\n        -------\n        Tuple[pd.Timestamp, pd.Timestamp]\n            - If shift == 0, return the trading time range\n            - If shift > 0, return the trading time range of the earlier shift bars\n            - If shift < 0, return the trading time range of the later shift bar\n        \"\"\"\n        if trade_step is None:\n            trade_step = self.get_trade_step()\n        calendar_index = self.start_index + trade_step - shift\n        return self._calendar[calendar_index], epsilon_change(self._calendar[calendar_index + 1])\n\n    def get_data_cal_range(self, rtype: str = \"full\") -> Tuple[int, int]:\n        \"\"\"\n        get the calendar range\n        The following assumptions are made\n        1) The frequency of the exchange in common_infra is the same as the data calendar\n        2) Users want the **data index** mod by **day** (i.e. 240 min)\n\n        Parameters\n        ----------\n        rtype: str\n            - \"full\": return the full limitation of the decision in the day\n            - \"step\": return the limitation of current step\n\n        Returns\n        -------\n        Tuple[int, int]:\n        \"\"\"\n        # potential performance issue\n        assert self.level_infra is not None\n\n        day_start = pd.Timestamp(self.start_time.date())\n        day_end = epsilon_change(day_start + pd.Timedelta(days=1))\n        freq = self.level_infra.get(\"common_infra\").get(\"trade_exchange\").freq\n        _, _, day_start_idx, _ = Cal.locate_index(day_start, day_end, freq=freq)\n\n        if rtype == \"full\":\n            _, _, start_idx, end_index = Cal.locate_index(self.start_time, self.end_time, freq=freq)\n        elif rtype == \"step\":\n            _, _, start_idx, end_index = Cal.locate_index(*self.get_step_time(), freq=freq)\n        else:\n            raise ValueError(f\"This type of input {rtype} is not supported\")\n\n        return start_idx - day_start_idx, end_index - day_start_idx\n\n    def get_all_time(self) -> Tuple[pd.Timestamp, pd.Timestamp]:\n        \"\"\"Get the start_time and end_time for trading\"\"\"\n        return self.start_time, self.end_time\n\n    # helper functions\n    def get_range_idx(self, start_time: pd.Timestamp, end_time: pd.Timestamp) -> Tuple[int, int]:\n        \"\"\"\n        get the range index which involve start_time~end_time  (both sides are closed)\n\n        Parameters\n        ----------\n        start_time : pd.Timestamp\n        end_time : pd.Timestamp\n\n        Returns\n        -------\n        Tuple[int, int]:\n            the index of the range.  **the left and right are closed**\n        \"\"\"\n        left = int(np.searchsorted(self._calendar, start_time, side=\"right\") - 1)\n        right = int(np.searchsorted(self._calendar, end_time, side=\"right\") - 1)\n        left -= self.start_index\n        right -= self.start_index\n\n        def clip(idx: int) -> int:\n            return min(max(0, idx), self.trade_len - 1)\n\n        return clip(left), clip(right)\n\n    def __repr__(self) -> str:\n        return (\n            f\"class: {self.__class__.__name__}; \"\n            f\"{self.start_time}[{self.start_index}]~{self.end_time}[{self.end_index}]: \"\n            f\"[{self.trade_step}/{self.trade_len}]\"\n        )\n\n\nclass BaseInfrastructure:\n    def __init__(self, **kwargs: Any) -> None:\n        self.reset_infra(**kwargs)\n\n    @abstractmethod\n    def get_support_infra(self) -> Set[str]:\n        raise NotImplementedError(\"`get_support_infra` is not implemented!\")\n\n    def reset_infra(self, **kwargs: Any) -> None:\n        support_infra = self.get_support_infra()\n        for k, v in kwargs.items():\n            if k in support_infra:\n                setattr(self, k, v)\n            else:\n                warnings.warn(f\"{k} is ignored in `reset_infra`!\")\n\n    def get(self, infra_name: str) -> Any:\n        if hasattr(self, infra_name):\n            return getattr(self, infra_name)\n        else:\n            warnings.warn(f\"infra {infra_name} is not found!\")\n\n    def has(self, infra_name: str) -> bool:\n        return infra_name in self.get_support_infra() and hasattr(self, infra_name)\n\n    def update(self, other: BaseInfrastructure) -> None:\n        support_infra = other.get_support_infra()\n        infra_dict = {_infra: getattr(other, _infra) for _infra in support_infra if hasattr(other, _infra)}\n        self.reset_infra(**infra_dict)\n\n\nclass CommonInfrastructure(BaseInfrastructure):\n    def get_support_infra(self) -> Set[str]:\n        return {\"trade_account\", \"trade_exchange\"}\n\n\nclass LevelInfrastructure(BaseInfrastructure):\n    \"\"\"level infrastructure is created by executor, and then shared to strategies on the same level\"\"\"\n\n    def get_support_infra(self) -> Set[str]:\n        \"\"\"\n        Descriptions about the infrastructure\n\n        sub_level_infra:\n        - **NOTE**: this will only work after _init_sub_trading !!!\n        \"\"\"\n        return {\"trade_calendar\", \"sub_level_infra\", \"common_infra\", \"executor\"}\n\n    def reset_cal(\n        self,\n        freq: str,\n        start_time: Union[str, pd.Timestamp, None],\n        end_time: Union[str, pd.Timestamp, None],\n    ) -> None:\n        \"\"\"reset trade calendar manager\"\"\"\n        if self.has(\"trade_calendar\"):\n            self.get(\"trade_calendar\").reset(freq, start_time=start_time, end_time=end_time)\n        else:\n            self.reset_infra(\n                trade_calendar=TradeCalendarManager(freq, start_time=start_time, end_time=end_time, level_infra=self),\n            )\n\n    def set_sub_level_infra(self, sub_level_infra: LevelInfrastructure) -> None:\n        \"\"\"this will make the calendar access easier when crossing multi-levels\"\"\"\n        self.reset_infra(sub_level_infra=sub_level_infra)\n\n\ndef get_start_end_idx(trade_calendar: TradeCalendarManager, outer_trade_decision: BaseTradeDecision) -> Tuple[int, int]:\n    \"\"\"\n    A helper function for getting the decision-level index range limitation for inner strategy\n    - NOTE: this function is not applicable to order-level\n\n    Parameters\n    ----------\n    trade_calendar : TradeCalendarManager\n    outer_trade_decision : BaseTradeDecision\n        the trade decision made by outer strategy\n\n    Returns\n    -------\n    Union[int, int]:\n        start index and end index\n    \"\"\"\n    try:\n        return outer_trade_decision.get_range_limit(inner_calendar=trade_calendar)\n    except NotImplementedError:\n        return 0, trade_calendar.get_trade_len() - 1\n"
  },
  {
    "path": "qlib/cli/__init__.py",
    "content": ""
  },
  {
    "path": "qlib/cli/data.py",
    "content": "#  Copyright (c) Microsoft Corporation.\n#  Licensed under the MIT License.\n\nimport fire\nfrom qlib.tests.data import GetData\n\nif __name__ == \"__main__\":\n    fire.Fire(GetData)\n"
  },
  {
    "path": "qlib/cli/run.py",
    "content": "#  Copyright (c) Microsoft Corporation.\n#  Licensed under the MIT License.\nimport logging\nimport os\nfrom pathlib import Path\nimport sys\n\nimport fire\nfrom jinja2 import Template, meta\nfrom ruamel.yaml import YAML\n\nimport qlib\nfrom qlib.config import C\nfrom qlib.log import get_module_logger\nfrom qlib.model.trainer import task_train\nfrom qlib.utils import set_log_with_config\nfrom qlib.utils.data import update_config\n\nset_log_with_config(C.logging_config)\nlogger = get_module_logger(\"qrun\", logging.INFO)\n\n\ndef get_path_list(path):\n    if isinstance(path, str):\n        return [path]\n    else:\n        return list(path)\n\n\ndef sys_config(config, config_path):\n    \"\"\"\n    Configure the `sys` section\n\n    Parameters\n    ----------\n    config : dict\n        configuration of the workflow.\n    config_path : str\n        path of the configuration\n    \"\"\"\n    sys_config = config.get(\"sys\", {})\n\n    # abspath\n    for p in get_path_list(sys_config.get(\"path\", [])):\n        sys.path.append(p)\n\n    # relative path to config path\n    for p in get_path_list(sys_config.get(\"rel_path\", [])):\n        sys.path.append(str(Path(config_path).parent.resolve().absolute() / p))\n\n\ndef render_template(config_path: str) -> str:\n    \"\"\"\n    render the template based on the environment\n\n    Parameters\n    ----------\n    config_path : str\n        configuration path\n\n    Returns\n    -------\n    str\n        the rendered content\n    \"\"\"\n    with open(config_path, \"r\") as f:\n        config = f.read()\n    # Set up the Jinja2 environment\n    template = Template(config)\n\n    # Parse the template to find undeclared variables\n    env = template.environment\n    parsed_content = env.parse(config)\n    variables = meta.find_undeclared_variables(parsed_content)\n\n    # Get context from os.environ according to the variables\n    context = {var: os.getenv(var, \"\") for var in variables if var in os.environ}\n    logger.info(f\"Render the template with the context: {context}\")\n\n    # Render the template with the context\n    rendered_content = template.render(context)\n    return rendered_content\n\n\n# workflow handler function\ndef workflow(config_path, experiment_name=\"workflow\", uri_folder=\"mlruns\"):\n    \"\"\"\n    This is a Qlib CLI entrance.\n    User can run the whole Quant research workflow defined by a configure file\n    - the code is located here ``qlib/cli/run.py``\n\n    User can specify a base_config file in your workflow.yml file by adding \"BASE_CONFIG_PATH\".\n    Qlib will load the configuration in BASE_CONFIG_PATH first, and the user only needs to update the custom fields\n    in their own workflow.yml file.\n\n    For examples:\n\n        qlib_init:\n            provider_uri: \"~/.qlib/qlib_data/cn_data\"\n            region: cn\n        BASE_CONFIG_PATH: \"workflow_config_lightgbm_Alpha158_csi500.yaml\"\n        market: csi300\n\n    \"\"\"\n    # Render the template\n    rendered_yaml = render_template(config_path)\n    yaml = YAML(typ=\"safe\", pure=True)\n    config = yaml.load(rendered_yaml)\n\n    base_config_path = config.get(\"BASE_CONFIG_PATH\", None)\n    if base_config_path:\n        logger.info(f\"Use BASE_CONFIG_PATH: {base_config_path}\")\n        base_config_path = Path(base_config_path)\n\n        # it will find config file in absolute path and relative path\n        if base_config_path.exists():\n            path = base_config_path\n        else:\n            logger.info(\n                f\"Can't find BASE_CONFIG_PATH base on: {Path.cwd()}, \"\n                f\"try using relative path to config path: {Path(config_path).absolute()}\"\n            )\n            relative_path = Path(config_path).absolute().parent.joinpath(base_config_path)\n            if relative_path.exists():\n                path = relative_path\n            else:\n                raise FileNotFoundError(f\"Can't find the BASE_CONFIG file: {base_config_path}\")\n\n        with open(path) as fp:\n            yaml = YAML(typ=\"safe\", pure=True)\n            base_config = yaml.load(fp)\n        logger.info(f\"Load BASE_CONFIG_PATH succeed: {path.resolve()}\")\n        config = update_config(base_config, config)\n\n    # config the `sys` section\n    sys_config(config, config_path)\n\n    if \"exp_manager\" in config.get(\"qlib_init\"):\n        qlib.init(**config.get(\"qlib_init\"))\n    else:\n        exp_manager = C[\"exp_manager\"]\n        exp_manager[\"kwargs\"][\"uri\"] = \"file:\" + str(Path(os.getcwd()).resolve() / uri_folder)\n        qlib.init(**config.get(\"qlib_init\"), exp_manager=exp_manager)\n\n    if \"experiment_name\" in config:\n        experiment_name = config[\"experiment_name\"]\n    recorder = task_train(config.get(\"task\"), experiment_name=experiment_name)\n    recorder.save_objects(config=config)\n\n\n# function to run workflow by config\ndef run():\n    fire.Fire(workflow)\n\n\nif __name__ == \"__main__\":\n    run()\n"
  },
  {
    "path": "qlib/config.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\"\"\"\nAbout the configs\n=================\n\nThe config will be based on _default_config.\nTwo modes are supported\n- client\n- server\n\n\"\"\"\n\nfrom __future__ import annotations\n\nimport os\nimport re\nimport copy\nimport logging\nimport platform\nimport multiprocessing\nfrom pathlib import Path\nfrom typing import Callable, Optional, Union\nfrom typing import TYPE_CHECKING\n\nfrom qlib.constant import REG_CN, REG_US, REG_TW\n\nif TYPE_CHECKING:\n    from qlib.utils.time import Freq\n\nfrom pydantic_settings import BaseSettings, SettingsConfigDict\n\n\nclass MLflowSettings(BaseSettings):\n    uri: str = \"file:\" + str(Path(os.getcwd()).resolve() / \"mlruns\")\n    default_exp_name: str = \"Experiment\"\n\n\nclass QSettings(BaseSettings):\n    \"\"\"\n    Qlib's settings.\n    It tries to provide a default settings for most of Qlib's components.\n    But it would be a long journey to provide a comprehensive settings for all of Qlib's components.\n\n    Here is some design guidelines:\n    - The priority of settings is\n        - Actively passed-in settings, like `qlib.init(provider_uri=...)`\n        - The default settings\n            - QSettings tries to provide default settings for most of Qlib's components.\n    \"\"\"\n\n    mlflow: MLflowSettings = MLflowSettings()\n    provider_uri: str = \"~/.qlib/qlib_data/cn_data\"\n\n    model_config = SettingsConfigDict(\n        env_prefix=\"QLIB_\",\n        env_nested_delimiter=\"_\",\n    )\n\n\nQSETTINGS = QSettings()\n\n\nclass Config:\n    def __init__(self, default_conf):\n        self.__dict__[\"_default_config\"] = copy.deepcopy(default_conf)  # avoiding conflicts with __getattr__\n        self.reset()\n\n    def __getitem__(self, key):\n        return self.__dict__[\"_config\"][key]\n\n    def __getattr__(self, attr):\n        if attr in self.__dict__[\"_config\"]:\n            return self.__dict__[\"_config\"][attr]\n\n        raise AttributeError(f\"No such `{attr}` in self._config\")\n\n    def get(self, key, default=None):\n        return self.__dict__[\"_config\"].get(key, default)\n\n    def __setitem__(self, key, value):\n        self.__dict__[\"_config\"][key] = value\n\n    def __setattr__(self, attr, value):\n        self.__dict__[\"_config\"][attr] = value\n\n    def __contains__(self, item):\n        return item in self.__dict__[\"_config\"]\n\n    def __getstate__(self):\n        return self.__dict__\n\n    def __setstate__(self, state):\n        self.__dict__.update(state)\n\n    def __str__(self):\n        return str(self.__dict__[\"_config\"])\n\n    def __repr__(self):\n        return str(self.__dict__[\"_config\"])\n\n    def reset(self):\n        self.__dict__[\"_config\"] = copy.deepcopy(self._default_config)\n\n    def update(self, *args, **kwargs):\n        self.__dict__[\"_config\"].update(*args, **kwargs)\n\n    def set_conf_from_C(self, config_c):\n        self.update(**config_c.__dict__[\"_config\"])\n\n    @staticmethod\n    def register_from_C(config, skip_register=True):\n        from .utils import set_log_with_config  # pylint: disable=C0415\n\n        if C.registered and skip_register:\n            return\n\n        C.set_conf_from_C(config)\n        if C.logging_config:\n            set_log_with_config(C.logging_config)\n        C.register()\n\n\n# pickle.dump protocol version: https://docs.python.org/3/library/pickle.html#data-stream-format\nPROTOCOL_VERSION = 4\n\nNUM_USABLE_CPU = max(multiprocessing.cpu_count() - 2, 1)\n\nDISK_DATASET_CACHE = \"DiskDatasetCache\"\nSIMPLE_DATASET_CACHE = \"SimpleDatasetCache\"\nDISK_EXPRESSION_CACHE = \"DiskExpressionCache\"\n\nDEPENDENCY_REDIS_CACHE = (DISK_DATASET_CACHE, DISK_EXPRESSION_CACHE)\n\n_default_config = {\n    # data provider config\n    \"calendar_provider\": \"LocalCalendarProvider\",\n    \"instrument_provider\": \"LocalInstrumentProvider\",\n    \"feature_provider\": \"LocalFeatureProvider\",\n    \"pit_provider\": \"LocalPITProvider\",\n    \"expression_provider\": \"LocalExpressionProvider\",\n    \"dataset_provider\": \"LocalDatasetProvider\",\n    \"provider\": \"LocalProvider\",\n    # config it in qlib.init()\n    # \"provider_uri\" str or dict:\n    #   # str\n    #   \"~/.qlib/stock_data/cn_data\"\n    #   # dict\n    #   {\"day\": \"~/.qlib/stock_data/cn_data\", \"1min\": \"~/.qlib/stock_data/cn_data_1min\"}\n    # NOTE: provider_uri priority:\n    #   1. backend_config: backend_obj[\"kwargs\"][\"provider_uri\"]\n    #   2. backend_config: backend_obj[\"kwargs\"][\"provider_uri_map\"]\n    #   3. qlib.init: provider_uri\n    \"provider_uri\": \"\",\n    # cache\n    \"expression_cache\": None,\n    \"calendar_cache\": None,\n    # for simple dataset cache\n    \"local_cache_path\": None,\n    # kernels can be a fixed value or a callable function lie `def (freq: str) -> int`\n    # If the kernels are arctic_kernels, `min(NUM_USABLE_CPU, 30)` may be a good value\n    \"kernels\": NUM_USABLE_CPU,\n    # pickle.dump protocol version\n    \"dump_protocol_version\": PROTOCOL_VERSION,\n    # How many tasks belong to one process. Recommend 1 for high-frequency data and None for daily data.\n    \"maxtasksperchild\": None,\n    # If joblib_backend is None, use loky\n    \"joblib_backend\": \"multiprocessing\",\n    \"default_disk_cache\": 1,  # 0:skip/1:use\n    \"mem_cache_size_limit\": 500,\n    \"mem_cache_limit_type\": \"length\",\n    # memory cache expire second, only in used 'DatasetURICache' and 'client D.calendar'\n    # default 1 hour\n    \"mem_cache_expire\": 60 * 60,\n    # cache dir name\n    \"dataset_cache_dir_name\": \"dataset_cache\",\n    \"features_cache_dir_name\": \"features_cache\",\n    # redis\n    # in order to use cache\n    \"redis_host\": \"127.0.0.1\",\n    \"redis_port\": 6379,\n    \"redis_task_db\": 1,\n    \"redis_password\": None,\n    # This value can be reset via qlib.init\n    \"logging_level\": logging.INFO,\n    # Global configuration of qlib log\n    # logging_level can control the logging level more finely\n    \"logging_config\": {\n        \"version\": 1,\n        \"formatters\": {\n            \"logger_format\": {\n                \"format\": \"[%(process)s:%(threadName)s](%(asctime)s) %(levelname)s - %(name)s - [%(filename)s:%(lineno)d] - %(message)s\"\n            }\n        },\n        \"filters\": {\n            \"field_not_found\": {\n                \"()\": \"qlib.log.LogFilter\",\n                \"param\": [\".*?WARN: data not found for.*?\"],\n            }\n        },\n        \"handlers\": {\n            \"console\": {\n                \"class\": \"logging.StreamHandler\",\n                \"level\": logging.DEBUG,\n                \"formatter\": \"logger_format\",\n                \"filters\": [\"field_not_found\"],\n            }\n        },\n        # Normally this should be set to `False` to avoid duplicated logging [1].\n        # However, due to bug in pytest, it requires log message to propagate to root logger to be captured by `caplog` [2].\n        # [1] https://github.com/microsoft/qlib/pull/1661\n        # [2] https://github.com/pytest-dev/pytest/issues/3697\n        \"loggers\": {\"qlib\": {\"level\": logging.DEBUG, \"handlers\": [\"console\"], \"propagate\": False}},\n        # To let qlib work with other packages, we shouldn't disable existing loggers.\n        # Note that this param is default to True according to the documentation of logging.\n        \"disable_existing_loggers\": False,\n    },\n    # Default config for experiment manager\n    \"exp_manager\": {\n        \"class\": \"MLflowExpManager\",\n        \"module_path\": \"qlib.workflow.expm\",\n        \"kwargs\": {\n            \"uri\": QSETTINGS.mlflow.uri,\n            \"default_exp_name\": QSETTINGS.mlflow.default_exp_name,\n        },\n    },\n    \"pit_record_type\": {\n        \"date\": \"I\",  # uint32\n        \"period\": \"I\",  # uint32\n        \"value\": \"d\",  # float64\n        \"index\": \"I\",  # uint32\n    },\n    \"pit_record_nan\": {\n        \"date\": 0,\n        \"period\": 0,\n        \"value\": float(\"NAN\"),\n        \"index\": 0xFFFFFFFF,\n    },\n    # Default config for MongoDB\n    \"mongo\": {\n        \"task_url\": \"mongodb://localhost:27017/\",\n        \"task_db_name\": \"default_task_db\",\n    },\n    # Shift minute for highfreq minute data, used in backtest\n    # if min_data_shift == 0, use default market time [9:30, 11:29, 1:00, 2:59]\n    # if min_data_shift != 0, use shifted market time [9:30, 11:29, 1:00, 2:59] - shift*minute\n    \"min_data_shift\": 0,\n}\n\nMODE_CONF = {\n    \"server\": {\n        # config it in qlib.init()\n        \"provider_uri\": \"\",\n        # redis\n        \"redis_host\": \"127.0.0.1\",\n        \"redis_port\": 6379,\n        \"redis_task_db\": 1,\n        # cache\n        \"expression_cache\": DISK_EXPRESSION_CACHE,\n        \"dataset_cache\": DISK_DATASET_CACHE,\n        \"local_cache_path\": Path(\"~/.cache/qlib_simple_cache\").expanduser().resolve(),\n        \"mount_path\": None,\n    },\n    \"client\": {\n        # config it in user's own code\n        \"provider_uri\": QSETTINGS.provider_uri,\n        # cache\n        # Using parameter 'remote' to announce the client is using server_cache, and the writing access will be disabled.\n        # Disable cache by default. Avoid introduce advanced features for beginners\n        \"dataset_cache\": None,\n        # SimpleDatasetCache directory\n        \"local_cache_path\": Path(\"~/.cache/qlib_simple_cache\").expanduser().resolve(),\n        # client config\n        \"mount_path\": None,\n        \"auto_mount\": False,  # The nfs is already mounted on our server[auto_mount: False].\n        # The nfs should be auto-mounted by qlib on other\n        # serversS(such as PAI) [auto_mount:True]\n        \"timeout\": 100,\n        \"logging_level\": logging.INFO,\n        \"region\": REG_CN,\n        # custom operator\n        # each element of custom_ops should be Type[ExpressionOps] or dict\n        # if element of custom_ops is Type[ExpressionOps], it represents the custom operator class\n        # if element of custom_ops is dict, it represents the config of custom operator and should include `class` and `module_path` keys.\n        \"custom_ops\": [],\n    },\n}\n\nHIGH_FREQ_CONFIG = {\n    \"provider_uri\": \"~/.qlib/qlib_data/cn_data_1min\",\n    \"dataset_cache\": None,\n    \"expression_cache\": \"DiskExpressionCache\",\n    \"region\": REG_CN,\n}\n\n_default_region_config = {\n    REG_CN: {\n        \"trade_unit\": 100,\n        \"limit_threshold\": 0.095,\n        \"deal_price\": \"close\",\n    },\n    REG_US: {\n        \"trade_unit\": 1,\n        \"limit_threshold\": None,\n        \"deal_price\": \"close\",\n    },\n    REG_TW: {\n        \"trade_unit\": 1000,\n        \"limit_threshold\": 0.1,\n        \"deal_price\": \"close\",\n    },\n}\n\n\nclass QlibConfig(Config):\n    # URI_TYPE\n    LOCAL_URI = \"local\"\n    NFS_URI = \"nfs\"\n    DEFAULT_FREQ = \"__DEFAULT_FREQ\"\n\n    def __init__(self, default_conf):\n        super().__init__(default_conf)\n        self._registered = False\n\n    class DataPathManager:\n        \"\"\"\n        Motivation:\n        - get the right path (e.g. data uri) for accessing data based on given information(e.g. provider_uri, mount_path and frequency)\n        - some helper functions to process uri.\n        \"\"\"\n\n        def __init__(self, provider_uri: Union[str, Path, dict], mount_path: Union[str, Path, dict]):\n            \"\"\"\n            The relation of `provider_uri` and `mount_path`\n            - `mount_path` is used only if provider_uri is an NFS path\n            - otherwise, provider_uri will be used for accessing data\n            \"\"\"\n            self.provider_uri = provider_uri\n            self.mount_path = mount_path\n\n        @staticmethod\n        def format_provider_uri(provider_uri: Union[str, dict, Path]) -> dict:\n            if provider_uri is None:\n                raise ValueError(\"provider_uri cannot be None\")\n            if isinstance(provider_uri, (str, dict, Path)):\n                if not isinstance(provider_uri, dict):\n                    provider_uri = {QlibConfig.DEFAULT_FREQ: provider_uri}\n            else:\n                raise TypeError(f\"provider_uri does not support {type(provider_uri)}\")\n            for freq, _uri in provider_uri.items():\n                if QlibConfig.DataPathManager.get_uri_type(_uri) == QlibConfig.LOCAL_URI:\n                    provider_uri[freq] = str(Path(_uri).expanduser().resolve())\n            return provider_uri\n\n        @staticmethod\n        def get_uri_type(uri: Union[str, Path]):\n            uri = uri if isinstance(uri, str) else str(uri.expanduser().resolve())\n            is_win = re.match(\"^[a-zA-Z]:.*\", uri) is not None  # such as 'C:\\\\data', 'D:'\n            # such as 'host:/data/'   (User may define short hostname by themselves or use localhost)\n            is_nfs_or_win = re.match(\"^[^/]+:.+\", uri) is not None\n\n            if is_nfs_or_win and not is_win:\n                return QlibConfig.NFS_URI\n            else:\n                return QlibConfig.LOCAL_URI\n\n        def get_data_uri(self, freq: Optional[Union[str, Freq]] = None) -> Path:\n            \"\"\"\n            please refer DataPathManager's __init__ and class doc\n            \"\"\"\n            if freq is not None:\n                freq = str(freq)  # converting Freq to string\n            if freq is None or freq not in self.provider_uri:\n                freq = QlibConfig.DEFAULT_FREQ\n            _provider_uri = self.provider_uri[freq]\n            if self.get_uri_type(_provider_uri) == QlibConfig.LOCAL_URI:\n                return Path(_provider_uri)\n            elif self.get_uri_type(_provider_uri) == QlibConfig.NFS_URI:\n                if \"win\" in platform.system().lower():\n                    # windows, mount_path is the drive\n                    _path = str(self.mount_path[freq])\n                    return Path(f\"{_path}:\\\\\") if \":\" not in _path else Path(_path)\n                return Path(self.mount_path[freq])\n            else:\n                raise NotImplementedError(f\"This type of uri is not supported\")\n\n    def set_mode(self, mode):\n        # raise KeyError\n        self.update(MODE_CONF[mode])\n        # TODO: update region based on kwargs\n\n    def set_region(self, region):\n        # raise KeyError\n        self.update(_default_region_config[region])\n\n    @staticmethod\n    def is_depend_redis(cache_name: str):\n        return cache_name in DEPENDENCY_REDIS_CACHE\n\n    @property\n    def dpm(self):\n        return self.DataPathManager(self[\"provider_uri\"], self[\"mount_path\"])\n\n    def resolve_path(self):\n        # resolve path\n        _mount_path = self[\"mount_path\"]\n        _provider_uri = self.DataPathManager.format_provider_uri(self[\"provider_uri\"])\n        if not isinstance(_mount_path, dict):\n            _mount_path = {_freq: _mount_path for _freq in _provider_uri.keys()}\n\n        # check provider_uri and mount_path\n        _miss_freq = set(_provider_uri.keys()) - set(_mount_path.keys())\n        assert len(_miss_freq) == 0, f\"mount_path is missing freq: {_miss_freq}\"\n\n        # resolve\n        for _freq in _provider_uri.keys():\n            # mount_path\n            _mount_path[_freq] = (\n                _mount_path[_freq] if _mount_path[_freq] is None else str(Path(_mount_path[_freq]).expanduser())\n            )\n        self[\"provider_uri\"] = _provider_uri\n        self[\"mount_path\"] = _mount_path\n\n    def set(self, default_conf: str = \"client\", **kwargs):\n        \"\"\"\n        configure qlib based on the input parameters\n\n        The configuration will act like a dictionary.\n\n        Normally, it literally is replaced the value according to the keys.\n        However, sometimes it is hard for users to set the config when the configuration is nested and complicated\n\n        So this API provides some special parameters for users to set the keys in a more convenient way.\n        - region:  REG_CN, REG_US\n            - several region-related config will be changed\n\n        Parameters\n        ----------\n        default_conf : str\n            the default config template chosen by user: \"server\", \"client\"\n        \"\"\"\n        from .utils import set_log_with_config, get_module_logger, can_use_cache  # pylint: disable=C0415\n\n        self.reset()\n\n        _logging_config = kwargs.get(\"logging_config\", self.logging_config)\n\n        # set global config\n        if _logging_config:\n            set_log_with_config(_logging_config)\n\n        logger = get_module_logger(\"Initialization\", kwargs.get(\"logging_level\", self.logging_level))\n        logger.info(f\"default_conf: {default_conf}.\")\n\n        self.set_mode(default_conf)\n        self.set_region(kwargs.get(\"region\", self[\"region\"] if \"region\" in self else REG_CN))\n\n        for k, v in kwargs.items():\n            if k not in self:\n                logger.warning(\"Unrecognized config %s\" % k)\n            self[k] = v\n\n        self.resolve_path()\n\n        if not (self[\"expression_cache\"] is None and self[\"dataset_cache\"] is None):\n            # check redis\n            if not can_use_cache():\n                log_str = \"\"\n                # check expression cache\n                if self.is_depend_redis(self[\"expression_cache\"]):\n                    log_str += self[\"expression_cache\"]\n                    self[\"expression_cache\"] = None\n                # check dataset cache\n                if self.is_depend_redis(self[\"dataset_cache\"]):\n                    log_str += f\" and {self['dataset_cache']}\" if log_str else self[\"dataset_cache\"]\n                    self[\"dataset_cache\"] = None\n                if log_str:\n                    logger.warning(\n                        f\"redis connection failed(host={self['redis_host']} port={self['redis_port']}), \"\n                        f\"{log_str} will not be used!\"\n                    )\n\n    def register(self):\n        from .utils import init_instance_by_config  # pylint: disable=C0415\n        from .data.ops import register_all_ops  # pylint: disable=C0415\n        from .data.data import register_all_wrappers  # pylint: disable=C0415\n        from .workflow import R, QlibRecorder  # pylint: disable=C0415\n        from .workflow.utils import experiment_exit_handler  # pylint: disable=C0415\n\n        register_all_ops(self)\n        register_all_wrappers(self)\n        # set up QlibRecorder\n        exp_manager = init_instance_by_config(self[\"exp_manager\"])\n        qr = QlibRecorder(exp_manager)\n        R.register(qr)\n        # clean up experiment when python program ends\n        experiment_exit_handler()\n\n        # Supporting user reset qlib version (useful when user want to connect to qlib server with old version)\n        self.reset_qlib_version()\n\n        self._registered = True\n\n    def reset_qlib_version(self):\n        import qlib  # pylint: disable=C0415\n\n        reset_version = self.get(\"qlib_reset_version\", None)\n        if reset_version is not None:\n            qlib.__version__ = reset_version\n        else:\n            qlib.__version__ = getattr(qlib, \"__version__bak\")\n            # Due to a bug? that converting __version__ to _QlibConfig__version__bak\n            # Using  __version__bak instead of __version__\n\n    def get_kernels(self, freq: str):\n        \"\"\"get number of processors given frequency\"\"\"\n        if isinstance(self[\"kernels\"], Callable):\n            return self[\"kernels\"](freq)\n        return self[\"kernels\"]\n\n    @property\n    def registered(self):\n        return self._registered\n\n\n# global config\nC = QlibConfig(_default_config)\n"
  },
  {
    "path": "qlib/constant.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\n# REGION CONST\nfrom typing import TypeVar\n\nimport numpy as np\nimport pandas as pd\n\nREG_CN = \"cn\"\nREG_US = \"us\"\nREG_TW = \"tw\"\n\n# Epsilon for avoiding division by zero.\nEPS = 1e-12\n\n# Infinity in integer\nINF = int(1e18)\nONE_DAY = pd.Timedelta(\"1day\")\nONE_MIN = pd.Timedelta(\"1min\")\nEPS_T = pd.Timedelta(\"1s\")  # use 1 second to exclude the right interval point\nfloat_or_ndarray = TypeVar(\"float_or_ndarray\", float, np.ndarray)\n"
  },
  {
    "path": "qlib/contrib/__init__.py",
    "content": ""
  },
  {
    "path": "qlib/contrib/data/__init__.py",
    "content": ""
  },
  {
    "path": "qlib/contrib/data/data.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\n# We remove arctic from core framework of Qlib to contrib due to\n# - Arctic has very strict limitation on pandas and numpy version\n#    - https://github.com/man-group/arctic/pull/908\n# - pip fail to computing the right version number!!!!\n#    - Maybe we can solve this problem by poetry\n\n# FIXME: So if you want to use arctic-based provider, please install arctic manually\n# `pip install arctic` may not be enough.\nfrom arctic import Arctic\nimport pandas as pd\nimport pymongo\n\nfrom qlib.data.data import FeatureProvider\n\n\nclass ArcticFeatureProvider(FeatureProvider):\n    def __init__(\n        self, uri=\"127.0.0.1\", retry_time=0, market_transaction_time_list=[(\"09:15\", \"11:30\"), (\"13:00\", \"15:00\")]\n    ):\n        super().__init__()\n        self.uri = uri\n        # TODO:\n        # retry connecting if error occurs\n        # does it real matters?\n        self.retry_time = retry_time\n        # NOTE: this is especially important for TResample operator\n        self.market_transaction_time_list = market_transaction_time_list\n\n    def feature(self, instrument, field, start_index, end_index, freq):\n        field = str(field)[1:]\n        with pymongo.MongoClient(self.uri) as client:\n            # TODO: this will result in frequently connecting the server and performance issue\n            arctic = Arctic(client)\n\n            if freq not in arctic.list_libraries():\n                raise ValueError(\"lib {} not in arctic\".format(freq))\n\n            if instrument not in arctic[freq].list_symbols():\n                # instruments does not exist\n                return pd.Series()\n            else:\n                df = arctic[freq].read(instrument, columns=[field], chunk_range=(start_index, end_index))\n                s = df[field]\n\n                if not s.empty:\n                    s = pd.concat(\n                        [\n                            s.between_time(time_tuple[0], time_tuple[1])\n                            for time_tuple in self.market_transaction_time_list\n                        ]\n                    )\n                return s\n"
  },
  {
    "path": "qlib/contrib/data/dataset.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport copy\nimport torch\nimport warnings\nimport numpy as np\nimport pandas as pd\nfrom qlib.utils.data import guess_horizon\nfrom qlib.utils import init_instance_by_config\n\nfrom qlib.data.dataset import DatasetH\n\ndevice = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\n\ndef _to_tensor(x):\n    if not isinstance(x, torch.Tensor):\n        return torch.tensor(x, dtype=torch.float, device=device)  # pylint: disable=E1101\n    return x\n\n\ndef _create_ts_slices(index, seq_len):\n    \"\"\"\n    create time series slices from pandas index\n\n    Args:\n        index (pd.MultiIndex): pandas multiindex with <instrument, datetime> order\n        seq_len (int): sequence length\n    \"\"\"\n    assert isinstance(index, pd.MultiIndex), \"unsupported index type\"\n    assert seq_len > 0, \"sequence length should be larger than 0\"\n    assert index.is_monotonic_increasing, \"index should be sorted\"\n\n    # number of dates for each instrument\n    sample_count_by_insts = index.to_series().groupby(level=0, group_keys=False).size().values\n\n    # start index for each instrument\n    start_index_of_insts = np.roll(np.cumsum(sample_count_by_insts), 1)\n    start_index_of_insts[0] = 0\n\n    # all the [start, stop) indices of features\n    # features between [start, stop) will be used to predict label at `stop - 1`\n    slices = []\n    for cur_loc, cur_cnt in zip(start_index_of_insts, sample_count_by_insts):\n        for stop in range(1, cur_cnt + 1):\n            end = cur_loc + stop\n            start = max(end - seq_len, 0)\n            slices.append(slice(start, end))\n    slices = np.array(slices, dtype=\"object\")\n\n    assert len(slices) == len(index)  # the i-th slice = index[i]\n\n    return slices\n\n\ndef _get_date_parse_fn(target):\n    \"\"\"get date parse function\n\n    This method is used to parse date arguments as target type.\n\n    Example:\n        get_date_parse_fn('20120101')('2017-01-01') => '20170101'\n        get_date_parse_fn(20120101)('2017-01-01') => 20170101\n    \"\"\"\n    if isinstance(target, int):\n\n        def _fn(x):\n            return int(str(x).replace(\"-\", \"\")[:8])  # 20200201\n\n    elif isinstance(target, str) and len(target) == 8:\n\n        def _fn(x):\n            return str(x).replace(\"-\", \"\")[:8]  # '20200201'\n\n    else:\n\n        def _fn(x):\n            return x  # '2021-01-01'\n\n    return _fn\n\n\ndef _maybe_padding(x, seq_len, zeros=None):\n    \"\"\"padding 2d <time * feature> data with zeros\n\n    Args:\n        x (np.ndarray): 2d data with shape <time * feature>\n        seq_len (int): target sequence length\n        zeros (np.ndarray): zeros with shape <seq_len * feature>\n    \"\"\"\n    assert seq_len > 0, \"sequence length should be larger than 0\"\n    if zeros is None:\n        zeros = np.zeros((seq_len, x.shape[1]), dtype=np.float32)\n    else:\n        assert len(zeros) >= seq_len, \"zeros matrix is not large enough for padding\"\n    if len(x) != seq_len:  # padding zeros\n        x = np.concatenate([zeros[: seq_len - len(x), : x.shape[1]], x], axis=0)\n    return x\n\n\nclass MTSDatasetH(DatasetH):\n    \"\"\"Memory Augmented Time Series Dataset\n\n    Args:\n        handler (DataHandler): data handler\n        segments (dict): data split segments\n        seq_len (int): time series sequence length\n        horizon (int): label horizon\n        num_states (int): how many memory states to be added\n        memory_mode (str): memory mode (daily or sample)\n        batch_size (int): batch size (<0 will use daily sampling)\n        n_samples (int): number of samples in the same day\n        shuffle (bool): whether shuffle data\n        drop_last (bool): whether drop last batch < batch_size\n        input_size (int): reshape flatten rows as this input_size (backward compatibility)\n    \"\"\"\n\n    def __init__(\n        self,\n        handler,\n        segments,\n        seq_len=60,\n        horizon=0,\n        num_states=0,\n        memory_mode=\"sample\",\n        batch_size=-1,\n        n_samples=None,\n        shuffle=True,\n        drop_last=False,\n        input_size=None,\n        **kwargs,\n    ):\n        if horizon == 0:\n            # Try to guess horizon\n            if isinstance(handler, (dict, str)):\n                handler = init_instance_by_config(handler)\n            assert \"label\" in getattr(handler.data_loader, \"fields\", None)\n            label = handler.data_loader.fields[\"label\"][0][0]\n            horizon = guess_horizon([label])\n\n        assert num_states == 0 or horizon > 0, \"please specify `horizon` to avoid data leakage\"\n        assert memory_mode in [\"sample\", \"daily\"], \"unsupported memory mode\"\n        assert memory_mode == \"sample\" or batch_size < 0, \"daily memory requires daily sampling (`batch_size < 0`)\"\n        assert batch_size != 0, \"invalid batch size\"\n\n        if batch_size > 0 and n_samples is not None:\n            warnings.warn(\"`n_samples` can only be used for daily sampling (`batch_size < 0`)\")\n\n        self.seq_len = seq_len\n        self.horizon = horizon\n        self.num_states = num_states\n        self.memory_mode = memory_mode\n        self.batch_size = batch_size\n        self.n_samples = n_samples\n        self.shuffle = shuffle\n        self.drop_last = drop_last\n        self.input_size = input_size\n        self.params = (batch_size, n_samples, drop_last, shuffle)  # for train/eval switch\n\n        super().__init__(handler, segments, **kwargs)\n\n    def setup_data(self, handler_kwargs: dict = None, **kwargs):\n        super().setup_data(**kwargs)\n\n        if handler_kwargs is not None:\n            self.handler.setup_data(**handler_kwargs)\n\n        # pre-fetch data and change index to <code, date>\n        # NOTE: we will use inplace sort to reduce memory use\n        try:\n            df = self.handler._learn.copy()  # use copy otherwise recorder will fail\n            # FIXME: currently we cannot support switching from `_learn` to `_infer` for inference\n        except Exception:\n            warnings.warn(\"cannot access `_learn`, will load raw data\")\n            df = self.handler._data.copy()\n        df.index = df.index.swaplevel()\n        df.sort_index(inplace=True)\n\n        # convert to numpy\n        self._data = df[\"feature\"].values.astype(\"float32\")\n        np.nan_to_num(self._data, copy=False)  # NOTE: fillna in case users forget using the fillna processor\n        self._label = df[\"label\"].squeeze().values.astype(\"float32\")\n        self._index = df.index\n\n        if self.input_size is not None and self.input_size != self._data.shape[1]:\n            warnings.warn(\"the data has different shape from input_size and the data will be reshaped\")\n            assert self._data.shape[1] % self.input_size == 0, \"data mismatch, please check `input_size`\"\n\n        # create batch slices\n        self._batch_slices = _create_ts_slices(self._index, self.seq_len)\n\n        # create daily slices\n        daily_slices = {date: [] for date in sorted(self._index.unique(level=1))}  # sorted by date\n        for i, (code, date) in enumerate(self._index):\n            daily_slices[date].append(self._batch_slices[i])\n        self._daily_slices = np.array(list(daily_slices.values()), dtype=\"object\")\n        self._daily_index = pd.Series(list(daily_slices.keys()))  # index is the original date index\n\n        # add memory (sample wise and daily)\n        if self.memory_mode == \"sample\":\n            self._memory = np.zeros((len(self._data), self.num_states), dtype=np.float32)\n        elif self.memory_mode == \"daily\":\n            self._memory = np.zeros((len(self._daily_index), self.num_states), dtype=np.float32)\n        else:\n            raise ValueError(f\"invalid memory_mode `{self.memory_mode}`\")\n\n        # padding tensor\n        self._zeros = np.zeros((self.seq_len, max(self.num_states, self._data.shape[1])), dtype=np.float32)\n\n    def _prepare_seg(self, slc, **kwargs):\n        fn = _get_date_parse_fn(self._index[0][1])\n        if isinstance(slc, slice):\n            start, stop = slc.start, slc.stop\n        elif isinstance(slc, (list, tuple)):\n            start, stop = slc\n        else:\n            raise NotImplementedError(f\"This type of input is not supported\")\n        start_date = pd.Timestamp(fn(start))\n        end_date = pd.Timestamp(fn(stop))\n        obj = copy.copy(self)  # shallow copy\n        # NOTE: Seriable will disable copy `self._data` so we manually assign them here\n        obj._data = self._data  # reference (no copy)\n        obj._label = self._label\n        obj._index = self._index\n        obj._memory = self._memory\n        obj._zeros = self._zeros\n        # update index for this batch\n        date_index = self._index.get_level_values(1)\n        obj._batch_slices = self._batch_slices[(date_index >= start_date) & (date_index <= end_date)]\n        mask = (self._daily_index.values >= start_date) & (self._daily_index.values <= end_date)\n        obj._daily_slices = self._daily_slices[mask]\n        obj._daily_index = self._daily_index[mask]\n        return obj\n\n    def restore_index(self, index):\n        return self._index[index]\n\n    def restore_daily_index(self, daily_index):\n        return pd.Index(self._daily_index.loc[daily_index])\n\n    def assign_data(self, index, vals):\n        if self.num_states == 0:\n            raise ValueError(\"cannot assign data as `num_states==0`\")\n        if isinstance(vals, torch.Tensor):\n            vals = vals.detach().cpu().numpy()\n        self._memory[index] = vals\n\n    def clear_memory(self):\n        if self.num_states == 0:\n            raise ValueError(\"cannot clear memory as `num_states==0`\")\n        self._memory[:] = 0\n\n    def train(self):\n        \"\"\"enable traning mode\"\"\"\n        self.batch_size, self.n_samples, self.drop_last, self.shuffle = self.params\n\n    def eval(self):\n        \"\"\"enable evaluation mode\"\"\"\n        self.batch_size = -1\n        self.n_samples = None\n        self.drop_last = False\n        self.shuffle = False\n\n    def _get_slices(self):\n        if self.batch_size < 0:  # daily sampling\n            slices = self._daily_slices.copy()\n            batch_size = -1 * self.batch_size\n        else:  # normal sampling\n            slices = self._batch_slices.copy()\n            batch_size = self.batch_size\n        return slices, batch_size\n\n    def __len__(self):\n        slices, batch_size = self._get_slices()\n        if self.drop_last:\n            return len(slices) // batch_size\n        return (len(slices) + batch_size - 1) // batch_size\n\n    def __iter__(self):\n        slices, batch_size = self._get_slices()\n        indices = np.arange(len(slices))\n        if self.shuffle:\n            np.random.shuffle(indices)\n\n        for i in range(len(indices))[::batch_size]:\n            if self.drop_last and i + batch_size > len(indices):\n                break\n\n            data = []  # store features\n            label = []  # store labels\n            index = []  # store index\n            state = []  # store memory states\n            daily_index = []  # store daily index\n            daily_count = []  # store number of samples for each day\n\n            for j in indices[i : i + batch_size]:\n                # normal sampling: self.batch_size > 0 => slices is a list => slices_subset is a slice\n                # daily sampling: self.batch_size < 0 => slices is a nested list => slices_subset is a list\n                slices_subset = slices[j]\n\n                # daily sampling\n                # each slices_subset contains a list of slices for multiple stocks\n                # NOTE: daily sampling is used in 1) eval mode, 2) train mode with self.batch_size < 0\n                if self.batch_size < 0:\n                    # store daily index\n                    idx = self._daily_index.index[j]  # daily_index.index is the index of the original data\n                    daily_index.append(idx)\n\n                    # store daily memory if specified\n                    # NOTE: daily memory always requires daily sampling (self.batch_size < 0)\n                    if self.memory_mode == \"daily\":\n                        slc = slice(max(idx - self.seq_len - self.horizon, 0), max(idx - self.horizon, 0))\n                        state.append(_maybe_padding(self._memory[slc], self.seq_len, self._zeros))\n\n                    # down-sample stocks and store count\n                    if self.n_samples and 0 < self.n_samples < len(slices_subset):  # intraday subsample\n                        slices_subset = np.random.choice(slices_subset, self.n_samples, replace=False)\n                    daily_count.append(len(slices_subset))\n\n                # normal sampling\n                # each slices_subset is a single slice\n                # NOTE: normal sampling is used in train mode with self.batch_size > 0\n                else:\n                    slices_subset = [slices_subset]\n\n                for slc in slices_subset:\n                    # legacy support for Alpha360 data by `input_size`\n                    if self.input_size:\n                        data.append(self._data[slc.stop - 1].reshape(self.input_size, -1).T)\n                    else:\n                        data.append(_maybe_padding(self._data[slc], self.seq_len, self._zeros))\n\n                    if self.memory_mode == \"sample\":\n                        state.append(_maybe_padding(self._memory[slc], self.seq_len, self._zeros)[: -self.horizon])\n\n                    label.append(self._label[slc.stop - 1])\n                    index.append(slc.stop - 1)\n\n                    # end slices loop\n\n                # end indices batch loop\n\n            # concate\n            data = _to_tensor(np.stack(data))\n            state = _to_tensor(np.stack(state))\n            label = _to_tensor(np.stack(label))\n            index = np.array(index)\n            daily_index = np.array(daily_index)\n            daily_count = np.array(daily_count)\n\n            # yield -> generator\n            yield {\n                \"data\": data,\n                \"label\": label,\n                \"state\": state,\n                \"index\": index,\n                \"daily_index\": daily_index,\n                \"daily_count\": daily_count,\n            }\n\n        # end indice loop\n"
  },
  {
    "path": "qlib/contrib/data/handler.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom qlib.contrib.data.loader import Alpha158DL, Alpha360DL\nfrom ...data.dataset.handler import DataHandlerLP\nfrom ...data.dataset.processor import Processor\nfrom ...utils import get_callable_kwargs\nfrom ...data.dataset import processor as processor_module\nfrom inspect import getfullargspec\n\n\ndef check_transform_proc(proc_l, fit_start_time, fit_end_time):\n    new_l = []\n    for p in proc_l:\n        if not isinstance(p, Processor):\n            klass, pkwargs = get_callable_kwargs(p, processor_module)\n            args = getfullargspec(klass).args\n            if \"fit_start_time\" in args and \"fit_end_time\" in args:\n                assert (\n                    fit_start_time is not None and fit_end_time is not None\n                ), \"Make sure `fit_start_time` and `fit_end_time` are not None.\"\n                pkwargs.update(\n                    {\n                        \"fit_start_time\": fit_start_time,\n                        \"fit_end_time\": fit_end_time,\n                    }\n                )\n            proc_config = {\"class\": klass.__name__, \"kwargs\": pkwargs}\n            if isinstance(p, dict) and \"module_path\" in p:\n                proc_config[\"module_path\"] = p[\"module_path\"]\n            new_l.append(proc_config)\n        else:\n            new_l.append(p)\n    return new_l\n\n\n_DEFAULT_LEARN_PROCESSORS = [\n    {\"class\": \"DropnaLabel\"},\n    {\"class\": \"CSZScoreNorm\", \"kwargs\": {\"fields_group\": \"label\"}},\n]\n_DEFAULT_INFER_PROCESSORS = [\n    {\"class\": \"ProcessInf\", \"kwargs\": {}},\n    {\"class\": \"ZScoreNorm\", \"kwargs\": {}},\n    {\"class\": \"Fillna\", \"kwargs\": {}},\n]\n\n\nclass Alpha360(DataHandlerLP):\n    def __init__(\n        self,\n        instruments=\"csi500\",\n        start_time=None,\n        end_time=None,\n        freq=\"day\",\n        infer_processors=_DEFAULT_INFER_PROCESSORS,\n        learn_processors=_DEFAULT_LEARN_PROCESSORS,\n        fit_start_time=None,\n        fit_end_time=None,\n        filter_pipe=None,\n        inst_processors=None,\n        **kwargs,\n    ):\n        infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)\n        learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)\n\n        data_loader = {\n            \"class\": \"QlibDataLoader\",\n            \"kwargs\": {\n                \"config\": {\n                    \"feature\": Alpha360DL.get_feature_config(),\n                    \"label\": kwargs.pop(\"label\", self.get_label_config()),\n                },\n                \"filter_pipe\": filter_pipe,\n                \"freq\": freq,\n                \"inst_processors\": inst_processors,\n            },\n        }\n\n        super().__init__(\n            instruments=instruments,\n            start_time=start_time,\n            end_time=end_time,\n            data_loader=data_loader,\n            learn_processors=learn_processors,\n            infer_processors=infer_processors,\n            **kwargs,\n        )\n\n    def get_label_config(self):\n        return [\"Ref($close, -2)/Ref($close, -1) - 1\"], [\"LABEL0\"]\n\n\nclass Alpha360vwap(Alpha360):\n    def get_label_config(self):\n        return [\"Ref($vwap, -2)/Ref($vwap, -1) - 1\"], [\"LABEL0\"]\n\n\nclass Alpha158(DataHandlerLP):\n    def __init__(\n        self,\n        instruments=\"csi500\",\n        start_time=None,\n        end_time=None,\n        freq=\"day\",\n        infer_processors=[],\n        learn_processors=_DEFAULT_LEARN_PROCESSORS,\n        fit_start_time=None,\n        fit_end_time=None,\n        process_type=DataHandlerLP.PTYPE_A,\n        filter_pipe=None,\n        inst_processors=None,\n        **kwargs,\n    ):\n        infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)\n        learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)\n\n        data_loader = {\n            \"class\": \"QlibDataLoader\",\n            \"kwargs\": {\n                \"config\": {\n                    \"feature\": self.get_feature_config(),\n                    \"label\": kwargs.pop(\"label\", self.get_label_config()),\n                },\n                \"filter_pipe\": filter_pipe,\n                \"freq\": freq,\n                \"inst_processors\": inst_processors,\n            },\n        }\n        super().__init__(\n            instruments=instruments,\n            start_time=start_time,\n            end_time=end_time,\n            data_loader=data_loader,\n            infer_processors=infer_processors,\n            learn_processors=learn_processors,\n            process_type=process_type,\n            **kwargs,\n        )\n\n    def get_feature_config(self):\n        conf = {\n            \"kbar\": {},\n            \"price\": {\n                \"windows\": [0],\n                \"feature\": [\"OPEN\", \"HIGH\", \"LOW\", \"VWAP\"],\n            },\n            \"rolling\": {},\n        }\n        return Alpha158DL.get_feature_config(conf)\n\n    def get_label_config(self):\n        return [\"Ref($close, -2)/Ref($close, -1) - 1\"], [\"LABEL0\"]\n\n\nclass Alpha158vwap(Alpha158):\n    def get_label_config(self):\n        return [\"Ref($vwap, -2)/Ref($vwap, -1) - 1\"], [\"LABEL0\"]\n"
  },
  {
    "path": "qlib/contrib/data/highfreq_handler.py",
    "content": "from qlib.data.dataset.handler import DataHandler, DataHandlerLP\n\nfrom .handler import check_transform_proc\n\nEPSILON = 1e-4\n\n\nclass HighFreqHandler(DataHandlerLP):\n    def __init__(\n        self,\n        instruments=\"csi300\",\n        start_time=None,\n        end_time=None,\n        infer_processors=[],\n        learn_processors=[],\n        fit_start_time=None,\n        fit_end_time=None,\n        drop_raw=True,\n    ):\n        infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)\n        learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)\n\n        data_loader = {\n            \"class\": \"QlibDataLoader\",\n            \"kwargs\": {\n                \"config\": self.get_feature_config(),\n                \"swap_level\": False,\n                \"freq\": \"1min\",\n            },\n        }\n        super().__init__(\n            instruments=instruments,\n            start_time=start_time,\n            end_time=end_time,\n            data_loader=data_loader,\n            infer_processors=infer_processors,\n            learn_processors=learn_processors,\n            drop_raw=drop_raw,\n        )\n\n    def get_feature_config(self):\n        fields = []\n        names = []\n\n        template_if = \"If(IsNull({1}), {0}, {1})\"\n        template_paused = \"Select(Gt($paused_num, 1.001), {0})\"\n\n        def get_normalized_price_feature(price_field, shift=0):\n            # norm with the close price of 237th minute of yesterday.\n            if shift == 0:\n                template_norm = \"{0}/DayLast(Ref({1}, 243))\"\n            else:\n                template_norm = \"Ref({0}, \" + str(shift) + \")/DayLast(Ref({1}, 243))\"\n\n            template_fillnan = \"FFillNan({0})\"\n            # calculate -> ffill -> remove paused\n            feature_ops = template_paused.format(\n                template_fillnan.format(\n                    template_norm.format(template_if.format(\"$close\", price_field), template_fillnan.format(\"$close\"))\n                )\n            )\n            return feature_ops\n\n        fields += [get_normalized_price_feature(\"$open\", 0)]\n        fields += [get_normalized_price_feature(\"$high\", 0)]\n        fields += [get_normalized_price_feature(\"$low\", 0)]\n        fields += [get_normalized_price_feature(\"$close\", 0)]\n        fields += [get_normalized_price_feature(\"$vwap\", 0)]\n        names += [\"$open\", \"$high\", \"$low\", \"$close\", \"$vwap\"]\n\n        fields += [get_normalized_price_feature(\"$open\", 240)]\n        fields += [get_normalized_price_feature(\"$high\", 240)]\n        fields += [get_normalized_price_feature(\"$low\", 240)]\n        fields += [get_normalized_price_feature(\"$close\", 240)]\n        fields += [get_normalized_price_feature(\"$vwap\", 240)]\n        names += [\"$open_1\", \"$high_1\", \"$low_1\", \"$close_1\", \"$vwap_1\"]\n\n        # calculate and fill nan with 0\n        template_gzero = \"If(Ge({0}, 0), {0}, 0)\"\n        fields += [\n            template_gzero.format(\n                template_paused.format(\n                    \"If(IsNull({0}), 0, {0})\".format(\"{0}/Ref(DayLast(Mean({0}, 7200)), 240)\".format(\"$volume\"))\n                )\n            )\n        ]\n        names += [\"$volume\"]\n\n        fields += [\n            template_gzero.format(\n                template_paused.format(\n                    \"If(IsNull({0}), 0, {0})\".format(\n                        \"Ref({0}, 240)/Ref(DayLast(Mean({0}, 7200)), 240)\".format(\"$volume\")\n                    )\n                )\n            )\n        ]\n        names += [\"$volume_1\"]\n\n        return fields, names\n\n\nclass HighFreqGeneralHandler(DataHandlerLP):\n    def __init__(\n        self,\n        instruments=\"csi300\",\n        start_time=None,\n        end_time=None,\n        infer_processors=[],\n        learn_processors=[],\n        fit_start_time=None,\n        fit_end_time=None,\n        drop_raw=True,\n        day_length=240,\n        freq=\"1min\",\n        columns=[\"$open\", \"$high\", \"$low\", \"$close\", \"$vwap\"],\n        inst_processors=None,\n    ):\n        self.day_length = day_length\n        self.columns = columns\n\n        infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)\n        learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)\n\n        data_loader = {\n            \"class\": \"QlibDataLoader\",\n            \"kwargs\": {\n                \"config\": self.get_feature_config(),\n                \"swap_level\": False,\n                \"freq\": freq,\n                \"inst_processors\": inst_processors,\n            },\n        }\n        super().__init__(\n            instruments=instruments,\n            start_time=start_time,\n            end_time=end_time,\n            data_loader=data_loader,\n            infer_processors=infer_processors,\n            learn_processors=learn_processors,\n            drop_raw=drop_raw,\n        )\n\n    def get_feature_config(self):\n        fields = []\n        names = []\n\n        template_if = \"If(IsNull({1}), {0}, {1})\"\n        template_paused = f\"Cut({{0}}, {self.day_length * 2}, None)\"\n\n        def get_normalized_price_feature(price_field, shift=0):\n            # norm with the close price of 237th minute of yesterday.\n            if shift == 0:\n                template_norm = f\"{{0}}/DayLast(Ref({{1}}, {self.day_length * 2}))\"\n            else:\n                template_norm = f\"Ref({{0}}, \" + str(shift) + f\")/DayLast(Ref({{1}}, {self.day_length}))\"\n\n            template_fillnan = \"FFillNan({0})\"\n            # calculate -> ffill -> remove paused\n            feature_ops = template_paused.format(\n                template_fillnan.format(\n                    template_norm.format(template_if.format(\"$close\", price_field), template_fillnan.format(\"$close\"))\n                )\n            )\n            return feature_ops\n\n        for column_name in self.columns:\n            fields.append(get_normalized_price_feature(column_name, 0))\n            names.append(column_name)\n\n        for column_name in self.columns:\n            fields.append(get_normalized_price_feature(column_name, self.day_length))\n            names.append(column_name + \"_1\")\n\n        # calculate and fill nan with 0\n        fields += [\n            template_paused.format(\n                \"If(IsNull({0}), 0, {0})\".format(\n                    f\"{{0}}/Ref(DayLast(Mean({{0}}, {self.day_length * 30})), {self.day_length})\".format(\"$volume\")\n                )\n            )\n        ]\n        names += [\"$volume\"]\n\n        fields += [\n            template_paused.format(\n                \"If(IsNull({0}), 0, {0})\".format(\n                    f\"Ref({{0}}, {self.day_length})/Ref(DayLast(Mean({{0}}, {self.day_length * 30})), {self.day_length})\".format(\n                        \"$volume\"\n                    )\n                )\n            )\n        ]\n        names += [\"$volume_1\"]\n\n        return fields, names\n\n\nclass HighFreqBacktestHandler(DataHandler):\n    def __init__(\n        self,\n        instruments=\"csi300\",\n        start_time=None,\n        end_time=None,\n    ):\n        data_loader = {\n            \"class\": \"QlibDataLoader\",\n            \"kwargs\": {\n                \"config\": self.get_feature_config(),\n                \"swap_level\": False,\n                \"freq\": \"1min\",\n            },\n        }\n        super().__init__(\n            instruments=instruments,\n            start_time=start_time,\n            end_time=end_time,\n            data_loader=data_loader,\n        )\n\n    def get_feature_config(self):\n        fields = []\n        names = []\n\n        template_if = \"If(IsNull({1}), {0}, {1})\"\n        template_paused = \"Select(Gt($paused_num, 1.001), {0})\"\n        template_fillnan = \"FFillNan({0})\"\n        fields += [\n            template_fillnan.format(template_paused.format(\"$close\")),\n        ]\n        names += [\"$close0\"]\n\n        fields += [\n            template_paused.format(\n                template_if.format(\n                    template_fillnan.format(\"$close\"),\n                    \"$vwap\",\n                )\n            )\n        ]\n        names += [\"$vwap0\"]\n\n        fields += [template_paused.format(\"If(IsNull({0}), 0, {0})\".format(\"$volume\"))]\n        names += [\"$volume0\"]\n\n        fields += [template_paused.format(\"If(IsNull({0}), 0, {0})\".format(\"$factor\"))]\n        names += [\"$factor0\"]\n\n        return fields, names\n\n\nclass HighFreqGeneralBacktestHandler(DataHandler):\n    def __init__(\n        self,\n        instruments=\"csi300\",\n        start_time=None,\n        end_time=None,\n        day_length=240,\n        freq=\"1min\",\n        columns=[\"$close\", \"$vwap\", \"$volume\"],\n        inst_processors=None,\n    ):\n        self.day_length = day_length\n        self.columns = set(columns)\n        data_loader = {\n            \"class\": \"QlibDataLoader\",\n            \"kwargs\": {\n                \"config\": self.get_feature_config(),\n                \"swap_level\": False,\n                \"freq\": freq,\n                \"inst_processors\": inst_processors,\n            },\n        }\n        super().__init__(\n            instruments=instruments,\n            start_time=start_time,\n            end_time=end_time,\n            data_loader=data_loader,\n        )\n\n    def get_feature_config(self):\n        fields = []\n        names = []\n\n        if \"$close\" in self.columns:\n            template_paused = f\"Cut({{0}}, {self.day_length * 2}, None)\"\n            template_fillnan = \"FFillNan({0})\"\n            template_if = \"If(IsNull({1}), {0}, {1})\"\n            fields += [\n                template_paused.format(template_fillnan.format(\"$close\")),\n            ]\n            names += [\"$close0\"]\n\n        if \"$vwap\" in self.columns:\n            fields += [\n                template_paused.format(template_if.format(template_fillnan.format(\"$close\"), \"$vwap\")),\n            ]\n            names += [\"$vwap0\"]\n\n        if \"$volume\" in self.columns:\n            fields += [template_paused.format(\"If(IsNull({0}), 0, {0})\".format(\"$volume\"))]\n            names += [\"$volume0\"]\n\n        return fields, names\n\n\nclass HighFreqOrderHandler(DataHandlerLP):\n    def __init__(\n        self,\n        instruments=\"csi300\",\n        start_time=None,\n        end_time=None,\n        infer_processors=[],\n        learn_processors=[],\n        fit_start_time=None,\n        fit_end_time=None,\n        inst_processors=None,\n        drop_raw=True,\n    ):\n        infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)\n        learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)\n\n        data_loader = {\n            \"class\": \"QlibDataLoader\",\n            \"kwargs\": {\n                \"config\": self.get_feature_config(),\n                \"swap_level\": False,\n                \"freq\": \"1min\",\n                \"inst_processors\": inst_processors,\n            },\n        }\n        super().__init__(\n            instruments=instruments,\n            start_time=start_time,\n            end_time=end_time,\n            data_loader=data_loader,\n            infer_processors=infer_processors,\n            learn_processors=learn_processors,\n            drop_raw=drop_raw,\n        )\n\n    def get_feature_config(self):\n        fields = []\n        names = []\n\n        template_if = \"If(IsNull({1}), {0}, {1})\"\n        template_ifinf = \"If(IsInf({1}), {0}, {1})\"\n        template_paused = \"Select(Gt($paused_num, 1.001), {0})\"\n\n        def get_normalized_price_feature(price_field, shift=0):\n            # norm with the close price of 237th minute of yesterday.\n            if shift == 0:\n                template_norm = \"{0}/DayLast(Ref({1}, 243))\"\n            else:\n                template_norm = \"Ref({0}, \" + str(shift) + \")/DayLast(Ref({1}, 243))\"\n\n            template_fillnan = \"FFillNan({0})\"\n            # calculate -> ffill -> remove paused\n            feature_ops = template_paused.format(\n                template_fillnan.format(\n                    template_norm.format(template_if.format(\"$close\", price_field), template_fillnan.format(\"$close\"))\n                )\n            )\n            return feature_ops\n\n        def get_normalized_vwap_price_feature(price_field, shift=0):\n            # norm with the close price of 237th minute of yesterday.\n            if shift == 0:\n                template_norm = \"{0}/DayLast(Ref({1}, 243))\"\n            else:\n                template_norm = \"Ref({0}, \" + str(shift) + \")/DayLast(Ref({1}, 243))\"\n\n            template_fillnan = \"FFillNan({0})\"\n            # calculate -> ffill -> remove paused\n            feature_ops = template_paused.format(\n                template_fillnan.format(\n                    template_norm.format(\n                        template_if.format(\"$close\", template_ifinf.format(\"$close\", price_field)),\n                        template_fillnan.format(\"$close\"),\n                    )\n                )\n            )\n            return feature_ops\n\n        fields += [get_normalized_price_feature(\"$open\", 0)]\n        fields += [get_normalized_price_feature(\"$high\", 0)]\n        fields += [get_normalized_price_feature(\"$low\", 0)]\n        fields += [get_normalized_price_feature(\"$close\", 0)]\n        fields += [get_normalized_vwap_price_feature(\"$vwap\", 0)]\n        names += [\"$open\", \"$high\", \"$low\", \"$close\", \"$vwap\"]\n\n        fields += [get_normalized_price_feature(\"$open\", 240)]\n        fields += [get_normalized_price_feature(\"$high\", 240)]\n        fields += [get_normalized_price_feature(\"$low\", 240)]\n        fields += [get_normalized_price_feature(\"$close\", 240)]\n        fields += [get_normalized_vwap_price_feature(\"$vwap\", 240)]\n        names += [\"$open_1\", \"$high_1\", \"$low_1\", \"$close_1\", \"$vwap_1\"]\n\n        fields += [get_normalized_price_feature(\"$bid\", 0)]\n        fields += [get_normalized_price_feature(\"$ask\", 0)]\n        names += [\"$bid\", \"$ask\"]\n\n        fields += [get_normalized_price_feature(\"$bid\", 240)]\n        fields += [get_normalized_price_feature(\"$ask\", 240)]\n        names += [\"$bid_1\", \"$ask_1\"]\n\n        # calculate and fill nan with 0\n\n        def get_volume_feature(volume_field, shift=0):\n            template_gzero = \"If(Ge({0}, 0), {0}, 0)\"\n            if shift == 0:\n                feature_ops = template_gzero.format(\n                    template_paused.format(\n                        \"If(IsInf({0}), 0, {0})\".format(\n                            \"If(IsNull({0}), 0, {0})\".format(\n                                \"{0}/Ref(DayLast(Mean({0}, 7200)), 240)\".format(volume_field)\n                            )\n                        )\n                    )\n                )\n            else:\n                feature_ops = template_gzero.format(\n                    template_paused.format(\n                        \"If(IsInf({0}), 0, {0})\".format(\n                            \"If(IsNull({0}), 0, {0})\".format(\n                                f\"Ref({{0}}, {shift})/Ref(DayLast(Mean({{0}}, 7200)), 240)\".format(volume_field)\n                            )\n                        )\n                    )\n                )\n            return feature_ops\n\n        fields += [get_volume_feature(\"$volume\", 0)]\n        names += [\"$volume\"]\n\n        fields += [get_volume_feature(\"$volume\", 240)]\n        names += [\"$volume_1\"]\n\n        fields += [get_volume_feature(\"$bidV\", 0)]\n        fields += [get_volume_feature(\"$bidV1\", 0)]\n        fields += [get_volume_feature(\"$bidV3\", 0)]\n        fields += [get_volume_feature(\"$bidV5\", 0)]\n        fields += [get_volume_feature(\"$askV\", 0)]\n        fields += [get_volume_feature(\"$askV1\", 0)]\n        fields += [get_volume_feature(\"$askV3\", 0)]\n        fields += [get_volume_feature(\"$askV5\", 0)]\n        names += [\"$bidV\", \"$bidV1\", \"$bidV3\", \"$bidV5\", \"$askV\", \"$askV1\", \"$askV3\", \"$askV5\"]\n\n        fields += [get_volume_feature(\"$bidV\", 240)]\n        fields += [get_volume_feature(\"$bidV1\", 240)]\n        fields += [get_volume_feature(\"$bidV3\", 240)]\n        fields += [get_volume_feature(\"$bidV5\", 240)]\n        fields += [get_volume_feature(\"$askV\", 240)]\n        fields += [get_volume_feature(\"$askV1\", 240)]\n        fields += [get_volume_feature(\"$askV3\", 240)]\n        fields += [get_volume_feature(\"$askV5\", 240)]\n        names += [\"$bidV_1\", \"$bidV1_1\", \"$bidV3_1\", \"$bidV5_1\", \"$askV_1\", \"$askV1_1\", \"$askV3_1\", \"$askV5_1\"]\n\n        return fields, names\n\n\nclass HighFreqBacktestOrderHandler(DataHandler):\n    def __init__(\n        self,\n        instruments=\"csi300\",\n        start_time=None,\n        end_time=None,\n    ):\n        data_loader = {\n            \"class\": \"QlibDataLoader\",\n            \"kwargs\": {\n                \"config\": self.get_feature_config(),\n                \"swap_level\": False,\n                \"freq\": \"1min\",\n            },\n        }\n        super().__init__(\n            instruments=instruments,\n            start_time=start_time,\n            end_time=end_time,\n            data_loader=data_loader,\n        )\n\n    def get_feature_config(self):\n        fields = []\n        names = []\n\n        template_if = \"If(IsNull({1}), {0}, {1})\"\n        template_paused = \"Select(Gt($paused_num, 1.001), {0})\"\n        template_fillnan = \"FFillNan({0})\"\n        fields += [\n            template_fillnan.format(template_paused.format(\"$close\")),\n        ]\n        names += [\"$close0\"]\n\n        fields += [\n            template_paused.format(\n                template_if.format(\n                    template_fillnan.format(\"$close\"),\n                    \"$vwap\",\n                )\n            )\n        ]\n        names += [\"$vwap0\"]\n\n        fields += [template_paused.format(\"If(IsNull({0}), 0, {0})\".format(\"$volume\"))]\n        names += [\"$volume0\"]\n\n        fields += [template_paused.format(\"If(IsNull({0}), 0, {0})\".format(\"$bid\"))]\n        names += [\"$bid0\"]\n\n        fields += [template_paused.format(\"If(IsNull({0}), 0, {0})\".format(\"$bidV\"))]\n        names += [\"$bidV0\"]\n\n        fields += [template_paused.format(\"If(IsNull({0}), 0, {0})\".format(\"$ask\"))]\n        names += [\"$ask0\"]\n\n        fields += [template_paused.format(\"If(IsNull({0}), 0, {0})\".format(\"$askV\"))]\n        names += [\"$askV0\"]\n\n        fields += [template_paused.format(\"If(IsNull({0}), 0, {0})\".format(\"($bid + $ask) / 2\"))]\n        names += [\"$median0\"]\n\n        fields += [template_paused.format(\"If(IsNull({0}), 0, {0})\".format(\"$factor\"))]\n        names += [\"$factor0\"]\n\n        fields += [template_paused.format(\"If(IsNull({0}), 0, {0})\".format(\"$downlimitmarket\"))]\n        names += [\"$downlimitmarket0\"]\n\n        fields += [template_paused.format(\"If(IsNull({0}), 0, {0})\".format(\"$uplimitmarket\"))]\n        names += [\"$uplimitmarket0\"]\n\n        fields += [template_paused.format(\"If(IsNull({0}), 0, {0})\".format(\"$highmarket\"))]\n        names += [\"$highmarket0\"]\n\n        fields += [template_paused.format(\"If(IsNull({0}), 0, {0})\".format(\"$lowmarket\"))]\n        names += [\"$lowmarket0\"]\n\n        return fields, names\n"
  },
  {
    "path": "qlib/contrib/data/highfreq_processor.py",
    "content": "import os\n\nimport numpy as np\nimport pandas as pd\nfrom qlib.data.dataset.processor import Processor\nfrom qlib.data.dataset.utils import fetch_df_by_index\nfrom typing import Dict\n\n\nclass HighFreqTrans(Processor):\n    def __init__(self, dtype: str = \"bool\"):\n        self.dtype = dtype\n\n    def fit(self, df_features):\n        pass\n\n    def __call__(self, df_features):\n        if self.dtype == \"bool\":\n            return df_features.astype(np.int8)\n        else:\n            return df_features.astype(np.float32)\n\n\nclass HighFreqNorm(Processor):\n    def __init__(\n        self,\n        fit_start_time: pd.Timestamp,\n        fit_end_time: pd.Timestamp,\n        feature_save_dir: str,\n        norm_groups: Dict[str, int],\n    ):\n        self.fit_start_time = fit_start_time\n        self.fit_end_time = fit_end_time\n        self.feature_save_dir = feature_save_dir\n        self.norm_groups = norm_groups\n\n    def fit(self, df_features) -> None:\n        if os.path.exists(self.feature_save_dir) and len(os.listdir(self.feature_save_dir)) != 0:\n            return\n        os.makedirs(self.feature_save_dir)\n        fetch_df = fetch_df_by_index(df_features, slice(self.fit_start_time, self.fit_end_time), level=\"datetime\")\n        del df_features\n        index = 0\n        names = {}\n        for name, dim in self.norm_groups.items():\n            names[name] = slice(index, index + dim)\n            index += dim\n        for name, name_val in names.items():\n            df_values = fetch_df.iloc(axis=1)[name_val].values\n            if name.endswith(\"volume\"):\n                df_values = np.log1p(df_values)\n            self.feature_mean = np.nanmean(df_values)\n            np.save(self.feature_save_dir + name + \"_mean.npy\", self.feature_mean)\n            df_values = df_values - self.feature_mean\n            self.feature_std = np.nanstd(np.absolute(df_values))\n            np.save(self.feature_save_dir + name + \"_std.npy\", self.feature_std)\n            df_values = df_values / self.feature_std\n            np.save(self.feature_save_dir + name + \"_vmax.npy\", np.nanmax(df_values))\n            np.save(self.feature_save_dir + name + \"_vmin.npy\", np.nanmin(df_values))\n        return\n\n    def __call__(self, df_features):\n        if \"date\" in df_features:\n            df_features.droplevel(\"date\", inplace=True)\n        df_values = df_features.values\n        index = 0\n        names = {}\n        for name, dim in self.norm_groups.items():\n            names[name] = slice(index, index + dim)\n            index += dim\n        for name, name_val in names.items():\n            feature_mean = np.load(self.feature_save_dir + name + \"_mean.npy\")\n            feature_std = np.load(self.feature_save_dir + name + \"_std.npy\")\n\n            if name.endswith(\"volume\"):\n                df_values[:, name_val] = np.log1p(df_values[:, name_val])\n            df_values[:, name_val] -= feature_mean\n            df_values[:, name_val] /= feature_std\n        df_features = pd.DataFrame(data=df_values, index=df_features.index, columns=df_features.columns)\n        return df_features.fillna(0)\n"
  },
  {
    "path": "qlib/contrib/data/highfreq_provider.py",
    "content": "import os\nimport time\nimport datetime\nfrom typing import Optional\n\nimport qlib\nfrom qlib import get_module_logger\nfrom qlib.data import D\nfrom qlib.config import REG_CN\nfrom qlib.utils import init_instance_by_config\nfrom qlib.data.dataset.handler import DataHandlerLP\nfrom qlib.data.data import Cal\nfrom qlib.contrib.ops.high_freq import get_calendar_day, DayLast, FFillNan, BFillNan, Date, Select, IsNull, IsInf, Cut\nimport pickle as pkl\nfrom joblib import Parallel, delayed\n\n\nclass HighFreqProvider:\n    def __init__(\n        self,\n        start_time: str,\n        end_time: str,\n        train_end_time: str,\n        valid_start_time: str,\n        valid_end_time: str,\n        test_start_time: str,\n        qlib_conf: dict,\n        feature_conf: dict,\n        label_conf: Optional[dict] = None,\n        backtest_conf: dict = None,\n        freq: str = \"1min\",\n        **kwargs,\n    ) -> None:\n        self.start_time = start_time\n        self.end_time = end_time\n        self.test_start_time = test_start_time\n        self.train_end_time = train_end_time\n        self.valid_start_time = valid_start_time\n        self.valid_end_time = valid_end_time\n        self._init_qlib(qlib_conf)\n        self.feature_conf = feature_conf\n        self.label_conf = label_conf\n        self.backtest_conf = backtest_conf\n        self.qlib_conf = qlib_conf\n        self.logger = get_module_logger(\"HighFreqProvider\")\n        self.freq = freq\n\n    def get_pre_datasets(self):\n        \"\"\"Generate the training, validation and test datasets for prediction\n\n        Returns:\n            Tuple[BaseDataset, BaseDataset, BaseDataset]: The training and test datasets\n        \"\"\"\n\n        dict_feature_path = self.feature_conf[\"path\"]\n        train_feature_path = dict_feature_path[:-4] + \"_train.pkl\"\n        valid_feature_path = dict_feature_path[:-4] + \"_valid.pkl\"\n        test_feature_path = dict_feature_path[:-4] + \"_test.pkl\"\n\n        dict_label_path = self.label_conf[\"path\"]\n        train_label_path = dict_label_path[:-4] + \"_train.pkl\"\n        valid_label_path = dict_label_path[:-4] + \"_valid.pkl\"\n        test_label_path = dict_label_path[:-4] + \"_test.pkl\"\n\n        if (\n            not os.path.isfile(train_feature_path)\n            or not os.path.isfile(valid_feature_path)\n            or not os.path.isfile(test_feature_path)\n        ):\n            xtrain, xvalid, xtest = self._gen_data(self.feature_conf)\n            xtrain.to_pickle(train_feature_path)\n            xvalid.to_pickle(valid_feature_path)\n            xtest.to_pickle(test_feature_path)\n            del xtrain, xvalid, xtest\n\n        if (\n            not os.path.isfile(train_label_path)\n            or not os.path.isfile(valid_label_path)\n            or not os.path.isfile(test_label_path)\n        ):\n            ytrain, yvalid, ytest = self._gen_data(self.label_conf)\n            ytrain.to_pickle(train_label_path)\n            yvalid.to_pickle(valid_label_path)\n            ytest.to_pickle(test_label_path)\n            del ytrain, yvalid, ytest\n\n        feature = {\n            \"train\": train_feature_path,\n            \"valid\": valid_feature_path,\n            \"test\": test_feature_path,\n        }\n\n        label = {\n            \"train\": train_label_path,\n            \"valid\": valid_label_path,\n            \"test\": test_label_path,\n        }\n\n        return feature, label\n\n    def get_backtest(self, **kwargs) -> None:\n        self._gen_data(self.backtest_conf)\n\n    def _init_qlib(self, qlib_conf):\n        \"\"\"initialize qlib\"\"\"\n\n        qlib.init(\n            region=REG_CN,\n            auto_mount=False,\n            custom_ops=[DayLast, FFillNan, BFillNan, Date, Select, IsNull, IsInf, Cut],\n            expression_cache=None,\n            **qlib_conf,\n        )\n\n    def _prepare_calender_cache(self):\n        \"\"\"preload the calendar for cache\"\"\"\n\n        # This code used the copy-on-write feature of Linux\n        # to avoid calculating the calendar multiple times in the subprocess.\n        # This code may accelerate, but may be not useful on Windows and Mac Os\n        Cal.calendar(freq=self.freq)\n        get_calendar_day(freq=self.freq)\n\n    def _gen_dataframe(self, config, datasets=[\"train\", \"valid\", \"test\"]):\n        try:\n            path = config.pop(\"path\")\n        except KeyError as e:\n            raise ValueError(\"Must specify the path to save the dataset.\") from e\n        if os.path.isfile(path):\n            start = time.time()\n            self.logger.info(f\"[{__name__}]Dataset exists, load from disk.\")\n\n            # res = dataset.prepare(['train', 'valid', 'test'])\n            with open(path, \"rb\") as f:\n                data = pkl.load(f)\n            if isinstance(data, dict):\n                res = [data[i] for i in datasets]\n            else:\n                res = data.prepare(datasets)\n            self.logger.info(f\"[{__name__}]Data loaded, time cost: {time.time() - start:.2f}\")\n        else:\n            if not os.path.exists(os.path.dirname(path)):\n                os.makedirs(os.path.dirname(path))\n            self.logger.info(f\"[{__name__}]Generating dataset\")\n            start_time = time.time()\n            self._prepare_calender_cache()\n            dataset = init_instance_by_config(config)\n            trainset, validset, testset = dataset.prepare([\"train\", \"valid\", \"test\"])\n            data = {\n                \"train\": trainset,\n                \"valid\": validset,\n                \"test\": testset,\n            }\n            with open(path, \"wb\") as f:\n                pkl.dump(data, f)\n            with open(path[:-4] + \"train.pkl\", \"wb\") as f:\n                pkl.dump(trainset, f)\n            with open(path[:-4] + \"valid.pkl\", \"wb\") as f:\n                pkl.dump(validset, f)\n            with open(path[:-4] + \"test.pkl\", \"wb\") as f:\n                pkl.dump(testset, f)\n            res = [data[i] for i in datasets]\n            self.logger.info(f\"[{__name__}]Data generated, time cost: {(time.time() - start_time):.2f}\")\n        return res\n\n    def _gen_data(self, config, datasets=[\"train\", \"valid\", \"test\"]):\n        try:\n            path = config.pop(\"path\")\n        except KeyError as e:\n            raise ValueError(\"Must specify the path to save the dataset.\") from e\n        if os.path.isfile(path):\n            start = time.time()\n            self.logger.info(f\"[{__name__}]Dataset exists, load from disk.\")\n\n            # res = dataset.prepare(['train', 'valid', 'test'])\n            with open(path, \"rb\") as f:\n                data = pkl.load(f)\n            if isinstance(data, dict):\n                res = [data[i] for i in datasets]\n            else:\n                res = data.prepare(datasets)\n            self.logger.info(f\"[{__name__}]Data loaded, time cost: {time.time() - start:.2f}\")\n        else:\n            if not os.path.exists(os.path.dirname(path)):\n                os.makedirs(os.path.dirname(path))\n            self.logger.info(f\"[{__name__}]Generating dataset\")\n            start_time = time.time()\n            self._prepare_calender_cache()\n            dataset = init_instance_by_config(config)\n            dataset.config(dump_all=True, recursive=True)\n            dataset.to_pickle(path)\n            res = dataset.prepare(datasets)\n            self.logger.info(f\"[{__name__}]Data generated, time cost: {(time.time() - start_time):.2f}\")\n        return res\n\n    def _gen_dataset(self, config):\n        try:\n            path = config.pop(\"path\")\n        except KeyError as e:\n            raise ValueError(\"Must specify the path to save the dataset.\") from e\n        if os.path.isfile(path):\n            start = time.time()\n            self.logger.info(f\"[{__name__}]Dataset exists, load from disk.\")\n\n            with open(path, \"rb\") as f:\n                dataset = pkl.load(f)\n            self.logger.info(f\"[{__name__}]Data loaded, time cost: {time.time() - start:.2f}\")\n        else:\n            start = time.time()\n            if not os.path.exists(os.path.dirname(path)):\n                os.makedirs(os.path.dirname(path))\n            self.logger.info(f\"[{__name__}]Generating dataset\")\n            self._prepare_calender_cache()\n            dataset = init_instance_by_config(config)\n            self.logger.info(f\"[{__name__}]Dataset init, time cost: {time.time() - start:.2f}\")\n            dataset.prepare([\"train\", \"valid\", \"test\"])\n            self.logger.info(f\"[{__name__}]Dataset prepared, time cost: {time.time() - start:.2f}\")\n            dataset.config(dump_all=True, recursive=True)\n            dataset.to_pickle(path)\n        return dataset\n\n    def _gen_day_dataset(self, config, conf_type):\n        try:\n            path = config.pop(\"path\")\n        except KeyError as e:\n            raise ValueError(\"Must specify the path to save the dataset.\") from e\n\n        if os.path.isfile(path + \"tmp_dataset.pkl\"):\n            start = time.time()\n            self.logger.info(f\"[{__name__}]Dataset exists, load from disk.\")\n        else:\n            start = time.time()\n            if not os.path.exists(os.path.dirname(path)):\n                os.makedirs(os.path.dirname(path))\n            self.logger.info(f\"[{__name__}]Generating dataset\")\n            self._prepare_calender_cache()\n            dataset = init_instance_by_config(config)\n            self.logger.info(f\"[{__name__}]Dataset init, time cost: {time.time() - start:.2f}\")\n            dataset.config(dump_all=False, recursive=True)\n            dataset.to_pickle(path + \"tmp_dataset.pkl\")\n\n        with open(path + \"tmp_dataset.pkl\", \"rb\") as f:\n            new_dataset = pkl.load(f)\n\n        time_list = D.calendar(start_time=self.start_time, end_time=self.end_time, freq=self.freq)[::240]\n\n        def generate_dataset(times):\n            if os.path.isfile(path + times.strftime(\"%Y-%m-%d\") + \".pkl\"):\n                print(\"exist \" + times.strftime(\"%Y-%m-%d\"))\n                return\n            self._init_qlib(self.qlib_conf)\n            end_times = times + datetime.timedelta(days=1)\n            new_dataset.handler.config(**{\"start_time\": times, \"end_time\": end_times})\n            if conf_type == \"backtest\":\n                new_dataset.handler.setup_data()\n            else:\n                new_dataset.handler.setup_data(init_type=DataHandlerLP.IT_LS)\n            new_dataset.config(dump_all=True, recursive=True)\n            new_dataset.to_pickle(path + times.strftime(\"%Y-%m-%d\") + \".pkl\")\n\n        Parallel(n_jobs=8)(delayed(generate_dataset)(times) for times in time_list)\n\n    def _gen_stock_dataset(self, config, conf_type):\n        try:\n            path = config.pop(\"path\")\n        except KeyError as e:\n            raise ValueError(\"Must specify the path to save the dataset.\") from e\n\n        if os.path.isfile(path + \"tmp_dataset.pkl\"):\n            start = time.time()\n            self.logger.info(f\"[{__name__}]Dataset exists, load from disk.\")\n        else:\n            start = time.time()\n            if not os.path.exists(os.path.dirname(path)):\n                os.makedirs(os.path.dirname(path))\n            self.logger.info(f\"[{__name__}]Generating dataset\")\n            self._prepare_calender_cache()\n            dataset = init_instance_by_config(config)\n            self.logger.info(f\"[{__name__}]Dataset init, time cost: {time.time() - start:.2f}\")\n            dataset.config(dump_all=False, recursive=True)\n            dataset.to_pickle(path + \"tmp_dataset.pkl\")\n\n        with open(path + \"tmp_dataset.pkl\", \"rb\") as f:\n            new_dataset = pkl.load(f)\n\n        instruments = D.instruments(market=\"all\")\n        stock_list = D.list_instruments(\n            instruments=instruments, start_time=self.start_time, end_time=self.end_time, freq=self.freq, as_list=True\n        )\n\n        def generate_dataset(stock):\n            if os.path.isfile(path + stock + \".pkl\"):\n                print(\"exist \" + stock)\n                return\n            self._init_qlib(self.qlib_conf)\n            new_dataset.handler.config(**{\"instruments\": [stock]})\n            if conf_type == \"backtest\":\n                new_dataset.handler.setup_data()\n            else:\n                new_dataset.handler.setup_data(init_type=DataHandlerLP.IT_LS)\n            new_dataset.config(dump_all=True, recursive=True)\n            new_dataset.to_pickle(path + stock + \".pkl\")\n\n        Parallel(n_jobs=32)(delayed(generate_dataset)(stock) for stock in stock_list)\n"
  },
  {
    "path": "qlib/contrib/data/loader.py",
    "content": "from qlib.data.dataset.loader import QlibDataLoader\n\n\nclass Alpha360DL(QlibDataLoader):\n    \"\"\"Dataloader to get Alpha360\"\"\"\n\n    def __init__(self, config=None, **kwargs):\n        _config = {\n            \"feature\": self.get_feature_config(),\n        }\n        if config is not None:\n            _config.update(config)\n        super().__init__(config=_config, **kwargs)\n\n    @staticmethod\n    def get_feature_config():\n        # NOTE:\n        # Alpha360 tries to provide a dataset with original price data\n        # the original price data includes the prices and volume in the last 60 days.\n        # To make it easier to learn models from this dataset, all the prices and volume\n        # are normalized by the latest price and volume data ( dividing by $close, $volume)\n        # So the latest normalized $close will be 1 (with name CLOSE0), the latest normalized $volume will be 1 (with name VOLUME0)\n        # If further normalization are executed (e.g. centralization),  CLOSE0 and VOLUME0 will be 0.\n        fields = []\n        names = []\n\n        for i in range(59, 0, -1):\n            fields += [\"Ref($close, %d)/$close\" % i]\n            names += [\"CLOSE%d\" % i]\n        fields += [\"$close/$close\"]\n        names += [\"CLOSE0\"]\n        for i in range(59, 0, -1):\n            fields += [\"Ref($open, %d)/$close\" % i]\n            names += [\"OPEN%d\" % i]\n        fields += [\"$open/$close\"]\n        names += [\"OPEN0\"]\n        for i in range(59, 0, -1):\n            fields += [\"Ref($high, %d)/$close\" % i]\n            names += [\"HIGH%d\" % i]\n        fields += [\"$high/$close\"]\n        names += [\"HIGH0\"]\n        for i in range(59, 0, -1):\n            fields += [\"Ref($low, %d)/$close\" % i]\n            names += [\"LOW%d\" % i]\n        fields += [\"$low/$close\"]\n        names += [\"LOW0\"]\n        for i in range(59, 0, -1):\n            fields += [\"Ref($vwap, %d)/$close\" % i]\n            names += [\"VWAP%d\" % i]\n        fields += [\"$vwap/$close\"]\n        names += [\"VWAP0\"]\n        for i in range(59, 0, -1):\n            fields += [\"Ref($volume, %d)/($volume+1e-12)\" % i]\n            names += [\"VOLUME%d\" % i]\n        fields += [\"$volume/($volume+1e-12)\"]\n        names += [\"VOLUME0\"]\n\n        return fields, names\n\n\nclass Alpha158DL(QlibDataLoader):\n    \"\"\"Dataloader to get Alpha158\"\"\"\n\n    def __init__(self, config=None, **kwargs):\n        _config = {\n            \"feature\": self.get_feature_config(),\n        }\n        if config is not None:\n            _config.update(config)\n        super().__init__(config=_config, **kwargs)\n\n    @staticmethod\n    def get_feature_config(\n        config={\n            \"kbar\": {},\n            \"price\": {\n                \"windows\": [0],\n                \"feature\": [\"OPEN\", \"HIGH\", \"LOW\", \"VWAP\"],\n            },\n            \"rolling\": {},\n        }\n    ):\n        \"\"\"create factors from config\n\n        config = {\n            'kbar': {}, # whether to use some hard-code kbar features\n            'price': { # whether to use raw price features\n                'windows': [0, 1, 2, 3, 4], # use price at n days ago\n                'feature': ['OPEN', 'HIGH', 'LOW'] # which price field to use\n            },\n            'volume': { # whether to use raw volume features\n                'windows': [0, 1, 2, 3, 4], # use volume at n days ago\n            },\n            'rolling': { # whether to use rolling operator based features\n                'windows': [5, 10, 20, 30, 60], # rolling windows size\n                'include': ['ROC', 'MA', 'STD'], # rolling operator to use\n                #if include is None we will use default operators\n                'exclude': ['RANK'], # rolling operator not to use\n            }\n        }\n        \"\"\"\n        fields = []\n        names = []\n        if \"kbar\" in config:\n            fields += [\n                \"($close-$open)/$open\",\n                \"($high-$low)/$open\",\n                \"($close-$open)/($high-$low+1e-12)\",\n                \"($high-Greater($open, $close))/$open\",\n                \"($high-Greater($open, $close))/($high-$low+1e-12)\",\n                \"(Less($open, $close)-$low)/$open\",\n                \"(Less($open, $close)-$low)/($high-$low+1e-12)\",\n                \"(2*$close-$high-$low)/$open\",\n                \"(2*$close-$high-$low)/($high-$low+1e-12)\",\n            ]\n            names += [\n                \"KMID\",\n                \"KLEN\",\n                \"KMID2\",\n                \"KUP\",\n                \"KUP2\",\n                \"KLOW\",\n                \"KLOW2\",\n                \"KSFT\",\n                \"KSFT2\",\n            ]\n        if \"price\" in config:\n            windows = config[\"price\"].get(\"windows\", range(5))\n            feature = config[\"price\"].get(\"feature\", [\"OPEN\", \"HIGH\", \"LOW\", \"CLOSE\", \"VWAP\"])\n            for field in feature:\n                field = field.lower()\n                fields += [\"Ref($%s, %d)/$close\" % (field, d) if d != 0 else \"$%s/$close\" % field for d in windows]\n                names += [field.upper() + str(d) for d in windows]\n        if \"volume\" in config:\n            windows = config[\"volume\"].get(\"windows\", range(5))\n            fields += [\"Ref($volume, %d)/($volume+1e-12)\" % d if d != 0 else \"$volume/($volume+1e-12)\" for d in windows]\n            names += [\"VOLUME\" + str(d) for d in windows]\n        if \"rolling\" in config:\n            windows = config[\"rolling\"].get(\"windows\", [5, 10, 20, 30, 60])\n            include = config[\"rolling\"].get(\"include\", None)\n            exclude = config[\"rolling\"].get(\"exclude\", [])\n            # `exclude` in dataset config unnecessary filed\n            # `include` in dataset config necessary field\n\n            def use(x):\n                return x not in exclude and (include is None or x in include)\n\n            # Some factor ref: https://guorn.com/static/upload/file/3/134065454575605.pdf\n            if use(\"ROC\"):\n                # https://www.investopedia.com/terms/r/rateofchange.asp\n                # Rate of change, the price change in the past d days, divided by latest close price to remove unit\n                fields += [\"Ref($close, %d)/$close\" % d for d in windows]\n                names += [\"ROC%d\" % d for d in windows]\n            if use(\"MA\"):\n                # https://www.investopedia.com/ask/answers/071414/whats-difference-between-moving-average-and-weighted-moving-average.asp\n                # Simple Moving Average, the simple moving average in the past d days, divided by latest close price to remove unit\n                fields += [\"Mean($close, %d)/$close\" % d for d in windows]\n                names += [\"MA%d\" % d for d in windows]\n            if use(\"STD\"):\n                # The standard diviation of close price for the past d days, divided by latest close price to remove unit\n                fields += [\"Std($close, %d)/$close\" % d for d in windows]\n                names += [\"STD%d\" % d for d in windows]\n            if use(\"BETA\"):\n                # The rate of close price change in the past d days, divided by latest close price to remove unit\n                # For example, price increase 10 dollar per day in the past d days, then Slope will be 10.\n                fields += [\"Slope($close, %d)/$close\" % d for d in windows]\n                names += [\"BETA%d\" % d for d in windows]\n            if use(\"RSQR\"):\n                # The R-sqaure value of linear regression for the past d days, represent the trend linear\n                fields += [\"Rsquare($close, %d)\" % d for d in windows]\n                names += [\"RSQR%d\" % d for d in windows]\n            if use(\"RESI\"):\n                # The redisdual for linear regression for the past d days, represent the trend linearity for past d days.\n                fields += [\"Resi($close, %d)/$close\" % d for d in windows]\n                names += [\"RESI%d\" % d for d in windows]\n            if use(\"MAX\"):\n                # The max price for past d days, divided by latest close price to remove unit\n                fields += [\"Max($high, %d)/$close\" % d for d in windows]\n                names += [\"MAX%d\" % d for d in windows]\n            if use(\"LOW\"):\n                # The low price for past d days, divided by latest close price to remove unit\n                fields += [\"Min($low, %d)/$close\" % d for d in windows]\n                names += [\"MIN%d\" % d for d in windows]\n            if use(\"QTLU\"):\n                # The 80% quantile of past d day's close price, divided by latest close price to remove unit\n                # Used with MIN and MAX\n                fields += [\"Quantile($close, %d, 0.8)/$close\" % d for d in windows]\n                names += [\"QTLU%d\" % d for d in windows]\n            if use(\"QTLD\"):\n                # The 20% quantile of past d day's close price, divided by latest close price to remove unit\n                fields += [\"Quantile($close, %d, 0.2)/$close\" % d for d in windows]\n                names += [\"QTLD%d\" % d for d in windows]\n            if use(\"RANK\"):\n                # Get the percentile of current close price in past d day's close price.\n                # Represent the current price level comparing to past N days, add additional information to moving average.\n                fields += [\"Rank($close, %d)\" % d for d in windows]\n                names += [\"RANK%d\" % d for d in windows]\n            if use(\"RSV\"):\n                # Represent the price position between upper and lower resistent price for past d days.\n                fields += [\"($close-Min($low, %d))/(Max($high, %d)-Min($low, %d)+1e-12)\" % (d, d, d) for d in windows]\n                names += [\"RSV%d\" % d for d in windows]\n            if use(\"IMAX\"):\n                # The number of days between current date and previous highest price date.\n                # Part of Aroon Indicator https://www.investopedia.com/terms/a/aroon.asp\n                # The indicator measures the time between highs and the time between lows over a time period.\n                # The idea is that strong uptrends will regularly see new highs, and strong downtrends will regularly see new lows.\n                fields += [\"IdxMax($high, %d)/%d\" % (d, d) for d in windows]\n                names += [\"IMAX%d\" % d for d in windows]\n            if use(\"IMIN\"):\n                # The number of days between current date and previous lowest price date.\n                # Part of Aroon Indicator https://www.investopedia.com/terms/a/aroon.asp\n                # The indicator measures the time between highs and the time between lows over a time period.\n                # The idea is that strong uptrends will regularly see new highs, and strong downtrends will regularly see new lows.\n                fields += [\"IdxMin($low, %d)/%d\" % (d, d) for d in windows]\n                names += [\"IMIN%d\" % d for d in windows]\n            if use(\"IMXD\"):\n                # The time period between previous lowest-price date occur after highest price date.\n                # Large value suggest downward momemtum.\n                fields += [\"(IdxMax($high, %d)-IdxMin($low, %d))/%d\" % (d, d, d) for d in windows]\n                names += [\"IMXD%d\" % d for d in windows]\n            if use(\"CORR\"):\n                # The correlation between absolute close price and log scaled trading volume\n                fields += [\"Corr($close, Log($volume+1), %d)\" % d for d in windows]\n                names += [\"CORR%d\" % d for d in windows]\n            if use(\"CORD\"):\n                # The correlation between price change ratio and volume change ratio\n                fields += [\"Corr($close/Ref($close,1), Log($volume/Ref($volume, 1)+1), %d)\" % d for d in windows]\n                names += [\"CORD%d\" % d for d in windows]\n            if use(\"CNTP\"):\n                # The percentage of days in past d days that price go up.\n                fields += [\"Mean($close>Ref($close, 1), %d)\" % d for d in windows]\n                names += [\"CNTP%d\" % d for d in windows]\n            if use(\"CNTN\"):\n                # The percentage of days in past d days that price go down.\n                fields += [\"Mean($close<Ref($close, 1), %d)\" % d for d in windows]\n                names += [\"CNTN%d\" % d for d in windows]\n            if use(\"CNTD\"):\n                # The diff between past up day and past down day\n                fields += [\"Mean($close>Ref($close, 1), %d)-Mean($close<Ref($close, 1), %d)\" % (d, d) for d in windows]\n                names += [\"CNTD%d\" % d for d in windows]\n            if use(\"SUMP\"):\n                # The total gain / the absolute total price changed\n                # Similar to RSI indicator. https://www.investopedia.com/terms/r/rsi.asp\n                fields += [\n                    \"Sum(Greater($close-Ref($close, 1), 0), %d)/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)\" % (d, d)\n                    for d in windows\n                ]\n                names += [\"SUMP%d\" % d for d in windows]\n            if use(\"SUMN\"):\n                # The total lose / the absolute total price changed\n                # Can be derived from SUMP by SUMN = 1 - SUMP\n                # Similar to RSI indicator. https://www.investopedia.com/terms/r/rsi.asp\n                fields += [\n                    \"Sum(Greater(Ref($close, 1)-$close, 0), %d)/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)\" % (d, d)\n                    for d in windows\n                ]\n                names += [\"SUMN%d\" % d for d in windows]\n            if use(\"SUMD\"):\n                # The diff ratio between total gain and total lose\n                # Similar to RSI indicator. https://www.investopedia.com/terms/r/rsi.asp\n                fields += [\n                    \"(Sum(Greater($close-Ref($close, 1), 0), %d)-Sum(Greater(Ref($close, 1)-$close, 0), %d))\"\n                    \"/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)\" % (d, d, d)\n                    for d in windows\n                ]\n                names += [\"SUMD%d\" % d for d in windows]\n            if use(\"VMA\"):\n                # Simple Volume Moving average: https://www.barchart.com/education/technical-indicators/volume_moving_average\n                fields += [\"Mean($volume, %d)/($volume+1e-12)\" % d for d in windows]\n                names += [\"VMA%d\" % d for d in windows]\n            if use(\"VSTD\"):\n                # The standard deviation for volume in past d days.\n                fields += [\"Std($volume, %d)/($volume+1e-12)\" % d for d in windows]\n                names += [\"VSTD%d\" % d for d in windows]\n            if use(\"WVMA\"):\n                # The volume weighted price change volatility\n                fields += [\n                    \"Std(Abs($close/Ref($close, 1)-1)*$volume, %d)/(Mean(Abs($close/Ref($close, 1)-1)*$volume, %d)+1e-12)\"\n                    % (d, d)\n                    for d in windows\n                ]\n                names += [\"WVMA%d\" % d for d in windows]\n            if use(\"VSUMP\"):\n                # The total volume increase / the absolute total volume changed\n                fields += [\n                    \"Sum(Greater($volume-Ref($volume, 1), 0), %d)/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)\"\n                    % (d, d)\n                    for d in windows\n                ]\n                names += [\"VSUMP%d\" % d for d in windows]\n            if use(\"VSUMN\"):\n                # The total volume increase / the absolute total volume changed\n                # Can be derived from VSUMP by VSUMN = 1 - VSUMP\n                fields += [\n                    \"Sum(Greater(Ref($volume, 1)-$volume, 0), %d)/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)\"\n                    % (d, d)\n                    for d in windows\n                ]\n                names += [\"VSUMN%d\" % d for d in windows]\n            if use(\"VSUMD\"):\n                # The diff ratio between total volume increase and total volume decrease\n                # RSI indicator for volume\n                fields += [\n                    \"(Sum(Greater($volume-Ref($volume, 1), 0), %d)-Sum(Greater(Ref($volume, 1)-$volume, 0), %d))\"\n                    \"/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)\" % (d, d, d)\n                    for d in windows\n                ]\n                names += [\"VSUMD%d\" % d for d in windows]\n\n        return fields, names\n"
  },
  {
    "path": "qlib/contrib/data/processor.py",
    "content": "import numpy as np\n\nfrom ...log import TimeInspector\nfrom ...data.dataset.processor import Processor, get_group_columns\n\n\nclass ConfigSectionProcessor(Processor):\n    \"\"\"\n    This processor is designed for Alpha158. And will be replaced by simple processors in the future\n    \"\"\"\n\n    def __init__(self, fields_group=None, **kwargs):\n        super().__init__()\n        # Options\n        self.fillna_feature = kwargs.get(\"fillna_feature\", True)\n        self.fillna_label = kwargs.get(\"fillna_label\", True)\n        self.clip_feature_outlier = kwargs.get(\"clip_feature_outlier\", False)\n        self.shrink_feature_outlier = kwargs.get(\"shrink_feature_outlier\", True)\n        self.clip_label_outlier = kwargs.get(\"clip_label_outlier\", False)\n\n        self.fields_group = None\n\n    def __call__(self, df):\n        return self._transform(df)\n\n    def _transform(self, df):\n        def _label_norm(x):\n            x = x - x.mean()  # copy\n            x /= x.std()\n            if self.clip_label_outlier:\n                x.clip(-3, 3, inplace=True)\n            if self.fillna_label:\n                x.fillna(0, inplace=True)\n            return x\n\n        def _feature_norm(x):\n            x = x - x.median()  # copy\n            x /= x.abs().median() * 1.4826\n            if self.clip_feature_outlier:\n                x.clip(-3, 3, inplace=True)\n            if self.shrink_feature_outlier:\n                x.where(x <= 3, 3 + (x - 3).div(x.max() - 3) * 0.5, inplace=True)\n                x.where(x >= -3, -3 - (x + 3).div(x.min() + 3) * 0.5, inplace=True)\n            if self.fillna_feature:\n                x.fillna(0, inplace=True)\n            return x\n\n        TimeInspector.set_time_mark()\n\n        # Copy the focus part and change it to single level\n        selected_cols = get_group_columns(df, self.fields_group)\n        df_focus = df[selected_cols].copy()\n        if len(df_focus.columns.levels) > 1:\n            df_focus = df_focus.droplevel(level=0)\n\n        # Label\n        cols = df_focus.columns[df_focus.columns.str.contains(\"^LABEL\")]\n        df_focus[cols] = df_focus[cols].groupby(level=\"datetime\", group_keys=False).apply(_label_norm)\n\n        # Features\n        cols = df_focus.columns[df_focus.columns.str.contains(\"^KLEN|^KLOW|^KUP\")]\n        df_focus[cols] = (\n            df_focus[cols].apply(lambda x: x**0.25).groupby(level=\"datetime\", group_keys=False).apply(_feature_norm)\n        )\n\n        cols = df_focus.columns[df_focus.columns.str.contains(\"^KLOW2|^KUP2\")]\n        df_focus[cols] = (\n            df_focus[cols].apply(lambda x: x**0.5).groupby(level=\"datetime\", group_keys=False).apply(_feature_norm)\n        )\n\n        _cols = [\n            \"KMID\",\n            \"KSFT\",\n            \"OPEN\",\n            \"HIGH\",\n            \"LOW\",\n            \"CLOSE\",\n            \"VWAP\",\n            \"ROC\",\n            \"MA\",\n            \"BETA\",\n            \"RESI\",\n            \"QTLU\",\n            \"QTLD\",\n            \"RSV\",\n            \"SUMP\",\n            \"SUMN\",\n            \"SUMD\",\n            \"VSUMP\",\n            \"VSUMN\",\n            \"VSUMD\",\n        ]\n        pat = \"|\".join([\"^\" + x for x in _cols])\n        cols = df_focus.columns[df_focus.columns.str.contains(pat) & (~df_focus.columns.isin([\"HIGH0\", \"LOW0\"]))]\n        df_focus[cols] = df_focus[cols].groupby(level=\"datetime\", group_keys=False).apply(_feature_norm)\n\n        cols = df_focus.columns[df_focus.columns.str.contains(\"^STD|^VOLUME|^VMA|^VSTD\")]\n        df_focus[cols] = df_focus[cols].apply(np.log).groupby(level=\"datetime\", group_keys=False).apply(_feature_norm)\n\n        cols = df_focus.columns[df_focus.columns.str.contains(\"^RSQR\")]\n        df_focus[cols] = df_focus[cols].fillna(0).groupby(level=\"datetime\", group_keys=False).apply(_feature_norm)\n\n        cols = df_focus.columns[df_focus.columns.str.contains(\"^MAX|^HIGH0\")]\n        df_focus[cols] = (\n            df_focus[cols]\n            .apply(lambda x: (x - 1) ** 0.5)\n            .groupby(level=\"datetime\", group_keys=False)\n            .apply(_feature_norm)\n        )\n\n        cols = df_focus.columns[df_focus.columns.str.contains(\"^MIN|^LOW0\")]\n        df_focus[cols] = (\n            df_focus[cols]\n            .apply(lambda x: (1 - x) ** 0.5)\n            .groupby(level=\"datetime\", group_keys=False)\n            .apply(_feature_norm)\n        )\n\n        cols = df_focus.columns[df_focus.columns.str.contains(\"^CORR|^CORD\")]\n        df_focus[cols] = df_focus[cols].apply(np.exp).groupby(level=\"datetime\", group_keys=False).apply(_feature_norm)\n\n        cols = df_focus.columns[df_focus.columns.str.contains(\"^WVMA\")]\n        df_focus[cols] = df_focus[cols].apply(np.log1p).groupby(level=\"datetime\", group_keys=False).apply(_feature_norm)\n\n        df[selected_cols] = df_focus.values\n\n        TimeInspector.log_cost_time(\"Finished preprocessing data.\")\n\n        return df\n"
  },
  {
    "path": "qlib/contrib/data/utils/__init__.py",
    "content": ""
  },
  {
    "path": "qlib/contrib/data/utils/sepdf.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\nimport pandas as pd\nfrom typing import Dict, Iterable, Union\n\n\ndef align_index(df_dict, join):\n    res = {}\n    for k, df in df_dict.items():\n        if join is not None and k != join:\n            df = df.reindex(df_dict[join].index)\n        res[k] = df\n    return res\n\n\n# Mocking the pd.DataFrame class\nclass SepDataFrame:\n    \"\"\"\n    (Sep)erate DataFrame\n    We usually concat multiple dataframe to be processed together(Such as feature, label, weight, filter).\n    However, they are usually be used separately at last.\n    This will result in extra cost for concatenating and splitting data(reshaping and copying data in the memory is very expensive)\n\n    SepDataFrame tries to act like a DataFrame whose column with multiindex\n    \"\"\"\n\n    # TODO:\n    # SepDataFrame try to behave like pandas dataframe,  but it is still not them same\n    # Contributions are welcome to make it more complete.\n\n    def __init__(self, df_dict: Dict[str, pd.DataFrame], join: str, skip_align=False):\n        \"\"\"\n        initialize the data based on the dataframe dictionary\n\n        Parameters\n        ----------\n        df_dict : Dict[str, pd.DataFrame]\n            dataframe dictionary\n        join : str\n            how to join the data\n            It will reindex the dataframe based on the join key.\n            If join is None, the reindex step will be skipped\n\n        skip_align :\n            for some cases, we can improve performance by skipping aligning index\n        \"\"\"\n        self.join = join\n\n        if skip_align:\n            self._df_dict = df_dict\n        else:\n            self._df_dict = align_index(df_dict, join)\n\n    @property\n    def loc(self):\n        return SDFLoc(self, join=self.join)\n\n    @property\n    def index(self):\n        return self._df_dict[self.join].index\n\n    def apply_each(self, method: str, skip_align=True, *args, **kwargs):\n        \"\"\"\n        Assumptions:\n        - inplace methods will return None\n        \"\"\"\n        inplace = False\n        df_dict = {}\n        for k, df in self._df_dict.items():\n            df_dict[k] = getattr(df, method)(*args, **kwargs)\n            if df_dict[k] is None:\n                inplace = True\n        if not inplace:\n            return SepDataFrame(df_dict=df_dict, join=self.join, skip_align=skip_align)\n\n    def sort_index(self, *args, **kwargs):\n        return self.apply_each(\"sort_index\", True, *args, **kwargs)\n\n    def copy(self, *args, **kwargs):\n        return self.apply_each(\"copy\", True, *args, **kwargs)\n\n    def _update_join(self):\n        if self.join not in self:\n            if len(self._df_dict) > 0:\n                self.join = next(iter(self._df_dict.keys()))\n            else:\n                # NOTE: this will change the behavior of previous reindex when all the keys are empty\n                self.join = None\n\n    def __getitem__(self, item):\n        # TODO: behave more like pandas when multiindex\n        return self._df_dict[item]\n\n    def __setitem__(self, item: str, df: Union[pd.DataFrame, pd.Series]):\n        # TODO: consider the join behavior\n        if not isinstance(item, tuple):\n            self._df_dict[item] = df\n        else:\n            # NOTE: corner case of MultiIndex\n            _df_dict_key, *col_name = item\n            col_name = tuple(col_name)\n            if _df_dict_key in self._df_dict:\n                if len(col_name) == 1:\n                    col_name = col_name[0]\n                self._df_dict[_df_dict_key][col_name] = df\n            else:\n                if isinstance(df, pd.Series):\n                    if len(col_name) == 1:\n                        col_name = col_name[0]\n                    self._df_dict[_df_dict_key] = df.to_frame(col_name)\n                else:\n                    df_copy = df.copy()  # avoid changing df\n                    df_copy.columns = pd.MultiIndex.from_tuples([(*col_name, *idx) for idx in df.columns.to_list()])\n                    self._df_dict[_df_dict_key] = df_copy\n\n    def __delitem__(self, item: str):\n        del self._df_dict[item]\n        self._update_join()\n\n    def __contains__(self, item):\n        return item in self._df_dict\n\n    def __len__(self):\n        return len(self._df_dict[self.join])\n\n    def droplevel(self, *args, **kwargs):\n        raise NotImplementedError(f\"Please implement the `droplevel` method\")\n\n    @property\n    def columns(self):\n        dfs = []\n        for k, df in self._df_dict.items():\n            df = df.head(0)\n            df.columns = pd.MultiIndex.from_product([[k], df.columns])\n            dfs.append(df)\n        return pd.concat(dfs, axis=1).columns\n\n    # Useless methods\n    @staticmethod\n    def merge(df_dict: Dict[str, pd.DataFrame], join: str):\n        all_df = df_dict[join]\n        for k, df in df_dict.items():\n            if k != join:\n                all_df = all_df.join(df)\n        return all_df\n\n\nclass SDFLoc:\n    \"\"\"Mock Class\"\"\"\n\n    def __init__(self, sdf: SepDataFrame, join):\n        self._sdf = sdf\n        self.axis = None\n        self.join = join\n\n    def __call__(self, axis):\n        self.axis = axis\n        return self\n\n    def __getitem__(self, args):\n        if self.axis == 1:\n            if isinstance(args, str):\n                return self._sdf[args]\n            elif isinstance(args, (tuple, list)):\n                new_df_dict = {k: self._sdf[k] for k in args}\n                return SepDataFrame(new_df_dict, join=self.join if self.join in args else args[0], skip_align=True)\n            else:\n                raise NotImplementedError(f\"This type of input is not supported\")\n        elif self.axis == 0:\n            return SepDataFrame(\n                {k: df.loc(axis=0)[args] for k, df in self._sdf._df_dict.items()}, join=self.join, skip_align=True\n            )\n        else:\n            df = self._sdf\n            if isinstance(args, tuple):\n                ax0, *ax1 = args\n                if len(ax1) == 0:\n                    ax1 = None\n                if ax1 is not None:\n                    df = df.loc(axis=1)[ax1]\n                if ax0 is not None:\n                    df = df.loc(axis=0)[ax0]\n                return df\n            else:\n                return df.loc(axis=0)[args]\n\n\n# Patch pandas DataFrame\n# Tricking isinstance to accept SepDataFrame as its subclass\nimport builtins\n\n\ndef _isinstance(instance, cls):\n    if isinstance_orig(instance, SepDataFrame):  # pylint: disable=E0602  # noqa: F821\n        if isinstance(cls, Iterable):\n            for c in cls:\n                if c is pd.DataFrame:\n                    return True\n        elif cls is pd.DataFrame:\n            return True\n    return isinstance_orig(instance, cls)  # pylint: disable=E0602  # noqa: F821\n\n\nbuiltins.isinstance_orig = builtins.isinstance\nbuiltins.isinstance = _isinstance\n\nif __name__ == \"__main__\":\n    sdf = SepDataFrame({}, join=None)\n    print(isinstance(sdf, (pd.DataFrame,)))\n    print(isinstance(sdf, pd.DataFrame))\n"
  },
  {
    "path": "qlib/contrib/eva/__init__.py",
    "content": ""
  },
  {
    "path": "qlib/contrib/eva/alpha.py",
    "content": "\"\"\"\nHere is a batch of evaluation functions.\n\nThe interface should be redesigned carefully in the future.\n\"\"\"\n\nimport pandas as pd\nfrom typing import Tuple\nfrom qlib import get_module_logger\nfrom qlib.utils.paral import complex_parallel, DelayedDict\nfrom joblib import Parallel, delayed\n\n\ndef calc_long_short_prec(\n    pred: pd.Series, label: pd.Series, date_col=\"datetime\", quantile: float = 0.2, dropna=False, is_alpha=False\n) -> Tuple[pd.Series, pd.Series]:\n    \"\"\"\n    calculate the precision for long and short operation\n\n\n    :param pred/label: index is **pd.MultiIndex**, index name is **[datetime, instruments]**; columns names is **[score]**.\n\n            .. code-block:: python\n                                                  score\n                datetime            instrument\n                2020-12-01 09:30:00 SH600068    0.553634\n                                    SH600195    0.550017\n                                    SH600276    0.540321\n                                    SH600584    0.517297\n                                    SH600715    0.544674\n    label :\n        label\n    date_col :\n        date_col\n\n    Returns\n    -------\n    (pd.Series, pd.Series)\n        long precision and short precision in time level\n    \"\"\"\n    if is_alpha:\n        label = label - label.groupby(level=date_col, group_keys=False).mean()\n    if int(1 / quantile) >= len(label.index.get_level_values(1).unique()):\n        raise ValueError(\"Need more instruments to calculate precision\")\n\n    df = pd.DataFrame({\"pred\": pred, \"label\": label})\n    if dropna:\n        df.dropna(inplace=True)\n\n    group = df.groupby(level=date_col, group_keys=False)\n\n    def N(x):\n        return int(len(x) * quantile)\n\n    # find the top/low quantile of prediction and treat them as long and short target\n    long = group.apply(lambda x: x.nlargest(N(x), columns=\"pred\").label)\n    short = group.apply(lambda x: x.nsmallest(N(x), columns=\"pred\").label)\n\n    groupll = long.groupby(date_col, group_keys=False)\n    l_dom = groupll.apply(lambda x: x > 0)\n    l_c = groupll.count()\n\n    groups = short.groupby(date_col, group_keys=False)\n    s_dom = groups.apply(lambda x: x < 0)\n    s_c = groups.count()\n    return (l_dom.groupby(date_col, group_keys=False).sum() / l_c), (\n        s_dom.groupby(date_col, group_keys=False).sum() / s_c\n    )\n\n\ndef calc_long_short_return(\n    pred: pd.Series,\n    label: pd.Series,\n    date_col: str = \"datetime\",\n    quantile: float = 0.2,\n    dropna: bool = False,\n) -> Tuple[pd.Series, pd.Series]:\n    \"\"\"\n    calculate long-short return\n\n    Note:\n        `label` must be raw stock returns.\n\n    Parameters\n    ----------\n    pred : pd.Series\n        stock predictions\n    label : pd.Series\n        stock returns\n    date_col : str\n        datetime index name\n    quantile : float\n        long-short quantile\n\n    Returns\n    ----------\n    long_short_r : pd.Series\n        daily long-short returns\n    long_avg_r : pd.Series\n        daily long-average returns\n    \"\"\"\n    df = pd.DataFrame({\"pred\": pred, \"label\": label})\n    if dropna:\n        df.dropna(inplace=True)\n    group = df.groupby(level=date_col, group_keys=False)\n\n    def N(x):\n        return int(len(x) * quantile)\n\n    r_long = group.apply(lambda x: x.nlargest(N(x), columns=\"pred\").label.mean())\n    r_short = group.apply(lambda x: x.nsmallest(N(x), columns=\"pred\").label.mean())\n    r_avg = group.label.mean()\n    return (r_long - r_short) / 2, r_avg\n\n\ndef pred_autocorr(pred: pd.Series, lag=1, inst_col=\"instrument\", date_col=\"datetime\"):\n    \"\"\"pred_autocorr.\n\n    Limitation:\n    - If the datetime is not sequential densely, the correlation will be calulated based on adjacent dates. (some users may expected NaN)\n\n    :param pred: pd.Series with following format\n                instrument  datetime\n                SH600000    2016-01-04   -0.000403\n                            2016-01-05   -0.000753\n                            2016-01-06   -0.021801\n                            2016-01-07   -0.065230\n                            2016-01-08   -0.062465\n    :type pred: pd.Series\n    :param lag:\n    \"\"\"\n    if isinstance(pred, pd.DataFrame):\n        pred = pred.iloc[:, 0]\n        get_module_logger(\"pred_autocorr\").warning(f\"Only the first column in {pred.columns} of `pred` is kept\")\n    pred_ustk = pred.sort_index().unstack(inst_col)\n    corr_s = {}\n    for (idx, cur), (_, prev) in zip(pred_ustk.iterrows(), pred_ustk.shift(lag).iterrows()):\n        corr_s[idx] = cur.corr(prev)\n    corr_s = pd.Series(corr_s).sort_index()\n    return corr_s\n\n\ndef pred_autocorr_all(pred_dict, n_jobs=-1, **kwargs):\n    \"\"\"\n    calculate auto correlation for pred_dict\n\n    Parameters\n    ----------\n    pred_dict : dict\n        A dict like {<method_name>:  <prediction>}\n    kwargs :\n        all these arguments will be passed into pred_autocorr\n    \"\"\"\n    ac_dict = {}\n    for k, pred in pred_dict.items():\n        ac_dict[k] = delayed(pred_autocorr)(pred, **kwargs)\n    return complex_parallel(Parallel(n_jobs=n_jobs, verbose=10), ac_dict)\n\n\ndef calc_ic(pred: pd.Series, label: pd.Series, date_col=\"datetime\", dropna=False) -> (pd.Series, pd.Series):\n    \"\"\"calc_ic.\n\n    Parameters\n    ----------\n    pred :\n        pred\n    label :\n        label\n    date_col :\n        date_col\n\n    Returns\n    -------\n    (pd.Series, pd.Series)\n        ic and rank ic\n    \"\"\"\n    df = pd.DataFrame({\"pred\": pred, \"label\": label})\n    ic = df.groupby(date_col, group_keys=False).apply(lambda df: df[\"pred\"].corr(df[\"label\"]))\n    ric = df.groupby(date_col, group_keys=False).apply(lambda df: df[\"pred\"].corr(df[\"label\"], method=\"spearman\"))\n    if dropna:\n        return ic.dropna(), ric.dropna()\n    else:\n        return ic, ric\n\n\ndef calc_all_ic(pred_dict_all, label, date_col=\"datetime\", dropna=False, n_jobs=-1):\n    \"\"\"calc_all_ic.\n\n    Parameters\n    ----------\n    pred_dict_all :\n        A dict like {<method_name>:  <prediction>}\n    label:\n        A pd.Series of label values\n\n    Returns\n    -------\n    {'Q2+IND_z': {'ic': <ic series like>\n                          2016-01-04   -0.057407\n                          ...\n                          2020-05-28    0.183470\n                          2020-05-29    0.171393\n                  'ric': <rank ic series like>\n                          2016-01-04   -0.040888\n                          ...\n                          2020-05-28    0.236665\n                          2020-05-29    0.183886\n                  }\n    ...}\n    \"\"\"\n    pred_all_ics = {}\n    for k, pred in pred_dict_all.items():\n        pred_all_ics[k] = DelayedDict([\"ic\", \"ric\"], delayed(calc_ic)(pred, label, date_col=date_col, dropna=dropna))\n    pred_all_ics = complex_parallel(Parallel(n_jobs=n_jobs, verbose=10), pred_all_ics)\n    return pred_all_ics\n"
  },
  {
    "path": "qlib/contrib/evaluate.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport numpy as np\nimport pandas as pd\nimport warnings\nfrom typing import Union, Literal\n\nfrom ..log import get_module_logger\nfrom ..utils import get_date_range\nfrom ..utils.resam import Freq\nfrom ..strategy.base import BaseStrategy\nfrom ..backtest import get_exchange, position, backtest as backtest_func, executor as _executor\n\n\nfrom ..data import D\nfrom ..config import C\nfrom ..data.dataset.utils import get_level_index\n\nlogger = get_module_logger(\"Evaluate\")\n\n\ndef risk_analysis(r, N: int = None, freq: str = \"day\", mode: Literal[\"sum\", \"product\"] = \"sum\"):\n    \"\"\"Risk Analysis\n    NOTE:\n    The calculation of annualized return is different from the definition of annualized return.\n    It is implemented by design.\n    Qlib tries to cumulate returns by summation instead of production to avoid the cumulated curve being skewed exponentially.\n    All the calculation of annualized returns follows this principle in Qlib.\n\n    Parameters\n    ----------\n    r : pandas.Series\n        daily return series.\n    N: int\n        scaler for annualizing information_ratio (day: 252, week: 50, month: 12), at least one of `N` and `freq` should exist\n    freq: str\n        analysis frequency used for calculating the scaler, at least one of `N` and `freq` should exist\n    mode: Literal[\"sum\", \"product\"]\n        the method by which returns are accumulated:\n        - \"sum\": Arithmetic accumulation (linear returns).\n        - \"product\": Geometric accumulation (compounded returns).\n    \"\"\"\n\n    def cal_risk_analysis_scaler(freq):\n        _count, _freq = Freq.parse(freq)\n        _freq_scaler = {\n            Freq.NORM_FREQ_MINUTE: 240 * 238,\n            Freq.NORM_FREQ_DAY: 238,\n            Freq.NORM_FREQ_WEEK: 50,\n            Freq.NORM_FREQ_MONTH: 12,\n        }\n        return _freq_scaler[_freq] / _count\n\n    if N is None and freq is None:\n        raise ValueError(\"at least one of `N` and `freq` should exist\")\n    if N is not None and freq is not None:\n        warnings.warn(\"risk_analysis freq will be ignored\")\n    if N is None:\n        N = cal_risk_analysis_scaler(freq)\n\n    if mode == \"sum\":\n        mean = r.mean()\n        std = r.std(ddof=1)\n        annualized_return = mean * N\n        max_drawdown = (r.cumsum() - r.cumsum().cummax()).min()\n    elif mode == \"product\":\n        cumulative_curve = (1 + r).cumprod()\n        # geometric mean (compound annual growth rate)\n        mean = cumulative_curve.iloc[-1] ** (1 / len(r)) - 1\n        # volatility of log returns\n        std = np.log(1 + r).std(ddof=1)\n\n        cumulative_return = cumulative_curve.iloc[-1] - 1\n        annualized_return = (1 + cumulative_return) ** (N / len(r)) - 1\n        # max percentage drawdown from peak cumulative product\n        max_drawdown = (cumulative_curve / cumulative_curve.cummax() - 1).min()\n    else:\n        raise ValueError(f\"risk_analysis accumulation mode {mode} is not supported. Expected `sum` or `product`.\")\n\n    information_ratio = mean / std * np.sqrt(N)\n    data = {\n        \"mean\": mean,\n        \"std\": std,\n        \"annualized_return\": annualized_return,\n        \"information_ratio\": information_ratio,\n        \"max_drawdown\": max_drawdown,\n    }\n    res = pd.Series(data).to_frame(\"risk\")\n    return res\n\n\ndef indicator_analysis(df, method=\"mean\"):\n    \"\"\"analyze statistical time-series indicators of trading\n\n    Parameters\n    ----------\n    df : pandas.DataFrame\n        columns: like ['pa', 'pos', 'ffr', 'deal_amount', 'value'].\n            Necessary fields:\n                - 'pa' is the price advantage in trade indicators\n                - 'pos' is the positive rate in trade indicators\n                - 'ffr' is the fulfill rate in trade indicators\n            Optional fields:\n                - 'deal_amount' is the total deal deal_amount, only necessary when method is 'amount_weighted'\n                - 'value' is the total trade value, only necessary when method is 'value_weighted'\n\n        index: Index(datetime)\n    method : str, optional\n        statistics method of pa/ffr, by default \"mean\"\n\n        - if method is 'mean', count the mean statistical value of each trade indicator\n        - if method is 'amount_weighted', count the deal_amount weighted mean statistical value of each trade indicator\n        - if method is 'value_weighted', count the value weighted mean statistical value of each trade indicator\n\n        Note: statistics method of pos is always \"mean\"\n\n    Returns\n    -------\n    pd.DataFrame\n        statistical value of each trade indicators\n    \"\"\"\n    weights_dict = {\n        \"mean\": df[\"count\"],\n        \"amount_weighted\": df[\"deal_amount\"].abs(),\n        \"value_weighted\": df[\"value\"].abs(),\n    }\n    if method not in weights_dict:\n        raise ValueError(f\"indicator_analysis method {method} is not supported!\")\n\n    # statistic pa/ffr indicator\n    indicators_df = df[[\"ffr\", \"pa\"]]\n    weights = weights_dict.get(method)\n    res = indicators_df.mul(weights, axis=0).sum() / weights.sum()\n\n    # statistic pos\n    weights = weights_dict.get(\"mean\")\n    res.loc[\"pos\"] = df[\"pos\"].mul(weights).sum() / weights.sum()\n    res = res.to_frame(\"value\")\n    return res\n\n\n# This is the API for compatibility for legacy code\ndef backtest_daily(\n    start_time: Union[str, pd.Timestamp],\n    end_time: Union[str, pd.Timestamp],\n    strategy: Union[str, dict, BaseStrategy],\n    executor: Union[str, dict, _executor.BaseExecutor] = None,\n    account: Union[float, int, position.Position] = 1e8,\n    benchmark: str = \"SH000300\",\n    exchange_kwargs: dict = None,\n    pos_type: str = \"Position\",\n):\n    \"\"\"initialize the strategy and executor, then executor the backtest of daily frequency\n\n    Parameters\n    ----------\n    start_time : Union[str, pd.Timestamp]\n        closed start time for backtest\n        **NOTE**: This will be applied to the outmost executor's calendar.\n    end_time : Union[str, pd.Timestamp]\n        closed end time for backtest\n        **NOTE**: This will be applied to the outmost executor's calendar.\n        E.g. Executor[day](Executor[1min]),   setting `end_time == 20XX0301` will include all the minutes on 20XX0301\n    strategy : Union[str, dict, BaseStrategy]\n        for initializing outermost portfolio strategy. Please refer to the docs of init_instance_by_config for more information.\n\n        E.g.\n\n        .. code-block:: python\n\n            # dict\n            strategy = {\n                \"class\": \"TopkDropoutStrategy\",\n                \"module_path\": \"qlib.contrib.strategy.signal_strategy\",\n                \"kwargs\": {\n                    \"signal\": (model, dataset),\n                    \"topk\": 50,\n                    \"n_drop\": 5,\n                },\n            }\n            # BaseStrategy\n            pred_score = pd.read_pickle(\"score.pkl\")[\"score\"]\n            STRATEGY_CONFIG = {\n                \"topk\": 50,\n                \"n_drop\": 5,\n                \"signal\": pred_score,\n            }\n            strategy = TopkDropoutStrategy(**STRATEGY_CONFIG)\n            # str example.\n            # 1) specify a pickle object\n            #     - path like 'file:///<path to pickle file>/obj.pkl'\n            # 2) specify a class name\n            #     - \"ClassName\":  getattr(module, \"ClassName\")() will be used.\n            # 3) specify module path with class name\n            #     - \"a.b.c.ClassName\" getattr(<a.b.c.module>, \"ClassName\")() will be used.\n\n    executor : Union[str, dict, BaseExecutor]\n        for initializing the outermost executor.\n    benchmark: str\n        the benchmark for reporting.\n    account : Union[float, int, Position]\n        information for describing how to creating the account\n\n        For `float` or `int`:\n\n            Using Account with only initial cash\n\n        For `Position`:\n\n            Using Account with a Position\n    exchange_kwargs : dict\n        the kwargs for initializing Exchange\n        E.g.\n\n        .. code-block:: python\n\n            exchange_kwargs = {\n                \"freq\": freq,\n                \"limit_threshold\": None, # limit_threshold is None, using C.limit_threshold\n                \"deal_price\": None, # deal_price is None, using C.deal_price\n                \"open_cost\": 0.0005,\n                \"close_cost\": 0.0015,\n                \"min_cost\": 5,\n            }\n\n    pos_type : str\n        the type of Position.\n\n    Returns\n    -------\n    report_normal: pd.DataFrame\n        backtest report\n    positions_normal: pd.DataFrame\n        backtest positions\n\n    \"\"\"\n    freq = \"day\"\n    if executor is None:\n        executor_config = {\n            \"time_per_step\": freq,\n            \"generate_portfolio_metrics\": True,\n        }\n        executor = _executor.SimulatorExecutor(**executor_config)\n    _exchange_kwargs = {\n        \"freq\": freq,\n        \"limit_threshold\": None,\n        \"deal_price\": None,\n        \"open_cost\": 0.0005,\n        \"close_cost\": 0.0015,\n        \"min_cost\": 5,\n    }\n    if exchange_kwargs is not None:\n        _exchange_kwargs.update(exchange_kwargs)\n\n    portfolio_metric_dict, indicator_dict = backtest_func(\n        start_time=start_time,\n        end_time=end_time,\n        strategy=strategy,\n        executor=executor,\n        account=account,\n        benchmark=benchmark,\n        exchange_kwargs=_exchange_kwargs,\n        pos_type=pos_type,\n    )\n    analysis_freq = \"{0}{1}\".format(*Freq.parse(freq))\n\n    report_normal, positions_normal = portfolio_metric_dict.get(analysis_freq)\n\n    return report_normal, positions_normal\n\n\ndef long_short_backtest(\n    pred,\n    topk=50,\n    deal_price=None,\n    shift=1,\n    open_cost=0,\n    close_cost=0,\n    trade_unit=None,\n    limit_threshold=None,\n    min_cost=5,\n    subscribe_fields=[],\n    extract_codes=False,\n):\n    \"\"\"\n    A backtest for long-short strategy\n\n    :param pred:        The trading signal produced on day `T`.\n    :param topk:       The short topk securities and long topk securities.\n    :param deal_price:  The price to deal the trading.\n    :param shift:       Whether to shift prediction by one day.  The trading day will be T+1 if shift==1.\n    :param open_cost:   open transaction cost.\n    :param close_cost:  close transaction cost.\n    :param trade_unit:  100 for China A.\n    :param limit_threshold: limit move 0.1 (10%) for example, long and short with same limit.\n    :param min_cost:    min transaction cost.\n    :param subscribe_fields: subscribe fields.\n    :param extract_codes:  bool.\n                       will we pass the codes extracted from the pred to the exchange.\n                       NOTE: This will be faster with offline qlib.\n    :return:            The result of backtest, it is represented by a dict.\n                        { \"long\": long_returns(excess),\n                        \"short\": short_returns(excess),\n                        \"long_short\": long_short_returns}\n    \"\"\"\n    if get_level_index(pred, level=\"datetime\") == 1:\n        pred = pred.swaplevel().sort_index()\n\n    if trade_unit is None:\n        trade_unit = C.trade_unit\n    if limit_threshold is None:\n        limit_threshold = C.limit_threshold\n    if deal_price is None:\n        deal_price = C.deal_price\n    if deal_price[0] != \"$\":\n        deal_price = \"$\" + deal_price\n\n    subscribe_fields = subscribe_fields.copy()\n    profit_str = f\"Ref({deal_price}, -1)/{deal_price} - 1\"\n    subscribe_fields.append(profit_str)\n\n    trade_exchange = get_exchange(\n        pred=pred,\n        deal_price=deal_price,\n        subscribe_fields=subscribe_fields,\n        limit_threshold=limit_threshold,\n        open_cost=open_cost,\n        close_cost=close_cost,\n        min_cost=min_cost,\n        trade_unit=trade_unit,\n        extract_codes=extract_codes,\n        shift=shift,\n    )\n\n    _pred_dates = pred.index.get_level_values(level=\"datetime\")\n    predict_dates = D.calendar(start_time=_pred_dates.min(), end_time=_pred_dates.max())\n    trade_dates = np.append(predict_dates[shift:], get_date_range(predict_dates[-1], left_shift=1, right_shift=shift))\n\n    long_returns = {}\n    short_returns = {}\n    ls_returns = {}\n\n    for pdate, date in zip(predict_dates, trade_dates):\n        score = pred.loc(axis=0)[pdate, :]\n        score = score.reset_index().sort_values(by=\"score\", ascending=False)\n\n        long_stocks = list(score.iloc[:topk][\"instrument\"])\n        short_stocks = list(score.iloc[-topk:][\"instrument\"])\n\n        score = score.set_index([\"datetime\", \"instrument\"]).sort_index()\n\n        long_profit = []\n        short_profit = []\n        all_profit = []\n\n        for stock in long_stocks:\n            if not trade_exchange.is_stock_tradable(stock_id=stock, trade_date=date):\n                continue\n            profit = trade_exchange.get_quote_info(stock_id=stock, start_time=date, end_time=date, field=profit_str)\n            if np.isnan(profit):\n                long_profit.append(0)\n            else:\n                long_profit.append(profit)\n\n        for stock in short_stocks:\n            if not trade_exchange.is_stock_tradable(stock_id=stock, trade_date=date):\n                continue\n            profit = trade_exchange.get_quote_info(stock_id=stock, start_time=date, end_time=date, field=profit_str)\n            if np.isnan(profit):\n                short_profit.append(0)\n            else:\n                short_profit.append(profit * -1)\n\n        for stock in list(score.loc(axis=0)[pdate, :].index.get_level_values(level=0)):\n            # exclude the suspend stock\n            if trade_exchange.check_stock_suspended(stock_id=stock, trade_date=date):\n                continue\n            profit = trade_exchange.get_quote_info(stock_id=stock, start_time=date, end_time=date, field=profit_str)\n            if np.isnan(profit):\n                all_profit.append(0)\n            else:\n                all_profit.append(profit)\n\n        long_returns[date] = np.mean(long_profit) - np.mean(all_profit)\n        short_returns[date] = np.mean(short_profit) + np.mean(all_profit)\n        ls_returns[date] = np.mean(short_profit) + np.mean(long_profit)\n\n    return dict(\n        zip(\n            [\"long\", \"short\", \"long_short\"],\n            map(pd.Series, [long_returns, short_returns, ls_returns]),\n        )\n    )\n\n\ndef t_run():\n    pred_FN = \"./check_pred.csv\"\n    pred: pd.DataFrame = pd.read_csv(pred_FN)\n    pred[\"datetime\"] = pd.to_datetime(pred[\"datetime\"])\n    pred = pred.set_index([pred.columns[0], pred.columns[1]])\n    pred = pred.iloc[:9000]\n    strategy_config = {\n        \"topk\": 50,\n        \"n_drop\": 5,\n        \"signal\": pred,\n    }\n    report_df, positions = backtest_daily(start_time=\"2017-01-01\", end_time=\"2020-08-01\", strategy=strategy_config)\n    print(report_df.head())\n    print(positions.keys())\n    print(positions[list(positions.keys())[0]])\n    return 0\n\n\nif __name__ == \"__main__\":\n    t_run()\n"
  },
  {
    "path": "qlib/contrib/evaluate_portfolio.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\n\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport numpy as np\nimport pandas as pd\nfrom scipy.stats import spearmanr, pearsonr\n\nfrom ..data import D\n\nfrom collections import OrderedDict\n\n\ndef _get_position_value_from_df(evaluate_date, position, close_data_df):\n    \"\"\"Get position value by existed close data df\n    close_data_df:\n        pd.DataFrame\n        multi-index\n        close_data_df['$close'][stock_id][evaluate_date]: close price for (stock_id, evaluate_date)\n    position:\n        same in get_position_value()\n    \"\"\"\n    value = 0\n    for stock_id, report in position.items():\n        if stock_id != \"cash\":\n            value += report[\"amount\"] * close_data_df[\"$close\"][stock_id][evaluate_date]\n            # value += report['amount'] * report['price']\n    if \"cash\" in position:\n        value += position[\"cash\"]\n    return value\n\n\ndef get_position_value(evaluate_date, position):\n    \"\"\"sum of close*amount\n\n    get value of position\n\n    use close price\n\n        positions:\n        {\n            Timestamp('2016-01-05 00:00:00'):\n            {\n                'SH600022':\n                {\n                    'amount':100.00,\n                    'price':12.00\n                },\n\n                'cash':100000.0\n            }\n        }\n\n    It means Hold 100.0 'SH600022' and 100000.0 RMB in '2016-01-05'\n    \"\"\"\n    # load close price for position\n    # position should also consider cash\n    instruments = list(position.keys())\n    instruments = list(set(instruments) - {\"cash\"})  # filter 'cash'\n    fields = [\"$close\"]\n    close_data_df = D.features(\n        instruments,\n        fields,\n        start_time=evaluate_date,\n        end_time=evaluate_date,\n        freq=\"day\",\n        disk_cache=0,\n    )\n    value = _get_position_value_from_df(evaluate_date, position, close_data_df)\n    return value\n\n\ndef get_position_list_value(positions):\n    # generate instrument list and date for whole poitions\n    instruments = set()\n    for day, position in positions.items():\n        instruments.update(position.keys())\n    instruments = list(set(instruments) - {\"cash\"})  # filter 'cash'\n    instruments.sort()\n    day_list = list(positions.keys())\n    day_list.sort()\n    start_date, end_date = day_list[0], day_list[-1]\n    # load data\n    fields = [\"$close\"]\n    close_data_df = D.features(\n        instruments,\n        fields,\n        start_time=start_date,\n        end_time=end_date,\n        freq=\"day\",\n        disk_cache=0,\n    )\n    # generate value\n    # return dict for time:position_value\n    value_dict = OrderedDict()\n    for day, position in positions.items():\n        value = _get_position_value_from_df(evaluate_date=day, position=position, close_data_df=close_data_df)\n        value_dict[day] = value\n    return value_dict\n\n\ndef get_daily_return_series_from_positions(positions, init_asset_value):\n    \"\"\"Parameters\n    generate daily return series from  position view\n    positions: positions generated by strategy\n    init_asset_value : init asset value\n    return: pd.Series of daily return , return_series[date] = daily return rate\n    \"\"\"\n    value_dict = get_position_list_value(positions)\n    value_series = pd.Series(value_dict)\n    value_series = value_series.sort_index()  # check date\n    return_series = value_series.pct_change()\n    return_series[value_series.index[0]] = (\n        value_series[value_series.index[0]] / init_asset_value - 1\n    )  # update daily return for the first date\n    return return_series\n\n\ndef get_annual_return_from_positions(positions, init_asset_value):\n    \"\"\"Annualized Returns\n\n    p_r = (p_end / p_start)^{(250/n)} - 1\n\n    p_r     annual return\n    p_end   final value\n    p_start init value\n    n       days of backtest\n\n    \"\"\"\n    date_range_list = sorted(list(positions.keys()))\n    end_time = date_range_list[-1]\n    p_end = get_position_value(end_time, positions[end_time])\n    p_start = init_asset_value\n    n_period = len(date_range_list)\n    annual = pow((p_end / p_start), (250 / n_period)) - 1\n\n    return annual\n\n\ndef get_annaul_return_from_return_series(r, method=\"ci\"):\n    \"\"\"Risk Analysis from daily return series\n\n    Parameters\n    ----------\n    r : pandas.Series\n        daily return series\n    method : str\n        interest calculation method, ci(compound interest)/si(simple interest)\n    \"\"\"\n    mean = r.mean()\n    annual = (1 + mean) ** 250 - 1 if method == \"ci\" else mean * 250\n\n    return annual\n\n\ndef get_sharpe_ratio_from_return_series(r, risk_free_rate=0.00, method=\"ci\"):\n    \"\"\"Risk Analysis\n\n    Parameters\n    ----------\n    r : pandas.Series\n        daily return series\n    method : str\n        interest calculation method, ci(compound interest)/si(simple interest)\n    risk_free_rate : float\n        risk_free_rate, default as 0.00, can set as 0.03 etc\n    \"\"\"\n    std = r.std(ddof=1)\n    annual = get_annaul_return_from_return_series(r, method=method)\n    sharpe = (annual - risk_free_rate) / std / np.sqrt(250)\n\n    return sharpe\n\n\ndef get_max_drawdown_from_series(r):\n    \"\"\"Risk Analysis from asset value\n\n    cumprod way\n\n    Parameters\n    ----------\n    r : pandas.Series\n        daily return series\n    \"\"\"\n    # mdd = ((r.cumsum() - r.cumsum().cummax()) / (1 + r.cumsum().cummax())).min()\n\n    mdd = (((1 + r).cumprod() - (1 + r).cumprod().cummax()) / ((1 + r).cumprod().cummax())).min()\n\n    return mdd\n\n\ndef get_turnover_rate():\n    # in backtest\n    pass\n\n\ndef get_beta(r, b):\n    \"\"\"Risk Analysis  beta\n\n    Parameters\n    ----------\n    r : pandas.Series\n        daily return series of strategy\n    b : pandas.Series\n        daily return series of baseline\n    \"\"\"\n    cov_r_b = np.cov(r, b)\n    var_b = np.var(b)\n    return cov_r_b / var_b\n\n\ndef get_alpha(r, b, risk_free_rate=0.03):\n    beta = get_beta(r, b)\n    annaul_r = get_annaul_return_from_return_series(r)\n    annaul_b = get_annaul_return_from_return_series(b)\n\n    alpha = annaul_r - risk_free_rate - beta * (annaul_b - risk_free_rate)\n\n    return alpha\n\n\ndef get_volatility_from_series(r):\n    return r.std(ddof=1)\n\n\ndef get_rank_ic(a, b):\n    \"\"\"Rank IC\n\n    Parameters\n    ----------\n    r : pandas.Series\n        daily score series of feature\n    b : pandas.Series\n        daily return series\n\n    \"\"\"\n    return spearmanr(a, b).correlation\n\n\ndef get_normal_ic(a, b):\n    return pearsonr(a, b)[0]\n"
  },
  {
    "path": "qlib/contrib/meta/__init__.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom .data_selection import MetaTaskDS, MetaDatasetDS, MetaModelDS\n\n__all__ = [\"MetaTaskDS\", \"MetaDatasetDS\", \"MetaModelDS\"]\n"
  },
  {
    "path": "qlib/contrib/meta/data_selection/__init__.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom .dataset import MetaDatasetDS, MetaTaskDS\nfrom .model import MetaModelDS\n\n__all__ = [\"MetaDatasetDS\", \"MetaTaskDS\", \"MetaModelDS\"]\n"
  },
  {
    "path": "qlib/contrib/meta/data_selection/dataset.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\nimport pandas as pd\nimport numpy as np\nfrom copy import deepcopy\nfrom joblib import Parallel, delayed  # pylint: disable=E0401\nfrom typing import Dict, List, Union, Text, Tuple\nfrom qlib.data.dataset.utils import init_task_handler\nfrom qlib.data.dataset import DatasetH\nfrom qlib.contrib.torch import data_to_tensor\nfrom qlib.model.meta.task import MetaTask\nfrom qlib.model.meta.dataset import MetaTaskDataset\nfrom qlib.model.trainer import TrainerR\nfrom qlib.log import get_module_logger\nfrom qlib.utils import auto_filter_kwargs, get_date_by_shift, init_instance_by_config\nfrom qlib.utils.data import deepcopy_basic_type\nfrom qlib.workflow import R\nfrom qlib.workflow.task.gen import RollingGen, task_generator\nfrom qlib.workflow.task.utils import TimeAdjuster\nfrom tqdm.auto import tqdm\n\n\nclass InternalData:\n    def __init__(self, task_tpl: dict, step: int, exp_name: str):\n        self.task_tpl = task_tpl\n        self.step = step\n        self.exp_name = exp_name\n\n    def setup(self, trainer=TrainerR, trainer_kwargs={}):\n        \"\"\"\n        after running this function `self.data_ic_df` will become set.\n        Each col represents a data.\n        Each row represents the Timestamp of performance of that data.\n        For example,\n\n        .. code-block:: python\n\n                       2021-06-21 2021-06-04 2021-05-21 2021-05-07 2021-04-20 2021-04-06 2021-03-22 2021-03-08  ...\n                       2021-07-02 2021-06-18 2021-06-03 2021-05-20 2021-05-06 2021-04-19 2021-04-02 2021-03-19  ...\n            datetime                                                                                            ...\n            2018-01-02   0.079782   0.115975   0.070866   0.028849  -0.081170   0.140380   0.063864   0.110987  ...\n            2018-01-03   0.123386   0.107789   0.071037   0.045278  -0.060782   0.167446   0.089779   0.124476  ...\n            2018-01-04   0.140775   0.097206   0.063702   0.042415  -0.078164   0.173218   0.098914   0.114389  ...\n            2018-01-05   0.030320  -0.037209  -0.044536  -0.047267  -0.081888   0.045648   0.059947   0.047652  ...\n            2018-01-08   0.107201   0.009219  -0.015995  -0.036594  -0.086633   0.108965   0.122164   0.108508  ...\n            ...               ...        ...        ...        ...        ...        ...        ...        ...  ...\n\n        \"\"\"\n\n        # 1) prepare the prediction of proxy models\n        perf_task_tpl = deepcopy(self.task_tpl)  # this task is supposed to contains no complicated objects\n        # The only thing we want to save is the prediction\n        perf_task_tpl[\"record\"] = [\"qlib.workflow.record_temp.SignalRecord\"]\n\n        trainer = auto_filter_kwargs(trainer)(experiment_name=self.exp_name, **trainer_kwargs)\n        # NOTE:\n        # The handler is initialized for only once.\n        if not trainer.has_worker():\n            self.dh = init_task_handler(perf_task_tpl)\n            self.dh.config(dump_all=False)  # in some cases, the data handler are saved to disk with `dump_all=True`\n        else:\n            self.dh = init_instance_by_config(perf_task_tpl[\"dataset\"][\"kwargs\"][\"handler\"])\n        assert self.dh.dump_all is False  # otherwise, it will save all the detailed data\n\n        seg = perf_task_tpl[\"dataset\"][\"kwargs\"][\"segments\"]\n\n        # We want to split the training time period into small segments.\n        perf_task_tpl[\"dataset\"][\"kwargs\"][\"segments\"] = {\n            \"train\": (DatasetH.get_min_time(seg), DatasetH.get_max_time(seg)),\n            \"test\": (None, None),\n        }\n\n        # NOTE:\n        # we play a trick here\n        # treat the training segments as test to create the rolling tasks\n        rg = RollingGen(step=self.step, test_key=\"train\", train_key=None, task_copy_func=deepcopy_basic_type)\n        gen_task = task_generator(perf_task_tpl, [rg])\n\n        recorders = R.list_recorders(experiment_name=self.exp_name)\n        if len(gen_task) == len(recorders):\n            get_module_logger(\"Internal Data\").info(\"the data has been initialized\")\n        else:\n            # train new models\n            assert 0 == len(recorders), \"An empty experiment is required for setup `InternalData`\"\n            trainer.train(gen_task)\n\n        # 2) extract the similarity matrix\n        label_df = self.dh.fetch(col_set=\"label\")\n        # for\n        recorders = R.list_recorders(experiment_name=self.exp_name)\n\n        key_l = []\n        ic_l = []\n        for _, rec in tqdm(recorders.items(), desc=\"calc\"):\n            pred = rec.load_object(\"pred.pkl\")\n            task = rec.load_object(\"task\")\n            data_key = task[\"dataset\"][\"kwargs\"][\"segments\"][\"train\"]\n            key_l.append(data_key)\n            ic_l.append(delayed(self._calc_perf)(pred.iloc[:, 0], label_df.iloc[:, 0]))\n\n        ic_l = Parallel(n_jobs=-1)(ic_l)\n        self.data_ic_df = pd.DataFrame(dict(zip(key_l, ic_l)))\n        self.data_ic_df = self.data_ic_df.sort_index().sort_index(axis=1)\n\n        del self.dh  # handler is not useful now\n\n    def _calc_perf(self, pred, label):\n        df = pd.DataFrame({\"pred\": pred, \"label\": label})\n        df = df.groupby(\"datetime\", group_keys=False).corr(method=\"spearman\")\n        corr = df.loc(axis=0)[:, \"pred\"][\"label\"].droplevel(axis=0, level=-1)\n        return corr\n\n    def update(self):\n        \"\"\"update the data for online trading\"\"\"\n        # TODO:\n        # when new data are totally(including label) available\n        # - update the prediction\n        # - update the data similarity map(if applied)\n\n\nclass MetaTaskDS(MetaTask):\n    \"\"\"Meta Task for Data Selection\"\"\"\n\n    def __init__(self, task: dict, meta_info: pd.DataFrame, mode: str = MetaTask.PROC_MODE_FULL, fill_method=\"max\"):\n        \"\"\"\n\n        The description of the processed data\n\n            time_perf: A array with shape  <hist_step_n * step, data pieces>  ->  data piece performance\n\n            time_belong:  A array with shape <sample, data pieces>  -> belong or not (1. or 0.)\n            array([[1., 0., 0., ..., 0., 0., 0.],\n                   [1., 0., 0., ..., 0., 0., 0.],\n                   [1., 0., 0., ..., 0., 0., 0.],\n                   ...,\n                   [0., 0., 0., ..., 0., 0., 1.],\n                   [0., 0., 0., ..., 0., 0., 1.],\n                   [0., 0., 0., ..., 0., 0., 1.]])\n\n        Parameters\n        ----------\n        meta_info: pd.DataFrame\n            please refer to the docs of _prepare_meta_ipt for detailed explanation.\n        \"\"\"\n        super().__init__(task, meta_info)\n        self.fill_method = fill_method\n\n        time_perf = self._get_processed_meta_info()\n        self.processed_meta_input = {\"time_perf\": time_perf}\n        # FIXME: memory issue in this step\n        if mode == MetaTask.PROC_MODE_FULL:\n            # process metainfo_\n            ds = self.get_dataset()\n\n            # these three lines occupied 70% of the time of initializing MetaTaskDS\n            d_train, d_test = ds.prepare([\"train\", \"test\"], col_set=[\"feature\", \"label\"])\n            prev_size = d_test.shape[0]\n            d_train = d_train.dropna(axis=0)\n            d_test = d_test.dropna(axis=0)\n            if prev_size == 0 or d_test.shape[0] / prev_size <= 0.1:\n                raise ValueError(f\"Most of samples are dropped. Please check this task: {task}\")\n\n            assert (\n                d_test.groupby(\"datetime\", group_keys=False).size().shape[0] >= 5\n            ), \"In this segment, this trading dates is less than 5, you'd better check the data.\"\n\n            sample_time_belong = np.zeros((d_train.shape[0], time_perf.shape[1]))\n            for i, col in enumerate(time_perf.columns):\n                # these two lines of code occupied 20% of the time of initializing MetaTaskDS\n                slc = slice(*d_train.index.slice_locs(start=col[0], end=col[1]))\n                sample_time_belong[slc, i] = 1.0\n\n            # If you want that last month also belongs to the last time_perf\n            # Assumptions: the latest data has similar performance like the last month\n            sample_time_belong[sample_time_belong.sum(axis=1) != 1, -1] = 1.0\n\n            self.processed_meta_input.update(\n                dict(\n                    X=d_train[\"feature\"],\n                    y=d_train[\"label\"].iloc[:, 0],\n                    X_test=d_test[\"feature\"],\n                    y_test=d_test[\"label\"].iloc[:, 0],\n                    time_belong=sample_time_belong,\n                    test_idx=d_test[\"label\"].index,\n                )\n            )\n\n        # TODO: set device: I think this is not necessary to converting data format.\n        self.processed_meta_input = data_to_tensor(self.processed_meta_input)\n\n    def _get_processed_meta_info(self):\n        meta_info_norm = self.meta_info.sub(self.meta_info.mean(axis=1), axis=0)\n        if self.fill_method.startswith(\"max\"):\n            suffix = self.fill_method.lstrip(\"max\")\n            if suffix == \"seg\":\n                fill_value = {}\n                for col in meta_info_norm.columns:\n                    fill_value[col] = meta_info_norm.loc[meta_info_norm[col].isna(), :].dropna(axis=1).mean().max()\n                fill_value = pd.Series(fill_value).sort_index()\n                # The NaN Values are filled segment-wise. Below is an exampleof fill_value\n                # 2009-01-05  2009-02-06    0.145809\n                # 2009-02-09  2009-03-06    0.148005\n                # 2009-03-09  2009-04-03    0.090385\n                # 2009-04-07  2009-05-05    0.114318\n                # 2009-05-06  2009-06-04    0.119328\n                # ...\n                meta_info_norm = meta_info_norm.fillna(fill_value)\n            else:\n                if len(suffix) > 0:\n                    get_module_logger(\"MetaTaskDS\").warning(\n                        f\"fill_method={self.fill_method}; the info after can't be correctly parsed. Please check your parameters.\"\n                    )\n                fill_value = meta_info_norm.max(axis=1)\n                # fill it with row max to align with previous implementation\n                # This will magnify the data similarity when data is in daily freq\n\n                # the fill value corresponds to data like this\n                # It get a performance value for each day.\n                # The performance value are get from other models on this day\n                # 2009-01-16    0.276320\n                # 2009-01-19    0.280603\n                #                 ...\n                # 2011-06-27    0.203773\n                meta_info_norm = meta_info_norm.T.fillna(fill_value).T\n        elif self.fill_method == \"zero\":\n            # It will fillna(0.0) at the end.\n            pass\n        else:\n            raise NotImplementedError(f\"This type of input is not supported\")\n        meta_info_norm = meta_info_norm.fillna(0.0)  # always fill zero in case of NaN\n        return meta_info_norm\n\n    def get_meta_input(self):\n        return self.processed_meta_input\n\n\nclass MetaDatasetDS(MetaTaskDataset):\n    def __init__(\n        self,\n        *,\n        task_tpl: Union[dict, list],\n        step: int,\n        trunc_days: int = None,\n        rolling_ext_days: int = 0,\n        exp_name: Union[str, InternalData],\n        segments: Union[Dict[Text, Tuple], float, str],\n        hist_step_n: int = 10,\n        task_mode: str = MetaTask.PROC_MODE_FULL,\n        fill_method: str = \"max\",\n    ):\n        \"\"\"\n        A dataset for meta model.\n\n        Parameters\n        ----------\n        task_tpl : Union[dict, list]\n            Decide what tasks are used.\n            - dict : the task template, the prepared task is generated with `step`, `trunc_days` and `RollingGen`\n            - list : when list, use the list of tasks directly\n                     the list is supposed to be sorted according timeline\n        step : int\n            the rolling step\n        trunc_days: int\n            days to be truncated based on the test start\n        rolling_ext_days: int\n            sometimes users want to train meta models for a longer test period but with smaller rolling steps for more task samples.\n            the total length of test periods will be `step + rolling_ext_days`\n\n        exp_name : Union[str, InternalData]\n            Decide what meta_info are used for prediction.\n            - str: the name of the experiment to store the performance of data\n            - InternalData: a prepared internal data\n        segments: Union[Dict[Text, Tuple], float]\n            if the segment is a Dict\n                the segments to divide data\n                both left and right are included\n            if segments is a float:\n                the float represents the percentage of data for training\n            if segments is a string:\n                it will try its best to put its data in training and ensure that the date `segments` is in the test set\n        hist_step_n: int\n            length of historical steps for the meta infomation\n            Number of steps of the data similarity information\n        task_mode : str\n            Please refer to the docs of MetaTask\n        \"\"\"\n        super().__init__(segments=segments)\n        if isinstance(exp_name, InternalData):\n            self.internal_data = exp_name\n        else:\n            self.internal_data = InternalData(task_tpl, step=step, exp_name=exp_name)\n            self.internal_data.setup()\n        self.task_tpl = deepcopy(task_tpl)  # FIXME: if the handler is shared, how to avoid the explosion of the memroy.\n        self.trunc_days = trunc_days\n        self.hist_step_n = hist_step_n\n        self.step = step\n\n        if isinstance(task_tpl, dict):\n            rg = RollingGen(\n                step=step, trunc_days=trunc_days, task_copy_func=deepcopy_basic_type\n            )  # NOTE: trunc_days is very important !!!!\n            task_iter = rg(task_tpl)\n            if rolling_ext_days > 0:\n                self.ta = TimeAdjuster(future=True)\n                for t in task_iter:\n                    t[\"dataset\"][\"kwargs\"][\"segments\"][\"test\"] = self.ta.shift(\n                        t[\"dataset\"][\"kwargs\"][\"segments\"][\"test\"], step=rolling_ext_days, rtype=RollingGen.ROLL_EX\n                    )\n            if task_mode == MetaTask.PROC_MODE_FULL:\n                # Only pre initializing the task when full task is req\n                # initializing handler and share it.\n                init_task_handler(task_tpl)\n        else:\n            assert isinstance(task_tpl, list)\n            task_iter = task_tpl\n\n        self.task_list = []\n        self.meta_task_l = []\n        logger = get_module_logger(\"MetaDatasetDS\")\n        logger.info(f\"Example task for training meta model: {task_iter[0]}\")\n        for t in tqdm(task_iter, desc=\"creating meta tasks\"):\n            try:\n                self.meta_task_l.append(\n                    MetaTaskDS(t, meta_info=self._prepare_meta_ipt(t), mode=task_mode, fill_method=fill_method)\n                )\n                self.task_list.append(t)\n            except ValueError as e:\n                logger.warning(f\"ValueError: {e}\")\n        assert len(self.meta_task_l) > 0, \"No meta tasks found. Please check the data and setting\"\n\n    def _prepare_meta_ipt(self, task) -> pd.DataFrame:\n        \"\"\"\n        Please refer to `self.internal_data.setup` for detailed information about `self.internal_data.data_ic_df`\n\n        Indices with format below can be successfully sliced by  `ic_df.loc[:end, pd.IndexSlice[:, :end]]`\n\n               2021-06-21 2021-06-04 .. 2021-03-22 2021-03-08\n               2021-07-02 2021-06-18 .. 2021-04-02 None\n\n        Returns\n        -------\n            a pd.DataFrame with similar content below.\n            - each column corresponds to a trained model named by the training data range\n            - each row corresponds to a day of data tested by the models of the columns\n            - The rows cells that overlaps with the data used by columns are masked\n\n\n                       2009-01-05 2009-02-09 ... 2011-04-27 2011-05-26\n                       2009-02-06 2009-03-06 ... 2011-05-25 2011-06-23\n            datetime                         ...\n            2009-01-13        NaN   0.310639 ...  -0.169057   0.137792\n            2009-01-14        NaN   0.261086 ...  -0.143567   0.082581\n            ...               ...        ... ...        ...        ...\n            2011-06-30  -0.054907  -0.020219 ...  -0.023226        NaN\n            2011-07-01  -0.075762  -0.026626 ...  -0.003167        NaN\n\n        \"\"\"\n        ic_df = self.internal_data.data_ic_df\n\n        segs = task[\"dataset\"][\"kwargs\"][\"segments\"]\n        end = max(segs[k][1] for k in (\"train\", \"valid\") if k in segs)\n        ic_df_avail = ic_df.loc[:end, pd.IndexSlice[:, :end]]\n\n        # meta data set focus on the **information** instead of preprocess\n        # 1) filter the overlap info\n        def mask_overlap(s):\n            \"\"\"\n            mask overlap information\n            data after self.name[end] with self.trunc_days that contains future info are also considered as overlap info\n\n            Approximately the diagnal + horizon length of data are masked.\n            \"\"\"\n            start, end = s.name\n            end = get_date_by_shift(trading_date=end, shift=self.trunc_days - 1, future=True)\n            return s.mask((s.index >= start) & (s.index <= end))\n\n        ic_df_avail = ic_df_avail.apply(mask_overlap)  # apply to each col\n\n        # 2) filter the info with too long periods\n        total_len = self.step * self.hist_step_n\n        if ic_df_avail.shape[0] >= total_len:\n            return ic_df_avail.iloc[-total_len:]\n        else:\n            raise ValueError(\"the history of distribution data is not long enough.\")\n\n    def _prepare_seg(self, segment: Text) -> List[MetaTask]:\n        if isinstance(self.segments, float):\n            train_task_n = int(len(self.meta_task_l) * self.segments)\n            if segment == \"train\":\n                train_tasks = self.meta_task_l[:train_task_n]\n                get_module_logger(\"MetaDatasetDS\").info(f\"The first train meta task: {train_tasks[0]}\")\n                return train_tasks\n            elif segment == \"test\":\n                test_tasks = self.meta_task_l[train_task_n:]\n                get_module_logger(\"MetaDatasetDS\").info(f\"The first test meta task: {test_tasks[0]}\")\n                return test_tasks\n            else:\n                raise NotImplementedError(f\"This type of input is not supported\")\n        elif isinstance(self.segments, str):\n            train_tasks = []\n            test_tasks = []\n            for t in self.meta_task_l:\n                test_end = t.task[\"dataset\"][\"kwargs\"][\"segments\"][\"test\"][1]\n                if test_end is None or pd.Timestamp(test_end) < pd.Timestamp(self.segments):\n                    train_tasks.append(t)\n                else:\n                    test_tasks.append(t)\n            get_module_logger(\"MetaDatasetDS\").info(f\"The first train meta task: {train_tasks[0]}\")\n            get_module_logger(\"MetaDatasetDS\").info(f\"The first test meta task: {test_tasks[0]}\")\n            if segment == \"train\":\n                return train_tasks\n            elif segment == \"test\":\n                return test_tasks\n            raise NotImplementedError(f\"This type of input is not supported\")\n        else:\n            raise NotImplementedError(f\"This type of input is not supported\")\n"
  },
  {
    "path": "qlib/contrib/meta/data_selection/model.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport pandas as pd\nimport numpy as np\nimport torch\nfrom torch import nn\nfrom torch import optim\nfrom tqdm.auto import tqdm\nimport copy\nfrom typing import Union, List\n\nfrom ....model.meta.dataset import MetaTaskDataset\nfrom ....model.meta.model import MetaTaskModel\nfrom ....workflow import R\nfrom .utils import ICLoss\nfrom .dataset import MetaDatasetDS\n\nfrom qlib.log import get_module_logger\nfrom qlib.model.meta.task import MetaTask\nfrom qlib.data.dataset.weight import Reweighter\nfrom qlib.contrib.meta.data_selection.net import PredNet\n\nlogger = get_module_logger(\"data selection\")\n\n\nclass TimeReweighter(Reweighter):\n    def __init__(self, time_weight: pd.Series):\n        self.time_weight = time_weight\n\n    def reweight(self, data: Union[pd.DataFrame, pd.Series]):\n        # TODO: handling TSDataSampler\n        w_s = pd.Series(1.0, index=data.index)\n        for k, w in self.time_weight.items():\n            w_s.loc[slice(*k)] = w\n        logger.info(f\"Reweighting result: {w_s}\")\n        return w_s\n\n\nclass MetaModelDS(MetaTaskModel):\n    \"\"\"\n    The meta-model for meta-learning-based data selection.\n    \"\"\"\n\n    def __init__(\n        self,\n        step,\n        hist_step_n,\n        clip_method=\"tanh\",\n        clip_weight=2.0,\n        criterion=\"ic_loss\",\n        lr=0.0001,\n        max_epoch=100,\n        seed=43,\n        alpha=0.0,\n        loss_skip_thresh=50,\n    ):\n        \"\"\"\n        loss_skip_size: int\n            The number of threshold to skip the loss calculation for each day.\n        \"\"\"\n        self.step = step\n        self.hist_step_n = hist_step_n\n        self.clip_method = clip_method\n        self.clip_weight = clip_weight\n        self.criterion = criterion\n        self.lr = lr\n        self.max_epoch = max_epoch\n        self.fitted = False\n        self.alpha = alpha\n        self.loss_skip_thresh = loss_skip_thresh\n        torch.manual_seed(seed)\n\n    def run_epoch(self, phase, task_list, epoch, opt, loss_l, ignore_weight=False):\n        if phase == \"train\":\n            self.tn.train()\n            torch.set_grad_enabled(True)\n        else:\n            self.tn.eval()\n            torch.set_grad_enabled(False)\n        running_loss = 0.0\n        pred_y_all = []\n        for task in tqdm(task_list, desc=f\"{phase} Task\", leave=False):\n            meta_input = task.get_meta_input()\n            pred, weights = self.tn(\n                meta_input[\"X\"],\n                meta_input[\"y\"],\n                meta_input[\"time_perf\"],\n                meta_input[\"time_belong\"],\n                meta_input[\"X_test\"],\n                ignore_weight=ignore_weight,\n            )\n            if self.criterion == \"mse\":\n                criterion = nn.MSELoss()\n                loss = criterion(pred, meta_input[\"y_test\"])\n            elif self.criterion == \"ic_loss\":\n                criterion = ICLoss(self.loss_skip_thresh)\n                try:\n                    loss = criterion(pred, meta_input[\"y_test\"], meta_input[\"test_idx\"])\n                except ValueError as e:\n                    get_module_logger(\"MetaModelDS\").warning(f\"Exception `{e}` when calculating IC loss\")\n                    continue\n            else:\n                raise ValueError(f\"Unknown criterion: {self.criterion}\")\n\n            assert not np.isnan(loss.detach().item()), \"NaN loss!\"\n\n            if phase == \"train\":\n                opt.zero_grad()\n                loss.backward()\n                opt.step()\n            elif phase == \"test\":\n                pass\n\n            pred_y_all.append(\n                pd.DataFrame(\n                    {\n                        \"pred\": pd.Series(pred.detach().cpu().numpy(), index=meta_input[\"test_idx\"]),\n                        \"label\": pd.Series(meta_input[\"y_test\"].detach().cpu().numpy(), index=meta_input[\"test_idx\"]),\n                    }\n                )\n            )\n            running_loss += loss.detach().item()\n        running_loss = running_loss / len(task_list)\n        loss_l.setdefault(phase, []).append(running_loss)\n\n        pred_y_all = pd.concat(pred_y_all)\n        ic = (\n            pred_y_all.groupby(\"datetime\", group_keys=False)\n            .apply(lambda df: df[\"pred\"].corr(df[\"label\"], method=\"spearman\"))\n            .mean()\n        )\n\n        R.log_metrics(**{f\"loss/{phase}\": running_loss, \"step\": epoch})\n        R.log_metrics(**{f\"ic/{phase}\": ic, \"step\": epoch})\n\n    def fit(self, meta_dataset: MetaDatasetDS):\n        \"\"\"\n        The meta-learning-based data selection interacts directly with meta-dataset due to the close-form proxy measurement.\n\n        Parameters\n        ----------\n        meta_dataset : MetaDatasetDS\n            The meta-model takes the meta-dataset for its training process.\n        \"\"\"\n\n        if not self.fitted:\n            for k in set([\"lr\", \"step\", \"hist_step_n\", \"clip_method\", \"clip_weight\", \"criterion\", \"max_epoch\"]):\n                R.log_params(**{k: getattr(self, k)})\n\n        # FIXME: get test tasks for just checking the performance\n        phases = [\"train\", \"test\"]\n        meta_tasks_l = meta_dataset.prepare_tasks(phases)\n\n        if len(meta_tasks_l[1]):\n            R.log_params(\n                **dict(proxy_test_begin=meta_tasks_l[1][0].task[\"dataset\"][\"kwargs\"][\"segments\"][\"test\"])\n            )  # debug: record when the test phase starts\n\n        self.tn = PredNet(\n            step=self.step,\n            hist_step_n=self.hist_step_n,\n            clip_weight=self.clip_weight,\n            clip_method=self.clip_method,\n            alpha=self.alpha,\n        )\n\n        opt = optim.Adam(self.tn.parameters(), lr=self.lr)\n\n        # run weight with no weight\n        for phase, task_list in zip(phases, meta_tasks_l):\n            self.run_epoch(f\"{phase}_noweight\", task_list, 0, opt, {}, ignore_weight=True)\n            self.run_epoch(f\"{phase}_init\", task_list, 0, opt, {})\n\n        # run training\n        loss_l = {}\n        for epoch in tqdm(range(self.max_epoch), desc=\"epoch\"):\n            for phase, task_list in zip(phases, meta_tasks_l):\n                self.run_epoch(phase, task_list, epoch, opt, loss_l)\n            R.save_objects(**{\"model.pkl\": self.tn})\n        self.fitted = True\n\n    def _prepare_task(self, task: MetaTask) -> dict:\n        meta_ipt = task.get_meta_input()\n        weights = self.tn.twm(meta_ipt[\"time_perf\"])\n\n        weight_s = pd.Series(weights.detach().cpu().numpy(), index=task.meta_info.columns)\n        task = copy.copy(task.task)  # NOTE: this is a shallow copy.\n        task[\"reweighter\"] = TimeReweighter(weight_s)\n        return task\n\n    def inference(self, meta_dataset: MetaTaskDataset) -> List[dict]:\n        res = []\n        for mt in meta_dataset.prepare_tasks(\"test\"):\n            res.append(self._prepare_task(mt))\n        return res\n"
  },
  {
    "path": "qlib/contrib/meta/data_selection/net.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport numpy as np\nimport torch\nfrom torch import nn\n\nfrom .utils import preds_to_weight_with_clamp, SingleMetaBase\n\n\nclass TimeWeightMeta(SingleMetaBase):\n    def __init__(self, hist_step_n, clip_weight=None, clip_method=\"clamp\"):\n        # clip_method includes \"tanh\" or \"clamp\"\n        super().__init__(hist_step_n, clip_weight, clip_method)\n        self.linear = nn.Linear(hist_step_n, 1)\n        self.k = nn.Parameter(torch.Tensor([8.0]))\n\n    def forward(self, time_perf, time_belong=None, return_preds=False):\n        hist_step_n = self.linear.in_features\n        # NOTE: the reshape order is very important\n        time_perf = time_perf.reshape(hist_step_n, time_perf.shape[0] // hist_step_n, *time_perf.shape[1:])\n        time_perf = torch.mean(time_perf, dim=1, keepdim=False)\n\n        preds = []\n        for i in range(time_perf.shape[1]):\n            preds.append(self.linear(time_perf[:, i]))\n        preds = torch.cat(preds)\n        preds = preds - torch.mean(preds)  # avoid using future information\n        preds = preds * self.k\n        if return_preds:\n            if time_belong is None:\n                return preds\n            else:\n                return time_belong @ preds\n        else:\n            weights = preds_to_weight_with_clamp(preds, self.clip_weight, self.clip_method)\n            if time_belong is None:\n                return weights\n            else:\n                return time_belong @ weights\n\n\nclass PredNet(nn.Module):\n    def __init__(self, step, hist_step_n, clip_weight=None, clip_method=\"tanh\", alpha: float = 0.0):\n        \"\"\"\n        Parameters\n        ----------\n        alpha : float\n            the regularization for sub model (useful when align meta model with linear submodel)\n        \"\"\"\n        super().__init__()\n        self.step = step\n        self.twm = TimeWeightMeta(hist_step_n=hist_step_n, clip_weight=clip_weight, clip_method=clip_method)\n        self.init_paramters(hist_step_n)\n        self.alpha = alpha\n\n    def get_sample_weights(self, X, time_perf, time_belong, ignore_weight=False):\n        weights = torch.from_numpy(np.ones(X.shape[0])).float().to(X.device)\n        if not ignore_weight:\n            if time_perf is not None:\n                weights_t = self.twm(time_perf, time_belong)\n                weights = weights * weights_t\n        return weights\n\n    def forward(self, X, y, time_perf, time_belong, X_test, ignore_weight=False):\n        \"\"\"Please refer to the docs of MetaTaskDS for the description of the variables\"\"\"\n        weights = self.get_sample_weights(X, time_perf, time_belong, ignore_weight=ignore_weight)\n        X_w = X.T * weights.view(1, -1)\n        theta = torch.inverse(X_w @ X + self.alpha * torch.eye(X_w.shape[0])) @ X_w @ y\n        return X_test @ theta, weights\n\n    def init_paramters(self, hist_step_n):\n        self.twm.linear.weight.data = 1.0 / hist_step_n + self.twm.linear.weight.data * 0.01\n        self.twm.linear.bias.data.fill_(0.0)\n"
  },
  {
    "path": "qlib/contrib/meta/data_selection/utils.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport numpy as np\nimport torch\nfrom torch import nn\n\nfrom qlib.constant import EPS\nfrom qlib.log import get_module_logger\n\n\nclass ICLoss(nn.Module):\n    def __init__(self, skip_size=50):\n        super().__init__()\n        self.skip_size = skip_size\n\n    def forward(self, pred, y, idx):\n        \"\"\"forward.\n        FIXME:\n        - Some times it will be a slightly different from the result from `pandas.corr()`\n        - It may be caused by the precision problem of model;\n\n        :param pred:\n        :param y:\n        :param idx: Assume the level of the idx is (date, inst), and it is sorted\n        \"\"\"\n        prev = None\n        diff_point = []\n        for i, (date, inst) in enumerate(idx):\n            if date != prev:\n                diff_point.append(i)\n            prev = date\n        diff_point.append(None)\n        # The lengths of diff_point will be one more larger then diff_point\n\n        ic_all = 0.0\n        skip_n = 0\n        for start_i, end_i in zip(diff_point, diff_point[1:]):\n            pred_focus = pred[start_i:end_i]  # TODO: just for fake\n            if pred_focus.shape[0] < self.skip_size:\n                # skip some days which have very small amount of stock.\n                skip_n += 1\n                continue\n            y_focus = y[start_i:end_i]\n            if pred_focus.std() < EPS or y_focus.std() < EPS:\n                # These cases often happend at the end of test data.\n                # Usually caused by fillna(0.)\n                skip_n += 1\n                continue\n\n            ic_day = torch.dot(\n                (pred_focus - pred_focus.mean()) / np.sqrt(pred_focus.shape[0]) / pred_focus.std(),\n                (y_focus - y_focus.mean()) / np.sqrt(y_focus.shape[0]) / y_focus.std(),\n            )\n            ic_all += ic_day\n        if len(diff_point) - 1 - skip_n <= 0:\n            __import__(\"ipdb\").set_trace()\n            raise ValueError(\"No enough data for calculating IC\")\n        if skip_n > 0:\n            get_module_logger(\"ICLoss\").info(\n                f\"{skip_n} days are skipped due to zero std or small scale of valid samples.\"\n            )\n        ic_mean = ic_all / (len(diff_point) - 1 - skip_n)\n        return -ic_mean  # ic loss\n\n\ndef preds_to_weight_with_clamp(preds, clip_weight=None, clip_method=\"tanh\"):\n    \"\"\"\n    Clip the weights.\n\n    Parameters\n    ----------\n    clip_weight: float\n        The clip threshold.\n    clip_method: str\n        The clip method. Current available: \"clamp\", \"tanh\", and \"sigmoid\".\n    \"\"\"\n    if clip_weight is not None:\n        if clip_method == \"clamp\":\n            weights = torch.exp(preds)\n            weights = weights.clamp(1.0 / clip_weight, clip_weight)\n        elif clip_method == \"tanh\":\n            weights = torch.exp(torch.tanh(preds) * np.log(clip_weight))\n        elif clip_method == \"sigmoid\":\n            # intuitively assume its sum is 1\n            if clip_weight == 0.0:\n                weights = torch.ones_like(preds)\n            else:\n                sm = nn.Sigmoid()\n                weights = sm(preds) * clip_weight  # TODO: The clip_weight is useless here.\n                weights = weights / torch.sum(weights) * weights.numel()\n        else:\n            raise ValueError(\"Unknown clip_method\")\n    else:\n        weights = torch.exp(preds)\n    return weights\n\n\nclass SingleMetaBase(nn.Module):\n    def __init__(self, hist_n, clip_weight=None, clip_method=\"clamp\"):\n        # method can be tanh or clamp\n        super().__init__()\n        self.clip_weight = clip_weight\n        if clip_method in [\"tanh\", \"clamp\"]:\n            if self.clip_weight is not None and self.clip_weight < 1.0:\n                self.clip_weight = 1 / self.clip_weight\n        self.clip_method = clip_method\n\n    def is_enabled(self):\n        if self.clip_weight is None:\n            return True\n        if self.clip_method == \"sigmoid\":\n            if self.clip_weight > 0.0:\n                return True\n        else:\n            if self.clip_weight > 1.0:\n                return True\n        return False\n"
  },
  {
    "path": "qlib/contrib/model/__init__.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\ntry:\n    from .catboost_model import CatBoostModel\nexcept ModuleNotFoundError:\n    CatBoostModel = None\n    print(\"ModuleNotFoundError. CatBoostModel are skipped. (optional: maybe installing CatBoostModel can fix it.)\")\ntry:\n    from .double_ensemble import DEnsembleModel\n    from .gbdt import LGBModel\nexcept ModuleNotFoundError:\n    DEnsembleModel, LGBModel = None, None\n    print(\n        \"ModuleNotFoundError. DEnsembleModel and LGBModel are skipped. (optional: maybe installing lightgbm can fix it.)\"\n    )\ntry:\n    from .xgboost import XGBModel\nexcept ModuleNotFoundError:\n    XGBModel = None\n    print(\"ModuleNotFoundError. XGBModel is skipped(optional: maybe installing xgboost can fix it).\")\ntry:\n    from .linear import LinearModel\nexcept ModuleNotFoundError:\n    LinearModel = None\n    print(\"ModuleNotFoundError. LinearModel is skipped(optional: maybe installing scipy and sklearn can fix it).\")\n# import pytorch models\ntry:\n    from .pytorch_alstm import ALSTM\n    from .pytorch_gats import GATs\n    from .pytorch_gru import GRU\n    from .pytorch_lstm import LSTM\n    from .pytorch_nn import DNNModelPytorch\n    from .pytorch_tabnet import TabnetModel\n    from .pytorch_sfm import SFM_Model\n    from .pytorch_tcn import TCN\n    from .pytorch_add import ADD\n\n    pytorch_classes = (ALSTM, GATs, GRU, LSTM, DNNModelPytorch, TabnetModel, SFM_Model, TCN, ADD)\nexcept ModuleNotFoundError:\n    pytorch_classes = ()\n    print(\"ModuleNotFoundError.  PyTorch models are skipped (optional: maybe installing pytorch can fix it).\")\n\nall_model_classes = (CatBoostModel, DEnsembleModel, LGBModel, XGBModel, LinearModel) + pytorch_classes\n"
  },
  {
    "path": "qlib/contrib/model/catboost_model.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport numpy as np\nimport pandas as pd\nfrom typing import Text, Union\nfrom catboost import Pool, CatBoost\nfrom catboost.utils import get_gpu_device_count\n\nfrom ...model.base import Model\nfrom ...data.dataset import DatasetH\nfrom ...data.dataset.handler import DataHandlerLP\nfrom ...model.interpret.base import FeatureInt\nfrom ...data.dataset.weight import Reweighter\n\n\nclass CatBoostModel(Model, FeatureInt):\n    \"\"\"CatBoost Model\"\"\"\n\n    def __init__(self, loss=\"RMSE\", **kwargs):\n        # There are more options\n        if loss not in {\"RMSE\", \"Logloss\"}:\n            raise NotImplementedError\n        self._params = {\"loss_function\": loss}\n        self._params.update(kwargs)\n        self.model = None\n\n    def fit(\n        self,\n        dataset: DatasetH,\n        num_boost_round=1000,\n        early_stopping_rounds=50,\n        verbose_eval=20,\n        evals_result=dict(),\n        reweighter=None,\n        **kwargs,\n    ):\n        df_train, df_valid = dataset.prepare(\n            [\"train\", \"valid\"],\n            col_set=[\"feature\", \"label\"],\n            data_key=DataHandlerLP.DK_L,\n        )\n        if df_train.empty or df_valid.empty:\n            raise ValueError(\"Empty data from dataset, please check your dataset config.\")\n        x_train, y_train = df_train[\"feature\"], df_train[\"label\"]\n        x_valid, y_valid = df_valid[\"feature\"], df_valid[\"label\"]\n\n        # CatBoost needs 1D array as its label\n        if y_train.values.ndim == 2 and y_train.values.shape[1] == 1:\n            y_train_1d, y_valid_1d = np.squeeze(y_train.values), np.squeeze(y_valid.values)\n        else:\n            raise ValueError(\"CatBoost doesn't support multi-label training\")\n\n        if reweighter is None:\n            w_train = None\n            w_valid = None\n        elif isinstance(reweighter, Reweighter):\n            w_train = reweighter.reweight(df_train).values\n            w_valid = reweighter.reweight(df_valid).values\n        else:\n            raise ValueError(\"Unsupported reweighter type.\")\n\n        train_pool = Pool(data=x_train, label=y_train_1d, weight=w_train)\n        valid_pool = Pool(data=x_valid, label=y_valid_1d, weight=w_valid)\n\n        # Initialize the catboost model\n        self._params[\"iterations\"] = num_boost_round\n        self._params[\"early_stopping_rounds\"] = early_stopping_rounds\n        self._params[\"verbose_eval\"] = verbose_eval\n        self._params[\"task_type\"] = \"GPU\" if get_gpu_device_count() > 0 else \"CPU\"\n        self.model = CatBoost(self._params, **kwargs)\n\n        # train the model\n        self.model.fit(train_pool, eval_set=valid_pool, use_best_model=True, **kwargs)\n\n        evals_result = self.model.get_evals_result()\n        evals_result[\"train\"] = list(evals_result[\"learn\"].values())[0]\n        evals_result[\"valid\"] = list(evals_result[\"validation\"].values())[0]\n\n    def predict(self, dataset: DatasetH, segment: Union[Text, slice] = \"test\"):\n        if self.model is None:\n            raise ValueError(\"model is not fitted yet!\")\n        x_test = dataset.prepare(segment, col_set=\"feature\", data_key=DataHandlerLP.DK_I)\n        return pd.Series(self.model.predict(x_test.values), index=x_test.index)\n\n    def get_feature_importance(self, *args, **kwargs) -> pd.Series:\n        \"\"\"get feature importance\n\n        Notes\n        -----\n            parameters references:\n            https://catboost.ai/docs/concepts/python-reference_catboost_get_feature_importance.html#python-reference_catboost_get_feature_importance\n        \"\"\"\n        return pd.Series(\n            data=self.model.get_feature_importance(*args, **kwargs), index=self.model.feature_names_\n        ).sort_values(ascending=False)\n\n\nif __name__ == \"__main__\":\n    cat = CatBoostModel()\n"
  },
  {
    "path": "qlib/contrib/model/double_ensemble.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport lightgbm as lgb\nimport numpy as np\nimport pandas as pd\nfrom typing import Text, Union\nfrom ...model.base import Model\nfrom ...data.dataset import DatasetH\nfrom ...data.dataset.handler import DataHandlerLP\nfrom ...model.interpret.base import FeatureInt\nfrom ...log import get_module_logger\n\n\nclass DEnsembleModel(Model, FeatureInt):\n    \"\"\"Double Ensemble Model\"\"\"\n\n    def __init__(\n        self,\n        base_model=\"gbm\",\n        loss=\"mse\",\n        num_models=6,\n        enable_sr=True,\n        enable_fs=True,\n        alpha1=1.0,\n        alpha2=1.0,\n        bins_sr=10,\n        bins_fs=5,\n        decay=None,\n        sample_ratios=None,\n        sub_weights=None,\n        epochs=100,\n        early_stopping_rounds=None,\n        **kwargs,\n    ):\n        self.base_model = base_model  # \"gbm\" or \"mlp\", specifically, we use lgbm for \"gbm\"\n        self.num_models = num_models  # the number of sub-models\n        self.enable_sr = enable_sr\n        self.enable_fs = enable_fs\n        self.alpha1 = alpha1\n        self.alpha2 = alpha2\n        self.bins_sr = bins_sr\n        self.bins_fs = bins_fs\n        self.decay = decay\n        if sample_ratios is None:  # the default values for sample_ratios\n            sample_ratios = [0.8, 0.7, 0.6, 0.5, 0.4]\n        if sub_weights is None:  # the default values for sub_weights\n            sub_weights = [1] * self.num_models\n        if not len(sample_ratios) == bins_fs:\n            raise ValueError(\"The length of sample_ratios should be equal to bins_fs.\")\n        self.sample_ratios = sample_ratios\n        if not len(sub_weights) == num_models:\n            raise ValueError(\"The length of sub_weights should be equal to num_models.\")\n        self.sub_weights = sub_weights\n        self.epochs = epochs\n        self.logger = get_module_logger(\"DEnsembleModel\")\n        self.logger.info(\"Double Ensemble Model...\")\n        self.ensemble = []  # the current ensemble model, a list contains all the sub-models\n        self.sub_features = []  # the features for each sub model in the form of pandas.Index\n        self.params = {\"objective\": loss}\n        self.params.update(kwargs)\n        self.loss = loss\n        self.early_stopping_rounds = early_stopping_rounds\n\n    def fit(self, dataset: DatasetH):\n        df_train, df_valid = dataset.prepare(\n            [\"train\", \"valid\"], col_set=[\"feature\", \"label\"], data_key=DataHandlerLP.DK_L\n        )\n        if df_train.empty or df_valid.empty:\n            raise ValueError(\"Empty data from dataset, please check your dataset config.\")\n        x_train, y_train = df_train[\"feature\"], df_train[\"label\"]\n        # initialize the sample weights\n        N, F = x_train.shape\n        weights = pd.Series(np.ones(N, dtype=float))\n        # initialize the features\n        features = x_train.columns\n        pred_sub = pd.DataFrame(np.zeros((N, self.num_models), dtype=float), index=x_train.index)\n        # train sub-models\n        for k in range(self.num_models):\n            self.sub_features.append(features)\n            self.logger.info(\"Training sub-model: ({}/{})\".format(k + 1, self.num_models))\n            model_k = self.train_submodel(df_train, df_valid, weights, features)\n            self.ensemble.append(model_k)\n            # no further sample re-weight and feature selection needed for the last sub-model\n            if k + 1 == self.num_models:\n                break\n\n            self.logger.info(\"Retrieving loss curve and loss values...\")\n            loss_curve = self.retrieve_loss_curve(model_k, df_train, features)\n            pred_k = self.predict_sub(model_k, df_train, features)\n            pred_sub.iloc[:, k] = pred_k\n            pred_ensemble = (pred_sub.iloc[:, : k + 1] * self.sub_weights[0 : k + 1]).sum(axis=1) / np.sum(\n                self.sub_weights[0 : k + 1]\n            )\n            loss_values = pd.Series(self.get_loss(y_train.values.squeeze(), pred_ensemble.values))\n\n            if self.enable_sr:\n                self.logger.info(\"Sample re-weighting...\")\n                weights = self.sample_reweight(loss_curve, loss_values, k + 1)\n\n            if self.enable_fs:\n                self.logger.info(\"Feature selection...\")\n                features = self.feature_selection(df_train, loss_values)\n\n    def train_submodel(self, df_train, df_valid, weights, features):\n        dtrain, dvalid = self._prepare_data_gbm(df_train, df_valid, weights, features)\n        evals_result = dict()\n\n        callbacks = [lgb.log_evaluation(20), lgb.record_evaluation(evals_result)]\n        if self.early_stopping_rounds:\n            callbacks.append(lgb.early_stopping(self.early_stopping_rounds))\n            self.logger.info(\"Training with early_stopping...\")\n\n        model = lgb.train(\n            self.params,\n            dtrain,\n            num_boost_round=self.epochs,\n            valid_sets=[dtrain, dvalid],\n            valid_names=[\"train\", \"valid\"],\n            callbacks=callbacks,\n        )\n        evals_result[\"train\"] = list(evals_result[\"train\"].values())[0]\n        evals_result[\"valid\"] = list(evals_result[\"valid\"].values())[0]\n        return model\n\n    def _prepare_data_gbm(self, df_train, df_valid, weights, features):\n        x_train, y_train = df_train[\"feature\"].loc[:, features], df_train[\"label\"]\n        x_valid, y_valid = df_valid[\"feature\"].loc[:, features], df_valid[\"label\"]\n\n        # Lightgbm need 1D array as its label\n        if y_train.values.ndim == 2 and y_train.values.shape[1] == 1:\n            y_train, y_valid = np.squeeze(y_train.values), np.squeeze(y_valid.values)\n        else:\n            raise ValueError(\"LightGBM doesn't support multi-label training\")\n\n        dtrain = lgb.Dataset(x_train, label=y_train, weight=weights)\n        dvalid = lgb.Dataset(x_valid, label=y_valid)\n        return dtrain, dvalid\n\n    def sample_reweight(self, loss_curve, loss_values, k_th):\n        \"\"\"\n        the SR module of Double Ensemble\n        :param loss_curve: the shape is NxT\n        the loss curve for the previous sub-model, where the element (i, t) if the error on the i-th sample\n        after the t-th iteration in the training of the previous sub-model.\n        :param loss_values: the shape is N\n        the loss of the current ensemble on the i-th sample.\n        :param k_th: the index of the current sub-model, starting from 1\n        :return: weights\n        the weights for all the samples.\n        \"\"\"\n        # normalize loss_curve and loss_values with ranking\n        loss_curve_norm = loss_curve.rank(axis=0, pct=True)\n        loss_values_norm = (-loss_values).rank(pct=True)\n\n        # calculate l_start and l_end from loss_curve\n        N, T = loss_curve.shape\n        part = np.maximum(int(T * 0.1), 1)\n        l_start = loss_curve_norm.iloc[:, :part].mean(axis=1)\n        l_end = loss_curve_norm.iloc[:, -part:].mean(axis=1)\n\n        # calculate h-value for each sample\n        h1 = loss_values_norm\n        h2 = (l_end / l_start).rank(pct=True)\n        h = pd.DataFrame({\"h_value\": self.alpha1 * h1 + self.alpha2 * h2})\n\n        # calculate weights\n        h[\"bins\"] = pd.cut(h[\"h_value\"], self.bins_sr)\n        h_avg = h.groupby(\"bins\", group_keys=False, observed=False)[\"h_value\"].mean()\n        weights = pd.Series(np.zeros(N, dtype=float))\n        for b in h_avg.index:\n            weights[h[\"bins\"] == b] = 1.0 / (self.decay**k_th * h_avg[b] + 0.1)\n        return weights\n\n    def feature_selection(self, df_train, loss_values):\n        \"\"\"\n        the FS module of Double Ensemble\n        :param df_train: the shape is NxF\n        :param loss_values: the shape is N\n        the loss of the current ensemble on the i-th sample.\n        :return: res_feat: in the form of pandas.Index\n\n        \"\"\"\n        x_train, y_train = df_train[\"feature\"], df_train[\"label\"]\n        features = x_train.columns\n        N, F = x_train.shape\n        g = pd.DataFrame({\"g_value\": np.zeros(F, dtype=float)})\n        M = len(self.ensemble)\n\n        # shuffle specific columns and calculate g-value for each feature\n        x_train_tmp = x_train.copy()\n        for i_f, feat in enumerate(features):\n            x_train_tmp.loc[:, feat] = np.random.permutation(x_train_tmp.loc[:, feat].values)\n            pred = pd.Series(np.zeros(N), index=x_train_tmp.index)\n            for i_s, submodel in enumerate(self.ensemble):\n                pred += (\n                    pd.Series(\n                        submodel.predict(x_train_tmp.loc[:, self.sub_features[i_s]].values), index=x_train_tmp.index\n                    )\n                    / M\n                )\n            loss_feat = self.get_loss(y_train.values.squeeze(), pred.values)\n            g.loc[i_f, \"g_value\"] = np.mean(loss_feat - loss_values) / (np.std(loss_feat - loss_values) + 1e-7)\n            x_train_tmp.loc[:, feat] = x_train.loc[:, feat].copy()\n\n        # one column in train features is all-nan # if g['g_value'].isna().any()\n        g[\"g_value\"].replace(np.nan, 0, inplace=True)\n\n        # divide features into bins_fs bins\n        g[\"bins\"] = pd.cut(g[\"g_value\"], self.bins_fs)\n\n        # randomly sample features from bins to construct the new features\n        res_feat = []\n        sorted_bins = sorted(g[\"bins\"].unique(), reverse=True)\n        for i_b, b in enumerate(sorted_bins):\n            b_feat = features[g[\"bins\"] == b]\n            num_feat = int(np.ceil(self.sample_ratios[i_b] * len(b_feat)))\n            res_feat = res_feat + np.random.choice(b_feat, size=num_feat, replace=False).tolist()\n        return pd.Index(set(res_feat))\n\n    def get_loss(self, label, pred):\n        if self.loss == \"mse\":\n            return (label - pred) ** 2\n        else:\n            raise ValueError(\"not implemented yet\")\n\n    def retrieve_loss_curve(self, model, df_train, features):\n        if self.base_model == \"gbm\":\n            num_trees = model.num_trees()\n            x_train, y_train = df_train[\"feature\"].loc[:, features], df_train[\"label\"]\n            # Lightgbm need 1D array as its label\n            if y_train.values.ndim == 2 and y_train.values.shape[1] == 1:\n                y_train = np.squeeze(y_train.values)\n            else:\n                raise ValueError(\"LightGBM doesn't support multi-label training\")\n\n            N = x_train.shape[0]\n            loss_curve = pd.DataFrame(np.zeros((N, num_trees)))\n            pred_tree = np.zeros(N, dtype=float)\n            for i_tree in range(num_trees):\n                pred_tree += model.predict(x_train.values, start_iteration=i_tree, num_iteration=1)\n                loss_curve.iloc[:, i_tree] = self.get_loss(y_train, pred_tree)\n        else:\n            raise ValueError(\"not implemented yet\")\n        return loss_curve\n\n    def predict(self, dataset: DatasetH, segment: Union[Text, slice] = \"test\"):\n        if self.ensemble is None:\n            raise ValueError(\"model is not fitted yet!\")\n        x_test = dataset.prepare(segment, col_set=\"feature\", data_key=DataHandlerLP.DK_I)\n        pred = pd.Series(np.zeros(x_test.shape[0]), index=x_test.index)\n        for i_sub, submodel in enumerate(self.ensemble):\n            feat_sub = self.sub_features[i_sub]\n            pred += (\n                pd.Series(submodel.predict(x_test.loc[:, feat_sub].values), index=x_test.index)\n                * self.sub_weights[i_sub]\n            )\n        pred = pred / np.sum(self.sub_weights)\n        return pred\n\n    def predict_sub(self, submodel, df_data, features):\n        x_data = df_data[\"feature\"].loc[:, features]\n        pred_sub = pd.Series(submodel.predict(x_data.values), index=x_data.index)\n        return pred_sub\n\n    def get_feature_importance(self, *args, **kwargs) -> pd.Series:\n        \"\"\"get feature importance\n\n        Notes\n        -----\n            parameters reference:\n            https://lightgbm.readthedocs.io/en/latest/pythonapi/lightgbm.Booster.html?highlight=feature_importance#lightgbm.Booster.feature_importance\n        \"\"\"\n        res = []\n        for _model, _weight in zip(self.ensemble, self.sub_weights):\n            res.append(pd.Series(_model.feature_importance(*args, **kwargs), index=_model.feature_name()) * _weight)\n        return pd.concat(res, axis=1, sort=False).sum(axis=1).sort_values(ascending=False)\n"
  },
  {
    "path": "qlib/contrib/model/gbdt.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport numpy as np\nimport pandas as pd\nimport lightgbm as lgb\nfrom typing import List, Text, Tuple, Union\nfrom ...model.base import ModelFT\nfrom ...data.dataset import DatasetH\nfrom ...data.dataset.handler import DataHandlerLP\nfrom ...model.interpret.base import LightGBMFInt\nfrom ...data.dataset.weight import Reweighter\nfrom qlib.workflow import R\n\n\nclass LGBModel(ModelFT, LightGBMFInt):\n    \"\"\"LightGBM Model\"\"\"\n\n    def __init__(self, loss=\"mse\", early_stopping_rounds=50, num_boost_round=1000, **kwargs):\n        if loss not in {\"mse\", \"binary\"}:\n            raise NotImplementedError\n        self.params = {\"objective\": loss, \"verbosity\": -1}\n        self.params.update(kwargs)\n        self.early_stopping_rounds = early_stopping_rounds\n        self.num_boost_round = num_boost_round\n        self.model = None\n\n    def _prepare_data(self, dataset: DatasetH, reweighter=None) -> List[Tuple[lgb.Dataset, str]]:\n        \"\"\"\n        The motivation of current version is to make validation optional\n        - train segment is necessary;\n        \"\"\"\n        ds_l = []\n        assert \"train\" in dataset.segments\n        for key in [\"train\", \"valid\"]:\n            if key in dataset.segments:\n                df = dataset.prepare(key, col_set=[\"feature\", \"label\"], data_key=DataHandlerLP.DK_L)\n                if df.empty:\n                    raise ValueError(\"Empty data from dataset, please check your dataset config.\")\n                x, y = df[\"feature\"], df[\"label\"]\n\n                # Lightgbm need 1D array as its label\n                if y.values.ndim == 2 and y.values.shape[1] == 1:\n                    y = np.squeeze(y.values)\n                else:\n                    raise ValueError(\"LightGBM doesn't support multi-label training\")\n\n                if reweighter is None:\n                    w = None\n                elif isinstance(reweighter, Reweighter):\n                    w = reweighter.reweight(df)\n                else:\n                    raise ValueError(\"Unsupported reweighter type.\")\n                ds_l.append((lgb.Dataset(x.values, label=y, weight=w, free_raw_data=False), key))\n        return ds_l\n\n    def fit(\n        self,\n        dataset: DatasetH,\n        num_boost_round=None,\n        early_stopping_rounds=None,\n        verbose_eval=20,\n        evals_result=None,\n        reweighter=None,\n        **kwargs,\n    ):\n        if evals_result is None:\n            evals_result = {}  # in case of unsafety of Python default values\n        ds_l = self._prepare_data(dataset, reweighter)\n        ds, names = list(zip(*ds_l))\n        early_stopping_callback = lgb.early_stopping(\n            self.early_stopping_rounds if early_stopping_rounds is None else early_stopping_rounds\n        )\n        # NOTE: if you encounter error here. Please upgrade your lightgbm\n        verbose_eval_callback = lgb.log_evaluation(period=verbose_eval)\n        evals_result_callback = lgb.record_evaluation(evals_result)\n        self.model = lgb.train(\n            self.params,\n            ds[0],  # training dataset\n            num_boost_round=self.num_boost_round if num_boost_round is None else num_boost_round,\n            valid_sets=ds,\n            valid_names=names,\n            callbacks=[early_stopping_callback, verbose_eval_callback, evals_result_callback],\n            **kwargs,\n        )\n        for k in names:\n            for key, val in evals_result[k].items():\n                name = f\"{key}.{k}\"\n                for epoch, m in enumerate(val):\n                    R.log_metrics(**{name.replace(\"@\", \"_\"): m}, step=epoch)\n\n    def predict(self, dataset: DatasetH, segment: Union[Text, slice] = \"test\"):\n        if self.model is None:\n            raise ValueError(\"model is not fitted yet!\")\n        x_test = dataset.prepare(segment, col_set=\"feature\", data_key=DataHandlerLP.DK_I)\n        return pd.Series(self.model.predict(x_test.values), index=x_test.index)\n\n    def finetune(self, dataset: DatasetH, num_boost_round=10, verbose_eval=20, reweighter=None):\n        \"\"\"\n        finetune model\n\n        Parameters\n        ----------\n        dataset : DatasetH\n            dataset for finetuning\n        num_boost_round : int\n            number of round to finetune model\n        verbose_eval : int\n            verbose level\n        \"\"\"\n        # Based on existing model and finetune by train more rounds\n        ds_l = self._prepare_data(dataset, reweighter)\n        dtrain, _ = ds_l[0]\n\n        if dtrain.construct().num_data() == 0:\n            raise ValueError(\"Empty data from dataset, please check your dataset config.\")\n        verbose_eval_callback = lgb.log_evaluation(period=verbose_eval)\n        self.model = lgb.train(\n            self.params,\n            dtrain,\n            num_boost_round=num_boost_round,\n            init_model=self.model,\n            valid_sets=[dtrain],\n            valid_names=[\"train\"],\n            callbacks=[verbose_eval_callback],\n        )\n"
  },
  {
    "path": "qlib/contrib/model/highfreq_gdbt_model.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport warnings\nimport numpy as np\nimport pandas as pd\nimport lightgbm as lgb\n\nfrom ...model.base import ModelFT\nfrom ...data.dataset import DatasetH\nfrom ...data.dataset.handler import DataHandlerLP\nfrom ...model.interpret.base import LightGBMFInt\n\n\nclass HFLGBModel(ModelFT, LightGBMFInt):\n    \"\"\"LightGBM Model for high frequency prediction\"\"\"\n\n    def __init__(self, loss=\"mse\", **kwargs):\n        if loss not in {\"mse\", \"binary\"}:\n            raise NotImplementedError\n        self.params = {\"objective\": loss, \"verbosity\": -1}\n        self.params.update(kwargs)\n        self.model = None\n\n    def _cal_signal_metrics(self, y_test, l_cut, r_cut):\n        \"\"\"\n        Calcaute the signal metrics by daily level\n        \"\"\"\n        up_pre, down_pre = [], []\n        up_alpha_ll, down_alpha_ll = [], []\n        for date in y_test.index.get_level_values(0).unique():\n            df_res = y_test.loc[date].sort_values(\"pred\")\n            if int(l_cut * len(df_res)) < 10:\n                warnings.warn(\"Warning: threhold is too low or instruments number is not enough\")\n                continue\n            top = df_res.iloc[: int(l_cut * len(df_res))]\n            bottom = df_res.iloc[int(r_cut * len(df_res)) :]\n\n            down_precision = len(top[top[top.columns[0]] < 0]) / (len(top))\n            up_precision = len(bottom[bottom[top.columns[0]] > 0]) / (len(bottom))\n\n            down_alpha = top[top.columns[0]].mean()\n            up_alpha = bottom[bottom.columns[0]].mean()\n\n            up_pre.append(up_precision)\n            down_pre.append(down_precision)\n            up_alpha_ll.append(up_alpha)\n            down_alpha_ll.append(down_alpha)\n\n        return (\n            np.array(up_pre).mean(),\n            np.array(down_pre).mean(),\n            np.array(up_alpha_ll).mean(),\n            np.array(down_alpha_ll).mean(),\n        )\n\n    def hf_signal_test(self, dataset: DatasetH, threhold=0.2):\n        \"\"\"\n        Test the signal in high frequency test set\n        \"\"\"\n        if self.model is None:\n            raise ValueError(\"Model hasn't been trained yet\")\n        df_test = dataset.prepare(\"test\", col_set=[\"feature\", \"label\"], data_key=DataHandlerLP.DK_I)\n        df_test.dropna(inplace=True)\n        x_test, y_test = df_test[\"feature\"], df_test[\"label\"]\n        # Convert label into alpha\n        y_test[y_test.columns[0]] = y_test[y_test.columns[0]] - y_test[y_test.columns[0]].mean(level=0)\n\n        res = pd.Series(self.model.predict(x_test.values), index=x_test.index)\n        y_test[\"pred\"] = res\n\n        up_p, down_p, up_a, down_a = self._cal_signal_metrics(y_test, threhold, 1 - threhold)\n        print(\"===============================\")\n        print(\"High frequency signal test\")\n        print(\"===============================\")\n        print(\"Test set precision: \")\n        print(\"Positive precision: {}, Negative precision: {}\".format(up_p, down_p))\n        print(\"Test Alpha Average in test set: \")\n        print(\"Positive average alpha: {}, Negative average alpha: {}\".format(up_a, down_a))\n\n    def _prepare_data(self, dataset: DatasetH):\n        df_train, df_valid = dataset.prepare(\n            [\"train\", \"valid\"], col_set=[\"feature\", \"label\"], data_key=DataHandlerLP.DK_L\n        )\n        if df_train.empty or df_valid.empty:\n            raise ValueError(\"Empty data from dataset, please check your dataset config.\")\n\n        x_train, y_train = df_train[\"feature\"], df_train[\"label\"]\n        x_valid, y_valid = df_valid[\"feature\"], df_valid[\"label\"]\n        if y_train.values.ndim == 2 and y_train.values.shape[1] == 1:\n            l_name = df_train[\"label\"].columns[0]\n            # Convert label into alpha\n            df_train.loc[:, (\"label\", l_name)] = (\n                df_train.loc[:, (\"label\", l_name)]\n                - df_train.loc[:, (\"label\", l_name)].groupby(level=0, group_keys=False).mean()\n            )\n            df_valid.loc[:, (\"label\", l_name)] = (\n                df_valid.loc[:, (\"label\", l_name)]\n                - df_valid.loc[:, (\"label\", l_name)].groupby(level=0, group_keys=False).mean()\n            )\n\n            def mapping_fn(x):\n                return 0 if x < 0 else 1\n\n            df_train[\"label_c\"] = df_train[\"label\"][l_name].apply(mapping_fn)\n            df_valid[\"label_c\"] = df_valid[\"label\"][l_name].apply(mapping_fn)\n            x_train, y_train = df_train[\"feature\"], df_train[\"label_c\"].values\n            x_valid, y_valid = df_valid[\"feature\"], df_valid[\"label_c\"].values\n        else:\n            raise ValueError(\"LightGBM doesn't support multi-label training\")\n\n        dtrain = lgb.Dataset(x_train, label=y_train)\n        dvalid = lgb.Dataset(x_valid, label=y_valid)\n        return dtrain, dvalid\n\n    def fit(\n        self,\n        dataset: DatasetH,\n        num_boost_round=1000,\n        early_stopping_rounds=50,\n        verbose_eval=20,\n        evals_result=None,\n    ):\n        if evals_result is None:\n            evals_result = dict()\n        dtrain, dvalid = self._prepare_data(dataset)\n        early_stopping_callback = lgb.early_stopping(early_stopping_rounds)\n        verbose_eval_callback = lgb.log_evaluation(period=verbose_eval)\n        evals_result_callback = lgb.record_evaluation(evals_result)\n        self.model = lgb.train(\n            self.params,\n            dtrain,\n            num_boost_round=num_boost_round,\n            valid_sets=[dtrain, dvalid],\n            valid_names=[\"train\", \"valid\"],\n            callbacks=[early_stopping_callback, verbose_eval_callback, evals_result_callback],\n        )\n        evals_result[\"train\"] = list(evals_result[\"train\"].values())[0]\n        evals_result[\"valid\"] = list(evals_result[\"valid\"].values())[0]\n\n    def predict(self, dataset):\n        if self.model is None:\n            raise ValueError(\"model is not fitted yet!\")\n        x_test = dataset.prepare(\"test\", col_set=\"feature\", data_key=DataHandlerLP.DK_I)\n        return pd.Series(self.model.predict(x_test.values), index=x_test.index)\n\n    def finetune(self, dataset: DatasetH, num_boost_round=10, verbose_eval=20):\n        \"\"\"\n        finetune model\n\n        Parameters\n        ----------\n        dataset : DatasetH\n            dataset for finetuning\n        num_boost_round : int\n            number of round to finetune model\n        verbose_eval : int\n            verbose level\n        \"\"\"\n        # Based on existing model and finetune by train more rounds\n        dtrain, _ = self._prepare_data(dataset)\n        verbose_eval_callback = lgb.log_evaluation(period=verbose_eval)\n        self.model = lgb.train(\n            self.params,\n            dtrain,\n            num_boost_round=num_boost_round,\n            init_model=self.model,\n            valid_sets=[dtrain],\n            valid_names=[\"train\"],\n            callbacks=[verbose_eval_callback],\n        )\n"
  },
  {
    "path": "qlib/contrib/model/linear.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport numpy as np\nimport pandas as pd\nfrom typing import Text, Union\nfrom qlib.log import get_module_logger\nfrom qlib.data.dataset.weight import Reweighter\nfrom scipy.optimize import nnls\nfrom sklearn.linear_model import LinearRegression, Ridge, Lasso\n\nfrom ...model.base import Model\nfrom ...data.dataset import DatasetH\nfrom ...data.dataset.handler import DataHandlerLP\n\n\nclass LinearModel(Model):\n    \"\"\"Linear Model\n\n    Solve one of the following regression problems:\n        - `ols`: min_w |y - Xw|^2_2\n        - `nnls`: min_w |y - Xw|^2_2, s.t. w >= 0\n        - `ridge`: min_w |y - Xw|^2_2 + \\alpha*|w|^2_2\n        - `lasso`: min_w |y - Xw|^2_2 + \\alpha*|w|_1\n    where `w` is the regression coefficient.\n    \"\"\"\n\n    OLS = \"ols\"\n    NNLS = \"nnls\"\n    RIDGE = \"ridge\"\n    LASSO = \"lasso\"\n\n    def __init__(self, estimator=\"ols\", alpha=0.0, fit_intercept=False, include_valid: bool = False):\n        \"\"\"\n        Parameters\n        ----------\n        estimator : str\n            which estimator to use for linear regression\n        alpha : float\n            l1 or l2 regularization parameter\n        fit_intercept : bool\n            whether fit intercept\n        include_valid: bool\n            Should the validation data be included for training?\n            The validation data should be included\n        \"\"\"\n        assert estimator in [self.OLS, self.NNLS, self.RIDGE, self.LASSO], f\"unsupported estimator `{estimator}`\"\n        self.estimator = estimator\n\n        assert alpha == 0 or estimator in [self.RIDGE, self.LASSO], f\"alpha is only supported in `ridge`&`lasso`\"\n        self.alpha = alpha\n\n        self.fit_intercept = fit_intercept\n\n        self.coef_ = None\n        self.include_valid = include_valid\n\n    def fit(self, dataset: DatasetH, reweighter: Reweighter = None):\n        df_train = dataset.prepare(\"train\", col_set=[\"feature\", \"label\"], data_key=DataHandlerLP.DK_L)\n        if self.include_valid:\n            try:\n                df_valid = dataset.prepare(\"valid\", col_set=[\"feature\", \"label\"], data_key=DataHandlerLP.DK_L)\n                df_train = pd.concat([df_train, df_valid])\n            except KeyError:\n                get_module_logger(\"LinearModel\").info(\"include_valid=True, but valid does not exist\")\n        df_train = df_train.dropna()\n        if df_train.empty:\n            raise ValueError(\"Empty data from dataset, please check your dataset config.\")\n        if reweighter is not None:\n            w: pd.Series = reweighter.reweight(df_train)\n            w = w.values\n        else:\n            w = None\n        X, y = df_train[\"feature\"].values, np.squeeze(df_train[\"label\"].values)\n\n        if self.estimator in [self.OLS, self.RIDGE, self.LASSO]:\n            self._fit(X, y, w)\n        elif self.estimator == self.NNLS:\n            self._fit_nnls(X, y, w)\n        else:\n            raise ValueError(f\"unknown estimator `{self.estimator}`\")\n\n        return self\n\n    def _fit(self, X, y, w):\n        if self.estimator == self.OLS:\n            model = LinearRegression(fit_intercept=self.fit_intercept, copy_X=False)\n        else:\n            model = {self.RIDGE: Ridge, self.LASSO: Lasso}[self.estimator](\n                alpha=self.alpha, fit_intercept=self.fit_intercept, copy_X=False\n            )\n        model.fit(X, y, sample_weight=w)\n        self.coef_ = model.coef_\n        self.intercept_ = model.intercept_\n\n    def _fit_nnls(self, X, y, w=None):\n        if w is not None:\n            raise NotImplementedError(\"TODO: support nnls with weight\")  # TODO\n        if self.fit_intercept:\n            X = np.c_[X, np.ones(len(X))]  # NOTE: mem copy\n        coef = nnls(X, y)[0]\n        if self.fit_intercept:\n            self.coef_ = coef[:-1]\n            self.intercept_ = coef[-1]\n        else:\n            self.coef_ = coef\n            self.intercept_ = 0.0\n\n    def predict(self, dataset: DatasetH, segment: Union[Text, slice] = \"test\"):\n        if self.coef_ is None:\n            raise ValueError(\"model is not fitted yet!\")\n        x_test = dataset.prepare(segment, col_set=\"feature\", data_key=DataHandlerLP.DK_I)\n        return pd.Series(x_test.values @ self.coef_ + self.intercept_, index=x_test.index)\n"
  },
  {
    "path": "qlib/contrib/model/pytorch_adarnn.py",
    "content": "# Copyright (c) Microsoft Corporation.\nimport os\nfrom torch.utils.data import Dataset, DataLoader\n\nimport copy\nfrom typing import Text, Union\n\nimport numpy as np\nimport pandas as pd\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.optim as optim\nfrom torch.autograd import Function\nfrom qlib.contrib.model.pytorch_utils import count_parameters\nfrom qlib.data.dataset import DatasetH\nfrom qlib.data.dataset.handler import DataHandlerLP\nfrom qlib.log import get_module_logger\nfrom qlib.model.base import Model\nfrom qlib.utils import get_or_create_path\n\n\nclass ADARNN(Model):\n    \"\"\"ADARNN Model\n\n    Parameters\n    ----------\n    d_feat : int\n        input dimension for each time step\n    metric: str\n        the evaluation metric used in early stop\n    optimizer : str\n        optimizer name\n    GPU : str\n        the GPU ID(s) used for training\n    \"\"\"\n\n    def __init__(\n        self,\n        d_feat=6,\n        hidden_size=64,\n        num_layers=2,\n        dropout=0.0,\n        n_epochs=200,\n        pre_epoch=40,\n        dw=0.5,\n        loss_type=\"cosine\",\n        len_seq=60,\n        len_win=0,\n        lr=0.001,\n        metric=\"mse\",\n        batch_size=2000,\n        early_stop=20,\n        loss=\"mse\",\n        optimizer=\"adam\",\n        n_splits=2,\n        GPU=0,\n        seed=None,\n        **_,\n    ):\n        # Set logger.\n        self.logger = get_module_logger(\"ADARNN\")\n        self.logger.info(\"ADARNN pytorch version...\")\n        os.environ[\"CUDA_VISIBLE_DEVICES\"] = str(GPU)\n\n        # set hyper-parameters.\n        self.d_feat = d_feat\n        self.hidden_size = hidden_size\n        self.num_layers = num_layers\n        self.dropout = dropout\n        self.n_epochs = n_epochs\n        self.pre_epoch = pre_epoch\n        self.dw = dw\n        self.loss_type = loss_type\n        self.len_seq = len_seq\n        self.len_win = len_win\n        self.lr = lr\n        self.metric = metric\n        self.batch_size = batch_size\n        self.early_stop = early_stop\n        self.optimizer = optimizer.lower()\n        self.loss = loss\n        self.n_splits = n_splits\n        self.device = torch.device(\"cuda:%d\" % GPU if torch.cuda.is_available() and GPU >= 0 else \"cpu\")\n        self.seed = seed\n\n        self.logger.info(\n            \"ADARNN parameters setting:\"\n            \"\\nd_feat : {}\"\n            \"\\nhidden_size : {}\"\n            \"\\nnum_layers : {}\"\n            \"\\ndropout : {}\"\n            \"\\nn_epochs : {}\"\n            \"\\nlr : {}\"\n            \"\\nmetric : {}\"\n            \"\\nbatch_size : {}\"\n            \"\\nearly_stop : {}\"\n            \"\\noptimizer : {}\"\n            \"\\nloss_type : {}\"\n            \"\\nvisible_GPU : {}\"\n            \"\\nuse_GPU : {}\"\n            \"\\nseed : {}\".format(\n                d_feat,\n                hidden_size,\n                num_layers,\n                dropout,\n                n_epochs,\n                lr,\n                metric,\n                batch_size,\n                early_stop,\n                optimizer.lower(),\n                loss,\n                GPU,\n                self.use_gpu,\n                seed,\n            )\n        )\n\n        if self.seed is not None:\n            np.random.seed(self.seed)\n            torch.manual_seed(self.seed)\n\n        n_hiddens = [hidden_size for _ in range(num_layers)]\n        self.model = AdaRNN(\n            use_bottleneck=False,\n            bottleneck_width=64,\n            n_input=d_feat,\n            n_hiddens=n_hiddens,\n            n_output=1,\n            dropout=dropout,\n            model_type=\"AdaRNN\",\n            len_seq=len_seq,\n            trans_loss=loss_type,\n        )\n        self.logger.info(\"model:\\n{:}\".format(self.model))\n        self.logger.info(\"model size: {:.4f} MB\".format(count_parameters(self.model)))\n\n        if optimizer.lower() == \"adam\":\n            self.train_optimizer = optim.Adam(self.model.parameters(), lr=self.lr)\n        elif optimizer.lower() == \"gd\":\n            self.train_optimizer = optim.SGD(self.model.parameters(), lr=self.lr)\n        else:\n            raise NotImplementedError(\"optimizer {} is not supported!\".format(optimizer))\n\n        self.fitted = False\n        self.model.to(self.device)\n\n    @property\n    def use_gpu(self):\n        return self.device != torch.device(\"cpu\")\n\n    def train_AdaRNN(self, train_loader_list, epoch, dist_old=None, weight_mat=None):\n        self.model.train()\n        criterion = nn.MSELoss()\n        dist_mat = torch.zeros(self.num_layers, self.len_seq).to(self.device)\n        out_weight_list = None\n        for data_all in zip(*train_loader_list):\n            #  for data_all in zip(*train_loader_list):\n            self.train_optimizer.zero_grad()\n            list_feat = []\n            list_label = []\n            for data in data_all:\n                # feature :[36, 24, 6]\n                feature, label_reg = data[0].to(self.device).float(), data[1].to(self.device).float()\n                list_feat.append(feature)\n                list_label.append(label_reg)\n            flag = False\n            index = get_index(len(data_all) - 1)\n            for temp_index in index:\n                s1 = temp_index[0]\n                s2 = temp_index[1]\n                if list_feat[s1].shape[0] != list_feat[s2].shape[0]:\n                    flag = True\n                    break\n            if flag:\n                continue\n\n            total_loss = torch.zeros(1).to(self.device)\n            for i, n in enumerate(index):\n                feature_s = list_feat[n[0]]\n                feature_t = list_feat[n[1]]\n                label_reg_s = list_label[n[0]]\n                label_reg_t = list_label[n[1]]\n                feature_all = torch.cat((feature_s, feature_t), 0)\n\n                if epoch < self.pre_epoch:\n                    pred_all, loss_transfer, out_weight_list = self.model.forward_pre_train(\n                        feature_all, len_win=self.len_win\n                    )\n                else:\n                    pred_all, loss_transfer, dist, weight_mat = self.model.forward_Boosting(feature_all, weight_mat)\n                    dist_mat = dist_mat + dist\n                pred_s = pred_all[0 : feature_s.size(0)]\n                pred_t = pred_all[feature_s.size(0) :]\n\n                loss_s = criterion(pred_s, label_reg_s)\n                loss_t = criterion(pred_t, label_reg_t)\n\n                total_loss = total_loss + loss_s + loss_t + self.dw * loss_transfer\n            self.train_optimizer.zero_grad()\n            total_loss.backward()\n            torch.nn.utils.clip_grad_value_(self.model.parameters(), 3.0)\n            self.train_optimizer.step()\n        if epoch >= self.pre_epoch:\n            if epoch > self.pre_epoch:\n                weight_mat = self.model.update_weight_Boosting(weight_mat, dist_old, dist_mat)\n            return weight_mat, dist_mat\n        else:\n            weight_mat = self.transform_type(out_weight_list)\n            return weight_mat, None\n\n    @staticmethod\n    def calc_all_metrics(pred):\n        \"\"\"pred is a pandas dataframe that has two attributes: score (pred) and label (real)\"\"\"\n        res = {}\n        ic = pred.groupby(level=\"datetime\", group_keys=False).apply(lambda x: x.label.corr(x.score))\n        rank_ic = pred.groupby(level=\"datetime\", group_keys=False).apply(\n            lambda x: x.label.corr(x.score, method=\"spearman\")\n        )\n        res[\"ic\"] = ic.mean()\n        res[\"icir\"] = ic.mean() / ic.std()\n        res[\"ric\"] = rank_ic.mean()\n        res[\"ricir\"] = rank_ic.mean() / rank_ic.std()\n        res[\"mse\"] = -(pred[\"label\"] - pred[\"score\"]).mean()\n        res[\"loss\"] = res[\"mse\"]\n        return res\n\n    def test_epoch(self, df):\n        self.model.eval()\n        preds = self.infer(df[\"feature\"])\n        label = df[\"label\"].squeeze()\n        preds = pd.DataFrame({\"label\": label, \"score\": preds}, index=df.index)\n        metrics = self.calc_all_metrics(preds)\n        return metrics\n\n    def log_metrics(self, mode, metrics):\n        metrics = [\"{}/{}: {:.6f}\".format(k, mode, v) for k, v in metrics.items()]\n        metrics = \", \".join(metrics)\n        self.logger.info(metrics)\n\n    def fit(\n        self,\n        dataset: DatasetH,\n        evals_result=dict(),\n        save_path=None,\n    ):\n        df_train, df_valid = dataset.prepare(\n            [\"train\", \"valid\"],\n            col_set=[\"feature\", \"label\"],\n            data_key=DataHandlerLP.DK_L,\n        )\n        #  splits = ['2011-06-30']\n        days = df_train.index.get_level_values(level=0).unique()\n        train_splits = np.array_split(days, self.n_splits)\n        train_splits = [df_train[s[0] : s[-1]] for s in train_splits]\n        train_loader_list = [get_stock_loader(df, self.batch_size) for df in train_splits]\n\n        save_path = get_or_create_path(save_path)\n        stop_steps = 0\n        evals_result[\"train\"] = []\n        evals_result[\"valid\"] = []\n\n        # train\n        self.logger.info(\"training...\")\n        self.fitted = True\n        best_score = -np.inf\n        best_epoch = 0\n        weight_mat, dist_mat = None, None\n\n        for step in range(self.n_epochs):\n            self.logger.info(\"Epoch%d:\", step)\n            self.logger.info(\"training...\")\n            weight_mat, dist_mat = self.train_AdaRNN(train_loader_list, step, dist_mat, weight_mat)\n            self.logger.info(\"evaluating...\")\n            train_metrics = self.test_epoch(df_train)\n            valid_metrics = self.test_epoch(df_valid)\n            self.log_metrics(\"train: \", train_metrics)\n            self.log_metrics(\"valid: \", valid_metrics)\n\n            valid_score = valid_metrics[self.metric]\n            train_score = train_metrics[self.metric]\n            evals_result[\"train\"].append(train_score)\n            evals_result[\"valid\"].append(valid_score)\n            if valid_score > best_score:\n                best_score = valid_score\n                stop_steps = 0\n                best_epoch = step\n                best_param = copy.deepcopy(self.model.state_dict())\n            else:\n                stop_steps += 1\n                if stop_steps >= self.early_stop:\n                    self.logger.info(\"early stop\")\n                    break\n\n        self.logger.info(\"best score: %.6lf @ %d\" % (best_score, best_epoch))\n        self.model.load_state_dict(best_param)\n        torch.save(best_param, save_path)\n\n        if self.use_gpu:\n            torch.cuda.empty_cache()\n        return best_score\n\n    def predict(self, dataset: DatasetH, segment: Union[Text, slice] = \"test\"):\n        if not self.fitted:\n            raise ValueError(\"model is not fitted yet!\")\n        x_test = dataset.prepare(segment, col_set=\"feature\", data_key=DataHandlerLP.DK_I)\n        return self.infer(x_test)\n\n    def infer(self, x_test):\n        index = x_test.index\n        self.model.eval()\n        x_values = x_test.values\n        sample_num = x_values.shape[0]\n        x_values = x_values.reshape(sample_num, self.d_feat, -1).transpose(0, 2, 1)\n        preds = []\n\n        for begin in range(sample_num)[:: self.batch_size]:\n            if sample_num - begin < self.batch_size:\n                end = sample_num\n            else:\n                end = begin + self.batch_size\n\n            x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)\n\n            with torch.no_grad():\n                pred = self.model.predict(x_batch).detach().cpu().numpy()\n\n            preds.append(pred)\n\n        return pd.Series(np.concatenate(preds), index=index)\n\n    def transform_type(self, init_weight):\n        weight = torch.ones(self.num_layers, self.len_seq).to(self.device)\n        for i in range(self.num_layers):\n            for j in range(self.len_seq):\n                weight[i, j] = init_weight[i][j].item()\n        return weight\n\n\nclass data_loader(Dataset):\n    def __init__(self, df):\n        self.df_feature = df[\"feature\"]\n        self.df_label_reg = df[\"label\"]\n        self.df_index = df.index\n        self.df_feature = torch.tensor(\n            self.df_feature.values.reshape(-1, 6, 60).transpose(0, 2, 1), dtype=torch.float32\n        )\n        self.df_label_reg = torch.tensor(self.df_label_reg.values.reshape(-1), dtype=torch.float32)\n\n    def __getitem__(self, index):\n        sample, label_reg = self.df_feature[index], self.df_label_reg[index]\n        return sample, label_reg\n\n    def __len__(self):\n        return len(self.df_feature)\n\n\ndef get_stock_loader(df, batch_size, shuffle=True):\n    train_loader = DataLoader(data_loader(df), batch_size=batch_size, shuffle=shuffle)\n    return train_loader\n\n\ndef get_index(num_domain=2):\n    index = []\n    for i in range(num_domain):\n        for j in range(i + 1, num_domain + 1):\n            index.append((i, j))\n    return index\n\n\nclass AdaRNN(nn.Module):\n    \"\"\"\n    model_type:  'Boosting', 'AdaRNN'\n    \"\"\"\n\n    def __init__(\n        self,\n        use_bottleneck=False,\n        bottleneck_width=256,\n        n_input=128,\n        n_hiddens=[64, 64],\n        n_output=6,\n        dropout=0.0,\n        len_seq=9,\n        model_type=\"AdaRNN\",\n        trans_loss=\"mmd\",\n        GPU=0,\n    ):\n        super(AdaRNN, self).__init__()\n        self.use_bottleneck = use_bottleneck\n        self.n_input = n_input\n        self.num_layers = len(n_hiddens)\n        self.hiddens = n_hiddens\n        self.n_output = n_output\n        self.model_type = model_type\n        self.trans_loss = trans_loss\n        self.len_seq = len_seq\n        self.device = torch.device(\"cuda:%d\" % GPU if torch.cuda.is_available() and GPU >= 0 else \"cpu\")\n        in_size = self.n_input\n\n        features = nn.ModuleList()\n        for hidden in n_hiddens:\n            rnn = nn.GRU(input_size=in_size, num_layers=1, hidden_size=hidden, batch_first=True, dropout=dropout)\n            features.append(rnn)\n            in_size = hidden\n        self.features = nn.Sequential(*features)\n\n        if use_bottleneck is True:  # finance\n            self.bottleneck = nn.Sequential(\n                nn.Linear(n_hiddens[-1], bottleneck_width),\n                nn.Linear(bottleneck_width, bottleneck_width),\n                nn.BatchNorm1d(bottleneck_width),\n                nn.ReLU(),\n                nn.Dropout(),\n            )\n            self.bottleneck[0].weight.data.normal_(0, 0.005)\n            self.bottleneck[0].bias.data.fill_(0.1)\n            self.bottleneck[1].weight.data.normal_(0, 0.005)\n            self.bottleneck[1].bias.data.fill_(0.1)\n            self.fc = nn.Linear(bottleneck_width, n_output)\n            torch.nn.init.xavier_normal_(self.fc.weight)\n        else:\n            self.fc_out = nn.Linear(n_hiddens[-1], self.n_output)\n\n        if self.model_type == \"AdaRNN\":\n            gate = nn.ModuleList()\n            for i in range(len(n_hiddens)):\n                gate_weight = nn.Linear(len_seq * self.hiddens[i] * 2, len_seq)\n                gate.append(gate_weight)\n            self.gate = gate\n\n            bnlst = nn.ModuleList()\n            for i in range(len(n_hiddens)):\n                bnlst.append(nn.BatchNorm1d(len_seq))\n            self.bn_lst = bnlst\n            self.softmax = torch.nn.Softmax(dim=0)\n            self.init_layers()\n\n    def init_layers(self):\n        for i in range(len(self.hiddens)):\n            self.gate[i].weight.data.normal_(0, 0.05)\n            self.gate[i].bias.data.fill_(0.0)\n\n    def forward_pre_train(self, x, len_win=0):\n        out = self.gru_features(x)\n        fea = out[0]  # [2N,L,H]\n        if self.use_bottleneck is True:\n            fea_bottleneck = self.bottleneck(fea[:, -1, :])\n            fc_out = self.fc(fea_bottleneck).squeeze()\n        else:\n            fc_out = self.fc_out(fea[:, -1, :]).squeeze()  # [N,]\n\n        out_list_all, out_weight_list = out[1], out[2]\n        out_list_s, out_list_t = self.get_features(out_list_all)\n        loss_transfer = torch.zeros((1,)).to(self.device)\n        for i, n in enumerate(out_list_s):\n            criterion_transder = TransferLoss(loss_type=self.trans_loss, input_dim=n.shape[2])\n            h_start = 0\n            for j in range(h_start, self.len_seq, 1):\n                i_start = j - len_win if j - len_win >= 0 else 0\n                i_end = j + len_win if j + len_win < self.len_seq else self.len_seq - 1\n                for k in range(i_start, i_end + 1):\n                    weight = (\n                        out_weight_list[i][j]\n                        if self.model_type == \"AdaRNN\"\n                        else 1 / (self.len_seq - h_start) * (2 * len_win + 1)\n                    )\n                    loss_transfer = loss_transfer + weight * criterion_transder.compute(\n                        n[:, j, :], out_list_t[i][:, k, :]\n                    )\n        return fc_out, loss_transfer, out_weight_list\n\n    def gru_features(self, x, predict=False):\n        x_input = x\n        out = None\n        out_lis = []\n        out_weight_list = [] if (self.model_type == \"AdaRNN\") else None\n        for i in range(self.num_layers):\n            out, _ = self.features[i](x_input.float())\n            x_input = out\n            out_lis.append(out)\n            if self.model_type == \"AdaRNN\" and predict is False:\n                out_gate = self.process_gate_weight(x_input, i)\n                out_weight_list.append(out_gate)\n        return out, out_lis, out_weight_list\n\n    def process_gate_weight(self, out, index):\n        x_s = out[0 : int(out.shape[0] // 2)]\n        x_t = out[out.shape[0] // 2 : out.shape[0]]\n        x_all = torch.cat((x_s, x_t), 2)\n        x_all = x_all.view(x_all.shape[0], -1)\n        weight = torch.sigmoid(self.bn_lst[index](self.gate[index](x_all.float())))\n        weight = torch.mean(weight, dim=0)\n        res = self.softmax(weight).squeeze()\n        return res\n\n    @staticmethod\n    def get_features(output_list):\n        fea_list_src, fea_list_tar = [], []\n        for fea in output_list:\n            fea_list_src.append(fea[0 : fea.size(0) // 2])\n            fea_list_tar.append(fea[fea.size(0) // 2 :])\n        return fea_list_src, fea_list_tar\n\n    # For Boosting-based\n    def forward_Boosting(self, x, weight_mat=None):\n        out = self.gru_features(x)\n        fea = out[0]\n        if self.use_bottleneck:\n            fea_bottleneck = self.bottleneck(fea[:, -1, :])\n            fc_out = self.fc(fea_bottleneck).squeeze()\n        else:\n            fc_out = self.fc_out(fea[:, -1, :]).squeeze()\n\n        out_list_all = out[1]\n        out_list_s, out_list_t = self.get_features(out_list_all)\n        loss_transfer = torch.zeros((1,)).to(self.device)\n        if weight_mat is None:\n            weight = (1.0 / self.len_seq * torch.ones(self.num_layers, self.len_seq)).to(self.device)\n        else:\n            weight = weight_mat\n        dist_mat = torch.zeros(self.num_layers, self.len_seq).to(self.device)\n        for i, n in enumerate(out_list_s):\n            criterion_transder = TransferLoss(loss_type=self.trans_loss, input_dim=n.shape[2])\n            for j in range(self.len_seq):\n                loss_trans = criterion_transder.compute(n[:, j, :], out_list_t[i][:, j, :])\n                loss_transfer = loss_transfer + weight[i, j] * loss_trans\n                dist_mat[i, j] = loss_trans\n        return fc_out, loss_transfer, dist_mat, weight\n\n    # For Boosting-based\n    def update_weight_Boosting(self, weight_mat, dist_old, dist_new):\n        epsilon = 1e-5\n        dist_old = dist_old.detach()\n        dist_new = dist_new.detach()\n        ind = dist_new > dist_old + epsilon\n        weight_mat[ind] = weight_mat[ind] * (1 + torch.sigmoid(dist_new[ind] - dist_old[ind]))\n        weight_norm = torch.norm(weight_mat, dim=1, p=1)\n        weight_mat = weight_mat / weight_norm.t().unsqueeze(1).repeat(1, self.len_seq)\n        return weight_mat\n\n    def predict(self, x):\n        out = self.gru_features(x, predict=True)\n        fea = out[0]\n        if self.use_bottleneck is True:\n            fea_bottleneck = self.bottleneck(fea[:, -1, :])\n            fc_out = self.fc(fea_bottleneck).squeeze()\n        else:\n            fc_out = self.fc_out(fea[:, -1, :]).squeeze()\n        return fc_out\n\n\nclass TransferLoss:\n    def __init__(self, loss_type=\"cosine\", input_dim=512, GPU=0):\n        \"\"\"\n        Supported loss_type: mmd(mmd_lin), mmd_rbf, coral, cosine, kl, js, mine, adv\n        \"\"\"\n        self.loss_type = loss_type\n        self.input_dim = input_dim\n        self.device = torch.device(\"cuda:%d\" % GPU if torch.cuda.is_available() and GPU >= 0 else \"cpu\")\n\n    def compute(self, X, Y):\n        \"\"\"Compute adaptation loss\n\n        Arguments:\n            X {tensor} -- source matrix\n            Y {tensor} -- target matrix\n\n        Returns:\n            [tensor] -- transfer loss\n        \"\"\"\n        loss = None\n        if self.loss_type in (\"mmd_lin\", \"mmd\"):\n            mmdloss = MMD_loss(kernel_type=\"linear\")\n            loss = mmdloss(X, Y)\n        elif self.loss_type == \"coral\":\n            loss = CORAL(X, Y, self.device)\n        elif self.loss_type in (\"cosine\", \"cos\"):\n            loss = 1 - cosine(X, Y)\n        elif self.loss_type == \"kl\":\n            loss = kl_div(X, Y)\n        elif self.loss_type == \"js\":\n            loss = js(X, Y)\n        elif self.loss_type == \"mine\":\n            mine_model = Mine_estimator(input_dim=self.input_dim, hidden_dim=60).to(self.device)\n            loss = mine_model(X, Y)\n        elif self.loss_type == \"adv\":\n            loss = adv(X, Y, self.device, input_dim=self.input_dim, hidden_dim=32)\n        elif self.loss_type == \"mmd_rbf\":\n            mmdloss = MMD_loss(kernel_type=\"rbf\")\n            loss = mmdloss(X, Y)\n        elif self.loss_type == \"pairwise\":\n            pair_mat = pairwise_dist(X, Y)\n            loss = torch.norm(pair_mat)\n\n        return loss\n\n\ndef cosine(source, target):\n    source, target = source.mean(), target.mean()\n    cos = nn.CosineSimilarity(dim=0)\n    loss = cos(source, target)\n    return loss.mean()\n\n\nclass ReverseLayerF(Function):\n    @staticmethod\n    def forward(ctx, x, alpha):\n        ctx.alpha = alpha\n        return x.view_as(x)\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        output = grad_output.neg() * ctx.alpha\n        return output, None\n\n\nclass Discriminator(nn.Module):\n    def __init__(self, input_dim=256, hidden_dim=256):\n        super(Discriminator, self).__init__()\n        self.input_dim = input_dim\n        self.hidden_dim = hidden_dim\n        self.dis1 = nn.Linear(input_dim, hidden_dim)\n        self.dis2 = nn.Linear(hidden_dim, 1)\n\n    def forward(self, x):\n        x = F.relu(self.dis1(x))\n        x = self.dis2(x)\n        x = torch.sigmoid(x)\n        return x\n\n\ndef adv(source, target, device, input_dim=256, hidden_dim=512):\n    domain_loss = nn.BCELoss()\n    # !!! Pay attention to .cuda !!!\n    adv_net = Discriminator(input_dim, hidden_dim).to(device)\n    domain_src = torch.ones(len(source)).to(device)\n    domain_tar = torch.zeros(len(target)).to(device)\n    domain_src, domain_tar = domain_src.view(domain_src.shape[0], 1), domain_tar.view(domain_tar.shape[0], 1)\n    reverse_src = ReverseLayerF.apply(source, 1)\n    reverse_tar = ReverseLayerF.apply(target, 1)\n    pred_src = adv_net(reverse_src)\n    pred_tar = adv_net(reverse_tar)\n    loss_s, loss_t = domain_loss(pred_src, domain_src), domain_loss(pred_tar, domain_tar)\n    loss = loss_s + loss_t\n    return loss\n\n\ndef CORAL(source, target, device):\n    d = source.size(1)\n    ns, nt = source.size(0), target.size(0)\n\n    # source covariance\n    tmp_s = torch.ones((1, ns)).to(device) @ source\n    cs = (source.t() @ source - (tmp_s.t() @ tmp_s) / ns) / (ns - 1)\n\n    # target covariance\n    tmp_t = torch.ones((1, nt)).to(device) @ target\n    ct = (target.t() @ target - (tmp_t.t() @ tmp_t) / nt) / (nt - 1)\n\n    # frobenius norm\n    loss = (cs - ct).pow(2).sum()\n    loss = loss / (4 * d * d)\n\n    return loss\n\n\nclass MMD_loss(nn.Module):\n    def __init__(self, kernel_type=\"linear\", kernel_mul=2.0, kernel_num=5):\n        super(MMD_loss, self).__init__()\n        self.kernel_num = kernel_num\n        self.kernel_mul = kernel_mul\n        self.fix_sigma = None\n        self.kernel_type = kernel_type\n\n    @staticmethod\n    def guassian_kernel(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):\n        n_samples = int(source.size()[0]) + int(target.size()[0])\n        total = torch.cat([source, target], dim=0)\n        total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))\n        total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))\n        L2_distance = ((total0 - total1) ** 2).sum(2)\n        if fix_sigma:\n            bandwidth = fix_sigma\n        else:\n            bandwidth = torch.sum(L2_distance.data) / (n_samples**2 - n_samples)\n        bandwidth /= kernel_mul ** (kernel_num // 2)\n        bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)]\n        kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list]\n        return sum(kernel_val)\n\n    @staticmethod\n    def linear_mmd(X, Y):\n        delta = X.mean(axis=0) - Y.mean(axis=0)\n        loss = delta.dot(delta.T)\n        return loss\n\n    def forward(self, source, target):\n        if self.kernel_type == \"linear\":\n            return self.linear_mmd(source, target)\n        elif self.kernel_type == \"rbf\":\n            batch_size = int(source.size()[0])\n            kernels = self.guassian_kernel(\n                source, target, kernel_mul=self.kernel_mul, kernel_num=self.kernel_num, fix_sigma=self.fix_sigma\n            )\n            with torch.no_grad():\n                XX = torch.mean(kernels[:batch_size, :batch_size])\n                YY = torch.mean(kernels[batch_size:, batch_size:])\n                XY = torch.mean(kernels[:batch_size, batch_size:])\n                YX = torch.mean(kernels[batch_size:, :batch_size])\n                loss = torch.mean(XX + YY - XY - YX)\n            return loss\n\n\nclass Mine_estimator(nn.Module):\n    def __init__(self, input_dim=2048, hidden_dim=512):\n        super(Mine_estimator, self).__init__()\n        self.mine_model = Mine(input_dim, hidden_dim)\n\n    def forward(self, X, Y):\n        Y_shffle = Y[torch.randperm(len(Y))]\n        loss_joint = self.mine_model(X, Y)\n        loss_marginal = self.mine_model(X, Y_shffle)\n        ret = torch.mean(loss_joint) - torch.log(torch.mean(torch.exp(loss_marginal)))\n        loss = -ret\n        return loss\n\n\nclass Mine(nn.Module):\n    def __init__(self, input_dim=2048, hidden_dim=512):\n        super(Mine, self).__init__()\n        self.fc1_x = nn.Linear(input_dim, hidden_dim)\n        self.fc1_y = nn.Linear(input_dim, hidden_dim)\n        self.fc2 = nn.Linear(hidden_dim, 1)\n\n    def forward(self, x, y):\n        h1 = F.leaky_relu(self.fc1_x(x) + self.fc1_y(y))\n        h2 = self.fc2(h1)\n        return h2\n\n\ndef pairwise_dist(X, Y):\n    n, d = X.shape\n    m, _ = Y.shape\n    assert d == Y.shape[1]\n    a = X.unsqueeze(1).expand(n, m, d)\n    b = Y.unsqueeze(0).expand(n, m, d)\n    return torch.pow(a - b, 2).sum(2)\n\n\ndef pairwise_dist_np(X, Y):\n    n, d = X.shape\n    m, _ = Y.shape\n    assert d == Y.shape[1]\n    a = np.expand_dims(X, 1)\n    b = np.expand_dims(Y, 0)\n    a = np.tile(a, (1, m, 1))\n    b = np.tile(b, (n, 1, 1))\n    return np.power(a - b, 2).sum(2)\n\n\ndef pa(X, Y):\n    XY = np.dot(X, Y.T)\n    XX = np.sum(np.square(X), axis=1)\n    XX = np.transpose([XX])\n    YY = np.sum(np.square(Y), axis=1)\n    dist = XX + YY - 2 * XY\n\n    return dist\n\n\ndef kl_div(source, target):\n    if len(source) < len(target):\n        target = target[: len(source)]\n    elif len(source) > len(target):\n        source = source[: len(target)]\n    criterion = nn.KLDivLoss(reduction=\"batchmean\")\n    loss = criterion(source.log(), target)\n    return loss\n\n\ndef js(source, target):\n    if len(source) < len(target):\n        target = target[: len(source)]\n    elif len(source) > len(target):\n        source = source[: len(target)]\n    M = 0.5 * (source + target)\n    loss_1, loss_2 = kl_div(source, M), kl_div(target, M)\n    return 0.5 * (loss_1 + loss_2)\n"
  },
  {
    "path": "qlib/contrib/model/pytorch_add.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom __future__ import division\nfrom __future__ import print_function\n\n\nimport copy\nimport math\nfrom typing import Text, Union\n\nimport numpy as np\nimport pandas as pd\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.optim as optim\nfrom qlib.contrib.model.pytorch_gru import GRUModel\nfrom qlib.contrib.model.pytorch_lstm import LSTMModel\nfrom qlib.contrib.model.pytorch_utils import count_parameters\nfrom qlib.data.dataset import DatasetH\nfrom qlib.data.dataset.handler import DataHandlerLP\nfrom qlib.log import get_module_logger\nfrom qlib.model.base import Model\nfrom qlib.utils import get_or_create_path\nfrom torch.autograd import Function\n\n\nclass ADD(Model):\n    \"\"\"ADD Model\n\n    Parameters\n    ----------\n     lr : float\n         learning rate\n     d_feat : int\n         input dimensions for each time step\n     metric : str\n         the evaluation metric used in early stop\n     optimizer : str\n         optimizer name\n     GPU : int\n         the GPU ID used for training\n    \"\"\"\n\n    def __init__(\n        self,\n        d_feat=6,\n        hidden_size=64,\n        num_layers=2,\n        dropout=0.0,\n        dec_dropout=0.0,\n        n_epochs=200,\n        lr=0.001,\n        metric=\"mse\",\n        batch_size=5000,\n        early_stop=20,\n        base_model=\"GRU\",\n        model_path=None,\n        optimizer=\"adam\",\n        gamma=0.1,\n        gamma_clip=0.4,\n        mu=0.05,\n        GPU=0,\n        seed=None,\n        **kwargs,\n    ):\n        # Set logger.\n        self.logger = get_module_logger(\"ADD\")\n        self.logger.info(\"ADD pytorch version...\")\n\n        # set hyper-parameters.\n        self.d_feat = d_feat\n        self.hidden_size = hidden_size\n        self.num_layers = num_layers\n        self.dropout = dropout\n        self.dec_dropout = dec_dropout\n        self.n_epochs = n_epochs\n        self.lr = lr\n        self.metric = metric\n        self.batch_size = batch_size\n        self.early_stop = early_stop\n        self.optimizer = optimizer.lower()\n        self.base_model = base_model\n        self.model_path = model_path\n        self.device = torch.device(\"cuda:%d\" % (GPU) if torch.cuda.is_available() and GPU >= 0 else \"cpu\")\n        self.seed = seed\n\n        self.gamma = gamma\n        self.gamma_clip = gamma_clip\n        self.mu = mu\n\n        self.logger.info(\n            \"ADD parameters setting:\"\n            \"\\nd_feat : {}\"\n            \"\\nhidden_size : {}\"\n            \"\\nnum_layers : {}\"\n            \"\\ndropout : {}\"\n            \"\\ndec_dropout : {}\"\n            \"\\nn_epochs : {}\"\n            \"\\nlr : {}\"\n            \"\\nmetric : {}\"\n            \"\\nbatch_size : {}\"\n            \"\\nearly_stop : {}\"\n            \"\\noptimizer : {}\"\n            \"\\nbase_model : {}\"\n            \"\\nmodel_path : {}\"\n            \"\\ngamma : {}\"\n            \"\\ngamma_clip : {}\"\n            \"\\nmu : {}\"\n            \"\\ndevice : {}\"\n            \"\\nuse_GPU : {}\"\n            \"\\nseed : {}\".format(\n                d_feat,\n                hidden_size,\n                num_layers,\n                dropout,\n                dec_dropout,\n                n_epochs,\n                lr,\n                metric,\n                batch_size,\n                early_stop,\n                optimizer.lower(),\n                base_model,\n                model_path,\n                gamma,\n                gamma_clip,\n                mu,\n                self.device,\n                self.use_gpu,\n                seed,\n            )\n        )\n\n        if self.seed is not None:\n            np.random.seed(self.seed)\n            torch.manual_seed(self.seed)\n\n        self.ADD_model = ADDModel(\n            d_feat=self.d_feat,\n            hidden_size=self.hidden_size,\n            num_layers=self.num_layers,\n            dropout=self.dropout,\n            dec_dropout=self.dec_dropout,\n            base_model=self.base_model,\n            gamma=self.gamma,\n            gamma_clip=self.gamma_clip,\n        )\n        self.logger.info(\"model:\\n{:}\".format(self.ADD_model))\n        self.logger.info(\"model size: {:.4f} MB\".format(count_parameters(self.ADD_model)))\n\n        if optimizer.lower() == \"adam\":\n            self.train_optimizer = optim.Adam(self.ADD_model.parameters(), lr=self.lr)\n        elif optimizer.lower() == \"gd\":\n            self.train_optimizer = optim.SGD(self.ADD_model.parameters(), lr=self.lr)\n        else:\n            raise NotImplementedError(\"optimizer {} is not supported!\".format(optimizer))\n\n        self.fitted = False\n        self.ADD_model.to(self.device)\n\n    @property\n    def use_gpu(self):\n        return self.device != torch.device(\"cpu\")\n\n    def loss_pre_excess(self, pred_excess, label_excess, record=None):\n        mask = ~torch.isnan(label_excess)\n        pre_excess_loss = F.mse_loss(pred_excess[mask], label_excess[mask])\n        if record is not None:\n            record[\"pre_excess_loss\"] = pre_excess_loss.item()\n        return pre_excess_loss\n\n    def loss_pre_market(self, pred_market, label_market, record=None):\n        pre_market_loss = F.cross_entropy(pred_market, label_market)\n        if record is not None:\n            record[\"pre_market_loss\"] = pre_market_loss.item()\n        return pre_market_loss\n\n    def loss_pre(self, pred_excess, label_excess, pred_market, label_market, record=None):\n        pre_loss = self.loss_pre_excess(pred_excess, label_excess, record) + self.loss_pre_market(\n            pred_market, label_market, record\n        )\n        if record is not None:\n            record[\"pre_loss\"] = pre_loss.item()\n        return pre_loss\n\n    def loss_adv_excess(self, adv_excess, label_excess, record=None):\n        mask = ~torch.isnan(label_excess)\n        adv_excess_loss = F.mse_loss(adv_excess.squeeze()[mask], label_excess[mask])\n        if record is not None:\n            record[\"adv_excess_loss\"] = adv_excess_loss.item()\n        return adv_excess_loss\n\n    def loss_adv_market(self, adv_market, label_market, record=None):\n        adv_market_loss = F.cross_entropy(adv_market, label_market)\n        if record is not None:\n            record[\"adv_market_loss\"] = adv_market_loss.item()\n        return adv_market_loss\n\n    def loss_adv(self, adv_excess, label_excess, adv_market, label_market, record=None):\n        adv_loss = self.loss_adv_excess(adv_excess, label_excess, record) + self.loss_adv_market(\n            adv_market, label_market, record\n        )\n        if record is not None:\n            record[\"adv_loss\"] = adv_loss.item()\n        return adv_loss\n\n    def loss_fn(self, x, preds, label_excess, label_market, record=None):\n        loss = (\n            self.loss_pre(preds[\"excess\"], label_excess, preds[\"market\"], label_market, record)\n            + self.loss_adv(preds[\"adv_excess\"], label_excess, preds[\"adv_market\"], label_market, record)\n            + self.mu * self.loss_rec(x, preds[\"reconstructed_feature\"], record)\n        )\n        if record is not None:\n            record[\"loss\"] = loss.item()\n        return loss\n\n    def loss_rec(self, x, rec_x, record=None):\n        x = x.reshape(len(x), self.d_feat, -1)\n        x = x.permute(0, 2, 1)\n        rec_loss = F.mse_loss(x, rec_x)\n        if record is not None:\n            record[\"rec_loss\"] = rec_loss.item()\n        return rec_loss\n\n    def get_daily_inter(self, df, shuffle=False):\n        # organize the train data into daily batches\n        daily_count = df.groupby(level=0, group_keys=False).size().values\n        daily_index = np.roll(np.cumsum(daily_count), 1)\n        daily_index[0] = 0\n        if shuffle:\n            # shuffle data\n            daily_shuffle = list(zip(daily_index, daily_count))\n            np.random.shuffle(daily_shuffle)\n            daily_index, daily_count = zip(*daily_shuffle)\n        return daily_index, daily_count\n\n    def cal_ic_metrics(self, pred, label):\n        metrics = {}\n        metrics[\"mse\"] = -F.mse_loss(pred, label).item()\n        metrics[\"loss\"] = metrics[\"mse\"]\n        pred = pd.Series(pred.cpu().detach().numpy())\n        label = pd.Series(label.cpu().detach().numpy())\n        metrics[\"ic\"] = pred.corr(label)\n        metrics[\"ric\"] = pred.corr(label, method=\"spearman\")\n        return metrics\n\n    def test_epoch(self, data_x, data_y, data_m):\n        x_values = data_x.values\n        y_values = np.squeeze(data_y.values)\n        m_values = np.squeeze(data_m.values.astype(int))\n        self.ADD_model.eval()\n\n        metrics_list = []\n\n        daily_index, daily_count = self.get_daily_inter(data_x, shuffle=False)\n\n        for idx, count in zip(daily_index, daily_count):\n            batch = slice(idx, idx + count)\n            feature = torch.from_numpy(x_values[batch]).float().to(self.device)\n            label_excess = torch.from_numpy(y_values[batch]).float().to(self.device)\n            label_market = torch.from_numpy(m_values[batch]).long().to(self.device)\n\n            metrics = {}\n            preds = self.ADD_model(feature)\n            self.loss_fn(feature, preds, label_excess, label_market, metrics)\n            metrics.update(self.cal_ic_metrics(preds[\"excess\"], label_excess))\n            metrics_list.append(metrics)\n        metrics = {}\n        keys = metrics_list[0].keys()\n        for k in keys:\n            vs = [m[k] for m in metrics_list]\n            metrics[k] = sum(vs) / len(vs)\n\n        return metrics\n\n    def train_epoch(self, x_train_values, y_train_values, m_train_values):\n        self.ADD_model.train()\n\n        indices = np.arange(len(x_train_values))\n        np.random.shuffle(indices)\n\n        cur_step = 1\n\n        for i in range(len(indices))[:: self.batch_size]:\n            if len(indices) - i < self.batch_size:\n                break\n            batch = indices[i : i + self.batch_size]\n            feature = torch.from_numpy(x_train_values[batch]).float().to(self.device)\n            label_excess = torch.from_numpy(y_train_values[batch]).float().to(self.device)\n            label_market = torch.from_numpy(m_train_values[batch]).long().to(self.device)\n\n            preds = self.ADD_model(feature)\n\n            loss = self.loss_fn(feature, preds, label_excess, label_market)\n\n            self.train_optimizer.zero_grad()\n            loss.backward()\n            torch.nn.utils.clip_grad_value_(self.ADD_model.parameters(), 3.0)\n            self.train_optimizer.step()\n            cur_step += 1\n\n    def log_metrics(self, mode, metrics):\n        metrics = [\"{}/{}: {:.6f}\".format(k, mode, v) for k, v in metrics.items()]\n        metrics = \", \".join(metrics)\n        self.logger.info(metrics)\n\n    def bootstrap_fit(self, x_train, y_train, m_train, x_valid, y_valid, m_valid):\n        stop_steps = 0\n        best_score = -np.inf\n        best_epoch = 0\n\n        # train\n        self.logger.info(\"training...\")\n        self.fitted = True\n        x_train_values = x_train.values\n        y_train_values = np.squeeze(y_train.values)\n        m_train_values = np.squeeze(m_train.values.astype(int))\n\n        for step in range(self.n_epochs):\n            self.logger.info(\"Epoch%d:\", step)\n            self.logger.info(\"training...\")\n            self.train_epoch(x_train_values, y_train_values, m_train_values)\n            self.logger.info(\"evaluating...\")\n            train_metrics = self.test_epoch(x_train, y_train, m_train)\n            valid_metrics = self.test_epoch(x_valid, y_valid, m_valid)\n            self.log_metrics(\"train\", train_metrics)\n            self.log_metrics(\"valid\", valid_metrics)\n\n            if self.metric in valid_metrics:\n                val_score = valid_metrics[self.metric]\n            else:\n                raise ValueError(\"unknown metric name `%s`\" % self.metric)\n            if val_score > best_score:\n                best_score = val_score\n                stop_steps = 0\n                best_epoch = step\n                best_param = copy.deepcopy(self.ADD_model.state_dict())\n            else:\n                stop_steps += 1\n                if stop_steps >= self.early_stop:\n                    self.logger.info(\"early stop\")\n                    break\n            self.ADD_model.before_adv_excess.step_alpha()\n            self.ADD_model.before_adv_market.step_alpha()\n        self.logger.info(\"bootstrap_fit best score: {:.6f} @ {}\".format(best_score, best_epoch))\n        self.ADD_model.load_state_dict(best_param)\n        return best_score\n\n    def gen_market_label(self, df, raw_label):\n        market_label = raw_label.groupby(\"datetime\", group_keys=False).mean().squeeze()\n        bins = [-np.inf, self.lo, self.hi, np.inf]\n        market_label = pd.cut(market_label, bins, labels=False)\n        market_label.name = (\"market_return\", \"market_return\")\n        df = df.join(market_label)\n        return df\n\n    def fit_thresh(self, train_label):\n        market_label = train_label.groupby(\"datetime\", group_keys=False).mean().squeeze()\n        self.lo, self.hi = market_label.quantile([1 / 3, 2 / 3])\n\n    def fit(\n        self,\n        dataset: DatasetH,\n        evals_result=dict(),\n        save_path=None,\n    ):\n        label_train, label_valid = dataset.prepare(\n            [\"train\", \"valid\"],\n            col_set=[\"label\"],\n            data_key=DataHandlerLP.DK_R,\n        )\n        self.fit_thresh(label_train)\n        df_train, df_valid = dataset.prepare(\n            [\"train\", \"valid\"],\n            col_set=[\"feature\", \"label\"],\n            data_key=DataHandlerLP.DK_L,\n        )\n        df_train = self.gen_market_label(df_train, label_train)\n        df_valid = self.gen_market_label(df_valid, label_valid)\n\n        x_train, y_train, m_train = df_train[\"feature\"], df_train[\"label\"], df_train[\"market_return\"]\n        x_valid, y_valid, m_valid = df_valid[\"feature\"], df_valid[\"label\"], df_valid[\"market_return\"]\n\n        evals_result[\"train\"] = []\n        evals_result[\"valid\"] = []\n        # load pretrained base_model\n\n        if self.base_model == \"LSTM\":\n            pretrained_model = LSTMModel()\n        elif self.base_model == \"GRU\":\n            pretrained_model = GRUModel()\n        else:\n            raise ValueError(\"unknown base model name `%s`\" % self.base_model)\n\n        if self.model_path is not None:\n            self.logger.info(\"Loading pretrained model...\")\n            pretrained_model.load_state_dict(torch.load(self.model_path, map_location=self.device))\n\n            model_dict = self.ADD_model.enc_excess.state_dict()\n            pretrained_dict = {k: v for k, v in pretrained_model.rnn.state_dict().items() if k in model_dict}\n            model_dict.update(pretrained_dict)\n            self.ADD_model.enc_excess.load_state_dict(model_dict)\n            model_dict = self.ADD_model.enc_market.state_dict()\n            pretrained_dict = {k: v for k, v in pretrained_model.rnn.state_dict().items() if k in model_dict}\n            model_dict.update(pretrained_dict)\n            self.ADD_model.enc_market.load_state_dict(model_dict)\n            self.logger.info(\"Loading pretrained model Done...\")\n\n        self.bootstrap_fit(x_train, y_train, m_train, x_valid, y_valid, m_valid)\n\n        best_param = copy.deepcopy(self.ADD_model.state_dict())\n        save_path = get_or_create_path(save_path)\n        torch.save(best_param, save_path)\n        if self.use_gpu:\n            torch.cuda.empty_cache()\n\n    def predict(self, dataset: DatasetH, segment: Union[Text, slice] = \"test\"):\n        x_test = dataset.prepare(segment, col_set=\"feature\", data_key=DataHandlerLP.DK_I)\n        index = x_test.index\n        self.ADD_model.eval()\n        x_values = x_test.values\n        preds = []\n\n        daily_index, daily_count = self.get_daily_inter(x_test, shuffle=False)\n\n        for idx, count in zip(daily_index, daily_count):\n            batch = slice(idx, idx + count)\n            x_batch = torch.from_numpy(x_values[batch]).float().to(self.device)\n\n            with torch.no_grad():\n                pred = self.ADD_model(x_batch)\n                pred = pred[\"excess\"].detach().cpu().numpy()\n\n            preds.append(pred)\n\n        r = pd.Series(np.concatenate(preds), index=index)\n        return r\n\n\nclass ADDModel(nn.Module):\n    def __init__(\n        self,\n        d_feat=6,\n        hidden_size=64,\n        num_layers=1,\n        dropout=0.0,\n        dec_dropout=0.5,\n        base_model=\"GRU\",\n        gamma=0.1,\n        gamma_clip=0.4,\n    ):\n        super().__init__()\n        self.d_feat = d_feat\n        self.base_model = base_model\n        if base_model == \"GRU\":\n            self.enc_excess, self.enc_market = [\n                nn.GRU(\n                    input_size=d_feat,\n                    hidden_size=hidden_size,\n                    num_layers=num_layers,\n                    batch_first=True,\n                    dropout=dropout,\n                )\n                for _ in range(2)\n            ]\n        elif base_model == \"LSTM\":\n            self.enc_excess, self.enc_market = [\n                nn.LSTM(\n                    input_size=d_feat,\n                    hidden_size=hidden_size,\n                    num_layers=num_layers,\n                    batch_first=True,\n                    dropout=dropout,\n                )\n                for _ in range(2)\n            ]\n        else:\n            raise ValueError(\"unknown base model name `%s`\" % base_model)\n        self.dec = Decoder(d_feat, 2 * hidden_size, num_layers, dec_dropout, base_model)\n\n        ctx_size = hidden_size * num_layers\n        self.pred_excess, self.adv_excess = [\n            nn.Sequential(nn.Linear(ctx_size, ctx_size), nn.BatchNorm1d(ctx_size), nn.Tanh(), nn.Linear(ctx_size, 1))\n            for _ in range(2)\n        ]\n        self.adv_market, self.pred_market = [\n            nn.Sequential(nn.Linear(ctx_size, ctx_size), nn.BatchNorm1d(ctx_size), nn.Tanh(), nn.Linear(ctx_size, 3))\n            for _ in range(2)\n        ]\n        self.before_adv_market, self.before_adv_excess = [RevGrad(gamma, gamma_clip) for _ in range(2)]\n\n    def forward(self, x):\n        x = x.reshape(len(x), self.d_feat, -1)\n        N = x.shape[0]\n        T = x.shape[-1]\n        x = x.permute(0, 2, 1)\n\n        out, hidden_excess = self.enc_excess(x)\n        out, hidden_market = self.enc_market(x)\n        if self.base_model == \"LSTM\":\n            feature_excess = hidden_excess[0].permute(1, 0, 2).reshape(N, -1)\n            feature_market = hidden_market[0].permute(1, 0, 2).reshape(N, -1)\n        else:\n            feature_excess = hidden_excess.permute(1, 0, 2).reshape(N, -1)\n            feature_market = hidden_market.permute(1, 0, 2).reshape(N, -1)\n        predicts = {}\n        predicts[\"excess\"] = self.pred_excess(feature_excess).squeeze(1)\n        predicts[\"market\"] = self.pred_market(feature_market)\n        predicts[\"adv_market\"] = self.adv_market(self.before_adv_market(feature_excess))\n        predicts[\"adv_excess\"] = self.adv_excess(self.before_adv_excess(feature_market).squeeze(1))\n        if self.base_model == \"LSTM\":\n            hidden = [torch.cat([hidden_excess[i], hidden_market[i]], -1) for i in range(2)]\n        else:\n            hidden = torch.cat([hidden_excess, hidden_market], -1)\n        x = torch.zeros_like(x[:, 1, :])\n        reconstructed_feature = []\n        for i in range(T):\n            x, hidden = self.dec(x, hidden)\n            reconstructed_feature.append(x)\n        reconstructed_feature = torch.stack(reconstructed_feature, 1)\n        predicts[\"reconstructed_feature\"] = reconstructed_feature\n        return predicts\n\n\nclass Decoder(nn.Module):\n    def __init__(self, d_feat=6, hidden_size=128, num_layers=1, dropout=0.5, base_model=\"GRU\"):\n        super().__init__()\n        self.base_model = base_model\n        if base_model == \"GRU\":\n            self.rnn = nn.GRU(\n                input_size=d_feat,\n                hidden_size=hidden_size,\n                num_layers=num_layers,\n                batch_first=True,\n                dropout=dropout,\n            )\n        elif base_model == \"LSTM\":\n            self.rnn = nn.LSTM(\n                input_size=d_feat,\n                hidden_size=hidden_size,\n                num_layers=num_layers,\n                batch_first=True,\n                dropout=dropout,\n            )\n        else:\n            raise ValueError(\"unknown base model name `%s`\" % base_model)\n\n        self.fc = nn.Linear(hidden_size, d_feat)\n\n    def forward(self, x, hidden):\n        x = x.unsqueeze(1)\n        output, hidden = self.rnn(x, hidden)\n        output = output.squeeze(1)\n        pred = self.fc(output)\n        return pred, hidden\n\n\nclass RevGradFunc(Function):\n    @staticmethod\n    def forward(ctx, input_, alpha_):\n        ctx.save_for_backward(input_, alpha_)\n        output = input_\n        return output\n\n    @staticmethod\n    def backward(ctx, grad_output):  # pragma: no cover\n        grad_input = None\n        _, alpha_ = ctx.saved_tensors\n        if ctx.needs_input_grad[0]:\n            grad_input = -grad_output * alpha_\n        return grad_input, None\n\n\nclass RevGrad(nn.Module):\n    def __init__(self, gamma=0.1, gamma_clip=0.4, *args, **kwargs):\n        \"\"\"\n        A gradient reversal layer.\n        This layer has no parameters, and simply reverses the gradient\n        in the backward pass.\n        \"\"\"\n        super().__init__(*args, **kwargs)\n\n        self.gamma = gamma\n        self.gamma_clip = torch.tensor(float(gamma_clip), requires_grad=False)\n        self._alpha = torch.tensor(0, requires_grad=False)\n        self._p = 0\n\n    def step_alpha(self):\n        self._p += 1\n        self._alpha = min(\n            self.gamma_clip, torch.tensor(2 / (1 + math.exp(-self.gamma * self._p)) - 1, requires_grad=False)\n        )\n\n    def forward(self, input_):\n        return RevGradFunc.apply(input_, self._alpha)\n"
  },
  {
    "path": "qlib/contrib/model/pytorch_alstm.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\n\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport numpy as np\nimport pandas as pd\nfrom typing import Text, Union\nimport copy\nfrom ...utils import get_or_create_path\nfrom ...log import get_module_logger\n\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\n\nfrom .pytorch_utils import count_parameters\nfrom ...model.base import Model\nfrom ...data.dataset import DatasetH\nfrom ...data.dataset.handler import DataHandlerLP\n\n\nclass ALSTM(Model):\n    \"\"\"ALSTM Model\n\n    Parameters\n    ----------\n    d_feat : int\n        input dimension for each time step\n    metric: str\n        the evaluation metric used in early stop\n    optimizer : str\n        optimizer name\n    GPU : int\n        the GPU ID used for training\n    \"\"\"\n\n    def __init__(\n        self,\n        d_feat=6,\n        hidden_size=64,\n        num_layers=2,\n        dropout=0.0,\n        n_epochs=200,\n        lr=0.001,\n        metric=\"\",\n        batch_size=2000,\n        early_stop=20,\n        loss=\"mse\",\n        optimizer=\"adam\",\n        GPU=0,\n        seed=None,\n        **kwargs,\n    ):\n        # Set logger.\n        self.logger = get_module_logger(\"ALSTM\")\n        self.logger.info(\"ALSTM pytorch version...\")\n\n        # set hyper-parameters.\n        self.d_feat = d_feat\n        self.hidden_size = hidden_size\n        self.num_layers = num_layers\n        self.dropout = dropout\n        self.n_epochs = n_epochs\n        self.lr = lr\n        self.metric = metric\n        self.batch_size = batch_size\n        self.early_stop = early_stop\n        self.optimizer = optimizer.lower()\n        self.loss = loss\n        self.device = torch.device(\"cuda:%d\" % (GPU) if torch.cuda.is_available() and GPU >= 0 else \"cpu\")\n        self.seed = seed\n\n        self.logger.info(\n            \"ALSTM parameters setting:\"\n            \"\\nd_feat : {}\"\n            \"\\nhidden_size : {}\"\n            \"\\nnum_layers : {}\"\n            \"\\ndropout : {}\"\n            \"\\nn_epochs : {}\"\n            \"\\nlr : {}\"\n            \"\\nmetric : {}\"\n            \"\\nbatch_size : {}\"\n            \"\\nearly_stop : {}\"\n            \"\\noptimizer : {}\"\n            \"\\nloss_type : {}\"\n            \"\\ndevice : {}\"\n            \"\\nuse_GPU : {}\"\n            \"\\nseed : {}\".format(\n                d_feat,\n                hidden_size,\n                num_layers,\n                dropout,\n                n_epochs,\n                lr,\n                metric,\n                batch_size,\n                early_stop,\n                optimizer.lower(),\n                loss,\n                self.device,\n                self.use_gpu,\n                seed,\n            )\n        )\n\n        if self.seed is not None:\n            np.random.seed(self.seed)\n            torch.manual_seed(self.seed)\n\n        self.ALSTM_model = ALSTMModel(\n            d_feat=self.d_feat,\n            hidden_size=self.hidden_size,\n            num_layers=self.num_layers,\n            dropout=self.dropout,\n        )\n        self.logger.info(\"model:\\n{:}\".format(self.ALSTM_model))\n        self.logger.info(\"model size: {:.4f} MB\".format(count_parameters(self.ALSTM_model)))\n\n        if optimizer.lower() == \"adam\":\n            self.train_optimizer = optim.Adam(self.ALSTM_model.parameters(), lr=self.lr)\n        elif optimizer.lower() == \"gd\":\n            self.train_optimizer = optim.SGD(self.ALSTM_model.parameters(), lr=self.lr)\n        else:\n            raise NotImplementedError(\"optimizer {} is not supported!\".format(optimizer))\n\n        self.fitted = False\n        self.ALSTM_model.to(self.device)\n\n    @property\n    def use_gpu(self):\n        return self.device != torch.device(\"cpu\")\n\n    def mse(self, pred, label):\n        loss = (pred - label) ** 2\n        return torch.mean(loss)\n\n    def loss_fn(self, pred, label):\n        mask = ~torch.isnan(label)\n\n        if self.loss == \"mse\":\n            return self.mse(pred[mask], label[mask])\n\n        raise ValueError(\"unknown loss `%s`\" % self.loss)\n\n    def metric_fn(self, pred, label):\n        mask = torch.isfinite(label)\n\n        if self.metric in (\"\", \"loss\"):\n            return -self.loss_fn(pred[mask], label[mask])\n\n        raise ValueError(\"unknown metric `%s`\" % self.metric)\n\n    def train_epoch(self, x_train, y_train):\n        x_train_values = x_train.values\n        y_train_values = np.squeeze(y_train.values)\n\n        self.ALSTM_model.train()\n\n        indices = np.arange(len(x_train_values))\n        np.random.shuffle(indices)\n\n        for i in range(len(indices))[:: self.batch_size]:\n            if len(indices) - i < self.batch_size:\n                break\n\n            feature = torch.from_numpy(x_train_values[indices[i : i + self.batch_size]]).float().to(self.device)\n            label = torch.from_numpy(y_train_values[indices[i : i + self.batch_size]]).float().to(self.device)\n\n            pred = self.ALSTM_model(feature)\n            loss = self.loss_fn(pred, label)\n\n            self.train_optimizer.zero_grad()\n            loss.backward()\n            torch.nn.utils.clip_grad_value_(self.ALSTM_model.parameters(), 3.0)\n            self.train_optimizer.step()\n\n    def test_epoch(self, data_x, data_y):\n        # prepare training data\n        x_values = data_x.values\n        y_values = np.squeeze(data_y.values)\n\n        self.ALSTM_model.eval()\n\n        scores = []\n        losses = []\n\n        indices = np.arange(len(x_values))\n\n        for i in range(len(indices))[:: self.batch_size]:\n            if len(indices) - i < self.batch_size:\n                break\n\n            feature = torch.from_numpy(x_values[indices[i : i + self.batch_size]]).float().to(self.device)\n            label = torch.from_numpy(y_values[indices[i : i + self.batch_size]]).float().to(self.device)\n\n            with torch.no_grad():\n                pred = self.ALSTM_model(feature)\n                loss = self.loss_fn(pred, label)\n                losses.append(loss.item())\n\n                score = self.metric_fn(pred, label)\n                scores.append(score.item())\n\n        return np.mean(losses), np.mean(scores)\n\n    def fit(\n        self,\n        dataset: DatasetH,\n        evals_result=dict(),\n        save_path=None,\n    ):\n        df_train, df_valid, df_test = dataset.prepare(\n            [\"train\", \"valid\", \"test\"],\n            col_set=[\"feature\", \"label\"],\n            data_key=DataHandlerLP.DK_L,\n        )\n        if df_train.empty or df_valid.empty:\n            raise ValueError(\"Empty data from dataset, please check your dataset config.\")\n\n        x_train, y_train = df_train[\"feature\"], df_train[\"label\"]\n        x_valid, y_valid = df_valid[\"feature\"], df_valid[\"label\"]\n\n        save_path = get_or_create_path(save_path)\n        stop_steps = 0\n        train_loss = 0\n        best_score = -np.inf\n        best_epoch = 0\n        evals_result[\"train\"] = []\n        evals_result[\"valid\"] = []\n\n        # train\n        self.logger.info(\"training...\")\n        self.fitted = True\n\n        for step in range(self.n_epochs):\n            self.logger.info(\"Epoch%d:\", step)\n            self.logger.info(\"training...\")\n            self.train_epoch(x_train, y_train)\n            self.logger.info(\"evaluating...\")\n            train_loss, train_score = self.test_epoch(x_train, y_train)\n            val_loss, val_score = self.test_epoch(x_valid, y_valid)\n            self.logger.info(\"train %.6f, valid %.6f\" % (train_score, val_score))\n            evals_result[\"train\"].append(train_score)\n            evals_result[\"valid\"].append(val_score)\n\n            if val_score > best_score:\n                best_score = val_score\n                stop_steps = 0\n                best_epoch = step\n                best_param = copy.deepcopy(self.ALSTM_model.state_dict())\n            else:\n                stop_steps += 1\n                if stop_steps >= self.early_stop:\n                    self.logger.info(\"early stop\")\n                    break\n\n        self.logger.info(\"best score: %.6lf @ %d\" % (best_score, best_epoch))\n        self.ALSTM_model.load_state_dict(best_param)\n        torch.save(best_param, save_path)\n\n        if self.use_gpu:\n            torch.cuda.empty_cache()\n\n    def predict(self, dataset: DatasetH, segment: Union[Text, slice] = \"test\"):\n        if not self.fitted:\n            raise ValueError(\"model is not fitted yet!\")\n\n        x_test = dataset.prepare(segment, col_set=\"feature\", data_key=DataHandlerLP.DK_I)\n        index = x_test.index\n        self.ALSTM_model.eval()\n        x_values = x_test.values\n        sample_num = x_values.shape[0]\n        preds = []\n\n        for begin in range(sample_num)[:: self.batch_size]:\n            if sample_num - begin < self.batch_size:\n                end = sample_num\n            else:\n                end = begin + self.batch_size\n\n            x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)\n\n            with torch.no_grad():\n                pred = self.ALSTM_model(x_batch).detach().cpu().numpy()\n\n            preds.append(pred)\n\n        return pd.Series(np.concatenate(preds), index=index)\n\n\nclass ALSTMModel(nn.Module):\n    def __init__(self, d_feat=6, hidden_size=64, num_layers=2, dropout=0.0, rnn_type=\"GRU\"):\n        super().__init__()\n        self.hid_size = hidden_size\n        self.input_size = d_feat\n        self.dropout = dropout\n        self.rnn_type = rnn_type\n        self.rnn_layer = num_layers\n        self._build_model()\n\n    def _build_model(self):\n        try:\n            klass = getattr(nn, self.rnn_type.upper())\n        except Exception as e:\n            raise ValueError(\"unknown rnn_type `%s`\" % self.rnn_type) from e\n        self.net = nn.Sequential()\n        self.net.add_module(\"fc_in\", nn.Linear(in_features=self.input_size, out_features=self.hid_size))\n        self.net.add_module(\"act\", nn.Tanh())\n        self.rnn = klass(\n            input_size=self.hid_size,\n            hidden_size=self.hid_size,\n            num_layers=self.rnn_layer,\n            batch_first=True,\n            dropout=self.dropout,\n        )\n        self.fc_out = nn.Linear(in_features=self.hid_size * 2, out_features=1)\n        self.att_net = nn.Sequential()\n        self.att_net.add_module(\n            \"att_fc_in\",\n            nn.Linear(in_features=self.hid_size, out_features=int(self.hid_size / 2)),\n        )\n        self.att_net.add_module(\"att_dropout\", torch.nn.Dropout(self.dropout))\n        self.att_net.add_module(\"att_act\", nn.Tanh())\n        self.att_net.add_module(\n            \"att_fc_out\",\n            nn.Linear(in_features=int(self.hid_size / 2), out_features=1, bias=False),\n        )\n        self.att_net.add_module(\"att_softmax\", nn.Softmax(dim=1))\n\n    def forward(self, inputs):\n        # inputs: [batch_size, input_size*input_day]\n        inputs = inputs.view(len(inputs), self.input_size, -1)\n        inputs = inputs.permute(0, 2, 1)  # [batch, input_size, seq_len] -> [batch, seq_len, input_size]\n        rnn_out, _ = self.rnn(self.net(inputs))  # [batch, seq_len, num_directions * hidden_size]\n        attention_score = self.att_net(rnn_out)  # [batch, seq_len, 1]\n        out_att = torch.mul(rnn_out, attention_score)\n        out_att = torch.sum(out_att, dim=1)\n        out = self.fc_out(\n            torch.cat((rnn_out[:, -1, :], out_att), dim=1)\n        )  # [batch, seq_len, num_directions * hidden_size] -> [batch, 1]\n        return out[..., 0]\n"
  },
  {
    "path": "qlib/contrib/model/pytorch_alstm_ts.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\n\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport numpy as np\nimport pandas as pd\nfrom typing import Text, Union\nimport copy\nfrom ...utils import get_or_create_path\nfrom ...log import get_module_logger\n\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.utils.data import DataLoader\n\nfrom .pytorch_utils import count_parameters\nfrom ...model.base import Model\nfrom ...data.dataset import DatasetH\nfrom ...data.dataset.handler import DataHandlerLP\nfrom ...model.utils import ConcatDataset\nfrom ...data.dataset.weight import Reweighter\n\n\nclass ALSTM(Model):\n    \"\"\"ALSTM Model\n\n    Parameters\n    ----------\n    d_feat : int\n        input dimension for each time step\n    metric: str\n        the evaluation metric used in early stop\n    optimizer : str\n        optimizer name\n    GPU : int\n        the GPU ID used for training\n    \"\"\"\n\n    def __init__(\n        self,\n        d_feat=6,\n        hidden_size=64,\n        num_layers=2,\n        dropout=0.0,\n        n_epochs=200,\n        lr=0.001,\n        metric=\"\",\n        batch_size=2000,\n        early_stop=20,\n        loss=\"mse\",\n        optimizer=\"adam\",\n        n_jobs=10,\n        GPU=0,\n        seed=None,\n        **kwargs,\n    ):\n        # Set logger.\n        self.logger = get_module_logger(\"ALSTM\")\n        self.logger.info(\"ALSTM pytorch version...\")\n\n        # set hyper-parameters.\n        self.d_feat = d_feat\n        self.hidden_size = hidden_size\n        self.num_layers = num_layers\n        self.dropout = dropout\n        self.n_epochs = n_epochs\n        self.lr = lr\n        self.metric = metric\n        self.batch_size = batch_size\n        self.early_stop = early_stop\n        self.optimizer = optimizer.lower()\n        self.loss = loss\n        self.device = torch.device(\"cuda:%d\" % (GPU) if torch.cuda.is_available() and GPU >= 0 else \"cpu\")\n        self.n_jobs = n_jobs\n        self.seed = seed\n\n        self.logger.info(\n            \"ALSTM parameters setting:\"\n            \"\\nd_feat : {}\"\n            \"\\nhidden_size : {}\"\n            \"\\nnum_layers : {}\"\n            \"\\ndropout : {}\"\n            \"\\nn_epochs : {}\"\n            \"\\nlr : {}\"\n            \"\\nmetric : {}\"\n            \"\\nbatch_size : {}\"\n            \"\\nearly_stop : {}\"\n            \"\\noptimizer : {}\"\n            \"\\nloss_type : {}\"\n            \"\\ndevice : {}\"\n            \"\\nn_jobs : {}\"\n            \"\\nuse_GPU : {}\"\n            \"\\nseed : {}\".format(\n                d_feat,\n                hidden_size,\n                num_layers,\n                dropout,\n                n_epochs,\n                lr,\n                metric,\n                batch_size,\n                early_stop,\n                optimizer.lower(),\n                loss,\n                self.device,\n                n_jobs,\n                self.use_gpu,\n                seed,\n            )\n        )\n\n        if self.seed is not None:\n            np.random.seed(self.seed)\n            torch.manual_seed(self.seed)\n\n        self.ALSTM_model = ALSTMModel(\n            d_feat=self.d_feat,\n            hidden_size=self.hidden_size,\n            num_layers=self.num_layers,\n            dropout=self.dropout,\n        )\n        self.logger.info(\"model:\\n{:}\".format(self.ALSTM_model))\n        self.logger.info(\"model size: {:.4f} MB\".format(count_parameters(self.ALSTM_model)))\n\n        if optimizer.lower() == \"adam\":\n            self.train_optimizer = optim.Adam(self.ALSTM_model.parameters(), lr=self.lr)\n        elif optimizer.lower() == \"gd\":\n            self.train_optimizer = optim.SGD(self.ALSTM_model.parameters(), lr=self.lr)\n        else:\n            raise NotImplementedError(\"optimizer {} is not supported!\".format(optimizer))\n\n        self.fitted = False\n        self.ALSTM_model.to(self.device)\n\n    @property\n    def use_gpu(self):\n        return self.device != torch.device(\"cpu\")\n\n    def mse(self, pred, label, weight):\n        loss = weight * (pred - label) ** 2\n        return torch.mean(loss)\n\n    def loss_fn(self, pred, label, weight=None):\n        mask = ~torch.isnan(label)\n\n        if weight is None:\n            weight = torch.ones_like(label)\n\n        if self.loss == \"mse\":\n            return self.mse(pred[mask], label[mask], weight[mask])\n\n        raise ValueError(\"unknown loss `%s`\" % self.loss)\n\n    def metric_fn(self, pred, label):\n        mask = torch.isfinite(label)\n\n        if self.metric in (\"\", \"loss\"):\n            return -self.loss_fn(pred[mask], label[mask])\n        elif self.metric == \"mse\":\n            mask = ~torch.isnan(label)\n            weight = torch.ones_like(label)\n            return -self.mse(pred[mask], label[mask], weight[mask])\n\n        raise ValueError(\"unknown metric `%s`\" % self.metric)\n\n    def train_epoch(self, data_loader):\n        self.ALSTM_model.train()\n\n        for data, weight in data_loader:\n            feature = data[:, :, 0:-1].to(self.device)\n            label = data[:, -1, -1].to(self.device)\n\n            pred = self.ALSTM_model(feature.float())\n            loss = self.loss_fn(pred, label, weight.to(self.device))\n\n            self.train_optimizer.zero_grad()\n            loss.backward()\n            torch.nn.utils.clip_grad_value_(self.ALSTM_model.parameters(), 3.0)\n            self.train_optimizer.step()\n\n    def test_epoch(self, data_loader):\n        self.ALSTM_model.eval()\n\n        scores = []\n        losses = []\n\n        for data, weight in data_loader:\n            feature = data[:, :, 0:-1].to(self.device)\n            # feature[torch.isnan(feature)] = 0\n            label = data[:, -1, -1].to(self.device)\n\n            with torch.no_grad():\n                pred = self.ALSTM_model(feature.float())\n                loss = self.loss_fn(pred, label, weight.to(self.device))\n                losses.append(loss.item())\n\n                score = self.metric_fn(pred, label)\n                scores.append(score.item())\n\n        return np.mean(losses), np.mean(scores)\n\n    def fit(\n        self,\n        dataset,\n        evals_result=dict(),\n        save_path=None,\n        reweighter=None,\n    ):\n        dl_train = dataset.prepare(\"train\", col_set=[\"feature\", \"label\"], data_key=DataHandlerLP.DK_L)\n        dl_valid = dataset.prepare(\"valid\", col_set=[\"feature\", \"label\"], data_key=DataHandlerLP.DK_L)\n        if dl_train.empty or dl_valid.empty:\n            raise ValueError(\"Empty data from dataset, please check your dataset config.\")\n\n        dl_train.config(fillna_type=\"ffill+bfill\")  # process nan brought by dataloader\n        dl_valid.config(fillna_type=\"ffill+bfill\")  # process nan brought by dataloader\n\n        if reweighter is None:\n            wl_train = np.ones(len(dl_train))\n            wl_valid = np.ones(len(dl_valid))\n        elif isinstance(reweighter, Reweighter):\n            wl_train = reweighter.reweight(dl_train)\n            wl_valid = reweighter.reweight(dl_valid)\n        else:\n            raise ValueError(\"Unsupported reweighter type.\")\n\n        train_loader = DataLoader(\n            ConcatDataset(dl_train, wl_train),\n            batch_size=self.batch_size,\n            shuffle=True,\n            num_workers=self.n_jobs,\n            drop_last=True,\n        )\n        valid_loader = DataLoader(\n            ConcatDataset(dl_valid, wl_valid),\n            batch_size=self.batch_size,\n            shuffle=False,\n            num_workers=self.n_jobs,\n            drop_last=True,\n        )\n\n        save_path = get_or_create_path(save_path)\n\n        stop_steps = 0\n        train_loss = 0\n        best_score = -np.inf\n        best_epoch = 0\n        evals_result[\"train\"] = []\n        evals_result[\"valid\"] = []\n\n        # train\n        self.logger.info(\"training...\")\n        self.fitted = True\n\n        for step in range(self.n_epochs):\n            self.logger.info(\"Epoch%d:\", step)\n            self.logger.info(\"training...\")\n            self.train_epoch(train_loader)\n            self.logger.info(\"evaluating...\")\n            train_loss, train_score = self.test_epoch(train_loader)\n            val_loss, val_score = self.test_epoch(valid_loader)\n            self.logger.info(\"train %.6f, valid %.6f\" % (train_score, val_score))\n            evals_result[\"train\"].append(train_score)\n            evals_result[\"valid\"].append(val_score)\n\n            if val_score > best_score:\n                best_score = val_score\n                stop_steps = 0\n                best_epoch = step\n                best_param = copy.deepcopy(self.ALSTM_model.state_dict())\n            else:\n                stop_steps += 1\n                if stop_steps >= self.early_stop:\n                    self.logger.info(\"early stop\")\n                    break\n\n        self.logger.info(\"best score: %.6lf @ %d\" % (best_score, best_epoch))\n        self.ALSTM_model.load_state_dict(best_param)\n        torch.save(best_param, save_path)\n\n        if self.use_gpu:\n            torch.cuda.empty_cache()\n\n    def predict(self, dataset: DatasetH, segment: Union[Text, slice] = \"test\"):\n        if not self.fitted:\n            raise ValueError(\"model is not fitted yet!\")\n\n        dl_test = dataset.prepare(segment, col_set=[\"feature\", \"label\"], data_key=DataHandlerLP.DK_I)\n        dl_test.config(fillna_type=\"ffill+bfill\")\n        test_loader = DataLoader(dl_test, batch_size=self.batch_size, num_workers=self.n_jobs)\n        self.ALSTM_model.eval()\n        preds = []\n\n        for data in test_loader:\n            feature = data[:, :, 0:-1].to(self.device)\n\n            with torch.no_grad():\n                pred = self.ALSTM_model(feature.float()).detach().cpu().numpy()\n\n            preds.append(pred)\n\n        return pd.Series(np.concatenate(preds), index=dl_test.get_index())\n\n\nclass ALSTMModel(nn.Module):\n    def __init__(self, d_feat=6, hidden_size=64, num_layers=2, dropout=0.0, rnn_type=\"GRU\"):\n        super().__init__()\n        self.hid_size = hidden_size\n        self.input_size = d_feat\n        self.dropout = dropout\n        self.rnn_type = rnn_type\n        self.rnn_layer = num_layers\n        self._build_model()\n\n    def _build_model(self):\n        try:\n            klass = getattr(nn, self.rnn_type.upper())\n        except Exception as e:\n            raise ValueError(\"unknown rnn_type `%s`\" % self.rnn_type) from e\n        self.net = nn.Sequential()\n        self.net.add_module(\"fc_in\", nn.Linear(in_features=self.input_size, out_features=self.hid_size))\n        self.net.add_module(\"act\", nn.Tanh())\n        self.rnn = klass(\n            input_size=self.hid_size,\n            hidden_size=self.hid_size,\n            num_layers=self.rnn_layer,\n            batch_first=True,\n            dropout=self.dropout,\n        )\n        self.fc_out = nn.Linear(in_features=self.hid_size * 2, out_features=1)\n        self.att_net = nn.Sequential()\n        self.att_net.add_module(\n            \"att_fc_in\",\n            nn.Linear(in_features=self.hid_size, out_features=int(self.hid_size / 2)),\n        )\n        self.att_net.add_module(\"att_dropout\", torch.nn.Dropout(self.dropout))\n        self.att_net.add_module(\"att_act\", nn.Tanh())\n        self.att_net.add_module(\n            \"att_fc_out\",\n            nn.Linear(in_features=int(self.hid_size / 2), out_features=1, bias=False),\n        )\n        self.att_net.add_module(\"att_softmax\", nn.Softmax(dim=1))\n\n    def forward(self, inputs):\n        rnn_out, _ = self.rnn(self.net(inputs))  # [batch, seq_len, num_directions * hidden_size]\n        attention_score = self.att_net(rnn_out)  # [batch, seq_len, 1]\n        out_att = torch.mul(rnn_out, attention_score)\n        out_att = torch.sum(out_att, dim=1)\n        out = self.fc_out(\n            torch.cat((rnn_out[:, -1, :], out_att), dim=1)\n        )  # [batch, seq_len, num_directions * hidden_size] -> [batch, 1]\n        return out[..., 0]\n"
  },
  {
    "path": "qlib/contrib/model/pytorch_gats.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\n\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport numpy as np\nimport pandas as pd\nfrom typing import Text, Union\nimport copy\nfrom ...utils import get_or_create_path\nfrom ...log import get_module_logger\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\n\nfrom .pytorch_utils import count_parameters\nfrom ...model.base import Model\nfrom ...data.dataset import DatasetH\nfrom ...data.dataset.handler import DataHandlerLP\nfrom ...contrib.model.pytorch_lstm import LSTMModel\nfrom ...contrib.model.pytorch_gru import GRUModel\n\n\nclass GATs(Model):\n    \"\"\"GATs Model\n\n    Parameters\n    ----------\n    lr : float\n        learning rate\n    d_feat : int\n        input dimensions for each time step\n    metric : str\n        the evaluation metric used in early stop\n    optimizer : str\n        optimizer name\n    GPU : int\n        the GPU ID used for training\n    \"\"\"\n\n    def __init__(\n        self,\n        d_feat=6,\n        hidden_size=64,\n        num_layers=2,\n        dropout=0.0,\n        n_epochs=200,\n        lr=0.001,\n        metric=\"\",\n        early_stop=20,\n        loss=\"mse\",\n        base_model=\"GRU\",\n        model_path=None,\n        optimizer=\"adam\",\n        GPU=0,\n        seed=None,\n        **kwargs,\n    ):\n        # Set logger.\n        self.logger = get_module_logger(\"GATs\")\n        self.logger.info(\"GATs pytorch version...\")\n\n        # set hyper-parameters.\n        self.d_feat = d_feat\n        self.hidden_size = hidden_size\n        self.num_layers = num_layers\n        self.dropout = dropout\n        self.n_epochs = n_epochs\n        self.lr = lr\n        self.metric = metric\n        self.early_stop = early_stop\n        self.optimizer = optimizer.lower()\n        self.loss = loss\n        self.base_model = base_model\n        self.model_path = model_path\n        self.device = torch.device(\"cuda:%d\" % (GPU) if torch.cuda.is_available() and GPU >= 0 else \"cpu\")\n        self.seed = seed\n\n        self.logger.info(\n            \"GATs parameters setting:\"\n            \"\\nd_feat : {}\"\n            \"\\nhidden_size : {}\"\n            \"\\nnum_layers : {}\"\n            \"\\ndropout : {}\"\n            \"\\nn_epochs : {}\"\n            \"\\nlr : {}\"\n            \"\\nmetric : {}\"\n            \"\\nearly_stop : {}\"\n            \"\\noptimizer : {}\"\n            \"\\nloss_type : {}\"\n            \"\\nbase_model : {}\"\n            \"\\nmodel_path : {}\"\n            \"\\ndevice : {}\"\n            \"\\nuse_GPU : {}\"\n            \"\\nseed : {}\".format(\n                d_feat,\n                hidden_size,\n                num_layers,\n                dropout,\n                n_epochs,\n                lr,\n                metric,\n                early_stop,\n                optimizer.lower(),\n                loss,\n                base_model,\n                model_path,\n                self.device,\n                self.use_gpu,\n                seed,\n            )\n        )\n\n        if self.seed is not None:\n            np.random.seed(self.seed)\n            torch.manual_seed(self.seed)\n\n        self.GAT_model = GATModel(\n            d_feat=self.d_feat,\n            hidden_size=self.hidden_size,\n            num_layers=self.num_layers,\n            dropout=self.dropout,\n            base_model=self.base_model,\n        )\n        self.logger.info(\"model:\\n{:}\".format(self.GAT_model))\n        self.logger.info(\"model size: {:.4f} MB\".format(count_parameters(self.GAT_model)))\n\n        if optimizer.lower() == \"adam\":\n            self.train_optimizer = optim.Adam(self.GAT_model.parameters(), lr=self.lr)\n        elif optimizer.lower() == \"gd\":\n            self.train_optimizer = optim.SGD(self.GAT_model.parameters(), lr=self.lr)\n        else:\n            raise NotImplementedError(\"optimizer {} is not supported!\".format(optimizer))\n\n        self.fitted = False\n        self.GAT_model.to(self.device)\n\n    @property\n    def use_gpu(self):\n        return self.device != torch.device(\"cpu\")\n\n    def mse(self, pred, label):\n        loss = (pred - label) ** 2\n        return torch.mean(loss)\n\n    def loss_fn(self, pred, label):\n        mask = ~torch.isnan(label)\n\n        if self.loss == \"mse\":\n            return self.mse(pred[mask], label[mask])\n\n        raise ValueError(\"unknown loss `%s`\" % self.loss)\n\n    def metric_fn(self, pred, label):\n        mask = torch.isfinite(label)\n\n        if self.metric in (\"\", \"loss\"):\n            return -self.loss_fn(pred[mask], label[mask])\n\n        raise ValueError(\"unknown metric `%s`\" % self.metric)\n\n    def get_daily_inter(self, df, shuffle=False):\n        # organize the train data into daily batches\n        daily_count = df.groupby(level=0, group_keys=False).size().values\n        daily_index = np.roll(np.cumsum(daily_count), 1)\n        daily_index[0] = 0\n        if shuffle:\n            # shuffle data\n            daily_shuffle = list(zip(daily_index, daily_count))\n            np.random.shuffle(daily_shuffle)\n            daily_index, daily_count = zip(*daily_shuffle)\n        return daily_index, daily_count\n\n    def train_epoch(self, x_train, y_train):\n        x_train_values = x_train.values\n        y_train_values = np.squeeze(y_train.values)\n        self.GAT_model.train()\n\n        # organize the train data into daily batches\n        daily_index, daily_count = self.get_daily_inter(x_train, shuffle=True)\n\n        for idx, count in zip(daily_index, daily_count):\n            batch = slice(idx, idx + count)\n            feature = torch.from_numpy(x_train_values[batch]).float().to(self.device)\n            label = torch.from_numpy(y_train_values[batch]).float().to(self.device)\n\n            pred = self.GAT_model(feature)\n            loss = self.loss_fn(pred, label)\n\n            self.train_optimizer.zero_grad()\n            loss.backward()\n            torch.nn.utils.clip_grad_value_(self.GAT_model.parameters(), 3.0)\n            self.train_optimizer.step()\n\n    def test_epoch(self, data_x, data_y):\n        # prepare training data\n        x_values = data_x.values\n        y_values = np.squeeze(data_y.values)\n\n        self.GAT_model.eval()\n\n        scores = []\n        losses = []\n\n        # organize the test data into daily batches\n        daily_index, daily_count = self.get_daily_inter(data_x, shuffle=False)\n\n        for idx, count in zip(daily_index, daily_count):\n            batch = slice(idx, idx + count)\n            feature = torch.from_numpy(x_values[batch]).float().to(self.device)\n            label = torch.from_numpy(y_values[batch]).float().to(self.device)\n\n            pred = self.GAT_model(feature)\n            loss = self.loss_fn(pred, label)\n            losses.append(loss.item())\n\n            score = self.metric_fn(pred, label)\n            scores.append(score.item())\n\n        return np.mean(losses), np.mean(scores)\n\n    def fit(\n        self,\n        dataset: DatasetH,\n        evals_result=dict(),\n        save_path=None,\n    ):\n        df_train, df_valid, df_test = dataset.prepare(\n            [\"train\", \"valid\", \"test\"],\n            col_set=[\"feature\", \"label\"],\n            data_key=DataHandlerLP.DK_L,\n        )\n        if df_train.empty or df_valid.empty:\n            raise ValueError(\"Empty data from dataset, please check your dataset config.\")\n\n        x_train, y_train = df_train[\"feature\"], df_train[\"label\"]\n        x_valid, y_valid = df_valid[\"feature\"], df_valid[\"label\"]\n\n        save_path = get_or_create_path(save_path)\n        stop_steps = 0\n        best_score = -np.inf\n        best_epoch = 0\n        evals_result[\"train\"] = []\n        evals_result[\"valid\"] = []\n\n        # load pretrained base_model\n        if self.base_model == \"LSTM\":\n            pretrained_model = LSTMModel()\n        elif self.base_model == \"GRU\":\n            pretrained_model = GRUModel()\n        else:\n            raise ValueError(\"unknown base model name `%s`\" % self.base_model)\n\n        if self.model_path is not None:\n            self.logger.info(\"Loading pretrained model...\")\n            pretrained_model.load_state_dict(torch.load(self.model_path, map_location=self.device))\n\n        model_dict = self.GAT_model.state_dict()\n        pretrained_dict = {\n            k: v for k, v in pretrained_model.state_dict().items() if k in model_dict  # pylint: disable=E1135\n        }\n        model_dict.update(pretrained_dict)\n        self.GAT_model.load_state_dict(model_dict)\n        self.logger.info(\"Loading pretrained model Done...\")\n\n        # train\n        self.logger.info(\"training...\")\n        self.fitted = True\n\n        for step in range(self.n_epochs):\n            self.logger.info(\"Epoch%d:\", step)\n            self.logger.info(\"training...\")\n            self.train_epoch(x_train, y_train)\n            self.logger.info(\"evaluating...\")\n            train_loss, train_score = self.test_epoch(x_train, y_train)\n            val_loss, val_score = self.test_epoch(x_valid, y_valid)\n            self.logger.info(\"train %.6f, valid %.6f\" % (train_score, val_score))\n            evals_result[\"train\"].append(train_score)\n            evals_result[\"valid\"].append(val_score)\n\n            if val_score > best_score:\n                best_score = val_score\n                stop_steps = 0\n                best_epoch = step\n                best_param = copy.deepcopy(self.GAT_model.state_dict())\n            else:\n                stop_steps += 1\n                if stop_steps >= self.early_stop:\n                    self.logger.info(\"early stop\")\n                    break\n\n        self.logger.info(\"best score: %.6lf @ %d\" % (best_score, best_epoch))\n        self.GAT_model.load_state_dict(best_param)\n        torch.save(best_param, save_path)\n\n        if self.use_gpu:\n            torch.cuda.empty_cache()\n\n    def predict(self, dataset: DatasetH, segment: Union[Text, slice] = \"test\"):\n        if not self.fitted:\n            raise ValueError(\"model is not fitted yet!\")\n\n        x_test = dataset.prepare(segment, col_set=\"feature\")\n        index = x_test.index\n        self.GAT_model.eval()\n        x_values = x_test.values\n        preds = []\n\n        # organize the data into daily batches\n        daily_index, daily_count = self.get_daily_inter(x_test, shuffle=False)\n\n        for idx, count in zip(daily_index, daily_count):\n            batch = slice(idx, idx + count)\n            x_batch = torch.from_numpy(x_values[batch]).float().to(self.device)\n\n            with torch.no_grad():\n                pred = self.GAT_model(x_batch).detach().cpu().numpy()\n\n            preds.append(pred)\n\n        return pd.Series(np.concatenate(preds), index=index)\n\n\nclass GATModel(nn.Module):\n    def __init__(self, d_feat=6, hidden_size=64, num_layers=2, dropout=0.0, base_model=\"GRU\"):\n        super().__init__()\n\n        if base_model == \"GRU\":\n            self.rnn = nn.GRU(\n                input_size=d_feat,\n                hidden_size=hidden_size,\n                num_layers=num_layers,\n                batch_first=True,\n                dropout=dropout,\n            )\n        elif base_model == \"LSTM\":\n            self.rnn = nn.LSTM(\n                input_size=d_feat,\n                hidden_size=hidden_size,\n                num_layers=num_layers,\n                batch_first=True,\n                dropout=dropout,\n            )\n        else:\n            raise ValueError(\"unknown base model name `%s`\" % base_model)\n\n        self.hidden_size = hidden_size\n        self.d_feat = d_feat\n        self.transformation = nn.Linear(self.hidden_size, self.hidden_size)\n        self.a = nn.Parameter(torch.randn(self.hidden_size * 2, 1))\n        self.a.requires_grad = True\n        self.fc = nn.Linear(self.hidden_size, self.hidden_size)\n        self.fc_out = nn.Linear(hidden_size, 1)\n        self.leaky_relu = nn.LeakyReLU()\n        self.softmax = nn.Softmax(dim=1)\n\n    def cal_attention(self, x, y):\n        x = self.transformation(x)\n        y = self.transformation(y)\n\n        sample_num = x.shape[0]\n        dim = x.shape[1]\n        e_x = x.expand(sample_num, sample_num, dim)\n        e_y = torch.transpose(e_x, 0, 1)\n        attention_in = torch.cat((e_x, e_y), 2).view(-1, dim * 2)\n        self.a_t = torch.t(self.a)\n        attention_out = self.a_t.mm(torch.t(attention_in)).view(sample_num, sample_num)\n        attention_out = self.leaky_relu(attention_out)\n        att_weight = self.softmax(attention_out)\n        return att_weight\n\n    def forward(self, x):\n        # x: [N, F*T]\n        x = x.reshape(len(x), self.d_feat, -1)  # [N, F, T]\n        x = x.permute(0, 2, 1)  # [N, T, F]\n        out, _ = self.rnn(x)\n        hidden = out[:, -1, :]\n        att_weight = self.cal_attention(hidden, hidden)\n        hidden = att_weight.mm(hidden) + hidden\n        hidden = self.fc(hidden)\n        hidden = self.leaky_relu(hidden)\n        return self.fc_out(hidden).squeeze()\n"
  },
  {
    "path": "qlib/contrib/model/pytorch_gats_ts.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\n\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport numpy as np\nimport pandas as pd\nimport copy\nfrom ...utils import get_or_create_path\nfrom ...log import get_module_logger\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.utils.data import DataLoader\nfrom torch.utils.data import Sampler\n\nfrom .pytorch_utils import count_parameters\nfrom ...model.base import Model\nfrom ...data.dataset.handler import DataHandlerLP\nfrom ...contrib.model.pytorch_lstm import LSTMModel\nfrom ...contrib.model.pytorch_gru import GRUModel\n\n\nclass DailyBatchSampler(Sampler):\n    def __init__(self, data_source):\n        self.data_source = data_source\n        # calculate number of samples in each batch\n        self.daily_count = (\n            pd.Series(index=self.data_source.get_index()).groupby(\"datetime\", group_keys=False).size().values\n        )\n        self.daily_index = np.roll(np.cumsum(self.daily_count), 1)  # calculate begin index of each batch\n        self.daily_index[0] = 0\n\n    def __iter__(self):\n        for idx, count in zip(self.daily_index, self.daily_count):\n            yield np.arange(idx, idx + count)\n\n    def __len__(self):\n        return len(self.data_source)\n\n\nclass GATs(Model):\n    \"\"\"GATs Model\n\n    Parameters\n    ----------\n    lr : float\n        learning rate\n    d_feat : int\n        input dimensions for each time step\n    metric : str\n        the evaluation metric used in early stop\n    optimizer : str\n        optimizer name\n    GPU : int\n        the GPU ID used for training\n    \"\"\"\n\n    def __init__(\n        self,\n        d_feat=20,\n        hidden_size=64,\n        num_layers=2,\n        dropout=0.0,\n        n_epochs=200,\n        lr=0.001,\n        metric=\"\",\n        early_stop=20,\n        loss=\"mse\",\n        base_model=\"GRU\",\n        model_path=None,\n        optimizer=\"adam\",\n        GPU=0,\n        n_jobs=10,\n        seed=None,\n        **kwargs,\n    ):\n        # Set logger.\n        self.logger = get_module_logger(\"GATs\")\n        self.logger.info(\"GATs pytorch version...\")\n\n        # set hyper-parameters.\n        self.d_feat = d_feat\n        self.hidden_size = hidden_size\n        self.num_layers = num_layers\n        self.dropout = dropout\n        self.n_epochs = n_epochs\n        self.lr = lr\n        self.metric = metric\n        self.early_stop = early_stop\n        self.optimizer = optimizer.lower()\n        self.loss = loss\n        self.base_model = base_model\n        self.model_path = model_path\n        self.device = torch.device(\"cuda:%d\" % (GPU) if torch.cuda.is_available() and GPU >= 0 else \"cpu\")\n        self.n_jobs = n_jobs\n        self.seed = seed\n\n        self.logger.info(\n            \"GATs parameters setting:\"\n            \"\\nd_feat : {}\"\n            \"\\nhidden_size : {}\"\n            \"\\nnum_layers : {}\"\n            \"\\ndropout : {}\"\n            \"\\nn_epochs : {}\"\n            \"\\nlr : {}\"\n            \"\\nmetric : {}\"\n            \"\\nearly_stop : {}\"\n            \"\\noptimizer : {}\"\n            \"\\nloss_type : {}\"\n            \"\\nbase_model : {}\"\n            \"\\nmodel_path : {}\"\n            \"\\nvisible_GPU : {}\"\n            \"\\nuse_GPU : {}\"\n            \"\\nseed : {}\".format(\n                d_feat,\n                hidden_size,\n                num_layers,\n                dropout,\n                n_epochs,\n                lr,\n                metric,\n                early_stop,\n                optimizer.lower(),\n                loss,\n                base_model,\n                model_path,\n                GPU,\n                self.use_gpu,\n                seed,\n            )\n        )\n\n        if self.seed is not None:\n            np.random.seed(self.seed)\n            torch.manual_seed(self.seed)\n\n        self.GAT_model = GATModel(\n            d_feat=self.d_feat,\n            hidden_size=self.hidden_size,\n            num_layers=self.num_layers,\n            dropout=self.dropout,\n            base_model=self.base_model,\n        )\n        self.logger.info(\"model:\\n{:}\".format(self.GAT_model))\n        self.logger.info(\"model size: {:.4f} MB\".format(count_parameters(self.GAT_model)))\n\n        if optimizer.lower() == \"adam\":\n            self.train_optimizer = optim.Adam(self.GAT_model.parameters(), lr=self.lr)\n        elif optimizer.lower() == \"gd\":\n            self.train_optimizer = optim.SGD(self.GAT_model.parameters(), lr=self.lr)\n        else:\n            raise NotImplementedError(\"optimizer {} is not supported!\".format(optimizer))\n\n        self.fitted = False\n        self.GAT_model.to(self.device)\n\n    @property\n    def use_gpu(self):\n        return self.device != torch.device(\"cpu\")\n\n    def mse(self, pred, label):\n        loss = (pred - label) ** 2\n        return torch.mean(loss)\n\n    def loss_fn(self, pred, label):\n        mask = ~torch.isnan(label)\n\n        if self.loss == \"mse\":\n            return self.mse(pred[mask], label[mask])\n\n        raise ValueError(\"unknown loss `%s`\" % self.loss)\n\n    def metric_fn(self, pred, label):\n        mask = torch.isfinite(label)\n\n        if self.metric in (\"\", \"loss\"):\n            return -self.loss_fn(pred[mask], label[mask])\n\n        raise ValueError(\"unknown metric `%s`\" % self.metric)\n\n    def get_daily_inter(self, df, shuffle=False):\n        # organize the train data into daily batches\n        daily_count = df.groupby(level=0, group_keys=False).size().values\n        daily_index = np.roll(np.cumsum(daily_count), 1)\n        daily_index[0] = 0\n        if shuffle:\n            # shuffle data\n            daily_shuffle = list(zip(daily_index, daily_count))\n            np.random.shuffle(daily_shuffle)\n            daily_index, daily_count = zip(*daily_shuffle)\n        return daily_index, daily_count\n\n    def train_epoch(self, data_loader):\n        self.GAT_model.train()\n\n        for data in data_loader:\n            data = data.squeeze()\n            feature = data[:, :, 0:-1].to(self.device)\n            label = data[:, -1, -1].to(self.device)\n\n            pred = self.GAT_model(feature.float())\n            loss = self.loss_fn(pred, label)\n\n            self.train_optimizer.zero_grad()\n            loss.backward()\n            torch.nn.utils.clip_grad_value_(self.GAT_model.parameters(), 3.0)\n            self.train_optimizer.step()\n\n    def test_epoch(self, data_loader):\n        self.GAT_model.eval()\n\n        scores = []\n        losses = []\n\n        for data in data_loader:\n            data = data.squeeze()\n            feature = data[:, :, 0:-1].to(self.device)\n            # feature[torch.isnan(feature)] = 0\n            label = data[:, -1, -1].to(self.device)\n\n            pred = self.GAT_model(feature.float())\n            loss = self.loss_fn(pred, label)\n            losses.append(loss.item())\n\n            score = self.metric_fn(pred, label)\n            scores.append(score.item())\n\n        return np.mean(losses), np.mean(scores)\n\n    def fit(\n        self,\n        dataset,\n        evals_result=dict(),\n        save_path=None,\n    ):\n        dl_train = dataset.prepare(\"train\", col_set=[\"feature\", \"label\"], data_key=DataHandlerLP.DK_L)\n        dl_valid = dataset.prepare(\"valid\", col_set=[\"feature\", \"label\"], data_key=DataHandlerLP.DK_L)\n        if dl_train.empty or dl_valid.empty:\n            raise ValueError(\"Empty data from dataset, please check your dataset config.\")\n\n        dl_train.config(fillna_type=\"ffill+bfill\")  # process nan brought by dataloader\n        dl_valid.config(fillna_type=\"ffill+bfill\")  # process nan brought by dataloader\n\n        sampler_train = DailyBatchSampler(dl_train)\n        sampler_valid = DailyBatchSampler(dl_valid)\n\n        train_loader = DataLoader(dl_train, sampler=sampler_train, num_workers=self.n_jobs, drop_last=True)\n        valid_loader = DataLoader(dl_valid, sampler=sampler_valid, num_workers=self.n_jobs, drop_last=True)\n\n        save_path = get_or_create_path(save_path)\n\n        stop_steps = 0\n        train_loss = 0\n        best_score = -np.inf\n        best_epoch = 0\n        evals_result[\"train\"] = []\n        evals_result[\"valid\"] = []\n\n        # load pretrained base_model\n        if self.base_model == \"LSTM\":\n            pretrained_model = LSTMModel(d_feat=self.d_feat, hidden_size=self.hidden_size, num_layers=self.num_layers)\n        elif self.base_model == \"GRU\":\n            pretrained_model = GRUModel(d_feat=self.d_feat, hidden_size=self.hidden_size, num_layers=self.num_layers)\n        else:\n            raise ValueError(\"unknown base model name `%s`\" % self.base_model)\n\n        if self.model_path is not None:\n            self.logger.info(\"Loading pretrained model...\")\n            pretrained_model.load_state_dict(torch.load(self.model_path, map_location=self.device))\n\n        model_dict = self.GAT_model.state_dict()\n        pretrained_dict = {\n            k: v for k, v in pretrained_model.state_dict().items() if k in model_dict  # pylint: disable=E1135\n        }\n        model_dict.update(pretrained_dict)\n        self.GAT_model.load_state_dict(model_dict)\n        self.logger.info(\"Loading pretrained model Done...\")\n\n        # train\n        self.logger.info(\"training...\")\n        self.fitted = True\n\n        for step in range(self.n_epochs):\n            self.logger.info(\"Epoch%d:\", step)\n            self.logger.info(\"training...\")\n            self.train_epoch(train_loader)\n            self.logger.info(\"evaluating...\")\n            train_loss, train_score = self.test_epoch(train_loader)\n            val_loss, val_score = self.test_epoch(valid_loader)\n            self.logger.info(\"train %.6f, valid %.6f\" % (train_score, val_score))\n            evals_result[\"train\"].append(train_score)\n            evals_result[\"valid\"].append(val_score)\n\n            if val_score > best_score:\n                best_score = val_score\n                stop_steps = 0\n                best_epoch = step\n                best_param = copy.deepcopy(self.GAT_model.state_dict())\n            else:\n                stop_steps += 1\n                if stop_steps >= self.early_stop:\n                    self.logger.info(\"early stop\")\n                    break\n\n        self.logger.info(\"best score: %.6lf @ %d\" % (best_score, best_epoch))\n        self.GAT_model.load_state_dict(best_param)\n        torch.save(best_param, save_path)\n\n        if self.use_gpu:\n            torch.cuda.empty_cache()\n\n    def predict(self, dataset):\n        if not self.fitted:\n            raise ValueError(\"model is not fitted yet!\")\n\n        dl_test = dataset.prepare(\"test\", col_set=[\"feature\", \"label\"], data_key=DataHandlerLP.DK_I)\n        dl_test.config(fillna_type=\"ffill+bfill\")\n        sampler_test = DailyBatchSampler(dl_test)\n        test_loader = DataLoader(dl_test, sampler=sampler_test, num_workers=self.n_jobs)\n        self.GAT_model.eval()\n        preds = []\n\n        for data in test_loader:\n            data = data.squeeze()\n            feature = data[:, :, 0:-1].to(self.device)\n\n            with torch.no_grad():\n                pred = self.GAT_model(feature.float()).detach().cpu().numpy()\n\n            preds.append(pred)\n\n        return pd.Series(np.concatenate(preds), index=dl_test.get_index())\n\n\nclass GATModel(nn.Module):\n    def __init__(self, d_feat=6, hidden_size=64, num_layers=2, dropout=0.0, base_model=\"GRU\"):\n        super().__init__()\n\n        if base_model == \"GRU\":\n            self.rnn = nn.GRU(\n                input_size=d_feat,\n                hidden_size=hidden_size,\n                num_layers=num_layers,\n                batch_first=True,\n                dropout=dropout,\n            )\n        elif base_model == \"LSTM\":\n            self.rnn = nn.LSTM(\n                input_size=d_feat,\n                hidden_size=hidden_size,\n                num_layers=num_layers,\n                batch_first=True,\n                dropout=dropout,\n            )\n        else:\n            raise ValueError(\"unknown base model name `%s`\" % base_model)\n\n        self.hidden_size = hidden_size\n        self.d_feat = d_feat\n        self.transformation = nn.Linear(self.hidden_size, self.hidden_size)\n        self.a = nn.Parameter(torch.randn(self.hidden_size * 2, 1))\n        self.a.requires_grad = True\n        self.fc = nn.Linear(self.hidden_size, self.hidden_size)\n        self.fc_out = nn.Linear(hidden_size, 1)\n        self.leaky_relu = nn.LeakyReLU()\n        self.softmax = nn.Softmax(dim=1)\n\n    def cal_attention(self, x, y):\n        x = self.transformation(x)\n        y = self.transformation(y)\n\n        sample_num = x.shape[0]\n        dim = x.shape[1]\n        e_x = x.expand(sample_num, sample_num, dim)\n        e_y = torch.transpose(e_x, 0, 1)\n        attention_in = torch.cat((e_x, e_y), 2).view(-1, dim * 2)\n        self.a_t = torch.t(self.a)\n        attention_out = self.a_t.mm(torch.t(attention_in)).view(sample_num, sample_num)\n        attention_out = self.leaky_relu(attention_out)\n        att_weight = self.softmax(attention_out)\n        return att_weight\n\n    def forward(self, x):\n        out, _ = self.rnn(x)\n        hidden = out[:, -1, :]\n        att_weight = self.cal_attention(hidden, hidden)\n        hidden = att_weight.mm(hidden) + hidden\n        hidden = self.fc(hidden)\n        hidden = self.leaky_relu(hidden)\n        return self.fc_out(hidden).squeeze()\n"
  },
  {
    "path": "qlib/contrib/model/pytorch_general_nn.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\nfrom __future__ import division\nfrom __future__ import print_function\n\nfrom torch.utils.data import DataLoader\n\n\nimport numpy as np\nimport pandas as pd\nfrom typing import Union\nimport copy\n\nimport torch\nimport torch.optim as optim\nfrom torch.optim.lr_scheduler import ReduceLROnPlateau\n\nfrom qlib.data.dataset.weight import Reweighter\n\nfrom .pytorch_utils import count_parameters\nfrom ...model.base import Model\nfrom ...data.dataset import DatasetH, TSDatasetH\nfrom ...data.dataset.handler import DataHandlerLP\nfrom ...utils import (\n    init_instance_by_config,\n    get_or_create_path,\n)\nfrom ...log import get_module_logger\n\nfrom ...model.utils import ConcatDataset\n\n\nclass GeneralPTNN(Model):\n    \"\"\"\n    Motivation:\n        We want to provide a Qlib General Pytorch Model Adaptor\n        You can reuse it for all kinds of Pytorch models.\n        It should include the training and predict process\n\n    Parameters\n    ----------\n    d_feat : int\n        input dimension for each time step\n    metric: str\n        the evaluation metric used in early stop\n    optimizer : str\n        optimizer name\n    GPU : str\n        the GPU ID(s) used for training\n    \"\"\"\n\n    def __init__(\n        self,\n        n_epochs=200,\n        lr=0.001,\n        metric=\"\",\n        batch_size=2000,\n        early_stop=20,\n        loss=\"mse\",\n        weight_decay=0.0,\n        optimizer=\"adam\",\n        n_jobs=10,\n        GPU=0,\n        seed=None,\n        pt_model_uri=\"qlib.contrib.model.pytorch_gru_ts.GRUModel\",\n        pt_model_kwargs={\n            \"d_feat\": 6,\n            \"hidden_size\": 64,\n            \"num_layers\": 2,\n            \"dropout\": 0.0,\n        },\n    ):\n        # Set logger.\n        self.logger = get_module_logger(\"GeneralPTNN\")\n        self.logger.info(\"GeneralPTNN pytorch version...\")\n\n        # set hyper-parameters.\n        self.n_epochs = n_epochs\n        self.lr = lr\n        self.metric = metric\n        self.batch_size = batch_size\n        self.early_stop = early_stop\n        self.optimizer = optimizer.lower()\n        self.loss = loss\n        self.weight_decay = weight_decay\n        self.device = torch.device(\"cuda:%d\" % (GPU) if torch.cuda.is_available() and GPU >= 0 else \"cpu\")\n        self.n_jobs = n_jobs\n        self.seed = seed\n\n        self.pt_model_uri, self.pt_model_kwargs = pt_model_uri, pt_model_kwargs\n        self.dnn_model = init_instance_by_config({\"class\": pt_model_uri, \"kwargs\": pt_model_kwargs})\n\n        self.logger.info(\n            \"GeneralPTNN parameters setting:\"\n            \"\\nn_epochs : {}\"\n            \"\\nlr : {}\"\n            \"\\nmetric : {}\"\n            \"\\nbatch_size : {}\"\n            \"\\nearly_stop : {}\"\n            \"\\noptimizer : {}\"\n            \"\\nloss_type : {}\"\n            \"\\ndevice : {}\"\n            \"\\nn_jobs : {}\"\n            \"\\nuse_GPU : {}\"\n            \"\\nweight_decay : {}\"\n            \"\\nseed : {}\"\n            \"\\npt_model_uri: {}\"\n            \"\\npt_model_kwargs: {}\".format(\n                n_epochs,\n                lr,\n                metric,\n                batch_size,\n                early_stop,\n                optimizer.lower(),\n                loss,\n                self.device,\n                n_jobs,\n                self.use_gpu,\n                weight_decay,\n                seed,\n                pt_model_uri,\n                pt_model_kwargs,\n            )\n        )\n\n        if self.seed is not None:\n            np.random.seed(self.seed)\n            torch.manual_seed(self.seed)\n\n        self.logger.info(\"model:\\n{:}\".format(self.dnn_model))\n        self.logger.info(\"model size: {:.4f} MB\".format(count_parameters(self.dnn_model)))\n\n        if optimizer.lower() == \"adam\":\n            self.train_optimizer = optim.Adam(self.dnn_model.parameters(), lr=self.lr, weight_decay=weight_decay)\n        elif optimizer.lower() == \"gd\":\n            self.train_optimizer = optim.SGD(self.dnn_model.parameters(), lr=self.lr, weight_decay=weight_decay)\n        else:\n            raise NotImplementedError(\"optimizer {} is not supported!\".format(optimizer))\n\n        # === ReduceLROnPlateau learning rate scheduler ===\n        self.lr_scheduler = ReduceLROnPlateau(\n            self.train_optimizer, mode=\"min\", factor=0.5, patience=5, min_lr=1e-6, threshold=1e-5\n        )\n        self.fitted = False\n        self.dnn_model.to(self.device)\n\n    @property\n    def use_gpu(self):\n        return self.device != torch.device(\"cpu\")\n\n    def mse(self, pred, label, weight):\n        loss = weight * (pred - label) ** 2\n        return torch.mean(loss)\n\n    def loss_fn(self, pred, label, weight=None):\n        mask = ~torch.isnan(label)\n\n        if weight is None:\n            weight = torch.ones_like(label)\n\n        if self.loss == \"mse\":\n            return self.mse(pred[mask], label[mask].view(-1, 1), weight[mask])\n\n        raise ValueError(\"unknown loss `%s`\" % self.loss)\n\n    def metric_fn(self, pred, label):\n        mask = torch.isfinite(label)\n\n        if self.metric in (\"\", \"loss\"):\n            return self.loss_fn(pred[mask], label[mask])\n\n        raise ValueError(\"unknown metric `%s`\" % self.metric)\n\n    def _get_fl(self, data: torch.Tensor):\n        \"\"\"\n        get feature and label from data\n        - Handle the different data shape of time series and tabular data\n\n        Parameters\n        ----------\n        data : torch.Tensor\n            input data which maybe 3 dimension or 2 dimension\n            - 3dim: [batch_size, time_step, feature_dim]\n            - 2dim: [batch_size, feature_dim]\n\n        Returns\n        -------\n        Tuple[torch.Tensor, torch.Tensor]\n        \"\"\"\n        if data.dim() == 3:\n            # it is a time series dataset\n            feature = data[:, :, 0:-1].to(self.device)\n            label = data[:, -1, -1].to(self.device)\n        elif data.dim() == 2:\n            # it is a tabular dataset\n            feature = data[:, 0:-1].to(self.device)\n            label = data[:, -1].to(self.device)\n        else:\n            raise ValueError(\"Unsupported data shape.\")\n        return feature, label\n\n    def train_epoch(self, data_loader):\n        self.dnn_model.train()\n\n        for data, weight in data_loader:\n            feature, label = self._get_fl(data)\n\n            pred = self.dnn_model(feature.float())\n            loss = self.loss_fn(pred, label, weight.to(self.device))\n\n            self.train_optimizer.zero_grad()\n            loss.backward()\n            torch.nn.utils.clip_grad_value_(self.dnn_model.parameters(), 3.0)\n            self.train_optimizer.step()\n\n    def test_epoch(self, data_loader):\n        self.dnn_model.eval()\n\n        scores = []\n        losses = []\n\n        for data, weight in data_loader:\n            feature, label = self._get_fl(data)\n\n            with torch.no_grad():\n                pred = self.dnn_model(feature.float())\n                loss = self.loss_fn(pred, label, weight.to(self.device))\n                losses.append(loss.item())\n\n                score = self.metric_fn(pred, label)\n                scores.append(score.item())\n\n        return np.mean(losses), np.mean(scores)\n\n    def fit(\n        self,\n        dataset: Union[DatasetH, TSDatasetH],\n        evals_result=dict(),\n        save_path=None,\n        reweighter=None,\n    ):\n        ists = isinstance(dataset, TSDatasetH)  # is this time series dataset\n\n        dl_train = dataset.prepare(\"train\", col_set=[\"feature\", \"label\"], data_key=DataHandlerLP.DK_L)\n        dl_valid = dataset.prepare(\"valid\", col_set=[\"feature\", \"label\"], data_key=DataHandlerLP.DK_L)\n        self.logger.info(f\"Train samples: {len(dl_train)}\")\n        self.logger.info(f\"Valid samples: {len(dl_valid)}\")\n        if dl_train.empty or dl_valid.empty:\n            raise ValueError(\"Empty data from dataset, please check your dataset config.\")\n\n        if reweighter is None:\n            wl_train = np.ones(len(dl_train))\n            wl_valid = np.ones(len(dl_valid))\n        elif isinstance(reweighter, Reweighter):\n            wl_train = reweighter.reweight(dl_train)\n            wl_valid = reweighter.reweight(dl_valid)\n        else:\n            raise ValueError(\"Unsupported reweighter type.\")\n\n        # Preprocess for data.  To align to Dataset Interface for DataLoader\n        if ists:\n            dl_train.config(fillna_type=\"ffill+bfill\")  # process nan brought by dataloader\n            dl_valid.config(fillna_type=\"ffill+bfill\")  # process nan brought by dataloader\n        else:\n            # If it is a tabular, we convert the dataframe to numpy to be indexable by DataLoader\n            dl_train = dl_train.values\n            dl_valid = dl_valid.values\n\n        train_loader = DataLoader(\n            ConcatDataset(dl_train, wl_train),\n            batch_size=self.batch_size,\n            shuffle=True,\n            num_workers=self.n_jobs,\n            drop_last=True,\n        )\n        valid_loader = DataLoader(\n            ConcatDataset(dl_valid, wl_valid),\n            batch_size=self.batch_size,\n            shuffle=False,\n            num_workers=self.n_jobs,\n            drop_last=True,\n        )\n        del dl_train, dl_valid, wl_train, wl_valid\n\n        save_path = get_or_create_path(save_path)\n\n        stop_steps = 0\n        train_loss = 0\n        best_score = np.inf\n        best_epoch = 0\n        evals_result[\"train\"] = []\n        evals_result[\"valid\"] = []\n\n        # train\n        self.logger.info(\"training...\")\n        self.fitted = True\n\n        for step in range(self.n_epochs):\n            self.logger.info(\"Epoch%d:\", step)\n            self.logger.info(\"training...\")\n            self.train_epoch(train_loader)\n            self.logger.info(\"evaluating...\")\n            train_loss, train_score = self.test_epoch(train_loader)\n            val_loss, val_score = self.test_epoch(valid_loader)\n            self.logger.info(\"Epoch%d: train %.6f, valid %.6f\" % (step, train_score, val_score))\n            evals_result[\"train\"].append(train_score)\n            evals_result[\"valid\"].append(val_score)\n\n            # current_lr = self.train_optimizer.param_groups[0][\"lr\"]\n            # self.logger.info(\"Current learning rate: %.6e\" % current_lr)\n\n            self.lr_scheduler.step(val_score)\n\n            if step == 0:\n                best_param = copy.deepcopy(self.dnn_model.state_dict())\n            if val_score < best_score:\n                best_score = val_score\n                stop_steps = 0\n                best_epoch = step\n                best_param = copy.deepcopy(self.dnn_model.state_dict())\n            else:\n                stop_steps += 1\n                if stop_steps >= self.early_stop:\n                    self.logger.info(\"early stop\")\n                    break\n\n        self.logger.info(\"best score: %.6lf @ %d epoch\" % (best_score, best_epoch))\n        self.dnn_model.load_state_dict(best_param)\n        torch.save(best_param, save_path)\n\n        if self.use_gpu:\n            torch.cuda.empty_cache()\n\n    def predict(\n        self,\n        dataset: Union[DatasetH, TSDatasetH],\n        batch_size=None,\n        n_jobs=None,\n    ):\n        if not self.fitted:\n            raise ValueError(\"model is not fitted yet!\")\n\n        dl_test = dataset.prepare(\"test\", col_set=[\"feature\", \"label\"], data_key=DataHandlerLP.DK_I)\n        self.logger.info(f\"Test samples: {len(dl_test)}\")\n\n        if isinstance(dataset, TSDatasetH):\n            dl_test.config(fillna_type=\"ffill+bfill\")  # process nan brought by dataloader\n            index = dl_test.get_index()\n        else:\n            # If it is a tabular, we convert the dataframe to numpy to be indexable by DataLoader\n            index = dl_test.index\n            dl_test = dl_test.values\n\n        test_loader = DataLoader(dl_test, batch_size=self.batch_size, num_workers=self.n_jobs)\n        self.dnn_model.eval()\n        preds = []\n\n        for data in test_loader:\n            feature, _ = self._get_fl(data)\n            feature = feature.to(self.device)\n\n            with torch.no_grad():\n                pred = self.dnn_model(feature.float()).detach().cpu().numpy()\n\n            preds.append(pred)\n\n        preds_concat = np.concatenate(preds)\n        if preds_concat.ndim != 1:\n            preds_concat = preds_concat.ravel()\n\n        return pd.Series(preds_concat, index=index)\n"
  },
  {
    "path": "qlib/contrib/model/pytorch_gru.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom __future__ import division\nfrom __future__ import print_function\nimport copy\nfrom typing import Text, Union\n\nimport numpy as np\nimport pandas as pd\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\n\nfrom qlib.workflow import R\n\nfrom ...data.dataset import DatasetH\nfrom ...data.dataset.handler import DataHandlerLP\nfrom ...log import get_module_logger\nfrom ...model.base import Model\nfrom ...utils import get_or_create_path\nfrom .pytorch_utils import count_parameters\n\n\nclass GRU(Model):\n    \"\"\"GRU Model\n\n    Parameters\n    ----------\n    d_feat : int\n        input dimension for each time step\n    metric: str\n        the evaluation metric used in early stop\n    optimizer : str\n        optimizer name\n    GPU : str\n        the GPU ID(s) used for training\n    \"\"\"\n\n    def __init__(\n        self,\n        d_feat=6,\n        hidden_size=64,\n        num_layers=2,\n        dropout=0.0,\n        n_epochs=200,\n        lr=0.001,\n        metric=\"\",\n        batch_size=2000,\n        early_stop=20,\n        loss=\"mse\",\n        optimizer=\"adam\",\n        GPU=0,\n        seed=None,\n        **kwargs,\n    ):\n        # Set logger.\n        self.logger = get_module_logger(\"GRU\")\n        self.logger.info(\"GRU pytorch version...\")\n\n        # set hyper-parameters.\n        self.d_feat = d_feat\n        self.hidden_size = hidden_size\n        self.num_layers = num_layers\n        self.dropout = dropout\n        self.n_epochs = n_epochs\n        self.lr = lr\n        self.metric = metric\n        self.batch_size = batch_size\n        self.early_stop = early_stop\n        self.optimizer = optimizer.lower()\n        self.loss = loss\n        self.device = torch.device(\"cuda:%d\" % (GPU) if torch.cuda.is_available() and GPU >= 0 else \"cpu\")\n        self.seed = seed\n\n        self.logger.info(\n            \"GRU parameters setting:\"\n            \"\\nd_feat : {}\"\n            \"\\nhidden_size : {}\"\n            \"\\nnum_layers : {}\"\n            \"\\ndropout : {}\"\n            \"\\nn_epochs : {}\"\n            \"\\nlr : {}\"\n            \"\\nmetric : {}\"\n            \"\\nbatch_size : {}\"\n            \"\\nearly_stop : {}\"\n            \"\\noptimizer : {}\"\n            \"\\nloss_type : {}\"\n            \"\\nvisible_GPU : {}\"\n            \"\\nuse_GPU : {}\"\n            \"\\nseed : {}\".format(\n                d_feat,\n                hidden_size,\n                num_layers,\n                dropout,\n                n_epochs,\n                lr,\n                metric,\n                batch_size,\n                early_stop,\n                optimizer.lower(),\n                loss,\n                GPU,\n                self.use_gpu,\n                seed,\n            )\n        )\n\n        if self.seed is not None:\n            np.random.seed(self.seed)\n            torch.manual_seed(self.seed)\n\n        self.gru_model = GRUModel(\n            d_feat=self.d_feat,\n            hidden_size=self.hidden_size,\n            num_layers=self.num_layers,\n            dropout=self.dropout,\n        )\n        self.logger.info(\"model:\\n{:}\".format(self.gru_model))\n        self.logger.info(\"model size: {:.4f} MB\".format(count_parameters(self.gru_model)))\n\n        if optimizer.lower() == \"adam\":\n            self.train_optimizer = optim.Adam(self.gru_model.parameters(), lr=self.lr)\n        elif optimizer.lower() == \"gd\":\n            self.train_optimizer = optim.SGD(self.gru_model.parameters(), lr=self.lr)\n        else:\n            raise NotImplementedError(\"optimizer {} is not supported!\".format(optimizer))\n\n        self.fitted = False\n        self.gru_model.to(self.device)\n\n    @property\n    def use_gpu(self):\n        return self.device != torch.device(\"cpu\")\n\n    def mse(self, pred, label):\n        loss = (pred - label) ** 2\n        return torch.mean(loss)\n\n    def loss_fn(self, pred, label):\n        mask = ~torch.isnan(label)\n\n        if self.loss == \"mse\":\n            return self.mse(pred[mask], label[mask])\n\n        raise ValueError(\"unknown loss `%s`\" % self.loss)\n\n    def metric_fn(self, pred, label):\n        mask = torch.isfinite(label)\n\n        if self.metric in (\"\", \"loss\"):\n            return -self.loss_fn(pred[mask], label[mask])\n\n        raise ValueError(\"unknown metric `%s`\" % self.metric)\n\n    def train_epoch(self, x_train, y_train):\n        x_train_values = x_train.values\n        y_train_values = np.squeeze(y_train.values)\n\n        self.gru_model.train()\n\n        indices = np.arange(len(x_train_values))\n        np.random.shuffle(indices)\n\n        for i in range(len(indices))[:: self.batch_size]:\n            if len(indices) - i < self.batch_size:\n                break\n\n            feature = torch.from_numpy(x_train_values[indices[i : i + self.batch_size]]).float().to(self.device)\n            label = torch.from_numpy(y_train_values[indices[i : i + self.batch_size]]).float().to(self.device)\n\n            pred = self.gru_model(feature)\n            loss = self.loss_fn(pred, label)\n\n            self.train_optimizer.zero_grad()\n            loss.backward()\n            torch.nn.utils.clip_grad_value_(self.gru_model.parameters(), 3.0)\n            self.train_optimizer.step()\n\n    def test_epoch(self, data_x, data_y):\n        # prepare training data\n        x_values = data_x.values\n        y_values = np.squeeze(data_y.values)\n\n        self.gru_model.eval()\n\n        scores = []\n        losses = []\n\n        indices = np.arange(len(x_values))\n\n        for i in range(len(indices))[:: self.batch_size]:\n            if len(indices) - i < self.batch_size:\n                break\n\n            feature = torch.from_numpy(x_values[indices[i : i + self.batch_size]]).float().to(self.device)\n            label = torch.from_numpy(y_values[indices[i : i + self.batch_size]]).float().to(self.device)\n\n            with torch.no_grad():\n                pred = self.gru_model(feature)\n                loss = self.loss_fn(pred, label)\n                losses.append(loss.item())\n\n                score = self.metric_fn(pred, label)\n                scores.append(score.item())\n\n        return np.mean(losses), np.mean(scores)\n\n    def fit(\n        self,\n        dataset: DatasetH,\n        evals_result=dict(),\n        save_path=None,\n    ):\n        # prepare training and validation data\n        dfs = {\n            k: dataset.prepare(\n                k,\n                col_set=[\"feature\", \"label\"],\n                data_key=DataHandlerLP.DK_L,\n            )\n            for k in [\"train\", \"valid\"]\n            if k in dataset.segments\n        }\n        df_train, df_valid = dfs.get(\"train\", pd.DataFrame()), dfs.get(\"valid\", pd.DataFrame())\n\n        # check if training data is empty\n        if df_train.empty:\n            raise ValueError(\"Empty training data from dataset, please check your dataset config.\")\n\n        df_train = df_train.dropna()\n        x_train, y_train = df_train[\"feature\"], df_train[\"label\"]\n\n        # check if validation data is provided\n        if not df_valid.empty:\n            df_valid = df_valid.dropna()\n            x_valid, y_valid = df_valid[\"feature\"], df_valid[\"label\"]\n        else:\n            x_valid, y_valid = None, None\n\n        save_path = get_or_create_path(save_path)\n        stop_steps = 0\n        train_loss = 0\n        best_score = -np.inf\n        best_epoch = 0\n        evals_result[\"train\"] = []\n        evals_result[\"valid\"] = []\n\n        # train\n        self.logger.info(\"training...\")\n        self.fitted = True\n\n        best_param = copy.deepcopy(self.gru_model.state_dict())\n        for step in range(self.n_epochs):\n            self.logger.info(\"Epoch%d:\", step)\n            self.logger.info(\"training...\")\n            self.train_epoch(x_train, y_train)\n            self.logger.info(\"evaluating...\")\n            train_loss, train_score = self.test_epoch(x_train, y_train)\n            evals_result[\"train\"].append(train_score)\n\n            # evaluate on validation data if provided\n            if x_valid is not None and y_valid is not None:\n                val_loss, val_score = self.test_epoch(x_valid, y_valid)\n                self.logger.info(\"train %.6f, valid %.6f\" % (train_score, val_score))\n                evals_result[\"valid\"].append(val_score)\n\n                if val_score > best_score:\n                    best_score = val_score\n                    stop_steps = 0\n                    best_epoch = step\n                    best_param = copy.deepcopy(self.gru_model.state_dict())\n                else:\n                    stop_steps += 1\n                    if stop_steps >= self.early_stop:\n                        self.logger.info(\"early stop\")\n                        break\n\n        self.logger.info(\"best score: %.6lf @ %d\" % (best_score, best_epoch))\n        self.gru_model.load_state_dict(best_param)\n        torch.save(best_param, save_path)\n\n        # Logging\n        rec = R.get_recorder()\n        for k, v_l in evals_result.items():\n            for i, v in enumerate(v_l):\n                rec.log_metrics(step=i, **{k: v})\n\n        if self.use_gpu:\n            torch.cuda.empty_cache()\n\n    def predict(self, dataset: DatasetH, segment: Union[Text, slice] = \"test\"):\n        if not self.fitted:\n            raise ValueError(\"model is not fitted yet!\")\n\n        x_test = dataset.prepare(segment, col_set=\"feature\", data_key=DataHandlerLP.DK_I)\n        index = x_test.index\n        self.gru_model.eval()\n        x_values = x_test.values\n        sample_num = x_values.shape[0]\n        preds = []\n\n        for begin in range(sample_num)[:: self.batch_size]:\n            if sample_num - begin < self.batch_size:\n                end = sample_num\n            else:\n                end = begin + self.batch_size\n\n            x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)\n\n            with torch.no_grad():\n                pred = self.gru_model(x_batch).detach().cpu().numpy()\n\n            preds.append(pred)\n\n        return pd.Series(np.concatenate(preds), index=index)\n\n\nclass GRUModel(nn.Module):\n    def __init__(self, d_feat=6, hidden_size=64, num_layers=2, dropout=0.0):\n        super().__init__()\n\n        self.rnn = nn.GRU(\n            input_size=d_feat,\n            hidden_size=hidden_size,\n            num_layers=num_layers,\n            batch_first=True,\n            dropout=dropout,\n        )\n        self.fc_out = nn.Linear(hidden_size, 1)\n\n        self.d_feat = d_feat\n\n    def forward(self, x):\n        # x: [N, F*T]\n        x = x.reshape(len(x), self.d_feat, -1)  # [N, F, T]\n        x = x.permute(0, 2, 1)  # [N, T, F]\n        out, _ = self.rnn(x)\n        return self.fc_out(out[:, -1, :]).squeeze()\n"
  },
  {
    "path": "qlib/contrib/model/pytorch_gru_ts.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\n\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport numpy as np\nimport pandas as pd\nimport copy\nfrom ...utils import get_or_create_path\nfrom ...log import get_module_logger\n\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.utils.data import DataLoader\n\nfrom .pytorch_utils import count_parameters\nfrom ...model.base import Model\nfrom ...data.dataset.handler import DataHandlerLP\nfrom ...model.utils import ConcatDataset\nfrom ...data.dataset.weight import Reweighter\n\n\nclass GRU(Model):\n    \"\"\"GRU Model\n\n    Parameters\n    ----------\n    d_feat : int\n        input dimension for each time step\n    metric: str\n        the evaluation metric used in early stop\n    optimizer : str\n        optimizer name\n    GPU : str\n        the GPU ID(s) used for training\n    \"\"\"\n\n    def __init__(\n        self,\n        d_feat=6,\n        hidden_size=64,\n        num_layers=2,\n        dropout=0.0,\n        n_epochs=200,\n        lr=0.001,\n        metric=\"\",\n        batch_size=2000,\n        early_stop=20,\n        loss=\"mse\",\n        optimizer=\"adam\",\n        n_jobs=10,\n        GPU=0,\n        seed=None,\n        **kwargs,\n    ):\n        # Set logger.\n        self.logger = get_module_logger(\"GRU\")\n        self.logger.info(\"GRU pytorch version...\")\n\n        # set hyper-parameters.\n        self.d_feat = d_feat\n        self.hidden_size = hidden_size\n        self.num_layers = num_layers\n        self.dropout = dropout\n        self.n_epochs = n_epochs\n        self.lr = lr\n        self.metric = metric\n        self.batch_size = batch_size\n        self.early_stop = early_stop\n        self.optimizer = optimizer.lower()\n        self.loss = loss\n        self.device = torch.device(\"cuda:%d\" % (GPU) if torch.cuda.is_available() and GPU >= 0 else \"cpu\")\n        self.n_jobs = n_jobs\n        self.seed = seed\n\n        self.logger.info(\n            \"GRU parameters setting:\"\n            \"\\nd_feat : {}\"\n            \"\\nhidden_size : {}\"\n            \"\\nnum_layers : {}\"\n            \"\\ndropout : {}\"\n            \"\\nn_epochs : {}\"\n            \"\\nlr : {}\"\n            \"\\nmetric : {}\"\n            \"\\nbatch_size : {}\"\n            \"\\nearly_stop : {}\"\n            \"\\noptimizer : {}\"\n            \"\\nloss_type : {}\"\n            \"\\ndevice : {}\"\n            \"\\nn_jobs : {}\"\n            \"\\nuse_GPU : {}\"\n            \"\\nseed : {}\".format(\n                d_feat,\n                hidden_size,\n                num_layers,\n                dropout,\n                n_epochs,\n                lr,\n                metric,\n                batch_size,\n                early_stop,\n                optimizer.lower(),\n                loss,\n                self.device,\n                n_jobs,\n                self.use_gpu,\n                seed,\n            )\n        )\n\n        if self.seed is not None:\n            np.random.seed(self.seed)\n            torch.manual_seed(self.seed)\n\n        self.GRU_model = GRUModel(\n            d_feat=self.d_feat,\n            hidden_size=self.hidden_size,\n            num_layers=self.num_layers,\n            dropout=self.dropout,\n        )\n        self.logger.info(\"model:\\n{:}\".format(self.GRU_model))\n        self.logger.info(\"model size: {:.4f} MB\".format(count_parameters(self.GRU_model)))\n\n        if optimizer.lower() == \"adam\":\n            self.train_optimizer = optim.Adam(self.GRU_model.parameters(), lr=self.lr)\n        elif optimizer.lower() == \"gd\":\n            self.train_optimizer = optim.SGD(self.GRU_model.parameters(), lr=self.lr)\n        else:\n            raise NotImplementedError(\"optimizer {} is not supported!\".format(optimizer))\n\n        self.fitted = False\n        self.GRU_model.to(self.device)\n\n    @property\n    def use_gpu(self):\n        return self.device != torch.device(\"cpu\")\n\n    def mse(self, pred, label, weight):\n        loss = weight * (pred - label) ** 2\n        return torch.mean(loss)\n\n    def loss_fn(self, pred, label, weight=None):\n        mask = ~torch.isnan(label)\n\n        if weight is None:\n            weight = torch.ones_like(label)\n\n        if self.loss == \"mse\":\n            return self.mse(pred[mask], label[mask], weight[mask])\n\n        raise ValueError(\"unknown loss `%s`\" % self.loss)\n\n    def metric_fn(self, pred, label):\n        mask = torch.isfinite(label)\n\n        if self.metric in (\"\", \"loss\"):\n            return -self.loss_fn(pred[mask], label[mask])\n\n        raise ValueError(\"unknown metric `%s`\" % self.metric)\n\n    def train_epoch(self, data_loader):\n        self.GRU_model.train()\n\n        for data, weight in data_loader:\n            feature = data[:, :, 0:-1].to(self.device)\n            label = data[:, -1, -1].to(self.device)\n\n            pred = self.GRU_model(feature.float())\n            loss = self.loss_fn(pred, label, weight.to(self.device))\n\n            self.train_optimizer.zero_grad()\n            loss.backward()\n            torch.nn.utils.clip_grad_value_(self.GRU_model.parameters(), 3.0)\n            self.train_optimizer.step()\n\n    def test_epoch(self, data_loader):\n        self.GRU_model.eval()\n\n        scores = []\n        losses = []\n\n        for data, weight in data_loader:\n            feature = data[:, :, 0:-1].to(self.device)\n            # feature[torch.isnan(feature)] = 0\n            label = data[:, -1, -1].to(self.device)\n\n            with torch.no_grad():\n                pred = self.GRU_model(feature.float())\n                loss = self.loss_fn(pred, label, weight.to(self.device))\n                losses.append(loss.item())\n\n                score = self.metric_fn(pred, label)\n                scores.append(score.item())\n\n        return np.mean(losses), np.mean(scores)\n\n    def fit(\n        self,\n        dataset,\n        evals_result=dict(),\n        save_path=None,\n        reweighter=None,\n    ):\n        dl_train = dataset.prepare(\"train\", col_set=[\"feature\", \"label\"], data_key=DataHandlerLP.DK_L)\n        dl_valid = dataset.prepare(\"valid\", col_set=[\"feature\", \"label\"], data_key=DataHandlerLP.DK_L)\n        if dl_train.empty or dl_valid.empty:\n            raise ValueError(\"Empty data from dataset, please check your dataset config.\")\n\n        dl_train.config(fillna_type=\"ffill+bfill\")  # process nan brought by dataloader\n        dl_valid.config(fillna_type=\"ffill+bfill\")  # process nan brought by dataloader\n\n        if reweighter is None:\n            wl_train = np.ones(len(dl_train))\n            wl_valid = np.ones(len(dl_valid))\n        elif isinstance(reweighter, Reweighter):\n            wl_train = reweighter.reweight(dl_train)\n            wl_valid = reweighter.reweight(dl_valid)\n        else:\n            raise ValueError(\"Unsupported reweighter type.\")\n\n        train_loader = DataLoader(\n            ConcatDataset(dl_train, wl_train),\n            batch_size=self.batch_size,\n            shuffle=True,\n            num_workers=self.n_jobs,\n            drop_last=True,\n        )\n        valid_loader = DataLoader(\n            ConcatDataset(dl_valid, wl_valid),\n            batch_size=self.batch_size,\n            shuffle=False,\n            num_workers=self.n_jobs,\n            drop_last=True,\n        )\n\n        save_path = get_or_create_path(save_path)\n\n        stop_steps = 0\n        train_loss = 0\n        best_score = -np.inf\n        best_epoch = 0\n        evals_result[\"train\"] = []\n        evals_result[\"valid\"] = []\n\n        # train\n        self.logger.info(\"training...\")\n        self.fitted = True\n\n        for step in range(self.n_epochs):\n            self.logger.info(\"Epoch%d:\", step)\n            self.logger.info(\"training...\")\n            self.train_epoch(train_loader)\n            self.logger.info(\"evaluating...\")\n            train_loss, train_score = self.test_epoch(train_loader)\n            val_loss, val_score = self.test_epoch(valid_loader)\n            self.logger.info(\"train %.6f, valid %.6f\" % (train_score, val_score))\n            evals_result[\"train\"].append(train_score)\n            evals_result[\"valid\"].append(val_score)\n\n            if val_score > best_score:\n                best_score = val_score\n                stop_steps = 0\n                best_epoch = step\n                best_param = copy.deepcopy(self.GRU_model.state_dict())\n            else:\n                stop_steps += 1\n                if stop_steps >= self.early_stop:\n                    self.logger.info(\"early stop\")\n                    break\n\n        self.logger.info(\"best score: %.6lf @ %d\" % (best_score, best_epoch))\n        self.GRU_model.load_state_dict(best_param)\n        torch.save(best_param, save_path)\n\n        if self.use_gpu:\n            torch.cuda.empty_cache()\n\n    def predict(self, dataset):\n        if not self.fitted:\n            raise ValueError(\"model is not fitted yet!\")\n\n        dl_test = dataset.prepare(\"test\", col_set=[\"feature\", \"label\"], data_key=DataHandlerLP.DK_I)\n        dl_test.config(fillna_type=\"ffill+bfill\")\n        test_loader = DataLoader(dl_test, batch_size=self.batch_size, num_workers=self.n_jobs)\n        self.GRU_model.eval()\n        preds = []\n\n        for data in test_loader:\n            feature = data[:, :, 0:-1].to(self.device)\n\n            with torch.no_grad():\n                pred = self.GRU_model(feature.float()).detach().cpu().numpy()\n\n            preds.append(pred)\n\n        return pd.Series(np.concatenate(preds), index=dl_test.get_index())\n\n\nclass GRUModel(nn.Module):\n    def __init__(self, d_feat=6, hidden_size=64, num_layers=2, dropout=0.0):\n        super().__init__()\n\n        self.rnn = nn.GRU(\n            input_size=d_feat,\n            hidden_size=hidden_size,\n            num_layers=num_layers,\n            batch_first=True,\n            dropout=dropout,\n        )\n        self.fc_out = nn.Linear(hidden_size, 1)\n\n        self.d_feat = d_feat\n\n    def forward(self, x):\n        out, _ = self.rnn(x)\n        return self.fc_out(out[:, -1, :]).squeeze()\n"
  },
  {
    "path": "qlib/contrib/model/pytorch_hist.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\n\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport os\nimport numpy as np\nimport pandas as pd\nfrom typing import Text, Union\nimport urllib.request\nimport copy\nfrom ...utils import get_or_create_path\nfrom ...log import get_module_logger\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom .pytorch_utils import count_parameters\nfrom ...model.base import Model\nfrom ...data.dataset import DatasetH\nfrom ...data.dataset.handler import DataHandlerLP\nfrom ...contrib.model.pytorch_lstm import LSTMModel\nfrom ...contrib.model.pytorch_gru import GRUModel\n\n\nclass HIST(Model):\n    \"\"\"HIST Model\n\n    Parameters\n    ----------\n    lr : float\n        learning rate\n    d_feat : int\n        input dimensions for each time step\n    metric : str\n        the evaluation metric used in early stop\n    optimizer : str\n        optimizer name\n    GPU : str\n        the GPU ID(s) used for training\n    \"\"\"\n\n    def __init__(\n        self,\n        d_feat=6,\n        hidden_size=64,\n        num_layers=2,\n        dropout=0.0,\n        n_epochs=200,\n        lr=0.001,\n        metric=\"\",\n        early_stop=20,\n        loss=\"mse\",\n        base_model=\"GRU\",\n        model_path=None,\n        stock2concept=None,\n        stock_index=None,\n        optimizer=\"adam\",\n        GPU=0,\n        seed=None,\n        **kwargs,\n    ):\n        # Set logger.\n        self.logger = get_module_logger(\"HIST\")\n        self.logger.info(\"HIST pytorch version...\")\n\n        # set hyper-parameters.\n        self.d_feat = d_feat\n        self.hidden_size = hidden_size\n        self.num_layers = num_layers\n        self.dropout = dropout\n        self.n_epochs = n_epochs\n        self.lr = lr\n        self.metric = metric\n        self.early_stop = early_stop\n        self.optimizer = optimizer.lower()\n        self.loss = loss\n        self.base_model = base_model\n        self.model_path = model_path\n        self.stock2concept = stock2concept\n        self.stock_index = stock_index\n        self.device = torch.device(\"cuda:%d\" % (GPU) if torch.cuda.is_available() and GPU >= 0 else \"cpu\")\n        self.seed = seed\n\n        self.logger.info(\n            \"HIST parameters setting:\"\n            \"\\nd_feat : {}\"\n            \"\\nhidden_size : {}\"\n            \"\\nnum_layers : {}\"\n            \"\\ndropout : {}\"\n            \"\\nn_epochs : {}\"\n            \"\\nlr : {}\"\n            \"\\nmetric : {}\"\n            \"\\nearly_stop : {}\"\n            \"\\noptimizer : {}\"\n            \"\\nloss_type : {}\"\n            \"\\nbase_model : {}\"\n            \"\\nmodel_path : {}\"\n            \"\\nstock2concept : {}\"\n            \"\\nstock_index : {}\"\n            \"\\nuse_GPU : {}\"\n            \"\\nseed : {}\".format(\n                d_feat,\n                hidden_size,\n                num_layers,\n                dropout,\n                n_epochs,\n                lr,\n                metric,\n                early_stop,\n                optimizer.lower(),\n                loss,\n                base_model,\n                model_path,\n                stock2concept,\n                stock_index,\n                GPU,\n                seed,\n            )\n        )\n\n        if self.seed is not None:\n            np.random.seed(self.seed)\n            torch.manual_seed(self.seed)\n\n        self.HIST_model = HISTModel(\n            d_feat=self.d_feat,\n            hidden_size=self.hidden_size,\n            num_layers=self.num_layers,\n            dropout=self.dropout,\n            base_model=self.base_model,\n        )\n        self.logger.info(\"model:\\n{:}\".format(self.HIST_model))\n        self.logger.info(\"model size: {:.4f} MB\".format(count_parameters(self.HIST_model)))\n        if optimizer.lower() == \"adam\":\n            self.train_optimizer = optim.Adam(self.HIST_model.parameters(), lr=self.lr)\n        elif optimizer.lower() == \"gd\":\n            self.train_optimizer = optim.SGD(self.HIST_model.parameters(), lr=self.lr)\n        else:\n            raise NotImplementedError(\"optimizer {} is not supported!\".format(optimizer))\n\n        self.fitted = False\n        self.HIST_model.to(self.device)\n\n    @property\n    def use_gpu(self):\n        return self.device != torch.device(\"cpu\")\n\n    def mse(self, pred, label):\n        loss = (pred - label) ** 2\n        return torch.mean(loss)\n\n    def loss_fn(self, pred, label):\n        mask = ~torch.isnan(label)\n\n        if self.loss == \"mse\":\n            return self.mse(pred[mask], label[mask])\n\n        raise ValueError(\"unknown loss `%s`\" % self.loss)\n\n    def metric_fn(self, pred, label):\n        mask = torch.isfinite(label)\n\n        if self.metric == \"ic\":\n            x = pred[mask]\n            y = label[mask]\n\n            vx = x - torch.mean(x)\n            vy = y - torch.mean(y)\n            return torch.sum(vx * vy) / (torch.sqrt(torch.sum(vx**2)) * torch.sqrt(torch.sum(vy**2)))\n\n        if self.metric == (\"\", \"loss\"):\n            return -self.loss_fn(pred[mask], label[mask])\n\n        raise ValueError(\"unknown metric `%s`\" % self.metric)\n\n    def get_daily_inter(self, df, shuffle=False):\n        # organize the train data into daily batches\n        daily_count = df.groupby(level=0, group_keys=False).size().values\n        daily_index = np.roll(np.cumsum(daily_count), 1)\n        daily_index[0] = 0\n        if shuffle:\n            # shuffle data\n            daily_shuffle = list(zip(daily_index, daily_count))\n            np.random.shuffle(daily_shuffle)\n            daily_index, daily_count = zip(*daily_shuffle)\n        return daily_index, daily_count\n\n    def train_epoch(self, x_train, y_train, stock_index):\n        stock2concept_matrix = np.load(self.stock2concept)\n        x_train_values = x_train.values\n        y_train_values = np.squeeze(y_train.values)\n        stock_index = stock_index.values\n        stock_index[np.isnan(stock_index)] = 733\n        self.HIST_model.train()\n\n        # organize the train data into daily batches\n        daily_index, daily_count = self.get_daily_inter(x_train, shuffle=True)\n\n        for idx, count in zip(daily_index, daily_count):\n            batch = slice(idx, idx + count)\n            feature = torch.from_numpy(x_train_values[batch]).float().to(self.device)\n            concept_matrix = torch.from_numpy(stock2concept_matrix[stock_index[batch]]).float().to(self.device)\n            label = torch.from_numpy(y_train_values[batch]).float().to(self.device)\n            pred = self.HIST_model(feature, concept_matrix)\n            loss = self.loss_fn(pred, label)\n\n            self.train_optimizer.zero_grad()\n            loss.backward()\n            torch.nn.utils.clip_grad_value_(self.HIST_model.parameters(), 3.0)\n            self.train_optimizer.step()\n\n    def test_epoch(self, data_x, data_y, stock_index):\n        # prepare training data\n        stock2concept_matrix = np.load(self.stock2concept)\n        x_values = data_x.values\n        y_values = np.squeeze(data_y.values)\n        stock_index = stock_index.values\n        stock_index[np.isnan(stock_index)] = 733\n        self.HIST_model.eval()\n\n        scores = []\n        losses = []\n\n        # organize the test data into daily batches\n        daily_index, daily_count = self.get_daily_inter(data_x, shuffle=False)\n\n        for idx, count in zip(daily_index, daily_count):\n            batch = slice(idx, idx + count)\n            feature = torch.from_numpy(x_values[batch]).float().to(self.device)\n            concept_matrix = torch.from_numpy(stock2concept_matrix[stock_index[batch]]).float().to(self.device)\n            label = torch.from_numpy(y_values[batch]).float().to(self.device)\n            with torch.no_grad():\n                pred = self.HIST_model(feature, concept_matrix)\n                loss = self.loss_fn(pred, label)\n                losses.append(loss.item())\n\n                score = self.metric_fn(pred, label)\n                scores.append(score.item())\n\n        return np.mean(losses), np.mean(scores)\n\n    def fit(\n        self,\n        dataset: DatasetH,\n        evals_result=dict(),\n        save_path=None,\n    ):\n        df_train, df_valid, df_test = dataset.prepare(\n            [\"train\", \"valid\", \"test\"],\n            col_set=[\"feature\", \"label\"],\n            data_key=DataHandlerLP.DK_L,\n        )\n        if df_train.empty or df_valid.empty:\n            raise ValueError(\"Empty data from dataset, please check your dataset config.\")\n\n        if not os.path.exists(self.stock2concept):\n            url = \"https://github.com/SunsetWolf/qlib_dataset/releases/download/v0/qlib_csi300_stock2concept.npy\"\n            urllib.request.urlretrieve(url, self.stock2concept)\n\n        stock_index = np.load(self.stock_index, allow_pickle=True).item()\n        df_train[\"stock_index\"] = 733\n        df_train[\"stock_index\"] = df_train.index.get_level_values(\"instrument\").map(stock_index)\n        df_valid[\"stock_index\"] = 733\n        df_valid[\"stock_index\"] = df_valid.index.get_level_values(\"instrument\").map(stock_index)\n\n        x_train, y_train, stock_index_train = df_train[\"feature\"], df_train[\"label\"], df_train[\"stock_index\"]\n        x_valid, y_valid, stock_index_valid = df_valid[\"feature\"], df_valid[\"label\"], df_valid[\"stock_index\"]\n\n        save_path = get_or_create_path(save_path)\n\n        stop_steps = 0\n        best_score = -np.inf\n        best_epoch = 0\n        evals_result[\"train\"] = []\n        evals_result[\"valid\"] = []\n\n        # load pretrained base_model\n        if self.base_model == \"LSTM\":\n            pretrained_model = LSTMModel()\n        elif self.base_model == \"GRU\":\n            pretrained_model = GRUModel()\n        else:\n            raise ValueError(\"unknown base model name `%s`\" % self.base_model)\n\n        if self.model_path is not None:\n            self.logger.info(\"Loading pretrained model...\")\n            pretrained_model.load_state_dict(torch.load(self.model_path))\n\n        model_dict = self.HIST_model.state_dict()\n        pretrained_dict = {\n            k: v for k, v in pretrained_model.state_dict().items() if k in model_dict  # pylint: disable=E1135\n        }\n        model_dict.update(pretrained_dict)\n        self.HIST_model.load_state_dict(model_dict)\n        self.logger.info(\"Loading pretrained model Done...\")\n\n        # train\n        self.logger.info(\"training...\")\n        self.fitted = True\n\n        for step in range(self.n_epochs):\n            self.logger.info(\"Epoch%d:\", step)\n            self.logger.info(\"training...\")\n            self.train_epoch(x_train, y_train, stock_index_train)\n\n            self.logger.info(\"evaluating...\")\n            train_loss, train_score = self.test_epoch(x_train, y_train, stock_index_train)\n            val_loss, val_score = self.test_epoch(x_valid, y_valid, stock_index_valid)\n            self.logger.info(\"train %.6f, valid %.6f\" % (train_score, val_score))\n            evals_result[\"train\"].append(train_score)\n            evals_result[\"valid\"].append(val_score)\n\n            if val_score > best_score:\n                best_score = val_score\n                stop_steps = 0\n                best_epoch = step\n                best_param = copy.deepcopy(self.HIST_model.state_dict())\n            else:\n                stop_steps += 1\n                if stop_steps >= self.early_stop:\n                    self.logger.info(\"early stop\")\n                    break\n\n        self.logger.info(\"best score: %.6lf @ %d\" % (best_score, best_epoch))\n        self.HIST_model.load_state_dict(best_param)\n        torch.save(best_param, save_path)\n\n    def predict(self, dataset: DatasetH, segment: Union[Text, slice] = \"test\"):\n        if not self.fitted:\n            raise ValueError(\"model is not fitted yet!\")\n\n        stock2concept_matrix = np.load(self.stock2concept)\n        stock_index = np.load(self.stock_index, allow_pickle=True).item()\n        df_test = dataset.prepare(segment, col_set=\"feature\", data_key=DataHandlerLP.DK_I)\n        df_test[\"stock_index\"] = 733\n        df_test[\"stock_index\"] = df_test.index.get_level_values(\"instrument\").map(stock_index)\n        stock_index_test = df_test[\"stock_index\"].values\n        stock_index_test[np.isnan(stock_index_test)] = 733\n        stock_index_test = stock_index_test.astype(\"int\")\n        df_test = df_test.drop([\"stock_index\"], axis=1)\n        index = df_test.index\n\n        self.HIST_model.eval()\n        x_values = df_test.values\n        preds = []\n\n        # organize the data into daily batches\n        daily_index, daily_count = self.get_daily_inter(df_test, shuffle=False)\n\n        for idx, count in zip(daily_index, daily_count):\n            batch = slice(idx, idx + count)\n            x_batch = torch.from_numpy(x_values[batch]).float().to(self.device)\n            concept_matrix = torch.from_numpy(stock2concept_matrix[stock_index_test[batch]]).float().to(self.device)\n\n            with torch.no_grad():\n                pred = self.HIST_model(x_batch, concept_matrix).detach().cpu().numpy()\n\n            preds.append(pred)\n\n        return pd.Series(np.concatenate(preds), index=index)\n\n\nclass HISTModel(nn.Module):\n    def __init__(self, d_feat=6, hidden_size=64, num_layers=2, dropout=0.0, base_model=\"GRU\"):\n        super().__init__()\n\n        self.d_feat = d_feat\n        self.hidden_size = hidden_size\n\n        if base_model == \"GRU\":\n            self.rnn = nn.GRU(\n                input_size=d_feat,\n                hidden_size=hidden_size,\n                num_layers=num_layers,\n                batch_first=True,\n                dropout=dropout,\n            )\n        elif base_model == \"LSTM\":\n            self.rnn = nn.LSTM(\n                input_size=d_feat,\n                hidden_size=hidden_size,\n                num_layers=num_layers,\n                batch_first=True,\n                dropout=dropout,\n            )\n        else:\n            raise ValueError(\"unknown base model name `%s`\" % base_model)\n\n        self.fc_es = nn.Linear(hidden_size, hidden_size)\n        torch.nn.init.xavier_uniform_(self.fc_es.weight)\n        self.fc_is = nn.Linear(hidden_size, hidden_size)\n        torch.nn.init.xavier_uniform_(self.fc_is.weight)\n\n        self.fc_es_middle = nn.Linear(hidden_size, hidden_size)\n        torch.nn.init.xavier_uniform_(self.fc_es_middle.weight)\n        self.fc_is_middle = nn.Linear(hidden_size, hidden_size)\n        torch.nn.init.xavier_uniform_(self.fc_is_middle.weight)\n\n        self.fc_es_fore = nn.Linear(hidden_size, hidden_size)\n        torch.nn.init.xavier_uniform_(self.fc_es_fore.weight)\n        self.fc_is_fore = nn.Linear(hidden_size, hidden_size)\n        torch.nn.init.xavier_uniform_(self.fc_is_fore.weight)\n        self.fc_indi_fore = nn.Linear(hidden_size, hidden_size)\n        torch.nn.init.xavier_uniform_(self.fc_indi_fore.weight)\n\n        self.fc_es_back = nn.Linear(hidden_size, hidden_size)\n        torch.nn.init.xavier_uniform_(self.fc_es_back.weight)\n        self.fc_is_back = nn.Linear(hidden_size, hidden_size)\n        torch.nn.init.xavier_uniform_(self.fc_is_back.weight)\n        self.fc_indi = nn.Linear(hidden_size, hidden_size)\n        torch.nn.init.xavier_uniform_(self.fc_indi.weight)\n\n        self.leaky_relu = nn.LeakyReLU()\n        self.softmax_s2t = torch.nn.Softmax(dim=0)\n        self.softmax_t2s = torch.nn.Softmax(dim=1)\n\n        self.fc_out_es = nn.Linear(hidden_size, 1)\n        self.fc_out_is = nn.Linear(hidden_size, 1)\n        self.fc_out_indi = nn.Linear(hidden_size, 1)\n        self.fc_out = nn.Linear(hidden_size, 1)\n\n    def cal_cos_similarity(self, x, y):  # the 2nd dimension of x and y are the same\n        xy = x.mm(torch.t(y))\n        x_norm = torch.sqrt(torch.sum(x * x, dim=1)).reshape(-1, 1)\n        y_norm = torch.sqrt(torch.sum(y * y, dim=1)).reshape(-1, 1)\n        cos_similarity = xy / (x_norm.mm(torch.t(y_norm)) + 1e-6)\n        return cos_similarity\n\n    def forward(self, x, concept_matrix):\n        device = torch.device(torch.get_device(x))\n\n        x_hidden = x.reshape(len(x), self.d_feat, -1)  # [N, F, T]\n        x_hidden = x_hidden.permute(0, 2, 1)  # [N, T, F]\n        x_hidden, _ = self.rnn(x_hidden)\n        x_hidden = x_hidden[:, -1, :]\n\n        # Predefined Concept Module\n\n        stock_to_concept = concept_matrix\n\n        stock_to_concept_sum = torch.sum(stock_to_concept, 0).reshape(1, -1).repeat(stock_to_concept.shape[0], 1)\n        stock_to_concept_sum = stock_to_concept_sum.mul(concept_matrix)\n\n        stock_to_concept_sum = stock_to_concept_sum + (\n            torch.ones(stock_to_concept.shape[0], stock_to_concept.shape[1]).to(device)\n        )\n        stock_to_concept = stock_to_concept / stock_to_concept_sum\n        hidden = torch.t(stock_to_concept).mm(x_hidden)\n\n        hidden = hidden[hidden.sum(1) != 0]\n\n        concept_to_stock = self.cal_cos_similarity(x_hidden, hidden)\n        concept_to_stock = self.softmax_t2s(concept_to_stock)\n\n        e_shared_info = concept_to_stock.mm(hidden)\n        e_shared_info = self.fc_es(e_shared_info)\n\n        e_shared_back = self.fc_es_back(e_shared_info)\n        output_es = self.fc_es_fore(e_shared_info)\n        output_es = self.leaky_relu(output_es)\n\n        # Hidden Concept Module\n        i_shared_info = x_hidden - e_shared_back\n        hidden = i_shared_info\n        i_stock_to_concept = self.cal_cos_similarity(i_shared_info, hidden)\n        dim = i_stock_to_concept.shape[0]\n        diag = i_stock_to_concept.diagonal(0)\n        i_stock_to_concept = i_stock_to_concept * (torch.ones(dim, dim) - torch.eye(dim)).to(device)\n        row = torch.linspace(0, dim - 1, dim).to(device).long()\n        column = i_stock_to_concept.max(1)[1].long()\n        value = i_stock_to_concept.max(1)[0]\n        i_stock_to_concept[row, column] = 10\n        i_stock_to_concept[i_stock_to_concept != 10] = 0\n        i_stock_to_concept[row, column] = value\n        i_stock_to_concept = i_stock_to_concept + torch.diag_embed((i_stock_to_concept.sum(0) != 0).float() * diag)\n        hidden = torch.t(i_shared_info).mm(i_stock_to_concept).t()\n        hidden = hidden[hidden.sum(1) != 0]\n\n        i_concept_to_stock = self.cal_cos_similarity(i_shared_info, hidden)\n        i_concept_to_stock = self.softmax_t2s(i_concept_to_stock)\n        i_shared_info = i_concept_to_stock.mm(hidden)\n        i_shared_info = self.fc_is(i_shared_info)\n\n        i_shared_back = self.fc_is_back(i_shared_info)\n        output_is = self.fc_is_fore(i_shared_info)\n        output_is = self.leaky_relu(output_is)\n\n        # Individual Information Module\n        individual_info = x_hidden - e_shared_back - i_shared_back\n        output_indi = individual_info\n        output_indi = self.fc_indi(output_indi)\n        output_indi = self.leaky_relu(output_indi)\n\n        # Stock Trend Prediction\n        all_info = output_es + output_is + output_indi\n        pred_all = self.fc_out(all_info).squeeze()\n\n        return pred_all\n"
  },
  {
    "path": "qlib/contrib/model/pytorch_igmtf.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\n\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport numpy as np\nimport pandas as pd\nfrom typing import Text, Union\nimport copy\nfrom ...utils import get_or_create_path\nfrom ...log import get_module_logger\n\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\n\nfrom .pytorch_utils import count_parameters\nfrom ...model.base import Model\nfrom ...data.dataset import DatasetH\nfrom ...data.dataset.handler import DataHandlerLP\nfrom ...contrib.model.pytorch_lstm import LSTMModel\nfrom ...contrib.model.pytorch_gru import GRUModel\n\n\nclass IGMTF(Model):\n    \"\"\"IGMTF Model\n\n    Parameters\n    ----------\n    d_feat : int\n        input dimension for each time step\n    metric: str\n        the evaluation metric used in early stop\n    optimizer : str\n        optimizer name\n    GPU : str\n        the GPU ID(s) used for training\n    \"\"\"\n\n    def __init__(\n        self,\n        d_feat=6,\n        hidden_size=64,\n        num_layers=2,\n        dropout=0.0,\n        n_epochs=200,\n        lr=0.001,\n        metric=\"\",\n        early_stop=20,\n        loss=\"mse\",\n        base_model=\"GRU\",\n        model_path=None,\n        optimizer=\"adam\",\n        GPU=0,\n        seed=None,\n        **kwargs,\n    ):\n        # Set logger.\n        self.logger = get_module_logger(\"IGMTF\")\n        self.logger.info(\"IMGTF pytorch version...\")\n\n        # set hyper-parameters.\n        self.d_feat = d_feat\n        self.hidden_size = hidden_size\n        self.num_layers = num_layers\n        self.dropout = dropout\n        self.n_epochs = n_epochs\n        self.lr = lr\n        self.metric = metric\n        self.early_stop = early_stop\n        self.optimizer = optimizer.lower()\n        self.loss = loss\n        self.base_model = base_model\n        self.model_path = model_path\n        self.device = torch.device(\"cuda:%d\" % (GPU) if torch.cuda.is_available() and GPU >= 0 else \"cpu\")\n        self.seed = seed\n\n        self.logger.info(\n            \"IGMTF parameters setting:\"\n            \"\\nd_feat : {}\"\n            \"\\nhidden_size : {}\"\n            \"\\nnum_layers : {}\"\n            \"\\ndropout : {}\"\n            \"\\nn_epochs : {}\"\n            \"\\nlr : {}\"\n            \"\\nmetric : {}\"\n            \"\\nearly_stop : {}\"\n            \"\\noptimizer : {}\"\n            \"\\nloss_type : {}\"\n            \"\\nbase_model : {}\"\n            \"\\nmodel_path : {}\"\n            \"\\nvisible_GPU : {}\"\n            \"\\nuse_GPU : {}\"\n            \"\\nseed : {}\".format(\n                d_feat,\n                hidden_size,\n                num_layers,\n                dropout,\n                n_epochs,\n                lr,\n                metric,\n                early_stop,\n                optimizer.lower(),\n                loss,\n                base_model,\n                model_path,\n                GPU,\n                self.use_gpu,\n                seed,\n            )\n        )\n\n        if self.seed is not None:\n            np.random.seed(self.seed)\n            torch.manual_seed(self.seed)\n\n        self.igmtf_model = IGMTFModel(\n            d_feat=self.d_feat,\n            hidden_size=self.hidden_size,\n            num_layers=self.num_layers,\n            dropout=self.dropout,\n            base_model=self.base_model,\n        )\n        self.logger.info(\"model:\\n{:}\".format(self.igmtf_model))\n        self.logger.info(\"model size: {:.4f} MB\".format(count_parameters(self.igmtf_model)))\n\n        if optimizer.lower() == \"adam\":\n            self.train_optimizer = optim.Adam(self.igmtf_model.parameters(), lr=self.lr)\n        elif optimizer.lower() == \"gd\":\n            self.train_optimizer = optim.SGD(self.igmtf_model.parameters(), lr=self.lr)\n        else:\n            raise NotImplementedError(\"optimizer {} is not supported!\".format(optimizer))\n\n        self.fitted = False\n        self.igmtf_model.to(self.device)\n\n    @property\n    def use_gpu(self):\n        return self.device != torch.device(\"cpu\")\n\n    def mse(self, pred, label):\n        loss = (pred - label) ** 2\n        return torch.mean(loss)\n\n    def loss_fn(self, pred, label):\n        mask = ~torch.isnan(label)\n\n        if self.loss == \"mse\":\n            return self.mse(pred[mask], label[mask])\n\n        raise ValueError(\"unknown loss `%s`\" % self.loss)\n\n    def metric_fn(self, pred, label):\n        mask = torch.isfinite(label)\n\n        if self.metric == \"ic\":\n            x = pred[mask]\n            y = label[mask]\n\n            vx = x - torch.mean(x)\n            vy = y - torch.mean(y)\n            return torch.sum(vx * vy) / (torch.sqrt(torch.sum(vx**2)) * torch.sqrt(torch.sum(vy**2)))\n\n        if self.metric == (\"\", \"loss\"):\n            return -self.loss_fn(pred[mask], label[mask])\n\n        raise ValueError(\"unknown metric `%s`\" % self.metric)\n\n    def get_daily_inter(self, df, shuffle=False):\n        # organize the train data into daily batches\n        daily_count = df.groupby(level=0, group_keys=False).size().values\n        daily_index = np.roll(np.cumsum(daily_count), 1)\n        daily_index[0] = 0\n        if shuffle:\n            # shuffle data\n            daily_shuffle = list(zip(daily_index, daily_count))\n            np.random.shuffle(daily_shuffle)\n            daily_index, daily_count = zip(*daily_shuffle)\n        return daily_index, daily_count\n\n    def get_train_hidden(self, x_train):\n        x_train_values = x_train.values\n        daily_index, daily_count = self.get_daily_inter(x_train, shuffle=True)\n        self.igmtf_model.eval()\n        train_hidden = []\n        train_hidden_day = []\n\n        for idx, count in zip(daily_index, daily_count):\n            batch = slice(idx, idx + count)\n            feature = torch.from_numpy(x_train_values[batch]).float().to(self.device)\n            out = self.igmtf_model(feature, get_hidden=True)\n            train_hidden.append(out.detach().cpu())\n            train_hidden_day.append(out.detach().cpu().mean(dim=0).unsqueeze(dim=0))\n\n        train_hidden = np.asarray(train_hidden, dtype=object)\n        train_hidden_day = torch.cat(train_hidden_day)\n\n        return train_hidden, train_hidden_day\n\n    def train_epoch(self, x_train, y_train, train_hidden, train_hidden_day):\n        x_train_values = x_train.values\n        y_train_values = np.squeeze(y_train.values)\n\n        self.igmtf_model.train()\n\n        daily_index, daily_count = self.get_daily_inter(x_train, shuffle=True)\n\n        for idx, count in zip(daily_index, daily_count):\n            batch = slice(idx, idx + count)\n            feature = torch.from_numpy(x_train_values[batch]).float().to(self.device)\n            label = torch.from_numpy(y_train_values[batch]).float().to(self.device)\n            pred = self.igmtf_model(feature, train_hidden=train_hidden, train_hidden_day=train_hidden_day)\n            loss = self.loss_fn(pred, label)\n\n            self.train_optimizer.zero_grad()\n            loss.backward()\n            torch.nn.utils.clip_grad_value_(self.igmtf_model.parameters(), 3.0)\n            self.train_optimizer.step()\n\n    def test_epoch(self, data_x, data_y, train_hidden, train_hidden_day):\n        # prepare training data\n        x_values = data_x.values\n        y_values = np.squeeze(data_y.values)\n\n        self.igmtf_model.eval()\n\n        scores = []\n        losses = []\n\n        daily_index, daily_count = self.get_daily_inter(data_x, shuffle=False)\n\n        for idx, count in zip(daily_index, daily_count):\n            batch = slice(idx, idx + count)\n            feature = torch.from_numpy(x_values[batch]).float().to(self.device)\n            label = torch.from_numpy(y_values[batch]).float().to(self.device)\n\n            pred = self.igmtf_model(feature, train_hidden=train_hidden, train_hidden_day=train_hidden_day)\n            loss = self.loss_fn(pred, label)\n            losses.append(loss.item())\n\n            score = self.metric_fn(pred, label)\n            scores.append(score.item())\n\n        return np.mean(losses), np.mean(scores)\n\n    def fit(\n        self,\n        dataset: DatasetH,\n        evals_result=dict(),\n        save_path=None,\n    ):\n        df_train, df_valid = dataset.prepare(\n            [\"train\", \"valid\"],\n            col_set=[\"feature\", \"label\"],\n            data_key=DataHandlerLP.DK_L,\n        )\n        if df_train.empty or df_valid.empty:\n            raise ValueError(\"Empty data from dataset, please check your dataset config.\")\n\n        x_train, y_train = df_train[\"feature\"], df_train[\"label\"]\n        x_valid, y_valid = df_valid[\"feature\"], df_valid[\"label\"]\n\n        save_path = get_or_create_path(save_path)\n        stop_steps = 0\n        train_loss = 0\n        best_score = -np.inf\n        best_epoch = 0\n        evals_result[\"train\"] = []\n        evals_result[\"valid\"] = []\n\n        # load pretrained base_model\n        if self.base_model == \"LSTM\":\n            pretrained_model = LSTMModel()\n        elif self.base_model == \"GRU\":\n            pretrained_model = GRUModel()\n        else:\n            raise ValueError(\"unknown base model name `%s`\" % self.base_model)\n\n        if self.model_path is not None:\n            self.logger.info(\"Loading pretrained model...\")\n            pretrained_model.load_state_dict(torch.load(self.model_path, map_location=self.device))\n\n        model_dict = self.igmtf_model.state_dict()\n        pretrained_dict = {\n            k: v for k, v in pretrained_model.state_dict().items() if k in model_dict  # pylint: disable=E1135\n        }\n        model_dict.update(pretrained_dict)\n        self.igmtf_model.load_state_dict(model_dict)\n        self.logger.info(\"Loading pretrained model Done...\")\n\n        # train\n        self.logger.info(\"training...\")\n        self.fitted = True\n\n        for step in range(self.n_epochs):\n            self.logger.info(\"Epoch%d:\", step)\n            self.logger.info(\"training...\")\n            train_hidden, train_hidden_day = self.get_train_hidden(x_train)\n            self.train_epoch(x_train, y_train, train_hidden, train_hidden_day)\n            self.logger.info(\"evaluating...\")\n            train_loss, train_score = self.test_epoch(x_train, y_train, train_hidden, train_hidden_day)\n            val_loss, val_score = self.test_epoch(x_valid, y_valid, train_hidden, train_hidden_day)\n            self.logger.info(\"train %.6f, valid %.6f\" % (train_score, val_score))\n            evals_result[\"train\"].append(train_score)\n            evals_result[\"valid\"].append(val_score)\n\n            if val_score > best_score:\n                best_score = val_score\n                stop_steps = 0\n                best_epoch = step\n                best_param = copy.deepcopy(self.igmtf_model.state_dict())\n            else:\n                stop_steps += 1\n                if stop_steps >= self.early_stop:\n                    self.logger.info(\"early stop\")\n                    break\n\n        self.logger.info(\"best score: %.6lf @ %d\" % (best_score, best_epoch))\n        self.igmtf_model.load_state_dict(best_param)\n        torch.save(best_param, save_path)\n\n        if self.use_gpu:\n            torch.cuda.empty_cache()\n\n    def predict(self, dataset: DatasetH, segment: Union[Text, slice] = \"test\"):\n        if not self.fitted:\n            raise ValueError(\"model is not fitted yet!\")\n        x_train = dataset.prepare(\"train\", col_set=\"feature\", data_key=DataHandlerLP.DK_L)\n        train_hidden, train_hidden_day = self.get_train_hidden(x_train)\n        x_test = dataset.prepare(segment, col_set=\"feature\", data_key=DataHandlerLP.DK_I)\n        index = x_test.index\n        self.igmtf_model.eval()\n        x_values = x_test.values\n        preds = []\n\n        daily_index, daily_count = self.get_daily_inter(x_test, shuffle=False)\n\n        for idx, count in zip(daily_index, daily_count):\n            batch = slice(idx, idx + count)\n            x_batch = torch.from_numpy(x_values[batch]).float().to(self.device)\n\n            with torch.no_grad():\n                pred = (\n                    self.igmtf_model(x_batch, train_hidden=train_hidden, train_hidden_day=train_hidden_day)\n                    .detach()\n                    .cpu()\n                    .numpy()\n                )\n\n            preds.append(pred)\n\n        return pd.Series(np.concatenate(preds), index=index)\n\n\nclass IGMTFModel(nn.Module):\n    def __init__(self, d_feat=6, hidden_size=64, num_layers=2, dropout=0.0, base_model=\"GRU\"):\n        super().__init__()\n\n        if base_model == \"GRU\":\n            self.rnn = nn.GRU(\n                input_size=d_feat,\n                hidden_size=hidden_size,\n                num_layers=num_layers,\n                batch_first=True,\n                dropout=dropout,\n            )\n        elif base_model == \"LSTM\":\n            self.rnn = nn.LSTM(\n                input_size=d_feat,\n                hidden_size=hidden_size,\n                num_layers=num_layers,\n                batch_first=True,\n                dropout=dropout,\n            )\n        else:\n            raise ValueError(\"unknown base model name `%s`\" % base_model)\n        self.lins = nn.Sequential()\n        for i in range(2):\n            self.lins.add_module(\"linear\" + str(i), nn.Linear(hidden_size, hidden_size))\n            self.lins.add_module(\"leakyrelu\" + str(i), nn.LeakyReLU())\n        self.fc_output = nn.Linear(hidden_size * 2, hidden_size * 2)\n        self.project1 = nn.Linear(hidden_size, hidden_size, bias=False)\n        self.project2 = nn.Linear(hidden_size, hidden_size, bias=False)\n        self.fc_out_pred = nn.Linear(hidden_size * 2, 1)\n\n        self.leaky_relu = nn.LeakyReLU()\n        self.d_feat = d_feat\n\n    def cal_cos_similarity(self, x, y):  # the 2nd dimension of x and y are the same\n        xy = x.mm(torch.t(y))\n        x_norm = torch.sqrt(torch.sum(x * x, dim=1)).reshape(-1, 1)\n        y_norm = torch.sqrt(torch.sum(y * y, dim=1)).reshape(-1, 1)\n        cos_similarity = xy / (x_norm.mm(torch.t(y_norm)) + 1e-6)\n        return cos_similarity\n\n    def sparse_dense_mul(self, s, d):\n        i = s._indices()\n        v = s._values()\n        dv = d[i[0, :], i[1, :]]  # get values from relevant entries of dense matrix\n        return torch.sparse.FloatTensor(i, v * dv, s.size())\n\n    def forward(self, x, get_hidden=False, train_hidden=None, train_hidden_day=None, k_day=10, n_neighbor=10):\n        # x: [N, F*T]\n        device = x.device\n        x = x.reshape(len(x), self.d_feat, -1)  # [N, F, T]\n        x = x.permute(0, 2, 1)  # [N, T, F]\n        out, _ = self.rnn(x)\n        out = out[:, -1, :]\n        out = self.lins(out)\n        mini_batch_out = out\n        if get_hidden is True:\n            return mini_batch_out\n\n        mini_batch_out_day = torch.mean(mini_batch_out, dim=0).unsqueeze(0)\n        day_similarity = self.cal_cos_similarity(mini_batch_out_day, train_hidden_day.to(device))\n        day_index = torch.topk(day_similarity, k_day, dim=1)[1]\n        sample_train_hidden = train_hidden[day_index.long().cpu()].squeeze()\n        sample_train_hidden = torch.cat(list(sample_train_hidden)).to(device)\n        sample_train_hidden = self.lins(sample_train_hidden)\n        cos_similarity = self.cal_cos_similarity(self.project1(mini_batch_out), self.project2(sample_train_hidden))\n\n        row = (\n            torch.linspace(0, x.shape[0] - 1, x.shape[0])\n            .reshape([-1, 1])\n            .repeat(1, n_neighbor)\n            .reshape(1, -1)\n            .to(device)\n        )\n        column = torch.topk(cos_similarity, n_neighbor, dim=1)[1].reshape(1, -1)\n        mask = torch.sparse_coo_tensor(\n            torch.cat([row, column]),\n            torch.ones([row.shape[1]]).to(device) / n_neighbor,\n            (x.shape[0], sample_train_hidden.shape[0]),\n        )\n        cos_similarity = self.sparse_dense_mul(mask, cos_similarity)\n\n        agg_out = torch.sparse.mm(cos_similarity, self.project2(sample_train_hidden))\n        # out = self.fc_out(out).squeeze()\n        out = self.fc_out_pred(torch.cat([mini_batch_out, agg_out], axis=1)).squeeze()\n        return out\n"
  },
  {
    "path": "qlib/contrib/model/pytorch_krnn.py",
    "content": "# Copyright (c) Microsoft Corporation.\r\n# Licensed under the MIT License.\r\n\r\n\r\nfrom __future__ import division\r\nfrom __future__ import print_function\r\n\r\nimport numpy as np\r\nimport pandas as pd\r\nfrom typing import Text, Union\r\nimport copy\r\nfrom ...utils import get_or_create_path\r\nfrom ...log import get_module_logger\r\n\r\nimport torch\r\nimport torch.nn as nn\r\nimport torch.optim as optim\r\n\r\nfrom ...model.base import Model\r\nfrom ...data.dataset import DatasetH\r\nfrom ...data.dataset.handler import DataHandlerLP\r\n\r\n########################################################################\r\n########################################################################\r\n########################################################################\r\n\r\n\r\nclass CNNEncoderBase(nn.Module):\r\n    def __init__(self, input_dim, output_dim, kernel_size, device):\r\n        \"\"\"Build a basic CNN encoder\r\n\r\n        Parameters\r\n        ----------\r\n        input_dim : int\r\n            The input dimension\r\n        output_dim : int\r\n            The output dimension\r\n        kernel_size : int\r\n            The size of convolutional kernels\r\n        \"\"\"\r\n        super().__init__()\r\n\r\n        self.input_dim = input_dim\r\n        self.output_dim = output_dim\r\n        self.kernel_size = kernel_size\r\n        self.device = device\r\n\r\n        # set padding to ensure the same length\r\n        # it is correct only when kernel_size is odd, dilation is 1, stride is 1\r\n        self.conv = nn.Conv1d(input_dim, output_dim, kernel_size, padding=(kernel_size - 1) // 2)\r\n\r\n    def forward(self, x):\r\n        \"\"\"\r\n        Parameters\r\n        ----------\r\n        x : torch.Tensor\r\n            input data\r\n\r\n        Returns\r\n        -------\r\n        torch.Tensor\r\n            Updated representations\r\n        \"\"\"\r\n\r\n        # input shape: [batch_size, seq_len*input_dim]\r\n        # output shape: [batch_size, seq_len, input_dim]\r\n        x = x.view(x.shape[0], -1, self.input_dim).permute(0, 2, 1).to(self.device)\r\n        y = self.conv(x)  # [batch_size, output_dim, conved_seq_len]\r\n        y = y.permute(0, 2, 1)  # [batch_size, conved_seq_len, output_dim]\r\n\r\n        return y\r\n\r\n\r\nclass KRNNEncoderBase(nn.Module):\r\n    def __init__(self, input_dim, output_dim, dup_num, rnn_layers, dropout, device):\r\n        \"\"\"Build K parallel RNNs\r\n\r\n        Parameters\r\n        ----------\r\n        input_dim : int\r\n            The input dimension\r\n        output_dim : int\r\n            The output dimension\r\n        dup_num : int\r\n            The number of parallel RNNs\r\n        rnn_layers: int\r\n            The number of RNN layers\r\n        \"\"\"\r\n        super().__init__()\r\n\r\n        self.input_dim = input_dim\r\n        self.output_dim = output_dim\r\n        self.dup_num = dup_num\r\n        self.rnn_layers = rnn_layers\r\n        self.dropout = dropout\r\n        self.device = device\r\n\r\n        self.rnn_modules = nn.ModuleList()\r\n        for _ in range(dup_num):\r\n            self.rnn_modules.append(nn.GRU(input_dim, output_dim, num_layers=self.rnn_layers, dropout=dropout))\r\n\r\n    def forward(self, x):\r\n        \"\"\"\r\n        Parameters\r\n        ----------\r\n        x : torch.Tensor\r\n            Input data\r\n        n_id : torch.Tensor\r\n            Node indices\r\n\r\n        Returns\r\n        -------\r\n        torch.Tensor\r\n            Updated representations\r\n        \"\"\"\r\n\r\n        # input shape: [batch_size, seq_len, input_dim]\r\n        # output shape: [batch_size, seq_len, output_dim]\r\n        # [seq_len, batch_size, input_dim]\r\n        batch_size, seq_len, input_dim = x.shape\r\n        x = x.permute(1, 0, 2).to(self.device)\r\n\r\n        hids = []\r\n        for rnn in self.rnn_modules:\r\n            h, _ = rnn(x)  # [seq_len, batch_size, output_dim]\r\n            hids.append(h)\r\n        # [seq_len, batch_size, output_dim, num_dups]\r\n        hids = torch.stack(hids, dim=-1)\r\n        hids = hids.view(seq_len, batch_size, self.output_dim, self.dup_num)\r\n        hids = hids.mean(dim=3)\r\n        hids = hids.permute(1, 0, 2)\r\n\r\n        return hids\r\n\r\n\r\nclass CNNKRNNEncoder(nn.Module):\r\n    def __init__(\r\n        self, cnn_input_dim, cnn_output_dim, cnn_kernel_size, rnn_output_dim, rnn_dup_num, rnn_layers, dropout, device\r\n    ):\r\n        \"\"\"Build an encoder composed of CNN and KRNN\r\n\r\n        Parameters\r\n        ----------\r\n        cnn_input_dim : int\r\n            The input dimension of CNN\r\n        cnn_output_dim : int\r\n            The output dimension of CNN\r\n        cnn_kernel_size : int\r\n            The size of convolutional kernels\r\n        rnn_output_dim : int\r\n            The output dimension of KRNN\r\n        rnn_dup_num : int\r\n            The number of parallel duplicates for KRNN\r\n        rnn_layers : int\r\n            The number of RNN layers\r\n        \"\"\"\r\n        super().__init__()\r\n\r\n        self.cnn_encoder = CNNEncoderBase(cnn_input_dim, cnn_output_dim, cnn_kernel_size, device)\r\n        self.krnn_encoder = KRNNEncoderBase(cnn_output_dim, rnn_output_dim, rnn_dup_num, rnn_layers, dropout, device)\r\n\r\n    def forward(self, x):\r\n        \"\"\"\r\n        Parameters\r\n        ----------\r\n        x : torch.Tensor\r\n            Input data\r\n        n_id : torch.Tensor\r\n            Node indices\r\n\r\n        Returns\r\n        -------\r\n        torch.Tensor\r\n            Updated representations\r\n        \"\"\"\r\n        cnn_out = self.cnn_encoder(x)\r\n        krnn_out = self.krnn_encoder(cnn_out)\r\n\r\n        return krnn_out\r\n\r\n\r\nclass KRNNModel(nn.Module):\r\n    def __init__(self, fea_dim, cnn_dim, cnn_kernel_size, rnn_dim, rnn_dups, rnn_layers, dropout, device, **params):\r\n        \"\"\"Build a KRNN model\r\n\r\n        Parameters\r\n        ----------\r\n        fea_dim : int\r\n            The feature dimension\r\n        cnn_dim : int\r\n            The hidden dimension of CNN\r\n        cnn_kernel_size : int\r\n            The size of convolutional kernels\r\n        rnn_dim : int\r\n            The hidden dimension of KRNN\r\n        rnn_dups : int\r\n            The number of parallel duplicates\r\n        rnn_layers: int\r\n            The number of RNN layers\r\n        \"\"\"\r\n        super().__init__()\r\n\r\n        self.encoder = CNNKRNNEncoder(\r\n            cnn_input_dim=fea_dim,\r\n            cnn_output_dim=cnn_dim,\r\n            cnn_kernel_size=cnn_kernel_size,\r\n            rnn_output_dim=rnn_dim,\r\n            rnn_dup_num=rnn_dups,\r\n            rnn_layers=rnn_layers,\r\n            dropout=dropout,\r\n            device=device,\r\n        )\r\n\r\n        self.out_fc = nn.Linear(rnn_dim, 1)\r\n        self.device = device\r\n\r\n    def forward(self, x):\r\n        # x: [batch_size, node_num, seq_len, input_dim]\r\n        encode = self.encoder(x)\r\n        out = self.out_fc(encode[:, -1, :]).squeeze().to(self.device)\r\n\r\n        return out\r\n\r\n\r\nclass KRNN(Model):\r\n    \"\"\"KRNN Model\r\n\r\n    Parameters\r\n    ----------\r\n    d_feat : int\r\n        input dimension for each time step\r\n    metric: str\r\n        the evaluation metric used in early stop\r\n    optimizer : str\r\n        optimizer name\r\n    GPU : str\r\n        the GPU ID(s) used for training\r\n    \"\"\"\r\n\r\n    def __init__(\r\n        self,\r\n        fea_dim=6,\r\n        cnn_dim=64,\r\n        cnn_kernel_size=3,\r\n        rnn_dim=64,\r\n        rnn_dups=3,\r\n        rnn_layers=2,\r\n        dropout=0,\r\n        n_epochs=200,\r\n        lr=0.001,\r\n        metric=\"\",\r\n        batch_size=2000,\r\n        early_stop=20,\r\n        loss=\"mse\",\r\n        optimizer=\"adam\",\r\n        GPU=0,\r\n        seed=None,\r\n        **kwargs,\r\n    ):\r\n        # Set logger.\r\n        self.logger = get_module_logger(\"KRNN\")\r\n        self.logger.info(\"KRNN pytorch version...\")\r\n\r\n        # set hyper-parameters.\r\n        self.fea_dim = fea_dim\r\n        self.cnn_dim = cnn_dim\r\n        self.cnn_kernel_size = cnn_kernel_size\r\n        self.rnn_dim = rnn_dim\r\n        self.rnn_dups = rnn_dups\r\n        self.rnn_layers = rnn_layers\r\n        self.dropout = dropout\r\n        self.n_epochs = n_epochs\r\n        self.lr = lr\r\n        self.metric = metric\r\n        self.batch_size = batch_size\r\n        self.early_stop = early_stop\r\n        self.optimizer = optimizer.lower()\r\n        self.loss = loss\r\n        self.device = torch.device(\"cuda:%d\" % (GPU) if torch.cuda.is_available() and GPU >= 0 else \"cpu\")\r\n        self.seed = seed\r\n\r\n        self.logger.info(\r\n            \"KRNN parameters setting:\"\r\n            \"\\nfea_dim : {}\"\r\n            \"\\ncnn_dim : {}\"\r\n            \"\\ncnn_kernel_size : {}\"\r\n            \"\\nrnn_dim : {}\"\r\n            \"\\nrnn_dups : {}\"\r\n            \"\\nrnn_layers : {}\"\r\n            \"\\ndropout : {}\"\r\n            \"\\nn_epochs : {}\"\r\n            \"\\nlr : {}\"\r\n            \"\\nmetric : {}\"\r\n            \"\\nbatch_size: {}\"\r\n            \"\\nearly_stop : {}\"\r\n            \"\\noptimizer : {}\"\r\n            \"\\nloss_type : {}\"\r\n            \"\\nvisible_GPU : {}\"\r\n            \"\\nuse_GPU : {}\"\r\n            \"\\nseed : {}\".format(\r\n                fea_dim,\r\n                cnn_dim,\r\n                cnn_kernel_size,\r\n                rnn_dim,\r\n                rnn_dups,\r\n                rnn_layers,\r\n                dropout,\r\n                n_epochs,\r\n                lr,\r\n                metric,\r\n                batch_size,\r\n                early_stop,\r\n                optimizer.lower(),\r\n                loss,\r\n                GPU,\r\n                self.use_gpu,\r\n                seed,\r\n            )\r\n        )\r\n\r\n        if self.seed is not None:\r\n            np.random.seed(self.seed)\r\n            torch.manual_seed(self.seed)\r\n\r\n        self.krnn_model = KRNNModel(\r\n            fea_dim=self.fea_dim,\r\n            cnn_dim=self.cnn_dim,\r\n            cnn_kernel_size=self.cnn_kernel_size,\r\n            rnn_dim=self.rnn_dim,\r\n            rnn_dups=self.rnn_dups,\r\n            rnn_layers=self.rnn_layers,\r\n            dropout=self.dropout,\r\n            device=self.device,\r\n        )\r\n        if optimizer.lower() == \"adam\":\r\n            self.train_optimizer = optim.Adam(self.krnn_model.parameters(), lr=self.lr)\r\n        elif optimizer.lower() == \"gd\":\r\n            self.train_optimizer = optim.SGD(self.krnn_model.parameters(), lr=self.lr)\r\n        else:\r\n            raise NotImplementedError(\"optimizer {} is not supported!\".format(optimizer))\r\n\r\n        self.fitted = False\r\n        self.krnn_model.to(self.device)\r\n\r\n    @property\r\n    def use_gpu(self):\r\n        return self.device != torch.device(\"cpu\")\r\n\r\n    def mse(self, pred, label):\r\n        loss = (pred - label) ** 2\r\n        return torch.mean(loss)\r\n\r\n    def loss_fn(self, pred, label):\r\n        mask = ~torch.isnan(label)\r\n\r\n        if self.loss == \"mse\":\r\n            return self.mse(pred[mask], label[mask])\r\n\r\n        raise ValueError(\"unknown loss `%s`\" % self.loss)\r\n\r\n    def metric_fn(self, pred, label):\r\n        mask = torch.isfinite(label)\r\n\r\n        if self.metric in (\"\", \"loss\"):\r\n            return -self.loss_fn(pred[mask], label[mask])\r\n\r\n        raise ValueError(\"unknown metric `%s`\" % self.metric)\r\n\r\n    def get_daily_inter(self, df, shuffle=False):\r\n        # organize the train data into daily batches\r\n        daily_count = df.groupby(level=0, group_keys=False).size().values\r\n        daily_index = np.roll(np.cumsum(daily_count), 1)\r\n        daily_index[0] = 0\r\n        if shuffle:\r\n            # shuffle data\r\n            daily_shuffle = list(zip(daily_index, daily_count))\r\n            np.random.shuffle(daily_shuffle)\r\n            daily_index, daily_count = zip(*daily_shuffle)\r\n        return daily_index, daily_count\r\n\r\n    def train_epoch(self, x_train, y_train):\r\n        x_train_values = x_train.values\r\n        y_train_values = np.squeeze(y_train.values)\r\n        self.krnn_model.train()\r\n\r\n        indices = np.arange(len(x_train_values))\r\n        np.random.shuffle(indices)\r\n\r\n        for i in range(len(indices))[:: self.batch_size]:\r\n            if len(indices) - i < self.batch_size:\r\n                break\r\n\r\n            feature = torch.from_numpy(x_train_values[indices[i : i + self.batch_size]]).float().to(self.device)\r\n            label = torch.from_numpy(y_train_values[indices[i : i + self.batch_size]]).float().to(self.device)\r\n\r\n            pred = self.krnn_model(feature)\r\n            loss = self.loss_fn(pred, label)\r\n\r\n            self.train_optimizer.zero_grad()\r\n            loss.backward()\r\n            torch.nn.utils.clip_grad_value_(self.krnn_model.parameters(), 3.0)\r\n            self.train_optimizer.step()\r\n\r\n    def test_epoch(self, data_x, data_y):\r\n        # prepare training data\r\n        x_values = data_x.values\r\n        y_values = np.squeeze(data_y.values)\r\n\r\n        self.krnn_model.eval()\r\n\r\n        scores = []\r\n        losses = []\r\n\r\n        indices = np.arange(len(x_values))\r\n\r\n        for i in range(len(indices))[:: self.batch_size]:\r\n            if len(indices) - i < self.batch_size:\r\n                break\r\n\r\n            feature = torch.from_numpy(x_values[indices[i : i + self.batch_size]]).float().to(self.device)\r\n            label = torch.from_numpy(y_values[indices[i : i + self.batch_size]]).float().to(self.device)\r\n\r\n            pred = self.krnn_model(feature)\r\n            loss = self.loss_fn(pred, label)\r\n            losses.append(loss.item())\r\n\r\n            score = self.metric_fn(pred, label)\r\n            scores.append(score.item())\r\n\r\n        return np.mean(losses), np.mean(scores)\r\n\r\n    def fit(\r\n        self,\r\n        dataset: DatasetH,\r\n        evals_result=dict(),\r\n        save_path=None,\r\n    ):\r\n        df_train, df_valid, df_test = dataset.prepare(\r\n            [\"train\", \"valid\", \"test\"],\r\n            col_set=[\"feature\", \"label\"],\r\n            data_key=DataHandlerLP.DK_L,\r\n        )\r\n        if df_train.empty or df_valid.empty:\r\n            raise ValueError(\"Empty data from dataset, please check your dataset config.\")\r\n\r\n        x_train, y_train = df_train[\"feature\"], df_train[\"label\"]\r\n        x_valid, y_valid = df_valid[\"feature\"], df_valid[\"label\"]\r\n\r\n        save_path = get_or_create_path(save_path)\r\n        stop_steps = 0\r\n        train_loss = 0\r\n        best_score = -np.inf\r\n        best_epoch = 0\r\n        evals_result[\"train\"] = []\r\n        evals_result[\"valid\"] = []\r\n\r\n        # train\r\n        self.logger.info(\"training...\")\r\n        self.fitted = True\r\n\r\n        for step in range(self.n_epochs):\r\n            self.logger.info(\"Epoch%d:\", step)\r\n            self.logger.info(\"training...\")\r\n            self.train_epoch(x_train, y_train)\r\n            self.logger.info(\"evaluating...\")\r\n            train_loss, train_score = self.test_epoch(x_train, y_train)\r\n            val_loss, val_score = self.test_epoch(x_valid, y_valid)\r\n            self.logger.info(\"train %.6f, valid %.6f\" % (train_score, val_score))\r\n            evals_result[\"train\"].append(train_score)\r\n            evals_result[\"valid\"].append(val_score)\r\n\r\n            if val_score > best_score:\r\n                best_score = val_score\r\n                stop_steps = 0\r\n                best_epoch = step\r\n                best_param = copy.deepcopy(self.krnn_model.state_dict())\r\n            else:\r\n                stop_steps += 1\r\n                if stop_steps >= self.early_stop:\r\n                    self.logger.info(\"early stop\")\r\n                    break\r\n\r\n        self.logger.info(\"best score: %.6lf @ %d\" % (best_score, best_epoch))\r\n        self.krnn_model.load_state_dict(best_param)\r\n        torch.save(best_param, save_path)\r\n\r\n        if self.use_gpu:\r\n            torch.cuda.empty_cache()\r\n\r\n    def predict(self, dataset: DatasetH, segment: Union[Text, slice] = \"test\"):\r\n        if not self.fitted:\r\n            raise ValueError(\"model is not fitted yet!\")\r\n\r\n        x_test = dataset.prepare(segment, col_set=\"feature\", data_key=DataHandlerLP.DK_I)\r\n        index = x_test.index\r\n        self.krnn_model.eval()\r\n        x_values = x_test.values\r\n        sample_num = x_values.shape[0]\r\n        preds = []\r\n\r\n        for begin in range(sample_num)[:: self.batch_size]:\r\n            if sample_num - begin < self.batch_size:\r\n                end = sample_num\r\n            else:\r\n                end = begin + self.batch_size\r\n            x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)\r\n            with torch.no_grad():\r\n                pred = self.krnn_model(x_batch).detach().cpu().numpy()\r\n            preds.append(pred)\r\n\r\n        return pd.Series(np.concatenate(preds), index=index)\r\n"
  },
  {
    "path": "qlib/contrib/model/pytorch_localformer.py",
    "content": "# Copyright (c) Microsoft Corporation.\r\n# Licensed under the MIT License.\r\n\r\n\r\nfrom __future__ import division\r\nfrom __future__ import print_function\r\n\r\nimport numpy as np\r\nimport pandas as pd\r\nfrom typing import Text, Union\r\nimport copy\r\nimport math\r\nfrom ...utils import get_or_create_path\r\nfrom ...log import get_module_logger\r\n\r\nimport torch\r\nimport torch.nn as nn\r\nimport torch.optim as optim\r\n\r\nfrom ...model.base import Model\r\nfrom ...data.dataset import DatasetH\r\nfrom ...data.dataset.handler import DataHandlerLP\r\nfrom torch.nn.modules.container import ModuleList\r\n\r\n# qrun examples/benchmarks/Localformer/workflow_config_localformer_Alpha360.yaml ”\r\n\r\n\r\nclass LocalformerModel(Model):\r\n    def __init__(\r\n        self,\r\n        d_feat: int = 20,\r\n        d_model: int = 64,\r\n        batch_size: int = 2048,\r\n        nhead: int = 2,\r\n        num_layers: int = 2,\r\n        dropout: float = 0,\r\n        n_epochs=100,\r\n        lr=0.0001,\r\n        metric=\"\",\r\n        early_stop=5,\r\n        loss=\"mse\",\r\n        optimizer=\"adam\",\r\n        reg=1e-3,\r\n        n_jobs=10,\r\n        GPU=0,\r\n        seed=None,\r\n        **kwargs,\r\n    ):\r\n        # set hyper-parameters.\r\n        self.d_model = d_model\r\n        self.dropout = dropout\r\n        self.n_epochs = n_epochs\r\n        self.lr = lr\r\n        self.reg = reg\r\n        self.metric = metric\r\n        self.batch_size = batch_size\r\n        self.early_stop = early_stop\r\n        self.optimizer = optimizer.lower()\r\n        self.loss = loss\r\n        self.n_jobs = n_jobs\r\n        self.device = torch.device(\"cuda:%d\" % GPU if torch.cuda.is_available() and GPU >= 0 else \"cpu\")\r\n        self.seed = seed\r\n        self.logger = get_module_logger(\"TransformerModel\")\r\n        self.logger.info(\"Naive Transformer:\" \"\\nbatch_size : {}\" \"\\ndevice : {}\".format(self.batch_size, self.device))\r\n\r\n        if self.seed is not None:\r\n            np.random.seed(self.seed)\r\n            torch.manual_seed(self.seed)\r\n\r\n        self.model = Transformer(d_feat, d_model, nhead, num_layers, dropout, self.device)\r\n        if optimizer.lower() == \"adam\":\r\n            self.train_optimizer = optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.reg)\r\n        elif optimizer.lower() == \"gd\":\r\n            self.train_optimizer = optim.SGD(self.model.parameters(), lr=self.lr, weight_decay=self.reg)\r\n        else:\r\n            raise NotImplementedError(\"optimizer {} is not supported!\".format(optimizer))\r\n\r\n        self.fitted = False\r\n        self.model.to(self.device)\r\n\r\n    @property\r\n    def use_gpu(self):\r\n        return self.device != torch.device(\"cpu\")\r\n\r\n    def mse(self, pred, label):\r\n        loss = (pred.float() - label.float()) ** 2\r\n        return torch.mean(loss)\r\n\r\n    def loss_fn(self, pred, label):\r\n        mask = ~torch.isnan(label)\r\n\r\n        if self.loss == \"mse\":\r\n            return self.mse(pred[mask], label[mask])\r\n\r\n        raise ValueError(\"unknown loss `%s`\" % self.loss)\r\n\r\n    def metric_fn(self, pred, label):\r\n        mask = torch.isfinite(label)\r\n\r\n        if self.metric in (\"\", \"loss\"):\r\n            return -self.loss_fn(pred[mask], label[mask])\r\n\r\n        raise ValueError(\"unknown metric `%s`\" % self.metric)\r\n\r\n    def train_epoch(self, x_train, y_train):\r\n        x_train_values = x_train.values\r\n        y_train_values = np.squeeze(y_train.values)\r\n\r\n        self.model.train()\r\n\r\n        indices = np.arange(len(x_train_values))\r\n        np.random.shuffle(indices)\r\n\r\n        for i in range(len(indices))[:: self.batch_size]:\r\n            if len(indices) - i < self.batch_size:\r\n                break\r\n\r\n            feature = torch.from_numpy(x_train_values[indices[i : i + self.batch_size]]).float().to(self.device)\r\n            label = torch.from_numpy(y_train_values[indices[i : i + self.batch_size]]).float().to(self.device)\r\n\r\n            pred = self.model(feature)\r\n            loss = self.loss_fn(pred, label)\r\n\r\n            self.train_optimizer.zero_grad()\r\n            loss.backward()\r\n            torch.nn.utils.clip_grad_value_(self.model.parameters(), 3.0)\r\n            self.train_optimizer.step()\r\n\r\n    def test_epoch(self, data_x, data_y):\r\n        # prepare training data\r\n        x_values = data_x.values\r\n        y_values = np.squeeze(data_y.values)\r\n\r\n        self.model.eval()\r\n\r\n        scores = []\r\n        losses = []\r\n\r\n        indices = np.arange(len(x_values))\r\n\r\n        for i in range(len(indices))[:: self.batch_size]:\r\n            if len(indices) - i < self.batch_size:\r\n                break\r\n\r\n            feature = torch.from_numpy(x_values[indices[i : i + self.batch_size]]).float().to(self.device)\r\n            label = torch.from_numpy(y_values[indices[i : i + self.batch_size]]).float().to(self.device)\r\n\r\n            with torch.no_grad():\r\n                pred = self.model(feature)\r\n                loss = self.loss_fn(pred, label)\r\n                losses.append(loss.item())\r\n\r\n                score = self.metric_fn(pred, label)\r\n                scores.append(score.item())\r\n\r\n        return np.mean(losses), np.mean(scores)\r\n\r\n    def fit(\r\n        self,\r\n        dataset: DatasetH,\r\n        evals_result=dict(),\r\n        save_path=None,\r\n    ):\r\n        df_train, df_valid, df_test = dataset.prepare(\r\n            [\"train\", \"valid\", \"test\"],\r\n            col_set=[\"feature\", \"label\"],\r\n            data_key=DataHandlerLP.DK_L,\r\n        )\r\n        if df_train.empty or df_valid.empty:\r\n            raise ValueError(\"Empty data from dataset, please check your dataset config.\")\r\n\r\n        x_train, y_train = df_train[\"feature\"], df_train[\"label\"]\r\n        x_valid, y_valid = df_valid[\"feature\"], df_valid[\"label\"]\r\n\r\n        save_path = get_or_create_path(save_path)\r\n        stop_steps = 0\r\n        train_loss = 0\r\n        best_score = -np.inf\r\n        best_epoch = 0\r\n        evals_result[\"train\"] = []\r\n        evals_result[\"valid\"] = []\r\n\r\n        # train\r\n        self.logger.info(\"training...\")\r\n        self.fitted = True\r\n\r\n        for step in range(self.n_epochs):\r\n            self.logger.info(\"Epoch%d:\", step)\r\n            self.logger.info(\"training...\")\r\n            self.train_epoch(x_train, y_train)\r\n            self.logger.info(\"evaluating...\")\r\n            train_loss, train_score = self.test_epoch(x_train, y_train)\r\n            val_loss, val_score = self.test_epoch(x_valid, y_valid)\r\n            self.logger.info(\"train %.6f, valid %.6f\" % (train_score, val_score))\r\n            evals_result[\"train\"].append(train_score)\r\n            evals_result[\"valid\"].append(val_score)\r\n\r\n            if val_score > best_score:\r\n                best_score = val_score\r\n                stop_steps = 0\r\n                best_epoch = step\r\n                best_param = copy.deepcopy(self.model.state_dict())\r\n            else:\r\n                stop_steps += 1\r\n                if stop_steps >= self.early_stop:\r\n                    self.logger.info(\"early stop\")\r\n                    break\r\n\r\n        self.logger.info(\"best score: %.6lf @ %d\" % (best_score, best_epoch))\r\n        self.model.load_state_dict(best_param)\r\n        torch.save(best_param, save_path)\r\n\r\n        if self.use_gpu:\r\n            torch.cuda.empty_cache()\r\n\r\n    def predict(self, dataset: DatasetH, segment: Union[Text, slice] = \"test\"):\r\n        if not self.fitted:\r\n            raise ValueError(\"model is not fitted yet!\")\r\n\r\n        x_test = dataset.prepare(segment, col_set=\"feature\", data_key=DataHandlerLP.DK_I)\r\n        index = x_test.index\r\n        self.model.eval()\r\n        x_values = x_test.values\r\n        sample_num = x_values.shape[0]\r\n        preds = []\r\n\r\n        for begin in range(sample_num)[:: self.batch_size]:\r\n            if sample_num - begin < self.batch_size:\r\n                end = sample_num\r\n            else:\r\n                end = begin + self.batch_size\r\n\r\n            x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)\r\n\r\n            with torch.no_grad():\r\n                pred = self.model(x_batch).detach().cpu().numpy()\r\n\r\n            preds.append(pred)\r\n\r\n        return pd.Series(np.concatenate(preds), index=index)\r\n\r\n\r\nclass PositionalEncoding(nn.Module):\r\n    def __init__(self, d_model, max_len=1000):\r\n        super(PositionalEncoding, self).__init__()\r\n        pe = torch.zeros(max_len, d_model)\r\n        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)\r\n        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))\r\n        pe[:, 0::2] = torch.sin(position * div_term)\r\n        pe[:, 1::2] = torch.cos(position * div_term)\r\n        pe = pe.unsqueeze(0).transpose(0, 1)\r\n        self.register_buffer(\"pe\", pe)\r\n\r\n    def forward(self, x):\r\n        # [T, N, F]\r\n        return x + self.pe[: x.size(0), :]\r\n\r\n\r\ndef _get_clones(module, N):\r\n    return ModuleList([copy.deepcopy(module) for i in range(N)])\r\n\r\n\r\nclass LocalformerEncoder(nn.Module):\r\n    __constants__ = [\"norm\"]\r\n\r\n    def __init__(self, encoder_layer, num_layers, d_model):\r\n        super(LocalformerEncoder, self).__init__()\r\n        self.layers = _get_clones(encoder_layer, num_layers)\r\n        self.conv = _get_clones(nn.Conv1d(d_model, d_model, 3, 1, 1), num_layers)\r\n        self.num_layers = num_layers\r\n\r\n    def forward(self, src, mask):\r\n        output = src\r\n        out = src\r\n\r\n        for i, mod in enumerate(self.layers):\r\n            # [T, N, F] --> [N, T, F] --> [N, F, T]\r\n            out = output.transpose(1, 0).transpose(2, 1)\r\n            out = self.conv[i](out).transpose(2, 1).transpose(1, 0)\r\n\r\n            output = mod(output + out, src_mask=mask)\r\n\r\n        return output + out\r\n\r\n\r\nclass Transformer(nn.Module):\r\n    def __init__(self, d_feat=6, d_model=8, nhead=4, num_layers=2, dropout=0.5, device=None):\r\n        super(Transformer, self).__init__()\r\n        self.rnn = nn.GRU(\r\n            input_size=d_model,\r\n            hidden_size=d_model,\r\n            num_layers=num_layers,\r\n            batch_first=False,\r\n            dropout=dropout,\r\n        )\r\n        self.feature_layer = nn.Linear(d_feat, d_model)\r\n        self.pos_encoder = PositionalEncoding(d_model)\r\n        self.encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dropout=dropout)\r\n        self.transformer_encoder = LocalformerEncoder(self.encoder_layer, num_layers=num_layers, d_model=d_model)\r\n        self.decoder_layer = nn.Linear(d_model, 1)\r\n        self.device = device\r\n        self.d_feat = d_feat\r\n\r\n    def forward(self, src):\r\n        # src [N, F*T] --> [N, T, F]\r\n        src = src.reshape(len(src), self.d_feat, -1).permute(0, 2, 1)\r\n        src = self.feature_layer(src)\r\n\r\n        # src [N, T, F] --> [T, N, F], [60, 512, 8]\r\n        src = src.transpose(1, 0)  # not batch first\r\n\r\n        mask = None\r\n\r\n        src = self.pos_encoder(src)\r\n        output = self.transformer_encoder(src, mask)  # [60, 512, 8]\r\n\r\n        output, _ = self.rnn(output)\r\n\r\n        # [T, N, F] --> [N, T*F]\r\n        output = self.decoder_layer(output.transpose(1, 0)[:, -1, :])  # [512, 1]\r\n\r\n        return output.squeeze()\r\n"
  },
  {
    "path": "qlib/contrib/model/pytorch_localformer_ts.py",
    "content": "# Copyright (c) Microsoft Corporation.\r\n# Licensed under the MIT License.\r\n\r\n\r\nfrom __future__ import division\r\nfrom __future__ import print_function\r\n\r\nimport numpy as np\r\nimport pandas as pd\r\nimport copy\r\nimport math\r\nfrom ...utils import get_or_create_path\r\nfrom ...log import get_module_logger\r\n\r\nimport torch\r\nimport torch.nn as nn\r\nimport torch.optim as optim\r\nfrom torch.utils.data import DataLoader\r\n\r\nfrom ...model.base import Model\r\nfrom ...data.dataset import DatasetH\r\nfrom ...data.dataset.handler import DataHandlerLP\r\nfrom torch.nn.modules.container import ModuleList\r\n\r\n\r\nclass LocalformerModel(Model):\r\n    def __init__(\r\n        self,\r\n        d_feat: int = 20,\r\n        d_model: int = 64,\r\n        batch_size: int = 8192,\r\n        nhead: int = 2,\r\n        num_layers: int = 2,\r\n        dropout: float = 0,\r\n        n_epochs=100,\r\n        lr=0.0001,\r\n        metric=\"\",\r\n        early_stop=5,\r\n        loss=\"mse\",\r\n        optimizer=\"adam\",\r\n        reg=1e-3,\r\n        n_jobs=10,\r\n        GPU=0,\r\n        seed=None,\r\n        **kwargs,\r\n    ):\r\n        # set hyper-parameters.\r\n        self.d_model = d_model\r\n        self.dropout = dropout\r\n        self.n_epochs = n_epochs\r\n        self.lr = lr\r\n        self.reg = reg\r\n        self.metric = metric\r\n        self.batch_size = batch_size\r\n        self.early_stop = early_stop\r\n        self.optimizer = optimizer.lower()\r\n        self.loss = loss\r\n        self.n_jobs = n_jobs\r\n        self.device = torch.device(\"cuda:%d\" % GPU if torch.cuda.is_available() and GPU >= 0 else \"cpu\")\r\n        self.seed = seed\r\n        self.logger = get_module_logger(\"TransformerModel\")\r\n        self.logger.info(\r\n            \"Improved Transformer:\" \"\\nbatch_size : {}\" \"\\ndevice : {}\".format(self.batch_size, self.device)\r\n        )\r\n\r\n        if self.seed is not None:\r\n            np.random.seed(self.seed)\r\n            torch.manual_seed(self.seed)\r\n\r\n        self.model = Transformer(d_feat, d_model, nhead, num_layers, dropout, self.device)\r\n        if optimizer.lower() == \"adam\":\r\n            self.train_optimizer = optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.reg)\r\n        elif optimizer.lower() == \"gd\":\r\n            self.train_optimizer = optim.SGD(self.model.parameters(), lr=self.lr, weight_decay=self.reg)\r\n        else:\r\n            raise NotImplementedError(\"optimizer {} is not supported!\".format(optimizer))\r\n\r\n        self.fitted = False\r\n        self.model.to(self.device)\r\n\r\n    @property\r\n    def use_gpu(self):\r\n        return self.device != torch.device(\"cpu\")\r\n\r\n    def mse(self, pred, label):\r\n        loss = (pred.float() - label.float()) ** 2\r\n        return torch.mean(loss)\r\n\r\n    def loss_fn(self, pred, label):\r\n        mask = ~torch.isnan(label)\r\n\r\n        if self.loss == \"mse\":\r\n            return self.mse(pred[mask], label[mask])\r\n\r\n        raise ValueError(\"unknown loss `%s`\" % self.loss)\r\n\r\n    def metric_fn(self, pred, label):\r\n        mask = torch.isfinite(label)\r\n\r\n        if self.metric in (\"\", \"loss\"):\r\n            return -self.loss_fn(pred[mask], label[mask])\r\n\r\n        raise ValueError(\"unknown metric `%s`\" % self.metric)\r\n\r\n    def train_epoch(self, data_loader):\r\n        self.model.train()\r\n\r\n        for data in data_loader:\r\n            feature = data[:, :, 0:-1].to(self.device)\r\n            label = data[:, -1, -1].to(self.device)\r\n\r\n            pred = self.model(feature.float())  # .float()\r\n            loss = self.loss_fn(pred, label)\r\n\r\n            self.train_optimizer.zero_grad()\r\n            loss.backward()\r\n            torch.nn.utils.clip_grad_value_(self.model.parameters(), 3.0)\r\n            self.train_optimizer.step()\r\n\r\n    def test_epoch(self, data_loader):\r\n        self.model.eval()\r\n\r\n        scores = []\r\n        losses = []\r\n\r\n        for data in data_loader:\r\n            feature = data[:, :, 0:-1].to(self.device)\r\n            label = data[:, -1, -1].to(self.device)\r\n\r\n            with torch.no_grad():\r\n                pred = self.model(feature.float())  # .float()\r\n                loss = self.loss_fn(pred, label)\r\n                losses.append(loss.item())\r\n\r\n                score = self.metric_fn(pred, label)\r\n                scores.append(score.item())\r\n\r\n        return np.mean(losses), np.mean(scores)\r\n\r\n    def fit(\r\n        self,\r\n        dataset: DatasetH,\r\n        evals_result=dict(),\r\n        save_path=None,\r\n    ):\r\n        dl_train = dataset.prepare(\"train\", col_set=[\"feature\", \"label\"], data_key=DataHandlerLP.DK_L)\r\n        dl_valid = dataset.prepare(\"valid\", col_set=[\"feature\", \"label\"], data_key=DataHandlerLP.DK_L)\r\n        if dl_train.empty or dl_valid.empty:\r\n            raise ValueError(\"Empty data from dataset, please check your dataset config.\")\r\n\r\n        dl_train.config(fillna_type=\"ffill+bfill\")  # process nan brought by dataloader\r\n        dl_valid.config(fillna_type=\"ffill+bfill\")  # process nan brought by dataloader\r\n\r\n        train_loader = DataLoader(\r\n            dl_train, batch_size=self.batch_size, shuffle=True, num_workers=self.n_jobs, drop_last=True\r\n        )\r\n        valid_loader = DataLoader(\r\n            dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs, drop_last=True\r\n        )\r\n\r\n        save_path = get_or_create_path(save_path)\r\n\r\n        stop_steps = 0\r\n        train_loss = 0\r\n        best_score = -np.inf\r\n        best_epoch = 0\r\n        evals_result[\"train\"] = []\r\n        evals_result[\"valid\"] = []\r\n\r\n        # train\r\n        self.logger.info(\"training...\")\r\n        self.fitted = True\r\n\r\n        for step in range(self.n_epochs):\r\n            self.logger.info(\"Epoch%d:\", step)\r\n            self.logger.info(\"training...\")\r\n            self.train_epoch(train_loader)\r\n            self.logger.info(\"evaluating...\")\r\n            train_loss, train_score = self.test_epoch(train_loader)\r\n            val_loss, val_score = self.test_epoch(valid_loader)\r\n            self.logger.info(\"train %.6f, valid %.6f\" % (train_score, val_score))\r\n            evals_result[\"train\"].append(train_score)\r\n            evals_result[\"valid\"].append(val_score)\r\n\r\n            if val_score > best_score:\r\n                best_score = val_score\r\n                stop_steps = 0\r\n                best_epoch = step\r\n                best_param = copy.deepcopy(self.model.state_dict())\r\n            else:\r\n                stop_steps += 1\r\n                if stop_steps >= self.early_stop:\r\n                    self.logger.info(\"early stop\")\r\n                    break\r\n\r\n        self.logger.info(\"best score: %.6lf @ %d\" % (best_score, best_epoch))\r\n        self.model.load_state_dict(best_param)\r\n        torch.save(best_param, save_path)\r\n\r\n        if self.use_gpu:\r\n            torch.cuda.empty_cache()\r\n\r\n    def predict(self, dataset):\r\n        if not self.fitted:\r\n            raise ValueError(\"model is not fitted yet!\")\r\n\r\n        dl_test = dataset.prepare(\"test\", col_set=[\"feature\", \"label\"], data_key=DataHandlerLP.DK_I)\r\n        dl_test.config(fillna_type=\"ffill+bfill\")\r\n        test_loader = DataLoader(dl_test, batch_size=self.batch_size, num_workers=self.n_jobs)\r\n        self.model.eval()\r\n        preds = []\r\n\r\n        for data in test_loader:\r\n            feature = data[:, :, 0:-1].to(self.device)\r\n\r\n            with torch.no_grad():\r\n                pred = self.model(feature.float()).detach().cpu().numpy()\r\n\r\n            preds.append(pred)\r\n\r\n        return pd.Series(np.concatenate(preds), index=dl_test.get_index())\r\n\r\n\r\nclass PositionalEncoding(nn.Module):\r\n    def __init__(self, d_model, max_len=1000):\r\n        super(PositionalEncoding, self).__init__()\r\n        pe = torch.zeros(max_len, d_model)\r\n        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)\r\n        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))\r\n        pe[:, 0::2] = torch.sin(position * div_term)\r\n        pe[:, 1::2] = torch.cos(position * div_term)\r\n        pe = pe.unsqueeze(0).transpose(0, 1)\r\n        self.register_buffer(\"pe\", pe)\r\n\r\n    def forward(self, x):\r\n        # [T, N, F]\r\n        return x + self.pe[: x.size(0), :]\r\n\r\n\r\ndef _get_clones(module, N):\r\n    return ModuleList([copy.deepcopy(module) for i in range(N)])\r\n\r\n\r\nclass LocalformerEncoder(nn.Module):\r\n    __constants__ = [\"norm\"]\r\n\r\n    def __init__(self, encoder_layer, num_layers, d_model):\r\n        super(LocalformerEncoder, self).__init__()\r\n        self.layers = _get_clones(encoder_layer, num_layers)\r\n        self.conv = _get_clones(nn.Conv1d(d_model, d_model, 3, 1, 1), num_layers)\r\n        self.num_layers = num_layers\r\n\r\n    def forward(self, src, mask):\r\n        output = src\r\n        out = src\r\n\r\n        for i, mod in enumerate(self.layers):\r\n            # [T, N, F] --> [N, T, F] --> [N, F, T]\r\n            out = output.transpose(1, 0).transpose(2, 1)\r\n            out = self.conv[i](out).transpose(2, 1).transpose(1, 0)\r\n\r\n            output = mod(output + out, src_mask=mask)\r\n\r\n        return output + out\r\n\r\n\r\nclass Transformer(nn.Module):\r\n    def __init__(self, d_feat=6, d_model=8, nhead=4, num_layers=2, dropout=0.5, device=None):\r\n        super(Transformer, self).__init__()\r\n        self.rnn = nn.GRU(\r\n            input_size=d_model,\r\n            hidden_size=d_model,\r\n            num_layers=num_layers,\r\n            batch_first=False,\r\n            dropout=dropout,\r\n        )\r\n        self.feature_layer = nn.Linear(d_feat, d_model)\r\n        self.pos_encoder = PositionalEncoding(d_model)\r\n        self.encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dropout=dropout)\r\n        self.transformer_encoder = LocalformerEncoder(self.encoder_layer, num_layers=num_layers, d_model=d_model)\r\n        self.decoder_layer = nn.Linear(d_model, 1)\r\n        self.device = device\r\n        self.d_feat = d_feat\r\n\r\n    def forward(self, src):\r\n        # src [N, T, F], [512, 60, 6]\r\n        src = self.feature_layer(src)  # [512, 60, 8]\r\n\r\n        # src [N, T, F] --> [T, N, F], [60, 512, 8]\r\n        src = src.transpose(1, 0)  # not batch first\r\n\r\n        mask = None\r\n\r\n        src = self.pos_encoder(src)\r\n        output = self.transformer_encoder(src, mask)  # [60, 512, 8]\r\n\r\n        output, _ = self.rnn(output)\r\n\r\n        # [T, N, F] --> [N, T*F]\r\n        output = self.decoder_layer(output.transpose(1, 0)[:, -1, :])  # [512, 1]\r\n\r\n        return output.squeeze()\r\n"
  },
  {
    "path": "qlib/contrib/model/pytorch_lstm.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\n\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport numpy as np\nimport pandas as pd\nfrom typing import Text, Union\nimport copy\nfrom ...utils import get_or_create_path\nfrom ...log import get_module_logger\n\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\n\nfrom ...model.base import Model\nfrom ...data.dataset import DatasetH\nfrom ...data.dataset.handler import DataHandlerLP\n\n\nclass LSTM(Model):\n    \"\"\"LSTM Model\n\n    Parameters\n    ----------\n    d_feat : int\n        input dimension for each time step\n    metric: str\n        the evaluation metric used in early stop\n    optimizer : str\n        optimizer name\n    GPU : str\n        the GPU ID(s) used for training\n    \"\"\"\n\n    def __init__(\n        self,\n        d_feat=6,\n        hidden_size=64,\n        num_layers=2,\n        dropout=0.0,\n        n_epochs=200,\n        lr=0.001,\n        metric=\"\",\n        batch_size=2000,\n        early_stop=20,\n        loss=\"mse\",\n        optimizer=\"adam\",\n        GPU=0,\n        seed=None,\n        **kwargs,\n    ):\n        # Set logger.\n        self.logger = get_module_logger(\"LSTM\")\n        self.logger.info(\"LSTM pytorch version...\")\n\n        # set hyper-parameters.\n        self.d_feat = d_feat\n        self.hidden_size = hidden_size\n        self.num_layers = num_layers\n        self.dropout = dropout\n        self.n_epochs = n_epochs\n        self.lr = lr\n        self.metric = metric\n        self.batch_size = batch_size\n        self.early_stop = early_stop\n        self.optimizer = optimizer.lower()\n        self.loss = loss\n        self.device = torch.device(\"cuda:%d\" % (GPU) if torch.cuda.is_available() and GPU >= 0 else \"cpu\")\n        self.seed = seed\n\n        self.logger.info(\n            \"LSTM parameters setting:\"\n            \"\\nd_feat : {}\"\n            \"\\nhidden_size : {}\"\n            \"\\nnum_layers : {}\"\n            \"\\ndropout : {}\"\n            \"\\nn_epochs : {}\"\n            \"\\nlr : {}\"\n            \"\\nmetric : {}\"\n            \"\\nbatch_size : {}\"\n            \"\\nearly_stop : {}\"\n            \"\\noptimizer : {}\"\n            \"\\nloss_type : {}\"\n            \"\\nvisible_GPU : {}\"\n            \"\\nuse_GPU : {}\"\n            \"\\nseed : {}\".format(\n                d_feat,\n                hidden_size,\n                num_layers,\n                dropout,\n                n_epochs,\n                lr,\n                metric,\n                batch_size,\n                early_stop,\n                optimizer.lower(),\n                loss,\n                GPU,\n                self.use_gpu,\n                seed,\n            )\n        )\n\n        if self.seed is not None:\n            np.random.seed(self.seed)\n            torch.manual_seed(self.seed)\n\n        self.lstm_model = LSTMModel(\n            d_feat=self.d_feat,\n            hidden_size=self.hidden_size,\n            num_layers=self.num_layers,\n            dropout=self.dropout,\n        )\n        if optimizer.lower() == \"adam\":\n            self.train_optimizer = optim.Adam(self.lstm_model.parameters(), lr=self.lr)\n        elif optimizer.lower() == \"gd\":\n            self.train_optimizer = optim.SGD(self.lstm_model.parameters(), lr=self.lr)\n        else:\n            raise NotImplementedError(\"optimizer {} is not supported!\".format(optimizer))\n\n        self.fitted = False\n        self.lstm_model.to(self.device)\n\n    @property\n    def use_gpu(self):\n        return self.device != torch.device(\"cpu\")\n\n    def mse(self, pred, label):\n        loss = (pred - label) ** 2\n        return torch.mean(loss)\n\n    def loss_fn(self, pred, label):\n        mask = ~torch.isnan(label)\n\n        if self.loss == \"mse\":\n            return self.mse(pred[mask], label[mask])\n\n        raise ValueError(\"unknown loss `%s`\" % self.loss)\n\n    def metric_fn(self, pred, label):\n        mask = torch.isfinite(label)\n\n        if self.metric in (\"\", \"loss\"):\n            return -self.loss_fn(pred[mask], label[mask])\n\n        raise ValueError(\"unknown metric `%s`\" % self.metric)\n\n    def train_epoch(self, x_train, y_train):\n        x_train_values = x_train.values\n        y_train_values = np.squeeze(y_train.values)\n\n        self.lstm_model.train()\n\n        indices = np.arange(len(x_train_values))\n        np.random.shuffle(indices)\n\n        for i in range(len(indices))[:: self.batch_size]:\n            if len(indices) - i < self.batch_size:\n                break\n\n            feature = torch.from_numpy(x_train_values[indices[i : i + self.batch_size]]).float().to(self.device)\n            label = torch.from_numpy(y_train_values[indices[i : i + self.batch_size]]).float().to(self.device)\n\n            pred = self.lstm_model(feature)\n            loss = self.loss_fn(pred, label)\n\n            self.train_optimizer.zero_grad()\n            loss.backward()\n            torch.nn.utils.clip_grad_value_(self.lstm_model.parameters(), 3.0)\n            self.train_optimizer.step()\n\n    def test_epoch(self, data_x, data_y):\n        # prepare training data\n        x_values = data_x.values\n        y_values = np.squeeze(data_y.values)\n\n        self.lstm_model.eval()\n\n        scores = []\n        losses = []\n\n        indices = np.arange(len(x_values))\n\n        for i in range(len(indices))[:: self.batch_size]:\n            if len(indices) - i < self.batch_size:\n                break\n\n            feature = torch.from_numpy(x_values[indices[i : i + self.batch_size]]).float().to(self.device)\n            label = torch.from_numpy(y_values[indices[i : i + self.batch_size]]).float().to(self.device)\n\n            pred = self.lstm_model(feature)\n            loss = self.loss_fn(pred, label)\n            losses.append(loss.item())\n\n            score = self.metric_fn(pred, label)\n            scores.append(score.item())\n\n        return np.mean(losses), np.mean(scores)\n\n    def fit(\n        self,\n        dataset: DatasetH,\n        evals_result=dict(),\n        save_path=None,\n    ):\n        df_train, df_valid, df_test = dataset.prepare(\n            [\"train\", \"valid\", \"test\"],\n            col_set=[\"feature\", \"label\"],\n            data_key=DataHandlerLP.DK_L,\n        )\n        if df_train.empty or df_valid.empty:\n            raise ValueError(\"Empty data from dataset, please check your dataset config.\")\n\n        x_train, y_train = df_train[\"feature\"], df_train[\"label\"]\n        x_valid, y_valid = df_valid[\"feature\"], df_valid[\"label\"]\n\n        save_path = get_or_create_path(save_path)\n        stop_steps = 0\n        train_loss = 0\n        best_score = -np.inf\n        best_epoch = 0\n        evals_result[\"train\"] = []\n        evals_result[\"valid\"] = []\n\n        # train\n        self.logger.info(\"training...\")\n        self.fitted = True\n\n        for step in range(self.n_epochs):\n            self.logger.info(\"Epoch%d:\", step)\n            self.logger.info(\"training...\")\n            self.train_epoch(x_train, y_train)\n            self.logger.info(\"evaluating...\")\n            train_loss, train_score = self.test_epoch(x_train, y_train)\n            val_loss, val_score = self.test_epoch(x_valid, y_valid)\n            self.logger.info(\"train %.6f, valid %.6f\" % (train_score, val_score))\n            evals_result[\"train\"].append(train_score)\n            evals_result[\"valid\"].append(val_score)\n\n            if val_score > best_score:\n                best_score = val_score\n                stop_steps = 0\n                best_epoch = step\n                best_param = copy.deepcopy(self.lstm_model.state_dict())\n            else:\n                stop_steps += 1\n                if stop_steps >= self.early_stop:\n                    self.logger.info(\"early stop\")\n                    break\n\n        self.logger.info(\"best score: %.6lf @ %d\" % (best_score, best_epoch))\n        self.lstm_model.load_state_dict(best_param)\n        torch.save(best_param, save_path)\n\n        if self.use_gpu:\n            torch.cuda.empty_cache()\n\n    def predict(self, dataset: DatasetH, segment: Union[Text, slice] = \"test\"):\n        if not self.fitted:\n            raise ValueError(\"model is not fitted yet!\")\n\n        x_test = dataset.prepare(segment, col_set=\"feature\", data_key=DataHandlerLP.DK_I)\n        index = x_test.index\n        self.lstm_model.eval()\n        x_values = x_test.values\n        sample_num = x_values.shape[0]\n        preds = []\n\n        for begin in range(sample_num)[:: self.batch_size]:\n            if sample_num - begin < self.batch_size:\n                end = sample_num\n            else:\n                end = begin + self.batch_size\n            x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)\n            with torch.no_grad():\n                pred = self.lstm_model(x_batch).detach().cpu().numpy()\n            preds.append(pred)\n\n        return pd.Series(np.concatenate(preds), index=index)\n\n\nclass LSTMModel(nn.Module):\n    def __init__(self, d_feat=6, hidden_size=64, num_layers=2, dropout=0.0):\n        super().__init__()\n\n        self.rnn = nn.LSTM(\n            input_size=d_feat,\n            hidden_size=hidden_size,\n            num_layers=num_layers,\n            batch_first=True,\n            dropout=dropout,\n        )\n        self.fc_out = nn.Linear(hidden_size, 1)\n\n        self.d_feat = d_feat\n\n    def forward(self, x):\n        # x: [N, F*T]\n        x = x.reshape(len(x), self.d_feat, -1)  # [N, F, T]\n        x = x.permute(0, 2, 1)  # [N, T, F]\n        out, _ = self.rnn(x)\n        return self.fc_out(out[:, -1, :]).squeeze()\n"
  },
  {
    "path": "qlib/contrib/model/pytorch_lstm_ts.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\n\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport numpy as np\nimport pandas as pd\nimport copy\nfrom ...utils import get_or_create_path\nfrom ...log import get_module_logger\n\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.utils.data import DataLoader\n\nfrom ...model.base import Model\nfrom ...data.dataset.handler import DataHandlerLP\nfrom ...model.utils import ConcatDataset\nfrom ...data.dataset.weight import Reweighter\n\n\nclass LSTM(Model):\n    \"\"\"LSTM Model\n\n    Parameters\n    ----------\n    d_feat : int\n        input dimension for each time step\n    metric: str\n        the evaluation metric used in early stop\n    optimizer : str\n        optimizer name\n    GPU : str\n        the GPU ID(s) used for training\n    \"\"\"\n\n    def __init__(\n        self,\n        d_feat=6,\n        hidden_size=64,\n        num_layers=2,\n        dropout=0.0,\n        n_epochs=200,\n        lr=0.001,\n        metric=\"\",\n        batch_size=2000,\n        early_stop=20,\n        loss=\"mse\",\n        optimizer=\"adam\",\n        n_jobs=10,\n        GPU=0,\n        seed=None,\n        **kwargs,\n    ):\n        # Set logger.\n        self.logger = get_module_logger(\"LSTM\")\n        self.logger.info(\"LSTM pytorch version...\")\n\n        # set hyper-parameters.\n        self.d_feat = d_feat\n        self.hidden_size = hidden_size\n        self.num_layers = num_layers\n        self.dropout = dropout\n        self.n_epochs = n_epochs\n        self.lr = lr\n        self.metric = metric\n        self.batch_size = batch_size\n        self.early_stop = early_stop\n        self.optimizer = optimizer.lower()\n        self.loss = loss\n        self.device = torch.device(\"cuda:%d\" % (GPU) if torch.cuda.is_available() and GPU >= 0 else \"cpu\")\n        self.n_jobs = n_jobs\n        self.seed = seed\n\n        self.logger.info(\n            \"LSTM parameters setting:\"\n            \"\\nd_feat : {}\"\n            \"\\nhidden_size : {}\"\n            \"\\nnum_layers : {}\"\n            \"\\ndropout : {}\"\n            \"\\nn_epochs : {}\"\n            \"\\nlr : {}\"\n            \"\\nmetric : {}\"\n            \"\\nbatch_size : {}\"\n            \"\\nearly_stop : {}\"\n            \"\\noptimizer : {}\"\n            \"\\nloss_type : {}\"\n            \"\\ndevice : {}\"\n            \"\\nn_jobs : {}\"\n            \"\\nuse_GPU : {}\"\n            \"\\nseed : {}\".format(\n                d_feat,\n                hidden_size,\n                num_layers,\n                dropout,\n                n_epochs,\n                lr,\n                metric,\n                batch_size,\n                early_stop,\n                optimizer.lower(),\n                loss,\n                self.device,\n                n_jobs,\n                self.use_gpu,\n                seed,\n            )\n        )\n\n        if self.seed is not None:\n            np.random.seed(self.seed)\n            torch.manual_seed(self.seed)\n\n        self.LSTM_model = LSTMModel(\n            d_feat=self.d_feat,\n            hidden_size=self.hidden_size,\n            num_layers=self.num_layers,\n            dropout=self.dropout,\n        ).to(self.device)\n        if optimizer.lower() == \"adam\":\n            self.train_optimizer = optim.Adam(self.LSTM_model.parameters(), lr=self.lr)\n        elif optimizer.lower() == \"gd\":\n            self.train_optimizer = optim.SGD(self.LSTM_model.parameters(), lr=self.lr)\n        else:\n            raise NotImplementedError(\"optimizer {} is not supported!\".format(optimizer))\n\n        self.fitted = False\n        self.LSTM_model.to(self.device)\n\n    @property\n    def use_gpu(self):\n        return self.device != torch.device(\"cpu\")\n\n    def mse(self, pred, label, weight):\n        loss = weight * (pred - label) ** 2\n        return torch.mean(loss)\n\n    def loss_fn(self, pred, label, weight):\n        mask = ~torch.isnan(label)\n\n        if weight is None:\n            weight = torch.ones_like(label)\n\n        if self.loss == \"mse\":\n            return self.mse(pred[mask], label[mask], weight[mask])\n\n        raise ValueError(\"unknown loss `%s`\" % self.loss)\n\n    def metric_fn(self, pred, label):\n        mask = torch.isfinite(label)\n\n        if self.metric in (\"\", \"loss\"):\n            return -self.loss_fn(pred[mask], label[mask], weight=None)\n\n        raise ValueError(\"unknown metric `%s`\" % self.metric)\n\n    def train_epoch(self, data_loader):\n        self.LSTM_model.train()\n\n        for data, weight in data_loader:\n            feature = data[:, :, 0:-1].to(self.device)\n            label = data[:, -1, -1].to(self.device)\n\n            pred = self.LSTM_model(feature.float())\n            loss = self.loss_fn(pred, label, weight.to(self.device))\n\n            self.train_optimizer.zero_grad()\n            loss.backward()\n            torch.nn.utils.clip_grad_value_(self.LSTM_model.parameters(), 3.0)\n            self.train_optimizer.step()\n\n    def test_epoch(self, data_loader):\n        self.LSTM_model.eval()\n\n        scores = []\n        losses = []\n\n        for data, weight in data_loader:\n            feature = data[:, :, 0:-1].to(self.device)\n            # feature[torch.isnan(feature)] = 0\n            label = data[:, -1, -1].to(self.device)\n\n            pred = self.LSTM_model(feature.float())\n            loss = self.loss_fn(pred, label, weight.to(self.device))\n            losses.append(loss.item())\n\n            score = self.metric_fn(pred, label)\n            scores.append(score.item())\n\n        return np.mean(losses), np.mean(scores)\n\n    def fit(\n        self,\n        dataset,\n        evals_result=dict(),\n        save_path=None,\n        reweighter=None,\n    ):\n        dl_train = dataset.prepare(\"train\", col_set=[\"feature\", \"label\"], data_key=DataHandlerLP.DK_L)\n        dl_valid = dataset.prepare(\"valid\", col_set=[\"feature\", \"label\"], data_key=DataHandlerLP.DK_L)\n        if dl_train.empty or dl_valid.empty:\n            raise ValueError(\"Empty data from dataset, please check your dataset config.\")\n\n        dl_train.config(fillna_type=\"ffill+bfill\")  # process nan brought by dataloader\n        dl_valid.config(fillna_type=\"ffill+bfill\")  # process nan brought by dataloader\n\n        if reweighter is None:\n            wl_train = np.ones(len(dl_train))\n            wl_valid = np.ones(len(dl_valid))\n        elif isinstance(reweighter, Reweighter):\n            wl_train = reweighter.reweight(dl_train)\n            wl_valid = reweighter.reweight(dl_valid)\n        else:\n            raise ValueError(\"Unsupported reweighter type.\")\n\n        train_loader = DataLoader(\n            ConcatDataset(dl_train, wl_train),\n            batch_size=self.batch_size,\n            shuffle=True,\n            num_workers=self.n_jobs,\n            drop_last=True,\n        )\n        valid_loader = DataLoader(\n            ConcatDataset(dl_valid, wl_valid),\n            batch_size=self.batch_size,\n            shuffle=False,\n            num_workers=self.n_jobs,\n            drop_last=True,\n        )\n\n        save_path = get_or_create_path(save_path)\n\n        stop_steps = 0\n        train_loss = 0\n        best_score = -np.inf\n        best_epoch = 0\n        evals_result[\"train\"] = []\n        evals_result[\"valid\"] = []\n\n        # train\n        self.logger.info(\"training...\")\n        self.fitted = True\n\n        for step in range(self.n_epochs):\n            self.logger.info(\"Epoch%d:\", step)\n            self.logger.info(\"training...\")\n            self.train_epoch(train_loader)\n            self.logger.info(\"evaluating...\")\n            train_loss, train_score = self.test_epoch(train_loader)\n            val_loss, val_score = self.test_epoch(valid_loader)\n            self.logger.info(\"train %.6f, valid %.6f\" % (train_score, val_score))\n            evals_result[\"train\"].append(train_score)\n            evals_result[\"valid\"].append(val_score)\n\n            if val_score > best_score:\n                best_score = val_score\n                stop_steps = 0\n                best_epoch = step\n                best_param = copy.deepcopy(self.LSTM_model.state_dict())\n            else:\n                stop_steps += 1\n                if stop_steps >= self.early_stop:\n                    self.logger.info(\"early stop\")\n                    break\n\n        self.logger.info(\"best score: %.6lf @ %d\" % (best_score, best_epoch))\n        self.LSTM_model.load_state_dict(best_param)\n        torch.save(best_param, save_path)\n\n        if self.use_gpu:\n            torch.cuda.empty_cache()\n\n    def predict(self, dataset):\n        if not self.fitted:\n            raise ValueError(\"model is not fitted yet!\")\n\n        dl_test = dataset.prepare(\"test\", col_set=[\"feature\", \"label\"], data_key=DataHandlerLP.DK_I)\n        dl_test.config(fillna_type=\"ffill+bfill\")\n        test_loader = DataLoader(dl_test, batch_size=self.batch_size, num_workers=self.n_jobs)\n        self.LSTM_model.eval()\n        preds = []\n\n        for data in test_loader:\n            feature = data[:, :, 0:-1].to(self.device)\n\n            with torch.no_grad():\n                pred = self.LSTM_model(feature.float()).detach().cpu().numpy()\n\n            preds.append(pred)\n\n        return pd.Series(np.concatenate(preds), index=dl_test.get_index())\n\n\nclass LSTMModel(nn.Module):\n    def __init__(self, d_feat=6, hidden_size=64, num_layers=2, dropout=0.0):\n        super().__init__()\n\n        self.rnn = nn.LSTM(\n            input_size=d_feat,\n            hidden_size=hidden_size,\n            num_layers=num_layers,\n            batch_first=True,\n            dropout=dropout,\n        )\n        self.fc_out = nn.Linear(hidden_size, 1)\n\n        self.d_feat = d_feat\n\n    def forward(self, x):\n        out, _ = self.rnn(x)\n        return self.fc_out(out[:, -1, :]).squeeze()\n"
  },
  {
    "path": "qlib/contrib/model/pytorch_nn.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\n\nfrom __future__ import division\nfrom __future__ import print_function\nfrom collections import defaultdict\n\nimport os\nimport gc\nimport numpy as np\nimport pandas as pd\nfrom packaging import version\nfrom typing import Callable, Optional, Text, Union\nfrom sklearn.metrics import roc_auc_score, mean_squared_error\n\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\n\nfrom .pytorch_utils import count_parameters\nfrom ...model.base import Model\nfrom ...data.dataset import DatasetH\nfrom ...data.dataset.handler import DataHandlerLP\nfrom ...data.dataset.weight import Reweighter\nfrom ...utils import (\n    auto_filter_kwargs,\n    init_instance_by_config,\n    unpack_archive_with_buffer,\n    save_multiple_parts_file,\n    get_or_create_path,\n)\nfrom ...log import get_module_logger\nfrom ...workflow import R\nfrom qlib.contrib.meta.data_selection.utils import ICLoss\nfrom torch.nn import DataParallel\n\n\nclass DNNModelPytorch(Model):\n    \"\"\"DNN Model\n    Parameters\n    ----------\n    input_dim : int\n        input dimension\n    output_dim : int\n        output dimension\n    layers : tuple\n        layer sizes\n    lr : float\n        learning rate\n    optimizer : str\n        optimizer name\n    GPU : int\n        the GPU ID used for training\n    \"\"\"\n\n    def __init__(\n        self,\n        lr=0.001,\n        max_steps=300,\n        batch_size=2000,\n        early_stop_rounds=50,\n        eval_steps=20,\n        optimizer=\"gd\",\n        loss=\"mse\",\n        GPU=0,\n        seed=None,\n        weight_decay=0.0,\n        data_parall=False,\n        scheduler: Optional[Union[Callable]] = \"default\",  # when it is Callable, it accept one argument named optimizer\n        init_model=None,\n        eval_train_metric=False,\n        pt_model_uri=\"qlib.contrib.model.pytorch_nn.Net\",\n        pt_model_kwargs={\n            \"input_dim\": 360,\n            \"layers\": (256,),\n        },\n        valid_key=DataHandlerLP.DK_L,\n        # TODO: Infer Key is a more reasonable key. But it requires more detailed processing on label processing\n    ):\n        # Set logger.\n        self.logger = get_module_logger(\"DNNModelPytorch\")\n        self.logger.info(\"DNN pytorch version...\")\n\n        # set hyper-parameters.\n        self.lr = lr\n        self.max_steps = max_steps\n        self.batch_size = batch_size\n        self.early_stop_rounds = early_stop_rounds\n        self.eval_steps = eval_steps\n        self.optimizer = optimizer.lower()\n        self.loss_type = loss\n        if isinstance(GPU, str):\n            self.device = torch.device(GPU)\n        else:\n            self.device = torch.device(\"cuda:%d\" % (GPU) if torch.cuda.is_available() and GPU >= 0 else \"cpu\")\n        self.seed = seed\n        self.weight_decay = weight_decay\n        self.data_parall = data_parall\n        self.eval_train_metric = eval_train_metric\n        self.valid_key = valid_key\n\n        self.best_step = None\n\n        self.logger.info(\n            \"DNN parameters setting:\"\n            f\"\\nlr : {lr}\"\n            f\"\\nmax_steps : {max_steps}\"\n            f\"\\nbatch_size : {batch_size}\"\n            f\"\\nearly_stop_rounds : {early_stop_rounds}\"\n            f\"\\neval_steps : {eval_steps}\"\n            f\"\\noptimizer : {optimizer}\"\n            f\"\\nloss_type : {loss}\"\n            f\"\\nseed : {seed}\"\n            f\"\\ndevice : {self.device}\"\n            f\"\\nuse_GPU : {self.use_gpu}\"\n            f\"\\nweight_decay : {weight_decay}\"\n            f\"\\nenable data parall : {self.data_parall}\"\n            f\"\\npt_model_uri: {pt_model_uri}\"\n            f\"\\npt_model_kwargs: {pt_model_kwargs}\"\n        )\n\n        if self.seed is not None:\n            np.random.seed(self.seed)\n            torch.manual_seed(self.seed)\n\n        if loss not in {\"mse\", \"binary\"}:\n            raise NotImplementedError(\"loss {} is not supported!\".format(loss))\n        self._scorer = mean_squared_error if loss == \"mse\" else roc_auc_score\n\n        if init_model is None:\n            self.dnn_model = init_instance_by_config({\"class\": pt_model_uri, \"kwargs\": pt_model_kwargs})\n\n            if self.data_parall:\n                self.dnn_model = DataParallel(self.dnn_model).to(self.device)\n        else:\n            self.dnn_model = init_model\n\n        self.logger.info(\"model:\\n{:}\".format(self.dnn_model))\n        self.logger.info(\"model size: {:.4f} MB\".format(count_parameters(self.dnn_model)))\n\n        if optimizer.lower() == \"adam\":\n            self.train_optimizer = optim.Adam(self.dnn_model.parameters(), lr=self.lr, weight_decay=self.weight_decay)\n        elif optimizer.lower() == \"gd\":\n            self.train_optimizer = optim.SGD(self.dnn_model.parameters(), lr=self.lr, weight_decay=self.weight_decay)\n        else:\n            raise NotImplementedError(\"optimizer {} is not supported!\".format(optimizer))\n\n        if scheduler == \"default\":\n            # In torch version 2.7.0, the verbose parameter has been removed. Reference Link:\n            # https://github.com/pytorch/pytorch/pull/147301/files#diff-036a7470d5307f13c9a6a51c3a65dd014f00ca02f476c545488cd856bea9bcf2L1313\n            if version.parse(str(torch.__version__).split(\"+\", maxsplit=1)[0]) <= version.parse(\"2.6.0\"):\n                # Reduce learning rate when loss has stopped decrease\n                self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(  # pylint: disable=E1123\n                    self.train_optimizer,\n                    mode=\"min\",\n                    factor=0.5,\n                    patience=10,\n                    verbose=True,\n                    threshold=0.0001,\n                    threshold_mode=\"rel\",\n                    cooldown=0,\n                    min_lr=0.00001,\n                    eps=1e-08,\n                )\n            else:\n                self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(\n                    self.train_optimizer,\n                    mode=\"min\",\n                    factor=0.5,\n                    patience=10,\n                    threshold=0.0001,\n                    threshold_mode=\"rel\",\n                    cooldown=0,\n                    min_lr=0.00001,\n                    eps=1e-08,\n                )\n        elif scheduler is None:\n            self.scheduler = None\n        else:\n            self.scheduler = scheduler(optimizer=self.train_optimizer)\n\n        self.fitted = False\n        self.dnn_model.to(self.device)\n\n    @property\n    def use_gpu(self):\n        return self.device != torch.device(\"cpu\")\n\n    def fit(\n        self,\n        dataset: DatasetH,\n        evals_result=dict(),\n        verbose=True,\n        save_path=None,\n        reweighter=None,\n    ):\n        has_valid = \"valid\" in dataset.segments\n        segments = [\"train\", \"valid\"]\n        vars = [\"x\", \"y\", \"w\"]\n        all_df = defaultdict(dict)  # x_train, x_valid y_train, y_valid w_train, w_valid\n        all_t = defaultdict(dict)  # tensors\n        for seg in segments:\n            if seg in dataset.segments:\n                # df_train df_valid\n                df = dataset.prepare(\n                    seg, col_set=[\"feature\", \"label\"], data_key=self.valid_key if seg == \"valid\" else DataHandlerLP.DK_L\n                )\n                all_df[\"x\"][seg] = df[\"feature\"]\n                all_df[\"y\"][seg] = df[\"label\"].copy()  # We have to use copy to remove the reference to release mem\n                if reweighter is None:\n                    all_df[\"w\"][seg] = pd.DataFrame(np.ones_like(all_df[\"y\"][seg].values), index=df.index)\n                elif isinstance(reweighter, Reweighter):\n                    all_df[\"w\"][seg] = pd.DataFrame(reweighter.reweight(df))\n                else:\n                    raise ValueError(\"Unsupported reweighter type.\")\n\n                # get tensors\n                for v in vars:\n                    all_t[v][seg] = torch.from_numpy(all_df[v][seg].values).float()\n                    # if seg == \"valid\": # accelerate the eval of validation\n                    all_t[v][seg] = all_t[v][seg].to(self.device)  # This will consume a lot of memory !!!!\n\n                evals_result[seg] = []\n                # free memory\n                del df\n                del all_df[\"x\"]\n                gc.collect()\n\n        save_path = get_or_create_path(save_path)\n        stop_steps = 0\n        train_loss = 0\n        best_loss = np.inf\n        # train\n        self.logger.info(\"training...\")\n        self.fitted = True\n        # return\n        # prepare training data\n        train_num = all_t[\"y\"][\"train\"].shape[0]\n\n        for step in range(1, self.max_steps + 1):\n            if stop_steps >= self.early_stop_rounds:\n                if verbose:\n                    self.logger.info(\"\\tearly stop\")\n                break\n            loss = AverageMeter()\n            self.dnn_model.train()\n            self.train_optimizer.zero_grad()\n            choice = np.random.choice(train_num, self.batch_size)\n            x_batch_auto = all_t[\"x\"][\"train\"][choice].to(self.device)\n            y_batch_auto = all_t[\"y\"][\"train\"][choice].to(self.device)\n            w_batch_auto = all_t[\"w\"][\"train\"][choice].to(self.device)\n\n            # forward\n            preds = self.dnn_model(x_batch_auto)\n            cur_loss = self.get_loss(preds, w_batch_auto, y_batch_auto, self.loss_type)\n            cur_loss.backward()\n            self.train_optimizer.step()\n            loss.update(cur_loss.item())\n            R.log_metrics(train_loss=loss.avg, step=step)\n\n            # validation\n            train_loss += loss.val\n            # for evert `eval_steps` steps or at the last steps, we will evaluate the model.\n            if step % self.eval_steps == 0 or step == self.max_steps:\n                if has_valid:\n                    stop_steps += 1\n                    train_loss /= self.eval_steps\n\n                    with torch.no_grad():\n                        self.dnn_model.eval()\n\n                        # forward\n                        preds = self._nn_predict(all_t[\"x\"][\"valid\"], return_cpu=False)\n                        cur_loss_val = self.get_loss(preds, all_t[\"w\"][\"valid\"], all_t[\"y\"][\"valid\"], self.loss_type)\n                        loss_val = cur_loss_val.item()\n                        metric_val = (\n                            self.get_metric(\n                                preds.reshape(-1), all_t[\"y\"][\"valid\"].reshape(-1), all_df[\"y\"][\"valid\"].index\n                            )\n                            .detach()\n                            .cpu()\n                            .numpy()\n                            .item()\n                        )\n                        R.log_metrics(val_loss=loss_val, step=step)\n                        R.log_metrics(val_metric=metric_val, step=step)\n\n                        if self.eval_train_metric:\n                            metric_train = (\n                                self.get_metric(\n                                    self._nn_predict(all_t[\"x\"][\"train\"], return_cpu=False),\n                                    all_t[\"y\"][\"train\"].reshape(-1),\n                                    all_df[\"y\"][\"train\"].index,\n                                )\n                                .detach()\n                                .cpu()\n                                .numpy()\n                                .item()\n                            )\n                            R.log_metrics(train_metric=metric_train, step=step)\n                        else:\n                            metric_train = np.nan\n                    if verbose:\n                        self.logger.info(\n                            f\"[Step {step}]: train_loss {train_loss:.6f}, valid_loss {loss_val:.6f}, train_metric {metric_train:.6f}, valid_metric {metric_val:.6f}\"\n                        )\n                    evals_result[\"train\"].append(train_loss)\n                    evals_result[\"valid\"].append(loss_val)\n                    if loss_val < best_loss:\n                        if verbose:\n                            self.logger.info(\n                                \"\\tvalid loss update from {:.6f} to {:.6f}, save checkpoint.\".format(\n                                    best_loss, loss_val\n                                )\n                            )\n                        best_loss = loss_val\n                        self.best_step = step\n                        R.log_metrics(best_step=self.best_step, step=step)\n                        stop_steps = 0\n                        torch.save(self.dnn_model.state_dict(), save_path)\n                    train_loss = 0\n                    # update learning rate\n                    if self.scheduler is not None:\n                        auto_filter_kwargs(self.scheduler.step, warning=False)(metrics=cur_loss_val, epoch=step)\n                    R.log_metrics(lr=self.get_lr(), step=step)\n                else:\n                    # retraining mode\n                    if self.scheduler is not None:\n                        self.scheduler.step(epoch=step)\n\n        if has_valid:\n            # restore the optimal parameters after training\n            self.dnn_model.load_state_dict(torch.load(save_path, map_location=self.device))\n        if self.use_gpu:\n            torch.cuda.empty_cache()\n\n    def get_lr(self):\n        assert len(self.train_optimizer.param_groups) == 1\n        return self.train_optimizer.param_groups[0][\"lr\"]\n\n    def get_loss(self, pred, w, target, loss_type):\n        pred, w, target = pred.reshape(-1), w.reshape(-1), target.reshape(-1)\n        if loss_type == \"mse\":\n            sqr_loss = torch.mul(pred - target, pred - target)\n            loss = torch.mul(sqr_loss, w).mean()\n            return loss\n        elif loss_type == \"binary\":\n            loss = nn.BCEWithLogitsLoss(weight=w)\n            return loss(pred, target)\n        else:\n            raise NotImplementedError(\"loss {} is not supported!\".format(loss_type))\n\n    def get_metric(self, pred, target, index):\n        # NOTE: the order of the index must follow <datetime, instrument> sorted order\n        return -ICLoss()(pred, target, index)  # pylint: disable=E1130\n\n    def _nn_predict(self, data, return_cpu=True):\n        \"\"\"Reusing predicting NN.\n        Scenarios\n        1) test inference (data may come from CPU and expect the output data is on CPU)\n        2) evaluation on training (data may come from GPU)\n        \"\"\"\n        if not isinstance(data, torch.Tensor):\n            if isinstance(data, pd.DataFrame):\n                data = data.values\n            data = torch.Tensor(data)\n        data = data.to(self.device)\n        preds = []\n        self.dnn_model.eval()\n        with torch.no_grad():\n            batch_size = 8096\n            for i in range(0, len(data), batch_size):\n                x = data[i : i + batch_size]\n                preds.append(self.dnn_model(x.to(self.device)).detach().reshape(-1))\n        if return_cpu:\n            preds = np.concatenate([pr.cpu().numpy() for pr in preds])\n        else:\n            preds = torch.cat(preds, axis=0)\n        return preds\n\n    def predict(self, dataset: DatasetH, segment: Union[Text, slice] = \"test\"):\n        if not self.fitted:\n            raise ValueError(\"model is not fitted yet!\")\n        x_test_pd = dataset.prepare(segment, col_set=\"feature\", data_key=DataHandlerLP.DK_I)\n        preds = self._nn_predict(x_test_pd)\n        return pd.Series(preds.reshape(-1), index=x_test_pd.index)\n\n    def save(self, filename, **kwargs):\n        with save_multiple_parts_file(filename) as model_dir:\n            model_path = os.path.join(model_dir, os.path.split(model_dir)[-1])\n            # Save model\n            torch.save(self.dnn_model.state_dict(), model_path)\n\n    def load(self, buffer, **kwargs):\n        with unpack_archive_with_buffer(buffer) as model_dir:\n            # Get model name\n            _model_name = os.path.splitext(list(filter(lambda x: x.startswith(\"model.bin\"), os.listdir(model_dir)))[0])[\n                0\n            ]\n            _model_path = os.path.join(model_dir, _model_name)\n            # Load model\n            self.dnn_model.load_state_dict(torch.load(_model_path, map_location=self.device))\n        self.fitted = True\n\n\nclass AverageMeter:\n    \"\"\"Computes and stores the average and current value\"\"\"\n\n    def __init__(self):\n        self.reset()\n\n    def reset(self):\n        self.val = 0\n        self.avg = 0\n        self.sum = 0\n        self.count = 0\n\n    def update(self, val, n=1):\n        self.val = val\n        self.sum += val * n\n        self.count += n\n        self.avg = self.sum / self.count\n\n\nclass Net(nn.Module):\n    def __init__(self, input_dim, output_dim=1, layers=(256,), act=\"LeakyReLU\"):\n        super(Net, self).__init__()\n\n        layers = [input_dim] + list(layers)\n        dnn_layers = []\n        drop_input = nn.Dropout(0.05)\n        dnn_layers.append(drop_input)\n        hidden_units = input_dim\n        for i, (_input_dim, hidden_units) in enumerate(zip(layers[:-1], layers[1:])):\n            fc = nn.Linear(_input_dim, hidden_units)\n            if act == \"LeakyReLU\":\n                activation = nn.LeakyReLU(negative_slope=0.1, inplace=False)\n            elif act == \"SiLU\":\n                activation = nn.SiLU()\n            else:\n                raise NotImplementedError(f\"This type of input is not supported\")\n            bn = nn.BatchNorm1d(hidden_units)\n            seq = nn.Sequential(fc, bn, activation)\n            dnn_layers.append(seq)\n        drop_input = nn.Dropout(0.05)\n        dnn_layers.append(drop_input)\n        fc = nn.Linear(hidden_units, output_dim)\n        dnn_layers.append(fc)\n        # optimizer  # pylint: disable=W0631\n        self.dnn_layers = nn.ModuleList(dnn_layers)\n        self._weight_init()\n\n    def _weight_init(self):\n        for m in self.modules():\n            if isinstance(m, nn.Linear):\n                nn.init.kaiming_normal_(m.weight, a=0.1, mode=\"fan_in\", nonlinearity=\"leaky_relu\")\n\n    def forward(self, x):\n        cur_output = x\n        for i, now_layer in enumerate(self.dnn_layers):\n            cur_output = now_layer(cur_output)\n        return cur_output\n"
  },
  {
    "path": "qlib/contrib/model/pytorch_sandwich.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\n\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport numpy as np\nimport pandas as pd\nfrom typing import Text, Union\nimport copy\nfrom ...utils import get_or_create_path\nfrom ...log import get_module_logger\n\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\n\nfrom ...model.base import Model\nfrom ...data.dataset import DatasetH\nfrom ...data.dataset.handler import DataHandlerLP\nfrom .pytorch_krnn import CNNKRNNEncoder\n\n\nclass SandwichModel(nn.Module):\n    def __init__(\n        self,\n        fea_dim,\n        cnn_dim_1,\n        cnn_dim_2,\n        cnn_kernel_size,\n        rnn_dim_1,\n        rnn_dim_2,\n        rnn_dups,\n        rnn_layers,\n        dropout,\n        device,\n        **params,\n    ):\n        \"\"\"Build a Sandwich model\n\n        Parameters\n        ----------\n        fea_dim : int\n            The feature dimension\n        cnn_dim_1 : int\n            The hidden dimension of the first CNN\n        cnn_dim_2 : int\n            The hidden dimension of the second CNN\n        cnn_kernel_size : int\n            The size of convolutional kernels\n        rnn_dim_1 : int\n            The hidden dimension of the first KRNN\n        rnn_dim_2 : int\n            The hidden dimension of the second KRNN\n        rnn_dups : int\n            The number of parallel duplicates\n        rnn_layers: int\n            The number of RNN layers\n        \"\"\"\n        super().__init__()\n\n        self.first_encoder = CNNKRNNEncoder(\n            cnn_input_dim=fea_dim,\n            cnn_output_dim=cnn_dim_1,\n            cnn_kernel_size=cnn_kernel_size,\n            rnn_output_dim=rnn_dim_1,\n            rnn_dup_num=rnn_dups,\n            rnn_layers=rnn_layers,\n            dropout=dropout,\n            device=device,\n        )\n\n        self.second_encoder = CNNKRNNEncoder(\n            cnn_input_dim=rnn_dim_1,\n            cnn_output_dim=cnn_dim_2,\n            cnn_kernel_size=cnn_kernel_size,\n            rnn_output_dim=rnn_dim_2,\n            rnn_dup_num=rnn_dups,\n            rnn_layers=rnn_layers,\n            dropout=dropout,\n            device=device,\n        )\n\n        self.out_fc = nn.Linear(rnn_dim_2, 1)\n        self.device = device\n\n    def forward(self, x):\n        # x: [batch_size, node_num, seq_len, input_dim]\n        encode = self.first_encoder(x)\n        encode = self.second_encoder(encode)\n        out = self.out_fc(encode[:, -1, :]).squeeze().to(self.device)\n\n        return out\n\n\nclass Sandwich(Model):\n    \"\"\"Sandwich Model\n\n    Parameters\n    ----------\n    d_feat : int\n        input dimension for each time step\n    metric: str\n        the evaluation metric used in early stop\n    optimizer : str\n        optimizer name\n    GPU : str\n        the GPU ID(s) used for training\n    \"\"\"\n\n    def __init__(\n        self,\n        fea_dim=6,\n        cnn_dim_1=64,\n        cnn_dim_2=32,\n        cnn_kernel_size=3,\n        rnn_dim_1=16,\n        rnn_dim_2=8,\n        rnn_dups=3,\n        rnn_layers=2,\n        dropout=0,\n        n_epochs=200,\n        lr=0.001,\n        metric=\"\",\n        batch_size=2000,\n        early_stop=20,\n        loss=\"mse\",\n        optimizer=\"adam\",\n        GPU=0,\n        seed=None,\n        **kwargs,\n    ):\n        # Set logger.\n        self.logger = get_module_logger(\"Sandwich\")\n        self.logger.info(\"Sandwich pytorch version...\")\n\n        # set hyper-parameters.\n        self.fea_dim = fea_dim\n        self.cnn_dim_1 = cnn_dim_1\n        self.cnn_dim_2 = cnn_dim_2\n        self.cnn_kernel_size = cnn_kernel_size\n        self.rnn_dim_1 = rnn_dim_1\n        self.rnn_dim_2 = rnn_dim_2\n        self.rnn_dups = rnn_dups\n        self.rnn_layers = rnn_layers\n        self.dropout = dropout\n        self.n_epochs = n_epochs\n        self.lr = lr\n        self.metric = metric\n        self.batch_size = batch_size\n        self.early_stop = early_stop\n        self.optimizer = optimizer.lower()\n        self.loss = loss\n        self.device = torch.device(\"cuda:%d\" % (GPU) if torch.cuda.is_available() and GPU >= 0 else \"cpu\")\n        self.seed = seed\n\n        self.logger.info(\n            \"Sandwich parameters setting:\"\n            \"\\nfea_dim : {}\"\n            \"\\ncnn_dim_1 : {}\"\n            \"\\ncnn_dim_2 : {}\"\n            \"\\ncnn_kernel_size : {}\"\n            \"\\nrnn_dim_1 : {}\"\n            \"\\nrnn_dim_2 : {}\"\n            \"\\nrnn_dups : {}\"\n            \"\\nrnn_layers : {}\"\n            \"\\ndropout : {}\"\n            \"\\nn_epochs : {}\"\n            \"\\nlr : {}\"\n            \"\\nmetric : {}\"\n            \"\\nbatch_size: {}\"\n            \"\\nearly_stop : {}\"\n            \"\\noptimizer : {}\"\n            \"\\nloss_type : {}\"\n            \"\\nvisible_GPU : {}\"\n            \"\\nuse_GPU : {}\"\n            \"\\nseed : {}\".format(\n                fea_dim,\n                cnn_dim_1,\n                cnn_dim_2,\n                cnn_kernel_size,\n                rnn_dim_1,\n                rnn_dim_2,\n                rnn_dups,\n                rnn_layers,\n                dropout,\n                n_epochs,\n                lr,\n                metric,\n                batch_size,\n                early_stop,\n                optimizer.lower(),\n                loss,\n                GPU,\n                self.use_gpu,\n                seed,\n            )\n        )\n\n        if self.seed is not None:\n            np.random.seed(self.seed)\n            torch.manual_seed(self.seed)\n\n        self.sandwich_model = SandwichModel(\n            fea_dim=self.fea_dim,\n            cnn_dim_1=self.cnn_dim_1,\n            cnn_dim_2=self.cnn_dim_2,\n            cnn_kernel_size=self.cnn_kernel_size,\n            rnn_dim_1=self.rnn_dim_1,\n            rnn_dim_2=self.rnn_dim_2,\n            rnn_dups=self.rnn_dups,\n            rnn_layers=self.rnn_layers,\n            dropout=self.dropout,\n            device=self.device,\n        )\n        if optimizer.lower() == \"adam\":\n            self.train_optimizer = optim.Adam(self.sandwich_model.parameters(), lr=self.lr)\n        elif optimizer.lower() == \"gd\":\n            self.train_optimizer = optim.SGD(self.sandwich_model.parameters(), lr=self.lr)\n        else:\n            raise NotImplementedError(\"optimizer {} is not supported!\".format(optimizer))\n\n        self.fitted = False\n        self.sandwich_model.to(self.device)\n\n    @property\n    def use_gpu(self):\n        return self.device != torch.device(\"cpu\")\n\n    def mse(self, pred, label):\n        loss = (pred - label) ** 2\n        return torch.mean(loss)\n\n    def loss_fn(self, pred, label):\n        mask = ~torch.isnan(label)\n\n        if self.loss == \"mse\":\n            return self.mse(pred[mask], label[mask])\n\n        raise ValueError(\"unknown loss `%s`\" % self.loss)\n\n    def metric_fn(self, pred, label):\n        mask = torch.isfinite(label)\n\n        if self.metric in (\"\", \"loss\"):\n            return -self.loss_fn(pred[mask], label[mask])\n\n        raise ValueError(\"unknown metric `%s`\" % self.metric)\n\n    def train_epoch(self, x_train, y_train):\n        x_train_values = x_train.values\n        y_train_values = np.squeeze(y_train.values)\n        self.sandwich_model.train()\n\n        indices = np.arange(len(x_train_values))\n        np.random.shuffle(indices)\n\n        for i in range(len(indices))[:: self.batch_size]:\n            if len(indices) - i < self.batch_size:\n                break\n\n            feature = torch.from_numpy(x_train_values[indices[i : i + self.batch_size]]).float().to(self.device)\n            label = torch.from_numpy(y_train_values[indices[i : i + self.batch_size]]).float().to(self.device)\n\n            pred = self.sandwich_model(feature)\n            loss = self.loss_fn(pred, label)\n\n            self.train_optimizer.zero_grad()\n            loss.backward()\n            torch.nn.utils.clip_grad_value_(self.sandwich_model.parameters(), 3.0)\n            self.train_optimizer.step()\n\n    def test_epoch(self, data_x, data_y):\n        # prepare training data\n        x_values = data_x.values\n        y_values = np.squeeze(data_y.values)\n\n        self.sandwich_model.eval()\n\n        scores = []\n        losses = []\n\n        indices = np.arange(len(x_values))\n\n        for i in range(len(indices))[:: self.batch_size]:\n            if len(indices) - i < self.batch_size:\n                break\n\n            feature = torch.from_numpy(x_values[indices[i : i + self.batch_size]]).float().to(self.device)\n            label = torch.from_numpy(y_values[indices[i : i + self.batch_size]]).float().to(self.device)\n\n            pred = self.sandwich_model(feature)\n            loss = self.loss_fn(pred, label)\n            losses.append(loss.item())\n\n            score = self.metric_fn(pred, label)\n            scores.append(score.item())\n\n        return np.mean(losses), np.mean(scores)\n\n    def fit(\n        self,\n        dataset: DatasetH,\n        evals_result=dict(),\n        save_path=None,\n    ):\n        df_train, df_valid, df_test = dataset.prepare(\n            [\"train\", \"valid\", \"test\"],\n            col_set=[\"feature\", \"label\"],\n            data_key=DataHandlerLP.DK_L,\n        )\n        if df_train.empty or df_valid.empty:\n            raise ValueError(\"Empty data from dataset, please check your dataset config.\")\n\n        x_train, y_train = df_train[\"feature\"], df_train[\"label\"]\n        x_valid, y_valid = df_valid[\"feature\"], df_valid[\"label\"]\n\n        save_path = get_or_create_path(save_path)\n        stop_steps = 0\n        train_loss = 0\n        best_score = -np.inf\n        best_epoch = 0\n        evals_result[\"train\"] = []\n        evals_result[\"valid\"] = []\n\n        # train\n        self.logger.info(\"training...\")\n        self.fitted = True\n\n        for step in range(self.n_epochs):\n            self.logger.info(\"Epoch%d:\", step)\n            self.logger.info(\"training...\")\n            self.train_epoch(x_train, y_train)\n            self.logger.info(\"evaluating...\")\n            train_loss, train_score = self.test_epoch(x_train, y_train)\n            val_loss, val_score = self.test_epoch(x_valid, y_valid)\n            self.logger.info(\"train %.6f, valid %.6f\" % (train_score, val_score))\n            evals_result[\"train\"].append(train_score)\n            evals_result[\"valid\"].append(val_score)\n\n            if val_score > best_score:\n                best_score = val_score\n                stop_steps = 0\n                best_epoch = step\n                best_param = copy.deepcopy(self.sandwich_model.state_dict())\n            else:\n                stop_steps += 1\n                if stop_steps >= self.early_stop:\n                    self.logger.info(\"early stop\")\n                    break\n\n        self.logger.info(\"best score: %.6lf @ %d\" % (best_score, best_epoch))\n        self.sandwich_model.load_state_dict(best_param)\n        torch.save(best_param, save_path)\n\n        if self.use_gpu:\n            torch.cuda.empty_cache()\n\n    def predict(self, dataset: DatasetH, segment: Union[Text, slice] = \"test\"):\n        if not self.fitted:\n            raise ValueError(\"model is not fitted yet!\")\n\n        x_test = dataset.prepare(segment, col_set=\"feature\", data_key=DataHandlerLP.DK_I)\n        index = x_test.index\n        self.sandwich_model.eval()\n        x_values = x_test.values\n        sample_num = x_values.shape[0]\n        preds = []\n\n        for begin in range(sample_num)[:: self.batch_size]:\n            if sample_num - begin < self.batch_size:\n                end = sample_num\n            else:\n                end = begin + self.batch_size\n            x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)\n            with torch.no_grad():\n                pred = self.sandwich_model(x_batch).detach().cpu().numpy()\n            preds.append(pred)\n\n        return pd.Series(np.concatenate(preds), index=index)\n"
  },
  {
    "path": "qlib/contrib/model/pytorch_sfm.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport numpy as np\nimport pandas as pd\nfrom typing import Text, Union\nimport copy\nfrom ...utils import get_or_create_path\nfrom ...log import get_module_logger\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.init as init\nimport torch.optim as optim\n\nfrom .pytorch_utils import count_parameters\nfrom ...model.base import Model\nfrom ...data.dataset import DatasetH\nfrom ...data.dataset.handler import DataHandlerLP\n\n\nclass SFM_Model(nn.Module):\n    def __init__(\n        self,\n        d_feat=6,\n        output_dim=1,\n        freq_dim=10,\n        hidden_size=64,\n        dropout_W=0.0,\n        dropout_U=0.0,\n        device=\"cpu\",\n    ):\n        super().__init__()\n\n        self.input_dim = d_feat\n        self.output_dim = output_dim\n        self.freq_dim = freq_dim\n        self.hidden_dim = hidden_size\n        self.device = device\n\n        self.W_i = nn.Parameter(init.xavier_uniform_(torch.empty((self.input_dim, self.hidden_dim))))\n        self.U_i = nn.Parameter(init.orthogonal_(torch.empty(self.hidden_dim, self.hidden_dim)))\n        self.b_i = nn.Parameter(torch.zeros(self.hidden_dim))\n\n        self.W_ste = nn.Parameter(init.xavier_uniform_(torch.empty(self.input_dim, self.hidden_dim)))\n        self.U_ste = nn.Parameter(init.orthogonal_(torch.empty(self.hidden_dim, self.hidden_dim)))\n        self.b_ste = nn.Parameter(torch.ones(self.hidden_dim))\n\n        self.W_fre = nn.Parameter(init.xavier_uniform_(torch.empty(self.input_dim, self.freq_dim)))\n        self.U_fre = nn.Parameter(init.orthogonal_(torch.empty(self.hidden_dim, self.freq_dim)))\n        self.b_fre = nn.Parameter(torch.ones(self.freq_dim))\n\n        self.W_c = nn.Parameter(init.xavier_uniform_(torch.empty(self.input_dim, self.hidden_dim)))\n        self.U_c = nn.Parameter(init.orthogonal_(torch.empty(self.hidden_dim, self.hidden_dim)))\n        self.b_c = nn.Parameter(torch.zeros(self.hidden_dim))\n\n        self.W_o = nn.Parameter(init.xavier_uniform_(torch.empty(self.input_dim, self.hidden_dim)))\n        self.U_o = nn.Parameter(init.orthogonal_(torch.empty(self.hidden_dim, self.hidden_dim)))\n        self.b_o = nn.Parameter(torch.zeros(self.hidden_dim))\n\n        self.U_a = nn.Parameter(init.orthogonal_(torch.empty(self.freq_dim, 1)))\n        self.b_a = nn.Parameter(torch.zeros(self.hidden_dim))\n\n        self.W_p = nn.Parameter(init.xavier_uniform_(torch.empty(self.hidden_dim, self.output_dim)))\n        self.b_p = nn.Parameter(torch.zeros(self.output_dim))\n\n        self.activation = nn.Tanh()\n        self.inner_activation = nn.Hardsigmoid()\n        self.dropout_W, self.dropout_U = (dropout_W, dropout_U)\n        self.fc_out = nn.Linear(self.output_dim, 1)\n\n        self.states = []\n\n    def forward(self, input):\n        input = input.reshape(len(input), self.input_dim, -1)  # [N, F, T]\n        input = input.permute(0, 2, 1)  # [N, T, F]\n        time_step = input.shape[1]\n\n        for ts in range(time_step):\n            x = input[:, ts, :]\n            if len(self.states) == 0:  # hasn't initialized yet\n                self.init_states(x)\n            self.get_constants(x)\n            p_tm1 = self.states[0]  # noqa: F841\n            h_tm1 = self.states[1]\n            S_re_tm1 = self.states[2]\n            S_im_tm1 = self.states[3]\n            time_tm1 = self.states[4]\n            B_U = self.states[5]\n            B_W = self.states[6]\n            frequency = self.states[7]\n\n            x_i = torch.matmul(x * B_W[0], self.W_i) + self.b_i\n            x_ste = torch.matmul(x * B_W[0], self.W_ste) + self.b_ste\n            x_fre = torch.matmul(x * B_W[0], self.W_fre) + self.b_fre\n            x_c = torch.matmul(x * B_W[0], self.W_c) + self.b_c\n            x_o = torch.matmul(x * B_W[0], self.W_o) + self.b_o\n\n            i = self.inner_activation(x_i + torch.matmul(h_tm1 * B_U[0], self.U_i))\n            ste = self.inner_activation(x_ste + torch.matmul(h_tm1 * B_U[0], self.U_ste))\n            fre = self.inner_activation(x_fre + torch.matmul(h_tm1 * B_U[0], self.U_fre))\n\n            ste = torch.reshape(ste, (-1, self.hidden_dim, 1))\n            fre = torch.reshape(fre, (-1, 1, self.freq_dim))\n\n            f = ste * fre\n\n            c = i * self.activation(x_c + torch.matmul(h_tm1 * B_U[0], self.U_c))\n\n            time = time_tm1 + 1\n\n            omega = torch.tensor(2 * np.pi) * time * frequency\n\n            re = torch.cos(omega)\n            im = torch.sin(omega)\n\n            c = torch.reshape(c, (-1, self.hidden_dim, 1))\n\n            S_re = f * S_re_tm1 + c * re\n            S_im = f * S_im_tm1 + c * im\n\n            A = torch.square(S_re) + torch.square(S_im)\n\n            A = torch.reshape(A, (-1, self.freq_dim)).float()\n            A_a = torch.matmul(A * B_U[0], self.U_a)\n            A_a = torch.reshape(A_a, (-1, self.hidden_dim))\n            a = self.activation(A_a + self.b_a)\n\n            o = self.inner_activation(x_o + torch.matmul(h_tm1 * B_U[0], self.U_o))\n\n            h = o * a\n            p = torch.matmul(h, self.W_p) + self.b_p\n\n            self.states = [p, h, S_re, S_im, time, None, None, None]\n        self.states = []\n        return self.fc_out(p).squeeze()\n\n    def init_states(self, x):\n        reducer_f = torch.zeros((self.hidden_dim, self.freq_dim)).to(self.device)\n        reducer_p = torch.zeros((self.hidden_dim, self.output_dim)).to(self.device)\n\n        init_state_h = torch.zeros(self.hidden_dim).to(self.device)\n        init_state_p = torch.matmul(init_state_h, reducer_p)\n\n        init_state = torch.zeros_like(init_state_h).to(self.device)\n        init_freq = torch.matmul(init_state_h, reducer_f)\n\n        init_state = torch.reshape(init_state, (-1, self.hidden_dim, 1))\n        init_freq = torch.reshape(init_freq, (-1, 1, self.freq_dim))\n\n        init_state_S_re = init_state * init_freq\n        init_state_S_im = init_state * init_freq\n\n        init_state_time = torch.tensor(0).to(self.device)\n\n        self.states = [\n            init_state_p,\n            init_state_h,\n            init_state_S_re,\n            init_state_S_im,\n            init_state_time,\n            None,\n            None,\n            None,\n        ]\n\n    def get_constants(self, x):\n        constants = []\n        constants.append([torch.tensor(1.0).to(self.device) for _ in range(6)])\n        constants.append([torch.tensor(1.0).to(self.device) for _ in range(7)])\n        array = np.array([float(ii) / self.freq_dim for ii in range(self.freq_dim)])\n        constants.append(torch.tensor(array).to(self.device))\n\n        self.states[5:] = constants\n\n\nclass SFM(Model):\n    \"\"\"SFM Model\n\n    Parameters\n    ----------\n    input_dim : int\n        input dimension\n    output_dim : int\n        output dimension\n    lr : float\n        learning rate\n    optimizer : str\n        optimizer name\n    GPU : int\n        the GPU ID used for training\n    \"\"\"\n\n    def __init__(\n        self,\n        d_feat=6,\n        hidden_size=64,\n        output_dim=1,\n        freq_dim=10,\n        dropout_W=0.0,\n        dropout_U=0.0,\n        n_epochs=200,\n        lr=0.001,\n        metric=\"\",\n        batch_size=2000,\n        early_stop=20,\n        eval_steps=5,\n        loss=\"mse\",\n        optimizer=\"gd\",\n        GPU=0,\n        seed=None,\n        **kwargs,\n    ):\n        # Set logger.\n        self.logger = get_module_logger(\"SFM\")\n        self.logger.info(\"SFM pytorch version...\")\n\n        # set hyper-parameters.\n        self.d_feat = d_feat\n        self.hidden_size = hidden_size\n        self.output_dim = output_dim\n        self.freq_dim = freq_dim\n        self.dropout_W = dropout_W\n        self.dropout_U = dropout_U\n        self.n_epochs = n_epochs\n        self.lr = lr\n        self.metric = metric\n        self.batch_size = batch_size\n        self.early_stop = early_stop\n        self.eval_steps = eval_steps\n        self.optimizer = optimizer.lower()\n        self.loss = loss\n        self.device = torch.device(\"cuda:%d\" % (GPU) if torch.cuda.is_available() and GPU >= 0 else \"cpu\")\n        self.seed = seed\n\n        self.logger.info(\n            \"SFM parameters setting:\"\n            \"\\nd_feat : {}\"\n            \"\\nhidden_size : {}\"\n            \"\\noutput_size : {}\"\n            \"\\nfrequency_dimension : {}\"\n            \"\\ndropout_W: {}\"\n            \"\\ndropout_U: {}\"\n            \"\\nn_epochs : {}\"\n            \"\\nlr : {}\"\n            \"\\nmetric : {}\"\n            \"\\nbatch_size : {}\"\n            \"\\nearly_stop : {}\"\n            \"\\neval_steps : {}\"\n            \"\\noptimizer : {}\"\n            \"\\nloss_type : {}\"\n            \"\\ndevice : {}\"\n            \"\\nuse_GPU : {}\"\n            \"\\nseed : {}\".format(\n                d_feat,\n                hidden_size,\n                output_dim,\n                freq_dim,\n                dropout_W,\n                dropout_U,\n                n_epochs,\n                lr,\n                metric,\n                batch_size,\n                early_stop,\n                eval_steps,\n                optimizer.lower(),\n                loss,\n                self.device,\n                self.use_gpu,\n                seed,\n            )\n        )\n\n        if self.seed is not None:\n            np.random.seed(self.seed)\n            torch.manual_seed(self.seed)\n\n        self.sfm_model = SFM_Model(\n            d_feat=self.d_feat,\n            output_dim=self.output_dim,\n            hidden_size=self.hidden_size,\n            freq_dim=self.freq_dim,\n            dropout_W=self.dropout_W,\n            dropout_U=self.dropout_U,\n            device=self.device,\n        )\n        self.logger.info(\"model:\\n{:}\".format(self.sfm_model))\n        self.logger.info(\"model size: {:.4f} MB\".format(count_parameters(self.sfm_model)))\n\n        if optimizer.lower() == \"adam\":\n            self.train_optimizer = optim.Adam(self.sfm_model.parameters(), lr=self.lr)\n        elif optimizer.lower() == \"gd\":\n            self.train_optimizer = optim.SGD(self.sfm_model.parameters(), lr=self.lr)\n        else:\n            raise NotImplementedError(\"optimizer {} is not supported!\".format(optimizer))\n\n        self.fitted = False\n        self.sfm_model.to(self.device)\n\n    @property\n    def use_gpu(self):\n        return self.device != torch.device(\"cpu\")\n\n    def test_epoch(self, data_x, data_y):\n        # prepare training data\n        x_values = data_x.values\n        y_values = np.squeeze(data_y.values)\n\n        self.sfm_model.eval()\n\n        scores = []\n        losses = []\n\n        indices = np.arange(len(x_values))\n\n        for i in range(len(indices))[:: self.batch_size]:\n            if len(indices) - i < self.batch_size:\n                break\n\n            feature = torch.from_numpy(x_values[indices[i : i + self.batch_size]]).float().to(self.device)\n            label = torch.from_numpy(y_values[indices[i : i + self.batch_size]]).float().to(self.device)\n\n            pred = self.sfm_model(feature)\n            loss = self.loss_fn(pred, label)\n            losses.append(loss.item())\n\n            score = self.metric_fn(pred, label)\n            scores.append(score.item())\n\n        return np.mean(losses), np.mean(scores)\n\n    def train_epoch(self, x_train, y_train):\n        x_train_values = x_train.values\n        y_train_values = np.squeeze(y_train.values)\n\n        self.sfm_model.train()\n\n        indices = np.arange(len(x_train_values))\n        np.random.shuffle(indices)\n\n        for i in range(len(indices))[:: self.batch_size]:\n            if len(indices) - i < self.batch_size:\n                break\n\n            feature = torch.from_numpy(x_train_values[indices[i : i + self.batch_size]]).float().to(self.device)\n            label = torch.from_numpy(y_train_values[indices[i : i + self.batch_size]]).float().to(self.device)\n\n            pred = self.sfm_model(feature)\n            loss = self.loss_fn(pred, label)\n\n            self.train_optimizer.zero_grad()\n            loss.backward()\n            torch.nn.utils.clip_grad_value_(self.sfm_model.parameters(), 3.0)\n            self.train_optimizer.step()\n\n    def fit(\n        self,\n        dataset: DatasetH,\n        evals_result=dict(),\n        save_path=None,\n    ):\n        df_train, df_valid = dataset.prepare(\n            [\"train\", \"valid\"],\n            col_set=[\"feature\", \"label\"],\n            data_key=DataHandlerLP.DK_L,\n        )\n        if df_train.empty or df_valid.empty:\n            raise ValueError(\"Empty data from dataset, please check your dataset config.\")\n        x_train, y_train = df_train[\"feature\"], df_train[\"label\"]\n        x_valid, y_valid = df_valid[\"feature\"], df_valid[\"label\"]\n\n        save_path = get_or_create_path(save_path)\n        stop_steps = 0\n        train_loss = 0\n        best_score = -np.inf\n        best_epoch = 0\n        evals_result[\"train\"] = []\n        evals_result[\"valid\"] = []\n\n        # train\n        self.logger.info(\"training...\")\n        self.fitted = True\n\n        for step in range(self.n_epochs):\n            self.logger.info(\"Epoch%d:\", step)\n            self.logger.info(\"training...\")\n            self.train_epoch(x_train, y_train)\n            self.logger.info(\"evaluating...\")\n            train_loss, train_score = self.test_epoch(x_train, y_train)\n            val_loss, val_score = self.test_epoch(x_valid, y_valid)\n            self.logger.info(\"train %.6f, valid %.6f\" % (train_score, val_score))\n            evals_result[\"train\"].append(train_score)\n            evals_result[\"valid\"].append(val_score)\n\n            if val_score > best_score:\n                best_score = val_score\n                stop_steps = 0\n                best_epoch = step\n                best_param = copy.deepcopy(self.sfm_model.state_dict())\n            else:\n                stop_steps += 1\n                if stop_steps >= self.early_stop:\n                    self.logger.info(\"early stop\")\n                    break\n\n        self.logger.info(\"best score: %.6lf @ %d\" % (best_score, best_epoch))\n        self.sfm_model.load_state_dict(best_param)\n        torch.save(best_param, save_path)\n        if self.device != \"cpu\":\n            torch.cuda.empty_cache()\n\n    def mse(self, pred, label):\n        loss = (pred - label) ** 2\n        return torch.mean(loss)\n\n    def loss_fn(self, pred, label):\n        mask = ~torch.isnan(label)\n\n        if self.loss == \"mse\":\n            return self.mse(pred[mask], label[mask])\n\n        raise ValueError(\"unknown loss `%s`\" % self.loss)\n\n    def metric_fn(self, pred, label):\n        mask = torch.isfinite(label)\n\n        if self.metric in (\"\", \"loss\"):\n            return -self.loss_fn(pred[mask], label[mask])\n\n        raise ValueError(\"unknown metric `%s`\" % self.metric)\n\n    def predict(self, dataset: DatasetH, segment: Union[Text, slice] = \"test\"):\n        if not self.fitted:\n            raise ValueError(\"model is not fitted yet!\")\n\n        x_test = dataset.prepare(segment, col_set=\"feature\", data_key=DataHandlerLP.DK_I)\n        index = x_test.index\n        self.sfm_model.eval()\n        x_values = x_test.values\n        sample_num = x_values.shape[0]\n        preds = []\n\n        for begin in range(sample_num)[:: self.batch_size]:\n            if sample_num - begin < self.batch_size:\n                end = sample_num\n            else:\n                end = begin + self.batch_size\n\n            x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)\n\n            with torch.no_grad():\n                pred = self.sfm_model(x_batch).detach().cpu().numpy()\n\n            preds.append(pred)\n\n        return pd.Series(np.concatenate(preds), index=index)\n\n\nclass AverageMeter:\n    \"\"\"Computes and stores the average and current value\"\"\"\n\n    def __init__(self):\n        self.reset()\n\n    def reset(self):\n        self.val = 0\n        self.avg = 0\n        self.sum = 0\n        self.count = 0\n\n    def update(self, val, n=1):\n        self.val = val\n        self.sum += val * n\n        self.count += n\n        self.avg = self.sum / self.count\n"
  },
  {
    "path": "qlib/contrib/model/pytorch_tabnet.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport numpy as np\nimport pandas as pd\nfrom typing import Text, Union\nimport copy\nfrom ...utils import get_or_create_path\nfrom ...log import get_module_logger\n\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nimport torch.nn.functional as F\nfrom torch.autograd import Function\n\nfrom .pytorch_utils import count_parameters\nfrom ...model.base import Model\nfrom ...data.dataset import DatasetH\nfrom ...data.dataset.handler import DataHandlerLP\n\n\nclass TabnetModel(Model):\n    def __init__(\n        self,\n        d_feat=158,\n        out_dim=64,\n        final_out_dim=1,\n        batch_size=4096,\n        n_d=64,\n        n_a=64,\n        n_shared=2,\n        n_ind=2,\n        n_steps=5,\n        n_epochs=100,\n        pretrain_n_epochs=50,\n        relax=1.3,\n        vbs=2048,\n        seed=993,\n        optimizer=\"adam\",\n        loss=\"mse\",\n        metric=\"\",\n        early_stop=20,\n        GPU=0,\n        pretrain_loss=\"custom\",\n        ps=0.3,\n        lr=0.01,\n        pretrain=True,\n        pretrain_file=None,\n    ):\n        \"\"\"\n        TabNet model for Qlib\n\n        Args:\n        ps: probability to generate the bernoulli mask\n        \"\"\"\n        # set hyper-parameters.\n        self.d_feat = d_feat\n        self.out_dim = out_dim\n        self.final_out_dim = final_out_dim\n        self.lr = lr\n        self.batch_size = batch_size\n        self.optimizer = optimizer.lower()\n        self.pretrain_loss = pretrain_loss\n        self.seed = seed\n        self.ps = ps\n        self.n_epochs = n_epochs\n        self.logger = get_module_logger(\"TabNet\")\n        self.pretrain_n_epochs = pretrain_n_epochs\n        self.device = \"cuda:%s\" % (GPU) if torch.cuda.is_available() and GPU >= 0 else \"cpu\"\n        self.loss = loss\n        self.metric = metric\n        self.early_stop = early_stop\n        self.pretrain = pretrain\n        self.pretrain_file = get_or_create_path(pretrain_file)\n        self.logger.info(\n            \"TabNet:\"\n            \"\\nbatch_size : {}\"\n            \"\\nvirtual bs : {}\"\n            \"\\ndevice : {}\"\n            \"\\npretrain: {}\".format(self.batch_size, vbs, self.device, self.pretrain)\n        )\n        self.fitted = False\n        np.random.seed(self.seed)\n        torch.manual_seed(self.seed)\n\n        self.tabnet_model = TabNet(inp_dim=self.d_feat, out_dim=self.out_dim, vbs=vbs, relax=relax).to(self.device)\n        self.tabnet_decoder = TabNet_Decoder(self.out_dim, self.d_feat, n_shared, n_ind, vbs, n_steps).to(self.device)\n        self.logger.info(\"model:\\n{:}\\n{:}\".format(self.tabnet_model, self.tabnet_decoder))\n        self.logger.info(\"model size: {:.4f} MB\".format(count_parameters([self.tabnet_model, self.tabnet_decoder])))\n\n        if optimizer.lower() == \"adam\":\n            self.pretrain_optimizer = optim.Adam(\n                list(self.tabnet_model.parameters()) + list(self.tabnet_decoder.parameters()), lr=self.lr\n            )\n            self.train_optimizer = optim.Adam(self.tabnet_model.parameters(), lr=self.lr)\n\n        elif optimizer.lower() == \"gd\":\n            self.pretrain_optimizer = optim.SGD(\n                list(self.tabnet_model.parameters()) + list(self.tabnet_decoder.parameters()), lr=self.lr\n            )\n            self.train_optimizer = optim.SGD(self.tabnet_model.parameters(), lr=self.lr)\n        else:\n            raise NotImplementedError(\"optimizer {} is not supported!\".format(optimizer))\n\n    @property\n    def use_gpu(self):\n        return self.device != torch.device(\"cpu\")\n\n    def pretrain_fn(self, dataset=DatasetH, pretrain_file=\"./pretrain/best.model\"):\n        get_or_create_path(pretrain_file)\n\n        [df_train, df_valid] = dataset.prepare(\n            [\"pretrain\", \"pretrain_validation\"],\n            col_set=[\"feature\", \"label\"],\n            data_key=DataHandlerLP.DK_L,\n        )\n\n        df_train.fillna(df_train.mean(), inplace=True)\n        df_valid.fillna(df_valid.mean(), inplace=True)\n\n        x_train = df_train[\"feature\"]\n        x_valid = df_valid[\"feature\"]\n\n        # Early stop setup\n        stop_steps = 0\n        train_loss = 0\n        best_loss = np.inf\n\n        for epoch_idx in range(self.pretrain_n_epochs):\n            self.logger.info(\"epoch: %s\" % (epoch_idx))\n            self.logger.info(\"pre-training...\")\n            self.pretrain_epoch(x_train)\n            self.logger.info(\"evaluating...\")\n            train_loss = self.pretrain_test_epoch(x_train)\n            valid_loss = self.pretrain_test_epoch(x_valid)\n            self.logger.info(\"train %.6f, valid %.6f\" % (train_loss, valid_loss))\n\n            if valid_loss < best_loss:\n                self.logger.info(\"Save Model...\")\n                torch.save(self.tabnet_model.state_dict(), pretrain_file)\n                best_loss = valid_loss\n            else:\n                stop_steps += 1\n                if stop_steps >= self.early_stop:\n                    self.logger.info(\"early stop\")\n                    break\n\n    def fit(\n        self,\n        dataset: DatasetH,\n        evals_result=dict(),\n        save_path=None,\n    ):\n        if self.pretrain:\n            # there is a  pretrained model, load the model\n            self.logger.info(\"Pretrain...\")\n            self.pretrain_fn(dataset, self.pretrain_file)\n            self.logger.info(\"Load Pretrain model\")\n            self.tabnet_model.load_state_dict(torch.load(self.pretrain_file, map_location=self.device))\n\n        # adding one more linear layer to fit the final output dimension\n        self.tabnet_model = FinetuneModel(self.out_dim, self.final_out_dim, self.tabnet_model).to(self.device)\n        df_train, df_valid = dataset.prepare(\n            [\"train\", \"valid\"],\n            col_set=[\"feature\", \"label\"],\n            data_key=DataHandlerLP.DK_L,\n        )\n        if df_train.empty or df_valid.empty:\n            raise ValueError(\"Empty data from dataset, please check your dataset config.\")\n        df_train.fillna(df_train.mean(), inplace=True)\n        x_train, y_train = df_train[\"feature\"], df_train[\"label\"]\n        x_valid, y_valid = df_valid[\"feature\"], df_valid[\"label\"]\n        save_path = get_or_create_path(save_path)\n\n        stop_steps = 0\n        train_loss = 0\n        best_score = -np.inf\n        best_epoch = 0\n        evals_result[\"train\"] = []\n        evals_result[\"valid\"] = []\n\n        self.logger.info(\"training...\")\n        self.fitted = True\n\n        for epoch_idx in range(self.n_epochs):\n            self.logger.info(\"epoch: %s\" % (epoch_idx))\n            self.logger.info(\"training...\")\n            self.train_epoch(x_train, y_train)\n            self.logger.info(\"evaluating...\")\n            train_loss, train_score = self.test_epoch(x_train, y_train)\n            valid_loss, val_score = self.test_epoch(x_valid, y_valid)\n            self.logger.info(\"train %.6f, valid %.6f\" % (train_score, val_score))\n            evals_result[\"train\"].append(train_score)\n            evals_result[\"valid\"].append(val_score)\n\n            if val_score > best_score:\n                best_score = val_score\n                stop_steps = 0\n                best_epoch = epoch_idx\n                best_param = copy.deepcopy(self.tabnet_model.state_dict())\n            else:\n                stop_steps += 1\n                if stop_steps >= self.early_stop:\n                    self.logger.info(\"early stop\")\n                    break\n\n        self.logger.info(\"best score: %.6lf @ %d\" % (best_score, best_epoch))\n        self.tabnet_model.load_state_dict(best_param)\n        torch.save(best_param, save_path)\n\n        if self.use_gpu:\n            torch.cuda.empty_cache()\n\n    def predict(self, dataset: DatasetH, segment: Union[Text, slice] = \"test\"):\n        if not self.fitted:\n            raise ValueError(\"model is not fitted yet!\")\n\n        x_test = dataset.prepare(segment, col_set=\"feature\", data_key=DataHandlerLP.DK_I)\n        index = x_test.index\n        self.tabnet_model.eval()\n        x_values = torch.from_numpy(x_test.values)\n        x_values[torch.isnan(x_values)] = 0\n        sample_num = x_values.shape[0]\n        preds = []\n\n        for begin in range(sample_num)[:: self.batch_size]:\n            if sample_num - begin < self.batch_size:\n                end = sample_num\n            else:\n                end = begin + self.batch_size\n\n            x_batch = x_values[begin:end].float().to(self.device)\n            priors = torch.ones(end - begin, self.d_feat).to(self.device)\n\n            with torch.no_grad():\n                pred = self.tabnet_model(x_batch, priors).detach().cpu().numpy()\n\n            preds.append(pred)\n\n        return pd.Series(np.concatenate(preds), index=index)\n\n    def test_epoch(self, data_x, data_y):\n        # prepare training data\n        x_values = torch.from_numpy(data_x.values)\n        y_values = torch.from_numpy(np.squeeze(data_y.values))\n        x_values[torch.isnan(x_values)] = 0\n        y_values[torch.isnan(y_values)] = 0\n        self.tabnet_model.eval()\n\n        scores = []\n        losses = []\n\n        indices = np.arange(len(x_values))\n\n        for i in range(len(indices))[:: self.batch_size]:\n            if len(indices) - i < self.batch_size:\n                break\n            feature = x_values[indices[i : i + self.batch_size]].float().to(self.device)\n            label = y_values[indices[i : i + self.batch_size]].float().to(self.device)\n            priors = torch.ones(self.batch_size, self.d_feat).to(self.device)\n            with torch.no_grad():\n                pred = self.tabnet_model(feature, priors)\n                loss = self.loss_fn(pred, label)\n                losses.append(loss.item())\n\n                score = self.metric_fn(pred, label)\n                scores.append(score.item())\n\n        return np.mean(losses), np.mean(scores)\n\n    def train_epoch(self, x_train, y_train):\n        x_train_values = torch.from_numpy(x_train.values)\n        y_train_values = torch.from_numpy(np.squeeze(y_train.values))\n        x_train_values[torch.isnan(x_train_values)] = 0\n        y_train_values[torch.isnan(y_train_values)] = 0\n        self.tabnet_model.train()\n\n        indices = np.arange(len(x_train_values))\n        np.random.shuffle(indices)\n\n        for i in range(len(indices))[:: self.batch_size]:\n            if len(indices) - i < self.batch_size:\n                break\n\n            feature = x_train_values[indices[i : i + self.batch_size]].float().to(self.device)\n            label = y_train_values[indices[i : i + self.batch_size]].float().to(self.device)\n            priors = torch.ones(self.batch_size, self.d_feat).to(self.device)\n            pred = self.tabnet_model(feature, priors)\n            loss = self.loss_fn(pred, label)\n\n            self.train_optimizer.zero_grad()\n            loss.backward()\n            torch.nn.utils.clip_grad_value_(self.tabnet_model.parameters(), 3.0)\n            self.train_optimizer.step()\n\n    def pretrain_epoch(self, x_train):\n        train_set = torch.from_numpy(x_train.values)\n        train_set[torch.isnan(train_set)] = 0\n        indices = np.arange(len(train_set))\n        np.random.shuffle(indices)\n\n        self.tabnet_model.train()\n        self.tabnet_decoder.train()\n\n        for i in range(len(indices))[:: self.batch_size]:\n            if len(indices) - i < self.batch_size:\n                break\n\n            S_mask = torch.bernoulli(torch.empty(self.batch_size, self.d_feat).fill_(self.ps))\n            x_train_values = train_set[indices[i : i + self.batch_size]] * (1 - S_mask)\n            y_train_values = train_set[indices[i : i + self.batch_size]] * (S_mask)\n\n            S_mask = S_mask.to(self.device)\n            feature = x_train_values.float().to(self.device)\n            label = y_train_values.float().to(self.device)\n            priors = 1 - S_mask\n            vec, sparse_loss = self.tabnet_model(feature, priors)\n            f = self.tabnet_decoder(vec)\n            loss = self.pretrain_loss_fn(label, f, S_mask)\n\n            self.pretrain_optimizer.zero_grad()\n            loss.backward()\n            self.pretrain_optimizer.step()\n\n    def pretrain_test_epoch(self, x_train):\n        train_set = torch.from_numpy(x_train.values)\n        train_set[torch.isnan(train_set)] = 0\n        indices = np.arange(len(train_set))\n\n        self.tabnet_model.eval()\n        self.tabnet_decoder.eval()\n\n        losses = []\n\n        for i in range(len(indices))[:: self.batch_size]:\n            if len(indices) - i < self.batch_size:\n                break\n\n            S_mask = torch.bernoulli(torch.empty(self.batch_size, self.d_feat).fill_(self.ps))\n            x_train_values = train_set[indices[i : i + self.batch_size]] * (1 - S_mask)\n            y_train_values = train_set[indices[i : i + self.batch_size]] * (S_mask)\n\n            feature = x_train_values.float().to(self.device)\n            label = y_train_values.float().to(self.device)\n            S_mask = S_mask.to(self.device)\n            priors = 1 - S_mask\n            with torch.no_grad():\n                vec, sparse_loss = self.tabnet_model(feature, priors)\n                f = self.tabnet_decoder(vec)\n\n                loss = self.pretrain_loss_fn(label, f, S_mask)\n            losses.append(loss.item())\n\n        return np.mean(losses)\n\n    def pretrain_loss_fn(self, f_hat, f, S):\n        \"\"\"\n        Pretrain loss function defined in the original paper, read \"Tabular self-supervised learning\" in https://arxiv.org/pdf/1908.07442.pdf\n        \"\"\"\n        down_mean = torch.mean(f, dim=0)\n        down = torch.sqrt(torch.sum(torch.square(f - down_mean), dim=0))\n        up = (f_hat - f) * S\n        return torch.sum(torch.square(up / down))\n\n    def loss_fn(self, pred, label):\n        mask = ~torch.isnan(label)\n        if self.loss == \"mse\":\n            return self.mse(pred[mask], label[mask])\n        raise ValueError(\"unknown loss `%s`\" % self.loss)\n\n    def metric_fn(self, pred, label):\n        mask = torch.isfinite(label)\n        if self.metric in (\"\", \"loss\"):\n            return -self.loss_fn(pred[mask], label[mask])\n        raise ValueError(\"unknown metric `%s`\" % self.metric)\n\n    def mse(self, pred, label):\n        loss = (pred - label) ** 2\n        return torch.mean(loss)\n\n\nclass FinetuneModel(nn.Module):\n    \"\"\"\n    FinuetuneModel for adding a layer by the end\n    \"\"\"\n\n    def __init__(self, input_dim, output_dim, trained_model):\n        super().__init__()\n        self.model = trained_model\n        self.fc = nn.Linear(input_dim, output_dim)\n\n    def forward(self, x, priors):\n        return self.fc(self.model(x, priors)[0]).squeeze()  # take the vec out\n\n\nclass DecoderStep(nn.Module):\n    def __init__(self, inp_dim, out_dim, shared, n_ind, vbs):\n        super().__init__()\n        self.fea_tran = FeatureTransformer(inp_dim, out_dim, shared, n_ind, vbs)\n        self.fc = nn.Linear(out_dim, out_dim)\n\n    def forward(self, x):\n        x = self.fea_tran(x)\n        return self.fc(x)\n\n\nclass TabNet_Decoder(nn.Module):\n    def __init__(self, inp_dim, out_dim, n_shared, n_ind, vbs, n_steps):\n        \"\"\"\n        TabNet decoder that is used in pre-training\n        \"\"\"\n        super().__init__()\n        self.out_dim = out_dim\n        if n_shared > 0:\n            self.shared = nn.ModuleList()\n            self.shared.append(nn.Linear(inp_dim, 2 * out_dim))\n            for x in range(n_shared - 1):\n                self.shared.append(nn.Linear(out_dim, 2 * out_dim))  # preset the linear function we will use\n        else:\n            self.shared = None\n        self.n_steps = n_steps\n        self.steps = nn.ModuleList()\n        for x in range(n_steps):\n            self.steps.append(DecoderStep(inp_dim, out_dim, self.shared, n_ind, vbs))\n\n    def forward(self, x):\n        out = torch.zeros(x.size(0), self.out_dim).to(x.device)\n        for step in self.steps:\n            out += step(x)\n        return out\n\n\nclass TabNet(nn.Module):\n    def __init__(self, inp_dim=6, out_dim=6, n_d=64, n_a=64, n_shared=2, n_ind=2, n_steps=5, relax=1.2, vbs=1024):\n        \"\"\"\n        TabNet AKA the original encoder\n\n        Args:\n            n_d: dimension of the features used to calculate the final results\n            n_a: dimension of the features input to the attention transformer of the next step\n            n_shared: numbr of shared steps in feature transformer(optional)\n            n_ind: number of independent steps in feature transformer\n            n_steps: number of steps of pass through tabbet\n            relax coefficient:\n            virtual batch size:\n        \"\"\"\n        super().__init__()\n\n        # set the number of shared step in feature transformer\n        if n_shared > 0:\n            self.shared = nn.ModuleList()\n            self.shared.append(nn.Linear(inp_dim, 2 * (n_d + n_a)))\n            for x in range(n_shared - 1):\n                self.shared.append(nn.Linear(n_d + n_a, 2 * (n_d + n_a)))  # preset the linear function we will use\n        else:\n            self.shared = None\n\n        self.first_step = FeatureTransformer(inp_dim, n_d + n_a, self.shared, n_ind, vbs)\n        self.steps = nn.ModuleList()\n        for x in range(n_steps - 1):\n            self.steps.append(DecisionStep(inp_dim, n_d, n_a, self.shared, n_ind, relax, vbs))\n        self.fc = nn.Linear(n_d, out_dim)\n        self.bn = nn.BatchNorm1d(inp_dim, momentum=0.01)\n        self.n_d = n_d\n\n    def forward(self, x, priors):\n        assert not torch.isnan(x).any()\n        x = self.bn(x)\n        x_a = self.first_step(x)[:, self.n_d :]\n        sparse_loss = []\n        out = torch.zeros(x.size(0), self.n_d).to(x.device)\n        for step in self.steps:\n            x_te, loss = step(x, x_a, priors)\n            out += F.relu(x_te[:, : self.n_d])  # split the feature from feat_transformer\n            x_a = x_te[:, self.n_d :]\n            sparse_loss.append(loss)\n        return self.fc(out), sum(sparse_loss)\n\n\nclass GBN(nn.Module):\n    \"\"\"\n    Ghost Batch Normalization\n    an efficient way of doing batch normalization\n\n    Args:\n        vbs: virtual batch size\n    \"\"\"\n\n    def __init__(self, inp, vbs=1024, momentum=0.01):\n        super().__init__()\n        self.bn = nn.BatchNorm1d(inp, momentum=momentum)\n        self.vbs = vbs\n\n    def forward(self, x):\n        if x.size(0) <= self.vbs:  # can not be chunked\n            return self.bn(x)\n        else:\n            chunk = torch.chunk(x, x.size(0) // self.vbs, 0)\n            res = [self.bn(y) for y in chunk]\n            return torch.cat(res, 0)\n\n\nclass GLU(nn.Module):\n    \"\"\"\n    GLU block that extracts only the most essential information\n\n    Args:\n        vbs: virtual batch size\n    \"\"\"\n\n    def __init__(self, inp_dim, out_dim, fc=None, vbs=1024):\n        super().__init__()\n        if fc:\n            self.fc = fc\n        else:\n            self.fc = nn.Linear(inp_dim, out_dim * 2)\n        self.bn = GBN(out_dim * 2, vbs=vbs)\n        self.od = out_dim\n\n    def forward(self, x):\n        x = self.bn(self.fc(x))\n        return torch.mul(x[:, : self.od], torch.sigmoid(x[:, self.od :]))\n\n\nclass AttentionTransformer(nn.Module):\n    \"\"\"\n    Args:\n        relax: relax coefficient. The greater it is, we can\n        use the same features more. When it is set to 1\n        we can use every feature only once\n    \"\"\"\n\n    def __init__(self, d_a, inp_dim, relax, vbs=1024):\n        super().__init__()\n        self.fc = nn.Linear(d_a, inp_dim)\n        self.bn = GBN(inp_dim, vbs=vbs)\n        self.r = relax\n\n    # a:feature from previous decision step\n    def forward(self, a, priors):\n        a = self.bn(self.fc(a))\n        mask = SparsemaxFunction.apply(a * priors)\n        priors = priors * (self.r - mask)  # updating the prior\n        return mask\n\n\nclass FeatureTransformer(nn.Module):\n    def __init__(self, inp_dim, out_dim, shared, n_ind, vbs):\n        super().__init__()\n        first = True\n        self.shared = nn.ModuleList()\n        if shared:\n            self.shared.append(GLU(inp_dim, out_dim, shared[0], vbs=vbs))\n            first = False\n            for fc in shared[1:]:\n                self.shared.append(GLU(out_dim, out_dim, fc, vbs=vbs))\n        else:\n            self.shared = None\n        self.independ = nn.ModuleList()\n        if first:\n            self.independ.append(GLU(inp_dim, out_dim, vbs=vbs))\n        for x in range(first, n_ind):\n            self.independ.append(GLU(out_dim, out_dim, vbs=vbs))\n        self.scale = float(np.sqrt(0.5))\n\n    def forward(self, x):\n        if self.shared:\n            x = self.shared[0](x)\n            for glu in self.shared[1:]:\n                x = torch.add(x, glu(x))\n                x = x * self.scale\n        for glu in self.independ:\n            x = torch.add(x, glu(x))\n            x = x * self.scale\n        return x\n\n\nclass DecisionStep(nn.Module):\n    \"\"\"\n    One step for the TabNet\n    \"\"\"\n\n    def __init__(self, inp_dim, n_d, n_a, shared, n_ind, relax, vbs):\n        super().__init__()\n        self.atten_tran = AttentionTransformer(n_a, inp_dim, relax, vbs)\n        self.fea_tran = FeatureTransformer(inp_dim, n_d + n_a, shared, n_ind, vbs)\n\n    def forward(self, x, a, priors):\n        mask = self.atten_tran(a, priors)\n        sparse_loss = ((-1) * mask * torch.log(mask + 1e-10)).mean()\n        x = self.fea_tran(x * mask)\n        return x, sparse_loss\n\n\ndef make_ix_like(input, dim=0):\n    d = input.size(dim)\n    rho = torch.arange(1, d + 1, device=input.device, dtype=input.dtype)\n    view = [1] * input.dim()\n    view[0] = -1\n    return rho.view(view).transpose(0, dim)\n\n\nclass SparsemaxFunction(Function):\n    \"\"\"\n    SparseMax function for replacing reLU\n    \"\"\"\n\n    @staticmethod\n    def forward(ctx, input, dim=-1):\n        ctx.dim = dim\n        max_val, _ = input.max(dim=dim, keepdim=True)\n        input -= max_val  # same numerical stability trick as for softmax\n        tau, supp_size = SparsemaxFunction.threshold_and_support(input, dim=dim)\n        output = torch.clamp(input - tau, min=0)\n        ctx.save_for_backward(supp_size, output)\n        return output\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        supp_size, output = ctx.saved_tensors\n        dim = ctx.dim\n        grad_input = grad_output.clone()\n        grad_input[output == 0] = 0\n\n        v_hat = grad_input.sum(dim=dim) / supp_size.to(output.dtype).squeeze()\n        v_hat = v_hat.unsqueeze(dim)\n        grad_input = torch.where(output != 0, grad_input - v_hat, grad_input)\n        return grad_input, None\n\n    @staticmethod\n    def threshold_and_support(input, dim=-1):\n        input_srt, _ = torch.sort(input, descending=True, dim=dim)\n        input_cumsum = input_srt.cumsum(dim) - 1\n        rhos = make_ix_like(input, dim)\n        support = rhos * input_srt > input_cumsum\n\n        support_size = support.sum(dim=dim).unsqueeze(dim)\n        tau = input_cumsum.gather(dim, support_size - 1)\n        tau /= support_size.to(input.dtype)\n        return tau, support_size\n"
  },
  {
    "path": "qlib/contrib/model/pytorch_tcn.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\n\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport numpy as np\nimport pandas as pd\nfrom typing import Text, Union\nimport copy\nfrom ...utils import get_or_create_path\nfrom ...log import get_module_logger\n\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\n\nfrom .pytorch_utils import count_parameters\nfrom ...model.base import Model\nfrom ...data.dataset import DatasetH\nfrom ...data.dataset.handler import DataHandlerLP\nfrom .tcn import TemporalConvNet\n\n\nclass TCN(Model):\n    \"\"\"TCN Model\n\n    Parameters\n    ----------\n    d_feat : int\n        input dimension for each time step\n    n_chans: int\n        number of channels\n    metric: str\n        the evaluation metric used in early stop\n    optimizer : str\n        optimizer name\n    GPU : str\n        the GPU ID(s) used for training\n    \"\"\"\n\n    def __init__(\n        self,\n        d_feat=6,\n        n_chans=128,\n        kernel_size=5,\n        num_layers=5,\n        dropout=0.5,\n        n_epochs=200,\n        lr=0.0001,\n        metric=\"\",\n        batch_size=2000,\n        early_stop=20,\n        loss=\"mse\",\n        optimizer=\"adam\",\n        GPU=0,\n        seed=None,\n        **kwargs,\n    ):\n        # Set logger.\n        self.logger = get_module_logger(\"TCN\")\n        self.logger.info(\"TCN pytorch version...\")\n\n        # set hyper-parameters.\n        self.d_feat = d_feat\n        self.n_chans = n_chans\n        self.kernel_size = kernel_size\n        self.num_layers = num_layers\n        self.dropout = dropout\n        self.n_epochs = n_epochs\n        self.lr = lr\n        self.metric = metric\n        self.batch_size = batch_size\n        self.early_stop = early_stop\n        self.optimizer = optimizer.lower()\n        self.loss = loss\n        self.device = torch.device(\"cuda:%d\" % (GPU) if torch.cuda.is_available() and GPU >= 0 else \"cpu\")\n        self.seed = seed\n\n        self.logger.info(\n            \"TCN parameters setting:\"\n            \"\\nd_feat : {}\"\n            \"\\nn_chans : {}\"\n            \"\\nkernel_size : {}\"\n            \"\\nnum_layers : {}\"\n            \"\\ndropout : {}\"\n            \"\\nn_epochs : {}\"\n            \"\\nlr : {}\"\n            \"\\nmetric : {}\"\n            \"\\nbatch_size : {}\"\n            \"\\nearly_stop : {}\"\n            \"\\noptimizer : {}\"\n            \"\\nloss_type : {}\"\n            \"\\nvisible_GPU : {}\"\n            \"\\nuse_GPU : {}\"\n            \"\\nseed : {}\".format(\n                d_feat,\n                n_chans,\n                kernel_size,\n                num_layers,\n                dropout,\n                n_epochs,\n                lr,\n                metric,\n                batch_size,\n                early_stop,\n                optimizer.lower(),\n                loss,\n                GPU,\n                self.use_gpu,\n                seed,\n            )\n        )\n\n        if self.seed is not None:\n            np.random.seed(self.seed)\n            torch.manual_seed(self.seed)\n\n        self.tcn_model = TCNModel(\n            num_input=self.d_feat,\n            output_size=1,\n            num_channels=[self.n_chans] * self.num_layers,\n            kernel_size=self.kernel_size,\n            dropout=self.dropout,\n        )\n        self.logger.info(\"model:\\n{:}\".format(self.tcn_model))\n        self.logger.info(\"model size: {:.4f} MB\".format(count_parameters(self.tcn_model)))\n\n        if optimizer.lower() == \"adam\":\n            self.train_optimizer = optim.Adam(self.tcn_model.parameters(), lr=self.lr)\n        elif optimizer.lower() == \"gd\":\n            self.train_optimizer = optim.SGD(self.tcn_model.parameters(), lr=self.lr)\n        else:\n            raise NotImplementedError(\"optimizer {} is not supported!\".format(optimizer))\n\n        self.fitted = False\n        self.tcn_model.to(self.device)\n\n    @property\n    def use_gpu(self):\n        return self.device != torch.device(\"cpu\")\n\n    def mse(self, pred, label):\n        loss = (pred - label) ** 2\n        return torch.mean(loss)\n\n    def loss_fn(self, pred, label):\n        mask = ~torch.isnan(label)\n\n        if self.loss == \"mse\":\n            return self.mse(pred[mask], label[mask])\n\n        raise ValueError(\"unknown loss `%s`\" % self.loss)\n\n    def metric_fn(self, pred, label):\n        mask = torch.isfinite(label)\n\n        if self.metric in (\"\", \"loss\"):\n            return -self.loss_fn(pred[mask], label[mask])\n\n        raise ValueError(\"unknown metric `%s`\" % self.metric)\n\n    def train_epoch(self, x_train, y_train):\n        x_train_values = x_train.values\n        y_train_values = np.squeeze(y_train.values)\n\n        self.tcn_model.train()\n\n        indices = np.arange(len(x_train_values))\n        np.random.shuffle(indices)\n\n        for i in range(len(indices))[:: self.batch_size]:\n            if len(indices) - i < self.batch_size:\n                break\n\n            feature = torch.from_numpy(x_train_values[indices[i : i + self.batch_size]]).float().to(self.device)\n            label = torch.from_numpy(y_train_values[indices[i : i + self.batch_size]]).float().to(self.device)\n\n            pred = self.tcn_model(feature)\n            loss = self.loss_fn(pred, label)\n\n            self.train_optimizer.zero_grad()\n            loss.backward()\n            torch.nn.utils.clip_grad_value_(self.tcn_model.parameters(), 3.0)\n            self.train_optimizer.step()\n\n    def test_epoch(self, data_x, data_y):\n        x_values = data_x.values\n        y_values = np.squeeze(data_y.values)\n\n        self.tcn_model.eval()\n\n        scores = []\n        losses = []\n\n        indices = np.arange(len(x_values))\n\n        for i in range(len(indices))[:: self.batch_size]:\n            if len(indices) - i < self.batch_size:\n                break\n\n            feature = torch.from_numpy(x_values[indices[i : i + self.batch_size]]).float().to(self.device)\n            label = torch.from_numpy(y_values[indices[i : i + self.batch_size]]).float().to(self.device)\n\n            with torch.no_grad():\n                pred = self.tcn_model(feature)\n                loss = self.loss_fn(pred, label)\n                losses.append(loss.item())\n\n                score = self.metric_fn(pred, label)\n                scores.append(score.item())\n\n        return np.mean(losses), np.mean(scores)\n\n    def fit(\n        self,\n        dataset: DatasetH,\n        evals_result=dict(),\n        save_path=None,\n    ):\n        df_train, df_valid, df_test = dataset.prepare(\n            [\"train\", \"valid\", \"test\"],\n            col_set=[\"feature\", \"label\"],\n            data_key=DataHandlerLP.DK_L,\n        )\n\n        x_train, y_train = df_train[\"feature\"], df_train[\"label\"]\n        x_valid, y_valid = df_valid[\"feature\"], df_valid[\"label\"]\n\n        save_path = get_or_create_path(save_path)\n        stop_steps = 0\n        train_loss = 0\n        best_score = -np.inf\n        best_epoch = 0\n        evals_result[\"train\"] = []\n        evals_result[\"valid\"] = []\n\n        # train\n        self.logger.info(\"training...\")\n        self.fitted = True\n\n        for step in range(self.n_epochs):\n            self.logger.info(\"Epoch%d:\", step)\n            self.logger.info(\"training...\")\n            self.train_epoch(x_train, y_train)\n            self.logger.info(\"evaluating...\")\n            train_loss, train_score = self.test_epoch(x_train, y_train)\n            val_loss, val_score = self.test_epoch(x_valid, y_valid)\n            self.logger.info(\"train %.6f, valid %.6f\" % (train_score, val_score))\n            evals_result[\"train\"].append(train_score)\n            evals_result[\"valid\"].append(val_score)\n\n            if val_score > best_score:\n                best_score = val_score\n                stop_steps = 0\n                best_epoch = step\n                best_param = copy.deepcopy(self.tcn_model.state_dict())\n            else:\n                stop_steps += 1\n                if stop_steps >= self.early_stop:\n                    self.logger.info(\"early stop\")\n                    break\n\n        self.logger.info(\"best score: %.6lf @ %d\" % (best_score, best_epoch))\n        self.tcn_model.load_state_dict(best_param)\n        torch.save(best_param, save_path)\n\n        if self.use_gpu:\n            torch.cuda.empty_cache()\n\n    def predict(self, dataset: DatasetH, segment: Union[Text, slice] = \"test\"):\n        if not self.fitted:\n            raise ValueError(\"model is not fitted yet!\")\n\n        x_test = dataset.prepare(segment, col_set=\"feature\", data_key=DataHandlerLP.DK_I)\n        index = x_test.index\n        self.tcn_model.eval()\n        x_values = x_test.values\n        sample_num = x_values.shape[0]\n        preds = []\n\n        for begin in range(sample_num)[:: self.batch_size]:\n            if sample_num - begin < self.batch_size:\n                end = sample_num\n            else:\n                end = begin + self.batch_size\n\n            x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)\n\n            with torch.no_grad():\n                pred = self.tcn_model(x_batch).detach().cpu().numpy()\n\n            preds.append(pred)\n\n        return pd.Series(np.concatenate(preds), index=index)\n\n\nclass TCNModel(nn.Module):\n    def __init__(self, num_input, output_size, num_channels, kernel_size, dropout):\n        super().__init__()\n        self.num_input = num_input\n        self.tcn = TemporalConvNet(num_input, num_channels, kernel_size, dropout=dropout)\n        self.linear = nn.Linear(num_channels[-1], output_size)\n\n    def forward(self, x):\n        x = x.reshape(x.shape[0], self.num_input, -1)\n        output = self.tcn(x)\n        output = self.linear(output[:, :, -1])\n        return output.squeeze()\n"
  },
  {
    "path": "qlib/contrib/model/pytorch_tcn_ts.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\n\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport numpy as np\nimport pandas as pd\nimport copy\nfrom ...utils import get_or_create_path\nfrom ...log import get_module_logger\n\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.utils.data import DataLoader\n\nfrom .pytorch_utils import count_parameters\nfrom ...model.base import Model\nfrom ...data.dataset.handler import DataHandlerLP\nfrom .tcn import TemporalConvNet\n\n\nclass TCN(Model):\n    \"\"\"TCN Model\n\n    Parameters\n    ----------\n    d_feat : int\n        input dimension for each time step\n    metric: str\n        the evaluation metric used in early stop\n    optimizer : str\n        optimizer name\n    GPU : str\n        the GPU ID(s) used for training\n    \"\"\"\n\n    def __init__(\n        self,\n        d_feat=6,\n        n_chans=128,\n        kernel_size=5,\n        num_layers=2,\n        dropout=0.0,\n        n_epochs=200,\n        lr=0.001,\n        metric=\"\",\n        batch_size=2000,\n        early_stop=20,\n        loss=\"mse\",\n        optimizer=\"adam\",\n        n_jobs=10,\n        GPU=0,\n        seed=None,\n        **kwargs,\n    ):\n        # Set logger.\n        self.logger = get_module_logger(\"TCN\")\n        self.logger.info(\"TCN pytorch version...\")\n\n        # set hyper-parameters.\n        self.d_feat = d_feat\n        self.n_chans = n_chans\n        self.kernel_size = kernel_size\n        self.num_layers = num_layers\n        self.dropout = dropout\n        self.n_epochs = n_epochs\n        self.lr = lr\n        self.metric = metric\n        self.batch_size = batch_size\n        self.early_stop = early_stop\n        self.optimizer = optimizer.lower()\n        self.loss = loss\n        self.device = torch.device(\"cuda:%d\" % (GPU) if torch.cuda.is_available() and GPU >= 0 else \"cpu\")\n        self.n_jobs = n_jobs\n        self.seed = seed\n\n        self.logger.info(\n            \"TCN parameters setting:\"\n            \"\\nd_feat : {}\"\n            \"\\nn_chans : {}\"\n            \"\\nkernel_size : {}\"\n            \"\\nnum_layers : {}\"\n            \"\\ndropout : {}\"\n            \"\\nn_epochs : {}\"\n            \"\\nlr : {}\"\n            \"\\nmetric : {}\"\n            \"\\nbatch_size : {}\"\n            \"\\nearly_stop : {}\"\n            \"\\noptimizer : {}\"\n            \"\\nloss_type : {}\"\n            \"\\ndevice : {}\"\n            \"\\nn_jobs : {}\"\n            \"\\nuse_GPU : {}\"\n            \"\\nseed : {}\".format(\n                d_feat,\n                n_chans,\n                kernel_size,\n                num_layers,\n                dropout,\n                n_epochs,\n                lr,\n                metric,\n                batch_size,\n                early_stop,\n                optimizer.lower(),\n                loss,\n                self.device,\n                n_jobs,\n                self.use_gpu,\n                seed,\n            )\n        )\n\n        if self.seed is not None:\n            np.random.seed(self.seed)\n            torch.manual_seed(self.seed)\n\n        self.TCN_model = TCNModel(\n            num_input=self.d_feat,\n            output_size=1,\n            num_channels=[self.n_chans] * self.num_layers,\n            kernel_size=self.kernel_size,\n            dropout=self.dropout,\n        )\n        self.logger.info(\"model:\\n{:}\".format(self.TCN_model))\n        self.logger.info(\"model size: {:.4f} MB\".format(count_parameters(self.TCN_model)))\n\n        if optimizer.lower() == \"adam\":\n            self.train_optimizer = optim.Adam(self.TCN_model.parameters(), lr=self.lr)\n        elif optimizer.lower() == \"gd\":\n            self.train_optimizer = optim.SGD(self.TCN_model.parameters(), lr=self.lr)\n        else:\n            raise NotImplementedError(\"optimizer {} is not supported!\".format(optimizer))\n\n        self.fitted = False\n        self.TCN_model.to(self.device)\n\n    @property\n    def use_gpu(self):\n        return self.device != torch.device(\"cpu\")\n\n    def mse(self, pred, label):\n        loss = (pred - label) ** 2\n        return torch.mean(loss)\n\n    def loss_fn(self, pred, label):\n        mask = ~torch.isnan(label)\n\n        if self.loss == \"mse\":\n            return self.mse(pred[mask], label[mask])\n\n        raise ValueError(\"unknown loss `%s`\" % self.loss)\n\n    def metric_fn(self, pred, label):\n        mask = torch.isfinite(label)\n\n        if self.metric in (\"\", \"loss\"):\n            return -self.loss_fn(pred[mask], label[mask])\n\n        raise ValueError(\"unknown metric `%s`\" % self.metric)\n\n    def train_epoch(self, data_loader):\n        self.TCN_model.train()\n\n        for data in data_loader:\n            data = torch.transpose(data, 1, 2)\n            feature = data[:, 0:-1, :].to(self.device)\n            label = data[:, -1, -1].to(self.device)\n\n            pred = self.TCN_model(feature.float())\n            loss = self.loss_fn(pred, label)\n\n            self.train_optimizer.zero_grad()\n            loss.backward()\n            torch.nn.utils.clip_grad_value_(self.TCN_model.parameters(), 3.0)\n            self.train_optimizer.step()\n\n    def test_epoch(self, data_loader):\n        self.TCN_model.eval()\n\n        scores = []\n        losses = []\n\n        for data in data_loader:\n            data = torch.transpose(data, 1, 2)\n            feature = data[:, 0:-1, :].to(self.device)\n            # feature[torch.isnan(feature)] = 0\n            label = data[:, -1, -1].to(self.device)\n\n            with torch.no_grad():\n                pred = self.TCN_model(feature.float())\n                loss = self.loss_fn(pred, label)\n                losses.append(loss.item())\n\n                score = self.metric_fn(pred, label)\n                scores.append(score.item())\n\n        return np.mean(losses), np.mean(scores)\n\n    def fit(\n        self,\n        dataset,\n        evals_result=dict(),\n        save_path=None,\n    ):\n        dl_train = dataset.prepare(\"train\", col_set=[\"feature\", \"label\"], data_key=DataHandlerLP.DK_L)\n        dl_valid = dataset.prepare(\"valid\", col_set=[\"feature\", \"label\"], data_key=DataHandlerLP.DK_L)\n\n        # process nan brought by dataloader\n        dl_train.config(fillna_type=\"ffill+bfill\")\n        # process nan brought by dataloader\n        dl_valid.config(fillna_type=\"ffill+bfill\")\n\n        train_loader = DataLoader(\n            dl_train, batch_size=self.batch_size, shuffle=True, num_workers=self.n_jobs, drop_last=True\n        )\n        valid_loader = DataLoader(\n            dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs, drop_last=True\n        )\n\n        save_path = get_or_create_path(save_path)\n\n        stop_steps = 0\n        train_loss = 0\n        best_score = -np.inf\n        best_epoch = 0\n        evals_result[\"train\"] = []\n        evals_result[\"valid\"] = []\n\n        # train\n        self.logger.info(\"training...\")\n        self.fitted = True\n\n        for step in range(self.n_epochs):\n            self.logger.info(\"Epoch%d:\", step)\n            self.logger.info(\"training...\")\n            self.train_epoch(train_loader)\n            self.logger.info(\"evaluating...\")\n            train_loss, train_score = self.test_epoch(train_loader)\n            val_loss, val_score = self.test_epoch(valid_loader)\n            self.logger.info(\"train %.6f, valid %.6f\" % (train_score, val_score))\n            evals_result[\"train\"].append(train_score)\n            evals_result[\"valid\"].append(val_score)\n\n            if val_score > best_score:\n                best_score = val_score\n                stop_steps = 0\n                best_epoch = step\n                best_param = copy.deepcopy(self.TCN_model.state_dict())\n            else:\n                stop_steps += 1\n                if stop_steps >= self.early_stop:\n                    self.logger.info(\"early stop\")\n                    break\n\n        self.logger.info(\"best score: %.6lf @ %d\" % (best_score, best_epoch))\n        self.TCN_model.load_state_dict(best_param)\n        torch.save(best_param, save_path)\n\n        if self.use_gpu:\n            torch.cuda.empty_cache()\n\n    def predict(self, dataset):\n        if not self.fitted:\n            raise ValueError(\"model is not fitted yet!\")\n\n        dl_test = dataset.prepare(\"test\", col_set=[\"feature\", \"label\"], data_key=DataHandlerLP.DK_I)\n        dl_test.config(fillna_type=\"ffill+bfill\")\n        test_loader = DataLoader(dl_test, batch_size=self.batch_size, num_workers=self.n_jobs)\n        self.TCN_model.eval()\n        preds = []\n\n        for data in test_loader:\n            feature = data[:, :, 0:-1].to(self.device)\n\n            with torch.no_grad():\n                pred = self.TCN_model(feature.float()).detach().cpu().numpy()\n\n            preds.append(pred)\n\n        return pd.Series(np.concatenate(preds), index=dl_test.get_index())\n\n\nclass TCNModel(nn.Module):\n    def __init__(self, num_input, output_size, num_channels, kernel_size, dropout):\n        super().__init__()\n        self.num_input = num_input\n        self.tcn = TemporalConvNet(num_input, num_channels, kernel_size, dropout=dropout)\n        self.linear = nn.Linear(num_channels[-1], output_size)\n\n    def forward(self, x):\n        output = self.tcn(x)\n        output = self.linear(output[:, :, -1])\n        return output.squeeze()\n"
  },
  {
    "path": "qlib/contrib/model/pytorch_tcts.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\n\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport numpy as np\nimport pandas as pd\nimport copy\nimport random\nfrom ...utils import get_or_create_path\nfrom ...log import get_module_logger\n\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\n\nfrom ...model.base import Model\nfrom ...data.dataset import DatasetH\nfrom ...data.dataset.handler import DataHandlerLP\n\n\nclass TCTS(Model):\n    \"\"\"TCTS Model\n\n    Parameters\n    ----------\n    d_feat : int\n        input dimension for each time step\n    metric: str\n        the evaluation metric used in early stop\n    optimizer : str\n        optimizer name\n    GPU : str\n        the GPU ID(s) used for training\n    \"\"\"\n\n    def __init__(\n        self,\n        d_feat=6,\n        hidden_size=64,\n        num_layers=2,\n        dropout=0.0,\n        n_epochs=200,\n        batch_size=2000,\n        early_stop=20,\n        loss=\"mse\",\n        fore_optimizer=\"adam\",\n        weight_optimizer=\"adam\",\n        input_dim=360,\n        output_dim=5,\n        fore_lr=5e-7,\n        weight_lr=5e-7,\n        steps=3,\n        GPU=0,\n        target_label=0,\n        mode=\"soft\",\n        seed=None,\n        lowest_valid_performance=0.993,\n        **kwargs,\n    ):\n        # Set logger.\n        self.logger = get_module_logger(\"TCTS\")\n        self.logger.info(\"TCTS pytorch version...\")\n\n        # set hyper-parameters.\n        self.d_feat = d_feat\n        self.hidden_size = hidden_size\n        self.num_layers = num_layers\n        self.dropout = dropout\n        self.n_epochs = n_epochs\n        self.batch_size = batch_size\n        self.early_stop = early_stop\n        self.loss = loss\n        self.device = torch.device(\"cuda:%d\" % (GPU) if torch.cuda.is_available() else \"cpu\")\n        self.use_gpu = torch.cuda.is_available()\n        self.seed = seed\n        self.input_dim = input_dim\n        self.output_dim = output_dim\n        self.fore_lr = fore_lr\n        self.weight_lr = weight_lr\n        self.steps = steps\n        self.target_label = target_label\n        self.mode = mode\n        self.lowest_valid_performance = lowest_valid_performance\n        self._fore_optimizer = fore_optimizer\n        self._weight_optimizer = weight_optimizer\n\n        self.logger.info(\n            \"TCTS parameters setting:\"\n            \"\\nd_feat : {}\"\n            \"\\nhidden_size : {}\"\n            \"\\nnum_layers : {}\"\n            \"\\ndropout : {}\"\n            \"\\nn_epochs : {}\"\n            \"\\nbatch_size : {}\"\n            \"\\nearly_stop : {}\"\n            \"\\ntarget_label : {}\"\n            \"\\nmode : {}\"\n            \"\\nloss_type : {}\"\n            \"\\nvisible_GPU : {}\"\n            \"\\nuse_GPU : {}\"\n            \"\\nseed : {}\".format(\n                d_feat,\n                hidden_size,\n                num_layers,\n                dropout,\n                n_epochs,\n                batch_size,\n                early_stop,\n                target_label,\n                mode,\n                loss,\n                GPU,\n                self.use_gpu,\n                seed,\n            )\n        )\n\n    def loss_fn(self, pred, label, weight):\n        if self.mode == \"hard\":\n            loc = torch.argmax(weight, 1)\n            loss = (pred - label[np.arange(weight.shape[0]), loc]) ** 2\n            return torch.mean(loss)\n\n        elif self.mode == \"soft\":\n            loss = (pred - label.transpose(0, 1)) ** 2\n            return torch.mean(loss * weight.transpose(0, 1))\n\n        else:\n            raise NotImplementedError(\"mode {} is not supported!\".format(self.mode))\n\n    def train_epoch(self, x_train, y_train, x_valid, y_valid):\n        x_train_values = x_train.values\n        y_train_values = np.squeeze(y_train.values)\n\n        indices = np.arange(len(x_train_values))\n        np.random.shuffle(indices)\n\n        task_embedding = torch.zeros([self.batch_size, self.output_dim])\n        task_embedding[:, self.target_label] = 1\n        task_embedding = task_embedding.to(self.device)\n\n        init_fore_model = copy.deepcopy(self.fore_model)\n        for p in init_fore_model.parameters():\n            p.requires_grad = False\n\n        self.fore_model.train()\n        self.weight_model.train()\n\n        for p in self.weight_model.parameters():\n            p.requires_grad = False\n        for p in self.fore_model.parameters():\n            p.requires_grad = True\n\n        for i in range(self.steps):\n            for i in range(len(indices))[:: self.batch_size]:\n                if len(indices) - i < self.batch_size:\n                    break\n\n                feature = torch.from_numpy(x_train_values[indices[i : i + self.batch_size]]).float().to(self.device)\n                label = torch.from_numpy(y_train_values[indices[i : i + self.batch_size]]).float().to(self.device)\n\n                init_pred = init_fore_model(feature)\n                pred = self.fore_model(feature)\n                dis = init_pred - label.transpose(0, 1)\n                weight_feature = torch.cat(\n                    (feature, dis.transpose(0, 1), label, init_pred.view(-1, 1), task_embedding), 1\n                )\n                weight = self.weight_model(weight_feature)\n\n                loss = self.loss_fn(pred, label, weight)\n\n                self.fore_optimizer.zero_grad()\n                loss.backward()\n                torch.nn.utils.clip_grad_value_(self.fore_model.parameters(), 3.0)\n                self.fore_optimizer.step()\n\n        x_valid_values = x_valid.values\n        y_valid_values = np.squeeze(y_valid.values)\n\n        indices = np.arange(len(x_valid_values))\n        np.random.shuffle(indices)\n        for p in self.weight_model.parameters():\n            p.requires_grad = True\n        for p in self.fore_model.parameters():\n            p.requires_grad = False\n\n        # fix forecasting model and valid weight model\n        for i in range(len(indices))[:: self.batch_size]:\n            if len(indices) - i < self.batch_size:\n                break\n\n            feature = torch.from_numpy(x_valid_values[indices[i : i + self.batch_size]]).float().to(self.device)\n            label = torch.from_numpy(y_valid_values[indices[i : i + self.batch_size]]).float().to(self.device)\n\n            pred = self.fore_model(feature)\n            dis = pred - label.transpose(0, 1)\n            weight_feature = torch.cat((feature, dis.transpose(0, 1), label, pred.view(-1, 1), task_embedding), 1)\n            weight = self.weight_model(weight_feature)\n            loc = torch.argmax(weight, 1)\n            valid_loss = torch.mean((pred - label[:, abs(self.target_label)]) ** 2)\n            loss = torch.mean(valid_loss * torch.log(weight[np.arange(weight.shape[0]), loc]))\n\n            self.weight_optimizer.zero_grad()\n            loss.backward()\n            torch.nn.utils.clip_grad_value_(self.weight_model.parameters(), 3.0)\n            self.weight_optimizer.step()\n\n    def test_epoch(self, data_x, data_y):\n        # prepare training data\n        x_values = data_x.values\n        y_values = np.squeeze(data_y.values)\n\n        self.fore_model.eval()\n\n        losses = []\n\n        indices = np.arange(len(x_values))\n\n        for i in range(len(indices))[:: self.batch_size]:\n            if len(indices) - i < self.batch_size:\n                break\n\n            feature = torch.from_numpy(x_values[indices[i : i + self.batch_size]]).float().to(self.device)\n            label = torch.from_numpy(y_values[indices[i : i + self.batch_size]]).float().to(self.device)\n\n            pred = self.fore_model(feature)\n            loss = torch.mean((pred - label[:, abs(self.target_label)]) ** 2)\n            losses.append(loss.item())\n\n        return np.mean(losses)\n\n    def fit(\n        self,\n        dataset: DatasetH,\n        verbose=True,\n        save_path=None,\n    ):\n        df_train, df_valid, df_test = dataset.prepare(\n            [\"train\", \"valid\", \"test\"],\n            col_set=[\"feature\", \"label\"],\n            data_key=DataHandlerLP.DK_L,\n        )\n        if df_train.empty or df_valid.empty:\n            raise ValueError(\"Empty data from dataset, please check your dataset config.\")\n\n        x_train, y_train = df_train[\"feature\"], df_train[\"label\"]\n        x_valid, y_valid = df_valid[\"feature\"], df_valid[\"label\"]\n        x_test, y_test = df_test[\"feature\"], df_test[\"label\"]\n\n        if save_path is None:\n            save_path = get_or_create_path(save_path)\n        best_loss = np.inf\n        while best_loss > self.lowest_valid_performance:\n            if best_loss < np.inf:\n                print(\"Failed! Start retraining.\")\n                self.seed = random.randint(0, 1000)  # reset random seed\n\n            if self.seed is not None:\n                np.random.seed(self.seed)\n                torch.manual_seed(self.seed)\n\n            best_loss = self.training(\n                x_train, y_train, x_valid, y_valid, x_test, y_test, verbose=verbose, save_path=save_path\n            )\n\n    def training(\n        self,\n        x_train,\n        y_train,\n        x_valid,\n        y_valid,\n        x_test,\n        y_test,\n        verbose=True,\n        save_path=None,\n    ):\n        self.fore_model = GRUModel(\n            d_feat=self.d_feat,\n            hidden_size=self.hidden_size,\n            num_layers=self.num_layers,\n            dropout=self.dropout,\n        )\n        self.weight_model = MLPModel(\n            d_feat=self.input_dim + 3 * self.output_dim + 1,\n            hidden_size=self.hidden_size,\n            num_layers=self.num_layers,\n            dropout=self.dropout,\n            output_dim=self.output_dim,\n        )\n        if self._fore_optimizer.lower() == \"adam\":\n            self.fore_optimizer = optim.Adam(self.fore_model.parameters(), lr=self.fore_lr)\n        elif self._fore_optimizer.lower() == \"gd\":\n            self.fore_optimizer = optim.SGD(self.fore_model.parameters(), lr=self.fore_lr)\n        else:\n            raise NotImplementedError(\"optimizer {} is not supported!\".format(self._fore_optimizer))\n        if self._weight_optimizer.lower() == \"adam\":\n            self.weight_optimizer = optim.Adam(self.weight_model.parameters(), lr=self.weight_lr)\n        elif self._weight_optimizer.lower() == \"gd\":\n            self.weight_optimizer = optim.SGD(self.weight_model.parameters(), lr=self.weight_lr)\n        else:\n            raise NotImplementedError(\"optimizer {} is not supported!\".format(self._weight_optimizer))\n\n        self.fitted = False\n        self.fore_model.to(self.device)\n        self.weight_model.to(self.device)\n\n        best_loss = np.inf\n        best_epoch = 0\n        stop_round = 0\n\n        for epoch in range(self.n_epochs):\n            print(\"Epoch:\", epoch)\n\n            print(\"training...\")\n            self.train_epoch(x_train, y_train, x_valid, y_valid)\n            print(\"evaluating...\")\n            val_loss = self.test_epoch(x_valid, y_valid)\n            test_loss = self.test_epoch(x_test, y_test)\n\n            if verbose:\n                print(\"valid %.6f, test %.6f\" % (val_loss, test_loss))\n\n            if val_loss < best_loss:\n                best_loss = val_loss\n                stop_round = 0\n                best_epoch = epoch\n                torch.save(copy.deepcopy(self.fore_model.state_dict()), save_path + \"_fore_model.bin\")\n                torch.save(copy.deepcopy(self.weight_model.state_dict()), save_path + \"_weight_model.bin\")\n\n            else:\n                stop_round += 1\n                if stop_round >= self.early_stop:\n                    print(\"early stop\")\n                    break\n\n        print(\"best loss:\", best_loss, \"@\", best_epoch)\n        best_param = torch.load(save_path + \"_fore_model.bin\", map_location=self.device)\n        self.fore_model.load_state_dict(best_param)\n        best_param = torch.load(save_path + \"_weight_model.bin\", map_location=self.device)\n        self.weight_model.load_state_dict(best_param)\n        self.fitted = True\n\n        if self.use_gpu:\n            torch.cuda.empty_cache()\n\n        return best_loss\n\n    def predict(self, dataset):\n        if not self.fitted:\n            raise ValueError(\"model is not fitted yet!\")\n\n        x_test = dataset.prepare(\"test\", col_set=\"feature\")\n        index = x_test.index\n        self.fore_model.eval()\n        x_values = x_test.values\n        sample_num = x_values.shape[0]\n        preds = []\n\n        for begin in range(sample_num)[:: self.batch_size]:\n            if sample_num - begin < self.batch_size:\n                end = sample_num\n            else:\n                end = begin + self.batch_size\n\n            x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)\n\n            with torch.no_grad():\n                if self.use_gpu:\n                    pred = self.fore_model(x_batch).detach().cpu().numpy()\n                else:\n                    pred = self.fore_model(x_batch).detach().numpy()\n\n            preds.append(pred)\n\n        return pd.Series(np.concatenate(preds), index=index)\n\n\nclass MLPModel(nn.Module):\n    def __init__(self, d_feat, hidden_size=256, num_layers=3, dropout=0.0, output_dim=1):\n        super().__init__()\n\n        self.mlp = nn.Sequential()\n        self.softmax = nn.Softmax(dim=1)\n\n        for i in range(num_layers):\n            if i > 0:\n                self.mlp.add_module(\"drop_%d\" % i, nn.Dropout(dropout))\n            self.mlp.add_module(\"fc_%d\" % i, nn.Linear(d_feat if i == 0 else hidden_size, hidden_size))\n            self.mlp.add_module(\"relu_%d\" % i, nn.ReLU())\n\n        self.mlp.add_module(\"fc_out\", nn.Linear(hidden_size, output_dim))\n\n    def forward(self, x):\n        # feature\n        # [N, F]\n        out = self.mlp(x).squeeze()\n        out = self.softmax(out)\n        return out\n\n\nclass GRUModel(nn.Module):\n    def __init__(self, d_feat=6, hidden_size=64, num_layers=2, dropout=0.0):\n        super().__init__()\n\n        self.rnn = nn.GRU(\n            input_size=d_feat,\n            hidden_size=hidden_size,\n            num_layers=num_layers,\n            batch_first=True,\n            dropout=dropout,\n        )\n        self.fc_out = nn.Linear(hidden_size, 1)\n\n        self.d_feat = d_feat\n\n    def forward(self, x):\n        # x: [N, F*T]\n        x = x.reshape(len(x), self.d_feat, -1)  # [N, F, T]\n        x = x.permute(0, 2, 1)  # [N, T, F]\n        out, _ = self.rnn(x)\n        return self.fc_out(out[:, -1, :]).squeeze()\n"
  },
  {
    "path": "qlib/contrib/model/pytorch_tra.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport io\nimport os\nimport copy\nimport math\nimport json\nimport numpy as np\nimport pandas as pd\nimport matplotlib.pyplot as plt\n\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nimport torch.nn.functional as F\n\ntry:\n    from torch.utils.tensorboard import SummaryWriter\nexcept ImportError:\n    SummaryWriter = None\n\nfrom tqdm import tqdm\n\nfrom qlib.constant import EPS\nfrom qlib.log import get_module_logger\nfrom qlib.model.base import Model\nfrom qlib.contrib.data.dataset import MTSDatasetH\n\ndevice = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\n\nclass TRAModel(Model):\n    \"\"\"\n    TRA Model\n\n    Args:\n        model_config (dict): model config (will be used by RNN or Transformer)\n        tra_config (dict): TRA config (will be used by TRA)\n        model_type (str): which backbone model to use (RNN/Transformer)\n        lr (float): learning rate\n        n_epochs (int): number of total epochs\n        early_stop (int): early stop when performance not improved at this step\n        update_freq (int): gradient update frequency\n        max_steps_per_epoch (int): maximum number of steps in one epoch\n        lamb (float): regularization parameter\n        rho (float): exponential decay rate for `lamb`\n        alpha (float): fusion parameter for calculating transport loss matrix\n        seed (int): random seed\n        logdir (str): local log directory\n        eval_train (bool): whether evaluate train set between epochs\n        eval_test (bool): whether evaluate test set between epochs\n        pretrain (bool): whether pretrain the backbone model before training TRA.\n            Note that only TRA will be optimized after pretraining\n        init_state (str): model init state path\n        freeze_model (bool): whether freeze backbone model parameters\n        freeze_predictors (bool): whether freeze predictors parameters\n        transport_method (str): transport method, can be none/router/oracle\n        memory_mode (str): memory mode, the same argument for MTSDatasetH\n    \"\"\"\n\n    def __init__(\n        self,\n        model_config,\n        tra_config,\n        model_type=\"RNN\",\n        lr=1e-3,\n        n_epochs=500,\n        early_stop=50,\n        update_freq=1,\n        max_steps_per_epoch=None,\n        lamb=0.0,\n        rho=0.99,\n        alpha=1.0,\n        seed=None,\n        logdir=None,\n        eval_train=False,\n        eval_test=False,\n        pretrain=False,\n        init_state=None,\n        reset_router=False,\n        freeze_model=False,\n        freeze_predictors=False,\n        transport_method=\"none\",\n        memory_mode=\"sample\",\n    ):\n        self.logger = get_module_logger(\"TRA\")\n\n        assert memory_mode in [\"sample\", \"daily\"], \"invalid memory mode\"\n        assert transport_method in [\"none\", \"router\", \"oracle\"], f\"invalid transport method {transport_method}\"\n        assert transport_method == \"none\" or tra_config[\"num_states\"] > 1, \"optimal transport requires `num_states` > 1\"\n        assert (\n            memory_mode != \"daily\" or tra_config[\"src_info\"] == \"TPE\"\n        ), \"daily transport can only support TPE as `src_info`\"\n\n        if transport_method == \"router\" and not eval_train:\n            self.logger.warning(\"`eval_train` will be ignored when using TRA.router\")\n\n        if seed is not None:\n            np.random.seed(seed)\n            torch.manual_seed(seed)\n\n        self.model_config = model_config\n        self.tra_config = tra_config\n        self.model_type = model_type\n        self.lr = lr\n        self.n_epochs = n_epochs\n        self.early_stop = early_stop\n        self.update_freq = update_freq\n        self.max_steps_per_epoch = max_steps_per_epoch\n        self.lamb = lamb\n        self.rho = rho\n        self.alpha = alpha\n        self.seed = seed\n        self.logdir = logdir\n        self.eval_train = eval_train\n        self.eval_test = eval_test\n        self.pretrain = pretrain\n        self.init_state = init_state\n        self.reset_router = reset_router\n        self.freeze_model = freeze_model\n        self.freeze_predictors = freeze_predictors\n        self.transport_method = transport_method\n        self.use_daily_transport = memory_mode == \"daily\"\n        self.transport_fn = transport_daily if self.use_daily_transport else transport_sample\n\n        self._writer = None\n        if self.logdir is not None:\n            if os.path.exists(self.logdir):\n                self.logger.warning(f\"logdir {self.logdir} is not empty\")\n            os.makedirs(self.logdir, exist_ok=True)\n            if SummaryWriter is not None:\n                self._writer = SummaryWriter(log_dir=self.logdir)\n\n        self._init_model()\n\n    def _init_model(self):\n        self.logger.info(\"init TRAModel...\")\n\n        self.model = eval(self.model_type)(**self.model_config).to(device)\n        print(self.model)\n\n        self.tra = TRA(self.model.output_size, **self.tra_config).to(device)\n        print(self.tra)\n\n        if self.init_state:\n            self.logger.warning(f\"load state dict from `init_state`\")\n            state_dict = torch.load(self.init_state, map_location=\"cpu\")\n            self.model.load_state_dict(state_dict[\"model\"])\n            res = load_state_dict_unsafe(self.tra, state_dict[\"tra\"])\n            self.logger.warning(str(res))\n\n        if self.reset_router:\n            self.logger.warning(f\"reset TRA.router parameters\")\n            self.tra.fc.reset_parameters()\n            self.tra.router.reset_parameters()\n\n        if self.freeze_model:\n            self.logger.warning(f\"freeze model parameters\")\n            for param in self.model.parameters():\n                param.requires_grad_(False)\n\n        if self.freeze_predictors:\n            self.logger.warning(f\"freeze TRA.predictors parameters\")\n            for param in self.tra.predictors.parameters():\n                param.requires_grad_(False)\n\n        self.logger.info(\"# model params: %d\" % sum(p.numel() for p in self.model.parameters() if p.requires_grad))\n        self.logger.info(\"# tra params: %d\" % sum(p.numel() for p in self.tra.parameters() if p.requires_grad))\n\n        self.optimizer = optim.Adam(list(self.model.parameters()) + list(self.tra.parameters()), lr=self.lr)\n\n        self.fitted = False\n        self.global_step = -1\n\n    def train_epoch(self, epoch, data_set, is_pretrain=False):\n        self.model.train()\n        self.tra.train()\n        data_set.train()\n        self.optimizer.zero_grad()\n\n        P_all = []\n        prob_all = []\n        choice_all = []\n        max_steps = len(data_set)\n        if self.max_steps_per_epoch is not None:\n            if epoch == 0 and self.max_steps_per_epoch < max_steps:\n                self.logger.info(f\"max steps updated from {max_steps} to {self.max_steps_per_epoch}\")\n            max_steps = min(self.max_steps_per_epoch, max_steps)\n\n        cur_step = 0\n        total_loss = 0\n        total_count = 0\n        for batch in tqdm(data_set, total=max_steps):\n            cur_step += 1\n            if cur_step > max_steps:\n                break\n\n            if not is_pretrain:\n                self.global_step += 1\n\n            data, state, label, count = batch[\"data\"], batch[\"state\"], batch[\"label\"], batch[\"daily_count\"]\n            index = batch[\"daily_index\"] if self.use_daily_transport else batch[\"index\"]\n\n            with torch.set_grad_enabled(not self.freeze_model):\n                hidden = self.model(data)\n\n            all_preds, choice, prob = self.tra(hidden, state)\n\n            if is_pretrain or self.transport_method != \"none\":\n                # NOTE: use oracle transport for pre-training\n                loss, pred, L, P = self.transport_fn(\n                    all_preds,\n                    label,\n                    choice,\n                    prob,\n                    state.mean(dim=1),\n                    count,\n                    self.transport_method if not is_pretrain else \"oracle\",\n                    self.alpha,\n                    training=True,\n                )\n                data_set.assign_data(index, L)  # save loss to memory\n                if self.use_daily_transport:  # only save for daily transport\n                    P_all.append(pd.DataFrame(P.detach().cpu().numpy(), index=index))\n                    prob_all.append(pd.DataFrame(prob.detach().cpu().numpy(), index=index))\n                    choice_all.append(pd.DataFrame(choice.detach().cpu().numpy(), index=index))\n                decay = self.rho ** (self.global_step // 100)  # decay every 100 steps\n                lamb = 0 if is_pretrain else self.lamb * decay\n                reg = prob.log().mul(P).sum(dim=1).mean()  # train router to predict TO assignment\n                if self._writer is not None and not is_pretrain:\n                    self._writer.add_scalar(\"training/router_loss\", -reg.item(), self.global_step)\n                    self._writer.add_scalar(\"training/reg_loss\", loss.item(), self.global_step)\n                    self._writer.add_scalar(\"training/lamb\", lamb, self.global_step)\n                    if not self.use_daily_transport:\n                        P_mean = P.mean(axis=0).detach()\n                        self._writer.add_scalar(\"training/P\", P_mean.max() / P_mean.min(), self.global_step)\n                loss = loss - lamb * reg\n            else:\n                pred = all_preds.mean(dim=1)\n                loss = loss_fn(pred, label)\n\n            (loss / self.update_freq).backward()\n            if cur_step % self.update_freq == 0:\n                self.optimizer.step()\n                self.optimizer.zero_grad()\n\n            if self._writer is not None and not is_pretrain:\n                self._writer.add_scalar(\"training/total_loss\", loss.item(), self.global_step)\n\n            total_loss += loss.item()\n            total_count += 1\n\n        if self.use_daily_transport and len(P_all) > 0:\n            P_all = pd.concat(P_all, axis=0)\n            prob_all = pd.concat(prob_all, axis=0)\n            choice_all = pd.concat(choice_all, axis=0)\n            P_all.index = data_set.restore_daily_index(P_all.index)\n            prob_all.index = P_all.index\n            choice_all.index = P_all.index\n            if not is_pretrain:\n                self._writer.add_image(\"P\", plot(P_all), epoch, dataformats=\"HWC\")\n                self._writer.add_image(\"prob\", plot(prob_all), epoch, dataformats=\"HWC\")\n                self._writer.add_image(\"choice\", plot(choice_all), epoch, dataformats=\"HWC\")\n\n        total_loss /= total_count\n\n        if self._writer is not None and not is_pretrain:\n            self._writer.add_scalar(\"training/loss\", total_loss, epoch)\n\n        return total_loss\n\n    def test_epoch(self, epoch, data_set, return_pred=False, prefix=\"test\", is_pretrain=False):\n        self.model.eval()\n        self.tra.eval()\n        data_set.eval()\n\n        preds = []\n        probs = []\n        P_all = []\n        metrics = []\n        for batch in tqdm(data_set):\n            data, state, label, count = batch[\"data\"], batch[\"state\"], batch[\"label\"], batch[\"daily_count\"]\n            index = batch[\"daily_index\"] if self.use_daily_transport else batch[\"index\"]\n\n            with torch.no_grad():\n                hidden = self.model(data)\n                all_preds, choice, prob = self.tra(hidden, state)\n\n            if is_pretrain or self.transport_method != \"none\":\n                loss, pred, L, P = self.transport_fn(\n                    all_preds,\n                    label,\n                    choice,\n                    prob,\n                    state.mean(dim=1),\n                    count,\n                    self.transport_method if not is_pretrain else \"oracle\",\n                    self.alpha,\n                    training=False,\n                )\n                data_set.assign_data(index, L)  # save loss to memory\n                if P is not None and return_pred:\n                    P_all.append(pd.DataFrame(P.cpu().numpy(), index=index))\n            else:\n                pred = all_preds.mean(dim=1)\n\n            X = np.c_[pred.cpu().numpy(), label.cpu().numpy(), all_preds.cpu().numpy()]\n            columns = [\"score\", \"label\"] + [\"score_%d\" % d for d in range(all_preds.shape[1])]\n            pred = pd.DataFrame(X, index=batch[\"index\"], columns=columns)\n\n            metrics.append(evaluate(pred))\n\n            if return_pred:\n                preds.append(pred)\n                if prob is not None:\n                    columns = [\"prob_%d\" % d for d in range(all_preds.shape[1])]\n                    probs.append(pd.DataFrame(prob.cpu().numpy(), index=index, columns=columns))\n\n        metrics = pd.DataFrame(metrics)\n        metrics = {\n            \"MSE\": metrics.MSE.mean(),\n            \"MAE\": metrics.MAE.mean(),\n            \"IC\": metrics.IC.mean(),\n            \"ICIR\": metrics.IC.mean() / metrics.IC.std(),\n        }\n\n        if self._writer is not None and epoch >= 0 and not is_pretrain:\n            for key, value in metrics.items():\n                self._writer.add_scalar(prefix + \"/\" + key, value, epoch)\n\n        if return_pred:\n            preds = pd.concat(preds, axis=0)\n            preds.index = data_set.restore_index(preds.index)\n            preds.index = preds.index.swaplevel()\n            preds.sort_index(inplace=True)\n\n            if probs:\n                probs = pd.concat(probs, axis=0)\n                if self.use_daily_transport:\n                    probs.index = data_set.restore_daily_index(probs.index)\n                else:\n                    probs.index = data_set.restore_index(probs.index)\n                    probs.index = probs.index.swaplevel()\n                    probs.sort_index(inplace=True)\n\n            if len(P_all):\n                P_all = pd.concat(P_all, axis=0)\n                if self.use_daily_transport:\n                    P_all.index = data_set.restore_daily_index(P_all.index)\n                else:\n                    P_all.index = data_set.restore_index(P_all.index)\n                    P_all.index = P_all.index.swaplevel()\n                    P_all.sort_index(inplace=True)\n\n        return metrics, preds, probs, P_all\n\n    def _fit(self, train_set, valid_set, test_set, evals_result, is_pretrain=True):\n        best_score = -1\n        best_epoch = 0\n        stop_rounds = 0\n        best_params = {\n            \"model\": copy.deepcopy(self.model.state_dict()),\n            \"tra\": copy.deepcopy(self.tra.state_dict()),\n        }\n        # train\n        if not is_pretrain and self.transport_method != \"none\":\n            self.logger.info(\"init memory...\")\n            self.test_epoch(-1, train_set)\n\n        for epoch in range(self.n_epochs):\n            self.logger.info(\"Epoch %d:\", epoch)\n\n            self.logger.info(\"training...\")\n            self.train_epoch(epoch, train_set, is_pretrain=is_pretrain)\n\n            self.logger.info(\"evaluating...\")\n            # NOTE: during evaluating, the whole memory will be refreshed\n            if not is_pretrain and (self.transport_method == \"router\" or self.eval_train):\n                train_set.clear_memory()  # NOTE: clear the shared memory\n                train_metrics = self.test_epoch(epoch, train_set, is_pretrain=is_pretrain, prefix=\"train\")[0]\n                evals_result[\"train\"].append(train_metrics)\n                self.logger.info(\"train metrics: %s\" % train_metrics)\n\n            valid_metrics = self.test_epoch(epoch, valid_set, is_pretrain=is_pretrain, prefix=\"valid\")[0]\n            evals_result[\"valid\"].append(valid_metrics)\n            self.logger.info(\"valid metrics: %s\" % valid_metrics)\n\n            if self.eval_test:\n                test_metrics = self.test_epoch(epoch, test_set, is_pretrain=is_pretrain, prefix=\"test\")[0]\n                evals_result[\"test\"].append(test_metrics)\n                self.logger.info(\"test metrics: %s\" % test_metrics)\n\n            if valid_metrics[\"IC\"] > best_score:\n                best_score = valid_metrics[\"IC\"]\n                stop_rounds = 0\n                best_epoch = epoch\n                best_params = {\n                    \"model\": copy.deepcopy(self.model.state_dict()),\n                    \"tra\": copy.deepcopy(self.tra.state_dict()),\n                }\n                if self.logdir is not None:\n                    torch.save(best_params, self.logdir + \"/model.bin\")\n            else:\n                stop_rounds += 1\n                if stop_rounds >= self.early_stop:\n                    self.logger.info(\"early stop @ %s\" % epoch)\n                    break\n\n        self.logger.info(\"best score: %.6lf @ %d\" % (best_score, best_epoch))\n        self.model.load_state_dict(best_params[\"model\"])\n        self.tra.load_state_dict(best_params[\"tra\"])\n\n        return best_score\n\n    def fit(self, dataset, evals_result=dict()):\n        assert isinstance(dataset, MTSDatasetH), \"TRAModel only supports `qlib.contrib.data.dataset.MTSDatasetH`\"\n\n        train_set, valid_set, test_set = dataset.prepare([\"train\", \"valid\", \"test\"])\n\n        self.fitted = True\n        self.global_step = -1\n\n        evals_result[\"train\"] = []\n        evals_result[\"valid\"] = []\n        evals_result[\"test\"] = []\n\n        if self.pretrain:\n            self.logger.info(\"pretraining...\")\n            self.optimizer = optim.Adam(\n                list(self.model.parameters()) + list(self.tra.predictors.parameters()), lr=self.lr\n            )\n            self._fit(train_set, valid_set, test_set, evals_result, is_pretrain=True)\n\n            # reset optimizer\n            self.optimizer = optim.Adam(list(self.model.parameters()) + list(self.tra.parameters()), lr=self.lr)\n\n        self.logger.info(\"training...\")\n        best_score = self._fit(train_set, valid_set, test_set, evals_result, is_pretrain=False)\n\n        self.logger.info(\"inference\")\n        train_metrics, train_preds, train_probs, train_P = self.test_epoch(-1, train_set, return_pred=True)\n        self.logger.info(\"train metrics: %s\" % train_metrics)\n\n        valid_metrics, valid_preds, valid_probs, valid_P = self.test_epoch(-1, valid_set, return_pred=True)\n        self.logger.info(\"valid metrics: %s\" % valid_metrics)\n\n        test_metrics, test_preds, test_probs, test_P = self.test_epoch(-1, test_set, return_pred=True)\n        self.logger.info(\"test metrics: %s\" % test_metrics)\n\n        if self.logdir:\n            self.logger.info(\"save model & pred to local directory\")\n\n            pd.concat({name: pd.DataFrame(evals_result[name]) for name in evals_result}, axis=1).to_csv(\n                self.logdir + \"/logs.csv\", index=False\n            )\n\n            torch.save({\"model\": self.model.state_dict(), \"tra\": self.tra.state_dict()}, self.logdir + \"/model.bin\")\n\n            train_preds.to_pickle(self.logdir + \"/train_pred.pkl\")\n            valid_preds.to_pickle(self.logdir + \"/valid_pred.pkl\")\n            test_preds.to_pickle(self.logdir + \"/test_pred.pkl\")\n\n            if len(train_probs):\n                train_probs.to_pickle(self.logdir + \"/train_prob.pkl\")\n                valid_probs.to_pickle(self.logdir + \"/valid_prob.pkl\")\n                test_probs.to_pickle(self.logdir + \"/test_prob.pkl\")\n\n            if len(train_P):\n                train_P.to_pickle(self.logdir + \"/train_P.pkl\")\n                valid_P.to_pickle(self.logdir + \"/valid_P.pkl\")\n                test_P.to_pickle(self.logdir + \"/test_P.pkl\")\n\n            info = {\n                \"config\": {\n                    \"model_config\": self.model_config,\n                    \"tra_config\": self.tra_config,\n                    \"model_type\": self.model_type,\n                    \"lr\": self.lr,\n                    \"n_epochs\": self.n_epochs,\n                    \"early_stop\": self.early_stop,\n                    \"max_steps_per_epoch\": self.max_steps_per_epoch,\n                    \"lamb\": self.lamb,\n                    \"rho\": self.rho,\n                    \"alpha\": self.alpha,\n                    \"seed\": self.seed,\n                    \"logdir\": self.logdir,\n                    \"pretrain\": self.pretrain,\n                    \"init_state\": self.init_state,\n                    \"transport_method\": self.transport_method,\n                    \"use_daily_transport\": self.use_daily_transport,\n                },\n                \"best_eval_metric\": -best_score,  # NOTE: -1 for minimize\n                \"metrics\": {\"train\": train_metrics, \"valid\": valid_metrics, \"test\": test_metrics},\n            }\n            with open(self.logdir + \"/info.json\", \"w\") as f:\n                json.dump(info, f)\n\n    def predict(self, dataset, segment=\"test\"):\n        assert isinstance(dataset, MTSDatasetH), \"TRAModel only supports `qlib.contrib.data.dataset.MTSDatasetH`\"\n\n        if not self.fitted:\n            raise ValueError(\"model is not fitted yet!\")\n\n        test_set = dataset.prepare(segment)\n\n        metrics, preds, _, _ = self.test_epoch(-1, test_set, return_pred=True)\n        self.logger.info(\"test metrics: %s\" % metrics)\n\n        return preds\n\n\nclass RNN(nn.Module):\n    \"\"\"RNN Model\n\n    Args:\n        input_size (int): input size (# features)\n        hidden_size (int): hidden size\n        num_layers (int): number of hidden layers\n        rnn_arch (str): rnn architecture\n        use_attn (bool): whether use attention layer.\n            we use concat attention as https://github.com/fulifeng/Adv-ALSTM/\n        dropout (float): dropout rate\n    \"\"\"\n\n    def __init__(\n        self,\n        input_size=16,\n        hidden_size=64,\n        num_layers=2,\n        rnn_arch=\"GRU\",\n        use_attn=True,\n        dropout=0.0,\n        **kwargs,\n    ):\n        super().__init__()\n\n        self.input_size = input_size\n        self.hidden_size = hidden_size\n        self.num_layers = num_layers\n        self.rnn_arch = rnn_arch\n        self.use_attn = use_attn\n\n        if hidden_size < input_size:\n            # compression\n            self.input_proj = nn.Linear(input_size, hidden_size)\n        else:\n            self.input_proj = None\n\n        self.rnn = getattr(nn, rnn_arch)(\n            input_size=min(input_size, hidden_size),\n            hidden_size=hidden_size,\n            num_layers=num_layers,\n            batch_first=True,\n            dropout=dropout,\n        )\n\n        if self.use_attn:\n            self.W = nn.Linear(hidden_size, hidden_size)\n            self.u = nn.Linear(hidden_size, 1, bias=False)\n            self.output_size = hidden_size * 2\n        else:\n            self.output_size = hidden_size\n\n    def forward(self, x):\n        if self.input_proj is not None:\n            x = self.input_proj(x)\n\n        rnn_out, last_out = self.rnn(x)\n        if self.rnn_arch == \"LSTM\":\n            last_out = last_out[0]\n        last_out = last_out.mean(dim=0)\n\n        if self.use_attn:\n            laten = self.W(rnn_out).tanh()\n            scores = self.u(laten).softmax(dim=1)\n            att_out = (rnn_out * scores).sum(dim=1)\n            last_out = torch.cat([last_out, att_out], dim=1)\n\n        return last_out\n\n\nclass PositionalEncoding(nn.Module):\n    # reference: https://pytorch.org/tutorials/beginner/transformer_tutorial.html\n    def __init__(self, d_model, dropout=0.1, max_len=5000):\n        super(PositionalEncoding, self).__init__()\n        self.dropout = nn.Dropout(p=dropout)\n\n        pe = torch.zeros(max_len, d_model)\n        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)\n        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))\n        pe[:, 0::2] = torch.sin(position * div_term)\n        pe[:, 1::2] = torch.cos(position * div_term)\n        pe = pe.unsqueeze(0).transpose(0, 1)\n        self.register_buffer(\"pe\", pe)\n\n    def forward(self, x):\n        x = x + self.pe[: x.size(0), :]\n        return self.dropout(x)\n\n\nclass Transformer(nn.Module):\n    \"\"\"Transformer Model\n\n    Args:\n        input_size (int): input size (# features)\n        hidden_size (int): hidden size\n        num_layers (int): number of transformer layers\n        num_heads (int): number of heads in transformer\n        dropout (float): dropout rate\n    \"\"\"\n\n    def __init__(\n        self,\n        input_size=16,\n        hidden_size=64,\n        num_layers=2,\n        num_heads=2,\n        dropout=0.0,\n        **kwargs,\n    ):\n        super().__init__()\n\n        self.input_size = input_size\n        self.hidden_size = hidden_size\n        self.num_layers = num_layers\n        self.num_heads = num_heads\n\n        self.input_proj = nn.Linear(input_size, hidden_size)\n\n        self.pe = PositionalEncoding(input_size, dropout)\n        layer = nn.TransformerEncoderLayer(\n            nhead=num_heads, dropout=dropout, d_model=hidden_size, dim_feedforward=hidden_size * 4\n        )\n        self.encoder = nn.TransformerEncoder(layer, num_layers=num_layers)\n\n        self.output_size = hidden_size\n\n    def forward(self, x):\n        x = x.permute(1, 0, 2).contiguous()  # the first dim need to be time\n        x = self.pe(x)\n\n        x = self.input_proj(x)\n        out = self.encoder(x)\n\n        return out[-1]\n\n\nclass TRA(nn.Module):\n    \"\"\"Temporal Routing Adaptor (TRA)\n\n    TRA takes historical prediction errors & latent representation as inputs,\n    then routes the input sample to a specific predictor for training & inference.\n\n    Args:\n        input_size (int): input size (RNN/Transformer's hidden size)\n        num_states (int): number of latent states (i.e., trading patterns)\n            If `num_states=1`, then TRA falls back to traditional methods\n        hidden_size (int): hidden size of the router\n        tau (float): gumbel softmax temperature\n        src_info (str): information for the router\n    \"\"\"\n\n    def __init__(\n        self,\n        input_size,\n        num_states=1,\n        hidden_size=8,\n        rnn_arch=\"GRU\",\n        num_layers=1,\n        dropout=0.0,\n        tau=1.0,\n        src_info=\"LR_TPE\",\n    ):\n        super().__init__()\n\n        assert src_info in [\"LR\", \"TPE\", \"LR_TPE\"], \"invalid `src_info`\"\n\n        self.num_states = num_states\n        self.tau = tau\n        self.rnn_arch = rnn_arch\n        self.src_info = src_info\n\n        self.predictors = nn.Linear(input_size, num_states)\n\n        if self.num_states > 1:\n            if \"TPE\" in src_info:\n                self.router = getattr(nn, rnn_arch)(\n                    input_size=num_states,\n                    hidden_size=hidden_size,\n                    num_layers=num_layers,\n                    batch_first=True,\n                    dropout=dropout,\n                )\n                self.fc = nn.Linear(hidden_size + input_size if \"LR\" in src_info else hidden_size, num_states)\n            else:\n                self.fc = nn.Linear(input_size, num_states)\n\n    def reset_parameters(self):\n        for child in self.children():\n            child.reset_parameters()\n\n    def forward(self, hidden, hist_loss):\n        preds = self.predictors(hidden)\n\n        if self.num_states == 1:  # no need for router when having only one prediction\n            return preds, None, None\n\n        if \"TPE\" in self.src_info:\n            out = self.router(hist_loss)[1]  # TPE\n            if self.rnn_arch == \"LSTM\":\n                out = out[0]\n            out = out.mean(dim=0)\n            if \"LR\" in self.src_info:\n                out = torch.cat([hidden, out], dim=-1)  # LR_TPE\n        else:\n            out = hidden  # LR\n\n        out = self.fc(out)\n\n        choice = F.gumbel_softmax(out, dim=-1, tau=self.tau, hard=True)\n        prob = torch.softmax(out / self.tau, dim=-1)\n\n        return preds, choice, prob\n\n\ndef evaluate(pred):\n    pred = pred.rank(pct=True)  # transform into percentiles\n    score = pred.score\n    label = pred.label\n    diff = score - label\n    MSE = (diff**2).mean()\n    MAE = (diff.abs()).mean()\n    IC = score.corr(label, method=\"spearman\")\n    return {\"MSE\": MSE, \"MAE\": MAE, \"IC\": IC}\n\n\ndef shoot_infs(inp_tensor):\n    \"\"\"Replaces inf by maximum of tensor\"\"\"\n    mask_inf = torch.isinf(inp_tensor)\n    ind_inf = torch.nonzero(mask_inf, as_tuple=False)\n    if len(ind_inf) > 0:\n        for ind in ind_inf:\n            if len(ind) == 2:\n                inp_tensor[ind[0], ind[1]] = 0\n            elif len(ind) == 1:\n                inp_tensor[ind[0]] = 0\n        m = torch.max(inp_tensor)\n        for ind in ind_inf:\n            if len(ind) == 2:\n                inp_tensor[ind[0], ind[1]] = m\n            elif len(ind) == 1:\n                inp_tensor[ind[0]] = m\n    return inp_tensor\n\n\ndef sinkhorn(Q, n_iters=3, epsilon=0.1):\n    # epsilon should be adjusted according to logits value's scale\n    with torch.no_grad():\n        Q = torch.exp(Q / epsilon)\n        Q = shoot_infs(Q)\n        for i in range(n_iters):\n            Q /= Q.sum(dim=0, keepdim=True)\n            Q /= Q.sum(dim=1, keepdim=True)\n    return Q\n\n\ndef loss_fn(pred, label):\n    mask = ~torch.isnan(label)\n    if len(pred.shape) == 2:\n        label = label[:, None]\n    return (pred[mask] - label[mask]).pow(2).mean(dim=0)\n\n\ndef minmax_norm(x):\n    xmin = x.min(dim=-1, keepdim=True).values\n    xmax = x.max(dim=-1, keepdim=True).values\n    mask = (xmin == xmax).squeeze()\n    x = (x - xmin) / (xmax - xmin + EPS)\n    x[mask] = 1\n    return x\n\n\ndef transport_sample(all_preds, label, choice, prob, hist_loss, count, transport_method, alpha, training=False):\n    \"\"\"\n    sample-wise transport\n\n    Args:\n        all_preds (torch.Tensor): predictions from all predictors, [sample x states]\n        label (torch.Tensor): label, [sample]\n        choice (torch.Tensor): gumbel softmax choice, [sample x states]\n        prob (torch.Tensor): router predicted probility, [sample x states]\n        hist_loss (torch.Tensor): history loss matrix, [sample x states]\n        count (list): sample counts for each day, empty list for sample-wise transport\n        transport_method (str): transportation method\n        alpha (float): fusion parameter for calculating transport loss matrix\n        training (bool): indicate training or inference\n    \"\"\"\n    assert all_preds.shape == choice.shape\n    assert len(all_preds) == len(label)\n    assert transport_method in [\"oracle\", \"router\"]\n\n    all_loss = torch.zeros_like(all_preds)\n    mask = ~torch.isnan(label)\n    all_loss[mask] = (all_preds[mask] - label[mask, None]).pow(2)  # [sample x states]\n\n    L = minmax_norm(all_loss.detach())\n    Lh = L * alpha + minmax_norm(hist_loss) * (1 - alpha)  # add hist loss for transport\n    Lh = minmax_norm(Lh)\n    P = sinkhorn(-Lh)\n    del Lh\n\n    if transport_method == \"router\":\n        if training:\n            pred = (all_preds * choice).sum(dim=1)  # gumbel softmax\n        else:\n            pred = all_preds[range(len(all_preds)), prob.argmax(dim=-1)]  # argmax\n    else:\n        pred = (all_preds * P).sum(dim=1)\n\n    if transport_method == \"router\":\n        loss = loss_fn(pred, label)\n    else:\n        loss = (all_loss * P).sum(dim=1).mean()\n\n    return loss, pred, L, P\n\n\ndef transport_daily(all_preds, label, choice, prob, hist_loss, count, transport_method, alpha, training=False):\n    \"\"\"\n    daily transport\n\n    Args:\n        all_preds (torch.Tensor): predictions from all predictors, [sample x states]\n        label (torch.Tensor): label, [sample]\n        choice (torch.Tensor): gumbel softmax choice, [days x states]\n        prob (torch.Tensor): router predicted probility, [days x states]\n        hist_loss (torch.Tensor): history loss matrix, [days x states]\n        count (list): sample counts for each day, [days]\n        transport_method (str): transportation method\n        alpha (float): fusion parameter for calculating transport loss matrix\n        training (bool): indicate training or inference\n    \"\"\"\n    assert len(prob) == len(count)\n    assert len(all_preds) == sum(count)\n    assert transport_method in [\"oracle\", \"router\"]\n\n    all_loss = []  # loss of all predictions\n    start = 0\n    for i, cnt in enumerate(count):\n        slc = slice(start, start + cnt)  # samples from the i-th day\n        start += cnt\n        tloss = loss_fn(all_preds[slc], label[slc])  # loss of the i-th day\n        all_loss.append(tloss)\n    all_loss = torch.stack(all_loss, dim=0)  # [days x states]\n\n    L = minmax_norm(all_loss.detach())\n    Lh = L * alpha + minmax_norm(hist_loss) * (1 - alpha)  # add hist loss for transport\n    Lh = minmax_norm(Lh)\n    P = sinkhorn(-Lh)\n    del Lh\n\n    pred = []\n    start = 0\n    for i, cnt in enumerate(count):\n        slc = slice(start, start + cnt)  # samples from the i-th day\n        start += cnt\n        if transport_method == \"router\":\n            if training:\n                tpred = all_preds[slc] @ choice[i]  # gumbel softmax\n            else:\n                tpred = all_preds[slc][:, prob[i].argmax(dim=-1)]  # argmax\n        else:\n            tpred = all_preds[slc] @ P[i]\n        pred.append(tpred)\n    pred = torch.cat(pred, dim=0)  # [samples]\n\n    if transport_method == \"router\":\n        loss = loss_fn(pred, label)\n    else:\n        loss = (all_loss * P).sum(dim=1).mean()\n\n    return loss, pred, L, P\n\n\ndef load_state_dict_unsafe(model, state_dict):\n    \"\"\"\n    Load state dict to provided model while ignore exceptions.\n    \"\"\"\n\n    missing_keys = []\n    unexpected_keys = []\n    error_msgs = []\n\n    # copy state_dict so _load_from_state_dict can modify it\n    metadata = getattr(state_dict, \"_metadata\", None)\n    state_dict = state_dict.copy()\n    if metadata is not None:\n        state_dict._metadata = metadata\n\n    def load(module, prefix=\"\"):\n        local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})\n        module._load_from_state_dict(\n            state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs\n        )\n        for name, child in module._modules.items():\n            if child is not None:\n                load(child, prefix + name + \".\")\n\n    load(model)\n    load = None  # break load->load reference cycle\n\n    return {\"unexpected_keys\": unexpected_keys, \"missing_keys\": missing_keys, \"error_msgs\": error_msgs}\n\n\ndef plot(P):\n    assert isinstance(P, pd.DataFrame)\n\n    fig, axes = plt.subplots(1, 2, figsize=(10, 4))\n    P.plot.area(ax=axes[0], xlabel=\"\")\n    P.idxmax(axis=1).value_counts().sort_index().plot.bar(ax=axes[1], xlabel=\"\")\n    plt.tight_layout()\n\n    with io.BytesIO() as buf:\n        plt.savefig(buf, format=\"png\")\n        buf.seek(0)\n        img = plt.imread(buf)\n        plt.close()\n\n    return np.uint8(img * 255)\n"
  },
  {
    "path": "qlib/contrib/model/pytorch_transformer.py",
    "content": "# Copyright (c) Microsoft Corporation.\r\n# Licensed under the MIT License.\r\n\r\n\r\nfrom __future__ import division\r\nfrom __future__ import print_function\r\n\r\nimport numpy as np\r\nimport pandas as pd\r\nfrom typing import Text, Union\r\nimport copy\r\nimport math\r\nfrom ...utils import get_or_create_path\r\nfrom ...log import get_module_logger\r\n\r\nimport torch\r\nimport torch.nn as nn\r\nimport torch.optim as optim\r\n\r\nfrom ...model.base import Model\r\nfrom ...data.dataset import DatasetH\r\nfrom ...data.dataset.handler import DataHandlerLP\r\n\r\n# qrun examples/benchmarks/Transformer/workflow_config_transformer_Alpha360.yaml ”\r\n\r\n\r\nclass TransformerModel(Model):\r\n    def __init__(\r\n        self,\r\n        d_feat: int = 20,\r\n        d_model: int = 64,\r\n        batch_size: int = 2048,\r\n        nhead: int = 2,\r\n        num_layers: int = 2,\r\n        dropout: float = 0,\r\n        n_epochs=100,\r\n        lr=0.0001,\r\n        metric=\"\",\r\n        early_stop=5,\r\n        loss=\"mse\",\r\n        optimizer=\"adam\",\r\n        reg=1e-3,\r\n        n_jobs=10,\r\n        GPU=0,\r\n        seed=None,\r\n        **kwargs,\r\n    ):\r\n        # set hyper-parameters.\r\n        self.d_model = d_model\r\n        self.dropout = dropout\r\n        self.n_epochs = n_epochs\r\n        self.lr = lr\r\n        self.reg = reg\r\n        self.metric = metric\r\n        self.batch_size = batch_size\r\n        self.early_stop = early_stop\r\n        self.optimizer = optimizer.lower()\r\n        self.loss = loss\r\n        self.n_jobs = n_jobs\r\n        self.device = torch.device(\"cuda:%d\" % GPU if torch.cuda.is_available() and GPU >= 0 else \"cpu\")\r\n        self.seed = seed\r\n        self.logger = get_module_logger(\"TransformerModel\")\r\n        self.logger.info(\"Naive Transformer:\" \"\\nbatch_size : {}\" \"\\ndevice : {}\".format(self.batch_size, self.device))\r\n\r\n        if self.seed is not None:\r\n            np.random.seed(self.seed)\r\n            torch.manual_seed(self.seed)\r\n\r\n        self.model = Transformer(d_feat, d_model, nhead, num_layers, dropout, self.device)\r\n        if optimizer.lower() == \"adam\":\r\n            self.train_optimizer = optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.reg)\r\n        elif optimizer.lower() == \"gd\":\r\n            self.train_optimizer = optim.SGD(self.model.parameters(), lr=self.lr, weight_decay=self.reg)\r\n        else:\r\n            raise NotImplementedError(\"optimizer {} is not supported!\".format(optimizer))\r\n\r\n        self.fitted = False\r\n        self.model.to(self.device)\r\n\r\n    @property\r\n    def use_gpu(self):\r\n        return self.device != torch.device(\"cpu\")\r\n\r\n    def mse(self, pred, label):\r\n        loss = (pred.float() - label.float()) ** 2\r\n        return torch.mean(loss)\r\n\r\n    def loss_fn(self, pred, label):\r\n        mask = ~torch.isnan(label)\r\n\r\n        if self.loss == \"mse\":\r\n            return self.mse(pred[mask], label[mask])\r\n\r\n        raise ValueError(\"unknown loss `%s`\" % self.loss)\r\n\r\n    def metric_fn(self, pred, label):\r\n        mask = torch.isfinite(label)\r\n\r\n        if self.metric in (\"\", \"loss\"):\r\n            return -self.loss_fn(pred[mask], label[mask])\r\n\r\n        raise ValueError(\"unknown metric `%s`\" % self.metric)\r\n\r\n    def train_epoch(self, x_train, y_train):\r\n        x_train_values = x_train.values\r\n        y_train_values = np.squeeze(y_train.values)\r\n\r\n        self.model.train()\r\n\r\n        indices = np.arange(len(x_train_values))\r\n        np.random.shuffle(indices)\r\n\r\n        for i in range(len(indices))[:: self.batch_size]:\r\n            if len(indices) - i < self.batch_size:\r\n                break\r\n\r\n            feature = torch.from_numpy(x_train_values[indices[i : i + self.batch_size]]).float().to(self.device)\r\n            label = torch.from_numpy(y_train_values[indices[i : i + self.batch_size]]).float().to(self.device)\r\n\r\n            pred = self.model(feature)\r\n            loss = self.loss_fn(pred, label)\r\n\r\n            self.train_optimizer.zero_grad()\r\n            loss.backward()\r\n            torch.nn.utils.clip_grad_value_(self.model.parameters(), 3.0)\r\n            self.train_optimizer.step()\r\n\r\n    def test_epoch(self, data_x, data_y):\r\n        # prepare training data\r\n        x_values = data_x.values\r\n        y_values = np.squeeze(data_y.values)\r\n\r\n        self.model.eval()\r\n\r\n        scores = []\r\n        losses = []\r\n\r\n        indices = np.arange(len(x_values))\r\n\r\n        for i in range(len(indices))[:: self.batch_size]:\r\n            if len(indices) - i < self.batch_size:\r\n                break\r\n\r\n            feature = torch.from_numpy(x_values[indices[i : i + self.batch_size]]).float().to(self.device)\r\n            label = torch.from_numpy(y_values[indices[i : i + self.batch_size]]).float().to(self.device)\r\n\r\n            with torch.no_grad():\r\n                pred = self.model(feature)\r\n                loss = self.loss_fn(pred, label)\r\n                losses.append(loss.item())\r\n\r\n                score = self.metric_fn(pred, label)\r\n                scores.append(score.item())\r\n\r\n        return np.mean(losses), np.mean(scores)\r\n\r\n    def fit(\r\n        self,\r\n        dataset: DatasetH,\r\n        evals_result=dict(),\r\n        save_path=None,\r\n    ):\r\n        df_train, df_valid, df_test = dataset.prepare(\r\n            [\"train\", \"valid\", \"test\"],\r\n            col_set=[\"feature\", \"label\"],\r\n            data_key=DataHandlerLP.DK_L,\r\n        )\r\n        if df_train.empty or df_valid.empty:\r\n            raise ValueError(\"Empty data from dataset, please check your dataset config.\")\r\n\r\n        x_train, y_train = df_train[\"feature\"], df_train[\"label\"]\r\n        x_valid, y_valid = df_valid[\"feature\"], df_valid[\"label\"]\r\n\r\n        save_path = get_or_create_path(save_path)\r\n        stop_steps = 0\r\n        train_loss = 0\r\n        best_score = -np.inf\r\n        best_epoch = 0\r\n        evals_result[\"train\"] = []\r\n        evals_result[\"valid\"] = []\r\n\r\n        # train\r\n        self.logger.info(\"training...\")\r\n        self.fitted = True\r\n\r\n        for step in range(self.n_epochs):\r\n            self.logger.info(\"Epoch%d:\", step)\r\n            self.logger.info(\"training...\")\r\n            self.train_epoch(x_train, y_train)\r\n            self.logger.info(\"evaluating...\")\r\n            train_loss, train_score = self.test_epoch(x_train, y_train)\r\n            val_loss, val_score = self.test_epoch(x_valid, y_valid)\r\n            self.logger.info(\"train %.6f, valid %.6f\" % (train_score, val_score))\r\n            evals_result[\"train\"].append(train_score)\r\n            evals_result[\"valid\"].append(val_score)\r\n\r\n            if val_score > best_score:\r\n                best_score = val_score\r\n                stop_steps = 0\r\n                best_epoch = step\r\n                best_param = copy.deepcopy(self.model.state_dict())\r\n            else:\r\n                stop_steps += 1\r\n                if stop_steps >= self.early_stop:\r\n                    self.logger.info(\"early stop\")\r\n                    break\r\n\r\n        self.logger.info(\"best score: %.6lf @ %d\" % (best_score, best_epoch))\r\n        self.model.load_state_dict(best_param)\r\n        torch.save(best_param, save_path)\r\n\r\n        if self.use_gpu:\r\n            torch.cuda.empty_cache()\r\n\r\n    def predict(self, dataset: DatasetH, segment: Union[Text, slice] = \"test\"):\r\n        if not self.fitted:\r\n            raise ValueError(\"model is not fitted yet!\")\r\n\r\n        x_test = dataset.prepare(segment, col_set=\"feature\", data_key=DataHandlerLP.DK_I)\r\n        index = x_test.index\r\n        self.model.eval()\r\n        x_values = x_test.values\r\n        sample_num = x_values.shape[0]\r\n        preds = []\r\n\r\n        for begin in range(sample_num)[:: self.batch_size]:\r\n            if sample_num - begin < self.batch_size:\r\n                end = sample_num\r\n            else:\r\n                end = begin + self.batch_size\r\n\r\n            x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)\r\n\r\n            with torch.no_grad():\r\n                pred = self.model(x_batch).detach().cpu().numpy()\r\n\r\n            preds.append(pred)\r\n\r\n        return pd.Series(np.concatenate(preds), index=index)\r\n\r\n\r\nclass PositionalEncoding(nn.Module):\r\n    def __init__(self, d_model, max_len=1000):\r\n        super(PositionalEncoding, self).__init__()\r\n        pe = torch.zeros(max_len, d_model)\r\n        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)\r\n        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))\r\n        pe[:, 0::2] = torch.sin(position * div_term)\r\n        pe[:, 1::2] = torch.cos(position * div_term)\r\n        pe = pe.unsqueeze(0).transpose(0, 1)\r\n        self.register_buffer(\"pe\", pe)\r\n\r\n    def forward(self, x):\r\n        # [T, N, F]\r\n        return x + self.pe[: x.size(0), :]\r\n\r\n\r\nclass Transformer(nn.Module):\r\n    def __init__(self, d_feat=6, d_model=8, nhead=4, num_layers=2, dropout=0.5, device=None):\r\n        super(Transformer, self).__init__()\r\n        self.feature_layer = nn.Linear(d_feat, d_model)\r\n        self.pos_encoder = PositionalEncoding(d_model)\r\n        self.encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dropout=dropout)\r\n        self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers)\r\n        self.decoder_layer = nn.Linear(d_model, 1)\r\n        self.device = device\r\n        self.d_feat = d_feat\r\n\r\n    def forward(self, src):\r\n        # src [N, F*T] --> [N, T, F]\r\n        src = src.reshape(len(src), self.d_feat, -1).permute(0, 2, 1)\r\n        src = self.feature_layer(src)\r\n\r\n        # src [N, T, F] --> [T, N, F], [60, 512, 8]\r\n        src = src.transpose(1, 0)  # not batch first\r\n\r\n        mask = None\r\n\r\n        src = self.pos_encoder(src)\r\n        output = self.transformer_encoder(src, mask)  # [60, 512, 8]\r\n\r\n        # [T, N, F] --> [N, T*F]\r\n        output = self.decoder_layer(output.transpose(1, 0)[:, -1, :])  # [512, 1]\r\n\r\n        return output.squeeze()\r\n"
  },
  {
    "path": "qlib/contrib/model/pytorch_transformer_ts.py",
    "content": "# Copyright (c) Microsoft Corporation.\r\n# Licensed under the MIT License.\r\n\r\n\r\nfrom __future__ import division\r\nfrom __future__ import print_function\r\n\r\nimport numpy as np\r\nimport pandas as pd\r\nimport copy\r\nimport math\r\nfrom ...utils import get_or_create_path\r\nfrom ...log import get_module_logger\r\n\r\nimport torch\r\nimport torch.nn as nn\r\nimport torch.optim as optim\r\nfrom torch.utils.data import DataLoader\r\n\r\nfrom ...model.base import Model\r\nfrom ...data.dataset import DatasetH\r\nfrom ...data.dataset.handler import DataHandlerLP\r\n\r\n\r\nclass TransformerModel(Model):\r\n    def __init__(\r\n        self,\r\n        d_feat: int = 20,\r\n        d_model: int = 64,\r\n        batch_size: int = 8192,\r\n        nhead: int = 2,\r\n        num_layers: int = 2,\r\n        dropout: float = 0,\r\n        n_epochs=100,\r\n        lr=0.0001,\r\n        metric=\"\",\r\n        early_stop=5,\r\n        loss=\"mse\",\r\n        optimizer=\"adam\",\r\n        reg=1e-3,\r\n        n_jobs=10,\r\n        GPU=0,\r\n        seed=None,\r\n        **kwargs,\r\n    ):\r\n        # set hyper-parameters.\r\n        self.d_model = d_model\r\n        self.dropout = dropout\r\n        self.n_epochs = n_epochs\r\n        self.lr = lr\r\n        self.reg = reg\r\n        self.metric = metric\r\n        self.batch_size = batch_size\r\n        self.early_stop = early_stop\r\n        self.optimizer = optimizer.lower()\r\n        self.loss = loss\r\n        self.n_jobs = n_jobs\r\n        self.device = torch.device(\"cuda:%d\" % GPU if torch.cuda.is_available() and GPU >= 0 else \"cpu\")\r\n        self.seed = seed\r\n        self.logger = get_module_logger(\"TransformerModel\")\r\n        self.logger.info(\"Naive Transformer:\" \"\\nbatch_size : {}\" \"\\ndevice : {}\".format(self.batch_size, self.device))\r\n\r\n        if self.seed is not None:\r\n            np.random.seed(self.seed)\r\n            torch.manual_seed(self.seed)\r\n\r\n        self.model = Transformer(d_feat, d_model, nhead, num_layers, dropout, self.device)\r\n        if optimizer.lower() == \"adam\":\r\n            self.train_optimizer = optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.reg)\r\n        elif optimizer.lower() == \"gd\":\r\n            self.train_optimizer = optim.SGD(self.model.parameters(), lr=self.lr, weight_decay=self.reg)\r\n        else:\r\n            raise NotImplementedError(\"optimizer {} is not supported!\".format(optimizer))\r\n\r\n        self.fitted = False\r\n        self.model.to(self.device)\r\n\r\n    @property\r\n    def use_gpu(self):\r\n        return self.device != torch.device(\"cpu\")\r\n\r\n    def mse(self, pred, label):\r\n        loss = (pred.float() - label.float()) ** 2\r\n        return torch.mean(loss)\r\n\r\n    def loss_fn(self, pred, label):\r\n        mask = ~torch.isnan(label)\r\n\r\n        if self.loss == \"mse\":\r\n            return self.mse(pred[mask], label[mask])\r\n\r\n        raise ValueError(\"unknown loss `%s`\" % self.loss)\r\n\r\n    def metric_fn(self, pred, label):\r\n        mask = torch.isfinite(label)\r\n\r\n        if self.metric in (\"\", \"loss\"):\r\n            return -self.loss_fn(pred[mask], label[mask])\r\n\r\n        raise ValueError(\"unknown metric `%s`\" % self.metric)\r\n\r\n    def train_epoch(self, data_loader):\r\n        self.model.train()\r\n\r\n        for data in data_loader:\r\n            feature = data[:, :, 0:-1].to(self.device)\r\n            label = data[:, -1, -1].to(self.device)\r\n\r\n            pred = self.model(feature.float())  # .float()\r\n            loss = self.loss_fn(pred, label)\r\n\r\n            self.train_optimizer.zero_grad()\r\n            loss.backward()\r\n            torch.nn.utils.clip_grad_value_(self.model.parameters(), 3.0)\r\n            self.train_optimizer.step()\r\n\r\n    def test_epoch(self, data_loader):\r\n        self.model.eval()\r\n\r\n        scores = []\r\n        losses = []\r\n\r\n        for data in data_loader:\r\n            feature = data[:, :, 0:-1].to(self.device)\r\n            label = data[:, -1, -1].to(self.device)\r\n\r\n            with torch.no_grad():\r\n                pred = self.model(feature.float())  # .float()\r\n                loss = self.loss_fn(pred, label)\r\n                losses.append(loss.item())\r\n\r\n                score = self.metric_fn(pred, label)\r\n                scores.append(score.item())\r\n\r\n        return np.mean(losses), np.mean(scores)\r\n\r\n    def fit(\r\n        self,\r\n        dataset: DatasetH,\r\n        evals_result=dict(),\r\n        save_path=None,\r\n    ):\r\n        dl_train = dataset.prepare(\"train\", col_set=[\"feature\", \"label\"], data_key=DataHandlerLP.DK_L)\r\n        dl_valid = dataset.prepare(\"valid\", col_set=[\"feature\", \"label\"], data_key=DataHandlerLP.DK_L)\r\n\r\n        if dl_train.empty or dl_valid.empty:\r\n            raise ValueError(\"Empty data from dataset, please check your dataset config.\")\r\n\r\n        dl_train.config(fillna_type=\"ffill+bfill\")  # process nan brought by dataloader\r\n        dl_valid.config(fillna_type=\"ffill+bfill\")  # process nan brought by dataloader\r\n\r\n        train_loader = DataLoader(\r\n            dl_train, batch_size=self.batch_size, shuffle=True, num_workers=self.n_jobs, drop_last=True\r\n        )\r\n        valid_loader = DataLoader(\r\n            dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs, drop_last=True\r\n        )\r\n\r\n        save_path = get_or_create_path(save_path)\r\n\r\n        stop_steps = 0\r\n        train_loss = 0\r\n        best_score = -np.inf\r\n        best_epoch = 0\r\n        evals_result[\"train\"] = []\r\n        evals_result[\"valid\"] = []\r\n\r\n        # train\r\n        self.logger.info(\"training...\")\r\n        self.fitted = True\r\n\r\n        for step in range(self.n_epochs):\r\n            self.logger.info(\"Epoch%d:\", step)\r\n            self.logger.info(\"training...\")\r\n            self.train_epoch(train_loader)\r\n            self.logger.info(\"evaluating...\")\r\n            train_loss, train_score = self.test_epoch(train_loader)\r\n            val_loss, val_score = self.test_epoch(valid_loader)\r\n            self.logger.info(\"train %.6f, valid %.6f\" % (train_score, val_score))\r\n            evals_result[\"train\"].append(train_score)\r\n            evals_result[\"valid\"].append(val_score)\r\n\r\n            if val_score > best_score:\r\n                best_score = val_score\r\n                stop_steps = 0\r\n                best_epoch = step\r\n                best_param = copy.deepcopy(self.model.state_dict())\r\n            else:\r\n                stop_steps += 1\r\n                if stop_steps >= self.early_stop:\r\n                    self.logger.info(\"early stop\")\r\n                    break\r\n\r\n        self.logger.info(\"best score: %.6lf @ %d\" % (best_score, best_epoch))\r\n        self.model.load_state_dict(best_param)\r\n        torch.save(best_param, save_path)\r\n\r\n        if self.use_gpu:\r\n            torch.cuda.empty_cache()\r\n\r\n    def predict(self, dataset):\r\n        if not self.fitted:\r\n            raise ValueError(\"model is not fitted yet!\")\r\n\r\n        dl_test = dataset.prepare(\"test\", col_set=[\"feature\", \"label\"], data_key=DataHandlerLP.DK_I)\r\n        dl_test.config(fillna_type=\"ffill+bfill\")\r\n        test_loader = DataLoader(dl_test, batch_size=self.batch_size, num_workers=self.n_jobs)\r\n        self.model.eval()\r\n        preds = []\r\n\r\n        for data in test_loader:\r\n            feature = data[:, :, 0:-1].to(self.device)\r\n\r\n            with torch.no_grad():\r\n                pred = self.model(feature.float()).detach().cpu().numpy()\r\n\r\n            preds.append(pred)\r\n\r\n        return pd.Series(np.concatenate(preds), index=dl_test.get_index())\r\n\r\n\r\nclass PositionalEncoding(nn.Module):\r\n    def __init__(self, d_model, max_len=1000):\r\n        super(PositionalEncoding, self).__init__()\r\n        pe = torch.zeros(max_len, d_model)\r\n        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)\r\n        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))\r\n        pe[:, 0::2] = torch.sin(position * div_term)\r\n        pe[:, 1::2] = torch.cos(position * div_term)\r\n        pe = pe.unsqueeze(0).transpose(0, 1)\r\n        self.register_buffer(\"pe\", pe)\r\n\r\n    def forward(self, x):\r\n        # [T, N, F]\r\n        return x + self.pe[: x.size(0), :]\r\n\r\n\r\nclass Transformer(nn.Module):\r\n    def __init__(self, d_feat=6, d_model=8, nhead=4, num_layers=2, dropout=0.5, device=None):\r\n        super(Transformer, self).__init__()\r\n        self.feature_layer = nn.Linear(d_feat, d_model)\r\n        self.pos_encoder = PositionalEncoding(d_model)\r\n        self.encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dropout=dropout)\r\n        self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers)\r\n        self.decoder_layer = nn.Linear(d_model, 1)\r\n        self.device = device\r\n        self.d_feat = d_feat\r\n\r\n    def forward(self, src):\r\n        # src [N, T, F], [512, 60, 6]\r\n        src = self.feature_layer(src)  # [512, 60, 8]\r\n\r\n        # src [N, T, F] --> [T, N, F], [60, 512, 8]\r\n        src = src.transpose(1, 0)  # not batch first\r\n\r\n        mask = None\r\n\r\n        src = self.pos_encoder(src)\r\n        output = self.transformer_encoder(src, mask)  # [60, 512, 8]\r\n\r\n        # [T, N, F] --> [N, T*F]\r\n        output = self.decoder_layer(output.transpose(1, 0)[:, -1, :])  # [512, 1]\r\n\r\n        return output.squeeze()\r\n"
  },
  {
    "path": "qlib/contrib/model/pytorch_utils.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport torch.nn as nn\n\n\ndef count_parameters(models_or_parameters, unit=\"m\"):\n    \"\"\"\n    This function is to obtain the storage size unit of a (or multiple) models.\n\n    Parameters\n    ----------\n    models_or_parameters : PyTorch model(s) or a list of parameters.\n    unit : the storage size unit.\n\n    Returns\n    -------\n    The number of parameters of the given model(s) or parameters.\n    \"\"\"\n    if isinstance(models_or_parameters, nn.Module):\n        counts = sum(v.numel() for v in models_or_parameters.parameters())\n    elif isinstance(models_or_parameters, nn.Parameter):\n        counts = models_or_parameters.numel()\n    elif isinstance(models_or_parameters, (list, tuple)):\n        return sum(count_parameters(x, unit) for x in models_or_parameters)\n    else:\n        counts = sum(v.numel() for v in models_or_parameters)\n    unit = unit.lower()\n    if unit in (\"kb\", \"k\"):\n        counts /= 2**10\n    elif unit in (\"mb\", \"m\"):\n        counts /= 2**20\n    elif unit in (\"gb\", \"g\"):\n        counts /= 2**30\n    elif unit is not None:\n        raise ValueError(\"Unknown unit: {:}\".format(unit))\n    return counts\n"
  },
  {
    "path": "qlib/contrib/model/tcn.py",
    "content": "# MIT License\n# Copyright (c) 2018 CMU Locus Lab\nimport torch.nn as nn\nfrom torch.nn.utils import weight_norm\n\n\nclass Chomp1d(nn.Module):\n    def __init__(self, chomp_size):\n        super(Chomp1d, self).__init__()\n        self.chomp_size = chomp_size\n\n    def forward(self, x):\n        return x[:, :, : -self.chomp_size].contiguous()\n\n\nclass TemporalBlock(nn.Module):\n    def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout=0.2):\n        super(TemporalBlock, self).__init__()\n        self.conv1 = weight_norm(\n            nn.Conv1d(n_inputs, n_outputs, kernel_size, stride=stride, padding=padding, dilation=dilation)\n        )\n        self.chomp1 = Chomp1d(padding)\n        self.relu1 = nn.ReLU()\n        self.dropout1 = nn.Dropout(dropout)\n\n        self.conv2 = weight_norm(\n            nn.Conv1d(n_outputs, n_outputs, kernel_size, stride=stride, padding=padding, dilation=dilation)\n        )\n        self.chomp2 = Chomp1d(padding)\n        self.relu2 = nn.ReLU()\n        self.dropout2 = nn.Dropout(dropout)\n\n        self.net = nn.Sequential(\n            self.conv1, self.chomp1, self.relu1, self.dropout1, self.conv2, self.chomp2, self.relu2, self.dropout2\n        )\n        self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None\n        self.relu = nn.ReLU()\n        self.init_weights()\n\n    def init_weights(self):\n        self.conv1.weight.data.normal_(0, 0.01)\n        self.conv2.weight.data.normal_(0, 0.01)\n        if self.downsample is not None:\n            self.downsample.weight.data.normal_(0, 0.01)\n\n    def forward(self, x):\n        out = self.net(x)\n        res = x if self.downsample is None else self.downsample(x)\n        return self.relu(out + res)\n\n\nclass TemporalConvNet(nn.Module):\n    def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.2):\n        super(TemporalConvNet, self).__init__()\n        layers = []\n        num_levels = len(num_channels)\n        for i in range(num_levels):\n            dilation_size = 2**i\n            in_channels = num_inputs if i == 0 else num_channels[i - 1]\n            out_channels = num_channels[i]\n            layers += [\n                TemporalBlock(\n                    in_channels,\n                    out_channels,\n                    kernel_size,\n                    stride=1,\n                    dilation=dilation_size,\n                    padding=(kernel_size - 1) * dilation_size,\n                    dropout=dropout,\n                )\n            ]\n\n        self.network = nn.Sequential(*layers)\n\n    def forward(self, x):\n        return self.network(x)\n"
  },
  {
    "path": "qlib/contrib/model/xgboost.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport numpy as np\nimport pandas as pd\nimport xgboost as xgb\nfrom typing import Text, Union\nfrom ...model.base import Model\nfrom ...data.dataset import DatasetH\nfrom ...data.dataset.handler import DataHandlerLP\nfrom ...model.interpret.base import FeatureInt\nfrom ...data.dataset.weight import Reweighter\n\n\nclass XGBModel(Model, FeatureInt):\n    \"\"\"XGBModel Model\"\"\"\n\n    def __init__(self, **kwargs):\n        self._params = {}\n        self._params.update(kwargs)\n        self.model = None\n\n    def fit(\n        self,\n        dataset: DatasetH,\n        num_boost_round=1000,\n        early_stopping_rounds=50,\n        verbose_eval=20,\n        evals_result=dict(),\n        reweighter=None,\n        **kwargs,\n    ):\n        df_train, df_valid = dataset.prepare(\n            [\"train\", \"valid\"],\n            col_set=[\"feature\", \"label\"],\n            data_key=DataHandlerLP.DK_L,\n        )\n        x_train, y_train = df_train[\"feature\"], df_train[\"label\"]\n        x_valid, y_valid = df_valid[\"feature\"], df_valid[\"label\"]\n\n        # Lightgbm need 1D array as its label\n        if y_train.values.ndim == 2 and y_train.values.shape[1] == 1:\n            y_train_1d, y_valid_1d = np.squeeze(y_train.values), np.squeeze(y_valid.values)\n        else:\n            raise ValueError(\"XGBoost doesn't support multi-label training\")\n\n        if reweighter is None:\n            w_train = None\n            w_valid = None\n        elif isinstance(reweighter, Reweighter):\n            w_train = reweighter.reweight(df_train)\n            w_valid = reweighter.reweight(df_valid)\n        else:\n            raise ValueError(\"Unsupported reweighter type.\")\n\n        dtrain = xgb.DMatrix(x_train.values, label=y_train_1d, weight=w_train)\n        dvalid = xgb.DMatrix(x_valid.values, label=y_valid_1d, weight=w_valid)\n        self.model = xgb.train(\n            self._params,\n            dtrain=dtrain,\n            num_boost_round=num_boost_round,\n            evals=[(dtrain, \"train\"), (dvalid, \"valid\")],\n            early_stopping_rounds=early_stopping_rounds,\n            verbose_eval=verbose_eval,\n            evals_result=evals_result,\n            **kwargs,\n        )\n        evals_result[\"train\"] = list(evals_result[\"train\"].values())[0]\n        evals_result[\"valid\"] = list(evals_result[\"valid\"].values())[0]\n\n    def predict(self, dataset: DatasetH, segment: Union[Text, slice] = \"test\"):\n        if self.model is None:\n            raise ValueError(\"model is not fitted yet!\")\n        x_test = dataset.prepare(segment, col_set=\"feature\", data_key=DataHandlerLP.DK_I)\n        return pd.Series(self.model.predict(xgb.DMatrix(x_test)), index=x_test.index)\n\n    def get_feature_importance(self, *args, **kwargs) -> pd.Series:\n        \"\"\"get feature importance\n\n        Notes\n        -------\n            parameters reference:\n                https://xgboost.readthedocs.io/en/latest/python/python_api.html#xgboost.Booster.get_score\n        \"\"\"\n        return pd.Series(self.model.get_score(*args, **kwargs)).sort_values(ascending=False)\n"
  },
  {
    "path": "qlib/contrib/online/__init__.py",
    "content": "# pylint: skip-file\n# flake8: noqa\n\n'''\nTODO:\n\n- Online needs that the model have such method\n    def get_data_with_date(self, date, **kwargs):\n        \"\"\"\n        Will be called in online module\n        need to return the data that used to predict the label (score) of stocks at date.\n\n        :param\n            date: pd.Timestamp\n                predict date\n        :return:\n            data: the input data that used to predict the label (score) of stocks at predict date.\n        \"\"\"\n        raise NotImplementedError(\"get_data_with_date for this model is not implemented.\")\n\n'''\n"
  },
  {
    "path": "qlib/contrib/online/manager.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\n# pylint: skip-file\n# flake8: noqa\n\nimport pathlib\nimport pandas as pd\nimport shutil\nfrom ruamel.yaml import YAML\nfrom ...backtest.account import Account\nfrom .user import User\nfrom .utils import load_instance, save_instance\nfrom ...utils import init_instance_by_config\n\n\nclass UserManager:\n    def __init__(self, user_data_path, save_report=True):\n        \"\"\"\n        This module is designed to manager the users in online system\n        all users' data were assumed to be saved in user_data_path\n            Parameter\n                user_data_path : string\n                    data path that all users' data were saved in\n\n        variables:\n            data_path : string\n                data path that all users' data were saved in\n            users_file : string\n                A path of the file record the add_date of users\n            save_report : bool\n                whether to save report after each trading process\n            users : dict{}\n                [user_id]->User()\n                the python dict save instances of User() for each user_id\n            user_record : pd.Dataframe\n                user_id(string), add_date(string)\n                indicate the add_date for each users\n        \"\"\"\n        self.data_path = pathlib.Path(user_data_path)\n        self.users_file = self.data_path / \"users.csv\"\n        self.save_report = save_report\n        self.users = {}\n        self.user_record = None\n\n    def load_users(self):\n        \"\"\"\n        load all users' data into manager\n        \"\"\"\n        self.users = {}\n        self.user_record = pd.read_csv(self.users_file, index_col=0)\n        for user_id in self.user_record.index:\n            self.users[user_id] = self.load_user(user_id)\n\n    def load_user(self, user_id):\n        \"\"\"\n        return a instance of User() represents a user to be processed\n            Parameter\n                user_id : string\n            :return\n                user : User()\n        \"\"\"\n        account_path = self.data_path / user_id\n        strategy_file = self.data_path / user_id / \"strategy_{}.pickle\".format(user_id)\n        model_file = self.data_path / user_id / \"model_{}.pickle\".format(user_id)\n        cur_user_list = list(self.users)\n        if user_id in cur_user_list:\n            raise ValueError(\"User {} has been loaded\".format(user_id))\n        else:\n            trade_account = Account(0)\n            trade_account.load_account(account_path)\n            strategy = load_instance(strategy_file)\n            model = load_instance(model_file)\n            user = User(account=trade_account, strategy=strategy, model=model)\n            return user\n\n    def save_user_data(self, user_id):\n        \"\"\"\n        save a instance of User() to user data path\n            Parameter\n                user_id : string\n        \"\"\"\n        if not user_id in self.users:\n            raise ValueError(\"Cannot find user {}\".format(user_id))\n        self.users[user_id].account.save_account(self.data_path / user_id)\n        save_instance(\n            self.users[user_id].strategy,\n            self.data_path / user_id / \"strategy_{}.pickle\".format(user_id),\n        )\n        save_instance(\n            self.users[user_id].model,\n            self.data_path / user_id / \"model_{}.pickle\".format(user_id),\n        )\n\n    def add_user(self, user_id, config_file, add_date):\n        \"\"\"\n        add the new user {user_id} into user data\n        will create a new folder named \"{user_id}\" in user data path\n            Parameter\n                user_id : string\n                init_cash : int\n                config_file : str/pathlib.Path()\n                   path of config file\n        \"\"\"\n        config_file = pathlib.Path(config_file)\n        if not config_file.exists():\n            raise ValueError(\"Cannot find config file {}\".format(config_file))\n        user_path = self.data_path / user_id\n        if user_path.exists():\n            raise ValueError(\"User data for {} already exists\".format(user_id))\n\n        with config_file.open(\"r\") as fp:\n            yaml = YAML(typ=\"safe\", pure=True)\n            config = yaml.load(fp)\n        # load model\n        model = init_instance_by_config(config[\"model\"])\n\n        # load strategy\n        strategy = init_instance_by_config(config[\"strategy\"])\n        init_args = strategy.get_init_args_from_model(model, add_date)\n        strategy.init(**init_args)\n\n        # init Account\n        trade_account = Account(init_cash=config[\"init_cash\"])\n\n        # save user\n        user_path.mkdir()\n        save_instance(model, self.data_path / user_id / \"model_{}.pickle\".format(user_id))\n        save_instance(strategy, self.data_path / user_id / \"strategy_{}.pickle\".format(user_id))\n        trade_account.save_account(self.data_path / user_id)\n        user_record = pd.read_csv(self.users_file, index_col=0)\n        user_record.loc[user_id] = [add_date]\n        user_record.to_csv(self.users_file)\n\n    def remove_user(self, user_id):\n        \"\"\"\n        remove user {user_id} in current user dataset\n        will delete the folder \"{user_id}\" in user data path\n            :param\n                user_id : string\n        \"\"\"\n        user_path = self.data_path / user_id\n        if not user_path.exists():\n            raise ValueError(\"Cannot find user data {}\".format(user_id))\n        shutil.rmtree(user_path)\n        user_record = pd.read_csv(self.users_file, index_col=0)\n        user_record.drop([user_id], inplace=True)\n        user_record.to_csv(self.users_file)\n"
  },
  {
    "path": "qlib/contrib/online/online_model.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\n# pylint: skip-file\n# flake8: noqa\n\nimport random\nimport pandas as pd\nfrom ...data import D\nfrom ..model.base import Model\n\n\nclass ScoreFileModel(Model):\n    \"\"\"\n    This model will load a score file, and return score at date exists in score file.\n    \"\"\"\n\n    def __init__(self, score_path):\n        pred_test = pd.read_csv(score_path, index_col=[0, 1], parse_dates=True, infer_datetime_format=True)\n        self.pred = pred_test\n\n    def get_data_with_date(self, date, **kwargs):\n        score = self.pred.loc(axis=0)[:, date]  # (stock_id, trade_date) multi_index, score in pdate\n        score_series = score.reset_index(level=\"datetime\", drop=True)[\n            \"score\"\n        ]  # pd.Series ; index:stock_id, data: score\n        return score_series\n\n    def predict(self, x_test, **kwargs):\n        return x_test\n\n    def score(self, x_test, **kwargs):\n        return\n\n    def fit(self, x_train, y_train, x_valid, y_valid, w_train=None, w_valid=None, **kwargs):\n        return\n\n    def save(self, fname, **kwargs):\n        return\n"
  },
  {
    "path": "qlib/contrib/online/operator.py",
    "content": "# Copyright (c) Microsoft Corporation.\r\n# Licensed under the MIT License.\r\n\r\n# pylint: skip-file\r\n# flake8: noqa\r\n\r\nimport fire\r\nimport pandas as pd\r\nimport pathlib\r\nimport qlib\r\nimport logging\r\n\r\nfrom ...data import D\r\nfrom ...log import get_module_logger\r\nfrom ...utils import get_pre_trading_date, is_tradable_date\r\nfrom ..evaluate import risk_analysis\r\nfrom ..backtest.backtest import update_account\r\n\r\nfrom .manager import UserManager\r\nfrom .utils import prepare\r\nfrom .utils import create_user_folder\r\nfrom .executor import load_order_list, save_order_list\r\nfrom .executor import SimulatorExecutor\r\nfrom .executor import save_score_series, load_score_series\r\n\r\n\r\nclass Operator:\r\n    def __init__(self, client: str):\r\n        \"\"\"\r\n        Parameters\r\n        ----------\r\n            client: str\r\n                The qlib client config file(.yaml)\r\n        \"\"\"\r\n        self.logger = get_module_logger(\"online operator\", level=logging.INFO)\r\n        self.client = client\r\n\r\n    @staticmethod\r\n    def init(client, path, date=None):\r\n        \"\"\"Initial UserManager(), get predict date and trade date\r\n        Parameters\r\n        ----------\r\n            client: str\r\n                The qlib client config file(.yaml)\r\n            path : str\r\n                Path to save user account.\r\n            date : str (YYYY-MM-DD)\r\n                Trade date, when the generated order list will be traded.\r\n        Return\r\n        ----------\r\n            um: UserManager()\r\n            pred_date: pd.Timestamp\r\n            trade_date: pd.Timestamp\r\n        \"\"\"\r\n        qlib.init_from_yaml_conf(client)\r\n        um = UserManager(user_data_path=pathlib.Path(path))\r\n        um.load_users()\r\n        if not date:\r\n            trade_date, pred_date = None, None\r\n        else:\r\n            trade_date = pd.Timestamp(date)\r\n            if not is_tradable_date(trade_date):\r\n                raise ValueError(\"trade date is not tradable date\".format(trade_date.date()))\r\n            pred_date = get_pre_trading_date(trade_date, future=True)\r\n        return um, pred_date, trade_date\r\n\r\n    def add_user(self, id, config, path, date):\r\n        \"\"\"Add a new user into the a folder to run 'online' module.\r\n\r\n        Parameters\r\n        ----------\r\n        id : str\r\n            User id, should be unique.\r\n        config : str\r\n            The file path (yaml) of user config\r\n        path : str\r\n            Path to save user account.\r\n        date : str (YYYY-MM-DD)\r\n            The date that user account was added.\r\n        \"\"\"\r\n        create_user_folder(path)\r\n        qlib.init_from_yaml_conf(self.client)\r\n        um = UserManager(user_data_path=path)\r\n        add_date = D.calendar(end_time=date)[-1]\r\n        if not is_tradable_date(add_date):\r\n            raise ValueError(\"add date is not tradable date\".format(add_date.date()))\r\n        um.add_user(user_id=id, config_file=config, add_date=add_date)\r\n\r\n    def remove_user(self, id, path):\r\n        \"\"\"Remove user from folder used in 'online' module.\r\n\r\n        Parameters\r\n        ----------\r\n        id : str\r\n            User id, should be unique.\r\n        path : str\r\n            Path to save user account.\r\n        \"\"\"\r\n        um = UserManager(user_data_path=path)\r\n        um.remove_user(user_id=id)\r\n\r\n    def generate(self, date, path):\r\n        \"\"\"Generate order list that will be traded at 'date'.\r\n\r\n        Parameters\r\n        ----------\r\n        date : str (YYYY-MM-DD)\r\n            Trade date, when the generated order list will be traded.\r\n        path : str\r\n            Path to save user account.\r\n        \"\"\"\r\n        um, pred_date, trade_date = self.init(self.client, path, date)\r\n        for user_id, user in um.users.items():\r\n            dates, trade_exchange = prepare(um, pred_date, user_id)\r\n            # get and save the score at predict date\r\n            input_data = user.model.get_data_with_date(pred_date)\r\n            score_series = user.model.predict(input_data)\r\n            save_score_series(score_series, (pathlib.Path(path) / user_id), trade_date)\r\n\r\n            # update strategy (and model)\r\n            user.strategy.update(score_series, pred_date, trade_date)\r\n\r\n            # generate and save order list\r\n            order_list = user.strategy.generate_trade_decision(\r\n                score_series=score_series,\r\n                current=user.account.current_position,\r\n                trade_exchange=trade_exchange,\r\n                trade_date=trade_date,\r\n            )\r\n            save_order_list(\r\n                order_list=order_list,\r\n                user_path=(pathlib.Path(path) / user_id),\r\n                trade_date=trade_date,\r\n            )\r\n            self.logger.info(\"Generate order list at {} for {}\".format(trade_date, user_id))\r\n            um.save_user_data(user_id)\r\n\r\n    def execute(self, date, exchange_config, path):\r\n        \"\"\"Execute the orderlist at 'date'.\r\n\r\n        Parameters\r\n        ----------\r\n           date : str (YYYY-MM-DD)\r\n               Trade date, that the generated order list will be traded.\r\n           exchange_config: str\r\n               The file path (yaml) of exchange config\r\n           path : str\r\n               Path to save user account.\r\n        \"\"\"\r\n        um, pred_date, trade_date = self.init(self.client, path, date)\r\n        for user_id, user in um.users.items():\r\n            dates, trade_exchange = prepare(um, trade_date, user_id, exchange_config)\r\n            executor = SimulatorExecutor(trade_exchange=trade_exchange)\r\n            if str(dates[0].date()) != str(pred_date.date()):\r\n                raise ValueError(\r\n                    \"The account data is not newest! last trading date {}, today {}\".format(\r\n                        dates[0].date(), trade_date.date()\r\n                    )\r\n                )\r\n\r\n            # load and execute the order list\r\n            # will not modify the trade_account after executing\r\n            order_list = load_order_list(user_path=(pathlib.Path(path) / user_id), trade_date=trade_date)\r\n            trade_info = executor.execute(order_list=order_list, trade_account=user.account, trade_date=trade_date)\r\n            executor.save_executed_file_from_trade_info(\r\n                trade_info=trade_info,\r\n                user_path=(pathlib.Path(path) / user_id),\r\n                trade_date=trade_date,\r\n            )\r\n            self.logger.info(\"execute order list at {} for {}\".format(trade_date.date(), user_id))\r\n\r\n    def update(self, date, path, type=\"SIM\"):\r\n        \"\"\"Update account at 'date'.\r\n\r\n        Parameters\r\n        ----------\r\n        date : str (YYYY-MM-DD)\r\n            Trade date, that the generated order list will be traded.\r\n        path : str\r\n            Path to save user account.\r\n        type : str\r\n            which executor was been used to execute the order list\r\n            'SIM': SimulatorExecutor()\r\n        \"\"\"\r\n        if type not in [\"SIM\", \"YC\"]:\r\n            raise ValueError(\"type is invalid, {}\".format(type))\r\n        um, pred_date, trade_date = self.init(self.client, path, date)\r\n        for user_id, user in um.users.items():\r\n            dates, trade_exchange = prepare(um, trade_date, user_id)\r\n            if type == \"SIM\":\r\n                executor = SimulatorExecutor(trade_exchange=trade_exchange)\r\n            else:\r\n                raise ValueError(\"not found executor\")\r\n            # dates[0] is the last_trading_date\r\n            if str(dates[0].date()) > str(pred_date.date()):\r\n                raise ValueError(\r\n                    \"The account data is not newest! last trading date {}, today {}\".format(\r\n                        dates[0].date(), trade_date.date()\r\n                    )\r\n                )\r\n            # load trade info and update account\r\n            trade_info = executor.load_trade_info_from_executed_file(\r\n                user_path=(pathlib.Path(path) / user_id), trade_date=trade_date\r\n            )\r\n            score_series = load_score_series((pathlib.Path(path) / user_id), trade_date)\r\n            update_account(user.account, trade_info, trade_exchange, trade_date)\r\n\r\n            portfolio_metrics = user.account.portfolio_metrics.generate_portfolio_metrics_dataframe()\r\n            self.logger.info(portfolio_metrics)\r\n            um.save_user_data(user_id)\r\n            self.logger.info(\"Update account state {} for {}\".format(trade_date, user_id))\r\n\r\n    def simulate(self, id, config, exchange_config, start, end, path, bench=\"SH000905\"):\r\n        \"\"\"Run the ( generate_trade_decision -> execute_order_list -> update_account) process everyday\r\n            from start date to end date.\r\n\r\n        Parameters\r\n        ----------\r\n        id : str\r\n            user id, need to be unique\r\n        config : str\r\n            The file path (yaml) of user config\r\n        exchange_config: str\r\n            The file path (yaml) of exchange config\r\n        start : str \"YYYY-MM-DD\"\r\n            The start date to run the online simulate\r\n        end : str \"YYYY-MM-DD\"\r\n            The end date to run the online simulate\r\n        path : str\r\n            Path to save user account.\r\n        bench : str\r\n            The benchmark that our result compared with.\r\n            'SH000905' for csi500, 'SH000300' for csi300\r\n        \"\"\"\r\n        # Clear the current user if exists, then add a new user.\r\n        create_user_folder(path)\r\n        um = self.init(self.client, path, None)[0]\r\n        start_date, end_date = pd.Timestamp(start), pd.Timestamp(end)\r\n        try:\r\n            um.remove_user(user_id=id)\r\n        except BaseException:\r\n            pass\r\n        um.add_user(user_id=id, config_file=config, add_date=pd.Timestamp(start_date))\r\n\r\n        # Do the online simulate\r\n        um.load_users()\r\n        user = um.users[id]\r\n        dates, trade_exchange = prepare(um, end_date, id, exchange_config)\r\n        executor = SimulatorExecutor(trade_exchange=trade_exchange)\r\n        for pred_date, trade_date in zip(dates[:-2], dates[1:-1]):\r\n            user_path = pathlib.Path(path) / id\r\n\r\n            # 1. load and save score_series\r\n            input_data = user.model.get_data_with_date(pred_date)\r\n            score_series = user.model.predict(input_data)\r\n            save_score_series(score_series, (pathlib.Path(path) / id), trade_date)\r\n\r\n            # 2. update strategy (and model)\r\n            user.strategy.update(score_series, pred_date, trade_date)\r\n\r\n            # 3. generate and save order list\r\n            order_list = user.strategy.generate_trade_decision(\r\n                score_series=score_series,\r\n                current=user.account.current_position,\r\n                trade_exchange=trade_exchange,\r\n                trade_date=trade_date,\r\n            )\r\n            save_order_list(order_list=order_list, user_path=user_path, trade_date=trade_date)\r\n\r\n            # 4. auto execute order list\r\n            order_list = load_order_list(user_path=user_path, trade_date=trade_date)\r\n            trade_info = executor.execute(trade_account=user.account, order_list=order_list, trade_date=trade_date)\r\n            executor.save_executed_file_from_trade_info(\r\n                trade_info=trade_info, user_path=user_path, trade_date=trade_date\r\n            )\r\n            # 5. update account state\r\n            trade_info = executor.load_trade_info_from_executed_file(user_path=user_path, trade_date=trade_date)\r\n            update_account(user.account, trade_info, trade_exchange, trade_date)\r\n        portfolio_metrics = user.account.portfolio_metrics.generate_portfolio_metrics_dataframe()\r\n        self.logger.info(portfolio_metrics)\r\n        um.save_user_data(id)\r\n        self.show(id, path, bench)\r\n\r\n    def show(self, id, path, bench=\"SH000905\"):\r\n        \"\"\"show the newly report (mean, std, information_ratio, annualized_return)\r\n\r\n        Parameters\r\n        ----------\r\n        id : str\r\n            user id, need to be unique\r\n        path : str\r\n            Path to save user account.\r\n        bench : str\r\n            The benchmark that our result compared with.\r\n            'SH000905' for csi500, 'SH000300' for csi300\r\n        \"\"\"\r\n        um = self.init(self.client, path, None)[0]\r\n        if id not in um.users:\r\n            raise ValueError(\"Cannot find user \".format(id))\r\n        bench = D.features([bench], [\"$change\"]).loc[bench, \"$change\"]\r\n        portfolio_metrics = um.users[id].account.portfolio_metrics.generate_portfolio_metrics_dataframe()\r\n        portfolio_metrics[\"bench\"] = bench\r\n        analysis_result = {}\r\n        r = (portfolio_metrics[\"return\"] - portfolio_metrics[\"bench\"]).dropna()\r\n        analysis_result[\"excess_return_without_cost\"] = risk_analysis(r)\r\n        r = (portfolio_metrics[\"return\"] - portfolio_metrics[\"bench\"] - portfolio_metrics[\"cost\"]).dropna()\r\n        analysis_result[\"excess_return_with_cost\"] = risk_analysis(r)\r\n        print(\"Result:\")\r\n        print(\"excess_return_without_cost:\")\r\n        print(analysis_result[\"excess_return_without_cost\"])\r\n        print(\"excess_return_with_cost:\")\r\n        print(analysis_result[\"excess_return_with_cost\"])\r\n\r\n\r\ndef run():\r\n    fire.Fire(Operator)\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    run()\r\n"
  },
  {
    "path": "qlib/contrib/online/user.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\n# pylint: skip-file\n# flake8: noqa\n\nimport logging\n\nfrom ...log import get_module_logger\nfrom ..evaluate import risk_analysis\nfrom ...data import D\n\n\nclass User:\n    def __init__(self, account, strategy, model, verbose=False):\n        \"\"\"\n        A user in online system, which contains account, strategy and model three module.\n            Parameter\n                account : Account()\n                strategy :\n                    a strategy instance\n                model :\n                    a model instance\n                report_save_path : string\n                    the path to save report. Will not save report if None\n                verbose : bool\n                    Whether to print the info during the process\n        \"\"\"\n        self.logger = get_module_logger(\"User\", level=logging.INFO)\n        self.account = account\n        self.strategy = strategy\n        self.model = model\n        self.verbose = verbose\n\n    def init_state(self, date):\n        \"\"\"\n        init state when each trading date begin\n            Parameter\n                date : pd.Timestamp\n        \"\"\"\n        self.account.init_state(today=date)\n        self.strategy.init_state(trade_date=date, model=self.model, account=self.account)\n        return\n\n    def get_latest_trading_date(self):\n        \"\"\"\n        return the latest trading date for user {user_id}\n            Parameter\n                user_id : string\n            :return\n                date : string (e.g '2018-10-08')\n        \"\"\"\n        if not self.account.last_trade_date:\n            return None\n        return str(self.account.last_trade_date.date())\n\n    def showReport(self, benchmark=\"SH000905\"):\n        \"\"\"\n        show the newly report (mean, std, information_ratio, annualized_return)\n            Parameter\n                benchmark : string\n                    bench that to be compared, 'SH000905' for csi500\n        \"\"\"\n        bench = D.features([benchmark], [\"$change\"], disk_cache=True).loc[benchmark, \"$change\"]\n        portfolio_metrics = self.account.portfolio_metrics.generate_portfolio_metrics_dataframe()\n        portfolio_metrics[\"bench\"] = bench\n        analysis_result = {\"pred\": {}, \"excess_return_without_cost\": {}, \"excess_return_with_cost\": {}}\n        r = (portfolio_metrics[\"return\"] - portfolio_metrics[\"bench\"]).dropna()\n        analysis_result[\"excess_return_without_cost\"][0] = risk_analysis(r)\n        r = (portfolio_metrics[\"return\"] - portfolio_metrics[\"bench\"] - portfolio_metrics[\"cost\"]).dropna()\n        analysis_result[\"excess_return_with_cost\"][0] = risk_analysis(r)\n        self.logger.info(\"Result of porfolio:\")\n        self.logger.info(\"excess_return_without_cost:\")\n        self.logger.info(analysis_result[\"excess_return_without_cost\"][0])\n        self.logger.info(\"excess_return_with_cost:\")\n        self.logger.info(analysis_result[\"excess_return_with_cost\"][0])\n        return portfolio_metrics\n"
  },
  {
    "path": "qlib/contrib/online/utils.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\n# pylint: skip-file\n# flake8: noqa\n\nimport pathlib\nimport pickle\nimport pandas as pd\nfrom ruamel.yaml import YAML\nfrom ...data import D\nfrom ...config import C\nfrom ...log import get_module_logger\nfrom ...utils import get_next_trading_date\nfrom ...utils.pickle_utils import restricted_pickle_load\nfrom ...backtest.exchange import Exchange\n\nlog = get_module_logger(\"utils\")\n\n\ndef load_instance(file_path):\n    \"\"\"\n    load a pickle file\n        Parameter\n           file_path : string / pathlib.Path()\n                path of file to be loaded\n        :return\n            An instance loaded from file\n    \"\"\"\n    file_path = pathlib.Path(file_path)\n    if not file_path.exists():\n        raise ValueError(\"Cannot find file {}\".format(file_path))\n    with file_path.open(\"rb\") as fr:\n        instance = restricted_pickle_load(fr)\n    return instance\n\n\ndef save_instance(instance, file_path):\n    \"\"\"\n    save(dump) an instance to a pickle file\n        Parameter\n            instance :\n                data to be dumped\n            file_path : string / pathlib.Path()\n                path of file to be dumped\n    \"\"\"\n    file_path = pathlib.Path(file_path)\n    with file_path.open(\"wb\") as fr:\n        pickle.dump(instance, fr, C.dump_protocol_version)\n\n\ndef create_user_folder(path):\n    path = pathlib.Path(path)\n    if path.exists():\n        return\n    path.mkdir(parents=True)\n    head = pd.DataFrame(columns=(\"user_id\", \"add_date\"))\n    head.to_csv(path / \"users.csv\", index=None)\n\n\ndef prepare(um, today, user_id, exchange_config=None):\n    \"\"\"\n    1. Get the dates that need to do trading till today for user {user_id}\n        dates[0] indicate the latest trading date of User{user_id},\n        if User{user_id} haven't do trading before, than dates[0] presents the init date of User{user_id}.\n    2. Set the exchange with exchange_config file\n\n        Parameter\n            um : UserManager()\n            today : pd.Timestamp()\n            user_id : str\n        :return\n            dates : list of pd.Timestamp\n            trade_exchange : Exchange()\n    \"\"\"\n    # get latest trading date for {user_id}\n    # if is None, indicate it haven't traded, then last trading date is init date of {user_id}\n    latest_trading_date = um.users[user_id].get_latest_trading_date()\n    if not latest_trading_date:\n        latest_trading_date = um.user_record.loc[user_id][0]\n\n    if str(today.date()) < latest_trading_date:\n        log.warning(\"user_id:{}, last trading date {} after today {}\".format(user_id, latest_trading_date, today))\n        return [pd.Timestamp(latest_trading_date)], None\n\n    dates = D.calendar(\n        start_time=pd.Timestamp(latest_trading_date),\n        end_time=pd.Timestamp(today),\n        future=True,\n    )\n    dates = list(dates)\n    dates.append(get_next_trading_date(dates[-1], future=True))\n    if exchange_config:\n        with pathlib.Path(exchange_config).open(\"r\") as fp:\n            yaml = YAML(typ=\"safe\", pure=True)\n            exchange_paras = yaml.load(fp)\n    else:\n        exchange_paras = {}\n    trade_exchange = Exchange(trade_dates=dates, **exchange_paras)\n    return dates, trade_exchange\n"
  },
  {
    "path": "qlib/contrib/ops/__init__.py",
    "content": ""
  },
  {
    "path": "qlib/contrib/ops/high_freq.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\nimport numpy as np\nimport pandas as pd\nfrom datetime import datetime\n\nfrom qlib.data.cache import H\nfrom qlib.data.data import Cal\nfrom qlib.data.ops import ElemOperator, PairOperator\nfrom qlib.utils.time import time_to_day_index\n\n\ndef get_calendar_day(freq=\"1min\", future=False):\n    \"\"\"\n    Load High-Freq Calendar Date Using Memcache.\n    !!!NOTE: Loading the calendar is quite slow. So loading calendar before start multiprocessing will make it faster.\n\n    Parameters\n    ----------\n    freq : str\n        frequency of read calendar file.\n    future : bool\n        whether including future trading day.\n\n    Returns\n    -------\n    _calendar:\n        array of date.\n    \"\"\"\n    flag = f\"{freq}_future_{future}_day\"\n    if flag in H[\"c\"]:\n        _calendar = H[\"c\"][flag]\n    else:\n        _calendar = np.array(list(map(lambda x: x.date(), Cal.load_calendar(freq, future))))\n        H[\"c\"][flag] = _calendar\n    return _calendar\n\n\ndef get_calendar_minute(freq=\"day\", future=False):\n    \"\"\"Load High-Freq Calendar Minute Using Memcache\"\"\"\n    flag = f\"{freq}_future_{future}_day\"\n    if flag in H[\"c\"]:\n        _calendar = H[\"c\"][flag]\n    else:\n        _calendar = np.array(list(map(lambda x: x.minute // 30, Cal.load_calendar(freq, future))))\n        H[\"c\"][flag] = _calendar\n    return _calendar\n\n\nclass DayCumsum(ElemOperator):\n    \"\"\"DayCumsum Operator during start time and end time.\n\n    Parameters\n    ----------\n    feature : Expression\n        feature instance\n    start : str\n        the start time of backtest in one day.\n        !!!NOTE: \"9:30\" means the time period of (9:30, 9:31) is in transaction.\n    end : str\n        the end time of backtest in one day.\n        !!!NOTE: \"14:59\" means the time period of (14:59, 15:00) is in transaction,\n                but (15:00, 15:01) is not.\n        So start=\"9:30\" and end=\"14:59\" means trading all day.\n\n    Returns\n    ----------\n    feature:\n        a series of that each value equals the cumsum value during start time and end time.\n        Otherwise, the value is zero.\n    \"\"\"\n\n    def __init__(self, feature, start: str = \"9:30\", end: str = \"14:59\", data_granularity: int = 1):\n        self.feature = feature\n        self.start = datetime.strptime(start, \"%H:%M\")\n        self.end = datetime.strptime(end, \"%H:%M\")\n\n        self.morning_open = datetime.strptime(\"9:30\", \"%H:%M\")\n        self.morning_close = datetime.strptime(\"11:30\", \"%H:%M\")\n        self.noon_open = datetime.strptime(\"13:00\", \"%H:%M\")\n        self.noon_close = datetime.strptime(\"15:00\", \"%H:%M\")\n\n        self.data_granularity = data_granularity\n        self.start_id = time_to_day_index(self.start) // self.data_granularity\n        self.end_id = time_to_day_index(self.end) // self.data_granularity\n        assert 240 % self.data_granularity == 0\n\n    def period_cusum(self, df):\n        df = df.copy()\n        assert len(df) == 240 // self.data_granularity\n        df.iloc[0 : self.start_id] = 0\n        df = df.cumsum()\n        df.iloc[self.end_id + 1 : 240 // self.data_granularity] = 0\n        return df\n\n    def _load_internal(self, instrument, start_index, end_index, freq):\n        _calendar = get_calendar_day(freq=freq)\n        series = self.feature.load(instrument, start_index, end_index, freq)\n        return series.groupby(_calendar[series.index], group_keys=False).transform(self.period_cusum)\n\n\nclass DayLast(ElemOperator):\n    \"\"\"DayLast Operator\n\n    Parameters\n    ----------\n    feature : Expression\n        feature instance\n\n    Returns\n    ----------\n    feature:\n        a series of that each value equals the last value of its day\n    \"\"\"\n\n    def _load_internal(self, instrument, start_index, end_index, freq):\n        _calendar = get_calendar_day(freq=freq)\n        series = self.feature.load(instrument, start_index, end_index, freq)\n        return series.groupby(_calendar[series.index], group_keys=False).transform(\"last\")\n\n\nclass FFillNan(ElemOperator):\n    \"\"\"FFillNan Operator\n\n    Parameters\n    ----------\n    feature : Expression\n        feature instance\n\n    Returns\n    ----------\n    feature:\n        a forward fill nan feature\n    \"\"\"\n\n    def _load_internal(self, instrument, start_index, end_index, freq):\n        series = self.feature.load(instrument, start_index, end_index, freq)\n        return series.ffill()\n\n\nclass BFillNan(ElemOperator):\n    \"\"\"BFillNan Operator\n\n    Parameters\n    ----------\n    feature : Expression\n        feature instance\n\n    Returns\n    ----------\n    feature:\n        a backfoward fill nan feature\n    \"\"\"\n\n    def _load_internal(self, instrument, start_index, end_index, freq):\n        series = self.feature.load(instrument, start_index, end_index, freq)\n        return series.bfill()\n\n\nclass Date(ElemOperator):\n    \"\"\"Date Operator\n\n    Parameters\n    ----------\n    feature : Expression\n        feature instance\n\n    Returns\n    ----------\n    feature:\n        a series of that each value is the date corresponding to feature.index\n    \"\"\"\n\n    def _load_internal(self, instrument, start_index, end_index, freq):\n        _calendar = get_calendar_day(freq=freq)\n        series = self.feature.load(instrument, start_index, end_index, freq)\n        return pd.Series(_calendar[series.index], index=series.index)\n\n\nclass Select(PairOperator):\n    \"\"\"Select Operator\n\n    Parameters\n    ----------\n    feature_left : Expression\n        feature instance, select condition\n    feature_right : Expression\n        feature instance, select value\n\n    Returns\n    ----------\n    feature:\n        value(feature_right) that meets the condition(feature_left)\n\n    \"\"\"\n\n    def _load_internal(self, instrument, start_index, end_index, freq):\n        series_condition = self.feature_left.load(instrument, start_index, end_index, freq)\n        series_feature = self.feature_right.load(instrument, start_index, end_index, freq)\n        return series_feature.loc[series_condition]\n\n\nclass IsNull(ElemOperator):\n    \"\"\"IsNull Operator\n\n    Parameters\n    ----------\n    feature : Expression\n        feature instance\n\n    Returns\n    ----------\n    feature:\n        A series indicating whether the feature is nan\n    \"\"\"\n\n    def _load_internal(self, instrument, start_index, end_index, freq):\n        series = self.feature.load(instrument, start_index, end_index, freq)\n        return series.isnull()\n\n\nclass IsInf(ElemOperator):\n    \"\"\"IsInf Operator\n\n    Parameters\n    ----------\n    feature : Expression\n        feature instance\n\n    Returns\n    ----------\n    feature:\n        A series indicating whether the feature is inf\n    \"\"\"\n\n    def _load_internal(self, instrument, start_index, end_index, freq):\n        series = self.feature.load(instrument, start_index, end_index, freq)\n        return np.isinf(series)\n\n\nclass Cut(ElemOperator):\n    \"\"\"Cut Operator\n\n    Parameters\n    ----------\n    feature : Expression\n        feature instance\n    l : int\n        l > 0, delete the first l elements of feature (default is None, which means 0)\n    r : int\n        r < 0, delete the last -r elements of feature (default is None, which means 0)\n    Returns\n    ----------\n    feature:\n        A series with the first l and last -r elements deleted from the feature.\n        Note: It is deleted from the raw data, not the sliced data\n    \"\"\"\n\n    def __init__(self, feature, left=None, right=None):\n        self.left = left\n        self.right = right\n        if (self.left is not None and self.left <= 0) or (self.right is not None and self.right >= 0):\n            raise ValueError(\"Cut operator l shoud > 0 and r should < 0\")\n\n        super(Cut, self).__init__(feature)\n\n    def _load_internal(self, instrument, start_index, end_index, freq):\n        series = self.feature.load(instrument, start_index, end_index, freq)\n        return series.iloc[self.left : self.right]\n\n    def get_extended_window_size(self):\n        ll = 0 if self.left is None else self.left\n        rr = 0 if self.right is None else abs(self.right)\n        lft_etd, rght_etd = self.feature.get_extended_window_size()\n        lft_etd = lft_etd + ll\n        rght_etd = rght_etd + rr\n        return lft_etd, rght_etd\n"
  },
  {
    "path": "qlib/contrib/report/__init__.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nGRAPH_NAME_LIST = [\n    \"analysis_position.report_graph\",\n    \"analysis_position.score_ic_graph\",\n    \"analysis_position.cumulative_return_graph\",\n    \"analysis_position.risk_analysis_graph\",\n    \"analysis_position.rank_label_graph\",\n    \"analysis_model.model_performance_graph\",\n]\n"
  },
  {
    "path": "qlib/contrib/report/analysis_model/__init__.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom .analysis_model_performance import model_performance_graph\n\n__all__ = [\"model_performance_graph\"]\n"
  },
  {
    "path": "qlib/contrib/report/analysis_model/analysis_model_performance.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\nfrom functools import partial\n\nimport pandas as pd\n\nimport plotly.graph_objs as go\n\nimport statsmodels.api as sm\nimport matplotlib.pyplot as plt\n\nfrom scipy import stats\n\nfrom typing import Sequence\nfrom qlib.typehint import Literal\n\nfrom ..graph import ScatterGraph, SubplotsGraph, BarGraph, HeatmapGraph\nfrom ..utils import guess_plotly_rangebreaks\n\n\ndef _group_return(pred_label: pd.DataFrame = None, reverse: bool = False, N: int = 5, **kwargs) -> tuple:\n    \"\"\"\n\n    :param pred_label:\n    :param reverse:\n    :param N:\n    :return:\n    \"\"\"\n    if reverse:\n        pred_label[\"score\"] *= -1\n\n    pred_label = pred_label.sort_values(\"score\", ascending=False)\n\n    # Group1 ~ Group5 only consider the dropna values\n    pred_label_drop = pred_label.dropna(subset=[\"score\"])\n\n    # Group\n    t_df = pd.DataFrame(\n        {\n            \"Group%d\"\n            % (i + 1): pred_label_drop.groupby(level=\"datetime\", group_keys=False)[\"label\"].apply(\n                lambda x: x[len(x) // N * i : len(x) // N * (i + 1)].mean()  # pylint: disable=W0640\n            )\n            for i in range(N)\n        }\n    )\n    t_df.index = pd.to_datetime(t_df.index)\n\n    # Long-Short\n    t_df[\"long-short\"] = t_df[\"Group1\"] - t_df[\"Group%d\" % N]\n\n    # Long-Average\n    t_df[\"long-average\"] = t_df[\"Group1\"] - pred_label.groupby(level=\"datetime\", group_keys=False)[\"label\"].mean()\n\n    t_df = t_df.dropna(how=\"all\")  # for days which does not contain label\n    # Cumulative Return By Group\n    group_scatter_figure = ScatterGraph(\n        t_df.cumsum(),\n        layout=dict(\n            title=\"Cumulative Return\",\n            xaxis=dict(tickangle=45, rangebreaks=kwargs.get(\"rangebreaks\", guess_plotly_rangebreaks(t_df.index))),\n        ),\n    ).figure\n\n    t_df = t_df.loc[:, [\"long-short\", \"long-average\"]]\n    _bin_size = float(((t_df.max() - t_df.min()) / 20).min())\n    group_hist_figure = SubplotsGraph(\n        t_df,\n        kind_map=dict(kind=\"DistplotGraph\", kwargs=dict(bin_size=_bin_size)),\n        subplots_kwargs=dict(\n            rows=1,\n            cols=2,\n            print_grid=False,\n            subplot_titles=[\"long-short\", \"long-average\"],\n        ),\n    ).figure\n\n    return group_scatter_figure, group_hist_figure\n\n\ndef _plot_qq(data: pd.Series = None, dist=stats.norm) -> go.Figure:\n    \"\"\"\n\n    :param data:\n    :param dist:\n    :return:\n    \"\"\"\n    # NOTE: plotly.tools.mpl_to_plotly not actively maintained, resulting in errors in the new version of matplotlib,\n    # ref: https://github.com/plotly/plotly.py/issues/2913#issuecomment-730071567\n    # removing plotly.tools.mpl_to_plotly for greater compatibility with matplotlib versions\n    _plt_fig = sm.qqplot(data.dropna(), dist=dist, fit=True, line=\"45\")\n    plt.close(_plt_fig)\n    qqplot_data = _plt_fig.gca().lines\n    fig = go.Figure()\n\n    fig.add_trace(\n        {\n            \"type\": \"scatter\",\n            \"x\": qqplot_data[0].get_xdata(),\n            \"y\": qqplot_data[0].get_ydata(),\n            \"mode\": \"markers\",\n            \"marker\": {\"color\": \"#19d3f3\"},\n        }\n    )\n\n    fig.add_trace(\n        {\n            \"type\": \"scatter\",\n            \"x\": qqplot_data[1].get_xdata(),\n            \"y\": qqplot_data[1].get_ydata(),\n            \"mode\": \"lines\",\n            \"line\": {\"color\": \"#636efa\"},\n        }\n    )\n    del qqplot_data\n    return fig\n\n\ndef _pred_ic(\n    pred_label: pd.DataFrame = None, methods: Sequence[Literal[\"IC\", \"Rank IC\"]] = (\"IC\", \"Rank IC\"), **kwargs\n) -> tuple:\n    \"\"\"\n\n    :param pred_label: pd.DataFrame\n    must contain one column of realized return with name `label` and one column of predicted score names `score`.\n    :param methods: Sequence[Literal[\"IC\", \"Rank IC\"]]\n    IC series to plot.\n    IC is sectional pearson correlation between label and score\n    Rank IC is the spearman correlation between label and score\n    For the Monthly IC, IC histogram, IC Q-Q plot.  Only the first type of IC will be plotted.\n    :return:\n    \"\"\"\n    _methods_mapping = {\"IC\": \"pearson\", \"Rank IC\": \"spearman\"}\n\n    def _corr_series(x, method):\n        return x[\"label\"].corr(x[\"score\"], method=method)\n\n    ic_df = pd.concat(\n        [\n            pred_label.groupby(level=\"datetime\", group_keys=False)\n            .apply(partial(_corr_series, method=_methods_mapping[m]))\n            .rename(m)\n            for m in methods\n        ],\n        axis=1,\n    )\n    _ic = ic_df.iloc(axis=1)[0]\n\n    _index = _ic.index.get_level_values(0).astype(\"str\").str.replace(\"-\", \"\").str.slice(0, 6)\n    _monthly_ic = _ic.groupby(_index, group_keys=False).mean()\n    _monthly_ic.index = pd.MultiIndex.from_arrays(\n        [_monthly_ic.index.str.slice(0, 4), _monthly_ic.index.str.slice(4, 6)],\n        names=[\"year\", \"month\"],\n    )\n\n    # fill month\n    _month_list = pd.date_range(\n        start=pd.Timestamp(f\"{_index.min()[:4]}0101\"),\n        end=pd.Timestamp(f\"{_index.max()[:4]}1231\"),\n        freq=\"1M\",\n    )\n    _years = []\n    _month = []\n    for _date in _month_list:\n        _date = _date.strftime(\"%Y%m%d\")\n        _years.append(_date[:4])\n        _month.append(_date[4:6])\n\n    fill_index = pd.MultiIndex.from_arrays([_years, _month], names=[\"year\", \"month\"])\n\n    _monthly_ic = _monthly_ic.reindex(fill_index)\n\n    ic_bar_figure = ic_figure(ic_df, kwargs.get(\"show_nature_day\", False))\n\n    ic_heatmap_figure = HeatmapGraph(\n        _monthly_ic.unstack(),\n        layout=dict(title=\"Monthly IC\", xaxis=dict(dtick=1), yaxis=dict(tickformat=\"04d\", dtick=1)),\n        graph_kwargs=dict(xtype=\"array\", ytype=\"array\"),\n    ).figure\n\n    dist = stats.norm\n    _qqplot_fig = _plot_qq(_ic, dist)\n\n    if isinstance(dist, stats.norm.__class__):\n        dist_name = \"Normal\"\n    else:\n        dist_name = \"Unknown\"\n\n    _ic_df = _ic.to_frame(\"IC\")\n    _bin_size = ((_ic_df.max() - _ic_df.min()) / 20).min()\n    _sub_graph_data = [\n        (\n            \"IC\",\n            dict(\n                row=1,\n                col=1,\n                name=\"\",\n                kind=\"DistplotGraph\",\n                graph_kwargs=dict(bin_size=_bin_size),\n            ),\n        ),\n        (_qqplot_fig, dict(row=1, col=2)),\n    ]\n    ic_hist_figure = SubplotsGraph(\n        _ic_df.dropna(),\n        kind_map=dict(kind=\"HistogramGraph\", kwargs=dict()),\n        subplots_kwargs=dict(\n            rows=1,\n            cols=2,\n            print_grid=False,\n            subplot_titles=[\"IC\", \"IC %s Dist. Q-Q\" % dist_name],\n        ),\n        sub_graph_data=_sub_graph_data,\n        layout=dict(\n            yaxis2=dict(title=\"Observed Quantile\"),\n            xaxis2=dict(title=f\"{dist_name} Distribution Quantile\"),\n        ),\n    ).figure\n\n    return ic_bar_figure, ic_heatmap_figure, ic_hist_figure\n\n\ndef _pred_autocorr(pred_label: pd.DataFrame, lag=1, **kwargs) -> tuple:\n    pred = pred_label.copy()\n    pred[\"score_last\"] = pred.groupby(level=\"instrument\", group_keys=False)[\"score\"].shift(lag)\n    ac = pred.groupby(level=\"datetime\", group_keys=False).apply(\n        lambda x: x[\"score\"].rank(pct=True).corr(x[\"score_last\"].rank(pct=True))\n    )\n    _df = ac.to_frame(\"value\")\n    ac_figure = ScatterGraph(\n        _df,\n        layout=dict(\n            title=\"Auto Correlation\",\n            xaxis=dict(tickangle=45, rangebreaks=kwargs.get(\"rangebreaks\", guess_plotly_rangebreaks(_df.index))),\n        ),\n    ).figure\n    return (ac_figure,)\n\n\ndef _pred_turnover(pred_label: pd.DataFrame, N=5, lag=1, **kwargs) -> tuple:\n    pred = pred_label.copy()\n    pred[\"score_last\"] = pred.groupby(level=\"instrument\", group_keys=False)[\"score\"].shift(lag)\n    top = pred.groupby(level=\"datetime\", group_keys=False).apply(\n        lambda x: 1\n        - x.nlargest(len(x) // N, columns=\"score\").index.isin(x.nlargest(len(x) // N, columns=\"score_last\").index).sum()\n        / (len(x) // N)\n    )\n    bottom = pred.groupby(level=\"datetime\", group_keys=False).apply(\n        lambda x: 1\n        - x.nsmallest(len(x) // N, columns=\"score\")\n        .index.isin(x.nsmallest(len(x) // N, columns=\"score_last\").index)\n        .sum()\n        / (len(x) // N)\n    )\n    r_df = pd.DataFrame(\n        {\n            \"Top\": top,\n            \"Bottom\": bottom,\n        }\n    )\n    turnover_figure = ScatterGraph(\n        r_df,\n        layout=dict(\n            title=\"Top-Bottom Turnover\",\n            xaxis=dict(tickangle=45, rangebreaks=kwargs.get(\"rangebreaks\", guess_plotly_rangebreaks(r_df.index))),\n        ),\n    ).figure\n    return (turnover_figure,)\n\n\ndef ic_figure(ic_df: pd.DataFrame, show_nature_day=True, **kwargs) -> go.Figure:\n    r\"\"\"IC figure\n\n    :param ic_df: ic DataFrame\n    :param show_nature_day: whether to display the abscissa of non-trading day\n    :param \\*\\*kwargs: contains some parameters to control plot style in plotly. Currently, supports\n       - `rangebreaks`: https://plotly.com/python/time-series/#Hiding-Weekends-and-Holidays\n    :return: plotly.graph_objs.Figure\n    \"\"\"\n    if show_nature_day:\n        date_index = pd.date_range(ic_df.index.min(), ic_df.index.max())\n        ic_df = ic_df.reindex(date_index)\n    ic_bar_figure = BarGraph(\n        ic_df,\n        layout=dict(\n            title=\"Information Coefficient (IC)\",\n            xaxis=dict(tickangle=45, rangebreaks=kwargs.get(\"rangebreaks\", guess_plotly_rangebreaks(ic_df.index))),\n        ),\n    ).figure\n    return ic_bar_figure\n\n\ndef model_performance_graph(\n    pred_label: pd.DataFrame,\n    lag: int = 1,\n    N: int = 5,\n    reverse=False,\n    rank=False,\n    graph_names: list = [\"group_return\", \"pred_ic\", \"pred_autocorr\"],\n    show_notebook: bool = True,\n    show_nature_day: bool = False,\n    **kwargs,\n) -> [list, tuple]:\n    r\"\"\"Model performance\n\n    :param pred_label: index is **pd.MultiIndex**, index name is **[instrument, datetime]**; columns names is **[score, label]**.\n           It is usually same as the label of model training(e.g. \"Ref($close, -2)/Ref($close, -1) - 1\").\n\n\n            .. code-block:: python\n\n                instrument  datetime        score       label\n                SH600004    2017-12-11  -0.013502       -0.013502\n                                2017-12-12  -0.072367       -0.072367\n                                2017-12-13  -0.068605       -0.068605\n                                2017-12-14  0.012440        0.012440\n                                2017-12-15  -0.102778       -0.102778\n\n\n    :param lag: `pred.groupby(level='instrument', group_keys=False)['score'].shift(lag)`. It will be only used in the auto-correlation computing.\n    :param N: group number, default 5.\n    :param reverse: if `True`, `pred['score'] *= -1`.\n    :param rank: if **True**, calculate rank ic.\n    :param graph_names: graph names; default ['cumulative_return', 'pred_ic', 'pred_autocorr', 'pred_turnover'].\n    :param show_notebook: whether to display graphics in notebook, the default is `True`.\n    :param show_nature_day: whether to display the abscissa of non-trading day.\n    :param \\*\\*kwargs: contains some parameters to control plot style in plotly. Currently, supports\n       - `rangebreaks`: https://plotly.com/python/time-series/#Hiding-Weekends-and-Holidays\n    :return: if show_notebook is True, display in notebook; else return `plotly.graph_objs.Figure` list.\n    \"\"\"\n    figure_list = []\n    for graph_name in graph_names:\n        fun_res = eval(f\"_{graph_name}\")(\n            pred_label=pred_label, lag=lag, N=N, reverse=reverse, rank=rank, show_nature_day=show_nature_day, **kwargs\n        )\n        figure_list += fun_res\n\n    if show_notebook:\n        BarGraph.show_graph_in_notebook(figure_list)\n    else:\n        return figure_list\n"
  },
  {
    "path": "qlib/contrib/report/analysis_position/__init__.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom .cumulative_return import cumulative_return_graph\nfrom .score_ic import score_ic_graph\nfrom .report import report_graph\nfrom .rank_label import rank_label_graph\nfrom .risk_analysis import risk_analysis_graph\n\n__all__ = [\"cumulative_return_graph\", \"score_ic_graph\", \"report_graph\", \"rank_label_graph\", \"risk_analysis_graph\"]\n"
  },
  {
    "path": "qlib/contrib/report/analysis_position/cumulative_return.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport copy\nfrom typing import Iterable\n\nimport pandas as pd\nimport plotly.graph_objs as go\n\nfrom ..graph import BaseGraph, SubplotsGraph\n\nfrom ..analysis_position.parse_position import get_position_data\n\n\ndef _get_cum_return_data_with_position(\n    position: dict,\n    report_normal: pd.DataFrame,\n    label_data: pd.DataFrame,\n    start_date=None,\n    end_date=None,\n):\n    \"\"\"\n\n    :param position:\n    :param report_normal:\n    :param label_data:\n    :param start_date:\n    :param end_date:\n    :return:\n    \"\"\"\n    _cumulative_return_df = get_position_data(\n        position=position,\n        report_normal=report_normal,\n        label_data=label_data,\n        start_date=start_date,\n        end_date=end_date,\n    ).copy()\n\n    _cumulative_return_df[\"label\"] = _cumulative_return_df[\"label\"] - _cumulative_return_df[\"bench\"]\n    _cumulative_return_df = _cumulative_return_df.dropna()\n    df_gp = _cumulative_return_df.groupby(level=\"datetime\", group_keys=False)\n    result_list = []\n    for gp in df_gp:\n        date = gp[0]\n        day_df = gp[1]\n\n        _hold_df = day_df[day_df[\"status\"] == 0]\n        _buy_df = day_df[day_df[\"status\"] == 1]\n        _sell_df = day_df[day_df[\"status\"] == -1]\n\n        hold_value = (_hold_df[\"label\"] * _hold_df[\"weight\"]).sum()\n        hold_weight = _hold_df[\"weight\"].sum()\n        hold_mean = (hold_value / hold_weight) if hold_weight else 0\n\n        sell_value = (_sell_df[\"label\"] * _sell_df[\"weight\"]).sum()\n        sell_weight = _sell_df[\"weight\"].sum()\n        sell_mean = (sell_value / sell_weight) if sell_weight else 0\n\n        buy_value = (_buy_df[\"label\"] * _buy_df[\"weight\"]).sum()\n        buy_weight = _buy_df[\"weight\"].sum()\n        buy_mean = (buy_value / buy_weight) if buy_weight else 0\n\n        result_list.append(\n            dict(\n                hold_value=hold_value,\n                hold_mean=hold_mean,\n                hold_weight=hold_weight,\n                buy_value=buy_value,\n                buy_mean=buy_mean,\n                buy_weight=buy_weight,\n                sell_value=sell_value,\n                sell_mean=sell_mean,\n                sell_weight=sell_weight,\n                buy_minus_sell_value=buy_value - sell_value,\n                buy_minus_sell_mean=buy_mean - sell_mean,\n                buy_plus_sell_weight=buy_weight + sell_weight,\n                date=date,\n            )\n        )\n\n    r_df = pd.DataFrame(data=result_list)\n    r_df[\"cum_hold\"] = r_df[\"hold_mean\"].cumsum()\n    r_df[\"cum_buy\"] = r_df[\"buy_mean\"].cumsum()\n    r_df[\"cum_sell\"] = r_df[\"sell_mean\"].cumsum()\n    r_df[\"cum_buy_minus_sell\"] = r_df[\"buy_minus_sell_mean\"].cumsum()\n    return r_df\n\n\ndef _get_figure_with_position(\n    position: dict,\n    report_normal: pd.DataFrame,\n    label_data: pd.DataFrame,\n    start_date=None,\n    end_date=None,\n) -> Iterable[go.Figure]:\n    \"\"\"Get average analysis figures\n\n    :param position: position\n    :param report_normal:\n    :param label_data:\n    :param start_date:\n    :param end_date:\n    :return:\n    \"\"\"\n\n    cum_return_df = _get_cum_return_data_with_position(position, report_normal, label_data, start_date, end_date)\n    cum_return_df = cum_return_df.set_index(\"date\")\n    # FIXME: support HIGH-FREQ\n    cum_return_df.index = cum_return_df.index.strftime(\"%Y-%m-%d\")\n\n    # Create figures\n    for _t_name in [\"buy\", \"sell\", \"buy_minus_sell\", \"hold\"]:\n        sub_graph_data = [\n            (\n                \"cum_{}\".format(_t_name),\n                dict(row=1, col=1, graph_kwargs={\"mode\": \"lines+markers\", \"xaxis\": \"x3\"}),\n            ),\n            (\n                \"{}_weight\".format(_t_name.replace(\"minus\", \"plus\") if \"minus\" in _t_name else _t_name),\n                dict(row=2, col=1),\n            ),\n            (\n                \"{}_value\".format(_t_name),\n                dict(row=1, col=2, kind=\"HistogramGraph\", graph_kwargs={}),\n            ),\n        ]\n\n        _default_xaxis = dict(showline=False, zeroline=True, tickangle=45)\n        _default_yaxis = dict(zeroline=True, showline=True, showticklabels=True)\n        sub_graph_layout = dict(\n            xaxis1=dict(**_default_xaxis, type=\"category\", showticklabels=False),\n            xaxis3=dict(**_default_xaxis, type=\"category\"),\n            xaxis2=_default_xaxis,\n            yaxis1=dict(**_default_yaxis, title=_t_name),\n            yaxis2=_default_yaxis,\n            yaxis3=_default_yaxis,\n        )\n\n        mean_value = cum_return_df[\"{}_value\".format(_t_name)].mean()\n        layout = dict(\n            height=500,\n            title=f\"{_t_name}(the red line in the histogram on the right represents the average)\",\n            shapes=[\n                {\n                    \"type\": \"line\",\n                    \"xref\": \"x2\",\n                    \"yref\": \"paper\",\n                    \"x0\": mean_value,\n                    \"y0\": 0,\n                    \"x1\": mean_value,\n                    \"y1\": 1,\n                    # NOTE: 'fillcolor': '#d3d3d3', 'opacity': 0.3,\n                    \"line\": {\"color\": \"red\", \"width\": 1},\n                },\n            ],\n        )\n\n        kind_map = dict(kind=\"ScatterGraph\", kwargs=dict(mode=\"lines+markers\"))\n        specs = [\n            [{\"rowspan\": 1}, {\"rowspan\": 2}],\n            [{\"rowspan\": 1}, None],\n        ]\n        subplots_kwargs = dict(\n            vertical_spacing=0.01,\n            rows=2,\n            cols=2,\n            row_width=[1, 2],\n            column_width=[3, 1],\n            print_grid=False,\n            specs=specs,\n        )\n        yield SubplotsGraph(\n            cum_return_df,\n            layout=layout,\n            kind_map=kind_map,\n            sub_graph_layout=sub_graph_layout,\n            sub_graph_data=sub_graph_data,\n            subplots_kwargs=subplots_kwargs,\n        ).figure\n\n\ndef cumulative_return_graph(\n    position: dict,\n    report_normal: pd.DataFrame,\n    label_data: pd.DataFrame,\n    show_notebook=True,\n    start_date=None,\n    end_date=None,\n) -> Iterable[go.Figure]:\n    \"\"\"Backtest buy, sell, and holding cumulative return graph\n\n        Example:\n\n\n            .. code-block:: python\n\n                from qlib.data import D\n                from qlib.contrib.evaluate import risk_analysis, backtest, long_short_backtest\n                from qlib.contrib.strategy import TopkDropoutStrategy\n\n                # backtest parameters\n                bparas = {}\n                bparas['limit_threshold'] = 0.095\n                bparas['account'] = 1000000000\n\n                sparas = {}\n                sparas['topk'] = 50\n                sparas['n_drop'] = 5\n                strategy = TopkDropoutStrategy(**sparas)\n\n                report_normal_df, positions = backtest(pred_df, strategy, **bparas)\n\n                pred_df_dates = pred_df.index.get_level_values(level='datetime')\n                features_df = D.features(D.instruments('csi500'), ['Ref($close, -1)/$close - 1'], pred_df_dates.min(), pred_df_dates.max())\n                features_df.columns = ['label']\n\n                qcr.analysis_position.cumulative_return_graph(positions, report_normal_df, features_df)\n\n\n        Graph desc:\n\n            - Axis X: Trading day.\n            - Axis Y:\n            - Above axis Y: `(((Ref($close, -1)/$close - 1) * weight).sum() / weight.sum()).cumsum()`.\n            - Below axis Y: Daily weight sum.\n            - In the **sell** graph, `y < 0` stands for profit; in other cases, `y > 0` stands for profit.\n            - In the **buy_minus_sell** graph, the **y** value of the **weight** graph at the bottom is `buy_weight + sell_weight`.\n            - In each graph, the **red line** in the histogram on the right represents the average.\n\n    :param position: position data\n    :param report_normal:\n\n\n            .. code-block:: python\n\n                                return      cost        bench       turnover\n                date\n                2017-01-04  0.003421    0.000864    0.011693    0.576325\n                2017-01-05  0.000508    0.000447    0.000721    0.227882\n                2017-01-06  -0.003321   0.000212    -0.004322   0.102765\n                2017-01-09  0.006753    0.000212    0.006874    0.105864\n                2017-01-10  -0.000416   0.000440    -0.003350   0.208396\n\n\n    :param label_data: `D.features` result; index is `pd.MultiIndex`, index name is [`instrument`, `datetime`]; columns names is [`label`].\n\n        **The label T is the change from T to T+1**, it is recommended to use ``close``, example: `D.features(D.instruments('csi500'), ['Ref($close, -1)/$close-1'])`\n\n\n            .. code-block:: python\n\n                                                label\n                instrument  datetime\n                SH600004        2017-12-11  -0.013502\n                                2017-12-12  -0.072367\n                                2017-12-13  -0.068605\n                                2017-12-14  0.012440\n                                2017-12-15  -0.102778\n\n\n    :param show_notebook: True or False. If True, show graph in notebook, else return figures\n    :param start_date: start date\n    :param end_date: end date\n    :return:\n    \"\"\"\n    position = copy.deepcopy(position)\n    report_normal = report_normal.copy()\n    label_data.columns = [\"label\"]\n    _figures = _get_figure_with_position(position, report_normal, label_data, start_date, end_date)\n    if show_notebook:\n        BaseGraph.show_graph_in_notebook(_figures)\n    else:\n        return _figures\n"
  },
  {
    "path": "qlib/contrib/report/analysis_position/parse_position.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport pandas as pd\n\n\nfrom ....backtest.profit_attribution import get_stock_weight_df\n\n\ndef parse_position(position: dict = None) -> pd.DataFrame:\n    \"\"\"Parse position dict to position DataFrame\n\n    :param position: position data\n    :return: position DataFrame;\n\n\n        .. code-block:: python\n\n            position_df = parse_position(positions)\n            print(position_df.head())\n            # status: 0-hold, -1-sell, 1-buy\n\n                                        amount      cash      count    price status weight\n            instrument  datetime\n            SZ000547    2017-01-04  44.154290   211405.285654   1   205.189575  1   0.031255\n            SZ300202    2017-01-04  60.638845   211405.285654   1   154.356506  1   0.032290\n            SH600158    2017-01-04  46.531681   211405.285654   1   153.895142  1   0.024704\n            SH600545    2017-01-04  197.173093  211405.285654   1   48.607037   1   0.033063\n            SZ000930    2017-01-04  103.938300  211405.285654   1   80.759453   1   0.028958\n\n\n    \"\"\"\n\n    position_weight_df = get_stock_weight_df(position)\n    # If the day does not exist, use the last weight\n    position_weight_df.ffill(inplace=True)\n\n    previous_data = {\"date\": None, \"code_list\": []}\n\n    result_df = pd.DataFrame()\n    for _trading_date, _value in position.items():\n        _value = _value.position\n        # pd_date type: pd.Timestamp\n        _cash = _value.pop(\"cash\")\n        for _item in [\"now_account_value\"]:\n            if _item in _value:\n                _value.pop(_item)\n\n        _trading_day_df = pd.DataFrame.from_dict(_value, orient=\"index\")\n        _trading_day_df[\"weight\"] = position_weight_df.loc[_trading_date]\n        _trading_day_df[\"cash\"] = _cash\n        _trading_day_df[\"date\"] = _trading_date\n        # status: 0-hold, -1-sell, 1-buy\n        _trading_day_df[\"status\"] = 0\n\n        # T not exist, T-1 exist, T sell\n        _cur_day_sell = set(previous_data[\"code_list\"]) - set(_trading_day_df.index)\n        # T exist, T-1 not exist, T buy\n        _cur_day_buy = set(_trading_day_df.index) - set(previous_data[\"code_list\"])\n\n        # Trading day buy\n        _trading_day_df.loc[_trading_day_df.index.isin(_cur_day_buy), \"status\"] = 1\n\n        # Trading day sell\n        if not result_df.empty:\n            _trading_day_sell_df = result_df.loc[\n                (result_df[\"date\"] == previous_data[\"date\"]) & (result_df.index.isin(_cur_day_sell))\n            ].copy()\n            if not _trading_day_sell_df.empty:\n                _trading_day_sell_df[\"status\"] = -1\n                _trading_day_sell_df[\"date\"] = _trading_date\n                _trading_day_df = pd.concat([_trading_day_df, _trading_day_sell_df], sort=False)\n\n        result_df = pd.concat([result_df, _trading_day_df], sort=True)\n\n        previous_data = dict(\n            date=_trading_date,\n            code_list=_trading_day_df[_trading_day_df[\"status\"] != -1].index,\n        )\n\n    result_df.reset_index(inplace=True)\n    result_df.rename(columns={\"date\": \"datetime\", \"index\": \"instrument\"}, inplace=True)\n    return result_df.set_index([\"instrument\", \"datetime\"])\n\n\ndef _add_label_to_position(position_df: pd.DataFrame, label_data: pd.DataFrame) -> pd.DataFrame:\n    \"\"\"Concat position with custom label\n\n    :param position_df: position DataFrame\n    :param label_data:\n    :return: concat result\n    \"\"\"\n\n    _start_time = position_df.index.get_level_values(level=\"datetime\").min()\n    _end_time = position_df.index.get_level_values(level=\"datetime\").max()\n    label_data = label_data.loc(axis=0)[:, pd.to_datetime(_start_time) :]\n    _result_df = pd.concat([position_df, label_data], axis=1, sort=True).reindex(label_data.index)\n    _result_df = _result_df.loc[_result_df.index.get_level_values(1) <= _end_time]\n    return _result_df\n\n\ndef _add_bench_to_position(position_df: pd.DataFrame = None, bench: pd.Series = None) -> pd.DataFrame:\n    \"\"\"Concat position with bench\n\n    :param position_df: position DataFrame\n    :param bench: report normal data\n    :return: concat result\n    \"\"\"\n    _temp_df = position_df.reset_index(level=\"instrument\")\n    # FIXME: After the stock is bought and sold, the rise and fall of the next trading day are calculated.\n    _temp_df[\"bench\"] = bench.shift(-1)\n    res_df = _temp_df.set_index([\"instrument\", _temp_df.index])\n    return res_df\n\n\ndef _calculate_label_rank(df: pd.DataFrame) -> pd.DataFrame:\n    \"\"\"calculate label rank\n\n    :param df:\n    :return:\n    \"\"\"\n    _label_name = \"label\"\n\n    def _calculate_day_value(g_df: pd.DataFrame):\n        g_df = g_df.copy()\n        g_df[\"rank_ratio\"] = g_df[_label_name].rank(ascending=False) / len(g_df) * 100\n\n        # Sell: -1, Hold: 0, Buy: 1\n        for i in [-1, 0, 1]:\n            g_df.loc[g_df[\"status\"] == i, \"rank_label_mean\"] = g_df[g_df[\"status\"] == i][\"rank_ratio\"].mean()\n\n        g_df[\"excess_return\"] = g_df[_label_name] - g_df[_label_name].mean()\n        return g_df\n\n    return df.groupby(level=\"datetime\", group_keys=False).apply(_calculate_day_value)\n\n\ndef get_position_data(\n    position: dict,\n    label_data: pd.DataFrame,\n    report_normal: pd.DataFrame = None,\n    calculate_label_rank=False,\n    start_date=None,\n    end_date=None,\n) -> pd.DataFrame:\n    \"\"\"Concat position data with pred/report_normal\n\n    :param position: position data\n    :param report_normal: report normal, must be container 'bench' column\n    :param label_data:\n    :param calculate_label_rank:\n    :param start_date: start date\n    :param end_date: end date\n    :return: concat result,\n        columns: ['amount', 'cash', 'count', 'price', 'status', 'weight', 'label',\n                    'rank_ratio', 'rank_label_mean', 'excess_return', 'score', 'bench']\n        index: ['instrument', 'date']\n    \"\"\"\n    _position_df = parse_position(position)\n\n    # Add custom_label, rank_ratio, rank_mean, and excess_return field\n    _position_df = _add_label_to_position(_position_df, label_data)\n\n    if calculate_label_rank:\n        _position_df = _calculate_label_rank(_position_df)\n\n    if report_normal is not None:\n        # Add bench field\n        _position_df = _add_bench_to_position(_position_df, report_normal[\"bench\"])\n\n    _date_list = _position_df.index.get_level_values(level=\"datetime\")\n    start_date = _date_list.min() if start_date is None else start_date\n    end_date = _date_list.max() if end_date is None else end_date\n    _position_df = _position_df.loc[(start_date <= _date_list) & (_date_list <= end_date)]\n    return _position_df\n"
  },
  {
    "path": "qlib/contrib/report/analysis_position/rank_label.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport copy\nfrom typing import Iterable\n\nimport pandas as pd\nimport plotly.graph_objs as go\n\nfrom ..graph import ScatterGraph\nfrom ..analysis_position.parse_position import get_position_data\n\n\ndef _get_figure_with_position(\n    position: dict, label_data: pd.DataFrame, start_date=None, end_date=None\n) -> Iterable[go.Figure]:\n    \"\"\"Get average analysis figures\n\n    :param position: position\n    :param label_data:\n    :param start_date:\n    :param end_date:\n    :return:\n    \"\"\"\n    _position_df = get_position_data(\n        position,\n        label_data,\n        calculate_label_rank=True,\n        start_date=start_date,\n        end_date=end_date,\n    )\n\n    res_dict = dict()\n    _pos_gp = _position_df.groupby(level=1, group_keys=False)\n    for _item in _pos_gp:\n        _date = _item[0]\n        _day_df = _item[1]\n\n        _day_value = res_dict.setdefault(_date, {})\n        for _i, _name in {0: \"Hold\", 1: \"Buy\", -1: \"Sell\"}.items():\n            _temp_df = _day_df[_day_df[\"status\"] == _i]\n            if _temp_df.empty:\n                _day_value[_name] = 0\n            else:\n                _day_value[_name] = _temp_df[\"rank_label_mean\"].values[0]\n\n    _res_df = pd.DataFrame.from_dict(res_dict, orient=\"index\")\n    # FIXME: support HIGH-FREQ\n    _res_df.index = _res_df.index.strftime(\"%Y-%m-%d\")\n    for _col in _res_df.columns:\n        yield ScatterGraph(\n            _res_df.loc[:, [_col]],\n            layout=dict(\n                title=_col,\n                xaxis=dict(type=\"category\", tickangle=45),\n                yaxis=dict(title=\"lable-rank-ratio: %\"),\n            ),\n            graph_kwargs=dict(mode=\"lines+markers\"),\n        ).figure\n\n\ndef rank_label_graph(\n    position: dict,\n    label_data: pd.DataFrame,\n    start_date=None,\n    end_date=None,\n    show_notebook=True,\n) -> Iterable[go.Figure]:\n    \"\"\"Ranking percentage of stocks buy, sell, and holding on the trading day.\n    Average rank-ratio(similar to **sell_df['label'].rank(ascending=False) / len(sell_df)**) of daily trading\n\n        Example:\n\n\n            .. code-block:: python\n\n                from qlib.data import D\n                from qlib.contrib.evaluate import backtest\n                from qlib.contrib.strategy import TopkDropoutStrategy\n\n                # backtest parameters\n                bparas = {}\n                bparas['limit_threshold'] = 0.095\n                bparas['account'] = 1000000000\n\n                sparas = {}\n                sparas['topk'] = 50\n                sparas['n_drop'] = 230\n                strategy = TopkDropoutStrategy(**sparas)\n\n                _, positions = backtest(pred_df, strategy, **bparas)\n\n                pred_df_dates = pred_df.index.get_level_values(level='datetime')\n                features_df = D.features(D.instruments('csi500'), ['Ref($close, -1)/$close-1'], pred_df_dates.min(), pred_df_dates.max())\n                features_df.columns = ['label']\n\n                qcr.analysis_position.rank_label_graph(positions, features_df, pred_df_dates.min(), pred_df_dates.max())\n\n\n    :param position: position data; **qlib.backtest.backtest** result.\n    :param label_data: **D.features** result; index is **pd.MultiIndex**, index name is **[instrument, datetime]**; columns names is **[label]**.\n\n        **The label T is the change from T to T+1**, it is recommended to use ``close``, example: `D.features(D.instruments('csi500'), ['Ref($close, -1)/$close-1'])`.\n\n\n            .. code-block:: python\n\n                                                label\n                instrument  datetime\n                SH600004        2017-12-11  -0.013502\n                                2017-12-12  -0.072367\n                                2017-12-13  -0.068605\n                                2017-12-14  0.012440\n                                2017-12-15  -0.102778\n\n\n    :param start_date: start date\n    :param end_date: end_date\n    :param show_notebook: **True** or **False**. If True, show graph in notebook, else return figures.\n    :return:\n    \"\"\"\n    position = copy.deepcopy(position)\n    label_data.columns = [\"label\"]\n    _figures = _get_figure_with_position(position, label_data, start_date, end_date)\n    if show_notebook:\n        ScatterGraph.show_graph_in_notebook(_figures)\n    else:\n        return _figures\n"
  },
  {
    "path": "qlib/contrib/report/analysis_position/report.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport pandas as pd\n\nfrom ..graph import SubplotsGraph, BaseGraph\n\n\ndef _calculate_maximum(df: pd.DataFrame, is_ex: bool = False):\n    \"\"\"\n\n    :param df:\n    :param is_ex:\n    :return:\n    \"\"\"\n    if is_ex:\n        end_date = df[\"cum_ex_return_wo_cost_mdd\"].idxmin()\n        start_date = df.loc[df.index <= end_date][\"cum_ex_return_wo_cost\"].idxmax()\n    else:\n        end_date = df[\"return_wo_mdd\"].idxmin()\n        start_date = df.loc[df.index <= end_date][\"cum_return_wo_cost\"].idxmax()\n    return start_date, end_date\n\n\ndef _calculate_mdd(series):\n    \"\"\"\n    Calculate mdd\n\n    :param series:\n    :return:\n    \"\"\"\n    return series - series.cummax()\n\n\ndef _calculate_report_data(df: pd.DataFrame) -> pd.DataFrame:\n    \"\"\"\n\n    :param df:\n    :return:\n    \"\"\"\n    index_names = df.index.names\n    df.index = df.index.strftime(\"%Y-%m-%d\")\n\n    report_df = pd.DataFrame()\n\n    report_df[\"cum_bench\"] = df[\"bench\"].cumsum()\n    report_df[\"cum_return_wo_cost\"] = df[\"return\"].cumsum()\n    report_df[\"cum_return_w_cost\"] = (df[\"return\"] - df[\"cost\"]).cumsum()\n    # report_df['cum_return'] - report_df['cum_return'].cummax()\n    report_df[\"return_wo_mdd\"] = _calculate_mdd(report_df[\"cum_return_wo_cost\"])\n    report_df[\"return_w_cost_mdd\"] = _calculate_mdd((df[\"return\"] - df[\"cost\"]).cumsum())\n\n    report_df[\"cum_ex_return_wo_cost\"] = (df[\"return\"] - df[\"bench\"]).cumsum()\n    report_df[\"cum_ex_return_w_cost\"] = (df[\"return\"] - df[\"bench\"] - df[\"cost\"]).cumsum()\n    report_df[\"cum_ex_return_wo_cost_mdd\"] = _calculate_mdd((df[\"return\"] - df[\"bench\"]).cumsum())\n    report_df[\"cum_ex_return_w_cost_mdd\"] = _calculate_mdd((df[\"return\"] - df[\"cost\"] - df[\"bench\"]).cumsum())\n    # return_wo_mdd , return_w_cost_mdd,  cum_ex_return_wo_cost_mdd, cum_ex_return_w\n\n    report_df[\"turnover\"] = df[\"turnover\"]\n    report_df.sort_index(ascending=True, inplace=True)\n\n    report_df.index.names = index_names\n    return report_df\n\n\ndef _report_figure(df: pd.DataFrame) -> [list, tuple]:\n    \"\"\"\n\n    :param df:\n    :return:\n    \"\"\"\n\n    # Get data\n    report_df = _calculate_report_data(df)\n\n    # Maximum Drawdown\n    max_start_date, max_end_date = _calculate_maximum(report_df)\n    ex_max_start_date, ex_max_end_date = _calculate_maximum(report_df, True)\n\n    index_name = report_df.index.name\n    _temp_df = report_df.reset_index()\n    _temp_df.loc[-1] = 0\n    _temp_df = _temp_df.shift(1)\n    _temp_df.loc[0, index_name] = \"T0\"\n    _temp_df.set_index(index_name, inplace=True)\n    _temp_df.iloc[0] = 0\n    report_df = _temp_df\n\n    # Create figure\n    _default_kind_map = dict(kind=\"ScatterGraph\", kwargs={\"mode\": \"lines+markers\"})\n    _temp_fill_args = {\"fill\": \"tozeroy\", \"mode\": \"lines+markers\"}\n    _column_row_col_dict = [\n        (\"cum_bench\", dict(row=1, col=1)),\n        (\"cum_return_wo_cost\", dict(row=1, col=1)),\n        (\"cum_return_w_cost\", dict(row=1, col=1)),\n        (\"return_wo_mdd\", dict(row=2, col=1, graph_kwargs=_temp_fill_args)),\n        (\"return_w_cost_mdd\", dict(row=3, col=1, graph_kwargs=_temp_fill_args)),\n        (\"cum_ex_return_wo_cost\", dict(row=4, col=1)),\n        (\"cum_ex_return_w_cost\", dict(row=4, col=1)),\n        (\"turnover\", dict(row=5, col=1)),\n        (\"cum_ex_return_w_cost_mdd\", dict(row=6, col=1, graph_kwargs=_temp_fill_args)),\n        (\"cum_ex_return_wo_cost_mdd\", dict(row=7, col=1, graph_kwargs=_temp_fill_args)),\n    ]\n\n    _subplot_layout = dict()\n    for i in range(1, 8):\n        # yaxis\n        _subplot_layout.update({\"yaxis{}\".format(i): dict(zeroline=True, showline=True, showticklabels=True)})\n        _show_line = i == 7\n        _subplot_layout.update({\"xaxis{}\".format(i): dict(showline=_show_line, type=\"category\", tickangle=45)})\n\n    _layout_style = dict(\n        height=1200,\n        title=\" \",\n        shapes=[\n            {\n                \"type\": \"rect\",\n                \"xref\": \"x\",\n                \"yref\": \"paper\",\n                \"x0\": max_start_date,\n                \"y0\": 0.55,\n                \"x1\": max_end_date,\n                \"y1\": 1,\n                \"fillcolor\": \"#d3d3d3\",\n                \"opacity\": 0.3,\n                \"line\": {\n                    \"width\": 0,\n                },\n            },\n            {\n                \"type\": \"rect\",\n                \"xref\": \"x\",\n                \"yref\": \"paper\",\n                \"x0\": ex_max_start_date,\n                \"y0\": 0,\n                \"x1\": ex_max_end_date,\n                \"y1\": 0.55,\n                \"fillcolor\": \"#d3d3d3\",\n                \"opacity\": 0.3,\n                \"line\": {\n                    \"width\": 0,\n                },\n            },\n        ],\n    )\n\n    _subplot_kwargs = dict(\n        shared_xaxes=True,\n        vertical_spacing=0.01,\n        rows=7,\n        cols=1,\n        row_width=[1, 1, 1, 3, 1, 1, 3],\n        print_grid=False,\n    )\n    figure = SubplotsGraph(\n        df=report_df,\n        layout=_layout_style,\n        sub_graph_data=_column_row_col_dict,\n        subplots_kwargs=_subplot_kwargs,\n        kind_map=_default_kind_map,\n        sub_graph_layout=_subplot_layout,\n    ).figure\n    return (figure,)\n\n\ndef report_graph(report_df: pd.DataFrame, show_notebook: bool = True) -> [list, tuple]:\n    \"\"\"display backtest report\n\n        Example:\n\n\n            .. code-block:: python\n\n                import qlib\n                import pandas as pd\n                from qlib.utils.time import Freq\n                from qlib.utils import flatten_dict\n                from qlib.backtest import backtest, executor\n                from qlib.contrib.evaluate import risk_analysis\n                from qlib.contrib.strategy import TopkDropoutStrategy\n\n                # init qlib\n                qlib.init(provider_uri=<qlib data dir>)\n\n                CSI300_BENCH = \"SH000300\"\n                FREQ = \"day\"\n                STRATEGY_CONFIG = {\n                    \"topk\": 50,\n                    \"n_drop\": 5,\n                    # pred_score, pd.Series\n                    \"signal\": pred_score,\n                }\n\n                EXECUTOR_CONFIG = {\n                    \"time_per_step\": \"day\",\n                    \"generate_portfolio_metrics\": True,\n                }\n\n                backtest_config = {\n                    \"start_time\": \"2017-01-01\",\n                    \"end_time\": \"2020-08-01\",\n                    \"account\": 100000000,\n                    \"benchmark\": CSI300_BENCH,\n                    \"exchange_kwargs\": {\n                        \"freq\": FREQ,\n                        \"limit_threshold\": 0.095,\n                        \"deal_price\": \"close\",\n                        \"open_cost\": 0.0005,\n                        \"close_cost\": 0.0015,\n                        \"min_cost\": 5,\n                    },\n                }\n\n                # strategy object\n                strategy_obj = TopkDropoutStrategy(**STRATEGY_CONFIG)\n                # executor object\n                executor_obj = executor.SimulatorExecutor(**EXECUTOR_CONFIG)\n                # backtest\n                portfolio_metric_dict, indicator_dict = backtest(executor=executor_obj, strategy=strategy_obj, **backtest_config)\n                analysis_freq = \"{0}{1}\".format(*Freq.parse(FREQ))\n                # backtest info\n                report_normal_df, positions_normal = portfolio_metric_dict.get(analysis_freq)\n\n                qcr.analysis_position.report_graph(report_normal_df)\n\n    :param report_df: **df.index.name** must be **date**, **df.columns** must contain **return**, **turnover**, **cost**, **bench**.\n\n\n            .. code-block:: python\n\n                            return      cost        bench       turnover\n                date\n                2017-01-04  0.003421    0.000864    0.011693    0.576325\n                2017-01-05  0.000508    0.000447    0.000721    0.227882\n                2017-01-06  -0.003321   0.000212    -0.004322   0.102765\n                2017-01-09  0.006753    0.000212    0.006874    0.105864\n                2017-01-10  -0.000416   0.000440    -0.003350   0.208396\n\n\n    :param show_notebook: whether to display graphics in notebook, the default is **True**.\n    :return: if show_notebook is True, display in notebook; else return **plotly.graph_objs.Figure** list.\n    \"\"\"\n    report_df = report_df.copy()\n    fig_list = _report_figure(report_df)\n    if show_notebook:\n        BaseGraph.show_graph_in_notebook(fig_list)\n    else:\n        return fig_list\n"
  },
  {
    "path": "qlib/contrib/report/analysis_position/risk_analysis.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom typing import Iterable\n\nimport pandas as pd\n\nimport plotly.graph_objs as py\n\nfrom ...evaluate import risk_analysis\n\nfrom ..graph import SubplotsGraph, ScatterGraph\n\n\ndef _get_risk_analysis_data_with_report(\n    report_normal_df: pd.DataFrame,\n    # report_long_short_df: pd.DataFrame,\n    date: pd.Timestamp,\n) -> pd.DataFrame:\n    \"\"\"Get risk analysis data with report\n\n    :param report_normal_df: report data\n    :param report_long_short_df: report data\n    :param date: date string\n    :return:\n    \"\"\"\n\n    analysis = dict()\n    # if not report_long_short_df.empty:\n    #     analysis[\"pred_long\"] = risk_analysis(report_long_short_df[\"long\"])\n    #     analysis[\"pred_short\"] = risk_analysis(report_long_short_df[\"short\"])\n    #     analysis[\"pred_long_short\"] = risk_analysis(report_long_short_df[\"long_short\"])\n\n    if not report_normal_df.empty:\n        analysis[\"excess_return_without_cost\"] = risk_analysis(report_normal_df[\"return\"] - report_normal_df[\"bench\"])\n        analysis[\"excess_return_with_cost\"] = risk_analysis(\n            report_normal_df[\"return\"] - report_normal_df[\"bench\"] - report_normal_df[\"cost\"]\n        )\n    analysis_df = pd.concat(analysis)  # type: pd.DataFrame\n    analysis_df[\"date\"] = date\n    return analysis_df\n\n\ndef _get_all_risk_analysis(risk_df: pd.DataFrame) -> pd.DataFrame:\n    \"\"\"risk_df to standard\n\n    :param risk_df: risk data\n    :return:\n    \"\"\"\n    if risk_df is None:\n        return pd.DataFrame()\n    risk_df = risk_df.unstack()\n    risk_df.columns = risk_df.columns.droplevel(0)\n    return risk_df.drop(\"mean\", axis=1)\n\n\ndef _get_monthly_risk_analysis_with_report(report_normal_df: pd.DataFrame) -> pd.DataFrame:\n    \"\"\"Get monthly analysis data\n\n    :param report_normal_df:\n    # :param report_long_short_df:\n    :return:\n    \"\"\"\n\n    # Group by month\n    report_normal_gp = report_normal_df.groupby(\n        [report_normal_df.index.year, report_normal_df.index.month], group_keys=False\n    )\n    # report_long_short_gp = report_long_short_df.groupby(\n    #     [report_long_short_df.index.year, report_long_short_df.index.month], group_keys=False\n    # )\n\n    gp_month = sorted(set(report_normal_gp.size().index))\n\n    _monthly_df = pd.DataFrame()\n    for gp_m in gp_month:\n        _m_report_normal = report_normal_gp.get_group(gp_m)\n        # _m_report_long_short = report_long_short_gp.get_group(gp_m)\n\n        if len(_m_report_normal) < 3:\n            # The month's data is less than 3, not displayed\n            # FIXME: If the trading day of a month is less than 3 days, a breakpoint will appear in the graph\n            continue\n        month_days = pd.Timestamp(year=gp_m[0], month=gp_m[1], day=1).days_in_month\n        _temp_df = _get_risk_analysis_data_with_report(\n            _m_report_normal,\n            # _m_report_long_short,\n            pd.Timestamp(year=gp_m[0], month=gp_m[1], day=month_days),\n        )\n        _monthly_df = pd.concat([_monthly_df, _temp_df], sort=False)\n\n    return _monthly_df\n\n\ndef _get_monthly_analysis_with_feature(monthly_df: pd.DataFrame, feature: str = \"annualized_return\") -> pd.DataFrame:\n    \"\"\"\n\n    :param monthly_df:\n    :param feature:\n    :return:\n    \"\"\"\n    _monthly_df_gp = monthly_df.reset_index().groupby([\"level_1\"], group_keys=False)\n\n    _name_df = _monthly_df_gp.get_group(feature).set_index([\"level_0\", \"level_1\"])\n    _temp_df = _name_df.pivot_table(index=\"date\", values=[\"risk\"], columns=_name_df.index)\n    _temp_df.columns = map(lambda x: \"_\".join(x[-1]), _temp_df.columns)\n    _temp_df.index = _temp_df.index.strftime(\"%Y-%m\")\n\n    return _temp_df\n\n\ndef _get_risk_analysis_figure(analysis_df: pd.DataFrame) -> Iterable[py.Figure]:\n    \"\"\"Get analysis graph figure\n\n    :param analysis_df:\n    :return:\n    \"\"\"\n    if analysis_df is None:\n        return []\n\n    _figure = SubplotsGraph(\n        _get_all_risk_analysis(analysis_df),\n        kind_map=dict(kind=\"BarGraph\", kwargs={}),\n        subplots_kwargs={\"rows\": 1, \"cols\": 4},\n    ).figure\n    return (_figure,)\n\n\ndef _get_monthly_risk_analysis_figure(report_normal_df: pd.DataFrame) -> Iterable[py.Figure]:\n    \"\"\"Get analysis monthly graph figure\n\n    :param report_normal_df:\n    :param report_long_short_df:\n    :return:\n    \"\"\"\n\n    # if report_normal_df is None and report_long_short_df is None:\n    #     return []\n    if report_normal_df is None:\n        return []\n\n    # if report_normal_df is None:\n    #     report_normal_df = pd.DataFrame(index=report_long_short_df.index)\n\n    # if report_long_short_df is None:\n    #     report_long_short_df = pd.DataFrame(index=report_normal_df.index)\n\n    _monthly_df = _get_monthly_risk_analysis_with_report(\n        report_normal_df=report_normal_df,\n        # report_long_short_df=report_long_short_df,\n    )\n\n    for _feature in [\"annualized_return\", \"max_drawdown\", \"information_ratio\", \"std\"]:\n        _temp_df = _get_monthly_analysis_with_feature(_monthly_df, _feature)\n        yield ScatterGraph(\n            _temp_df,\n            layout=dict(title=_feature, xaxis=dict(type=\"category\", tickangle=45)),\n            graph_kwargs={\"mode\": \"lines+markers\"},\n        ).figure\n\n\ndef risk_analysis_graph(\n    analysis_df: pd.DataFrame = None,\n    report_normal_df: pd.DataFrame = None,\n    report_long_short_df: pd.DataFrame = None,\n    show_notebook: bool = True,\n) -> Iterable[py.Figure]:\n    \"\"\"Generate analysis graph and monthly analysis\n\n        Example:\n\n\n            .. code-block:: python\n\n                import qlib\n                import pandas as pd\n                from qlib.utils.time import Freq\n                from qlib.utils import flatten_dict\n                from qlib.backtest import backtest, executor\n                from qlib.contrib.evaluate import risk_analysis\n                from qlib.contrib.strategy import TopkDropoutStrategy\n\n                # init qlib\n                qlib.init(provider_uri=<qlib data dir>)\n\n                CSI300_BENCH = \"SH000300\"\n                FREQ = \"day\"\n                STRATEGY_CONFIG = {\n                    \"topk\": 50,\n                    \"n_drop\": 5,\n                    # pred_score, pd.Series\n                    \"signal\": pred_score,\n                }\n\n                EXECUTOR_CONFIG = {\n                    \"time_per_step\": \"day\",\n                    \"generate_portfolio_metrics\": True,\n                }\n\n                backtest_config = {\n                    \"start_time\": \"2017-01-01\",\n                    \"end_time\": \"2020-08-01\",\n                    \"account\": 100000000,\n                    \"benchmark\": CSI300_BENCH,\n                    \"exchange_kwargs\": {\n                        \"freq\": FREQ,\n                        \"limit_threshold\": 0.095,\n                        \"deal_price\": \"close\",\n                        \"open_cost\": 0.0005,\n                        \"close_cost\": 0.0015,\n                        \"min_cost\": 5,\n                    },\n                }\n\n                # strategy object\n                strategy_obj = TopkDropoutStrategy(**STRATEGY_CONFIG)\n                # executor object\n                executor_obj = executor.SimulatorExecutor(**EXECUTOR_CONFIG)\n                # backtest\n                portfolio_metric_dict, indicator_dict = backtest(executor=executor_obj, strategy=strategy_obj, **backtest_config)\n                analysis_freq = \"{0}{1}\".format(*Freq.parse(FREQ))\n                # backtest info\n                report_normal_df, positions_normal = portfolio_metric_dict.get(analysis_freq)\n                analysis = dict()\n                analysis[\"excess_return_without_cost\"] = risk_analysis(\n                    report_normal_df[\"return\"] - report_normal_df[\"bench\"], freq=analysis_freq\n                )\n                analysis[\"excess_return_with_cost\"] = risk_analysis(\n                    report_normal_df[\"return\"] - report_normal_df[\"bench\"] - report_normal_df[\"cost\"], freq=analysis_freq\n                )\n\n                analysis_df = pd.concat(analysis)  # type: pd.DataFrame\n                analysis_position.risk_analysis_graph(analysis_df, report_normal_df)\n\n\n\n    :param analysis_df: analysis data, index is **pd.MultiIndex**; columns names is **[risk]**.\n\n\n            .. code-block:: python\n\n                                                                  risk\n                excess_return_without_cost mean               0.000692\n                                           std                0.005374\n                                           annualized_return  0.174495\n                                           information_ratio  2.045576\n                                           max_drawdown      -0.079103\n                excess_return_with_cost    mean               0.000499\n                                           std                0.005372\n                                           annualized_return  0.125625\n                                           information_ratio  1.473152\n                                           max_drawdown      -0.088263\n\n\n    :param report_normal_df: **df.index.name** must be **date**, df.columns must contain **return**, **turnover**, **cost**, **bench**.\n\n\n            .. code-block:: python\n\n                            return      cost        bench       turnover\n                date\n                2017-01-04  0.003421    0.000864    0.011693    0.576325\n                2017-01-05  0.000508    0.000447    0.000721    0.227882\n                2017-01-06  -0.003321   0.000212    -0.004322   0.102765\n                2017-01-09  0.006753    0.000212    0.006874    0.105864\n                2017-01-10  -0.000416   0.000440    -0.003350   0.208396\n\n\n    :param report_long_short_df: **df.index.name** must be **date**, df.columns contain **long**, **short**, **long_short**.\n\n\n            .. code-block:: python\n\n                            long        short       long_short\n                date\n                2017-01-04  -0.001360   0.001394    0.000034\n                2017-01-05  0.002456    0.000058    0.002514\n                2017-01-06  0.000120    0.002739    0.002859\n                2017-01-09  0.001436    0.001838    0.003273\n                2017-01-10  0.000824    -0.001944   -0.001120\n\n\n    :param show_notebook: Whether to display graphics in a notebook, default **True**.\n        If True, show graph in notebook\n        If False, return graph figure\n    :return:\n    \"\"\"\n    _figure_list = list(_get_risk_analysis_figure(analysis_df)) + list(\n        _get_monthly_risk_analysis_figure(\n            report_normal_df,\n            # report_long_short_df,\n        )\n    )\n    if show_notebook:\n        ScatterGraph.show_graph_in_notebook(_figure_list)\n    else:\n        return _figure_list\n"
  },
  {
    "path": "qlib/contrib/report/analysis_position/score_ic.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport pandas as pd\n\nfrom ..graph import ScatterGraph\nfrom ..utils import guess_plotly_rangebreaks\n\n\ndef _get_score_ic(pred_label: pd.DataFrame):\n    \"\"\"\n\n    :param pred_label:\n    :return:\n    \"\"\"\n    concat_data = pred_label.copy()\n    concat_data.dropna(axis=0, how=\"any\", inplace=True)\n    _ic = concat_data.groupby(level=\"datetime\", group_keys=False).apply(lambda x: x[\"label\"].corr(x[\"score\"]))\n    _rank_ic = concat_data.groupby(level=\"datetime\", group_keys=False).apply(\n        lambda x: x[\"label\"].corr(x[\"score\"], method=\"spearman\")\n    )\n    return pd.DataFrame({\"ic\": _ic, \"rank_ic\": _rank_ic})\n\n\ndef score_ic_graph(pred_label: pd.DataFrame, show_notebook: bool = True, **kwargs) -> [list, tuple]:\n    \"\"\"score IC\n\n        Example:\n\n\n            .. code-block:: python\n\n                from qlib.data import D\n                from qlib.contrib.report import analysis_position\n                pred_df_dates = pred_df.index.get_level_values(level='datetime')\n                features_df = D.features(D.instruments('csi500'), ['Ref($close, -2)/Ref($close, -1)-1'], pred_df_dates.min(), pred_df_dates.max())\n                features_df.columns = ['label']\n                pred_label = pd.concat([features_df, pred], axis=1, sort=True).reindex(features_df.index)\n                analysis_position.score_ic_graph(pred_label)\n\n\n    :param pred_label: index is **pd.MultiIndex**, index name is **[instrument, datetime]**; columns names is **[score, label]**.\n\n\n            .. code-block:: python\n\n                instrument  datetime        score         label\n                SH600004  2017-12-11     -0.013502       -0.013502\n                            2017-12-12   -0.072367       -0.072367\n                            2017-12-13   -0.068605       -0.068605\n                            2017-12-14    0.012440        0.012440\n                            2017-12-15   -0.102778       -0.102778\n\n\n    :param show_notebook: whether to display graphics in notebook, the default is **True**.\n    :return: if show_notebook is True, display in notebook; else return **plotly.graph_objs.Figure** list.\n    \"\"\"\n    _ic_df = _get_score_ic(pred_label)\n\n    _figure = ScatterGraph(\n        _ic_df,\n        layout=dict(\n            title=\"Score IC\",\n            xaxis=dict(tickangle=45, rangebreaks=kwargs.get(\"rangebreaks\", guess_plotly_rangebreaks(_ic_df.index))),\n        ),\n        graph_kwargs={\"mode\": \"lines+markers\"},\n    ).figure\n    if show_notebook:\n        ScatterGraph.show_graph_in_notebook([_figure])\n    else:\n        return (_figure,)\n"
  },
  {
    "path": "qlib/contrib/report/data/__init__.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\n\"\"\"\nThis module is designed to analysis data\n\n\"\"\"\n"
  },
  {
    "path": "qlib/contrib/report/data/ana.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\"\"\"\nHere we have a comprehensive set of analysis classes.\n\nHere is an example.\n\n.. code-block:: python\n\n    from qlib.contrib.report.data.ana import FeaMeanStd\n    fa = FeaMeanStd(ret_df)\n    fa.plot_all(wspace=0.3, sub_figsize=(12, 3), col_n=5)\n\n\"\"\"\n\nimport pandas as pd\nimport numpy as np\nfrom qlib.contrib.report.data.base import FeaAnalyser\nfrom qlib.contrib.report.utils import sub_fig_generator\nfrom qlib.utils.paral import datetime_groupby_apply\nfrom qlib.contrib.eva.alpha import pred_autocorr_all\nfrom loguru import logger\nimport seaborn as sns\n\nDT_COL_NAME = \"datetime\"\n\n\nclass CombFeaAna(FeaAnalyser):\n    \"\"\"\n    Combine the sub feature analysers and plot then in a single graph\n    \"\"\"\n\n    def __init__(self, dataset: pd.DataFrame, *fea_ana_cls):\n        if len(fea_ana_cls) <= 1:\n            raise NotImplementedError(f\"This type of input is not supported\")\n        self._fea_ana_l = [fcls(dataset) for fcls in fea_ana_cls]\n        super().__init__(dataset=dataset)\n\n    def skip(self, col):\n        return np.all(list(map(lambda fa: fa.skip(col), self._fea_ana_l)))\n\n    def calc_stat_values(self):\n        \"\"\"The statistics of features are finished in the underlying analysers\"\"\"\n\n    def plot_all(self, *args, **kwargs):\n        ax_gen = iter(sub_fig_generator(row_n=len(self._fea_ana_l), *args, **kwargs))\n\n        for col in self._dataset:\n            if not self.skip(col):\n                axes = next(ax_gen)\n                for fa, ax in zip(self._fea_ana_l, axes):\n                    if not fa.skip(col):\n                        fa.plot_single(col, ax)\n                    ax.set_xlabel(\"\")\n                    ax.set_title(\"\")\n                axes[0].set_title(col)\n\n\nclass NumFeaAnalyser(FeaAnalyser):\n    def skip(self, col):\n        is_obj = np.issubdtype(self._dataset[col], np.dtype(\"O\"))\n        if is_obj:\n            logger.info(f\"{col} is not numeric and is skipped\")\n        return is_obj\n\n\nclass ValueCNT(FeaAnalyser):\n    def __init__(self, dataset: pd.DataFrame, ratio=False):\n        self.ratio = ratio\n        super().__init__(dataset)\n\n    def calc_stat_values(self):\n        self._val_cnt = {}\n        for col, item in self._dataset.items():\n            if not super().skip(col):\n                self._val_cnt[col] = item.groupby(DT_COL_NAME, group_keys=False).apply(lambda s: len(s.unique()))\n        self._val_cnt = pd.DataFrame(self._val_cnt)\n        if self.ratio:\n            self._val_cnt = self._val_cnt.div(self._dataset.groupby(DT_COL_NAME, group_keys=False).size(), axis=0)\n\n        # TODO: transfer this feature to other analysers\n        ymin, ymax = self._val_cnt.min().min(), self._val_cnt.max().max()\n        self.ylim = (ymin - 0.05 * (ymax - ymin), ymax + 0.05 * (ymax - ymin))\n\n    def plot_single(self, col, ax):\n        self._val_cnt[col].plot(ax=ax, title=col, ylim=self.ylim)\n        ax.set_xlabel(\"\")\n\n\nclass FeaDistAna(NumFeaAnalyser):\n    def plot_single(self, col, ax):\n        sns.histplot(self._dataset[col], ax=ax, kde=False, bins=100)\n        ax.set_xlabel(\"\")\n        ax.set_title(col)\n\n\nclass FeaInfAna(NumFeaAnalyser):\n    def calc_stat_values(self):\n        self._inf_cnt = {}\n        for col, item in self._dataset.items():\n            if not super().skip(col):\n                self._inf_cnt[col] = item.apply(np.isinf).astype(np.int).groupby(DT_COL_NAME, group_keys=False).sum()\n        self._inf_cnt = pd.DataFrame(self._inf_cnt)\n\n    def skip(self, col):\n        return (col not in self._inf_cnt) or (self._inf_cnt[col].sum() == 0)\n\n    def plot_single(self, col, ax):\n        self._inf_cnt[col].plot(ax=ax, title=col)\n        ax.set_xlabel(\"\")\n\n\nclass FeaNanAna(FeaAnalyser):\n    def calc_stat_values(self):\n        self._nan_cnt = self._dataset.isna().groupby(DT_COL_NAME, group_keys=False).sum()\n\n    def skip(self, col):\n        return (col not in self._nan_cnt) or (self._nan_cnt[col].sum() == 0)\n\n    def plot_single(self, col, ax):\n        self._nan_cnt[col].plot(ax=ax, title=col)\n        ax.set_xlabel(\"\")\n\n\nclass FeaNanAnaRatio(FeaAnalyser):\n    def calc_stat_values(self):\n        self._nan_cnt = self._dataset.isna().groupby(DT_COL_NAME, group_keys=False).sum()\n        self._total_cnt = self._dataset.groupby(DT_COL_NAME, group_keys=False).size()\n\n    def skip(self, col):\n        return (col not in self._nan_cnt) or (self._nan_cnt[col].sum() == 0)\n\n    def plot_single(self, col, ax):\n        (self._nan_cnt[col] / self._total_cnt).plot(ax=ax, title=col)\n        ax.set_xlabel(\"\")\n\n\nclass FeaACAna(FeaAnalyser):\n    \"\"\"Analysis the auto-correlation of features\"\"\"\n\n    def calc_stat_values(self):\n        self._fea_corr = pred_autocorr_all(self._dataset.to_dict(\"series\"))\n        df = pd.DataFrame(self._fea_corr)\n        ymin, ymax = df.min().min(), df.max().max()\n        self.ylim = (ymin - 0.05 * (ymax - ymin), ymax + 0.05 * (ymax - ymin))\n\n    def plot_single(self, col, ax):\n        self._fea_corr[col].plot(ax=ax, title=col, ylim=self.ylim)\n        ax.set_xlabel(\"\")\n\n\nclass FeaSkewTurt(NumFeaAnalyser):\n    def calc_stat_values(self):\n        self._skew = datetime_groupby_apply(self._dataset, \"skew\")\n        self._kurt = datetime_groupby_apply(self._dataset, pd.DataFrame.kurt)\n\n    def plot_single(self, col, ax):\n        self._skew[col].plot(ax=ax, label=\"skew\")\n        ax.set_xlabel(\"\")\n        ax.set_ylabel(\"skew\")\n        ax.legend()\n\n        right_ax = ax.twinx()\n\n        self._kurt[col].plot(ax=right_ax, label=\"kurt\", color=\"green\")\n        right_ax.set_xlabel(\"\")\n        right_ax.set_ylabel(\"kurt\")\n        right_ax.grid(None)  # set the grid to None to avoid two layer of grid\n\n        h1, l1 = ax.get_legend_handles_labels()\n        h2, l2 = right_ax.get_legend_handles_labels()\n\n        ax.legend().set_visible(False)\n        right_ax.legend(h1 + h2, l1 + l2)\n        ax.set_title(col)\n\n\nclass FeaMeanStd(NumFeaAnalyser):\n    def calc_stat_values(self):\n        self._std = self._dataset.groupby(DT_COL_NAME, group_keys=False).std()\n        self._mean = self._dataset.groupby(DT_COL_NAME, group_keys=False).mean()\n\n    def plot_single(self, col, ax):\n        self._mean[col].plot(ax=ax, label=\"mean\")\n        ax.set_xlabel(\"\")\n        ax.set_ylabel(\"mean\")\n        ax.legend()\n        ax.tick_params(axis=\"x\", rotation=90)\n\n        right_ax = ax.twinx()\n\n        self._std[col].plot(ax=right_ax, label=\"std\", color=\"green\")\n        right_ax.set_xlabel(\"\")\n        right_ax.set_ylabel(\"std\")\n        right_ax.tick_params(axis=\"x\", rotation=90)\n        right_ax.grid(None)  # set the grid to None to avoid two layer of grid\n\n        h1, l1 = ax.get_legend_handles_labels()\n        h2, l2 = right_ax.get_legend_handles_labels()\n\n        ax.legend().set_visible(False)\n        right_ax.legend(h1 + h2, l1 + l2)\n        ax.set_title(col)\n\n\nclass RawFeaAna(FeaAnalyser):\n    \"\"\"\n    Motivation:\n    - display the values without further analysis\n    \"\"\"\n\n    def calc_stat_values(self):\n        ymin, ymax = self._dataset.min().min(), self._dataset.max().max()\n        self.ylim = (ymin - 0.05 * (ymax - ymin), ymax + 0.05 * (ymax - ymin))\n\n    def plot_single(self, col, ax):\n        self._dataset[col].plot(ax=ax, title=col, ylim=self.ylim)\n        ax.set_xlabel(\"\")\n"
  },
  {
    "path": "qlib/contrib/report/data/base.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\"\"\"\nThis module is responsible for analysing data\n\nAssumptions\n- The analyse each feature individually\n\n\"\"\"\n\nimport pandas as pd\nfrom qlib.log import TimeInspector\nfrom qlib.contrib.report.utils import sub_fig_generator\n\n\nclass FeaAnalyser:\n    def __init__(self, dataset: pd.DataFrame):\n        \"\"\"\n\n        Parameters\n        ----------\n        dataset : pd.DataFrame\n\n            We often have multiple columns for dataset. Each column corresponds to one sub figure.\n            There will be a datatime column in the index levels.\n            Aggretation will be used for more summarized metrics overtime.\n            Here is an example of data:\n\n            .. code-block::\n\n                                            return\n                datetime   instrument\n                2007-02-06 equity_tpx     0.010087\n                           equity_spx     0.000786\n        \"\"\"\n        self._dataset = dataset\n        with TimeInspector.logt(\"calc_stat_values\"):\n            self.calc_stat_values()\n\n    def calc_stat_values(self):\n        pass\n\n    def plot_single(self, col, ax):\n        raise NotImplementedError(f\"This type of input is not supported\")\n\n    def skip(self, col):\n        return False\n\n    def plot_all(self, *args, **kwargs):\n        ax_gen = iter(sub_fig_generator(*args, **kwargs))\n        for col in self._dataset:\n            if not self.skip(col):\n                ax = next(ax_gen)\n                self.plot_single(col, ax)\n"
  },
  {
    "path": "qlib/contrib/report/graph.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport math\nimport importlib\nfrom typing import Iterable\n\nimport pandas as pd\n\nimport plotly.offline as py\nimport plotly.graph_objs as go\n\nfrom plotly.subplots import make_subplots\nfrom plotly.figure_factory import create_distplot\n\n\nclass BaseGraph:\n    _name = None\n\n    def __init__(\n        self, df: pd.DataFrame = None, layout: dict = None, graph_kwargs: dict = None, name_dict: dict = None, **kwargs\n    ):\n        \"\"\"\n\n        :param df:\n        :param layout:\n        :param graph_kwargs:\n        :param name_dict:\n        :param kwargs:\n            layout: dict\n                go.Layout parameters\n            graph_kwargs: dict\n                Graph parameters, eg: go.Bar(**graph_kwargs)\n        \"\"\"\n        self._df = df\n\n        self._layout = dict() if layout is None else layout\n        self._graph_kwargs = dict() if graph_kwargs is None else graph_kwargs\n        self._name_dict = name_dict\n\n        self.data = None\n\n        self._init_parameters(**kwargs)\n        self._init_data()\n\n    def _init_data(self):\n        \"\"\"\n\n        :return:\n        \"\"\"\n        if self._df.empty:\n            raise ValueError(\"df is empty.\")\n\n        self.data = self._get_data()\n\n    def _init_parameters(self, **kwargs):\n        \"\"\"\n\n        :param kwargs\n        \"\"\"\n\n        # Instantiate graphics parameters\n        self._graph_type = self._name.lower().capitalize()\n\n        # Displayed column name\n        if self._name_dict is None:\n            self._name_dict = {_item: _item for _item in self._df.columns}\n\n    @staticmethod\n    def get_instance_with_graph_parameters(graph_type: str = None, **kwargs):\n        \"\"\"\n\n        :param graph_type:\n        :param kwargs:\n        :return:\n        \"\"\"\n        try:\n            _graph_module = importlib.import_module(\"plotly.graph_objs\")\n            _graph_class = getattr(_graph_module, graph_type)\n        except AttributeError:\n            _graph_module = importlib.import_module(\"qlib.contrib.report.graph\")\n            _graph_class = getattr(_graph_module, graph_type)\n        return _graph_class(**kwargs)\n\n    @staticmethod\n    def show_graph_in_notebook(figure_list: Iterable[go.Figure] = None):\n        \"\"\"\n\n        :param figure_list:\n        :return:\n        \"\"\"\n        py.init_notebook_mode()\n        for _fig in figure_list:\n            # NOTE: displays figures: https://plotly.com/python/renderers/\n            # default: plotly_mimetype+notebook\n            # support renderers: import plotly.io as pio; print(pio.renderers)\n            renderer = None\n            try:\n                # in notebook\n                _ipykernel = str(type(get_ipython()))\n                if \"google.colab\" in _ipykernel:\n                    renderer = \"colab\"\n            except NameError:\n                pass\n\n            _fig.show(renderer=renderer)\n\n    def _get_layout(self) -> go.Layout:\n        \"\"\"\n\n        :return:\n        \"\"\"\n        return go.Layout(**self._layout)\n\n    def _get_data(self) -> list:\n        \"\"\"\n\n        :return:\n        \"\"\"\n\n        _data = [\n            self.get_instance_with_graph_parameters(\n                graph_type=self._graph_type, x=self._df.index, y=self._df[_col], name=_name, **self._graph_kwargs\n            )\n            for _col, _name in self._name_dict.items()\n        ]\n        return _data\n\n    @property\n    def figure(self) -> go.Figure:\n        \"\"\"\n\n        :return:\n        \"\"\"\n        _figure = go.Figure(data=self.data, layout=self._get_layout())\n        # NOTE: Use the default theme from plotly version 3.x, template=None\n        _figure[\"layout\"].update(template=None)\n        return _figure\n\n\nclass ScatterGraph(BaseGraph):\n    _name = \"scatter\"\n\n\nclass BarGraph(BaseGraph):\n    _name = \"bar\"\n\n\nclass DistplotGraph(BaseGraph):\n    _name = \"distplot\"\n\n    def _get_data(self):\n        \"\"\"\n\n        :return:\n        \"\"\"\n        _t_df = self._df.dropna()\n        _data_list = [_t_df[_col] for _col in self._name_dict]\n        _label_list = list(self._name_dict.values())\n        _fig = create_distplot(_data_list, _label_list, show_rug=False, **self._graph_kwargs)\n\n        return _fig[\"data\"]\n\n\nclass HeatmapGraph(BaseGraph):\n    _name = \"heatmap\"\n\n    def _get_data(self):\n        \"\"\"\n\n        :return:\n        \"\"\"\n        _data = [\n            self.get_instance_with_graph_parameters(\n                graph_type=self._graph_type,\n                x=self._df.columns,\n                y=self._df.index,\n                z=self._df.values.tolist(),\n                **self._graph_kwargs,\n            )\n        ]\n        return _data\n\n\nclass HistogramGraph(BaseGraph):\n    _name = \"histogram\"\n\n    def _get_data(self):\n        \"\"\"\n\n        :return:\n        \"\"\"\n        _data = [\n            self.get_instance_with_graph_parameters(\n                graph_type=self._graph_type, x=self._df[_col], name=_name, **self._graph_kwargs\n            )\n            for _col, _name in self._name_dict.items()\n        ]\n        return _data\n\n\nclass SubplotsGraph:\n    \"\"\"Create subplots same as df.plot(subplots=True)\n\n    Simple package for `plotly.tools.subplots`\n    \"\"\"\n\n    def __init__(\n        self,\n        df: pd.DataFrame = None,\n        kind_map: dict = None,\n        layout: dict = None,\n        sub_graph_layout: dict = None,\n        sub_graph_data: list = None,\n        subplots_kwargs: dict = None,\n        **kwargs,\n    ):\n        \"\"\"\n\n        :param df: pd.DataFrame\n\n        :param kind_map: dict, subplots graph kind and kwargs\n            eg: dict(kind='ScatterGraph', kwargs=dict())\n\n        :param layout: `go.Layout` parameters\n\n        :param sub_graph_layout: Layout of each graphic, similar to 'layout'\n\n        :param sub_graph_data: Instantiation parameters for each sub-graphic\n            eg: [(column_name, instance_parameters), ]\n\n            column_name: str or go.Figure\n\n            Instance_parameters:\n\n                - row: int, the row where the graph is located\n\n                - col: int, the col where the graph is located\n\n                - name: str, show name, default column_name in 'df'\n\n                - kind: str, graph kind, default `kind` param, eg: bar, scatter, ...\n\n                - graph_kwargs: dict, graph kwargs, default {}, used in `go.Bar(**graph_kwargs)`\n\n        :param subplots_kwargs: `plotly.tools.make_subplots` original parameters\n\n                - shared_xaxes: bool, default False\n\n                - shared_yaxes: bool, default False\n\n                - vertical_spacing: float, default 0.3 / rows\n\n                - subplot_titles: list, default []\n                    If `sub_graph_data` is None, will generate 'subplot_titles' according to `df.columns`,\n                    this field will be discarded\n\n\n                - specs: list, see `make_subplots` docs\n\n                - rows: int, Number of rows in the subplot grid, default 1\n                    If `sub_graph_data` is None, will generate 'rows' according to `df`, this field will be discarded\n\n                - cols: int, Number of cols in the subplot grid, default 1\n                    If `sub_graph_data` is None, will generate 'cols' according to `df`, this field will be discarded\n\n\n        :param kwargs:\n\n        \"\"\"\n\n        self._df = df\n        self._layout = layout\n        self._sub_graph_layout = sub_graph_layout\n\n        self._kind_map = kind_map\n        if self._kind_map is None:\n            self._kind_map = dict(kind=\"ScatterGraph\", kwargs=dict())\n\n        self._subplots_kwargs = subplots_kwargs\n        if self._subplots_kwargs is None:\n            self._init_subplots_kwargs()\n\n        self.__cols = self._subplots_kwargs.get(\"cols\", 2)  # pylint: disable=W0238\n        self.__rows = self._subplots_kwargs.get(  # pylint: disable=W0238\n            \"rows\", math.ceil(len(self._df.columns) / self.__cols)\n        )\n\n        self._sub_graph_data = sub_graph_data\n        if self._sub_graph_data is None:\n            self._init_sub_graph_data()\n\n        self._init_figure()\n\n    def _init_sub_graph_data(self):\n        \"\"\"\n\n        :return:\n        \"\"\"\n        self._sub_graph_data = []\n        self._subplot_titles = []\n\n        for i, column_name in enumerate(self._df.columns):\n            row = math.ceil((i + 1) / self.__cols)\n            _temp = (i + 1) % self.__cols\n            col = _temp if _temp else self.__cols\n            res_name = column_name.replace(\"_\", \" \")\n            _temp_row_data = (\n                column_name,\n                dict(\n                    row=row,\n                    col=col,\n                    name=res_name,\n                    kind=self._kind_map[\"kind\"],\n                    graph_kwargs=self._kind_map[\"kwargs\"],\n                ),\n            )\n            self._sub_graph_data.append(_temp_row_data)\n            self._subplot_titles.append(res_name)\n\n    def _init_subplots_kwargs(self):\n        \"\"\"\n\n        :return:\n        \"\"\"\n        # Default cols, rows\n        _cols = 2\n        _rows = math.ceil(len(self._df.columns) / 2)\n        self._subplots_kwargs = dict()\n        self._subplots_kwargs[\"rows\"] = _rows\n        self._subplots_kwargs[\"cols\"] = _cols\n        self._subplots_kwargs[\"shared_xaxes\"] = False\n        self._subplots_kwargs[\"shared_yaxes\"] = False\n        self._subplots_kwargs[\"vertical_spacing\"] = 0.3 / _rows\n        self._subplots_kwargs[\"print_grid\"] = False\n        self._subplots_kwargs[\"subplot_titles\"] = self._df.columns.tolist()\n\n    def _init_figure(self):\n        \"\"\"\n\n        :return:\n        \"\"\"\n        self._figure = make_subplots(**self._subplots_kwargs)\n\n        for column_name, column_map in self._sub_graph_data:\n            if isinstance(column_name, go.Figure):\n                _graph_obj = column_name\n            elif isinstance(column_name, str):\n                temp_name = column_map.get(\"name\", column_name.replace(\"_\", \" \"))\n                kind = column_map.get(\"kind\", self._kind_map.get(\"kind\", \"ScatterGraph\"))\n                _graph_kwargs = column_map.get(\"graph_kwargs\", self._kind_map.get(\"kwargs\", {}))\n                _graph_obj = BaseGraph.get_instance_with_graph_parameters(\n                    kind,\n                    **dict(\n                        df=self._df.loc[:, [column_name]],\n                        name_dict={column_name: temp_name},\n                        graph_kwargs=_graph_kwargs,\n                    ),\n                )\n            else:\n                raise TypeError()\n\n            row = column_map[\"row\"]\n            col = column_map[\"col\"]\n\n            _graph_data = getattr(_graph_obj, \"data\")\n            # for _item in _graph_data:\n            #     _item.pop('xaxis', None)\n            #     _item.pop('yaxis', None)\n\n            for _g_obj in _graph_data:\n                self._figure.add_trace(_g_obj, row=row, col=col)\n\n        if self._sub_graph_layout is not None:\n            for k, v in self._sub_graph_layout.items():\n                self._figure[\"layout\"][k].update(v)\n\n        # NOTE: Use the default theme from plotly version 3.x: template=None\n        self._figure[\"layout\"].update(template=None)\n        self._figure[\"layout\"].update(self._layout)\n\n    @property\n    def figure(self):\n        return self._figure\n"
  },
  {
    "path": "qlib/contrib/report/utils.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\nimport matplotlib.pyplot as plt\nimport pandas as pd\n\n\ndef sub_fig_generator(sub_figsize=(3, 3), col_n=10, row_n=1, wspace=None, hspace=None, sharex=False, sharey=False):\n    \"\"\"sub_fig_generator.\n    it will return a generator, each row contains <col_n> sub graph\n\n    FIXME: Known limitation:\n    - The last row will not be plotted automatically, please plot it outside the function\n\n    Parameters\n    ----------\n    sub_figsize :\n        the figure size of each subgraph in <col_n> * <row_n> subgraphs\n    col_n :\n        the number of subgraph in each row;  It will generating a new graph after generating <col_n> of subgraphs.\n    row_n :\n        the number of subgraph in each column\n    wspace :\n        the width of the space for subgraphs in each row\n    hspace :\n        the height of blank space for subgraphs in each column\n        You can try 0.3 if you feel it is too crowded\n\n    Returns\n    -------\n    It will return graphs with the shape of <col_n> each iter (it is squeezed).\n    \"\"\"\n    assert col_n > 1\n\n    while True:\n        fig, axes = plt.subplots(\n            row_n, col_n, figsize=(sub_figsize[0] * col_n, sub_figsize[1] * row_n), sharex=sharex, sharey=sharey\n        )\n        plt.subplots_adjust(wspace=wspace, hspace=hspace)\n        axes = axes.reshape(row_n, col_n)\n\n        for col in range(col_n):\n            res = axes[:, col].squeeze()\n            if res.size == 1:\n                res = res.item()\n            yield res\n        plt.show()\n\n\ndef guess_plotly_rangebreaks(dt_index: pd.DatetimeIndex):\n    \"\"\"\n    This function `guesses` the rangebreaks required to remove gaps in datetime index.\n    It basically calculates the difference between a `continuous` datetime index and index given.\n\n    For more details on `rangebreaks` params in plotly, see\n    https://plotly.com/python/reference/layout/xaxis/#layout-xaxis-rangebreaks\n\n    Parameters\n    ----------\n    dt_index: pd.DatetimeIndex\n    The datetimes of the data.\n\n    Returns\n    -------\n    the `rangebreaks` to be passed into plotly axis.\n\n    \"\"\"\n    dt_idx = dt_index.sort_values()\n    gaps = dt_idx[1:] - dt_idx[:-1]\n    min_gap = gaps.min()\n    gaps_to_break = {}\n    for gap, d in zip(gaps, dt_idx[:-1]):\n        if gap > min_gap:\n            gaps_to_break.setdefault(gap - min_gap, []).append(d + min_gap)\n    return [dict(values=v, dvalue=int(k.total_seconds() * 1000)) for k, v in gaps_to_break.items()]\n"
  },
  {
    "path": "qlib/contrib/rolling/__init__.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\"\"\"\nThe difference between me and the scripts in examples/benchmarks/benchmarks_dynamic\n- This module only focus provide a general rolling implementation.\n  Anything specific that benchmark is placed in examples/benchmarks/benchmarks_dynamic\n\"\"\"\n"
  },
  {
    "path": "qlib/contrib/rolling/__main__.py",
    "content": "import fire\nfrom qlib import auto_init\nfrom qlib.contrib.rolling.base import Rolling\nfrom qlib.utils.mod import find_all_classes\n\nif __name__ == \"__main__\":\n    sub_commands = {}\n    for cls in find_all_classes(\"qlib.contrib.rolling\", Rolling):\n        sub_commands[cls.__module__.split(\".\")[-1]] = cls\n    # The sub_commands will be like\n    # {'base': <class 'qlib.contrib.rolling.base.Rolling'>, ...}\n    # So the you can run it with commands like command below\n    # - `python -m qlib.contrib.rolling base --conf_path <path to the yaml> run`\n    # - base can be replace with other module names\n    auto_init()\n    fire.Fire(sub_commands)\n"
  },
  {
    "path": "qlib/contrib/rolling/base.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\nfrom copy import deepcopy\nfrom pathlib import Path\nfrom ruamel.yaml import YAML\nfrom typing import List, Optional, Union\n\nimport fire\nimport pandas as pd\n\nfrom qlib import auto_init\nfrom qlib.log import get_module_logger\nfrom qlib.model.ens.ensemble import RollingEnsemble\nfrom qlib.model.trainer import TrainerR\nfrom qlib.utils import get_cls_kwargs, init_instance_by_config\nfrom qlib.utils.data import update_config\nfrom qlib.workflow import R\nfrom qlib.workflow.record_temp import SignalRecord\nfrom qlib.workflow.task.collect import RecorderCollector\nfrom qlib.workflow.task.gen import RollingGen, task_generator\nfrom qlib.workflow.task.utils import replace_task_handler_with_cache\n\n\nclass Rolling:\n    \"\"\"\n    The motivation of Rolling Module\n    - It only focus **offlinely** turn a specific task to rollinng\n    - To make the implementation easier, following factors are ignored.\n        - The tasks is dependent (e.g. time series).\n\n    Related modules and difference from me:\n    - MetaController: It is learning how to handle a task (e.g. learning to learn).\n        - But rolling is about how to split a single task into tasks in time series and run them.\n    - OnlineStrategy: It is focusing on serving a model, the model can be updated time dependently in time.\n        - Rolling is much simpler and is only for testing rolling models offline. It does not want to share the interface with OnlineStrategy.\n\n    The code about rolling is shared in `task_generator` & `RollingGen` level between me and the above modules\n    But it is for different purpose, so other parts are not shared.\n\n\n    .. code-block:: shell\n\n        # here is an typical use case of the module.\n        python -m qlib.contrib.rolling.base --conf_path <path to the yaml> run\n\n    **NOTE**\n    before running the example, please clean your previous results with following command\n    - `rm -r mlruns`\n    - Because it is very hard to permanently delete a experiment (it will be moved into .trash and raise error when creating experiment with same name).\n\n    \"\"\"\n\n    def __init__(\n        self,\n        conf_path: Union[str, Path],\n        exp_name: Optional[str] = None,\n        horizon: Optional[int] = 20,\n        step: int = 20,\n        h_path: Optional[str] = None,\n        train_start: Optional[str] = None,\n        test_end: Optional[str] = None,\n        task_ext_conf: Optional[dict] = None,\n        rolling_exp: Optional[str] = None,\n    ) -> None:\n        \"\"\"\n        Parameters\n        ----------\n        conf_path : str\n            Path to the config for rolling.\n        exp_name : Optional[str]\n            The exp name of the outputs (Output is a record which contains the concatenated predictions of rolling records).\n        horizon: Optional[int] = 20,\n            The horizon of the prediction target.\n            This is used to override the prediction horizon of the file.\n        h_path : Optional[str]\n            It is other data source that is dumped as a handler. It will override the data handler section in the config.\n            If it is not given, it will create a customized cache for the handler when `enable_handler_cache=True`\n        test_end : Optional[str]\n            the test end for the data. It is typically used together with the handler\n            You can do the same thing with task_ext_conf in a more complicated way\n        train_start : Optional[str]\n            the train start for the data.  It is typically used together with the handler.\n            You can do the same thing with task_ext_conf in a more complicated way\n        task_ext_conf : Optional[dict]\n            some option to update the task config.\n        rolling_exp : Optional[str]\n            The name for the experiments for rolling.\n            It will contains a lot of record in an experiment. Each record corresponds to a specific rolling.\n            Please note that it is different from the final experiments\n        \"\"\"\n        self.logger = get_module_logger(\"Rolling\")\n        self.conf_path = Path(conf_path)\n        self.exp_name = exp_name\n        self._rid = None  # the final combined recorder id in `exp_name`\n\n        self.step = step\n        assert horizon is not None, \"Current version does not support extracting horizon from the underlying dataset\"\n        self.horizon = horizon\n        if rolling_exp is None:\n            datetime_suffix = pd.Timestamp.now().strftime(\"%Y%m%d%H%M%S\")\n            self.rolling_exp = f\"rolling_models_{datetime_suffix}\"\n        else:\n            self.rolling_exp = rolling_exp\n            self.logger.warning(\n                \"Using user specifiied name for rolling models. So the experiment names duplicateds. \"\n                \"Please manually remove your experiment for rolling model with command like `rm -r mlruns`.\"\n                \" Otherwise it will prevents the creating of experimen with same name\"\n            )\n        self.train_start = train_start\n        self.test_end = test_end\n        self.task_ext_conf = task_ext_conf\n        self.h_path = h_path\n\n        # FIXME:\n        # - the qlib_init section will be ignored by me.\n        # - So we have to design a priority mechanism to solve this issue.\n\n    def _raw_conf(self) -> dict:\n        with self.conf_path.open(\"r\") as f:\n            yaml = YAML(typ=\"safe\", pure=True)\n            return yaml.load(f)\n\n    def _replace_handler_with_cache(self, task: dict):\n        \"\"\"\n        Due to the data processing part in original rolling is slow. So we have to\n        This class tries to add more feature\n        \"\"\"\n        if self.h_path is not None:\n            h_path = Path(self.h_path)\n            task[\"dataset\"][\"kwargs\"][\"handler\"] = f\"file://{h_path}\"\n        else:\n            task = replace_task_handler_with_cache(task, self.conf_path.parent)\n        return task\n\n    def _update_start_end_time(self, task: dict):\n        if self.train_start is not None:\n            seg = task[\"dataset\"][\"kwargs\"][\"segments\"][\"train\"]\n            task[\"dataset\"][\"kwargs\"][\"segments\"][\"train\"] = pd.Timestamp(self.train_start), seg[1]\n\n        if self.test_end is not None:\n            seg = task[\"dataset\"][\"kwargs\"][\"segments\"][\"test\"]\n            task[\"dataset\"][\"kwargs\"][\"segments\"][\"test\"] = seg[0], pd.Timestamp(self.test_end)\n        return task\n\n    def basic_task(self, enable_handler_cache: Optional[bool] = True):\n        \"\"\"\n        The basic task may not be the exactly same as the config from `conf_path` from __init__ due to\n        - some parameters could be overriding by some parameters from __init__\n        - user could implementing sublcass to change it for higher performance\n        \"\"\"\n        task: dict = self._raw_conf()[\"task\"]\n        task = deepcopy(task)\n\n        # modify dataset horizon\n        # NOTE:\n        # It assumpts that the label can be modifiled in the handler's kwargs\n        # But is not always a valid. It is only valid in the predefined dataset `Alpha158` & `Alpha360`\n        if self.horizon is None:\n            # TODO:\n            # - get horizon automatically from the expression!!!!\n            raise NotImplementedError(f\"This type of input is not supported\")\n        else:\n            if enable_handler_cache and self.h_path is not None:\n                self.logger.info(\"Fail to override the horizon due to data handler cache\")\n            else:\n                self.logger.info(\"The prediction horizon is overrided\")\n                if isinstance(task[\"dataset\"][\"kwargs\"][\"handler\"], dict):\n                    task[\"dataset\"][\"kwargs\"][\"handler\"][\"kwargs\"][\"label\"] = [\n                        \"Ref($close, -{}) / Ref($close, -1) - 1\".format(self.horizon + 1)\n                    ]\n                else:\n                    self.logger.warning(\"Try to automatically configure the lablel but failed.\")\n\n        if self.h_path is not None or enable_handler_cache:\n            # if we already have provided data source or we want to create one\n            task = self._replace_handler_with_cache(task)\n        task = self._update_start_end_time(task)\n\n        if self.task_ext_conf is not None:\n            task = update_config(task, self.task_ext_conf)\n        self.logger.info(task)\n        return task\n\n    def run_basic_task(self):\n        \"\"\"\n        Run the basic task without rolling.\n        This is for fast testing for model tunning.\n        \"\"\"\n        task = self.basic_task()\n        print(task)\n        trainer = TrainerR(experiment_name=self.exp_name)\n        trainer([task])\n\n    def get_task_list(self) -> List[dict]:\n        \"\"\"return a batch of tasks for rolling.\"\"\"\n        task = self.basic_task()\n        task_l = task_generator(\n            task, RollingGen(step=self.step, trunc_days=self.horizon + 1)\n        )  # the last two days should be truncated to avoid information leakage\n        for t in task_l:\n            # when we rolling tasks. No further analyis is needed.\n            # analyis are postponed to the final ensemble.\n            t[\"record\"] = [\"qlib.workflow.record_temp.SignalRecord\"]\n        return task_l\n\n    def _train_rolling_tasks(self):\n        task_l = self.get_task_list()\n        self.logger.info(\"Deleting previous Rolling results\")\n        try:\n            # TODO: mlflow does not support permanently delete experiment\n            # it will  be moved to .trash and prevents creating the experiments with the same name\n            R.delete_exp(experiment_name=self.rolling_exp)  # We should remove the rolling experiments.\n        except ValueError:\n            self.logger.info(\"No previous rolling results\")\n        trainer = TrainerR(experiment_name=self.rolling_exp)\n        trainer(task_l)\n\n    def _ens_rolling(self):\n        rc = RecorderCollector(\n            experiment=self.rolling_exp,\n            artifacts_key=[\"pred\", \"label\"],\n            process_list=[RollingEnsemble()],\n            # rec_key_func=lambda rec: (self.COMB_EXP, rec.info[\"id\"]),\n            artifacts_path={\"pred\": \"pred.pkl\", \"label\": \"label.pkl\"},\n        )\n        res = rc()\n        with R.start(experiment_name=self.exp_name):\n            R.log_params(exp_name=self.rolling_exp)\n            R.save_objects(**{\"pred.pkl\": res[\"pred\"], \"label.pkl\": res[\"label\"]})\n            self._rid = R.get_recorder().id\n\n    def _update_rolling_rec(self):\n        \"\"\"\n        Evaluate the combined rolling results\n        \"\"\"\n        rec = R.get_recorder(experiment_name=self.exp_name, recorder_id=self._rid)\n        # Follow the original analyser\n        records = self._raw_conf()[\"task\"].get(\"record\", [])\n        if isinstance(records, dict):  # prevent only one dict\n            records = [records]\n        for record in records:\n            if issubclass(get_cls_kwargs(record)[0], SignalRecord):\n                # skip the signal record.\n                continue\n            r = init_instance_by_config(\n                record,\n                recorder=rec,\n                default_module=\"qlib.workflow.record_temp\",\n            )\n            r.generate()\n        print(f\"Your evaluation results can be found in the experiment named `{self.exp_name}`.\")\n\n    def run(self):\n        # the results will be  save in mlruns.\n        # 1) each rolling task is saved in rolling_models\n        self._train_rolling_tasks()\n        # 2) combined rolling tasks and evaluation results are saved in rolling\n        self._ens_rolling()\n        self._update_rolling_rec()\n\n\nif __name__ == \"__main__\":\n    auto_init()\n    fire.Fire(Rolling)\n"
  },
  {
    "path": "qlib/contrib/rolling/ddgda.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\nfrom pathlib import Path\nimport pickle\nfrom typing import Optional, Union\n\nimport pandas as pd\nimport yaml\n\nfrom qlib.contrib.meta.data_selection.dataset import InternalData, MetaDatasetDS\nfrom qlib.contrib.meta.data_selection.model import MetaModelDS\nfrom qlib.data.dataset.handler import DataHandlerLP\nfrom qlib.model.meta.task import MetaTask\nfrom qlib.model.trainer import TrainerR\nfrom qlib.typehint import Literal\nfrom qlib.utils import init_instance_by_config\nfrom qlib.utils.pickle_utils import restricted_pickle_load\nfrom qlib.workflow import R\nfrom qlib.workflow.task.utils import replace_task_handler_with_cache\n\nfrom .base import Rolling\n\n# LGBM is designed for feature importance & similarity\nLGBM_MODEL = \"\"\"\nclass: LGBModel\nmodule_path: qlib.contrib.model.gbdt\nkwargs:\n    loss: mse\n    colsample_bytree: 0.8879\n    learning_rate: 0.2\n    subsample: 0.8789\n    lambda_l1: 205.6999\n    lambda_l2: 580.9768\n    max_depth: 8\n    num_leaves: 210\n    num_threads: 20\n\"\"\"\n# covnert the yaml to dict\nLGBM_MODEL = yaml.load(LGBM_MODEL, Loader=yaml.FullLoader)\n\nLINEAR_MODEL = \"\"\"\nclass: LinearModel\nmodule_path: qlib.contrib.model.linear\nkwargs:\n    estimator: ridge\n    alpha: 0.05\n\"\"\"\nLINEAR_MODEL = yaml.load(LINEAR_MODEL, Loader=yaml.FullLoader)\n\nPROC_ARGS = \"\"\"\ninfer_processors:\n    - class: RobustZScoreNorm\n      kwargs:\n          fields_group: feature\n          clip_outlier: true\n    - class: Fillna\n      kwargs:\n          fields_group: feature\nlearn_processors:\n    - class: DropnaLabel\n    - class: CSRankNorm\n      kwargs:\n          fields_group: label\n\"\"\"\nPROC_ARGS = yaml.load(PROC_ARGS, Loader=yaml.FullLoader)\n\nUTIL_MODEL_TYPE = Literal[\"linear\", \"gbdt\"]\n\n\nclass DDGDA(Rolling):\n    \"\"\"\n    It is a rolling based on DDG-DA\n\n    **NOTE**\n    before running the example, please clean your previous results with following command\n    - `rm -r mlruns`\n    \"\"\"\n\n    def __init__(\n        self,\n        sim_task_model: UTIL_MODEL_TYPE = \"gbdt\",\n        meta_1st_train_end: Optional[str] = None,\n        alpha: float = 0.01,\n        loss_skip_thresh: int = 50,\n        fea_imp_n: Optional[int] = 30,\n        meta_data_proc: Optional[str] = \"V01\",\n        segments: Union[float, str] = 0.62,\n        hist_step_n: int = 30,\n        working_dir: Optional[Union[str, Path]] = None,\n        **kwargs,\n    ):\n        \"\"\"\n\n        Parameters\n        ----------\n        sim_task_model: Literal[\"linear\", \"gbdt\"] = \"gbdt\",\n            The model for calculating similarity between data.\n        meta_1st_train_end: Optional[str]\n            the datetime of training end of the first meta_task\n        alpha: float\n            Setting the L2 regularization for ridge\n            The `alpha` is only passed to MetaModelDS (it is not passed to sim_task_model currently..)\n        loss_skip_thresh: int\n            The thresh to skip the loss calculation for each day. If the number of item is less than it, it will skip the loss on that day.\n        meta_data_proc : Optional[str]\n            How we process the meta dataset for learning meta model.\n        segments : Union[float, str]\n            if segments is a float:\n                The ratio of training data in the meta task dataset\n            if segments is a string:\n                it will try its best to put its data in training and ensure that the date `segments` is in the test set\n        \"\"\"\n        # NOTE:\n        # the horizon must match the meaning in the base task template\n        self.meta_exp_name = \"DDG-DA\"\n        self.sim_task_model: UTIL_MODEL_TYPE = sim_task_model  # The model to capture the distribution of data.\n        self.alpha = alpha\n        self.meta_1st_train_end = meta_1st_train_end\n        super().__init__(**kwargs)\n        self.working_dir = self.conf_path.parent if working_dir is None else Path(working_dir)\n        self.proxy_hd = self.working_dir / \"handler_proxy.pkl\"\n        self.fea_imp_n = fea_imp_n\n        self.meta_data_proc = meta_data_proc\n        self.loss_skip_thresh = loss_skip_thresh\n        self.segments = segments\n        self.hist_step_n = hist_step_n\n\n    def _adjust_task(self, task: dict, astype: UTIL_MODEL_TYPE):\n        \"\"\"\n        Base on the original task, we need to do some extra things.\n\n        For example:\n        - GBDT for calculating feature importance\n        - Linear or GBDT for calculating similarity\n        - Datset (well processed) that aligned to Linear that for meta learning\n\n        So we may need to change the dataset and model for the special purpose and other settings remains the same.\n        \"\"\"\n        # NOTE: here is just for aligning with previous implementation\n        # It is not necessary for the current implementation\n        handler = task[\"dataset\"].setdefault(\"kwargs\", {}).setdefault(\"handler\", {})\n        if astype == \"gbdt\":\n            task[\"model\"] = LGBM_MODEL\n            if isinstance(handler, dict):\n                # We don't need preprocessing when using GBDT model\n                for k in [\"infer_processors\", \"learn_processors\"]:\n                    if k in handler.setdefault(\"kwargs\", {}):\n                        handler[\"kwargs\"].pop(k)\n        elif astype == \"linear\":\n            task[\"model\"] = LINEAR_MODEL\n            if isinstance(handler, dict):\n                handler[\"kwargs\"].update(PROC_ARGS)\n            else:\n                self.logger.warning(\"The handler can't be adjusted.\")\n        else:\n            raise ValueError(f\"astype not supported: {astype}\")\n        return task\n\n    def _get_feature_importance(self):\n        # this must be lightGBM, because it needs to get the feature importance\n        task = self.basic_task(enable_handler_cache=False)\n        task = self._adjust_task(task, astype=\"gbdt\")\n        task = replace_task_handler_with_cache(task, self.working_dir)\n\n        with R.start(experiment_name=\"feature_importance\"):\n            model = init_instance_by_config(task[\"model\"])\n            dataset = init_instance_by_config(task[\"dataset\"])\n            model.fit(dataset)\n\n        fi = model.get_feature_importance()\n        # Because the model use numpy instead of dataframe for training lightgbm\n        # So the we must use following extra steps to get the right feature importance\n        df = dataset.prepare(segments=slice(None), col_set=\"feature\", data_key=DataHandlerLP.DK_R)\n        cols = df.columns\n        fi_named = {cols[int(k.split(\"_\")[1])]: imp for k, imp in fi.to_dict().items()}\n\n        return pd.Series(fi_named)\n\n    def _dump_data_for_proxy_model(self):\n        \"\"\"\n        Dump data for training meta model.\n        The meta model will be trained upon the proxy forecasting model.\n        This dataset is for the proxy forecasting model.\n        \"\"\"\n\n        # NOTE: adjusting to `self.sim_task_model` just for aligning with previous implementation.\n        # In previous version. The data for proxy model is using sim_task_model's way for processing\n        task = self._adjust_task(self.basic_task(enable_handler_cache=False), self.sim_task_model)\n        task = replace_task_handler_with_cache(task, self.working_dir)\n        # if self.meta_data_proc is not None:\n        # else:\n        #     # Otherwise, we don't need futher processing\n        #     task = self.basic_task()\n\n        dataset = init_instance_by_config(task[\"dataset\"])\n        prep_ds = dataset.prepare(slice(None), col_set=[\"feature\", \"label\"], data_key=DataHandlerLP.DK_L)\n\n        feature_df = prep_ds[\"feature\"]\n        label_df = prep_ds[\"label\"]\n\n        if self.fea_imp_n is not None:\n            fi = self._get_feature_importance()\n            col_selected = fi.nlargest(self.fea_imp_n)\n            feature_selected = feature_df.loc[:, col_selected.index]\n        else:\n            feature_selected = feature_df\n\n        if self.meta_data_proc == \"V01\":\n            feature_selected = feature_selected.groupby(\"datetime\", group_keys=False).apply(\n                lambda df: (df - df.mean()).div(df.std())\n            )\n            feature_selected = feature_selected.fillna(0.0)\n\n        df_all = {\n            \"label\": label_df.reindex(feature_selected.index),\n            \"feature\": feature_selected,\n        }\n        df_all = pd.concat(df_all, axis=1)\n        df_all.to_pickle(self.working_dir / \"fea_label_df.pkl\")\n\n        # dump data in handler format for aligning the interface\n        handler = DataHandlerLP(\n            data_loader={\n                \"class\": \"qlib.data.dataset.loader.StaticDataLoader\",\n                \"kwargs\": {\"config\": self.working_dir / \"fea_label_df.pkl\"},\n            }\n        )\n        handler.to_pickle(self.working_dir / self.proxy_hd, dump_all=True)\n\n    @property\n    def _internal_data_path(self):\n        return self.working_dir / f\"internal_data_s{self.step}.pkl\"\n\n    def _dump_meta_ipt(self):\n        \"\"\"\n        Dump data for training meta model.\n        This function will dump the input data for meta model\n        \"\"\"\n        # According to the experiments, the choice of the model type is very important for achieving good results\n        sim_task = self._adjust_task(self.basic_task(enable_handler_cache=False), astype=self.sim_task_model)\n        sim_task = replace_task_handler_with_cache(sim_task, self.working_dir)\n\n        if self.sim_task_model == \"gbdt\":\n            sim_task[\"model\"].setdefault(\"kwargs\", {}).update({\"early_stopping_rounds\": None, \"num_boost_round\": 150})\n\n        exp_name_sim = f\"data_sim_s{self.step}\"\n\n        internal_data = InternalData(sim_task, self.step, exp_name=exp_name_sim)\n        internal_data.setup(trainer=TrainerR)\n\n        with self._internal_data_path.open(\"wb\") as f:\n            pickle.dump(internal_data, f)\n\n    def _train_meta_model(self, fill_method=\"max\"):\n        \"\"\"\n        training a meta model based on a simplified linear proxy model;\n        \"\"\"\n\n        # 1) leverage the simplified proxy forecasting model to train meta model.\n        # - Only the dataset part is important, in current version of meta model will integrate the\n\n        # NOTE:\n        # - The train_start for training meta model does not necessarily align with final rolling\n        #   But please select a right time to make sure the finnal rolling tasks are not leaked in the training data.\n        # - The test_start is automatically aligned to the next day of test_end.  Validation is ignored.\n        train_start = \"2008-01-01\" if self.train_start is None else self.train_start\n        train_end = \"2010-12-31\" if self.meta_1st_train_end is None else self.meta_1st_train_end\n        test_start = (pd.Timestamp(train_end) + pd.Timedelta(days=1)).strftime(\"%Y-%m-%d\")\n        proxy_forecast_model_task = {\n            # \"model\": \"qlib.contrib.model.linear.LinearModel\",\n            \"dataset\": {\n                \"class\": \"qlib.data.dataset.DatasetH\",\n                \"kwargs\": {\n                    \"handler\": f\"file://{(self.working_dir / self.proxy_hd).absolute()}\",\n                    \"segments\": {\n                        \"train\": (train_start, train_end),\n                        \"test\": (test_start, self.basic_task()[\"dataset\"][\"kwargs\"][\"segments\"][\"test\"][1]),\n                    },\n                },\n            },\n            # \"record\": [\"qlib.workflow.record_temp.SignalRecord\"]\n        }\n        # the proxy_forecast_model_task will be used to create meta tasks.\n        # The test date of first task will be 2011-01-01. Each test segment will be about 20days\n        # The tasks include all training tasks and test tasks.\n\n        # 2) preparing meta dataset\n        kwargs = dict(\n            task_tpl=proxy_forecast_model_task,\n            step=self.step,\n            segments=self.segments,  # keep test period consistent with the dataset yaml\n            trunc_days=1 + self.horizon,\n            hist_step_n=self.hist_step_n,\n            fill_method=fill_method,\n            rolling_ext_days=0,\n        )\n        # NOTE:\n        # the input of meta model (internal data) are shared between proxy model and final forecasting model\n        # but their task test segment are not aligned! It worked in my previous experiment.\n        # So the misalignment will not affect the effectiveness of the method.\n        with self._internal_data_path.open(\"rb\") as f:\n            internal_data = restricted_pickle_load(f)\n\n        md = MetaDatasetDS(exp_name=internal_data, **kwargs)\n\n        # 3) train and logging meta model\n        with R.start(experiment_name=self.meta_exp_name):\n            R.log_params(**kwargs)\n            mm = MetaModelDS(\n                step=self.step,\n                hist_step_n=kwargs[\"hist_step_n\"],\n                lr=0.001,\n                max_epoch=30,\n                seed=43,\n                alpha=self.alpha,\n                loss_skip_thresh=self.loss_skip_thresh,\n            )\n            mm.fit(md)\n            R.save_objects(model=mm)\n\n    @property\n    def _task_path(self):\n        return self.working_dir / f\"tasks_s{self.step}.pkl\"\n\n    def get_task_list(self):\n        \"\"\"\n        Leverage meta-model for inference:\n        - Given\n            - baseline tasks\n            - input for meta model(internal data)\n            - meta model (its learnt knowledge on proxy forecasting model is expected to transfer to normal forecasting model)\n        \"\"\"\n        # 1) get meta model\n        exp = R.get_exp(experiment_name=self.meta_exp_name)\n        rec = exp.list_recorders(rtype=exp.RT_L)[0]\n        meta_model: MetaModelDS = rec.load_object(\"model\")\n\n        # 2)\n        # we are transfer to knowledge of meta model to final forecasting tasks.\n        # Create MetaTaskDataset for the final forecasting tasks\n        # Aligning the setting of it to the MetaTaskDataset when training Meta model is necessary\n\n        # 2.1) get previous config\n        param = rec.list_params()\n        trunc_days = int(param[\"trunc_days\"])\n        step = int(param[\"step\"])\n        hist_step_n = int(param[\"hist_step_n\"])\n        fill_method = param.get(\"fill_method\", \"max\")\n\n        task_l = super().get_task_list()\n\n        # 2.2) create meta dataset for final dataset\n        kwargs = dict(\n            task_tpl=task_l,\n            step=step,\n            segments=0.0,  # all the tasks are for testing\n            trunc_days=trunc_days,\n            hist_step_n=hist_step_n,\n            fill_method=fill_method,\n            task_mode=MetaTask.PROC_MODE_TRANSFER,\n        )\n\n        with self._internal_data_path.open(\"rb\") as f:\n            internal_data = restricted_pickle_load(f)\n        mds = MetaDatasetDS(exp_name=internal_data, **kwargs)\n\n        # 3) meta model make inference and get new qlib task\n        new_tasks = meta_model.inference(mds)\n        with self._task_path.open(\"wb\") as f:\n            pickle.dump(new_tasks, f)\n        return new_tasks\n\n    def run(self):\n        # prepare the meta model for rolling ---------\n        # 1) file: handler_proxy.pkl (self.proxy_hd)\n        self._dump_data_for_proxy_model()\n        # 2)\n        # file: internal_data_s20.pkl\n        # mlflow: data_sim_s20, models for calculating meta_ipt\n        self._dump_meta_ipt()\n        # 3) meta model will be stored in `DDG-DA`\n        self._train_meta_model()\n\n        # Run rolling --------------------------------\n        # 4) new_tasks are saved in \"tasks_s20.pkl\" (reweighter is added)\n        # - the meta inference are done when calling `get_task_list`\n        # 5) load the saved tasks and train model\n        super().run()\n"
  },
  {
    "path": "qlib/contrib/strategy/__init__.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\n\nfrom .signal_strategy import (\n    TopkDropoutStrategy,\n    WeightStrategyBase,\n    EnhancedIndexingStrategy,\n)\n\nfrom .rule_strategy import (\n    TWAPStrategy,\n    SBBStrategyBase,\n    SBBStrategyEMA,\n)\n\nfrom .cost_control import SoftTopkStrategy\n\n__all__ = [\n    \"TopkDropoutStrategy\",\n    \"WeightStrategyBase\",\n    \"EnhancedIndexingStrategy\",\n    \"TWAPStrategy\",\n    \"SBBStrategyBase\",\n    \"SBBStrategyEMA\",\n    \"SoftTopkStrategy\",\n]\n"
  },
  {
    "path": "qlib/contrib/strategy/cost_control.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom .order_generator import OrderGenWInteract\nfrom .signal_strategy import WeightStrategyBase\n\n\nclass SoftTopkStrategy(WeightStrategyBase):\n    def __init__(\n        self,\n        model=None,\n        dataset=None,\n        topk=None,\n        order_generator_cls_or_obj=OrderGenWInteract,\n        max_sold_weight=1.0,\n        trade_impact_limit=None,\n        risk_degree=0.95,\n        buy_method=\"first_fill\",\n        **kwargs,\n    ):\n        \"\"\"\n        Refactored SoftTopkStrategy with a budget-constrained rebalancing engine.\n\n        Parameters\n        ----------\n        topk : int\n            The number of top-N stocks to be held in the portfolio.\n        trade_impact_limit : float\n            Maximum weight change for each stock in one trade. If None, fallback to max_sold_weight.\n        max_sold_weight : float\n            Backward-compatible alias for trade_impact_limit. Use 1.0 to effectively disable the limit.\n        risk_degree : float\n            The target percentage of total value to be invested.\n        \"\"\"\n        super(SoftTopkStrategy, self).__init__(\n            model=model, dataset=dataset, order_generator_cls_or_obj=order_generator_cls_or_obj, **kwargs\n        )\n\n        self.topk = topk\n        self.trade_impact_limit = trade_impact_limit if trade_impact_limit is not None else max_sold_weight\n        self.risk_degree = risk_degree\n        self.buy_method = buy_method\n\n    def get_risk_degree(self, trade_step=None):\n        return self.risk_degree\n\n    def generate_target_weight_position(self, score, current, trade_start_time, trade_end_time, **kwargs):\n        \"\"\"\n        Generates target position using Proportional Budget Allocation.\n        Ensures deterministic sells and synchronized buys under impact limits.\n        \"\"\"\n\n        if self.topk is None or self.topk <= 0:\n            return {}\n\n        def apply_impact_limit(weight):\n            return weight if self.trade_impact_limit is None else min(weight, self.trade_impact_limit)\n\n        ideal_per_stock = self.risk_degree / self.topk\n        ideal_list = score.sort_values(ascending=False).iloc[: self.topk].index.tolist()\n\n        cur_weights = current.get_stock_weight_dict(only_stock=True)\n        initial_total_weight = sum(cur_weights.values())\n\n        # --- Case A: Cold Start ---\n        if not cur_weights:\n            fill = apply_impact_limit(ideal_per_stock)\n            return {code: fill for code in ideal_list}\n\n        # --- Case B: Rebalancing ---\n        all_tickers = set(cur_weights.keys()) | set(ideal_list)\n        next_weights = {t: cur_weights.get(t, 0.0) for t in all_tickers}\n\n        # Phase 1: Deterministic Sell Phase\n        released_cash = 0.0\n        for t in list(next_weights.keys()):\n            cur = next_weights[t]\n            if cur <= 1e-8:\n                continue\n\n            if t not in ideal_list:\n                sell = apply_impact_limit(cur)\n                next_weights[t] -= sell\n                released_cash += sell\n            elif cur > ideal_per_stock + 1e-8:\n                excess = cur - ideal_per_stock\n                sell = apply_impact_limit(excess)\n                next_weights[t] -= sell\n                released_cash += sell\n\n        # Phase 2: Budget Calculation\n        # Budget = Cash from sells + Available space from target risk degree\n        total_budget = released_cash + (self.risk_degree - initial_total_weight)\n\n        # Phase 3: Proportional Buy Allocation\n        if total_budget > 1e-8:\n            shortfalls = {\n                t: (ideal_per_stock - next_weights.get(t, 0.0))\n                for t in ideal_list\n                if next_weights.get(t, 0.0) < ideal_per_stock - 1e-8\n            }\n\n            if shortfalls:\n                total_shortfall = sum(shortfalls.values())\n                # Normalize total_budget to not exceed total_shortfall\n                available_to_spend = min(total_budget, total_shortfall)\n\n                for t, shortfall in shortfalls.items():\n                    # Every stock gets its fair share based on its distance to target\n                    share_of_budget = (shortfall / total_shortfall) * available_to_spend\n\n                    # Capped by impact limit\n                    max_buy_cap = apply_impact_limit(shortfall)\n\n                    next_weights[t] += min(share_of_budget, max_buy_cap)\n\n        return {k: v for k, v in next_weights.items() if v > 1e-8}\n"
  },
  {
    "path": "qlib/contrib/strategy/optimizer/__init__.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom .base import BaseOptimizer\nfrom .optimizer import PortfolioOptimizer\nfrom .enhanced_indexing import EnhancedIndexingOptimizer\n\n__all__ = [\"BaseOptimizer\", \"PortfolioOptimizer\", \"EnhancedIndexingOptimizer\"]\n"
  },
  {
    "path": "qlib/contrib/strategy/optimizer/base.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport abc\n\n\nclass BaseOptimizer(abc.ABC):\n    \"\"\"Construct portfolio with a optimization related method\"\"\"\n\n    @abc.abstractmethod\n    def __call__(self, *args, **kwargs) -> object:\n        \"\"\"Generate a optimized portfolio allocation\"\"\"\n"
  },
  {
    "path": "qlib/contrib/strategy/optimizer/enhanced_indexing.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport numpy as np\nimport cvxpy as cp\n\nfrom typing import Union, Optional, Dict, Any, List\n\nfrom qlib.log import get_module_logger\nfrom .base import BaseOptimizer\n\nlogger = get_module_logger(\"EnhancedIndexingOptimizer\")\n\n\nclass EnhancedIndexingOptimizer(BaseOptimizer):\n    \"\"\"\n    Portfolio Optimizer for Enhanced Indexing\n\n    Notations:\n        w0: current holding weights\n        wb: benchmark weight\n        r: expected return\n        F: factor exposure\n        cov_b: factor covariance\n        var_u: residual variance (diagonal)\n        lamb: risk aversion parameter\n        delta: total turnover limit\n        b_dev: benchmark deviation limit\n        f_dev: factor deviation limit\n\n    Also denote:\n        d = w - wb: benchmark deviation\n        v = d @ F: factor deviation\n\n    The optimization problem for enhanced indexing:\n        max_w  d @ r - lamb * (v @ cov_b @ v + var_u @ d**2)\n        s.t.   w >= 0\n               sum(w) == 1\n               sum(|w - w0|) <= delta\n               d >= -b_dev\n               d <= b_dev\n               v >= -f_dev\n               v <= f_dev\n    \"\"\"\n\n    def __init__(\n        self,\n        lamb: float = 1,\n        delta: Optional[float] = 0.2,\n        b_dev: Optional[float] = 0.01,\n        f_dev: Optional[Union[List[float], np.ndarray]] = None,\n        scale_return: bool = True,\n        epsilon: float = 5e-5,\n        solver_kwargs: Optional[Dict[str, Any]] = {},\n    ):\n        \"\"\"\n        Args:\n            lamb (float): risk aversion parameter (larger `lamb` means more focus on risk)\n            delta (float): total turnover limit\n            b_dev (float): benchmark deviation limit\n            f_dev (list): factor deviation limit\n            scale_return (bool): whether scale return to match estimated volatility\n            epsilon (float): minimum weight\n            solver_kwargs (dict): kwargs for cvxpy solver\n        \"\"\"\n\n        assert lamb >= 0, \"risk aversion parameter `lamb` should be positive\"\n        self.lamb = lamb\n\n        assert delta >= 0, \"turnover limit `delta` should be positive\"\n        self.delta = delta\n\n        assert b_dev is None or b_dev >= 0, \"benchmark deviation limit `b_dev` should be positive\"\n        self.b_dev = b_dev\n\n        if isinstance(f_dev, float):\n            assert f_dev >= 0, \"factor deviation limit `f_dev` should be positive\"\n        elif f_dev is not None:\n            f_dev = np.array(f_dev)\n            assert all(f_dev >= 0), \"factor deviation limit `f_dev` should be positive\"\n        self.f_dev = f_dev\n\n        self.scale_return = scale_return\n        self.epsilon = epsilon\n        self.solver_kwargs = solver_kwargs\n\n    def __call__(\n        self,\n        r: np.ndarray,\n        F: np.ndarray,\n        cov_b: np.ndarray,\n        var_u: np.ndarray,\n        w0: np.ndarray,\n        wb: np.ndarray,\n        mfh: Optional[np.ndarray] = None,\n        mfs: Optional[np.ndarray] = None,\n    ) -> np.ndarray:\n        \"\"\"\n        Args:\n            r (np.ndarray): expected returns\n            F (np.ndarray): factor exposure\n            cov_b (np.ndarray): factor covariance\n            var_u (np.ndarray): residual variance\n            w0 (np.ndarray): current holding weights\n            wb (np.ndarray): benchmark weights\n            mfh (np.ndarray): mask force holding\n            mfs (np.ndarray): mask force selling\n\n        Returns:\n            np.ndarray: optimized portfolio allocation\n        \"\"\"\n        # scale return to match volatility\n        if self.scale_return:\n            r = r / r.std()\n            r *= np.sqrt(np.mean(np.diag(F @ cov_b @ F.T) + var_u))\n\n        # target weight\n        w = cp.Variable(len(r), nonneg=True)\n        w.value = wb  # for warm start\n\n        # precompute exposure\n        d = w - wb  # benchmark exposure\n        v = d @ F  # factor exposure\n\n        # objective\n        ret = d @ r  # excess return\n        risk = cp.quad_form(v, cov_b) + var_u @ (d**2)  # tracking error\n        obj = cp.Maximize(ret - self.lamb * risk)\n\n        # weight bounds\n        lb = np.zeros_like(wb)\n        ub = np.ones_like(wb)\n\n        # bench bounds\n        if self.b_dev is not None:\n            lb = np.maximum(lb, wb - self.b_dev)\n            ub = np.minimum(ub, wb + self.b_dev)\n\n        # force holding\n        if mfh is not None:\n            lb[mfh] = w0[mfh]\n            ub[mfh] = w0[mfh]\n\n        # force selling\n        # NOTE: this will override mfh\n        if mfs is not None:\n            lb[mfs] = 0\n            ub[mfs] = 0\n\n        # constraints\n        # TODO: currently we assume fullly invest in the stocks,\n        # in the future we should support holding cash as an asset\n        cons = [cp.sum(w) == 1, w >= lb, w <= ub]\n\n        # factor deviation\n        if self.f_dev is not None:\n            cons.extend([v >= -self.f_dev, v <= self.f_dev])  # pylint: disable=E1130\n\n        # total turnover constraint\n        t_cons = []\n        if self.delta is not None:\n            if w0 is not None and w0.sum() > 0:\n                t_cons.extend([cp.norm(w - w0, 1) <= self.delta])\n\n        # optimize\n        # trial 1: use all constraints\n        success = False\n        try:\n            prob = cp.Problem(obj, cons + t_cons)\n            prob.solve(solver=cp.ECOS, warm_start=True, **self.solver_kwargs)\n            assert prob.status == \"optimal\"\n            success = True\n        except Exception as e:\n            logger.warning(f\"trial 1 failed {e} (status: {prob.status})\")\n\n        # trial 2: remove turnover constraint\n        if not success and len(t_cons):\n            logger.info(\"try removing turnover constraint as the last optimization failed\")\n            try:\n                w.value = wb\n                prob = cp.Problem(obj, cons)\n                prob.solve(solver=cp.ECOS, warm_start=True, **self.solver_kwargs)\n                assert prob.status in [\"optimal\", \"optimal_inaccurate\"]\n                success = True\n            except Exception as e:\n                logger.warning(f\"trial 2 failed {e} (status: {prob.status})\")\n\n        # return current weight if not success\n        if not success:\n            logger.warning(\"optimization failed, will return current holding weight\")\n            return w0\n\n        if prob.status == \"optimal_inaccurate\":\n            logger.warning(f\"the optimization is inaccurate\")\n\n        # remove small weight\n        w = np.asarray(w.value)\n        w[w < self.epsilon] = 0\n        w /= w.sum()\n\n        return w\n"
  },
  {
    "path": "qlib/contrib/strategy/optimizer/optimizer.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\n\nimport warnings\nimport numpy as np\nimport pandas as pd\nimport scipy.optimize as so\nfrom typing import Optional, Union, Callable, List\n\nfrom .base import BaseOptimizer\n\n\nclass PortfolioOptimizer(BaseOptimizer):\n    \"\"\"Portfolio Optimizer\n\n    The following optimization algorithms are supported:\n        - `gmv`: Global Minimum Variance Portfolio\n        - `mvo`: Mean Variance Optimized Portfolio\n        - `rp`: Risk Parity\n        - `inv`: Inverse Volatility\n\n    Note:\n        This optimizer always assumes full investment and no-shorting.\n    \"\"\"\n\n    OPT_GMV = \"gmv\"\n    OPT_MVO = \"mvo\"\n    OPT_RP = \"rp\"\n    OPT_INV = \"inv\"\n\n    def __init__(\n        self,\n        method: str = \"inv\",\n        lamb: float = 0,\n        delta: float = 0,\n        alpha: float = 0.0,\n        scale_return: bool = True,\n        tol: float = 1e-8,\n    ):\n        \"\"\"\n        Args:\n            method (str): portfolio optimization method\n            lamb (float): risk aversion parameter (larger `lamb` means more focus on return)\n            delta (float): turnover rate limit\n            alpha (float): l2 norm regularizer\n            scale_return (bool): if to scale alpha to match the volatility of the covariance matrix\n            tol (float): tolerance for optimization termination\n        \"\"\"\n        assert method in [self.OPT_GMV, self.OPT_MVO, self.OPT_RP, self.OPT_INV], f\"method `{method}` is not supported\"\n        self.method = method\n\n        assert lamb >= 0, f\"risk aversion parameter `lamb` should be positive\"\n        self.lamb = lamb\n\n        assert delta >= 0, f\"turnover limit `delta` should be positive\"\n        self.delta = delta\n\n        assert alpha >= 0, f\"l2 norm regularizer `alpha` should be positive\"\n        self.alpha = alpha\n\n        self.tol = tol\n        self.scale_return = scale_return\n\n    def __call__(\n        self,\n        S: Union[np.ndarray, pd.DataFrame],\n        r: Optional[Union[np.ndarray, pd.Series]] = None,\n        w0: Optional[Union[np.ndarray, pd.Series]] = None,\n    ) -> Union[np.ndarray, pd.Series]:\n        \"\"\"\n        Args:\n            S (np.ndarray or pd.DataFrame): covariance matrix\n            r (np.ndarray or pd.Series): expected return\n            w0 (np.ndarray or pd.Series): initial weights (for turnover control)\n\n        Returns:\n            np.ndarray or pd.Series: optimized portfolio allocation\n        \"\"\"\n        # transform dataframe into array\n        index = None\n        if isinstance(S, pd.DataFrame):\n            index = S.index\n            S = S.values\n\n        # transform return\n        if r is not None:\n            assert len(r) == len(S), \"`r` has mismatched shape\"\n            if isinstance(r, pd.Series):\n                assert r.index.equals(index), \"`r` has mismatched index\"\n                r = r.values\n\n        # transform initial weights\n        if w0 is not None:\n            assert len(w0) == len(S), \"`w0` has mismatched shape\"\n            if isinstance(w0, pd.Series):\n                assert w0.index.equals(index), \"`w0` has mismatched index\"\n                w0 = w0.values\n\n        # scale return to match volatility\n        if r is not None and self.scale_return:\n            r = r / r.std()\n            r *= np.sqrt(np.mean(np.diag(S)))\n\n        # optimize\n        w = self._optimize(S, r, w0)\n\n        # restore index if needed\n        if index is not None:\n            w = pd.Series(w, index=index)\n\n        return w\n\n    def _optimize(self, S: np.ndarray, r: Optional[np.ndarray] = None, w0: Optional[np.ndarray] = None) -> np.ndarray:\n        # inverse volatility\n        if self.method == self.OPT_INV:\n            if r is not None:\n                warnings.warn(\"`r` is set but will not be used for `inv` portfolio\")\n            if w0 is not None:\n                warnings.warn(\"`w0` is set but will not be used for `inv` portfolio\")\n            return self._optimize_inv(S)\n\n        # global minimum variance\n        if self.method == self.OPT_GMV:\n            if r is not None:\n                warnings.warn(\"`r` is set but will not be used for `gmv` portfolio\")\n            return self._optimize_gmv(S, w0)\n\n        # mean-variance\n        if self.method == self.OPT_MVO:\n            return self._optimize_mvo(S, r, w0)\n\n        # risk parity\n        if self.method == self.OPT_RP:\n            if r is not None:\n                warnings.warn(\"`r` is set but will not be used for `rp` portfolio\")\n            return self._optimize_rp(S, w0)\n\n    def _optimize_inv(self, S: np.ndarray) -> np.ndarray:\n        \"\"\"Inverse volatility\"\"\"\n        vola = np.diag(S) ** 0.5\n        w = 1 / vola\n        w /= w.sum()\n        return w\n\n    def _optimize_gmv(self, S: np.ndarray, w0: Optional[np.ndarray] = None) -> np.ndarray:\n        \"\"\"optimize global minimum variance portfolio\n\n        This method solves the following optimization problem\n            min_w w' S w\n            s.t. w >= 0, sum(w) == 1\n        where `S` is the covariance matrix.\n        \"\"\"\n        return self._solve(len(S), self._get_objective_gmv(S), *self._get_constrains(w0))\n\n    def _optimize_mvo(\n        self, S: np.ndarray, r: Optional[np.ndarray] = None, w0: Optional[np.ndarray] = None\n    ) -> np.ndarray:\n        \"\"\"optimize mean-variance portfolio\n\n        This method solves the following optimization problem\n            min_w   - w' r + lamb * w' S w\n            s.t.   w >= 0, sum(w) == 1\n        where `S` is the covariance matrix, `u` is the expected returns,\n        and `lamb` is the risk aversion parameter.\n        \"\"\"\n        return self._solve(len(S), self._get_objective_mvo(S, r), *self._get_constrains(w0))\n\n    def _optimize_rp(self, S: np.ndarray, w0: Optional[np.ndarray] = None) -> np.ndarray:\n        \"\"\"optimize risk parity portfolio\n\n        This method solves the following optimization problem\n            min_w sum_i [w_i - (w' S w) / ((S w)_i * N)]**2\n            s.t. w >= 0, sum(w) == 1\n        where `S` is the covariance matrix and `N` is the number of stocks.\n        \"\"\"\n        return self._solve(len(S), self._get_objective_rp(S), *self._get_constrains(w0))\n\n    def _get_objective_gmv(self, S: np.ndarray) -> Callable:\n        \"\"\"global minimum variance optimization objective\n\n        Optimization objective\n            min_w w' S w\n        \"\"\"\n\n        def func(x):\n            return x @ S @ x\n\n        return func\n\n    def _get_objective_mvo(self, S: np.ndarray, r: np.ndarray = None) -> Callable:\n        \"\"\"mean-variance optimization objective\n\n        Optimization objective\n            min_w - w' r + lamb * w' S w\n        \"\"\"\n\n        def func(x):\n            risk = x @ S @ x\n            ret = x @ r\n            return -ret + self.lamb * risk\n\n        return func\n\n    def _get_objective_rp(self, S: np.ndarray) -> Callable:\n        \"\"\"risk-parity optimization objective\n\n        Optimization objective\n            min_w sum_i [w_i - (w' S w) / ((S w)_i * N)]**2\n        \"\"\"\n\n        def func(x):\n            N = len(x)\n            Sx = S @ x\n            xSx = x @ Sx\n            return np.sum((x - xSx / Sx / N) ** 2)\n\n        return func\n\n    def _get_constrains(self, w0: Optional[np.ndarray] = None):\n        \"\"\"optimization constraints\n\n        Defines the following constraints:\n            - no shorting and leverage: 0 <= w <= 1\n            - full investment: sum(w) == 1\n            - turnover constraint: |w - w0| <= delta\n        \"\"\"\n\n        # no shorting and leverage\n        bounds = so.Bounds(0.0, 1.0)\n\n        # full investment constraint\n        cons = [{\"type\": \"eq\", \"fun\": lambda x: np.sum(x) - 1}]  # == 0\n\n        # turnover constraint\n        if w0 is not None:\n            cons.append({\"type\": \"ineq\", \"fun\": lambda x: self.delta - np.sum(np.abs(x - w0))})  # >= 0\n\n        return bounds, cons\n\n    def _solve(self, n: int, obj: Callable, bounds: so.Bounds, cons: List) -> np.ndarray:\n        \"\"\"solve optimization\n\n        Args:\n            n (int): number of parameters\n            obj (callable): optimization objective\n            bounds (Bounds): bounds of parameters\n            cons (list): optimization constraints\n        \"\"\"\n        # add l2 regularization\n        wrapped_obj = obj\n        if self.alpha > 0:\n\n            def opt_obj(x):\n                return obj(x) + self.alpha * np.sum(np.square(x))\n\n            wrapped_obj = opt_obj\n\n        # solve\n        x0 = np.ones(n) / n  # init results\n        sol = so.minimize(wrapped_obj, x0, bounds=bounds, constraints=cons, tol=self.tol)\n        if not sol.success:\n            warnings.warn(f\"optimization not success ({sol.status})\")\n\n        return sol.x\n"
  },
  {
    "path": "qlib/contrib/strategy/order_generator.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\n\"\"\"\nThis order generator is for strategies based on WeightStrategyBase\n\"\"\"\n\nfrom ...backtest.position import Position\nfrom ...backtest.exchange import Exchange\n\nimport pandas as pd\nimport copy\n\n\nclass OrderGenerator:\n    def generate_order_list_from_target_weight_position(\n        self,\n        current: Position,\n        trade_exchange: Exchange,\n        target_weight_position: dict,\n        risk_degree: float,\n        pred_start_time: pd.Timestamp,\n        pred_end_time: pd.Timestamp,\n        trade_start_time: pd.Timestamp,\n        trade_end_time: pd.Timestamp,\n    ) -> list:\n        \"\"\"generate_order_list_from_target_weight_position\n\n        :param current: The current position\n        :type current: Position\n        :param trade_exchange:\n        :type trade_exchange: Exchange\n        :param target_weight_position: {stock_id : weight}\n        :type target_weight_position: dict\n        :param risk_degree:\n        :type risk_degree: float\n        :param pred_start_time:\n        :type pred_start_time: pd.Timestamp\n        :param pred_end_time:\n        :type pred_end_time: pd.Timestamp\n        :param trade_start_time:\n        :type trade_start_time: pd.Timestamp\n        :param trade_end_time:\n        :type trade_end_time: pd.Timestamp\n\n        :rtype: list\n        \"\"\"\n        raise NotImplementedError()\n\n\nclass OrderGenWInteract(OrderGenerator):\n    \"\"\"Order Generator With Interact\"\"\"\n\n    def generate_order_list_from_target_weight_position(\n        self,\n        current: Position,\n        trade_exchange: Exchange,\n        target_weight_position: dict,\n        risk_degree: float,\n        pred_start_time: pd.Timestamp,\n        pred_end_time: pd.Timestamp,\n        trade_start_time: pd.Timestamp,\n        trade_end_time: pd.Timestamp,\n    ) -> list:\n        \"\"\"generate_order_list_from_target_weight_position\n\n        No adjustment for for the nontradable share.\n        All the tadable value is assigned to the tadable stock according to the weight.\n        if interact == True, will use the price at trade date to generate order list\n        else, will only use the price before the trade date to generate order list\n\n        :param current:\n        :type current: Position\n        :param trade_exchange:\n        :type trade_exchange: Exchange\n        :param target_weight_position:\n        :type target_weight_position: dict\n        :param risk_degree:\n        :type risk_degree: float\n        :param pred_start_time:\n        :type pred_start_time: pd.Timestamp\n        :param pred_end_time:\n        :type pred_end_time: pd.Timestamp\n        :param trade_start_time:\n        :type trade_start_time: pd.Timestamp\n        :param trade_end_time:\n        :type trade_end_time: pd.Timestamp\n\n        :rtype: list\n        \"\"\"\n        if target_weight_position is None:\n            return []\n\n        # calculate current_tradable_value\n        current_amount_dict = current.get_stock_amount_dict()\n\n        current_total_value = trade_exchange.calculate_amount_position_value(\n            amount_dict=current_amount_dict,\n            start_time=trade_start_time,\n            end_time=trade_end_time,\n            only_tradable=False,\n        )\n        current_tradable_value = trade_exchange.calculate_amount_position_value(\n            amount_dict=current_amount_dict,\n            start_time=trade_start_time,\n            end_time=trade_end_time,\n            only_tradable=True,\n        )\n        # add cash\n        current_tradable_value += current.get_cash()\n\n        reserved_cash = (1.0 - risk_degree) * (current_total_value + current.get_cash())\n        current_tradable_value -= reserved_cash\n\n        if current_tradable_value < 0:\n            # if you sell all the tradable stock can not meet the reserved\n            # value. Then just sell all the stocks\n            target_amount_dict = copy.deepcopy(current_amount_dict.copy())\n            for stock_id in list(target_amount_dict.keys()):\n                if trade_exchange.is_stock_tradable(stock_id, start_time=trade_start_time, end_time=trade_end_time):\n                    del target_amount_dict[stock_id]\n        else:\n            # consider cost rate\n            current_tradable_value /= 1 + max(trade_exchange.close_cost, trade_exchange.open_cost)\n\n            # strategy 1 : generate amount_position by weight_position\n            # Use API in Exchange()\n            target_amount_dict = trade_exchange.generate_amount_position_from_weight_position(\n                weight_position=target_weight_position,\n                cash=current_tradable_value,\n                start_time=trade_start_time,\n                end_time=trade_end_time,\n            )\n        order_list = trade_exchange.generate_order_for_target_amount_position(\n            target_position=target_amount_dict,\n            current_position=current_amount_dict,\n            start_time=trade_start_time,\n            end_time=trade_end_time,\n        )\n        return order_list\n\n\nclass OrderGenWOInteract(OrderGenerator):\n    \"\"\"Order Generator Without Interact\"\"\"\n\n    def generate_order_list_from_target_weight_position(\n        self,\n        current: Position,\n        trade_exchange: Exchange,\n        target_weight_position: dict,\n        risk_degree: float,\n        pred_start_time: pd.Timestamp,\n        pred_end_time: pd.Timestamp,\n        trade_start_time: pd.Timestamp,\n        trade_end_time: pd.Timestamp,\n    ) -> list:\n        \"\"\"generate_order_list_from_target_weight_position\n\n        generate order list directly not using the information (e.g. whether can be traded, the accurate trade price)\n         at trade date.\n        In target weight position, generating order list need to know the price of objective stock in trade date,\n        but we cannot get that\n        value when do not interact with exchange, so we check the %close price at pred_date or price recorded\n        in current position.\n\n        :param current:\n        :type current: Position\n        :param trade_exchange:\n        :type trade_exchange: Exchange\n        :param target_weight_position:\n        :type target_weight_position: dict\n        :param risk_degree:\n        :type risk_degree: float\n        :param pred_start_time:\n        :type pred_start_time: pd.Timestamp\n        :param pred_end_time:\n        :type pred_end_time: pd.Timestamp\n        :param trade_start_time:\n        :type trade_start_time: pd.Timestamp\n        :param trade_end_time:\n        :type trade_end_time: pd.Timestamp\n\n        :rtype: list of generated orders\n        \"\"\"\n        if target_weight_position is None:\n            return []\n\n        risk_total_value = risk_degree * current.calculate_value()\n\n        current_stock = current.get_stock_list()\n        amount_dict = {}\n        for stock_id in target_weight_position:\n            # Current rule will ignore the stock that not hold and cannot be traded at predict date\n            if trade_exchange.is_stock_tradable(\n                stock_id=stock_id, start_time=trade_start_time, end_time=trade_end_time\n            ) and trade_exchange.is_stock_tradable(\n                stock_id=stock_id, start_time=pred_start_time, end_time=pred_end_time\n            ):\n                amount_dict[stock_id] = (\n                    risk_total_value\n                    * target_weight_position[stock_id]\n                    / trade_exchange.get_close(stock_id, start_time=pred_start_time, end_time=pred_end_time)\n                )\n                # TODO: Qlib use None to represent trading suspension.\n                #  So last close price can't be the estimated trading price.\n                # Maybe a close price with forward fill will be a better solution.\n            elif stock_id in current_stock:\n                amount_dict[stock_id] = (\n                    risk_total_value * target_weight_position[stock_id] / current.get_stock_price(stock_id)\n                )\n            else:\n                continue\n        order_list = trade_exchange.generate_order_for_target_amount_position(\n            target_position=amount_dict,\n            current_position=current.get_stock_amount_dict(),\n            start_time=trade_start_time,\n            end_time=trade_end_time,\n        )\n        return order_list\n"
  },
  {
    "path": "qlib/contrib/strategy/rule_strategy.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\nfrom pathlib import Path\nimport warnings\nimport numpy as np\nimport pandas as pd\nfrom typing import IO, List, Tuple, Union\nfrom qlib.data.dataset.utils import convert_index_format\n\nfrom qlib.utils import lazy_sort_index\n\nfrom ...utils.resam import resam_ts_data, ts_data_last\nfrom ...data.data import D\nfrom ...strategy.base import BaseStrategy\nfrom ...backtest.decision import BaseTradeDecision, Order, TradeDecisionWO, TradeRange\nfrom ...backtest.exchange import Exchange, OrderHelper\nfrom ...backtest.utils import CommonInfrastructure, LevelInfrastructure\nfrom qlib.utils.file import get_io_object\nfrom qlib.backtest.utils import get_start_end_idx\n\n\nclass TWAPStrategy(BaseStrategy):\n    \"\"\"TWAP Strategy for trading\n\n    NOTE:\n        - This TWAP strategy will celling round when trading. This will make the TWAP trading strategy produce the order\n          earlier when the total trade unit of amount is less than the trading step\n    \"\"\"\n\n    def reset(self, outer_trade_decision: BaseTradeDecision = None, **kwargs):\n        \"\"\"\n        Parameters\n        ----------\n        outer_trade_decision : BaseTradeDecision, optional\n        \"\"\"\n\n        super(TWAPStrategy, self).reset(outer_trade_decision=outer_trade_decision, **kwargs)\n        if outer_trade_decision is not None:\n            self.trade_amount_remain = {}\n            for order in outer_trade_decision.get_decision():\n                self.trade_amount_remain[order.stock_id] = order.amount\n\n    def generate_trade_decision(self, execute_result=None):\n        # NOTE:  corner cases!!!\n        # - If using upperbound round, please don't sell the amount which should in next step\n        #   - the coordinate of the amount between steps is hard to be dealt between steps in the same level. It\n        #     is easier to be dealt in upper steps\n\n        # strategy is not available. Give an empty decision\n        if len(self.outer_trade_decision.get_decision()) == 0:\n            return TradeDecisionWO(order_list=[], strategy=self)\n\n        # get the number of trading step finished, trade_step can be [0, 1, 2, ..., trade_len - 1]\n        trade_step = self.trade_calendar.get_trade_step()\n        # get the total count of trading step\n        start_idx, end_idx = get_start_end_idx(self.trade_calendar, self.outer_trade_decision)\n        trade_len = end_idx - start_idx + 1\n\n        if trade_step < start_idx or trade_step > end_idx:\n            # It is not time to start trading or trading has ended.\n            return TradeDecisionWO(order_list=[], strategy=self)\n\n        rel_trade_step = trade_step - start_idx  # trade_step relative to start_idx (number of steps has already passed)\n\n        # update the order amount\n        if execute_result is not None:\n            for order, _, _, _ in execute_result:\n                self.trade_amount_remain[order.stock_id] -= order.deal_amount\n\n        trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)\n        order_list = []\n        for order in self.outer_trade_decision.get_decision():\n            # Don't peek the future information, so we use check_stock_suspended instead of is_stock_tradable\n            # necessity of this\n            # - if stock is suspended, the quote values of stocks is NaN. The following code will raise error when\n            # encountering NaN factor\n            if self.trade_exchange.check_stock_suspended(\n                stock_id=order.stock_id, start_time=trade_start_time, end_time=trade_end_time\n            ):\n                continue\n\n            # the expected trade amount after current step\n            amount_expect = order.amount / trade_len * (rel_trade_step + 1)\n\n            # remain amount\n            amount_remain = self.trade_amount_remain[order.stock_id]\n\n            # the amount has already been finished now.\n            amount_finished = order.amount - amount_remain\n\n            # the expected amount of current step\n            amount_delta = amount_expect - amount_finished\n\n            _amount_trade_unit = self.trade_exchange.get_amount_of_trade_unit(\n                stock_id=order.stock_id, start_time=order.start_time, end_time=order.end_time\n            )\n\n            # round the amount_delta by trade_unit and clip by remain\n            # NOTE: this could be more than expected.\n            if _amount_trade_unit is None:\n                # divide the order into equal parts, and trade one part\n                amount_delta_target = amount_delta\n            else:\n                amount_delta_target = min(\n                    np.round(amount_delta / _amount_trade_unit) * _amount_trade_unit, amount_remain\n                )\n\n            # handle last step to make sure all positions have gone\n            # necessity: the last step can't be rounded to the a unit (e.g. reminder < 0.5 unit)\n            if rel_trade_step == trade_len - 1:\n                amount_delta_target = amount_remain\n\n            if amount_delta_target > 1e-5:\n                _order = Order(\n                    stock_id=order.stock_id,\n                    amount=amount_delta_target,\n                    start_time=trade_start_time,\n                    end_time=trade_end_time,\n                    direction=order.direction,  # 1 for buy\n                )\n                order_list.append(_order)\n        return TradeDecisionWO(order_list=order_list, strategy=self)\n\n\nclass SBBStrategyBase(BaseStrategy):\n    \"\"\"\n    (S)elect the (B)etter one among every two adjacent trading (B)ars to sell or buy.\n    \"\"\"\n\n    TREND_MID = 0\n    TREND_SHORT = 1\n    TREND_LONG = 2\n\n    # TODO:\n    # 1. Supporting leverage the get_range_limit result from the decision\n    # 2. Supporting alter_outer_trade_decision\n    # 3. Supporting checking the availability of trade decision\n\n    def reset(self, outer_trade_decision: BaseTradeDecision = None, **kwargs):\n        \"\"\"\n        Parameters\n        ----------\n        outer_trade_decision : BaseTradeDecision, optional\n        \"\"\"\n        super(SBBStrategyBase, self).reset(outer_trade_decision=outer_trade_decision, **kwargs)\n        if outer_trade_decision is not None:\n            self.trade_trend = {}\n            self.trade_amount = {}\n            # init the trade amount of order and  predicted trade trend\n            for order in outer_trade_decision.get_decision():\n                self.trade_trend[order.stock_id] = self.TREND_MID\n                self.trade_amount[order.stock_id] = order.amount\n\n    def _pred_price_trend(self, stock_id, pred_start_time=None, pred_end_time=None):\n        raise NotImplementedError(\"pred_price_trend method is not implemented!\")\n\n    def generate_trade_decision(self, execute_result=None):\n        # get the number of trading step finished, trade_step can be [0, 1, 2, ..., trade_len - 1]\n        trade_step = self.trade_calendar.get_trade_step()\n        # get the total count of trading step\n        trade_len = self.trade_calendar.get_trade_len()\n\n        # update the order amount\n        if execute_result is not None:\n            for order, _, _, _ in execute_result:\n                self.trade_amount[order.stock_id] -= order.deal_amount\n\n        trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)\n        pred_start_time, pred_end_time = self.trade_calendar.get_step_time(trade_step, shift=1)\n        order_list = []\n        # for each order in in self.outer_trade_decision\n        for order in self.outer_trade_decision.get_decision():\n            # get the price trend\n            if trade_step % 2 == 0:\n                # in the first of two adjacent bars, predict the price trend\n                _pred_trend = self._pred_price_trend(order.stock_id, pred_start_time, pred_end_time)\n            else:\n                # in the second of two adjacent bars, use the trend predicted in the first one\n                _pred_trend = self.trade_trend[order.stock_id]\n            # if not tradable, continue\n            if not self.trade_exchange.is_stock_tradable(\n                stock_id=order.stock_id, start_time=trade_start_time, end_time=trade_end_time\n            ):\n                if trade_step % 2 == 0:\n                    self.trade_trend[order.stock_id] = _pred_trend\n                continue\n            # get amount of one trade unit\n            _amount_trade_unit = self.trade_exchange.get_amount_of_trade_unit(\n                stock_id=order.stock_id, start_time=order.start_time, end_time=order.end_time\n            )\n            if _pred_trend == self.TREND_MID:\n                _order_amount = None\n                # considering trade unit\n                if _amount_trade_unit is None:\n                    # divide the order into equal parts, and trade one part\n                    _order_amount = self.trade_amount[order.stock_id] / (trade_len - trade_step)\n                # without considering trade unit\n                else:\n                    # divide the order into equal parts, and trade one part\n                    # calculate the total count of trade units to trade\n                    trade_unit_cnt = int(self.trade_amount[order.stock_id] // _amount_trade_unit)\n                    # calculate the amount of one part, ceil the amount\n                    # floor((trade_unit_cnt + trade_len - trade_step - 1) / (trade_len - trade_step)) == ceil(trade_unit_cnt / (trade_len - trade_step))\n                    _order_amount = (\n                        (trade_unit_cnt + trade_len - trade_step - 1) // (trade_len - trade_step) * _amount_trade_unit\n                    )\n                if order.direction == order.SELL:\n                    # sell all amount at last\n                    if self.trade_amount[order.stock_id] > 1e-5 and (\n                        _order_amount < 1e-5 or trade_step == trade_len - 1\n                    ):\n                        _order_amount = self.trade_amount[order.stock_id]\n\n                _order_amount = min(_order_amount, self.trade_amount[order.stock_id])\n\n                if _order_amount > 1e-5:\n                    _order = Order(\n                        stock_id=order.stock_id,\n                        amount=_order_amount,\n                        start_time=trade_start_time,\n                        end_time=trade_end_time,\n                        direction=order.direction,\n                    )\n                    order_list.append(_order)\n\n            else:\n                _order_amount = None\n                # considering trade unit\n                if _amount_trade_unit is None:\n                    # N trade day left, divide the order into N + 1 parts, and trade 2 parts\n                    _order_amount = 2 * self.trade_amount[order.stock_id] / (trade_len - trade_step + 1)\n                # without considering trade unit\n                else:\n                    # cal how many trade unit\n                    trade_unit_cnt = int(self.trade_amount[order.stock_id] // _amount_trade_unit)\n                    # N trade day left, divide the order into N + 1 parts, and trade 2 parts\n                    _order_amount = (\n                        (trade_unit_cnt + trade_len - trade_step)\n                        // (trade_len - trade_step + 1)\n                        * 2\n                        * _amount_trade_unit\n                    )\n                if order.direction == order.SELL:\n                    # sell all amount at last\n                    if self.trade_amount[order.stock_id] > 1e-5 and (\n                        _order_amount < 1e-5 or trade_step == trade_len - 1\n                    ):\n                        _order_amount = self.trade_amount[order.stock_id]\n\n                _order_amount = min(_order_amount, self.trade_amount[order.stock_id])\n\n                if _order_amount > 1e-5:\n                    if trade_step % 2 == 0:\n                        # in the first one of two adjacent bars\n                        # if look short on the price, sell the stock more\n                        # if look long on the price, buy the stock more\n                        if (\n                            _pred_trend == self.TREND_SHORT\n                            and order.direction == order.SELL\n                            or _pred_trend == self.TREND_LONG\n                            and order.direction == order.BUY\n                        ):\n                            _order = Order(\n                                stock_id=order.stock_id,\n                                amount=_order_amount,\n                                start_time=trade_start_time,\n                                end_time=trade_end_time,\n                                direction=order.direction,  # 1 for buy\n                            )\n                            order_list.append(_order)\n                    else:\n                        # in the second one of two adjacent bars\n                        # if look short on the price, buy the stock more\n                        # if look long on the price, sell the stock more\n                        if (\n                            _pred_trend == self.TREND_SHORT\n                            and order.direction == order.BUY\n                            or _pred_trend == self.TREND_LONG\n                            and order.direction == order.SELL\n                        ):\n                            _order = Order(\n                                stock_id=order.stock_id,\n                                amount=_order_amount,\n                                start_time=trade_start_time,\n                                end_time=trade_end_time,\n                                direction=order.direction,  # 1 for buy\n                            )\n                            order_list.append(_order)\n\n            if trade_step % 2 == 0:\n                # in the first one of two adjacent bars, store the trend for the second one to use\n                self.trade_trend[order.stock_id] = _pred_trend\n\n        return TradeDecisionWO(order_list, self)\n\n\nclass SBBStrategyEMA(SBBStrategyBase):\n    \"\"\"\n    (S)elect the (B)etter one among every two adjacent trading (B)ars to sell or buy with (EMA) signal.\n    \"\"\"\n\n    # TODO:\n    # 1. Supporting leverage the get_range_limit result from the decision\n    # 2. Supporting alter_outer_trade_decision\n    # 3. Supporting checking the availability of trade decision\n\n    def __init__(\n        self,\n        outer_trade_decision: BaseTradeDecision = None,\n        instruments: Union[List, str] = \"csi300\",\n        freq: str = \"day\",\n        trade_exchange: Exchange = None,\n        level_infra: LevelInfrastructure = None,\n        common_infra: CommonInfrastructure = None,\n        **kwargs,\n    ):\n        \"\"\"\n        Parameters\n        ----------\n        instruments : Union[List, str], optional\n            instruments of EMA signal, by default \"csi300\"\n        freq : str, optional\n            freq of EMA signal, by default \"day\"\n            Note: `freq` may be different from `time_per_step`\n        \"\"\"\n        if instruments is None:\n            warnings.warn(\"`instruments` is not set, will load all stocks\")\n            self.instruments = \"all\"\n        elif isinstance(instruments, str):\n            self.instruments = D.instruments(instruments)\n        elif isinstance(instruments, List):\n            self.instruments = instruments\n        self.freq = freq\n        super(SBBStrategyEMA, self).__init__(\n            outer_trade_decision, level_infra, common_infra, trade_exchange=trade_exchange, **kwargs\n        )\n\n    def _reset_signal(self):\n        trade_len = self.trade_calendar.get_trade_len()\n        fields = [\"EMA($close, 10)-EMA($close, 20)\"]\n        signal_start_time, _ = self.trade_calendar.get_step_time(trade_step=0, shift=1)\n        _, signal_end_time = self.trade_calendar.get_step_time(trade_step=trade_len - 1, shift=1)\n        signal_df = D.features(\n            self.instruments, fields, start_time=signal_start_time, end_time=signal_end_time, freq=self.freq\n        )\n        signal_df.columns = [\"signal\"]\n        self.signal = {}\n\n        if not signal_df.empty:\n            for stock_id, stock_val in signal_df.groupby(level=\"instrument\", group_keys=False):\n                self.signal[stock_id] = stock_val[\"signal\"].droplevel(level=\"instrument\")\n\n    def reset_level_infra(self, level_infra):\n        \"\"\"\n        reset level-shared infra\n        - After reset the trade calendar, the signal will be changed\n        \"\"\"\n        super().reset_level_infra(level_infra)\n        self._reset_signal()\n\n    def _pred_price_trend(self, stock_id, pred_start_time=None, pred_end_time=None):\n        # if no signal, return mid trend\n        if stock_id not in self.signal:\n            return self.TREND_MID\n        else:\n            _sample_signal = resam_ts_data(\n                self.signal[stock_id],\n                pred_start_time,\n                pred_end_time,\n                method=ts_data_last,\n            )\n            # if EMA signal == 0 or None, return mid trend\n            if _sample_signal is None or np.isnan(_sample_signal) or _sample_signal == 0:\n                return self.TREND_MID\n            # if EMA signal > 0, return long trend\n            elif _sample_signal > 0:\n                return self.TREND_LONG\n            # if EMA signal < 0, return short trend\n            else:\n                return self.TREND_SHORT\n\n\nclass ACStrategy(BaseStrategy):\n    # TODO:\n    # 1. Supporting leverage the get_range_limit result from the decision\n    # 2. Supporting alter_outer_trade_decision\n    # 3. Supporting checking the availability of trade decision\n    def __init__(\n        self,\n        lamb: float = 1e-6,\n        eta: float = 2.5e-6,\n        window_size: int = 20,\n        outer_trade_decision: BaseTradeDecision = None,\n        instruments: Union[List, str] = \"csi300\",\n        freq: str = \"day\",\n        trade_exchange: Exchange = None,\n        level_infra: LevelInfrastructure = None,\n        common_infra: CommonInfrastructure = None,\n        **kwargs,\n    ):\n        \"\"\"\n        Parameters\n        ----------\n        instruments : Union[List, str], optional\n            instruments of Volatility, by default \"csi300\"\n        freq : str, optional\n            freq of Volatility, by default \"day\"\n            Note: `freq` may be different from `time_per_step`\n        \"\"\"\n        self.lamb = lamb\n        self.eta = eta\n        self.window_size = window_size\n        if instruments is None:\n            warnings.warn(\"`instruments` is not set, will load all stocks\")\n            self.instruments = \"all\"\n        if isinstance(instruments, str):\n            self.instruments = D.instruments(instruments)\n        self.freq = freq\n        super(ACStrategy, self).__init__(\n            outer_trade_decision, level_infra, common_infra, trade_exchange=trade_exchange, **kwargs\n        )\n\n    def _reset_signal(self):\n        trade_len = self.trade_calendar.get_trade_len()\n        fields = [\n            f\"Power(Sum(Power(Log($close/Ref($close, 1)), 2), {self.window_size})/{self.window_size - 1}-Power(Sum(Log($close/Ref($close, 1)), {self.window_size}), 2)/({self.window_size}*{self.window_size - 1}), 0.5)\"\n        ]\n        signal_start_time, _ = self.trade_calendar.get_step_time(trade_step=0, shift=1)\n        _, signal_end_time = self.trade_calendar.get_step_time(trade_step=trade_len - 1, shift=1)\n        signal_df = D.features(\n            self.instruments, fields, start_time=signal_start_time, end_time=signal_end_time, freq=self.freq\n        )\n        signal_df.columns = [\"volatility\"]\n        self.signal = {}\n\n        if not signal_df.empty:\n            for stock_id, stock_val in signal_df.groupby(level=\"instrument\", group_keys=False):\n                self.signal[stock_id] = stock_val[\"volatility\"].droplevel(level=\"instrument\")\n\n    def reset_level_infra(self, level_infra):\n        \"\"\"\n        reset level-shared infra\n        - After reset the trade calendar, the signal will be changed\n        \"\"\"\n        super().reset_level_infra(level_infra)\n        self._reset_signal()\n\n    def reset(self, outer_trade_decision: BaseTradeDecision = None, **kwargs):\n        \"\"\"\n        Parameters\n        ----------\n        outer_trade_decision : BaseTradeDecision, optional\n        \"\"\"\n        super(ACStrategy, self).reset(outer_trade_decision=outer_trade_decision, **kwargs)\n        if outer_trade_decision is not None:\n            self.trade_amount = {}\n            # init the trade amount of order and  predicted trade trend\n            for order in outer_trade_decision.get_decision():\n                self.trade_amount[order.stock_id] = order.amount\n\n    def generate_trade_decision(self, execute_result=None):\n        # get the number of trading step finished, trade_step can be [0, 1, 2, ..., trade_len - 1]\n        trade_step = self.trade_calendar.get_trade_step()\n        # get the total count of trading step\n        trade_len = self.trade_calendar.get_trade_len()\n\n        # update the order amount\n        if execute_result is not None:\n            for order, _, _, _ in execute_result:\n                self.trade_amount[order.stock_id] -= order.deal_amount\n\n        trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)\n        pred_start_time, pred_end_time = self.trade_calendar.get_step_time(trade_step, shift=1)\n        order_list = []\n        for order in self.outer_trade_decision.get_decision():\n            # if not tradable, continue\n            if not self.trade_exchange.is_stock_tradable(\n                stock_id=order.stock_id, start_time=trade_start_time, end_time=trade_end_time\n            ):\n                continue\n            _order_amount = None\n            # considering trade unit\n\n            sig_sam = (\n                resam_ts_data(self.signal[order.stock_id], pred_start_time, pred_end_time, method=ts_data_last)\n                if order.stock_id in self.signal\n                else None\n            )\n\n            if sig_sam is None or np.isnan(sig_sam):\n                # no signal, TWAP\n                _amount_trade_unit = self.trade_exchange.get_amount_of_trade_unit(\n                    stock_id=order.stock_id, start_time=order.start_time, end_time=order.end_time\n                )\n                if _amount_trade_unit is None:\n                    # divide the order into equal parts, and trade one part\n                    _order_amount = self.trade_amount[order.stock_id] / (trade_len - trade_step)\n                else:\n                    # divide the order into equal parts, and trade one part\n                    # calculate the total count of trade units to trade\n                    trade_unit_cnt = int(self.trade_amount[order.stock_id] // _amount_trade_unit)\n                    # calculate the amount of one part, ceil the amount\n                    # floor((trade_unit_cnt + trade_len - trade_step - 1) / (trade_len - trade_step)) == ceil(trade_unit_cnt / (trade_len - trade_step))\n                    _order_amount = (\n                        (trade_unit_cnt + trade_len - trade_step - 1) // (trade_len - trade_step) * _amount_trade_unit\n                    )\n            else:\n                # VA strategy\n                kappa_tild = self.lamb / self.eta * sig_sam * sig_sam\n                kappa = np.arccosh(kappa_tild / 2 + 1)\n                amount_ratio = (\n                    np.sinh(kappa * (trade_len - trade_step)) - np.sinh(kappa * (trade_len - trade_step - 1))\n                ) / np.sinh(kappa * trade_len)\n                _order_amount = order.amount * amount_ratio\n                _order_amount = self.trade_exchange.round_amount_by_trade_unit(\n                    _order_amount, stock_id=order.stock_id, start_time=order.start_time, end_time=order.end_time\n                )\n\n            if order.direction == order.SELL:\n                # sell all amount at last\n                if self.trade_amount[order.stock_id] > 1e-5 and (_order_amount < 1e-5 or trade_step == trade_len - 1):\n                    _order_amount = self.trade_amount[order.stock_id]\n\n            _order_amount = min(_order_amount, self.trade_amount[order.stock_id])\n\n            if _order_amount > 1e-5:\n                _order = Order(\n                    stock_id=order.stock_id,\n                    amount=_order_amount,\n                    start_time=trade_start_time,\n                    end_time=trade_end_time,\n                    direction=order.direction,  # 1 for buy\n                    factor=order.factor,\n                )\n                order_list.append(_order)\n        return TradeDecisionWO(order_list, self)\n\n\nclass RandomOrderStrategy(BaseStrategy):\n    def __init__(\n        self,\n        trade_range: Union[Tuple[int, int], TradeRange],  # The range is closed on both left and right.\n        sample_ratio: float = 1.0,\n        volume_ratio: float = 0.01,\n        market: str = \"all\",\n        direction: int = Order.BUY,\n        *args,\n        **kwargs,\n    ):\n        \"\"\"\n        Parameters\n        ----------\n        trade_range : Tuple\n            please refer to the `trade_range` parameter of BaseStrategy\n        sample_ratio : float\n            the ratio of all orders are sampled\n        volume_ratio : float\n            the volume of the total day\n            raito of the total volume of a specific day\n        market : str\n            stock pool for sampling\n        \"\"\"\n\n        super().__init__(*args, **kwargs)\n        self.sample_ratio = sample_ratio\n        self.volume_ratio = volume_ratio\n        self.market = market\n        self.direction = direction\n        exch: Exchange = self.common_infra.get(\"trade_exchange\")\n        # TODO: this can't be online\n        self.volume = D.features(\n            D.instruments(market), [\"Mean(Ref($volume, 1), 10)\"], start_time=exch.start_time, end_time=exch.end_time\n        )\n        self.volume_df = self.volume.iloc[:, 0].unstack()\n        self.trade_range = trade_range\n\n    def generate_trade_decision(self, execute_result=None):\n        trade_step = self.trade_calendar.get_trade_step()\n        step_time_start, step_time_end = self.trade_calendar.get_step_time(trade_step)\n\n        order_list = []\n        if step_time_start in self.volume_df:\n            for stock_id, volume in self.volume_df[step_time_start].dropna().sample(frac=self.sample_ratio).items():\n                order_list.append(\n                    self.common_infra.get(\"trade_exchange\")\n                    .get_order_helper()\n                    .create(\n                        code=stock_id,\n                        amount=volume * self.volume_ratio,\n                        direction=self.direction,\n                    )\n                )\n        return TradeDecisionWO(order_list, self, self.trade_range)\n\n\nclass FileOrderStrategy(BaseStrategy):\n    \"\"\"\n    Motivation:\n    - This class provides an interface for user to read orders from csv files.\n    \"\"\"\n\n    def __init__(\n        self,\n        file: Union[IO, str, Path, pd.DataFrame],\n        trade_range: Union[Tuple[int, int], TradeRange] = None,\n        *args,\n        **kwargs,\n    ):\n        \"\"\"\n\n        Parameters\n        ----------\n        file : Union[IO, str, Path, pd.DataFrame]\n            this parameters will specify the info of expected orders\n\n            Here is an example of the content\n\n            1) Amount (**adjusted**) based strategy\n\n                datetime,instrument,amount,direction\n                20200102,  SH600519,  1000,     sell\n                20200103,  SH600519,  1000,      buy\n                20200106,  SH600519,  1000,     sell\n\n        trade_range : Tuple[int, int]\n            the intra day time index range of the orders\n            the left and right is closed.\n\n            If you want to get the trade_range in intra-day\n            - `qlib/utils/time.py:def get_day_min_idx_range` can help you create the index range easier\n            # TODO: this is a trade_range level limitation. We'll implement a more detailed limitation later.\n\n        \"\"\"\n        super().__init__(*args, **kwargs)\n        if isinstance(file, pd.DataFrame):\n            self.order_df = file\n        else:\n            with get_io_object(file) as f:\n                self.order_df = pd.read_csv(f, dtype={\"datetime\": str})\n\n        self.order_df[\"datetime\"] = self.order_df[\"datetime\"].apply(pd.Timestamp)\n        self.order_df = self.order_df.set_index([\"datetime\", \"instrument\"])\n\n        # make sure the datetime is the first level for fast indexing\n        self.order_df = lazy_sort_index(convert_index_format(self.order_df, level=\"datetime\"))\n        self.trade_range = trade_range\n\n    def generate_trade_decision(self, execute_result=None) -> TradeDecisionWO:\n        \"\"\"\n        Parameters\n        ----------\n        execute_result :\n            execute_result will be ignored in FileOrderStrategy\n        \"\"\"\n        oh: OrderHelper = self.common_infra.get(\"trade_exchange\").get_order_helper()\n        start, _ = self.trade_calendar.get_step_time()\n        # CONVERSION: the bar is indexed by the time\n        try:\n            df = self.order_df.loc(axis=0)[start]\n        except KeyError:\n            return TradeDecisionWO([], self)\n        else:\n            order_list = []\n            for idx, row in df.iterrows():\n                order_list.append(\n                    oh.create(\n                        code=idx,\n                        amount=row[\"amount\"],\n                        direction=Order.parse_dir(row[\"direction\"]),\n                    )\n                )\n            return TradeDecisionWO(order_list, self, self.trade_range)\n"
  },
  {
    "path": "qlib/contrib/strategy/signal_strategy.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\nimport os\nimport copy\nimport warnings\nimport numpy as np\nimport pandas as pd\n\nfrom typing import Dict, List, Text, Tuple, Union\nfrom abc import ABC\n\nfrom qlib.data import D\nfrom qlib.data.dataset import Dataset\nfrom qlib.model.base import BaseModel\nfrom qlib.strategy.base import BaseStrategy\nfrom qlib.backtest.position import Position\nfrom qlib.backtest.signal import Signal, create_signal_from\nfrom qlib.backtest.decision import Order, OrderDir, TradeDecisionWO\nfrom qlib.log import get_module_logger\nfrom qlib.utils import get_pre_trading_date, load_dataset\nfrom qlib.contrib.strategy.order_generator import OrderGenerator, OrderGenWOInteract\nfrom qlib.contrib.strategy.optimizer import EnhancedIndexingOptimizer\n\n\nclass BaseSignalStrategy(BaseStrategy, ABC):\n    def __init__(\n        self,\n        *,\n        signal: Union[Signal, Tuple[BaseModel, Dataset], List, Dict, Text, pd.Series, pd.DataFrame] = None,\n        model=None,\n        dataset=None,\n        risk_degree: float = 0.95,\n        trade_exchange=None,\n        level_infra=None,\n        common_infra=None,\n        **kwargs,\n    ):\n        \"\"\"\n        Parameters\n        -----------\n        signal :\n            the information to describe a signal. Please refer to the docs of `qlib.backtest.signal.create_signal_from`\n            the decision of the strategy will base on the given signal\n        risk_degree : float\n            position percentage of total value.\n        trade_exchange : Exchange\n            exchange that provides market info, used to deal order and generate report\n            - If `trade_exchange` is None, self.trade_exchange will be set with common_infra\n            - It allowes different trade_exchanges is used in different executions.\n            - For example:\n                - In daily execution, both daily exchange and minutely are usable, but the daily exchange is recommended because it runs faster.\n                - In minutely execution, the daily exchange is not usable, only the minutely exchange is recommended.\n\n        \"\"\"\n        super().__init__(level_infra=level_infra, common_infra=common_infra, trade_exchange=trade_exchange, **kwargs)\n\n        self.risk_degree = risk_degree\n\n        # This is trying to be compatible with previous version of qlib task config\n        if model is not None and dataset is not None:\n            warnings.warn(\"`model` `dataset` is deprecated; use `signal`.\", DeprecationWarning)\n            signal = model, dataset\n\n        self.signal: Signal = create_signal_from(signal)\n\n    def get_risk_degree(self, trade_step=None):\n        \"\"\"get_risk_degree\n        Return the proportion of your total value you will use in investment.\n        Dynamically risk_degree will result in Market timing.\n        \"\"\"\n        # It will use 95% amount of your total value by default\n        return self.risk_degree\n\n\nclass TopkDropoutStrategy(BaseSignalStrategy):\n    # TODO:\n    # 1. Supporting leverage the get_range_limit result from the decision\n    # 2. Supporting alter_outer_trade_decision\n    # 3. Supporting checking the availability of trade decision\n    # 4. Regenerate results with forbid_all_trade_at_limit set to false and flip the default to false, as it is consistent with reality.\n    def __init__(\n        self,\n        *,\n        topk,\n        n_drop,\n        method_sell=\"bottom\",\n        method_buy=\"top\",\n        hold_thresh=1,\n        only_tradable=False,\n        forbid_all_trade_at_limit=True,\n        **kwargs,\n    ):\n        \"\"\"\n        Parameters\n        -----------\n        topk : int\n            the number of stocks in the portfolio.\n        n_drop : int\n            number of stocks to be replaced in each trading date.\n        method_sell : str\n            dropout method_sell, random/bottom.\n        method_buy : str\n            dropout method_buy, random/top.\n        hold_thresh : int\n            minimum holding days\n            before sell stock , will check current.get_stock_count(order.stock_id) >= self.hold_thresh.\n        only_tradable : bool\n            will the strategy only consider the tradable stock when buying and selling.\n\n            if only_tradable:\n\n                strategy will make decision with the tradable state of the stock info and avoid buy and sell them.\n\n            else:\n\n                strategy will make buy sell decision without checking the tradable state of the stock.\n        forbid_all_trade_at_limit : bool\n            if forbid all trades when limit_up or limit_down reached.\n\n            if forbid_all_trade_at_limit:\n\n                strategy will not do any trade when price reaches limit up/down, even not sell at limit up nor buy at\n                limit down, though allowed in reality.\n\n            else:\n\n                strategy will sell at limit up and buy ad limit down.\n        \"\"\"\n        super().__init__(**kwargs)\n        self.topk = topk\n        self.n_drop = n_drop\n        self.method_sell = method_sell\n        self.method_buy = method_buy\n        self.hold_thresh = hold_thresh\n        self.only_tradable = only_tradable\n        self.forbid_all_trade_at_limit = forbid_all_trade_at_limit\n\n    def generate_trade_decision(self, execute_result=None):\n        # get the number of trading step finished, trade_step can be [0, 1, 2, ..., trade_len - 1]\n        trade_step = self.trade_calendar.get_trade_step()\n        trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)\n        pred_start_time, pred_end_time = self.trade_calendar.get_step_time(trade_step, shift=1)\n        pred_score = self.signal.get_signal(start_time=pred_start_time, end_time=pred_end_time)\n        # NOTE: the current version of topk dropout strategy can't handle pd.DataFrame(multiple signal)\n        # So it only leverage the first col of signal\n        if isinstance(pred_score, pd.DataFrame):\n            pred_score = pred_score.iloc[:, 0]\n        if pred_score is None:\n            return TradeDecisionWO([], self)\n        if self.only_tradable:\n            # If The strategy only consider tradable stock when make decision\n            # It needs following actions to filter stocks\n            def get_first_n(li, n, reverse=False):\n                cur_n = 0\n                res = []\n                for si in reversed(li) if reverse else li:\n                    if self.trade_exchange.is_stock_tradable(\n                        stock_id=si, start_time=trade_start_time, end_time=trade_end_time\n                    ):\n                        res.append(si)\n                        cur_n += 1\n                        if cur_n >= n:\n                            break\n                return res[::-1] if reverse else res\n\n            def get_last_n(li, n):\n                return get_first_n(li, n, reverse=True)\n\n            def filter_stock(li):\n                return [\n                    si\n                    for si in li\n                    if self.trade_exchange.is_stock_tradable(\n                        stock_id=si, start_time=trade_start_time, end_time=trade_end_time\n                    )\n                ]\n\n        else:\n            # Otherwise, the stock will make decision without the stock tradable info\n            def get_first_n(li, n):\n                return list(li)[:n]\n\n            def get_last_n(li, n):\n                return list(li)[-n:]\n\n            def filter_stock(li):\n                return li\n\n        current_temp: Position = copy.deepcopy(self.trade_position)\n        # generate order list for this adjust date\n        sell_order_list = []\n        buy_order_list = []\n        # load score\n        cash = current_temp.get_cash()\n        current_stock_list = current_temp.get_stock_list()\n        # last position (sorted by score)\n        last = pred_score.reindex(current_stock_list).sort_values(ascending=False).index\n        # The new stocks today want to buy **at most**\n        if self.method_buy == \"top\":\n            today = get_first_n(\n                pred_score[~pred_score.index.isin(last)].sort_values(ascending=False).index,\n                self.n_drop + self.topk - len(last),\n            )\n        elif self.method_buy == \"random\":\n            topk_candi = get_first_n(pred_score.sort_values(ascending=False).index, self.topk)\n            candi = list(filter(lambda x: x not in last, topk_candi))\n            n = self.n_drop + self.topk - len(last)\n            try:\n                today = np.random.choice(candi, n, replace=False)\n            except ValueError:\n                today = candi\n        else:\n            raise NotImplementedError(f\"This type of input is not supported\")\n        # combine(new stocks + last stocks),  we will drop stocks from this list\n        # In case of dropping higher score stock and buying lower score stock.\n        comb = pred_score.reindex(last.union(pd.Index(today))).sort_values(ascending=False).index\n\n        # Get the stock list we really want to sell (After filtering the case that we sell high and buy low)\n        if self.method_sell == \"bottom\":\n            sell = last[last.isin(get_last_n(comb, self.n_drop))]\n        elif self.method_sell == \"random\":\n            candi = filter_stock(last)\n            try:\n                sell = pd.Index(np.random.choice(candi, self.n_drop, replace=False) if len(last) else [])\n            except ValueError:  # No enough candidates\n                sell = candi\n        else:\n            raise NotImplementedError(f\"This type of input is not supported\")\n\n        # Get the stock list we really want to buy\n        buy = today[: len(sell) + self.topk - len(last)]\n        for code in current_stock_list:\n            if not self.trade_exchange.is_stock_tradable(\n                stock_id=code,\n                start_time=trade_start_time,\n                end_time=trade_end_time,\n                direction=None if self.forbid_all_trade_at_limit else OrderDir.SELL,\n            ):\n                continue\n            if code in sell:\n                # check hold limit\n                time_per_step = self.trade_calendar.get_freq()\n                if current_temp.get_stock_count(code, bar=time_per_step) < self.hold_thresh:\n                    continue\n                # sell order\n                sell_amount = current_temp.get_stock_amount(code=code)\n                # sell_amount = self.trade_exchange.round_amount_by_trade_unit(sell_amount, factor)\n                sell_order = Order(\n                    stock_id=code,\n                    amount=sell_amount,\n                    start_time=trade_start_time,\n                    end_time=trade_end_time,\n                    direction=Order.SELL,  # 0 for sell, 1 for buy\n                )\n                # is order executable\n                if self.trade_exchange.check_order(sell_order):\n                    sell_order_list.append(sell_order)\n                    trade_val, trade_cost, trade_price = self.trade_exchange.deal_order(\n                        sell_order, position=current_temp\n                    )\n                    # update cash\n                    cash += trade_val - trade_cost\n        # buy new stock\n        # note the current has been changed\n        # current_stock_list = current_temp.get_stock_list()\n        value = cash * self.risk_degree / len(buy) if len(buy) > 0 else 0\n\n        # open_cost should be considered in the real trading environment, while the backtest in evaluate.py does not\n        # consider it as the aim of demo is to accomplish same strategy as evaluate.py, so comment out this line\n        # value = value / (1+self.trade_exchange.open_cost) # set open_cost limit\n        for code in buy:\n            # check is stock suspended\n            if not self.trade_exchange.is_stock_tradable(\n                stock_id=code,\n                start_time=trade_start_time,\n                end_time=trade_end_time,\n                direction=None if self.forbid_all_trade_at_limit else OrderDir.BUY,\n            ):\n                continue\n            # buy order\n            buy_price = self.trade_exchange.get_deal_price(\n                stock_id=code, start_time=trade_start_time, end_time=trade_end_time, direction=OrderDir.BUY\n            )\n            buy_amount = value / buy_price\n            factor = self.trade_exchange.get_factor(stock_id=code, start_time=trade_start_time, end_time=trade_end_time)\n            buy_amount = self.trade_exchange.round_amount_by_trade_unit(buy_amount, factor)\n            buy_order = Order(\n                stock_id=code,\n                amount=buy_amount,\n                start_time=trade_start_time,\n                end_time=trade_end_time,\n                direction=Order.BUY,  # 1 for buy\n            )\n            buy_order_list.append(buy_order)\n        return TradeDecisionWO(sell_order_list + buy_order_list, self)\n\n\nclass WeightStrategyBase(BaseSignalStrategy):\n    # TODO:\n    # 1. Supporting leverage the get_range_limit result from the decision\n    # 2. Supporting alter_outer_trade_decision\n    # 3. Supporting checking the availability of trade decision\n    def __init__(\n        self,\n        *,\n        order_generator_cls_or_obj=OrderGenWOInteract,\n        **kwargs,\n    ):\n        \"\"\"\n        signal :\n            the information to describe a signal. Please refer to the docs of `qlib.backtest.signal.create_signal_from`\n            the decision of the strategy will base on the given signal\n        trade_exchange : Exchange\n            exchange that provides market info, used to deal order and generate report\n\n            - If `trade_exchange` is None, self.trade_exchange will be set with common_infra\n            - It allowes different trade_exchanges is used in different executions.\n            - For example:\n\n                - In daily execution, both daily exchange and minutely are usable, but the daily exchange is recommended because it runs faster.\n                - In minutely execution, the daily exchange is not usable, only the minutely exchange is recommended.\n        \"\"\"\n        super().__init__(**kwargs)\n\n        if isinstance(order_generator_cls_or_obj, type):\n            self.order_generator: OrderGenerator = order_generator_cls_or_obj()\n        else:\n            self.order_generator: OrderGenerator = order_generator_cls_or_obj\n\n    def generate_target_weight_position(self, score, current, trade_start_time, trade_end_time):\n        \"\"\"\n        Generate target position from score for this date and the current position.The cash is not considered in the position\n\n        Parameters\n        -----------\n        score : pd.Series\n            pred score for this trade date, index is stock_id, contain 'score' column.\n        current : Position()\n            current position.\n        trade_start_time: pd.Timestamp\n        trade_end_time: pd.Timestamp\n        \"\"\"\n        raise NotImplementedError()\n\n    def generate_trade_decision(self, execute_result=None):\n        # generate_trade_decision\n        # generate_target_weight_position() and generate_order_list_from_target_weight_position() to generate order_list\n\n        # get the number of trading step finished, trade_step can be [0, 1, 2, ..., trade_len - 1]\n        trade_step = self.trade_calendar.get_trade_step()\n        trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)\n        pred_start_time, pred_end_time = self.trade_calendar.get_step_time(trade_step, shift=1)\n        pred_score = self.signal.get_signal(start_time=pred_start_time, end_time=pred_end_time)\n        if pred_score is None:\n            return TradeDecisionWO([], self)\n        current_temp = copy.deepcopy(self.trade_position)\n        assert isinstance(current_temp, Position)  # Avoid InfPosition\n\n        target_weight_position = self.generate_target_weight_position(\n            score=pred_score, current=current_temp, trade_start_time=trade_start_time, trade_end_time=trade_end_time\n        )\n        order_list = self.order_generator.generate_order_list_from_target_weight_position(\n            current=current_temp,\n            trade_exchange=self.trade_exchange,\n            risk_degree=self.get_risk_degree(trade_step),\n            target_weight_position=target_weight_position,\n            pred_start_time=pred_start_time,\n            pred_end_time=pred_end_time,\n            trade_start_time=trade_start_time,\n            trade_end_time=trade_end_time,\n        )\n        return TradeDecisionWO(order_list, self)\n\n\nclass EnhancedIndexingStrategy(WeightStrategyBase):\n    \"\"\"Enhanced Indexing Strategy\n\n    Enhanced indexing combines the arts of active management and passive management,\n    with the aim of outperforming a benchmark index (e.g., S&P 500) in terms of\n    portfolio return while controlling the risk exposure (a.k.a. tracking error).\n\n    Users need to prepare their risk model data like below:\n\n    .. code-block:: text\n\n        ├── /path/to/riskmodel\n        ├──── 20210101\n        ├────── factor_exp.{csv|pkl|h5}\n        ├────── factor_cov.{csv|pkl|h5}\n        ├────── specific_risk.{csv|pkl|h5}\n        ├────── blacklist.{csv|pkl|h5}  # optional\n\n    The risk model data can be obtained from risk data provider. You can also use\n    `qlib.model.riskmodel.structured.StructuredCovEstimator` to prepare these data.\n\n    Args:\n        riskmodel_path (str): risk model path\n        name_mapping (dict): alternative file names\n    \"\"\"\n\n    FACTOR_EXP_NAME = \"factor_exp.pkl\"\n    FACTOR_COV_NAME = \"factor_cov.pkl\"\n    SPECIFIC_RISK_NAME = \"specific_risk.pkl\"\n    BLACKLIST_NAME = \"blacklist.pkl\"\n\n    def __init__(\n        self,\n        *,\n        riskmodel_root,\n        market=\"csi500\",\n        turn_limit=None,\n        name_mapping={},\n        optimizer_kwargs={},\n        verbose=False,\n        **kwargs,\n    ):\n        super().__init__(**kwargs)\n\n        self.logger = get_module_logger(\"EnhancedIndexingStrategy\")\n\n        self.riskmodel_root = riskmodel_root\n        self.market = market\n        self.turn_limit = turn_limit\n\n        self.factor_exp_path = name_mapping.get(\"factor_exp\", self.FACTOR_EXP_NAME)\n        self.factor_cov_path = name_mapping.get(\"factor_cov\", self.FACTOR_COV_NAME)\n        self.specific_risk_path = name_mapping.get(\"specific_risk\", self.SPECIFIC_RISK_NAME)\n        self.blacklist_path = name_mapping.get(\"blacklist\", self.BLACKLIST_NAME)\n\n        self.optimizer = EnhancedIndexingOptimizer(**optimizer_kwargs)\n\n        self.verbose = verbose\n\n        self._riskdata_cache = {}\n\n    def get_risk_data(self, date):\n        if date in self._riskdata_cache:\n            return self._riskdata_cache[date]\n\n        root = self.riskmodel_root + \"/\" + date.strftime(\"%Y%m%d\")\n        if not os.path.exists(root):\n            return None\n\n        factor_exp = load_dataset(root + \"/\" + self.factor_exp_path, index_col=[0])\n        factor_cov = load_dataset(root + \"/\" + self.factor_cov_path, index_col=[0])\n        specific_risk = load_dataset(root + \"/\" + self.specific_risk_path, index_col=[0])\n\n        if not factor_exp.index.equals(specific_risk.index):\n            # NOTE: for stocks missing specific_risk, we always assume it has the highest volatility\n            specific_risk = specific_risk.reindex(factor_exp.index, fill_value=specific_risk.max())\n\n        universe = factor_exp.index.tolist()\n\n        blacklist = []\n        if os.path.exists(root + \"/\" + self.blacklist_path):\n            blacklist = load_dataset(root + \"/\" + self.blacklist_path).index.tolist()\n\n        self._riskdata_cache[date] = factor_exp.values, factor_cov.values, specific_risk.values, universe, blacklist\n\n        return self._riskdata_cache[date]\n\n    def generate_target_weight_position(self, score, current, trade_start_time, trade_end_time):\n        trade_date = trade_start_time\n        pre_date = get_pre_trading_date(trade_date, future=True)  # previous trade date\n\n        # load risk data\n        outs = self.get_risk_data(pre_date)\n        if outs is None:\n            self.logger.warning(f\"no risk data for {pre_date:%Y-%m-%d}, skip optimization\")\n            return None\n        factor_exp, factor_cov, specific_risk, universe, blacklist = outs\n\n        # transform score\n        # NOTE: for stocks missing score, we always assume they have the lowest score\n        score = score.reindex(universe).fillna(score.min()).values\n\n        # get current weight\n        # NOTE: if a stock is not in universe, its current weight will be zero\n        cur_weight = current.get_stock_weight_dict(only_stock=False)\n        cur_weight = np.array([cur_weight.get(stock, 0) for stock in universe])\n        assert all(cur_weight >= 0), \"current weight has negative values\"\n        cur_weight = cur_weight / self.get_risk_degree(trade_date)  # sum of weight should be risk_degree\n        if cur_weight.sum() > 1 and self.verbose:\n            self.logger.warning(f\"previous total holdings excess risk degree (current: {cur_weight.sum()})\")\n\n        # load bench weight\n        bench_weight = D.features(\n            D.instruments(\"all\"), [f\"${self.market}_weight\"], start_time=pre_date, end_time=pre_date\n        ).squeeze()\n        bench_weight.index = bench_weight.index.droplevel(level=\"datetime\")\n        bench_weight = bench_weight.reindex(universe).fillna(0).values\n\n        # whether stock tradable\n        # NOTE: currently we use last day volume to check whether tradable\n        tradable = D.features(D.instruments(\"all\"), [\"$volume\"], start_time=pre_date, end_time=pre_date).squeeze()\n        tradable.index = tradable.index.droplevel(level=\"datetime\")\n        tradable = tradable.reindex(universe).gt(0).values\n        mask_force_hold = ~tradable\n\n        # mask force sell\n        mask_force_sell = np.array([stock in blacklist for stock in universe], dtype=bool)\n\n        # optimize\n        weight = self.optimizer(\n            r=score,\n            F=factor_exp,\n            cov_b=factor_cov,\n            var_u=specific_risk**2,\n            w0=cur_weight,\n            wb=bench_weight,\n            mfh=mask_force_hold,\n            mfs=mask_force_sell,\n        )\n\n        target_weight_position = {stock: weight for stock, weight in zip(universe, weight) if weight > 0}\n\n        if self.verbose:\n            self.logger.info(\"trade date: {:%Y-%m-%d}\".format(trade_date))\n            self.logger.info(\"number of holding stocks: {}\".format(len(target_weight_position)))\n            self.logger.info(\"total holding weight: {:.6f}\".format(weight.sum()))\n\n        return target_weight_position\n"
  },
  {
    "path": "qlib/contrib/torch.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\"\"\"\nThis module is not a necessary part of Qlib.\nThey are just some tools for convenience\nIt is should not imported into the core part of qlib\n\"\"\"\n\nimport torch\nimport numpy as np\nimport pandas as pd\n\n\ndef data_to_tensor(data, device=\"cpu\", raise_error=False):\n    if isinstance(data, torch.Tensor):\n        if device == \"cpu\":\n            return data.cpu()\n        else:\n            return data.to(device)\n    if isinstance(data, (pd.DataFrame, pd.Series)):\n        return data_to_tensor(torch.from_numpy(data.values).float(), device)\n    elif isinstance(data, np.ndarray):\n        return data_to_tensor(torch.from_numpy(data).float(), device)\n    elif isinstance(data, (tuple, list)):\n        return [data_to_tensor(i, device) for i in data]\n    elif isinstance(data, dict):\n        return {k: data_to_tensor(v, device) for k, v in data.items()}\n    else:\n        if raise_error:\n            raise ValueError(f\"Unsupported data type: {type(data)}.\")\n        else:\n            return data\n"
  },
  {
    "path": "qlib/contrib/tuner/__init__.py",
    "content": "# pylint: skip-file\n# flake8: noqa\n"
  },
  {
    "path": "qlib/contrib/tuner/config.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\n# pylint: skip-file\n# flake8: noqa\n\nimport copy\nimport os\nfrom ruamel.yaml import YAML\n\n\nclass TunerConfigManager:\n    def __init__(self, config_path):\n        if not config_path:\n            raise ValueError(\"Config path is invalid.\")\n        self.config_path = config_path\n\n        with open(config_path) as fp:\n            yaml = YAML(typ=\"safe\", pure=True)\n            config = yaml.load(fp)\n        self.config = copy.deepcopy(config)\n\n        self.pipeline_ex_config = PipelineExperimentConfig(config.get(\"experiment\", dict()), self)\n        self.pipeline_config = config.get(\"tuner_pipeline\", list())\n        self.optim_config = OptimizationConfig(config.get(\"optimization_criteria\", dict()), self)\n\n        self.time_config = config.get(\"time_period\", dict())\n        self.data_config = config.get(\"data\", dict())\n        self.backtest_config = config.get(\"backtest\", dict())\n        self.qlib_client_config = config.get(\"qlib_client\", dict())\n\n\nclass PipelineExperimentConfig:\n    def __init__(self, config, TUNER_CONFIG_MANAGER):\n        \"\"\"\n        :param config:  The config dict for tuner experiment\n        :param TUNER_CONFIG_MANAGER:   The tuner config manager\n        \"\"\"\n        self.name = config.get(\"name\", \"tuner_experiment\")\n        # The dir of the config\n        self.global_dir = config.get(\"dir\", os.path.dirname(TUNER_CONFIG_MANAGER.config_path))\n        # The dir of the result of tuner experiment\n        self.tuner_ex_dir = config.get(\"tuner_ex_dir\", os.path.join(self.global_dir, self.name))\n        if not os.path.exists(self.tuner_ex_dir):\n            os.makedirs(self.tuner_ex_dir)\n        # The dir of the results of all estimator experiments\n        self.estimator_ex_dir = config.get(\"estimator_ex_dir\", os.path.join(self.tuner_ex_dir, \"estimator_experiment\"))\n        if not os.path.exists(self.estimator_ex_dir):\n            os.makedirs(self.estimator_ex_dir)\n        # Get the tuner type\n        self.tuner_module_path = config.get(\"tuner_module_path\", \"qlib.contrib.tuner.tuner\")\n        self.tuner_class = config.get(\"tuner_class\", \"QLibTuner\")\n        # Save the tuner experiment for further view\n        tuner_ex_config_path = os.path.join(self.tuner_ex_dir, \"tuner_config.yaml\")\n        with open(tuner_ex_config_path, \"w\") as fp:\n            yaml.dump(TUNER_CONFIG_MANAGER.config, fp)\n\n\nclass OptimizationConfig:\n    def __init__(self, config, TUNER_CONFIG_MANAGER):\n        self.report_type = config.get(\"report_type\", \"pred_long\")\n        if self.report_type not in [\n            \"pred_long\",\n            \"pred_long_short\",\n            \"pred_short\",\n            \"excess_return_without_cost\",\n            \"excess_return_with_cost\",\n            \"model\",\n        ]:\n            raise ValueError(\n                \"report_type should be one of pred_long, pred_long_short, pred_short, excess_return_without_cost, excess_return_with_cost and model\"\n            )\n\n        self.report_factor = config.get(\"report_factor\", \"information_ratio\")\n        if self.report_factor not in [\n            \"annualized_return\",\n            \"information_ratio\",\n            \"max_drawdown\",\n            \"mean\",\n            \"std\",\n            \"model_score\",\n            \"model_pearsonr\",\n        ]:\n            raise ValueError(\n                \"report_factor should be one of annualized_return, information_ratio, max_drawdown, mean, std, model_pearsonr and model_score\"\n            )\n\n        self.optim_type = config.get(\"optim_type\", \"max\")\n        if self.optim_type not in [\"min\", \"max\", \"correlation\"]:\n            raise ValueError(\"optim_type should be min, max or correlation\")\n"
  },
  {
    "path": "qlib/contrib/tuner/launcher.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\n# pylint: skip-file\n# flake8: noqa\n\n# coding=utf-8\n\nimport argparse\nimport importlib\nimport os\nimport yaml\n\nfrom .config import TunerConfigManager\n\nargs_parser = argparse.ArgumentParser(prog=\"tuner\")\nargs_parser.add_argument(\n    \"-c\",\n    \"--config_path\",\n    required=True,\n    type=str,\n    help=\"config path indicates where to load yaml config.\",\n)\n\nargs = args_parser.parse_args()\n\nTUNER_CONFIG_MANAGER = TunerConfigManager(args.config_path)\n\n\ndef run():\n    # 1. Get pipeline class.\n    tuner_pipeline_class = getattr(importlib.import_module(\".pipeline\", package=\"qlib.contrib.tuner\"), \"Pipeline\")\n    # 2. Init tuner pipeline.\n    tuner_pipeline = tuner_pipeline_class(TUNER_CONFIG_MANAGER)\n    # 3. Begin to tune\n    tuner_pipeline.run()\n"
  },
  {
    "path": "qlib/contrib/tuner/pipeline.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\n# pylint: skip-file\n# flake8: noqa\n\nimport os\nimport json\nimport logging\nimport importlib\nfrom abc import abstractmethod\n\nfrom ...log import get_module_logger, TimeInspector\nfrom ...utils import get_module_by_module_path\n\n\nclass Pipeline:\n    GLOBAL_BEST_PARAMS_NAME = \"global_best_params.json\"\n\n    def __init__(self, tuner_config_manager):\n        self.logger = get_module_logger(\"Pipeline\", sh_level=logging.INFO)\n\n        self.tuner_config_manager = tuner_config_manager\n\n        self.pipeline_ex_config = tuner_config_manager.pipeline_ex_config\n        self.optim_config = tuner_config_manager.optim_config\n        self.time_config = tuner_config_manager.time_config\n        self.pipeline_config = tuner_config_manager.pipeline_config\n        self.data_config = tuner_config_manager.data_config\n        self.backtest_config = tuner_config_manager.backtest_config\n        self.qlib_client_config = tuner_config_manager.qlib_client_config\n\n        self.global_best_res = None\n        self.global_best_params = None\n        self.best_tuner_index = None\n\n    def run(self):\n        TimeInspector.set_time_mark()\n        for tuner_index, tuner_config in enumerate(self.pipeline_config):\n            tuner = self.init_tuner(tuner_index, tuner_config)\n            tuner.tune()\n            if self.global_best_res is None or self.global_best_res > tuner.best_res:\n                self.global_best_res = tuner.best_res\n                self.global_best_params = tuner.best_params\n                self.best_tuner_index = tuner_index\n        TimeInspector.log_cost_time(\"Finished tuner pipeline.\")\n\n        self.save_tuner_exp_info()\n\n    def init_tuner(self, tuner_index, tuner_config):\n        \"\"\"\n        Implement this method to build the tuner by config\n        return: tuner\n        \"\"\"\n        # 1. Add experiment config in tuner_config\n        tuner_config[\"experiment\"] = {\n            \"name\": \"estimator_experiment_{}\".format(tuner_index),\n            \"id\": tuner_index,\n            \"dir\": self.pipeline_ex_config.estimator_ex_dir,\n            \"observer_type\": \"file_storage\",\n        }\n        tuner_config[\"qlib_client\"] = self.qlib_client_config\n        # 2. Add data config in tuner_config\n        tuner_config[\"data\"] = self.data_config\n        # 3. Add backtest config in tuner_config\n        tuner_config[\"backtest\"] = self.backtest_config\n        # 4. Update trainer in tuner_config\n        tuner_config[\"trainer\"].update({\"args\": self.time_config})\n\n        # 5. Import Tuner class\n        tuner_module = get_module_by_module_path(self.pipeline_ex_config.tuner_module_path)\n        tuner_class = getattr(tuner_module, self.pipeline_ex_config.tuner_class)\n        # 6. Return the specific tuner\n        return tuner_class(tuner_config, self.optim_config)\n\n    def save_tuner_exp_info(self):\n        TimeInspector.set_time_mark()\n        save_path = os.path.join(self.pipeline_ex_config.tuner_ex_dir, Pipeline.GLOBAL_BEST_PARAMS_NAME)\n        with open(save_path, \"w\") as fp:\n            json.dump(self.global_best_params, fp)\n        TimeInspector.log_cost_time(\"Finished save global best tuner parameters.\")\n\n        self.logger.info(\"Best Tuner id: {}.\".format(self.best_tuner_index))\n        self.logger.info(\"Global best parameters: {}.\".format(self.global_best_params))\n        self.logger.info(\"You can check the best parameters at {}.\".format(save_path))\n"
  },
  {
    "path": "qlib/contrib/tuner/space.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\n# pylint: skip-file\n# flake8: noqa\n\nfrom hyperopt import hp\n\nTopkAmountStrategySpace = {\n    \"topk\": hp.choice(\"topk\", [30, 35, 40]),\n    \"buffer_margin\": hp.choice(\"buffer_margin\", [200, 250, 300]),\n}\n\nQLibDataLabelSpace = {\n    \"labels\": hp.choice(\n        \"labels\",\n        [[\"Ref($vwap, -2)/Ref($vwap, -1) - 1\"], [\"Ref($close, -5)/$close - 1\"]],\n    )\n}\n"
  },
  {
    "path": "qlib/contrib/tuner/tuner.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\n# pylint: skip-file\n# flake8: noqa\n\nimport os\nimport yaml\nimport json\nimport copy\nimport logging\nimport importlib\nimport subprocess\nimport pandas as pd\nimport numpy as np\n\nfrom abc import abstractmethod\n\nfrom ...log import get_module_logger, TimeInspector\nfrom ...utils.pickle_utils import restricted_pickle_load\nfrom hyperopt import fmin, tpe\nfrom hyperopt import STATUS_OK, STATUS_FAIL\n\n\nclass Tuner:\n    def __init__(self, tuner_config, optim_config):\n        self.logger = get_module_logger(\"Tuner\", sh_level=logging.INFO)\n\n        self.tuner_config = tuner_config\n        self.optim_config = optim_config\n\n        self.max_evals = self.tuner_config.get(\"max_evals\", 10)\n        self.ex_dir = os.path.join(\n            self.tuner_config[\"experiment\"][\"dir\"],\n            self.tuner_config[\"experiment\"][\"name\"],\n        )\n\n        self.best_params = None\n        self.best_res = None\n\n        self.space = self.setup_space()\n\n    def tune(self):\n        TimeInspector.set_time_mark()\n        fmin(\n            fn=self.objective,\n            space=self.space,\n            algo=tpe.suggest,\n            max_evals=self.max_evals,\n            show_progressbar=False,\n        )\n        self.logger.info(\"Local best params: {} \".format(self.best_params))\n        TimeInspector.log_cost_time(\n            \"Finished searching best parameters in Tuner {}.\".format(self.tuner_config[\"experiment\"][\"id\"])\n        )\n\n        self.save_local_best_params()\n\n    @abstractmethod\n    def objective(self, params):\n        \"\"\"\n        Implement this method to give an optimization factor using parameters in space.\n        :return: {'loss': a factor for optimization, float type,\n                  'status': the status of this evaluation step, STATUS_OK or STATUS_FAIL}.\n        \"\"\"\n        pass\n\n    @abstractmethod\n    def setup_space(self):\n        \"\"\"\n        Implement this method to setup the searching space of tuner.\n        :return: searching space, dict type.\n        \"\"\"\n        pass\n\n    @abstractmethod\n    def save_local_best_params(self):\n        \"\"\"\n        Implement this method to save the best parameters of this tuner.\n        \"\"\"\n        pass\n\n\nclass QLibTuner(Tuner):\n    ESTIMATOR_CONFIG_NAME = \"estimator_config.yaml\"\n    EXP_INFO_NAME = \"exp_info.json\"\n    EXP_RESULT_DIR = \"sacred/{}\"\n    EXP_RESULT_NAME = \"analysis.pkl\"\n    LOCAL_BEST_PARAMS_NAME = \"local_best_params.json\"\n\n    def objective(self, params):\n        # 1. Setup an config for a specific estimator process\n        estimator_path = self.setup_estimator_config(params)\n        self.logger.info(\"Searching params: {} \".format(params))\n\n        # 2. Use subprocess to do the estimator program, this process will wait until subprocess finish\n        sub_fails = subprocess.call(\"estimator -c {}\".format(estimator_path), shell=True)\n        if sub_fails:\n            # If this subprocess failed, ignore this evaluation step\n            self.logger.info(\"Estimator experiment failed when using this searching parameters\")\n            return {\"loss\": np.nan, \"status\": STATUS_FAIL}\n\n        # 3. Fetch the result of subprocess, and check whether the result is Nan\n        res = self.fetch_result()\n        if np.isnan(res):\n            status = STATUS_FAIL\n        else:\n            status = STATUS_OK\n\n        # 4. Save the best score and params\n        if self.best_res is None or self.best_res > res:\n            self.best_res = res\n            self.best_params = params\n\n        # 5. Return the result as optim objective\n        return {\"loss\": res, \"status\": status}\n\n    def fetch_result(self):\n        # 1. Get experiment information\n        exp_info_path = os.path.join(self.ex_dir, QLibTuner.EXP_INFO_NAME)\n        with open(exp_info_path) as fp:\n            exp_info = json.load(fp)\n        estimator_ex_id = exp_info[\"id\"]\n\n        # 2. Return model result if needed\n        if self.optim_config.report_type == \"model\":\n            if self.optim_config.report_factor == \"model_score\":\n                # if estimator experiment is multi-label training, user need to process the scores by himself\n                # Default method is return the average score\n                return np.mean(exp_info[\"performance\"][\"model_score\"])\n            elif self.optim_config.report_factor == \"model_pearsonr\":\n                # pearsonr is a correlation coefficient, 1 is the best\n                return np.abs(exp_info[\"performance\"][\"model_pearsonr\"] - 1)\n\n        # 3. Get backtest results\n        exp_result_dir = os.path.join(self.ex_dir, QLibTuner.EXP_RESULT_DIR.format(estimator_ex_id))\n        exp_result_path = os.path.join(exp_result_dir, QLibTuner.EXP_RESULT_NAME)\n        with open(exp_result_path, \"rb\") as fp:\n            analysis_df = restricted_pickle_load(fp)\n\n        # 4. Get the backtest factor which user want to optimize, if user want to maximize the factor, then reverse the result\n        res = analysis_df.loc[self.optim_config.report_type].loc[self.optim_config.report_factor]\n        # res = res.values[0] if self.optim_config.optim_type == 'min' else -res.values[0]\n        if self.optim_config == \"min\":\n            return res.values[0]\n        elif self.optim_config == \"max\":\n            return -res.values[0]\n        else:\n            # self.optim_config == 'correlation'\n            return np.abs(res.values[0] - 1)\n\n    def setup_estimator_config(self, params):\n        estimator_config = copy.deepcopy(self.tuner_config)\n        estimator_config[\"model\"].update({\"args\": params[\"model_space\"]})\n        estimator_config[\"strategy\"].update({\"args\": params[\"strategy_space\"]})\n        if params.get(\"data_label_space\", None) is not None:\n            estimator_config[\"data\"][\"args\"].update(params[\"data_label_space\"])\n\n        estimator_path = os.path.join(\n            self.tuner_config[\"experiment\"].get(\"dir\", \"../\"),\n            QLibTuner.ESTIMATOR_CONFIG_NAME,\n        )\n\n        with open(estimator_path, \"w\") as fp:\n            yaml.dump(estimator_config, fp)\n\n        return estimator_path\n\n    def setup_space(self):\n        # 1. Setup model space\n        model_space_name = self.tuner_config[\"model\"].get(\"space\", None)\n        if model_space_name is None:\n            raise ValueError(\"Please give the search space of model.\")\n        model_space = getattr(\n            importlib.import_module(\".space\", package=\"qlib.contrib.tuner\"),\n            model_space_name,\n        )\n\n        # 2. Setup strategy space\n        strategy_space_name = self.tuner_config[\"strategy\"].get(\"space\", None)\n        if strategy_space_name is None:\n            raise ValueError(\"Please give the search space of strategy.\")\n        strategy_space = getattr(\n            importlib.import_module(\".space\", package=\"qlib.contrib.tuner\"),\n            strategy_space_name,\n        )\n\n        # 3. Setup data label space if given\n        if self.tuner_config.get(\"data_label\", None) is not None:\n            data_label_space_name = self.tuner_config[\"data_label\"].get(\"space\", None)\n            if data_label_space_name is not None:\n                data_label_space = getattr(\n                    importlib.import_module(\".space\", package=\"qlib.contrib.tuner\"),\n                    data_label_space_name,\n                )\n        else:\n            data_label_space_name = None\n\n        # 4. Combine the searching space\n        space = dict()\n        space.update({\"model_space\": model_space})\n        space.update({\"strategy_space\": strategy_space})\n        if data_label_space_name is not None:\n            space.update({\"data_label_space\": data_label_space})\n\n        return space\n\n    def save_local_best_params(self):\n        TimeInspector.set_time_mark()\n        local_best_params_path = os.path.join(self.ex_dir, QLibTuner.LOCAL_BEST_PARAMS_NAME)\n        with open(local_best_params_path, \"w\") as fp:\n            json.dump(self.best_params, fp)\n        TimeInspector.log_cost_time(\n            \"Finished saving local best tuner parameters to: {} .\".format(local_best_params_path)\n        )\n"
  },
  {
    "path": "qlib/contrib/workflow/__init__.py",
    "content": "#  Copyright (c) Microsoft Corporation.\n#  Licensed under the MIT License.\nfrom .record_temp import MultiSegRecord\nfrom .record_temp import SignalMseRecord\n\n__all__ = [\"MultiSegRecord\", \"SignalMseRecord\"]\n"
  },
  {
    "path": "qlib/contrib/workflow/record_temp.py",
    "content": "#  Copyright (c) Microsoft Corporation.\n#  Licensed under the MIT License.\n\nimport logging\nimport pandas as pd\nimport numpy as np\nfrom sklearn.metrics import mean_squared_error\nfrom typing import Dict, Text, Any\n\nfrom ...contrib.eva.alpha import calc_ic\nfrom ...workflow.record_temp import RecordTemp\nfrom ...workflow.record_temp import SignalRecord\nfrom ...data import dataset as qlib_dataset\nfrom ...log import get_module_logger\n\nlogger = get_module_logger(\"workflow\", logging.INFO)\n\n\nclass MultiSegRecord(RecordTemp):\n    \"\"\"\n    This is the multiple segments signal record class that generates the signal prediction.\n    This class inherits the ``RecordTemp`` class.\n    \"\"\"\n\n    def __init__(self, model, dataset, recorder=None):\n        super().__init__(recorder=recorder)\n        if not isinstance(dataset, qlib_dataset.DatasetH):\n            raise ValueError(\"The type of dataset is not DatasetH instead of {:}\".format(type(dataset)))\n        self.model = model\n        self.dataset = dataset\n\n    def generate(self, segments: Dict[Text, Any], save: bool = False):\n        for key, segment in segments.items():\n            predics = self.model.predict(self.dataset, segment)\n            if isinstance(predics, pd.Series):\n                predics = predics.to_frame(\"score\")\n            labels = self.dataset.prepare(\n                segments=segment, col_set=\"label\", data_key=qlib_dataset.handler.DataHandlerLP.DK_R\n            )\n            # Compute the IC and Rank IC\n            ic, ric = calc_ic(predics.iloc[:, 0], labels.iloc[:, 0])\n            results = {\"all-IC\": ic, \"mean-IC\": ic.mean(), \"all-Rank-IC\": ric, \"mean-Rank-IC\": ric.mean()}\n            logger.info(\"--- Results for {:} ({:}) ---\".format(key, segment))\n            ic_x100, ric_x100 = ic * 100, ric * 100\n            logger.info(\"IC: {:.4f}%\".format(ic_x100.mean()))\n            logger.info(\"ICIR: {:.4f}%\".format(ic_x100.mean() / ic_x100.std()))\n            logger.info(\"Rank IC: {:.4f}%\".format(ric_x100.mean()))\n            logger.info(\"Rank ICIR: {:.4f}%\".format(ric_x100.mean() / ric_x100.std()))\n\n            if save:\n                save_name = \"results-{:}.pkl\".format(key)\n                self.save(**{save_name: results})\n                logger.info(\n                    \"The record '{:}' has been saved as the artifact of the Experiment {:}\".format(\n                        save_name, self.recorder.experiment_id\n                    )\n                )\n\n\nclass SignalMseRecord(RecordTemp):\n    \"\"\"\n    This is the Signal MSE Record class that computes the mean squared error (MSE).\n    This class inherits the ``SignalMseRecord`` class.\n    \"\"\"\n\n    artifact_path = \"sig_analysis\"\n    depend_cls = SignalRecord\n\n    def __init__(self, recorder, **kwargs):\n        super().__init__(recorder=recorder, **kwargs)\n\n    def generate(self):\n        self.check()\n\n        pred = self.load(\"pred.pkl\")\n        label = self.load(\"label.pkl\")\n        masks = ~np.isnan(label.values)\n        mse = mean_squared_error(pred.values[masks], label[masks])\n        metrics = {\"MSE\": mse, \"RMSE\": np.sqrt(mse)}\n        objects = {\"mse.pkl\": mse, \"rmse.pkl\": np.sqrt(mse)}\n        self.recorder.log_metrics(**metrics)\n        self.save(**objects)\n        logger.info(\"The evaluation results in SignalMseRecord is {:}\".format(metrics))\n\n    def list(self):\n        return [\"mse.pkl\", \"rmse.pkl\"]\n"
  },
  {
    "path": "qlib/data/__init__.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\n\nfrom __future__ import division\nfrom __future__ import print_function\n\nfrom .data import (\n    D,\n    CalendarProvider,\n    InstrumentProvider,\n    FeatureProvider,\n    ExpressionProvider,\n    DatasetProvider,\n    LocalCalendarProvider,\n    LocalInstrumentProvider,\n    LocalFeatureProvider,\n    LocalPITProvider,\n    LocalExpressionProvider,\n    LocalDatasetProvider,\n    ClientCalendarProvider,\n    ClientInstrumentProvider,\n    ClientDatasetProvider,\n    BaseProvider,\n    LocalProvider,\n    ClientProvider,\n)\n\nfrom .cache import (\n    ExpressionCache,\n    DatasetCache,\n    DiskExpressionCache,\n    DiskDatasetCache,\n    SimpleDatasetCache,\n    DatasetURICache,\n    MemoryCalendarCache,\n)\n\n__all__ = [\n    \"D\",\n    \"CalendarProvider\",\n    \"InstrumentProvider\",\n    \"FeatureProvider\",\n    \"ExpressionProvider\",\n    \"DatasetProvider\",\n    \"LocalCalendarProvider\",\n    \"LocalInstrumentProvider\",\n    \"LocalFeatureProvider\",\n    \"LocalPITProvider\",\n    \"LocalExpressionProvider\",\n    \"LocalDatasetProvider\",\n    \"ClientCalendarProvider\",\n    \"ClientInstrumentProvider\",\n    \"ClientDatasetProvider\",\n    \"BaseProvider\",\n    \"LocalProvider\",\n    \"ClientProvider\",\n    \"ExpressionCache\",\n    \"DatasetCache\",\n    \"DiskExpressionCache\",\n    \"DiskDatasetCache\",\n    \"SimpleDatasetCache\",\n    \"DatasetURICache\",\n    \"MemoryCalendarCache\",\n]\n"
  },
  {
    "path": "qlib/data/_libs/__init__.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n"
  },
  {
    "path": "qlib/data/_libs/expanding.pyx",
    "content": "# cython: profile=False\n# cython: boundscheck=False, wraparound=False, cdivision=True\ncimport cython\ncimport numpy as np\nimport numpy as np\n\nfrom libc.math cimport sqrt, isnan, NAN\nfrom libcpp.vector cimport vector\n\n\ncdef class Expanding:\n    \"\"\"1-D array expanding\"\"\"\n    cdef vector[double] barv\n    cdef int na_count\n    def __init__(self):\n        self.na_count = 0\n\n    cdef double update(self, double val):\n        pass\n\n\ncdef class Mean(Expanding):\n    \"\"\"1-D array expanding mean\"\"\"\n    cdef double vsum\n    def __init__(self):\n        super(Mean, self).__init__()\n        self.vsum = 0\n\n    cdef double update(self, double val):\n        self.barv.push_back(val)\n        if isnan(val):\n            self.na_count += 1\n        else:\n            self.vsum += val\n        return self.vsum / (self.barv.size() - self.na_count)\n\n\ncdef class Slope(Expanding):\n    \"\"\"1-D array expanding slope\"\"\"\n    cdef double x_sum\n    cdef double x2_sum\n    cdef double y_sum\n    cdef double xy_sum\n    def __init__(self):\n        super(Slope, self).__init__()\n        self.x_sum  = 0\n        self.x2_sum = 0\n        self.y_sum  = 0\n        self.xy_sum = 0\n\n    cdef double update(self, double val):\n        self.barv.push_back(val)\n        cdef size_t size = self.barv.size()\n        if isnan(val):\n            self.na_count += 1\n        else:\n            self.x_sum  += size\n            self.x2_sum += size * size\n            self.y_sum  += val\n            self.xy_sum += size * val\n        cdef int N = size - self.na_count\n        return (N*self.xy_sum - self.x_sum*self.y_sum) / \\\n            (N*self.x2_sum - self.x_sum*self.x_sum)\n\n\ncdef class Resi(Expanding):\n    \"\"\"1-D array expanding residuals\"\"\"\n    cdef double x_sum\n    cdef double x2_sum\n    cdef double y_sum\n    cdef double xy_sum\n    def __init__(self):\n        super(Resi, self).__init__()\n        self.x_sum  = 0\n        self.x2_sum = 0\n        self.y_sum  = 0\n        self.xy_sum = 0\n\n    cdef double update(self, double val):\n        self.barv.push_back(val)\n        cdef size_t size = self.barv.size()\n        if isnan(val):\n            self.na_count += 1\n        else:\n            self.x_sum  += size\n            self.x2_sum += size * size\n            self.y_sum  += val\n            self.xy_sum += size * val\n        cdef int N = size - self.na_count\n        slope = (N*self.xy_sum - self.x_sum*self.y_sum) / \\\n                (N*self.x2_sum - self.x_sum*self.x_sum)\n        x_mean = self.x_sum / N\n        y_mean = self.y_sum / N\n        interp = y_mean - slope*x_mean\n        return val - (slope*size + interp)\n\n\ncdef class Rsquare(Expanding):\n    \"\"\"1-D array expanding rsquare\"\"\"\n    cdef double x_sum\n    cdef double x2_sum\n    cdef double y_sum\n    cdef double y2_sum\n    cdef double xy_sum\n    def __init__(self):\n        super(Rsquare, self).__init__()\n        self.x_sum  = 0\n        self.x2_sum = 0\n        self.y_sum  = 0\n        self.y2_sum = 0\n        self.xy_sum = 0\n\n    cdef double update(self, double val):\n        self.barv.push_back(val)\n        cdef size_t size = self.barv.size()\n        if isnan(val):\n            self.na_count += 1\n        else:\n            self.x_sum  += size\n            self.x2_sum += size * size\n            self.y_sum  += val\n            self.y2_sum += val * val\n            self.xy_sum += size * val\n        cdef int N = size - self.na_count\n        cdef double rvalue = (N*self.xy_sum - self.x_sum*self.y_sum) / \\\n            sqrt((N*self.x2_sum - self.x_sum*self.x_sum) * (N*self.y2_sum - self.y_sum*self.y_sum))\n        return rvalue * rvalue\n\n\ncdef np.ndarray[double, ndim=1] expanding(Expanding r, np.ndarray a):\n    cdef int  i\n    cdef int  N = len(a)\n    cdef np.ndarray[double, ndim=1] ret = np.empty(N)\n    for i in range(N):\n        ret[i] = r.update(a[i])\n    return ret\n\ndef expanding_mean(np.ndarray a):\n    cdef Mean r = Mean()\n    return expanding(r, a)\n\ndef expanding_slope(np.ndarray a):\n    cdef Slope r = Slope()\n    return expanding(r, a)\n\ndef expanding_rsquare(np.ndarray a):\n    cdef Rsquare r = Rsquare()\n    return expanding(r, a)\n\ndef expanding_resi(np.ndarray a):\n    cdef Resi r = Resi()\n    return expanding(r, a)\n"
  },
  {
    "path": "qlib/data/_libs/rolling.pyx",
    "content": "# cython: profile=False\n# cython: boundscheck=False, wraparound=False, cdivision=True\ncimport cython\ncimport numpy as np\nimport numpy as np\n\nfrom libc.math cimport sqrt, isnan, NAN\nfrom libcpp.deque cimport deque\n\n\ncdef class Rolling:\n    \"\"\"1-D array rolling\"\"\"\n    cdef int window\n    cdef deque[double] barv\n    cdef int na_count\n    def __init__(self, int window):\n        self.window = window\n        self.na_count = window\n        cdef int i\n        for i in range(window):\n            self.barv.push_back(NAN)\n\n    cdef double update(self, double val):\n        pass\n\n\ncdef class Mean(Rolling):\n    \"\"\"1-D array rolling mean\"\"\"\n    cdef double vsum\n    def __init__(self, int window):\n        super(Mean, self).__init__(window)\n        self.vsum = 0\n        \n    cdef double update(self, double val):\n        self.barv.push_back(val)\n        if not isnan(self.barv.front()):\n            self.vsum -= self.barv.front()\n        else:\n            self.na_count -= 1\n        self.barv.pop_front()\n        if isnan(val):\n            self.na_count += 1\n            # return NAN\n        else:\n            self.vsum += val\n        return self.vsum / (self.window - self.na_count)\n\n\ncdef class Slope(Rolling):\n    \"\"\"1-D array rolling slope\"\"\"\n    cdef double i_sum # can be used as i2_sum\n    cdef double x_sum\n    cdef double x2_sum\n    cdef double y_sum\n    cdef double xy_sum\n    def __init__(self, int window):\n        super(Slope, self).__init__(window)\n        self.i_sum  = 0\n        self.x_sum  = 0\n        self.x2_sum = 0\n        self.y_sum  = 0\n        self.xy_sum = 0\n\n    cdef double update(self, double val):\n        self.barv.push_back(val)\n        self.xy_sum = self.xy_sum - self.y_sum\n        self.x2_sum = self.x2_sum + self.i_sum - 2*self.x_sum\n        self.x_sum = self.x_sum - self.i_sum\n        cdef double _val\n        _val = self.barv.front()\n        if not isnan(_val):\n            self.i_sum -= 1\n            self.y_sum -= _val\n        else:\n            self.na_count -= 1\n        self.barv.pop_front()\n        if isnan(val):\n            self.na_count += 1\n            # return NAN\n        else:\n            self.i_sum  += 1\n            self.x_sum  += self.window\n            self.x2_sum += self.window * self.window\n            self.y_sum  += val\n            self.xy_sum += self.window * val\n        cdef int N = self.window - self.na_count\n        return (N*self.xy_sum - self.x_sum*self.y_sum) / \\\n            (N*self.x2_sum - self.x_sum*self.x_sum)\n\n    \ncdef class Resi(Rolling):\n    \"\"\"1-D array rolling residuals\"\"\"\n    cdef double i_sum # can be used as i2_sum\n    cdef double x_sum\n    cdef double x2_sum\n    cdef double y_sum\n    cdef double xy_sum\n    def __init__(self, int window):\n        super(Resi, self).__init__(window)\n        self.i_sum  = 0\n        self.x_sum  = 0\n        self.x2_sum = 0\n        self.y_sum  = 0\n        self.xy_sum = 0\n\n    cdef double update(self, double val):\n        self.barv.push_back(val)\n        self.xy_sum = self.xy_sum - self.y_sum\n        self.x2_sum = self.x2_sum + self.i_sum - 2*self.x_sum\n        self.x_sum = self.x_sum - self.i_sum\n        cdef double _val\n        _val = self.barv.front()\n        if not isnan(_val):\n            self.i_sum -= 1\n            self.y_sum -= _val\n        else:\n            self.na_count -= 1\n        self.barv.pop_front()\n        if isnan(val):\n            self.na_count += 1\n            # return NAN\n        else:\n            self.i_sum  += 1\n            self.x_sum  += self.window\n            self.x2_sum += self.window * self.window\n            self.y_sum  += val\n            self.xy_sum += self.window * val\n        cdef int N = self.window - self.na_count\n        slope = (N*self.xy_sum - self.x_sum*self.y_sum) / \\\n                (N*self.x2_sum - self.x_sum*self.x_sum)\n        x_mean = self.x_sum / N\n        y_mean = self.y_sum / N\n        interp = y_mean - slope*x_mean\n        return val - (slope*self.window + interp)\n\n    \ncdef class Rsquare(Rolling):\n    \"\"\"1-D array rolling rsquare\"\"\"\n    cdef double i_sum\n    cdef double x_sum\n    cdef double x2_sum\n    cdef double y_sum\n    cdef double y2_sum\n    cdef double xy_sum\n    def __init__(self, int window):\n        super(Rsquare, self).__init__(window)\n        self.i_sum  = 0\n        self.x_sum  = 0\n        self.x2_sum = 0\n        self.y_sum  = 0\n        self.y2_sum = 0\n        self.xy_sum = 0\n\n    cdef double update(self, double val):\n        self.barv.push_back(val)\n        self.xy_sum = self.xy_sum - self.y_sum\n        self.x2_sum = self.x2_sum + self.i_sum - 2*self.x_sum\n        self.x_sum = self.x_sum - self.i_sum\n        cdef double _val\n        _val = self.barv.front()\n        if not isnan(_val):\n            self.i_sum  -= 1\n            self.y_sum  -= _val\n            self.y2_sum -= _val * _val\n        else:\n            self.na_count -= 1\n        self.barv.pop_front()\n        if isnan(val):\n            self.na_count += 1\n            # return NAN\n        else:\n            self.i_sum  += 1\n            self.x_sum  += self.window\n            self.x2_sum += self.window * self.window\n            self.y_sum  += val\n            self.y2_sum += val * val\n            self.xy_sum += self.window * val\n        cdef int N = self.window - self.na_count\n        cdef double rvalue\n        rvalue = (N*self.xy_sum - self.x_sum*self.y_sum) / \\\n            sqrt((N*self.x2_sum - self.x_sum*self.x_sum) * (N*self.y2_sum - self.y_sum*self.y_sum))\n        return rvalue * rvalue\n\n    \ncdef np.ndarray[double, ndim=1] rolling(Rolling r, np.ndarray a):\n    cdef int  i\n    cdef int  N = len(a)\n    cdef np.ndarray[double, ndim=1] ret = np.empty(N)\n    for i in range(N):\n        ret[i] = r.update(a[i])\n    return ret\n\ndef rolling_mean(np.ndarray a, int window):\n    cdef Mean r = Mean(window)\n    return rolling(r, a)\n\ndef rolling_slope(np.ndarray a, int window):\n    cdef Slope r = Slope(window)\n    return rolling(r, a)\n\ndef rolling_rsquare(np.ndarray a, int window):\n    cdef Rsquare r = Rsquare(window)\n    return rolling(r, a)\n\ndef rolling_resi(np.ndarray a, int window):\n    cdef Resi r = Resi(window)\n    return rolling(r, a)\n"
  },
  {
    "path": "qlib/data/base.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\n\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport abc\nimport pandas as pd\nfrom ..log import get_module_logger\n\n\nclass Expression(abc.ABC):\n    \"\"\"\n    Expression base class\n\n    Expression is designed to handle the calculation of data with the format below\n    data with two dimension for each instrument,\n\n    - feature\n    - time:  it  could be observation time or period time.\n\n        - period time is designed for Point-in-time database.  For example, the period time maybe 2014Q4, its value can observed for multiple times(different value may be observed at different time due to amendment).\n    \"\"\"\n\n    def __str__(self):\n        return type(self).__name__\n\n    def __repr__(self):\n        return str(self)\n\n    def __gt__(self, other):\n        from .ops import Gt  # pylint: disable=C0415\n\n        return Gt(self, other)\n\n    def __ge__(self, other):\n        from .ops import Ge  # pylint: disable=C0415\n\n        return Ge(self, other)\n\n    def __lt__(self, other):\n        from .ops import Lt  # pylint: disable=C0415\n\n        return Lt(self, other)\n\n    def __le__(self, other):\n        from .ops import Le  # pylint: disable=C0415\n\n        return Le(self, other)\n\n    def __eq__(self, other):\n        from .ops import Eq  # pylint: disable=C0415\n\n        return Eq(self, other)\n\n    def __ne__(self, other):\n        from .ops import Ne  # pylint: disable=C0415\n\n        return Ne(self, other)\n\n    def __add__(self, other):\n        from .ops import Add  # pylint: disable=C0415\n\n        return Add(self, other)\n\n    def __radd__(self, other):\n        from .ops import Add  # pylint: disable=C0415\n\n        return Add(other, self)\n\n    def __sub__(self, other):\n        from .ops import Sub  # pylint: disable=C0415\n\n        return Sub(self, other)\n\n    def __rsub__(self, other):\n        from .ops import Sub  # pylint: disable=C0415\n\n        return Sub(other, self)\n\n    def __mul__(self, other):\n        from .ops import Mul  # pylint: disable=C0415\n\n        return Mul(self, other)\n\n    def __rmul__(self, other):\n        from .ops import Mul  # pylint: disable=C0415\n\n        return Mul(self, other)\n\n    def __div__(self, other):\n        from .ops import Div  # pylint: disable=C0415\n\n        return Div(self, other)\n\n    def __rdiv__(self, other):\n        from .ops import Div  # pylint: disable=C0415\n\n        return Div(other, self)\n\n    def __truediv__(self, other):\n        from .ops import Div  # pylint: disable=C0415\n\n        return Div(self, other)\n\n    def __rtruediv__(self, other):\n        from .ops import Div  # pylint: disable=C0415\n\n        return Div(other, self)\n\n    def __pow__(self, other):\n        from .ops import Power  # pylint: disable=C0415\n\n        return Power(self, other)\n\n    def __rpow__(self, other):\n        from .ops import Power  # pylint: disable=C0415\n\n        return Power(other, self)\n\n    def __and__(self, other):\n        from .ops import And  # pylint: disable=C0415\n\n        return And(self, other)\n\n    def __rand__(self, other):\n        from .ops import And  # pylint: disable=C0415\n\n        return And(other, self)\n\n    def __or__(self, other):\n        from .ops import Or  # pylint: disable=C0415\n\n        return Or(self, other)\n\n    def __ror__(self, other):\n        from .ops import Or  # pylint: disable=C0415\n\n        return Or(other, self)\n\n    def load(self, instrument, start_index, end_index, *args):\n        \"\"\"load  feature\n        This function is responsible for loading feature/expression based on the expression engine.\n\n        The concrete implementation will be separated into two parts:\n\n        1) caching data, handle errors.\n\n            - This part is shared by all the expressions and implemented in Expression\n        2) processing and calculating data based on the specific expression.\n\n            - This part is different in each expression and implemented in each expression\n\n        Expression Engine is shared by different data.\n        Different data will have different extra information for `args`.\n\n        Parameters\n        ----------\n        instrument : str\n            instrument code.\n        start_index : str\n            feature start index [in calendar].\n        end_index : str\n            feature end  index  [in calendar].\n\n        *args may contain following information:\n        1) if it is used in basic expression engine data, it contains following arguments\n            freq: str\n                feature frequency.\n\n        2) if is used in PIT data, it contains following arguments\n            cur_pit:\n                it is designed for the point-in-time data.\n            period: int\n                This is used for query specific period.\n                The period is represented with int in Qlib. (e.g. 202001 may represent the first quarter in 2020)\n\n        Returns\n        ----------\n        pd.Series\n            feature series: The index of the series is the calendar index\n        \"\"\"\n        from .cache import H  # pylint: disable=C0415\n\n        # cache\n        cache_key = str(self), instrument, start_index, end_index, *args\n        if cache_key in H[\"f\"]:\n            return H[\"f\"][cache_key]\n        if start_index is not None and end_index is not None and start_index > end_index:\n            raise ValueError(\"Invalid index range: {} {}\".format(start_index, end_index))\n        try:\n            series = self._load_internal(instrument, start_index, end_index, *args)\n        except Exception as e:\n            get_module_logger(\"data\").debug(\n                f\"Loading data error: instrument={instrument}, expression={str(self)}, \"\n                f\"start_index={start_index}, end_index={end_index}, args={args}. \"\n                f\"error info: {str(e)}\"\n            )\n            raise\n        series.name = str(self)\n        H[\"f\"][cache_key] = series\n        return series\n\n    @abc.abstractmethod\n    def _load_internal(self, instrument, start_index, end_index, *args) -> pd.Series:\n        raise NotImplementedError(\"This function must be implemented in your newly defined feature\")\n\n    @abc.abstractmethod\n    def get_longest_back_rolling(self):\n        \"\"\"Get the longest length of historical data the feature has accessed\n\n        This is designed for getting the needed range of the data to calculate\n        the features in specific range at first.  However, situations like\n        Ref(Ref($close, -1), 1) can not be handled rightly.\n\n        So this will only used for detecting the length of historical data needed.\n        \"\"\"\n        # TODO: forward operator like Ref($close, -1) is not supported yet.\n        raise NotImplementedError(\"This function must be implemented in your newly defined feature\")\n\n    @abc.abstractmethod\n    def get_extended_window_size(self):\n        \"\"\"get_extend_window_size\n\n        For to calculate this Operator in range[start_index, end_index]\n        We have to get the *leaf feature* in\n        range[start_index - lft_etd, end_index + rght_etd].\n\n        Returns\n        ----------\n        (int, int)\n            lft_etd, rght_etd\n        \"\"\"\n        raise NotImplementedError(\"This function must be implemented in your newly defined feature\")\n\n\nclass Feature(Expression):\n    \"\"\"Static Expression\n\n    This kind of feature will load data from provider\n    \"\"\"\n\n    def __init__(self, name=None):\n        if name:\n            self._name = name\n        else:\n            self._name = type(self).__name__\n\n    def __str__(self):\n        return \"$\" + self._name\n\n    def _load_internal(self, instrument, start_index, end_index, freq):\n        # load\n        from .data import FeatureD  # pylint: disable=C0415\n\n        return FeatureD.feature(instrument, str(self), start_index, end_index, freq)\n\n    def get_longest_back_rolling(self):\n        return 0\n\n    def get_extended_window_size(self):\n        return 0, 0\n\n\nclass PFeature(Feature):\n    def __str__(self):\n        return \"$$\" + self._name\n\n    def _load_internal(self, instrument, start_index, end_index, cur_time, period=None):\n        from .data import PITD  # pylint: disable=C0415\n\n        return PITD.period_feature(instrument, str(self), start_index, end_index, cur_time, period)\n\n\nclass ExpressionOps(Expression):\n    \"\"\"Operator Expression\n\n    This kind of feature will use operator for feature\n    construction on the fly.\n    \"\"\"\n"
  },
  {
    "path": "qlib/data/cache.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\n\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport os\nimport sys\nimport stat\nimport time\nimport pickle\nimport traceback\nimport redis_lock\nimport contextlib\nimport abc\nfrom pathlib import Path\nimport numpy as np\nimport pandas as pd\nfrom typing import Union, Iterable\nfrom collections import OrderedDict\n\nfrom ..config import C\nfrom ..utils import (\n    hash_args,\n    get_redis_connection,\n    read_bin,\n    parse_field,\n    remove_fields_space,\n    normalize_cache_fields,\n    normalize_cache_instruments,\n)\nfrom ..utils.pickle_utils import restricted_pickle_load\n\nfrom ..log import get_module_logger\nfrom .base import Feature\nfrom .ops import Operators  # pylint: disable=W0611  # noqa: F401\n\n\nclass QlibCacheException(RuntimeError):\n    pass\n\n\nclass MemCacheUnit(abc.ABC):\n    \"\"\"Memory Cache Unit.\"\"\"\n\n    def __init__(self, *args, **kwargs):\n        self.size_limit = kwargs.pop(\"size_limit\", 0)\n        self._size = 0\n        self.od = OrderedDict()\n\n    def __setitem__(self, key, value):\n        # TODO: thread safe?__setitem__ failure might cause inconsistent size?\n\n        # precalculate the size after od.__setitem__\n        self._adjust_size(key, value)\n\n        self.od.__setitem__(key, value)\n\n        # move the key to end,make it latest\n        self.od.move_to_end(key)\n\n        if self.limited:\n            # pop the oldest items beyond size limit\n            while self._size > self.size_limit:\n                self.popitem(last=False)\n\n    def __getitem__(self, key):\n        v = self.od.__getitem__(key)\n        self.od.move_to_end(key)\n        return v\n\n    def __contains__(self, key):\n        return key in self.od\n\n    def __len__(self):\n        return self.od.__len__()\n\n    def __repr__(self):\n        return f\"{self.__class__.__name__}<size_limit:{self.size_limit if self.limited else 'no limit'} total_size:{self._size}>\\n{self.od.__repr__()}\"\n\n    def set_limit_size(self, limit):\n        self.size_limit = limit\n\n    @property\n    def limited(self):\n        \"\"\"whether memory cache is limited\"\"\"\n        return self.size_limit > 0\n\n    @property\n    def total_size(self):\n        return self._size\n\n    def clear(self):\n        self._size = 0\n        self.od.clear()\n\n    def popitem(self, last=True):\n        k, v = self.od.popitem(last=last)\n        self._size -= self._get_value_size(v)\n\n        return k, v\n\n    def pop(self, key):\n        v = self.od.pop(key)\n        self._size -= self._get_value_size(v)\n\n        return v\n\n    def _adjust_size(self, key, value):\n        if key in self.od:\n            self._size -= self._get_value_size(self.od[key])\n\n        self._size += self._get_value_size(value)\n\n    @abc.abstractmethod\n    def _get_value_size(self, value):\n        raise NotImplementedError\n\n\nclass MemCacheLengthUnit(MemCacheUnit):\n    def __init__(self, size_limit=0):\n        super().__init__(size_limit=size_limit)\n\n    def _get_value_size(self, value):\n        return 1\n\n\nclass MemCacheSizeofUnit(MemCacheUnit):\n    def __init__(self, size_limit=0):\n        super().__init__(size_limit=size_limit)\n\n    def _get_value_size(self, value):\n        return sys.getsizeof(value)\n\n\nclass MemCache:\n    \"\"\"Memory cache.\"\"\"\n\n    def __init__(self, mem_cache_size_limit=None, limit_type=\"length\"):\n        \"\"\"\n\n        Parameters\n        ----------\n        mem_cache_size_limit:\n            cache max size.\n        limit_type:\n            length or sizeof; length(call fun: len), size(call fun: sys.getsizeof).\n        \"\"\"\n\n        size_limit = C.mem_cache_size_limit if mem_cache_size_limit is None else mem_cache_size_limit\n        limit_type = C.mem_cache_limit_type if limit_type is None else limit_type\n\n        if limit_type == \"length\":\n            klass = MemCacheLengthUnit\n        elif limit_type == \"sizeof\":\n            klass = MemCacheSizeofUnit\n        else:\n            raise ValueError(f\"limit_type must be length or sizeof, your limit_type is {limit_type}\")\n\n        self.__calendar_mem_cache = klass(size_limit)\n        self.__instrument_mem_cache = klass(size_limit)\n        self.__feature_mem_cache = klass(size_limit)\n\n    def __getitem__(self, key):\n        if key == \"c\":\n            return self.__calendar_mem_cache\n        elif key == \"i\":\n            return self.__instrument_mem_cache\n        elif key == \"f\":\n            return self.__feature_mem_cache\n        else:\n            raise KeyError(\"Unknown memcache unit\")\n\n    def clear(self):\n        self.__calendar_mem_cache.clear()\n        self.__instrument_mem_cache.clear()\n        self.__feature_mem_cache.clear()\n\n\nclass MemCacheExpire:\n    CACHE_EXPIRE = C.mem_cache_expire\n\n    @staticmethod\n    def set_cache(mem_cache, key, value):\n        \"\"\"set cache\n\n        :param mem_cache: MemCache attribute('c'/'i'/'f').\n        :param key: cache key.\n        :param value: cache value.\n        \"\"\"\n        mem_cache[key] = value, time.time()\n\n    @staticmethod\n    def get_cache(mem_cache, key):\n        \"\"\"get mem cache\n\n        :param mem_cache: MemCache attribute('c'/'i'/'f').\n        :param key: cache key.\n        :return: cache value; if cache not exist, return None.\n        \"\"\"\n        value = None\n        expire = False\n        if key in mem_cache:\n            value, latest_time = mem_cache[key]\n            expire = (time.time() - latest_time) > MemCacheExpire.CACHE_EXPIRE\n        return value, expire\n\n\nclass CacheUtils:\n    LOCK_ID = \"QLIB\"\n\n    @staticmethod\n    def organize_meta_file():\n        pass\n\n    @staticmethod\n    def reset_lock():\n        r = get_redis_connection()\n        redis_lock.reset_all(r)\n\n    @staticmethod\n    def visit(cache_path: Union[str, Path]):\n        # FIXME: Because read_lock was canceled when reading the cache, multiple processes may have read and write exceptions here\n        try:\n            cache_path = Path(cache_path)\n            meta_path = cache_path.with_suffix(\".meta\")\n            with meta_path.open(\"rb\") as f:\n                d = restricted_pickle_load(f)\n            with meta_path.open(\"wb\") as f:\n                try:\n                    d[\"meta\"][\"last_visit\"] = str(time.time())\n                    d[\"meta\"][\"visits\"] = d[\"meta\"][\"visits\"] + 1\n                except KeyError as key_e:\n                    raise KeyError(\"Unknown meta keyword\") from key_e\n                pickle.dump(d, f, protocol=C.dump_protocol_version)\n        except Exception as e:\n            get_module_logger(\"CacheUtils\").warning(f\"visit {cache_path} cache error: {e}\")\n\n    @staticmethod\n    def acquire(lock, lock_name):\n        try:\n            lock.acquire()\n        except redis_lock.AlreadyAcquired as lock_acquired:\n            raise QlibCacheException(\n                f\"\"\"It sees the key(lock:{repr(lock_name)[1:-1]}-wlock) of the redis lock has existed in your redis db now.\n                    You can use the following command to clear your redis keys and rerun your commands:\n                    $ redis-cli\n                    > select {C.redis_task_db}\n                    > del \"lock:{repr(lock_name)[1:-1]}-wlock\"\n                    > quit\n                    If the issue is not resolved, use \"keys *\" to find if multiple keys exist. If so, try using \"flushall\" to clear all the keys.\n                \"\"\"\n            ) from lock_acquired\n\n    @staticmethod\n    @contextlib.contextmanager\n    def reader_lock(redis_t, lock_name: str):\n        current_cache_rlock = redis_lock.Lock(redis_t, f\"{lock_name}-rlock\")\n        current_cache_wlock = redis_lock.Lock(redis_t, f\"{lock_name}-wlock\")\n        lock_reader = f\"{lock_name}-reader\"\n        # make sure only one reader is entering\n        current_cache_rlock.acquire(timeout=60)\n        try:\n            current_cache_readers = redis_t.get(lock_reader)\n            if current_cache_readers is None or int(current_cache_readers) == 0:\n                CacheUtils.acquire(current_cache_wlock, lock_name)\n            redis_t.incr(lock_reader)\n        finally:\n            current_cache_rlock.release()\n        try:\n            yield\n        finally:\n            # make sure only one reader is leaving\n            current_cache_rlock.acquire(timeout=60)\n            try:\n                redis_t.decr(lock_reader)\n                if int(redis_t.get(lock_reader)) == 0:\n                    redis_t.delete(lock_reader)\n                    current_cache_wlock.reset()\n            finally:\n                current_cache_rlock.release()\n\n    @staticmethod\n    @contextlib.contextmanager\n    def writer_lock(redis_t, lock_name):\n        current_cache_wlock = redis_lock.Lock(redis_t, f\"{lock_name}-wlock\", id=CacheUtils.LOCK_ID)\n        CacheUtils.acquire(current_cache_wlock, lock_name)\n        try:\n            yield\n        finally:\n            current_cache_wlock.release()\n\n\nclass BaseProviderCache:\n    \"\"\"Provider cache base class\"\"\"\n\n    def __init__(self, provider):\n        self.provider = provider\n        self.logger = get_module_logger(self.__class__.__name__)\n\n    def __getattr__(self, attr):\n        return getattr(self.provider, attr)\n\n    @staticmethod\n    def check_cache_exists(cache_path: Union[str, Path], suffix_list: Iterable = (\".index\", \".meta\")) -> bool:\n        cache_path = Path(cache_path)\n        for p in [cache_path] + [cache_path.with_suffix(_s) for _s in suffix_list]:\n            if not p.exists():\n                return False\n        return True\n\n    @staticmethod\n    def clear_cache(cache_path: Union[str, Path]):\n        for p in [\n            cache_path,\n            cache_path.with_suffix(\".meta\"),\n            cache_path.with_suffix(\".index\"),\n        ]:\n            if p.exists():\n                p.unlink()\n\n    @staticmethod\n    def get_cache_dir(dir_name: str, freq: str = None) -> Path:\n        cache_dir = Path(C.dpm.get_data_uri(freq)).joinpath(dir_name)\n        cache_dir.mkdir(parents=True, exist_ok=True)\n        return cache_dir\n\n\nclass ExpressionCache(BaseProviderCache):\n    \"\"\"Expression cache mechanism base class.\n\n    This class is used to wrap expression provider with self-defined expression cache mechanism.\n\n    .. note:: Override the `_uri` and `_expression` method to create your own expression cache mechanism.\n    \"\"\"\n\n    def expression(self, instrument, field, start_time, end_time, freq):\n        \"\"\"Get expression data.\n\n        .. note:: Same interface as `expression` method in expression provider\n        \"\"\"\n        try:\n            return self._expression(instrument, field, start_time, end_time, freq)\n        except NotImplementedError:\n            return self.provider.expression(instrument, field, start_time, end_time, freq)\n\n    def _uri(self, instrument, field, start_time, end_time, freq):\n        \"\"\"Get expression cache file uri.\n\n        Override this method to define how to get expression cache file uri corresponding to users' own cache mechanism.\n        \"\"\"\n        raise NotImplementedError(\"Implement this function to match your own cache mechanism\")\n\n    def _expression(self, instrument, field, start_time, end_time, freq):\n        \"\"\"Get expression data using cache.\n\n        Override this method to define how to get expression data corresponding to users' own cache mechanism.\n        \"\"\"\n        raise NotImplementedError(\"Implement this method if you want to use expression cache\")\n\n    def update(self, cache_uri: Union[str, Path], freq: str = \"day\"):\n        \"\"\"Update expression cache to latest calendar.\n\n        Override this method to define how to update expression cache corresponding to users' own cache mechanism.\n\n        Parameters\n        ----------\n        cache_uri : str or Path\n            the complete uri of expression cache file (include dir path).\n        freq : str\n\n        Returns\n        -------\n        int\n            0(successful update)/ 1(no need to update)/ 2(update failure).\n        \"\"\"\n        raise NotImplementedError(\"Implement this method if you want to make expression cache up to date\")\n\n\nclass DatasetCache(BaseProviderCache):\n    \"\"\"Dataset cache mechanism base class.\n\n    This class is used to wrap dataset provider with self-defined dataset cache mechanism.\n\n    .. note:: Override the `_uri` and `_dataset` method to create your own dataset cache mechanism.\n    \"\"\"\n\n    HDF_KEY = \"df\"\n\n    def dataset(\n        self, instruments, fields, start_time=None, end_time=None, freq=\"day\", disk_cache=1, inst_processors=[]\n    ):\n        \"\"\"Get feature dataset.\n\n        .. note:: Same interface as `dataset` method in dataset provider\n\n        .. note:: The server use redis_lock to make sure\n            read-write conflicts will not be triggered\n            but client readers are not considered.\n        \"\"\"\n        if disk_cache == 0:\n            # skip cache\n            return self.provider.dataset(\n                instruments, fields, start_time, end_time, freq, inst_processors=inst_processors\n            )\n        else:\n            # use and replace cache\n            try:\n                return self._dataset(\n                    instruments, fields, start_time, end_time, freq, disk_cache, inst_processors=inst_processors\n                )\n            except NotImplementedError:\n                return self.provider.dataset(\n                    instruments, fields, start_time, end_time, freq, inst_processors=inst_processors\n                )\n\n    def _uri(self, instruments, fields, start_time, end_time, freq, **kwargs):\n        \"\"\"Get dataset cache file uri.\n\n        Override this method to define how to get dataset cache file uri corresponding to users' own cache mechanism.\n        \"\"\"\n        raise NotImplementedError(\"Implement this function to match your own cache mechanism\")\n\n    def _dataset(\n        self, instruments, fields, start_time=None, end_time=None, freq=\"day\", disk_cache=1, inst_processors=[]\n    ):\n        \"\"\"Get feature dataset using cache.\n\n        Override this method to define how to get feature dataset corresponding to users' own cache mechanism.\n        \"\"\"\n        raise NotImplementedError(\"Implement this method if you want to use dataset feature cache\")\n\n    def _dataset_uri(\n        self, instruments, fields, start_time=None, end_time=None, freq=\"day\", disk_cache=1, inst_processors=[]\n    ):\n        \"\"\"Get a uri of feature dataset using cache.\n        specially:\n            disk_cache=1 means using data set cache and return the uri of cache file.\n            disk_cache=0 means client knows the path of expression cache,\n                         server checks if the cache exists(if not, generate it), and client loads data by itself.\n        Override this method to define how to get feature dataset uri corresponding to users' own cache mechanism.\n        \"\"\"\n        raise NotImplementedError(\n            \"Implement this method if you want to use dataset feature cache as a cache file for client\"\n        )\n\n    def update(self, cache_uri: Union[str, Path], freq: str = \"day\"):\n        \"\"\"Update dataset cache to latest calendar.\n\n        Override this method to define how to update dataset cache corresponding to users' own cache mechanism.\n\n        Parameters\n        ----------\n        cache_uri : str or Path\n            the complete uri of dataset cache file (include dir path).\n        freq : str\n\n        Returns\n        -------\n        int\n            0(successful update)/ 1(no need to update)/ 2(update failure)\n        \"\"\"\n        raise NotImplementedError(\"Implement this method if you want to make expression cache up to date\")\n\n    @staticmethod\n    def cache_to_origin_data(data, fields):\n        \"\"\"cache data to origin data\n\n        :param data: pd.DataFrame, cache data.\n        :param fields: feature fields.\n        :return: pd.DataFrame.\n        \"\"\"\n        not_space_fields = remove_fields_space(fields)\n        data = data.loc[:, not_space_fields]\n        # set features fields\n        data.columns = [str(i) for i in fields]\n        return data\n\n    @staticmethod\n    def normalize_uri_args(instruments, fields, freq):\n        \"\"\"normalize uri args\"\"\"\n        instruments = normalize_cache_instruments(instruments)\n        fields = normalize_cache_fields(fields)\n        freq = freq.lower()\n\n        return instruments, fields, freq\n\n\nclass DiskExpressionCache(ExpressionCache):\n    \"\"\"Prepared cache mechanism for server.\"\"\"\n\n    def __init__(self, provider, **kwargs):\n        super(DiskExpressionCache, self).__init__(provider)\n        self.r = get_redis_connection()\n        # remote==True means client is using this module, writing behaviour will not be allowed.\n        self.remote = kwargs.get(\"remote\", False)\n\n    def get_cache_dir(self, freq: str = None) -> Path:\n        return super(DiskExpressionCache, self).get_cache_dir(C.features_cache_dir_name, freq)\n\n    def _uri(self, instrument, field, start_time, end_time, freq):\n        field = remove_fields_space(field)\n        instrument = str(instrument).lower()\n        return hash_args(instrument, field, freq)\n\n    def _expression(self, instrument, field, start_time=None, end_time=None, freq=\"day\"):\n        _cache_uri = self._uri(instrument=instrument, field=field, start_time=None, end_time=None, freq=freq)\n        _instrument_dir = self.get_cache_dir(freq).joinpath(instrument.lower())\n        cache_path = _instrument_dir.joinpath(_cache_uri)\n        # get calendar\n        from .data import Cal  # pylint: disable=C0415\n\n        _calendar = Cal.calendar(freq=freq)\n\n        _, _, start_index, end_index = Cal.locate_index(start_time, end_time, freq, future=False)\n\n        if self.check_cache_exists(cache_path, suffix_list=[\".meta\"]):\n            \"\"\"\n            In most cases, we do not need reader_lock.\n            Because updating data is a small probability event compare to reading data.\n\n            \"\"\"\n            # FIXME: Removing the reader lock may result in conflicts.\n            # with CacheUtils.reader_lock(self.r, 'expression-%s' % _cache_uri):\n\n            # modify expression cache meta file\n            try:\n                # FIXME: Multiple readers may result in error visit number\n                if not self.remote:\n                    CacheUtils.visit(cache_path)\n                series = read_bin(cache_path, start_index, end_index)\n                return series\n            except Exception:\n                series = None\n                self.logger.error(\"reading %s file error : %s\" % (cache_path, traceback.format_exc()))\n            return series\n        else:\n            # normalize field\n            field = remove_fields_space(field)\n            # cache unavailable, generate the cache\n            _instrument_dir.mkdir(parents=True, exist_ok=True)\n            if not isinstance(eval(parse_field(field)), Feature):\n                # When the expression is not a raw feature\n                # generate expression cache if the feature is not a Feature\n                # instance\n                series = self.provider.expression(instrument, field, _calendar[0], _calendar[-1], freq)\n                if not series.empty:\n                    # This expression is empty, we don't generate any cache for it.\n                    with CacheUtils.writer_lock(self.r, f\"{str(C.dpm.get_data_uri(freq))}:expression-{_cache_uri}\"):\n                        self.gen_expression_cache(\n                            expression_data=series,\n                            cache_path=cache_path,\n                            instrument=instrument,\n                            field=field,\n                            freq=freq,\n                            last_update=str(_calendar[-1]),\n                        )\n                    return series.loc[start_index:end_index]\n                else:\n                    return series\n            else:\n                # If the expression is a raw feature(such as $close, $open)\n                return self.provider.expression(instrument, field, start_time, end_time, freq)\n\n    def gen_expression_cache(self, expression_data, cache_path, instrument, field, freq, last_update):\n        \"\"\"use bin file to save like feature-data.\"\"\"\n        # Make sure the cache runs right when the directory is deleted\n        # while running\n        meta = {\n            \"info\": {\"instrument\": instrument, \"field\": field, \"freq\": freq, \"last_update\": last_update},\n            \"meta\": {\"last_visit\": time.time(), \"visits\": 1},\n        }\n        self.logger.debug(f\"generating expression cache: {meta}\")\n        self.clear_cache(cache_path)\n        meta_path = cache_path.with_suffix(\".meta\")\n\n        with meta_path.open(\"wb\") as f:\n            pickle.dump(meta, f, protocol=C.dump_protocol_version)\n        meta_path.chmod(stat.S_IRWXU | stat.S_IRGRP | stat.S_IROTH)\n        df = expression_data.to_frame()\n\n        r = np.hstack([df.index[0], expression_data]).astype(\"<f\")\n        r.tofile(str(cache_path))\n\n    def update(self, sid, cache_uri, freq: str = \"day\"):\n        cp_cache_uri = self.get_cache_dir(freq).joinpath(sid).joinpath(cache_uri)\n        meta_path = cp_cache_uri.with_suffix(\".meta\")\n        if not self.check_cache_exists(cp_cache_uri, suffix_list=[\".meta\"]):\n            self.logger.info(f\"The cache {cp_cache_uri} has corrupted. It will be removed\")\n            self.clear_cache(cp_cache_uri)\n            return 2\n\n        with CacheUtils.writer_lock(self.r, f\"{str(C.dpm.get_data_uri())}:expression-{cache_uri}\"):\n            with meta_path.open(\"rb\") as f:\n                d = restricted_pickle_load(f)\n            instrument = d[\"info\"][\"instrument\"]\n            field = d[\"info\"][\"field\"]\n            freq = d[\"info\"][\"freq\"]\n            last_update_time = d[\"info\"][\"last_update\"]\n\n            # get newest calendar\n            from .data import Cal, ExpressionD  # pylint: disable=C0415\n\n            whole_calendar = Cal.calendar(start_time=None, end_time=None, freq=freq)\n            # calendar since last updated.\n            new_calendar = Cal.calendar(start_time=last_update_time, end_time=None, freq=freq)\n\n            # get append data\n            if len(new_calendar) <= 1:\n                # Including last updated calendar, we only get 1 item.\n                # No future updating is needed.\n                return 1\n            else:\n                # get the data needed after the historical data are removed.\n                # The start index of new data\n                current_index = len(whole_calendar) - len(new_calendar) + 1\n\n                # The existing data length\n                size_bytes = os.path.getsize(cp_cache_uri)\n                ele_size = np.dtype(\"<f\").itemsize\n                assert size_bytes % ele_size == 0\n                ele_n = size_bytes // ele_size - 1\n\n                expr = ExpressionD.get_expression_instance(field)\n                lft_etd, rght_etd = expr.get_extended_window_size()\n                # The expression used the future data after rght_etd days.\n                # So the last rght_etd data should be removed.\n                # There are most `ele_n` period of data can be remove\n                remove_n = min(rght_etd, ele_n)\n                assert new_calendar[1] == whole_calendar[current_index]\n                data = self.provider.expression(\n                    instrument, field, whole_calendar[current_index - remove_n], new_calendar[-1], freq\n                )\n                with open(cp_cache_uri, \"ab\") as f:\n                    data = np.array(data).astype(\"<f\")\n                    # Remove the last bits\n                    f.truncate(size_bytes - ele_size * remove_n)\n                    f.write(data)\n                # update meta file\n                d[\"info\"][\"last_update\"] = str(new_calendar[-1])\n                with meta_path.open(\"wb\") as f:\n                    pickle.dump(d, f, protocol=C.dump_protocol_version)\n        return 0\n\n\nclass DiskDatasetCache(DatasetCache):\n    \"\"\"Prepared cache mechanism for server.\"\"\"\n\n    def __init__(self, provider, **kwargs):\n        super(DiskDatasetCache, self).__init__(provider)\n        self.r = get_redis_connection()\n        self.remote = kwargs.get(\"remote\", False)\n\n    @staticmethod\n    def _uri(instruments, fields, start_time, end_time, freq, disk_cache=1, inst_processors=[], **kwargs):\n        return hash_args(*DatasetCache.normalize_uri_args(instruments, fields, freq), disk_cache, inst_processors)\n\n    def get_cache_dir(self, freq: str = None) -> Path:\n        return super(DiskDatasetCache, self).get_cache_dir(C.dataset_cache_dir_name, freq)\n\n    @classmethod\n    def read_data_from_cache(cls, cache_path: Union[str, Path], start_time, end_time, fields):\n        \"\"\"read_cache_from\n\n        This function can read data from the disk cache dataset\n\n        :param cache_path:\n        :param start_time:\n        :param end_time:\n        :param fields: The fields order of the dataset cache is sorted. So rearrange the columns to make it consistent.\n        :return:\n        \"\"\"\n\n        im = DiskDatasetCache.IndexManager(cache_path)\n        index_data = im.get_index(start_time, end_time)\n        if index_data.shape[0] > 0:\n            start, stop = (\n                index_data[\"start\"].iloc[0].item(),\n                index_data[\"end\"].iloc[-1].item(),\n            )\n        else:\n            start = stop = 0\n\n        with pd.HDFStore(cache_path, mode=\"r\") as store:\n            if \"/{}\".format(im.KEY) in store.keys():\n                df = store.select(key=im.KEY, start=start, stop=stop)\n                df = df.swaplevel(\"datetime\", \"instrument\").sort_index()\n                # read cache and need to replace not-space fields to field\n                df = cls.cache_to_origin_data(df, fields)\n\n            else:\n                df = pd.DataFrame(columns=fields)\n        return df\n\n    def _dataset(\n        self, instruments, fields, start_time=None, end_time=None, freq=\"day\", disk_cache=0, inst_processors=[]\n    ):\n        if disk_cache == 0:\n            # In this case, data_set cache is configured but will not be used.\n            return self.provider.dataset(\n                instruments, fields, start_time, end_time, freq, inst_processors=inst_processors\n            )\n        # FIXME: The cache after resample, when read again and intercepted with end_time, results in incomplete data date\n        if inst_processors:\n            raise ValueError(\n                f\"{self.__class__.__name__} does not support inst_processor. \"\n                f\"Please use `D.features(disk_cache=0)` or `qlib.init(dataset_cache=None)`\"\n            )\n        _cache_uri = self._uri(\n            instruments=instruments,\n            fields=fields,\n            start_time=None,\n            end_time=None,\n            freq=freq,\n            disk_cache=disk_cache,\n            inst_processors=inst_processors,\n        )\n\n        cache_path = self.get_cache_dir(freq).joinpath(_cache_uri)\n\n        features = pd.DataFrame()\n        gen_flag = False\n\n        if self.check_cache_exists(cache_path):\n            if disk_cache == 1:\n                # use cache\n                with CacheUtils.reader_lock(self.r, f\"{str(C.dpm.get_data_uri(freq))}:dataset-{_cache_uri}\"):\n                    CacheUtils.visit(cache_path)\n                    features = self.read_data_from_cache(cache_path, start_time, end_time, fields)\n            elif disk_cache == 2:\n                gen_flag = True\n        else:\n            gen_flag = True\n\n        if gen_flag:\n            # cache unavailable, generate the cache\n            with CacheUtils.writer_lock(self.r, f\"{str(C.dpm.get_data_uri(freq))}:dataset-{_cache_uri}\"):\n                features = self.gen_dataset_cache(\n                    cache_path=cache_path,\n                    instruments=instruments,\n                    fields=fields,\n                    freq=freq,\n                    inst_processors=inst_processors,\n                )\n            if not features.empty:\n                features = features.sort_index().loc(axis=0)[:, start_time:end_time]\n        return features\n\n    def _dataset_uri(\n        self, instruments, fields, start_time=None, end_time=None, freq=\"day\", disk_cache=0, inst_processors=[]\n    ):\n        if disk_cache == 0:\n            # In this case, server only checks the expression cache.\n            # The client will load the cache data by itself.\n            from .data import LocalDatasetProvider  # pylint: disable=C0415\n\n            LocalDatasetProvider.multi_cache_walker(instruments, fields, start_time, end_time, freq)\n            return \"\"\n        # FIXME: The cache after resample, when read again and intercepted with end_time, results in incomplete data date\n        if inst_processors:\n            raise ValueError(\n                f\"{self.__class__.__name__} does not support inst_processor. \"\n                f\"Please use `D.features(disk_cache=0)` or `qlib.init(dataset_cache=None)`\"\n            )\n        _cache_uri = self._uri(\n            instruments=instruments,\n            fields=fields,\n            start_time=None,\n            end_time=None,\n            freq=freq,\n            disk_cache=disk_cache,\n            inst_processors=inst_processors,\n        )\n        cache_path = self.get_cache_dir(freq).joinpath(_cache_uri)\n\n        if self.check_cache_exists(cache_path):\n            self.logger.debug(f\"The cache dataset has already existed {cache_path}. Return the uri directly\")\n            with CacheUtils.reader_lock(self.r, f\"{str(C.dpm.get_data_uri(freq))}:dataset-{_cache_uri}\"):\n                CacheUtils.visit(cache_path)\n            return _cache_uri\n        else:\n            # cache unavailable, generate the cache\n            with CacheUtils.writer_lock(self.r, f\"{str(C.dpm.get_data_uri(freq))}:dataset-{_cache_uri}\"):\n                self.gen_dataset_cache(\n                    cache_path=cache_path,\n                    instruments=instruments,\n                    fields=fields,\n                    freq=freq,\n                    inst_processors=inst_processors,\n                )\n            return _cache_uri\n\n    class IndexManager:\n        \"\"\"\n        The lock is not considered in the class. Please consider the lock outside the code.\n        This class is the proxy of the disk data.\n        \"\"\"\n\n        KEY = \"df\"\n\n        def __init__(self, cache_path: Union[str, Path]):\n            self.index_path = cache_path.with_suffix(\".index\")\n            self._data = None\n            self.logger = get_module_logger(self.__class__.__name__)\n\n        def get_index(self, start_time=None, end_time=None):\n            # TODO: fast read index from the disk.\n            if self._data is None:\n                self.sync_from_disk()\n            return self._data.loc[start_time:end_time].copy()\n\n        def sync_to_disk(self):\n            if self._data is None:\n                raise ValueError(\"No data to sync to disk.\")\n            self._data.sort_index(inplace=True)\n            self._data.to_hdf(self.index_path, key=self.KEY, mode=\"w\", format=\"table\")\n            # The index should be readable for all users\n            self.index_path.chmod(stat.S_IRWXU | stat.S_IRGRP | stat.S_IROTH)\n\n        def sync_from_disk(self):\n            # The file will not be closed directly if we read_hdf from the disk directly\n            with pd.HDFStore(self.index_path, mode=\"r\") as store:\n                if \"/{}\".format(self.KEY) in store.keys():\n                    self._data = pd.read_hdf(store, key=self.KEY)\n                else:\n                    self._data = pd.DataFrame()\n\n        def update(self, data, sync=True):\n            self._data = data.astype(np.int32).copy()\n            if sync:\n                self.sync_to_disk()\n\n        def append_index(self, data, to_disk=True):\n            data = data.astype(np.int32).copy()\n            data.sort_index(inplace=True)\n            self._data = pd.concat([self._data, data])\n            if to_disk:\n                with pd.HDFStore(self.index_path) as store:\n                    store.append(self.KEY, data, append=True)\n\n        @staticmethod\n        def build_index_from_data(data, start_index=0):\n            if data.empty:\n                return pd.DataFrame()\n            line_data = data.groupby(\"datetime\", group_keys=False).size()\n            line_data.sort_index(inplace=True)\n            index_end = line_data.cumsum()\n            index_start = index_end.shift(1, fill_value=0)\n\n            index_data = pd.DataFrame()\n            index_data[\"start\"] = index_start\n            index_data[\"end\"] = index_end\n            index_data += start_index\n            return index_data\n\n    def gen_dataset_cache(self, cache_path: Union[str, Path], instruments, fields, freq, inst_processors=[]):\n        \"\"\"gen_dataset_cache\n\n        .. note:: This function does not consider the cache read write lock. Please\n            acquire the lock outside this function\n\n        The format the cache contains 3 parts(followed by typical filename).\n\n        - index : cache/d41366901e25de3ec47297f12e2ba11d.index\n\n            - The content of the file may be in following format(pandas.Series)\n\n                .. code-block:: python\n\n                                        start end\n                    1999-11-10 00:00:00     0   1\n                    1999-11-11 00:00:00     1   2\n                    1999-11-12 00:00:00     2   3\n                    ...\n\n                .. note:: The start is closed. The end is open!!!!!\n\n            - Each line contains two element <start_index, end_index> with a timestamp as its index.\n            - It indicates the `start_index` (included) and `end_index` (excluded) of the data for `timestamp`\n\n        - meta data: cache/d41366901e25de3ec47297f12e2ba11d.meta\n\n        - data     : cache/d41366901e25de3ec47297f12e2ba11d\n\n            - This is a hdf file sorted by datetime\n\n        :param cache_path:  The path to store the cache.\n        :param instruments:  The instruments to store the cache.\n        :param fields:  The fields to store the cache.\n        :param freq:  The freq to store the cache.\n        :param inst_processors:  Instrument processors.\n\n        :return type pd.DataFrame; The fields of the returned DataFrame are consistent with the parameters of the function.\n        \"\"\"\n        # get calendar\n        from .data import Cal  # pylint: disable=C0415\n\n        cache_path = Path(cache_path)\n        _calendar = Cal.calendar(freq=freq)\n        self.logger.debug(f\"Generating dataset cache {cache_path}\")\n        # Make sure the cache runs right when the directory is deleted\n        # while running\n        self.clear_cache(cache_path)\n\n        features = self.provider.dataset(\n            instruments, fields, _calendar[0], _calendar[-1], freq, inst_processors=inst_processors\n        )\n\n        if features.empty:\n            return features\n\n        # swap index and sorted\n        features = features.swaplevel(\"instrument\", \"datetime\").sort_index()\n\n        # write cache data\n        with pd.HDFStore(str(cache_path.with_suffix(\".data\"))) as store:\n            cache_to_orig_map = dict(zip(remove_fields_space(features.columns), features.columns))\n            orig_to_cache_map = dict(zip(features.columns, remove_fields_space(features.columns)))\n            cache_features = features[list(cache_to_orig_map.values())].rename(columns=orig_to_cache_map)\n            # cache columns\n            cache_columns = sorted(cache_features.columns)\n            cache_features = cache_features.loc[:, cache_columns]\n            cache_features = cache_features.loc[:, ~cache_features.columns.duplicated()]\n            store.append(DatasetCache.HDF_KEY, cache_features, append=False)\n        # write meta file\n        meta = {\n            \"info\": {\n                \"instruments\": instruments,\n                \"fields\": list(cache_features.columns),\n                \"freq\": freq,\n                \"last_update\": str(_calendar[-1]),  # The last_update to store the cache\n                \"inst_processors\": inst_processors,  # The last_update to store the cache\n            },\n            \"meta\": {\"last_visit\": time.time(), \"visits\": 1},\n        }\n        with cache_path.with_suffix(\".meta\").open(\"wb\") as f:\n            pickle.dump(meta, f, protocol=C.dump_protocol_version)\n        cache_path.with_suffix(\".meta\").chmod(stat.S_IRWXU | stat.S_IRGRP | stat.S_IROTH)\n        # write index file\n        im = DiskDatasetCache.IndexManager(cache_path)\n        index_data = im.build_index_from_data(features)\n        im.update(index_data)\n\n        # rename the file after the cache has been generated\n        # this doesn't work well on windows, but our server won't use windows\n        # temporarily\n        cache_path.with_suffix(\".data\").rename(cache_path)\n        # the fields of the cached features are converted to the original fields\n        return features.swaplevel(\"datetime\", \"instrument\")\n\n    def update(self, cache_uri, freq: str = \"day\"):\n        cp_cache_uri = self.get_cache_dir(freq).joinpath(cache_uri)\n        meta_path = cp_cache_uri.with_suffix(\".meta\")\n        if not self.check_cache_exists(cp_cache_uri):\n            self.logger.info(f\"The cache {cp_cache_uri} has corrupted. It will be removed\")\n            self.clear_cache(cp_cache_uri)\n            return 2\n\n        im = DiskDatasetCache.IndexManager(cp_cache_uri)\n        with CacheUtils.writer_lock(self.r, f\"{str(C.dpm.get_data_uri())}:dataset-{cache_uri}\"):\n            with meta_path.open(\"rb\") as f:\n                d = restricted_pickle_load(f)\n            instruments = d[\"info\"][\"instruments\"]\n            fields = d[\"info\"][\"fields\"]\n            freq = d[\"info\"][\"freq\"]\n            last_update_time = d[\"info\"][\"last_update\"]\n            inst_processors = d[\"info\"].get(\"inst_processors\", [])\n            index_data = im.get_index()\n\n            self.logger.debug(\"Updating dataset: {}\".format(d))\n            from .data import Inst  # pylint: disable=C0415\n\n            if Inst.get_inst_type(instruments) == Inst.DICT:\n                self.logger.info(f\"The file {cache_uri} has dict cache. Skip updating\")\n                return 1\n\n            # get newest calendar\n            from .data import Cal  # pylint: disable=C0415\n\n            whole_calendar = Cal.calendar(start_time=None, end_time=None, freq=freq)\n            # The calendar since last updated\n            new_calendar = Cal.calendar(start_time=last_update_time, end_time=None, freq=freq)\n\n            # get append data\n            if len(new_calendar) <= 1:\n                # Including last updated calendar, we only get 1 item.\n                # No future updating is needed.\n                return 1\n            else:\n                # get the data needed after the historical data are removed.\n                # The start index of new data\n                current_index = len(whole_calendar) - len(new_calendar) + 1\n\n                # To avoid recursive import\n                from .data import ExpressionD  # pylint: disable=C0415\n\n                # The existing data length\n                lft_etd = rght_etd = 0\n                for field in fields:\n                    expr = ExpressionD.get_expression_instance(field)\n                    l, r = expr.get_extended_window_size()\n                    lft_etd = max(lft_etd, l)\n                    rght_etd = max(rght_etd, r)\n                # remove the period that should be updated.\n                if index_data.empty:\n                    # We don't have any data for such dataset. Nothing to remove\n                    rm_n_period = rm_lines = 0\n                else:\n                    rm_n_period = min(rght_etd, index_data.shape[0])\n                    rm_lines = (\n                        (index_data[\"end\"] - index_data[\"start\"])\n                        .loc[whole_calendar[current_index - rm_n_period] :]\n                        .sum()\n                        .item()\n                    )\n\n                data = self.provider.dataset(\n                    instruments,\n                    fields,\n                    whole_calendar[current_index - rm_n_period],\n                    new_calendar[-1],\n                    freq,\n                    inst_processors=inst_processors,\n                )\n\n                if not data.empty:\n                    data.reset_index(inplace=True)\n                    data.set_index([\"datetime\", \"instrument\"], inplace=True)\n                    data.sort_index(inplace=True)\n                else:\n                    return 0  # No data to update cache\n\n                store = pd.HDFStore(cp_cache_uri)\n                # FIXME:\n                # Because the feature cache are stored as .bin file.\n                # So the series read from features are all float32.\n                # However, the first dataset cache is calculated based on the\n                # raw data. So the data type may be float64.\n                # Different data type will result in failure of appending data\n                if \"/{}\".format(DatasetCache.HDF_KEY) in store.keys():\n                    schema = store.select(DatasetCache.HDF_KEY, start=0, stop=0)\n                    for col, dtype in schema.dtypes.items():\n                        data[col] = data[col].astype(dtype)\n                if rm_lines > 0:\n                    store.remove(key=im.KEY, start=-rm_lines)\n                store.append(DatasetCache.HDF_KEY, data)\n                store.close()\n\n                # update index file\n                new_index_data = im.build_index_from_data(\n                    data.loc(axis=0)[whole_calendar[current_index] :, :],\n                    start_index=0 if index_data.empty else index_data[\"end\"].iloc[-1],\n                )\n                im.append_index(new_index_data)\n\n                # update meta file\n                d[\"info\"][\"last_update\"] = str(new_calendar[-1])\n                with meta_path.open(\"wb\") as f:\n                    pickle.dump(d, f, protocol=C.dump_protocol_version)\n                return 0\n\n\nclass SimpleDatasetCache(DatasetCache):\n    \"\"\"Simple dataset cache that can be used locally or on client.\"\"\"\n\n    def __init__(self, provider):\n        super(SimpleDatasetCache, self).__init__(provider)\n        try:\n            self.local_cache_path: Path = Path(C[\"local_cache_path\"]).expanduser().resolve()\n        except (KeyError, TypeError):\n            self.logger.error(\"Assign a local_cache_path in config if you want to use this cache mechanism\")\n            raise\n        self.logger.info(\n            f\"DatasetCache directory: {self.local_cache_path}, \"\n            f\"modify the cache directory via the local_cache_path in the config\"\n        )\n\n    def _uri(self, instruments, fields, start_time, end_time, freq, disk_cache=1, inst_processors=[], **kwargs):\n        instruments, fields, freq = self.normalize_uri_args(instruments, fields, freq)\n        return hash_args(\n            instruments, fields, start_time, end_time, freq, disk_cache, str(self.local_cache_path), inst_processors\n        )\n\n    def _dataset(\n        self, instruments, fields, start_time=None, end_time=None, freq=\"day\", disk_cache=1, inst_processors=[]\n    ):\n        if disk_cache == 0:\n            # In this case, data_set cache is configured but will not be used.\n            return self.provider.dataset(instruments, fields, start_time, end_time, freq)\n        self.local_cache_path.mkdir(exist_ok=True, parents=True)\n        cache_file = self.local_cache_path.joinpath(\n            self._uri(\n                instruments, fields, start_time, end_time, freq, disk_cache=disk_cache, inst_processors=inst_processors\n            )\n        )\n        gen_flag = False\n\n        if cache_file.exists():\n            if disk_cache == 1:\n                # use cache\n                df = pd.read_pickle(cache_file)\n                return self.cache_to_origin_data(df, fields)\n            elif disk_cache == 2:\n                # replace cache\n                gen_flag = True\n        else:\n            gen_flag = True\n\n        if gen_flag:\n            data = self.provider.dataset(\n                instruments, normalize_cache_fields(fields), start_time, end_time, freq, inst_processors=inst_processors\n            )\n            data.to_pickle(cache_file)\n            return self.cache_to_origin_data(data, fields)\n\n\nclass DatasetURICache(DatasetCache):\n    \"\"\"Prepared cache mechanism for server.\"\"\"\n\n    def _uri(self, instruments, fields, start_time, end_time, freq, disk_cache=1, inst_processors=[], **kwargs):\n        return hash_args(*self.normalize_uri_args(instruments, fields, freq), disk_cache, inst_processors)\n\n    def dataset(\n        self, instruments, fields, start_time=None, end_time=None, freq=\"day\", disk_cache=0, inst_processors=[]\n    ):\n        if \"local\" in C.dataset_provider.lower():\n            # use LocalDatasetProvider\n            return self.provider.dataset(\n                instruments, fields, start_time, end_time, freq, inst_processors=inst_processors\n            )\n\n        if disk_cache == 0:\n            # do not use data_set cache, load data from remote expression cache directly\n            return self.provider.dataset(\n                instruments,\n                fields,\n                start_time,\n                end_time,\n                freq,\n                disk_cache,\n                return_uri=False,\n                inst_processors=inst_processors,\n            )\n        # FIXME: The cache after resample, when read again and intercepted with end_time, results in incomplete data date\n        if inst_processors:\n            raise ValueError(\n                f\"{self.__class__.__name__} does not support inst_processor. \"\n                f\"Please use `D.features(disk_cache=0)` or `qlib.init(dataset_cache=None)`\"\n            )\n        # use ClientDatasetProvider\n        feature_uri = self._uri(\n            instruments, fields, None, None, freq, disk_cache=disk_cache, inst_processors=inst_processors\n        )\n        value, expire = MemCacheExpire.get_cache(H[\"f\"], feature_uri)\n        mnt_feature_uri = C.dpm.get_data_uri(freq).joinpath(C.dataset_cache_dir_name).joinpath(feature_uri)\n        if value is None or expire or not mnt_feature_uri.exists():\n            df, uri = self.provider.dataset(\n                instruments,\n                fields,\n                start_time,\n                end_time,\n                freq,\n                disk_cache,\n                return_uri=True,\n                inst_processors=inst_processors,\n            )\n            # cache uri\n            MemCacheExpire.set_cache(H[\"f\"], uri, uri)\n            # cache DataFrame\n            # HZ['f'][uri] = df.copy()\n            get_module_logger(\"cache\").debug(f\"get feature from {C.dataset_provider}\")\n        else:\n            df = DiskDatasetCache.read_data_from_cache(mnt_feature_uri, start_time, end_time, fields)\n            get_module_logger(\"cache\").debug(\"get feature from uri cache\")\n\n        return df\n\n\nclass CalendarCache(BaseProviderCache):\n    pass\n\n\nclass MemoryCalendarCache(CalendarCache):\n    def calendar(self, start_time=None, end_time=None, freq=\"day\", future=False):\n        uri = self._uri(start_time, end_time, freq, future)\n        result, expire = MemCacheExpire.get_cache(H[\"c\"], uri)\n        if result is None or expire:\n            result = self.provider.calendar(start_time, end_time, freq, future)\n            MemCacheExpire.set_cache(H[\"c\"], uri, result)\n\n            get_module_logger(\"data\").debug(f\"get calendar from {C.calendar_provider}\")\n        else:\n            get_module_logger(\"data\").debug(\"get calendar from local cache\")\n\n        return result\n\n\nH = MemCache()\n"
  },
  {
    "path": "qlib/data/client.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\n\nfrom __future__ import division, print_function\n\nimport json\n\nimport socketio\n\nimport qlib\n\nfrom ..log import get_module_logger\n\n\nclass Client:\n    \"\"\"A client class\n\n    Provide the connection tool functions for ClientProvider.\n    \"\"\"\n\n    def __init__(self, host, port):\n        super(Client, self).__init__()\n        self.sio = socketio.Client()\n        self.server_host = host\n        self.server_port = port\n        self.logger = get_module_logger(self.__class__.__name__)\n        # bind connect/disconnect callbacks\n        self.sio.on(\n            \"connect\",\n            lambda: self.logger.debug(\"Connect to server {}\".format(self.sio.connection_url)),\n        )\n        self.sio.on(\"disconnect\", lambda: self.logger.debug(\"Disconnect from server!\"))\n\n    def connect_server(self):\n        \"\"\"Connect to server.\"\"\"\n        try:\n            self.sio.connect(f\"ws://{self.server_host}:{self.server_port}\")\n        except socketio.exceptions.ConnectionError:\n            self.logger.error(\"Cannot connect to server - check your network or server status\")\n\n    def disconnect(self):\n        \"\"\"Disconnect from server.\"\"\"\n        try:\n            self.sio.eio.disconnect(True)\n        except Exception as e:\n            self.logger.error(\"Cannot disconnect from server : %s\" % e)\n\n    def send_request(self, request_type, request_content, msg_queue, msg_proc_func=None):\n        \"\"\"Send a certain request to server.\n\n        Parameters\n        ----------\n        request_type : str\n            type of proposed request, 'calendar'/'instrument'/'feature'.\n        request_content : dict\n            records the information of the request.\n        msg_proc_func : func\n            the function to process the message when receiving response, should have arg `*args`.\n        msg_queue: Queue\n            The queue to pass the message after callback.\n        \"\"\"\n        head_info = {\"version\": qlib.__version__}\n\n        def request_callback(*args):\n            \"\"\"callback_wrapper\n\n            :param *args: args[0] is the response content\n            \"\"\"\n            # args[0] is the response content\n            self.logger.debug(\"receive data and enter queue\")\n            msg = dict(args[0])\n            if msg[\"detailed_info\"] is not None:\n                if msg[\"status\"] != 0:\n                    self.logger.error(msg[\"detailed_info\"])\n                else:\n                    self.logger.info(msg[\"detailed_info\"])\n            if msg[\"status\"] != 0:\n                ex = ValueError(f\"Bad response(status=={msg['status']}), detailed info: {msg['detailed_info']}\")\n                msg_queue.put(ex)\n            else:\n                if msg_proc_func is not None:\n                    try:\n                        ret = msg_proc_func(msg[\"result\"])\n                    except Exception as e:\n                        self.logger.exception(\"Error when processing message.\")\n                        ret = e\n                else:\n                    ret = msg[\"result\"]\n                msg_queue.put(ret)\n            self.disconnect()\n            self.logger.debug(\"disconnected\")\n\n        self.logger.debug(\"try connecting\")\n        self.connect_server()\n        self.logger.debug(\"connected\")\n        # The pickle is for passing some parameters with special type(such as\n        # pd.Timestamp)\n        request_content = {\"head\": head_info, \"body\": json.dumps(request_content, default=str)}\n        self.sio.on(request_type + \"_response\", request_callback)\n        self.logger.debug(\"try sending\")\n        self.sio.emit(request_type + \"_request\", request_content)\n        self.sio.wait()\n"
  },
  {
    "path": "qlib/data/data.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\n\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport re\nimport abc\nimport copy\nimport queue\nimport bisect\nimport numpy as np\nimport pandas as pd\nfrom typing import List, Union, Optional\n\n# For supporting multiprocessing in outer code, joblib is used\nfrom joblib import delayed\n\nfrom .cache import H\nfrom ..config import C\nfrom .inst_processor import InstProcessor\n\nfrom ..log import get_module_logger\nfrom .cache import DiskDatasetCache\nfrom ..utils import (\n    Wrapper,\n    init_instance_by_config,\n    register_wrapper,\n    get_module_by_module_path,\n    parse_field,\n    hash_args,\n    normalize_cache_fields,\n    code_to_fname,\n    time_to_slc_point,\n    read_period_data,\n    get_period_list,\n)\nfrom ..utils.paral import ParallelExt\nfrom .ops import Operators  # pylint: disable=W0611  # noqa: F401\n\n\nclass ProviderBackendMixin:\n    \"\"\"\n    This helper class tries to make the provider based on storage backend more convenient\n    It is not necessary to inherent this class if that provider don't rely on the backend storage\n    \"\"\"\n\n    def get_default_backend(self):\n        backend = {}\n        provider_name: str = re.findall(\"[A-Z][^A-Z]*\", self.__class__.__name__)[-2]\n        # set default storage class\n        backend.setdefault(\"class\", f\"File{provider_name}Storage\")\n        # set default storage module\n        backend.setdefault(\"module_path\", \"qlib.data.storage.file_storage\")\n        return backend\n\n    def backend_obj(self, **kwargs):\n        backend = self.backend if self.backend else self.get_default_backend()\n        backend = copy.deepcopy(backend)\n        backend.setdefault(\"kwargs\", {}).update(**kwargs)\n        return init_instance_by_config(backend)\n\n\nclass CalendarProvider(abc.ABC):\n    \"\"\"Calendar provider base class\n\n    Provide calendar data.\n    \"\"\"\n\n    def calendar(self, start_time=None, end_time=None, freq=\"day\", future=False):\n        \"\"\"Get calendar of certain market in given time range.\n\n        Parameters\n        ----------\n        start_time : str\n            start of the time range.\n        end_time : str\n            end of the time range.\n        freq : str\n            time frequency, available: year/quarter/month/week/day.\n        future : bool\n            whether including future trading day.\n\n        Returns\n        ----------\n        list\n            calendar list\n        \"\"\"\n        _calendar, _calendar_index = self._get_calendar(freq, future)\n        if start_time == \"None\":\n            start_time = None\n        if end_time == \"None\":\n            end_time = None\n        # strip\n        if start_time:\n            start_time = pd.Timestamp(start_time)\n            if start_time > _calendar[-1]:\n                return np.array([])\n        else:\n            start_time = _calendar[0]\n        if end_time:\n            end_time = pd.Timestamp(end_time)\n            if end_time < _calendar[0]:\n                return np.array([])\n        else:\n            end_time = _calendar[-1]\n        _, _, si, ei = self.locate_index(start_time, end_time, freq, future)\n        return _calendar[si : ei + 1]\n\n    def locate_index(\n        self, start_time: Union[pd.Timestamp, str], end_time: Union[pd.Timestamp, str], freq: str, future: bool = False\n    ):\n        \"\"\"Locate the start time index and end time index in a calendar under certain frequency.\n\n        Parameters\n        ----------\n        start_time : pd.Timestamp\n            start of the time range.\n        end_time : pd.Timestamp\n            end of the time range.\n        freq : str\n            time frequency, available: year/quarter/month/week/day.\n        future : bool\n            whether including future trading day.\n\n        Returns\n        -------\n        pd.Timestamp\n            the real start time.\n        pd.Timestamp\n            the real end time.\n        int\n            the index of start time.\n        int\n            the index of end time.\n        \"\"\"\n        start_time = pd.Timestamp(start_time)\n        end_time = pd.Timestamp(end_time)\n        calendar, calendar_index = self._get_calendar(freq=freq, future=future)\n        if start_time not in calendar_index:\n            try:\n                start_time = calendar[bisect.bisect_left(calendar, start_time)]\n            except IndexError as index_e:\n                raise IndexError(\n                    \"`start_time` uses a future date, if you want to get future trading days, you can use: `future=True`\"\n                ) from index_e\n        start_index = calendar_index[start_time]\n        if end_time not in calendar_index:\n            end_time = calendar[bisect.bisect_right(calendar, end_time) - 1]\n        end_index = calendar_index[end_time]\n        return start_time, end_time, start_index, end_index\n\n    def _get_calendar(self, freq, future):\n        \"\"\"Load calendar using memcache.\n\n        Parameters\n        ----------\n        freq : str\n            frequency of read calendar file.\n        future : bool\n            whether including future trading day.\n\n        Returns\n        -------\n        list\n            list of timestamps.\n        dict\n            dict composed by timestamp as key and index as value for fast search.\n        \"\"\"\n        flag = f\"{freq}_future_{future}\"\n        if flag not in H[\"c\"]:\n            _calendar = np.array(self.load_calendar(freq, future))\n            _calendar_index = {x: i for i, x in enumerate(_calendar)}  # for fast search\n            H[\"c\"][flag] = _calendar, _calendar_index\n        return H[\"c\"][flag]\n\n    def _uri(self, start_time, end_time, freq, future=False):\n        \"\"\"Get the uri of calendar generation task.\"\"\"\n        return hash_args(start_time, end_time, freq, future)\n\n    def load_calendar(self, freq, future):\n        \"\"\"Load original calendar timestamp from file.\n\n        Parameters\n        ----------\n        freq : str\n            frequency of read calendar file.\n        future: bool\n\n        Returns\n        ----------\n        list\n            list of timestamps\n        \"\"\"\n        raise NotImplementedError(\"Subclass of CalendarProvider must implement `load_calendar` method\")\n\n\nclass InstrumentProvider(abc.ABC):\n    \"\"\"Instrument provider base class\n\n    Provide instrument data.\n    \"\"\"\n\n    @staticmethod\n    def instruments(market: Union[List, str] = \"all\", filter_pipe: Union[List, None] = None):\n        \"\"\"Get the general config dictionary for a base market adding several dynamic filters.\n\n        Parameters\n        ----------\n        market : Union[List, str]\n            str:\n                market/industry/index shortname, e.g. all/sse/szse/sse50/csi300/csi500.\n            list:\n                [\"ID1\", \"ID2\"]. A list of stocks\n        filter_pipe : list\n            the list of dynamic filters.\n\n        Returns\n        ----------\n        dict: if isinstance(market, str)\n            dict of stockpool config.\n\n            {`market` => base market name, `filter_pipe` => list of filters}\n\n            example :\n\n            .. code-block::\n\n                {'market': 'csi500',\n                'filter_pipe': [{'filter_type': 'ExpressionDFilter',\n                'rule_expression': '$open<40',\n                'filter_start_time': None,\n                'filter_end_time': None,\n                'keep': False},\n                {'filter_type': 'NameDFilter',\n                'name_rule_re': 'SH[0-9]{4}55',\n                'filter_start_time': None,\n                'filter_end_time': None}]}\n\n        list: if isinstance(market, list)\n            just return the original list directly.\n            NOTE: this will make the instruments compatible with more cases. The user code will be simpler.\n        \"\"\"\n        if isinstance(market, list):\n            return market\n        from .filter import SeriesDFilter  # pylint: disable=C0415\n\n        if filter_pipe is None:\n            filter_pipe = []\n        config = {\"market\": market, \"filter_pipe\": []}\n        # the order of the filters will affect the result, so we need to keep\n        # the order\n        for filter_t in filter_pipe:\n            if isinstance(filter_t, dict):\n                _config = filter_t\n            elif isinstance(filter_t, SeriesDFilter):\n                _config = filter_t.to_config()\n            else:\n                raise TypeError(\n                    f\"Unsupported filter types: {type(filter_t)}! Filter only supports dict or isinstance(filter, SeriesDFilter)\"\n                )\n            config[\"filter_pipe\"].append(_config)\n        return config\n\n    @abc.abstractmethod\n    def list_instruments(self, instruments, start_time=None, end_time=None, freq=\"day\", as_list=False):\n        \"\"\"List the instruments based on a certain stockpool config.\n\n        Parameters\n        ----------\n        instruments : dict\n            stockpool config.\n        start_time : str\n            start of the time range.\n        end_time : str\n            end of the time range.\n        as_list : bool\n            return instruments as list or dict.\n\n        Returns\n        -------\n        dict or list\n            instruments list or dictionary with time spans\n        \"\"\"\n        raise NotImplementedError(\"Subclass of InstrumentProvider must implement `list_instruments` method\")\n\n    def _uri(self, instruments, start_time=None, end_time=None, freq=\"day\", as_list=False):\n        return hash_args(instruments, start_time, end_time, freq, as_list)\n\n    # instruments type\n    LIST = \"LIST\"\n    DICT = \"DICT\"\n    CONF = \"CONF\"\n\n    @classmethod\n    def get_inst_type(cls, inst):\n        if \"market\" in inst:\n            return cls.CONF\n        if isinstance(inst, dict):\n            return cls.DICT\n        if isinstance(inst, (list, tuple, pd.Index, np.ndarray)):\n            return cls.LIST\n        raise ValueError(f\"Unknown instrument type {inst}\")\n\n\nclass FeatureProvider(abc.ABC):\n    \"\"\"Feature provider class\n\n    Provide feature data.\n    \"\"\"\n\n    @abc.abstractmethod\n    def feature(self, instrument, field, start_time, end_time, freq):\n        \"\"\"Get feature data.\n\n        Parameters\n        ----------\n        instrument : str\n            a certain instrument.\n        field : str\n            a certain field of feature.\n        start_time : str\n            start of the time range.\n        end_time : str\n            end of the time range.\n        freq : str\n            time frequency, available: year/quarter/month/week/day.\n\n        Returns\n        -------\n        pd.Series\n            data of a certain feature\n        \"\"\"\n        raise NotImplementedError(\"Subclass of FeatureProvider must implement `feature` method\")\n\n\nclass PITProvider(abc.ABC):\n    @abc.abstractmethod\n    def period_feature(\n        self,\n        instrument,\n        field,\n        start_index: int,\n        end_index: int,\n        cur_time: pd.Timestamp,\n        period: Optional[int] = None,\n    ) -> pd.Series:\n        \"\"\"\n        get the historical periods data series between `start_index` and `end_index`\n\n        Parameters\n        ----------\n        start_index: int\n            start_index is a relative index to the latest period to cur_time\n\n        end_index: int\n            end_index is a relative index to the latest period to cur_time\n            in most cases, the start_index and end_index will be a non-positive values\n            For example, start_index == -3 end_index == 0 and current period index is cur_idx,\n            then the data between [start_index + cur_idx, end_index + cur_idx] will be retrieved.\n\n        period: int\n            This is used for query specific period.\n            The period is represented with int in Qlib. (e.g. 202001 may represent the first quarter in 2020)\n            NOTE: `period`  will override `start_index` and `end_index`\n\n        Returns\n        -------\n        pd.Series\n            The index will be integers to indicate the periods of the data\n            An typical examples will be\n            TODO\n\n        Raises\n        ------\n        FileNotFoundError\n            This exception will be raised if the queried data do not exist.\n        \"\"\"\n        raise NotImplementedError(f\"Please implement the `period_feature` method\")\n\n\nclass ExpressionProvider(abc.ABC):\n    \"\"\"Expression provider class\n\n    Provide Expression data.\n    \"\"\"\n\n    def __init__(self):\n        self.expression_instance_cache = {}\n\n    def get_expression_instance(self, field):\n        try:\n            if field in self.expression_instance_cache:\n                expression = self.expression_instance_cache[field]\n            else:\n                expression = eval(parse_field(field))\n                self.expression_instance_cache[field] = expression\n        except NameError as e:\n            get_module_logger(\"data\").exception(\n                \"ERROR: field [%s] contains invalid operator/variable [%s]\" % (str(field), str(e).split()[1])\n            )\n            raise\n        except SyntaxError:\n            get_module_logger(\"data\").exception(\"ERROR: field [%s] contains invalid syntax\" % str(field))\n            raise\n        return expression\n\n    @abc.abstractmethod\n    def expression(self, instrument, field, start_time=None, end_time=None, freq=\"day\") -> pd.Series:\n        \"\"\"Get Expression data.\n\n        The responsibility of `expression`\n        - parse the `field` and `load` the according data.\n        - When loading the data, it should handle the time dependency of the data. `get_expression_instance` is commonly used in this method\n\n        Parameters\n        ----------\n        instrument : str\n            a certain instrument.\n        field : str\n            a certain field of feature.\n        start_time : str\n            start of the time range.\n        end_time : str\n            end of the time range.\n        freq : str\n            time frequency, available: year/quarter/month/week/day.\n\n        Returns\n        -------\n        pd.Series\n            data of a certain expression\n\n            The data has two types of format\n\n            1) expression with datetime index\n\n            2) expression with integer index\n\n                - because the datetime is not as good as\n        \"\"\"\n        raise NotImplementedError(\"Subclass of ExpressionProvider must implement `Expression` method\")\n\n\nclass DatasetProvider(abc.ABC):\n    \"\"\"Dataset provider class\n\n    Provide Dataset data.\n    \"\"\"\n\n    @abc.abstractmethod\n    def dataset(self, instruments, fields, start_time=None, end_time=None, freq=\"day\", inst_processors=[]):\n        \"\"\"Get dataset data.\n\n        Parameters\n        ----------\n        instruments : list or dict\n            list/dict of instruments or dict of stockpool config.\n        fields : list\n            list of feature instances.\n        start_time : str\n            start of the time range.\n        end_time : str\n            end of the time range.\n        freq : str\n            time frequency.\n        inst_processors:  Iterable[Union[dict, InstProcessor]]\n            the operations performed on each instrument\n\n        Returns\n        ----------\n        pd.DataFrame\n            a pandas dataframe with <instrument, datetime> index.\n        \"\"\"\n        raise NotImplementedError(\"Subclass of DatasetProvider must implement `Dataset` method\")\n\n    def _uri(\n        self,\n        instruments,\n        fields,\n        start_time=None,\n        end_time=None,\n        freq=\"day\",\n        disk_cache=1,\n        inst_processors=[],\n        **kwargs,\n    ):\n        \"\"\"Get task uri, used when generating rabbitmq task in qlib_server\n\n        Parameters\n        ----------\n        instruments : list or dict\n            list/dict of instruments or dict of stockpool config.\n        fields : list\n            list of feature instances.\n        start_time : str\n            start of the time range.\n        end_time : str\n            end of the time range.\n        freq : str\n            time frequency.\n        disk_cache : int\n            whether to skip(0)/use(1)/replace(2) disk_cache.\n\n        \"\"\"\n        # TODO: qlib-server support inst_processors\n        return DiskDatasetCache._uri(instruments, fields, start_time, end_time, freq, disk_cache, inst_processors)\n\n    @staticmethod\n    def get_instruments_d(instruments, freq):\n        \"\"\"\n        Parse different types of input instruments to output instruments_d\n        Wrong format of input instruments will lead to exception.\n\n        \"\"\"\n        if isinstance(instruments, dict):\n            if \"market\" in instruments:\n                # dict of stockpool config\n                instruments_d = Inst.list_instruments(instruments=instruments, freq=freq, as_list=False)\n            else:\n                # dict of instruments and timestamp\n                instruments_d = instruments\n        elif isinstance(instruments, (list, tuple, pd.Index, np.ndarray)):\n            # list or tuple of a group of instruments\n            instruments_d = list(instruments)\n        else:\n            raise ValueError(\"Unsupported input type for param `instrument`\")\n        return instruments_d\n\n    @staticmethod\n    def get_column_names(fields):\n        \"\"\"\n        Get column names from input fields\n\n        \"\"\"\n        if len(fields) == 0:\n            raise ValueError(\"fields cannot be empty\")\n        column_names = [str(f) for f in fields]\n        return column_names\n\n    @staticmethod\n    def parse_fields(fields):\n        # parse and check the input fields\n        return [ExpressionD.get_expression_instance(f) for f in fields]\n\n    @staticmethod\n    def dataset_processor(instruments_d, column_names, start_time, end_time, freq, inst_processors=[]):\n        \"\"\"\n        Load and process the data, return the data set.\n        - default using multi-kernel method.\n\n        \"\"\"\n        normalize_column_names = normalize_cache_fields(column_names)\n        # One process for one task, so that the memory will be freed quicker.\n        workers = max(min(C.get_kernels(freq), len(instruments_d)), 1)\n\n        # create iterator\n        if isinstance(instruments_d, dict):\n            it = instruments_d.items()\n        else:\n            it = zip(instruments_d, [None] * len(instruments_d))\n\n        inst_l = []\n        task_l = []\n        for inst, spans in it:\n            inst_l.append(inst)\n            task_l.append(\n                delayed(DatasetProvider.inst_calculator)(\n                    inst, start_time, end_time, freq, normalize_column_names, spans, C, inst_processors\n                )\n            )\n\n        data = dict(\n            zip(\n                inst_l,\n                ParallelExt(n_jobs=workers, backend=C.joblib_backend, maxtasksperchild=C.maxtasksperchild)(task_l),\n            )\n        )\n\n        new_data = dict()\n        for inst in sorted(data.keys()):\n            if len(data[inst]) > 0:\n                # NOTE: Python version >= 3.6; in versions after python3.6, dict will always guarantee the insertion order\n                new_data[inst] = data[inst]\n\n        if len(new_data) > 0:\n            data = pd.concat(new_data, names=[\"instrument\"], sort=False)\n            data = DiskDatasetCache.cache_to_origin_data(data, column_names)\n        else:\n            data = pd.DataFrame(\n                index=pd.MultiIndex.from_arrays([[], []], names=(\"instrument\", \"datetime\")),\n                columns=column_names,\n                dtype=np.float32,\n            )\n\n        return data\n\n    @staticmethod\n    def inst_calculator(inst, start_time, end_time, freq, column_names, spans=None, g_config=None, inst_processors=[]):\n        \"\"\"\n        Calculate the expressions for **one** instrument, return a df result.\n        If the expression has been calculated before, load from cache.\n\n        return value: A data frame with index 'datetime' and other data columns.\n\n        \"\"\"\n        # FIXME: Windows OS or MacOS using spawn: https://docs.python.org/3.8/library/multiprocessing.html?highlight=spawn#contexts-and-start-methods\n        # NOTE: This place is compatible with windows, windows multi-process is spawn\n        C.register_from_C(g_config)\n\n        obj = dict()\n        for field in column_names:\n            #  The client does not have expression provider, the data will be loaded from cache using static method.\n            obj[field] = ExpressionD.expression(inst, field, start_time, end_time, freq)\n\n        data = pd.DataFrame(obj)\n        if not data.empty and not np.issubdtype(data.index.dtype, np.dtype(\"M\")):\n            # If the underlaying provides the data not in datetime format, we'll convert it into datetime format\n            _calendar = Cal.calendar(freq=freq)\n            data.index = _calendar[data.index.values.astype(int)]\n        data.index.names = [\"datetime\"]\n\n        if not data.empty and spans is not None:\n            mask = np.zeros(len(data), dtype=bool)\n            for begin, end in spans:\n                mask |= (data.index >= begin) & (data.index <= end)\n            data = data[mask]\n\n        for _processor in inst_processors:\n            if _processor:\n                _processor_obj = init_instance_by_config(_processor, accept_types=InstProcessor)\n                data = _processor_obj(data, instrument=inst)\n        return data\n\n\nclass LocalCalendarProvider(CalendarProvider, ProviderBackendMixin):\n    \"\"\"Local calendar data provider class\n\n    Provide calendar data from local data source.\n    \"\"\"\n\n    def __init__(self, remote=False, backend={}):\n        super().__init__()\n        self.remote = remote\n        self.backend = backend\n\n    def load_calendar(self, freq, future):\n        \"\"\"Load original calendar timestamp from file.\n\n        Parameters\n        ----------\n        freq : str\n            frequency of read calendar file.\n        future: bool\n        Returns\n        ----------\n        list\n            list of timestamps\n        \"\"\"\n        try:\n            backend_obj = self.backend_obj(freq=freq, future=future).data\n        except ValueError:\n            if future:\n                get_module_logger(\"data\").warning(\n                    f\"load calendar error: freq={freq}, future={future}; return current calendar!\"\n                )\n                get_module_logger(\"data\").warning(\n                    \"You can get future calendar by referring to the following document: https://github.com/microsoft/qlib/blob/main/scripts/data_collector/contrib/README.md\"\n                )\n                backend_obj = self.backend_obj(freq=freq, future=False).data\n            else:\n                raise\n\n        return [pd.Timestamp(x) for x in backend_obj]\n\n\nclass LocalInstrumentProvider(InstrumentProvider, ProviderBackendMixin):\n    \"\"\"Local instrument data provider class\n\n    Provide instrument data from local data source.\n    \"\"\"\n\n    def __init__(self, backend={}) -> None:\n        super().__init__()\n        self.backend = backend\n\n    def _load_instruments(self, market, freq):\n        return self.backend_obj(market=market, freq=freq).data\n\n    def list_instruments(self, instruments, start_time=None, end_time=None, freq=\"day\", as_list=False):\n        market = instruments[\"market\"]\n        if market in H[\"i\"]:\n            _instruments = H[\"i\"][market]\n        else:\n            _instruments = self._load_instruments(market, freq=freq)\n            H[\"i\"][market] = _instruments\n        # strip\n        # use calendar boundary\n        cal = Cal.calendar(freq=freq)\n        start_time = pd.Timestamp(start_time or cal[0])\n        end_time = pd.Timestamp(end_time or cal[-1])\n        _instruments_filtered = {\n            inst: list(\n                filter(\n                    lambda x: x[0] <= x[1],\n                    [(max(start_time, pd.Timestamp(x[0])), min(end_time, pd.Timestamp(x[1]))) for x in spans],\n                )\n            )\n            for inst, spans in _instruments.items()\n        }\n        _instruments_filtered = {key: value for key, value in _instruments_filtered.items() if value}\n        # filter\n        filter_pipe = instruments[\"filter_pipe\"]\n        for filter_config in filter_pipe:\n            from . import filter as F  # pylint: disable=C0415\n\n            filter_t = getattr(F, filter_config[\"filter_type\"]).from_config(filter_config)\n            _instruments_filtered = filter_t(_instruments_filtered, start_time, end_time, freq)\n        # as list\n        if as_list:\n            return list(_instruments_filtered)\n        return _instruments_filtered\n\n\nclass LocalFeatureProvider(FeatureProvider, ProviderBackendMixin):\n    \"\"\"Local feature data provider class\n\n    Provide feature data from local data source.\n    \"\"\"\n\n    def __init__(self, remote=False, backend={}):\n        super().__init__()\n        self.remote = remote\n        self.backend = backend\n\n    def feature(self, instrument, field, start_index, end_index, freq):\n        # validate\n        field = str(field)[1:]\n        instrument = code_to_fname(instrument)\n        return self.backend_obj(instrument=instrument, field=field, freq=freq)[start_index : end_index + 1]\n\n\nclass LocalPITProvider(PITProvider):\n    # TODO: Add PIT backend file storage\n    # NOTE: This class is not multi-threading-safe!!!!\n\n    def period_feature(self, instrument, field, start_index, end_index, cur_time, period=None):\n        if not isinstance(cur_time, pd.Timestamp):\n            raise ValueError(\n                f\"Expected pd.Timestamp for `cur_time`, got '{cur_time}'. Advices: you can't query PIT data directly(e.g. '$$roewa_q'), you must use `P` operator to convert data to each day (e.g. 'P($$roewa_q)')\"\n            )\n\n        assert end_index <= 0  # PIT don't support querying future data\n\n        DATA_RECORDS = [\n            (\"date\", C.pit_record_type[\"date\"]),\n            (\"period\", C.pit_record_type[\"period\"]),\n            (\"value\", C.pit_record_type[\"value\"]),\n            (\"_next\", C.pit_record_type[\"index\"]),\n        ]\n        VALUE_DTYPE = C.pit_record_type[\"value\"]\n\n        field = str(field).lower()[2:]\n        instrument = code_to_fname(instrument)\n\n        # {For acceleration\n        # start_index, end_index, cur_index = kwargs[\"info\"]\n        # if cur_index == start_index:\n        #     if not hasattr(self, \"all_fields\"):\n        #         self.all_fields = []\n        #     self.all_fields.append(field)\n        #     if not hasattr(self, \"period_index\"):\n        #         self.period_index = {}\n        #     if field not in self.period_index:\n        #         self.period_index[field] = {}\n        # For acceleration}\n\n        if not field.endswith(\"_q\") and not field.endswith(\"_a\"):\n            raise ValueError(\"period field must ends with '_q' or '_a'\")\n        quarterly = field.endswith(\"_q\")\n        index_path = C.dpm.get_data_uri() / \"financial\" / instrument.lower() / f\"{field}.index\"\n        data_path = C.dpm.get_data_uri() / \"financial\" / instrument.lower() / f\"{field}.data\"\n        if not (index_path.exists() and data_path.exists()):\n            raise FileNotFoundError(\"No file is found.\")\n        # NOTE: The most significant performance loss is here.\n        # Does the acceleration that makes the program complicated really matters?\n        # - It makes parameters of the interface complicate\n        # - It does not performance in the optimal way (places all the pieces together, we may achieve higher performance)\n        #    - If we design it carefully, we can go through for only once to get the historical evolution of the data.\n        # So I decide to deprecated previous implementation and keep the logic of the program simple\n        # Instead, I'll add a cache for the index file.\n        data = np.fromfile(data_path, dtype=DATA_RECORDS)\n\n        # find all revision periods before `cur_time`\n        cur_time_int = int(cur_time.year) * 10000 + int(cur_time.month) * 100 + int(cur_time.day)\n        loc = np.searchsorted(data[\"date\"], cur_time_int, side=\"right\")\n        if loc <= 0:\n            return pd.Series(dtype=C.pit_record_type[\"value\"])\n        last_period = data[\"period\"][:loc].max()  # return the latest quarter\n        first_period = data[\"period\"][:loc].min()\n        period_list = get_period_list(first_period, last_period, quarterly)\n        if period is not None:\n            # NOTE: `period` has higher priority than `start_index` & `end_index`\n            if period not in period_list:\n                return pd.Series(dtype=C.pit_record_type[\"value\"])\n            else:\n                period_list = [period]\n        else:\n            period_list = period_list[max(0, len(period_list) + start_index - 1) : len(period_list) + end_index]\n        value = np.full((len(period_list),), np.nan, dtype=VALUE_DTYPE)\n        for i, p in enumerate(period_list):\n            # last_period_index = self.period_index[field].get(period)  # For acceleration\n            value[i], now_period_index = read_period_data(\n                index_path, data_path, p, cur_time_int, quarterly  # , last_period_index  # For acceleration\n            )\n            # self.period_index[field].update({period: now_period_index})  # For acceleration\n        # NOTE: the index is period_list; So it may result in unexpected values(e.g. nan)\n        # when calculation between different features and only part of its financial indicator is published\n        series = pd.Series(value, index=period_list, dtype=VALUE_DTYPE)\n\n        # {For acceleration\n        # if cur_index == end_index:\n        #     self.all_fields.remove(field)\n        #     if not len(self.all_fields):\n        #         del self.all_fields\n        #         del self.period_index\n        # For acceleration}\n\n        return series\n\n\nclass LocalExpressionProvider(ExpressionProvider):\n    \"\"\"Local expression data provider class\n\n    Provide expression data from local data source.\n    \"\"\"\n\n    def __init__(self, time2idx=True):\n        super().__init__()\n        self.time2idx = time2idx\n\n    def expression(self, instrument, field, start_time=None, end_time=None, freq=\"day\"):\n        expression = self.get_expression_instance(field)\n        start_time = time_to_slc_point(start_time)\n        end_time = time_to_slc_point(end_time)\n\n        # Two kinds of queries are supported\n        # - Index-based expression: this may save a lot of memory because the datetime index is not saved on the disk\n        # - Data with datetime index expression: this will make it more convenient to integrating with some existing databases\n        if self.time2idx:\n            _, _, start_index, end_index = Cal.locate_index(start_time, end_time, freq=freq, future=False)\n            lft_etd, rght_etd = expression.get_extended_window_size()\n            query_start, query_end = max(0, start_index - lft_etd), end_index + rght_etd\n        else:\n            start_index, end_index = query_start, query_end = start_time, end_time\n\n        try:\n            series = expression.load(instrument, query_start, query_end, freq)\n        except Exception as e:\n            get_module_logger(\"data\").debug(\n                f\"Loading expression error: \"\n                f\"instrument={instrument}, field=({field}), start_time={start_time}, end_time={end_time}, freq={freq}. \"\n                f\"error info: {str(e)}\"\n            )\n            raise\n        # Ensure that each column type is consistent\n        # FIXME:\n        # 1) The stock data is currently float. If there is other types of data, this part needs to be re-implemented.\n        # 2) The precision should be configurable\n        try:\n            series = series.astype(np.float32)\n        except ValueError:\n            pass\n        except TypeError:\n            pass\n        if not series.empty:\n            series = series.loc[start_index:end_index]\n        return series\n\n\nclass LocalDatasetProvider(DatasetProvider):\n    \"\"\"Local dataset data provider class\n\n    Provide dataset data from local data source.\n    \"\"\"\n\n    def __init__(self, align_time: bool = True):\n        \"\"\"\n        Parameters\n        ----------\n        align_time : bool\n            Will we align the time to calendar\n            the frequency is flexible in some dataset and can't be aligned.\n            For the data with fixed frequency with a shared calendar, the align data to the calendar will provides following benefits\n\n            - Align queries to the same parameters, so the cache can be shared.\n        \"\"\"\n        super().__init__()\n        self.align_time = align_time\n\n    def dataset(\n        self,\n        instruments,\n        fields,\n        start_time=None,\n        end_time=None,\n        freq=\"day\",\n        inst_processors=[],\n    ):\n        instruments_d = self.get_instruments_d(instruments, freq)\n        column_names = self.get_column_names(fields)\n        if self.align_time:\n            # NOTE: if the frequency is a fixed value.\n            # align the data to fixed calendar point\n            cal = Cal.calendar(start_time, end_time, freq)\n            if len(cal) == 0:\n                return pd.DataFrame(\n                    index=pd.MultiIndex.from_arrays([[], []], names=(\"instrument\", \"datetime\")), columns=column_names\n                )\n            start_time = cal[0]\n            end_time = cal[-1]\n        data = self.dataset_processor(\n            instruments_d, column_names, start_time, end_time, freq, inst_processors=inst_processors\n        )\n\n        return data\n\n    @staticmethod\n    def multi_cache_walker(instruments, fields, start_time=None, end_time=None, freq=\"day\"):\n        \"\"\"\n        This method is used to prepare the expression cache for the client.\n        Then the client will load the data from expression cache by itself.\n\n        \"\"\"\n        instruments_d = DatasetProvider.get_instruments_d(instruments, freq)\n        column_names = DatasetProvider.get_column_names(fields)\n        cal = Cal.calendar(start_time, end_time, freq)\n        if len(cal) == 0:\n            return\n        start_time = cal[0]\n        end_time = cal[-1]\n        workers = max(min(C.kernels, len(instruments_d)), 1)\n\n        ParallelExt(n_jobs=workers, backend=C.joblib_backend, maxtasksperchild=C.maxtasksperchild)(\n            delayed(LocalDatasetProvider.cache_walker)(inst, start_time, end_time, freq, column_names)\n            for inst in instruments_d\n        )\n\n    @staticmethod\n    def cache_walker(inst, start_time, end_time, freq, column_names):\n        \"\"\"\n        If the expressions of one instrument haven't been calculated before,\n        calculate it and write it into expression cache.\n\n        \"\"\"\n        for field in column_names:\n            ExpressionD.expression(inst, field, start_time, end_time, freq)\n\n\nclass ClientCalendarProvider(CalendarProvider):\n    \"\"\"Client calendar data provider class\n\n    Provide calendar data by requesting data from server as a client.\n    \"\"\"\n\n    def __init__(self):\n        self.conn = None\n        self.queue = queue.Queue()\n\n    def set_conn(self, conn):\n        self.conn = conn\n\n    def calendar(self, start_time=None, end_time=None, freq=\"day\", future=False):\n        self.conn.send_request(\n            request_type=\"calendar\",\n            request_content={\"start_time\": str(start_time), \"end_time\": str(end_time), \"freq\": freq, \"future\": future},\n            msg_queue=self.queue,\n            msg_proc_func=lambda response_content: [pd.Timestamp(c) for c in response_content],\n        )\n        result = self.queue.get(timeout=C[\"timeout\"])\n        return result\n\n\nclass ClientInstrumentProvider(InstrumentProvider):\n    \"\"\"Client instrument data provider class\n\n    Provide instrument data by requesting data from server as a client.\n    \"\"\"\n\n    def __init__(self):\n        self.conn = None\n        self.queue = queue.Queue()\n\n    def set_conn(self, conn):\n        self.conn = conn\n\n    def list_instruments(self, instruments, start_time=None, end_time=None, freq=\"day\", as_list=False):\n        def inst_msg_proc_func(response_content):\n            if isinstance(response_content, dict):\n                instrument = {\n                    i: [(pd.Timestamp(s), pd.Timestamp(e)) for s, e in t] for i, t in response_content.items()\n                }\n            else:\n                instrument = response_content\n            return instrument\n\n        self.conn.send_request(\n            request_type=\"instrument\",\n            request_content={\n                \"instruments\": instruments,\n                \"start_time\": str(start_time),\n                \"end_time\": str(end_time),\n                \"freq\": freq,\n                \"as_list\": as_list,\n            },\n            msg_queue=self.queue,\n            msg_proc_func=inst_msg_proc_func,\n        )\n        result = self.queue.get(timeout=C[\"timeout\"])\n        if isinstance(result, Exception):\n            raise result\n        get_module_logger(\"data\").debug(\"get result\")\n        return result\n\n\nclass ClientDatasetProvider(DatasetProvider):\n    \"\"\"Client dataset data provider class\n\n    Provide dataset data by requesting data from server as a client.\n    \"\"\"\n\n    def __init__(self):\n        self.conn = None\n\n    def set_conn(self, conn):\n        self.conn = conn\n        self.queue = queue.Queue()\n\n    def dataset(\n        self,\n        instruments,\n        fields,\n        start_time=None,\n        end_time=None,\n        freq=\"day\",\n        disk_cache=0,\n        return_uri=False,\n        inst_processors=[],\n    ):\n        if Inst.get_inst_type(instruments) == Inst.DICT:\n            get_module_logger(\"data\").warning(\n                \"Getting features from a dict of instruments is not recommended because the features will not be \"\n                \"cached! \"\n                \"The dict of instruments will be cleaned every day.\"\n            )\n\n        if disk_cache == 0:\n            \"\"\"\n            Call the server to generate the expression cache.\n            Then load the data from the expression cache directly.\n            - default using multi-kernel method.\n\n            \"\"\"\n            self.conn.send_request(\n                request_type=\"feature\",\n                request_content={\n                    \"instruments\": instruments,\n                    \"fields\": fields,\n                    \"start_time\": start_time,\n                    \"end_time\": end_time,\n                    \"freq\": freq,\n                    \"disk_cache\": 0,\n                },\n                msg_queue=self.queue,\n            )\n            feature_uri = self.queue.get(timeout=C[\"timeout\"])\n            if isinstance(feature_uri, Exception):\n                raise feature_uri\n            else:\n                instruments_d = self.get_instruments_d(instruments, freq)\n                column_names = self.get_column_names(fields)\n                cal = Cal.calendar(start_time, end_time, freq)\n                if len(cal) == 0:\n                    return pd.DataFrame(\n                        index=pd.MultiIndex.from_arrays([[], []], names=(\"instrument\", \"datetime\")),\n                        columns=column_names,\n                    )\n                start_time = cal[0]\n                end_time = cal[-1]\n\n                data = self.dataset_processor(instruments_d, column_names, start_time, end_time, freq, inst_processors)\n                if return_uri:\n                    return data, feature_uri\n                else:\n                    return data\n        else:\n            \"\"\"\n            Call the server to generate the data-set cache, get the uri of the cache file.\n            Then load the data from the file on NFS directly.\n            - using single-process implementation.\n\n            \"\"\"\n            # TODO: support inst_processors, need to change the code of qlib-server at the same time\n            # FIXME: The cache after resample, when read again and intercepted with end_time, results in incomplete data date\n            if inst_processors:\n                raise ValueError(\n                    f\"{self.__class__.__name__} does not support inst_processor. \"\n                    f\"Please use `D.features(disk_cache=0)` or `qlib.init(dataset_cache=None)`\"\n                )\n            self.conn.send_request(\n                request_type=\"feature\",\n                request_content={\n                    \"instruments\": instruments,\n                    \"fields\": fields,\n                    \"start_time\": start_time,\n                    \"end_time\": end_time,\n                    \"freq\": freq,\n                    \"disk_cache\": 1,\n                },\n                msg_queue=self.queue,\n            )\n            # - Done in callback\n            feature_uri = self.queue.get(timeout=C[\"timeout\"])\n            if isinstance(feature_uri, Exception):\n                raise feature_uri\n            get_module_logger(\"data\").debug(\"get result\")\n            try:\n                # pre-mound nfs, used for demo\n                mnt_feature_uri = C.dpm.get_data_uri(freq).joinpath(C.dataset_cache_dir_name, feature_uri)\n                df = DiskDatasetCache.read_data_from_cache(mnt_feature_uri, start_time, end_time, fields)\n                get_module_logger(\"data\").debug(\"finish slicing data\")\n                if return_uri:\n                    return df, feature_uri\n                return df\n            except AttributeError as attribute_e:\n                raise IOError(\"Unable to fetch instruments from remote server!\") from attribute_e\n\n\nclass BaseProvider:\n    \"\"\"Local provider class\n    It is a set of interface that allow users to access data.\n    Because PITD is not exposed publicly to users, so it is not included in the interface.\n\n    To keep compatible with old qlib provider.\n    \"\"\"\n\n    def calendar(self, start_time=None, end_time=None, freq=\"day\", future=False):\n        return Cal.calendar(start_time, end_time, freq, future=future)\n\n    def instruments(self, market=\"all\", filter_pipe=None, start_time=None, end_time=None):\n        if start_time is not None or end_time is not None:\n            get_module_logger(\"Provider\").warning(\n                \"The instruments corresponds to a stock pool. \"\n                \"Parameters `start_time` and `end_time` does not take effect now.\"\n            )\n        return InstrumentProvider.instruments(market, filter_pipe)\n\n    def list_instruments(self, instruments, start_time=None, end_time=None, freq=\"day\", as_list=False):\n        return Inst.list_instruments(instruments, start_time, end_time, freq, as_list)\n\n    def features(\n        self,\n        instruments,\n        fields,\n        start_time=None,\n        end_time=None,\n        freq=\"day\",\n        disk_cache=None,\n        inst_processors=[],\n    ):\n        \"\"\"\n        Parameters\n        ----------\n        disk_cache : int\n            whether to skip(0)/use(1)/replace(2) disk_cache\n\n\n        This function will try to use cache method which has a keyword `disk_cache`,\n        and will use provider method if a type error is raised because the DatasetD instance\n        is a provider class.\n        \"\"\"\n        disk_cache = C.default_disk_cache if disk_cache is None else disk_cache\n        fields = list(fields)  # In case of tuple.\n        try:\n            return DatasetD.dataset(\n                instruments, fields, start_time, end_time, freq, disk_cache, inst_processors=inst_processors\n            )\n        except TypeError:\n            return DatasetD.dataset(instruments, fields, start_time, end_time, freq, inst_processors=inst_processors)\n\n\nclass LocalProvider(BaseProvider):\n    def _uri(self, type, **kwargs):\n        \"\"\"_uri\n        The server hope to get the uri of the request. The uri will be decided\n        by the dataprovider. For ex, different cache layer has different uri.\n\n        :param type: The type of resource for the uri\n        :param **kwargs:\n        \"\"\"\n        if type == \"calendar\":\n            return Cal._uri(**kwargs)\n        elif type == \"instrument\":\n            return Inst._uri(**kwargs)\n        elif type == \"feature\":\n            return DatasetD._uri(**kwargs)\n\n    def features_uri(self, instruments, fields, start_time, end_time, freq, disk_cache=1):\n        \"\"\"features_uri\n\n        Return the uri of the generated cache of features/dataset\n\n        :param disk_cache:\n        :param instruments:\n        :param fields:\n        :param start_time:\n        :param end_time:\n        :param freq:\n        \"\"\"\n        return DatasetD._dataset_uri(instruments, fields, start_time, end_time, freq, disk_cache)\n\n\nclass ClientProvider(BaseProvider):\n    \"\"\"Client Provider\n\n    Requesting data from server as a client. Can propose requests:\n\n        - Calendar : Directly respond a list of calendars\n        - Instruments (without filter): Directly respond a list/dict of instruments\n        - Instruments (with filters):  Respond a list/dict of instruments\n        - Features : Respond a cache uri\n\n    The general workflow is described as follows:\n    When the user use client provider to propose a request, the client provider will connect the server and send the request. The client will start to wait for the response. The response will be made instantly indicating whether the cache is available. The waiting procedure will terminate only when the client get the response saying `feature_available` is true.\n    `BUG` : Everytime we make request for certain data we need to connect to the server, wait for the response and disconnect from it. We can't make a sequence of requests within one connection. You can refer to https://python-socketio.readthedocs.io/en/latest/client.html for documentation of python-socketIO client.\n    \"\"\"\n\n    def __init__(self):\n        def is_instance_of_provider(instance: object, cls: type):\n            if isinstance(instance, Wrapper):\n                p = getattr(instance, \"_provider\", None)\n\n                return False if p is None else isinstance(p, cls)\n\n            return isinstance(instance, cls)\n\n        from .client import Client  # pylint: disable=C0415\n\n        self.client = Client(C.flask_server, C.flask_port)\n        self.logger = get_module_logger(self.__class__.__name__)\n        if is_instance_of_provider(Cal, ClientCalendarProvider):\n            Cal.set_conn(self.client)\n        if is_instance_of_provider(Inst, ClientInstrumentProvider):\n            Inst.set_conn(self.client)\n        if hasattr(DatasetD, \"provider\"):\n            DatasetD.provider.set_conn(self.client)\n        else:\n            DatasetD.set_conn(self.client)\n\n\nimport sys\n\nif sys.version_info >= (3, 9):\n    from typing import Annotated\n\n    CalendarProviderWrapper = Annotated[CalendarProvider, Wrapper]\n    InstrumentProviderWrapper = Annotated[InstrumentProvider, Wrapper]\n    FeatureProviderWrapper = Annotated[FeatureProvider, Wrapper]\n    PITProviderWrapper = Annotated[PITProvider, Wrapper]\n    ExpressionProviderWrapper = Annotated[ExpressionProvider, Wrapper]\n    DatasetProviderWrapper = Annotated[DatasetProvider, Wrapper]\n    BaseProviderWrapper = Annotated[BaseProvider, Wrapper]\nelse:\n    CalendarProviderWrapper = CalendarProvider\n    InstrumentProviderWrapper = InstrumentProvider\n    FeatureProviderWrapper = FeatureProvider\n    PITProviderWrapper = PITProvider\n    ExpressionProviderWrapper = ExpressionProvider\n    DatasetProviderWrapper = DatasetProvider\n    BaseProviderWrapper = BaseProvider\n\nCal: CalendarProviderWrapper = Wrapper()\nInst: InstrumentProviderWrapper = Wrapper()\nFeatureD: FeatureProviderWrapper = Wrapper()\nPITD: PITProviderWrapper = Wrapper()\nExpressionD: ExpressionProviderWrapper = Wrapper()\nDatasetD: DatasetProviderWrapper = Wrapper()\nD: BaseProviderWrapper = Wrapper()\n\n\ndef register_all_wrappers(C):\n    \"\"\"register_all_wrappers\"\"\"\n    logger = get_module_logger(\"data\")\n    module = get_module_by_module_path(\"qlib.data\")\n\n    _calendar_provider = init_instance_by_config(C.calendar_provider, module)\n    if getattr(C, \"calendar_cache\", None) is not None:\n        _calendar_provider = init_instance_by_config(C.calendar_cache, module, provide=_calendar_provider)\n    register_wrapper(Cal, _calendar_provider, \"qlib.data\")\n    logger.debug(f\"registering Cal {C.calendar_provider}-{C.calendar_cache}\")\n\n    _instrument_provider = init_instance_by_config(C.instrument_provider, module)\n    register_wrapper(Inst, _instrument_provider, \"qlib.data\")\n    logger.debug(f\"registering Inst {C.instrument_provider}\")\n\n    if getattr(C, \"feature_provider\", None) is not None:\n        feature_provider = init_instance_by_config(C.feature_provider, module)\n        register_wrapper(FeatureD, feature_provider, \"qlib.data\")\n        logger.debug(f\"registering FeatureD {C.feature_provider}\")\n\n    if getattr(C, \"pit_provider\", None) is not None:\n        pit_provider = init_instance_by_config(C.pit_provider, module)\n        register_wrapper(PITD, pit_provider, \"qlib.data\")\n        logger.debug(f\"registering PITD {C.pit_provider}\")\n\n    if getattr(C, \"expression_provider\", None) is not None:\n        # This provider is unnecessary in client provider\n        _eprovider = init_instance_by_config(C.expression_provider, module)\n        if getattr(C, \"expression_cache\", None) is not None:\n            _eprovider = init_instance_by_config(C.expression_cache, module, provider=_eprovider)\n        register_wrapper(ExpressionD, _eprovider, \"qlib.data\")\n        logger.debug(f\"registering ExpressionD {C.expression_provider}-{C.expression_cache}\")\n\n    _dprovider = init_instance_by_config(C.dataset_provider, module)\n    if getattr(C, \"dataset_cache\", None) is not None:\n        _dprovider = init_instance_by_config(C.dataset_cache, module, provider=_dprovider)\n    register_wrapper(DatasetD, _dprovider, \"qlib.data\")\n    logger.debug(f\"registering DatasetD {C.dataset_provider}-{C.dataset_cache}\")\n\n    register_wrapper(D, C.provider, \"qlib.data\")\n    logger.debug(f\"registering D {C.provider}\")\n"
  },
  {
    "path": "qlib/data/dataset/__init__.py",
    "content": "from ...utils.serial import Serializable\nfrom typing import Callable, Union, List, Tuple, Dict, Text, Optional\nfrom ...utils import init_instance_by_config, np_ffill, time_to_slc_point\nfrom ...log import get_module_logger\nfrom .handler import DataHandler, DataHandlerLP\nfrom copy import copy, deepcopy\nfrom inspect import getfullargspec\nimport pandas as pd\nimport numpy as np\nimport bisect\nfrom ...utils import lazy_sort_index\nfrom .utils import get_level_index\n\n\nclass Dataset(Serializable):\n    \"\"\"\n    Preparing data for model training and inferencing.\n    \"\"\"\n\n    def __init__(self, **kwargs):\n        \"\"\"\n        init is designed to finish following steps:\n\n        - init the sub instance and the state of the dataset(info to prepare the data)\n            - The name of essential state for preparing data should not start with '_' so that it could be serialized on disk when serializing.\n\n        - setup data\n            - The data related attributes' names should start with '_' so that it will not be saved on disk when serializing.\n\n        The data could specify the info to calculate the essential data for preparation\n        \"\"\"\n        self.setup_data(**kwargs)\n        super().__init__()\n\n    def config(self, **kwargs):\n        \"\"\"\n        config is designed to configure and parameters that cannot be learned from the data\n        \"\"\"\n        super().config(**kwargs)\n\n    def setup_data(self, **kwargs):\n        \"\"\"\n        Setup the data.\n\n        We split the setup_data function for following situation:\n\n        - User have a Dataset object with learned status on disk.\n\n        - User load the Dataset object from the disk.\n\n        - User call `setup_data` to load new data.\n\n        - User prepare data for model based on previous status.\n        \"\"\"\n\n    def prepare(self, **kwargs) -> object:\n        \"\"\"\n        The type of dataset depends on the model. (It could be pd.DataFrame, pytorch.DataLoader, etc.)\n        The parameters should specify the scope for the prepared data\n        The method should:\n        - process the data\n\n        - return the processed data\n\n        Returns\n        -------\n        object:\n            return the object\n        \"\"\"\n\n\nclass DatasetH(Dataset):\n    \"\"\"\n    Dataset with Data(H)andler\n\n    User should try to put the data preprocessing functions into handler.\n    Only following data processing functions should be placed in Dataset:\n\n    - The processing is related to specific model.\n\n    - The processing is related to data split.\n    \"\"\"\n\n    def __init__(\n        self,\n        handler: Union[Dict, DataHandler],\n        segments: Dict[Text, Tuple],\n        fetch_kwargs: Dict = {},\n        **kwargs,\n    ):\n        \"\"\"\n        Setup the underlying data.\n\n        Parameters\n        ----------\n        handler : Union[dict, DataHandler]\n            handler could be:\n\n            - instance of `DataHandler`\n\n            - config of `DataHandler`.  Please refer to `DataHandler`\n\n        segments : dict\n            Describe the options to segment the data.\n            Here are some examples:\n\n            .. code-block::\n\n                1) 'segments': {\n                        'train': (\"2008-01-01\", \"2014-12-31\"),\n                        'valid': (\"2017-01-01\", \"2020-08-01\",),\n                        'test': (\"2015-01-01\", \"2016-12-31\",),\n                    }\n                2) 'segments': {\n                        'insample': (\"2008-01-01\", \"2014-12-31\"),\n                        'outsample': (\"2017-01-01\", \"2020-08-01\",),\n                    }\n        \"\"\"\n        self.handler: DataHandler = init_instance_by_config(handler, accept_types=DataHandler)\n        self.segments = segments.copy()\n        self.fetch_kwargs = copy(fetch_kwargs)\n        super().__init__(**kwargs)\n\n    def config(self, handler_kwargs: dict = None, **kwargs):\n        \"\"\"\n        Initialize the DatasetH\n\n        Parameters\n        ----------\n        handler_kwargs : dict\n            Config of DataHandler, which could include the following arguments:\n\n            - arguments of DataHandler.conf_data, such as 'instruments', 'start_time' and 'end_time'.\n\n        kwargs : dict\n            Config of DatasetH, such as\n\n            - segments : dict\n                Config of segments which is same as 'segments' in self.__init__\n\n        \"\"\"\n        if handler_kwargs is not None:\n            self.handler.config(**handler_kwargs)\n        if \"segments\" in kwargs:\n            self.segments = deepcopy(kwargs.pop(\"segments\"))\n        super().config(**kwargs)\n\n    def setup_data(self, handler_kwargs: dict = None, **kwargs):\n        \"\"\"\n        Setup the Data\n\n        Parameters\n        ----------\n        handler_kwargs : dict\n            init arguments of DataHandler, which could include the following arguments:\n\n            - init_type : Init Type of Handler\n\n            - enable_cache : whether to enable cache\n\n        \"\"\"\n        super().setup_data(**kwargs)\n        if handler_kwargs is not None:\n            self.handler.setup_data(**handler_kwargs)\n\n    def __repr__(self):\n        return \"{name}(handler={handler}, segments={segments})\".format(\n            name=self.__class__.__name__, handler=self.handler, segments=self.segments\n        )\n\n    def _prepare_seg(self, slc, **kwargs):\n        \"\"\"\n        Give a query, retrieve the according data\n\n        Parameters\n        ----------\n        slc : please refer to the docs of `prepare`\n                NOTE: it may not be an instance of slice. It may be a segment of `segments` from `def prepare`\n        \"\"\"\n        if hasattr(self, \"fetch_kwargs\"):\n            return self.handler.fetch(slc, **kwargs, **self.fetch_kwargs)\n        else:\n            return self.handler.fetch(slc, **kwargs)\n\n    def prepare(\n        self,\n        segments: Union[List[Text], Tuple[Text], Text, slice, pd.Index],\n        col_set=DataHandler.CS_ALL,\n        data_key=DataHandlerLP.DK_I,\n        **kwargs,\n    ) -> Union[List[pd.DataFrame], pd.DataFrame]:\n        \"\"\"\n        Prepare the data for learning and inference.\n\n        Parameters\n        ----------\n        segments : Union[List[Text], Tuple[Text], Text, slice]\n            Describe the scope of the data to be prepared\n            Here are some examples:\n\n            - 'train'\n\n            - ['train', 'valid']\n\n        col_set : str\n            The col_set will be passed to self.handler when fetching data.\n            TODO: make it automatic:\n\n            - select DK_I for test data\n            - select DK_L for training data.\n        data_key : str\n            The data to fetch:  DK_*\n            Default is DK_I, which indicate fetching data for **inference**.\n\n        kwargs :\n            The parameters that kwargs may contain:\n                flt_col : str\n                    It only exists in TSDatasetH, can be used to add a column of data(True or False) to filter data.\n                    This parameter is only supported when it is an instance of TSDatasetH.\n\n        Returns\n        -------\n        Union[List[pd.DataFrame], pd.DataFrame]:\n\n        Raises\n        ------\n        NotImplementedError:\n        \"\"\"\n        seg_kwargs = {\"col_set\": col_set, \"data_key\": data_key}\n        seg_kwargs.update(kwargs)\n\n        # Conflictions may happen here\n        # - The fetched data and the segment key may both be string\n        # To resolve the confliction\n        # - The segment name will have higher priorities\n\n        # 1) Use it as segment name first\n        # 1.1) directly fetch split like \"train\" \"valid\" \"test\"\n        if isinstance(segments, str) and segments in self.segments:\n            return self._prepare_seg(self.segments[segments], **seg_kwargs)\n\n        # 1.2) fetch multiple splits like [\"train\", \"valid\"] [\"train\", \"valid\", \"test\"]\n        if isinstance(segments, (list, tuple)) and all(seg in self.segments for seg in segments):\n            return [self._prepare_seg(self.segments[seg], **seg_kwargs) for seg in segments]\n\n        # 2) Use pass it directly to prepare a single seg\n        return self._prepare_seg(segments, **seg_kwargs)\n\n    # helper functions\n    @staticmethod\n    def get_min_time(segments):\n        return DatasetH._get_extrema(segments, 0, (lambda a, b: a > b))\n\n    @staticmethod\n    def get_max_time(segments):\n        return DatasetH._get_extrema(segments, 1, (lambda a, b: a < b))\n\n    @staticmethod\n    def _get_extrema(segments, idx: int, cmp: Callable, key_func=pd.Timestamp):\n        \"\"\"it will act like sort and return the max value or None\"\"\"\n        candidate = None\n        for _, seg in segments.items():\n            point = seg[idx]\n            if point is None:\n                # None indicates unbounded, return directly\n                return None\n            elif candidate is None or cmp(key_func(candidate), key_func(point)):\n                candidate = point\n        return candidate\n\n\nclass TSDataSampler:\n    \"\"\"\n    (T)ime-(S)eries DataSampler\n    This is the result of TSDatasetH\n\n    It works like `torch.data.utils.Dataset`, it provides a very convenient interface for constructing time-series\n    dataset based on tabular data.\n    - On time step dimension, the smaller index indicates the historical data and the larger index indicates the future\n      data.\n\n    If user have further requirements for processing data, user could process them based on `TSDataSampler` or create\n    more powerful subclasses.\n\n    Known Issues:\n    - For performance issues, this Sampler will convert dataframe into arrays for better performance. This could result\n      in a different data type\n\n\n    Indices design:\n        TSDataSampler has a index mechanism to help users query time-series data efficiently.\n\n        The definition of related variables:\n            data_arr: np.ndarray\n                The original data. it will contains all the original data.\n                The querying are often for time-series of a specific stock.\n                By leveraging this data charactoristics to speed up querying, the multi-index of data_arr is rearranged in (instrument, datetime) order\n\n            data_index: pd.MultiIndex with index order <instrument, datetime>\n                it has the same shape with `idx_map`. Each elements of them are expected to be aligned.\n\n            idx_map: np.ndarray\n                It is the indexable data. It originates from data_arr, and then filtered by 1) `start` and `end`  2) `flt_data`\n                    The extra data in data_arr is useful in following cases\n                    1) creating meaningful time series data before `start` instead of padding them with zeros\n                    2) some data are excluded by `flt_data` (e.g. no <X, y> sample pair for that index). but they are still used in time-series in X\n\n                Finnally, it will look like.\n\n                array([[  0,   0],\n                       [  1,   0],\n                       [  2,   0],\n                       ...,\n                       [241, 348],\n                       [242, 348],\n                       [243, 348]], dtype=int32)\n\n                It list all indexable data(some data only used in historical time series data may not be indexabla), the values are the corresponding row and col in idx_df\n            idx_df: pd.DataFrame\n                It aims to map the <datetime, instrument> key to the original position in data_arr\n\n                For example, it may look like (NOTE: the index for a instrument time-series is continoues in memory)\n\n                    instrument SH600000 SH600008 SH600009 SH600010 SH600011 SH600015  ...\n                    datetime\n                    2017-01-03        0      242      473      717      NaN      974  ...\n                    2017-01-04        1      243      474      718      NaN      975  ...\n                    2017-01-05        2      244      475      719      NaN      976  ...\n                    2017-01-06        3      245      476      720      NaN      977  ...\n\n            With these two indices(idx_map, idx_df) and original data(data_arr), we can make the following queries fast (implemented in __getitem__)\n            (1) Get the i-th indexable sample(time-series):   (indexable sample index) -> [idx_map] -> (row col) -> [idx_df] -> (index in data_arr)\n            (2) Get the specific sample by <datetime, instrument>:  (<datetime, instrument>, i.e. <row, col>) -> [idx_df] -> (index in data_arr)\n            (3) Get the index of a time-series data:   (get the <row, col>, refer to (1), (2)) -> [idx_df] -> (all indices in data_arr for time-series)\n    \"\"\"\n\n    # Please refer to the docstring of TSDataSampler for the definition of following attributes\n    data_arr: np.ndarray\n    data_index: pd.MultiIndex\n    idx_map: np.ndarray\n    idx_df: pd.DataFrame\n\n    def __init__(\n        self,\n        data: pd.DataFrame,\n        start,\n        end,\n        step_len: int,\n        fillna_type: str = \"none\",\n        dtype=None,\n        flt_data=None,\n    ):\n        \"\"\"\n        Build a dataset which looks like torch.data.utils.Dataset.\n\n        Parameters\n        ----------\n        data : pd.DataFrame\n            The raw tabular data whose index order is <\"datetime\", \"instrument\">\n        start :\n            The indexable start time\n        end :\n            The indexable end time\n        step_len : int\n            The length of the time-series step\n        fillna_type : int\n            How will qlib handle the sample if there is on sample in a specific date.\n            none:\n                fill with np.nan\n            ffill:\n                ffill with previous sample\n            ffill+bfill:\n                ffill with previous samples first and fill with later samples second\n        flt_data : pd.Series\n            a column of data(True or False) to filter data. Its index order is <\"datetime\", \"instrument\">\n            This feature is essential because:\n            - We want some sample not included due to label-based filtering, but we can't filter them at the beginning due to the features is still important in the feature.\n            None:\n                kepp all data\n\n        \"\"\"\n        self.start = start\n        self.end = end\n        self.step_len = step_len\n        self.fillna_type = fillna_type\n        assert get_level_index(data, \"datetime\") == 0\n        self.data = data.swaplevel().sort_index().copy()\n        data.drop(\n            data.columns, axis=1, inplace=True\n        )  # data is useless since it's passed to a transposed one, hard code to free the memory of this dataframe to avoid three big dataframe in the memory(including: data, self.data, self.data_arr)\n\n        kwargs = {\"object\": self.data}\n        if dtype is not None:\n            kwargs[\"dtype\"] = dtype\n\n        self.data_arr = np.array(**kwargs)  # Get index from numpy.array will much faster than DataFrame.values!\n        # NOTE:\n        # - append last line with full NaN for better performance in `__getitem__`\n        # - Keep the same dtype will result in a better performance\n        self.data_arr = np.append(\n            self.data_arr,\n            np.full((1, self.data_arr.shape[1]), np.nan, dtype=self.data_arr.dtype),\n            axis=0,\n        )\n        self.nan_idx = len(self.data_arr) - 1  # The last line is all NaN; setting it to -1 can cause bug #1716\n\n        # the data type will be changed\n        # The index of usable data is between start_idx and end_idx\n        self.idx_df, self.idx_map = self.build_index(self.data)\n        self.data_index = deepcopy(self.data.index)\n\n        if flt_data is not None:\n            if isinstance(flt_data, pd.DataFrame):\n                assert len(flt_data.columns) == 1\n                flt_data = flt_data.iloc[:, 0]\n            # NOTE: bool(np.nan) is True !!!!!!!!\n            # make sure reindex comes first. Otherwise extra NaN may appear.\n            flt_data = flt_data.swaplevel()\n            flt_data = flt_data.reindex(self.data_index).fillna(False).astype(bool)\n            self.flt_data = flt_data.values\n            self.idx_map = self.flt_idx_map(self.flt_data, self.idx_map)\n            self.data_index = self.data_index[np.where(self.flt_data)[0]]\n        self.idx_map = self.idx_map2arr(self.idx_map)\n        self.idx_map, self.data_index = self.slice_idx_map_and_data_index(\n            self.idx_map, self.idx_df, self.data_index, start, end\n        )\n\n        self.idx_arr = np.array(self.idx_df.values, dtype=np.float64)  # for better performance\n        del self.data  # save memory\n\n    @staticmethod\n    def slice_idx_map_and_data_index(\n        idx_map,\n        idx_df,\n        data_index,\n        start,\n        end,\n    ):\n        assert (\n            len(idx_map) == data_index.shape[0]\n        )  # make sure idx_map and data_index is same so index of idx_map can be used on data_index\n\n        start_row_idx, end_row_idx = idx_df.index.slice_locs(start=time_to_slc_point(start), end=time_to_slc_point(end))\n\n        time_flter_idx = (idx_map[:, 0] < end_row_idx) & (idx_map[:, 0] >= start_row_idx)\n        return idx_map[time_flter_idx], data_index[time_flter_idx]\n\n    @staticmethod\n    def idx_map2arr(idx_map):\n        # pytorch data sampler will have better memory control without large dict or list\n        # - https://github.com/pytorch/pytorch/issues/13243\n        # - https://github.com/airctic/icevision/issues/613\n        # So we convert the dict into int array.\n        # The arr_map is expected to behave the same as idx_map\n\n        dtype = np.int32\n        # set a index out of bound to indicate the none existing\n        no_existing_idx = (np.iinfo(dtype).max, np.iinfo(dtype).max)\n\n        max_idx = max(idx_map.keys())\n        arr_map = []\n        for i in range(max_idx + 1):\n            arr_map.append(idx_map.get(i, no_existing_idx))\n        arr_map = np.array(arr_map, dtype=dtype)\n        return arr_map\n\n    @staticmethod\n    def flt_idx_map(flt_data, idx_map):\n        idx = 0\n        new_idx_map = {}\n        for i, exist in enumerate(flt_data):\n            if exist:\n                new_idx_map[idx] = idx_map[i]\n                idx += 1\n        return new_idx_map\n\n    def get_index(self):\n        \"\"\"\n        Get the pandas index of the data, it will be useful in following scenarios\n        - Special sampler will be used (e.g. user want to sample day by day)\n        \"\"\"\n        return self.data_index.swaplevel()  # to align the order of multiple index of original data received by __init__\n\n    def config(self, **kwargs):\n        # Config the attributes\n        for k, v in kwargs.items():\n            setattr(self, k, v)\n\n    @staticmethod\n    def build_index(data: pd.DataFrame) -> Tuple[pd.DataFrame, dict]:\n        \"\"\"\n        The relation of the data\n\n        Parameters\n        ----------\n        data : pd.DataFrame\n            A DataFrame with index in order <instrument, datetime>\n\n                                      RSQR5     RESI5     WVMA5    LABEL0\n            instrument datetime\n            SH600000   2017-01-03  0.016389  0.461632 -1.154788 -0.048056\n                       2017-01-04  0.884545 -0.110597 -1.059332 -0.030139\n                       2017-01-05  0.507540 -0.535493 -1.099665 -0.644983\n                       2017-01-06 -1.267771 -0.669685 -1.636733  0.295366\n                       2017-01-09  0.339346  0.074317 -0.984989  0.765540\n\n        Returns\n        -------\n        Tuple[pd.DataFrame, dict]:\n            1) the first element:  reshape the original index into a <datetime(row), instrument(column)> 2D dataframe\n                instrument SH600000 SH600008 SH600009 SH600010 SH600011 SH600015  ...\n                datetime\n                2017-01-03        0      242      473      717      NaN      974  ...\n                2017-01-04        1      243      474      718      NaN      975  ...\n                2017-01-05        2      244      475      719      NaN      976  ...\n                2017-01-06        3      245      476      720      NaN      977  ...\n            2) the second element:  {<original index>: <row, col>}\n        \"\"\"\n        # object incase of pandas converting int to float\n        idx_df = pd.Series(range(data.shape[0]), index=data.index, dtype=object)\n        idx_df = lazy_sort_index(idx_df.unstack())\n        # NOTE: the correctness of `__getitem__` depends on columns sorted here\n        idx_df = lazy_sort_index(idx_df, axis=1).T\n\n        idx_map = {}\n        for i, (_, row) in enumerate(idx_df.iterrows()):\n            for j, real_idx in enumerate(row):\n                if not np.isnan(real_idx):\n                    idx_map[real_idx] = (i, j)\n        return idx_df, idx_map\n\n    @property\n    def empty(self):\n        return len(self) == 0\n\n    def _get_indices(self, row: int, col: int) -> np.array:\n        \"\"\"\n        get series indices of self.data_arr from the row, col indices of self.idx_df\n\n        Parameters\n        ----------\n        row : int\n            the row in self.idx_df\n        col : int\n            the col in self.idx_df\n\n        Returns\n        -------\n        np.array:\n            The indices of data of the data\n        \"\"\"\n        indices = self.idx_arr[max(row - self.step_len + 1, 0) : row + 1, col]\n\n        if len(indices) < self.step_len:\n            indices = np.concatenate([np.full((self.step_len - len(indices),), np.nan), indices])\n\n        if self.fillna_type == \"ffill\":\n            indices = np_ffill(indices)\n        elif self.fillna_type == \"ffill+bfill\":\n            indices = np_ffill(np_ffill(indices)[::-1])[::-1]\n        else:\n            assert self.fillna_type == \"none\"\n        return indices\n\n    def _get_row_col(self, idx) -> Tuple[int]:\n        \"\"\"\n        get the col index and row index of a given sample index in self.idx_df\n\n        Parameters\n        ----------\n        idx :\n            the input of  `__getitem__`\n\n        Returns\n        -------\n        Tuple[int]:\n            the row and col index\n        \"\"\"\n        # The the right row number `i` and col number `j` in idx_df\n        if isinstance(idx, (int, np.integer)):\n            real_idx = idx\n            if 0 <= real_idx < len(self.idx_map):\n                i, j = self.idx_map[real_idx]  # TODO: The performance of this line is not good\n            else:\n                raise KeyError(f\"{real_idx} is out of [0, {len(self.idx_map)})\")\n        elif isinstance(idx, tuple):\n            # <TSDataSampler object>[\"datetime\", \"instruments\"]\n            date, inst = idx\n            date = pd.Timestamp(date)\n            i = bisect.bisect_right(self.idx_df.index, date) - 1\n            # NOTE: This relies on the idx_df columns sorted in `__init__`\n            j = bisect.bisect_left(self.idx_df.columns, inst)\n        else:\n            raise NotImplementedError(f\"This type of input is not supported\")\n        return i, j\n\n    def __getitem__(self, idx: Union[int, Tuple[object, str], List[int]]):\n        \"\"\"\n        # We have two method to get the time-series of a sample\n        tsds is a instance of TSDataSampler\n\n        # 1) sample by int index directly\n        tsds[len(tsds) - 1]\n\n        # 2) sample by <datetime,instrument> index\n        tsds['2016-12-31', \"SZ300315\"]\n\n        # The return value will be similar to the data retrieved by following code\n        df.loc(axis=0)['2015-01-01':'2016-12-31', \"SZ300315\"].iloc[-30:]\n\n        Parameters\n        ----------\n        idx : Union[int, Tuple[object, str]]\n        \"\"\"\n        # Multi-index type\n        mtit = (list, np.ndarray)\n        if isinstance(idx, mtit):\n            indices = [self._get_indices(*self._get_row_col(i)) for i in idx]\n            indices = np.concatenate(indices)\n        else:\n            indices = self._get_indices(*self._get_row_col(idx))\n\n        # 1) for better performance, use the last nan line for padding the lost date\n        # 2) In case of precision problems. We use np.float64. # TODO: I'm not sure if whether np.float64 will result in\n        # precision problems. It will not cause any problems in my tests at least\n        indices = np.nan_to_num(indices.astype(np.float64), nan=self.nan_idx).astype(int)\n\n        if (np.diff(indices) == 1).all():  # slicing instead of indexing for speeding up.\n            data = self.data_arr[indices[0] : indices[-1] + 1]\n        else:\n            data = self.data_arr[indices]\n        if isinstance(idx, mtit):\n            # if we get multiple indexes, addition dimension should be added.\n            # <sample_idx, step_idx, feature_idx>\n            data = data.reshape(-1, self.step_len, *data.shape[1:])\n        return data\n\n    def __len__(self):\n        return len(self.idx_map)\n\n\nclass TSDatasetH(DatasetH):\n    \"\"\"\n    (T)ime-(S)eries Dataset (H)andler\n\n\n    Convert the tabular data to Time-Series data\n\n    Requirements analysis\n\n    The typical workflow of a user to get time-series data for an sample\n    - process features\n    - slice proper data from data handler:  dimension of sample <feature, >\n    - Build relation of samples by <time, instrument> index\n        - Be able to sample times series of data <timestep, feature>\n        - It will be better if the interface is like \"torch.utils.data.Dataset\"\n    - User could build customized batch based on the data\n        - The dimension of a batch of data <batch_idx, feature, timestep>\n    \"\"\"\n\n    DEFAULT_STEP_LEN = 30\n\n    def __init__(self, step_len=DEFAULT_STEP_LEN, flt_col: Optional[str] = None, **kwargs):\n        self.step_len = step_len\n        self.flt_col = flt_col\n        super().__init__(**kwargs)\n\n    def config(self, **kwargs):\n        if \"step_len\" in kwargs:\n            self.step_len = kwargs.pop(\"step_len\")\n        super().config(**kwargs)\n\n    def setup_data(self, **kwargs):\n        super().setup_data(**kwargs)\n        # make sure the calendar is updated to latest when loading data from new config\n        cal = self.handler.fetch(col_set=self.handler.CS_RAW).index.get_level_values(\"datetime\").unique()\n        self.cal = sorted(cal)\n\n    @staticmethod\n    def _extend_slice(slc: slice, cal: list, step_len: int) -> slice:\n        # Dataset decide how to slice data(Get more data for timeseries).\n        start, end = slc.start, slc.stop\n        start_idx = bisect.bisect_left(cal, pd.Timestamp(start))\n        pad_start_idx = max(0, start_idx - step_len)\n        pad_start = cal[pad_start_idx]\n        return slice(pad_start, end)\n\n    def _prepare_seg(self, slc: slice, **kwargs) -> TSDataSampler:\n        \"\"\"\n        split the _prepare_raw_seg is to leave a hook for data preprocessing before creating processing data\n        NOTE: TSDatasetH only support slc segment on datetime !!!\n        \"\"\"\n        dtype = kwargs.pop(\"dtype\", None)\n        if not isinstance(slc, slice):\n            slc = slice(*slc)\n        if (flt_col := kwargs.pop(\"flt_col\", None)) is None:\n            flt_col = self.flt_col\n\n        # TSDatasetH will retrieve more data for complete time-series\n        ext_slice = self._extend_slice(slc, self.cal, self.step_len)\n        data = super()._prepare_seg(ext_slice, **kwargs)\n\n        flt_kwargs = deepcopy(kwargs)\n        if flt_col is not None:\n            flt_kwargs[\"col_set\"] = flt_col\n            flt_data = super()._prepare_seg(ext_slice, **flt_kwargs)\n            assert len(flt_data.columns) == 1\n        else:\n            flt_data = None\n\n        tsds = TSDataSampler(\n            data=data,\n            start=slc.start,\n            end=slc.stop,\n            step_len=self.step_len,\n            dtype=dtype,\n            flt_data=flt_data,\n        )\n        return tsds\n\n\n__all__ = [\"Optional\", \"Dataset\", \"DatasetH\"]\n"
  },
  {
    "path": "qlib/data/dataset/handler.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\n# coding=utf-8\nfrom abc import abstractmethod\nimport warnings\nfrom typing import Callable, Union, Tuple, List, Iterator, Optional\n\nimport pandas as pd\n\nfrom qlib.typehint import Literal\nfrom ...log import get_module_logger, TimeInspector\nfrom ...utils import init_instance_by_config\nfrom ...utils.serial import Serializable\nfrom .utils import fetch_df_by_index, fetch_df_by_col\nfrom ...utils import lazy_sort_index\nfrom .loader import DataLoader\n\nfrom . import processor as processor_module\nfrom . import loader as data_loader_module\n\nDATA_KEY_TYPE = Literal[\"raw\", \"infer\", \"learn\"]\n\n\nclass DataHandlerABC(Serializable):\n    \"\"\"\n    Interface for data handler.\n\n    This class does not assume the internal data structure of the data handler.\n    It only defines the interface for external users (uses DataFrame as the internal data structure).\n\n    In the future, the data handler's more detailed implementation should be refactored. Here are some guidelines:\n\n    It covers several components:\n\n    - [data loader] -> internal representation of the data -> data preprocessing -> interface adaptor for the fetch interface\n    - The workflow to combine them all:\n      The workflow may be very complicated. DataHandlerLP is one of the practices, but it can't satisfy all the requirements.\n      So leaving the flexibility to the user to implement the workflow is a more reasonable choice.\n    \"\"\"\n\n    def __init__(self, *args, **kwargs):  # pylint: disable=W0246\n        \"\"\"\n        We should define how to get ready for the fetching.\n        \"\"\"\n        super().__init__(*args, **kwargs)\n\n    CS_ALL = \"__all\"  # return all columns with single-level index column\n    CS_RAW = \"__raw\"  # return raw data with multi-level index column\n\n    # data key\n    DK_R: DATA_KEY_TYPE = \"raw\"\n    DK_I: DATA_KEY_TYPE = \"infer\"\n    DK_L: DATA_KEY_TYPE = \"learn\"\n\n    @abstractmethod\n    def fetch(\n        self,\n        selector: Union[pd.Timestamp, slice, str, pd.Index] = slice(None, None),\n        level: Union[str, int] = \"datetime\",\n        col_set: Union[str, List[str]] = CS_ALL,\n        data_key: DATA_KEY_TYPE = DK_I,\n    ) -> pd.DataFrame:\n        pass\n\n\nclass DataHandler(DataHandlerABC):\n    \"\"\"\n    The motivation of DataHandler:\n\n    - It provides an implementation of BaseDataHandler that we implement with:\n        - Handling responses with an internal loaded DataFrame\n        - The DataFrame is loaded by a data loader.\n\n    The steps to using a handler\n    1. initialized data handler  (call by `init`).\n    2. use the data.\n\n\n    The data handler try to maintain a handler with 2 level.\n    `datetime` & `instruments`.\n\n    Any order of the index level can be supported (The order will be implied in the data).\n    The order  <`datetime`, `instruments`> will be used when the dataframe index name is missed.\n\n    Example of the data:\n    The multi-index of the columns is optional.\n\n    .. code-block:: text\n\n                                feature                                                            label\n                                $close     $volume  Ref($close, 1)  Mean($close, 3)  $high-$low  LABEL0\n        datetime   instrument\n        2010-01-04 SH600000    81.807068  17145150.0       83.737389        83.016739    2.741058  0.0032\n                   SH600004    13.313329  11800983.0       13.313329        13.317701    0.183632  0.0042\n                   SH600005    37.796539  12231662.0       38.258602        37.919757    0.970325  0.0289\n\n\n    Tips for improving the performance of datahandler\n    - Fetching data with `col_set=CS_RAW` will return the raw data and may avoid pandas from copying the data when calling `loc`\n    \"\"\"\n\n    _data: pd.DataFrame  # underlying data.\n\n    def __init__(\n        self,\n        instruments=None,\n        start_time=None,\n        end_time=None,\n        data_loader: Union[dict, str, DataLoader] = None,\n        init_data=True,\n        fetch_orig=True,\n    ):\n        \"\"\"\n        Parameters\n        ----------\n        instruments :\n            The stock list to retrieve.\n        start_time :\n            start_time of the original data.\n        end_time :\n            end_time of the original data.\n        data_loader : Union[dict, str, DataLoader]\n            data loader to load the data.\n        init_data :\n            initialize the original data in the constructor.\n        fetch_orig : bool\n            Return the original data instead of copy if possible.\n        \"\"\"\n\n        # Setup data loader\n        assert data_loader is not None  # to make start_time end_time could have None default value\n\n        # what data source to load data\n        self.data_loader = init_instance_by_config(\n            data_loader,\n            None if (isinstance(data_loader, dict) and \"module_path\" in data_loader) else data_loader_module,\n            accept_types=DataLoader,\n        )\n\n        # what data to be loaded from data source\n        # For IDE auto-completion.\n        self.instruments = instruments\n        self.start_time = start_time\n        self.end_time = end_time\n\n        self.fetch_orig = fetch_orig\n        if init_data:\n            with TimeInspector.logt(\"Init data\"):\n                self.setup_data()\n        super().__init__()\n\n    def config(self, **kwargs):\n        \"\"\"\n        configuration of data.\n        # what data to be loaded from data source\n\n        This method will be used when loading pickled handler from dataset.\n        The data will be initialized with different time range.\n\n        \"\"\"\n        attr_list = {\"instruments\", \"start_time\", \"end_time\"}\n        for k, v in kwargs.items():\n            if k in attr_list:\n                setattr(self, k, v)\n\n        for attr in attr_list:\n            if attr in kwargs:\n                kwargs.pop(attr)\n\n        super().config(**kwargs)\n\n    def setup_data(self, enable_cache: bool = False):\n        \"\"\"\n        Set Up the data in case of running initialization for multiple time\n\n        It is responsible for maintaining following variable\n        1) self._data\n\n        Parameters\n        ----------\n        enable_cache : bool\n            default value is false:\n\n            - if `enable_cache` == True:\n\n                the processed data will be saved on disk, and handler will load the cached data from the disk directly\n                when we call `init` next time\n        \"\"\"\n        # Setup data.\n        # _data may be with multiple column index level. The outer level indicates the feature set name\n        with TimeInspector.logt(\"Loading data\"):\n            # make sure the fetch method is based on an index-sorted pd.DataFrame\n            self._data = lazy_sort_index(self.data_loader.load(self.instruments, self.start_time, self.end_time))\n        # TODO: cache\n\n    def fetch(\n        self,\n        selector: Union[pd.Timestamp, slice, str, pd.Index] = slice(None, None),\n        level: Union[str, int] = \"datetime\",\n        col_set: Union[str, List[str]] = DataHandlerABC.CS_ALL,\n        data_key: DATA_KEY_TYPE = DataHandlerABC.DK_I,\n        squeeze: bool = False,\n        proc_func: Optional[Callable] = None,\n    ) -> pd.DataFrame:\n        \"\"\"\n        fetch data from underlying data source\n\n        Design motivation:\n        - providing a unified interface for underlying data.\n        - Potential to make the interface more friendly.\n        - User can improve performance when fetching data in this extra layer\n\n        Parameters\n        ----------\n        selector : Union[pd.Timestamp, slice, str]\n            describe how to select data by index\n            It can be categories as following\n\n            - fetch single index\n            - fetch a range of index\n\n                - a slice range\n                - pd.Index for specific indexes\n\n            Following conflicts may occur\n\n            - Does [\"20200101\", \"20210101\"] mean selecting this slice or these two days?\n\n                - slice have higher priorities\n\n        level : Union[str, int]\n            which index level to select the data\n\n        col_set : Union[str, List[str]]\n\n            - if isinstance(col_set, str):\n\n                select a set of meaningful, pd.Index columns.(e.g. features, columns)\n\n                - if col_set == CS_RAW:\n\n                    the raw dataset will be returned.\n\n            - if isinstance(col_set, List[str]):\n\n                select several sets of meaningful columns, the returned data has multiple levels\n\n        proc_func: Callable\n\n            - Give a hook for processing data before fetching\n            - An example to explain the necessity of the hook:\n\n                - A Dataset learned some processors to process data which is related to data segmentation\n                - It will apply them every time when preparing data.\n                - The learned processor require the dataframe remains the same format when fitting and applying\n                - However the data format will change according to the parameters.\n                - So the processors should be applied to the underlayer data.\n\n        squeeze : bool\n            whether squeeze columns and index\n\n        Returns\n        -------\n        pd.DataFrame.\n        \"\"\"\n        # DataHandler is an example with only one dataframe, so data_key is not used.\n        _ = data_key  # avoid linting errors (e.g., unused-argument)\n        return self._fetch_data(\n            data_storage=self._data,\n            selector=selector,\n            level=level,\n            col_set=col_set,\n            squeeze=squeeze,\n            proc_func=proc_func,\n        )\n\n    def _fetch_data(\n        self,\n        data_storage,\n        selector: Union[pd.Timestamp, slice, str, pd.Index] = slice(None, None),\n        level: Union[str, int] = \"datetime\",\n        col_set: Union[str, List[str]] = DataHandlerABC.CS_ALL,\n        squeeze: bool = False,\n        proc_func: Callable = None,\n    ):\n        # This method is extracted for sharing in subclasses\n        from .storage import BaseHandlerStorage  # pylint: disable=C0415\n\n        # Following conflicts may occur\n        # - Does [20200101\", \"20210101\"] mean selecting this slice or these two days?\n        # To solve this issue\n        #   - slice have higher priorities (except when level is none)\n        if isinstance(selector, (tuple, list)) and level is not None:\n            # when level is None, the argument will be passed in directly\n            # we don't have to convert it into slice\n            try:\n                selector = slice(*selector)\n            except ValueError:\n                get_module_logger(\"DataHandlerLP\").info(f\"Fail to converting to query to slice. It will used directly\")\n\n        if isinstance(data_storage, pd.DataFrame):\n            data_df = data_storage\n            if proc_func is not None:\n                # FIXME: fetching by time first will be more friendly to `proc_func`\n                # Copy in case of `proc_func` changing the data inplace....\n                data_df = proc_func(fetch_df_by_index(data_df, selector, level, fetch_orig=self.fetch_orig).copy())\n                data_df = fetch_df_by_col(data_df, col_set)\n            else:\n                # Fetch column  first will be more friendly to SepDataFrame\n                data_df = fetch_df_by_col(data_df, col_set)\n                data_df = fetch_df_by_index(data_df, selector, level, fetch_orig=self.fetch_orig)\n        elif isinstance(data_storage, BaseHandlerStorage):\n            if proc_func is not None:\n                raise ValueError(f\"proc_func is not supported by the storage {type(data_storage)}\")\n            data_df = data_storage.fetch(selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig)\n        else:\n            raise TypeError(f\"data_storage should be pd.DataFrame|HashingStockStorage, not {type(data_storage)}\")\n\n        if squeeze:\n            # squeeze columns\n            data_df = data_df.squeeze()\n            # squeeze index\n            if isinstance(selector, (str, pd.Timestamp)):\n                data_df = data_df.reset_index(level=level, drop=True)\n        return data_df\n\n    def get_cols(self, col_set=DataHandlerABC.CS_ALL) -> list:\n        \"\"\"\n        get the column names\n\n        Parameters\n        ----------\n        col_set : str\n            select a set of meaningful columns.(e.g. features, columns)\n\n        Returns\n        -------\n        list:\n            list of column names\n        \"\"\"\n        df = self._data.head()\n        df = fetch_df_by_col(df, col_set)\n        return df.columns.to_list()\n\n    def get_range_selector(self, cur_date: Union[pd.Timestamp, str], periods: int) -> slice:\n        \"\"\"\n        get range selector by number of periods\n\n        Args:\n            cur_date (pd.Timestamp or str): current date\n            periods (int): number of periods\n        \"\"\"\n        trading_dates = self._data.index.unique(level=\"datetime\")\n        cur_loc = trading_dates.get_loc(cur_date)\n        pre_loc = cur_loc - periods + 1\n        if pre_loc < 0:\n            warnings.warn(\"`periods` is too large. the first date will be returned.\")\n            pre_loc = 0\n        ref_date = trading_dates[pre_loc]\n        return slice(ref_date, cur_date)\n\n    def get_range_iterator(\n        self, periods: int, min_periods: Optional[int] = None, **kwargs\n    ) -> Iterator[Tuple[pd.Timestamp, pd.DataFrame]]:\n        \"\"\"\n        get an iterator of sliced data with given periods\n\n        Args:\n            periods (int): number of periods.\n            min_periods (int): minimum periods for sliced dataframe.\n            kwargs (dict): will be passed to `self.fetch`.\n        \"\"\"\n        trading_dates = self._data.index.unique(level=\"datetime\")\n        if min_periods is None:\n            min_periods = periods\n        for cur_date in trading_dates[min_periods:]:\n            selector = self.get_range_selector(cur_date, periods)\n            yield cur_date, self.fetch(selector, **kwargs)\n\n\nclass DataHandlerLP(DataHandler):\n    \"\"\"\n    Motivation:\n    - For the case that we hope using different processor workflows for learning and inference;\n\n\n    DataHandler with **(L)earnable (P)rocessor**\n\n    This handler will produce three pieces of data in pd.DataFrame format.\n\n    - DK_R / self._data: the raw data loaded from the loader\n    - DK_I / self._infer: the data processed for inference\n    - DK_L / self._learn: the data processed for learning model.\n\n    The motivation of using different processor workflows for learning and inference\n    Here are some examples.\n\n    - The instrument universe for learning and inference may be different.\n    - The processing of some samples may rely on label (for example, some samples hit the limit may need extra processing or be dropped).\n\n        - These processors only apply to the learning phase.\n\n    Tips for data handler\n\n    - To reduce the memory cost\n\n        - `drop_raw=True`: this will modify the data inplace on raw data;\n\n    - Please note processed data like `self._infer` or `self._learn` are concepts different from `segments` in Qlib's `Dataset` like \"train\" and \"test\"\n\n        - Processed data like `self._infer` or `self._learn` are underlying data processed with different processors\n        - `segments` in Qlib's `Dataset` like \"train\" and \"test\" are simply the time segmentations when querying data(\"train\" are often before \"test\" in time-series).\n        - For example, you can query `data._infer` processed by `infer_processors` in the \"train\" time segmentation.\n    \"\"\"\n\n    # based on `self._data`, _infer and _learn are genrated after processors\n    _infer: pd.DataFrame  # data for inference\n    _learn: pd.DataFrame  # data for learning models\n\n    # map data_key to attribute name\n    ATTR_MAP = {DataHandler.DK_R: \"_data\", DataHandler.DK_I: \"_infer\", DataHandler.DK_L: \"_learn\"}\n\n    # process type\n    PTYPE_I = \"independent\"\n    # - self._infer will be processed by shared_processors + infer_processors\n    # - self._learn will be processed by shared_processors + learn_processors\n\n    # NOTE:\n    PTYPE_A = \"append\"\n\n    # - self._infer will be processed by shared_processors + infer_processors\n    # - self._learn will be processed by shared_processors + infer_processors + learn_processors\n    #   - (e.g. self._infer processed by learn_processors )\n\n    def __init__(\n        self,\n        instruments=None,\n        start_time=None,\n        end_time=None,\n        data_loader: Union[dict, str, DataLoader] = None,\n        infer_processors: List = [],\n        learn_processors: List = [],\n        shared_processors: List = [],\n        process_type=PTYPE_A,\n        drop_raw=False,\n        **kwargs,\n    ):\n        \"\"\"\n        Parameters\n        ----------\n        infer_processors : list\n            - list of <description info> of processors to generate data for inference\n\n            - example of <description info>:\n\n            .. code-block::\n\n                1) classname & kwargs:\n                    {\n                        \"class\": \"MinMaxNorm\",\n                        \"kwargs\": {\n                            \"fit_start_time\": \"20080101\",\n                            \"fit_end_time\": \"20121231\"\n                        }\n                    }\n                2) Only classname:\n                    \"DropnaFeature\"\n                3) object instance of Processor\n\n        learn_processors : list\n            similar to infer_processors, but for generating data for learning models\n\n        process_type: str\n            PTYPE_I = 'independent'\n\n            - self._infer will be processed by infer_processors\n\n            - self._learn will be processed by learn_processors\n\n            PTYPE_A = 'append'\n\n            - self._infer will be processed by infer_processors\n\n            - self._learn will be processed by infer_processors + learn_processors\n\n              - (e.g. self._infer processed by learn_processors )\n        drop_raw: bool\n            Whether to drop the raw data\n        \"\"\"\n\n        # Setup preprocessor\n        self.infer_processors = []  # for lint\n        self.learn_processors = []  # for lint\n        self.shared_processors = []  # for lint\n        for pname in \"infer_processors\", \"learn_processors\", \"shared_processors\":\n            for proc in locals()[pname]:\n                getattr(self, pname).append(\n                    init_instance_by_config(\n                        proc,\n                        None if (isinstance(proc, dict) and \"module_path\" in proc) else processor_module,\n                        accept_types=processor_module.Processor,\n                    )\n                )\n\n        self.process_type = process_type\n        self.drop_raw = drop_raw\n        super().__init__(instruments, start_time, end_time, data_loader, **kwargs)\n\n    def get_all_processors(self):\n        return self.shared_processors + self.infer_processors + self.learn_processors\n\n    def fit(self):\n        \"\"\"\n        fit data without processing the data\n        \"\"\"\n        for proc in self.get_all_processors():\n            with TimeInspector.logt(f\"{proc.__class__.__name__}\"):\n                proc.fit(self._data)\n\n    def fit_process_data(self):\n        \"\"\"\n        fit and process data\n\n        The input of the `fit` will be the output of the previous processor\n        \"\"\"\n        self.process_data(with_fit=True)\n\n    @staticmethod\n    def _run_proc_l(\n        df: pd.DataFrame, proc_l: List[processor_module.Processor], with_fit: bool, check_for_infer: bool\n    ) -> pd.DataFrame:\n        for proc in proc_l:\n            if check_for_infer and not proc.is_for_infer():\n                raise TypeError(\"Only processors usable for inference can be used in `infer_processors` \")\n            with TimeInspector.logt(f\"{proc.__class__.__name__}\"):\n                if with_fit:\n                    proc.fit(df)\n                df = proc(df)\n        return df\n\n    @staticmethod\n    def _is_proc_readonly(proc_l: List[processor_module.Processor]):\n        \"\"\"\n        NOTE: it will return True if `len(proc_l) == 0`\n        \"\"\"\n        for p in proc_l:\n            if not p.readonly():\n                return False\n        return True\n\n    def process_data(self, with_fit: bool = False):\n        \"\"\"\n        process_data data. Fun `processor.fit` if necessary\n\n        Notation: (data)  [processor]\n\n        # data processing flow of self.process_type == DataHandlerLP.PTYPE_I\n\n        .. code-block:: text\n\n            (self._data)-[shared_processors]-(_shared_df)-[learn_processors]-(_learn_df)\n                                                   \\\\\n                                                    -[infer_processors]-(_infer_df)\n\n        # data processing flow of self.process_type == DataHandlerLP.PTYPE_A\n\n        .. code-block:: text\n\n            (self._data)-[shared_processors]-(_shared_df)-[infer_processors]-(_infer_df)-[learn_processors]-(_learn_df)\n\n        Parameters\n        ----------\n        with_fit : bool\n            The input of the `fit` will be the output of the previous processor\n        \"\"\"\n        # shared data processors\n        # 1) assign\n        _shared_df = self._data\n        if not self._is_proc_readonly(self.shared_processors):  # avoid modifying the original data\n            _shared_df = _shared_df.copy()\n        # 2) process\n        _shared_df = self._run_proc_l(_shared_df, self.shared_processors, with_fit=with_fit, check_for_infer=True)\n\n        # data for inference\n        # 1) assign\n        _infer_df = _shared_df\n        if not self._is_proc_readonly(self.infer_processors):  # avoid modifying the original data\n            _infer_df = _infer_df.copy()\n        # 2) process\n        _infer_df = self._run_proc_l(_infer_df, self.infer_processors, with_fit=with_fit, check_for_infer=True)\n\n        self._infer = _infer_df\n\n        # data for learning\n        # 1) assign\n        if self.process_type == DataHandlerLP.PTYPE_I:\n            _learn_df = _shared_df\n        elif self.process_type == DataHandlerLP.PTYPE_A:\n            # based on `infer_df` and append the processor\n            _learn_df = _infer_df\n        else:\n            raise NotImplementedError(f\"This type of input is not supported\")\n        if not self._is_proc_readonly(self.learn_processors):  # avoid modifying the original  data\n            _learn_df = _learn_df.copy()\n        # 2) process\n        _learn_df = self._run_proc_l(_learn_df, self.learn_processors, with_fit=with_fit, check_for_infer=False)\n\n        self._learn = _learn_df\n\n        if self.drop_raw:\n            del self._data\n\n    def config(self, processor_kwargs: dict = None, **kwargs):\n        \"\"\"\n        configuration of data.\n        # what data to be loaded from data source\n\n        This method will be used when loading pickled handler from dataset.\n        The data will be initialized with different time range.\n\n        \"\"\"\n        super().config(**kwargs)\n        if processor_kwargs is not None:\n            for processor in self.get_all_processors():\n                processor.config(**processor_kwargs)\n\n    # init type\n    IT_FIT_SEQ = \"fit_seq\"  # the input of `fit` will be the output of the previous processor\n    IT_FIT_IND = \"fit_ind\"  # the input of `fit` will be the original df\n    IT_LS = \"load_state\"  # The state of the object has been load by pickle\n\n    def setup_data(self, init_type: str = IT_FIT_SEQ, **kwargs):\n        \"\"\"\n        Set up the data in case of running initialization for multiple time\n\n        Parameters\n        ----------\n        init_type : str\n            The type `IT_*` listed above.\n        enable_cache : bool\n            default value is false:\n\n            - if `enable_cache` == True:\n\n                the processed data will be saved on disk, and handler will load the cached data from the disk directly\n                when we call `init` next time\n        \"\"\"\n        # init raw data\n        super().setup_data(**kwargs)\n\n        with TimeInspector.logt(\"fit & process data\"):\n            if init_type == DataHandlerLP.IT_FIT_IND:\n                self.fit()\n                self.process_data()\n            elif init_type == DataHandlerLP.IT_LS:\n                self.process_data()\n            elif init_type == DataHandlerLP.IT_FIT_SEQ:\n                self.fit_process_data()\n            else:\n                raise NotImplementedError(f\"This type of input is not supported\")\n\n        # TODO: Be able to cache handler data. Save the memory for data processing\n\n    def _get_df_by_key(self, data_key: DATA_KEY_TYPE = DataHandlerABC.DK_I) -> pd.DataFrame:\n        if data_key == self.DK_R and self.drop_raw:\n            raise AttributeError(\n                \"DataHandlerLP has not attribute _data, please set drop_raw = False if you want to use raw data\"\n            )\n        df = getattr(self, self.ATTR_MAP[data_key])\n        return df\n\n    def fetch(\n        self,\n        selector: Union[pd.Timestamp, slice, str] = slice(None, None),\n        level: Union[str, int] = \"datetime\",\n        col_set=DataHandler.CS_ALL,\n        data_key: DATA_KEY_TYPE = DataHandler.DK_I,\n        squeeze: bool = False,\n        proc_func: Callable = None,\n    ) -> pd.DataFrame:\n        \"\"\"\n        fetch data from underlying data source\n\n        Parameters\n        ----------\n        selector : Union[pd.Timestamp, slice, str]\n            describe how to select data by index.\n        level : Union[str, int]\n            which index level to select the data.\n        col_set : str\n            select a set of meaningful columns.(e.g. features, columns).\n        data_key : str\n            the data to fetch:  DK_*.\n        proc_func: Callable\n            please refer to the doc of DataHandler.fetch\n\n        Returns\n        -------\n        pd.DataFrame:\n        \"\"\"\n\n        return self._fetch_data(\n            data_storage=self._get_df_by_key(data_key),\n            selector=selector,\n            level=level,\n            col_set=col_set,\n            squeeze=squeeze,\n            proc_func=proc_func,\n        )\n\n    def get_cols(self, col_set=DataHandler.CS_ALL, data_key: DATA_KEY_TYPE = DataHandlerABC.DK_I) -> list:\n        \"\"\"\n        get the column names\n\n        Parameters\n        ----------\n        col_set : str\n            select a set of meaningful columns.(e.g. features, columns).\n        data_key : DATA_KEY_TYPE\n            the data to fetch:  DK_*.\n\n        Returns\n        -------\n        list:\n            list of column names\n        \"\"\"\n        df = self._get_df_by_key(data_key).head()\n        df = fetch_df_by_col(df, col_set)\n        return df.columns.to_list()\n\n    @classmethod\n    def cast(cls, handler: \"DataHandlerLP\") -> \"DataHandlerLP\":\n        \"\"\"\n        Motivation\n\n        - A user creates a datahandler in his customized package. Then he wants to share the processed handler to\n          other users without introduce the package dependency and complicated data processing logic.\n        - This class make it possible by casting the class to DataHandlerLP and only keep the processed data\n\n        Parameters\n        ----------\n        handler : DataHandlerLP\n            A subclass of DataHandlerLP\n\n        Returns\n        -------\n        DataHandlerLP:\n            the converted processed data\n        \"\"\"\n        new_hd: DataHandlerLP = object.__new__(DataHandlerLP)\n        new_hd.from_cast = True  # add a mark for the cast instance\n\n        for key in list(DataHandlerLP.ATTR_MAP.values()) + [\n            \"instruments\",\n            \"start_time\",\n            \"end_time\",\n            \"fetch_orig\",\n            \"drop_raw\",\n        ]:\n            setattr(new_hd, key, getattr(handler, key, None))\n        return new_hd\n\n    @classmethod\n    def from_df(cls, df: pd.DataFrame) -> \"DataHandlerLP\":\n        \"\"\"\n        Motivation:\n        - When user want to get a quick data handler.\n\n        The created data handler will have only one shared Dataframe without processors.\n        After creating the handler, user may often want to dump the handler for reuse\n        Here is a typical use case\n\n        .. code-block:: python\n\n            from qlib.data.dataset import DataHandlerLP\n            dh = DataHandlerLP.from_df(df)\n            dh.to_pickle(fname, dump_all=True)\n\n        TODO:\n        - The StaticDataLoader is quite slow. It don't have to copy the data again...\n\n        \"\"\"\n        loader = data_loader_module.StaticDataLoader(df)\n        return cls(data_loader=loader)\n"
  },
  {
    "path": "qlib/data/dataset/loader.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport abc\nfrom pathlib import Path\nimport warnings\nimport pandas as pd\n\nfrom typing import Tuple, Union, List, Dict\n\nfrom qlib.data import D\nfrom qlib.utils import load_dataset, init_instance_by_config, time_to_slc_point\nfrom qlib.utils.pickle_utils import restricted_pickle_load\nfrom qlib.log import get_module_logger\nfrom qlib.utils.serial import Serializable\n\n\nclass DataLoader(abc.ABC):\n    \"\"\"\n    DataLoader is designed for loading raw data from original data source.\n    \"\"\"\n\n    @abc.abstractmethod\n    def load(self, instruments, start_time=None, end_time=None) -> pd.DataFrame:\n        \"\"\"\n        load the data as pd.DataFrame.\n\n        Example of the data (The multi-index of the columns is optional.):\n\n            .. code-block:: text\n\n                                        feature                                                             label\n                                        $close     $volume     Ref($close, 1)  Mean($close, 3)  $high-$low  LABEL0\n                datetime    instrument\n                2010-01-04  SH600000    81.807068  17145150.0       83.737389        83.016739    2.741058  0.0032\n                            SH600004    13.313329  11800983.0       13.313329        13.317701    0.183632  0.0042\n                            SH600005    37.796539  12231662.0       38.258602        37.919757    0.970325  0.0289\n\n\n        Parameters\n        ----------\n        instruments : str or dict\n            it can either be the market name or the config file of instruments generated by InstrumentProvider.\n            If the value of instruments is None, it means that no filtering is done.\n        start_time : str\n            start of the time range.\n        end_time : str\n            end of the time range.\n\n        Returns\n        -------\n        pd.DataFrame:\n            data load from the under layer source\n\n        Raise\n        -----\n        KeyError:\n            if the instruments filter is not supported, raise KeyError\n        \"\"\"\n\n\nclass DLWParser(DataLoader):\n    \"\"\"\n    (D)ata(L)oader (W)ith (P)arser for features and names\n\n    Extracting this class so that QlibDataLoader and other dataloaders(such as QdbDataLoader) can share the fields.\n    \"\"\"\n\n    def __init__(self, config: Union[list, tuple, dict]):\n        \"\"\"\n        Parameters\n        ----------\n        config : Union[list, tuple, dict]\n            Config will be used to describe the fields and column names\n\n            .. code-block::\n\n                <config> := {\n                    \"group_name1\": <fields_info1>\n                    \"group_name2\": <fields_info2>\n                }\n                or\n                <config> := <fields_info>\n\n                <fields_info> := [\"expr\", ...] | ([\"expr\", ...], [\"col_name\", ...])\n                # NOTE: list or tuple will be treated as the things when parsing\n        \"\"\"\n        self.is_group = isinstance(config, dict)\n\n        if self.is_group:\n            self.fields = {grp: self._parse_fields_info(fields_info) for grp, fields_info in config.items()}\n        else:\n            self.fields = self._parse_fields_info(config)\n\n    def _parse_fields_info(self, fields_info: Union[list, tuple]) -> Tuple[list, list]:\n        if len(fields_info) == 0:\n            raise ValueError(\"The size of fields must be greater than 0\")\n\n        if not isinstance(fields_info, (list, tuple)):\n            raise TypeError(\"Unsupported type\")\n\n        if isinstance(fields_info[0], str):\n            exprs = names = fields_info\n        elif isinstance(fields_info[0], (list, tuple)):\n            exprs, names = fields_info\n        else:\n            raise NotImplementedError(f\"This type of input is not supported\")\n        return exprs, names\n\n    @abc.abstractmethod\n    def load_group_df(\n        self,\n        instruments,\n        exprs: list,\n        names: list,\n        start_time: Union[str, pd.Timestamp] = None,\n        end_time: Union[str, pd.Timestamp] = None,\n        gp_name: str = None,\n    ) -> pd.DataFrame:\n        \"\"\"\n        load the dataframe for specific group\n\n        Parameters\n        ----------\n        instruments :\n            the instruments.\n        exprs : list\n            the expressions to describe the content of the data.\n        names : list\n            the name of the data.\n\n        Returns\n        -------\n        pd.DataFrame:\n            the queried dataframe.\n        \"\"\"\n\n    def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame:\n        if self.is_group:\n            df = pd.concat(\n                {\n                    grp: self.load_group_df(instruments, exprs, names, start_time, end_time, grp)\n                    for grp, (exprs, names) in self.fields.items()\n                },\n                axis=1,\n            )\n        else:\n            exprs, names = self.fields\n            df = self.load_group_df(instruments, exprs, names, start_time, end_time)\n        return df\n\n\nclass QlibDataLoader(DLWParser):\n    \"\"\"Same as QlibDataLoader. The fields can be define by config\"\"\"\n\n    def __init__(\n        self,\n        config: Tuple[list, tuple, dict],\n        filter_pipe: List = None,\n        swap_level: bool = True,\n        freq: Union[str, dict] = \"day\",\n        inst_processors: Union[dict, list] = None,\n    ):\n        \"\"\"\n        Parameters\n        ----------\n        config : Tuple[list, tuple, dict]\n            Please refer to the doc of DLWParser\n        filter_pipe :\n            Filter pipe for the instruments\n        swap_level :\n            Whether to swap level of MultiIndex\n        freq:  dict or str\n            If type(config) == dict and type(freq) == str, load config data using freq.\n            If type(config) == dict and type(freq) == dict, load config[<group_name>] data using freq[<group_name>]\n        inst_processors: dict | list\n            If inst_processors is not None and type(config) == dict; load config[<group_name>] data using inst_processors[<group_name>]\n            If inst_processors is a list, then it will be applied to all groups.\n        \"\"\"\n        self.filter_pipe = filter_pipe\n        self.swap_level = swap_level\n        self.freq = freq\n\n        # sample\n        self.inst_processors = inst_processors if inst_processors is not None else {}\n        assert isinstance(\n            self.inst_processors, (dict, list)\n        ), f\"inst_processors(={self.inst_processors}) must be dict or list\"\n\n        super().__init__(config)\n\n        if self.is_group:\n            # check sample config\n            if isinstance(freq, dict):\n                for _gp in config.keys():\n                    if _gp not in freq:\n                        raise ValueError(f\"freq(={freq}) missing group(={_gp})\")\n                assert (\n                    self.inst_processors\n                ), f\"freq(={self.freq}), inst_processors(={self.inst_processors}) cannot be None/empty\"\n\n    def load_group_df(\n        self,\n        instruments,\n        exprs: list,\n        names: list,\n        start_time: Union[str, pd.Timestamp] = None,\n        end_time: Union[str, pd.Timestamp] = None,\n        gp_name: str = None,\n    ) -> pd.DataFrame:\n        if instruments is None:\n            warnings.warn(\"`instruments` is not set, will load all stocks\")\n            instruments = \"all\"\n        if isinstance(instruments, str):\n            instruments = D.instruments(instruments, filter_pipe=self.filter_pipe)\n        elif self.filter_pipe is not None:\n            warnings.warn(\"`filter_pipe` is not None, but it will not be used with `instruments` as list\")\n\n        freq = self.freq[gp_name] if isinstance(self.freq, dict) else self.freq\n        inst_processors = (\n            self.inst_processors if isinstance(self.inst_processors, list) else self.inst_processors.get(gp_name, [])\n        )\n        df = D.features(instruments, exprs, start_time, end_time, freq=freq, inst_processors=inst_processors)\n        df.columns = names\n        if self.swap_level:\n            df = df.swaplevel().sort_index()  # NOTE: if swaplevel, return <datetime, instrument>\n        return df\n\n\nclass StaticDataLoader(DataLoader, Serializable):\n    \"\"\"\n    DataLoader that supports loading data from file or as provided.\n    \"\"\"\n\n    include_attr = [\"_config\"]\n\n    def __init__(self, config: Union[dict, str, pd.DataFrame], join=\"outer\"):\n        \"\"\"\n        Parameters\n        ----------\n        config : dict\n            {fields_group: <path or object>}\n        join : str\n            How to align different dataframes\n        \"\"\"\n        self._config = config  # using \"_\" to avoid confliction with the method `config` of Serializable\n        self.join = join\n        self._data = None\n\n    def __getstate__(self) -> dict:\n        # avoid pickling `self._data`\n        return {k: v for k, v in self.__dict__.items() if not k.startswith(\"_\")}\n\n    def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame:\n        self._maybe_load_raw_data()\n\n        # 1) Filter by instruments\n        if instruments is None:\n            df = self._data\n        else:\n            df = self._data.loc(axis=0)[:, instruments]\n\n        # 2) Filter by Datetime\n        if start_time is None and end_time is None:\n            return df  # NOTE: avoid copy by loc\n        # pd.Timestamp(None) == NaT, use NaT as index can not fetch correct thing, so do not change None.\n        start_time = time_to_slc_point(start_time)\n        end_time = time_to_slc_point(end_time)\n        return df.loc[start_time:end_time]\n\n    def _maybe_load_raw_data(self):\n        if self._data is not None:\n            return\n        if isinstance(self._config, dict):\n            self._data = pd.concat(\n                {fields_group: load_dataset(path_or_obj) for fields_group, path_or_obj in self._config.items()},\n                axis=1,\n                join=self.join,\n            )\n            self._data.sort_index(inplace=True)\n        elif isinstance(self._config, (str, Path)):\n            if str(self._config).strip().endswith(\".parquet\"):\n                self._data = pd.read_parquet(self._config, engine=\"pyarrow\")\n            else:\n                with Path(self._config).open(\"rb\") as f:\n                    self._data = restricted_pickle_load(f)\n        elif isinstance(self._config, pd.DataFrame):\n            self._data = self._config\n\n\nclass NestedDataLoader(DataLoader):\n    \"\"\"\n    We have multiple DataLoader, we can use this class to combine them.\n    \"\"\"\n\n    def __init__(self, dataloader_l: List[Dict], join=\"left\") -> None:\n        \"\"\"\n\n        Parameters\n        ----------\n        dataloader_l : list[dict]\n            A list of dataloader, for exmaple\n\n            .. code-block:: python\n\n                nd = NestedDataLoader(\n                    dataloader_l=[\n                        {\n                            \"class\": \"qlib.contrib.data.loader.Alpha158DL\",\n                        }, {\n                            \"class\": \"qlib.contrib.data.loader.Alpha360DL\",\n                            \"kwargs\": {\n                                \"config\": {\n                                    \"label\": ( [\"Ref($close, -2)/Ref($close, -1) - 1\"], [\"LABEL0\"])\n                                }\n                            }\n                        }\n                    ]\n                )\n        join :\n            it will pass to pd.concat when merging it.\n        \"\"\"\n        super().__init__()\n        self.data_loader_l = [\n            (dl if isinstance(dl, DataLoader) else init_instance_by_config(dl)) for dl in dataloader_l\n        ]\n        self.join = join\n\n    def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame:\n        df_full = None\n        for dl in self.data_loader_l:\n            try:\n                df_current = dl.load(instruments, start_time, end_time)\n            except KeyError:\n                warnings.warn(\n                    \"If the value of `instruments` cannot be processed, it will set instruments to None to get all the data.\"\n                )\n                df_current = dl.load(instruments=None, start_time=start_time, end_time=end_time)\n            if df_full is None:\n                df_full = df_current\n            else:\n                current_columns = df_current.columns.tolist()\n                full_columns = df_full.columns.tolist()\n                columns_to_drop = [col for col in current_columns if col in full_columns]\n                df_full.drop(columns=columns_to_drop, inplace=True)\n                df_full = pd.merge(df_full, df_current, left_index=True, right_index=True, how=self.join)\n        return df_full.sort_index(axis=1)\n\n\nclass DataLoaderDH(DataLoader):\n    \"\"\"DataLoaderDH\n    DataLoader based on (D)ata (H)andler\n    It is designed to load multiple data from data handler\n    - If you just want to load data from single datahandler, you can write them in single data handler\n\n    TODO: What make this module not that easy to use.\n\n    - For online scenario\n\n        - The underlayer data handler should be configured. But data loader doesn't provide such interface & hook.\n    \"\"\"\n\n    def __init__(self, handler_config: dict, fetch_kwargs: dict = {}, is_group=False):\n        \"\"\"\n        Parameters\n        ----------\n        handler_config : dict\n            handler_config will be used to describe the handlers\n\n            .. code-block::\n\n                <handler_config> := {\n                    \"group_name1\": <handler>\n                    \"group_name2\": <handler>\n                }\n                or\n                <handler_config> := <handler>\n                <handler> := DataHandler Instance | DataHandler Config\n\n        fetch_kwargs : dict\n            fetch_kwargs will be used to describe the different arguments of fetch method, such as col_set, squeeze, data_key, etc.\n\n        is_group: bool\n            is_group will be used to describe whether the key of handler_config is group\n\n        \"\"\"\n        from qlib.data.dataset.handler import DataHandler  # pylint: disable=C0415\n\n        if is_group:\n            self.handlers = {\n                grp: init_instance_by_config(config, accept_types=DataHandler) for grp, config in handler_config.items()\n            }\n        else:\n            self.handlers = init_instance_by_config(handler_config, accept_types=DataHandler)\n\n        self.is_group = is_group\n        self.fetch_kwargs = {\"col_set\": DataHandler.CS_RAW}\n        self.fetch_kwargs.update(fetch_kwargs)\n\n    def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame:\n        if instruments is not None:\n            get_module_logger(self.__class__.__name__).warning(f\"instruments[{instruments}] is ignored\")\n\n        if self.is_group:\n            df = pd.concat(\n                {\n                    grp: dh.fetch(selector=slice(start_time, end_time), level=\"datetime\", **self.fetch_kwargs)\n                    for grp, dh in self.handlers.items()\n                },\n                axis=1,\n            )\n        else:\n            df = self.handlers.fetch(selector=slice(start_time, end_time), level=\"datetime\", **self.fetch_kwargs)\n        return df\n"
  },
  {
    "path": "qlib/data/dataset/processor.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport abc\nfrom typing import Union, Text, Optional\nimport numpy as np\nimport pandas as pd\n\nfrom qlib.utils.data import robust_zscore, zscore\nfrom ...constant import EPS\nfrom .utils import fetch_df_by_index\nfrom ...utils.serial import Serializable\nfrom ...utils.paral import datetime_groupby_apply\nfrom qlib.data.inst_processor import InstProcessor\nfrom qlib.data import D\n\n\ndef get_group_columns(df: pd.DataFrame, group: Union[Text, None]):\n    \"\"\"\n    get a group of columns from multi-index columns DataFrame\n\n    Parameters\n    ----------\n    df : pd.DataFrame\n        with multi of columns.\n    group : str\n        the name of the feature group, i.e. the first level value of the group index.\n    \"\"\"\n    if group is None:\n        return df.columns\n    else:\n        return df.columns[df.columns.get_loc(group)]\n\n\nclass Processor(Serializable):\n    def fit(self, df: pd.DataFrame = None):\n        \"\"\"\n        learn data processing parameters\n\n        Parameters\n        ----------\n        df : pd.DataFrame\n            When we fit and process data with processor one by one. The fit function reiles on the output of previous\n            processor, i.e. `df`.\n\n        \"\"\"\n\n    @abc.abstractmethod\n    def __call__(self, df: pd.DataFrame):\n        \"\"\"\n        process the data\n\n        NOTE: **The processor could change the content of `df` inplace !!!!! **\n        User should keep a copy of data outside\n\n        Parameters\n        ----------\n        df : pd.DataFrame\n            The raw_df of handler or result from previous processor.\n        \"\"\"\n\n    def is_for_infer(self) -> bool:\n        \"\"\"\n        Is this processor usable for inference\n        Some processors are not usable for inference.\n\n        Returns\n        -------\n        bool:\n            if it is usable for infenrece.\n        \"\"\"\n        return True\n\n    def readonly(self) -> bool:\n        \"\"\"\n        Does the processor treat the input data readonly (i.e. does not write the input data) when processing\n\n        Knowning the readonly information is helpful to the Handler to avoid uncessary copy\n        \"\"\"\n        return False\n\n    def config(self, **kwargs):\n        attr_list = {\"fit_start_time\", \"fit_end_time\"}\n        for k, v in kwargs.items():\n            if k in attr_list and hasattr(self, k):\n                setattr(self, k, v)\n\n        for attr in attr_list:\n            if attr in kwargs:\n                kwargs.pop(attr)\n        super().config(**kwargs)\n\n\nclass DropnaProcessor(Processor):\n    def __init__(self, fields_group=None):\n        self.fields_group = fields_group\n\n    def __call__(self, df):\n        return df.dropna(subset=get_group_columns(df, self.fields_group))\n\n    def readonly(self):\n        return True\n\n\nclass DropnaLabel(DropnaProcessor):\n    def __init__(self, fields_group=\"label\"):\n        super().__init__(fields_group=fields_group)\n\n    def is_for_infer(self) -> bool:\n        \"\"\"The samples are dropped according to label. So it is not usable for inference\"\"\"\n        return False\n\n\nclass DropCol(Processor):\n    def __init__(self, col_list=[]):\n        self.col_list = col_list\n\n    def __call__(self, df):\n        if isinstance(df.columns, pd.MultiIndex):\n            mask = df.columns.get_level_values(-1).isin(self.col_list)\n        else:\n            mask = df.columns.isin(self.col_list)\n        return df.loc[:, ~mask]\n\n    def readonly(self):\n        return True\n\n\nclass FilterCol(Processor):\n    def __init__(self, fields_group=\"feature\", col_list=[]):\n        self.fields_group = fields_group\n        self.col_list = col_list\n\n    def __call__(self, df):\n        cols = get_group_columns(df, self.fields_group)\n        all_cols = df.columns\n        diff_cols = np.setdiff1d(all_cols.get_level_values(-1), cols.get_level_values(-1))\n        self.col_list = np.union1d(diff_cols, self.col_list)\n        mask = df.columns.get_level_values(-1).isin(self.col_list)\n        return df.loc[:, mask]\n\n    def readonly(self):\n        return True\n\n\nclass TanhProcess(Processor):\n    \"\"\"Use tanh to process noise data\"\"\"\n\n    def __call__(self, df):\n        def tanh_denoise(data):\n            mask = data.columns.get_level_values(1).str.contains(\"LABEL\")\n            col = df.columns[~mask]\n            data[col] = data[col] - 1\n            data[col] = np.tanh(data[col])\n\n            return data\n\n        return tanh_denoise(df)\n\n\nclass ProcessInf(Processor):\n    \"\"\"Process infinity\"\"\"\n\n    def __call__(self, df):\n        def replace_inf(data):\n            def process_inf(df):\n                for col in df.columns:\n                    # FIXME: Such behavior is very weird\n                    df[col] = df[col].replace([np.inf, -np.inf], df[col][~np.isinf(df[col])].mean())\n                return df\n\n            data = datetime_groupby_apply(data, process_inf)\n            data.sort_index(inplace=True)\n            return data\n\n        return replace_inf(df)\n\n\nclass Fillna(Processor):\n    \"\"\"Process NaN\"\"\"\n\n    def __init__(self, fields_group=None, fill_value=0):\n        self.fields_group = fields_group\n        self.fill_value = fill_value\n\n    def __call__(self, df):\n        if self.fields_group is None:\n            df.fillna(self.fill_value, inplace=True)\n        else:\n            # this implementation is extremely slow\n            # df.fillna({col: self.fill_value for col in cols}, inplace=True)\n            df[self.fields_group] = df[self.fields_group].fillna(self.fill_value)\n        return df\n\n\nclass MinMaxNorm(Processor):\n    def __init__(self, fit_start_time, fit_end_time, fields_group=None):\n        # NOTE: correctly set the `fit_start_time` and `fit_end_time` is very important !!!\n        # `fit_end_time` **must not** include any information from the test data!!!\n        self.fit_start_time = fit_start_time\n        self.fit_end_time = fit_end_time\n        self.fields_group = fields_group\n\n    def fit(self, df: pd.DataFrame = None):\n        df = fetch_df_by_index(df, slice(self.fit_start_time, self.fit_end_time), level=\"datetime\")\n        cols = get_group_columns(df, self.fields_group)\n        self.min_val = np.nanmin(df[cols].values, axis=0)\n        self.max_val = np.nanmax(df[cols].values, axis=0)\n        self.ignore = self.min_val == self.max_val\n        # To improve the speed, we set the value of `min_val` to `0` for the columns that do not need to be processed,\n        # and the value of `max_val` to `1`, when using `(x - min_val) / (max_val - min_val)` for uniform calculation,\n        # the columns that do not need to be processed will be calculated by `(x - 0) / (1 - 0)`,\n        # as you can see, the columns that do not need to be processed, will not be affected.\n        for _i, _con in enumerate(self.ignore):\n            if _con:\n                self.min_val[_i] = 0\n                self.max_val[_i] = 1\n        self.cols = cols\n\n    def __call__(self, df):\n        def normalize(x, min_val=self.min_val, max_val=self.max_val):\n            return (x - min_val) / (max_val - min_val)\n\n        df.loc(axis=1)[self.cols] = normalize(df[self.cols].values)\n        return df\n\n\nclass ZScoreNorm(Processor):\n    \"\"\"ZScore Normalization\"\"\"\n\n    def __init__(self, fit_start_time, fit_end_time, fields_group=None):\n        # NOTE: correctly set the `fit_start_time` and `fit_end_time` is very important !!!\n        # `fit_end_time` **must not** include any information from the test data!!!\n        self.fit_start_time = fit_start_time\n        self.fit_end_time = fit_end_time\n        self.fields_group = fields_group\n\n    def fit(self, df: pd.DataFrame = None):\n        df = fetch_df_by_index(df, slice(self.fit_start_time, self.fit_end_time), level=\"datetime\")\n        cols = get_group_columns(df, self.fields_group)\n        self.mean_train = np.nanmean(df[cols].values, axis=0)\n        self.std_train = np.nanstd(df[cols].values, axis=0)\n        self.ignore = self.std_train == 0\n        # To improve the speed, we set the value of `std_train` to `1` for the columns that do not need to be processed,\n        # and the value of `mean_train` to `0`, when using `(x - mean_train) / std_train` for uniform calculation,\n        # the columns that do not need to be processed will be calculated by `(x - 0) / 1`,\n        # as you can see, the columns that do not need to be processed, will not be affected.\n        for _i, _con in enumerate(self.ignore):\n            if _con:\n                self.std_train[_i] = 1\n                self.mean_train[_i] = 0\n        self.cols = cols\n\n    def __call__(self, df):\n        def normalize(x, mean_train=self.mean_train, std_train=self.std_train):\n            return (x - mean_train) / std_train\n\n        df.loc(axis=1)[self.cols] = normalize(df[self.cols].values)\n        return df\n\n\nclass RobustZScoreNorm(Processor):\n    \"\"\"Robust ZScore Normalization\n\n    Use robust statistics for Z-Score normalization:\n        mean(x) = median(x)\n        std(x) = MAD(x) * 1.4826\n\n    Reference:\n        https://en.wikipedia.org/wiki/Median_absolute_deviation.\n    \"\"\"\n\n    def __init__(self, fit_start_time, fit_end_time, fields_group=None, clip_outlier=True):\n        # NOTE: correctly set the `fit_start_time` and `fit_end_time` is very important !!!\n        # `fit_end_time` **must not** include any information from the test data!!!\n        self.fit_start_time = fit_start_time\n        self.fit_end_time = fit_end_time\n        self.fields_group = fields_group\n        self.clip_outlier = clip_outlier\n\n    def fit(self, df: pd.DataFrame = None):\n        df = fetch_df_by_index(df, slice(self.fit_start_time, self.fit_end_time), level=\"datetime\")\n        self.cols = get_group_columns(df, self.fields_group)\n        X = df[self.cols].values\n        self.mean_train = np.nanmedian(X, axis=0)\n        self.std_train = np.nanmedian(np.abs(X - self.mean_train), axis=0)\n        self.std_train += EPS\n        self.std_train *= 1.4826\n\n    def __call__(self, df):\n        X = df[self.cols]\n        X -= self.mean_train\n        X /= self.std_train\n        if self.clip_outlier:\n            X = np.clip(X, -3, 3)\n        df[self.cols] = X\n        return df\n\n\nclass CSZScoreNorm(Processor):\n    \"\"\"Cross Sectional ZScore Normalization\"\"\"\n\n    def __init__(self, fields_group=None, method=\"zscore\"):\n        self.fields_group = fields_group\n        if method == \"zscore\":\n            self.zscore_func = zscore\n        elif method == \"robust\":\n            self.zscore_func = robust_zscore\n        else:\n            raise NotImplementedError(f\"This type of input is not supported\")\n\n    def __call__(self, df):\n        # try not modify original dataframe\n        if not isinstance(self.fields_group, list):\n            self.fields_group = [self.fields_group]\n        # depress warning by references:\n        # https://stackoverflow.com/questions/20625582/how-to-deal-with-settingwithcopywarning-in-pandas\n        # https://pandas.pydata.org/pandas-docs/stable/user_guide/options.html#getting-and-setting-options\n        with pd.option_context(\"mode.chained_assignment\", None):\n            for g in self.fields_group:\n                cols = get_group_columns(df, g)\n                df[cols] = df[cols].groupby(\"datetime\", group_keys=False).apply(self.zscore_func)\n        return df\n\n\nclass CSRankNorm(Processor):\n    \"\"\"\n    Cross Sectional Rank Normalization.\n    \"Cross Sectional\" is often used to describe data operations.\n    The operations across different stocks are often called Cross Sectional Operation.\n\n    For example, CSRankNorm is an operation that grouping the data by each day and rank `across` all the stocks in each day.\n\n    Explanation about 3.46 & 0.5\n\n    .. code-block:: python\n\n        import numpy as np\n        import pandas as pd\n        x = np.random.random(10000)  # for any variable\n        x_rank = pd.Series(x).rank(pct=True)  # if it is converted to rank, it will be a uniform distributed\n        x_rank_norm = (x_rank - x_rank.mean()) / x_rank.std()  # Normally, we will normalize it to make it like normal distribution\n\n        x_rank.mean()   # accounts for 0.5\n        1 / x_rank.std()  # accounts for 3.46\n\n    \"\"\"\n\n    def __init__(self, fields_group=None):\n        self.fields_group = fields_group\n\n    def __call__(self, df):\n        # try not modify original dataframe\n        cols = get_group_columns(df, self.fields_group)\n        t = df[cols].groupby(\"datetime\", group_keys=False).rank(pct=True)\n        t -= 0.5\n        t *= 3.46  # NOTE: towards unit std\n        df[cols] = t\n        return df\n\n\nclass CSZFillna(Processor):\n    \"\"\"Cross Sectional Fill Nan\"\"\"\n\n    def __init__(self, fields_group=None):\n        self.fields_group = fields_group\n\n    def __call__(self, df):\n        cols = get_group_columns(df, self.fields_group)\n        df[cols] = df[cols].groupby(\"datetime\", group_keys=False).apply(lambda x: x.fillna(x.mean()))\n        return df\n\n\nclass HashStockFormat(Processor):\n    \"\"\"Process the storage of from df into hasing stock format\"\"\"\n\n    def __call__(self, df: pd.DataFrame):\n        from .storage import HashingStockStorage  # pylint: disable=C0415\n\n        return HashingStockStorage.from_df(df)\n\n\nclass TimeRangeFlt(InstProcessor):\n    \"\"\"\n    This is a filter to filter stock.\n    Only keep the data that exist from start_time to end_time (the existence in the middle is not checked.)\n    WARNING:  It may induce leakage!!!\n    \"\"\"\n\n    def __init__(\n        self,\n        start_time: Optional[Union[pd.Timestamp, str]] = None,\n        end_time: Optional[Union[pd.Timestamp, str]] = None,\n        freq: str = \"day\",\n    ):\n        \"\"\"\n        Parameters\n        ----------\n        start_time : Optional[Union[pd.Timestamp, str]]\n            The data must start earlier (or equal) than `start_time`\n            None indicates data will not be filtered based on `start_time`\n        end_time : Optional[Union[pd.Timestamp, str]]\n            similar to start_time\n        freq : str\n            The frequency of the calendar\n        \"\"\"\n        # Align to calendar before filtering\n        cal = D.calendar(start_time=start_time, end_time=end_time, freq=freq)\n        self.start_time = None if start_time is None else cal[0]\n        self.end_time = None if end_time is None else cal[-1]\n\n    def __call__(self, df: pd.DataFrame, instrument, *args, **kwargs):\n        if (\n            df.empty\n            or (self.start_time is None or df.index.min() <= self.start_time)\n            and (self.end_time is None or df.index.max() >= self.end_time)\n        ):\n            return df\n        return df.head(0)\n"
  },
  {
    "path": "qlib/data/dataset/storage.py",
    "content": "from abc import abstractmethod\nimport pandas as pd\nimport numpy as np\n\nfrom .handler import DataHandler\nfrom typing import Union, List\nfrom qlib.log import get_module_logger\n\nfrom .utils import get_level_index, fetch_df_by_index, fetch_df_by_col\n\n\nclass BaseHandlerStorage:\n    \"\"\"\n    Base data storage for datahandler\n    - pd.DataFrame is the default data storage format in Qlib datahandler\n    - If users want to use custom data storage, they should define subclass inherited BaseHandlerStorage, and implement the following method\n    \"\"\"\n\n    @abstractmethod\n    def fetch(\n        self,\n        selector: Union[pd.Timestamp, slice, str, pd.Index] = slice(None, None),\n        level: Union[str, int] = \"datetime\",\n        col_set: Union[str, List[str]] = DataHandler.CS_ALL,\n        fetch_orig: bool = True,\n    ) -> pd.DataFrame:\n        \"\"\"fetch data from the data storage\n\n        Parameters\n        ----------\n        selector : Union[pd.Timestamp, slice, str]\n            describe how to select data by index\n        level : Union[str, int]\n            which index level to select the data\n            - if level is None, apply selector to df directly\n        col_set : Union[str, List[str]]\n            - if isinstance(col_set, str):\n                select a set of meaningful columns.(e.g. features, columns)\n                if col_set == DataHandler.CS_RAW:\n                    the raw dataset will be returned.\n            - if isinstance(col_set, List[str]):\n                select several sets of meaningful columns, the returned data has multiple level\n        fetch_orig : bool\n            Return the original data instead of copy if possible.\n\n        Returns\n        -------\n        pd.DataFrame\n            the dataframe fetched\n        \"\"\"\n        raise NotImplementedError(\"fetch is method not implemented!\")\n\n\nclass NaiveDFStorage(BaseHandlerStorage):\n    \"\"\"Naive data storage for datahandler\n    - NaiveDFStorage is a naive data storage for datahandler\n    - NaiveDFStorage will input a pandas.DataFrame as and provide interface support for fetching data\n    \"\"\"\n\n    def __init__(self, df: pd.DataFrame):\n        self.df = df\n\n    def fetch(\n        self,\n        selector: Union[pd.Timestamp, slice, str, pd.Index] = slice(None, None),\n        level: Union[str, int] = \"datetime\",\n        col_set: Union[str, List[str]] = DataHandler.CS_ALL,\n        fetch_orig: bool = True,\n    ) -> pd.DataFrame:\n        # Following conflicts may occur\n        # - Does [20200101\", \"20210101\"] mean selecting this slice or these two days?\n        # To solve this issue\n        #   - slice have higher priorities (except when level is none)\n        if isinstance(selector, (tuple, list)) and level is not None:\n            # when level is None, the argument will be passed in directly\n            # we don't have to convert it into slice\n            try:\n                selector = slice(*selector)\n            except ValueError:\n                get_module_logger(\"DataHandlerLP\").info(f\"Fail to converting to query to slice. It will used directly\")\n\n        data_df = self.df\n        data_df = fetch_df_by_col(data_df, col_set)\n        data_df = fetch_df_by_index(data_df, selector, level, fetch_orig=fetch_orig)\n        return data_df\n\n\nclass HashingStockStorage(BaseHandlerStorage):\n    \"\"\"Hashing data storage for datahanlder\n    - The default data storage pandas.DataFrame is too slow when randomly accessing one stock's data\n    - HashingStockStorage hashes the multiple stocks' data(pandas.DataFrame) by the key `stock_id`.\n    - HashingStockStorage hashes the pandas.DataFrame into a dict, whose key is the stock_id(str) and value this stock data(panda.DataFrame), it has the following format:\n        {\n            stock1_id: stock1_data,\n            stock2_id: stock2_data,\n            ...\n            stockn_id: stockn_data,\n        }\n    - By the `fetch` method, users can access any stock data with much lower time cost than default data storage\n    \"\"\"\n\n    def __init__(self, df):\n        self.hash_df = dict()\n        self.stock_level = get_level_index(df, \"instrument\")\n        for k, v in df.groupby(level=\"instrument\", group_keys=False):\n            self.hash_df[k] = v\n        self.columns = df.columns\n\n    @staticmethod\n    def from_df(df):\n        return HashingStockStorage(df)\n\n    def _fetch_hash_df_by_stock(self, selector, level):\n        \"\"\"fetch the data with stock selector\n\n        Parameters\n        ----------\n        selector : Union[pd.Timestamp, slice, str]\n            describe how to select data by index\n        level : Union[str, int]\n            which index level to select the data\n            - if level is None, apply selector to df directly\n            - the `_fetch_hash_df_by_stock` will parse the stock selector in arg `selector`\n\n        Returns\n        -------\n        Dict\n            The dict whose key is stock_id, value is the stock's data\n        \"\"\"\n\n        stock_selector = slice(None)\n        time_selector = slice(None)  # by default not filter by time.\n\n        if level is None:\n            # For directly applying.\n            if isinstance(selector, tuple) and self.stock_level < len(selector):\n                # full selector format\n                stock_selector = selector[self.stock_level]\n                time_selector = selector[1 - self.stock_level]\n            elif isinstance(selector, (list, str)) and self.stock_level == 0:\n                # only stock selector\n                stock_selector = selector\n        elif level in (\"instrument\", self.stock_level):\n            if isinstance(selector, tuple):\n                # NOTE: How could the stock level selector be a tuple?\n                stock_selector = selector[0]\n                raise TypeError(\n                    \"I forget why would this case appear. But I think it does not make sense. So we raise a error for that case.\"\n                )\n            elif isinstance(selector, (list, str)):\n                stock_selector = selector\n\n        if not isinstance(stock_selector, (list, str)) and stock_selector != slice(None):\n            raise TypeError(f\"stock selector must be type str|list, or slice(None), rather than {stock_selector}\")\n\n        if stock_selector == slice(None):\n            return self.hash_df, time_selector\n\n        if isinstance(stock_selector, str):\n            stock_selector = [stock_selector]\n\n        select_dict = dict()\n        for each_stock in sorted(stock_selector):\n            if each_stock in self.hash_df:\n                select_dict[each_stock] = self.hash_df[each_stock]\n        return select_dict, time_selector\n\n    def fetch(\n        self,\n        selector: Union[pd.Timestamp, slice, str, pd.Index] = slice(None, None),\n        level: Union[str, int] = \"datetime\",\n        col_set: Union[str, List[str]] = DataHandler.CS_ALL,\n        fetch_orig: bool = True,\n    ) -> pd.DataFrame:\n        fetch_stock_df_list, time_selector = self._fetch_hash_df_by_stock(selector=selector, level=level)\n        fetch_stock_df_list = list(fetch_stock_df_list.values())\n        for _index, stock_df in enumerate(fetch_stock_df_list):\n            fetch_col_df = fetch_df_by_col(df=stock_df, col_set=col_set)\n            fetch_index_df = fetch_df_by_index(\n                df=fetch_col_df, selector=time_selector, level=\"datetime\", fetch_orig=fetch_orig\n            )\n            fetch_stock_df_list[_index] = fetch_index_df\n        if len(fetch_stock_df_list) == 0:\n            index_names = (\"instrument\", \"datetime\") if self.stock_level == 0 else (\"datetime\", \"instrument\")\n            return pd.DataFrame(\n                index=pd.MultiIndex.from_arrays([[], []], names=index_names), columns=self.columns, dtype=np.float32\n            )\n        elif len(fetch_stock_df_list) == 1:\n            return fetch_stock_df_list[0]\n        else:\n            return pd.concat(fetch_stock_df_list, sort=False, copy=~fetch_orig)\n"
  },
  {
    "path": "qlib/data/dataset/utils.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\nfrom __future__ import annotations\nimport pandas as pd\nfrom typing import Union, List, TYPE_CHECKING\nfrom qlib.utils import init_instance_by_config\n\nif TYPE_CHECKING:\n    from qlib.data.dataset import DataHandler\n\n\ndef get_level_index(df: pd.DataFrame, level: Union[str, int]) -> int:\n    \"\"\"\n\n    get the level index of `df` given `level`\n\n    Parameters\n    ----------\n    df : pd.DataFrame\n        data\n    level : Union[str, int]\n        index level\n\n    Returns\n    -------\n    int:\n        The level index in the multiple index\n    \"\"\"\n    if isinstance(level, str):\n        try:\n            return df.index.names.index(level)\n        except (AttributeError, ValueError):\n            # NOTE: If level index is not given in the data, the default level index will be ('datetime', 'instrument')\n            return (\"datetime\", \"instrument\").index(level)\n    elif isinstance(level, int):\n        return level\n    else:\n        raise NotImplementedError(f\"This type of input is not supported\")\n\n\ndef fetch_df_by_index(\n    df: pd.DataFrame,\n    selector: Union[pd.Timestamp, slice, str, list, pd.Index],\n    level: Union[str, int],\n    fetch_orig=True,\n) -> pd.DataFrame:\n    \"\"\"\n    fetch data from `data` with `selector` and `level`\n\n    selector are assumed to be well processed.\n    `fetch_df_by_index` is only responsible for get the right level\n\n    Parameters\n    ----------\n    selector : Union[pd.Timestamp, slice, str, list]\n        selector\n    level : Union[int, str]\n        the level to use the selector\n\n    Returns\n    -------\n    Data of the given index.\n    \"\"\"\n    # level = None -> use selector directly\n    if level is None or isinstance(selector, pd.MultiIndex):\n        return df.loc(axis=0)[selector]\n    # Try to get the right index\n    idx_slc = (selector, slice(None, None))\n    if get_level_index(df, level) == 1:\n        idx_slc = idx_slc[1], idx_slc[0]\n    if fetch_orig:\n        for slc in idx_slc:\n            if slc != slice(None, None):\n                return df.loc[pd.IndexSlice[idx_slc],]  # noqa: E231\n        else:  # pylint: disable=W0120\n            return df\n    else:\n        return df.loc[pd.IndexSlice[idx_slc],]  # noqa: E231\n\n\ndef fetch_df_by_col(df: pd.DataFrame, col_set: Union[str, List[str]]) -> pd.DataFrame:\n    from .handler import DataHandler  # pylint: disable=C0415\n\n    if not isinstance(df.columns, pd.MultiIndex) or col_set == DataHandler.CS_RAW:\n        return df\n    elif col_set == DataHandler.CS_ALL:\n        return df.droplevel(axis=1, level=0)\n    else:\n        return df.loc(axis=1)[col_set]\n\n\ndef convert_index_format(df: Union[pd.DataFrame, pd.Series], level: str = \"datetime\") -> Union[pd.DataFrame, pd.Series]:\n    \"\"\"\n    Convert the format of df.MultiIndex according to the following rules:\n        - If `level` is the first level of df.MultiIndex, do nothing\n        - If `level` is the second level of df.MultiIndex, swap the level of index.\n\n    NOTE:\n        the number of levels of df.MultiIndex should be 2\n\n    Parameters\n    ----------\n    df : Union[pd.DataFrame, pd.Series]\n        raw DataFrame/Series\n    level : str, optional\n        the level that will be converted to the first one, by default \"datetime\"\n\n    Returns\n    -------\n    Union[pd.DataFrame, pd.Series]\n        converted DataFrame/Series\n    \"\"\"\n\n    if get_level_index(df, level=level) == 1:\n        df = df.swaplevel().sort_index()\n    return df\n\n\ndef init_task_handler(task: dict) -> DataHandler:\n    \"\"\"\n    initialize the handler part of the task **inplace**\n\n    Parameters\n    ----------\n    task : dict\n        the task to be handled\n\n    Returns\n    -------\n    Union[DataHandler, None]:\n        returns\n    \"\"\"\n    # avoid recursive import\n    from .handler import DataHandler  # pylint: disable=C0415\n\n    h_conf = task[\"dataset\"][\"kwargs\"].get(\"handler\")\n    if h_conf is not None:\n        handler = init_instance_by_config(h_conf, accept_types=DataHandler)\n        task[\"dataset\"][\"kwargs\"][\"handler\"] = handler\n        return handler\n    else:\n        raise ValueError(\"The task does not contains a handler part.\")\n"
  },
  {
    "path": "qlib/data/dataset/weight.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\n\nclass Reweighter:\n    def __init__(self, *args, **kwargs):\n        \"\"\"\n        To initialize the Reweighter, users should provide specific methods to let reweighter do the reweighting (such as sample-wise, rule-based).\n        \"\"\"\n        raise NotImplementedError()\n\n    def reweight(self, data: object) -> object:\n        \"\"\"\n        Get weights for data\n\n        Parameters\n        ----------\n        data : object\n            The input data.\n            The first dimension is the index of samples\n\n        Returns\n        -------\n        object:\n            the weights info for the data\n        \"\"\"\n        raise NotImplementedError(f\"This type of input is not supported\")\n"
  },
  {
    "path": "qlib/data/filter.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom __future__ import print_function\nfrom abc import abstractmethod\n\nimport re\nimport pandas as pd\nimport numpy as np\nimport abc\n\nfrom .data import Cal, DatasetD\n\n\nclass BaseDFilter(abc.ABC):\n    \"\"\"Dynamic Instruments Filter Abstract class\n\n    Users can override this class to construct their own filter\n\n    Override __init__ to input filter regulations\n\n    Override filter_main to use the regulations to filter instruments\n    \"\"\"\n\n    def __init__(self):\n        pass\n\n    @staticmethod\n    def from_config(config):\n        \"\"\"Construct an instance from config dict.\n\n        Parameters\n        ----------\n        config : dict\n            dict of config parameters.\n        \"\"\"\n        raise NotImplementedError(\"Subclass of BaseDFilter must reimplement `from_config` method\")\n\n    @abstractmethod\n    def to_config(self):\n        \"\"\"Construct an instance from config dict.\n\n        Returns\n        ----------\n        dict\n            return the dict of config parameters.\n        \"\"\"\n        raise NotImplementedError(\"Subclass of BaseDFilter must reimplement `to_config` method\")\n\n\nclass SeriesDFilter(BaseDFilter):\n    \"\"\"Dynamic Instruments Filter Abstract class to filter a series of certain features\n\n    Filters should provide parameters:\n\n    - filter start time\n    - filter end time\n    - filter rule\n\n    Override __init__ to assign a certain rule to filter the series.\n\n    Override _getFilterSeries to use the rule to filter the series and get a dict of {inst => series}, or override filter_main for more advanced series filter rule\n    \"\"\"\n\n    def __init__(self, fstart_time=None, fend_time=None, keep=False):\n        \"\"\"Init function for filter base class.\n            Filter a set of instruments based on a certain rule within a certain period assigned by fstart_time and fend_time.\n\n        Parameters\n        ----------\n        fstart_time: str\n            the time for the filter rule to start filter the instruments.\n        fend_time: str\n            the time for the filter rule to stop filter the instruments.\n        keep: bool\n            whether to keep the instruments of which features don't exist in the filter time span.\n        \"\"\"\n        super(SeriesDFilter, self).__init__()\n        self.filter_start_time = pd.Timestamp(fstart_time) if fstart_time else None\n        self.filter_end_time = pd.Timestamp(fend_time) if fend_time else None\n        self.keep = keep\n\n    def _getTimeBound(self, instruments):\n        \"\"\"Get time bound for all instruments.\n\n        Parameters\n        ----------\n        instruments: dict\n            the dict of instruments in the form {instrument_name => list of timestamp tuple}.\n\n        Returns\n        ----------\n        pd.Timestamp, pd.Timestamp\n            the lower time bound and upper time bound of all the instruments.\n        \"\"\"\n        trange = Cal.calendar(freq=self.filter_freq)\n        ubound, lbound = trange[0], trange[-1]\n        for _, timestamp in instruments.items():\n            if timestamp:\n                lbound = timestamp[0][0] if timestamp[0][0] < lbound else lbound\n                ubound = timestamp[-1][-1] if timestamp[-1][-1] > ubound else ubound\n        return lbound, ubound\n\n    def _toSeries(self, time_range, target_timestamp):\n        \"\"\"Convert the target timestamp to a pandas series of bool value within a time range.\n            Make the time inside the target_timestamp range TRUE, others FALSE.\n\n        Parameters\n        ----------\n        time_range : D.calendar\n            the time range of the instruments.\n        target_timestamp : list\n            the list of tuple (timestamp, timestamp).\n\n        Returns\n        ----------\n        pd.Series\n            the series of bool value for an instrument.\n        \"\"\"\n        # Construct a whole dict of {date => bool}\n        timestamp_series = {timestamp: False for timestamp in time_range}\n        # Convert to pd.Series\n        timestamp_series = pd.Series(timestamp_series)\n        # Fill the date within target_timestamp with TRUE\n        for start, end in target_timestamp:\n            timestamp_series[Cal.calendar(start_time=start, end_time=end, freq=self.filter_freq)] = True\n        return timestamp_series\n\n    def _filterSeries(self, timestamp_series, filter_series):\n        \"\"\"Filter the timestamp series with filter series by using element-wise AND operation of the two series.\n\n        Parameters\n        ----------\n        timestamp_series : pd.Series\n            the series of bool value indicating existing time.\n        filter_series : pd.Series\n            the series of bool value indicating filter feature.\n\n        Returns\n        ----------\n        pd.Series\n            the series of bool value indicating whether the date satisfies the filter condition and exists in target timestamp.\n        \"\"\"\n        fstart, fend = list(filter_series.keys())[0], list(filter_series.keys())[-1]\n        filter_series = filter_series.astype(\"bool\")  # Make sure the filter_series is boolean\n        timestamp_series[fstart:fend] = timestamp_series[fstart:fend] & filter_series\n        return timestamp_series\n\n    def _toTimestamp(self, timestamp_series):\n        \"\"\"Convert the timestamp series to a list of tuple (timestamp, timestamp) indicating a continuous range of TRUE.\n\n        Parameters\n        ----------\n        timestamp_series: pd.Series\n            the series of bool value after being filtered.\n\n        Returns\n        ----------\n        list\n            the list of tuple (timestamp, timestamp).\n        \"\"\"\n        # sort the timestamp_series according to the timestamps\n        timestamp_series.sort_index()\n        timestamp = []\n        _lbool = None\n        _ltime = None\n        _cur_start = None\n        for _ts, _bool in timestamp_series.items():\n            # there is likely to be NAN when the filter series don't have the\n            # bool value, so we just change the NAN into False\n            if np.isnan(_bool):\n                _bool = False\n            if _lbool is None:\n                _cur_start = _ts\n                _lbool = _bool\n                _ltime = _ts\n                continue\n            if (_lbool, _bool) == (True, False):\n                if _cur_start:\n                    timestamp.append((_cur_start, _ltime))\n            elif (_lbool, _bool) == (False, True):\n                _cur_start = _ts\n            _lbool = _bool\n            _ltime = _ts\n        if _lbool:\n            timestamp.append((_cur_start, _ltime))\n        return timestamp\n\n    def __call__(self, instruments, start_time=None, end_time=None, freq=\"day\"):\n        \"\"\"Call this filter to get filtered instruments list\"\"\"\n        self.filter_freq = freq\n        return self.filter_main(instruments, start_time, end_time)\n\n    @abstractmethod\n    def _getFilterSeries(self, instruments, fstart, fend):\n        \"\"\"Get filter series based on the rules assigned during the initialization and the input time range.\n\n        Parameters\n        ----------\n        instruments : dict\n            the dict of instruments to be filtered.\n        fstart : pd.Timestamp\n            start time of filter.\n        fend : pd.Timestamp\n            end time of filter.\n\n        .. note:: fstart/fend indicates the intersection of instruments start/end time and filter start/end time.\n\n        Returns\n        ----------\n        pd.Dataframe\n            a series of {pd.Timestamp => bool}.\n        \"\"\"\n        raise NotImplementedError(\"Subclass of SeriesDFilter must reimplement `getFilterSeries` method\")\n\n    def filter_main(self, instruments, start_time=None, end_time=None):\n        \"\"\"Implement this method to filter the instruments.\n\n        Parameters\n        ----------\n        instruments: dict\n            input instruments to be filtered.\n        start_time: str\n            start of the time range.\n        end_time: str\n            end of the time range.\n\n        Returns\n        ----------\n        dict\n            filtered instruments, same structure as input instruments.\n        \"\"\"\n        lbound, ubound = self._getTimeBound(instruments)\n        start_time = pd.Timestamp(start_time or lbound)\n        end_time = pd.Timestamp(end_time or ubound)\n        _instruments_filtered = {}\n        _all_calendar = Cal.calendar(start_time=start_time, end_time=end_time, freq=self.filter_freq)\n        _filter_calendar = Cal.calendar(\n            start_time=self.filter_start_time and max(self.filter_start_time, _all_calendar[0]) or _all_calendar[0],\n            end_time=self.filter_end_time and min(self.filter_end_time, _all_calendar[-1]) or _all_calendar[-1],\n            freq=self.filter_freq,\n        )\n        _all_filter_series = self._getFilterSeries(instruments, _filter_calendar[0], _filter_calendar[-1])\n        for inst, timestamp in instruments.items():\n            # Construct a whole map of date\n            _timestamp_series = self._toSeries(_all_calendar, timestamp)\n            # Get filter series\n            if inst in _all_filter_series:\n                _filter_series = _all_filter_series[inst]\n            else:\n                if self.keep:\n                    _filter_series = pd.Series({timestamp: True for timestamp in _filter_calendar})\n                else:\n                    _filter_series = pd.Series({timestamp: False for timestamp in _filter_calendar})\n            # Calculate bool value within the range of filter\n            _timestamp_series = self._filterSeries(_timestamp_series, _filter_series)\n            # Reform the map to (start_timestamp, end_timestamp) format\n            _timestamp = self._toTimestamp(_timestamp_series)\n            # Remove empty timestamp\n            if _timestamp:\n                _instruments_filtered[inst] = _timestamp\n        return _instruments_filtered\n\n\nclass NameDFilter(SeriesDFilter):\n    \"\"\"Name dynamic instrument filter\n\n    Filter the instruments based on a regulated name format.\n\n    A name rule regular expression is required.\n    \"\"\"\n\n    def __init__(self, name_rule_re, fstart_time=None, fend_time=None):\n        \"\"\"Init function for name filter class\n\n        Parameters\n        ----------\n        name_rule_re: str\n            regular expression for the name rule.\n        \"\"\"\n        super(NameDFilter, self).__init__(fstart_time, fend_time)\n        self.name_rule_re = name_rule_re\n\n    def _getFilterSeries(self, instruments, fstart, fend):\n        all_filter_series = {}\n        filter_calendar = Cal.calendar(start_time=fstart, end_time=fend, freq=self.filter_freq)\n        for inst, timestamp in instruments.items():\n            if re.match(self.name_rule_re, inst):\n                _filter_series = pd.Series({timestamp: True for timestamp in filter_calendar})\n            else:\n                _filter_series = pd.Series({timestamp: False for timestamp in filter_calendar})\n            all_filter_series[inst] = _filter_series\n        return all_filter_series\n\n    @staticmethod\n    def from_config(config):\n        return NameDFilter(\n            name_rule_re=config[\"name_rule_re\"],\n            fstart_time=config[\"filter_start_time\"],\n            fend_time=config[\"filter_end_time\"],\n        )\n\n    def to_config(self):\n        return {\n            \"filter_type\": \"NameDFilter\",\n            \"name_rule_re\": self.name_rule_re,\n            \"filter_start_time\": str(self.filter_start_time) if self.filter_start_time else self.filter_start_time,\n            \"filter_end_time\": str(self.filter_end_time) if self.filter_end_time else self.filter_end_time,\n        }\n\n\nclass ExpressionDFilter(SeriesDFilter):\n    \"\"\"Expression dynamic instrument filter\n\n    Filter the instruments based on a certain expression.\n\n    An expression rule indicating a certain feature field is required.\n\n    Examples\n    ----------\n    - *basic features filter* : rule_expression = '$close/$open>5'\n    - *cross-sectional features filter* : rule_expression = '$rank($close)<10'\n    - *time-sequence features filter* : rule_expression = '$Ref($close, 3)>100'\n    \"\"\"\n\n    def __init__(self, rule_expression, fstart_time=None, fend_time=None, keep=False):\n        \"\"\"Init function for expression filter class\n\n        Parameters\n        ----------\n        fstart_time: str\n            filter the feature starting from this time.\n        fend_time: str\n            filter the feature ending by this time.\n        rule_expression: str\n            an input expression for the rule.\n        \"\"\"\n        super(ExpressionDFilter, self).__init__(fstart_time, fend_time, keep=keep)\n        self.rule_expression = rule_expression\n\n    def _getFilterSeries(self, instruments, fstart, fend):\n        # do not use dataset cache\n        try:\n            _features = DatasetD.dataset(\n                instruments,\n                [self.rule_expression],\n                fstart,\n                fend,\n                freq=self.filter_freq,\n                disk_cache=0,\n            )\n        except TypeError:\n            # use LocalDatasetProvider\n            _features = DatasetD.dataset(instruments, [self.rule_expression], fstart, fend, freq=self.filter_freq)\n        rule_expression_field_name = list(_features.keys())[0]\n        all_filter_series = _features[rule_expression_field_name]\n        return all_filter_series\n\n    @staticmethod\n    def from_config(config):\n        return ExpressionDFilter(\n            rule_expression=config[\"rule_expression\"],\n            fstart_time=config[\"filter_start_time\"],\n            fend_time=config[\"filter_end_time\"],\n            keep=config[\"keep\"],\n        )\n\n    def to_config(self):\n        return {\n            \"filter_type\": \"ExpressionDFilter\",\n            \"rule_expression\": self.rule_expression,\n            \"filter_start_time\": str(self.filter_start_time) if self.filter_start_time else self.filter_start_time,\n            \"filter_end_time\": str(self.filter_end_time) if self.filter_end_time else self.filter_end_time,\n            \"keep\": self.keep,\n        }\n"
  },
  {
    "path": "qlib/data/inst_processor.py",
    "content": "import abc\nimport json\nimport pandas as pd\n\n\nclass InstProcessor:\n    @abc.abstractmethod\n    def __call__(self, df: pd.DataFrame, instrument, *args, **kwargs):\n        \"\"\"\n        process the data\n\n        NOTE: **The processor could change the content of `df` inplace !!!!! **\n        User should keep a copy of data outside\n\n        Parameters\n        ----------\n        df : pd.DataFrame\n            The raw_df of handler or result from previous processor.\n        \"\"\"\n\n    def __str__(self):\n        return f\"{self.__class__.__name__}:{json.dumps(self.__dict__, sort_keys=True, default=str)}\"\n"
  },
  {
    "path": "qlib/data/ops.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\n\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport numpy as np\nimport pandas as pd\n\nfrom typing import Union, List, Type\nfrom scipy.stats import percentileofscore\nfrom .base import Expression, ExpressionOps, Feature, PFeature\nfrom ..log import get_module_logger\nfrom ..utils import get_callable_kwargs\n\ntry:\n    from ._libs.rolling import rolling_slope, rolling_rsquare, rolling_resi\n    from ._libs.expanding import expanding_slope, expanding_rsquare, expanding_resi\nexcept ImportError:\n    print(\n        \"#### Do not import qlib package in the repository directory in case of importing qlib from . without compiling #####\"\n    )\n    raise\nexcept ValueError:\n    print(\"!!!!!!!! A error occurs when importing operators implemented based on Cython.!!!!!!!!\")\n    print(\"!!!!!!!! They will be disabled. Please Upgrade your numpy to enable them     !!!!!!!!\")\n    # We catch this error because some platform can't upgrade there package (e.g. Kaggle)\n    # https://www.kaggle.com/general/293387\n    # https://www.kaggle.com/product-feedback/98562\n\n\nnp.seterr(invalid=\"ignore\")\n\n\n#################### Element-Wise Operator ####################\nclass ElemOperator(ExpressionOps):\n    \"\"\"Element-wise Operator\n\n    Parameters\n    ----------\n    feature : Expression\n        feature instance\n\n    Returns\n    ----------\n    Expression\n        feature operation output\n    \"\"\"\n\n    def __init__(self, feature):\n        self.feature = feature\n\n    def __str__(self):\n        return \"{}({})\".format(type(self).__name__, self.feature)\n\n    def get_longest_back_rolling(self):\n        return self.feature.get_longest_back_rolling()\n\n    def get_extended_window_size(self):\n        return self.feature.get_extended_window_size()\n\n\nclass ChangeInstrument(ElemOperator):\n    \"\"\"Change Instrument Operator\n    In some case, one may want to change to another instrument when calculating, for example, to\n    calculate beta of a stock with respect to a market index.\n    This would require changing the calculation of features from the stock (original instrument) to\n    the index (reference instrument)\n    Parameters\n    ----------\n    instrument: new instrument for which the downstream operations should be performed upon.\n                i.e., SH000300 (CSI300 index), or ^GPSC (SP500 index).\n\n    feature: the feature to be calculated for the new instrument.\n    Returns\n    ----------\n    Expression\n        feature operation output\n    \"\"\"\n\n    def __init__(self, instrument, feature):\n        self.instrument = instrument\n        self.feature = feature\n\n    def __str__(self):\n        return \"{}('{}',{})\".format(type(self).__name__, self.instrument, self.feature)\n\n    def load(self, instrument, start_index, end_index, *args):\n        # the first `instrument` is ignored\n        return super().load(self.instrument, start_index, end_index, *args)\n\n    def _load_internal(self, instrument, start_index, end_index, *args):\n        return self.feature.load(instrument, start_index, end_index, *args)\n\n\nclass NpElemOperator(ElemOperator):\n    \"\"\"Numpy Element-wise Operator\n\n    Parameters\n    ----------\n    feature : Expression\n        feature instance\n    func : str\n        numpy feature operation method\n\n    Returns\n    ----------\n    Expression\n        feature operation output\n    \"\"\"\n\n    def __init__(self, feature, func):\n        self.func = func\n        super(NpElemOperator, self).__init__(feature)\n\n    def _load_internal(self, instrument, start_index, end_index, *args):\n        series = self.feature.load(instrument, start_index, end_index, *args)\n        return getattr(np, self.func)(series)\n\n\nclass Abs(NpElemOperator):\n    \"\"\"Feature Absolute Value\n\n    Parameters\n    ----------\n    feature : Expression\n        feature instance\n\n    Returns\n    ----------\n    Expression\n        a feature instance with absolute output\n    \"\"\"\n\n    def __init__(self, feature):\n        super(Abs, self).__init__(feature, \"abs\")\n\n\nclass Sign(NpElemOperator):\n    \"\"\"Feature Sign\n\n    Parameters\n    ----------\n    feature : Expression\n        feature instance\n\n    Returns\n    ----------\n    Expression\n        a feature instance with sign\n    \"\"\"\n\n    def __init__(self, feature):\n        super(Sign, self).__init__(feature, \"sign\")\n\n    def _load_internal(self, instrument, start_index, end_index, *args):\n        \"\"\"\n        To avoid error raised by bool type input, we transform the data into float32.\n        \"\"\"\n        series = self.feature.load(instrument, start_index, end_index, *args)\n        # TODO:  More precision types should be configurable\n        series = series.astype(np.float32)\n        return getattr(np, self.func)(series)\n\n\nclass Log(NpElemOperator):\n    \"\"\"Feature Log\n\n    Parameters\n    ----------\n    feature : Expression\n        feature instance\n\n    Returns\n    ----------\n    Expression\n        a feature instance with log\n    \"\"\"\n\n    def __init__(self, feature):\n        super(Log, self).__init__(feature, \"log\")\n\n\nclass Mask(NpElemOperator):\n    \"\"\"Feature Mask\n\n    Parameters\n    ----------\n    feature : Expression\n        feature instance\n    instrument : str\n        instrument mask\n\n    Returns\n    ----------\n    Expression\n        a feature instance with masked instrument\n    \"\"\"\n\n    def __init__(self, feature, instrument):\n        super(Mask, self).__init__(feature, \"mask\")\n        self.instrument = instrument\n\n    def __str__(self):\n        return \"{}({},{})\".format(type(self).__name__, self.feature, self.instrument.lower())\n\n    def _load_internal(self, instrument, start_index, end_index, *args):\n        return self.feature.load(self.instrument, start_index, end_index, *args)\n\n\nclass Not(NpElemOperator):\n    \"\"\"Not Operator\n\n    Parameters\n    ----------\n    feature : Expression\n        feature instance\n\n    Returns\n    ----------\n    Feature:\n        feature elementwise not output\n    \"\"\"\n\n    def __init__(self, feature):\n        super(Not, self).__init__(feature, \"bitwise_not\")\n\n\n#################### Pair-Wise Operator ####################\nclass PairOperator(ExpressionOps):\n    \"\"\"Pair-wise operator\n\n    Parameters\n    ----------\n    feature_left : Expression\n        feature instance or numeric value\n    feature_right : Expression\n        feature instance or numeric value\n\n    Returns\n    ----------\n    Feature:\n        two features' operation output\n    \"\"\"\n\n    def __init__(self, feature_left, feature_right):\n        self.feature_left = feature_left\n        self.feature_right = feature_right\n\n    def __str__(self):\n        return \"{}({},{})\".format(type(self).__name__, self.feature_left, self.feature_right)\n\n    def get_longest_back_rolling(self):\n        if isinstance(self.feature_left, (Expression,)):\n            left_br = self.feature_left.get_longest_back_rolling()\n        else:\n            left_br = 0\n\n        if isinstance(self.feature_right, (Expression,)):\n            right_br = self.feature_right.get_longest_back_rolling()\n        else:\n            right_br = 0\n        return max(left_br, right_br)\n\n    def get_extended_window_size(self):\n        if isinstance(self.feature_left, (Expression,)):\n            ll, lr = self.feature_left.get_extended_window_size()\n        else:\n            ll, lr = 0, 0\n\n        if isinstance(self.feature_right, (Expression,)):\n            rl, rr = self.feature_right.get_extended_window_size()\n        else:\n            rl, rr = 0, 0\n        return max(ll, rl), max(lr, rr)\n\n\nclass NpPairOperator(PairOperator):\n    \"\"\"Numpy Pair-wise operator\n\n    Parameters\n    ----------\n    feature_left : Expression\n        feature instance or numeric value\n    feature_right : Expression\n        feature instance or numeric value\n    func : str\n        operator function\n\n    Returns\n    ----------\n    Feature:\n        two features' operation output\n    \"\"\"\n\n    def __init__(self, feature_left, feature_right, func):\n        self.func = func\n        super(NpPairOperator, self).__init__(feature_left, feature_right)\n\n    def _load_internal(self, instrument, start_index, end_index, *args):\n        assert any(\n            [isinstance(self.feature_left, (Expression,)), self.feature_right, Expression]\n        ), \"at least one of two inputs is Expression instance\"\n        if isinstance(self.feature_left, (Expression,)):\n            series_left = self.feature_left.load(instrument, start_index, end_index, *args)\n        else:\n            series_left = self.feature_left  # numeric value\n        if isinstance(self.feature_right, (Expression,)):\n            series_right = self.feature_right.load(instrument, start_index, end_index, *args)\n        else:\n            series_right = self.feature_right\n        check_length = isinstance(series_left, (np.ndarray, pd.Series)) and isinstance(\n            series_right, (np.ndarray, pd.Series)\n        )\n        if check_length:\n            warning_info = (\n                f\"Loading {instrument}: {str(self)}; np.{self.func}(series_left, series_right), \"\n                f\"The length of series_left and series_right is different: ({len(series_left)}, {len(series_right)}), \"\n                f\"series_left is {str(self.feature_left)}, series_right is {str(self.feature_right)}. Please check the data\"\n            )\n        else:\n            warning_info = (\n                f\"Loading {instrument}: {str(self)}; np.{self.func}(series_left, series_right), \"\n                f\"series_left is {str(self.feature_left)}, series_right is {str(self.feature_right)}. Please check the data\"\n            )\n        try:\n            res = getattr(np, self.func)(series_left, series_right)\n        except ValueError as e:\n            get_module_logger(\"ops\").debug(warning_info)\n            raise ValueError(f\"{str(e)}. \\n\\t{warning_info}\") from e\n        else:\n            if check_length and len(series_left) != len(series_right):\n                get_module_logger(\"ops\").debug(warning_info)\n        return res\n\n\nclass Power(NpPairOperator):\n    \"\"\"Power Operator\n\n    Parameters\n    ----------\n    feature_left : Expression\n        feature instance\n    feature_right : Expression\n        feature instance\n\n    Returns\n    ----------\n    Feature:\n        The bases in feature_left raised to the exponents in feature_right\n    \"\"\"\n\n    def __init__(self, feature_left, feature_right):\n        super(Power, self).__init__(feature_left, feature_right, \"power\")\n\n\nclass Add(NpPairOperator):\n    \"\"\"Add Operator\n\n    Parameters\n    ----------\n    feature_left : Expression\n        feature instance\n    feature_right : Expression\n        feature instance\n\n    Returns\n    ----------\n    Feature:\n        two features' sum\n    \"\"\"\n\n    def __init__(self, feature_left, feature_right):\n        super(Add, self).__init__(feature_left, feature_right, \"add\")\n\n\nclass Sub(NpPairOperator):\n    \"\"\"Subtract Operator\n\n    Parameters\n    ----------\n    feature_left : Expression\n        feature instance\n    feature_right : Expression\n        feature instance\n\n    Returns\n    ----------\n    Feature:\n        two features' subtraction\n    \"\"\"\n\n    def __init__(self, feature_left, feature_right):\n        super(Sub, self).__init__(feature_left, feature_right, \"subtract\")\n\n\nclass Mul(NpPairOperator):\n    \"\"\"Multiply Operator\n\n    Parameters\n    ----------\n    feature_left : Expression\n        feature instance\n    feature_right : Expression\n        feature instance\n\n    Returns\n    ----------\n    Feature:\n        two features' product\n    \"\"\"\n\n    def __init__(self, feature_left, feature_right):\n        super(Mul, self).__init__(feature_left, feature_right, \"multiply\")\n\n\nclass Div(NpPairOperator):\n    \"\"\"Division Operator\n\n    Parameters\n    ----------\n    feature_left : Expression\n        feature instance\n    feature_right : Expression\n        feature instance\n\n    Returns\n    ----------\n    Feature:\n        two features' division\n    \"\"\"\n\n    def __init__(self, feature_left, feature_right):\n        super(Div, self).__init__(feature_left, feature_right, \"divide\")\n\n\nclass Greater(NpPairOperator):\n    \"\"\"Greater Operator\n\n    Parameters\n    ----------\n    feature_left : Expression\n        feature instance\n    feature_right : Expression\n        feature instance\n\n    Returns\n    ----------\n    Feature:\n        greater elements taken from the input two features\n    \"\"\"\n\n    def __init__(self, feature_left, feature_right):\n        super(Greater, self).__init__(feature_left, feature_right, \"maximum\")\n\n\nclass Less(NpPairOperator):\n    \"\"\"Less Operator\n\n    Parameters\n    ----------\n    feature_left : Expression\n        feature instance\n    feature_right : Expression\n        feature instance\n\n    Returns\n    ----------\n    Feature:\n        smaller elements taken from the input two features\n    \"\"\"\n\n    def __init__(self, feature_left, feature_right):\n        super(Less, self).__init__(feature_left, feature_right, \"minimum\")\n\n\nclass Gt(NpPairOperator):\n    \"\"\"Greater Than Operator\n\n    Parameters\n    ----------\n    feature_left : Expression\n        feature instance\n    feature_right : Expression\n        feature instance\n\n    Returns\n    ----------\n    Feature:\n        bool series indicate `left > right`\n    \"\"\"\n\n    def __init__(self, feature_left, feature_right):\n        super(Gt, self).__init__(feature_left, feature_right, \"greater\")\n\n\nclass Ge(NpPairOperator):\n    \"\"\"Greater Equal Than Operator\n\n    Parameters\n    ----------\n    feature_left : Expression\n        feature instance\n    feature_right : Expression\n        feature instance\n\n    Returns\n    ----------\n    Feature:\n        bool series indicate `left >= right`\n    \"\"\"\n\n    def __init__(self, feature_left, feature_right):\n        super(Ge, self).__init__(feature_left, feature_right, \"greater_equal\")\n\n\nclass Lt(NpPairOperator):\n    \"\"\"Less Than Operator\n\n    Parameters\n    ----------\n    feature_left : Expression\n        feature instance\n    feature_right : Expression\n        feature instance\n\n    Returns\n    ----------\n    Feature:\n        bool series indicate `left < right`\n    \"\"\"\n\n    def __init__(self, feature_left, feature_right):\n        super(Lt, self).__init__(feature_left, feature_right, \"less\")\n\n\nclass Le(NpPairOperator):\n    \"\"\"Less Equal Than Operator\n\n    Parameters\n    ----------\n    feature_left : Expression\n        feature instance\n    feature_right : Expression\n        feature instance\n\n    Returns\n    ----------\n    Feature:\n        bool series indicate `left <= right`\n    \"\"\"\n\n    def __init__(self, feature_left, feature_right):\n        super(Le, self).__init__(feature_left, feature_right, \"less_equal\")\n\n\nclass Eq(NpPairOperator):\n    \"\"\"Equal Operator\n\n    Parameters\n    ----------\n    feature_left : Expression\n        feature instance\n    feature_right : Expression\n        feature instance\n\n    Returns\n    ----------\n    Feature:\n        bool series indicate `left == right`\n    \"\"\"\n\n    def __init__(self, feature_left, feature_right):\n        super(Eq, self).__init__(feature_left, feature_right, \"equal\")\n\n\nclass Ne(NpPairOperator):\n    \"\"\"Not Equal Operator\n\n    Parameters\n    ----------\n    feature_left : Expression\n        feature instance\n    feature_right : Expression\n        feature instance\n\n    Returns\n    ----------\n    Feature:\n        bool series indicate `left != right`\n    \"\"\"\n\n    def __init__(self, feature_left, feature_right):\n        super(Ne, self).__init__(feature_left, feature_right, \"not_equal\")\n\n\nclass And(NpPairOperator):\n    \"\"\"And Operator\n\n    Parameters\n    ----------\n    feature_left : Expression\n        feature instance\n    feature_right : Expression\n        feature instance\n\n    Returns\n    ----------\n    Feature:\n        two features' row by row & output\n    \"\"\"\n\n    def __init__(self, feature_left, feature_right):\n        super(And, self).__init__(feature_left, feature_right, \"bitwise_and\")\n\n\nclass Or(NpPairOperator):\n    \"\"\"Or Operator\n\n    Parameters\n    ----------\n    feature_left : Expression\n        feature instance\n    feature_right : Expression\n        feature instance\n\n    Returns\n    ----------\n    Feature:\n        two features' row by row | outputs\n    \"\"\"\n\n    def __init__(self, feature_left, feature_right):\n        super(Or, self).__init__(feature_left, feature_right, \"bitwise_or\")\n\n\n#################### Triple-wise Operator ####################\nclass If(ExpressionOps):\n    \"\"\"If Operator\n\n    Parameters\n    ----------\n    condition : Expression\n        feature instance with bool values as condition\n    feature_left : Expression\n        feature instance\n    feature_right : Expression\n        feature instance\n    \"\"\"\n\n    def __init__(self, condition, feature_left, feature_right):\n        self.condition = condition\n        self.feature_left = feature_left\n        self.feature_right = feature_right\n\n    def __str__(self):\n        return \"If({},{},{})\".format(self.condition, self.feature_left, self.feature_right)\n\n    def _load_internal(self, instrument, start_index, end_index, *args):\n        series_cond = self.condition.load(instrument, start_index, end_index, *args)\n        if isinstance(self.feature_left, (Expression,)):\n            series_left = self.feature_left.load(instrument, start_index, end_index, *args)\n        else:\n            series_left = self.feature_left\n        if isinstance(self.feature_right, (Expression,)):\n            series_right = self.feature_right.load(instrument, start_index, end_index, *args)\n        else:\n            series_right = self.feature_right\n        series = pd.Series(np.where(series_cond, series_left, series_right), index=series_cond.index)\n        return series\n\n    def get_longest_back_rolling(self):\n        if isinstance(self.feature_left, (Expression,)):\n            left_br = self.feature_left.get_longest_back_rolling()\n        else:\n            left_br = 0\n\n        if isinstance(self.feature_right, (Expression,)):\n            right_br = self.feature_right.get_longest_back_rolling()\n        else:\n            right_br = 0\n\n        if isinstance(self.condition, (Expression,)):\n            c_br = self.condition.get_longest_back_rolling()\n        else:\n            c_br = 0\n        return max(left_br, right_br, c_br)\n\n    def get_extended_window_size(self):\n        if isinstance(self.feature_left, (Expression,)):\n            ll, lr = self.feature_left.get_extended_window_size()\n        else:\n            ll, lr = 0, 0\n\n        if isinstance(self.feature_right, (Expression,)):\n            rl, rr = self.feature_right.get_extended_window_size()\n        else:\n            rl, rr = 0, 0\n\n        if isinstance(self.condition, (Expression,)):\n            cl, cr = self.condition.get_extended_window_size()\n        else:\n            cl, cr = 0, 0\n        return max(ll, rl, cl), max(lr, rr, cr)\n\n\n#################### Rolling ####################\n# NOTE: methods like `rolling.mean` are optimized with cython,\n# and are super faster than `rolling.apply(np.mean)`\n\n\nclass Rolling(ExpressionOps):\n    \"\"\"Rolling Operator\n    The meaning of rolling and expanding is the same in pandas.\n    When the window is set to 0, the behaviour of the operator should follow `expanding`\n    Otherwise, it follows `rolling`\n\n    Parameters\n    ----------\n    feature : Expression\n        feature instance\n    N : int\n        rolling window size\n    func : str\n        rolling method\n\n    Returns\n    ----------\n    Expression\n        rolling outputs\n    \"\"\"\n\n    def __init__(self, feature, N, func):\n        self.feature = feature\n        self.N = N\n        self.func = func\n\n    def __str__(self):\n        return \"{}({},{})\".format(type(self).__name__, self.feature, self.N)\n\n    def _load_internal(self, instrument, start_index, end_index, *args):\n        series = self.feature.load(instrument, start_index, end_index, *args)\n        # NOTE: remove all null check,\n        # now it's user's responsibility to decide whether use features in null days\n        # isnull = series.isnull() # NOTE: isnull = NaN, inf is not null\n        if isinstance(self.N, int) and self.N == 0:\n            series = getattr(series.expanding(min_periods=1), self.func)()\n        elif isinstance(self.N, float) and 0 < self.N < 1:\n            series = series.ewm(alpha=self.N, min_periods=1).mean()\n        else:\n            series = getattr(series.rolling(self.N, min_periods=1), self.func)()\n            # series.iloc[:self.N-1] = np.nan\n        # series[isnull] = np.nan\n        return series\n\n    def get_longest_back_rolling(self):\n        if self.N == 0:\n            return np.inf\n        if 0 < self.N < 1:\n            return int(np.log(1e-6) / np.log(1 - self.N))  # (1 - N)**window == 1e-6\n        return self.feature.get_longest_back_rolling() + self.N - 1\n\n    def get_extended_window_size(self):\n        if self.N == 0:\n            # FIXME: How to make this accurate and efficiently? Or  should we\n            # remove such support for N == 0?\n            get_module_logger(self.__class__.__name__).warning(\"The Rolling(ATTR, 0) will not be accurately calculated\")\n            return self.feature.get_extended_window_size()\n        elif 0 < self.N < 1:\n            lft_etd, rght_etd = self.feature.get_extended_window_size()\n            size = int(np.log(1e-6) / np.log(1 - self.N))\n            lft_etd = max(lft_etd + size - 1, lft_etd)\n            return lft_etd, rght_etd\n        else:\n            lft_etd, rght_etd = self.feature.get_extended_window_size()\n            lft_etd = max(lft_etd + self.N - 1, lft_etd)\n            return lft_etd, rght_etd\n\n\nclass Ref(Rolling):\n    \"\"\"Feature Reference\n\n    Parameters\n    ----------\n    feature : Expression\n        feature instance\n    N : int\n        N = 0, retrieve the first data; N > 0, retrieve data of N periods ago; N < 0, future data\n\n    Returns\n    ----------\n    Expression\n        a feature instance with target reference\n    \"\"\"\n\n    def __init__(self, feature, N):\n        super(Ref, self).__init__(feature, N, \"ref\")\n\n    def _load_internal(self, instrument, start_index, end_index, *args):\n        series = self.feature.load(instrument, start_index, end_index, *args)\n        # N = 0, return first day\n        if series.empty:\n            return series  # Pandas bug, see: https://github.com/pandas-dev/pandas/issues/21049\n        elif self.N == 0:\n            series = pd.Series(series.iloc[0], index=series.index)\n        else:\n            series = series.shift(self.N)  # copy\n        return series\n\n    def get_longest_back_rolling(self):\n        if self.N == 0:\n            return np.inf\n        return self.feature.get_longest_back_rolling() + self.N\n\n    def get_extended_window_size(self):\n        if self.N == 0:\n            get_module_logger(self.__class__.__name__).warning(\"The Ref(ATTR, 0) will not be accurately calculated\")\n            return self.feature.get_extended_window_size()\n        else:\n            lft_etd, rght_etd = self.feature.get_extended_window_size()\n            lft_etd = max(lft_etd + self.N, lft_etd)\n            rght_etd = max(rght_etd - self.N, rght_etd)\n            return lft_etd, rght_etd\n\n\nclass Mean(Rolling):\n    \"\"\"Rolling Mean (MA)\n\n    Parameters\n    ----------\n    feature : Expression\n        feature instance\n    N : int\n        rolling window size\n\n    Returns\n    ----------\n    Expression\n        a feature instance with rolling average\n    \"\"\"\n\n    def __init__(self, feature, N):\n        super(Mean, self).__init__(feature, N, \"mean\")\n\n\nclass Sum(Rolling):\n    \"\"\"Rolling Sum\n\n    Parameters\n    ----------\n    feature : Expression\n        feature instance\n    N : int\n        rolling window size\n\n    Returns\n    ----------\n    Expression\n        a feature instance with rolling sum\n    \"\"\"\n\n    def __init__(self, feature, N):\n        super(Sum, self).__init__(feature, N, \"sum\")\n\n\nclass Std(Rolling):\n    \"\"\"Rolling Std\n\n    Parameters\n    ----------\n    feature : Expression\n        feature instance\n    N : int\n        rolling window size\n\n    Returns\n    ----------\n    Expression\n        a feature instance with rolling std\n    \"\"\"\n\n    def __init__(self, feature, N):\n        super(Std, self).__init__(feature, N, \"std\")\n\n\nclass Var(Rolling):\n    \"\"\"Rolling Variance\n\n    Parameters\n    ----------\n    feature : Expression\n        feature instance\n    N : int\n        rolling window size\n\n    Returns\n    ----------\n    Expression\n        a feature instance with rolling variance\n    \"\"\"\n\n    def __init__(self, feature, N):\n        super(Var, self).__init__(feature, N, \"var\")\n\n\nclass Skew(Rolling):\n    \"\"\"Rolling Skewness\n\n    Parameters\n    ----------\n    feature : Expression\n        feature instance\n    N : int\n        rolling window size\n\n    Returns\n    ----------\n    Expression\n        a feature instance with rolling skewness\n    \"\"\"\n\n    def __init__(self, feature, N):\n        if N != 0 and N < 3:\n            raise ValueError(\"The rolling window size of Skewness operation should >= 3\")\n        super(Skew, self).__init__(feature, N, \"skew\")\n\n\nclass Kurt(Rolling):\n    \"\"\"Rolling Kurtosis\n\n    Parameters\n    ----------\n    feature : Expression\n        feature instance\n    N : int\n        rolling window size\n\n    Returns\n    ----------\n    Expression\n        a feature instance with rolling kurtosis\n    \"\"\"\n\n    def __init__(self, feature, N):\n        if N != 0 and N < 4:\n            raise ValueError(\"The rolling window size of Kurtosis operation should >= 5\")\n        super(Kurt, self).__init__(feature, N, \"kurt\")\n\n\nclass Max(Rolling):\n    \"\"\"Rolling Max\n\n    Parameters\n    ----------\n    feature : Expression\n        feature instance\n    N : int\n        rolling window size\n\n    Returns\n    ----------\n    Expression\n        a feature instance with rolling max\n    \"\"\"\n\n    def __init__(self, feature, N):\n        super(Max, self).__init__(feature, N, \"max\")\n\n\nclass IdxMax(Rolling):\n    \"\"\"Rolling Max Index\n\n    Parameters\n    ----------\n    feature : Expression\n        feature instance\n    N : int\n        rolling window size\n\n    Returns\n    ----------\n    Expression\n        a feature instance with rolling max index\n    \"\"\"\n\n    def __init__(self, feature, N):\n        super(IdxMax, self).__init__(feature, N, \"idxmax\")\n\n    def _load_internal(self, instrument, start_index, end_index, *args):\n        series = self.feature.load(instrument, start_index, end_index, *args)\n        if self.N == 0:\n            series = series.expanding(min_periods=1).apply(lambda x: x.argmax() + 1, raw=True)\n        else:\n            series = series.rolling(self.N, min_periods=1).apply(lambda x: x.argmax() + 1, raw=True)\n        return series\n\n\nclass Min(Rolling):\n    \"\"\"Rolling Min\n\n    Parameters\n    ----------\n    feature : Expression\n        feature instance\n    N : int\n        rolling window size\n\n    Returns\n    ----------\n    Expression\n        a feature instance with rolling min\n    \"\"\"\n\n    def __init__(self, feature, N):\n        super(Min, self).__init__(feature, N, \"min\")\n\n\nclass IdxMin(Rolling):\n    \"\"\"Rolling Min Index\n\n    Parameters\n    ----------\n    feature : Expression\n        feature instance\n    N : int\n        rolling window size\n\n    Returns\n    ----------\n    Expression\n        a feature instance with rolling min index\n    \"\"\"\n\n    def __init__(self, feature, N):\n        super(IdxMin, self).__init__(feature, N, \"idxmin\")\n\n    def _load_internal(self, instrument, start_index, end_index, *args):\n        series = self.feature.load(instrument, start_index, end_index, *args)\n        if self.N == 0:\n            series = series.expanding(min_periods=1).apply(lambda x: x.argmin() + 1, raw=True)\n        else:\n            series = series.rolling(self.N, min_periods=1).apply(lambda x: x.argmin() + 1, raw=True)\n        return series\n\n\nclass Quantile(Rolling):\n    \"\"\"Rolling Quantile\n\n    Parameters\n    ----------\n    feature : Expression\n        feature instance\n    N : int\n        rolling window size\n\n    Returns\n    ----------\n    Expression\n        a feature instance with rolling quantile\n    \"\"\"\n\n    def __init__(self, feature, N, qscore):\n        super(Quantile, self).__init__(feature, N, \"quantile\")\n        self.qscore = qscore\n\n    def __str__(self):\n        return \"{}({},{},{})\".format(type(self).__name__, self.feature, self.N, self.qscore)\n\n    def _load_internal(self, instrument, start_index, end_index, *args):\n        series = self.feature.load(instrument, start_index, end_index, *args)\n        if self.N == 0:\n            series = series.expanding(min_periods=1).quantile(self.qscore)\n        else:\n            series = series.rolling(self.N, min_periods=1).quantile(self.qscore)\n        return series\n\n\nclass Med(Rolling):\n    \"\"\"Rolling Median\n\n    Parameters\n    ----------\n    feature : Expression\n        feature instance\n    N : int\n        rolling window size\n\n    Returns\n    ----------\n    Expression\n        a feature instance with rolling median\n    \"\"\"\n\n    def __init__(self, feature, N):\n        super(Med, self).__init__(feature, N, \"median\")\n\n\nclass Mad(Rolling):\n    \"\"\"Rolling Mean Absolute Deviation\n\n    Parameters\n    ----------\n    feature : Expression\n        feature instance\n    N : int\n        rolling window size\n\n    Returns\n    ----------\n    Expression\n        a feature instance with rolling mean absolute deviation\n    \"\"\"\n\n    def __init__(self, feature, N):\n        super(Mad, self).__init__(feature, N, \"mad\")\n\n    def _load_internal(self, instrument, start_index, end_index, *args):\n        series = self.feature.load(instrument, start_index, end_index, *args)\n        # TODO: implement in Cython\n\n        def mad(x):\n            x1 = x[~np.isnan(x)]\n            return np.mean(np.abs(x1 - x1.mean()))\n\n        if self.N == 0:\n            series = series.expanding(min_periods=1).apply(mad, raw=True)\n        else:\n            series = series.rolling(self.N, min_periods=1).apply(mad, raw=True)\n        return series\n\n\nclass Rank(Rolling):\n    \"\"\"Rolling Rank (Percentile)\n\n    Parameters\n    ----------\n    feature : Expression\n        feature instance\n    N : int\n        rolling window size\n\n    Returns\n    ----------\n    Expression\n        a feature instance with rolling rank\n    \"\"\"\n\n    def __init__(self, feature, N):\n        super(Rank, self).__init__(feature, N, \"rank\")\n\n    # for compatiblity of python 3.7, which doesn't support pandas 1.4.0+ which implements Rolling.rank\n    def _load_internal(self, instrument, start_index, end_index, *args):\n        series = self.feature.load(instrument, start_index, end_index, *args)\n\n        rolling_or_expending = series.expanding(min_periods=1) if self.N == 0 else series.rolling(self.N, min_periods=1)\n        if hasattr(rolling_or_expending, \"rank\"):\n            return rolling_or_expending.rank(pct=True)\n\n        def rank(x):\n            if np.isnan(x[-1]):\n                return np.nan\n            x1 = x[~np.isnan(x)]\n            if x1.shape[0] == 0:\n                return np.nan\n            return percentileofscore(x1, x1[-1]) / 100\n\n        return rolling_or_expending.apply(rank, raw=True)\n\n\nclass Count(Rolling):\n    \"\"\"Rolling Count\n\n    Parameters\n    ----------\n    feature : Expression\n        feature instance\n    N : int\n        rolling window size\n\n    Returns\n    ----------\n    Expression\n        a feature instance with rolling count of number of non-NaN elements\n    \"\"\"\n\n    def __init__(self, feature, N):\n        super(Count, self).__init__(feature, N, \"count\")\n\n\nclass Delta(Rolling):\n    \"\"\"Rolling Delta\n\n    Parameters\n    ----------\n    feature : Expression\n        feature instance\n    N : int\n        rolling window size\n\n    Returns\n    ----------\n    Expression\n        a feature instance with end minus start in rolling window\n    \"\"\"\n\n    def __init__(self, feature, N):\n        super(Delta, self).__init__(feature, N, \"delta\")\n\n    def _load_internal(self, instrument, start_index, end_index, *args):\n        series = self.feature.load(instrument, start_index, end_index, *args)\n        if self.N == 0:\n            series = series - series.iloc[0]\n        else:\n            series = series - series.shift(self.N)\n        return series\n\n\n# TODO:\n# support pair-wise rolling like `Slope(A, B, N)`\nclass Slope(Rolling):\n    \"\"\"Rolling Slope\n    This operator calculate the slope between `idx` and `feature`.\n    (e.g. [<feature_t1>, <feature_t2>, <feature_t3>] and [1, 2, 3])\n\n    Usage Example:\n    - \"Slope($close, %d)/$close\"\n\n    # TODO:\n    # Some users may want pair-wise rolling like `Slope(A, B, N)`\n\n    Parameters\n    ----------\n    feature : Expression\n        feature instance\n    N : int\n        rolling window size\n\n    Returns\n    ----------\n    Expression\n        a feature instance with linear regression slope of given window\n    \"\"\"\n\n    def __init__(self, feature, N):\n        super(Slope, self).__init__(feature, N, \"slope\")\n\n    def _load_internal(self, instrument, start_index, end_index, *args):\n        series = self.feature.load(instrument, start_index, end_index, *args)\n        if self.N == 0:\n            series = pd.Series(expanding_slope(series.values), index=series.index)\n        else:\n            series = pd.Series(rolling_slope(series.values, self.N), index=series.index)\n        return series\n\n\nclass Rsquare(Rolling):\n    \"\"\"Rolling R-value Square\n\n    Parameters\n    ----------\n    feature : Expression\n        feature instance\n    N : int\n        rolling window size\n\n    Returns\n    ----------\n    Expression\n        a feature instance with linear regression r-value square of given window\n    \"\"\"\n\n    def __init__(self, feature, N):\n        super(Rsquare, self).__init__(feature, N, \"rsquare\")\n\n    def _load_internal(self, instrument, start_index, end_index, *args):\n        _series = self.feature.load(instrument, start_index, end_index, *args)\n        if self.N == 0:\n            series = pd.Series(expanding_rsquare(_series.values), index=_series.index)\n        else:\n            series = pd.Series(rolling_rsquare(_series.values, self.N), index=_series.index)\n            series.loc[np.isclose(_series.rolling(self.N, min_periods=1).std(), 0, atol=2e-05)] = np.nan\n        return series\n\n\nclass Resi(Rolling):\n    \"\"\"Rolling Regression Residuals\n\n    Parameters\n    ----------\n    feature : Expression\n        feature instance\n    N : int\n        rolling window size\n\n    Returns\n    ----------\n    Expression\n        a feature instance with regression residuals of given window\n    \"\"\"\n\n    def __init__(self, feature, N):\n        super(Resi, self).__init__(feature, N, \"resi\")\n\n    def _load_internal(self, instrument, start_index, end_index, *args):\n        series = self.feature.load(instrument, start_index, end_index, *args)\n        if self.N == 0:\n            series = pd.Series(expanding_resi(series.values), index=series.index)\n        else:\n            series = pd.Series(rolling_resi(series.values, self.N), index=series.index)\n        return series\n\n\nclass WMA(Rolling):\n    \"\"\"Rolling WMA\n\n    Parameters\n    ----------\n    feature : Expression\n        feature instance\n    N : int\n        rolling window size\n\n    Returns\n    ----------\n    Expression\n        a feature instance with weighted moving average output\n    \"\"\"\n\n    def __init__(self, feature, N):\n        super(WMA, self).__init__(feature, N, \"wma\")\n\n    def _load_internal(self, instrument, start_index, end_index, *args):\n        series = self.feature.load(instrument, start_index, end_index, *args)\n        # TODO: implement in Cython\n\n        def weighted_mean(x):\n            w = np.arange(len(x)) + 1\n            w = w / w.sum()\n            return np.nanmean(w * x)\n\n        if self.N == 0:\n            series = series.expanding(min_periods=1).apply(weighted_mean, raw=True)\n        else:\n            series = series.rolling(self.N, min_periods=1).apply(weighted_mean, raw=True)\n        return series\n\n\nclass EMA(Rolling):\n    \"\"\"Rolling Exponential Mean (EMA)\n\n    Parameters\n    ----------\n    feature : Expression\n        feature instance\n    N : int, float\n        rolling window size\n\n    Returns\n    ----------\n    Expression\n        a feature instance with regression r-value square of given window\n    \"\"\"\n\n    def __init__(self, feature, N):\n        super(EMA, self).__init__(feature, N, \"ema\")\n\n    def _load_internal(self, instrument, start_index, end_index, *args):\n        series = self.feature.load(instrument, start_index, end_index, *args)\n\n        def exp_weighted_mean(x):\n            a = 1 - 2 / (1 + len(x))\n            w = a ** np.arange(len(x))[::-1]\n            w /= w.sum()\n            return np.nansum(w * x)\n\n        if self.N == 0:\n            series = series.expanding(min_periods=1).apply(exp_weighted_mean, raw=True)\n        elif 0 < self.N < 1:\n            series = series.ewm(alpha=self.N, min_periods=1).mean()\n        else:\n            series = series.ewm(span=self.N, min_periods=1).mean()\n        return series\n\n\n#################### Pair-Wise Rolling ####################\nclass PairRolling(ExpressionOps):\n    \"\"\"Pair Rolling Operator\n\n    Parameters\n    ----------\n    feature_left : Expression\n        feature instance\n    feature_right : Expression\n        feature instance\n    N : int\n        rolling window size\n\n    Returns\n    ----------\n    Expression\n        a feature instance with rolling output of two input features\n    \"\"\"\n\n    def __init__(self, feature_left, feature_right, N, func):\n        # TODO: in what case will a const be passed into `__init__` as `feature_left` or `feature_right`\n        self.feature_left = feature_left\n        self.feature_right = feature_right\n        self.N = N\n        self.func = func\n\n    def __str__(self):\n        return \"{}({},{},{})\".format(type(self).__name__, self.feature_left, self.feature_right, self.N)\n\n    def _load_internal(self, instrument, start_index, end_index, *args):\n        assert any(\n            [isinstance(self.feature_left, Expression), self.feature_right, Expression]\n        ), \"at least one of two inputs is Expression instance\"\n\n        if isinstance(self.feature_left, Expression):\n            series_left = self.feature_left.load(instrument, start_index, end_index, *args)\n        else:\n            series_left = self.feature_left  # numeric value\n        if isinstance(self.feature_right, Expression):\n            series_right = self.feature_right.load(instrument, start_index, end_index, *args)\n        else:\n            series_right = self.feature_right\n\n        if self.N == 0:\n            series = getattr(series_left.expanding(min_periods=1), self.func)(series_right)\n        else:\n            series = getattr(series_left.rolling(self.N, min_periods=1), self.func)(series_right)\n        return series\n\n    def get_longest_back_rolling(self):\n        if self.N == 0:\n            return np.inf\n        if isinstance(self.feature_left, Expression):\n            left_br = self.feature_left.get_longest_back_rolling()\n        else:\n            left_br = 0\n\n        if isinstance(self.feature_right, Expression):\n            right_br = self.feature_right.get_longest_back_rolling()\n        else:\n            right_br = 0\n        return max(left_br, right_br)\n\n    def get_extended_window_size(self):\n        if isinstance(self.feature_left, Expression):\n            ll, lr = self.feature_left.get_extended_window_size()\n        else:\n            ll, lr = 0, 0\n        if isinstance(self.feature_right, Expression):\n            rl, rr = self.feature_right.get_extended_window_size()\n        else:\n            rl, rr = 0, 0\n        if self.N == 0:\n            get_module_logger(self.__class__.__name__).warning(\n                \"The PairRolling(ATTR, 0) will not be accurately calculated\"\n            )\n            return -np.inf, max(lr, rr)\n        else:\n            return max(ll, rl) + self.N - 1, max(lr, rr)\n\n\nclass Corr(PairRolling):\n    \"\"\"Rolling Correlation\n\n    Parameters\n    ----------\n    feature_left : Expression\n        feature instance\n    feature_right : Expression\n        feature instance\n    N : int\n        rolling window size\n\n    Returns\n    ----------\n    Expression\n        a feature instance with rolling correlation of two input features\n    \"\"\"\n\n    def __init__(self, feature_left, feature_right, N):\n        super(Corr, self).__init__(feature_left, feature_right, N, \"corr\")\n\n    def _load_internal(self, instrument, start_index, end_index, *args):\n        res: pd.Series = super(Corr, self)._load_internal(instrument, start_index, end_index, *args)\n\n        # NOTE: Load uses MemCache, so calling load again will not cause performance degradation\n        series_left = self.feature_left.load(instrument, start_index, end_index, *args)\n        series_right = self.feature_right.load(instrument, start_index, end_index, *args)\n        res.loc[\n            np.isclose(series_left.rolling(self.N, min_periods=1).std(), 0, atol=2e-05)\n            | np.isclose(series_right.rolling(self.N, min_periods=1).std(), 0, atol=2e-05)\n        ] = np.nan\n        return res\n\n\nclass Cov(PairRolling):\n    \"\"\"Rolling Covariance\n\n    Parameters\n    ----------\n    feature_left : Expression\n        feature instance\n    feature_right : Expression\n        feature instance\n    N : int\n        rolling window size\n\n    Returns\n    ----------\n    Expression\n        a feature instance with rolling max of two input features\n    \"\"\"\n\n    def __init__(self, feature_left, feature_right, N):\n        super(Cov, self).__init__(feature_left, feature_right, N, \"cov\")\n\n\n#################### Operator which only support data with time index ####################\n# Convention\n# - The name of the operators in this section will start with \"T\"\n\n\nclass TResample(ElemOperator):\n    def __init__(self, feature, freq, func):\n        \"\"\"\n        Resampling the data to target frequency.\n        The resample function of pandas is used.\n\n        - the timestamp will be at the start of the time span after resample.\n\n        Parameters\n        ----------\n        feature : Expression\n            An expression for calculating the feature\n        freq : str\n            It will be passed into the resample method for resampling basedn on given frequency\n        func : method\n            The method to get the resampled values\n            Some expression are high frequently used\n        \"\"\"\n        self.feature = feature\n        self.freq = freq\n        self.func = func\n\n    def __str__(self):\n        return \"{}({},{})\".format(type(self).__name__, self.feature, self.freq)\n\n    def _load_internal(self, instrument, start_index, end_index, *args):\n        series = self.feature.load(instrument, start_index, end_index, *args)\n\n        if series.empty:\n            return series\n        else:\n            if self.func == \"sum\":\n                return getattr(series.resample(self.freq), self.func)(min_count=1)\n            else:\n                return getattr(series.resample(self.freq), self.func)()\n\n\nTOpsList = [TResample]\nOpsList = [\n    ChangeInstrument,\n    Rolling,\n    Ref,\n    Max,\n    Min,\n    Sum,\n    Mean,\n    Std,\n    Var,\n    Skew,\n    Kurt,\n    Med,\n    Mad,\n    Slope,\n    Rsquare,\n    Resi,\n    Rank,\n    Quantile,\n    Count,\n    EMA,\n    WMA,\n    Corr,\n    Cov,\n    Delta,\n    Abs,\n    Sign,\n    Log,\n    Power,\n    Add,\n    Sub,\n    Mul,\n    Div,\n    Greater,\n    Less,\n    And,\n    Or,\n    Not,\n    Gt,\n    Ge,\n    Lt,\n    Le,\n    Eq,\n    Ne,\n    Mask,\n    IdxMax,\n    IdxMin,\n    If,\n    Feature,\n    PFeature,\n] + [TResample]\n\n\nclass OpsWrapper:\n    \"\"\"Ops Wrapper\"\"\"\n\n    def __init__(self):\n        self._ops = {}\n\n    def reset(self):\n        self._ops = {}\n\n    def register(self, ops_list: List[Union[Type[ExpressionOps], dict]]):\n        \"\"\"register operator\n\n        Parameters\n        ----------\n        ops_list : List[Union[Type[ExpressionOps], dict]]\n            - if type(ops_list) is List[Type[ExpressionOps]], each element of ops_list represents the operator class, which should be the subclass of `ExpressionOps`.\n            - if type(ops_list) is List[dict], each element of ops_list represents the config of operator, which has the following format:\n\n                .. code-block:: text\n\n                    {\n                        \"class\": class_name,\n                        \"module_path\": path,\n                    }\n\n                Note: `class` should be the class name of operator, `module_path` should be a python module or path of file.\n        \"\"\"\n        for _operator in ops_list:\n            if isinstance(_operator, dict):\n                _ops_class, _ = get_callable_kwargs(_operator)\n            else:\n                _ops_class = _operator\n\n            if not issubclass(_ops_class, (Expression,)):\n                raise TypeError(\"operator must be subclass of ExpressionOps, not {}\".format(_ops_class))\n\n            if _ops_class.__name__ in self._ops:\n                get_module_logger(self.__class__.__name__).warning(\n                    \"The custom operator [{}] will override the qlib default definition\".format(_ops_class.__name__)\n                )\n            self._ops[_ops_class.__name__] = _ops_class\n\n    def __getattr__(self, key):\n        if key not in self._ops:\n            raise AttributeError(\"The operator [{0}] is not registered\".format(key))\n        return self._ops[key]\n\n\nOperators = OpsWrapper()\n\n\ndef register_all_ops(C):\n    \"\"\"register all operator\"\"\"\n    logger = get_module_logger(\"ops\")\n\n    from qlib.data.pit import P, PRef  # pylint: disable=C0415\n\n    Operators.reset()\n    Operators.register(OpsList + [P, PRef])\n\n    if getattr(C, \"custom_ops\", None) is not None:\n        Operators.register(C.custom_ops)\n        logger.debug(\"register custom operator {}\".format(C.custom_ops))\n"
  },
  {
    "path": "qlib/data/pit.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\"\"\"\nQlib follow the logic below to supporting point-in-time database\n\nFor each stock, the format of its data is <observe_time, feature>. Expression Engine support calculation on such format of data\n\nTo calculate the feature value f_t at a specific observe time t,  data with format <period_time, feature> will be used.\nFor example, the average earning of last 4 quarters (period_time) on 20190719 (observe_time)\n\nThe calculation of both <period_time, feature> and <observe_time, feature> data rely on expression engine. It consists of 2 phases.\n1) calculation <period_time, feature> at each observation time t and it will collasped into a point (just like a normal feature)\n2) concatenate all th collasped data, we will get data with format <observe_time, feature>.\nQlib will use the operator `P` to perform the collapse.\n\"\"\"\n\nimport numpy as np\nimport pandas as pd\nfrom qlib.data.ops import ElemOperator\nfrom qlib.log import get_module_logger\nfrom .data import Cal\n\n\nclass P(ElemOperator):\n    def _load_internal(self, instrument, start_index, end_index, freq):\n        _calendar = Cal.calendar(freq=freq)\n        resample_data = np.empty(end_index - start_index + 1, dtype=\"float32\")\n\n        for cur_index in range(start_index, end_index + 1):\n            cur_time = _calendar[cur_index]\n            # To load expression accurately, more historical data are required\n            start_ws, end_ws = self.feature.get_extended_window_size()\n            if end_ws > 0:\n                raise ValueError(\n                    \"PIT database does not support referring to future period (e.g. expressions like `Ref('$$roewa_q', -1)` are not supported\"\n                )\n\n            # The calculated value will always the last element, so the end_offset is zero.\n            try:\n                s = self._load_feature(instrument, -start_ws, 0, cur_time)\n                resample_data[cur_index - start_index] = s.iloc[-1] if len(s) > 0 else np.nan\n            except FileNotFoundError:\n                get_module_logger(\"base\").warning(f\"WARN: period data not found for {str(self)}\")\n                return pd.Series(dtype=\"float32\", name=str(self))\n\n        resample_series = pd.Series(\n            resample_data, index=pd.RangeIndex(start_index, end_index + 1), dtype=\"float32\", name=str(self)\n        )\n        return resample_series\n\n    def _load_feature(self, instrument, start_index, end_index, cur_time):\n        return self.feature.load(instrument, start_index, end_index, cur_time)\n\n    def get_longest_back_rolling(self):\n        # The period data will collapse as a normal feature. So no extending and looking back\n        return 0\n\n    def get_extended_window_size(self):\n        # The period data will collapse as a normal feature. So no extending and looking back\n        return 0, 0\n\n\nclass PRef(P):\n    def __init__(self, feature, period):\n        super().__init__(feature)\n        self.period = period\n\n    def __str__(self):\n        return f\"{super().__str__()}[{self.period}]\"\n\n    def _load_feature(self, instrument, start_index, end_index, cur_time):\n        return self.feature.load(instrument, start_index, end_index, cur_time, self.period)\n"
  },
  {
    "path": "qlib/data/storage/__init__.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom .storage import CalendarStorage, InstrumentStorage, FeatureStorage, CalVT, InstVT, InstKT\n\n__all__ = [\"CalendarStorage\", \"InstrumentStorage\", \"FeatureStorage\", \"CalVT\", \"InstVT\", \"InstKT\"]\n"
  },
  {
    "path": "qlib/data/storage/file_storage.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport struct\nfrom pathlib import Path\nfrom typing import Iterable, Union, Dict, Mapping, Tuple, List\n\nimport numpy as np\nimport pandas as pd\n\nfrom qlib.utils.time import Freq\nfrom qlib.utils.resam import resam_calendar\nfrom qlib.config import C\nfrom qlib.data.cache import H\nfrom qlib.log import get_module_logger\nfrom qlib.data.storage import CalendarStorage, InstrumentStorage, FeatureStorage, CalVT, InstKT, InstVT\n\nlogger = get_module_logger(\"file_storage\")\n\n\nclass FileStorageMixin:\n    \"\"\"FileStorageMixin, applicable to FileXXXStorage\n    Subclasses need to have provider_uri, freq, storage_name, file_name attributes\n\n    \"\"\"\n\n    # NOTE: provider_uri priority:\n    #   1. self._provider_uri : if provider_uri is provided.\n    #   2. provider_uri in qlib.config.C\n\n    @property\n    def provider_uri(self):\n        return C[\"provider_uri\"] if getattr(self, \"_provider_uri\", None) is None else self._provider_uri\n\n    @property\n    def dpm(self):\n        return (\n            C.dpm\n            if getattr(self, \"_provider_uri\", None) is None\n            else C.DataPathManager(self._provider_uri, C.mount_path)\n        )\n\n    @property\n    def support_freq(self) -> List[str]:\n        _v = \"_support_freq\"\n        if hasattr(self, _v):\n            return getattr(self, _v)\n        if len(self.provider_uri) == 1 and C.DEFAULT_FREQ in self.provider_uri:\n            freq_l = filter(\n                lambda _freq: not _freq.endswith(\"_future\"),\n                map(lambda x: x.stem, self.dpm.get_data_uri(C.DEFAULT_FREQ).joinpath(\"calendars\").glob(\"*.txt\")),\n            )\n        else:\n            freq_l = self.provider_uri.keys()\n        freq_l = [Freq(freq) for freq in freq_l]\n        setattr(self, _v, freq_l)\n        return freq_l\n\n    @property\n    def uri(self) -> Path:\n        if self.freq not in self.support_freq:\n            raise ValueError(f\"{self.storage_name}: {self.provider_uri} does not contain data for {self.freq}\")\n        return self.dpm.get_data_uri(self.freq).joinpath(f\"{self.storage_name}s\", self.file_name)\n\n    def check(self):\n        \"\"\"check self.uri\n\n        Raises\n        -------\n        ValueError\n        \"\"\"\n        if not self.uri.exists():\n            raise ValueError(f\"{self.storage_name} not exists: {self.uri}\")\n\n\nclass FileCalendarStorage(FileStorageMixin, CalendarStorage):\n    def __init__(self, freq: str, future: bool, provider_uri: dict = None, **kwargs):\n        super(FileCalendarStorage, self).__init__(freq, future, **kwargs)\n        self.future = future\n        self._provider_uri = None if provider_uri is None else C.DataPathManager.format_provider_uri(provider_uri)\n        self.enable_read_cache = True  # TODO: make it configurable\n        self.region = C[\"region\"]\n\n    @property\n    def file_name(self) -> str:\n        return f\"{self._freq_file}_future.txt\" if self.future else f\"{self._freq_file}.txt\".lower()\n\n    @property\n    def _freq_file(self) -> str:\n        \"\"\"the freq to read from file\"\"\"\n        if not hasattr(self, \"_freq_file_cache\"):\n            freq = Freq(self.freq)\n            if freq not in self.support_freq:\n                # NOTE: uri\n                #   1. If `uri` does not exist\n                #       - Get the `min_uri` of the closest `freq` under the same \"directory\" as the `uri`\n                #       - Read data from `min_uri` and resample to `freq`\n\n                freq = Freq.get_recent_freq(freq, self.support_freq)\n                if freq is None:\n                    raise ValueError(f\"can't find a freq from {self.support_freq} that can resample to {self.freq}!\")\n            self._freq_file_cache = freq\n        return self._freq_file_cache\n\n    def _read_calendar(self) -> List[CalVT]:\n        # NOTE:\n        # if we want to accelerate partial reading calendar\n        # we can add parameters like `skip_rows: int = 0, n_rows: int = None` to the interface.\n        # Currently, it is not supported for the txt-based calendar\n\n        if not self.uri.exists():\n            self._write_calendar(values=[])\n\n        with self.uri.open(\"r\") as fp:\n            res = []\n            for line in fp.readlines():\n                line = line.strip()\n                if len(line) > 0:\n                    res.append(line)\n            return res\n\n    def _write_calendar(self, values: Iterable[CalVT], mode: str = \"wb\"):\n        with self.uri.open(mode=mode) as fp:\n            np.savetxt(fp, values, fmt=\"%s\", encoding=\"utf-8\")\n\n    @property\n    def uri(self) -> Path:\n        return self.dpm.get_data_uri(self._freq_file).joinpath(f\"{self.storage_name}s\", self.file_name)\n\n    @property\n    def data(self) -> List[CalVT]:\n        self.check()\n        # If cache is enabled, then return cache directly\n        if self.enable_read_cache:\n            key = \"orig_file\" + str(self.uri)\n            if key not in H[\"c\"]:\n                H[\"c\"][key] = self._read_calendar()\n            _calendar = H[\"c\"][key]\n        else:\n            _calendar = self._read_calendar()\n        if Freq(self._freq_file) != Freq(self.freq):\n            _calendar = resam_calendar(\n                np.array(list(map(pd.Timestamp, _calendar))), self._freq_file, self.freq, self.region\n            )\n        return _calendar\n\n    def _get_storage_freq(self) -> List[str]:\n        return sorted(set(map(lambda x: x.stem.split(\"_\")[0], self.uri.parent.glob(\"*.txt\"))))\n\n    def extend(self, values: Iterable[CalVT]) -> None:\n        self._write_calendar(values, mode=\"ab\")\n\n    def clear(self) -> None:\n        self._write_calendar(values=[])\n\n    def index(self, value: CalVT) -> int:\n        self.check()\n        calendar = self._read_calendar()\n        return int(np.argwhere(calendar == value)[0])\n\n    def insert(self, index: int, value: CalVT):\n        calendar = self._read_calendar()\n        calendar = np.insert(calendar, index, value)\n        self._write_calendar(values=calendar)\n\n    def remove(self, value: CalVT) -> None:\n        self.check()\n        index = self.index(value)\n        calendar = self._read_calendar()\n        calendar = np.delete(calendar, index)\n        self._write_calendar(values=calendar)\n\n    def __setitem__(self, i: Union[int, slice], values: Union[CalVT, Iterable[CalVT]]) -> None:\n        calendar = self._read_calendar()\n        calendar[i] = values\n        self._write_calendar(values=calendar)\n\n    def __delitem__(self, i: Union[int, slice]) -> None:\n        self.check()\n        calendar = self._read_calendar()\n        calendar = np.delete(calendar, i)\n        self._write_calendar(values=calendar)\n\n    def __getitem__(self, i: Union[int, slice]) -> Union[CalVT, List[CalVT]]:\n        self.check()\n        return self._read_calendar()[i]\n\n    def __len__(self) -> int:\n        return len(self.data)\n\n\nclass FileInstrumentStorage(FileStorageMixin, InstrumentStorage):\n    INSTRUMENT_SEP = \"\\t\"\n    INSTRUMENT_START_FIELD = \"start_datetime\"\n    INSTRUMENT_END_FIELD = \"end_datetime\"\n    SYMBOL_FIELD_NAME = \"instrument\"\n\n    def __init__(self, market: str, freq: str, provider_uri: dict = None, **kwargs):\n        super(FileInstrumentStorage, self).__init__(market, freq, **kwargs)\n        self._provider_uri = None if provider_uri is None else C.DataPathManager.format_provider_uri(provider_uri)\n        self.file_name = f\"{market.lower()}.txt\"\n\n    def _read_instrument(self) -> Dict[InstKT, InstVT]:\n        if not self.uri.exists():\n            self._write_instrument()\n\n        _instruments = dict()\n        df = pd.read_csv(\n            self.uri,\n            sep=\"\\t\",\n            usecols=[0, 1, 2],\n            names=[self.SYMBOL_FIELD_NAME, self.INSTRUMENT_START_FIELD, self.INSTRUMENT_END_FIELD],\n            dtype={self.SYMBOL_FIELD_NAME: str},\n            parse_dates=[self.INSTRUMENT_START_FIELD, self.INSTRUMENT_END_FIELD],\n        )\n        for row in df.itertuples(index=False):\n            _instruments.setdefault(row[0], []).append((row[1], row[2]))\n        return _instruments\n\n    def _write_instrument(self, data: Dict[InstKT, InstVT] = None) -> None:\n        if not data:\n            with self.uri.open(\"w\") as _:\n                pass\n            return\n\n        res = []\n        for inst, v_list in data.items():\n            _df = pd.DataFrame(v_list, columns=[self.INSTRUMENT_START_FIELD, self.INSTRUMENT_END_FIELD])\n            _df[self.SYMBOL_FIELD_NAME] = inst\n            res.append(_df)\n\n        df = pd.concat(res, sort=False)\n        df.loc[:, [self.SYMBOL_FIELD_NAME, self.INSTRUMENT_START_FIELD, self.INSTRUMENT_END_FIELD]].to_csv(\n            self.uri, header=False, sep=self.INSTRUMENT_SEP, index=False\n        )\n        df.to_csv(self.uri, sep=\"\\t\", encoding=\"utf-8\", header=False, index=False)\n\n    def clear(self) -> None:\n        self._write_instrument(data={})\n\n    @property\n    def data(self) -> Dict[InstKT, InstVT]:\n        self.check()\n        return self._read_instrument()\n\n    def __setitem__(self, k: InstKT, v: InstVT) -> None:\n        inst = self._read_instrument()\n        inst[k] = v\n        self._write_instrument(inst)\n\n    def __delitem__(self, k: InstKT) -> None:\n        self.check()\n        inst = self._read_instrument()\n        del inst[k]\n        self._write_instrument(inst)\n\n    def __getitem__(self, k: InstKT) -> InstVT:\n        self.check()\n        return self._read_instrument()[k]\n\n    def update(self, *args, **kwargs) -> None:\n        if len(args) > 1:\n            raise TypeError(f\"update expected at most 1 arguments, got {len(args)}\")\n        inst = self._read_instrument()\n        if args:\n            other = args[0]  # type: dict\n            if isinstance(other, Mapping):\n                for key in other:\n                    inst[key] = other[key]\n            elif hasattr(other, \"keys\"):\n                for key in other.keys():\n                    inst[key] = other[key]\n            else:\n                for key, value in other:\n                    inst[key] = value\n        for key, value in kwargs.items():\n            inst[key] = value\n\n        self._write_instrument(inst)\n\n    def __len__(self) -> int:\n        return len(self.data)\n\n\nclass FileFeatureStorage(FileStorageMixin, FeatureStorage):\n    def __init__(self, instrument: str, field: str, freq: str, provider_uri: dict = None, **kwargs):\n        super(FileFeatureStorage, self).__init__(instrument, field, freq, **kwargs)\n        self._provider_uri = None if provider_uri is None else C.DataPathManager.format_provider_uri(provider_uri)\n        self.file_name = f\"{instrument.lower()}/{field.lower()}.{freq.lower()}.bin\"\n\n    def clear(self):\n        with self.uri.open(\"wb\") as _:\n            pass\n\n    @property\n    def data(self) -> pd.Series:\n        return self[:]\n\n    def write(self, data_array: Union[List, np.ndarray], index: int = None) -> None:\n        if len(data_array) == 0:\n            logger.info(\n                \"len(data_array) == 0, write\"\n                \"if you need to clear the FeatureStorage, please execute: FeatureStorage.clear\"\n            )\n            return\n        if not self.uri.exists():\n            # write\n            index = 0 if index is None else index\n            with self.uri.open(\"wb\") as fp:\n                np.hstack([index, data_array]).astype(\"<f\").tofile(fp)\n        else:\n            if index is None or index > self.end_index:\n                # append\n                index = 0 if index is None else index\n                with self.uri.open(\"ab+\") as fp:\n                    np.hstack([[np.nan] * (index - self.end_index - 1), data_array]).astype(\"<f\").tofile(fp)\n            else:\n                # rewrite\n                with self.uri.open(\"rb+\") as fp:\n                    _old_data = np.fromfile(fp, dtype=\"<f\")\n                    _old_index = _old_data[0]\n                    _old_df = pd.DataFrame(\n                        _old_data[1:], index=range(_old_index, _old_index + len(_old_data) - 1), columns=[\"old\"]\n                    )\n                    fp.seek(0)\n                    _new_df = pd.DataFrame(data_array, index=range(index, index + len(data_array)), columns=[\"new\"])\n                    _df = pd.concat([_old_df, _new_df], sort=False, axis=1)\n                    _df = _df.reindex(range(_df.index.min(), _df.index.max() + 1))\n                    _df[\"new\"].fillna(_df[\"old\"]).values.astype(\"<f\").tofile(fp)\n\n    @property\n    def start_index(self) -> Union[int, None]:\n        if not self.uri.exists():\n            return None\n        with self.uri.open(\"rb\") as fp:\n            index = int(np.frombuffer(fp.read(4), dtype=\"<f\")[0])\n        return index\n\n    @property\n    def end_index(self) -> Union[int, None]:\n        if not self.uri.exists():\n            return None\n        # The next  data appending index point will be  `end_index + 1`\n        return self.start_index + len(self) - 1\n\n    def __getitem__(self, i: Union[int, slice]) -> Union[Tuple[int, float], pd.Series]:\n        if not self.uri.exists():\n            if isinstance(i, int):\n                return None, None\n            elif isinstance(i, slice):\n                return pd.Series(dtype=np.float32)\n            else:\n                raise TypeError(f\"type(i) = {type(i)}\")\n\n        storage_start_index = self.start_index\n        storage_end_index = self.end_index\n        with self.uri.open(\"rb\") as fp:\n            if isinstance(i, int):\n                if storage_start_index > i:\n                    raise IndexError(f\"{i}: start index is {storage_start_index}\")\n                fp.seek(4 * (i - storage_start_index) + 4)\n                return i, struct.unpack(\"f\", fp.read(4))[0]\n            elif isinstance(i, slice):\n                start_index = storage_start_index if i.start is None else i.start\n                end_index = storage_end_index if i.stop is None else i.stop - 1\n                si = max(start_index, storage_start_index)\n                if si > end_index:\n                    return pd.Series(dtype=np.float32)\n                fp.seek(4 * (si - storage_start_index) + 4)\n                # read n bytes\n                count = end_index - si + 1\n                data = np.frombuffer(fp.read(4 * count), dtype=\"<f\")\n                return pd.Series(data, index=pd.RangeIndex(si, si + len(data)))\n            else:\n                raise TypeError(f\"type(i) = {type(i)}\")\n\n    def __len__(self) -> int:\n        self.check()\n        return self.uri.stat().st_size // 4 - 1\n"
  },
  {
    "path": "qlib/data/storage/storage.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport re\nfrom typing import Iterable, overload, Tuple, List, Text, Union, Dict\n\nimport numpy as np\nimport pandas as pd\nfrom qlib.log import get_module_logger\n\n# calendar value type\nCalVT = str\n\n# instrument value\nInstVT = List[Tuple[CalVT, CalVT]]\n# instrument key\nInstKT = Text\n\nlogger = get_module_logger(\"storage\")\n\n\"\"\"\nIf the user is only using it in `qlib`, you can customize Storage to implement only the following methods:\n\nclass UserCalendarStorage(CalendarStorage):\n\n    @property\n    def data(self) -> Iterable[CalVT]:\n        '''get all data\n\n        Raises\n        ------\n        ValueError\n            If the data(storage) does not exist, raise ValueError\n        '''\n        raise NotImplementedError(\"Subclass of CalendarStorage must implement `data` method\")\n\n\nclass UserInstrumentStorage(InstrumentStorage):\n\n    @property\n    def data(self) -> Dict[InstKT, InstVT]:\n        '''get all data\n\n        Raises\n        ------\n        ValueError\n            If the data(storage) does not exist, raise ValueError\n        '''\n        raise NotImplementedError(\"Subclass of InstrumentStorage must implement `data` method\")\n\n\nclass UserFeatureStorage(FeatureStorage):\n\n    def __getitem__(self, s: slice) -> pd.Series:\n        '''x.__getitem__(slice(start: int, stop: int, step: int)) <==> x[start:stop:step]\n\n        Returns\n        -------\n            pd.Series(values, index=pd.RangeIndex(start, len(values))\n\n        Notes\n        -------\n        if data(storage) does not exist:\n            if isinstance(i, int):\n                return (None, None)\n            if isinstance(i,  slice):\n                # return empty pd.Series\n                return pd.Series(dtype=np.float32)\n        '''\n        raise NotImplementedError(\n            \"Subclass of FeatureStorage must implement `__getitem__(s: slice)` method\"\n        )\n\n\n\"\"\"\n\n\nclass BaseStorage:\n    @property\n    def storage_name(self) -> str:\n        return re.findall(\"[A-Z][^A-Z]*\", self.__class__.__name__)[-2].lower()\n\n\nclass CalendarStorage(BaseStorage):\n    \"\"\"\n    The behavior of CalendarStorage's methods and List's methods of the same name remain consistent\n    \"\"\"\n\n    def __init__(self, freq: str, future: bool, **kwargs):\n        self.freq = freq\n        self.future = future\n        self.kwargs = kwargs\n\n    @property\n    def data(self) -> Iterable[CalVT]:\n        \"\"\"get all data\n\n        Raises\n        ------\n        ValueError\n            If the data(storage) does not exist, raise ValueError\n        \"\"\"\n        raise NotImplementedError(\"Subclass of CalendarStorage must implement `data` method\")\n\n    def clear(self) -> None:\n        raise NotImplementedError(\"Subclass of CalendarStorage must implement `clear` method\")\n\n    def extend(self, iterable: Iterable[CalVT]) -> None:\n        raise NotImplementedError(\"Subclass of CalendarStorage must implement `extend` method\")\n\n    def index(self, value: CalVT) -> int:\n        \"\"\"\n        Raises\n        ------\n        ValueError\n            If the data(storage) does not exist, raise ValueError\n        \"\"\"\n        raise NotImplementedError(\"Subclass of CalendarStorage must implement `index` method\")\n\n    def insert(self, index: int, value: CalVT) -> None:\n        raise NotImplementedError(\"Subclass of CalendarStorage must implement `insert` method\")\n\n    def remove(self, value: CalVT) -> None:\n        raise NotImplementedError(\"Subclass of CalendarStorage must implement `remove` method\")\n\n    @overload\n    def __setitem__(self, i: int, value: CalVT) -> None:\n        \"\"\"x.__setitem__(i, o) <==> (x[i] = o)\"\"\"\n\n    @overload\n    def __setitem__(self, s: slice, value: Iterable[CalVT]) -> None:\n        \"\"\"x.__setitem__(s, o) <==> (x[s] = o)\"\"\"\n\n    def __setitem__(self, i, value) -> None:\n        raise NotImplementedError(\n            \"Subclass of CalendarStorage must implement `__setitem__(i: int, o: CalVT)`/`__setitem__(s: slice, o: Iterable[CalVT])`  method\"\n        )\n\n    @overload\n    def __delitem__(self, i: int) -> None:\n        \"\"\"x.__delitem__(i) <==> del x[i]\"\"\"\n\n    @overload\n    def __delitem__(self, i: slice) -> None:\n        \"\"\"x.__delitem__(slice(start: int, stop: int, step: int)) <==> del x[start:stop:step]\"\"\"\n\n    def __delitem__(self, i) -> None:\n        \"\"\"\n        Raises\n        ------\n        ValueError\n            If the data(storage) does not exist, raise ValueError\n        \"\"\"\n        raise NotImplementedError(\n            \"Subclass of CalendarStorage must implement `__delitem__(i: int)`/`__delitem__(s: slice)`  method\"\n        )\n\n    @overload\n    def __getitem__(self, s: slice) -> Iterable[CalVT]:\n        \"\"\"x.__getitem__(slice(start: int, stop: int, step: int)) <==> x[start:stop:step]\"\"\"\n\n    @overload\n    def __getitem__(self, i: int) -> CalVT:\n        \"\"\"x.__getitem__(i) <==> x[i]\"\"\"\n\n    def __getitem__(self, i) -> CalVT:\n        \"\"\"\n\n        Raises\n        ------\n        ValueError\n            If the data(storage) does not exist, raise ValueError\n\n        \"\"\"\n        raise NotImplementedError(\n            \"Subclass of CalendarStorage must implement `__getitem__(i: int)`/`__getitem__(s: slice)`  method\"\n        )\n\n    def __len__(self) -> int:\n        \"\"\"\n\n        Raises\n        ------\n        ValueError\n            If the data(storage) does not exist, raise ValueError\n\n        \"\"\"\n        raise NotImplementedError(\"Subclass of CalendarStorage must implement `__len__`  method\")\n\n\nclass InstrumentStorage(BaseStorage):\n    def __init__(self, market: str, freq: str, **kwargs):\n        self.market = market\n        self.freq = freq\n        self.kwargs = kwargs\n\n    @property\n    def data(self) -> Dict[InstKT, InstVT]:\n        \"\"\"get all data\n\n        Raises\n        ------\n        ValueError\n            If the data(storage) does not exist, raise ValueError\n        \"\"\"\n        raise NotImplementedError(\"Subclass of InstrumentStorage must implement `data` method\")\n\n    def clear(self) -> None:\n        raise NotImplementedError(\"Subclass of InstrumentStorage must implement `clear` method\")\n\n    def update(self, *args, **kwargs) -> None:\n        \"\"\"D.update([E, ]**F) -> None.  Update D from mapping/iterable E and F.\n\n        Notes\n        ------\n            If E present and has a .keys() method, does:     for k in E: D[k] = E[k]\n\n            If E present and lacks .keys() method, does:     for (k, v) in E: D[k] = v\n\n            In either case, this is followed by: for k, v in F.items(): D[k] = v\n\n        \"\"\"\n        raise NotImplementedError(\"Subclass of InstrumentStorage must implement `update` method\")\n\n    def __setitem__(self, k: InstKT, v: InstVT) -> None:\n        \"\"\"Set self[key] to value.\"\"\"\n        raise NotImplementedError(\"Subclass of InstrumentStorage must implement `__setitem__` method\")\n\n    def __delitem__(self, k: InstKT) -> None:\n        \"\"\"Delete self[key].\n\n        Raises\n        ------\n        ValueError\n            If the data(storage) does not exist, raise ValueError\n        \"\"\"\n        raise NotImplementedError(\"Subclass of InstrumentStorage must implement `__delitem__` method\")\n\n    def __getitem__(self, k: InstKT) -> InstVT:\n        \"\"\"x.__getitem__(k) <==> x[k]\"\"\"\n        raise NotImplementedError(\"Subclass of InstrumentStorage must implement `__getitem__` method\")\n\n    def __len__(self) -> int:\n        \"\"\"\n\n        Raises\n        ------\n        ValueError\n            If the data(storage) does not exist, raise ValueError\n\n        \"\"\"\n        raise NotImplementedError(\"Subclass of InstrumentStorage must implement `__len__`  method\")\n\n\nclass FeatureStorage(BaseStorage):\n    def __init__(self, instrument: str, field: str, freq: str, **kwargs):\n        self.instrument = instrument\n        self.field = field\n        self.freq = freq\n        self.kwargs = kwargs\n\n    @property\n    def data(self) -> pd.Series:\n        \"\"\"get all data\n\n        Notes\n        ------\n        if data(storage) does not exist, return empty pd.Series: `return pd.Series(dtype=np.float32)`\n        \"\"\"\n        raise NotImplementedError(\"Subclass of FeatureStorage must implement `data` method\")\n\n    @property\n    def start_index(self) -> Union[int, None]:\n        \"\"\"get FeatureStorage start index\n\n        Notes\n        -----\n        If the data(storage) does not exist, return None\n        \"\"\"\n        raise NotImplementedError(\"Subclass of FeatureStorage must implement `start_index` method\")\n\n    @property\n    def end_index(self) -> Union[int, None]:\n        \"\"\"get FeatureStorage end index\n\n        Notes\n        -----\n        The  right index of the data range (both sides are closed)\n\n            The next  data appending point will be  `end_index + 1`\n\n        If the data(storage) does not exist, return None\n        \"\"\"\n        raise NotImplementedError(\"Subclass of FeatureStorage must implement `end_index` method\")\n\n    def clear(self) -> None:\n        raise NotImplementedError(\"Subclass of FeatureStorage must implement `clear` method\")\n\n    def write(self, data_array: Union[List, np.ndarray, Tuple], index: int = None):\n        \"\"\"Write data_array to FeatureStorage starting from index.\n\n        Notes\n        ------\n            If index is None, append data_array to feature.\n\n            If len(data_array) == 0; return\n\n            If (index - self.end_index) >= 1, self[end_index+1: index] will be filled with np.nan\n\n        Examples\n        ---------\n            .. code-block::\n\n                feature:\n                    3   4\n                    4   5\n                    5   6\n\n\n            >>> self.write([6, 7], index=6)\n\n                feature:\n                    3   4\n                    4   5\n                    5   6\n                    6   6\n                    7   7\n\n            >>> self.write([8], index=9)\n\n                feature:\n                    3   4\n                    4   5\n                    5   6\n                    6   6\n                    7   7\n                    8   np.nan\n                    9   8\n\n            >>> self.write([1, np.nan], index=3)\n\n                feature:\n                    3   1\n                    4   np.nan\n                    5   6\n                    6   6\n                    7   7\n                    8   np.nan\n                    9   8\n\n        \"\"\"\n        raise NotImplementedError(\"Subclass of FeatureStorage must implement `write` method\")\n\n    def rebase(self, start_index: int = None, end_index: int = None):\n        \"\"\"Rebase the start_index and end_index of the FeatureStorage.\n\n        start_index and end_index are closed intervals: [start_index, end_index]\n\n        Examples\n        ---------\n\n            .. code-block::\n\n                    feature:\n                        3   4\n                        4   5\n                        5   6\n\n\n                >>> self.rebase(start_index=4)\n\n                    feature:\n                        4   5\n                        5   6\n\n                >>> self.rebase(start_index=3)\n\n                    feature:\n                        3   np.nan\n                        4   5\n                        5   6\n\n                >>> self.write([3], index=3)\n\n                    feature:\n                        3   3\n                        4   5\n                        5   6\n\n                >>> self.rebase(end_index=4)\n\n                    feature:\n                        3   3\n                        4   5\n\n                >>> self.write([6, 7, 8], index=4)\n\n                    feature:\n                        3   3\n                        4   6\n                        5   7\n                        6   8\n\n                >>> self.rebase(start_index=4, end_index=5)\n\n                    feature:\n                        4   6\n                        5   7\n\n        \"\"\"\n        storage_si = self.start_index\n        storage_ei = self.end_index\n        if storage_si is None or storage_ei is None:\n            raise ValueError(\"storage.start_index or storage.end_index is None, storage may not exist\")\n\n        start_index = storage_si if start_index is None else start_index\n        end_index = storage_ei if end_index is None else end_index\n\n        if start_index is None or end_index is None:\n            logger.warning(\"both start_index and end_index are None, or storage does not exist; rebase is ignored\")\n            return\n\n        if start_index < 0 or end_index < 0:\n            logger.warning(\"start_index or end_index cannot be less than 0\")\n            return\n        if start_index > end_index:\n            logger.warning(\n                f\"start_index({start_index}) > end_index({end_index}), rebase is ignored; \"\n                f\"if you need to clear the FeatureStorage, please execute: FeatureStorage.clear\"\n            )\n            return\n\n        if start_index <= storage_si:\n            self.write([np.nan] * (storage_si - start_index), start_index)\n        else:\n            self.rewrite(self[start_index:].values, start_index)\n\n        if end_index >= self.end_index:\n            self.write([np.nan] * (end_index - self.end_index))\n        else:\n            self.rewrite(self[: end_index + 1].values, start_index)\n\n    def rewrite(self, data: Union[List, np.ndarray, Tuple], index: int):\n        \"\"\"overwrite all data in FeatureStorage with data\n\n        Parameters\n        ----------\n        data: Union[List, np.ndarray, Tuple]\n            data\n        index: int\n            data start index\n        \"\"\"\n        self.clear()\n        self.write(data, index)\n\n    @overload\n    def __getitem__(self, s: slice) -> pd.Series:\n        \"\"\"x.__getitem__(slice(start: int, stop: int, step: int)) <==> x[start:stop:step]\n\n        Returns\n        -------\n            pd.Series(values, index=pd.RangeIndex(start, len(values))\n        \"\"\"\n\n    @overload\n    def __getitem__(self, i: int) -> Tuple[int, float]:\n        \"\"\"x.__getitem__(y) <==> x[y]\"\"\"\n\n    def __getitem__(self, i) -> Union[Tuple[int, float], pd.Series]:\n        \"\"\"x.__getitem__(y) <==> x[y]\n\n        Notes\n        -------\n        if data(storage) does not exist:\n            if isinstance(i, int):\n                return (None, None)\n            if isinstance(i,  slice):\n                # return empty pd.Series\n                return pd.Series(dtype=np.float32)\n        \"\"\"\n        raise NotImplementedError(\n            \"Subclass of FeatureStorage must implement `__getitem__(i: int)`/`__getitem__(s: slice)` method\"\n        )\n\n    def __len__(self) -> int:\n        \"\"\"\n\n        Raises\n        ------\n        ValueError\n            If the data(storage) does not exist, raise ValueError\n\n        \"\"\"\n        raise NotImplementedError(\"Subclass of FeatureStorage must implement `__len__`  method\")\n"
  },
  {
    "path": "qlib/log.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\n\nimport logging\nfrom typing import Optional, Text, Dict, Any\nimport re\nfrom logging import config as logging_config\nfrom time import time\nfrom contextlib import contextmanager\n\nfrom .config import C\n\n\nclass MetaLogger(type):\n    def __new__(mcs, name, bases, attrs):  # pylint: disable=C0204\n        wrapper_dict = logging.Logger.__dict__.copy()\n        for key, val in wrapper_dict.items():\n            if key not in attrs and key != \"__reduce__\":\n                attrs[key] = val\n        return type.__new__(mcs, name, bases, attrs)\n\n\nclass QlibLogger(metaclass=MetaLogger):\n    \"\"\"\n    Customized logger for Qlib.\n    \"\"\"\n\n    def __init__(self, module_name):\n        self.module_name = module_name\n        # this feature name conflicts with the attribute with Logger\n        # rename it to avoid some corner cases that result in comparing `str` and `int`\n        self.__level = 0\n\n    @property\n    def logger(self):\n        logger = logging.getLogger(self.module_name)\n        logger.setLevel(self.__level)\n        return logger\n\n    def setLevel(self, level):\n        self.__level = level\n\n    def __getattr__(self, name):\n        # During unpickling, python will call __getattr__. Use this line to avoid maximum recursion error.\n        if name in {\"__setstate__\"}:\n            raise AttributeError\n        return self.logger.__getattribute__(name)\n\n\nclass _QLibLoggerManager:\n    def __init__(self):\n        self._loggers = {}\n\n    def setLevel(self, level):\n        for logger in self._loggers.values():\n            logger.setLevel(level)\n\n    def __call__(self, module_name, level: Optional[int] = None) -> QlibLogger:\n        \"\"\"\n        Get a logger for a specific module.\n\n        :param module_name: str\n            Logic module name.\n        :param level: int\n        :return: Logger\n            Logger object.\n        \"\"\"\n        if level is None:\n            level = C.logging_level\n\n        if not module_name.startswith(\"qlib.\"):\n            # Add a prefix of qlib. when the requested ``module_name`` doesn't start with ``qlib.``.\n            # If the module_name is already qlib.xxx, we do not format here. Otherwise, it will become qlib.qlib.xxx.\n            module_name = \"qlib.{}\".format(module_name)\n\n        # Get logger.\n        module_logger = self._loggers.setdefault(module_name, QlibLogger(module_name))\n        module_logger.setLevel(level)\n        return module_logger\n\n\nget_module_logger = _QLibLoggerManager()\n\n\nclass TimeInspector:\n    timer_logger = get_module_logger(\"timer\")\n\n    time_marks = []\n\n    @classmethod\n    def set_time_mark(cls):\n        \"\"\"\n        Set a time mark with current time, and this time mark will push into a stack.\n        :return: float\n            A timestamp for current time.\n        \"\"\"\n        _time = time()\n        cls.time_marks.append(_time)\n        return _time\n\n    @classmethod\n    def pop_time_mark(cls):\n        \"\"\"\n        Pop last time mark from stack.\n        \"\"\"\n        return cls.time_marks.pop()\n\n    @classmethod\n    def get_cost_time(cls):\n        \"\"\"\n        Get last time mark from stack, calculate time diff with current time.\n        :return: float\n            Time diff calculated by last time mark with current time.\n        \"\"\"\n        cost_time = time() - cls.time_marks.pop()\n        return cost_time\n\n    @classmethod\n    def log_cost_time(cls, info=\"Done\"):\n        \"\"\"\n        Get last time mark from stack, calculate time diff with current time, and log time diff and info.\n        :param info: str\n            Info that will be logged into stdout.\n        \"\"\"\n        cost_time = time() - cls.time_marks.pop()\n        cls.timer_logger.info(\"Time cost: {0:.3f}s | {1}\".format(cost_time, info))\n\n    @classmethod\n    @contextmanager\n    def logt(cls, name=\"\", show_start=False):\n        \"\"\"logt.\n        Log the time of the inside code\n\n        Parameters\n        ----------\n        name :\n            name\n        show_start :\n            show_start\n        \"\"\"\n        if show_start:\n            cls.timer_logger.info(f\"{name} Begin\")\n        cls.set_time_mark()\n        try:\n            yield None\n        finally:\n            pass\n        cls.log_cost_time(info=f\"{name} Done\")\n\n\ndef set_log_with_config(log_config: Dict[Text, Any]):\n    \"\"\"set log with config\n\n    :param log_config:\n    :return:\n    \"\"\"\n    logging_config.dictConfig(log_config)\n\n\nclass LogFilter(logging.Filter):\n    def __init__(self, param=None):\n        super().__init__()\n        self.param = param\n\n    @staticmethod\n    def match_msg(filter_str, msg):\n        match = False\n        try:\n            if re.match(filter_str, msg):\n                match = True\n        except Exception:\n            pass\n        return match\n\n    def filter(self, record):\n        allow = True\n        if isinstance(self.param, str):\n            allow = not self.match_msg(self.param, record.msg)\n        elif isinstance(self.param, list):\n            allow = not any(self.match_msg(p, record.msg) for p in self.param)\n        return allow\n\n\ndef set_global_logger_level(level: int, return_orig_handler_level: bool = False):\n    \"\"\"set qlib.xxx logger handlers level\n\n    Parameters\n    ----------\n    level: int\n        logger level\n\n    return_orig_handler_level: bool\n        return origin handler level map\n\n    Examples\n    ---------\n\n        .. code-block:: python\n\n            import qlib\n            import logging\n            from qlib.log import get_module_logger, set_global_logger_level\n            qlib.init()\n\n            tmp_logger_01 = get_module_logger(\"tmp_logger_01\", level=logging.INFO)\n            tmp_logger_01.info(\"1. tmp_logger_01 info show\")\n\n            global_level = logging.WARNING + 1\n            set_global_logger_level(global_level)\n            tmp_logger_02 = get_module_logger(\"tmp_logger_02\", level=logging.INFO)\n            tmp_logger_02.log(msg=\"2. tmp_logger_02 log show\", level=global_level)\n\n            tmp_logger_01.info(\"3. tmp_logger_01 info do not show\")\n\n    \"\"\"\n    _handler_level_map = {}\n    qlib_logger = logging.root.manager.loggerDict.get(\"qlib\", None)  # pylint: disable=E1101\n    if qlib_logger is not None:\n        for _handler in qlib_logger.handlers:\n            _handler_level_map[_handler] = _handler.level\n            _handler.level = level\n    return _handler_level_map if return_orig_handler_level else None\n\n\n@contextmanager\ndef set_global_logger_level_cm(level: int):\n    \"\"\"set qlib.xxx logger handlers level to use contextmanager\n\n    Parameters\n    ----------\n    level: int\n        logger level\n\n    Examples\n    ---------\n\n        .. code-block:: python\n\n            import qlib\n            import logging\n            from qlib.log import get_module_logger, set_global_logger_level_cm\n            qlib.init()\n\n            tmp_logger_01 = get_module_logger(\"tmp_logger_01\", level=logging.INFO)\n            tmp_logger_01.info(\"1. tmp_logger_01 info show\")\n\n            global_level = logging.WARNING + 1\n            with set_global_logger_level_cm(global_level):\n                tmp_logger_02 = get_module_logger(\"tmp_logger_02\", level=logging.INFO)\n                tmp_logger_02.log(msg=\"2. tmp_logger_02 log show\", level=global_level)\n                tmp_logger_01.info(\"3. tmp_logger_01 info do not show\")\n\n            tmp_logger_01.info(\"4. tmp_logger_01 info show\")\n\n    \"\"\"\n    _handler_level_map = set_global_logger_level(level, return_orig_handler_level=True)\n    try:\n        yield\n    finally:\n        for _handler, _level in _handler_level_map.items():\n            _handler.level = _level\n"
  },
  {
    "path": "qlib/model/__init__.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport warnings\n\nfrom .base import Model\n\n__all__ = [\"Model\", \"warnings\"]\n"
  },
  {
    "path": "qlib/model/base.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\nimport abc\nfrom typing import Text, Union\nfrom ..utils.serial import Serializable\nfrom ..data.dataset import Dataset\nfrom ..data.dataset.weight import Reweighter\n\n\nclass BaseModel(Serializable, metaclass=abc.ABCMeta):\n    \"\"\"Modeling things\"\"\"\n\n    @abc.abstractmethod\n    def predict(self, *args, **kwargs) -> object:\n        \"\"\"Make predictions after modeling things\"\"\"\n\n    def __call__(self, *args, **kwargs) -> object:\n        \"\"\"leverage Python syntactic sugar to make the models' behaviors like functions\"\"\"\n        return self.predict(*args, **kwargs)\n\n\nclass Model(BaseModel):\n    \"\"\"Learnable Models\"\"\"\n\n    def fit(self, dataset: Dataset, reweighter: Reweighter):\n        \"\"\"\n        Learn model from the base model\n\n        .. note::\n\n            The attribute names of learned model should `not` start with '_'. So that the model could be\n            dumped to disk.\n\n        The following code example shows how to retrieve `x_train`, `y_train` and `w_train` from the `dataset`:\n\n            .. code-block:: Python\n\n                # get features and labels\n                df_train, df_valid = dataset.prepare(\n                    [\"train\", \"valid\"], col_set=[\"feature\", \"label\"], data_key=DataHandlerLP.DK_L\n                )\n                x_train, y_train = df_train[\"feature\"], df_train[\"label\"]\n                x_valid, y_valid = df_valid[\"feature\"], df_valid[\"label\"]\n\n                # get weights\n                try:\n                    wdf_train, wdf_valid = dataset.prepare([\"train\", \"valid\"], col_set=[\"weight\"],\n                                                           data_key=DataHandlerLP.DK_L)\n                    w_train, w_valid = wdf_train[\"weight\"], wdf_valid[\"weight\"]\n                except KeyError as e:\n                    w_train = pd.DataFrame(np.ones_like(y_train.values), index=y_train.index)\n                    w_valid = pd.DataFrame(np.ones_like(y_valid.values), index=y_valid.index)\n\n        Parameters\n        ----------\n        dataset : Dataset\n            dataset will generate the processed data from model training.\n\n        \"\"\"\n        raise NotImplementedError()\n\n    @abc.abstractmethod\n    def predict(self, dataset: Dataset, segment: Union[Text, slice] = \"test\") -> object:\n        \"\"\"give prediction given Dataset\n\n        Parameters\n        ----------\n        dataset : Dataset\n            dataset will generate the processed dataset from model training.\n\n        segment : Text or slice\n            dataset will use this segment to prepare data. (default=test)\n\n        Returns\n        -------\n        Prediction results with certain type such as `pandas.Series`.\n        \"\"\"\n        raise NotImplementedError()\n\n\nclass ModelFT(Model):\n    \"\"\"Model (F)ine(t)unable\"\"\"\n\n    @abc.abstractmethod\n    def finetune(self, dataset: Dataset):\n        \"\"\"finetune model based given dataset\n\n        A typical use case of finetuning model with qlib.workflow.R\n\n        .. code-block:: python\n\n            # start exp to train init model\n            with R.start(experiment_name=\"init models\"):\n                model.fit(dataset)\n                R.save_objects(init_model=model)\n                rid = R.get_recorder().id\n\n            # Finetune model based on previous trained model\n            with R.start(experiment_name=\"finetune model\"):\n                recorder = R.get_recorder(recorder_id=rid, experiment_name=\"init models\")\n                model = recorder.load_object(\"init_model\")\n                model.finetune(dataset, num_boost_round=10)\n\n\n        Parameters\n        ----------\n        dataset : Dataset\n            dataset will generate the processed dataset from model training.\n        \"\"\"\n        raise NotImplementedError()\n"
  },
  {
    "path": "qlib/model/ens/__init__.py",
    "content": ""
  },
  {
    "path": "qlib/model/ens/ensemble.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\n\"\"\"\nEnsemble module can merge the objects in an Ensemble. For example, if there are many submodels predictions, we may need to merge them into an ensemble prediction.\n\"\"\"\n\nfrom typing import Union\nimport pandas as pd\nfrom qlib.utils import FLATTEN_TUPLE, flatten_dict\nfrom qlib.log import get_module_logger\n\n\nclass Ensemble:\n    \"\"\"Merge the ensemble_dict into an ensemble object.\n\n    For example: {Rollinga_b: object, Rollingb_c: object} -> object\n\n    When calling this class:\n\n        Args:\n            ensemble_dict (dict): the ensemble dict like {name: things} waiting for merging\n\n        Returns:\n            object: the ensemble object\n    \"\"\"\n\n    def __call__(self, ensemble_dict: dict, *args, **kwargs):\n        raise NotImplementedError(f\"Please implement the `__call__` method.\")\n\n\nclass SingleKeyEnsemble(Ensemble):\n    \"\"\"\n    Extract the object if there is only one key and value in the dict. Make the result more readable.\n    {Only key: Only value} -> Only value\n\n    If there is more than 1 key or less than 1 key, then do nothing.\n    Even you can run this recursively to make dict more readable.\n\n    NOTE: Default runs recursively.\n\n    When calling this class:\n\n        Args:\n            ensemble_dict (dict): the dict. The key of the dict will be ignored.\n\n        Returns:\n            dict: the readable dict.\n    \"\"\"\n\n    def __call__(self, ensemble_dict: Union[dict, object], recursion: bool = True) -> object:\n        if not isinstance(ensemble_dict, dict):\n            return ensemble_dict\n        if recursion:\n            tmp_dict = {}\n            for k, v in ensemble_dict.items():\n                tmp_dict[k] = self(v, recursion)\n            ensemble_dict = tmp_dict\n        keys = list(ensemble_dict.keys())\n        if len(keys) == 1:\n            ensemble_dict = ensemble_dict[keys[0]]\n        return ensemble_dict\n\n\nclass RollingEnsemble(Ensemble):\n    \"\"\"Merge a dict of rolling dataframe like `prediction` or `IC` into an ensemble.\n\n    NOTE: The values of dict must be pd.DataFrame, and have the index \"datetime\".\n\n    When calling this class:\n\n        Args:\n            ensemble_dict (dict): a dict like {\"A\": pd.DataFrame, \"B\": pd.DataFrame}.\n            The key of the dict will be ignored.\n\n        Returns:\n            pd.DataFrame: the complete result of rolling.\n    \"\"\"\n\n    def __call__(self, ensemble_dict: dict) -> pd.DataFrame:\n        get_module_logger(\"RollingEnsemble\").info(f\"keys in group: {list(ensemble_dict.keys())}\")\n        artifact_list = list(ensemble_dict.values())\n        artifact_list.sort(key=lambda x: x.index.get_level_values(\"datetime\").min())\n        artifact = pd.concat(artifact_list)\n        # If there are duplicated predition, use the latest perdiction\n        artifact = artifact[~artifact.index.duplicated(keep=\"last\")]\n        artifact = artifact.sort_index()\n        return artifact\n\n\nclass AverageEnsemble(Ensemble):\n    \"\"\"\n    Average and standardize a dict of same shape dataframe like `prediction` or `IC` into an ensemble.\n\n    NOTE: The values of dict must be pd.DataFrame, and have the index \"datetime\". If it is a nested dict, then flat it.\n\n    When calling this class:\n\n        Args:\n            ensemble_dict (dict): a dict like {\"A\": pd.DataFrame, \"B\": pd.DataFrame}.\n            The key of the dict will be ignored.\n\n        Returns:\n            pd.DataFrame: the complete result of averaging and standardizing.\n    \"\"\"\n\n    def __call__(self, ensemble_dict: dict) -> pd.DataFrame:\n        \"\"\"using sample:\n        from qlib.model.ens.ensemble import AverageEnsemble\n        pred_res['new_key_name'] = AverageEnsemble()(predict_dict)\n\n        Parameters\n        ----------\n        ensemble_dict : dict\n            Dictionary you want to ensemble\n\n        Returns\n        -------\n        pd.DataFrame\n            The dictionary including ensenbling result\n        \"\"\"\n        # need to flatten the nested dict\n        ensemble_dict = flatten_dict(ensemble_dict, sep=FLATTEN_TUPLE)\n        get_module_logger(\"AverageEnsemble\").info(f\"keys in group: {list(ensemble_dict.keys())}\")\n        values = list(ensemble_dict.values())\n        # NOTE: this may change the style underlying data!!!!\n        # from pd.DataFrame to pd.Series\n        results = pd.concat(values, axis=1)\n        results = results.groupby(\"datetime\", group_keys=False).apply(lambda df: (df - df.mean()) / df.std())\n        results = results.mean(axis=1)\n        results = results.sort_index()\n        return results\n"
  },
  {
    "path": "qlib/model/ens/group.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\n\"\"\"\nGroup can group a set of objects based on `group_func` and change them to a dict.\nAfter group, we provide a method to reduce them.\n\nFor example:\n\ngroup: {(A,B,C1): object, (A,B,C2): object} -> {(A,B): {C1: object, C2: object}}\nreduce: {(A,B): {C1: object, C2: object}} -> {(A,B): object}\n\n\"\"\"\n\nfrom qlib.model.ens.ensemble import Ensemble, RollingEnsemble\nfrom typing import Callable\nfrom joblib import Parallel, delayed\n\n\nclass Group:\n    \"\"\"Group the objects based on dict\"\"\"\n\n    def __init__(self, group_func=None, ens: Ensemble = None):\n        \"\"\"\n        Init Group.\n\n        Args:\n            group_func (Callable, optional): Given a dict and return the group key and one of the group elements.\n\n                For example: {(A,B,C1): object, (A,B,C2): object} -> {(A,B): {C1: object, C2: object}}\n\n            Defaults to None.\n\n            ens (Ensemble, optional): If not None, do ensemble for grouped value after grouping.\n        \"\"\"\n        self._group_func = group_func\n        self._ens_func = ens\n\n    def group(self, *args, **kwargs) -> dict:\n        \"\"\"\n        Group a set of objects and change them to a dict.\n\n        For example: {(A,B,C1): object, (A,B,C2): object} -> {(A,B): {C1: object, C2: object}}\n\n        Returns:\n            dict: grouped dict\n        \"\"\"\n        if isinstance(getattr(self, \"_group_func\", None), Callable):\n            return self._group_func(*args, **kwargs)\n        else:\n            raise NotImplementedError(f\"Please specify valid `group_func`.\")\n\n    def reduce(self, *args, **kwargs) -> dict:\n        \"\"\"\n        Reduce grouped dict.\n\n        For example: {(A,B): {C1: object, C2: object}} -> {(A,B): object}\n\n        Returns:\n            dict: reduced dict\n        \"\"\"\n        if isinstance(getattr(self, \"_ens_func\", None), Callable):\n            return self._ens_func(*args, **kwargs)\n        else:\n            raise NotImplementedError(f\"Please specify valid `_ens_func`.\")\n\n    def __call__(self, ungrouped_dict: dict, n_jobs: int = 1, verbose: int = 0, *args, **kwargs) -> dict:\n        \"\"\"\n        Group the ungrouped_dict into different groups.\n\n        Args:\n            ungrouped_dict (dict): the ungrouped dict waiting for grouping like {name: things}\n\n        Returns:\n            dict: grouped_dict like {G1: object, G2: object}\n            n_jobs: how many progress you need.\n            verbose: the print mode for Parallel.\n        \"\"\"\n\n        # NOTE: The multiprocessing will raise error if you use `Serializable`\n        # Because the `Serializable` will affect the behaviors of pickle\n        grouped_dict = self.group(ungrouped_dict, *args, **kwargs)\n\n        key_l = []\n        job_l = []\n        for key, value in grouped_dict.items():\n            key_l.append(key)\n            job_l.append(delayed(Group.reduce)(self, value))\n        return dict(zip(key_l, Parallel(n_jobs=n_jobs, verbose=verbose)(job_l)))\n\n\nclass RollingGroup(Group):\n    \"\"\"Group the rolling dict\"\"\"\n\n    def group(self, rolling_dict: dict) -> dict:\n        \"\"\"Given an rolling dict likes {(A,B,R): things}, return the grouped dict likes {(A,B): {R:things}}\n\n        NOTE: There is an assumption which is the rolling key is at the end of the key tuple, because the rolling results always need to be ensemble firstly.\n\n        Args:\n            rolling_dict (dict): an rolling dict. If the key is not a tuple, then do nothing.\n\n        Returns:\n            dict: grouped dict\n        \"\"\"\n        grouped_dict = {}\n        for key, values in rolling_dict.items():\n            if isinstance(key, tuple):\n                grouped_dict.setdefault(key[:-1], {})[key[-1]] = values\n            else:\n                raise TypeError(f\"Expected `tuple` type, but got a value `{key}`\")\n        return grouped_dict\n\n    def __init__(self, ens=RollingEnsemble()):\n        super().__init__(ens=ens)\n"
  },
  {
    "path": "qlib/model/interpret/__init__.py",
    "content": ""
  },
  {
    "path": "qlib/model/interpret/base.py",
    "content": "#  Copyright (c) Microsoft Corporation.\n#  Licensed under the MIT License.\n\n\"\"\"\nInterfaces to interpret models\n\"\"\"\n\nimport pandas as pd\nfrom abc import abstractmethod\n\n\nclass FeatureInt:\n    \"\"\"Feature (Int)erpreter\"\"\"\n\n    @abstractmethod\n    def get_feature_importance(self) -> pd.Series:\n        \"\"\"get feature importance\n\n        Returns\n        -------\n            The index is the feature name.\n\n            The greater the value, the higher importance.\n        \"\"\"\n\n\nclass LightGBMFInt(FeatureInt):\n    \"\"\"LightGBM (F)eature (Int)erpreter\"\"\"\n\n    def __init__(self):\n        self.model = None\n\n    def get_feature_importance(self, *args, **kwargs) -> pd.Series:\n        \"\"\"get feature importance\n\n        Notes\n        -----\n            parameters reference:\n            https://lightgbm.readthedocs.io/en/latest/pythonapi/lightgbm.Booster.html?highlight=feature_importance#lightgbm.Booster.feature_importance\n        \"\"\"\n        return pd.Series(\n            self.model.feature_importance(*args, **kwargs), index=self.model.feature_name()\n        ).sort_values(  # pylint: disable=E1101\n            ascending=False\n        )\n"
  },
  {
    "path": "qlib/model/meta/__init__.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom .task import MetaTask\nfrom .dataset import MetaTaskDataset\n\n__all__ = [\"MetaTask\", \"MetaTaskDataset\"]\n"
  },
  {
    "path": "qlib/model/meta/dataset.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport abc\nfrom qlib.model.meta.task import MetaTask\nfrom typing import Dict, Union, List, Tuple, Text\nfrom ...utils.serial import Serializable\n\n\nclass MetaTaskDataset(Serializable, metaclass=abc.ABCMeta):\n    \"\"\"\n    A dataset fetching the data in a meta-level.\n\n    A Meta Dataset is responsible for\n\n    - input tasks(e.g. Qlib tasks) and prepare meta tasks\n\n        - meta task contains more information than normal tasks (e.g. input data for meta model)\n\n    The learnt pattern could transfer to other meta dataset. The following cases should be supported\n\n    - A meta-model trained on meta-dataset A and then applied to meta-dataset B\n\n        - Some pattern are shared between meta-dataset A and B, so meta-input on meta-dataset A are used when meta model are applied on meta-dataset-B\n    \"\"\"\n\n    def __init__(self, segments: Union[Dict[Text, Tuple], float], *args, **kwargs):\n        \"\"\"\n        The meta-dataset maintains a list of meta-tasks when it is initialized.\n\n        The segments indicates the way to divide the data\n\n        The duty of the `__init__` function of MetaTaskDataset\n        - initialize the tasks\n        \"\"\"\n        super().__init__(*args, **kwargs)\n        self.segments = segments\n\n    def prepare_tasks(self, segments: Union[List[Text], Text], *args, **kwargs) -> List[MetaTask]:\n        \"\"\"\n        Prepare the data in each meta-task and ready for training.\n\n        The following code example shows how to retrieve a list of meta-tasks from the `meta_dataset`:\n\n            .. code-block:: Python\n\n                # get the train segment and the test segment, both of them are lists\n                train_meta_tasks, test_meta_tasks = meta_dataset.prepare_tasks([\"train\", \"test\"])\n\n        Parameters\n        ----------\n        segments: Union[List[Text], Tuple[Text], Text]\n            the info to select data\n\n        Returns\n        -------\n        list:\n            A list of the prepared data of each meta-task for training the meta-model. For multiple segments [seg1, seg2, ... , segN], the returned list will be [[tasks in seg1], [tasks in seg2], ... , [tasks in segN]].\n            Each task is a meta task\n        \"\"\"\n        if isinstance(segments, (list, tuple)):\n            return [self._prepare_seg(seg) for seg in segments]\n        elif isinstance(segments, str):\n            return self._prepare_seg(segments)\n        else:\n            raise NotImplementedError(f\"This type of input is not supported\")\n\n    @abc.abstractmethod\n    def _prepare_seg(self, segment: Text):\n        \"\"\"\n        prepare a single segment of data for training data\n\n        Parameters\n        ----------\n        seg : Text\n            the name of the segment\n        \"\"\"\n"
  },
  {
    "path": "qlib/model/meta/model.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport abc\nfrom typing import List\n\nfrom .dataset import MetaTaskDataset\n\n\nclass MetaModel(metaclass=abc.ABCMeta):\n    \"\"\"\n    The meta-model guiding the model learning.\n\n    The word `Guiding` can be categorized into two types based on the stage of model learning\n    - The definition of learning tasks:  Please refer to docs of `MetaTaskModel`\n    - Controlling the learning process of models: Please refer to the docs of `MetaGuideModel`\n    \"\"\"\n\n    @abc.abstractmethod\n    def fit(self, *args, **kwargs):\n        \"\"\"\n        The training process of the meta-model.\n        \"\"\"\n\n    @abc.abstractmethod\n    def inference(self, *args, **kwargs) -> object:\n        \"\"\"\n        The inference process of the meta-model.\n\n        Returns\n        -------\n        object:\n            Some information to guide the model learning\n        \"\"\"\n\n\nclass MetaTaskModel(MetaModel):\n    \"\"\"\n    This type of meta-model deals with base task definitions. The meta-model creates tasks for training new base forecasting models after it is trained. `prepare_tasks` directly modifies the task definitions.\n    \"\"\"\n\n    def fit(self, meta_dataset: MetaTaskDataset):\n        \"\"\"\n        The MetaTaskModel is expected to get prepared MetaTask from meta_dataset.\n        And then it will learn knowledge from the meta tasks\n        \"\"\"\n        raise NotImplementedError(f\"Please implement the `fit` method\")\n\n    def inference(self, meta_dataset: MetaTaskDataset) -> List[dict]:\n        \"\"\"\n        MetaTaskModel will make inference on the meta_dataset\n        The MetaTaskModel is expected to get prepared MetaTask from meta_dataset.\n        Then it will create modified task with Qlib format which can be executed by Qlib trainer.\n\n        Returns\n        -------\n        List[dict]:\n            A list of modified task definitions.\n\n        \"\"\"\n        raise NotImplementedError(f\"Please implement the `inference` method\")\n\n\nclass MetaGuideModel(MetaModel):\n    \"\"\"\n    This type of meta-model aims to guide the training process of the base model. The meta-model interacts with the base forecasting models during their training process.\n    \"\"\"\n\n    @abc.abstractmethod\n    def fit(self, *args, **kwargs):\n        pass\n\n    @abc.abstractmethod\n    def inference(self, *args, **kwargs):\n        pass\n"
  },
  {
    "path": "qlib/model/meta/task.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom qlib.data.dataset import Dataset\nfrom ...utils import init_instance_by_config\n\n\nclass MetaTask:\n    \"\"\"\n    A single meta-task, a meta-dataset contains a list of them.\n    It serves as a component as in MetaDatasetDS\n\n    The data processing is different\n\n    - the processed input may be different between training and testing\n\n        - When training, the X, y, X_test, y_test in training tasks are necessary (# PROC_MODE_FULL #)\n          but not necessary in test tasks. (# PROC_MODE_TEST #)\n        - When the meta model can be transferred into other dataset, only meta_info is necessary  (# PROC_MODE_TRANSFER #)\n    \"\"\"\n\n    PROC_MODE_FULL = \"full\"\n    PROC_MODE_TEST = \"test\"\n    PROC_MODE_TRANSFER = \"transfer\"\n\n    def __init__(self, task: dict, meta_info: object, mode: str = PROC_MODE_FULL):\n        \"\"\"\n        The `__init__` func is responsible for\n\n        - store the task\n        - store the origin input data for\n        - process the input data for meta data\n\n        Parameters\n        ----------\n        task : dict\n            the task to be enhanced by meta model\n\n        meta_info : object\n            the input for meta model\n        \"\"\"\n        self.task = task\n        self.meta_info = meta_info  # the original meta input information, it will be processed later\n        self.mode = mode\n\n    def get_dataset(self) -> Dataset:\n        return init_instance_by_config(self.task[\"dataset\"], accept_types=Dataset)\n\n    def get_meta_input(self) -> object:\n        \"\"\"\n        Return the **processed** meta_info\n        \"\"\"\n        return self.meta_info\n\n    def __repr__(self):\n        return f\"MetaTask(task={self.task}, meta_info={self.meta_info})\"\n"
  },
  {
    "path": "qlib/model/riskmodel/__init__.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom .base import RiskModel\nfrom .poet import POETCovEstimator\nfrom .shrink import ShrinkCovEstimator\nfrom .structured import StructuredCovEstimator\n\n__all__ = [\n    \"RiskModel\",\n    \"POETCovEstimator\",\n    \"ShrinkCovEstimator\",\n    \"StructuredCovEstimator\",\n]\n"
  },
  {
    "path": "qlib/model/riskmodel/base.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport inspect\nimport numpy as np\nimport pandas as pd\nfrom typing import Union\n\nfrom qlib.model.base import BaseModel\n\n\nclass RiskModel(BaseModel):\n    \"\"\"Risk Model\n\n    A risk model is used to estimate the covariance matrix of stock returns.\n    \"\"\"\n\n    MASK_NAN = \"mask\"\n    FILL_NAN = \"fill\"\n    IGNORE_NAN = \"ignore\"\n\n    def __init__(self, nan_option: str = \"ignore\", assume_centered: bool = False, scale_return: bool = True):\n        \"\"\"\n        Args:\n            nan_option (str): nan handling option (`ignore`/`mask`/`fill`).\n            assume_centered (bool): whether the data is assumed to be centered.\n            scale_return (bool): whether scale returns as percentage.\n        \"\"\"\n        # nan\n        assert nan_option in [\n            self.MASK_NAN,\n            self.FILL_NAN,\n            self.IGNORE_NAN,\n        ], f\"`nan_option={nan_option}` is not supported\"\n        self.nan_option = nan_option\n\n        self.assume_centered = assume_centered\n        self.scale_return = scale_return\n\n    def predict(\n        self,\n        X: Union[pd.Series, pd.DataFrame, np.ndarray],\n        return_corr: bool = False,\n        is_price: bool = True,\n        return_decomposed_components=False,\n    ) -> Union[pd.DataFrame, np.ndarray, tuple]:\n        \"\"\"\n        Args:\n            X (pd.Series, pd.DataFrame or np.ndarray): data from which to estimate the covariance,\n                with variables as columns and observations as rows.\n            return_corr (bool): whether return the correlation matrix.\n            is_price (bool): whether `X` contains price (if not assume stock returns).\n            return_decomposed_components (bool): whether return decomposed components of the covariance matrix.\n\n        Returns:\n            pd.DataFrame or np.ndarray: estimated covariance (or correlation).\n        \"\"\"\n        assert (\n            not return_corr or not return_decomposed_components\n        ), \"Can only return either correlation matrix or decomposed components.\"\n\n        # transform input into 2D array\n        if not isinstance(X, (pd.Series, pd.DataFrame)):\n            columns = None\n        else:\n            if isinstance(X.index, pd.MultiIndex):\n                if isinstance(X, pd.DataFrame):\n                    X = X.iloc[:, 0].unstack(level=\"instrument\")  # always use the first column\n                else:\n                    X = X.unstack(level=\"instrument\")\n            else:\n                # X is 2D DataFrame\n                pass\n            columns = X.columns  # will be used to restore dataframe\n            X = X.values\n\n        # calculate pct_change\n        if is_price:\n            X = X[1:] / X[:-1] - 1  # NOTE: resulting `n - 1` rows\n\n        # scale return\n        if self.scale_return:\n            X *= 100\n\n        # handle nan and centered\n        X = self._preprocess(X)\n\n        # return decomposed components if needed\n        if return_decomposed_components:\n            assert (\n                \"return_decomposed_components\" in inspect.getfullargspec(self._predict).args\n            ), \"This risk model does not support return decomposed components of the covariance matrix \"\n\n            F, cov_b, var_u = self._predict(X, return_decomposed_components=True)  # pylint: disable=E1123\n            return F, cov_b, var_u\n\n        # estimate covariance\n        S = self._predict(X)\n\n        # return correlation if needed\n        if return_corr:\n            vola = np.sqrt(np.diag(S))\n            corr = S / np.outer(vola, vola)\n            if columns is None:\n                return corr\n            return pd.DataFrame(corr, index=columns, columns=columns)\n\n        # return covariance\n        if columns is None:\n            return S\n        return pd.DataFrame(S, index=columns, columns=columns)\n\n    def _predict(self, X: np.ndarray) -> np.ndarray:\n        \"\"\"covariance estimation implementation\n\n        This method should be overridden by child classes.\n\n        By default, this method implements the empirical covariance estimation.\n\n        Args:\n            X (np.ndarray): data matrix containing multiple variables (columns) and observations (rows).\n\n        Returns:\n            np.ndarray: covariance matrix.\n        \"\"\"\n        xTx = np.asarray(X.T.dot(X))\n        N = len(X)\n        if isinstance(X, np.ma.MaskedArray):\n            M = 1 - X.mask\n            N = M.T.dot(M)  # each pair has distinct number of samples\n        return xTx / N\n\n    def _preprocess(self, X: np.ndarray) -> Union[np.ndarray, np.ma.MaskedArray]:\n        \"\"\"handle nan and centerize data\n\n        Note:\n            if `nan_option='mask'` then the returned array will be `np.ma.MaskedArray`.\n        \"\"\"\n        # handle nan\n        if self.nan_option == self.FILL_NAN:\n            X = np.nan_to_num(X)\n        elif self.nan_option == self.MASK_NAN:\n            X = np.ma.masked_invalid(X)\n        # centralize\n        if not self.assume_centered:\n            X = X - np.nanmean(X, axis=0)\n        return X\n"
  },
  {
    "path": "qlib/model/riskmodel/poet.py",
    "content": "import numpy as np\n\nfrom qlib.model.riskmodel import RiskModel\n\n\nclass POETCovEstimator(RiskModel):\n    \"\"\"Principal Orthogonal Complement Thresholding Estimator (POET)\n\n    Reference:\n        [1] Fan, J., Liao, Y., & Mincheva, M. (2013). Large covariance estimation by thresholding principal orthogonal complements.\n            Journal of the Royal Statistical Society. Series B: Statistical Methodology, 75(4), 603–680. https://doi.org/10.1111/rssb.12016\n        [2] http://econweb.rutgers.edu/yl1114/papers/poet/POET.m\n    \"\"\"\n\n    THRESH_SOFT = \"soft\"\n    THRESH_HARD = \"hard\"\n    THRESH_SCAD = \"scad\"\n\n    def __init__(self, num_factors: int = 0, thresh: float = 1.0, thresh_method: str = \"soft\", **kwargs):\n        \"\"\"\n        Args:\n            num_factors (int): number of factors (if set to zero, no factor model will be used).\n            thresh (float): the positive constant for thresholding.\n            thresh_method (str): thresholding method, which can be\n                - 'soft': soft thresholding.\n                - 'hard': hard thresholding.\n                - 'scad': scad thresholding.\n            kwargs: see `RiskModel` for more information.\n        \"\"\"\n        super().__init__(**kwargs)\n\n        assert num_factors >= 0, \"`num_factors` requires a positive integer\"\n        self.num_factors = num_factors\n\n        assert thresh >= 0, \"`thresh` requires a positive float number\"\n        self.thresh = thresh\n\n        assert thresh_method in [\n            self.THRESH_HARD,\n            self.THRESH_SOFT,\n            self.THRESH_SCAD,\n        ], \"`thresh_method` should be `soft`/`hard`/`scad`\"\n        self.thresh_method = thresh_method\n\n    def _predict(self, X: np.ndarray) -> np.ndarray:\n        Y = X.T  # NOTE: to match POET's implementation\n        p, n = Y.shape\n\n        if self.num_factors > 0:\n            Dd, V = np.linalg.eig(Y.T.dot(Y))\n            V = V[:, np.argsort(Dd)]\n            F = V[:, -self.num_factors :][:, ::-1] * np.sqrt(n)\n            LamPCA = Y.dot(F) / n\n            uhat = np.asarray(Y - LamPCA.dot(F.T))\n            Lowrank = np.asarray(LamPCA.dot(LamPCA.T))\n            rate = 1 / np.sqrt(p) + np.sqrt(np.log(p) / n)\n        else:\n            uhat = np.asarray(Y)\n            rate = np.sqrt(np.log(p) / n)\n            Lowrank = 0\n\n        lamb = rate * self.thresh\n        SuPCA = uhat.dot(uhat.T) / n\n        SuDiag = np.diag(np.diag(SuPCA))\n        R = np.linalg.inv(SuDiag**0.5).dot(SuPCA).dot(np.linalg.inv(SuDiag**0.5))\n\n        if self.thresh_method == self.THRESH_HARD:\n            M = R * (np.abs(R) > lamb)\n        elif self.thresh_method == self.THRESH_SOFT:\n            res = np.abs(R) - lamb\n            res = (res + np.abs(res)) / 2\n            M = np.sign(R) * res\n        else:\n            M1 = (np.abs(R) < 2 * lamb) * np.sign(R) * (np.abs(R) - lamb) * (np.abs(R) > lamb)\n            M2 = (np.abs(R) < 3.7 * lamb) * (np.abs(R) >= 2 * lamb) * (2.7 * R - 3.7 * np.sign(R) * lamb) / 1.7\n            M3 = (np.abs(R) >= 3.7 * lamb) * R\n            M = M1 + M2 + M3\n\n        Rthresh = M - np.diag(np.diag(M)) + np.eye(p)\n        SigmaU = (SuDiag**0.5).dot(Rthresh).dot(SuDiag**0.5)\n        SigmaY = SigmaU + Lowrank\n\n        return SigmaY\n"
  },
  {
    "path": "qlib/model/riskmodel/shrink.py",
    "content": "import numpy as np\nfrom typing import Union\n\nfrom qlib.model.riskmodel import RiskModel\n\n\nclass ShrinkCovEstimator(RiskModel):\n    \"\"\"Shrinkage Covariance Estimator\n\n    This estimator will shrink the sample covariance matrix towards\n    an identify matrix:\n        S_hat = (1 - alpha) * S + alpha * F\n    where `alpha` is the shrink parameter and `F` is the shrinking target.\n\n    The following shrinking parameters (`alpha`) are supported:\n        - `lw` [1][2][3]: use Ledoit-Wolf shrinking parameter.\n        - `oas` [4]: use Oracle Approximating Shrinkage shrinking parameter.\n        - float: directly specify the shrink parameter, should be between [0, 1].\n\n    The following shrinking targets (`F`) are supported:\n        - `const_var` [1][4][5]: assume stocks have the same constant variance and zero correlation.\n        - `const_corr` [2][6]: assume stocks have different variance but equal correlation.\n        - `single_factor` [3][7]: assume single factor model as the shrinking target.\n        - np.ndarray: provide the shrinking targets directly.\n\n    Note:\n        - The optimal shrinking parameter depends on the selection of the shrinking target.\n            Currently, `oas` is not supported for `const_corr` and `single_factor`.\n        - Remember to set `nan_option` to `fill` or `mask` if your data has missing values.\n\n    References:\n        [1] Ledoit, O., & Wolf, M. (2004). A well-conditioned estimator for large-dimensional covariance matrices.\n            Journal of Multivariate Analysis, 88(2), 365–411. https://doi.org/10.1016/S0047-259X(03)00096-4\n        [2] Ledoit, O., & Wolf, M. (2004). Honey, I shrunk the sample covariance matrix.\n            Journal of Portfolio Management, 30(4), 1–22. https://doi.org/10.3905/jpm.2004.110\n        [3] Ledoit, O., & Wolf, M. (2003). Improved estimation of the covariance matrix of stock returns\n            with an application to portfolio selection.\n            Journal of Empirical Finance, 10(5), 603–621. https://doi.org/10.1016/S0927-5398(03)00007-0\n        [4] Chen, Y., Wiesel, A., Eldar, Y. C., & Hero, A. O. (2010). Shrinkage algorithms for MMSE covariance\n            estimation. IEEE Transactions on Signal Processing, 58(10), 5016–5029.\n            https://doi.org/10.1109/TSP.2010.2053029\n        [5] https://www.econ.uzh.ch/dam/jcr:ffffffff-935a-b0d6-0000-00007f64e5b9/cov1para.m.zip\n        [6] https://www.econ.uzh.ch/dam/jcr:ffffffff-935a-b0d6-ffff-ffffde5e2d4e/covCor.m.zip\n        [7] https://www.econ.uzh.ch/dam/jcr:ffffffff-935a-b0d6-0000-0000648dfc98/covMarket.m.zip\n    \"\"\"\n\n    SHR_LW = \"lw\"\n    SHR_OAS = \"oas\"\n\n    TGT_CONST_VAR = \"const_var\"\n    TGT_CONST_CORR = \"const_corr\"\n    TGT_SINGLE_FACTOR = \"single_factor\"\n\n    def __init__(self, alpha: Union[str, float] = 0.0, target: Union[str, np.ndarray] = \"const_var\", **kwargs):\n        \"\"\"\n        Args:\n            alpha (str or float): shrinking parameter or estimator (`lw`/`oas`)\n            target (str or np.ndarray): shrinking target (`const_var`/`const_corr`/`single_factor`)\n            kwargs: see `RiskModel` for more information\n        \"\"\"\n        super().__init__(**kwargs)\n\n        # alpha\n        if isinstance(alpha, str):\n            assert alpha in [self.SHR_LW, self.SHR_OAS], f\"shrinking method `{alpha}` is not supported\"\n        elif isinstance(alpha, (float, np.floating)):\n            assert 0 <= alpha <= 1, \"alpha should be between [0, 1]\"\n        else:\n            raise TypeError(\"invalid argument type for `alpha`\")\n        self.alpha = alpha\n\n        # target\n        if isinstance(target, str):\n            assert target in [\n                self.TGT_CONST_VAR,\n                self.TGT_CONST_CORR,\n                self.TGT_SINGLE_FACTOR,\n            ], f\"shrinking target `{target} is not supported\"\n        elif isinstance(target, np.ndarray):\n            pass\n        else:\n            raise TypeError(\"invalid argument type for `target`\")\n        if alpha == self.SHR_OAS and target != self.TGT_CONST_VAR:\n            raise NotImplementedError(\"currently `oas` can only support `const_var` as target\")\n        self.target = target\n\n    def _predict(self, X: np.ndarray) -> np.ndarray:\n        # sample covariance\n        S = super()._predict(X)\n\n        # shrinking target\n        F = self._get_shrink_target(X, S)\n\n        # get shrinking parameter\n        alpha = self._get_shrink_param(X, S, F)\n\n        # shrink covariance\n        if alpha > 0:\n            S *= 1 - alpha\n            F *= alpha\n            S += F\n\n        return S\n\n    def _get_shrink_target(self, X: np.ndarray, S: np.ndarray) -> np.ndarray:\n        \"\"\"get shrinking target `F`\"\"\"\n        if self.target == self.TGT_CONST_VAR:\n            return self._get_shrink_target_const_var(X, S)\n        if self.target == self.TGT_CONST_CORR:\n            return self._get_shrink_target_const_corr(X, S)\n        if self.target == self.TGT_SINGLE_FACTOR:\n            return self._get_shrink_target_single_factor(X, S)\n        return self.target\n\n    def _get_shrink_target_const_var(self, X: np.ndarray, S: np.ndarray) -> np.ndarray:\n        \"\"\"get shrinking target with constant variance\n\n        This target assumes zero pair-wise correlation and constant variance.\n        The constant variance is estimated by averaging all sample's variances.\n        \"\"\"\n        n = len(S)\n        F = np.eye(n)\n        np.fill_diagonal(F, np.mean(np.diag(S)))\n        return F\n\n    def _get_shrink_target_const_corr(self, X: np.ndarray, S: np.ndarray) -> np.ndarray:\n        \"\"\"get shrinking target with constant correlation\n\n        This target assumes constant pair-wise correlation but keep the sample variance.\n        The constant correlation is estimated by averaging all pairwise correlations.\n        \"\"\"\n        n = len(S)\n        var = np.diag(S)\n        sqrt_var = np.sqrt(var)\n        covar = np.outer(sqrt_var, sqrt_var)\n        r_bar = (np.sum(S / covar) - n) / (n * (n - 1))\n        F = r_bar * covar\n        np.fill_diagonal(F, var)\n        return F\n\n    def _get_shrink_target_single_factor(self, X: np.ndarray, S: np.ndarray) -> np.ndarray:\n        \"\"\"get shrinking target with single factor model\"\"\"\n        X_mkt = np.nanmean(X, axis=1)\n        cov_mkt = np.asarray(X.T.dot(X_mkt) / len(X))\n        var_mkt = np.asarray(X_mkt.dot(X_mkt) / len(X))\n        F = np.outer(cov_mkt, cov_mkt) / var_mkt\n        np.fill_diagonal(F, np.diag(S))\n        return F\n\n    def _get_shrink_param(self, X: np.ndarray, S: np.ndarray, F: np.ndarray) -> float:\n        \"\"\"get shrinking parameter `alpha`\n\n        Note:\n            The Ledoit-Wolf shrinking parameter estimator consists of three different methods.\n        \"\"\"\n        if self.alpha == self.SHR_OAS:\n            return self._get_shrink_param_oas(X, S, F)\n        elif self.alpha == self.SHR_LW:\n            if self.target == self.TGT_CONST_VAR:\n                return self._get_shrink_param_lw_const_var(X, S, F)\n            if self.target == self.TGT_CONST_CORR:\n                return self._get_shrink_param_lw_const_corr(X, S, F)\n            if self.target == self.TGT_SINGLE_FACTOR:\n                return self._get_shrink_param_lw_single_factor(X, S, F)\n        return self.alpha\n\n    def _get_shrink_param_oas(self, X: np.ndarray, S: np.ndarray, F: np.ndarray) -> float:\n        \"\"\"Oracle Approximating Shrinkage Estimator\n\n        This method uses the following formula to estimate the `alpha`\n        parameter for the shrink covariance estimator:\n            A = (1 - 2 / p) * trace(S^2) + trace^2(S)\n            B = (n + 1 - 2 / p) * (trace(S^2) - trace^2(S) / p)\n            alpha = A / B\n        where `n`, `p` are the dim of observations and variables respectively.\n        \"\"\"\n        trS2 = np.sum(S**2)\n        tr2S = np.trace(S) ** 2\n\n        n, p = X.shape\n\n        A = (1 - 2 / p) * (trS2 + tr2S)\n        B = (n + 1 - 2 / p) * (trS2 + tr2S / p)\n        alpha = A / B\n\n        return alpha\n\n    def _get_shrink_param_lw_const_var(self, X: np.ndarray, S: np.ndarray, F: np.ndarray) -> float:\n        \"\"\"Ledoit-Wolf Shrinkage Estimator (Constant Variance)\n\n        This method shrinks the covariance matrix towards the constand variance target.\n        \"\"\"\n        t, n = X.shape\n\n        y = X**2\n        phi = np.sum(y.T.dot(y) / t - S**2)\n\n        gamma = np.linalg.norm(S - F, \"fro\") ** 2\n\n        kappa = phi / gamma\n        alpha = max(0, min(1, kappa / t))\n\n        return alpha\n\n    def _get_shrink_param_lw_const_corr(self, X: np.ndarray, S: np.ndarray, F: np.ndarray) -> float:\n        \"\"\"Ledoit-Wolf Shrinkage Estimator (Constant Correlation)\n\n        This method shrinks the covariance matrix towards the constand correlation target.\n        \"\"\"\n        t, n = X.shape\n\n        var = np.diag(S)\n        sqrt_var = np.sqrt(var)\n        r_bar = (np.sum(S / np.outer(sqrt_var, sqrt_var)) - n) / (n * (n - 1))\n\n        y = X**2\n        phi_mat = y.T.dot(y) / t - S**2\n        phi = np.sum(phi_mat)\n\n        theta_mat = (X**3).T.dot(X) / t - var[:, None] * S\n        np.fill_diagonal(theta_mat, 0)\n        rho = np.sum(np.diag(phi_mat)) + r_bar * np.sum(np.outer(1 / sqrt_var, sqrt_var) * theta_mat)\n\n        gamma = np.linalg.norm(S - F, \"fro\") ** 2\n\n        kappa = (phi - rho) / gamma\n        alpha = max(0, min(1, kappa / t))\n\n        return alpha\n\n    def _get_shrink_param_lw_single_factor(self, X: np.ndarray, S: np.ndarray, F: np.ndarray) -> float:\n        \"\"\"Ledoit-Wolf Shrinkage Estimator (Single Factor Model)\n\n        This method shrinks the covariance matrix towards the single factor model target.\n        \"\"\"\n        t, n = X.shape\n\n        X_mkt = np.nanmean(X, axis=1)\n        cov_mkt = np.asarray(X.T.dot(X_mkt) / len(X))\n        var_mkt = np.asarray(X_mkt.dot(X_mkt) / len(X))\n\n        y = X**2\n        phi = np.sum(y.T.dot(y)) / t - np.sum(S**2)\n\n        rdiag = np.sum(y**2) / t - np.sum(np.diag(S) ** 2)\n        z = X * X_mkt[:, None]\n        v1 = y.T.dot(z) / t - cov_mkt[:, None] * S\n        roff1 = np.sum(v1 * cov_mkt[:, None].T) / var_mkt - np.sum(np.diag(v1) * cov_mkt) / var_mkt\n        v3 = z.T.dot(z) / t - var_mkt * S\n        roff3 = np.sum(v3 * np.outer(cov_mkt, cov_mkt)) / var_mkt**2 - np.sum(np.diag(v3) * cov_mkt**2) / var_mkt**2\n        roff = 2 * roff1 - roff3\n        rho = rdiag + roff\n\n        gamma = np.linalg.norm(S - F, \"fro\") ** 2\n\n        kappa = (phi - rho) / gamma\n        alpha = max(0, min(1, kappa / t))\n\n        return alpha\n"
  },
  {
    "path": "qlib/model/riskmodel/structured.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport numpy as np\nfrom typing import Union\nfrom sklearn.decomposition import PCA, FactorAnalysis\n\nfrom qlib.model.riskmodel import RiskModel\n\n\nclass StructuredCovEstimator(RiskModel):\n    \"\"\"Structured Covariance Estimator\n\n    This estimator assumes observations can be predicted by multiple factors\n        X = B @ F.T + U\n    where `X` contains observations (row) of multiple variables (column),\n    `F` contains factor exposures (column) for all variables (row),\n    `B` is the regression coefficients matrix for all observations (row) on\n    all factors (columns), and `U` is the residual matrix with shape like `X`.\n\n    Therefore, the structured covariance can be estimated by\n        cov(X.T) = F @ cov(B.T) @ F.T + diag(var(U))\n\n    In finance domain, there are mainly three methods to design `F` [1][2]:\n        - Statistical Risk Model (SRM): latent factor models major components\n        - Fundamental Risk Model (FRM): human designed factors\n        - Deep Risk Model (DRM): neural network designed factors (like a blend of SRM & DRM)\n\n    In this implementation we use latent factor models to specify `F`.\n    Specifically, the following two latent factor models are supported:\n        - `pca`: Principal Component Analysis\n        - `fa`: Factor Analysis\n\n    Reference:\n        [1] Fan, J., Liao, Y., & Liu, H. (2016). An overview of the estimation of large covariance and\n            precision matrices. Econometrics Journal, 19(1), C1–C32. https://doi.org/10.1111/ectj.12061\n        [2] Lin, H., Zhou, D., Liu, W., & Bian, J. (2021). Deep Risk Model: A Deep Learning Solution for\n            Mining Latent Risk Factors to Improve Covariance Matrix Estimation. arXiv preprint arXiv:2107.05201.\n    \"\"\"\n\n    FACTOR_MODEL_PCA = \"pca\"\n    FACTOR_MODEL_FA = \"fa\"\n    DEFAULT_NAN_OPTION = \"fill\"\n\n    def __init__(self, factor_model: str = \"pca\", num_factors: int = 10, **kwargs):\n        \"\"\"\n        Args:\n            factor_model (str): the latent factor models used to estimate the structured covariance (`pca`/`fa`).\n            num_factors (int): number of components to keep.\n            kwargs: see `RiskModel` for more information\n        \"\"\"\n        if \"nan_option\" in kwargs:\n            assert kwargs[\"nan_option\"] in [self.DEFAULT_NAN_OPTION], \"nan_option={} is not supported\".format(\n                kwargs[\"nan_option\"]\n            )\n        else:\n            kwargs[\"nan_option\"] = self.DEFAULT_NAN_OPTION\n\n        super().__init__(**kwargs)\n\n        assert factor_model in [\n            self.FACTOR_MODEL_PCA,\n            self.FACTOR_MODEL_FA,\n        ], \"factor_model={} is not supported\".format(factor_model)\n        self.solver = PCA if factor_model == self.FACTOR_MODEL_PCA else FactorAnalysis\n\n        self.num_factors = num_factors\n\n    def _predict(self, X: np.ndarray, return_decomposed_components=False) -> Union[np.ndarray, tuple]:\n        \"\"\"\n        covariance estimation implementation\n\n        Args:\n            X (np.ndarray): data matrix containing multiple variables (columns) and observations (rows).\n            return_decomposed_components (bool): whether return decomposed components of the covariance matrix.\n\n        Returns:\n            tuple or np.ndarray: decomposed covariance matrix or covariance matrix.\n        \"\"\"\n\n        model = self.solver(self.num_factors, random_state=0).fit(X)\n\n        F = model.components_.T  # variables x factors\n        B = model.transform(X)  # observations x factors\n        U = X - B @ F.T\n        cov_b = np.cov(B.T)  # factors x factors\n        var_u = np.var(U, axis=0)  # diagonal\n\n        if return_decomposed_components:\n            return F, cov_b, var_u\n\n        cov_x = F @ cov_b @ F.T + np.diag(var_u)\n\n        return cov_x\n"
  },
  {
    "path": "qlib/model/trainer.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\n\"\"\"\nThe Trainer will train a list of tasks and return a list of model recorders.\nThere are two steps in each Trainer including ``train`` (make model recorder) and ``end_train`` (modify model recorder).\n\nThis is a concept called ``DelayTrainer``, which can be used in online simulating for parallel training.\nIn ``DelayTrainer``, the first step is only to save some necessary info to model recorders, and the second step which will be finished in the end can do some concurrent and time-consuming operations such as model fitting.\n\n``Qlib`` offer two kinds of Trainer, ``TrainerR`` is the simplest way and ``TrainerRM`` is based on TaskManager to help manager tasks lifecycle automatically.\n\"\"\"\n\nimport socket\nfrom typing import Callable, List, Optional\n\nfrom tqdm.auto import tqdm\n\nfrom qlib.config import C\nfrom qlib.data.dataset import Dataset\nfrom qlib.data.dataset.weight import Reweighter\nfrom qlib.log import get_module_logger\nfrom qlib.model.base import Model\nfrom qlib.utils import (\n    auto_filter_kwargs,\n    fill_placeholder,\n    flatten_dict,\n    init_instance_by_config,\n)\nfrom qlib.utils.paral import call_in_subproc\nfrom qlib.workflow import R\nfrom qlib.workflow.recorder import Recorder\nfrom qlib.workflow.task.manage import TaskManager, run_task\n\n\ndef _log_task_info(task_config: dict):\n    R.log_params(**flatten_dict(task_config))\n    R.save_objects(**{\"task\": task_config})  # keep the original format and datatype\n    R.set_tags(**{\"hostname\": socket.gethostname()})\n\n\ndef _exe_task(task_config: dict):\n    rec = R.get_recorder()\n    # model & dataset initialization\n    model: Model = init_instance_by_config(task_config[\"model\"], accept_types=Model)\n    dataset: Dataset = init_instance_by_config(task_config[\"dataset\"], accept_types=Dataset)\n    reweighter: Reweighter = task_config.get(\"reweighter\", None)\n    # model training\n    auto_filter_kwargs(model.fit)(dataset, reweighter=reweighter)\n    R.save_objects(**{\"params.pkl\": model})\n    # this dataset is saved for online inference. So the concrete data should not be dumped\n    dataset.config(dump_all=False, recursive=True)\n    R.save_objects(**{\"dataset\": dataset})\n    # fill placehorder\n    placehorder_value = {\"<MODEL>\": model, \"<DATASET>\": dataset}\n    task_config = fill_placeholder(task_config, placehorder_value)\n    # generate records: prediction, backtest, and analysis\n    records = task_config.get(\"record\", [])\n    if isinstance(records, dict):  # prevent only one dict\n        records = [records]\n    for record in records:\n        # Some recorder require the parameter `model` and `dataset`.\n        # try to automatically pass in them to the initialization function\n        # to make defining the tasking easier\n        r = init_instance_by_config(\n            record,\n            recorder=rec,\n            default_module=\"qlib.workflow.record_temp\",\n            try_kwargs={\"model\": model, \"dataset\": dataset},\n        )\n        r.generate()\n\n\ndef begin_task_train(task_config: dict, experiment_name: str, recorder_name: str = None) -> Recorder:\n    \"\"\"\n    Begin task training to start a recorder and save the task config.\n\n    Args:\n        task_config (dict): the config of a task\n        experiment_name (str): the name of experiment\n        recorder_name (str): the given name will be the recorder name. None for using rid.\n\n    Returns:\n        Recorder: the model recorder\n    \"\"\"\n    with R.start(experiment_name=experiment_name, recorder_name=recorder_name):\n        _log_task_info(task_config)\n        return R.get_recorder()\n\n\ndef end_task_train(rec: Recorder, experiment_name: str) -> Recorder:\n    \"\"\"\n    Finish task training with real model fitting and saving.\n\n    Args:\n        rec (Recorder): the recorder will be resumed\n        experiment_name (str): the name of experiment\n\n    Returns:\n        Recorder: the model recorder\n    \"\"\"\n    with R.start(experiment_name=experiment_name, recorder_id=rec.info[\"id\"], resume=True):\n        task_config = R.load_object(\"task\")\n        _exe_task(task_config)\n    return rec\n\n\ndef task_train(task_config: dict, experiment_name: str, recorder_name: str = None) -> Recorder:\n    \"\"\"\n    Task based training, will be divided into two steps.\n\n    Parameters\n    ----------\n    task_config : dict\n        The config of a task.\n    experiment_name: str\n        The name of experiment\n    recorder_name: str\n        The name of recorder\n\n    Returns\n    ----------\n    Recorder: The instance of the recorder\n    \"\"\"\n    with R.start(experiment_name=experiment_name, recorder_name=recorder_name):\n        _log_task_info(task_config)\n        _exe_task(task_config)\n        return R.get_recorder()\n\n\nclass Trainer:\n    \"\"\"\n    The trainer can train a list of models.\n    There are Trainer and DelayTrainer, which can be distinguished by when it will finish real training.\n    \"\"\"\n\n    def __init__(self):\n        self.delay = False\n\n    def train(self, tasks: list, *args, **kwargs) -> list:\n        \"\"\"\n        Given a list of task definitions, begin training, and return the models.\n\n        For Trainer, it finishes real training in this method.\n        For DelayTrainer, it only does some preparation in this method.\n\n        Args:\n            tasks: a list of tasks\n\n        Returns:\n            list: a list of models\n        \"\"\"\n        raise NotImplementedError(f\"Please implement the `train` method.\")\n\n    def end_train(self, models: list, *args, **kwargs) -> list:\n        \"\"\"\n        Given a list of models, finished something at the end of training if you need.\n        The models may be Recorder, txt file, database, and so on.\n\n        For Trainer, it does some finishing touches in this method.\n        For DelayTrainer, it finishes real training in this method.\n\n        Args:\n            models: a list of models\n\n        Returns:\n            list: a list of models\n        \"\"\"\n        # do nothing if you finished all work in `train` method\n        return models\n\n    def is_delay(self) -> bool:\n        \"\"\"\n        If Trainer will delay finishing `end_train`.\n\n        Returns:\n            bool: if DelayTrainer\n        \"\"\"\n        return self.delay\n\n    def __call__(self, *args, **kwargs) -> list:\n        return self.end_train(self.train(*args, **kwargs))\n\n    def has_worker(self) -> bool:\n        \"\"\"\n        Some trainer has backend worker to support parallel training\n        This method can tell if the worker is enabled.\n\n        Returns\n        -------\n        bool:\n            if the worker is enabled\n\n        \"\"\"\n        return False\n\n    def worker(self):\n        \"\"\"\n        start the worker\n\n        Raises\n        ------\n        NotImplementedError:\n            If the worker is not supported\n        \"\"\"\n        raise NotImplementedError(f\"Please implement the `worker` method\")\n\n\nclass TrainerR(Trainer):\n    \"\"\"\n    Trainer based on (R)ecorder.\n    It will train a list of tasks and return a list of model recorders in a linear way.\n\n    Assumption: models were defined by `task` and the results will be saved to `Recorder`.\n    \"\"\"\n\n    # Those tag will help you distinguish whether the Recorder has finished traning\n    STATUS_KEY = \"train_status\"\n    STATUS_BEGIN = \"begin_task_train\"\n    STATUS_END = \"end_task_train\"\n\n    def __init__(\n        self,\n        experiment_name: Optional[str] = None,\n        train_func: Callable = task_train,\n        call_in_subproc: bool = False,\n        default_rec_name: Optional[str] = None,\n    ):\n        \"\"\"\n        Init TrainerR.\n\n        Args:\n            experiment_name (str, optional): the default name of experiment.\n            train_func (Callable, optional): default training method. Defaults to `task_train`.\n            call_in_subproc (bool): call the process in subprocess to force memory release\n        \"\"\"\n        super().__init__()\n        self.experiment_name = experiment_name\n        self.default_rec_name = default_rec_name\n        self.train_func = train_func\n        self._call_in_subproc = call_in_subproc\n\n    def train(\n        self, tasks: list, train_func: Optional[Callable] = None, experiment_name: Optional[str] = None, **kwargs\n    ) -> List[Recorder]:\n        \"\"\"\n        Given a list of `tasks` and return a list of trained Recorder. The order can be guaranteed.\n\n        Args:\n            tasks (list): a list of definitions based on `task` dict\n            train_func (Callable): the training method which needs at least `tasks` and `experiment_name`. None for the default training method.\n            experiment_name (str): the experiment name, None for use default name.\n            kwargs: the params for train_func.\n\n        Returns:\n            List[Recorder]: a list of Recorders\n        \"\"\"\n        if isinstance(tasks, dict):\n            tasks = [tasks]\n        if len(tasks) == 0:\n            return []\n        if train_func is None:\n            train_func = self.train_func\n        if experiment_name is None:\n            experiment_name = self.experiment_name\n        recs = []\n        for task in tqdm(tasks, desc=\"train tasks\"):\n            if self._call_in_subproc:\n                get_module_logger(\"TrainerR\").info(\"running models in sub process (for forcing release memroy).\")\n                train_func = call_in_subproc(train_func, C)\n            rec = train_func(task, experiment_name, recorder_name=self.default_rec_name, **kwargs)\n            rec.set_tags(**{self.STATUS_KEY: self.STATUS_BEGIN})\n            recs.append(rec)\n        return recs\n\n    def end_train(self, models: list, **kwargs) -> List[Recorder]:\n        \"\"\"\n        Set STATUS_END tag to the recorders.\n\n        Args:\n            models (list): a list of trained recorders.\n\n        Returns:\n            List[Recorder]: the same list as the param.\n        \"\"\"\n        if isinstance(models, Recorder):\n            models = [models]\n        for rec in models:\n            rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})\n        return models\n\n\nclass DelayTrainerR(TrainerR):\n    \"\"\"\n    A delayed implementation based on TrainerR, which means `train` method may only do some preparation and `end_train` method can do the real model fitting.\n    \"\"\"\n\n    def __init__(\n        self, experiment_name: str = None, train_func=begin_task_train, end_train_func=end_task_train, **kwargs\n    ):\n        \"\"\"\n        Init TrainerRM.\n\n        Args:\n            experiment_name (str): the default name of experiment.\n            train_func (Callable, optional): default train method. Defaults to `begin_task_train`.\n            end_train_func (Callable, optional): default end_train method. Defaults to `end_task_train`.\n        \"\"\"\n        super().__init__(experiment_name, train_func, **kwargs)\n        self.end_train_func = end_train_func\n        self.delay = True\n\n    def end_train(self, models, end_train_func=None, experiment_name: str = None, **kwargs) -> List[Recorder]:\n        \"\"\"\n        Given a list of Recorder and return a list of trained Recorder.\n        This class will finish real data loading and model fitting.\n\n        Args:\n            models (list): a list of Recorder, the tasks have been saved to them\n            end_train_func (Callable, optional): the end_train method which needs at least `recorders` and `experiment_name`. Defaults to None for using self.end_train_func.\n            experiment_name (str): the experiment name, None for use default name.\n            kwargs: the params for end_train_func.\n\n        Returns:\n            List[Recorder]: a list of Recorders\n        \"\"\"\n        if isinstance(models, Recorder):\n            models = [models]\n        if end_train_func is None:\n            end_train_func = self.end_train_func\n        if experiment_name is None:\n            experiment_name = self.experiment_name\n        for rec in models:\n            if rec.list_tags()[self.STATUS_KEY] == self.STATUS_END:\n                continue\n            end_train_func(rec, experiment_name, **kwargs)\n            rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})\n        return models\n\n\nclass TrainerRM(Trainer):\n    \"\"\"\n    Trainer based on (R)ecorder and Task(M)anager.\n    It can train a list of tasks and return a list of model recorders in a multiprocessing way.\n\n    Assumption: `task` will be saved to TaskManager and `task` will be fetched and trained from TaskManager\n    \"\"\"\n\n    # Those tag will help you distinguish whether the Recorder has finished traning\n    STATUS_KEY = \"train_status\"\n    STATUS_BEGIN = \"begin_task_train\"\n    STATUS_END = \"end_task_train\"\n\n    # This tag is the _id in TaskManager to distinguish tasks.\n    TM_ID = \"_id in TaskManager\"\n\n    def __init__(\n        self,\n        experiment_name: str = None,\n        task_pool: str = None,\n        train_func=task_train,\n        skip_run_task: bool = False,\n        default_rec_name: Optional[str] = None,\n    ):\n        \"\"\"\n        Init TrainerR.\n\n        Args:\n            experiment_name (str): the default name of experiment.\n            task_pool (str): task pool name in TaskManager. None for use same name as experiment_name.\n            train_func (Callable, optional): default training method. Defaults to `task_train`.\n            skip_run_task (bool):\n                If skip_run_task == True:\n                Only run_task in the worker. Otherwise skip run_task.\n        \"\"\"\n\n        super().__init__()\n        self.experiment_name = experiment_name\n        self.task_pool = task_pool\n        self.train_func = train_func\n        self.skip_run_task = skip_run_task\n        self.default_rec_name = default_rec_name\n\n    def train(\n        self,\n        tasks: list,\n        train_func: Callable = None,\n        experiment_name: str = None,\n        before_status: str = TaskManager.STATUS_WAITING,\n        after_status: str = TaskManager.STATUS_DONE,\n        default_rec_name: Optional[str] = None,\n        **kwargs,\n    ) -> List[Recorder]:\n        \"\"\"\n        Given a list of `tasks` and return a list of trained Recorder. The order can be guaranteed.\n\n        This method defaults to a single process, but TaskManager offered a great way to parallel training.\n        Users can customize their train_func to realize multiple processes or even multiple machines.\n\n        Args:\n            tasks (list): a list of definitions based on `task` dict\n            train_func (Callable): the training method which needs at least `tasks` and `experiment_name`. None for the default training method.\n            experiment_name (str): the experiment name, None for use default name.\n            before_status (str): the tasks in before_status will be fetched and trained. Can be STATUS_WAITING, STATUS_PART_DONE.\n            after_status (str): the tasks after trained will become after_status. Can be STATUS_WAITING, STATUS_PART_DONE.\n            kwargs: the params for train_func.\n\n        Returns:\n            List[Recorder]: a list of Recorders\n        \"\"\"\n        if isinstance(tasks, dict):\n            tasks = [tasks]\n        if len(tasks) == 0:\n            return []\n        if train_func is None:\n            train_func = self.train_func\n        if experiment_name is None:\n            experiment_name = self.experiment_name\n        if default_rec_name is None:\n            default_rec_name = self.default_rec_name\n        task_pool = self.task_pool\n        if task_pool is None:\n            task_pool = experiment_name\n        tm = TaskManager(task_pool=task_pool)\n        _id_list = tm.create_task(tasks)  # all tasks will be saved to MongoDB\n        query = {\"_id\": {\"$in\": _id_list}}\n        if not self.skip_run_task:\n            run_task(\n                train_func,\n                task_pool,\n                query=query,  # only train these tasks\n                experiment_name=experiment_name,\n                before_status=before_status,\n                after_status=after_status,\n                recorder_name=default_rec_name,\n                **kwargs,\n            )\n\n        if not self.is_delay():\n            tm.wait(query=query)\n\n        recs = []\n        for _id in _id_list:\n            rec = tm.re_query(_id)[\"res\"]\n            rec.set_tags(**{self.STATUS_KEY: self.STATUS_BEGIN})\n            rec.set_tags(**{self.TM_ID: _id})\n            recs.append(rec)\n        return recs\n\n    def end_train(self, recs: list, **kwargs) -> List[Recorder]:\n        \"\"\"\n        Set STATUS_END tag to the recorders.\n\n        Args:\n            recs (list): a list of trained recorders.\n\n        Returns:\n            List[Recorder]: the same list as the param.\n        \"\"\"\n        if isinstance(recs, Recorder):\n            recs = [recs]\n        for rec in recs:\n            rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})\n        return recs\n\n    def worker(\n        self,\n        train_func: Callable = None,\n        experiment_name: str = None,\n    ):\n        \"\"\"\n        The multiprocessing method for `train`. It can share a same task_pool with `train` and can run in other progress or other machines.\n\n        Args:\n            train_func (Callable): the training method which needs at least `tasks` and `experiment_name`. None for the default training method.\n            experiment_name (str): the experiment name, None for use default name.\n        \"\"\"\n        if train_func is None:\n            train_func = self.train_func\n        if experiment_name is None:\n            experiment_name = self.experiment_name\n        task_pool = self.task_pool\n        if task_pool is None:\n            task_pool = experiment_name\n        run_task(train_func, task_pool=task_pool, experiment_name=experiment_name)\n\n    def has_worker(self) -> bool:\n        return True\n\n\nclass DelayTrainerRM(TrainerRM):\n    \"\"\"\n    A delayed implementation based on TrainerRM, which means `train` method may only do some preparation and `end_train` method can do the real model fitting.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        experiment_name: str = None,\n        task_pool: str = None,\n        train_func=begin_task_train,\n        end_train_func=end_task_train,\n        skip_run_task: bool = False,\n        **kwargs,\n    ):\n        \"\"\"\n        Init DelayTrainerRM.\n\n        Args:\n            experiment_name (str): the default name of experiment.\n            task_pool (str): task pool name in TaskManager. None for use same name as experiment_name.\n            train_func (Callable, optional): default train method. Defaults to `begin_task_train`.\n            end_train_func (Callable, optional): default end_train method. Defaults to `end_task_train`.\n            skip_run_task (bool):\n                If skip_run_task == True:\n                Only run_task in the worker. Otherwise skip run_task.\n                E.g. Starting trainer on a CPU VM and then waiting tasks to be finished on GPU VMs.\n        \"\"\"\n        super().__init__(experiment_name, task_pool, train_func, **kwargs)\n        self.end_train_func = end_train_func\n        self.delay = True\n        self.skip_run_task = skip_run_task\n\n    def train(self, tasks: list, train_func=None, experiment_name: str = None, **kwargs) -> List[Recorder]:\n        \"\"\"\n        Same as `train` of TrainerRM, after_status will be STATUS_PART_DONE.\n\n        Args:\n            tasks (list): a list of definition based on `task` dict\n            train_func (Callable): the train method which need at least `tasks` and `experiment_name`. Defaults to None for using self.train_func.\n            experiment_name (str): the experiment name, None for use default name.\n\n        Returns:\n            List[Recorder]: a list of Recorders\n        \"\"\"\n        if isinstance(tasks, dict):\n            tasks = [tasks]\n        if len(tasks) == 0:\n            return []\n        _skip_run_task = self.skip_run_task\n        self.skip_run_task = False  # The task preparation can't be skipped\n        res = super().train(\n            tasks,\n            train_func=train_func,\n            experiment_name=experiment_name,\n            after_status=TaskManager.STATUS_PART_DONE,\n            **kwargs,\n        )\n        self.skip_run_task = _skip_run_task\n        return res\n\n    def end_train(self, recs, end_train_func=None, experiment_name: str = None, **kwargs) -> List[Recorder]:\n        \"\"\"\n        Given a list of Recorder and return a list of trained Recorder.\n        This class will finish real data loading and model fitting.\n\n        Args:\n            recs (list): a list of Recorder, the tasks have been saved to them.\n            end_train_func (Callable, optional): the end_train method which need at least `recorders` and `experiment_name`. Defaults to None for using self.end_train_func.\n            experiment_name (str): the experiment name, None for use default name.\n            kwargs: the params for end_train_func.\n\n        Returns:\n            List[Recorder]: a list of Recorders\n        \"\"\"\n        if isinstance(recs, Recorder):\n            recs = [recs]\n        if end_train_func is None:\n            end_train_func = self.end_train_func\n        if experiment_name is None:\n            experiment_name = self.experiment_name\n        task_pool = self.task_pool\n        if task_pool is None:\n            task_pool = experiment_name\n        _id_list = []\n        for rec in recs:\n            _id_list.append(rec.list_tags()[self.TM_ID])\n\n        query = {\"_id\": {\"$in\": _id_list}}\n        if not self.skip_run_task:\n            run_task(\n                end_train_func,\n                task_pool,\n                query=query,  # only train these tasks\n                experiment_name=experiment_name,\n                before_status=TaskManager.STATUS_PART_DONE,\n                **kwargs,\n            )\n\n        TaskManager(task_pool=task_pool).wait(query=query)\n\n        for rec in recs:\n            rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})\n        return recs\n\n    def worker(self, end_train_func=None, experiment_name: str = None):\n        \"\"\"\n        The multiprocessing method for `end_train`. It can share a same task_pool with `end_train` and can run in other progress or other machines.\n\n        Args:\n            end_train_func (Callable, optional): the end_train method which need at least `recorders` and `experiment_name`. Defaults to None for using self.end_train_func.\n            experiment_name (str): the experiment name, None for use default name.\n        \"\"\"\n        if end_train_func is None:\n            end_train_func = self.end_train_func\n        if experiment_name is None:\n            experiment_name = self.experiment_name\n        task_pool = self.task_pool\n        if task_pool is None:\n            task_pool = experiment_name\n        run_task(\n            end_train_func,\n            task_pool=task_pool,\n            experiment_name=experiment_name,\n            before_status=TaskManager.STATUS_PART_DONE,\n        )\n\n    def has_worker(self) -> bool:\n        return True\n"
  },
  {
    "path": "qlib/model/utils.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom torch.utils.data import Dataset\n\n\nclass ConcatDataset(Dataset):\n    def __init__(self, *datasets):\n        self.datasets = datasets\n\n    def __getitem__(self, i):\n        return tuple(d[i] for d in self.datasets)\n\n    def __len__(self):\n        return min(len(d) for d in self.datasets)\n\n\nclass IndexSampler:\n    def __init__(self, sampler):\n        self.sampler = sampler\n\n    def __getitem__(self, i: int):\n        return self.sampler[i], i\n\n    def __len__(self):\n        return len(self.sampler)\n"
  },
  {
    "path": "qlib/rl/__init__.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom .interpreter import Interpreter, StateInterpreter, ActionInterpreter\nfrom .reward import Reward, RewardCombination\nfrom .simulator import Simulator\n\n__all__ = [\"Interpreter\", \"StateInterpreter\", \"ActionInterpreter\", \"Reward\", \"RewardCombination\", \"Simulator\"]\n"
  },
  {
    "path": "qlib/rl/aux_info.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING, Generic, Optional, TypeVar\n\nfrom qlib.typehint import final\n\nfrom .simulator import StateType\n\nif TYPE_CHECKING:\n    from .utils.env_wrapper import EnvWrapper\n\n\n__all__ = [\"AuxiliaryInfoCollector\"]\n\nAuxInfoType = TypeVar(\"AuxInfoType\")\n\n\nclass AuxiliaryInfoCollector(Generic[StateType, AuxInfoType]):\n    \"\"\"Override this class to collect customized auxiliary information from environment.\"\"\"\n\n    env: Optional[EnvWrapper] = None\n\n    @final\n    def __call__(self, simulator_state: StateType) -> AuxInfoType:\n        return self.collect(simulator_state)\n\n    def collect(self, simulator_state: StateType) -> AuxInfoType:\n        \"\"\"Override this for customized auxiliary info.\n        Usually useful in Multi-agent RL.\n\n        Parameters\n        ----------\n        simulator_state\n            Retrieved with ``simulator.get_state()``.\n\n        Returns\n        -------\n        Auxiliary information.\n        \"\"\"\n        raise NotImplementedError(\"collect is not implemented!\")\n"
  },
  {
    "path": "qlib/rl/contrib/__init__.py",
    "content": ""
  },
  {
    "path": "qlib/rl/contrib/backtest.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\nfrom __future__ import annotations\n\nimport argparse\nimport copy\nimport os\nimport pickle\nfrom collections import defaultdict\nfrom pathlib import Path\nfrom typing import Dict, List, Optional, Tuple, Union, cast\n\nimport numpy as np\nimport pandas as pd\nimport torch\nfrom joblib import Parallel, delayed\n\nfrom qlib.backtest import INDICATOR_METRIC, collect_data_loop, get_strategy_executor\nfrom qlib.backtest.decision import BaseTradeDecision, Order, OrderDir, TradeRangeByTime\nfrom qlib.backtest.executor import SimulatorExecutor\nfrom qlib.backtest.high_performance_ds import BaseOrderIndicator\nfrom qlib.rl.contrib.naive_config_parser import get_backtest_config_fromfile\nfrom qlib.rl.contrib.utils import read_order_file\nfrom qlib.rl.data.integration import init_qlib\nfrom qlib.rl.order_execution.simulator_qlib import SingleAssetOrderExecution\nfrom qlib.typehint import Literal\n\n\ndef _get_multi_level_executor_config(\n    strategy_config: dict,\n    cash_limit: float | None = None,\n    generate_report: bool = False,\n    data_granularity: str = \"1min\",\n) -> dict:\n    executor_config = {\n        \"class\": \"SimulatorExecutor\",\n        \"module_path\": \"qlib.backtest.executor\",\n        \"kwargs\": {\n            \"time_per_step\": data_granularity,\n            \"verbose\": False,\n            \"trade_type\": SimulatorExecutor.TT_PARAL if cash_limit is not None else SimulatorExecutor.TT_SERIAL,\n            \"generate_report\": generate_report,\n            \"track_data\": True,\n        },\n    }\n\n    freqs = list(strategy_config.keys())\n    freqs.sort(key=pd.Timedelta)\n    for freq in freqs:\n        executor_config = {\n            \"class\": \"NestedExecutor\",\n            \"module_path\": \"qlib.backtest.executor\",\n            \"kwargs\": {\n                \"time_per_step\": freq,\n                \"inner_strategy\": strategy_config[freq],\n                \"inner_executor\": executor_config,\n                \"track_data\": True,\n            },\n        }\n\n    return executor_config\n\n\ndef _convert_indicator_to_dataframe(indicator: dict) -> Optional[pd.DataFrame]:\n    record_list = []\n    for time, value_dict in indicator.items():\n        if isinstance(value_dict, BaseOrderIndicator):\n            # HACK: for qlib v0.8\n            value_dict = value_dict.to_series()\n        try:\n            value_dict = copy.deepcopy(value_dict)\n            if value_dict[\"ffr\"].empty:\n                continue\n        except Exception:\n            value_dict = {k: v for k, v in value_dict.items() if k != \"pa\"}\n        value_dict = pd.DataFrame(value_dict)\n        value_dict[\"datetime\"] = time\n        record_list.append(value_dict)\n\n    if not record_list:\n        return None\n\n    records: pd.DataFrame = pd.concat(record_list, 0).reset_index().rename(columns={\"index\": \"instrument\"})\n    records = records.set_index([\"instrument\", \"datetime\"])\n    return records\n\n\ndef _generate_report(\n    decisions: List[BaseTradeDecision],\n    report_indicators: List[INDICATOR_METRIC],\n) -> Dict[str, Tuple[pd.DataFrame, pd.DataFrame]]:\n    \"\"\"Generate backtest reports\n\n    Parameters\n    ----------\n    decisions:\n        List of trade decisions.\n    report_indicators\n        List of indicator reports.\n    Returns\n    -------\n\n    \"\"\"\n    indicator_dict: Dict[str, List[pd.DataFrame]] = defaultdict(list)\n    indicator_his: Dict[str, List[dict]] = defaultdict(list)\n\n    for report_indicator in report_indicators:\n        for key, (indicator_df, indicator_obj) in report_indicator.items():\n            indicator_dict[key].append(indicator_df)\n            indicator_his[key].append(indicator_obj.order_indicator_his)\n\n    report = {}\n    decision_details = pd.concat([getattr(d, \"details\") for d in decisions if hasattr(d, \"details\")])\n    for key in indicator_dict:\n        cur_dict = pd.concat(indicator_dict[key])\n        cur_his = pd.concat([_convert_indicator_to_dataframe(his) for his in indicator_his[key]])\n        cur_details = decision_details[decision_details.freq == key].set_index([\"instrument\", \"datetime\"])\n        if len(cur_details) > 0:\n            cur_details.pop(\"freq\")\n            cur_his = cur_his.join(cur_details, how=\"outer\")\n\n        report[key] = (cur_dict, cur_his)\n\n    return report\n\n\ndef single_with_simulator(\n    backtest_config: dict,\n    orders: pd.DataFrame,\n    split: Literal[\"stock\", \"day\"] = \"stock\",\n    cash_limit: float | None = None,\n    generate_report: bool = False,\n) -> Union[Tuple[pd.DataFrame, dict], pd.DataFrame]:\n    \"\"\"Run backtest in a single thread with SingleAssetOrderExecution simulator. The orders will be executed day by day.\n    A new simulator will be created and used for every single-day order.\n\n    Parameters\n    ----------\n    backtest_config:\n        Backtest config\n    orders:\n        Orders to be executed. Example format:\n                 datetime instrument  amount  direction\n            0  2020-06-01       INST   600.0          0\n            1  2020-06-02       INST   700.0          1\n            ...\n    split\n        Method to split orders. If it is \"stock\", split orders by stock. If it is \"day\", split orders by date.\n    cash_limit\n        Limitation of cash.\n    generate_report\n        Whether to generate reports.\n\n    Returns\n    -------\n        If generate_report is True, return execution records and the generated report. Otherwise, return only records.\n    \"\"\"\n    init_qlib(backtest_config[\"qlib\"])\n\n    stocks = orders.instrument.unique().tolist()\n\n    reports = []\n    decisions = []\n    for _, row in orders.iterrows():\n        date = pd.Timestamp(row[\"datetime\"])\n        start_time = pd.Timestamp(backtest_config[\"start_time\"]).replace(year=date.year, month=date.month, day=date.day)\n        end_time = pd.Timestamp(backtest_config[\"end_time\"]).replace(year=date.year, month=date.month, day=date.day)\n        order = Order(\n            stock_id=row[\"instrument\"],\n            amount=row[\"amount\"],\n            direction=OrderDir(row[\"direction\"]),\n            start_time=start_time,\n            end_time=end_time,\n        )\n\n        executor_config = _get_multi_level_executor_config(\n            strategy_config=backtest_config[\"strategies\"],\n            cash_limit=cash_limit,\n            generate_report=generate_report,\n            data_granularity=backtest_config[\"data_granularity\"],\n        )\n\n        exchange_config = copy.deepcopy(backtest_config[\"exchange\"])\n        exchange_config.update(\n            {\n                \"codes\": stocks,\n                \"freq\": backtest_config[\"data_granularity\"],\n            }\n        )\n\n        simulator = SingleAssetOrderExecution(\n            order=order,\n            executor_config=executor_config,\n            exchange_config=exchange_config,\n            qlib_config=None,\n            cash_limit=None,\n        )\n\n        reports.append(simulator.report_dict)\n        decisions += simulator.decisions\n\n    indicator_1day_objs = [report[\"indicator_dict\"][\"1day\"][1] for report in reports]\n    indicator_info = {k: v for obj in indicator_1day_objs for k, v in obj.order_indicator_his.items()}\n    records = _convert_indicator_to_dataframe(indicator_info)\n    assert records is None or not np.isnan(records[\"ffr\"]).any()\n\n    if generate_report:\n        _report = _generate_report(decisions, [report[\"indicator\"] for report in reports])\n\n        if split == \"stock\":\n            stock_id = orders.iloc[0].instrument\n            report = {stock_id: _report}\n        else:\n            day = orders.iloc[0].datetime\n            report = {day: _report}\n\n        return records, report\n    else:\n        return records\n\n\ndef single_with_collect_data_loop(\n    backtest_config: dict,\n    orders: pd.DataFrame,\n    split: Literal[\"stock\", \"day\"] = \"stock\",\n    cash_limit: float | None = None,\n    generate_report: bool = False,\n) -> Union[Tuple[pd.DataFrame, dict], pd.DataFrame]:\n    \"\"\"Run backtest in a single thread with collect_data_loop.\n\n    Parameters\n    ----------\n    backtest_config:\n        Backtest config\n    orders:\n        Orders to be executed. Example format:\n                 datetime instrument  amount  direction\n            0  2020-06-01       INST   600.0          0\n            1  2020-06-02       INST   700.0          1\n            ...\n    split\n        Method to split orders. If it is \"stock\", split orders by stock. If it is \"day\", split orders by date.\n    cash_limit\n        Limitation of cash.\n    generate_report\n        Whether to generate reports.\n\n    Returns\n    -------\n        If generate_report is True, return execution records and the generated report. Otherwise, return only records.\n    \"\"\"\n\n    init_qlib(backtest_config[\"qlib\"])\n\n    trade_start_time = orders[\"datetime\"].min()\n    trade_end_time = orders[\"datetime\"].max()\n    stocks = orders.instrument.unique().tolist()\n\n    strategy_config = {\n        \"class\": \"FileOrderStrategy\",\n        \"module_path\": \"qlib.contrib.strategy.rule_strategy\",\n        \"kwargs\": {\n            \"file\": orders,\n            \"trade_range\": TradeRangeByTime(\n                pd.Timestamp(backtest_config[\"start_time\"]).time(),\n                pd.Timestamp(backtest_config[\"end_time\"]).time(),\n            ),\n        },\n    }\n\n    executor_config = _get_multi_level_executor_config(\n        strategy_config=backtest_config[\"strategies\"],\n        cash_limit=cash_limit,\n        generate_report=generate_report,\n        data_granularity=backtest_config[\"data_granularity\"],\n    )\n\n    exchange_config = copy.deepcopy(backtest_config[\"exchange\"])\n    exchange_config.update(\n        {\n            \"codes\": stocks,\n            \"freq\": backtest_config[\"data_granularity\"],\n        }\n    )\n\n    strategy, executor = get_strategy_executor(\n        start_time=pd.Timestamp(trade_start_time),\n        end_time=pd.Timestamp(trade_end_time) + pd.DateOffset(1),\n        strategy=strategy_config,\n        executor=executor_config,\n        benchmark=None,\n        account=cash_limit if cash_limit is not None else int(1e12),\n        exchange_kwargs=exchange_config,\n        pos_type=\"Position\" if cash_limit is not None else \"InfPosition\",\n    )\n\n    report_dict: dict = {}\n    decisions = list(collect_data_loop(trade_start_time, trade_end_time, strategy, executor, report_dict))\n\n    indicator_dict = cast(INDICATOR_METRIC, report_dict.get(\"indicator_dict\"))\n    records = _convert_indicator_to_dataframe(indicator_dict[\"1day\"][1].order_indicator_his)\n    assert records is None or not np.isnan(records[\"ffr\"]).any()\n\n    if generate_report:\n        _report = _generate_report(decisions, [indicator_dict])\n        if split == \"stock\":\n            stock_id = orders.iloc[0].instrument\n            report = {stock_id: _report}\n        else:\n            day = orders.iloc[0].datetime\n            report = {day: _report}\n        return records, report\n    else:\n        return records\n\n\ndef backtest(backtest_config: dict, with_simulator: bool = False) -> pd.DataFrame:\n    order_df = read_order_file(backtest_config[\"order_file\"])\n\n    cash_limit = backtest_config[\"exchange\"].pop(\"cash_limit\")\n    generate_report = backtest_config.pop(\"generate_report\")\n\n    stock_pool = order_df[\"instrument\"].unique().tolist()\n    stock_pool.sort()\n\n    single = single_with_simulator if with_simulator else single_with_collect_data_loop\n    mp_config = {\"n_jobs\": backtest_config[\"concurrency\"], \"verbose\": 10, \"backend\": \"multiprocessing\"}\n    torch.set_num_threads(1)  # https://github.com/pytorch/pytorch/issues/17199\n    res = Parallel(**mp_config)(\n        delayed(single)(\n            backtest_config=backtest_config,\n            orders=order_df[order_df[\"instrument\"] == stock].copy(),\n            split=\"stock\",\n            cash_limit=cash_limit,\n            generate_report=generate_report,\n        )\n        for stock in stock_pool\n    )\n\n    output_path = Path(backtest_config[\"output_dir\"])\n    if generate_report:\n        with (output_path / \"report.pkl\").open(\"wb\") as f:\n            report = {}\n            for r in res:\n                report.update(r[1])\n            pickle.dump(report, f)\n        res = pd.concat([r[0] for r in res], 0)\n    else:\n        res = pd.concat(res)\n\n    if not output_path.exists():\n        os.makedirs(output_path)\n\n    if \"pa\" in res.columns:\n        res[\"pa\"] = res[\"pa\"] * 10000.0  # align with training metrics\n    res.to_csv(output_path / \"backtest_result.csv\")\n    return res\n\n\nif __name__ == \"__main__\":\n    import warnings\n\n    warnings.filterwarnings(\"ignore\", category=DeprecationWarning)\n    warnings.filterwarnings(\"ignore\", category=RuntimeWarning)\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--config_path\", type=str, required=True, help=\"Path to the config file\")\n    parser.add_argument(\"--use_simulator\", action=\"store_true\", help=\"Whether to use simulator as the backend\")\n    parser.add_argument(\n        \"--n_jobs\",\n        type=int,\n        required=False,\n        help=\"The number of jobs for running backtest parallely(1 for single process)\",\n    )\n    args = parser.parse_args()\n\n    config = get_backtest_config_fromfile(args.config_path)\n    if args.n_jobs is not None:\n        config[\"concurrency\"] = args.n_jobs\n\n    backtest(\n        backtest_config=config,\n        with_simulator=args.use_simulator,\n    )\n"
  },
  {
    "path": "qlib/rl/contrib/naive_config_parser.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport os\nimport platform\nimport shutil\nimport sys\nimport tempfile\nfrom importlib import import_module\nfrom ruamel.yaml import YAML\n\nDELETE_KEY = \"_delete_\"\n\n\ndef merge_a_into_b(a: dict, b: dict) -> dict:\n    b = b.copy()\n    for k, v in a.items():\n        if isinstance(v, dict) and k in b:\n            v.pop(DELETE_KEY, False)\n            b[k] = merge_a_into_b(v, b[k])\n        else:\n            b[k] = v\n    return b\n\n\ndef check_file_exist(filename: str, msg_tmpl: str = 'file \"{}\" does not exist') -> None:\n    if not os.path.isfile(filename):\n        raise FileNotFoundError(msg_tmpl.format(filename))\n\n\ndef parse_backtest_config(path: str) -> dict:\n    abs_path = os.path.abspath(path)\n    check_file_exist(abs_path)\n\n    file_ext_name = os.path.splitext(abs_path)[1]\n    if file_ext_name not in (\".py\", \".json\", \".yaml\", \".yml\"):\n        raise IOError(\"Only py/yml/yaml/json type are supported now!\")\n\n    with tempfile.TemporaryDirectory() as tmp_config_dir:\n        with tempfile.NamedTemporaryFile(dir=tmp_config_dir, suffix=file_ext_name) as tmp_config_file:\n            if platform.system() == \"Windows\":\n                tmp_config_file.close()\n\n            tmp_config_name = os.path.basename(tmp_config_file.name)\n            shutil.copyfile(abs_path, tmp_config_file.name)\n\n            if abs_path.endswith(\".py\"):\n                tmp_module_name = os.path.splitext(tmp_config_name)[0]\n                sys.path.insert(0, tmp_config_dir)\n                module = import_module(tmp_module_name)\n                sys.path.pop(0)\n\n                config = {k: v for k, v in module.__dict__.items() if not k.startswith(\"__\")}\n\n                del sys.modules[tmp_module_name]\n            else:\n                with open(tmp_config_file.name) as input_stream:\n                    yaml = YAML(typ=\"safe\", pure=True)\n                    config = yaml.load(input_stream)\n\n    if \"_base_\" in config:\n        base_file_name = config.pop(\"_base_\")\n        if not isinstance(base_file_name, list):\n            base_file_name = [base_file_name]\n\n        for f in base_file_name:\n            base_config = parse_backtest_config(os.path.join(os.path.dirname(abs_path), f))\n            config = merge_a_into_b(a=config, b=base_config)\n\n    return config\n\n\ndef _convert_all_list_to_tuple(config: dict) -> dict:\n    for k, v in config.items():\n        if isinstance(v, list):\n            config[k] = tuple(v)\n        elif isinstance(v, dict):\n            config[k] = _convert_all_list_to_tuple(v)\n    return config\n\n\ndef get_backtest_config_fromfile(path: str) -> dict:\n    backtest_config = parse_backtest_config(path)\n\n    exchange_config_default = {\n        \"open_cost\": 0.0005,\n        \"close_cost\": 0.0015,\n        \"min_cost\": 5.0,\n        \"trade_unit\": 100.0,\n        \"cash_limit\": None,\n    }\n    backtest_config[\"exchange\"] = merge_a_into_b(a=backtest_config[\"exchange\"], b=exchange_config_default)\n    backtest_config[\"exchange\"] = _convert_all_list_to_tuple(backtest_config[\"exchange\"])\n\n    backtest_config_default = {\n        \"debug_single_stock\": None,\n        \"debug_single_day\": None,\n        \"concurrency\": -1,\n        \"multiplier\": 1.0,\n        \"output_dir\": \"outputs_backtest/\",\n        \"generate_report\": False,\n        \"data_granularity\": \"1min\",\n    }\n    backtest_config = merge_a_into_b(a=backtest_config, b=backtest_config_default)\n\n    return backtest_config\n"
  },
  {
    "path": "qlib/rl/contrib/train_onpolicy.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\nfrom __future__ import annotations\n\nimport argparse\nimport os\nimport random\nimport sys\nimport warnings\nfrom pathlib import Path\nfrom ruamel.yaml import YAML\nfrom typing import cast, List, Optional\n\nimport numpy as np\nimport pandas as pd\nimport torch\nfrom qlib.backtest import Order\nfrom qlib.backtest.decision import OrderDir\nfrom qlib.constant import ONE_MIN\nfrom qlib.rl.data.native import load_handler_intraday_processed_data\nfrom qlib.rl.interpreter import ActionInterpreter, StateInterpreter\nfrom qlib.rl.order_execution import SingleAssetOrderExecutionSimple\nfrom qlib.rl.reward import Reward\nfrom qlib.rl.trainer import Checkpoint, backtest, train\nfrom qlib.rl.trainer.callbacks import Callback, EarlyStopping, MetricsWriter\nfrom qlib.rl.utils.log import CsvWriter\nfrom qlib.utils import init_instance_by_config\nfrom tianshou.policy import BasePolicy\nfrom torch.utils.data import Dataset\n\n\ndef seed_everything(seed: int) -> None:\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n    np.random.seed(seed)\n    random.seed(seed)\n    torch.backends.cudnn.deterministic = True\n\n\ndef _read_orders(order_dir: Path) -> pd.DataFrame:\n    if os.path.isfile(order_dir):\n        return pd.read_pickle(order_dir)\n    else:\n        orders = []\n        for file in order_dir.iterdir():\n            order_data = pd.read_pickle(file)\n            orders.append(order_data)\n        return pd.concat(orders)\n\n\nclass LazyLoadDataset(Dataset):\n    def __init__(\n        self,\n        data_dir: str,\n        order_file_path: Path,\n        default_start_time_index: int,\n        default_end_time_index: int,\n    ) -> None:\n        self._default_start_time_index = default_start_time_index\n        self._default_end_time_index = default_end_time_index\n\n        self._order_df = _read_orders(order_file_path).reset_index()\n        self._ticks_index: Optional[pd.DatetimeIndex] = None\n        self._data_dir = Path(data_dir)\n\n    def __len__(self) -> int:\n        return len(self._order_df)\n\n    def __getitem__(self, index: int) -> Order:\n        row = self._order_df.iloc[index]\n        date = pd.Timestamp(str(row[\"date\"]))\n\n        if self._ticks_index is None:\n            # TODO: We only load ticks index once based on the assumption that ticks index of different dates\n            # TODO: in one experiment are all the same. If that assumption is not hold, we need to load ticks index\n            # TODO: of all dates.\n\n            data = load_handler_intraday_processed_data(\n                data_dir=self._data_dir,\n                stock_id=row[\"instrument\"],\n                date=date,\n                feature_columns_today=[],\n                feature_columns_yesterday=[],\n                backtest=True,\n                index_only=True,\n            )\n            self._ticks_index = [t - date for t in data.today.index]\n\n        order = Order(\n            stock_id=row[\"instrument\"],\n            amount=row[\"amount\"],\n            direction=OrderDir(int(row[\"order_type\"])),\n            start_time=date + self._ticks_index[self._default_start_time_index],\n            end_time=date + self._ticks_index[self._default_end_time_index - 1] + ONE_MIN,\n        )\n\n        return order\n\n\ndef train_and_test(\n    env_config: dict,\n    simulator_config: dict,\n    trainer_config: dict,\n    data_config: dict,\n    state_interpreter: StateInterpreter,\n    action_interpreter: ActionInterpreter,\n    policy: BasePolicy,\n    reward: Reward,\n    run_training: bool,\n    run_backtest: bool,\n) -> None:\n    order_root_path = Path(data_config[\"source\"][\"order_dir\"])\n\n    data_granularity = simulator_config.get(\"data_granularity\", 1)\n\n    def _simulator_factory_simple(order: Order) -> SingleAssetOrderExecutionSimple:\n        return SingleAssetOrderExecutionSimple(\n            order=order,\n            data_dir=data_config[\"source\"][\"feature_root_dir\"],\n            feature_columns_today=data_config[\"source\"][\"feature_columns_today\"],\n            feature_columns_yesterday=data_config[\"source\"][\"feature_columns_yesterday\"],\n            data_granularity=data_granularity,\n            ticks_per_step=simulator_config[\"time_per_step\"],\n            vol_threshold=simulator_config[\"vol_limit\"],\n        )\n\n    assert data_config[\"source\"][\"default_start_time_index\"] % data_granularity == 0\n    assert data_config[\"source\"][\"default_end_time_index\"] % data_granularity == 0\n\n    if run_training:\n        train_dataset, valid_dataset = [\n            LazyLoadDataset(\n                data_dir=data_config[\"source\"][\"feature_root_dir\"],\n                order_file_path=order_root_path / tag,\n                default_start_time_index=data_config[\"source\"][\"default_start_time_index\"] // data_granularity,\n                default_end_time_index=data_config[\"source\"][\"default_end_time_index\"] // data_granularity,\n            )\n            for tag in (\"train\", \"valid\")\n        ]\n\n        callbacks: List[Callback] = []\n        if \"checkpoint_path\" in trainer_config:\n            callbacks.append(MetricsWriter(dirpath=Path(trainer_config[\"checkpoint_path\"])))\n            callbacks.append(\n                Checkpoint(\n                    dirpath=Path(trainer_config[\"checkpoint_path\"]) / \"checkpoints\",\n                    every_n_iters=trainer_config.get(\"checkpoint_every_n_iters\", 1),\n                    save_latest=\"copy\",\n                ),\n            )\n        if \"earlystop_patience\" in trainer_config:\n            callbacks.append(\n                EarlyStopping(\n                    patience=trainer_config[\"earlystop_patience\"],\n                    monitor=\"val/pa\",\n                )\n            )\n\n        train(\n            simulator_fn=_simulator_factory_simple,\n            state_interpreter=state_interpreter,\n            action_interpreter=action_interpreter,\n            policy=policy,\n            reward=reward,\n            initial_states=cast(List[Order], train_dataset),\n            trainer_kwargs={\n                \"max_iters\": trainer_config[\"max_epoch\"],\n                \"finite_env_type\": env_config[\"parallel_mode\"],\n                \"concurrency\": env_config[\"concurrency\"],\n                \"val_every_n_iters\": trainer_config.get(\"val_every_n_epoch\", None),\n                \"callbacks\": callbacks,\n            },\n            vessel_kwargs={\n                \"episode_per_iter\": trainer_config[\"episode_per_collect\"],\n                \"update_kwargs\": {\n                    \"batch_size\": trainer_config[\"batch_size\"],\n                    \"repeat\": trainer_config[\"repeat_per_collect\"],\n                },\n                \"val_initial_states\": valid_dataset,\n            },\n        )\n\n    if run_backtest:\n        test_dataset = LazyLoadDataset(\n            data_dir=data_config[\"source\"][\"feature_root_dir\"],\n            order_file_path=order_root_path / \"test\",\n            default_start_time_index=data_config[\"source\"][\"default_start_time_index\"] // data_granularity,\n            default_end_time_index=data_config[\"source\"][\"default_end_time_index\"] // data_granularity,\n        )\n\n        backtest(\n            simulator_fn=_simulator_factory_simple,\n            state_interpreter=state_interpreter,\n            action_interpreter=action_interpreter,\n            initial_states=test_dataset,\n            policy=policy,\n            logger=CsvWriter(Path(trainer_config[\"checkpoint_path\"])),\n            reward=reward,\n            finite_env_type=env_config[\"parallel_mode\"],\n            concurrency=env_config[\"concurrency\"],\n        )\n\n\ndef main(config: dict, run_training: bool, run_backtest: bool) -> None:\n    if not run_training and not run_backtest:\n        warnings.warn(\"Skip the entire job since training and backtest are both skipped.\")\n        return\n\n    if \"seed\" in config[\"runtime\"]:\n        seed_everything(config[\"runtime\"][\"seed\"])\n\n    for extra_module_path in config[\"env\"].get(\"extra_module_paths\", []):\n        sys.path.append(extra_module_path)\n\n    state_interpreter: StateInterpreter = init_instance_by_config(config[\"state_interpreter\"])\n    action_interpreter: ActionInterpreter = init_instance_by_config(config[\"action_interpreter\"])\n    reward: Reward = init_instance_by_config(config[\"reward\"])\n\n    additional_policy_kwargs = {\n        \"obs_space\": state_interpreter.observation_space,\n        \"action_space\": action_interpreter.action_space,\n    }\n\n    # Create torch network\n    if \"network\" in config:\n        if \"kwargs\" not in config[\"network\"]:\n            config[\"network\"][\"kwargs\"] = {}\n        config[\"network\"][\"kwargs\"].update({\"obs_space\": state_interpreter.observation_space})\n        additional_policy_kwargs[\"network\"] = init_instance_by_config(config[\"network\"])\n\n    # Create policy\n    if \"kwargs\" not in config[\"policy\"]:\n        config[\"policy\"][\"kwargs\"] = {}\n    config[\"policy\"][\"kwargs\"].update(additional_policy_kwargs)\n    policy: BasePolicy = init_instance_by_config(config[\"policy\"])\n\n    use_cuda = config[\"runtime\"].get(\"use_cuda\", False)\n    if use_cuda:\n        policy.cuda()\n\n    train_and_test(\n        env_config=config[\"env\"],\n        simulator_config=config[\"simulator\"],\n        data_config=config[\"data\"],\n        trainer_config=config[\"trainer\"],\n        action_interpreter=action_interpreter,\n        state_interpreter=state_interpreter,\n        policy=policy,\n        reward=reward,\n        run_training=run_training,\n        run_backtest=run_backtest,\n    )\n\n\nif __name__ == \"__main__\":\n    warnings.filterwarnings(\"ignore\", category=DeprecationWarning)\n    warnings.filterwarnings(\"ignore\", category=RuntimeWarning)\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--config_path\", type=str, required=True, help=\"Path to the config file\")\n    parser.add_argument(\"--no_training\", action=\"store_true\", help=\"Skip training workflow.\")\n    parser.add_argument(\"--run_backtest\", action=\"store_true\", help=\"Run backtest workflow.\")\n    args = parser.parse_args()\n\n    with open(args.config_path, \"r\") as input_stream:\n        yaml = YAML(typ=\"safe\", pure=True)\n        config = yaml.load(input_stream)\n\n    main(config, run_training=not args.no_training, run_backtest=args.run_backtest)\n"
  },
  {
    "path": "qlib/rl/contrib/utils.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom __future__ import annotations\n\nfrom pathlib import Path\n\nimport pandas as pd\n\n\ndef read_order_file(order_file: Path | pd.DataFrame) -> pd.DataFrame:\n    if isinstance(order_file, pd.DataFrame):\n        return order_file\n\n    order_file = Path(order_file)\n\n    if order_file.suffix == \".pkl\":\n        order_df = pd.read_pickle(order_file).reset_index()\n    elif order_file.suffix == \".csv\":\n        order_df = pd.read_csv(order_file)\n    else:\n        raise TypeError(f\"Unsupported order file type: {order_file}\")\n\n    if \"date\" in order_df.columns:\n        # legacy dataframe columns\n        order_df = order_df.rename(columns={\"date\": \"datetime\", \"order_type\": \"direction\"})\n    order_df[\"datetime\"] = order_df[\"datetime\"].astype(str)\n\n    return order_df\n"
  },
  {
    "path": "qlib/rl/data/__init__.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\n\"\"\"Common utilities to handle ad-hoc-styled data.\n\nMost of these snippets comes from research project (paper code).\nPlease take caution when using them in production.\n\"\"\"\n"
  },
  {
    "path": "qlib/rl/data/base.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\nfrom __future__ import annotations\n\nfrom abc import abstractmethod\n\nimport pandas as pd\n\n\nclass BaseIntradayBacktestData:\n    \"\"\"\n    Raw market data that is often used in backtesting (thus called BacktestData).\n\n    Base class for all types of backtest data. Currently, each type of simulator has its corresponding backtest\n    data type.\n    \"\"\"\n\n    @abstractmethod\n    def __repr__(self) -> str:\n        raise NotImplementedError\n\n    @abstractmethod\n    def __len__(self) -> int:\n        raise NotImplementedError\n\n    @abstractmethod\n    def get_deal_price(self) -> pd.Series:\n        raise NotImplementedError\n\n    @abstractmethod\n    def get_volume(self) -> pd.Series:\n        raise NotImplementedError\n\n    @abstractmethod\n    def get_time_index(self) -> pd.DatetimeIndex:\n        raise NotImplementedError\n\n\nclass BaseIntradayProcessedData:\n    \"\"\"Processed market data after data cleanup and feature engineering.\n\n    It contains both processed data for \"today\" and \"yesterday\", as some algorithms\n    might use the market information of the previous day to assist decision making.\n    \"\"\"\n\n    today: pd.DataFrame\n    \"\"\"Processed data for \"today\".\n    Number of records must be ``time_length``, and columns must be ``feature_dim``.\"\"\"\n\n    yesterday: pd.DataFrame\n    \"\"\"Processed data for \"yesterday\".\n    Number of records must be ``time_length``, and columns must be ``feature_dim``.\"\"\"\n\n\nclass ProcessedDataProvider:\n    \"\"\"Provider of processed data\"\"\"\n\n    def get_data(\n        self,\n        stock_id: str,\n        date: pd.Timestamp,\n        feature_dim: int,\n        time_index: pd.Index,\n    ) -> BaseIntradayProcessedData:\n        raise NotImplementedError\n"
  },
  {
    "path": "qlib/rl/data/integration.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\n\"\"\"\nTODO: This file is used to integrate NeuTrader with Qlib to run the existing projects.\nTODO: The implementation here is kind of adhoc. It is better to design a more uniformed & general implementation.\n\"\"\"\n\nfrom __future__ import annotations\n\nfrom pathlib import Path\n\nimport qlib\nfrom qlib.constant import REG_CN\nfrom qlib.contrib.ops.high_freq import BFillNan, Cut, Date, DayCumsum, DayLast, FFillNan, IsInf, IsNull, Select\n\n\ndef init_qlib(qlib_config: dict) -> None:\n    \"\"\"Initialize necessary resource to launch the workflow, including data direction, feature columns, etc..\n\n    Parameters\n    ----------\n    qlib_config:\n        Qlib configuration.\n\n        Example::\n\n            {\n                \"provider_uri_day\": DATA_ROOT_DIR / \"qlib_1d\",\n                \"provider_uri_1min\": DATA_ROOT_DIR / \"qlib_1min\",\n                \"feature_root_dir\": DATA_ROOT_DIR / \"qlib_handler_stock\",\n                \"feature_columns_today\": [\n                    \"$open\", \"$high\", \"$low\", \"$close\", \"$vwap\", \"$bid\", \"$ask\", \"$volume\",\n                    \"$bidV\", \"$bidV1\", \"$bidV3\", \"$bidV5\", \"$askV\", \"$askV1\", \"$askV3\", \"$askV5\",\n                ],\n                \"feature_columns_yesterday\": [\n                    \"$open_1\", \"$high_1\", \"$low_1\", \"$close_1\", \"$vwap_1\", \"$bid_1\", \"$ask_1\", \"$volume_1\",\n                    \"$bidV_1\", \"$bidV1_1\", \"$bidV3_1\", \"$bidV5_1\", \"$askV_1\", \"$askV1_1\", \"$askV3_1\", \"$askV5_1\",\n                ],\n            }\n    \"\"\"\n\n    def _convert_to_path(path: str | Path) -> Path:\n        return path if isinstance(path, Path) else Path(path)\n\n    provider_uri_map = {}\n    for granularity in [\"1min\", \"5min\", \"day\"]:\n        if f\"provider_uri_{granularity}\" in qlib_config:\n            provider_uri_map[f\"{granularity}\"] = _convert_to_path(qlib_config[f\"provider_uri_{granularity}\"]).as_posix()\n\n    qlib.init(\n        region=REG_CN,\n        auto_mount=False,\n        custom_ops=[DayLast, FFillNan, BFillNan, Date, Select, IsNull, IsInf, Cut, DayCumsum],\n        expression_cache=None,\n        calendar_provider={\n            \"class\": \"LocalCalendarProvider\",\n            \"module_path\": \"qlib.data.data\",\n            \"kwargs\": {\n                \"backend\": {\n                    \"class\": \"FileCalendarStorage\",\n                    \"module_path\": \"qlib.data.storage.file_storage\",\n                    \"kwargs\": {\"provider_uri_map\": provider_uri_map},\n                },\n            },\n        },\n        feature_provider={\n            \"class\": \"LocalFeatureProvider\",\n            \"module_path\": \"qlib.data.data\",\n            \"kwargs\": {\n                \"backend\": {\n                    \"class\": \"FileFeatureStorage\",\n                    \"module_path\": \"qlib.data.storage.file_storage\",\n                    \"kwargs\": {\"provider_uri_map\": provider_uri_map},\n                },\n            },\n        },\n        provider_uri=provider_uri_map,\n        kernels=1,\n        redis_port=-1,\n        clear_mem_cache=False,  # init_qlib will be called for multiple times. Keep the cache for improving performance\n    )\n"
  },
  {
    "path": "qlib/rl/data/native.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\nfrom __future__ import annotations\n\nimport os\nfrom pathlib import Path\nfrom typing import List, cast\n\nimport cachetools\nimport pandas as pd\n\nfrom qlib.backtest import Exchange, Order\nfrom qlib.backtest.decision import TradeRange, TradeRangeByTime\nfrom qlib.constant import EPS_T\nfrom qlib.utils.pickle_utils import restricted_pickle_load\n\nfrom .base import BaseIntradayBacktestData, BaseIntradayProcessedData, ProcessedDataProvider\n\n\ndef get_ticks_slice(\n    ticks_index: pd.DatetimeIndex,\n    start: pd.Timestamp,\n    end: pd.Timestamp,\n    include_end: bool = False,\n) -> pd.DatetimeIndex:\n    if not include_end:\n        end = end - EPS_T\n    return ticks_index[ticks_index.slice_indexer(start, end)]\n\n\nclass IntradayBacktestData(BaseIntradayBacktestData):\n    \"\"\"Backtest data for Qlib simulator\"\"\"\n\n    def __init__(\n        self,\n        order: Order,\n        exchange: Exchange,\n        ticks_index: pd.DatetimeIndex,\n        ticks_for_order: pd.DatetimeIndex,\n    ) -> None:\n        self._order = order\n        self._exchange = exchange\n        self._start_time = ticks_for_order[0]\n        self._end_time = ticks_for_order[-1]\n        self.ticks_index = ticks_index\n        self.ticks_for_order = ticks_for_order\n\n        self._deal_price = cast(\n            pd.Series,\n            self._exchange.get_deal_price(\n                self._order.stock_id,\n                self._start_time,\n                self._end_time,\n                direction=self._order.direction,\n                method=None,\n            ),\n        )\n        self._volume = cast(\n            pd.Series,\n            self._exchange.get_volume(\n                self._order.stock_id,\n                self._start_time,\n                self._end_time,\n                method=None,\n            ),\n        )\n\n    def __repr__(self) -> str:\n        return (\n            f\"Order: {self._order}, Exchange: {self._exchange}, \"\n            f\"Start time: {self._start_time}, End time: {self._end_time}\"\n        )\n\n    def __len__(self) -> int:\n        return len(self._deal_price)\n\n    def get_deal_price(self) -> pd.Series:\n        return self._deal_price\n\n    def get_volume(self) -> pd.Series:\n        return self._volume\n\n    def get_time_index(self) -> pd.DatetimeIndex:\n        return pd.DatetimeIndex([e[1] for e in list(self._exchange.quote_df.index)])\n\n\nclass DataframeIntradayBacktestData(BaseIntradayBacktestData):\n    \"\"\"Backtest data from dataframe\"\"\"\n\n    def __init__(self, df: pd.DataFrame, price_column: str = \"$close0\", volume_column: str = \"$volume0\") -> None:\n        self.df = df\n        self.price_column = price_column\n        self.volume_column = volume_column\n\n    def __repr__(self) -> str:\n        with pd.option_context(\"memory_usage\", False, \"display.max_info_columns\", 1, \"display.large_repr\", \"info\"):\n            return f\"{self.__class__.__name__}({self.df})\"\n\n    def __len__(self) -> int:\n        return len(self.df)\n\n    def get_deal_price(self) -> pd.Series:\n        return self.df[self.price_column]\n\n    def get_volume(self) -> pd.Series:\n        return self.df[self.volume_column]\n\n    def get_time_index(self) -> pd.DatetimeIndex:\n        return cast(pd.DatetimeIndex, self.df.index)\n\n\n@cachetools.cached(  # type: ignore\n    cache=cachetools.LRUCache(100),\n    key=lambda order, _, __: order.key_by_day,\n)\ndef load_backtest_data(\n    order: Order,\n    trade_exchange: Exchange,\n    trade_range: TradeRange,\n) -> IntradayBacktestData:\n    ticks_index = pd.DatetimeIndex(trade_exchange.quote_df.reset_index()[\"datetime\"])\n    ticks_index = ticks_index[order.start_time <= ticks_index]\n    ticks_index = ticks_index[ticks_index <= order.end_time]\n\n    if isinstance(trade_range, TradeRangeByTime):\n        ticks_for_order = get_ticks_slice(\n            ticks_index,\n            trade_range.start_time,\n            trade_range.end_time,\n            include_end=True,\n        )\n    else:\n        ticks_for_order = None  # FIXME: implement this logic\n\n    backtest_data = IntradayBacktestData(\n        order=order,\n        exchange=trade_exchange,\n        ticks_index=ticks_index,\n        ticks_for_order=ticks_for_order,\n    )\n    return backtest_data\n\n\nclass HandlerIntradayProcessedData(BaseIntradayProcessedData):\n    \"\"\"Subclass of IntradayProcessedData. Used to handle handler (bin format) style data.\"\"\"\n\n    def __init__(\n        self,\n        data_dir: Path,\n        stock_id: str,\n        date: pd.Timestamp,\n        feature_columns_today: List[str],\n        feature_columns_yesterday: List[str],\n        backtest: bool = False,\n        index_only: bool = False,\n    ) -> None:\n        def _drop_stock_id(df: pd.DataFrame) -> pd.DataFrame:\n            df = df.reset_index()\n            if \"instrument\" in df.columns:\n                df = df.drop(columns=[\"instrument\"])\n            return df.set_index([\"datetime\"])\n\n        path = os.path.join(data_dir, \"backtest\" if backtest else \"feature\", f\"{stock_id}.pkl\")\n        start_time, end_time = date.replace(hour=0, minute=0, second=0), date.replace(hour=23, minute=59, second=59)\n        with open(path, \"rb\") as fstream:\n            dataset = restricted_pickle_load(fstream)\n        data = dataset.handler.fetch(pd.IndexSlice[stock_id, start_time:end_time], level=None)\n\n        if index_only:\n            self.today = _drop_stock_id(data[[]])\n            self.yesterday = _drop_stock_id(data[[]])\n        else:\n            self.today = _drop_stock_id(data[feature_columns_today])\n            self.yesterday = _drop_stock_id(data[feature_columns_yesterday])\n\n    def __repr__(self) -> str:\n        with pd.option_context(\"memory_usage\", False, \"display.max_info_columns\", 1, \"display.large_repr\", \"info\"):\n            return f\"{self.__class__.__name__}({self.today}, {self.yesterday})\"\n\n\n@cachetools.cached(  # type: ignore\n    cache=cachetools.LRUCache(100),  # 100 * 50K = 5MB\n    key=lambda data_dir, stock_id, date, feature_columns_today, feature_columns_yesterday, backtest, index_only: (\n        stock_id,\n        date,\n        backtest,\n        index_only,\n    ),\n)\ndef load_handler_intraday_processed_data(\n    data_dir: Path,\n    stock_id: str,\n    date: pd.Timestamp,\n    feature_columns_today: List[str],\n    feature_columns_yesterday: List[str],\n    backtest: bool = False,\n    index_only: bool = False,\n) -> HandlerIntradayProcessedData:\n    return HandlerIntradayProcessedData(\n        data_dir, stock_id, date, feature_columns_today, feature_columns_yesterday, backtest, index_only\n    )\n\n\nclass HandlerProcessedDataProvider(ProcessedDataProvider):\n    def __init__(\n        self,\n        data_dir: str,\n        feature_columns_today: List[str],\n        feature_columns_yesterday: List[str],\n        backtest: bool = False,\n    ) -> None:\n        super().__init__()\n\n        self.data_dir = Path(data_dir)\n        self.feature_columns_today = feature_columns_today\n        self.feature_columns_yesterday = feature_columns_yesterday\n        self.backtest = backtest\n\n    def get_data(\n        self,\n        stock_id: str,\n        date: pd.Timestamp,\n        feature_dim: int,\n        time_index: pd.Index,\n    ) -> BaseIntradayProcessedData:\n        return load_handler_intraday_processed_data(\n            self.data_dir,\n            stock_id,\n            date,\n            self.feature_columns_today,\n            self.feature_columns_yesterday,\n            backtest=self.backtest,\n            index_only=False,\n        )\n"
  },
  {
    "path": "qlib/rl/data/pickle_styled.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\n\"\"\"This module contains utilities to read financial data from pickle-styled files.\n\nThis is the format used in `OPD paper <https://seqml.github.io/opd/>`__. NOT the standard data format in qlib.\n\nThe data here are all wrapped with ``@lru_cache``, which saves the expensive IO cost to repetitively read the data.\nWe also encourage users to use ``get_xxx_yyy`` rather than ``XxxYyy`` (although they are the same thing),\nbecause ``get_xxx_yyy`` is cache-optimized.\n\nNote that these pickle files are dumped with Python 3.8. Python lower than 3.7 might not be able to load them.\nSee `PEP 574 <https://peps.python.org/pep-0574/>`__ for details.\n\nThis file shows resemblence to qlib.backtest.high_performance_ds. We might merge those two in future.\n\"\"\"\n\n# TODO: merge with qlib/backtest/high_performance_ds.py\n\nfrom __future__ import annotations\n\nfrom functools import lru_cache\nfrom pathlib import Path\nfrom typing import List, Sequence, cast\n\nimport cachetools\nimport numpy as np\nimport pandas as pd\nfrom cachetools.keys import hashkey\n\nfrom qlib.backtest.decision import Order, OrderDir\nfrom qlib.rl.data.base import BaseIntradayBacktestData, BaseIntradayProcessedData, ProcessedDataProvider\nfrom qlib.typehint import Literal\n\nDealPriceType = Literal[\"bid_or_ask\", \"bid_or_ask_fill\", \"close\"]\n\"\"\"Several ad-hoc deal price.\n``bid_or_ask``: If sell, use column ``$bid0``; if buy, use column ``$ask0``.\n``bid_or_ask_fill``: Based on ``bid_or_ask``. If price is 0, use another price (``$ask0`` / ``$bid0``) instead.\n``close``: Use close price (``$close0``) as deal price.\n\"\"\"\n\n\ndef _infer_processed_data_column_names(shape: int) -> List[str]:\n    if shape == 16:\n        return [\n            \"$open\",\n            \"$high\",\n            \"$low\",\n            \"$close\",\n            \"$vwap\",\n            \"$bid\",\n            \"$ask\",\n            \"$volume\",\n            \"$bidV\",\n            \"$bidV1\",\n            \"$bidV3\",\n            \"$bidV5\",\n            \"$askV\",\n            \"$askV1\",\n            \"$askV3\",\n            \"$askV5\",\n        ]\n    if shape == 6:\n        return [\"$high\", \"$low\", \"$open\", \"$close\", \"$vwap\", \"$volume\"]\n    elif shape == 5:\n        return [\"$high\", \"$low\", \"$open\", \"$close\", \"$volume\"]\n    raise ValueError(f\"Unrecognized data shape: {shape}\")\n\n\ndef _find_pickle(filename_without_suffix: Path) -> Path:\n    suffix_list = [\".pkl\", \".pkl.backtest\"]\n    paths: List[Path] = []\n    for suffix in suffix_list:\n        path = filename_without_suffix.parent / (filename_without_suffix.name + suffix)\n        if path.exists():\n            paths.append(path)\n    if not paths:\n        raise FileNotFoundError(f\"No file starting with '{filename_without_suffix}' found\")\n    if len(paths) > 1:\n        raise ValueError(f\"Multiple paths are found with prefix '{filename_without_suffix}': {paths}\")\n    return paths[0]\n\n\n@lru_cache(maxsize=10)  # 10 * 40M = 400MB\ndef _read_pickle(filename_without_suffix: Path) -> pd.DataFrame:\n    df = pd.read_pickle(_find_pickle(filename_without_suffix))\n    index_cols = df.index.names\n\n    df = df.reset_index()\n    for date_col_name in [\"date\", \"datetime\"]:\n        if date_col_name in df:\n            df[date_col_name] = pd.to_datetime(df[date_col_name])\n    df = df.set_index(index_cols)\n\n    return df\n\n\nclass SimpleIntradayBacktestData(BaseIntradayBacktestData):\n    \"\"\"Backtest data for simple simulator\"\"\"\n\n    def __init__(\n        self,\n        data_dir: Path | str,\n        stock_id: str,\n        date: pd.Timestamp,\n        deal_price: DealPriceType = \"close\",\n        order_dir: int | None = None,\n    ) -> None:\n        super(SimpleIntradayBacktestData, self).__init__()\n\n        backtest = _read_pickle((data_dir if isinstance(data_dir, Path) else Path(data_dir)) / stock_id)\n        backtest = backtest.loc[pd.IndexSlice[stock_id, :, date]]\n\n        # No longer need for pandas >= 1.4\n        # backtest = backtest.droplevel([0, 2])\n\n        self.data: pd.DataFrame = backtest\n        self.deal_price_type: DealPriceType = deal_price\n        self.order_dir = order_dir\n\n    def __repr__(self) -> str:\n        with pd.option_context(\"memory_usage\", False, \"display.max_info_columns\", 1, \"display.large_repr\", \"info\"):\n            return f\"{self.__class__.__name__}({self.data})\"\n\n    def __len__(self) -> int:\n        return len(self.data)\n\n    def get_deal_price(self) -> pd.Series:\n        \"\"\"Return a pandas series that can be indexed with time.\n        See :attribute:`DealPriceType` for details.\"\"\"\n        if self.deal_price_type in (\"bid_or_ask\", \"bid_or_ask_fill\"):\n            if self.order_dir is None:\n                raise ValueError(\"Order direction cannot be none when deal_price_type is not close.\")\n            if self.order_dir == OrderDir.SELL:\n                col = \"$bid0\"\n            else:  # BUY\n                col = \"$ask0\"\n        elif self.deal_price_type == \"close\":\n            col = \"$close0\"\n        else:\n            raise ValueError(f\"Unsupported deal_price_type: {self.deal_price_type}\")\n        price = self.data[col]\n\n        if self.deal_price_type == \"bid_or_ask_fill\":\n            if self.order_dir == OrderDir.SELL:\n                fill_col = \"$ask0\"\n            else:\n                fill_col = \"$bid0\"\n            price = price.replace(0, np.nan).fillna(self.data[fill_col])\n\n        return price\n\n    def get_volume(self) -> pd.Series:\n        \"\"\"Return a volume series that can be indexed with time.\"\"\"\n        return self.data[\"$volume0\"]\n\n    def get_time_index(self) -> pd.DatetimeIndex:\n        return cast(pd.DatetimeIndex, self.data.index)\n\n\nclass PickleIntradayProcessedData(BaseIntradayProcessedData):\n    \"\"\"Subclass of IntradayProcessedData. Used to handle pickle-styled data.\"\"\"\n\n    def __init__(\n        self,\n        data_dir: Path | str,\n        stock_id: str,\n        date: pd.Timestamp,\n        feature_dim: int,\n        time_index: pd.Index,\n    ) -> None:\n        proc = _read_pickle((data_dir if isinstance(data_dir, Path) else Path(data_dir)) / stock_id)\n\n        # We have to infer the names here because,\n        # unfortunately they are not included in the original data.\n        cnames = _infer_processed_data_column_names(feature_dim)\n\n        time_length: int = len(time_index)\n\n        try:\n            # new data format\n            proc = proc.loc[pd.IndexSlice[stock_id, :, date]]\n            assert len(proc) == time_length and len(proc.columns) == feature_dim * 2\n            proc_today = proc[cnames]\n            proc_yesterday = proc[[f\"{c}_1\" for c in cnames]].rename(columns=lambda c: c[:-2])\n        except (IndexError, KeyError):\n            # legacy data\n            proc = proc.loc[pd.IndexSlice[stock_id, date]]\n            assert time_length * feature_dim * 2 == len(proc)\n            proc_today = proc.to_numpy()[: time_length * feature_dim].reshape((time_length, feature_dim))\n            proc_yesterday = proc.to_numpy()[time_length * feature_dim :].reshape((time_length, feature_dim))\n            proc_today = pd.DataFrame(proc_today, index=time_index, columns=cnames)\n            proc_yesterday = pd.DataFrame(proc_yesterday, index=time_index, columns=cnames)\n\n        self.today: pd.DataFrame = proc_today\n        self.yesterday: pd.DataFrame = proc_yesterday\n        assert len(self.today.columns) == len(self.yesterday.columns) == feature_dim\n        assert len(self.today) == len(self.yesterday) == time_length\n\n    def __repr__(self) -> str:\n        with pd.option_context(\"memory_usage\", False, \"display.max_info_columns\", 1, \"display.large_repr\", \"info\"):\n            return f\"{self.__class__.__name__}({self.today}, {self.yesterday})\"\n\n\n@lru_cache(maxsize=100)  # 100 * 50K = 5MB\ndef load_simple_intraday_backtest_data(\n    data_dir: Path,\n    stock_id: str,\n    date: pd.Timestamp,\n    deal_price: DealPriceType = \"close\",\n    order_dir: int | None = None,\n) -> SimpleIntradayBacktestData:\n    return SimpleIntradayBacktestData(data_dir, stock_id, date, deal_price, order_dir)\n\n\n@cachetools.cached(  # type: ignore\n    cache=cachetools.LRUCache(100),  # 100 * 50K = 5MB\n    key=lambda data_dir, stock_id, date, feature_dim, time_index: hashkey(data_dir, stock_id, date),\n)\ndef load_pickle_intraday_processed_data(\n    data_dir: Path,\n    stock_id: str,\n    date: pd.Timestamp,\n    feature_dim: int,\n    time_index: pd.Index,\n) -> BaseIntradayProcessedData:\n    return PickleIntradayProcessedData(data_dir, stock_id, date, feature_dim, time_index)\n\n\nclass PickleProcessedDataProvider(ProcessedDataProvider):\n    def __init__(self, data_dir: Path) -> None:\n        super().__init__()\n\n        self._data_dir = data_dir\n\n    def get_data(\n        self,\n        stock_id: str,\n        date: pd.Timestamp,\n        feature_dim: int,\n        time_index: pd.Index,\n    ) -> BaseIntradayProcessedData:\n        return load_pickle_intraday_processed_data(\n            data_dir=self._data_dir,\n            stock_id=stock_id,\n            date=date,\n            feature_dim=feature_dim,\n            time_index=time_index,\n        )\n\n\ndef load_orders(\n    order_path: Path,\n    start_time: pd.Timestamp = None,\n    end_time: pd.Timestamp = None,\n) -> Sequence[Order]:\n    \"\"\"Load orders, and set start time and end time for the orders.\"\"\"\n\n    start_time = start_time or pd.Timestamp(\"0:00:00\")\n    end_time = end_time or pd.Timestamp(\"23:59:59\")\n\n    if order_path.is_file():\n        order_df = pd.read_pickle(order_path)\n    else:\n        order_df = []\n        for file in order_path.iterdir():\n            order_data = pd.read_pickle(file)\n            order_df.append(order_data)\n        order_df = pd.concat(order_df)\n\n    order_df = order_df.reset_index()\n\n    # Legacy-style orders have \"date\" instead of \"datetime\"\n    if \"date\" in order_df.columns:\n        order_df = order_df.rename(columns={\"date\": \"datetime\"})\n\n    # Sometimes \"date\" are str rather than Timestamp\n    order_df[\"datetime\"] = pd.to_datetime(order_df[\"datetime\"])\n\n    orders: List[Order] = []\n\n    for _, row in order_df.iterrows():\n        # filter out orders with amount == 0\n        if row[\"amount\"] <= 0:\n            continue\n        orders.append(\n            Order(\n                row[\"instrument\"],\n                row[\"amount\"],\n                OrderDir(int(row[\"order_type\"])),\n                row[\"datetime\"].replace(hour=start_time.hour, minute=start_time.minute, second=start_time.second),\n                row[\"datetime\"].replace(hour=end_time.hour, minute=end_time.minute, second=end_time.second),\n            ),\n        )\n\n    return orders\n"
  },
  {
    "path": "qlib/rl/interpreter.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom __future__ import annotations\n\nfrom typing import Any, Generic, TypeVar\n\nimport gym\nimport numpy as np\nfrom gym import spaces\n\nfrom qlib.typehint import final\nfrom .simulator import ActType, StateType\n\nObsType = TypeVar(\"ObsType\")\nPolicyActType = TypeVar(\"PolicyActType\")\n\n\nclass Interpreter:\n    \"\"\"Interpreter is a media between states produced by simulators and states needed by RL policies.\n    Interpreters are two-way:\n\n    1. From simulator state to policy state (aka observation), see :class:`StateInterpreter`.\n    2. From policy action to action accepted by simulator, see :class:`ActionInterpreter`.\n\n    Inherit one of the two sub-classes to define your own interpreter.\n    This super-class is only used for isinstance check.\n\n    Interpreters are recommended to be stateless, meaning that storing temporary information with ``self.xxx``\n    in interpreter is anti-pattern. In future, we might support register some interpreter-related\n    states by calling ``self.env.register_state()``, but it's not planned for first iteration.\n    \"\"\"\n\n\nclass StateInterpreter(Generic[StateType, ObsType], Interpreter):\n    \"\"\"State Interpreter that interpret execution result of qlib executor into rl env state\"\"\"\n\n    @property\n    def observation_space(self) -> gym.Space:\n        raise NotImplementedError()\n\n    @final  # no overridden\n    def __call__(self, simulator_state: StateType) -> ObsType:\n        obs = self.interpret(simulator_state)\n        self.validate(obs)\n        return obs\n\n    def validate(self, obs: ObsType) -> None:\n        \"\"\"Validate whether an observation belongs to the pre-defined observation space.\"\"\"\n        _gym_space_contains(self.observation_space, obs)\n\n    def interpret(self, simulator_state: StateType) -> ObsType:\n        \"\"\"Interpret the state of simulator.\n\n        Parameters\n        ----------\n        simulator_state\n            Retrieved with ``simulator.get_state()``.\n\n        Returns\n        -------\n        State needed by policy. Should conform with the state space defined in ``observation_space``.\n        \"\"\"\n        raise NotImplementedError(\"interpret is not implemented!\")\n\n\nclass ActionInterpreter(Generic[StateType, PolicyActType, ActType], Interpreter):\n    \"\"\"Action Interpreter that interpret rl agent action into qlib orders\"\"\"\n\n    @property\n    def action_space(self) -> gym.Space:\n        raise NotImplementedError()\n\n    @final  # no overridden\n    def __call__(self, simulator_state: StateType, action: PolicyActType) -> ActType:\n        self.validate(action)\n        obs = self.interpret(simulator_state, action)\n        return obs\n\n    def validate(self, action: PolicyActType) -> None:\n        \"\"\"Validate whether an action belongs to the pre-defined action space.\"\"\"\n        _gym_space_contains(self.action_space, action)\n\n    def interpret(self, simulator_state: StateType, action: PolicyActType) -> ActType:\n        \"\"\"Convert the policy action to simulator action.\n\n        Parameters\n        ----------\n        simulator_state\n            Retrieved with ``simulator.get_state()``.\n        action\n            Raw action given by policy.\n\n        Returns\n        -------\n        The action needed by simulator,\n        \"\"\"\n        raise NotImplementedError(\"interpret is not implemented!\")\n\n\ndef _gym_space_contains(space: gym.Space, x: Any) -> None:\n    \"\"\"Strengthened version of gym.Space.contains.\n    Giving more diagnostic information on why validation fails.\n\n    Throw exception rather than returning true or false.\n    \"\"\"\n    if isinstance(space, spaces.Dict):\n        if not isinstance(x, dict) or len(x) != len(space):\n            raise GymSpaceValidationError(\"Sample must be a dict with same length as space.\", space, x)\n        for k, subspace in space.spaces.items():\n            if k not in x:\n                raise GymSpaceValidationError(f\"Key {k} not found in sample.\", space, x)\n            try:\n                _gym_space_contains(subspace, x[k])\n            except GymSpaceValidationError as e:\n                raise GymSpaceValidationError(f\"Subspace of key {k} validation error.\", space, x) from e\n\n    elif isinstance(space, spaces.Tuple):\n        if isinstance(x, (list, np.ndarray)):\n            x = tuple(x)  # Promote list and ndarray to tuple for contains check\n        if not isinstance(x, tuple) or len(x) != len(space):\n            raise GymSpaceValidationError(\"Sample must be a tuple with same length as space.\", space, x)\n        for i, (subspace, part) in enumerate(zip(space, x)):\n            try:\n                _gym_space_contains(subspace, part)\n            except GymSpaceValidationError as e:\n                raise GymSpaceValidationError(f\"Subspace of index {i} validation error.\", space, x) from e\n\n    else:\n        if not space.contains(x):\n            raise GymSpaceValidationError(\"Validation error reported by gym.\", space, x)\n\n\nclass GymSpaceValidationError(Exception):\n    def __init__(self, message: str, space: gym.Space, x: Any) -> None:\n        self.message = message\n        self.space = space\n        self.x = x\n\n    def __str__(self) -> str:\n        return f\"{self.message}\\n  Space: {self.space}\\n  Sample: {self.x}\"\n"
  },
  {
    "path": "qlib/rl/order_execution/__init__.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\n\"\"\"\nCurrently it supports single-asset order execution.\nMulti-asset is on the way.\n\"\"\"\n\nfrom .interpreter import (\n    FullHistoryStateInterpreter,\n    CurrentStepStateInterpreter,\n    CategoricalActionInterpreter,\n    TwapRelativeActionInterpreter,\n)\nfrom .network import Recurrent\nfrom .policy import AllOne, PPO\nfrom .reward import PAPenaltyReward\nfrom .simulator_simple import SingleAssetOrderExecutionSimple\nfrom .state import SAOEMetrics, SAOEState\nfrom .strategy import SAOEStateAdapter, SAOEStrategy, ProxySAOEStrategy, SAOEIntStrategy\n\n__all__ = [\n    \"FullHistoryStateInterpreter\",\n    \"CurrentStepStateInterpreter\",\n    \"CategoricalActionInterpreter\",\n    \"TwapRelativeActionInterpreter\",\n    \"Recurrent\",\n    \"AllOne\",\n    \"PPO\",\n    \"PAPenaltyReward\",\n    \"SingleAssetOrderExecutionSimple\",\n    \"SAOEStateAdapter\",\n    \"SAOEMetrics\",\n    \"SAOEState\",\n    \"SAOEStrategy\",\n    \"ProxySAOEStrategy\",\n    \"SAOEIntStrategy\",\n]\n"
  },
  {
    "path": "qlib/rl/order_execution/interpreter.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom __future__ import annotations\n\nimport math\nfrom typing import Any, List, Optional, cast\n\nimport numpy as np\nimport pandas as pd\nfrom gym import spaces\n\nfrom qlib.constant import EPS\nfrom qlib.rl.data.base import ProcessedDataProvider\nfrom qlib.rl.interpreter import ActionInterpreter, StateInterpreter\nfrom qlib.rl.order_execution.state import SAOEState\nfrom qlib.typehint import TypedDict\n\n__all__ = [\n    \"FullHistoryStateInterpreter\",\n    \"CurrentStepStateInterpreter\",\n    \"CategoricalActionInterpreter\",\n    \"TwapRelativeActionInterpreter\",\n    \"FullHistoryObs\",\n]\n\nfrom qlib.utils import init_instance_by_config\n\n\ndef canonicalize(value: int | float | np.ndarray | pd.DataFrame | dict) -> np.ndarray | dict:\n    \"\"\"To 32-bit numeric types. Recursively.\"\"\"\n    if isinstance(value, pd.DataFrame):\n        return value.to_numpy()\n    if isinstance(value, (float, np.floating)) or (isinstance(value, np.ndarray) and value.dtype.kind == \"f\"):\n        return np.array(value, dtype=np.float32)\n    elif isinstance(value, (int, bool, np.integer)) or (isinstance(value, np.ndarray) and value.dtype.kind == \"i\"):\n        return np.array(value, dtype=np.int32)\n    elif isinstance(value, dict):\n        return {k: canonicalize(v) for k, v in value.items()}\n    else:\n        return value\n\n\nclass FullHistoryObs(TypedDict):\n    data_processed: Any\n    data_processed_prev: Any\n    acquiring: Any\n    cur_tick: Any\n    cur_step: Any\n    num_step: Any\n    target: Any\n    position: Any\n    position_history: Any\n\n\nclass DummyStateInterpreter(StateInterpreter[SAOEState, dict]):\n    \"\"\"Dummy interpreter for policies that do not need inputs (for example, AllOne).\"\"\"\n\n    def interpret(self, state: SAOEState) -> dict:\n        # TODO: A fake state, used to pass `check_nan_observation`. Find a better way in the future.\n        return {\"DUMMY\": _to_int32(1)}\n\n    @property\n    def observation_space(self) -> spaces.Dict:\n        return spaces.Dict({\"DUMMY\": spaces.Box(-np.inf, np.inf, shape=(), dtype=np.int32)})\n\n\nclass FullHistoryStateInterpreter(StateInterpreter[SAOEState, FullHistoryObs]):\n    \"\"\"The observation of all the history, including today (until this moment), and yesterday.\n\n    Parameters\n    ----------\n    max_step\n        Total number of steps (an upper-bound estimation). For example, 390min / 30min-per-step = 13 steps.\n    data_ticks\n        Equal to the total number of records. For example, in SAOE per minute,\n        the total ticks is the length of day in minutes.\n    data_dim\n        Number of dimensions in data.\n    processed_data_provider\n        Provider of the processed data.\n    \"\"\"\n\n    def __init__(\n        self,\n        max_step: int,\n        data_ticks: int,\n        data_dim: int,\n        processed_data_provider: dict | ProcessedDataProvider,\n    ) -> None:\n        super().__init__()\n\n        self.max_step = max_step\n        self.data_ticks = data_ticks\n        self.data_dim = data_dim\n        self.processed_data_provider: ProcessedDataProvider = init_instance_by_config(\n            processed_data_provider,\n            accept_types=ProcessedDataProvider,\n        )\n\n    def interpret(self, state: SAOEState) -> FullHistoryObs:\n        processed = self.processed_data_provider.get_data(\n            stock_id=state.order.stock_id,\n            date=pd.Timestamp(state.order.start_time.date()),\n            feature_dim=self.data_dim,\n            time_index=state.ticks_index,\n        )\n\n        position_history = np.full(self.max_step + 1, 0.0, dtype=np.float32)\n        position_history[0] = state.order.amount\n        position_history[1 : len(state.history_steps) + 1] = state.history_steps[\"position\"].to_numpy()\n\n        # The min, slice here are to make sure that indices fit into the range,\n        # even after the final step of the simulator (in the done step),\n        # to make network in policy happy.\n        return cast(\n            FullHistoryObs,\n            canonicalize(\n                {\n                    \"data_processed\": np.array(self._mask_future_info(processed.today, state.cur_time)),\n                    \"data_processed_prev\": np.array(processed.yesterday),\n                    \"acquiring\": _to_int32(state.order.direction == state.order.BUY),\n                    \"cur_tick\": _to_int32(min(int(np.sum(state.ticks_index < state.cur_time)), self.data_ticks - 1)),\n                    \"cur_step\": _to_int32(min(state.cur_step, self.max_step - 1)),\n                    \"num_step\": _to_int32(self.max_step),\n                    \"target\": _to_float32(state.order.amount),\n                    \"position\": _to_float32(state.position),\n                    \"position_history\": _to_float32(position_history[: self.max_step]),\n                },\n            ),\n        )\n\n    @property\n    def observation_space(self) -> spaces.Dict:\n        space = {\n            \"data_processed\": spaces.Box(-np.inf, np.inf, shape=(self.data_ticks, self.data_dim)),\n            \"data_processed_prev\": spaces.Box(-np.inf, np.inf, shape=(self.data_ticks, self.data_dim)),\n            \"acquiring\": spaces.Discrete(2),\n            \"cur_tick\": spaces.Box(0, self.data_ticks - 1, shape=(), dtype=np.int32),\n            \"cur_step\": spaces.Box(0, self.max_step - 1, shape=(), dtype=np.int32),\n            # TODO: support arbitrary length index\n            \"num_step\": spaces.Box(self.max_step, self.max_step, shape=(), dtype=np.int32),\n            \"target\": spaces.Box(-EPS, np.inf, shape=()),\n            \"position\": spaces.Box(-EPS, np.inf, shape=()),\n            \"position_history\": spaces.Box(-EPS, np.inf, shape=(self.max_step,)),\n        }\n        return spaces.Dict(space)\n\n    @staticmethod\n    def _mask_future_info(arr: pd.DataFrame, current: pd.Timestamp) -> pd.DataFrame:\n        arr = arr.copy(deep=True)\n        arr.loc[current:] = 0.0  # mask out data after this moment (inclusive)\n        return arr\n\n\nclass CurrentStateObs(TypedDict):\n    acquiring: bool\n    cur_step: int\n    num_step: int\n    target: float\n    position: float\n\n\nclass CurrentStepStateInterpreter(StateInterpreter[SAOEState, CurrentStateObs]):\n    \"\"\"The observation of current step.\n\n    Used when policy only depends on the latest state, but not history.\n    The key list is not full. You can add more if more information is needed by your policy.\n    \"\"\"\n\n    def __init__(self, max_step: int) -> None:\n        super().__init__()\n\n        self.max_step = max_step\n\n    @property\n    def observation_space(self) -> spaces.Dict:\n        space = {\n            \"acquiring\": spaces.Discrete(2),\n            \"cur_step\": spaces.Box(0, self.max_step - 1, shape=(), dtype=np.int32),\n            \"num_step\": spaces.Box(self.max_step, self.max_step, shape=(), dtype=np.int32),\n            \"target\": spaces.Box(-EPS, np.inf, shape=()),\n            \"position\": spaces.Box(-EPS, np.inf, shape=()),\n        }\n        return spaces.Dict(space)\n\n    def interpret(self, state: SAOEState) -> CurrentStateObs:\n        assert state.cur_step <= self.max_step\n        obs = CurrentStateObs(\n            acquiring=state.order.direction == state.order.BUY,\n            cur_step=state.cur_step,\n            num_step=self.max_step,\n            target=state.order.amount,\n            position=state.position,\n        )\n        return obs\n\n\nclass CategoricalActionInterpreter(ActionInterpreter[SAOEState, int, float]):\n    \"\"\"Convert a discrete policy action to a continuous action, then multiplied by ``order.amount``.\n\n    Parameters\n    ----------\n    values\n        It can be a list of length $L$: $[a_1, a_2, \\\\ldots, a_L]$.\n        Then when policy givens decision $x$, $a_x$ times order amount is the output.\n        It can also be an integer $n$, in which case the list of length $n+1$ is auto-generated,\n        i.e., $[0, 1/n, 2/n, \\\\ldots, n/n]$.\n    max_step\n        Total number of steps (an upper-bound estimation). For example, 390min / 30min-per-step = 13 steps.\n    \"\"\"\n\n    def __init__(self, values: int | List[float], max_step: Optional[int] = None) -> None:\n        super().__init__()\n\n        if isinstance(values, int):\n            values = [i / values for i in range(0, values + 1)]\n        self.action_values = values\n        self.max_step = max_step\n\n    @property\n    def action_space(self) -> spaces.Discrete:\n        return spaces.Discrete(len(self.action_values))\n\n    def interpret(self, state: SAOEState, action: int) -> float:\n        assert 0 <= action < len(self.action_values)\n        if self.max_step is not None and state.cur_step >= self.max_step - 1:\n            return state.position\n        else:\n            return min(state.position, state.order.amount * self.action_values[action])\n\n\nclass TwapRelativeActionInterpreter(ActionInterpreter[SAOEState, float, float]):\n    \"\"\"Convert a continuous ratio to deal amount.\n\n    The ratio is relative to TWAP on the remainder of the day.\n    For example, there are 5 steps left, and the left position is 300.\n    With TWAP strategy, in each position, 60 should be traded.\n    When this interpreter receives action $a$, its output is $60 \\\\cdot a$.\n    \"\"\"\n\n    @property\n    def action_space(self) -> spaces.Box:\n        return spaces.Box(0, np.inf, shape=(), dtype=np.float32)\n\n    def interpret(self, state: SAOEState, action: float) -> float:\n        estimated_total_steps = math.ceil(len(state.ticks_for_order) / state.ticks_per_step)\n        twap_volume = state.position / (estimated_total_steps - state.cur_step)\n        return min(state.position, twap_volume * action)\n\n\ndef _to_int32(val):\n    return np.array(int(val), dtype=np.int32)\n\n\ndef _to_float32(val):\n    return np.array(val, dtype=np.float32)\n"
  },
  {
    "path": "qlib/rl/order_execution/network.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom __future__ import annotations\n\nfrom typing import List, Tuple, cast\n\nimport torch\nimport torch.nn as nn\nfrom tianshou.data import Batch\n\nfrom qlib.typehint import Literal\n\nfrom .interpreter import FullHistoryObs\n\n__all__ = [\"Recurrent\"]\n\n\nclass Recurrent(nn.Module):\n    \"\"\"The network architecture proposed in `OPD <https://seqml.github.io/opd/opd_aaai21_supplement.pdf>`_.\n\n    At every time step the input of policy network is divided into two parts,\n    the public variables and the private variables. which are handled by ``raw_rnn``\n    and ``pri_rnn`` in this network, respectively.\n\n    One minor difference is that, in this implementation, we don't assume the direction to be fixed.\n    Thus, another ``dire_fc`` is added to produce an extra direction-related feature.\n    \"\"\"\n\n    def __init__(\n        self,\n        obs_space: FullHistoryObs,\n        hidden_dim: int = 64,\n        output_dim: int = 32,\n        rnn_type: Literal[\"rnn\", \"lstm\", \"gru\"] = \"gru\",\n        rnn_num_layers: int = 1,\n    ) -> None:\n        super().__init__()\n\n        self.hidden_dim = hidden_dim\n        self.output_dim = output_dim\n        self.num_sources = 3\n\n        rnn_classes = {\"rnn\": nn.RNN, \"lstm\": nn.LSTM, \"gru\": nn.GRU}\n\n        self.rnn_class = rnn_classes[rnn_type]\n        self.rnn_layers = rnn_num_layers\n\n        self.raw_rnn = self.rnn_class(hidden_dim, hidden_dim, batch_first=True, num_layers=self.rnn_layers)\n        self.prev_rnn = self.rnn_class(hidden_dim, hidden_dim, batch_first=True, num_layers=self.rnn_layers)\n        self.pri_rnn = self.rnn_class(hidden_dim, hidden_dim, batch_first=True, num_layers=self.rnn_layers)\n\n        self.raw_fc = nn.Sequential(nn.Linear(obs_space[\"data_processed\"].shape[-1], hidden_dim), nn.ReLU())\n        self.pri_fc = nn.Sequential(nn.Linear(2, hidden_dim), nn.ReLU())\n        self.dire_fc = nn.Sequential(nn.Linear(2, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU())\n\n        self._init_extra_branches()\n\n        self.fc = nn.Sequential(\n            nn.Linear(hidden_dim * self.num_sources, hidden_dim),\n            nn.ReLU(),\n            nn.Linear(hidden_dim, output_dim),\n            nn.ReLU(),\n        )\n\n    def _init_extra_branches(self) -> None:\n        pass\n\n    def _source_features(self, obs: FullHistoryObs, device: torch.device) -> Tuple[List[torch.Tensor], torch.Tensor]:\n        bs, _, data_dim = obs[\"data_processed\"].size()\n        data = torch.cat((torch.zeros(bs, 1, data_dim, device=device), obs[\"data_processed\"]), 1)\n        cur_step = obs[\"cur_step\"].long()\n        cur_tick = obs[\"cur_tick\"].long()\n        bs_indices = torch.arange(bs, device=device)\n\n        position = obs[\"position_history\"] / obs[\"target\"].unsqueeze(-1)  # [bs, num_step]\n        steps = (\n            torch.arange(position.size(-1), device=device).unsqueeze(0).repeat(bs, 1).float()\n            / obs[\"num_step\"].unsqueeze(-1).float()\n        )  # [bs, num_step]\n        priv = torch.stack((position.float(), steps), -1)\n\n        data_in = self.raw_fc(data)\n        data_out, _ = self.raw_rnn(data_in)\n        # as it is padded with zero in front, this should be last minute\n        data_out_slice = data_out[bs_indices, cur_tick]\n\n        priv_in = self.pri_fc(priv)\n        priv_out = self.pri_rnn(priv_in)[0]\n        priv_out = priv_out[bs_indices, cur_step]\n\n        sources = [data_out_slice, priv_out]\n\n        dir_out = self.dire_fc(torch.stack((obs[\"acquiring\"], 1 - obs[\"acquiring\"]), -1).float())\n        sources.append(dir_out)\n\n        return sources, data_out\n\n    def forward(self, batch: Batch) -> torch.Tensor:\n        \"\"\"\n        Input should be a dict (at least) containing:\n\n        - data_processed: [N, T, C]\n        - cur_step: [N]  (int)\n        - cur_time: [N]  (int)\n        - position_history: [N, S]  (S is number of steps)\n        - target: [N]\n        - num_step: [N]  (int)\n        - acquiring: [N]  (0 or 1)\n        \"\"\"\n\n        inp = cast(FullHistoryObs, batch)\n        device = inp[\"data_processed\"].device\n\n        sources, _ = self._source_features(inp, device)\n        assert len(sources) == self.num_sources\n\n        out = torch.cat(sources, -1)\n        return self.fc(out)\n\n\nclass Attention(nn.Module):\n    def __init__(self, in_dim, out_dim):\n        super().__init__()\n        self.q_net = nn.Linear(in_dim, out_dim)\n        self.k_net = nn.Linear(in_dim, out_dim)\n        self.v_net = nn.Linear(in_dim, out_dim)\n\n    def forward(self, Q, K, V):\n        q = self.q_net(Q)\n        k = self.k_net(K)\n        v = self.v_net(V)\n\n        attn = torch.einsum(\"ijk,ilk->ijl\", q, k)\n        attn = attn.to(Q.device)\n        attn_prob = torch.softmax(attn, dim=-1)\n\n        attn_vec = torch.einsum(\"ijk,ikl->ijl\", attn_prob, v)\n\n        return attn_vec\n"
  },
  {
    "path": "qlib/rl/order_execution/policy.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom __future__ import annotations\n\nfrom pathlib import Path\nfrom typing import Any, Dict, Generator, Iterable, Optional, OrderedDict, Tuple, cast\n\nimport gym\nimport numpy as np\nimport torch\nimport torch.nn as nn\nfrom gym.spaces import Discrete\nfrom tianshou.data import Batch, ReplayBuffer, to_torch\nfrom tianshou.policy import BasePolicy, PPOPolicy, DQNPolicy\n\nfrom qlib.rl.trainer.trainer import Trainer\n\n__all__ = [\"AllOne\", \"PPO\", \"DQN\"]\n\n\n# baselines #\n\n\nclass NonLearnablePolicy(BasePolicy):\n    \"\"\"Tianshou's BasePolicy with empty ``learn`` and ``process_fn``.\n\n    This could be moved outside in future.\n    \"\"\"\n\n    def __init__(self, obs_space: gym.Space, action_space: gym.Space) -> None:\n        super().__init__()\n\n    def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, Any]:\n        return {}\n\n    def process_fn(\n        self,\n        batch: Batch,\n        buffer: ReplayBuffer,\n        indices: np.ndarray,\n    ) -> Batch:\n        return Batch({})\n\n\nclass AllOne(NonLearnablePolicy):\n    \"\"\"Forward returns a batch full of 1.\n\n    Useful when implementing some baselines (e.g., TWAP).\n    \"\"\"\n\n    def __init__(self, obs_space: gym.Space, action_space: gym.Space, fill_value: float | int = 1.0) -> None:\n        super().__init__(obs_space, action_space)\n\n        self.fill_value = fill_value\n\n    def forward(\n        self,\n        batch: Batch,\n        state: dict | Batch | np.ndarray = None,\n        **kwargs: Any,\n    ) -> Batch:\n        return Batch(act=np.full(len(batch), self.fill_value), state=state)\n\n\n# ppo #\n\n\nclass PPOActor(nn.Module):\n    def __init__(self, extractor: nn.Module, action_dim: int) -> None:\n        super().__init__()\n        self.extractor = extractor\n        self.layer_out = nn.Sequential(nn.Linear(cast(int, extractor.output_dim), action_dim), nn.Softmax(dim=-1))\n\n    def forward(\n        self,\n        obs: torch.Tensor,\n        state: torch.Tensor = None,\n        info: dict = {},\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:\n        feature = self.extractor(to_torch(obs, device=auto_device(self)))\n        out = self.layer_out(feature)\n        return out, state\n\n\nclass PPOCritic(nn.Module):\n    def __init__(self, extractor: nn.Module) -> None:\n        super().__init__()\n        self.extractor = extractor\n        self.value_out = nn.Linear(cast(int, extractor.output_dim), 1)\n\n    def forward(\n        self,\n        obs: torch.Tensor,\n        state: torch.Tensor = None,\n        info: dict = {},\n    ) -> torch.Tensor:\n        feature = self.extractor(to_torch(obs, device=auto_device(self)))\n        return self.value_out(feature).squeeze(dim=-1)\n\n\nclass PPO(PPOPolicy):\n    \"\"\"A wrapper of tianshou PPOPolicy.\n\n    Differences:\n\n    - Auto-create actor and critic network. Supports discrete action space only.\n    - Dedup common parameters between actor network and critic network\n      (not sure whether this is included in latest tianshou or not).\n    - Support a ``weight_file`` that supports loading checkpoint.\n    - Some parameters' default values are different from original.\n    \"\"\"\n\n    def __init__(\n        self,\n        network: nn.Module,\n        obs_space: gym.Space,\n        action_space: gym.Space,\n        lr: float,\n        weight_decay: float = 0.0,\n        discount_factor: float = 1.0,\n        max_grad_norm: float = 100.0,\n        reward_normalization: bool = True,\n        eps_clip: float = 0.3,\n        value_clip: bool = True,\n        vf_coef: float = 1.0,\n        gae_lambda: float = 1.0,\n        max_batch_size: int = 256,\n        deterministic_eval: bool = True,\n        weight_file: Optional[Path] = None,\n    ) -> None:\n        assert isinstance(action_space, Discrete)\n        actor = PPOActor(network, action_space.n)\n        critic = PPOCritic(network)\n        optimizer = torch.optim.Adam(\n            chain_dedup(actor.parameters(), critic.parameters()),\n            lr=lr,\n            weight_decay=weight_decay,\n        )\n        super().__init__(\n            actor,\n            critic,\n            optimizer,\n            torch.distributions.Categorical,\n            discount_factor=discount_factor,\n            max_grad_norm=max_grad_norm,\n            reward_normalization=reward_normalization,\n            eps_clip=eps_clip,\n            value_clip=value_clip,\n            vf_coef=vf_coef,\n            gae_lambda=gae_lambda,\n            max_batchsize=max_batch_size,\n            deterministic_eval=deterministic_eval,\n            observation_space=obs_space,\n            action_space=action_space,\n        )\n        if weight_file is not None:\n            set_weight(self, Trainer.get_policy_state_dict(weight_file))\n\n\nDQNModel = PPOActor  # Reuse PPOActor.\n\n\nclass DQN(DQNPolicy):\n    \"\"\"A wrapper of tianshou DQNPolicy.\n\n    Differences:\n\n    - Auto-create model network. Supports discrete action space only.\n    - Support a ``weight_file`` that supports loading checkpoint.\n    \"\"\"\n\n    def __init__(\n        self,\n        network: nn.Module,\n        obs_space: gym.Space,\n        action_space: gym.Space,\n        lr: float,\n        weight_decay: float = 0.0,\n        discount_factor: float = 0.99,\n        estimation_step: int = 1,\n        target_update_freq: int = 0,\n        reward_normalization: bool = False,\n        is_double: bool = True,\n        clip_loss_grad: bool = False,\n        weight_file: Optional[Path] = None,\n    ) -> None:\n        assert isinstance(action_space, Discrete)\n\n        model = DQNModel(network, action_space.n)\n        optimizer = torch.optim.Adam(\n            model.parameters(),\n            lr=lr,\n            weight_decay=weight_decay,\n        )\n\n        super().__init__(\n            model,\n            optimizer,\n            discount_factor=discount_factor,\n            estimation_step=estimation_step,\n            target_update_freq=target_update_freq,\n            reward_normalization=reward_normalization,\n            is_double=is_double,\n            clip_loss_grad=clip_loss_grad,\n        )\n        if weight_file is not None:\n            set_weight(self, Trainer.get_policy_state_dict(weight_file))\n\n\n# utilities: these should be put in a separate (common) file. #\n\n\ndef auto_device(module: nn.Module) -> torch.device:\n    for param in module.parameters():\n        return param.device\n    return torch.device(\"cpu\")  # fallback to cpu\n\n\ndef set_weight(policy: nn.Module, loaded_weight: OrderedDict) -> None:\n    try:\n        policy.load_state_dict(loaded_weight)\n    except RuntimeError:\n        # try again by loading the converted weight\n        # https://github.com/thu-ml/tianshou/issues/468\n        for k in list(loaded_weight):\n            loaded_weight[\"_actor_critic.\" + k] = loaded_weight[k]\n        policy.load_state_dict(loaded_weight)\n\n\ndef chain_dedup(*iterables: Iterable) -> Generator[Any, None, None]:\n    seen = set()\n    for iterable in iterables:\n        for i in iterable:\n            if i not in seen:\n                seen.add(i)\n                yield i\n"
  },
  {
    "path": "qlib/rl/order_execution/reward.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom __future__ import annotations\n\nfrom typing import cast\n\nimport numpy as np\n\nfrom qlib.backtest.decision import OrderDir\nfrom qlib.rl.order_execution.state import SAOEMetrics, SAOEState\nfrom qlib.rl.reward import Reward\n\n__all__ = [\"PAPenaltyReward\"]\n\n\nclass PAPenaltyReward(Reward[SAOEState]):\n    \"\"\"Encourage higher PAs, but penalize stacking all the amounts within a very short time.\n    Formally, for each time step, the reward is :math:`(PA_t * vol_t / target - vol_t^2 * penalty)`.\n\n    Parameters\n    ----------\n    penalty\n        The penalty for large volume in a short time.\n    scale\n        The weight used to scale up or down the reward.\n    \"\"\"\n\n    def __init__(self, penalty: float = 100.0, scale: float = 1.0) -> None:\n        self.penalty = penalty\n        self.scale = scale\n\n    def reward(self, simulator_state: SAOEState) -> float:\n        whole_order = simulator_state.order.amount\n        assert whole_order > 0\n        last_step = cast(SAOEMetrics, simulator_state.history_steps.reset_index().iloc[-1].to_dict())\n        pa = last_step[\"pa\"] * last_step[\"amount\"] / whole_order\n\n        # Inspect the \"break-down\" of the latest step: trading amount at every tick\n        last_step_breakdown = simulator_state.history_exec.loc[last_step[\"datetime\"] :]\n        penalty = -self.penalty * ((last_step_breakdown[\"amount\"] / whole_order) ** 2).sum()\n\n        reward = pa + penalty\n\n        # Throw error in case of NaN\n        assert not (np.isnan(reward) or np.isinf(reward)), f\"Invalid reward for simulator state: {simulator_state}\"\n\n        self.log(\"reward/pa\", pa)\n        self.log(\"reward/penalty\", penalty)\n        return reward * self.scale\n\n\nclass PPOReward(Reward[SAOEState]):\n    \"\"\"Reward proposed by paper \"An End-to-End Optimal Trade Execution Framework based on Proximal Policy Optimization\".\n\n    Parameters\n    ----------\n    max_step\n        Maximum number of steps.\n    start_time_index\n        First time index that allowed to trade.\n    end_time_index\n        Last time index that allowed to trade.\n    \"\"\"\n\n    def __init__(self, max_step: int, start_time_index: int = 0, end_time_index: int = 239) -> None:\n        self.max_step = max_step\n        self.start_time_index = start_time_index\n        self.end_time_index = end_time_index\n\n    def reward(self, simulator_state: SAOEState) -> float:\n        if simulator_state.cur_step == self.max_step - 1 or simulator_state.position < 1e-6:\n            if simulator_state.history_exec[\"deal_amount\"].sum() == 0.0:\n                vwap_price = cast(\n                    float,\n                    np.average(simulator_state.history_exec[\"market_price\"]),\n                )\n            else:\n                vwap_price = cast(\n                    float,\n                    np.average(\n                        simulator_state.history_exec[\"market_price\"],\n                        weights=simulator_state.history_exec[\"deal_amount\"],\n                    ),\n                )\n            twap_price = simulator_state.backtest_data.get_deal_price().mean()\n\n            if simulator_state.order.direction == OrderDir.SELL:\n                ratio = vwap_price / twap_price if twap_price != 0 else 1.0\n            else:\n                ratio = twap_price / vwap_price if vwap_price != 0 else 1.0\n            if ratio < 1.0:\n                return -1.0\n            elif ratio < 1.1:\n                return 0.0\n            else:\n                return 1.0\n        else:\n            return 0.0\n"
  },
  {
    "path": "qlib/rl/order_execution/simulator_qlib.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom __future__ import annotations\n\nfrom typing import Generator, List, Optional\n\nimport pandas as pd\n\nfrom qlib.backtest import collect_data_loop, get_strategy_executor\nfrom qlib.backtest.decision import BaseTradeDecision, Order, TradeRangeByTime\nfrom qlib.backtest.executor import NestedExecutor\nfrom qlib.rl.data.integration import init_qlib\nfrom qlib.rl.simulator import Simulator\nfrom .state import SAOEState\nfrom .strategy import SAOEStateAdapter, SAOEStrategy\n\n\nclass SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):\n    \"\"\"Single-asset order execution (SAOE) simulator which is implemented based on Qlib backtest tools.\n\n    Parameters\n    ----------\n    order\n        The seed to start an SAOE simulator is an order.\n    executor_config\n        Executor configuration\n    exchange_config\n        Exchange configuration\n    qlib_config\n        Configuration used to initialize Qlib. If it is None, Qlib will not be initialized.\n    cash_limit:\n        Cash limit.\n    \"\"\"\n\n    def __init__(\n        self,\n        order: Order,\n        executor_config: dict,\n        exchange_config: dict,\n        qlib_config: dict | None = None,\n        cash_limit: float | None = None,\n    ) -> None:\n        super().__init__(initial=order)\n\n        assert order.start_time.date() == order.end_time.date(), \"Start date and end date must be the same.\"\n\n        strategy_config = {\n            \"class\": \"SingleOrderStrategy\",\n            \"module_path\": \"qlib.rl.strategy.single_order\",\n            \"kwargs\": {\n                \"order\": order,\n                \"trade_range\": TradeRangeByTime(order.start_time.time(), order.end_time.time()),\n            },\n        }\n\n        self._collect_data_loop: Optional[Generator] = None\n        self.reset(order, strategy_config, executor_config, exchange_config, qlib_config, cash_limit)\n\n    def reset(\n        self,\n        order: Order,\n        strategy_config: dict,\n        executor_config: dict,\n        exchange_config: dict,\n        qlib_config: dict | None = None,\n        cash_limit: Optional[float] = None,\n    ) -> None:\n        if qlib_config is not None:\n            init_qlib(qlib_config)\n\n        strategy, self._executor = get_strategy_executor(\n            start_time=order.date,\n            end_time=order.date + pd.DateOffset(1),\n            strategy=strategy_config,\n            executor=executor_config,\n            benchmark=order.stock_id,\n            account=cash_limit if cash_limit is not None else int(1e12),\n            exchange_kwargs=exchange_config,\n            pos_type=\"Position\" if cash_limit is not None else \"InfPosition\",\n        )\n\n        assert isinstance(self._executor, NestedExecutor)\n\n        self.report_dict: dict = {}\n        self.decisions: List[BaseTradeDecision] = []\n        self._collect_data_loop = collect_data_loop(\n            start_time=order.date,\n            end_time=order.date,\n            trade_strategy=strategy,\n            trade_executor=self._executor,\n            return_value=self.report_dict,\n        )\n        assert isinstance(self._collect_data_loop, Generator)\n\n        self.step(action=None)\n\n        self._order = order\n\n    def _get_adapter(self) -> SAOEStateAdapter:\n        return self._last_yielded_saoe_strategy.adapter_dict[self._order.key_by_day]\n\n    @property\n    def twap_price(self) -> float:\n        return self._get_adapter().twap_price\n\n    def _iter_strategy(self, action: Optional[float] = None) -> SAOEStrategy:\n        \"\"\"Iterate the _collect_data_loop until we get the next yield SAOEStrategy.\"\"\"\n        assert self._collect_data_loop is not None\n\n        obj = next(self._collect_data_loop) if action is None else self._collect_data_loop.send(action)\n        while not isinstance(obj, SAOEStrategy):\n            if isinstance(obj, BaseTradeDecision):\n                self.decisions.append(obj)\n            obj = next(self._collect_data_loop) if action is None else self._collect_data_loop.send(action)\n        assert isinstance(obj, SAOEStrategy)\n        return obj\n\n    def step(self, action: Optional[float]) -> None:\n        \"\"\"Execute one step or SAOE.\n\n        Parameters\n        ----------\n        action (float):\n            The amount you wish to deal. The simulator doesn't guarantee all the amount to be successfully dealt.\n        \"\"\"\n\n        assert not self.done(), \"Simulator has already done!\"\n\n        try:\n            self._last_yielded_saoe_strategy = self._iter_strategy(action=action)\n        except StopIteration:\n            pass\n\n        assert self._executor is not None\n\n    def get_state(self) -> SAOEState:\n        return self._get_adapter().saoe_state\n\n    def done(self) -> bool:\n        return self._executor.finished()\n"
  },
  {
    "path": "qlib/rl/order_execution/simulator_simple.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom __future__ import annotations\n\nfrom typing import Any, cast, List, Optional\n\nimport numpy as np\nimport pandas as pd\n\nfrom pathlib import Path\nfrom qlib.backtest.decision import Order, OrderDir\nfrom qlib.constant import EPS, EPS_T, float_or_ndarray\nfrom qlib.rl.data.base import BaseIntradayBacktestData\nfrom qlib.rl.data.native import DataframeIntradayBacktestData, load_handler_intraday_processed_data\nfrom qlib.rl.data.pickle_styled import load_simple_intraday_backtest_data\nfrom qlib.rl.simulator import Simulator\nfrom qlib.rl.utils import LogLevel\nfrom .state import SAOEMetrics, SAOEState\n\n__all__ = [\"SingleAssetOrderExecutionSimple\"]\n\n\nclass SingleAssetOrderExecutionSimple(Simulator[Order, SAOEState, float]):\n    \"\"\"Single-asset order execution (SAOE) simulator.\n\n    As there's no \"calendar\" in the simple simulator, ticks are used to trade.\n    A tick is a record (a line) in the pickle-styled data file.\n    Each tick is considered as a individual trading opportunity.\n    If such fine granularity is not needed, use ``ticks_per_step`` to\n    lengthen the ticks for each step.\n\n    In each step, the traded amount are \"equally\" separated to each tick,\n    then bounded by volume maximum execution volume (i.e., ``vol_threshold``),\n    and if it's the last step, try to ensure all the amount to be executed.\n\n    Parameters\n    ----------\n    order\n        The seed to start an SAOE simulator is an order.\n    data_dir\n        Path to load backtest data.\n    feature_columns_today\n        Columns of today's feature.\n    feature_columns_yesterday\n        Columns of yesterday's feature.\n    data_granularity\n        Number of ticks between consecutive data entries.\n    ticks_per_step\n        How many ticks per step.\n    vol_threshold\n        Maximum execution volume (divided by market execution volume).\n    \"\"\"\n\n    history_exec: pd.DataFrame\n    \"\"\"All execution history at every possible time ticks. See :class:`SAOEMetrics` for available columns.\n    Index is ``datetime``.\n    \"\"\"\n\n    history_steps: pd.DataFrame\n    \"\"\"Positions at each step. The position before first step is also recorded.\n    See :class:`SAOEMetrics` for available columns.\n    Index is ``datetime``, which is the **starting** time of each step.\"\"\"\n\n    metrics: Optional[SAOEMetrics]\n    \"\"\"Metrics. Only available when done.\"\"\"\n\n    twap_price: float\n    \"\"\"This price is used to compute price advantage.\n    It\"s defined as the average price in the period from order\"s start time to end time.\"\"\"\n\n    ticks_index: pd.DatetimeIndex\n    \"\"\"All available ticks for the day (not restricted to order).\"\"\"\n\n    ticks_for_order: pd.DatetimeIndex\n    \"\"\"Ticks that is available for trading (sliced by order).\"\"\"\n\n    def __init__(\n        self,\n        order: Order,\n        data_dir: Path,\n        feature_columns_today: List[str] = [],\n        feature_columns_yesterday: List[str] = [],\n        data_granularity: int = 1,\n        ticks_per_step: int = 30,\n        vol_threshold: Optional[float] = None,\n    ) -> None:\n        super().__init__(initial=order)\n\n        assert ticks_per_step % data_granularity == 0\n\n        self.order = order\n        self.data_dir = data_dir\n        self.feature_columns_today = feature_columns_today\n        self.feature_columns_yesterday = feature_columns_yesterday\n        self.ticks_per_step: int = ticks_per_step // data_granularity\n        self.vol_threshold = vol_threshold\n\n        self.backtest_data = self.get_backtest_data()\n        self.ticks_index = self.backtest_data.get_time_index()\n\n        # Get time index available for trading\n        self.ticks_for_order = self._get_ticks_slice(self.order.start_time, self.order.end_time)\n\n        self.cur_time = self.ticks_for_order[0]\n        self.cur_step = 0\n        # NOTE: astype(float) is necessary in some systems.\n        # this will align the precision with `.to_numpy()` in `_split_exec_vol`\n        self.twap_price = float(self.backtest_data.get_deal_price().loc[self.ticks_for_order].astype(float).mean())\n\n        self.position = order.amount\n\n        metric_keys = list(SAOEMetrics.__annotations__.keys())  # pylint: disable=no-member\n        # NOTE: can empty dataframe contain index?\n        self.history_exec = pd.DataFrame(columns=metric_keys).set_index(\"datetime\")\n        self.history_steps = pd.DataFrame(columns=metric_keys).set_index(\"datetime\")\n        self.metrics = None\n\n        self.market_price: Optional[np.ndarray] = None\n        self.market_vol: Optional[np.ndarray] = None\n        self.market_vol_limit: Optional[np.ndarray] = None\n\n    def get_backtest_data(self) -> BaseIntradayBacktestData:\n        try:\n            data = load_handler_intraday_processed_data(\n                data_dir=self.data_dir,\n                stock_id=self.order.stock_id,\n                date=pd.Timestamp(self.order.start_time.date()),\n                feature_columns_today=self.feature_columns_today,\n                feature_columns_yesterday=self.feature_columns_yesterday,\n                backtest=True,\n                index_only=False,\n            )\n            return DataframeIntradayBacktestData(data.today)\n        except (AttributeError, FileNotFoundError):\n            # TODO: For compatibility with older versions of test scripts (tests/rl/test_saoe_simple.py)\n            # TODO: In the future, we should modify the data format used by the test script,\n            # TODO: and then delete this branch.\n            return load_simple_intraday_backtest_data(\n                self.data_dir / \"backtest\",\n                self.order.stock_id,\n                pd.Timestamp(self.order.start_time.date()),\n                \"close\",\n                self.order.direction,\n            )\n\n    def step(self, amount: float) -> None:\n        \"\"\"Execute one step or SAOE.\n\n        Parameters\n        ----------\n        amount\n            The amount you wish to deal. The simulator doesn't guarantee all the amount to be successfully dealt.\n        \"\"\"\n\n        assert not self.done()\n\n        self.market_price = self.market_vol = None  # avoid misuse\n        exec_vol = self._split_exec_vol(amount)\n        assert self.market_price is not None\n        assert self.market_vol is not None\n\n        ticks_position = self.position - np.cumsum(exec_vol)\n\n        self.position -= exec_vol.sum()\n        if abs(self.position) < 1e-6:\n            self.position = 0.0\n        if self.position < -EPS or (exec_vol < -EPS).any():\n            raise ValueError(f\"Execution volume is invalid: {exec_vol} (position = {self.position})\")\n\n        # Get time index available for this step\n        time_index = self._get_ticks_slice(self.cur_time, self._next_time())\n\n        self.history_exec = self._dataframe_append(\n            self.history_exec,\n            SAOEMetrics(\n                # It should have the same keys with SAOEMetrics,\n                # but the values do not necessarily have the annotated type.\n                # Some values could be vectorized (e.g., exec_vol).\n                stock_id=self.order.stock_id,\n                datetime=time_index,\n                direction=self.order.direction,\n                market_volume=self.market_vol,\n                market_price=self.market_price,\n                amount=exec_vol,\n                inner_amount=exec_vol,\n                deal_amount=exec_vol,\n                trade_price=self.market_price,\n                trade_value=self.market_price * exec_vol,\n                position=ticks_position,\n                ffr=exec_vol / self.order.amount,\n                pa=price_advantage(self.market_price, self.twap_price, self.order.direction),\n            ),\n        )\n\n        self.history_steps = self._dataframe_append(\n            self.history_steps,\n            [self._metrics_collect(self.cur_time, self.market_vol, self.market_price, amount, exec_vol)],\n        )\n\n        if self.done():\n            if self.env is not None:\n                self.env.logger.add_any(\"history_steps\", self.history_steps, loglevel=LogLevel.DEBUG)\n                self.env.logger.add_any(\"history_exec\", self.history_exec, loglevel=LogLevel.DEBUG)\n\n            self.metrics = self._metrics_collect(\n                self.ticks_index[0],  # start time\n                self.history_exec[\"market_volume\"],\n                self.history_exec[\"market_price\"],\n                self.history_steps[\"amount\"].sum(),\n                self.history_exec[\"deal_amount\"],\n            )\n\n            # NOTE (yuge): It looks to me that it's the \"correct\" decision to\n            # put all the logs here, because only components like simulators themselves\n            # have the knowledge about what could appear in the logs, and what's the format.\n            # But I admit it's not necessarily the most convenient way.\n            # I'll rethink about it when we have the second environment\n            # Maybe some APIs like self.logger.enable_auto_log() ?\n\n            if self.env is not None:\n                for key, value in self.metrics.items():\n                    if isinstance(value, float):\n                        self.env.logger.add_scalar(key, value)\n                    else:\n                        self.env.logger.add_any(key, value)\n\n        self.cur_time = self._next_time()\n        self.cur_step += 1\n\n    def get_state(self) -> SAOEState:\n        return SAOEState(\n            order=self.order,\n            cur_time=self.cur_time,\n            cur_step=self.cur_step,\n            position=self.position,\n            history_exec=self.history_exec,\n            history_steps=self.history_steps,\n            metrics=self.metrics,\n            backtest_data=self.backtest_data,\n            ticks_per_step=self.ticks_per_step,\n            ticks_index=self.ticks_index,\n            ticks_for_order=self.ticks_for_order,\n        )\n\n    def done(self) -> bool:\n        return self.position < EPS or self.cur_time >= self.order.end_time\n\n    def _next_time(self) -> pd.Timestamp:\n        \"\"\"The \"current time\" (``cur_time``) for next step.\"\"\"\n        # Look for next time on time index\n        current_loc = self.ticks_index.get_loc(self.cur_time)\n        next_loc = current_loc + self.ticks_per_step\n\n        # Calibrate the next location to multiple of ticks_per_step.\n        # This is to make sure that:\n        # as long as ticks_per_step is a multiple of something, each step won't cross morning and afternoon.\n        next_loc = next_loc - next_loc % self.ticks_per_step\n\n        if next_loc < len(self.ticks_index) and self.ticks_index[next_loc] < self.order.end_time:\n            return self.ticks_index[next_loc]\n        else:\n            return self.order.end_time\n\n    def _cur_duration(self) -> pd.Timedelta:\n        \"\"\"The \"duration\" of this step (step that is about to happen).\"\"\"\n        return self._next_time() - self.cur_time\n\n    def _split_exec_vol(self, exec_vol_sum: float) -> np.ndarray:\n        \"\"\"\n        Split the volume in each step into minutes, considering possible constraints.\n        This follows TWAP strategy.\n        \"\"\"\n        next_time = self._next_time()\n\n        # get the backtest data for next interval\n        self.market_vol = self.backtest_data.get_volume().loc[self.cur_time : next_time - EPS_T].to_numpy()\n        self.market_price = self.backtest_data.get_deal_price().loc[self.cur_time : next_time - EPS_T].to_numpy()\n\n        assert self.market_vol is not None and self.market_price is not None\n\n        # split the volume equally into each minute\n        exec_vol = np.repeat(exec_vol_sum / len(self.market_price), len(self.market_price))\n\n        # apply the volume threshold\n        market_vol_limit = self.vol_threshold * self.market_vol if self.vol_threshold is not None else np.inf\n        exec_vol = np.minimum(exec_vol, market_vol_limit)  # type: ignore\n\n        # Complete all the order amount at the last moment.\n        if next_time >= self.order.end_time:\n            exec_vol[-1] += self.position - exec_vol.sum()\n            exec_vol = np.minimum(exec_vol, market_vol_limit)  # type: ignore\n\n        return exec_vol\n\n    def _metrics_collect(\n        self,\n        datetime: pd.Timestamp,\n        market_vol: np.ndarray,\n        market_price: np.ndarray,\n        amount: float,  # intended to trade such amount\n        exec_vol: np.ndarray,\n    ) -> SAOEMetrics:\n        assert len(market_vol) == len(market_price) == len(exec_vol)\n\n        if np.abs(np.sum(exec_vol)) < EPS:\n            exec_avg_price = 0.0\n        else:\n            exec_avg_price = cast(float, np.average(market_price, weights=exec_vol))  # could be nan\n            if hasattr(exec_avg_price, \"item\"):  # could be numpy scalar\n                exec_avg_price = exec_avg_price.item()  # type: ignore\n\n        return SAOEMetrics(\n            stock_id=self.order.stock_id,\n            datetime=datetime,\n            direction=self.order.direction,\n            market_volume=market_vol.sum(),\n            market_price=market_price.mean(),\n            amount=amount,\n            inner_amount=exec_vol.sum(),\n            deal_amount=exec_vol.sum(),  # in this simulator, there's no other restrictions\n            trade_price=exec_avg_price,\n            trade_value=float(np.sum(market_price * exec_vol)),\n            position=self.position,\n            ffr=float(exec_vol.sum() / self.order.amount),\n            pa=price_advantage(exec_avg_price, self.twap_price, self.order.direction),\n        )\n\n    def _get_ticks_slice(self, start: pd.Timestamp, end: pd.Timestamp, include_end: bool = False) -> pd.DatetimeIndex:\n        if not include_end:\n            end = end - EPS_T\n        return self.ticks_index[self.ticks_index.slice_indexer(start, end)]\n\n    @staticmethod\n    def _dataframe_append(df: pd.DataFrame, other: Any) -> pd.DataFrame:\n        # dataframe.append is deprecated\n        other_df = pd.DataFrame(other).set_index(\"datetime\")\n        other_df.index.name = \"datetime\"\n        return pd.concat([df, other_df], axis=0)\n\n\ndef price_advantage(\n    exec_price: float_or_ndarray,\n    baseline_price: float,\n    direction: OrderDir | int,\n) -> float_or_ndarray:\n    if baseline_price == 0:  # something is wrong with data. Should be nan here\n        if isinstance(exec_price, float):\n            return 0.0\n        else:\n            return np.zeros_like(exec_price)\n    if direction == OrderDir.BUY:\n        res = (1 - exec_price / baseline_price) * 10000\n    elif direction == OrderDir.SELL:\n        res = (exec_price / baseline_price - 1) * 10000\n    else:\n        raise ValueError(f\"Unexpected order direction: {direction}\")\n    res_wo_nan: np.ndarray = np.nan_to_num(res, nan=0.0)\n    if res_wo_nan.size == 1:\n        return res_wo_nan.item()\n    else:\n        return cast(float_or_ndarray, res_wo_nan)\n"
  },
  {
    "path": "qlib/rl/order_execution/state.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom __future__ import annotations\n\nimport typing\nfrom typing import NamedTuple, Optional\n\nimport numpy as np\nimport pandas as pd\nfrom qlib.backtest import Order\nfrom qlib.typehint import TypedDict\n\nif typing.TYPE_CHECKING:\n    from qlib.rl.data.base import BaseIntradayBacktestData\n\n\nclass SAOEMetrics(TypedDict):\n    \"\"\"Metrics for SAOE accumulated for a \"period\".\n    It could be accumulated for a day, or a period of time (e.g., 30min), or calculated separately for every minute.\n\n    Warnings\n    --------\n    The type hints are for single elements. In lots of times, they can be vectorized.\n    For example, ``market_volume`` could be a list of float (or ndarray) rather tahn a single float.\n    \"\"\"\n\n    stock_id: str\n    \"\"\"Stock ID of this record.\"\"\"\n    datetime: pd.Timestamp | pd.DatetimeIndex\n    \"\"\"Datetime of this record (this is index in the dataframe).\"\"\"\n    direction: int\n    \"\"\"Direction of the order. 0 for sell, 1 for buy.\"\"\"\n\n    # Market information.\n    market_volume: np.ndarray | float\n    \"\"\"(total) market volume traded in the period.\"\"\"\n    market_price: np.ndarray | float\n    \"\"\"Deal price. If it's a period of time, this is the average market deal price.\"\"\"\n\n    # Strategy records.\n\n    amount: np.ndarray | float\n    \"\"\"Total amount (volume) strategy intends to trade.\"\"\"\n    inner_amount: np.ndarray | float\n    \"\"\"Total amount that the lower-level strategy intends to trade\n    (might be larger than amount, e.g., to ensure ffr).\"\"\"\n\n    deal_amount: np.ndarray | float\n    \"\"\"Amount that successfully takes effect (must be less than inner_amount).\"\"\"\n    trade_price: np.ndarray | float\n    \"\"\"The average deal price for this strategy.\"\"\"\n    trade_value: np.ndarray | float\n    \"\"\"Total worth of trading. In the simple simulation, trade_value = deal_amount * price.\"\"\"\n    position: np.ndarray | float\n    \"\"\"Position left after this \"period\".\"\"\"\n\n    # Accumulated metrics\n\n    ffr: np.ndarray | float\n    \"\"\"Completed how much percent of the daily order.\"\"\"\n\n    pa: np.ndarray | float\n    \"\"\"Price advantage compared to baseline (i.e., trade with baseline market price).\n    The baseline is trade price when using TWAP strategy to execute this order.\n    Please note that there could be data leak here).\n    Unit is BP (basis point, 1/10000).\"\"\"\n\n\nclass SAOEState(NamedTuple):\n    \"\"\"Data structure holding a state for SAOE simulator.\"\"\"\n\n    order: Order\n    \"\"\"The order we are dealing with.\"\"\"\n    cur_time: pd.Timestamp\n    \"\"\"Current time, e.g., 9:30.\"\"\"\n    cur_step: int\n    \"\"\"Current step, e.g., 0.\"\"\"\n    position: float\n    \"\"\"Current remaining volume to execute.\"\"\"\n    history_exec: pd.DataFrame\n    \"\"\"See :attr:`SingleAssetOrderExecution.history_exec`.\"\"\"\n    history_steps: pd.DataFrame\n    \"\"\"See :attr:`SingleAssetOrderExecution.history_steps`.\"\"\"\n\n    metrics: Optional[SAOEMetrics]\n    \"\"\"Daily metric, only available when the trading is in \"done\" state.\"\"\"\n\n    backtest_data: BaseIntradayBacktestData\n    \"\"\"Backtest data is included in the state.\n    Actually, only the time index of this data is needed, at this moment.\n    I include the full data so that algorithms (e.g., VWAP) that relies on the raw data can be implemented.\n    Interpreter can use this as they wish, but they should be careful not to leak future data.\n    \"\"\"\n\n    ticks_per_step: int\n    \"\"\"How many ticks for each step.\"\"\"\n    ticks_index: pd.DatetimeIndex\n    \"\"\"Trading ticks in all day, NOT sliced by order (defined in data). e.g., [9:30, 9:31, ..., 14:59].\"\"\"\n    ticks_for_order: pd.DatetimeIndex\n    \"\"\"Trading ticks sliced by order, e.g., [9:45, 9:46, ..., 14:44].\"\"\"\n"
  },
  {
    "path": "qlib/rl/order_execution/strategy.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom __future__ import annotations\n\nimport collections\nfrom types import GeneratorType\nfrom typing import Any, Callable, cast, Dict, Generator, List, Optional, Tuple, Union\n\nimport warnings\nimport numpy as np\nimport pandas as pd\nimport torch\nfrom tianshou.data import Batch\nfrom tianshou.policy import BasePolicy\n\nfrom qlib.backtest import CommonInfrastructure, Order\nfrom qlib.backtest.decision import BaseTradeDecision, TradeDecisionWithDetails, TradeDecisionWO, TradeRange\nfrom qlib.backtest.exchange import Exchange\nfrom qlib.backtest.executor import BaseExecutor\nfrom qlib.backtest.utils import LevelInfrastructure, get_start_end_idx\nfrom qlib.constant import EPS, ONE_MIN, REG_CN\nfrom qlib.rl.data.native import IntradayBacktestData, load_backtest_data\nfrom qlib.rl.interpreter import ActionInterpreter, StateInterpreter\nfrom qlib.rl.order_execution.state import SAOEMetrics, SAOEState\nfrom qlib.rl.order_execution.utils import dataframe_append, price_advantage\nfrom qlib.strategy.base import RLStrategy\nfrom qlib.utils import init_instance_by_config\nfrom qlib.utils.index_data import IndexData\nfrom qlib.utils.time import get_day_min_idx_range\n\n\ndef _get_all_timestamps(\n    start: pd.Timestamp,\n    end: pd.Timestamp,\n    granularity: pd.Timedelta = ONE_MIN,\n    include_end: bool = True,\n) -> pd.DatetimeIndex:\n    ret = []\n    while start <= end:\n        ret.append(start)\n        start += granularity\n\n    if ret[-1] > end:\n        ret.pop()\n    if ret[-1] == end and not include_end:\n        ret.pop()\n    return pd.DatetimeIndex(ret)\n\n\ndef fill_missing_data(\n    original_data: np.ndarray,\n    fill_method: Callable = np.nanmedian,\n) -> np.ndarray:\n    \"\"\"Fill missing data.\n\n    Parameters\n    ----------\n    original_data\n        Original data without missing values.\n    fill_method\n        Method used to fill the missing data.\n\n    Returns\n    -------\n        The filled data.\n    \"\"\"\n    return np.nan_to_num(original_data, nan=fill_method(original_data))\n\n\nclass SAOEStateAdapter:\n    \"\"\"\n    Maintain states of the environment. SAOEStateAdapter accepts execution results and update its internal state\n    according to the execution results with additional information acquired from executors & exchange. For example,\n    it gets the dealt order amount from execution results, and get the corresponding market price / volume from\n    exchange.\n\n    Example usage::\n\n        adapter = SAOEStateAdapter(...)\n        adapter.update(...)\n        state = adapter.saoe_state\n    \"\"\"\n\n    def __init__(\n        self,\n        order: Order,\n        trade_decision: BaseTradeDecision,\n        executor: BaseExecutor,\n        exchange: Exchange,\n        ticks_per_step: int,\n        backtest_data: IntradayBacktestData,\n        data_granularity: int = 1,\n    ) -> None:\n        self.position = order.amount\n        self.order = order\n        self.executor = executor\n        self.exchange = exchange\n        self.backtest_data = backtest_data\n        self.start_idx, _ = get_start_end_idx(self.executor.trade_calendar, trade_decision)\n\n        self.twap_price = self.backtest_data.get_deal_price().mean()\n\n        metric_keys = list(SAOEMetrics.__annotations__.keys())  # pylint: disable=no-member\n        self.history_exec = pd.DataFrame(columns=metric_keys).set_index(\"datetime\")\n        self.history_steps = pd.DataFrame(columns=metric_keys).set_index(\"datetime\")\n        self.metrics: Optional[SAOEMetrics] = None\n\n        self.cur_time = max(backtest_data.ticks_for_order[0], order.start_time)\n        self.ticks_per_step = ticks_per_step\n        self.data_granularity = data_granularity\n        assert self.ticks_per_step % self.data_granularity == 0\n\n    def _next_time(self) -> pd.Timestamp:\n        current_loc = self.backtest_data.ticks_index.get_loc(self.cur_time)\n        next_loc = current_loc + (self.ticks_per_step // self.data_granularity)\n        next_loc = next_loc - next_loc % (self.ticks_per_step // self.data_granularity)\n        if (\n            next_loc < len(self.backtest_data.ticks_index)\n            and self.backtest_data.ticks_index[next_loc] < self.order.end_time\n        ):\n            return self.backtest_data.ticks_index[next_loc]\n        else:\n            return self.order.end_time\n\n    def update(\n        self,\n        execute_result: list,\n        last_step_range: Tuple[int, int],\n    ) -> None:\n        last_step_size = last_step_range[1] - last_step_range[0] + 1\n        start_time = self.backtest_data.ticks_index[last_step_range[0]]\n        end_time = self.backtest_data.ticks_index[last_step_range[1]]\n\n        exec_vol = np.zeros(last_step_size)\n        for order, _, __, ___ in execute_result:\n            idx, _ = get_day_min_idx_range(order.start_time, order.end_time, f\"{self.data_granularity}min\", REG_CN)\n            exec_vol[idx - last_step_range[0]] = order.deal_amount\n\n        if exec_vol.sum() > self.position and exec_vol.sum() > 0.0:\n            if exec_vol.sum() > self.position + 1.0:\n                warnings.warn(\n                    f\"Sum of execution volume is {exec_vol.sum()} which is larger than \"\n                    f\"position + 1.0 = {self.position} + 1.0 = {self.position + 1.0}. \"\n                    f\"All execution volume is scaled down linearly to ensure that their sum does not position.\"\n                )\n            exec_vol *= self.position / (exec_vol.sum())\n\n        market_volume = cast(\n            IndexData,\n            self.exchange.get_volume(\n                self.order.stock_id,\n                pd.Timestamp(start_time),\n                pd.Timestamp(end_time),\n                method=None,\n            ),\n        )\n        market_price = cast(\n            IndexData,\n            self.exchange.get_deal_price(\n                self.order.stock_id,\n                pd.Timestamp(start_time),\n                pd.Timestamp(end_time),\n                method=None,\n                direction=self.order.direction,\n            ),\n        )\n        market_price = fill_missing_data(np.array(market_price, dtype=float).reshape(-1))\n        market_volume = fill_missing_data(np.array(market_volume, dtype=float).reshape(-1))\n\n        assert market_price.shape == market_volume.shape == exec_vol.shape\n\n        # Get data from the current level executor's indicator\n        current_trade_account = self.executor.trade_account\n        current_df = current_trade_account.get_trade_indicator().generate_trade_indicators_dataframe()\n        self.history_exec = dataframe_append(\n            self.history_exec,\n            self._collect_multi_order_metric(\n                order=self.order,\n                datetime=_get_all_timestamps(\n                    start_time, end_time, include_end=True, granularity=ONE_MIN * self.data_granularity\n                ),\n                market_vol=market_volume,\n                market_price=market_price,\n                exec_vol=exec_vol,\n                pa=current_df.iloc[-1][\"pa\"],\n            ),\n        )\n\n        self.history_steps = dataframe_append(\n            self.history_steps,\n            [\n                self._collect_single_order_metric(\n                    self.order,\n                    self.cur_time,\n                    market_volume,\n                    market_price,\n                    exec_vol.sum(),\n                    exec_vol,\n                ),\n            ],\n        )\n\n        # Do this at the end\n        self.position -= exec_vol.sum()\n\n        self.cur_time = self._next_time()\n\n    def generate_metrics_after_done(self) -> None:\n        \"\"\"Generate metrics once the upper level execution is done\"\"\"\n\n        self.metrics = self._collect_single_order_metric(\n            self.order,\n            self.backtest_data.ticks_index[0],  # start time\n            self.history_exec[\"market_volume\"],\n            self.history_exec[\"market_price\"],\n            self.history_steps[\"amount\"].sum(),\n            self.history_exec[\"deal_amount\"],\n        )\n\n    def _collect_multi_order_metric(\n        self,\n        order: Order,\n        datetime: pd.DatetimeIndex,\n        market_vol: np.ndarray,\n        market_price: np.ndarray,\n        exec_vol: np.ndarray,\n        pa: float,\n    ) -> SAOEMetrics:\n        return SAOEMetrics(\n            # It should have the same keys with SAOEMetrics,\n            # but the values do not necessarily have the annotated type.\n            # Some values could be vectorized (e.g., exec_vol).\n            stock_id=order.stock_id,\n            datetime=datetime,\n            direction=order.direction,\n            market_volume=market_vol,\n            market_price=market_price,\n            amount=exec_vol,\n            inner_amount=exec_vol,\n            deal_amount=exec_vol,\n            trade_price=market_price,\n            trade_value=market_price * exec_vol,\n            position=self.position - np.cumsum(exec_vol),\n            ffr=exec_vol / order.amount,\n            pa=pa,\n        )\n\n    def _collect_single_order_metric(\n        self,\n        order: Order,\n        datetime: pd.Timestamp,\n        market_vol: np.ndarray,\n        market_price: np.ndarray,\n        amount: float,  # intended to trade such amount\n        exec_vol: np.ndarray,\n    ) -> SAOEMetrics:\n        assert len(market_vol) == len(market_price) == len(exec_vol)\n\n        if np.abs(np.sum(exec_vol)) < EPS:\n            exec_avg_price = 0.0\n        else:\n            exec_avg_price = cast(float, np.average(market_price, weights=exec_vol))  # could be nan\n            if hasattr(exec_avg_price, \"item\"):  # could be numpy scalar\n                exec_avg_price = exec_avg_price.item()  # type: ignore\n\n        exec_sum = exec_vol.sum()\n        return SAOEMetrics(\n            stock_id=order.stock_id,\n            datetime=datetime,\n            direction=order.direction,\n            market_volume=market_vol.sum(),\n            market_price=market_price.mean() if len(market_price) > 0 else np.nan,\n            amount=amount,\n            inner_amount=exec_sum,\n            deal_amount=exec_sum,  # in this simulator, there's no other restrictions\n            trade_price=exec_avg_price,\n            trade_value=float(np.sum(market_price * exec_vol)),\n            position=self.position - exec_sum,\n            ffr=float(exec_sum / order.amount),\n            pa=price_advantage(exec_avg_price, self.twap_price, order.direction),\n        )\n\n    @property\n    def saoe_state(self) -> SAOEState:\n        return SAOEState(\n            order=self.order,\n            cur_time=self.cur_time,\n            cur_step=self.executor.trade_calendar.get_trade_step() - self.start_idx,\n            position=self.position,\n            history_exec=self.history_exec,\n            history_steps=self.history_steps,\n            metrics=self.metrics,\n            backtest_data=self.backtest_data,\n            ticks_per_step=self.ticks_per_step,\n            ticks_index=self.backtest_data.ticks_index,\n            ticks_for_order=self.backtest_data.ticks_for_order,\n        )\n\n\nclass SAOEStrategy(RLStrategy):\n    \"\"\"RL-based strategies that use SAOEState as state.\"\"\"\n\n    def __init__(\n        self,\n        policy: BasePolicy,\n        outer_trade_decision: BaseTradeDecision | None = None,\n        level_infra: LevelInfrastructure | None = None,\n        common_infra: CommonInfrastructure | None = None,\n        data_granularity: int = 1,\n        **kwargs: Any,\n    ) -> None:\n        super(SAOEStrategy, self).__init__(\n            policy=policy,\n            outer_trade_decision=outer_trade_decision,\n            level_infra=level_infra,\n            common_infra=common_infra,\n            **kwargs,\n        )\n\n        self._data_granularity = data_granularity\n        self.adapter_dict: Dict[tuple, SAOEStateAdapter] = {}\n        self._last_step_range = (0, 0)\n\n    def _create_qlib_backtest_adapter(\n        self,\n        order: Order,\n        trade_decision: BaseTradeDecision,\n        trade_range: TradeRange,\n    ) -> SAOEStateAdapter:\n        backtest_data = load_backtest_data(order, self.trade_exchange, trade_range)\n\n        return SAOEStateAdapter(\n            order=order,\n            trade_decision=trade_decision,\n            executor=self.executor,\n            exchange=self.trade_exchange,\n            ticks_per_step=int(pd.Timedelta(self.trade_calendar.get_freq()) / ONE_MIN),\n            backtest_data=backtest_data,\n            data_granularity=self._data_granularity,\n        )\n\n    def reset(self, outer_trade_decision: BaseTradeDecision | None = None, **kwargs: Any) -> None:\n        super(SAOEStrategy, self).reset(outer_trade_decision=outer_trade_decision, **kwargs)\n\n        self.adapter_dict = {}\n        self._last_step_range = (0, 0)\n\n        if outer_trade_decision is not None and not outer_trade_decision.empty():\n            trade_range = outer_trade_decision.trade_range\n            assert trade_range is not None\n\n            self.adapter_dict = {}\n            for decision in outer_trade_decision.get_decision():\n                order = cast(Order, decision)\n                self.adapter_dict[order.key_by_day] = self._create_qlib_backtest_adapter(\n                    order, outer_trade_decision, trade_range\n                )\n\n    def get_saoe_state_by_order(self, order: Order) -> SAOEState:\n        return self.adapter_dict[order.key_by_day].saoe_state\n\n    def post_upper_level_exe_step(self) -> None:\n        for adapter in self.adapter_dict.values():\n            adapter.generate_metrics_after_done()\n\n    def post_exe_step(self, execute_result: Optional[list]) -> None:\n        last_step_length = self._last_step_range[1] - self._last_step_range[0]\n        if last_step_length <= 0:\n            assert not execute_result\n            return\n\n        results = collections.defaultdict(list)\n        if execute_result is not None:\n            for e in execute_result:\n                results[e[0].key_by_day].append(e)\n\n        for key, adapter in self.adapter_dict.items():\n            adapter.update(results[key], self._last_step_range)\n\n    def generate_trade_decision(\n        self,\n        execute_result: list | None = None,\n    ) -> Union[BaseTradeDecision, Generator[Any, Any, BaseTradeDecision]]:\n        \"\"\"\n        For SAOEStrategy, we need to update the `self._last_step_range` every time a decision is generated.\n        This operation should be invisible to developers, so we implement it in `generate_trade_decision()`\n        The concrete logic to generate decisions should be implemented in `_generate_trade_decision()`.\n        In other words, all subclass of `SAOEStrategy` should overwrite `_generate_trade_decision()` instead of\n        `generate_trade_decision()`.\n        \"\"\"\n        self._last_step_range = self.get_data_cal_avail_range(rtype=\"step\")\n\n        decision = self._generate_trade_decision(execute_result)\n        if isinstance(decision, GeneratorType):\n            decision = yield from decision\n\n        return decision\n\n    def _generate_trade_decision(\n        self,\n        execute_result: list | None = None,\n    ) -> Union[BaseTradeDecision, Generator[Any, Any, BaseTradeDecision]]:\n        raise NotImplementedError\n\n\nclass ProxySAOEStrategy(SAOEStrategy):\n    \"\"\"Proxy strategy that uses SAOEState. It is called a 'proxy' strategy because it does not make any decisions\n    by itself. Instead, when the strategy is required to generate a decision, it will yield the environment's\n    information and let the outside agents to make the decision. Please refer to `_generate_trade_decision` for\n    more details.\n    \"\"\"\n\n    def __init__(\n        self,\n        outer_trade_decision: BaseTradeDecision | None = None,\n        level_infra: LevelInfrastructure | None = None,\n        common_infra: CommonInfrastructure | None = None,\n        **kwargs: Any,\n    ) -> None:\n        super().__init__(None, outer_trade_decision, level_infra, common_infra, **kwargs)\n\n    def _generate_trade_decision(self, execute_result: list | None = None) -> Generator[Any, Any, BaseTradeDecision]:\n        # Once the following line is executed, this ProxySAOEStrategy (self) will be yielded to the outside\n        # of the entire executor, and the execution will be suspended. When the execution is resumed by `send()`,\n        # the item will be captured by `exec_vol`. The outside policy could communicate with the inner\n        # level strategy through this way.\n        exec_vol = yield self\n\n        oh = self.trade_exchange.get_order_helper()\n        order = oh.create(self._order.stock_id, exec_vol, self._order.direction)\n\n        return TradeDecisionWO([order], self)\n\n    def reset(self, outer_trade_decision: BaseTradeDecision | None = None, **kwargs: Any) -> None:\n        super().reset(outer_trade_decision=outer_trade_decision, **kwargs)\n\n        assert isinstance(outer_trade_decision, TradeDecisionWO)\n        if outer_trade_decision is not None:\n            order_list = outer_trade_decision.order_list\n            assert len(order_list) == 1\n            self._order = order_list[0]\n\n\nclass SAOEIntStrategy(SAOEStrategy):\n    \"\"\"(SAOE)state based strategy with (Int)preters.\"\"\"\n\n    def __init__(\n        self,\n        policy: dict | BasePolicy,\n        state_interpreter: dict | StateInterpreter,\n        action_interpreter: dict | ActionInterpreter,\n        network: dict | torch.nn.Module | None = None,\n        outer_trade_decision: BaseTradeDecision | None = None,\n        level_infra: LevelInfrastructure | None = None,\n        common_infra: CommonInfrastructure | None = None,\n        **kwargs: Any,\n    ) -> None:\n        super(SAOEIntStrategy, self).__init__(\n            policy=policy,\n            outer_trade_decision=outer_trade_decision,\n            level_infra=level_infra,\n            common_infra=common_infra,\n            **kwargs,\n        )\n\n        self._state_interpreter: StateInterpreter = init_instance_by_config(\n            state_interpreter,\n            accept_types=StateInterpreter,\n        )\n        self._action_interpreter: ActionInterpreter = init_instance_by_config(\n            action_interpreter,\n            accept_types=ActionInterpreter,\n        )\n\n        if isinstance(policy, dict):\n            assert network is not None\n\n            if isinstance(network, dict):\n                network[\"kwargs\"].update(\n                    {\n                        \"obs_space\": self._state_interpreter.observation_space,\n                    }\n                )\n                network_inst = init_instance_by_config(network)\n            else:\n                network_inst = network\n\n            policy[\"kwargs\"].update(\n                {\n                    \"obs_space\": self._state_interpreter.observation_space,\n                    \"action_space\": self._action_interpreter.action_space,\n                    \"network\": network_inst,\n                }\n            )\n            self._policy = init_instance_by_config(policy)\n        elif isinstance(policy, BasePolicy):\n            self._policy = policy\n        else:\n            raise ValueError(f\"Unsupported policy type: {type(policy)}.\")\n\n        if self._policy is not None:\n            self._policy.eval()\n\n    def reset(self, outer_trade_decision: BaseTradeDecision | None = None, **kwargs: Any) -> None:\n        super().reset(outer_trade_decision=outer_trade_decision, **kwargs)\n\n    def _generate_trade_details(self, act: np.ndarray, exec_vols: List[float]) -> pd.DataFrame:\n        assert hasattr(self.outer_trade_decision, \"order_list\")\n\n        trade_details = []\n        for a, v, o in zip(act, exec_vols, getattr(self.outer_trade_decision, \"order_list\")):\n            trade_details.append(\n                {\n                    \"instrument\": o.stock_id,\n                    \"datetime\": self.trade_calendar.get_step_time()[0],\n                    \"freq\": self.trade_calendar.get_freq(),\n                    \"rl_exec_vol\": v,\n                }\n            )\n            if a is not None:\n                trade_details[-1][\"rl_action\"] = a\n        return pd.DataFrame.from_records(trade_details)\n\n    def _generate_trade_decision(self, execute_result: list | None = None) -> BaseTradeDecision:\n        states = []\n        obs_batch = []\n        for decision in self.outer_trade_decision.get_decision():\n            order = cast(Order, decision)\n            state = self.get_saoe_state_by_order(order)\n\n            states.append(state)\n            obs_batch.append({\"obs\": self._state_interpreter.interpret(state)})\n\n        with torch.no_grad():\n            policy_out = self._policy(Batch(obs_batch))\n        act = policy_out.act.numpy() if torch.is_tensor(policy_out.act) else policy_out.act\n        exec_vols = [self._action_interpreter.interpret(s, a) for s, a in zip(states, act)]\n\n        oh = self.trade_exchange.get_order_helper()\n        order_list = []\n        for decision, exec_vol in zip(self.outer_trade_decision.get_decision(), exec_vols):\n            if exec_vol != 0:\n                order = cast(Order, decision)\n                order_list.append(oh.create(order.stock_id, exec_vol, order.direction))\n\n        return TradeDecisionWithDetails(\n            order_list=order_list,\n            strategy=self,\n            details=self._generate_trade_details(act, exec_vols),\n        )\n"
  },
  {
    "path": "qlib/rl/order_execution/utils.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom __future__ import annotations\n\nfrom typing import Any, cast\n\nimport numpy as np\nimport pandas as pd\n\nfrom qlib.backtest.decision import OrderDir\nfrom qlib.backtest.executor import BaseExecutor, NestedExecutor, SimulatorExecutor\nfrom qlib.constant import float_or_ndarray\n\n\ndef dataframe_append(df: pd.DataFrame, other: Any) -> pd.DataFrame:\n    # dataframe.append is deprecated\n    other_df = pd.DataFrame(other).set_index(\"datetime\")\n    other_df.index.name = \"datetime\"\n\n    res = pd.concat([df, other_df], axis=0)\n    return res\n\n\ndef price_advantage(\n    exec_price: float_or_ndarray,\n    baseline_price: float,\n    direction: OrderDir | int,\n) -> float_or_ndarray:\n    if baseline_price == 0:  # something is wrong with data. Should be nan here\n        if isinstance(exec_price, float):\n            return 0.0\n        else:\n            return np.zeros_like(exec_price)\n    if direction == OrderDir.BUY:\n        res = (1 - exec_price / baseline_price) * 10000\n    elif direction == OrderDir.SELL:\n        res = (exec_price / baseline_price - 1) * 10000\n    else:\n        raise ValueError(f\"Unexpected order direction: {direction}\")\n    res_wo_nan: np.ndarray = np.nan_to_num(res, nan=0.0)\n    if res_wo_nan.size == 1:\n        return res_wo_nan.item()\n    else:\n        return cast(float_or_ndarray, res_wo_nan)\n\n\ndef get_simulator_executor(executor: BaseExecutor) -> SimulatorExecutor:\n    while isinstance(executor, NestedExecutor):\n        executor = executor.inner_executor\n    assert isinstance(executor, SimulatorExecutor)\n    return executor\n"
  },
  {
    "path": "qlib/rl/reward.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING, Any, Dict, Generic, Optional, Tuple, TypeVar\n\nfrom qlib.typehint import final\n\nif TYPE_CHECKING:\n    from .utils.env_wrapper import EnvWrapper\n\nSimulatorState = TypeVar(\"SimulatorState\")\n\n\nclass Reward(Generic[SimulatorState]):\n    \"\"\"\n    Reward calculation component that takes a single argument: state of simulator. Returns a real number: reward.\n\n    Subclass should implement ``reward(simulator_state)`` to implement their own reward calculation recipe.\n    \"\"\"\n\n    env: Optional[EnvWrapper] = None\n\n    @final\n    def __call__(self, simulator_state: SimulatorState) -> float:\n        return self.reward(simulator_state)\n\n    def reward(self, simulator_state: SimulatorState) -> float:\n        \"\"\"Implement this method for your own reward.\"\"\"\n        raise NotImplementedError(\"Implement reward calculation recipe in `reward()`.\")\n\n    def log(self, name: str, value: Any) -> None:\n        assert self.env is not None\n        self.env.logger.add_scalar(name, value)\n\n\nclass RewardCombination(Reward):\n    \"\"\"Combination of multiple reward.\"\"\"\n\n    def __init__(self, rewards: Dict[str, Tuple[Reward, float]]) -> None:\n        self.rewards = rewards\n\n    def reward(self, simulator_state: Any) -> float:\n        total_reward = 0.0\n        for name, (reward_fn, weight) in self.rewards.items():\n            rew = reward_fn(simulator_state) * weight\n            total_reward += rew\n            self.log(name, rew)\n        return total_reward\n\n\n# TODO:\n# reward_factory is disabled for now\n\n# _RegistryConfigReward = RegistryConfig[REWARDS]\n\n\n# @configclass\n# class _WeightedRewardConfig:\n#     weight: float\n#     reward: _RegistryConfigReward\n\n\n# RewardConfig = Union[_RegistryConfigReward, Dict[str, Union[_RegistryConfigReward, _WeightedRewardConfig]]]\n\n\n# def reward_factory(reward_config: RewardConfig) -> Reward:\n#     \"\"\"\n#     Use this factory to instantiate the reward from config.\n#     Simply using ``reward_config.build()`` might not work because reward can have complex combinations.\n#     \"\"\"\n#     if isinstance(reward_config, dict):\n#         # as reward combination\n#         rewards = {}\n#         for name, rew in reward_config.items():\n#             if not isinstance(rew, _WeightedRewardConfig):\n#                 # default weight is 1.\n#                 rew = _WeightedRewardConfig(weight=1., rew=rew)\n#             # no recursive build in this step\n#             rewards[name] = (rew.reward.build(), rew.weight)\n#         return RewardCombination(rewards)\n#     else:\n#         # single reward\n#         return reward_config.build()\n"
  },
  {
    "path": "qlib/rl/seed.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\n\"\"\"Defines a set of initial state definitions and state-set definitions.\n\nWith single-asset order execution only, the only seed is order.\n\"\"\"\n\nfrom typing import TypeVar\n\nInitialStateType = TypeVar(\"InitialStateType\")\n\"\"\"Type of data that creates the simulator.\"\"\"\n"
  },
  {
    "path": "qlib/rl/simulator.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar\n\nfrom .seed import InitialStateType\n\nif TYPE_CHECKING:\n    from .utils.env_wrapper import EnvWrapper\n\nStateType = TypeVar(\"StateType\")\n\"\"\"StateType stores all the useful data in the simulation process\n(as well as utilities to generate/retrieve data when needed).\"\"\"\n\nActType = TypeVar(\"ActType\")\n\"\"\"This ActType is the type of action at the simulator end.\"\"\"\n\n\nclass Simulator(Generic[InitialStateType, StateType, ActType]):\n    \"\"\"\n    Simulator that resets with ``__init__``, and transits with ``step(action)``.\n\n    To make the data-flow clear, we make the following restrictions to Simulator:\n\n    1. The only way to modify the inner status of a simulator is by using ``step(action)``.\n    2. External modules can *read* the status of a simulator by using ``simulator.get_state()``,\n       and check whether the simulator is in the ending state by calling ``simulator.done()``.\n\n    A simulator is defined to be bounded with three types:\n\n    - *InitialStateType* that is the type of the data used to create the simulator.\n    - *StateType* that is the type of the **status** (state) of the simulator.\n    - *ActType* that is the type of the **action**, which is the input received in each step.\n\n    Different simulators might share the same StateType. For example, when they are dealing with the same task,\n    but with different simulation implementation. With the same type, they can safely share other components in the MDP.\n\n    Simulators are ephemeral. The lifecycle of a simulator starts with an initial state, and ends with the trajectory.\n    In another word, when the trajectory ends, simulator is recycled.\n    If simulators want to share context between (e.g., for speed-up purposes),\n    this could be done by accessing the weak reference of environment wrapper.\n\n    Attributes\n    ----------\n    env\n        A reference of env-wrapper, which could be useful in some corner cases.\n        Simulators are discouraged to use this, because it's prone to induce errors.\n    \"\"\"\n\n    env: Optional[EnvWrapper] = None\n\n    def __init__(self, initial: InitialStateType, **kwargs: Any) -> None:\n        pass\n\n    def step(self, action: ActType) -> None:\n        \"\"\"Receives an action of ActType.\n\n        Simulator should update its internal state, and return None.\n        The updated state can be retrieved with ``simulator.get_state()``.\n        \"\"\"\n        raise NotImplementedError()\n\n    def get_state(self) -> StateType:\n        raise NotImplementedError()\n\n    def done(self) -> bool:\n        \"\"\"Check whether the simulator is in a \"done\" state.\n        When simulator is in a \"done\" state,\n        it should no longer receives any ``step`` request.\n        As simulators are ephemeral, to reset the simulator,\n        the old one should be destroyed and a new simulator can be created.\n        \"\"\"\n        raise NotImplementedError()\n"
  },
  {
    "path": "qlib/rl/strategy/__init__.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\nfrom .single_order import SingleOrderStrategy\n\n__all__ = [\"SingleOrderStrategy\"]\n"
  },
  {
    "path": "qlib/rl/strategy/single_order.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom __future__ import annotations\n\nfrom qlib.backtest import Order\nfrom qlib.backtest.decision import OrderHelper, TradeDecisionWO, TradeRange\nfrom qlib.strategy.base import BaseStrategy\n\n\nclass SingleOrderStrategy(BaseStrategy):\n    \"\"\"Strategy used to generate a trade decision with exactly one order.\"\"\"\n\n    def __init__(\n        self,\n        order: Order,\n        trade_range: TradeRange | None = None,\n    ) -> None:\n        super().__init__()\n\n        self._order = order\n        self._trade_range = trade_range\n\n    def generate_trade_decision(self, execute_result: list | None = None) -> TradeDecisionWO:\n        oh: OrderHelper = self.common_infra.get(\"trade_exchange\").get_order_helper()\n        order_list = [\n            oh.create(\n                code=self._order.stock_id,\n                amount=self._order.amount,\n                direction=self._order.direction,\n            ),\n        ]\n        return TradeDecisionWO(order_list, self, self._trade_range)\n"
  },
  {
    "path": "qlib/rl/trainer/__init__.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\n\"\"\"Train, test, inference utilities.\"\"\"\n\nfrom .api import backtest, train\nfrom .callbacks import Checkpoint, EarlyStopping, MetricsWriter\nfrom .trainer import Trainer\nfrom .vessel import TrainingVessel, TrainingVesselBase\n\n__all__ = [\n    \"Trainer\",\n    \"TrainingVessel\",\n    \"TrainingVesselBase\",\n    \"Checkpoint\",\n    \"EarlyStopping\",\n    \"MetricsWriter\",\n    \"train\",\n    \"backtest\",\n]\n"
  },
  {
    "path": "qlib/rl/trainer/api.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom __future__ import annotations\n\nfrom typing import Any, Callable, Dict, List, Sequence, cast\n\nfrom tianshou.policy import BasePolicy\n\nfrom qlib.rl.interpreter import ActionInterpreter, StateInterpreter\nfrom qlib.rl.reward import Reward\nfrom qlib.rl.simulator import InitialStateType, Simulator\nfrom qlib.rl.utils import FiniteEnvType, LogWriter\n\nfrom .trainer import Trainer\nfrom .vessel import TrainingVessel\n\n\ndef train(\n    simulator_fn: Callable[[InitialStateType], Simulator],\n    state_interpreter: StateInterpreter,\n    action_interpreter: ActionInterpreter,\n    initial_states: Sequence[InitialStateType],\n    policy: BasePolicy,\n    reward: Reward,\n    vessel_kwargs: Dict[str, Any],\n    trainer_kwargs: Dict[str, Any],\n) -> None:\n    \"\"\"Train a policy with the parallelism provided by RL framework.\n\n    Experimental API. Parameters might change shortly.\n\n    Parameters\n    ----------\n    simulator_fn\n        Callable receiving initial seed, returning a simulator.\n    state_interpreter\n        Interprets the state of simulators.\n    action_interpreter\n        Interprets the policy actions.\n    initial_states\n        Initial states to iterate over. Every state will be run exactly once.\n    policy\n        Policy to train against.\n    reward\n        Reward function.\n    vessel_kwargs\n        Keyword arguments passed to :class:`TrainingVessel`, like ``episode_per_iter``.\n    trainer_kwargs\n        Keyword arguments passed to :class:`Trainer`, like ``finite_env_type``, ``concurrency``.\n    \"\"\"\n\n    vessel = TrainingVessel(\n        simulator_fn=simulator_fn,\n        state_interpreter=state_interpreter,\n        action_interpreter=action_interpreter,\n        policy=policy,\n        train_initial_states=initial_states,\n        reward=reward,  # ignore none\n        **vessel_kwargs,\n    )\n    trainer = Trainer(**trainer_kwargs)\n    trainer.fit(vessel)\n\n\ndef backtest(\n    simulator_fn: Callable[[InitialStateType], Simulator],\n    state_interpreter: StateInterpreter,\n    action_interpreter: ActionInterpreter,\n    initial_states: Sequence[InitialStateType],\n    policy: BasePolicy,\n    logger: LogWriter | List[LogWriter],\n    reward: Reward | None = None,\n    finite_env_type: FiniteEnvType = \"subproc\",\n    concurrency: int = 2,\n) -> None:\n    \"\"\"Backtest with the parallelism provided by RL framework.\n\n    Experimental API. Parameters might change shortly.\n\n    Parameters\n    ----------\n    simulator_fn\n        Callable receiving initial seed, returning a simulator.\n    state_interpreter\n        Interprets the state of simulators.\n    action_interpreter\n        Interprets the policy actions.\n    initial_states\n        Initial states to iterate over. Every state will be run exactly once.\n    policy\n        Policy to test against.\n    logger\n        Logger to record the backtest results. Logger must be present because\n        without logger, all information will be lost.\n    reward\n        Optional reward function. For backtest, this is for testing the rewards\n        and logging them only.\n    finite_env_type\n        Type of finite env implementation.\n    concurrency\n        Parallel workers.\n    \"\"\"\n\n    vessel = TrainingVessel(\n        simulator_fn=simulator_fn,\n        state_interpreter=state_interpreter,\n        action_interpreter=action_interpreter,\n        policy=policy,\n        test_initial_states=initial_states,\n        reward=cast(Reward, reward),  # ignore none\n    )\n    trainer = Trainer(\n        finite_env_type=finite_env_type,\n        concurrency=concurrency,\n        loggers=logger,\n    )\n    trainer.test(vessel)\n"
  },
  {
    "path": "qlib/rl/trainer/callbacks.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\n\"\"\"Callbacks to insert customized recipes during the training.\nMimicks the hooks of Keras / PyTorch-Lightning, but tailored for the context of RL.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport copy\nimport os\nimport shutil\nimport time\nfrom datetime import datetime\nfrom pathlib import Path\nfrom typing import Any, List, TYPE_CHECKING\n\nimport numpy as np\nimport pandas as pd\nimport torch\n\nfrom qlib.log import get_module_logger\nfrom qlib.typehint import Literal\n\nif TYPE_CHECKING:\n    from .trainer import Trainer\n    from .vessel import TrainingVesselBase\n\n_logger = get_module_logger(__name__)\n\n\nclass Callback:\n    \"\"\"Base class of all callbacks.\"\"\"\n\n    def on_fit_start(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:\n        \"\"\"Called before the whole fit process begins.\"\"\"\n\n    def on_fit_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:\n        \"\"\"Called after the whole fit process ends.\"\"\"\n\n    def on_train_start(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:\n        \"\"\"Called when each collect for training begins.\"\"\"\n\n    def on_train_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:\n        \"\"\"Called when the training ends.\n        To access all outputs produced during training, cache the data in either trainer and vessel,\n        and post-process them in this hook.\n        \"\"\"\n\n    def on_validate_start(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:\n        \"\"\"Called when every run for validation begins.\"\"\"\n\n    def on_validate_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:\n        \"\"\"Called when the validation ends.\"\"\"\n\n    def on_test_start(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:\n        \"\"\"Called when every run of testing begins.\"\"\"\n\n    def on_test_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:\n        \"\"\"Called when the testing ends.\"\"\"\n\n    def on_iter_start(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:\n        \"\"\"Called when every iteration (i.e., collect) starts.\"\"\"\n\n    def on_iter_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:\n        \"\"\"Called upon every end of iteration.\n        This is called **after** the bump of ``current_iter``,\n        when the previous iteration is considered complete.\n        \"\"\"\n\n    def state_dict(self) -> Any:\n        \"\"\"Get a state dict of the callback for pause and resume.\"\"\"\n\n    def load_state_dict(self, state_dict: Any) -> None:\n        \"\"\"Resume the callback from a saved state dict.\"\"\"\n\n\nclass EarlyStopping(Callback):\n    \"\"\"Stop training when a monitored metric has stopped improving.\n\n    The earlystopping callback will be triggered each time validation ends.\n    It will examine the metrics produced in validation,\n    and get the metric with name ``monitor` (``monitor`` is ``reward`` by default),\n    to check whether it's no longer increasing / decreasing.\n    It takes ``min_delta`` and ``patience`` if applicable.\n    If it's found to be not increasing / decreasing any more.\n    ``trainer.should_stop`` will be set to true,\n    and the training terminates.\n\n    Implementation reference: https://github.com/keras-team/keras/blob/v2.9.0/keras/callbacks.py#L1744-L1893\n    \"\"\"\n\n    def __init__(\n        self,\n        monitor: str = \"reward\",\n        min_delta: float = 0.0,\n        patience: int = 0,\n        mode: Literal[\"min\", \"max\"] = \"max\",\n        baseline: float | None = None,\n        restore_best_weights: bool = False,\n    ):\n        super().__init__()\n\n        self.monitor = monitor\n        self.patience = patience\n        self.baseline = baseline\n        self.min_delta = abs(min_delta)\n        self.restore_best_weights = restore_best_weights\n        self.best_weights: Any | None = None\n\n        if mode not in [\"min\", \"max\"]:\n            raise ValueError(\"Unsupported earlystopping mode: \" + mode)\n\n        if mode == \"min\":\n            self.monitor_op = np.less\n        elif mode == \"max\":\n            self.monitor_op = np.greater\n\n        if self.monitor_op == np.greater:\n            self.min_delta *= 1\n        else:\n            self.min_delta *= -1\n\n    def state_dict(self) -> dict:\n        return {\"wait\": self.wait, \"best\": self.best, \"best_weights\": self.best_weights, \"best_iter\": self.best_iter}\n\n    def load_state_dict(self, state_dict: dict) -> None:\n        self.wait = state_dict[\"wait\"]\n        self.best = state_dict[\"best\"]\n        self.best_weights = state_dict[\"best_weights\"]\n        self.best_iter = state_dict[\"best_iter\"]\n\n    def on_fit_start(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:\n        # Allow instances to be re-used\n        self.wait = 0\n        self.best = np.inf if self.monitor_op == np.less else -np.inf\n        self.best_weights = None\n        self.best_iter = 0\n\n    def on_validate_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:\n        current = self.get_monitor_value(trainer)\n        if current is None:\n            return\n        if self.restore_best_weights and self.best_weights is None:\n            # Restore the weights after first iteration if no progress is ever made.\n            self.best_weights = copy.deepcopy(vessel.state_dict())\n\n        self.wait += 1\n        if self._is_improvement(current, self.best):\n            self.best = current\n            self.best_iter = trainer.current_iter\n            if self.restore_best_weights:\n                self.best_weights = copy.deepcopy(vessel.state_dict())\n            # Only restart wait if we beat both the baseline and our previous best.\n            if self.baseline is None or self._is_improvement(current, self.baseline):\n                self.wait = 0\n\n        msg = (\n            f\"#{trainer.current_iter} current reward: {current:.4f}, best reward: {self.best:.4f} in #{self.best_iter}\"\n        )\n        _logger.info(msg)\n\n        # Only check after the first epoch.\n        if self.wait >= self.patience and trainer.current_iter > 0:\n            trainer.should_stop = True\n            _logger.info(f\"On iteration %d: early stopping\", trainer.current_iter + 1)\n            if self.restore_best_weights and self.best_weights is not None:\n                _logger.info(\"Restoring model weights from the end of the best iteration: %d\", self.best_iter + 1)\n                vessel.load_state_dict(self.best_weights)\n\n    def get_monitor_value(self, trainer: Trainer) -> Any:\n        monitor_value = trainer.metrics.get(self.monitor)\n        if monitor_value is None:\n            _logger.warning(\n                \"Early stopping conditioned on metric `%s` which is not available. Available metrics are: %s\",\n                self.monitor,\n                \",\".join(list(trainer.metrics.keys())),\n            )\n        return monitor_value\n\n    def _is_improvement(self, monitor_value, reference_value):\n        return self.monitor_op(monitor_value - self.min_delta, reference_value)\n\n\nclass MetricsWriter(Callback):\n    \"\"\"Dump training metrics to file.\"\"\"\n\n    def __init__(self, dirpath: Path) -> None:\n        self.dirpath = dirpath\n        self.dirpath.mkdir(exist_ok=True, parents=True)\n        self.train_records: List[dict] = []\n        self.valid_records: List[dict] = []\n\n    def on_train_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:\n        self.train_records.append({k: v for k, v in trainer.metrics.items() if not k.startswith(\"val/\")})\n        pd.DataFrame.from_records(self.train_records).to_csv(self.dirpath / \"train_result.csv\", index=True)\n\n    def on_validate_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:\n        self.valid_records.append({k: v for k, v in trainer.metrics.items() if k.startswith(\"val/\")})\n        pd.DataFrame.from_records(self.valid_records).to_csv(self.dirpath / \"validation_result.csv\", index=True)\n\n\nclass Checkpoint(Callback):\n    \"\"\"Save checkpoints periodically for persistence and recovery.\n\n    Reference: https://github.com/PyTorchLightning/pytorch-lightning/blob/bfa8b7be/pytorch_lightning/callbacks/model_checkpoint.py\n\n    Parameters\n    ----------\n    dirpath\n        Directory to save the checkpoint file.\n    filename\n        Checkpoint filename. Can contain named formatting options to be auto-filled.\n        For example: ``{iter:03d}-{reward:.2f}.pth``.\n        Supported argument names are:\n\n        - iter (int)\n        - metrics in ``trainer.metrics``\n        - time string, in the format of ``%Y%m%d%H%M%S``\n    save_latest\n        Save the latest checkpoint in ``latest.pth``.\n        If ``link``, ``latest.pth`` will be created as a softlink.\n        If ``copy``, ``latest.pth`` will be stored as an individual copy.\n        Set to none to disable this.\n    every_n_iters\n        Checkpoints are saved at the end of every n iterations of training,\n        after validation if applicable.\n    time_interval\n        Maximum time (seconds) before checkpoints save again.\n    save_on_fit_end\n        Save one last checkpoint at the end to fit.\n        Do nothing if a checkpoint is already saved there.\n    \"\"\"\n\n    def __init__(\n        self,\n        dirpath: Path,\n        filename: str = \"{iter:03d}.pth\",\n        save_latest: Literal[\"link\", \"copy\"] | None = \"link\",\n        every_n_iters: int | None = None,\n        time_interval: int | None = None,\n        save_on_fit_end: bool = True,\n    ):\n        self.dirpath = Path(dirpath)\n        self.filename = filename\n        self.save_latest = save_latest\n        self.every_n_iters = every_n_iters\n        self.time_interval = time_interval\n        self.save_on_fit_end = save_on_fit_end\n\n        self._last_checkpoint_name: str | None = None\n        self._last_checkpoint_iter: int | None = None\n        self._last_checkpoint_time: float | None = None\n\n    def on_fit_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:\n        if self.save_on_fit_end and (trainer.current_iter != self._last_checkpoint_iter):\n            self._save_checkpoint(trainer)\n\n    def on_iter_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:\n        should_save_ckpt = False\n        if self.every_n_iters is not None and (trainer.current_iter + 1) % self.every_n_iters == 0:\n            should_save_ckpt = True\n        if self.time_interval is not None and (\n            self._last_checkpoint_time is None or (time.time() - self._last_checkpoint_time) >= self.time_interval\n        ):\n            should_save_ckpt = True\n        if should_save_ckpt:\n            self._save_checkpoint(trainer)\n\n    def _save_checkpoint(self, trainer: Trainer) -> None:\n        self.dirpath.mkdir(exist_ok=True, parents=True)\n        self._last_checkpoint_name = self._new_checkpoint_name(trainer)\n        self._last_checkpoint_iter = trainer.current_iter\n        self._last_checkpoint_time = time.time()\n        torch.save(trainer.state_dict(), self.dirpath / self._last_checkpoint_name)\n\n        latest_pth = self.dirpath / \"latest.pth\"\n\n        # Remove first before saving\n        if self.save_latest and (latest_pth.exists() or os.path.islink(latest_pth)):\n            latest_pth.unlink()\n\n        if self.save_latest == \"link\":\n            latest_pth.symlink_to(self.dirpath / self._last_checkpoint_name)\n        elif self.save_latest == \"copy\":\n            shutil.copyfile(self.dirpath / self._last_checkpoint_name, latest_pth)\n\n    def _new_checkpoint_name(self, trainer: Trainer) -> str:\n        return self.filename.format(\n            iter=trainer.current_iter, time=datetime.now().strftime(\"%Y%m%d%H%M%S\"), **trainer.metrics\n        )\n"
  },
  {
    "path": "qlib/rl/trainer/trainer.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom __future__ import annotations\n\nimport collections\nimport copy\nfrom contextlib import AbstractContextManager, contextmanager\nfrom datetime import datetime\nfrom pathlib import Path\nfrom typing import Any, Dict, Iterable, List, OrderedDict, Sequence, TypeVar, cast\n\nimport torch\n\nfrom qlib.log import get_module_logger\nfrom qlib.rl.simulator import InitialStateType\nfrom qlib.rl.utils import EnvWrapper, FiniteEnvType, LogBuffer, LogCollector, LogLevel, LogWriter, vectorize_env\nfrom qlib.rl.utils.finite_env import FiniteVectorEnv\nfrom qlib.typehint import Literal\n\nfrom .callbacks import Callback\nfrom .vessel import TrainingVesselBase\n\n_logger = get_module_logger(__name__)\n\n\nT = TypeVar(\"T\")\n\n\nclass Trainer:\n    \"\"\"\n    Utility to train a policy on a particular task.\n\n    Different from traditional DL trainer, the iteration of this trainer is \"collect\",\n    rather than \"epoch\", or \"mini-batch\".\n    In each collect, :class:`Collector` collects a number of policy-env interactions, and accumulates\n    them into a replay buffer. This buffer is used as the \"data\" to train the policy.\n    At the end of each collect, the policy is *updated* several times.\n\n    The API has some resemblence with `PyTorch Lightning <https://pytorch-lightning.readthedocs.io/>`__,\n    but it's essentially different because this trainer is built for RL applications, and thus\n    most configurations are under RL context.\n    We are still looking for ways to incorporate existing trainer libraries, because it looks like\n    big efforts to build a trainer as powerful as those libraries, and also, that's not our primary goal.\n\n    It's essentially different\n    `tianshou's built-in trainers <https://tianshou.readthedocs.io/en/master/api/tianshou.trainer.html>`__,\n    as it's far much more complicated than that.\n\n    Parameters\n    ----------\n    max_iters\n        Maximum iterations before stopping.\n    val_every_n_iters\n        Perform validation every n iterations (i.e., training collects).\n    logger\n        Logger to record the backtest results. Logger must be present because\n        without logger, all information will be lost.\n    finite_env_type\n        Type of finite env implementation.\n    concurrency\n        Parallel workers.\n    fast_dev_run\n        Create a subset for debugging.\n        How this is implemented depends on the implementation of training vessel.\n        For :class:`~qlib.rl.vessel.TrainingVessel`, if greater than zero,\n        a random subset sized ``fast_dev_run`` will be used\n        instead of ``train_initial_states`` and ``val_initial_states``.\n    \"\"\"\n\n    should_stop: bool\n    \"\"\"Set to stop the training.\"\"\"\n\n    metrics: dict\n    \"\"\"Numeric metrics of produced in train/val/test.\n    In the middle of training / validation, metrics will be of the latest episode.\n    When each iteration of training / validation finishes, metrics will be the aggregation\n    of all episodes encountered in this iteration.\n\n    Cleared on every new iteration of training.\n\n    In fit, validation metrics will be prefixed with ``val/``.\n    \"\"\"\n\n    current_iter: int\n    \"\"\"Current iteration (collect) of training.\"\"\"\n\n    loggers: List[LogWriter]\n    \"\"\"A list of log writers.\"\"\"\n\n    def __init__(\n        self,\n        *,\n        max_iters: int | None = None,\n        val_every_n_iters: int | None = None,\n        loggers: LogWriter | List[LogWriter] | None = None,\n        callbacks: List[Callback] | None = None,\n        finite_env_type: FiniteEnvType = \"subproc\",\n        concurrency: int = 2,\n        fast_dev_run: int | None = None,\n    ):\n        self.max_iters = max_iters\n        self.val_every_n_iters = val_every_n_iters\n\n        if isinstance(loggers, list):\n            self.loggers = loggers\n        elif isinstance(loggers, LogWriter):\n            self.loggers = [loggers]\n        else:\n            self.loggers = []\n\n        self.loggers.append(LogBuffer(self._metrics_callback, loglevel=self._min_loglevel()))\n\n        self.callbacks: List[Callback] = callbacks if callbacks is not None else []\n        self.finite_env_type = finite_env_type\n        self.concurrency = concurrency\n        self.fast_dev_run = fast_dev_run\n\n        self.current_stage: Literal[\"train\", \"val\", \"test\"] = \"train\"\n\n        self.vessel: TrainingVesselBase = cast(TrainingVesselBase, None)\n\n    def initialize(self):\n        \"\"\"Initialize the whole training process.\n\n        The states here should be synchronized with state_dict.\n        \"\"\"\n        self.should_stop = False\n        self.current_iter = 0\n        self.current_episode = 0\n        self.current_stage = \"train\"\n\n    def initialize_iter(self):\n        \"\"\"Initialize one iteration / collect.\"\"\"\n        self.metrics = {}\n\n    def state_dict(self) -> dict:\n        \"\"\"Putting every states of current training into a dict, at best effort.\n\n        It doesn't try to handle all the possible kinds of states in the middle of one training collect.\n        For most cases at the end of each iteration, things should be usually correct.\n\n        Note that it's also intended behavior that replay buffer data in the collector will be lost.\n        \"\"\"\n        return {\n            \"vessel\": self.vessel.state_dict(),\n            \"callbacks\": {name: callback.state_dict() for name, callback in self.named_callbacks().items()},\n            \"loggers\": {name: logger.state_dict() for name, logger in self.named_loggers().items()},\n            \"should_stop\": self.should_stop,\n            \"current_iter\": self.current_iter,\n            \"current_episode\": self.current_episode,\n            \"current_stage\": self.current_stage,\n            \"metrics\": self.metrics,\n        }\n\n    @staticmethod\n    def get_policy_state_dict(ckpt_path: Path) -> OrderedDict:\n        state_dict = torch.load(ckpt_path, map_location=\"cpu\")\n        if \"vessel\" in state_dict:\n            state_dict = state_dict[\"vessel\"][\"policy\"]\n        return state_dict\n\n    def load_state_dict(self, state_dict: dict) -> None:\n        \"\"\"Load all states into current trainer.\"\"\"\n        self.vessel.load_state_dict(state_dict[\"vessel\"])\n        for name, callback in self.named_callbacks().items():\n            callback.load_state_dict(state_dict[\"callbacks\"][name])\n        for name, logger in self.named_loggers().items():\n            logger.load_state_dict(state_dict[\"loggers\"][name])\n        self.should_stop = state_dict[\"should_stop\"]\n        self.current_iter = state_dict[\"current_iter\"]\n        self.current_episode = state_dict[\"current_episode\"]\n        self.current_stage = state_dict[\"current_stage\"]\n        self.metrics = state_dict[\"metrics\"]\n\n    def named_callbacks(self) -> Dict[str, Callback]:\n        \"\"\"Retrieve a collection of callbacks where each one has a name.\n        Useful when saving checkpoints.\n        \"\"\"\n        return _named_collection(self.callbacks)\n\n    def named_loggers(self) -> Dict[str, LogWriter]:\n        \"\"\"Retrieve a collection of loggers where each one has a name.\n        Useful when saving checkpoints.\n        \"\"\"\n        return _named_collection(self.loggers)\n\n    def fit(self, vessel: TrainingVesselBase, ckpt_path: Path | None = None) -> None:\n        \"\"\"Train the RL policy upon the defined simulator.\n\n        Parameters\n        ----------\n        vessel\n            A bundle of all elements used in training.\n        ckpt_path\n            Load a pre-trained / paused training checkpoint.\n        \"\"\"\n        self.vessel = vessel\n        vessel.assign_trainer(self)\n\n        if ckpt_path is not None:\n            _logger.info(\"Resuming states from %s\", str(ckpt_path))\n            self.load_state_dict(torch.load(ckpt_path, weights_only=False))\n        else:\n            self.initialize()\n\n        self._call_callback_hooks(\"on_fit_start\")\n\n        while not self.should_stop:\n            msg = f\"\\n{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\\tTrain iteration {self.current_iter + 1}/{self.max_iters}\"\n            _logger.info(msg)\n\n            self.initialize_iter()\n\n            self._call_callback_hooks(\"on_iter_start\")\n\n            self.current_stage = \"train\"\n            self._call_callback_hooks(\"on_train_start\")\n\n            # TODO\n            # Add a feature that supports reloading the training environment every few iterations.\n            with _wrap_context(vessel.train_seed_iterator()) as iterator:\n                vector_env = self.venv_from_iterator(iterator)\n                self.vessel.train(vector_env)\n                del vector_env  # FIXME: Explicitly delete this object to avoid memory leak.\n\n            self._call_callback_hooks(\"on_train_end\")\n\n            if self.val_every_n_iters is not None and (self.current_iter + 1) % self.val_every_n_iters == 0:\n                # Implementation of validation loop\n                self.current_stage = \"val\"\n                self._call_callback_hooks(\"on_validate_start\")\n                with _wrap_context(vessel.val_seed_iterator()) as iterator:\n                    vector_env = self.venv_from_iterator(iterator)\n                    self.vessel.validate(vector_env)\n                    del vector_env  # FIXME: Explicitly delete this object to avoid memory leak.\n\n                self._call_callback_hooks(\"on_validate_end\")\n\n            # This iteration is considered complete.\n            # Bumping the current iteration counter.\n            self.current_iter += 1\n\n            if self.max_iters is not None and self.current_iter >= self.max_iters:\n                self.should_stop = True\n\n            self._call_callback_hooks(\"on_iter_end\")\n\n        self._call_callback_hooks(\"on_fit_end\")\n\n    def test(self, vessel: TrainingVesselBase) -> None:\n        \"\"\"Test the RL policy against the simulator.\n\n        The simulator will be fed with data generated in ``test_seed_iterator``.\n\n        Parameters\n        ----------\n        vessel\n            A bundle of all related elements.\n        \"\"\"\n        self.vessel = vessel\n        vessel.assign_trainer(self)\n\n        self.initialize_iter()\n\n        self.current_stage = \"test\"\n        self._call_callback_hooks(\"on_test_start\")\n        with _wrap_context(vessel.test_seed_iterator()) as iterator:\n            vector_env = self.venv_from_iterator(iterator)\n            self.vessel.test(vector_env)\n            del vector_env  # FIXME: Explicitly delete this object to avoid memory leak.\n        self._call_callback_hooks(\"on_test_end\")\n\n    def venv_from_iterator(self, iterator: Iterable[InitialStateType]) -> FiniteVectorEnv:\n        \"\"\"Create a vectorized environment from iterator and the training vessel.\"\"\"\n\n        def env_factory():\n            # FIXME: state_interpreter and action_interpreter are stateful (having a weakref of env),\n            # and could be thread unsafe.\n            # I'm not sure whether it's a design flaw.\n            # I'll rethink about this when designing the trainer.\n\n            if self.finite_env_type == \"dummy\":\n                # We could only experience the \"threading-unsafe\" problem in dummy.\n                state = copy.deepcopy(self.vessel.state_interpreter)\n                action = copy.deepcopy(self.vessel.action_interpreter)\n                rew = copy.deepcopy(self.vessel.reward)\n            else:\n                state = self.vessel.state_interpreter\n                action = self.vessel.action_interpreter\n                rew = self.vessel.reward\n\n            return EnvWrapper(\n                self.vessel.simulator_fn,\n                state,\n                action,\n                iterator,\n                rew,\n                logger=LogCollector(min_loglevel=self._min_loglevel()),\n            )\n\n        return vectorize_env(\n            env_factory,\n            self.finite_env_type,\n            self.concurrency,\n            self.loggers,\n        )\n\n    def _metrics_callback(self, on_episode: bool, on_collect: bool, log_buffer: LogBuffer) -> None:\n        if on_episode:\n            # Update the global counter.\n            self.current_episode = log_buffer.global_episode\n            metrics = log_buffer.episode_metrics()\n        elif on_collect:\n            # Update the latest metrics.\n            metrics = log_buffer.collect_metrics()\n        if self.current_stage == \"val\":\n            metrics = {\"val/\" + name: value for name, value in metrics.items()}\n        self.metrics.update(metrics)\n\n    def _call_callback_hooks(self, hook_name: str, *args: Any, **kwargs: Any) -> None:\n        for callback in self.callbacks:\n            fn = getattr(callback, hook_name)\n            fn(self, self.vessel, *args, **kwargs)\n\n    def _min_loglevel(self):\n        if not self.loggers:\n            return LogLevel.PERIODIC\n        else:\n            # To save bandwidth\n            return min(lg.loglevel for lg in self.loggers)\n\n\n@contextmanager\ndef _wrap_context(obj):\n    \"\"\"Make any object a (possibly dummy) context manager.\"\"\"\n\n    if isinstance(obj, AbstractContextManager):\n        # obj has __enter__ and __exit__\n        with obj as ctx:\n            yield ctx\n    else:\n        yield obj\n\n\ndef _named_collection(seq: Sequence[T]) -> Dict[str, T]:\n    \"\"\"Convert a list into a dict, where each item is named with its type.\"\"\"\n    res = {}\n    retry_cnt: collections.Counter = collections.Counter()\n    for item in seq:\n        typename = type(item).__name__.lower()\n        key = typename if retry_cnt[typename] == 0 else f\"{typename}{retry_cnt[typename]}\"\n        retry_cnt[typename] += 1\n        res[key] = item\n    return res\n"
  },
  {
    "path": "qlib/rl/trainer/vessel.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom __future__ import annotations\n\nimport weakref\nfrom typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, Generic, Iterable, Sequence, TypeVar, cast\n\nimport numpy as np\nfrom tianshou.data import Collector, VectorReplayBuffer\nfrom tianshou.env import BaseVectorEnv\nfrom tianshou.policy import BasePolicy\n\nfrom qlib.constant import INF\nfrom qlib.log import get_module_logger\nfrom qlib.rl.interpreter import ActionInterpreter, ActType, ObsType, PolicyActType, StateInterpreter, StateType\nfrom qlib.rl.reward import Reward\nfrom qlib.rl.simulator import InitialStateType, Simulator\nfrom qlib.rl.utils import DataQueue\nfrom qlib.rl.utils.finite_env import FiniteVectorEnv\n\nif TYPE_CHECKING:\n    from .trainer import Trainer\n\n\nT = TypeVar(\"T\")\n_logger = get_module_logger(__name__)\n\n\nclass SeedIteratorNotAvailable(BaseException):\n    pass\n\n\nclass TrainingVesselBase(Generic[InitialStateType, StateType, ActType, ObsType, PolicyActType]):\n    \"\"\"A ship that contains simulator, interpreter, and policy, will be sent to trainer.\n    This class controls algorithm-related parts of training, while trainer is responsible for runtime part.\n\n    The ship also defines the most important logic of the core training part,\n    and (optionally) some callbacks to insert customized logics at specific events.\n    \"\"\"\n\n    simulator_fn: Callable[[InitialStateType], Simulator[InitialStateType, StateType, ActType]]\n    state_interpreter: StateInterpreter[StateType, ObsType]\n    action_interpreter: ActionInterpreter[StateType, PolicyActType, ActType]\n    policy: BasePolicy\n    reward: Reward\n    trainer: Trainer\n\n    def assign_trainer(self, trainer: Trainer) -> None:\n        self.trainer = weakref.proxy(trainer)  # type: ignore\n\n    def train_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]:\n        \"\"\"Override this to create a seed iterator for training.\n        If the iterable is a context manager, the whole training will be invoked in the with-block,\n        and the iterator will be automatically closed after the training is done.\"\"\"\n        raise SeedIteratorNotAvailable(\"Seed iterator for training is not available.\")\n\n    def val_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]:\n        \"\"\"Override this to create a seed iterator for validation.\"\"\"\n        raise SeedIteratorNotAvailable(\"Seed iterator for validation is not available.\")\n\n    def test_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]:\n        \"\"\"Override this to create a seed iterator for testing.\"\"\"\n        raise SeedIteratorNotAvailable(\"Seed iterator for testing is not available.\")\n\n    def train(self, vector_env: BaseVectorEnv) -> Dict[str, Any]:\n        \"\"\"Implement this to train one iteration. In RL, one iteration usually refers to one collect.\"\"\"\n        raise NotImplementedError()\n\n    def validate(self, vector_env: FiniteVectorEnv) -> Dict[str, Any]:\n        \"\"\"Implement this to validate the policy once.\"\"\"\n        raise NotImplementedError()\n\n    def test(self, vector_env: FiniteVectorEnv) -> Dict[str, Any]:\n        \"\"\"Implement this to evaluate the policy on test environment once.\"\"\"\n        raise NotImplementedError()\n\n    def log(self, name: str, value: Any) -> None:\n        # FIXME: this is a workaround to make the log at least show somewhere.\n        # Need a refactor in logger to formalize this.\n        if isinstance(value, (np.ndarray, list)):\n            value = np.mean(value)\n        _logger.info(f\"[Iter {self.trainer.current_iter + 1}] {name} = {value}\")\n\n    def log_dict(self, data: Dict[str, Any]) -> None:\n        for name, value in data.items():\n            self.log(name, value)\n\n    def state_dict(self) -> Dict:\n        \"\"\"Return a checkpoint of current vessel state.\"\"\"\n        return {\"policy\": self.policy.state_dict()}\n\n    def load_state_dict(self, state_dict: Dict) -> None:\n        \"\"\"Restore a checkpoint from a previously saved state dict.\"\"\"\n        self.policy.load_state_dict(state_dict[\"policy\"])\n\n\nclass TrainingVessel(TrainingVesselBase):\n    \"\"\"The default implementation of training vessel.\n\n    ``__init__`` accepts a sequence of initial states so that iterator can be created.\n    ``train``, ``validate``, ``test`` each do one collect (and also update in train).\n    By default, the train initial states will be repeated infinitely during training,\n    and collector will control the number of episodes for each iteration.\n    In validation and testing, the val / test initial states will be used exactly once.\n\n    Extra hyper-parameters (only used in train) include:\n\n    - ``buffer_size``: Size of replay buffer.\n    - ``episode_per_iter``: Episodes per collect at training. Can be overridden by fast dev run.\n    - ``update_kwargs``: Keyword arguments appearing in ``policy.update``.\n      For example, ``dict(repeat=10, batch_size=64)``.\n    \"\"\"\n\n    def __init__(\n        self,\n        *,\n        simulator_fn: Callable[[InitialStateType], Simulator[InitialStateType, StateType, ActType]],\n        state_interpreter: StateInterpreter[StateType, ObsType],\n        action_interpreter: ActionInterpreter[StateType, PolicyActType, ActType],\n        policy: BasePolicy,\n        reward: Reward,\n        train_initial_states: Sequence[InitialStateType] | None = None,\n        val_initial_states: Sequence[InitialStateType] | None = None,\n        test_initial_states: Sequence[InitialStateType] | None = None,\n        buffer_size: int = 20000,\n        episode_per_iter: int = 1000,\n        update_kwargs: Dict[str, Any] = cast(Dict[str, Any], None),\n    ):\n        self.simulator_fn = simulator_fn  # type: ignore\n        self.state_interpreter = state_interpreter\n        self.action_interpreter = action_interpreter\n        self.policy = policy\n        self.reward = reward\n        self.train_initial_states = train_initial_states\n        self.val_initial_states = val_initial_states\n        self.test_initial_states = test_initial_states\n        self.buffer_size = buffer_size\n        self.episode_per_iter = episode_per_iter\n        self.update_kwargs = update_kwargs or {}\n\n    def train_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]:\n        if self.train_initial_states is not None:\n            _logger.info(\"Training initial states collection size: %d\", len(self.train_initial_states))\n            # Implement fast_dev_run here.\n            train_initial_states = self._random_subset(\"train\", self.train_initial_states, self.trainer.fast_dev_run)\n            return DataQueue(train_initial_states, repeat=-1, shuffle=True)\n        return super().train_seed_iterator()\n\n    def val_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]:\n        if self.val_initial_states is not None:\n            _logger.info(\"Validation initial states collection size: %d\", len(self.val_initial_states))\n            val_initial_states = self._random_subset(\"val\", self.val_initial_states, self.trainer.fast_dev_run)\n            return DataQueue(val_initial_states, repeat=1)\n        return super().val_seed_iterator()\n\n    def test_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]:\n        if self.test_initial_states is not None:\n            _logger.info(\"Testing initial states collection size: %d\", len(self.test_initial_states))\n            test_initial_states = self._random_subset(\"test\", self.test_initial_states, self.trainer.fast_dev_run)\n            return DataQueue(test_initial_states, repeat=1)\n        return super().test_seed_iterator()\n\n    def train(self, vector_env: FiniteVectorEnv) -> Dict[str, Any]:\n        \"\"\"Create a collector and collects ``episode_per_iter`` episodes.\n        Update the policy on the collected replay buffer.\n        \"\"\"\n        self.policy.train()\n\n        with vector_env.collector_guard():\n            collector = Collector(\n                self.policy, vector_env, VectorReplayBuffer(self.buffer_size, len(vector_env)), exploration_noise=True\n            )\n\n            # Number of episodes collected in each training iteration can be overridden by fast dev run.\n            if self.trainer.fast_dev_run is not None:\n                episodes = self.trainer.fast_dev_run\n            else:\n                episodes = self.episode_per_iter\n\n            col_result = collector.collect(n_episode=episodes)\n            update_result = self.policy.update(sample_size=0, buffer=collector.buffer, **self.update_kwargs)\n            res = {**col_result, **update_result}\n            self.log_dict(res)\n            return res\n\n    def validate(self, vector_env: FiniteVectorEnv) -> Dict[str, Any]:\n        self.policy.eval()\n\n        with vector_env.collector_guard():\n            test_collector = Collector(self.policy, vector_env)\n            res = test_collector.collect(n_step=INF * len(vector_env))\n            self.log_dict(res)\n            return res\n\n    def test(self, vector_env: FiniteVectorEnv) -> Dict[str, Any]:\n        self.policy.eval()\n\n        with vector_env.collector_guard():\n            test_collector = Collector(self.policy, vector_env)\n            res = test_collector.collect(n_step=INF * len(vector_env))\n            self.log_dict(res)\n            return res\n\n    @staticmethod\n    def _random_subset(name: str, collection: Sequence[T], size: int | None) -> Sequence[T]:\n        if size is None:\n            # Size = None -> original collection\n            return collection\n        order = np.random.permutation(len(collection))\n        res = [collection[o] for o in order[:size]]\n        _logger.info(\n            \"Fast running in development mode. Cut %s initial states from %d to %d.\",\n            name,\n            len(collection),\n            len(res),\n        )\n        return res\n"
  },
  {
    "path": "qlib/rl/utils/__init__.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom .data_queue import DataQueue\nfrom .env_wrapper import EnvWrapper, EnvWrapperStatus\nfrom .finite_env import FiniteEnvType, vectorize_env\nfrom .log import ConsoleWriter, CsvWriter, LogBuffer, LogCollector, LogLevel, LogWriter\n\n__all__ = [\n    \"LogLevel\",\n    \"DataQueue\",\n    \"EnvWrapper\",\n    \"FiniteEnvType\",\n    \"LogCollector\",\n    \"LogWriter\",\n    \"vectorize_env\",\n    \"ConsoleWriter\",\n    \"CsvWriter\",\n    \"EnvWrapperStatus\",\n    \"LogBuffer\",\n]\n"
  },
  {
    "path": "qlib/rl/utils/data_queue.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom __future__ import annotations\n\nimport multiprocessing\nfrom multiprocessing.sharedctypes import Synchronized\nimport os\nimport threading\nimport time\nimport warnings\nfrom queue import Empty\nfrom typing import Any, Generator, Generic, Sequence, TypeVar, cast\n\nfrom qlib.log import get_module_logger\n\n_logger = get_module_logger(__name__)\n\nT = TypeVar(\"T\")\n\n__all__ = [\"DataQueue\"]\n\n\nclass DataQueue(Generic[T]):\n    \"\"\"Main process (producer) produces data and stores them in a queue.\n    Sub-processes (consumers) can retrieve the data-points from the queue.\n    Data-points are generated via reading items from ``dataset``.\n\n    :class:`DataQueue` is ephemeral. You must create a new DataQueue\n    when the ``repeat`` is exhausted.\n\n    See the documents of :class:`qlib.rl.utils.FiniteVectorEnv` for more background.\n\n    Parameters\n    ----------\n    dataset\n        The dataset to read data from. Must implement ``__len__`` and ``__getitem__``.\n    repeat\n        Iterate over the data-points for how many times. Use ``-1`` to iterate forever.\n    shuffle\n        If ``shuffle`` is true, the items will be read in random order.\n    producer_num_workers\n        Concurrent workers for data-loading.\n    queue_maxsize\n        Maximum items to put into queue before it jams.\n\n    Examples\n    --------\n    >>> data_queue = DataQueue(my_dataset)\n    >>> with data_queue:\n    ...     ...\n\n    In worker:\n\n    >>> for data in data_queue:\n    ...     print(data)\n    \"\"\"\n\n    def __init__(\n        self,\n        dataset: Sequence[T],\n        repeat: int = 1,\n        shuffle: bool = True,\n        producer_num_workers: int = 0,\n        queue_maxsize: int = 0,\n    ) -> None:\n        if queue_maxsize == 0:\n            if os.cpu_count() is not None:\n                queue_maxsize = cast(int, os.cpu_count())\n                _logger.info(f\"Automatically set data queue maxsize to {queue_maxsize} to avoid overwhelming.\")\n            else:\n                queue_maxsize = 1\n                _logger.warning(f\"CPU count not available. Setting queue maxsize to 1.\")\n\n        self.dataset: Sequence[T] = dataset\n        self.repeat: int = repeat\n        self.shuffle: bool = shuffle\n        self.producer_num_workers: int = producer_num_workers\n\n        self._activated: bool = False\n        self._queue: multiprocessing.Queue = multiprocessing.Queue(maxsize=queue_maxsize)\n        # Mypy 0.981 brought '\"SynchronizedBase[Any]\" has no attribute \"value\"  [attr-defined]' bug.\n        # Therefore, add this type casting to pass Mypy checking.\n        self._done = cast(Synchronized, multiprocessing.Value(\"i\", 0))\n\n    def __enter__(self) -> DataQueue:\n        self.activate()\n        return self\n\n    def __exit__(self, exc_type, exc_val, exc_tb):\n        self.cleanup()\n\n    def cleanup(self) -> None:\n        with self._done.get_lock():\n            self._done.value += 1\n        for repeat in range(500):\n            if repeat >= 1:\n                warnings.warn(f\"After {repeat} cleanup, the queue is still not empty.\", category=RuntimeWarning)\n            while not self._queue.empty():\n                try:\n                    self._queue.get(block=False)\n                except Empty:\n                    pass\n            # Sometimes when the queue gets emptied, more data have already been sent,\n            # and they are on the way into the queue.\n            # If these data didn't get consumed, it will jam the queue and make the process hang.\n            # We wait a second here for potential data arriving, and check again (for ``repeat`` times).\n            time.sleep(1.0)\n            if self._queue.empty():\n                break\n        _logger.debug(f\"Remaining items in queue collection done. Empty: {self._queue.empty()}\")\n\n    def get(self, block: bool = True) -> Any:\n        if not hasattr(self, \"_first_get\"):\n            self._first_get = True\n        if self._first_get:\n            timeout = 5.0\n            self._first_get = False\n        else:\n            timeout = 0.5\n        while True:\n            try:\n                return self._queue.get(block=block, timeout=timeout)\n            except Empty:\n                if self._done.value:\n                    raise StopIteration  # pylint: disable=raise-missing-from\n\n    def put(self, obj: Any, block: bool = True, timeout: int | None = None) -> None:\n        self._queue.put(obj, block=block, timeout=timeout)\n\n    def mark_as_done(self) -> None:\n        with self._done.get_lock():\n            self._done.value = 1\n\n    def done(self) -> int:\n        return self._done.value\n\n    def activate(self) -> DataQueue:\n        if self._activated:\n            raise ValueError(\"DataQueue can not activate twice.\")\n        thread = threading.Thread(target=self._producer, daemon=True)\n        thread.start()\n        self._activated = True\n        return self\n\n    def __del__(self) -> None:\n        _logger.debug(f\"__del__ of {__name__}.DataQueue\")\n        self.cleanup()\n\n    def __iter__(self) -> Generator[Any, None, None]:\n        if not self._activated:\n            raise ValueError(\n                \"Need to call activate() to launch a daemon worker \"\n                \"to produce data into data queue before using it. \"\n                \"You probably have forgotten to use the DataQueue in a with block.\",\n            )\n        return self._consumer()\n\n    def _consumer(self) -> Generator[Any, None, None]:\n        while True:\n            try:\n                yield self.get()\n            except StopIteration:\n                _logger.debug(\"Data consumer timed-out from get.\")\n                return\n\n    def _producer(self) -> None:\n        # pytorch dataloader is used here only because we need its sampler and multi-processing\n        from torch.utils.data import DataLoader, Dataset  # pylint: disable=import-outside-toplevel\n\n        try:\n            dataloader = DataLoader(\n                cast(Dataset[T], self.dataset),\n                batch_size=None,\n                num_workers=self.producer_num_workers,\n                shuffle=self.shuffle,\n                collate_fn=lambda t: t,  # identity collate fn\n            )\n            repeat = 10**18 if self.repeat == -1 else self.repeat\n            for _rep in range(repeat):\n                for data in dataloader:\n                    if self._done.value:\n                        # Already done.\n                        return\n                    self._queue.put(data)\n                _logger.debug(f\"Dataloader loop done. Repeat {_rep}.\")\n        finally:\n            self.mark_as_done()\n"
  },
  {
    "path": "qlib/rl/utils/env_wrapper.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom __future__ import annotations\n\nimport weakref\nfrom typing import Any, Callable, cast, Dict, Generic, Iterable, Iterator, Optional, Tuple\n\nimport gym\nfrom gym import Space\n\nfrom qlib.rl.aux_info import AuxiliaryInfoCollector\nfrom qlib.rl.interpreter import ActionInterpreter, ObsType, PolicyActType, StateInterpreter\nfrom qlib.rl.reward import Reward\nfrom qlib.rl.simulator import ActType, InitialStateType, Simulator, StateType\nfrom qlib.typehint import TypedDict\nfrom .finite_env import generate_nan_observation\nfrom .log import LogCollector, LogLevel\n\n__all__ = [\"InfoDict\", \"EnvWrapperStatus\", \"EnvWrapper\"]\n\n# in this case, there won't be any seed for simulator\nSEED_INTERATOR_MISSING = \"_missing_\"\n\n\nclass InfoDict(TypedDict):\n    \"\"\"The type of dict that is used in the 4th return value of ``env.step()``.\"\"\"\n\n    aux_info: dict\n    \"\"\"Any information depends on auxiliary info collector.\"\"\"\n    log: Dict[str, Any]\n    \"\"\"Collected by LogCollector.\"\"\"\n\n\nclass EnvWrapperStatus(TypedDict):\n    \"\"\"\n    This is the status data structure used in EnvWrapper.\n    The fields here are in the semantics of RL.\n    For example, ``obs`` means the observation fed into policy.\n    ``action`` means the raw action returned by policy.\n    \"\"\"\n\n    cur_step: int\n    done: bool\n    initial_state: Optional[Any]\n    obs_history: list\n    action_history: list\n    reward_history: list\n\n\nclass EnvWrapper(\n    gym.Env[ObsType, PolicyActType],\n    Generic[InitialStateType, StateType, ActType, ObsType, PolicyActType],\n):\n    \"\"\"Qlib-based RL environment, subclassing ``gym.Env``.\n    A wrapper of components, including simulator, state-interpreter, action-interpreter, reward.\n\n    This is what the framework of simulator - interpreter - policy looks like in RL training.\n    All the components other than policy needs to be assembled into a single object called \"environment\".\n    The \"environment\" are replicated into multiple workers, and (at least in tianshou's implementation),\n    one single policy (agent) plays against a batch of environments.\n\n    Parameters\n    ----------\n    simulator_fn\n        A callable that is the simulator factory.\n        When ``seed_iterator`` is present, the factory should take one argument,\n        that is the seed (aka initial state).\n        Otherwise, it should take zero argument.\n    state_interpreter\n        State-observation converter.\n    action_interpreter\n        Policy-simulator action converter.\n    seed_iterator\n        An iterable of seed. With the help of :class:`qlib.rl.utils.DataQueue`,\n        environment workers in different processes can share one ``seed_iterator``.\n    reward_fn\n        A callable that accepts the StateType and returns a float (at least in single-agent case).\n    aux_info_collector\n        Collect auxiliary information. Could be useful in MARL.\n    logger\n        Log collector that collects the logs. The collected logs are sent back to main process,\n        via the return value of ``env.step()``.\n\n    Attributes\n    ----------\n    status : EnvWrapperStatus\n        Status indicator. All terms are in *RL language*.\n        It can be used if users care about data on the RL side.\n        Can be none when no trajectory is available.\n    \"\"\"\n\n    simulator: Simulator[InitialStateType, StateType, ActType]\n    seed_iterator: str | Iterator[InitialStateType] | None\n\n    def __init__(\n        self,\n        simulator_fn: Callable[..., Simulator[InitialStateType, StateType, ActType]],\n        state_interpreter: StateInterpreter[StateType, ObsType],\n        action_interpreter: ActionInterpreter[StateType, PolicyActType, ActType],\n        seed_iterator: Optional[Iterable[InitialStateType]],\n        reward_fn: Reward | None = None,\n        aux_info_collector: AuxiliaryInfoCollector[StateType, Any] | None = None,\n        logger: LogCollector | None = None,\n    ) -> None:\n        # Assign weak reference to wrapper.\n        #\n        # Use weak reference here, because:\n        # 1. Logically, the other components should be able to live without an env_wrapper.\n        #    For example, they might live in a strategy_wrapper in future.\n        #    Therefore injecting a \"hard\" attribute called \"env\" is not appropripate.\n        # 2. When the environment gets destroyed, it gets destoryed.\n        #    We don't want it to silently live inside some interpreters.\n        # 3. Avoid circular reference.\n        # 4. When the components get serialized, we can throw away the env without any burden.\n        #    (though this part is not implemented yet)\n        for obj in [state_interpreter, action_interpreter, reward_fn, aux_info_collector]:\n            if obj is not None:\n                obj.env = weakref.proxy(self)  # type: ignore\n\n        self.simulator_fn = simulator_fn\n        self.state_interpreter = state_interpreter\n        self.action_interpreter = action_interpreter\n\n        if seed_iterator is None:\n            # In this case, there won't be any seed for simulator\n            # We can't set it to None because None actually means something else.\n            # If `seed_iterator` is None, it means that it's exhausted.\n            self.seed_iterator = SEED_INTERATOR_MISSING\n        else:\n            self.seed_iterator = iter(seed_iterator)\n        self.reward_fn = reward_fn\n\n        self.aux_info_collector = aux_info_collector\n        self.logger: LogCollector = logger or LogCollector()\n        self.status: EnvWrapperStatus = cast(EnvWrapperStatus, None)\n\n    @property\n    def action_space(self) -> Space:\n        return self.action_interpreter.action_space\n\n    @property\n    def observation_space(self) -> Space:\n        return self.state_interpreter.observation_space\n\n    def reset(self, **kwargs: Any) -> ObsType:\n        \"\"\"\n        Try to get a state from state queue, and init the simulator with this state.\n        If the queue is exhausted, generate an invalid (nan) observation.\n        \"\"\"\n\n        try:\n            if self.seed_iterator is None:\n                raise RuntimeError(\"You can trying to get a state from a dead environment wrapper.\")\n\n            # TODO: simulator/observation might need seed to prefetch something\n            # as only seed has the ability to do the work beforehands\n\n            # NOTE: though logger is reset here, logs in this function won't work,\n            # because we can't send them outside.\n            # See https://github.com/thu-ml/tianshou/issues/605\n            self.logger.reset()\n\n            if self.seed_iterator is SEED_INTERATOR_MISSING:\n                # no initial state\n                initial_state = None\n                self.simulator = cast(Callable[[], Simulator], self.simulator_fn)()\n            else:\n                initial_state = next(cast(Iterator[InitialStateType], self.seed_iterator))\n                self.simulator = self.simulator_fn(initial_state)\n\n            self.status = EnvWrapperStatus(\n                cur_step=0,\n                done=False,\n                initial_state=initial_state,\n                obs_history=[],\n                action_history=[],\n                reward_history=[],\n            )\n\n            self.simulator.env = cast(EnvWrapper, weakref.proxy(self))\n\n            sim_state = self.simulator.get_state()\n            obs = self.state_interpreter(sim_state)\n\n            self.status[\"obs_history\"].append(obs)\n\n            return obs\n\n        except StopIteration:\n            # The environment should be recycled because it's in a dead state.\n            self.seed_iterator = None\n            return generate_nan_observation(self.observation_space)\n\n    def step(self, policy_action: PolicyActType, **kwargs: Any) -> Tuple[ObsType, float, bool, InfoDict]:\n        \"\"\"Environment step.\n\n        See the code along with comments to get a sequence of things happening here.\n        \"\"\"\n\n        if self.seed_iterator is None:\n            raise RuntimeError(\"State queue is already exhausted, but the environment is still receiving action.\")\n\n        # Clear the logged information from last step\n        self.logger.reset()\n\n        # Action is what we have got from policy\n        self.status[\"action_history\"].append(policy_action)\n        action = self.action_interpreter(self.simulator.get_state(), policy_action)\n\n        # This update must be after action interpreter and before simulator.\n        self.status[\"cur_step\"] += 1\n\n        # Use the converted action of update the simulator\n        self.simulator.step(action)\n\n        # Update \"done\" first, as this status might be used by reward_fn later\n        done = self.simulator.done()\n        self.status[\"done\"] = done\n\n        # Get state and calculate observation\n        sim_state = self.simulator.get_state()\n        obs = self.state_interpreter(sim_state)\n        self.status[\"obs_history\"].append(obs)\n\n        # Reward and extra info\n        if self.reward_fn is not None:\n            rew = self.reward_fn(sim_state)\n        else:\n            # No reward. Treated as 0.\n            rew = 0.0\n        self.status[\"reward_history\"].append(rew)\n\n        if self.aux_info_collector is not None:\n            aux_info = self.aux_info_collector(sim_state)\n        else:\n            aux_info = {}\n\n        # Final logging stuff: RL-specific logs\n        if done:\n            self.logger.add_scalar(\"steps_per_episode\", self.status[\"cur_step\"])\n        self.logger.add_scalar(\"reward\", rew)\n        self.logger.add_any(\"obs\", obs, loglevel=LogLevel.DEBUG)\n        self.logger.add_any(\"policy_act\", policy_action, loglevel=LogLevel.DEBUG)\n\n        info_dict = InfoDict(log=self.logger.logs(), aux_info=aux_info)\n        return obs, rew, done, info_dict\n\n    def render(self, mode: str = \"human\") -> None:\n        raise NotImplementedError(\"Render is not implemented in EnvWrapper.\")\n"
  },
  {
    "path": "qlib/rl/utils/finite_env.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\n\"\"\"\nThis is to support finite env in vector env.\nSee https://github.com/thu-ml/tianshou/issues/322 for details.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport copy\nimport warnings\nfrom contextlib import contextmanager\nfrom typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple, Type, Union, cast\n\nimport gym\nimport numpy as np\nfrom tianshou.env import BaseVectorEnv, DummyVectorEnv, ShmemVectorEnv, SubprocVectorEnv\n\nfrom qlib.typehint import Literal\n\nfrom .log import LogWriter\n\n__all__ = [\n    \"generate_nan_observation\",\n    \"check_nan_observation\",\n    \"FiniteVectorEnv\",\n    \"FiniteDummyVectorEnv\",\n    \"FiniteSubprocVectorEnv\",\n    \"FiniteShmemVectorEnv\",\n    \"FiniteEnvType\",\n    \"vectorize_env\",\n]\n\nFiniteEnvType = Literal[\"dummy\", \"subproc\", \"shmem\"]\nT = Union[dict, list, tuple, np.ndarray]\n\n\ndef fill_invalid(obj: int | float | bool | T) -> T:\n    if isinstance(obj, (int, float, bool)):\n        return fill_invalid(np.array(obj))\n    if hasattr(obj, \"dtype\"):\n        if isinstance(obj, np.ndarray):\n            if np.issubdtype(obj.dtype, np.floating):\n                return np.full_like(obj, np.nan)\n            return np.full_like(obj, np.iinfo(obj.dtype).max)\n        # dealing with corner cases that numpy number is not supported by tianshou's sharray\n        return fill_invalid(np.array(obj))\n    elif isinstance(obj, dict):\n        return {k: fill_invalid(v) for k, v in obj.items()}\n    elif isinstance(obj, list):\n        return [fill_invalid(v) for v in obj]\n    elif isinstance(obj, tuple):\n        return tuple(fill_invalid(v) for v in obj)\n    raise ValueError(f\"Unsupported value to fill with invalid: {obj}\")\n\n\ndef is_invalid(arr: int | float | bool | T) -> bool:\n    if isinstance(arr, np.ndarray):\n        if np.issubdtype(arr.dtype, np.floating):\n            return np.isnan(arr).all()\n        return cast(bool, cast(np.ndarray, np.iinfo(arr.dtype).max == arr).all())\n    if isinstance(arr, dict):\n        return all(is_invalid(o) for o in arr.values())\n    if isinstance(arr, (list, tuple)):\n        return all(is_invalid(o) for o in arr)\n    if isinstance(arr, (int, float, bool, np.number)):\n        return is_invalid(np.array(arr))\n    return True\n\n\ndef generate_nan_observation(obs_space: gym.Space) -> Any:\n    \"\"\"The NaN observation that indicates the environment receives no seed.\n\n    We assume that obs is complex and there must be something like float.\n    Otherwise this logic doesn't work.\n    \"\"\"\n\n    sample = obs_space.sample()\n    sample = fill_invalid(sample)\n    return sample\n\n\ndef check_nan_observation(obs: Any) -> bool:\n    \"\"\"Check whether obs is generated by :func:`generate_nan_observation`.\"\"\"\n    return is_invalid(obs)\n\n\nclass FiniteVectorEnv(BaseVectorEnv):\n    \"\"\"To allow the paralleled env workers consume a single DataQueue until it's exhausted.\n\n    See `tianshou issue #322 <https://github.com/thu-ml/tianshou/issues/322>`_.\n\n    The requirement is to make every possible seed (stored in :class:`qlib.rl.utils.DataQueue` in our case)\n    consumed by exactly one environment. This is not possible by tianshou's native VectorEnv and Collector,\n    because tianshou is unaware of this \"exactly one\" constraint, and might launch extra workers.\n\n    Consider a corner case, where concurrency is 2, but there is only one seed in DataQueue.\n    The reset of two workers must be both called according to the logic in collect.\n    The returned results of two workers are collected, regardless of what they are.\n    The problem is, one of the reset result must be invalid, or repeated,\n    because there's only one need in queue, and collector isn't aware of such situation.\n\n    Luckily, we can hack the vector env, and make a protocol between single env and vector env.\n    The single environment (should be :class:`qlib.rl.utils.EnvWrapper` in our case) is responsible for\n    reading from queue, and generate a special observation when the queue is exhausted. The special obs\n    is called \"nan observation\", because simply using none causes problems in shared-memory vector env.\n    :class:`FiniteVectorEnv` then read the observations from all workers, and select those non-nan\n    observation. It also maintains an ``_alive_env_ids`` to track which workers should never be\n    called again. When also the environments are exhausted, it will raise StopIteration exception.\n\n    The usage of this vector env in collector are two parts:\n\n    1. If the data queue is finite (usually when inference), collector should collect \"infinity\" number of\n       episodes, until the vector env exhausts by itself.\n    2. If the data queue is infinite (usually in training), collector can set number of episodes / steps.\n       In this case, data would be randomly ordered, and some repetitions wouldn't matter.\n\n    One extra function of this vector env is that it has a logger that explicitly collects logs\n    from child workers. See :class:`qlib.rl.utils.LogWriter`.\n    \"\"\"\n\n    _logger: list[LogWriter]\n\n    def __init__(\n        self, logger: LogWriter | list[LogWriter] | None, env_fns: list[Callable[..., gym.Env]], **kwargs: Any\n    ) -> None:\n        super().__init__(env_fns, **kwargs)\n\n        if isinstance(logger, list):\n            self._logger = logger\n        elif isinstance(logger, LogWriter):\n            self._logger = [logger]\n        else:\n            self._logger = []\n        self._alive_env_ids: Set[int] = set()\n        self._reset_alive_envs()\n        self._default_obs = self._default_info = self._default_rew = None\n        self._zombie = False\n\n        self._collector_guarded: bool = False\n\n    def _reset_alive_envs(self) -> None:\n        if not self._alive_env_ids:\n            # starting or running out\n            self._alive_env_ids = set(range(self.env_num))\n\n    # to workaround with tianshou's buffer and batch\n    def _set_default_obs(self, obs: Any) -> None:\n        if obs is not None and self._default_obs is None:\n            self._default_obs = copy.deepcopy(obs)\n\n    def _set_default_info(self, info: Any) -> None:\n        if info is not None and self._default_info is None:\n            self._default_info = copy.deepcopy(info)\n\n    def _set_default_rew(self, rew: Any) -> None:\n        if rew is not None and self._default_rew is None:\n            self._default_rew = copy.deepcopy(rew)\n\n    def _get_default_obs(self) -> Any:\n        return copy.deepcopy(self._default_obs)\n\n    def _get_default_info(self) -> Any:\n        return copy.deepcopy(self._default_info)\n\n    def _get_default_rew(self) -> Any:\n        return copy.deepcopy(self._default_rew)\n\n    # END\n\n    @staticmethod\n    def _postproc_env_obs(obs: Any) -> Optional[Any]:\n        # reserved for shmem vector env to restore empty observation\n        if obs is None or check_nan_observation(obs):\n            return None\n        return obs\n\n    @contextmanager\n    def collector_guard(self) -> Generator[FiniteVectorEnv, None, None]:\n        \"\"\"Guard the collector. Recommended to guard every collect.\n\n        This guard is for two purposes.\n\n        1. Catch and ignore the StopIteration exception, which is the stopping signal\n           thrown by FiniteEnv to let tianshou know that ``collector.collect()`` should exit.\n        2. Notify the loggers that the collect is ready / done what it's ready / done.\n\n        Examples\n        --------\n        >>> with finite_env.collector_guard():\n        ...     collector.collect(n_episode=INF)\n        \"\"\"\n        self._collector_guarded = True\n\n        for logger in self._logger:\n            logger.on_env_all_ready()\n\n        try:\n            yield self\n        except StopIteration:\n            pass\n        finally:\n            self._collector_guarded = False\n\n        # At last trigger the loggers\n        for logger in self._logger:\n            logger.on_env_all_done()\n\n    def reset(\n        self,\n        id: int | List[int] | np.ndarray | None = None,\n    ) -> np.ndarray:\n        assert not self._zombie\n\n        # Check whether it's guarded by collector_guard()\n        if not self._collector_guarded:\n            warnings.warn(\n                \"Collector is not guarded by FiniteEnv. \"\n                \"This may cause unexpected problems, like unexpected StopIteration exception, \"\n                \"or missing logs.\",\n                RuntimeWarning,\n            )\n\n        wrapped_id = self._wrap_id(id)\n        self._reset_alive_envs()\n\n        # ask super to reset alive envs and remap to current index\n        request_id = [i for i in wrapped_id if i in self._alive_env_ids]\n        obs = [None] * len(wrapped_id)\n        id2idx = {i: k for k, i in enumerate(wrapped_id)}\n        if request_id:\n            for i, o in zip(request_id, super().reset(request_id)):\n                obs[id2idx[i]] = self._postproc_env_obs(o)\n\n        for i, o in zip(wrapped_id, obs):\n            if o is None and i in self._alive_env_ids:\n                self._alive_env_ids.remove(i)\n\n        # logging\n        for i, o in zip(wrapped_id, obs):\n            if i in self._alive_env_ids:\n                for logger in self._logger:\n                    logger.on_env_reset(i, obs)\n\n        # fill empty observation with default(fake) observation\n        for o in obs:\n            self._set_default_obs(o)\n        for i, o in enumerate(obs):\n            if o is None:\n                obs[i] = self._get_default_obs()\n\n        if not self._alive_env_ids:\n            # comment this line so that the env becomes indispensable\n            # self.reset()\n            self._zombie = True\n            raise StopIteration\n\n        return np.stack(obs)\n\n    def step(\n        self,\n        action: np.ndarray,\n        id: int | List[int] | np.ndarray | None = None,\n    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:\n        assert not self._zombie\n        wrapped_id = self._wrap_id(id)\n        id2idx = {i: k for k, i in enumerate(wrapped_id)}\n        request_id = list(filter(lambda i: i in self._alive_env_ids, wrapped_id))\n        result = [[None, None, False, None] for _ in range(len(wrapped_id))]\n\n        # ask super to step alive envs and remap to current index\n        if request_id:\n            valid_act = np.stack([action[id2idx[i]] for i in request_id])\n            for i, r in zip(request_id, zip(*super().step(valid_act, request_id))):\n                result[id2idx[i]] = list(r)\n                result[id2idx[i]][0] = self._postproc_env_obs(result[id2idx[i]][0])\n\n        # logging\n        for i, r in zip(wrapped_id, result):\n            if i in self._alive_env_ids:\n                for logger in self._logger:\n                    logger.on_env_step(i, *r)\n\n        # fill empty observation/info with default(fake)\n        for _, r, ___, i in result:\n            self._set_default_info(i)\n            self._set_default_rew(r)\n        for i, r in enumerate(result):\n            if r[0] is None:\n                result[i][0] = self._get_default_obs()\n            if r[1] is None:\n                result[i][1] = self._get_default_rew()\n            if r[3] is None:\n                result[i][3] = self._get_default_info()\n\n        ret = list(map(np.stack, zip(*result)))\n        return cast(Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], ret)\n\n\nclass FiniteDummyVectorEnv(FiniteVectorEnv, DummyVectorEnv):\n    pass\n\n\nclass FiniteSubprocVectorEnv(FiniteVectorEnv, SubprocVectorEnv):\n    pass\n\n\nclass FiniteShmemVectorEnv(FiniteVectorEnv, ShmemVectorEnv):\n    pass\n\n\ndef vectorize_env(\n    env_factory: Callable[..., gym.Env],\n    env_type: FiniteEnvType,\n    concurrency: int,\n    logger: LogWriter | List[LogWriter],\n) -> FiniteVectorEnv:\n    \"\"\"Helper function to create a vector env. Can be used to replace usual VectorEnv.\n\n    For example, once you wrote: ::\n\n        DummyVectorEnv([lambda: gym.make(task) for _ in range(env_num)])\n\n    Now you can replace it with: ::\n\n        finite_env_factory(lambda: gym.make(task), \"dummy\", env_num, my_logger)\n\n    By doing such replacement, you have two additional features enabled (compared to normal VectorEnv):\n\n    1. The vector env will check for NaN observation and kill the worker when its found.\n       See :class:`FiniteVectorEnv` for why we need this.\n    2. A logger to explicit collect logs from environment workers.\n\n    Parameters\n    ----------\n    env_factory\n        Callable to instantiate one single ``gym.Env``.\n        All concurrent workers will have the same ``env_factory``.\n    env_type\n        dummy or subproc or shmem. Corresponding to\n        `parallelism in tianshou <https://tianshou.readthedocs.io/en/master/api/tianshou.env.html#vectorenv>`_.\n    concurrency\n        Concurrent environment workers.\n    logger\n        Log writers.\n\n    Warnings\n    --------\n    Please do not use lambda expression here for ``env_factory`` as it may create incorrectly-shared instances.\n\n    Don't do: ::\n\n        vectorize_env(lambda: EnvWrapper(...), ...)\n\n    Please do: ::\n\n        def env_factory(): ...\n        vectorize_env(env_factory, ...)\n    \"\"\"\n    env_type_cls_mapping: Dict[str, Type[FiniteVectorEnv]] = {\n        \"dummy\": FiniteDummyVectorEnv,\n        \"subproc\": FiniteSubprocVectorEnv,\n        \"shmem\": FiniteShmemVectorEnv,\n    }\n\n    finite_env_cls = env_type_cls_mapping[env_type]\n\n    return finite_env_cls(logger, [env_factory for _ in range(concurrency)])\n"
  },
  {
    "path": "qlib/rl/utils/log.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\n\"\"\"Distributed logger for RL.\n\n:class:`LogCollector` runs in every environment workers. It collects log info from simulator states,\nand add them (as a dict) to auxiliary info returned for each step.\n\n:class:`LogWriter` runs in the central worker. It decodes the dict collected by :class:`LogCollector`\nin each worker, and writes them to console, log files, or tensorboard...\n\nThe two modules communicate by the \"log\" field in \"info\" returned by ``env.step()``.\n\"\"\"\n\n# NOTE: This file contains many hardcoded / ad-hoc rules.\n# Refactoring it will be one of the future tasks.\n\nfrom __future__ import annotations\n\nimport logging\nfrom collections import defaultdict\nfrom enum import IntEnum\nfrom pathlib import Path\nfrom typing import TYPE_CHECKING, Any, Callable, Dict, Generic, List, Sequence, Set, Tuple, TypeVar\n\nimport numpy as np\nimport pandas as pd\n\nfrom qlib.log import get_module_logger\n\nif TYPE_CHECKING:\n    from .env_wrapper import InfoDict\n\n\n__all__ = [\"LogCollector\", \"LogWriter\", \"LogLevel\", \"LogBuffer\", \"ConsoleWriter\", \"CsvWriter\"]\n\nObsType = TypeVar(\"ObsType\")\nActType = TypeVar(\"ActType\")\n\n\nclass LogLevel(IntEnum):\n    \"\"\"Log-levels for RL training.\n    The behavior of handling each log level depends on the implementation of :class:`LogWriter`.\n    \"\"\"\n\n    DEBUG = 10\n    \"\"\"If you only want to see the metric in debug mode.\"\"\"\n    PERIODIC = 20\n    \"\"\"If you want to see the metric periodically.\"\"\"\n    # FIXME: I haven't given much thought about this. Let's hold it for one iteration.\n\n    INFO = 30\n    \"\"\"Important log messages.\"\"\"\n    CRITICAL = 40\n    \"\"\"LogWriter should always handle CRITICAL messages\"\"\"\n\n\nclass LogCollector:\n    \"\"\"Logs are first collected in each environment worker,\n    and then aggregated to stream at the central thread in vector env.\n\n    In :class:`LogCollector`, every metric is added to a dict, which needs to be ``reset()`` at each step.\n    The dict is sent via the ``info`` in ``env.step()``, and decoded by the :class:`LogWriter` at vector env.\n\n    ``min_loglevel`` is for optimization purposes: to avoid too much traffic on networks / in pipe.\n    \"\"\"\n\n    _logged: Dict[str, Tuple[int, Any]]\n    _min_loglevel: int\n\n    def __init__(self, min_loglevel: int | LogLevel = LogLevel.PERIODIC) -> None:\n        self._min_loglevel = int(min_loglevel)\n\n    def reset(self) -> None:\n        \"\"\"Clear all collected contents.\"\"\"\n        self._logged = {}\n\n    def _add_metric(self, name: str, metric: Any, loglevel: int | LogLevel) -> None:\n        if name in self._logged:\n            raise ValueError(f\"A metric with {name} is already added. Please change a name or reset the log collector.\")\n        self._logged[name] = (int(loglevel), metric)\n\n    def add_string(self, name: str, string: str, loglevel: int | LogLevel = LogLevel.PERIODIC) -> None:\n        \"\"\"Add a string with name into logged contents.\"\"\"\n        if loglevel < self._min_loglevel:\n            return\n        if not isinstance(string, str):\n            raise TypeError(f\"{string} is not a string.\")\n        self._add_metric(name, string, loglevel)\n\n    def add_scalar(self, name: str, scalar: Any, loglevel: int | LogLevel = LogLevel.PERIODIC) -> None:\n        \"\"\"Add a scalar with name into logged contents.\n        Scalar will be converted into a float.\n        \"\"\"\n        if loglevel < self._min_loglevel:\n            return\n\n        if hasattr(scalar, \"item\"):\n            # could be single-item number\n            scalar = scalar.item()\n        if not isinstance(scalar, (float, int)):\n            raise TypeError(f\"{scalar} is not and can not be converted into float or integer.\")\n        scalar = float(scalar)\n        self._add_metric(name, scalar, loglevel)\n\n    def add_array(\n        self,\n        name: str,\n        array: np.ndarray | pd.DataFrame | pd.Series,\n        loglevel: int | LogLevel = LogLevel.PERIODIC,\n    ) -> None:\n        \"\"\"Add an array with name into logging.\"\"\"\n        if loglevel < self._min_loglevel:\n            return\n\n        if not isinstance(array, (np.ndarray, pd.DataFrame, pd.Series)):\n            raise TypeError(f\"{array} is not one of ndarray, DataFrame and Series.\")\n        self._add_metric(name, array, loglevel)\n\n    def add_any(self, name: str, obj: Any, loglevel: int | LogLevel = LogLevel.PERIODIC) -> None:\n        \"\"\"Log something with any type.\n\n        As it's an \"any\" object, the only LogWriter accepting it is pickle.\n        Therefore, pickle must be able to serialize it.\n        \"\"\"\n        if loglevel < self._min_loglevel:\n            return\n\n        # FIXME: detect and rescue object that could be scalar or array\n\n        self._add_metric(name, obj, loglevel)\n\n    def logs(self) -> Dict[str, np.ndarray]:\n        return {key: np.asanyarray(value, dtype=\"object\") for key, value in self._logged.items()}\n\n\nclass LogWriter(Generic[ObsType, ActType]):\n    \"\"\"Base class for log writers, triggered at every reset and step by finite env.\n\n    What to do with a specific log depends on the implementation of subclassing :class:`LogWriter`.\n    The general principle is that, it should handle logs above its loglevel (inclusive),\n    and discard logs that are not acceptable. For instance, console loggers obviously can't handle an image.\n    \"\"\"\n\n    episode_count: int\n    \"\"\"Counter of episodes.\"\"\"\n\n    step_count: int\n    \"\"\"Counter of steps.\"\"\"\n\n    global_step: int\n    \"\"\"Counter of steps. Won\"t be cleared in ``clear``.\"\"\"\n\n    global_episode: int\n    \"\"\"Counter of episodes. Won\"t be cleared in ``clear``.\"\"\"\n\n    active_env_ids: Set[int]\n    \"\"\"Active environment ids in vector env.\"\"\"\n\n    episode_lengths: Dict[int, int]\n    \"\"\"Map from environment id to episode length.\"\"\"\n\n    episode_rewards: Dict[int, List[float]]\n    \"\"\"Map from environment id to episode total reward.\"\"\"\n\n    episode_logs: Dict[int, list]\n    \"\"\"Map from environment id to episode logs.\"\"\"\n\n    def __init__(self, loglevel: int | LogLevel = LogLevel.PERIODIC) -> None:\n        self.loglevel = loglevel\n\n        self.global_step = 0\n        self.global_episode = 0\n\n        # Information, logs of one episode is stored here.\n        # This assumes that episode is not too long to fit into the memory.\n        self.episode_lengths = dict()\n        self.episode_rewards = dict()\n        self.episode_logs = dict()\n\n        self.clear()\n\n    def clear(self):\n        \"\"\"Clear all the metrics for a fresh start.\n        To make the logger instance reusable.\n        \"\"\"\n        self.episode_count = self.step_count = 0\n        self.active_env_ids = set()\n\n    def state_dict(self) -> dict:\n        \"\"\"Save the states of the logger to a dict.\"\"\"\n        return {\n            \"episode_count\": self.episode_count,\n            \"step_count\": self.step_count,\n            \"global_step\": self.global_step,\n            \"global_episode\": self.global_episode,\n            \"active_env_ids\": self.active_env_ids,\n            \"episode_lengths\": self.episode_lengths,\n            \"episode_rewards\": self.episode_rewards,\n            \"episode_logs\": self.episode_logs,\n        }\n\n    def load_state_dict(self, state_dict: dict) -> None:\n        \"\"\"Load the states of current logger from a dict.\"\"\"\n        self.episode_count = state_dict[\"episode_count\"]\n        self.step_count = state_dict[\"step_count\"]\n        self.global_step = state_dict[\"global_step\"]\n        self.global_episode = state_dict[\"global_episode\"]\n\n        # These are runtime infos.\n        # Though they are loaded, I don't think it really helps.\n        self.active_env_ids = state_dict[\"active_env_ids\"]\n        self.episode_lengths = state_dict[\"episode_lengths\"]\n        self.episode_rewards = state_dict[\"episode_rewards\"]\n        self.episode_logs = state_dict[\"episode_logs\"]\n\n    @staticmethod\n    def aggregation(array: Sequence[Any], name: str | None = None) -> Any:\n        \"\"\"Aggregation function from step-wise to episode-wise.\n\n        If it's a sequence of float, take the mean.\n        Otherwise, take the first element.\n\n        If a name is specified and,\n\n        - if it's ``reward``, the reduction will be sum.\n        \"\"\"\n        assert len(array) > 0, \"The aggregated array must be not empty.\"\n        if all(isinstance(v, float) for v in array):\n            if name == \"reward\":\n                return np.sum(array)\n            return np.mean(array)\n        else:\n            return array[0]\n\n    def log_episode(self, length: int, rewards: List[float], contents: List[Dict[str, Any]]) -> None:\n        \"\"\"This is triggered at the end of each trajectory.\n\n        Parameters\n        ----------\n        length\n            Length of this trajectory.\n        rewards\n            A list of rewards at each step of this episode.\n        contents\n            Logged contents for every step.\n        \"\"\"\n\n    def log_step(self, reward: float, contents: Dict[str, Any]) -> None:\n        \"\"\"This is triggered at each step.\n\n        Parameters\n        ----------\n        reward\n            Reward for this step.\n        contents\n            Logged contents for this step.\n        \"\"\"\n\n    def on_env_step(self, env_id: int, obs: ObsType, rew: float, done: bool, info: InfoDict) -> None:\n        \"\"\"Callback for finite env, on each step.\"\"\"\n\n        # Update counter\n        self.global_step += 1\n        self.step_count += 1\n\n        self.active_env_ids.add(env_id)\n        self.episode_lengths[env_id] += 1\n        # TODO: reward can be a list of list for MARL\n        self.episode_rewards[env_id].append(rew)\n\n        values: Dict[str, Any] = {}\n\n        for key, (loglevel, value) in info[\"log\"].items():\n            if loglevel >= self.loglevel:  # FIXME: this is actually incorrect (see last FIXME)\n                values[key] = value\n        self.episode_logs[env_id].append(values)\n\n        self.log_step(rew, values)\n\n        if done:\n            # Update counter\n            self.global_episode += 1\n            self.episode_count += 1\n\n            self.log_episode(self.episode_lengths[env_id], self.episode_rewards[env_id], self.episode_logs[env_id])\n\n    def on_env_reset(self, env_id: int, _: ObsType) -> None:\n        \"\"\"Callback for finite env.\n\n        Reset episode statistics. Nothing task-specific is logged here because of\n        `a limitation of tianshou <https://github.com/thu-ml/tianshou/issues/605>`__.\n        \"\"\"\n        self.episode_lengths[env_id] = 0\n        self.episode_rewards[env_id] = []\n        self.episode_logs[env_id] = []\n\n    def on_env_all_ready(self) -> None:\n        \"\"\"When all environments are ready to run.\n        Usually, loggers should be reset here.\n        \"\"\"\n        self.clear()\n\n    def on_env_all_done(self) -> None:\n        \"\"\"All done. Time for cleanup.\"\"\"\n\n\nclass LogBuffer(LogWriter):\n    \"\"\"Keep all numbers in memory.\n\n    Objects that can't be aggregated like strings, tensors, images can't be stored in the buffer.\n    To persist them, please use :class:`PickleWriter`.\n\n    Every time, Log buffer receives a new metric, the callback is triggered,\n    which is useful when tracking metrics inside a trainer.\n\n    Parameters\n    ----------\n    callback\n        A callback receiving three arguments:\n\n        - on_episode: Whether it's called at the end of an episode\n        - on_collect: Whether it's called at the end of a collect\n        - log_buffer: the :class:`LogBbuffer` object\n\n        No return value is expected.\n    \"\"\"\n\n    # FIXME: needs a metric count\n\n    def __init__(self, callback: Callable[[bool, bool, LogBuffer], None], loglevel: int | LogLevel = LogLevel.PERIODIC):\n        super().__init__(loglevel)\n        self.callback = callback\n\n    def state_dict(self) -> dict:\n        return {\n            **super().state_dict(),\n            \"latest_metrics\": self._latest_metrics,\n            \"aggregated_metrics\": self._aggregated_metrics,\n        }\n\n    def load_state_dict(self, state_dict: dict) -> None:\n        self._latest_metrics = state_dict[\"latest_metrics\"]\n        self._aggregated_metrics = state_dict[\"aggregated_metrics\"]\n        return super().load_state_dict(state_dict)\n\n    def clear(self):\n        super().clear()\n        self._latest_metrics: dict[str, float] | None = None\n        self._aggregated_metrics: dict[str, float] = defaultdict(float)\n\n    def log_episode(self, length: int, rewards: list[float], contents: list[dict[str, Any]]) -> None:\n        # FIXME Dup of ConsoleWriter\n        episode_wise_contents: dict[str, list] = defaultdict(list)\n        for step_contents in contents:\n            for name, value in step_contents.items():\n                # FIXME This could be false-negative for some numpy types\n                if isinstance(value, float):\n                    episode_wise_contents[name].append(value)\n\n        logs: dict[str, float] = {}\n        for name, values in episode_wise_contents.items():\n            logs[name] = self.aggregation(values, name)  # type: ignore\n            self._aggregated_metrics[name] += logs[name]\n\n        self._latest_metrics = logs\n\n        self.callback(True, False, self)\n\n    def on_env_all_done(self) -> None:\n        # This happens when collect exits\n        self.callback(False, True, self)\n\n    def episode_metrics(self) -> dict[str, float]:\n        \"\"\"Retrieve the numeric metrics of the latest episode.\"\"\"\n        if self._latest_metrics is None:\n            raise ValueError(\"No episode metrics available yet.\")\n        return self._latest_metrics\n\n    def collect_metrics(self) -> dict[str, float]:\n        \"\"\"Retrieve the aggregated metrics of the latest collect.\"\"\"\n        return {name: value / self.episode_count for name, value in self._aggregated_metrics.items()}\n\n\nclass ConsoleWriter(LogWriter):\n    \"\"\"Write log messages to console periodically.\n\n    It tracks an average meter for each metric, which is the average value since last ``clear()`` till now.\n    The display format for each metric is ``<name> <latest_value> (<average_value>)``.\n\n    Non-single-number metrics are auto skipped.\n    \"\"\"\n\n    prefix: str\n    \"\"\"Prefix can be set via ``writer.prefix``.\"\"\"\n\n    def __init__(\n        self,\n        log_every_n_episode: int = 20,\n        total_episodes: int | None = None,\n        float_format: str = \":.4f\",\n        counter_format: str = \":4d\",\n        loglevel: int | LogLevel = LogLevel.PERIODIC,\n    ) -> None:\n        super().__init__(loglevel)\n        # TODO: support log_every_n_step\n        self.log_every_n_episode = log_every_n_episode\n        self.total_episodes = total_episodes\n\n        self.counter_format = counter_format\n        self.float_format = float_format\n\n        self.prefix = \"\"\n\n        self.console_logger = get_module_logger(__name__, level=logging.INFO)\n\n    # FIXME: save & reload\n\n    def clear(self) -> None:\n        super().clear()\n        # Clear average meters\n        self.metric_counts: Dict[str, int] = defaultdict(int)\n        self.metric_sums: Dict[str, float] = defaultdict(float)\n\n    def log_episode(self, length: int, rewards: List[float], contents: List[Dict[str, Any]]) -> None:\n        # Aggregate step-wise to episode-wise\n        episode_wise_contents: Dict[str, list] = defaultdict(list)\n\n        for step_contents in contents:\n            for name, value in step_contents.items():\n                if isinstance(value, float):\n                    episode_wise_contents[name].append(value)\n\n        # Generate log contents and track them in average-meter.\n        # This should be done at every step, regardless of periodic or not.\n        logs: Dict[str, float] = {}\n        for name, values in episode_wise_contents.items():\n            logs[name] = self.aggregation(values, name)  # type: ignore\n\n        for name, value in logs.items():\n            self.metric_counts[name] += 1\n            self.metric_sums[name] += value\n\n        if self.episode_count % self.log_every_n_episode == 0 or self.episode_count == self.total_episodes:\n            # Only log periodically or at the end\n            self.console_logger.info(self.generate_log_message(logs))\n\n    def generate_log_message(self, logs: Dict[str, float]) -> str:\n        if self.prefix:\n            msg_prefix = self.prefix + \" \"\n        else:\n            msg_prefix = \"\"\n        if self.total_episodes is None:\n            msg_prefix += \"[Step {\" + self.counter_format + \"}]\"\n        else:\n            msg_prefix += \"[{\" + self.counter_format + \"}/\" + str(self.total_episodes) + \"]\"\n        msg_prefix = msg_prefix.format(self.episode_count)\n\n        msg = \"\"\n        for name, value in logs.items():\n            # Double-space as delimiter\n            format_template = r\"  {} {\" + self.float_format + \"} ({\" + self.float_format + \"})\"\n            msg += format_template.format(name, value, self.metric_sums[name] / self.metric_counts[name])\n\n        msg = msg_prefix + \" \" + msg\n\n        return msg\n\n\nclass CsvWriter(LogWriter):\n    \"\"\"Dump all episode metrics to a ``result.csv``.\n\n    This is not the correct implementation. It's only used for first iteration.\n    \"\"\"\n\n    SUPPORTED_TYPES = (float, str, pd.Timestamp)\n\n    all_records: List[Dict[str, Any]]\n\n    # FIXME: save & reload\n\n    def __init__(self, output_dir: Path, loglevel: int | LogLevel = LogLevel.PERIODIC) -> None:\n        super().__init__(loglevel)\n        self.output_dir = output_dir\n        self.output_dir.mkdir(exist_ok=True)\n\n    def clear(self) -> None:\n        super().clear()\n        self.all_records = []\n\n    def log_episode(self, length: int, rewards: List[float], contents: List[Dict[str, Any]]) -> None:\n        # FIXME Same as ConsoleLogger, needs a refactor to eliminate code-dup\n        episode_wise_contents: Dict[str, list] = defaultdict(list)\n\n        for step_contents in contents:\n            for name, value in step_contents.items():\n                if isinstance(value, self.SUPPORTED_TYPES):\n                    episode_wise_contents[name].append(value)\n\n        logs: Dict[str, float] = {}\n        for name, values in episode_wise_contents.items():\n            logs[name] = self.aggregation(values, name)  # type: ignore\n\n        self.all_records.append(logs)\n\n    def on_env_all_done(self) -> None:\n        # FIXME: this is temporary\n        pd.DataFrame.from_records(self.all_records).to_csv(self.output_dir / \"result.csv\", index=False)\n\n\n# The following are not implemented yet.\n\n\nclass PickleWriter(LogWriter):\n    \"\"\"Dump logs to pickle files.\"\"\"\n\n\nclass TensorboardWriter(LogWriter):\n    \"\"\"Write logs to event files that can be visualized with tensorboard.\"\"\"\n\n\nclass MlflowWriter(LogWriter):\n    \"\"\"Add logs to mlflow.\"\"\"\n"
  },
  {
    "path": "qlib/strategy/__init__.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n"
  },
  {
    "path": "qlib/strategy/base.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\nfrom __future__ import annotations\n\nfrom abc import ABCMeta, abstractmethod\nfrom typing import Any, Generator, Optional, TYPE_CHECKING, Union\n\nif TYPE_CHECKING:\n    from qlib.backtest.exchange import Exchange\n    from qlib.backtest.position import BasePosition\n    from qlib.backtest.executor import BaseExecutor\n\nfrom typing import Tuple\n\nfrom ..backtest.decision import BaseTradeDecision\nfrom ..backtest.utils import CommonInfrastructure, LevelInfrastructure, TradeCalendarManager\nfrom ..rl.interpreter import ActionInterpreter, StateInterpreter\nfrom ..utils import init_instance_by_config\n\n__all__ = [\"BaseStrategy\", \"RLStrategy\", \"RLIntStrategy\"]\n\n\nclass BaseStrategy:\n    \"\"\"Base strategy for trading\"\"\"\n\n    def __init__(\n        self,\n        outer_trade_decision: BaseTradeDecision = None,\n        level_infra: LevelInfrastructure = None,\n        common_infra: CommonInfrastructure = None,\n        trade_exchange: Exchange = None,\n    ) -> None:\n        \"\"\"\n        Parameters\n        ----------\n        outer_trade_decision : BaseTradeDecision, optional\n            the trade decision of outer strategy which this strategy relies, and it will be traded in\n            [start_time, end_time], by default None\n\n            - If the strategy is used to split trade decision, it will be used\n            - If the strategy is used for portfolio management, it can be ignored\n        level_infra : LevelInfrastructure, optional\n            level shared infrastructure for backtesting, including trade calendar\n        common_infra : CommonInfrastructure, optional\n            common infrastructure for backtesting, including trade_account, trade_exchange, .etc\n\n        trade_exchange : Exchange\n            exchange that provides market info, used to deal order and generate report\n\n            - If `trade_exchange` is None, self.trade_exchange will be set with common_infra\n            - It allows different trade_exchanges is used in different executions.\n            - For example:\n\n                - In daily execution, both daily exchange and minutely are usable, but the daily exchange is\n                  recommended because it run faster.\n                - In minutely execution, the daily exchange is not usable, only the minutely exchange is recommended.\n        \"\"\"\n\n        self._reset(level_infra=level_infra, common_infra=common_infra, outer_trade_decision=outer_trade_decision)\n        self._trade_exchange = trade_exchange\n\n    @property\n    def executor(self) -> BaseExecutor:\n        return self.level_infra.get(\"executor\")\n\n    @property\n    def trade_calendar(self) -> TradeCalendarManager:\n        return self.level_infra.get(\"trade_calendar\")\n\n    @property\n    def trade_position(self) -> BasePosition:\n        return self.common_infra.get(\"trade_account\").current_position\n\n    @property\n    def trade_exchange(self) -> Exchange:\n        \"\"\"get trade exchange in a prioritized order\"\"\"\n        return getattr(self, \"_trade_exchange\", None) or self.common_infra.get(\"trade_exchange\")\n\n    def reset_level_infra(self, level_infra: LevelInfrastructure) -> None:\n        if not hasattr(self, \"level_infra\"):\n            self.level_infra = level_infra\n        else:\n            self.level_infra.update(level_infra)\n\n    def reset_common_infra(self, common_infra: CommonInfrastructure) -> None:\n        if not hasattr(self, \"common_infra\"):\n            self.common_infra: CommonInfrastructure = common_infra\n        else:\n            self.common_infra.update(common_infra)\n\n    def reset(\n        self,\n        level_infra: LevelInfrastructure = None,\n        common_infra: CommonInfrastructure = None,\n        outer_trade_decision: BaseTradeDecision = None,\n        **kwargs,\n    ) -> None:\n        \"\"\"\n        - reset `level_infra`, used to reset trade calendar, .etc\n        - reset `common_infra`, used to reset `trade_account`, `trade_exchange`, .etc\n        - reset `outer_trade_decision`, used to make split decision\n\n        **NOTE**:\n        split this function into `reset` and `_reset` will make following cases more convenient\n        1. Users want to initialize his strategy by overriding `reset`, but they don't want to affect the `_reset`\n        called when initialization\n        \"\"\"\n        self._reset(\n            level_infra=level_infra,\n            common_infra=common_infra,\n            outer_trade_decision=outer_trade_decision,\n        )\n\n    def _reset(\n        self,\n        level_infra: LevelInfrastructure = None,\n        common_infra: CommonInfrastructure = None,\n        outer_trade_decision: BaseTradeDecision = None,\n    ):\n        \"\"\"\n        Please refer to the docs of `reset`\n        \"\"\"\n        if level_infra is not None:\n            self.reset_level_infra(level_infra)\n\n        if common_infra is not None:\n            self.reset_common_infra(common_infra)\n\n        if outer_trade_decision is not None:\n            self.outer_trade_decision = outer_trade_decision\n\n    @abstractmethod\n    def generate_trade_decision(\n        self,\n        execute_result: list = None,\n    ) -> Union[BaseTradeDecision, Generator[Any, Any, BaseTradeDecision]]:\n        \"\"\"Generate trade decision in each trading bar\n\n        Parameters\n        ----------\n        execute_result : List[object], optional\n            the executed result for trade decision, by default None\n\n            - When call the generate_trade_decision firstly, `execute_result` could be None\n        \"\"\"\n        raise NotImplementedError(\"generate_trade_decision is not implemented!\")\n\n    # helper methods: not necessary but for convenience\n    def get_data_cal_avail_range(self, rtype: str = \"full\") -> Tuple[int, int]:\n        \"\"\"\n        return data calendar's available decision range for `self` strategy\n        the range consider following factors\n        - data calendar in the charge of `self` strategy\n        - trading range limitation from the decision of outer strategy\n\n\n        related methods\n        - TradeCalendarManager.get_data_cal_range\n        - BaseTradeDecision.get_data_cal_range_limit\n\n        Parameters\n        ----------\n        rtype: str\n            - \"full\": return the available data index range of the strategy from `start_time` to `end_time`\n            - \"step\": return the available data index range of the strategy of current step\n\n        Returns\n        -------\n        Tuple[int, int]:\n            the available range both sides are closed\n        \"\"\"\n        cal_range = self.trade_calendar.get_data_cal_range(rtype=rtype)\n        if self.outer_trade_decision is None:\n            raise ValueError(f\"There is not limitation for strategy {self}\")\n        range_limit = self.outer_trade_decision.get_data_cal_range_limit(rtype=rtype)\n        return max(cal_range[0], range_limit[0]), min(cal_range[1], range_limit[1])\n\n    \"\"\"\n    The following methods are used to do cross-level communications in nested execution.\n    You do not need to care about them if you are implementing a single-level execution.\n    \"\"\"\n\n    @staticmethod\n    def update_trade_decision(\n        trade_decision: BaseTradeDecision,\n        trade_calendar: TradeCalendarManager,\n    ) -> Optional[BaseTradeDecision]:\n        \"\"\"\n        update trade decision in each step of inner execution, this method enable all order\n\n        Parameters\n        ----------\n        trade_decision : BaseTradeDecision\n            the trade decision that will be updated\n        trade_calendar : TradeCalendarManager\n            The calendar of the **inner strategy**!!!!!\n\n        Returns\n        -------\n            BaseTradeDecision:\n        \"\"\"\n        # default to return None, which indicates that the trade decision is not changed\n        return None\n\n    def alter_outer_trade_decision(self, outer_trade_decision: BaseTradeDecision) -> BaseTradeDecision:\n        \"\"\"\n        A method for updating the outer_trade_decision.\n        The outer strategy may change its decision during updating.\n\n        Parameters\n        ----------\n        outer_trade_decision : BaseTradeDecision\n            the decision updated by the outer strategy\n\n        Returns\n        -------\n            BaseTradeDecision\n        \"\"\"\n        # default to reset the decision directly\n        # NOTE: normally, user should do something to the strategy due to the change of outer decision\n        return outer_trade_decision\n\n    def post_upper_level_exe_step(self) -> None:\n        \"\"\"\n        A hook for doing sth after the upper level executor finished its execution (for example, finalize\n        the metrics collection).\n        \"\"\"\n\n    def post_exe_step(self, execute_result: Optional[list]) -> None:\n        \"\"\"\n        A hook for doing sth after the corresponding executor finished its execution.\n\n        Parameters\n        ----------\n        execute_result :\n            the execution result\n        \"\"\"\n\n\nclass RLStrategy(BaseStrategy, metaclass=ABCMeta):\n    \"\"\"RL-based strategy\"\"\"\n\n    def __init__(\n        self,\n        policy,\n        outer_trade_decision: BaseTradeDecision = None,\n        level_infra: LevelInfrastructure = None,\n        common_infra: CommonInfrastructure = None,\n        **kwargs,\n    ) -> None:\n        \"\"\"\n        Parameters\n        ----------\n        policy :\n            RL policy for generate action\n        \"\"\"\n        super(RLStrategy, self).__init__(outer_trade_decision, level_infra, common_infra, **kwargs)\n        self.policy = policy\n\n\nclass RLIntStrategy(RLStrategy, metaclass=ABCMeta):\n    \"\"\"(RL)-based (Strategy) with (Int)erpreter\"\"\"\n\n    def __init__(\n        self,\n        policy,\n        state_interpreter: dict | StateInterpreter,\n        action_interpreter: dict | ActionInterpreter,\n        outer_trade_decision: BaseTradeDecision = None,\n        level_infra: LevelInfrastructure = None,\n        common_infra: CommonInfrastructure = None,\n        **kwargs,\n    ) -> None:\n        \"\"\"\n        Parameters\n        ----------\n        state_interpreter : Union[dict, StateInterpreter]\n            interpreter that interprets the qlib execute result into rl env state\n        action_interpreter : Union[dict, ActionInterpreter]\n            interpreter that interprets the rl agent action into qlib order list\n        start_time : Union[str, pd.Timestamp], optional\n            start time of trading, by default None\n        end_time : Union[str, pd.Timestamp], optional\n            end time of trading, by default None\n        \"\"\"\n        super(RLIntStrategy, self).__init__(policy, outer_trade_decision, level_infra, common_infra, **kwargs)\n\n        self.policy = policy\n        self.state_interpreter = init_instance_by_config(state_interpreter, accept_types=StateInterpreter)\n        self.action_interpreter = init_instance_by_config(action_interpreter, accept_types=ActionInterpreter)\n\n    def generate_trade_decision(self, execute_result: list = None) -> BaseTradeDecision:\n        _interpret_state = self.state_interpreter.interpret(execute_result=execute_result)\n        _action = self.policy.step(_interpret_state)\n        _trade_decision = self.action_interpreter.interpret(action=_action)\n        return _trade_decision\n"
  },
  {
    "path": "qlib/tests/__init__.py",
    "content": "from typing import Union, List, Dict, Tuple\nimport unittest\nimport pandas as pd\nimport numpy as np\nimport io\n\nfrom .data import GetData\nfrom .. import init\nfrom ..constant import REG_CN, REG_TW\nfrom qlib.data.filter import NameDFilter\nfrom qlib.data import D\nfrom qlib.data.data import Cal, DatasetD\nfrom qlib.data.storage import CalendarStorage, InstrumentStorage, FeatureStorage, CalVT, InstKT, InstVT\n\n\nclass TestAutoData(unittest.TestCase):\n    _setup_kwargs = {}\n    provider_uri = \"~/.qlib/qlib_data/cn_data_simple\"  # target_dir\n    provider_uri_1day = \"~/.qlib/qlib_data/cn_data\"  # target_dir\n    provider_uri_1min = \"~/.qlib/qlib_data/cn_data_1min\"\n\n    @classmethod\n    def setUpClass(cls, enable_1d_type=\"simple\", enable_1min=False) -> None:\n        # use default data\n\n        if enable_1d_type == \"simple\":\n            provider_uri_day = cls.provider_uri\n            name_day = \"qlib_data_simple\"\n        elif enable_1d_type == \"full\":\n            provider_uri_day = cls.provider_uri_1day\n            name_day = \"qlib_data\"\n        else:\n            raise NotImplementedError(f\"This type of input is not supported\")\n\n        GetData().qlib_data(\n            name=name_day,\n            region=REG_CN,\n            interval=\"1d\",\n            target_dir=provider_uri_day,\n            delete_old=False,\n            exists_skip=True,\n        )\n\n        if enable_1min:\n            GetData().qlib_data(\n                name=\"qlib_data\",\n                region=REG_CN,\n                interval=\"1min\",\n                target_dir=cls.provider_uri_1min,\n                delete_old=False,\n                exists_skip=True,\n            )\n\n        provider_uri_map = {\"1min\": cls.provider_uri_1min, \"day\": provider_uri_day}\n        init(\n            provider_uri=provider_uri_map,\n            region=REG_CN,\n            expression_cache=None,\n            dataset_cache=None,\n            **cls._setup_kwargs,\n        )\n\n\nclass TestOperatorData(TestAutoData):\n    @classmethod\n    def setUpClass(cls, enable_1d_type=\"simple\", enable_1min=False) -> None:\n        # use default data\n        super().setUpClass(enable_1d_type, enable_1min)\n        nameDFilter = NameDFilter(name_rule_re=\"SH600110\")\n        instruments = D.instruments(\"csi300\", filter_pipe=[nameDFilter])\n        start_time = \"2005-01-04\"\n        end_time = \"2005-12-31\"\n        freq = \"day\"\n\n        instruments_d = DatasetD.get_instruments_d(instruments, freq)\n        cls.instruments_d = instruments_d\n        cal = Cal.calendar(start_time, end_time, freq)\n        cls.cal = cal\n        cls.start_time = cal[0]\n        cls.end_time = cal[-1]\n        cls.inst = list(instruments_d.keys())[0]\n        cls.spans = list(instruments_d.values())[0]\n\n\nMOCK_DATA = \"\"\"\nid,symbol,datetime,interval,volume,open,high,low,close\n20275,0050,2022-01-03 00:00:00,day,6761.0,146.0,147.35,146.0,146.4\n20276,0050,2022-01-04 00:00:00,day,9608.0,147.7,149.6,147.7,149.6\n20277,0050,2022-01-05 00:00:00,day,11387.0,150.1,150.55,149.1,149.3\n20278,0050,2022-01-06 00:00:00,day,8611.0,148.3,148.75,147.0,147.9\n20279,0050,2022-01-07 00:00:00,day,6954.0,148.3,149.0,146.5,146.6\n20280,0050,2022-01-10 00:00:00,day,15684.0,146.0,147.8,145.4,147.55\n20281,0050,2022-01-11 00:00:00,day,17741.0,147.6,148.5,146.7,148.3\n20282,0050,2022-01-12 00:00:00,day,10134.0,149.35,149.6,148.7,149.55\n20283,0050,2022-01-13 00:00:00,day,7431.0,149.55,150.45,149.55,150.3\n20284,0050,2022-01-14 00:00:00,day,10091.0,150.8,151.2,149.05,150.3\n20285,0050,2022-01-17 00:00:00,day,6899.0,151.1,152.4,151.1,152.0\n20286,0050,2022-01-18 00:00:00,day,14360.0,152.2,152.25,150.15,150.3\n20287,0050,2022-01-19 00:00:00,day,14654.0,149.0,149.65,148.25,148.5\n20288,0050,2022-01-20 00:00:00,day,16201.0,148.5,149.2,147.6,149.1\n20289,0050,2022-01-21 00:00:00,day,29848.0,143.9,143.95,142.3,142.65\n20290,0050,2022-01-24 00:00:00,day,13143.0,142.1,144.0,141.7,144.0\n20291,0050,2022-01-25 00:00:00,day,23982.0,142.55,142.55,141.25,141.65\n20292,0050,2022-01-26 00:00:00,day,17729.0,141.15,142.2,141.05,141.55\n8547,1101,2021-12-01 00:00:00,day,16119.0,46.0,46.85,46.0,46.6\n8548,1101,2021-12-02 00:00:00,day,14521.0,46.6,46.7,46.3,46.3\n8549,1101,2021-12-03 00:00:00,day,14357.0,46.55,46.85,46.4,46.4\n8550,1101,2021-12-06 00:00:00,day,15115.0,46.45,47.35,46.4,47.3\n8551,1101,2021-12-07 00:00:00,day,13117.0,47.35,47.55,46.9,47.55\n8552,1101,2021-12-08 00:00:00,day,10329.0,47.75,47.8,47.5,47.7\n8553,1101,2021-12-09 00:00:00,day,9300.0,47.8,47.85,47.1,47.4\n8554,1101,2021-12-10 00:00:00,day,9919.0,47.4,47.6,47.1,47.3\n8555,1101,2021-12-13 00:00:00,day,7784.0,47.3,47.75,47.1,47.1\n8556,1101,2021-12-14 00:00:00,day,9373.0,47.05,47.2,46.95,47.0\n8557,1101,2021-12-15 00:00:00,day,11189.0,47.0,47.3,46.8,46.95\n8558,1101,2021-12-16 00:00:00,day,7516.0,47.0,47.15,46.8,46.9\n8559,1101,2021-12-17 00:00:00,day,18502.0,46.95,47.6,46.9,47.45\n8560,1101,2021-12-20 00:00:00,day,11309.0,47.45,47.5,47.1,47.4\n8561,1101,2021-12-21 00:00:00,day,5666.0,47.4,47.45,47.1,47.25\n8562,1101,2021-12-22 00:00:00,day,5460.0,47.4,47.45,47.2,47.4\n8563,1101,2021-12-23 00:00:00,day,9371.0,47.3,47.7,47.3,47.7\n8564,1101,2021-12-24 00:00:00,day,5980.0,47.75,47.95,47.75,47.9\n8565,1101,2021-12-27 00:00:00,day,5709.0,47.9,48.1,47.9,48.1\n8566,1101,2021-12-28 00:00:00,day,7777.0,48.1,48.15,47.95,48.15\n8567,1101,2021-12-29 00:00:00,day,5309.0,48.15,48.25,48.05,48.15\n8568,1101,2021-12-30 00:00:00,day,4616.0,48.15,48.2,48.0,48.0\n8569,1101,2022-01-03 00:00:00,day,12350.0,48.05,48.15,47.35,47.45\n8570,1101,2022-01-04 00:00:00,day,11439.0,47.5,47.6,47.0,47.3\n8571,1101,2022-01-05 00:00:00,day,9692.0,47.1,47.3,47.0,47.15\n8572,1101,2022-01-06 00:00:00,day,12361.0,47.3,47.6,47.15,47.6\n8573,1101,2022-01-07 00:00:00,day,10921.0,47.6,47.65,47.2,47.45\n8574,1101,2022-01-10 00:00:00,day,11925.0,47.45,47.5,47.0,47.3\n8575,1101,2022-01-11 00:00:00,day,11047.0,47.1,47.5,47.1,47.5\n8576,1101,2022-01-12 00:00:00,day,10817.0,47.5,47.5,47.1,47.5\n8577,1101,2022-01-13 00:00:00,day,13849.0,47.5,47.95,47.4,47.95\n8578,1101,2022-01-14 00:00:00,day,9460.0,47.85,47.85,47.45,47.6\n8579,1101,2022-01-17 00:00:00,day,9057.0,47.55,47.7,47.35,47.6\n8580,1101,2022-01-18 00:00:00,day,8089.0,47.6,47.75,47.45,47.75\n8581,1101,2022-01-19 00:00:00,day,5110.0,47.6,47.7,47.5,47.6\n8582,1101,2022-01-20 00:00:00,day,6327.0,47.55,47.7,47.45,47.5\n8583,1101,2022-01-21 00:00:00,day,9470.0,47.5,47.65,47.15,47.4\n8584,1101,2022-01-24 00:00:00,day,5475.0,47.1,47.3,47.0,47.15\n8585,1101,2022-01-25 00:00:00,day,16153.0,47.0,47.05,46.6,46.8\n8586,1101,2022-01-26 00:00:00,day,7772.0,46.7,47.0,46.55,46.85\n8587,1101,2022-02-07 00:00:00,day,17031.0,46.55,47.1,46.0,47.1\n8588,1101,2022-02-08 00:00:00,day,9741.0,47.1,47.25,46.9,46.95\n8589,1101,2022-02-09 00:00:00,day,7968.0,46.95,47.3,46.9,47.3\n8590,1101,2022-02-10 00:00:00,day,7479.0,47.15,47.55,47.05,47.55\n8591,1101,2022-02-11 00:00:00,day,6841.0,47.3,47.55,47.15,47.55\n8592,1101,2022-02-14 00:00:00,day,9136.0,47.2,47.3,46.95,47.15\n8593,1101,2022-02-15 00:00:00,day,5444.0,47.05,47.1,46.8,47.0\n8594,1101,2022-02-16 00:00:00,day,8751.0,47.0,47.15,47.0,47.0\n8595,1101,2022-02-17 00:00:00,day,10662.0,47.15,47.55,47.1,47.45\n8596,1101,2022-02-18 00:00:00,day,8781.0,47.25,47.55,47.2,47.45\n8597,1101,2022-02-21 00:00:00,day,8201.0,47.35,47.75,47.15,47.6\n8598,1101,2022-02-22 00:00:00,day,10655.0,47.4,47.7,47.1,47.7\n8599,1101,2022-02-23 00:00:00,day,8040.0,47.7,47.85,47.45,47.65\n8600,1101,2022-02-24 00:00:00,day,13124.0,47.5,47.5,47.1,47.3\n8601,1101,2022-02-25 00:00:00,day,14556.0,47.2,47.5,46.9,47.35\n\"\"\"\n\nMOCK_DF = pd.read_csv(io.StringIO(MOCK_DATA), header=0, dtype={\"symbol\": str})\n\n\nclass MockStorageBase:\n    def __init__(self, **kwargs):\n        self.df = MOCK_DF\n\n\nclass MockCalendarStorage(MockStorageBase, CalendarStorage):\n    def __init__(self, **kwargs):\n        super().__init__()\n        self._data = sorted(self.df[\"datetime\"].unique())\n\n    @property\n    def data(self) -> List[CalVT]:\n        return self._data\n\n    def __getitem__(self, i: Union[int, slice]) -> Union[CalVT, List[CalVT]]:\n        return self.data[i]\n\n    def __len__(self) -> int:\n        return len(self.data)\n\n\nclass MockInstrumentStorage(MockStorageBase, InstrumentStorage):\n    def __init__(self, **kwargs):\n        super().__init__()\n        instruments = {}\n        for symbol, group in self.df.groupby(by=\"symbol\", group_keys=False):\n            start = group[\"datetime\"].iloc[0]\n            end = group[\"datetime\"].iloc[-1]\n            instruments[symbol] = [(start, end)]\n        self._data = instruments\n\n    @property\n    def data(self) -> Dict[InstKT, InstVT]:\n        return self._data\n\n    def __getitem__(self, k: InstKT) -> InstVT:\n        return self.data[k]\n\n    def __len__(self) -> int:\n        return len(self.data)\n\n\nclass MockFeatureStorage(MockStorageBase, FeatureStorage):\n    def __init__(self, instrument: str, field: str, freq: str, db_region: str = None, **kwargs):  # type: ignore\n        super().__init__(instrument=instrument, field=field, freq=freq, db_region=db_region, **kwargs)\n        self.field = field\n        calendar = sorted(self.df[\"datetime\"].unique())\n        df_calendar = pd.DataFrame(calendar, columns=[\"datetime\"]).set_index(\"datetime\")\n        df = self.df[self.df[\"symbol\"] == instrument]\n        data_dt_field = \"datetime\"\n        cal_df = df_calendar[\n            (df_calendar.index >= df[data_dt_field].min()) & (df_calendar.index <= df[data_dt_field].max())\n        ]\n        df = df.set_index(data_dt_field)\n        df_data = df.reindex(cal_df.index)\n        date_index = df_calendar.index.get_loc(df_data.index.min())  # type: ignore\n        df_data.reset_index(inplace=True)\n        df_data.index += date_index\n        self._data = df_data\n\n    @property\n    def data(self) -> pd.Series:\n        return self._data[self.field]\n\n    @property\n    def start_index(self) -> Union[int, None]:\n        if self._data.empty:\n            return None\n        return self._data.index[0]\n\n    @property\n    def end_index(self) -> Union[int, None]:\n        if self._data.empty:\n            return None\n        # The next  data appending index point will be  `end_index + 1`\n        return self._data.index[-1]\n\n    def __getitem__(self, i: Union[int, slice]) -> Union[Tuple[int, float], pd.Series]:\n        df = self._data\n        storage_start_index = df.index[0]\n        storage_end_index = df.index[-1]\n        if isinstance(i, int):\n            if storage_start_index > i or i > storage_end_index:\n                raise IndexError(f\"{i}: start index is {storage_start_index}\")\n            data = self.data[i]\n            return i, data\n        elif isinstance(i, slice):\n            start_index = storage_start_index if i.start is None else i.start\n            end_index = storage_end_index if i.stop is None else i.stop\n            si = max(start_index, storage_start_index)\n            if si > end_index or self.field not in df.columns:\n                return pd.Series(dtype=np.float32)  # type: ignore\n            data = df[self.field].tolist()\n            result = data[si - storage_start_index : end_index - storage_start_index]\n            return pd.Series(result, index=pd.RangeIndex(si, si + len(result)))  # type: ignore\n        else:\n            raise TypeError(f\"type(i) = {type(i)}\")\n\n    def __len__(self) -> int:\n        return len(self.data)\n\n\nclass TestMockData(unittest.TestCase):\n    _setup_kwargs = {\n        \"calendar_provider\": {\n            \"class\": \"LocalCalendarProvider\",\n            \"module_path\": \"qlib.data.data\",\n            \"kwargs\": {\"backend\": {\"class\": \"MockCalendarStorage\", \"module_path\": \"qlib.tests\"}},\n        },\n        \"instrument_provider\": {\n            \"class\": \"LocalInstrumentProvider\",\n            \"module_path\": \"qlib.data.data\",\n            \"kwargs\": {\"backend\": {\"class\": \"MockInstrumentStorage\", \"module_path\": \"qlib.tests\"}},\n        },\n        \"feature_provider\": {\n            \"class\": \"LocalFeatureProvider\",\n            \"module_path\": \"qlib.data.data\",\n            \"kwargs\": {\"backend\": {\"class\": \"MockFeatureStorage\", \"module_path\": \"qlib.tests\"}},\n        },\n    }\n\n    @classmethod\n    def setUpClass(cls) -> None:\n        provider_uri = \"Not necessary.\"\n        init(region=REG_TW, provider_uri=provider_uri, expression_cache=None, dataset_cache=None, **cls._setup_kwargs)\n"
  },
  {
    "path": "qlib/tests/config.py",
    "content": "#  Copyright (c) Microsoft Corporation.\n#  Licensed under the MIT License.\n\nCSI300_MARKET = \"csi300\"\nCSI100_MARKET = \"csi100\"\n\nCSI300_BENCH = \"SH000300\"\n\nDATASET_ALPHA158_CLASS = \"Alpha158\"\nDATASET_ALPHA360_CLASS = \"Alpha360\"\n\n###################################\n# config\n###################################\n\n\nGBDT_MODEL = {\n    \"class\": \"LGBModel\",\n    \"module_path\": \"qlib.contrib.model.gbdt\",\n    \"kwargs\": {\n        \"loss\": \"mse\",\n        \"colsample_bytree\": 0.8879,\n        \"learning_rate\": 0.0421,\n        \"subsample\": 0.8789,\n        \"lambda_l1\": 205.6999,\n        \"lambda_l2\": 580.9768,\n        \"max_depth\": 8,\n        \"num_leaves\": 210,\n        \"num_threads\": 20,\n    },\n}\n\n\nSA_RC = {\n    \"class\": \"SigAnaRecord\",\n    \"module_path\": \"qlib.workflow.record_temp\",\n}\n\n\nRECORD_CONFIG = [\n    {\n        \"class\": \"SignalRecord\",\n        \"module_path\": \"qlib.workflow.record_temp\",\n        \"kwargs\": {\n            \"dataset\": \"<DATASET>\",\n            \"model\": \"<MODEL>\",\n        },\n    },\n    SA_RC,\n]\n\n\ndef get_data_handler_config(\n    start_time=\"2008-01-01\",\n    end_time=\"2020-08-01\",\n    fit_start_time=\"<dataset.kwargs.segments.train.0>\",\n    fit_end_time=\"<dataset.kwargs.segments.train.1>\",\n    instruments=CSI300_MARKET,\n):\n    return {\n        \"start_time\": start_time,\n        \"end_time\": end_time,\n        \"fit_start_time\": fit_start_time,\n        \"fit_end_time\": fit_end_time,\n        \"instruments\": instruments,\n    }\n\n\ndef get_dataset_config(\n    dataset_class=DATASET_ALPHA158_CLASS,\n    train=(\"2008-01-01\", \"2014-12-31\"),\n    valid=(\"2015-01-01\", \"2016-12-31\"),\n    test=(\"2017-01-01\", \"2020-08-01\"),\n    handler_kwargs={\"instruments\": CSI300_MARKET},\n):\n    return {\n        \"class\": \"DatasetH\",\n        \"module_path\": \"qlib.data.dataset\",\n        \"kwargs\": {\n            \"handler\": {\n                \"class\": dataset_class,\n                \"module_path\": \"qlib.contrib.data.handler\",\n                \"kwargs\": get_data_handler_config(**handler_kwargs),\n            },\n            \"segments\": {\n                \"train\": train,\n                \"valid\": valid,\n                \"test\": test,\n            },\n        },\n    }\n\n\ndef get_gbdt_task(dataset_kwargs={}, handler_kwargs={\"instruments\": CSI300_MARKET}):\n    return {\n        \"model\": GBDT_MODEL,\n        \"dataset\": get_dataset_config(**dataset_kwargs, handler_kwargs=handler_kwargs),\n    }\n\n\ndef get_record_lgb_config(dataset_kwargs={}, handler_kwargs={\"instruments\": CSI300_MARKET}):\n    return {\n        \"model\": {\n            \"class\": \"LGBModel\",\n            \"module_path\": \"qlib.contrib.model.gbdt\",\n        },\n        \"dataset\": get_dataset_config(**dataset_kwargs, handler_kwargs=handler_kwargs),\n        \"record\": RECORD_CONFIG,\n    }\n\n\ndef get_record_xgboost_config(dataset_kwargs={}, handler_kwargs={\"instruments\": CSI300_MARKET}):\n    return {\n        \"model\": {\n            \"class\": \"XGBModel\",\n            \"module_path\": \"qlib.contrib.model.xgboost\",\n        },\n        \"dataset\": get_dataset_config(**dataset_kwargs, handler_kwargs=handler_kwargs),\n        \"record\": RECORD_CONFIG,\n    }\n\n\nCSI300_DATASET_CONFIG = get_dataset_config(handler_kwargs={\"instruments\": CSI300_MARKET})\nCSI300_GBDT_TASK = get_gbdt_task(handler_kwargs={\"instruments\": CSI300_MARKET})\n\nCSI100_RECORD_XGBOOST_TASK_CONFIG = get_record_xgboost_config(handler_kwargs={\"instruments\": CSI100_MARKET})\nCSI100_RECORD_LGB_TASK_CONFIG = get_record_lgb_config(handler_kwargs={\"instruments\": CSI100_MARKET})\n\n# use for rolling_online_managment.py\nROLLING_HANDLER_CONFIG = {\n    \"start_time\": \"2013-01-01\",\n    \"end_time\": \"2020-09-25\",\n    \"fit_start_time\": \"2013-01-01\",\n    \"fit_end_time\": \"2014-12-31\",\n    \"instruments\": CSI100_MARKET,\n}\nROLLING_DATASET_CONFIG = {\n    \"train\": (\"2013-01-01\", \"2014-12-31\"),\n    \"valid\": (\"2015-01-01\", \"2015-12-31\"),\n    \"test\": (\"2016-01-01\", \"2020-07-10\"),\n}\nCSI100_RECORD_XGBOOST_TASK_CONFIG_ROLLING = get_record_xgboost_config(\n    dataset_kwargs=ROLLING_DATASET_CONFIG, handler_kwargs=ROLLING_HANDLER_CONFIG\n)\nCSI100_RECORD_LGB_TASK_CONFIG_ROLLING = get_record_lgb_config(\n    dataset_kwargs=ROLLING_DATASET_CONFIG, handler_kwargs=ROLLING_HANDLER_CONFIG\n)\n\n# use for online_management_simulate.py\nONLINE_HANDLER_CONFIG = {\n    \"start_time\": \"2018-01-01\",\n    \"end_time\": \"2018-10-31\",\n    \"fit_start_time\": \"2018-01-01\",\n    \"fit_end_time\": \"2018-03-31\",\n    \"instruments\": CSI100_MARKET,\n}\nONLINE_DATASET_CONFIG = {\n    \"train\": (\"2018-01-01\", \"2018-03-31\"),\n    \"valid\": (\"2018-04-01\", \"2018-05-31\"),\n    \"test\": (\"2018-06-01\", \"2018-09-10\"),\n}\nCSI100_RECORD_XGBOOST_TASK_CONFIG_ONLINE = get_record_xgboost_config(\n    dataset_kwargs=ONLINE_DATASET_CONFIG, handler_kwargs=ONLINE_HANDLER_CONFIG\n)\nCSI100_RECORD_LGB_TASK_CONFIG_ONLINE = get_record_lgb_config(\n    dataset_kwargs=ONLINE_DATASET_CONFIG, handler_kwargs=ONLINE_HANDLER_CONFIG\n)\n"
  },
  {
    "path": "qlib/tests/data.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport os\nimport re\nimport sys\nimport qlib\nimport shutil\nimport zipfile\nimport requests\nimport datetime\nfrom tqdm import tqdm\nfrom pathlib import Path\nfrom loguru import logger\nfrom qlib.utils import exists_qlib_data\n\n\nclass GetData:\n    REMOTE_URL = \"https://github.com/SunsetWolf/qlib_dataset/releases/download\"\n\n    def __init__(self, delete_zip_file=False):\n        \"\"\"\n\n        Parameters\n        ----------\n        delete_zip_file : bool, optional\n            Whether to delete the zip file, value from True or False, by default False\n        \"\"\"\n        self.delete_zip_file = delete_zip_file\n\n    def merge_remote_url(self, file_name: str):\n        \"\"\"\n        Generate download links.\n\n        Parameters\n        ----------\n        file_name: str\n            The name of the file to be downloaded.\n            The file name can be accompanied by a version number, (e.g.: v2/qlib_data_simple_cn_1d_latest.zip),\n            if no version number is attached, it will be downloaded from v0 by default.\n        \"\"\"\n        return f\"{self.REMOTE_URL}/{file_name}\" if \"/\" in file_name else f\"{self.REMOTE_URL}/v0/{file_name}\"\n\n    def download(self, url: str, target_path: [Path, str]):\n        \"\"\"\n        Download a file from the specified url.\n\n        Parameters\n        ----------\n        url: str\n            The url of the data.\n        target_path: str\n            The location where the data is saved, including the file name.\n        \"\"\"\n        file_name = str(target_path).rsplit(\"/\", maxsplit=1)[-1]\n        resp = requests.get(url, stream=True, timeout=60)\n        resp.raise_for_status()\n        if resp.status_code != 200:\n            raise requests.exceptions.HTTPError()\n\n        chunk_size = 1024\n        logger.warning(\n            f\"The data for the example is collected from Yahoo Finance. Please be aware that the quality of the data might not be perfect. (You can refer to the original data source: https://finance.yahoo.com/lookup.)\"\n        )\n        logger.info(f\"{os.path.basename(file_name)} downloading......\")\n        with tqdm(total=int(resp.headers.get(\"Content-Length\", 0))) as p_bar:\n            with target_path.open(\"wb\") as fp:\n                for chunk in resp.iter_content(chunk_size=chunk_size):\n                    fp.write(chunk)\n                    p_bar.update(chunk_size)\n\n    def download_data(self, file_name: str, target_dir: [Path, str], delete_old: bool = True):\n        \"\"\"\n        Download the specified file to the target folder.\n\n        Parameters\n        ----------\n        target_dir: str\n            data save directory\n        file_name: str\n            dataset name, needs to endwith .zip, value from [rl_data.zip, csv_data_cn.zip, ...]\n            may contain folder names, for example: v2/qlib_data_simple_cn_1d_latest.zip\n        delete_old: bool\n            delete an existing directory, by default True\n\n        Examples\n        ---------\n        # get rl data\n        python get_data.py download_data --file_name rl_data.zip --target_dir ~/.qlib/qlib_data/rl_data\n        When this command is run, the data will be downloaded from this link: https://qlibpublic.blob.core.windows.net/data/default/stock_data/rl_data.zip?{token}\n\n        # get cn csv data\n        python get_data.py download_data --file_name csv_data_cn.zip --target_dir ~/.qlib/csv_data/cn_data\n        When this command is run, the data will be downloaded from this link: https://qlibpublic.blob.core.windows.net/data/default/stock_data/csv_data_cn.zip?{token}\n        -------\n\n        \"\"\"\n        target_dir = Path(target_dir).expanduser()\n        target_dir.mkdir(exist_ok=True, parents=True)\n        # saved file name\n        _target_file_name = datetime.datetime.now().strftime(\"%Y%m%d%H%M%S\") + \"_\" + os.path.basename(file_name)\n        target_path = target_dir.joinpath(_target_file_name)\n\n        url = self.merge_remote_url(file_name)\n        self.download(url=url, target_path=target_path)\n\n        self._unzip(target_path, target_dir, delete_old)\n        if self.delete_zip_file:\n            target_path.unlink()\n\n    def check_dataset(self, file_name: str):\n        url = self.merge_remote_url(file_name)\n        resp = requests.get(url, stream=True, timeout=60)\n        status = True\n        if resp.status_code == 404:\n            status = False\n        return status\n\n    @staticmethod\n    def _unzip(file_path: [Path, str], target_dir: [Path, str], delete_old: bool = True):\n        file_path = Path(file_path)\n        target_dir = Path(target_dir)\n        if delete_old:\n            logger.warning(\n                f\"will delete the old qlib data directory(features, instruments, calendars, features_cache, dataset_cache): {target_dir}\"\n            )\n            GetData._delete_qlib_data(target_dir)\n        logger.info(f\"{file_path} unzipping......\")\n        with zipfile.ZipFile(str(file_path.resolve()), \"r\") as zp:\n            for _file in tqdm(zp.namelist()):\n                zp.extract(_file, str(target_dir.resolve()))\n\n    @staticmethod\n    def _delete_qlib_data(file_dir: Path):\n        rm_dirs = []\n        for _name in [\"features\", \"calendars\", \"instruments\", \"features_cache\", \"dataset_cache\"]:\n            _p = file_dir.joinpath(_name)\n            if _p.exists():\n                rm_dirs.append(str(_p.resolve()))\n        if rm_dirs:\n            flag = input(\n                f\"Will be deleted: \"\n                f\"\\n\\t{rm_dirs}\"\n                f\"\\nIf you do not need to delete {file_dir}, please change the <--target_dir>\"\n                f\"\\nAre you sure you want to delete, yes(Y/y), no (N/n):\"\n            )\n            if str(flag) not in [\"Y\", \"y\"]:\n                sys.exit()\n            for _p in rm_dirs:\n                logger.warning(f\"delete: {_p}\")\n                shutil.rmtree(_p)\n\n    def qlib_data(\n        self,\n        name=\"qlib_data\",\n        target_dir=\"~/.qlib/qlib_data/cn_data\",\n        version=None,\n        interval=\"1d\",\n        region=\"cn\",\n        delete_old=True,\n        exists_skip=False,\n    ):\n        \"\"\"download cn qlib data from remote\n\n        Parameters\n        ----------\n        target_dir: str\n            data save directory\n        name: str\n            dataset name, value from [qlib_data, qlib_data_simple], by default qlib_data\n        version: str\n            data version, value from [v1, ...], by default None(use script to specify version)\n        interval: str\n            data freq, value from [1d], by default 1d\n        region: str\n            data region, value from [cn, us], by default cn\n        delete_old: bool\n            delete an existing directory, by default True\n        exists_skip: bool\n            exists skip, by default False\n\n        Examples\n        ---------\n        # get 1d data\n        python get_data.py qlib_data --name qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn\n        When this command is run, the data will be downloaded from this link: https://qlibpublic.blob.core.windows.net/data/default/stock_data/v2/qlib_data_cn_1d_latest.zip?{token}\n\n        # get 1min data\n        python get_data.py qlib_data --name qlib_data --target_dir ~/.qlib/qlib_data/cn_data_1min --interval 1min --region cn\n        When this command is run, the data will be downloaded from this link: https://qlibpublic.blob.core.windows.net/data/default/stock_data/v2/qlib_data_cn_1min_latest.zip?{token}\n        -------\n\n        \"\"\"\n        if exists_skip and exists_qlib_data(target_dir):\n            logger.warning(\n                f\"Data already exists: {target_dir}, the data download will be skipped\\n\"\n                f\"\\tIf downloading is required: `exists_skip=False` or `change target_dir`\"\n            )\n            return\n\n        qlib_version = \".\".join(re.findall(r\"(\\d+)\\.+\", qlib.__version__))\n\n        def _get_file_name_with_version(qlib_version, dataset_version):\n            dataset_version = \"v2\" if dataset_version is None else dataset_version\n            file_name_with_version = f\"{dataset_version}/{name}_{region.lower()}_{interval.lower()}_{qlib_version}.zip\"\n            return file_name_with_version\n\n        file_name = _get_file_name_with_version(qlib_version, dataset_version=version)\n        if not self.check_dataset(file_name):\n            file_name = _get_file_name_with_version(\"latest\", dataset_version=version)\n        self.download_data(file_name.lower(), target_dir, delete_old)\n"
  },
  {
    "path": "qlib/typehint.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\n\"\"\"Commonly used types.\"\"\"\n\nimport sys\nfrom typing import Union\nfrom pathlib import Path\n\n__all__ = [\"Literal\", \"TypedDict\", \"final\"]\n\nif sys.version_info >= (3, 8):\n    from typing import Literal, TypedDict, final  # type: ignore  # pylint: disable=no-name-in-module\nelse:\n    from typing_extensions import Literal, TypedDict, final\n\n\nclass InstDictConf(TypedDict):\n    \"\"\"\n    InstDictConf  is a Dict-based config to describe an instance\n\n        case 1)\n        {\n            'class': 'ClassName',\n            'kwargs': dict, #  It is optional. {} will be used if not given\n            'model_path': path, # It is optional if module is given in the class\n        }\n        case 2)\n        {\n            'class': <The class it self>,\n            'kwargs': dict, #  It is optional. {} will be used if not given\n        }\n    \"\"\"\n\n    # class: str  # because class is a keyword of Python. We have to comment it\n    kwargs: dict  # It is optional. {} will be used if not given\n    module_path: str  # It is optional if module is given in the class\n\n\nInstConf = Union[InstDictConf, str, object, Path]\n\"\"\"\nInstConf is a type to describe an instance; it will be passed into init_instance_by_config for Qlib\n\n    config : Union[str, dict, object, Path]\n\n        InstDictConf example.\n            please refer to the docs of InstDictConf\n\n        str example.\n            1) specify a pickle object\n                - path like 'file:///<path to pickle file>/obj.pkl'\n            2) specify a class name\n                - \"ClassName\":  getattr(module, \"ClassName\")() will be used.\n            3) specify module path with class name\n                - \"a.b.c.ClassName\" getattr(<a.b.c.module>, \"ClassName\")() will be used.\n\n        object example:\n            instance of accept_types\n\n        Path example:\n            specify a pickle object\n                - it will be treated like 'file:///<path to pickle file>/obj.pkl'\n\"\"\"\n"
  },
  {
    "path": "qlib/utils/__init__.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\n# TODO: this utils covers too much utilities, please seperat it into sub modules\n\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport os\nimport re\nimport copy\nimport json\nimport redis\nimport bisect\nimport struct\nimport difflib\nimport inspect\nimport hashlib\nimport datetime\nimport requests\nimport collections\nimport numpy as np\nimport pandas as pd\nfrom pathlib import Path\nfrom typing import List, Union, Optional, Callable\nfrom packaging import version\nfrom ruamel.yaml import YAML\nfrom .file import (\n    get_or_create_path,\n    save_multiple_parts_file,\n    unpack_archive_with_buffer,\n    get_tmp_file_with_buffer,\n)\nfrom ..config import C\nfrom ..log import get_module_logger, set_log_with_config\n\nlog = get_module_logger(\"utils\")\n# MultiIndex.is_lexsorted() is a deprecated method in Pandas 1.3.0.\nis_deprecated_lexsorted_pandas = version.parse(pd.__version__) > version.parse(\"1.3.0\")\n\n\n#################### Server ####################\ndef get_redis_connection():\n    \"\"\"get redis connection instance.\"\"\"\n    return redis.StrictRedis(\n        host=C.redis_host,\n        port=C.redis_port,\n        db=C.redis_task_db,\n        password=C.redis_password,\n    )\n\n\n#################### Data ####################\ndef read_bin(file_path: Union[str, Path], start_index, end_index):\n    file_path = Path(file_path.expanduser().resolve())\n    with file_path.open(\"rb\") as f:\n        # read start_index\n        ref_start_index = int(np.frombuffer(f.read(4), dtype=\"<f\")[0])\n        si = max(ref_start_index, start_index)\n        if si > end_index:\n            return pd.Series(dtype=np.float32)\n        # calculate offset\n        f.seek(4 * (si - ref_start_index) + 4)\n        # read nbytes\n        count = end_index - si + 1\n        data = np.frombuffer(f.read(4 * count), dtype=\"<f\")\n        series = pd.Series(data, index=pd.RangeIndex(si, si + len(data)))\n    return series\n\n\ndef get_period_list(first: int, last: int, quarterly: bool) -> List[int]:\n    \"\"\"\n    This method will be used in PIT database.\n    It return all the possible values between `first` and `end`  (first and end is included)\n\n    Parameters\n    ----------\n    quarterly : bool\n        will it return quarterly index or yearly index.\n\n    Returns\n    -------\n    List[int]\n        the possible index between [first, last]\n    \"\"\"\n\n    if not quarterly:\n        assert all(1900 <= x <= 2099 for x in (first, last)), \"invalid arguments\"\n        return list(range(first, last + 1))\n    else:\n        assert all(190000 <= x <= 209904 for x in (first, last)), \"invalid arguments\"\n        res = []\n        for year in range(first // 100, last // 100 + 1):\n            for q in range(1, 5):\n                period = year * 100 + q\n                if first <= period <= last:\n                    res.append(year * 100 + q)\n        return res\n\n\ndef get_period_offset(first_year, period, quarterly):\n    if quarterly:\n        offset = (period // 100 - first_year) * 4 + period % 100 - 1\n    else:\n        offset = period - first_year\n    return offset\n\n\ndef read_period_data(\n    index_path,\n    data_path,\n    period,\n    cur_date_int: int,\n    quarterly,\n    last_period_index: int = None,\n):\n    \"\"\"\n    At `cur_date`(e.g. 20190102), read the information at `period`(e.g. 201803).\n    Only the updating info before cur_date or at cur_date will be used.\n\n    Parameters\n    ----------\n    period: int\n        date period represented by interger, e.g. 201901 corresponds to the first quarter in 2019\n    cur_date_int: int\n        date which represented by interger, e.g. 20190102\n    last_period_index: int\n        it is a optional parameter; it is designed to avoid repeatedly access the .index data of PIT database when\n        sequentially observing the data (Because the latest index of a specific period of data certainly appear in after the one in last observation).\n\n    Returns\n    -------\n    the query value and byte index the index value\n    \"\"\"\n    DATA_DTYPE = \"\".join(\n        [\n            C.pit_record_type[\"date\"],\n            C.pit_record_type[\"period\"],\n            C.pit_record_type[\"value\"],\n            C.pit_record_type[\"index\"],\n        ]\n    )\n\n    PERIOD_DTYPE = C.pit_record_type[\"period\"]\n    INDEX_DTYPE = C.pit_record_type[\"index\"]\n\n    NAN_VALUE = C.pit_record_nan[\"value\"]\n    NAN_INDEX = C.pit_record_nan[\"index\"]\n\n    # find the first index of linked revisions\n    if last_period_index is None:\n        with open(index_path, \"rb\") as fi:\n            (first_year,) = struct.unpack(PERIOD_DTYPE, fi.read(struct.calcsize(PERIOD_DTYPE)))\n            all_periods = np.fromfile(fi, dtype=INDEX_DTYPE)\n        offset = get_period_offset(first_year, period, quarterly)\n        _next = all_periods[offset]\n    else:\n        _next = last_period_index\n\n    # load data following the `_next` link\n    prev_value = NAN_VALUE\n    prev_next = _next\n\n    with open(data_path, \"rb\") as fd:\n        while _next != NAN_INDEX:\n            fd.seek(_next)\n            date, period, value, new_next = struct.unpack(DATA_DTYPE, fd.read(struct.calcsize(DATA_DTYPE)))\n            if date > cur_date_int:\n                break\n            prev_next = _next\n            _next = new_next\n            prev_value = value\n    return prev_value, prev_next\n\n\ndef np_ffill(arr: np.array):\n    \"\"\"\n    forward fill a 1D numpy array\n\n    Parameters\n    ----------\n    arr : np.array\n        Input numpy 1D array\n    \"\"\"\n    mask = np.isnan(arr.astype(float))  # np.isnan only works on np.float\n    # get fill index\n    idx = np.where(~mask, np.arange(mask.shape[0]), 0)\n    np.maximum.accumulate(idx, out=idx)\n    return arr[idx]\n\n\n#################### Search ####################\ndef lower_bound(data, val, level=0):\n    \"\"\"multi fields list lower bound.\n\n    for single field list use `bisect.bisect_left` instead\n    \"\"\"\n    left = 0\n    right = len(data)\n    while left < right:\n        mid = (left + right) // 2\n        if val <= data[mid][level]:\n            right = mid\n        else:\n            left = mid + 1\n    return left\n\n\ndef upper_bound(data, val, level=0):\n    \"\"\"multi fields list upper bound.\n\n    for single field list use `bisect.bisect_right` instead\n    \"\"\"\n    left = 0\n    right = len(data)\n    while left < right:\n        mid = (left + right) // 2\n        if val >= data[mid][level]:\n            left = mid + 1\n        else:\n            right = mid\n    return left\n\n\n#################### HTTP ####################\ndef requests_with_retry(url, retry=5, **kwargs):\n    while retry > 0:\n        retry -= 1\n        try:\n            res = requests.get(url, timeout=1, **kwargs)\n            assert res.status_code in {200, 206}\n            return res\n        except AssertionError:\n            continue\n        except Exception as e:\n            log.warning(\"exception encountered {}\".format(e))\n            continue\n    raise TimeoutError(\"ERROR: requests failed!\")\n\n\n#################### Parse ####################\ndef parse_config(config):\n    # Check whether need parse, all object except str do not need to be parsed\n    if not isinstance(config, str):\n        return config\n    # Check whether config is file\n    yaml = YAML(typ=\"safe\", pure=True)\n    if os.path.exists(config):\n        with open(config, \"r\") as f:\n            return yaml.load(f)\n    # Check whether the str can be parsed\n    try:\n        return yaml.load(config)\n    except BaseException as base_exp:\n        raise ValueError(\"cannot parse config!\") from base_exp\n\n\n#################### Other ####################\ndef drop_nan_by_y_index(x, y, weight=None):\n    # x, y, weight: DataFrame\n    # Find index of rows which do not contain Nan in all columns from y.\n    mask = ~y.isna().any(axis=1)\n    # Get related rows from x, y, weight.\n    x = x[mask]\n    y = y[mask]\n    if weight is not None:\n        weight = weight[mask]\n    return x, y, weight\n\n\ndef hash_args(*args):\n    # json.dumps will keep the dict keys always sorted.\n    string = json.dumps(args, sort_keys=True, default=str)  # frozenset\n    return hashlib.md5(string.encode()).hexdigest()\n\n\ndef parse_field(field):\n    # Following patterns will be matched:\n    # - $close -> Feature(\"close\")\n    # - $close5 -> Feature(\"close5\")\n    # - $open+$close -> Feature(\"open\")+Feature(\"close\")\n    # TODO: this maybe used in the feature if we want to support the computation of different frequency data\n    # - $close@5min -> Feature(\"close\", \"5min\")\n\n    if not isinstance(field, str):\n        field = str(field)\n    # Chinese punctuation regex:\n    # \\u3001 -> 、\n    # \\uff1a -> ：\n    # \\uff08 -> (\n    # \\uff09 -> )\n    chinese_punctuation_regex = r\"\\u3001\\uff1a\\uff08\\uff09\"\n    for pattern, new in [\n        (\n            rf\"\\$\\$([\\w{chinese_punctuation_regex}]+)\",\n            r'PFeature(\"\\1\")',\n        ),  # $$ must be before $\n        (rf\"\\$([\\w{chinese_punctuation_regex}]+)\", r'Feature(\"\\1\")'),\n        (r\"(\\w+\\s*)\\(\", r\"Operators.\\1(\"),\n    ]:  # Features  # Operators\n        field = re.sub(pattern, new, field)\n    return field\n\n\ndef compare_dict_value(src_data: dict, dst_data: dict):\n    \"\"\"Compare dict value\n\n    :param src_data:\n    :param dst_data:\n    :return:\n    \"\"\"\n\n    class DateEncoder(json.JSONEncoder):\n        # FIXME: This class can only be accurate to the day. If it is a minute,\n        # there may be a bug\n        def default(self, o):\n            if isinstance(o, (datetime.datetime, datetime.date)):\n                return o.strftime(\"%Y-%m-%d %H:%M:%S\")\n            return json.JSONEncoder.default(self, o)\n\n    src_data = json.dumps(src_data, indent=4, sort_keys=True, cls=DateEncoder)\n    dst_data = json.dumps(dst_data, indent=4, sort_keys=True, cls=DateEncoder)\n    diff = difflib.ndiff(src_data, dst_data)\n    changes = [line for line in diff if line.startswith(\"+ \") or line.startswith(\"- \")]\n    return changes\n\n\ndef remove_repeat_field(fields):\n    \"\"\"remove repeat field\n\n    :param fields: list; features fields\n    :return: list\n    \"\"\"\n    fields = copy.deepcopy(fields)\n    _fields = set(fields)\n    return sorted(_fields, key=fields.index)\n\n\ndef remove_fields_space(fields: [list, str, tuple]):\n    \"\"\"remove fields space\n\n    :param fields: features fields\n    :return: list or str\n    \"\"\"\n    if isinstance(fields, str):\n        return fields.replace(\" \", \"\")\n    return [i.replace(\" \", \"\") if isinstance(i, str) else str(i) for i in fields]\n\n\ndef normalize_cache_fields(fields: [list, tuple]):\n    \"\"\"normalize cache fields\n\n    :param fields: features fields\n    :return: list\n    \"\"\"\n    return sorted(remove_repeat_field(remove_fields_space(fields)))\n\n\ndef normalize_cache_instruments(instruments):\n    \"\"\"normalize cache instruments\n\n    :return: list or dict\n    \"\"\"\n    if isinstance(instruments, (list, tuple, pd.Index, np.ndarray)):\n        instruments = sorted(list(instruments))\n    else:\n        # dict type stockpool\n        if \"market\" in instruments:\n            pass\n        else:\n            instruments = {k: sorted(v) for k, v in instruments.items()}\n    return instruments\n\n\ndef is_tradable_date(cur_date):\n    \"\"\"judgy whether date is a tradable date\n    ----------\n    date : pandas.Timestamp\n        current date\n    \"\"\"\n    from ..data import D  # pylint: disable=C0415\n\n    return str(cur_date.date()) == str(D.calendar(start_time=cur_date, future=True)[0].date())\n\n\ndef get_date_range(trading_date, left_shift=0, right_shift=0, future=False):\n    \"\"\"get trading date range by shift\n\n    Parameters\n    ----------\n    trading_date: pd.Timestamp\n    left_shift: int\n    right_shift: int\n    future: bool\n\n    \"\"\"\n\n    from ..data import D  # pylint: disable=C0415\n\n    start = get_date_by_shift(trading_date, left_shift, future=future)\n    end = get_date_by_shift(trading_date, right_shift, future=future)\n\n    calendar = D.calendar(start, end, future=future)\n    return calendar\n\n\ndef get_date_by_shift(\n    trading_date,\n    shift,\n    future=False,\n    clip_shift=True,\n    freq=\"day\",\n    align: Optional[str] = None,\n):\n    \"\"\"get trading date with shift bias will cur_date\n        e.g. : shift == 1,  return next trading date\n               shift == -1, return previous trading date\n    ----------\n    trading_date : pandas.Timestamp\n        current date\n    shift : int\n    clip_shift: bool\n    align : Optional[str]\n        When align is None, this function will raise ValueError if `trading_date` is not a trading date\n        when align is \"left\"/\"right\", it will try to align to left/right nearest trading date before shifting when `trading_date` is not a trading date\n\n    \"\"\"\n    from qlib.data import D  # pylint: disable=C0415\n\n    cal = D.calendar(future=future, freq=freq)\n    trading_date = pd.to_datetime(trading_date)\n    if align is None:\n        if trading_date not in list(cal):\n            raise ValueError(\"{} is not trading day!\".format(str(trading_date)))\n        _index = bisect.bisect_left(cal, trading_date)\n    elif align == \"left\":\n        _index = bisect.bisect_right(cal, trading_date) - 1\n    elif align == \"right\":\n        _index = bisect.bisect_left(cal, trading_date)\n    else:\n        raise ValueError(f\"align with value `{align}` is not supported\")\n    shift_index = _index + shift\n    if shift_index < 0 or shift_index >= len(cal):\n        if clip_shift:\n            shift_index = np.clip(shift_index, 0, len(cal) - 1)\n        else:\n            raise IndexError(f\"The shift_index({shift_index}) of the trading day ({trading_date}) is out of range\")\n    return cal[shift_index]\n\n\ndef get_next_trading_date(trading_date, future=False):\n    \"\"\"get next trading date\n    ----------\n    cur_date : pandas.Timestamp\n        current date\n    \"\"\"\n    return get_date_by_shift(trading_date, 1, future=future)\n\n\ndef get_pre_trading_date(trading_date, future=False):\n    \"\"\"get previous trading date\n    ----------\n    date : pandas.Timestamp\n        current date\n    \"\"\"\n    return get_date_by_shift(trading_date, -1, future=future)\n\n\ndef transform_end_date(end_date=None, freq=\"day\"):\n    \"\"\"handle the end date with various format\n\n    If end_date is -1, None, or end_date is greater than the maximum trading day, the last trading date is returned.\n    Otherwise, returns the end_date\n\n    ----------\n    end_date: str\n        end trading date\n    date : pandas.Timestamp\n        current date\n    \"\"\"\n    from ..data import D  # pylint: disable=C0415\n\n    last_date = D.calendar(freq=freq)[-1]\n    if end_date is None or (str(end_date) == \"-1\") or (pd.Timestamp(last_date) < pd.Timestamp(end_date)):\n        log.warning(\n            \"\\nInfo: the end_date in the configuration file is {}, \"\n            \"so the default last date {} is used.\".format(end_date, last_date)\n        )\n        end_date = last_date\n    return end_date\n\n\ndef get_date_in_file_name(file_name):\n    \"\"\"Get the date(YYYY-MM-DD) written in file name\n    Parameter\n            file_name : str\n       :return\n            date : str\n                'YYYY-MM-DD'\n    \"\"\"\n    pattern = \"[0-9]{4}-[0-9]{2}-[0-9]{2}\"\n    date = re.search(pattern, str(file_name)).group()\n    return date\n\n\ndef split_pred(pred, number=None, split_date=None):\n    \"\"\"split the score file into two part\n    Parameter\n    ---------\n        pred : pd.DataFrame (index:<instrument, datetime>)\n            A score file of stocks\n        number: the number of dates for pred_left\n        split_date: the last date of the pred_left\n    Return\n    -------\n        pred_left : pd.DataFrame (index:<instrument, datetime>)\n            The first part of original score file\n        pred_right : pd.DataFrame (index:<instrument, datetime>)\n            The second part of original score file\n    \"\"\"\n    if number is None and split_date is None:\n        raise ValueError(\"`number` and `split date` cannot both be None\")\n    dates = sorted(pred.index.get_level_values(\"datetime\").unique())\n    dates = list(map(pd.Timestamp, dates))\n    if split_date is None:\n        date_left_end = dates[number - 1]\n        date_right_begin = dates[number]\n        date_left_start = None\n    else:\n        split_date = pd.Timestamp(split_date)\n        date_left_end = split_date\n        date_right_begin = split_date + pd.Timedelta(days=1)\n        if number is None:\n            date_left_start = None\n        else:\n            end_idx = bisect.bisect_right(dates, split_date)\n            date_left_start = dates[end_idx - number]\n    pred_temp = pred.sort_index()\n    pred_left = pred_temp.loc(axis=0)[:, date_left_start:date_left_end]\n    pred_right = pred_temp.loc(axis=0)[:, date_right_begin:]\n    return pred_left, pred_right\n\n\ndef time_to_slc_point(t: Union[None, str, pd.Timestamp]) -> Union[None, pd.Timestamp]:\n    \"\"\"\n    Time slicing in Qlib or Pandas is a frequently-used action.\n    However, user often input all kinds of data format to represent time.\n    This function will help user to convert these inputs into a uniform format which is friendly to time slicing.\n\n    Parameters\n    ----------\n    t : Union[None, str, pd.Timestamp]\n        original time\n\n    Returns\n    -------\n    Union[None, pd.Timestamp]:\n    \"\"\"\n    if t is None:\n        # None represents unbounded in Qlib or Pandas(e.g. df.loc[slice(None, \"20210303\")]).\n        return t\n    else:\n        return pd.Timestamp(t)\n\n\ndef can_use_cache():\n    res = True\n    r = get_redis_connection()\n    try:\n        r.client()\n    except redis.exceptions.ConnectionError:\n        res = False\n    finally:\n        r.close()\n    return res\n\n\ndef exists_qlib_data(qlib_dir):\n    qlib_dir = Path(qlib_dir).expanduser()\n    if not qlib_dir.exists():\n        return False\n\n    calendars_dir = qlib_dir.joinpath(\"calendars\")\n    instruments_dir = qlib_dir.joinpath(\"instruments\")\n    features_dir = qlib_dir.joinpath(\"features\")\n    # check dir\n    for _dir in [calendars_dir, instruments_dir, features_dir]:\n        if not (_dir.exists() and list(_dir.iterdir())):\n            return False\n    # check calendar bin\n    for _calendar in calendars_dir.iterdir():\n        if (\"_future\" not in _calendar.name) and (\n            not list(features_dir.rglob(f\"*.{_calendar.name.split('.')[0]}.bin\"))\n        ):\n            return False\n\n    # check instruments\n    code_names = set(map(lambda x: fname_to_code(x.name.lower()), features_dir.iterdir()))\n    _instrument = instruments_dir.joinpath(\"all.txt\")\n    # Removed two possible ticker names \"NA\" and \"NULL\" from the default na_values list for column 0\n    miss_code = set(\n        pd.read_csv(\n            _instrument,\n            sep=\"\\t\",\n            header=None,\n            keep_default_na=False,\n            na_values={\n                0: [\n                    \" \",\n                    \"#N/A\",\n                    \"#N/A N/A\",\n                    \"#NA\",\n                    \"-1.#IND\",\n                    \"-1.#QNAN\",\n                    \"-NaN\",\n                    \"-nan\",\n                    \"1.#IND\",\n                    \"1.#QNAN\",\n                    \"<NA>\",\n                    \"N/A\",\n                    \"NaN\",\n                    \"None\",\n                    \"n/a\",\n                    \"nan\",\n                    \"null \",\n                ]\n            },\n        )\n        .loc[:, 0]\n        .apply(str.lower)\n    ) - set(code_names)\n    if miss_code and any(map(lambda x: \"sht\" not in x, miss_code)):\n        return False\n\n    return True\n\n\ndef check_qlib_data(qlib_config):\n    inst_dir = Path(qlib_config[\"provider_uri\"]).joinpath(\"instruments\")\n    for _p in inst_dir.glob(\"*.txt\"):\n        assert len(pd.read_csv(_p, sep=\"\\t\", nrows=0, header=None).columns) == 3, (\n            f\"\\nThe {str(_p.resolve())} of qlib data is not equal to 3 columns:\"\n            f\"\\n\\tIf you are using the data provided by qlib: \"\n            f\"https://qlib.readthedocs.io/en/latest/component/data.html#qlib-format-dataset\"\n            f\"\\n\\tIf you are using your own data, please dump the data again: \"\n            f\"https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format\"\n        )\n\n\ndef lazy_sort_index(df: pd.DataFrame, axis=0) -> pd.DataFrame:\n    \"\"\"\n    make the df index sorted\n\n    df.sort_index() will take a lot of time even when `df.is_lexsorted() == True`\n    This function could avoid such case\n\n    Parameters\n    ----------\n    df : pd.DataFrame\n\n    Returns\n    -------\n    pd.DataFrame:\n        sorted dataframe\n    \"\"\"\n    idx = df.index if axis == 0 else df.columns\n    if (\n        not idx.is_monotonic_increasing\n        or not is_deprecated_lexsorted_pandas\n        and isinstance(idx, pd.MultiIndex)\n        and not idx.is_lexsorted()\n    ):  # this case is for the old version\n        return df.sort_index(axis=axis)\n    else:\n        return df\n\n\nFLATTEN_TUPLE = \"_FLATTEN_TUPLE\"\n\n\ndef flatten_dict(d, parent_key=\"\", sep=\".\") -> dict:\n    \"\"\"\n    Flatten a nested dict.\n\n        >>> flatten_dict({'a': 1, 'c': {'a': 2, 'b': {'x': 5, 'y' : 10}}, 'd': [1, 2, 3]})\n        >>> {'a': 1, 'c.a': 2, 'c.b.x': 5, 'd': [1, 2, 3], 'c.b.y': 10}\n\n        >>> flatten_dict({'a': 1, 'c': {'a': 2, 'b': {'x': 5, 'y' : 10}}, 'd': [1, 2, 3]}, sep=FLATTEN_TUPLE)\n        >>> {'a': 1, ('c','a'): 2, ('c','b','x'): 5, 'd': [1, 2, 3], ('c','b','y'): 10}\n\n    Args:\n        d (dict): the dict waiting for flatting\n        parent_key (str, optional): the parent key, will be a prefix in new key. Defaults to \"\".\n        sep (str, optional): the separator for string connecting. FLATTEN_TUPLE for tuple connecting.\n\n    Returns:\n        dict: flatten dict\n    \"\"\"\n    items = []\n    for k, v in d.items():\n        if sep == FLATTEN_TUPLE:\n            new_key = (parent_key, k) if parent_key else k\n        else:\n            new_key = parent_key + sep + k if parent_key else k\n        if isinstance(v, collections.abc.MutableMapping):\n            items.extend(flatten_dict(v, new_key, sep=sep).items())\n        else:\n            items.append((new_key, v))\n    return dict(items)\n\n\ndef get_item_from_obj(config: dict, name_path: str) -> object:\n    \"\"\"\n    Follow the name_path to get values from config\n    For example:\n    If we follow the example in in the Parameters section,\n        Timestamp('2008-01-02 00:00:00') will be returned\n\n    Parameters\n    ----------\n    config : dict\n        e.g.\n        {'dataset': {'class': 'DatasetH',\n          'kwargs': {'handler': {'class': 'Alpha158',\n                                 'kwargs': {'end_time': '2020-08-01',\n                                            'fit_end_time': '<dataset.kwargs.segments.train.1>',\n                                            'fit_start_time': '<dataset.kwargs.segments.train.0>',\n                                            'instruments': 'csi100',\n                                            'start_time': '2008-01-01'},\n                                 'module_path': 'qlib.contrib.data.handler'},\n                     'segments': {'test': (Timestamp('2017-01-03 00:00:00'),\n                                           Timestamp('2019-04-08 00:00:00')),\n                                  'train': (Timestamp('2008-01-02 00:00:00'),\n                                            Timestamp('2014-12-31 00:00:00')),\n                                  'valid': (Timestamp('2015-01-05 00:00:00'),\n                                            Timestamp('2016-12-30 00:00:00'))}}\n        }}\n    name_path : str\n        e.g.\n        \"dataset.kwargs.segments.train.1\"\n\n    Returns\n    -------\n    object\n        the retrieved object\n    \"\"\"\n    cur_cfg = config\n    for k in name_path.split(\".\"):\n        if isinstance(cur_cfg, dict):\n            cur_cfg = cur_cfg[k]  # may raise KeyError\n        elif k.isdigit():\n            cur_cfg = cur_cfg[int(k)]  # may raise IndexError\n        else:\n            raise ValueError(f\"Error when getting {k} from cur_cfg\")\n    return cur_cfg\n\n\ndef fill_placeholder(config: dict, config_extend: dict):\n    \"\"\"\n    Detect placeholder in config and fill them with config_extend.\n    The item of dict must be single item(int, str, etc), dict and list. Tuples are not supported.\n    There are two type of variables:\n    - user-defined variables :\n        e.g. when config_extend is `{\"<MODEL>\": model, \"<DATASET>\": dataset}`, \"<MODEL>\" and \"<DATASET>\" in `config` will be replaced with `model` `dataset`\n    - variables extracted from `config` :\n        e.g. the variables like \"<dataset.kwargs.segments.train.0>\" will be replaced with the values from `config`\n\n    Parameters\n    ----------\n    config : dict\n        the parameter dict will be filled\n    config_extend : dict\n        the value of all placeholders\n\n    Returns\n    -------\n    dict\n        the parameter dict\n    \"\"\"\n    # check the format of config_extend\n    for placeholder in config_extend.keys():\n        assert re.match(r\"<[^<>]+>\", placeholder)\n\n    # bfs\n    top = 0\n    tail = 1\n    item_queue = [config]\n\n    def try_replace_placeholder(value):\n        if value in config_extend.keys():\n            value = config_extend[value]\n        else:\n            m = re.match(r\"<(?P<name_path>[^<>]+)>\", value)\n            if m is not None:\n                try:\n                    value = get_item_from_obj(config, m.groupdict()[\"name_path\"])\n                except (KeyError, ValueError, IndexError):\n                    get_module_logger(\"fill_placeholder\").info(\n                        f\"{value} lookes like a placeholder, but it can't match to any given values\"\n                    )\n        return value\n\n    item_keys = None\n    while top < tail:\n        now_item = item_queue[top]\n        top += 1\n        if isinstance(now_item, list):\n            item_keys = range(len(now_item))\n        elif isinstance(now_item, dict):\n            item_keys = now_item.keys()\n        for key in item_keys:  # noqa\n            if isinstance(now_item[key], (list, dict)):\n                item_queue.append(now_item[key])\n                tail += 1\n            elif isinstance(now_item[key], str):\n                # If it is a string, try to replace it with placeholder\n                now_item[key] = try_replace_placeholder(now_item[key])\n    return config\n\n\ndef auto_filter_kwargs(func: Callable, warning=True) -> Callable:\n    \"\"\"\n    this will work like a decoration function\n\n    The decrated function will ignore and give warning when the parameter is not acceptable\n\n    For example, if you have a function `f` which may optionally consume the keywards `bar`.\n    then you can call it by `auto_filter_kwargs(f)(bar=3)`, which will automatically filter out\n    `bar` when f does not need bar\n\n    Parameters\n    ----------\n    func : Callable\n        The original function\n\n    Returns\n    -------\n    Callable:\n        the new callable function\n    \"\"\"\n\n    def _func(*args, **kwargs):\n        spec = inspect.getfullargspec(func)\n        new_kwargs = {}\n        for k, v in kwargs.items():\n            # if `func` don't accept variable keyword arguments like `**kwargs` and have not according named arguments\n            if spec.varkw is None and k not in spec.args:\n                if warning:\n                    log.warning(f\"The parameter `{k}` with value `{v}` is ignored.\")\n            else:\n                new_kwargs[k] = v\n        return func(*args, **new_kwargs)\n\n    return _func\n\n\n#################### Wrapper #####################\nclass Wrapper:\n    \"\"\"Wrapper class for anything that needs to set up during qlib.init\"\"\"\n\n    def __init__(self):\n        self._provider = None\n\n    def register(self, provider):\n        self._provider = provider\n\n    def __repr__(self):\n        return \"{name}(provider={provider})\".format(name=self.__class__.__name__, provider=self._provider)\n\n    def __getattr__(self, key):\n        if self.__dict__.get(\"_provider\", None) is None:\n            raise AttributeError(\"Please run qlib.init() first using qlib\")\n        return getattr(self._provider, key)\n\n\ndef register_wrapper(wrapper, cls_or_obj, module_path=None):\n    \"\"\"register_wrapper\n\n    :param wrapper: A wrapper.\n    :param cls_or_obj:  A class or class name or object instance.\n    \"\"\"\n    if isinstance(cls_or_obj, str):\n        module = get_module_by_module_path(module_path)\n        cls_or_obj = getattr(module, cls_or_obj)\n    obj = cls_or_obj() if isinstance(cls_or_obj, type) else cls_or_obj\n    wrapper.register(obj)\n\n\ndef load_dataset(path_or_obj, index_col=[0, 1]):\n    \"\"\"load dataset from multiple file formats\"\"\"\n    if isinstance(path_or_obj, pd.DataFrame):\n        return path_or_obj\n    if not os.path.exists(path_or_obj):\n        raise ValueError(f\"file {path_or_obj} doesn't exist\")\n    _, extension = os.path.splitext(path_or_obj)\n    if extension == \".h5\":\n        return pd.read_hdf(path_or_obj)\n    elif extension == \".pkl\":\n        return pd.read_pickle(path_or_obj)\n    elif extension == \".csv\":\n        return pd.read_csv(path_or_obj, parse_dates=True, index_col=index_col)\n    raise ValueError(f\"unsupported file type `{extension}`\")\n\n\ndef code_to_fname(code: str):\n    \"\"\"stock code to file name\n\n    Parameters\n    ----------\n    code: str\n    \"\"\"\n    # NOTE: In windows, the following name is I/O device, and the file with the corresponding name cannot be created\n    # reference: https://superuser.com/questions/86999/why-cant-i-name-a-folder-or-file-con-in-windows\n    replace_names = [\"CON\", \"PRN\", \"AUX\", \"NUL\"]\n    replace_names += [f\"COM{i}\" for i in range(10)]\n    replace_names += [f\"LPT{i}\" for i in range(10)]\n\n    prefix = \"_qlib_\"\n    if str(code).upper() in replace_names:\n        code = prefix + str(code)\n\n    return code\n\n\ndef fname_to_code(fname: str):\n    \"\"\"file name to stock code\n\n    Parameters\n    ----------\n    fname: str\n    \"\"\"\n\n    prefix = \"_qlib_\"\n    if fname.startswith(prefix):\n        fname = fname.lstrip(prefix)\n    return fname\n\n\nfrom .mod import (\n    get_module_by_module_path,\n    split_module_path,\n    get_callable_kwargs,\n    get_cls_kwargs,\n    init_instance_by_config,\n    class_casting,\n)\n\n__all__ = [\n    \"get_or_create_path\",\n    \"save_multiple_parts_file\",\n    \"unpack_archive_with_buffer\",\n    \"get_tmp_file_with_buffer\",\n    \"set_log_with_config\",\n    \"init_instance_by_config\",\n    \"get_module_by_module_path\",\n    \"split_module_path\",\n    \"get_callable_kwargs\",\n    \"get_cls_kwargs\",\n    \"init_instance_by_config\",\n    \"class_casting\",\n]\n"
  },
  {
    "path": "qlib/utils/data.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\"\"\"\nThis module covers some utility functions that operate on data or basic object\n\"\"\"\n\nfrom copy import deepcopy\nfrom typing import List, Union\n\nimport numpy as np\nimport pandas as pd\n\nfrom qlib.data.data import DatasetProvider\n\n\ndef robust_zscore(x: pd.Series, zscore=False):\n    \"\"\"Robust ZScore Normalization\n\n    Use robust statistics for Z-Score normalization:\n        mean(x) = median(x)\n        std(x) = MAD(x) * 1.4826\n\n    Reference:\n        https://en.wikipedia.org/wiki/Median_absolute_deviation.\n    \"\"\"\n    x = x - x.median()\n    mad = x.abs().median()\n    x = np.clip(x / mad / 1.4826, -3, 3)\n    if zscore:\n        x -= x.mean()\n        x /= x.std()\n    return x\n\n\ndef zscore(x: Union[pd.Series, pd.DataFrame]):\n    return (x - x.mean()).div(x.std())\n\n\ndef deepcopy_basic_type(obj: object) -> object:\n    \"\"\"\n    deepcopy an object without copy the complicated objects.\n        This is useful when you want to generate Qlib tasks and share the handler\n\n    NOTE:\n    - This function can't handle recursive objects!!!!!\n\n    Parameters\n    ----------\n    obj : object\n        the object to be copied\n\n    Returns\n    -------\n    object:\n        The copied object\n    \"\"\"\n    if isinstance(obj, tuple):\n        return tuple(deepcopy_basic_type(i) for i in obj)\n    elif isinstance(obj, list):\n        return list(deepcopy_basic_type(i) for i in obj)\n    elif isinstance(obj, dict):\n        return {k: deepcopy_basic_type(v) for k, v in obj.items()}\n    else:\n        return obj\n\n\nS_DROP = \"__DROP__\"  # this is a symbol which indicates drop the value\n\n\ndef update_config(base_config: dict, ext_config: Union[dict, List[dict]]):\n    \"\"\"\n    supporting adding base config based on the ext_config\n\n    >>> bc = {\"a\": \"xixi\"}\n    >>> ec = {\"b\": \"haha\"}\n    >>> new_bc = update_config(bc, ec)\n    >>> print(new_bc)\n    {'a': 'xixi', 'b': 'haha'}\n    >>> print(bc)  # base config should not be changed\n    {'a': 'xixi'}\n    >>> print(update_config(bc, {\"b\": S_DROP}))\n    {'a': 'xixi'}\n    >>> print(update_config(new_bc, {\"b\": S_DROP}))\n    {'a': 'xixi'}\n    \"\"\"\n\n    base_config = deepcopy(base_config)  # in case of modifying base config\n\n    for ec in ext_config if isinstance(ext_config, (list, tuple)) else [ext_config]:\n        for key in ec:\n            if key not in base_config:\n                # if it is not in the default key, then replace it.\n                # ADD if not drop\n                if ec[key] != S_DROP:\n                    base_config[key] = ec[key]\n\n            else:\n                if isinstance(base_config[key], dict) and isinstance(ec[key], dict):\n                    # Recursive\n                    # Both of them are dict, then update it nested\n                    base_config[key] = update_config(base_config[key], ec[key])\n                elif ec[key] == S_DROP:\n                    # DROP\n                    del base_config[key]\n                else:\n                    # REPLACE\n                    # one of then are not dict. Then replace\n                    base_config[key] = ec[key]\n    return base_config\n\n\ndef guess_horizon(label: List):\n    \"\"\"\n    Try to guess the horizon by parsing label\n    \"\"\"\n    expr = DatasetProvider.parse_fields(label)[0]\n    lft_etd, rght_etd = expr.get_extended_window_size()\n    return rght_etd\n"
  },
  {
    "path": "qlib/utils/exceptions.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\n\n# Base exception class\nclass QlibException(Exception):\n    pass\n\n\nclass RecorderInitializationError(QlibException):\n    \"\"\"Error type for re-initialization when starting an experiment\"\"\"\n\n\nclass LoadObjectError(QlibException):\n    \"\"\"Error type for Recorder when can not load object\"\"\"\n\n\nclass ExpAlreadyExistError(Exception):\n    \"\"\"Experiment already exists\"\"\"\n"
  },
  {
    "path": "qlib/utils/file.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport os\nimport shutil\nimport tempfile\nimport contextlib\nfrom typing import Optional, Text, IO, Union\nfrom pathlib import Path\n\nfrom qlib.log import get_module_logger\n\nlog = get_module_logger(\"utils.file\")\n\n\ndef get_or_create_path(path: Optional[Text] = None, return_dir: bool = False):\n    \"\"\"Create or get a file or directory given the path and return_dir.\n\n    Parameters\n    ----------\n    path: a string indicates the path or None indicates creating a temporary path.\n    return_dir: if True, create and return a directory; otherwise c&r a file.\n\n    \"\"\"\n    if path:\n        if return_dir and not os.path.exists(path):\n            os.makedirs(path)\n        elif not return_dir:  # return a file, thus we need to create its parent directory\n            xpath = os.path.abspath(os.path.join(path, \"..\"))\n            if not os.path.exists(xpath):\n                os.makedirs(xpath)\n    else:\n        temp_dir = os.path.expanduser(\"~/tmp\")\n        if not os.path.exists(temp_dir):\n            os.makedirs(temp_dir)\n        if return_dir:\n            _, path = tempfile.mkdtemp(dir=temp_dir)\n        else:\n            _, path = tempfile.mkstemp(dir=temp_dir)\n    return path\n\n\n@contextlib.contextmanager\ndef save_multiple_parts_file(filename, format=\"gztar\"):\n    \"\"\"Save multiple parts file\n\n    Implementation process:\n        1. get the absolute path to 'filename'\n        2. create a 'filename' directory\n        3. user does something with file_path('filename/')\n        4. remove 'filename' directory\n        5. make_archive 'filename' directory, and rename 'archive file' to filename\n\n    :param filename: result model path\n    :param format: archive format: one of \"zip\", \"tar\", \"gztar\", \"bztar\", or \"xztar\"\n    :return: real model path\n\n    Usage::\n\n        >>> # The following code will create an archive file('~/tmp/test_file') containing 'test_doc_i'(i is 0-10) files.\n        >>> with save_multiple_parts_file('~/tmp/test_file') as filename_dir:\n        ...   for i in range(10):\n        ...       temp_path = os.path.join(filename_dir, 'test_doc_{}'.format(str(i)))\n        ...       with open(temp_path) as fp:\n        ...           fp.write(str(i))\n        ...\n\n    \"\"\"\n\n    if filename.startswith(\"~\"):\n        filename = os.path.expanduser(filename)\n\n    file_path = os.path.abspath(filename)\n\n    # Create model dir\n    if os.path.exists(file_path):\n        raise FileExistsError(\"ERROR: file exists: {}, cannot be create the directory.\".format(file_path))\n\n    os.makedirs(file_path)\n\n    # return model dir\n    yield file_path\n\n    # filename dir to filename.tar.gz file\n    tar_file = shutil.make_archive(file_path, format=format, root_dir=file_path)\n\n    # Remove filename dir\n    if os.path.exists(file_path):\n        shutil.rmtree(file_path)\n\n    # filename.tar.gz rename to filename\n    os.rename(tar_file, file_path)\n\n\n@contextlib.contextmanager\ndef unpack_archive_with_buffer(buffer, format=\"gztar\"):\n    \"\"\"Unpack archive with archive buffer\n    After the call is finished, the archive file and directory will be deleted.\n\n    Implementation process:\n        1. create 'tempfile' in '~/tmp/' and directory\n        2. 'buffer' write to 'tempfile'\n        3. unpack archive file('tempfile')\n        4. user does something with file_path('tempfile/')\n        5. remove 'tempfile' and 'tempfile directory'\n\n    :param buffer: bytes\n    :param format: archive format: one of \"zip\", \"tar\", \"gztar\", \"bztar\", or \"xztar\"\n    :return: unpack archive directory path\n\n    Usage::\n\n        >>> # The following code is to print all the file names in 'test_unpack.tar.gz'\n        >>> with open('test_unpack.tar.gz') as fp:\n        ...     buffer = fp.read()\n        ...\n        >>> with unpack_archive_with_buffer(buffer) as temp_dir:\n        ...     for f_n in os.listdir(temp_dir):\n        ...         print(f_n)\n        ...\n\n    \"\"\"\n    temp_dir = os.path.expanduser(\"~/tmp\")\n    if not os.path.exists(temp_dir):\n        os.makedirs(temp_dir)\n    with tempfile.NamedTemporaryFile(\"wb\", delete=False, dir=temp_dir) as fp:\n        fp.write(buffer)\n        file_path = fp.name\n\n    try:\n        tar_file = file_path + \".tar.gz\"\n        os.rename(file_path, tar_file)\n        # Create dir\n        os.makedirs(file_path)\n        shutil.unpack_archive(tar_file, format=format, extract_dir=file_path)\n\n        # Return temp dir\n        yield file_path\n\n    except Exception as e:\n        log.error(str(e))\n    finally:\n        # Remove temp tar file\n        if os.path.exists(tar_file):\n            os.unlink(tar_file)\n\n        # Remove temp model dir\n        if os.path.exists(file_path):\n            shutil.rmtree(file_path)\n\n\n@contextlib.contextmanager\ndef get_tmp_file_with_buffer(buffer):\n    temp_dir = os.path.expanduser(\"~/tmp\")\n    if not os.path.exists(temp_dir):\n        os.makedirs(temp_dir)\n    with tempfile.NamedTemporaryFile(\"wb\", delete=True, dir=temp_dir) as fp:\n        fp.write(buffer)\n        file_path = fp.name\n        yield file_path\n\n\n@contextlib.contextmanager\ndef get_io_object(file: Union[IO, str, Path], *args, **kwargs) -> IO:\n    \"\"\"\n    providing a easy interface to get an IO object\n\n    Parameters\n    ----------\n    file : Union[IO, str, Path]\n        a object representing the file\n\n    Returns\n    -------\n    IO:\n        a IO-like object\n\n    Raises\n    ------\n    NotImplementedError:\n    \"\"\"\n    if isinstance(file, IO):\n        yield file\n    else:\n        if isinstance(file, str):\n            file = Path(file)\n        if not isinstance(file, Path):\n            raise NotImplementedError(f\"This type[{type(file)}] of input is not supported\")\n        with file.open(*args, **kwargs) as f:\n            yield f\n"
  },
  {
    "path": "qlib/utils/index_data.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\"\"\"\nMotivation of index_data\n- Pandas has a lot of user-friendly interfaces. However, integrating too much features in a single tool bring too much overhead and makes it much slower than numpy.\n    Some users just want a simple numpy dataframe with indices and don't want such a complicated tools.\n    Such users are the target of `index_data`\n\n`index_data` try to behave like pandas (some API will be different because we try to be simpler and more intuitive) but don't compromise the performance. It provides the basic numpy data and simple indexing feature. If users call APIs which may compromise the performance, index_data will raise Errors.\n\"\"\"\n\nfrom __future__ import annotations\n\nfrom typing import Dict, Tuple, Union, Callable, List\nimport bisect\n\nimport numpy as np\nimport pandas as pd\n\n\ndef concat(data_list: Union[SingleData], axis=0) -> MultiData:\n    \"\"\"concat all SingleData by index.\n    TODO: now just for SingleData.\n\n    Parameters\n    ----------\n    data_list : List[SingleData]\n        the list of all SingleData to concat.\n\n    Returns\n    -------\n    MultiData\n        the MultiData with ndim == 2\n    \"\"\"\n    if axis == 0:\n        raise NotImplementedError(f\"please implement this func when axis == 0\")\n    elif axis == 1:\n        # get all index and row\n        all_index = set()\n        for index_data in data_list:\n            all_index = all_index | set(index_data.index)\n        all_index = list(all_index)\n        all_index.sort()\n        all_index_map = dict(zip(all_index, range(len(all_index))))\n\n        # concat all\n        tmp_data = np.full((len(all_index), len(data_list)), np.nan)\n        for data_id, index_data in enumerate(data_list):\n            assert isinstance(index_data, SingleData)\n            now_data_map = [all_index_map[index] for index in index_data.index]\n            tmp_data[now_data_map, data_id] = index_data.data\n        return MultiData(tmp_data, all_index)\n    else:\n        raise ValueError(f\"axis must be 0 or 1\")\n\n\ndef sum_by_index(data_list: Union[SingleData], new_index: list, fill_value=0) -> SingleData:\n    \"\"\"concat all SingleData by new index.\n\n    Parameters\n    ----------\n    data_list : List[SingleData]\n        the list of all SingleData to sum.\n    new_index : list\n        the new_index of new SingleData.\n    fill_value : float\n        fill the missing values or replace np.nan.\n\n    Returns\n    -------\n    SingleData\n        the SingleData with new_index and values after sum.\n    \"\"\"\n    data_list = [data.to_dict() for data in data_list]\n    data_sum = {}\n    for id in new_index:\n        item_sum = 0\n        for data in data_list:\n            if id in data and not np.isnan(data[id]):\n                item_sum += data[id]\n            else:\n                item_sum += fill_value\n        data_sum[id] = item_sum\n    return SingleData(data_sum)\n\n\nclass Index:\n    \"\"\"\n    This is for indexing(rows or columns)\n\n    Read-only operations has higher priorities than others.\n    So this class is designed in a **read-only** way to shared data for queries.\n    Modifications will results in new Index.\n\n    NOTE: the indexing has following flaws\n    - duplicated index value is not well supported (only the first appearance will be considered)\n    - The order of the index is not considered!!!! So the slicing will not behave like pandas when indexings are ordered\n    \"\"\"\n\n    def __init__(self, idx_list: Union[List, pd.Index, \"Index\", int]):\n        self.idx_list: np.ndarray = None  # using array type for index list will make things easier\n        if isinstance(idx_list, Index):\n            # Fast read-only copy\n            self.idx_list = idx_list.idx_list\n            self.index_map = idx_list.index_map\n            self._is_sorted = idx_list._is_sorted\n        elif isinstance(idx_list, int):\n            self.index_map = self.idx_list = np.arange(idx_list)\n            self._is_sorted = True\n        else:\n            # Check if all elements in idx_list are of the same type\n            if not all(isinstance(x, type(idx_list[0])) for x in idx_list):\n                raise TypeError(\"All elements in idx_list must be of the same type\")\n            # Check if all elements in idx_list are of the same datetime64 precision\n            if isinstance(idx_list[0], np.datetime64) and not all(x.dtype == idx_list[0].dtype for x in idx_list):\n                raise TypeError(\"All elements in idx_list must be of the same datetime64 precision\")\n            self.idx_list = np.array(idx_list)\n            # NOTE: only the first appearance is indexed\n            self.index_map = dict(zip(self.idx_list, range(len(self))))\n            self._is_sorted = False\n\n    def __getitem__(self, i: int):\n        return self.idx_list[i]\n\n    def _convert_type(self, item):\n        \"\"\"\n\n        After user creates indices with Type A, user may query data with other types with the same info.\n            This method try to make type conversion and make query sane rather than raising KeyError strictly\n\n        Parameters\n        ----------\n        item :\n            The item to query index\n        \"\"\"\n\n        if self.idx_list.dtype.type is np.datetime64:\n            if isinstance(item, pd.Timestamp):\n                # This happens often when creating index based on pandas.DatetimeIndex and query with pd.Timestamp\n                return item.to_numpy().astype(self.idx_list.dtype)\n            elif isinstance(item, np.datetime64):\n                # This happens often when creating index based on np.datetime64 and query with another precision\n                return item.astype(self.idx_list.dtype)\n            # NOTE: It is hard to consider every case at first.\n            # We just try to cover part of cases to make it more user-friendly\n        return item\n\n    def index(self, item) -> int:\n        \"\"\"\n        Given the index value, get the integer index\n\n        Parameters\n        ----------\n        item :\n            The item to query\n\n        Returns\n        -------\n        int:\n            The index of the item\n\n        Raises\n        ------\n        KeyError:\n            If the query item does not exist\n        \"\"\"\n        try:\n            return self.index_map[self._convert_type(item)]\n        except IndexError as index_e:\n            raise KeyError(f\"{item} can't be found in {self}\") from index_e\n\n    def __or__(self, other: \"Index\"):\n        return Index(idx_list=list(set(self.idx_list) | set(other.idx_list)))\n\n    def __eq__(self, other: \"Index\"):\n        # NOTE:  np.nan is not supported in the index\n        if self.idx_list.shape != other.idx_list.shape:\n            return False\n        return (self.idx_list == other.idx_list).all()\n\n    def __len__(self):\n        return len(self.idx_list)\n\n    def is_sorted(self):\n        return self._is_sorted\n\n    def sort(self) -> Tuple[\"Index\", np.ndarray]:\n        \"\"\"\n        sort the index\n\n        Returns\n        -------\n        Tuple[\"Index\", np.ndarray]:\n            the sorted Index and the changed index\n        \"\"\"\n        sorted_idx = np.argsort(self.idx_list)\n        idx = Index(self.idx_list[sorted_idx])\n        idx._is_sorted = True\n        return idx, sorted_idx\n\n    def tolist(self):\n        \"\"\"return the index with the format of list.\"\"\"\n        return self.idx_list.tolist()\n\n\nclass LocIndexer:\n    \"\"\"\n    `Indexer` will behave like the `LocIndexer` in Pandas\n\n    Read-only operations has higher priorities than others.\n    So this class is designed in a read-only way to shared data for queries.\n    Modifications will results in new Index.\n    \"\"\"\n\n    def __init__(self, index_data: \"IndexData\", indices: List[Index], int_loc: bool = False):\n        self._indices: List[Index] = indices\n        self._bind_id = index_data  # bind index data\n        self._int_loc = int_loc\n        assert self._bind_id.data.ndim == len(self._indices)\n\n    @staticmethod\n    def proc_idx_l(indices: List[Union[List, pd.Index, Index]], data_shape: Tuple = None) -> List[Index]:\n        \"\"\"process the indices from user and output a list of `Index`\"\"\"\n        res = []\n        for i, idx in enumerate(indices):\n            res.append(Index(data_shape[i] if len(idx) == 0 else idx))\n        return res\n\n    def _slc_convert(self, index: Index, indexing: slice) -> slice:\n        \"\"\"\n        convert value-based indexing to integer-based indexing.\n\n        Parameters\n        ----------\n        index : Index\n            index data.\n        indexing : slice\n            value based indexing data with slice type for indexing.\n\n        Returns\n        -------\n        slice:\n            the integer based slicing\n        \"\"\"\n        if index.is_sorted():\n            int_start = None if indexing.start is None else bisect.bisect_left(index, indexing.start)\n            int_stop = None if indexing.stop is None else bisect.bisect_right(index, indexing.stop)\n        else:\n            int_start = None if indexing.start is None else index.index(indexing.start)\n            int_stop = None if indexing.stop is None else index.index(indexing.stop) + 1\n        return slice(int_start, int_stop)\n\n    def __getitem__(self, indexing):\n        \"\"\"\n\n        Parameters\n        ----------\n        indexing :\n            query for data\n\n        Raises\n        ------\n        KeyError:\n            If the non-slice index is queried but does not exist, `KeyError` is raised.\n        \"\"\"\n        # 1) convert slices to int loc\n        if not isinstance(indexing, tuple):\n            # NOTE: tuple is not supported for indexing\n            indexing = (indexing,)\n\n        # TODO: create a subclass for single value query\n        assert len(indexing) <= len(self._indices)\n\n        int_indexing = []\n        for dim, index in enumerate(self._indices):\n            if dim < len(indexing):\n                _indexing = indexing[dim]\n                if not self._int_loc:  # type converting is only necessary when it is not `iloc`\n                    if isinstance(_indexing, slice):\n                        _indexing = self._slc_convert(index, _indexing)\n                    elif isinstance(_indexing, (IndexData, np.ndarray)):\n                        if isinstance(_indexing, IndexData):\n                            _indexing = _indexing.data\n                        assert _indexing.ndim == 1\n                        if _indexing.dtype != bool:\n                            _indexing = np.array(list(index.index(i) for i in _indexing))\n                    else:\n                        _indexing = index.index(_indexing)\n            else:\n                # Default to select all when user input is not given\n                _indexing = slice(None)\n            int_indexing.append(_indexing)\n\n        # 2) select data and index\n        new_data = self._bind_id.data[tuple(int_indexing)]\n        # return directly if it is scalar\n        if new_data.ndim == 0:\n            return new_data\n        # otherwise we go on to the index part\n        new_indices = [idx[indexing] for idx, indexing in zip(self._indices, int_indexing)]\n\n        # 3) squash dimensions\n        new_indices = [\n            idx for idx in new_indices if isinstance(idx, np.ndarray) and idx.ndim > 0\n        ]  # squash the zero dim indexing\n\n        if new_data.ndim == 1:\n            cls = SingleData\n        elif new_data.ndim == 2:\n            cls = MultiData\n        else:\n            raise ValueError(\"Not supported\")\n        return cls(new_data, *new_indices)\n\n\nclass BinaryOps:\n    def __init__(self, method_name):\n        self.method_name = method_name\n\n    def __get__(self, obj, *args):\n        # bind object\n        self.obj = obj\n        return self\n\n    def __call__(self, other):\n        self_data_method = getattr(self.obj.data, self.method_name)\n\n        if isinstance(other, (int, float, np.number)):\n            return self.obj.__class__(self_data_method(other), *self.obj.indices)\n        elif isinstance(other, self.obj.__class__):\n            other_aligned = self.obj._align_indices(other)\n            return self.obj.__class__(self_data_method(other_aligned.data), *self.obj.indices)\n        else:\n            return NotImplemented\n\n\ndef index_data_ops_creator(*args, **kwargs):\n    \"\"\"\n    meta class for auto generating operations for index data.\n    \"\"\"\n    for method_name in [\"__add__\", \"__sub__\", \"__rsub__\", \"__mul__\", \"__truediv__\", \"__eq__\", \"__gt__\", \"__lt__\"]:\n        args[2][method_name] = BinaryOps(method_name=method_name)\n    return type(*args)\n\n\nclass IndexData(metaclass=index_data_ops_creator):\n    \"\"\"\n    Base data structure of SingleData and MultiData.\n\n    NOTE:\n    - For performance issue, only **np.floating** is supported in the underlayer data !!!\n    - Boolean based on np.floating is also supported. Here are some examples\n\n    .. code-block:: python\n\n        np.array([ np.nan]).any() -> True\n        np.array([ np.nan]).all() -> True\n        np.array([1. , 0.]).any() -> True\n        np.array([1. , 0.]).all() -> False\n    \"\"\"\n\n    loc_idx_cls = LocIndexer\n\n    def __init__(self, data: np.ndarray, *indices: Union[List, pd.Index, Index]):\n        self.data = data\n        self.indices = indices\n\n        # get the expected data shape\n        # - The index has higher priority\n        self.data = np.array(data)\n\n        expected_dim = max(self.data.ndim, len(indices))\n\n        data_shape = []\n        for i in range(expected_dim):\n            idx_l = indices[i] if len(indices) > i else []\n            if len(idx_l) == 0:\n                data_shape.append(self.data.shape[i])\n            else:\n                data_shape.append(len(idx_l))\n        data_shape = tuple(data_shape)\n\n        # broadcast the data to expected shape\n        if self.data.shape != data_shape:\n            self.data = np.broadcast_to(self.data, data_shape)\n\n        self.data = self.data.astype(np.float64)\n        # Please notice following cases when converting the type\n        # - np.array([None, 1]).astype(np.float64) -> array([nan,  1.])\n\n        # create index from user's index data.\n        self.indices: List[Index] = self.loc_idx_cls.proc_idx_l(indices, data_shape)\n\n        for dim in range(expected_dim):\n            assert self.data.shape[dim] == len(self.indices[dim])\n\n        self.ndim = expected_dim\n\n    # indexing related methods\n    @property\n    def loc(self):\n        return self.loc_idx_cls(index_data=self, indices=self.indices)\n\n    @property\n    def iloc(self):\n        return self.loc_idx_cls(index_data=self, indices=self.indices, int_loc=True)\n\n    @property\n    def index(self):\n        return self.indices[0]\n\n    @property\n    def columns(self):\n        return self.indices[1]\n\n    def __getitem__(self, args):\n        # NOTE: this tries to behave like a numpy array to be compatible with numpy aggregating function like nansum and nanmean\n        return self.iloc[args]\n\n    def _align_indices(self, other: \"IndexData\") -> \"IndexData\":\n        \"\"\"\n        Align all indices of `other` to `self` before performing the arithmetic operations.\n        This function will return a new IndexData rather than changing data in `other` inplace\n\n        Parameters\n        ----------\n        other : \"IndexData\"\n            the index in `other` is to be changed\n\n        Returns\n        -------\n        IndexData:\n            the data in `other` with index aligned to `self`\n        \"\"\"\n        raise NotImplementedError(f\"please implement _align_indices func\")\n\n    def sort_index(self, axis=0, inplace=True):\n        assert inplace, \"Only support sorting inplace now\"\n        self.indices[axis], sorted_idx = self.indices[axis].sort()\n        self.data = np.take(self.data, sorted_idx, axis=axis)\n\n    # The code below could be simpler like methods in __getattribute__\n    def __invert__(self):\n        return self.__class__(~self.data.astype(bool), *self.indices)\n\n    def abs(self):\n        \"\"\"get the abs of data except np.nan.\"\"\"\n        tmp_data = np.absolute(self.data)\n        return self.__class__(tmp_data, *self.indices)\n\n    def replace(self, to_replace: Dict[np.number, np.number]):\n        assert isinstance(to_replace, dict)\n        tmp_data = self.data.copy()\n        for num in to_replace:\n            if num in tmp_data:\n                tmp_data[self.data == num] = to_replace[num]\n        return self.__class__(tmp_data, *self.indices)\n\n    def apply(self, func: Callable):\n        \"\"\"apply a function to data.\"\"\"\n        tmp_data = func(self.data)\n        return self.__class__(tmp_data, *self.indices)\n\n    def __len__(self):\n        \"\"\"the length of the data.\n\n        Returns\n        -------\n        int\n            the length of the data.\n        \"\"\"\n        return len(self.data)\n\n    def sum(self, axis=None, dtype=None, out=None):\n        assert out is None and dtype is None, \"`out` is just for compatible with numpy's aggregating function\"\n        # FIXME: weird logic and not general\n        if axis is None:\n            return np.nansum(self.data)\n        elif axis == 0:\n            tmp_data = np.nansum(self.data, axis=0)\n            return SingleData(tmp_data, self.columns)\n        elif axis == 1:\n            tmp_data = np.nansum(self.data, axis=1)\n            return SingleData(tmp_data, self.index)\n        else:\n            raise ValueError(f\"axis must be None, 0 or 1\")\n\n    def mean(self, axis=None, dtype=None, out=None):\n        assert out is None and dtype is None, \"`out` is just for compatible with numpy's aggregating function\"\n        # FIXME: weird logic and not general\n        if axis is None:\n            return np.nanmean(self.data)\n        elif axis == 0:\n            tmp_data = np.nanmean(self.data, axis=0)\n            return SingleData(tmp_data, self.columns)\n        elif axis == 1:\n            tmp_data = np.nanmean(self.data, axis=1)\n            return SingleData(tmp_data, self.index)\n        else:\n            raise ValueError(f\"axis must be None, 0 or 1\")\n\n    def isna(self):\n        return self.__class__(np.isnan(self.data), *self.indices)\n\n    def fillna(self, value=0.0, inplace: bool = False):\n        if inplace:\n            self.data = np.nan_to_num(self.data, nan=value)\n        else:\n            return self.__class__(np.nan_to_num(self.data, nan=value), *self.indices)\n\n    def count(self):\n        return len(self.data[~np.isnan(self.data)])\n\n    def all(self):\n        if None in self.data:\n            return self.data[self.data is not None].all()\n        else:\n            return self.data.all()\n\n    @property\n    def empty(self):\n        return len(self.data) == 0\n\n    @property\n    def values(self):\n        return self.data\n\n\nclass SingleData(IndexData):\n    def __init__(\n        self, data: Union[int, float, np.number, list, dict, pd.Series] = [], index: Union[List, pd.Index, Index] = []\n    ):\n        \"\"\"A data structure of index and numpy data.\n        It's used to replace pd.Series due to high-speed.\n\n        Parameters\n        ----------\n        data : Union[int, float, np.number, list, dict, pd.Series]\n            the input data\n        index : Union[list, pd.Index]\n            the index of data.\n            empty list indicates that auto filling the index to the length of data\n        \"\"\"\n        # for special data type\n        if isinstance(data, dict):\n            assert len(index) == 0\n            if len(data) > 0:\n                index, data = zip(*data.items())\n            else:\n                index, data = [], []\n        elif isinstance(data, pd.Series):\n            assert len(index) == 0\n            index, data = data.index, data.values\n        elif isinstance(data, (int, float, np.number)):\n            data = [data]\n        super().__init__(data, index)\n        assert self.ndim == 1\n\n    def _align_indices(self, other):\n        if self.index == other.index:\n            return other\n        elif set(self.index) == set(other.index):\n            return other.reindex(self.index)\n        else:\n            raise ValueError(\n                f\"The indexes of self and other do not meet the requirements of the four arithmetic operations\"\n            )\n\n    def reindex(self, index: Index, fill_value=np.nan) -> SingleData:\n        \"\"\"reindex data and fill the missing value with np.nan.\n\n        Parameters\n        ----------\n        new_index : list\n            new index\n        fill_value:\n            what value to fill if index is missing\n\n        Returns\n        -------\n        SingleData\n            reindex data\n        \"\"\"\n        # TODO: This method can be more general\n        if self.index == index:\n            return self\n        tmp_data = np.full(len(index), fill_value, dtype=np.float64)\n        for index_id, index_item in enumerate(index):\n            try:\n                tmp_data[index_id] = self.loc[index_item]\n            except KeyError:\n                pass\n        return SingleData(tmp_data, index)\n\n    def add(self, other: SingleData, fill_value=0):\n        # TODO: add and __add__ are a little confusing.\n        # This could be a more general\n        common_index = self.index | other.index\n        common_index, _ = common_index.sort()\n        tmp_data1 = self.reindex(common_index, fill_value)\n        tmp_data2 = other.reindex(common_index, fill_value)\n        return tmp_data1.fillna(fill_value) + tmp_data2.fillna(fill_value)\n\n    def to_dict(self):\n        \"\"\"convert SingleData to dict.\n\n        Returns\n        -------\n        dict\n            data with the dict format.\n        \"\"\"\n        return dict(zip(self.index, self.data.tolist()))\n\n    def to_series(self):\n        return pd.Series(self.data, index=self.index)\n\n    def __repr__(self) -> str:\n        return str(pd.Series(self.data, index=self.index.tolist()))\n\n\nclass MultiData(IndexData):\n    def __init__(\n        self,\n        data: Union[int, float, np.number, list] = [],\n        index: Union[List, pd.Index, Index] = [],\n        columns: Union[List, pd.Index, Index] = [],\n    ):\n        \"\"\"A data structure of index and numpy data.\n        It's used to replace pd.DataFrame due to high-speed.\n\n        Parameters\n        ----------\n        data : Union[list, np.ndarray]\n            the dim of data must be 2.\n        index : Union[List, pd.Index, Index]\n            the index of data.\n        columns: Union[List, pd.Index, Index]\n            the columns of data.\n        \"\"\"\n        if isinstance(data, pd.DataFrame):\n            index, columns, data = data.index, data.columns, data.values\n        super().__init__(data, index, columns)\n        assert self.ndim == 2\n\n    def _align_indices(self, other):\n        if self.indices == other.indices:\n            return other\n        else:\n            raise ValueError(\n                f\"The indexes of self and other do not meet the requirements of the four arithmetic operations\"\n            )\n\n    def __repr__(self) -> str:\n        return str(pd.DataFrame(self.data, index=self.index.tolist(), columns=self.columns.tolist()))\n"
  },
  {
    "path": "qlib/utils/mod.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\"\"\"\nAll module related class, e.g. :\n- importing a module, class\n- walkiing a module\n- operations on class or module...\n\"\"\"\n\nimport contextlib\nimport importlib\nimport os\nfrom pathlib import Path\nimport pkgutil\nimport re\nimport sys\nfrom types import ModuleType\nfrom typing import Any, Dict, List, Tuple, Union\nfrom urllib.parse import urlparse\n\nfrom qlib.typehint import InstConf\nfrom qlib.utils.pickle_utils import restricted_pickle_load\n\n\ndef get_module_by_module_path(module_path: Union[str, ModuleType]):\n    \"\"\"Load module path\n\n    :param module_path:\n    :return:\n    :raises: ModuleNotFoundError\n    \"\"\"\n    if module_path is None:\n        raise ModuleNotFoundError(\"None is passed in as parameters as module_path\")\n\n    if isinstance(module_path, ModuleType):\n        module = module_path\n    else:\n        if module_path.endswith(\".py\"):\n            module_name = re.sub(\"^[^a-zA-Z_]+\", \"\", re.sub(\"[^0-9a-zA-Z_]\", \"\", module_path[:-3].replace(\"/\", \"_\")))\n            module_spec = importlib.util.spec_from_file_location(module_name, module_path)\n            module = importlib.util.module_from_spec(module_spec)\n            sys.modules[module_name] = module\n            module_spec.loader.exec_module(module)\n        else:\n            module = importlib.import_module(module_path)\n    return module\n\n\ndef split_module_path(module_path: str) -> Tuple[str, str]:\n    \"\"\"\n\n    Parameters\n    ----------\n    module_path : str\n        e.g. \"a.b.c.ClassName\"\n\n    Returns\n    -------\n    Tuple[str, str]\n        e.g. (\"a.b.c\", \"ClassName\")\n    \"\"\"\n    *m_path, cls = module_path.split(\".\")\n    m_path = \".\".join(m_path)\n    return m_path, cls\n\n\ndef get_callable_kwargs(config: InstConf, default_module: Union[str, ModuleType] = None) -> (type, dict):\n    \"\"\"\n    extract class/func and kwargs from config info\n\n    Parameters\n    ----------\n    config : [dict, str]\n        similar to config\n        please refer to the doc of init_instance_by_config\n\n    default_module : Python module or str\n        It should be a python module to load the class type\n        This function will load class from the config['module_path'] first.\n        If config['module_path'] doesn't exists, it will load the class from default_module.\n\n    Returns\n    -------\n    (type, dict):\n        the class/func object and it's arguments.\n\n    Raises\n    ------\n        ModuleNotFoundError\n    \"\"\"\n    if isinstance(config, dict):\n        key = \"class\" if \"class\" in config else \"func\"\n        if isinstance(config[key], str):\n            # 1) get module and class\n            # - case 1): \"a.b.c.ClassName\"\n            # - case 2): {\"class\": \"ClassName\", \"module_path\": \"a.b.c\"}\n            m_path, cls = split_module_path(config[key])\n            if m_path == \"\":\n                m_path = config.get(\"module_path\", default_module)\n            module = get_module_by_module_path(m_path)\n\n            # 2) get callable\n            _callable = getattr(module, cls)  # may raise AttributeError\n        else:\n            _callable = config[key]  # the class type itself is passed in\n        kwargs = config.get(\"kwargs\", {})\n    elif isinstance(config, str):\n        # a.b.c.ClassName\n        m_path, cls = split_module_path(config)\n        module = get_module_by_module_path(default_module if m_path == \"\" else m_path)\n\n        _callable = getattr(module, cls)\n        kwargs = {}\n    else:\n        raise NotImplementedError(f\"This type of input is not supported\")\n    return _callable, kwargs\n\n\nget_cls_kwargs = get_callable_kwargs  # NOTE: this is for compatibility for the previous version\n\n\ndef init_instance_by_config(\n    config: InstConf,\n    default_module=None,\n    accept_types: Union[type, Tuple[type]] = (),\n    try_kwargs: Dict = {},\n    **kwargs,\n) -> Any:\n    \"\"\"\n    get initialized instance with config\n\n    Parameters\n    ----------\n    config : InstConf\n\n    default_module : Python module\n        Optional. It should be a python module.\n        NOTE: the \"module_path\" will be override by `module` arguments\n\n        This function will load class from the config['module_path'] first.\n        If config['module_path'] doesn't exists, it will load the class from default_module.\n\n    accept_types: Union[type, Tuple[type]]\n        Optional. If the config is a instance of specific type, return the config directly.\n        This will be passed into the second parameter of isinstance.\n\n    try_kwargs: Dict\n        Try to pass in kwargs in `try_kwargs` when initialized the instance\n        If error occurred, it will fail back to initialization without try_kwargs.\n\n    Returns\n    -------\n    object:\n        An initialized object based on the config info\n    \"\"\"\n    if isinstance(config, accept_types):\n        return config\n\n    if isinstance(config, (str, Path)):\n        if isinstance(config, str):\n            # path like 'file:///<path to pickle file>/obj.pkl'\n            pr = urlparse(config)\n            if pr.scheme == \"file\":\n                # To enable relative path like file://data/a/b/c.pkl.  pr.netloc will be data\n                path = pr.path\n                if pr.netloc != \"\":\n                    path = path.lstrip(\"/\")\n\n                pr_path = os.path.join(pr.netloc, path) if bool(pr.path) else pr.netloc\n                with open(os.path.normpath(pr_path), \"rb\") as f:\n                    return restricted_pickle_load(f)\n        else:\n            with config.open(\"rb\") as f:\n                return restricted_pickle_load(f)\n\n    klass, cls_kwargs = get_callable_kwargs(config, default_module=default_module)\n\n    try:\n        return klass(**cls_kwargs, **try_kwargs, **kwargs)\n    except (TypeError,):\n        # TypeError for handling errors like\n        # 1: `XXX() got multiple values for keyword argument 'YYY'`\n        # 2: `XXX() got an unexpected keyword argument 'YYY'\n        return klass(**cls_kwargs, **kwargs)\n\n\n@contextlib.contextmanager\ndef class_casting(obj: object, cls: type):\n    \"\"\"\n    Python doesn't provide the downcasting mechanism.\n    We use the trick here to downcast the class\n\n    Parameters\n    ----------\n    obj : object\n        the object to be cast\n    cls : type\n        the target class type\n    \"\"\"\n    orig_cls = obj.__class__\n    obj.__class__ = cls\n    yield\n    obj.__class__ = orig_cls\n\n\ndef find_all_classes(module_path: Union[str, ModuleType], cls: type) -> List[type]:\n    \"\"\"\n    Find all the classes recursively that inherit from `cls` in a given module.\n    - `cls` itself is also included\n\n        >>> from qlib.data.dataset.handler import DataHandler\n        >>> find_all_classes(\"qlib.contrib.data.handler\", DataHandler)\n        [<class 'qlib.contrib.data.handler.Alpha158'>, <class 'qlib.contrib.data.handler.Alpha158vwap'>, <class 'qlib.contrib.data.handler.Alpha360'>, <class 'qlib.contrib.data.handler.Alpha360vwap'>, <class 'qlib.data.dataset.handler.DataHandlerLP'>]\n\n    TODO:\n    - skip import error\n\n    \"\"\"\n    if isinstance(module_path, ModuleType):\n        mod = module_path\n    else:\n        mod = importlib.import_module(module_path)\n\n    cls_list = []\n\n    def _append_cls(obj):\n        # Leverage the closure trick to reuse code\n        if isinstance(obj, type) and issubclass(obj, cls) and cls not in cls_list:\n            cls_list.append(obj)\n\n    for attr in dir(mod):\n        _append_cls(getattr(mod, attr))\n\n    if hasattr(mod, \"__path__\"):\n        # if the model is a package\n        for _, modname, _ in pkgutil.iter_modules(mod.__path__):\n            sub_mod = importlib.import_module(f\"{mod.__package__}.{modname}\")\n            for m_cls in find_all_classes(sub_mod, cls):\n                _append_cls(m_cls)\n    return cls_list\n"
  },
  {
    "path": "qlib/utils/objm.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\nimport os\nimport pickle\nimport tempfile\nfrom pathlib import Path\n\nfrom qlib.config import C\nfrom qlib.utils.pickle_utils import restricted_pickle_load\n\n\nclass ObjManager:\n    def save_obj(self, obj: object, name: str):\n        \"\"\"\n        save obj as name\n\n        Parameters\n        ----------\n        obj : object\n            object to be saved\n        name : str\n            name of the object\n        \"\"\"\n        raise NotImplementedError(f\"Please implement `save_obj`\")\n\n    def save_objs(self, obj_name_l):\n        \"\"\"\n        save objects\n\n        Parameters\n        ----------\n        obj_name_l : list of <obj, name>\n        \"\"\"\n        raise NotImplementedError(f\"Please implement the `save_objs` method\")\n\n    def load_obj(self, name: str) -> object:\n        \"\"\"\n        load object by name\n\n        Parameters\n        ----------\n        name : str\n            the name of the object\n\n        Returns\n        -------\n        object:\n            loaded object\n        \"\"\"\n        raise NotImplementedError(f\"Please implement the `load_obj` method\")\n\n    def exists(self, name: str) -> bool:\n        \"\"\"\n        if the object named `name` exists\n\n        Parameters\n        ----------\n        name : str\n            name of the objecT\n\n        Returns\n        -------\n        bool:\n            If the object exists\n        \"\"\"\n        raise NotImplementedError(f\"Please implement the `exists` method\")\n\n    def list(self) -> list:\n        \"\"\"\n        list the objects\n\n        Returns\n        -------\n        list:\n            the list of returned objects\n        \"\"\"\n        raise NotImplementedError(f\"Please implement the `list` method\")\n\n    def remove(self, fname=None):\n        \"\"\"remove.\n\n        Parameters\n        ----------\n        fname :\n            if file name is provided. specific file is removed\n            otherwise, The all the objects will be removed.\n        \"\"\"\n        raise NotImplementedError(f\"Please implement the `remove` method\")\n\n\nclass FileManager(ObjManager):\n    \"\"\"\n    Use file system to manage objects\n    \"\"\"\n\n    def __init__(self, path=None):\n        if path is None:\n            self.path = Path(self.create_path())\n        else:\n            self.path = Path(path).resolve()\n\n    def create_path(self) -> str:\n        try:\n            return tempfile.mkdtemp(prefix=str(C[\"file_manager_path\"]) + os.sep)\n        except AttributeError as attribute_e:\n            raise NotImplementedError(\n                f\"If path is not given, the `create_path` function should be implemented\"\n            ) from attribute_e\n\n    def save_obj(self, obj, name):\n        with (self.path / name).open(\"wb\") as f:\n            pickle.dump(obj, f, protocol=C.dump_protocol_version)\n\n    def save_objs(self, obj_name_l):\n        for obj, name in obj_name_l:\n            self.save_obj(obj, name)\n\n    def load_obj(self, name):\n        with (self.path / name).open(\"rb\") as f:\n            return restricted_pickle_load(f)\n\n    def exists(self, name):\n        return (self.path / name).exists()\n\n    def list(self):\n        return list(self.path.iterdir())\n\n    def remove(self, fname=None):\n        if fname is None:\n            for fp in self.path.glob(\"*\"):\n                fp.unlink()\n            self.path.rmdir()\n        else:\n            (self.path / fname).unlink()\n"
  },
  {
    "path": "qlib/utils/paral.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport threading\nfrom functools import partial\nfrom threading import Thread\nfrom typing import Callable, Text, Union\n\nimport joblib\nfrom joblib import Parallel, delayed\nfrom joblib._parallel_backends import MultiprocessingBackend\nimport pandas as pd\n\nfrom queue import Empty, Queue\nimport concurrent\n\nfrom qlib.config import C, QlibConfig\n\n\nclass ParallelExt(Parallel):\n    def __init__(self, *args, **kwargs):\n        maxtasksperchild = kwargs.pop(\"maxtasksperchild\", None)\n        super(ParallelExt, self).__init__(*args, **kwargs)\n        if isinstance(self._backend, MultiprocessingBackend):\n            # 2025-05-04 joblib released version 1.5.0, in which _backend_args was removed and replaced by _backend_kwargs.\n            # Ref: https://github.com/joblib/joblib/pull/1525/files#diff-e4dff8042ce45b443faf49605b75a58df35b8c195978d4a57f4afa695b406bdc\n            if joblib.__version__ < \"1.5.0\":\n                self._backend_args[\"maxtasksperchild\"] = maxtasksperchild  # pylint: disable=E1101\n            else:\n                self._backend_kwargs[\"maxtasksperchild\"] = maxtasksperchild  # pylint: disable=E1101\n\n\ndef datetime_groupby_apply(\n    df, apply_func: Union[Callable, Text], axis=0, level=\"datetime\", resample_rule=\"ME\", n_jobs=-1\n):\n    \"\"\"datetime_groupby_apply\n    This function will apply the `apply_func` on the datetime level index.\n\n    Parameters\n    ----------\n    df :\n        DataFrame for processing\n    apply_func : Union[Callable, Text]\n        apply_func for processing the data\n        if a string is given, then it is treated as naive pandas function\n    axis :\n        which axis is the datetime level located\n    level :\n        which level is the datetime level\n    resample_rule :\n        How to resample the data to calculating parallel\n    n_jobs :\n        n_jobs for joblib\n    Returns:\n        pd.DataFrame\n    \"\"\"\n\n    def _naive_group_apply(df):\n        if isinstance(apply_func, str):\n            return getattr(df.groupby(axis=axis, level=level, group_keys=False), apply_func)()\n        return df.groupby(level=level, group_keys=False).apply(apply_func)\n\n    if n_jobs != 1:\n        dfs = ParallelExt(n_jobs=n_jobs)(\n            delayed(_naive_group_apply)(sub_df) for idx, sub_df in df.resample(resample_rule, level=level)\n        )\n        return pd.concat(dfs, axis=axis).sort_index()\n    else:\n        return _naive_group_apply(df)\n\n\nclass AsyncCaller:\n    \"\"\"\n    This AsyncCaller tries to make it easier to async call\n\n    Currently, it is used in MLflowRecorder to make functions like `log_params` async\n\n    NOTE:\n    - This caller didn't consider the return value\n    \"\"\"\n\n    STOP_MARK = \"__STOP\"\n\n    def __init__(self) -> None:\n        self._q = Queue()\n        self._stop = False\n        self._t = Thread(target=self.run)\n        self._t.start()\n\n    def close(self):\n        self._q.put(self.STOP_MARK)\n\n    def run(self):\n        while True:\n            # NOTE:\n            # atexit will only trigger when all the threads ended. So it may results in deadlock.\n            # So the child-threading should actively watch the status of main threading to stop itself.\n            main_thread = threading.main_thread()\n            if not main_thread.is_alive():\n                break\n            try:\n                data = self._q.get(timeout=1)\n            except Empty:\n                # NOTE: avoid deadlock. make checking main thread possible\n                continue\n            if data == self.STOP_MARK:\n                break\n            data()\n\n    def __call__(self, func, *args, **kwargs):\n        self._q.put(partial(func, *args, **kwargs))\n\n    def wait(self, close=True):\n        if close:\n            self.close()\n        self._t.join()\n\n    @staticmethod\n    def async_dec(ac_attr):\n        def decorator_func(func):\n            def wrapper(self, *args, **kwargs):\n                if isinstance(getattr(self, ac_attr, None), Callable):\n                    return getattr(self, ac_attr)(func, self, *args, **kwargs)\n                else:\n                    return func(self, *args, **kwargs)\n\n            return wrapper\n\n        return decorator_func\n\n\n# # Outlines: Joblib enhancement\n# The code are for implementing following workflow\n# - Construct complex data structure nested with delayed joblib tasks\n#      - For example,  {\"job\": [<delayed_joblib_task>,  {\"1\": <delayed_joblib_task>}]}\n# - executing all the tasks and replace all the <delayed_joblib_task> with its return value\n\n# This will make it easier to convert some existing code to a parallel one\n\n\nclass DelayedTask:\n    def get_delayed_tuple(self):\n        \"\"\"get_delayed_tuple.\n        Return the delayed_tuple created by joblib.delayed\n        \"\"\"\n        raise NotImplementedError(\"NotImplemented\")\n\n    def set_res(self, res):\n        \"\"\"set_res.\n\n        Parameters\n        ----------\n        res :\n            the executed result of the delayed tuple\n        \"\"\"\n        self.res = res\n\n    def get_replacement(self):\n        \"\"\"return the object to replace the delayed task\"\"\"\n        raise NotImplementedError(\"NotImplemented\")\n\n\nclass DelayedTuple(DelayedTask):\n    def __init__(self, delayed_tpl):\n        self.delayed_tpl = delayed_tpl\n        self.res = None\n\n    def get_delayed_tuple(self):\n        return self.delayed_tpl\n\n    def get_replacement(self):\n        return self.res\n\n\nclass DelayedDict(DelayedTask):\n    \"\"\"DelayedDict.\n    It is designed for following feature:\n    Converting following existing code to parallel\n    - constructing a dict\n    - key can be gotten instantly\n    - computation of values tasks a lot of time.\n        - AND ALL the values are calculated in a SINGLE function\n    \"\"\"\n\n    def __init__(self, key_l, delayed_tpl):\n        self.key_l = key_l\n        self.delayed_tpl = delayed_tpl\n\n    def get_delayed_tuple(self):\n        return self.delayed_tpl\n\n    def get_replacement(self):\n        return dict(zip(self.key_l, self.res))\n\n\ndef is_delayed_tuple(obj) -> bool:\n    \"\"\"is_delayed_tuple.\n\n    Parameters\n    ----------\n    obj : object\n\n    Returns\n    -------\n    bool\n        is `obj` joblib.delayed tuple\n    \"\"\"\n    return isinstance(obj, tuple) and len(obj) == 3 and callable(obj[0])\n\n\ndef _replace_and_get_dt(complex_iter):\n    \"\"\"_replace_and_get_dt.\n\n    FIXME: this function may cause infinite loop when the complex data-structure contains loop-reference\n\n    Parameters\n    ----------\n    complex_iter :\n        complex_iter\n    \"\"\"\n    if isinstance(complex_iter, DelayedTask):\n        dt = complex_iter\n        return dt, [dt]\n    elif is_delayed_tuple(complex_iter):\n        dt = DelayedTuple(complex_iter)\n        return dt, [dt]\n    elif isinstance(complex_iter, (list, tuple)):\n        new_ci = []\n        dt_all = []\n        for item in complex_iter:\n            new_item, dt_list = _replace_and_get_dt(item)\n            new_ci.append(new_item)\n            dt_all += dt_list\n        return new_ci, dt_all\n    elif isinstance(complex_iter, dict):\n        new_ci = {}\n        dt_all = []\n        for key, item in complex_iter.items():\n            new_item, dt_list = _replace_and_get_dt(item)\n            new_ci[key] = new_item\n            dt_all += dt_list\n        return new_ci, dt_all\n    else:\n        return complex_iter, []\n\n\ndef _recover_dt(complex_iter):\n    \"\"\"_recover_dt.\n\n    replace all the DelayedTask in the `complex_iter` with its `.res` value\n\n    FIXME: this function may cause infinite loop when the complex data-structure contains loop-reference\n\n    Parameters\n    ----------\n    complex_iter :\n        complex_iter\n    \"\"\"\n    if isinstance(complex_iter, DelayedTask):\n        return complex_iter.get_replacement()\n    elif isinstance(complex_iter, (list, tuple)):\n        return [_recover_dt(item) for item in complex_iter]\n    elif isinstance(complex_iter, dict):\n        return {key: _recover_dt(item) for key, item in complex_iter.items()}\n    else:\n        return complex_iter\n\n\ndef complex_parallel(paral: Parallel, complex_iter):\n    \"\"\"complex_parallel.\n    Find all the delayed function created by delayed in complex_iter, run them parallelly and then replace it with the result\n\n    >>> from qlib.utils.paral import complex_parallel\n    >>> from joblib import Parallel, delayed\n    >>> complex_iter = {\"a\": delayed(sum)([1,2,3]), \"b\": [1, 2, delayed(sum)([10, 1])]}\n    >>> complex_parallel(Parallel(), complex_iter)\n    {'a': 6, 'b': [1, 2, 11]}\n\n    Parameters\n    ----------\n    paral : Parallel\n        paral\n    complex_iter :\n        NOTE: only list, tuple and dict will be explored!!!!\n\n    Returns\n    -------\n    complex_iter whose delayed joblib tasks are replaced with its execution results.\n    \"\"\"\n\n    complex_iter, dt_all = _replace_and_get_dt(complex_iter)\n    for res, dt in zip(paral(dt.get_delayed_tuple() for dt in dt_all), dt_all):\n        dt.set_res(res)\n    complex_iter = _recover_dt(complex_iter)\n    return complex_iter\n\n\nclass call_in_subproc:\n    \"\"\"\n    When we repeatedly run functions, it is hard to avoid memory leakage.\n    So we run it in the subprocess to ensure it is OK.\n\n    NOTE: Because local object can't be pickled. So we can't implement it via closure.\n          We have to implement it via callable Class\n    \"\"\"\n\n    def __init__(self, func: Callable, qlib_config: QlibConfig = None):\n        \"\"\"\n        Parameters\n        ----------\n        func : Callable\n            the function to be wrapped\n\n        qlib_config : QlibConfig\n            Qlib config for initialization in subprocess\n\n        Returns\n        -------\n        Callable\n        \"\"\"\n        self.func = func\n        self.qlib_config = qlib_config\n\n    def _func_mod(self, *args, **kwargs):\n        \"\"\"Modify the initial function by adding Qlib initialization\"\"\"\n        if self.qlib_config is not None:\n            C.register_from_C(self.qlib_config)\n        return self.func(*args, **kwargs)\n\n    def __call__(self, *args, **kwargs):\n        with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor:\n            return executor.submit(self._func_mod, *args, **kwargs).result()\n"
  },
  {
    "path": "qlib/utils/pickle_utils.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\"\"\"\nSecure pickle utilities to prevent arbitrary code execution through deserialization.\n\nThis module provides a secure alternative to pickle.load() and pickle.loads()\nthat restricts deserialization to a whitelist of safe classes.\n\"\"\"\n\nimport io\nimport pickle\nfrom typing import Any, BinaryIO, Set, Tuple\n\n# Whitelist of safe classes that are allowed to be unpickled\n# These are common data types used in qlib that should be safe to deserialize\nSAFE_PICKLE_CLASSES: Set[Tuple[str, str]] = {\n    # python builtins\n    (\"builtins\", \"slice\"),\n    (\"builtins\", \"range\"),\n    (\"builtins\", \"dict\"),\n    (\"builtins\", \"list\"),\n    (\"builtins\", \"tuple\"),\n    (\"builtins\", \"set\"),\n    (\"builtins\", \"frozenset\"),\n    (\"builtins\", \"bytearray\"),\n    (\"builtins\", \"bytes\"),\n    (\"builtins\", \"str\"),\n    (\"builtins\", \"int\"),\n    (\"builtins\", \"float\"),\n    (\"builtins\", \"bool\"),\n    (\"builtins\", \"complex\"),\n    (\"builtins\", \"type\"),\n    (\"builtins\", \"property\"),\n    # common utility classes\n    (\"datetime\", \"datetime\"),\n    (\"datetime\", \"date\"),\n    (\"datetime\", \"time\"),\n    (\"datetime\", \"timedelta\"),\n    (\"datetime\", \"timezone\"),\n    (\"decimal\", \"Decimal\"),\n    (\"collections\", \"OrderedDict\"),\n    (\"collections\", \"defaultdict\"),\n    (\"collections\", \"Counter\"),\n    (\"collections\", \"namedtuple\"),\n    (\"enum\", \"Enum\"),\n    (\"pathlib\", \"Path\"),\n    (\"pathlib\", \"PosixPath\"),\n    (\"pathlib\", \"WindowsPath\"),\n    (\"qlib.data.dataset.handler\", \"DataHandler\"),\n    (\"qlib.data.dataset.handler\", \"DataHandlerLP\"),\n    (\"qlib.data.dataset.loader\", \"StaticDataLoader\"),\n}\n\n\nTRUSTED_MODULE_PREFIXES = (\n    \"pandas\",\n    \"numpy\",\n)\n\n\nclass RestrictedUnpickler(pickle.Unpickler):\n    \"\"\"Custom unpickler that only allows safe classes to be deserialized.\n\n    This prevents arbitrary code execution through malicious pickle files by\n    restricting deserialization to a whitelist of safe classes.\n\n    Example:\n        >>> with open(\"data.pkl\", \"rb\") as f:\n        ...     data = RestrictedUnpickler(f).load()\n    \"\"\"\n\n    def find_class(self, module: str, name: str):\n        \"\"\"Override find_class to restrict allowed classes.\n\n        Args:\n            module: Module name of the class\n            name: Class name\n\n        Returns:\n            The class object if it's in the whitelist\n\n        Raises:\n            pickle.UnpicklingError: If the class is not in the whitelist\n        \"\"\"\n        if module.startswith(TRUSTED_MODULE_PREFIXES):\n            return super().find_class(module, name)\n\n        # 2. explicit whitelist (qlib internal)\n        if (module, name) in SAFE_PICKLE_CLASSES:\n            return super().find_class(module, name)\n\n        raise pickle.UnpicklingError(\n            f\"Forbidden class: {module}.{name}. \"\n            f\"Only whitelisted classes are allowed for security reasons. \"\n            f\"This is to prevent arbitrary code execution through pickle deserialization.\"\n        )\n\n\ndef restricted_pickle_load(file: BinaryIO) -> Any:\n    \"\"\"Safely load a pickle file with restricted classes.\n\n    This is a drop-in replacement for pickle.load() that prevents\n    arbitrary code execution by only allowing whitelisted classes.\n\n    Args:\n        file: An opened file object in binary mode\n\n    Returns:\n        The unpickled Python object\n\n    Raises:\n        pickle.UnpicklingError: If the pickle contains forbidden classes\n\n    Example:\n        >>> with open(\"data.pkl\", \"rb\") as f:\n        ...     data = restricted_pickle_load(f)\n    \"\"\"\n    return RestrictedUnpickler(file).load()\n\n\ndef restricted_pickle_loads(data: bytes) -> Any:\n    \"\"\"Safely load a pickle from bytes with restricted classes.\n\n    This is a drop-in replacement for pickle.loads() that prevents\n    arbitrary code execution by only allowing whitelisted classes.\n\n    Args:\n        data: Bytes object containing pickled data\n\n    Returns:\n        The unpickled Python object\n\n    Raises:\n        pickle.UnpicklingError: If the pickle contains forbidden classes\n\n    Example:\n        >>> data = b'\\\\x80\\\\x04\\\\x95...'\n        >>> obj = restricted_pickle_loads(data)\n    \"\"\"\n    file_like = io.BytesIO(data)\n    return RestrictedUnpickler(file_like).load()\n\n\ndef add_safe_class(module: str, name: str) -> None:\n    \"\"\"Add a class to the whitelist of safe classes for unpickling.\n\n    Use this function to extend the whitelist if your code needs to deserialize\n    additional classes. However, be very careful when adding classes, as this\n    could potentially introduce security vulnerabilities.\n\n    Args:\n        module: Module name of the class (e.g., 'my_package.my_module')\n        name: Class name (e.g., 'MyClass')\n\n    Warning:\n        Only add classes that you fully control and trust. Adding arbitrary\n        classes from external packages could introduce security risks.\n\n    Example:\n        >>> add_safe_class('my_package.models', 'CustomModel')\n    \"\"\"\n    SAFE_PICKLE_CLASSES.add((module, name))\n\n\ndef get_safe_classes() -> Set[Tuple[str, str]]:\n    \"\"\"Get a copy of the current whitelist of safe classes.\n\n    Returns:\n        A set of (module, name) tuples representing allowed classes\n    \"\"\"\n    return SAFE_PICKLE_CLASSES.copy()\n"
  },
  {
    "path": "qlib/utils/resam.py",
    "content": "import numpy as np\nimport pandas as pd\n\nfrom functools import partial\nfrom typing import Union, Callable\n\nfrom . import lazy_sort_index\nfrom .time import Freq, cal_sam_minute\nfrom ..config import C\n\n\ndef resam_calendar(\n    calendar_raw: np.ndarray, freq_raw: Union[str, Freq], freq_sam: Union[str, Freq], region: str = None\n) -> np.ndarray:\n    \"\"\"\n    Resample the calendar with frequency freq_raw into the calendar with frequency freq_sam\n    Assumption:\n        - Fix length (240) of the calendar in each day.\n\n    Parameters\n    ----------\n    calendar_raw : np.ndarray\n        The calendar with frequency  freq_raw\n    freq_raw : str\n        Frequency of the raw calendar\n    freq_sam : str\n        Sample frequency\n    region: str\n        Region, for example, \"cn\", \"us\"\n    Returns\n    -------\n    np.ndarray\n        The calendar with frequency freq_sam\n    \"\"\"\n    if region is None:\n        region = C[\"region\"]\n\n    freq_raw = Freq(freq_raw)\n    freq_sam = Freq(freq_sam)\n    if not len(calendar_raw):\n        return calendar_raw\n\n    # if freq_sam is xminute, divide each trading day into several bars evenly\n    if freq_sam.base == Freq.NORM_FREQ_MINUTE:\n        if freq_raw.base != Freq.NORM_FREQ_MINUTE:\n            raise ValueError(\"when sampling minute calendar, freq of raw calendar must be minute or min\")\n        else:\n            if freq_raw.count > freq_sam.count:\n                raise ValueError(\"raw freq must be higher than sampling freq\")\n        _calendar_minute = np.unique(list(map(lambda x: cal_sam_minute(x, freq_sam.count, region), calendar_raw)))\n        return _calendar_minute\n\n    # else, convert the raw calendar into day calendar, and divide the whole calendar into several bars evenly\n    else:\n        _calendar_day = np.unique(list(map(lambda x: pd.Timestamp(x.year, x.month, x.day, 0, 0, 0), calendar_raw)))\n        if freq_sam.base == Freq.NORM_FREQ_DAY:\n            return _calendar_day[:: freq_sam.count]\n\n        elif freq_sam.base == Freq.NORM_FREQ_WEEK:\n            _day_in_week = np.array(list(map(lambda x: x.dayofweek, _calendar_day)))\n            _calendar_week = _calendar_day[np.ediff1d(_day_in_week, to_begin=-1) < 0]\n            return _calendar_week[:: freq_sam.count]\n\n        elif freq_sam.base == Freq.NORM_FREQ_MONTH:\n            _day_in_month = np.array(list(map(lambda x: x.day, _calendar_day)))\n            _calendar_month = _calendar_day[np.ediff1d(_day_in_month, to_begin=-1) < 0]\n            return _calendar_month[:: freq_sam.count]\n        else:\n            raise ValueError(\"sampling freq must be xmin, xd, xw, xm\")\n\n\ndef get_higher_eq_freq_feature(instruments, fields, start_time=None, end_time=None, freq=\"day\", disk_cache=1):\n    \"\"\"get the feature with higher or equal frequency than `freq`.\n    Returns\n    -------\n    pd.DataFrame\n        the feature with higher or equal frequency\n    \"\"\"\n\n    from ..data.data import D  # pylint: disable=C0415\n\n    try:\n        _result = D.features(instruments, fields, start_time, end_time, freq=freq, disk_cache=disk_cache)\n        _freq = freq\n    except (ValueError, KeyError) as value_key_e:\n        _, norm_freq = Freq.parse(freq)\n        if norm_freq in [Freq.NORM_FREQ_MONTH, Freq.NORM_FREQ_WEEK, Freq.NORM_FREQ_DAY]:\n            try:\n                _result = D.features(instruments, fields, start_time, end_time, freq=\"day\", disk_cache=disk_cache)\n                _freq = \"day\"\n            except (ValueError, KeyError):\n                _result = D.features(instruments, fields, start_time, end_time, freq=\"1min\", disk_cache=disk_cache)\n                _freq = \"1min\"\n        elif norm_freq == Freq.NORM_FREQ_MINUTE:\n            _result = D.features(instruments, fields, start_time, end_time, freq=\"1min\", disk_cache=disk_cache)\n            _freq = \"1min\"\n        else:\n            raise ValueError(f\"freq {freq} is not supported\") from value_key_e\n    return _result, _freq\n\n\ndef resam_ts_data(\n    ts_feature: Union[pd.DataFrame, pd.Series],\n    start_time: Union[str, pd.Timestamp] = None,\n    end_time: Union[str, pd.Timestamp] = None,\n    method: Union[str, Callable] = \"last\",\n    method_kwargs: dict = {},\n):\n    \"\"\"\n    Resample value from time-series data\n\n        - If `feature` has MultiIndex[instrument, datetime], apply the `method` to each instrument data with datetime in [start_time, end_time]\n            Example:\n\n            .. code-block::\n\n                print(feature)\n                                        $close      $volume\n                instrument  datetime\n                SH600000    2010-01-04  86.778313   16162960.0\n                            2010-01-05  87.433578   28117442.0\n                            2010-01-06  85.713585   23632884.0\n                            2010-01-07  83.788803   20813402.0\n                            2010-01-08  84.730675   16044853.0\n\n                SH600655    2010-01-04  2699.567383  158193.328125\n                            2010-01-08  2612.359619   77501.406250\n                            2010-01-11  2712.982422  160852.390625\n                            2010-01-12  2788.688232  164587.937500\n                            2010-01-13  2790.604004  145460.453125\n\n                print(resam_ts_data(feature, start_time=\"2010-01-04\", end_time=\"2010-01-05\", fields=[\"$close\", \"$volume\"], method=\"last\"))\n                            $close      $volume\n                instrument\n                SH600000    87.433578 28117442.0\n                SH600655    2699.567383  158193.328125\n\n        - Else, the `feature` should have Index[datetime], just apply the `method` to `feature` directly\n            Example:\n\n            .. code-block::\n                print(feature)\n                            $close      $volume\n                datetime\n                2010-01-04  86.778313   16162960.0\n                2010-01-05  87.433578   28117442.0\n                2010-01-06  85.713585   23632884.0\n                2010-01-07  83.788803   20813402.0\n                2010-01-08  84.730675   16044853.0\n\n                print(resam_ts_data(feature, start_time=\"2010-01-04\", end_time=\"2010-01-05\", method=\"last\"))\n\n                $close 87.433578\n                $volume 28117442.0\n\n                print(resam_ts_data(feature['$close'], start_time=\"2010-01-04\", end_time=\"2010-01-05\", method=\"last\"))\n\n                87.433578\n\n    Parameters\n    ----------\n    ts_feature : Union[pd.DataFrame, pd.Series]\n        Raw time-series feature to be resampled\n    start_time : Union[str, pd.Timestamp], optional\n        start sampling time, by default None\n    end_time : Union[str, pd.Timestamp], optional\n        end sampling time, by default None\n    method : Union[str, Callable], optional\n        sample method, apply method function to each stock series data, by default \"last\"\n        - If type(method) is str or callable function, it should be an attribute of SeriesGroupBy or DataFrameGroupby, and applies groupy.method for the sliced time-series data\n        - If method is None, do nothing for the sliced time-series data.\n    method_kwargs : dict, optional\n        arguments of method, by default {}\n\n    Returns\n    -------\n        The resampled DataFrame/Series/value, return None when the resampled data is empty.\n    \"\"\"\n\n    selector_datetime = slice(start_time, end_time)\n\n    from ..data.dataset.utils import get_level_index  # pylint: disable=C0415\n\n    feature = lazy_sort_index(ts_feature)\n\n    datetime_level = get_level_index(feature, level=\"datetime\") == 0\n    if datetime_level:\n        feature = feature.loc[selector_datetime]\n    else:\n        feature = feature.loc(axis=0)[(slice(None), selector_datetime)]\n\n    if feature.empty:\n        return None\n    if isinstance(feature.index, pd.MultiIndex):\n        if callable(method):\n            method_func = method\n            return feature.groupby(level=\"instrument\", group_keys=False).apply(method_func, **method_kwargs)\n        elif isinstance(method, str):\n            return getattr(feature.groupby(level=\"instrument\", group_keys=False), method)(**method_kwargs)\n    else:\n        if callable(method):\n            method_func = method\n            return method_func(feature, **method_kwargs)\n        elif isinstance(method, str):\n            return getattr(feature, method)(**method_kwargs)\n    return feature\n\n\ndef get_valid_value(series, last=True):\n    \"\"\"get the first/last not nan value of pd.Series with single level index\n    Parameters\n    ----------\n    series : pd.Series\n        series should not be empty\n    last : bool, optional\n        whether to get the last valid value, by default True\n        - if last is True, get the last valid value\n        - else, get the first valid value\n\n    Returns\n    -------\n    Nan | float\n        the first/last valid value\n    \"\"\"\n    return series.ffill().iloc[-1] if last else series.bfill().iloc[0]\n\n\ndef _ts_data_valid(ts_feature, last=False):\n    \"\"\"get the first/last not nan value of pd.Series|DataFrame with single level index\"\"\"\n    if isinstance(ts_feature, pd.DataFrame):\n        return ts_feature.apply(lambda column: get_valid_value(column, last=last))\n    elif isinstance(ts_feature, pd.Series):\n        return get_valid_value(ts_feature, last=last)\n    else:\n        raise TypeError(f\"ts_feature should be pd.DataFrame/Series, not {type(ts_feature)}\")\n\n\nts_data_last = partial(_ts_data_valid, last=True)\nts_data_first = partial(_ts_data_valid, last=False)\n"
  },
  {
    "path": "qlib/utils/serial.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport pickle\nimport dill\nfrom pathlib import Path\nfrom typing import Union\nfrom ..config import C\n\n\nclass Serializable:\n    \"\"\"\n    Serializable will change the behaviors of pickle.\n\n        The rule to tell if a attribute will be kept or dropped when dumping.\n        The rule with higher priorities is on the top\n        - in the config attribute list -> always dropped\n        - in the include attribute list -> always kept\n        - in the exclude attribute list -> always dropped\n        - name not starts with `_` -> kept\n        - name starts with `_` -> kept if `dump_all` is true else dropped\n\n    It provides a syntactic sugar for distinguish the attributes which user doesn't want.\n    - For examples, a learnable Datahandler just wants to save the parameters without data when dumping to disk\n    \"\"\"\n\n    pickle_backend = \"pickle\"  # another optional value is \"dill\" which can pickle more things of python.\n    default_dump_all = False  # if dump all things\n    config_attr = [\"_include\", \"_exclude\"]\n    exclude_attr = []  # exclude_attr have lower priorities than `self._exclude`\n    include_attr = []  # include_attr have lower priorities then `self._include`\n    FLAG_KEY = \"_qlib_serial_flag\"\n\n    def __init__(self):\n        self._dump_all = self.default_dump_all\n        self._exclude = None  # this attribute have higher priorities than `exclude_attr`\n\n    def _is_kept(self, key):\n        if key in self.config_attr:\n            return False\n        if key in self._get_attr_list(\"include\"):\n            return True\n        if key in self._get_attr_list(\"exclude\"):\n            return False\n        return self.dump_all or not key.startswith(\"_\")\n\n    def __getstate__(self) -> dict:\n        return {k: v for k, v in self.__dict__.items() if self._is_kept(k)}\n\n    def __setstate__(self, state: dict):\n        self.__dict__.update(state)\n\n    @property\n    def dump_all(self):\n        \"\"\"\n        will the object dump all object\n        \"\"\"\n        return getattr(self, \"_dump_all\", False)\n\n    def _get_attr_list(self, attr_type: str) -> list:\n        \"\"\"\n        What attribute will not be in specific list\n\n        Parameters\n        ----------\n        attr_type : str\n            \"include\" or \"exclude\"\n\n        Returns\n        -------\n        list:\n        \"\"\"\n        if hasattr(self, f\"_{attr_type}\"):\n            res = getattr(self, f\"_{attr_type}\", [])\n        else:\n            res = getattr(self.__class__, f\"{attr_type}_attr\", [])\n        if res is None:\n            return []\n        return res\n\n    def config(self, recursive=False, **kwargs):\n        \"\"\"\n        configure the serializable object\n\n        Parameters\n        ----------\n        kwargs may include following keys\n\n            dump_all : bool\n                will the object dump all object\n            exclude : list\n                What attribute will not be dumped\n            include : list\n                What attribute will be dumped\n\n        recursive : bool\n            will the configuration be recursive\n        \"\"\"\n        keys = {\"dump_all\", \"exclude\", \"include\"}\n        for k, v in kwargs.items():\n            if k in keys:\n                attr_name = f\"_{k}\"\n                setattr(self, attr_name, v)\n            else:\n                raise KeyError(f\"Unknown parameter: {k}\")\n\n        if recursive:\n            for obj in self.__dict__.values():\n                # set flag to prevent endless loop\n                self.__dict__[self.FLAG_KEY] = True\n                if isinstance(obj, Serializable) and self.FLAG_KEY not in obj.__dict__:\n                    obj.config(recursive=True, **kwargs)\n                del self.__dict__[self.FLAG_KEY]\n\n    def to_pickle(self, path: Union[Path, str], **kwargs):\n        \"\"\"\n        Dump self to a pickle file.\n\n        path (Union[Path, str]): the path to dump\n\n        kwargs may include following keys\n\n            dump_all : bool\n                will the object dump all object\n            exclude : list\n                What attribute will not be dumped\n            include : list\n                What attribute will be dumped\n        \"\"\"\n        self.config(**kwargs)\n        with Path(path).open(\"wb\") as f:\n            # pickle interface like backend; such as dill\n            self.get_backend().dump(self, f, protocol=C.dump_protocol_version)\n\n    @classmethod\n    def load(cls, filepath):\n        \"\"\"\n        Load the serializable class from a filepath.\n\n        Args:\n            filepath (str): the path of file\n\n        Raises:\n            TypeError: the pickled file must be `type(cls)`\n\n        Returns:\n            `type(cls)`: the instance of `type(cls)`\n        \"\"\"\n        with open(filepath, \"rb\") as f:\n            object = cls.get_backend().load(f)\n        if isinstance(object, cls):\n            return object\n        else:\n            raise TypeError(f\"The instance of {type(object)} is not a valid `{type(cls)}`!\")\n\n    @classmethod\n    def get_backend(cls):\n        \"\"\"\n        Return the real backend of a Serializable class. The pickle_backend value can be \"pickle\" or \"dill\".\n\n        Returns:\n            module: pickle or dill module based on pickle_backend\n        \"\"\"\n        # NOTE: pickle interface like backend; such as dill\n        if cls.pickle_backend == \"pickle\":\n            return pickle\n        elif cls.pickle_backend == \"dill\":\n            return dill\n        else:\n            raise ValueError(\"Unknown pickle backend, please use 'pickle' or 'dill'.\")\n\n    @staticmethod\n    def general_dump(obj, path: Union[Path, str]):\n        \"\"\"\n        A general dumping method for object\n\n        Parameters\n        ----------\n        obj : object\n            the object to be dumped\n        path : Union[Path, str]\n            the target path the data will be dumped\n        \"\"\"\n        path = Path(path)\n        if isinstance(obj, Serializable):\n            obj.to_pickle(path)\n        else:\n            with path.open(\"wb\") as f:\n                pickle.dump(obj, f, protocol=C.dump_protocol_version)\n"
  },
  {
    "path": "qlib/utils/time.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\"\"\"\nTime related utils are compiled in this script\n\"\"\"\n\nimport bisect\nfrom datetime import datetime, time, date, timedelta\nfrom typing import List, Optional, Tuple, Union\nimport functools\nimport re\n\nimport pandas as pd\n\nfrom qlib.config import C\nfrom qlib.constant import REG_CN, REG_TW, REG_US\n\nCN_TIME = [\n    datetime.strptime(\"9:30\", \"%H:%M\"),\n    datetime.strptime(\"11:30\", \"%H:%M\"),\n    datetime.strptime(\"13:00\", \"%H:%M\"),\n    datetime.strptime(\"15:00\", \"%H:%M\"),\n]\nUS_TIME = [datetime.strptime(\"9:30\", \"%H:%M\"), datetime.strptime(\"16:00\", \"%H:%M\")]\nTW_TIME = [\n    datetime.strptime(\"9:00\", \"%H:%M\"),\n    datetime.strptime(\"13:30\", \"%H:%M\"),\n]\n\n\n@functools.lru_cache(maxsize=240)\ndef get_min_cal(shift: int = 0, region: str = REG_CN) -> List[time]:\n    \"\"\"\n    get the minute level calendar in day period\n\n    Parameters\n    ----------\n    shift : int\n        the shift direction would be like pandas shift.\n        series.shift(1) will replace the value at `i`-th with the one at `i-1`-th\n    region: str\n        Region, for example, \"cn\", \"us\"\n\n    Returns\n    -------\n    List[time]:\n\n    \"\"\"\n    cal = []\n\n    if region == REG_CN:\n        for ts in list(\n            pd.date_range(CN_TIME[0], CN_TIME[1] - timedelta(minutes=1), freq=\"1min\") - pd.Timedelta(minutes=shift)\n        ) + list(\n            pd.date_range(CN_TIME[2], CN_TIME[3] - timedelta(minutes=1), freq=\"1min\") - pd.Timedelta(minutes=shift)\n        ):\n            cal.append(ts.time())\n    elif region == REG_TW:\n        for ts in list(\n            pd.date_range(TW_TIME[0], TW_TIME[1] - timedelta(minutes=1), freq=\"1min\") - pd.Timedelta(minutes=shift)\n        ):\n            cal.append(ts.time())\n    elif region == REG_US:\n        for ts in list(\n            pd.date_range(US_TIME[0], US_TIME[1] - timedelta(minutes=1), freq=\"1min\") - pd.Timedelta(minutes=shift)\n        ):\n            cal.append(ts.time())\n    else:\n        raise ValueError(f\"{region} is not supported\")\n    return cal\n\n\ndef is_single_value(start_time, end_time, freq, region: str = REG_CN):\n    \"\"\"Is there only one piece of data for stock market.\n\n    Parameters\n    ----------\n    start_time : Union[pd.Timestamp, str]\n        closed start time for data.\n    end_time : Union[pd.Timestamp, str]\n        closed end time for data.\n    freq :\n    region: str\n        Region, for example, \"cn\", \"us\"\n    Returns\n    -------\n    bool\n        True means one piece of data to obtain.\n    \"\"\"\n    if region == REG_CN:\n        if end_time - start_time < freq:\n            return True\n        if start_time.hour == 11 and start_time.minute == 29 and start_time.second == 0:\n            return True\n        if start_time.hour == 14 and start_time.minute == 59 and start_time.second == 0:\n            return True\n        return False\n    elif region == REG_TW:\n        if end_time - start_time < freq:\n            return True\n        if start_time.hour == 13 and start_time.minute >= 25 and start_time.second == 0:\n            return True\n        return False\n    elif region == REG_US:\n        if end_time - start_time < freq:\n            return True\n        if start_time.hour == 15 and start_time.minute == 59 and start_time.second == 0:\n            return True\n        return False\n    else:\n        raise NotImplementedError(f\"please implement the is_single_value func for {region}\")\n\n\nclass Freq:\n    NORM_FREQ_MONTH = \"month\"\n    NORM_FREQ_WEEK = \"week\"\n    NORM_FREQ_DAY = \"day\"\n    NORM_FREQ_MINUTE = \"min\"  # using min instead of minute for align with Qlib's data filename\n    SUPPORT_CAL_LIST = [NORM_FREQ_MINUTE, NORM_FREQ_DAY]  # FIXME: this list should from data\n\n    def __init__(self, freq: Union[str, \"Freq\"]) -> None:\n        if isinstance(freq, str):\n            self.count, self.base = self.parse(freq)\n        elif isinstance(freq, Freq):\n            self.count, self.base = freq.count, freq.base\n        else:\n            raise NotImplementedError(f\"This type of input is not supported\")\n\n    def __eq__(self, freq):\n        freq = Freq(freq)\n        return freq.count == self.count and freq.base == self.base\n\n    def __str__(self):\n        # trying to align to the filename of Qlib: day, 30min, 5min, 1min...\n        return f\"{self.count if self.count != 1 or self.base != 'day' else ''}{self.base}\"\n\n    def __repr__(self) -> str:\n        return f\"{self.__class__.__name__}({str(self)})\"\n\n    @staticmethod\n    def parse(freq: str) -> Tuple[int, str]:\n        \"\"\"\n        Parse freq into a unified format\n\n        Parameters\n        ----------\n        freq : str\n            Raw freq, supported freq should match the re '^([0-9]*)(month|mon|week|w|day|d|minute|min)$'\n\n        Returns\n        -------\n        freq: Tuple[int, str]\n            Unified freq, including freq count and unified freq unit. The freq unit should be '[month|week|day|minute]'.\n                Example:\n\n                .. code-block::\n\n                    print(Freq.parse(\"day\"))\n                    (1, \"day\" )\n                    print(Freq.parse(\"2mon\"))\n                    (2, \"month\")\n                    print(Freq.parse(\"10w\"))\n                    (10, \"week\")\n\n        \"\"\"\n        freq = freq.lower()\n        match_obj = re.match(\"^([0-9]*)(month|mon|week|w|day|d|minute|min)$\", freq)\n        if match_obj is None:\n            raise ValueError(\n                \"freq format is not supported, the freq should be like (n)month/mon, (n)week/w, (n)day/d, (n)minute/min\"\n            )\n        _count = int(match_obj.group(1)) if match_obj.group(1) else 1\n        _freq = match_obj.group(2)\n        _freq_format_dict = {\n            \"month\": Freq.NORM_FREQ_MONTH,\n            \"mon\": Freq.NORM_FREQ_MONTH,\n            \"week\": Freq.NORM_FREQ_WEEK,\n            \"w\": Freq.NORM_FREQ_WEEK,\n            \"day\": Freq.NORM_FREQ_DAY,\n            \"d\": Freq.NORM_FREQ_DAY,\n            \"minute\": Freq.NORM_FREQ_MINUTE,\n            \"min\": Freq.NORM_FREQ_MINUTE,\n        }\n        return _count, _freq_format_dict[_freq]\n\n    @staticmethod\n    def get_timedelta(n: int, freq: str) -> pd.Timedelta:\n        \"\"\"\n        get pd.Timedeta object\n\n        Parameters\n        ----------\n        n : int\n        freq : str\n            Typically, they are the return value of Freq.parse\n\n        Returns\n        -------\n        pd.Timedelta:\n        \"\"\"\n        return pd.Timedelta(f\"{n}{freq}\")\n\n    @staticmethod\n    def get_min_delta(left_frq: str, right_freq: str):\n        \"\"\"Calculate freq delta\n\n        Parameters\n        ----------\n        left_frq: str\n        right_freq: str\n\n        Returns\n        -------\n\n        \"\"\"\n        minutes_map = {\n            Freq.NORM_FREQ_MINUTE: 1,\n            Freq.NORM_FREQ_DAY: 60 * 24,\n            Freq.NORM_FREQ_WEEK: 7 * 60 * 24,\n            Freq.NORM_FREQ_MONTH: 30 * 7 * 60 * 24,\n        }\n        left_freq = Freq(left_frq)\n        left_minutes = left_freq.count * minutes_map[left_freq.base]\n        right_freq = Freq(right_freq)\n        right_minutes = right_freq.count * minutes_map[right_freq.base]\n        return left_minutes - right_minutes\n\n    @staticmethod\n    def get_recent_freq(base_freq: Union[str, \"Freq\"], freq_list: List[Union[str, \"Freq\"]]) -> Optional[\"Freq\"]:\n        \"\"\"Get the closest freq to base_freq from freq_list\n\n        Parameters\n        ----------\n        base_freq\n        freq_list\n\n        Returns\n        -------\n        if the recent frequency is found\n            Freq\n        else:\n            None\n        \"\"\"\n        base_freq = Freq(base_freq)\n        # use the nearest freq greater than 0\n        min_freq = None\n        for _freq in freq_list:\n            _min_delta = Freq.get_min_delta(base_freq, _freq)\n            if _min_delta < 0:\n                continue\n            if min_freq is None:\n                min_freq = (_min_delta, str(_freq))\n                continue\n            min_freq = min_freq if min_freq[0] <= _min_delta else (_min_delta, _freq)\n        return min_freq[1] if min_freq else None\n\n\ndef time_to_day_index(time_obj: Union[str, datetime], region: str = REG_CN):\n    if isinstance(time_obj, str):\n        time_obj = datetime.strptime(time_obj, \"%H:%M\")\n\n    if region == REG_CN:\n        if CN_TIME[0] <= time_obj < CN_TIME[1]:\n            return int((time_obj - CN_TIME[0]).total_seconds() / 60)\n        elif CN_TIME[2] <= time_obj < CN_TIME[3]:\n            return int((time_obj - CN_TIME[2]).total_seconds() / 60) + 120\n        else:\n            raise ValueError(f\"{time_obj} is not the opening time of the {region} stock market\")\n    elif region == REG_US:\n        if US_TIME[0] <= time_obj < US_TIME[1]:\n            return int((time_obj - US_TIME[0]).total_seconds() / 60)\n        else:\n            raise ValueError(f\"{time_obj} is not the opening time of the {region} stock market\")\n    elif region == REG_TW:\n        if TW_TIME[0] <= time_obj < TW_TIME[1]:\n            return int((time_obj - TW_TIME[0]).total_seconds() / 60)\n        else:\n            raise ValueError(f\"{time_obj} is not the opening time of the {region} stock market\")\n    else:\n        raise ValueError(f\"{region} is not supported\")\n\n\ndef get_day_min_idx_range(start: str, end: str, freq: str, region: str) -> Tuple[int, int]:\n    \"\"\"\n    get the min-bar index in a day for a time range (both left and right is closed) given a fixed frequency\n    Parameters\n    ----------\n    start : str\n        e.g. \"9:30\"\n    end : str\n        e.g. \"14:30\"\n    freq : str\n        \"1min\"\n\n    Returns\n    -------\n    Tuple[int, int]:\n        The index of start and end in the calendar. Both left and right are **closed**\n    \"\"\"\n    start = pd.Timestamp(start).time()\n    end = pd.Timestamp(end).time()\n    freq = Freq(freq)\n    in_day_cal = get_min_cal(region=region)[:: freq.count]\n    left_idx = bisect.bisect_left(in_day_cal, start)\n    right_idx = bisect.bisect_right(in_day_cal, end) - 1\n    return left_idx, right_idx\n\n\ndef concat_date_time(date_obj: date, time_obj: time) -> pd.Timestamp:\n    return pd.Timestamp(\n        datetime(\n            date_obj.year,\n            month=date_obj.month,\n            day=date_obj.day,\n            hour=time_obj.hour,\n            minute=time_obj.minute,\n            second=time_obj.second,\n            microsecond=time_obj.microsecond,\n        )\n    )\n\n\ndef cal_sam_minute(x: pd.Timestamp, sam_minutes: int, region: str = REG_CN) -> pd.Timestamp:\n    \"\"\"\n    align the minute-level data to a down sampled calendar\n\n    e.g. align 10:38 to 10:35 in 5 minute-level(10:30 in 10 minute-level)\n\n    Parameters\n    ----------\n    x : pd.Timestamp\n        datetime to be aligned\n    sam_minutes : int\n        align to `sam_minutes` minute-level calendar\n    region: str\n        Region, for example, \"cn\", \"us\"\n\n    Returns\n    -------\n    pd.Timestamp:\n        the datetime after aligned\n    \"\"\"\n    cal = get_min_cal(C.min_data_shift, region)[::sam_minutes]\n    idx = bisect.bisect_right(cal, x.time()) - 1\n    _date, new_time = x.date(), cal[idx]\n    return concat_date_time(_date, new_time)\n\n\ndef epsilon_change(date_time: pd.Timestamp, direction: str = \"backward\") -> pd.Timestamp:\n    \"\"\"\n    change the time by infinitely small quantity.\n\n\n    Parameters\n    ----------\n    date_time : pd.Timestamp\n        the original time\n    direction : str\n        the direction the time are going to\n        - \"backward\" for going to history\n        - \"forward\" for going to the future\n\n    Returns\n    -------\n    pd.Timestamp:\n        the shifted time\n    \"\"\"\n    if direction == \"backward\":\n        return date_time - pd.Timedelta(seconds=1)\n    elif direction == \"forward\":\n        return date_time + pd.Timedelta(seconds=1)\n    else:\n        raise ValueError(\"Wrong input\")\n\n\nif __name__ == \"__main__\":\n    print(get_day_min_idx_range(\"8:30\", \"14:59\", \"10min\", REG_CN))\n"
  },
  {
    "path": "qlib/workflow/__init__.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\"\"\"\nMotivation of this design (instead of using mlflow directly):\n- Better design than mlflow native design\n    - we have record object with a lot of methods(more intuitive), instead of use run_id everytime in mlflow\n        - So the recorder's interfaces like log, start, will be more intuitive.\n- Provide richer and tailerd features than mlflow native\n    - Logging code diff at the start of run.\n    - log_object and load_object to for Python object directly instead log_artifact and download_artifact\n- (weak) Allow diverse backend support\n\nTo be honest, design always add burdens. For example,\n- You need to create an experiment before you can get a recorder. (In MLflow, experiments are more like tags, and you often just use a run_id in many interfaces without first defining an experiment.)\n\"\"\"\n\nfrom contextlib import contextmanager\nfrom typing import Text, Optional, Any, Dict\nfrom .expm import ExpManager\nfrom .exp import Experiment\nfrom .recorder import Recorder\nfrom ..utils import Wrapper\nfrom ..utils.exceptions import RecorderInitializationError\n\n\nclass QlibRecorder:\n    \"\"\"\n    A global system that helps to manage the experiments.\n    \"\"\"\n\n    def __init__(self, exp_manager: ExpManager):\n        self.exp_manager: ExpManager = exp_manager\n\n    def __repr__(self):\n        return \"{name}(manager={manager})\".format(name=self.__class__.__name__, manager=self.exp_manager)\n\n    @contextmanager\n    def start(\n        self,\n        *,\n        experiment_id: Optional[Text] = None,\n        experiment_name: Optional[Text] = None,\n        recorder_id: Optional[Text] = None,\n        recorder_name: Optional[Text] = None,\n        uri: Optional[Text] = None,\n        resume: bool = False,\n    ):\n        \"\"\"\n        Method to start an experiment. This method can only be called within a Python's `with` statement. Here is the example code:\n\n        .. code-block:: Python\n\n            # start new experiment and recorder\n            with R.start(experiment_name='test', recorder_name='recorder_1'):\n                model.fit(dataset)\n                R.log...\n                ... # further operations\n\n            # resume previous experiment and recorder\n            with R.start(experiment_name='test', recorder_name='recorder_1', resume=True): # if users want to resume recorder, they have to specify the exact same name for experiment and recorder.\n                ... # further operations\n\n\n        Parameters\n        ----------\n        experiment_id : str\n            id of the experiment one wants to start.\n        experiment_name : str\n            name of the experiment one wants to start.\n        recorder_id : str\n            id of the recorder under the experiment one wants to start.\n        recorder_name : str\n            name of the recorder under the experiment one wants to start.\n        uri : str\n            The tracking uri of the experiment, where all the artifacts/metrics etc. will be stored.\n            The default uri is set in the qlib.config. Note that this uri argument will not change the one defined in the config file.\n            Therefore, the next time when users call this function in the same experiment,\n            they have to also specify this argument with the same value. Otherwise, inconsistent uri may occur.\n        resume : bool\n            whether to resume the specific recorder with given name under the given experiment.\n        \"\"\"\n        run = self.start_exp(\n            experiment_id=experiment_id,\n            experiment_name=experiment_name,\n            recorder_id=recorder_id,\n            recorder_name=recorder_name,\n            uri=uri,\n            resume=resume,\n        )\n        try:\n            yield run\n        except Exception as e:\n            self.end_exp(Recorder.STATUS_FA)  # end the experiment if something went wrong\n            raise e\n        self.end_exp(Recorder.STATUS_FI)\n\n    def start_exp(\n        self,\n        *,\n        experiment_id=None,\n        experiment_name=None,\n        recorder_id=None,\n        recorder_name=None,\n        uri=None,\n        resume=False,\n    ):\n        \"\"\"\n        Lower level method for starting an experiment. When use this method, one should end the experiment manually\n        and the status of the recorder may not be handled properly. Here is the example code:\n\n        .. code-block:: Python\n\n            R.start_exp(experiment_name='test', recorder_name='recorder_1')\n            ... # further operations\n            R.end_exp('FINISHED') or R.end_exp(Recorder.STATUS_S)\n\n\n        Parameters\n        ----------\n        experiment_id : str\n            id of the experiment one wants to start.\n        experiment_name : str\n            the name of the experiment to be started\n        recorder_id : str\n            id of the recorder under the experiment one wants to start.\n        recorder_name : str\n            name of the recorder under the experiment one wants to start.\n        uri : str\n            the tracking uri of the experiment, where all the artifacts/metrics etc. will be stored.\n            The default uri are set in the qlib.config.\n        resume : bool\n            whether to resume the specific recorder with given name under the given experiment.\n\n        Returns\n        -------\n        An experiment instance being started.\n        \"\"\"\n        return self.exp_manager.start_exp(\n            experiment_id=experiment_id,\n            experiment_name=experiment_name,\n            recorder_id=recorder_id,\n            recorder_name=recorder_name,\n            uri=uri,\n            resume=resume,\n        )\n\n    def end_exp(self, recorder_status=Recorder.STATUS_FI):\n        \"\"\"\n        Method for ending an experiment manually. It will end the current active experiment, as well as its\n        active recorder with the specified `status` type. Here is the example code of the method:\n\n        .. code-block:: Python\n\n            R.start_exp(experiment_name='test')\n            ... # further operations\n            R.end_exp('FINISHED') or R.end_exp(Recorder.STATUS_S)\n\n        Parameters\n        ----------\n        status : str\n            The status of a recorder, which can be SCHEDULED, RUNNING, FINISHED, FAILED.\n        \"\"\"\n        self.exp_manager.end_exp(recorder_status)\n\n    def search_records(self, experiment_ids, **kwargs):\n        \"\"\"\n        Get a pandas DataFrame of records that fit the search criteria.\n\n        The arguments of this function are not set to be rigid, and they will be different with different implementation of\n        ``ExpManager`` in ``Qlib``. ``Qlib`` now provides an implementation of ``ExpManager`` with mlflow, and here is the\n        example code of the method with the ``MLflowExpManager``:\n\n        .. code-block:: Python\n\n            R.log_metrics(m=2.50, step=0)\n            records = R.search_records([experiment_id], order_by=[\"metrics.m DESC\"])\n\n        Parameters\n        ----------\n        experiment_ids : list\n            list of experiment IDs.\n        filter_string : str\n            filter query string, defaults to searching all runs.\n        run_view_type : int\n            one of enum values ACTIVE_ONLY, DELETED_ONLY, or ALL (e.g. in mlflow.entities.ViewType).\n        max_results  : int\n            the maximum number of runs to put in the dataframe.\n        order_by : list\n            list of columns to order by (e.g., “metrics.rmse”).\n\n        Returns\n        -------\n        A pandas.DataFrame of records, where each metric, parameter, and tag\n        are expanded into their own columns named metrics.*, params.*, and tags.*\n        respectively. For records that don't have a particular metric, parameter, or tag, their\n        value will be (NumPy) Nan, None, or None respectively.\n        \"\"\"\n        return self.exp_manager.search_records(experiment_ids, **kwargs)\n\n    def list_experiments(self):\n        \"\"\"\n        Method for listing all the existing experiments (except for those being deleted.)\n\n        .. code-block:: Python\n\n            exps = R.list_experiments()\n\n        Returns\n        -------\n        A dictionary (name -> experiment) of experiments information that being stored.\n        \"\"\"\n        return self.exp_manager.list_experiments()\n\n    def list_recorders(self, experiment_id=None, experiment_name=None):\n        \"\"\"\n        Method for listing all the recorders of experiment with given id or name.\n\n        If user doesn't provide the id or name of the experiment, this method will try to retrieve the default experiment and\n        list all the recorders of the default experiment. If the default experiment doesn't exist, the method will first\n        create the default experiment, and then create a new recorder under it. (More information about the default experiment\n        can be found `here <../component/recorder.html#qlib.workflow.exp.Experiment>`__).\n\n        Here is the example code:\n\n        .. code-block:: Python\n\n            recorders = R.list_recorders(experiment_name='test')\n\n        Parameters\n        ----------\n        experiment_id : str\n            id of the experiment.\n        experiment_name : str\n            name of the experiment.\n\n        Returns\n        -------\n        A dictionary (id -> recorder) of recorder information that being stored.\n        \"\"\"\n        return self.get_exp(experiment_id=experiment_id, experiment_name=experiment_name).list_recorders()\n\n    def get_exp(\n        self, *, experiment_id=None, experiment_name=None, create: bool = True, start: bool = False\n    ) -> Experiment:\n        \"\"\"\n        Method for retrieving an experiment with given id or name. Once the `create` argument is set to\n        True, if no valid experiment is found, this method will create one for you. Otherwise, it will\n        only retrieve a specific experiment or raise an Error.\n\n        - If '`create`' is True:\n\n            - If `active experiment` exists:\n\n                - no id or name specified, return the active experiment.\n\n                - if id or name is specified, return the specified experiment. If no such exp found, create a new experiment with given id or name.\n\n            - If `active experiment` not exists:\n\n                - no id or name specified, create a default experiment, and the experiment is set to be active.\n\n                - if id or name is specified, return the specified experiment. If no such exp found, create a new experiment with given name or the default experiment.\n\n        - Else If '`create`' is False:\n\n            - If `active experiment` exists:\n\n                - no id or name specified, return the active experiment.\n\n                - if id or name is specified, return the specified experiment. If no such exp found, raise Error.\n\n            - If `active experiment` not exists:\n\n                - no id or name specified. If the default experiment exists, return it, otherwise, raise Error.\n\n                - if id or name is specified, return the specified experiment. If no such exp found, raise Error.\n\n        Here are some use cases:\n\n        .. code-block:: Python\n\n            # Case 1\n            with R.start('test'):\n                exp = R.get_exp()\n                recorders = exp.list_recorders()\n\n            # Case 2\n            with R.start('test'):\n                exp = R.get_exp(experiment_name='test1')\n\n            # Case 3\n            exp = R.get_exp() -> a default experiment.\n\n            # Case 4\n            exp = R.get_exp(experiment_name='test')\n\n            # Case 5\n            exp = R.get_exp(create=False) -> the default experiment if exists.\n\n        Parameters\n        ----------\n        experiment_id : str\n            id of the experiment.\n        experiment_name : str\n            name of the experiment.\n        create : boolean\n            an argument determines whether the method will automatically create a new experiment\n            according to user's specification if the experiment hasn't been created before.\n        start : bool\n            when start is True,\n            if the experiment has not started(not activated), it will start\n            It is designed for R.log_params to auto start experiments\n\n        Returns\n        -------\n        An experiment instance with given id or name.\n        \"\"\"\n        return self.exp_manager.get_exp(\n            experiment_id=experiment_id,\n            experiment_name=experiment_name,\n            create=create,\n            start=start,\n        )\n\n    def delete_exp(self, experiment_id=None, experiment_name=None):\n        \"\"\"\n        Method for deleting the experiment with given id or name. At least one of id or name must be given,\n        otherwise, error will occur.\n\n        Here is the example code:\n\n        .. code-block:: Python\n\n            R.delete_exp(experiment_name='test')\n\n        Parameters\n        ----------\n        experiment_id : str\n            id of the experiment.\n        experiment_name : str\n            name of the experiment.\n        \"\"\"\n        self.exp_manager.delete_exp(experiment_id, experiment_name)\n\n    def get_uri(self):\n        \"\"\"\n        Method for retrieving the uri of current experiment manager.\n\n        Here is the example code:\n\n        .. code-block:: Python\n\n            uri = R.get_uri()\n\n        Returns\n        -------\n        The uri of current experiment manager.\n        \"\"\"\n        return self.exp_manager.uri\n\n    def set_uri(self, uri: Optional[Text]):\n        \"\"\"\n        Method to reset the **default** uri of current experiment manager.\n\n        NOTE:\n\n        - When the uri is refer to a file path, please using the absolute path instead of strings like \"~/mlruns/\"\n          The backend don't support strings like this.\n        \"\"\"\n        self.exp_manager.default_uri = uri\n\n    @contextmanager\n    def uri_context(self, uri: Text):\n        \"\"\"\n        Temporarily set the exp_manager's **default_uri** to uri\n\n        NOTE:\n        - Please refer to the NOTE in the `set_uri`\n\n        Parameters\n        ----------\n        uri : Text\n            the temporal uri\n        \"\"\"\n        prev_uri = self.exp_manager.default_uri\n        self.exp_manager.default_uri = uri\n        try:\n            yield\n        finally:\n            self.exp_manager.default_uri = prev_uri\n\n    def get_recorder(\n        self,\n        *,\n        recorder_id=None,\n        recorder_name=None,\n        experiment_id=None,\n        experiment_name=None,\n    ) -> Recorder:\n        \"\"\"\n        Method for retrieving a recorder.\n\n        - If `active recorder` exists:\n\n            - no id or name specified, return the active recorder.\n\n            - if id or name is specified, return the specified recorder.\n\n        - If `active recorder` not exists:\n\n            - no id or name specified, raise Error.\n\n            - if id or name is specified, and the corresponding experiment_name must be given, return the specified recorder. Otherwise, raise Error.\n\n        The recorder can be used for further process such as `save_object`, `load_object`, `log_params`,\n        `log_metrics`, etc.\n\n        Here are some use cases:\n\n        .. code-block:: Python\n\n            # Case 1\n            with R.start(experiment_name='test'):\n                recorder = R.get_recorder()\n\n            # Case 2\n            with R.start(experiment_name='test'):\n                recorder = R.get_recorder(recorder_id='2e7a4efd66574fa49039e00ffaefa99d')\n\n            # Case 3\n            recorder = R.get_recorder() -> Error\n\n            # Case 4\n            recorder = R.get_recorder(recorder_id='2e7a4efd66574fa49039e00ffaefa99d') -> Error\n\n            # Case 5\n            recorder = R.get_recorder(recorder_id='2e7a4efd66574fa49039e00ffaefa99d', experiment_name='test')\n\n\n        Here are some things users may concern\n        - Q: What recorder will it return if multiple recorder meets the query (e.g. query with experiment_name)\n        - A: If mlflow backend is used, then the recorder with the latest `start_time` will be returned. Because MLflow's `search_runs` function guarantee it\n\n        Parameters\n        ----------\n        recorder_id : str\n            id of the recorder.\n        recorder_name : str\n            name of the recorder.\n        experiment_name : str\n            name of the experiment.\n\n        Returns\n        -------\n        A recorder instance.\n        \"\"\"\n        return self.get_exp(experiment_name=experiment_name, experiment_id=experiment_id, create=False).get_recorder(\n            recorder_id, recorder_name, create=False, start=False\n        )\n\n    def delete_recorder(self, recorder_id=None, recorder_name=None):\n        \"\"\"\n        Method for deleting the recorders with given id or name. At least one of id or name must be given,\n        otherwise, error will occur.\n\n        Here is the example code:\n\n        .. code-block:: Python\n\n            R.delete_recorder(recorder_id='2e7a4efd66574fa49039e00ffaefa99d')\n\n        Parameters\n        ----------\n        recorder_id : str\n            id of the experiment.\n        recorder_name : str\n            name of the experiment.\n        \"\"\"\n        self.get_exp().delete_recorder(recorder_id, recorder_name)\n\n    def save_objects(self, local_path=None, artifact_path=None, **kwargs: Dict[Text, Any]):\n        \"\"\"\n        Method for saving objects as artifacts in the experiment to the uri. It supports either saving\n        from a local file/directory, or directly saving objects. User can use valid python's keywords arguments\n        to specify the object to be saved as well as its name (name: value).\n\n        In summary, this API is designs for saving **objects** to **the experiments management backend path**,\n        1. Qlib provide two methods to specify **objects**\n        - Passing in the object directly by passing with `**kwargs` (e.g. R.save_objects(trained_model=model))\n        - Passing in the local path to the object, i.e. `local_path` parameter.\n        2. `artifact_path` represents the  **the experiments management backend path**\n\n        - If `active recorder` exists: it will save the objects through the active recorder.\n        - If `active recorder` not exists: the system will create a default experiment, and a new recorder and save objects under it.\n\n        .. note::\n\n            If one wants to save objects with a specific recorder. It is recommended to first get the specific recorder through `get_recorder` API and use the recorder the save objects. The supported arguments are the same as this method.\n\n        Here are some use cases:\n\n        .. code-block:: Python\n\n            # Case 1\n            with R.start(experiment_name='test'):\n                pred = model.predict(dataset)\n                R.save_objects(**{\"pred.pkl\": pred}, artifact_path='prediction')\n                rid = R.get_recorder().id\n            ...\n            R.get_recorder(recorder_id=rid).load_object(\"prediction/pred.pkl\")  #  after saving objects, you can load the previous object with this api\n\n            # Case 2\n            with R.start(experiment_name='test'):\n                R.save_objects(local_path='results/pred.pkl', artifact_path=\"prediction\")\n                rid = R.get_recorder().id\n            ...\n            R.get_recorder(recorder_id=rid).load_object(\"prediction/pred.pkl\")  #  after saving objects, you can load the previous object with this api\n\n\n        Parameters\n        ----------\n        local_path : str\n            if provided, them save the file or directory to the artifact URI.\n        artifact_path : str\n            the relative path for the artifact to be stored in the URI.\n        **kwargs: Dict[Text, Any]\n            the object to be saved.\n            For example, `{\"pred.pkl\": pred}`\n        \"\"\"\n        if local_path is not None and len(kwargs) > 0:\n            raise ValueError(\n                \"You can choose only one of `local_path`(save the files in a path) or `kwargs`(pass in the objects directly)\"\n            )\n        self.get_exp().get_recorder(start=True).save_objects(local_path, artifact_path, **kwargs)\n\n    def load_object(self, name: Text):\n        \"\"\"\n        Method for loading an object from artifacts in the experiment in the uri.\n        \"\"\"\n        return self.get_exp().get_recorder(start=True).load_object(name)\n\n    def log_params(self, **kwargs):\n        \"\"\"\n        Method for logging parameters during an experiment. In addition to using ``R``, one can also log to a specific recorder after getting it with `get_recorder` API.\n\n        - If `active recorder` exists: it will log parameters through the active recorder.\n        - If `active recorder` not exists: the system will create a default experiment as well as a new recorder, and log parameters under it.\n\n        Here are some use cases:\n\n        .. code-block:: Python\n\n            # Case 1\n            with R.start('test'):\n                R.log_params(learning_rate=0.01)\n\n            # Case 2\n            R.log_params(learning_rate=0.01)\n\n        Parameters\n        ----------\n        keyword argument:\n            name1=value1, name2=value2, ...\n        \"\"\"\n        self.get_exp(start=True).get_recorder(start=True).log_params(**kwargs)\n\n    def log_metrics(self, step=None, **kwargs):\n        \"\"\"\n        Method for logging metrics during an experiment. In addition to using ``R``, one can also log to a specific recorder after getting it with `get_recorder` API.\n\n        - If `active recorder` exists: it will log metrics through the active recorder.\n        - If `active recorder` not exists: the system will create a default experiment as well as a new recorder, and log metrics under it.\n\n        Here are some use cases:\n\n        .. code-block:: Python\n\n            # Case 1\n            with R.start('test'):\n                R.log_metrics(train_loss=0.33, step=1)\n\n            # Case 2\n            R.log_metrics(train_loss=0.33, step=1)\n\n        Parameters\n        ----------\n        keyword argument:\n            name1=value1, name2=value2, ...\n        \"\"\"\n        self.get_exp(start=True).get_recorder(start=True).log_metrics(step, **kwargs)\n\n    def log_artifact(self, local_path: str, artifact_path: Optional[str] = None):\n        \"\"\"\n        Log a local file or directory as an artifact of the currently active run\n\n        - If `active recorder` exists: it will set tags through the active recorder.\n        - If `active recorder` not exists: the system will create a default experiment as well as a new recorder, and set the tags under it.\n\n        Parameters\n        ----------\n        local_path : str\n            Path to the file to write.\n        artifact_path : Optional[str]\n            If provided, the directory in ``artifact_uri`` to write to.\n        \"\"\"\n        self.get_exp(start=True).get_recorder(start=True).log_artifact(local_path, artifact_path)\n\n    def download_artifact(self, path: str, dst_path: Optional[str] = None) -> str:\n        \"\"\"\n        Download an artifact file or directory from a run to a local directory if applicable,\n        and return a local path for it.\n\n        Parameters\n        ----------\n        path : str\n            Relative source path to the desired artifact.\n        dst_path : Optional[str]\n            Absolute path of the local filesystem destination directory to which to\n            download the specified artifacts. This directory must already exist.\n            If unspecified, the artifacts will either be downloaded to a new\n            uniquely-named directory on the local filesystem.\n\n        Returns\n        -------\n        str\n            Local path of desired artifact.\n        \"\"\"\n        self.get_exp(start=True).get_recorder(start=True).download_artifact(path, dst_path)\n\n    def set_tags(self, **kwargs):\n        \"\"\"\n        Method for setting tags for a recorder. In addition to using ``R``, one can also set the tag to a specific recorder after getting it with `get_recorder` API.\n\n        - If `active recorder` exists: it will set tags through the active recorder.\n        - If `active recorder` not exists: the system will create a default experiment as well as a new recorder, and set the tags under it.\n\n        Here are some use cases:\n\n        .. code-block:: Python\n\n            # Case 1\n            with R.start('test'):\n                R.set_tags(release_version=\"2.2.0\")\n\n            # Case 2\n            R.set_tags(release_version=\"2.2.0\")\n\n        Parameters\n        ----------\n        keyword argument:\n            name1=value1, name2=value2, ...\n        \"\"\"\n        self.get_exp(start=True).get_recorder(start=True).set_tags(**kwargs)\n\n\nclass RecorderWrapper(Wrapper):\n    \"\"\"\n    Wrapper class for QlibRecorder, which detects whether users reinitialize qlib when already starting an experiment.\n    \"\"\"\n\n    def register(self, provider):\n        if self._provider is not None:\n            expm = getattr(self._provider, \"exp_manager\")\n            if expm.active_experiment is not None:\n                raise RecorderInitializationError(\n                    \"Please don't reinitialize Qlib if QlibRecorder is already activated. Otherwise, the experiment stored location will be modified.\"\n                )\n        self._provider = provider\n\n\nimport sys\n\nif sys.version_info >= (3, 9):\n    from typing import Annotated\n\n    QlibRecorderWrapper = Annotated[QlibRecorder, RecorderWrapper]\nelse:\n    QlibRecorderWrapper = QlibRecorder\n\n# global record\nR: QlibRecorderWrapper = RecorderWrapper()\n"
  },
  {
    "path": "qlib/workflow/exp.py",
    "content": "# Copyright (c) Microsoft Corporation.\r\n# Licensed under the MIT License.\r\n\r\nfrom typing import Dict, List, Union\r\nfrom qlib.typehint import Literal\r\nimport mlflow\r\nfrom mlflow.entities import ViewType\r\nfrom mlflow.exceptions import MlflowException\r\nfrom .recorder import Recorder, MLflowRecorder\r\nfrom ..log import get_module_logger\r\n\r\nlogger = get_module_logger(\"workflow\")\r\n\r\n\r\nclass Experiment:\r\n    \"\"\"\r\n    This is the `Experiment` class for each experiment being run. The API is designed similar to mlflow.\r\n    (The link: https://mlflow.org/docs/latest/python_api/mlflow.html)\r\n    \"\"\"\r\n\r\n    def __init__(self, id, name):\r\n        self.id = id\r\n        self.name = name\r\n        self.active_recorder = None  # only one recorder can run each time\r\n        self._default_rec_name = \"abstract_recorder\"\r\n\r\n    def __repr__(self):\r\n        return \"{name}(id={id}, info={info})\".format(name=self.__class__.__name__, id=self.id, info=self.info)\r\n\r\n    def __str__(self):\r\n        return str(self.info)\r\n\r\n    @property\r\n    def info(self):\r\n        recorders = self.list_recorders()\r\n        output = dict()\r\n        output[\"class\"] = \"Experiment\"\r\n        output[\"id\"] = self.id\r\n        output[\"name\"] = self.name\r\n        output[\"active_recorder\"] = self.active_recorder.id if self.active_recorder is not None else None\r\n        output[\"recorders\"] = list(recorders.keys())\r\n        return output\r\n\r\n    def start(self, *, recorder_id=None, recorder_name=None, resume=False):\r\n        \"\"\"\r\n        Start the experiment and set it to be active. This method will also start a new recorder.\r\n\r\n        Parameters\r\n        ----------\r\n        recorder_id : str\r\n            the id of the recorder to be created.\r\n        recorder_name : str\r\n            the name of the recorder to be created.\r\n        resume : bool\r\n            whether to resume the first recorder\r\n\r\n        Returns\r\n        -------\r\n        An active recorder.\r\n        \"\"\"\r\n        raise NotImplementedError(f\"Please implement the `start` method.\")\r\n\r\n    def end(self, recorder_status=Recorder.STATUS_S):\r\n        \"\"\"\r\n        End the experiment.\r\n\r\n        Parameters\r\n        ----------\r\n        recorder_status : str\r\n            the status the recorder to be set with when ending (SCHEDULED, RUNNING, FINISHED, FAILED).\r\n        \"\"\"\r\n        raise NotImplementedError(f\"Please implement the `end` method.\")\r\n\r\n    def create_recorder(self, recorder_name=None):\r\n        \"\"\"\r\n        Create a recorder for each experiment.\r\n\r\n        Parameters\r\n        ----------\r\n        recorder_name : str\r\n            the name of the recorder to be created.\r\n\r\n        Returns\r\n        -------\r\n        A recorder object.\r\n        \"\"\"\r\n        raise NotImplementedError(f\"Please implement the `create_recorder` method.\")\r\n\r\n    def search_records(self, **kwargs):\r\n        \"\"\"\r\n        Get a pandas DataFrame of records that fit the search criteria of the experiment.\r\n        Inputs are the search criteria user want to apply.\r\n\r\n        Returns\r\n        -------\r\n        A pandas.DataFrame of records, where each metric, parameter, and tag\r\n        are expanded into their own columns named metrics.*, params.*, and tags.*\r\n        respectively. For records that don't have a particular metric, parameter, or tag, their\r\n        value will be (NumPy) Nan, None, or None respectively.\r\n        \"\"\"\r\n        raise NotImplementedError(f\"Please implement the `search_records` method.\")\r\n\r\n    def delete_recorder(self, recorder_id):\r\n        \"\"\"\r\n        Create a recorder for each experiment.\r\n\r\n        Parameters\r\n        ----------\r\n        recorder_id : str\r\n            the id of the recorder to be deleted.\r\n        \"\"\"\r\n        raise NotImplementedError(f\"Please implement the `delete_recorder` method.\")\r\n\r\n    def get_recorder(self, recorder_id=None, recorder_name=None, create: bool = True, start: bool = False) -> Recorder:\r\n        \"\"\"\r\n        Retrieve a Recorder for user. When user specify recorder id and name, the method will try to return the\r\n        specific recorder. When user does not provide recorder id or name, the method will try to return the current\r\n        active recorder. The `create` argument determines whether the method will automatically create a new recorder\r\n        according to user's specification if the recorder hasn't been created before.\r\n\r\n        * If `create` is True:\r\n\r\n            * If `active recorder` exists:\r\n\r\n                * no id or name specified, return the active recorder.\r\n                * if id or name is specified, return the specified recorder. If no such exp found, create a new recorder with given id or name. If `start` is set to be True, the recorder is set to be active.\r\n\r\n            * If `active recorder` not exists:\r\n\r\n                * no id or name specified, create a new recorder.\r\n                * if id or name is specified, return the specified experiment. If no such exp found, create a new recorder with given id or name. If `start` is set to be True, the recorder is set to be active.\r\n\r\n        * Else If `create` is False:\r\n\r\n            * If `active recorder` exists:\r\n\r\n                * no id or name specified, return the active recorder.\r\n                * if id or name is specified, return the specified recorder. If no such exp found, raise Error.\r\n\r\n            * If `active recorder` not exists:\r\n\r\n                * no id or name specified, raise Error.\r\n                * if id or name is specified, return the specified recorder. If no such exp found, raise Error.\r\n\r\n        Parameters\r\n        ----------\r\n        recorder_id : str\r\n            the id of the recorder to be deleted.\r\n        recorder_name : str\r\n            the name of the recorder to be deleted.\r\n        create : boolean\r\n            create the recorder if it hasn't been created before.\r\n        start : boolean\r\n            start the new recorder if one is **created**.\r\n\r\n        Returns\r\n        -------\r\n        A recorder object.\r\n        \"\"\"\r\n        # special case of getting the recorder\r\n        if recorder_id is None and recorder_name is None:\r\n            if self.active_recorder is not None:\r\n                return self.active_recorder\r\n            recorder_name = self._default_rec_name\r\n        if create:\r\n            recorder, is_new = self._get_or_create_rec(recorder_id=recorder_id, recorder_name=recorder_name)\r\n        else:\r\n            recorder, is_new = (\r\n                self._get_recorder(recorder_id=recorder_id, recorder_name=recorder_name),\r\n                False,\r\n            )\r\n        if is_new and start:\r\n            self.active_recorder = recorder\r\n            # start the recorder\r\n            self.active_recorder.start_run()\r\n        return recorder\r\n\r\n    def _get_or_create_rec(self, recorder_id=None, recorder_name=None) -> (object, bool):\r\n        \"\"\"\r\n        Method for getting or creating a recorder. It will try to first get a valid recorder, if exception occurs, it will\r\n        automatically create a new recorder based on the given id and name.\r\n        \"\"\"\r\n        try:\r\n            if recorder_id is None and recorder_name is None:\r\n                recorder_name = self._default_rec_name\r\n            return (\r\n                self._get_recorder(recorder_id=recorder_id, recorder_name=recorder_name),\r\n                False,\r\n            )\r\n        except ValueError:\r\n            if recorder_name is None:\r\n                recorder_name = self._default_rec_name\r\n            logger.info(f\"No valid recorder found. Create a new recorder with name {recorder_name}.\")\r\n            return self.create_recorder(recorder_name), True\r\n\r\n    def _get_recorder(self, recorder_id=None, recorder_name=None):\r\n        \"\"\"\r\n        Get specific recorder by name or id. If it does not exist, raise ValueError\r\n\r\n        Parameters\r\n        ----------\r\n        recorder_id :\r\n            The id of recorder\r\n        recorder_name :\r\n            The name of recorder\r\n\r\n        Returns\r\n        -------\r\n        Recorder:\r\n            The searched recorder\r\n\r\n        Raises\r\n        ------\r\n        ValueError\r\n        \"\"\"\r\n        raise NotImplementedError(f\"Please implement the `_get_recorder` method\")\r\n\r\n    RT_D = \"dict\"  # return type dict\r\n    RT_L = \"list\"  # return type list\r\n\r\n    def list_recorders(\r\n        self, rtype: Literal[\"dict\", \"list\"] = RT_D, **flt_kwargs\r\n    ) -> Union[List[Recorder], Dict[str, Recorder]]:\r\n        \"\"\"\r\n        List all the existing recorders of this experiment. Please first get the experiment instance before calling this method.\r\n        If user want to use the method `R.list_recorders()`, please refer to the related API document in `QlibRecorder`.\r\n\r\n        flt_kwargs : dict\r\n            filter recorders by conditions\r\n            e.g.  list_recorders(status=Recorder.STATUS_FI)\r\n\r\n        Returns\r\n        -------\r\n        The return type depends on `rtype`\r\n            if `rtype` == \"dict\":\r\n                A dictionary (id -> recorder) of recorder information that being stored.\r\n            elif `rtype` == \"list\":\r\n                A list of Recorder.\r\n        \"\"\"\r\n        raise NotImplementedError(f\"Please implement the `list_recorders` method.\")\r\n\r\n\r\nclass MLflowExperiment(Experiment):\r\n    \"\"\"\r\n    Use mlflow to implement Experiment.\r\n    \"\"\"\r\n\r\n    def __init__(self, id, name, uri):\r\n        super(MLflowExperiment, self).__init__(id, name)\r\n        self._uri = uri\r\n        self._default_rec_name = \"mlflow_recorder\"\r\n        self._client = mlflow.tracking.MlflowClient(tracking_uri=self._uri)\r\n\r\n    def __repr__(self):\r\n        return \"{name}(id={id}, info={info})\".format(name=self.__class__.__name__, id=self.id, info=self.info)\r\n\r\n    def start(self, *, recorder_id=None, recorder_name=None, resume=False):\r\n        logger.info(f\"Experiment {self.id} starts running ...\")\r\n        # Get or create recorder\r\n        if recorder_name is None:\r\n            recorder_name = self._default_rec_name\r\n        # resume the recorder\r\n        if resume:\r\n            recorder, _ = self._get_or_create_rec(recorder_id=recorder_id, recorder_name=recorder_name)\r\n        # create a new recorder\r\n        else:\r\n            recorder = self.create_recorder(recorder_name)\r\n        # Set up active recorder\r\n        self.active_recorder = recorder\r\n        # Start the recorder\r\n        self.active_recorder.start_run()\r\n\r\n        return self.active_recorder\r\n\r\n    def end(self, recorder_status=Recorder.STATUS_S):\r\n        if self.active_recorder is not None:\r\n            self.active_recorder.end_run(recorder_status)\r\n            self.active_recorder = None\r\n\r\n    def create_recorder(self, recorder_name=None):\r\n        if recorder_name is None:\r\n            recorder_name = self._default_rec_name\r\n        recorder = MLflowRecorder(self.id, self._uri, recorder_name)\r\n\r\n        return recorder\r\n\r\n    def _get_recorder(self, recorder_id=None, recorder_name=None):\r\n        \"\"\"\r\n        Method for getting or creating a recorder. It will try to first get a valid recorder, if exception occurs, it will\r\n        raise errors.\r\n\r\n        Quoting docs of search_runs from MLflow\r\n        > The default ordering is to sort by start_time DESC, then run_id.\r\n        \"\"\"\r\n        assert (\r\n            recorder_id is not None or recorder_name is not None\r\n        ), \"Please input at least one of recorder id or name before retrieving recorder.\"\r\n        if recorder_id is not None:\r\n            try:\r\n                run = self._client.get_run(recorder_id)\r\n                recorder = MLflowRecorder(self.id, self._uri, mlflow_run=run)\r\n                return recorder\r\n            except MlflowException as mlflow_exp:\r\n                raise ValueError(\r\n                    \"No valid recorder has been found, please make sure the input recorder id is correct.\"\r\n                ) from mlflow_exp\r\n        elif recorder_name is not None:\r\n            logger.warning(\r\n                f\"Please make sure the recorder name {recorder_name} is unique, we will only return the latest recorder if there exist several matched the given name.\"\r\n            )\r\n            recorders = self.list_recorders()\r\n            for rid in recorders:\r\n                if recorders[rid].name == recorder_name:\r\n                    return recorders[rid]\r\n            raise ValueError(\"No valid recorder has been found, please make sure the input recorder name is correct.\")\r\n\r\n    def search_records(self, **kwargs):\r\n        filter_string = \"\" if kwargs.get(\"filter_string\") is None else kwargs.get(\"filter_string\")\r\n        run_view_type = 1 if kwargs.get(\"run_view_type\") is None else kwargs.get(\"run_view_type\")\r\n        max_results = 100000 if kwargs.get(\"max_results\") is None else kwargs.get(\"max_results\")\r\n        order_by = kwargs.get(\"order_by\")\r\n\r\n        return self._client.search_runs([self.id], filter_string, run_view_type, max_results, order_by)\r\n\r\n    def delete_recorder(self, recorder_id=None, recorder_name=None):\r\n        assert (\r\n            recorder_id is not None or recorder_name is not None\r\n        ), \"Please input a valid recorder id or name before deleting.\"\r\n        try:\r\n            if recorder_id is not None:\r\n                self._client.delete_run(recorder_id)\r\n            else:\r\n                recorder = self._get_recorder(recorder_name=recorder_name)\r\n                self._client.delete_run(recorder.id)\r\n        except MlflowException as e:\r\n            raise ValueError(\r\n                f\"Error: {e}. Something went wrong when deleting recorder. Please check if the name/id of the recorder is correct.\"\r\n            ) from e\r\n\r\n    UNLIMITED = 50000  # FIXME: Mlflow can only list 50000 records at most!!!!!!!\r\n\r\n    def list_recorders(\r\n        self,\r\n        rtype: Literal[\"dict\", \"list\"] = Experiment.RT_D,\r\n        max_results: int = UNLIMITED,\r\n        status: Union[str, None] = None,\r\n        filter_string: str = \"\",\r\n    ):\r\n        \"\"\"\r\n        Quoting docs of search_runs\r\n        > The default ordering is to sort by start_time DESC, then run_id.\r\n\r\n        Parameters\r\n        ----------\r\n        max_results : int\r\n            the number limitation of the results'\r\n        status : str\r\n            the criteria based on status to filter results.\r\n            `None` indicates no filtering.\r\n        filter_string : str\r\n            mlflow supported filter string like 'params.\"my_param\"=\"a\" and tags.\"my_tag\"=\"b\"', use this will help to reduce too much run number.\r\n        \"\"\"\r\n        runs = self._client.search_runs(\r\n            self.id, run_view_type=ViewType.ACTIVE_ONLY, max_results=max_results, filter_string=filter_string\r\n        )\r\n        rids = []\r\n        recorders = []\r\n        for i, n in enumerate(runs):\r\n            recorder = MLflowRecorder(self.id, self._uri, mlflow_run=n)\r\n            if status is None or recorder.status == status:\r\n                rids.append(n.info.run_id)\r\n                recorders.append(recorder)\r\n\r\n        if rtype == Experiment.RT_D:\r\n            return dict(zip(rids, recorders))\r\n        elif rtype == Experiment.RT_L:\r\n            return recorders\r\n        else:\r\n            raise NotImplementedError(f\"This type of input is not supported\")\r\n"
  },
  {
    "path": "qlib/workflow/expm.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom urllib.parse import urlparse\nimport mlflow\nfrom filelock import FileLock\nfrom mlflow.exceptions import MlflowException, RESOURCE_ALREADY_EXISTS, ErrorCode\nfrom mlflow.entities import ViewType\nimport os\nfrom typing import Optional, Text\nfrom pathlib import Path\n\nfrom .exp import MLflowExperiment, Experiment\nfrom ..config import C\nfrom .recorder import Recorder\nfrom ..log import get_module_logger\nfrom ..utils.exceptions import ExpAlreadyExistError\n\nlogger = get_module_logger(\"workflow\")\n\n\nclass ExpManager:\n    \"\"\"\n    This is the `ExpManager` class for managing experiments. The API is designed similar to mlflow.\n    (The link: https://mlflow.org/docs/latest/python_api/mlflow.html)\n\n    The `ExpManager` is expected to be a singleton (btw, we can have multiple `Experiment`s with different uri. user can get different experiments from different uri, and then compare records of them). Global Config (i.e. `C`)  is also a singleton.\n\n    So we try to align them together.  They share the same variable, which is called **default uri**. Please refer to `ExpManager.default_uri` for details of variable sharing.\n\n    When the user starts an experiment, the user may want to set the uri to a specific uri (it will override **default uri** during this period), and then unset the **specific uri** and fallback to the **default uri**.    `ExpManager._active_exp_uri` is that **specific uri**.\n    \"\"\"\n\n    active_experiment: Optional[Experiment]\n\n    def __init__(self, uri: Text, default_exp_name: Optional[Text]):\n        self.default_uri = uri\n        self._active_exp_uri = None  # No active experiments. So it is set to None\n        self._default_exp_name = default_exp_name\n        self.active_experiment = None  # only one experiment can be active each time\n        logger.debug(f\"experiment manager uri is at {self.uri}\")\n\n    def __repr__(self):\n        return \"{name}(uri={uri})\".format(name=self.__class__.__name__, uri=self.uri)\n\n    def start_exp(\n        self,\n        *,\n        experiment_id: Optional[Text] = None,\n        experiment_name: Optional[Text] = None,\n        recorder_id: Optional[Text] = None,\n        recorder_name: Optional[Text] = None,\n        uri: Optional[Text] = None,\n        resume: bool = False,\n        **kwargs,\n    ) -> Experiment:\n        \"\"\"\n        Start an experiment. This method includes first get_or_create an experiment, and then\n        set it to be active.\n\n        Maintaining `_active_exp_uri` is included in start_exp, remaining implementation should be included in _end_exp in subclass\n\n        Parameters\n        ----------\n        experiment_id : str\n            id of the active experiment.\n        experiment_name : str\n            name of the active experiment.\n        recorder_id : str\n            id of the recorder to be started.\n        recorder_name : str\n            name of the recorder to be started.\n        uri : str\n            the current tracking URI.\n        resume : boolean\n            whether to resume the experiment and recorder.\n\n        Returns\n        -------\n        An active experiment.\n        \"\"\"\n        self._active_exp_uri = uri\n        # The subclass may set the underlying uri back.\n        # So setting `_active_exp_uri` come before `_start_exp`\n        return self._start_exp(\n            experiment_id=experiment_id,\n            experiment_name=experiment_name,\n            recorder_id=recorder_id,\n            recorder_name=recorder_name,\n            resume=resume,\n            **kwargs,\n        )\n\n    def _start_exp(self, *args, **kwargs) -> Experiment:\n        \"\"\"Please refer to the doc of `start_exp`\"\"\"\n        raise NotImplementedError(f\"Please implement the `start_exp` method.\")\n\n    def end_exp(self, recorder_status: Text = Recorder.STATUS_S, **kwargs):\n        \"\"\"\n        End an active experiment.\n\n        Maintaining `_active_exp_uri` is included in end_exp, remaining implementation should be included in _end_exp in subclass\n\n        Parameters\n        ----------\n        experiment_name : str\n            name of the active experiment.\n        recorder_status : str\n            the status of the active recorder of the experiment.\n        \"\"\"\n        self._active_exp_uri = None\n        # The subclass may set the underlying uri back.\n        # So setting `_active_exp_uri` come before `_end_exp`\n        self._end_exp(recorder_status=recorder_status, **kwargs)\n\n    def _end_exp(self, recorder_status: Text = Recorder.STATUS_S, **kwargs):\n        raise NotImplementedError(f\"Please implement the `end_exp` method.\")\n\n    def create_exp(self, experiment_name: Optional[Text] = None):\n        \"\"\"\n        Create an experiment.\n\n        Parameters\n        ----------\n        experiment_name : str\n            the experiment name, which must be unique.\n\n        Returns\n        -------\n        An experiment object.\n\n        Raise\n        -----\n        ExpAlreadyExistError\n        \"\"\"\n        raise NotImplementedError(f\"Please implement the `create_exp` method.\")\n\n    def search_records(self, experiment_ids=None, **kwargs):\n        \"\"\"\n        Get a pandas DataFrame of records that fit the search criteria of the experiment.\n        Inputs are the search criteria user want to apply.\n\n        Returns\n        -------\n        A pandas.DataFrame of records, where each metric, parameter, and tag\n        are expanded into their own columns named metrics.*, params.*, and tags.*\n        respectively. For records that don't have a particular metric, parameter, or tag, their\n        value will be (NumPy) Nan, None, or None respectively.\n        \"\"\"\n        raise NotImplementedError(f\"Please implement the `search_records` method.\")\n\n    def get_exp(self, *, experiment_id=None, experiment_name=None, create: bool = True, start: bool = False):\n        \"\"\"\n        Retrieve an experiment. This method includes getting an active experiment, and get_or_create a specific experiment.\n\n        When user specify experiment id and name, the method will try to return the specific experiment.\n        When user does not provide recorder id or name, the method will try to return the current active experiment.\n        The `create` argument determines whether the method will automatically create a new experiment according\n        to user's specification if the experiment hasn't been created before.\n\n        * If `create` is True:\n\n            * If `active experiment` exists:\n\n                * no id or name specified, return the active experiment.\n                * if id or name is specified, return the specified experiment. If no such exp found, create a new experiment with given id or name. If `start` is set to be True, the experiment is set to be active.\n\n            * If `active experiment` not exists:\n\n                * no id or name specified, create a default experiment.\n                * if id or name is specified, return the specified experiment. If no such exp found, create a new experiment with given id or name. If `start` is set to be True, the experiment is set to be active.\n\n        * Else If `create` is False:\n\n            * If `active experiment` exists:\n\n                * no id or name specified, return the active experiment.\n                * if id or name is specified, return the specified experiment. If no such exp found, raise Error.\n\n            * If `active experiment` not exists:\n\n                *  no id or name specified. If the default experiment exists, return it, otherwise, raise Error.\n                * if id or name is specified, return the specified experiment. If no such exp found, raise Error.\n\n        Parameters\n        ----------\n        experiment_id : str\n            id of the experiment to return.\n        experiment_name : str\n            name of the experiment to return.\n        create : boolean\n            create the experiment it if hasn't been created before.\n        start : boolean\n            start the new experiment if one is created.\n\n        Returns\n        -------\n        An experiment object.\n        \"\"\"\n        # special case of getting experiment\n        if experiment_id is None and experiment_name is None:\n            if self.active_experiment is not None:\n                return self.active_experiment\n            # User don't want get active code now.\n            experiment_name = self._default_exp_name\n\n        if create:\n            exp, _ = self._get_or_create_exp(experiment_id=experiment_id, experiment_name=experiment_name)\n        else:\n            exp = self._get_exp(experiment_id=experiment_id, experiment_name=experiment_name)\n        if self.active_experiment is None and start:\n            self.active_experiment = exp\n            # start the recorder\n            self.active_experiment.start()\n        return exp\n\n    def _get_or_create_exp(self, experiment_id=None, experiment_name=None) -> (object, bool):\n        \"\"\"\n        Method for getting or creating an experiment. It will try to first get a valid experiment, if exception occurs, it will\n        automatically create a new experiment based on the given id and name.\n        \"\"\"\n        try:\n            return (\n                self._get_exp(experiment_id=experiment_id, experiment_name=experiment_name),\n                False,\n            )\n        except ValueError:\n            if experiment_name is None:\n                experiment_name = self._default_exp_name\n            logger.warning(f\"No valid experiment found. Create a new experiment with name {experiment_name}.\")\n\n            # NOTE: mlflow doesn't consider the lock for recording multiple runs\n            # So we supported it in the interface wrapper\n            pr = urlparse(self.uri)\n            if pr.scheme == \"file\":\n                with FileLock(Path(os.path.join(pr.netloc, pr.path.lstrip(\"/\"), \"filelock\"))):  # pylint: disable=E0110\n                    return self.create_exp(experiment_name), True\n            # NOTE: for other schemes like http, we double check to avoid create exp conflicts\n            try:\n                return self.create_exp(experiment_name), True\n            except ExpAlreadyExistError:\n                return (\n                    self._get_exp(experiment_id=experiment_id, experiment_name=experiment_name),\n                    False,\n                )\n\n    def _get_exp(self, experiment_id=None, experiment_name=None) -> Experiment:\n        \"\"\"\n        Get specific experiment by name or id. If it does not exist, raise ValueError.\n\n        Parameters\n        ----------\n        experiment_id :\n            The id of experiment\n        experiment_name :\n            The name of experiment\n\n        Returns\n        -------\n        Experiment:\n            The searched experiment\n\n        Raises\n        ------\n        ValueError\n        \"\"\"\n        raise NotImplementedError(f\"Please implement the `_get_exp` method\")\n\n    def delete_exp(self, experiment_id=None, experiment_name=None):\n        \"\"\"\n        Delete an experiment.\n\n        Parameters\n        ----------\n        experiment_id  : str\n            the experiment id.\n        experiment_name  : str\n            the experiment name.\n        \"\"\"\n        raise NotImplementedError(f\"Please implement the `delete_exp` method.\")\n\n    @property\n    def default_uri(self):\n        \"\"\"\n        Get the default tracking URI from qlib.config.C\n        \"\"\"\n        if \"kwargs\" not in C.exp_manager or \"uri\" not in C.exp_manager[\"kwargs\"]:\n            raise ValueError(\"The default URI is not set in qlib.config.C\")\n        return C.exp_manager[\"kwargs\"][\"uri\"]\n\n    @default_uri.setter\n    def default_uri(self, value):\n        C.exp_manager.setdefault(\"kwargs\", {})[\"uri\"] = value\n\n    @property\n    def uri(self):\n        \"\"\"\n        Get the default tracking URI or current URI.\n\n        Returns\n        -------\n        The tracking URI string.\n        \"\"\"\n        return self._active_exp_uri or self.default_uri\n\n    def list_experiments(self):\n        \"\"\"\n        List all the existing experiments.\n\n        Returns\n        -------\n        A dictionary (name -> experiment) of experiments information that being stored.\n        \"\"\"\n        raise NotImplementedError(f\"Please implement the `list_experiments` method.\")\n\n\nclass MLflowExpManager(ExpManager):\n    \"\"\"\n    Use mlflow to implement ExpManager.\n    \"\"\"\n\n    @property\n    def client(self):\n        # Please refer to `tests/dependency_tests/test_mlflow.py::MLflowTest::test_creating_client`\n        # The test ensure the speed of create a new client\n        return mlflow.tracking.MlflowClient(tracking_uri=self.uri)\n\n    def _start_exp(\n        self,\n        *,\n        experiment_id: Optional[Text] = None,\n        experiment_name: Optional[Text] = None,\n        recorder_id: Optional[Text] = None,\n        recorder_name: Optional[Text] = None,\n        resume: bool = False,\n    ):\n        # Create experiment\n        if experiment_name is None:\n            experiment_name = self._default_exp_name\n        experiment, _ = self._get_or_create_exp(experiment_id=experiment_id, experiment_name=experiment_name)\n        # Set up active experiment\n        self.active_experiment = experiment\n        # Start the experiment\n        self.active_experiment.start(recorder_id=recorder_id, recorder_name=recorder_name, resume=resume)\n\n        return self.active_experiment\n\n    def _end_exp(self, recorder_status: Text = Recorder.STATUS_S):\n        if self.active_experiment is not None:\n            self.active_experiment.end(recorder_status)\n            self.active_experiment = None\n\n    def create_exp(self, experiment_name: Optional[Text] = None):\n        assert experiment_name is not None\n        # init experiment\n        try:\n            experiment_id = self.client.create_experiment(experiment_name)\n        except MlflowException as e:\n            if e.error_code == ErrorCode.Name(RESOURCE_ALREADY_EXISTS):\n                raise ExpAlreadyExistError() from e\n            raise e\n\n        return MLflowExperiment(experiment_id, experiment_name, self.uri)\n\n    def _get_exp(self, experiment_id=None, experiment_name=None):\n        \"\"\"\n        Method for getting or creating an experiment. It will try to first get a valid experiment, if exception occurs, it will\n        raise errors.\n        \"\"\"\n        assert (\n            experiment_id is not None or experiment_name is not None\n        ), \"Please input at least one of experiment/recorder id or name before retrieving experiment/recorder.\"\n        if experiment_id is not None:\n            try:\n                # NOTE: the mlflow's experiment_id must be str type...\n                # https://www.mlflow.org/docs/latest/python_api/mlflow.tracking.html#mlflow.tracking.MlflowClient.get_experiment\n                exp = self.client.get_experiment(experiment_id)\n                if exp.lifecycle_stage.upper() == \"DELETED\":\n                    raise MlflowException(\"No valid experiment has been found.\")\n                experiment = MLflowExperiment(exp.experiment_id, exp.name, self.uri)\n                return experiment\n            except MlflowException as e:\n                raise ValueError(\n                    \"No valid experiment has been found, please make sure the input experiment id is correct.\"\n                ) from e\n        elif experiment_name is not None:\n            try:\n                exp = self.client.get_experiment_by_name(experiment_name)\n                if exp is None or exp.lifecycle_stage.upper() == \"DELETED\":\n                    raise MlflowException(\"No valid experiment has been found.\")\n                experiment = MLflowExperiment(exp.experiment_id, experiment_name, self.uri)\n                return experiment\n            except MlflowException as e:\n                raise ValueError(\n                    \"No valid experiment has been found, please make sure the input experiment name is correct.\"\n                ) from e\n\n    def search_records(self, experiment_ids=None, **kwargs):\n        filter_string = \"\" if kwargs.get(\"filter_string\") is None else kwargs.get(\"filter_string\")\n        run_view_type = 1 if kwargs.get(\"run_view_type\") is None else kwargs.get(\"run_view_type\")\n        max_results = 100000 if kwargs.get(\"max_results\") is None else kwargs.get(\"max_results\")\n        order_by = kwargs.get(\"order_by\")\n        return self.client.search_runs(experiment_ids, filter_string, run_view_type, max_results, order_by)\n\n    def delete_exp(self, experiment_id=None, experiment_name=None):\n        assert (\n            experiment_id is not None or experiment_name is not None\n        ), \"Please input a valid experiment id or name before deleting.\"\n        try:\n            if experiment_id is not None:\n                self.client.delete_experiment(experiment_id)\n            else:\n                experiment = self.client.get_experiment_by_name(experiment_name)\n                if experiment is None:\n                    raise MlflowException(\"No valid experiment has been found.\")\n                self.client.delete_experiment(experiment.experiment_id)\n        except MlflowException as e:\n            raise ValueError(\n                f\"Error: {e}. Something went wrong when deleting experiment. Please check if the name/id of the experiment is correct.\"\n            ) from e\n\n    def list_experiments(self):\n        # retrieve all the existing experiments\n        mlflow_version = int(mlflow.__version__.split(\".\", maxsplit=1)[0])\n        if mlflow_version >= 2:\n            exps = self.client.search_experiments(view_type=ViewType.ACTIVE_ONLY)\n        else:\n            exps = self.client.list_experiments(view_type=ViewType.ACTIVE_ONLY)  # pylint: disable=E1101\n        experiments = dict()\n        for exp in exps:\n            experiment = MLflowExperiment(exp.experiment_id, exp.name, self.uri)\n            experiments[exp.name] = experiment\n        return experiments\n"
  },
  {
    "path": "qlib/workflow/online/__init__.py",
    "content": ""
  },
  {
    "path": "qlib/workflow/online/manager.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\n\"\"\"\nOnlineManager can manage a set of `Online Strategy <#Online Strategy>`_ and run them dynamically.\n\nWith the change of time, the decisive models will be also changed. In this module, we call those contributing models `online` models.\nIn every routine(such as every day or every minute), the `online` models may be changed and the prediction of them needs to be updated.\nSo this module provides a series of methods to control this process.\n\nThis module also provides a method to simulate `Online Strategy <#Online Strategy>`_ in history.\nWhich means you can verify your strategy or find a better one.\n\nThere are 4 total situations for using different trainers in different situations:\n\n\n\n=========================  ===================================================================================\nSituations                 Description\n=========================  ===================================================================================\nOnline + Trainer           When you want to do a REAL routine, the Trainer will help you train the models. It\n                           will train models task by task and strategy by strategy.\n\nOnline + DelayTrainer      DelayTrainer will skip concrete training until all tasks have been prepared by\n                           different strategies. It makes users can parallelly train all tasks at the end of\n                           `routine` or `first_train`. Otherwise, these functions will get stuck when each\n                           strategy prepare tasks.\n\nSimulation + Trainer       It will behave in the same way as `Online + Trainer`. The only difference is that it\n                           is for simulation/backtesting instead of online trading\n\nSimulation + DelayTrainer  When your models don't have any temporal dependence, you can use DelayTrainer\n                           for the ability to multitasking. It means all tasks in all routines\n                           can be REAL trained at the end of simulating. The signals will be prepared well at\n                           different time segments (based on whether or not any new model is online).\n=========================  ===================================================================================\n\nHere is some pseudo code that demonstrate the workflow of each situation\n\nFor simplicity\n    - Only one strategy is used in the strategy\n    - `update_online_pred` is only called in the online mode and is ignored\n\n1) `Online + Trainer`\n\n.. code-block:: python\n\n    tasks = first_train()\n    models = trainer.train(tasks)\n    trainer.end_train(models)\n    for day in online_trading_days:\n        # OnlineManager.routine\n        models = trainer.train(strategy.prepare_tasks())  # for each strategy\n        strategy.prepare_online_models(models)  # for each strategy\n\n        trainer.end_train(models)\n        prepare_signals()  # prepare trading signals daily\n\n\n`Online + DelayTrainer`: the workflow is the same as `Online + Trainer`.\n\n\n2) `Simulation + DelayTrainer`\n\n.. code-block:: python\n\n    # simulate\n    tasks = first_train()\n    models = trainer.train(tasks)\n    for day in historical_calendars:\n        # OnlineManager.routine\n        models = trainer.train(strategy.prepare_tasks())  # for each strategy\n        strategy.prepare_online_models(models)  # for each strategy\n    # delay_prepare()\n    # FIXME: Currently the delay_prepare is not implemented in a proper way.\n    trainer.end_train(<for all previous models>)\n    prepare_signals()\n\n\n# Can we simplify current workflow?\n\n- Can reduce the number of state of tasks?\n\n    - For each task, we have three phases (i.e. task, partly trained task, final trained task)\n\"\"\"\n\nimport logging\nfrom typing import Callable, List, Union\n\nimport pandas as pd\nfrom qlib import get_module_logger\nfrom qlib.data.data import D\nfrom qlib.log import set_global_logger_level\nfrom qlib.model.ens.ensemble import AverageEnsemble\nfrom qlib.model.trainer import Trainer, TrainerR\nfrom qlib.utils.serial import Serializable\nfrom qlib.workflow.online.strategy import OnlineStrategy\nfrom qlib.workflow.task.collect import MergeCollector\n\n\nclass OnlineManager(Serializable):\n    \"\"\"\n    OnlineManager can manage online models with `Online Strategy <#Online Strategy>`_.\n    It also provides a history recording of which models are online at what time.\n    \"\"\"\n\n    STATUS_SIMULATING = \"simulating\"  # when calling `simulate`\n    STATUS_ONLINE = \"online\"  # the normal status. It is used when online trading\n\n    def __init__(\n        self,\n        strategies: Union[OnlineStrategy, List[OnlineStrategy]],\n        trainer: Trainer = None,\n        begin_time: Union[str, pd.Timestamp] = None,\n        freq=\"day\",\n    ):\n        \"\"\"\n        Init OnlineManager.\n        One OnlineManager must have at least one OnlineStrategy.\n\n        Args:\n            strategies (Union[OnlineStrategy, List[OnlineStrategy]]): an instance of OnlineStrategy or a list of OnlineStrategy\n            begin_time (Union[str,pd.Timestamp], optional): the OnlineManager will begin at this time. Defaults to None for using the latest date.\n            trainer (qlib.model.trainer.Trainer): the trainer to train task. None for using TrainerR.\n            freq (str, optional): data frequency. Defaults to \"day\".\n        \"\"\"\n        self.logger = get_module_logger(self.__class__.__name__)\n        if not isinstance(strategies, list):\n            strategies = [strategies]\n        self.strategies = strategies\n        self.freq = freq\n        if begin_time is None:\n            begin_time = D.calendar(freq=self.freq).max()\n        self.begin_time = pd.Timestamp(begin_time)\n        self.cur_time = self.begin_time\n        # OnlineManager will recorder the history of online models, which is a dict like {pd.Timestamp, {strategy, [online_models]}}.\n        # It records the online servnig models of each strategy for each day.\n        self.history = {}\n        if trainer is None:\n            trainer = TrainerR()\n        self.trainer = trainer\n        self.signals = None\n        self.status = self.STATUS_ONLINE\n\n    def _postpone_action(self):\n        \"\"\"\n        Should the workflow to postpone the following actions to the end (in delay_prepare)\n        - trainer.end_train\n        - prepare_signals\n\n        Postpone these actions is to support simulating/backtest online strategies without time dependencies.\n        All the actions can be done parallelly at the end.\n        \"\"\"\n        return self.status == self.STATUS_SIMULATING and self.trainer.is_delay()\n\n    def first_train(self, strategies: List[OnlineStrategy] = None, model_kwargs: dict = {}):\n        \"\"\"\n        Get tasks from every strategy's first_tasks method and train them.\n        If using DelayTrainer, it can finish training all together after every strategy's first_tasks.\n\n        Args:\n            strategies (List[OnlineStrategy]): the strategies list (need this param when adding strategies). None for use default strategies.\n            model_kwargs (dict): the params for `prepare_online_models`\n        \"\"\"\n        if strategies is None:\n            strategies = self.strategies\n\n        models_list = []\n        for strategy in strategies:\n            self.logger.info(f\"Strategy `{strategy.name_id}` begins first training...\")\n            tasks = strategy.first_tasks()\n            models = self.trainer.train(tasks, experiment_name=strategy.name_id)\n            models_list.append(models)\n            self.logger.info(f\"Finished training {len(models)} models.\")\n            # FIXME: Train multiple online models at `first_train` will result in getting too much online models at the\n            # start.\n            online_models = strategy.prepare_online_models(models, **model_kwargs)\n            self.history.setdefault(self.cur_time, {})[strategy] = online_models\n\n        if not self._postpone_action():\n            for strategy, models in zip(strategies, models_list):\n                models = self.trainer.end_train(models, experiment_name=strategy.name_id)\n\n    def routine(\n        self,\n        cur_time: Union[str, pd.Timestamp] = None,\n        task_kwargs: dict = {},\n        model_kwargs: dict = {},\n        signal_kwargs: dict = {},\n    ):\n        \"\"\"\n        Typical update process for every strategy and record the online history.\n\n        The typical update process after a routine, such as day by day or month by month.\n        The process is: Update predictions -> Prepare tasks -> Prepare online models -> Prepare signals.\n\n        If using DelayTrainer, it can finish training all together after every strategy's prepare_tasks.\n\n        Args:\n            cur_time (Union[str,pd.Timestamp], optional): run routine method in this time. Defaults to None.\n            task_kwargs (dict): the params for `prepare_tasks`\n            model_kwargs (dict): the params for `prepare_online_models`\n            signal_kwargs (dict): the params for `prepare_signals`\n        \"\"\"\n        if cur_time is None:\n            cur_time = D.calendar(freq=self.freq).max()\n        self.cur_time = pd.Timestamp(cur_time)  # None for latest date\n\n        models_list = []\n        for strategy in self.strategies:\n            self.logger.info(f\"Strategy `{strategy.name_id}` begins routine...\")\n\n            tasks = strategy.prepare_tasks(self.cur_time, **task_kwargs)\n            models = self.trainer.train(tasks, experiment_name=strategy.name_id)\n            models_list.append(models)\n            self.logger.info(f\"Finished training {len(models)} models.\")\n            online_models = strategy.prepare_online_models(models, **model_kwargs)\n            self.history.setdefault(self.cur_time, {})[strategy] = online_models\n\n            # The online model may changes in the above processes\n            # So updating the predictions of online models should be the last step\n            if self.status == self.STATUS_ONLINE:\n                strategy.tool.update_online_pred()\n\n        if not self._postpone_action():\n            for strategy, models in zip(self.strategies, models_list):\n                models = self.trainer.end_train(models, experiment_name=strategy.name_id)\n            self.prepare_signals(**signal_kwargs)\n\n    def get_collector(self, **kwargs) -> MergeCollector:\n        \"\"\"\n        Get the instance of `Collector <../advanced/task_management.html#Task Collecting>`_ to collect results from every strategy.\n        This collector can be a basis as the signals preparation.\n\n        Args:\n            **kwargs: the params for get_collector.\n\n        Returns:\n            MergeCollector: the collector to merge other collectors.\n        \"\"\"\n        collector_dict = {}\n        for strategy in self.strategies:\n            collector_dict[strategy.name_id] = strategy.get_collector(**kwargs)\n        return MergeCollector(collector_dict, process_list=[])\n\n    def add_strategy(self, strategies: Union[OnlineStrategy, List[OnlineStrategy]]):\n        \"\"\"\n        Add some new strategies to OnlineManager.\n\n        Args:\n            strategy (Union[OnlineStrategy, List[OnlineStrategy]]): a list of OnlineStrategy\n        \"\"\"\n        if not isinstance(strategies, list):\n            strategies = [strategies]\n        self.first_train(strategies)\n        self.strategies.extend(strategies)\n\n    def prepare_signals(self, prepare_func: Callable = AverageEnsemble(), over_write=False):\n        \"\"\"\n        After preparing the data of the last routine (a box in box-plot) which means the end of the routine, we can prepare trading signals for the next routine.\n\n        NOTE: Given a set prediction, all signals before these prediction end times will be prepared well.\n\n        Even if the latest signal already exists, the latest calculation result will be overwritten.\n\n        .. note::\n\n            Given a prediction of a certain time, all signals before this time will be prepared well.\n\n        Args:\n            prepare_func (Callable, optional): Get signals from a dict after collecting. Defaults to AverageEnsemble(), the results collected by MergeCollector must be {xxx:pred}.\n            over_write (bool, optional): If True, the new signals will overwrite. If False, the new signals will append to the end of signals. Defaults to False.\n\n        Returns:\n            pd.DataFrame: the signals.\n        \"\"\"\n        signals = prepare_func(self.get_collector()())\n        old_signals = self.signals\n        if old_signals is not None and not over_write:\n            old_max = old_signals.index.get_level_values(\"datetime\").max()\n            new_signals = signals.loc[old_max:]\n            signals = pd.concat([old_signals, new_signals], axis=0)\n        else:\n            new_signals = signals\n        self.logger.info(f\"Finished preparing new {len(new_signals)} signals.\")\n        self.signals = signals\n        return new_signals\n\n    def get_signals(self) -> Union[pd.Series, pd.DataFrame]:\n        \"\"\"\n        Get prepared online signals.\n\n        Returns:\n            Union[pd.Series, pd.DataFrame]: pd.Series for only one signals every datetime.\n            pd.DataFrame for multiple signals, for example, buy and sell operations use different trading signals.\n        \"\"\"\n        return self.signals\n\n    SIM_LOG_LEVEL = logging.INFO + 1  # when simulating, reduce information\n    SIM_LOG_NAME = \"SIMULATE_INFO\"\n\n    def simulate(\n        self, end_time=None, frequency=\"day\", task_kwargs={}, model_kwargs={}, signal_kwargs={}\n    ) -> Union[pd.Series, pd.DataFrame]:\n        \"\"\"\n        Starting from the current time, this method will simulate every routine in OnlineManager until the end time.\n\n        Considering the parallel training, the models and signals can be prepared after all routine simulating.\n\n        The delay training way can be ``DelayTrainer`` and the delay preparing signals way can be ``delay_prepare``.\n\n        Args:\n            end_time: the time the simulation will end\n            frequency: the calendar frequency\n            task_kwargs (dict): the params for `prepare_tasks`\n            model_kwargs (dict): the params for `prepare_online_models`\n            signal_kwargs (dict): the params for `prepare_signals`\n\n        Returns:\n            Union[pd.Series, pd.DataFrame]: pd.Series for only one signals every datetime.\n            pd.DataFrame for multiple signals, for example, buy and sell operations use different trading signals.\n        \"\"\"\n        self.status = self.STATUS_SIMULATING\n        cal = D.calendar(start_time=self.cur_time, end_time=end_time, freq=frequency)\n        self.first_train()\n\n        simulate_level = self.SIM_LOG_LEVEL\n        set_global_logger_level(simulate_level)\n        logging.addLevelName(simulate_level, self.SIM_LOG_NAME)\n\n        for cur_time in cal:\n            self.logger.log(level=simulate_level, msg=f\"Simulating at {str(cur_time)}......\")\n            self.routine(\n                cur_time,\n                task_kwargs=task_kwargs,\n                model_kwargs=model_kwargs,\n                signal_kwargs=signal_kwargs,\n            )\n        # delay prepare the models and signals\n        if self._postpone_action():\n            self.delay_prepare(model_kwargs=model_kwargs, signal_kwargs=signal_kwargs)\n\n        # FIXME: get logging level firstly and restore it here\n        set_global_logger_level(logging.DEBUG)\n        self.logger.info(f\"Finished preparing signals\")\n        self.status = self.STATUS_ONLINE\n        return self.get_signals()\n\n    def delay_prepare(self, model_kwargs={}, signal_kwargs={}):\n        \"\"\"\n        Prepare all models and signals if something is waiting for preparation.\n\n        Args:\n            model_kwargs: the params for `end_train`\n            signal_kwargs: the params for `prepare_signals`\n        \"\"\"\n        # FIXME:\n        # This method is not implemented in the proper way!!!\n        last_models = {}\n        signals_time = D.calendar()[0]\n        need_prepare = False\n        for cur_time, strategy_models in self.history.items():\n            self.cur_time = cur_time\n\n            for strategy, models in strategy_models.items():\n                # only new online models need to prepare\n                if last_models.setdefault(strategy, set()) != set(models):\n                    models = self.trainer.end_train(models, experiment_name=strategy.name_id, **model_kwargs)\n                    strategy.tool.reset_online_tag(models)\n                    need_prepare = True\n                last_models[strategy] = set(models)\n\n            if need_prepare:\n                # NOTE: Assumption: the predictions of online models need less than next cur_time, or this method will work in a wrong way.\n                self.prepare_signals(**signal_kwargs)\n                if signals_time > cur_time:\n                    # FIXME: if use DelayTrainer and worker (and worker is faster than main progress), there are some possibilities of showing this warning.\n                    self.logger.warn(\n                        f\"The signals have already parpred to {signals_time} by last preparation, but current time is only {cur_time}. This may be because the online models predict more than they should, which can cause signals to be contaminated by the offline models.\"\n                    )\n                need_prepare = False\n                signals_time = self.signals.index.get_level_values(\"datetime\").max()\n"
  },
  {
    "path": "qlib/workflow/online/strategy.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\n\"\"\"\nOnlineStrategy module is an element of online serving.\n\"\"\"\n\nfrom typing import List, Union\nfrom qlib.log import get_module_logger\nfrom qlib.model.ens.group import RollingGroup\nfrom qlib.utils import transform_end_date\nfrom qlib.workflow.online.utils import OnlineTool, OnlineToolR\nfrom qlib.workflow.recorder import Recorder\nfrom qlib.workflow.task.collect import Collector, RecorderCollector\nfrom qlib.workflow.task.gen import RollingGen, task_generator\nfrom qlib.workflow.task.utils import TimeAdjuster\n\n\nclass OnlineStrategy:\n    \"\"\"\n    OnlineStrategy is working with `Online Manager <#Online Manager>`_, responding to how the tasks are generated, the models are updated and signals are prepared.\n    \"\"\"\n\n    def __init__(self, name_id: str):\n        \"\"\"\n        Init OnlineStrategy.\n        This module **MUST** use `Trainer <../reference/api.html#qlib.model.trainer.Trainer>`_ to finishing model training.\n\n        Args:\n            name_id (str): a unique name or id.\n            trainer (qlib.model.trainer.Trainer, optional): a instance of Trainer. Defaults to None.\n        \"\"\"\n        self.name_id = name_id\n        self.logger = get_module_logger(self.__class__.__name__)\n        self.tool = OnlineTool()\n\n    def prepare_tasks(self, cur_time, **kwargs) -> List[dict]:\n        \"\"\"\n        After the end of a routine, check whether we need to prepare and train some new tasks based on cur_time (None for latest)..\n        Return the new tasks waiting for training.\n\n        You can find the last online models by OnlineTool.online_models.\n        \"\"\"\n        raise NotImplementedError(f\"Please implement the `prepare_tasks` method.\")\n\n    def prepare_online_models(self, trained_models, cur_time=None) -> List[object]:\n        \"\"\"\n        Select some models from trained models and set them to online models.\n        This is a typical implementation to online all trained models, you can override it to implement the complex method.\n        You can find the last online models by OnlineTool.online_models if you still need them.\n\n        NOTE: Reset all online models to trained models. If there are no trained models, then do nothing.\n\n        **NOTE**:\n            Current implementation is very naive. Here is a more complex situation which is more closer to the\n            practical scenarios.\n            1. Train new models at the day before `test_start` (at time stamp `T`)\n            2. Switch models at the `test_start` (at time timestamp `T + 1` typically)\n\n        Args:\n            models (list): a list of models.\n            cur_time (pd.Dataframe): current time from OnlineManger. None for the latest.\n\n        Returns:\n            List[object]: a list of online models.\n        \"\"\"\n        if not trained_models:\n            return self.tool.online_models()\n        self.tool.reset_online_tag(trained_models)\n        return trained_models\n\n    def first_tasks(self) -> List[dict]:\n        \"\"\"\n        Generate a series of tasks firstly and return them.\n        \"\"\"\n        raise NotImplementedError(f\"Please implement the `first_tasks` method.\")\n\n    def get_collector(self) -> Collector:\n        \"\"\"\n        Get the instance of `Collector <../advanced/task_management.html#Task Collecting>`_ to collect different results of this strategy.\n\n        For example:\n            1) collect predictions in Recorder\n            2) collect signals in a txt file\n\n        Returns:\n            Collector\n        \"\"\"\n        raise NotImplementedError(f\"Please implement the `get_collector` method.\")\n\n\nclass RollingStrategy(OnlineStrategy):\n    \"\"\"\n    This example strategy always uses the latest rolling model sas online models.\n    \"\"\"\n\n    def __init__(\n        self,\n        name_id: str,\n        task_template: Union[dict, List[dict]],\n        rolling_gen: RollingGen,\n    ):\n        \"\"\"\n        Init RollingStrategy.\n\n        Assumption: the str of name_id, the experiment name, and the trainer's experiment name are the same.\n\n        Args:\n            name_id (str): a unique name or id. Will be also the name of the Experiment.\n            task_template (Union[dict, List[dict]]): a list of task_template or a single template, which will be used to generate many tasks using rolling_gen.\n            rolling_gen (RollingGen): an instance of RollingGen\n        \"\"\"\n        super().__init__(name_id=name_id)\n        self.exp_name = self.name_id\n        if not isinstance(task_template, list):\n            task_template = [task_template]\n        self.task_template = task_template\n        self.rg = rolling_gen\n        assert issubclass(self.rg.__class__, RollingGen), \"The rolling strategy relies on the feature if RollingGen\"\n        self.tool = OnlineToolR(self.exp_name)\n        self.ta = TimeAdjuster()\n\n    def get_collector(self, process_list=[RollingGroup()], rec_key_func=None, rec_filter_func=None, artifacts_key=None):\n        \"\"\"\n        Get the instance of `Collector <../advanced/task_management.html#Task Collecting>`_ to collect results. The returned collector must distinguish results in different models.\n\n        Assumption: the models can be distinguished based on the model name and rolling test segments.\n        If you do not want this assumption, please implement your method or use another rec_key_func.\n\n        Args:\n            rec_key_func (Callable): a function to get the key of a recorder. If None, use recorder id.\n            rec_filter_func (Callable, optional): filter the recorder by return True or False. Defaults to None.\n            artifacts_key (List[str], optional): the artifacts key you want to get. If None, get all artifacts.\n        \"\"\"\n\n        def rec_key(recorder):\n            task_config = recorder.load_object(\"task\")\n            model_key = task_config[\"model\"][\"class\"]\n            rolling_key = task_config[\"dataset\"][\"kwargs\"][\"segments\"][\"test\"]\n            return model_key, rolling_key\n\n        if rec_key_func is None:\n            rec_key_func = rec_key\n\n        artifacts_collector = RecorderCollector(\n            experiment=self.exp_name,\n            process_list=process_list,\n            rec_key_func=rec_key_func,\n            rec_filter_func=rec_filter_func,\n            artifacts_key=artifacts_key,\n        )\n\n        return artifacts_collector\n\n    def first_tasks(self) -> List[dict]:\n        \"\"\"\n        Use rolling_gen to generate different tasks based on task_template.\n\n        Returns:\n            List[dict]: a list of tasks\n        \"\"\"\n        return task_generator(\n            tasks=self.task_template,\n            generators=self.rg,  # generate different date segment\n        )\n\n    def prepare_tasks(self, cur_time) -> List[dict]:\n        \"\"\"\n        Prepare new tasks based on cur_time (None for the latest).\n\n        You can find the last online models by OnlineToolR.online_models.\n\n        Returns:\n            List[dict]: a list of new tasks.\n        \"\"\"\n        # TODO: filter recorders by latest test segments is not a necessary\n        latest_records, max_test = self._list_latest(self.tool.online_models())\n        if max_test is None:\n            self.logger.warn(f\"No latest online recorders, no new tasks.\")\n            return []\n        calendar_latest = transform_end_date(cur_time)\n        self.logger.info(\n            f\"The interval between current time {calendar_latest} and last rolling test begin time {max_test[0]} is {self.ta.cal_interval(calendar_latest, max_test[0])}, the rolling step is {self.rg.step}\"\n        )\n        res = []\n        for rec in latest_records:\n            task = rec.load_object(\"task\")\n            res.extend(self.rg.gen_following_tasks(task, calendar_latest))\n        return res\n\n    def _list_latest(self, rec_list: List[Recorder]):\n        \"\"\"\n        List latest recorder form rec_list\n\n        Args:\n            rec_list (List[Recorder]): a list of Recorder\n\n        Returns:\n            List[Recorder], pd.Timestamp: the latest recorders and their test end time\n        \"\"\"\n        if len(rec_list) == 0:\n            return rec_list, None\n        max_test = max(rec.load_object(\"task\")[\"dataset\"][\"kwargs\"][\"segments\"][\"test\"] for rec in rec_list)\n        latest_rec = []\n        for rec in rec_list:\n            if rec.load_object(\"task\")[\"dataset\"][\"kwargs\"][\"segments\"][\"test\"] == max_test:\n                latest_rec.append(rec)\n        return latest_rec, max_test\n"
  },
  {
    "path": "qlib/workflow/online/update.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\"\"\"\nUpdater is a module to update artifacts such as predictions when the stock data is updating.\n\"\"\"\n\nfrom abc import ABCMeta, abstractmethod\nfrom typing import Optional\n\nimport pandas as pd\nfrom qlib import get_module_logger\nfrom qlib.data import D\nfrom qlib.data.dataset import Dataset, DatasetH, TSDatasetH\nfrom qlib.data.dataset.handler import DataHandlerLP\nfrom qlib.model import Model\nfrom qlib.utils import get_date_by_shift\nfrom qlib.workflow.recorder import Recorder\nfrom qlib.workflow.record_temp import SignalRecord\n\n\nclass RMDLoader:\n    \"\"\"\n    Recorder Model Dataset Loader\n    \"\"\"\n\n    def __init__(self, rec: Recorder):\n        self.rec = rec\n\n    def get_dataset(\n        self, start_time, end_time, segments=None, unprepared_dataset: Optional[DatasetH] = None\n    ) -> DatasetH:\n        \"\"\"\n        Load, config and setup dataset.\n\n        This dataset is for inference.\n\n        Args:\n            start_time :\n                the start_time of underlying data\n            end_time :\n                the end_time of underlying data\n            segments : dict\n                the segments config for dataset\n                Due to the time series dataset (TSDatasetH), the test segments maybe different from start_time and end_time\n            unprepared_dataset: Optional[DatasetH]\n                if user don't want to load dataset from recorder, please specify user's dataset\n\n        Returns:\n            DatasetH: the instance of DatasetH\n\n        \"\"\"\n        if segments is None:\n            segments = {\"test\": (start_time, end_time)}\n        if unprepared_dataset is None:\n            dataset: DatasetH = self.rec.load_object(\"dataset\")\n        else:\n            dataset = unprepared_dataset\n        dataset.config(handler_kwargs={\"start_time\": start_time, \"end_time\": end_time}, segments=segments)\n        dataset.setup_data(handler_kwargs={\"init_type\": DataHandlerLP.IT_LS})\n        return dataset\n\n    def get_model(self) -> Model:\n        return self.rec.load_object(\"params.pkl\")\n\n\nclass RecordUpdater(metaclass=ABCMeta):\n    \"\"\"\n    Update a specific recorders\n    \"\"\"\n\n    def __init__(self, record: Recorder, *args, **kwargs):\n        self.record = record\n        self.logger = get_module_logger(self.__class__.__name__)\n\n    @abstractmethod\n    def update(self, *args, **kwargs):\n        \"\"\"\n        Update info for specific recorder\n        \"\"\"\n\n\nclass DSBasedUpdater(RecordUpdater, metaclass=ABCMeta):\n    \"\"\"\n    Dataset-Based Updater\n\n    - Providing updating feature for Updating data based on Qlib Dataset\n\n    Assumption\n\n    - Based on Qlib dataset\n    - The data to be updated is a multi-level index pd.DataFrame. For example label, prediction.\n\n        .. code-block::\n\n                                     LABEL0\n            datetime   instrument\n            2021-05-10 SH600000    0.006965\n                       SH600004    0.003407\n            ...                         ...\n            2021-05-28 SZ300498    0.015748\n                       SZ300676   -0.001321\n    \"\"\"\n\n    def __init__(\n        self,\n        record: Recorder,\n        to_date=None,\n        from_date=None,\n        hist_ref: Optional[int] = None,\n        freq=\"day\",\n        fname=\"pred.pkl\",\n        loader_cls: type = RMDLoader,\n    ):\n        \"\"\"\n        Init PredUpdater.\n\n        Expected behavior in following cases:\n\n        - if `to_date` is greater than the max date in the calendar, the data will be updated to the latest date\n        - if there are data before `from_date` or after `to_date`, only the data between `from_date` and `to_date` are affected.\n\n        Args:\n            record : Recorder\n            to_date :\n                update to prediction to the `to_date`\n\n                if to_date is None:\n\n                    data will updated to the latest date.\n            from_date :\n                the update will start from `from_date`\n\n                if from_date is None:\n\n                    the updating will occur on the next tick after the latest data in historical data\n            hist_ref : int\n                Sometimes, the dataset will have historical depends.\n                Leave the problem to users to set the length of historical dependency\n                If user doesn't specify this parameter, Updater will try to load dataset to automatically determine the hist_ref\n\n                .. note::\n\n                    the start_time is not included in the `hist_ref`; So the `hist_ref` will be `step_len - 1` in most cases\n\n            loader_cls : type\n                the class to load the model and dataset\n\n        \"\"\"\n        # TODO: automate this hist_ref in the future.\n        super().__init__(record=record)\n\n        self.to_date = to_date\n        self.hist_ref = hist_ref\n        self.freq = freq\n        self.fname = fname\n        self.rmdl = loader_cls(rec=record)\n\n        latest_date = D.calendar(freq=freq)[-1]\n        if to_date is None:\n            to_date = latest_date\n        to_date = pd.Timestamp(to_date)\n\n        if to_date >= latest_date:\n            self.logger.warning(\n                f\"The given `to_date`({to_date}) is later than `latest_date`({latest_date}). So `to_date` is clipped to `latest_date`.\"\n            )\n            to_date = latest_date\n        self.to_date = to_date\n\n        # FIXME: it will raise error when running routine with delay trainer\n        # should we use another prediction updater for delay trainer?\n        self.old_data: pd.DataFrame = record.load_object(fname)\n        if from_date is None:\n            # dropna is for being compatible to some data with future information(e.g. label)\n            # The recent label data should be updated together\n            self.last_end = self.old_data.dropna().index.get_level_values(\"datetime\").max()\n        else:\n            self.last_end = get_date_by_shift(from_date, -1, align=\"right\")\n\n    def prepare_data(self, unprepared_dataset: Optional[DatasetH] = None) -> DatasetH:\n        \"\"\"\n        Load dataset\n        - if unprepared_dataset is specified, then prepare the dataset directly\n        - Otherwise,\n\n        Separating this function will make it easier to reuse the dataset\n\n        Returns:\n            DatasetH: the instance of DatasetH\n        \"\"\"\n        # automatically getting the historical dependency if not specified\n        if self.hist_ref is None:\n            dataset: DatasetH = self.record.load_object(\"dataset\") if unprepared_dataset is None else unprepared_dataset\n            # Special treatment of historical dependencies\n            if isinstance(dataset, TSDatasetH):\n                hist_ref = dataset.step_len - 1\n            else:\n                hist_ref = 0  # if only the lastest data is used, then only current data will be used and no historical data will be used\n        else:\n            hist_ref = self.hist_ref\n\n        start_time_buffer = get_date_by_shift(\n            self.last_end, -hist_ref + 1, clip_shift=False, freq=self.freq  # pylint: disable=E1130\n        )\n        start_time = get_date_by_shift(self.last_end, 1, freq=self.freq)\n        seg = {\"test\": (start_time, self.to_date)}\n        return self.rmdl.get_dataset(\n            start_time=start_time_buffer, end_time=self.to_date, segments=seg, unprepared_dataset=unprepared_dataset\n        )\n\n    def update(self, dataset: DatasetH = None, write: bool = True, ret_new: bool = False) -> Optional[object]:\n        \"\"\"\n        Parameters\n        ----------\n        dataset : DatasetH\n            DatasetH: the instance of DatasetH. None for prepare it again.\n        write : bool\n            will the the write action be executed\n        ret_new : bool\n            will the updated data be returned\n\n        Returns\n        -------\n        Optional[object]\n            the updated dataset\n        \"\"\"\n        # FIXME: the problem below is not solved\n        # The model dumped on GPU instances can not be loaded on CPU instance. Follow exception will raised\n        # RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.\n        # https://github.com/pytorch/pytorch/issues/16797\n\n        if self.last_end >= self.to_date:\n            self.logger.info(\n                f\"The data in {self.record.info['id']} are latest ({self.last_end}). No need to update to {self.to_date}.\"\n            )\n            return\n\n        # load dataset\n        if dataset is None:\n            # For reusing the dataset\n            dataset = self.prepare_data()\n\n        updated_data = self.get_update_data(dataset)\n\n        if write:\n            self.record.save_objects(**{self.fname: updated_data})\n        if ret_new:\n            return updated_data\n\n    @abstractmethod\n    def get_update_data(self, dataset: Dataset) -> pd.DataFrame:\n        \"\"\"\n        return the updated data based on the given dataset\n\n        The difference between `get_update_data` and `update`\n        - `update_date` only include some data specific feature\n        - `update` include some general routine steps(e.g. prepare dataset, checking)\n        \"\"\"\n\n\ndef _replace_range(data, new_data):\n    dates = new_data.index.get_level_values(\"datetime\")\n    data = data.sort_index()\n    data = data.drop(data.loc[dates.min() : dates.max()].index)\n    cb_data = pd.concat([data, new_data], axis=0)\n    cb_data = cb_data[~cb_data.index.duplicated(keep=\"last\")].sort_index()\n    return cb_data\n\n\nclass PredUpdater(DSBasedUpdater):\n    \"\"\"\n    Update the prediction in the Recorder\n    \"\"\"\n\n    def get_update_data(self, dataset: Dataset) -> pd.DataFrame:\n        # Load model\n        model = self.rmdl.get_model()\n        new_pred: pd.Series = model.predict(dataset)\n        data = _replace_range(self.old_data, new_pred.to_frame(\"score\"))\n        self.logger.info(f\"Finish updating new {new_pred.shape[0]} predictions in {self.record.info['id']}.\")\n        return data\n\n\nclass LabelUpdater(DSBasedUpdater):\n    \"\"\"\n    Update the label in the recorder\n\n    Assumption\n    - The label is generated from record_temp.SignalRecord.\n    \"\"\"\n\n    def __init__(self, record: Recorder, to_date=None, **kwargs):\n        super().__init__(record, to_date=to_date, fname=\"label.pkl\", **kwargs)\n\n    def get_update_data(self, dataset: Dataset) -> pd.DataFrame:\n        new_label = SignalRecord.generate_label(dataset)\n        cb_data = _replace_range(self.old_data.sort_index(), new_label)\n        return cb_data\n"
  },
  {
    "path": "qlib/workflow/online/utils.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\n\"\"\"\nOnlineTool is a module to set and unset a series of `online` models.\nThe `online` models are some decisive models in some time points, which can be changed with the change of time.\nThis allows us to use efficient submodels as the market-style changing.\n\"\"\"\n\nfrom typing import List, Union\n\nfrom qlib.log import get_module_logger\nfrom qlib.utils.exceptions import LoadObjectError\nfrom qlib.workflow.online.update import PredUpdater\nfrom qlib.workflow.recorder import Recorder\nfrom qlib.workflow.task.utils import list_recorders\n\n\nclass OnlineTool:\n    \"\"\"\n    OnlineTool will manage `online` models in an experiment that includes the model recorders.\n    \"\"\"\n\n    ONLINE_KEY = \"online_status\"  # the online status key in recorder\n    ONLINE_TAG = \"online\"  # the 'online' model\n    OFFLINE_TAG = \"offline\"  # the 'offline' model, not for online serving\n\n    def __init__(self):\n        \"\"\"\n        Init OnlineTool.\n        \"\"\"\n        self.logger = get_module_logger(self.__class__.__name__)\n\n    def set_online_tag(self, tag, recorder: Union[list, object]):\n        \"\"\"\n        Set `tag` to the model to sign whether online.\n\n        Args:\n            tag (str): the tags in `ONLINE_TAG`, `OFFLINE_TAG`\n            recorder (Union[list,object]): the model's recorder\n        \"\"\"\n        raise NotImplementedError(f\"Please implement the `set_online_tag` method.\")\n\n    def get_online_tag(self, recorder: object) -> str:\n        \"\"\"\n        Given a model recorder and return its online tag.\n\n        Args:\n            recorder (Object): the model's recorder\n\n        Returns:\n            str: the online tag\n        \"\"\"\n        raise NotImplementedError(f\"Please implement the `get_online_tag` method.\")\n\n    def reset_online_tag(self, recorder: Union[list, object]):\n        \"\"\"\n        Offline all models and set the recorders to 'online'.\n\n        Args:\n            recorder (Union[list,object]):\n                the recorder you want to reset to 'online'.\n\n        \"\"\"\n        raise NotImplementedError(f\"Please implement the `reset_online_tag` method.\")\n\n    def online_models(self) -> list:\n        \"\"\"\n        Get current `online` models\n\n        Returns:\n            list: a list of `online` models.\n        \"\"\"\n        raise NotImplementedError(f\"Please implement the `online_models` method.\")\n\n    def update_online_pred(self, to_date=None):\n        \"\"\"\n        Update the predictions of `online` models to to_date.\n\n        Args:\n            to_date (pd.Timestamp): the pred before this date will be updated. None for updating to the latest.\n\n        \"\"\"\n        raise NotImplementedError(f\"Please implement the `update_online_pred` method.\")\n\n\nclass OnlineToolR(OnlineTool):\n    \"\"\"\n    The implementation of OnlineTool based on (R)ecorder.\n    \"\"\"\n\n    def __init__(self, default_exp_name: str = None):\n        \"\"\"\n        Init OnlineToolR.\n\n        Args:\n            default_exp_name (str): the default experiment name.\n        \"\"\"\n        super().__init__()\n        self.default_exp_name = default_exp_name\n\n    def set_online_tag(self, tag, recorder: Union[Recorder, List]):\n        \"\"\"\n        Set `tag` to the model's recorder to sign whether online.\n\n        Args:\n            tag (str): the tags in `ONLINE_TAG`, `NEXT_ONLINE_TAG`, `OFFLINE_TAG`\n            recorder (Union[Recorder, List]): a list of Recorder or an instance of Recorder\n        \"\"\"\n        if isinstance(recorder, Recorder):\n            recorder = [recorder]\n        for rec in recorder:\n            rec.set_tags(**{self.ONLINE_KEY: tag})\n        self.logger.info(f\"Set {len(recorder)} models to '{tag}'.\")\n\n    def get_online_tag(self, recorder: Recorder) -> str:\n        \"\"\"\n        Given a model recorder and return its online tag.\n\n        Args:\n            recorder (Recorder): an instance of recorder\n\n        Returns:\n            str: the online tag\n        \"\"\"\n        tags = recorder.list_tags()\n        return tags.get(self.ONLINE_KEY, self.OFFLINE_TAG)\n\n    def reset_online_tag(self, recorder: Union[Recorder, List], exp_name: str = None):\n        \"\"\"\n        Offline all models and set the recorders to 'online'.\n\n        Args:\n            recorder (Union[Recorder, List]):\n                the recorder you want to reset to 'online'.\n            exp_name (str): the experiment name. If None, then use default_exp_name.\n\n        \"\"\"\n        exp_name = self._get_exp_name(exp_name)\n        if isinstance(recorder, Recorder):\n            recorder = [recorder]\n        recs = list_recorders(exp_name)\n        self.set_online_tag(self.OFFLINE_TAG, list(recs.values()))\n        self.set_online_tag(self.ONLINE_TAG, recorder)\n\n    def online_models(self, exp_name: str = None) -> list:\n        \"\"\"\n        Get current `online` models\n\n        Args:\n            exp_name (str): the experiment name. If None, then use default_exp_name.\n\n        Returns:\n            list: a list of `online` models.\n        \"\"\"\n        exp_name = self._get_exp_name(exp_name)\n        return list(list_recorders(exp_name, lambda rec: self.get_online_tag(rec) == self.ONLINE_TAG).values())\n\n    def update_online_pred(self, to_date=None, from_date=None, exp_name: str = None):\n        \"\"\"\n        Update the predictions of online models to to_date.\n\n        Args:\n            to_date (pd.Timestamp): the pred before this date will be updated. None for updating to latest time in Calendar.\n            exp_name (str): the experiment name. If None, then use default_exp_name.\n        \"\"\"\n        exp_name = self._get_exp_name(exp_name)\n        online_models = self.online_models(exp_name=exp_name)\n        for rec in online_models:\n            try:\n                updater = PredUpdater(rec, to_date=to_date, from_date=from_date)\n            except LoadObjectError as e:\n                # skip the recorder without pred\n                self.logger.warn(f\"An exception `{str(e)}` happened when load `pred.pkl`, skip it.\")\n                continue\n            updater.update()\n\n        self.logger.info(f\"Finished updating {len(online_models)} online model predictions of {exp_name}.\")\n\n    def _get_exp_name(self, exp_name):\n        if exp_name is None:\n            if self.default_exp_name is None:\n                raise ValueError(\n                    \"Both default_exp_name and exp_name are None. OnlineToolR needs a specific experiment.\"\n                )\n            exp_name = self.default_exp_name\n        return exp_name\n"
  },
  {
    "path": "qlib/workflow/record_temp.py",
    "content": "#  Copyright (c) Microsoft Corporation.\n#  Licensed under the MIT License.\n\nimport logging\nimport warnings\nimport pandas as pd\nimport numpy as np\nfrom tqdm import trange\nfrom pprint import pprint\nfrom typing import Union, List, Optional, Dict\n\nfrom qlib.utils.exceptions import LoadObjectError\nfrom ..contrib.evaluate import risk_analysis, indicator_analysis\n\nfrom ..data.dataset import DatasetH\nfrom ..data.dataset.handler import DataHandlerLP\nfrom ..backtest import backtest as normal_backtest\nfrom ..log import get_module_logger\nfrom ..utils import fill_placeholder, flatten_dict, class_casting, get_date_by_shift\nfrom ..utils.time import Freq\nfrom ..utils.data import deepcopy_basic_type\nfrom ..utils.exceptions import QlibException\nfrom ..contrib.eva.alpha import calc_ic, calc_long_short_return, calc_long_short_prec\n\nlogger = get_module_logger(\"workflow\", logging.INFO)\n\n\nclass RecordTemp:\n    \"\"\"\n    This is the Records Template class that enables user to generate experiment results such as IC and\n    backtest in a certain format.\n    \"\"\"\n\n    artifact_path = None\n    depend_cls = None  # the dependant class of the record; the record will depend on the results generated by\n    # `depend_cls`\n\n    @classmethod\n    def get_path(cls, path=None):\n        names = []\n        if cls.artifact_path is not None:\n            names.append(cls.artifact_path)\n\n        if path is not None:\n            names.append(path)\n\n        return \"/\".join(names)\n\n    def save(self, **kwargs):\n        \"\"\"\n        It behaves the same as self.recorder.save_objects.\n        But it is an easier interface because users don't have to care about `get_path` and `artifact_path`\n        \"\"\"\n        art_path = self.get_path()\n        if art_path == \"\":\n            art_path = None\n        self.recorder.save_objects(artifact_path=art_path, **kwargs)\n\n    def __init__(self, recorder):\n        self._recorder = recorder\n\n    @property\n    def recorder(self):\n        if self._recorder is None:\n            raise ValueError(\"This RecordTemp did not set recorder yet.\")\n        return self._recorder\n\n    def generate(self, **kwargs):\n        \"\"\"\n        Generate certain records such as IC, backtest etc., and save them.\n\n        Parameters\n        ----------\n        kwargs\n\n        Return\n        ------\n        \"\"\"\n        raise NotImplementedError(f\"Please implement the `generate` method.\")\n\n    def load(self, name: str, parents: bool = True):\n        \"\"\"\n        It behaves the same as self.recorder.load_object.\n        But it is an easier interface because users don't have to care about `get_path` and `artifact_path`\n\n        Parameters\n        ----------\n        name : str\n            the name for the file to be load.\n\n        parents : bool\n            Each recorder has different `artifact_path`.\n            So parents recursively find the path in parents\n            Sub classes has higher priority\n\n        Return\n        ------\n        The stored records.\n        \"\"\"\n        try:\n            return self.recorder.load_object(self.get_path(name))\n        except LoadObjectError as e:\n            if parents:\n                if self.depend_cls is not None:\n                    with class_casting(self, self.depend_cls):\n                        return self.load(name, parents=True)\n            raise e\n\n    def list(self):\n        \"\"\"\n        List the supported artifacts.\n        Users don't have to consider self.get_path\n\n        Return\n        ------\n        A list of all the supported artifacts.\n        \"\"\"\n        return []\n\n    def check(self, include_self: bool = False, parents: bool = True):\n        \"\"\"\n        Check if the records is properly generated and saved.\n        It is useful in following examples\n\n        - checking if the dependant files complete before generating new things.\n        - checking if the final files is completed\n\n        Parameters\n        ----------\n        include_self : bool\n            is the file generated by self included\n        parents : bool\n            will we check parents\n\n        Raise\n        ------\n        FileNotFoundError\n            whether the records are stored properly.\n        \"\"\"\n        if include_self:\n            # Some mlflow backend will not list the directly recursively.\n            # So we force to the directly\n            artifacts = {}\n\n            def _get_arts(dirn):\n                if dirn not in artifacts:\n                    artifacts[dirn] = self.recorder.list_artifacts(dirn)\n                return artifacts[dirn]\n\n            for item in self.list():\n                ps = self.get_path(item).split(\"/\")\n                dirn = \"/\".join(ps[:-1])\n                if self.get_path(item) not in _get_arts(dirn):\n                    raise FileNotFoundError\n        if parents:\n            if self.depend_cls is not None:\n                with class_casting(self, self.depend_cls):\n                    self.check(include_self=True)\n\n\nclass SignalRecord(RecordTemp):\n    \"\"\"\n    This is the Signal Record class that generates the signal prediction. This class inherits the ``RecordTemp`` class.\n    \"\"\"\n\n    def __init__(self, model=None, dataset=None, recorder=None):\n        super().__init__(recorder=recorder)\n        self.model = model\n        self.dataset = dataset\n\n    @staticmethod\n    def generate_label(dataset):\n        with class_casting(dataset, DatasetH):\n            params = dict(segments=\"test\", col_set=\"label\", data_key=DataHandlerLP.DK_R)\n            try:\n                # Assume the backend handler is DataHandlerLP\n                raw_label = dataset.prepare(**params)\n            except TypeError:\n                # The argument number is not right\n                del params[\"data_key\"]\n                # The backend handler should be DataHandler\n                raw_label = dataset.prepare(**params)\n            except AttributeError as e:\n                # The data handler is initialized with `drop_raw=True`...\n                # So raw_label is not available\n                logger.warning(f\"Exception: {e}\")\n                raw_label = None\n        return raw_label\n\n    def generate(self, **kwargs):\n        # generate prediction\n        pred = self.model.predict(self.dataset)\n        if isinstance(pred, pd.Series):\n            pred = pred.to_frame(\"score\")\n        self.save(**{\"pred.pkl\": pred})\n\n        logger.info(\n            f\"Signal record 'pred.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}\"\n        )\n        # print out results\n        pprint(f\"The following are prediction results of the {type(self.model).__name__} model.\")\n        pprint(pred.head(5))\n\n        if isinstance(self.dataset, DatasetH):\n            raw_label = self.generate_label(self.dataset)\n            self.save(**{\"label.pkl\": raw_label})\n\n    def list(self):\n        return [\"pred.pkl\", \"label.pkl\"]\n\n\nclass ACRecordTemp(RecordTemp):\n    \"\"\"Automatically checking record template\"\"\"\n\n    def __init__(self, recorder, skip_existing=False):\n        self.skip_existing = skip_existing\n        super().__init__(recorder=recorder)\n\n    def generate(self, *args, **kwargs):\n        \"\"\"automatically checking the files and then run the concrete generating task\"\"\"\n        if self.skip_existing:\n            try:\n                self.check(include_self=True, parents=False)\n            except FileNotFoundError:\n                pass  # continue to generating metrics\n            else:\n                logger.info(\"The results has previously generated, Generation skipped.\")\n                return\n\n        try:\n            self.check()\n        except FileNotFoundError:\n            logger.warning(\"The dependent data does not exists. Generation skipped.\")\n            return\n        artifact_dict = self._generate(*args, **kwargs)\n        if isinstance(artifact_dict, dict):\n            self.save(**artifact_dict)\n        return artifact_dict\n\n    def _generate(self, *args, **kwargs) -> Dict[str, object]:\n        \"\"\"\n        Run the concrete generating task, return the dictionary of the generated results.\n        The caller method will save the results to the recorder.\n        \"\"\"\n        raise NotImplementedError(f\"Please implement the `_generate` method\")\n\n\nclass HFSignalRecord(SignalRecord):\n    \"\"\"\n    This is the Signal Analysis Record class that generates the analysis results such as IC and IR. This class inherits the ``RecordTemp`` class.\n    \"\"\"\n\n    artifact_path = \"hg_sig_analysis\"\n    depend_cls = SignalRecord\n\n    def __init__(self, recorder, **kwargs):\n        super().__init__(recorder=recorder)\n\n    def generate(self):\n        pred = self.load(\"pred.pkl\")\n        raw_label = self.load(\"label.pkl\")\n        long_pre, short_pre = calc_long_short_prec(pred.iloc[:, 0], raw_label.iloc[:, 0], is_alpha=True)\n        ic, ric = calc_ic(pred.iloc[:, 0], raw_label.iloc[:, 0])\n        metrics = {\n            \"IC\": ic.mean(),\n            \"ICIR\": ic.mean() / ic.std(),\n            \"Rank IC\": ric.mean(),\n            \"Rank ICIR\": ric.mean() / ric.std(),\n            \"Long precision\": long_pre.mean(),\n            \"Short precision\": short_pre.mean(),\n        }\n        objects = {\"ic.pkl\": ic, \"ric.pkl\": ric}\n        objects.update({\"long_pre.pkl\": long_pre, \"short_pre.pkl\": short_pre})\n        long_short_r, long_avg_r = calc_long_short_return(pred.iloc[:, 0], raw_label.iloc[:, 0])\n        metrics.update(\n            {\n                \"Long-Short Average Return\": long_short_r.mean(),\n                \"Long-Short Average Sharpe\": long_short_r.mean() / long_short_r.std(),\n            }\n        )\n        objects.update(\n            {\n                \"long_short_r.pkl\": long_short_r,\n                \"long_avg_r.pkl\": long_avg_r,\n            }\n        )\n        self.recorder.log_metrics(**metrics)\n        self.save(**objects)\n        pprint(metrics)\n\n    def list(self):\n        return [\"ic.pkl\", \"ric.pkl\", \"long_pre.pkl\", \"short_pre.pkl\", \"long_short_r.pkl\", \"long_avg_r.pkl\"]\n\n\nclass SigAnaRecord(ACRecordTemp):\n    \"\"\"\n    This is the Signal Analysis Record class that generates the analysis results such as IC and IR.\n    This class inherits the ``RecordTemp`` class.\n    \"\"\"\n\n    artifact_path = \"sig_analysis\"\n    depend_cls = SignalRecord\n\n    def __init__(self, recorder, ana_long_short=False, ann_scaler=252, label_col=0, skip_existing=False):\n        super().__init__(recorder=recorder, skip_existing=skip_existing)\n        self.ana_long_short = ana_long_short\n        self.ann_scaler = ann_scaler\n        self.label_col = label_col\n\n    def _generate(self, label: Optional[pd.DataFrame] = None, **kwargs):\n        \"\"\"\n        Parameters\n        ----------\n        label : Optional[pd.DataFrame]\n            Label should be a dataframe.\n        \"\"\"\n        pred = self.load(\"pred.pkl\")\n        if label is None:\n            label = self.load(\"label.pkl\")\n        if label is None or not isinstance(label, pd.DataFrame) or label.empty:\n            logger.warning(f\"Empty label.\")\n            return\n        ic, ric = calc_ic(pred.iloc[:, 0], label.iloc[:, self.label_col])\n        metrics = {\n            \"IC\": ic.mean(),\n            \"ICIR\": ic.mean() / ic.std(),\n            \"Rank IC\": ric.mean(),\n            \"Rank ICIR\": ric.mean() / ric.std(),\n        }\n        objects = {\"ic.pkl\": ic, \"ric.pkl\": ric}\n        if self.ana_long_short:\n            long_short_r, long_avg_r = calc_long_short_return(pred.iloc[:, 0], label.iloc[:, self.label_col])\n            metrics.update(\n                {\n                    \"Long-Short Ann Return\": long_short_r.mean() * self.ann_scaler,\n                    \"Long-Short Ann Sharpe\": long_short_r.mean() / long_short_r.std() * self.ann_scaler**0.5,\n                    \"Long-Avg Ann Return\": long_avg_r.mean() * self.ann_scaler,\n                    \"Long-Avg Ann Sharpe\": long_avg_r.mean() / long_avg_r.std() * self.ann_scaler**0.5,\n                }\n            )\n            objects.update(\n                {\n                    \"long_short_r.pkl\": long_short_r,\n                    \"long_avg_r.pkl\": long_avg_r,\n                }\n            )\n        self.recorder.log_metrics(**metrics)\n        pprint(metrics)\n        return objects\n\n    def list(self):\n        paths = [\"ic.pkl\", \"ric.pkl\"]\n        if self.ana_long_short:\n            paths.extend([\"long_short_r.pkl\", \"long_avg_r.pkl\"])\n        return paths\n\n\nclass PortAnaRecord(ACRecordTemp):\n    \"\"\"\n    This is the Portfolio Analysis Record class that generates the analysis results such as those of backtest. This class inherits the ``RecordTemp`` class.\n\n    The following files will be stored in recorder\n\n    - report_normal.pkl & positions_normal.pkl:\n\n        - The return report and detailed positions of the backtest, returned by `qlib/contrib/evaluate.py:backtest`\n    - port_analysis.pkl : The risk analysis of your portfolio, returned by `qlib/contrib/evaluate.py:risk_analysis`\n    \"\"\"\n\n    artifact_path = \"portfolio_analysis\"\n    depend_cls = SignalRecord\n\n    def __init__(\n        self,\n        recorder,\n        config=None,\n        risk_analysis_freq: Union[List, str] = None,\n        indicator_analysis_freq: Union[List, str] = None,\n        indicator_analysis_method=None,\n        skip_existing=False,\n        **kwargs,\n    ):\n        \"\"\"\n        config[\"strategy\"] : dict\n            define the strategy class as well as the kwargs.\n        config[\"executor\"] : dict\n            define the executor class as well as the kwargs.\n        config[\"backtest\"] : dict\n            define the backtest kwargs.\n        risk_analysis_freq : str|List[str]\n            risk analysis freq of report\n        indicator_analysis_freq : str|List[str]\n            indicator analysis freq of report\n        indicator_analysis_method : str, optional, default by None\n            the candidate values include 'mean', 'amount_weighted', 'value_weighted'\n        \"\"\"\n        super().__init__(recorder=recorder, skip_existing=skip_existing, **kwargs)\n\n        if config is None:\n            config = {  # Default config for daily trading\n                \"strategy\": {\n                    \"class\": \"TopkDropoutStrategy\",\n                    \"module_path\": \"qlib.contrib.strategy\",\n                    \"kwargs\": {\"signal\": \"<PRED>\", \"topk\": 50, \"n_drop\": 5},\n                },\n                \"backtest\": {\n                    \"start_time\": None,\n                    \"end_time\": None,\n                    \"account\": 100000000,\n                    \"benchmark\": \"SH000300\",\n                    \"exchange_kwargs\": {\n                        \"limit_threshold\": 0.095,\n                        \"deal_price\": \"close\",\n                        \"open_cost\": 0.0005,\n                        \"close_cost\": 0.0015,\n                        \"min_cost\": 5,\n                    },\n                },\n            }\n        # We only deepcopy_basic_type because\n        # - We don't want to affect the config outside.\n        # - We don't want to deepcopy complex object to avoid overhead\n        config = deepcopy_basic_type(config)\n\n        self.strategy_config = config[\"strategy\"]\n        _default_executor_config = {\n            \"class\": \"SimulatorExecutor\",\n            \"module_path\": \"qlib.backtest.executor\",\n            \"kwargs\": {\n                \"time_per_step\": \"day\",\n                \"generate_portfolio_metrics\": True,\n            },\n        }\n        self.executor_config = config.get(\"executor\", _default_executor_config)\n        self.backtest_config = config[\"backtest\"]\n\n        self.all_freq = self._get_report_freq(self.executor_config)\n        if risk_analysis_freq is None:\n            risk_analysis_freq = [self.all_freq[0]]\n        if indicator_analysis_freq is None:\n            indicator_analysis_freq = [self.all_freq[0]]\n\n        if isinstance(risk_analysis_freq, str):\n            risk_analysis_freq = [risk_analysis_freq]\n        if isinstance(indicator_analysis_freq, str):\n            indicator_analysis_freq = [indicator_analysis_freq]\n\n        self.risk_analysis_freq = [\n            \"{0}{1}\".format(*Freq.parse(_analysis_freq)) for _analysis_freq in risk_analysis_freq\n        ]\n        self.indicator_analysis_freq = [\n            \"{0}{1}\".format(*Freq.parse(_analysis_freq)) for _analysis_freq in indicator_analysis_freq\n        ]\n        self.indicator_analysis_method = indicator_analysis_method\n\n    def _get_report_freq(self, executor_config):\n        ret_freq = []\n        if executor_config[\"kwargs\"].get(\"generate_portfolio_metrics\", False):\n            _count, _freq = Freq.parse(executor_config[\"kwargs\"][\"time_per_step\"])\n            ret_freq.append(f\"{_count}{_freq}\")\n        if \"inner_executor\" in executor_config[\"kwargs\"]:\n            ret_freq.extend(self._get_report_freq(executor_config[\"kwargs\"][\"inner_executor\"]))\n        return ret_freq\n\n    def _generate(self, **kwargs):\n        pred = self.load(\"pred.pkl\")\n\n        # replace the \"<PRED>\" with prediction saved before\n        placeholder_value = {\"<PRED>\": pred}\n        for k in \"executor_config\", \"strategy_config\":\n            setattr(self, k, fill_placeholder(getattr(self, k), placeholder_value))\n\n        # if the backtesting time range is not set, it will automatically extract time range from the prediction file\n        dt_values = pred.index.get_level_values(\"datetime\")\n        if self.backtest_config[\"start_time\"] is None:\n            self.backtest_config[\"start_time\"] = dt_values.min()\n        if self.backtest_config[\"end_time\"] is None:\n            self.backtest_config[\"end_time\"] = get_date_by_shift(dt_values.max(), -1)\n            warnings.warn(\n                \"No explicit backtest end_time provided. \"\n                \"Qlib requires one extra calendar step to determine the right boundary of a bar. \"\n                \"Therefore the end_time is shifted backward by one trading day from \"\n                f\"{dt_values.max()} -> {self.backtest_config['end_time']}.\"\n            )\n\n        artifact_objects = {}\n        # custom strategy and get backtest\n        portfolio_metric_dict, indicator_dict = normal_backtest(\n            executor=self.executor_config, strategy=self.strategy_config, **self.backtest_config\n        )\n        for _freq, (report_normal, positions_normal) in portfolio_metric_dict.items():\n            artifact_objects.update({f\"report_normal_{_freq}.pkl\": report_normal})\n            artifact_objects.update({f\"positions_normal_{_freq}.pkl\": positions_normal})\n\n        for _freq, indicators_normal in indicator_dict.items():\n            artifact_objects.update({f\"indicators_normal_{_freq}.pkl\": indicators_normal[0]})\n            artifact_objects.update({f\"indicators_normal_{_freq}_obj.pkl\": indicators_normal[1]})\n\n        for _analysis_freq in self.risk_analysis_freq:\n            if _analysis_freq not in portfolio_metric_dict:\n                warnings.warn(\n                    f\"the freq {_analysis_freq} report is not found, please set the corresponding env with `generate_portfolio_metrics=True`\"\n                )\n            else:\n                report_normal, _ = portfolio_metric_dict.get(_analysis_freq)\n                analysis = dict()\n                analysis[\"excess_return_without_cost\"] = risk_analysis(\n                    report_normal[\"return\"] - report_normal[\"bench\"], freq=_analysis_freq\n                )\n                analysis[\"excess_return_with_cost\"] = risk_analysis(\n                    report_normal[\"return\"] - report_normal[\"bench\"] - report_normal[\"cost\"], freq=_analysis_freq\n                )\n\n                analysis_df = pd.concat(analysis)  # type: pd.DataFrame\n                # log metrics\n                analysis_dict = flatten_dict(analysis_df[\"risk\"].unstack().T.to_dict())\n                self.recorder.log_metrics(**{f\"{_analysis_freq}.{k}\": v for k, v in analysis_dict.items()})\n                # save results\n                artifact_objects.update({f\"port_analysis_{_analysis_freq}.pkl\": analysis_df})\n                logger.info(\n                    f\"Portfolio analysis record 'port_analysis_{_analysis_freq}.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}\"\n                )\n                # print out results\n                pprint(f\"The following are analysis results of benchmark return({_analysis_freq}).\")\n                pprint(risk_analysis(report_normal[\"bench\"], freq=_analysis_freq))\n                pprint(f\"The following are analysis results of the excess return without cost({_analysis_freq}).\")\n                pprint(analysis[\"excess_return_without_cost\"])\n                pprint(f\"The following are analysis results of the excess return with cost({_analysis_freq}).\")\n                pprint(analysis[\"excess_return_with_cost\"])\n\n        for _analysis_freq in self.indicator_analysis_freq:\n            if _analysis_freq not in indicator_dict:\n                warnings.warn(f\"the freq {_analysis_freq} indicator is not found\")\n            else:\n                indicators_normal = indicator_dict.get(_analysis_freq)[0]\n                if self.indicator_analysis_method is None:\n                    analysis_df = indicator_analysis(indicators_normal)\n                else:\n                    analysis_df = indicator_analysis(indicators_normal, method=self.indicator_analysis_method)\n                # log metrics\n                analysis_dict = analysis_df[\"value\"].to_dict()\n                self.recorder.log_metrics(**{f\"{_analysis_freq}.{k}\": v for k, v in analysis_dict.items()})\n                # save results\n                artifact_objects.update({f\"indicator_analysis_{_analysis_freq}.pkl\": analysis_df})\n                logger.info(\n                    f\"Indicator analysis record 'indicator_analysis_{_analysis_freq}.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}\"\n                )\n                pprint(f\"The following are analysis results of indicators({_analysis_freq}).\")\n                pprint(analysis_df)\n        return artifact_objects\n\n    def list(self):\n        list_path = []\n        for _freq in self.all_freq:\n            list_path.extend(\n                [\n                    f\"report_normal_{_freq}.pkl\",\n                    f\"positions_normal_{_freq}.pkl\",\n                ]\n            )\n        for _analysis_freq in self.risk_analysis_freq:\n            if _analysis_freq in self.all_freq:\n                list_path.append(f\"port_analysis_{_analysis_freq}.pkl\")\n            else:\n                warnings.warn(f\"risk_analysis freq {_analysis_freq} is not found\")\n\n        for _analysis_freq in self.indicator_analysis_freq:\n            if _analysis_freq in self.all_freq:\n                list_path.append(f\"indicator_analysis_{_analysis_freq}.pkl\")\n            else:\n                warnings.warn(f\"indicator_analysis freq {_analysis_freq} is not found\")\n        return list_path\n\n\nclass MultiPassPortAnaRecord(PortAnaRecord):\n    \"\"\"\n    This is the Multiple Pass Portfolio Analysis Record class that run backtest multiple times and generates the analysis results such as those of backtest. This class inherits the ``PortAnaRecord`` class.\n\n    If shuffle_init_score enabled, the prediction score of the first backtest date will be shuffled, so that initial position will be random.\n    The shuffle_init_score will only works when the signal is used as <PRED> placeholder. The placeholder will be replaced by pred.pkl saved in recorder.\n\n    Parameters\n    ----------\n    recorder : Recorder\n        The recorder used to save the backtest results.\n    pass_num : int\n        The number of backtest passes.\n    shuffle_init_score : bool\n        Whether to shuffle the prediction score of the first backtest date.\n    \"\"\"\n\n    depend_cls = SignalRecord\n\n    def __init__(self, recorder, pass_num=10, shuffle_init_score=True, **kwargs):\n        \"\"\"\n        Parameters\n        ----------\n        recorder : Recorder\n            The recorder used to save the backtest results.\n        pass_num : int\n            The number of backtest passes.\n        shuffle_init_score : bool\n            Whether to shuffle the prediction score of the first backtest date.\n        \"\"\"\n        self.pass_num = pass_num\n        self.shuffle_init_score = shuffle_init_score\n\n        super().__init__(recorder, **kwargs)\n\n        # Save original strategy so that pred df can be replaced in next generate\n        self.original_strategy = deepcopy_basic_type(self.strategy_config)\n        if not isinstance(self.original_strategy, dict):\n            raise QlibException(\"MultiPassPortAnaRecord require the passed in strategy to be a dict\")\n        if \"signal\" not in self.original_strategy.get(\"kwargs\", {}):\n            raise QlibException(\"MultiPassPortAnaRecord require the passed in strategy to have signal as a parameter\")\n\n    def random_init(self):\n        pred_df = self.load(\"pred.pkl\")\n\n        all_pred_dates = pred_df.index.get_level_values(\"datetime\")\n        bt_start_date = pd.to_datetime(self.backtest_config.get(\"start_time\"))\n        if bt_start_date is None:\n            first_bt_pred_date = all_pred_dates.min()\n        else:\n            first_bt_pred_date = all_pred_dates[all_pred_dates >= bt_start_date].min()\n\n        # Shuffle the first backtest date's pred score\n        first_date_score = pred_df.loc[first_bt_pred_date][\"score\"]\n        np.random.shuffle(first_date_score.values)\n\n        # Use shuffled signal as the strategy signal\n        self.strategy_config = deepcopy_basic_type(self.original_strategy)\n        self.strategy_config[\"kwargs\"][\"signal\"] = pred_df\n\n    def _generate(self, **kwargs):\n        risk_analysis_df_map = {}\n\n        # Collect each frequency's analysis df as df list\n        for i in trange(self.pass_num):\n            if self.shuffle_init_score:\n                self.random_init()\n\n            # Not check for cache file list\n            single_run_artifacts = super()._generate(**kwargs)\n\n            for _analysis_freq in self.risk_analysis_freq:\n                risk_analysis_df_list = risk_analysis_df_map.get(_analysis_freq, [])\n                risk_analysis_df_map[_analysis_freq] = risk_analysis_df_list\n\n                analysis_df = single_run_artifacts[f\"port_analysis_{_analysis_freq}.pkl\"]\n                analysis_df[\"run_id\"] = i\n                risk_analysis_df_list.append(analysis_df)\n\n        result_artifacts = {}\n        # Concat df list\n        for _analysis_freq in self.risk_analysis_freq:\n            combined_df = pd.concat(risk_analysis_df_map[_analysis_freq])\n\n            # Calculate return and information ratio's mean, std and mean/std\n            multi_pass_port_analysis_df = combined_df.groupby(level=[0, 1], group_keys=False).apply(\n                lambda x: pd.Series(\n                    {\"mean\": x[\"risk\"].mean(), \"std\": x[\"risk\"].std(), \"mean_std\": x[\"risk\"].mean() / x[\"risk\"].std()}\n                )\n            )\n\n            # Only look at \"annualized_return\" and \"information_ratio\"\n            multi_pass_port_analysis_df = multi_pass_port_analysis_df.loc[\n                (slice(None), [\"annualized_return\", \"information_ratio\"]), :\n            ]\n            pprint(multi_pass_port_analysis_df)\n\n            # Save new df\n            result_artifacts.update({f\"multi_pass_port_analysis_{_analysis_freq}.pkl\": multi_pass_port_analysis_df})\n\n            # Log metrics\n            metrics = flatten_dict(\n                {\n                    \"mean\": multi_pass_port_analysis_df[\"mean\"].unstack().T.to_dict(),\n                    \"std\": multi_pass_port_analysis_df[\"std\"].unstack().T.to_dict(),\n                    \"mean_std\": multi_pass_port_analysis_df[\"mean_std\"].unstack().T.to_dict(),\n                }\n            )\n            self.recorder.log_metrics(**metrics)\n        return result_artifacts\n\n    def list(self):\n        list_path = []\n        for _analysis_freq in self.risk_analysis_freq:\n            if _analysis_freq in self.all_freq:\n                list_path.append(f\"multi_pass_port_analysis_{_analysis_freq}.pkl\")\n            else:\n                warnings.warn(f\"risk_analysis freq {_analysis_freq} is not found\")\n        return list_path\n"
  },
  {
    "path": "qlib/workflow/recorder.py",
    "content": "# Copyright (c) Microsoft Corporation.\r\n# Licensed under the MIT License.\r\n\r\nimport os\r\nimport sys\r\nfrom typing import Optional\r\nimport mlflow\r\nimport shutil\r\nimport pickle\r\nimport tempfile\r\nimport subprocess\r\nimport platform\r\nfrom pathlib import Path\r\nfrom datetime import datetime\r\n\r\nfrom qlib.utils.serial import Serializable\r\nfrom qlib.utils.exceptions import LoadObjectError\r\nfrom qlib.utils.paral import AsyncCaller\r\n\r\nfrom ..log import TimeInspector, get_module_logger\r\nfrom mlflow.store.artifact.azure_blob_artifact_repo import AzureBlobArtifactRepository\r\n\r\nlogger = get_module_logger(\"workflow\")\r\n# mlflow limits the length of log_param to 500, but this caused errors when using qrun, so we extended the mlflow limit.\r\nmlflow.utils.validation.MAX_PARAM_VAL_LENGTH = 1000\r\n\r\n\r\nclass Recorder:\r\n    \"\"\"\r\n    This is the `Recorder` class for logging the experiments. The API is designed similar to mlflow.\r\n    (The link: https://mlflow.org/docs/latest/python_api/mlflow.html)\r\n\r\n    The status of the recorder can be SCHEDULED, RUNNING, FINISHED, FAILED.\r\n    \"\"\"\r\n\r\n    # status type\r\n    STATUS_S = \"SCHEDULED\"\r\n    STATUS_R = \"RUNNING\"\r\n    STATUS_FI = \"FINISHED\"\r\n    STATUS_FA = \"FAILED\"\r\n\r\n    def __init__(self, experiment_id, name):\r\n        self.id = None\r\n        self.name = name\r\n        self.experiment_id = experiment_id\r\n        self.start_time = None\r\n        self.end_time = None\r\n        self.status = Recorder.STATUS_S\r\n\r\n    def __repr__(self):\r\n        return \"{name}(info={info})\".format(name=self.__class__.__name__, info=self.info)\r\n\r\n    def __str__(self):\r\n        return str(self.info)\r\n\r\n    def __hash__(self) -> int:\r\n        return hash(self.info[\"id\"])\r\n\r\n    @property\r\n    def info(self):\r\n        output = dict()\r\n        output[\"class\"] = \"Recorder\"\r\n        output[\"id\"] = self.id\r\n        output[\"name\"] = self.name\r\n        output[\"experiment_id\"] = self.experiment_id\r\n        output[\"start_time\"] = self.start_time\r\n        output[\"end_time\"] = self.end_time\r\n        output[\"status\"] = self.status\r\n        return output\r\n\r\n    def set_recorder_name(self, rname):\r\n        self.recorder_name = rname\r\n\r\n    def save_objects(self, local_path=None, artifact_path=None, **kwargs):\r\n        \"\"\"\r\n        Save objects such as prediction file or model checkpoints to the artifact URI. User\r\n        can save object through keywords arguments (name:value).\r\n\r\n        Please refer to the docs of qlib.workflow:R.save_objects\r\n\r\n        Parameters\r\n        ----------\r\n        local_path : str\r\n            if provided, them save the file or directory to the artifact URI.\r\n        artifact_path=None : str\r\n            the relative path for the artifact to be stored in the URI.\r\n        \"\"\"\r\n        raise NotImplementedError(f\"Please implement the `save_objects` method.\")\r\n\r\n    def load_object(self, name):\r\n        \"\"\"\r\n        Load objects such as prediction file or model checkpoints.\r\n\r\n        Parameters\r\n        ----------\r\n        name : str\r\n            name of the file to be loaded.\r\n\r\n        Returns\r\n        -------\r\n        The saved object.\r\n        \"\"\"\r\n        raise NotImplementedError(f\"Please implement the `load_object` method.\")\r\n\r\n    def start_run(self):\r\n        \"\"\"\r\n        Start running or resuming the Recorder. The return value can be used as a context manager within a `with` block;\r\n        otherwise, you must call end_run() to terminate the current run. (See `ActiveRun` class in mlflow)\r\n\r\n        Returns\r\n        -------\r\n        An active running object (e.g. mlflow.ActiveRun object).\r\n        \"\"\"\r\n        raise NotImplementedError(f\"Please implement the `start_run` method.\")\r\n\r\n    def end_run(self):\r\n        \"\"\"\r\n        End an active Recorder.\r\n        \"\"\"\r\n        raise NotImplementedError(f\"Please implement the `end_run` method.\")\r\n\r\n    def log_params(self, **kwargs):\r\n        \"\"\"\r\n        Log a batch of params for the current run.\r\n\r\n        Parameters\r\n        ----------\r\n        keyword arguments\r\n            key, value pair to be logged as parameters.\r\n        \"\"\"\r\n        raise NotImplementedError(f\"Please implement the `log_params` method.\")\r\n\r\n    def log_metrics(self, step=None, **kwargs):\r\n        \"\"\"\r\n        Log multiple metrics for the current run.\r\n\r\n        Parameters\r\n        ----------\r\n        keyword arguments\r\n            key, value pair to be logged as metrics.\r\n        \"\"\"\r\n        raise NotImplementedError(f\"Please implement the `log_metrics` method.\")\r\n\r\n    def log_artifact(self, local_path: str, artifact_path: Optional[str] = None):\r\n        \"\"\"\r\n        Log a local file or directory as an artifact of the currently active run.\r\n\r\n        Parameters\r\n        ----------\r\n        local_path : str\r\n            Path to the file to write.\r\n        artifact_path : Optional[str]\r\n            If provided, the directory in ``artifact_uri`` to write to.\r\n        \"\"\"\r\n        raise NotImplementedError(f\"Please implement the `log_metrics` method.\")\r\n\r\n    def set_tags(self, **kwargs):\r\n        \"\"\"\r\n        Log a batch of tags for the current run.\r\n\r\n        Parameters\r\n        ----------\r\n        keyword arguments\r\n            key, value pair to be logged as tags.\r\n        \"\"\"\r\n        raise NotImplementedError(f\"Please implement the `set_tags` method.\")\r\n\r\n    def delete_tags(self, *keys):\r\n        \"\"\"\r\n        Delete some tags from a run.\r\n\r\n        Parameters\r\n        ----------\r\n        keys : series of strs of the keys\r\n            all the name of the tag to be deleted.\r\n        \"\"\"\r\n        raise NotImplementedError(f\"Please implement the `delete_tags` method.\")\r\n\r\n    def list_artifacts(self, artifact_path: str = None):\r\n        \"\"\"\r\n        List all the artifacts of a recorder.\r\n\r\n        Parameters\r\n        ----------\r\n        artifact_path : str\r\n            the relative path for the artifact to be stored in the URI.\r\n\r\n        Returns\r\n        -------\r\n        A list of artifacts information (name, path, etc.) that being stored.\r\n        \"\"\"\r\n        raise NotImplementedError(f\"Please implement the `list_artifacts` method.\")\r\n\r\n    def download_artifact(self, path: str, dst_path: Optional[str] = None) -> str:\r\n        \"\"\"\r\n        Download an artifact file or directory from a run to a local directory if applicable,\r\n        and return a local path for it.\r\n\r\n        Parameters\r\n        ----------\r\n        path : str\r\n            Relative source path to the desired artifact.\r\n        dst_path : Optional[str]\r\n            Absolute path of the local filesystem destination directory to which to\r\n            download the specified artifacts. This directory must already exist.\r\n            If unspecified, the artifacts will either be downloaded to a new\r\n            uniquely-named directory on the local filesystem.\r\n\r\n        Returns\r\n        -------\r\n        str\r\n            Local path of desired artifact.\r\n        \"\"\"\r\n        raise NotImplementedError(f\"Please implement the `list_artifacts` method.\")\r\n\r\n    def list_metrics(self):\r\n        \"\"\"\r\n        List all the metrics of a recorder.\r\n\r\n        Returns\r\n        -------\r\n        A dictionary of metrics that being stored.\r\n        \"\"\"\r\n        raise NotImplementedError(f\"Please implement the `list_metrics` method.\")\r\n\r\n    def list_params(self):\r\n        \"\"\"\r\n        List all the params of a recorder.\r\n\r\n        Returns\r\n        -------\r\n        A dictionary of params that being stored.\r\n        \"\"\"\r\n        raise NotImplementedError(f\"Please implement the `list_params` method.\")\r\n\r\n    def list_tags(self):\r\n        \"\"\"\r\n        List all the tags of a recorder.\r\n\r\n        Returns\r\n        -------\r\n        A dictionary of tags that being stored.\r\n        \"\"\"\r\n        raise NotImplementedError(f\"Please implement the `list_tags` method.\")\r\n\r\n\r\nclass MLflowRecorder(Recorder):\r\n    \"\"\"\r\n    Use mlflow to implement a Recorder.\r\n\r\n    Due to the fact that mlflow will only log artifact from a file or directory, we decide to\r\n    use file manager to help maintain the objects in the project.\r\n\r\n    Instead of using mlflow directly, we use another interface wrapping mlflow to log experiments.\r\n    Though it takes extra efforts, but it brings users benefits due to following reasons.\r\n    - It will be more convenient to change the experiment logging backend without changing any code in upper level\r\n    - We can provide more convenience to automatically do some extra things and make interface easier. For examples:\r\n        - Automatically logging the uncommitted code\r\n        - Automatically logging part of environment variables\r\n        - User can control several different runs by just creating different Recorder (in mlflow, you always have to switch artifact_uri and pass in run ids frequently)\r\n    \"\"\"\r\n\r\n    def __init__(self, experiment_id, uri, name=None, mlflow_run=None):\r\n        super(MLflowRecorder, self).__init__(experiment_id, name)\r\n        self._uri = uri\r\n        self._artifact_uri = None\r\n        self.client = mlflow.tracking.MlflowClient(tracking_uri=self._uri)\r\n        # construct from mlflow run\r\n        if mlflow_run is not None:\r\n            assert isinstance(mlflow_run, mlflow.entities.run.Run), \"Please input with a MLflow Run object.\"\r\n            self.name = mlflow_run.data.tags[\"mlflow.runName\"]\r\n            self.id = mlflow_run.info.run_id\r\n            self.status = mlflow_run.info.status\r\n            self.start_time = (\r\n                datetime.fromtimestamp(float(mlflow_run.info.start_time) / 1000.0).strftime(\"%Y-%m-%d %H:%M:%S\")\r\n                if mlflow_run.info.start_time is not None\r\n                else None\r\n            )\r\n            self.end_time = (\r\n                datetime.fromtimestamp(float(mlflow_run.info.end_time) / 1000.0).strftime(\"%Y-%m-%d %H:%M:%S\")\r\n                if mlflow_run.info.end_time is not None\r\n                else None\r\n            )\r\n            self._artifact_uri = mlflow_run.info.artifact_uri\r\n        self.async_log = None\r\n\r\n    def __repr__(self):\r\n        name = self.__class__.__name__\r\n        space_length = len(name) + 1\r\n        return \"{name}(info={info},\\n{space}uri={uri},\\n{space}artifact_uri={artifact_uri},\\n{space}client={client})\".format(\r\n            name=name,\r\n            space=\" \" * space_length,\r\n            info=self.info,\r\n            uri=self.uri,\r\n            artifact_uri=self.artifact_uri,\r\n            client=self.client,\r\n        )\r\n\r\n    def __hash__(self) -> int:\r\n        return hash(self.info[\"id\"])\r\n\r\n    def __eq__(self, o: object) -> bool:\r\n        if isinstance(o, MLflowRecorder):\r\n            return self.info[\"id\"] == o.info[\"id\"]\r\n        return False\r\n\r\n    @property\r\n    def uri(self):\r\n        return self._uri\r\n\r\n    @property\r\n    def artifact_uri(self):\r\n        return self._artifact_uri\r\n\r\n    def get_local_dir(self):\r\n        \"\"\"\r\n        This function will return the directory path of this recorder.\r\n        \"\"\"\r\n        if self.artifact_uri is not None:\r\n            if platform.system() == \"Windows\":\r\n                local_dir_path = Path(self.artifact_uri.lstrip(\"file:\").lstrip(\"/\")).parent\r\n            else:\r\n                local_dir_path = Path(self.artifact_uri.lstrip(\"file:\")).parent\r\n            local_dir_path = str(local_dir_path.resolve())\r\n            if os.path.isdir(local_dir_path):\r\n                return local_dir_path\r\n            else:\r\n                raise RuntimeError(\"This recorder is not saved in the local file system.\")\r\n\r\n        else:\r\n            raise ValueError(\r\n                \"Please make sure the recorder has been created and started properly before getting artifact uri.\"\r\n            )\r\n\r\n    def start_run(self):\r\n        # set the tracking uri\r\n        mlflow.set_tracking_uri(self.uri)\r\n        # start the run\r\n        run = mlflow.start_run(self.id, self.experiment_id, self.name)\r\n        # save the run id and artifact_uri\r\n        self.id = run.info.run_id\r\n        self._artifact_uri = run.info.artifact_uri\r\n        self.start_time = datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\")\r\n        self.status = Recorder.STATUS_R\r\n        logger.info(f\"Recorder {self.id} starts running under Experiment {self.experiment_id} ...\")\r\n\r\n        # NOTE: making logging async.\r\n        # - This may cause delay when uploading results\r\n        # - The logging time may not be accurate\r\n        self.async_log = AsyncCaller()\r\n\r\n        # TODO: currently, this is only supported in MLflowRecorder.\r\n        # Maybe we can make this feature more general.\r\n        self._log_uncommitted_code()\r\n\r\n        self.log_params(**{\"cmd-sys.argv\": \" \".join(sys.argv)})  # log the command to produce current experiment\r\n        self.log_params(\r\n            **{k: v for k, v in os.environ.items() if k.startswith(\"_QLIB_\")}\r\n        )  # Log necessary environment variables\r\n        return run\r\n\r\n    def _log_uncommitted_code(self):\r\n        \"\"\"\r\n        Mlflow only log the commit id of the current repo. But usually, user will have a lot of uncommitted changes.\r\n        So this tries to automatically to log them all.\r\n        \"\"\"\r\n        # TODO: the sub-directories maybe git repos.\r\n        # So it will be better if we can walk the sub-directories and log the uncommitted changes.\r\n        for cmd, fname in [\r\n            (\"git diff\", \"code_diff.txt\"),\r\n            (\"git status\", \"code_status.txt\"),\r\n            (\"git diff --cached\", \"code_cached.txt\"),\r\n        ]:\r\n            try:\r\n                out = subprocess.check_output(cmd, shell=True)\r\n                self.client.log_text(self.id, out.decode(), fname)  # this behaves same as above\r\n            except subprocess.CalledProcessError:\r\n                logger.info(f\"Fail to log the uncommitted code of $CWD({os.getcwd()}) when run {cmd}.\")\r\n\r\n    def end_run(self, status: str = Recorder.STATUS_S):\r\n        assert status in [\r\n            Recorder.STATUS_S,\r\n            Recorder.STATUS_R,\r\n            Recorder.STATUS_FI,\r\n            Recorder.STATUS_FA,\r\n        ], f\"The status type {status} is not supported.\"\r\n        self.end_time = datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\")\r\n        if self.status != Recorder.STATUS_S:\r\n            self.status = status\r\n        if self.async_log is not None:\r\n            # Waiting Queue should go before mlflow.end_run. Otherwise mlflow will raise error\r\n            with TimeInspector.logt(\"waiting `async_log`\"):\r\n                self.async_log.wait()\r\n        self.async_log = None\r\n        mlflow.end_run(status)\r\n\r\n    def save_objects(self, local_path=None, artifact_path=None, **kwargs):\r\n        assert self.uri is not None, \"Please start the experiment and recorder first before using recorder directly.\"\r\n        if local_path is not None:\r\n            path = Path(local_path)\r\n            if path.is_dir():\r\n                self.client.log_artifacts(self.id, local_path, artifact_path)\r\n            else:\r\n                self.client.log_artifact(self.id, local_path, artifact_path)\r\n        else:\r\n            temp_dir = Path(tempfile.mkdtemp()).resolve()\r\n            for name, data in kwargs.items():\r\n                path = temp_dir / name\r\n                Serializable.general_dump(data, path)\r\n                self.client.log_artifact(self.id, temp_dir / name, artifact_path)\r\n            shutil.rmtree(temp_dir)\r\n\r\n    def load_object(self, name, unpickler=pickle.Unpickler):\r\n        \"\"\"\r\n        Load object such as prediction file or model checkpoint in mlflow.\r\n\r\n        Args:\r\n            name (str): the object name\r\n\r\n            unpickler: Supporting using custom unpickler\r\n\r\n        Raises:\r\n            LoadObjectError: if raise some exceptions when load the object\r\n\r\n        Returns:\r\n            object: the saved object in mlflow.\r\n        \"\"\"\r\n        assert self.uri is not None, \"Please start the experiment and recorder first before using recorder directly.\"\r\n\r\n        path = None\r\n        try:\r\n            path = self.client.download_artifacts(self.id, name)\r\n            with Path(path).open(\"rb\") as f:\r\n                data = unpickler(f).load()\r\n            return data\r\n        except Exception as e:\r\n            raise LoadObjectError(str(e)) from e\r\n        finally:\r\n            ar = self.client._tracking_client._get_artifact_repo(self.id)\r\n            if isinstance(ar, AzureBlobArtifactRepository) and path is not None:\r\n                # for saving disk space\r\n                # For safety, only remove redundant file for specific ArtifactRepository\r\n                shutil.rmtree(Path(path).absolute().parent)\r\n\r\n    @AsyncCaller.async_dec(ac_attr=\"async_log\")\r\n    def log_params(self, **kwargs):\r\n        for name, data in kwargs.items():\r\n            self.client.log_param(self.id, name, data)\r\n\r\n    @AsyncCaller.async_dec(ac_attr=\"async_log\")\r\n    def log_metrics(self, step=None, **kwargs):\r\n        for name, data in kwargs.items():\r\n            self.client.log_metric(self.id, name, data, step=step)\r\n\r\n    def log_artifact(self, local_path, artifact_path: Optional[str] = None):\r\n        self.client.log_artifact(self.id, local_path=local_path, artifact_path=artifact_path)\r\n\r\n    @AsyncCaller.async_dec(ac_attr=\"async_log\")\r\n    def set_tags(self, **kwargs):\r\n        for name, data in kwargs.items():\r\n            self.client.set_tag(self.id, name, data)\r\n\r\n    def delete_tags(self, *keys):\r\n        for key in keys:\r\n            self.client.delete_tag(self.id, key)\r\n\r\n    def get_artifact_uri(self):\r\n        if self.artifact_uri is not None:\r\n            return self.artifact_uri\r\n        else:\r\n            raise ValueError(\r\n                \"Please make sure the recorder has been created and started properly before getting artifact uri.\"\r\n            )\r\n\r\n    def list_artifacts(self, artifact_path=None):\r\n        assert self.uri is not None, \"Please start the experiment and recorder first before using recorder directly.\"\r\n        artifacts = self.client.list_artifacts(self.id, artifact_path)\r\n        return [art.path for art in artifacts]\r\n\r\n    def download_artifact(self, path: str, dst_path: Optional[str] = None) -> str:\r\n        return self.client.download_artifacts(self.id, path, dst_path)\r\n\r\n    def list_metrics(self):\r\n        run = self.client.get_run(self.id)\r\n        return run.data.metrics\r\n\r\n    def list_params(self):\r\n        run = self.client.get_run(self.id)\r\n        return run.data.params\r\n\r\n    def list_tags(self):\r\n        run = self.client.get_run(self.id)\r\n        return run.data.tags\r\n"
  },
  {
    "path": "qlib/workflow/task/__init__.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\"\"\"\nTask related workflow is implemented in this folder\n\nA typical task workflow\n\n| Step                  | Description                                    |\n|-----------------------+------------------------------------------------|\n| TaskGen               | Generating tasks.                              |\n| TaskManager(optional) | Manage generated tasks                         |\n| run task              | retrieve  tasks from TaskManager and run tasks. |\n\"\"\"\n"
  },
  {
    "path": "qlib/workflow/task/collect.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\n\"\"\"\nCollector module can collect objects from everywhere and process them such as merging, grouping, averaging and so on.\n\"\"\"\n\nfrom collections import defaultdict\nfrom qlib.log import TimeInspector\nfrom typing import Callable, Dict, Iterable, List\nfrom qlib.log import get_module_logger\nfrom qlib.utils.serial import Serializable\nfrom qlib.utils.exceptions import LoadObjectError\nfrom qlib.workflow import R\nfrom qlib.workflow.exp import Experiment\nfrom qlib.workflow.recorder import Recorder\n\n\nclass Collector(Serializable):\n    \"\"\"The collector to collect different results\"\"\"\n\n    pickle_backend = \"dill\"  # use dill to dump user method\n\n    def __init__(self, process_list=[]):\n        \"\"\"\n        Init Collector.\n\n        Args:\n            process_list (list or Callable):  the list of processors or the instance of a processor to process dict.\n        \"\"\"\n        if not isinstance(process_list, list):\n            process_list = [process_list]\n        self.process_list = process_list\n\n    def collect(self) -> dict:\n        \"\"\"\n        Collect the results and return a dict like {key: things}\n\n        Returns:\n            dict: the dict after collecting.\n\n            For example:\n\n            {\"prediction\": pd.Series}\n\n            {\"IC\": {\"Xgboost\": pd.Series, \"LSTM\": pd.Series}}\n\n            ...\n        \"\"\"\n        raise NotImplementedError(f\"Please implement the `collect` method.\")\n\n    @staticmethod\n    def process_collect(collected_dict, process_list=[], *args, **kwargs) -> dict:\n        \"\"\"\n        Do a series of processing to the dict returned by collect and return a dict like {key: things}\n        For example, you can group and ensemble.\n\n        Args:\n            collected_dict (dict): the dict return by `collect`\n            process_list (list or Callable): the list of processors or the instance of a processor to process dict.\n                The processor order is the same as the list order.\n                For example: [Group1(..., Ensemble1()), Group2(..., Ensemble2())]\n\n        Returns:\n            dict: the dict after processing.\n        \"\"\"\n        if not isinstance(process_list, list):\n            process_list = [process_list]\n        result = {}\n        for artifact in collected_dict:\n            value = collected_dict[artifact]\n            for process in process_list:\n                if not callable(process):\n                    raise NotImplementedError(f\"{type(process)} is not supported in `process_collect`.\")\n                value = process(value, *args, **kwargs)\n            result[artifact] = value\n        return result\n\n    def __call__(self, *args, **kwargs) -> dict:\n        \"\"\"\n        Do the workflow including ``collect`` and ``process_collect``\n\n        Returns:\n            dict: the dict after collecting and processing.\n        \"\"\"\n        collected = self.collect()\n        return self.process_collect(collected, self.process_list, *args, **kwargs)\n\n\nclass MergeCollector(Collector):\n    \"\"\"\n    A collector to collect the results of other Collectors\n\n    For example:\n\n        We have 2 collector, which named A and B.\n        A can collect {\"prediction\": pd.Series} and B can collect {\"IC\": {\"Xgboost\": pd.Series, \"LSTM\": pd.Series}}.\n        Then after this class's collect, we can collect {\"A_prediction\": pd.Series, \"B_IC\": {\"Xgboost\": pd.Series, \"LSTM\": pd.Series}}\n\n        ...\n\n    \"\"\"\n\n    def __init__(self, collector_dict: Dict[str, Collector], process_list: List[Callable] = [], merge_func=None):\n        \"\"\"\n        Init MergeCollector.\n\n        Args:\n            collector_dict (Dict[str,Collector]): the dict like {collector_key, Collector}\n            process_list (List[Callable]): the list of processors or the instance of processor to process dict.\n            merge_func (Callable): a method to generate outermost key. The given params are ``collector_key`` from collector_dict and ``key`` from every collector after collecting.\n                None for using tuple to connect them, such as \"ABC\"+(\"a\",\"b\") -> (\"ABC\", (\"a\",\"b\")).\n        \"\"\"\n        super().__init__(process_list=process_list)\n        self.collector_dict = collector_dict\n        self.merge_func = merge_func\n\n    def collect(self) -> dict:\n        \"\"\"\n        Collect all results of collector_dict and change the outermost key to a recombination key.\n\n        Returns:\n            dict: the dict after collecting.\n        \"\"\"\n        collect_dict = {}\n        for collector_key, collector in self.collector_dict.items():\n            tmp_dict = collector()\n            for key, value in tmp_dict.items():\n                if self.merge_func is not None:\n                    collect_dict[self.merge_func(collector_key, key)] = value\n                else:\n                    collect_dict[(collector_key, key)] = value\n        return collect_dict\n\n\nclass RecorderCollector(Collector):\n    ART_KEY_RAW = \"__raw\"\n\n    def __init__(\n        self,\n        experiment,\n        process_list=[],\n        rec_key_func=None,\n        rec_filter_func=None,\n        artifacts_path={\"pred\": \"pred.pkl\"},\n        artifacts_key=None,\n        list_kwargs={},\n        status: Iterable = {Recorder.STATUS_FI},\n    ):\n        \"\"\"\n        Init RecorderCollector.\n\n        Args:\n            experiment:\n                (Experiment or str): an instance of an Experiment or the name of an Experiment\n                (Callable): an callable function, which returns a list of experiments\n            process_list (list or Callable): the list of processors or the instance of a processor to process dict.\n            rec_key_func (Callable): a function to get the key of a recorder. If None, use recorder id.\n            rec_filter_func (Callable, optional): filter the recorder by return True or False. Defaults to None.\n            artifacts_path (dict, optional): The artifacts name and its path in Recorder. Defaults to {\"pred\": \"pred.pkl\", \"IC\": \"sig_analysis/ic.pkl\"}.\n            artifacts_key (str or List, optional): the artifacts key you want to get. If None, get all artifacts.\n            list_kwargs (str): arguments for list_recorders function.\n            status (Iterable): only collect recorders with specific status. None indicating collecting all the recorders\n        \"\"\"\n        super().__init__(process_list=process_list)\n        if isinstance(experiment, str):\n            experiment = R.get_exp(experiment_name=experiment)\n        assert isinstance(experiment, (Experiment, Callable))\n        self.experiment = experiment\n        self.artifacts_path = artifacts_path\n        if rec_key_func is None:\n\n            def rec_key_func(rec):\n                return rec.info[\"id\"]\n\n        if artifacts_key is None:\n            artifacts_key = list(self.artifacts_path.keys())\n        self.rec_key_func = rec_key_func\n        self.artifacts_key = artifacts_key\n        self.rec_filter_func = rec_filter_func\n        self.list_kwargs = list_kwargs\n        self.status = status\n\n    def collect(self, artifacts_key=None, rec_filter_func=None, only_exist=True) -> dict:\n        \"\"\"\n        Collect different artifacts based on recorder after filtering.\n\n        Args:\n            artifacts_key (str or List, optional): the artifacts key you want to get. If None, use the default.\n            rec_filter_func (Callable, optional): filter the recorder by return True or False. If None, use the default.\n            only_exist (bool, optional): if only collect the artifacts when a recorder really has.\n                If True, the recorder with exception when loading will not be collected. But if False, it will raise the exception.\n\n        Returns:\n            dict: the dict after collected like {artifact: {rec_key: object}}\n        \"\"\"\n        if artifacts_key is None:\n            artifacts_key = self.artifacts_key\n        if rec_filter_func is None:\n            rec_filter_func = self.rec_filter_func\n\n        if isinstance(artifacts_key, str):\n            artifacts_key = [artifacts_key]\n\n        collect_dict = {}\n        # filter records\n\n        if isinstance(self.experiment, Experiment):\n            with TimeInspector.logt(\"Time to `list_recorders` in RecorderCollector\"):\n                recs = list(self.experiment.list_recorders(**self.list_kwargs).values())\n        elif isinstance(self.experiment, Callable):\n            recs = self.experiment()\n\n        recs = [\n            rec\n            for rec in recs\n            if (\n                (self.status is None or rec.status in self.status) and (rec_filter_func is None or rec_filter_func(rec))\n            )\n        ]\n\n        logger = get_module_logger(\"RecorderCollector\")\n        status_stat = defaultdict(int)\n        for r in recs:\n            status_stat[r.status] += 1\n        logger.info(f\"Nubmer of recorders after filter: {status_stat}\")\n        for rec in recs:\n            rec_key = self.rec_key_func(rec)\n            for key in artifacts_key:\n                if self.ART_KEY_RAW == key:\n                    artifact = rec\n                else:\n                    try:\n                        artifact = rec.load_object(self.artifacts_path[key])\n                    except LoadObjectError as e:\n                        if only_exist:\n                            # only collect existing artifact\n                            logger.warning(f\"Fail to load {self.artifacts_path[key]} and it is ignored.\")\n                            continue\n                        raise e\n                # give user some warning if the values are overridden\n                cdd = collect_dict.setdefault(key, {})\n                if rec_key in cdd:\n                    logger.warning(\n                        f\"key '{rec_key}' is duplicated. Previous value will be overrides. Please check you `rec_key_func`\"\n                    )\n                cdd[rec_key] = artifact\n\n        return collect_dict\n\n    def get_exp_name(self) -> str:\n        \"\"\"\n        Get experiment name\n\n        Returns:\n            str: experiment name\n        \"\"\"\n        return self.experiment.name\n"
  },
  {
    "path": "qlib/workflow/task/gen.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\"\"\"\nTaskGenerator module can generate many tasks based on TaskGen and some task templates.\n\"\"\"\n\nimport abc\nimport copy\nimport pandas as pd\nfrom typing import Dict, List, Union, Callable\n\nfrom qlib.utils import transform_end_date\nfrom .utils import TimeAdjuster\n\n\ndef task_generator(tasks, generators) -> list:\n    \"\"\"\n    Use a list of TaskGen and a list of task templates to generate different tasks.\n\n    For examples:\n\n        There are 3 task templates a,b,c and 2 TaskGen A,B. A will generates 2 tasks from a template and B will generates 3 tasks from a template.\n        task_generator([a, b, c], [A, B]) will finally generate 3*2*3 = 18 tasks.\n\n    Parameters\n    ----------\n    tasks : List[dict] or dict\n        a list of task templates or a single task\n    generators : List[TaskGen] or TaskGen\n        a list of TaskGen or a single TaskGen\n\n    Returns\n    -------\n    list\n        a list of tasks\n    \"\"\"\n\n    if isinstance(tasks, dict):\n        tasks = [tasks]\n    if isinstance(generators, TaskGen):\n        generators = [generators]\n\n    # generate gen_task_list\n    for gen in generators:\n        new_task_list = []\n        for task in tasks:\n            new_task_list.extend(gen.generate(task))\n        tasks = new_task_list\n\n    return tasks\n\n\nclass TaskGen(metaclass=abc.ABCMeta):\n    \"\"\"\n    The base class for generating different tasks\n\n    Example 1:\n\n        input: a specific task template and rolling steps\n\n        output: rolling version of the tasks\n\n    Example 2:\n\n        input: a specific task template and losses list\n\n        output: a set of tasks with different losses\n\n    \"\"\"\n\n    @abc.abstractmethod\n    def generate(self, task: dict) -> List[dict]:\n        \"\"\"\n        Generate different tasks based on a task template\n\n        Parameters\n        ----------\n        task: dict\n            a task template\n\n        Returns\n        -------\n        typing.List[dict]:\n            A list of tasks\n        \"\"\"\n\n    def __call__(self, *args, **kwargs):\n        \"\"\"\n        This is just a syntactic sugar for generate\n        \"\"\"\n        return self.generate(*args, **kwargs)\n\n\ndef handler_mod(task: dict, rolling_gen):\n    \"\"\"\n    Help to modify the handler end time when using RollingGen\n    It try to handle the following case\n\n    - Hander's data end_time is earlier than  dataset's test_data's segments.\n\n        - To handle this, handler's data's end_time is extended.\n\n    If the handler's end_time is None, then it is not necessary to change it's end time.\n\n    Args:\n        task (dict): a task template\n        rg (RollingGen): an instance of RollingGen\n    \"\"\"\n    try:\n        handler_kwargs = task[\"dataset\"][\"kwargs\"][\"handler\"][\"kwargs\"]\n        handler_end_time = handler_kwargs.get(\"end_time\")\n        test_seg_end_time = task[\"dataset\"][\"kwargs\"][\"segments\"][rolling_gen.test_key][1]\n        # if the end of test_segments is None (open-ended segment, i.e., \"until now\") or end_time < the end of test_segments,\n        # then change end_time to allow load more data\n        if test_seg_end_time is None or rolling_gen.ta.cal_interval(handler_end_time, test_seg_end_time) < 0:\n            handler_kwargs[\"end_time\"] = copy.deepcopy(test_seg_end_time)\n    except KeyError:\n        # Maybe dataset do not have handler, then do nothing.\n        pass\n    except TypeError:\n        # May be the handler is a string. `\"handler.pkl\"[\"kwargs\"]` will raise TypeError\n        # e.g. a dumped file like file:///<file>/\n        pass\n\n\ndef trunc_segments(ta: TimeAdjuster, segments: Dict[str, pd.Timestamp], days, test_key=\"test\"):\n    \"\"\"\n    To avoid the leakage of future information, the segments should be truncated according to the test start_time\n\n    NOTE:\n        This function will change segments **inplace**\n    \"\"\"\n    # adjust segment\n    test_start = min(t for t in segments[test_key] if t is not None)\n    for k in list(segments.keys()):\n        if k != test_key:\n            segments[k] = ta.truncate(segments[k], test_start, days)\n\n\nclass RollingGen(TaskGen):\n    ROLL_EX = TimeAdjuster.SHIFT_EX  # fixed start date, expanding end date\n    ROLL_SD = TimeAdjuster.SHIFT_SD  # fixed segments size, slide it from start date\n\n    def __init__(\n        self,\n        step: int = 40,\n        rtype: str = ROLL_EX,\n        ds_extra_mod_func: Union[None, Callable] = handler_mod,\n        test_key=\"test\",\n        train_key=\"train\",\n        trunc_days: int = None,\n        task_copy_func: Callable = copy.deepcopy,\n    ):\n        \"\"\"\n        Generate tasks for rolling\n\n        Parameters\n        ----------\n        step : int\n            step to rolling\n        rtype : str\n            rolling type (expanding, sliding)\n        ds_extra_mod_func: Callable\n            A method like: handler_mod(task: dict, rg: RollingGen)\n            Do some extra action after generating a task. For example, use ``handler_mod`` to modify the end time of the handler of a dataset.\n        trunc_days: int\n            trunc some data to avoid future information leakage\n        task_copy_func: Callable\n            the function to copy entire task. This is very useful when user want to share something between tasks\n        \"\"\"\n        self.step = step\n        self.rtype = rtype\n        self.ds_extra_mod_func = ds_extra_mod_func\n        self.ta = TimeAdjuster(future=True)\n\n        self.test_key = test_key\n        self.train_key = train_key\n        self.trunc_days = trunc_days\n        self.task_copy_func = task_copy_func\n\n    def _update_task_segs(self, task, segs):\n        # update segments of this task\n        task[\"dataset\"][\"kwargs\"][\"segments\"] = copy.deepcopy(segs)\n        if self.ds_extra_mod_func is not None:\n            self.ds_extra_mod_func(task, self)\n\n    def gen_following_tasks(self, task: dict, test_end: pd.Timestamp) -> List[dict]:\n        \"\"\"\n        generating following rolling tasks for `task` until test_end\n\n        Parameters\n        ----------\n        task : dict\n            Qlib task format\n        test_end : pd.Timestamp\n            the latest rolling task includes `test_end`\n\n        Returns\n        -------\n        List[dict]:\n            the following tasks of `task`(`task` itself is excluded)\n        \"\"\"\n        prev_seg = task[\"dataset\"][\"kwargs\"][\"segments\"]\n        while True:\n            segments = {}\n            try:\n                for k, seg in prev_seg.items():\n                    # decide how to shift\n                    # expanding only for train data, the segments size of test data and valid data won't change\n                    if k == self.train_key and self.rtype == self.ROLL_EX:\n                        rtype = self.ta.SHIFT_EX\n                    else:\n                        rtype = self.ta.SHIFT_SD\n                    # shift the segments data\n                    segments[k] = self.ta.shift(seg, step=self.step, rtype=rtype)\n                if segments[self.test_key][0] > test_end:\n                    break\n            except KeyError:\n                # We reach the end of tasks\n                # No more rolling\n                break\n\n            prev_seg = segments\n            t = self.task_copy_func(task)  # deepcopy is necessary to avoid replace task inplace\n            self._update_task_segs(t, segments)\n            yield t\n\n    def generate(self, task: dict) -> List[dict]:\n        \"\"\"\n        Converting the task into a rolling task.\n\n        Parameters\n        ----------\n        task: dict\n            A dict describing a task. For example.\n\n            .. code-block:: python\n\n                DEFAULT_TASK = {\n                    \"model\": {\n                        \"class\": \"LGBModel\",\n                        \"module_path\": \"qlib.contrib.model.gbdt\",\n                    },\n                    \"dataset\": {\n                        \"class\": \"DatasetH\",\n                        \"module_path\": \"qlib.data.dataset\",\n                        \"kwargs\": {\n                            \"handler\": {\n                                \"class\": \"Alpha158\",\n                                \"module_path\": \"qlib.contrib.data.handler\",\n                                \"kwargs\": {\n                                    \"start_time\": \"2008-01-01\",\n                                    \"end_time\": \"2020-08-01\",\n                                    \"fit_start_time\": \"2008-01-01\",\n                                    \"fit_end_time\": \"2014-12-31\",\n                                    \"instruments\": \"csi100\",\n                                },\n                            },\n                            \"segments\": {\n                                \"train\": (\"2008-01-01\", \"2014-12-31\"),\n                                \"valid\": (\"2015-01-01\", \"2016-12-20\"),  # Please avoid leaking the future test data into validation\n                                \"test\": (\"2017-01-01\", \"2020-08-01\"),\n                            },\n                        },\n                    },\n                    \"record\": [\n                        {\n                            \"class\": \"SignalRecord\",\n                            \"module_path\": \"qlib.workflow.record_temp\",\n                        },\n                    ]\n                }\n\n        Returns\n        ----------\n        List[dict]: a list of tasks\n        \"\"\"\n        res = []\n\n        t = self.task_copy_func(task)\n\n        # calculate segments\n\n        # First rolling\n        # 1) prepare the end point\n        segments: dict = copy.deepcopy(self.ta.align_seg(t[\"dataset\"][\"kwargs\"][\"segments\"]))\n        test_end = transform_end_date(segments[self.test_key][1])\n        # 2) and init test segments\n        test_start_idx = self.ta.align_idx(segments[self.test_key][0])\n        segments[self.test_key] = (self.ta.get(test_start_idx), self.ta.get(test_start_idx + self.step - 1))\n        if self.trunc_days is not None:\n            trunc_segments(self.ta, segments, self.trunc_days, self.test_key)\n\n        # update segments of this task\n        self._update_task_segs(t, segments)\n\n        res.append(t)\n\n        # Update the following rolling\n        res.extend(self.gen_following_tasks(t, test_end))\n        return res\n\n\nclass MultiHorizonGenBase(TaskGen):\n    def __init__(self, horizon: List[int] = [5], label_leak_n=2):\n        \"\"\"\n        This task generator tries to generate tasks for different horizons based on an existing task\n\n        Parameters\n        ----------\n        horizon : List[int]\n            the possible horizons of the tasks\n        label_leak_n : int\n            How many future days it will take to get complete label after the day making prediction\n            For example:\n            - User make prediction on day `T`(after getting the close price on `T`)\n            - The label is the return of buying stock on `T + 1` and selling it on `T + 2`\n            - the `label_leak_n` will be 2 (e.g. two days of information is leaked to leverage this sample)\n        \"\"\"\n        self.horizon = list(horizon)\n        self.label_leak_n = label_leak_n\n        self.ta = TimeAdjuster()\n        self.test_key = \"test\"\n\n    @abc.abstractmethod\n    def set_horizon(self, task: dict, hr: int):\n        \"\"\"\n        This method is designed to change the task **in place**\n\n        Parameters\n        ----------\n        task : dict\n            Qlib's task\n        hr : int\n            the horizon of task\n        \"\"\"\n\n    def generate(self, task: dict):\n        res = []\n        for hr in self.horizon:\n            # Add horizon\n            t = copy.deepcopy(task)\n            self.set_horizon(t, hr)\n\n            # adjust segment\n            segments = self.ta.align_seg(t[\"dataset\"][\"kwargs\"][\"segments\"])\n            trunc_segments(self.ta, segments, days=hr + self.label_leak_n, test_key=self.test_key)\n            t[\"dataset\"][\"kwargs\"][\"segments\"] = segments\n            res.append(t)\n        return res\n"
  },
  {
    "path": "qlib/workflow/task/manage.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\n\"\"\"\nTaskManager can fetch unused tasks automatically and manage the lifecycle of a set of tasks with error handling.\nThese features can run tasks concurrently and ensure every task will be used only once.\nTask Manager will store all tasks in `MongoDB <https://www.mongodb.com/>`_.\nUsers **MUST** finished the configuration of `MongoDB <https://www.mongodb.com/>`_ when using this module.\n\nA task in TaskManager consists of 3 parts\n- tasks description: the desc will define the task\n- tasks status: the status of the task\n- tasks result: A user can get the task with the task description and task result.\n\"\"\"\n\nimport concurrent\nimport pickle\nimport time\nfrom contextlib import contextmanager\nfrom typing import Callable, List\n\nimport fire\nimport pymongo\nfrom bson.binary import Binary\nfrom bson.objectid import ObjectId\nfrom pymongo.errors import InvalidDocument\nfrom qlib import auto_init, get_module_logger\nfrom tqdm.cli import tqdm\n\nfrom .utils import get_mongodb\nfrom ...config import C\nfrom ...utils.pickle_utils import restricted_pickle_loads\n\n\nclass TaskManager:\n    \"\"\"\n    TaskManager\n\n    Here is what will a task looks like when it created by TaskManager\n\n    .. code-block:: python\n\n        {\n            'def': pickle serialized task definition.  using pickle will make it easier\n            'filter': json-like data. This is for filtering the tasks.\n            'status': 'waiting' | 'running' | 'done'\n            'res': pickle serialized task result,\n        }\n\n    The tasks manager assumes that you will only update the tasks you fetched.\n    The mongo fetch one and update will make it date updating secure.\n\n    This class can be used as a tool from commandline. Here are several examples.\n    You can view the help of manage module with the following commands:\n    python -m qlib.workflow.task.manage -h # show manual of manage module CLI\n    python -m qlib.workflow.task.manage wait -h # show manual of the wait command of manage\n\n    .. code-block:: shell\n\n        python -m qlib.workflow.task.manage -t <pool_name> wait\n        python -m qlib.workflow.task.manage -t <pool_name> task_stat\n\n\n    .. note::\n\n        Assumption: the data in MongoDB was encoded and the data out of MongoDB was decoded\n\n    Here are four status which are:\n\n        STATUS_WAITING: waiting for training\n\n        STATUS_RUNNING: training\n\n        STATUS_PART_DONE: finished some step and waiting for next step\n\n        STATUS_DONE: all work done\n    \"\"\"\n\n    STATUS_WAITING = \"waiting\"\n    STATUS_RUNNING = \"running\"\n    STATUS_DONE = \"done\"\n    STATUS_PART_DONE = \"part_done\"\n\n    ENCODE_FIELDS_PREFIX = [\"def\", \"res\"]\n\n    def __init__(self, task_pool: str):\n        \"\"\"\n        Init Task Manager, remember to make the statement of MongoDB url and database name firstly.\n        A TaskManager instance serves a specific task pool.\n        The static method of this module serves the whole MongoDB.\n\n        Parameters\n        ----------\n        task_pool: str\n            the name of Collection in MongoDB\n        \"\"\"\n        self.task_pool: pymongo.collection.Collection = getattr(get_mongodb(), task_pool)\n        self.logger = get_module_logger(self.__class__.__name__)\n        self.logger.info(f\"task_pool:{task_pool}\")\n\n    @staticmethod\n    def list() -> list:\n        \"\"\"\n        List the all collection(task_pool) of the db.\n\n        Returns:\n            list\n        \"\"\"\n        return get_mongodb().list_collection_names()\n\n    def _encode_task(self, task):\n        for prefix in self.ENCODE_FIELDS_PREFIX:\n            for k in list(task.keys()):\n                if k.startswith(prefix):\n                    task[k] = Binary(pickle.dumps(task[k], protocol=C.dump_protocol_version))\n        return task\n\n    def _decode_task(self, task):\n        \"\"\"\n        _decode_task is Serialization tool.\n        Mongodb needs JSON, so it needs to convert Python objects into JSON objects through pickle\n\n        Parameters\n        ----------\n        task : dict\n            task information\n\n        Returns\n        -------\n        dict\n            JSON required by mongodb\n        \"\"\"\n        for prefix in self.ENCODE_FIELDS_PREFIX:\n            for k in list(task.keys()):\n                if k.startswith(prefix):\n                    task[k] = restricted_pickle_loads(task[k])\n        return task\n\n    def _dict_to_str(self, flt):\n        return {k: str(v) for k, v in flt.items()}\n\n    def _decode_query(self, query):\n        \"\"\"\n        If the query includes any `_id`, then it needs `ObjectId` to decode.\n        For example, when using TrainerRM, it needs query `{\"_id\": {\"$in\": _id_list}}`. Then we need to `ObjectId` every `_id` in `_id_list`.\n\n        Args:\n            query (dict): query dict. Defaults to {}.\n\n        Returns:\n            dict: the query after decoding.\n        \"\"\"\n        if \"_id\" in query:\n            if isinstance(query[\"_id\"], dict):\n                for key in query[\"_id\"]:\n                    query[\"_id\"][key] = [ObjectId(i) for i in query[\"_id\"][key]]\n            else:\n                query[\"_id\"] = ObjectId(query[\"_id\"])\n        return query\n\n    def replace_task(self, task, new_task):\n        \"\"\"\n        Use a new task to replace a old one\n\n        Args:\n            task: old task\n            new_task: new task\n        \"\"\"\n        new_task = self._encode_task(new_task)\n        query = {\"_id\": ObjectId(task[\"_id\"])}\n        try:\n            self.task_pool.replace_one(query, new_task)\n        except InvalidDocument:\n            task[\"filter\"] = self._dict_to_str(task[\"filter\"])\n            self.task_pool.replace_one(query, new_task)\n\n    def insert_task(self, task):\n        \"\"\"\n        Insert a task.\n\n        Args:\n            task: the task waiting for insert\n\n        Returns:\n            pymongo.results.InsertOneResult\n        \"\"\"\n        try:\n            insert_result = self.task_pool.insert_one(task)\n        except InvalidDocument:\n            task[\"filter\"] = self._dict_to_str(task[\"filter\"])\n            insert_result = self.task_pool.insert_one(task)\n        return insert_result\n\n    def insert_task_def(self, task_def):\n        \"\"\"\n        Insert a task to task_pool\n\n        Parameters\n        ----------\n        task_def: dict\n            the task definition\n\n        Returns\n        -------\n        pymongo.results.InsertOneResult\n        \"\"\"\n        task = self._encode_task(\n            {\n                \"def\": task_def,\n                \"filter\": task_def,  # FIXME: catch the raised error\n                \"status\": self.STATUS_WAITING,\n            }\n        )\n        insert_result = self.insert_task(task)\n        return insert_result\n\n    def create_task(self, task_def_l, dry_run=False, print_nt=False) -> List[str]:\n        \"\"\"\n        If the tasks in task_def_l are new, then insert new tasks into the task_pool, and record inserted_id.\n        If a task is not new, then just query its _id.\n\n        Parameters\n        ----------\n        task_def_l: list\n            a list of task\n        dry_run: bool\n            if insert those new tasks to task pool\n        print_nt: bool\n            if print new task\n\n        Returns\n        -------\n        List[str]\n            a list of the _id of task_def_l\n        \"\"\"\n        new_tasks = []\n        _id_list = []\n        for t in task_def_l:\n            try:\n                r = self.task_pool.find_one({\"filter\": t})\n            except InvalidDocument:\n                r = self.task_pool.find_one({\"filter\": self._dict_to_str(t)})\n            # When r is none, it indicates that r s a new task\n            if r is None:\n                new_tasks.append(t)\n                if not dry_run:\n                    insert_result = self.insert_task_def(t)\n                    _id_list.append(insert_result.inserted_id)\n                else:\n                    _id_list.append(None)\n            else:\n                _id_list.append(self._decode_task(r)[\"_id\"])\n\n        self.logger.info(f\"Total Tasks: {len(task_def_l)}, New Tasks: {len(new_tasks)}\")\n\n        if print_nt:  # print new task\n            for t in new_tasks:\n                print(t)\n\n        if dry_run:\n            return []\n\n        return _id_list\n\n    def fetch_task(self, query={}, status=STATUS_WAITING) -> dict:\n        \"\"\"\n        Use query to fetch tasks.\n\n        Args:\n            query (dict, optional): query dict. Defaults to {}.\n            status (str, optional): [description]. Defaults to STATUS_WAITING.\n\n        Returns:\n            dict: a task(document in collection) after decoding\n        \"\"\"\n        query = query.copy()\n        query = self._decode_query(query)\n        query.update({\"status\": status})\n        task = self.task_pool.find_one_and_update(\n            query, {\"$set\": {\"status\": self.STATUS_RUNNING}}, sort=[(\"priority\", pymongo.DESCENDING)]\n        )\n        # null will be at the top after sorting when using ASCENDING, so the larger the number higher, the higher the priority\n        if task is None:\n            return None\n        task[\"status\"] = self.STATUS_RUNNING\n        return self._decode_task(task)\n\n    @contextmanager\n    def safe_fetch_task(self, query={}, status=STATUS_WAITING):\n        \"\"\"\n        Fetch task from task_pool using query with contextmanager\n\n        Parameters\n        ----------\n        query: dict\n            the dict of query\n\n        Returns\n        -------\n        dict: a task(document in collection) after decoding\n        \"\"\"\n        task = self.fetch_task(query=query, status=status)\n        try:\n            yield task\n        except (Exception, KeyboardInterrupt):  # KeyboardInterrupt is not a subclass of Exception\n            if task is not None:\n                self.logger.info(\"Returning task before raising error\")\n                self.return_task(task, status=status)  # return task as the original status\n                self.logger.info(\"Task returned\")\n            raise\n\n    def task_fetcher_iter(self, query={}):\n        while True:\n            with self.safe_fetch_task(query=query) as task:\n                if task is None:\n                    break\n                yield task\n\n    def query(self, query={}, decode=True):\n        \"\"\"\n        Query task in collection.\n        This function may raise exception `pymongo.errors.CursorNotFound: cursor id not found` if it takes too long to iterate the generator\n\n        python -m qlib.workflow.task.manage -t <your task pool> query '{\"_id\": \"615498be837d0053acbc5d58\"}'\n\n        Parameters\n        ----------\n        query: dict\n            the dict of query\n        decode: bool\n\n        Returns\n        -------\n        dict: a task(document in collection) after decoding\n        \"\"\"\n        query = query.copy()\n        query = self._decode_query(query)\n        for t in self.task_pool.find(query):\n            yield self._decode_task(t)\n\n    def re_query(self, _id) -> dict:\n        \"\"\"\n        Use _id to query task.\n\n        Args:\n            _id (str): _id of a document\n\n        Returns:\n            dict: a task(document in collection) after decoding\n        \"\"\"\n        t = self.task_pool.find_one({\"_id\": ObjectId(_id)})\n        return self._decode_task(t)\n\n    def commit_task_res(self, task, res, status=STATUS_DONE):\n        \"\"\"\n        Commit the result to task['res'].\n\n        Args:\n            task ([type]): [description]\n            res (object): the result you want to save\n            status (str, optional): STATUS_WAITING, STATUS_RUNNING, STATUS_DONE, STATUS_PART_DONE. Defaults to STATUS_DONE.\n        \"\"\"\n        # A workaround to use the class attribute.\n        if status is None:\n            status = TaskManager.STATUS_DONE\n        self.task_pool.update_one(\n            {\"_id\": task[\"_id\"]},\n            {\"$set\": {\"status\": status, \"res\": Binary(pickle.dumps(res, protocol=C.dump_protocol_version))}},\n        )\n\n    def return_task(self, task, status=STATUS_WAITING):\n        \"\"\"\n        Return a task to status. Always using in error handling.\n\n        Args:\n            task ([type]): [description]\n            status (str, optional): STATUS_WAITING, STATUS_RUNNING, STATUS_DONE, STATUS_PART_DONE. Defaults to STATUS_WAITING.\n        \"\"\"\n        if status is None:\n            status = TaskManager.STATUS_WAITING\n        update_dict = {\"$set\": {\"status\": status}}\n        self.task_pool.update_one({\"_id\": task[\"_id\"]}, update_dict)\n\n    def remove(self, query={}):\n        \"\"\"\n        Remove the task using query\n\n        Parameters\n        ----------\n        query: dict\n            the dict of query\n\n        \"\"\"\n        query = query.copy()\n        query = self._decode_query(query)\n        self.task_pool.delete_many(query)\n\n    def task_stat(self, query={}) -> dict:\n        \"\"\"\n        Count the tasks in every status.\n\n        Args:\n            query (dict, optional): the query dict. Defaults to {}.\n\n        Returns:\n            dict\n        \"\"\"\n        query = query.copy()\n        query = self._decode_query(query)\n        tasks = self.query(query=query, decode=False)\n        status_stat = {}\n        for t in tasks:\n            status_stat[t[\"status\"]] = status_stat.get(t[\"status\"], 0) + 1\n        return status_stat\n\n    def reset_waiting(self, query={}):\n        \"\"\"\n        Reset all running task into waiting status. Can be used when some running task exit unexpected.\n\n        Args:\n            query (dict, optional): the query dict. Defaults to {}.\n        \"\"\"\n        query = query.copy()\n        # default query\n        if \"status\" not in query:\n            query[\"status\"] = self.STATUS_RUNNING\n        return self.reset_status(query=query, status=self.STATUS_WAITING)\n\n    def reset_status(self, query, status):\n        query = query.copy()\n        query = self._decode_query(query)\n        print(self.task_pool.update_many(query, {\"$set\": {\"status\": status}}))\n\n    def prioritize(self, task, priority: int):\n        \"\"\"\n        Set priority for task\n\n        Parameters\n        ----------\n        task : dict\n            The task query from the database\n        priority : int\n            the target priority\n        \"\"\"\n        update_dict = {\"$set\": {\"priority\": priority}}\n        self.task_pool.update_one({\"_id\": task[\"_id\"]}, update_dict)\n\n    def _get_undone_n(self, task_stat):\n        return (\n            task_stat.get(self.STATUS_WAITING, 0)\n            + task_stat.get(self.STATUS_RUNNING, 0)\n            + task_stat.get(self.STATUS_PART_DONE, 0)\n        )\n\n    def _get_total(self, task_stat):\n        return sum(task_stat.values())\n\n    def wait(self, query={}):\n        \"\"\"\n        When multiprocessing, the main progress may fetch nothing from TaskManager because there are still some running tasks.\n        So main progress should wait until all tasks are trained well by other progress or machines.\n\n        Args:\n            query (dict, optional): the query dict. Defaults to {}.\n        \"\"\"\n        task_stat = self.task_stat(query)\n        total = self._get_total(task_stat)\n        last_undone_n = self._get_undone_n(task_stat)\n        if last_undone_n == 0:\n            return\n        self.logger.warning(f\"Waiting for {last_undone_n} undone tasks. Please make sure they are running.\")\n        with tqdm(total=total, initial=total - last_undone_n) as pbar:\n            while True:\n                time.sleep(10)\n                undone_n = self._get_undone_n(self.task_stat(query))\n                pbar.update(last_undone_n - undone_n)\n                last_undone_n = undone_n\n                if undone_n == 0:\n                    break\n\n    def __str__(self):\n        return f\"TaskManager({self.task_pool})\"\n\n\ndef run_task(\n    task_func: Callable,\n    task_pool: str,\n    query: dict = {},\n    force_release: bool = False,\n    before_status: str = TaskManager.STATUS_WAITING,\n    after_status: str = TaskManager.STATUS_DONE,\n    **kwargs,\n):\n    r\"\"\"\n    While the task pool is not empty (has WAITING tasks), use task_func to fetch and run tasks in task_pool\n\n    After running this method, here are 4 situations (before_status -> after_status):\n\n        STATUS_WAITING -> STATUS_DONE: use task[\"def\"] as `task_func` param, it means that the task has not been started\n\n        STATUS_WAITING -> STATUS_PART_DONE: use task[\"def\"] as `task_func` param\n\n        STATUS_PART_DONE -> STATUS_PART_DONE: use task[\"res\"] as `task_func` param, it means that the task has been started but not completed\n\n        STATUS_PART_DONE -> STATUS_DONE: use task[\"res\"] as `task_func` param\n\n    Parameters\n    ----------\n    task_func : Callable\n        def (task_def, \\**kwargs) -> <res which will be committed>\n\n        the function to run the task\n    task_pool : str\n        the name of the task pool (Collection in MongoDB)\n    query: dict\n        will use this dict to query task_pool when fetching task\n    force_release : bool\n        will the program force to release the resource\n    before_status : str:\n        the tasks in before_status will be fetched and trained. Can be STATUS_WAITING, STATUS_PART_DONE.\n    after_status : str:\n        the tasks after trained will become after_status. Can be STATUS_WAITING, STATUS_PART_DONE.\n    kwargs\n        the params for `task_func`\n    \"\"\"\n    tm = TaskManager(task_pool)\n\n    ever_run = False\n\n    while True:\n        with tm.safe_fetch_task(status=before_status, query=query) as task:\n            if task is None:\n                break\n            get_module_logger(\"run_task\").info(task[\"def\"])\n            # when fetching `WAITING` task, use task[\"def\"] to train\n            if before_status == TaskManager.STATUS_WAITING:\n                param = task[\"def\"]\n            # when fetching `PART_DONE` task, use task[\"res\"] to train because the middle result has been saved to task[\"res\"]\n            elif before_status == TaskManager.STATUS_PART_DONE:\n                param = task[\"res\"]\n            else:\n                raise ValueError(\"The fetched task must be `STATUS_WAITING` or `STATUS_PART_DONE`!\")\n            if force_release:\n                with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor:\n                    res = executor.submit(task_func, param, **kwargs).result()\n            else:\n                res = task_func(param, **kwargs)\n            tm.commit_task_res(task, res, status=after_status)\n            ever_run = True\n\n    return ever_run\n\n\nif __name__ == \"__main__\":\n    # This is for using it in cmd\n    # E.g. : `python -m qlib.workflow.task.manage list`\n    auto_init()\n    fire.Fire(TaskManager)\n"
  },
  {
    "path": "qlib/workflow/task/utils.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\"\"\"\nSome tools for task management.\n\"\"\"\n\nimport bisect\nfrom copy import deepcopy\nimport pandas as pd\nfrom qlib.data import D\nfrom qlib.utils import hash_args\nfrom qlib.utils.mod import init_instance_by_config\nfrom qlib.workflow import R\nfrom qlib.config import C\nfrom qlib.log import get_module_logger\nfrom pymongo import MongoClient\nfrom pymongo.database import Database\nfrom typing import Union\nfrom pathlib import Path\n\n\ndef get_mongodb() -> Database:\n    \"\"\"\n    Get database in MongoDB, which means you need to declare the address and the name of a database at first.\n\n    For example:\n\n        Using qlib.init():\n\n            .. code-block:: python\n\n                mongo_conf = {\n                    \"task_url\": task_url,  # your MongoDB url\n                    \"task_db_name\": task_db_name,  # database name\n                }\n                qlib.init(..., mongo=mongo_conf)\n\n        After qlib.init():\n\n            .. code-block:: python\n\n                C[\"mongo\"] = {\n                    \"task_url\" : \"mongodb://localhost:27017/\",\n                    \"task_db_name\" : \"rolling_db\"\n                }\n\n    Returns:\n        Database: the Database instance\n    \"\"\"\n    try:\n        cfg = C[\"mongo\"]\n    except KeyError:\n        get_module_logger(\"task\").error(\"Please configure `C['mongo']` before using TaskManager\")\n        raise\n    get_module_logger(\"task\").info(f\"mongo config:{cfg}\")\n    client = MongoClient(cfg[\"task_url\"])\n    return client.get_database(name=cfg[\"task_db_name\"])\n\n\ndef list_recorders(experiment, rec_filter_func=None):\n    \"\"\"\n    List all recorders which can pass the filter in an experiment.\n\n    Args:\n        experiment (str or Experiment): the name of an Experiment or an instance\n        rec_filter_func (Callable, optional): return True to retain the given recorder. Defaults to None.\n\n    Returns:\n        dict: a dict {rid: recorder} after filtering.\n    \"\"\"\n    if isinstance(experiment, str):\n        experiment = R.get_exp(experiment_name=experiment)\n    recs = experiment.list_recorders()\n    recs_flt = {}\n    for rid, rec in recs.items():\n        if rec_filter_func is None or rec_filter_func(rec):\n            recs_flt[rid] = rec\n\n    return recs_flt\n\n\nclass TimeAdjuster:\n    \"\"\"\n    Find appropriate date and adjust date.\n    \"\"\"\n\n    def __init__(self, future=True, end_time=None):\n        self._future = future\n        self.cals = D.calendar(future=future, end_time=end_time)\n\n    def set_end_time(self, end_time=None):\n        \"\"\"\n        Set end time. None for use calendar's end time.\n\n        Args:\n            end_time\n        \"\"\"\n        self.cals = D.calendar(future=self._future, end_time=end_time)\n\n    def get(self, idx: int):\n        \"\"\"\n        Get datetime by index.\n\n        Parameters\n        ----------\n        idx : int\n            index of the calendar\n        \"\"\"\n        if idx is None or idx >= len(self.cals):\n            return None\n        return self.cals[idx]\n\n    def max(self) -> pd.Timestamp:\n        \"\"\"\n        Return the max calendar datetime\n        \"\"\"\n        return max(self.cals)\n\n    def align_idx(self, time_point, tp_type=\"start\") -> int:\n        \"\"\"\n        Align the index of time_point in the calendar.\n\n        Parameters\n        ----------\n        time_point\n        tp_type : str\n\n        Returns\n        -------\n        index : int\n        \"\"\"\n        if time_point is None:\n            # `None` indicates unbounded index/boarder\n            return None\n        time_point = pd.Timestamp(time_point)\n        if tp_type == \"start\":\n            idx = bisect.bisect_left(self.cals, time_point)\n        elif tp_type == \"end\":\n            idx = bisect.bisect_right(self.cals, time_point) - 1\n        else:\n            raise NotImplementedError(f\"This type of input is not supported\")\n        return idx\n\n    def cal_interval(self, time_point_A, time_point_B) -> int:\n        \"\"\"\n        Calculate the trading day interval (time_point_A - time_point_B)\n\n        Args:\n            time_point_A : time_point_A\n            time_point_B : time_point_B (is the past of time_point_A)\n\n        Returns:\n            int: the interval between A and B\n        \"\"\"\n        return self.align_idx(time_point_A) - self.align_idx(time_point_B)\n\n    def align_time(self, time_point, tp_type=\"start\") -> pd.Timestamp:\n        \"\"\"\n        Align time_point to trade date of calendar\n\n        Args:\n            time_point\n                Time point\n            tp_type : str\n                time point type (`\"start\"`, `\"end\"`)\n\n        Returns:\n            pd.Timestamp\n        \"\"\"\n        if time_point is None:\n            return None\n        return self.cals[self.align_idx(time_point, tp_type=tp_type)]\n\n    def align_seg(self, segment: Union[dict, tuple]) -> Union[dict, tuple]:\n        \"\"\"\n        Align the given date to the trade date\n\n        for example:\n\n            .. code-block:: python\n\n                input: {'train': ('2008-01-01', '2014-12-31'), 'valid': ('2015-01-01', '2016-12-31'), 'test': ('2017-01-01', '2020-08-01')}\n\n                output: {'train': (Timestamp('2008-01-02 00:00:00'), Timestamp('2014-12-31 00:00:00')),\n                        'valid': (Timestamp('2015-01-05 00:00:00'), Timestamp('2016-12-30 00:00:00')),\n                        'test': (Timestamp('2017-01-03 00:00:00'), Timestamp('2020-07-31 00:00:00'))}\n\n        Parameters\n        ----------\n        segment\n\n        Returns\n        -------\n        Union[dict, tuple]: the start and end trade date (pd.Timestamp) between the given start and end date.\n        \"\"\"\n        if isinstance(segment, dict):\n            return {k: self.align_seg(seg) for k, seg in segment.items()}\n        elif isinstance(segment, (tuple, list)):\n            return self.align_time(segment[0], tp_type=\"start\"), self.align_time(segment[1], tp_type=\"end\")\n        else:\n            raise NotImplementedError(f\"This type of input is not supported\")\n\n    def truncate(self, segment: tuple, test_start, days: int) -> tuple:\n        \"\"\"\n        Truncate the segment based on the test_start date\n\n        Parameters\n        ----------\n        segment : tuple\n            time segment\n        test_start\n        days : int\n            The trading days to be truncated\n            the data in this segment may need 'days' data\n            `days` are based on the `test_start`.\n            For example, if the label contains the information of 2 days in the near future, the prediction horizon 1 day.\n            (e.g. the prediction target is `Ref($close, -2)/Ref($close, -1) - 1`)\n            the days should be 2 + 1 == 3 days.\n\n        Returns\n        ---------\n        tuple: new segment\n        \"\"\"\n        test_idx = self.align_idx(test_start)\n        if isinstance(segment, tuple):\n            new_seg = []\n            for time_point in segment:\n                tp_idx = min(self.align_idx(time_point), test_idx - days)\n                assert tp_idx > 0\n                new_seg.append(self.get(tp_idx))\n            return tuple(new_seg)\n        else:\n            raise NotImplementedError(f\"This type of input is not supported\")\n\n    SHIFT_SD = \"sliding\"\n    SHIFT_EX = \"expanding\"\n\n    def _add_step(self, index, step):\n        if index is None:\n            return None\n        return index + step\n\n    def shift(self, seg: tuple, step: int, rtype=SHIFT_SD) -> tuple:\n        \"\"\"\n        Shift the datetime of segment\n\n        If there are None (which indicates unbounded index) in the segment, this method will return None.\n\n        Parameters\n        ----------\n        seg :\n            datetime segment\n        step : int\n            rolling step\n        rtype : str\n            rolling type (\"sliding\" or \"expanding\")\n\n        Returns\n        --------\n        tuple: new segment\n\n        Raises\n        ------\n        KeyError:\n            shift will raise error if the index(both start and end) is out of self.cal\n        \"\"\"\n        if isinstance(seg, tuple):\n            start_idx, end_idx = self.align_idx(seg[0], tp_type=\"start\"), self.align_idx(seg[1], tp_type=\"end\")\n            if rtype == self.SHIFT_SD:\n                start_idx = self._add_step(start_idx, step)\n                end_idx = self._add_step(end_idx, step)\n            elif rtype == self.SHIFT_EX:\n                end_idx = self._add_step(end_idx, step)\n            else:\n                raise NotImplementedError(f\"This type of input is not supported\")\n            if start_idx is not None and start_idx > len(self.cals):\n                raise KeyError(\"The segment is out of valid calendar\")\n            return self.get(start_idx), self.get(end_idx)\n        else:\n            raise NotImplementedError(f\"This type of input is not supported\")\n\n\ndef replace_task_handler_with_cache(task: dict, cache_dir: Union[str, Path] = \".\") -> dict:\n    \"\"\"\n    Replace the handler in task with a cache handler.\n    It will automatically cache the file and save it in cache_dir.\n\n    >>> import qlib\n    >>> qlib.auto_init()\n    >>> import datetime\n    >>> # it is simplified task\n    >>> task = {\"dataset\": {\"kwargs\":{'handler': {'class': 'Alpha158', 'module_path': 'qlib.contrib.data.handler', 'kwargs': {'start_time': datetime.date(2008, 1, 1), 'end_time': datetime.date(2020, 8, 1), 'fit_start_time': datetime.date(2008, 1, 1), 'fit_end_time': datetime.date(2014, 12, 31), 'instruments': 'CSI300'}}}}}\n    >>> new_task = replace_task_handler_with_cache(task)\n    >>> print(new_task)\n    {'dataset': {'kwargs': {'handler': 'file...Alpha158.3584f5f8b4.pkl'}}}\n\n    \"\"\"\n    cache_dir = Path(cache_dir)\n    task = deepcopy(task)\n    handler = task[\"dataset\"][\"kwargs\"][\"handler\"]\n    if isinstance(handler, dict):\n        hash = hash_args(handler)\n        h_path = cache_dir / f\"{handler['class']}.{hash[:10]}.pkl\"\n        if not h_path.exists():\n            h = init_instance_by_config(handler)\n            h.to_pickle(h_path, dump_all=True)\n        task[\"dataset\"][\"kwargs\"][\"handler\"] = f\"file://{h_path}\"\n    return task\n"
  },
  {
    "path": "qlib/workflow/utils.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport atexit\nimport logging\nimport sys\nimport traceback\n\nfrom ..log import get_module_logger\nfrom . import R\nfrom .recorder import Recorder\n\nlogger = get_module_logger(\"workflow\", logging.INFO)\n\n\n# function to handle the experiment when unusual program ending occurs\ndef experiment_exit_handler():\n    \"\"\"\n    Method for handling the experiment when any unusual program ending occurs.\n    The `atexit` handler should be put in the last, since, as long as the program ends, it will be called.\n    Thus, if any exception or user interruption occurs beforehand, we should handle them first. Once `R` is\n    ended, another call of `R.end_exp` will not take effect.\n\n    Limitations:\n    - If pdb is used in your program, excepthook will not be triggered when it ends.  The status will be finished\n    \"\"\"\n    sys.excepthook = experiment_exception_hook  # handle uncaught exception\n    atexit.register(R.end_exp, recorder_status=Recorder.STATUS_FI)  # will not take effect if experiment ends\n\n\ndef experiment_exception_hook(exc_type, value, tb):\n    \"\"\"\n    End an experiment with status to be \"FAILED\". This exception tries to catch those uncaught exception\n    and end the experiment automatically.\n\n    Parameters\n    exc_type: Exception type\n    value: Exception's value\n    tb: Exception's traceback\n    \"\"\"\n    logger.error(f\"An exception has been raised[{exc_type.__name__}: {value}].\")\n\n    # Same as original format\n    traceback.print_tb(tb)\n    print(f\"{exc_type.__name__}: {value}\")\n\n    R.end_exp(recorder_status=Recorder.STATUS_FA)\n"
  },
  {
    "path": "scripts/README.md",
    "content": "\n- [Download Qlib Data](#Download-Qlib-Data)\n  - [Download CN Data](#Download-CN-Data)\n  - [Download US Data](#Download-US-Data)\n  - [Download CN Simple Data](#Download-CN-Simple-Data)\n  - [Help](#Help)\n- [Using in Qlib](#Using-in-Qlib)\n  - [US data](#US-data)\n  - [CN data](#CN-data)\n\n\n## Download Qlib Data\n\n\n### Download CN Data\n\n```bash\n# daily data\npython get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn\n\n# 1min  data (Optional for running non-high-frequency strategies)\npython get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data_1min --region cn --interval 1min\n```\n\n### Download US Data\n\n\n```bash\npython get_data.py qlib_data --target_dir ~/.qlib/qlib_data/us_data --region us\n```\n\n### Download CN Simple Data\n\n```bash\npython get_data.py qlib_data --name qlib_data_simple --target_dir ~/.qlib/qlib_data/cn_data --region cn\n```\n\n### Help\n\n```bash\npython get_data.py qlib_data --help\n```\n\n## Using in Qlib\n> For more information: https://qlib.readthedocs.io/en/latest/start/initialization.html\n\n\n### US data\n\n> Need to download data first: [Download US Data](#Download-US-Data)\n\n```python\nimport qlib\nfrom qlib.config import REG_US\nprovider_uri = \"~/.qlib/qlib_data/us_data\"  # target_dir\nqlib.init(provider_uri=provider_uri, region=REG_US)\n```\n\n### CN data\n\n> Need to download data first: [Download CN Data](#Download-CN-Data)\n\n```python\nimport qlib\nfrom qlib.constant import REG_CN\n\nprovider_uri = \"~/.qlib/qlib_data/cn_data\"  # target_dir\nqlib.init(provider_uri=provider_uri, region=REG_CN)\n```\n\n## Use Crowd Sourced Data\nThe is also a [crowd sourced version of qlib data](data_collector/crowd_source/README.md): https://github.com/chenditc/investment_data/releases\n```bash\nwget https://github.com/chenditc/investment_data/releases/latest/download/qlib_bin.tar.gz\ntar -zxvf qlib_bin.tar.gz -C ~/.qlib/qlib_data/cn_data --strip-components=2\n```\n"
  },
  {
    "path": "scripts/check_data_health.py",
    "content": "import os\nfrom typing import Optional\n\nimport fire\nimport pandas as pd\nfrom loguru import logger\nfrom tqdm import tqdm\n\nimport qlib\nfrom qlib.data import D\n\n\nclass DataHealthChecker:\n    \"\"\"Checks a dataset for data completeness and correctness. The data will be converted to a pd.DataFrame and checked for the following problems:\n    - any of the columns [\"open\", \"high\", \"low\", \"close\", \"volume\"] are missing\n    - any data is missing\n    - any step change in the OHLCV columns is above a threshold (default: 0.5 for price, 3 for volume)\n    - any factor is missing\n    \"\"\"\n\n    def __init__(\n        self,\n        csv_path=None,\n        qlib_dir=None,\n        freq=\"day\",\n        large_step_threshold_price=0.5,\n        large_step_threshold_volume=3,\n        missing_data_num=0,\n    ):\n        assert csv_path or qlib_dir, \"One of csv_path or qlib_dir should be provided.\"\n        assert not (csv_path and qlib_dir), \"Only one of csv_path or qlib_dir should be provided.\"\n\n        self.data = {}\n        self.problems = {}\n        self.freq = freq\n        self.large_step_threshold_price = large_step_threshold_price\n        self.large_step_threshold_volume = large_step_threshold_volume\n        self.missing_data_num = missing_data_num\n        self.qlib_dir = os.path.abspath(os.path.expanduser(qlib_dir))\n\n        if csv_path:\n            assert os.path.isdir(csv_path), f\"{csv_path} should be a directory.\"\n            files = [f for f in os.listdir(csv_path) if f.endswith(\".csv\")]\n            for filename in tqdm(files, desc=\"Loading data\"):\n                df = pd.read_csv(os.path.join(csv_path, filename))\n                self.data[filename] = df\n\n        elif qlib_dir:\n            qlib.init(provider_uri=qlib_dir)\n            self.load_qlib_data()\n\n    def load_qlib_data(self):\n        instruments = D.instruments(market=\"all\")\n        instrument_list = D.list_instruments(instruments=instruments, as_list=True, freq=self.freq)\n        required_fields = [\"$open\", \"$close\", \"$low\", \"$high\", \"$volume\", \"$factor\"]\n        for instrument in instrument_list:\n            df = D.features([instrument], required_fields, freq=self.freq)\n            df.rename(\n                columns={\n                    \"$open\": \"open\",\n                    \"$close\": \"close\",\n                    \"$low\": \"low\",\n                    \"$high\": \"high\",\n                    \"$volume\": \"volume\",\n                    \"$factor\": \"factor\",\n                },\n                inplace=True,\n            )\n            self.data[instrument] = df\n        print(df)\n\n    # NOTE:\n    # This check is added due to a known issue in Qlib where feature paths\n    # are constructed using lowercased instrument names. On case-sensitive\n    # file systems (e.g. Linux), uppercase directory names under `features/`\n    # will cause data loading failures.\n    #\n    # See: https://github.com/microsoft/qlib/issues/2053\n    def check_features_dir_lowercase(self) -> Optional[pd.DataFrame]:\n        \"\"\"\n        Check whether all subdirectories under `<qlib_dir>/features` are named in lowercase.\n\n        This validation helps prevent data loading issues on case-sensitive\n        file systems caused by uppercase instrument directory names.\n        \"\"\"\n        if not self.qlib_dir:\n            return None\n\n        features_dir = os.path.join(self.qlib_dir, \"features\")\n        if not os.path.isdir(features_dir):\n            logger.warning(f\"`features` directory not found under {self.qlib_dir}\")\n            return None\n\n        bad_dirs = []\n        for name in os.listdir(features_dir):\n            full_path = os.path.join(features_dir, name)\n            if os.path.isdir(full_path) and name != name.lower():\n                bad_dirs.append(name)\n\n        if bad_dirs:\n            result_df = pd.DataFrame({\"non_lowercase_dir\": bad_dirs})\n            return result_df\n        else:\n            logger.info(\n                f\"✅ All subdirectories under `{os.path.join(self.qlib_dir, 'features')}` are named in lowercase.\"\n            )\n            return None\n\n    def check_missing_data(self) -> Optional[pd.DataFrame]:\n        \"\"\"Check if any data is missing in the DataFrame.\"\"\"\n        result_dict = {\n            \"instruments\": [],\n            \"open\": [],\n            \"high\": [],\n            \"low\": [],\n            \"close\": [],\n            \"volume\": [],\n        }\n        for filename, df in self.data.items():\n            missing_data_columns = df.isnull().sum()[df.isnull().sum() > self.missing_data_num].index.tolist()\n            if len(missing_data_columns) > 0:\n                result_dict[\"instruments\"].append(filename)\n                result_dict[\"open\"].append(df.isnull().sum()[\"open\"])\n                result_dict[\"high\"].append(df.isnull().sum()[\"high\"])\n                result_dict[\"low\"].append(df.isnull().sum()[\"low\"])\n                result_dict[\"close\"].append(df.isnull().sum()[\"close\"])\n                result_dict[\"volume\"].append(df.isnull().sum()[\"volume\"])\n\n        result_df = pd.DataFrame(result_dict).set_index(\"instruments\")\n        if not result_df.empty:\n            return result_df\n        else:\n            logger.info(f\"✅ There are no missing data.\")\n            return None\n\n    def check_large_step_changes(self) -> Optional[pd.DataFrame]:\n        \"\"\"Check if there are any large step changes above the threshold in the OHLCV columns.\"\"\"\n        result_dict = {\n            \"instruments\": [],\n            \"col_name\": [],\n            \"date\": [],\n            \"pct_change\": [],\n        }\n        for filename, df in self.data.items():\n            affected_columns = []\n            for col in [\"open\", \"high\", \"low\", \"close\", \"volume\"]:\n                if col in df.columns:\n                    pct_change = df[col].pct_change(fill_method=None).abs()\n                    threshold = self.large_step_threshold_volume if col == \"volume\" else self.large_step_threshold_price\n                    if pct_change.max() > threshold:\n                        large_steps = pct_change[pct_change > threshold]\n                        result_dict[\"instruments\"].append(filename)\n                        result_dict[\"col_name\"].append(col)\n                        result_dict[\"date\"].append(large_steps.index.to_list()[0][1].strftime(\"%Y-%m-%d\"))\n                        result_dict[\"pct_change\"].append(pct_change.max())\n                        affected_columns.append(col)\n\n        result_df = pd.DataFrame(result_dict).set_index(\"instruments\")\n        if not result_df.empty:\n            return result_df\n        else:\n            logger.info(f\"✅ There are no large step changes in the OHLCV column above the threshold.\")\n            return None\n\n    def check_required_columns(self) -> Optional[pd.DataFrame]:\n        \"\"\"Check if any of the required columns (OLHCV) are missing in the DataFrame.\"\"\"\n        required_columns = [\"open\", \"high\", \"low\", \"close\", \"volume\"]\n        result_dict = {\n            \"instruments\": [],\n            \"missing_col\": [],\n        }\n        for filename, df in self.data.items():\n            if not all(column in df.columns for column in required_columns):\n                missing_required_columns = [column for column in required_columns if column not in df.columns]\n                result_dict[\"instruments\"].append(filename)\n                result_dict[\"missing_col\"] += missing_required_columns\n\n        result_df = pd.DataFrame(result_dict).set_index(\"instruments\")\n        if not result_df.empty:\n            return result_df\n        else:\n            logger.info(f\"✅ The columns (OLHCV) are complete and not missing.\")\n            return None\n\n    def check_missing_factor(self) -> Optional[pd.DataFrame]:\n        \"\"\"Check if the 'factor' column is missing in the DataFrame.\"\"\"\n        result_dict = {\n            \"instruments\": [],\n            \"missing_factor_col\": [],\n            \"missing_factor_data\": [],\n        }\n        for filename, df in self.data.items():\n            if \"000300\" in filename or \"000903\" in filename or \"000905\" in filename:\n                continue\n            if \"factor\" not in df.columns:\n                result_dict[\"instruments\"].append(filename)\n                result_dict[\"missing_factor_col\"].append(True)\n            if df[\"factor\"].isnull().all():\n                if filename in result_dict[\"instruments\"]:\n                    result_dict[\"missing_factor_data\"].append(True)\n                else:\n                    result_dict[\"instruments\"].append(filename)\n                    result_dict[\"missing_factor_col\"].append(False)\n                    result_dict[\"missing_factor_data\"].append(True)\n\n        result_df = pd.DataFrame(result_dict).set_index(\"instruments\")\n        if not result_df.empty:\n            return result_df\n        else:\n            logger.info(f\"✅ The `factor` column already exists and is not empty.\")\n            return None\n\n    def check_data(self):\n        check_missing_data_result = self.check_missing_data()\n        check_large_step_changes_result = self.check_large_step_changes()\n        check_required_columns_result = self.check_required_columns()\n        check_missing_factor_result = self.check_missing_factor()\n        check_features_dir_case_result = self.check_features_dir_lowercase()\n        if (\n            check_large_step_changes_result is not None\n            or check_large_step_changes_result is not None\n            or check_required_columns_result is not None\n            or check_missing_factor_result is not None\n            or check_features_dir_case_result is not None\n        ):\n            print(f\"\\nSummary of data health check ({len(self.data)} files checked):\")\n            print(\"-------------------------------------------------\")\n            if isinstance(check_missing_data_result, pd.DataFrame):\n                logger.warning(f\"There is missing data.\")\n                print(check_missing_data_result)\n            if isinstance(check_large_step_changes_result, pd.DataFrame):\n                logger.warning(f\"The OHLCV column has large step changes.\")\n                print(check_large_step_changes_result)\n            if isinstance(check_required_columns_result, pd.DataFrame):\n                logger.warning(f\"Columns (OLHCV) are missing.\")\n                print(check_required_columns_result)\n            if isinstance(check_missing_factor_result, pd.DataFrame):\n                logger.warning(f\"The factor column does not exist or is empty\")\n                print(check_missing_factor_result)\n            if isinstance(check_features_dir_case_result, pd.DataFrame):\n                logger.warning(\n                    f\"Some subdirectories under `{os.path.join(self.qlib_dir, 'features')}` contain uppercase letters, please rename them to lowercase manually.\"\n                )\n                print(check_features_dir_case_result)\n\n\nif __name__ == \"__main__\":\n    fire.Fire(DataHealthChecker)\n"
  },
  {
    "path": "scripts/check_dump_bin.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom pathlib import Path\nfrom concurrent.futures import ProcessPoolExecutor\n\nimport qlib\nfrom qlib.data import D\n\nimport fire\nimport datacompy\nimport pandas as pd\nfrom tqdm import tqdm\nfrom loguru import logger\n\n\nclass CheckBin:\n    NOT_IN_FEATURES = \"not in features\"\n    COMPARE_FALSE = \"compare False\"\n    COMPARE_TRUE = \"compare True\"\n    COMPARE_ERROR = \"compare error\"\n\n    def __init__(\n        self,\n        qlib_dir: str,\n        csv_path: str,\n        check_fields: str = None,\n        freq: str = \"day\",\n        symbol_field_name: str = \"symbol\",\n        date_field_name: str = \"date\",\n        file_suffix: str = \".csv\",\n        max_workers: int = 16,\n    ):\n        \"\"\"\n\n        Parameters\n        ----------\n        qlib_dir : str\n            qlib dir\n        csv_path : str\n            origin csv path\n        check_fields : str, optional\n            check fields, by default None, check qlib_dir/features/<first_dir>/*.<freq>.bin\n        freq : str, optional\n            freq, value from [\"day\", \"1m\"]\n        symbol_field_name: str, optional\n            symbol field name, by default \"symbol\"\n        date_field_name: str, optional\n            date field name, by default \"date\"\n        file_suffix: str, optional\n            csv file suffix, by default \".csv\"\n        max_workers: int, optional\n            max workers, by default 16\n        \"\"\"\n        self.qlib_dir = Path(qlib_dir).expanduser()\n        bin_path_list = list(self.qlib_dir.joinpath(\"features\").iterdir())\n        self.qlib_symbols = sorted(map(lambda x: x.name.lower(), bin_path_list))\n        qlib.init(\n            provider_uri=str(self.qlib_dir.resolve()),\n            mount_path=str(self.qlib_dir.resolve()),\n            auto_mount=False,\n            redis_port=-1,\n        )\n        csv_path = Path(csv_path).expanduser()\n        self.csv_files = sorted(csv_path.glob(f\"*{file_suffix}\") if csv_path.is_dir() else [csv_path])\n\n        if check_fields is None:\n            check_fields = list(map(lambda x: x.name.split(\".\")[0], bin_path_list[0].glob(f\"*.bin\")))\n        else:\n            check_fields = check_fields.split(\",\") if isinstance(check_fields, str) else check_fields\n        self.check_fields = list(map(lambda x: x.strip(), check_fields))\n        self.qlib_fields = list(map(lambda x: f\"${x}\", self.check_fields))\n        self.max_workers = max_workers\n        self.symbol_field_name = symbol_field_name\n        self.date_field_name = date_field_name\n        self.freq = freq\n        self.file_suffix = file_suffix\n\n    def _compare(self, file_path: Path):\n        symbol = file_path.name.strip(self.file_suffix)\n        if symbol.lower() not in self.qlib_symbols:\n            return self.NOT_IN_FEATURES\n        # qlib data\n        qlib_df = D.features([symbol], self.qlib_fields, freq=self.freq)\n        qlib_df.rename(columns={_c: _c.strip(\"$\") for _c in qlib_df.columns}, inplace=True)\n        # csv data\n        origin_df = pd.read_csv(file_path)\n        origin_df[self.date_field_name] = pd.to_datetime(origin_df[self.date_field_name])\n        if self.symbol_field_name not in origin_df.columns:\n            origin_df[self.symbol_field_name] = symbol\n        origin_df.set_index([self.symbol_field_name, self.date_field_name], inplace=True)\n        origin_df.index.names = qlib_df.index.names\n        origin_df = origin_df.reindex(qlib_df.index)\n        try:\n            compare = datacompy.Compare(\n                origin_df,\n                qlib_df,\n                on_index=True,\n                abs_tol=1e-08,  # Optional, defaults to 0\n                rel_tol=1e-05,  # Optional, defaults to 0\n                df1_name=\"Original\",  # Optional, defaults to 'df1'\n                df2_name=\"New\",  # Optional, defaults to 'df2'\n            )\n            _r = compare.matches(ignore_extra_columns=True)\n            return self.COMPARE_TRUE if _r else self.COMPARE_FALSE\n        except Exception as e:\n            logger.warning(f\"{symbol} compare error: {e}\")\n            return self.COMPARE_ERROR\n\n    def check(self):\n        \"\"\"Check whether the bin file after ``dump_bin.py`` is executed is consistent with the original csv file data\"\"\"\n        logger.info(\"start check......\")\n\n        error_list = []\n        not_in_features = []\n        compare_false = []\n        with tqdm(total=len(self.csv_files)) as p_bar:\n            with ProcessPoolExecutor(max_workers=self.max_workers) as executor:\n                for file_path, _check_res in zip(self.csv_files, executor.map(self._compare, self.csv_files)):\n                    symbol = file_path.name.strip(self.file_suffix)\n                    if _check_res == self.NOT_IN_FEATURES:\n                        not_in_features.append(symbol)\n                    elif _check_res == self.COMPARE_ERROR:\n                        error_list.append(symbol)\n                    elif _check_res == self.COMPARE_FALSE:\n                        compare_false.append(symbol)\n                    p_bar.update()\n\n        logger.info(\"end of check......\")\n        if error_list:\n            logger.warning(f\"compare error: {error_list}\")\n        if not_in_features:\n            logger.warning(f\"not in features: {not_in_features}\")\n        if compare_false:\n            logger.warning(f\"compare False: {compare_false}\")\n        logger.info(\n            f\"total {len(self.csv_files)}, {len(error_list)} errors, {len(not_in_features)} not in features, {len(compare_false)} compare false\"\n        )\n\n\nif __name__ == \"__main__\":\n    fire.Fire(CheckBin)\n"
  },
  {
    "path": "scripts/collect_info.py",
    "content": "import sys\nimport platform\nimport qlib\nimport fire\nimport pkg_resources\nfrom pathlib import Path\n\nQLIB_PATH = Path(__file__).absolute().resolve().parent.parent\n\n\nclass InfoCollector:\n    \"\"\"\n    User could collect system info by following commands\n    `cd scripts && python collect_info.py all`\n    - NOTE: please avoid running this script in the project folder which contains `qlib`\n    \"\"\"\n\n    def sys(self):\n        \"\"\"collect system related info\"\"\"\n        for method in [\"system\", \"machine\", \"platform\", \"version\"]:\n            print(getattr(platform, method)())\n\n    def py(self):\n        \"\"\"collect Python related info\"\"\"\n        print(\"Python version: {}\".format(sys.version.replace(\"\\n\", \" \")))\n\n    def qlib(self):\n        \"\"\"collect qlib related info\"\"\"\n        print(\"Qlib version: {}\".format(qlib.__version__))\n        REQUIRED = [\n            \"setuptools\",\n            \"wheel\",\n            \"cython\",\n            \"pyyaml\",\n            \"numpy\",\n            \"pandas\",\n            \"mlflow\",\n            \"filelock\",\n            \"redis\",\n            \"dill\",\n            \"fire\",\n            \"ruamel.yaml\",\n            \"python-redis-lock\",\n            \"tqdm\",\n            \"pymongo\",\n            \"loguru\",\n            \"lightgbm\",\n            \"gym\",\n            \"cvxpy\",\n            \"joblib\",\n            \"matplotlib\",\n            \"jupyter\",\n            \"nbconvert\",\n            \"pyarrow\",\n            \"pydantic-settings\",\n            \"setuptools-scm\",\n        ]\n\n        for package in REQUIRED:\n            version = pkg_resources.get_distribution(package).version\n            print(f\"{package}=={version}\")\n\n    def all(self):\n        \"\"\"collect all info\"\"\"\n        for method in [\"sys\", \"py\", \"qlib\"]:\n            getattr(self, method)()\n            print()\n\n\nif __name__ == \"__main__\":\n    fire.Fire(InfoCollector)\n"
  },
  {
    "path": "scripts/data_collector/README.md",
    "content": "# Data Collector\n\n## Introduction\n\nScripts for data collection\n\n- yahoo: get *US/CN* stock data from *Yahoo Finance*\n- fund: get fund data from *http://fund.eastmoney.com*\n- cn_index: get *CN index* from *http://www.csindex.com.cn*, *CSI300*/*CSI100*\n- us_index: get *US index* from *https://en.wikipedia.org/wiki*, *SP500*/*NASDAQ100*/*DJIA*/*SP400*\n- contrib: scripts for some auxiliary functions\n\n\n## Custom Data Collection\n\n> Specific implementation reference: https://github.com/microsoft/qlib/tree/main/scripts/data_collector/yahoo\n\n1. Create a dataset code directory in the current directory\n2. Add `collector.py`\n   - add collector class:\n     ```python\n     CUR_DIR = Path(__file__).resolve().parent\n     sys.path.append(str(CUR_DIR.parent.parent))\n     from data_collector.base import BaseCollector, BaseNormalize, BaseRun\n     class UserCollector(BaseCollector):\n         ...\n     ```\n   - add normalize class:\n     ```python\n     class UserNormalzie(BaseNormalize):\n         ...\n     ```\n   - add `CLI` class:\n     ```python\n     class Run(BaseRun):\n         ...\n     ```\n3. add `README.md`\n4. add `requirements.txt`\n\n\n## Description of dataset\n\n  |             | Basic data                                                                                                       |\n  |------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------|\n  | Features    | **Price/Volume**: <br>&nbsp;&nbsp; - $close/$open/$low/$high/$volume/$change/$factor                             |\n  | Calendar    | **\\<freq>.txt**: <br>&nbsp;&nbsp; - day.txt<br>&nbsp;&nbsp;  - 1min.txt                                          |\n  | Instruments | **\\<market>.txt**: <br>&nbsp;&nbsp; - required: **all.txt**; <br>&nbsp;&nbsp;  - csi300.txt/csi500.txt/sp500.txt |\n\n  - `Features`: data, **digital**\n    - if not **adjusted**, **factor=1**\n\n### Data-dependent component\n\n> To make the component running correctly, the dependent data are required\n\n  | Component      | required data                                     |\n  |---------------------------------------------------|--------------------------------|\n  | Data retrieval | Features, Calendar, Instrument                    |\n  | Backtest       | **Features[Price/Volume]**, Calendar, Instruments |"
  },
  {
    "path": "scripts/data_collector/baostock_5min/README.md",
    "content": "## Collector Data\n\n### Get Qlib data(`bin file`)\n\n  - get data: `python scripts/get_data.py qlib_data`\n  - parameters:\n    - `target_dir`: save dir, by default *~/.qlib/qlib_data/cn_data_5min*\n    - `version`: dataset version, value from [`v2`], by default `v2`\n      - `v2` end date is *2022-12*\n    - `interval`: `5min`\n    - `region`: `hs300`\n    - `delete_old`: delete existing data from `target_dir`(*features, calendars, instruments, dataset_cache, features_cache*), value from [`True`, `False`], by default `True`\n    - `exists_skip`: traget_dir data already exists, skip `get_data`, value from [`True`, `False`], by default `False`\n  - examples:\n    ```bash\n    # hs300 5min\n    python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/hs300_data_5min --region hs300 --interval 5min\n    ```\n    \n### Collector *Baostock high frequency* data to qlib\n> collector *Baostock high frequency* data and *dump* into `qlib` format.\n> If the above ready-made data can't meet users' requirements,  users can follow this section to crawl the latest data and convert it to qlib-data.\n  1. download data to csv: `python scripts/data_collector/baostock_5min/collector.py download_data`\n     \n     This will download the raw data such as date, symbol, open, high, low, close, volume, amount, adjustflag from baostock to a local directory. One file per symbol.\n     - parameters:\n          - `source_dir`: save the directory\n          - `interval`: `5min`\n          - `region`: `HS300`\n          - `start`: start datetime, by default *None*\n          - `end`: end datetime, by default *None*\n     - examples:\n          ```bash\n          # cn 5min data\n          python collector.py download_data --source_dir ~/.qlib/stock_data/source/hs300_5min_original --start 2022-01-01 --end 2022-01-30 --interval 5min --region HS300\n          ```\n  2. normalize data: `python scripts/data_collector/baostock_5min/collector.py normalize_data`\n     \n     This will:\n     1. Normalize high, low, close, open price using adjclose.\n     2. Normalize the high, low, close, open price so that the first valid trading date's close price is 1. \n     - parameters:\n          - `source_dir`: csv directory\n          - `normalize_dir`: result directory\n          - `interval`: `5min`\n            > if **`interval == 5min`**, `qlib_data_1d_dir` cannot be `None`\n          - `region`: `HS300`\n          - `date_field_name`: column *name* identifying time in csv files, by default `date`\n          - `symbol_field_name`: column *name* identifying symbol in csv files, by default `symbol`\n          - `end_date`: if not `None`, normalize the last date saved (*including end_date*); if `None`, it will ignore this parameter; by default `None`\n          - `qlib_data_1d_dir`: qlib directory(1d data)\n            if interval==5min, qlib_data_1d_dir cannot be None, normalize 5min needs to use 1d data;\n            ```\n                # qlib_data_1d can be obtained like this:\n                python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn --version v3\n            ```\n      - examples:\n        ```bash\n        # normalize 5min cn\n        python collector.py normalize_data --qlib_data_1d_dir ~/.qlib/qlib_data/cn_data --source_dir ~/.qlib/stock_data/source/hs300_5min_original --normalize_dir ~/.qlib/stock_data/source/hs300_5min_nor --region HS300 --interval 5min\n        ```\n  3. dump data: `python scripts/dump_bin.py dump_all`\n    \n     This will convert the normalized csv in `feature` directory as numpy array and store the normalized data one file per column and one symbol per directory. \n    \n     - parameters:\n       - `data_path`: stock data path or directory, **normalize result(normalize_dir)**\n       - `qlib_dir`: qlib(dump) data director\n       - `freq`: transaction frequency, by default `day`\n         > `freq_map = {1d:day, 5mih: 5min}`\n       - `max_workers`: number of threads, by default *16*\n       - `include_fields`: dump fields, by default `\"\"`\n       - `exclude_fields`: fields not dumped, by default `\"\"\"\n         > dump_fields = `include_fields if include_fields else set(symbol_df.columns) - set(exclude_fields) exclude_fields else symbol_df.columns`\n       - `symbol_field_name`: column *name* identifying symbol in csv files, by default `symbol`\n       - `date_field_name`: column *name* identifying time in csv files, by default `date`\n       - `file_suffix`: stock data file format, by default \".csv\"\n     - examples:\n       ```bash\n       # dump 5min cn\n       python dump_bin.py dump_all --data_path ~/.qlib/stock_data/source/hs300_5min_nor --qlib_dir ~/.qlib/qlib_data/hs300_5min_bin --freq 5min --exclude_fields date,symbol\n       ```"
  },
  {
    "path": "scripts/data_collector/baostock_5min/collector.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\n\nimport sys\nimport copy\nimport fire\nimport numpy as np\nimport pandas as pd\nimport baostock as bs\nfrom tqdm import tqdm\nfrom pathlib import Path\nfrom loguru import logger\nfrom typing import Iterable, List\n\nimport qlib\nfrom qlib.data import D\n\nCUR_DIR = Path(__file__).resolve().parent\nsys.path.append(str(CUR_DIR.parent.parent))\n\nfrom data_collector.base import BaseCollector, BaseNormalize, BaseRun\nfrom data_collector.utils import generate_minutes_calendar_from_daily, calc_adjusted_price\n\n\nclass BaostockCollectorHS3005min(BaseCollector):\n    def __init__(\n        self,\n        save_dir: [str, Path],\n        start=None,\n        end=None,\n        interval=\"5min\",\n        max_workers=4,\n        max_collector_count=2,\n        delay=0,\n        check_data_length: int = None,\n        limit_nums: int = None,\n    ):\n        \"\"\"\n\n        Parameters\n        ----------\n        save_dir: str\n            stock save dir\n        max_workers: int\n            workers, default 4\n        max_collector_count: int\n            default 2\n        delay: float\n            time.sleep(delay), default 0\n        interval: str\n            freq, value from [5min], default 5min\n        start: str\n            start datetime, default None\n        end: str\n            end datetime, default None\n        check_data_length: int\n            check data length, by default None\n        limit_nums: int\n            using for debug, by default None\n        \"\"\"\n        bs.login()\n        super(BaostockCollectorHS3005min, self).__init__(\n            save_dir=save_dir,\n            start=start,\n            end=end,\n            interval=interval,\n            max_workers=max_workers,\n            max_collector_count=max_collector_count,\n            delay=delay,\n            check_data_length=check_data_length,\n            limit_nums=limit_nums,\n        )\n\n    def get_trade_calendar(self):\n        _format = \"%Y-%m-%d\"\n        start = self.start_datetime.strftime(_format)\n        end = self.end_datetime.strftime(_format)\n        rs = bs.query_trade_dates(start_date=start, end_date=end)\n        calendar_list = []\n        while (rs.error_code == \"0\") & rs.next():\n            calendar_list.append(rs.get_row_data())\n        calendar_df = pd.DataFrame(calendar_list, columns=rs.fields)\n        trade_calendar_df = calendar_df[~calendar_df[\"is_trading_day\"].isin([\"0\"])]\n        return trade_calendar_df[\"calendar_date\"].values\n\n    @staticmethod\n    def process_interval(interval: str):\n        if interval == \"1d\":\n            return {\"interval\": \"d\", \"fields\": \"date,code,open,high,low,close,volume,amount,adjustflag\"}\n        if interval == \"5min\":\n            return {\"interval\": \"5\", \"fields\": \"date,time,code,open,high,low,close,volume,amount,adjustflag\"}\n\n    def get_data(\n        self, symbol: str, interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp\n    ) -> pd.DataFrame:\n        df = self.get_data_from_remote(\n            symbol=symbol, interval=interval, start_datetime=start_datetime, end_datetime=end_datetime\n        )\n        df.columns = [\"date\", \"time\", \"symbol\", \"open\", \"high\", \"low\", \"close\", \"volume\", \"amount\", \"adjustflag\"]\n        df[\"time\"] = pd.to_datetime(df[\"time\"], format=\"%Y%m%d%H%M%S%f\")\n        df[\"date\"] = df[\"time\"].dt.strftime(\"%Y-%m-%d %H:%M:%S\")\n        df[\"date\"] = df[\"date\"].map(lambda x: pd.Timestamp(x) - pd.Timedelta(minutes=5))\n        df.drop([\"time\"], axis=1, inplace=True)\n        df[\"symbol\"] = df[\"symbol\"].map(lambda x: str(x).replace(\".\", \"\").upper())\n        return df\n\n    @staticmethod\n    def get_data_from_remote(\n        symbol: str, interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp\n    ) -> pd.DataFrame:\n        df = pd.DataFrame()\n        rs = bs.query_history_k_data_plus(\n            symbol,\n            BaostockCollectorHS3005min.process_interval(interval=interval)[\"fields\"],\n            start_date=str(start_datetime.strftime(\"%Y-%m-%d\")),\n            end_date=str(end_datetime.strftime(\"%Y-%m-%d\")),\n            frequency=BaostockCollectorHS3005min.process_interval(interval=interval)[\"interval\"],\n            adjustflag=\"3\",\n        )\n        if rs.error_code == \"0\" and len(rs.data) > 0:\n            data_list = rs.data\n            columns = rs.fields\n            df = pd.DataFrame(data_list, columns=columns)\n        return df\n\n    def get_hs300_symbols(self) -> List[str]:\n        hs300_stocks = []\n        trade_calendar = self.get_trade_calendar()\n        with tqdm(total=len(trade_calendar)) as p_bar:\n            for date in trade_calendar:\n                rs = bs.query_hs300_stocks(date=date)\n                while rs.error_code == \"0\" and rs.next():\n                    hs300_stocks.append(rs.get_row_data())\n                p_bar.update()\n        return sorted({e[1] for e in hs300_stocks})\n\n    def get_instrument_list(self):\n        logger.info(\"get HS stock symbols......\")\n        symbols = self.get_hs300_symbols()\n        logger.info(f\"get {len(symbols)} symbols.\")\n        return symbols\n\n    def normalize_symbol(self, symbol: str):\n        return str(symbol).replace(\".\", \"\").upper()\n\n\nclass BaostockNormalizeHS3005min(BaseNormalize):\n    COLUMNS = [\"open\", \"close\", \"high\", \"low\", \"volume\"]\n    AM_RANGE = (\"09:30:00\", \"11:29:00\")\n    PM_RANGE = (\"13:00:00\", \"14:59:00\")\n\n    def __init__(\n        self, qlib_data_1d_dir: [str, Path], date_field_name: str = \"date\", symbol_field_name: str = \"symbol\", **kwargs\n    ):\n        \"\"\"\n\n        Parameters\n        ----------\n        qlib_data_1d_dir: str, Path\n            the qlib data to be updated for yahoo, usually from: Normalised to 5min using local 1d data\n        date_field_name: str\n            date field name, default is date\n        symbol_field_name: str\n            symbol field name, default is symbol\n        \"\"\"\n        bs.login()\n        qlib.init(provider_uri=qlib_data_1d_dir)\n        self.all_1d_data = D.features(D.instruments(\"all\"), [\"$paused\", \"$volume\", \"$factor\", \"$close\"], freq=\"day\")\n        super(BaostockNormalizeHS3005min, self).__init__(date_field_name, symbol_field_name)\n\n    @staticmethod\n    def calc_change(df: pd.DataFrame, last_close: float) -> pd.Series:\n        df = df.copy()\n        _tmp_series = df[\"close\"].ffill()\n        _tmp_shift_series = _tmp_series.shift(1)\n        if last_close is not None:\n            _tmp_shift_series.iloc[0] = float(last_close)\n        change_series = _tmp_series / _tmp_shift_series - 1\n        return change_series\n\n    def _get_calendar_list(self) -> Iterable[pd.Timestamp]:\n        return self.generate_5min_from_daily(self.calendar_list_1d)\n\n    @property\n    def calendar_list_1d(self):\n        calendar_list_1d = getattr(self, \"_calendar_list_1d\", None)\n        if calendar_list_1d is None:\n            calendar_list_1d = self._get_1d_calendar_list()\n            setattr(self, \"_calendar_list_1d\", calendar_list_1d)\n        return calendar_list_1d\n\n    @staticmethod\n    def normalize_baostock(\n        df: pd.DataFrame,\n        calendar_list: list = None,\n        date_field_name: str = \"date\",\n        symbol_field_name: str = \"symbol\",\n        last_close: float = None,\n    ):\n        if df.empty:\n            return df\n        symbol = df.loc[df[symbol_field_name].first_valid_index(), symbol_field_name]\n        columns = copy.deepcopy(BaostockNormalizeHS3005min.COLUMNS)\n        df = df.copy()\n        df.set_index(date_field_name, inplace=True)\n        df.index = pd.to_datetime(df.index)\n        df = df[~df.index.duplicated(keep=\"first\")]\n        if calendar_list is not None:\n            df = df.reindex(\n                pd.DataFrame(index=calendar_list)\n                .loc[pd.Timestamp(df.index.min()).date() : pd.Timestamp(df.index.max()).date() + pd.Timedelta(days=1)]\n                .index\n            )\n        df.sort_index(inplace=True)\n        df.loc[(df[\"volume\"] <= 0) | np.isnan(df[\"volume\"]), list(set(df.columns) - {symbol_field_name})] = np.nan\n\n        df[\"change\"] = BaostockNormalizeHS3005min.calc_change(df, last_close)\n\n        columns += [\"change\"]\n        df.loc[(df[\"volume\"] <= 0) | np.isnan(df[\"volume\"]), columns] = np.nan\n\n        df[symbol_field_name] = symbol\n        df.index.names = [date_field_name]\n        return df.reset_index()\n\n    def generate_5min_from_daily(self, calendars: Iterable) -> pd.Index:\n        return generate_minutes_calendar_from_daily(\n            calendars, freq=\"5min\", am_range=self.AM_RANGE, pm_range=self.PM_RANGE\n        )\n\n    def adjusted_price(self, df: pd.DataFrame) -> pd.DataFrame:\n        df = calc_adjusted_price(\n            df=df,\n            _date_field_name=self._date_field_name,\n            _symbol_field_name=self._symbol_field_name,\n            frequence=\"5min\",\n            _1d_data_all=self.all_1d_data,\n        )\n        return df\n\n    def _get_1d_calendar_list(self) -> Iterable[pd.Timestamp]:\n        return list(D.calendar(freq=\"day\"))\n\n    def normalize(self, df: pd.DataFrame) -> pd.DataFrame:\n        # normalize\n        df = self.normalize_baostock(df, self._calendar_list, self._date_field_name, self._symbol_field_name)\n        # adjusted price\n        df = self.adjusted_price(df)\n        return df\n\n\nclass Run(BaseRun):\n    def __init__(self, source_dir=None, normalize_dir=None, max_workers=1, interval=\"5min\", region=\"HS300\"):\n        \"\"\"\n        Changed the default value of: scripts.data_collector.base.BaseRun.\n        \"\"\"\n        super().__init__(source_dir, normalize_dir, max_workers, interval)\n        self.region = region\n\n    @property\n    def collector_class_name(self):\n        return f\"BaostockCollector{self.region.upper()}{self.interval}\"\n\n    @property\n    def normalize_class_name(self):\n        return f\"BaostockNormalize{self.region.upper()}{self.interval}\"\n\n    @property\n    def default_base_dir(self) -> [Path, str]:\n        return CUR_DIR\n\n    def download_data(\n        self,\n        max_collector_count=2,\n        delay=0.5,\n        start=None,\n        end=None,\n        check_data_length=None,\n        limit_nums=None,\n    ):\n        \"\"\"download data from Baostock\n\n        Notes\n        -----\n            check_data_length, example:\n                hs300 5min, a week: 4 * 60 * 5\n\n        Examples\n        ---------\n            # get hs300 5min data\n            $ python collector.py download_data --source_dir ~/.qlib/stock_data/source/hs300_5min_original --start 2022-01-01 --end 2022-01-30 --interval 5min --region HS300\n        \"\"\"\n        super(Run, self).download_data(max_collector_count, delay, start, end, check_data_length, limit_nums)\n\n    def normalize_data(\n        self,\n        date_field_name: str = \"date\",\n        symbol_field_name: str = \"symbol\",\n        end_date: str = None,\n        qlib_data_1d_dir: str = None,\n    ):\n        \"\"\"normalize data\n\n        Attention\n        ---------\n        qlib_data_1d_dir cannot be None, normalize 5min needs to use 1d data;\n\n            qlib_data_1d can be obtained like this:\n                $ python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn --version v3\n            or:\n                download 1d data, reference: https://github.com/microsoft/qlib/tree/main/scripts/data_collector/yahoo#1d-from-yahoo\n\n        Examples\n        ---------\n            $ python collector.py normalize_data --qlib_data_1d_dir ~/.qlib/qlib_data/cn_data --source_dir ~/.qlib/stock_data/source/hs300_5min_original --normalize_dir ~/.qlib/stock_data/source/hs300_5min_nor --region HS300 --interval 5min\n        \"\"\"\n        if qlib_data_1d_dir is None or not Path(qlib_data_1d_dir).expanduser().exists():\n            raise ValueError(\n                \"If normalize 5min, the qlib_data_1d_dir parameter must be set: --qlib_data_1d_dir <user qlib 1d data >, Reference: https://github.com/microsoft/qlib/tree/main/scripts/data_collector/yahoo#automatic-update-of-daily-frequency-datafrom-yahoo-finance\"\n            )\n        super(Run, self).normalize_data(\n            date_field_name, symbol_field_name, end_date=end_date, qlib_data_1d_dir=qlib_data_1d_dir\n        )\n\n\nif __name__ == \"__main__\":\n    fire.Fire(Run)\n"
  },
  {
    "path": "scripts/data_collector/baostock_5min/requirements.txt",
    "content": "loguru\nfire\nrequests\nnumpy\npandas\ntqdm\nlxml\nyahooquery\njoblib\nbeautifulsoup4\nbs4\nsoupsieve\nbaostock"
  },
  {
    "path": "scripts/data_collector/base.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\n\nimport abc\nimport time\nimport datetime\nimport importlib\nfrom pathlib import Path\nfrom typing import Type, Iterable\nfrom concurrent.futures import ProcessPoolExecutor\n\nimport pandas as pd\nfrom tqdm import tqdm\nfrom loguru import logger\nfrom joblib import Parallel, delayed\nfrom qlib.utils import code_to_fname\n\n\nclass BaseCollector(abc.ABC):\n    CACHE_FLAG = \"CACHED\"\n    NORMAL_FLAG = \"NORMAL\"\n\n    DEFAULT_START_DATETIME_1D = pd.Timestamp(\"2000-01-01\")\n    DEFAULT_START_DATETIME_1MIN = pd.Timestamp(datetime.datetime.now() - pd.Timedelta(days=5 * 6 - 1)).date()\n    DEFAULT_END_DATETIME_1D = pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1)).date()\n    DEFAULT_END_DATETIME_1MIN = DEFAULT_END_DATETIME_1D\n\n    INTERVAL_1min = \"1min\"\n    INTERVAL_1d = \"1d\"\n\n    def __init__(\n        self,\n        save_dir: [str, Path],\n        start=None,\n        end=None,\n        interval=\"1d\",\n        max_workers=1,\n        max_collector_count=2,\n        delay=0,\n        check_data_length: int = None,\n        limit_nums: int = None,\n    ):\n        \"\"\"\n\n        Parameters\n        ----------\n        save_dir: str\n            instrument save dir\n        max_workers: int\n            workers, default 1; Concurrent number, default is 1; when collecting data, it is recommended that max_workers be set to 1\n        max_collector_count: int\n            default 2\n        delay: float\n            time.sleep(delay), default 0\n        interval: str\n            freq, value from [1min, 1d], default 1d\n        start: str\n            start datetime, default None\n        end: str\n            end datetime, default None\n        check_data_length: int\n            check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None.\n        limit_nums: int\n            using for debug, by default None\n        \"\"\"\n        self.save_dir = Path(save_dir).expanduser().resolve()\n        self.save_dir.mkdir(parents=True, exist_ok=True)\n\n        self.delay = delay\n        self.max_workers = max_workers\n        self.max_collector_count = max_collector_count\n        self.mini_symbol_map = {}\n        self.interval = interval\n        self.check_data_length = max(int(check_data_length) if check_data_length is not None else 0, 0)\n\n        self.start_datetime = self.normalize_start_datetime(start)\n        self.end_datetime = self.normalize_end_datetime(end)\n\n        self.instrument_list = sorted(set(self.get_instrument_list()))\n\n        if limit_nums is not None:\n            try:\n                self.instrument_list = self.instrument_list[: int(limit_nums)]\n            except Exception as e:\n                logger.warning(f\"Cannot use limit_nums={limit_nums}, the parameter will be ignored\")\n\n    def normalize_start_datetime(self, start_datetime: [str, pd.Timestamp] = None):\n        return (\n            pd.Timestamp(str(start_datetime))\n            if start_datetime\n            else getattr(self, f\"DEFAULT_START_DATETIME_{self.interval.upper()}\")\n        )\n\n    def normalize_end_datetime(self, end_datetime: [str, pd.Timestamp] = None):\n        return (\n            pd.Timestamp(str(end_datetime))\n            if end_datetime\n            else getattr(self, f\"DEFAULT_END_DATETIME_{self.interval.upper()}\")\n        )\n\n    @abc.abstractmethod\n    def get_instrument_list(self):\n        raise NotImplementedError(\"rewrite get_instrument_list\")\n\n    @abc.abstractmethod\n    def normalize_symbol(self, symbol: str):\n        \"\"\"normalize symbol\"\"\"\n        raise NotImplementedError(\"rewrite normalize_symbol\")\n\n    @abc.abstractmethod\n    def get_data(\n        self, symbol: str, interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp\n    ) -> pd.DataFrame:\n        \"\"\"get data with symbol\n\n        Parameters\n        ----------\n        symbol: str\n        interval: str\n            value from [1min, 1d]\n        start_datetime: pd.Timestamp\n        end_datetime: pd.Timestamp\n\n        Returns\n        ---------\n            pd.DataFrame, \"symbol\" and \"date\"in pd.columns\n\n        \"\"\"\n        raise NotImplementedError(\"rewrite get_timezone\")\n\n    def sleep(self):\n        time.sleep(self.delay)\n\n    def _simple_collector(self, symbol: str):\n        \"\"\"\n\n        Parameters\n        ----------\n        symbol: str\n\n        \"\"\"\n        self.sleep()\n        df = self.get_data(symbol, self.interval, self.start_datetime, self.end_datetime)\n        _result = self.NORMAL_FLAG\n        if self.check_data_length > 0:\n            _result = self.cache_small_data(symbol, df)\n        if _result == self.NORMAL_FLAG:\n            self.save_instrument(symbol, df)\n        return _result\n\n    def save_instrument(self, symbol, df: pd.DataFrame):\n        \"\"\"save instrument data to file\n\n        Parameters\n        ----------\n        symbol: str\n            instrument code\n        df : pd.DataFrame\n            df.columns must contain \"symbol\" and \"datetime\"\n        \"\"\"\n        if df is None or df.empty:\n            logger.warning(f\"{symbol} is empty\")\n            return\n\n        symbol = self.normalize_symbol(symbol)\n        symbol = code_to_fname(symbol)\n        instrument_path = self.save_dir.joinpath(f\"{symbol}.csv\")\n        df[\"symbol\"] = symbol\n        if instrument_path.exists():\n            _old_df = pd.read_csv(instrument_path)\n            df = pd.concat([_old_df, df], sort=False)\n        df.to_csv(instrument_path, index=False)\n\n    def cache_small_data(self, symbol, df):\n        if len(df) < self.check_data_length:\n            logger.warning(f\"the number of trading days of {symbol} is less than {self.check_data_length}!\")\n            _temp = self.mini_symbol_map.setdefault(symbol, [])\n            _temp.append(df.copy())\n            return self.CACHE_FLAG\n        else:\n            if symbol in self.mini_symbol_map:\n                self.mini_symbol_map.pop(symbol)\n            return self.NORMAL_FLAG\n\n    def _collector(self, instrument_list):\n        error_symbol = []\n        res = Parallel(n_jobs=self.max_workers)(\n            delayed(self._simple_collector)(_inst) for _inst in tqdm(instrument_list)\n        )\n        for _symbol, _result in zip(instrument_list, res):\n            if _result != self.NORMAL_FLAG:\n                error_symbol.append(_symbol)\n        print(error_symbol)\n        logger.info(f\"error symbol nums: {len(error_symbol)}\")\n        logger.info(f\"current get symbol nums: {len(instrument_list)}\")\n        error_symbol.extend(self.mini_symbol_map.keys())\n        return sorted(set(error_symbol))\n\n    def collector_data(self):\n        \"\"\"collector data\"\"\"\n        logger.info(\"start collector data......\")\n        instrument_list = self.instrument_list\n        for i in range(self.max_collector_count):\n            if not instrument_list:\n                break\n            logger.info(f\"getting data: {i+1}\")\n            instrument_list = self._collector(instrument_list)\n            logger.info(f\"{i+1} finish.\")\n        for _symbol, _df_list in self.mini_symbol_map.items():\n            _df = pd.concat(_df_list, sort=False)\n            if not _df.empty:\n                self.save_instrument(_symbol, _df.drop_duplicates([\"date\"]).sort_values([\"date\"]))\n        if self.mini_symbol_map:\n            logger.warning(f\"less than {self.check_data_length} instrument list: {list(self.mini_symbol_map.keys())}\")\n        logger.info(f\"total {len(self.instrument_list)}, error: {len(set(instrument_list))}\")\n\n\nclass BaseNormalize(abc.ABC):\n    def __init__(self, date_field_name: str = \"date\", symbol_field_name: str = \"symbol\", **kwargs):\n        \"\"\"\n\n        Parameters\n        ----------\n        date_field_name: str\n            date field name, default is date\n        symbol_field_name: str\n            symbol field name, default is symbol\n        \"\"\"\n        self._date_field_name = date_field_name\n        self._symbol_field_name = symbol_field_name\n        self.kwargs = kwargs\n        self._calendar_list = self._get_calendar_list()\n\n    @abc.abstractmethod\n    def normalize(self, df: pd.DataFrame) -> pd.DataFrame:\n        # normalize\n        raise NotImplementedError(\"\")\n\n    @abc.abstractmethod\n    def _get_calendar_list(self) -> Iterable[pd.Timestamp]:\n        \"\"\"Get benchmark calendar\"\"\"\n        raise NotImplementedError(\"\")\n\n\nclass Normalize:\n    def __init__(\n        self,\n        source_dir: [str, Path],\n        target_dir: [str, Path],\n        normalize_class: Type[BaseNormalize],\n        max_workers: int = 16,\n        date_field_name: str = \"date\",\n        symbol_field_name: str = \"symbol\",\n        **kwargs,\n    ):\n        \"\"\"\n\n        Parameters\n        ----------\n        source_dir: str or Path\n            The directory where the raw data collected from the Internet is saved\n        target_dir: str or Path\n            Directory for normalize data\n        normalize_class: Type[YahooNormalize]\n            normalize class\n        max_workers: int\n            Concurrent number, default is 16\n        date_field_name: str\n            date field name, default is date\n        symbol_field_name: str\n            symbol field name, default is symbol\n        \"\"\"\n        if not (source_dir and target_dir):\n            raise ValueError(\"source_dir and target_dir cannot be None\")\n        self._source_dir = Path(source_dir).expanduser()\n        self._target_dir = Path(target_dir).expanduser()\n        self._target_dir.mkdir(parents=True, exist_ok=True)\n        self._date_field_name = date_field_name\n        self._symbol_field_name = symbol_field_name\n        self._end_date = kwargs.get(\"end_date\", None)\n        self._max_workers = max_workers\n\n        self._normalize_obj = normalize_class(\n            date_field_name=date_field_name, symbol_field_name=symbol_field_name, **kwargs\n        )\n\n    def _executor(self, file_path: Path):\n        file_path = Path(file_path)\n\n        # some symbol_field values such as TRUE, NA are decoded as True(bool), NaN(np.float) by pandas default csv parsing.\n        # manually defines dtype and na_values of the symbol_field.\n        default_na = pd._libs.parsers.STR_NA_VALUES  # pylint: disable=I1101\n        symbol_na = default_na.copy()\n        symbol_na.remove(\"NA\")\n        columns = pd.read_csv(file_path, nrows=0).columns\n        df = pd.read_csv(\n            file_path,\n            dtype={self._symbol_field_name: str},\n            keep_default_na=False,\n            na_values={col: symbol_na if col == self._symbol_field_name else default_na for col in columns},\n        )\n\n        # NOTE: It has been reported that there may be some problems here, and the specific issues will be dealt with when they are identified.\n        df = self._normalize_obj.normalize(df)\n        if df is not None and not df.empty:\n            if self._end_date is not None:\n                _mask = pd.to_datetime(df[self._date_field_name]) <= pd.Timestamp(self._end_date)\n                df = df[_mask]\n            df.to_csv(self._target_dir.joinpath(file_path.name), index=False)\n\n    def normalize(self):\n        logger.info(\"normalize data......\")\n\n        with ProcessPoolExecutor(max_workers=self._max_workers) as worker:\n            file_list = list(self._source_dir.glob(\"*.csv\"))\n            with tqdm(total=len(file_list)) as p_bar:\n                for _ in worker.map(self._executor, file_list):\n                    p_bar.update()\n\n\nclass BaseRun(abc.ABC):\n    def __init__(self, source_dir=None, normalize_dir=None, max_workers=1, interval=\"1d\"):\n        \"\"\"\n\n        Parameters\n        ----------\n        source_dir: str\n            The directory where the raw data collected from the Internet is saved, default \"Path(__file__).parent/source\"\n        normalize_dir: str\n            Directory for normalize data, default \"Path(__file__).parent/normalize\"\n        max_workers: int\n            Concurrent number, default is 1; Concurrent number, default is 1; when collecting data, it is recommended that max_workers be set to 1\n        interval: str\n            freq, value from [1min, 1d], default 1d\n        \"\"\"\n        if source_dir is None:\n            source_dir = Path(self.default_base_dir).joinpath(\"source\")\n        self.source_dir = Path(source_dir).expanduser().resolve()\n        self.source_dir.mkdir(parents=True, exist_ok=True)\n\n        if normalize_dir is None:\n            normalize_dir = Path(self.default_base_dir).joinpath(\"normalize\")\n        self.normalize_dir = Path(normalize_dir).expanduser().resolve()\n        self.normalize_dir.mkdir(parents=True, exist_ok=True)\n\n        self._cur_module = importlib.import_module(\"collector\")\n        self.max_workers = max_workers\n        self.interval = interval\n\n    @property\n    @abc.abstractmethod\n    def collector_class_name(self):\n        raise NotImplementedError(\"rewrite collector_class_name\")\n\n    @property\n    @abc.abstractmethod\n    def normalize_class_name(self):\n        raise NotImplementedError(\"rewrite normalize_class_name\")\n\n    @property\n    @abc.abstractmethod\n    def default_base_dir(self) -> [Path, str]:\n        raise NotImplementedError(\"rewrite default_base_dir\")\n\n    def download_data(\n        self,\n        max_collector_count=2,\n        delay=0,\n        start=None,\n        end=None,\n        check_data_length: int = None,\n        limit_nums=None,\n        **kwargs,\n    ):\n        \"\"\"download data from Internet\n\n        Parameters\n        ----------\n        max_collector_count: int\n            default 2\n        delay: float\n            time.sleep(delay), default 0\n        start: str\n            start datetime, default \"2000-01-01\"\n        end: str\n            end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))``\n        check_data_length: int\n            check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None.\n        limit_nums: int\n            using for debug, by default None\n\n        Examples\n        ---------\n            # get daily data\n            $ python collector.py download_data --source_dir ~/.qlib/instrument_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d\n            # get 1m data\n            $ python collector.py download_data --source_dir ~/.qlib/instrument_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1m\n        \"\"\"\n\n        _class = getattr(self._cur_module, self.collector_class_name)  # type: Type[BaseCollector]\n        _class(\n            self.source_dir,\n            max_workers=self.max_workers,\n            max_collector_count=max_collector_count,\n            delay=delay,\n            start=start,\n            end=end,\n            interval=self.interval,\n            check_data_length=check_data_length,\n            limit_nums=limit_nums,\n            **kwargs,\n        ).collector_data()\n\n    def normalize_data(self, date_field_name: str = \"date\", symbol_field_name: str = \"symbol\", **kwargs):\n        \"\"\"normalize data\n\n        Parameters\n        ----------\n        date_field_name: str\n            date field name, default date\n        symbol_field_name: str\n            symbol field name, default symbol\n\n        Examples\n        ---------\n            $ python collector.py normalize_data --source_dir ~/.qlib/instrument_data/source --normalize_dir ~/.qlib/instrument_data/normalize --region CN --interval 1d\n        \"\"\"\n        _class = getattr(self._cur_module, self.normalize_class_name)\n        yc = Normalize(\n            source_dir=self.source_dir,\n            target_dir=self.normalize_dir,\n            normalize_class=_class,\n            max_workers=self.max_workers,\n            date_field_name=date_field_name,\n            symbol_field_name=symbol_field_name,\n            **kwargs,\n        )\n        yc.normalize()\n"
  },
  {
    "path": "scripts/data_collector/br_index/README.md",
    "content": "# iBOVESPA History Companies Collection\n\n## Requirements\n\n- Install the libs from the file `requirements.txt`\n\n    ```bash\n    pip install -r requirements.txt\n    ```\n- `requirements.txt` file was generated using python3.8\n\n## For the ibovespa (IBOV) index, we have:\n\n<hr/>\n\n### Method `get_new_companies`\n\n#### <b>Index start date</b>\n\n- The ibovespa index started on 2 January 1968 ([wiki](https://en.wikipedia.org/wiki/%C3%8Dndice_Bovespa)).  In order to use this start date in our `bench_start_date(self)` method, two conditions must be satisfied:\n    1) APIs used to download brazilian stocks (B3) historical prices must keep track of such historic data since 2 January 1968\n\n    2) Some website or API must provide, from that date, the historic index composition. In other words, the companies used to build the index .\n\n    As a consequence, the method `bench_start_date(self)` inside `collector.py` was implemented using `pd.Timestamp(\"2003-01-03\")` due to two reasons\n\n    1) The earliest ibov composition that have been found was from the first quarter of 2003. More informations about such composition can be seen on the sections below.\n\n    2) Yahoo finance, one of the libraries used to download symbols historic prices, keeps track from this date forward.\n\n- Within the `get_new_companies` method, a logic was implemented to get, for each ibovespa component stock, the start date that yahoo finance keeps track of.\n\n#### <b>Code Logic</b>\n\nThe code does a web scrapping into the B3's [website](https://sistemaswebb3-listados.b3.com.br/indexPage/day/IBOV?language=pt-br), which keeps track of the ibovespa stocks composition on the current day. \n\nOther approaches, such as `request` and `Beautiful Soup` could have been used. However, the website shows the table with the stocks with some delay, since it uses a script inside of it to obtain such compositions.\nAlternatively, `selenium` was used to download this stocks' composition in order to overcome this problem.\n\nFuthermore, the data downloaded from the selenium script  was preprocessed so it could be saved into the `csv` format stablished by `scripts/data_collector/index.py`.\n\n<hr/>\n\n### Method `get_changes` \n\nNo suitable data source that keeps track of ibovespa's history stocks composition has been found. Except from this [repository](https://github.com/igor17400/IBOV-HCI) which provide such information have been used, however it only provides the data from the 1st quarter of 2003 to 3rd quarter of 2021.\n\nWith that reference, the index's composition can be compared quarter by quarter and year by year and then generate a file that keeps track of which stocks have been removed and which have been added each quarter and year.\n\n<hr/>\n\n### Collector Data\n\n```bash\n# parse instruments, using in qlib/instruments.\npython collector.py --index_name IBOV --qlib_dir ~/.qlib/qlib_data/br_data --method parse_instruments\n\n# parse new companies\npython collector.py --index_name IBOV --qlib_dir ~/.qlib/qlib_data/br_data --method save_new_companies\n```\n\n"
  },
  {
    "path": "scripts/data_collector/br_index/collector.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\nfrom functools import partial\nimport sys\nfrom pathlib import Path\nimport datetime\n\nimport fire\nimport pandas as pd\nfrom tqdm import tqdm\nfrom loguru import logger\n\nCUR_DIR = Path(__file__).resolve().parent\nsys.path.append(str(CUR_DIR.parent.parent))\n\nfrom data_collector.index import IndexBase\nfrom data_collector.utils import get_instruments\n\nquarter_dict = {\"1Q\": \"01-03\", \"2Q\": \"05-01\", \"3Q\": \"09-01\"}\n\n\nclass IBOVIndex(IndexBase):\n    ibov_index_composition = \"https://raw.githubusercontent.com/igor17400/IBOV-HCI/main/historic_composition/{}.csv\"\n    years_4_month_periods = []\n\n    def __init__(\n        self,\n        index_name: str,\n        qlib_dir: [str, Path] = None,\n        freq: str = \"day\",\n        request_retry: int = 5,\n        retry_sleep: int = 3,\n    ):\n        super(IBOVIndex, self).__init__(\n            index_name=index_name, qlib_dir=qlib_dir, freq=freq, request_retry=request_retry, retry_sleep=retry_sleep\n        )\n\n        self.today: datetime = datetime.date.today()\n        self.current_4_month_period = self.get_current_4_month_period(self.today.month)\n        self.year = str(self.today.year)\n        self.years_4_month_periods = self.get_four_month_period()\n\n    @property\n    def bench_start_date(self) -> pd.Timestamp:\n        \"\"\"\n        The ibovespa index started on 2 January 1968 (wiki), however,\n        no suitable data source that keeps track of ibovespa's history\n        stocks composition has been found. Except from the repo indicated\n        in README. Which keeps track of such information starting from\n        the first quarter of 2003\n        \"\"\"\n        return pd.Timestamp(\"2003-01-03\")\n\n    def get_current_4_month_period(self, current_month: int):\n        \"\"\"\n        This function is used to calculated what is the current\n        four month period for the current month. For example,\n        If the current month is August 8, its four month period\n        is 2Q.\n\n        OBS: In english Q is used to represent *quarter*\n        which means a three month period. However, in\n        portuguese we use Q to represent a four month period.\n        In other words,\n\n        Jan, Feb, Mar, Apr: 1Q\n        May, Jun, Jul, Aug: 2Q\n        Sep, Oct, Nov, Dez: 3Q\n\n        Parameters\n        ----------\n        month : int\n            Current month (1 <= month <= 12)\n\n        Returns\n        -------\n        current_4m_period:str\n            Current Four Month Period (1Q or 2Q or 3Q)\n        \"\"\"\n        if current_month < 5:\n            return \"1Q\"\n        if current_month < 9:\n            return \"2Q\"\n        if current_month <= 12:\n            return \"3Q\"\n        else:\n            return -1\n\n    def get_four_month_period(self):\n        \"\"\"\n        The ibovespa index is updated every four months.\n        Therefore, we will represent each time period as 2003_1Q\n        which means 2003 first four mount period (Jan, Feb, Mar, Apr)\n        \"\"\"\n        four_months_period = [\"1Q\", \"2Q\", \"3Q\"]\n        init_year = 2003\n        now = datetime.datetime.now()\n        current_year = now.year\n        current_month = now.month\n        for year in [item for item in range(init_year, current_year)]:  # pylint: disable=R1721\n            for el in four_months_period:\n                self.years_4_month_periods.append(str(year) + \"_\" + el)\n        # For current year the logic must be a little different\n        current_4_month_period = self.get_current_4_month_period(current_month)\n        for i in range(int(current_4_month_period[0])):\n            self.years_4_month_periods.append(str(current_year) + \"_\" + str(i + 1) + \"Q\")\n        return self.years_4_month_periods\n\n    def format_datetime(self, inst_df: pd.DataFrame) -> pd.DataFrame:\n        \"\"\"formatting the datetime in an instrument\n\n        Parameters\n        ----------\n        inst_df: pd.DataFrame\n            inst_df.columns = [self.SYMBOL_FIELD_NAME, self.START_DATE_FIELD, self.END_DATE_FIELD]\n\n        Returns\n        -------\n        inst_df: pd.DataFrame\n\n        \"\"\"\n        logger.info(\"Formatting Datetime\")\n        if self.freq != \"day\":\n            inst_df[self.END_DATE_FIELD] = inst_df[self.END_DATE_FIELD].apply(\n                lambda x: (pd.Timestamp(x) + pd.Timedelta(hours=23, minutes=59)).strftime(\"%Y-%m-%d %H:%M:%S\")\n            )\n        else:\n            inst_df[self.START_DATE_FIELD] = inst_df[self.START_DATE_FIELD].apply(\n                lambda x: (pd.Timestamp(x)).strftime(\"%Y-%m-%d\")\n            )\n\n            inst_df[self.END_DATE_FIELD] = inst_df[self.END_DATE_FIELD].apply(\n                lambda x: (pd.Timestamp(x)).strftime(\"%Y-%m-%d\")\n            )\n        return inst_df\n\n    def format_quarter(self, cell: str):\n        \"\"\"\n        Parameters\n        ----------\n        cell: str\n            It must be on the format 2003_1Q --> years_4_month_periods\n\n        Returns\n        ----------\n        date: str\n            Returns date in format 2003-03-01\n        \"\"\"\n        cell_split = cell.split(\"_\")\n        return cell_split[0] + \"-\" + quarter_dict[cell_split[1]]\n\n    def get_changes(self):\n        \"\"\"\n        Access the index historic composition and compare it quarter\n        by quarter and year by year in order to generate a file that\n        keeps track of which stocks have been removed and which have\n        been added.\n\n        The Dataframe used as reference will provided the index\n        composition for each year an quarter:\n        pd.DataFrame:\n            symbol\n            SH600000\n            SH600001\n            .\n            .\n            .\n\n        Parameters\n        ----------\n        self: is used to represent the instance of the class.\n\n        Returns\n        ----------\n        pd.DataFrame:\n            symbol      date        type\n            SH600000  2019-11-11    add\n            SH600001  2020-11-10    remove\n            dtypes:\n                symbol: str\n                date: pd.Timestamp\n                type: str, value from [\"add\", \"remove\"]\n        \"\"\"\n        logger.info(\"Getting companies changes in {} index ...\".format(self.index_name))\n\n        try:\n            df_changes_list = []\n            for i in tqdm(range(len(self.years_4_month_periods) - 1)):\n                df = pd.read_csv(\n                    self.ibov_index_composition.format(self.years_4_month_periods[i]), on_bad_lines=\"skip\"\n                )[\"symbol\"]\n                df_ = pd.read_csv(\n                    self.ibov_index_composition.format(self.years_4_month_periods[i + 1]), on_bad_lines=\"skip\"\n                )[\"symbol\"]\n\n                ## Remove Dataframe\n                remove_date = (\n                    self.years_4_month_periods[i].split(\"_\")[0]\n                    + \"-\"\n                    + quarter_dict[self.years_4_month_periods[i].split(\"_\")[1]]\n                )\n                list_remove = list(df[~df.isin(df_)])\n                df_removed = pd.DataFrame(\n                    {\n                        \"date\": len(list_remove) * [remove_date],\n                        \"type\": len(list_remove) * [\"remove\"],\n                        \"symbol\": list_remove,\n                    }\n                )\n\n                ## Add Dataframe\n                add_date = (\n                    self.years_4_month_periods[i + 1].split(\"_\")[0]\n                    + \"-\"\n                    + quarter_dict[self.years_4_month_periods[i + 1].split(\"_\")[1]]\n                )\n                list_add = list(df_[~df_.isin(df)])\n                df_added = pd.DataFrame(\n                    {\"date\": len(list_add) * [add_date], \"type\": len(list_add) * [\"add\"], \"symbol\": list_add}\n                )\n\n                df_changes_list.append(pd.concat([df_added, df_removed], sort=False))\n                df = pd.concat(df_changes_list).reset_index(drop=True)\n                df[\"symbol\"] = df[\"symbol\"].astype(str) + \".SA\"\n\n            return df\n\n        except Exception as E:\n            logger.error(\"An error occured while downloading 2008 index composition - {}\".format(E))\n\n    def get_new_companies(self):\n        \"\"\"\n        Get latest index composition.\n        The repo indicated on README has implemented a script\n        to get the latest index composition from B3 website using\n        selenium. Therefore, this method will download the file\n        containing such composition\n\n        Parameters\n        ----------\n        self: is used to represent the instance of the class.\n\n        Returns\n        ----------\n        pd.DataFrame:\n            symbol      start_date  end_date\n            RRRP3\t    2020-11-13\t2022-03-02\n            ALPA4\t    2008-01-02\t2022-03-02\n            dtypes:\n                symbol: str\n                start_date: pd.Timestamp\n                end_date: pd.Timestamp\n        \"\"\"\n        logger.info(\"Getting new companies in {} index ...\".format(self.index_name))\n\n        try:\n            ## Get index composition\n\n            df_index = pd.read_csv(\n                self.ibov_index_composition.format(self.year + \"_\" + self.current_4_month_period), on_bad_lines=\"skip\"\n            )\n            df_date_first_added = pd.read_csv(\n                self.ibov_index_composition.format(\"date_first_added_\" + self.year + \"_\" + self.current_4_month_period),\n                on_bad_lines=\"skip\",\n            )\n            df = df_index.merge(df_date_first_added, on=\"symbol\")[[\"symbol\", \"Date First Added\"]]\n            df[self.START_DATE_FIELD] = df[\"Date First Added\"].map(self.format_quarter)\n\n            # end_date will be our current quarter + 1, since the IBOV index updates itself every quarter\n            df[self.END_DATE_FIELD] = self.year + \"-\" + quarter_dict[self.current_4_month_period]\n            df = df[[\"symbol\", self.START_DATE_FIELD, self.END_DATE_FIELD]]\n            df[\"symbol\"] = df[\"symbol\"].astype(str) + \".SA\"\n\n            return df\n\n        except Exception as E:\n            logger.error(\"An error occured while getting new companies - {}\".format(E))\n\n    def filter_df(self, df: pd.DataFrame) -> pd.DataFrame:\n        if \"Código\" in df.columns:\n            return df.loc[:, [\"Código\"]].copy()\n\n\nif __name__ == \"__main__\":\n    fire.Fire(partial(get_instruments, market_index=\"br_index\"))\n"
  },
  {
    "path": "scripts/data_collector/br_index/requirements.txt",
    "content": "async-generator==1.10\nattrs==21.4.0\ncertifi==2022.12.7\ncffi==1.15.0\ncharset-normalizer==2.0.12\ncryptography==36.0.1\nfire==0.4.0\nh11==0.13.0\nidna==3.3\nloguru==0.6.0\nlxml==4.9.1\nmultitasking==0.0.10\nnumpy==1.22.2\noutcome==1.1.0\npandas==1.4.1\npycoingecko==2.2.0\npycparser==2.21\npyOpenSSL==22.0.0\nPySocks==1.7.1\npython-dateutil==2.8.2\npytz==2021.3\nrequests==2.27.1\nrequests-futures==1.0.0\nsix==1.16.0\nsniffio==1.2.0\nsortedcontainers==2.4.0\ntermcolor==1.1.0\ntqdm==4.63.0\ntrio==0.20.0\ntrio-websocket==0.9.2\nurllib3==1.26.19\nwget==3.2\nwsproto==1.1.0\nyahooquery==2.2.15\n"
  },
  {
    "path": "scripts/data_collector/cn_index/README.md",
    "content": "# CSI300/CSI100/CSI500 History Companies Collection\n\n## Requirements\n\n```bash\npip install -r requirements.txt\n```\n\n## Collector Data\n\n```bash\n# parse instruments, using in qlib/instruments.\npython collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --method parse_instruments\n\n# parse new companies\npython collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --method save_new_companies\n\n# index_name support: CSI300, CSI100, CSI500\n# help\npython collector.py --help\n```\n\n"
  },
  {
    "path": "scripts/data_collector/cn_index/collector.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport re\nimport abc\nimport sys\nfrom io import BytesIO\nfrom typing import List, Iterable\nfrom pathlib import Path\n\nimport fire\nimport requests\nimport pandas as pd\nimport baostock as bs\nfrom tqdm import tqdm\nfrom loguru import logger\n\nCUR_DIR = Path(__file__).resolve().parent\nsys.path.append(str(CUR_DIR.parent.parent))\n\nfrom data_collector.index import IndexBase\nfrom data_collector.utils import get_calendar_list, get_trading_date_by_shift, deco_retry\nfrom data_collector.utils import get_instruments\n\nNEW_COMPANIES_URL = (\n    \"https://oss-ch.csindex.com.cn/static/html/csindex/public/uploads/file/autofile/cons/{index_code}cons.xls\"\n)\n\n\nINDEX_CHANGES_URL = \"https://www.csindex.com.cn/csindex-home/search/search-content?lang=cn&searchInput=%E5%85%B3%E4%BA%8E%E8%B0%83%E6%95%B4%E6%B2%AA%E6%B7%B1300%E5%92%8C%E4%B8%AD%E8%AF%81%E9%A6%99%E6%B8%AF100%E7%AD%89%E6%8C%87%E6%95%B0%E6%A0%B7%E6%9C%AC&pageNum={page_num}&pageSize={page_size}&sortField=date&dateRange=all&contentType=announcement\"\n\nREQ_HEADERS = {\n    \"User-Agent\": \"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.101 Safari/537.36 Edg/91.0.864.48\"\n}\n\n\n@deco_retry\ndef retry_request(url: str, method: str = \"get\", exclude_status: List = None):\n    if exclude_status is None:\n        exclude_status = []\n    method_func = getattr(requests, method)\n    _resp = method_func(url, headers=REQ_HEADERS, timeout=None)\n    _status = _resp.status_code\n    if _status not in exclude_status and _status != 200:\n        raise ValueError(f\"response status: {_status}, url={url}\")\n    return _resp\n\n\nclass CSIIndex(IndexBase):\n    @property\n    def calendar_list(self) -> List[pd.Timestamp]:\n        \"\"\"get history trading date\n\n        Returns\n        -------\n            calendar list\n        \"\"\"\n        _calendar = getattr(self, \"_calendar_list\", None)\n        if not _calendar:\n            _calendar = get_calendar_list(bench_code=self.index_name.upper())\n            setattr(self, \"_calendar_list\", _calendar)\n        return _calendar\n\n    @property\n    def new_companies_url(self) -> str:\n        return NEW_COMPANIES_URL.format(index_code=self.index_code)\n\n    @property\n    def changes_url(self) -> str:\n        return INDEX_CHANGES_URL\n\n    @property\n    @abc.abstractmethod\n    def bench_start_date(self) -> pd.Timestamp:\n        \"\"\"\n        Returns\n        -------\n            index start date\n        \"\"\"\n        raise NotImplementedError(\"rewrite bench_start_date\")\n\n    @property\n    @abc.abstractmethod\n    def index_code(self) -> str:\n        \"\"\"\n        Returns\n        -------\n            index code\n        \"\"\"\n        raise NotImplementedError(\"rewrite index_code\")\n\n    @property\n    def html_table_index(self) -> int:\n        \"\"\"Which table of changes in html\n\n        CSI300: 0\n        CSI100: 1\n        :return:\n        \"\"\"\n        raise NotImplementedError(\"rewrite html_table_index\")\n\n    def format_datetime(self, inst_df: pd.DataFrame) -> pd.DataFrame:\n        \"\"\"formatting the datetime in an instrument\n\n        Parameters\n        ----------\n        inst_df: pd.DataFrame\n            inst_df.columns = [self.SYMBOL_FIELD_NAME, self.START_DATE_FIELD, self.END_DATE_FIELD]\n\n        Returns\n        -------\n\n        \"\"\"\n        if self.freq != \"day\":\n            inst_df[self.START_DATE_FIELD] = inst_df[self.START_DATE_FIELD].apply(\n                lambda x: (pd.Timestamp(x) + pd.Timedelta(hours=9, minutes=30)).strftime(\"%Y-%m-%d %H:%M:%S\")\n            )\n            inst_df[self.END_DATE_FIELD] = inst_df[self.END_DATE_FIELD].apply(\n                lambda x: (pd.Timestamp(x) + pd.Timedelta(hours=15, minutes=0)).strftime(\"%Y-%m-%d %H:%M:%S\")\n            )\n        return inst_df\n\n    def get_changes(self) -> pd.DataFrame:\n        \"\"\"get companies changes\n\n        Returns\n        -------\n            pd.DataFrame:\n                symbol      date        type\n                SH600000  2019-11-11    add\n                SH600000  2020-11-10    remove\n            dtypes:\n                symbol: str\n                date: pd.Timestamp\n                type: str, value from [\"add\", \"remove\"]\n        \"\"\"\n        logger.info(\"get companies changes......\")\n        res = []\n        for _url in self._get_change_notices_url():\n            _df = self._read_change_from_url(_url)\n            if not _df.empty:\n                res.append(_df)\n        logger.info(\"get companies changes finish\")\n        return pd.concat(res, sort=False)\n\n    @staticmethod\n    def normalize_symbol(symbol: str) -> str:\n        \"\"\"\n\n        Parameters\n        ----------\n        symbol: str\n            symbol\n\n        Returns\n        -------\n            symbol\n        \"\"\"\n        symbol = f\"{int(symbol):06}\"\n        return f\"SH{symbol}\" if symbol.startswith(\"60\") or symbol.startswith(\"688\") else f\"SZ{symbol}\"\n\n    def _parse_excel(self, excel_url: str, add_date: pd.Timestamp, remove_date: pd.Timestamp) -> pd.DataFrame:\n        content = retry_request(excel_url, exclude_status=[404]).content\n        _io = BytesIO(content)\n        df_map = pd.read_excel(_io, sheet_name=None)\n        with self.cache_dir.joinpath(\n            f\"{self.index_name.lower()}_changes_{add_date.strftime('%Y%m%d')}.{excel_url.split('.')[-1]}\"\n        ).open(\"wb\") as fp:\n            fp.write(content)\n        tmp = []\n        for _s_name, _type, _date in [(\"调入\", self.ADD, add_date), (\"调出\", self.REMOVE, remove_date)]:\n            _df = df_map[_s_name]\n            _df = _df.loc[_df[\"指数代码\"] == self.index_code, [\"证券代码\"]]\n            _df = _df.applymap(self.normalize_symbol)\n            _df.columns = [self.SYMBOL_FIELD_NAME]\n            _df[\"type\"] = _type\n            _df[self.DATE_FIELD_NAME] = _date\n            tmp.append(_df)\n        df = pd.concat(tmp)\n        return df\n\n    def _parse_table(self, content: str, add_date: pd.DataFrame, remove_date: pd.DataFrame) -> pd.DataFrame:\n        df = pd.DataFrame()\n        _tmp_count = 0\n        for _df in pd.read_html(content):\n            if _df.shape[-1] != 4 or _df.isnull().loc(0)[0][0]:\n                continue\n            _tmp_count += 1\n            if self.html_table_index + 1 > _tmp_count:\n                continue\n            tmp = []\n            for _s, _type, _date in [\n                (_df.iloc[2:, 0], self.REMOVE, remove_date),\n                (_df.iloc[2:, 2], self.ADD, add_date),\n            ]:\n                _tmp_df = pd.DataFrame()\n                _tmp_df[self.SYMBOL_FIELD_NAME] = _s.map(self.normalize_symbol)\n                _tmp_df[\"type\"] = _type\n                _tmp_df[self.DATE_FIELD_NAME] = _date\n                tmp.append(_tmp_df)\n            df = pd.concat(tmp)\n            df.to_csv(\n                str(\n                    self.cache_dir.joinpath(\n                        f\"{self.index_name.lower()}_changes_{add_date.strftime('%Y%m%d')}.csv\"\n                    ).resolve()\n                )\n            )\n            break\n        return df\n\n    def _read_change_from_url(self, url: str) -> pd.DataFrame:\n        \"\"\"read change from url\n        The parameter url is from the _get_change_notices_url method.\n        Determine the stock add_date/remove_date based on the title.\n        The response contains three cases:\n            1.Only excel_url(extract data from excel_url)\n            2.Both the excel_url and the body text(try to extract data from excel_url first, and then try to extract data from body text)\n            3.Only body text(extract data from body text)\n\n        Parameters\n        ----------\n        url : str\n            change url\n\n        Returns\n        -------\n            pd.DataFrame:\n                symbol      date        type\n                SH600000  2019-11-11    add\n                SH600000  2020-11-10    remove\n            dtypes:\n                symbol: str\n                date: pd.Timestamp\n                type: str, value from [\"add\", \"remove\"]\n        \"\"\"\n        resp = retry_request(url).json()[\"data\"]\n        title = resp[\"title\"]\n        if not title.startswith(\"关于\"):\n            return pd.DataFrame()\n        if \"沪深300\" not in title:\n            return pd.DataFrame()\n\n        logger.info(f\"load index data from https://www.csindex.com.cn/#/about/newsDetail?id={url.split('id=')[-1]}\")\n        _text = resp[\"content\"]\n        date_list = re.findall(r\"(\\d{4}).*?年.*?(\\d+).*?月.*?(\\d+).*?日\", _text)\n        if len(date_list) >= 2:\n            add_date = pd.Timestamp(\"-\".join(date_list[0]))\n        else:\n            _date = pd.Timestamp(\"-\".join(re.findall(r\"(\\d{4}).*?年.*?(\\d+).*?月\", _text)[0]))\n            add_date = get_trading_date_by_shift(self.calendar_list, _date, shift=0)\n        if \"盘后\" in _text or \"市后\" in _text:\n            add_date = get_trading_date_by_shift(self.calendar_list, add_date, shift=1)\n        remove_date = get_trading_date_by_shift(self.calendar_list, add_date, shift=-1)\n\n        excel_url = None\n        if resp.get(\"enclosureList\", []):\n            excel_url = resp[\"enclosureList\"][0][\"fileUrl\"]\n        else:\n            excel_url_list = re.findall('.*href=\"(.*?xls.*?)\".*', _text)\n            if excel_url_list:\n                excel_url = excel_url_list[0]\n                if not excel_url.startswith(\"http\"):\n                    excel_url = excel_url if excel_url.startswith(\"/\") else \"/\" + excel_url\n                    excel_url = f\"http://www.csindex.com.cn{excel_url}\"\n        if excel_url:\n            try:\n                logger.info(f\"get {add_date} changes from the excel, title={title}, excel_url={excel_url}\")\n                df = self._parse_excel(excel_url, add_date, remove_date)\n            except ValueError:\n                logger.info(\n                    f\"get {add_date} changes from the web page, title={title}, url=https://www.csindex.com.cn/#/about/newsDetail?id={url.split('id=')[-1]}\"\n                )\n                df = self._parse_table(_text, add_date, remove_date)\n        else:\n            logger.info(\n                f\"get {add_date} changes from the web page, title={title}, url=https://www.csindex.com.cn/#/about/newsDetail?id={url.split('id=')[-1]}\"\n            )\n            df = self._parse_table(_text, add_date, remove_date)\n        return df\n\n    def _get_change_notices_url(self) -> Iterable[str]:\n        \"\"\"get change notices url\n\n        Returns\n        -------\n            [url1, url2]\n        \"\"\"\n        page_num = 1\n        page_size = 5\n        data = retry_request(self.changes_url.format(page_size=page_size, page_num=page_num)).json()\n        data = retry_request(self.changes_url.format(page_size=data[\"total\"], page_num=page_num)).json()\n        for item in data[\"data\"]:\n            yield f\"https://www.csindex.com.cn/csindex-home/announcement/queryAnnouncementById?id={item['id']}\"\n\n    def get_new_companies(self) -> pd.DataFrame:\n        \"\"\"\n\n        Returns\n        -------\n            pd.DataFrame:\n\n                symbol     start_date    end_date\n                SH600000   2000-01-01    2099-12-31\n\n            dtypes:\n                symbol: str\n                start_date: pd.Timestamp\n                end_date: pd.Timestamp\n        \"\"\"\n        logger.info(\"get new companies......\")\n        context = retry_request(self.new_companies_url).content\n        with self.cache_dir.joinpath(\n            f\"{self.index_name.lower()}_new_companies.{self.new_companies_url.split('.')[-1]}\"\n        ).open(\"wb\") as fp:\n            fp.write(context)\n        _io = BytesIO(context)\n        df = pd.read_excel(_io)\n        df = df.iloc[:, [0, 4]]\n        df.columns = [self.END_DATE_FIELD, self.SYMBOL_FIELD_NAME]\n        df[self.SYMBOL_FIELD_NAME] = df[self.SYMBOL_FIELD_NAME].map(self.normalize_symbol)\n        df[self.END_DATE_FIELD] = pd.to_datetime(df[self.END_DATE_FIELD].astype(str))\n        df[self.START_DATE_FIELD] = self.bench_start_date\n        logger.info(\"end of get new companies.\")\n        return df\n\n\nclass CSI300Index(CSIIndex):\n    @property\n    def index_code(self):\n        return \"000300\"\n\n    @property\n    def bench_start_date(self) -> pd.Timestamp:\n        return pd.Timestamp(\"2005-01-01\")\n\n    @property\n    def html_table_index(self) -> int:\n        return 0\n\n\nclass CSI100Index(CSIIndex):\n    @property\n    def index_code(self):\n        return \"000903\"\n\n    @property\n    def bench_start_date(self) -> pd.Timestamp:\n        return pd.Timestamp(\"2006-05-29\")\n\n    @property\n    def html_table_index(self) -> int:\n        return 1\n\n\nclass CSI500Index(CSIIndex):\n    @property\n    def index_code(self) -> str:\n        return \"000905\"\n\n    @property\n    def bench_start_date(self) -> pd.Timestamp:\n        return pd.Timestamp(\"2007-01-15\")\n\n    def get_changes(self) -> pd.DataFrame:\n        \"\"\"get companies changes\n\n        Return\n        --------\n           pd.DataFrame:\n               symbol      date        type\n               SH600000  2019-11-11    add\n               SH600000  2020-11-10    remove\n           dtypes:\n               symbol: str\n               date: pd.Timestamp\n               type: str, value from [\"add\", \"remove\"]\n        \"\"\"\n        return self.get_changes_with_history_companies(self.get_history_companies())\n\n    def get_history_companies(self) -> pd.DataFrame:\n        \"\"\"\n\n        Returns\n        -------\n\n            pd.DataFrame:\n                symbol      date        type\n                SH600000  2019-11-11    add\n                SH600000  2020-11-10    remove\n            dtypes:\n                symbol: str\n                date: pd.Timestamp\n                type: str, value from [\"add\", \"remove\"]\n        \"\"\"\n        bs.login()\n        today = pd.Timestamp.now()\n        date_range = pd.DataFrame(pd.date_range(start=\"2007-01-15\", end=today, freq=\"7D\"))[0].dt.date\n        ret_list = []\n        for date in tqdm(date_range, desc=\"Download CSI500\"):\n            result = self.get_data_from_baostock(date)\n            ret_list.append(result[[\"date\", \"symbol\"]])\n        bs.logout()\n        return pd.concat(ret_list, sort=False)\n\n    @staticmethod\n    def get_data_from_baostock(date) -> pd.DataFrame:\n        \"\"\"\n        Data source: http://baostock.com/baostock/index.php/%E4%B8%AD%E8%AF%81500%E6%88%90%E5%88%86%E8%82%A1\n        Avoid a large number of parallel data acquisition,\n        such as 1000 times of concurrent data acquisition, because IP will be blocked\n\n        Returns\n        -------\n            pd.DataFrame:\n                date      symbol        code_name\n                SH600039  2007-01-15    四川路桥\n                SH600051  2020-01-15    宁波联合\n            dtypes:\n                date: pd.Timestamp\n                symbol: str\n                code_name: str\n        \"\"\"\n        col = [\"date\", \"symbol\", \"code_name\"]\n        rs = bs.query_zz500_stocks(date=str(date))\n        zz500_stocks = []\n        while (rs.error_code == \"0\") & rs.next():\n            zz500_stocks.append(rs.get_row_data())\n        result = pd.DataFrame(zz500_stocks, columns=col)\n        result[\"symbol\"] = result[\"symbol\"].apply(lambda x: x.replace(\".\", \"\").upper())\n        return result\n\n    def get_new_companies(self) -> pd.DataFrame:\n        \"\"\"\n\n        Returns\n        -------\n            pd.DataFrame:\n\n                symbol     start_date    end_date\n                SH600000   2000-01-01    2099-12-31\n\n            dtypes:\n                symbol: str\n                start_date: pd.Timestamp\n                end_date: pd.Timestamp\n        \"\"\"\n        logger.info(\"get new companies......\")\n        today = pd.Timestamp.now().normalize()\n        bs.login()\n        result = self.get_data_from_baostock(today.strftime(\"%Y-%m-%d\"))\n        bs.logout()\n        df = result[[\"date\", \"symbol\"]]\n        df.columns = [self.END_DATE_FIELD, self.SYMBOL_FIELD_NAME]\n        df[self.END_DATE_FIELD] = today\n        df[self.START_DATE_FIELD] = self.bench_start_date\n        logger.info(\"end of get new companies.\")\n        return df\n\n\nif __name__ == \"__main__\":\n    fire.Fire(get_instruments)\n"
  },
  {
    "path": "scripts/data_collector/cn_index/requirements.txt",
    "content": "baostock\nfire\nrequests\npandas\nlxml\nloguru\ntqdm\nyahooquery\nopenpyxl\n"
  },
  {
    "path": "scripts/data_collector/contrib/fill_cn_1min_data/README.md",
    "content": "# Use 1d data to fill in the missing symbols relative to 1min\n\n\n## Requirements\n\n```bash\npip install -r requirements.txt\n```\n\n## fill 1min data\n\n```bash\npython fill_cn_1min_data.py --data_1min_dir ~/.qlib/csv_data/cn_data_1min --qlib_data_1d_dir ~/.qlib/qlib_data/cn_data\n```\n\n## Parameters\n\n- data_1min_dir: csv data\n- qlib_data_1d_dir: qlib data directory\n- max_workers: `ThreadPoolExecutor(max_workers=max_workers)`, by default *16*\n- date_field_name: date field name, by default *date*\n- symbol_field_name: symbol field name, by default *symbol*\n\n"
  },
  {
    "path": "scripts/data_collector/contrib/fill_cn_1min_data/fill_cn_1min_data.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport sys\nfrom pathlib import Path\nfrom concurrent.futures import ThreadPoolExecutor\n\nimport fire\nimport qlib\nimport pandas as pd\nfrom tqdm import tqdm\nfrom qlib.data import D\nfrom loguru import logger\n\nCUR_DIR = Path(__file__).resolve().parent\nsys.path.append(str(CUR_DIR.parent.parent.parent))\nfrom data_collector.utils import generate_minutes_calendar_from_daily\n\n\ndef get_date_range(data_1min_dir: Path, max_workers: int = 16, date_field_name: str = \"date\"):\n    csv_files = list(data_1min_dir.glob(\"*.csv\"))\n    min_date = None\n    max_date = None\n    with tqdm(total=len(csv_files)) as p_bar:\n        with ThreadPoolExecutor(max_workers=max_workers) as executor:\n            for _file, _result in zip(csv_files, executor.map(pd.read_csv, csv_files)):\n                if not _result.empty:\n                    _dates = pd.to_datetime(_result[date_field_name])\n\n                    _tmp_min = _dates.min()\n                    min_date = min(min_date, _tmp_min) if min_date is not None else _tmp_min\n                    _tmp_max = _dates.max()\n                    max_date = max(max_date, _tmp_max) if max_date is not None else _tmp_max\n                p_bar.update()\n    return min_date, max_date\n\n\ndef get_symbols(data_1min_dir: Path):\n    return list(map(lambda x: x.name[:-4].upper(), data_1min_dir.glob(\"*.csv\")))\n\n\ndef fill_1min_using_1d(\n    data_1min_dir: [str, Path],\n    qlib_data_1d_dir: [str, Path],\n    max_workers: int = 16,\n    date_field_name: str = \"date\",\n    symbol_field_name: str = \"symbol\",\n):\n    \"\"\"Use 1d data to fill in the missing symbols relative to 1min\n\n    Parameters\n    ----------\n    data_1min_dir: str\n        1min data dir\n    qlib_data_1d_dir: str\n        1d qlib data(bin data) dir, from: https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format\n    max_workers: int\n        ThreadPoolExecutor(max_workers), by default 16\n    date_field_name: str\n        date field name, by default date\n    symbol_field_name: str\n        symbol field name, by default symbol\n\n    \"\"\"\n    data_1min_dir = Path(data_1min_dir).expanduser().resolve()\n    qlib_data_1d_dir = Path(qlib_data_1d_dir).expanduser().resolve()\n\n    min_date, max_date = get_date_range(data_1min_dir, max_workers, date_field_name)\n    symbols_1min = get_symbols(data_1min_dir)\n\n    qlib.init(provider_uri=str(qlib_data_1d_dir))\n    data_1d = D.features(D.instruments(\"all\"), [\"$close\"], min_date, max_date, freq=\"day\")\n\n    miss_symbols = set(data_1d.index.get_level_values(level=\"instrument\").unique()) - set(symbols_1min)\n    if not miss_symbols:\n        logger.warning(\"More symbols in 1min than 1d, no padding required\")\n        return\n\n    logger.info(f\"miss_symbols  {len(miss_symbols)}: {miss_symbols}\")\n    tmp_df = pd.read_csv(list(data_1min_dir.glob(\"*.csv\"))[0])\n    columns = tmp_df.columns\n    _si = tmp_df[symbol_field_name].first_valid_index()\n    is_lower = tmp_df.loc[_si][symbol_field_name].islower()\n    for symbol in tqdm(miss_symbols):\n        if is_lower:\n            symbol = symbol.lower()\n        index_1d = data_1d.loc(axis=0)[symbol.upper()].index\n        index_1min = generate_minutes_calendar_from_daily(index_1d)\n        index_1min.name = date_field_name\n        _df = pd.DataFrame(columns=columns, index=index_1min)\n        if date_field_name in _df.columns:\n            del _df[date_field_name]\n        _df.reset_index(inplace=True)\n        _df[symbol_field_name] = symbol\n        _df[\"paused_num\"] = 0\n        _df.to_csv(data_1min_dir.joinpath(f\"{symbol}.csv\"), index=False)\n\n\nif __name__ == \"__main__\":\n    fire.Fire(fill_1min_using_1d)\n"
  },
  {
    "path": "scripts/data_collector/contrib/fill_cn_1min_data/requirements.txt",
    "content": "fire\npandas\nloguru\ntqdm\npyqlib"
  },
  {
    "path": "scripts/data_collector/contrib/future_trading_date_collector/README.md",
    "content": "# Get future trading days\n\n> `D.calendar(future=True)` will be used\n\n## Requirements\n\n```bash\npip install -r requirements.txt\n```\n\n## Collector Data\n\n```bash\n# parse instruments, using in qlib/instruments.\npython future_trading_date_collector.py --qlib_dir ~/.qlib/qlib_data/cn_data --freq day\n```\n\n## Parameters\n\n- qlib_dir: qlib data directory\n- freq: value from [`day`, `1min`], default `day`\n\n\n\n"
  },
  {
    "path": "scripts/data_collector/contrib/future_trading_date_collector/future_trading_date_collector.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport sys\nfrom typing import List\nfrom pathlib import Path\n\nimport fire\nimport numpy as np\nimport pandas as pd\nfrom loguru import logger\n\n# get data from baostock\nimport baostock as bs\n\nCUR_DIR = Path(__file__).resolve().parent\nsys.path.append(str(CUR_DIR.parent.parent.parent))\n\n\nfrom data_collector.utils import generate_minutes_calendar_from_daily\n\n\ndef read_calendar_from_qlib(qlib_dir: Path) -> pd.DataFrame:\n    calendar_path = qlib_dir.joinpath(\"calendars\").joinpath(\"day.txt\")\n    if not calendar_path.exists():\n        return pd.DataFrame()\n    return pd.read_csv(calendar_path, header=None)\n\n\ndef write_calendar_to_qlib(qlib_dir: Path, date_list: List[str], freq: str = \"day\"):\n    calendar_path = str(qlib_dir.joinpath(\"calendars\").joinpath(f\"{freq}_future.txt\"))\n\n    np.savetxt(calendar_path, date_list, fmt=\"%s\", encoding=\"utf-8\")\n    logger.info(f\"write future calendars success: {calendar_path}\")\n\n\ndef generate_qlib_calendar(date_list: List[str], freq: str) -> List[str]:\n    print(freq)\n    if freq == \"day\":\n        return date_list\n    elif freq == \"1min\":\n        date_list = generate_minutes_calendar_from_daily(date_list, freq=freq).tolist()\n        return list(map(lambda x: pd.Timestamp(x).strftime(\"%Y-%m-%d %H:%M:%S\"), date_list))\n    else:\n        raise ValueError(f\"Unsupported freq: {freq}\")\n\n\ndef future_calendar_collector(qlib_dir: [str, Path], freq: str = \"day\"):\n    \"\"\"get future calendar\n\n    Parameters\n    ----------\n    qlib_dir: str or Path\n        qlib data directory\n    freq: str\n        value from [\"day\", \"1min\"], by default day\n    \"\"\"\n    qlib_dir = Path(qlib_dir).expanduser().resolve()\n    if not qlib_dir.exists():\n        raise FileNotFoundError(str(qlib_dir))\n\n    lg = bs.login()\n    if lg.error_code != \"0\":\n        logger.error(f\"login error: {lg.error_msg}\")\n        return\n    # read daily calendar\n    daily_calendar = read_calendar_from_qlib(qlib_dir)\n    end_year = pd.Timestamp.now().year\n    if daily_calendar.empty:\n        start_year = pd.Timestamp.now().year\n    else:\n        start_year = pd.Timestamp(daily_calendar.iloc[-1, 0]).year\n    rs = bs.query_trade_dates(start_date=pd.Timestamp(f\"{start_year}-01-01\"), end_date=f\"{end_year}-12-31\")\n    data_list = []\n    while (rs.error_code == \"0\") & rs.next():\n        _row_data = rs.get_row_data()\n        if int(_row_data[1]) == 1:\n            data_list.append(_row_data[0])\n    data_list = sorted(data_list)\n    date_list = generate_qlib_calendar(data_list, freq=freq)\n    date_list = sorted(set(daily_calendar.loc[:, 0].values.tolist() + date_list))\n    write_calendar_to_qlib(qlib_dir, date_list, freq=freq)\n    bs.logout()\n    logger.info(f\"get trading dates success: {start_year}-01-01 to {end_year}-12-31\")\n\n\nif __name__ == \"__main__\":\n    fire.Fire(future_calendar_collector)\n"
  },
  {
    "path": "scripts/data_collector/contrib/future_trading_date_collector/requirements.txt",
    "content": "baostock\nfire\nnumpy\npandas\nloguru\n"
  },
  {
    "path": "scripts/data_collector/crowd_source/README.md",
    "content": "# Crowd Source Data\n\n## Initiative\nPublic data source like yahoo is flawed, it might miss data for stock which is delisted and it might have data which is wrong. This can introduce survivorship bias into our training process.\n\nThe Crowd Source Data is introduced to merged data from multiple data source and cross validate against each other, so that:\n1. We will have a more complete history record.\n2. We can identify the anomaly data and apply correction when necessary.\n\n## Related Repo\nThe raw data is hosted on dolthub repo: https://www.dolthub.com/repositories/chenditc/investment_data\n\nThe processing script and sql is hosted on github repo: https://github.com/chenditc/investment_data\n\nThe packaged docker runtime is hosted on dockerhub: https://hub.docker.com/repository/docker/chenditc/investment_data\n\n## How to use it in qlib\n### Option 1: Download release bin data\nUser can download data in qlib bin format and use it directly: https://github.com/chenditc/investment_data/releases/latest\n```bash\nwget https://github.com/chenditc/investment_data/releases/latest/download/qlib_bin.tar.gz\ntar -zxvf qlib_bin.tar.gz -C ~/.qlib/qlib_data/cn_data --strip-components=2\n```\n\n### Option 2: Generate qlib data from dolthub\nDolthub data will be update daily, so that if user wants to get up to date data, they can dump qlib bin using docker:\n```\ndocker run -v /<some output directory>:/output -it --rm chenditc/investment_data bash dump_qlib_bin.sh && cp ./qlib_bin.tar.gz /output/\n```\n\n## FAQ and other info\nSee: https://github.com/chenditc/investment_data/blob/main/README.md\n"
  },
  {
    "path": "scripts/data_collector/crypto/README.md",
    "content": "# Collect Crypto Data\n\n> *Please pay **ATTENTION** that the data is collected from [Coingecko](https://www.coingecko.com/en/api) and the data might not be perfect. We recommend users to prepare their own data if they have high-quality dataset. For more information, users can refer to the [related document](https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format)*\n\n## Requirements\n\n```bash\npip install -r requirements.txt\n```\n\n## Usage of the dataset\n> *Crypto dataset only support Data retrieval function but not support backtest function due to the lack of OHLC data.*\n\n## Collector Data\n\n\n### Crypto Data\n\n#### 1d from Coingecko\n\n```bash\n\n# download from https://api.coingecko.com/api/v3/\npython collector.py download_data --source_dir ~/.qlib/crypto_data/source/1d --start 2015-01-01 --end 2021-11-30 --delay 1 --interval 1d\n\n# normalize\npython collector.py normalize_data --source_dir ~/.qlib/crypto_data/source/1d --normalize_dir ~/.qlib/crypto_data/source/1d_nor --interval 1d --date_field_name date\n\n# dump data\ncd qlib/scripts\npython dump_bin.py dump_all --data_path ~/.qlib/crypto_data/source/1d_nor --qlib_dir ~/.qlib/qlib_data/crypto_data --freq day --date_field_name date --include_fields prices,total_volumes,market_caps\n\n```\n\n### using data\n\n```python\nimport qlib\nfrom qlib.data import D\n\nqlib.init(provider_uri=\"~/.qlib/qlib_data/crypto_data\")\ndf = D.features(D.instruments(market=\"all\"), [\"$prices\", \"$total_volumes\",\"$market_caps\"], freq=\"day\")\n```\n\n\n### Help\n```bash\npython collector.py collector_data --help\n```\n\n## Parameters\n\n- interval: 1d\n- delay: 1\n"
  },
  {
    "path": "scripts/data_collector/crypto/collector.py",
    "content": "import abc\nimport sys\nimport datetime\nfrom abc import ABC\nfrom pathlib import Path\n\nimport fire\nimport pandas as pd\nfrom loguru import logger\nfrom dateutil.tz import tzlocal\n\nCUR_DIR = Path(__file__).resolve().parent\nsys.path.append(str(CUR_DIR.parent.parent))\nfrom data_collector.base import BaseCollector, BaseNormalize, BaseRun\nfrom data_collector.utils import deco_retry\n\nfrom pycoingecko import CoinGeckoAPI\nfrom time import mktime\nfrom datetime import datetime as dt\nimport time\n\n_CG_CRYPTO_SYMBOLS = None\n\n\ndef get_cg_crypto_symbols(qlib_data_path: [str, Path] = None) -> list:\n    \"\"\"get crypto symbols in coingecko\n\n    Returns\n    -------\n        crypto symbols in given exchanges list of coingecko\n    \"\"\"\n    global _CG_CRYPTO_SYMBOLS  # pylint: disable=W0603\n\n    @deco_retry\n    def _get_coingecko():\n        try:\n            cg = CoinGeckoAPI()\n            resp = pd.DataFrame(cg.get_coins_markets(vs_currency=\"usd\"))\n        except Exception as e:\n            raise ValueError(\"request error\") from e\n        try:\n            _symbols = resp[\"id\"].to_list()\n        except Exception as e:\n            logger.warning(f\"request error: {e}\")\n            raise\n        return _symbols\n\n    if _CG_CRYPTO_SYMBOLS is None:\n        _all_symbols = _get_coingecko()\n\n        _CG_CRYPTO_SYMBOLS = sorted(set(_all_symbols))\n\n    return _CG_CRYPTO_SYMBOLS\n\n\nclass CryptoCollector(BaseCollector):\n    def __init__(\n        self,\n        save_dir: [str, Path],\n        start=None,\n        end=None,\n        interval=\"1d\",\n        max_workers=1,\n        max_collector_count=2,\n        delay=1,  # delay need to be one\n        check_data_length: int = None,\n        limit_nums: int = None,\n    ):\n        \"\"\"\n\n        Parameters\n        ----------\n        save_dir: str\n            crypto save dir\n        max_workers: int\n            workers, default 4\n        max_collector_count: int\n            default 2\n        delay: float\n            time.sleep(delay), default 0\n        interval: str\n            freq, value from [1min, 1d], default 1min\n        start: str\n            start datetime, default None\n        end: str\n            end datetime, default None\n        check_data_length: int\n            check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None.\n        limit_nums: int\n            using for debug, by default None\n        \"\"\"\n        super(CryptoCollector, self).__init__(\n            save_dir=save_dir,\n            start=start,\n            end=end,\n            interval=interval,\n            max_workers=max_workers,\n            max_collector_count=max_collector_count,\n            delay=delay,\n            check_data_length=check_data_length,\n            limit_nums=limit_nums,\n        )\n\n        self.init_datetime()\n\n    def init_datetime(self):\n        if self.interval == self.INTERVAL_1min:\n            self.start_datetime = max(self.start_datetime, self.DEFAULT_START_DATETIME_1MIN)\n        elif self.interval == self.INTERVAL_1d:\n            pass\n        else:\n            raise ValueError(f\"interval error: {self.interval}\")\n\n        self.start_datetime = self.convert_datetime(self.start_datetime, self._timezone)\n        self.end_datetime = self.convert_datetime(self.end_datetime, self._timezone)\n\n    @staticmethod\n    def convert_datetime(dt: [pd.Timestamp, datetime.date, str], timezone):\n        try:\n            dt = pd.Timestamp(dt, tz=timezone).timestamp()\n            dt = pd.Timestamp(dt, tz=tzlocal(), unit=\"s\")\n        except ValueError as e:\n            pass\n        return dt\n\n    @property\n    @abc.abstractmethod\n    def _timezone(self):\n        raise NotImplementedError(\"rewrite get_timezone\")\n\n    @staticmethod\n    def get_data_from_remote(symbol, interval, start, end):\n        error_msg = f\"{symbol}-{interval}-{start}-{end}\"\n        try:\n            cg = CoinGeckoAPI()\n            data = cg.get_coin_market_chart_by_id(id=symbol, vs_currency=\"usd\", days=\"max\")\n            _resp = pd.DataFrame(columns=[\"date\"] + list(data.keys()))\n            _resp[\"date\"] = [dt.fromtimestamp(mktime(time.localtime(x[0] / 1000))) for x in data[\"prices\"]]\n            for key in data.keys():\n                _resp[key] = [x[1] for x in data[key]]\n            _resp[\"date\"] = pd.to_datetime(_resp[\"date\"])\n            _resp[\"date\"] = [x.date() for x in _resp[\"date\"]]\n            _resp = _resp[(_resp[\"date\"] < pd.to_datetime(end).date()) & (_resp[\"date\"] > pd.to_datetime(start).date())]\n            if _resp.shape[0] != 0:\n                _resp = _resp.reset_index()\n            if isinstance(_resp, pd.DataFrame):\n                return _resp.reset_index()\n        except Exception as e:\n            logger.warning(f\"{error_msg}:{e}\")\n\n    def get_data(\n        self, symbol: str, interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp\n    ) -> [pd.DataFrame]:\n        def _get_simple(start_, end_):\n            self.sleep()\n            _remote_interval = interval\n            return self.get_data_from_remote(\n                symbol,\n                interval=_remote_interval,\n                start=start_,\n                end=end_,\n            )\n\n        if interval == self.INTERVAL_1d:\n            _result = _get_simple(start_datetime, end_datetime)\n        else:\n            raise ValueError(f\"cannot support {interval}\")\n        return _result\n\n\nclass CryptoCollector1d(CryptoCollector, ABC):\n    def get_instrument_list(self):\n        logger.info(\"get coingecko crypto symbols......\")\n        symbols = get_cg_crypto_symbols()\n        logger.info(f\"get {len(symbols)} symbols.\")\n        return symbols\n\n    def normalize_symbol(self, symbol):\n        return symbol\n\n    @property\n    def _timezone(self):\n        return \"Asia/Shanghai\"\n\n\nclass CryptoNormalize(BaseNormalize):\n    DAILY_FORMAT = \"%Y-%m-%d\"\n\n    @staticmethod\n    def normalize_crypto(\n        df: pd.DataFrame,\n        calendar_list: list = None,\n        date_field_name: str = \"date\",\n        symbol_field_name: str = \"symbol\",\n    ):\n        if df.empty:\n            return df\n        df = df.copy()\n        df.set_index(date_field_name, inplace=True)\n        df.index = pd.to_datetime(df.index)\n        df = df[~df.index.duplicated(keep=\"first\")]\n        if calendar_list is not None:\n            df = df.reindex(\n                pd.DataFrame(index=calendar_list)\n                .loc[\n                    pd.Timestamp(df.index.min()).date() : pd.Timestamp(df.index.max()).date()\n                    + pd.Timedelta(hours=23, minutes=59)\n                ]\n                .index\n            )\n        df.sort_index(inplace=True)\n\n        df.index.names = [date_field_name]\n        return df.reset_index()\n\n    def normalize(self, df: pd.DataFrame) -> pd.DataFrame:\n        df = self.normalize_crypto(df, self._calendar_list, self._date_field_name, self._symbol_field_name)\n        return df\n\n\nclass CryptoNormalize1d(CryptoNormalize):\n    def _get_calendar_list(self):\n        return None\n\n\nclass Run(BaseRun):\n    def __init__(self, source_dir=None, normalize_dir=None, max_workers=1, interval=\"1d\"):\n        \"\"\"\n\n        Parameters\n        ----------\n        source_dir: str\n            The directory where the raw data collected from the Internet is saved, default \"Path(__file__).parent/source\"\n        normalize_dir: str\n            Directory for normalize data, default \"Path(__file__).parent/normalize\"\n        max_workers: int\n            Concurrent number, default is 1\n        interval: str\n            freq, value from [1min, 1d], default 1d\n        \"\"\"\n        super().__init__(source_dir, normalize_dir, max_workers, interval)\n\n    @property\n    def collector_class_name(self):\n        return f\"CryptoCollector{self.interval}\"\n\n    @property\n    def normalize_class_name(self):\n        return f\"CryptoNormalize{self.interval}\"\n\n    @property\n    def default_base_dir(self) -> [Path, str]:\n        return CUR_DIR\n\n    def download_data(\n        self,\n        max_collector_count=2,\n        delay=0,\n        start=None,\n        end=None,\n        check_data_length: int = None,\n        limit_nums=None,\n    ):\n        \"\"\"download data from Internet\n\n        Parameters\n        ----------\n        max_collector_count: int\n            default 2\n        delay: float\n            time.sleep(delay), default 0\n        interval: str\n            freq, value from [1min, 1d], default 1d, currently only supprot 1d\n        start: str\n            start datetime, default \"2000-01-01\"\n        end: str\n            end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))``\n        check_data_length: int # if this param useful?\n            check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None.\n        limit_nums: int\n            using for debug, by default None\n\n        Examples\n        ---------\n            # get daily data\n            $ python collector.py download_data --source_dir ~/.qlib/crypto_data/source/1d --start 2015-01-01 --end 2021-11-30 --delay 1 --interval 1d\n        \"\"\"\n\n        super(Run, self).download_data(max_collector_count, delay, start, end, check_data_length, limit_nums)\n\n    def normalize_data(self, date_field_name: str = \"date\", symbol_field_name: str = \"symbol\"):\n        \"\"\"normalize data\n\n        Parameters\n        ----------\n        date_field_name: str\n            date field name, default date\n        symbol_field_name: str\n            symbol field name, default symbol\n\n        Examples\n        ---------\n            $ python collector.py normalize_data --source_dir ~/.qlib/crypto_data/source/1d --normalize_dir ~/.qlib/crypto_data/source/1d_nor --interval 1d --date_field_name date\n        \"\"\"\n        super(Run, self).normalize_data(date_field_name, symbol_field_name)\n\n\nif __name__ == \"__main__\":\n    fire.Fire(Run)\n"
  },
  {
    "path": "scripts/data_collector/crypto/requirement.txt",
    "content": "loguru\nfire\nrequests\nnumpy\npandas\ntqdm\nlxml\npycoingecko"
  },
  {
    "path": "scripts/data_collector/fund/README.md",
    "content": "# Collect Fund Data\n\n> *Please pay **ATTENTION** that the data is collected from [天天基金网](https://fund.eastmoney.com/) and the data might not be perfect. We recommend users to prepare their own data if they have high-quality dataset. For more information, users can refer to the [related document](https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format)*\n\n## Requirements\n\n```bash\npip install -r requirements.txt\n```\n\n## Collector Data\n\n\n### CN Data\n\n#### 1d from East Money\n\n```bash\n\n# download from eastmoney.com\npython collector.py download_data --source_dir ~/.qlib/fund_data/source/cn_data --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d\n\n# normalize\npython collector.py normalize_data --source_dir ~/.qlib/fund_data/source/cn_data --normalize_dir ~/.qlib/fund_data/source/cn_1d_nor --region CN --interval 1d --date_field_name FSRQ\n\n# dump data\ncd qlib/scripts\npython dump_bin.py dump_all --data_path ~/.qlib/fund_data/source/cn_1d_nor --qlib_dir ~/.qlib/qlib_data/cn_fund_data --freq day --date_field_name FSRQ --include_fields DWJZ,LJJZ\n\n```\n\n### using data\n\n```python\nimport qlib\nfrom qlib.data import D\n\nqlib.init(provider_uri=\"~/.qlib/qlib_data/cn_fund_data\")\ndf = D.features(D.instruments(market=\"all\"), [\"$DWJZ\", \"$LJJZ\"], freq=\"day\")\n```\n\n\n### Help\n```bash\npythono collector.py collector_data --help\n```\n\n## Parameters\n\n- interval: 1d\n- region: CN\n\n## 免责声明\n\n本项目仅供学习研究使用，不作为任何行为的指导和建议，由此而引发任何争议和纠纷，与本项目无任何关系\n"
  },
  {
    "path": "scripts/data_collector/fund/collector.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport abc\nimport sys\nimport datetime\nimport json\nfrom abc import ABC\nfrom pathlib import Path\n\nimport fire\nimport requests\nimport pandas as pd\nfrom loguru import logger\nfrom dateutil.tz import tzlocal\nfrom qlib.constant import REG_CN as REGION_CN\n\nCUR_DIR = Path(__file__).resolve().parent\nsys.path.append(str(CUR_DIR.parent.parent))\nfrom data_collector.base import BaseCollector, BaseNormalize, BaseRun\nfrom data_collector.utils import get_calendar_list, get_en_fund_symbols\n\nINDEX_BENCH_URL = \"http://api.fund.eastmoney.com/f10/lsjz?callback=jQuery_&fundCode={index_code}&pageIndex=1&pageSize={numberOfHistoricalDaysToCrawl}&startDate={startDate}&endDate={endDate}\"\n\n\nclass FundCollector(BaseCollector):\n    def __init__(\n        self,\n        save_dir: [str, Path],\n        start=None,\n        end=None,\n        interval=\"1d\",\n        max_workers=4,\n        max_collector_count=2,\n        delay=0,\n        check_data_length: int = None,\n        limit_nums: int = None,\n    ):\n        \"\"\"\n\n        Parameters\n        ----------\n        save_dir: str\n            fund save dir\n        max_workers: int\n            workers, default 4\n        max_collector_count: int\n            default 2\n        delay: float\n            time.sleep(delay), default 0\n        interval: str\n            freq, value from [1min, 1d], default 1min\n        start: str\n            start datetime, default None\n        end: str\n            end datetime, default None\n        check_data_length: int\n            check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None.\n        limit_nums: int\n            using for debug, by default None\n        \"\"\"\n        super(FundCollector, self).__init__(\n            save_dir=save_dir,\n            start=start,\n            end=end,\n            interval=interval,\n            max_workers=max_workers,\n            max_collector_count=max_collector_count,\n            delay=delay,\n            check_data_length=check_data_length,\n            limit_nums=limit_nums,\n        )\n\n        self.init_datetime()\n\n    def init_datetime(self):\n        if self.interval == self.INTERVAL_1min:\n            self.start_datetime = max(self.start_datetime, self.DEFAULT_START_DATETIME_1MIN)\n        elif self.interval == self.INTERVAL_1d:\n            pass\n        else:\n            raise ValueError(f\"interval error: {self.interval}\")\n\n        self.start_datetime = self.convert_datetime(self.start_datetime, self._timezone)\n        self.end_datetime = self.convert_datetime(self.end_datetime, self._timezone)\n\n    @staticmethod\n    def convert_datetime(dt: [pd.Timestamp, datetime.date, str], timezone):\n        try:\n            dt = pd.Timestamp(dt, tz=timezone).timestamp()\n            dt = pd.Timestamp(dt, tz=tzlocal(), unit=\"s\")\n        except ValueError as e:\n            pass\n        return dt\n\n    @property\n    @abc.abstractmethod\n    def _timezone(self):\n        raise NotImplementedError(\"rewrite get_timezone\")\n\n    @staticmethod\n    def get_data_from_remote(symbol, interval, start, end):\n        error_msg = f\"{symbol}-{interval}-{start}-{end}\"\n\n        try:\n            # TODO: numberOfHistoricalDaysToCrawl should be bigger enough\n            url = INDEX_BENCH_URL.format(\n                index_code=symbol, numberOfHistoricalDaysToCrawl=10000, startDate=start, endDate=end\n            )\n            resp = requests.get(url, headers={\"referer\": \"http://fund.eastmoney.com/110022.html\"}, timeout=None)\n\n            if resp.status_code != 200:\n                raise ValueError(\"request error\")\n\n            data = json.loads(resp.text.split(\"(\")[-1].split(\")\")[0])\n\n            # Some funds don't show the net value, example: http://fundf10.eastmoney.com/jjjz_010288.html\n            SYType = data[\"Data\"][\"SYType\"]\n            if SYType in {\"每万份收益\", \"每百份收益\", \"每百万份收益\"}:\n                raise ValueError(\"The fund contains 每*份收益\")\n\n            # TODO: should we sort the value by datetime?\n            _resp = pd.DataFrame(data[\"Data\"][\"LSJZList\"])\n\n            if isinstance(_resp, pd.DataFrame):\n                return _resp.reset_index()\n        except Exception as e:\n            logger.warning(f\"{error_msg}:{e}\")\n\n    def get_data(\n        self, symbol: str, interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp\n    ) -> [pd.DataFrame]:\n        def _get_simple(start_, end_):\n            self.sleep()\n            _remote_interval = interval\n            return self.get_data_from_remote(\n                symbol,\n                interval=_remote_interval,\n                start=start_,\n                end=end_,\n            )\n\n        if interval == self.INTERVAL_1d:\n            _result = _get_simple(start_datetime, end_datetime)\n        else:\n            raise ValueError(f\"cannot support {interval}\")\n        return _result\n\n\nclass FundollectorCN(FundCollector, ABC):\n    def get_instrument_list(self):\n        logger.info(\"get cn fund symbols......\")\n        symbols = get_en_fund_symbols()\n        logger.info(f\"get {len(symbols)} symbols.\")\n        return symbols\n\n    def normalize_symbol(self, symbol):\n        return symbol\n\n    @property\n    def _timezone(self):\n        return \"Asia/Shanghai\"\n\n\nclass FundCollectorCN1d(FundollectorCN):\n    pass\n\n\nclass FundNormalize(BaseNormalize):\n    DAILY_FORMAT = \"%Y-%m-%d\"\n\n    @staticmethod\n    def normalize_fund(\n        df: pd.DataFrame,\n        calendar_list: list = None,\n        date_field_name: str = \"date\",\n        symbol_field_name: str = \"symbol\",\n    ):\n        if df.empty:\n            return df\n        df = df.copy()\n        df.set_index(date_field_name, inplace=True)\n        df.index = pd.to_datetime(df.index)\n        df = df[~df.index.duplicated(keep=\"first\")]\n        if calendar_list is not None:\n            df = df.reindex(\n                pd.DataFrame(index=calendar_list)\n                .loc[\n                    pd.Timestamp(df.index.min()).date() : pd.Timestamp(df.index.max()).date()\n                    + pd.Timedelta(hours=23, minutes=59)\n                ]\n                .index\n            )\n        df.sort_index(inplace=True)\n\n        df.index.names = [date_field_name]\n        return df.reset_index()\n\n    def normalize(self, df: pd.DataFrame) -> pd.DataFrame:\n        # normalize\n        df = self.normalize_fund(df, self._calendar_list, self._date_field_name, self._symbol_field_name)\n        return df\n\n\nclass FundNormalize1d(FundNormalize):\n    pass\n\n\nclass FundNormalizeCN:\n    def _get_calendar_list(self):\n        return get_calendar_list(\"ALL\")\n\n\nclass FundNormalizeCN1d(FundNormalizeCN, FundNormalize1d):\n    pass\n\n\nclass Run(BaseRun):\n    def __init__(self, source_dir=None, normalize_dir=None, max_workers=4, interval=\"1d\", region=REGION_CN):\n        \"\"\"\n\n        Parameters\n        ----------\n        source_dir: str\n            The directory where the raw data collected from the Internet is saved, default \"Path(__file__).parent/source\"\n        normalize_dir: str\n            Directory for normalize data, default \"Path(__file__).parent/normalize\"\n        max_workers: int\n            Concurrent number, default is 4\n        interval: str\n            freq, value from [1min, 1d], default 1d\n        region: str\n            region, value from [\"CN\"], default \"CN\"\n        \"\"\"\n        super().__init__(source_dir, normalize_dir, max_workers, interval)\n        self.region = region\n\n    @property\n    def collector_class_name(self):\n        return f\"FundCollector{self.region.upper()}{self.interval}\"\n\n    @property\n    def normalize_class_name(self):\n        return f\"FundNormalize{self.region.upper()}{self.interval}\"\n\n    @property\n    def default_base_dir(self) -> [Path, str]:\n        return CUR_DIR\n\n    def download_data(\n        self,\n        max_collector_count=2,\n        delay=0,\n        start=None,\n        end=None,\n        check_data_length: int = None,\n        limit_nums=None,\n    ):\n        \"\"\"download data from Internet\n\n        Parameters\n        ----------\n        max_collector_count: int\n            default 2\n        delay: float\n            time.sleep(delay), default 0\n        interval: str\n            freq, value from [1min, 1d], default 1d\n        start: str\n            start datetime, default \"2000-01-01\"\n        end: str\n            end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))``\n        check_data_length: int # if this param useful?\n            check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None.\n        limit_nums: int\n            using for debug, by default None\n\n        Examples\n        ---------\n            # get daily data\n            $ python collector.py download_data --source_dir ~/.qlib/fund_data/source/cn_data --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d\n        \"\"\"\n\n        super(Run, self).download_data(max_collector_count, delay, start, end, check_data_length, limit_nums)\n\n    def normalize_data(self, date_field_name: str = \"date\", symbol_field_name: str = \"symbol\"):\n        \"\"\"normalize data\n\n        Parameters\n        ----------\n        date_field_name: str\n            date field name, default date\n        symbol_field_name: str\n            symbol field name, default symbol\n\n        Examples\n        ---------\n            $ python collector.py normalize_data --source_dir ~/.qlib/fund_data/source/cn_data --normalize_dir ~/.qlib/fund_data/source/cn_1d_nor --region CN --interval 1d --date_field_name FSRQ\n        \"\"\"\n        super(Run, self).normalize_data(date_field_name, symbol_field_name)\n\n\nif __name__ == \"__main__\":\n    fire.Fire(Run)\n"
  },
  {
    "path": "scripts/data_collector/fund/requirements.txt",
    "content": "loguru\nfire\nrequests\nnumpy\npandas\ntqdm\nlxml\nloguru\nyahooquery\n"
  },
  {
    "path": "scripts/data_collector/future_calendar_collector.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport abc\nimport importlib\nfrom pathlib import Path\nfrom typing import Union, Iterable, List\n\nimport fire\nimport numpy as np\nimport pandas as pd\n\n# pip install baostock\nimport baostock as bs\nfrom loguru import logger\n\n\nclass CollectorFutureCalendar:\n    calendar_format = \"%Y-%m-%d\"\n\n    def __init__(self, qlib_dir: Union[str, Path], start_date: str = None, end_date: str = None):\n        \"\"\"\n\n        Parameters\n        ----------\n        qlib_dir:\n            qlib data directory\n        start_date\n            start date\n        end_date\n            end date\n        \"\"\"\n        self.qlib_dir = Path(qlib_dir).expanduser().absolute()\n        self.calendar_path = self.qlib_dir.joinpath(\"calendars/day.txt\")\n        self.future_path = self.qlib_dir.joinpath(\"calendars/day_future.txt\")\n        self._calendar_list = self.calendar_list\n        _latest_date = self._calendar_list[-1]\n        self.start_date = _latest_date if start_date is None else pd.Timestamp(start_date)\n        self.end_date = _latest_date + pd.Timedelta(days=365 * 2) if end_date is None else pd.Timestamp(end_date)\n\n    @property\n    def calendar_list(self) -> List[pd.Timestamp]:\n        # load old calendar\n        if not self.calendar_path.exists():\n            raise ValueError(f\"calendar does not exist: {self.calendar_path}\")\n        calendar_df = pd.read_csv(self.calendar_path, header=None)\n        calendar_df.columns = [\"date\"]\n        calendar_df[\"date\"] = pd.to_datetime(calendar_df[\"date\"])\n        return calendar_df[\"date\"].to_list()\n\n    def _format_datetime(self, datetime_d: [str, pd.Timestamp]):\n        datetime_d = pd.Timestamp(datetime_d)\n        return datetime_d.strftime(self.calendar_format)\n\n    def write_calendar(self, calendar: Iterable):\n        calendars_list = [self._format_datetime(x) for x in sorted(set(self.calendar_list + calendar))]\n        np.savetxt(self.future_path, calendars_list, fmt=\"%s\", encoding=\"utf-8\")\n\n    @abc.abstractmethod\n    def collector(self) -> Iterable[pd.Timestamp]:\n        \"\"\"\n\n        Returns\n        -------\n\n        \"\"\"\n        raise NotImplementedError(f\"Please implement the `collector` method\")\n\n\nclass CollectorFutureCalendarCN(CollectorFutureCalendar):\n    def collector(self) -> Iterable[pd.Timestamp]:\n        lg = bs.login()\n        if lg.error_code != \"0\":\n            raise ValueError(f\"login respond error_msg: {lg.error_msg}\")\n        rs = bs.query_trade_dates(\n            start_date=self._format_datetime(self.start_date), end_date=self._format_datetime(self.end_date)\n        )\n        if rs.error_code != \"0\":\n            raise ValueError(f\"query_trade_dates respond error_msg: {rs.error_msg}\")\n        data_list = []\n        while (rs.error_code == \"0\") & rs.next():\n            data_list.append(rs.get_row_data())\n        calendar = pd.DataFrame(data_list, columns=rs.fields)\n        calendar[\"is_trading_day\"] = calendar[\"is_trading_day\"].astype(int)\n        return pd.to_datetime(calendar[calendar[\"is_trading_day\"] == 1][\"calendar_date\"]).to_list()\n\n\nclass CollectorFutureCalendarUS(CollectorFutureCalendar):\n    def collector(self) -> Iterable[pd.Timestamp]:\n        # TODO: US future calendar\n        raise ValueError(\"Us calendar is not supported\")\n\n\ndef run(qlib_dir: Union[str, Path], region: str = \"cn\", start_date: str = None, end_date: str = None):\n    \"\"\"Collect future calendar(day)\n\n    Parameters\n    ----------\n    qlib_dir:\n        qlib data directory\n    region:\n        cn/CN or us/US\n    start_date\n        start date\n    end_date\n        end date\n\n    Examples\n    -------\n        # get cn future calendar\n        $ python future_calendar_collector.py --qlib_data_1d_dir <user data dir> --region cn\n    \"\"\"\n    logger.info(f\"collector future calendar: region={region}\")\n    _cur_module = importlib.import_module(\"future_calendar_collector\")\n    _class = getattr(_cur_module, f\"CollectorFutureCalendar{region.upper()}\")\n    collector = _class(qlib_dir=qlib_dir, start_date=start_date, end_date=end_date)\n    collector.write_calendar(collector.collector())\n\n\nif __name__ == \"__main__\":\n    fire.Fire(run)\n"
  },
  {
    "path": "scripts/data_collector/index.py",
    "content": "import sys\nimport abc\nfrom pathlib import Path\nfrom typing import List\n\nimport pandas as pd\nfrom tqdm import tqdm\nfrom loguru import logger\n\nCUR_DIR = Path(__file__).resolve().parent\nsys.path.append(str(CUR_DIR.parent))\n\n\nfrom data_collector.utils import get_trading_date_by_shift\n\n\nclass IndexBase:\n    DEFAULT_END_DATE = pd.Timestamp(\"2099-12-31\")\n    SYMBOL_FIELD_NAME = \"symbol\"\n    DATE_FIELD_NAME = \"date\"\n    START_DATE_FIELD = \"start_date\"\n    END_DATE_FIELD = \"end_date\"\n    CHANGE_TYPE_FIELD = \"type\"\n    INSTRUMENTS_COLUMNS = [SYMBOL_FIELD_NAME, START_DATE_FIELD, END_DATE_FIELD]\n    REMOVE = \"remove\"\n    ADD = \"add\"\n    INST_PREFIX = \"\"\n\n    def __init__(\n        self,\n        index_name: str,\n        qlib_dir: [str, Path] = None,\n        freq: str = \"day\",\n        request_retry: int = 5,\n        retry_sleep: int = 3,\n    ):\n        \"\"\"\n\n        Parameters\n        ----------\n        index_name: str\n            index name\n        qlib_dir: str\n            qlib directory, by default Path(__file__).resolve().parent.joinpath(\"qlib_data\")\n        freq: str\n            freq, value from [\"day\", \"1min\"]\n        request_retry: int\n            request retry, by default 5\n        retry_sleep: int\n            request sleep, by default 3\n        \"\"\"\n        self.index_name = index_name\n        if qlib_dir is None:\n            qlib_dir = Path(__file__).resolve().parent.joinpath(\"qlib_data\")\n        self.instruments_dir = Path(qlib_dir).expanduser().resolve().joinpath(\"instruments\")\n        self.instruments_dir.mkdir(exist_ok=True, parents=True)\n        self.cache_dir = Path(f\"~/.cache/qlib/index/{self.index_name}\").expanduser().resolve()\n        self.cache_dir.mkdir(exist_ok=True, parents=True)\n        self._request_retry = request_retry\n        self._retry_sleep = retry_sleep\n        self.freq = freq\n\n    @property\n    @abc.abstractmethod\n    def bench_start_date(self) -> pd.Timestamp:\n        \"\"\"\n        Returns\n        -------\n            index start date\n        \"\"\"\n        raise NotImplementedError(\"rewrite bench_start_date\")\n\n    @property\n    @abc.abstractmethod\n    def calendar_list(self) -> List[pd.Timestamp]:\n        \"\"\"get history trading date\n\n        Returns\n        -------\n            calendar list\n        \"\"\"\n        raise NotImplementedError(\"rewrite calendar_list\")\n\n    @abc.abstractmethod\n    def get_new_companies(self) -> pd.DataFrame:\n        \"\"\"\n\n        Returns\n        -------\n            pd.DataFrame:\n\n                symbol     start_date    end_date\n                SH600000   2000-01-01    2099-12-31\n\n            dtypes:\n                symbol: str\n                start_date: pd.Timestamp\n                end_date: pd.Timestamp\n        \"\"\"\n        raise NotImplementedError(\"rewrite get_new_companies\")\n\n    @abc.abstractmethod\n    def get_changes(self) -> pd.DataFrame:\n        \"\"\"get companies changes\n\n        Returns\n        -------\n            pd.DataFrame:\n                symbol      date        type\n                SH600000  2019-11-11    add\n                SH600000  2020-11-10    remove\n            dtypes:\n                symbol: str\n                date: pd.Timestamp\n                type: str, value from [\"add\", \"remove\"]\n        \"\"\"\n        raise NotImplementedError(\"rewrite get_changes\")\n\n    @abc.abstractmethod\n    def format_datetime(self, inst_df: pd.DataFrame) -> pd.DataFrame:\n        \"\"\"formatting the datetime in an instrument\n\n        Parameters\n        ----------\n        inst_df: pd.DataFrame\n            inst_df.columns = [self.SYMBOL_FIELD_NAME, self.START_DATE_FIELD, self.END_DATE_FIELD]\n\n        Returns\n        -------\n\n        \"\"\"\n        raise NotImplementedError(\"rewrite format_datetime\")\n\n    def save_new_companies(self):\n        \"\"\"save new companies\n\n        Examples\n        -------\n            $ python collector.py save_new_companies --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data\n        \"\"\"\n        df = self.get_new_companies()\n        if df is None or df.empty:\n            raise ValueError(f\"get new companies error: {self.index_name}\")\n        df = df.drop_duplicates([self.SYMBOL_FIELD_NAME])\n        df.loc[:, self.INSTRUMENTS_COLUMNS].to_csv(\n            self.instruments_dir.joinpath(f\"{self.index_name.lower()}_only_new.txt\"), sep=\"\\t\", index=False, header=None\n        )\n\n    def get_changes_with_history_companies(self, history_companies: pd.DataFrame) -> pd.DataFrame:\n        \"\"\"get changes with history companies\n\n        Parameters\n        ----------\n        history_companies : pd.DataFrame\n            symbol        date\n            SH600000   2020-11-11\n\n            dtypes:\n                symbol: str\n                date: pd.Timestamp\n\n        Return\n        --------\n            pd.DataFrame:\n                symbol      date        type\n                SH600000  2019-11-11    add\n                SH600000  2020-11-10    remove\n            dtypes:\n                symbol: str\n                date: pd.Timestamp\n                type: str, value from [\"add\", \"remove\"]\n\n        \"\"\"\n        logger.info(\"parse changes from history companies......\")\n        last_code = []\n        result_df_list = []\n        _columns = [self.DATE_FIELD_NAME, self.SYMBOL_FIELD_NAME, self.CHANGE_TYPE_FIELD]\n        for _trading_date in tqdm(sorted(history_companies[self.DATE_FIELD_NAME].unique(), reverse=True)):\n            _currenet_code = history_companies[history_companies[self.DATE_FIELD_NAME] == _trading_date][\n                self.SYMBOL_FIELD_NAME\n            ].tolist()\n            if last_code:\n                add_code = list(set(last_code) - set(_currenet_code))\n                remote_code = list(set(_currenet_code) - set(last_code))\n                for _code in add_code:\n                    result_df_list.append(\n                        pd.DataFrame(\n                            [[get_trading_date_by_shift(self.calendar_list, _trading_date, 1), _code, self.ADD]],\n                            columns=_columns,\n                        )\n                    )\n                for _code in remote_code:\n                    result_df_list.append(\n                        pd.DataFrame(\n                            [[get_trading_date_by_shift(self.calendar_list, _trading_date, 0), _code, self.REMOVE]],\n                            columns=_columns,\n                        )\n                    )\n            last_code = _currenet_code\n        df = pd.concat(result_df_list)\n        logger.info(\"end of parse changes from history companies.\")\n        return df\n\n    def parse_instruments(self):\n        \"\"\"parse instruments, eg: csi300.txt\n\n        Examples\n        -------\n            $ python collector.py parse_instruments --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data\n        \"\"\"\n        logger.info(f\"start parse {self.index_name.lower()} companies.....\")\n        instruments_columns = [self.SYMBOL_FIELD_NAME, self.START_DATE_FIELD, self.END_DATE_FIELD]\n        changers_df = self.get_changes()\n        new_df = self.get_new_companies()\n        if new_df is None or new_df.empty:\n            raise ValueError(f\"get new companies error: {self.index_name}\")\n        new_df = new_df.copy()\n        logger.info(\"parse history companies by changes......\")\n        for _row in tqdm(changers_df.sort_values(self.DATE_FIELD_NAME, ascending=False).itertuples(index=False)):\n            if _row.type == self.ADD:\n                min_end_date = new_df.loc[new_df[self.SYMBOL_FIELD_NAME] == _row.symbol, self.END_DATE_FIELD].min()\n                new_df.loc[\n                    (new_df[self.END_DATE_FIELD] == min_end_date) & (new_df[self.SYMBOL_FIELD_NAME] == _row.symbol),\n                    self.START_DATE_FIELD,\n                ] = _row.date\n            else:\n                _tmp_df = pd.DataFrame([[_row.symbol, self.bench_start_date, _row.date]], columns=instruments_columns)\n                new_df = pd.concat([new_df, _tmp_df], sort=False)\n\n        inst_df = new_df.loc[:, instruments_columns]\n        _inst_prefix = self.INST_PREFIX.strip()\n        if _inst_prefix:\n            inst_df[\"save_inst\"] = inst_df[self.SYMBOL_FIELD_NAME].apply(lambda x: f\"{_inst_prefix}{x}\")\n        inst_df = self.format_datetime(inst_df)\n        inst_df.to_csv(\n            self.instruments_dir.joinpath(f\"{self.index_name.lower()}.txt\"), sep=\"\\t\", index=False, header=None\n        )\n        logger.info(f\"parse {self.index_name.lower()} companies finished.\")\n"
  },
  {
    "path": "scripts/data_collector/pit/README.md",
    "content": "# Collect Point-in-Time Data\n\n> *Please pay **ATTENTION** that the data is collected from [baostock](http://baostock.com) and the data might not be perfect. We recommend users to prepare their own data if they have high-quality dataset. For more information, users can refer to the [related document](https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format)*\n\n## Requirements\n\n```bash\npip install -r requirements.txt\n```\n\n## Collector Data\n\n\n### Download Quarterly CN Data\n\n```bash\ncd qlib/scripts/data_collector/pit/\n# download from baostock.com\npython collector.py download_data --source_dir ~/.qlib/stock_data/source/pit --start 2000-01-01 --end 2020-01-01 --interval quarterly\n```\n\nDownloading all data from the stock is very time-consuming. If you just want to run a quick test on a few stocks,  you can run the command below\n```bash\npython collector.py download_data --source_dir ~/.qlib/stock_data/source/pit --start 2000-01-01 --end 2020-01-01 --interval quarterly --symbol_regex \"^(600519|000725).*\"\n```\n\n\n### Normalize Data\n```bash\npython collector.py normalize_data --interval quarterly --source_dir ~/.qlib/stock_data/source/pit --normalize_dir ~/.qlib/stock_data/source/pit_normalized\n```\n\n\n\n### Dump Data into PIT Format\n\n```bash\ncd qlib/scripts\npython dump_pit.py dump --data_path ~/.qlib/stock_data/source/pit_normalized --qlib_dir ~/.qlib/qlib_data/cn_data --interval quarterly\n```\n"
  },
  {
    "path": "scripts/data_collector/pit/collector.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport re\nimport sys\nfrom datetime import datetime\nfrom pathlib import Path\nfrom typing import List, Iterable, Optional, Union\n\nimport fire\nimport pandas as pd\nimport baostock as bs\nfrom loguru import logger\n\nBASE_DIR = Path(__file__).resolve().parent\nsys.path.append(str(BASE_DIR.parent.parent))\n\nfrom data_collector.base import BaseCollector, BaseRun, BaseNormalize\nfrom data_collector.utils import get_hs_stock_symbols, get_calendar_list\n\n\nclass PitCollector(BaseCollector):\n    DEFAULT_START_DATETIME_QUARTERLY = pd.Timestamp(\"2000-01-01\")\n    DEFAULT_START_DATETIME_ANNUAL = pd.Timestamp(\"2000-01-01\")\n    DEFAULT_END_DATETIME_QUARTERLY = pd.Timestamp(datetime.now() + pd.Timedelta(days=1))\n    DEFAULT_END_DATETIME_ANNUAL = pd.Timestamp(datetime.now() + pd.Timedelta(days=1))\n\n    INTERVAL_QUARTERLY = \"quarterly\"\n    INTERVAL_ANNUAL = \"annual\"\n\n    def __init__(\n        self,\n        save_dir: Union[str, Path],\n        start: Optional[str] = None,\n        end: Optional[str] = None,\n        interval: str = \"quarterly\",\n        max_workers: int = 1,\n        max_collector_count: int = 1,\n        delay: int = 0,\n        check_data_length: bool = False,\n        limit_nums: Optional[int] = None,\n        symbol_regex: Optional[str] = None,\n    ):\n        \"\"\"\n        Parameters\n        ----------\n        save_dir: str\n            instrument save dir\n        max_workers: int\n            workers, default 1; Concurrent number, default is 1; when collecting data, it is recommended that max_workers be set to 1\n        max_collector_count: int\n            default 2\n        delay: float\n            time.sleep(delay), default 0\n        interval: str\n            freq, value from [1min, 1d], default 1d\n        start: str\n            start datetime, default None\n        end: str\n            end datetime, default None\n        check_data_length: int\n            check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None.\n        limit_nums: int\n            using for debug, by default None\n        symbol_regex: str\n            symbol regular expression, by default None.\n        \"\"\"\n        self.symbol_regex = symbol_regex\n        super().__init__(\n            save_dir=save_dir,\n            start=start,\n            end=end,\n            interval=interval,\n            max_workers=max_workers,\n            max_collector_count=max_collector_count,\n            delay=delay,\n            check_data_length=check_data_length,\n            limit_nums=limit_nums,\n        )\n\n    def get_instrument_list(self) -> List[str]:\n        logger.info(\"get cn stock symbols......\")\n        symbols = get_hs_stock_symbols()\n        if self.symbol_regex is not None:\n            regex_compile = re.compile(self.symbol_regex)\n            symbols = [symbol for symbol in symbols if regex_compile.match(symbol)]\n        logger.info(f\"get {len(symbols)} symbols.\")\n        return symbols\n\n    def normalize_symbol(self, symbol: str) -> str:\n        symbol, exchange = symbol.split(\".\")\n        exchange = \"sh\" if exchange == \"ss\" else \"sz\"\n        return f\"{exchange}{symbol}\"\n\n    @staticmethod\n    def get_performance_express_report_df(code: str, start_date: str, end_date: str) -> pd.DataFrame:\n        column_mapping = {\n            \"performanceExpPubDate\": \"date\",\n            \"performanceExpStatDate\": \"period\",\n            \"performanceExpressROEWa\": \"value\",\n        }\n\n        resp = bs.query_performance_express_report(code=code, start_date=start_date, end_date=end_date)\n        report_list = []\n        while (resp.error_code == \"0\") and resp.next():\n            report_list.append(resp.get_row_data())\n        report_df = pd.DataFrame(report_list, columns=resp.fields)\n        try:\n            report_df = report_df[list(column_mapping.keys())]\n        except KeyError:\n            return pd.DataFrame()\n        report_df.rename(columns=column_mapping, inplace=True)\n        report_df[\"field\"] = \"roeWa\"\n        report_df[\"value\"] = pd.to_numeric(report_df[\"value\"], errors=\"ignore\")\n        report_df[\"value\"] = report_df[\"value\"].apply(lambda x: x / 100.0)\n        return report_df\n\n    @staticmethod\n    def get_profit_df(code: str, start_date: str, end_date: str) -> pd.DataFrame:\n        column_mapping = {\"pubDate\": \"date\", \"statDate\": \"period\", \"roeAvg\": \"value\"}\n        fields = bs.query_profit_data(code=\"sh.600519\", year=2020, quarter=1).fields\n        start_date = datetime.strptime(start_date, \"%Y-%m-%d\")\n        end_date = datetime.strptime(end_date, \"%Y-%m-%d\")\n        args = [(year, quarter) for quarter in range(1, 5) for year in range(start_date.year - 1, end_date.year + 1)]\n        profit_list = []\n        for year, quarter in args:\n            resp = bs.query_profit_data(code=code, year=year, quarter=quarter)\n            while (resp.error_code == \"0\") and resp.next():\n                if \"pubDate\" not in resp.fields:\n                    continue\n                row_data = resp.get_row_data()\n                pub_date = pd.Timestamp(row_data[resp.fields.index(\"pubDate\")])\n                if start_date <= pub_date <= end_date and row_data:\n                    profit_list.append(row_data)\n        profit_df = pd.DataFrame(profit_list, columns=fields)\n        try:\n            profit_df = profit_df[list(column_mapping.keys())]\n        except KeyError:\n            return pd.DataFrame()\n        profit_df.rename(columns=column_mapping, inplace=True)\n        profit_df[\"field\"] = \"roeWa\"\n        profit_df[\"value\"] = pd.to_numeric(profit_df[\"value\"], errors=\"ignore\")\n        return profit_df\n\n    @staticmethod\n    def get_forecast_report_df(code: str, start_date: str, end_date: str) -> pd.DataFrame:\n        column_mapping = {\n            \"profitForcastExpPubDate\": \"date\",\n            \"profitForcastExpStatDate\": \"period\",\n            \"value\": \"value\",\n        }\n        resp = bs.query_forecast_report(code=code, start_date=start_date, end_date=end_date)\n        forecast_list = []\n        while (resp.error_code == \"0\") and resp.next():\n            forecast_list.append(resp.get_row_data())\n        forecast_df = pd.DataFrame(forecast_list, columns=resp.fields)\n        numeric_fields = [\"profitForcastChgPctUp\", \"profitForcastChgPctDwn\"]\n        try:\n            forecast_df[numeric_fields] = forecast_df[numeric_fields].apply(pd.to_numeric, errors=\"ignore\")\n        except KeyError:\n            return pd.DataFrame()\n        forecast_df[\"value\"] = (forecast_df[\"profitForcastChgPctUp\"] + forecast_df[\"profitForcastChgPctDwn\"]) / 200\n        forecast_df = forecast_df[list(column_mapping.keys())]\n        forecast_df.rename(columns=column_mapping, inplace=True)\n        forecast_df[\"field\"] = \"YOYNI\"\n        return forecast_df\n\n    @staticmethod\n    def get_growth_df(code: str, start_date: str, end_date: str) -> pd.DataFrame:\n        column_mapping = {\"pubDate\": \"date\", \"statDate\": \"period\", \"YOYNI\": \"value\"}\n        fields = bs.query_growth_data(code=\"sh.600519\", year=2020, quarter=1).fields\n        start_date = datetime.strptime(start_date, \"%Y-%m-%d\")\n        end_date = datetime.strptime(end_date, \"%Y-%m-%d\")\n        args = [(year, quarter) for quarter in range(1, 5) for year in range(start_date.year - 1, end_date.year + 1)]\n        growth_list = []\n        for year, quarter in args:\n            resp = bs.query_growth_data(code=code, year=year, quarter=quarter)\n            while (resp.error_code == \"0\") and resp.next():\n                if \"pubDate\" not in resp.fields:\n                    continue\n                row_data = resp.get_row_data()\n                pub_date = pd.Timestamp(row_data[resp.fields.index(\"pubDate\")])\n                if start_date <= pub_date <= end_date and row_data:\n                    growth_list.append(row_data)\n        growth_df = pd.DataFrame(growth_list, columns=fields)\n        try:\n            growth_df = growth_df[list(column_mapping.keys())]\n        except KeyError:\n            return pd.DataFrame()\n        growth_df.rename(columns=column_mapping, inplace=True)\n        growth_df[\"field\"] = \"YOYNI\"\n        growth_df[\"value\"] = pd.to_numeric(growth_df[\"value\"], errors=\"ignore\")\n        return growth_df\n\n    def get_data(\n        self,\n        symbol: str,\n        interval: str,\n        start_datetime: pd.Timestamp,\n        end_datetime: pd.Timestamp,\n    ) -> pd.DataFrame:\n        if interval != self.INTERVAL_QUARTERLY:\n            raise ValueError(f\"cannot support {interval}\")\n        symbol, exchange = symbol.split(\".\")\n        exchange = \"sh\" if exchange == \"ss\" else \"sz\"\n        code = f\"{exchange}.{symbol}\"\n        start_date = start_datetime.strftime(\"%Y-%m-%d\")\n        end_date = end_datetime.strftime(\"%Y-%m-%d\")\n\n        performance_express_report_df = self.get_performance_express_report_df(code, start_date, end_date)\n        profit_df = self.get_profit_df(code, start_date, end_date)\n        forecast_report_df = self.get_forecast_report_df(code, start_date, end_date)\n        growth_df = self.get_growth_df(code, start_date, end_date)\n\n        df = pd.concat(\n            [performance_express_report_df, profit_df, forecast_report_df, growth_df],\n            axis=0,\n        )\n        return df\n\n\nclass PitNormalize(BaseNormalize):\n    def __init__(self, interval: str = \"quarterly\", *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.interval = interval\n\n    def normalize(self, df: pd.DataFrame) -> pd.DataFrame:\n        dt = df[\"period\"].apply(\n            lambda x: (\n                pd.to_datetime(x) + pd.DateOffset(days=(45 if self.interval == PitCollector.INTERVAL_QUARTERLY else 90))\n            ).date()\n        )\n        df[\"date\"] = df[\"date\"].fillna(dt.astype(str))\n\n        df[\"period\"] = pd.to_datetime(df[\"period\"])\n        df[\"period\"] = df[\"period\"].apply(\n            lambda x: x.year if self.interval == PitCollector.INTERVAL_ANNUAL else x.year * 100 + (x.month - 1) // 3 + 1\n        )\n        return df\n\n    def _get_calendar_list(self) -> Iterable[pd.Timestamp]:\n        return get_calendar_list()\n\n\nclass Run(BaseRun):\n    @property\n    def collector_class_name(self) -> str:\n        return f\"PitCollector\"\n\n    @property\n    def normalize_class_name(self) -> str:\n        return f\"PitNormalize\"\n\n    @property\n    def default_base_dir(self) -> [Path, str]:\n        return BASE_DIR\n\n\nif __name__ == \"__main__\":\n    bs.login()\n    fire.Fire(Run)\n    bs.logout()\n"
  },
  {
    "path": "scripts/data_collector/pit/requirements.txt",
    "content": "loguru\nfire\ntqdm\nrequests\npandas\nlxml\nloguru\nbaostock\nyahooquery\nbeautifulsoup4\n"
  },
  {
    "path": "scripts/data_collector/us_index/README.md",
    "content": "# NASDAQ100/SP500/SP400/DJIA History Companies Collection\n\n## Requirements\n\n```bash\npip install -r requirements.txt\n```\n\n## Collector Data\n\n```bash\n# parse instruments, using in qlib/instruments.\npython collector.py --index_name SP500 --qlib_dir ~/.qlib/qlib_data/us_data --method parse_instruments\n\n# parse new companies\npython collector.py --index_name SP500 --qlib_dir ~/.qlib/qlib_data/us_data --method save_new_companies\n\n# index_name support: SP500, NASDAQ100, DJIA, SP400\n# help\npython collector.py --help\n```\n\n"
  },
  {
    "path": "scripts/data_collector/us_index/collector.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport abc\nfrom functools import partial\nimport sys\nfrom pathlib import Path\nfrom concurrent.futures import ThreadPoolExecutor\nfrom typing import List\nfrom io import StringIO\n\nimport fire\nimport requests\nimport pandas as pd\nfrom tqdm import tqdm\nfrom loguru import logger\nfrom fake_useragent import UserAgent\n\nCUR_DIR = Path(__file__).resolve().parent\nsys.path.append(str(CUR_DIR.parent.parent))\n\nfrom data_collector.index import IndexBase\nfrom data_collector.utils import deco_retry, get_calendar_list, get_trading_date_by_shift\nfrom data_collector.utils import get_instruments\n\nWIKI_URL = \"https://en.wikipedia.org/wiki\"\n\nWIKI_INDEX_NAME_MAP = {\n    \"NASDAQ100\": \"NASDAQ-100\",\n    \"SP500\": \"List_of_S%26P_500_companies\",\n    \"SP400\": \"List_of_S%26P_400_companies\",\n    \"DJIA\": \"Dow_Jones_Industrial_Average\",\n}\n\n\nclass WIKIIndex(IndexBase):\n    # NOTE: The US stock code contains \"PRN\", and the directory cannot be created on Windows system, use the \"_\" prefix\n    # https://superuser.com/questions/613313/why-cant-we-make-con-prn-null-folder-in-windows\n    INST_PREFIX = \"\"\n\n    def __init__(\n        self,\n        index_name: str,\n        qlib_dir: [str, Path] = None,\n        freq: str = \"day\",\n        request_retry: int = 5,\n        retry_sleep: int = 3,\n    ):\n        super(WIKIIndex, self).__init__(\n            index_name=index_name, qlib_dir=qlib_dir, freq=freq, request_retry=request_retry, retry_sleep=retry_sleep\n        )\n\n        self._target_url = f\"{WIKI_URL}/{WIKI_INDEX_NAME_MAP[self.index_name.upper()]}\"\n        self._ua = UserAgent()\n\n    @property\n    @abc.abstractmethod\n    def bench_start_date(self) -> pd.Timestamp:\n        \"\"\"\n        Returns\n        -------\n            index start date\n        \"\"\"\n        raise NotImplementedError(\"rewrite bench_start_date\")\n\n    @abc.abstractmethod\n    def get_changes(self) -> pd.DataFrame:\n        \"\"\"get companies changes\n\n        Returns\n        -------\n            pd.DataFrame:\n                symbol      date        type\n                SH600000  2019-11-11    add\n                SH600000  2020-11-10    remove\n            dtypes:\n                symbol: str\n                date: pd.Timestamp\n                type: str, value from [\"add\", \"remove\"]\n        \"\"\"\n        raise NotImplementedError(\"rewrite get_changes\")\n\n    def format_datetime(self, inst_df: pd.DataFrame) -> pd.DataFrame:\n        \"\"\"formatting the datetime in an instrument\n\n        Parameters\n        ----------\n        inst_df: pd.DataFrame\n            inst_df.columns = [self.SYMBOL_FIELD_NAME, self.START_DATE_FIELD, self.END_DATE_FIELD]\n\n        Returns\n        -------\n\n        \"\"\"\n        if self.freq != \"day\":\n            inst_df[self.END_DATE_FIELD] = inst_df[self.END_DATE_FIELD].apply(\n                lambda x: (pd.Timestamp(x) + pd.Timedelta(hours=23, minutes=59)).strftime(\"%Y-%m-%d %H:%M:%S\")\n            )\n        return inst_df\n\n    @property\n    def calendar_list(self) -> List[pd.Timestamp]:\n        \"\"\"get history trading date\n\n        Returns\n        -------\n            calendar list\n        \"\"\"\n        _calendar_list = getattr(self, \"_calendar_list\", None)\n        if _calendar_list is None:\n            _calendar_list = list(filter(lambda x: x >= self.bench_start_date, get_calendar_list(\"US_ALL\")))\n            setattr(self, \"_calendar_list\", _calendar_list)\n        return _calendar_list\n\n    def _request_new_companies(self) -> requests.Response:\n        headers = {\"User-Agent\": self._ua.random}\n        resp = requests.get(self._target_url, timeout=None, headers=headers)\n        if resp.status_code != 200:\n            raise ValueError(f\"request error: {self._target_url}\")\n\n        return resp\n\n    def set_default_date_range(self, df: pd.DataFrame) -> pd.DataFrame:\n        _df = df.copy()\n        _df[self.SYMBOL_FIELD_NAME] = _df[self.SYMBOL_FIELD_NAME].str.strip()\n        _df[self.START_DATE_FIELD] = self.bench_start_date\n        _df[self.END_DATE_FIELD] = self.DEFAULT_END_DATE\n        return _df.loc[:, self.INSTRUMENTS_COLUMNS]\n\n    def get_new_companies(self):\n        logger.info(f\"get new companies {self.index_name} ......\")\n        _data = deco_retry(retry=self._request_retry, retry_sleep=self._retry_sleep)(self._request_new_companies)()\n        df_list = pd.read_html(StringIO(_data.text))\n        for _df in df_list:\n            _df = self.filter_df(_df)\n            if (_df is not None) and (not _df.empty):\n                _df.columns = [self.SYMBOL_FIELD_NAME]\n                _df = self.set_default_date_range(_df)\n                logger.info(f\"end of get new companies {self.index_name} ......\")\n                return _df\n\n    def filter_df(self, df: pd.DataFrame) -> pd.DataFrame:\n        raise NotImplementedError(\"rewrite filter_df\")\n\n\nclass NASDAQ100Index(WIKIIndex):\n    HISTORY_COMPANIES_URL = (\n        \"https://indexes.nasdaqomx.com/Index/WeightingData?id=NDX&tradeDate={trade_date}T00%3A00%3A00.000&timeOfDay=SOD\"\n    )\n    MAX_WORKERS = 16\n\n    def filter_df(self, df: pd.DataFrame) -> pd.DataFrame:\n        if len(df) >= 100 and \"Ticker\" in df.columns:\n            return df.loc[:, [\"Ticker\"]].copy()\n\n    @property\n    def bench_start_date(self) -> pd.Timestamp:\n        return pd.Timestamp(\"2003-01-02\")\n\n    @deco_retry\n    def _request_history_companies(self, trade_date: pd.Timestamp, use_cache: bool = True) -> pd.DataFrame:\n        trade_date = trade_date.strftime(\"%Y-%m-%d\")\n        cache_path = self.cache_dir.joinpath(f\"{trade_date}_history_companies.pkl\")\n        if cache_path.exists() and use_cache:\n            df = pd.read_pickle(cache_path)\n        else:\n            url = self.HISTORY_COMPANIES_URL.format(trade_date=trade_date)\n            resp = requests.post(url, timeout=None)\n            if resp.status_code != 200:\n                raise ValueError(f\"request error: {url}\")\n            df = pd.DataFrame(resp.json()[\"aaData\"])\n            df[self.DATE_FIELD_NAME] = trade_date\n            df.rename(columns={\"Name\": \"name\", \"Symbol\": self.SYMBOL_FIELD_NAME}, inplace=True)\n            if not df.empty:\n                df.to_pickle(cache_path)\n        return df\n\n    def get_history_companies(self):\n        logger.info(f\"start get history companies......\")\n        all_history = []\n        error_list = []\n        with tqdm(total=len(self.calendar_list)) as p_bar:\n            with ThreadPoolExecutor(max_workers=self.MAX_WORKERS) as executor:\n                for _trading_date, _df in zip(\n                    self.calendar_list, executor.map(self._request_history_companies, self.calendar_list)\n                ):\n                    if _df.empty:\n                        error_list.append(_trading_date)\n                    else:\n                        all_history.append(_df)\n                    p_bar.update()\n\n        if error_list:\n            logger.warning(f\"get error: {error_list}\")\n        logger.info(f\"total {len(self.calendar_list)}, error {len(error_list)}\")\n        logger.info(f\"end of get history companies.\")\n        return pd.concat(all_history, sort=False)\n\n    def get_changes(self):\n        return self.get_changes_with_history_companies(self.get_history_companies())\n\n\nclass DJIAIndex(WIKIIndex):\n    @property\n    def bench_start_date(self) -> pd.Timestamp:\n        return pd.Timestamp(\"2000-01-01\")\n\n    def get_changes(self) -> pd.DataFrame:\n        pass\n\n    def filter_df(self, df: pd.DataFrame) -> pd.DataFrame:\n        if \"Symbol\" in df.columns:\n            _df = df.loc[:, [\"Symbol\"]].copy()\n            _df[\"Symbol\"] = _df[\"Symbol\"].apply(lambda x: x.split(\":\")[-1])\n            return _df\n\n    def parse_instruments(self):\n        logger.warning(f\"No suitable data source has been found!\")\n\n\nclass SP500Index(WIKIIndex):\n    WIKISP500_CHANGES_URL = \"https://en.wikipedia.org/wiki/List_of_S%26P_500_companies\"\n\n    @property\n    def bench_start_date(self) -> pd.Timestamp:\n        return pd.Timestamp(\"1999-01-01\")\n\n    def get_changes(self) -> pd.DataFrame:\n        logger.info(f\"get sp500 history changes......\")\n        # NOTE: may update the index of the table\n        # Add headers to avoid 403 Forbidden error from Wikipedia\n        headers = {\"User-Agent\": self._ua.random}\n        response = requests.get(self.WIKISP500_CHANGES_URL, headers=headers, timeout=None)\n        response.raise_for_status()\n        changes_df = pd.read_html(StringIO(response.text))[-1]\n        changes_df = changes_df.iloc[:, [0, 1, 3]]\n        changes_df.columns = [self.DATE_FIELD_NAME, self.ADD, self.REMOVE]\n        changes_df[self.DATE_FIELD_NAME] = pd.to_datetime(changes_df[self.DATE_FIELD_NAME])\n        _result = []\n        for _type in [self.ADD, self.REMOVE]:\n            _df = changes_df.copy()\n            _df[self.CHANGE_TYPE_FIELD] = _type\n            _df[self.SYMBOL_FIELD_NAME] = _df[_type]\n            _df.dropna(subset=[self.SYMBOL_FIELD_NAME], inplace=True)\n            if _type == self.ADD:\n                _df[self.DATE_FIELD_NAME] = _df[self.DATE_FIELD_NAME].apply(\n                    lambda x: get_trading_date_by_shift(self.calendar_list, x, 0)\n                )\n            else:\n                _df[self.DATE_FIELD_NAME] = _df[self.DATE_FIELD_NAME].apply(\n                    lambda x: get_trading_date_by_shift(self.calendar_list, x, -1)\n                )\n            _result.append(_df[[self.DATE_FIELD_NAME, self.CHANGE_TYPE_FIELD, self.SYMBOL_FIELD_NAME]])\n        logger.info(f\"end of get sp500 history changes.\")\n        return pd.concat(_result, sort=False)\n\n    def filter_df(self, df: pd.DataFrame) -> pd.DataFrame:\n        if \"Symbol\" in df.columns:\n            return df.loc[:, [\"Symbol\"]].copy()\n\n\nclass SP400Index(WIKIIndex):\n    @property\n    def bench_start_date(self) -> pd.Timestamp:\n        return pd.Timestamp(\"2000-01-01\")\n\n    def get_changes(self) -> pd.DataFrame:\n        pass\n\n    def filter_df(self, df: pd.DataFrame) -> pd.DataFrame:\n        if \"Ticker symbol\" in df.columns:\n            return df.loc[:, [\"Ticker symbol\"]].copy()\n\n    def parse_instruments(self):\n        logger.warning(f\"No suitable data source has been found!\")\n\n\nif __name__ == \"__main__\":\n    fire.Fire(partial(get_instruments, market_index=\"us_index\"))\n"
  },
  {
    "path": "scripts/data_collector/us_index/requirements.txt",
    "content": "fire\nrequests\npandas\nlxml\nloguru\nfake-useragent\n"
  },
  {
    "path": "scripts/data_collector/utils.py",
    "content": "#  Copyright (c) Microsoft Corporation.\n#  Licensed under the MIT License.\n\nimport re\nimport copy\nimport importlib\nimport time\nimport bisect\nimport pickle\nimport requests\nimport functools\nfrom pathlib import Path\nfrom typing import Iterable, Tuple, List\n\nimport numpy as np\nimport pandas as pd\nfrom loguru import logger\nfrom yahooquery import Ticker\nfrom tqdm import tqdm\nfrom functools import partial\nfrom concurrent.futures import ProcessPoolExecutor\nfrom bs4 import BeautifulSoup\n\nfrom qlib.utils.pickle_utils import restricted_pickle_load\n\nHS_SYMBOLS_URL = \"http://app.finance.ifeng.com/hq/list.php?type=stock_a&class={s_type}\"\n\nCALENDAR_URL_BASE = \"http://push2his.eastmoney.com/api/qt/stock/kline/get?secid={market}.{bench_code}&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg=19900101&end=20991231\"\nSZSE_CALENDAR_URL = \"http://www.szse.cn/api/report/exchange/onepersistenthour/monthList?month={month}&random={random}\"\n\nCALENDAR_BENCH_URL_MAP = {\n    \"CSI300\": CALENDAR_URL_BASE.format(market=1, bench_code=\"000300\"),\n    \"CSI500\": CALENDAR_URL_BASE.format(market=1, bench_code=\"000905\"),\n    \"CSI100\": CALENDAR_URL_BASE.format(market=1, bench_code=\"000903\"),\n    # NOTE: Use the time series of SH600000 as the sequence of all stocks\n    \"ALL\": CALENDAR_URL_BASE.format(market=1, bench_code=\"000905\"),\n    # NOTE: Use the time series of ^GSPC(SP500) as the sequence of all stocks\n    \"US_ALL\": \"^GSPC\",\n    \"IN_ALL\": \"^NSEI\",\n    \"BR_ALL\": \"^BVSP\",\n}\n\n_BENCH_CALENDAR_LIST = None\n_ALL_CALENDAR_LIST = None\n_HS_SYMBOLS = None\n_US_SYMBOLS = None\n_IN_SYMBOLS = None\n_BR_SYMBOLS = None\n_EN_FUND_SYMBOLS = None\n_CALENDAR_MAP = {}\n\n# NOTE: Until 2020-10-20 20:00:00\nMINIMUM_SYMBOLS_NUM = 3900\n\n\ndef get_calendar_list(bench_code=\"CSI300\") -> List[pd.Timestamp]:\n    \"\"\"get SH/SZ history calendar list\n\n    Parameters\n    ----------\n    bench_code: str\n        value from [\"CSI300\", \"CSI500\", \"ALL\", \"US_ALL\"]\n\n    Returns\n    -------\n        history calendar list\n    \"\"\"\n\n    logger.info(f\"get calendar list: {bench_code}......\")\n\n    def _get_calendar(url):\n        _value_list = requests.get(url, timeout=None).json()[\"data\"][\"klines\"]\n        return sorted(map(lambda x: pd.Timestamp(x.split(\",\")[0]), _value_list))\n\n    calendar = _CALENDAR_MAP.get(bench_code, None)\n    if calendar is None:\n        if bench_code.startswith(\"US_\") or bench_code.startswith(\"IN_\") or bench_code.startswith(\"BR_\"):\n            print(Ticker(CALENDAR_BENCH_URL_MAP[bench_code]))\n            print(Ticker(CALENDAR_BENCH_URL_MAP[bench_code]).history(interval=\"1d\", period=\"max\"))\n            df = Ticker(CALENDAR_BENCH_URL_MAP[bench_code]).history(interval=\"1d\", period=\"max\")\n            calendar = df.index.get_level_values(level=\"date\").map(pd.Timestamp).unique().tolist()\n        else:\n            if bench_code.upper() == \"ALL\":\n                import akshare as ak  # pylint: disable=C0415\n\n                trade_date_df = ak.tool_trade_date_hist_sina()\n                trade_date_list = trade_date_df[\"trade_date\"].tolist()\n                trade_date_list = [pd.Timestamp(d) for d in trade_date_list]\n                dates = pd.DatetimeIndex(trade_date_list)\n                filtered_dates = dates[(dates >= \"2000-01-04\") & (dates <= pd.Timestamp.today().normalize())]\n                calendar = filtered_dates.tolist()\n            else:\n                calendar = _get_calendar(CALENDAR_BENCH_URL_MAP[bench_code])\n        _CALENDAR_MAP[bench_code] = calendar\n    logger.info(f\"end of get calendar list: {bench_code}.\")\n    return calendar\n\n\ndef return_date_list(date_field_name: str, file_path: Path):\n    date_list = pd.read_csv(file_path, sep=\",\", index_col=0)[date_field_name].to_list()\n    return sorted([pd.Timestamp(x) for x in date_list])\n\n\ndef get_calendar_list_by_ratio(\n    source_dir: [str, Path],\n    date_field_name: str = \"date\",\n    threshold: float = 0.5,\n    minimum_count: int = 10,\n    max_workers: int = 16,\n) -> list:\n    \"\"\"get calendar list by selecting the date when few funds trade in this day\n\n    Parameters\n    ----------\n    source_dir: str or Path\n        The directory where the raw data collected from the Internet is saved\n    date_field_name: str\n            date field name, default is date\n    threshold: float\n        threshold to exclude some days when few funds trade in this day, default 0.5\n    minimum_count: int\n        minimum count of funds should trade in one day\n    max_workers: int\n        Concurrent number, default is 16\n\n    Returns\n    -------\n        history calendar list\n    \"\"\"\n    logger.info(f\"get calendar list from {source_dir} by threshold = {threshold}......\")\n\n    source_dir = Path(source_dir).expanduser()\n    file_list = list(source_dir.glob(\"*.csv\"))\n\n    _number_all_funds = len(file_list)\n\n    logger.info(f\"count how many funds trade in this day......\")\n    _dict_count_trade = dict()  # dict{date:count}\n    _fun = partial(return_date_list, date_field_name)\n    all_oldest_list = []\n    with tqdm(total=_number_all_funds) as p_bar:\n        with ProcessPoolExecutor(max_workers=max_workers) as executor:\n            for date_list in executor.map(_fun, file_list):\n                if date_list:\n                    all_oldest_list.append(date_list[0])\n                for date in date_list:\n                    if date not in _dict_count_trade:\n                        _dict_count_trade[date] = 0\n\n                    _dict_count_trade[date] += 1\n\n                p_bar.update()\n\n    logger.info(f\"count how many funds have founded in this day......\")\n    _dict_count_founding = {date: _number_all_funds for date in _dict_count_trade}  # dict{date:count}\n    with tqdm(total=_number_all_funds) as p_bar:\n        for oldest_date in all_oldest_list:\n            for date in _dict_count_founding.keys():\n                if date < oldest_date:\n                    _dict_count_founding[date] -= 1\n\n    calendar = [\n        date for date, count in _dict_count_trade.items() if count >= max(int(count * threshold), minimum_count)\n    ]\n\n    return calendar\n\n\ndef get_hs_stock_symbols() -> list:\n    \"\"\"get SH/SZ stock symbols\n\n    Returns\n    -------\n        stock symbols\n    \"\"\"\n    global _HS_SYMBOLS  # pylint: disable=W0603\n\n    def _get_symbol():\n        \"\"\"\n        Get the stock pool from a web page and process it into the format required by yahooquery.\n        Format of data retrieved from the web page: 600519, 000001\n        The data format required by yahooquery: 600519.ss, 000001.sz\n\n        Returns\n        -------\n            set: Returns the set of symbol codes.\n\n        Examples:\n        -------\n            {600000.ss, 600001.ss, 600002.ss, 600003.ss, ...}\n        \"\"\"\n        # url = \"http://99.push2.eastmoney.com/api/qt/clist/get?pn=1&pz=10000&po=1&np=1&fs=m:0+t:6,m:0+t:80,m:1+t:2,m:1+t:23,m:0+t:81+s:2048&fields=f12\"\n\n        base_url = \"http://99.push2.eastmoney.com/api/qt/clist/get\"\n        params = {\n            \"pn\": 1,  # page number\n            \"pz\": 100,  # page size, default to 100\n            \"po\": 1,\n            \"np\": 1,\n            \"fs\": \"m:0+t:6,m:0+t:80,m:1+t:2,m:1+t:23,m:0+t:81+s:2048\",\n            \"fields\": \"f12\",\n        }\n\n        _symbols = []\n        page = 1\n\n        while True:\n            params[\"pn\"] = page\n            try:\n                resp = requests.get(base_url, params=params, timeout=None)\n                resp.raise_for_status()\n                data = resp.json()\n\n                # Check if response contains valid data\n                if not data or \"data\" not in data or not data[\"data\"] or \"diff\" not in data[\"data\"]:\n                    logger.warning(f\"Invalid response structure on page {page}\")\n                    break\n\n                # fetch the current page data\n                current_symbols = [_v[\"f12\"] for _v in data[\"data\"][\"diff\"]]\n\n                if not current_symbols:  # It's the last page if there is no data in current page\n                    logger.info(f\"Last page reached: {page - 1}\")\n                    break\n\n                _symbols.extend(current_symbols)\n\n                # show progress\n                logger.info(\n                    f\"Page {page}: fetch {len(current_symbols)} stocks:[{current_symbols[0]} ... {current_symbols[-1]}]\"\n                )\n\n                page += 1\n\n                # sleep time to avoid overloading the server\n                time.sleep(0.5)\n\n            except requests.exceptions.HTTPError as e:\n                raise requests.exceptions.HTTPError(\n                    f\"Request to {base_url} failed with status code {resp.status_code}\"\n                ) from e\n            except Exception as e:\n                logger.warning(\"An error occurred while extracting data from the response.\")\n                raise\n\n        if len(_symbols) < 3900:\n            raise ValueError(\"The complete list of stocks is not available.\")\n\n        # Add suffix after the stock code to conform to yahooquery standard, otherwise the data will not be fetched.\n        _symbols = [\n            _symbol + \".ss\" if _symbol.startswith(\"6\") else _symbol + \".sz\" if _symbol.startswith((\"0\", \"3\")) else None\n            for _symbol in _symbols\n        ]\n        _symbols = [_symbol for _symbol in _symbols if _symbol is not None]\n\n        return set(_symbols)\n\n    if _HS_SYMBOLS is None:\n        symbols = set()\n        _retry = 60\n        # It may take multiple times to get the complete\n        while len(symbols) < MINIMUM_SYMBOLS_NUM:\n            symbols |= _get_symbol()\n            time.sleep(3)\n\n        symbol_cache_path = Path(\"~/.cache/hs_symbols_cache.pkl\").expanduser().resolve()\n        symbol_cache_path.parent.mkdir(parents=True, exist_ok=True)\n        if symbol_cache_path.exists():\n            with symbol_cache_path.open(\"rb\") as fp:\n                cache_symbols = restricted_pickle_load(fp)\n                symbols |= cache_symbols\n        with symbol_cache_path.open(\"wb\") as fp:\n            pickle.dump(symbols, fp)\n\n        _HS_SYMBOLS = sorted(list(symbols))\n\n    return _HS_SYMBOLS\n\n\ndef get_us_stock_symbols(qlib_data_path: [str, Path] = None) -> list:\n    \"\"\"get US stock symbols\n\n    Returns\n    -------\n        stock symbols\n    \"\"\"\n    import akshare as ak  # pylint: disable=C0415\n\n    global _US_SYMBOLS  # pylint: disable=W0603\n\n    @deco_retry\n    def _get_eastmoney():\n        df = ak.get_us_stock_name()\n        _symbols = df[\"symbol\"].to_list()\n\n        if len(_symbols) < 8000:\n            raise ValueError(\"request error\")\n\n        return _symbols\n\n    @deco_retry\n    def _get_nasdaq():\n        _res_symbols = []\n        for _name in [\"otherlisted\", \"nasdaqtraded\"]:\n            url = f\"ftp://ftp.nasdaqtrader.com/SymbolDirectory/{_name}.txt\"\n            df = pd.read_csv(url, sep=\"|\")\n            df = df.rename(columns={\"ACT Symbol\": \"Symbol\"})\n            _symbols = df[\"Symbol\"].dropna()\n            _symbols = _symbols.str.replace(\"$\", \"-P\", regex=False)\n            _symbols = _symbols.str.replace(\".W\", \"-WT\", regex=False)\n            _symbols = _symbols.str.replace(\".U\", \"-UN\", regex=False)\n            _symbols = _symbols.str.replace(\".R\", \"-RI\", regex=False)\n            _symbols = _symbols.str.replace(\".\", \"-\", regex=False)\n            _res_symbols += _symbols.unique().tolist()\n        return _res_symbols\n\n    @deco_retry\n    def _get_nyse():\n        url = \"https://www.nyse.com/api/quotes/filter\"\n        _parms = {\n            \"instrumentType\": \"EQUITY\",\n            \"pageNumber\": 1,\n            \"sortColumn\": \"NORMALIZED_TICKER\",\n            \"sortOrder\": \"ASC\",\n            \"maxResultsPerPage\": 10000,\n            \"filterToken\": \"\",\n        }\n        resp = requests.post(url, json=_parms, timeout=None)\n        if resp.status_code != 200:\n            raise ValueError(\"request error\")\n\n        try:\n            _symbols = [_v[\"symbolTicker\"].replace(\"-\", \"-P\") for _v in resp.json()]\n        except Exception as e:\n            logger.warning(f\"request error: {e}\")\n            _symbols = []\n        return _symbols\n\n    if _US_SYMBOLS is None:\n        _all_symbols = _get_eastmoney() + _get_nasdaq() + _get_nyse()\n        if qlib_data_path is not None:\n            for _index in [\"nasdaq100\", \"sp500\"]:\n                ins_df = pd.read_csv(\n                    Path(qlib_data_path).joinpath(f\"instruments/{_index}.txt\"),\n                    sep=\"\\t\",\n                    names=[\"symbol\", \"start_date\", \"end_date\"],\n                )\n                _all_symbols += ins_df[\"symbol\"].unique().tolist()\n\n        def _format(s_):\n            s_ = s_.replace(\".\", \"-\")\n            s_ = s_.strip(\"$\")\n            s_ = s_.strip(\"*\")\n            return s_\n\n        _US_SYMBOLS = sorted(set(map(_format, filter(lambda x: len(x) < 8 and not x.endswith(\"WS\"), _all_symbols))))\n\n    return _US_SYMBOLS\n\n\ndef get_in_stock_symbols(qlib_data_path: [str, Path] = None) -> list:\n    \"\"\"get IN stock symbols\n\n    Returns\n    -------\n        stock symbols\n    \"\"\"\n    global _IN_SYMBOLS  # pylint: disable=W0603\n\n    @deco_retry\n    def _get_nifty():\n        url = f\"https://www1.nseindia.com/content/equities/EQUITY_L.csv\"\n        df = pd.read_csv(url)\n        df = df.rename(columns={\"SYMBOL\": \"Symbol\"})\n        df[\"Symbol\"] = df[\"Symbol\"] + \".NS\"\n        _symbols = df[\"Symbol\"].dropna()\n        _symbols = _symbols.unique().tolist()\n        return _symbols\n\n    if _IN_SYMBOLS is None:\n        _all_symbols = _get_nifty()\n        if qlib_data_path is not None:\n            for _index in [\"nifty\"]:\n                ins_df = pd.read_csv(\n                    Path(qlib_data_path).joinpath(f\"instruments/{_index}.txt\"),\n                    sep=\"\\t\",\n                    names=[\"symbol\", \"start_date\", \"end_date\"],\n                )\n                _all_symbols += ins_df[\"symbol\"].unique().tolist()\n\n        def _format(s_):\n            s_ = s_.replace(\".\", \"-\")\n            s_ = s_.strip(\"$\")\n            s_ = s_.strip(\"*\")\n            return s_\n\n        _IN_SYMBOLS = sorted(set(_all_symbols))\n\n    return _IN_SYMBOLS\n\n\ndef get_br_stock_symbols(qlib_data_path: [str, Path] = None) -> list:\n    \"\"\"get Brazil(B3) stock symbols\n\n    Returns\n    -------\n        B3 stock symbols\n    \"\"\"\n    global _BR_SYMBOLS  # pylint: disable=W0603\n\n    @deco_retry\n    def _get_ibovespa():\n        _symbols = []\n        url = \"https://www.fundamentus.com.br/detalhes.php?papel=\"\n\n        # Request\n        agent = {\"User-Agent\": \"Mozilla/5.0\"}\n        page = requests.get(url, headers=agent, timeout=None)\n\n        # BeautifulSoup\n        soup = BeautifulSoup(page.content, \"html.parser\")\n        tbody = soup.find(\"tbody\")\n\n        children = tbody.findChildren(\"a\", recursive=True)\n        for child in children:\n            _symbols.append(str(child).rsplit('\"', maxsplit=1)[-1].split(\">\")[1].split(\"<\")[0])\n\n        return _symbols\n\n    if _BR_SYMBOLS is None:\n        _all_symbols = _get_ibovespa()\n        if qlib_data_path is not None:\n            for _index in [\"ibov\"]:\n                ins_df = pd.read_csv(\n                    Path(qlib_data_path).joinpath(f\"instruments/{_index}.txt\"),\n                    sep=\"\\t\",\n                    names=[\"symbol\", \"start_date\", \"end_date\"],\n                )\n                _all_symbols += ins_df[\"symbol\"].unique().tolist()\n\n        def _format(s_):\n            s_ = s_.strip()\n            s_ = s_.strip(\"$\")\n            s_ = s_.strip(\"*\")\n            s_ = s_ + \".SA\"\n            return s_\n\n        _BR_SYMBOLS = sorted(set(map(_format, _all_symbols)))\n\n    return _BR_SYMBOLS\n\n\ndef get_en_fund_symbols(qlib_data_path: [str, Path] = None) -> list:\n    \"\"\"get en fund symbols\n\n    Returns\n    -------\n        fund symbols in China\n    \"\"\"\n    global _EN_FUND_SYMBOLS  # pylint: disable=W0603\n\n    @deco_retry\n    def _get_eastmoney():\n        url = \"http://fund.eastmoney.com/js/fundcode_search.js\"\n        resp = requests.get(url, timeout=None)\n        if resp.status_code != 200:\n            raise ValueError(\"request error\")\n        try:\n            _symbols = []\n            for sub_data in re.findall(r\"[\\[](.*?)[\\]]\", resp.content.decode().split(\"= [\")[-1].replace(\"];\", \"\")):\n                data = sub_data.replace('\"', \"\").replace(\"'\", \"\")\n                # TODO: do we need other information, like fund_name from ['000001', 'HXCZHH', '华夏成长混合', '混合型', 'HUAXIACHENGZHANGHUNHE']\n                _symbols.append(data.split(\",\")[0])\n        except Exception as e:\n            logger.warning(f\"request error: {e}\")\n            raise\n        if len(_symbols) < 8000:\n            raise ValueError(\"request error\")\n        return _symbols\n\n    if _EN_FUND_SYMBOLS is None:\n        _all_symbols = _get_eastmoney()\n\n        _EN_FUND_SYMBOLS = sorted(set(_all_symbols))\n\n    return _EN_FUND_SYMBOLS\n\n\ndef symbol_suffix_to_prefix(symbol: str, capital: bool = True) -> str:\n    \"\"\"symbol suffix to prefix\n\n    Parameters\n    ----------\n    symbol: str\n        symbol\n    capital : bool\n        by default True\n    Returns\n    -------\n\n    \"\"\"\n    code, exchange = symbol.split(\".\")\n    if exchange.lower() in [\"sh\", \"ss\"]:\n        res = f\"sh{code}\"\n    else:\n        res = f\"{exchange}{code}\"\n    return res.upper() if capital else res.lower()\n\n\ndef symbol_prefix_to_sufix(symbol: str, capital: bool = True) -> str:\n    \"\"\"symbol prefix to sufix\n\n    Parameters\n    ----------\n    symbol: str\n        symbol\n    capital : bool\n        by default True\n    Returns\n    -------\n\n    \"\"\"\n    res = f\"{symbol[:-2]}.{symbol[-2:]}\"\n    return res.upper() if capital else res.lower()\n\n\ndef deco_retry(retry: int = 5, retry_sleep: int = 3):\n    def deco_func(func):\n        @functools.wraps(func)\n        def wrapper(*args, **kwargs):\n            _retry = 5 if callable(retry) else retry\n            _result = None\n            for _i in range(1, _retry + 1):\n                try:\n                    _result = func(*args, **kwargs)\n                    break\n\n                except Exception as e:\n                    logger.warning(f\"{func.__name__}: {_i} :{e}\")\n                    if _i == _retry:\n                        raise\n\n                time.sleep(retry_sleep)\n            return _result\n\n        return wrapper\n\n    return deco_func(retry) if callable(retry) else deco_func\n\n\ndef get_trading_date_by_shift(trading_list: list, trading_date: pd.Timestamp, shift: int = 1):\n    \"\"\"get trading date by shift\n\n    Parameters\n    ----------\n    trading_list: list\n        trading calendar list\n    shift : int\n        shift, default is 1\n\n    trading_date : pd.Timestamp\n        trading date\n    Returns\n    -------\n\n    \"\"\"\n    trading_date = pd.Timestamp(trading_date)\n    left_index = bisect.bisect_left(trading_list, trading_date)\n    try:\n        res = trading_list[left_index + shift]\n    except IndexError:\n        res = trading_date\n    return res\n\n\ndef generate_minutes_calendar_from_daily(\n    calendars: Iterable,\n    freq: str = \"1min\",\n    am_range: Tuple[str, str] = (\"09:30:00\", \"11:29:00\"),\n    pm_range: Tuple[str, str] = (\"13:00:00\", \"14:59:00\"),\n) -> pd.Index:\n    \"\"\"generate minutes calendar\n\n    Parameters\n    ----------\n    calendars: Iterable\n        daily calendar\n    freq: str\n        by default 1min\n    am_range: Tuple[str, str]\n        AM Time Range, by default China-Stock: (\"09:30:00\", \"11:29:00\")\n    pm_range: Tuple[str, str]\n        PM Time Range, by default China-Stock: (\"13:00:00\", \"14:59:00\")\n\n    \"\"\"\n    daily_format: str = \"%Y-%m-%d\"\n    res = []\n    for _day in calendars:\n        for _range in [am_range, pm_range]:\n            res.append(\n                pd.date_range(\n                    f\"{pd.Timestamp(_day).strftime(daily_format)} {_range[0]}\",\n                    f\"{pd.Timestamp(_day).strftime(daily_format)} {_range[1]}\",\n                    freq=freq,\n                )\n            )\n\n    return pd.Index(sorted(set(np.hstack(res))))\n\n\ndef get_instruments(\n    qlib_dir: str,\n    index_name: str,\n    method: str = \"parse_instruments\",\n    freq: str = \"day\",\n    request_retry: int = 5,\n    retry_sleep: int = 3,\n    market_index: str = \"cn_index\",\n):\n    \"\"\"\n\n    Parameters\n    ----------\n    qlib_dir: str\n        qlib data dir, default \"Path(__file__).parent/qlib_data\"\n    index_name: str\n        index name, value from [\"csi100\", \"csi300\"]\n    method: str\n        method, value from [\"parse_instruments\", \"save_new_companies\"]\n    freq: str\n        freq, value from [\"day\", \"1min\"]\n    request_retry: int\n        request retry, by default 5\n    retry_sleep: int\n        request sleep, by default 3\n    market_index: str\n        Where the files to obtain the index are located,\n        for example data_collector.cn_index.collector\n\n    Examples\n    -------\n        # parse instruments\n        $ python collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --method parse_instruments\n\n        # parse new companies\n        $ python collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --method save_new_companies\n\n    \"\"\"\n    _cur_module = importlib.import_module(\"data_collector.{}.collector\".format(market_index))\n    obj = getattr(_cur_module, f\"{index_name.upper()}Index\")(\n        qlib_dir=qlib_dir, index_name=index_name, freq=freq, request_retry=request_retry, retry_sleep=retry_sleep\n    )\n    getattr(obj, method)()\n\n\ndef _get_all_1d_data(_date_field_name: str, _symbol_field_name: str, _1d_data_all: pd.DataFrame):\n    df = copy.deepcopy(_1d_data_all)\n    df.reset_index(inplace=True)\n    df.rename(columns={\"datetime\": _date_field_name, \"instrument\": _symbol_field_name}, inplace=True)\n    df.columns = list(map(lambda x: x[1:] if x.startswith(\"$\") else x, df.columns))\n    return df\n\n\ndef get_1d_data(\n    _date_field_name: str,\n    _symbol_field_name: str,\n    symbol: str,\n    start: str,\n    end: str,\n    _1d_data_all: pd.DataFrame,\n) -> pd.DataFrame:\n    \"\"\"get 1d data\n\n    Returns\n    ------\n        data_1d: pd.DataFrame\n            data_1d.columns = [_date_field_name, _symbol_field_name, \"paused\", \"volume\", \"factor\", \"close\"]\n\n    \"\"\"\n    _all_1d_data = _get_all_1d_data(_date_field_name, _symbol_field_name, _1d_data_all)\n    return _all_1d_data[\n        (_all_1d_data[_symbol_field_name] == symbol.upper())\n        & (_all_1d_data[_date_field_name] >= pd.Timestamp(start))\n        & (_all_1d_data[_date_field_name] < pd.Timestamp(end))\n    ]\n\n\ndef calc_adjusted_price(\n    df: pd.DataFrame,\n    _1d_data_all: pd.DataFrame,\n    _date_field_name: str,\n    _symbol_field_name: str,\n    frequence: str,\n    consistent_1d: bool = True,\n    calc_paused: bool = True,\n) -> pd.DataFrame:\n    \"\"\"calc adjusted price\n    This method does 4 things.\n    1. Adds the `paused` field.\n        - The added paused field comes from the paused field of the 1d data.\n    2. Aligns the time of the 1d data.\n    3. The data is reweighted.\n        - The reweighting method:\n            - volume / factor\n            - open * factor\n            - high * factor\n            - low * factor\n            - close * factor\n    4. Called `calc_paused_num` method to add the `paused_num` field.\n        - The `paused_num` is the number of consecutive days of trading suspension.\n    \"\"\"\n    # TODO: using daily data factor\n    if df.empty:\n        return df\n    df = df.copy()\n    df.drop_duplicates(subset=_date_field_name, inplace=True)\n    df.sort_values(_date_field_name, inplace=True)\n    symbol = df.iloc[0][_symbol_field_name]\n    df[_date_field_name] = pd.to_datetime(df[_date_field_name])\n    # get 1d data from qlib\n    _start = pd.Timestamp(df[_date_field_name].min()).strftime(\"%Y-%m-%d\")\n    _end = (pd.Timestamp(df[_date_field_name].max()) + pd.Timedelta(days=1)).strftime(\"%Y-%m-%d\")\n    data_1d: pd.DataFrame = get_1d_data(_date_field_name, _symbol_field_name, symbol, _start, _end, _1d_data_all)\n    data_1d = data_1d.copy()\n    if data_1d is None or data_1d.empty:\n        df[\"factor\"] = 1 / df.loc[df[\"close\"].first_valid_index()][\"close\"]\n        # TODO: np.nan or 1 or 0\n        df[\"paused\"] = np.nan\n    else:\n        # NOTE: volume is np.nan or volume <= 0, paused = 1\n        # FIXME: find a more accurate data source\n        data_1d[\"paused\"] = 0\n        data_1d.loc[(data_1d[\"volume\"].isna()) | (data_1d[\"volume\"] <= 0), \"paused\"] = 1\n        data_1d = data_1d.set_index(_date_field_name)\n\n        # add factor from 1d data\n        # NOTE: 1d data info:\n        #   - Close price adjusted for splits. Adjusted close price adjusted for both dividends and splits.\n        #   - data_1d.adjclose: Adjusted close price adjusted for both dividends and splits.\n        #   - data_1d.close: `data_1d.adjclose / (close for the first trading day that is not np.nan)`\n        def _calc_factor(df_1d: pd.DataFrame):\n            try:\n                _date = pd.Timestamp(pd.Timestamp(df_1d[_date_field_name].iloc[0]).date())\n                df_1d[\"factor\"] = data_1d.loc[_date][\"close\"] / df_1d.loc[df_1d[\"close\"].last_valid_index()][\"close\"]\n                df_1d[\"paused\"] = data_1d.loc[_date][\"paused\"]\n            except Exception:\n                df_1d[\"factor\"] = np.nan\n                df_1d[\"paused\"] = np.nan\n            return df_1d\n\n        df = df.groupby([df[_date_field_name].dt.date], group_keys=False).apply(_calc_factor)\n        if consistent_1d:\n            # the date sequence is consistent with 1d\n            df.set_index(_date_field_name, inplace=True)\n            df = df.reindex(\n                generate_minutes_calendar_from_daily(\n                    calendars=pd.to_datetime(data_1d.reset_index()[_date_field_name].drop_duplicates()),\n                    freq=frequence,\n                    am_range=(\"09:30:00\", \"11:29:00\"),\n                    pm_range=(\"13:00:00\", \"14:59:00\"),\n                )\n            )\n            df[_symbol_field_name] = df.loc[df[_symbol_field_name].first_valid_index()][_symbol_field_name]\n            df.index.names = [_date_field_name]\n            df.reset_index(inplace=True)\n    for _col in [\"open\", \"close\", \"high\", \"low\", \"volume\"]:\n        if _col not in df.columns:\n            continue\n        if _col == \"volume\":\n            df[_col] = df[_col] / df[\"factor\"]\n        else:\n            df[_col] = df[_col] * df[\"factor\"]\n    if calc_paused:\n        df = calc_paused_num(df, _date_field_name, _symbol_field_name)\n    return df\n\n\ndef calc_paused_num(df: pd.DataFrame, _date_field_name, _symbol_field_name):\n    \"\"\"calc paused num\n    This method adds the paused_num field\n        - The `paused_num` is the number of consecutive days of trading suspension.\n    \"\"\"\n    _symbol = df.iloc[0][_symbol_field_name]\n    df = df.copy()\n    df[\"_tmp_date\"] = df[_date_field_name].apply(lambda x: pd.Timestamp(x).date())\n    # remove data that starts and ends with `np.nan` all day\n    all_data = []\n    # Record the number of consecutive trading days where the whole day is nan, to remove the last trading day where the whole day is nan\n    all_nan_nums = 0\n    # Record the number of consecutive occurrences of trading days that are not nan throughout the day\n    not_nan_nums = 0\n    for _date, _df in df.groupby(\"_tmp_date\", group_keys=False):\n        _df[\"paused\"] = 0\n        if not _df.loc[_df[\"volume\"] < 0].empty:\n            logger.warning(f\"volume < 0, will fill np.nan: {_date} {_symbol}\")\n            _df.loc[_df[\"volume\"] < 0, \"volume\"] = np.nan\n\n        check_fields = set(_df.columns) - {\n            \"_tmp_date\",\n            \"paused\",\n            \"factor\",\n            _date_field_name,\n            _symbol_field_name,\n        }\n        if _df.loc[:, list(check_fields)].isna().values.all() or (_df[\"volume\"] == 0).all():\n            all_nan_nums += 1\n            not_nan_nums = 0\n            _df[\"paused\"] = 1\n            if all_data:\n                _df[\"paused_num\"] = not_nan_nums\n                all_data.append(_df)\n        else:\n            all_nan_nums = 0\n            not_nan_nums += 1\n            _df[\"paused_num\"] = not_nan_nums\n            all_data.append(_df)\n    all_data = all_data[: len(all_data) - all_nan_nums]\n    if all_data:\n        df = pd.concat(all_data, sort=False)\n    else:\n        logger.warning(f\"data is empty: {_symbol}\")\n        df = pd.DataFrame()\n        return df\n    del df[\"_tmp_date\"]\n    return df\n\n\nif __name__ == \"__main__\":\n    assert len(get_hs_stock_symbols()) >= MINIMUM_SYMBOLS_NUM\n"
  },
  {
    "path": "scripts/data_collector/yahoo/README.md",
    "content": "\n- [Collector Data](#collector-data)\n  - [Get Qlib data](#get-qlib-databin-file)\n  - [Collector *YahooFinance* data to qlib](#collector-yahoofinance-data-to-qlib)\n  - [Automatic update of daily frequency data](#automatic-update-of-daily-frequency-datafrom-yahoo-finance)\n- [Using qlib data](#using-qlib-data)\n\n\n# Collect Data From Yahoo Finance\n\n> *Please pay **ATTENTION** that the data is collected from [Yahoo Finance](https://finance.yahoo.com/lookup) and the data might not be perfect. We recommend users to prepare their own data if they have high-quality dataset. For more information, users can refer to the [related document](https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format)*\n\n**NOTE**:  Yahoo! Finance has blocked the access from China. Please change your network if you want to use the Yahoo data crawler.\n\n>  **Examples of abnormal data**\n\n- [SH000661](https://finance.yahoo.com/quote/000661.SZ/history?period1=1558310400&period2=1590796800&interval=1d&filter=history&frequency=1d)\n- [SZ300144](https://finance.yahoo.com/quote/300144.SZ/history?period1=1557446400&period2=1589932800&interval=1d&filter=history&frequency=1d)\n\nWe have considered **STOCK PRICE ADJUSTMENT**, but some price series seem still very abnormal.\n\n## Requirements\n\n```bash\npip install -r requirements.txt\n```\n\n## Collector Data\n\n### Get Qlib data(`bin file`)\n  > `qlib-data` from *YahooFinance*, is the data that has been dumped and can be used directly in `qlib`.\n  > This ready-made qlib-data is not updated regularly. If users want the latest data, please follow [these steps](#collector-yahoofinance-data-to-qlib) download the latest data. \n\n  - get data: `python scripts/get_data.py qlib_data`\n  - parameters:\n    - `target_dir`: save dir, by default *~/.qlib/qlib_data/cn_data*\n    - `version`: dataset version, value from [`v1`, `v2`], by default `v1`\n      - `v2` end date is *2021-06*, `v1` end date is *2020-09*\n      - If users want to incrementally update data, they need to use yahoo collector to [collect data from scratch](#collector-yahoofinance-data-to-qlib).\n      - **the [benchmarks](https://github.com/microsoft/qlib/tree/main/examples/benchmarks) for qlib use `v1`**, *due to the unstable access to historical data by YahooFinance, there are some differences between `v2` and `v1`*\n    - `interval`: `1d` or `1min`, by default `1d`\n    - `region`: `cn` or `us` or `in`, by default `cn`\n    - `delete_old`: delete existing data from `target_dir`(*features, calendars, instruments, dataset_cache, features_cache*), value from [`True`, `False`], by default `True`\n    - `exists_skip`: traget_dir data already exists, skip `get_data`, value from [`True`, `False`], by default `False`\n  - examples:\n    ```bash\n    # cn 1d\n    python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn\n    # cn 1min\n    python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data_1min --region cn --interval 1min\n    # us 1d\n    python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/us_data --region us --interval 1d\n    ```\n\n### Collector *YahooFinance* data to qlib\n> collector *YahooFinance* data and *dump* into `qlib` format.\n> If the above ready-made data can't meet users' requirements,  users can follow this section to crawl the latest data and convert it to qlib-data.\n  1. download data to csv: `python scripts/data_collector/yahoo/collector.py download_data`\n     \n     This will download the raw data such as high, low, open, close, adjclose price from yahoo to a local directory. One file per symbol.\n\n     - parameters:\n          - `source_dir`: save the directory\n          - `interval`: `1d` or `1min`, by default `1d`\n            > **due to the limitation of the *YahooFinance API*, only the last month's data is available in `1min`**\n          - `region`: `CN` or `US` or `IN` or `BR`, by default `CN`\n          - `delay`: `time.sleep(delay)`, by default *0.5*\n          - `start`: start datetime, by default *\"2000-01-01\"*; *closed interval(including start)*\n          - `end`: end datetime, by default `pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))`; *open interval(excluding end)*\n          - `max_workers`: get the number of concurrent symbols, it is not recommended to change this parameter in order to maintain the integrity of the symbol data, by default *1*\n          - `check_data_length`: check the number of rows per *symbol*, by default `None`\n            > if `len(symbol_df) < check_data_length`, it will be re-fetched, with the number of re-fetches coming from the `max_collector_count` parameter\n          - `max_collector_count`: number of *\"failed\"* symbol retries, by default 2\n     - examples:\n          ```bash\n          # cn 1d data\n          python collector.py download_data --source_dir ~/.qlib/stock_data/source/cn_data --start 2020-01-01 --end 2020-12-31 --delay 1 --interval 1d --region CN\n          # cn 1min data\n          python collector.py download_data --source_dir ~/.qlib/stock_data/source/cn_data_1min --delay 1 --interval 1min --region CN\n\n          # us 1d data\n          python collector.py download_data --source_dir ~/.qlib/stock_data/source/us_data --start 2020-01-01 --end 2020-12-31 --delay 1 --interval 1d --region US\n          # us 1min data\n          python collector.py download_data --source_dir ~/.qlib/stock_data/source/us_data_1min --delay 1 --interval 1min --region US\n\n          # in 1d data\n          python collector.py download_data --source_dir ~/.qlib/stock_data/source/in_data --start 2020-01-01 --end 2020-12-31 --delay 1 --interval 1d --region IN\n          # in 1min data\n          python collector.py download_data --source_dir ~/.qlib/stock_data/source/in_data_1min --delay 1 --interval 1min --region IN\n\n          # br 1d data\n          python collector.py download_data --source_dir ~/.qlib/stock_data/source/br_data --start 2003-01-03 --end 2022-03-01 --delay 1 --interval 1d --region BR\n          # br 1min data\n          python collector.py download_data --source_dir ~/.qlib/stock_data/source/br_data_1min --delay 1 --interval 1min --region BR\n          ```\n  2. normalize data: `python scripts/data_collector/yahoo/collector.py normalize_data`\n     \n     This will:\n     1. Normalize high, low, close, open price using adjclose.\n     2. Normalize the high, low, close, open price so that the first valid trading date's close price is 1. \n\n     - parameters:\n          - `source_dir`: csv directory\n          - `normalize_dir`: result directory\n          - `max_workers`: number of concurrent, by default *1*\n          - `interval`: `1d` or `1min`, by default `1d`\n            > if **`interval == 1min`**, `qlib_data_1d_dir` cannot be `None`\n          - `region`: `CN` or `US` or `IN`, by default `CN`\n          - `date_field_name`: column *name* identifying time in csv files, by default `date`\n          - `symbol_field_name`: column *name* identifying symbol in csv files, by default `symbol`\n          - `end_date`: if not `None`, normalize the last date saved (*including end_date*); if `None`, it will ignore this parameter; by default `None`\n          - `qlib_data_1d_dir`: qlib directory(1d data)\n            ```\n            if interval==1min, qlib_data_1d_dir cannot be None, normalize 1min needs to use 1d data;\n        \n                qlib_data_1d can be obtained like this:\n                    $ python scripts/get_data.py qlib_data --target_dir <qlib_data_1d_dir> --interval 1d\n                    $ python scripts/data_collector/yahoo/collector.py update_data_to_bin --qlib_data_1d_dir <qlib_data_1d_dir> --end_date <end_date>\n                or:\n                    download 1d data from YahooFinance\n            \n            ```\n      - examples:\n        ```bash\n        # normalize 1d cn\n        python collector.py normalize_data --source_dir ~/.qlib/stock_data/source/cn_data --normalize_dir ~/.qlib/stock_data/source/cn_1d_nor --region CN --interval 1d\n\n        # normalize 1min cn\n        python collector.py normalize_data --qlib_data_1d_dir ~/.qlib/qlib_data/cn_data --source_dir ~/.qlib/stock_data/source/cn_data_1min --normalize_dir ~/.qlib/stock_data/source/cn_1min_nor --region CN --interval 1min\n\n        # normalize 1d br\n        python scripts/data_collector/yahoo/collector.py normalize_data --source_dir ~/.qlib/stock_data/source/br_data --normalize_dir ~/.qlib/stock_data/source/br_1d_nor --region BR --interval 1d\n\n        # normalize 1min br\n        python collector.py normalize_data --qlib_data_1d_dir ~/.qlib/qlib_data/br_data --source_dir ~/.qlib/stock_data/source/br_data_1min --normalize_dir ~/.qlib/stock_data/source/br_1min_nor --region BR --interval 1min\n        ```\n  3. dump data: `python scripts/dump_bin.py dump_all`\n    \n     This will convert the normalized csv in `feature` directory as numpy array and store the normalized data one file per column and one symbol per directory. \n    \n     - parameters:\n       - `data_path`: stock data path or directory, **normalize result(normalize_dir)**\n       - `qlib_dir`: qlib(dump) data director\n       - `freq`: transaction frequency, by default `day`\n         > `freq_map = {1d:day, 1mih: 1min}`\n       - `max_workers`: number of threads, by default *16*\n       - `include_fields`: dump fields, by default `\"\"`\n       - `exclude_fields`: fields not dumped, by default `\"\"\"\n         > dump_fields = `include_fields if include_fields else set(symbol_df.columns) - set(exclude_fields) exclude_fields else symbol_df.columns`\n       - `symbol_field_name`: column *name* identifying symbol in csv files, by default `symbol`\n       - `date_field_name`: column *name* identifying time in csv files, by default `date`\n       - `file_suffix`: stock data file format, by default \".csv\"\n     - examples:\n       ```bash\n       # dump 1d cn\n       python dump_bin.py dump_all --data_path ~/.qlib/stock_data/source/cn_1d_nor --qlib_dir ~/.qlib/qlib_data/cn_data --freq day --exclude_fields date,symbol --file_suffix .csv\n       # dump 1min cn\n       python dump_bin.py dump_all --data_path ~/.qlib/stock_data/source/cn_1min_nor --qlib_dir ~/.qlib/qlib_data/cn_data_1min --freq 1min --exclude_fields date,symbol --file_suffix .csv\n       ```\n\n### Automatic update of daily frequency data(from yahoo finance)\n  > It is recommended that users update the data manually once (--trading_date 2021-05-25) and then set it to update automatically.\n  >\n  > **NOTE**: Users can't incrementally  update data based on the offline data provided by Qlib(some fields are removed to reduce the data size). Users should use [yahoo collector](https://github.com/microsoft/qlib/tree/main/scripts/data_collector/yahoo#automatic-update-of-daily-frequency-datafrom-yahoo-finance) to download Yahoo data from scratch and then incrementally update it.\n  > \n\n  * Automatic update of data to the \"qlib\" directory each trading day(Linux)\n      * use *crontab*: `crontab -e`\n      * set up timed tasks:\n\n        ```\n        * * * * 1-5 python <script path> update_data_to_bin --qlib_data_1d_dir <user data dir>\n        ```\n        * **script path**: *scripts/data_collector/yahoo/collector.py*\n\n  * Manual update of data\n      ```\n      python scripts/data_collector/yahoo/collector.py update_data_to_bin --qlib_data_1d_dir <user data dir> --end_date <end date>\n      ```\n      * `end_date`: end of trading day(not included)\n      * `check_data_length`: check the number of rows per *symbol*, by default `None`\n        > if `len(symbol_df) < check_data_length`, it will be re-fetched, with the number of re-fetches coming from the `max_collector_count` parameter\n\n  * `scripts/data_collector/yahoo/collector.py update_data_to_bin` parameters:\n      * `source_dir`: The directory where the raw data collected from the Internet is saved, default \"Path(__file__).parent/source\"\n      * `normalize_dir`: Directory for normalize data, default \"Path(__file__).parent/normalize\"\n      * `qlib_data_1d_dir`: the qlib data to be updated for yahoo, usually from: [download qlib data](https://github.com/microsoft/qlib/tree/main/scripts#download-cn-data)\n      * `end_date`: end datetime, default ``pd.Timestamp(trading_date + pd.Timedelta(days=1))``; open interval(excluding end)\n      * `region`: region, value from [\"CN\", \"US\"], default \"CN\"\n      * `interval`: interval, default \"1d\"(Currently only supports 1d data)\n      * `exists_skip`: exists skip, by default False\n\n## Using qlib data\n\n  ```python\n  import qlib\n  from qlib.data import D\n\n  # 1d data cn\n  # freq=day, freq default day\n  qlib.init(provider_uri=\"~/.qlib/qlib_data/cn_data\", region=\"cn\")\n  df = D.features(D.instruments(\"all\"), [\"$close\"], freq=\"day\")\n\n  # 1min data cn\n  # freq=1min\n  qlib.init(provider_uri=\"~/.qlib/qlib_data/cn_data_1min\", region=\"cn\")\n  inst = D.list_instruments(D.instruments(\"all\"), freq=\"1min\", as_list=True)\n  # get 100 symbols\n  df = D.features(inst[:100], [\"$close\"], freq=\"1min\")\n  # get all symbol data\n  # df = D.features(D.instruments(\"all\"), [\"$close\"], freq=\"1min\")\n\n  # 1d data us\n  qlib.init(provider_uri=\"~/.qlib/qlib_data/us_data\", region=\"us\")\n  df = D.features(D.instruments(\"all\"), [\"$close\"], freq=\"day\")\n\n  # 1min data us\n  qlib.init(provider_uri=\"~/.qlib/qlib_data/us_data_1min\", region=\"cn\")\n  inst = D.list_instruments(D.instruments(\"all\"), freq=\"1min\", as_list=True)\n  # get 100 symbols\n  df = D.features(inst[:100], [\"$close\"], freq=\"1min\")\n  # get all symbol data\n  # df = D.features(D.instruments(\"all\"), [\"$close\"], freq=\"1min\")\n  ```\n\n"
  },
  {
    "path": "scripts/data_collector/yahoo/collector.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport abc\nimport sys\nimport copy\nimport time\nimport datetime\nimport importlib\nfrom abc import ABC\nimport multiprocessing\nfrom pathlib import Path\nfrom typing import Iterable\n\nimport fire\nimport requests\nimport numpy as np\nimport pandas as pd\nfrom loguru import logger\nfrom yahooquery import Ticker\nfrom dateutil.tz import tzlocal\n\nimport qlib\nfrom qlib.data import D\nfrom qlib.tests.data import GetData\nfrom qlib.utils import code_to_fname, fname_to_code, exists_qlib_data\nfrom qlib.constant import REG_CN as REGION_CN\n\nCUR_DIR = Path(__file__).resolve().parent\nsys.path.append(str(CUR_DIR.parent.parent))\n\nfrom dump_bin import DumpDataUpdate\nfrom data_collector.base import BaseCollector, BaseNormalize, BaseRun, Normalize\nfrom data_collector.utils import (\n    deco_retry,\n    get_calendar_list,\n    get_hs_stock_symbols,\n    get_us_stock_symbols,\n    get_in_stock_symbols,\n    get_br_stock_symbols,\n    generate_minutes_calendar_from_daily,\n    calc_adjusted_price,\n)\n\nINDEX_BENCH_URL = \"http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.{index_code}&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg={begin}&end={end}\"\n\n\nclass YahooCollector(BaseCollector):\n    retry = 5  # Configuration attribute.  How many times will it try to re-request the data if the network fails.\n\n    def __init__(\n        self,\n        save_dir: [str, Path],\n        start=None,\n        end=None,\n        interval=\"1d\",\n        max_workers=4,\n        max_collector_count=2,\n        delay=0,\n        check_data_length: int = None,\n        limit_nums: int = None,\n    ):\n        \"\"\"\n\n        Parameters\n        ----------\n        save_dir: str\n            stock save dir\n        max_workers: int\n            workers, default 4\n        max_collector_count: int\n            default 2\n        delay: float\n            time.sleep(delay), default 0\n        interval: str\n            freq, value from [1min, 1d], default 1min\n        start: str\n            start datetime, default None\n        end: str\n            end datetime, default None\n        check_data_length: int\n            check data length, by default None\n        limit_nums: int\n            using for debug, by default None\n        \"\"\"\n        super(YahooCollector, self).__init__(\n            save_dir=save_dir,\n            start=start,\n            end=end,\n            interval=interval,\n            max_workers=max_workers,\n            max_collector_count=max_collector_count,\n            delay=delay,\n            check_data_length=check_data_length,\n            limit_nums=limit_nums,\n        )\n\n        self.init_datetime()\n\n    def init_datetime(self):\n        if self.interval == self.INTERVAL_1min:\n            self.start_datetime = max(self.start_datetime, self.DEFAULT_START_DATETIME_1MIN)\n        elif self.interval == self.INTERVAL_1d:\n            pass\n        else:\n            raise ValueError(f\"interval error: {self.interval}\")\n\n        self.start_datetime = self.convert_datetime(self.start_datetime, self._timezone)\n        self.end_datetime = self.convert_datetime(self.end_datetime, self._timezone)\n\n    @staticmethod\n    def convert_datetime(dt: [pd.Timestamp, datetime.date, str], timezone):\n        try:\n            dt = pd.Timestamp(dt, tz=timezone).timestamp()\n            dt = pd.Timestamp(dt, tz=tzlocal(), unit=\"s\")\n        except ValueError as e:\n            pass\n        return dt\n\n    @property\n    @abc.abstractmethod\n    def _timezone(self):\n        raise NotImplementedError(\"rewrite get_timezone\")\n\n    @staticmethod\n    def get_data_from_remote(symbol, interval, start, end, show_1min_logging: bool = False):\n        error_msg = f\"{symbol}-{interval}-{start}-{end}\"\n\n        def _show_logging_func():\n            if interval == YahooCollector.INTERVAL_1min and show_1min_logging:\n                logger.warning(f\"{error_msg}:{_resp}\")\n\n        interval = \"1m\" if interval in [\"1m\", \"1min\"] else interval\n        try:\n            _resp = Ticker(symbol, asynchronous=False).history(interval=interval, start=start, end=end)\n            if isinstance(_resp, pd.DataFrame):\n                return _resp.reset_index()\n            elif isinstance(_resp, dict):\n                _temp_data = _resp.get(symbol, {})\n                if isinstance(_temp_data, str) or (\n                    isinstance(_resp, dict) and _temp_data.get(\"indicators\", {}).get(\"quote\", None) is None\n                ):\n                    _show_logging_func()\n            else:\n                _show_logging_func()\n        except Exception as e:\n            logger.warning(\n                f\"get data error: {symbol}--{start}--{end}\"\n                + \"Your data request fails. This may be caused by your firewall (e.g. GFW). Please switch your network if you want to access Yahoo! data\"\n            )\n\n    def get_data(\n        self, symbol: str, interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp\n    ) -> pd.DataFrame:\n        @deco_retry(retry_sleep=self.delay, retry=self.retry)\n        def _get_simple(start_, end_):\n            self.sleep()\n            _remote_interval = \"1m\" if interval == self.INTERVAL_1min else interval\n            resp = self.get_data_from_remote(\n                symbol,\n                interval=_remote_interval,\n                start=start_,\n                end=end_,\n            )\n            if resp is None or resp.empty:\n                raise ValueError(\n                    f\"get data error: {symbol}--{start_}--{end_}\" + \"The stock may be delisted, please check\"\n                )\n            return resp\n\n        _result = None\n        if interval == self.INTERVAL_1d:\n            try:\n                _result = _get_simple(start_datetime, end_datetime)\n            except ValueError as e:\n                pass\n        elif interval == self.INTERVAL_1min:\n            _res = []\n            _start = self.start_datetime\n            while _start < self.end_datetime:\n                _tmp_end = min(_start + pd.Timedelta(days=7), self.end_datetime)\n                try:\n                    _resp = _get_simple(_start, _tmp_end)\n                    _res.append(_resp)\n                except ValueError as e:\n                    pass\n                _start = _tmp_end\n            if _res:\n                _result = pd.concat(_res, sort=False).sort_values([\"symbol\", \"date\"])\n        else:\n            raise ValueError(f\"cannot support {self.interval}\")\n        return pd.DataFrame() if _result is None else _result\n\n    def collector_data(self):\n        \"\"\"collector data\"\"\"\n        super(YahooCollector, self).collector_data()\n        self.download_index_data()\n\n    @abc.abstractmethod\n    def download_index_data(self):\n        \"\"\"download index data\"\"\"\n        raise NotImplementedError(\"rewrite download_index_data\")\n\n\nclass YahooCollectorCN(YahooCollector, ABC):\n    def get_instrument_list(self):\n        logger.info(\"get HS stock symbols......\")\n        symbols = get_hs_stock_symbols()\n        logger.info(f\"get {len(symbols)} symbols.\")\n        return symbols\n\n    def normalize_symbol(self, symbol):\n        symbol_s = symbol.split(\".\")\n        symbol = f\"sh{symbol_s[0]}\" if symbol_s[-1] == \"ss\" else f\"sz{symbol_s[0]}\"\n        return symbol\n\n    @property\n    def _timezone(self):\n        return \"Asia/Shanghai\"\n\n\nclass YahooCollectorCN1d(YahooCollectorCN):\n    def download_index_data(self):\n        # TODO: from MSN\n        _format = \"%Y%m%d\"\n        _begin = self.start_datetime.strftime(_format)\n        _end = self.end_datetime.strftime(_format)\n        for _index_name, _index_code in {\"csi300\": \"000300\", \"csi100\": \"000903\", \"csi500\": \"000905\"}.items():\n            logger.info(f\"get bench data: {_index_name}({_index_code})......\")\n            try:\n                df = pd.DataFrame(\n                    map(\n                        lambda x: x.split(\",\"),\n                        requests.get(\n                            INDEX_BENCH_URL.format(index_code=_index_code, begin=_begin, end=_end), timeout=None\n                        ).json()[\"data\"][\"klines\"],\n                    )\n                )\n            except Exception as e:\n                logger.warning(f\"get {_index_name} error: {e}\")\n                continue\n            df.columns = [\"date\", \"open\", \"close\", \"high\", \"low\", \"volume\", \"money\", \"change\"]\n            df[\"date\"] = pd.to_datetime(df[\"date\"])\n            df = df.astype(float, errors=\"ignore\")\n            df[\"adjclose\"] = df[\"close\"]\n            df[\"symbol\"] = f\"sh{_index_code}\"\n            _path = self.save_dir.joinpath(f\"sh{_index_code}.csv\")\n            if _path.exists():\n                _old_df = pd.read_csv(_path)\n                df = pd.concat([_old_df, df], sort=False)\n            df.to_csv(_path, index=False)\n            time.sleep(5)\n\n\nclass YahooCollectorCN1min(YahooCollectorCN):\n    def get_instrument_list(self):\n        symbols = super(YahooCollectorCN1min, self).get_instrument_list()\n        return symbols + [\"000300.ss\", \"000905.ss\", \"000903.ss\"]\n\n    def download_index_data(self):\n        pass\n\n\nclass YahooCollectorUS(YahooCollector, ABC):\n    def get_instrument_list(self):\n        logger.info(\"get US stock symbols......\")\n        symbols = get_us_stock_symbols() + [\n            \"^GSPC\",\n            \"^NDX\",\n            \"^DJI\",\n        ]\n        logger.info(f\"get {len(symbols)} symbols.\")\n        return symbols\n\n    def download_index_data(self):\n        pass\n\n    def normalize_symbol(self, symbol):\n        return code_to_fname(symbol).upper()\n\n    @property\n    def _timezone(self):\n        return \"America/New_York\"\n\n\nclass YahooCollectorUS1d(YahooCollectorUS):\n    pass\n\n\nclass YahooCollectorUS1min(YahooCollectorUS):\n    pass\n\n\nclass YahooCollectorIN(YahooCollector, ABC):\n    def get_instrument_list(self):\n        logger.info(\"get INDIA stock symbols......\")\n        symbols = get_in_stock_symbols()\n        logger.info(f\"get {len(symbols)} symbols.\")\n        return symbols\n\n    def download_index_data(self):\n        pass\n\n    def normalize_symbol(self, symbol):\n        return code_to_fname(symbol).upper()\n\n    @property\n    def _timezone(self):\n        return \"Asia/Kolkata\"\n\n\nclass YahooCollectorIN1d(YahooCollectorIN):\n    pass\n\n\nclass YahooCollectorIN1min(YahooCollectorIN):\n    pass\n\n\nclass YahooCollectorBR(YahooCollector, ABC):\n    def retry(cls):  # pylint: disable=E0213\n        \"\"\"\n        The reason to use retry=2 is due to the fact that\n        Yahoo Finance unfortunately does not keep track of some\n        Brazilian stocks.\n\n        Therefore, the decorator deco_retry with retry argument\n        set to 5 will keep trying to get the stock data up to 5 times,\n        which makes the code to download Brazilians stocks very slow.\n\n        In future, this may change, but for now\n        I suggest to leave retry argument to 1 or 2 in\n        order to improve download speed.\n\n        To achieve this goal an abstract attribute (retry)\n        was added into YahooCollectorBR base class\n        \"\"\"\n        raise NotImplementedError\n\n    def get_instrument_list(self):\n        logger.info(\"get BR stock symbols......\")\n        symbols = get_br_stock_symbols() + [\n            \"^BVSP\",\n        ]\n        logger.info(f\"get {len(symbols)} symbols.\")\n        return symbols\n\n    def download_index_data(self):\n        pass\n\n    def normalize_symbol(self, symbol):\n        return code_to_fname(symbol).upper()\n\n    @property\n    def _timezone(self):\n        return \"Brazil/East\"\n\n\nclass YahooCollectorBR1d(YahooCollectorBR):\n    retry = 2\n\n\nclass YahooCollectorBR1min(YahooCollectorBR):\n    retry = 2\n\n\nclass YahooNormalize(BaseNormalize):\n    COLUMNS = [\"open\", \"close\", \"high\", \"low\", \"volume\"]\n    DAILY_FORMAT = \"%Y-%m-%d\"\n\n    @staticmethod\n    def calc_change(df: pd.DataFrame, last_close: float) -> pd.Series:\n        df = df.copy()\n        _tmp_series = df[\"close\"].ffill()\n        _tmp_shift_series = _tmp_series.shift(1)\n        if last_close is not None:\n            _tmp_shift_series.iloc[0] = float(last_close)\n        change_series = _tmp_series / _tmp_shift_series - 1\n        return change_series\n\n    @staticmethod\n    def normalize_yahoo(\n        df: pd.DataFrame,\n        calendar_list: list = None,\n        date_field_name: str = \"date\",\n        symbol_field_name: str = \"symbol\",\n        last_close: float = None,\n    ):\n        if df.empty:\n            return df\n        symbol = df.loc[df[symbol_field_name].first_valid_index(), symbol_field_name]\n        columns = copy.deepcopy(YahooNormalize.COLUMNS)\n        df = df.copy()\n        df.set_index(date_field_name, inplace=True)\n        df.index = pd.to_datetime(df.index)\n        df.index = df.index.tz_localize(None)\n        df = df[~df.index.duplicated(keep=\"first\")]\n        if calendar_list is not None:\n            df = df.reindex(\n                pd.DataFrame(index=calendar_list)\n                .loc[\n                    pd.Timestamp(df.index.min()).date() : pd.Timestamp(df.index.max()).date()\n                    + pd.Timedelta(hours=23, minutes=59)\n                ]\n                .index\n            )\n        df.sort_index(inplace=True)\n        df.loc[(df[\"volume\"] <= 0) | np.isnan(df[\"volume\"]), list(set(df.columns) - {symbol_field_name})] = np.nan\n\n        change_series = YahooNormalize.calc_change(df, last_close)\n        # NOTE: The data obtained by Yahoo finance sometimes has exceptions\n        # WARNING: If it is normal for a `symbol(exchange)` to differ by a factor of *89* to *111* for consecutive trading days,\n        # WARNING: the logic in the following line needs to be modified\n        _count = 0\n        while True:\n            # NOTE: may appear unusual for many days in a row\n            change_series = YahooNormalize.calc_change(df, last_close)\n            _mask = (change_series >= 89) & (change_series <= 111)\n            if not _mask.any():\n                break\n            _tmp_cols = [\"high\", \"close\", \"low\", \"open\", \"adjclose\"]\n            df.loc[_mask, _tmp_cols] = df.loc[_mask, _tmp_cols] / 100\n            _count += 1\n            if _count >= 10:\n                _symbol = df.loc[df[symbol_field_name].first_valid_index()][\"symbol\"]\n                logger.warning(\n                    f\"{_symbol} `change` is abnormal for {_count} consecutive days, please check the specific data file carefully\"\n                )\n\n        df[\"change\"] = YahooNormalize.calc_change(df, last_close)\n\n        columns += [\"change\"]\n        df.loc[(df[\"volume\"] <= 0) | np.isnan(df[\"volume\"]), columns] = np.nan\n\n        df[symbol_field_name] = symbol\n        df.index.names = [date_field_name]\n        return df.reset_index()\n\n    def normalize(self, df: pd.DataFrame) -> pd.DataFrame:\n        # normalize\n        df = self.normalize_yahoo(df, self._calendar_list, self._date_field_name, self._symbol_field_name)\n        # adjusted price\n        df = self.adjusted_price(df)\n        return df\n\n    @abc.abstractmethod\n    def adjusted_price(self, df: pd.DataFrame) -> pd.DataFrame:\n        \"\"\"adjusted price\"\"\"\n        raise NotImplementedError(\"rewrite adjusted_price\")\n\n\nclass YahooNormalize1d(YahooNormalize, ABC):\n    DAILY_FORMAT = \"%Y-%m-%d\"\n\n    def adjusted_price(self, df: pd.DataFrame) -> pd.DataFrame:\n        if df.empty:\n            return df\n        df = df.copy()\n        df.set_index(self._date_field_name, inplace=True)\n        if \"adjclose\" in df:\n            df[\"factor\"] = df[\"adjclose\"] / df[\"close\"]\n            df[\"factor\"] = df[\"factor\"].ffill()\n        else:\n            df[\"factor\"] = 1\n        for _col in self.COLUMNS:\n            if _col not in df.columns:\n                continue\n            if _col == \"volume\":\n                df[_col] = df[_col] / df[\"factor\"]\n            else:\n                df[_col] = df[_col] * df[\"factor\"]\n        df.index.names = [self._date_field_name]\n        return df.reset_index()\n\n    def normalize(self, df: pd.DataFrame) -> pd.DataFrame:\n        df = super(YahooNormalize1d, self).normalize(df)\n        df = self._manual_adj_data(df)\n        return df\n\n    def _get_first_close(self, df: pd.DataFrame) -> float:\n        \"\"\"get first close value\n\n        Notes\n        -----\n            For incremental updates(append) to Yahoo 1D data, user need to use a close that is not 0 on the first trading day of the existing data\n        \"\"\"\n        df = df.loc[df[\"close\"].first_valid_index() :]\n        _close = df[\"close\"].iloc[0]\n        return _close\n\n    def _manual_adj_data(self, df: pd.DataFrame) -> pd.DataFrame:\n        \"\"\"manual adjust data: All fields (except change) are standardized according to the close of the first day\"\"\"\n        if df.empty:\n            return df\n        df = df.copy()\n        df.sort_values(self._date_field_name, inplace=True)\n        df = df.set_index(self._date_field_name)\n        _close = self._get_first_close(df)\n        for _col in df.columns:\n            # NOTE: retain original adjclose, required for incremental updates\n            if _col in [self._symbol_field_name, \"adjclose\", \"change\"]:\n                continue\n            if _col == \"volume\":\n                df[_col] = df[_col] * _close\n            else:\n                df[_col] = df[_col] / _close\n        return df.reset_index()\n\n\nclass YahooNormalize1dExtend(YahooNormalize1d):\n    def __init__(\n        self, old_qlib_data_dir: [str, Path], date_field_name: str = \"date\", symbol_field_name: str = \"symbol\", **kwargs\n    ):\n        \"\"\"\n\n        Parameters\n        ----------\n        old_qlib_data_dir: str, Path\n            the qlib data to be updated for yahoo, usually from: https://github.com/microsoft/qlib/tree/main/scripts#download-cn-data\n        date_field_name: str\n            date field name, default is date\n        symbol_field_name: str\n            symbol field name, default is symbol\n        \"\"\"\n        super(YahooNormalize1dExtend, self).__init__(date_field_name, symbol_field_name)\n        self.column_list = [\"open\", \"high\", \"low\", \"close\", \"volume\", \"factor\", \"change\"]\n        self.old_qlib_data = self._get_old_data(old_qlib_data_dir)\n\n    def _get_old_data(self, qlib_data_dir: [str, Path]):\n        qlib_data_dir = str(Path(qlib_data_dir).expanduser().resolve())\n        qlib.init(provider_uri=qlib_data_dir, expression_cache=None, dataset_cache=None)\n        df = D.features(D.instruments(\"all\"), [\"$\" + col for col in self.column_list])\n        df.columns = self.column_list\n        return df\n\n    def normalize(self, df: pd.DataFrame) -> pd.DataFrame:\n        df = super(YahooNormalize1dExtend, self).normalize(df)\n        df.set_index(self._date_field_name, inplace=True)\n        symbol_name = df[self._symbol_field_name].iloc[0]\n        old_symbol_list = self.old_qlib_data.index.get_level_values(\"instrument\").unique().to_list()\n        if str(symbol_name).upper() not in old_symbol_list:\n            return df.reset_index()\n        old_df = self.old_qlib_data.loc[str(symbol_name).upper()]\n        latest_date = old_df.index[-1]\n        df = df.loc[latest_date:]\n        new_latest_data = df.iloc[0]\n        old_latest_data = old_df.loc[latest_date]\n        for col in self.column_list[:-1]:\n            if col == \"volume\":\n                df[col] = df[col] / (new_latest_data[col] / old_latest_data[col])\n            else:\n                df[col] = df[col] * (old_latest_data[col] / new_latest_data[col])\n        return df.drop(df.index[0]).reset_index()\n\n\nclass YahooNormalize1min(YahooNormalize, ABC):\n    \"\"\"Normalised to 1min using local 1d data\"\"\"\n\n    AM_RANGE = None  # type: tuple  # eg: (\"09:30:00\", \"11:29:00\")\n    PM_RANGE = None  # type: tuple  # eg: (\"13:00:00\", \"14:59:00\")\n\n    # Whether the trading day of 1min data is consistent with 1d\n    CONSISTENT_1d = True\n    CALC_PAUSED_NUM = True\n\n    def __init__(\n        self, qlib_data_1d_dir: [str, Path], date_field_name: str = \"date\", symbol_field_name: str = \"symbol\", **kwargs\n    ):\n        \"\"\"\n\n        Parameters\n        ----------\n        qlib_data_1d_dir: str, Path\n            the qlib data to be updated for yahoo, usually from: Normalised to 1min using local 1d data\n        date_field_name: str\n            date field name, default is date\n        symbol_field_name: str\n            symbol field name, default is symbol\n        \"\"\"\n        super(YahooNormalize1min, self).__init__(date_field_name, symbol_field_name)\n        qlib.init(provider_uri=qlib_data_1d_dir)\n        self.all_1d_data = D.features(D.instruments(\"all\"), [\"$paused\", \"$volume\", \"$factor\", \"$close\"], freq=\"day\")\n\n    def _get_1d_calendar_list(self) -> Iterable[pd.Timestamp]:\n        return list(D.calendar(freq=\"day\"))\n\n    @property\n    def calendar_list_1d(self):\n        calendar_list_1d = getattr(self, \"_calendar_list_1d\", None)\n        if calendar_list_1d is None:\n            calendar_list_1d = self._get_1d_calendar_list()\n            setattr(self, \"_calendar_list_1d\", calendar_list_1d)\n        return calendar_list_1d\n\n    def generate_1min_from_daily(self, calendars: Iterable) -> pd.Index:\n        return generate_minutes_calendar_from_daily(\n            calendars, freq=\"1min\", am_range=self.AM_RANGE, pm_range=self.PM_RANGE\n        )\n\n    def adjusted_price(self, df: pd.DataFrame) -> pd.DataFrame:\n        df = calc_adjusted_price(\n            df=df,\n            _date_field_name=self._date_field_name,\n            _symbol_field_name=self._symbol_field_name,\n            frequence=\"1min\",\n            consistent_1d=self.CONSISTENT_1d,\n            calc_paused=self.CALC_PAUSED_NUM,\n            _1d_data_all=self.all_1d_data,\n        )\n        return df\n\n    @abc.abstractmethod\n    def symbol_to_yahoo(self, symbol):\n        raise NotImplementedError(\"rewrite symbol_to_yahoo\")\n\n\nclass YahooNormalizeUS:\n    def _get_calendar_list(self) -> Iterable[pd.Timestamp]:\n        # TODO: from MSN\n        return get_calendar_list(\"US_ALL\")\n\n\nclass YahooNormalizeUS1d(YahooNormalizeUS, YahooNormalize1d):\n    pass\n\n\nclass YahooNormalizeUS1dExtend(YahooNormalizeUS, YahooNormalize1dExtend):\n    pass\n\n\nclass YahooNormalizeUS1min(YahooNormalizeUS, YahooNormalize1min):\n    CALC_PAUSED_NUM = False\n\n    def _get_calendar_list(self) -> Iterable[pd.Timestamp]:\n        # TODO: support 1min\n        raise ValueError(\"Does not support 1min\")\n\n    def _get_1d_calendar_list(self):\n        return get_calendar_list(\"US_ALL\")\n\n    def symbol_to_yahoo(self, symbol):\n        return fname_to_code(symbol)\n\n\nclass YahooNormalizeIN:\n    def _get_calendar_list(self) -> Iterable[pd.Timestamp]:\n        return get_calendar_list(\"IN_ALL\")\n\n\nclass YahooNormalizeIN1d(YahooNormalizeIN, YahooNormalize1d):\n    pass\n\n\nclass YahooNormalizeIN1min(YahooNormalizeIN, YahooNormalize1min):\n    CALC_PAUSED_NUM = False\n\n    def _get_calendar_list(self) -> Iterable[pd.Timestamp]:\n        # TODO: support 1min\n        raise ValueError(\"Does not support 1min\")\n\n    def _get_1d_calendar_list(self):\n        return get_calendar_list(\"IN_ALL\")\n\n    def symbol_to_yahoo(self, symbol):\n        return fname_to_code(symbol)\n\n\nclass YahooNormalizeCN:\n    def _get_calendar_list(self) -> Iterable[pd.Timestamp]:\n        # TODO: from MSN\n        return get_calendar_list(\"ALL\")\n\n\nclass YahooNormalizeCN1d(YahooNormalizeCN, YahooNormalize1d):\n    pass\n\n\nclass YahooNormalizeCN1dExtend(YahooNormalizeCN, YahooNormalize1dExtend):\n    pass\n\n\nclass YahooNormalizeCN1min(YahooNormalizeCN, YahooNormalize1min):\n    AM_RANGE = (\"09:30:00\", \"11:29:00\")\n    PM_RANGE = (\"13:00:00\", \"14:59:00\")\n\n    def _get_calendar_list(self) -> Iterable[pd.Timestamp]:\n        return self.generate_1min_from_daily(self.calendar_list_1d)\n\n    def symbol_to_yahoo(self, symbol):\n        if \".\" not in symbol:\n            _exchange = symbol[:2]\n            _exchange = (\"ss\" if _exchange.islower() else \"SS\") if _exchange.lower() == \"sh\" else _exchange\n            symbol = symbol[2:] + \".\" + _exchange\n        return symbol\n\n    def _get_1d_calendar_list(self) -> Iterable[pd.Timestamp]:\n        return get_calendar_list(\"ALL\")\n\n\nclass YahooNormalizeBR:\n    def _get_calendar_list(self) -> Iterable[pd.Timestamp]:\n        return get_calendar_list(\"BR_ALL\")\n\n\nclass YahooNormalizeBR1d(YahooNormalizeBR, YahooNormalize1d):\n    pass\n\n\nclass YahooNormalizeBR1min(YahooNormalizeBR, YahooNormalize1min):\n    CALC_PAUSED_NUM = False\n\n    def _get_calendar_list(self) -> Iterable[pd.Timestamp]:\n        # TODO: support 1min\n        raise ValueError(\"Does not support 1min\")\n\n    def _get_1d_calendar_list(self):\n        return get_calendar_list(\"BR_ALL\")\n\n    def symbol_to_yahoo(self, symbol):\n        return fname_to_code(symbol)\n\n\nclass Run(BaseRun):\n    def __init__(self, source_dir=None, normalize_dir=None, max_workers=1, interval=\"1d\", region=REGION_CN):\n        \"\"\"\n\n        Parameters\n        ----------\n        source_dir: str\n            The directory where the raw data collected from the Internet is saved, default \"Path(__file__).parent/source\"\n        normalize_dir: str\n            Directory for normalize data, default \"Path(__file__).parent/normalize\"\n        max_workers: int\n            Concurrent number, default is 1; when collecting data, it is recommended that max_workers be set to 1\n        interval: str\n            freq, value from [1min, 1d], default 1d\n        region: str\n            region, value from [\"CN\", \"US\", \"BR\"], default \"CN\"\n        \"\"\"\n        super().__init__(source_dir, normalize_dir, max_workers, interval)\n        self.region = region\n\n    @property\n    def collector_class_name(self):\n        return f\"YahooCollector{self.region.upper()}{self.interval}\"\n\n    @property\n    def normalize_class_name(self):\n        return f\"YahooNormalize{self.region.upper()}{self.interval}\"\n\n    @property\n    def default_base_dir(self) -> [Path, str]:\n        return CUR_DIR\n\n    def download_data(\n        self,\n        max_collector_count=2,\n        delay=0.5,\n        start=None,\n        end=None,\n        check_data_length=None,\n        limit_nums=None,\n    ):\n        \"\"\"download data from Internet\n\n        Parameters\n        ----------\n        max_collector_count: int\n            default 2\n        delay: float\n            time.sleep(delay), default 0.5\n        start: str\n            start datetime, default \"2000-01-01\"; closed interval(including start)\n        end: str\n            end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))``; open interval(excluding end)\n        check_data_length: int\n            check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None.\n        limit_nums: int\n            using for debug, by default None\n\n        Notes\n        -----\n            check_data_length, example:\n                daily, one year: 252 // 4\n                us 1min, a week: 6.5 * 60 * 5\n                cn 1min, a week: 4 * 60 * 5\n\n        Examples\n        ---------\n            # get daily data\n            $ python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d\n            # get 1m data\n            $ python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1m\n        \"\"\"\n        if self.interval == \"1d\" and pd.Timestamp(end) > pd.Timestamp(datetime.datetime.now().strftime(\"%Y-%m-%d\")):\n            raise ValueError(f\"end_date: {end} is greater than the current date.\")\n\n        super(Run, self).download_data(max_collector_count, delay, start, end, check_data_length, limit_nums)\n\n    def normalize_data(\n        self,\n        date_field_name: str = \"date\",\n        symbol_field_name: str = \"symbol\",\n        end_date: str = None,\n        qlib_data_1d_dir: str = None,\n    ):\n        \"\"\"normalize data\n\n        Parameters\n        ----------\n        date_field_name: str\n            date field name, default date\n        symbol_field_name: str\n            symbol field name, default symbol\n        end_date: str\n            if not None, normalize the last date saved (including end_date); if None, it will ignore this parameter; by default None\n        qlib_data_1d_dir: str\n            if interval==1min, qlib_data_1d_dir cannot be None, normalize 1min needs to use 1d data;\n\n                qlib_data_1d can be obtained like this:\n                    $ python scripts/get_data.py qlib_data --target_dir <qlib_data_1d_dir> --interval 1d\n                    $ python scripts/data_collector/yahoo/collector.py update_data_to_bin --qlib_data_1d_dir <qlib_data_1d_dir> --trading_date 2021-06-01\n                or:\n                    download 1d data, reference: https://github.com/microsoft/qlib/tree/main/scripts/data_collector/yahoo#1d-from-yahoo\n\n        Examples\n        ---------\n            $ python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region cn --interval 1d\n            $ python collector.py normalize_data --qlib_data_1d_dir ~/.qlib/qlib_data/cn_data --source_dir ~/.qlib/stock_data/source_cn_1min --normalize_dir ~/.qlib/stock_data/normalize_cn_1min --region CN --interval 1min\n        \"\"\"\n        if self.interval.lower() == \"1min\":\n            if qlib_data_1d_dir is None or not Path(qlib_data_1d_dir).expanduser().exists():\n                raise ValueError(\n                    \"If normalize 1min, the qlib_data_1d_dir parameter must be set: --qlib_data_1d_dir <user qlib 1d data >, Reference: https://github.com/microsoft/qlib/tree/main/scripts/data_collector/yahoo#automatic-update-of-daily-frequency-datafrom-yahoo-finance\"\n                )\n        super(Run, self).normalize_data(\n            date_field_name, symbol_field_name, end_date=end_date, qlib_data_1d_dir=qlib_data_1d_dir\n        )\n\n    def normalize_data_1d_extend(\n        self, old_qlib_data_dir, date_field_name: str = \"date\", symbol_field_name: str = \"symbol\"\n    ):\n        \"\"\"normalize data extend; extending yahoo qlib data(from: https://github.com/microsoft/qlib/tree/main/scripts#download-cn-data)\n\n        Notes\n        -----\n            Steps to extend yahoo qlib data:\n\n                1. download qlib data: https://github.com/microsoft/qlib/tree/main/scripts#download-cn-data; save to <dir1>\n\n                2. collector source data: https://github.com/microsoft/qlib/tree/main/scripts/data_collector/yahoo#collector-data; save to <dir2>\n\n                3. normalize new source data(from step 2): python scripts/data_collector/yahoo/collector.py normalize_data_1d_extend --old_qlib_dir <dir1> --source_dir <dir2> --normalize_dir <dir3> --region CN --interval 1d\n\n                4. dump data: python scripts/dump_bin.py dump_update --data_path <dir3> --qlib_dir <dir1> --freq day --date_field_name date --symbol_field_name symbol --exclude_fields symbol,date\n\n                5. update instrument(eg. csi300): python python scripts/data_collector/cn_index/collector.py --index_name CSI300 --qlib_dir <dir1> --method parse_instruments\n\n        Parameters\n        ----------\n        old_qlib_data_dir: str\n            the qlib data to be updated for yahoo, usually from: https://github.com/microsoft/qlib/tree/main/scripts#download-cn-data\n        date_field_name: str\n            date field name, default date\n        symbol_field_name: str\n            symbol field name, default symbol\n\n        Examples\n        ---------\n            $ python collector.py normalize_data_1d_extend --old_qlib_dir ~/.qlib/qlib_data/cn_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region CN --interval 1d\n        \"\"\"\n        _class = getattr(self._cur_module, f\"{self.normalize_class_name}Extend\")\n        yc = Normalize(\n            source_dir=self.source_dir,\n            target_dir=self.normalize_dir,\n            normalize_class=_class,\n            max_workers=self.max_workers,\n            date_field_name=date_field_name,\n            symbol_field_name=symbol_field_name,\n            old_qlib_data_dir=old_qlib_data_dir,\n        )\n        yc.normalize()\n\n    def download_today_data(\n        self,\n        max_collector_count=2,\n        delay=0.5,\n        check_data_length=None,\n        limit_nums=None,\n    ):\n        \"\"\"download today data from Internet\n\n        Parameters\n        ----------\n        max_collector_count: int\n            default 2\n        delay: float\n            time.sleep(delay), default 0.5\n        check_data_length: int\n            check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None.\n        limit_nums: int\n            using for debug, by default None\n\n        Notes\n        -----\n            Download today's data:\n                start_time = datetime.datetime.now().date(); closed interval(including start)\n                end_time = pd.Timestamp(start_time + pd.Timedelta(days=1)).date(); open interval(excluding end)\n\n            check_data_length, example:\n                daily, one year: 252 // 4\n                us 1min, a week: 6.5 * 60 * 5\n                cn 1min, a week: 4 * 60 * 5\n\n        Examples\n        ---------\n            # get daily data\n            $ python collector.py download_today_data --source_dir ~/.qlib/stock_data/source --region CN --delay 0.1 --interval 1d\n            # get 1m data\n            $ python collector.py download_today_data --source_dir ~/.qlib/stock_data/source --region CN --delay 0.1 --interval 1m\n        \"\"\"\n        start = datetime.datetime.now().date()\n        end = pd.Timestamp(start + pd.Timedelta(days=1)).date()\n        self.download_data(\n            max_collector_count,\n            delay,\n            start.strftime(\"%Y-%m-%d\"),\n            end.strftime(\"%Y-%m-%d\"),\n            check_data_length,\n            limit_nums,\n        )\n\n    def update_data_to_bin(\n        self,\n        qlib_data_1d_dir: str,\n        end_date: str = None,\n        check_data_length: int = None,\n        delay: float = 1,\n        exists_skip: bool = False,\n    ):\n        \"\"\"update yahoo data to bin\n\n        Parameters\n        ----------\n        qlib_data_1d_dir: str\n            the qlib data to be updated for yahoo, usually from: https://github.com/microsoft/qlib/tree/main/scripts#download-cn-data\n\n        end_date: str\n            end datetime, default ``pd.Timestamp(trading_date + pd.Timedelta(days=1))``; open interval(excluding end)\n        check_data_length: int\n            check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None.\n        delay: float\n            time.sleep(delay), default 1\n        exists_skip: bool\n            exists skip, by default False\n        Notes\n        -----\n            If the data in qlib_data_dir is incomplete, np.nan will be populated to trading_date for the previous trading day\n\n        Examples\n        -------\n            $ python collector.py update_data_to_bin --qlib_data_1d_dir <user data dir> --trading_date <start date> --end_date <end date>\n        \"\"\"\n\n        if self.interval.lower() != \"1d\":\n            logger.warning(f\"currently supports 1d data updates: --interval 1d\")\n\n        # download qlib 1d data\n        qlib_data_1d_dir = str(Path(qlib_data_1d_dir).expanduser().resolve())\n        if not exists_qlib_data(qlib_data_1d_dir):\n            GetData().qlib_data(\n                target_dir=qlib_data_1d_dir, interval=self.interval, region=self.region, exists_skip=exists_skip\n            )\n\n        # start/end date\n        calendar_df = pd.read_csv(Path(qlib_data_1d_dir).joinpath(\"calendars/day.txt\"))\n        trading_date = (pd.Timestamp(calendar_df.iloc[-1, 0]) - pd.Timedelta(days=1)).strftime(\"%Y-%m-%d\")\n\n        if end_date is None:\n            end_date = (pd.Timestamp(trading_date) + pd.Timedelta(days=1)).strftime(\"%Y-%m-%d\")\n\n        # download data from yahoo\n        # NOTE: when downloading data from YahooFinance, max_workers is recommended to be 1\n        self.download_data(delay=delay, start=trading_date, end=end_date, check_data_length=check_data_length)\n        # NOTE: a larger max_workers setting here would be faster\n        self.max_workers = (\n            max(multiprocessing.cpu_count() - 2, 1)\n            if self.max_workers is None or self.max_workers <= 1\n            else self.max_workers\n        )\n        # normalize data\n        self.normalize_data_1d_extend(qlib_data_1d_dir)\n\n        # dump bin\n        _dump = DumpDataUpdate(\n            data_path=self.normalize_dir,\n            qlib_dir=qlib_data_1d_dir,\n            exclude_fields=\"symbol,date\",\n            max_workers=self.max_workers,\n        )\n        _dump.dump()\n\n        # parse index\n        _region = self.region.lower()\n        if _region not in [\"cn\", \"us\"]:\n            logger.warning(f\"Unsupported region: region={_region}, component downloads will be ignored\")\n            return\n        index_list = [\"CSI100\", \"CSI300\"] if _region == \"cn\" else [\"SP500\", \"NASDAQ100\", \"DJIA\", \"SP400\"]\n        get_instruments = getattr(\n            importlib.import_module(f\"data_collector.{_region}_index.collector\"), \"get_instruments\"\n        )\n        for _index in index_list:\n            get_instruments(str(qlib_data_1d_dir), _index, market_index=f\"{_region}_index\")\n\n\nif __name__ == \"__main__\":\n    fire.Fire(Run)\n"
  },
  {
    "path": "scripts/data_collector/yahoo/requirements.txt",
    "content": "loguru\nfire\nrequests\nnumpy\npandas\ntqdm\nlxml\nyahooquery\njoblib\nbeautifulsoup4\nbs4\nsoupsieve\nakshare"
  },
  {
    "path": "scripts/dump_bin.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport abc\nimport shutil\nimport traceback\nfrom pathlib import Path\nfrom typing import Iterable, List, Union\nfrom functools import partial\nfrom concurrent.futures import ThreadPoolExecutor, as_completed, ProcessPoolExecutor\n\nimport fire\nimport numpy as np\nimport pandas as pd\nfrom tqdm import tqdm\nfrom loguru import logger\nfrom qlib.utils import fname_to_code, code_to_fname\n\n\ndef read_as_df(file_path: Union[str, Path], **kwargs) -> pd.DataFrame:\n    \"\"\"\n    Read a csv or parquet file into a pandas DataFrame.\n\n    Parameters\n    ----------\n    file_path : Union[str, Path]\n        Path to the data file.\n    **kwargs :\n        Additional keyword arguments passed to the underlying pandas\n        reader.\n\n    Returns\n    -------\n    pd.DataFrame\n    \"\"\"\n    file_path = Path(file_path).expanduser()\n    suffix = file_path.suffix.lower()\n\n    keep_keys = {\".csv\": (\"low_memory\",)}\n    kept_kwargs = {}\n    for k in keep_keys.get(suffix, []):\n        if k in kwargs:\n            kept_kwargs[k] = kwargs[k]\n\n    if suffix == \".csv\":\n        return pd.read_csv(file_path, **kept_kwargs)\n    elif suffix == \".parquet\":\n        return pd.read_parquet(file_path, **kept_kwargs)\n    else:\n        raise ValueError(f\"Unsupported file format: {suffix}\")\n\n\nclass DumpDataBase:\n    INSTRUMENTS_START_FIELD = \"start_datetime\"\n    INSTRUMENTS_END_FIELD = \"end_datetime\"\n    CALENDARS_DIR_NAME = \"calendars\"\n    FEATURES_DIR_NAME = \"features\"\n    INSTRUMENTS_DIR_NAME = \"instruments\"\n    DUMP_FILE_SUFFIX = \".bin\"\n    DAILY_FORMAT = \"%Y-%m-%d\"\n    HIGH_FREQ_FORMAT = \"%Y-%m-%d %H:%M:%S\"\n    INSTRUMENTS_SEP = \"\\t\"\n    INSTRUMENTS_FILE_NAME = \"all.txt\"\n\n    UPDATE_MODE = \"update\"\n    ALL_MODE = \"all\"\n\n    def __init__(\n        self,\n        data_path: str,\n        qlib_dir: str,\n        backup_dir: str = None,\n        freq: str = \"day\",\n        max_workers: int = 16,\n        date_field_name: str = \"date\",\n        file_suffix: str = \".csv\",\n        symbol_field_name: str = \"symbol\",\n        exclude_fields: str = \"\",\n        include_fields: str = \"\",\n        limit_nums: int = None,\n    ):\n        \"\"\"\n\n        Parameters\n        ----------\n        data_path: str\n            stock data path or directory\n        qlib_dir: str\n            qlib(dump) data director\n        backup_dir: str, default None\n            if backup_dir is not None, backup qlib_dir to backup_dir\n        freq: str, default \"day\"\n            transaction frequency\n        max_workers: int, default None\n            number of threads\n        date_field_name: str, default \"date\"\n            the name of the date field in the csv\n        file_suffix: str, default \".csv\"\n            file suffix\n        symbol_field_name: str, default \"symbol\"\n            symbol field name\n        include_fields: tuple\n            dump fields\n        exclude_fields: tuple\n            fields not dumped\n        limit_nums: int\n            Use when debugging, default None\n        \"\"\"\n        data_path = Path(data_path).expanduser()\n        if isinstance(exclude_fields, str):\n            exclude_fields = exclude_fields.split(\",\")\n        if isinstance(include_fields, str):\n            include_fields = include_fields.split(\",\")\n        self._exclude_fields = tuple(filter(lambda x: len(x) > 0, map(str.strip, exclude_fields)))\n        self._include_fields = tuple(filter(lambda x: len(x) > 0, map(str.strip, include_fields)))\n        self.file_suffix = file_suffix\n        self.symbol_field_name = symbol_field_name\n        self.df_files = sorted(data_path.glob(f\"*{self.file_suffix}\") if data_path.is_dir() else [data_path])\n        if limit_nums is not None:\n            self.df_files = self.df_files[: int(limit_nums)]\n        self.qlib_dir = Path(qlib_dir).expanduser()\n        self.backup_dir = backup_dir if backup_dir is None else Path(backup_dir).expanduser()\n        if backup_dir is not None:\n            self._backup_qlib_dir(Path(backup_dir).expanduser())\n\n        self.freq = freq\n        self.calendar_format = self.DAILY_FORMAT if self.freq == \"day\" else self.HIGH_FREQ_FORMAT\n\n        self.works = max_workers\n        self.date_field_name = date_field_name\n\n        self._calendars_dir = self.qlib_dir.joinpath(self.CALENDARS_DIR_NAME)\n        self._features_dir = self.qlib_dir.joinpath(self.FEATURES_DIR_NAME)\n        self._instruments_dir = self.qlib_dir.joinpath(self.INSTRUMENTS_DIR_NAME)\n\n        self._calendars_list = []\n\n        self._mode = self.ALL_MODE\n        self._kwargs = {}\n\n    def _backup_qlib_dir(self, target_dir: Path):\n        shutil.copytree(str(self.qlib_dir.resolve()), str(target_dir.resolve()))\n\n    def _format_datetime(self, datetime_d: [str, pd.Timestamp]):\n        datetime_d = pd.Timestamp(datetime_d)\n        return datetime_d.strftime(self.calendar_format)\n\n    def _get_date(\n        self, file_or_df: [Path, pd.DataFrame], *, is_begin_end: bool = False, as_set: bool = False\n    ) -> Iterable[pd.Timestamp]:\n        if not isinstance(file_or_df, pd.DataFrame):\n            df = self._get_source_data(file_or_df)\n        else:\n            df = file_or_df\n        if df.empty or self.date_field_name not in df.columns.tolist():\n            _calendars = pd.Series(dtype=np.float32)\n        else:\n            _calendars = df[self.date_field_name]\n\n        if is_begin_end and as_set:\n            return (_calendars.min(), _calendars.max()), set(_calendars)\n        elif is_begin_end:\n            return _calendars.min(), _calendars.max()\n        elif as_set:\n            return set(_calendars)\n        else:\n            return _calendars.tolist()\n\n    def _get_source_data(self, file_path: Path) -> pd.DataFrame:\n        df = read_as_df(file_path, low_memory=False)\n        if self.date_field_name in df.columns:\n            df[self.date_field_name] = pd.to_datetime(df[self.date_field_name])\n        # df.drop_duplicates([self.date_field_name], inplace=True)\n        return df\n\n    def get_symbol_from_file(self, file_path: Path) -> str:\n        return fname_to_code(file_path.stem.strip().lower())\n\n    def get_dump_fields(self, df_columns: Iterable[str]) -> Iterable[str]:\n        return (\n            self._include_fields\n            if self._include_fields\n            else set(df_columns) - set(self._exclude_fields) if self._exclude_fields else df_columns\n        )\n\n    @staticmethod\n    def _read_calendars(calendar_path: Path) -> List[pd.Timestamp]:\n        return sorted(\n            map(\n                pd.Timestamp,\n                pd.read_csv(calendar_path, header=None).loc[:, 0].tolist(),\n            )\n        )\n\n    def _read_instruments(self, instrument_path: Path) -> pd.DataFrame:\n        df = pd.read_csv(\n            instrument_path,\n            sep=self.INSTRUMENTS_SEP,\n            names=[\n                self.symbol_field_name,\n                self.INSTRUMENTS_START_FIELD,\n                self.INSTRUMENTS_END_FIELD,\n            ],\n        )\n\n        return df\n\n    def save_calendars(self, calendars_data: list):\n        self._calendars_dir.mkdir(parents=True, exist_ok=True)\n        calendars_path = str(self._calendars_dir.joinpath(f\"{self.freq}.txt\").expanduser().resolve())\n        result_calendars_list = [self._format_datetime(x) for x in calendars_data]\n        np.savetxt(calendars_path, result_calendars_list, fmt=\"%s\", encoding=\"utf-8\")\n\n    def save_instruments(self, instruments_data: Union[list, pd.DataFrame]):\n        self._instruments_dir.mkdir(parents=True, exist_ok=True)\n        instruments_path = str(self._instruments_dir.joinpath(self.INSTRUMENTS_FILE_NAME).resolve())\n        if isinstance(instruments_data, pd.DataFrame):\n            _df_fields = [self.symbol_field_name, self.INSTRUMENTS_START_FIELD, self.INSTRUMENTS_END_FIELD]\n            instruments_data = instruments_data.loc[:, _df_fields]\n            instruments_data[self.symbol_field_name] = instruments_data[self.symbol_field_name].apply(\n                lambda x: fname_to_code(x.lower()).upper()\n            )\n            instruments_data.to_csv(instruments_path, header=False, sep=self.INSTRUMENTS_SEP, index=False)\n        else:\n            np.savetxt(instruments_path, instruments_data, fmt=\"%s\", encoding=\"utf-8\")\n\n    def data_merge_calendar(self, df: pd.DataFrame, calendars_list: List[pd.Timestamp]) -> pd.DataFrame:\n        # calendars\n        calendars_df = pd.DataFrame(data=calendars_list, columns=[self.date_field_name])\n        calendars_df[self.date_field_name] = calendars_df[self.date_field_name].astype(\"datetime64[ns]\")\n        cal_df = calendars_df[\n            (calendars_df[self.date_field_name] >= df[self.date_field_name].min())\n            & (calendars_df[self.date_field_name] <= df[self.date_field_name].max())\n        ]\n        # align index\n        cal_df.set_index(self.date_field_name, inplace=True)\n        df.set_index(self.date_field_name, inplace=True)\n        r_df = df.reindex(cal_df.index)\n        return r_df\n\n    @staticmethod\n    def get_datetime_index(df: pd.DataFrame, calendar_list: List[pd.Timestamp]) -> int:\n        return calendar_list.index(df.index.min())\n\n    def _data_to_bin(self, df: pd.DataFrame, calendar_list: List[pd.Timestamp], features_dir: Path):\n        if df.empty:\n            logger.warning(f\"{features_dir.name} data is None or empty\")\n            return\n        if not calendar_list:\n            logger.warning(\"calendar_list is empty\")\n            return\n        # align index\n        _df = self.data_merge_calendar(df, calendar_list)\n        if _df.empty:\n            logger.warning(f\"{features_dir.name} data is not in calendars\")\n            return\n        # used when creating a bin file\n        date_index = self.get_datetime_index(_df, calendar_list)\n        for field in self.get_dump_fields(_df.columns):\n            bin_path = features_dir.joinpath(f\"{field.lower()}.{self.freq}{self.DUMP_FILE_SUFFIX}\")\n            if field not in _df.columns:\n                continue\n            if bin_path.exists() and self._mode == self.UPDATE_MODE:\n                # update\n                with bin_path.open(\"ab\") as fp:\n                    np.array(_df[field]).astype(\"<f\").tofile(fp)\n            else:\n                # append; self._mode == self.ALL_MODE or not bin_path.exists()\n                np.hstack([date_index, _df[field]]).astype(\"<f\").tofile(str(bin_path.resolve()))\n\n    def _dump_bin(self, file_or_data: [Path, pd.DataFrame], calendar_list: List[pd.Timestamp]):\n        if not calendar_list:\n            logger.warning(\"calendar_list is empty\")\n            return\n        if isinstance(file_or_data, pd.DataFrame):\n            if file_or_data.empty:\n                return\n            code = fname_to_code(str(file_or_data.iloc[0][self.symbol_field_name]).lower())\n            df = file_or_data\n        elif isinstance(file_or_data, Path):\n            code = self.get_symbol_from_file(file_or_data)\n            df = self._get_source_data(file_or_data)\n        else:\n            raise ValueError(f\"not support {type(file_or_data)}\")\n        if df is None or df.empty:\n            logger.warning(f\"{code} data is None or empty\")\n            return\n\n        # try to remove dup rows or it will cause exception when reindex.\n        df = df.drop_duplicates(self.date_field_name)\n\n        # features save dir\n        features_dir = self._features_dir.joinpath(code_to_fname(code).lower())\n        features_dir.mkdir(parents=True, exist_ok=True)\n        self._data_to_bin(df, calendar_list, features_dir)\n\n    @abc.abstractmethod\n    def dump(self):\n        raise NotImplementedError(\"dump not implemented!\")\n\n    def __call__(self, *args, **kwargs):\n        self.dump()\n\n\nclass DumpDataAll(DumpDataBase):\n    def _get_all_date(self):\n        logger.info(\"start get all date......\")\n        all_datetime = set()\n        date_range_list = []\n        _fun = partial(self._get_date, as_set=True, is_begin_end=True)\n        with tqdm(total=len(self.df_files)) as p_bar:\n            with ProcessPoolExecutor(max_workers=self.works) as executor:\n                for file_path, ((_begin_time, _end_time), _set_calendars) in zip(\n                    self.df_files, executor.map(_fun, self.df_files)\n                ):\n                    all_datetime = all_datetime | _set_calendars\n                    if isinstance(_begin_time, pd.Timestamp) and isinstance(_end_time, pd.Timestamp):\n                        _begin_time = self._format_datetime(_begin_time)\n                        _end_time = self._format_datetime(_end_time)\n                        symbol = self.get_symbol_from_file(file_path)\n                        _inst_fields = [symbol.upper(), _begin_time, _end_time]\n                        date_range_list.append(f\"{self.INSTRUMENTS_SEP.join(_inst_fields)}\")\n                    p_bar.update()\n        self._kwargs[\"all_datetime_set\"] = all_datetime\n        self._kwargs[\"date_range_list\"] = date_range_list\n        logger.info(\"end of get all date.\\n\")\n\n    def _dump_calendars(self):\n        logger.info(\"start dump calendars......\")\n        self._calendars_list = sorted(map(pd.Timestamp, self._kwargs[\"all_datetime_set\"]))\n        self.save_calendars(self._calendars_list)\n        logger.info(\"end of calendars dump.\\n\")\n\n    def _dump_instruments(self):\n        logger.info(\"start dump instruments......\")\n        self.save_instruments(self._kwargs[\"date_range_list\"])\n        logger.info(\"end of instruments dump.\\n\")\n\n    def _dump_features(self):\n        logger.info(\"start dump features......\")\n        _dump_func = partial(self._dump_bin, calendar_list=self._calendars_list)\n        with tqdm(total=len(self.df_files)) as p_bar:\n            with ProcessPoolExecutor(max_workers=self.works) as executor:\n                for _ in executor.map(_dump_func, self.df_files):\n                    p_bar.update()\n\n        logger.info(\"end of features dump.\\n\")\n\n    def dump(self):\n        self._get_all_date()\n        self._dump_calendars()\n        self._dump_instruments()\n        self._dump_features()\n\n\nclass DumpDataFix(DumpDataAll):\n    def _dump_instruments(self):\n        logger.info(\"start dump instruments......\")\n        _fun = partial(self._get_date, is_begin_end=True)\n        new_stock_files = sorted(\n            filter(\n                lambda x: self.get_symbol_from_file(x).upper() not in self._old_instruments,\n                self.df_files,\n            )\n        )\n        with tqdm(total=len(new_stock_files)) as p_bar:\n            with ProcessPoolExecutor(max_workers=self.works) as execute:\n                for file_path, (_begin_time, _end_time) in zip(new_stock_files, execute.map(_fun, new_stock_files)):\n                    if isinstance(_begin_time, pd.Timestamp) and isinstance(_end_time, pd.Timestamp):\n                        symbol = self.get_symbol_from_file(file_path).upper()\n                        _dt_map = self._old_instruments.setdefault(symbol, dict())\n                        _dt_map[self.INSTRUMENTS_START_FIELD] = self._format_datetime(_begin_time)\n                        _dt_map[self.INSTRUMENTS_END_FIELD] = self._format_datetime(_end_time)\n                    p_bar.update()\n        _inst_df = pd.DataFrame.from_dict(self._old_instruments, orient=\"index\")\n        _inst_df.index.names = [self.symbol_field_name]\n        self.save_instruments(_inst_df.reset_index())\n        logger.info(\"end of instruments dump.\\n\")\n\n    def dump(self):\n        self._calendars_list = self._read_calendars(self._calendars_dir.joinpath(f\"{self.freq}.txt\"))\n        # noinspection PyAttributeOutsideInit\n        self._old_instruments = (\n            self._read_instruments(self._instruments_dir.joinpath(self.INSTRUMENTS_FILE_NAME))\n            .set_index([self.symbol_field_name])\n            .to_dict(orient=\"index\")\n        )  # type: dict\n        self._dump_instruments()\n        self._dump_features()\n\n\nclass DumpDataUpdate(DumpDataBase):\n    def __init__(\n        self,\n        data_path: str,\n        qlib_dir: str,\n        backup_dir: str = None,\n        freq: str = \"day\",\n        max_workers: int = 16,\n        date_field_name: str = \"date\",\n        file_suffix: str = \".csv\",\n        symbol_field_name: str = \"symbol\",\n        exclude_fields: str = \"\",\n        include_fields: str = \"\",\n        limit_nums: int = None,\n    ):\n        \"\"\"\n\n        Parameters\n        ----------\n        data_path: str\n            stock data path or directory\n        qlib_dir: str\n            qlib(dump) data director\n        backup_dir: str, default None\n            if backup_dir is not None, backup qlib_dir to backup_dir\n        freq: str, default \"day\"\n            transaction frequency\n        max_workers: int, default None\n            number of threads\n        date_field_name: str, default \"date\"\n            the name of the date field in the csv\n        file_suffix: str, default \".csv\"\n            file suffix\n        symbol_field_name: str, default \"symbol\"\n            symbol field name\n        include_fields: tuple\n            dump fields\n        exclude_fields: tuple\n            fields not dumped\n        limit_nums: int\n            Use when debugging, default None\n        \"\"\"\n        super().__init__(\n            data_path,\n            qlib_dir,\n            backup_dir,\n            freq,\n            max_workers,\n            date_field_name,\n            file_suffix,\n            symbol_field_name,\n            exclude_fields,\n            include_fields,\n        )\n        self._mode = self.UPDATE_MODE\n        self._old_calendar_list = self._read_calendars(self._calendars_dir.joinpath(f\"{self.freq}.txt\"))\n        # NOTE: all.txt only exists once for each stock\n        # NOTE: if a stock corresponds to multiple different time ranges, user need to modify self._update_instruments\n        self._update_instruments = (\n            self._read_instruments(self._instruments_dir.joinpath(self.INSTRUMENTS_FILE_NAME))\n            .set_index([self.symbol_field_name])\n            .to_dict(orient=\"index\")\n        )  # type: dict\n\n        # load all csv files\n        self._all_data = self._load_all_source_data()  # type: pd.DataFrame\n        self._new_calendar_list = self._old_calendar_list + sorted(\n            filter(lambda x: x > self._old_calendar_list[-1], self._all_data[self.date_field_name].unique())\n        )\n\n    def _load_all_source_data(self):\n        # NOTE: Need more memory\n        logger.info(\"start load all source data....\")\n        all_df = []\n\n        def _read_df(file_path: Path):\n            _df = read_as_df(file_path)\n            if self.date_field_name in _df.columns and not np.issubdtype(\n                _df[self.date_field_name].dtype, np.datetime64\n            ):\n                _df[self.date_field_name] = pd.to_datetime(_df[self.date_field_name])\n            if self.symbol_field_name not in _df.columns:\n                _df[self.symbol_field_name] = self.get_symbol_from_file(file_path)\n            return _df\n\n        with tqdm(total=len(self.df_files)) as p_bar:\n            with ThreadPoolExecutor(max_workers=self.works) as executor:\n                for df in executor.map(_read_df, self.df_files):\n                    if not df.empty:\n                        all_df.append(df)\n                    p_bar.update()\n\n        logger.info(\"end of load all data.\\n\")\n        return pd.concat(all_df, sort=False)\n\n    def _dump_calendars(self):\n        pass\n\n    def _dump_instruments(self):\n        pass\n\n    def _dump_features(self):\n        logger.info(\"start dump features......\")\n        error_code = {}\n        with ProcessPoolExecutor(max_workers=self.works) as executor:\n            futures = {}\n            for _code, _df in self._all_data.groupby(self.symbol_field_name, group_keys=False):\n                _code = fname_to_code(str(_code).lower()).upper()\n                _start, _end = self._get_date(_df, is_begin_end=True)\n                if not (isinstance(_start, pd.Timestamp) and isinstance(_end, pd.Timestamp)):\n                    continue\n                if _code in self._update_instruments:\n                    # exists stock, will append data\n                    _update_calendars = (\n                        _df[_df[self.date_field_name] > self._update_instruments[_code][self.INSTRUMENTS_END_FIELD]][\n                            self.date_field_name\n                        ]\n                        .sort_values()\n                        .to_list()\n                    )\n                    if _update_calendars:\n                        self._update_instruments[_code][self.INSTRUMENTS_END_FIELD] = self._format_datetime(_end)\n                        futures[executor.submit(self._dump_bin, _df, _update_calendars)] = _code\n                else:\n                    # new stock\n                    _dt_range = self._update_instruments.setdefault(_code, dict())\n                    _dt_range[self.INSTRUMENTS_START_FIELD] = self._format_datetime(_start)\n                    _dt_range[self.INSTRUMENTS_END_FIELD] = self._format_datetime(_end)\n                    futures[executor.submit(self._dump_bin, _df, self._new_calendar_list)] = _code\n\n            with tqdm(total=len(futures)) as p_bar:\n                for _future in as_completed(futures):\n                    try:\n                        _future.result()\n                    except Exception:\n                        error_code[futures[_future]] = traceback.format_exc()\n                    p_bar.update()\n            logger.info(f\"dump bin errors: {error_code}\")\n\n        logger.info(\"end of features dump.\\n\")\n\n    def dump(self):\n        self.save_calendars(self._new_calendar_list)\n        self._dump_features()\n        df = pd.DataFrame.from_dict(self._update_instruments, orient=\"index\")\n        df.index.names = [self.symbol_field_name]\n        self.save_instruments(df.reset_index())\n\n\nif __name__ == \"__main__\":\n    fire.Fire({\"dump_all\": DumpDataAll, \"dump_fix\": DumpDataFix, \"dump_update\": DumpDataUpdate})\n"
  },
  {
    "path": "scripts/dump_pit.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\"\"\"\nTODO:\n- A more well-designed PIT database is required.\n    - separated insert, delete, update, query operations are required.\n\"\"\"\n\nimport shutil\nimport struct\nfrom pathlib import Path\nfrom typing import Iterable\nfrom functools import partial\nfrom concurrent.futures import ProcessPoolExecutor\n\nimport fire\nimport pandas as pd\nfrom tqdm import tqdm\nfrom loguru import logger\nfrom qlib.utils import fname_to_code, get_period_offset\nfrom qlib.config import C\n\n\nclass DumpPitData:\n    PIT_DIR_NAME = \"financial\"\n    PIT_CSV_SEP = \",\"\n    DATA_FILE_SUFFIX = \".data\"\n    INDEX_FILE_SUFFIX = \".index\"\n\n    INTERVAL_quarterly = \"quarterly\"\n    INTERVAL_annual = \"annual\"\n\n    PERIOD_DTYPE = C.pit_record_type[\"period\"]\n    INDEX_DTYPE = C.pit_record_type[\"index\"]\n    DATA_DTYPE = \"\".join(\n        [\n            C.pit_record_type[\"date\"],\n            C.pit_record_type[\"period\"],\n            C.pit_record_type[\"value\"],\n            C.pit_record_type[\"index\"],\n        ]\n    )\n\n    NA_INDEX = C.pit_record_nan[\"index\"]\n\n    INDEX_DTYPE_SIZE = struct.calcsize(INDEX_DTYPE)\n    PERIOD_DTYPE_SIZE = struct.calcsize(PERIOD_DTYPE)\n    DATA_DTYPE_SIZE = struct.calcsize(DATA_DTYPE)\n\n    UPDATE_MODE = \"update\"\n    ALL_MODE = \"all\"\n\n    def __init__(\n        self,\n        csv_path: str,\n        qlib_dir: str,\n        backup_dir: str = None,\n        freq: str = \"quarterly\",\n        max_workers: int = 16,\n        date_column_name: str = \"date\",\n        period_column_name: str = \"period\",\n        value_column_name: str = \"value\",\n        field_column_name: str = \"field\",\n        file_suffix: str = \".csv\",\n        exclude_fields: str = \"\",\n        include_fields: str = \"\",\n        limit_nums: int = None,\n    ):\n        \"\"\"\n\n        Parameters\n        ----------\n        csv_path: str\n            stock data path or directory\n        qlib_dir: str\n            qlib(dump) data director\n        backup_dir: str, default None\n            if backup_dir is not None, backup qlib_dir to backup_dir\n        freq: str, default \"quarterly\"\n            data frequency\n        max_workers: int, default None\n            number of threads\n        date_column_name: str, default \"date\"\n            the name of the date field in the csv\n        file_suffix: str, default \".csv\"\n            file suffix\n        include_fields: tuple\n            dump fields\n        exclude_fields: tuple\n            fields not dumped\n        limit_nums: int\n            Use when debugging, default None\n        \"\"\"\n        csv_path = Path(csv_path).expanduser()\n        if isinstance(exclude_fields, str):\n            exclude_fields = exclude_fields.split(\",\")\n        if isinstance(include_fields, str):\n            include_fields = include_fields.split(\",\")\n        self._exclude_fields = tuple(filter(lambda x: len(x) > 0, map(str.strip, exclude_fields)))\n        self._include_fields = tuple(filter(lambda x: len(x) > 0, map(str.strip, include_fields)))\n        self.file_suffix = file_suffix\n        self.csv_files = sorted(csv_path.glob(f\"*{self.file_suffix}\") if csv_path.is_dir() else [csv_path])\n        if limit_nums is not None:\n            self.csv_files = self.csv_files[: int(limit_nums)]\n        self.qlib_dir = Path(qlib_dir).expanduser()\n        self.backup_dir = backup_dir if backup_dir is None else Path(backup_dir).expanduser()\n        if backup_dir is not None:\n            self._backup_qlib_dir(Path(backup_dir).expanduser())\n\n        self.works = max_workers\n        self.date_column_name = date_column_name\n        self.period_column_name = period_column_name\n        self.value_column_name = value_column_name\n        self.field_column_name = field_column_name\n\n        self._mode = self.ALL_MODE\n\n    def _backup_qlib_dir(self, target_dir: Path):\n        shutil.copytree(str(self.qlib_dir.resolve()), str(target_dir.resolve()))\n\n    def get_source_data(self, file_path: Path) -> pd.DataFrame:\n        df = pd.read_csv(str(file_path.resolve()), low_memory=False)\n        df[self.value_column_name] = df[self.value_column_name].astype(\"float32\")\n        df[self.date_column_name] = df[self.date_column_name].str.replace(\"-\", \"\").astype(\"int32\")\n        # df.drop_duplicates([self.date_field_name], inplace=True)\n        return df\n\n    def get_symbol_from_file(self, file_path: Path) -> str:\n        return fname_to_code(file_path.name[: -len(self.file_suffix)].strip().lower())\n\n    def get_dump_fields(self, df: Iterable[str]) -> Iterable[str]:\n        return (\n            set(self._include_fields)\n            if self._include_fields\n            else (\n                set(df[self.field_column_name]) - set(self._exclude_fields)\n                if self._exclude_fields\n                else set(df[self.field_column_name])\n            )\n        )\n\n    def get_filenames(self, symbol, field, interval):\n        dir_name = self.qlib_dir.joinpath(self.PIT_DIR_NAME, symbol)\n        dir_name.mkdir(parents=True, exist_ok=True)\n        return (\n            dir_name.joinpath(f\"{field}_{interval[0]}{self.DATA_FILE_SUFFIX}\".lower()),\n            dir_name.joinpath(f\"{field}_{interval[0]}{self.INDEX_FILE_SUFFIX}\".lower()),\n        )\n\n    def _dump_pit(\n        self,\n        file_path: str,\n        interval: str = \"quarterly\",\n        overwrite: bool = False,\n    ):\n        \"\"\"\n        dump data as the following format:\n            `/path/to/<field>.data`\n                [date, period, value, _next]\n                [date, period, value, _next]\n                [...]\n            `/path/to/<field>.index`\n                [first_year, index, index, ...]\n\n        `<field.data>` contains the data as the point-in-time (PIT) order: `value` of `period`\n        is published at `date`, and its successive revised value can be found at `_next` (linked list).\n\n        `<field>.index` contains the index of value for each period (quarter or year). To save\n        disk space, we only store the `first_year` as its followings periods can be easily infered.\n\n        Parameters\n        ----------\n        symbol: str\n            stock symbol\n        interval: str\n            data interval\n        overwrite: bool\n            whether overwrite existing data or update only\n        \"\"\"\n        symbol = self.get_symbol_from_file(file_path)\n        df = self.get_source_data(file_path)\n        if df.empty:\n            logger.warning(f\"{symbol} file is empty\")\n            return\n        for field in self.get_dump_fields(df):\n            df_sub = df.query(f'{self.field_column_name}==\"{field}\"').sort_values(self.date_column_name)\n            if df_sub.empty:\n                logger.warning(f\"field {field} of {symbol} is empty\")\n                continue\n            data_file, index_file = self.get_filenames(symbol, field, interval)\n\n            ## calculate first & last period\n            start_year = df_sub[self.period_column_name].min()\n            end_year = df_sub[self.period_column_name].max()\n            if interval == self.INTERVAL_quarterly:\n                start_year //= 100\n                end_year //= 100\n\n            # adjust `first_year` if existing data found\n            if not overwrite and index_file.exists():\n                with open(index_file, \"rb\") as fi:\n                    (first_year,) = struct.unpack(self.PERIOD_DTYPE, fi.read(self.PERIOD_DTYPE_SIZE))\n                    n_years = len(fi.read()) // self.INDEX_DTYPE_SIZE\n                    if interval == self.INTERVAL_quarterly:\n                        n_years //= 4\n                    start_year = first_year + n_years\n            else:\n                with open(index_file, \"wb\") as f:\n                    f.write(struct.pack(self.PERIOD_DTYPE, start_year))\n                first_year = start_year\n\n            # if data already exists, continue to the next field\n            if start_year > end_year:\n                logger.warning(f\"{symbol}-{field} data already exists, continue to the next field\")\n                continue\n\n            # dump index filled with NA\n            with open(index_file, \"ab\") as fi:\n                for year in range(start_year, end_year + 1):\n                    if interval == self.INTERVAL_quarterly:\n                        fi.write(struct.pack(self.INDEX_DTYPE * 4, *[self.NA_INDEX] * 4))\n                    else:\n                        fi.write(struct.pack(self.INDEX_DTYPE, self.NA_INDEX))\n\n            # if data already exists, remove overlapped data\n            if not overwrite and data_file.exists():\n                with open(data_file, \"rb\") as fd:\n                    fd.seek(-self.DATA_DTYPE_SIZE, 2)\n                    last_date, _, _, _ = struct.unpack(self.DATA_DTYPE, fd.read())\n                df_sub = df_sub.query(f\"{self.date_column_name}>{last_date}\")\n            # otherwise,\n            # 1) truncate existing file or create a new file with `wb+` if overwrite,\n            # 2) or append existing file or create a new file with `ab+` if not overwrite\n            else:\n                with open(data_file, \"wb+\" if overwrite else \"ab+\"):\n                    pass\n\n            with open(data_file, \"rb+\") as fd, open(index_file, \"rb+\") as fi:\n                # update index if needed\n                for i, row in df_sub.iterrows():\n                    # get index\n                    offset = get_period_offset(first_year, row.period, interval == self.INTERVAL_quarterly)\n\n                    fi.seek(self.PERIOD_DTYPE_SIZE + self.INDEX_DTYPE_SIZE * offset)\n                    (cur_index,) = struct.unpack(self.INDEX_DTYPE, fi.read(self.INDEX_DTYPE_SIZE))\n\n                    # Case I: new data => update `_next` with current index\n                    if cur_index == self.NA_INDEX:\n                        fi.seek(self.PERIOD_DTYPE_SIZE + self.INDEX_DTYPE_SIZE * offset)\n                        fi.write(struct.pack(self.INDEX_DTYPE, fd.tell()))\n                    # Case II: previous data exists => find and update the last `_next`\n                    else:\n                        _cur_fd = fd.tell()\n                        prev_index = self.NA_INDEX\n                        while cur_index != self.NA_INDEX:  # NOTE: first iter always != NA_INDEX\n                            fd.seek(cur_index + self.DATA_DTYPE_SIZE - self.INDEX_DTYPE_SIZE)\n                            prev_index = cur_index\n                            (cur_index,) = struct.unpack(self.INDEX_DTYPE, fd.read(self.INDEX_DTYPE_SIZE))\n                        fd.seek(prev_index + self.DATA_DTYPE_SIZE - self.INDEX_DTYPE_SIZE)\n                        fd.write(struct.pack(self.INDEX_DTYPE, _cur_fd))  # NOTE: add _next pointer\n                        fd.seek(_cur_fd)\n\n                    # dump data\n                    fd.write(struct.pack(self.DATA_DTYPE, row.date, row.period, row.value, self.NA_INDEX))\n\n    def dump(self, interval=\"quarterly\", overwrite=False):\n        logger.info(\"start dump pit data......\")\n        _dump_func = partial(self._dump_pit, interval=interval, overwrite=overwrite)\n\n        with tqdm(total=len(self.csv_files)) as p_bar:\n            with ProcessPoolExecutor(max_workers=self.works) as executor:\n                for _ in executor.map(_dump_func, self.csv_files):\n                    p_bar.update()\n\n    def __call__(self, *args, **kwargs):\n        self.dump()\n\n\nif __name__ == \"__main__\":\n    fire.Fire(DumpPitData)\n"
  },
  {
    "path": "scripts/get_data.py",
    "content": "#  Copyright (c) Microsoft Corporation.\n#  Licensed under the MIT License.\n\nimport fire\nfrom qlib.tests.data import GetData\n\nif __name__ == \"__main__\":\n    fire.Fire(GetData)\n"
  },
  {
    "path": "setup.py",
    "content": "import os\n\nimport numpy\nfrom setuptools import Extension, setup\n\nNUMPY_INCLUDE = numpy.get_include()\n\n\nsetup(\n    ext_modules=[\n        Extension(\n            \"qlib.data._libs.rolling\",\n            [\"qlib/data/_libs/rolling.pyx\"],\n            language=\"c++\",\n            include_dirs=[NUMPY_INCLUDE],\n        ),\n        Extension(\n            \"qlib.data._libs.expanding\",\n            [\"qlib/data/_libs/expanding.pyx\"],\n            language=\"c++\",\n            include_dirs=[NUMPY_INCLUDE],\n        ),\n    ],\n)\n"
  },
  {
    "path": "tests/backtest/test_file_strategy.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport unittest\nfrom qlib.backtest import backtest\nfrom qlib.tests import TestAutoData\nimport pandas as pd\nfrom pathlib import Path\nfrom qlib.data import D\nimport numpy as np\n\nDIRNAME = Path(__file__).absolute().resolve().parent\n\n\nclass FileStrTest(TestAutoData):\n    # Assumption to ensure the correctness of the test\n    # - No price adjustment in these several trading days.\n    TEST_INST = \"SH600519\"\n\n    EXAMPLE_FILE = DIRNAME / \"order_example.csv\"\n\n    def _gen_orders(self, dealt_num_for_1000) -> pd.DataFrame:\n        headers = [\n            \"datetime\",\n            \"instrument\",\n            \"amount\",\n            \"direction\",\n        ]\n        orders = [\n            # test cash limit for buying\n            [\"20200103\", self.TEST_INST, \"1000\", \"buy\"],\n            # test min_cost for buying\n            [\"20200106\", self.TEST_INST, \"1\", \"buy\"],\n            # test held stock limit for selling\n            [\"20200107\", self.TEST_INST, \"1000\", \"sell\"],\n            # test cash limit for buying\n            [\"20200108\", self.TEST_INST, \"1000\", \"buy\"],\n            # test min_cost for selling\n            [\"20200109\", self.TEST_INST, \"1\", \"sell\"],\n            # test selling all stocks\n            [\"20200110\", self.TEST_INST, str(dealt_num_for_1000), \"sell\"],\n        ]\n        return pd.DataFrame(orders, columns=headers).set_index([\"datetime\", \"instrument\"])\n\n    def test_file_str(self):\n        # 0) basic settings\n        account_money = 150000\n\n        # 1) get information\n        df = D.features([self.TEST_INST], [\"$close\", \"$factor\"], start_time=\"20200103\", end_time=\"20200103\")\n        price = df[\"$close\"].item()\n        factor = df[\"$factor\"].item()\n        price_unit = price / factor * 100\n        dealt_num_for_1000 = (account_money // price_unit) * (100 / factor)\n        print(price, factor, price_unit, dealt_num_for_1000)\n\n        # 2) generate orders\n        orders = self._gen_orders(dealt_num_for_1000)\n        orders.to_csv(self.EXAMPLE_FILE)\n        print(orders)\n\n        # 3) run the strategy\n        strategy_config = {\n            \"class\": \"FileOrderStrategy\",\n            \"module_path\": \"qlib.contrib.strategy.rule_strategy\",\n            \"kwargs\": {\"file\": self.EXAMPLE_FILE},\n        }\n\n        freq = \"day\"\n        start_time = \"2020-01-01\"\n        end_time = \"2020-01-16\"\n        codes = [self.TEST_INST]\n\n        backtest_config = {\n            \"start_time\": start_time,\n            \"end_time\": end_time,\n            \"account\": account_money,\n            \"benchmark\": None,  # benchmark is not required here for trading\n            \"exchange_kwargs\": {\n                \"freq\": freq,\n                \"limit_threshold\": 0.095,\n                \"deal_price\": \"close\",\n                \"open_cost\": 0.0005,\n                \"close_cost\": 0.0015,\n                \"min_cost\": 500,\n                \"codes\": codes,\n                \"trade_unit\": 100,\n            },\n            # \"pos_type\": \"InfPosition\"  # Position with infinitive position\n        }\n        executor_config = {\n            \"class\": \"SimulatorExecutor\",\n            \"module_path\": \"qlib.backtest.executor\",\n            \"kwargs\": {\n                \"time_per_step\": freq,\n                \"generate_portfolio_metrics\": False,\n                \"verbose\": True,\n                \"indicator_config\": {\n                    \"show_indicator\": False,\n                },\n            },\n        }\n        report_dict, indicator_dict = backtest(\n            executor=executor_config,\n            strategy=strategy_config,\n            **backtest_config,\n        )\n\n        # ffr valid\n        ffr_dict = indicator_dict[\"1day\"][0][\"ffr\"].to_dict()\n        ffr_dict = {str(date).split()[0]: ffr_dict[date] for date in ffr_dict}\n        assert np.isclose(ffr_dict[\"2020-01-03\"], dealt_num_for_1000 / 1000)\n        assert np.isclose(ffr_dict[\"2020-01-06\"], 0)\n        assert np.isclose(ffr_dict[\"2020-01-07\"], dealt_num_for_1000 / 1000)\n        assert np.isclose(ffr_dict[\"2020-01-08\"], dealt_num_for_1000 / 1000)\n        assert np.isclose(ffr_dict[\"2020-01-09\"], 0)\n        assert np.isclose(ffr_dict[\"2020-01-10\"], 1)\n\n        self.EXAMPLE_FILE.unlink()\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/backtest/test_high_freq_trading.py",
    "content": "from typing import List, Tuple, Union\nfrom qlib.backtest.position import Position\nfrom qlib.backtest import collect_data, format_decisions\nfrom qlib.backtest.decision import BaseTradeDecision, TradeRangeByTime\nimport qlib\nfrom qlib.tests import TestAutoData\nimport unittest\nimport pandas as pd\n\n\n@unittest.skip(\"This test takes a lot of time due to the large size of high-frequency data\")\nclass TestHFBacktest(TestAutoData):\n    @classmethod\n    def setUpClass(cls) -> None:\n        super().setUpClass(enable_1min=True, enable_1d_type=\"full\")\n\n    def _gen_orders(self, inst, date, pos) -> pd.DataFrame:\n        headers = [\n            \"datetime\",\n            \"instrument\",\n            \"amount\",\n            \"direction\",\n        ]\n        orders = [\n            [date, inst, pos, \"sell\"],\n        ]\n        return pd.DataFrame(orders, columns=headers)\n\n    def test_trading(self):\n        # date = \"2020-02-03\"\n        # inst = \"SH600068\"\n        # pos = 2.0167\n        pos = 100000\n        inst, date = \"SH600519\", \"2021-01-18\"\n        market = [inst]\n\n        start_time = f\"{date}\"\n        end_time = f\"{date} 15:00\"  # include the high-freq data on the end day\n        freq_l0 = \"day\"\n        freq_l1 = \"30min\"\n        freq_l2 = \"1min\"\n\n        orders = self._gen_orders(inst=inst, date=date, pos=pos * 0.90)\n\n        strategy_config = {\n            \"class\": \"FileOrderStrategy\",\n            \"module_path\": \"qlib.contrib.strategy.rule_strategy\",\n            \"kwargs\": {\n                \"trade_range\": TradeRangeByTime(\"10:45\", \"14:44\"),\n                \"file\": orders,\n            },\n        }\n        backtest_config = {\n            \"start_time\": start_time,\n            \"end_time\": end_time,\n            \"account\": {\n                \"cash\": 0,\n                inst: pos,\n            },\n            \"benchmark\": None,  # benchmark is not required here for trading\n            \"exchange_kwargs\": {\n                \"freq\": freq_l2,  # use the most fine-grained data as the exchange\n                \"limit_threshold\": 0.095,\n                \"deal_price\": \"close\",\n                \"open_cost\": 0.0005,\n                \"close_cost\": 0.0015,\n                \"min_cost\": 5,\n                \"codes\": market,\n                \"trade_unit\": 100,\n            },\n            # \"pos_type\": \"InfPosition\"  # Position with infinitive position\n        }\n        executor_config = {\n            \"class\": \"NestedExecutor\",  # Level 1 Order execution\n            \"module_path\": \"qlib.backtest.executor\",\n            \"kwargs\": {\n                \"time_per_step\": freq_l0,\n                \"inner_executor\": {\n                    \"class\": \"NestedExecutor\",  # Leve 2 Order Execution\n                    \"module_path\": \"qlib.backtest.executor\",\n                    \"kwargs\": {\n                        \"time_per_step\": freq_l1,\n                        \"inner_executor\": {\n                            \"class\": \"SimulatorExecutor\",\n                            \"module_path\": \"qlib.backtest.executor\",\n                            \"kwargs\": {\n                                \"time_per_step\": freq_l2,\n                                \"generate_portfolio_metrics\": False,\n                                \"verbose\": True,\n                                \"indicator_config\": {\n                                    \"show_indicator\": False,\n                                },\n                                \"track_data\": True,\n                            },\n                        },\n                        \"inner_strategy\": {\n                            \"class\": \"TWAPStrategy\",\n                            \"module_path\": \"qlib.contrib.strategy.rule_strategy\",\n                        },\n                        \"generate_portfolio_metrics\": False,\n                        \"indicator_config\": {\n                            \"show_indicator\": True,\n                        },\n                        \"track_data\": True,\n                    },\n                },\n                \"inner_strategy\": {\n                    \"class\": \"TWAPStrategy\",\n                    \"module_path\": \"qlib.contrib.strategy.rule_strategy\",\n                },\n                \"generate_portfolio_metrics\": False,\n                \"indicator_config\": {\n                    \"show_indicator\": True,\n                },\n                \"track_data\": True,\n            },\n        }\n\n        ret_val = {}\n        decisions = list(\n            collect_data(executor=executor_config, strategy=strategy_config, **backtest_config, return_value=ret_val)\n        )\n        report, indicator = ret_val[\"report\"], ret_val[\"indicator\"]\n        # NOTE: please refer to the docs of format_decisions\n        # NOTE: `\"track_data\": True,`  is very NECESSARY for collecting the decision!!!!!\n        f_dec = format_decisions(decisions)\n        print(indicator[\"1day\"][0])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/backtest/test_soft_topk_strategy.py",
    "content": "import pandas as pd\nimport pytest\nfrom qlib.contrib.strategy.cost_control import SoftTopkStrategy\n\n\nclass MockPosition:\n    def __init__(self, weights):\n        self.weights = weights\n\n    def get_stock_weight_dict(self, only_stock=True):\n        return self.weights\n\n\ndef test_soft_topk_logic():\n    # Initial: A=0.8, B=0.2 (Total=1.0). Target Risk=0.95.\n    # Scores: A and B are low, C and D are topk.\n    scores = pd.Series({\"C\": 0.9, \"D\": 0.8, \"A\": 0.1, \"B\": 0.1})\n    current_pos = MockPosition({\"A\": 0.8, \"B\": 0.2})\n\n    topk = 2\n    risk_degree = 0.95\n    impact_limit = 0.1  # Max change per step\n\n    def create_test_strategy(impact_limit_value):\n        strat = SoftTopkStrategy.__new__(SoftTopkStrategy)\n        strat.topk = topk\n        strat.risk_degree = risk_degree\n        strat.trade_impact_limit = impact_limit_value\n        return strat\n\n    # 1. With impact limit: Expect deterministic sell and limited buy\n    strat_i = create_test_strategy(impact_limit)\n    res_i = strat_i.generate_target_weight_position(scores, current_pos, None, None)\n\n    # A should be exactly 0.8 - 0.1 = 0.7\n    assert abs(res_i[\"A\"] - 0.7) < 1e-8\n    # B should be exactly 0.2 - 0.1 = 0.1\n    assert abs(res_i[\"B\"] - 0.1) < 1e-8\n    # Total sells = 0.2 released. New budget = 0.2 + (0.95 - 1.0) = 0.15.\n    # C and D share 0.15 -> 0.075 each.\n    assert abs(res_i[\"C\"] - 0.075) < 1e-8\n    assert abs(res_i[\"D\"] - 0.075) < 1e-8\n\n    # 2. Without impact limit: Expect full liquidation and full target fill\n    strat_c = create_test_strategy(1.0)\n    res_c = strat_c.generate_target_weight_position(scores, current_pos, None, None)\n\n    # A, B not in topk -> Liquidated\n    assert \"A\" not in res_c and \"B\" not in res_c\n    # C, D should reach ideal_per_stock (0.95/2 = 0.475)\n    assert abs(res_c[\"C\"] - 0.475) < 1e-8\n    assert abs(res_c[\"D\"] - 0.475) < 1e-8\n\n\nif __name__ == \"__main__\":\n    pytest.main([__file__])\n"
  },
  {
    "path": "tests/backtest/test_soft_topk_strategy_cold_start.py",
    "content": "import pandas as pd\nimport pytest\n\nfrom qlib.contrib.strategy.cost_control import SoftTopkStrategy\n\n\nclass MockPosition:\n    def __init__(self, weights):\n        self.weights = weights\n\n    def get_stock_weight_dict(self, only_stock=True):\n        return self.weights\n\n\ndef create_test_strategy(topk, risk_degree, impact_limit):\n    strat = SoftTopkStrategy.__new__(SoftTopkStrategy)\n    strat.topk = topk\n    strat.risk_degree = risk_degree\n    strat.trade_impact_limit = impact_limit\n    return strat\n\n\n@pytest.mark.parametrize(\n    (\"impact_limit\", \"expected_fill\"),\n    [\n        (0.1, 0.1),\n        (1.0, 0.475),\n    ],\n)\ndef test_soft_topk_cold_start_impact_limit(impact_limit, expected_fill):\n    scores = pd.Series({\"C\": 0.9, \"D\": 0.8, \"A\": 0.1, \"B\": 0.1})\n    current_pos = MockPosition({})\n\n    strat = create_test_strategy(topk=2, risk_degree=0.95, impact_limit=impact_limit)\n    res = strat.generate_target_weight_position(scores, current_pos, None, None)\n\n    assert abs(res[\"C\"] - expected_fill) < 1e-8\n    assert abs(res[\"D\"] - expected_fill) < 1e-8\n"
  },
  {
    "path": "tests/conftest.py",
    "content": "import os\nimport sys\n\n\"\"\"Ignore RL tests on non-linux platform.\"\"\"\ncollect_ignore = []\n\nif sys.platform != \"linux\":\n    for root, dirs, files in os.walk(\"rl\"):\n        for file in files:\n            collect_ignore.append(os.path.join(root, file))\n"
  },
  {
    "path": "tests/data_mid_layer_tests/README.md",
    "content": "# Introduction\nThe middle layers of data, which mainly includes\n- Handler\n    - processors\n- Datasets\n"
  },
  {
    "path": "tests/data_mid_layer_tests/test_dataloader.py",
    "content": "# TODO:\n# dump alpha 360 to dataframe and merge it with Alpha158\n\nimport sys\nimport unittest\nimport qlib\nfrom pathlib import Path\n\nsys.path.append(str(Path(__file__).resolve().parent))\nfrom qlib.data.dataset.loader import NestedDataLoader, QlibDataLoader\nfrom qlib.data.dataset.handler import DataHandlerLP\nfrom qlib.contrib.data.loader import Alpha158DL, Alpha360DL\nfrom qlib.data.dataset.processor import Fillna\nfrom qlib.data import D\n\n\nclass TestDataLoader(unittest.TestCase):\n\n    def test_nested_data_loader(self):\n        qlib.init(kernels=1)\n        nd = NestedDataLoader(\n            dataloader_l=[\n                {\n                    \"class\": \"qlib.contrib.data.loader.Alpha158DL\",\n                },\n                {\n                    \"class\": \"qlib.contrib.data.loader.Alpha360DL\",\n                    \"kwargs\": {\"config\": {\"label\": ([\"Ref($close, -2)/Ref($close, -1) - 1\"], [\"LABEL0\"])}},\n                },\n            ]\n        )\n        # Of course you can use StaticDataLoader\n\n        dataset = nd.load(instruments=\"csi300\", start_time=\"2020-01-01\", end_time=\"2020-01-31\")\n\n        assert dataset is not None\n\n        columns = dataset.columns.tolist()\n        columns_list = [tup[1] for tup in columns]\n\n        for col in Alpha158DL.get_feature_config()[1]:\n            assert col in columns_list\n\n        for col in Alpha360DL.get_feature_config()[1]:\n            assert col in columns_list\n\n        assert \"LABEL0\" in columns_list\n\n        assert dataset.isna().any().any()\n\n        fn = Fillna(fields_group=\"feature\", fill_value=0)\n        fn_dataset = fn.__call__(dataset)\n\n        assert not fn_dataset.isna().any().any()\n\n        # Then you can use it wth DataHandler;\n        # NOTE: please note that the data processors are missing!!!  You should add based on your requirements\n\n        \"\"\"\n        dataset.to_pickle(\"test_df.pkl\")\n        nested_data_loader = NestedDataLoader(\n            dataloader_l=[\n                {\n                    \"class\": \"qlib.contrib.data.loader.Alpha158DL\",\n                    \"kwargs\": {\"config\": {\"label\": ([\"Ref($close, -2)/Ref($close, -1) - 1\"], [\"LABEL0\"])}},\n                },\n                {\n                    \"class\": \"qlib.contrib.data.loader.Alpha360DL\",\n                },\n                {\n                    \"class\": \"qlib.data.dataset.loader.StaticDataLoader\",\n                    \"kwargs\": {\"config\": \"test_df.pkl\"},\n                },\n            ]\n        )\n        data_handler_config = {\n            \"start_time\": \"2008-01-01\",\n            \"end_time\": \"2020-08-01\",\n            \"instruments\": \"csi300\",\n            \"data_loader\": nested_data_loader,\n        }\n        data_handler = DataHandlerLP(**data_handler_config)\n        data = data_handler.fetch()\n        print(data)\n        \"\"\"\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/data_mid_layer_tests/test_dataset.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport unittest\nimport pytest\nimport sys\nfrom qlib.tests import TestAutoData\nfrom qlib.data.dataset import TSDatasetH, TSDataSampler\nimport numpy as np\nimport pandas as pd\nimport time\nfrom qlib.data.dataset.handler import DataHandlerLP\n\n\nclass TestDataset(TestAutoData):\n    @pytest.mark.slow\n    def testTSDataset(self):\n        tsdh = TSDatasetH(\n            handler={\n                \"class\": \"Alpha158\",\n                \"module_path\": \"qlib.contrib.data.handler\",\n                \"kwargs\": {\n                    \"start_time\": \"2017-01-01\",\n                    \"end_time\": \"2020-08-01\",\n                    \"fit_start_time\": \"2017-01-01\",\n                    \"fit_end_time\": \"2017-12-31\",\n                    \"instruments\": \"csi300\",\n                    \"infer_processors\": [\n                        {\"class\": \"FilterCol\", \"kwargs\": {\"col_list\": [\"RESI5\", \"WVMA5\", \"RSQR5\"]}},\n                        {\"class\": \"RobustZScoreNorm\", \"kwargs\": {\"fields_group\": \"feature\", \"clip_outlier\": \"true\"}},\n                        {\"class\": \"Fillna\", \"kwargs\": {\"fields_group\": \"feature\"}},\n                    ],\n                    \"learn_processors\": [\n                        \"DropnaLabel\",\n                        {\"class\": \"CSRankNorm\", \"kwargs\": {\"fields_group\": \"label\"}},  # CSRankNorm\n                    ],\n                },\n            },\n            segments={\n                \"train\": (\"2017-01-01\", \"2017-12-31\"),\n                \"valid\": (\"2018-01-01\", \"2018-12-31\"),\n                \"test\": (\"2019-01-01\", \"2020-08-01\"),\n            },\n        )\n        tsds_train = tsdh.prepare(\"train\", data_key=DataHandlerLP.DK_L)  # Test the correctness\n        tsds = tsdh.prepare(\"valid\", data_key=DataHandlerLP.DK_L)\n\n        t = time.time()\n        for idx in np.random.randint(0, len(tsds_train), size=2000):\n            _ = tsds_train[idx]\n        print(f\"2000 sample takes {time.time() - t}s\")\n\n        t = time.time()\n        for _ in range(20):\n            data = tsds_train[np.random.randint(0, len(tsds_train), size=2000)]\n        print(data.shape)\n        print(f\"2000 sample(batch index) * 20 times takes {time.time() - t}s\")\n\n        # The dimension of sample is same as tabular data, but it will return timeseries data of the sample\n\n        # We have two method to get the time-series of a sample\n\n        # 1) sample by int index directly\n        tsds[len(tsds) - 1]\n\n        # 2) sample by <datetime,instrument> index\n        data_from_ds = tsds[\"2017-12-31\", \"SZ300315\"]\n\n        # Check the data\n        # Get data from DataFrame Directly\n        data_from_df = (\n            tsdh.handler.fetch(data_key=DataHandlerLP.DK_L)\n            .loc(axis=0)[\"2017-01-01\":\"2017-12-31\", \"SZ300315\"]\n            .iloc[-30:]\n            .values\n        )\n\n        equal = np.isclose(data_from_df, data_from_ds)\n        self.assertTrue(equal[~np.isnan(data_from_df)].all())\n\n        if False:\n            # 3) get both index and data\n            # NOTE: We don't want to reply on pytorch, so this test can't be included. It is just a example\n            from torch.utils.data import DataLoader\n            from qlib.model.utils import IndexSampler\n\n            i = len(tsds) - 1\n            idx = tsds.get_index()\n            tsds[i]\n            idx[i]\n\n            s_w_i = IndexSampler(tsds)\n            test_loader = DataLoader(s_w_i)\n\n            s_w_i[3]\n            for data, i in test_loader:\n                break\n            print(data.shape)\n            print(idx[i])\n\n\nclass TestTSDataSampler(unittest.TestCase):\n    def test_TSDataSampler(self):\n        \"\"\"\n        Test TSDataSampler for issue #1716\n        \"\"\"\n        datetime_list = [\"2000-01-31\", \"2000-02-29\", \"2000-03-31\", \"2000-04-30\", \"2000-05-31\"]\n        instruments = [\"000001\", \"000002\", \"000003\", \"000004\", \"000005\"]\n        index = pd.MultiIndex.from_product(\n            [pd.to_datetime(datetime_list), instruments], names=[\"datetime\", \"instrument\"]\n        )\n        data = np.random.randn(len(datetime_list) * len(instruments))\n        test_df = pd.DataFrame(data=data, index=index, columns=[\"factor\"])\n        dataset = TSDataSampler(test_df, datetime_list[0], datetime_list[-1], step_len=2)\n        print()\n        print(\"--------------dataset[0]--------------\")\n        print(dataset[0])\n        print(\"--------------dataset[1]--------------\")\n        print(dataset[1])\n        assert len(dataset[0]) == 2\n        self.assertTrue(np.isnan(dataset[0][0]))\n        self.assertEqual(dataset[0][1], dataset[1][0])\n        self.assertEqual(dataset[1][1], dataset[2][0])\n        self.assertEqual(dataset[2][1], dataset[3][0])\n\n    def test_TSDataSampler2(self):\n        \"\"\"\n        Extra test TSDataSampler to prevent incorrect filling of nan for the values at the front\n        \"\"\"\n        datetime_list = [\"2000-01-31\", \"2000-02-29\", \"2000-03-31\", \"2000-04-30\", \"2000-05-31\"]\n        instruments = [\"000001\", \"000002\", \"000003\", \"000004\", \"000005\"]\n        index = pd.MultiIndex.from_product(\n            [pd.to_datetime(datetime_list), instruments], names=[\"datetime\", \"instrument\"]\n        )\n        data = np.random.randn(len(datetime_list) * len(instruments))\n        test_df = pd.DataFrame(data=data, index=index, columns=[\"factor\"])\n        dataset = TSDataSampler(test_df, datetime_list[2], datetime_list[-1], step_len=3)\n        print()\n        print(\"--------------dataset[0]--------------\")\n        print(dataset[0])\n        print(\"--------------dataset[1]--------------\")\n        print(dataset[1])\n        for i in range(3):\n            self.assertFalse(np.isnan(dataset[0][i]))\n            self.assertFalse(np.isnan(dataset[1][i]))\n        self.assertEqual(dataset[0][1], dataset[1][0])\n        self.assertEqual(dataset[0][2], dataset[1][1])\n\n\nif __name__ == \"__main__\":\n    unittest.main(verbosity=10)\n\n    # User could use following code to run test when using line_profiler\n    # td = TestDataset()\n    # td.setUpClass()\n    # td.testTSDataset()\n"
  },
  {
    "path": "tests/data_mid_layer_tests/test_handler.py",
    "content": "import os\nimport unittest\n\nfrom qlib.data import D\nfrom qlib.data.dataset.handler import DataHandlerLP\nfrom qlib.tests import TestAutoData\nfrom qlib.utils.pickle_utils import restricted_pickle_load\n\n\nclass HandlerTests(TestAutoData):\n    def to_str(self, obj):\n        return \"\".join(str(obj).split())\n\n    def test_handler_df(self):\n        df = D.features([\"sh600519\"], start_time=\"20190101\", end_time=\"20190201\", fields=[\"$close\"])\n        dh = DataHandlerLP.from_df(df)\n        print(dh.fetch())\n        self.assertTrue(dh._data.equals(df))\n        self.assertTrue(dh._infer is dh._data)\n        self.assertTrue(dh._learn is dh._data)\n        self.assertTrue(dh.data_loader._data is dh._data)\n        fname = \"_handler_test.pkl\"\n        dh.to_pickle(fname, dump_all=True)\n\n        with open(fname, \"rb\") as f:\n            dh_d = restricted_pickle_load(f)\n\n        self.assertTrue(dh_d._data.equals(df))\n        self.assertTrue(dh_d._infer is dh_d._data)\n        self.assertTrue(dh_d._learn is dh_d._data)\n        # Data loader will no longer be useful\n        self.assertTrue(\"_data\" not in dh_d.data_loader.__dict__.keys())\n        os.remove(fname)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/data_mid_layer_tests/test_handler_storage.py",
    "content": "import unittest\nimport time\nimport numpy as np\nfrom qlib.data import D\nfrom qlib.tests import TestAutoData\n\nfrom qlib.data.dataset.handler import DataHandlerLP\nfrom qlib.contrib.data.handler import check_transform_proc\nfrom qlib.log import TimeInspector\n\n\nclass TestHandler(DataHandlerLP):\n    def __init__(\n        self,\n        instruments=\"csi300\",\n        start_time=None,\n        end_time=None,\n        infer_processors=[],\n        learn_processors=[],\n        fit_start_time=None,\n        fit_end_time=None,\n        drop_raw=True,\n    ):\n        infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)\n        learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)\n\n        data_loader = {\n            \"class\": \"QlibDataLoader\",\n            \"kwargs\": {\n                \"freq\": \"day\",\n                \"config\": self.get_feature_config(),\n                \"swap_level\": False,\n            },\n        }\n\n        super().__init__(\n            instruments=instruments,\n            start_time=start_time,\n            end_time=end_time,\n            data_loader=data_loader,\n            infer_processors=infer_processors,\n            learn_processors=learn_processors,\n            drop_raw=drop_raw,\n        )\n\n    def get_feature_config(self):\n        fields = [\"Ref($open, 1)\", \"Ref($close, 1)\", \"Ref($volume, 1)\", \"$open\", \"$close\", \"$volume\"]\n        names = [\"open_0\", \"close_0\", \"volume_0\", \"open_1\", \"close_1\", \"volume_1\"]\n        return fields, names\n\n\nclass TestHandlerStorage(TestAutoData):\n    market = \"all\"\n\n    start_time = \"2010-01-01\"\n    end_time = \"2020-12-31\"\n    train_end_time = \"2015-12-31\"\n    test_start_time = \"2016-01-01\"\n\n    data_handler_kwargs = {\n        \"start_time\": start_time,\n        \"end_time\": end_time,\n        \"fit_start_time\": start_time,\n        \"fit_end_time\": train_end_time,\n        \"instruments\": market,\n    }\n\n    def test_handler_storage(self):\n        # init data handler\n        data_handler = TestHandler(**self.data_handler_kwargs)\n\n        # init data handler with hasing storage\n        data_handler_hs = TestHandler(**self.data_handler_kwargs, infer_processors=[\"HashStockFormat\"])\n\n        fetch_start_time = \"2019-01-01\"\n        fetch_end_time = \"2019-12-31\"\n        instruments = D.instruments(market=self.market)\n        instruments = D.list_instruments(\n            instruments=instruments, start_time=fetch_start_time, end_time=fetch_end_time, as_list=True\n        )\n\n        with TimeInspector.logt(\"random fetch with DataFrame Storage\"):\n            # single stock\n            for i in range(100):\n                random_index = np.random.randint(len(instruments), size=1)[0]\n                fetch_stock = instruments[random_index]\n                data_handler.fetch(selector=(fetch_stock, slice(fetch_start_time, fetch_end_time)), level=None)\n\n            # multi stocks\n            for i in range(100):\n                random_indexs = np.random.randint(len(instruments), size=5)\n                fetch_stocks = [instruments[_index] for _index in random_indexs]\n                data_handler.fetch(selector=(fetch_stocks, slice(fetch_start_time, fetch_end_time)), level=None)\n\n        with TimeInspector.logt(\"random fetch with HashingStock Storage\"):\n            # single stock\n            for i in range(100):\n                random_index = np.random.randint(len(instruments), size=1)[0]\n                fetch_stock = instruments[random_index]\n                data_handler_hs.fetch(selector=(fetch_stock, slice(fetch_start_time, fetch_end_time)), level=None)\n\n            # multi stocks\n            for i in range(100):\n                random_indexs = np.random.randint(len(instruments), size=5)\n                fetch_stocks = [instruments[_index] for _index in random_indexs]\n                data_handler_hs.fetch(selector=(fetch_stocks, slice(fetch_start_time, fetch_end_time)), level=None)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/data_mid_layer_tests/test_processor.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport unittest\nimport numpy as np\nfrom qlib.data import D\nfrom qlib.tests import TestAutoData\nfrom qlib.data.dataset.processor import MinMaxNorm, ZScoreNorm, CSZScoreNorm, CSZFillna\n\n\nclass TestProcessor(TestAutoData):\n    TEST_INST = \"SH600519\"\n\n    def test_MinMaxNorm(self):\n        def normalize(df):\n            min_val = np.nanmin(df.values, axis=0)\n            max_val = np.nanmax(df.values, axis=0)\n            ignore = min_val == max_val\n            for _i, _con in enumerate(ignore):\n                if _con:\n                    max_val[_i] = 1\n                    min_val[_i] = 0\n            df.loc(axis=1)[df.columns] = (df.values - min_val) / (max_val - min_val)\n            return df\n\n        origin_df = D.features([self.TEST_INST], [\"$high\", \"$open\", \"$low\", \"$close\"]).tail(10)\n        origin_df[\"test\"] = 0\n        df = origin_df.copy()\n        mmn = MinMaxNorm(fields_group=None, fit_start_time=\"2021-05-31\", fit_end_time=\"2021-06-11\")\n        mmn.fit(df)\n        mmn.__call__(df)\n        origin_df = normalize(origin_df)\n        assert (df == origin_df).all().all()\n\n    def test_ZScoreNorm(self):\n        def normalize(df):\n            mean_train = np.nanmean(df.values, axis=0)\n            std_train = np.nanstd(df.values, axis=0)\n            ignore = std_train == 0\n            for _i, _con in enumerate(ignore):\n                if _con:\n                    std_train[_i] = 1\n                    mean_train[_i] = 0\n            df.loc(axis=1)[df.columns] = (df.values - mean_train) / std_train\n            return df\n\n        origin_df = D.features([self.TEST_INST], [\"$high\", \"$open\", \"$low\", \"$close\"]).tail(10)\n        origin_df[\"test\"] = 0\n        df = origin_df.copy()\n        zsn = ZScoreNorm(fields_group=None, fit_start_time=\"2021-05-31\", fit_end_time=\"2021-06-11\")\n        zsn.fit(df)\n        zsn.__call__(df)\n        origin_df = normalize(origin_df)\n        assert (df == origin_df).all().all()\n\n    def test_CSZFillna(self):\n        origin_df = D.features(D.instruments(market=\"csi300\"), fields=[\"$high\", \"$open\", \"$low\", \"$close\"])\n        origin_df = origin_df.groupby(\"datetime\", group_keys=False).apply(lambda x: x[97:99])[228:238]\n        df = origin_df.copy()\n        CSZFillna(fields_group=None).__call__(df)\n        assert ~df[1:2].isna().all().all() and origin_df[1:2].isna().all().all()\n\n    def test_CSZScoreNorm(self):\n        origin_df = D.features(D.instruments(market=\"csi300\"), fields=[\"$high\", \"$open\", \"$low\", \"$close\"])\n        origin_df = origin_df.groupby(\"datetime\", group_keys=False).apply(lambda x: x[10:12])[50:60]\n        df = origin_df.copy()\n        CSZScoreNorm(fields_group=None).__call__(df)\n        # If we use the formula directly on the original data, we cannot get the correct result,\n        # because the original data is processed by `groupby`, so we use the method of slicing,\n        # taking the 2nd group of data from the original data, to calculate and compare.\n        assert (df[2:4] == ((origin_df[2:4] - origin_df[2:4].mean()).div(origin_df[2:4].std()))).all().all()\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/dataset_tests/README.md",
    "content": "# About dataset tests\nTests in this folder are for testing the prepared dataset from Yahoo\n"
  },
  {
    "path": "tests/dataset_tests/test_datalayer.py",
    "content": "import unittest\nimport numpy as np\nfrom qlib.data import D\nfrom qlib.tests import TestAutoData\n\n\nclass TestDataset(TestAutoData):\n    def testCSI300(self):\n        close_p = D.features(D.instruments(\"csi300\"), [\"$close\"])\n        size = close_p.groupby(\"datetime\", group_keys=False).size()\n        cnt = close_p.groupby(\"datetime\", group_keys=False).count()[\"$close\"]\n        size_desc = size.describe(percentiles=np.arange(0.1, 1.0, 0.1))\n        cnt_desc = cnt.describe(percentiles=np.arange(0.1, 1.0, 0.1))\n\n        print(size_desc)\n        print(cnt_desc)\n\n        self.assertLessEqual(size_desc.loc[\"max\"], 305, \"Excessive number of CSI300 constituent stocks\")\n        self.assertGreaterEqual(size_desc.loc[\"80%\"], 290, \"Insufficient number of CSI300 constituent stocks\")\n\n        self.assertLessEqual(cnt_desc.loc[\"max\"], 305, \"Excessive number of CSI300 constituent stocks\")\n        # FIXME: Due to the low quality of data. Hard to make sure there are enough data\n        # self.assertEqual(cnt_desc.loc[\"80%\"], 300, \"Insufficient number of CSI300 constituent stocks\")\n\n    def testClose(self):\n        close_p = D.features(D.instruments(\"csi300\"), [\"Ref($close, 1)/$close - 1\"])\n        close_desc = close_p.describe(percentiles=np.arange(0.1, 1.0, 0.1))\n        print(close_desc)\n        self.assertLessEqual(abs(close_desc.loc[\"90%\"][0]), 0.1, \"Close value is abnormal\")\n        self.assertLessEqual(abs(close_desc.loc[\"10%\"][0]), 0.1, \"Close value is abnormal\")\n        # FIXME: The yahoo data is not perfect. We have to\n        # self.assertLessEqual(abs(close_desc.loc[\"max\"][0]), 0.2, \"Close value is abnormal\")\n        # self.assertGreaterEqual(close_desc.loc[\"min\"][0], -0.2, \"Close value is abnormal\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/dependency_tests/README.md",
    "content": "Some implementations of Qlib depend on some assumptions of its dependencies.\n\nSo some tests are requried to ensure that these assumptions are valid.\n"
  },
  {
    "path": "tests/dependency_tests/test_mlflow.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\nimport unittest\nimport platform\nimport mlflow\nimport time\nfrom pathlib import Path\nimport shutil\n\n\nclass MLflowTest(unittest.TestCase):\n    TMP_PATH = Path(\"./.mlruns_tmp/\")\n\n    def tearDown(self) -> None:\n        if self.TMP_PATH.exists():\n            shutil.rmtree(self.TMP_PATH)\n\n    def test_creating_client(self):\n        \"\"\"\n        Please refer to qlib/workflow/expm.py:MLflowExpManager._client\n        we don't cache _client (this is helpful to reduce maintainance work when MLflowExpManager's uri is chagned)\n\n        This implementation is based on the assumption creating a client is fast\n        \"\"\"\n        start = time.time()\n        for i in range(10):\n            _ = mlflow.tracking.MlflowClient(tracking_uri=str(self.TMP_PATH))\n        end = time.time()\n        elapsed = end - start\n        if platform.system() == \"Linux\":\n            self.assertLess(elapsed, 1e-2)  # it can be done in less than 10ms\n        else:\n            self.assertLess(elapsed, 2e-2)\n        print(elapsed)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/misc/test_get_multi_proc.py",
    "content": "#  Copyright (c) Microsoft Corporation.\n#  Licensed under the MIT License.\n\nimport unittest\n\nimport qlib\nfrom qlib.data import D\nfrom qlib.tests import TestAutoData\nfrom multiprocessing import Pool\n\n\ndef get_features(fields):\n    qlib.init(provider_uri=TestAutoData.provider_uri, expression_cache=None, dataset_cache=None, joblib_backend=\"loky\")\n    return D.features(D.instruments(\"csi300\"), fields)\n\n\nclass TestGetData(TestAutoData):\n    FIELDS = \"$open,$close,$high,$low,$volume,$factor,$change\".split(\",\")\n\n    def test_multi_proc(self):\n        \"\"\"\n        For testing if it will raise error\n        \"\"\"\n        iter_n = 2\n        pool = Pool(iter_n)\n\n        res = []\n        for _ in range(iter_n):\n            res.append(pool.apply_async(get_features, (self.FIELDS,), {}))\n\n        for r in res:\n            print(r.get())\n\n        pool.close()\n        pool.join()\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/misc/test_index_data.py",
    "content": "import numpy as np\nimport pandas as pd\nimport qlib.utils.index_data as idd\n\nimport unittest\n\n\nclass IndexDataTest(unittest.TestCase):\n    def test_index_single_data(self):\n        # Auto broadcast for scalar\n        sd = idd.SingleData(0, index=[\"foo\", \"bar\"])\n        print(sd)\n\n        # Support empty value\n        sd = idd.SingleData()\n        print(sd)\n\n        # Bad case: the input is not aligned\n        with self.assertRaises(ValueError):\n            idd.SingleData(range(10), index=[\"foo\", \"bar\"])\n\n        # test indexing\n        sd = idd.SingleData([1, 2, 3, 4], index=[\"foo\", \"bar\", \"f\", \"g\"])\n        print(sd)\n        print(sd.iloc[1])  # get second row\n\n        # Bad case: it is not in the index\n        with self.assertRaises(KeyError):\n            print(sd.loc[1])\n\n        print(sd.loc[\"foo\"])\n\n        # Test slicing\n        print(sd.loc[:\"bar\"])\n\n        print(sd.iloc[:3])\n\n    def test_index_multi_data(self):\n        # Auto broadcast for scalar\n        sd = idd.MultiData(0, index=[\"foo\", \"bar\"], columns=[\"f\", \"g\"])\n        print(sd)\n\n        # Bad case: the input is not aligned\n        with self.assertRaises(ValueError):\n            idd.MultiData(range(10), index=[\"foo\", \"bar\"], columns=[\"f\", \"g\"])\n\n        # test indexing\n        sd = idd.MultiData(np.arange(4).reshape(2, 2), index=[\"foo\", \"bar\"], columns=[\"f\", \"g\"])\n        print(sd)\n        print(sd.iloc[1])  # get second row\n\n        # Bad case: it is not in the index\n        with self.assertRaises(KeyError):\n            print(sd.loc[1])\n\n        print(sd.loc[\"foo\"])\n\n        # Test slicing\n\n        print(sd.loc[:\"foo\"])\n\n        print(sd.loc[:, \"g\":])\n\n    def test_sorting(self):\n        sd = idd.MultiData(np.arange(4).reshape(2, 2), index=[\"foo\", \"bar\"], columns=[\"f\", \"g\"])\n        print(sd)\n        sd.sort_index()\n\n        print(sd)\n        print(sd.loc[:\"c\"])\n\n    def test_corner_cases(self):\n        sd = idd.MultiData([[1, 2], [3, np.nan]], index=[\"foo\", \"bar\"], columns=[\"f\", \"g\"])\n        print(sd)\n\n        self.assertTrue(np.isnan(sd.loc[\"bar\", \"g\"]))\n\n        # support slicing\n        print(sd.loc[~sd.loc[:, \"g\"].isna().data.astype(bool)])\n\n        print(self.assertTrue(idd.SingleData().index == idd.SingleData().index))\n\n        # empty dict\n        print(idd.SingleData({}))\n        print(idd.SingleData(pd.Series()))\n\n        sd = idd.SingleData()\n        with self.assertRaises(KeyError):\n            sd.loc[\"foo\"]\n\n        # replace\n        sd = idd.SingleData([1, 2, 3, 4], index=[\"foo\", \"bar\", \"f\", \"g\"])\n        sd = sd.replace(dict(zip(range(1, 5), range(2, 6))))\n        print(sd)\n        self.assertTrue(sd.iloc[0] == 2)\n\n        # test different precisions of time data\n        timeindex = [\n            np.datetime64(\"2024-06-22T00:00:00.000000000\"),\n            np.datetime64(\"2024-06-21T00:00:00.000000000\"),\n            np.datetime64(\"2024-06-20T00:00:00.000000000\"),\n        ]\n        sd = idd.SingleData([1, 2, 3], index=timeindex)\n        self.assertTrue(\n            sd.index.index(np.datetime64(\"2024-06-21T00:00:00.000000000\"))\n            == sd.index.index(np.datetime64(\"2024-06-21T00:00:00\"))\n        )\n        self.assertTrue(sd.index.index(pd.Timestamp(\"2024-06-21 00:00\")) == 1)\n\n        # Bad case: the input is not aligned\n        timeindex[1] = (np.datetime64(\"2024-06-21T00:00:00.00\"),)\n        with self.assertRaises(TypeError):\n            sd = idd.SingleData([1, 2, 3], index=timeindex)\n\n    def test_ops(self):\n        sd1 = idd.SingleData([1, 2, 3, 4], index=[\"foo\", \"bar\", \"f\", \"g\"])\n        sd2 = idd.SingleData([1, 2, 3, 4], index=[\"foo\", \"bar\", \"f\", \"g\"])\n        print(sd1 + sd2)\n        new_sd = sd2 * 2\n        self.assertTrue(new_sd.index == sd2.index)\n\n        sd1 = idd.SingleData([1, 2, None, 4], index=[\"foo\", \"bar\", \"f\", \"g\"])\n        sd2 = idd.SingleData([1, 2, 3, None], index=[\"foo\", \"bar\", \"f\", \"g\"])\n        self.assertTrue(np.isnan((sd1 + sd2).iloc[3]))\n        self.assertTrue(sd1.add(sd2).sum() == 13)\n\n        self.assertTrue(idd.sum_by_index([sd1, sd2], sd1.index, fill_value=0.0).sum() == 13)\n\n    def test_todo(self):\n        pass\n        # here are some examples which do not affect the current system, but it is weird not to support it\n        # sd2 = idd.SingleData([1, 2, 3, 4], index=[\"foo\", \"bar\", \"f\", \"g\"])\n        # 2 * sd2\n\n    def test_squeeze(self):\n        sd1 = idd.SingleData([1, 2, 3, 4], index=[\"foo\", \"bar\", \"f\", \"g\"])\n        # automatically squeezing\n        self.assertTrue(not isinstance(np.nansum(sd1), idd.IndexData))\n        self.assertTrue(not isinstance(np.sum(sd1), idd.IndexData))\n        self.assertTrue(not isinstance(sd1.sum(), idd.IndexData))\n        self.assertEqual(np.nansum(sd1), 10)\n        self.assertEqual(np.sum(sd1), 10)\n        self.assertEqual(sd1.sum(), 10)\n        self.assertEqual(np.nanmean(sd1), 2.5)\n        self.assertEqual(np.mean(sd1), 2.5)\n        self.assertEqual(sd1.mean(), 2.5)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/misc/test_sepdf.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\nimport unittest\nimport numpy as np\nimport pandas as pd\nfrom qlib.contrib.data.utils.sepdf import SepDataFrame\n\n\nclass SepDF(unittest.TestCase):\n    def to_str(self, obj):\n        return \"\".join(str(obj).split())\n\n    def test_index_data(self):\n        np.random.seed(42)\n\n        index = [\n            np.array([\"bar\", \"bar\", \"baz\", \"baz\", \"foo\", \"foo\", \"qux\", \"qux\"]),\n            np.array([\"one\", \"two\", \"one\", \"two\", \"one\", \"two\", \"one\", \"two\"]),\n        ]\n\n        cols = [\n            np.repeat(np.array([\"g1\", \"g2\"]), 2),\n            np.arange(4),\n        ]\n        df = pd.DataFrame(np.random.randn(8, 4), index=index, columns=cols)\n        sdf = SepDataFrame(df_dict={\"g2\": df[\"g2\"]}, join=None)\n        sdf[(\"g2\", 4)] = 3\n        sdf[\"g1\"] = df[\"g1\"]\n        exp = \"\"\"\n        {'g2':                 2         3  4\n        bar one  0.647689  1.523030  3\n            two  1.579213  0.767435  3\n        baz one -0.463418 -0.465730  3\n            two -1.724918 -0.562288  3\n        foo one -0.908024 -1.412304  3\n            two  0.067528 -1.424748  3\n        qux one -1.150994  0.375698  3\n            two -0.601707  1.852278  3, 'g1':                 0         1\n        bar one  0.496714 -0.138264\n            two -0.234153 -0.234137\n        baz one -0.469474  0.542560\n            two  0.241962 -1.913280\n        foo one -1.012831  0.314247\n            two  1.465649 -0.225776\n        qux one -0.544383  0.110923\n            two -0.600639 -0.291694}\n        \"\"\"\n        self.assertEqual(self.to_str(sdf._df_dict), self.to_str(exp))\n\n        del df[\"g1\"]\n        del df[\"g2\"]\n        # it will not raise error, and df will be an empty dataframe\n\n        del sdf[\"g1\"]\n        del sdf[\"g2\"]\n        # sdf should support deleting all the columns\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/misc/test_utils.py",
    "content": "from typing import List\nfrom unittest.case import TestCase\nimport unittest\nimport pandas as pd\nimport numpy as np\nfrom datetime import datetime\nfrom qlib import init\nfrom qlib.config import C\nfrom qlib.log import TimeInspector\nfrom qlib.constant import REG_CN, REG_US, REG_TW\nfrom qlib.utils.time import cal_sam_minute as cal_sam_minute_new, get_min_cal, CN_TIME, US_TIME, TW_TIME\nfrom qlib.utils.data import guess_horizon\n\nREG_MAP = {REG_CN: CN_TIME, REG_US: US_TIME, REG_TW: TW_TIME}\n\n\ndef cal_sam_minute(x: pd.Timestamp, sam_minutes: int, region: str):\n    \"\"\"\n    Sample raw calendar into calendar with sam_minutes freq, shift represents the shift minute the market time\n        - open time of stock market is [9:30 - shift*pd.Timedelta(minutes=1)]\n        - mid close time of stock market is [11:29 - shift*pd.Timedelta(minutes=1)]\n        - mid open time of stock market is [13:00 - shift*pd.Timedelta(minutes=1)]\n        - close time of stock market is [14:59 - shift*pd.Timedelta(minutes=1)]\n    \"\"\"\n    # TODO: actually, this version is much faster when no cache or optimization\n    day_time = pd.Timestamp(x.date())\n    shift = C.min_data_shift\n    region_time = REG_MAP[region]\n\n    open_time = (\n        day_time\n        + pd.Timedelta(hours=region_time[0].hour, minutes=region_time[0].minute)\n        - shift * pd.Timedelta(minutes=1)\n    )\n    close_time = (\n        day_time\n        + pd.Timedelta(hours=region_time[-1].hour, minutes=region_time[-1].minute)\n        - shift * pd.Timedelta(minutes=1)\n    )\n    if region_time == CN_TIME:\n        mid_close_time = (\n            day_time\n            + pd.Timedelta(hours=region_time[1].hour, minutes=region_time[1].minute - 1)\n            - shift * pd.Timedelta(minutes=1)\n        )\n        mid_open_time = (\n            day_time\n            + pd.Timedelta(hours=region_time[2].hour, minutes=region_time[2].minute)\n            - shift * pd.Timedelta(minutes=1)\n        )\n    else:\n        mid_close_time = close_time\n        mid_open_time = open_time\n\n    if open_time <= x <= mid_close_time:\n        minute_index = (x - open_time).seconds // 60\n    elif mid_open_time <= x <= close_time:\n        minute_index = (x - mid_open_time).seconds // 60 + 120\n    else:\n        raise ValueError(\"datetime of calendar is out of range\")\n\n    minute_index = minute_index // sam_minutes * sam_minutes\n\n    if 0 <= minute_index < 120 or region_time != CN_TIME:\n        return open_time + minute_index * pd.Timedelta(minutes=1)\n    elif 120 <= minute_index < 240:\n        return mid_open_time + (minute_index - 120) * pd.Timedelta(minutes=1)\n    else:\n        raise ValueError(\"calendar minute_index error, check `min_data_shift` in qlib.config.C\")\n\n\nclass TimeUtils(TestCase):\n    @classmethod\n    def setUpClass(cls):\n        init()\n\n    def test_cal_sam_minute(self):\n        # test the correctness of the code\n        random_n = 1000\n        regions = [REG_CN, REG_US, REG_TW]\n\n        def gen_args(cal: List):\n            for time in np.random.choice(cal, size=random_n, replace=True):\n                sam_minutes = np.random.choice([1, 2, 3, 4, 5, 6])\n                dt = pd.Timestamp(\n                    datetime(\n                        2021,\n                        month=3,\n                        day=3,\n                        hour=time.hour,\n                        minute=time.minute,\n                        second=time.second,\n                        microsecond=time.microsecond,\n                    )\n                )\n                args = dt, sam_minutes\n                yield args\n\n        for region in regions:\n            cal_time = get_min_cal(region=region)\n            for args in gen_args(cal_time):\n                assert cal_sam_minute(*args, region) == cal_sam_minute_new(*args, region=region)\n\n            # test the performance of the code\n            args_l = list(gen_args(cal_time))\n\n            with TimeInspector.logt():\n                for args in args_l:\n                    cal_sam_minute(*args, region=region)\n\n            with TimeInspector.logt():\n                for args in args_l:\n                    cal_sam_minute_new(*args, region=region)\n\n\nclass DataUtils(TestCase):\n    @classmethod\n    def setUpClass(cls):\n        init()\n\n    def test_guess_horizon(self):\n        label = [\"Ref($close, -2) / Ref($close, -1) - 1\"]\n        result = guess_horizon(label)\n        assert result == 2\n\n        label = [\"Ref($close, -5) / Ref($close, -1) - 1\"]\n        result = guess_horizon(label)\n        assert result == 5\n\n        label = [\"Ref($close, -1) / Ref($close, -1) - 1\"]\n        result = guess_horizon(label)\n        assert result == 1\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/model/test_general_nn.py",
    "content": "import unittest\nfrom qlib.tests import TestAutoData\n\n\nclass TestNN(TestAutoData):\n    def test_both_dataset(self):\n        try:\n            from qlib.contrib.model.pytorch_general_nn import GeneralPTNN\n            from qlib.data.dataset import DatasetH, TSDatasetH\n            from qlib.data.dataset.handler import DataHandlerLP\n        except ImportError:\n            print(\"Import error.\")\n            return\n\n        data_handler_config = {\n            \"start_time\": \"2008-01-01\",\n            \"end_time\": \"2020-08-01\",\n            \"instruments\": \"csi300\",\n            \"data_loader\": {\n                \"class\": \"QlibDataLoader\",  # Assuming QlibDataLoader is a string reference to the class\n                \"kwargs\": {\n                    \"config\": {\n                        \"feature\": [[\"$high\", \"$close\", \"$low\"], [\"H\", \"C\", \"L\"]],\n                        \"label\": [[\"Ref($close, -2)/Ref($close, -1) - 1\"], [\"LABEL0\"]],\n                    },\n                    \"freq\": \"day\",\n                },\n            },\n            # TODO: processors\n            \"learn_processors\": [\n                {\n                    \"class\": \"DropnaLabel\",\n                },\n                {\"class\": \"CSZScoreNorm\", \"kwargs\": {\"fields_group\": \"label\"}},\n            ],\n        }\n        segments = {\n            \"train\": [\"2008-01-01\", \"2014-12-31\"],\n            \"valid\": [\"2015-01-01\", \"2016-12-31\"],\n            \"test\": [\"2017-01-01\", \"2020-08-01\"],\n        }\n        data_handler = DataHandlerLP(**data_handler_config)\n\n        # time-series dataset\n        tsds = TSDatasetH(handler=data_handler, segments=segments)\n\n        # tabular dataset\n        tbds = DatasetH(handler=data_handler, segments=segments)\n\n        model_l = [\n            GeneralPTNN(\n                n_epochs=2,\n                batch_size=32,\n                n_jobs=0,\n                pt_model_uri=\"qlib.contrib.model.pytorch_gru_ts.GRUModel\",\n                pt_model_kwargs={\n                    \"d_feat\": 3,\n                    \"hidden_size\": 8,\n                    \"num_layers\": 1,\n                    \"dropout\": 0.0,\n                },\n            ),\n            GeneralPTNN(\n                n_epochs=2,\n                batch_size=32,\n                n_jobs=0,\n                pt_model_uri=\"qlib.contrib.model.pytorch_nn.Net\",  # it is a MLP\n                pt_model_kwargs={\n                    \"input_dim\": 3,\n                },\n            ),\n        ]\n\n        for ds, model in list(zip((tsds, tbds), model_l)):\n            model.fit(ds)  # It works\n            model.predict(ds)  # It works\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/ops/test_elem_operator.py",
    "content": "import unittest\nimport numpy as np\nimport pytest\n\nfrom qlib.data import DatasetProvider\nfrom qlib.data.data import ExpressionD\nfrom qlib.tests import TestOperatorData, TestMockData, MOCK_DF\nfrom qlib.config import C\n\n\nclass TestElementOperator(TestMockData):\n    def setUp(self) -> None:\n        self.instrument = \"0050\"\n        self.start_time = \"2022-01-01\"\n        self.end_time = \"2022-02-01\"\n        self.freq = \"day\"\n        self.mock_df = MOCK_DF[MOCK_DF[\"symbol\"] == self.instrument]\n\n    def test_Abs(self):\n        field = \"Abs($close-Ref($close, 1))\"\n        result = ExpressionD.expression(self.instrument, field, self.start_time, self.end_time, self.freq)\n        self.assertGreaterEqual(result.min(), 0)\n        result = result.to_numpy()\n        prev_close = self.mock_df[\"close\"].shift(1)\n        close = self.mock_df[\"close\"]\n        change = prev_close - close\n        golden = change.abs().to_numpy()\n        self.assertIsNone(np.testing.assert_allclose(result, golden))\n\n    def test_Sign(self):\n        field = \"Sign($close-Ref($close, 1))\"\n        result = ExpressionD.expression(self.instrument, field, self.start_time, self.end_time, self.freq)\n        result = result.to_numpy()\n        prev_close = self.mock_df[\"close\"].shift(1)\n        close = self.mock_df[\"close\"]\n        change = close - prev_close\n        change[change > 0] = 1.0\n        change[change < 0] = -1.0\n        golden = change.to_numpy()\n        self.assertIsNone(np.testing.assert_allclose(result, golden))\n\n\nclass TestOperatorDataSetting(TestOperatorData):\n    def test_setting(self):\n        self.assertEqual(len(self.instruments_d), 1)\n        self.assertGreater(len(self.cal), 0)\n\n\nclass TestInstElementOperator(TestOperatorData):\n    def setUp(self) -> None:\n        freq = \"day\"\n        expressions = [\n            \"$change\",\n            \"Abs($change)\",\n        ]\n        columns = [\"change\", \"abs\"]\n        self.data = DatasetProvider.inst_calculator(\n            self.inst, self.start_time, self.end_time, freq, expressions, self.spans, C, []\n        )\n        self.data.columns = columns\n\n    @pytest.mark.slow\n    def test_abs(self):\n        abs_values = self.data[\"abs\"]\n        self.assertGreater(abs_values[2], 0)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/ops/test_special_ops.py",
    "content": "import unittest\n\nfrom qlib.data import D\nfrom qlib.data.dataset.loader import QlibDataLoader\nfrom qlib.data.ops import ChangeInstrument, Cov, Feature, Ref, Var\nfrom qlib.tests import TestOperatorData\n\n\nclass TestOperatorDataSetting(TestOperatorData):\n    def test_setting(self):\n        # All the query below passes\n        df = D.features([\"SH600519\"], [\"ChangeInstrument('SH000300', $close)\"])\n\n        # get market return for \"SH600519\"\n        df = D.features([\"SH600519\"], [\"ChangeInstrument('SH000300', Feature('close')/Ref(Feature('close'),1) -1)\"])\n        df = D.features([\"SH600519\"], [\"ChangeInstrument('SH000300', $close/Ref($close,1) -1)\"])\n        # excess return\n        df = D.features(\n            [\"SH600519\"], [\"($close/Ref($close,1) -1) - ChangeInstrument('SH000300', $close/Ref($close,1) -1)\"]\n        )\n        print(df)\n\n    def test_case2(self):\n        def test_case(instruments, queries, note=None):\n            if note:\n                print(note)\n            print(f\"checking {instruments} with queries {queries}\")\n            df = D.features(instruments, queries)\n            print(df)\n            return df\n\n        test_case([\"SH600519\"], [\"ChangeInstrument('SH000300', $close)\"], \"get market index close\")\n        test_case(\n            [\"SH600519\"],\n            [\"ChangeInstrument('SH000300', Feature('close')/Ref(Feature('close'),1) -1)\"],\n            \"get market index return with Feature\",\n        )\n        test_case(\n            [\"SH600519\"],\n            [\"ChangeInstrument('SH000300', $close/Ref($close,1) -1)\"],\n            \"get market index return with expression\",\n        )\n        test_case(\n            [\"SH600519\"],\n            [\"($close/Ref($close,1) -1) - ChangeInstrument('SH000300', $close/Ref($close,1) -1)\"],\n            \"get excess return with expression with beta=1\",\n        )\n\n        ret = \"Feature('close') / Ref(Feature('close'), 1) - 1\"\n        benchmark = \"SH000300\"\n        n_period = 252\n        marketRet = f\"ChangeInstrument('{benchmark}', Feature('close') / Ref(Feature('close'), 1) - 1)\"\n        marketVar = f\"ChangeInstrument('{benchmark}', Var({marketRet}, {n_period}))\"\n        beta = f\"Cov({ret}, {marketRet}, {n_period}) / {marketVar}\"\n        excess_return = f\"{ret} - {beta}*({marketRet})\"\n        fields = [\n            \"Feature('close')\",\n            f\"ChangeInstrument('{benchmark}', Feature('close'))\",\n            ret,\n            marketRet,\n            beta,\n            excess_return,\n        ]\n        test_case([\"SH600519\"], fields[5:], \"get market beta and excess_return with estimated beta\")\n\n        instrument = \"sh600519\"\n        ret = Feature(\"close\") / Ref(Feature(\"close\"), 1) - 1\n        benchmark = \"sh000300\"\n        n_period = 252\n        marketRet = ChangeInstrument(benchmark, Feature(\"close\") / Ref(Feature(\"close\"), 1) - 1)\n        marketVar = ChangeInstrument(benchmark, Var(marketRet, n_period))\n        beta = Cov(ret, marketRet, n_period) / marketVar\n        fields = [\n            Feature(\"close\"),\n            ChangeInstrument(benchmark, Feature(\"close\")),\n            ret,\n            marketRet,\n            beta,\n            ret - beta * marketRet,\n        ]\n        names = [\"close\", \"marketClose\", \"ret\", \"marketRet\", f\"beta_{n_period}\", \"excess_return\"]\n        data_loader_config = {\"feature\": (fields, names)}\n        data_loader = QlibDataLoader(config=data_loader_config)\n        df = data_loader.load(instruments=[instrument])  # , start_time=start_time)\n        print(df)\n\n        # test_case([\"sh600519\"],fields,\n        # \"get market beta and excess_return with estimated beta\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/pytest.ini",
    "content": "[pytest]\nmarkers =\n    slow: marks tests as slow (deselect with '-m \"not slow\"')\nfilterwarnings =\n    ignore:.*rng.randint:DeprecationWarning\n    ignore:.*Casting input x to numpy array:UserWarning\n"
  },
  {
    "path": "tests/rl/test_data_queue.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport multiprocessing\nimport time\n\nimport numpy as np\nimport pandas as pd\n\nfrom torch.utils.data import Dataset, DataLoader\nfrom qlib.rl.utils.data_queue import DataQueue\n\n\nclass DummyDataset(Dataset):\n    def __init__(self, length):\n        self.length = length\n\n    def __getitem__(self, index):\n        assert 0 <= index < self.length\n        return pd.DataFrame(np.random.randint(0, 100, size=(index + 1, 4)), columns=list(\"ABCD\"))\n\n    def __len__(self):\n        return self.length\n\n\ndef _worker(dataloader, collector):\n    # for i in range(3):\n    for i, data in enumerate(dataloader):\n        collector.put(len(data))\n\n\ndef _queue_to_list(queue):\n    result = []\n    while not queue.empty():\n        result.append(queue.get())\n    return result\n\n\ndef test_pytorch_dataloader():\n    dataset = DummyDataset(100)\n    dataloader = DataLoader(dataset, batch_size=None, num_workers=1)\n    queue = multiprocessing.Queue()\n    _worker(dataloader, queue)\n    assert len(set(_queue_to_list(queue))) == 100\n\n\ndef test_multiprocess_shared_dataloader():\n    dataset = DummyDataset(100)\n    with DataQueue(dataset, producer_num_workers=1) as data_queue:\n        queue = multiprocessing.Queue()\n        processes = []\n        for _ in range(3):\n            processes.append(multiprocessing.Process(target=_worker, args=(data_queue, queue)))\n            processes[-1].start()\n        for p in processes:\n            p.join()\n        assert len(set(_queue_to_list(queue))) == 100\n\n\ndef test_exit_on_crash_finite():\n    def _exit_finite():\n        dataset = DummyDataset(100)\n\n        with DataQueue(dataset, producer_num_workers=4) as data_queue:\n            time.sleep(3)\n            raise ValueError\n\n        # https://stackoverflow.com/questions/34506638/how-to-register-atexit-function-in-pythons-multiprocessing-subprocess\n\n    process = multiprocessing.Process(target=_exit_finite)\n    process.start()\n    process.join()\n\n\ndef test_exit_on_crash_infinite():\n    def _exit_infinite():\n        dataset = DummyDataset(100)\n        with DataQueue(dataset, repeat=-1, queue_maxsize=100) as data_queue:\n            time.sleep(3)\n            raise ValueError\n\n    process = multiprocessing.Process(target=_exit_infinite)\n    process.start()\n    process.join()\n\n\nif __name__ == \"__main__\":\n    test_multiprocess_shared_dataloader()\n"
  },
  {
    "path": "tests/rl/test_finite_env.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom collections import Counter\n\nimport gym\nimport numpy as np\nfrom tianshou.data import Batch, Collector\nfrom tianshou.policy import BasePolicy\nfrom torch.utils.data import DataLoader, Dataset, DistributedSampler\nfrom qlib.rl.utils.finite_env import (\n    LogWriter,\n    FiniteDummyVectorEnv,\n    FiniteShmemVectorEnv,\n    FiniteSubprocVectorEnv,\n    check_nan_observation,\n    generate_nan_observation,\n)\n\n_test_space = gym.spaces.Dict(\n    {\n        \"sensors\": gym.spaces.Dict(\n            {\n                \"position\": gym.spaces.Box(low=-100, high=100, shape=(3,)),\n                \"velocity\": gym.spaces.Box(low=-1, high=1, shape=(3,)),\n                \"front_cam\": gym.spaces.Tuple(\n                    (gym.spaces.Box(low=0, high=1, shape=(10, 10, 3)), gym.spaces.Box(low=0, high=1, shape=(10, 10, 3)))\n                ),\n                \"rear_cam\": gym.spaces.Box(low=0, high=1, shape=(10, 10, 3)),\n            }\n        ),\n        \"ext_controller\": gym.spaces.MultiDiscrete((5, 2, 2)),\n        \"inner_state\": gym.spaces.Dict(\n            {\n                \"charge\": gym.spaces.Discrete(100),\n                \"system_checks\": gym.spaces.MultiBinary(10),\n                \"job_status\": gym.spaces.Dict(\n                    {\n                        \"task\": gym.spaces.Discrete(5),\n                        \"progress\": gym.spaces.Box(low=0, high=100, shape=()),\n                    }\n                ),\n            }\n        ),\n    }\n)\n\n\nclass FiniteEnv(gym.Env):\n    def __init__(self, dataset, num_replicas, rank):\n        self.dataset = dataset\n        self.num_replicas = num_replicas\n        self.rank = rank\n        self.loader = DataLoader(dataset, sampler=DistributedSampler(dataset, num_replicas, rank), batch_size=None)\n        self.iterator = None\n        self.observation_space = gym.spaces.Discrete(255)\n        self.action_space = gym.spaces.Discrete(2)\n\n    def reset(self):\n        if self.iterator is None:\n            self.iterator = iter(self.loader)\n        try:\n            self.current_sample, self.step_count = next(self.iterator)\n            self.current_step = 0\n            return self.current_sample\n        except StopIteration:\n            self.iterator = None\n            return generate_nan_observation(self.observation_space)\n\n    def step(self, action):\n        self.current_step += 1\n        assert self.current_step <= self.step_count\n        return (\n            0,\n            1.0,\n            self.current_step >= self.step_count,\n            {\"sample\": self.current_sample, \"action\": action, \"metric\": 2.0},\n        )\n\n\nclass FiniteEnvWithComplexObs(FiniteEnv):\n    def __init__(self, dataset, num_replicas, rank):\n        self.dataset = dataset\n        self.num_replicas = num_replicas\n        self.rank = rank\n        self.loader = DataLoader(dataset, sampler=DistributedSampler(dataset, num_replicas, rank), batch_size=None)\n        self.iterator = None\n        self.observation_space = gym.spaces.Discrete(255)\n        self.action_space = gym.spaces.Discrete(2)\n\n    def reset(self):\n        if self.iterator is None:\n            self.iterator = iter(self.loader)\n        try:\n            self.current_sample, self.step_count = next(self.iterator)\n            self.current_step = 0\n            return _test_space.sample()\n        except StopIteration:\n            self.iterator = None\n            return generate_nan_observation(self.observation_space)\n\n    def step(self, action):\n        self.current_step += 1\n        assert self.current_step <= self.step_count\n        return (\n            _test_space.sample(),\n            1.0,\n            self.current_step >= self.step_count,\n            {\"sample\": _test_space.sample(), \"action\": action, \"metric\": 2.0},\n        )\n\n\nclass DummyDataset(Dataset):\n    def __init__(self, length):\n        self.length = length\n        self.episodes = [3 * i % 5 + 1 for i in range(self.length)]\n\n    def __getitem__(self, index):\n        assert 0 <= index < self.length\n        return index, self.episodes[index]\n\n    def __len__(self):\n        return self.length\n\n\nclass AnyPolicy(BasePolicy):\n    def forward(self, batch, state=None):\n        return Batch(act=np.stack([1] * len(batch)))\n\n    def learn(self, batch):\n        pass\n\n\ndef _finite_env_factory(dataset, num_replicas, rank, complex=False):\n    if complex:\n        return lambda: FiniteEnvWithComplexObs(dataset, num_replicas, rank)\n    return lambda: FiniteEnv(dataset, num_replicas, rank)\n\n\nclass MetricTracker(LogWriter):\n    def __init__(self, length):\n        super().__init__()\n        self.counter = Counter()\n        self.finished = set()\n        self.length = length\n\n    def on_env_step(self, env_id, obs, rew, done, info):\n        assert rew == 1.0\n        index = info[\"sample\"]\n        if done:\n            # assert index not in self.finished\n            self.finished.add(index)\n        self.counter[index] += 1\n\n    def validate(self):\n        assert len(self.finished) == self.length\n        for k, v in self.counter.items():\n            assert v == k * 3 % 5 + 1\n\n\nclass DoNothingTracker(LogWriter):\n    def on_env_step(self, *args, **kwargs):\n        pass\n\n\ndef test_finite_dummy_vector_env():\n    length = 100\n    dataset = DummyDataset(length)\n    envs = FiniteDummyVectorEnv(MetricTracker(length), [_finite_env_factory(dataset, 5, i) for i in range(5)])\n    envs._collector_guarded = True\n    policy = AnyPolicy()\n    test_collector = Collector(policy, envs, exploration_noise=True)\n\n    for _ in range(1):\n        envs._logger = [MetricTracker(length)]\n        try:\n            test_collector.collect(n_step=10**18)\n        except StopIteration:\n            envs._logger[0].validate()\n\n\ndef test_finite_shmem_vector_env():\n    length = 100\n    dataset = DummyDataset(length)\n    envs = FiniteShmemVectorEnv(MetricTracker(length), [_finite_env_factory(dataset, 5, i) for i in range(5)])\n    envs._collector_guarded = True\n    policy = AnyPolicy()\n    test_collector = Collector(policy, envs, exploration_noise=True)\n\n    for _ in range(1):\n        envs._logger = [MetricTracker(length)]\n        try:\n            test_collector.collect(n_step=10**18)\n        except StopIteration:\n            envs._logger[0].validate()\n\n\ndef test_finite_subproc_vector_env():\n    length = 100\n    dataset = DummyDataset(length)\n    envs = FiniteSubprocVectorEnv(MetricTracker(length), [_finite_env_factory(dataset, 5, i) for i in range(5)])\n    envs._collector_guarded = True\n    policy = AnyPolicy()\n    test_collector = Collector(policy, envs, exploration_noise=True)\n\n    for _ in range(1):\n        envs._logger = [MetricTracker(length)]\n        try:\n            test_collector.collect(n_step=10**18)\n        except StopIteration:\n            envs._logger[0].validate()\n\n\ndef test_nan():\n    assert check_nan_observation(generate_nan_observation(_test_space))\n    assert not check_nan_observation(_test_space.sample())\n\n\ndef test_finite_dummy_vector_env_complex():\n    length = 100\n    dataset = DummyDataset(length)\n    envs = FiniteDummyVectorEnv(\n        DoNothingTracker(), [_finite_env_factory(dataset, 5, i, complex=True) for i in range(5)]\n    )\n    envs._collector_guarded = True\n    policy = AnyPolicy()\n    test_collector = Collector(policy, envs, exploration_noise=True)\n\n    try:\n        test_collector.collect(n_step=10**18)\n    except StopIteration:\n        pass\n\n\ndef test_finite_shmem_vector_env_complex():\n    length = 100\n    dataset = DummyDataset(length)\n    envs = FiniteShmemVectorEnv(\n        DoNothingTracker(), [_finite_env_factory(dataset, 5, i, complex=True) for i in range(5)]\n    )\n    envs._collector_guarded = True\n    policy = AnyPolicy()\n    test_collector = Collector(policy, envs, exploration_noise=True)\n\n    try:\n        test_collector.collect(n_step=10**18)\n    except StopIteration:\n        pass\n"
  },
  {
    "path": "tests/rl/test_logger.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\nfrom random import randint, choice\nfrom pathlib import Path\nimport logging\n\nimport re\nfrom typing import Any, Tuple\n\nimport gym\nimport numpy as np\nimport pandas as pd\nfrom gym import spaces\nfrom tianshou.data import Collector, Batch\nfrom tianshou.policy import BasePolicy\n\nfrom qlib.log import set_log_with_config\nfrom qlib.config import C\nfrom qlib.constant import INF\nfrom qlib.rl.interpreter import StateInterpreter, ActionInterpreter\nfrom qlib.rl.simulator import Simulator\nfrom qlib.rl.utils.data_queue import DataQueue\nfrom qlib.rl.utils.env_wrapper import InfoDict, EnvWrapper\nfrom qlib.rl.utils.log import LogLevel, LogCollector, CsvWriter, ConsoleWriter\nfrom qlib.rl.utils.finite_env import vectorize_env\n\n\nclass SimpleEnv(gym.Env[int, int]):\n    def __init__(self) -> None:\n        self.logger = LogCollector()\n        self.observation_space = gym.spaces.Discrete(2)\n        self.action_space = gym.spaces.Discrete(2)\n\n    def reset(self, *args: Any, **kwargs: Any) -> int:\n        self.step_count = 0\n        return 0\n\n    def step(self, action: int) -> Tuple[int, float, bool, dict]:\n        self.logger.reset()\n\n        self.logger.add_scalar(\"reward\", 42.0)\n\n        self.logger.add_scalar(\"a\", randint(1, 10))\n        self.logger.add_array(\"b\", pd.DataFrame({\"a\": [1, 2], \"b\": [3, 4]}))\n\n        if self.step_count >= 3:\n            done = choice([False, True])\n        else:\n            done = False\n\n        if 2 <= self.step_count <= 3:\n            self.logger.add_scalar(\"c\", randint(11, 20))\n\n        self.step_count += 1\n\n        return 1, 42.0, done, InfoDict(log=self.logger.logs(), aux_info={})\n\n    def render(self, mode: str = \"human\") -> None:\n        pass\n\n\nclass AnyPolicy(BasePolicy):\n    def forward(self, batch, state=None):\n        return Batch(act=np.stack([1] * len(batch)))\n\n    def learn(self, batch):\n        pass\n\n\ndef test_simple_env_logger(caplog):\n    set_log_with_config(C.logging_config)\n    # In order for caplog to capture log messages, we configure it here:\n    # allow logs from the qlib logger to be passed to the parent logger.\n    C.logging_config[\"loggers\"][\"qlib\"][\"propagate\"] = True\n    logging.config.dictConfig(C.logging_config)\n    for venv_cls_name in [\"dummy\", \"shmem\", \"subproc\"]:\n        writer = ConsoleWriter()\n        csv_writer = CsvWriter(Path(__file__).parent / \".output\")\n        venv = vectorize_env(lambda: SimpleEnv(), venv_cls_name, 4, [writer, csv_writer])\n        with venv.collector_guard():\n            collector = Collector(AnyPolicy(), venv)\n            collector.collect(n_episode=30)\n\n        output_file = pd.read_csv(Path(__file__).parent / \".output\" / \"result.csv\")\n        assert output_file.columns.tolist() == [\"reward\", \"a\", \"c\"]\n        assert len(output_file) >= 30\n    line_counter = 0\n    for line in caplog.text.splitlines():\n        line = line.strip()\n        if line:\n            line_counter += 1\n            assert re.match(r\".*reward .* {2}a .* \\(([456])\\.\\d+\\) {2}c .* \\((14|15|16)\\.\\d+\\)\", line)\n    assert line_counter >= 3\n\n\nclass SimpleSimulator(Simulator[int, float, float]):\n    def __init__(self, initial: int, **kwargs: Any) -> None:\n        super(SimpleSimulator, self).__init__(initial, **kwargs)\n        self.initial = float(initial)\n\n    def step(self, action: float) -> None:\n        import torch\n\n        self.initial += action\n        self.env.logger.add_scalar(\"test_a\", torch.tensor(233.0))\n        self.env.logger.add_scalar(\"test_b\", np.array(200))\n\n    def get_state(self) -> float:\n        return self.initial\n\n    def done(self) -> bool:\n        return self.initial % 1 > 0.5\n\n\nclass DummyStateInterpreter(StateInterpreter[float, float]):\n    def interpret(self, state: float) -> float:\n        return state\n\n    @property\n    def observation_space(self) -> spaces.Box:\n        return spaces.Box(0, np.inf, shape=(), dtype=np.float32)\n\n\nclass DummyActionInterpreter(ActionInterpreter[float, int, float]):\n    def interpret(self, state: float, action: int) -> float:\n        return action / 100\n\n    @property\n    def action_space(self) -> spaces.Box:\n        return spaces.Discrete(5)\n\n\nclass RandomFivePolicy(BasePolicy):\n    def forward(self, batch, state=None):\n        return Batch(act=np.random.randint(5, size=len(batch)))\n\n    def learn(self, batch):\n        pass\n\n\ndef test_logger_with_env_wrapper():\n    with DataQueue(list(range(20)), shuffle=False) as data_iterator:\n\n        def env_wrapper_factory():\n            return EnvWrapper(\n                SimpleSimulator,\n                DummyStateInterpreter(),\n                DummyActionInterpreter(),\n                data_iterator,\n                logger=LogCollector(LogLevel.DEBUG),\n            )\n\n        # loglevel can be debugged here because metrics can all dump into csv\n        # otherwise, csv writer might crash\n        csv_writer = CsvWriter(Path(__file__).parent / \".output\", loglevel=LogLevel.DEBUG)\n        venv = vectorize_env(env_wrapper_factory, \"shmem\", 4, csv_writer)\n        with venv.collector_guard():\n            collector = Collector(RandomFivePolicy(), venv)\n            collector.collect(n_episode=INF * len(venv))\n\n    output_df = pd.read_csv(Path(__file__).parent / \".output\" / \"result.csv\")\n    assert len(output_df) == 20\n    # obs has an increasing trend\n    assert output_df[\"obs\"].to_numpy()[:10].sum() < output_df[\"obs\"].to_numpy()[10:].sum()\n    assert (output_df[\"test_a\"] == 233).all()\n    assert (output_df[\"test_b\"] == 200).all()\n    assert \"steps_per_episode\" in output_df and \"reward\" in output_df\n"
  },
  {
    "path": "tests/rl/test_qlib_simulator.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport sys\nfrom pathlib import Path\nfrom typing import Tuple\n\nimport pandas as pd\nimport pytest\n\nfrom qlib.backtest.decision import Order, OrderDir\nfrom qlib.backtest.executor import SimulatorExecutor\nfrom qlib.rl.order_execution import CategoricalActionInterpreter\nfrom qlib.rl.order_execution.simulator_qlib import SingleAssetOrderExecution\n\nTOTAL_POSITION = 2100.0\n\npython_version_request = pytest.mark.skipif(sys.version_info < (3, 8), reason=\"requires python3.8 or higher\")\n\n\ndef is_close(a: float, b: float, epsilon: float = 1e-4) -> bool:\n    return abs(a - b) <= epsilon\n\n\ndef get_order() -> Order:\n    return Order(\n        stock_id=\"SH600000\",\n        amount=TOTAL_POSITION,\n        direction=OrderDir.BUY,\n        start_time=pd.Timestamp(\"2019-03-04 09:30:00\"),\n        end_time=pd.Timestamp(\"2019-03-04 14:29:00\"),\n    )\n\n\ndef get_configs(order: Order) -> Tuple[dict, dict]:\n    executor_config = {\n        \"class\": \"NestedExecutor\",\n        \"module_path\": \"qlib.backtest.executor\",\n        \"kwargs\": {\n            \"time_per_step\": \"1day\",\n            \"inner_strategy\": {\"class\": \"ProxySAOEStrategy\", \"module_path\": \"qlib.rl.order_execution.strategy\"},\n            \"track_data\": True,\n            \"inner_executor\": {\n                \"class\": \"NestedExecutor\",\n                \"module_path\": \"qlib.backtest.executor\",\n                \"kwargs\": {\n                    \"time_per_step\": \"30min\",\n                    \"inner_strategy\": {\n                        \"class\": \"TWAPStrategy\",\n                        \"module_path\": \"qlib.contrib.strategy.rule_strategy\",\n                    },\n                    \"inner_executor\": {\n                        \"class\": \"SimulatorExecutor\",\n                        \"module_path\": \"qlib.backtest.executor\",\n                        \"kwargs\": {\n                            \"time_per_step\": \"1min\",\n                            \"verbose\": False,\n                            \"trade_type\": SimulatorExecutor.TT_SERIAL,\n                            \"generate_report\": False,\n                            \"track_data\": True,\n                        },\n                    },\n                    \"track_data\": True,\n                },\n            },\n            \"start_time\": pd.Timestamp(order.start_time.date()),\n            \"end_time\": pd.Timestamp(order.start_time.date()),\n        },\n    }\n\n    exchange_config = {\n        \"freq\": \"1min\",\n        \"codes\": [order.stock_id],\n        \"limit_threshold\": (\"$ask == 0\", \"$bid == 0\"),\n        \"deal_price\": (\"If($ask == 0, $bid, $ask)\", \"If($bid == 0, $ask, $bid)\"),\n        \"volume_threshold\": {\n            \"all\": (\"cum\", \"0.2 * DayCumsum($volume, '9:30', '14:29')\"),\n            \"buy\": (\"current\", \"$askV1\"),\n            \"sell\": (\"current\", \"$bidV1\"),\n        },\n        \"open_cost\": 0.0005,\n        \"close_cost\": 0.0015,\n        \"min_cost\": 5.0,\n        \"trade_unit\": None,\n    }\n\n    return executor_config, exchange_config\n\n\ndef get_simulator(order: Order) -> SingleAssetOrderExecution:\n    DATA_ROOT_DIR = Path(__file__).parent.parent / \".data\" / \"rl\" / \"qlib_simulator\"\n\n    # fmt: off\n    qlib_config = {\n        \"provider_uri_day\": DATA_ROOT_DIR / \"qlib_1d\",\n        \"provider_uri_1min\": DATA_ROOT_DIR / \"qlib_1min\",\n        \"feature_root_dir\": DATA_ROOT_DIR / \"qlib_handler_stock\",\n        \"feature_columns_today\": [\n            \"$open\", \"$high\", \"$low\", \"$close\", \"$vwap\", \"$bid\", \"$ask\", \"$volume\",\n            \"$bidV\", \"$bidV1\", \"$bidV3\", \"$bidV5\", \"$askV\", \"$askV1\", \"$askV3\", \"$askV5\",\n        ],\n        \"feature_columns_yesterday\": [\n            \"$open_1\", \"$high_1\", \"$low_1\", \"$close_1\", \"$vwap_1\", \"$bid_1\", \"$ask_1\", \"$volume_1\",\n            \"$bidV_1\", \"$bidV1_1\", \"$bidV3_1\", \"$bidV5_1\", \"$askV_1\", \"$askV1_1\", \"$askV3_1\", \"$askV5_1\",\n        ],\n    }\n    # fmt: on\n\n    executor_config, exchange_config = get_configs(order)\n\n    return SingleAssetOrderExecution(\n        order=order,\n        qlib_config=qlib_config,\n        executor_config=executor_config,\n        exchange_config=exchange_config,\n    )\n\n\n@python_version_request\ndef test_simulator_first_step():\n    order = get_order()\n    simulator = get_simulator(order)\n    state = simulator.get_state()\n    assert state.cur_time == pd.Timestamp(\"2019-03-04 09:30:00\")\n    assert state.position == TOTAL_POSITION\n\n    AMOUNT = 300.0\n    simulator.step(AMOUNT)\n    state = simulator.get_state()\n    assert state.cur_time == pd.Timestamp(\"2019-03-04 10:00:00\")\n    assert state.position == TOTAL_POSITION - AMOUNT\n    assert len(state.history_exec) == 30\n    assert state.history_exec.index[0] == pd.Timestamp(\"2019-03-04 09:30:00\")\n\n    assert is_close(state.history_exec[\"market_volume\"].iloc[0], 109382.382812)\n    assert is_close(state.history_exec[\"market_price\"].iloc[0], 149.566483)\n    assert (state.history_exec[\"amount\"] == AMOUNT / 30).all()\n    assert (state.history_exec[\"deal_amount\"] == AMOUNT / 30).all()\n    assert is_close(state.history_exec[\"trade_price\"].iloc[0], 149.566483)\n    assert is_close(state.history_exec[\"trade_value\"].iloc[0], 1495.664825)\n    assert is_close(state.history_exec[\"position\"].iloc[0], TOTAL_POSITION - AMOUNT / 30)\n    assert is_close(state.history_exec[\"ffr\"].iloc[0], AMOUNT / TOTAL_POSITION / 30)\n\n    assert is_close(state.history_steps[\"market_volume\"].iloc[0], 1254848.5756835938)\n    assert state.history_steps[\"amount\"].iloc[0] == AMOUNT\n    assert state.history_steps[\"deal_amount\"].iloc[0] == AMOUNT\n    assert state.history_steps[\"ffr\"].iloc[0] == AMOUNT / TOTAL_POSITION\n    assert is_close(\n        state.history_steps[\"pa\"].iloc[0] * (1.0 if order.direction == OrderDir.SELL else -1.0),\n        (state.history_steps[\"trade_price\"].iloc[0] / simulator.twap_price - 1) * 10000,\n    )\n\n\n@python_version_request\ndef test_simulator_stop_twap() -> None:\n    order = get_order()\n    simulator = get_simulator(order)\n    NUM_STEPS = 7\n    for i in range(NUM_STEPS):\n        simulator.step(TOTAL_POSITION / NUM_STEPS)\n\n    HISTORY_STEP_LENGTH = 30 * NUM_STEPS\n    state = simulator.get_state()\n    assert len(state.history_exec) == HISTORY_STEP_LENGTH\n\n    assert (state.history_exec[\"deal_amount\"] == TOTAL_POSITION / HISTORY_STEP_LENGTH).all()\n    assert is_close(state.history_steps[\"position\"].iloc[0], TOTAL_POSITION * (NUM_STEPS - 1) / NUM_STEPS)\n    assert is_close(state.history_steps[\"position\"].iloc[-1], 0.0)\n    assert is_close(state.position, 0.0)\n    assert is_close(state.metrics[\"ffr\"], 1.0)\n\n    assert is_close(state.metrics[\"market_price\"], state.backtest_data.get_deal_price().mean())\n    assert is_close(state.metrics[\"market_volume\"], state.backtest_data.get_volume().sum())\n    assert is_close(state.metrics[\"trade_price\"], state.metrics[\"market_price\"])\n    assert is_close(state.metrics[\"pa\"], 0.0)\n\n    assert simulator.done()\n\n\n@python_version_request\ndef test_interpreter() -> None:\n    NUM_EXECUTION = 3\n    order = get_order()\n    simulator = get_simulator(order)\n    interpreter_action = CategoricalActionInterpreter(values=NUM_EXECUTION)\n\n    NUM_STEPS = 7\n    state = simulator.get_state()\n    position_history = []\n    for i in range(NUM_STEPS):\n        simulator.step(interpreter_action(state, 1))\n        state = simulator.get_state()\n        position_history.append(state.position)\n\n        assert position_history[-1] == max(TOTAL_POSITION - TOTAL_POSITION / NUM_EXECUTION * (i + 1), 0.0)\n"
  },
  {
    "path": "tests/rl/test_saoe_simple.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport sys\nfrom functools import partial\nfrom pathlib import Path\nfrom typing import NamedTuple\n\nimport numpy as np\nimport pandas as pd\nimport pytest\nimport torch\nfrom tianshou.data import Batch\n\nfrom qlib.backtest import Order\nfrom qlib.config import C\nfrom qlib.log import set_log_with_config\nfrom qlib.rl.data import pickle_styled\nfrom qlib.rl.data.pickle_styled import PickleProcessedDataProvider\nfrom qlib.rl.order_execution import *\nfrom qlib.rl.trainer import backtest, train\nfrom qlib.rl.utils import ConsoleWriter, CsvWriter, EnvWrapperStatus\n\npytestmark = pytest.mark.skipif(sys.version_info < (3, 8), reason=\"Pickle styled data only supports Python >= 3.8\")\n\n\nDATA_ROOT_DIR = Path(__file__).parent.parent / \".data\" / \"rl\" / \"intraday_saoe\"\nDATA_DIR = DATA_ROOT_DIR / \"us\"\nBACKTEST_DATA_DIR = DATA_DIR / \"backtest\"\nFEATURE_DATA_DIR = DATA_DIR / \"processed\"\nORDER_DIR = DATA_DIR / \"order\" / \"valid_bidir\"\n\nCN_DATA_DIR = DATA_ROOT_DIR / \"cn\"\nCN_FEATURE_DATA_DIR = CN_DATA_DIR / \"processed\"\nCN_ORDER_DIR = CN_DATA_DIR / \"order\" / \"test\"\nCN_POLICY_WEIGHTS_DIR = CN_DATA_DIR / \"weights\"\n\n\ndef test_pickle_data_inspect():\n    data = pickle_styled.load_simple_intraday_backtest_data(BACKTEST_DATA_DIR, \"AAL\", \"2013-12-11\", \"close\", 0)\n    assert len(data) == 390\n\n    provider = PickleProcessedDataProvider(DATA_DIR / \"processed\")\n    data = provider.get_data(\"AAL\", \"2013-12-11\", 5, data.get_time_index())\n    assert len(data.today) == len(data.yesterday) == 390\n\n\ndef test_simulator_first_step():\n    order = Order(\"AAL\", 30.0, 0, pd.Timestamp(\"2013-12-11 00:00:00\"), pd.Timestamp(\"2013-12-11 23:59:59\"))\n\n    simulator = SingleAssetOrderExecutionSimple(order, DATA_DIR)\n    state = simulator.get_state()\n    assert state.cur_time == pd.Timestamp(\"2013-12-11 09:30:00\")\n    assert state.position == 30.0\n\n    simulator.step(15.0)\n    state = simulator.get_state()\n    assert len(state.history_exec) == 30\n    assert state.history_exec.index[0] == pd.Timestamp(\"2013-12-11 09:30:00\")\n    assert state.history_exec[\"market_volume\"].iloc[0] == 450072.0\n    assert abs(state.history_exec[\"market_price\"].iloc[0] - 25.370001) < 1e-4\n    assert (state.history_exec[\"amount\"] == 0.5).all()\n    assert (state.history_exec[\"deal_amount\"] == 0.5).all()\n    assert abs(state.history_exec[\"trade_price\"].iloc[0] - 25.370001) < 1e-4\n    assert abs(state.history_exec[\"trade_value\"].iloc[0] - 12.68500) < 1e-4\n    assert state.history_exec[\"position\"].iloc[0] == 29.5\n    assert state.history_exec[\"ffr\"].iloc[0] == 1 / 60\n\n    assert state.history_steps[\"market_volume\"].iloc[0] == 5041147.0\n    assert state.history_steps[\"amount\"].iloc[0] == 15.0\n    assert state.history_steps[\"deal_amount\"].iloc[0] == 15.0\n    assert state.history_steps[\"ffr\"].iloc[0] == 0.5\n    assert (\n        state.history_steps[\"pa\"].iloc[0]\n        == (state.history_steps[\"trade_price\"].iloc[0] / simulator.twap_price - 1) * 10000\n    )\n\n    assert state.position == 15.0\n    assert state.cur_time == pd.Timestamp(\"2013-12-11 10:00:00\")\n\n\ndef test_simulator_stop_twap():\n    order = Order(\"AAL\", 13.0, 0, pd.Timestamp(\"2013-12-11 00:00:00\"), pd.Timestamp(\"2013-12-11 23:59:59\"))\n\n    simulator = SingleAssetOrderExecutionSimple(order, DATA_DIR)\n    for _ in range(13):\n        simulator.step(1.0)\n\n    state = simulator.get_state()\n    assert len(state.history_exec) == 390\n    assert (state.history_exec[\"deal_amount\"] == 13 / 390).all()\n    assert state.history_steps[\"position\"].iloc[0] == 12 and state.history_steps[\"position\"].iloc[-1] == 0\n\n    assert (state.metrics[\"ffr\"] - 1) < 1e-3\n    assert abs(state.metrics[\"market_price\"] - state.backtest_data.get_deal_price().mean()) < 1e-4\n    assert np.isclose(state.metrics[\"market_volume\"], state.backtest_data.get_volume().sum())\n    assert state.position == 0.0\n    assert abs(state.metrics[\"trade_price\"] - state.metrics[\"market_price\"]) < 1e-4\n    assert abs(state.metrics[\"pa\"]) < 1e-2\n\n    assert simulator.done()\n\n\ndef test_simulator_stop_early():\n    order = Order(\"AAL\", 1.0, 1, pd.Timestamp(\"2013-12-11 00:00:00\"), pd.Timestamp(\"2013-12-11 23:59:59\"))\n\n    with pytest.raises(ValueError):\n        simulator = SingleAssetOrderExecutionSimple(order, DATA_DIR)\n        simulator.step(2.0)\n\n    simulator = SingleAssetOrderExecutionSimple(order, DATA_DIR)\n    simulator.step(1.0)\n\n    with pytest.raises(AssertionError):\n        simulator.step(1.0)\n\n\ndef test_simulator_start_middle():\n    order = Order(\"AAL\", 15.0, 1, pd.Timestamp(\"2013-12-11 10:15:00\"), pd.Timestamp(\"2013-12-11 15:44:59\"))\n\n    simulator = SingleAssetOrderExecutionSimple(order, DATA_DIR)\n    assert len(simulator.ticks_for_order) == 330\n    assert simulator.cur_time == pd.Timestamp(\"2013-12-11 10:15:00\")\n    simulator.step(2.0)\n    assert simulator.cur_time == pd.Timestamp(\"2013-12-11 10:30:00\")\n\n    for _ in range(10):\n        simulator.step(1.0)\n\n    simulator.step(2.0)\n    assert len(simulator.history_exec) == 330\n    assert simulator.done()\n    assert abs(simulator.history_exec[\"amount\"].iloc[-1] - (1 + 2 / 15)) < 1e-4\n    assert abs(simulator.metrics[\"ffr\"] - 1) < 1e-4\n\n\ndef test_interpreter():\n    order = Order(\"AAL\", 15.0, 1, pd.Timestamp(\"2013-12-11 10:15:00\"), pd.Timestamp(\"2013-12-11 15:44:59\"))\n\n    simulator = SingleAssetOrderExecutionSimple(order, DATA_DIR)\n    assert len(simulator.ticks_for_order) == 330\n    assert simulator.cur_time == pd.Timestamp(\"2013-12-11 10:15:00\")\n\n    # emulate a env status\n    class EmulateEnvWrapper(NamedTuple):\n        status: EnvWrapperStatus\n\n    interpreter = FullHistoryStateInterpreter(13, 390, 5, PickleProcessedDataProvider(FEATURE_DATA_DIR))\n    interpreter_step = CurrentStepStateInterpreter(13)\n    interpreter_action = CategoricalActionInterpreter(20)\n    interpreter_action_twap = TwapRelativeActionInterpreter()\n\n    wrapper_status_kwargs = dict(initial_state=order, obs_history=[], action_history=[], reward_history=[])\n\n    # first step\n    interpreter.env = EmulateEnvWrapper(status=EnvWrapperStatus(cur_step=0, done=False, **wrapper_status_kwargs))\n\n    obs = interpreter(simulator.get_state())\n    assert obs[\"cur_tick\"] == 45\n    assert obs[\"cur_step\"] == 0\n    assert obs[\"position\"] == 15.0\n    assert obs[\"position_history\"][0] == 15.0\n    assert all(np.sum(obs[\"data_processed\"][i]) != 0 for i in range(45))\n    assert np.sum(obs[\"data_processed\"][45:]) == 0\n    assert obs[\"data_processed_prev\"].shape == (390, 5)\n\n    # first step: second interpreter\n    interpreter_step.env = EmulateEnvWrapper(status=EnvWrapperStatus(cur_step=0, done=False, **wrapper_status_kwargs))\n\n    obs = interpreter_step(simulator.get_state())\n    assert obs[\"acquiring\"] == 1\n    assert obs[\"position\"] == 15.0\n\n    # second step\n    simulator.step(5.0)\n    interpreter.env = EmulateEnvWrapper(status=EnvWrapperStatus(cur_step=1, done=False, **wrapper_status_kwargs))\n\n    obs = interpreter(simulator.get_state())\n    assert obs[\"cur_tick\"] == 60\n    assert obs[\"cur_step\"] == 1\n    assert obs[\"position\"] == 10.0\n    assert obs[\"position_history\"][:2].tolist() == [15.0, 10.0]\n    assert all(np.sum(obs[\"data_processed\"][i]) != 0 for i in range(60))\n    assert np.sum(obs[\"data_processed\"][60:]) == 0\n\n    # second step: action\n    action = interpreter_action(simulator.get_state(), 1)\n    assert action == 15 / 20\n\n    interpreter_action_twap.env = EmulateEnvWrapper(\n        status=EnvWrapperStatus(cur_step=1, done=False, **wrapper_status_kwargs)\n    )\n    action = interpreter_action_twap(simulator.get_state(), 1.5)\n    assert action == 1.5\n\n    # fast-forward\n    for _ in range(10):\n        simulator.step(0.0)\n\n    # last step\n    simulator.step(5.0)\n    interpreter.env = EmulateEnvWrapper(\n        status=EnvWrapperStatus(cur_step=12, done=simulator.done(), **wrapper_status_kwargs)\n    )\n\n    assert interpreter.env.status[\"done\"]\n\n    obs = interpreter(simulator.get_state())\n    assert obs[\"cur_tick\"] == 375\n    assert obs[\"cur_step\"] == 12\n    assert obs[\"position\"] == 0.0\n    assert obs[\"position_history\"][1:11].tolist() == [10.0] * 10\n    assert all(np.sum(obs[\"data_processed\"][i]) != 0 for i in range(375))\n    assert np.sum(obs[\"data_processed\"][375:]) == 0\n\n\ndef test_network_sanity():\n    # we won't check the correctness of networks here\n    order = Order(\"AAL\", 15.0, 1, pd.Timestamp(\"2013-12-11 9:30:00\"), pd.Timestamp(\"2013-12-11 15:59:59\"))\n\n    simulator = SingleAssetOrderExecutionSimple(order, DATA_DIR)\n    assert len(simulator.ticks_for_order) == 390\n\n    class EmulateEnvWrapper(NamedTuple):\n        status: EnvWrapperStatus\n\n    interpreter = FullHistoryStateInterpreter(13, 390, 5, PickleProcessedDataProvider(FEATURE_DATA_DIR))\n    action_interp = CategoricalActionInterpreter(13)\n\n    wrapper_status_kwargs = dict(initial_state=order, obs_history=[], action_history=[], reward_history=[])\n\n    network = Recurrent(interpreter.observation_space)\n    policy = PPO(network, interpreter.observation_space, action_interp.action_space, 1e-3)\n\n    for i in range(14):\n        interpreter.env = EmulateEnvWrapper(status=EnvWrapperStatus(cur_step=i, done=False, **wrapper_status_kwargs))\n        obs = interpreter(simulator.get_state())\n        batch = Batch(obs=[obs])\n        output = policy(batch)\n        assert 0 <= output[\"act\"].item() <= 13\n        if i < 13:\n            simulator.step(1.0)\n        else:\n            assert obs[\"cur_tick\"] == 389\n            assert obs[\"cur_step\"] == 12\n            assert obs[\"position_history\"][-1] == 3\n\n\n@pytest.mark.parametrize(\"finite_env_type\", [\"dummy\", \"subproc\", \"shmem\"])\ndef test_twap_strategy(finite_env_type):\n    set_log_with_config(C.logging_config)\n    orders = pickle_styled.load_orders(ORDER_DIR)\n    assert len(orders) == 248\n\n    state_interp = FullHistoryStateInterpreter(13, 390, 5, PickleProcessedDataProvider(FEATURE_DATA_DIR))\n    action_interp = TwapRelativeActionInterpreter()\n    policy = AllOne(state_interp.observation_space, action_interp.action_space)\n    csv_writer = CsvWriter(Path(__file__).parent / \".output\")\n\n    backtest(\n        partial(SingleAssetOrderExecutionSimple, data_dir=DATA_DIR, ticks_per_step=30),\n        state_interp,\n        action_interp,\n        orders,\n        policy,\n        [ConsoleWriter(total_episodes=len(orders)), csv_writer],\n        concurrency=4,\n        finite_env_type=finite_env_type,\n    )\n\n    metrics = pd.read_csv(Path(__file__).parent / \".output\" / \"result.csv\")\n    assert len(metrics) == 248\n    assert np.isclose(metrics[\"ffr\"].mean(), 1.0)\n    assert np.isclose(metrics[\"pa\"].mean(), 0.0)\n    assert np.allclose(metrics[\"pa\"], 0.0, atol=2e-3)\n\n\ndef test_cn_ppo_strategy():\n    set_log_with_config(C.logging_config)\n    # The data starts with 9:31 and ends with 15:00\n    orders = pickle_styled.load_orders(CN_ORDER_DIR, start_time=pd.Timestamp(\"9:31\"), end_time=pd.Timestamp(\"14:58\"))\n    assert len(orders) == 40\n\n    state_interp = FullHistoryStateInterpreter(8, 240, 6, PickleProcessedDataProvider(CN_FEATURE_DATA_DIR))\n    action_interp = CategoricalActionInterpreter(4)\n    network = Recurrent(state_interp.observation_space)\n    policy = PPO(network, state_interp.observation_space, action_interp.action_space, 1e-4)\n    policy.load_state_dict(torch.load(CN_POLICY_WEIGHTS_DIR / \"ppo_recurrent_30min.pth\", map_location=\"cpu\"))\n    csv_writer = CsvWriter(Path(__file__).parent / \".output\")\n\n    backtest(\n        partial(SingleAssetOrderExecutionSimple, data_dir=CN_DATA_DIR, ticks_per_step=30),\n        state_interp,\n        action_interp,\n        orders,\n        policy,\n        [ConsoleWriter(total_episodes=len(orders)), csv_writer],\n        concurrency=4,\n    )\n\n    metrics = pd.read_csv(Path(__file__).parent / \".output\" / \"result.csv\")\n    assert len(metrics) == len(orders)\n    assert np.isclose(metrics[\"ffr\"].mean(), 1.0)\n    assert np.isclose(metrics[\"pa\"].mean(), -16.21578303474833)\n    assert np.isclose(metrics[\"market_price\"].mean(), 58.68277690875527)\n    assert np.isclose(metrics[\"trade_price\"].mean(), 58.76063985000002)\n\n\ndef test_ppo_train():\n    set_log_with_config(C.logging_config)\n    # The data starts with 9:31 and ends with 15:00\n    orders = pickle_styled.load_orders(CN_ORDER_DIR, start_time=pd.Timestamp(\"9:31\"), end_time=pd.Timestamp(\"14:58\"))\n    assert len(orders) == 40\n\n    state_interp = FullHistoryStateInterpreter(8, 240, 6, PickleProcessedDataProvider(CN_FEATURE_DATA_DIR))\n    action_interp = CategoricalActionInterpreter(4)\n    network = Recurrent(state_interp.observation_space)\n    policy = PPO(network, state_interp.observation_space, action_interp.action_space, 1e-4)\n\n    train(\n        partial(SingleAssetOrderExecutionSimple, data_dir=CN_DATA_DIR, ticks_per_step=30),\n        state_interp,\n        action_interp,\n        orders,\n        policy,\n        PAPenaltyReward(),\n        vessel_kwargs={\"episode_per_iter\": 100, \"update_kwargs\": {\"batch_size\": 64, \"repeat\": 5}},\n        trainer_kwargs={\"max_iters\": 2, \"loggers\": ConsoleWriter(total_episodes=100)},\n    )\n"
  },
  {
    "path": "tests/rl/test_trainer.py",
    "content": "import os\nimport random\nimport sys\nfrom pathlib import Path\n\nimport pytest\n\nimport torch\nimport torch.nn as nn\nfrom gym import spaces\nfrom tianshou.policy import PPOPolicy\n\nfrom qlib.config import C\nfrom qlib.log import set_log_with_config\nfrom qlib.rl.interpreter import StateInterpreter, ActionInterpreter\nfrom qlib.rl.simulator import Simulator\nfrom qlib.rl.reward import Reward\nfrom qlib.rl.trainer import Trainer, TrainingVessel, EarlyStopping, Checkpoint\n\npytestmark = pytest.mark.skipif(sys.version_info < (3, 8), reason=\"Pickle styled data only supports Python >= 3.8\")\n\n\nclass ZeroSimulator(Simulator):\n    def __init__(self, *args, **kwargs):\n        self.action = self.correct = 0\n\n    def step(self, action):\n        self.action = action\n        self.correct = action == 0\n        self._done = random.choice([False, True])\n        if self._done:\n            self.env.logger.add_scalar(\"acc\", self.correct * 100)\n\n    def get_state(self):\n        return {\n            \"acc\": self.correct * 100,\n            \"action\": self.action,\n        }\n\n    def done(self) -> bool:\n        return self._done\n\n\nclass NoopStateInterpreter(StateInterpreter):\n    observation_space = spaces.Dict(\n        {\n            \"acc\": spaces.Discrete(200),\n            \"action\": spaces.Discrete(2),\n        }\n    )\n\n    def interpret(self, simulator_state):\n        return simulator_state\n\n\nclass NoopActionInterpreter(ActionInterpreter):\n    action_space = spaces.Discrete(2)\n\n    def interpret(self, simulator_state, action):\n        return action\n\n\nclass AccReward(Reward):\n    def reward(self, simulator_state):\n        if self.env.status[\"done\"]:\n            return simulator_state[\"acc\"] / 100\n        return 0.0\n\n\nclass PolicyNet(nn.Module):\n    def __init__(self, out_features=1, return_state=False):\n        super().__init__()\n        self.fc = nn.Linear(32, out_features)\n        self.return_state = return_state\n\n    def forward(self, obs, state=None, **kwargs):\n        res = self.fc(torch.randn(obs[\"acc\"].shape[0], 32))\n        if self.return_state:\n            return nn.functional.softmax(res, dim=-1), state\n        else:\n            return res\n\n\ndef _ppo_policy():\n    actor = PolicyNet(2, True)\n    critic = PolicyNet()\n    policy = PPOPolicy(\n        actor,\n        critic,\n        torch.optim.Adam(tuple(actor.parameters()) + tuple(critic.parameters())),\n        torch.distributions.Categorical,\n        action_space=NoopActionInterpreter().action_space,\n    )\n    return policy\n\n\ndef test_trainer():\n    set_log_with_config(C.logging_config)\n    trainer = Trainer(max_iters=10, finite_env_type=\"subproc\")\n    policy = _ppo_policy()\n\n    vessel = TrainingVessel(\n        simulator_fn=lambda init: ZeroSimulator(init),\n        state_interpreter=NoopStateInterpreter(),\n        action_interpreter=NoopActionInterpreter(),\n        policy=policy,\n        train_initial_states=list(range(100)),\n        val_initial_states=list(range(10)),\n        test_initial_states=list(range(10)),\n        reward=AccReward(),\n        episode_per_iter=500,\n        update_kwargs=dict(repeat=10, batch_size=64),\n    )\n    trainer.fit(vessel)\n    assert trainer.current_iter == 10\n    assert trainer.current_episode == 5000\n    assert abs(trainer.metrics[\"acc\"] - trainer.metrics[\"reward\"] * 100) < 1e-4\n    assert trainer.metrics[\"acc\"] > 80\n    trainer.test(vessel)\n    assert trainer.metrics[\"acc\"] > 60\n\n\ndef test_trainer_fast_dev_run():\n    set_log_with_config(C.logging_config)\n    trainer = Trainer(max_iters=2, fast_dev_run=2, finite_env_type=\"shmem\")\n    policy = _ppo_policy()\n\n    vessel = TrainingVessel(\n        simulator_fn=lambda init: ZeroSimulator(init),\n        state_interpreter=NoopStateInterpreter(),\n        action_interpreter=NoopActionInterpreter(),\n        policy=policy,\n        train_initial_states=list(range(100)),\n        val_initial_states=list(range(10)),\n        test_initial_states=list(range(10)),\n        reward=AccReward(),\n        episode_per_iter=500,\n        update_kwargs=dict(repeat=10, batch_size=64),\n    )\n    trainer.fit(vessel)\n    assert trainer.current_episode == 4\n\n\ndef test_trainer_earlystop():\n    # TODO this is just sanity check.\n    # need to see the logs to check whether it works.\n    set_log_with_config(C.logging_config)\n    trainer = Trainer(\n        max_iters=10,\n        val_every_n_iters=1,\n        finite_env_type=\"dummy\",\n        callbacks=[EarlyStopping(\"val/reward\", restore_best_weights=True)],\n    )\n    policy = _ppo_policy()\n\n    vessel = TrainingVessel(\n        simulator_fn=lambda init: ZeroSimulator(init),\n        state_interpreter=NoopStateInterpreter(),\n        action_interpreter=NoopActionInterpreter(),\n        policy=policy,\n        train_initial_states=list(range(100)),\n        val_initial_states=list(range(10)),\n        test_initial_states=list(range(10)),\n        reward=AccReward(),\n        episode_per_iter=500,\n        update_kwargs=dict(repeat=10, batch_size=64),\n    )\n    trainer.fit(vessel)\n    assert trainer.metrics[\"val/acc\"] > 30\n    assert trainer.current_iter == 2  # second iteration\n\n\ndef test_trainer_checkpoint():\n    set_log_with_config(C.logging_config)\n    output_dir = Path(__file__).parent / \".output\"\n    trainer = Trainer(max_iters=2, finite_env_type=\"dummy\", callbacks=[Checkpoint(output_dir, every_n_iters=1)])\n    policy = _ppo_policy()\n\n    vessel = TrainingVessel(\n        simulator_fn=lambda init: ZeroSimulator(init),\n        state_interpreter=NoopStateInterpreter(),\n        action_interpreter=NoopActionInterpreter(),\n        policy=policy,\n        train_initial_states=list(range(100)),\n        val_initial_states=list(range(10)),\n        test_initial_states=list(range(10)),\n        reward=AccReward(),\n        episode_per_iter=100,\n        update_kwargs=dict(repeat=10, batch_size=64),\n    )\n    trainer.fit(vessel)\n\n    assert (output_dir / \"001.pth\").exists()\n    assert (output_dir / \"002.pth\").exists()\n    assert os.readlink(output_dir / \"latest.pth\") == str(output_dir / \"002.pth\")\n\n    trainer.load_state_dict(torch.load(output_dir / \"001.pth\", weights_only=False))\n    assert trainer.current_iter == 1\n    assert trainer.current_episode == 100\n\n    # Reload the checkpoint at first iteration\n    trainer.fit(vessel, ckpt_path=output_dir / \"001.pth\")\n"
  },
  {
    "path": "tests/rolling_tests/test_update_pred.py",
    "content": "import copy\nimport unittest\nimport pytest\n\nimport fire\nimport pandas as pd\n\nimport qlib\nfrom qlib.data import D\nfrom qlib.model.trainer import task_train\nfrom qlib.tests import TestAutoData\nfrom qlib.tests.config import CSI300_GBDT_TASK\nfrom qlib.workflow.online.utils import OnlineToolR\nfrom qlib.workflow.online.update import LabelUpdater\n\n\nclass TestRolling(TestAutoData):\n    @pytest.mark.slow\n    def test_update_pred(self):\n        \"\"\"\n        This test is for testing if it will raise error if the `to_date` is out of the boundary.\n        \"\"\"\n        task = copy.deepcopy(CSI300_GBDT_TASK)\n\n        task[\"record\"] = [\"qlib.workflow.record_temp.SignalRecord\"]\n\n        exp_name = \"online_srv_test\"\n\n        cal = D.calendar()\n        latest_date = cal[-1]\n\n        train_start = latest_date - pd.Timedelta(days=61)\n        train_end = latest_date - pd.Timedelta(days=41)\n        task[\"dataset\"][\"kwargs\"][\"segments\"] = {\n            \"train\": (train_start, train_end),\n            \"valid\": (latest_date - pd.Timedelta(days=40), latest_date - pd.Timedelta(days=21)),\n            \"test\": (latest_date - pd.Timedelta(days=20), latest_date),\n        }\n\n        task[\"dataset\"][\"kwargs\"][\"handler\"][\"kwargs\"] = {\n            \"start_time\": train_start,\n            \"end_time\": latest_date,\n            \"fit_start_time\": train_start,\n            \"fit_end_time\": train_end,\n            \"instruments\": \"csi300\",\n        }\n\n        rec = task_train(task, exp_name)\n\n        pred = rec.load_object(\"pred.pkl\")\n\n        online_tool = OnlineToolR(exp_name)\n        online_tool.reset_online_tag(rec)  # set to online model\n\n        online_tool.update_online_pred(to_date=latest_date + pd.Timedelta(days=10))\n\n        good_pred = rec.load_object(\"pred.pkl\")\n\n        mod_range = slice(latest_date - pd.Timedelta(days=20), latest_date - pd.Timedelta(days=10))\n        mod_range2 = slice(latest_date - pd.Timedelta(days=9), latest_date - pd.Timedelta(days=2))\n        mod_pred = good_pred.copy()\n\n        mod_pred.loc[mod_range] = -1\n        mod_pred.loc[mod_range2] = -2\n\n        rec.save_objects(**{\"pred.pkl\": mod_pred})\n        online_tool.update_online_pred(\n            to_date=latest_date - pd.Timedelta(days=10), from_date=latest_date - pd.Timedelta(days=20)\n        )\n\n        updated_pred = rec.load_object(\"pred.pkl\")\n\n        # this range is not fixed\n        self.assertTrue((updated_pred.loc[mod_range] == good_pred.loc[mod_range]).all().item())\n        # this range is fixed now\n        self.assertTrue((updated_pred.loc[mod_range2] == -2).all().item())\n\n    @pytest.mark.slow\n    def test_update_label(self):\n        task = copy.deepcopy(CSI300_GBDT_TASK)\n\n        task[\"record\"] = {\n            \"class\": \"SignalRecord\",\n            \"module_path\": \"qlib.workflow.record_temp\",\n            \"kwargs\": {\"dataset\": \"<DATASET>\", \"model\": \"<MODEL>\"},\n        }\n\n        exp_name = \"online_srv_test\"\n\n        cal = D.calendar()\n        shift = 10\n        latest_date = cal[-1 - shift]\n\n        train_start = latest_date - pd.Timedelta(days=61)\n        train_end = latest_date - pd.Timedelta(days=41)\n        task[\"dataset\"][\"kwargs\"][\"segments\"] = {\n            \"train\": (train_start, train_end),\n            \"valid\": (latest_date - pd.Timedelta(days=40), latest_date - pd.Timedelta(days=21)),\n            \"test\": (latest_date - pd.Timedelta(days=20), latest_date),\n        }\n\n        task[\"dataset\"][\"kwargs\"][\"handler\"][\"kwargs\"] = {\n            \"start_time\": train_start,\n            \"end_time\": latest_date,\n            \"fit_start_time\": train_start,\n            \"fit_end_time\": train_end,\n            \"instruments\": \"csi300\",\n        }\n\n        rec = task_train(task, exp_name)\n\n        pred = rec.load_object(\"pred.pkl\")\n\n        online_tool = OnlineToolR(exp_name)\n        online_tool.reset_online_tag(rec)  # set to online model\n        online_tool.update_online_pred()\n\n        new_pred = rec.load_object(\"pred.pkl\")\n        label = rec.load_object(\"label.pkl\")\n        label_date = label.dropna().index.get_level_values(\"datetime\").max()\n        pred_date = new_pred.dropna().index.get_level_values(\"datetime\").max()\n\n        # The prediction is updated, but the label is not updated.\n        self.assertTrue(label_date < pred_date)\n\n        # Update label now\n        lu = LabelUpdater(rec)\n        lu.update()\n        new_label = rec.load_object(\"label.pkl\")\n        new_label_date = new_label.index.get_level_values(\"datetime\").max()\n        self.assertTrue(new_label_date == pred_date)  # make sure the label is updated now\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/storage_tests/test_storage.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\n\nfrom pathlib import Path\nfrom collections.abc import Iterable\n\nimport numpy as np\nfrom qlib.tests import TestAutoData\n\nfrom qlib.data.storage.file_storage import (\n    FileCalendarStorage as CalendarStorage,\n    FileInstrumentStorage as InstrumentStorage,\n    FileFeatureStorage as FeatureStorage,\n)\n\n_file_name = Path(__file__).name.split(\".\")[0]\nDATA_DIR = Path(__file__).parent.joinpath(f\"{_file_name}_data\")\nQLIB_DIR = DATA_DIR.joinpath(\"qlib\")\nQLIB_DIR.mkdir(exist_ok=True, parents=True)\n\n\nclass TestStorage(TestAutoData):\n    def test_calendar_storage(self):\n        calendar = CalendarStorage(freq=\"day\", future=False, provider_uri=self.provider_uri)\n        assert isinstance(calendar[:], Iterable), f\"{calendar.__class__.__name__}.__getitem__(s: slice) is not Iterable\"\n        assert isinstance(calendar.data, Iterable), f\"{calendar.__class__.__name__}.data is not Iterable\"\n\n        print(f\"calendar[1: 5]: {calendar[1:5]}\")\n        print(f\"calendar[0]: {calendar[0]}\")\n        print(f\"calendar[-1]: {calendar[-1]}\")\n\n        calendar = CalendarStorage(freq=\"1min\", future=False, provider_uri=\"not_found\")\n        with self.assertRaises(ValueError):\n            print(calendar.data)\n\n        with self.assertRaises(ValueError):\n            print(calendar[:])\n\n        with self.assertRaises(ValueError):\n            print(calendar[0])\n\n    def test_instrument_storage(self):\n        \"\"\"\n        The meaning of instrument, such as CSI500:\n\n            CSI500 composition changes:\n\n                date            add         remove\n                2005-01-01      SH600000\n                2005-01-01      SH600001\n                2005-01-01      SH600002\n                2005-02-01      SH600003    SH600000\n                2005-02-15      SH600000    SH600002\n\n            Calendar:\n                pd.date_range(start=\"2020-01-01\", stop=\"2020-03-01\", freq=\"1D\")\n\n            Instrument:\n                symbol      start_time      end_time\n                SH600000    2005-01-01      2005-01-31 (2005-02-01 Last trading day)\n                SH600000    2005-02-15      2005-03-01\n                SH600001    2005-01-01      2005-03-01\n                SH600002    2005-01-01      2005-02-14 (2005-02-15 Last trading day)\n                SH600003    2005-02-01      2005-03-01\n\n            InstrumentStorage:\n                {\n                    \"SH600000\": [(2005-01-01, 2005-01-31), (2005-02-15, 2005-03-01)],\n                    \"SH600001\": [(2005-01-01, 2005-03-01)],\n                    \"SH600002\": [(2005-01-01, 2005-02-14)],\n                    \"SH600003\": [(2005-02-01, 2005-03-01)],\n                }\n\n        \"\"\"\n\n        instrument = InstrumentStorage(market=\"csi300\", provider_uri=self.provider_uri, freq=\"day\")\n\n        for inst, spans in instrument.data.items():\n            assert isinstance(inst, str) and isinstance(\n                spans, Iterable\n            ), f\"{instrument.__class__.__name__} value is not Iterable\"\n            for s_e in spans:\n                assert (\n                    isinstance(s_e, tuple) and len(s_e) == 2\n                ), f\"{instrument.__class__.__name__}.__getitem__(k) TypeError\"\n\n        print(f\"instrument['SH600000']: {instrument['SH600000']}\")\n\n        instrument = InstrumentStorage(market=\"csi300\", provider_uri=\"not_found\", freq=\"day\")\n        with self.assertRaises(ValueError):\n            print(instrument.data)\n\n        with self.assertRaises(ValueError):\n            print(instrument[\"sSH600000\"])\n\n    def test_feature_storage(self):\n        \"\"\"\n        Calendar:\n            pd.date_range(start=\"2005-01-01\", stop=\"2005-03-01\", freq=\"1D\")\n\n        Instrument:\n            {\n                \"SH600000\": [(2005-01-01, 2005-01-31), (2005-02-15, 2005-03-01)],\n                \"SH600001\": [(2005-01-01, 2005-03-01)],\n                \"SH600002\": [(2005-01-01, 2005-02-14)],\n                \"SH600003\": [(2005-02-01, 2005-03-01)],\n            }\n\n        Feature:\n            Stock data(close):\n                            2005-01-01  ...   2005-02-01   ...   2005-02-14  2005-02-15  ...  2005-03-01\n                SH600000     1          ...      3         ...      4           5               6\n                SH600001     1          ...      4         ...      5           6               7\n                SH600002     1          ...      5         ...      6           nan             nan\n                SH600003     nan        ...      1         ...      2           3               4\n\n            FeatureStorage(SH600000, close):\n\n                [\n                    (calendar.index(\"2005-01-01\"), 1),\n                    ...,\n                    (calendar.index(\"2005-03-01\"), 6)\n                ]\n\n                ====> [(0, 1), ..., (59, 6)]\n\n\n            FeatureStorage(SH600002, close):\n\n                [\n                    (calendar.index(\"2005-01-01\"), 1),\n                    ...,\n                    (calendar.index(\"2005-02-14\"), 6)\n                ]\n\n                ===> [(0, 1), ..., (44, 6)]\n\n            FeatureStorage(SH600003, close):\n\n                [\n                    (calendar.index(\"2005-02-01\"), 1),\n                    ...,\n                    (calendar.index(\"2005-03-01\"), 4)\n                ]\n\n                ===> [(31, 1), ..., (59, 4)]\n\n        \"\"\"\n\n        feature = FeatureStorage(instrument=\"SZ300677\", field=\"close\", freq=\"day\", provider_uri=self.provider_uri)\n\n        with self.assertRaises(IndexError):\n            print(feature[0])\n        assert isinstance(\n            feature[3049][1], (float, np.float32)\n        ), f\"{feature.__class__.__name__}.__getitem__(i: int) error\"\n        assert len(feature[3049:3052]) == 3, f\"{feature.__class__.__name__}.__getitem__(s: slice) error\"\n        print(f\"feature[3049: 3052]: \\n{feature[3049: 3052]}\")\n\n        print(f\"feature[:].tail(): \\n{feature[:].tail()}\")\n\n        feature = FeatureStorage(instrument=\"SH600004\", field=\"close\", freq=\"day\", provider_uri=\"not_fount\")\n\n        with self.assertRaises(ValueError):\n            print(feature[0])\n        with self.assertRaises(ValueError):\n            print(feature[:].empty)\n        with self.assertRaises(ValueError):\n            print(feature.data.empty)\n"
  },
  {
    "path": "tests/test_all_pipeline.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport sys\nimport shutil\nimport unittest\nimport pytest\nfrom pathlib import Path\n\nimport qlib\nfrom qlib.config import C\nfrom qlib.utils import init_instance_by_config, flatten_dict\nfrom qlib.workflow import R\nfrom qlib.workflow.record_temp import SignalRecord, SigAnaRecord, PortAnaRecord\nfrom qlib.tests import TestAutoData\nfrom qlib.tests.config import CSI300_GBDT_TASK, CSI300_BENCH\n\n\ndef train(uri_path: str = None):\n    \"\"\"train model\n\n    Returns\n    -------\n        pred_score: pandas.DataFrame\n            predict scores\n        performance: dict\n            model performance\n    \"\"\"\n\n    # model initialization\n    model = init_instance_by_config(CSI300_GBDT_TASK[\"model\"])\n    dataset = init_instance_by_config(CSI300_GBDT_TASK[\"dataset\"])\n    # To test __repr__\n    print(dataset)\n    print(R)\n\n    # start exp\n    with R.start(experiment_name=\"workflow\", uri=uri_path):\n        R.log_params(**flatten_dict(CSI300_GBDT_TASK))\n        model.fit(dataset)\n        R.save_objects(trained_model=model)\n        # prediction\n        recorder = R.get_recorder()\n        # To test __repr__\n        print(recorder)\n        # To test get_local_dir\n        print(recorder.get_local_dir())\n        rid = recorder.id\n        sr = SignalRecord(model, dataset, recorder)\n        sr.generate()\n        pred_score = sr.load(\"pred.pkl\")\n\n        # calculate ic and ric\n        sar = SigAnaRecord(recorder)\n        sar.generate()\n        ic = sar.load(\"ic.pkl\")\n        ric = sar.load(\"ric.pkl\")\n\n        uri_path = R.get_uri()\n    return pred_score, {\"ic\": ic, \"ric\": ric}, rid, uri_path\n\n\ndef fake_experiment():\n    \"\"\"A fake experiment workflow to test uri\n\n    Returns\n    -------\n        pass_or_not_for_default_uri: bool\n        pass_or_not_for_current_uri: bool\n        temporary_exp_dir: str\n    \"\"\"\n\n    # start exp\n    default_uri = R.get_uri()\n    current_uri = \"file:./temp-test-exp-mag\"\n    with R.start(experiment_name=\"fake_workflow_for_expm\", uri=current_uri):\n        R.log_params(**flatten_dict(CSI300_GBDT_TASK))\n\n        current_uri_to_check = R.get_uri()\n    default_uri_to_check = R.get_uri()\n    return default_uri == default_uri_to_check, current_uri == current_uri_to_check, current_uri\n\n\ndef backtest_analysis(pred, rid, uri_path: str = None):\n    \"\"\"backtest and analysis\n\n    Parameters\n    ----------\n    rid : str\n        the id of the recorder to be used in this function\n    uri_path: str\n        mlflow uri path\n\n    Returns\n    -------\n    analysis : pandas.DataFrame\n        the analysis result\n\n    \"\"\"\n    with R.uri_context(uri=uri_path):\n        recorder = R.get_recorder(experiment_name=\"workflow\", recorder_id=rid)\n\n    dataset = init_instance_by_config(CSI300_GBDT_TASK[\"dataset\"])\n    model = recorder.load_object(\"trained_model\")\n\n    port_analysis_config = {\n        \"executor\": {\n            \"class\": \"SimulatorExecutor\",\n            \"module_path\": \"qlib.backtest.executor\",\n            \"kwargs\": {\n                \"time_per_step\": \"day\",\n                \"generate_portfolio_metrics\": True,\n            },\n        },\n        \"strategy\": {\n            \"class\": \"TopkDropoutStrategy\",\n            \"module_path\": \"qlib.contrib.strategy.signal_strategy\",\n            \"kwargs\": {\n                \"signal\": (model, dataset),\n                \"topk\": 50,\n                \"n_drop\": 5,\n            },\n        },\n        \"backtest\": {\n            \"start_time\": \"2017-01-01\",\n            \"end_time\": \"2020-08-01\",\n            \"account\": 100000000,\n            \"benchmark\": CSI300_BENCH,\n            \"exchange_kwargs\": {\n                \"freq\": \"day\",\n                \"limit_threshold\": 0.095,\n                \"deal_price\": \"close\",\n                \"open_cost\": 0.0005,\n                \"close_cost\": 0.0015,\n                \"min_cost\": 5,\n            },\n        },\n    }\n    # backtest\n    par = PortAnaRecord(recorder, port_analysis_config, risk_analysis_freq=\"day\")\n    par.generate()\n    analysis_df = par.load(\"port_analysis_1day.pkl\")\n    print(analysis_df)\n    return analysis_df\n\n\nclass TestAllFlow(TestAutoData):\n    REPORT_NORMAL = None\n    POSITIONS = None\n    RID = None\n    URI_PATH = \"file:\" + str(Path(__file__).parent.joinpath(\"test_all_flow_mlruns\").resolve())\n\n    @classmethod\n    def tearDownClass(cls) -> None:\n        shutil.rmtree(cls.URI_PATH.lstrip(\"file:\"))\n\n    @pytest.mark.slow\n    def test_0_train(self):\n        TestAllFlow.PRED_SCORE, ic_ric, TestAllFlow.RID, uri_path = train(self.URI_PATH)\n        self.assertGreaterEqual(ic_ric[\"ic\"].all(), 0, \"train failed\")\n        self.assertGreaterEqual(ic_ric[\"ric\"].all(), 0, \"train failed\")\n\n    @pytest.mark.slow\n    def test_1_backtest(self):\n        analyze_df = backtest_analysis(TestAllFlow.PRED_SCORE, TestAllFlow.RID, self.URI_PATH)\n        self.assertGreaterEqual(\n            analyze_df.loc(axis=0)[\"excess_return_with_cost\", \"annualized_return\"].values[0],\n            0.05,\n            \"backtest failed\",\n        )\n        self.assertTrue(not analyze_df.isna().any().any(), \"backtest failed\")\n\n    @pytest.mark.slow\n    def test_2_expmanager(self):\n        pass_default, pass_current, uri_path = fake_experiment()\n        self.assertTrue(pass_default, msg=\"default uri is incorrect\")\n        self.assertTrue(pass_current, msg=\"current uri is incorrect\")\n        shutil.rmtree(str(Path(uri_path.strip(\"file:\")).resolve()))\n\n\ndef suite():\n    _suite = unittest.TestSuite()\n    _suite.addTest(TestAllFlow(\"test_0_train\"))\n    _suite.addTest(TestAllFlow(\"test_1_backtest\"))\n    _suite.addTest(TestAllFlow(\"test_2_expmanager\"))\n    return _suite\n\n\nif __name__ == \"__main__\":\n    runner = unittest.TextTestRunner()\n    runner.run(suite())\n"
  },
  {
    "path": "tests/test_contrib_model.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport unittest\n\nfrom qlib.contrib.model import all_model_classes\n\n\nclass TestAllFlow(unittest.TestCase):\n    def test_0_initialize(self):\n        num = 0\n        for model_class in all_model_classes:\n            if model_class is not None:\n                model = model_class()\n                num += 1\n        print(\"There are {:}/{:} valid models in total.\".format(num, len(all_model_classes)))\n\n\ndef suite():\n    _suite = unittest.TestSuite()\n    _suite.addTest(TestAllFlow(\"test_0_initialize\"))\n    return _suite\n\n\nif __name__ == \"__main__\":\n    runner = unittest.TextTestRunner()\n    runner.run(suite())\n"
  },
  {
    "path": "tests/test_contrib_workflow.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom qlib.workflow.record_temp import SignalRecord\nimport shutil\nimport unittest\nimport pytest\nfrom pathlib import Path\n\nfrom qlib.contrib.workflow import MultiSegRecord, SignalMseRecord\nfrom qlib.utils import init_instance_by_config, flatten_dict\nfrom qlib.workflow import R\nfrom qlib.tests import TestAutoData\nfrom qlib.tests.config import GBDT_MODEL, get_dataset_config, CSI300_MARKET\n\nCSI300_GBDT_TASK = {\n    \"model\": GBDT_MODEL,\n    \"dataset\": get_dataset_config(\n        train=(\"2020-05-01\", \"2020-06-01\"),\n        valid=(\"2020-06-01\", \"2020-07-01\"),\n        test=(\"2020-07-01\", \"2020-08-01\"),\n        handler_kwargs={\n            \"start_time\": \"2020-05-01\",\n            \"end_time\": \"2020-08-01\",\n            \"fit_start_time\": \"<dataset.kwargs.segments.train.0>\",\n            \"fit_end_time\": \"<dataset.kwargs.segments.train.1>\",\n            \"instruments\": CSI300_MARKET,\n        },\n    ),\n}\n\n\ndef train_multiseg(uri_path: str = None):\n    model = init_instance_by_config(CSI300_GBDT_TASK[\"model\"])\n    dataset = init_instance_by_config(CSI300_GBDT_TASK[\"dataset\"])\n    with R.start(experiment_name=\"workflow\", uri=uri_path):\n        R.log_params(**flatten_dict(CSI300_GBDT_TASK))\n        model.fit(dataset)\n        recorder = R.get_recorder()\n        sr = MultiSegRecord(model, dataset, recorder)\n        sr.generate(dict(valid=\"valid\", test=\"test\"), True)\n        uri = R.get_uri()\n    return uri\n\n\ndef train_mse(uri_path: str = None):\n    model = init_instance_by_config(CSI300_GBDT_TASK[\"model\"])\n    dataset = init_instance_by_config(CSI300_GBDT_TASK[\"dataset\"])\n    with R.start(experiment_name=\"workflow\", uri=uri_path):\n        R.log_params(**flatten_dict(CSI300_GBDT_TASK))\n        model.fit(dataset)\n        recorder = R.get_recorder()\n        SignalRecord(recorder=recorder, model=model, dataset=dataset).generate()\n        sr = SignalMseRecord(recorder)\n        sr.generate()\n        uri = R.get_uri()\n    return uri\n\n\nclass TestAllFlow(TestAutoData):\n    URI_PATH = \"file:\" + str(Path(__file__).parent.joinpath(\"test_contrib_mlruns\").resolve())\n\n    @classmethod\n    def tearDownClass(cls) -> None:\n        shutil.rmtree(cls.URI_PATH.lstrip(\"file:\"))\n\n    @pytest.mark.slow\n    def test_0_multiseg(self):\n        uri_path = train_multiseg(self.URI_PATH)\n\n    @pytest.mark.slow\n    def test_1_mse(self):\n        uri_path = train_mse(self.URI_PATH)\n\n\ndef suite():\n    _suite = unittest.TestSuite()\n    _suite.addTest(TestAllFlow(\"test_0_multiseg\"))\n    _suite.addTest(TestAllFlow(\"test_1_mse\"))\n    return _suite\n\n\nif __name__ == \"__main__\":\n    runner = unittest.TextTestRunner()\n    runner.run(suite())\n"
  },
  {
    "path": "tests/test_dump_data.py",
    "content": "#  Copyright (c) Microsoft Corporation.\n#  Licensed under the MIT License.\n\n\nimport sys\nimport shutil\nimport unittest\nfrom pathlib import Path\n\nimport qlib\nimport numpy as np\nimport pandas as pd\nfrom qlib.data import D\n\nsys.path.append(str(Path(__file__).resolve().parent.parent.joinpath(\"scripts\")))\nfrom get_data import GetData\nfrom dump_bin import DumpDataAll, DumpDataFix\n\nDATA_DIR = Path(__file__).parent.joinpath(\"test_dump_data\")\nSOURCE_DIR = DATA_DIR.joinpath(\"source\")\nSOURCE_DIR.mkdir(exist_ok=True, parents=True)\nQLIB_DIR = DATA_DIR.joinpath(\"qlib\")\nQLIB_DIR.mkdir(exist_ok=True, parents=True)\n\n\nclass TestDumpData(unittest.TestCase):\n    FIELDS = \"open,close,high,low,volume\".split(\",\")\n    QLIB_FIELDS = list(map(lambda x: f\"${x}\", FIELDS))\n    DUMP_DATA = None\n    STOCK_NAMES = None\n\n    # simpe data\n    SIMPLE_DATA = None\n\n    @classmethod\n    def setUpClass(cls) -> None:\n        GetData().download_data(file_name=\"csv_data_cn.zip\", target_dir=SOURCE_DIR)\n        TestDumpData.DUMP_DATA = DumpDataAll(data_path=SOURCE_DIR, qlib_dir=QLIB_DIR, include_fields=cls.FIELDS)\n        TestDumpData.STOCK_NAMES = list(map(lambda x: x.name[:-4].upper(), SOURCE_DIR.glob(\"*.csv\")))\n        provider_uri = str(QLIB_DIR.resolve())\n        qlib.init(\n            provider_uri=provider_uri,\n            expression_cache=None,\n            dataset_cache=None,\n        )\n\n    @classmethod\n    def tearDownClass(cls) -> None:\n        shutil.rmtree(str(DATA_DIR.resolve()))\n\n    def test_0_dump_bin(self):\n        self.DUMP_DATA.dump()\n\n    def test_1_dump_calendars(self):\n        ori_calendars = set(\n            map(\n                pd.Timestamp,\n                pd.read_csv(QLIB_DIR.joinpath(\"calendars\", \"day.txt\"), header=None).loc[:, 0].values,\n            )\n        )\n        res_calendars = set(D.calendar())\n        assert len(ori_calendars - res_calendars) == len(res_calendars - ori_calendars) == 0, \"dump calendars failed\"\n\n    def test_2_dump_instruments(self):\n        ori_ins = set(map(lambda x: x.name[:-4].upper(), SOURCE_DIR.glob(\"*.csv\")))\n        res_ins = set(D.list_instruments(D.instruments(\"all\"), as_list=True))\n        assert len(ori_ins - res_ins) == len(ori_ins - res_ins) == 0, \"dump instruments failed\"\n\n    def test_3_dump_features(self):\n        df = D.features(self.STOCK_NAMES, self.QLIB_FIELDS)\n        TestDumpData.SIMPLE_DATA = df.loc(axis=0)[self.STOCK_NAMES[0], :]\n        self.assertFalse(df.dropna().empty, \"features data failed\")\n        self.assertListEqual(list(df.columns), self.QLIB_FIELDS, \"features columns failed\")\n\n    def test_4_dump_features_simple(self):\n        stock = self.STOCK_NAMES[0]\n        dump_data = DumpDataFix(\n            data_path=SOURCE_DIR.joinpath(f\"{stock.lower()}.csv\"), qlib_dir=QLIB_DIR, include_fields=self.FIELDS\n        )\n        dump_data.dump()\n\n        df = D.features([stock], self.QLIB_FIELDS)\n\n        self.assertEqual(len(df), len(TestDumpData.SIMPLE_DATA), \"dump features simple failed\")\n        self.assertTrue(np.isclose(df.dropna(), self.SIMPLE_DATA.dropna()).all(), \"dump features simple failed\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/test_get_data.py",
    "content": "#  Copyright (c) Microsoft Corporation.\n#  Licensed under the MIT License.\n\nimport shutil\nimport unittest\nfrom pathlib import Path\n\nimport qlib\nfrom qlib.data import D\nfrom qlib.tests.data import GetData\n\nDATA_DIR = Path(__file__).parent.joinpath(\"test_get_data\")\nSOURCE_DIR = DATA_DIR.joinpath(\"source\")\nSOURCE_DIR.mkdir(exist_ok=True, parents=True)\nQLIB_DIR = DATA_DIR.joinpath(\"qlib\")\nQLIB_DIR.mkdir(exist_ok=True, parents=True)\n\n\nclass TestGetData(unittest.TestCase):\n    FIELDS = \"$open,$close,$high,$low,$volume,$factor,$change\".split(\",\")\n\n    @classmethod\n    def setUpClass(cls) -> None:\n        provider_uri = str(QLIB_DIR.resolve())\n        qlib.init(\n            provider_uri=provider_uri,\n            expression_cache=None,\n            dataset_cache=None,\n        )\n\n    @classmethod\n    def tearDownClass(cls) -> None:\n        shutil.rmtree(str(DATA_DIR.resolve()))\n\n    def test_0_qlib_data(self):\n        GetData().qlib_data(\n            name=\"qlib_data_simple\", target_dir=QLIB_DIR, region=\"cn\", interval=\"1d\", delete_old=False, exists_skip=True\n        )\n        df = D.features(D.instruments(\"csi300\"), self.FIELDS)\n        self.assertListEqual(list(df.columns), self.FIELDS, \"get qlib data failed\")\n        self.assertFalse(df.dropna().empty, \"get qlib data failed\")\n\n    def test_1_csv_data(self):\n        GetData().download_data(file_name=\"csv_data_cn.zip\", target_dir=SOURCE_DIR)\n        stock_name = set(map(lambda x: x.name[:-4].upper(), SOURCE_DIR.glob(\"*.csv\")))\n        self.assertEqual(len(stock_name), 85, \"get csv data failed\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/test_pit.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\n\nimport sys\nimport qlib\nimport shutil\nimport unittest\nimport pytest\nimport pandas as pd\nfrom pathlib import Path\n\nfrom qlib.data import D\nfrom qlib.tests.data import GetData\n\nsys.path.append(str(Path(__file__).resolve().parent.parent.joinpath(\"scripts\")))\nfrom dump_pit import DumpPitData\n\nsys.path.append(str(Path(__file__).resolve().parent.parent.joinpath(\"scripts/data_collector/pit\")))\nfrom collector import Run\n\npd.set_option(\"display.width\", 1000)\npd.set_option(\"display.max_columns\", None)\n\nDATA_DIR = Path(__file__).parent.joinpath(\"test_pit_data\")\nSOURCE_DIR = DATA_DIR.joinpath(\"stock_data/source\")\nSOURCE_DIR.mkdir(exist_ok=True, parents=True)\nQLIB_DIR = DATA_DIR.joinpath(\"qlib_data\")\nQLIB_DIR.mkdir(exist_ok=True, parents=True)\n\n\nclass TestPIT(unittest.TestCase):\n    @classmethod\n    def tearDownClass(cls) -> None:\n        shutil.rmtree(str(DATA_DIR.resolve()))\n\n    @classmethod\n    def setUpClass(cls) -> None:\n        cn_data_dir = str(QLIB_DIR.joinpath(\"cn_data\").resolve())\n        pit_dir = str(SOURCE_DIR.joinpath(\"pit\").resolve())\n        pit_normalized_dir = str(SOURCE_DIR.joinpath(\"pit_normalized\").resolve())\n        GetData().qlib_data(\n            name=\"qlib_data_simple\", target_dir=cn_data_dir, region=\"cn\", delete_old=False, exists_skip=True\n        )\n        GetData().qlib_data(name=\"qlib_data\", target_dir=pit_dir, region=\"pit\", delete_old=False, exists_skip=True)\n\n        # NOTE: This code does the same thing as line 43, but since baostock is not stable in downloading data, we have chosen to download offline data.\n        # bs.login()\n        # Run(\n        #     source_dir=pit_dir,\n        #     interval=\"quarterly\",\n        # ).download_data(start=\"2000-01-01\", end=\"2020-01-01\", symbol_regex=\"^(600519|000725).*\")\n        # bs.logout()\n\n        Run(\n            source_dir=pit_dir,\n            normalize_dir=pit_normalized_dir,\n            interval=\"quarterly\",\n        ).normalize_data()\n        DumpPitData(\n            csv_path=pit_normalized_dir,\n            qlib_dir=cn_data_dir,\n        ).dump(interval=\"quarterly\")\n\n    def setUp(self):\n        # qlib.init(kernels=1)  # NOTE: set kernel to 1 to make it debug easier\n        provider_uri = str(QLIB_DIR.joinpath(\"cn_data\").resolve())\n        qlib.init(provider_uri=provider_uri)\n\n    def to_str(self, obj):\n        return \"\".join(str(obj).split())\n\n    def check_same(self, a, b):\n        self.assertEqual(self.to_str(a), self.to_str(b))\n\n    def test_query(self):\n        instruments = [\"sh600519\"]\n        fields = [\"P($$roewa_q)\", \"P($$yoyni_q)\"]\n        # Mao Tai published 2019Q2 report at 2019-07-13 & 2019-07-18\n        # - http://www.cninfo.com.cn/new/commonUrl/pageOfSearch?url=disclosure/list/search&lastPage=index\n        data = D.features(instruments, fields, start_time=\"2019-01-01\", end_time=\"2019-07-19\", freq=\"day\")\n        res = \"\"\"\n               P($$roewa_q)  P($$yoyni_q)\n        count    133.000000    133.000000\n        mean       0.196412      0.277930\n        std        0.097591      0.030262\n        min        0.000000      0.243892\n        25%        0.094737      0.243892\n        50%        0.255220      0.304181\n        75%        0.255220      0.305041\n        max        0.344644      0.305041\n        \"\"\"\n        self.check_same(data.describe(), res)\n\n        res = \"\"\"\n                               P($$roewa_q)  P($$yoyni_q)\n        instrument datetime\n        sh600519   2019-07-15      0.000000      0.305041\n                   2019-07-16      0.000000      0.305041\n                   2019-07-17      0.000000      0.305041\n                   2019-07-18      0.175322      0.252650\n                   2019-07-19      0.175322      0.252650\n        \"\"\"\n        self.check_same(data.tail(), res)\n\n    def test_no_exist_data(self):\n        fields = [\"P($$roewa_q)\", \"P($$yoyni_q)\", \"$close\"]\n        data = D.features([\"sh600519\", \"sh601988\"], fields, start_time=\"2019-01-01\", end_time=\"2019-07-19\", freq=\"day\")\n        data[\"$close\"] = 1  # in case of different dataset gives different values\n        expect = \"\"\"\n                               P($$roewa_q)  P($$yoyni_q)  $close\n        instrument datetime\n        sh600519   2019-01-02       0.25522      0.243892       1\n                   2019-01-03       0.25522      0.243892       1\n                   2019-01-04       0.25522      0.243892       1\n                   2019-01-07       0.25522      0.243892       1\n                   2019-01-08       0.25522      0.243892       1\n        ...                             ...           ...     ...\n        sh601988   2019-07-15           NaN           NaN       1\n                   2019-07-16           NaN           NaN       1\n                   2019-07-17           NaN           NaN       1\n                   2019-07-18           NaN           NaN       1\n                   2019-07-19           NaN           NaN       1\n\n        [266 rows x 3 columns]\n        \"\"\"\n        self.check_same(data, expect)\n\n    @pytest.mark.slow\n    def test_expr(self):\n        fields = [\n            \"P(Mean($$roewa_q, 1))\",\n            \"P($$roewa_q)\",\n            \"P(Mean($$roewa_q, 2))\",\n            \"P(Ref($$roewa_q, 1))\",\n            \"P((Ref($$roewa_q, 1) +$$roewa_q) / 2)\",\n        ]\n        instruments = [\"sh600519\"]\n        data = D.features(instruments, fields, start_time=\"2019-01-01\", end_time=\"2019-07-19\", freq=\"day\")\n        expect = \"\"\"\n                               P(Mean($$roewa_q, 1))  P($$roewa_q)  P(Mean($$roewa_q, 2))  P(Ref($$roewa_q, 1))  P((Ref($$roewa_q, 1) +$$roewa_q) / 2)\n        instrument datetime\n        sh600519   2019-07-01               0.094737      0.094737               0.219691              0.344644                               0.219691\n                   2019-07-02               0.094737      0.094737               0.219691              0.344644                               0.219691\n                   2019-07-03               0.094737      0.094737               0.219691              0.344644                               0.219691\n                   2019-07-04               0.094737      0.094737               0.219691              0.344644                               0.219691\n                   2019-07-05               0.094737      0.094737               0.219691              0.344644                               0.219691\n                   2019-07-08               0.094737      0.094737               0.219691              0.344644                               0.219691\n                   2019-07-09               0.094737      0.094737               0.219691              0.344644                               0.219691\n                   2019-07-10               0.094737      0.094737               0.219691              0.344644                               0.219691\n                   2019-07-11               0.094737      0.094737               0.219691              0.344644                               0.219691\n                   2019-07-12               0.094737      0.094737               0.219691              0.344644                               0.219691\n                   2019-07-15               0.000000      0.000000               0.047369              0.094737                               0.047369\n                   2019-07-16               0.000000      0.000000               0.047369              0.094737                               0.047369\n                   2019-07-17               0.000000      0.000000               0.047369              0.094737                               0.047369\n                   2019-07-18               0.175322      0.175322               0.135029              0.094737                               0.135029\n                   2019-07-19               0.175322      0.175322               0.135029              0.094737                               0.135029\n        \"\"\"\n        self.check_same(data.tail(15), expect)\n\n    def test_unlimit(self):\n        # fields = [\"P(Mean($$roewa_q, 1))\", \"P($$roewa_q)\", \"P(Mean($$roewa_q, 2))\", \"P(Ref($$roewa_q, 1))\", \"P((Ref($$roewa_q, 1) +$$roewa_q) / 2)\"]\n        fields = [\"P($$roewa_q)\"]\n        instruments = [\"sh600519\"]\n        _ = D.features(instruments, fields, freq=\"day\")  # this should not raise error\n        data = D.features(instruments, fields, end_time=\"2020-01-01\", freq=\"day\")  # this should not raise error\n        s = data.iloc[:, 0]\n        # You can check the expected value based on the content in `docs/advanced/PIT.rst`\n        expect = \"\"\"\n        instrument  datetime\n        sh600519    2005-01-04         NaN\n                    2007-04-30    0.090219\n                    2007-08-17    0.139330\n                    2007-10-23    0.245863\n                    2008-03-03    0.347900\n                    2008-03-13    0.395989\n                    2008-04-22    0.100724\n                    2008-08-28    0.249968\n                    2008-10-27    0.334120\n                    2009-03-25    0.390117\n                    2009-04-21    0.102675\n                    2009-08-07    0.230712\n                    2009-10-26    0.300730\n                    2010-04-02    0.335461\n                    2010-04-26    0.083825\n                    2010-08-12    0.200545\n                    2010-10-29    0.260986\n                    2011-03-21    0.307393\n                    2011-04-25    0.097411\n                    2011-08-31    0.248251\n                    2011-10-18    0.318919\n                    2012-03-23    0.403900\n                    2012-04-11    0.403925\n                    2012-04-26    0.112148\n                    2012-08-10    0.264847\n                    2012-10-26    0.370487\n                    2013-03-29    0.450047\n                    2013-04-18    0.099958\n                    2013-09-02    0.210442\n                    2013-10-16    0.304543\n                    2014-03-25    0.394328\n                    2014-04-25    0.083217\n                    2014-08-29    0.164503\n                    2014-10-30    0.234085\n                    2015-04-21    0.078494\n                    2015-08-28    0.137504\n                    2015-10-23    0.201709\n                    2016-03-24    0.264205\n                    2016-04-21    0.073664\n                    2016-08-29    0.136576\n                    2016-10-31    0.188062\n                    2017-04-17    0.244385\n                    2017-04-25    0.080614\n                    2017-07-28    0.151510\n                    2017-10-26    0.254166\n                    2018-03-28    0.329542\n                    2018-05-02    0.088887\n                    2018-08-02    0.170563\n                    2018-10-29    0.255220\n                    2019-03-29    0.344644\n                    2019-04-25    0.094737\n                    2019-07-15    0.000000\n                    2019-07-18    0.175322\n                    2019-10-16    0.255819\n        Name: P($$roewa_q), dtype: float32\n        \"\"\"\n        self.check_same(s[~s.duplicated().values], expect)\n\n    def test_expr2(self):\n        instruments = [\"sh600519\"]\n        fields = [\"P($$roewa_q)\", \"P($$yoyni_q)\"]\n        fields += [\"P(($$roewa_q / $$yoyni_q) / Ref($$roewa_q / $$yoyni_q, 1) - 1)\"]\n        fields += [\"P(Sum($$yoyni_q, 4))\"]\n        fields += [\"$close\", \"P($$roewa_q) * $close\"]\n        data = D.features(instruments, fields, start_time=\"2019-01-01\", end_time=\"2020-01-01\", freq=\"day\")\n        except_data = \"\"\"\n                                       P($$roewa_q)  P($$yoyni_q)  P(($$roewa_q / $$yoyni_q) / Ref($$roewa_q / $$yoyni_q, 1) - 1)  P(Sum($$yoyni_q, 4))      $close  P($$roewa_q) * $close\n        instrument datetime\n        sh600519   2019-01-02      0.255220      0.243892                                           1.484224                           1.661578   63.595333              16.230801\n                   2019-01-03      0.255220      0.243892                                           1.484224                           1.661578   62.641907              15.987467\n                   2019-01-04      0.255220      0.243892                                           1.484224                           1.661578   63.915985              16.312637\n                   2019-01-07      0.255220      0.243892                                           1.484224                           1.661578   64.286530              16.407207\n                   2019-01-08      0.255220      0.243892                                           1.484224                           1.661578   64.212196              16.388237\n        ...                             ...           ...                                                ...                                ...         ...                    ...\n                   2019-12-25      0.255819      0.219821                                           0.677052                           1.081693  122.150467              31.248409\n                   2019-12-26      0.255819      0.219821                                           0.677052                           1.081693  122.301315              31.286999\n                   2019-12-27      0.255819      0.219821                                           0.677052                           1.081693  125.307404              32.056015\n                   2019-12-30      0.255819      0.219821                                           0.677052                           1.081693  127.763992              32.684456\n                   2019-12-31      0.255819      0.219821                                           0.677052                           1.081693  127.462303              32.607277\n\n        [244 rows x 6 columns]\n        \"\"\"\n        self.check_same(data, except_data)\n\n    def test_pref_operator(self):\n        instruments = [\"sh600519\"]\n        fields = [\n            \"PRef($$roewa_q, 201902)\",\n            \"PRef($$yoyni_q, 201801)\",\n            \"P($$roewa_q)\",\n            \"P($$roewa_q) / PRef($$roewa_q, 201801)\",\n        ]\n        data = D.features(instruments, fields, start_time=\"2018-04-28\", end_time=\"2019-07-19\", freq=\"day\")\n        except_data = \"\"\"\n                               PRef($$roewa_q, 201902)  PRef($$yoyni_q, 201801)  P($$roewa_q)  P($$roewa_q) / PRef($$roewa_q, 201801)\n        instrument datetime\n        sh600519   2018-05-02                      NaN                 0.395075      0.088887                                1.000000\n                   2018-05-03                      NaN                 0.395075      0.088887                                1.000000\n                   2018-05-04                      NaN                 0.395075      0.088887                                1.000000\n                   2018-05-07                      NaN                 0.395075      0.088887                                1.000000\n                   2018-05-08                      NaN                 0.395075      0.088887                                1.000000\n        ...                                        ...                      ...           ...                                     ...\n                   2019-07-15                 0.000000                 0.395075      0.000000                                0.000000\n                   2019-07-16                 0.000000                 0.395075      0.000000                                0.000000\n                   2019-07-17                 0.000000                 0.395075      0.000000                                0.000000\n                   2019-07-18                 0.175322                 0.395075      0.175322                                1.972414\n                   2019-07-19                 0.175322                 0.395075      0.175322                                1.972414\n\n        [299 rows x 4 columns]\n        \"\"\"\n        self.check_same(data, except_data)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/test_register_ops.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport unittest\nimport numpy as np\n\nfrom qlib.data import D\nfrom qlib.data.ops import ElemOperator, PairOperator\nfrom qlib.tests import TestAutoData\n\n\nclass Diff(ElemOperator):\n    \"\"\"Feature First Difference\n    Parameters\n    ----------\n    feature : Expression\n        feature instance\n    Returns\n    ----------\n    Expression\n        a feature instance with first difference\n    \"\"\"\n\n    def _load_internal(self, instrument, start_index, end_index, freq):\n        series = self.feature.load(instrument, start_index, end_index, freq)\n        return series.diff()\n\n    def get_extended_window_size(self):\n        lft_etd, rght_etd = self.feature.get_extended_window_size()\n        return lft_etd + 1, rght_etd\n\n\nclass Distance(PairOperator):\n    \"\"\"Feature Distance\n    Parameters\n    ----------\n    feature : Expression\n        feature instance\n    Returns\n    ----------\n    Expression\n        a feature instance with distance\n    \"\"\"\n\n    def _load_internal(self, instrument, start_index, end_index, freq):\n        series_left = self.feature_left.load(instrument, start_index, end_index, freq)\n        series_right = self.feature_right.load(instrument, start_index, end_index, freq)\n        return np.abs(series_left - series_right)\n\n\nclass TestRegiterCustomOps(TestAutoData):\n    @classmethod\n    def setUpClass(cls) -> None:\n        cls._setup_kwargs.update({\"custom_ops\": [Diff, Distance]})\n        super().setUpClass()\n\n    def test_regiter_custom_ops(self):\n        instruments = [\"SH600000\"]\n        fields = [\"Diff($close)\", \"Distance($close, Ref($close, 1))\"]\n        print(D.features(instruments, fields, start_time=\"2010-01-01\", end_time=\"2017-12-31\", freq=\"day\"))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/test_structured_cov_estimator.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport unittest\nimport numpy as np\nfrom scipy.linalg import sqrtm\n\nfrom qlib.model.riskmodel import StructuredCovEstimator\n\n\nclass TestStructuredCovEstimator(unittest.TestCase):\n    def test_random_covariance(self):\n        # Try to estimate the covariance from a randomly generated matrix.\n        NUM_VARIABLE = 10\n        NUM_OBSERVATION = 200\n        EPS = 1e-6\n\n        estimator = StructuredCovEstimator(scale_return=False, assume_centered=True)\n\n        X = np.random.rand(NUM_OBSERVATION, NUM_VARIABLE)\n\n        est_cov = estimator.predict(X, is_price=False)\n        np_cov = np.cov(X.T)  # While numpy assume row means variable, qlib assume the other wise.\n\n        delta = abs(est_cov - np_cov)\n        if_identical = (delta < EPS).all()\n\n        self.assertTrue(if_identical)\n\n    def test_nan_option_covariance(self):\n        # Test if nan_option is correctly passed.\n        NUM_VARIABLE = 10\n        NUM_OBSERVATION = 200\n        EPS = 1e-6\n\n        estimator = StructuredCovEstimator(scale_return=False, assume_centered=True, nan_option=\"fill\")\n\n        X = np.random.rand(NUM_OBSERVATION, NUM_VARIABLE)\n\n        est_cov = estimator.predict(X, is_price=False)\n        np_cov = np.cov(X.T)  # While numpy assume row means variable, qlib assume the other wise.\n\n        delta = abs(est_cov - np_cov)\n        if_identical = (delta < EPS).all()\n\n        self.assertTrue(if_identical)\n\n    def test_decompose_covariance(self):\n        # Test if return_decomposed_components is correctly passed.\n        NUM_VARIABLE = 10\n        NUM_OBSERVATION = 200\n\n        estimator = StructuredCovEstimator(scale_return=False, assume_centered=True, nan_option=\"fill\")\n\n        X = np.random.rand(NUM_OBSERVATION, NUM_VARIABLE)\n\n        F, cov_b, var_u = estimator.predict(X, is_price=False, return_decomposed_components=True)\n\n        self.assertTrue(F is not None and cov_b is not None and var_u is not None)\n\n    def test_constructed_covariance(self):\n        # Try to estimate the covariance from a specially crafted matrix.\n        # There should be some significant correlation since X is specially crafted.\n        NUM_VARIABLE = 7\n        NUM_OBSERVATION = 500\n        EPS = 0.1\n\n        estimator = StructuredCovEstimator(scale_return=False, assume_centered=True, num_factors=NUM_VARIABLE - 1)\n\n        sqrt_cov = None\n        while sqrt_cov is None or (np.iscomplex(sqrt_cov)).any():\n            cov = np.random.rand(NUM_VARIABLE, NUM_VARIABLE)\n            for i in range(NUM_VARIABLE):\n                cov[i][i] = 1\n            sqrt_cov = sqrtm(cov)\n        X = np.random.rand(NUM_OBSERVATION, NUM_VARIABLE) @ sqrt_cov\n\n        est_cov = estimator.predict(X, is_price=False)\n        np_cov = np.cov(X.T)  # While numpy assume row means variable, qlib assume the other wise.\n\n        delta = abs(est_cov - np_cov)\n        if_identical = (delta < EPS).all()\n\n        self.assertTrue(if_identical)\n\n    def test_decomposition(self):\n        # Try to estimate the covariance from a specially crafted matrix.\n        # The matrix is generated in the assumption that observations can be predicted by multiple factors.\n        NUM_VARIABLE = 30\n        NUM_OBSERVATION = 100\n        NUM_FACTOR = 10\n        EPS = 0.1\n\n        estimator = StructuredCovEstimator(scale_return=False, assume_centered=True, num_factors=NUM_FACTOR)\n\n        F = np.random.rand(NUM_VARIABLE, NUM_FACTOR)\n        B = np.random.rand(NUM_FACTOR, NUM_OBSERVATION)\n        U = np.random.rand(NUM_OBSERVATION, NUM_VARIABLE)\n        X = (F @ B).T + U\n\n        est_cov = estimator.predict(X, is_price=False)\n        np_cov = np.cov(X.T)  # While numpy assume row means variable, qlib assume the other wise.\n\n        delta = abs(est_cov - np_cov)\n        if_identical = (delta < EPS).all()\n\n        self.assertTrue(if_identical)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/test_workflow.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\nimport unittest\nfrom pathlib import Path\nimport shutil\n\nfrom qlib.workflow import R\nfrom qlib.tests import TestAutoData\n\n\nclass WorkflowTest(TestAutoData):\n    # Creating the directory manually doesn't work with mlflow,\n    # so we add a subfolder named .trash when we create the directory.\n    TMP_PATH = Path(\"./.mlruns_tmp/.trash\")\n\n    def tearDown(self) -> None:\n        if self.TMP_PATH.exists():\n            shutil.rmtree(self.TMP_PATH)\n\n    def test_get_local_dir(self):\n        \"\"\" \"\"\"\n        self.TMP_PATH.mkdir(parents=True, exist_ok=True)\n\n        with R.start(uri=str(self.TMP_PATH)):\n            pass\n\n        with R.uri_context(uri=str(self.TMP_PATH)):\n            resume_recorder = R.get_recorder()\n            resume_recorder.get_local_dir()\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  }
]