[
  {
    "path": ".gitignore",
    "content": "# vscode IDE\n.vscode\n\n# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\ncover/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\n.pybuilder/\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n#   For a library or package, you might want to ignore these files since the code is\n#   intended to run in multiple environments; otherwise, check them in:\n# .python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# poetry\n#   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.\n#   This is especially recommended for binary packages to ensure reproducibility, and is more\n#   commonly ignored for libraries.\n#   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control\n#poetry.lock\n\n# pdm\n#   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.\n#pdm.lock\n#   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it\n#   in version control.\n#   https://pdm.fming.dev/latest/usage/project/#working-with-version-control\n.pdm.toml\n.pdm-python\n.pdm-build/\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n\n# pytype static type analyzer\n.pytype/\n\n# Cython debug symbols\ncython_debug/\n\n# PyCharm\n#  JetBrains specific template is maintained in a separate JetBrains.gitignore that can\n#  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore\n#  and can be added to the global gitignore or merged into this file.  For a more nuclear\n#  option (not recommended) you can uncomment the following to ignore the entire idea folder.\n#.idea/\n\n1.sh\nlog/\n.vscode/\n\n*.DS_Store\n*.AppleDouble\n*.LSOverride\n*__MACOSX\n\n# Icon must end with two \\r characters\nIcon\n\n# Thumbnails / metadata\n._*\n.Spotlight-V100\n.Trashes\n.fseventsd\n\n# Volumes / network\n.AppleDB\n.AppleDesktop\nNetwork Trash Folder\nTemporary Items\n.VolumeIcon.icns\n\n# iCloud placeholders\n*.icloud"
  },
  {
    "path": ".gitmodules",
    "content": "[submodule \"submodules/uni2ts\"]\n\tpath = submodules/uni2ts\n\turl = https://github.com/SalesforceAIResearch/uni2ts.git\n[submodule \"submodules/lag_llama\"]\n\tpath = submodules/lag_llama\n\turl = https://github.com/time-series-foundation-models/lag-llama.git\n[submodule \"submodules/timesfm\"]\n\tpath = submodules/timesfm\n\turl = https://github.com/google-research/timesfm.git\n[submodule \"submodules/tsfm\"]\n\tpath = submodules/tsfm\n\turl = https://github.com/ibm-granite/granite-tsfm.git\n"
  },
  {
    "path": "CODE_OF_CONDUCT.md",
    "content": "# Microsoft Open Source Code of Conduct\n\nThis project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).\n\nResources:\n\n- [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)\n- [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)\n- Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns\n"
  },
  {
    "path": "LICENSE",
    "content": "    MIT License\n\n    Copyright (c) Microsoft Corporation.\n\n    Permission is hereby granted, free of charge, to any person obtaining a copy\n    of this software and associated documentation files (the \"Software\"), to deal\n    in the Software without restriction, including without limitation the rights\n    to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n    copies of the Software, and to permit persons to whom the Software is\n    furnished to do so, subject to the following conditions:\n\n    The above copyright notice and this permission notice shall be included in all\n    copies or substantial portions of the Software.\n\n    THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n    IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n    FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n    AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n    LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n    OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n    SOFTWARE\n"
  },
  {
    "path": "README.md",
    "content": "<div align=center> <img src=\"docs/figs/probts_logo.png\" width = 50%/> </div>\n\n# ProbTS: Benchmarking Point and Distributional Forecasting across Diverse Prediction Horizons\n\n[![arxiv](https://img.shields.io/badge/arXiv-2310.07446-red?link=https%3A%2F%2Farxiv.org%2Fabs%2F2310.07446)](https://arxiv.org/abs/2310.07446) [![benchmarking](https://img.shields.io/badge/Benchmarking-ExpResults-blue?style=flat&link=https%3A%2F%2Fgithub.com%2Fmicrosoft%2FProbTS%2Ftree%2Fadd_elastst%2Fdocs%2Fbenchmark)](./docs/benchmark/README.md) [![documentation](https://img.shields.io/badge/Toolkit-Documentation-green?style=flat&link=https%3A%2F%2Fgithub.com%2Fmicrosoft%2FProbTS%2Fblob%2Fadd_elastst%2Fdocs%2Fdocumentation%2FREADME.md)](./docs/documentation/README.md)\n\n\n## News :tada:\n\n:triangular_flag_on_post: **May 2025**: We have integrated [ModernTCN](https://github.com/luodhhh/ModernTCN/tree/main) into ProbTS. You can find the corresponding configuration file [here](./config/default/moderntcn.yaml).\n\n:triangular_flag_on_post: **Apr 2025**: ProbTS now includes [Time-MoE](https://github.com/Time-MoE/Time-MoE) and offers improved support for foundation models of varying sizes. See [Foundation Models](#foundation-models) for details.\n\n:triangular_flag_on_post: **Dec 2024**: ProbTS now supports [GIFT-EVAL](https://github.com/SalesforceAIResearch/gift-eval?tab=readme-ov-file#installation) benchmark datasets! Visit [this page](./docs/documentation/Gift_eval.md) for detailed instructions. *Please note that this feature is still in beta version and may contain bugs or inconsistencies. We will continue to update and improve it.*\n\n:triangular_flag_on_post: **Dec 2024**: Added quick guides for benchmarking foundation models. Visit [this page](./docs/benchmark/foundation_model/README.md) for detailed instructions.\n\n:triangular_flag_on_post: **Oct 2024**: ProbTS now includes the ElasTST model! Check out the [ElasTST branch](https://github.com/microsoft/ProbTS/tree/elastst) to reproduce all results reported in paper or run `bash scripts/run_elastst.sh` for a quick start.\n\n:triangular_flag_on_post: **Oct 2024**: The [camera-ready version](https://arxiv.org/abs/2310.07446) of ProbTS is now available, with more in-depth analyses on the impact of normalization.\n\n## About ProbTS :bulb:\n\nA wide range of industrial applications desire precise point and distributional forecasting for diverse prediction horizons. ProbTS serves as a benchmarking tool to aid in understanding how advanced time-series models fulfill these essential forecasting needs. It also sheds light on their advantages and disadvantages in addressing different challenges and unveil the possibilities for future research.\n\nTo achieve these objectives, ProbTS provides a unified pipeline that implements [cutting-edge models](#-available-models) from different research threads, including:\n- Supervised long-term point forecasting models, such as [PatchTST](https://arxiv.org/abs/2211.14730), [iTransformer](https://arxiv.org/abs/2310.06625), etc.\n- Supervised short-term probabilistic forecasting models, such as [TimeGrad](https://arxiv.org/abs/2101.12072), [CSDI](https://arxiv.org/abs/2107.03502), etc.\n- Pre-trained time-series foundation models for zero-shot forecasting, such as [TimesFM](https://arxiv.org/abs/2310.10688), [MOIRAI](https://arxiv.org/abs/2402.02592), etc.\n\nSpecifically, ProbTS emphasizes the differences in their primary methodological designs, including:\n- Supporting point or distributional forecasts\n- Using autoregressive or non-autoregressive decoding schemes for multi-step outputs\n\n<div align=center> <img src=\"docs/figs/probts_framework.png\" width = 95%/> </div>\n\n\n\n## Available Models 🧩\n\nProbTS includes both classical time-series models, specializing in long-term point forecasting or short-term distributional forecasting, and recent time-series foundation models that offer zero-shot and arbitrary-horizon forecasting capabilities for new time series.\n\n### Classical Time-series Models\n\n| **Model** | **Original Eval. Horizon** | **Estimation** | **Decoding Scheme** | **Class Path** |\n| --- | --- | --- | --- | --- |\n| Linear | - | Point | Auto / Non-auto | `probts.model.forecaster.point_forecaster.LinearForecaster` |\n| [GRU](https://arxiv.org/abs/1412.3555) | - | Point | AR / NAR | `probts.model.forecaster.point_forecaster.GRUForecaster` |\n| [Transformer](https://arxiv.org/abs/1706.03762) | - | Point | AR / NAR | `probts.model.forecaster.point_forecaster.TransformerForecaster` |\n| [Autoformer](https://arxiv.org/abs/2106.13008) | Long | Point | NAR | `probts.model.forecaster.point_forecaster.Autoformer` |\n| [N-HiTS](https://arxiv.org/abs/2201.12886) | Long | Point | NAR | `probts.model.forecaster.point_forecaster.NHiTS` |\n| [NLinear](https://arxiv.org/abs/2205.13504) | Long | Point | NAR | `probts.model.forecaster.point_forecaster.NLinear` |\n| [DLinear](https://arxiv.org/abs/2205.13504) | Long | Point | NAR | `probts.model.forecaster.point_forecaster.DLinear` |\n| [TSMixer](https://arxiv.org/abs/2303.06053) | Long | Point | NAR | `probts.model.forecaster.point_forecaster.TSMixer` |\n| [TimesNet](https://arxiv.org/abs/2210.02186) | Short / Long | Point | NAR | `probts.model.forecaster.point_forecaster.TimesNet` |\n| [PatchTST](https://arxiv.org/abs/2211.14730) | Long | Point | NAR | `probts.model.forecaster.point_forecaster.PatchTST` |\n| [iTransformer](https://arxiv.org/abs/2310.06625) | Long | Point | NAR | `probts.model.forecaster.point_forecaster.iTransformer` |\n| [ElasTST](https://arxiv.org/abs/2411.01842) | Long | Point | NAR | `probts.model.forecaster.point_forecaster.ElasTST` |\n| [GRU NVP](https://arxiv.org/abs/2002.06103) | Short | Probabilistic | AR | `probts.model.forecaster.prob_forecaster.GRU_NVP` |\n| [GRU MAF](https://arxiv.org/abs/2002.06103) | Short | Probabilistic | AR | `probts.model.forecaster.prob_forecaster.GRU_MAF` |\n| [Trans MAF](https://arxiv.org/abs/2002.06103) | Short | Probabilistic | AR | `probts.model.forecaster.prob_forecaster.Trans_MAF` |\n| [TimeGrad](https://arxiv.org/abs/2101.12072) | Short | Probabilistic | AR | `probts.model.forecaster.prob_forecaster.TimeGrad` |\n| [CSDI](https://arxiv.org/abs/2107.03502) | Short | Probabilistic | NAR | `probts.model.forecaster.prob_forecaster.CSDI` |\n| [TSDiff](https://arxiv.org/abs/2307.11494) | Short | Probabilistic | NAR | `probts.model.forecaster.prob_forecaster.TSDiffCond` |\n\n### Foundation Models\n\n| **Model** | **Any Horizon** | **Estimation** | **Decoding Scheme** | **Class Path** | **Model Size** | \n| --- | --- | --- | --- | --- | --- |\n| [Lag-Llama](https://arxiv.org/abs/2310.08278) | &#x2714; | Probabilistic | AR | `probts.model.forecaster.prob_forecaster.LagLlama` | - |\n| [ForecastPFN](https://arxiv.org/abs/2311.01933) | &#x2714; | Point | NAR | `probts.model.forecaster.point_forecaster.ForecastPFN` | - |\n| [TimesFM](https://arxiv.org/abs/2310.10688) | &#x2714; | Point | AR | `probts.model.forecaster.point_forecaster.TimesFM` | `200m`, `500m` |\n| [TTM](https://arxiv.org/abs/2401.03955) | &#x2718; | Point | NAR | `probts.model.forecaster.point_forecaster.TinyTimeMixer` | - |\n| [Timer](https://arxiv.org/abs/2402.02368) | &#x2714; | Point | AR | `probts.model.forecaster.point_forecaster.Timer` | - |\n| [MOIRAI](https://arxiv.org/abs/2402.02592) | &#x2714; | Probabilistic | NAR | `probts.model.forecaster.prob_forecaster.Moirai` | `small`, `base`, `large` |\n| [UniTS](https://arxiv.org/abs/2403.00131) | &#x2714; | Point | NAR | `probts.model.forecaster.point_forecaster.UniTS` | - |\n| [Chronos](https://arxiv.org/abs/2403.07815) | &#x2714; | Probabilistic | AR | `probts.model.forecaster.prob_forecaster.Chronos` | `tiny`, `mini`, `small`, `base`, `large` |\n| [Time-MoE](https://arxiv.org/abs/2409.16040) | &#x2714; | Point | AR | `probts.model.forecaster.point_forecaster.TimeMoE` | `50M`, `200M` |\n\nSee the [tsfm configuration directory](./config/tsfm/) for more details. More models will be added soon—stay tuned!\n\n## Setup :wrench:\n\n### Environment\n\nProbTS is developed with Python 3.10 and relies on [PyTorch Lightning](https://github.com/Lightning-AI/lightning). To set up the environment:\n\n```bash\n# Create a new conda environment\nconda create -n probts python=3.10\nconda activate probts\n\n# Install required packages\npip install .\npip uninstall -y probts # recommended to uninstall the root package (optional)\n```\n\n<details>\n\n<summary>Optional for TSFMs reproducibility</summary>\n\nFor time-series foundation models, you need to install basic packages and additional dependencies:\n\n**1. Set Up Environment**\n```bash\n# Create a new conda environment\nconda create -n probts_fm python=3.10\nconda activate probts_fm\n\n# Git submodule\ngit submodule update --init --recursive\n\n# Install additional packages for foundation models\npip install \".[tsfm]\"\npip uninstall -y probts # recommended to uninstall the root package (optional)\n```\n\n**2. Initialize Submodules**\n```bash\n# For MOIRAI, we fix the version of the package for better performance\ncd submodules/uni2ts\ngit reset --hard fce6a6f57bc3bc1a57c7feb3abc6c7eb2f264301\n\n# For Lag-Llama, fix the version for reproducibility (optional)\ncd submodules/lag_llama\ngit reset --hard 4ad82d9\n\n# For TinyTimeMixer, fix the version for reproducibility (optional)\ncd submodules/tsfm\ngit reset --hard bb125c14a05e4231636d6b64f8951d5fe96da1dc\n```\n\n</details>\n\n### Datasets\n\nFor a complete dataset list, refer to the [Datasets Overview](./docs/documentation/README.md#datasets-overview).\n\n- **Short-Term Forecasting**: We use datasets from [GluonTS](https://github.com/awslabs/gluonts). \n    Configure the datasets using `--data.data_manager.init_args.dataset {DATASET_NAME}`. You can choose from multivariate or univariate datasets as per your requirement.\n    ```bash\n    ['exchange_rate_nips', 'electricity_nips', 'traffic_nips', 'solar_nips', 'wiki2000_nips']\n    ```\n\n- **Long-Term Forecasting**: To download the [long-term forecasting datasets](https://drive.google.com/drive/folders/1ZOYpTUa82_jCcxIdTmyr0LXQfvaM9vIy), please follow these steps:\n    ```bash\n    bash scripts/prepare_datasets.sh \"./datasets\"\n    ```\n\n    Configure the datasets using `--data.data_manager.init_args.dataset {DATASET_NAME}` with the following list of available datasets:\n    ```bash\n    ['etth1', 'etth2','ettm1','ettm2','traffic_ltsf', 'electricity_ltsf', 'exchange_ltsf', 'illness_ltsf', 'weather_ltsf', 'caiso', 'nordpool']\n    ```\n    *Note: When utilizing long-term forecasting datasets, you must explicitly specify the `context_length` and `prediction_length` parameters. For example, to set a context length of 96 and a prediction length of 192, use the following command-line arguments:*\n    ```bash\n    --data.data_manager.init_args.context_length 96 \\\n    --data.data_manager.init_args.prediction_length 192 \\\n    ```\n\n- **Using Datasets from Monash Time Series Forecasting Repository**: To use datasets from the [Monash Time Series Forecasting Repository](https://forecastingdata.org/), follow these steps:\n\n    1. **Download the Dataset**: \n    - Navigate to the target dataset, such as the [Electricity Hourly Dataset](https://zenodo.org/records/4656140).\n    - Download the `.tsf` file and place it in your local `datasets` directory (e.g., `./datasets`).\n\n    1. **Configure the Dataset**:\n    - Use the following configuration to specify the dataset, file path, and frequency:\n        ```bash\n        --data.data_manager.init_args.dataset {DATASET_NAME} \\\n        --data.data_manager.init_args.data_path /path/to/data_file.tsf \\\n        --data.data_manager.init_args.freq {FREQ} \n        ```\n\n    - **Example Configuration**:\n        ```bash\n        --data.data_manager.init_args.dataset monash_electricity_hourly \\\n        --data.data_manager.init_args.data_path ./datasets/electricity_hourly_dataset.tsf \\\n        --data.data_manager.init_args.freq H \\\n        --data.data_manager.init_args.context_length 96 \\\n        --data.data_manager.init_args.prediction_length 96 \\\n        --data.data_manager.init_args.multivariate true\n        ```\n\n    *Note 1: Refer to the [Pandas Time Series Offset Aliases](https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#timeseries-offset-aliases) for the correct frequency values (`{FREQ}`) to use in your configuration.*\n\n    *Note 2: You can adjust the test instance sampling using the `--data.data_manager.init_args.test_rolling_length` parameter.*\n\n### Checkpoints for Foundation Models\n\nDownload the checkpoints with the following command (details can be found [here](./checkpoints/README.md)):\n```bash\nbash scripts/prepare_tsfm_checkpoints.sh # By downloading, you agree to the original licenses\n```\n\n## Quick Start :rocket:\n\nSpecify `--config` with a specific configuration file to reproduce results of point or probabilistic models on commonly used long- and short-term forecasting datasets. Configuration files are included in the [config](./config/) folder.\n\nTo run models:\n```bash \nbash run.sh\n```\n\nExperimental results reproduction:\n\n- **Long-term Forecasting:**\n\n    ```bash \n    bash scripts/reproduce_ltsf_results.sh\n    ```\n\n\n- **Short-term Forecasting:**\n\n    ```bash \n    bash scripts/reproduce_stsf_results.sh\n    ```\n\n- **Time Series Foundation Models:**\n\n    ```bash \n    bash scripts/reproduce_tsfm_results.sh\n    ```\n\n### Short-term Forecasting Configuration\n\nFor short-term forecasting scenarios, datasets and corresponding `context_length` and `prediction_length` are automatically obtained from [GluonTS](https://github.com/awslabs/gluonts). Use the following command:\n\n```bash \npython run.py --config config/path/to/model.yaml \\\n                --data.data_manager.init_args.path /path/to/datasets/ \\\n                --trainer.default_root_dir /path/to/log_dir/ \\\n                --data.data_manager.init_args.dataset {DATASET_NAME}\n```\nSee full `DATASET_NAME` list:\n```python\nfrom gluonts.dataset.repository import dataset_names\nprint(dataset_names)\n```\n\n### Long-term Forecasting Configuration\n\nFor long-term forecasting scenarios, `context_length` and `prediction_length` must be explicitly assigned:\n\n```bash \npython run.py --config config/path/to/model.yaml \\\n                --data.data_manager.init_args.path /path/to/datasets/ \\\n                --trainer.default_root_dir /path/to/log_dir/ \\\n                --data.data_manager.init_args.dataset {DATASET_NAME} \\\n                --data.data_manager.init_args.context_length {CTX_LEN} \\\n                --data.data_manager.init_args.prediction_length {PRED_LEN} \n```\n\n`DATASET_NAME` options:\n```bash \n['etth1', 'etth2','ettm1','ettm2','traffic_ltsf', 'electricity_ltsf', 'exchange_ltsf', 'illness_ltsf', 'weather_ltsf', 'caiso', 'nordpool']\n```\n\n### Forecasting with Varied Prediction Lengths\n\n\nConventional forecasting models typically require specific training and deployment for each prediction horizon. However, with the growing importance of varied-horizon forecasting, there is a need for models that can deliver robust predictions across multiple inference horizons after a single training phase.\n\nProbTS has been updated to support varied-horizon forecasting by enabling the specification of distinct context and prediction lengths for the training, validation, and testing phases.\n\n**Quick Start**\n\nTo quickly train and evaluate ElasTST:\n\n```bash \nbash scripts/run_elastst.sh\n```\n\nTo quickly set up varied-horizon training:\n\n```bash \nbash scripts/run_varied_hor_training.sh\n```\n\nFor detailed information on the configuration, refer to the [documentation](./docs/documentation/README.md#forecasting-with-varied-prediction-lengths).\n\n*Note: Currently, this feature is only supported by ElasTST, Autoformer, and foundation models.*\n\n\n## Benchmarking :balance_scale:\n\nBy utilizing ProbTS, we conduct a systematic comparison between studies that focus on point forecasting and those aimed at distributional estimation, employing various forecasting horizons and evaluation metrics. For more details\n\n- [Short-term & Long-term Forecasting Benchmarking](./docs/benchmark/README.md)\n- [Evaluating Time Series Foundation Models](./docs/benchmark/FOUNDATION_MODEL.md)\n\n\n## Documentation :open_book:\n\nFor detailed information on configuration parameters and model customization, please refer to the [documentation](./docs/documentation/README.md).\n\n\n- To print the full pipeline configuration to a file:\n\n    ```bash\n    python run.py --print_config > config/pipeline_config.yaml\n    ```\n\n## Acknowledgement 🌟\n\nSpecial thanks to the following repositories for their open-sourced code bases and datasets.\n\n### Tools/Packages\n\n- [GluonTS](https://github.com/awslabs/gluonts)\n- [PyTorch-TS](https://github.com/zalandoresearch/pytorch-ts)\n- [TSLib](https://github.com/libts/tslib) \n- [NeuralForecast](https://github.com/Nixtla/neuralforecast)\n\n### Official Implementations\n\n**Classical Time-series Models**\n\n- [Autoformer](https://github.com/thuml/Autoformer)\n- [N-HiTS](https://github.com/cchallu/n-hits)\n- [NLinear, DLinear](https://github.com/cure-lab/LTSF-Linear)\n- [TimesNet](https://github.com/thuml/Time-Series-Library)\n- [RevIN](https://github.com/ts-kim/RevIN)\n- [PatchTST](https://github.com/yuqinie98/PatchTST)\n- [iTransformer](https://github.com/thuml/iTransformer)\n- [GRU NVP, GRU MAF, Trans MAF, TimeGrad](https://github.com/zalandoresearch/pytorch-ts/tree/master)\n- [CSDI](https://github.com/ermongroup/CSDI)\n- [TSDiff](https://github.com/amazon-science/unconditional-time-series-diffusion)\n\n\n**Time-series Foundation Models**\n\n- [MOIRAI](https://github.com/SalesforceAIResearch/uni2ts)\n- [Chronos](https://github.com/amazon-science/chronos-forecasting)\n- [Lag-Llama](https://github.com/time-series-foundation-models/lag-llama)\n- [TimesFM](https://github.com/google-research/timesfm)\n- [Timer](https://github.com/thuml/Large-Time-Series-Model)\n- [UniTS](https://github.com/mims-harvard/UniTS)\n- [ForecastPFN](https://github.com/abacusai/ForecastPFN)\n- [TTM](https://github.com/ibm-granite/granite-tsfm)\n\n## Citing ProbTS :beers:\n\nIf you have used ProbTS for research or production, please cite it as follows.\n```tex\n@inproceedings{zhang2024probts,\n  title={{ProbTS}: Benchmarking Point and Distributional Forecasting across Diverse Prediction Horizons},\n  author={Zhang, Jiawen and Wen, Xumeng and Zhang, Zhenwei and Zheng, Shun and Li, Jia and Bian, Jiang},\n  booktitle={NeurIPS Datasets and Benchmarks Track},\n  year={2024}\n}\n```\n"
  },
  {
    "path": "SECURITY.md",
    "content": "<!-- BEGIN MICROSOFT SECURITY.MD V0.0.9 BLOCK -->\n\n## Security\n\nMicrosoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet) and [Xamarin](https://github.com/xamarin).\n\nIf you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below.\n\n## Reporting Security Issues\n\n**Please do not report security vulnerabilities through public GitHub issues.**\n\nInstead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report).\n\nIf you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com).  If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/security.md/msrc/pgp).\n\nYou should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). \n\nPlease include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:\n\n  * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)\n  * Full paths of source file(s) related to the manifestation of the issue\n  * The location of the affected source code (tag/branch/commit or direct URL)\n  * Any special configuration required to reproduce the issue\n  * Step-by-step instructions to reproduce the issue\n  * Proof-of-concept or exploit code (if possible)\n  * Impact of the issue, including how an attacker might exploit the issue\n\nThis information will help us triage your report more quickly.\n\nIf you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs.\n\n## Preferred Languages\n\nWe prefer all communications to be in English.\n\n## Policy\n\nMicrosoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd).\n\n<!-- END MICROSOFT SECURITY.MD BLOCK -->\n"
  },
  {
    "path": "checkpoints/README.md",
    "content": "# Checkpoints for Foundation Models\n\nFor full reproducibility, we provide the checkpoints for some foundation models as of the paper completion date. \n\nDownload the checkpoints from [Google Drive](https://drive.google.com/drive/folders/1FaCk9Lj9KZGEO09gehNqC4fbTj4wnN8j?usp=sharing) with:\n    \n```bash\n# By downloading, you agree to the terms of the original license agreements.\nsh scripts/prepare_checkpoints.sh # in root directory\n```\n\n\nYou can also download the newest checkpoints from the following repositories:\n\n- For `Timer`, download the checkpoints from its [official repository](https://github.com/thuml/Large-Time-Series-Model?tab=readme-ov-file#code-for-fine-tuning) ([Google Drive](https://drive.google.com/drive/folders/15oaiAl4OO5gFqZMJD2lOtX2fxHbpgcU8) or [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/d/235e6bfcf5fa440bb119/)) under the folder `./checkpoints/timer/Timer_67M_UTSD_4G.pt`.\n- For `ForecastPFN`, download the checkpoints from its [official repository](https://github.com/abacusai/ForecastPFN#installation-) ([Google Drive](https://drive.google.com/file/d/1acp5thS7I4g_6Gw40wNFGnU1Sx14z0cU/view)) under the folder `./checkpoints/ForecastPFN/saved_weights`.\n- For `UniTS`, download the checkpoints `units_x128_pretrain_checkpoint.pth` from its [official repository](https://github.com/mims-harvard/UniTS/releases/tag/ckpt) under the folder `./checkpoints/units/units_x128_pretrain_checkpoint.pth`.\n- For `Lag-Llama`, download the checkpoints `lag-llama.ckpt` from its [huggingface repository](https://huggingface.co/time-series-foundation-models/Lag-Llama/tree/main) under the folder `./checkpoints/lag-llama/lag-llama.ckpt`.\n- For other models, they can be automatically downloaded from huggingface during the first run.\n\n<center>\n\n| **Model** | **HuggingFace** |\n| --- | --- |\n| `MOIRAI` | [Link](https://huggingface.co/Salesforce/moirai-1.0-R-small) |\n| `Chronos` | [Link](https://huggingface.co/amazon/chronos-t5-large) |\n| `TinyTimeMixer` | [Link](https://huggingface.co/ibm-granite/granite-timeseries-ttm-v1) |\n| `TimesFM` | [Link](https://huggingface.co/google/timesfm-1.0-200m) |\n\n</center>\n"
  },
  {
    "path": "config/default/autoformer.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 0\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\n  accumulate_grad_batches: 1\n  # num_sanity_val_steps: 0\n  # gradient_clip_algorithm: 'norm'\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.Autoformer\n    init_args:\n      moving_avg: 25\n      factor: 1\n      n_heads: 8\n      activation: 'gelu'\n      e_layers: 2\n      d_layers: 1\n      output_attention: false\n      d_ff: 512\n      f_hidden_size: 512\n      embed: 'timeF'\n      use_lags: false\n      use_feat_idx_emb: false\n      use_time_feat: true\n      feat_idx_emb_dim: 1\n  num_samples: 1\n  learning_rate: 1e-3\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: solar_nips\n      split_val: true\n      scaler: standard # none, standard, scaling\n  batch_size: 32\n  test_batch_size: 32\n  num_workers: 8"
  },
  {
    "path": "config/default/csdi.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 800\n  log_every_n_steps: 1\n  check_val_every_n_epoch: 2\n  default_root_dir: ./results\n  accumulate_grad_batches: 8\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.CSDI\n    init_args:\n      emb_time_dim: 128\n      emb_feature_dim: 16\n      channels: 64\n      n_layers: 4\n      num_heads: 8\n      num_steps: 50\n      diffusion_embedding_dim: 128\n      beta_start: 0.001\n      beta_end: 0.5\n      sample_size: 64\n      linear_trans: false\n      use_lags: false\n      use_feat_idx_emb: false\n      use_time_feat: false\n      feat_idx_emb_dim: 1\n  num_samples: 100\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: solar_nips\n      split_val: true\n      scaler: standard # identity, standard, temporal\n  batch_size: 4\n  test_batch_size: 4\n  num_workers: 8\n"
  },
  {
    "path": "config/default/dlinear.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.DLinear\n    init_args:\n      individual: false\n      kernel_size: 3\n      use_lags: true\n      use_feat_idx_emb: true\n      use_time_feat: true\n  learning_rate: 0.01\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: solar_nips\n      split_val: true\n      scaler: standard # identity, standard, temporal\n  batch_size: 32\n  test_batch_size: 32\n  num_workers: 8"
  },
  {
    "path": "config/default/gru.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.GRUForecaster\n    init_args:\n      f_hidden_size: 40\n      num_layers: 2\n      dropout: 0.1\n      use_lags: true\n      use_feat_idx_emb: true\n      use_time_feat: true\n      feat_idx_emb_dim: 1\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: solar_nips\n      split_val: true\n      scaler: standard # identity, standard, temporal\n  batch_size: 64\n  test_batch_size: 64\n  num_workers: 8\n"
  },
  {
    "path": "config/default/gru_maf.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 1\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.GRU_MAF\n    init_args:\n      enc_num_layers: 2\n      enc_hidden_size: 40\n      enc_dropout: 0.1\n      n_blocks: 4\n      hidden_size: 100\n      n_hidden: 2\n      batch_norm: false\n      conditional_length: 200\n      dequantize: true\n      use_lags: true\n      use_feat_idx_emb: true\n      use_time_feat: true\n      feat_idx_emb_dim: 1\n      use_scaling: true\n  num_samples: 100\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: solar_nips\n      scaler: identity # identity, standard, temporal\n      split_val: true\n  batch_size: 64\n  test_batch_size: 64\n  num_workers: 8"
  },
  {
    "path": "config/default/gru_nvp.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 7\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.GRU_NVP\n    init_args:\n      enc_hidden_size: 40\n      enc_num_layers: 2\n      enc_dropout: 0.1\n      n_blocks: 4\n      hidden_size: 100\n      n_hidden: 2\n      batch_norm: true\n      conditional_length: 200\n      dequantize: true\n      use_lags: true\n      use_feat_idx_emb: true\n      use_time_feat: true\n      feat_idx_emb_dim: 1\n      use_scaling: true\n  num_samples: 100\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: solar_nips\n      split_val: true\n      scaler: identity # identity, standard, temporal\n  batch_size: 64\n  test_batch_size: 64\n  num_workers: 8"
  },
  {
    "path": "config/default/itransformer.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 0\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\n  accumulate_grad_batches: 1\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.iTransformer\n    init_args:\n      factor: 1\n      n_heads: 8\n      activation: 'gelu'\n      e_layers: 2\n      output_attention: false\n      f_hidden_size: 256\n      d_ff: 256\n      label_len: 48\n      use_lags: false\n      use_feat_idx_emb: false\n      use_time_feat: false\n      feat_idx_emb_dim: 1\n  num_samples: 1\n  learning_rate: 1e-4\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: solar_nips\n      split_val: true\n      scaler: standard # none, standard, scaling\n  batch_size: 32\n  test_batch_size: 32\n  num_workers: 8"
  },
  {
    "path": "config/default/linear.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 30\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.LinearForecaster\n    init_args:\n      individual: false\n      use_lags: true\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: solar_nips\n      split_val: true\n      scaler: standard # identity, standard, temporal\n  batch_size: 64\n  test_batch_size: 64\n  num_workers: 8\n"
  },
  {
    "path": "config/default/mean.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 40\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.MeanForecaster\n    init_args:\n      mode: global\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: solar_nips\n      split_val: true\n      scaler: identity # identity, standard, temporal\n  batch_size: 64\n  test_batch_size: 64\n  num_workers: 8"
  },
  {
    "path": "config/default/moderntcn.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 0\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.ModernTCN\n    init_args:\n      ffn_ratio: 1\n      patch_size: 8\n      patch_stride: 4\n      num_blocks: [1]\n      large_size: [51]\n      dims: [64, 64, 64, 64]\n      dropout: 0.3\n      kernel_size: 3\n      small_size: [5]\n      use_multi_scale: false\n      small_kernel_merged: false\n  learning_rate: 0.0001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: etth1\n      split_val: true\n      scaler: standard # identity, standard, temporal\n      context_length: 96\n      prediction_length: 96\n  batch_size: 32\n  test_batch_size: 32\n  num_workers: 8"
  },
  {
    "path": "config/default/naive.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 40\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.NaiveForecaster\n  learning_rate: 0.001\n  quantiles_num: 10\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: solar_nips\n      split_val: true\n      scaler: identity # identity, standard, temporal\n  batch_size: 64\n  test_batch_size: 64\n  num_workers: 8"
  },
  {
    "path": "config/default/nhits.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\n  accumulate_grad_batches: 4\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.NHiTS\n    init_args:\n      n_blocks: [1,1,1]\n      hidden_size: 512\n      pooling_mode: 'max'\n      interpolation_mode: 'linear'\n      activation: 'ReLU'\n      initialization: 'lecun_normal'\n      batch_normalization: false\n      shared_weights: false\n      naive_level: \n      dropout: 0\n      n_layers: 2\n      use_lags: false\n      use_feat_idx_emb: true\n      use_time_feat: true\n      feat_idx_emb_dim: 1\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: solar_nips\n      split_val: true\n      scaler: standard # identity, standard, temporal\n  batch_size: 64\n  test_batch_size: 64\n  num_workers: 8\n"
  },
  {
    "path": "config/default/nlinear.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.NLinear\n    init_args:\n      individual: false\n      use_lags: false\n      use_feat_idx_emb: false\n      use_time_feat: false\n  learning_rate: 0.01\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: solar_nips\n      split_val: true\n      scaler: standard # identity, standard, temporal\n  batch_size: 64\n  test_batch_size: 64\n  num_workers: 8\n"
  },
  {
    "path": "config/default/patchtst.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\n  accumulate_grad_batches: 1\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.PatchTST\n    init_args:\n      stride: 3\n      patch_len: 6\n      dropout: 0.1\n      f_hidden_size: 32\n      n_layers: 3\n      n_heads: 8\n      fc_dropout: 0.2\n      head_dropout: 0\n      individual: false\n  optimizer_config:\n    class_name: torch.optim.Adam\n    init_args:\n      weight_decay: 0\n  lr_scheduler_config:\n    class_name: torch.optim.lr_scheduler.OneCycleLR\n    init_args:\n      max_lr: 0.0001\n      steps_per_epoch: 100\n      pct_start: 0.3\n      epochs: 50\n  learning_rate: 0.0001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: exchange_rate_nips\n      split_val: true\n      scaler: standard # identity, standard, temporal\n  batch_size: 64\n  test_batch_size: 64\n  num_workers: 8"
  },
  {
    "path": "config/default/timegrad.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.TimeGrad\n    init_args:\n      loss_type: l2\n      diff_steps: 100\n      beta_end: 0.1\n      beta_schedule: linear\n      conditional_length: 100\n      enc_hidden_size: 128\n      enc_num_layers: 4\n      enc_dropout: 0.1\n      use_lags: true\n      use_feat_idx_emb: true\n      use_time_feat: true\n      feat_idx_emb_dim: 1\n      use_scaling: true\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: solar_nips\n      scaler: identity # identity, standard, temporal\n      split_val: true\n  batch_size: 64\n  test_batch_size: 64\n  num_workers: 8"
  },
  {
    "path": "config/default/timesnet.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.TimesNet\n    init_args:\n      n_layers: 2\n      num_kernels: 6\n      top_k: 5\n      d_ff: 32\n      dropout: 0.1\n      f_hidden_size: 40\n      use_lags: false\n      use_feat_idx_emb: false\n      use_time_feat: false\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: solar_nips\n      split_val: true\n      scaler: standard # identity, standard, temporal\n  batch_size: 64\n  test_batch_size: 64\n  num_workers: 8\n"
  },
  {
    "path": "config/default/trans_maf.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.Trans_MAF\n    init_args:\n      enc_hidden_size: 32\n      enc_num_heads: 8\n      enc_num_encoder_layers: 2\n      enc_num_decoder_layers: 2\n      enc_dim_feedforward_scale: 4\n      enc_dropout: 0.1\n      enc_activation: gelu\n      n_blocks: 4\n      hidden_size: 100\n      n_hidden: 2\n      batch_norm: false\n      conditional_length: 200\n      dequantize: true\n      use_lags: true\n      use_feat_idx_emb: true\n      use_time_feat: true\n      feat_idx_emb_dim: 1\n      use_scaling: true\n  num_samples: 100\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: solar_nips\n      scaler: identity # identity, standard, temporal\n      split_val: true\n  batch_size: 64\n  test_batch_size: 64\n  num_workers: 8"
  },
  {
    "path": "config/default/transformer.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.TransformerForecaster\n    init_args:\n      f_hidden_size: 16\n      num_heads: 4\n      num_encoder_layers: 3\n      num_decoder_layers: 3\n      dim_feedforward_scale: 4\n      dropout: 0.1\n      activation: gelu\n      use_lags: true\n      use_feat_idx_emb: true\n      use_time_feat: true\n      feat_idx_emb_dim: 1\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: solar_nips\n      split_val: true\n      scaler: standard # identity, standard, temporal\n  batch_size: 64\n  test_batch_size: 64\n  num_workers: 8"
  },
  {
    "path": "config/default/tsdiff.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  check_val_every_n_epoch: 1\n  default_root_dir: ./results\n  accumulate_grad_batches: 1\n  gradient_clip_val: 0.5\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.TSDiffCond\n    init_args:\n      timesteps: 100\n      hidden_dim: 64\n      step_emb: 128\n      num_residual_blocks: 3\n      dropout: 0.0\n      mode: diag # diag, nplr\n      measure: diag # 'diag', 'diag-lin', 'diag-inv', or 'diag-legs' for diag\n      use_lags: false\n      use_feat_idx_emb: false\n      use_time_feat: false\n      feat_idx_emb_dim: 1\n      use_scaling: false\n  num_samples: 100\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: solar_nips\n      split_val: true\n      scaler: temporal # identity, standard, temporal\n      context_length: 336\n  batch_size: 32\n  test_batch_size: 32\n  num_workers: 8\n"
  },
  {
    "path": "config/default/tsmixer.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 40\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\n  accumulate_grad_batches: 1\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.TSMixer\n    init_args:\n      num_blocks: 6\n      dropout_rate: 0.7\n      ff_dim: 64\n      use_lags: false\n      use_feat_idx_emb: false\n      use_time_feat: false\n      feat_idx_emb_dim: 1\n  learning_rate: 0.0001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: etth1\n      split_val: true\n      scaler: standard # identity, standard, temporal\n      context_length: 96\n      prediction_length: 96\n  batch_size: 64\n  test_batch_size: 64\n  num_workers: 8\n"
  },
  {
    "path": "config/ltsf/electricity_ltsf/csdi.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 800\n  log_every_n_steps: 1\n  check_val_every_n_epoch: 3\n  default_root_dir: ./results\n  accumulate_grad_batches: 8\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.CSDI\n    init_args:\n      emb_time_dim: 64\n      emb_feature_dim: 8\n      channels: 64\n      n_layers: 4\n      num_heads: 8\n      num_steps: 50\n      diffusion_embedding_dim: 64\n      beta_start: 0.001\n      beta_end: 0.5\n      sample_size: 16\n      linear_trans: false\n      use_lags: false\n      use_feat_idx_emb: false\n      use_time_feat: false\n      feat_idx_emb_dim: 1\n  num_samples: 100\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: electricity_ltsf\n      scaler: standard # identity, standard, temporal\n      split_val: true\n  batch_size: 4\n  test_batch_size: 8\n  num_workers: 8\n"
  },
  {
    "path": "config/ltsf/electricity_ltsf/dlinear.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 200\n  log_every_n_steps: 1\n  accumulate_grad_batches: 2\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.DLinearEncoder\n    init_args:\n      individual: true\n      kernel_size: 25\n      use_lags: false\n      use_feat_idx_emb: false\n      use_time_feat: false\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: electricity_ltsf\n      split_val: true\n      scaler: standard # identity, standard, temporal\n      context_length: 96\n      prediction_length: 96\n  batch_size: 16\n  test_batch_size: 16\n  num_workers: 8\n"
  },
  {
    "path": "config/ltsf/electricity_ltsf/gru_nvp.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 400\n  log_every_n_steps: 1\n  default_root_dir: ./results\n  accumulate_grad_batches: 4\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.GRU_NVP\n    init_args:\n      enc_hidden_size: 128\n      enc_num_layers: 2\n      enc_dropout: 0.1\n      n_blocks: 2\n      hidden_size: 64\n      n_hidden: 2\n      batch_norm: false\n      conditional_length: 200\n      dequantize: false\n      use_lags: true\n      use_feat_idx_emb: true\n      use_time_feat: true\n      feat_idx_emb_dim: 1\n      use_scaling: true\n  num_samples: 100\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: electricity_ltsf\n      split_val: true\n      scaler: identity # identity, standard, temporal\n      context_length: 96\n      prediction_length: 96\n  batch_size: 16\n  test_batch_size: 16\n  num_workers: 8\n"
  },
  {
    "path": "config/ltsf/electricity_ltsf/patchtst.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 400\n  log_every_n_steps: 1\n  default_root_dir: ./results\n  accumulate_grad_batches: 4\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.PatchTST\n    init_args:\n      stride: 8\n      patch_len: 16\n      dropout: 0.2\n      f_hidden_size: 128\n      n_layers: 3\n      n_heads: 16\n      fc_dropout: 0.2\n      head_dropout: 0\n      individual: false\n  num_samples: 100\n  learning_rate: 0.0001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: electricity_ltsf\n      split_val: true\n      scaler: standard # identity, standard, temporal\n      context_length: 96\n      prediction_length: 96\n  batch_size: 8\n  test_batch_size: 8\n  num_workers: 8"
  },
  {
    "path": "config/ltsf/electricity_ltsf/timegrad.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\n  accumulate_grad_batches: 4\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.TimeGrad\n    init_args:\n      loss_type: l2\n      diff_steps: 100\n      beta_end: 0.1\n      beta_schedule: linear\n      conditional_length: 200\n      enc_hidden_size: 128\n      enc_num_layers: 3\n      enc_dropout: 0.1\n      use_lags: true\n      use_feat_idx_emb: true\n      use_time_feat: true\n      feat_idx_emb_dim: 1\n      use_scaling: true\n  num_samples: 100\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: electricity_ltsf\n      split_val: true\n      scaler: identity # identity, standard, temporal\n      context_length: 96\n      prediction_length: 96\n  batch_size: 16\n  test_batch_size: 16\n  num_workers: 8\n"
  },
  {
    "path": "config/ltsf/etth1/csdi.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 800\n  log_every_n_steps: 1\n  check_val_every_n_epoch: 2\n  default_root_dir: ./results\n  accumulate_grad_batches: 8\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.CSDI\n    init_args:\n      emb_time_dim: 128\n      emb_feature_dim: 16\n      channels: 64\n      n_layers: 4\n      num_heads: 8\n      num_steps: 50\n      diffusion_embedding_dim: 128\n      beta_start: 0.001\n      beta_end: 0.5\n      sample_size: 64\n      linear_trans: false\n      use_lags: false\n      use_feat_idx_emb: false\n      use_time_feat: false\n      feat_idx_emb_dim: 1\n  num_samples: 100\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: etth1\n      split_val: true\n      scaler: standard # identity, standard, temporal\n      context_length: 96\n      prediction_length: 96\n  batch_size: 8\n  test_batch_size: 8\n  num_workers: 8\n"
  },
  {
    "path": "config/ltsf/etth1/dlinear.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  accumulate_grad_batches: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.DLinear\n    init_args:\n      individual: true\n      kernel_size: 25\n      use_lags: false\n      use_feat_idx_emb: false\n      use_time_feat: false\n  learning_rate: 0.005\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: etth1\n      split_val: true\n      scaler: standard # identity, standard, temporal\n      context_length: 96\n      prediction_length: 96\n  batch_size: 32\n  test_batch_size: 32\n  num_workers: 8\n"
  },
  {
    "path": "config/ltsf/etth1/gru_nvp.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\n  accumulate_grad_batches: 1\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.GRU_NVP\n    init_args:\n      enc_hidden_size: 64\n      enc_num_layers: 2\n      enc_dropout: 0.1\n      n_blocks: 4\n      hidden_size: 64\n      n_hidden: 3\n      batch_norm: false\n      conditional_length: 100\n      dequantize: false\n      use_lags: true\n      use_feat_idx_emb: true\n      use_time_feat: true\n      feat_idx_emb_dim: 1\n      use_scaling: true\n  num_samples: 100\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: etth1\n      split_val: true\n      scaler: identity # identity, standard, temporal\n      context_length: 96\n      prediction_length: 96\n  batch_size: 64\n  test_batch_size: 64\n  num_workers: 8\n"
  },
  {
    "path": "config/ltsf/etth1/patchtst.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\n  accumulate_grad_batches: 4\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.PatchTST\n    init_args:\n      stride: 8\n      patch_len: 16\n      dropout: 0.3\n      f_hidden_size: 16\n      n_layers: 3\n      n_heads: 4\n      fc_dropout: 0.2\n      head_dropout: 0\n      individual: true\n  learning_rate: 0.0001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: etth1\n      split_val: true\n      scaler: standard # identity, standard, temporal\n      context_length: 96\n      prediction_length: 96\n  batch_size: 32\n  test_batch_size: 32\n  num_workers: 8"
  },
  {
    "path": "config/ltsf/etth1/timegrad.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\n  accumulate_grad_batches: 4\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.TimeGrad\n    init_args:\n      loss_type: l2\n      diff_steps: 100\n      beta_end: 0.1\n      beta_schedule: linear\n      conditional_length: 200\n      enc_hidden_size: 128\n      enc_num_layers: 3\n      enc_dropout: 0.1\n      use_lags: true\n      use_feat_idx_emb: true\n      use_time_feat: true\n      feat_idx_emb_dim: 1\n      use_scaling: true\n  num_samples: 100\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: etth1\n      split_val: true\n      scaler: identity # identity, standard, temporal\n      context_length: 96\n      prediction_length: 96\n  batch_size: 16\n  test_batch_size: 16\n  num_workers: 8\n"
  },
  {
    "path": "config/ltsf/etth2/csdi.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 800\n  log_every_n_steps: 1\n  check_val_every_n_epoch: 2\n  default_root_dir: ./results\n  accumulate_grad_batches: 8\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.CSDI\n    init_args:\n      emb_time_dim: 128\n      emb_feature_dim: 16\n      channels: 64\n      n_layers: 4\n      num_heads: 8\n      num_steps: 50\n      diffusion_embedding_dim: 128\n      beta_start: 0.001\n      beta_end: 0.5\n      sample_size: 64\n      linear_trans: false\n      use_lags: false\n      use_feat_idx_emb: false\n      use_time_feat: false\n      feat_idx_emb_dim: 1\n  num_samples: 100\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: etth2\n      split_val: true\n      scaler: standard # identity, standard, temporal\n      context_length: 96\n      prediction_length: 96\n  batch_size: 8\n  test_batch_size: 8\n  num_workers: 8\n"
  },
  {
    "path": "config/ltsf/etth2/dlinear.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  accumulate_grad_batches: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.DLinear\n    init_args:\n      individual: false\n      kernel_size: 25\n      use_lags: false\n      use_feat_idx_emb: false\n      use_time_feat: false\n  learning_rate: 0.05\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: etth2\n      split_val: true\n      scaler: standard # identity, standard, temporal\n      context_length: 96\n      prediction_length: 96\n  batch_size: 32\n  test_batch_size: 32\n  num_workers: 8\n"
  },
  {
    "path": "config/ltsf/etth2/gru_nvp.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 400\n  log_every_n_steps: 1\n  default_root_dir: ./results\n  accumulate_grad_batches: 4\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.GRU_NVP\n    init_args:\n      enc_hidden_size: 64\n      enc_num_layers: 4\n      enc_dropout: 0.1\n      n_blocks: 2\n      hidden_size: 128\n      n_hidden: 3\n      batch_norm: true\n      conditional_length: 200\n      dequantize: false\n      use_lags: true\n      use_feat_idx_emb: true\n      use_time_feat: true\n      feat_idx_emb_dim: 1\n      use_scaling: true\n  num_samples: 100\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: etth2\n      path: /home/covpreduser/Blob/v-jiawezhang/data/all_datasets/\n      split_val: true\n      scaler: identity # identity, standard, temporal\n      context_length: 96\n      prediction_length: 96\n  batch_size: 16\n  test_batch_size: 16\n  num_workers: 8\n"
  },
  {
    "path": "config/ltsf/etth2/patchtst.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\n  accumulate_grad_batches: 1\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.PatchTST\n    init_args:\n      stride: 8\n      patch_len: 16\n      dropout: 0.3\n      f_hidden_size: 16\n      d_ff: 128\n      n_layers: 3\n      n_heads: 4\n      fc_dropout: 0.2\n      head_dropout: 0\n      individual: false\n  learning_rate: 0.0001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: etth2\n      split_val: true\n      scaler: standard # identity, standard, temporal\n      context_length: 96\n      prediction_length: 96\n  batch_size: 32\n  test_batch_size: 32\n  num_workers: 8"
  },
  {
    "path": "config/ltsf/etth2/timegrad.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\n  accumulate_grad_batches: 4\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.TimeGrad\n    init_args:\n      loss_type: l2\n      diff_steps: 100\n      beta_end: 0.1\n      beta_schedule: linear\n      conditional_length: 100\n      enc_hidden_size: 64\n      enc_num_layers: 4\n      enc_dropout: 0.1\n      use_lags: true\n      use_feat_idx_emb: true\n      use_time_feat: true\n      feat_idx_emb_dim: 1\n      use_scaling: true\n  num_samples: 100\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: etth2\n      split_val: true\n      scaler: identity # identity, standard, temporal\n      context_length: 96\n      prediction_length: 96\n  batch_size: 16\n  test_batch_size: 16\n  num_workers: 8\n"
  },
  {
    "path": "config/ltsf/ettm1/csdi.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 800\n  log_every_n_steps: 1\n  check_val_every_n_epoch: 2\n  default_root_dir: ./results\n  accumulate_grad_batches: 8\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.CSDI\n    init_args:\n      emb_time_dim: 128\n      emb_feature_dim: 16\n      channels: 64\n      n_layers: 4\n      num_heads: 8\n      num_steps: 50\n      diffusion_embedding_dim: 128\n      beta_start: 0.001\n      beta_end: 0.5\n      sample_size: 64\n      linear_trans: false\n      use_lags: false\n      use_feat_idx_emb: false\n      use_time_feat: false\n      feat_idx_emb_dim: 1\n  num_samples: 100\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: ettm1\n      split_val: true\n      scaler: standard # identity, standard, temporal\n      context_length: 96\n      prediction_length: 96\n  batch_size: 8\n  test_batch_size: 8\n  num_workers: 8\n"
  },
  {
    "path": "config/ltsf/ettm1/dlinear.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  accumulate_grad_batches: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.DLinear\n    init_args:\n      individual: true\n      kernel_size: 25\n      use_lags: false\n      use_feat_idx_emb: false\n      use_time_feat: false\n  learning_rate: 0.0001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: ettm1\n      split_val: true\n      scaler: standard # identity, standard, temporal\n      context_length: 96\n      prediction_length: 96\n  batch_size: 32\n  test_batch_size: 32\n  num_workers: 8\n"
  },
  {
    "path": "config/ltsf/ettm1/gru_nvp.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 400\n  log_every_n_steps: 1\n  default_root_dir: ./results\n  accumulate_grad_batches: 4\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.GRU_NVP\n    init_args:\n      enc_hidden_size: 64\n      enc_num_layers: 2\n      enc_dropout: 0.1\n      n_blocks: 4\n      hidden_size: 64\n      n_hidden: 3\n      batch_norm: false\n      conditional_length: 200\n      dequantize: false\n      use_lags: true\n      use_feat_idx_emb: true\n      use_time_feat: true\n      feat_idx_emb_dim: 1\n      use_scaling: true\n  num_samples: 100\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: ettm1\n      split_val: true\n      scaler: identity # identity, standard, temporal\n      context_length: 96\n      prediction_length: 96\n  batch_size: 16\n  test_batch_size: 16\n  num_workers: 8\n"
  },
  {
    "path": "config/ltsf/ettm1/patchtst.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\n  accumulate_grad_batches: 1\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.PatchTST\n    init_args:\n      stride: 8\n      patch_len: 16\n      dropout: 0.2\n      f_hidden_size: 128\n      n_layers: 3\n      n_heads: 16\n      fc_dropout: 0.2\n      head_dropout: 0\n      individual: true\n  learning_rate: 0.0001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: ettm1\n      split_val: true\n      scaler: standard # identity, standard, temporal\n      context_length: 96\n      prediction_length: 96\n  batch_size: 32\n  test_batch_size: 32\n  num_workers: 8"
  },
  {
    "path": "config/ltsf/ettm1/timegrad.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\n  accumulate_grad_batches: 4\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.TimeGrad\n    init_args:\n      loss_type: l2\n      diff_steps: 100\n      beta_end: 0.1\n      beta_schedule: linear\n      conditional_length: 200\n      enc_hidden_size: 128\n      enc_num_layers: 3\n      enc_dropout: 0.1\n      use_lags: true\n      use_feat_idx_emb: true\n      use_time_feat: true\n      feat_idx_emb_dim: 1\n      use_scaling: true\n  num_samples: 100\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: ettm1\n      split_val: true\n      scaler: identity # identity, standard, temporal\n      context_length: 96\n      prediction_length: 96\n  batch_size: 16\n  test_batch_size: 16\n  num_workers: 8\n"
  },
  {
    "path": "config/ltsf/ettm2/csdi.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 800\n  log_every_n_steps: 1\n  check_val_every_n_epoch: 2\n  default_root_dir: ./results\n  accumulate_grad_batches: 8\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.CSDI\n    init_args:\n      emb_time_dim: 128\n      emb_feature_dim: 16\n      channels: 64\n      n_layers: 4\n      num_heads: 8\n      num_steps: 50\n      diffusion_embedding_dim: 128\n      beta_start: 0.001\n      beta_end: 0.5\n      sample_size: 64\n      linear_trans: false\n      use_lags: false\n      use_feat_idx_emb: false\n      use_time_feat: false\n      feat_idx_emb_dim: 1\n  num_samples: 100\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: ettm2\n      split_val: true\n      scaler: standard # identity, standard, temporal\n      context_length: 96\n      prediction_length: 96\n  batch_size: 8\n  test_batch_size: 8\n  num_workers: 8\n"
  },
  {
    "path": "config/ltsf/ettm2/dlinear.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  accumulate_grad_batches: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.DLinear\n    init_args:\n      individual: false\n      kernel_size: 25\n      use_lags: false\n      use_feat_idx_emb: false\n      use_time_feat: false\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: ettm2\n      split_val: true\n      scaler: standard # identity, standard, temporal\n      context_length: 96\n      prediction_length: 96\n  batch_size: 32\n  test_batch_size: 32\n  num_workers: 8\n"
  },
  {
    "path": "config/ltsf/ettm2/gru_nvp.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 400\n  log_every_n_steps: 1\n  default_root_dir: ./results\n  accumulate_grad_batches: 4\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.GRU_NVP\n    init_args:\n      enc_hidden_size: 64\n      enc_num_layers: 4\n      enc_dropout: 0.1\n      n_blocks: 2\n      hidden_size: 128\n      n_hidden: 3\n      batch_norm: false\n      conditional_length: 200\n      dequantize: false\n      use_lags: true\n      use_feat_idx_emb: true\n      use_time_feat: true\n      feat_idx_emb_dim: 1\n      use_scaling: true\n  num_samples: 100\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: ettm2\n      split_val: true\n      scaler: identity # identity, standard, temporal\n      context_length: 96\n      prediction_length: 96\n  batch_size: 16\n  test_batch_size: 16\n  num_workers: 8\n"
  },
  {
    "path": "config/ltsf/ettm2/patchtst.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\n  accumulate_grad_batches: 1\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.PatchTST\n    init_args:\n      stride: 8\n      patch_len: 16\n      dropout: 0.2\n      f_hidden_size: 128\n      n_layers: 3\n      n_heads: 16\n      fc_dropout: 0.2\n      head_dropout: 0\n      individual: true\n  learning_rate: 0.0001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: ettm2\n      split_val: true\n      scaler: standard # identity, standard, temporal\n      context_length: 96\n      prediction_length: 96\n  batch_size: 32\n  test_batch_size: 32\n  num_workers: 8"
  },
  {
    "path": "config/ltsf/ettm2/timegrad.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\n  accumulate_grad_batches: 4\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.TimeGrad\n    init_args:\n      loss_type: l2\n      diff_steps: 100\n      beta_end: 0.1\n      beta_schedule: linear\n      conditional_length: 200\n      enc_hidden_size: 64\n      enc_num_layers: 2\n      enc_dropout: 0.1\n      use_lags: true\n      use_feat_idx_emb: true\n      use_time_feat: true\n      feat_idx_emb_dim: 1\n      use_scaling: true\n  num_samples: 100\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: ettm2\n      split_val: true\n      scaler: identity # identity, standard, temporal\n      context_length: 96\n      prediction_length: 96\n  batch_size: 16\n  test_batch_size: 16\n  num_workers: 8\n"
  },
  {
    "path": "config/ltsf/exchange_ltsf/csdi.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 800\n  log_every_n_steps: 1\n  check_val_every_n_epoch: 2\n  default_root_dir: ./results\n  accumulate_grad_batches: 8\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.CSDI\n    init_args:\n      emb_time_dim: 128\n      emb_feature_dim: 16\n      channels: 64\n      n_layers: 4\n      num_heads: 8\n      num_steps: 50\n      diffusion_embedding_dim: 128\n      beta_start: 0.001\n      beta_end: 0.5\n      sample_size: 64\n      linear_trans: false\n      use_lags: false\n      use_feat_idx_emb: false\n      use_time_feat: false\n      feat_idx_emb_dim: 1\n  num_samples: 100\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: exchange_ltsf\n      split_val: true\n      scaler: standard # identity, standard, temporal\n      context_length: 96\n      prediction_length: 96\n  batch_size: 8\n  test_batch_size: 8\n  num_workers: 8\n"
  },
  {
    "path": "config/ltsf/exchange_ltsf/dlinear.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  accumulate_grad_batches: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.DLinear\n    init_args:\n      individual: true\n      kernel_size: 25\n      use_lags: false\n      use_feat_idx_emb: false\n      use_time_feat: false\n  learning_rate: 0.0005\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: exchange_ltsf\n      split_val: true\n      scaler: standard # identity, standard, temporal\n      context_length: 96\n      prediction_length: 96\n  batch_size: 32\n  test_batch_size: 32\n  num_workers: 8\n"
  },
  {
    "path": "config/ltsf/exchange_ltsf/gru_nvp.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 400\n  log_every_n_steps: 1\n  default_root_dir: ./results\n  accumulate_grad_batches: 4\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.GRU_NVP\n    init_args:\n      enc_hidden_size: 128\n      enc_num_layers: 2\n      enc_dropout: 0.1\n      n_blocks: 2\n      hidden_size: 128\n      n_hidden: 3\n      batch_norm: false\n      conditional_length: 200\n      dequantize: false\n      use_lags: true\n      use_feat_idx_emb: true\n      use_time_feat: true\n      feat_idx_emb_dim: 1\n      use_scaling: true\n  num_samples: 100\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: exchange_ltsf\n      split_val: true\n      scaler: identity # identity, standard, temporal\n      context_length: 96\n      prediction_length: 96\n  batch_size: 16\n  test_batch_size: 16\n  num_workers: 8\n"
  },
  {
    "path": "config/ltsf/exchange_ltsf/patchtst.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\n  accumulate_grad_batches: 1\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.PatchTST\n    init_args:\n      stride: 8\n      patch_len: 16\n      dropout: 0.2\n      f_hidden_size: 16\n      n_layers: 3\n      n_heads: 4\n      fc_dropout: 0.2\n      head_dropout: 0\n      individual: true\n  learning_rate: 0.0001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: exchange_ltsf\n      split_val: true\n      scaler: standard # identity, standard, temporal\n      context_length: 96\n      prediction_length: 96\n  batch_size: 32\n  test_batch_size: 32\n  num_workers: 8"
  },
  {
    "path": "config/ltsf/exchange_ltsf/timegrad.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 400\n  log_every_n_steps: 1\n  default_root_dir: ./results\n  accumulate_grad_batches: 4\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.TimeGrad\n    init_args:\n      loss_type: l2\n      diff_steps: 100\n      beta_end: 0.1\n      beta_schedule: linear\n      conditional_length: 200\n      enc_hidden_size: 64\n      enc_num_layers: 4\n      enc_dropout: 0.1\n      use_lags: true\n      use_feat_idx_emb: true\n      use_time_feat: true\n      feat_idx_emb_dim: 1\n      use_scaling: true\n  num_samples: 100\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: exchange_ltsf\n      split_val: true\n      scaler: identity # identity, standard, temporal\n      context_length: 96\n      prediction_length: 96\n  batch_size: 16\n  test_batch_size: 16\n  num_workers: 8\n"
  },
  {
    "path": "config/ltsf/illness_ltsf/csdi.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 800\n  log_every_n_steps: 1\n  check_val_every_n_epoch: 2\n  default_root_dir: ./results\n  accumulate_grad_batches: 8\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.CSDI\n    init_args:\n      emb_time_dim: 128\n      emb_feature_dim: 16\n      channels: 64\n      n_layers: 4\n      num_heads: 8\n      num_steps: 50\n      diffusion_embedding_dim: 128\n      beta_start: 0.001\n      beta_end: 0.5\n      sample_size: 64\n      linear_trans: false\n      use_lags: false\n      use_feat_idx_emb: false\n      use_time_feat: false\n      feat_idx_emb_dim: 1\n  num_samples: 100\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: illness_ltsf\n      split_val: true\n      scaler: standard # identity, standard, temporal\n      context_length: 96\n      prediction_length: 96\n  batch_size: 8\n  test_batch_size: 8\n  num_workers: 8\n"
  },
  {
    "path": "config/ltsf/illness_ltsf/dlinear.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  accumulate_grad_batches: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.DLinear\n    init_args:\n      individual: false\n      kernel_size: 25\n      use_lags: false\n      use_feat_idx_emb: false\n      use_time_feat: false\n  learning_rate: 0.01\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: illness_ltsf\n      split_val: true\n      scaler: standard # identity, standard, temporal\n      context_length: 36\n      prediction_length: 36\n  batch_size: 32\n  test_batch_size: 32\n  num_workers: 8\n"
  },
  {
    "path": "config/ltsf/illness_ltsf/gru_nvp.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 400\n  log_every_n_steps: 1\n  default_root_dir: ./results\n  accumulate_grad_batches: 4\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.GRU_NVP\n    init_args:\n      enc_hidden_size: 64\n      enc_num_layers: 4\n      enc_dropout: 0.1\n      n_blocks: 4\n      hidden_size: 128\n      n_hidden: 2\n      batch_norm: false\n      conditional_length: 200\n      dequantize: false\n      use_lags: true\n      use_feat_idx_emb: true\n      use_time_feat: true\n      feat_idx_emb_dim: 1\n      use_scaling: true\n  num_samples: 100\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: illness_ltsf\n      split_val: true\n      scaler: identity # identity, standard, temporal\n      context_length: 36\n      prediction_length: 36\n  batch_size: 16\n  test_batch_size: 16\n  num_workers: 8\n"
  },
  {
    "path": "config/ltsf/illness_ltsf/patchtst.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\n  accumulate_grad_batches: 1\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.PatchTST\n    init_args:\n      stride: 2\n      patch_len: 24\n      dropout: 0.3\n      f_hidden_size: 16\n      n_layers: 3\n      n_heads: 4\n      fc_dropout: 0.3\n      head_dropout: 0\n      individual: true\n  learning_rate: 0.0025\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: illness_ltsf\n      path: /home/covpreduser/Blob/v-jiawezhang/data/all_datasets/\n      split_val: true\n      scaler: standard # identity, standard, temporal\n      context_length: 36\n      prediction_length: 36\n  batch_size: 32\n  test_batch_size: 32\n  num_workers: 8"
  },
  {
    "path": "config/ltsf/illness_ltsf/timegrad.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\n  accumulate_grad_batches: 4\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.TimeGrad\n    init_args:\n      loss_type: l2\n      diff_steps: 100\n      beta_end: 0.1\n      beta_schedule: linear\n      conditional_length: 200\n      enc_hidden_size: 64\n      enc_num_layers: 2\n      enc_dropout: 0.1\n      use_lags: true\n      use_feat_idx_emb: true\n      use_time_feat: true\n      feat_idx_emb_dim: 1\n      use_scaling: true\n  num_samples: 100\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: illness_ltsf\n      split_val: true\n      scaler: identity # identity, standard, temporal\n      context_length: 36\n      prediction_length: 36\n  batch_size: 16\n  test_batch_size: 16\n  num_workers: 8\n"
  },
  {
    "path": "config/ltsf/traffic_ltsf/csdi.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 800\n  log_every_n_steps: 1\n  check_val_every_n_epoch: 3\n  default_root_dir: ./results\n  accumulate_grad_batches: 8\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.CSDI\n    init_args:\n      emb_time_dim: 64\n      emb_feature_dim: 8\n      channels: 64\n      n_layers: 4\n      num_heads: 8\n      num_steps: 50\n      diffusion_embedding_dim: 64\n      beta_start: 0.001\n      beta_end: 0.5\n      sample_size: 16\n      linear_trans: false\n      use_lags: false\n      use_feat_idx_emb: false\n      use_time_feat: false\n      feat_idx_emb_dim: 1\n  num_samples: 100\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: traffic_ltsf\n      split_val: true\n      scaler: standard # identity, standard, temporal\n      context_length: 96\n      prediction_length: 96\n  batch_size: 4\n  test_batch_size: 4\n  num_workers: 8\n"
  },
  {
    "path": "config/ltsf/traffic_ltsf/dlinear.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 400\n  log_every_n_steps: 1\n  accumulate_grad_batches: 4\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.DLinear\n    init_args:\n      individual: false\n      kernel_size: 25\n      use_lags: false\n      use_feat_idx_emb: false\n      use_time_feat: false\n  learning_rate: 0.05\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: traffic_ltsf\n      split_val: true\n      scaler: standard # identity, standard, temporal\n      context_length: 96\n      prediction_length: 96\n  batch_size: 8\n  test_batch_size: 8\n  num_workers: 8\n"
  },
  {
    "path": "config/ltsf/traffic_ltsf/gru_nvp.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 400\n  log_every_n_steps: 1\n  default_root_dir: ./results\n  accumulate_grad_batches: 4\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.GRU_NVP\n    init_args:\n      enc_hidden_size: 128\n      enc_num_layers: 3\n      enc_dropout: 0.1\n      n_blocks: 4\n      hidden_size: 128\n      n_hidden: 3\n      batch_norm: true\n      conditional_length: 200\n      dequantize: false\n      use_lags: true\n      use_feat_idx_emb: true\n      use_time_feat: true\n      feat_idx_emb_dim: 1\n      use_scaling: true\n  num_samples: 100\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: traffic_ltsf\n      split_val: true\n      scaler: identity # identity, standard, temporal\n      context_length: 96\n      prediction_length: 96\n  batch_size: 16\n  test_batch_size: 16\n  num_workers: 8\n"
  },
  {
    "path": "config/ltsf/traffic_ltsf/patchtst.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 300\n  log_every_n_steps: 1\n  default_root_dir: ./results\n  accumulate_grad_batches: 3\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.PatchTST\n    init_args:\n      stride: 8\n      patch_len: 16\n      dropout: 0.2\n      f_hidden_size: 128\n      n_layers: 3\n      n_heads: 16\n      fc_dropout: 0.2\n      head_dropout: 0\n      individual: false\n  learning_rate: 0.0001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: traffic_ltsf\n      split_val: true\n      scaler: standard # identity, standard, temporal\n      context_length: 96\n      prediction_length: 96\n  batch_size: 8\n  test_batch_size: 8\n  num_workers: 8"
  },
  {
    "path": "config/ltsf/traffic_ltsf/timegrad.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\n  accumulate_grad_batches: 4\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.TimeGrad\n    init_args:\n      loss_type: l2\n      diff_steps: 100\n      beta_end: 0.1\n      beta_schedule: linear\n      conditional_length: 200\n      enc_hidden_size: 128\n      enc_num_layers: 3\n      enc_dropout: 0.1\n      use_lags: true\n      use_feat_idx_emb: true\n      use_time_feat: true\n      feat_idx_emb_dim: 1\n      use_scaling: true\n  num_samples: 100\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: traffic_ltsf\n      split_val: true\n      scaler: identity # identity, standard, temporal\n      context_length: 96\n      prediction_length: 96\n  batch_size: 16\n  test_batch_size: 16\n  num_workers: 8\n"
  },
  {
    "path": "config/ltsf/weather_ltsf/csdi.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 800\n  log_every_n_steps: 1\n  check_val_every_n_epoch: 2\n  default_root_dir: ./results\n  accumulate_grad_batches: 8\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.CSDI\n    init_args:\n      emb_time_dim: 128\n      emb_feature_dim: 16\n      channels: 64\n      n_layers: 4\n      num_heads: 8\n      num_steps: 50\n      diffusion_embedding_dim: 128\n      beta_start: 0.001\n      beta_end: 0.5\n      sample_size: 64\n      linear_trans: false\n      use_lags: false\n      use_feat_idx_emb: false\n      use_time_feat: false\n      feat_idx_emb_dim: 1\n  num_samples: 100\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: weather_ltsf\n      split_val: true\n      scaler: standard # identity, standard, temporal\n      context_length: 96\n      prediction_length: 96\n  batch_size: 8\n  test_batch_size: 8\n  num_workers: 8\n"
  },
  {
    "path": "config/ltsf/weather_ltsf/dlinear.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  accumulate_grad_batches: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.DLinear\n    init_args:\n      individual: false\n      kernel_size: 25\n      use_lags: false\n      use_feat_idx_emb: false\n      use_time_feat: false\n  learning_rate: 0.0001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: weather_ltsf\n      split_val: true\n      scaler: standard # identity, standard, temporal\n      context_length: 96\n      prediction_length: 96\n  batch_size: 32\n  test_batch_size: 32\n  num_workers: 8\n"
  },
  {
    "path": "config/ltsf/weather_ltsf/gru_nvp.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 400\n  log_every_n_steps: 1\n  default_root_dir: ./results\n  accumulate_grad_batches: 4\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.GRU_NVP\n    init_args:\n      enc_hidden_size: 64\n      enc_num_layers: 4\n      enc_dropout: 0.1\n      n_blocks: 4\n      hidden_size: 128\n      n_hidden: 3\n      batch_norm: false\n      conditional_length: 200\n      dequantize: false\n      use_lags: true\n      use_feat_idx_emb: true\n      use_time_feat: true\n      feat_idx_emb_dim: 1\n      use_scaling: true\n  num_samples: 100\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: weather_ltsf\n      split_val: true\n      scaler: identity # identity, standard, temporal\n      context_length: 96\n      prediction_length: 96\n  batch_size: 16\n  test_batch_size: 16\n  num_workers: 8\n"
  },
  {
    "path": "config/ltsf/weather_ltsf/patchtst.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\n  accumulate_grad_batches: 1\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.PatchTST\n    init_args:\n      stride: 8\n      patch_len: 16\n      dropout: 0.2\n      f_hidden_size: 128\n      n_layers: 3\n      n_heads: 16\n      fc_dropout: 0.2\n      head_dropout: 0\n      individual: false\n  learning_rate: 0.0001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: weather_ltsf\n      split_val: true\n      scaler: standard # identity, standard, temporal\n      context_length: 96\n      prediction_length: 96\n  batch_size: 32\n  test_batch_size: 32\n  num_workers: 8"
  },
  {
    "path": "config/ltsf/weather_ltsf/timegrad.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\n  accumulate_grad_batches: 4\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.TimeGrad\n    init_args:\n      loss_type: l2\n      diff_steps: 100\n      beta_end: 0.1\n      beta_schedule: linear\n      conditional_length: 200\n      enc_hidden_size: 64\n      enc_num_layers: 4\n      enc_dropout: 0.1\n      use_lags: true\n      use_feat_idx_emb: true\n      use_time_feat: true\n      feat_idx_emb_dim: 1\n      use_scaling: true\n  num_samples: 100\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: weather_ltsf\n      split_val: true\n      scaler: identity # identity, standard, temporal\n      context_length: 96\n      prediction_length: 96\n  batch_size: 16\n  test_batch_size: 16\n  num_workers: 8\n"
  },
  {
    "path": "config/m4/m4_daily/csdi.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 800\n  log_every_n_steps: 2\n  check_val_every_n_epoch: 2\n  default_root_dir: ./results\n  accumulate_grad_batches: 8\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.CSDI\n    init_args:\n      emb_time_dim: 32\n      emb_feature_dim: 4\n      channels: 16\n      n_layers: 4\n      num_heads: 4\n      num_steps: 50\n      diffusion_embedding_dim: 32\n      beta_start: 0.001\n      beta_end: 0.5\n      sample_size: 64\n      linear_trans: false\n      use_lags: false\n      use_feat_idx_emb: false\n      use_time_feat: false\n      feat_idx_emb_dim: 1\n  num_samples: 100\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: m4_daily\n      context_length_factor: 3\n      split_val: true\n      scaler: standard # identity, standard, temporal\n  batch_size: 1\n  test_batch_size: 1\n  num_workers: 8\n"
  },
  {
    "path": "config/m4/m4_daily/dlinear.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 800\n  log_every_n_steps: 2\n  default_root_dir: ./results\n  accumulate_grad_batches: 8\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.DLinear\n    init_args:\n      individual: false\n      kernel_size: 3\n      use_lags: false\n      use_feat_idx_emb: false\n      use_time_feat: false\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: m4_daily\n      context_length_factor: 3\n      split_val: true\n      scaler: standard # identity, standard, temporal\n  batch_size: 1\n  test_batch_size: 1\n  num_workers: 8\n"
  },
  {
    "path": "config/m4/m4_daily/gru_nvp.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 800\n  log_every_n_steps: 2\n  default_root_dir: ./results\n  accumulate_grad_batches: 8\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.GRU_NVP\n    init_args:\n      enc_hidden_size: 40\n      enc_num_layers: 2\n      enc_dropout: 0.1\n      n_blocks: 2\n      hidden_size: 100\n      n_hidden: 2\n      batch_norm: true\n      conditional_length: 100\n      dequantize: false\n      use_lags: false\n      use_feat_idx_emb: false\n      use_time_feat: false\n      feat_idx_emb_dim: 1\n      use_scaling: true\n  num_samples: 100\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: m4_daily\n      context_length_factor: 3\n      split_val: true\n      scaler: identity # identity, standard, temporal\n  batch_size: 1\n  test_batch_size: 1\n  num_workers: 8\n"
  },
  {
    "path": "config/m4/m4_daily/patchtst.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 800\n  log_every_n_steps: 2\n  default_root_dir: ./results\n  accumulate_grad_batches: 8\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.PatchTST\n    init_args:\n      stride: 2\n      patch_len: 6\n      dropout: 0.3\n      f_hidden_size: 32\n      d_ff: 128\n      n_layers: 3\n      n_heads: 8\n      fc_dropout: 0.2\n      head_dropout: 0\n      individual: true\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: m4_daily\n      context_length_factor: 3\n      split_val: true\n      scaler: standard # identity, standard, temporal\n  batch_size: 1\n  test_batch_size: 128\n  num_workers: 8"
  },
  {
    "path": "config/m4/m4_daily/timegrad.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 800\n  log_every_n_steps: 2\n  check_val_every_n_epoch: 2\n  default_root_dir: ./results\n  accumulate_grad_batches: 8\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.TimeGrad\n    init_args:\n      loss_type: l2\n      diff_steps: 50\n      beta_end: 0.1\n      beta_schedule: linear\n      conditional_length: 100\n      enc_hidden_size: 64\n      enc_num_layers: 4\n      enc_dropout: 0.1\n      use_lags: false\n      use_feat_idx_emb: false\n      use_time_feat: false\n      feat_idx_emb_dim: 1\n      use_scaling: true\n  num_samples: 100\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: m4_daily\n      context_length_factor: 3\n      split_val: true\n      scaler: identity # identity, standard, temporal\n  batch_size: 1\n  test_batch_size: 1\n  num_workers: 8\n"
  },
  {
    "path": "config/m4/m4_weekly/csdi.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 800\n  log_every_n_steps: 2\n  check_val_every_n_epoch: 2\n  default_root_dir: ./results\n  accumulate_grad_batches: 8\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.CSDI\n    init_args:\n      emb_time_dim: 32\n      emb_feature_dim: 4\n      channels: 16\n      n_layers: 4\n      num_heads: 4\n      num_steps: 50\n      diffusion_embedding_dim: 32\n      beta_start: 0.001\n      beta_end: 0.5\n      sample_size: 64\n      linear_trans: false\n      use_lags: false\n      use_feat_idx_emb: false\n      use_time_feat: false\n      feat_idx_emb_dim: 1\n  num_samples: 100\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: m4_weekly\n      context_length_factor: 3\n      split_val: true\n      scaler: standard # identity, standard, temporal\n  batch_size: 1\n  test_batch_size: 1\n  num_workers: 8\n"
  },
  {
    "path": "config/m4/m4_weekly/dlinear.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 800\n  log_every_n_steps: 2\n  default_root_dir: ./results\n  accumulate_grad_batches: 8\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.DLinear\n    init_args:\n      individual: false\n      kernel_size: 3\n      use_lags: false\n      use_feat_idx_emb: false\n      use_time_feat: false\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: m4_weekly\n      context_length_factor: 3\n      split_val: true\n      scaler: standard # identity, standard, temporal\n  batch_size: 1\n  test_batch_size: 1\n  num_workers: 8\n"
  },
  {
    "path": "config/m4/m4_weekly/gru_nvp.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 800\n  log_every_n_steps: 2\n  default_root_dir: ./results\n  accumulate_grad_batches: 8\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.GRU_NVP\n    init_args:\n      enc_hidden_size: 40\n      enc_num_layers: 2\n      enc_dropout: 0.1\n      n_blocks: 2\n      hidden_size: 100\n      n_hidden: 2\n      batch_norm: true\n      conditional_length: 100\n      dequantize: false\n      use_lags: false\n      use_feat_idx_emb: false\n      use_time_feat: false\n      feat_idx_emb_dim: 1\n      use_scaling: true\n  num_samples: 100\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: m4_weekly\n      context_length_factor: 3\n      split_val: true\n      scaler: identity # identity, standard, temporal\n  batch_size: 1\n  test_batch_size: 1\n  num_workers: 8\n"
  },
  {
    "path": "config/m4/m4_weekly/patchtst.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 800\n  log_every_n_steps: 2\n  default_root_dir: ./results\n  accumulate_grad_batches: 8\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.PatchTST\n    init_args:\n      stride: 3\n      patch_len: 6\n      dropout: 0.3\n      f_hidden_size: 32\n      d_ff: 128\n      n_layers: 3\n      n_heads: 8\n      fc_dropout: 0.2\n      head_dropout: 0\n      individual: true\n  learning_rate: 0.0001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: m4_weekly\n      context_length_factor: 3\n      split_val: true\n      scaler: standard # identity, standard, temporal\n  batch_size: 1\n  test_batch_size: 128\n  num_workers: 8"
  },
  {
    "path": "config/m4/m4_weekly/timegrad.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 800\n  log_every_n_steps: 2\n  check_val_every_n_epoch: 2\n  default_root_dir: ./results\n  accumulate_grad_batches: 8\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.TimeGrad\n    init_args:\n      loss_type: l2\n      diff_steps: 50\n      beta_end: 0.1\n      beta_schedule: linear\n      conditional_length: 100\n      enc_hidden_size: 64\n      enc_num_layers: 4\n      enc_dropout: 0.1\n      use_lags: false\n      use_feat_idx_emb: false\n      use_time_feat: false\n      feat_idx_emb_dim: 1\n      use_scaling: true\n  num_samples: 100\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: m4_weekly\n      context_length_factor: 3\n      split_val: true\n      scaler: identity # identity, standard, temporal\n  batch_size: 1\n  test_batch_size: 1\n  num_workers: 8\n"
  },
  {
    "path": "config/m4/m5/csdi.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 800\n  log_every_n_steps: 2\n  check_val_every_n_epoch: 2\n  default_root_dir: ./results\n  accumulate_grad_batches: 8\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.CSDI\n    init_args:\n      emb_time_dim: 32\n      emb_feature_dim: 4\n      channels: 16\n      n_layers: 4\n      num_heads: 4\n      num_steps: 50\n      diffusion_embedding_dim: 32\n      beta_start: 0.001\n      beta_end: 0.5\n      sample_size: 64\n      linear_trans: false\n      use_lags: false\n      use_feat_idx_emb: false\n      use_time_feat: false\n      feat_idx_emb_dim: 1\n  num_samples: 100\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: m5\n      context_length_factor: 3\n      split_val: true\n      scaler: standard # identity, standard, temporal\n  batch_size: 1\n  test_batch_size: 1\n  num_workers: 8\n"
  },
  {
    "path": "config/m4/m5/dlinear.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 800\n  log_every_n_steps: 2\n  check_val_every_n_epoch: 2\n  default_root_dir: ./results\n  accumulate_grad_batches: 8\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.DLinear\n    init_args:\n      individual: false\n      kernel_size: 3\n      use_lags: false\n      use_feat_idx_emb: false\n      use_time_feat: false\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: m5\n      context_length_factor: 3\n      split_val: true\n      scaler: standard # identity, standard, temporal\n  batch_size: 1\n  test_batch_size: 256\n  num_workers: 8\n"
  },
  {
    "path": "config/m4/m5/gru_nvp.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 800\n  log_every_n_steps: 2\n  check_val_every_n_epoch: 2\n  default_root_dir: ./results\n  accumulate_grad_batches: 8\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.GRU_NVP\n    init_args:\n      enc_hidden_size: 40\n      enc_num_layers: 2\n      enc_dropout: 0.1\n      n_blocks: 2\n      hidden_size: 100\n      n_hidden: 2\n      batch_norm: true\n      conditional_length: 100\n      dequantize: false\n      use_lags: false\n      use_feat_idx_emb: false\n      use_time_feat: false\n      feat_idx_emb_dim: 1\n      use_scaling: true\n  num_samples: 100\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: m5\n      context_length_factor: 3\n      split_val: true\n      scaler: identity # identity, standard, temporal\n  batch_size: 1\n  test_batch_size: 1\n  num_workers: 8\n"
  },
  {
    "path": "config/m4/m5/patchtst.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 800\n  log_every_n_steps: 2\n  check_val_every_n_epoch: 2\n  default_root_dir: ./results\n  accumulate_grad_batches: 8\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.PatchTST\n    init_args:\n      stride: 2\n      patch_len: 4\n      dropout: 0.3\n      f_hidden_size: 64\n      d_ff: 128\n      n_layers: 3\n      n_heads: 8\n      fc_dropout: 0.2\n      head_dropout: 0\n      individual: true\n  learning_rate: 0.0001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: m5\n      context_length_factor: 3\n      split_val: true\n      scaler: standard # identity, standard, temporal\n  batch_size: 1\n  test_batch_size: 128\n  num_workers: 8"
  },
  {
    "path": "config/m4/m5/timegrad.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 30\n  use_distributed_sampler: false\n  limit_train_batches: 800\n  log_every_n_steps: 2\n  check_val_every_n_epoch: 2\n  default_root_dir: ./results\n  accumulate_grad_batches: 8\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.TimeGrad\n    init_args:\n      loss_type: l2\n      diff_steps: 50\n      beta_end: 0.1\n      beta_schedule: linear\n      conditional_length: 100\n      enc_hidden_size: 64\n      enc_num_layers: 4\n      enc_dropout: 0.1\n      use_lags: false\n      use_feat_idx_emb: false\n      use_time_feat: false\n      feat_idx_emb_dim: 1\n      use_scaling: true\n  num_samples: 100\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: m5\n      context_length_factor: 3\n      split_val: true\n      scaler: identity # identity, standard, temporal\n  batch_size: 1\n  test_batch_size: 512\n  num_workers: 8\n"
  },
  {
    "path": "config/m4/tourism_monthly/csdi.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 800\n  log_every_n_steps: 2\n  check_val_every_n_epoch: 2\n  default_root_dir: ./results\n  accumulate_grad_batches: 8\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.CSDI\n    init_args:\n      emb_time_dim: 32\n      emb_feature_dim: 4\n      channels: 16\n      n_layers: 4\n      num_heads: 4\n      num_steps: 50\n      diffusion_embedding_dim: 32\n      beta_start: 0.001\n      beta_end: 0.5\n      sample_size: 64\n      linear_trans: false\n      use_lags: false\n      use_feat_idx_emb: false\n      use_time_feat: false\n      feat_idx_emb_dim: 1\n  num_samples: 100\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: tourism_monthly\n      context_length_factor: 3\n      split_val: true\n      scaler: standard # identity, standard, temporal\n  batch_size: 1\n  test_batch_size: 1\n  num_workers: 8\n"
  },
  {
    "path": "config/m4/tourism_monthly/dlinear.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 800\n  log_every_n_steps: 2\n  default_root_dir: ./results\n  accumulate_grad_batches: 8\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.DLinear\n    init_args:\n      individual: false\n      kernel_size: 3\n      use_lags: false\n      use_feat_idx_emb: false\n      use_time_feat: false\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: tourism_monthly\n      context_length_factor: 3\n      split_val: true\n      scaler: standard # identity, standard, temporal\n  batch_size: 1\n  test_batch_size: 1\n  num_workers: 8\n"
  },
  {
    "path": "config/m4/tourism_monthly/gru_nvp.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 800\n  log_every_n_steps: 2\n  default_root_dir: ./results\n  accumulate_grad_batches: 8\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.GRU_NVP\n    init_args:\n      enc_hidden_size: 40\n      enc_num_layers: 2\n      enc_dropout: 0.1\n      n_blocks: 2\n      hidden_size: 100\n      n_hidden: 2\n      batch_norm: true\n      conditional_length: 100\n      dequantize: false\n      use_lags: false\n      use_feat_idx_emb: false\n      use_time_feat: false\n      feat_idx_emb_dim: 1\n      use_scaling: true\n  num_samples: 100\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: tourism_monthly\n      context_length_factor: 3\n      split_val: true\n      scaler: identity # identity, standard, temporal\n  batch_size: 1\n  test_batch_size: 1\n  num_workers: 8\n"
  },
  {
    "path": "config/m4/tourism_monthly/patchtst.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 800\n  log_every_n_steps: 2\n  default_root_dir: ./results\n  accumulate_grad_batches: 8\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.PatchTST\n    init_args:\n      stride: 2\n      patch_len: 6\n      dropout: 0.3\n      f_hidden_size: 64\n      d_ff: 128\n      n_layers: 3\n      n_heads: 8\n      fc_dropout: 0.2\n      head_dropout: 0\n      individual: true\n  learning_rate: 0.0001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: tourism_monthly\n      context_length_factor: 3\n      split_val: true\n      scaler: standard # identity, standard, temporal\n  batch_size: 1\n  test_batch_size: 128\n  num_workers: 8"
  },
  {
    "path": "config/m4/tourism_monthly/timegrad.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 800\n  log_every_n_steps: 2\n  check_val_every_n_epoch: 2\n  default_root_dir: ./results\n  accumulate_grad_batches: 8\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.TimeGrad\n    init_args:\n      loss_type: l2\n      diff_steps: 50\n      beta_end: 0.1\n      beta_schedule: linear\n      conditional_length: 100\n      enc_hidden_size: 64\n      enc_num_layers: 4\n      enc_dropout: 0.1\n      use_lags: false\n      use_feat_idx_emb: false\n      use_time_feat: false\n      feat_idx_emb_dim: 1\n      use_scaling: true\n  num_samples: 100\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: tourism_monthly\n      context_length_factor: 3\n      split_val: true\n      scaler: identity # identity, standard, temporal\n  batch_size: 1\n  test_batch_size: 1\n  num_workers: 8\n"
  },
  {
    "path": "config/multi_hor/autoformer.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 0\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\n  accumulate_grad_batches: 1\n  # num_sanity_val_steps: 0\n  # gradient_clip_algorithm: 'norm'\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.Autoformer\n    init_args:\n      moving_avg: 25\n      factor: 1\n      n_heads: 8\n      activation: 'gelu'\n      e_layers: 2\n      d_layers: 1\n      output_attention: false\n      d_ff: 512\n      f_hidden_size: 512\n      embed: 'timeF'\n      use_lags: false\n      use_feat_idx_emb: false\n      use_time_feat: true\n      feat_idx_emb_dim: 1\n  num_samples: 1\n  learning_rate: 1e-3\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: etth1\n      split_val: true\n      scaler: standard # identity, standard, temporal\n      context_length: 96\n      prediction_length: 24-96-192-336-720-1024\n      train_ctx_len: 96\n      train_pred_len_list: 720\n      val_ctx_len: 96\n      val_pred_len_list: 720\n  batch_size: 32\n  test_batch_size: 32\n  num_workers: 8"
  },
  {
    "path": "config/multi_hor/elastst.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\n  accumulate_grad_batches: 1\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.ElasTST\n    init_args:\n      l_patch_size: '8_16_32'\n      dropout: 0.0\n      f_hidden_size: 256\n      d_inner: 256\n      t_layers: 2\n      v_layers: 0\n      n_heads: 8\n      d_v: 64\n      d_k: 64\n      structured_mask: true\n      rotate: true\n      rope_theta_init: 'exp'\n      learnable_rope: true\n      min_period: 1\n      max_period: 1000\n      addv: false\n      bin_att: false\n      learn_tem_emb: false\n  learning_rate: 0.001\n  quantiles_num: 20\n  sampling_weight_scheme: random\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: etth1\n      split_val: true\n      scaler: standard # identity, standard, temporal\n      context_length: 96\n      prediction_length: 24-96-192-336-720-1024\n      train_ctx_len: 96\n      train_pred_len_list: 720\n      val_ctx_len: 96\n      val_pred_len_list: 720\n      continuous_sample: false \n  batch_size: 32\n  test_batch_size: 32\n  num_workers: 8"
  },
  {
    "path": "config/pipeline_config.yaml",
    "content": "# lightning.pytorch==2.3.0dev\nseed_everything: true\ntrainer:\n  accelerator: auto\n  strategy: auto\n  devices: auto\n  num_nodes: 1\n  precision: null\n  logger: null\n  callbacks: null\n  fast_dev_run: false\n  max_epochs: null\n  min_epochs: null\n  max_steps: -1\n  min_steps: null\n  max_time: null\n  limit_train_batches: null\n  limit_val_batches: null\n  limit_test_batches: null\n  limit_predict_batches: null\n  overfit_batches: 0.0\n  val_check_interval: null\n  check_val_every_n_epoch: 1\n  num_sanity_val_steps: null\n  log_every_n_steps: null\n  enable_checkpointing: null\n  enable_progress_bar: null\n  enable_model_summary: null\n  accumulate_grad_batches: 1\n  gradient_clip_val: null\n  gradient_clip_algorithm: null\n  deterministic: null\n  benchmark: null\n  inference_mode: true\n  use_distributed_sampler: true\n  profiler: null\n  detect_anomaly: false\n  barebones: false\n  plugins: null\n  sync_batchnorm: false\n  reload_dataloaders_every_n_epochs: 0\n  default_root_dir: null\nmodel:\n  forecaster: null\n  num_samples: 100\n  learning_rate: 0.001\n  quantiles_num: 10\n  load_from_ckpt: null\ndata:\n  data_manager: null\n  batch_size: 64\n  test_batch_size: 8\n  num_workers: 8\n"
  },
  {
    "path": "config/stsf/electricity/csdi.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 800\n  log_every_n_steps: 1\n  check_val_every_n_epoch: 2\n  default_root_dir: ./results\n  accumulate_grad_batches: 8\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.CSDI\n    init_args:\n      emb_time_dim: 128\n      emb_feature_dim: 16\n      channels: 64\n      n_layers: 4\n      num_heads: 8\n      num_steps: 50\n      diffusion_embedding_dim: 128\n      beta_start: 0.001\n      beta_end: 0.5\n      sample_size: 64\n      linear_trans: false\n      use_lags: false\n      use_feat_idx_emb: false\n      use_time_feat: false\n      feat_idx_emb_dim: 1\n  num_samples: 100\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: electricity_nips\n      split_val: true\n      scaler: standard # identity, standard, temporal\n  batch_size: 4\n  test_batch_size: 4\n  num_workers: 8\n"
  },
  {
    "path": "config/stsf/electricity/dlinear.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.DLinear\n    init_args:\n      individual: true\n      kernel_size: 3\n      use_lags: false\n      use_feat_idx_emb: false\n      use_time_feat: false\n  learning_rate: 0.01\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: electricity_nips\n      split_val: true\n      scaler: standard # identity, standard, temporal\n  batch_size: 32\n  test_batch_size: 32\n  num_workers: 8"
  },
  {
    "path": "config/stsf/electricity/gru.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.GRUForecaster\n    init_args:\n      f_hidden_size: 40\n      num_layers: 2\n      dropout: 0.1\n      use_lags: true\n      use_feat_idx_emb: true\n      use_time_feat: true\n      feat_idx_emb_dim: 1\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: electricity_nips\n      split_val: true\n      scaler: standard # identity, standard, temporal\n  batch_size: 64\n  test_batch_size: 64\n  num_workers: 8\n"
  },
  {
    "path": "config/stsf/electricity/gru_maf.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.GRU_MAF\n    init_args:\n      enc_num_layers: 2\n      enc_hidden_size: 40\n      enc_dropout: 0.1\n      n_blocks: 4\n      hidden_size: 100\n      n_hidden: 2\n      batch_norm: true\n      conditional_length: 200\n      dequantize: false\n      use_lags: true\n      use_feat_idx_emb: true\n      use_time_feat: true\n      feat_idx_emb_dim: 1\n      use_scaling: true\n  num_samples: 100\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: electricity_nips\n      split_val: true\n      scaler: identity # identity, standard, temporal\n  batch_size: 64\n  test_batch_size: 64\n  num_workers: 8\n"
  },
  {
    "path": "config/stsf/electricity/gru_nvp.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.GRU_NVP\n    init_args:\n      enc_hidden_size: 40\n      enc_num_layers: 2\n      enc_dropout: 0.1\n      n_blocks: 3\n      hidden_size: 100\n      n_hidden: 2\n      batch_norm: true\n      conditional_length: 200\n      dequantize: false\n      use_lags: true\n      use_feat_idx_emb: true\n      use_time_feat: true\n      feat_idx_emb_dim: 1\n      use_scaling: true\n  num_samples: 100\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: electricity_nips\n      split_val: true\n      scaler: identity # identity, standard, temporal\n  batch_size: 64\n  test_batch_size: 64\n  num_workers: 8\n"
  },
  {
    "path": "config/stsf/electricity/patchtst.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\n  accumulate_grad_batches: 1\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.PatchTST\n    init_args:\n      stride: 2\n      patch_len: 4\n      dropout: 0.1\n      f_hidden_size: 64\n      n_layers: 4\n      n_heads: 8\n      fc_dropout: 0.1\n      head_dropout: 0\n      individual: true\n  learning_rate: 0.0001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: electricity_nips\n      split_val: true\n      scaler: standard # identity, standard, temporal\n  batch_size: 64\n  test_batch_size: 64\n  num_workers: 8"
  },
  {
    "path": "config/stsf/electricity/timegrad.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.TimeGrad\n    init_args:\n      loss_type: l2\n      diff_steps: 100\n      beta_end: 0.1\n      beta_schedule: linear\n      conditional_length: 100\n      enc_hidden_size: 128\n      enc_num_layers: 4\n      enc_dropout: 0.1\n      use_lags: true\n      use_feat_idx_emb: true\n      use_time_feat: true\n      feat_idx_emb_dim: 1\n      use_scaling: true\n  num_samples: 100\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: electricity_nips\n      split_val: true\n      scaler: identity # identity, standard, temporal\n  batch_size: 64\n  test_batch_size: 64\n  num_workers: 8\n"
  },
  {
    "path": "config/stsf/electricity/timesnet.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.TimesNet\n    init_args:\n      n_layers: 2\n      num_kernels: 6\n      top_k: 5\n      d_ff: 64\n      dropout: 0.1\n      f_hidden_size: 64\n      use_lags: false\n      use_feat_idx_emb: false\n      use_time_feat: false\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: electricity_nips\n      split_val: true\n      scaler: standard # identity, standard, temporal\n  batch_size: 64\n  test_batch_size: 64\n  num_workers: 8\n"
  },
  {
    "path": "config/stsf/electricity/trans_maf.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.Trans_MAF\n    init_args:\n      enc_hidden_size: 32\n      enc_num_heads: 8\n      enc_num_encoder_layers: 2\n      enc_num_decoder_layers: 2\n      enc_dim_feedforward_scale: 4\n      enc_dropout: 0.1\n      enc_activation: gelu\n      n_blocks: 4\n      hidden_size: 100\n      n_hidden: 2\n      batch_norm: true\n      conditional_length: 200\n      dequantize: false\n      use_lags: true\n      use_feat_idx_emb: true\n      use_time_feat: true\n      feat_idx_emb_dim: 1\n      use_scaling: true\n  num_samples: 100\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: electricity_nips\n      split_val: true\n      scaler: identity # identity, standard, temporal\n  batch_size: 64\n  test_batch_size: 64\n  num_workers: 8\n"
  },
  {
    "path": "config/stsf/electricity/transformer.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.TransformerForecaster\n    init_args:\n      f_hidden_size: 32\n      num_heads: 8\n      num_encoder_layers: 3\n      num_decoder_layers: 3\n      dim_feedforward_scale: 4\n      dropout: 0.1\n      activation: gelu\n      use_lags: true\n      use_feat_idx_emb: true\n      use_time_feat: true\n      feat_idx_emb_dim: 1\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: electricity_nips\n      split_val: true\n      scaler: standard # identity, standard, temporal\n  batch_size: 64\n  test_batch_size: 64\n  num_workers: 8\n"
  },
  {
    "path": "config/stsf/exchange/csdi.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 800\n  log_every_n_steps: 1\n  check_val_every_n_epoch: 2\n  default_root_dir: ./results\n  accumulate_grad_batches: 8\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.CSDI\n    init_args:\n      emb_time_dim: 128\n      emb_feature_dim: 16\n      channels: 64\n      n_layers: 4\n      num_heads: 8\n      num_steps: 50\n      diffusion_embedding_dim: 128\n      beta_start: 0.001\n      beta_end: 0.5\n      sample_size: 64\n      linear_trans: false\n      use_lags: false\n      use_feat_idx_emb: false\n      use_time_feat: false\n      feat_idx_emb_dim: 1\n  num_samples: 100\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: exchange_rate_nips\n      split_val: true\n      scaler: standard # identity, standard, temporal\n  batch_size: 4\n  test_batch_size: 4\n  num_workers: 8\n"
  },
  {
    "path": "config/stsf/exchange/dlinear.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.DLinear\n    init_args:\n      individual: false\n      kernel_size: 3\n      use_lags: false\n      use_feat_idx_emb: false\n      use_time_feat: false\n  learning_rate: 0.01\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: exchange_rate_nips\n      split_val: true\n      scaler: standard # identity, standard, temporal\n  batch_size: 32\n  test_batch_size: 32\n  num_workers: 8"
  },
  {
    "path": "config/stsf/exchange/gru.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.GRUForecaster\n    init_args:\n      f_hidden_size: 40\n      num_layers: 2\n      dropout: 0.1\n      use_lags: true\n      use_feat_idx_emb: true\n      use_time_feat: true\n      feat_idx_emb_dim: 1\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: exchange_rate_nips\n      split_val: true\n      scaler: standard # identity, standard, temporal\n  batch_size: 64\n  test_batch_size: 64\n  num_workers: 8\n"
  },
  {
    "path": "config/stsf/exchange/gru_maf.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.GRU_MAF\n    init_args:\n      enc_num_layers: 2\n      enc_hidden_size: 40\n      enc_dropout: 0.1\n      n_blocks: 4\n      hidden_size: 100\n      n_hidden: 2\n      batch_norm: false\n      conditional_length: 200\n      dequantize: false\n      use_lags: true\n      use_feat_idx_emb: true\n      use_time_feat: true\n      feat_idx_emb_dim: 1\n      use_scaling: true\n  num_samples: 100\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: exchange_rate_nips\n      split_val: true\n      scaler: identity # identity, standard, temporal\n  batch_size: 64\n  test_batch_size: 64\n  num_workers: 8\n"
  },
  {
    "path": "config/stsf/exchange/gru_nvp.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.GRU_NVP\n    init_args:\n      enc_hidden_size: 40\n      enc_num_layers: 2\n      enc_dropout: 0.1\n      n_blocks: 4\n      hidden_size: 100\n      n_hidden: 2\n      batch_norm: true\n      conditional_length: 200\n      dequantize: false\n      use_lags: true\n      use_feat_idx_emb: true\n      use_time_feat: true\n      feat_idx_emb_dim: 1\n      use_scaling: true\n  num_samples: 100\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: exchange_rate_nips\n      split_val: true\n      scaler: identity # identity, standard, temporal\n  batch_size: 64\n  test_batch_size: 64\n  num_workers: 8\n"
  },
  {
    "path": "config/stsf/exchange/patchtst.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\n  accumulate_grad_batches: 1\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.PatchTST\n    init_args:\n      stride: 3\n      patch_len: 6\n      dropout: 0.1\n      f_hidden_size: 32\n      n_layers: 3\n      n_heads: 8\n      fc_dropout: 0.2\n      head_dropout: 0\n      individual: true\n  learning_rate: 0.0001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: exchange_rate_nips\n      split_val: true\n      scaler: standard # identity, standard, temporal\n  batch_size: 64\n  test_batch_size: 64\n  num_workers: 8"
  },
  {
    "path": "config/stsf/exchange/timegrad.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.TimeGrad\n    init_args:\n      loss_type: l2\n      diff_steps: 100\n      beta_end: 0.1\n      beta_schedule: linear\n      conditional_length: 100\n      enc_hidden_size: 128\n      enc_num_layers: 4\n      enc_dropout: 0.1\n      use_lags: true\n      use_feat_idx_emb: true\n      use_time_feat: true\n      feat_idx_emb_dim: 1\n      use_scaling: true\n  num_samples: 100\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: exchange_rate_nips\n      split_val: true\n      scaler: identity # identity, standard, temporal\n  batch_size: 64\n  test_batch_size: 64\n  num_workers: 8\n"
  },
  {
    "path": "config/stsf/exchange/timesnet.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.TimesNet\n    init_args:\n      n_layers: 2\n      num_kernels: 6\n      top_k: 5\n      d_ff: 64\n      dropout: 0.1\n      f_hidden_size: 64\n      use_lags: false\n      use_feat_idx_emb: false\n      use_time_feat: false\n  learning_rate: 0.0001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: exchange_rate_nips\n      split_val: true\n      scaler: standard # identity, standard, temporal\n  batch_size: 64\n  test_batch_size: 64\n  num_workers: 8\n"
  },
  {
    "path": "config/stsf/exchange/trans_maf.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.Trans_MAF\n    init_args:\n      enc_hidden_size: 16\n      enc_num_heads: 8\n      enc_num_encoder_layers: 2\n      enc_num_decoder_layers: 2\n      enc_dim_feedforward_scale: 4\n      enc_dropout: 0.1\n      enc_activation: gelu\n      n_blocks: 4\n      hidden_size: 100\n      n_hidden: 2\n      batch_norm: false\n      conditional_length: 200\n      dequantize: false\n      use_lags: true\n      use_feat_idx_emb: true\n      use_time_feat: true\n      feat_idx_emb_dim: 1\n      use_scaling: true\n  num_samples: 100\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: exchange_rate_nips\n      split_val: true\n      scaler: identity # identity, standard, temporal\n  batch_size: 64\n  test_batch_size: 64\n  num_workers: 8"
  },
  {
    "path": "config/stsf/exchange/transformer.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.TransformerForecaster\n    init_args:\n      f_hidden_size: 32\n      num_heads: 8\n      num_encoder_layers: 3\n      num_decoder_layers: 3\n      dim_feedforward_scale: 4\n      dropout: 0.1\n      activation: gelu\n      use_lags: true\n      use_feat_idx_emb: true\n      use_time_feat: true\n      feat_idx_emb_dim: 1\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: exchange_rate_nips\n      split_val: true\n      scaler: standard # identity, standard, temporal\n  batch_size: 64\n  test_batch_size: 64\n  num_workers: 8"
  },
  {
    "path": "config/stsf/solar/csdi.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 800\n  log_every_n_steps: 1\n  check_val_every_n_epoch: 2\n  default_root_dir: ./results\n  accumulate_grad_batches: 8\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.CSDI\n    init_args:\n      emb_time_dim: 128\n      emb_feature_dim: 16\n      channels: 64\n      n_layers: 4\n      num_heads: 8\n      num_steps: 50\n      diffusion_embedding_dim: 128\n      beta_start: 0.001\n      beta_end: 0.5\n      sample_size: 64\n      linear_trans: false\n      use_lags: false\n      use_feat_idx_emb: false\n      use_time_feat: false\n      feat_idx_emb_dim: 1\n  num_samples: 100\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: solar_nips\n      split_val: true\n      scaler: standard # identity, standard, temporal\n  batch_size: 4\n  test_batch_size: 4\n  num_workers: 8\n"
  },
  {
    "path": "config/stsf/solar/dlinear.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.DLinear\n    init_args:\n      individual: false\n      kernel_size: 3\n      use_lags: true\n      use_feat_idx_emb: true\n      use_time_feat: true\n  learning_rate: 0.01\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: solar_nips\n      split_val: true\n      scaler: standard # identity, standard, temporal\n  batch_size: 32\n  test_batch_size: 32\n  num_workers: 8"
  },
  {
    "path": "config/stsf/solar/gru.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.GRUForecaster\n    init_args:\n      f_hidden_size: 40\n      num_layers: 2\n      dropout: 0.1\n      use_lags: true\n      use_feat_idx_emb: true\n      use_time_feat: true\n      feat_idx_emb_dim: 1\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: solar_nips\n      split_val: true\n      scaler: standard # identity, standard, temporal\n  batch_size: 64\n  test_batch_size: 64\n  num_workers: 8\n"
  },
  {
    "path": "config/stsf/solar/gru_maf.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.GRU_MAF\n    init_args:\n      enc_num_layers: 2\n      enc_hidden_size: 40\n      enc_dropout: 0.1\n      n_blocks: 4\n      hidden_size: 100\n      n_hidden: 2\n      batch_norm: false\n      conditional_length: 200\n      dequantize: true\n      use_lags: true\n      use_feat_idx_emb: true\n      use_time_feat: true\n      feat_idx_emb_dim: 1\n      use_scaling: true\n  num_samples: 100\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: solar_nips\n      split_val: true\n      scaler: identity # identity, standard, temporal\n  batch_size: 64\n  test_batch_size: 64\n  num_workers: 8\n"
  },
  {
    "path": "config/stsf/solar/gru_nvp.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.GRU_NVP\n    init_args:\n      enc_hidden_size: 40\n      enc_num_layers: 2\n      enc_dropout: 0.1\n      n_blocks: 4\n      hidden_size: 100\n      n_hidden: 2\n      batch_norm: true\n      conditional_length: 200\n      dequantize: true\n      use_lags: true\n      use_feat_idx_emb: true\n      use_time_feat: true\n      feat_idx_emb_dim: 1\n      use_scaling: true\n  num_samples: 100\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: solar_nips\n      split_val: true\n      scaler: identity # identity, standard, temporal\n  batch_size: 64\n  test_batch_size: 64\n  num_workers: 8\n"
  },
  {
    "path": "config/stsf/solar/patchtst.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\n  accumulate_grad_batches: 1\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.PatchTST\n    init_args:\n      stride: 3\n      patch_len: 6\n      dropout: 0.1\n      f_hidden_size: 32\n      n_layers: 3\n      n_heads: 8\n      fc_dropout: 0.2\n      head_dropout: 0\n      individual: true\n  learning_rate: 0.0001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: solar_nips\n      split_val: true\n      scaler: standard # identity, standard, temporal\n  batch_size: 64\n  test_batch_size: 64\n  num_workers: 8"
  },
  {
    "path": "config/stsf/solar/timegrad.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.TimeGrad\n    init_args:\n      loss_type: l2\n      diff_steps: 100\n      beta_end: 0.1\n      beta_schedule: linear\n      conditional_length: 100\n      enc_hidden_size: 128\n      enc_num_layers: 4\n      enc_dropout: 0.1\n      use_lags: true\n      use_feat_idx_emb: true\n      use_time_feat: true\n      feat_idx_emb_dim: 1\n      use_scaling: true\n  num_samples: 100\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: solar_nips\n      split_val: true\n      scaler: identity # identity, standard, temporal\n  batch_size: 64\n  test_batch_size: 64\n  num_workers: 8\n"
  },
  {
    "path": "config/stsf/solar/timesnet.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.TimesNet\n    init_args:\n      n_layers: 2\n      num_kernels: 6\n      top_k: 5\n      d_ff: 16\n      dropout: 0.1\n      f_hidden_size: 16\n      use_lags: false\n      use_feat_idx_emb: false\n      use_time_feat: false\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: solar_nips\n      split_val: true\n      scaler: standard # identity, standard, temporal\n  batch_size: 64\n  test_batch_size: 64\n  num_workers: 8\n"
  },
  {
    "path": "config/stsf/solar/trans_maf.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.Trans_MAF\n    init_args:\n      enc_hidden_size: 32\n      enc_num_heads: 8\n      enc_num_encoder_layers: 2\n      enc_num_decoder_layers: 2\n      enc_dim_feedforward_scale: 4\n      enc_dropout: 0.1\n      enc_activation: gelu\n      n_blocks: 4\n      hidden_size: 100\n      n_hidden: 2\n      batch_norm: false\n      conditional_length: 200\n      dequantize: true\n      use_lags: true\n      use_feat_idx_emb: true\n      use_time_feat: true\n      feat_idx_emb_dim: 1\n      use_scaling: true\n  num_samples: 100\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: solar_nips\n      split_val: true\n      scaler: identity # identity, standard, temporal\n  batch_size: 64\n  test_batch_size: 64\n  num_workers: 8"
  },
  {
    "path": "config/stsf/solar/transformer.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.TransformerForecaster\n    init_args:\n      f_hidden_size: 16\n      num_heads: 4\n      num_encoder_layers: 3\n      num_decoder_layers: 3\n      dim_feedforward_scale: 4\n      dropout: 0.1\n      activation: gelu\n      use_lags: true\n      use_feat_idx_emb: true\n      use_time_feat: true\n      feat_idx_emb_dim: 1\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: solar_nips\n      split_val: true\n      scaler: standard # identity, standard, temporal\n  batch_size: 64\n  test_batch_size: 64\n  num_workers: 8"
  },
  {
    "path": "config/stsf/traffic/csdi.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 400\n  log_every_n_steps: 1\n  check_val_every_n_epoch: 3\n  default_root_dir: ./results\n  accumulate_grad_batches: 4\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.CSDI\n    init_args:\n      emb_time_dim: 64\n      emb_feature_dim: 8\n      channels: 64\n      n_layers: 4\n      num_heads: 8\n      num_steps: 50\n      diffusion_embedding_dim: 64\n      beta_start: 0.001\n      beta_end: 0.5\n      sample_size: 16\n      linear_trans: false\n      use_lags: false\n      use_feat_idx_emb: false\n      use_time_feat: false\n      feat_idx_emb_dim: 1\n  num_samples: 100\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: traffic_nips\n      split_val: true\n      scaler: standard # identity, standard, temporal\n  batch_size: 8\n  test_batch_size: 8\n  num_workers: 8\n"
  },
  {
    "path": "config/stsf/traffic/dlinear.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.DLinear\n    init_args:\n      individual: false\n      kernel_size: 3\n      use_lags: false\n      use_feat_idx_emb: false\n      use_time_feat: false\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: traffic_nips\n      split_val: true\n      scaler: standard # identity, standard, temporal\n  batch_size: 32\n  test_batch_size: 32\n  num_workers: 8"
  },
  {
    "path": "config/stsf/traffic/gru.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.GRUForecaster\n    init_args:\n      f_hidden_size: 128\n      num_layers: 2\n      dropout: 0.1\n      use_lags: true\n      use_feat_idx_emb: true\n      use_time_feat: true\n      feat_idx_emb_dim: 1\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: traffic_nips\n      split_val: true\n      scaler: standard # identity, standard, temporal\n  batch_size: 32\n  test_batch_size: 32\n  num_workers: 8\n"
  },
  {
    "path": "config/stsf/traffic/gru_maf.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.GRU_MAF\n    init_args:\n      enc_num_layers: 2\n      enc_hidden_size: 128\n      enc_dropout: 0.3\n      n_blocks: 3\n      hidden_size: 100\n      n_hidden: 2\n      batch_norm: true\n      conditional_length: 200\n      dequantize: false\n      use_lags: true\n      use_feat_idx_emb: true\n      use_time_feat: true\n      feat_idx_emb_dim: 1\n      use_scaling: true\n  num_samples: 100\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: traffic_nips\n      split_val: true\n      scaler: identity # identity, standard, temporal\n  batch_size: 32\n  test_batch_size: 32\n  num_workers: 8"
  },
  {
    "path": "config/stsf/traffic/gru_nvp.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.GRU_NVP\n    init_args:\n      enc_hidden_size: 128\n      enc_num_layers: 2\n      enc_dropout: 0.3\n      n_blocks: 4\n      hidden_size: 100\n      n_hidden: 2\n      batch_norm: true\n      conditional_length: 200\n      dequantize: false\n      use_lags: true\n      use_feat_idx_emb: true\n      use_time_feat: true\n      feat_idx_emb_dim: 1\n      use_scaling: true\n  num_samples: 100\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: traffic_nips\n      split_val: true\n      scaler: identity # identity, standard, temporal\n  batch_size: 32\n  test_batch_size: 32\n  num_workers: 8"
  },
  {
    "path": "config/stsf/traffic/patchtst.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\n  accumulate_grad_batches: 1\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.PatchTST\n    init_args:\n      stride: 3\n      patch_len: 6\n      dropout: 0.1\n      f_hidden_size: 32\n      n_layers: 3\n      n_heads: 8\n      fc_dropout: 0.2\n      head_dropout: 0\n      individual: false\n  num_samples: 100\n  learning_rate: 0.0001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: traffic_nips\n      split_val: true\n      scaler: standard # identity, standard, temporal\n  batch_size: 64\n  test_batch_size: 64\n  num_workers: 8"
  },
  {
    "path": "config/stsf/traffic/timegrad.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.TimeGrad\n    init_args:\n      loss_type: l2\n      diff_steps: 100\n      beta_end: 0.1\n      beta_schedule: linear\n      conditional_length: 100\n      enc_hidden_size: 128\n      enc_num_layers: 4\n      enc_dropout: 0.1\n      use_lags: true\n      use_feat_idx_emb: true\n      use_time_feat: true\n      feat_idx_emb_dim: 1\n      use_scaling: true\n  num_samples: 100\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: traffic_nips\n      split_val: true\n      scaler: identity # identity, standard, temporal\n  batch_size: 32\n  test_batch_size: 32\n  num_workers: 8\n"
  },
  {
    "path": "config/stsf/traffic/timesnet.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.TimesNet\n    init_args:\n      n_layers: 2\n      num_kernels: 6\n      top_k: 5\n      d_ff: 16\n      dropout: 0.1\n      f_hidden_size: 16\n      use_lags: false\n      use_feat_idx_emb: false\n      use_time_feat: false\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: traffic_nips\n      split_val: true\n      scaler: standard # identity, standard, temporal\n  batch_size: 64\n  test_batch_size: 64\n  num_workers: 8\n"
  },
  {
    "path": "config/stsf/traffic/trans_maf.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.Trans_MAF\n    init_args:\n      enc_hidden_size: 128\n      enc_num_heads: 4\n      enc_num_encoder_layers: 2\n      enc_num_decoder_layers: 2\n      enc_dim_feedforward_scale: 4\n      enc_dropout: 0.1\n      enc_activation: gelu\n      n_blocks: 3\n      hidden_size: 100\n      n_hidden: 2\n      batch_norm: true\n      conditional_length: 200\n      dequantize: false\n      use_lags: true\n      use_feat_idx_emb: true\n      use_time_feat: true\n      feat_idx_emb_dim: 1\n      use_scaling: true\n  num_samples: 100\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: traffic_nips\n      split_val: true\n      scaler: identity # identity, standard, temporal\n  batch_size: 32\n  test_batch_size: 32\n  num_workers: 8"
  },
  {
    "path": "config/stsf/traffic/transformer.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.TransformerForecaster\n    init_args:\n      f_hidden_size: 32\n      num_heads: 8\n      num_encoder_layers: 3\n      num_decoder_layers: 3\n      dim_feedforward_scale: 4\n      dropout: 0.1\n      activation: gelu\n      use_lags: true\n      use_feat_idx_emb: true\n      use_time_feat: true\n      feat_idx_emb_dim: 1\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: traffic_nips\n      split_val: true\n      scaler: standard # identity, standard, temporal\n  batch_size: 32\n  test_batch_size: 32\n  num_workers: 8"
  },
  {
    "path": "config/stsf/wiki/csdi.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 400\n  log_every_n_steps: 1\n  check_val_every_n_epoch: 3\n  default_root_dir: ./results\n  accumulate_grad_batches: 4\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.CSDI\n    init_args:\n      emb_time_dim: 64\n      emb_feature_dim: 8\n      channels: 64\n      n_layers: 4\n      num_heads: 8\n      num_steps: 50\n      diffusion_embedding_dim: 64\n      beta_start: 0.001\n      beta_end: 0.5\n      sample_size: 16\n      linear_trans: false\n      use_lags: false\n      use_feat_idx_emb: false\n      use_time_feat: false\n      feat_idx_emb_dim: 1\n  num_samples: 100\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: wiki2000_nips\n      split_val: true\n      scaler: standard # identity, standard, temporal\n  batch_size: 8\n  test_batch_size: 8\n  num_workers: 8\n"
  },
  {
    "path": "config/stsf/wiki/dlinear.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.DLinear\n    init_args:\n      individual: false\n      kernel_size: 3\n      use_lags: false\n      use_feat_idx_emb: false\n      use_time_feat: false\n  learning_rate: 0.0001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: wiki2000_nips\n      split_val: true\n      scaler: standard # identity, standard, temporal\n  batch_size: 32\n  test_batch_size: 32\n  num_workers: 8"
  },
  {
    "path": "config/stsf/wiki/gru.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.GRUForecaster\n    init_args:\n      f_hidden_size: 40\n      num_layers: 2\n      dropout: 0.1\n      use_lags: true\n      use_feat_idx_emb: true\n      use_time_feat: true\n      feat_idx_emb_dim: 1\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: wiki2000_nips\n      split_val: true\n      scaler: standard # identity, standard, temporal\n  batch_size: 64\n  test_batch_size: 64\n  num_workers: 8\n"
  },
  {
    "path": "config/stsf/wiki/gru_maf.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.GRU_MAF\n    init_args:\n      enc_num_layers: 2\n      enc_hidden_size: 40\n      enc_dropout: 0.1\n      n_blocks: 3\n      hidden_size: 100\n      n_hidden: 2\n      batch_norm: true\n      conditional_length: 200\n      dequantize: true\n      use_lags: true\n      use_feat_idx_emb: true\n      use_time_feat: true\n      feat_idx_emb_dim: 1\n      use_scaling: true\n  num_samples: 100\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: wiki2000_nips\n      split_val: true\n      scaler: identity # identity, standard, temporal\n  batch_size: 64\n  test_batch_size: 64\n  num_workers: 8\n"
  },
  {
    "path": "config/stsf/wiki/gru_nvp.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.GRU_NVP\n    init_args:\n      enc_hidden_size: 40\n      enc_num_layers: 2\n      enc_dropout: 0.1\n      n_blocks: 3\n      hidden_size: 100\n      n_hidden: 2\n      batch_norm: true\n      conditional_length: 200\n      dequantize: true\n      use_lags: true\n      use_feat_idx_emb: true\n      use_time_feat: true\n      feat_idx_emb_dim: 1\n      use_scaling: true\n  num_samples: 100\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: wiki2000_nips\n      split_val: true\n      scaler: identity # identity, standard, temporal\n  batch_size: 64\n  test_batch_size: 64\n  num_workers: 8\n"
  },
  {
    "path": "config/stsf/wiki/patchtst.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 400\n  log_every_n_steps: 1\n  default_root_dir: ./results\n  accumulate_grad_batches: 4\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.PatchTST\n    init_args:\n      stride: 4\n      patch_len: 8\n      dropout: 0.1\n      f_hidden_size: 32\n      n_layers: 2\n      n_heads: 8\n      fc_dropout: 0.2\n      head_dropout: 0\n      individual: false\n  num_samples: 100\n  learning_rate: 0.0001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: wiki2000_nips\n      split_val: true\n      scaler: standard # identity, standard, temporal\n  batch_size: 16\n  test_batch_size: 16\n  num_workers: 8"
  },
  {
    "path": "config/stsf/wiki/timegrad.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.TimeGrad\n    init_args:\n      loss_type: l2\n      diff_steps: 100\n      beta_end: 0.1\n      beta_schedule: linear\n      conditional_length: 100\n      enc_hidden_size: 128\n      enc_num_layers: 4\n      enc_dropout: 0.1\n      use_lags: true\n      use_feat_idx_emb: true\n      use_time_feat: true\n      feat_idx_emb_dim: 1\n      use_scaling: true\n  num_samples: 100\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: wiki2000_nips\n      split_val: true\n      scaler: identity # identity, standard, temporal\n  batch_size: 64\n  test_batch_size: 64\n  num_workers: 8\n"
  },
  {
    "path": "config/stsf/wiki/timesnet.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.TimesNet\n    init_args:\n      n_layers: 2\n      num_kernels: 6\n      top_k: 5\n      d_ff: 32\n      dropout: 0.1\n      f_hidden_size: 32\n      use_lags: false\n      use_feat_idx_emb: false\n      use_time_feat: false\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: wiki2000_nips\n      split_val: true\n      scaler: standard # identity, standard, temporal\n  batch_size: 64\n  test_batch_size: 64\n  num_workers: 8\n"
  },
  {
    "path": "config/stsf/wiki/trans_maf.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.Trans_MAF\n    init_args:\n      enc_hidden_size: 128\n      enc_num_heads: 4\n      enc_num_encoder_layers: 2\n      enc_num_decoder_layers: 2\n      enc_dim_feedforward_scale: 4\n      enc_dropout: 0.1\n      enc_activation: gelu\n      n_blocks: 3\n      hidden_size: 100\n      n_hidden: 2\n      batch_norm: true\n      conditional_length: 200\n      dequantize: true\n      use_lags: true\n      use_feat_idx_emb: true\n      use_time_feat: true\n      feat_idx_emb_dim: 1\n      use_scaling: true\n  num_samples: 100\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: wiki2000_nips\n      split_val: true\n      scaler: identity # identity, standard, temporal\n  batch_size: 64\n  test_batch_size: 64\n  num_workers: 8"
  },
  {
    "path": "config/stsf/wiki/transformer.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 1\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.TransformerForecaster\n    init_args:\n      f_hidden_size: 32\n      num_heads: 8\n      num_encoder_layers: 3\n      num_decoder_layers: 3\n      dim_feedforward_scale: 4\n      dropout: 0.1\n      activation: gelu\n      use_lags: true\n      use_feat_idx_emb: true\n      use_time_feat: true\n      feat_idx_emb_dim: 1\n  learning_rate: 0.001\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: wiki2000_nips\n      split_val: true\n      scaler: standard # identity, standard, temporal\n  batch_size: 64\n  test_batch_size: 64\n  num_workers: 8"
  },
  {
    "path": "config/tsfm/chronos.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 0\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 40\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.Chronos\n    init_args:\n      model_size: base # tiny, mini, small, base, large\n  num_samples: 100\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: solar_nips\n      split_val: true\n      scaler: standard # identity, standard, temporal\n  batch_size: 16\n  test_batch_size: 16\n  num_workers: 8"
  },
  {
    "path": "config/tsfm/forecastpfn.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 0\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 40\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.ForecastPFN\n    init_args:\n        label_len: 48\n        ckpt_path: ./checkpoints/ForecastPFN/saved_weights\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: solar_nips\n      split_val: true\n      scaler: standard # identity, standard, temporal\n      timeenc: 2\n  batch_size: 64\n  test_batch_size: 64\n  num_workers: 8"
  },
  {
    "path": "config/tsfm/lag_llama.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 0\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 40\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.LagLlama\n    init_args:\n      use_rope_scaling: true\n      ckpt_path: ./checkpoints/lag-llama/lag-llama.ckpt\n  num_samples: 100\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: solar_nips\n      split_val: true\n      scaler: identity # identity, standard, temporal\n      timeenc: 2\n  batch_size: 1\n  test_batch_size: 1\n  num_workers: 8"
  },
  {
    "path": "config/tsfm/moirai/context_5000/electricity_ltsf.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 0\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 1\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.Moirai\n    init_args:\n      variate_mode: S\n      patch_size: 128\n      model_size: base\n      scaling: true\n  num_samples: 100\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: electricity_ltsf\n      split_val: true\n      scaler: standard # identity, standard, temporal\n      var_specific_norm: false\n      context_length: 5000\n      auto_search: true\n  batch_size: 1\n  test_batch_size: 1\n  num_workers: 8"
  },
  {
    "path": "config/tsfm/moirai/context_5000/electricity_nips.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 0\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 1\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.Moirai\n    init_args:\n      variate_mode: S\n      patch_size: 64\n      model_size: base\n      scaling: true\n  num_samples: 100\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: electricity_nips\n      split_val: true\n      scaler: standard # identity, standard, temporal\n      var_specific_norm: true\n      context_length: 3800  # maximum history length\n      auto_search: true\n  batch_size: 1\n  test_batch_size: 1\n  num_workers: 8"
  },
  {
    "path": "config/tsfm/moirai/context_5000/etth1.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 0\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 1\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.Moirai\n    init_args:\n      variate_mode: M\n      patch_size: 64\n      model_size: base\n      scaling: true\n  num_samples: 100\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: etth1\n      split_val: true\n      scaler: standard # identity, standard, temporal\n      context_length: 5000\n      auto_search: true\n  batch_size: 64\n  test_batch_size: 64\n  num_workers: 8"
  },
  {
    "path": "config/tsfm/moirai/context_5000/etth2.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 0\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 1\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.Moirai\n    init_args:\n      variate_mode: M\n      patch_size: 64\n      model_size: base\n      scaling: true\n  num_samples: 100\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: etth2\n      split_val: true\n      scaler: standard # identity, standard, temporal\n      context_length: 5000\n      auto_search: true\n  batch_size: 64\n  test_batch_size: 64\n  num_workers: 8"
  },
  {
    "path": "config/tsfm/moirai/context_5000/ettm1.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 0\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 1\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.Moirai\n    init_args:\n      variate_mode: S\n      patch_size: 64\n      model_size: base\n      scaling: true\n  num_samples: 100\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: ettm1\n      split_val: true\n      scaler: standard # identity, standard, temporal\n      context_length: 5000\n      auto_search: true\n  batch_size: 64\n  test_batch_size: 64\n  num_workers: 8"
  },
  {
    "path": "config/tsfm/moirai/context_5000/ettm2.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 0\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 1\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.Moirai\n    init_args:\n      variate_mode: M\n      patch_size: 128\n      model_size: base\n      scaling: true\n  num_samples: 100\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: ettm2\n      split_val: true\n      scaler: standard # identity, standard, temporal\n      context_length: 5000\n      auto_search: true\n  batch_size: 64\n  test_batch_size: 64\n  num_workers: 8"
  },
  {
    "path": "config/tsfm/moirai/context_5000/exchange_rate_nips.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 0\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 1\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.Moirai\n    init_args:\n      variate_mode: M\n      patch_size: 128\n      model_size: base\n      scaling: true\n  num_samples: 100\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: exchange_rate_nips\n      split_val: true\n      scaler: standard # identity, standard, temporal\n      var_specific_norm: false\n      context_length: 5000\n      auto_search: true\n  batch_size: 1\n  test_batch_size: 1\n  num_workers: 8"
  },
  {
    "path": "config/tsfm/moirai/context_5000/solar_nips.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 0\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 1\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.Moirai\n    init_args:\n      variate_mode: S\n      patch_size: auto\n      model_size: base\n      scaling: true\n  num_samples: 100\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: solar_nips\n      split_val: true\n      scaler: identity # identity, standard, temporal\n      var_specific_norm: false\n      context_length: 5000\n      auto_search: true\n  batch_size: 1\n  test_batch_size: 1\n  num_workers: 8"
  },
  {
    "path": "config/tsfm/moirai/context_5000/weather_ltsf.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 0\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 1\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.Moirai\n    init_args:\n      variate_mode: M\n      patch_size: 128\n      model_size: base\n      scaling: true\n  num_samples: 100\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: weather_ltsf\n      split_val: true\n      scaler: standard # identity, standard, temporal\n      var_specific_norm: true\n      context_length: 5000\n      auto_search: true\n  batch_size: 64\n  test_batch_size: 64\n  num_workers: 8"
  },
  {
    "path": "config/tsfm/moirai/context_96/electricity_ltsf.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 0\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 1\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.Moirai\n    init_args:\n      variate_mode: S\n      patch_size: auto\n      model_size: base\n      scaling: true\n  num_samples: 100\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: electricity_ltsf\n      split_val: true\n      scaler: standard # identity, standard, temporal\n      var_specific_norm: false\n      context_length: 96\n      auto_search: true\n  batch_size: 4\n  test_batch_size: 4\n  num_workers: 8"
  },
  {
    "path": "config/tsfm/moirai/context_96/electricity_nips.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 0\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 1\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.Moirai\n    init_args:\n      variate_mode: S\n      patch_size: 64\n      model_size: base\n      scaling: true\n  num_samples: 100\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: electricity_nips\n      split_val: true\n      scaler: standard # identity, standard, temporal\n      var_specific_norm: true\n      context_length: 96\n      auto_search: true\n  batch_size: 1\n  test_batch_size: 1\n  num_workers: 8"
  },
  {
    "path": "config/tsfm/moirai/context_96/etth1.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 0\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 1\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.Moirai\n    init_args:\n      variate_mode: M\n      patch_size: auto\n      model_size: base\n      scaling: true\n  num_samples: 100\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: etth1\n      split_val: true\n      scaler: standard # identity, standard, temporal\n      var_specific_norm: false\n      context_length: 96\n      auto_search: true\n  batch_size: 64\n  test_batch_size: 64\n  num_workers: 8"
  },
  {
    "path": "config/tsfm/moirai/context_96/etth2.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 0\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 1\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.Moirai\n    init_args:\n      variate_mode: M\n      patch_size: auto\n      model_size: base\n      scaling: true\n  num_samples: 100\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: etth2\n      split_val: true\n      scaler: standard # identity, standard, temporal\n      var_specific_norm: false\n      context_length: 96\n      auto_search: true\n  batch_size: 64\n  test_batch_size: 64\n  num_workers: 8"
  },
  {
    "path": "config/tsfm/moirai/context_96/ettm1.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 0\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 1\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.Moirai\n    init_args:\n      variate_mode: M\n      patch_size: auto\n      model_size: base\n      scaling: true\n  num_samples: 100\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: ettm1\n      split_val: true\n      scaler: standard # identity, standard, temporal\n      var_specific_norm: false\n      context_length: 96\n      auto_search: true\n  batch_size: 64\n  test_batch_size: 64\n  num_workers: 8"
  },
  {
    "path": "config/tsfm/moirai/context_96/ettm2.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 0\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 1\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.Moirai\n    init_args:\n      variate_mode: M\n      patch_size: auto\n      model_size: base\n      scaling: true\n  num_samples: 100\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: ettm2\n      split_val: true\n      scaler: standard # identity, standard, temporal\n      var_specific_norm: false\n      context_length: 96\n      auto_search: true\n  batch_size: 64\n  test_batch_size: 64\n  num_workers: 8"
  },
  {
    "path": "config/tsfm/moirai/context_96/exchange_rate_nips.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 0\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 1\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.Moirai\n    init_args:\n      variate_mode: M\n      patch_size: auto\n      model_size: base\n      scaling: true\n  num_samples: 100\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: exchange_rate_nips\n      split_val: true\n      scaler: standard # identity, standard, temporal\n      var_specific_norm: true\n      context_length: 96\n      auto_search: true\n  batch_size: 1\n  test_batch_size: 1\n  num_workers: 8"
  },
  {
    "path": "config/tsfm/moirai/context_96/solar_nips.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 0\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 1\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.Moirai\n    init_args:\n      variate_mode: S\n      patch_size: auto\n      model_size: base\n      scaling: true\n  num_samples: 100\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: solar_nips\n      split_val: true\n      scaler: identity # identity, standard, temporal\n      var_specific_norm: false\n      context_length: 96\n      auto_search: true\n  batch_size: 1\n  test_batch_size: 1\n  num_workers: 8"
  },
  {
    "path": "config/tsfm/moirai/context_96/weather_ltsf.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 0\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 1\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.Moirai\n    init_args:\n      variate_mode: M\n      patch_size: auto\n      model_size: base\n      scaling: true\n  num_samples: 100\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: weather_ltsf\n      split_val: true\n      scaler: standard # identity, standard, temporal\n      var_specific_norm: true\n      context_length: 96\n      auto_search: true\n  batch_size: 64\n  test_batch_size: 64\n  num_workers: 8"
  },
  {
    "path": "config/tsfm/moirai.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 0\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 40\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.prob_forecaster.Moirai\n    init_args:\n      variate_mode: S\n      patch_size: auto\n      model_size: base\n      scaling: true\n  num_samples: 100\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: solar_nips\n      split_val: true\n      scaler: identity # identity, standard, temporal\n      auto_search: true\n  batch_size: 64\n  test_batch_size: 64\n  num_workers: 8"
  },
  {
    "path": "config/tsfm/time_moe.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 0\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 40\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.TimeMoE\n    init_args:\n        model_size: 200M # select from ['50M', '200M']\n        instance_norm: true\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: solar_nips\n      split_val: true\n      scaler: identity # identity, standard, temporal\n      var_specific_norm: true\n  batch_size: 64\n  test_batch_size: 64\n  num_workers: 8"
  },
  {
    "path": "config/tsfm/timer.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 0\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 40\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.Timer\n    init_args:\n        label_len: 96\n        ckpt_path: ./checkpoints/timer/Timer_67M_UTSD_4G.pt\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: solar_nips\n      split_val: true\n      scaler: standard # identity, standard, temporal\n  batch_size: 64\n  test_batch_size: 64\n  num_workers: 8"
  },
  {
    "path": "config/tsfm/timesfm.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 0\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 40\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.TimesFM\n    init_args:\n        model_size: 200m # select from ['200m', '500m']\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: solar_nips\n      split_val: true\n      scaler: identity # identity, standard, temporal\n      var_specific_norm: true\n  batch_size: 64\n  test_batch_size: 64\n  num_workers: 8"
  },
  {
    "path": "config/tsfm/tinytimemixer.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 0\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 40\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.TinyTimeMixer\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: solar_nips\n      split_val: true\n      scaler: standard # identity, standard, temporal\n  batch_size: 64\n  test_batch_size: 64\n  num_workers: 8"
  },
  {
    "path": "config/tsfm/units.yaml",
    "content": "# lightning==2.3.0.dev0\nseed_everything: 0\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 40\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results\nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.UniTS\n    init_args:\n      ckpt_path: ./checkpoints/units/units_x128_pretrain_checkpoint.pth\n  quantiles_num: 20\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: solar_nips\n      split_val: true\n      scaler: standard # identity, standard, temporal\n      # var_norm: true\n  batch_size: 64\n  test_batch_size: 64\n  num_workers: 8"
  },
  {
    "path": "datasets/.gitignore",
    "content": "*\n!.gitignore"
  },
  {
    "path": "docs/benchmark/README.md",
    "content": "# Benchmarking :balance_scale:\n\nAccurate point and distributional forecasts across diverse horizons are crucial for time-series forecasting. However, existing research often focuses on isolated aspects, such as long-term point forecasting or short-term probabilistic estimation. This raises a fundamental question: **How do different methodological designs address these diverse forecasting needs?**\n\nIn this repository, we:\n1. **Provide Detailed Reproduction Guides:** Offer comprehensive instructions for replicating supervised models and pre-trained foundation models.\n2. **Evaluate Methods Under a Unified Framework:** Align and assess existing methods across various data scenarios using a consistent benchmarking framework.\n3. **Deliver In-Depth Insights:** Present detailed analyses and insights into the experimental results.\n\n\n## Benchmarking Scripts\n\n- [Supervised Forecasting Models](./supervised_model/README.md)\n- [Pre-trained Time-Series Foundation Models](./foundation_model/README.md)\n\n## Methodology Overview\n\n![Methodology](./figs/methodology.jpg)"
  },
  {
    "path": "docs/benchmark/foundation_model/README.md",
    "content": "# Time Series Foundation Models Benchmarking\n\n- [Time Series Foundation Models Benchmarking](#time-series-foundation-models-benchmarking)\n  - [Foundation Models](#foundation-models)\n    - [Overview](#overview)\n    - [Results Reproduction](#results-reproduction)\n  - [Key Insights \\& Takeaways](#key-insights--takeaways)\n  - [Experimental Results](#experimental-results)\n    - [Comparison Across Horizons](#comparison-across-horizons)\n    - [Short-term Probabilistic Forecasting](#short-term-probabilistic-forecasting)\n\n\n## Foundation Models\n\n### Overview\n\n| Model | Backbone | Dec. | Varied Hor. | Dist. Head | Var. | Hyper-param in Inference | Running Guides |\n| --- | --- | --- | --- | --- | --- | --- | --- |\n| [Lag-Llama](https://github.com/time-series-foundation-models/lag-llama) | Dec-only Trans. | AR | √ | Student' t | Uni | `context len`, `pred len`, `use_rope_scaling` | [Details](./lag-llama.md) |\n| [Chronos](https://github.com/amazon-science/chronos-forecasting) | Enc-Dec Trans. | AR | √ | Arbitrary | Uni | `context len`, `pred len`, `num_samples`, `temperature`, `top_k`, `top_p` | [Details](./chronos.md) |\n| [TimesFM](https://github.com/google-research/timesfm) | Dec-only Trans. | AR | √ | - | Uni | `context len`, `frequency`, `window size` | [Details](./timesfm.md) |\n| [Timer](https://github.com/thuml/Large-Time-Series-Model) | Dec-only Trans. | AR | √ | - | Uni | `context len`, `pred len`, `use_ims`  | [Details](./timer.md) |\n| [MOIRAI](https://github.com/SalesforceAIResearch/uni2ts) | Enc-only Trans.  | NAR | √ | Mixture dist. | Multi | `context len`, `pred len`, `patch size`, `variate_mode` | [Details](./moirai.md) |\n| [ForecastPFN](https://github.com/abacusai/ForecastPFN) | Enc-only Trans.  | NAR | √ | - | Uni | `context len`, `pred len` | [Details](./forecastpfn.md) |\n| [UniTS](https://github.com/mims-harvard/UniTS) | Enc-only Trans.  | NAR | √ | - | Multi | `context len`, `pred len` | [Details](./units.md) |\n| [Tiny Time Mixers](https://github.com/ibm-granite/granite-tsfm/tree/main/tsfm_public/models/tinytimemixer) | TSMixer | NAR | x | - | Multi | `context len`, `pred len` | [Details](./ttm.md) |\n\n### Results Reproduction\n\nFor time-series foundation models, you need to install basic packages and additional dependencies:\n\n**1. Set Up Environment**\n```bash\n# Create a new conda environment\nconda create -n probts_fm python=3.10\nconda activate probts_fm\n\n# Git submodule\ngit submodule update --init --recursive\n\n# Install additional packages for foundation models\npip install \".[tsfm]\"\npip uninstall -y probts # recommended to uninstall the root package (optional)\n```\n\n**2. Initialize Submodules**\n\nTo running model MOIRAI, TimesFM, Lag-Llama and TinyTimeMixer, please run the following commands for submodules initialization.\n```bash\n# For MOIRAI, we fix the version of the package for better performance\ncd submodules/uni2ts\ngit reset --hard fce6a6f57bc3bc1a57c7feb3abc6c7eb2f264301\n\n# For TimesFM, fix the version for reproducibility (optional)\ncd submodules/timesfm\ngit reset --hard 5c7b905\n\n# For Lag-Llama, fix the version for reproducibility (optional)\ncd submodules/lag_llama\ngit reset --hard 4ad82d9\n\n# For TinyTimeMixer, fix the version for reproducibility (optional)\ncd submodules/tsfm\ngit reset --hard bb125c14a05e4231636d6b64f8951d5fe96da1dc\n```\n\n**3. Download Model Checkpoints**\n\nDownload the necessary checkpoints (More details are available [here](./checkpoints/README.md)):\n```bash\nbash scripts/prepare_tsfm_checkpoints.sh\n```\nNote: By downloading, you agree to the original license terms. \n\n**4. Run Benchmarking:**\n\nReproduce the results reported in the ProbTS paper:\n\n```bash \nbash scripts/reproduce_tsfm_results.sh\n```\n\nConfiguration files are in [config/tsfm/](../../config/tsfm/).\n\n\n**5. Experimental Results Analysis (Coming Soon)** :construction:\n\nAnalysis notebooks will be added in a future update.\n\n## Key Insights & Takeaways\n\n**1. Similar Insights in Evaluating Supervised Models Reconfirmed**\n\n- Handling **Varied Forecasting Horizons:** Current AR-based time-series foundation models also encounter error accumulation problems.\n- Addressing **Complex Data Distributions:** Predefined distribution heads lack the capability to fully capture complex data distributions.\n\n**2. Supervised Time-Series Models vs. Pre-trained Foundation Models**\n- There is no definitive winner yet!\n\n![tsfm_analysis](./figs/tsfm_analysis.jpg)\n\n**Takeaways:** \n- In practice, you may need to choose the right paradigm based on specific cases:\n  - Unique data patterns → supervised models\n  - Scarce training data → pre-trained models, etc.\n\n\n## Experimental Results\n\n### Comparison Across Horizons\n\n![tsfm_res](./figs/tsfm_results.jpg)\nFigure. We use a dashed line to denote the datasets on which the model was pre-trained, e.g., both TimesFM and MOIRAI have leveraged Traffic datasets for their pre-training. The ETT encompasses averaged results from datasets ETTh1, ETTh2, ETTm1, and ETTm2. \n\nTable 3. NMAE of time-series foundation models on diverse prediction horizons. The input sequence length is set to 96 if not specified. For every model, we exclude the evaluation results on its pre-trained datasets\n\n![Comparison of Time-series Foundation Models on Diverse Prediction Horizons](./figs/fm_var_hor.jpg)\n\n### Short-term Probabilistic Forecasting\n\nTable 4. Results of probabilistic foundation models on short-term distributional forecasting. For every model, we exclude the evaluation results on its pre-trained datasets.\n\n![Comparison of Time-series Foundation Models on short-term scenerio](./figs/fm_short_term.jpg)\n"
  },
  {
    "path": "docs/benchmark/foundation_model/chronos.md",
    "content": "# Running Inference with Chronos\n\n[Original Repository](https://github.com/amazon-science/chronos-forecasting) | [Paper](https://arxiv.org/abs/2403.07815)\n\nFollow these steps to set up and run inference using Chronos:\n\n1. Set up the [environment](../README.md#results-reproduction).\n2. Run the inference script with the following commands:\n\n```bash\nMODEL='chronos'\nfor DATASET in 'etth1' 'etth2' 'ettm1' 'ettm2' 'weather_ltsf'; do\n    for CTX_LEN in 5000 96; do\n        for PRED_LEN in 24 48 96 192 336 720; do\n            python run.py --config config/tsfm/${MODEL}.yaml --seed_everything 0  \\\n                --data.data_manager.init_args.path ${DATA_DIR} \\\n                --trainer.default_root_dir ${LOG_DIR} \\\n                --data.data_manager.init_args.split_val true \\\n                --data.data_manager.init_args.dataset ${DATASET} \\\n                --data.data_manager.init_args.context_length ${CTX_LEN} \\\n                --data.data_manager.init_args.prediction_length ${PRED_LEN} \\\n                --data.test_batch_size 1\n        done\n    done\ndone\n```\n\n\n## Hyper-param in Inference\n\n\n`Temperature` (default: 1): If Temperature=0, the output is consistent. The bigger the more diverse\n\n`top_k`(default: 50): Only conduct softmax for top-k logits.\n\n`top-p` (default: 1): Nucleus sampling. The model sums the probabilities of the most likely next value in descending order and stops when the sum reaches p.\n\n"
  },
  {
    "path": "docs/benchmark/foundation_model/forecastpfn.md",
    "content": "# Running Inference with ForecastPFN\n\n[Original Repository](https://github.com/abacusai/ForecastPFN) | [Paper](https://arxiv.org/abs/2311.01933)\n\nFollow these steps to set up and run inference using ForecastPFN:\n\n1. Set up the [environment](../README.md#results-reproduction).\n2. Run the inference script with the following commands:\n\n```bash\n# ForecastPFN\nMODEL='forecastpfn'\nfor DATASET in 'etth1' 'etth2' 'ettm1' 'ettm2' 'weather_ltsf'; do\n    for CTX_LEN in 96; do\n        for PRED_LEN in 24 48 96 192 336 720; do\n            python run.py --config config/tsfm/${MODEL}.yaml --seed_everything 0  \\\n                --data.data_manager.init_args.path ${DATA_DIR} \\\n                --trainer.default_root_dir ${LOG_DIR} \\\n                --data.data_manager.init_args.split_val true \\\n                --data.data_manager.init_args.dataset ${DATASET} \\\n                --data.data_manager.init_args.context_length ${CTX_LEN} \\\n                --data.data_manager.init_args.prediction_length ${PRED_LEN} \\\n                --model.forecaster.init_args.ckpt_path './checkpoints/ForecastPFN/saved_weights' \\\n                --data.test_batch_size 64\n        done\n    done\ndone\n```\n"
  },
  {
    "path": "docs/benchmark/foundation_model/lag-llama.md",
    "content": "# Running Inference with Lag-Llama\n\n[Original Repository](https://github.com/time-series-foundation-models/lag-llama) | [Paper](https://arxiv.org/abs/2310.08278)\n\nFollow these steps to set up and run inference using Lag-Llama:\n\n1. Set up the [environment and initialize submodules](../README.md#results-reproduction).\n2. Run the inference script with the following commands:\n\n```bash\n# Lag-Llama\nMODEL='lag_llama'\nfor DATASET in 'etth1' 'etth2' 'ettm1' 'ettm2' 'weather_ltsf'; do\n    for CTX_LEN in 512; do\n        for PRED_LEN in 24 48 96 192 336 720; do\n            python run.py --config config/tsfm/${MODEL}.yaml --seed_everything 0  \\\n                --data.data_manager.init_args.path ${DATA_DIR} \\\n                --trainer.default_root_dir ${LOG_DIR} \\\n                --data.data_manager.init_args.split_val true \\\n                --data.data_manager.init_args.dataset ${DATASET} \\\n                --data.data_manager.init_args.context_length ${CTX_LEN} \\\n                --data.data_manager.init_args.prediction_length ${PRED_LEN} \\\n                --model.forecaster.init_args.ckpt_path './checkpoints/lag-llama/lag-llama.ckpt' \\\n                --data.test_batch_size 1\n        done\n    done\ndone\n```\n"
  },
  {
    "path": "docs/benchmark/foundation_model/moirai.md",
    "content": "# Running Inference with MOIRAI\n\n[Original Repository](https://github.com/SalesforceAIResearch/uni2ts) | [Paper](https://arxiv.org/abs/2402.02592)\n\nFollow these steps to set up and run inference using MOIRAI:\n\n1. Set up the [environment and initialize submodules](../README.md#results-reproduction).\n2. Run the inference script with the following commands:\n\n```bash\nMODEL='moirai'\nfor DATASET in 'etth1' 'etth2' 'ettm1' 'ettm2' 'weather_ltsf' 'electricity_ltsf'; do\n    for CTX_LEN in 5000 96; do\n        for PRED_LEN in 24 48 96 192 336 720; do\n            python run.py --config config/tsfm/${MODEL}/context_${CTX_LEN}/${DATASET}.yaml --seed_everything 0  \\\n                --data.data_manager.init_args.path ${DATA_DIR} \\\n                --trainer.default_root_dir ${LOG_DIR} \\\n                --data.data_manager.init_args.dataset ${DATASET} \\\n                --data.data_manager.init_args.prediction_length ${PRED_LEN}\n        done\n    done\ndone\n```\n\n## Hyper-param in Inference\n\n`patch size` (default: `auto`): Specifies the patch size used during inference. When set to `auto`, the model selects the patch size that minimizes validation loss based on historical data.\n\n`variate_mode` (default: `S`): Determines whether the model operates in univariate (`S`) or multivariate mode (`M`) during inference."
  },
  {
    "path": "docs/benchmark/foundation_model/timer.md",
    "content": "# Running Inference with Timer\n\n[Original Repository](https://github.com/thuml/Large-Time-Series-Model) | [Paper](https://arxiv.org/abs/2402.02368)\n\nFollow these steps to set up and run inference using Timer:\n\n1. Set up the [environment](../README.md#results-reproduction).\n2. Run the inference script with the following commands:\n\n```bash\nMODEL='timer'\nfor DATASET in 'etth1' 'etth2' 'ettm1' 'ettm2' 'weather_ltsf' 'electricity_ltsf'; do\n    for CTX_LEN in 96; do\n        for PRED_LEN in 24 48 96 192 336 720; do\n            python run.py --config config/tsfm/${MODEL}.yaml --seed_everything 0  \\\n                --data.data_manager.init_args.path ${DATA_DIR} \\\n                --trainer.default_root_dir ${LOG_DIR} \\\n                --data.data_manager.init_args.split_val true \\\n                --data.data_manager.init_args.dataset ${DATASET} \\\n                --data.data_manager.init_args.context_length ${CTX_LEN} \\\n                --data.data_manager.init_args.prediction_length ${PRED_LEN} \\\n                --model.forecaster.init_args.ckpt_path './checkpoints/timer/Timer_67M_UTSD_4G.pt' \\\n                --data.test_batch_size 64\n        done\n    done\ndone\n```\n\n## Hyper-param in Inference\n\n`use_ims` (default: false): Evaluate decoder-only models in the Iterative Multi-step (IMS) way or encoder-only forecasters in Direct Multi-step (DMS) approach\n\n`sub_rand_ratio`: The ratio of training samples in few-shot scenarios."
  },
  {
    "path": "docs/benchmark/foundation_model/timesfm.md",
    "content": "# Running Inference with TimesFM\n\n[Original Repository](https://github.com/google-research/timesfm) | [Paper](https://arxiv.org/abs/2310.10688)\n\nFollow these steps to set up and run inference using TimesFM:\n\n1. Set up the [environment](../README.md#results-reproduction).\n2. Run the inference script with the following commands:\n\n```bash\nMODEL='timesfm'\nfor DATASET in 'etth1' 'etth2' 'ettm1' 'ettm2'; do\n    for CTX_LEN in 96; do\n        for PRED_LEN in 24 48 96 192 336 720; do\n            python run.py --config config/tsfm/${MODEL}.yaml --seed_everything 0  \\\n                --data.data_manager.init_args.path ${DATA_DIR} \\\n                --trainer.default_root_dir ${LOG_DIR} \\\n                --data.data_manager.init_args.split_val true \\\n                --data.data_manager.init_args.dataset ${DATASET} \\\n                --data.data_manager.init_args.context_length ${CTX_LEN} \\\n                --data.data_manager.init_args.prediction_length ${PRED_LEN} \\\n                --data.test_batch_size 64\n        done\n    done\ndone\n```\n\n## Hyper-param in Inference\n\n`frequency` (default: 0): Chose from {0, 1, 2}.\n\n\n- **0 (default):** High frequency, long horizon time series. We recommend using this for time series up to daily granularity.\n- **1:** Medium frequency time series. We recommend using this for weekly and monthly data.\n- **2:** Low frequency, short horizon time series. We recommend using this for anything beyond monthly, e.g., quarterly or yearly.\n\n\n`window size` (default: None):  Window size of trend + residual decomposition\n\n\n"
  },
  {
    "path": "docs/benchmark/foundation_model/ttm.md",
    "content": "# Running Inference with Tiny Time Mixers\n\n[Original Repository](https://github.com/ibm-granite/granite-tsfm/tree/main/tsfm_public/models/tinytimemixer) | [Paper](https://arxiv.org/abs/2401.03955)\n\nFollow these steps to set up and run inference using Tiny Time Mixers:\n\n1. Set up the [environment and initialize submodules](../README.md#results-reproduction).\n2. Run the inference script with the following commands:\n\n```bash\nMODEL='tinytimemixer'\nfor DATASET in 'etth1' 'etth2' 'ettm1' 'ettm2' 'weather_ltsf'; do\n    for CTX_LEN in 5000 96; do\n        for PRED_LEN in 24 48 96 192 336 720; do\n            python run.py --config config/tsfm/${MODEL}.yaml --seed_everything 0  \\\n                --data.data_manager.init_args.path ${DATA_DIR} \\\n                --trainer.default_root_dir ${LOG_DIR} \\\n                --data.data_manager.init_args.split_val true \\\n                --data.data_manager.init_args.dataset ${DATASET} \\\n                --data.data_manager.init_args.context_length ${CTX_LEN} \\\n                --data.data_manager.init_args.prediction_length ${PRED_LEN} \\\n                --data.test_batch_size 1\n        done\n    done\ndone\n```\n"
  },
  {
    "path": "docs/benchmark/foundation_model/units.md",
    "content": "# Running Inference with UniTS\n\n[Original Repository](https://github.com/mims-harvard/UniTS) | [Paper](https://arxiv.org/pdf/2403.00131)\n\nFollow these steps to set up and run inference using UniTS:\n\n1. Set up the [environment](../README.md#results-reproduction).\n2. Run the inference script with the following commands:\n\n```bash\nMODEL='units'\nfor DATASET in 'etth1' 'etth2' 'ettm1' 'ettm2'; do\n    for CTX_LEN in 96; do\n        for PRED_LEN in 24 48 96 192 336 720; do\n            python run.py --config config/tsfm/${MODEL}.yaml --seed_everything 0  \\\n                --data.data_manager.init_args.path ${DATA_DIR} \\\n                --trainer.default_root_dir ${LOG_DIR} \\\n                --data.data_manager.init_args.split_val true \\\n                --data.data_manager.init_args.dataset ${DATASET} \\\n                --data.data_manager.init_args.context_length ${CTX_LEN} \\\n                --data.data_manager.init_args.prediction_length ${PRED_LEN} \\\n                --model.forecaster.init_args.ckpt_path './checkpoints/units/units_x128_pretrain_checkpoint.pth' \\\n                --data.test_batch_size 64\n        done\n    done\ndone\n```\n"
  },
  {
    "path": "docs/benchmark/supervised_model/README.md",
    "content": "# Supervised Forecasting Models Benchmarking\n\n- [Supervised Forecasting Models Benchmarking](#supervised-forecasting-models-benchmarking)\n  - [Experimental Results Reproduction](#experimental-results-reproduction)\n  - [Key Insights \\& Takeaways](#key-insights--takeaways)\n    - [Point vs. Probabilistic Estimation](#point-vs-probabilistic-estimation)\n    - [Autoregressive vs. Non-autoregressive Decoding Scheme](#autoregressive-vs-non-autoregressive-decoding-scheme)\n    - [Instance-level Normalization Choice](#instance-level-normalization-choice)\n  - [Experimental Result Details](#experimental-result-details)\n\n\n\n## Experimental Results Reproduction\n\nReproduce the experimental results using the provided scripts:\n\n- **Long-Term Forecasting:**\n\n```bash \nbash scripts/reproduce_ltsf_results.sh\n```\nConfiguration files: [config/ltsf/](../../../config/ltsf/).\n\n- **Short-Term Forecasting:**\n\n```bash \nbash scripts/reproduce_stsf_results.sh\n```\n\nConfiguration files: [config/stsf/](../../../config/stsf/).\n\n\n## Key Insights & Takeaways\n\n### Point vs. Probabilistic Estimation\n\n**Insights**\n\n- Current supervised long-term point forecasting models (e.g., DLinear, PatchTST, iTransformer) **struggle with intricate data distributions**.\n- Current supervised short-term probabilistic forecasting models (e.g., GRU NVP, TimeGrad, CSDI) **face challenges in extended forecasting horizons**.\n\n\n![point_vs_prob](./figs/point_vs_prob.jpg)\n\n**Takeaways**\n- It is important to consider both long-term and short-term evaluation scenarios.\n- Leverage both point and distributional metrics for more comprehensive insights.\n\n\n\n### Autoregressive vs. Non-autoregressive Decoding Scheme\n\n**Insights**\n\n- Current Supervised Non-Autoregressive (NAR) Models (e.g., PatchTST, iTransformer, CSDI)\n  - Primarily developed for long-term forecasting scenarios.\n  - **Suboptimal for short-term forecasting, and some models are memory-intensive.**\n- Current Supervised Autoregressive (AR) Models (e.g., GRU, GRU NVP, TimeGrad)\n  - Primarily developed for short-term forecasting scenarios\n  - **Perform well with strong seasonality but struggle with long-term, strong trends**\n\n![ar_vs_nar](./figs/ar_vs_nar.jpg)\n\n**Takeaways**\n\n- It is crucial to select the right **methodological design** based on the specific **data characteristics**.\n- There are tremendous **re-design opportunities**, given the **comprehensive forecasting needs**.\n\n\n### Instance-level Normalization Choice\n\n**Insights**\n\n- Reversible Instance Normalization (RevIN): Essential for Long-term Forecasting Scenarios\n  - Our observation: **AR models in the literature are scarce for long-term forecasting**\n  - Our finding: RevIN + AR => **A simple yet highly effective baseline that has been overlooked**\n- Normalization Choices under Short-term Forecasting Scenarios\n  - **No dominating normalization strategies**\n\n\n![norm](./figs/norm.jpg)\n\n**Takeaways**\n\n- The **co-design** of **normalization** techniques and **model** architectures warrants further research attention.\n- The **challenges and opportunities** in time-series normalization persist in balancing short-term and long-term forecasting needs.\n\n\n\n## Experimental Result Details\n\n\n**Long-Term Forecasting Benchmarking**\n\n\n\n\nTable 1. Results ($\\textrm{mean}_{\\textrm{std}}$) on long-term forecasting scenarios with the best in $\\textbf{bold}$ and the second $\\underline{\\textrm{underlined}}$, each containing five independent runs with different seeds. The input sequence length is set to 36 for the ILI-L dataset and 96 for the others. Due to the excessive time and memory consumption of CSDI in producing long-term forecasts, its results are unavailable in some datasets.\n\n![long-term forecasting experimental results](./figs/long_bench.jpg)\n\n\n\n**Short-Term Forecasting Benchmarking**\n\n\n\nTable 2.Results ($\\textrm{mean}_{\\textrm{std}}$) on short-term forecasting scenarios with the best in $\\textbf{bold}$ and the second $\\underline{\\textrm{underlined}}$, each containing five independent runs with different seeds.\n\n![short-term forecasting experimental results](./figs/short_bench.jpg)\n\n\n"
  },
  {
    "path": "docs/documentation/Gift_eval.md",
    "content": "\n## How to evaluate the models in ProbTS using the GIFT-EVAL benchmark\n\nLink to the GIFT-EVAL benchmark: [Github Repo](https://github.com/SalesforceAIResearch/gift-eval) [Paper](https://openreview.net/forum?id=9EBSEkFSje)\n\n1. Follow installation instructions in the GIFT-EVAL repository to **download the dataset** from its huggingface dataset repository.\n2. Also, set the environment variable `GIFT_EVAL` to the path where the dataset is downloaded.\n``` bash\necho \"GIFT_EVAL=/path/to/gift-eval\" >> .env\n```\n3. Quick start example:\n``` bash\npython run.py --config config/default/mean.yaml \\\n              --seed_everything 0 \\\n              --model.forecaster.init_args.mode batch \\\n              --data.data_manager.init_args.dataset gift/ett1/H/long \\\n              --data.data_manager.init_args.path ./datasets \\\n              --trainer.default_root_dir ./exps\n```\n\n> [!NOTE]  \n> The dataset name for the GIFT-EVAL format should be specified as follows: `\"gift/\" + \"dataset_name (main_name/freq)\" + \"short/medium/long\"`. For example, `gift/ett1/H/long`. More dataset names can be found in the GIFT-EVAL repository (for example [naive.ipynb](https://github.com/SalesforceAIResearch/gift-eval/blob/main/notebooks/naive.ipynb)).\n"
  },
  {
    "path": "docs/documentation/README.md",
    "content": "# Documentation :open_book:\n\n- [Documentation :open\\_book:](#documentation-open_book)\n  - [Setup](#setup)\n  - [Configuration Parameters](#configuration-parameters)\n    - [Trainer](#trainer)\n    - [Model](#model)\n    - [Data](#data)\n  - [Datasets](#datasets)\n    - [Datasets Overview](#datasets-overview)\n      - [Short-Term Setting](#short-term-setting)\n      - [Long-Term Setting](#long-term-setting)\n    - [Data Processing Pipeline](#data-processing-pipeline)\n    - [Using Build-in Datasets](#using-build-in-datasets)\n    - [Using Customized Dataset](#using-customized-dataset)\n  - [Model](#model-1)\n    - [Available Models](#available-models)\n    - [Using Customized Model](#using-customized-model)\n  - [Training](#training)\n    - [Configuring Optimizers and Learning Rate Schedulers](#configuring-optimizers-and-learning-rate-schedulers)\n  - [Forecasting with Varied Prediction Lengths](#forecasting-with-varied-prediction-lengths)\n    - [Example 1: Varied-Horizon Training](#example-1-varied-horizon-training)\n    - [Example 2: Validation and Testing with Multiple Horizons](#example-2-validation-and-testing-with-multiple-horizons)\n\n\n## Setup\n\nProbTS is developed with Python 3.10 and relies on [PyTorch Lightning](https://github.com/Lightning-AI/lightning). To set up the environment:\n\n```bash\n# Create a new conda environment\nconda create -n probts python=3.10\nconda activate probts\n\n# Install required packages\npip install .\npip uninstall -y probts # recommended to uninstall the root package (optional)\n```\n\n[Optional] For time-series foundation models, you need to install basic packages and additional dependencies:\n\n```bash\n# Create a new conda environment\nconda create -n probts_fm python=3.10\nconda activate probts_fm\n\n# Git submodule\ngit submodule update --init --recursive\n\n# Install additional packages for foundation models\npip install \".[tsfm]\"\npip uninstall -y probts # recommended to uninstall the root package (optional)\n\n# For MOIRAI, we fix the version of the package for better performance\ncd submodules/uni2ts\ngit reset --hard fce6a6f57bc3bc1a57c7feb3abc6c7eb2f264301\n```\n\n<details>\n\n<summary>Optional for TSFMs reproducibility</summary>\n\n```bash\n# For TimesFM, fix the version for reproducibility (optional)\ncd submodules/timesfm\ngit reset --hard 5c7b905\n\n# For Lag-Llama, fix the version for reproducibility (optional)\ncd submodules/lag_llama\ngit reset --hard 4ad82d9\n\n# For TinyTimeMixer, fix the version for reproducibility (optional)\ncd submodules/tsfm\ngit reset --hard bb125c14a05e4231636d6b64f8951d5fe96da1dc\n```\n\n</details>\n\n\n## Configuration Parameters \n\n- To print the full pipeline configuration to a file:\n\n    ```bash\n    python run.py --print_config > config/pipeline_config.yaml\n    ```\n\n### Trainer\n\n| Config Name | Type | Description |\n| --- | --- | --- |\n| `trainer.max_epochs` | `int` | Maximum number of training epochs. |\n| `trainer.limit_train_batches` | `int` | Limits the number of training batches per epoch. |\n| `trainer.check_val_every_n_epoch` | `int` | Perform validation every n training epochs. |\n| `trainer.default_root_dir` | `int` | Default path for logs and weights. |\n| `trainer.accumulate_grad_batches` | `int` | Number of batches to accumulate gradients before updating. |\n\n### Model\n\n| Config Name | Type | Description |\n| --- | --- | --- |\n| `model.forecaster.class_path` | `str` | Forecaster module path (e.g., `probts.model.forecaster.point_forecaster.PatchTST`). |\n| `model.forecaster.init_args.{ARG}` | - | Model-specific hyperparameters. |\n| `model.num_samples` | `int` | Number of samples per distribution during evaluation. |\n| `model.learning_rate` | `float` | Learning rate. |\n| `model.quantiles_num` | `int` | Number of quantiles for evaluation. |\n| `model.sampling_weight_scheme` | `str`  | The scheme of training horizon reweighting. Options: ['random', 'none', 'const'].|\n| `model.optimizer_config.class_name` | `str` | optimizer module (e.g., `torch.optim.Adam`). |\n| `model.optimizer_config.init_args.{ARG}` | - | optimizer hyperparameters. |\n| `model.scheduler_config.class_name` | `str` | lr_scheduler module (e.g., `torch.optim.lr_scheduler.OneCycleLR`). |\n| `model.scheduler_config.init_args.{ARG}` | - | lr_scheduler hyperparameters. |\n\n### Data\n\n| Config Name | Type | Description |\n| --- | --- | --- |\n| `data.data_manager.init_args.dataset` | `str` | Dataset for training and evaluation. |\n| `data.data_manager.init_args.path` | `str` | Path to the dataset folder. |\n| `data.data_manager.init_args.split_val` | `bool` | Whether to split a validation set during training. |\n| `data.data_manager.init_args.scaler` | `str` | Scaler type: `identity`, `standard` (z-score normalization), or `temporal` (scale based on average temporal absolute value). |\n| `data.data_manager.init_args.target_dim` | `int` | The number of variates. |\n| `data.data_manager.init_args.var_specific_norm` | `bool` | If conduct per-variate normalization or not. |\n| `data.data_manager.init_args.timeenc` | `int` | Time feature type. Select from `[0,1,2]`. See the explaination below for details. |\n| `data.data_manager.init_args.context_length`    | `Union[str, int, list]`       | Length of observation window in inference phase. |\n| `data.data_manager.init_args.prediction_length` | `Union[str, int, list]`       | Forecasting horizon length in inference phase. |\n| `data.data_manager.init_args.val_pred_len_list` | `Union[str, int, list]`       | Forecasting horizon length for performance validation. |\n| `data.data_manager.init_args.val_ctx_len`       | `Union[str, int, list]`      | Forecasting horizons for performance validation. |\n| `data.data_manager.init_args.train_pred_len_list`| `Union[str, int, list]`      | Length of observation window in training phase. |\n| `data.data_manager.init_args.train_ctx_len` | `Union[str, int, list]`      | Forecasting horizons in training phase. |\n| `data.data_manager.init_args.continuous_sample`  | `bool` | If True, sampling horizons from `[min(train_pred_len_list), max(train_pred_len_list)]`, else sampling within the set `train_pred_len_list`.|\n| `data.data_manager.init_args.test_rolling_length`  | `int` | `int` or `str` | Defines the gap window for rolling evaluations during testing. Defaults to `96` if not explicitly specified. If set to `auto`, the value is determined based on the dataset frequency: `{'h': 24, 'd': 7, 'b': 5, 'w': 4, 'min': 60}`. |\n| `data.data_manager.init_args.train_ratio`  | `float` | Specifies proportion of the dataset used for training. Default value is 0.7.|\n| `data.data_manager.init_args.test_ratio`  | `float` | Specifies proportion of the dataset used for training. Default value is 0.2.|\n| `data.batch_size` | `int` | Batch size. |\n\n**Temporal Features**\n\nFor the datasets used for long-term forecasting scenario, we support three types of time feature encoding\n\n```bash\n--data.data_manager.init_args.timeenc {the encoding type} # select from [0,1,2]\n```\n\n- **[timeenc 0] temporal information**\n\n    The dimension of time feature is 5, containing `month, day, weekday, hour, minute`.\n\n- **[timeenc 1] time feature based on frequency**\n    Extract time feature using `time_features_from_frequency_str()` function. The dimensionality follows:\n    ```bash\n    freq_map = {'h': 4, 't': 5, 's': 6, 'm': 1, 'a': 1, 'w': 2, 'd': 3, 'b': 3}\n    ```\n\n    *Note: timeenc = 0 if model.embed != 'timeF' else 1.*\n\n- **[timeenc 2] Raw date information**\n\n    The dimension of time feature is 5, using the following code to recover it to date data type:\n    ```bash\n    data_stamp = batch_data.past_time_feat.cpu().numpy().astype('datetime64[s]')\n    data_stamp = batch_data.future_time_feat.cpu().numpy().astype('datetime64[s]')\n    ```\n\n## Datasets\n\n### Datasets Overview\n\n\n#### Short-Term Setting\n\n| Dataset | DATASET_NAME | Domain | Frequency | #Var | time steps | Description |\n| --- | --- | --- | --- | --- | --- | --- |\n| Exchange | `exchange_rate_nips` | Finance | Busi. Day | 8 | 6,071  | Daily exchange rates of 8 countries |\n| Solar | `solar_nips` | Energy | H | 137 | 7,009 | Solar power production records |\n| Electricity | `electricity_nips` | Energy | H | 370 | 5,833  | Electricity consumption |\n| Traffic | `traffic_nips` | Transport | H | 963 | 4,001  | Road occupancy rates |\n| Wikipedia | `wiki2000_nips` | Web | D | 2,000 | 792 | Page views of 2000 Wikipedia pages |\n\n#### Long-Term Setting\n\n| Dataset | DATASET_NAME | Domain | Frequency | #Var | time steps | Description |\n| --- | --- | --- | --- | --- | --- | --- |\n| ETTh | `etth1` / `etth2` | Energy | H | 7 | 17,420 | Electricity transformer temperature per hour |\n| ETTm | `ettm1` / `ettm2` | Energy | 15min | 7 | 69,680  | Electricity transformer temperature every 15 min |\n| Electricity | `electricity_lstf` | Energy | H | 321  | 26,304  | Electricity consumption (Kwh) |\n| Weather | `weather_lstf` | Climate | 10min | 21 | 52,696  | Local climatological data |\n| Traffic  | `traffic_ltsf` | Transport | H  | 862 | 17,544  | Road occupancy rates |\n| Exchange | `exchange_ltsf` | Finance | Busi. Day | 8 | 7,588 | Daily exchange rates of 8 countries |\n| ILI  | `illness_ltsf` | Epidemiology | W | 7 | 966 | Ratio of patients seen with influenza-like illness |\n| Caiso | `caiso` | Energy | H | 10 | 74,472  | Electricity load series in different zones of California |\n| Nordpool | `nordpool` | Energy | H | 18 | 70,128  | Energy production volume in European countries |\n| Turkey Power | `turkey_power` | Energy | H | 18 | 26,304 | Electrical power demand in Turkey |\n| Istanbul Traffic | `istanbul_traffic` | Transport | H | 3 | 14,244 | Traffic Index data for Istanbul traffic |\n\n\n### Data Processing Pipeline\n\n<div align=center> <img src=\"../figs/data_pipeline.png\" width = 95%/> </div>\n\n### Using Build-in Datasets\n\n- **Short-Term Forecasting**: We use datasets from [GluonTS](https://github.com/awslabs/gluonts). \n    Configure the datasets using `--data.data_manager.init_args.dataset {DATASET_NAME}` with available `DATASET_NAME` in [short-term setting](#short-term-setting).\n\n- **Long-Term Forecasting**: To download the [long-term forecasting datasets](https://drive.google.com/drive/folders/1ZOYpTUa82_jCcxIdTmyr0LXQfvaM9vIy), please follow these steps:\n    ```bash\n    bash scripts/prepare_datasets.sh \"./datasets\"\n    ```\n\n    Configure the datasets using `--data.data_manager.init_args.dataset {DATASET_NAME}` with available `DATASET_NAME` in [long-term setting](#long-term-setting).\n\n    *Note: When utilizing long-term forecasting datasets, you must explicitly specify the `context_length` and `prediction_length` parameters. For example, to set a context length of 96 and a prediction length of 192, use the following command-line arguments:*\n    ```bash\n    --data.data_manager.init_args.context_length 96 \\\n    --data.data_manager.init_args.prediction_length 192 \\\n    ```\n\n- **Using Datasets from Monash Time Series Forecasting Repository**: To use datasets from the [Monash Time Series Forecasting Repository](https://forecastingdata.org/), follow these steps:\n\n    1. **Download the Dataset**: \n    - Navigate to the target dataset, such as the [Electricity Hourly Dataset](https://zenodo.org/records/4656140).\n    - Download the `.tsf` file and place it in your local `datasets` directory (e.g., `./datasets`).\n\n    1. **Configure the Dataset**:\n    - Use the following configuration to specify the dataset, file path, and frequency:\n        ```bash\n        --data.data_manager.init_args.dataset {DATASET_NAME} \\\n        --data.data_manager.init_args.data_path /path/to/data.csv \\\n        --data.data_manager.init_args.freq {FREQ} \n        ```\n\n    - **Example Configuration**:\n        ```bash\n        --data.data_manager.init_args.dataset monash_electricity_hourly \\\n        --data.data_manager.init_args.data_path ./datasets/electricity_hourly_dataset.tsf \\\n        --data.data_manager.init_args.freq H \\\n        --data.data_manager.init_args.context_length 96 \\\n        --data.data_manager.init_args.prediction_length 96 \\\n        --data.data_manager.init_args.multivariate true\n        ```\n\n    *Note: Refer to the [Pandas Time Series Offset Aliases](https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#timeseries-offset-aliases) for the correct frequency values (`{FREQ}`) to use in your configuration.*\n\n\n- **Using Datasets from GIFT-EVAL Benchmarking**: see [this page](./docs/documentation/Gift_eval.md) for detailed instructions.\n\n\n### Using Customized Dataset\n\n1. **Prepare the Data**: \n\n- Format your dataset as a `.csv` file with the following structure:\n\n  | date                | VAR1   | VAR2   | ... |\n  |---------------------|--------|--------|-----|\n  | 2013-01-01 00:00:00 | 2611.0 | 1539.0 | ... |\n  | 2013-01-01 01:00:00 | 2132.0 | 1535.0 | ... |\n\n  Note1: The date column represents timestamps.\n\n  Note2: VAR1, VAR2, etc., represent different variables (features) for each timestamp.\n\n- Place the csv file in your local `datasets` directory (e.g., `./datasets`).\n\n1. **Configure the Dataset**:\n- Use the following configuration to specify the dataset, file path, and frequency:\n   ```bash\n   --data.data_manager.init_args.dataset {DATASET_NAME} \\\n   --data.data_manager.init_args.data_path /path/to/data_file.tsf \\\n   --data.data_manager.init_args.freq {FREQ} \n   ```\n\n- **Example Configuration**:\n   ```bash\n   --data.data_manager.init_args.dataset my_data \\\n   --data.data_manager.init_args.data_path ./datasets/my_data.csv \\\n   --data.data_manager.init_args.freq H \\\n   --data.data_manager.init_args.context_length 96 \\\n   --data.data_manager.init_args.prediction_length 96 \\\n   --data.data_manager.init_args.multivariate true\n   ```\n\n*Note: You can adjust the test instance sampling using the `--data.data_manager.init_args.test_rolling_length` parameter.*\n\n\n\n## Model\n\n### Available Models\n\nProbTS includes both classical time-series models, specializing in long-term point forecasting or short-term distributional forecasting, and recent time-series foundation models that offer zero-shot and arbitrary-horizon forecasting capabilities for new time series.\n\n**Classical Time-series Models**\n\n| **Model** | **Original Eval. Horizon** | **Estimation** | **Decoding Scheme** | **Class Path** |\n| --- | --- | --- | --- | --- |\n| Linear | - | Point | Auto / Non-auto | `probts.model.forecaster.point_forecaster.LinearForecaster` |\n| [GRU](https://arxiv.org/abs/1412.3555) | - | Point | Auto / Non-auto | `probts.model.forecaster.point_forecaster.GRUForecaster` |\n| [Transformer](https://arxiv.org/abs/1706.03762) | - | Point | Auto / Non-auto | `probts.model.forecaster.point_forecaster.TransformerForecaster` |\n| [Autoformer](https://arxiv.org/abs/2106.13008) | Long-trem | Point | Non-auto | `probts.model.forecaster.point_forecaster.Autoformer` |\n| [N-HiTS](https://arxiv.org/abs/2201.12886) | Long-trem | Point | Non-auto | `probts.model.forecaster.point_forecaster.NHiTS` |\n| [NLinear](https://arxiv.org/abs/2205.13504) | Long-trem | Point | Non-auto | `probts.model.forecaster.point_forecaster.NLinear` |\n| [DLinear](https://arxiv.org/abs/2205.13504) | Long-trem | Point | Non-auto | `probts.model.forecaster.point_forecaster.DLinear` |\n| [TSMixer](https://arxiv.org/abs/2303.06053) | Long-trem | Point | Non-auto | `probts.model.forecaster.point_forecaster.TSMixer` |\n| [TimesNet](https://arxiv.org/abs/2210.02186) | Short- / Long-term | Point | Non-auto | `probts.model.forecaster.point_forecaster.TimesNet` |\n| [PatchTST](https://arxiv.org/abs/2211.14730) | Long-trem | Point | Non-auto | `probts.model.forecaster.point_forecaster.PatchTST` |\n| [iTransformer](https://arxiv.org/abs/2310.06625) | Long-trem | Point | Non-auto | `probts.model.forecaster.point_forecaster.iTransformer` |\n| [ElasTST](https://arxiv.org/abs/2411.01842) | Long-trem | Point | Non-auto | `probts.model.forecaster.point_forecaster.ElasTST` |\n| [GRU NVP](https://arxiv.org/abs/2002.06103) | Short-term | Probabilistic | Auto | `probts.model.forecaster.prob_forecaster.GRU_NVP` |\n| [GRU MAF](https://arxiv.org/abs/2002.06103) | Short-term | Probabilistic | Auto | `probts.model.forecaster.prob_forecaster.GRU_MAF` |\n| [Trans MAF](https://arxiv.org/abs/2002.06103) | Short-term | Probabilistic | Auto | `probts.model.forecaster.prob_forecaster.Trans_MAF` |\n| [TimeGrad](https://arxiv.org/abs/2101.12072) | Short-term | Probabilistic | Auto | `probts.model.forecaster.prob_forecaster.TimeGrad` |\n| [CSDI](https://arxiv.org/abs/2107.03502) | Short-term | Probabilistic | Non-auto | `probts.model.forecaster.prob_forecaster.CSDI` |\n| [TSDiff](https://arxiv.org/abs/2307.11494) | Short-term | Probabilistic | Non-auto | `probts.model.forecaster.prob_forecaster.TSDiffCond` |\n\n**Fundation Models**\n\n| **Model** | **Any Horizon** | **Estimation** | **Decoding Scheme** | **Class Path** | **Model Size** | \n| --- | --- | --- | --- | --- | --- |\n| [Lag-Llama](https://arxiv.org/abs/2310.08278) | &#x2714; | Probabilistic | AR | `probts.model.forecaster.prob_forecaster.LagLlama` | - |\n| [ForecastPFN](https://arxiv.org/abs/2311.01933) | &#x2714; | Point | NAR | `probts.model.forecaster.point_forecaster.ForecastPFN` | - |\n| [TimesFM](https://arxiv.org/abs/2310.10688) | &#x2714; | Point | AR | `probts.model.forecaster.point_forecaster.TimesFM` | `200m`, `500m` |\n| [TTM](https://arxiv.org/abs/2401.03955) | &#x2718; | Point | NAR | `probts.model.forecaster.point_forecaster.TinyTimeMixer` | - |\n| [Timer](https://arxiv.org/abs/2402.02368) | &#x2714; | Point | AR | `probts.model.forecaster.point_forecaster.Timer` | - |\n| [MOIRAI](https://arxiv.org/abs/2402.02592) | &#x2714; | Probabilistic | NAR | `probts.model.forecaster.prob_forecaster.Moirai` | `small`, `base`, `large` |\n| [UniTS](https://arxiv.org/abs/2403.00131) | &#x2714; | Point | NAR | `probts.model.forecaster.point_forecaster.UniTS` | - |\n| [Chronos](https://arxiv.org/abs/2403.07815) | &#x2714; | Probabilistic | AR | `probts.model.forecaster.prob_forecaster.Chronos` | `tiny`, `mini`, `small`, `base`, `large` |\n| [Time-MoE](https://arxiv.org/abs/2409.16040) | &#x2714; | Point | AR | `probts.model.forecaster.point_forecaster.TimeMoE` | `50M`, `200M` |\n\nSee the [tsfm configuration directory](./config/tsfm/) for more details. More models will be added soon—stay tuned!\n\n\n\n### Using Customized Model\n\nWith our platform, you can easily evaluate customized models across various datasets. Follow the steps below to create and evaluate your model.\n\n\n**Step 1: Create a New Python File**\n\nCreate a new Python file and follow the structure below to define your custom model:\n\n```python\nfrom probts.model.forecaster import Forecaster\n\nclass ModelName(Forecaster):\n    def __init__(\n        self,\n        **kwargs\n    ):\n        \"\"\"\n        Initialize the model with parameters.\n        \"\"\"\n        super().__init__(**kwargs)\n        # Initialize model parameters here\n\n    def forward(self, inputs):\n        \"\"\"\n        Forward pass for the model.\n\n        Parameters:\n        inputs [Tensor]: Input tensor for the model.\n\n        Returns:\n        Tensor: Output tensor.\n        \"\"\"\n        # Perform the forward pass of the model\n        return outputs\n\n    def loss(self, batch_data):\n        \"\"\"\n        Compute the loss for the given batch data.\n\n        Parameters:\n        batch_data [dict]: Dictionary containing input data and possibly target data.\n\n        Returns:\n        Tensor: Computed loss.\n        \"\"\"\n        # Extract inputs and targets from batch_data\n        inputs = batch_data.past_target_cdf[:, -self.context_length:, :] # [batch_size, context_length, var_num]\n        target = batch_data.future_target_cdf # [batch_size, prediction_length, var_num]\n\n        # Forward pass\n        outputs = self.forward(inputs)\n        \n        # Calculate loss using a loss function, e.g., Mean Squared Error\n        loss = self.loss_function(outputs, future_target_cdf)\n\n        return loss\n\n    def forecast(self, batch_data, num_samples=None):\n        \"\"\"\n        Generate forecasts for the given batch data.\n\n        Parameters:\n        batch_data [dict]: Dictionary containing input data.\n        num_samples [int, optional]: Number of samples per distribution during evaluation. Defaults to None.\n\n        Returns:\n        Tensor: Forecasted outputs.\n        \"\"\"\n        # Perform the forward pass to get the outputs\n        outputs = self(batch_data.past_target_cdf[:, -self.context_length:, :])\n\n        if num_samples is not None:\n            # If num_samples is specified, use it to sample from the distribution\n            outputs = self.sample_from_distribution(outputs, num_samples)\n        else: \n            # If perform point estimation, the num_samples is equal to 1\n            outputs = outputs.unsqueeze(1)\n        return outputs # [batch_size, num_samples, prediction_length, var_num]\n```\n\n  **Input Data Format**\n\n  The `batch_data` dictionary contains several fields that provide necessary information for the model's operation. Each field is described below:\n\n  - **`target_dimension_indicator`**: \n    - **Shape**: [var_num]\n    - **Description**: Indicator that specifies which dimension or feature of the target is being referenced. \n\n  - **`{past|future}_time_feat`**: \n    - **Shape**: [batch_size,length,time_feature_dim]\n    - **Description**: Time features associated with each time step in the past or future. This can include various time-related information such as timestamps, seasonal indicators (e.g., month, day of the week), or other temporal features that provide context to the observations.\n  - **`{past|future}_target_cdf`**: \n    - **Shape**: [batch_size,length,var_num]\n    - **Description**: The observation values of the target variable(s) for past or future time steps. \n  - **`{past|future}_observed_values`**: \n    - **Shape**: [batch_size,length,var_num]\n    - **Description**: Binary masks indicating which values in the past or future target data are observed (1) and which are missing or unobserved (0). \n\n**Step 2: Create YAML Configuration File**\n\nCreate a YAML configuration file (`model.yaml`) for the customized model:\n\n```yaml\nseed_everything: 1 # random seed\ntrainer:\n  accelerator: gpu\n  devices: 1\n  strategy: auto\n  max_epochs: 50\n  use_distributed_sampler: false\n  limit_train_batches: 100\n  log_every_n_steps: 1\n  default_root_dir: ./results # path to the log folder\nmodel:\n  forecaster:\n    class_path: class.path.to.ModelName\n    init_args:\n      # init your hyperparameter here\n  learning_rate: 0.001 # learning rate\ndata:\n  data_manager:\n    class_path: probts.data.data_manager.DataManager\n    init_args:\n      dataset: solar_nips # dataset name\n      split_val: true\n      scaler: standard # identity, standard, temporal\n  batch_size: 32\n  test_batch_size: 32\n  num_workers: 8\n```\n\n**Step 3: Run the Customized Model**\n\nRun the customized model using the configuration file:\n\n```bash\npython run.py --config config/path/to/model.yaml\n```\n\n\n## Training\n\n\n### Configuring Optimizers and Learning Rate Schedulers\n\nProbTS supports customizable optimizers and learning rate schedulers. You can specify them directly in the YAML configuration file.\n\n**Example Configuration**\n```yaml \nmodel:\n  forecaster:\n    class_path: probts.model.forecaster.point_forecaster.PatchTST\n    init_args:\n      # Add forecaster-specific parameters here\n\n  optimizer_config:\n    class_name: torch.optim.Adam\n    init_args:\n      weight_decay: 0  # Add optimizer-specific parameters here\n\n  lr_scheduler_config:\n    class_name: torch.optim.lr_scheduler.OneCycleLR\n    init_args:\n      max_lr: 0.0001\n      steps_per_epoch: 100\n      pct_start: 0.3\n      epochs: 50  # Add scheduler-specific parameters here\n```\n\nExample configurations can be found in [config/default/patchtst.yaml](../../config/default/patchtst.yaml).\n\n**Notes**\n\n- If no configuration is provided, ProbTS defaults to the Adam optimizer with a constant learning rate.\n- Adjust init_args for both the optimizer and scheduler to suit your specific use case.\n\n\n## Forecasting with Varied Prediction Lengths\n\n\n**Example:**\n```bash \npython run.py --config config/multi_hor/elastst.yaml \\\n                --data.data_manager.init_args.path ./datasets \\\n                --trainer.default_root_dir /path/to/log_dir/ \\\n                --data.data_manager.init_args.dataset {DATASET_NAME} \\\n                --data.data_manager.init_args.context_length ${TEST_CTX_LEN} \\\n                --data.data_manager.init_args.prediction_length ${TEST_PRED_LEN} \\\n                --data.data_manager.init_args.train_ctx_len ${TRAIN_CTX_LEN} \\\n                --data.data_manager.init_args.train_pred_len_list ${TRAIN_PRED_LEN} \\\n                --data.data_manager.init_args.val_ctx_len ${VAL_CTX_LEN} \\\n                --data.data_manager.init_args.val_pred_len_list ${VAL_PRED_LEN} \n```\n\n- `DATASET_NAME`: Select from datasets used in long-term forecasting scenerios.\n- `TEST_CTX_LEN`: Context length in the testing phase.\n- `VAL_CTX_LEN` (Default: `TEST_CTX_LEN`): Context length in the validation phase.\n- `TRAIN_CTX_LEN` (Default: `TEST_CTX_LEN`): Context length in the training phase.\n- `TEST_PRED_LEN`: Forecasting horizons in the testing phase.\n- `VAL_PRED_LEN` (Default: `TEST_PRED_LEN`): Forecasting horizons for performance validation.\n- `TRAIN_PRED_LEN` (Default: `TEST_PRED_LEN`): Forecasting horizons in the training phase.\n\nThe results across multiple horizons will be saved to: \n```bash \n/path/to/log_dir/{DATASET_NAME}_{MODEL}_{seed}_TrainCTX_{TRAIN_CTX_LEN}_TrainPRED_{TRAIN_PRED_LEN}_ValCTX_{CTX_LEN}_ValPRED_{VAL_PRED_LEN}/horizons_results.csv\n```\n\n### Example 1: Varied-Horizon Training\n\n**Mode 1: Random sampling from a set of horizons**\n\n```bash \npython run.py --config config/multi_hor/elastst.yaml \\\n                --data.data_manager.init_args.path ./datasets \\\n                --trainer.default_root_dir /path/to/log_dir/ \\\n                --data.data_manager.init_args.dataset ${DATASET} \\\n                --data.data_manager.init_args.context_length 96 \\\n                --data.data_manager.init_args.prediction_length 720 \\\n                --data.data_manager.init_args.train_ctx_len 96 \\\n                --data.data_manager.init_args.val_pred_len_list 720 \\\n                # random selection from {96, 192, 336, 720}\n                --data.data_manager.init_args.train_pred_len_list 96-192-336-720 \\\n                --data.data_manager.init_args.continuous_sample false \n```\n\n**Mode 2: Random sampling from a horizon range**\n\n```bash \npython run.py --config config/multi_hor/elastst.yaml \\\n                --data.data_manager.init_args.path ./datasets \\\n                --trainer.default_root_dir /path/to/log_dir/ \\\n                --data.data_manager.init_args.dataset ${DATASET} \\\n                --data.data_manager.init_args.context_length 96 \\\n                --data.data_manager.init_args.prediction_length 720 \\\n                --data.data_manager.init_args.train_ctx_len 96 \\\n                --data.data_manager.init_args.val_pred_len_list 720 \\\n                # random sampling from [1, 720]\n                --data.data_manager.init_args.train_pred_len_list 1-720 \\ \n                --data.data_manager.init_args.continuous_sample true \n```\n\n### Example 2: Validation and Testing with Multiple Horizons\n\n```bash \npython run.py --config config/multi_hor/elastst.yaml \\\n                --data.data_manager.init_args.path ./datasets \\\n                --trainer.default_root_dir /path/to/log_dir/ \\\n                --data.data_manager.init_args.dataset ${DATASET} \\\n                --data.data_manager.init_args.context_length 96 \\\n                --data.data_manager.init_args.train_pred_len_list 720 \\ \n                --data.data_manager.init_args.train_ctx_len 96 \\\n                # validation on {96, 192, 336, 720}\n                --data.data_manager.init_args.val_pred_len_list 96-192-336-720 \\\n                # testing on {24, 96, 192, 336, 720, 1024}\n                --data.data_manager.init_args.prediction_length 24-96-192-336-720-1024 \n```\n"
  },
  {
    "path": "exps/.gitignore",
    "content": "*\n!.gitignore"
  },
  {
    "path": "notebook/data_characteristics.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from pathlib import Path\\n\",\n    \"from gluonts.dataset.repository.datasets import get_dataset\\n\",\n    \"from gluonts.dataset.multivariate_grouper import MultivariateGrouper\\n\",\n    \"import pandas as pd\\n\",\n    \"import numpy as np\\n\",\n    \"\\n\",\n    \"data_path = 'path/to/datasets/'\\n\",\n    \"save_path = Path(data_path)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Decomposition\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 3,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from statsmodels.tsa.seasonal import STL\\n\",\n    \"from tqdm import trange\\n\",\n    \"\\n\",\n    \"def measure_strength(df, dataset, win=0):\\n\",\n    \"    \\\"\\\"\\\"\\n\",\n    \"    Measures the strength of trend (F_t) and seasonality (F_s) in time series data.\\n\",\n    \"\\n\",\n    \"    Parameters:\\n\",\n    \"    - df (pd.DataFrame): The input data containing time series columns.\\n\",\n    \"    - dataset (str): The name of the dataset to identify frequency or specific configurations.\\n\",\n    \"    - win (int): Window size for decomposition; if 0, applies decomposition on the full time series.\\n\",\n    \"\\n\",\n    \"    Outputs:\\n\",\n    \"    Prints the average strength of trend and seasonality for the dataset.\\n\",\n    \"    \\\"\\\"\\\"\\n\",\n    \"    # Decompose the time series for each dimension\\n\",\n    \"    dim_list = ts_decompose(df, dataset, win=win)\\n\",\n    \"    \\n\",\n    \"    F_t_list = []  # List to store trend strength values\\n\",\n    \"    F_s_list = []  # List to store seasonality strength values\\n\",\n    \"    \\n\",\n    \"    for res in dim_list:\\n\",\n    \"        # Skip calculations if variance of the decomposed components is zero\\n\",\n    \"        if (res.trend + res.resid).var() == 0 or (res.seasonal + res.resid).var() == 0:\\n\",\n    \"            continue\\n\",\n    \"        \\n\",\n    \"        # Calculate trend strength (F_t)\\n\",\n    \"        F_t = max(0, 1 - (res.resid.var() / (res.trend + res.resid).var()))\\n\",\n    \"        F_t_list.append(F_t)\\n\",\n    \"        \\n\",\n    \"        # Calculate seasonality strength (F_s)\\n\",\n    \"        F_s = max(0, 1 - (res.resid.var() / (res.seasonal + res.resid).var()))\\n\",\n    \"        F_s_list.append(F_s)\\n\",\n    \"    \\n\",\n    \"    # Print summary of results\\n\",\n    \"    print('dataset: {dataset}, \\\\t win. size: {win},\\\\t Avg. F_t: {avg_ft:2.4f},\\\\t Avg. F_s: {avg_fs:2.4f}'.format(\\n\",\n    \"        dataset=dataset, win=win, avg_ft=np.mean(F_t_list), avg_fs=np.mean(F_s_list)\\n\",\n    \"    ))\\n\",\n    \"\\n\",\n    \"def ts_decompose(df, dataset, win=0):\\n\",\n    \"    \\\"\\\"\\\"\\n\",\n    \"    Decomposes time series data into trend, seasonal, and residual components.\\n\",\n    \"\\n\",\n    \"    Parameters:\\n\",\n    \"    - df (pd.DataFrame): The input data containing time series columns.\\n\",\n    \"    - dataset (str): The name of the dataset to identify frequency or specific configurations.\\n\",\n    \"    - win (int): Window size for decomposition; if 0, applies decomposition on the full time series.\\n\",\n    \"\\n\",\n    \"    Returns:\\n\",\n    \"    - dim_list (list): A list of decomposition results for each dimension of the time series.\\n\",\n    \"    \\\"\\\"\\\"\\n\",\n    \"    # Define frequency mapping for datasets\\n\",\n    \"    freq_dict = {\\n\",\n    \"        'ETT-small/ETTh1': 'H', 'ETT-small/ETTh2': 'H', 'ETT-small/ETTm1': 'T', 'ETT-small/ETTm2': 'T',\\n\",\n    \"        'electricity/electricity': 'H', 'exchange_rate/exchange_rate': 'B',\\n\",\n    \"        'illness/national_illness': 'W', 'traffic/traffic': 'H', 'weather/weather': 'T',\\n\",\n    \"        'exchange_rate_nips': 'B', 'solar_nips': 'H', 'electricity_nips': 'H',\\n\",\n    \"        'traffic_nips': 'H', 'wiki2000_nips': 'D'\\n\",\n    \"    }\\n\",\n    \"    \\n\",\n    \"    # Define minimum period mapping for datasets\\n\",\n    \"    min_dict = {\\n\",\n    \"        'ETT-small/ETTm1': (24 * 60) // 15, 'ETT-small/ETTm2': (24 * 60) // 15,\\n\",\n    \"        'weather/weather': (24 * 60) // 10\\n\",\n    \"    }\\n\",\n    \"\\n\",\n    \"    dim = len(df.iloc[0])  # Number of dimensions (columns) in the data\\n\",\n    \"    dim_list = []  # List to store decomposition results for each dimension\\n\",\n    \"    \\n\",\n    \"    for i in trange(dim):  # Iterate over each column in the dataset\\n\",\n    \"        if win == 0:\\n\",\n    \"            # Standardize the time series column\\n\",\n    \"            tmp_df = (df.iloc[:, i] - df.iloc[:, i].mean()) / (df.iloc[:, i].std())\\n\",\n    \"            \\n\",\n    \"            # Perform STL decomposition with appropriate frequency settings\\n\",\n    \"            if dataset in freq_dict and freq_dict[dataset] == 'T':\\n\",\n    \"                stl = STL(tmp_df.fillna(0), period=7, robust=True)\\n\",\n    \"            else:\\n\",\n    \"                stl = STL(tmp_df.fillna(0), robust=True)\\n\",\n    \"            \\n\",\n    \"            res = stl.fit()  # Fit the decomposition model\\n\",\n    \"            dim_list.append(res)  # Store the result\\n\",\n    \"        else:\\n\",\n    \"            # Perform windowed decomposition\\n\",\n    \"            right = win  # Initialize the right boundary of the window\\n\",\n    \"            while right < len(df.iloc[1:, i]):\\n\",\n    \"                tmp_df = df.iloc[right - win:right, i]  # Extract the windowed data\\n\",\n    \"                tmp_df = (tmp_df - tmp_df.mean()) / (tmp_df.std())  # Standardize the windowed data\\n\",\n    \"                \\n\",\n    \"                # Perform STL decomposition with appropriate frequency settings\\n\",\n    \"                if dataset in freq_dict and freq_dict[dataset] == 'T':\\n\",\n    \"                    stl = STL(tmp_df.fillna(0), period=7, robust=True)\\n\",\n    \"                else:\\n\",\n    \"                    stl = STL(tmp_df.fillna(0), robust=True)\\n\",\n    \"                \\n\",\n    \"                res = stl.fit()  # Fit the decomposition model\\n\",\n    \"                right += win  # Move the window forward\\n\",\n    \"                dim_list.append(res)  # Store the result\\n\",\n    \"        \\n\",\n    \"    return dim_list  # Return the list of decomposition results\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Normality\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 4,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from scipy.stats import normaltest\\n\",\n    \"import numpy as np\\n\",\n    \"import scipy.stats\\n\",\n    \"from scipy.stats import norm\\n\",\n    \"\\n\",\n    \"def test_normal(df, dataset, win=0):\\n\",\n    \"    dim = len(df.iloc[0])\\n\",\n    \"    score_list = []\\n\",\n    \"    gaussian_count = 0\\n\",\n    \"    count = 0\\n\",\n    \"    for i in range(dim):\\n\",\n    \"        # z-score\\n\",\n    \"        # df.iloc[:,i]=(df.iloc[:,i]-df.iloc[:,i].mean())/(df.iloc[:,i].std())\\n\",\n    \"        value = df.iloc[:,i].dropna().values\\n\",\n    \"        if len(value) < 10:\\n\",\n    \"            continue\\n\",\n    \"        \\n\",\n    \"        right = win\\n\",\n    \"        pvalue = []\\n\",\n    \"        if win > 0:\\n\",\n    \"            while right < len(value):\\n\",\n    \"                res = normaltest(value[right-win:right])[1]\\n\",\n    \"                pvalue.append(res)\\n\",\n    \"                right += win\\n\",\n    \"            res = np.mean(pvalue)\\n\",\n    \"        else:\\n\",\n    \"            res = normaltest(value)[1]\\n\",\n    \"            # res = kstest(value, 'norm')[1]\\n\",\n    \"            if sum(value) == 0:\\n\",\n    \"                continue\\n\",\n    \"            \\n\",\n    \"        if res >= 0.05:\\n\",\n    \"            gaussian_count += 1\\n\",\n    \"        count += 1\\n\",\n    \"            \\n\",\n    \"        score_list.append(res)\\n\",\n    \"\\n\",\n    \"    \\n\",\n    \"    print(dataset, \\\" gaussian pvalue: \\\", str(np.mean(score_list)), '\\\\t gaussian ratio: ', str(gaussian_count/count))\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def JS_divergence(p,q):\\n\",\n    \"    M=(p+q)/2\\n\",\n    \"    return 0.5*scipy.stats.entropy(p, M, base=2)+0.5*scipy.stats.entropy(q, M, base=2)\\n\",\n    \"\\n\",\n    \"def JS_div(arr1,arr2,num_bins):\\n\",\n    \"    max0 = max(np.max(arr1),np.max(arr2))\\n\",\n    \"    min0 = min(np.min(arr1),np.min(arr2))\\n\",\n    \"    bins = np.linspace(min0-1e-4, max0-1e-4, num=num_bins)\\n\",\n    \"    \\n\",\n    \"    PDF1 = pd.cut(arr1,bins,duplicates='drop').value_counts()\\n\",\n    \"    PDF2 = pd.cut(arr2,bins, duplicates='drop').value_counts()\\n\",\n    \"    \\n\",\n    \"    if sum(PDF1) > 0 and sum(PDF2) > 0:\\n\",\n    \"        PDF1 = PDF1 / len(arr1)\\n\",\n    \"        PDF2 = PDF2 / len(arr2)\\n\",\n    \"        return JS_divergence(PDF1.values,PDF2.values)\\n\",\n    \"    else:\\n\",\n    \"        return None\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def cal_JS_divergence(df, dataset, win=0):\\n\",\n    \"    \\n\",\n    \"    dim = len(df.iloc[0])\\n\",\n    \"    js_list = []\\n\",\n    \"    for i in range(1, dim):\\n\",\n    \"        \\n\",\n    \"        # z-score\\n\",\n    \"        global_mu = df.iloc[:,i].mean()\\n\",\n    \"        global_std = df.iloc[:,i].std()\\n\",\n    \"        df.iloc[:,i]=(df.iloc[:,i]-global_mu) / global_std\\n\",\n    \"        value = df.iloc[:,i].dropna().values\\n\",\n    \"        \\n\",\n    \"        if sum(value) == 0:\\n\",\n    \"            continue\\n\",\n    \"        \\n\",\n    \"        right = win\\n\",\n    \"        dim_list = []\\n\",\n    \"        if win > 0:\\n\",\n    \"            while right < len(value):\\n\",\n    \"                tmp_value = value[right-win:right]\\n\",\n    \"                mu = tmp_value.mean()\\n\",\n    \"                std = tmp_value.std()\\n\",\n    \"\\n\",\n    \"                norm_dist = norm.rvs(loc=mu, scale=std, size=len(tmp_value))\\n\",\n    \"                res = JS_div(tmp_value,norm_dist,num_bins=20)\\n\",\n    \"                if res is not None:\\n\",\n    \"                    dim_list.append(res)\\n\",\n    \"                right += win\\n\",\n    \"                \\n\",\n    \"            js_div = np.mean(dim_list)\\n\",\n    \"\\n\",\n    \"        else:\\n\",\n    \"            norm_dist = norm.rvs(loc=global_mu, scale=global_std, size=len(value))\\n\",\n    \"            js_div = JS_div(value,norm_dist,num_bins=20)\\n\",\n    \"        \\n\",\n    \"        if js_div is not None:\\n\",\n    \"            js_list.append(js_div)\\n\",\n    \"        \\n\",\n    \"    print(\\\"window size: \\\", win, \\\"\\\\t dataset: \\\", dataset, \\\"\\\\t JS DIV avg: \\\", str(np.mean(js_list)))\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Long-term Datasets\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 6,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def load_csv_data(filename, dataset):\\n\",\n    \"    \\\"\\\"\\\"\\n\",\n    \"    Loads time series data from a CSV file and processes it based on dataset-specific requirements.\\n\",\n    \"\\n\",\n    \"    Parameters:\\n\",\n    \"    - filename (str): Path to the directory containing the CSV file.\\n\",\n    \"    - dataset (str): Name of the dataset to be loaded, used for specific handling.\\n\",\n    \"\\n\",\n    \"    Returns:\\n\",\n    \"    - df (pd.DataFrame): Processed DataFrame with time series data, indexed by date.\\n\",\n    \"    \\\"\\\"\\\"\\n\",\n    \"    # Dictionary to map dataset names to their respective data frequency\\n\",\n    \"    freq_dict = {\\n\",\n    \"        'ETT-small/ETTh1': 'H', 'ETT-small/ETTh2': 'H', 'ETT-small/ETTm1': 'T', 'ETT-small/ETTm2': 'T',\\n\",\n    \"        'electricity/electricity': 'H', 'exchange_rate/exchange_rate': 'D',\\n\",\n    \"        'illness/national_illness': 'D', 'traffic/traffic': 'H', 'weather/weather': 'T'\\n\",\n    \"    }\\n\",\n    \"\\n\",\n    \"    # Special handling for 'caiso' dataset\\n\",\n    \"    if 'caiso' in dataset:\\n\",\n    \"        # Load the dataset and convert the 'Date' column to datetime\\n\",\n    \"        data = pd.read_csv(filename + dataset + '.csv')\\n\",\n    \"        data['Date'] = data['Date'].astype('datetime64[ns]')\\n\",\n    \"        \\n\",\n    \"        # Names of zones in the dataset\\n\",\n    \"        names = ['PGE', 'SCE', 'SDGE', 'VEA', 'CA ISO', 'PACE', 'PACW', 'NEVP', 'AZPS', 'PSEI']\\n\",\n    \"        \\n\",\n    \"        # Create a DataFrame with a complete hourly date range\\n\",\n    \"        df = pd.DataFrame(pd.date_range('20130101', '20210630', freq='H')[:-1], columns=['Date'])\\n\",\n    \"        \\n\",\n    \"        # Process each zone's data and merge into a single DataFrame\\n\",\n    \"        for name in names:\\n\",\n    \"            current_df = (\\n\",\n    \"                data[data['zone'] == name]\\n\",\n    \"                .drop_duplicates(subset='Date', keep='last')  # Remove duplicate entries, keeping the last\\n\",\n    \"                .rename(columns={'load': name})  # Rename 'load' column to the zone name\\n\",\n    \"                .drop(columns=['zone'])  # Drop the 'zone' column\\n\",\n    \"            )\\n\",\n    \"            df = df.merge(current_df, on='Date', how='outer')  # Merge with the main DataFrame\\n\",\n    \"        \\n\",\n    \"        # Rename the 'Date' column to 'date'\\n\",\n    \"        df = df.rename(columns={'Date': 'date'})\\n\",\n    \"    elif 'nordpool' in dataset:\\n\",\n    \"        # Special handling for 'nordpool' dataset: Parse the 'Time' column as datetime\\n\",\n    \"        df = pd.read_csv(filename + dataset + '.csv', parse_dates=['Time'])\\n\",\n    \"        df = df.rename(columns={'Time': 'date'})  # Rename the 'Time' column to 'date'\\n\",\n    \"    else:\\n\",\n    \"        # General case: Load the dataset as-is\\n\",\n    \"        df = pd.read_csv(filename + dataset + '.csv')\\n\",\n    \"    \\n\",\n    \"    # Convert the 'date' column to datetime format and set it as the index\\n\",\n    \"    df['date'] = pd.to_datetime(df['date'])\\n\",\n    \"    df = df.set_index('date')\\n\",\n    \"\\n\",\n    \"    # Drop the first column (usually an index column or non-relevant column)\\n\",\n    \"    df = df.iloc[:, 1:]\\n\",\n    \"    \\n\",\n    \"    return df  # Return the processed DataFrame\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 8,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"100%|██████████| 6/6 [00:10<00:00,  1.67s/it]\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"dataset: ETT-small/ETTh1, \\t win. size: 0,\\t Avg. F_t: 0.7728,\\t Avg. F_s: 0.4772\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"dataset = 'ETT-small/ETTh1' # 'exchange_rate/exchange_rate'\\n\",\n    \"win_len = 0\\n\",\n    \"df = load_csv_data(data_path, dataset)\\n\",\n    \"measure_strength(df, dataset, win=win_len)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 8,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"window size:  336 \\t dataset:  ETT-small/ETTh1 \\t JS DIV avg:  0.0719988819816385\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"dataset = 'ETT-small/ETTh1' # 'exchange_rate/exchange_rate'\\n\",\n    \"win_len = 336\\n\",\n    \"df = load_csv_data(data_path, dataset)\\n\",\n    \"cal_JS_divergence(df, dataset, win=win_len)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Short-term Datasets\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 9,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def load_prob_data(dataset, win=0):\\n\",\n    \"    freq_dict = {'exchange_rate_nips':'B','solar_nips':'H','electricity_nips':'H','traffic_nips':'H', 'wiki2000_nips':'D'}\\n\",\n    \"    \\n\",\n    \"    idx = 0\\n\",\n    \"    dataname = dataset\\n\",\n    \"    dataset = get_dataset(dataset, path=save_path, regenerate=False)\\n\",\n    \"    dim = int(dataset.metadata.feat_static_cat[0].cardinality)\\n\",\n    \"    train_grouper = MultivariateGrouper(max_target_dim=dim)\\n\",\n    \"    dataset_train = train_grouper(dataset.train)\\n\",\n    \"    data = list(dataset_train)[0]['target']\\n\",\n    \"    start_date = dataset_train[0]['start'].to_timestamp()\\n\",\n    \"    \\n\",\n    \"    # multi\\n\",\n    \"    idx = [i for i in range(dim)]\\n\",\n    \"\\n\",\n    \"    data = data.transpose(1,0)\\n\",\n    \"    df = pd.DataFrame(data,columns=idx,dtype=float)\\n\",\n    \"\\n\",\n    \"    df['date'] = pd.date_range(start_date,periods=len(data),freq=freq_dict[dataname]) \\n\",\n    \"    df = df.set_index('date')\\n\",\n    \"        \\n\",\n    \"    return df\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 10,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"/home/v-zhangjiaw/miniconda3/envs/probts/lib/python3.10/site-packages/gluonts/dataset/common.py:263: FutureWarning: Period with BDay freq is deprecated and will be removed in a future version. Use a DatetimeIndex with BDay freq instead.\\n\",\n      \"  return pd.Period(val, freq)\\n\",\n      \"/home/v-zhangjiaw/miniconda3/envs/probts/lib/python3.10/site-packages/gluonts/dataset/multivariate_grouper.py:114: FutureWarning: Period with BDay freq is deprecated and will be removed in a future version. Use a DatetimeIndex with BDay freq instead.\\n\",\n      \"  timestamp + len(data[FieldName.TARGET]) - 1,\\n\",\n      \"/home/v-zhangjiaw/miniconda3/envs/probts/lib/python3.10/site-packages/gluonts/dataset/multivariate_grouper.py:243: FutureWarning: Period with BDay freq is deprecated and will be removed in a future version. Use a DatetimeIndex with BDay freq instead.\\n\",\n      \"  index=pd.period_range(\\n\",\n      \"/home/v-zhangjiaw/miniconda3/envs/probts/lib/python3.10/site-packages/gluonts/dataset/multivariate_grouper.py:243: FutureWarning: PeriodDtype[B] is deprecated and will be removed in a future version. Use a DatetimeIndex with freq='B' instead\\n\",\n      \"  index=pd.period_range(\\n\",\n      \"/home/v-zhangjiaw/miniconda3/envs/probts/lib/python3.10/site-packages/gluonts/dataset/multivariate_grouper.py:188: FutureWarning: Period with BDay freq is deprecated and will be removed in a future version. Use a DatetimeIndex with BDay freq instead.\\n\",\n      \"  pd.period_range(\\n\",\n      \"/home/v-zhangjiaw/miniconda3/envs/probts/lib/python3.10/site-packages/gluonts/dataset/multivariate_grouper.py:188: FutureWarning: PeriodDtype[B] is deprecated and will be removed in a future version. Use a DatetimeIndex with freq='B' instead\\n\",\n      \"  pd.period_range(\\n\",\n      \"/tmp/ipykernel_1399510/2105741496.py:11: FutureWarning: Period with BDay freq is deprecated and will be removed in a future version. Use a DatetimeIndex with BDay freq instead.\\n\",\n      \"  start_date = dataset_train[0]['start'].to_timestamp()\\n\",\n      \"100%|██████████| 8/8 [00:01<00:00,  5.98it/s]\\n\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"dataset: exchange_rate_nips, \\t win. size: 0,\\t Avg. F_t: 0.9982,\\t Avg. F_s: 0.1256\\n\",\n      \"window size:  30 \\t dataset:  exchange_rate_nips \\t JS DIV avg:  0.2964380648448922\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"# \\\"exchange_rate_nips\\\", \\\"solar_nips\\\", \\\"electricity_nips\\\", \\\"traffic_nips\\\", \\\"taxi_30min\\\", \\\"wiki2000_nips\\\"\\n\",\n    \"dataset = \\\"exchange_rate_nips\\\"\\n\",\n    \"df = load_prob_data(dataset, win=0)\\n\",\n    \"\\n\",\n    \"measure_strength(df, dataset, win=0)\\n\",\n    \"cal_JS_divergence(df, dataset, win=30)\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"probts\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.10.15\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}\n"
  },
  {
    "path": "probts/__init__.py",
    "content": "from .data import *\nfrom .model import *\nfrom .utils import *"
  },
  {
    "path": "probts/callbacks/__init__.py",
    "content": "from .memory_callback import MemoryCallback\nfrom .time_callback import TimeCallback"
  },
  {
    "path": "probts/callbacks/memory_callback.py",
    "content": "import gc\nimport threading\nimport psutil\nimport torch\n\nimport lightning.pytorch as pl\nfrom lightning.pytorch.callbacks.callback import Callback\n\n\ndef byte2gb(x):\n    return float(x / 2**30)\n\n\nclass MemoryTrace:\n    def __init__(self):\n        gc.collect()\n        torch.cuda.empty_cache()\n        torch.cuda.reset_max_memory_allocated()  # reset the peak gauge to zero\n        self.begin = byte2gb(torch.cuda.memory_allocated())\n        self.process = psutil.Process()\n        self.cpu_begin = byte2gb(self.cpu_mem_used())\n        self.peak_monitoring = True\n        peak_monitor_thread = threading.Thread(target=self.peak_monitor_func)\n        peak_monitor_thread.daemon = True\n        peak_monitor_thread.start()\n\n    def cpu_mem_used(self):\n        \"\"\"get resident set size memory for the current process\"\"\"\n        return self.process.memory_info().rss\n\n    def peak_monitor_func(self):\n        self.cpu_peak = -1\n\n        while True:\n            self.cpu_peak = max(self.cpu_mem_used(), self.cpu_peak)\n\n            if not self.peak_monitoring:\n                break\n\n    def __exit__(self, *exc):\n        self.peak_monitoring = False\n\n        gc.collect()\n        torch.cuda.empty_cache()\n        self.end = byte2gb(torch.cuda.memory_allocated())\n        self.peak = byte2gb(torch.cuda.max_memory_allocated())\n        cuda_info = torch.cuda.memory_stats()\n        self.peak_active_gb = byte2gb(cuda_info[\"active_bytes.all.peak\"])\n        self.cuda_malloc_retires = cuda_info.get(\"num_alloc_retries\", 0)\n        self.m_cuda_ooms = cuda_info.get(\"num_ooms\", 0)\n        self.used = byte2gb(self.end - self.begin)\n        self.peaked = byte2gb(self.peak - self.begin)\n        self.max_reserved = byte2gb(torch.cuda.max_memory_reserved())\n\n        self.cpu_end = self.cpu_mem_used()\n        self.cpu_used = byte2gb(self.cpu_end - self.cpu_begin)\n        self.cpu_peaked = byte2gb(self.cpu_peak - self.cpu_begin)\n\n\nclass MemoryCallback(Callback):\n    \"\"\"\n        Trace the memory usage.\n    \"\"\"\n    def __init__(self):\n        self.memory_summary = {\n            'train': {},\n            'val': {},\n            'test': {}\n        }\n    \n    def update_memory_summary(self, key, memtrace):\n        self.memory_summary[key] = {\n            \"mem_peak\": max(memtrace.peak, self.memory_summary[key].get(\"mem_peak\", 0)),\n            \"max_reserved\": max(memtrace.max_reserved, self.memory_summary[key].get(\"max_reserved\", 0)),\n            \"peak_active_gb\": max(memtrace.peak_active_gb, self.memory_summary[key].get(\"peak_active_gb\", 0)),\n            \"cuda_malloc_retires\": max(memtrace.cuda_malloc_retires, self.memory_summary[key].get(\"cuda_malloc_retires\", 0)),\n            \"cpu_total_peaked\": max(memtrace.cpu_peaked + memtrace.cpu_begin, self.memory_summary[key].get(\"cpu_total_peaked\", 0)),\n        }\n    \n    def on_train_epoch_start(\n        self,\n        trainer: \"pl.Trainer\",\n        pl_module: \"pl.LightningModule\"\n    ) -> None:\n        \"\"\"Called when the train epoch begins\"\"\"\n        if torch.cuda.is_available():\n            self.train_memtrace = MemoryTrace()\n    \n    def on_train_epoch_end(\n        self,\n        trainer: \"pl.Trainer\",\n        pl_module: \"pl.LightningModule\"\n    ) -> None:\n        \"\"\"Called when the train epoch ends\"\"\"\n        if torch.cuda.is_available():\n            self.train_memtrace.__exit__()\n            self.update_memory_summary('train', self.train_memtrace)\n\n    def on_validation_epoch_start(\n        self,\n        trainer: \"pl.Trainer\",\n        pl_module: \"pl.LightningModule\"\n    ) -> None:\n        \"\"\"Called when the validation epoch begins\"\"\"\n        if torch.cuda.is_available():\n            self.val_memtrace = MemoryTrace()\n    \n    def on_validation_epoch_end(\n        self,\n        trainer: \"pl.Trainer\",\n        pl_module: \"pl.LightningModule\"\n    ) -> None:\n        \"\"\"Called when the validation epoch ends\"\"\"\n        if torch.cuda.is_available():\n            self.val_memtrace.__exit__()\n            self.update_memory_summary('val', self.val_memtrace)\n    \n    def on_test_epoch_start(\n        self,\n        trainer: \"pl.Trainer\",\n        pl_module: \"pl.LightningModule\"\n    ) -> None:\n        \"\"\"Called when the test epoch begins\"\"\"\n        if torch.cuda.is_available():\n            self.test_memtrace = MemoryTrace()\n    \n    def on_test_epoch_end(\n        self,\n        trainer: \"pl.Trainer\",\n        pl_module: \"pl.LightningModule\"\n    ) -> None:\n        \"\"\"Called when the test epoch ends\"\"\"\n        if torch.cuda.is_available():\n            self.test_memtrace.__exit__()\n            self.update_memory_summary('test', self.test_memtrace)\n"
  },
  {
    "path": "probts/callbacks/time_callback.py",
    "content": "import time\nfrom typing import Any\n\nimport lightning.pytorch as pl\nfrom lightning.pytorch.utilities.types import STEP_OUTPUT\nfrom lightning.pytorch.callbacks.callback import Callback\n\n\nclass TimeCallback(Callback):\n    \"\"\"\n        Trace the computation time.\n    \"\"\"\n    def __init__(self):\n        self.time_summary = {\n            'train_batch_time': [],\n            'val_batch_time': [],\n            'test_batch_time': []\n        }\n    \n    def on_train_batch_start(\n        self, trainer: \"pl.Trainer\", pl_module: \"pl.LightningModule\", batch: Any, batch_idx: int\n    ) -> None:\n        \"\"\"Called when the train batch begins.\"\"\"\n        self.train_start_time = time.time()\n    \n    def on_train_batch_end(\n        self, trainer: \"pl.Trainer\", pl_module: \"pl.LightningModule\", outputs: STEP_OUTPUT, batch: Any, batch_idx: int\n    ) -> None:\n        \"\"\"Called when the train batch ends\"\"\"\n        self.time_summary['train_batch_time'].append(time.time() - self.train_start_time)\n\n    def on_validation_batch_start(\n        self,\n        trainer: \"pl.Trainer\",\n        pl_module: \"pl.LightningModule\",\n        batch: Any,\n        batch_idx: int,\n        dataloader_idx: int = 0,\n    ) -> None:\n        \"\"\"Called when the validation batch begins\"\"\"\n        self.val_start_time = time.time()\n    \n    def on_validation_batch_end(\n        self,\n        trainer: \"pl.Trainer\",\n        pl_module: \"pl.LightningModule\",\n        outputs: STEP_OUTPUT,\n        batch: Any,\n        batch_idx: int,\n        dataloader_idx: int = 0,\n    ) -> None:\n        \"\"\"Called when the validation batch ends\"\"\"\n        self.time_summary['val_batch_time'].append(time.time() - self.val_start_time)\n    \n    def on_test_batch_start(\n        self,\n        trainer: \"pl.Trainer\",\n        pl_module: \"pl.LightningModule\",\n        batch: Any,\n        batch_idx: int,\n        dataloader_idx: int = 0,\n    ) -> None:\n        \"\"\"Called when the test batch begins\"\"\"\n        self.test_start_time = time.time()\n    \n    def on_test_batch_end(\n        self,\n        trainer: \"pl.Trainer\",\n        pl_module: \"pl.LightningModule\",\n        outputs: STEP_OUTPUT,\n        batch: Any,\n        batch_idx: int,\n        dataloader_idx: int = 0,\n    ) -> None:\n        \"\"\"Called when the test batch ends\"\"\"\n        self.time_summary['test_batch_time'].append(time.time() - self.test_start_time)\n"
  },
  {
    "path": "probts/data/__init__.py",
    "content": "from .data_module import *\nfrom .data_manager import *\nfrom .data_utils.time_features import *"
  },
  {
    "path": "probts/data/data_manager.py",
    "content": "import torch\nfrom pathlib import Path\nfrom functools import cached_property\n\nfrom gluonts.dataset.repository import dataset_names, datasets\nfrom gluonts.dataset.multivariate_grouper import MultivariateGrouper\n\nfrom probts.data.data_utils.get_datasets import get_dataset_info, get_dataset_borders, load_dataset\nfrom probts.data.datasets.single_horizon_datasets import SingleHorizonDataset\nfrom probts.data.datasets.multi_horizon_datasets import MultiHorizonDataset\nfrom probts.data.datasets.gift_eval_datasets import GiftEvalDataset\n\nfrom probts.data.data_utils.time_features import get_lags\nfrom probts.data.data_utils.data_utils import split_train_val, truncate_test, get_rolling_test, df_to_mvds\nfrom probts.data.data_wrapper import ProbTSBatchData\nfrom probts.utils.utils import ensure_list\nfrom probts.data.data_utils.data_scaler import StandardScaler, TemporalScaler, IdentityScaler\nfrom typing import Union\n\nMULTI_VARIATE_DATASETS = [\n    'exchange_rate_nips',\n    'solar_nips',\n    'electricity_nips',\n    'traffic_nips',\n    'taxi_30min',\n    'wiki-rolling_nips',\n    'wiki2000_nips'\n]\n\nclass DataManager:\n    def __init__(\n        self,\n        dataset: str,\n        path: str = './datasets',\n        history_length: int = None,\n        context_length: int = None,\n        prediction_length: Union[list,int,str] = None,\n        train_ctx_len: int = None,\n        train_pred_len_list: Union[list,int,str] = None,\n        val_ctx_len: int = None,\n        val_pred_len_list: Union[list,int,str] = None,\n        test_rolling_length: int = 96,\n        split_val: bool = True,\n        scaler: str = 'none',\n        context_length_factor: int = 1,\n        timeenc: int = 1,\n        var_specific_norm: bool = True,\n        data_path: str = None,\n        freq: str = None,\n        multivariate: bool = True,\n        continuous_sample: bool = False,\n        train_ratio: float = 0.7,\n        test_ratio: float = 0.2,\n        auto_search: bool = False,\n    ):\n        \"\"\"\n        DataManager class for handling datasets and preparing data for time-series models.\n\n        Parameters\n        ----------\n        dataset : str\n            Name of the dataset to load. Examples include \"etth1\", \"electricity_ltsf\", etc.\n        path : str, optional, default='./datasets'\n            Root directory path where datasets are stored.\n        history_length : int, optional, default=None\n            Length of the historical input window for the model.\n            If not specified, it is automatically calculated based on `context_length` and lag features.\n        context_length : int, optional, default=None\n            Length of the input context for the model. \n        prediction_length : Union[list, int, str], optional, default=None\n            Length of the prediction horizon for the model. Can be:\n            - int: Fixed prediction length.\n            - list: Variable prediction lengths for multi-horizon training.\n            - str: The string format of multiple prediction length. E.g., '96-192-336-720' represents [96, 192, 336, 720]\n        train_ctx_len : int, optional, default=None\n            Context length for the training dataset.\n            If not specified, defaults to the value of `context_length`.\n        train_pred_len_list : Union[list, int, str], optional, default=None\n            List of prediction lengths for the training dataset.\n            If not specified, defaults to the value of `prediction_length`.\n        val_ctx_len : int, optional, default=None\n            Context length for the validation dataset.\n            If not specified, defaults to the value of `context_length`.\n        val_pred_len_list : Union[list, int, str], optional, default=None\n            List of prediction lengths for the validation dataset.\n            If not specified, defaults to the value of `prediction_length`.\n        test_rolling_length : int, optional, default=96\n            Gap window size used for rolling predictions in the testing phase.\n            - If set to `auto`, it is dynamically determined based on the dataset frequency\n            (e.g., 'H' -> 24, 'D' -> 7, 'W' -> 4).\n        split_val : bool, optional, default=True\n            Whether to split the training dataset into training and validation sets.\n        scaler : str, optional, default='none'\n            Type of normalization or scaling applied to the dataset. Options include:\n            - 'none': No scaling.\n            - 'standard': Standard normalization (z-score).\n            - 'temporal': Mean-scaling normalization.\n        context_length_factor : int, optional, default=1\n            Scaling factor for context length, allowing dynamic adjustment of `context_length`.\n        timeenc : int, optional, default=1\n            Time encoding strategy. Options include:\n            - 0: The dimension of time feature is 5, containing `month, day, weekday, hour, minute`\n            - 1: Cyclic time features (e.g., sine/cosine of timestamps).\n            - 2: Raw Timestamp information.\n        var_specific_norm : bool, optional, default=True\n            Whether to normalize variables independently. Only applies when `scaler='standard'`.\n        data_path : str, optional, default=None\n            Specific path to the dataset file.\n        freq : str, optional, default=None\n            Data frequency (e.g., 'H' for hourly, 'D' for daily).\n        multivariate : bool, optional, default=True\n            Whether the dataset is multivariables.\n        continuous_sample : bool, optional, default=False\n            Whether to enable continuous sampling for forecasting horizons during training phase.\n        train_ratio : float, optional, default=0.7\n            Proportion of the dataset used for training. Default is 70% of the data.\n        test_ratio : float, optional, default=0.2\n            Proportion of the dataset used for testing. Default is 20% of the data.\n        auto_search : bool, optional, default=False\n            Make past_len=ctx_len+pred_len, enabling post training search.\n        \"\"\"\n\n        self.dataset = dataset\n        self.path = path\n        self.history_length = history_length\n        self.context_length = context_length\n        self.prediction_length = prediction_length\n        self.train_ctx_len = train_ctx_len if train_ctx_len is not None else context_length\n        self.val_ctx_len = val_ctx_len if val_ctx_len is not None else context_length\n        self.train_pred_len_list = train_pred_len_list if train_pred_len_list is not None else prediction_length\n        self.val_pred_len_list = val_pred_len_list if val_pred_len_list is not None else prediction_length\n        self.test_rolling_length = test_rolling_length\n        self.split_val = split_val\n        self.scaler_type = scaler\n        self.context_length_factor = context_length_factor\n        self.timeenc = timeenc\n        self.var_specific_norm = var_specific_norm\n        self.data_path = data_path\n        self.freq = freq\n        self.multivariate = multivariate\n        self.continuous_sample = continuous_sample\n        self.train_ratio = train_ratio\n        self.test_ratio = test_ratio\n        self.auto_search = auto_search\n        \n        self.test_rolling_dict = {'h': 24, 'd': 7, 'b':5, 'w':4, 'min': 60}\n        self.global_mean = None\n\n        # Configure scaler\n        self.scaler = self._configure_scaler(self.scaler_type)\n  \n        # Load dataset and prepare for processing\n        if dataset in dataset_names:\n            self.multi_hor = False\n            self._load_short_term_dataset()\n        elif self.is_gift_eval:\n            self.multi_hor = False\n            # Load GIFT eval datasets from salesforce\n            self._load_gift_eval_dataset()\n        else:\n            # Process context and prediction lengths\n            self._process_context_and_prediction_lengths()\n            self._load_long_term_dataset()\n            # Print configuration details\n            self._print_configurations()\n        \n    def _configure_scaler(self, scaler_type: str):\n        \"\"\"Configure the scaler.\"\"\"\n        if scaler_type == \"standard\":\n            return StandardScaler(var_specific=self.var_specific_norm)\n        elif scaler_type == \"temporal\":\n            return TemporalScaler()\n        return IdentityScaler()\n    \n    def _load_gift_eval_dataset(self):\n        parts = self.dataset[5:].split('/')  # Remove first 'gift/'\n        self.dataset = '/'.join(parts[:-1])  # Join all parts except last one with '/'\n        gift_term = parts[-1] # corresponding to \"term\" parameter in GiftEvalDataset\n        TO_UNIVARIATE = False\n        self.dataset_raw = GiftEvalDataset(self.dataset, term=gift_term, to_univariate=TO_UNIVARIATE)\n        self._set_meta_parameters(self.dataset_raw.target_dim, self.dataset_raw.freq, self.dataset_raw.prediction_length)\n\n        dataset_loader = SingleHorizonDataset(\n            ProbTSBatchData.input_names_, \n            self.history_length,\n            self.context_length,\n            self.prediction_length,\n            self.freq,\n            self.multivariate\n        )\n\n        self.train_iter_dataset = dataset_loader.get_iter_dataset(self.dataset_raw.training_dataset, mode='train')\n        self.val_iter_dataset = dataset_loader.get_iter_dataset(self.dataset_raw.validation_dataset, mode='val')\n        self.test_iter_dataset = dataset_loader.get_iter_dataset(self.dataset_raw.test_dataset, mode='test')\n        self.time_feat_dim = dataset_loader.time_feat_dim\n        # TODO: Implement global mean for GIFT eval datasets\n        # self.global_mean = torch.mean(torch.tensor(self.dataset_raw.training_dataset[0]['target']), dim=-1)\n    \n    def _load_short_term_dataset(self):\n        \"\"\"Load short-term dataset using GluonTS.\"\"\"\n        print(f\"Loading Short-term Dataset: {self.dataset}\")\n        self.dataset_raw = datasets.get_dataset(self.dataset, path=Path(self.path), regenerate=True)\n        metadata = self.dataset_raw.metadata\n        if self.is_univar_dataset:\n            target_dim = 1\n        else:\n            target_dim = metadata.feat_static_cat[0].cardinality\n        self._set_meta_parameters(target_dim, metadata.freq.upper(), metadata.prediction_length)\n        self.prepare_STSF_dataset(self.dataset)\n\n    def _set_meta_parameters(self, target_dim, freq, prediction_length):\n        \"\"\"Set meta parameters from base dataset.\"\"\"\n        self.target_dim = int(target_dim)\n        self.multivariate = self.target_dim > 1\n        self.freq = freq\n        self.lags_list = get_lags(self.freq)\n        self.prediction_length = prediction_length\n        self.context_length = self.context_length or self.prediction_length * self.context_length_factor\n        self.history_length = self.history_length or (self.context_length + max(self.lags_list))\n        \n    def _process_context_and_prediction_lengths(self):\n        \"\"\"Convert context and prediction lengths to lists for multi-horizon processing.\"\"\"\n        self.train_ctx_len_list = ensure_list(self.train_ctx_len, default_value=self.context_length)\n        self.val_ctx_len_list = ensure_list(self.val_ctx_len, default_value=self.context_length)\n        self.test_ctx_len_list = ensure_list(self.context_length)\n        self.train_pred_len_list = ensure_list(self.train_pred_len_list, default_value=self.prediction_length)\n        self.val_pred_len_list = ensure_list(self.val_pred_len_list, default_value=self.prediction_length)\n        self.test_pred_len_list = ensure_list(self.prediction_length)\n\n        # Validate context length support\n        assert len(self.train_ctx_len_list) == 1, \"Assign a single context length for training.\"\n        assert len(self.val_ctx_len_list) == 1, \"Assign a single context length for validation.\"\n        assert len(self.test_ctx_len_list) == 1, \"Assign a single context length for testing.\"\n\n        self.multi_hor = len(self.train_pred_len_list) > 1 or \\\n                         len(self.val_pred_len_list) > 1 or \\\n                         len(self.test_pred_len_list) > 1\n\n    def _load_long_term_dataset(self):\n        \"\"\"Load long-term dataset or customized dataset.\"\"\"\n        print(f\"Loading Long-term Dataset: {self.dataset}\")\n        if not self.context_length or not self.prediction_length:\n            raise ValueError(\"context_length or prediction_length must be specified.\")\n\n        data_path, self.freq = get_dataset_info(self.dataset, data_path=self.data_path, freq=self.freq)\n        self.dataset_raw, self.data_stamp, self.target_dim, data_size = load_dataset(\n            self.path, data_path, freq=self.freq, timeenc=self.timeenc, multivariate=self.multivariate\n        )\n        self.border_begin, self.border_end = get_dataset_borders(\n            self.dataset, data_size, train_ratio=self.train_ratio, test_ratio=self.test_ratio\n        )\n        self._set_meta_parameters_from_raw(data_size)\n        self.prepare_dataset()\n        \n    def _set_meta_parameters_from_raw(self, data_size):\n        \"\"\"Set meta parameters directly from raw dataset.\"\"\"\n        self.lags_list = get_lags(self.freq)\n        self.prediction_length = ensure_list(self.prediction_length) if self.multi_hor else self.prediction_length\n        self.context_length = ensure_list(self.context_length) if self.multi_hor else self.context_length\n        self.history_length = self.history_length or (\n            max(self.context_length) + max(self.lags_list) if self.multi_hor else self.context_length + max(self.lags_list)\n        )\n        if not self.multivariate:\n            self.target_dim = 1\n            raise NotImplementedError(\"Customized univariate datasets are not yet supported.\")\n        assert data_size >= self.border_end[2], \"border_end index exceeds dataset size!\"\n        \n        # define the test_rolling_length\n        if self.test_rolling_length == 'auto':\n            if self.freq.lower() in self.test_rolling_dict:\n                self.test_rolling_length = self.test_rolling_dict[self.freq.lower()]\n            else:\n                self.test_rolling_length = 24\n            \n\n    def prepare_dataset(self):\n        \"\"\"Prepare datasets for training, validation, and testing.\"\"\"\n        # Split raw data into train, validation, and test sets\n        train_data = self.dataset_raw[: self.border_end[0]]\n        val_data = self.dataset_raw[: self.border_end[1]]\n        test_data = self.dataset_raw[: self.border_end[2]]\n        \n        # Calculate statictics using training data\n        self.scaler.fit(torch.tensor(train_data.values))\n        \n        # Convert dataframes to multivariate datasets\n        train_set = df_to_mvds(train_data, freq=self.freq)\n        val_set = df_to_mvds(val_data,freq=self.freq)\n        test_set = df_to_mvds(test_data,freq=self.freq)\n        \n        train_grouper = MultivariateGrouper(max_target_dim=self.target_dim)\n        test_grouper = MultivariateGrouper(max_target_dim=self.target_dim)\n        \n        group_train_set = train_grouper(train_set)\n        group_val_set = test_grouper(val_set)\n        group_test_set = test_grouper(test_set)\n        \n        if self.multi_hor:\n            # Handle multi-horizon datasets\n            dataset_loader = self._prepare_multi_horizon_datasets(group_val_set, group_test_set)\n        else:\n            # Handle single-horizon datasets\n            dataset_loader = self._prepare_single_horizon_datasets(group_val_set, group_test_set)\n\n        self.train_iter_dataset = dataset_loader.get_iter_dataset(group_train_set, mode='train', data_stamp=self.data_stamp[: self.border_end[0]])\n        \n        self.time_feat_dim = dataset_loader.time_feat_dim\n        self.global_mean = torch.mean(torch.tensor(group_train_set[0]['target']), dim=-1)\n    \n    \n    def _prepare_multi_horizon_datasets(self, group_val_set, group_test_set):\n        \"\"\"Prepare multi-horizon datasets for validation and testing.\"\"\"\n        self.val_iter_dataset = {}\n        self.test_iter_dataset = {}\n        dataset_loader = MultiHorizonDataset(\n            input_names = ProbTSBatchData.input_names_,\n            freq = self.freq,\n            train_ctx_range = self.train_ctx_len_list,\n            train_pred_range = self.train_pred_len_list,\n            val_ctx_range = self.val_ctx_len_list,\n            val_pred_range = self.val_pred_len_list,\n            test_ctx_range = self.test_ctx_len_list,\n            test_pred_range = self.test_pred_len_list,\n            multivariate = self.multivariate,\n            continuous_sample = self.continuous_sample\n        )\n\n        # Prepare validation datasets\n        for pred_len in self.val_pred_len_list:\n            local_group_val_set = get_rolling_test(\n                'val', group_val_set, self.border_begin[1], self.border_end[1],\n                rolling_length=self.test_rolling_length, pred_len=pred_len, freq=self.freq\n            )\n            self.val_iter_dataset[str(pred_len)] = dataset_loader.get_iter_dataset(\n                local_group_val_set, mode='val', data_stamp=self.data_stamp[:self.border_end[1]], pred_len=[pred_len]\n            )\n\n        # Prepare testing datasets\n        for pred_len in self.test_pred_len_list:\n            local_group_test_set = get_rolling_test(\n                'test', group_test_set, self.border_begin[2], self.border_end[2],\n                rolling_length=self.test_rolling_length, pred_len=pred_len, freq=self.freq\n            )\n            self.test_iter_dataset[str(pred_len)] = dataset_loader.get_iter_dataset(\n                local_group_test_set, mode='test', data_stamp=self.data_stamp[:self.border_end[2]], pred_len=[pred_len], auto_search=self.auto_search,\n            )\n            \n        return dataset_loader\n    \n    def _prepare_single_horizon_datasets(self, group_val_set, group_test_set):\n        \"\"\"Prepare single-horizon datasets for training, validation, and testing.\"\"\"\n        dataset_loader = SingleHorizonDataset(\n            ProbTSBatchData.input_names_,\n            self.history_length,\n            self.context_length,\n            self.prediction_length,\n            self.freq,\n            self.multivariate,\n        )\n\n        # Validation dataset\n        local_group_val_set = get_rolling_test(\n            'val', group_val_set, self.border_begin[1], self.border_end[1],\n            rolling_length=self.test_rolling_length, pred_len=self.val_pred_len_list[0], freq=self.freq\n        )\n        self.val_iter_dataset = dataset_loader.get_iter_dataset(local_group_val_set, mode='val', data_stamp=self.data_stamp[:self.border_end[1]])\n\n        # Testing dataset\n        local_group_test_set = get_rolling_test(\n            'test', group_test_set, self.border_begin[2], self.border_end[2],\n            rolling_length=self.test_rolling_length, pred_len=self.prediction_length, freq=self.freq\n        )\n        self.test_iter_dataset = dataset_loader.get_iter_dataset(local_group_test_set, mode='test', data_stamp=self.data_stamp[:self.border_end[2]], auto_search=self.auto_search)\n\n        return dataset_loader\n    \n    def prepare_STSF_dataset(self, dataset: str):\n        \"\"\"Prepare datasets for short-term series forecasting.\"\"\"\n        if dataset in MULTI_VARIATE_DATASETS:\n            self.num_test_dates = int(len(self.dataset_raw.test)/len(self.dataset_raw.train))\n\n            train_grouper = MultivariateGrouper(max_target_dim=int(self.target_dim))\n            test_grouper = MultivariateGrouper(\n                num_test_dates=self.num_test_dates, \n                max_target_dim=int(self.target_dim)\n            )\n            train_set = train_grouper(self.dataset_raw.train)\n            test_set = test_grouper(self.dataset_raw.test)\n            self.scaler.fit(torch.tensor(train_set[0]['target'].transpose(1, 0)))\n            self.global_mean = torch.mean(torch.tensor(train_set[0]['target']), dim=-1)\n            \n            # split_val\n            if self.split_val:\n                train_set, val_set = split_train_val(train_set, self.num_test_dates, self.context_length, self.prediction_length, self.freq)\n            else:\n                val_set = None\n        else:\n            self.target_dim = 1\n            self.multivariate = False\n            self.num_test_dates = 1\n            train_set = self.dataset_raw.train\n            test_set = self.dataset_raw.test\n            test_set = truncate_test(test_set, self.context_length, self.prediction_length, self.freq)\n            # for univariate dataset, e.g., M4 and M5, no validation set is used\n            val_set = None\n\n        if val_set is None:\n            print('No validation set is used.')\n            \n        dataset_loader = SingleHorizonDataset(\n            ProbTSBatchData.input_names_, \n            self.history_length,\n            self.context_length,\n            self.prediction_length,\n            self.freq,\n            self.multivariate\n        )\n\n        self.train_iter_dataset = dataset_loader.get_iter_dataset(train_set, mode='train')\n        if val_set is not None:\n            self.val_iter_dataset = dataset_loader.get_iter_dataset(val_set, mode='val')\n        else:\n            self.val_iter_dataset = None\n        self.test_iter_dataset = dataset_loader.get_iter_dataset(test_set, mode='test')\n        self.time_feat_dim = dataset_loader.time_feat_dim\n\n    def _print_configurations(self):\n        \"\"\"Print dataset and configuration details.\"\"\"\n        print(f\"Test context length: {self.test_ctx_len_list}, prediction length: {self.test_pred_len_list}\")\n        print(f\"Validation context length: {self.val_ctx_len_list}, prediction length: {self.val_pred_len_list}\")\n        print(f\"Training context length: {self.train_ctx_len_list}, prediction lengths: {self.train_pred_len_list}\")\n        print(f\"Test rolling length: {self.test_rolling_length}\")\n        if self.scaler_type == \"standard\":\n            print(f\"Variable-specific normalization: {self.var_specific_norm}\")\n\n    @cached_property\n    def is_gift_eval(self) -> bool:\n        return self.dataset[:5] == \"gift/\"\n    \n    @cached_property\n    def is_univar_dataset(self) -> bool:\n        if 'm4' in self.dataset or 'm5' in self.dataset:\n            return True\n        return False"
  },
  {
    "path": "probts/data/data_module.py",
    "content": "import torch\nimport lightning.pytorch as pl\nfrom torch.utils.data import DataLoader, Dataset\nfrom lightning.pytorch.utilities.combined_loader import CombinedLoader\nfrom probts.data.data_manager import DataManager\nfrom probts.data.data_wrapper import ProbTSBatchData\n\nclass EmptyDataset(Dataset):\n    def __len__(self):\n        return 0\n\n    def __getitem__(self, idx):\n        raise IndexError(\"This dataset is empty.\")\n\nclass ProbTSDataModule(pl.LightningDataModule):\n    r\"\"\"\n        DataModule for probablistic time series datasets.\n    \"\"\"\n    def __init__(\n        self,\n        data_manager: DataManager,\n        batch_size: int = 64,\n        test_batch_size: int = 8,\n        num_workers: int = 8\n    ):\n        super().__init__()\n        self.data_manager = data_manager\n        self.batch_size = batch_size\n        self.test_batch_size = test_batch_size\n        self.num_workers = num_workers\n        self.save_hyperparameters()\n\n        self.dataset_train = self.data_manager.train_iter_dataset\n        self.dataset_val = self.data_manager.val_iter_dataset\n        self.dataset_test = self.data_manager.test_iter_dataset\n\n    def train_dataloader(self):\n        if self.data_manager.multi_hor:\n                return DataLoader(\n                self.dataset_train,\n                batch_size=self.batch_size,\n                num_workers=0,\n                pin_memory=True,\n                collate_fn=self.train_collate_fn\n            )\n        else:\n            return DataLoader(\n                self.dataset_train,\n                batch_size=self.batch_size,\n                num_workers=self.num_workers,\n                persistent_workers=True,\n                pin_memory=True\n            )\n\n    def val_dataloader(self):\n        # if no validation set available\n        if self.dataset_val is None:\n            return DataLoader(EmptyDataset(), batch_size=1)\n        \n        if self.data_manager.multi_hor:\n            val_dataloader = self.combine_dataloader(self.dataset_val)\n        else:\n            val_dataloader = DataLoader(self.dataset_val, batch_size=self.test_batch_size, num_workers=1)\n        return val_dataloader\n\n    def test_dataloader(self):\n        if self.data_manager.multi_hor:\n            return self.combine_dataloader(self.dataset_test)\n        else:\n            return DataLoader(self.dataset_test, batch_size=self.test_batch_size, num_workers=1)\n\n    def predict_dataloader(self):\n        return DataLoader(self.dataset_test, batch_size=self.test_batch_size, num_workers=0)\n    \n    def combine_dataloader(self, dataset_dict):\n        dataloader_dict = {}\n        for hor in dataset_dict:\n            dataloader_dict[hor] = DataLoader(dataset_dict[hor], batch_size=self.test_batch_size, num_workers=0, persistent_workers=False,)\n        \n        combined_loader = CombinedLoader(dataloader_dict, mode=\"sequential\")\n        return combined_loader\n    \n    def train_collate_fn(self, batch):\n        '''\n        Training with varied horizons is achieved by padding horizons in training phase.\n        The look-back window for each sample can different within a batch.\n        '''\n        \n        past_len_list = [len(x['past_target_cdf']) for x in batch]\n        future_len_list = [len(x['future_target_cdf']) for x in batch]\n        \n        max_past_length = max(past_len_list)\n        max_future_length = max(future_len_list)\n        B = len(batch)\n        batch_dict = {}\n        batch_dict['context_length'] = []\n        batch_dict['prediction_length'] = []\n        batch_dict['target_dimension_indicator'] = []\n        \n        for idx in range(len(batch)):\n            local_past_len = len(batch[idx]['past_target_cdf'])\n            local_future_len = len(batch[idx]['future_target_cdf'])\n                \n            for input in ProbTSBatchData.input_names_:\n                K = batch[0][input].shape[-1]\n                if input in ['past_target_cdf','past_observed_values','past_time_feat','past_is_pad']:\n                    if input not in batch_dict and input in ['past_target_cdf','past_observed_values','past_time_feat']:\n                        batch_dict[input] = torch.zeros([B, max_past_length, K])\n                    if input not in batch_dict and input in ['past_is_pad']:\n                        batch_dict[input] = torch.zeros([B, max_past_length])\n                        \n                    batch_dict[input][idx][-local_past_len:] = torch.tensor(batch[idx][input])[:local_past_len]\n\n                elif input in ['future_target_cdf','future_observed_values','future_time_feat']:\n                    if input not in batch_dict:\n                        batch_dict[input] = torch.zeros([B, max_future_length, K])\n                    batch_dict[input][idx][:local_future_len] = torch.tensor(batch[idx][input])[:local_future_len]\n\n            batch_dict['target_dimension_indicator'].append(batch[idx]['target_dimension_indicator'])\n            batch_dict['context_length'].append(local_past_len)\n            batch_dict['prediction_length'].append(local_future_len)\n            \n        batch_dict['target_dimension_indicator'] = torch.tensor(batch_dict['target_dimension_indicator'])\n        \n        batch_dict['max_context_length'] = max_past_length\n        batch_dict['max_prediction_length'] = max_future_length\n        return batch_dict"
  },
  {
    "path": "probts/data/data_utils/data_scaler.py",
    "content": "# ---------------------------------------------------------------------------------\n# Portions of this file are derived from PyTorch-TS\n# - Source: https://github.com/zalandoresearch/pytorch-ts\n# - License: MIT, Apache-2.0 license\n\n# We thank the authors for their contributions.\n# ---------------------------------------------------------------------------------\n\nimport torch\nimport torch.nn as nn\n\nclass Scaler:\n    def __init__(self):\n        super().__init__()\n\n    def fit(self, values):\n        raise NotImplementedError\n\n    def transform(self, values):\n        raise NotImplementedError\n\n    def fit_transform(self, values):\n        raise NotImplementedError\n\n    def inverse_transform(self, values):\n        raise NotImplementedError\n\n\nclass StandardScaler(Scaler):\n    def __init__(\n        self,\n        mean: float = None,\n        std: float = None,\n        epsilon: float = 1e-9,\n        var_specific: bool = True\n    ):\n        \"\"\"\n        The class can be used to normalize PyTorch Tensors using native functions. The module does not expect the\n        tensors to be of any specific shape; as long as the features are the last dimension in the tensor, the module\n        will work fine.\n        \n        Args:\n            mean: The mean of the features. The property will be set after a call to fit.\n            std: The standard deviation of the features. The property will be set after a call to fit.\n            epsilon: Used to avoid a Division-By-Zero exception.\n            var_specific: If True, the mean and standard deviation will be computed per variate.\n        \"\"\"\n        self.mean = mean\n        self.scale = std\n        self.epsilon = epsilon\n        self.var_specific = var_specific\n\n    def fit(self, values):\n        \"\"\"\n        Args:\n            values: Input values should be a PyTorch tensor of shape (T, C) or (N, T, C), \n                where N is the batch size, T is the timesteps and C is the number of variates.\n        \"\"\"\n        dims = list(range(values.dim() - 1))\n        if not self.var_specific:\n            self.mean = torch.mean(values)\n            self.scale = torch.std(values)\n        else:\n            self.mean = torch.mean(values, dim=dims)\n            self.scale = torch.std(values, dim=dims)\n\n    def transform(self, values):\n        if self.mean is None:\n            return values\n\n        values = (values - self.mean.to(values.device)) / (self.scale.to(values.device) + self.epsilon)\n        return values.to(torch.float32)\n\n    def fit_transform(self, values):\n        self.fit(values)\n        return self.transform(values)\n\n    def inverse_transform(self, values):\n        if self.mean is None:\n            return values\n        \n        values = values * (self.scale.to(values.device) + self.epsilon)\n        values = values + self.mean.to(values.device)\n        return values\n\n\nclass TemporalScaler(Scaler):\n    def __init__(\n        self,\n        minimum_scale:float = 1e-10,\n        time_first: bool = True\n    ):\n        \"\"\"\n        The ``TemporalScaler`` computes a per-item scale according to the average\n        absolute value over time of each item. The average is computed only among\n        the observed values in the data tensor, as indicated by the second\n        argument. Items with no observed data are assigned a scale based on the\n        global average.\n\n        Args:\n            minimum_scale: default scale that is used if the time series has only zeros.\n            time_first: if True, the input tensor has shape (N, T, C), otherwise (N, C, T).\n        \"\"\"\n        super().__init__()\n        self.scale = None\n        self.minimum_scale = torch.tensor(minimum_scale)\n        self.time_first = time_first\n\n    def fit(\n        self,\n        data: torch.Tensor,\n        observed_indicator: torch.Tensor = None\n    ):\n        \"\"\"\n        Fit the scaler to the data.\n        \n        Args:\n            data: tensor of shape (N, T, C) if ``time_first == True`` or (N, C, T)\n                if ``time_first == False`` containing the data to be scaled\n\n            observed_indicator: observed_indicator: binary tensor with the same shape as\n                ``data``, that has 1 in correspondence of observed data points,\n                and 0 in correspondence of missing data points.\n\n        Note:\n            Tensor containing the scale, of shape (N, 1, C) or (N, C, 1).\n        \"\"\"\n        if self.time_first:\n            dim = -2\n        else:\n            dim = -1\n\n        if observed_indicator is None:\n            observed_indicator = torch.ones_like(data)\n\n        # These will have shape (N, C)\n        num_observed = observed_indicator.sum(dim=dim)\n        sum_observed = (data.abs() * observed_indicator).sum(dim=dim)\n\n        # First compute a global scale per-dimension\n        total_observed = num_observed.sum(dim=0)\n        denominator = torch.max(total_observed, torch.ones_like(total_observed))\n        default_scale = sum_observed.sum(dim=0) / denominator\n\n        # Then compute a per-item, per-dimension scale\n        denominator = torch.max(num_observed, torch.ones_like(num_observed))\n        scale = sum_observed / denominator\n\n        # Use per-batch scale when no element is observed\n        # or when the sequence contains only zeros\n        scale = torch.where(\n            sum_observed > torch.zeros_like(sum_observed),\n            scale,\n            default_scale * torch.ones_like(num_observed),\n        )\n\n        self.scale = torch.max(scale, self.minimum_scale).unsqueeze(dim=dim).detach()\n\n    def transform(self, data):\n        return data / self.scale.to(data.device)\n\n    def fit_transform(self, data, observed_indicator=None):\n        self.fit(data, observed_indicator)\n        return self.transform(data)\n\n    def inverse_transform(self, data):\n        return data * self.scale.to(data.device)\n\n\nclass IdentityScaler(Scaler):\n    \"\"\"\n    No scaling is applied upon calling the ``IdentityScaler``.\n    \"\"\"\n    def __init__(self, time_first: bool = True):\n        super().__init__()\n        self.scale = None\n        \n    def fit(self, data):\n        pass\n\n    def transform(self, data):\n        return data\n    \n    def inverse_transform(self, data):\n        return data\n    \nclass InstanceNorm(nn.Module):\n    def __init__(self, eps=1e-5):\n        \"\"\"\n        :param eps: a value added for numerical stability\n        \"\"\"\n        super(InstanceNorm, self).__init__()\n        self.eps = eps\n\n    def forward(self, x, mode:str):\n        if mode == 'norm':\n            self._get_statistics(x)\n            x = self._normalize(x)\n        elif mode == 'denorm':\n            x = self._denormalize(x)\n        else: raise NotImplementedError\n        return x\n\n    def _get_statistics(self, x):\n        dim2reduce = tuple(range(1, x.ndim-1))\n        self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach()\n        self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach()\n\n    def _normalize(self, x):\n        x = x - self.mean\n        x = x / self.stdev\n        return x\n\n    def _denormalize(self, x):\n        x = x * self.stdev\n        x = x + self.mean\n        return x\n\n"
  },
  {
    "path": "probts/data/data_utils/data_utils.py",
    "content": "from copy import deepcopy\nimport math\nimport pandas as pd\nimport numpy as np\nfrom datetime import datetime\nfrom distutils.util import strtobool\nfrom gluonts.dataset.common import ListDataset\nfrom gluonts.dataset.field_names import FieldName\n\n\ndef split_train_val(train_set, num_test_windows, context_length, prediction_length, freq):\n    \"\"\"\n    Splits a training dataset into a truncated training set and a validation set.\n\n    Parameters:\n    - train_set: The input training dataset.\n    - num_test_windows: Number of rolling windows for validation.\n    - context_length: Context length for the model.\n    - prediction_length: Prediction horizon for the model.\n    - freq: Data frequency (e.g., 'H' for hourly).\n\n    Returns:\n    - trunc_train_set: Truncated training dataset (ListDataset).\n    - val_set: Validation dataset (ListDataset).\n    \"\"\"\n    trunc_train_list = []\n    val_set_list = []\n    univariate = False\n\n    for train_seq in iter(train_set):\n        # truncate train set\n        offset = num_test_windows * prediction_length\n        trunc_train_seq = deepcopy(train_seq)\n\n        if len(train_seq[FieldName.TARGET].shape) == 1:\n            trunc_train_len = train_seq[FieldName.TARGET].shape[0] - offset\n            trunc_train_seq[FieldName.TARGET] = train_seq[FieldName.TARGET][:trunc_train_len]\n            univariate = True\n        elif len(train_seq[FieldName.TARGET].shape) == 2:\n            trunc_train_len = train_seq[FieldName.TARGET].shape[1] - offset\n            trunc_train_seq[FieldName.TARGET] = train_seq[FieldName.TARGET][:, :trunc_train_len]\n        else:\n            raise ValueError(f\"Invalid Data Shape: {str(len(train_seq[FieldName.TARGET].shape))}\")\n\n        trunc_train_list.append(trunc_train_seq)\n\n        # construct val set by rolling\n        for i in range(num_test_windows):\n            val_seq = deepcopy(train_seq)\n            rolling_len = trunc_train_len + prediction_length * (i+1)\n            if univariate:\n                val_seq[FieldName.TARGET] = val_seq[FieldName.TARGET][trunc_train_len + prediction_length * (i-1) - context_length : rolling_len]\n            else:\n                val_seq[FieldName.TARGET] = val_seq[FieldName.TARGET][:, :rolling_len]\n            \n            val_set_list.append(val_seq)\n\n    trunc_train_set = ListDataset(\n        trunc_train_list, freq=freq, one_dim_target=univariate\n    )\n\n    val_set = ListDataset(\n        val_set_list, freq=freq, one_dim_target=univariate\n    )\n    \n    return trunc_train_set, val_set\n\n\ndef truncate_test(test_set, context_length, prediction_length, freq):\n    \"\"\"\n    Truncates the test dataset to ensure only the last context and prediction lengths are retained.\n\n    Parameters:\n    - test_set: The input test dataset.\n    - context_length: Context length for the model.\n    - prediction_length: Prediction horizon for the model.\n    - freq: Data frequency.\n\n    Returns:\n    - trunc_test_set: Truncated test dataset (ListDataset).\n    \"\"\"\n    trunc_test_list = []\n    for test_seq in iter(test_set):\n        # truncate train set\n        trunc_test_seq = deepcopy(test_seq)\n\n        trunc_test_seq[FieldName.TARGET] = trunc_test_seq[FieldName.TARGET][- (prediction_length * 2 + context_length):]\n\n        trunc_test_list.append(trunc_test_seq)\n\n    trunc_test_set = ListDataset(\n        trunc_test_list, freq=freq, one_dim_target=True\n    )\n\n    return trunc_test_set\n\n\ndef get_rolling_test(stage, test_set, border_begin_idx, border_end_idx, rolling_length, pred_len, freq=None):\n    \"\"\"\n    Using rolling windows to build the test dataset.\n\n    Parameters:\n    - stage: Stage name (e.g., 'test', 'val').\n    - test_set: The test dataset.\n    - border_begin_idx: Start index for rolling windows.\n    - border_end_idx: End index for rolling windows.\n    - rolling_length: Gap length of each rolling window.\n    - pred_len: Prediction length.\n    - freq: Data frequency.\n\n    Returns:\n    - rolling_test_set: Rolling test dataset (ListDataset).\n    \"\"\"\n    num_test_windows = math.ceil(((border_end_idx - border_begin_idx - pred_len) / rolling_length))\n    print(f\"{stage}  pred_len: {pred_len} : num_test_windows: {num_test_windows}\")\n\n    test_set = next(iter(test_set))\n    rolling_test_seq_list = list()\n    for i in range(num_test_windows):\n        rolling_test_seq = deepcopy(test_set)\n        rolling_end = border_begin_idx + pred_len + i * rolling_length\n        rolling_test_seq[FieldName.TARGET] = rolling_test_seq[FieldName.TARGET][:, :rolling_end]\n        rolling_test_seq_list.append(rolling_test_seq)\n\n    rolling_test_set = ListDataset(\n        rolling_test_seq_list, freq=freq, one_dim_target=False\n    )\n    return rolling_test_set\n\n\ndef get_rolling_test_of_gift_eval(dataset, prediction_length, windows):\n    \"\"\"\n    Using rolling windows to build the test dataset for GiftEval.\n    https://github.com/SalesforceAIResearch/gift-eval/blob/61ec5e563188bc4b2d7e86f6a7fcc78270607ae7/src/gift_eval/data.py#L213\n    Get the windows from the back of the dataset, for example if the dataset has N time points:\n    - The first window will be from the first time point to the N - prediction_length * windows time point.\n    - The second window will be from the first time point to the N - prediction_length * (windows - 1) time point.\n    - The last window will be from the first time point to the N time point.\n\n    Parameters:\n    - dataset: The input dataset.\n    - prediction_length: Prediction length.\n    - windows: Number of rolling windows.\n\n    Returns:\n    - rolling_test_set: Rolling test dataset (ListDataset).\n    \"\"\"\n    rolling_test_seq_list = list()\n    dataset = next(iter(dataset))\n    if \"freq\" not in dataset.keys():\n        raise ValueError(\"The dataset must contain the 'freq' key.\")\n    freq = dataset[\"freq\"]\n    is_univariate = len(dataset[FieldName.TARGET].shape) == 1\n\n    for i in range(windows):\n        rolling_test_seq = deepcopy(dataset)\n        rolling_end = dataset[FieldName.TARGET].shape[-1] - prediction_length * (windows - i)\n        if is_univariate:\n            rolling_test_seq[FieldName.TARGET] = dataset[FieldName.TARGET][:rolling_end]\n        elif len(dataset[FieldName.TARGET].shape) == 2:\n            rolling_test_seq[FieldName.TARGET] = dataset[FieldName.TARGET][:, :rolling_end]\n        else:\n            raise ValueError(f\"Invalid Data Shape: expected 1 or 2 dimensions, got {len(dataset[FieldName.TARGET].shape)}\")\n        rolling_test_seq_list.append(rolling_test_seq)\n\n    rolling_test_set = ListDataset(\n        rolling_test_seq_list, freq=freq, one_dim_target=is_univariate\n    )\n    return rolling_test_set\n\n\n\ndef df_to_mvds(df, freq='H'):\n    \"\"\"\n    Converts a pandas DataFrame to a multivariate ListDataset for GluonTS.\n\n    Parameters:\n    - df: Input DataFrame where columns represent time series variables.\n    - freq: Data frequency (e.g., 'H' for hourly).\n\n    Returns:\n    - dataset: Multivariate ListDataset.\n    \"\"\"\n    datasets = []\n    for variable in df.keys():\n        ds = {\"item_id\" : variable, \"target\" : df[variable], \"start\": str(df.index[0])}\n        datasets.append(ds)\n    dataset = ListDataset(datasets,freq=freq)\n    return dataset\n\n\ndef convert_monash_data_to_dataframe(\n    full_file_path_and_name,\n    replace_missing_vals_with=\"NaN\",\n    value_column_name=\"series_value\",\n):\n    col_names = []\n    col_types = []\n    all_data = {}\n    line_count = 0\n    frequency = None\n    forecast_horizon = None\n    contain_missing_values = None\n    contain_equal_length = None\n    found_data_tag = False\n    found_data_section = False\n    started_reading_data_section = False\n\n    with open(full_file_path_and_name, \"r\", encoding=\"cp1252\") as file:\n        for line in file:\n            # Strip white space from start/end of line\n            line = line.strip()\n\n            if line:\n                if line.startswith(\"@\"):  # Read meta-data\n                    if not line.startswith(\"@data\"):\n                        line_content = line.split(\" \")\n                        if line.startswith(\"@attribute\"):\n                            if (\n                                len(line_content) != 3\n                            ):  # Attributes have both name and type\n                                raise Exception(\"Invalid meta-data specification.\")\n\n                            col_names.append(line_content[1])\n                            col_types.append(line_content[2])\n                        else:\n                            if (\n                                len(line_content) != 2\n                            ):  # Other meta-data have only values\n                                raise Exception(\"Invalid meta-data specification.\")\n\n                            if line.startswith(\"@frequency\"):\n                                frequency = line_content[1]\n                            elif line.startswith(\"@horizon\"):\n                                forecast_horizon = int(line_content[1])\n                            elif line.startswith(\"@missing\"):\n                                contain_missing_values = bool(\n                                    strtobool(line_content[1])\n                                )\n                            elif line.startswith(\"@equallength\"):\n                                contain_equal_length = bool(strtobool(line_content[1]))\n\n                    else:\n                        if len(col_names) == 0:\n                            raise Exception(\n                                \"Missing attribute section. Attribute section must come before data.\"\n                            )\n\n                        found_data_tag = True\n                elif not line.startswith(\"#\"):\n                    if len(col_names) == 0:\n                        raise Exception(\n                            \"Missing attribute section. Attribute section must come before data.\"\n                        )\n                    elif not found_data_tag:\n                        raise Exception(\"Missing @data tag.\")\n                    else:\n                        if not started_reading_data_section:\n                            started_reading_data_section = True\n                            found_data_section = True\n                            all_series = []\n\n                            for col in col_names:\n                                all_data[col] = []\n\n                        full_info = line.split(\":\")\n\n                        if len(full_info) != (len(col_names) + 1):\n                            raise Exception(\"Missing attributes/values in series.\")\n\n                        series = full_info[len(full_info) - 1]\n                        series = series.split(\",\")\n\n                        if len(series) == 0:\n                            raise Exception(\n                                \"A given series should contains a set of comma separated numeric values. At least one numeric value should be there in a series. Missing values should be indicated with ? symbol\"\n                            )\n\n                        numeric_series = []\n\n                        for val in series:\n                            if val == \"?\":\n                                numeric_series.append(replace_missing_vals_with)\n                            else:\n                                numeric_series.append(float(val))\n\n                        if numeric_series.count(replace_missing_vals_with) == len(\n                            numeric_series\n                        ):\n                            raise Exception(\n                                \"All series values are missing. A given series should contains a set of comma separated numeric values. At least one numeric value should be there in a series.\"\n                            )\n\n                        all_series.append(pd.Series(numeric_series).array)\n\n                        for i in range(len(col_names)):\n                            att_val = None\n                            if col_types[i] == \"numeric\":\n                                att_val = int(full_info[i])\n                            elif col_types[i] == \"string\":\n                                att_val = str(full_info[i])\n                            elif col_types[i] == \"date\":\n                                att_val = datetime.strptime(\n                                    full_info[i], \"%Y-%m-%d %H-%M-%S\"\n                                )\n                            else:\n                                raise Exception(\n                                    \"Invalid attribute type.\"\n                                )  # Currently, the code supports only numeric, string and date types. Extend this as required.\n\n                            if att_val is None:\n                                raise Exception(\"Invalid attribute value.\")\n                            else:\n                                all_data[col_names[i]].append(att_val)\n\n                line_count = line_count + 1\n\n        if line_count == 0:\n            raise Exception(\"Empty file.\")\n        if len(col_names) == 0:\n            raise Exception(\"Missing attribute section.\")\n        if not found_data_section:\n            raise Exception(\"Missing series information under data section.\")\n\n        all_data[value_column_name] = all_series\n        loaded_data = pd.DataFrame(all_data)\n\n        return (\n            loaded_data,\n            frequency,\n            forecast_horizon,\n            contain_missing_values,\n            contain_equal_length,\n        )\n\ndef monash_format_convert(loaded_data, frequency, multivariate):\n    series_names = loaded_data['series_name'].values\n\n    if str(frequency) == '10_minutes':\n        freq = '10min'\n    elif str(frequency) == 'daily':\n        freq = 'D'\n    else:\n        freq = frequency\n\n    if multivariate:\n        timestamps = pd.date_range(start=loaded_data['start_timestamp'][0], periods=len(loaded_data['series_value'][0]), freq=freq)\n        new_df = pd.DataFrame({ 'date': timestamps })\n\n        series_df = pd.DataFrame({ series: loaded_data['series_value'][i] for i, series in enumerate(series_names) })\n        result_df = pd.concat([new_df, series_df], axis=1)\n    else:\n        result = []\n        for idx, row in loaded_data.iterrows():\n            result.append({\n                'target': np.array(row['series_value'], dtype=np.float32),\n                'start': pd.Period(row['start_timestamp'], freq=freq),\n                'feat_static_cat': np.array([idx], dtype=np.int32),\n                'item_id': idx,\n            })\n        result_df = pd.DataFrame(result)\n    return result_df"
  },
  {
    "path": "probts/data/data_utils/get_datasets.py",
    "content": "# ---------------------------------------------------------------------------------\n# Portions of this file are derived from Autoformer\n# - Source: https://github.com/thuml/Autoformer/tree/main\n#\n# We thank the authors for their contributions.\n# ---------------------------------------------------------------------------------\n\n\nimport os\nimport pandas as pd\nfrom probts.data.data_utils.time_features import time_features\nfrom probts.data.data_utils.data_utils import convert_monash_data_to_dataframe, monash_format_convert\nimport numpy as np\n\n\ndef get_dataset_info(dataset, data_path=None, freq=None):\n    \"\"\"\n    Get the file path and frequency associated with the specified dataset.\n    Parameters:\n        dataset (str): The name of the dataset.\n        data_path (str): Optional custom data path for the dataset.\n        freq (str): Optional custom frequency for the dataset.\n    Returns:\n        tuple: A tuple containing the data path and frequency.\n    \"\"\"\n    paths = {\n        'etth1': ('ETT-small/ETTh1.csv', 'H'),\n        'etth2': ('ETT-small/ETTh2.csv', 'H'),\n        'ettm1': ('ETT-small/ETTm1.csv', 'min'),\n        'ettm2': ('ETT-small/ETTm2.csv', 'min'),\n        'traffic_ltsf': ('traffic/traffic.csv', 'H'),\n        'electricity_ltsf': ('electricity/electricity.csv', 'H'),\n        'exchange_ltsf': ('exchange_rate/exchange_rate.csv', 'B'),\n        'illness_ltsf': ('illness/national_illness.csv', 'W'),\n        'weather_ltsf': ('weather/weather.csv', 'min'),\n        'caiso': ('caiso/caiso_20130101_20210630.csv', 'H'),\n        'nordpool': ('nordpool/production.csv', 'H'),\n        'turkey_power': ('kaggle/power Generation and consumption.csv', 'H'),\n        'istanbul_traffic': ('kaggle/istanbul_traffic.csv', 'H')\n    }\n    \n    if dataset in paths:\n        data_path, freq = paths[dataset]\n    else:\n        assert data_path is not None, f'Invalid dataset name: {dataset}! Provide --data.data_manager.init_args.data_path for custom datasets.'\n        assert freq is not None, 'Provide --data.data_manager.init_args.freq for custom datasets.'\n    return data_path, freq\n\ndef get_dataset_borders(dataset, data_size, train_ratio=0.7, test_ratio=0.2):\n    \"\"\"\n    Compute the start and end indices for train, validation, and test splits.\n    Parameters:\n        dataset (str): The name of the dataset.\n        data_size (int): Total number of time points in the dataset.\n        train_ratio (float): Proportion of the dataset used for training.\n        test_ratio (float): Proportion of the dataset used for testing.\n    Returns:\n        tuple: Two lists representing the start and end indices of each split.\n    \"\"\"\n    # Validate ratios\n    assert 0 < train_ratio <= 1, \"train_ratio must be between 0 and 1 (exclusive of 0).\"\n    assert 0 < test_ratio <= 1, \"test_ratio must be between 0 and 1 (exclusive of 0).\"\n    assert train_ratio + test_ratio <= 1, \"The sum of train_ratio and test_ratio must not exceed 1.\"\n\n    # Predefined borders for ETT datasets\n    if dataset == 'etth1' or dataset == 'etth2':\n        border_begin = [0, 12 * 30 * 24, 12 * 30 * 24 + 4 * 30 * 24]\n        border_end = [12 * 30 * 24, 12 * 30 * 24 + 4 * 30 * 24, 12 * 30 * 24 + 8 * 30 * 24]\n    elif dataset == 'ettm1' or dataset == 'ettm2':\n        border_begin = [0, 12 * 30 * 24 * 4, 12 * 30 * 24 * 4 + 4 * 30 * 24 * 4]\n        border_end = [12 * 30 * 24 * 4, 12 * 30 * 24 * 4 + 4 * 30 * 24 * 4, 12 * 30 * 24 * 4 + 8 * 30 * 24 * 4]\n    else:\n        # Calculate borders for custom datasets\n        num_train = int(data_size * train_ratio)\n        num_test = int(data_size * test_ratio)\n        num_vali = data_size - num_train - num_test\n        border_begin = [0, num_train, data_size - num_test]\n        border_end = [num_train, num_train + num_vali, data_size]\n    return border_begin, border_end\n\ndef load_dataset(root_path, data_path,freq='h', timeenc=1, multivariate=True):\n    \"\"\"\n    Load and process datasets.\n    Parameters:\n        root_path (str): Root directory for datasets.\n        data_path (str): Path to the specific dataset.\n        freq (str): Frequency of the dataset (e.g., 'H', 'min').\n        timeenc (int): Time encoding method (0 for temporal information, 1 for time feature based on frequency, 2 for raw date information).\n        multivariate (bool): Whether the dataset is multivariate.\n    Returns:\n        df_raw: the processed DataFrame\n        data_stamp: time features\n        target_dim: target dimensions\n        data_size: total length of timestamps.\n    \"\"\"\n    data_format = None\n    if '.tsf' in data_path:\n        # Load Monash time series dataset\n        df_raw, _, _, _, _ = convert_monash_data_to_dataframe(data_path)\n        df_raw = monash_format_convert(df_raw, freq, multivariate)\n        \n        if multivariate:\n            if freq.lower() == 'h':\n                df_raw.set_index('date', inplace=True)\n                df_raw = df_raw.resample(freq).mean().reset_index()\n    elif 'caiso' in data_path:\n        # Load and process CAISO dataset\n        data = pd.read_csv(os.path.join(root_path, data_path))\n        data['Date'] = data['Date'].astype('datetime64[ns]')\n        names = ['PGE','SCE','SDGE','VEA','CA ISO','PACE','PACW','NEVP','AZPS','PSEI']\n        df_raw = pd.DataFrame(pd.date_range('20130101','20210630',freq='H')[:-1], columns=['Date'])\n        for name in names:\n            current_df = data[data['zone'] == name].drop_duplicates(subset='Date', keep='last').rename(columns={'load':name}).drop(columns=['zone'])\n            df_raw = df_raw.merge(current_df, on='Date', how='outer')\n        df_raw = df_raw.rename(columns={'Date': 'date'})\n    elif 'nordpool' in data_path:\n        # Load and process Nordpool dataset\n        df_raw = pd.read_csv(os.path.join(root_path, data_path), parse_dates=['Time'])\n        df_raw = df_raw.rename(columns={'Time': 'date'})\n    elif 'power Generation and consumption' in data_path:\n        # Load and process Turkey Power dataset\n        df_raw = pd.read_csv(os.path.join(root_path, data_path), parse_dates=['Date_Time'])\n        df_raw = df_raw.rename(columns={'Date_Time': 'date'})\n        data_format = \"%d.%m.%Y %H:%M\"\n    elif 'istanbul_traffic' in data_path:\n        # Load and process Istanbul Traffic dataset\n        df_raw = pd.read_csv(os.path.join(root_path, data_path), parse_dates=['datetime'])\n        df_raw = df_raw.rename(columns={'datetime': 'date'})\n        df_raw.set_index('date', inplace=True)\n        df_raw = df_raw.resample(freq).mean().reset_index()\n    else:\n        # Load customized dataset\n        df_raw = pd.read_csv(os.path.join(root_path, data_path), parse_dates=['date'])\n    \n    # Process time encoding\n    if multivariate:\n        df_stamp = df_raw[['date']]\n        df_stamp['date'] = pd.to_datetime(df_stamp.date, format=data_format)\n        \n        if timeenc == 0:\n            df_stamp['month'] = df_stamp.date.apply(lambda row: row.month, 1)\n            df_stamp['day'] = df_stamp.date.apply(lambda row: row.day, 1)\n            df_stamp['weekday'] = df_stamp.date.apply(lambda row: row.weekday(), 1)\n            df_stamp['hour'] = df_stamp.date.apply(lambda row: row.hour, 1)\n            df_stamp['minute'] = df_stamp.date.apply(lambda row: row.minute, 1)\n            df_stamp['minute'] = df_stamp.minute.map(lambda x: x // 15)\n            data_stamp = df_stamp.drop(labels='date', axis=1).values\n        elif timeenc == 1:\n            data_stamp = time_features(pd.to_datetime(df_stamp['date'].values), freq=freq)\n            data_stamp = data_stamp.transpose(1, 0)\n        elif timeenc == 2:\n            data_stamp = pd.to_datetime(df_stamp['date'].values)\n            data_stamp = np.array(data_stamp, dtype='datetime64[s]')\n        else:\n            raise ValueError('Invalid timeenc value. timeenc should be sellected within [0, 1, 2].')\n        df_raw = df_raw.set_index(keys='date')\n        \n    else:\n        data_stamp = None\n    \n    # Replace missing values with 0\n    df_raw = df_raw.fillna(0)\n    # Determine target dimension and dataset size\n    target_dim = len(df_raw.columns) if multivariate else 1\n    data_size = len(df_raw)\n    return df_raw, data_stamp, target_dim, data_size"
  },
  {
    "path": "probts/data/data_utils/time_features.py",
    "content": "# ---------------------------------------------------------------------------------\n# Portions of this file are derived from GluonTS\n# - Source: https://github.com/awslabs/gluonts\n#\n# We thank the authors for their contributions.\n# ---------------------------------------------------------------------------------\n\n\n\nfrom typing import List\n\nimport numpy as np\nimport pandas as pd\nfrom pandas.tseries import offsets\nfrom pandas.tseries.frequencies import to_offset\nfrom gluonts.core.component import validated\nfrom gluonts.dataset.common import DataEntry\nfrom gluonts.transform import MapTransformation\nfrom typing import List, Type\n\nclass TimeFeature:\n    def __init__(self):\n        pass\n\n    def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:\n        pass\n\n    def __repr__(self):\n        return self.__class__.__name__ + \"()\"\n\n\nclass SecondOfMinute(TimeFeature):\n    \"\"\"Minute of hour encoded as value between [-0.5, 0.5]\"\"\"\n\n    def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:\n        return index.second / 59.0 - 0.5\n\n\nclass MinuteOfHour(TimeFeature):\n    \"\"\"Minute of hour encoded as value between [-0.5, 0.5]\"\"\"\n\n    def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:\n        return index.minute / 59.0 - 0.5\n\n\nclass HourOfDay(TimeFeature):\n    \"\"\"Hour of day encoded as value between [-0.5, 0.5]\"\"\"\n\n    def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:\n        return index.hour / 23.0 - 0.5\n\n\nclass DayOfWeek(TimeFeature):\n    \"\"\"Hour of day encoded as value between [-0.5, 0.5]\"\"\"\n\n    def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:\n        return index.dayofweek / 6.0 - 0.5\n\n\nclass DayOfMonth(TimeFeature):\n    \"\"\"Day of month encoded as value between [-0.5, 0.5]\"\"\"\n\n    def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:\n        return (index.day - 1) / 30.0 - 0.5\n\n\nclass DayOfYear(TimeFeature):\n    \"\"\"Day of year encoded as value between [-0.5, 0.5]\"\"\"\n\n    def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:\n        return (index.dayofyear - 1) / 365.0 - 0.5\n\n\nclass MonthOfYear(TimeFeature):\n    \"\"\"Month of year encoded as value between [-0.5, 0.5]\"\"\"\n\n    def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:\n        return (index.month - 1) / 11.0 - 0.5\n\n\nclass WeekOfYear(TimeFeature):\n    \"\"\"Week of year encoded as value between [-0.5, 0.5]\"\"\"\n\n    def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:\n        return (index.isocalendar().week - 1) / 52.0 - 0.5\n\n\ndef time_features_from_frequency_str(freq_str: str) -> List[TimeFeature]:\n    \"\"\"\n    Returns a list of time features that will be appropriate for the given frequency string.\n    Parameters\n    ----------\n    freq_str\n        Frequency string of the form [multiple][granularity] such as \"12H\", \"5min\", \"1D\" etc.\n    \"\"\"\n\n    features_by_offsets = {\n        offsets.YearEnd: [],\n        offsets.QuarterEnd: [MonthOfYear],\n        offsets.MonthEnd: [MonthOfYear],\n        offsets.Week: [DayOfMonth, WeekOfYear],\n        offsets.Day: [DayOfWeek, DayOfMonth, DayOfYear],\n        offsets.BusinessDay: [DayOfWeek, DayOfMonth, DayOfYear],\n        offsets.Hour: [HourOfDay, DayOfWeek, DayOfMonth, DayOfYear],\n        offsets.Minute: [\n            MinuteOfHour,\n            HourOfDay,\n            DayOfWeek,\n            DayOfMonth,\n            DayOfYear,\n        ],\n        offsets.Second: [\n            SecondOfMinute,\n            MinuteOfHour,\n            HourOfDay,\n            DayOfWeek,\n            DayOfMonth,\n            DayOfYear,\n        ],\n    }\n\n    offset = to_offset(freq_str)\n\n    for offset_type, feature_classes in features_by_offsets.items():\n        if isinstance(offset, offset_type):\n            return [cls() for cls in feature_classes]\n\n    supported_freq_msg = f\"\"\"\n    Unsupported frequency {freq_str}\n    The following frequencies are supported:\n        Y   - yearly\n            alias: A\n        M   - monthly\n        W   - weekly\n        D   - daily\n        B   - business days\n        H   - hourly\n        T   - minutely\n            alias: min\n        S   - secondly\n    \"\"\"\n    raise RuntimeError(supported_freq_msg)\n\n\ndef time_features(dates, freq='h'):\n    return np.vstack([feat(dates) for feat in time_features_from_frequency_str(freq)])\n\n\nclass FourierDateFeatures(TimeFeature):\n    def __init__(self, freq: str) -> None:\n        super().__init__()\n        # reocurring freq\n        freqs = [\n            \"month\",\n            \"day\",\n            \"hour\",\n            \"minute\",\n            \"weekofyear\",\n            \"weekday\",\n            \"dayofweek\",\n            \"dayofyear\",\n            \"daysinmonth\",\n        ]\n\n        assert freq in freqs\n        self.freq = freq\n\n    def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:\n        values = getattr(index, self.freq)\n        num_values = max(values) + 1\n        steps = [x * 2.0 * np.pi / num_values for x in values]\n        return np.vstack([np.cos(steps), np.sin(steps)])\n\n\ndef norm_freq_str(freq_str: str) -> str:\n    base_freq = freq_str.split(\"-\")[0]\n\n    # Pandas has start and end frequencies, e.g `AS` and `A` for yearly start\n    # and yearly end frequencies. We don't make that difference and instead\n    # rely only on the end frequencies which don't have the `S` prefix.\n    # Note: Secondly (\"S\") frequency exists, where we don't want to remove the\n    # \"S\"!\n    if len(base_freq) >= 2 and base_freq.endswith(\"S\"):\n        return base_freq[:-1]\n\n    return base_freq\n\n\ndef fourier_time_features_from_frequency(freq_str: str) -> List[TimeFeature]:\n    offset = to_offset(freq_str)\n    granularity = norm_freq_str(offset.name)\n    granularity = granularity.upper()\n    features = {\n        \"M\": [\"weekofyear\"],\n        \"W\": [\"daysinmonth\", \"weekofyear\"],\n        \"D\": [\"dayofweek\"],\n        \"B\": [\"dayofweek\", \"dayofyear\"],\n        \"H\": [\"hour\", \"dayofweek\"],\n        \"min\": [\"minute\", \"hour\", \"dayofweek\"],\n        \"T\": [\"minute\", \"hour\", \"dayofweek\"],\n    }\n\n    assert granularity in features, f\"freq {granularity} not supported\"\n\n    feature_classes: List[TimeFeature] = [\n        FourierDateFeatures(freq=freq) for freq in features[granularity]\n    ]\n    return feature_classes\n\n\ndef get_lags(freq_str:str):\n    \"\"\"\n    Calculate appropriate lag values for time series forecasting based on data frequency.\n\n    Parameters\n    ----------\n    freq_str : str\n        The frequency of the time series data. Supported values include:\n\n    Returns\n    -------\n    lags : list[int]\n        A list of lag values, representing the offsets of past observations to include in the model.\n        The lags are tailored to capture autocorrelation and seasonality patterns for the specified frequency.\n\n    Examples\n    --------\n    >>> get_lags(\"H\")\n    [1, 24, 168]  # Captures hourly, daily, and weekly seasonality\n\n    >>> get_lags(\"D\")\n    [1, 7, 14]  # Captures daily, weekly, and bi-weekly seasonality\n    \"\"\"\n    freq_str = freq_str.upper()\n    if freq_str == \"M\":\n        lags = [1, 12]\n    elif freq_str == \"D\":\n        lags = [1, 7, 14]\n    elif freq_str == \"B\":\n        lags = [1, 2]\n    elif freq_str == \"H\":\n        lags = [1, 24, 168]\n    elif freq_str in (\"T\", \"min\"):\n        lags = [1, 4, 12, 24, 48]\n    else:\n        lags = [1]\n\n    return lags\n\n\ndef target_transformation_length(\n    target: np.ndarray, pred_length: int, is_train: bool\n) -> int:\n    return target.shape[-1] + (0 if is_train else pred_length)\n\n\nclass AddCustomizedTimeFeatures(MapTransformation):\n    \"\"\"\n    Adds a set of time features.\n\n    If `is_train=True` the feature matrix has the same length as the `target`\n    field. If `is_train=False` the feature matrix has length\n    `len(target) + pred_length`\n\n    Parameters\n    ----------\n    start_field\n        Field with the start time stamp of the time series\n    target_field\n        Field with the array containing the time series values\n    output_field\n        Field name for result.\n    time_features\n        list of time features to use.\n    pred_length\n        Prediction length\n    \"\"\"\n\n    @validated()\n    def __init__(\n        self,\n        start_field: str,\n        target_field: str,\n        output_field: str,\n        time_features,\n        pred_length: int,\n        dtype: Type = np.float32,\n    ) -> None:\n        self.date_features = time_features\n        self.pred_length = pred_length\n        self.start_field = start_field\n        self.target_field = target_field\n        self.output_field = output_field\n        self.dtype = dtype\n\n    def map_transform(self, data: DataEntry, is_train: bool) -> DataEntry:\n        length = target_transformation_length(\n            data[self.target_field], self.pred_length, is_train=is_train\n        )\n\n        if len(self.date_features.shape) == 2:\n            data[self.output_field] = self.date_features[:length].astype(self.dtype)\n        else:\n            data[self.output_field] = self.date_features[:length].astype(np.float64)\n        data[self.output_field] = self.date_features[:length].astype(np.float64)\n        data[self.output_field] = np.transpose(data[self.output_field])\n        \n        return data\n"
  },
  {
    "path": "probts/data/data_wrapper.py",
    "content": "import torch\n\nclass ProbTSBatchData:\n    input_names_ = [\n        'target_dimension_indicator',\n        'past_time_feat',\n        'past_target_cdf',\n        'past_observed_values',\n        'past_is_pad',\n        'future_time_feat',\n        'future_target_cdf',\n        'future_observed_values',\n    ]\n    \n    def __init__(self, data_dict, device):\n        # Initialize attributes from the provided data dictionary\n        self.__dict__.update(data_dict)\n        self.__dict__['context_length'] = data_dict.get('context_length', None)\n        self.__dict__['prediction_length'] = data_dict.get('prediction_length', None)\n        self.__dict__['max_context_length'] = data_dict.get('max_context_length', None)\n        self.__dict__['max_prediction_length'] = data_dict.get('max_prediction_length', None)\n        \n        # Expand dimensions for univariate data\n        if len(self.__dict__['past_target_cdf'].shape) == 2:\n            self._expand_dimensions()\n        \n        # Set tensors to the specified device\n        self._set_device(device)\n        # Fill missing inputs with None\n        self._ensure_all_inputs_present()\n        # Process padding for observed values\n        self._process_padding()\n\n    def _ensure_all_inputs_present(self):\n        \"\"\"Ensure all expected inputs are present in the data.\"\"\"\n        for input in self.input_names_:\n            if input not in self.__dict__:\n                self.__dict__[input] = None\n\n    def _set_device(self, device):\n        \"\"\"Move all tensors to the specified device.\"\"\"\n        for k, v in self.__dict__.items():\n            if v is not None and torch.is_tensor(v):\n                v.to(device)\n        self.device = device\n\n    def _expand_dimensions(self):\n        \"\"\"Expand dimensions for target-related tensors if necessary.\"\"\"\n        self.__dict__[\"target_dimension_indicator\"] = self.__dict__[\"target_dimension_indicator\"][:, :1]\n        for input in ['past_target_cdf','past_observed_values','future_target_cdf','future_observed_values']:\n            self.__dict__[input] = self.__dict__[input].unsqueeze(-1)\n\n    def _process_padding(self):\n        \"\"\"Adjust observed values based on the padding indicator.\"\"\"\n        if self.__dict__['past_is_pad'] is not None:\n            self.__dict__[\"past_observed_values\"] = torch.min(\n                self.__dict__[\"past_observed_values\"],\n                1 - self.__dict__[\"past_is_pad\"].unsqueeze(-1)\n            )\n\n"
  },
  {
    "path": "probts/data/datasets/gift_eval_datasets.py",
    "content": "# Copyright (c) 2023, Salesforce, Inc.\n# SPDX-License-Identifier: Apache-2\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\nimport math\nfrom functools import cached_property\nfrom enum import Enum\nfrom pathlib import Path\nfrom typing import Iterable, Iterator\n\nimport datasets\nfrom dotenv import load_dotenv\nfrom gluonts.dataset import DataEntry\nfrom gluonts.dataset.common import ProcessDataEntry\nfrom gluonts.dataset.split import TestData, TrainingDataset, split\nfrom gluonts.itertools import Map\nfrom gluonts.time_feature import norm_freq_str\nfrom gluonts.transform import Transformation\nfrom pandas.tseries.frequencies import to_offset\nimport pyarrow.compute as pc\nfrom toolz import compose\n\n# add for probts transform\nfrom probts.data.data_utils.data_utils import get_rolling_test_of_gift_eval\n\nTEST_SPLIT = 0.1\nMAX_WINDOW = 20\n\nM4_PRED_LENGTH_MAP = {\n    \"A\": 6,\n    \"Q\": 8,\n    \"M\": 18,\n    \"W\": 13,\n    \"D\": 14,\n    \"H\": 48,\n}\n\nPRED_LENGTH_MAP = {\n    \"M\": 12,\n    \"W\": 8,\n    \"D\": 30,\n    \"H\": 48,\n    \"T\": 48,\n    \"S\": 60,\n}\n\nTFB_PRED_LENGTH_MAP = {\n    \"A\": 6,\n    \"H\": 48,\n    \"Q\": 8,\n    \"D\": 14,\n    \"M\": 18,\n    \"W\": 13,\n    \"U\": 8,\n    \"T\": 8,\n}\n\n\nclass Term(Enum):\n    SHORT = \"short\"\n    MEDIUM = \"medium\"\n    LONG = \"long\"\n\n    @property\n    def multiplier(self) -> int:\n        if self == Term.SHORT:\n            return 1\n        elif self == Term.MEDIUM:\n            return 10\n        elif self == Term.LONG:\n            return 15\n\n\ndef itemize_start(data_entry: DataEntry) -> DataEntry:\n    data_entry[\"start\"] = data_entry[\"start\"].item()\n    return data_entry\n\n\nclass MultivariateToUnivariate(Transformation):\n    def __init__(self, field):\n        self.field = field\n\n    def __call__(\n        self, data_it: Iterable[DataEntry], is_train: bool = False\n    ) -> Iterator:\n        for data_entry in data_it:\n            item_id = data_entry[\"item_id\"]\n            val_ls = list(data_entry[self.field])\n            for id, val in enumerate(val_ls):\n                data_entry[self.field] = val\n                data_entry[\"item_id\"] = item_id + \"_dim\" + str(id)\n                yield data_entry\n\n\nclass GiftEvalDataset:\n    def __init__(\n        self,\n        name: str,\n        term: Term | str = Term.SHORT,\n        to_univariate: bool = False,\n        storage_env_var: str = \"GIFT_EVAL\",\n    ):\n        self.term = Term(term)\n        self.name = name\n        self.to_univariate = to_univariate\n\n        load_dotenv()\n        storage_path = Path(os.getenv(storage_env_var))\n        self.hf_dataset = datasets.load_from_disk(str(storage_path / name)).with_format(\n            \"numpy\"\n        )\n\n    @cached_property\n    def gluonts_dataset(self):\n        process = ProcessDataEntry(\n            self.freq,\n            one_dim_target=self.target_dim == 1,\n        )\n        gluonts_dataset = Map(compose(process, itemize_start), self.hf_dataset)\n        if self.to_univariate:\n            gluonts_dataset = MultivariateToUnivariate(\"target\").apply(\n                gluonts_dataset\n            )\n        return gluonts_dataset\n\n    @cached_property\n    def prediction_length(self) -> int:\n        freq = norm_freq_str(to_offset(self.freq).name)\n        pred_len = (\n            M4_PRED_LENGTH_MAP[freq] if \"m4\" in self.name else PRED_LENGTH_MAP[freq]\n        )\n        return self.term.multiplier * pred_len\n\n    @cached_property\n    def freq(self) -> str:\n        return self.hf_dataset[0][\"freq\"]\n\n    @cached_property\n    def target_dim(self) -> int:\n        return (\n            target.shape[0]\n            if len((target := self.hf_dataset[0][\"target\"]).shape) > 1\n            else 1\n        )\n\n    @cached_property\n    def target_ndim(self) -> int:\n        return 1 if self.target_dim == 1 else 2\n\n    @cached_property\n    def past_feat_dynamic_real_dim(self) -> int:\n        if \"past_feat_dynamic_real\" not in self.hf_dataset[0]:\n            return 0\n        elif (\n            len(\n                (\n                    past_feat_dynamic_real := self.hf_dataset[0][\n                        \"past_feat_dynamic_real\"\n                    ]\n                ).shape\n            )\n            > 1\n        ):\n            return past_feat_dynamic_real.shape[0]\n        else:\n            return 1\n\n    @cached_property\n    def windows(self) -> int:\n        if \"m4\" in self.name:\n            return 1\n        w = math.ceil(TEST_SPLIT * self._min_series_length / self.prediction_length)\n        return min(max(1, w), MAX_WINDOW)\n\n    @cached_property\n    def _min_series_length(self) -> int:\n        if self.hf_dataset[0][\"target\"].ndim > 1:\n            lengths = pc.list_value_length(\n                pc.list_flatten(\n                    pc.list_slice(self.hf_dataset.data.column(\"target\"), 0, 1)\n                )\n            )\n        else:\n            lengths = pc.list_value_length(self.hf_dataset.data.column(\"target\"))\n        return min(lengths.to_numpy())\n\n    @cached_property\n    def sum_series_length(self) -> int:\n        if self.hf_dataset[0][\"target\"].ndim > 1:\n            lengths = pc.list_value_length(\n                pc.list_flatten(self.hf_dataset.data.column(\"target\"))\n            )\n        else:\n            lengths = pc.list_value_length(self.hf_dataset.data.column(\"target\"))\n        return sum(lengths.to_numpy())\n\n    @property\n    def training_dataset(self) -> TrainingDataset:\n        training_dataset, _ = split(\n            self.gluonts_dataset, offset=-self.prediction_length * (self.windows + 1)\n        )\n        return training_dataset\n\n    @property\n    def validation_dataset(self) -> TrainingDataset:\n        validation_dataset, _ = split(\n            self.gluonts_dataset, offset=-self.prediction_length * self.windows\n        )\n        return validation_dataset\n\n    @property\n    def test_dataset(self) -> TrainingDataset:\n        print(f\"BETA version: generating test datasets for gift eval, should contain {self.windows} windows.\")\n        test_dataset = get_rolling_test_of_gift_eval(\n            dataset=self.gluonts_dataset,\n            prediction_length=self.prediction_length,\n            windows=self.windows,\n        )\n        return test_dataset\n\n    @property\n    def test_data(self) -> TestData:\n        _, test_template = split(\n            self.gluonts_dataset, offset=-self.prediction_length * self.windows\n        )\n        test_data = test_template.generate_instances(\n            prediction_length=self.prediction_length,\n            windows=self.windows,\n            distance=self.prediction_length,\n        )\n        return test_data\n"
  },
  {
    "path": "probts/data/datasets/multi_horizon_datasets.py",
    "content": "# ---------------------------------------------------------------------------------\n# Portions of this file are derived from GluonTS\n# - Source: https://github.com/awslabs/gluonts\n#\n# We thank the authors for their contributions.\n# ---------------------------------------------------------------------------------\n\nfrom torch.utils.data import IterableDataset\nfrom gluonts.env import env\nfrom gluonts.dataset.common import Dataset\nfrom gluonts.dataset.field_names import FieldName\nfrom gluonts.transform import (\n    SelectFields,\n    Transformation,\n    Chain,\n    ValidationSplitSampler,\n    ExpectedNumInstanceSampler,\n    RenameFields,\n    AsNumpyArray,\n    ExpandDimArray,\n    AddObservedValuesIndicator,\n    AddTimeFeatures,\n    VstackFeatures,\n    SetFieldIfNotPresent,\n    TargetDimIndicator,\n    InstanceSplitter\n)\nfrom gluonts.dataset.common import DataEntry\nfrom gluonts.transform import InstanceSampler\nfrom gluonts.zebras._util import pad_axis\nfrom gluonts.dataset.common import DataEntry\nfrom gluonts.transform._base import FlatMapTransformation\n\nfrom probts.data.data_utils.time_features import fourier_time_features_from_frequency, AddCustomizedTimeFeatures\nfrom probts.data.datasets.single_horizon_datasets import TransformedIterableDataset\nfrom typing import Union\nfrom typing import Iterator, List, Optional, Tuple, Union\nimport numpy as np\nimport random\n\n\nclass MultiHorizonDataset():\n    \"\"\"\n    MultiHorizonDataset: Supports multi-horizon forecasting by enabling flexible context and prediction lengths.\n\n    Parameters:\n    ----------\n    input_names : list\n        Names of input fields required by the model.\n    freq : str\n        Frequency of the data (e.g., 'H' for hourly, 'D' for daily).\n    train_ctx_range : Union[int, list]\n        Range of context lengths for the training dataset.\n    train_pred_range : Union[int, list]\n        Range of prediction lengths for the training dataset.\n    val_ctx_range : Union[int, list]\n        Range of context lengths for the validation dataset.\n    val_pred_range : Union[int, list]\n        Range of prediction lengths for the validation dataset.\n    test_ctx_range : Union[int, list]\n        Range of context lengths for the testing dataset.\n    test_pred_range : Union[int, list]\n        Range of prediction lengths for the testing dataset.\n    multivariate : bool, optional, default=True\n        Whether the dataset contains multiple target variables.\n    continuous_sample : bool, optional, default=False\n        Whether to enable continuous sampling horizons from the train_pred_range.\n    \"\"\"\n    def __init__(\n        self,\n        input_names: list,\n        freq: str,\n        train_ctx_range: Union[int, list],\n        train_pred_range: Union[int, list],\n        val_ctx_range: Union[int, list],\n        val_pred_range: Union[int, list],\n        test_ctx_range: Union[int, list],\n        test_pred_range: Union[int, list],\n        multivariate: bool = True,\n        continuous_sample: bool = False,\n    ):\n        super().__init__()\n        self.input_names_ = input_names\n        self.train_ctx_range = train_ctx_range\n        self.train_pred_range = train_pred_range\n        self.val_ctx_range = val_ctx_range\n        self.val_pred_range = val_pred_range\n        self.test_ctx_range = test_ctx_range\n        self.test_pred_range=test_pred_range\n        self.continuous_sample = continuous_sample\n        \n        self.freq = freq\n        if multivariate:\n            self.expected_ndim = 2\n        else:\n            self.expected_ndim = 1\n\n    def get_sampler(self):\n        \"\"\"\n        Creates samplers for training, validation, and testing datasets.\n        Samplers control how data instances are selected for each mode.\n        \"\"\"\n        \n        # for training\n        train_min_past = min(self.train_ctx_range)\n        train_min_future = min(self.train_pred_range)\n        \n        # for validation\n        val_min_past = max(self.val_ctx_range)\n        val_min_future = max(self.val_pred_range)\n        \n        # for testing\n        if (type(self.test_ctx_range).__name__=='list'):\n            test_min_past = max(self.test_ctx_range)\n        else:\n            test_min_past=self.test_ctx_range\n        \n        if (type(self.test_pred_range).__name__=='list'):\n            test_min_future = max(self.test_pred_range)\n        else:\n            test_min_future=self.test_pred_range\n\n        self.train_sampler = ExpectedNumInstanceSampler(\n            num_instances=1.0,\n            min_past=train_min_past,\n            min_future=train_min_future,\n        )\n\n        self.val_sampler = ValidationSplitSampler(\n            min_past=val_min_past,\n            min_future=val_min_future,\n        )\n        \n        self.test_sampler = ValidationSplitSampler(\n            min_past=test_min_past,\n            min_future=test_min_future,\n        )\n\n        \n    def create_transformation(self, data_stamp=None, pred_len=None) -> Transformation:\n        \"\"\"\n        Creates a transformation pipeline for data preprocessing.\n\n        Parameters:\n        ----------\n        data_stamp : np.array, optional\n            Precomputed time features. If None, features are generated based on the frequency.\n        pred_len : int, optional\n            Prediction length for the transformation. If None, uses the maximum training prediction range.\n\n        Returns:\n        ----------\n        Chain : Transformation\n            A chain of transformations applied to the dataset.\n        \"\"\"\n        if data_stamp is None:\n            if self.freq in [\"M\", \"W\", \"D\", \"B\", \"H\", \"min\", \"T\"]:\n                time_features = fourier_time_features_from_frequency(self.freq)\n            else:\n                time_features = fourier_time_features_from_frequency('D')\n            self.time_feat_dim = len(time_features) * 2\n            time_feature_func = AddTimeFeatures\n        else:\n            self.time_feat_dim = data_stamp.shape[-1]\n            time_features = data_stamp\n            time_feature_func = AddCustomizedTimeFeatures\n            \n        if pred_len is None:\n            pred_len = max(self.train_pred_range)\n        else:\n            pred_len = max(pred_len)\n            \n        return Chain(\n            [\n                AsNumpyArray(\n                    field=FieldName.TARGET,\n                    expected_ndim=self.expected_ndim,\n                ),\n                ExpandDimArray(\n                    field=FieldName.TARGET,\n                    axis=None,\n                ),\n                AddObservedValuesIndicator(\n                    target_field=FieldName.TARGET,\n                    output_field=FieldName.OBSERVED_VALUES,\n                ),\n                time_feature_func(\n                    start_field=FieldName.START,\n                    target_field=FieldName.TARGET,\n                    output_field=FieldName.FEAT_TIME,\n                    time_features=time_features,\n                    pred_length=pred_len,\n                ),\n                VstackFeatures(\n                    output_field=FieldName.FEAT_TIME,\n                    input_fields=[FieldName.FEAT_TIME],\n                ),\n                SetFieldIfNotPresent(field=FieldName.FEAT_STATIC_CAT, value=[0]),\n                TargetDimIndicator(\n                    field_name=\"target_dimension_indicator\",\n                    target_field=FieldName.TARGET,\n                ),\n                AsNumpyArray(field=FieldName.FEAT_STATIC_CAT, expected_ndim=1),\n            ]\n        )\n\n    def create_instance_splitter(self, mode: str, pred_len=None, auto_search=False):\n        \"\"\"\n        Creates an instance splitter for slicing data sequences.\n\n        Parameters:\n        ----------\n        mode : str\n            Dataset mode. Must be one of ['train', 'val', 'test'].\n        pred_len : list, optional\n            Prediction length for validation or testing. If None, defaults to the predefined ranges.\n\n        Returns:\n        ----------\n        MultiHorizonSplitter : Transformation\n            Transformation that slices time series sequences.\n        \"\"\"\n        assert mode in [\"train\", \"val\", \"test\"]\n\n        self.get_sampler()\n        instance_sampler = {\n            \"train\": self.train_sampler,\n            \"val\": self.val_sampler,\n            \"test\": self.test_sampler,\n        }[mode]\n\n        if mode == \"train\":\n            past_length = self.train_ctx_range\n            future_length = self.train_pred_range\n        elif mode == 'val':\n            past_length = self.val_ctx_range\n            if pred_len is None:\n                future_length = self.val_pred_range\n            else:\n                future_length = pred_len\n        else:\n            if pred_len is None:\n                future_length = self.test_pred_range\n            else:\n                future_length = pred_len\n                \n            if auto_search:\n                past_length = [max(self.test_ctx_range) + max(future_length)]\n            else:\n                past_length = self.test_ctx_range\n            \n            \n        return MultiHorizonSplitter(\n            target_field=FieldName.TARGET,\n            is_pad_field=FieldName.IS_PAD,\n            start_field=FieldName.START,\n            forecast_start_field=FieldName.FORECAST_START,\n            instance_sampler=instance_sampler,\n            past_length=past_length,\n            future_length=future_length,\n            mode=mode,\n            continuous_sample=self.continuous_sample,\n            time_series_fields=[\n                FieldName.FEAT_TIME,\n                FieldName.OBSERVED_VALUES,\n            ],\n        ) + (\n            RenameFields(\n                {\n                    f\"past_{FieldName.TARGET}\": f\"past_{FieldName.TARGET}_cdf\",\n                    f\"future_{FieldName.TARGET}\": f\"future_{FieldName.TARGET}_cdf\",\n                }\n            )\n        )\n\n\n    def get_iter_dataset(self, dataset: Dataset, mode: str, data_stamp=None, pred_len=None, auto_search=False) -> IterableDataset:\n        \"\"\"\n        Creates an iterable dataset with applied transformations and splitters.\n\n        Parameters:\n        ----------\n        dataset : Dataset\n            Input dataset to transform.\n        mode : str\n            Mode of operation. Must be one of ['train', 'val', 'test'].\n        data_stamp : np.array, optional\n            Precomputed time features.\n        pred_len : list, optional\n            Prediction length for validation or testing.\n\n        Returns:\n        ----------\n        IterableDataset : TransformedIterableDataset\n            Transformed dataset ready for model training or evaluation.\n        \"\"\"\n        assert mode in [\"train\", \"val\", \"test\"]\n\n        transform = self.create_transformation(data_stamp, pred_len=pred_len)\n            \n            \n        if mode == 'train':\n            with env._let(max_idle_transforms=100):\n                instance_splitter = self.create_instance_splitter(mode)\n        else:\n            instance_splitter = self.create_instance_splitter(mode, pred_len=pred_len, auto_search=auto_search)\n\n\n        input_names = self.input_names_\n\n        iter_dataset = TransformedIterableDataset(\n            dataset,\n            transform=transform\n            + instance_splitter\n            + SelectFields(input_names),\n            is_train=True if mode == 'train' else False\n        )\n\n        return iter_dataset\n\n\nclass MultiHorizonSplitter(FlatMapTransformation):\n    \"\"\"\n    Split instances from a dataset, by slicing the target and other time series\n    fields at points in time selected by the specified sampler. The assumption\n    is that all time series fields start at the same time point.\n\n    It is assumed that time axis is always the last axis.\n\n    The ``target_field`` and each field in ``time_series_fields`` are removed and\n    replaced by two new fields, with prefix `past_` and `future_` respectively.\n\n    A ``past_is_pad`` is also added, that indicates whether values at a given\n    time point are padding or not.\n\n    Parameters\n    ----------\n\n    target_field\n        field containing the target\n    is_pad_field\n        output field indicating whether padding happened\n    start_field\n        field containing the start date of the time series\n    forecast_start_field\n        output field that will contain the time point where the forecast starts\n    instance_sampler\n        instance sampler that provides sampling indices given a time series\n    past_length\n        length of the target seen before making prediction\n    future_length\n        length of the target that must be predicted\n    lead_time\n        gap between the past and future windows (default: 0)\n    output_NTC\n        whether to have time series output in (time, dimension) or in\n        (dimension, time) layout (default: True)\n    time_series_fields\n        fields that contains time series, they are split in the same interval\n        as the target (default: None)\n    dummy_value\n        Value to use for padding. (default: 0.0)\n    \"\"\"\n\n    # @validated()\n    def __init__(\n        self,\n        target_field: str,\n        is_pad_field: str,\n        start_field: str,\n        forecast_start_field: str,\n        instance_sampler: InstanceSampler,\n        past_length: Union[int, list],\n        future_length: Union[int, list],\n        mode: str,\n        lead_time: int = 0,\n        output_NTC: bool = True,\n        time_series_fields: List[str] = [],\n        dummy_value: float = 0.0,\n        continuous_sample: bool = False,\n    ) -> None:\n        super().__init__()\n\n        # assert future_length > 0, \"The value of `future_length` should be > 0\"\n\n        self.instance_sampler = instance_sampler\n        self.past_length = past_length\n        self.future_length = future_length\n        self.continuous_sample = continuous_sample\n        \n        self.lead_time = lead_time\n        self.output_NTC = output_NTC\n        self.ts_fields = time_series_fields\n        self.target_field = target_field\n        self.is_pad_field = is_pad_field\n        self.start_field = start_field\n        self.forecast_start_field = forecast_start_field\n        self.dummy_value = dummy_value\n        self.mode = mode\n\n    def _past(self, col_name):\n        return f\"past_{col_name}\"\n\n    def _future(self, col_name):\n        return f\"future_{col_name}\"\n\n    def _split_array(\n        self, array: np.ndarray, idx: int, past_length: int, future_length: int\n    ) -> Tuple[np.ndarray, np.ndarray]:\n        if idx >= past_length:\n            past_piece = array[..., idx - past_length : idx]\n        else:\n            past_piece = pad_axis(\n                array[..., :idx],\n                axis=-1,\n                left=past_length - idx,\n                value=self.dummy_value,\n            )\n\n        future_start = idx + self.lead_time\n        future_slice = slice(future_start, future_start + future_length)\n        future_piece = array[..., future_slice]\n\n        return past_piece, future_piece\n\n    def _split_instance(self, entry: DataEntry, idx: int, is_train) -> DataEntry:\n        slice_cols = self.ts_fields + [self.target_field]\n        dtype = entry[self.target_field].dtype\n        entry = entry.copy()\n        \n        if is_train:\n            if self.continuous_sample:\n                past_len = random.randint(min(self.past_length), max(self.past_length))\n                pred_len = random.randint(min(self.future_length), max(self.future_length))\n            else:\n                past_len = random.choice(self.past_length) \n                pred_len = random.choice(self.future_length) \n        else:\n            past_len = max(self.past_length)\n            pred_len = max(self.future_length)\n\n        for ts_field in slice_cols:\n            past_piece, future_piece = self._split_array(entry[ts_field], idx, past_length=past_len, future_length=pred_len)\n\n            if self.output_NTC:\n                past_piece = past_piece.transpose()\n                future_piece = future_piece.transpose()\n\n            entry[self._past(ts_field)] = past_piece\n            entry[self._future(ts_field)] = future_piece\n            del entry[ts_field]\n\n        pad_indicator = np.zeros(past_len, dtype=dtype)\n        pad_length = max(past_len - idx, 0)\n        pad_indicator[:pad_length] = 1\n\n        entry[self._past(self.is_pad_field)] = pad_indicator\n        entry[self.forecast_start_field] = (\n            entry[self.start_field] + idx + self.lead_time\n        )\n        entry['context_length'] = past_len\n        entry['prediction_length'] = pred_len\n\n        return entry\n\n    def flatmap_transform(\n            self, entry: DataEntry, is_train: bool\n        ) -> Iterator[DataEntry]:\n        sampled_indices = self.instance_sampler(entry[self.target_field])\n        \n        for idx in sampled_indices:\n            yield self._split_instance(entry, idx, is_train)"
  },
  {
    "path": "probts/data/datasets/single_horizon_datasets.py",
    "content": "# ---------------------------------------------------------------------------------\n# Portions of this file are derived from PyTorch-TS\n# - Source: https://github.com/zalandoresearch/pytorch-ts\n#\n# We thank the authors for their contributions.\n# ---------------------------------------------------------------------------------\n\n\nfrom torch.utils.data import IterableDataset\nfrom gluonts.env import env\nfrom gluonts.dataset.common import Dataset\nfrom gluonts.itertools import Cyclic\nfrom gluonts.dataset.field_names import FieldName\nfrom gluonts.transform import (\n    SelectFields,\n    Transformation,\n    Chain,\n    InstanceSplitter,\n    ValidationSplitSampler,\n    ExpectedNumInstanceSampler,\n    RenameFields,\n    AsNumpyArray,\n    ExpandDimArray,\n    AddObservedValuesIndicator,\n    AddTimeFeatures,\n    VstackFeatures,\n    SetFieldIfNotPresent,\n    TargetDimIndicator,\n    TransformedDataset,\n)\nfrom probts.data.data_utils.time_features import fourier_time_features_from_frequency, AddCustomizedTimeFeatures\n\n\nclass SingleHorizonDataset():\n    \"\"\"\n    SingleHorizonDataset: Handles dataset transformation and instance splitting for single-horizon forecasting tasks.\n\n    Parameters:\n    ----------\n    input_names : list\n        List of input field names required by the model.\n    history_length : int\n        Length of the historical time series window for input data.\n    prediction_length : int\n        Length of the forecasting horizon.\n    freq : str\n        Data frequency (e.g., 'H' for hourly, 'D' for daily).\n    multivariate : bool, optional, default=True\n        Indicates if the dataset contains multiple target variables.\n    \"\"\"\n    def __init__(\n        self,\n        input_names: list,\n        history_length: int,\n        context_length: int,\n        prediction_length: int,\n        freq: str,\n        multivariate: bool = True\n    ):\n        super().__init__()\n        self.input_names_ = input_names\n        self.history_length = history_length\n        self.context_length = context_length\n        self.prediction_length = prediction_length\n        self.freq = freq\n        if multivariate:\n            self.expected_ndim = 2\n        else:\n            self.expected_ndim = 1\n\n    def get_sampler(self):\n        \"\"\"\n        Creates samplers for training, validation, and testing.\n        - Training: Generates instances randomly.\n        - Validation and Testing: Always selects the last time point.\n        \"\"\"\n        # returns a set of indices at which training instances will be generated\n        self.train_sampler = ExpectedNumInstanceSampler(\n            num_instances=1.0,\n            min_past=self.history_length,\n            min_future=self.prediction_length,\n        )\n\n        self.val_sampler = ValidationSplitSampler(\n            min_past=self.history_length,\n            min_future=self.prediction_length,\n        )\n        \n        self.test_sampler = ValidationSplitSampler(\n            min_past=self.history_length,\n            min_future=self.prediction_length,\n        )\n\n\n    def create_transformation(self, data_stamp=None) -> Transformation:\n        \"\"\"\n        Creates a data transformation pipeline to prepare inputs for the model.\n        Adds features such as time attributes and observed value indicators.\n\n        Parameters:\n        ----------\n        data_stamp : np.array, optional\n            Precomputed time features. If None, features are generated based on the data frequency.\n\n        Returns:\n        ----------\n        Chain : Transformation\n            A chain of transformations applied to the dataset.\n        \"\"\"\n        if data_stamp is None:\n            if self.freq in [\"M\", \"W\", \"D\", \"B\", \"H\", \"min\", \"T\"]:\n                time_features = fourier_time_features_from_frequency(self.freq)\n            else:\n                time_features = fourier_time_features_from_frequency('D')\n            self.time_feat_dim = len(time_features) * 2\n            time_feature_func = AddTimeFeatures\n        else:\n            self.time_feat_dim = data_stamp.shape[-1]\n            time_features = data_stamp\n            time_feature_func = AddCustomizedTimeFeatures\n\n        return Chain(\n            [\n                AsNumpyArray(\n                    field=FieldName.TARGET,\n                    expected_ndim=self.expected_ndim,\n                ),\n                AddObservedValuesIndicator(\n                    target_field=FieldName.TARGET,\n                    output_field=FieldName.OBSERVED_VALUES,\n                ),\n                time_feature_func(\n                    start_field=FieldName.START,\n                    target_field=FieldName.TARGET,\n                    output_field=FieldName.FEAT_TIME,\n                    time_features=time_features,\n                    pred_length=self.prediction_length,\n                ),\n                VstackFeatures(\n                    output_field=FieldName.FEAT_TIME,\n                    input_fields=[FieldName.FEAT_TIME],\n                ),\n                SetFieldIfNotPresent(field=FieldName.FEAT_STATIC_CAT, value=[0]),\n                TargetDimIndicator(\n                    field_name=\"target_dimension_indicator\",\n                    target_field=FieldName.TARGET,\n                ),\n                AsNumpyArray(field=FieldName.FEAT_STATIC_CAT, expected_ndim=1),\n            ]\n        )\n\n    def create_instance_splitter(self, mode: str, auto_search=False):\n        \"\"\"\n        Creates an instance splitter for training, validation, or testing.\n\n        Parameters:\n        ----------\n        mode : str\n            Mode of operation. Must be one of ['train', 'val', 'test'].\n\n        Returns:\n        ----------\n        InstanceSplitter : Transformation\n            A splitter transformation that slices input data for model training or evaluation.\n        \"\"\"\n        assert mode in [\"train\", \"val\", \"test\"]\n\n        self.get_sampler()\n        instance_sampler = {\n            \"train\": self.train_sampler,\n            \"val\": self.val_sampler,\n            \"test\": self.test_sampler,\n        }[mode]\n\n        if auto_search:\n            past_length = self.context_length + self.prediction_length\n        else:\n            past_length=self.history_length\n        \n        return InstanceSplitter(\n            target_field=FieldName.TARGET,\n            is_pad_field=FieldName.IS_PAD,\n            start_field=FieldName.START,\n            forecast_start_field=FieldName.FORECAST_START,\n            instance_sampler=instance_sampler,\n            past_length=past_length,\n            future_length=self.prediction_length,\n            time_series_fields=[\n                FieldName.FEAT_TIME,\n                FieldName.OBSERVED_VALUES,\n            ],\n        ) + (\n            RenameFields(\n                {\n                    f\"past_{FieldName.TARGET}\": f\"past_{FieldName.TARGET}_cdf\",\n                    f\"future_{FieldName.TARGET}\": f\"future_{FieldName.TARGET}_cdf\",\n                }\n            )\n        )\n\n    def get_iter_dataset(self, dataset: Dataset, mode: str, data_stamp=None, auto_search=False) -> IterableDataset:\n        \"\"\"\n        Creates an iterable dataset for training, validation, or testing.\n\n        Parameters:\n        ----------\n        dataset : Dataset\n            Input dataset to transform.\n        mode : str\n            Mode of operation. Must be one of ['train', 'val', 'test'].\n        data_stamp : np.array, optional\n            Precomputed time features.\n\n        Returns:\n        ----------\n        IterableDataset : TransformedIterableDataset\n            Transformed dataset with applied transformations and instance splitting.\n        \"\"\"\n        assert mode in [\"train\", \"val\", \"test\"]\n\n        transform = self.create_transformation(data_stamp)\n        if mode == 'train':\n            with env._let(max_idle_transforms=100):\n                instance_splitter = self.create_instance_splitter(mode)\n        else:\n            instance_splitter = self.create_instance_splitter(mode, auto_search=auto_search)\n\n\n        input_names = self.input_names_\n\n        iter_dataset = TransformedIterableDataset(\n            dataset,\n            transform=transform\n            + instance_splitter\n            + SelectFields(input_names),\n            is_train=True if mode == 'train' else False\n        )\n\n        return iter_dataset\n\n\n\nclass TransformedIterableDataset(IterableDataset):\n    \"\"\"\n    A transformed iterable dataset that applies a transformation pipeline on-the-fly.\n\n    Parameters:\n    ----------\n    dataset : Dataset\n        The original dataset to transform.\n    transform : Transformation\n        The transformation pipeline to apply.\n    is_train : bool, optional, default=True\n        Whether the dataset is used for training.\n    \"\"\"\n    def __init__(\n        self,\n        dataset: Dataset,\n        transform: Transformation,\n        is_train: bool = True\n    ):\n        super().__init__()\n\n        self.transformed_dataset = TransformedDataset(\n            Cyclic(dataset) if is_train else dataset,\n            transform,\n            is_train=is_train,\n        )\n\n    def __iter__(self):\n        return iter(self.transformed_dataset)"
  },
  {
    "path": "probts/model/__init__.py",
    "content": "from .forecast_module import *"
  },
  {
    "path": "probts/model/forecast_module.py",
    "content": "import numpy as np\nimport torch\nfrom torch import optim\nfrom typing import Dict\nimport lightning.pytorch as pl\nimport sys\n\nfrom probts.data import ProbTSBatchData\nfrom probts.data.data_utils.data_scaler import Scaler\nfrom probts.model.forecaster import Forecaster\nfrom probts.utils.evaluator import Evaluator\nfrom probts.utils.metrics import *\nfrom probts.utils.save_utils import update_metrics, calculate_weighted_average, load_checkpoint, get_hor_str\nfrom probts.utils.utils import init_class_helper\n\ndef get_weights(sampling_weight_scheme, max_hor):\n    '''\n    return: w [max_hor]\n    '''\n    if sampling_weight_scheme == 'random':\n        i_array = np.linspace(1 + 1e-5, max_hor - 1e-3, max_hor)\n        w = (1 / max_hor) * (np.log(max_hor) - np.log(i_array))\n    elif sampling_weight_scheme == 'const':\n        w = np.array([1 / max_hor] * max_hor)\n    elif sampling_weight_scheme == 'none':\n        return None\n    else:\n        raise ValueError(f\"Invalid sampling scheme {sampling_weight_scheme}.\")\n    \n    return torch.tensor(w)\n\n\nclass ProbTSForecastModule(pl.LightningModule):\n    def __init__(\n        self,\n        forecaster: Forecaster,\n        scaler: Scaler = None,\n        train_pred_len_list: list = None,\n        num_samples: int = 100,\n        learning_rate: float = 1e-3,\n        quantiles_num: int = 10,\n        load_from_ckpt: str = None,\n        sampling_weight_scheme: str = 'none',\n        optimizer_config = None,\n        lr_scheduler_config = None,\n        **kwargs\n    ):\n        super().__init__()\n        self.num_samples = num_samples\n        self.learning_rate = learning_rate\n        self.load_from_ckpt = load_from_ckpt\n        self.train_pred_len_list = train_pred_len_list\n        self.forecaster = forecaster\n        self.optimizer_config = optimizer_config\n        self.scheduler_config = lr_scheduler_config\n        \n        if self.optimizer_config is not None:\n            print(\"optimizer config: \", self.optimizer_config)\n            \n        if self.scheduler_config is not None:\n            print(\"lr_scheduler config: \", self.scheduler_config)\n        \n        self.scaler = scaler\n        self.evaluator = Evaluator(quantiles_num=quantiles_num)\n        \n        # init the parapemetr for sampling\n        self.sampling_weight_scheme = sampling_weight_scheme\n        print(f'sampling_weight_scheme: {sampling_weight_scheme}')\n        self.save_hyperparameters()\n\n    @classmethod\n    def load_from_checkpoint(self, checkpoint_path, scaler=None, learning_rate=None, no_training=False, **kwargs):\n        model = load_checkpoint(self, checkpoint_path, scaler=scaler, learning_rate=learning_rate, no_training=no_training, **kwargs)\n        return model\n\n    def training_forward(self, batch_data):\n        batch_data.past_target_cdf = self.scaler.transform(batch_data.past_target_cdf)\n        batch_data.future_target_cdf = self.scaler.transform(batch_data.future_target_cdf)\n        loss = self.forecaster.loss(batch_data)\n\n        if len(loss.shape) > 1:\n            loss_weights = get_weights(self.sampling_weight_scheme, loss.shape[1])\n            loss = (loss_weights.detach().to(loss.device).unsqueeze(0).unsqueeze(-1) * loss).sum(dim=1)\n            loss = loss.mean()\n        \n        return loss\n\n    def training_step(self, batch, batch_idx):\n        batch_data = ProbTSBatchData(batch, self.device)\n        loss = self.training_forward(batch_data)\n        self.log(\"train_loss\", loss, on_step=True, prog_bar=True, logger=True)\n        return loss\n\n    def evaluate(self, batch, stage='',dataloader_idx=None):\n        batch_data = ProbTSBatchData(batch, self.device)\n        pred_len = batch_data.future_target_cdf.shape[1]\n        orin_past_data = batch_data.past_target_cdf[:]\n        orin_future_data = batch_data.future_target_cdf[:]\n\n        norm_past_data = self.scaler.transform(batch_data.past_target_cdf)\n        norm_future_data = self.scaler.transform(batch_data.future_target_cdf)\n        self.batch_size.append(orin_past_data.shape[0])\n        \n        batch_data.past_target_cdf = self.scaler.transform(batch_data.past_target_cdf)\n        forecasts = self.forecaster.forecast(batch_data, self.num_samples)[:,:, :pred_len]\n        \n        # Calculate denorm metrics\n        denorm_forecasts = self.scaler.inverse_transform(forecasts)\n        metrics = self.evaluator(orin_future_data, denorm_forecasts, past_data=orin_past_data, freq=self.forecaster.freq)\n        self.metrics_dict = update_metrics(metrics, stage, target_dict=self.metrics_dict)\n        \n        # Calculate norm metrics\n        norm_metrics = self.evaluator(norm_future_data, forecasts, past_data=norm_past_data, freq=self.forecaster.freq)\n        self.metrics_dict = update_metrics(norm_metrics, stage, 'norm', target_dict=self.metrics_dict)\n        \n        l = orin_future_data.shape[1]\n        \n        if stage != 'test' and self.sampling_weight_scheme not in ['fix', 'none']:\n            loss_weights = get_weights('random', l)\n        else:\n            loss_weights = None\n\n        hor_metrics = self.evaluator(orin_future_data, denorm_forecasts, past_data=orin_past_data, freq=self.forecaster.freq, loss_weights=loss_weights)\n        \n        if stage == 'test':\n            hor_str = get_hor_str(self.forecaster.prediction_length, dataloader_idx)\n            if hor_str not in self.hor_metrics:\n                self.hor_metrics[hor_str] = {}\n\n            \n            self.hor_metrics[hor_str] = update_metrics(hor_metrics, stage, target_dict=self.hor_metrics[hor_str])\n\n        return hor_metrics\n\n    def validation_step(self, batch, batch_idx, dataloader_idx=None):\n        metrics = self.evaluate(batch, stage='val',dataloader_idx=dataloader_idx)\n        return metrics\n\n\n    def on_validation_epoch_start(self):\n        self.metrics_dict = {}\n        self.hor_metrics = {}\n        self.batch_size = []\n\n    def on_validation_epoch_end(self):\n        avg_metrics = calculate_weighted_average(self.metrics_dict, self.batch_size)\n        self.log_dict(avg_metrics, prog_bar=True)\n\n    def test_step(self, batch, batch_idx, dataloader_idx=None):\n        metrics = self.evaluate(batch, stage='test',dataloader_idx=dataloader_idx)\n        return metrics\n\n    def on_test_epoch_start(self):\n        self.metrics_dict = {}\n        self.hor_metrics = {}\n        self.avg_metrics = {}\n        self.avg_hor_metrics = {}\n        self.batch_size = []\n\n    def on_test_epoch_end(self):\n        if len(self.hor_metrics) > 0:\n            for hor_str, metric in self.hor_metrics.items():\n                self.avg_hor_metrics[hor_str] = calculate_weighted_average(metric, batch_size=self.batch_size)\n                self.avg_metrics.update(calculate_weighted_average(metric, batch_size=self.batch_size, hor=hor_str+'_'))\n        else:\n            self.avg_metrics = calculate_weighted_average(self.metrics_dict, self.batch_size)\n        \n        if isinstance(self.forecaster.prediction_length, int) or len(self.forecaster.prediction_length) < 2:\n            self.log_dict(self.avg_metrics, logger=True)\n\n    def predict_step(self, batch, batch_idx):\n        batch_data = ProbTSBatchData(batch, self.device)\n        forecasts = self.forecaster.forecast(batch_data, self.num_samples)\n        return forecasts\n\n    def configure_optimizers(self):\n        if self.optimizer_config is None:\n            optimizer = optim.Adam(self.parameters(), lr=self.learning_rate)\n        else:\n            optimizer = init_class_helper(self.optimizer_config['class_name'])\n            params = self.optimizer_config['init_args']\n            optimizer = optimizer(self.parameters(), **params)\n        \n        if self.scheduler_config is not None:\n            scheduler = init_class_helper(self.scheduler_config['class_name'])\n            params = self.scheduler_config['init_args']\n            scheduler = scheduler(optimizer=optimizer, **params)\n            \n            lr_scheduler = {\n                \"scheduler\": scheduler,\n                \"interval\": \"epoch\",\n                \"frequency\": 1,\n                \"monitor\": \"val_loss\",\n                \"strict\": True,\n                \"name\": None,\n            }\n\n            return {\"optimizer\": optimizer, \"lr_scheduler\": lr_scheduler}\n\n        return optimizer"
  },
  {
    "path": "probts/model/forecaster/__init__.py",
    "content": "from .forecaster import Forecaster\nfrom .point_forecaster import *\nfrom .prob_forecaster import *"
  },
  {
    "path": "probts/model/forecaster/forecaster.py",
    "content": "import torch\nfrom torch import nn\nfrom typing import List\n\nfrom probts.utils import weighted_average\nfrom probts.data.data_utils.data_scaler import TemporalScaler\nfrom typing import Union\n\nclass Forecaster(nn.Module):\n    def __init__(\n        self,\n        target_dim: int,\n        context_length: Union[list,int],\n        prediction_length: Union[list,int],\n        freq: str ,\n        use_lags: bool = False,\n        use_feat_idx_emb: bool = False,\n        use_time_feat: bool = False,\n        lags_list: List[int] = [],\n        feat_idx_emb_dim: int = 1,\n        time_feat_dim: int = 1,\n        use_scaling: bool = False,\n        autoregressive: bool = False,\n        no_training: bool = False,\n        dataset: str = None,\n        **kwargs\n    ):\n        super().__init__()\n        \n        self.context_length = context_length\n        self.prediction_length = prediction_length\n        \n        if isinstance(self.context_length, list):\n            self.max_context_length = max(self.context_length)\n        else:\n            self.max_context_length = self.context_length\n        \n        if isinstance(self.prediction_length, list):\n            self.max_prediction_length = max(self.prediction_length)\n        else:\n            self.max_prediction_length = self.prediction_length\n            \n        self.target_dim = target_dim\n        self.freq = freq\n        self.use_lags = use_lags\n        self.use_feat_idx_emb = use_feat_idx_emb\n        self.use_time_feat = use_time_feat\n        self.feat_idx_emb_dim = feat_idx_emb_dim\n        self.time_feat_dim = time_feat_dim\n        self.autoregressive = autoregressive\n        self.no_training = no_training\n        self.use_scaling = use_scaling\n        self.dataset = dataset\n        # Lag parameters\n        self.lags_list = lags_list\n        if self.use_scaling:\n            self.scaler = TemporalScaler()\n        else:\n            self.scaler = None\n        \n        self.lags_dim = len(self.lags_list) * target_dim\n        \n        if use_feat_idx_emb:\n            self.feat_idx_emb = nn.Embedding(\n                num_embeddings=self.target_dim, embedding_dim=self.feat_idx_emb_dim\n            )\n        else:\n            self.feat_idx_emb = None\n            \n        self.input_size = self.get_input_size()\n            \n\n    @property\n    def name(self):\n        return self.__class__.__name__\n\n    def get_input_size(self):\n        input_size = self.target_dim if not self.use_lags else self.lags_dim\n        if self.use_feat_idx_emb:\n            input_size += self.use_feat_idx_emb * self.target_dim\n        if self.use_time_feat:\n            input_size += self.time_feat_dim\n        return input_size\n\n    def get_lags(self, sequence, lags_list, lags_length=1):\n        \"\"\"\n        Get several lags from the sequence of shape (B, L, C) to (B, L', C*N),\n        where L' = lag_length and N = len(lag_list).\n        \"\"\"\n        assert max(lags_list) + lags_length <= sequence.shape[1]\n\n        lagged_values = []\n        for lag_index in lags_list:\n            begin_index = -lag_index - lags_length\n            end_index = -lag_index if lag_index > 0 else None\n            lagged_value = sequence[:, begin_index:end_index, ...]\n            if self.use_scaling:\n                lagged_value = lagged_value / self.scaler.scale\n            lagged_values.append(lagged_value)\n        return torch.cat(lagged_values, dim=-1)\n\n    def get_input_sequence(\n        self,\n        past_target_cdf,\n        future_target_cdf,\n        mode\n    ):\n        if mode == 'all':\n            sequence = torch.cat((past_target_cdf, future_target_cdf), dim=1)\n            seq_length = self.max_context_length + self.max_prediction_length\n        elif mode == 'encode':\n            sequence = past_target_cdf\n            seq_length = self.max_context_length\n        elif mode == 'decode':\n            sequence = past_target_cdf\n            seq_length = 1\n        else:\n            raise ValueError(f\"Unsupported input mode: {mode}\")\n        \n        if self.use_lags:\n            input_seq = self.get_lags(sequence, self.lags_list, seq_length)\n        else: \n            input_seq = sequence[:, -seq_length:, ...]\n            if self.use_scaling:\n                input_seq = input_seq / self.scaler.scale\n        return input_seq\n    \n    def get_input_feat_idx_emb(self, target_dimension_indicator, input_length):\n        input_feat_idx_emb = self.feat_idx_emb(target_dimension_indicator) # [B K D]\n\n        input_feat_idx_emb = (\n            input_feat_idx_emb.unsqueeze(1)\n            .expand(-1, input_length, -1, -1)\n            .reshape(-1, input_length, self.target_dim * self.feat_idx_emb_dim)\n        )\n        return input_feat_idx_emb # [B L K*D]\n\n    def get_input_time_feat(\n        self,\n        past_time_feat,\n        future_time_feat,\n        mode\n    ):\n        if mode == 'all':\n            time_feat = torch.cat(\n                (past_time_feat[:, -self.max_context_length:, ...], future_time_feat), dim=1)\n        elif mode == 'encode':\n            time_feat = past_time_feat[:, -self.max_context_length:, ...]\n        elif mode == 'decode':\n            time_feat = future_time_feat\n        return time_feat\n\n    def get_inputs(self, batch_data, mode):\n        inputs_list = []\n\n        input_seq = self.get_input_sequence(\n            batch_data.past_target_cdf, batch_data.future_target_cdf, mode=mode)\n        input_length = input_seq.shape[1] # [B L n_lags*K]\n        inputs_list.append(input_seq)\n\n        if self.use_feat_idx_emb:\n            input_feat_idx_emb = self.get_input_feat_idx_emb(\n                batch_data.target_dimension_indicator, input_length) # [B L K*D]\n            inputs_list.append(input_feat_idx_emb)\n\n        if self.use_time_feat:\n            input_time_feat = self.get_input_time_feat(\n                batch_data.past_time_feat, batch_data.future_time_feat, mode=mode) # [B L Dt]\n            inputs_list.append(input_time_feat)\n        return torch.cat(inputs_list, dim=-1).to(dtype=torch.float32)\n    \n    def get_scale(self, batch_data):\n        self.scaler.fit(\n            batch_data.past_target_cdf[:, -self.max_context_length:, ...],\n            batch_data.past_observed_values[:, -self.max_context_length:, ...]\n        )\n    \n    def get_weighted_loss(self, batch_data, loss):\n        observed_values =  batch_data.future_observed_values\n        loss_weights, _ = observed_values.min(dim=-1, keepdim=True)\n        loss = weighted_average(loss, weights=loss_weights, dim=1)\n        return loss\n    \n    def loss(self, batch_data):\n        raise NotImplementedError\n    \n    def forecast(self, batch_data=None, num_samples=None):\n        raise NotImplementedError\n"
  },
  {
    "path": "probts/model/forecaster/point_forecaster/__init__.py",
    "content": "from .mean import MeanForecaster\nfrom .naive import NaiveForecaster\nfrom .linear import LinearForecaster\nfrom .patchtst import PatchTST\nfrom .transformer import TransformerForecaster\nfrom .gru import GRUForecaster\nfrom .dlinear import DLinear\nfrom .nlinear import NLinear\nfrom .nhits import NHiTS\nfrom .timesnet import TimesNet\nfrom .itransformer import iTransformer\nfrom .autoformer import Autoformer\nfrom .tsmixer import TSMixer\nfrom .elastst import ElasTST\nfrom .time_moe import TimeMoE\nfrom .timesfm import TimesFM\nfrom .moderntcn import ModernTCN\n\n# ------- add timesfm to sys.path ----------\ntry:\n    import os, sys\n    current_dir = os.path.dirname(os.path.realpath(__file__))\n    project_root = os.path.abspath(os.path.join(current_dir, '..', '..', '..', '..'))\n    timesfm_path = os.path.join(project_root, 'submodules', 'timesfm', 'src')\n\n    if timesfm_path not in sys.path:\n        sys.path.append(timesfm_path)\nexcept Exception as e:\n    print(f\"Warning: Unable to add timesfm to sys.path. {e}\")\n# ------------------------------------------\n\nimport importlib\n\nmodules = [\n    ('timer', 'Timer'),\n    ('units', 'UniTS'),\n    ('forecastpfn', 'ForecastPFN'),\n    ('tinytimemixer', 'TinyTimeMixer'),\n]\n\nfor module, class_name in modules:\n    try:\n        mod = importlib.import_module(f\".{module}\", package=__package__)\n        globals()[class_name] = getattr(mod, class_name)\n    except ImportError:\n        # print(f\"Warning: {class_name} is not available due to missing dependencies.\")\n        pass"
  },
  {
    "path": "probts/model/forecaster/point_forecaster/autoformer.py",
    "content": "# ---------------------------------------------------------------------------------\n# Portions of this file are derived from Autoformer\n# - Source: https://github.com/thuml/Autoformer\n# - Paper: Autoformer: Decomposition Transformers with Auto-Correlation for Long-Term Series Forecasting\n# - License: MIT License\n\n# We thank the authors for their contributions.\n# ---------------------------------------------------------------------------------\n\n\nimport torch\nimport torch.nn as nn\nfrom probts.model.forecaster import Forecaster\nfrom probts.model.nn.arch.TransformerModule.Embed import DataEmbedding_wo_pos\nfrom probts.model.nn.arch.AutoformerModule.AutoCorrelation import AutoCorrelation, AutoCorrelationLayer\nfrom probts.model.nn.arch.AutoformerModule.Autoformer_EncDec import Encoder, Decoder, EncoderLayer, DecoderLayer, my_Layernorm, series_decomp\n\n\nclass Autoformer(Forecaster):\n    def __init__(\n        self,\n        moving_avg: int = 25,\n        factor: int = 1,\n        n_heads: int = 8,\n        activation: str = 'gelu',\n        e_layers: int = 2,\n        d_layers: int = 1,\n        output_attention: bool = False,\n        d_ff: int = 256,\n        label_len: int = 48,\n        embed: str = 'timeF',\n        dropout: float = 0.1,\n        f_hidden_size: int = 256,\n        **kwargs\n    ):\n        super().__init__(**kwargs)\n        if isinstance(self.context_length, list):\n            self.context_length = max(self.context_length)\n        self.label_len = self.context_length\n\n        # Decomp\n        kernel_size = moving_avg\n        self.decomp = series_decomp(kernel_size)\n\n        # Embedding\n        # The series-wise connection inherently contains the sequential information.\n        # Thus, we can discard the position embedding of transformers.\n        self.enc_embedding = DataEmbedding_wo_pos(self.target_dim, f_hidden_size, embed, self.freq.lower(),\n                                                  dropout)\n        self.dec_embedding = DataEmbedding_wo_pos(self.target_dim, f_hidden_size, embed, self.freq.lower(),\n                                                  dropout)\n\n        # Encoder\n        self.model_encoder = Encoder(\n            [\n                EncoderLayer(\n                    AutoCorrelationLayer(\n                        AutoCorrelation(False, factor, attention_dropout=dropout,\n                                        output_attention=output_attention),\n                        f_hidden_size, n_heads),\n                    f_hidden_size,\n                    d_ff,\n                    moving_avg=moving_avg,\n                    dropout=dropout,\n                    activation=activation\n                ) for l in range(e_layers)\n            ],\n            norm_layer=my_Layernorm(f_hidden_size)\n        )\n        \n        # Decoder\n        self.model_decoder = Decoder(\n            [\n                DecoderLayer(\n                    AutoCorrelationLayer(\n                        AutoCorrelation(True, factor, attention_dropout=dropout,\n                                        output_attention=False),\n                        f_hidden_size, n_heads),\n                    AutoCorrelationLayer(\n                        AutoCorrelation(False, factor, attention_dropout=dropout,\n                                        output_attention=False),\n                        f_hidden_size, n_heads),\n                    f_hidden_size,\n                    self.target_dim,\n                    d_ff,\n                    moving_avg=moving_avg,\n                    dropout=dropout,\n                    activation=activation,\n                )\n                for l in range(d_layers)\n            ],\n            norm_layer=my_Layernorm(f_hidden_size),\n            projection=nn.Linear(f_hidden_size, self.target_dim, bias=True)\n        )\n        self.loss_fn = nn.MSELoss(reduction='none')\n        \n    def forward(self, inputs, pred_len, enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None, *args, **kwargs):\n        B, _, _ = inputs.shape\n\n        if self.use_time_feat:\n            past_target = inputs[:,:self.context_length, :self.target_dim]\n            x_mark_enc = inputs[:,:self.context_length, self.target_dim:]\n            time_feat = inputs[:,:,self.target_dim:]\n        else:\n            past_target = inputs[:,:self.context_length,:self.target_dim]\n            x_mark_enc = None\n            time_feat = None\n            \n        \n        # decomp init\n        mean = torch.mean(past_target, dim=1).unsqueeze(1).repeat(1, pred_len, 1)\n        zeros = torch.zeros([B, pred_len, self.target_dim], device=past_target.device)\n        seasonal_init, trend_init = self.decomp(past_target)\n        # decoder input\n        trend_init = torch.cat([trend_init[:, -self.label_len:, :], mean], dim=1)\n        seasonal_init = torch.cat([seasonal_init[:, -self.label_len:, :], zeros], dim=1)\n\n        enc_out = self.enc_embedding(past_target, x_mark_enc)\n        enc_out, attns = self.model_encoder(enc_out, attn_mask=enc_self_mask)\n        # dec\n        dec_out = self.dec_embedding(seasonal_init, time_feat)\n        seasonal_part, trend_part = self.model_decoder(dec_out, enc_out, x_mask=dec_self_mask, cross_mask=dec_enc_mask,\n                                                 trend=trend_init)\n        # final\n        dec_out = trend_part + seasonal_part\n        return dec_out[:, -pred_len:, :]\n\n    def loss(self, batch_data):\n        max_pred_len = batch_data.max_prediction_length if batch_data.max_prediction_length is not None else max(self.train_prediction_length)\n        inputs = self.get_inputs(batch_data, 'all')\n        outputs = self(inputs, max_pred_len)\n        \n        loss = self.loss_fn(batch_data.future_target_cdf, outputs)\n        loss = self.get_weighted_loss(batch_data, loss)\n        return loss.mean()\n\n    def forecast(self, batch_data, num_samples=None):\n        max_pred_len = batch_data.future_target_cdf.shape[1]\n        inputs = self.get_inputs(batch_data, 'all')\n\n        outputs = self(inputs, max_pred_len)\n        return outputs.unsqueeze(1)"
  },
  {
    "path": "probts/model/forecaster/point_forecaster/dlinear.py",
    "content": "# ---------------------------------------------------------------------------------\n# Portions of this file are derived from LTSF-Linear\n# - Source: https://github.com/cure-lab/LTSF-Linear\n# - Paper: Are Transformers Effective for Time Series Forecasting?\n# - License: Apache-2.0\n#\n# We thank the authors for their contributions.\n# ---------------------------------------------------------------------------------\n\n\nimport torch\nimport torch.nn as nn\nfrom probts.model.forecaster import Forecaster\nfrom probts.model.nn.arch.decomp import series_decomp\n\nclass DLinear(Forecaster):\n    def __init__(\n        self,\n        kernel_size: int,\n        individual: bool,\n        **kwargs\n    ):\n        super().__init__(**kwargs)\n        if self.input_size != self.target_dim:\n            self.enc_linear = nn.Linear(\n                in_features=self.input_size, out_features=self.target_dim\n            )\n        else:\n            self.enc_linear = nn.Identity()\n\n\n        # Decompsition Kernel Size\n        self.kernel_size = kernel_size\n        self.decompsition = series_decomp(kernel_size)\n        self.individual = individual\n\n        if self.individual:\n            self.Linear_Seasonal = nn.ModuleList()\n            self.Linear_Trend = nn.ModuleList()\n            \n            for i in range(self.target_dim):\n                self.Linear_Seasonal.append(nn.Linear(self.context_length, self.prediction_length))\n                self.Linear_Trend.append(nn.Linear(self.context_length, self.prediction_length))\n        else:\n            self.Linear_Seasonal = nn.Linear(self.context_length, self.prediction_length)\n            self.Linear_Trend = nn.Linear(self.context_length, self.prediction_length)\n        self.loss_fn = nn.MSELoss(reduction='none')\n\n    def encoder(self, inputs):\n        seasonal_init, trend_init = self.decompsition(inputs)\n\n        # [B,C,L]\n        seasonal_init, trend_init = seasonal_init.permute(0,2,1), trend_init.permute(0,2,1)\n\n        if self.individual:\n            seasonal_output = torch.zeros([seasonal_init.size(0),seasonal_init.size(1),self.prediction_length],dtype=seasonal_init.dtype).to(seasonal_init.device)\n            trend_output = torch.zeros([trend_init.size(0),trend_init.size(1),self.prediction_length],dtype=trend_init.dtype).to(trend_init.device)\n            for i in range(self.target_dim):\n                seasonal_output[:,i,:] = self.Linear_Seasonal[i](seasonal_init[:,i,:])\n                trend_output[:,i,:] = self.Linear_Trend[i](trend_init[:,i,:])\n        else:\n            seasonal_output = self.Linear_Seasonal(seasonal_init)\n            trend_output = self.Linear_Trend(trend_init)\n\n        outputs = seasonal_output + trend_output # [B,C,L]\n        return outputs.permute(0,2,1)\n\n    def loss(self, batch_data):\n        inputs = self.get_inputs(batch_data, 'encode')\n        inputs = self.enc_linear(inputs)\n        outputs = self.encoder(inputs)\n        \n        loss = self.loss_fn(batch_data.future_target_cdf, outputs)\n        loss = self.get_weighted_loss(batch_data, loss)\n        return loss.mean()\n\n    def forecast(self, batch_data, num_samples=None):\n        inputs = self.get_inputs(batch_data, 'encode')\n        inputs = self.enc_linear(inputs)\n        outputs = self.encoder(inputs)\n        return outputs.unsqueeze(1)\n"
  },
  {
    "path": "probts/model/forecaster/point_forecaster/elastst.py",
    "content": "import torch\nimport torch.nn as nn\nfrom typing import Union\nfrom probts.model.forecaster import Forecaster\nfrom probts.model.nn.arch.ElasTSTModule.ElasTST_backbone import ElasTST_backbone\nfrom probts.utils import convert_to_list, weighted_average\nfrom probts.data.data_utils.data_scaler import InstanceNorm\n\nclass ElasTST(Forecaster):\n    def __init__(\n        self,\n        l_patch_size: Union[str, int, list] = '8_16_32',\n        k_patch_size: int = 1,\n        stride: int = None,\n        rotate: bool = True, \n        addv: bool = False,\n        bin_att: bool = False,\n        rope_theta_init: str = 'exp',\n        min_period: float = 1, \n        max_period: float = 1000,\n        learn_tem_emb: bool = False,\n        learnable_rope: bool = True, \n        abs_tem_emb: bool = False,\n        structured_mask: bool = True,\n        max_seq_len: int = 1024,\n        theta_base: float = 10000,\n        t_layers: int = 1, \n        v_layers: int = 0,\n        patch_share_backbone: bool = True,\n        n_heads: int = 16, \n        d_k: int = 8, \n        d_v: int = 8,\n        d_inner: int = 256, \n        dropout: float = 0.,\n        in_channels: int = 1,\n        f_hidden_size: int = 40,\n        use_norm: bool = True,\n        **kwargs\n    ):\n        \"\"\"\n        ElasTST model.\n\n        Parameters\n        ----------\n        l_patch_size : Union[str, int, list]\n            Patch sizes configuration.\n        k_patch_size : int\n            Patch size for variables.\n        stride : int\n            Stride for patch splitting. If None, uses patch size as default.\n        rotate : bool\n            Apply rotational positional embeddings.\n        addv : bool\n            Whether to add RoPE information to value in attention. If False, only rotate the key and query embeddings.\n        bin_att : bool\n            Use binary attention biases to encode variate indices (any-variate attention).\n        rope_theta_init : str\n            Initialization for TRoPE, default is 'exp', as used in the paper. Options: ['exp', 'linear', 'uniform', 'rope'].\n        min_period : float\n            Minimum initialized period coefficient for rotary embeddings.\n        max_period : float\n            Maximum initialized period coefficient for rotary embeddings.\n        learn_tem_emb : bool\n            Whether to use learnable temporal embeddings.\n        learnable_rope : bool\n            Make period coefficient in TRoPE learnable.\n        abs_tem_emb : bool\n            Use absolute temporal embeddings if True.\n        structured_mask : bool\n            Apply structured mask or not.\n        max_seq_len : int\n            Maximum sequence length for the input time series.\n        theta_base : int\n            Base frequency of vanilla RoPE.\n        t_layers : int\n            Number of temporal attention layers.\n        v_layers : int\n            Number of variable attention layers.\n        patch_share_backbone : bool\n            Share Transformer backbone across patches.\n        n_heads : int\n            Number of attention heads in the multi-head attention mechanism.\n        d_k : int\n            Dimensionality of key embeddings in attention.\n        d_v : int\n            Dimensionality of value embeddings in attention.\n        d_inner : int\n            Size of inner layers in the feed-forward network.\n        dropout : float\n            Dropout rate for regularization during training.\n        in_channels : int\n            Number of input channels in the time series data. We only consider univariable.\n        f_hidden_size : int\n            Hidden size for the feed-forward layers.\n        use_norm : bool\n            Whether to apply instance normalization.\n        **kwargs : dict\n            Additional keyword arguments for extended functionality.\n        \"\"\"\n\n        super().__init__(**kwargs)\n        \n        self.l_patch_size = convert_to_list(l_patch_size)\n        self.use_norm = use_norm\n        # Model\n        self.model = ElasTST_backbone(l_patch_size=self.l_patch_size, \n            stride=stride, \n            k_patch_size=k_patch_size, \n            in_channels=in_channels,\n            t_layers=t_layers, \n            v_layers=v_layers, \n            hidden_size=f_hidden_size, \n            d_inner=d_inner,\n            n_heads=n_heads, \n            d_k=d_k, \n            d_v=d_v,\n            dropout=dropout,\n            rotate=rotate, \n            max_seq_len=max_seq_len, \n            theta=theta_base,\n            addv=addv, \n            bin_att=bin_att,\n            learn_tem_emb=learn_tem_emb, \n            abs_tem_emb=abs_tem_emb, \n            learnable_theta=learnable_rope, \n            structured_mask=structured_mask,\n            rope_theta_init=rope_theta_init, \n            min_period=min_period, \n            max_period=max_period,\n            patch_share_backbone=patch_share_backbone\n        )\n        \n        self.loss_fn = nn.MSELoss(reduction='none')\n        self.instance_norm = InstanceNorm()\n    \n    def forward(self, batch_data, pred_len, dataset_name=None):\n        new_pred_len = pred_len\n        for p in self.l_patch_size:\n            new_pred_len = self.check_divisibility(new_pred_len, p)\n        \n        B, _, K = batch_data.past_target_cdf.shape\n        past_target = batch_data.past_target_cdf\n        past_observed_values = batch_data.past_observed_values\n        \n        if self.use_norm:\n            past_target = self.instance_norm(past_target, 'norm')\n\n        # future_observed_values is the mask indicate whether there is a value in a position\n        future_observed_values = torch.zeros([B, new_pred_len, K]).to(batch_data.future_observed_values.device)\n\n        pred_len = batch_data.future_observed_values.shape[1]\n        future_observed_values[:,:pred_len] = batch_data.future_observed_values\n\n        # target placeholder\n        future_placeholder = torch.zeros([B, new_pred_len, K]).to(batch_data.past_target_cdf.device)\n\n        x, pred_list = self.model(past_target, future_placeholder, past_observed_values, future_observed_values, dataset_name=dataset_name)\n        dec_out = x[:, :pred_len]\n        if self.use_norm:\n            dec_out = self.instance_norm(dec_out, 'denorm')\n\n        return dec_out # [b l k], [b l k #patch_size]\n\n\n    def loss(self, batch_data, reduce='none'):\n        max_pred_len = batch_data.max_prediction_length if batch_data.max_prediction_length is not None else self.max_prediction_length\n            \n        predict = self(batch_data, max_pred_len, dataset_name=None, )\n        target = batch_data.future_target_cdf\n        \n        observed_values = batch_data.future_observed_values\n        loss = self.loss_fn(target, predict)\n\n        loss = self.get_weighted_loss(observed_values, loss, reduce=reduce)\n        \n        if reduce=='mean':\n            loss = loss.mean()\n        return loss\n\n    def forecast(self, batch_data, num_samples=None):\n        # max_pred_len = batch_data.max_prediction_length if batch_data.max_prediction_length is not None else max(self.prediction_length)\n        max_pred_len = batch_data.future_target_cdf.shape[1]\n        outputs = self(batch_data, max_pred_len, dataset_name=None, )\n        return outputs.unsqueeze(1)\n    \n    def check_divisibility(self, pred_len, patch_size):\n        if pred_len % patch_size == 0:\n            return pred_len\n        else:  \n            return (pred_len // patch_size + 1) * patch_size  \n\n    def get_weighted_loss(self, observed_values, loss, reduce='mean'):\n        loss_weights, _ = observed_values.min(dim=-1, keepdim=True)\n        loss = weighted_average(loss, weights=loss_weights, dim=1, reduce=reduce)\n        return loss"
  },
  {
    "path": "probts/model/forecaster/point_forecaster/forecastpfn.py",
    "content": "# ---------------------------------------------------------------------------------\n# Portions of this file are derived from ForecastPFN\n# - Source: https://github.com/abacusai/ForecastPFN\n# - Paper: ForecastPFN: Synthetically-Trained Zero-Shot Forecasting\n# - License: Apache License 2.0\n\n# We thank the authors for their contributions.\n# ---------------------------------------------------------------------------------\n\n\nimport datetime\n\nimport numpy as np\nimport pandas as pd\nimport tensorflow as tf\nimport torch\nfrom keras import backend\nfrom sklearn.preprocessing import StandardScaler\n\nfrom probts.model.forecaster import Forecaster\n\n\ndef smape(y_true, y_pred):\n    \"\"\" Calculate Armstrong's original definition of sMAPE between `y_true` & `y_pred`.\n        `loss = 200 * mean(abs((y_true - y_pred) / (y_true + y_pred), axis=-1)`\n        Args:\n        y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`.\n        y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`.\n        Returns:\n        Symmetric mean absolute percentage error values. shape = `[batch_size, d0, ..\n        dN-1]`.\n        \"\"\"\n    y_pred = tf.convert_to_tensor(y_pred)\n    y_true = tf.cast(y_true, y_pred.dtype)\n    diff = tf.abs(\n        (y_true - y_pred) /\n        backend.maximum(y_true + y_pred, backend.epsilon())\n    )\n    return 200.0 * backend.mean(diff, axis=-1)\n\n\nclass ForecastPFN(Forecaster):\n    def __init__(\n        self,\n        label_len: int = 48,\n        ckpt_path: str = None,\n        **kwargs\n    ):\n        super().__init__(**kwargs)\n        self.no_training = True\n\n        self.label_len = label_len\n        \n        self.model = tf.keras.models.load_model(ckpt_path, custom_objects={'smape': smape})\n\n\n    def _ForecastPFN_time_features(self, x_mark_enc: np.ndarray, x_mark_dec: np.ndarray):\n        def extract_time_features(ts):\n            original_shape = ts.shape\n            ts = ts.reshape(-1)  # Flatten the array\n            if type(ts[0]) == datetime.datetime:\n                year = np.array([x.year for x in ts])\n                month = np.array([x.month for x in ts])\n                day = np.array([x.day for x in ts])\n                day_of_week = np.array([x.weekday() + 1 for x in ts])\n                day_of_year = np.array([x.timetuple().tm_yday for x in ts])\n            else:\n                ts = pd.to_datetime(ts)\n                year = ts.year.values\n                month = ts.month.values\n                day = ts.day.values\n                day_of_week = ts.day_of_week.values + 1\n                day_of_year = ts.day_of_year.values\n            \n            features = np.stack([year, month, day, day_of_week, day_of_year], axis=-1)\n            return features.reshape(*original_shape, -1).squeeze()\n\n        # Process the encoder and decoder inputs\n        x_mark_enc_features = extract_time_features(x_mark_enc)\n        x_mark_dec_features = extract_time_features(x_mark_dec)\n\n        return x_mark_enc_features, x_mark_dec_features\n\n    def _process_tuple(self, x, x_mark, y_mark, horizon):\n        \"\"\"\n        x: tensor of shape (n, 1)\n        x_mark: tensor of shape (n, d)\n        y_mark: tensor of shape (horizon, d)\n\n        where\n        n       is the input sequence length\n        horizon is the output sequence length\n        d is the dimensionality of the time_stamp (5 for ForecastPFN)\n        \"\"\"\n        if tf.reduce_all(x == x[0]):\n            x = tf.concat([x[:-1], x[-1:] + 1], axis=0)\n        \n        history = x.numpy()\n        scaler = StandardScaler()\n        scaler.fit(history)\n        history = scaler.transform(history)\n        \n        history_mean = np.nanmean(history[-6:])\n        history_std = np.nanstd(history[-6:])\n        local_scale = history_mean + history_std + 1e-4\n        \n        history = np.clip(history / local_scale, a_min=0, a_max=1)\n        \n        if x.shape[0] != 100:\n            if x.shape[0] > 100:\n                target = x_mark[-100:, :]\n                history = history[-100:, :]\n            else:\n                target = tf.pad(x_mark, [[100 - x.shape[0], 0], [0, 0]])\n                history = tf.pad(history, [[100 - x.shape[0], 0], [0, 0]])\n            \n            history = tf.repeat(tf.expand_dims(history, axis=0), horizon, axis=0)[:, :, 0]\n            ts = tf.repeat(tf.expand_dims(target, axis=0), horizon, axis=0)\n        else:\n            ts = tf.repeat(tf.expand_dims(x_mark, axis=0), horizon, axis=0)\n            history = tf.convert_to_tensor(history, dtype=tf.float32)\n        \n        task = tf.fill([horizon], 1)\n        y_mark_tensor = tf.convert_to_tensor(y_mark[-horizon:, :], dtype=tf.int64)\n        target_ts = tf.expand_dims(y_mark_tensor, axis=1)\n        \n        model_input = {'ts': ts, 'history': history, 'target_ts': target_ts, 'task': task}\n        pred_vals = self.model(model_input)\n        \n        scaled_vals = pred_vals['result'].numpy().T.reshape(-1) * pred_vals['scale'].numpy().reshape(-1)\n        scaled_vals = scaler.inverse_transform([scaled_vals])\n        return scaled_vals\n    \n    \n    def _process_batch(self, batch_x, batch_y, batch_x_mark, batch_y_mark):\n        preds = []\n        for idx, (x, y, x_mark, y_mark) in enumerate(zip(batch_x, batch_y, batch_x_mark, batch_y_mark)):\n            pred = self._process_tuple(x, x_mark, y_mark, self.prediction_length)\n            preds.append(pred)\n        return preds\n\n\n    def forecast(self, batch_data, num_samples=None):\n        # For now, we only support batch_size=1\n        B, _, K = batch_data.past_target_cdf.shape\n        inputs = batch_data.past_target_cdf[:, -self.context_length:, ...].cpu()\n        x_mark_enc = batch_data.past_time_feat[:, -self.context_length:, ...].cpu().numpy().astype('datetime64[s]')\n        x_mark_dec = batch_data.future_time_feat.cpu().numpy().astype('datetime64[s]')\n        x_mark_enc, x_mark_dec = self._ForecastPFN_time_features(x_mark_enc, x_mark_dec)\n\n        x_mark_dec = tf.concat([x_mark_enc[:, -self.label_len:, :], x_mark_dec], axis=1)\n        \n        inputs = tf.reshape(inputs, [-1, self.context_length, 1])\n        x_mark_enc = tf.repeat(x_mark_enc, repeats=K, axis=0)\n        x_mark_dec = tf.repeat(x_mark_dec, repeats=K, axis=0)\n        \n        dec_inp = tf.zeros_like(inputs[:, -self.prediction_length:, :])\n        dec_inp = tf.concat([inputs[:, -self.label_len:, :], dec_inp], axis=1)\n        x_mark_enc = tf.cast(x_mark_enc, tf.int64)\n        x_mark_dec = tf.cast(x_mark_dec, tf.int64)\n        \n        outputs = self._process_batch(inputs, dec_inp, x_mark_enc, x_mark_dec)\n        outputs = tf.concat(outputs, axis=0)\n        outputs = tf.reshape(outputs, [B, -1, K])\n        outputs = outputs[:, :self.prediction_length, :].numpy()\n        outputs = torch.tensor(outputs)\n        return outputs.unsqueeze(1)"
  },
  {
    "path": "probts/model/forecaster/point_forecaster/gru.py",
    "content": "import torch\nimport torch.nn as nn\n\nfrom probts.data import ProbTSBatchData\nfrom probts.utils import repeat\nfrom probts.model.forecaster import Forecaster\n\n\nclass GRUForecaster(Forecaster):\n    def __init__(\n        self,\n        num_layers: int = 2,\n        f_hidden_size: int = 40,\n        dropout: float = 0.1,\n        **kwargs\n    ):\n        super().__init__(**kwargs)\n        self.autoregressive = True\n        \n        self.model = nn.GRU(\n            input_size=self.input_size,\n            hidden_size=f_hidden_size,\n            num_layers=num_layers,\n            dropout=dropout,\n            batch_first=True\n        )\n        self.linear = nn.Linear(f_hidden_size, self.target_dim)\n        self.loss_fn = nn.MSELoss(reduction='none')\n\n    def loss(self, batch_data):\n        inputs = self.get_inputs(batch_data, 'all')\n        outputs, _ = self.model(inputs)\n        outputs = outputs[:, -self.prediction_length-1:-1, ...]\n        outputs = self.linear(outputs)\n        \n        loss = self.loss_fn(batch_data.future_target_cdf, outputs)\n        loss = self.get_weighted_loss(batch_data, loss)\n        return loss.mean()\n\n    def forecast(self, batch_data, num_samples=None):\n        forecasts = []\n        states = self.encode(batch_data)\n        past_target_cdf = batch_data.past_target_cdf\n        \n        for k in range(self.prediction_length):\n            current_batch_data = ProbTSBatchData({\n                'target_dimension_indicator': batch_data.target_dimension_indicator,\n                'past_target_cdf': past_target_cdf,\n                'future_time_feat': batch_data.future_time_feat[:, k : k + 1:, ...]\n            }, device=batch_data.device)\n\n            outputs, states = self.decode(current_batch_data, states)\n            outputs = self.linear(outputs)\n            forecasts.append(outputs)\n\n            past_target_cdf = torch.cat(\n                (past_target_cdf, outputs), dim=1\n            )\n\n        forecasts = torch.cat(forecasts, dim=1).reshape(\n            -1, self.prediction_length, self.target_dim)\n        return forecasts.unsqueeze(1)\n\n    def encode(self, batch_data):\n        inputs = self.get_inputs(batch_data, 'encode')\n        outputs, states = self.model(inputs)\n        return states\n\n    def decode(self, batch_data, states=None):\n        inputs = self.get_inputs(batch_data, 'decode')\n        outputs, states = self.model(inputs, states)\n        return outputs, states\n"
  },
  {
    "path": "probts/model/forecaster/point_forecaster/itransformer.py",
    "content": "# ---------------------------------------------------------------------------------\n# Portions of this file are derived from iTransformer\n# - Source: https://github.com/thuml/iTransformer\n# - Paper: iTransformer: Inverted Transformers Are Effective for Time Series Forecasting\n# - License: MIT License\n\n# We thank the authors for their contributions.\n# ---------------------------------------------------------------------------------\n\n\nimport torch\nimport torch.nn as nn\nfrom probts.model.forecaster import Forecaster\nfrom probts.model.nn.arch.TransformerModule.Transformer_EncDec import Encoder, EncoderLayer\nfrom probts.model.nn.arch.TransformerModule.SelfAttention_Family import FullAttention, AttentionLayer\nfrom probts.model.nn.arch.TransformerModule.Embed import DataEmbedding_inverted\n\nclass iTransformer(Forecaster):\n    def __init__(\n        self,\n        factor: int = 1,\n        n_heads: int = 8,\n        activation: str = 'gelu',\n        e_layers: int = 2,\n        output_attention: bool = False,\n        d_ff: int = 512,\n        label_len: int = 48,\n        use_norm: bool = True,\n        class_strategy:str = 'projection',\n        dropout: float = 0.1,\n        f_hidden_size: int = 512,\n        **kwargs\n    ):\n        super().__init__(**kwargs)\n        \n        self.label_len = label_len\n        \n        self.use_norm = use_norm\n        # Embedding\n        self.enc_embedding = DataEmbedding_inverted(self.context_length, f_hidden_size,\n                                                    dropout)\n        self.class_strategy = class_strategy\n        # Encoder-only architecture\n        self.model_encoder = Encoder(\n            [\n                EncoderLayer(\n                    AttentionLayer(\n                        FullAttention(False, factor, attention_dropout=dropout,\n                                      output_attention=output_attention), f_hidden_size, n_heads),\n                    f_hidden_size,\n                    d_ff,\n                    dropout=dropout,\n                    activation=activation\n                ) for l in range(e_layers)\n            ],\n            norm_layer=torch.nn.LayerNorm(f_hidden_size)\n        )\n        self.projector = nn.Linear(f_hidden_size, self.prediction_length, bias=True)\n        self.loss_fn = nn.MSELoss(reduction='none')\n\n    def forward(self, inputs):\n        if self.use_time_feat:\n            past_target = inputs[:,:,:self.target_dim]\n            x_mark_enc = inputs[:,:,-self.target_dim:]\n        else:\n            past_target = inputs\n            x_mark_enc = None\n            \n        \n        if self.use_norm:\n            # Normalization from Non-stationary Transformer\n            means = past_target.mean(1, keepdim=True).detach()\n            past_target = past_target - means\n            stdev = torch.sqrt(torch.var(past_target, dim=1, keepdim=True, unbiased=False) + 1e-5)\n            past_target /= stdev\n\n        _, _, N = past_target.shape # B L N\n        # B: batch_size;    E: d_model; \n        # L: seq_len;       S: pred_len;\n        # N: number of variate (tokens), can also includes covariates\n\n        # Embedding\n        # B L N -> B N E                (B L N -> B L E in the vanilla Transformer)\n        enc_out = self.enc_embedding(past_target, x_mark_enc) # covariates (e.g timestamp) can be also embedded as tokens\n        \n        # B N E -> B N E                (B L E -> B L E in the vanilla Transformer)\n        # the dimensions of embedded time series has been inverted, and then processed by native attn, layernorm and ffn modules\n        enc_out, attns = self.model_encoder(enc_out, attn_mask=None)\n\n        # B N E -> B N S -> B S N \n        dec_out = self.projector(enc_out).permute(0, 2, 1)[:, :, :N] # filter the covariates\n\n        if self.use_norm:\n            # De-Normalization from Non-stationary Transformer\n            dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.prediction_length, 1))\n            dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, self.prediction_length, 1))\n\n        return dec_out[:, -self.prediction_length:, :]\n\n    def forecast(self, batch_data, num_samples=None):\n        inputs = self.get_inputs(batch_data, 'encode')\n        output = self(inputs)\n\n        return output.unsqueeze(1)\n\n    def loss(self, batch_data):\n        inputs = self.get_inputs(batch_data, 'encode')\n        outputs = self(inputs)\n        \n        loss = self.loss_fn(batch_data.future_target_cdf, outputs)\n        loss = self.get_weighted_loss(batch_data, loss)\n        return loss.mean()"
  },
  {
    "path": "probts/model/forecaster/point_forecaster/linear.py",
    "content": "# ---------------------------------------------------------------------------------\n# Portions of this file are derived from LTSF-Linear\n# - Source: https://github.com/cure-lab/LTSF-Linear\n# - Paper: Are Transformers Effective for Time Series Forecasting?\n# - License: Apache-2.0\n#\n# We thank the authors for their contributions.\n# ---------------------------------------------------------------------------------\n\n\nimport torch\nimport torch.nn as nn\n\nfrom probts.model.forecaster import Forecaster\n\n\nclass LinearForecaster(Forecaster):\n    def __init__(\n        self,\n        individual: bool = True,\n        **kwargs\n    ):\n        super().__init__(**kwargs)\n        self.individual = individual\n        \n        if self.individual:\n            self.linear = nn.ModuleList()\n            for i in range(self.input_size):\n                self.linear.append(nn.Linear(self.context_length, self.prediction_length))\n        else:\n            self.linear = nn.Linear(self.context_length, self.prediction_length)\n        self.out_linear = nn.Linear(self.input_size, self.target_dim)\n        self.loss_fn = nn.MSELoss(reduction='none')\n\n    def forward(self, x):\n        if self.individual:\n            outputs = torch.zeros([x.size(0), self.prediction_length, x.size(2)], dtype=x.dtype).to(x.device)\n            for i in range(self.input_size):\n                outputs[:, :, i] = self.linear[i](x[:, :, i])\n        else:\n            outputs = self.linear(x.permute(0,2,1)).permute(0,2,1)\n        outputs = self.out_linear(outputs)\n        return outputs\n\n    def forecast(self, batch_data, num_samples=None):\n        inputs = self.get_inputs(batch_data, 'encode')\n        forecasts = self(inputs).unsqueeze(1)\n        return forecasts\n\n    def loss(self, batch_data):\n        inputs = self.get_inputs(batch_data, 'encode')\n        outputs = self(inputs)\n        \n        loss = self.loss_fn(batch_data.future_target_cdf, outputs)\n        loss = self.get_weighted_loss(batch_data, loss)\n        return loss.mean()\n"
  },
  {
    "path": "probts/model/forecaster/point_forecaster/mean.py",
    "content": "import torch\nfrom einops import repeat\nfrom probts.model.forecaster import Forecaster\n\n\nclass MeanForecaster(Forecaster):\n    def __init__(\n        self,\n        global_mean: torch.Tensor,\n        mode: str = 'batch',\n        **kwargs\n    ):\n        super().__init__(**kwargs)\n        self.global_mean = global_mean\n        self.mode = mode\n        self.no_training = True\n\n    @property\n    def name(self):\n        return self.mode + self.__class__.__name__\n        \n    def forecast(self, batch_data, num_samples=None):\n        B = batch_data.past_target_cdf.shape[0]\n        if self.mode == 'global':\n            outputs = self.global_mean.clone()\n        elif self.mode == 'batch':\n            outputs = torch.mean(batch_data.past_target_cdf, dim=1)\n            outputs = torch.mean(outputs, dim=0)\n        else:\n            raise ValueError(f\"Unsupported mode: {self.mode}\")\n            \n        outputs = repeat(outputs,'d -> b n l d', b=B, n=1, l=self.prediction_length)\n        return outputs\n"
  },
  {
    "path": "probts/model/forecaster/point_forecaster/moderntcn.py",
    "content": "# ---------------------------------------------------------------------------------\n# Portions of this file are derived from ModernTCN\n# - Source: https://github.com/luodhhh/ModernTCN/tree/main\n# - Paper: ModernTCN: A Modern Pure Convolution Structure for General Time Series Analysis\n# - License: MIT License\n#\n# We thank the authors for their contributions.\n# ---------------------------------------------------------------------------------\nimport sys\nimport torch\nimport torch.nn as nn\nfrom typing import List\nfrom probts.model.forecaster import Forecaster\nfrom probts.model.nn.arch.decomp import series_decomp\nfrom probts.model.nn.arch.ModernTCN_backbone import ModernTCNModel\n# torch.backends.cudnn.enabled = False\n\nclass ModernTCN(Forecaster):\n    def __init__(\n        self,\n        kernel_size: int = 25,             \n        decomposition: int = 0,           \n        stem_ratio: int = 6,             \n        downsample_ratio: int = 2,      \n        ffn_ratio: int = 2,          \n        num_blocks: List[int] = [1, 1, 1, 1],  \n        large_size: List[int] = [31, 29, 27, 13], \n        small_size: List[int] = [5, 5, 5, 5],  \n        dims: List[int] = [256, 256, 256, 256], \n        dw_dims: List[int] = [256, 256, 256, 256], \n        small_kernel_merged: bool = False, \n        use_multi_scale: bool = True,     \n        revin: int = 1,                  \n        affine: int = 0,     \n        subtract_last: int = 0,  \n        individual: int = 0,      \n        patch_size: int = 16,  \n        patch_stride: int = 8,  \n        dropout: float = 0.05,\n        head_dropout: float = 0.0,\n        **kwargs\n    ):\n        super().__init__(**kwargs)\n        \n        self.stem_ratio = stem_ratio\n        self.downsample_ratio = downsample_ratio\n        self.ffn_ratio = ffn_ratio\n        self.num_blocks = num_blocks\n        self.large_size = large_size\n        self.small_size = small_size\n        self.dims = dims\n        self.dw_dims = dw_dims\n\n        self.nvars = self.target_dim\n        self.small_kernel_merged = small_kernel_merged\n        self.drop_backbone = dropout\n        self.drop_head = head_dropout\n        self.use_multi_scale = use_multi_scale\n        self.revin = revin\n        self.affine = affine\n        self.subtract_last = subtract_last\n\n        self.seq_len = self.context_length\n        self.c_in = self.nvars,\n        self.individual = individual\n        self.target_window = self.prediction_length\n\n        self.kernel_size = kernel_size\n        self.patch_size = patch_size\n        self.patch_stride = patch_stride\n\n        self.decomposition = decomposition\n        if self.decomposition:\n            self.decomp_module = series_decomp(self.kernel_size)\n            self.model_res = ModernTCNModel(patch_size=self.patch_size,patch_stride=self.patch_stride,stem_ratio=self.stem_ratio, downsample_ratio=self.downsample_ratio, ffn_ratio=self.ffn_ratio, num_blocks=self.num_blocks, large_size=self.large_size, small_size=self.small_size, dims=self.dims, dw_dims=self.dw_dims,\n                 nvars=self.nvars, small_kernel_merged=self.small_kernel_merged, backbone_dropout=self.drop_backbone, head_dropout=self.drop_head, use_multi_scale=self.use_multi_scale, revin=self.revin, affine=self.affine,\n                 subtract_last=self.subtract_last, freq=self.freq, seq_len=self.seq_len, c_in=self.c_in, individual=self.individual, target_window=self.target_window)\n            self.model_trend = ModernTCNModel(patch_size=self.patch_size,patch_stride=self.patch_stride,stem_ratio=self.stem_ratio, downsample_ratio=self.downsample_ratio, ffn_ratio=self.ffn_ratio, num_blocks=self.num_blocks, large_size=self.large_size, small_size=self.small_size, dims=self.dims, dw_dims=self.dw_dims,\n                 nvars=self.nvars, small_kernel_merged=self.small_kernel_merged, backbone_dropout=self.drop_backbone, head_dropout=self.drop_head, use_multi_scale=self.use_multi_scale, revin=self.revin, affine=self.affine,\n                 subtract_last=self.subtract_last, freq=self.freq, seq_len=self.seq_len, c_in=self.c_in, individual=self.individual, target_window=self.target_window)\n        else:\n            self.model = ModernTCNModel(patch_size=self.patch_size,patch_stride=self.patch_stride,stem_ratio=self.stem_ratio, downsample_ratio=self.downsample_ratio, ffn_ratio=self.ffn_ratio, num_blocks=self.num_blocks, large_size=self.large_size, small_size=self.small_size, dims=self.dims, dw_dims=self.dw_dims,\n                 nvars=self.nvars, small_kernel_merged=self.small_kernel_merged, backbone_dropout=self.drop_backbone, head_dropout=self.drop_head, use_multi_scale=self.use_multi_scale, revin=self.revin, affine=self.affine,\n                 subtract_last=self.subtract_last, freq=self.freq, seq_len=self.seq_len, c_in=self.c_in, individual=self.individual, target_window=self.target_window)\n            \n        self.loss_fn = nn.MSELoss(reduction='none')\n        \n        if self.input_size != self.target_dim:\n            self.enc_linear = nn.Linear(\n                in_features=self.input_size, out_features=self.target_dim\n            )\n        else:\n            self.enc_linear = nn.Identity()\n\n    def encoder(self, x, te=None):\n        if self.decomposition:\n            res_init, trend_init = self.decomp_module(x)\n            res_init, trend_init = res_init.permute(0, 2, 1), trend_init.permute(0, 2, 1)\n            if te is not None:\n                te = te.permute(0, 2, 1)\n            res = self.model_res(res_init, te)\n            trend = self.model_trend(trend_init, te)\n            x = res + trend\n            x = x.permute(0, 2, 1)\n        else:\n            x = x.permute(0, 2, 1)\n            if te is not None:\n                te = te.permute(0, 2, 1)\n\n            x = self.model(x, te)\n            x = x.permute(0, 2, 1)\n        return x\n\n    def loss(self, batch_data):\n        inputs = self.get_inputs(batch_data, 'encode')\n        # inputs = inputs[:,:,:self.target_dim]\n        inputs = self.enc_linear(inputs)\n        outputs = self.encoder(inputs)\n        \n        loss = self.loss_fn(batch_data.future_target_cdf, outputs)\n        loss = self.get_weighted_loss(batch_data, loss)\n        return loss.mean()\n\n    def forecast(self, batch_data, num_samples=None):\n        # b l k\n        inputs = self.get_inputs(batch_data, 'encode')\n        # inputs = inputs[:,:,:self.target_dim]\n        inputs = self.enc_linear(inputs)\n        outputs = self.encoder(inputs)\n        return outputs.unsqueeze(1)\n"
  },
  {
    "path": "probts/model/forecaster/point_forecaster/naive.py",
    "content": "import torch\nfrom einops import repeat\nfrom probts.model.forecaster import Forecaster\nimport sys\n\nclass NaiveForecaster(Forecaster):\n    def __init__(\n        self,\n        **kwargs\n    ):\n        super().__init__(**kwargs)\n        self.no_training = True\n\n\n    def forecast(self, batch_data, num_samples=None):\n        last_value = batch_data.past_target_cdf[:,-1,:]\n        outputs = repeat(last_value,'b k -> b n l k', n=1, l=self.prediction_length)\n        return outputs\n"
  },
  {
    "path": "probts/model/forecaster/point_forecaster/nhits.py",
    "content": "# ---------------------------------------------------------------------------------\n# Portions of this file are derived from NeuralForecast\n# - Source: https://github.com/Nixtla/neuralforecast\n# - Paper: N-HiTS: Neural Hierarchical Interpolation for Time Series Forecasting\n# - License: Apache-2.0\n#\n# We thank the authors for their contributions.\n# ---------------------------------------------------------------------------------\n\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nimport numpy as np\nfrom einops import rearrange, repeat\nfrom functools import partial\nfrom typing import List, Tuple\n\nfrom probts.model.forecaster import Forecaster\n\n\nclass StaticFeaturesEncoder(nn.Module):\n    def __init__(self, in_features, out_features):\n        super().__init__()\n        layers = [nn.Dropout(p=0.5), nn.Linear(in_features=in_features, out_features=out_features), nn.ReLU()]\n        self.encoder = nn.Sequential(*layers)\n\n    def forward(self, x):\n        x = self.encoder(x)\n        return x\n\n\nclass IdentityBasis(nn.Module):\n    def __init__(self, backcast_size: int, forecast_size: int, interpolation_mode: str):\n        super().__init__()\n        assert (interpolation_mode in [\"linear\", \"nearest\"]) or (\"cubic\" in interpolation_mode)\n        self.forecast_size = forecast_size\n        self.backcast_size = backcast_size\n        self.interpolation_mode = interpolation_mode\n\n    def forward(\n        self,\n        backcast_theta: torch.Tensor,\n        forecast_theta: torch.Tensor,\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        backcast = backcast_theta\n        knots = forecast_theta\n\n        if self.interpolation_mode == \"nearest\":\n            knots = knots[:, None, :]\n            forecast = F.interpolate(knots, size=self.forecast_size, mode=self.interpolation_mode)\n            forecast = forecast[:, 0, :]\n        elif self.interpolation_mode == \"linear\":\n            knots = knots[:, None, :]\n            forecast = F.interpolate(\n                knots, size=self.forecast_size, mode=self.interpolation_mode\n            )  # , align_corners=True)\n            forecast = forecast[:, 0, :]\n        elif \"cubic\" in self.interpolation_mode:\n            batch_size = int(self.interpolation_mode.split(\"-\")[-1])\n            knots = knots[:, None, None, :]\n            forecast = torch.zeros((len(knots), self.forecast_size)).to(knots.device)\n            n_batches = int(np.ceil(len(knots) / batch_size))\n            for i in range(n_batches):\n                forecast_i = F.interpolate(\n                    knots[i * batch_size : (i + 1) * batch_size], size=self.forecast_size, mode=\"bicubic\"\n                )  # , align_corners=True)\n                forecast[i * batch_size : (i + 1) * batch_size] += forecast_i[:, 0, 0, :]\n\n        return backcast, forecast\n\n\ndef init_weights(module, initialization):\n    if type(module) == torch.nn.Linear:\n        if initialization == \"orthogonal\":\n            torch.nn.init.orthogonal_(module.weight)\n        elif initialization == \"he_uniform\":\n            torch.nn.init.kaiming_uniform_(module.weight)\n        elif initialization == \"he_normal\":\n            torch.nn.init.kaiming_normal_(module.weight)\n        elif initialization == \"glorot_uniform\":\n            torch.nn.init.xavier_uniform_(module.weight)\n        elif initialization == \"glorot_normal\":\n            torch.nn.init.xavier_normal_(module.weight)\n        elif initialization == \"lecun_normal\":\n            pass  # torch.nn.init.normal_(module.weight, 0.0, std=1/np.sqrt(module.weight.numel()))\n        else:\n            assert 1 < 0, f\"Initialization {initialization} not found\"\n\n\nACTIVATIONS = [\"ReLU\", \"Softplus\", \"Tanh\", \"SELU\", \"LeakyReLU\", \"PReLU\", \"Sigmoid\"]\n\n\nclass NHiTSBlock(nn.Module):\n    \"\"\"\n    N-HiTS block which takes a basis function as an argument.\n    \"\"\"\n\n    def __init__(\n        self,\n        context_length: int,\n        prediction_length: int,\n        output_size: int,\n        covariate_size: int,\n        static_size: int,\n        static_hidden_size: int,\n        n_theta: int,\n        hidden_size: List[int],\n        pooling_sizes: int,\n        pooling_mode: str,\n        basis: nn.Module,\n        n_layers: int,\n        batch_normalization: bool,\n        dropout: float,\n        activation: str,\n    ):\n        super().__init__()\n\n        assert pooling_mode in [\"max\", \"average\"]\n\n        self.context_length_pooled = int(np.ceil(context_length / pooling_sizes))\n\n        if static_size == 0:\n            static_hidden_size = 0\n\n        self.context_length = context_length\n        self.output_size = [output_size]\n        self.n_theta = n_theta\n        self.prediction_length = prediction_length\n        self.static_size = static_size\n        self.static_hidden_size = static_hidden_size\n        self.covariate_size = covariate_size\n        self.pooling_sizes = pooling_sizes\n        self.batch_normalization = batch_normalization\n        self.dropout = dropout\n\n        hidden1 = [self.context_length_pooled * len(self.output_size) + (self.context_length + self.prediction_length) * self.covariate_size + self.static_hidden_size]\n        self.hidden_size = hidden1 + hidden_size\n\n\n\n        assert activation in ACTIVATIONS, f\"{activation} is not in {ACTIVATIONS}\"\n        activ = getattr(nn, activation)()\n\n        if pooling_mode == \"max\":\n            self.pooling_layer = nn.MaxPool1d(kernel_size=self.pooling_sizes, stride=self.pooling_sizes, ceil_mode=True)\n        elif pooling_mode == \"average\":\n            self.pooling_layer = nn.AvgPool1d(kernel_size=self.pooling_sizes, stride=self.pooling_sizes, ceil_mode=True)\n\n        hidden_layers = []\n        for i in range(n_layers):\n            hidden_layers.append(nn.Linear(in_features=self.hidden_size[i], out_features=self.hidden_size[i + 1]))\n            hidden_layers.append(activ)\n\n            if self.batch_normalization:\n                hidden_layers.append(nn.BatchNorm1d(num_features=self.hidden_size[i + 1]))\n\n            if self.dropout > 0:\n                hidden_layers.append(nn.Dropout(p=self.dropout))\n\n        output_layer = [\n            nn.Linear(\n                in_features=self.hidden_size[-1],\n                out_features=context_length * len(self.output_size) + n_theta * sum(self.output_size),\n            )\n        ]\n        layers = hidden_layers + output_layer\n\n        # static_size is computed with data, static_hidden_size is provided by user, if 0 no statics are used\n        if (self.static_size > 0) and (self.static_hidden_size > 0):\n            self.static_encoder = StaticFeaturesEncoder(in_features=static_size, out_features=static_hidden_size)\n        self.layers = nn.Sequential(*layers)\n        self.basis = basis\n\n    def forward(\n        self, encoder_y: torch.Tensor, encoder_x_t: torch.Tensor, decoder_x_t: torch.Tensor\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        batch_size = len(encoder_y)\n\n        encoder_y = encoder_y.transpose(1, 2)\n        # Pooling layer to downsample input\n        encoder_y = self.pooling_layer(encoder_y)\n\n        encoder_y = encoder_y.transpose(1, 2).reshape(batch_size, -1)\n\n\n        if self.covariate_size > 0:\n            encoder_y = torch.cat(\n                (\n                    encoder_y,\n                    encoder_x_t.reshape(batch_size, -1),\n                    decoder_x_t.reshape(batch_size, -1),\n                ),\n                1,\n            )\n\n        # Compute local projection weights and projection\n        theta = self.layers(encoder_y)\n        backcast_theta = theta[:, : self.context_length * len(self.output_size)].reshape(-1, self.context_length)\n        forecast_theta = theta[:, self.context_length * len(self.output_size) :].reshape(-1, self.n_theta)\n        backcast, forecast = self.basis(backcast_theta, forecast_theta)\n        backcast = backcast.reshape(-1, len(self.output_size), self.context_length).transpose(1, 2)\n        forecast = forecast.reshape(-1, sum(self.output_size), self.prediction_length).transpose(1, 2)\n\n        return backcast, forecast\n\n\n\nclass NHiTS(Forecaster):\n    def __init__(\n        self,\n        n_blocks: list,\n        pooling_mode,\n        interpolation_mode,\n        dropout,\n        activation,\n        initialization,\n        batch_normalization,\n        shared_weights,\n        output_size: int = 1,\n        hidden_size: int = 512,\n        naive_level: bool = True,\n        static_size: int = 0,\n        static_hidden_size: int = 0,\n        n_layers: int = 2,\n        pooling_sizes: list = None,\n        downsample_frequencies: list = None,\n        **kwargs\n    ):\n        super().__init__(**kwargs)\n\n        \"\"\"\n        N-HiTS model.\n\n        Parameters\n        ----------\n        n_time_in: int\n            Multiplier to get insample size.\n            Insample size = n_time_in * output_size\n        n_time_out: int\n            Forecast horizon.\n        shared_weights: bool\n            If True, repeats first block.\n        activation: str\n            Activation function.\n            An item from ['relu', 'softplus', 'tanh', 'selu', 'lrelu', 'prelu', 'sigmoid'].\n        initialization: str\n            Initialization function.\n            An item from ['orthogonal', 'he_uniform', 'glorot_uniform', 'glorot_normal', 'lecun_normal'].\n        stack_types: List[str]\n            List of stack types.\n            Subset from ['identity'].\n        n_blocks: List[int]\n            Number of blocks for each stack type.\n            Note that len(n_blocks) = len(stack_types).\n        n_layers: List[int]\n            Number of layers for each stack type.\n            Note that len(n_layers) = len(stack_types).\n        n_theta_hidden: List[List[int]]\n            Structure of hidden layers for each stack type.\n            Each internal list should contain the number of units of each hidden layer.\n            Note that len(n_theta_hidden) = len(stack_types).\n        n_pool_kernel_size List[int]:\n            Pooling size for input for each stack.\n            Note that len(n_pool_kernel_size) = len(stack_types).\n        n_freq_downsample List[int]:\n            Downsample multiplier of output for each stack.\n            Note that len(n_freq_downsample) = len(stack_types).\n        batch_normalization: bool\n            Whether perform batch normalization.\n        dropout_prob_theta: float\n            Float between (0, 1).\n            Dropout for Nbeats basis.\n        \"\"\"\n\n        n_stacks = len(n_blocks)\n        covariate_size = 0\n        if self.use_feat_idx_emb:\n            covariate_size = covariate_size + self.feat_idx_emb_dim\n        if self.use_time_feat:\n            covariate_size = covariate_size + self.time_feat_dim\n        self.covariate_size = covariate_size\n        self.output_size = output_size\n        self.naive_level = naive_level\n\n        n_layers = [n_layers] * n_stacks\n        hidden_size = n_stacks * [2 * [hidden_size]]\n\n        if pooling_sizes is None:\n            pooling_sizes = np.exp2(np.round(np.linspace(0.49, np.log2(self.prediction_length / 2), n_stacks)))\n            pooling_sizes = [int(x) for x in pooling_sizes[::-1]]\n\n        if downsample_frequencies is None:\n            downsample_frequencies = [min(self.prediction_length, int(np.power(x, 1.5))) for x in pooling_sizes]\n\n        blocks = self.create_stack(\n            n_blocks=n_blocks,\n            context_length=self.context_length,\n            prediction_length=self.prediction_length,\n            output_size=output_size,\n            covariate_size=covariate_size,\n            static_size=static_size,\n            static_hidden_size=static_hidden_size,\n            n_layers=n_layers,\n            hidden_size=hidden_size,\n            pooling_sizes=pooling_sizes,\n            downsample_frequencies=downsample_frequencies,\n            pooling_mode=pooling_mode,\n            interpolation_mode=interpolation_mode,\n            batch_normalization=batch_normalization,\n            dropout=dropout,\n            activation=activation,\n            shared_weights=shared_weights,\n            initialization=initialization,\n        )\n        self.blocks = torch.nn.ModuleList(blocks)\n        self.loss_fn = nn.MSELoss(reduction='none')\n\n    def create_stack(\n        self,\n        n_blocks,\n        context_length,\n        prediction_length,\n        output_size,\n        covariate_size,\n        static_size,\n        static_hidden_size,\n        n_layers,\n        hidden_size,\n        pooling_sizes,\n        downsample_frequencies,\n        pooling_mode,\n        interpolation_mode,\n        batch_normalization,\n        dropout,\n        activation,\n        shared_weights,\n        initialization,\n    ):\n        block_list = []\n\n        for i in range(len(n_blocks)):\n            for block_id in range(n_blocks[i]):\n                # Batch norm only on first block\n                if (len(block_list) == 0) and (batch_normalization):\n                    batch_normalization_block = True\n                else:\n                    batch_normalization_block = False\n\n                # Shared weights\n                if shared_weights and block_id > 0:\n                    nbeats_block = block_list[-1]\n                else:\n                    n_theta = max(prediction_length // downsample_frequencies[i], 1)\n                    basis = IdentityBasis(\n                        backcast_size=context_length,\n                        forecast_size=prediction_length,\n                        interpolation_mode=interpolation_mode,\n                    )\n\n                    nbeats_block = NHiTSBlock(\n                        context_length=context_length,\n                        prediction_length=prediction_length,\n                        output_size=output_size,\n                        covariate_size=covariate_size,\n                        static_size=static_size,\n                        static_hidden_size=static_hidden_size,\n                        n_theta=n_theta,\n                        hidden_size=hidden_size[i],\n                        pooling_sizes=pooling_sizes[i],\n                        pooling_mode=pooling_mode,\n                        basis=basis,\n                        n_layers=n_layers[i],\n                        batch_normalization=batch_normalization_block,\n                        dropout=dropout,\n                        activation=activation,\n                    )\n\n                # Select type of evaluation and apply it to all layers of block\n                init_function = partial(init_weights, initialization=initialization)\n                nbeats_block.layers.apply(init_function)\n                block_list.append(nbeats_block)\n        return block_list\n\n        \n\n    def encoder(self, encoder_y, encoder_x_t, decoder_x_t):\n        # encoder_y: [B L D]\n        residuals = (encoder_y)\n        level = encoder_y[:, -1:].repeat(1, self.prediction_length, 1)  # Level with Naive1\n        forecast_level = level.repeat_interleave(torch.tensor(self.output_size, device=level.device), dim=2)\n\n        # level with last available observation\n        if self.naive_level:\n            block_forecasts = [forecast_level]\n            forecast = block_forecasts[0]\n        else:\n            block_forecasts = []\n            forecast = torch.zeros_like(forecast_level, device=forecast_level.device)\n\n        # forecast by block\n        for block in self.blocks:\n            block_backcast, block_forecast = block(\n                encoder_y=residuals, encoder_x_t=encoder_x_t, decoder_x_t=decoder_x_t\n            )\n            residuals = (residuals - block_backcast) # * encoder_mask\n\n            forecast = forecast + block_forecast\n        return forecast\n\n    def get_cov(self, inputs):\n        if self.use_feat_idx_emb:\n            if self.use_time_feat:\n                encoder_dim_fea = inputs[:, : self.context_length, self.target_dim:-self.time_feat_dim]  # [B L K*D]\n                decoder_dim_fea = inputs[:, -self.prediction_length:, self.target_dim:-self.time_feat_dim]  # [B L K*D]\n            else:\n                encoder_dim_fea = inputs[:, : self.context_length, self.target_dim:]  # [B L K*D]\n                decoder_dim_fea = inputs[:, -self.prediction_length:, self.target_dim:]  # [B L K*D]\n\n            encoder_dim_fea = rearrange(encoder_dim_fea, \"b l (k d) -> (b k) l d\", k=self.target_dim, d=self.feat_idx_emb_dim)\n            decoder_dim_fea = rearrange(decoder_dim_fea, \"b l (k d) -> (b k) l d\", k=self.target_dim, d=self.feat_idx_emb_dim)\n        else:\n            encoder_dim_fea = []\n\n        if self.time_feat_dim:\n            encoder_time_fea = inputs[:, : self.context_length, -self.time_feat_dim: ] # [B L Dt]\n            encoder_time_fea = repeat(encoder_time_fea, 'b l d -> (b k) l d', k=self.target_dim)\n\n            decoder_time_fea = inputs[:, -self.prediction_length:, -self.time_feat_dim: ] # [B L Dt]\n            decoder_time_fea = repeat(decoder_time_fea, 'b l d -> (b k) l d', k=self.target_dim)\n\n        else:\n            encoder_time_fea = []\n\n        if self.use_feat_idx_emb and self.use_time_feat:\n            encoder_x_t = torch.cat([encoder_dim_fea, encoder_time_fea], dim=-1)\n            decoder_x_t = torch.cat([decoder_dim_fea, decoder_time_fea], dim=-1)\n        elif self.use_feat_idx_emb:\n            encoder_x_t, decoder_x_t = encoder_dim_fea, decoder_dim_fea\n        elif self.use_time_feat:\n            encoder_x_t, decoder_x_t = encoder_time_fea, decoder_time_fea\n        else:\n            encoder_x_t, decoder_x_t = None, None\n        return encoder_x_t, decoder_x_t\n\n    def loss(self, batch_data):\n        inputs = self.get_inputs(batch_data, 'all') # [B L D]\n        \n        # Encode\n        encoder_y = inputs[:, : self.context_length, :self.target_dim] # [B L K]\n        encoder_y = rearrange(encoder_y, \"b l k -> (b k) l 1\")\n        encoder_x_t, decoder_x_t = self.get_cov(inputs)\n        outputs = self.encoder(encoder_y, encoder_x_t, decoder_x_t)\n        outputs = rearrange(outputs, \"(b k) l 1 -> b l k\", k=self.target_dim)\n        \n        loss = self.loss_fn(batch_data.future_target_cdf, outputs)\n        loss = self.get_weighted_loss(batch_data, loss)\n        return loss.mean()\n\n    def forecast(self, batch_data, num_samples=None):\n        inputs = self.get_inputs(batch_data, 'all') # [B L D]\n        encoder_y = inputs[:, : self.context_length, :self.target_dim] # [B L K]\n        encoder_y = rearrange(encoder_y, \"b l k -> (b k) l 1\")\n        encoder_x_t, decoder_x_t = self.get_cov(inputs)\n        output = self.encoder(encoder_y,encoder_x_t, decoder_x_t)\n        outputs = rearrange(output, \"(b k) l 1 -> b l k\", k=self.target_dim)\n        return outputs.unsqueeze(1)\n"
  },
  {
    "path": "probts/model/forecaster/point_forecaster/nlinear.py",
    "content": "# ---------------------------------------------------------------------------------\n# Portions of this file are derived from LTSF-Linear\n# - Source: https://github.com/cure-lab/LTSF-Linear\n# - Paper: Are Transformers Effective for Time Series Forecasting?\n# - License: Apache-2.0\n#\n# We thank the authors for their contributions.\n# ---------------------------------------------------------------------------------\n\n\nimport torch\nimport torch.nn as nn\nfrom probts.model.forecaster import Forecaster\n\n\nclass NLinear(Forecaster):\n    def __init__(\n        self,\n        individual: bool,\n        **kwargs\n    ):\n        super().__init__(**kwargs)\n        if self.input_size != self.target_dim:\n            self.enc_linear = nn.Linear(\n                in_features=self.input_size, out_features=self.target_dim\n            )\n        else:\n            self.enc_linear = nn.Identity()\n\n        self.target_dim = self.target_dim\n        self.individual = individual\n        if individual:\n            self.Linear = nn.ModuleList()\n            for i in range(self.target_dim):\n                self.Linear.append(nn.Linear(self.context_length,self.prediction_length))\n        else:\n            self.Linear = nn.Linear(self.context_length, self.prediction_length)\n        self.loss_fn = nn.MSELoss(reduction='none')\n\n    def forward(self, inputs):\n        seq_last = inputs[:,-1:,:].detach()\n        inputs = inputs - seq_last\n        if self.individual:\n            output = torch.zeros([inputs.size(0),self.prediction_length,inputs.size(2)],dtype=inputs.dtype).to(inputs.device)\n            for i in range(self.target_dim):\n                output[:,:,i] = self.Linear[i](inputs[:,:,i])\n        else:\n            output = self.Linear(inputs.permute(0,2,1)).permute(0,2,1)\n        output = output + seq_last\n        return output\n\n    def loss(self, batch_data):\n        inputs = self.get_inputs(batch_data, 'all')\n        inputs = inputs[:, : self.context_length, ...]\n        inputs = self.enc_linear(inputs)\n        outputs = self(inputs)\n        \n        loss = self.loss_fn(batch_data.future_target_cdf, outputs)\n        loss = self.get_weighted_loss(batch_data, loss)\n        return loss.mean()\n\n    def forecast(self, batch_data, num_samples=None):\n        inputs = self.get_inputs(batch_data, 'encode')\n        inputs = self.enc_linear(inputs)\n        outputs = self(inputs)\n        return outputs.unsqueeze(1)\n"
  },
  {
    "path": "probts/model/forecaster/point_forecaster/patchtst.py",
    "content": "# ---------------------------------------------------------------------------------\n# Portions of this file are derived from PatchTST\n# - Source: https://github.com/yuqinie98/PatchTST/tree/main\n# - Paper: PatchTST: A Time Series is Worth 64 Words: Long-term Forecasting with Transformers\n# - License: Apache-2.0\n\n# We thank the authors for their contributions.\n# -----\n# ----------------------------------------------------------------------------\n\n\nimport torch.nn as nn\nfrom torch import Tensor\nfrom typing import Optional\n\nfrom probts.model.forecaster import Forecaster\nfrom probts.model.nn.arch.PatchTSTModule.PatchTST_backbone import PatchTST_backbone\nfrom probts.model.nn.arch.PatchTSTModule.PatchTST_layers import series_decomp\n\nclass PatchTST(Forecaster):\n    def __init__(\n        self,\n        stride: int,\n        patch_len: int,\n        padding_patch: str = None,\n        max_seq_len: int = 1024,\n        n_layers:int = 3,\n        n_heads = 16,\n        d_k: int = None,\n        d_v: int = None,\n        d_ff: int = 256,\n        attn_dropout: float = 0.,\n        dropout: float = 0.,\n        act: str = \"gelu\", \n        res_attention: bool = True,\n        pre_norm: bool = False,\n        store_attn: bool = False,\n        pe: str = 'zeros',\n        learn_pe: bool = True,\n        attn_mask: Optional[Tensor] = None,\n        individual: bool = False,\n        head_type: str = 'flatten',\n        padding_var: Optional[int] = None, \n        revin: bool = True,\n        key_padding_mask: str = 'auto',\n        affine: bool = False,\n        subtract_last: bool = False,\n        decomposition: bool = False,\n        kernel_size: int = 3,\n        fc_dropout: float = 0.,\n        head_dropout: float = 0.,\n        f_hidden_size: int = 40,\n        **kwargs\n    ):\n        super().__init__(**kwargs)\n        \n        if self.input_size != self.target_dim:\n            self.enc_linear = nn.Linear(\n                in_features=self.input_size, out_features=self.target_dim\n            )\n        else:\n            self.enc_linear = nn.Identity()\n\n        # Load parameters\n        c_in = self.input_size\n        context_window = self.context_length\n        target_window = self.prediction_length\n\n        # Model\n        self.decomposition = decomposition\n        if self.decomposition:\n            self.decomp_module = series_decomp(kernel_size)\n            self.model_trend = PatchTST_backbone(c_in=c_in, context_window=context_window, target_window=target_window, patch_len=patch_len, stride=stride, \n                                  max_seq_len=max_seq_len, n_layers=n_layers, d_model=f_hidden_size,\n                                  n_heads=n_heads, d_k=d_k, d_v=d_v, d_ff=d_ff, attn_dropout=attn_dropout,\n                                  dropout=dropout, act=act, key_padding_mask=key_padding_mask, padding_var=padding_var, \n                                  attn_mask=attn_mask, res_attention=res_attention, pre_norm=pre_norm, store_attn=store_attn,\n                                  pe=pe, learn_pe=learn_pe, fc_dropout=fc_dropout, head_dropout=head_dropout, padding_patch = padding_patch,\n                                  pretrain_head=False, head_type=head_type, individual=individual, revin=revin, affine=affine,\n                                  subtract_last=subtract_last)\n            self.model_res = PatchTST_backbone(c_in=c_in, context_window=context_window, target_window=target_window, patch_len=patch_len, stride=stride, \n                                  max_seq_len=max_seq_len, n_layers=n_layers, d_model=f_hidden_size,\n                                  n_heads=n_heads, d_k=d_k, d_v=d_v, d_ff=d_ff, attn_dropout=attn_dropout,\n                                  dropout=dropout, act=act, key_padding_mask=key_padding_mask, padding_var=padding_var, \n                                  attn_mask=attn_mask, res_attention=res_attention, pre_norm=pre_norm, store_attn=store_attn,\n                                  pe=pe, learn_pe=learn_pe, fc_dropout=fc_dropout, head_dropout=head_dropout, padding_patch = padding_patch,\n                                  pretrain_head=False, head_type=head_type, individual=individual, revin=revin, affine=affine,\n                                  subtract_last=subtract_last)\n        else:\n            self.model = PatchTST_backbone(c_in=c_in, context_window=context_window, target_window=target_window, patch_len=patch_len, stride=stride, \n                                  max_seq_len=max_seq_len, n_layers=n_layers, d_model=f_hidden_size,\n                                  n_heads=n_heads, d_k=d_k, d_v=d_v, d_ff=d_ff, attn_dropout=attn_dropout,\n                                  dropout=dropout, act=act, key_padding_mask=key_padding_mask, padding_var=padding_var, \n                                  attn_mask=attn_mask, res_attention=res_attention, pre_norm=pre_norm, store_attn=store_attn,\n                                  pe=pe, learn_pe=learn_pe, fc_dropout=fc_dropout, head_dropout=head_dropout, padding_patch = padding_patch,\n                                  pretrain_head=False, head_type=head_type, individual=individual, revin=revin, affine=affine,\n                                  subtract_last=subtract_last)\n        self.loss_fn = nn.MSELoss(reduction='none')\n    \n    def forward(self, x):\n        if self.decomposition:\n            res_init, trend_init = self.decomp_module(x)\n            res_init, trend_init = res_init.permute(0,2,1), trend_init.permute(0,2,1)  # x: [Batch, Channel, Input length]\n            res = self.model_res(res_init)\n            trend = self.model_trend(trend_init)\n            x = res + trend\n            x = x.permute(0,2,1)    # x: [Batch, Input length, Channel]\n        else:\n            x = x.permute(0,2,1)    # x: [Batch, Channel, Input length]\n            x = self.model(x)\n            x = x.permute(0,2,1)    # x: [Batch, Input length, Channel]\n        return x\n\n    def loss(self, batch_data):\n        inputs = self.get_inputs(batch_data, 'encode')\n        inputs = self.enc_linear(inputs)\n        outputs = self(inputs)\n        \n        loss = self.loss_fn(batch_data.future_target_cdf, outputs)\n        loss = self.get_weighted_loss(batch_data, loss)\n        return loss.mean()\n\n    def forecast(self, batch_data, num_samples=None):\n        inputs = self.get_inputs(batch_data, 'encode')\n        inputs = self.enc_linear(inputs)\n        outputs = self(inputs)\n        return outputs.unsqueeze(1)\n"
  },
  {
    "path": "probts/model/forecaster/point_forecaster/time_moe.py",
    "content": "# ---------------------------------------------------------------------------------\n# Portions of this file are derived from Time-MoE\n# - Source: https://github.com/Time-MoE/Time-MoE\n# - Paper: Time-MoE: Billion-Scale Time Series Foundation Models with Mixture of Experts\n# - License: Apache License 2.0\n\n# We thank the authors for their contributions.\n# ---------------------------------------------------------------------------------\n\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nfrom einops import rearrange\nfrom transformers import AutoModelForCausalLM\nfrom probts.model.forecaster import Forecaster\nimport sys\nfrom probts.data.data_utils.data_scaler import InstanceNorm\n\nclass TimeMoE(Forecaster):\n    def __init__(\n        self,\n        model_size: str = '50M',\n        instance_norm=True,\n        **kwargs\n    ):\n        super().__init__(**kwargs)\n        self.no_training = True\n        \n        if (type(self.target_dim).__name__=='dict'):\n            for dataset_name in self.target_dim:\n                target_dim = target_dim[dataset_name]\n                freq = freq[dataset_name]\n        else:\n            freq = self.freq\n                \n        if (type(self.context_length).__name__=='list'):\n            context_length = max(context_length)\n            \n        if (type(self.prediction_length).__name__=='list'):\n            prediction_length = max(prediction_length)\n            \n        if model_size not in ['50M', '200M']:\n            print('Invalid model size. Please choose from 50M or 200M')\n            sys.exit()\n        \n        if instance_norm:\n            self.normalization = InstanceNorm()\n        else:\n            self.normalization = None\n            \n        self.model = AutoModelForCausalLM.from_pretrained(\n            f'Maple728/TimeMoE-{model_size}',\n            trust_remote_code=True,\n            torch_dtype=torch.bfloat16,\n        )\n        print(f\"loaded TimeMoE-{model_size} model\")\n        \n\n    def forecast(self, batch_data, num_samples=None):\n        inputs = batch_data.past_target_cdf[:, -self.context_length:]\n        # inputs = inputs[:, -self.context_length:].cpu()\n        B, _, K = inputs.shape\n        inputs = inputs.to(dtype=torch.bfloat16)\n        inputs = rearrange(inputs, 'b l k -> (b k) l')\n        \n        if self.normalization:\n            inputs = self.normalization(inputs, mode='norm')\n            \n        forecasts = self.model.generate(inputs, max_new_tokens=self.prediction_length)  # shape is [batch_size, 12 + 6]\n        point_forecast = forecasts[:, -self.prediction_length:]\n        \n        \n        if self.normalization:\n            point_forecast = self.normalization(point_forecast, mode='denorm')\n            \n        point_forecast = point_forecast.to(dtype=torch.float32)\n        point_forecast = rearrange(point_forecast, '(b k) l -> b l k', b=B,k=K)\n        \n        point_forecast = point_forecast[:, :self.prediction_length]\n        return point_forecast.unsqueeze(1)\n    "
  },
  {
    "path": "probts/model/forecaster/point_forecaster/timer.py",
    "content": "# ---------------------------------------------------------------------------------\n# Portions of this file are derived from Large-Time-Series-Model\n# - Source: https://github.com/thuml/Large-Time-Series-Model\n# - Paper: Timer: Generative Pre-trained Transformers Are Large Time Series Models\n# - License: MIT License\n\n# We thank the authors for their contributions.\n# ---------------------------------------------------------------------------------\n\n\nimport torch\nfrom einops import rearrange, repeat\nfrom torch import nn\n\nfrom probts.model.forecaster import Forecaster\n\n\nclass Model(nn.Module):\n    \"\"\"\n    Paper link: https://arxiv.org/pdf/2402.02368.pdf\n    \"\"\"\n\n    def __init__(self, ckpt_path):\n        super().__init__()\n        if ckpt_path and ckpt_path != \"\":\n            if ckpt_path.endswith('.pt'):\n                # print(f\"Loading Timer model from {ckpt_path}\")\n                self.timer = torch.jit.load(ckpt_path)\n        else:\n            raise NotImplementedError\n\n    def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):\n        return self.timer(x_enc, x_mark_enc, x_dec, x_mark_dec)\n\n\nclass Timer(Forecaster):\n    def __init__(\n        self,\n        label_len: int = 576,\n        ckpt_path: str = None,\n        ckpt_path_finetune: str = None,\n        **kwargs\n    ):\n        super().__init__(**kwargs)\n        self.no_training = True\n        \n        self.output_patch_len = 96 # fixed by the pre-trained model\n        self.label_len = label_len\n\n        # Load Timer\n        self.model = Model(ckpt_path)\n        if ckpt_path_finetune:\n            print(f\"Loading Timer finetune model from {ckpt_path_finetune}\")\n            self.model.load_state_dict(torch.load(ckpt_path_finetune))\n       \n\n    def forecast(self, batch_data, num_samples=None):        \n        # for now, we only support batch_size=1\n        B, _, K = batch_data.past_target_cdf.shape\n        inputs = batch_data.past_target_cdf[:, -self.context_length:, ...]\n        x_mark_enc = batch_data.past_time_feat[:, -self.context_length:, ...]\n        x_mark_dec = batch_data.future_time_feat\n        x_mark_dec = torch.cat([x_mark_enc[:, -self.label_len:, :], x_mark_dec], dim=1)\n\n        inputs = rearrange(inputs, 'b l k -> (b k) l 1')\n        x_mark_enc = repeat(x_mark_enc, 'b l f -> (b k) l f', k=K)\n        x_mark_dec = repeat(x_mark_dec, 'b l f -> (b k) l f', k=K)\n\n        dec_inp = torch.zeros_like(inputs[:, -self.prediction_length:, :]).float()\n        dec_inp = torch.cat((inputs[:, -self.label_len:, ...], dec_inp), dim=1).float()\n\n        inference_steps = self.prediction_length // self.output_patch_len\n        dis = self.prediction_length - inference_steps * self.output_patch_len\n        if dis != 0:\n            inference_steps += 1\n\n        pred_y = []\n\n        for j in range(inference_steps):\n            if len(pred_y) != 0:\n                inputs = torch.cat([inputs[:, self.output_patch_len:, :], pred_y[-1]], dim=1)\n                tmp = x_mark_dec[:, j - 1:j, :]\n                x_mark_enc = torch.cat([x_mark_enc[:, 1:, :], tmp], dim=1)\n\n            outputs = self.model(inputs, x_mark_enc, dec_inp, x_mark_dec)\n            pred_y.append(outputs[:, -self.output_patch_len:, :])\n\n        pred_y = torch.cat(pred_y, dim=1)\n        if dis != 0:\n            pred_y = pred_y[:, :-dis, :]\n        pred_y = rearrange(pred_y, '(b k) l 1 -> b l k', b=B, k=K)\n        pred_y = pred_y[:, :self.prediction_length, :]\n        return pred_y.unsqueeze(1)\n"
  },
  {
    "path": "probts/model/forecaster/point_forecaster/timesfm.py",
    "content": "# ---------------------------------------------------------------------------------\n# Portions of this file are derived from timesfm\n# - Source: https://github.com/google-research/timesfm\n# - Paper: A decoder-only foundation model for time-series forecasting\n# - License: Apache License 2.0\n\n# We thank the authors for their contributions.\n# ---------------------------------------------------------------------------------\n\n\nimport numpy as np\nimport torch\nfrom einops import rearrange\nimport sys\nfrom probts.model.forecaster import Forecaster\nfrom probts.model.nn.arch.TimesFMModule import TimesFm, TimesFmCheckpoint, TimesFmHparams\n# from submodules.timesfm.src.timesfm import TimesFm\n\nclass TimesFM(Forecaster):\n    def __init__(\n        self,\n        model_size: str = '200m',\n        # input_patch_len: int = 32,\n        # output_patch_len: int = 128,\n        # num_layers: int = 20,\n        # model_dims: int = 1280,\n        **kwargs\n    ):\n        super().__init__(**kwargs)\n        self.no_training = True\n        \n        if (type(self.target_dim).__name__=='dict'):\n            for dataset_name in self.target_dim:\n                target_dim = target_dim[dataset_name]\n                freq = freq[dataset_name]\n        else:\n            freq = self.freq\n                \n        if (type(self.context_length).__name__=='list'):\n            context_length = max(context_length)\n            \n        if (type(self.prediction_length).__name__=='list'):\n            prediction_length = max(prediction_length)\n            \n        if model_size not in ['200m', '500m']:\n            print('Invalid model size. Please choose from 200m or 500m')\n            sys.exit()\n\n        if model_size == '200m':\n            self.tfm = TimesFm(\n                hparams=TimesFmHparams(\n                    backend=\"gpu\",\n                    per_core_batch_size=32,\n                    horizon_len=128,\n                ),\n                checkpoint=TimesFmCheckpoint(\n                    huggingface_repo_id=\"google/timesfm-1.0-200m-pytorch\"),\n            )\n        elif model_size == '500m':\n            self.tfm = TimesFm(\n                hparams=TimesFmHparams(\n                    backend=\"gpu\",\n                    per_core_batch_size=32,\n                    horizon_len=128,\n                    num_layers=50,\n                    use_positional_embedding=False,\n                    context_len=2048,\n                ),\n                checkpoint=TimesFmCheckpoint(\n                    huggingface_repo_id=\"google/timesfm-2.0-500m-pytorch\"),\n            )\n\n        \n        freq_dict = {'h': 0, 'min': 0, 'd': 0, 'b': 0, 'u': 0, 'w': 1, 'm': 1, 'q': 2, 'y': 2}\n        freq = freq.lower()\n        \n        if freq in freq_dict:\n            self.freq_int = freq_dict[freq]\n        else:\n            self.freq_int = 0\n\n        print(f\"TimesFM-{model_size} - frequency: {freq}, freq_num: {self.freq_int}\")\n\n\n    def forecast(self, batch_data, num_samples=None):\n        inputs = self.get_inputs(batch_data, 'encode')\n        inputs = inputs[:, -self.context_length:].cpu()\n        B, _, K = inputs.shape\n        # past_target = batch_data.past_target_cdf[:, -self.context_length:]\n        \n        inputs = np.array(rearrange(inputs, 'b l k -> (b k) l'))\n        frequency_input = [self.freq_int] * inputs.shape[0]\n        \n        _, out = self.tfm.forecast(\n            inputs,\n            freq=frequency_input,\n        )\n        point_forecast = out[:, :, 5]\n        point_forecast = rearrange(point_forecast, '(b k) l -> b l k', b=B,k=K)\n        \n        point_forecast = torch.tensor(point_forecast[:, :self.prediction_length])\n        return point_forecast.unsqueeze(1)\n    "
  },
  {
    "path": "probts/model/forecaster/point_forecaster/timesnet.py",
    "content": "# ---------------------------------------------------------------------------------\n# Portions of this file are derived from TSLib\n# - Source: https://github.com/libts/tslib\n# - Paper: TimesNet: Temporal 2D-Variation Modeling for General Time Series Analysis\n# - License:  LGPL-2.1\n\n# We thank the authors for their contributions.\n# ---------------------------------------------------------------------------------\n\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.fft\nfrom probts.model.forecaster import Forecaster\nfrom probts.model.nn.arch.TransformerModule.Embed import DataEmbedding\nfrom probts.model.nn.arch.Conv_Blocks import Inception_Block_V1\n\n\ndef FFT_for_Period(x, k=2):\n    # [B, T, C]\n    xf = torch.fft.rfft(x, dim=1)\n    # find period by amplitudes\n    frequency_list = abs(xf).mean(0).mean(-1)\n    frequency_list[0] = 0\n    _, top_list = torch.topk(frequency_list, k)\n    top_list = top_list.detach().cpu().numpy()\n    period = x.shape[1] // top_list\n    return period, abs(xf).mean(-1)[:, top_list]\n\n\nclass TimesBlock(nn.Module):\n    def __init__(self, context_length, prediction_length, top_k, d_model, d_ff, num_kernels):\n        super(TimesBlock, self).__init__()\n        self.seq_len = context_length\n        self.pred_len = prediction_length\n        self.k = top_k\n        # parameter-efficient design\n        self.conv = nn.Sequential(\n            Inception_Block_V1(d_model, d_ff,\n                               num_kernels=num_kernels),\n            nn.GELU(),\n            Inception_Block_V1(d_ff, d_model,\n                               num_kernels=num_kernels)\n        )\n\n    def forward(self, x):\n        B, T, N = x.size()\n        period_list, period_weight = FFT_for_Period(x, self.k)\n\n        res = []\n        for i in range(self.k):\n            period = period_list[i]\n            # padding\n            if (self.seq_len + self.pred_len) % period != 0:\n                length = (\n                                 ((self.seq_len + self.pred_len) // period) + 1) * period\n                padding = torch.zeros([x.shape[0], (length - (self.seq_len + self.pred_len)), x.shape[2]]).to(x.device)\n                out = torch.cat([x, padding], dim=1)\n            else:\n                length = (self.seq_len + self.pred_len)\n                out = x\n            # reshape\n            out = out.reshape(B, length // period, period,\n                              N).permute(0, 3, 1, 2).contiguous()\n            # 2D conv: from 1d Variation to 2d Variation\n            out = self.conv(out)\n            # reshape back\n            out = out.permute(0, 2, 3, 1).reshape(B, -1, N)\n            res.append(out[:, :(self.seq_len + self.pred_len), :])\n        res = torch.stack(res, dim=-1)\n        # adaptive aggregation\n        period_weight = F.softmax(period_weight, dim=1)\n        period_weight = period_weight.unsqueeze(\n            1).unsqueeze(1).repeat(1, T, N, 1)\n        res = torch.sum(res * period_weight, -1)\n        # residual connection\n        res = res + x\n        return res\n\n\nclass TimesNet(Forecaster):\n    def __init__(\n        self,\n        n_layers: int = 2,\n        num_kernels: int = 6,\n        top_k: int = 5,\n        d_ff: int = 32,\n        embed: str = 'timeF',\n        dropout: float = 0.1,\n        f_hidden_size: int = 40,\n        **kwargs\n    ):\n        super().__init__(**kwargs)\n\n        self.seq_len = self.context_length\n        self.pred_len = self.prediction_length\n\n        self.model = nn.ModuleList(\n            [TimesBlock(self.context_length, self.prediction_length, top_k, f_hidden_size, d_ff, num_kernels)\n                for _ in range(n_layers)]\n        )\n        self.enc_embedding = DataEmbedding(self.target_dim, f_hidden_size, embed, self.freq.lower(), dropout)\n        self.layer = n_layers\n        self.layer_norm = nn.LayerNorm(f_hidden_size)\n\n        self.predict_linear = nn.Linear(\n            self.seq_len, self.pred_len + self.seq_len)\n        self.projection = nn.Linear(\n            f_hidden_size, self.target_dim, bias=True)\n        \n        if self.input_size != self.target_dim:\n            self.enc_linear = nn.Linear(\n                in_features=self.input_size, out_features=self.target_dim\n            )\n        else:\n            self.enc_linear = nn.Identity()\n        self.loss_fn = nn.MSELoss(reduction='none')\n\n    def forward(self, x_enc, x_mark_enc=None):\n        # Normalization from Non-stationary Transformer\n        means = x_enc.mean(1, keepdim=True).detach()\n        x_enc = x_enc - means\n        stdev = torch.sqrt(\n            torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)\n        x_enc = x_enc / stdev\n\n        # embedding\n        enc_out = self.enc_embedding(x_enc, x_mark_enc)  # [B,T,C]\n        enc_out = self.predict_linear(enc_out.permute(0, 2, 1)).permute(\n            0, 2, 1)  # align temporal dimension\n        # TimesNet\n        for i in range(self.layer):\n            enc_out = self.layer_norm(self.model[i](enc_out))\n        # porject back\n        dec_out = self.projection(enc_out)\n\n        # De-Normalization from Non-stationary Transformer\n        dec_out = dec_out * \\\n                  (stdev[:, 0, :].unsqueeze(1).repeat(\n                      1, self.pred_len + self.seq_len, 1))\n        dec_out = dec_out + \\\n                  (means[:, 0, :].unsqueeze(1).repeat(\n                      1, self.pred_len + self.seq_len, 1))\n        return dec_out[:, -self.pred_len:, :]  # [B, L, D]\n\n\n    def loss(self, batch_data):\n        inputs = self.get_inputs(batch_data, 'all')\n        inputs = inputs[:, : self.context_length, ...]\n        inputs = self.enc_linear(inputs)\n        # x: [Batch, Input length, Channel]\n        outputs = self(inputs)\n    \n        loss = self.loss_fn(batch_data.future_target_cdf, outputs)\n        loss = self.get_weighted_loss(batch_data, loss)\n        return loss.mean()\n\n    def forecast(self, batch_data, num_samples=None):\n        inputs = self.get_inputs(batch_data, 'encode')\n        inputs = self.enc_linear(inputs)\n        outputs = self(inputs)\n        return outputs.unsqueeze(1)"
  },
  {
    "path": "probts/model/forecaster/point_forecaster/tinytimemixer.py",
    "content": "# ---------------------------------------------------------------------------------\n# Portions of this file are derived from granite-tsfm\n# - Source: https://github.com/ibm-granite/granite-tsfm\n# - Paper: Tiny Time Mixers (TTMs): Fast Pre-trained Models for Enhanced Zero/Few-Shot Forecasting of Multivariate Time Series\n# - License: Apache License 2.0\n\n# We thank the authors for their contributions.\n# ---------------------------------------------------------------------------------\n\n\nfrom probts.model.forecaster import Forecaster\n\nfrom submodules.tsfm.tsfm_public.models.tinytimemixer import TinyTimeMixerForPrediction\n\n\nclass TinyTimeMixer(Forecaster):\n    \"\"\"\n    TinyTimeMixer from https://github.com/ibm-granite/granite-tsfm/blob/main/notebooks/hfdemo/ttm_getting_started.ipynb\n    prediction length originally 96\n    context length originally 512\n    changes might cause degradation in performance\n    \"\"\"\n\n    def __init__(\n        self,\n        **kwargs,\n    ):\n        super().__init__(**kwargs)\n        self.no_training = True\n\n        # TTM model branch\n        # Use main for 512-96 model\n        # Use \"1024_96_v1\" for 1024-96 model\n        TTM_MODEL_REVISION = \"main\"\n        \n        if (type(self.context_length).__name__=='list'):\n            context_length = max(context_length)\n            \n        if (type(self.prediction_length).__name__=='list'):\n            prediction_length = max(prediction_length)\n\n        self.zeroshot_model = TinyTimeMixerForPrediction.from_pretrained(\n            \"ibm/TTM\", revision=TTM_MODEL_REVISION\n        )\n        \n\n    def forecast(self, batch_data, num_samples=None):\n        inputs = self.get_inputs(batch_data, 'encode')\n        inputs = inputs[:, -self.context_length:]\n        B, _, K = inputs.shape \n        # past_target = batch_data.past_target_cdf[:, -self.context_length:]\n        self.zeroshot_model.eval()\n        point_forecast = self.zeroshot_model.forward(inputs).prediction_outputs\n        return point_forecast.unsqueeze(1)\n"
  },
  {
    "path": "probts/model/forecaster/point_forecaster/transformer.py",
    "content": "import torch\nimport torch.nn as nn\n\nfrom probts.data import ProbTSBatchData\nfrom probts.model.forecaster import Forecaster\n\n\nclass TransformerForecaster(Forecaster):\n    def __init__(\n        self,\n        f_hidden_size: int = 32,\n        num_heads: int = 8,\n        num_encoder_layers: int = 3,\n        num_decoder_layers: int = 3,\n        dim_feedforward_scale: int = 4,\n        dropout: float = 0.1,\n        activation: str = 'gelu',\n        **kwargs\n    ):\n        super().__init__(**kwargs)\n        self.autoregressive = True\n        self.f_hidden_size = f_hidden_size\n\n        self.enc_linear = nn.Linear(self.input_size, self.f_hidden_size)\n        self.dec_linear = nn.Linear(self.input_size, self.f_hidden_size)\n        self.model = nn.Transformer(\n            d_model=self.f_hidden_size,\n            nhead=num_heads,\n            num_encoder_layers=num_encoder_layers,\n            num_decoder_layers=num_decoder_layers,\n            dim_feedforward=dim_feedforward_scale * self.f_hidden_size,\n            dropout=dropout,\n            activation=activation\n        )\n\n        self.register_buffer(\n            \"tgt_mask\",\n            self.model.generate_square_subsequent_mask(self.prediction_length),\n        )\n        self.linear = nn.Linear(self.f_hidden_size, self.target_dim)\n        self.loss_fn = nn.MSELoss(reduction='none')\n\n    def loss(self, batch_data):\n        inputs = self.get_inputs(batch_data, 'all') # [B L D]\n\n        # Encode\n        enc_inputs = inputs[:, :self.context_length, ...]\n        enc_inputs = self.enc_linear(enc_inputs).permute(1, 0, 2)\n        enc_outputs = self.model.encoder(enc_inputs) # [L_in B H]\n\n        # Decode\n        dec_inputs = inputs[:, -self.prediction_length-1:-1, ...]\n        dec_inputs = self.dec_linear(dec_inputs).permute(1, 0, 2)\n        dec_outputs = self.model.decoder(\n            dec_inputs, enc_outputs, tgt_mask=self.tgt_mask)\n        dec_outputs = dec_outputs.permute(1, 0, 2)  # [L_out B D]\n        outputs = self.linear(dec_outputs)\n        \n        loss = self.loss_fn(batch_data.future_target_cdf, outputs)\n        loss = self.get_weighted_loss(batch_data, loss)\n        return loss.mean()\n\n    def forecast(self, batch_data, num_samples=None):\n        forecasts = []\n        states = self.encode(batch_data)\n        past_target_cdf = batch_data.past_target_cdf\n        \n        for k in range(self.prediction_length):\n            current_batch_data = ProbTSBatchData({\n                'target_dimension_indicator': batch_data.target_dimension_indicator,\n                'past_target_cdf': past_target_cdf,\n                'future_time_feat': batch_data.future_time_feat[:, k : k + 1:, ...]\n            }, device=batch_data.device)\n\n            outputs, states = self.decode(current_batch_data, states)\n            outputs = self.linear(outputs)\n            forecasts.append(outputs)\n\n            past_target_cdf = torch.cat(\n                (past_target_cdf, outputs), dim=1\n            )\n\n        forecasts = torch.cat(forecasts, dim=1).reshape(\n            -1, self.prediction_length, self.target_dim)\n        return forecasts.unsqueeze(1)\n\n    def encode(self, batch_data):\n        inputs = self.get_inputs(batch_data, 'encode')\n        inputs = self.enc_linear(inputs).permute(1, 0, 2)\n        states = self.model.encoder(inputs)\n        return states\n\n    def decode(self, batch_data, states=None):\n        inputs = self.get_inputs(batch_data, 'decode')\n        inputs = self.dec_linear(inputs).permute(1, 0, 2)\n        outputs = self.model.decoder(inputs, states, tgt_mask=None)\n        return outputs.permute(1, 0, 2), states\n"
  },
  {
    "path": "probts/model/forecaster/point_forecaster/tsmixer.py",
    "content": "# ---------------------------------------------------------------------------------\n# Portions of this file are derived from TSMixer\n# - Source: https://github.com/google-research/google-research/tree/master/tsmixer\n# - Paper: TSMixer: An All-MLP Architecture for Time Series Forecasting\n# - License: Apache-2.0 license\n\n# We thank the authors for their contributions.\n# ---------------------------------------------------------------------------------\n\n\nfrom __future__ import annotations\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom probts.model.nn.arch.TSMixer_layers import MixerLayer, TimeBatchNorm2d, feature_to_time, time_to_feature\n\n\nfrom probts.model.forecaster import Forecaster\nimport sys\n\n\nclass TSMixer(Forecaster):\n    \"\"\"TSMixer model for time series forecasting.\n\n    This model uses a series of mixer layers to process time series data,\n    followed by a linear transformation to project the output to the desired\n    prediction length.\n\n    Attributes:\n        mixer_layers: Sequential container of mixer layers.\n        temporal_projection: Linear layer for temporal projection.\n\n    Args:\n        sequence_length: Length of the input time series sequence.\n        prediction_length: Desired length of the output prediction sequence.\n        input_channels: Number of input channels.\n        output_channels: Number of output channels. Defaults to None.\n        activation_fn: Activation function to use. Defaults to \"relu\".\n        num_blocks: Number of mixer blocks. Defaults to 2.\n        dropout_rate: Dropout rate for regularization. Defaults to 0.1.\n        ff_dim: Dimension of feedforward network inside mixer layer. Defaults to 64.\n        normalize_before: Whether to apply layer normalization before or after mixer layer.\n        norm_type: Type of normalization to use. \"batch\" or \"layer\". Defaults to \"batch\".\n    \"\"\"\n\n    def __init__(\n        self,\n        activation_fn: str = \"relu\",\n        num_blocks: int = 2,\n        dropout_rate: float = 0.1,\n        ff_dim: int = 64,\n        normalize_before: bool = True,\n        norm_type: str = \"batch\",\n        **kwargs\n    ):\n        super().__init__(**kwargs)\n\n        # Transform activation_fn to callable\n        activation_fn = getattr(F, activation_fn)\n        \n        input_channels = self.target_dim\n        output_channels = self.target_dim\n        \n        if type(self.prediction_length) == list:\n            self.prediction_length = max(self.prediction_length)\n\n        if type(self.context_length) == list:\n            self.context_length = max(self.context_length)\n            \n        sequence_length = self.context_length\n        prediction_length = self.prediction_length\n        # Transform norm_type to callable\n        assert norm_type in {\n            \"batch\",\n            \"layer\",\n        }, f\"Invalid norm_type: {norm_type}, must be one of batch, layer.\"\n        norm_type = TimeBatchNorm2d if norm_type == \"batch\" else nn.LayerNorm\n\n        # Build mixer layers\n        self.mixer_layers = self._build_mixer(\n            num_blocks,\n            input_channels,\n            output_channels,\n            ff_dim=ff_dim,\n            activation_fn=activation_fn,\n            dropout_rate=dropout_rate,\n            sequence_length=sequence_length,\n            normalize_before=normalize_before,\n            norm_type=norm_type,\n        )\n\n        # Temporal projection layer\n        self.temporal_projection = nn.Linear(sequence_length, prediction_length)\n        self.loss_fn = nn.MSELoss(reduction='none')\n\n    def _build_mixer(\n        self, num_blocks: int, input_channels: int, output_channels: int, **kwargs\n    ):\n        \"\"\"Build the mixer blocks for the model.\n\n        Args:\n            num_blocks (int): Number of mixer blocks to be built.\n            input_channels (int): Number of input channels for the first block.\n            output_channels (int): Number of output channels for the last block.\n            **kwargs: Additional keyword arguments for mixer layer configuration.\n\n        Returns:\n            nn.Sequential: Sequential container of mixer layers.\n        \"\"\"\n        output_channels = output_channels if output_channels is not None else input_channels\n        channels = [input_channels] * (num_blocks - 1) + [output_channels]\n    \n\n        return nn.Sequential(\n            *[\n                MixerLayer(input_channels=in_ch, output_channels=out_ch, **kwargs)\n                for in_ch, out_ch in zip(channels[:-1], channels[1:])\n            ]\n        )\n\n    def forward(self, x_hist: torch.Tensor) -> torch.Tensor:\n        \"\"\"Forward pass of the TSMixer model.\n\n        Args:\n            x_hist (torch.Tensor): Input time series tensor.\n\n        Returns:\n            torch.Tensor: The output tensor after processing by the model.\n        \"\"\"\n        x = self.mixer_layers(x_hist)\n\n        x_temp = feature_to_time(x)\n        x_temp = self.temporal_projection(x_temp)\n        x = time_to_feature(x_temp)\n\n        return x\n    \n    def loss(self, batch_data):\n        inputs = self.get_inputs(batch_data, 'encode')\n        outputs = self(inputs)\n        \n        loss = self.loss_fn(batch_data.future_target_cdf, outputs)\n        loss = self.get_weighted_loss(batch_data, loss)\n        return loss.mean()\n\n    def forecast(self, batch_data, num_samples=None):\n        inputs = self.get_inputs(batch_data, 'encode')\n        outputs = self(inputs)\n        return outputs.unsqueeze(1)\n"
  },
  {
    "path": "probts/model/forecaster/point_forecaster/units.py",
    "content": "# ---------------------------------------------------------------------------------\n# Portions of this file are derived from UniTS\n# - Source: https://github.com/mims-harvard/UniTS\n# - Paper: UNITS: A Unified Multi-Task Time Series Model\n# - License: MIT License\n#\n# We thank the authors for their contributions.\n# ---------------------------------------------------------------------------------\n\n\nimport math\n\nimport torch\nimport torch.nn.functional as F\nfrom timm.layers import DropPath, Mlp\nfrom timm.layers.helpers import to_2tuple\nfrom torch import nn\n\nfrom probts.model.forecaster import Forecaster\n\n\ndef calculate_unfold_output_length(input_length, size, step):\n    # Calculate the number of windows\n    num_windows = (input_length - size) // step + 1\n    return num_windows\n\n\nclass CrossAttention(nn.Module):\n    def __init__(\n            self,\n            dim,\n            num_heads=8,\n            qkv_bias=False,\n            qk_norm=False,\n            attn_drop=0.,\n            proj_drop=0.,\n            norm_layer=nn.LayerNorm,\n            var_num=None,\n    ):\n        super().__init__()\n        assert dim % num_heads == 0, 'dim should be divisible by num_heads'\n        self.num_heads = num_heads\n        self.head_dim = dim // num_heads\n        self.scale = self.head_dim ** -0.5\n\n        self.q = nn.Linear(dim, dim, bias=qkv_bias)\n        self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)\n        self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()\n        self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n        if var_num is not None:\n            self.template = nn.Parameter(\n                torch.zeros(var_num, dim), requires_grad=True)\n            torch.nn.init.normal_(self.template, std=.02)\n        self.var_num = var_num\n\n    def forward(self, x, query=None):\n        B, N, C = x.shape\n        if query is not None:\n            q = self.q(query).reshape(\n                B, query.shape[1], self.num_heads, self.head_dim).permute(0, 2, 1, 3)\n            q = self.q_norm(q)\n            var_num = query.shape[1]\n        else:\n            q = self.q(self.template).reshape(1, self.var_num,\n                                              self.num_heads, self.head_dim).permute(0, 2, 1, 3)\n            q = self.q_norm(q)\n            q = q.repeat(B, 1, 1, 1)\n            var_num = self.var_num\n        kv = self.kv(x).reshape(B, N, 2, self.num_heads,\n                                self.head_dim).permute(2, 0, 3, 1, 4)\n        k, v = kv.unbind(0)\n        k = self.k_norm(k)\n\n        x = F.scaled_dot_product_attention(\n            q, k, v,\n            dropout_p=self.attn_drop.p if self.training else 0.,\n        )\n\n        x = x.transpose(1, 2).reshape(B, var_num, C)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n\n\nclass DynamicLinear(nn.Module):\n    \"\"\"\n    A dynamic linear layer that can interpolate the weight size to support any given input and output feature dimension.\n    \"\"\"\n\n    def __init__(self, in_features=None, out_features=None, fixed_in=0, bias=True):\n        super(DynamicLinear, self).__init__()\n        assert fixed_in < in_features, \"fixed_in < in_features is required !!!\"\n        self.in_features = in_features\n        self.out_features = out_features\n        self.weights = nn.Parameter(torch.Tensor(out_features, in_features))\n        self.bias = nn.Parameter(torch.Tensor(out_features))\n        self.fixed_in = fixed_in\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        nn.init.kaiming_uniform_(self.weights, a=math.sqrt(5))\n        fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weights)\n        bound = 1 / math.sqrt(fan_in)\n        nn.init.uniform_(self.bias, -bound, bound)\n\n    def forward(self, x, out_features):\n        \"\"\"\n        Forward pass for the dynamic linear layer.\n        \"\"\"\n        fixed_weights = self.weights[:, :self.fixed_in]\n        dynamic_weights = self.weights[:, self.fixed_in:]\n        this_bias = self.bias\n        in_features = x.shape[-1]\n\n        if in_features != self.weights.size(1) or out_features != self.weights.size(0):\n            dynamic_weights = F.interpolate(dynamic_weights.unsqueeze(0).unsqueeze(0), size=(\n                out_features, in_features-self.fixed_in), mode='bilinear', align_corners=False).squeeze(0).squeeze(0)\n            if self.fixed_in != 0:\n                fixed_weights = F.interpolate(fixed_weights.unsqueeze(0).unsqueeze(0), size=(\n                    out_features, self.fixed_in), mode='bilinear', align_corners=False).squeeze(0).squeeze(0)\n        if out_features != self.weights.size(0):\n            this_bias = F.interpolate(this_bias.unsqueeze(0).unsqueeze(0).unsqueeze(0), size=(\n                1, out_features), mode='bilinear', align_corners=False).squeeze(0).squeeze(0).squeeze(0)\n        return F.linear(x, torch.cat((fixed_weights, dynamic_weights), dim=1), this_bias)\n\n\nclass DynamicLinearMlp(nn.Module):\n    def __init__(\n            self,\n            in_features,\n            hidden_features=None,\n            out_features=None,\n            act_layer=nn.GELU,\n            norm_layer=None,\n            bias=True,\n            drop=0.,\n            prefix_token_length=None,\n            group=1,\n    ):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        bias = to_2tuple(bias)\n        drop_probs = to_2tuple(drop)\n\n        self.fc1 = nn.Conv1d(in_features, hidden_features,\n                             3, groups=group, bias=bias[0], padding=1)\n        self.act = act_layer()\n        self.drop1 = nn.Dropout(drop_probs[0])\n\n        self.norm = norm_layer(\n            hidden_features) if norm_layer is not None else nn.Identity()\n        self.seq_fc = DynamicLinear(\n            hidden_features//4, hidden_features//4, bias=bias[1], fixed_in=prefix_token_length)\n        self.prompt_fc = DynamicLinear(\n            hidden_features//4, prefix_token_length, bias=bias[1], fixed_in=prefix_token_length)\n\n        self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])\n        self.drop2 = nn.Dropout(drop_probs[1])\n        self.hidden_features = hidden_features\n        self.prefix_token_length = prefix_token_length\n\n    def dynamic_linear(self, x, prefix_seq_len):\n        x_func = x[:, :, prefix_seq_len:]\n        x_seq = x[:, :, :prefix_seq_len]\n        x_seq_out = self.seq_fc(\n            x_seq, x_seq.shape[-1]-self.prefix_token_length)\n        x_prompt = self.prompt_fc(x_seq, self.prefix_token_length)\n        x = torch.cat((x_prompt, x_seq_out, x_func), dim=-1)\n        return x\n\n    def split_dynamic_linear(self, x, prefix_seq_len):\n        x1, x2 = x.chunk(2, dim=-2)\n        x1 = self.dynamic_linear(x1, prefix_seq_len)\n        return torch.cat((x1, x2), dim=-2)\n\n    def forward(self, x, prefix_seq_len, dim=2):\n        n, var, l, c = x.shape\n        x = x.view(-1, l, c)\n        x = x.transpose(-1, -2)\n        x = self.fc1(x)\n        x = self.split_dynamic_linear(x, prefix_seq_len)\n        x = self.act(x)\n        x = self.drop1(x)\n        x = x.transpose(1, 2)\n        x = self.norm(x)\n        x = self.fc2(x).view(n, var, l, c)\n        x = self.drop2(x)\n        return x\n\n\nclass LearnablePositionalEmbedding(nn.Module):\n    def __init__(self, d_model, max_len=5000):\n        super(LearnablePositionalEmbedding, self).__init__()\n        # Compute the positional encodings once in log space.\n        self.pe = nn.Parameter(torch.zeros(\n            1, 1, max_len, d_model), requires_grad=True)\n\n        pe = torch.zeros(max_len, d_model).float()\n        position = torch.arange(0, max_len).float().unsqueeze(1)\n        div_term = (torch.arange(0, d_model, 2).float()\n                    * -(math.log(10000.0) / d_model)).exp()\n\n        pe[:, 0::2] = torch.sin(position * div_term)\n        pe[:, 1::2] = torch.cos(position * div_term)\n\n        pe = pe.unsqueeze(0).unsqueeze(0)\n        self.pe.data.copy_(pe.float())\n        del pe\n\n    def forward(self, x, offset=0):\n        return self.pe[:, :, offset:offset+x.size(2)]\n\n\nclass SeqAttention(nn.Module):\n\n    def __init__(\n            self,\n            dim,\n            num_heads=8,\n            qkv_bias=False,\n            qk_norm=False,\n            attn_drop=0.,\n            proj_drop=0.,\n            norm_layer=nn.LayerNorm,\n    ):\n        super().__init__()\n        assert dim % num_heads == 0, 'dim should be divisible by num_heads'\n        self.num_heads = num_heads\n        self.head_dim = dim // num_heads\n        self.scale = self.head_dim ** -0.5\n\n        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\n        self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()\n        self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n\n    def forward(self, x, attn_mask=None):\n        B, N, C = x.shape\n        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,\n                                  self.head_dim).permute(2, 0, 3, 1, 4)\n        q, k, v = qkv.unbind(0)\n        q, k = self.q_norm(q), self.k_norm(k)\n        x = F.scaled_dot_product_attention(\n            q, k, v,  # attn_mask=attn_mask,\n            dropout_p=self.attn_drop.p if self.training else 0.,\n        )\n\n        x = x.transpose(1, 2).reshape(B, N, C)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n\n\nclass VarAttention(nn.Module):\n\n    def __init__(\n            self,\n            dim,\n            num_heads=8,\n            qkv_bias=False,\n            qk_norm=False,\n            attn_drop=0.,\n            proj_drop=0.,\n            norm_layer=nn.LayerNorm,\n    ):\n        super().__init__()\n        assert dim % num_heads == 0, 'dim should be divisible by num_heads'\n        self.num_heads = num_heads\n        self.head_dim = dim // num_heads\n        self.scale = self.head_dim ** -0.5\n        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\n        self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()\n        self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n\n    def forward(self, x):\n        B, N, P, C = x.shape\n\n        qkv = self.qkv(x).reshape(B, N, P, 3, self.num_heads,\n                                  self.head_dim).permute(3, 0, 2, 4, 1, 5)\n        q, k, v = qkv.unbind(0)\n        q, k = self.q_norm(q), self.k_norm(k)\n\n        q = q.mean(dim=1, keepdim=False)\n        k = k.mean(dim=1, keepdim=False)\n        v = v.permute(0, 2, 3, 4, 1).reshape(B, self.num_heads, N, -1)\n\n        x = F.scaled_dot_product_attention(\n            q, k, v,\n            dropout_p=self.attn_drop.p if self.training else 0.,\n        )\n\n        x = x.view(B, self.num_heads, N, -1, P).permute(0,\n                                                        2, 4, 1, 3).reshape(B, N, P, -1)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n\n\nclass GateLayer(nn.Module):\n    def __init__(self, dim, init_values=1e-5, inplace=False):\n        super().__init__()\n        self.inplace = inplace\n        self.gate = nn.Linear(dim, 1)\n\n    def forward(self, x):\n        gate_value = self.gate(x)\n        return gate_value.sigmoid() * x\n\n\nclass SeqAttBlock(nn.Module):\n\n    def __init__(\n            self,\n            dim,\n            num_heads,\n            qkv_bias=False,\n            qk_norm=False,\n            proj_drop=0.,\n            attn_drop=0.,\n            init_values=None,\n            drop_path=0.,\n            norm_layer=nn.LayerNorm,\n    ):\n        super().__init__()\n        self.norm1 = norm_layer(dim)\n        self.attn_seq = SeqAttention(\n            dim,\n            num_heads=num_heads,\n            qkv_bias=qkv_bias,\n            qk_norm=qk_norm,\n            attn_drop=attn_drop,\n            proj_drop=proj_drop,\n            norm_layer=norm_layer,\n        )\n\n        self.ls1 = GateLayer(dim, init_values=init_values)\n        self.drop_path1 = DropPath(\n            drop_path) if drop_path > 0. else nn.Identity()\n        self.proj = nn.Linear(dim, dim)\n\n    def forward(self, x, attn_mask):\n        x_input = x\n        x = self.norm1(x)\n        n_vars, n_seqs = x.shape[1], x.shape[2]\n        x = torch.reshape(\n            x, (-1, x.shape[-2], x.shape[-1]))\n        x = self.attn_seq(x, attn_mask)\n        x = torch.reshape(\n            x, (-1, n_vars, n_seqs, x.shape[-1]))\n        x = x_input + self.drop_path1(self.ls1(x))\n        return x\n\n\nclass VarAttBlock(nn.Module):\n\n    def __init__(\n            self,\n            dim,\n            num_heads,\n            qkv_bias=False,\n            qk_norm=False,\n            proj_drop=0.,\n            attn_drop=0.,\n            init_values=None,\n            drop_path=0.,\n            norm_layer=nn.LayerNorm,\n    ):\n        super().__init__()\n        self.norm1 = norm_layer(dim)\n        self.attn_var = VarAttention(\n            dim,\n            num_heads=num_heads,\n            qkv_bias=qkv_bias,\n            qk_norm=qk_norm,\n            attn_drop=attn_drop,\n            proj_drop=proj_drop,\n            norm_layer=norm_layer,\n        )\n        self.ls1 = GateLayer(dim, init_values=init_values)\n        self.drop_path1 = DropPath(\n            drop_path) if drop_path > 0. else nn.Identity()\n        self.proj = nn.Linear(dim, dim)\n\n    def forward(self, x):\n        x = x + self.drop_path1(self.ls1(self.attn_var(self.norm1(x))))\n        return x\n\n\nclass MLPBlock(nn.Module):\n\n    def __init__(\n            self,\n            dim,\n            mlp_ratio=4.,\n            proj_drop=0.,\n            init_values=None,\n            drop_path=0.,\n            act_layer=nn.GELU,\n            norm_layer=nn.LayerNorm,\n            mlp_layer=None,\n            prefix_token_length=0,\n    ):\n        super().__init__()\n        self.norm2 = norm_layer(dim)\n        if mlp_layer is DynamicLinearMlp:\n            self.mlp = mlp_layer(\n                in_features=dim,\n                hidden_features=int(dim * mlp_ratio),\n                act_layer=act_layer,\n                drop=proj_drop,\n                prefix_token_length=prefix_token_length,\n            )\n        else:\n            self.mlp = mlp_layer(\n                in_features=dim,\n                hidden_features=int(dim * mlp_ratio),\n                act_layer=act_layer,\n                drop=proj_drop,\n            )\n        self.ls2 = GateLayer(dim, init_values=init_values)\n        self.drop_path2 = DropPath(\n            drop_path) if drop_path > 0. else nn.Identity()\n\n    def forward(self, x, prefix_seq_len=None):\n        if prefix_seq_len is not None:\n            x = x + \\\n                self.drop_path2(\n                    self.ls2(self.mlp(self.norm2(x), prefix_seq_len=prefix_seq_len)))\n        else:\n            x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))\n        return x\n\n\nclass BasicBlock(nn.Module):\n    def __init__(\n            self,\n            dim,\n            num_heads,\n            mlp_ratio=8.,\n            qkv_bias=False,\n            qk_norm=False,\n            proj_drop=0.,\n            attn_drop=0.,\n            init_values=None,\n            drop_path=0.,\n            act_layer=nn.GELU,\n            norm_layer=nn.LayerNorm,\n            prefix_token_length=0,\n    ):\n        super().__init__()\n        self.seq_att_block = SeqAttBlock(dim=dim, num_heads=num_heads,\n                                         qkv_bias=qkv_bias, qk_norm=qk_norm,\n                                         attn_drop=attn_drop, init_values=init_values, proj_drop=proj_drop,\n                                         drop_path=drop_path, norm_layer=norm_layer)\n\n        self.var_att_block = VarAttBlock(dim=dim, num_heads=num_heads,\n                                         qkv_bias=qkv_bias, qk_norm=qk_norm,\n                                         attn_drop=attn_drop, init_values=init_values, proj_drop=proj_drop,\n                                         drop_path=drop_path, norm_layer=norm_layer)\n\n        self.dynamic_mlp = MLPBlock(dim=dim, mlp_ratio=mlp_ratio, mlp_layer=DynamicLinearMlp,\n                                    proj_drop=proj_drop, init_values=init_values, drop_path=drop_path,\n                                    act_layer=act_layer, norm_layer=norm_layer,\n                                    prefix_token_length=prefix_token_length)\n\n    def forward(self, x, prefix_seq_len, attn_mask):\n        x = self.seq_att_block(x, attn_mask)\n        x = self.var_att_block(x)\n        x = self.dynamic_mlp(x, prefix_seq_len=prefix_seq_len)\n        return x\n\n\nclass PatchEmbedding(nn.Module):\n    def __init__(self, d_model, patch_len, stride, padding, dropout):\n        super(PatchEmbedding, self).__init__()\n        # Patching\n        self.patch_len = patch_len\n        self.stride = stride\n        assert self.patch_len == self.stride, \"non-overlap\"\n        self.value_embedding = nn.Linear(patch_len, d_model, bias=False)\n        self.dropout = nn.Dropout(dropout)\n\n    def forward(self, x):\n        n_vars = x.shape[1]\n        x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride)\n        x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3]))\n        x = self.value_embedding(x)\n        return self.dropout(x), n_vars\n\n\nclass CLSHead(nn.Module):\n    def __init__(self, d_model, head_dropout=0):\n        super().__init__()\n        d_mid = d_model\n        self.proj_in = nn.Linear(d_model, d_mid)\n        self.cross_att = CrossAttention(d_mid)\n\n        self.mlp = MLPBlock(dim=d_mid, mlp_ratio=8, mlp_layer=Mlp,\n                            proj_drop=head_dropout, init_values=None, drop_path=0.0,\n                            act_layer=nn.GELU, norm_layer=nn.LayerNorm,\n                            prefix_token_length=None)\n\n    def forward(self, x, category_token=None, return_feature=False):\n        x = self.proj_in(x)\n        B, V, L, C = x.shape\n        x = x.view(-1, L, C)\n        cls_token = x[:, -1:]\n        cls_token = self.cross_att(x, query=cls_token)\n        cls_token = cls_token.reshape(B, V, -1, C)\n\n        cls_token = self.mlp(cls_token)\n        if return_feature:\n            return cls_token\n        m = category_token.shape[2]\n        cls_token = cls_token.expand(B, V, m, C)\n        distance = torch.einsum('nvkc,nvmc->nvm', cls_token, category_token)\n\n        distance = distance.mean(dim=1)\n        return distance\n\n\nclass ForecastHead(nn.Module):\n    def __init__(self, d_model, patch_len, stride, pad, head_dropout=0, prefix_token_length=None):\n        super().__init__()\n        d_mid = d_model\n        self.proj_in = nn.Linear(d_model, d_mid)\n        self.mlp = Mlp(\n            in_features=d_model,\n            hidden_features=int(d_model * 4),\n            act_layer=nn.GELU,\n            drop=head_dropout,\n        )\n        self.proj_out = nn.Linear(d_model, patch_len)\n        self.pad = pad\n        self.patch_len = patch_len\n        self.stride = stride\n        self.pos_proj = DynamicLinear(\n            in_features=128, out_features=128, fixed_in=prefix_token_length)\n\n    def forward(self, x_full, pred_len, token_len):\n        x_full = self.proj_in(x_full)\n        x_pred = x_full[:, :, -token_len:]\n        x = x_full.transpose(-1, -2)\n        x = self.pos_proj(x, token_len)\n        x = x.transpose(-1, -2)\n        x = x + x_pred\n        x = self.mlp(x)\n        x = self.proj_out(x)\n\n        bs, n_vars = x.shape[0], x.shape[1]\n        x = x.reshape(-1, x.shape[-2], x.shape[-1])\n        x = x.permute(0, 2, 1)\n        x = torch.nn.functional.fold(x, output_size=(\n            pred_len, 1), kernel_size=(self.patch_len, 1), stride=(self.stride, 1))\n        x = x.squeeze(dim=-1)\n        x = x.reshape(bs, n_vars, -1)\n        x = x.permute(0, 2, 1)\n        return x\n\n\nclass Model(nn.Module):\n    \"\"\"\n    UniTS: Building a Unified Time Series Model\n    \"\"\"\n\n    def __init__(self, args, configs_list, pretrain=False):\n        super().__init__()\n\n        # (zhenwei) we do not pretrain the model in this stage\n        # if pretrain:\n        #     self.right_prob = args.right_prob\n        #     self.min_mask_ratio = args.min_mask_ratio\n        #     self.max_mask_ratio = args.max_mask_ratio\n\n        # Tokens settings\n        self.num_task = len(configs_list)\n        self.prompt_tokens = nn.ParameterDict({})\n        self.mask_tokens = nn.ParameterDict({})\n        self.cls_tokens = nn.ParameterDict({})\n        self.category_tokens = nn.ParameterDict({})\n\n        for i in range(self.num_task):\n            dataset_name = configs_list[i][1]['dataset']\n            task_data_name = configs_list[i][0]\n            if dataset_name not in self.prompt_tokens:\n                self.prompt_tokens[dataset_name] = torch.zeros(\n                    1, configs_list[i][1]['enc_in'], args.prompt_num, args.d_model)\n                torch.nn.init.normal_(\n                    self.prompt_tokens[dataset_name], std=.02)\n                self.mask_tokens[dataset_name] = torch.zeros(\n                    1, configs_list[i][1]['enc_in'], 1, args.d_model)\n\n            if configs_list[i][1]['task_name'] == 'classification':\n                self.category_tokens[task_data_name] = torch.zeros(\n                    1, configs_list[i][1]['enc_in'], configs_list[i][1]['num_class'], args.d_model)\n                torch.nn.init.normal_(\n                    self.category_tokens[task_data_name], std=.02)\n                self.cls_tokens[task_data_name] = torch.zeros(\n                    1, configs_list[i][1]['enc_in'], 1, args.d_model)\n                torch.nn.init.normal_(self.cls_tokens[task_data_name], std=.02)\n            if pretrain:\n                self.cls_tokens[task_data_name] = torch.zeros(\n                    1, configs_list[i][1]['enc_in'], 1, args.d_model)\n                torch.nn.init.normal_(self.cls_tokens[task_data_name], std=.02)\n\n        self.cls_nums = {}\n        for i in range(self.num_task):\n            task_data_name = configs_list[i][0]\n            if configs_list[i][1]['task_name'] == 'classification':\n                self.cls_nums[task_data_name] = configs_list[i][1]['num_class']\n            elif configs_list[i][1]['task_name'] == 'long_term_forecast':\n                remainder = configs_list[i][1]['seq_len'] % args.patch_len\n                if remainder == 0:\n                    padding = 0\n                else:\n                    padding = args.patch_len - remainder\n                input_token_len = calculate_unfold_output_length(\n                    configs_list[i][1]['seq_len']+padding, args.stride, args.patch_len)\n                input_pad = args.stride * \\\n                    (input_token_len - 1) + args.patch_len - \\\n                    configs_list[i][1]['seq_len']\n                pred_token_len = calculate_unfold_output_length(\n                    configs_list[i][1]['pred_len']-input_pad, args.stride, args.patch_len)\n                real_len = configs_list[i][1]['seq_len'] + \\\n                    configs_list[i][1]['pred_len']\n                self.cls_nums[task_data_name] = [pred_token_len,\n                                                 configs_list[i][1]['pred_len'], real_len]\n\n        self.configs_list = configs_list\n\n        ### model settings ###\n        self.prompt_num = args.prompt_num\n        self.stride = args.stride\n        self.pad = args.stride\n        self.patch_len = args.patch_len\n\n        # input processing\n        self.patch_embeddings = PatchEmbedding(\n            args.d_model, args.patch_len, args.stride, args.stride, args.dropout)\n        self.position_embedding = LearnablePositionalEmbedding(args.d_model)\n        self.prompt2forecat = DynamicLinear(128, 128, fixed_in=args.prompt_num)\n\n        # basic blocks\n        self.block_num = args.e_layers\n        self.blocks = nn.ModuleList(\n            [BasicBlock(dim=args.d_model, num_heads=args.n_heads, qkv_bias=False, qk_norm=False,\n                        mlp_ratio=8., proj_drop=args.dropout, attn_drop=0., drop_path=0.,\n                        init_values=None, prefix_token_length=args.prompt_num) for l in range(args.e_layers)]\n        )\n\n        # output processing\n        self.cls_head = CLSHead(args.d_model, head_dropout=args.dropout)\n        self.forecast_head = ForecastHead(\n            args.d_model, args.patch_len, args.stride, args.stride, prefix_token_length=args.prompt_num, head_dropout=args.dropout)\n        if pretrain:\n            self.pretrain_head = ForecastHead(\n                args.d_model, args.patch_len, args.stride, args.stride, prefix_token_length=1, head_dropout=args.dropout)\n\n    def tokenize(self, x, mask=None):\n        # Normalization from Non-stationary Transformer\n        means = x.mean(1, keepdim=True).detach()\n        x = x - means\n        if mask is not None:\n            x = x.masked_fill(mask == 0, 0)\n            stdev = torch.sqrt(torch.sum(x * x, dim=1) /\n                               torch.sum(mask == 1, dim=1) + 1e-5)\n            stdev = stdev.unsqueeze(dim=1)\n        else:\n            stdev = torch.sqrt(\n                torch.var(x, dim=1, keepdim=True, unbiased=False) + 1e-5)\n        x /= stdev\n        x = x.permute(0, 2, 1)\n        remainder = x.shape[2] % self.patch_len\n        if remainder != 0:\n            padding = self.patch_len - remainder\n            x = F.pad(x, (0, padding))\n        else:\n            padding = 0\n        x, n_vars = self.patch_embeddings(x)\n        return x, means, stdev, n_vars, padding\n\n    def prepare_prompt(self, x, n_vars, prefix_prompt, task_prompt, task_prompt_num, task_name=None, mask=None):\n        x = torch.reshape(\n            x, (-1, n_vars, x.shape[-2], x.shape[-1]))\n        # append prompt tokens\n        this_prompt = prefix_prompt.repeat(x.shape[0], 1, 1, 1)\n\n        if task_name == 'forecast':\n            this_mask_prompt = task_prompt.repeat(\n                x.shape[0], 1, task_prompt_num, 1)\n            init_full_input = torch.cat(\n                (this_prompt, x, this_mask_prompt), dim=-2)\n            init_mask_prompt = self.prompt2forecat(init_full_input.transpose(\n                -1, -2), init_full_input.shape[2]-prefix_prompt.shape[2]).transpose(-1, -2)\n            this_function_prompt = init_mask_prompt[:, :, -task_prompt_num:]\n            x = torch.cat((this_prompt, x, this_function_prompt), dim=2)\n            x[:, :, self.prompt_num:] = x[:, :, self.prompt_num:] + \\\n                self.position_embedding(x[:, :, self.prompt_num:])\n        elif task_name == 'classification':\n            this_function_prompt = task_prompt.repeat(x.shape[0], 1, 1, 1)\n            x = x + self.position_embedding(x)\n            x = torch.cat((this_prompt, x, this_function_prompt), dim=2)\n        elif task_name == 'imputation':\n            # fill the masked parts with mask tokens\n            # for imputation, masked is 0, unmasked is 1, so here to reverse mask\n            mask = 1-mask\n            mask = mask.permute(0, 2, 1)\n            mask = self.mark2token(mask)\n            mask_repeat = mask.unsqueeze(dim=-1)\n\n            mask_token = task_prompt\n            mask_repeat = mask_repeat.repeat(1, 1, 1, x.shape[-1])\n            x = x * (1-mask_repeat) + mask_token * mask_repeat\n\n            init_full_input = torch.cat((this_prompt, x), dim=-2)\n            init_mask_prompt = self.prompt2forecat(\n                init_full_input.transpose(-1, -2), x.shape[2]).transpose(-1, -2)\n            # keep the unmasked tokens and fill the masked ones with init_mask_prompt.\n            x = x * (1-mask_repeat) + init_mask_prompt * mask_repeat\n            x = x + self.position_embedding(x)\n            x = torch.cat((this_prompt, x), dim=2)\n        elif task_name == 'anomaly_detection':\n            x = x + self.position_embedding(x)\n            x = torch.cat((this_prompt, x), dim=2)\n\n        return x\n\n    def mark2token(self, x_mark):\n        x_mark = x_mark.unfold(\n            dimension=-1, size=self.patch_len, step=self.stride)\n        x_mark = x_mark.mean(dim=-1)\n        x_mark = (x_mark > 0).float()\n        return x_mark\n\n    def backbone(self, x, prefix_len, seq_len):\n        attn_mask = None\n        for block in self.blocks:\n            x = block(x, prefix_seq_len=prefix_len +\n                      seq_len, attn_mask=attn_mask)\n        return x\n\n    def forecast(self, x, x_mark, task_id):\n        dataset_name = self.configs_list[task_id][1]['dataset']\n        task_data_name = self.configs_list[task_id][0]\n        prefix_prompt = self.prompt_tokens[dataset_name]\n        task_prompt = self.mask_tokens[dataset_name]\n        task_prompt_num = self.cls_nums[task_data_name][0]\n        task_seq_num = self.cls_nums[task_data_name][1]\n        real_seq_len = self.cls_nums[task_data_name][2]\n\n        x, means, stdev, n_vars, _ = self.tokenize(x)\n\n        x = self.prepare_prompt(\n            x, n_vars, prefix_prompt, task_prompt, task_prompt_num, task_name='forecast')\n\n        seq_token_len = x.shape[-2]-prefix_prompt.shape[2]\n        x = self.backbone(x, prefix_prompt.shape[2], seq_token_len)\n\n        x = self.forecast_head(\n            x, real_seq_len, seq_token_len)\n        x = x[:, -task_seq_num:]\n\n        # De-Normalization from Non-stationary Transformer\n        x = x * (stdev[:, 0, :].unsqueeze(1).repeat(1, x.shape[1], 1))\n        x = x + (means[:, 0, :].unsqueeze(1).repeat(1, x.shape[1], 1))\n\n        return x\n\n    def classification(self, x, x_mark, task_id):\n        dataset_name = self.configs_list[task_id][1]['dataset']\n        task_data_name = self.configs_list[task_id][0]\n        prefix_prompt = self.prompt_tokens[dataset_name]\n        task_prompt = self.cls_tokens[task_data_name]\n        task_prompt_num = 1\n        category_token = self.category_tokens[task_data_name]\n\n        x, means, stdev, n_vars, _ = self.tokenize(x)\n\n        seq_len = x.shape[-2]\n\n        x = self.prepare_prompt(\n            x, n_vars, prefix_prompt, task_prompt, task_prompt_num, task_name='classification')\n\n        x = self.backbone(x, prefix_prompt.shape[2], seq_len)\n\n        x = self.cls_head(x, category_token)\n\n        return x\n\n    def imputation(self, x, x_mark, mask, task_id):\n        dataset_name = self.configs_list[task_id][1]['dataset']\n        prefix_prompt = self.prompt_tokens[dataset_name]\n        task_prompt = self.mask_tokens[dataset_name]\n\n        seq_len = x.shape[1]\n        x, means, stdev, n_vars, padding = self.tokenize(x, mask)\n\n        x = self.prepare_prompt(\n            x, n_vars, prefix_prompt, task_prompt, None, mask=mask, task_name='imputation')\n        seq_token_len = x.shape[-2]-prefix_prompt.shape[2]\n        x = self.backbone(x, prefix_prompt.shape[2], seq_token_len)\n\n        x = self.forecast_head(\n            x, seq_len+padding, seq_token_len)\n        x = x[:, :seq_len]\n\n        # De-Normalization from Non-stationary Transformer\n        x = x * (stdev[:, 0, :].unsqueeze(1).repeat(1, x.shape[1], 1))\n        x = x + (means[:, 0, :].unsqueeze(1).repeat(1, x.shape[1], 1))\n\n        return x\n\n    def anomaly_detection(self, x, x_mark, task_id):\n        dataset_name = self.configs_list[task_id][1]['dataset']\n        prefix_prompt = self.prompt_tokens[dataset_name]\n\n        seq_len = x.shape[1]\n        x, means, stdev, n_vars, padding = self.tokenize(x)\n\n        x = self.prepare_prompt(x, n_vars, prefix_prompt,\n                                None, None, task_name='anomaly_detection')\n        seq_token_len = x.shape[-2]-prefix_prompt.shape[2]\n        x = self.backbone(x, prefix_prompt.shape[2], seq_token_len)\n\n        x = self.forecast_head(\n            x, seq_len+padding, seq_token_len)\n        x = x[:, :seq_len]\n\n        # De-Normalization from Non-stationary Transformer\n        x = x * (stdev[:, 0, :].unsqueeze(1).repeat(1, x.shape[1], 1))\n        x = x + (means[:, 0, :].unsqueeze(1).repeat(1, x.shape[1], 1))\n\n        return x\n\n    def random_masking(self, x, min_mask_ratio, max_mask_ratio):\n        \"\"\"\n        Perform per-sample random masking.\n        \"\"\"\n        N, V, L, D = x.shape  # batch, var, length, dim\n\n        # Calculate mask ratios and lengths to keep for each sample in the batch\n        mask_ratios = torch.rand(N, device=x.device) * \\\n            (max_mask_ratio - min_mask_ratio) + min_mask_ratio\n        len_keeps = (L * (1 - mask_ratios)).long()\n\n        noise = torch.rand(N, L, device=x.device)  # noise in [0, 1]\n\n        # sort noise for each sample\n        # ascend: small is keep, large is remove\n        ids_shuffle = torch.argsort(noise, dim=1)\n        ids_restore = torch.argsort(ids_shuffle, dim=1)\n\n        # generate the binary mask: 0 is keep, 1 is remove\n        mask = torch.ones([N, L], device=x.device)\n\n        # Create a range tensor and compare with len_keeps for mask generation\n        range_tensor = torch.arange(L, device=x.device).expand(N, L)\n        mask = (range_tensor >= len_keeps.unsqueeze(1))\n\n        # unshuffle to get the binary mask\n        mask = torch.gather(mask, dim=1, index=ids_restore)\n        mask = mask.float()\n\n        return mask\n\n    def right_masking(self, x, min_mask_ratio, max_mask_ratio):\n        N, V, L, D = x.shape  # batch, var, length, dim\n\n        # Randomly choose a mask ratio for each sample within the specified range\n        mask_ratios = torch.rand(N, device=x.device) * \\\n            (max_mask_ratio - min_mask_ratio) + min_mask_ratio\n        len_keeps = (L * (1 - mask_ratios)).long()\n\n        # Binary mask creation without a for loop\n        len_keeps_matrix = len_keeps.unsqueeze(1).expand(N, L)\n        indices = torch.arange(L, device=x.device).expand_as(len_keeps_matrix)\n        mask = indices >= len_keeps_matrix\n        mask = mask.float()\n\n        return mask\n\n    def choose_masking(self, x, right_prob, min_mask_ratio, max_mask_ratio):\n        # Generate a random number to decide which masking function to use\n        if torch.rand(1).item() > right_prob:\n            return self.random_masking(x, min_mask_ratio, max_mask_ratio)\n        else:\n            return self.right_masking(x, min_mask_ratio, max_mask_ratio)\n\n    def get_mask_seq(self, mask, seq_len):\n        mask_seq = mask.unsqueeze(dim=-1).repeat(1, 1, self.patch_len)\n        mask_seq = mask_seq.permute(0, 2, 1)\n        mask_seq = mask_seq.masked_fill(mask_seq == 0, -1e9)\n        # Fold operation\n        mask_seq = torch.nn.functional.fold(mask_seq, output_size=(\n            seq_len, 1), kernel_size=(self.patch_len, 1), stride=(self.stride, 1))\n        # Apply threshold to bring back to 0/1 values\n        mask_seq = (mask_seq > 0).float()\n        mask_seq = mask_seq.squeeze(dim=-1).squeeze(dim=1)\n        return mask_seq\n\n    def pretraining(self, x, x_mark, task_id, enable_mask=False):\n        dataset_name = self.configs_list[task_id][1]['dataset']\n        task_data_name = self.configs_list[task_id][0]\n        prefix_prompt = self.prompt_tokens[dataset_name]\n        mask_token = self.mask_tokens[dataset_name]\n        cls_token = self.cls_tokens[task_data_name]\n\n        seq_len = x.shape[1]\n        x, means, stdev, n_vars, padding = self.tokenize(x)\n        seq_token_len = x.shape[-2]\n\n        # append prompt tokens\n        x = torch.reshape(\n            x, (-1, n_vars, x.shape[-2], x.shape[-1]))\n        # prepare prompts\n        this_prompt = prefix_prompt.repeat(x.shape[0], 1, 1, 1)\n\n        if enable_mask:\n            mask = self.choose_masking(x, self.right_prob,\n                                       self.min_mask_ratio, self.max_mask_ratio)\n            mask_repeat = mask.unsqueeze(dim=1).unsqueeze(dim=-1)\n            mask_repeat = mask_repeat.repeat(1, x.shape[1], 1, x.shape[-1])\n            x = x * (1-mask_repeat) + mask_token * mask_repeat  # todo\n\n            init_full_input = torch.cat((this_prompt, x), dim=-2)\n            init_mask_prompt = self.prompt2forecat(\n                init_full_input.transpose(-1, -2), x.shape[2]).transpose(-1, -2)\n            # keep the unmasked tokens and fill the masked ones with init_mask_prompt.\n            x = x * (1-mask_repeat) + init_mask_prompt * mask_repeat\n            x = x + self.position_embedding(x)\n            mask_seq = self.get_mask_seq(mask, seq_len+padding)\n            mask_seq = mask_seq[:, :seq_len]\n        this_function_prompt = cls_token.repeat(x.shape[0], 1, 1, 1)\n        x = torch.cat((this_prompt, x, this_function_prompt), dim=2)\n\n        x = self.backbone(x, prefix_prompt.shape[2], seq_token_len)\n\n        if enable_mask:\n            mask_dec_out = self.forecast_head(\n                x[:, :, :-1], seq_len+padding, seq_token_len)\n            mask_dec_out = mask_dec_out[:, :seq_len]\n            # De-Normalization from Non-stationary Transformer\n            mask_dec_out = mask_dec_out * \\\n                (stdev[:, 0, :].unsqueeze(1).repeat(\n                    1, mask_dec_out.shape[1], 1))\n            mask_dec_out = mask_dec_out + \\\n                (means[:, 0, :].unsqueeze(1).repeat(\n                    1, mask_dec_out.shape[1], 1))\n            cls_dec_out = self.cls_head(x, return_feature=True)\n            # detach grad of the forecasting on tokens\n            fused_dec_out = torch.cat(\n                (cls_dec_out, x[:, :, self.prompt_num:-1].detach()), dim=2)\n            cls_dec_out = self.pretrain_head(\n                fused_dec_out, seq_len+padding, seq_token_len)\n            cls_dec_out = cls_dec_out[:, :seq_len]\n            cls_dec_out = cls_dec_out * \\\n                (stdev[:, 0, :].unsqueeze(1).repeat(\n                    1, cls_dec_out.shape[1], 1))\n            cls_dec_out = cls_dec_out + \\\n                (means[:, 0, :].unsqueeze(1).repeat(\n                    1, cls_dec_out.shape[1], 1))\n\n            return cls_dec_out, mask_dec_out, mask_seq\n        else:\n            return cls_dec_out\n\n    def forward(self, x_enc, x_mark_enc, x_dec=None, x_mark_dec=None,\n                mask=None, task_id=None, task_name=None, enable_mask=None):\n        task_id = 0\n\n        # if task_name == 'long_term_forecast' or task_name == 'short_term_forecast':\n        dec_out = self.forecast(x_enc, x_mark_enc, task_id)\n        return dec_out  # [B, L, D]\n        # if task_name == 'imputation':\n        #     dec_out = self.imputation(\n        #         x_enc, x_mark_enc, mask, task_id)\n        #     return dec_out  # [B, L, D]\n        # if task_name == 'anomaly_detection':\n        #     dec_out = self.anomaly_detection(x_enc, x_mark_enc, task_id)\n        #     return dec_out  # [B, L, D]\n        # if task_name == 'classification':\n        #     dec_out = self.classification(x_enc, x_mark_enc, task_id)\n        #     return dec_out  # [B, N]\n        # if 'pretrain' in task_name:\n        #     dec_out = self.pretraining(x_enc, x_mark_enc, task_id,\n        #                                enable_mask=enable_mask)\n        #     return dec_out\n        # return None\n\n\nclass UniTS(Forecaster):\n    def __init__(\n        self,\n        ckpt_path: str = None,\n        **kwargs\n    ):\n        super().__init__(**kwargs)\n        self.no_training = True\n        \n        if (type(self.context_length).__name__=='list'):\n            context_length = max(context_length)\n            \n        if (type(self.prediction_length).__name__=='list'):\n            prediction_length = max(prediction_length)\n\n        args, configs_list = self.generate_units_default_args(self.dataset)\n        self.model = Model(args, configs_list, pretrain=False)\n        \n        pretrain_weight_path = ckpt_path\n\n        device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n        state_dict = torch.load(pretrain_weight_path, map_location=device)['student']\n        ckpt = {}\n        for k, v in state_dict.items():\n            if not ('cls_prompts' in k):\n                k = k.replace('module.', '') if 'module.' in k else k\n                ckpt[k] = v\n        \n        msg = self.model.load_state_dict(ckpt, strict=False)\n        if len(msg.missing_keys) > 0:\n            print(f\"\"\"Warning: There are missing keys in the pretrained model: {msg.missing_keys}, \n                which may cause prediction results less accurate.\"\"\")\n\n\n    def generate_units_default_args(self, dataset_name='ETTh1'):\n        class Args:\n            def __init__(self):\n                self.d_model = 128\n                self.n_heads = 8\n                self.e_layers = 3\n                self.prompt_num = 10\n                self.dropout = 0.1\n                self.patch_len = 16\n                self.stride = 16\n                self.batch_size = 32\n\n        args = Args()\n\n        # parse dataset names - ECL, ETTh1, Exchange, ILI, Traffic, Weather\n        units_valid_dataset_map = {\n            'ECL': ['ECL', 'electricity'],\n            'ETTh1': ['ETT'],\n            'Exchange': ['Exchange'],\n            'ILI': ['ILI'],\n            'Traffic': ['Traffic'],\n            'Weather': ['Weather']\n        }\n\n        units_dataset_name = 'DEFAULT'\n        for key, value_list in units_valid_dataset_map.items():\n            if any(substring.lower() in dataset_name for substring in value_list):\n                units_dataset_name = key\n                break\n        task_name = f\"LTF_{units_dataset_name}_p{self.prediction_length}\"\n\n        task_data_config = {\n            task_name: {\n                \"task_name\": \"long_term_forecast\",\n                \"dataset\": units_dataset_name,\n                \"data\": units_dataset_name,\n                \"embed\": \"timeF\",\n                \"features\": \"M\",\n                \"seq_len\": self.context_length,\n                \"label_len\": 48,\n                \"pred_len\": self.prediction_length,\n                \"enc_in\": self.target_dim,\n                \"dec_in\": self.target_dim,\n                \"c_out\": self.target_dim\n            }\n        }\n        task_data_config_list = []\n        for task_name, task_config in task_data_config.items():\n            task_config['max_batch'] = args.batch_size\n            task_data_config_list.append([task_name, task_config])\n        return args, task_data_config_list\n\n\n    def forecast(self, batch_data, pred_len=None, dataset_name=None, *args, **kwargs):\n        inputs = self.get_inputs(batch_data, 'encode')\n        inputs = inputs[:, -self.context_length:]\n        B, _, K = inputs.shape\n        point_forecast = self.model.forward(inputs, None)\n        return point_forecast.unsqueeze(1)\n"
  },
  {
    "path": "probts/model/forecaster/prob_forecaster/__init__.py",
    "content": "from .gru_nvp import GRU_NVP\nfrom .gru_maf import GRU_MAF\nfrom .timegrad import TimeGrad\nfrom .trans_maf import Trans_MAF\nfrom .csdi import CSDI\nfrom .tsdiff import TSDiffCond\n\n# ------- add lag_llama to sys.path ---------\ntry:\n    import os, sys\n    current_dir = os.path.dirname(os.path.realpath(__file__))\n    project_root = os.path.abspath(os.path.join(current_dir, '..', '..', '..', '..'))\n    lag_llama_path = os.path.join(project_root, 'submodules', 'lag_llama')\n    moirai_path = os.path.join(project_root, 'submodules', 'uni2ts', 'src')\n\n    if lag_llama_path not in sys.path:\n        sys.path.append(lag_llama_path)\n\n    if moirai_path not in sys.path:\n        sys.path.append(moirai_path)\n\nexcept Exception as e:\n    print(f\"Warning: Unable to add lag_llama to sys.path. {e}\")\n# -------------------------------------------\n\nimport importlib\n\nmodules = [\n    ('moirai', 'Moirai'),\n    ('chronos', 'Chronos'),\n    ('lag_llama', 'LagLlama'),\n]\n\nfor module, class_name in modules:\n    try:\n        mod = importlib.import_module(f\".{module}\", package=__package__)\n        globals()[class_name] = getattr(mod, class_name)\n    except ImportError:\n        # print(f\"Warning: {class_name} is not available due to missing dependencies.\")\n        pass"
  },
  {
    "path": "probts/model/forecaster/prob_forecaster/chronos.py",
    "content": "# ---------------------------------------------------------------------------------\n# Portions of this file are derived from Chronos\n# - Source: https://github.com/amazon-science/chronos-forecasting\n# - Paper: Chronos: Learning the Language of Time Series\n# - License: Apache License 2.0\n#\n# We thank the authors for their contributions.\n# ---------------------------------------------------------------------------------\n\n\nimport torch\n# from chronos import ChronosPipeline\nfrom einops import rearrange\nfrom probts.model.nn.arch.ChronosModule.base import BaseChronosPipeline\nfrom probts.model.forecaster import Forecaster\n\n\nclass Chronos(Forecaster):\n    def __init__(\n        self,\n        model_size: str = 'base',\n        **kwargs\n    ):\n        super().__init__(**kwargs)\n\n        if type(self.prediction_length) == list:\n            self.prediction_length = max(self.prediction_length)\n            \n\n        if type(self.context_length) == list:\n            self.context_length = max(self.context_length)\n            \n        self.pred_len = self.prediction_length\n\n        # Load pretrained model\n        self.no_training = True\n\n        self.pipeline = BaseChronosPipeline.from_pretrained(\n            f\"amazon/chronos-t5-{model_size}\",  # use \"amazon/chronos-bolt-small\" for the corresponding Chronos-Bolt model\n            device_map=\"cuda\", \n            torch_dtype=torch.bfloat16,)\n        \n        self.q = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] # Quantile levels\n\n\n\n    def forecast(self, batch_data, num_samples=None):\n        inputs = self.get_inputs(batch_data, 'encode')\n        inputs = inputs[:, -self.context_length:]\n        \n        B, _, K = inputs.shape\n        inputs = rearrange(inputs, 'b l k -> (b k) l')#.cpu()\n        context = [inputs[i] for i in range(B*K)]\n        inner_batch_size = 12 # for 80G gpu\n        forecast_samples = []\n\n        # Process in batches of size `inner_batch_size`\n        for i in range(0, len(context), inner_batch_size):\n            batch_context = context[i:i + inner_batch_size]\n            batch_forecast_samples = self.pipeline.predict(\n                batch_context,\n                prediction_length=self.pred_len,\n                num_samples=num_samples,\n                limit_prediction_length=False\n            )\n            forecast_samples.append(batch_forecast_samples)\n        \n        forecast_samples = torch.cat(forecast_samples, dim=0)\n        prob_forecast = rearrange(forecast_samples, '(b k) s l -> b s l k', b=B, k=K)\n        \n        return prob_forecast\n\n\n"
  },
  {
    "path": "probts/model/forecaster/prob_forecaster/csdi.py",
    "content": "# ---------------------------------------------------------------------------------\n# Portions of this file are derived from CSDI\n# - Source: https://github.com/ermongroup/CSDI\n# - Paper: CSDI: Conditional Score-based Diffusion Models for Probabilistic Time Series Imputation\n# - License: MIT license\n\n# We thank the authors for their contributions.\n# ---------------------------------------------------------------------------------\n\n\nimport torch\nimport torch.nn as nn\nimport numpy as np\nfrom einops import repeat\nfrom probts.model.forecaster import Forecaster\nfrom probts.model.nn.prob.diffusion_layers import diff_CSDI\n\n\nclass CSDI(Forecaster):\n    def __init__(\n        self, \n        channels: int = 64,\n        emb_time_dim: int = 128,\n        emb_feature_dim: int = 16,\n        num_steps: int = 50,\n        schedule: str = \"quad\",\n        beta_start: float = 0.0001,\n        beta_end: float = 0.5,\n        diffusion_embedding_dim: int = 128,\n        num_heads: int = 8,\n        n_layers: int = 4,\n        sample_size: int = 64,\n        linear_trans: bool = False,\n        **kwargs\n    ):\n        super().__init__(**kwargs)\n        self.autoregressive = False\n        self.dist_args = nn.Identity()\n\n        self.emb_time_dim = emb_time_dim\n        self.emb_feature_dim = emb_feature_dim\n        self.emb_total_dim = self.emb_time_dim + self.emb_feature_dim\n        self.device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n        self.emb_total_dim += 1  # for conditional mask\n        self.embed_layer = nn.Embedding(\n            num_embeddings=self.target_dim, embedding_dim=self.emb_feature_dim\n        )\n        side_dim = self.emb_total_dim\n        self.sample_size = sample_size\n\n        input_dim = 2\n        self.diffmodel = diff_CSDI(channels, diffusion_embedding_dim, side_dim, num_steps, num_heads, n_layers, inputdim=input_dim,linear=linear_trans)\n\n        # parameters for diffusion models\n        self.num_steps = num_steps\n        if schedule == \"quad\":\n            self.beta = np.linspace(\n                beta_start ** 0.5, beta_end ** 0.5, self.num_steps\n            ) ** 2\n        elif schedule == \"linear\":\n            self.beta = np.linspace(\n                beta_start, beta_end, self.num_steps\n            )\n\n        self.alpha_hat = 1 - self.beta\n        self.alpha = np.cumprod(self.alpha_hat)\n        self.alpha_torch = torch.tensor(self.alpha).float().unsqueeze(1).unsqueeze(1).to(self.device)\n\n    def time_embedding(self, pos, device, d_model=128):\n        pe = torch.zeros(pos.shape[0], pos.shape[1], d_model).to(device)\n        position = pos.unsqueeze(2)\n        div_term = 1 / torch.pow(\n            10000.0, torch.arange(0, d_model, 2).to(device) / d_model\n        )\n        pe[:, :, 0::2] = torch.sin(position * div_term)\n        pe[:, :, 1::2] = torch.cos(position * div_term)\n        return pe\n\n    def set_input_to_diffmodel(self, noisy_data, observed_data, cond_mask):\n        cond_obs = (cond_mask * observed_data).unsqueeze(1)\n        noisy_target = ((1 - cond_mask) * noisy_data).unsqueeze(1)\n        total_input = torch.cat([cond_obs, noisy_target], dim=1)  # (B,2,K,L)\n        return total_input\n\n    def get_masks(self, batch_data):\n        hist_observed_mask = batch_data.past_observed_values[:, -self.context_length:, ...]\n        target_observed_mask = batch_data.future_observed_values\n        observed_mask = torch.cat((hist_observed_mask, target_observed_mask), dim=1)\n\n        cond_mask = torch.cat((hist_observed_mask, torch.zeros_like(target_observed_mask)), dim=1)\n        return observed_mask, cond_mask # [B L K]\n\n    def get_side_info(self, observed_data, cond_mask, target_dimension_indicator, observed_tp=None):\n        \n        B, K, L = observed_data.shape\n        if observed_tp is None:\n            observed_tp = torch.arange(L) * 1.0\n            observed_tp = repeat(observed_tp, 'l -> b l', b=B).to(observed_data.device)\n\n        time_embed = self.time_embedding(observed_tp, observed_data.device, self.emb_time_dim)  # (B,L,emb)\n        time_embed = time_embed.unsqueeze(2).expand(-1, -1, K, -1) # (B,L,K, emb)\n        feature_embed = self.embed_layer(target_dimension_indicator)  # (B, K,emb)\n        feature_embed = feature_embed.unsqueeze(1).expand(-1, L, -1, -1) # (B,L,K, emb)\n\n        side_info = torch.cat([time_embed, feature_embed], dim=-1)  # (B,L,K,*)\n        side_info = side_info.permute(0, 3, 2, 1)  # (B,*,K,L)\n        side_mask = cond_mask.unsqueeze(1)  # (B,1,K,L)\n\n        side_info = torch.cat([side_info, side_mask], dim=1)\n        return side_info # (B,D,K,L)\n\n    def loss(self, batch_data, observed_tp=None):\n        past_target_cdf = batch_data.past_target_cdf[:, -self.context_length:, ...]\n        future_target_cdf = batch_data.future_target_cdf\n\n        observed_data = torch.cat([past_target_cdf, future_target_cdf], dim=1)\n        B, L, K = observed_data.shape\n        t = torch.randint(0, self.num_steps, [B]).to(past_target_cdf.device)\n\n        observed_mask, gt_mask = self.get_masks(batch_data)\n        feature_id = batch_data.target_dimension_indicator\n\n        if K > self.sample_size:\n            # sample subset\n            sampled_data = []\n            sampled_mask = []\n            sampled_feature_id = []\n            sampled_gt_mask = []\n            for i in range(len(observed_data)):\n                ind = np.arange(K)\n                np.random.shuffle(ind)\n                sampled_data.append(observed_data[i,...,ind[:self.sample_size]])\n                sampled_mask.append(observed_mask[i,...,ind[:self.sample_size]])\n                sampled_feature_id.append(feature_id[i,ind[:self.sample_size]])\n                sampled_gt_mask.append(gt_mask[i,...,ind[:self.sample_size]])\n            observed_data = torch.stack(sampled_data,0)\n            observed_mask = torch.stack(sampled_mask,0)\n            feature_id = torch.stack(sampled_feature_id,0)\n            gt_mask = torch.stack(sampled_gt_mask,0)\n\n        observed_data = observed_data.permute(0,2,1) # [B K L]\n        observed_mask = observed_mask.permute(0,2,1) # [B K L]\n        cond_mask = gt_mask.permute(0,2,1) # [B K L]\n\n        side_info = self.get_side_info(observed_data, cond_mask, feature_id, observed_tp)\n\n        target_mask = observed_mask - cond_mask\n        current_alpha = self.alpha_torch[t]  # (B,1,1)\n        noise = torch.randn_like(observed_data).to(observed_data.device)\n        noisy_data = (current_alpha ** 0.5) * observed_data + (1.0 - current_alpha) ** 0.5 * noise\n\n\n        total_input = self.set_input_to_diffmodel(noisy_data, observed_data, cond_mask)\n\n        predicted = self.diffmodel(total_input, side_info, t)  # (B,K,L)\n        residual = (noise - predicted) * target_mask\n\n        num_eval = target_mask.sum()\n        loss = (residual ** 2).sum() / (num_eval if num_eval > 0 else 1)\n        loss = self.get_weighted_loss(batch_data, loss)\n        return loss.mean()\n\n    def forecast(self, batch_data, num_samples):\n        observed_data = torch.cat([batch_data.past_target_cdf[:, -self.context_length:, ...], torch.zeros_like(batch_data.future_target_cdf)], dim=1).permute(0,2,1) \n        _, cond_mask = self.get_masks(batch_data)\n        cond_mask = cond_mask.permute(0,2,1)\n        side_info = self.get_side_info(observed_data, cond_mask, batch_data.target_dimension_indicator)\n        sample = self.sample(observed_data, cond_mask, side_info, num_samples)\n        sample = sample.permute(0,1,3,2)\n        return sample[:, : , -self.prediction_length:, :] # [B N L K]\n\n    def sample(self, observed_data, cond_mask, side_info, n_samples):\n        B, K, L = observed_data.shape\n        imputed_samples = torch.zeros(B, n_samples, K, L).to(observed_data.device)\n\n        for i in range(n_samples):\n            current_sample = torch.randn_like(observed_data).to(observed_data.device)\n\n            for t in range(self.num_steps - 1, -1, -1):\n                cond_obs = (cond_mask * observed_data).unsqueeze(1)\n                noisy_target = ((1 - cond_mask) * current_sample).unsqueeze(1) # [B 1 K L]\n                diff_input = torch.cat([cond_obs, noisy_target], dim=1)  # (B,2,K,L)\n                predicted = self.diffmodel(diff_input, side_info, torch.tensor([t]).to(observed_data.device))\n\n                coeff1 = 1 / self.alpha_hat[t] ** 0.5\n                coeff2 = (1 - self.alpha_hat[t]) / (1 - self.alpha[t]) ** 0.5\n                current_sample = coeff1 * (current_sample - coeff2 * predicted)\n\n                if t > 0:\n                    noise = torch.randn_like(current_sample).to(observed_data.device)\n                    sigma = (\n                        (1.0 - self.alpha[t - 1]) / (1.0 - self.alpha[t]) * self.beta[t]\n                    ) ** 0.5\n                    current_sample += sigma * noise\n\n            imputed_samples[:, i] = current_sample.detach()\n        return imputed_samples\n\n\n"
  },
  {
    "path": "probts/model/forecaster/prob_forecaster/gru_maf.py",
    "content": "# ---------------------------------------------------------------------------------\n# Portions of this file are derived from PyTorch-TS\n# - Source: https://github.com/zalandoresearch/pytorch-ts\n# - Paper: Multi-variate Probabilistic Time Series Forecasting via Conditioned Normalizing Flows\n# - License: MIT, Apache-2.0 license\n\n# We thank the authors for their contributions.\n# ---------------------------------------------------------------------------------\n\n\nimport torch\nimport torch.nn as nn\n\nfrom probts.data import ProbTSBatchData\nfrom probts.utils import repeat\nfrom probts.model.forecaster import Forecaster\nfrom probts.model.nn.prob.MAF import MAF\n\n\nclass GRU_MAF(Forecaster):\n    def __init__(\n        self,\n        enc_num_layers: int = 2,\n        enc_hidden_size: int = 40,\n        enc_dropout: float = 0.1,\n        n_blocks: int = 4,\n        hidden_size: int = 100,\n        n_hidden: int = 2,\n        conditional_length: int = 200,\n        dequantize: bool = False,\n        batch_norm: bool = True,\n        **kwargs\n    ):\n        super().__init__(**kwargs)\n        self.autoregressive = True\n        \n        self.encoder = nn.GRU(\n            input_size=self.input_size,\n            hidden_size=enc_hidden_size,\n            num_layers=enc_num_layers,\n            dropout=enc_dropout,\n            batch_first=True\n        )\n        self.prob_model = MAF(\n            n_blocks=n_blocks,\n            target_dim=self.target_dim,\n            hidden_size=hidden_size,\n            n_hidden=n_hidden,\n            f_hidden_size=enc_hidden_size,\n            conditional_length=conditional_length,\n            dequantize=dequantize,\n            batch_norm=batch_norm\n        )\n\n    def loss(self, batch_data):\n        if self.use_scaling:\n            self.get_scale(batch_data)\n            self.prob_model.scale = self.scaler.scale\n        \n        inputs = self.get_inputs(batch_data, 'all')\n        enc_outs, states = self.encoder(inputs)\n        enc_outs = enc_outs[:, -self.prediction_length-1:-1, ...]\n        \n        dist_args = self.prob_model.dist_args(enc_outs)\n        loss = self.prob_model.loss(batch_data.future_target_cdf, dist_args).unsqueeze(-1)\n        loss = self.get_weighted_loss(batch_data, loss)\n        return loss.mean()\n\n    def forecast(self, batch_data, num_samples=None):\n        if self.use_scaling:\n            self.get_scale(batch_data)\n        \n        states = self.encode(batch_data)\n        \n        repeated_target_dimension_indicator = repeat(batch_data.target_dimension_indicator, num_samples)\n        repeated_past_target_cdf = repeat(batch_data.past_target_cdf, num_samples)\n        repeated_future_time_feat = repeat(batch_data.future_time_feat, num_samples)\n        repeated_states = repeat(states, num_samples, dim=1)\n        if self.use_scaling:\n            repeated_scale = repeat(self.scaler.scale, num_samples)\n            self.scaler.scale = repeated_scale\n            self.prob_model.scale = repeated_scale\n\n        future_samples = []\n        for k in range(self.prediction_length):\n            repeated_batch_data = ProbTSBatchData({\n                'target_dimension_indicator': repeated_target_dimension_indicator,\n                'past_target_cdf': repeated_past_target_cdf,\n                'future_time_feat': repeated_future_time_feat[:, k:k+1, ...]\n            }, device=batch_data.device)\n\n            enc_outs, repeated_states = self.decode(repeated_batch_data, repeated_states)\n            # Sample\n            dist_args = self.prob_model.dist_args(enc_outs)\n            new_samples = self.prob_model.sample(cond=dist_args)\n            future_samples.append(new_samples)\n\n            repeated_past_target_cdf = torch.cat(\n                (repeated_past_target_cdf, new_samples), dim=1\n            )\n\n        forecasts = torch.cat(future_samples, dim=1).reshape(\n            -1, num_samples, self.prediction_length, self.target_dim)\n        return forecasts\n\n    def encode(self, batch_data):\n        inputs = self.get_inputs(batch_data, 'encode')\n        outputs, states = self.encoder(inputs)\n        return states\n\n    def decode(self, batch_data, states=None):\n        inputs = self.get_inputs(batch_data, 'decode')\n        outputs, states = self.encoder(inputs, states)\n        return outputs, states\n"
  },
  {
    "path": "probts/model/forecaster/prob_forecaster/gru_nvp.py",
    "content": "# ---------------------------------------------------------------------------------\n# Portions of this file are derived from PyTorch-TS\n# - Source: https://github.com/zalandoresearch/pytorch-ts\n# - Paper: Multi-variate Probabilistic Time Series Forecasting via Conditioned Normalizing Flows\n# - License: MIT, Apache-2.0 license\n\n# We thank the authors for their contributions.\n# ---------------------------------------------------------------------------------\n\n\nimport torch\nimport torch.nn as nn\n\nfrom probts.data import ProbTSBatchData\nfrom probts.utils import repeat\nfrom probts.model.forecaster import Forecaster\nfrom probts.model.nn.prob.RealNVP import RealNVP\n\n\nclass GRU_NVP(Forecaster):\n    def __init__(\n        self,\n        enc_num_layers: int = 2,\n        enc_hidden_size: int = 40,\n        enc_dropout: float = 0.1,\n        n_blocks: int = 4,\n        hidden_size: int = 100,\n        n_hidden: int = 2,\n        conditional_length: int = 200,\n        dequantize: bool = False,\n        batch_norm: bool = True,\n        **kwargs\n    ):\n        super().__init__(**kwargs)\n        self.autoregressive = True\n        \n        self.encoder = nn.GRU(\n            input_size=self.input_size,\n            hidden_size=enc_hidden_size,\n            num_layers=enc_num_layers,\n            dropout=enc_dropout,\n            batch_first=True\n        )\n        self.prob_model = RealNVP(\n            n_blocks=n_blocks,\n            target_dim=self.target_dim,\n            hidden_size=hidden_size,\n            n_hidden=n_hidden,\n            f_hidden_size=enc_hidden_size,\n            conditional_length=conditional_length,\n            dequantize=dequantize,\n            batch_norm=batch_norm\n        )\n\n    def loss(self, batch_data):\n        if self.use_scaling:\n            self.get_scale(batch_data)\n            self.prob_model.scale = self.scaler.scale\n        \n        inputs = self.get_inputs(batch_data, 'all')\n        enc_outs, states = self.encoder(inputs)\n        enc_outs = enc_outs[:, -self.prediction_length-1:-1, ...]\n        \n        dist_args = self.prob_model.dist_args(enc_outs)\n        loss = self.prob_model.loss(batch_data.future_target_cdf, dist_args).unsqueeze(-1)\n        loss = self.get_weighted_loss(batch_data, loss)\n        return loss.mean()\n\n    def forecast(self, batch_data, num_samples=None):\n        if self.use_scaling:\n            self.get_scale(batch_data)\n        \n        states = self.encode(batch_data)\n        \n        repeated_target_dimension_indicator = repeat(batch_data.target_dimension_indicator, num_samples)\n        repeated_past_target_cdf = repeat(batch_data.past_target_cdf, num_samples)\n        repeated_future_time_feat = repeat(batch_data.future_time_feat, num_samples)\n        repeated_states = repeat(states, num_samples, dim=1)\n        if self.use_scaling:\n            repeated_scale = repeat(self.scaler.scale, num_samples)\n            self.scaler.scale = repeated_scale\n            self.prob_model.scale = repeated_scale\n\n        future_samples = []\n        for k in range(self.prediction_length):\n            repeated_batch_data = ProbTSBatchData({\n                'target_dimension_indicator': repeated_target_dimension_indicator,\n                'past_target_cdf': repeated_past_target_cdf,\n                'future_time_feat': repeated_future_time_feat[:, k:k+1, ...]\n            }, device=batch_data.device)\n\n            enc_outs, repeated_states = self.decode(repeated_batch_data, repeated_states)\n            # Sample\n            dist_args = self.prob_model.dist_args(enc_outs)\n            new_samples = self.prob_model.sample(cond=dist_args)\n            future_samples.append(new_samples)\n\n            repeated_past_target_cdf = torch.cat(\n                (repeated_past_target_cdf, new_samples), dim=1\n            )\n\n        forecasts = torch.cat(future_samples, dim=1).reshape(\n            -1, num_samples, self.prediction_length, self.target_dim)\n        return forecasts\n\n    def encode(self, batch_data):\n        inputs = self.get_inputs(batch_data, 'encode')\n        outputs, states = self.encoder(inputs)\n        return states\n\n    def decode(self, batch_data, states=None):\n        inputs = self.get_inputs(batch_data, 'decode')\n        outputs, states = self.encoder(inputs, states)\n        return outputs, states\n"
  },
  {
    "path": "probts/model/forecaster/prob_forecaster/lag_llama.py",
    "content": "# ---------------------------------------------------------------------------------\n# Portions of this file are derived from lag-llama\n# - Source: https://github.com/time-series-foundation-models/lag-llama\n# - Paper: Lag-Llama: Towards Foundation Models for Probabilistic Time Series Forecasting\n# - License: Apache License 2.0\n\n# We thank the authors for their contributions.\n# ---------------------------------------------------------------------------------\n\n\nimport numpy as np\nimport torch\n\nfrom gluonts.dataset.common import ListDataset\n\nfrom probts.model.forecaster import Forecaster\nfrom submodules.lag_llama.lag_llama.gluon.estimator import LagLlamaEstimator\n\n\nclass LagLlama(Forecaster):\n    def __init__(\n        self,\n        use_rope_scaling: bool = True,\n        ckpt_path: str = None,\n        **kwargs\n    ):\n        super().__init__(**kwargs)\n        \n        # self.ctx_len = kwargs.get('context_length')\n        # self.pred_len = kwargs.get('prediction_length')\n        \n        if type(self.prediction_length) == list:\n            self.prediction_length = max(self.prediction_length)\n            \n\n        if type(self.context_length) == list:\n            self.context_length = max(self.context_length)\n            \n        self.ctx_len = self.context_length\n        self.pred_len = self.prediction_length\n\n        # Load pretrained model\n        self.no_training = True\n        device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n        ckpt = torch.load(ckpt_path, map_location=device)\n        estimator_args = ckpt[\"hyper_parameters\"][\"model_kwargs\"]\n        rope_scaling_arguments = {\n            \"type\": \"linear\",\n            \"factor\": max(1.0, (self.ctx_len + self.pred_len) / estimator_args[\"context_length\"]), # 32\n        }\n        # Load model checkpoint\n        estimator = LagLlamaEstimator(\n            ckpt_path=ckpt_path,\n            prediction_length=self.pred_len,\n            context_length=self.ctx_len, # Lag-Llama was trained with a context length of 32, but can work with any context length\n\n            # estimator args\n            input_size=estimator_args[\"input_size\"], # 1\n            n_layer=estimator_args[\"n_layer\"], # 8\n            n_embd_per_head=estimator_args[\"n_embd_per_head\"], # 16\n            n_head=estimator_args[\"n_head\"], # 9\n            scaling=estimator_args[\"scaling\"], # robust\n            time_feat=estimator_args[\"time_feat\"], # True\n            rope_scaling=rope_scaling_arguments if use_rope_scaling else None, # long-term set to True\n\n            batch_size=4,\n            num_parallel_samples=100,\n            device=device,\n        )\n\n        lightning_module = estimator.create_lightning_module()\n        transformation = estimator.create_transformation()\n        self.predictor = estimator.create_predictor(transformation, lightning_module)\n\n    \n    def forecast(self, batch_data, num_samples=None):\n        inputs = self.get_inputs(batch_data, 'encode')\n        inputs = inputs[:, -self.context_length:]\n        datastamps = batch_data.past_time_feat.cpu().numpy().astype('datetime64[s]')\n\n        # for now, we only support batch_size=1\n        B, _, K = inputs.shape \n        # past_target = batch_data.past_target_cdf[:, -self.context_length:]\n        start_time = datastamps.reshape(-1)[0]\n        data = [{\"start\": start_time, \"target\": inputs[:,:,i].cpu().squeeze()} for i in range(K)]\n        dataset = ListDataset(data, freq='1h')\n\n        forecasts = self.predictor.predict(dataset, num_samples=num_samples)\n        samples = [fs.samples for fs in forecasts]\n        forecasts = np.array(samples).transpose(1, 2, 0)\n\n        prob_forecast = forecasts[np.newaxis, :, :]\n        prob_forecast = torch.tensor(prob_forecast) # shape: b s l k\n        \n        return prob_forecast\n"
  },
  {
    "path": "probts/model/forecaster/prob_forecaster/moirai.py",
    "content": "# ---------------------------------------------------------------------------------\n# Portions of this file are derived from uni2ts\n# - Source: https://github.com/SalesforceAIResearch/uni2ts\n# - Paper: Unified Training of Universal Time Series Forecasting Transformers\n# - License: Apache License 2.0\n\n# We thank the authors for their contributions.\n# ---------------------------------------------------------------------------------\n\n\nfrom typing import Union\nfrom probts.model.forecaster import Forecaster\nfrom einops import rearrange, repeat \nfrom probts.model.nn.arch.Moirai_backbone import MoiraiBackbone\nfrom uni2ts.model.moirai.module import MoiraiModule\nimport sys\n\nclass Moirai(Forecaster):\n    def __init__(\n        self,\n        variate_mode: str = 'M',\n        patch_size: Union[str, int] = 'auto',\n        model_size: str = 'base',\n        scaling: bool = True,\n        **kwargs\n    ):\n        super().__init__(**kwargs)\n        self.variate_mode = variate_mode\n        self.patch_size = patch_size if patch_size == 'auto' else int(patch_size)\n        \n        if type(self.prediction_length) == list:\n            self.prediction_length = max(self.prediction_length)\n\n        if type(self.context_length) == list:\n            self.context_length = max(self.context_length)\n        \n        # Load pretrained model\n        self.no_training = True\n        self.moirai = MoiraiBackbone(\n            module=MoiraiModule.from_pretrained(f\"Salesforce/moirai-1.0-R-{model_size}\"),\n            prediction_length=self.prediction_length,\n            context_length=self.context_length,\n            patch_size=self.patch_size,\n            target_dim=self.target_dim if self.variate_mode == 'M' else 1,\n            scaling=scaling\n        )\n\n    def forecast(self, batch_data, num_samples=None):\n        if self.variate_mode == 'M':\n            forecasts = self.moirai(\n                past_target=batch_data.past_target_cdf,\n                past_observed_target=batch_data.past_observed_values,\n                past_is_pad=batch_data.past_is_pad,\n                num_samples=num_samples\n            )\n        elif self.variate_mode == 'S':\n            B, L, K = batch_data.past_target_cdf.shape\n            forecasts = self.moirai(\n                past_target=rearrange(batch_data.past_target_cdf, 'b l k -> (b k) l').unsqueeze(-1),\n                past_observed_target=rearrange(batch_data.past_observed_values, 'b l k -> (b k) l').unsqueeze(-1),\n                past_is_pad=repeat(batch_data.past_is_pad, 'b l -> (b k) l', k=K),\n                num_samples=num_samples\n            )\n            forecasts = forecasts.squeeze(-1)\n            forecasts = rearrange(forecasts, '(b k) n l -> b n l k', b=B, k=K)\n        else:\n            raise ValueError(f\"Unknown variate mode: {self.variate_mode}\")\n        return forecasts\n"
  },
  {
    "path": "probts/model/forecaster/prob_forecaster/timegrad.py",
    "content": "# ---------------------------------------------------------------------------------\n# Portions of this file are derived from PyTorch-TS\n# - Source: https://github.com/zalandoresearch/pytorch-ts\n# - Paper: Multi-variate Probabilistic Time Series Forecasting via Conditioned Normalizing Flows\n# - License: MIT, Apache-2.0 license\n\n# We thank the authors for their contributions.\n# ---------------------------------------------------------------------------------\n\n\nimport torch\nimport torch.nn as nn\n\nfrom probts.data import ProbTSBatchData\nfrom probts.utils import repeat\nfrom probts.model.forecaster import Forecaster\nfrom probts.model.nn.prob.gaussian_diffusion import GaussianDiffusion\n\n\nclass TimeGrad(Forecaster):\n    def __init__(\n        self,\n        enc_num_layers: int = 2,\n        enc_hidden_size: int = 40,\n        enc_dropout: float = 0.1,\n        conditional_length: int = 100,\n        beta_end: float = 0.1,\n        diff_steps: int = 100,\n        loss_type: str = \"l2\",\n        beta_schedule: str = \"linear\",\n        **kwargs\n    ):\n        super().__init__(**kwargs)\n        self.autoregressive = True\n        \n        self.encoder = nn.GRU(\n            input_size=self.input_size,\n            hidden_size=enc_hidden_size,\n            num_layers=enc_num_layers,\n            dropout=enc_dropout,\n            batch_first=True\n        )\n        self.prob_model = GaussianDiffusion(\n            target_dim=self.target_dim,\n            f_hidden_size=enc_hidden_size,\n            conditional_length=conditional_length,\n            beta_end=beta_end,\n            diff_steps=diff_steps,\n            loss_type=loss_type,\n            beta_schedule=beta_schedule\n        )\n\n    def loss(self, batch_data):\n        if self.use_scaling:\n            self.get_scale(batch_data)\n            self.prob_model.scale = self.scaler.scale\n        \n        inputs = self.get_inputs(batch_data, 'all')\n        enc_outs, states = self.encoder(inputs)\n        enc_outs = enc_outs[:, -self.prediction_length-1:-1, ...]\n        \n        dist_args = self.prob_model.dist_args(enc_outs)\n        loss = self.prob_model.loss(batch_data.future_target_cdf, dist_args).unsqueeze(-1)\n        loss = self.get_weighted_loss(batch_data, loss)\n        return loss.mean()\n\n    def forecast(self, batch_data, num_samples=None):\n        if self.use_scaling:\n            self.get_scale(batch_data)\n        \n        states = self.encode(batch_data)\n        \n        repeated_target_dimension_indicator = repeat(batch_data.target_dimension_indicator, num_samples)\n        repeated_past_target_cdf = repeat(batch_data.past_target_cdf, num_samples)\n        repeated_future_time_feat = repeat(batch_data.future_time_feat, num_samples)\n        repeated_states = repeat(states, num_samples, dim=1)\n        if self.use_scaling:\n            repeated_scale = repeat(self.scaler.scale, num_samples)\n            self.scaler.scale = repeated_scale\n            self.prob_model.scale = repeated_scale\n\n        future_samples = []\n        for k in range(self.prediction_length):\n            repeated_batch_data = ProbTSBatchData({\n                'target_dimension_indicator': repeated_target_dimension_indicator,\n                'past_target_cdf': repeated_past_target_cdf,\n                'future_time_feat': repeated_future_time_feat[:, k:k+1, ...]\n            }, device=batch_data.device)\n\n            enc_outs, repeated_states = self.decode(repeated_batch_data, repeated_states)\n            # Sample\n            dist_args = self.prob_model.dist_args(enc_outs)\n            new_samples = self.prob_model.sample(cond=dist_args)\n            future_samples.append(new_samples)\n\n            repeated_past_target_cdf = torch.cat(\n                (repeated_past_target_cdf, new_samples), dim=1\n            )\n\n        forecasts = torch.cat(future_samples, dim=1).reshape(\n            -1, num_samples, self.prediction_length, self.target_dim)\n        return forecasts\n\n    def encode(self, batch_data):\n        inputs = self.get_inputs(batch_data, 'encode')\n        outputs, states = self.encoder(inputs)\n        return states\n\n    def decode(self, batch_data, states=None):\n        inputs = self.get_inputs(batch_data, 'decode')\n        outputs, states = self.encoder(inputs, states)\n        return outputs, states\n"
  },
  {
    "path": "probts/model/forecaster/prob_forecaster/trans_maf.py",
    "content": "# ---------------------------------------------------------------------------------\n# Portions of this file are derived from PyTorch-TS\n# - Source: https://github.com/zalandoresearch/pytorch-ts\n# - Paper: Multi-variate Probabilistic Time Series Forecasting via Conditioned Normalizing Flows\n# - License: MIT, Apache-2.0 license\n\n# We thank the authors for their contributions.\n# ---------------------------------------------------------------------------------\n\n\nimport torch\nimport torch.nn as nn\n\nfrom probts.data import ProbTSBatchData\nfrom probts.utils import repeat\nfrom probts.model.forecaster import Forecaster\nfrom probts.model.nn.prob.MAF import MAF\n\n\nclass Trans_MAF(Forecaster):\n    def __init__(\n        self,\n        enc_hidden_size: int = 32,\n        enc_num_heads: int = 8,\n        enc_num_encoder_layers: int = 3,\n        enc_num_decoder_layers: int = 3,\n        enc_dim_feedforward_scale: int = 4,\n        enc_dropout: float = 0.1,\n        enc_activation: str = 'gelu',\n        n_blocks: int = 4,\n        hidden_size: int = 100,\n        n_hidden: int = 2,\n        conditional_length: int = 200,\n        dequantize: bool = False,\n        batch_norm: bool = True,\n        **kwargs\n    ):\n        super().__init__(**kwargs)\n        self.autoregressive = True\n\n        self.enc_linear = nn.Linear(self.input_size, enc_hidden_size)\n        self.dec_linear = nn.Linear(self.input_size, enc_hidden_size)\n        self.model = nn.Transformer(\n            d_model=enc_hidden_size,\n            nhead=enc_num_heads,\n            num_encoder_layers=enc_num_encoder_layers,\n            num_decoder_layers=enc_num_decoder_layers,\n            dim_feedforward=enc_dim_feedforward_scale * enc_hidden_size,\n            dropout=enc_dropout,\n            activation=enc_activation\n        )\n\n        self.register_buffer(\n            \"tgt_mask\",\n            self.model.generate_square_subsequent_mask(self.prediction_length),\n        )\n        \n        self.prob_model = MAF(\n            n_blocks=n_blocks,\n            target_dim=self.target_dim,\n            hidden_size=hidden_size,\n            n_hidden=n_hidden,\n            f_hidden_size=enc_hidden_size,\n            conditional_length=conditional_length,\n            dequantize=dequantize,\n            batch_norm=batch_norm\n        )\n\n    def loss(self, batch_data):\n        if self.use_scaling:\n            self.get_scale(batch_data)\n            self.prob_model.scale = self.scaler.scale\n        \n        inputs = self.get_inputs(batch_data, 'all') # [B L D]\n\n        enc_inputs = inputs[:, :self.context_length, ...]\n        enc_inputs = self.enc_linear(enc_inputs).permute(1, 0, 2)\n        enc_outputs = self.model.encoder(enc_inputs) # [L_in B H]\n\n        dec_inputs = inputs[:, -self.prediction_length-1:-1, ...]\n        dec_inputs = self.dec_linear(dec_inputs).permute(1, 0, 2)\n        dec_outputs = self.model.decoder(\n            dec_inputs, enc_outputs, tgt_mask=self.tgt_mask)\n        dec_outputs = dec_outputs.permute(1, 0, 2)  # [L_out B D]\n        \n        dist_args = self.prob_model.dist_args(dec_outputs)\n        loss = self.prob_model.loss(batch_data.future_target_cdf, dist_args).unsqueeze(-1)\n        loss = self.get_weighted_loss(batch_data, loss)\n        return loss.mean()\n\n    def forecast(self, batch_data, num_samples=None):\n        if self.use_scaling:\n            self.get_scale(batch_data)\n        \n        states = self.encode(batch_data)\n        \n        repeated_target_dimension_indicator = repeat(batch_data.target_dimension_indicator, num_samples)\n        repeated_past_target_cdf = repeat(batch_data.past_target_cdf, num_samples)\n        repeated_future_time_feat = repeat(batch_data.future_time_feat, num_samples)\n        repeated_states = repeat(states, num_samples, dim=1)\n        if self.use_scaling:\n            repeated_scale = repeat(self.scaler.scale, num_samples)\n            self.scaler.scale = repeated_scale\n            self.prob_model.scale = repeated_scale\n\n        future_samples = []\n        for k in range(self.prediction_length):\n            repeated_batch_data = ProbTSBatchData({\n                'target_dimension_indicator': repeated_target_dimension_indicator,\n                'past_target_cdf': repeated_past_target_cdf,\n                'future_time_feat': repeated_future_time_feat[:, k:k+1, ...]\n            }, device=batch_data.device)\n\n            enc_outs, repeated_states = self.decode(repeated_batch_data, repeated_states)\n            # Sample\n            dist_args = self.prob_model.dist_args(enc_outs)\n            new_samples = self.prob_model.sample(cond=dist_args)\n            future_samples.append(new_samples)\n\n            repeated_past_target_cdf = torch.cat(\n                (repeated_past_target_cdf, new_samples), dim=1\n            )\n\n        forecasts = torch.cat(future_samples, dim=1).reshape(\n            -1, num_samples, self.prediction_length, self.target_dim)\n        return forecasts\n\n    def encode(self, batch_data):\n        inputs = self.get_inputs(batch_data, 'encode')\n        inputs = self.enc_linear(inputs).permute(1, 0, 2)\n        states = self.model.encoder(inputs)\n        return states\n\n    def decode(self, batch_data, states=None):\n        inputs = self.get_inputs(batch_data, 'decode')\n        inputs = self.dec_linear(inputs).permute(1, 0, 2)\n        outputs = self.model.decoder(inputs, states, tgt_mask=None)\n        return outputs.permute(1, 0, 2), states\n"
  },
  {
    "path": "probts/model/forecaster/prob_forecaster/tsdiff.py",
    "content": "# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n# SPDX-License-Identifier: Apache-2.0\n\n# ---------------------------------------------------------------------------------\n# Portions of this file are derived from TSDiff\n# - Source: https://github.com/amazon-science/unconditional-time-series-diffusion\n# - Paper: Predict, Refine, Synthesize: Self-Guiding Diffusion Models for Probabilistic Time Series Forecasting\n# - License: Apache-2.0\n#\n# We thank the authors for their contributions.\n# ---------------------------------------------------------------------------------\n\nimport torch\nimport torch.nn.functional as F\nfrom probts.utils import extract\nfrom probts.model.forecaster import Forecaster\nfrom probts.model.nn.arch.S4.s4_backbones import BackboneModel\nfrom probts.utils import repeat\nimport sys\n\ndef linear_beta_schedule(timesteps):\n    beta_start = 0.0001\n    beta_end = 0.1\n    return torch.linspace(beta_start, beta_end, timesteps)\n\n\nclass TSDiffCond(Forecaster):\n    def __init__(\n        self,\n        hidden_dim: int,\n        step_emb: int,\n        timesteps: int,\n        num_residual_blocks: int,\n        dropout: float = 0,\n        # use_features: bool = False,\n        init_skip=True,\n        noise_observed=False, # reconstruct past\n        mode=\"diag\",\n        measure=\"diag\",\n        **kwargs\n    ):\n        super().__init__(**kwargs)\n        backbone_parameters = {\n            \"input_dim\": self.target_dim,\n            \"hidden_dim\": hidden_dim,\n            \"output_dim\": self.target_dim,\n            \"step_emb\": step_emb,\n            \"num_residual_blocks\": num_residual_blocks,\n            \"residual_block\": \"s4\",\n            \"mode\": mode,\n            'measure': measure,\n        }\n        # self.use_features=use_features\n        self.timesteps = timesteps\n        self.betas = linear_beta_schedule(timesteps)\n        self.sqrt_one_minus_beta = torch.sqrt(1.0 - self.betas)\n        self.alphas = 1 - self.betas\n        self.alphas_cumprod = torch.cumprod(self.alphas, axis=0)\n        self.alphas_cumprod_prev = F.pad(\n            self.alphas_cumprod[:-1], (1, 0), value=1.0\n        )\n        self.sqrt_recip_alphas = torch.sqrt(1.0 / self.alphas)\n        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)\n        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(\n            1.0 - self.alphas_cumprod\n        )\n        self.posterior_variance = (\n            self.betas\n            * (1.0 - self.alphas_cumprod_prev)\n            / (1.0 - self.alphas_cumprod)\n        )\n        self.backbone = BackboneModel(\n            **backbone_parameters,\n            num_features=self.target_dim,\n            init_skip=init_skip,\n            dropout=dropout,\n        )\n        self.noise_observed = noise_observed\n\n    def _extract_features(self, batch_data):\n        inputs = self.get_inputs(batch_data, 'all')\n        x = inputs[:,:, :self.target_dim]\n        features = inputs.clone()\n        \n        if self.use_time_feat:\n            features[:,self.context_length:, :self.target_dim] = 0\n        else:\n            features = features[:,:, :self.target_dim]\n            features[:,self.context_length:] = 0\n        \n        observation_mask = torch.zeros_like(x, device=x.device)\n        observation_mask[:,:self.context_length] = 1\n        \n        return x, features, observation_mask\n\n    def q_sample(self, x_start, t, noise=None):\n        device = next(self.backbone.parameters()).device\n        if noise is None:\n            noise = torch.randn_like(x_start, device=device)\n        sqrt_alphas_cumprod_t = extract(\n            self.sqrt_alphas_cumprod, t, x_start.shape\n        )\n        sqrt_one_minus_alphas_cumprod_t = extract(\n            self.sqrt_one_minus_alphas_cumprod, t, x_start.shape\n        )\n\n        return (\n            sqrt_alphas_cumprod_t * x_start\n            + sqrt_one_minus_alphas_cumprod_t * noise\n        )\n\n    def p_losses(\n        self,\n        x_start,\n        t,\n        features=None,\n        noise=None,\n        loss_type=\"l2\",\n        reduction=\"none\",\n    ):\n        device = next(self.backbone.parameters()).device\n        if noise is None:\n            noise = torch.randn_like(x_start, device=device)\n\n        x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)\n        predicted_noise = self.backbone(x_noisy, t, features)\n\n        if loss_type == \"l1\":\n            loss = F.l1_loss(noise, predicted_noise, reduction=reduction)\n        elif loss_type == \"l2\":\n            loss = F.mse_loss(noise, predicted_noise, reduction=reduction)\n        elif loss_type == \"huber\":\n            loss = F.smooth_l1_loss(\n                noise, predicted_noise, reduction=reduction\n            )\n        else:\n            raise NotImplementedError()\n\n        return loss, x_noisy, predicted_noise\n\n    @torch.no_grad()\n    def p_sample(self, x, t, t_index, features=None):\n        betas_t = extract(self.betas, t, x.shape)\n        sqrt_one_minus_alphas_cumprod_t = extract(\n            self.sqrt_one_minus_alphas_cumprod, t, x.shape\n        )\n        sqrt_recip_alphas_t = extract(self.sqrt_recip_alphas, t, x.shape)\n\n\n        predicted_noise = self.backbone(x, t, features)\n\n        model_mean = sqrt_recip_alphas_t * (\n            x - betas_t * predicted_noise / sqrt_one_minus_alphas_cumprod_t\n        )\n\n        if t_index == 0:\n            return model_mean\n        else:\n            posterior_variance_t = extract(self.posterior_variance, t, x.shape)\n            noise = torch.randn_like(x)\n            return model_mean + torch.sqrt(posterior_variance_t) * noise\n\n    def step(self, x, t, features, loss_mask):\n        noise = torch.randn_like(x)\n        if not self.noise_observed:\n            noise = (1 - loss_mask) * x + noise * loss_mask\n\n        num_eval = loss_mask.sum()\n        sq_err, _, _ = self.p_losses(\n            x,\n            t,\n            features,\n            loss_type=\"l2\",\n            reduction=\"none\",\n            noise=noise,\n        )\n\n        if self.noise_observed:\n            elbo_loss = sq_err.mean()\n        else:\n            sq_err = sq_err * loss_mask\n            elbo_loss = sq_err.sum() / (num_eval if num_eval else 1)\n        return elbo_loss\n\n\n\n    def loss(self, batch_data):\n        # [b l k 1], [b l k 2]\n        x, features, observation_mask = self._extract_features(batch_data)\n        loss_mask = 1 - observation_mask\n\n        t = torch.randint(\n            0, self.timesteps, [x.shape[0]], device=x.device\n        ).long()\n        \n        loss = self.step(x, t, features, loss_mask)\n\n        if torch.isnan(loss):\n            print(\"Loss is NaN, exiting.\")\n            sys.exit(1)\n        return loss\n\n    def forecast(self, batch_data, num_samples):\n        observation, features, observation_mask = self._extract_features(batch_data)\n\n        observation = observation.to(observation.device)\n\n        pred = self.sample(\n            observation=observation,\n            observation_mask=observation_mask,\n            n_samples=num_samples,\n            features=features,\n        )  \n\n        return pred[:,:,-self.prediction_length:,:]\n\n    @torch.no_grad()\n    def sample(self, observation, observation_mask, n_samples, features=None):\n\n        repeated_observation = repeat(observation, n_samples)\n        repeated_observation_mask = repeat(observation_mask, n_samples)\n        repeated_features = repeat(features, n_samples)\n        \n        batch_size, length, ch = repeated_observation.shape\n        seq = torch.randn_like(repeated_observation)\n\n        for i in reversed(range(0, self.timesteps)):\n            if not self.noise_observed:\n                seq = repeated_observation_mask * repeated_observation + seq * (1 - repeated_observation_mask)\n\n            seq = self.p_sample(\n                seq,\n                torch.full((batch_size,), i, device=repeated_observation.device, dtype=torch.long),\n                i,\n                repeated_features,\n            )\n\n        seq = seq.reshape(-1, n_samples, length, ch)\n        return seq \n"
  },
  {
    "path": "probts/model/nn/__init__.py",
    "content": ""
  },
  {
    "path": "probts/model/nn/arch/AutoformerModule/AutoCorrelation.py",
    "content": "import torch\nimport torch.nn as nn\nimport math\n\n\nclass AutoCorrelation(nn.Module):\n    \"\"\"\n    AutoCorrelation Mechanism with the following two phases:\n    (1) period-based dependencies discovery\n    (2) time delay aggregation\n    This block can replace the self-attention family mechanism seamlessly.\n    \"\"\"\n    def __init__(self, mask_flag=True, factor=1, scale=None, attention_dropout=0.1, output_attention=False):\n        super(AutoCorrelation, self).__init__()\n        self.factor = factor\n        self.scale = scale\n        self.mask_flag = mask_flag\n        self.output_attention = output_attention\n        self.dropout = nn.Dropout(attention_dropout)\n\n    def time_delay_agg_training(self, values, corr):\n        \"\"\"\n        SpeedUp version of Autocorrelation (a batch-normalization style design)\n        This is for the training phase.\n        \"\"\"\n        head = values.shape[1]\n        channel = values.shape[2]\n        length = values.shape[3]\n        # find top k\n        top_k = int(self.factor * math.log(length))\n        mean_value = torch.mean(torch.mean(corr, dim=1), dim=1)\n        index = torch.topk(torch.mean(mean_value, dim=0), top_k, dim=-1)[1]\n        weights = torch.stack([mean_value[:, index[i]] for i in range(top_k)], dim=-1)\n        # update corr\n        tmp_corr = torch.softmax(weights, dim=-1)\n        # aggregation\n        tmp_values = values\n        delays_agg = torch.zeros_like(values).float()\n        for i in range(top_k):\n            pattern = torch.roll(tmp_values, -int(index[i]), -1)\n            delays_agg = delays_agg + pattern * \\\n                         (tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length))\n        return delays_agg\n\n    def time_delay_agg_inference(self, values, corr):\n        \"\"\"\n        SpeedUp version of Autocorrelation (a batch-normalization style design)\n        This is for the inference phase.\n        \"\"\"\n        batch = values.shape[0]\n        head = values.shape[1]\n        channel = values.shape[2]\n        length = values.shape[3]\n        # index init\n        init_index = torch.arange(length).unsqueeze(0).unsqueeze(0).unsqueeze(0)\\\n            .repeat(batch, head, channel, 1).to(values.device)\n        # find top k\n        top_k = int(self.factor * math.log(length))\n        mean_value = torch.mean(torch.mean(corr, dim=1), dim=1)\n        weights, delay = torch.topk(mean_value, top_k, dim=-1)\n        # update corr\n        tmp_corr = torch.softmax(weights, dim=-1)\n        # aggregation\n        tmp_values = values.repeat(1, 1, 1, 2)\n        delays_agg = torch.zeros_like(values).float()\n        for i in range(top_k):\n            tmp_delay = init_index + delay[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length)\n            pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay)\n            delays_agg = delays_agg + pattern * \\\n                         (tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length))\n        return delays_agg\n\n    def time_delay_agg_full(self, values, corr):\n        \"\"\"\n        Standard version of Autocorrelation\n        \"\"\"\n        batch = values.shape[0]\n        head = values.shape[1]\n        channel = values.shape[2]\n        length = values.shape[3]\n        # index init\n        init_index = torch.arange(length).unsqueeze(0).unsqueeze(0).unsqueeze(0)\\\n            .repeat(batch, head, channel, 1).to(values.device)\n        # find top k\n        top_k = int(self.factor * math.log(length))\n        weights, delay = torch.topk(corr, top_k, dim=-1)\n        # update corr\n        tmp_corr = torch.softmax(weights, dim=-1)\n        # aggregation\n        tmp_values = values.repeat(1, 1, 1, 2)\n        delays_agg = torch.zeros_like(values).float()\n        for i in range(top_k):\n            tmp_delay = init_index + delay[..., i].unsqueeze(-1)\n            pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay)\n            delays_agg = delays_agg + pattern * (tmp_corr[..., i].unsqueeze(-1))\n        return delays_agg\n\n    def forward(self, queries, keys, values, attn_mask):\n        B, L, H, E = queries.shape\n        _, S, _, D = values.shape\n        if L > S:\n            zeros = torch.zeros_like(queries[:, :(L - S), :]).float()\n            values = torch.cat([values, zeros], dim=1)\n            keys = torch.cat([keys, zeros], dim=1)\n        else:\n            values = values[:, :L, :, :]\n            keys = keys[:, :L, :, :]\n\n        # period-based dependencies\n        q_fft = torch.fft.rfft(queries.permute(0, 2, 3, 1).contiguous(), dim=-1)\n        k_fft = torch.fft.rfft(keys.permute(0, 2, 3, 1).contiguous(), dim=-1)\n        res = q_fft * torch.conj(k_fft)\n        corr = torch.fft.irfft(res, n=L, dim=-1)\n\n        # time delay agg\n        if self.training:\n            V = self.time_delay_agg_training(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2)\n        else:\n            V = self.time_delay_agg_inference(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2)\n\n        if self.output_attention:\n            return (V.contiguous(), corr.permute(0, 3, 1, 2))\n        else:\n            return (V.contiguous(), None)\n\n\nclass AutoCorrelationLayer(nn.Module):\n    def __init__(self, correlation, d_model, n_heads, d_keys=None,\n                 d_values=None):\n        super(AutoCorrelationLayer, self).__init__()\n\n        d_keys = d_keys or (d_model // n_heads)\n        d_values = d_values or (d_model // n_heads)\n\n        self.inner_correlation = correlation\n        self.query_projection = nn.Linear(d_model, d_keys * n_heads)\n        self.key_projection = nn.Linear(d_model, d_keys * n_heads)\n        self.value_projection = nn.Linear(d_model, d_values * n_heads)\n        self.out_projection = nn.Linear(d_values * n_heads, d_model)\n        self.n_heads = n_heads\n\n    def forward(self, queries, keys, values, attn_mask):\n        B, L, _ = queries.shape\n        _, S, _ = keys.shape\n        H = self.n_heads\n\n        queries = self.query_projection(queries).view(B, L, H, -1)\n        keys = self.key_projection(keys).view(B, S, H, -1)\n        values = self.value_projection(values).view(B, S, H, -1)\n\n        out, attn = self.inner_correlation(\n            queries,\n            keys,\n            values,\n            attn_mask\n        )\n        out = out.view(B, L, -1)\n\n        return self.out_projection(out), attn\n"
  },
  {
    "path": "probts/model/nn/arch/AutoformerModule/Autoformer_EncDec.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass my_Layernorm(nn.Module):\n    \"\"\"\n    Special designed layernorm for the seasonal part\n    \"\"\"\n    def __init__(self, channels):\n        super(my_Layernorm, self).__init__()\n        self.layernorm = nn.LayerNorm(channels)\n\n    def forward(self, x):\n        x_hat = self.layernorm(x)\n        bias = torch.mean(x_hat, dim=1).unsqueeze(1).repeat(1, x.shape[1], 1)\n        return x_hat - bias\n\n\nclass moving_avg(nn.Module):\n    \"\"\"\n    Moving average block to highlight the trend of time series\n    \"\"\"\n    def __init__(self, kernel_size, stride):\n        super(moving_avg, self).__init__()\n        self.kernel_size = kernel_size\n        self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)\n\n    def forward(self, x):\n        # padding on the both ends of time series\n        front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1)\n        end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1)\n        x = torch.cat([front, x, end], dim=1)\n        x = self.avg(x.permute(0, 2, 1))\n        x = x.permute(0, 2, 1)\n        return x\n\n\nclass series_decomp(nn.Module):\n    \"\"\"\n    Series decomposition block\n    \"\"\"\n    def __init__(self, kernel_size):\n        super(series_decomp, self).__init__()\n        self.moving_avg = moving_avg(kernel_size, stride=1)\n\n    def forward(self, x):\n        moving_mean = self.moving_avg(x)\n        res = x - moving_mean\n        return res, moving_mean\n\n\nclass EncoderLayer(nn.Module):\n    \"\"\"\n    Autoformer encoder layer with the progressive decomposition architecture\n    \"\"\"\n    def __init__(self, attention, d_model, d_ff=None, moving_avg=25, dropout=0.1, activation=\"relu\"):\n        super(EncoderLayer, self).__init__()\n        d_ff = d_ff or 4 * d_model\n        self.attention = attention\n        self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False)\n        self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False)\n        self.decomp1 = series_decomp(moving_avg)\n        self.decomp2 = series_decomp(moving_avg)\n        self.dropout = nn.Dropout(dropout)\n        self.activation = F.relu if activation == \"relu\" else F.gelu\n\n    def forward(self, x, attn_mask=None):\n        new_x, attn = self.attention(\n            x, x, x,\n            attn_mask=attn_mask\n        )\n        x = x + self.dropout(new_x)\n        x, _ = self.decomp1(x)\n        y = x\n        y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))\n        y = self.dropout(self.conv2(y).transpose(-1, 1))\n        res, _ = self.decomp2(x + y)\n        return res, attn\n\n\nclass Encoder(nn.Module):\n    \"\"\"\n    Autoformer encoder\n    \"\"\"\n    def __init__(self, attn_layers, conv_layers=None, norm_layer=None):\n        super(Encoder, self).__init__()\n        self.attn_layers = nn.ModuleList(attn_layers)\n        self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None\n        self.norm = norm_layer\n\n    def forward(self, x, attn_mask=None):\n        attns = []\n        if self.conv_layers is not None:\n            for attn_layer, conv_layer in zip(self.attn_layers, self.conv_layers):\n                x, attn = attn_layer(x, attn_mask=attn_mask)\n                x = conv_layer(x)\n                attns.append(attn)\n            x, attn = self.attn_layers[-1](x)\n            attns.append(attn)\n        else:\n            for attn_layer in self.attn_layers:\n                x, attn = attn_layer(x, attn_mask=attn_mask)\n                attns.append(attn)\n\n        if self.norm is not None:\n            x = self.norm(x)\n\n        return x, attns\n\n\nclass DecoderLayer(nn.Module):\n    \"\"\"\n    Autoformer decoder layer with the progressive decomposition architecture\n    \"\"\"\n    def __init__(self, self_attention, cross_attention, d_model, c_out, d_ff=None,\n                 moving_avg=25, dropout=0.1, activation=\"relu\"):\n        super(DecoderLayer, self).__init__()\n        d_ff = d_ff or 4 * d_model\n        self.self_attention = self_attention\n        self.cross_attention = cross_attention\n        self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False)\n        self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False)\n        self.decomp1 = series_decomp(moving_avg)\n        self.decomp2 = series_decomp(moving_avg)\n        self.decomp3 = series_decomp(moving_avg)\n        self.dropout = nn.Dropout(dropout)\n        self.projection = nn.Conv1d(in_channels=d_model, out_channels=c_out, kernel_size=3, stride=1, padding=1,\n                                    padding_mode='circular', bias=False)\n        self.activation = F.relu if activation == \"relu\" else F.gelu\n\n    def forward(self, x, cross, x_mask=None, cross_mask=None):\n        x = x + self.dropout(self.self_attention(\n            x, x, x,\n            attn_mask=x_mask\n        )[0])\n        x, trend1 = self.decomp1(x)\n        x = x + self.dropout(self.cross_attention(\n            x, cross, cross,\n            attn_mask=cross_mask\n        )[0])\n        x, trend2 = self.decomp2(x)\n        y = x\n        y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))\n        y = self.dropout(self.conv2(y).transpose(-1, 1))\n        x, trend3 = self.decomp3(x + y)\n\n        residual_trend = trend1 + trend2 + trend3\n        residual_trend = self.projection(residual_trend.permute(0, 2, 1)).transpose(1, 2)\n        return x, residual_trend\n\n\nclass Decoder(nn.Module):\n    \"\"\"\n    Autoformer encoder\n    \"\"\"\n    def __init__(self, layers, norm_layer=None, projection=None):\n        super(Decoder, self).__init__()\n        self.layers = nn.ModuleList(layers)\n        self.norm = norm_layer\n        self.projection = projection\n\n    def forward(self, x, cross, x_mask=None, cross_mask=None, trend=None):\n        for layer in self.layers:\n            x, residual_trend = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask)\n            trend = trend + residual_trend\n\n        if self.norm is not None:\n            x = self.norm(x)\n\n        if self.projection is not None:\n            x = self.projection(x)\n        return x, trend\n"
  },
  {
    "path": "probts/model/nn/arch/ChronosModule/__init__.py",
    "content": "# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n# SPDX-License-Identifier: Apache-2.0\n\nfrom .base import BaseChronosPipeline, ForecastType\nfrom .chronos import (\n    ChronosConfig,\n    ChronosModel,\n    ChronosPipeline,\n    ChronosTokenizer,\n    MeanScaleUniformBins,\n)\n\nfrom .chronos_bolt import ChronosBoltConfig, ChronosBoltPipeline\n\n__all__ = [\n    \"BaseChronosPipeline\",\n    \"ForecastType\",\n    \"ChronosConfig\",\n    \"ChronosModel\",\n    \"ChronosPipeline\",\n    \"ChronosTokenizer\",\n    \"MeanScaleUniformBins\",\n    \"ChronosBoltConfig\",\n    \"ChronosBoltPipeline\",\n]\n"
  },
  {
    "path": "probts/model/nn/arch/ChronosModule/base.py",
    "content": "# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n# SPDX-License-Identifier: Apache-2.0\n\n# Authors: Caner Turkmen <atturkm@amazon.com>, Abdul Fatir Ansari <ansarnd@amazon.com>, Lorenzo Stella <stellalo@amazon.com>\n# Original source:\n# https://github.com/autogluon/autogluon/blob/f57beb26cb769c6e0d484a6af2b89eab8aee73a8/timeseries/src/autogluon/timeseries/models/chronos/pipeline/base.py\n\nfrom enum import Enum\nfrom pathlib import Path\nfrom typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union\n\nimport torch\n\nif TYPE_CHECKING:\n    from transformers import PreTrainedModel\n\nfrom .utils import left_pad_and_stack_1D\n\n\nclass ForecastType(Enum):\n    SAMPLES = \"samples\"\n    QUANTILES = \"quantiles\"\n\n\nclass PipelineRegistry(type):\n    REGISTRY: Dict[str, \"PipelineRegistry\"] = {}\n\n    def __new__(cls, name, bases, attrs):\n        \"\"\"See, https://github.com/faif/python-patterns.\"\"\"\n        new_cls = type.__new__(cls, name, bases, attrs)\n        if name is not None:\n            cls.REGISTRY[name] = new_cls\n\n        return new_cls\n\n\nclass BaseChronosPipeline(metaclass=PipelineRegistry):\n    forecast_type: ForecastType\n    dtypes = {\"bfloat16\": torch.bfloat16, \"float32\": torch.float32}\n\n    def __init__(self, inner_model: \"PreTrainedModel\"):\n        \"\"\"\n        Parameters\n        ----------\n        inner_model : PreTrainedModel\n            A hugging-face transformers PreTrainedModel, e.g., T5ForConditionalGeneration\n        \"\"\"\n        # for easy access to the inner HF-style model\n        self.inner_model = inner_model\n\n    def _prepare_and_validate_context(\n        self, context: Union[torch.Tensor, List[torch.Tensor]]\n    ):\n        if isinstance(context, list):\n            context = left_pad_and_stack_1D(context)\n        assert isinstance(context, torch.Tensor)\n        if context.ndim == 1:\n            context = context.unsqueeze(0)\n        assert context.ndim == 2\n\n        return context\n\n    def predict(\n        self,\n        context: Union[torch.Tensor, List[torch.Tensor]],\n        prediction_length: Optional[int] = None,\n        **kwargs,\n    ):\n        \"\"\"\n        Get forecasts for the given time series. Predictions will be\n        returned in fp32 on the cpu.\n\n        Parameters\n        ----------\n        context\n            Input series. This is either a 1D tensor, or a list\n            of 1D tensors, or a 2D tensor whose first dimension\n            is batch. In the latter case, use left-padding with\n            ``torch.nan`` to align series of different lengths.\n        prediction_length\n            Time steps to predict. Defaults to a model-dependent\n            value if not given.\n\n        Returns\n        -------\n        forecasts\n            Tensor containing forecasts. The layout and meaning\n            of the forecasts values depends on ``self.forecast_type``.\n        \"\"\"\n        raise NotImplementedError()\n\n    def predict_quantiles(\n        self,\n        context: Union[torch.Tensor, List[torch.Tensor]],\n        prediction_length: Optional[int] = None,\n        quantile_levels: List[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],\n        **kwargs,\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\n        Get quantile and mean forecasts for given time series.\n        Predictions will be returned in fp32 on the cpu.\n\n        Parameters\n        ----------\n        context : Union[torch.Tensor, List[torch.Tensor]]\n            Input series. This is either a 1D tensor, or a list\n            of 1D tensors, or a 2D tensor whose first dimension\n            is batch. In the latter case, use left-padding with\n            ``torch.nan`` to align series of different lengths.\n        prediction_length : Optional[int], optional\n            Time steps to predict. Defaults to a model-dependent\n            value if not given.\n        quantile_levels : List[float], optional\n            Quantile levels to compute, by default [0.1, 0.2, ..., 0.9]\n\n        Returns\n        -------\n        quantiles\n            Tensor containing quantile forecasts. Shape\n            (batch_size, prediction_length, num_quantiles)\n        mean\n            Tensor containing mean (point) forecasts. Shape\n            (batch_size, prediction_length)\n        \"\"\"\n        raise NotImplementedError()\n\n    @classmethod\n    def from_pretrained(\n        cls,\n        pretrained_model_name_or_path: Union[str, Path],\n        *model_args,\n        **kwargs,\n    ):\n        \"\"\"\n        Load the model, either from a local path or from the HuggingFace Hub.\n        Supports the same arguments as ``AutoConfig`` and ``AutoModel``\n        from ``transformers``.\n        \"\"\"\n        from transformers import AutoConfig\n\n        torch_dtype = kwargs.get(\"torch_dtype\", \"auto\")\n        if torch_dtype != \"auto\" and isinstance(torch_dtype, str):\n            kwargs[\"torch_dtype\"] = cls.dtypes[torch_dtype]\n\n        config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)\n        is_valid_config = hasattr(config, \"chronos_pipeline_class\") or hasattr(\n            config, \"chronos_config\"\n        )\n\n        if not is_valid_config:\n            raise ValueError(\"Not a Chronos config file\")\n\n        pipeline_class_name = getattr(\n            config, \"chronos_pipeline_class\", \"ChronosPipeline\"\n        )\n        class_ = PipelineRegistry.REGISTRY.get(pipeline_class_name)\n        if class_ is None:\n            raise ValueError(\n                f\"Trying to load unknown pipeline class: {pipeline_class_name}\"\n            )\n\n        return class_.from_pretrained(  # type: ignore[attr-defined]\n            pretrained_model_name_or_path, *model_args, **kwargs\n        )\n"
  },
  {
    "path": "probts/model/nn/arch/ChronosModule/chronos.py",
    "content": "# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n# SPDX-License-Identifier: Apache-2.0\n\n# Authors: Abdul Fatir Ansari <ansarnd@amazon.com>, Lorenzo Stella <stellalo@amazon.com>, Caner Turkmen <atturkm@amazon.com>\n\nimport logging\nfrom dataclasses import dataclass\nfrom typing import Any, Dict, List, Literal, Optional, Tuple, Union\nfrom einops import rearrange\nimport sys\nfrom .loss import LabelSmoother\nimport torch\nimport torch.nn as nn\nfrom transformers import (\n    AutoConfig,\n    AutoModelForCausalLM,\n    AutoModelForSeq2SeqLM,\n    GenerationConfig,\n    PreTrainedModel,\n)\n# import chronos\nfrom probts.model.nn.arch import ChronosModule\nfrom .base import BaseChronosPipeline, ForecastType\nfrom .utils import left_pad_and_stack_1D\n\nlogger = logging.getLogger(__file__)\n\n\n@dataclass\nclass ChronosConfig:\n    \"\"\"\n    This class holds all the configuration parameters to be used\n    by ``ChronosTokenizer`` and ``ChronosModel``.\n    \"\"\"\n\n    tokenizer_class: str\n    tokenizer_kwargs: Dict[str, Any]\n    context_length: int\n    prediction_length: int\n    n_tokens: int\n    n_special_tokens: int\n    pad_token_id: int\n    eos_token_id: int\n    use_eos_token: bool\n    model_type: Literal[\"causal\", \"seq2seq\"]\n    num_samples: int\n    temperature: float\n    top_k: int\n    top_p: float\n\n    def __post_init__(self):\n        assert (\n            self.pad_token_id < self.n_special_tokens\n            and self.eos_token_id < self.n_special_tokens\n        ), f\"Special token id's must be smaller than {self.n_special_tokens=}\"\n\n    def create_tokenizer(self) -> \"ChronosTokenizer\":\n        class_ = getattr(ChronosModule, self.tokenizer_class)\n        return class_(**self.tokenizer_kwargs, config=self)\n\n\nclass ChronosTokenizer:\n    \"\"\"\n    A ``ChronosTokenizer`` definines how time series are mapped into token IDs\n    and back.\n\n    For details, see the ``input_transform`` and ``output_transform`` methods,\n    which concrete classes must implement.\n    \"\"\"\n\n    def context_input_transform(\n        self,\n        context: torch.Tensor,\n    ) -> Tuple:\n        \"\"\"\n        Turn a batch of time series into token IDs, attention map, and tokenizer_state.\n\n        Parameters\n        ----------\n        context\n            A tensor shaped (batch_size, time_length), containing the\n            timeseries to forecast. Use left-padding with ``torch.nan``\n            to align time series of different lengths.\n\n        Returns\n        -------\n        token_ids\n            A tensor of integers, shaped (batch_size, time_length + 1)\n            if ``config.use_eos_token`` and (batch_size, time_length)\n            otherwise, containing token IDs for the input series.\n        attention_mask\n            A boolean tensor, same shape as ``token_ids``, indicating\n            which input observations are not ``torch.nan`` (i.e. not\n            missing nor padding).\n        tokenizer_state\n            An object that can be passed to ``label_input_transform``\n            and ``output_transform``. Contains the relevant information\n            to decode output samples into real values,\n            such as location and scale parameters.\n        \"\"\"\n        raise NotImplementedError()\n\n    def label_input_transform(self, label: torch.Tensor, tokenizer_state: Any) -> Tuple:\n        \"\"\"\n        Turn a batch of label slices of time series into token IDs and attention map\n        using the ``tokenizer_state`` provided by ``context_input_transform``.\n\n        Parameters\n        ----------\n        context\n            A tensor shaped (batch_size, time_length), containing the\n            timeseries to forecast. Use left-padding with ``torch.nan``\n            to align time series of different lengths.\n        tokenizer_state\n            An object returned by ``context_input_transform`` containing\n            relevant information to preprocess data, such as location and\n            scale. The nature of this depends on the specific tokenizer.\n            This is used for tokenizing the label, in order to use the same\n            scaling used to tokenize the context.\n\n        Returns\n        -------\n        token_ids\n            A tensor of integers, shaped (batch_size, time_length + 1)\n            if ``config.use_eos_token`` and (batch_size, time_length)\n            otherwise, containing token IDs for the input series.\n        attention_mask\n            A boolean tensor, same shape as ``token_ids``, indicating\n            which input observations are not ``torch.nan`` (i.e. not\n            missing nor padding).\n        \"\"\"\n        raise NotImplementedError()\n\n    def output_transform(\n        self, samples: torch.Tensor, tokenizer_state: Any\n    ) -> torch.Tensor:\n        \"\"\"\n        Turn a batch of sample token IDs into real values.\n\n        Parameters\n        ----------\n        samples\n            A tensor of integers, shaped (batch_size, num_samples, time_length),\n            containing token IDs of sample trajectories.\n        tokenizer_state\n            An object returned by ``input_transform`` containing\n            relevant context to decode samples, such as location and scale.\n            The nature of this depends on the specific tokenizer.\n\n        Returns\n        -------\n        forecasts\n            A real tensor, shaped (batch_size, num_samples, time_length),\n            containing forecasted sample paths.\n        \"\"\"\n        raise NotImplementedError()\n\n\nclass MeanScaleUniformBins(ChronosTokenizer):\n    def __init__(\n        self, low_limit: float, high_limit: float, config: ChronosConfig, \n    ) -> None:\n        self.config = config\n        device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n        self.centers = torch.linspace(\n            low_limit,\n            high_limit,\n            config.n_tokens - config.n_special_tokens - 1,\n        ).to(device)\n        self.boundaries = torch.concat(\n            (\n                torch.tensor([-1e20], device=self.centers.device),\n                (self.centers[1:] + self.centers[:-1]) / 2,\n                torch.tensor([1e20], device=self.centers.device),\n            )\n        )\n\n    def _input_transform(\n        self, context: torch.Tensor, scale: Optional[torch.Tensor] = None\n    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n        context = context.to(dtype=torch.float32)\n        attention_mask = ~torch.isnan(context) #.to(context.device)\n\n        if scale is None:\n            scale = torch.nansum(\n                torch.abs(context) * attention_mask, dim=-1\n            ) / torch.nansum(attention_mask, dim=-1)\n            scale[~(scale > 0)] = 1.0\n\n        scaled_context = context / scale.unsqueeze(dim=-1)\n        token_ids = (\n            torch.bucketize(\n                input=scaled_context,\n                boundaries=self.boundaries,\n                # buckets are open to the right, see:\n                # https://pytorch.org/docs/2.1/generated/torch.bucketize.html#torch-bucketize\n                right=True,\n            )\n            + self.config.n_special_tokens\n        )\n\n        token_ids.clamp_(0, self.config.n_tokens - 1)\n\n        token_ids[~attention_mask] = self.config.pad_token_id\n\n        return token_ids, attention_mask, scale\n\n    def _append_eos_token(\n        self, token_ids: torch.Tensor, attention_mask: torch.Tensor\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        batch_size = token_ids.shape[0]\n        eos_tokens = torch.full((batch_size, 1), fill_value=self.config.eos_token_id).to(token_ids.device)\n        token_ids = torch.concat((token_ids, eos_tokens), dim=1)\n        eos_mask = torch.full((batch_size, 1), fill_value=True).to(attention_mask.device)\n        attention_mask = torch.concat((attention_mask, eos_mask), dim=1)\n\n        return token_ids, attention_mask\n\n    def context_input_transform(\n        self, context: torch.Tensor\n    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n        length = context.shape[-1]\n\n        if length > self.config.context_length:\n            context = context[..., -self.config.context_length :]\n\n        token_ids, attention_mask, scale = self._input_transform(context=context)\n\n        if self.config.use_eos_token and self.config.model_type == \"seq2seq\":\n            token_ids, attention_mask = self._append_eos_token(\n                token_ids=token_ids, attention_mask=attention_mask\n            )\n\n        return token_ids, attention_mask, scale\n\n    def label_input_transform(\n        self, label: torch.Tensor, scale: torch.Tensor\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        length = label.shape[-1]\n\n        assert length == self.config.prediction_length\n        token_ids, attention_mask, _ = self._input_transform(context=label, scale=scale)\n\n        if self.config.use_eos_token:\n            token_ids, attention_mask = self._append_eos_token(\n                token_ids=token_ids, attention_mask=attention_mask\n            )\n\n        return token_ids, attention_mask\n\n    def output_transform(\n        self, samples: torch.Tensor, scale: torch.Tensor\n    ) -> torch.Tensor:\n        scale_unsqueezed = scale.unsqueeze(-1).unsqueeze(-1)\n        indices = torch.clamp(\n            samples - self.config.n_special_tokens - 1,\n            min=0,\n            max=len(self.centers) - 1,\n        )\n        return self.centers[indices] * scale_unsqueezed\n\n\nclass ChronosModel(nn.Module):\n    \"\"\"\n    A ``ChronosModel`` wraps a ``PreTrainedModel`` object from ``transformers``\n    and uses it to predict sample paths for time series tokens.\n\n    Parameters\n    ----------\n    config\n        The configuration to use.\n    model\n        The pretrained model to use.\n    \"\"\"\n\n    def __init__(self, config: ChronosConfig, model: PreTrainedModel) -> None:\n        super().__init__()\n        self.config = config\n        self.model = model\n\n    @property\n    def device(self):\n        return self.model.device\n\n    def encode(\n        self,\n        input_ids: torch.Tensor,\n        attention_mask: torch.Tensor,\n    ):\n        \"\"\"\n        Extract the encoder embedding for the given token sequences.\n\n        Parameters\n        ----------\n        input_ids\n            Tensor of indices of input sequence tokens in the vocabulary\n            with shape (batch_size, sequence_length).\n        attention_mask\n            A mask tensor of the same shape as input_ids to avoid attending\n            on padding or missing tokens.\n\n        Returns\n        -------\n        embedding\n            A tensor of encoder embeddings with shape\n            (batch_size, sequence_length, d_model).\n        \"\"\"\n        assert (\n            self.config.model_type == \"seq2seq\"\n        ), \"Encoder embeddings are only supported for encoder-decoder models\"\n        return self.model.encoder(\n            input_ids=input_ids, attention_mask=attention_mask\n        ).last_hidden_state\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        attention_mask: torch.Tensor,\n        prediction_length: Optional[int] = None,\n        num_samples: Optional[int] = None,\n        temperature: Optional[float] = None,\n        top_k: Optional[int] = None,\n        top_p: Optional[float] = None,\n    ) -> torch.Tensor:\n        \"\"\"\n        Predict future sample tokens for the given token sequences.\n\n        Arguments ``prediction_length``, ``num_samples``, ``temperature``,\n        ``top_k``, ``top_p`` can be used to customize the model inference,\n        and default to the corresponding attributes in ``self.config`` if\n        not provided.\n\n        Returns\n        -------\n        samples\n            A tensor of integers, shaped (batch_size, num_samples, time_length),\n            containing forecasted sample paths.\n        \"\"\"\n        if prediction_length is None:\n            prediction_length = self.config.prediction_length\n        if num_samples is None:\n            num_samples = self.config.num_samples\n        if temperature is None:\n            temperature = self.config.temperature\n        if top_k is None:\n            top_k = self.config.top_k\n        if top_p is None:\n            top_p = self.config.top_p\n\n        preds = self.model.generate(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            generation_config=GenerationConfig(\n                min_new_tokens=prediction_length,\n                max_new_tokens=prediction_length,\n                do_sample=True,\n                num_return_sequences=num_samples,\n                eos_token_id=self.config.eos_token_id,\n                pad_token_id=self.config.pad_token_id,\n                temperature=temperature,\n                top_k=top_k,\n                top_p=top_p,\n            ),\n        )\n\n        if self.config.model_type == \"seq2seq\":\n            preds = preds[..., 1:]  # remove the decoder start token\n        else:\n            assert self.config.model_type == \"causal\"\n            assert preds.size(-1) == input_ids.size(-1) + prediction_length\n            preds = preds[..., -prediction_length:]\n\n        return preds.reshape(input_ids.size(0), num_samples, -1)\n\n\nclass ChronosPipeline(BaseChronosPipeline):\n    \"\"\"\n    A ``ChronosPipeline`` uses the given tokenizer and model to forecast\n    input time series.\n\n    Use the ``from_pretrained`` class method to load serialized models.\n    Use the ``predict`` method to get forecasts.\n\n    Parameters\n    ----------\n    tokenizer\n        The tokenizer object to use.\n    model\n        The model to use.\n    \"\"\"\n\n    tokenizer: ChronosTokenizer\n    model: ChronosModel\n    forecast_type: ForecastType = ForecastType.SAMPLES\n\n    def __init__(self, tokenizer, model):\n        super().__init__(inner_model=model.model)\n        self.tokenizer = tokenizer\n        self.model = model\n        self.loss_func = LabelSmoother()\n\n    def _prepare_and_validate_context(\n        self, context: Union[torch.Tensor, List[torch.Tensor]]\n    ):\n        if isinstance(context, list):\n            context = left_pad_and_stack_1D(context)\n        assert isinstance(context, torch.Tensor)\n        if context.ndim == 1:\n            context = context.unsqueeze(0)\n        assert context.ndim == 2\n\n        return context\n\n    @torch.no_grad()\n    def embed(\n        self, context: Union[torch.Tensor, List[torch.Tensor]]\n    ) -> Tuple[torch.Tensor, Any]:\n        \"\"\"\n        Get encoder embeddings for the given time series.\n\n        Parameters\n        ----------\n        context\n            Input series. This is either a 1D tensor, or a list\n            of 1D tensors, or a 2D tensor whose first dimension\n            is batch. In the latter case, use left-padding with\n            ``torch.nan`` to align series of different lengths.\n\n        Returns\n        -------\n        embeddings, tokenizer_state\n            A tuple of two tensors: the encoder embeddings and the tokenizer_state,\n            e.g., the scale of the time series in the case of mean scaling.\n            The encoder embeddings are shaped (batch_size, context_length, d_model)\n            or (batch_size, context_length + 1, d_model), where context_length\n            is the size of the context along the time axis if a 2D tensor was provided\n            or the length of the longest time series, if a list of 1D tensors was\n            provided, and the extra 1 is for EOS.\n        \"\"\"\n        context_tensor = self._prepare_and_validate_context(context=context)\n        token_ids, attention_mask, tokenizer_state = (\n            self.tokenizer.context_input_transform(context_tensor)\n        )\n        embeddings = self.model.encode(\n            input_ids=token_ids.to(self.model.device),\n            attention_mask=attention_mask.to(self.model.device),\n        ).cpu()\n        return embeddings, tokenizer_state\n\n    def predict(  # type: ignore[override]\n        self,\n        context: Union[torch.Tensor, List[torch.Tensor]],\n        prediction_length: Optional[int] = None,\n        num_samples: Optional[int] = None,\n        temperature: Optional[float] = None,\n        top_k: Optional[int] = None,\n        top_p: Optional[float] = None,\n        limit_prediction_length: bool = False,\n    ) -> torch.Tensor:\n        \"\"\"\n        Get forecasts for the given time series.\n\n        Refer to the base method (``BaseChronosPipeline.predict``)\n        for details on shared parameters.\n\n        Additional parameters\n        ---------------------\n        num_samples\n            Number of sample paths to predict. Defaults to what\n            specified in ``self.model.config``.\n        temperature\n            Temperature to use for generating sample tokens.\n            Defaults to what specified in ``self.model.config``.\n        top_k\n            Top-k parameter to use for generating sample tokens.\n            Defaults to what specified in ``self.model.config``.\n        top_p\n            Top-p parameter to use for generating sample tokens.\n            Defaults to what specified in ``self.model.config``.\n        limit_prediction_length\n            Force prediction length smaller or equal than the\n            built-in prediction length from the model. False by\n            default. When true, fail loudly if longer predictions\n            are requested, otherwise longer predictions are allowed.\n\n        Returns\n        -------\n        samples\n            Tensor of sample forecasts, of shape\n            (batch_size, num_samples, prediction_length).\n        \"\"\"\n        context_tensor = self._prepare_and_validate_context(context=context)\n\n        if prediction_length is None:\n            prediction_length = self.model.config.prediction_length\n\n        # if prediction_length > self.model.config.prediction_length:\n        #     msg = (\n        #         f\"We recommend keeping prediction length <= {self.model.config.prediction_length}. \"\n        #         \"The quality of longer predictions may degrade since the model is not optimized for it. \"\n        #     )\n        #     if limit_prediction_length:\n        #         msg += \"You can turn off this check by setting `limit_prediction_length=False`.\"\n        #         raise ValueError(msg)\n        #     logger.warning(msg)\n\n        predictions = []\n        remaining = prediction_length\n\n        while remaining > 0:\n            token_ids, attention_mask, scale = self.tokenizer.context_input_transform(\n                context_tensor\n            )\n            samples = self.model(\n                token_ids.to(self.model.device),\n                attention_mask.to(self.model.device),\n                min(remaining, self.model.config.prediction_length),\n                num_samples,\n                temperature,\n                top_k,\n                top_p,\n            )\n            prediction = self.tokenizer.output_transform(\n                samples.to(scale.device), scale\n            )\n\n            predictions.append(prediction)\n            remaining -= prediction.shape[-1]\n\n            if remaining <= 0:\n                break\n\n            context_tensor = torch.cat(\n                [context_tensor, prediction.median(dim=1).values], dim=-1\n            )\n\n        return torch.cat(predictions, dim=-1).to(dtype=torch.float32, device=\"cpu\")\n\n    def predict_quantiles(\n        self,\n        context: Union[torch.Tensor, List[torch.Tensor]],\n        prediction_length: Optional[int] = None,\n        quantile_levels: List[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],\n        **predict_kwargs,\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\n        Refer to the base method (``BaseChronosPipeline.predict_quantiles``).\n        \"\"\"\n        \n        shape_dim = context.shape\n        if len(shape_dim) == 3:\n            context = rearrange(context, 'b k l -> (b k) l')\n            \n        prediction_samples = (\n            self.predict(context, prediction_length=prediction_length, **predict_kwargs)\n            .detach()\n            .swapaxes(1, 2)\n        )\n        mean = prediction_samples.mean(dim=-1)\n        quantiles = torch.quantile(\n            prediction_samples,\n            q=torch.tensor(quantile_levels, dtype=prediction_samples.dtype),\n            dim=-1,\n        ).permute(1, 2, 0)\n        \n        if len(shape_dim) == 3:\n            quantiles = rearrange(quantiles, '(b k) l q -> b k l q', b=shape_dim[0])\n            mean = rearrange(mean, '(b k) l -> b k l',b=shape_dim[0])\n\n        return mean, quantiles\n\n    @classmethod\n    def from_pretrained(cls, *args, **kwargs):\n        \"\"\"\n        Load the model, either from a local path or from the HuggingFace Hub.\n        Supports the same arguments as ``AutoConfig`` and ``AutoModel``\n        from ``transformers``.\n        \"\"\"\n\n        config = AutoConfig.from_pretrained(*args, **kwargs)\n\n        assert hasattr(config, \"chronos_config\"), \"Not a Chronos config file\"\n\n        chronos_config = ChronosConfig(**config.chronos_config)\n\n        if chronos_config.model_type == \"seq2seq\":\n            inner_model = AutoModelForSeq2SeqLM.from_pretrained(*args, **kwargs)\n        else:\n            assert chronos_config.model_type == \"causal\"\n            inner_model = AutoModelForCausalLM.from_pretrained(*args, **kwargs)\n\n        return cls(\n            tokenizer=chronos_config.create_tokenizer(),\n            model=ChronosModel(config=chronos_config, model=inner_model),\n        )\n        \n"
  },
  {
    "path": "probts/model/nn/arch/ChronosModule/chronos_bolt.py",
    "content": "# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n# SPDX-License-Identifier: Apache-2.0\n\n# Authors: Abdul Fatir Ansari <ansarnd@amazon.com>, Caner Turkmen <atturkm@amazon.com>, Lorenzo Stella <stellalo@amazon.com>\n# Original source:\n# https://github.com/autogluon/autogluon/blob/f57beb26cb769c6e0d484a6af2b89eab8aee73a8/timeseries/src/autogluon/timeseries/models/chronos/pipeline/chronos_bolt.py\n\nimport copy\nimport logging\nimport warnings\nfrom dataclasses import dataclass\nfrom typing import List, Optional, Tuple, Union\n\nimport torch\nimport torch.nn as nn\nfrom transformers import AutoConfig\nfrom transformers.models.t5.modeling_t5 import (\n    ACT2FN,\n    T5Config,\n    T5LayerNorm,\n    T5PreTrainedModel,\n    T5Stack,\n)\nfrom transformers.utils import ModelOutput\n\nfrom .base import BaseChronosPipeline, ForecastType\n\nlogger = logging.getLogger(__file__)\n\n\n@dataclass\nclass ChronosBoltConfig:\n    context_length: int\n    prediction_length: int\n    input_patch_size: int\n    input_patch_stride: int\n    quantiles: List[float]\n    use_reg_token: bool = False\n\n\n@dataclass\nclass ChronosBoltOutput(ModelOutput):\n    loss: Optional[torch.Tensor] = None\n    quantile_preds: Optional[torch.Tensor] = None\n    attentions: Optional[torch.Tensor] = None\n    cross_attentions: Optional[torch.Tensor] = None\n\n\nclass Patch(nn.Module):\n    def __init__(self, patch_size: int, patch_stride: int) -> None:\n        super().__init__()\n        self.patch_size = patch_size\n        self.patch_stride = patch_stride\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        length = x.shape[-1]\n\n        if length % self.patch_size != 0:\n            padding_size = (\n                *x.shape[:-1],\n                self.patch_size - (length % self.patch_size),\n            )\n            padding = torch.full(\n                size=padding_size, fill_value=torch.nan, dtype=x.dtype, device=x.device\n            )\n            x = torch.concat((padding, x), dim=-1)\n\n        x = x.unfold(dimension=-1, size=self.patch_size, step=self.patch_stride)\n        return x\n\n\nclass InstanceNorm(nn.Module):\n    \"\"\"\n    See, also, RevIN. Apply standardization along the last dimension.\n    \"\"\"\n\n    def __init__(self, eps: float = 1e-5) -> None:\n        super().__init__()\n        self.eps = eps\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        loc_scale: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,\n    ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:\n        if loc_scale is None:\n            loc = torch.nan_to_num(torch.nanmean(x, dim=-1, keepdim=True), nan=0.0)\n            scale = torch.nan_to_num(\n                torch.nanmean((x - loc).square(), dim=-1, keepdim=True).sqrt(), nan=1.0\n            )\n            scale = torch.where(scale == 0, torch.abs(loc) + self.eps, scale)\n        else:\n            loc, scale = loc_scale\n\n        return (x - loc) / scale, (loc, scale)\n\n    def inverse(\n        self, x: torch.Tensor, loc_scale: Tuple[torch.Tensor, torch.Tensor]\n    ) -> torch.Tensor:\n        loc, scale = loc_scale\n        return x * scale + loc\n\n\nclass ResidualBlock(nn.Module):\n    def __init__(\n        self,\n        in_dim: int,\n        h_dim: int,\n        out_dim: int,\n        act_fn_name: str,\n        dropout_p: float = 0.0,\n        use_layer_norm: bool = False,\n    ) -> None:\n        super().__init__()\n\n        self.dropout = nn.Dropout(dropout_p)\n        self.hidden_layer = nn.Linear(in_dim, h_dim)\n        self.act = ACT2FN[act_fn_name]\n        self.output_layer = nn.Linear(h_dim, out_dim)\n        self.residual_layer = nn.Linear(in_dim, out_dim)\n\n        self.use_layer_norm = use_layer_norm\n        if use_layer_norm:\n            self.layer_norm = T5LayerNorm(out_dim)\n\n    def forward(self, x: torch.Tensor):\n        hid = self.act(self.hidden_layer(x))\n        out = self.dropout(self.output_layer(hid))\n        res = self.residual_layer(x)\n\n        out = out + res\n\n        if self.use_layer_norm:\n            return self.layer_norm(out)\n        return out\n\n\nclass ChronosBoltModelForForecasting(T5PreTrainedModel):\n    _keys_to_ignore_on_load_missing = [\n        r\"input_patch_embedding\\.\",\n        r\"output_patch_embedding\\.\",\n    ]\n    _keys_to_ignore_on_load_unexpected = [r\"lm_head.weight\"]\n    _tied_weights_keys = [\"encoder.embed_tokens.weight\", \"decoder.embed_tokens.weight\"]\n\n    def __init__(self, config: T5Config):\n        assert hasattr(config, \"chronos_config\"), \"Not a Chronos config file\"\n\n        super().__init__(config)\n        self.model_dim = config.d_model\n\n        self.chronos_config = ChronosBoltConfig(**config.chronos_config)\n\n        # Only decoder_start_id (and optionally REG token)\n        if self.chronos_config.use_reg_token:\n            config.reg_token_id = 1\n\n        config.vocab_size = 2 if self.chronos_config.use_reg_token else 1\n        self.shared = nn.Embedding(config.vocab_size, config.d_model)\n\n        # Input patch embedding layer\n        self.input_patch_embedding = ResidualBlock(\n            in_dim=self.chronos_config.input_patch_size * 2,\n            h_dim=config.d_ff,\n            out_dim=config.d_model,\n            act_fn_name=config.dense_act_fn,\n            dropout_p=config.dropout_rate,\n        )\n\n        # patching layer\n        self.patch = Patch(\n            patch_size=self.chronos_config.input_patch_size,\n            patch_stride=self.chronos_config.input_patch_stride,\n        )\n\n        # instance normalization, also referred to as \"scaling\" in Chronos and GluonTS\n        self.instance_norm = InstanceNorm()\n\n        encoder_config = copy.deepcopy(config)\n        encoder_config.is_decoder = False\n        encoder_config.use_cache = False\n        encoder_config.is_encoder_decoder = False\n        self.encoder = T5Stack(encoder_config, self.shared)\n\n        self._init_decoder(config)\n\n        self.num_quantiles = len(self.chronos_config.quantiles)\n        quantiles = torch.tensor(self.chronos_config.quantiles, dtype=self.dtype)\n        self.register_buffer(\"quantiles\", quantiles, persistent=False)\n\n        self.output_patch_embedding = ResidualBlock(\n            in_dim=config.d_model,\n            h_dim=config.d_ff,\n            out_dim=self.num_quantiles * self.chronos_config.prediction_length,\n            act_fn_name=config.dense_act_fn,\n            dropout_p=config.dropout_rate,\n        )\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n        # Model parallel\n        self.model_parallel = False\n        self.device_map = None\n\n    def _init_weights(self, module):\n        super()._init_weights(module)\n        \"\"\"Initialize the weights\"\"\"\n        factor = self.config.initializer_factor\n        if isinstance(module, (self.__class__)):\n            module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)\n        elif isinstance(module, ResidualBlock):\n            module.hidden_layer.weight.data.normal_(\n                mean=0.0,\n                std=factor * ((self.chronos_config.input_patch_size * 2) ** -0.5),\n            )\n            if (\n                hasattr(module.hidden_layer, \"bias\")\n                and module.hidden_layer.bias is not None\n            ):\n                module.hidden_layer.bias.data.zero_()\n\n            module.residual_layer.weight.data.normal_(\n                mean=0.0,\n                std=factor * ((self.chronos_config.input_patch_size * 2) ** -0.5),\n            )\n            if (\n                hasattr(module.residual_layer, \"bias\")\n                and module.residual_layer.bias is not None\n            ):\n                module.residual_layer.bias.data.zero_()\n\n            module.output_layer.weight.data.normal_(\n                mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)\n            )\n            if (\n                hasattr(module.output_layer, \"bias\")\n                and module.output_layer.bias is not None\n            ):\n                module.output_layer.bias.data.zero_()\n\n    def encode(\n        self, context: torch.Tensor, mask: Optional[torch.Tensor] = None\n    ) -> Tuple[\n        torch.Tensor, Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor\n    ]:\n        mask = (\n            mask.to(context.dtype)\n            if mask is not None\n            else torch.isnan(context).logical_not().to(context.dtype)\n        )\n\n        batch_size, _ = context.shape\n        if context.shape[-1] > self.chronos_config.context_length:\n            context = context[..., -self.chronos_config.context_length :]\n            mask = mask[..., -self.chronos_config.context_length :]\n\n        # scaling\n        context, loc_scale = self.instance_norm(context)\n\n        # the scaling op above is done in 32-bit precision,\n        # then the context is moved to model's dtype\n        context = context.to(self.dtype)\n        mask = mask.to(self.dtype)\n\n        # patching\n        patched_context = self.patch(context)\n        patched_mask = torch.nan_to_num(self.patch(mask), nan=0.0)\n        patched_context = torch.where(patched_mask > 0.0, patched_context, 0.0)\n        # concat context and mask along patch dim\n        patched_context = torch.cat([patched_context, patched_mask], dim=-1)\n\n        # attention_mask = 1 if at least one item in the patch is observed\n        attention_mask = (\n            patched_mask.sum(dim=-1) > 0\n        )  # (batch_size, patched_seq_length)\n\n        input_embeds = self.input_patch_embedding(patched_context)\n\n        if self.chronos_config.use_reg_token:\n            # Append [REG]\n            reg_input_ids = torch.full(\n                (batch_size, 1),\n                self.config.reg_token_id,\n                device=input_embeds.device,\n            )\n            reg_embeds = self.shared(reg_input_ids)\n            input_embeds = torch.cat([input_embeds, reg_embeds], dim=-2)\n            attention_mask = torch.cat(\n                [\n                    attention_mask.to(self.dtype),\n                    torch.ones_like(reg_input_ids).to(self.dtype),\n                ],\n                dim=-1,\n            )\n\n        encoder_outputs = self.encoder(\n            attention_mask=attention_mask,\n            inputs_embeds=input_embeds,\n        )\n\n        return encoder_outputs[0], loc_scale, input_embeds, attention_mask\n\n    def forward(\n        self,\n        context: torch.Tensor,\n        mask: Optional[torch.Tensor] = None,\n        target: Optional[torch.Tensor] = None,\n        target_mask: Optional[torch.Tensor] = None,\n    ) -> ChronosBoltOutput:\n        batch_size = context.size(0)\n\n        hidden_states, loc_scale, input_embeds, attention_mask = self.encode(\n            context=context, mask=mask\n        )\n        sequence_output = self.decode(input_embeds, attention_mask, hidden_states)\n\n        quantile_preds_shape = (\n            batch_size,\n            self.num_quantiles,\n            self.chronos_config.prediction_length,\n        )\n        quantile_preds = self.output_patch_embedding(sequence_output).view(\n            *quantile_preds_shape\n        )\n\n        loss = None\n        if target is not None:\n            # normalize target\n            target, _ = self.instance_norm(target, loc_scale)\n            target = target.unsqueeze(1)  # type: ignore\n            assert self.chronos_config.prediction_length >= target.shape[-1]\n\n            target = target.to(quantile_preds.device)\n            target_mask = (\n                target_mask.unsqueeze(1).to(quantile_preds.device)\n                if target_mask is not None\n                else ~torch.isnan(target)\n            )\n            target[~target_mask] = 0.0\n\n            # pad target and target_mask if they are shorter than model's prediction_length\n            if self.chronos_config.prediction_length > target.shape[-1]:\n                padding_shape = (\n                    *target.shape[:-1],\n                    self.chronos_config.prediction_length - target.shape[-1],\n                )\n                target = torch.cat(\n                    [target, torch.zeros(padding_shape).to(target)], dim=-1\n                )\n                target_mask = torch.cat(\n                    [target_mask, torch.zeros(padding_shape).to(target_mask)], dim=-1\n                )\n\n            loss = (\n                2\n                * torch.abs(\n                    (target - quantile_preds)\n                    * (\n                        (target <= quantile_preds).float()\n                        - self.quantiles.view(1, self.num_quantiles, 1)\n                    )\n                )\n                * target_mask.float()\n            )\n            loss = loss.mean(dim=-2)  # Mean over prediction horizon\n            loss = loss.sum(dim=-1)  # Sum over quantile levels\n            loss = loss.mean()  # Mean over batch\n\n        # Unscale predictions\n        quantile_preds = self.instance_norm.inverse(\n            quantile_preds.view(batch_size, -1),\n            loc_scale,\n        ).view(*quantile_preds_shape)\n\n        return ChronosBoltOutput(\n            loss=loss,\n            quantile_preds=quantile_preds,\n        )\n\n    def _init_decoder(self, config):\n        decoder_config = copy.deepcopy(config)\n        decoder_config.is_decoder = True\n        decoder_config.is_encoder_decoder = False\n        decoder_config.num_layers = config.num_decoder_layers\n        self.decoder = T5Stack(decoder_config, self.shared)\n\n    def decode(\n        self,\n        input_embeds,\n        attention_mask,\n        hidden_states,\n        output_attentions=False,\n    ):\n        \"\"\"\n        Parameters\n        ----------\n        input_embeds: torch.Tensor\n            Patched and embedded inputs. Shape (batch_size, patched_context_length, d_model)\n        attention_mask: torch.Tensor\n            Attention mask for the patched context. Shape (batch_size, patched_context_length), type: torch.int64\n        hidden_states: torch.Tensor\n            Hidden states returned by the encoder. Shape (batch_size, patched_context_length, d_model)\n\n        Returns\n        -------\n        last_hidden_state\n            Last hidden state returned by the decoder, of shape (batch_size, 1, d_model)\n        \"\"\"\n        batch_size = input_embeds.shape[0]\n        decoder_input_ids = torch.full(\n            (batch_size, 1),\n            self.config.decoder_start_token_id,\n            device=input_embeds.device,\n        )\n        decoder_outputs = self.decoder(\n            input_ids=decoder_input_ids,\n            encoder_hidden_states=hidden_states,\n            encoder_attention_mask=attention_mask,\n            output_attentions=output_attentions,\n            return_dict=True,\n        )\n\n        return decoder_outputs.last_hidden_state  # sequence_outputs, b x 1 x d_model\n\n\nclass ChronosBoltPipeline(BaseChronosPipeline):\n    forecast_type: ForecastType = ForecastType.QUANTILES\n    default_context_length: int = 2048\n\n    def __init__(self, model: ChronosBoltModelForForecasting):\n        super().__init__(inner_model=model)\n        self.model = model\n\n    @property\n    def quantiles(self) -> List[float]:\n        return self.model.config.chronos_config[\"quantiles\"]\n\n    @torch.no_grad()\n    def embed(\n        self, context: Union[torch.Tensor, List[torch.Tensor]]\n    ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:\n        \"\"\"\n        Get encoder embeddings for the given time series.\n\n        Parameters\n        ----------\n        context\n            Input series. This is either a 1D tensor, or a list\n            of 1D tensors, or a 2D tensor whose first dimension\n            is batch. In the latter case, use left-padding with\n            ``torch.nan`` to align series of different lengths.\n\n        Returns\n        -------\n        embeddings, loc_scale\n            A tuple of two items: the encoder embeddings and the loc_scale,\n            i.e., the mean and std of the original time series.\n            The encoder embeddings are shaped (batch_size, num_patches + 1, d_model),\n            where num_patches is the number of patches in the time series\n            and the extra 1 is for the [REG] token (if used by the model).\n        \"\"\"\n        context_tensor = self._prepare_and_validate_context(context=context)\n        model_context_length = self.model.config.chronos_config[\"context_length\"]\n\n        if context_tensor.shape[-1] > model_context_length:\n            context_tensor = context_tensor[..., -model_context_length:]\n\n        context_tensor = context_tensor.to(\n            device=self.model.device,\n            dtype=torch.float32,\n        )\n        embeddings, loc_scale, *_ = self.model.encode(context=context_tensor)\n        return embeddings.cpu(), (\n            loc_scale[0].squeeze(-1).cpu(),\n            loc_scale[1].squeeze(-1).cpu(),\n        )\n\n    def predict(  # type: ignore[override]\n        self,\n        context: Union[torch.Tensor, List[torch.Tensor]],\n        prediction_length: Optional[int] = None,\n        limit_prediction_length: bool = False,\n    ) -> torch.Tensor:\n        \"\"\"\n        Get forecasts for the given time series.\n\n        Refer to the base method (``BaseChronosPipeline.predict``)\n        for details on shared parameters.\n        Additional parameters\n        ---------------------\n        limit_prediction_length\n            Force prediction length smaller or equal than the\n            built-in prediction length from the model. False by\n            default. When true, fail loudly if longer predictions\n            are requested, otherwise longer predictions are allowed.\n\n        Returns\n        -------\n        torch.Tensor\n            Forecasts of shape (batch_size, num_quantiles, prediction_length)\n            where num_quantiles is the number of quantiles the model has been\n            trained to output. For official Chronos-Bolt models, the value of\n            num_quantiles is 9 for [0.1, 0.2, ..., 0.9]-quantiles.\n\n        Raises\n        ------\n        ValueError\n            When limit_prediction_length is True and the prediction_length is\n            greater than model's trainig prediction_length.\n        \"\"\"\n        context_tensor = self._prepare_and_validate_context(context=context)\n\n        model_context_length = self.model.config.chronos_config[\"context_length\"]\n        model_prediction_length = self.model.config.chronos_config[\"prediction_length\"]\n        if prediction_length is None:\n            prediction_length = model_prediction_length\n\n        if prediction_length > model_prediction_length:\n            msg = (\n                f\"We recommend keeping prediction length <= {model_prediction_length}. \"\n                \"The quality of longer predictions may degrade since the model is not optimized for it. \"\n            )\n            if limit_prediction_length:\n                msg += \"You can turn off this check by setting `limit_prediction_length=False`.\"\n                raise ValueError(msg)\n            warnings.warn(msg)\n\n        predictions = []\n        remaining = prediction_length\n\n        # We truncate the context here because otherwise batches with very long\n        # context could take up large amounts of GPU memory unnecessarily.\n        if context_tensor.shape[-1] > model_context_length:\n            context_tensor = context_tensor[..., -model_context_length:]\n\n        # TODO: We unroll the forecast of Chronos Bolt greedily with the full forecast\n        # horizon that the model was trained with (i.e., 64). This results in variance collapsing\n        # every 64 steps.\n        context_tensor = context_tensor.to(\n            device=self.model.device,\n            dtype=torch.float32,\n        )\n        while remaining > 0:\n            with torch.no_grad():\n                prediction = self.model(\n                    context=context_tensor,\n                ).quantile_preds.to(context_tensor)\n\n            predictions.append(prediction)\n            remaining -= prediction.shape[-1]\n\n            if remaining <= 0:\n                break\n\n            central_idx = torch.abs(torch.tensor(self.quantiles) - 0.5).argmin()\n            central_prediction = prediction[:, central_idx]\n\n            context_tensor = torch.cat([context_tensor, central_prediction], dim=-1)\n\n        return torch.cat(predictions, dim=-1)[..., :prediction_length].to(\n            dtype=torch.float32, device=\"cpu\"\n        )\n\n    def predict_quantiles(\n        self,\n        context: Union[torch.Tensor, List[torch.Tensor]],\n        prediction_length: Optional[int] = None,\n        quantile_levels: List[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],\n        **predict_kwargs,\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\n        Refer to the base method (``BaseChronosPipeline.predict_quantiles``).\n        \"\"\"\n        # shape (batch_size, prediction_length, len(training_quantile_levels))\n        predictions = (\n            self.predict(context, prediction_length=prediction_length, **predict_kwargs)\n            .detach()\n            .swapaxes(1, 2)\n        )\n\n        training_quantile_levels = self.quantiles\n\n        if set(quantile_levels).issubset(set(training_quantile_levels)):\n            # no need to perform intra/extrapolation\n            quantiles = predictions[\n                ..., [training_quantile_levels.index(q) for q in quantile_levels]\n            ]\n        else:\n            # we rely on torch for interpolating quantiles if quantiles that\n            # Chronos Bolt was trained on were not provided\n            if min(quantile_levels) < min(training_quantile_levels) or max(\n                quantile_levels\n            ) > max(training_quantile_levels):\n                logger.warning(\n                    f\"\\tQuantiles to be predicted ({quantile_levels}) are not within the range of \"\n                    f\"quantiles that Chronos-Bolt was trained on ({training_quantile_levels}). \"\n                    \"Quantile predictions will be set to the minimum/maximum levels at which Chronos-Bolt \"\n                    \"was trained on. This may significantly affect the quality of the predictions.\"\n                )\n\n            # TODO: this is a hack that assumes the model's quantiles during training (training_quantile_levels)\n            # made up an equidistant grid along the quantile dimension. i.e., they were (0.1, 0.2, ..., 0.9).\n            # While this holds for official Chronos-Bolt models, this may not be true in the future, and this\n            # function may have to be revised.\n            augmented_predictions = torch.cat(\n                [predictions[..., [0]], predictions, predictions[..., [-1]]],\n                dim=-1,\n            )\n            quantiles = torch.quantile(\n                augmented_predictions,\n                q=torch.tensor(quantile_levels, dtype=augmented_predictions.dtype),\n                dim=-1,\n            ).permute(1, 2, 0)\n        # NOTE: the median is returned as the mean here\n        mean = predictions[:, :, training_quantile_levels.index(0.5)]\n        return quantiles, mean\n\n    @classmethod\n    def from_pretrained(cls, *args, **kwargs):\n        \"\"\"\n        Load the model, either from a local path or from the HuggingFace Hub.\n        Supports the same arguments as ``AutoConfig`` and ``AutoModel``\n        from ``transformers``.\n        \"\"\"\n\n        config = AutoConfig.from_pretrained(*args, **kwargs)\n        assert hasattr(config, \"chronos_config\"), \"Not a Chronos config file\"\n\n        architecture = config.architectures[0]\n        class_ = globals().get(architecture)\n\n        if class_ is None:\n            logger.warning(\n                f\"Unknown architecture: {architecture}, defaulting to ChronosBoltModelForForecasting\"\n            )\n            class_ = ChronosBoltModelForForecasting\n\n        model = class_.from_pretrained(*args, **kwargs)\n        return cls(model=model)\n"
  },
  {
    "path": "probts/model/nn/arch/ChronosModule/loss.py",
    "content": "import torch\nimport torch.nn as nn\n\n\n# from huggingface transformers/trainer_pt_utils.py\nclass LabelSmoother:\n    \"\"\"\n    Adds label-smoothing on a pre-computed output from a Transformers model.\n\n    Args:\n        epsilon (`float`, *optional*, defaults to 0.1):\n            The label smoothing factor.\n        ignore_index (`int`, *optional*, defaults to -100):\n            The index in the labels to ignore when computing the loss.\n    \"\"\"\n\n    epsilon: float = 0.1\n    ignore_index: int = -100\n\n    def __call__(self, model_output, labels):\n        # logits = model_output[\"logits\"] if isinstance(model_output, dict) else model_output[0]\n        logits = model_output[\"logits\"] if isinstance(model_output, dict) else model_output\n        logits = logits.to(torch.float32)\n        log_probs = -nn.functional.log_softmax(logits, dim=-1)\n        if labels.dim() == log_probs.dim() - 1:\n            labels = labels.unsqueeze(-1)\n\n        padding_mask = labels.eq(self.ignore_index)\n        # In case the ignore_index is -100, the gather will fail, so we replace labels by 0. The padding_mask\n        # will ignore them in any case.\n        labels = torch.clamp(labels, min=0)\n        nll_loss = log_probs.gather(dim=-1, index=labels)\n        # works for fp16 input tensor too, by internally upcasting it to fp32\n        smoothed_loss = log_probs.sum(dim=-1, keepdim=True, dtype=torch.float32)\n\n        nll_loss.masked_fill_(padding_mask, 0.0)\n        smoothed_loss.masked_fill_(padding_mask, 0.0)\n\n        # Take the mean over the label dimensions, then divide by the number of active elements (i.e. not-padded):\n        num_active_elements = padding_mask.numel() - padding_mask.long().sum()\n        nll_loss = nll_loss.sum() / num_active_elements\n        smoothed_loss = smoothed_loss.sum() / (num_active_elements * log_probs.shape[-1])\n        return (1 - self.epsilon) * nll_loss + self.epsilon * smoothed_loss\n"
  },
  {
    "path": "probts/model/nn/arch/ChronosModule/utils.py",
    "content": "# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n# SPDX-License-Identifier: Apache-2.0\n\n\nfrom typing import List\n\nimport torch\n\n\ndef left_pad_and_stack_1D(tensors: List[torch.Tensor]) -> torch.Tensor:\n    max_len = max(len(c) for c in tensors)\n    padded = []\n    for c in tensors:\n        assert isinstance(c, torch.Tensor)\n        assert c.ndim == 1\n        padding = torch.full(\n            size=(max_len - len(c),), fill_value=torch.nan, device=c.device\n        )\n        padded.append(torch.concat((padding, c), dim=-1))\n    return torch.stack(padded)\n"
  },
  {
    "path": "probts/model/nn/arch/Conv_Blocks.py",
    "content": "import torch\nimport torch.nn as nn\n\n\nclass Inception_Block_V1(nn.Module):\n    def __init__(self, in_channels, out_channels, num_kernels=6, init_weight=True):\n        super(Inception_Block_V1, self).__init__()\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.num_kernels = num_kernels\n        kernels = []\n        for i in range(self.num_kernels):\n            kernels.append(nn.Conv2d(in_channels, out_channels, kernel_size=2 * i + 1, padding=i))\n        self.kernels = nn.ModuleList(kernels)\n        if init_weight:\n            self._initialize_weights()\n\n    def _initialize_weights(self):\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n                if m.bias is not None:\n                    nn.init.constant_(m.bias, 0)\n\n    def forward(self, x):\n        res_list = []\n        for i in range(self.num_kernels):\n            res_list.append(self.kernels[i](x))\n        res = torch.stack(res_list, dim=-1).mean(-1)\n        return res\n\n\nclass Inception_Block_V2(nn.Module):\n    def __init__(self, in_channels, out_channels, num_kernels=6, init_weight=True):\n        super(Inception_Block_V2, self).__init__()\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.num_kernels = num_kernels\n        kernels = []\n        for i in range(self.num_kernels // 2):\n            kernels.append(nn.Conv2d(in_channels, out_channels, kernel_size=[1, 2 * i + 3], padding=[0, i + 1]))\n            kernels.append(nn.Conv2d(in_channels, out_channels, kernel_size=[2 * i + 3, 1], padding=[i + 1, 0]))\n        kernels.append(nn.Conv2d(in_channels, out_channels, kernel_size=1))\n        self.kernels = nn.ModuleList(kernels)\n        if init_weight:\n            self._initialize_weights()\n\n    def _initialize_weights(self):\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n                if m.bias is not None:\n                    nn.init.constant_(m.bias, 0)\n\n    def forward(self, x):\n        res_list = []\n        for i in range(self.num_kernels + 1):\n            res_list.append(self.kernels[i](x))\n        res = torch.stack(res_list, dim=-1).mean(-1)\n        return res\n"
  },
  {
    "path": "probts/model/nn/arch/ElasTSTModule/ElasTST_backbone.py",
    "content": "__all__ = ['PatchTST_backbone']\n\n# Cell\nfrom typing import Callable, Optional\nimport torch\nfrom torch import nn\nfrom torch import Tensor\nimport numpy as np\nfrom einops import rearrange, repeat\nfrom probts.utils.position_emb import Time_Encoder, sin_cos_encoding\nfrom probts.model.nn.arch.ElasTSTModule.Layers import EncoderLayer\n\n# Cell\nclass ElasTST_backbone(nn.Module):\n    def __init__(self, \n                 l_patch_size: list,\n                 stride: int = None, \n                 k_patch_size: int = 1, \n                 in_channels: int = 1,\n                 n_layers: int = 0, \n                 t_layers: int = 1, \n                 v_layers: int = 1,\n                 hidden_size: int = 256, \n                 n_heads: int = 16, \n                 d_k: Optional[int] = None, \n                 d_v: Optional[int] = None,\n                 d_inner: int = 256, \n                 dropout: float = 0.,\n                 rotate: bool = False, \n                 max_seq_len = 1000, \n                 theta = 10000, \n                 learnable_theta = False, \n                 addv: bool = False,\n                 bin_att: bool = False,\n                 abs_tem_emb: bool = False,\n                 learn_tem_emb: bool = False,\n                 structured_mask: bool = True,\n                 rope_theta_init: str = 'exp',\n                 min_period: float = 1, \n                 max_period: float = 1000,\n                 patch_share_backbone: bool = True,):\n\n        super().__init__()\n        \n\n        if rotate:\n            print(f'Using Rotary Embedding... [theta init]: {rope_theta_init}, [period range]: [{min_period},{max_period}], [learnable]: {learnable_theta}')\n        print(\"[Binary Att.]: \", bin_att, \" [Learned time emb]: \", learn_tem_emb, \" [Abs time emb]: \", abs_tem_emb)\n        print(\"[Multi Patch Share Backbone]: \", patch_share_backbone)\n        print(\"[Structured Mask]: \", not structured_mask)\n        # Patching\n        self.l_patch_size = l_patch_size\n        self.k_patch_size = k_patch_size\n        self.in_channels = in_channels\n        self.out_channels = in_channels\n        self.patch_share_backbone = patch_share_backbone\n        self.abs_tem_emb= abs_tem_emb\n\n        self.hidden_size = hidden_size\n        if stride is not None:\n            self.stride = stride\n        else:\n            self.stride = self.l_patch_size\n\n        x_embedder = []\n        final_layer = []\n        backbone = []\n        for p in self.l_patch_size:\n            print(f\"=== Patch {p} Branch ===\")\n            x_embedder.append(TimePatchEmbed(p, self.k_patch_size, self.in_channels, self.hidden_size, bias=True,stride=p))\n            final_layer.append(MLP_FinalLayer(self.hidden_size, p, self.k_patch_size, self.out_channels))\n            \n            if not patch_share_backbone:\n                backbone.append(DoublyAtt(d_model=self.hidden_size,n_layers=n_layers, t_layers=t_layers, v_layers=v_layers, d_inner=d_inner, n_heads=n_heads, d_k=d_k, d_v=d_v, dropout=dropout, \n                                    rotate=rotate, max_seq_len=max_seq_len, theta=theta, addv=addv, bin_att=bin_att,\n                                    learnable_theta=learnable_theta, structured_mask=structured_mask,rope_theta_init=rope_theta_init, min_period=min_period, max_period=max_period))\n            \n        self.x_embedder = nn.ModuleList(x_embedder)\n        self.final_layer = nn.ModuleList(final_layer)\n        \n        if not patch_share_backbone:\n            self.backbone = nn.ModuleList(backbone)\n        else:\n            self.backbone = DoublyAtt(d_model=self.hidden_size,n_layers=n_layers, t_layers=t_layers, v_layers=v_layers, d_inner=d_inner, n_heads=n_heads, d_k=d_k, d_v=d_v, dropout=dropout, \n                                    rotate=rotate, max_seq_len=max_seq_len, theta=theta, addv=addv, bin_att=bin_att,\n                                    learnable_theta=learnable_theta, structured_mask=structured_mask,rope_theta_init=rope_theta_init, min_period=min_period, max_period=max_period)\n       \n        self.learn_tem_emb = learn_tem_emb\n        if self.learn_tem_emb:\n            self.learn_time_embedding = Time_Encoder(self.hidden_size)\n\n    def get_patch_num(self, dim_size, len_size, l_patch_size):\n        num_k_patches = int((dim_size - self.k_patch_size)/self.k_patch_size + 1)\n        num_l_patches = int((len_size - l_patch_size)/l_patch_size + 1)\n        return num_k_patches, num_l_patches\n\n\n    def forward(self, past_target, future_placeholder, past_observed_values, future_observed_values, dataset_name=None):                                                                   # z: [bs x nvars x seq_len]\n\n        pred_shape = future_placeholder.shape\n        future_observed_indicator = torch.zeros(future_observed_values.shape).to(future_observed_values.device)\n        \n        x = torch.cat((past_target, future_placeholder), dim=1) # B L+T K\n        \n        past_value_indicator = torch.cat((past_observed_values, future_observed_indicator), dim=1) # B L+T K\n        observed_value_indicator = torch.cat((past_observed_values, future_observed_values), dim=1) # B L+T K\n        \n        pred_list = []\n\n        for idx in range(len(self.l_patch_size)):\n\n            x_p = x.clone()\n            \n            num_k_patches, num_l_patches = self.get_patch_num(x_p.shape[-1], x_p.shape[-2],self.l_patch_size[idx])\n\n            # do patching\n            x_p, past_value_indicator_p, observed_value_indicator_p = self.x_embedder[idx](x_p, past_value_indicator, observed_value_indicator)  # b k l d\n            \n            if self.learn_tem_emb:\n                grid_len = np.arange(num_l_patches, dtype=np.float32)\n                grid_len = torch.tensor(grid_len, requires_grad=False).float().unsqueeze(0).to(x.device)\n                pos_embed = repeat(grid_len, '1 l -> b l', b=pred_shape[0])\n                pos_embed = self.learn_time_embedding(pos_embed) # b l 1 d\n                pos_embed = rearrange(pos_embed, 'b l 1 d -> b 1 l d')\n                x_p = x_p + pos_embed\n            \n            # use a absolute position embedding\n            if self.abs_tem_emb:\n                B, K, L, embed_dim = x_p.shape\n                pos_embed = sin_cos_encoding(B, K, L, embed_dim).float() # b k l d\n                x_p = x_p + pos_embed.to(x_p.device)\n\n            # model\n            if self.patch_share_backbone:\n                x_p = self.backbone(x_p, past_value_indicator_p, observed_value_indicator_p)        # b k l d\n            else:\n                x_p = self.backbone[idx](x_p, past_value_indicator_p, observed_value_indicator_p)        # b k l d\n\n            \n            x_p = self.final_layer[idx](x_p) # b k l p\n\n            x_p = rearrange(x_p, 'b k t p -> b (t p) k')\n\n            x_p = x_p[:,-pred_shape[1]:,:]\n            \n            pred_list.append(x_p.unsqueeze(-1))\n        \n        pred_list = torch.cat(pred_list, dim=-1)\n        multi_patch_mean_res = torch.mean(pred_list, dim=-1)\n\n        return multi_patch_mean_res, pred_list\n\n    \nclass DoublyAtt(nn.Module):  \n    def __init__(self, d_model,n_layers, d_inner, n_heads, d_k, d_v, dropout, \n                 rotate=False, max_seq_len=1024, theta=10000, t_layers=2, v_layers=1,\n                 bin_att=False, addv=False, learnable_theta=False, structured_mask=True,\n                 rope_theta_init='exp',min_period=0.1, max_period=10):\n        super().__init__()\n        # assert n_layers <= (t_layers + v_layers) <= 2*n_layers , \"Sum of t_layers and n_layers must be between 1 and 2\"  \n        \n        # Configuration based on temporal and variate ratios\n        self.layer_stack = nn.ModuleList()  \n        num_t = t_layers\n        num_v = v_layers\n        num_both = min(t_layers, v_layers)\n\n        num_t = num_t - num_both\n        num_v = num_v - num_both\n        \n        t_count = 0\n        v_count= 0\n        for _ in range(num_t + num_v):\n            if t_count < num_t  :\n                self.layer_stack.append(EncoderLayer(d_model, d_inner, n_heads, d_k, d_v, dropout=dropout, tem_att=True, type_att=False, \n                         structured_mask=structured_mask, rotate=rotate, max_seq_len=max_seq_len,theta=theta, addv=addv, \n                         learnable_theta=learnable_theta, bin_att=bin_att,rope_theta_init=rope_theta_init, min_period=min_period, max_period=max_period))  \n                t_count = t_count + 1\n                print(f\"[Encoder Layer {t_count+v_count}] Use tem att\")\n            if v_count < num_v:\n                self.layer_stack.append(EncoderLayer(d_model, d_inner, n_heads, d_k, d_v, dropout=dropout, tem_att=False, type_att=True, \n                         structured_mask=structured_mask, rotate=rotate, max_seq_len=max_seq_len,theta=theta, addv=addv, \n                         learnable_theta=learnable_theta, bin_att=bin_att,rope_theta_init=rope_theta_init, min_period=min_period, max_period=max_period))  \n                v_count = v_count + 1\n                print(f\"[Encoder Layer {t_count+v_count}] Use var att\")\n                \n        for idx in range(num_both):  \n            self.layer_stack.append(EncoderLayer(d_model, d_inner, n_heads, d_k, d_v, dropout=dropout, tem_att=True, type_att=True, \n                         structured_mask=structured_mask, rotate=rotate, max_seq_len=max_seq_len,theta=theta, addv=addv, \n                         learnable_theta=learnable_theta, bin_att=bin_att,rope_theta_init=rope_theta_init, min_period=min_period, max_period=max_period))  \n\n            print(f\"[Encoder Layer {idx+t_count+v_count}] Use tem and var att\")\n\n    def forward(self, x, past_value_indicator, observed_indicator) -> Tensor:                \n\n        for enc_layer in self.layer_stack:\n            x = enc_layer(x, past_value_indicator=past_value_indicator, observed_indicator=observed_indicator)\n\n        return x  \n\n\nclass MLP_FinalLayer(nn.Module):\n    \"\"\"\n    The final layer of DiT.\n    \"\"\"\n    def __init__(self, hidden_size, l_patch_size, k_patch_size, out_channels):\n        super().__init__()\n        self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)\n        self.linear = nn.Linear(hidden_size, l_patch_size * k_patch_size * out_channels, bias=True)\n\n\n    def forward(self, x):\n        x = self.norm_final(x)\n        x = self.linear(x)\n        return x\n\nclass TimePatchEmbed(nn.Module):\n    \"\"\" Time Patch Embedding\n    \"\"\"\n    def __init__(\n            self,\n            l_patch_size: int = 16,\n            k_patch_size = 1,\n            in_chans: int = 1,\n            embed_dim: int = 768,\n            norm_layer: Optional[Callable] = None,\n            flatten: bool = False,\n            bias: bool = True,\n            # padding_patch = None,\n            stride = None,\n            # strict_img_size: bool = True,\n    ):\n        super().__init__()\n        self.l_patch_size = l_patch_size\n        self.k_patch_size = k_patch_size\n        if stride is None:\n            stride = l_patch_size\n\n        self.flatten = flatten\n\n        padding = 0\n        kernel_size = (l_patch_size,k_patch_size)\n        stride_size = (stride,k_patch_size)\n\n        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride_size, bias=bias, padding=padding)\n        self.mask_proj = nn.Conv2d(1, 1, kernel_size=kernel_size, stride=stride_size, bias=False, padding=padding)\n\n        self.mask_proj.weight.data.fill_(1.0)\n\n        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()\n\n    def forward(self, x, future_mask, obv_mask):\n        '''\n        future_mask: only past values are set to 1\n        obv_mask: past values and values to be predicted are set to 1\n        '''\n        \n        # B, C, K, L = x.shape\n        if len(x.shape) == 3:\n            x = rearrange(x, 'b l k -> b 1 l k')\n            \n        future_mask = rearrange(future_mask, 'b l k -> b 1 l k')\n        obv_mask = rearrange(obv_mask, 'b l k -> b 1 l k')\n            \n        x = self.proj(x)  # B C L K -> B C L' K\n\n        with torch.no_grad():\n            future_mask = self.mask_proj(future_mask)\n            obv_mask = self.mask_proj(obv_mask)\n\n        if self.flatten:\n            x = x.flatten(2).transpose(1, 2)  # NCHW -> NLC\n            future_mask = future_mask.flatten(2).transpose(1, 2)  # NCHW -> NLC\n            obv_mask = obv_mask.flatten(2).transpose(1, 2)  # NCHW -> NLC\n\n        x = self.norm(x)\n\n        x = rearrange(x, 'b d l k -> b k l d')\n        future_mask = rearrange(future_mask, 'b 1 l k -> b k l')\n        obv_mask = rearrange(obv_mask, 'b 1 l k -> b k l')\n        return x, future_mask, obv_mask\n\n"
  },
  {
    "path": "probts/model/nn/arch/ElasTSTModule/Layers.py",
    "content": "import torch.nn as nn\nimport sys\nimport torch\nfrom probts.model.nn.arch.ElasTSTModule.SubLayers import  PositionwiseFeedForward, MultiHeadAttention_tem_bias, MultiHeadAttention_type_bias\nfrom einops import rearrange, repeat\n\n\nPAD = 0\n\ndef get_attn_key_pad_mask_K(past_value_indicator, observed_indicator , transpose=False, structured_mask=False):\n    \"\"\" For masking out the padding part of key sequence. \n    input: mask: transpose=False: [b k l]\n    \"\"\"\n\n    if structured_mask:\n        mask = past_value_indicator\n    else:\n        mask = observed_indicator\n\n\n    if transpose:\n        mask = rearrange(mask, 'b l k -> b k l')\n    padding_mask = repeat(mask, 'b k l1 -> b k l2 l1', l2=mask.shape[-1]).eq(PAD)\n\n    return padding_mask\n\nclass EncoderLayer(nn.Module):\n    \"\"\" Compose with two layers \"\"\"\n\n    def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1, \n                 tem_att=True, type_att=False, structured_mask=True, \n                 rotate=False, max_seq_len=100, theta=10000,\n                 addv=False, learnable_theta=False, bin_att=False,\n                 rope_theta_init='exp',min_period=0.1, max_period=10):\n        super(EncoderLayer, self).__init__()\n\n        self.structured_mask = structured_mask\n        self.tem_att = tem_att\n        self.type_att = type_att\n\n        if tem_att:\n            self.slf_tem_attn = MultiHeadAttention_tem_bias(\n                n_head, d_model, d_k, d_v, dropout=dropout, rotate=rotate, max_seq_len=max_seq_len, theta=theta, addv=addv, \n                learnable_theta=learnable_theta, bin_att=bin_att,rope_theta_init=rope_theta_init, min_period=min_period, max_period=max_period)\n\n        if type_att:\n            self.slf_type_attn = MultiHeadAttention_type_bias(\n                n_head, d_model, d_k, d_v, dropout=dropout, rotate=False, max_seq_len=max_seq_len, bin_att=bin_att)\n\n\n        self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)\n        \n        self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)\n\n    def forward(self, input, past_value_indicator=None, observed_indicator=None):\n        # time attention\n        # [B, K, L, D]\n        if self.tem_att:\n            tem_mask = get_attn_key_pad_mask_K(past_value_indicator=past_value_indicator, observed_indicator=observed_indicator, transpose=False, structured_mask=self.structured_mask)\n            tem_output = self.layer_norm(input)\n            \n            tem_output, enc_tem_attn = self.slf_tem_attn(\n                tem_output, tem_output, tem_output, mask=tem_mask) \n            \n            tem_output = tem_output + input\n        else:\n            tem_output = input\n        \n        tem_output = rearrange(tem_output, 'b k l d -> b l k d')\n\n        \n        # type attention\n        # [B, L, K, D]\n        if self.type_att:\n            type_mask = get_attn_key_pad_mask_K(past_value_indicator=past_value_indicator, observed_indicator=observed_indicator, transpose=True, structured_mask=self.structured_mask)\n            \n            type_output = self.layer_norm(tem_output)\n            \n            type_output, enc_type_attn = self.slf_type_attn(\n                type_output, type_output, type_output, mask=type_mask) \n            \n            enc_output = type_output + tem_output\n        else:\n            enc_output = tem_output\n            \n        # FFNN\n        output = self.layer_norm(enc_output)\n        \n        output = self.pos_ffn(output)\n\n        output = output + enc_output\n        \n        output = rearrange(output, 'b l k d -> b k l d')\n        \n        # optional\n        output = self.layer_norm(output)\n\n        return output #, enc_tem_attn, enc_type_attn\n\n\n"
  },
  {
    "path": "probts/model/nn/arch/ElasTSTModule/Modules.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom einops import rearrange, repeat\nfrom probts.model.nn.arch.ElasTSTModule.TRoPE import RotaryEmbedding\n\nclass ScaledDotProductAttention(nn.Module):\n    \"\"\" Scaled Dot-Product Attention \"\"\"\n\n    def __init__(self, temperature, attn_dropout=0.2):\n        super().__init__()\n\n        self.temperature = temperature\n        self.dropout = nn.Dropout(attn_dropout)\n\n    def forward(self, q, k, v, mask=None):\n        attn = torch.bmm(q / self.temperature, k.transpose(-2, -1))\n        \n        if mask is not None and mask.dim() == 5:\n            mask = mask.transpose(2, 4)\n\n        if mask is not None:\n            attn = attn.masked_fill(mask, -1e9)\n\n        attn = self.dropout(F.softmax(attn, dim=-1))\n        output = torch.bmm(attn, v)\n\n        return output, attn\n\n\nclass ScaledDotProductAttention_bias(nn.Module):\n\n    def __init__(self, d_model, n_head, d_k, d_v, temperature, \n                 attn_dropout=0.2, rotate=False, max_seq_len=100, \n                 theta=10000, addv=False, learnable_theta=False, \n                 bin_att=False,rope_theta_init='exp',\n                 min_period=0.1, max_period=10):\n        super().__init__()\n        \n        self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False)\n        self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False)\n        self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False)\n\n        self.temperature = temperature\n        self.dropout = nn.Dropout(attn_dropout)\n        self.n_head = n_head\n        self.bin_att = bin_att\n        self.rotate = rotate\n        self.addv = addv\n        self.trope = RotaryEmbedding(d_v, max_seq_len,base=theta, learnable=learnable_theta,init=rope_theta_init,min_period=min_period, max_period=max_period)\n\n        if self.bin_att:\n            self.alpha = nn.Parameter(torch.zeros([1,1,n_head,1,1]))\n            self.beta = nn.Parameter(torch.zeros([1,1,n_head,1,1]))\n\n    def forward(self, q, k, v, mask):\n        # input: [B,K,H,LQ,LK] for temporal, [B,L,H,Kq,Kk] for category\n        \n        # [B,K,L,H,D]\n        q = rearrange(self.w_qs(q), 'b k l (n d) -> b k n l d', n=self.n_head)\n        k = rearrange(self.w_ks(k), 'b k l (n d) -> b k n d l', n=self.n_head)\n        v = rearrange(self.w_vs(v), 'b k l (n d) -> b k n l d', n=self.n_head)\n        \n        B, K, N, L, D = q.shape\n        if self.rotate:\n            xq = rearrange(q, 'b k n l d -> (b k n) l d')\n            xk = rearrange(k, 'b k n d l -> (b k n) l d')\n            xv = rearrange(v, 'b k n l d -> (b k n) l d')\n\n            xq, xk, xv = self.trope(xq, xk, xv)\n\n            attn = torch.matmul(xq, xk.transpose(1, 2)) / self.temperature\n            attn = rearrange(attn, '(b k n) l t -> b k n l t', b=B, k=K,n=N)\n            if self.addv:\n                v = rearrange(xv, '(b k n) l d -> b k n l d', b=B, k=K,n=N)\n        else:\n            attn = torch.matmul(q , k) / self.temperature\n\n        if self.bin_att:\n            self_mask = torch.eye(L).to(mask.device)\n            self_mask = repeat(self_mask, 'l t -> b k n l t', b=B, k=K,n=N)\n\n            attn = attn + self_mask * self.alpha + (1-self_mask) * self.beta\n\n        if mask is not None:\n            if attn.dim() > mask.dim():\n                mask = mask.unsqueeze(2).expand(attn.shape)\n            attn = attn.masked_fill(mask, -1e9)\n            \n\n        attn = self.dropout(F.softmax(attn, dim=-1))\n\n        v = torch.matmul(attn, v)\n\n        v = rearrange(v, 'b k n l d -> b k l (n d)')\n\n        # sys.exit(0)\n        return v, attn\n    \nclass Attention(nn.Module):\n\n    def __init__(self, hin_d, d_model):\n        super().__init__()\n\n        self.linear = nn.Linear(d_model, hin_d)\n        self.W = nn.Linear(hin_d,1, bias=False)\n        \n    def forward(self, x, mask=None, mask_value=-1e30):\n        # [B,K,L,D]\n        \n        # map directly\n        attn = self.W(torch.tanh(self.linear(x))) # [B,K,L,1]\n        \n        if mask is not None:\n            attn = mask * attn + (1-mask)*mask_value\n            \n        attn = F.softmax(attn, dim=-2)\n        \n        x = torch.matmul(x.transpose(-1, -2), attn).squeeze(-1) # [B,K,D,1]\n\n        return x, attn"
  },
  {
    "path": "probts/model/nn/arch/ElasTSTModule/SubLayers.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport sys\n\nfrom probts.model.nn.arch.ElasTSTModule.Modules import ScaledDotProductAttention_bias\n\nclass MultiHeadAttention_tem_bias(nn.Module):\n    \"\"\" Multi-Head Attention module \"\"\"\n\n    def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1, rotate=False, max_seq_len=100, theta=10000, addv=False, \n                 learnable_theta=False, bin_att=False,rope_theta_init='exp',min_period=0.1, max_period=10):\n        super().__init__()\n        self.n_head = n_head\n        self.d_k = d_k\n        self.d_v = d_v\n\n        self.fc = nn.Linear(d_v * n_head, d_model)\n\n        self.attention = ScaledDotProductAttention_bias(d_model, n_head, d_k, d_v, temperature=d_k ** 0.5, \n                                                        attn_dropout=dropout, rotate=rotate, max_seq_len=max_seq_len, \n                                                        theta=theta, addv=addv, learnable_theta=learnable_theta, bin_att=bin_att, \n                                                        rope_theta_init=rope_theta_init,min_period=min_period, max_period=max_period)\n\n        self.dropout = nn.Dropout(dropout)\n\n    def forward(self, q, k, v, mask=None):\n        # event_matrix [B,L,K]\n\n        # [B,K,H,Lq,Lk]\n        output, attn = self.attention(q, k, v, mask=mask) # [B,K,H,L,D]\n\n        output = self.dropout(self.fc(output))\n\n        return output, attn\n\n\nclass MultiHeadAttention_type_bias(nn.Module):\n    \"\"\" Multi-Head Attention module \"\"\"\n\n    def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1, rotate=False, max_seq_len=1024, bin_att=False):\n        super().__init__()\n        self.n_head = n_head\n        self.d_k = d_k\n        self.d_v = d_v\n\n        self.fc = nn.Linear(d_v * n_head, d_model)\n        self.attention = ScaledDotProductAttention_bias(d_model, n_head, d_k, d_v, temperature=d_k ** 0.5, attn_dropout=dropout, rotate=False, max_seq_len=max_seq_len, bin_att=bin_att)\n\n        self.dropout = nn.Dropout(dropout)\n\n    def forward(self, q, k, v, mask=None):\n        # [B,L,K,D]\n        output, attn = self.attention(q, k, v, mask=mask) \n\n        output = self.dropout(self.fc(output))\n\n        return output, attn\n\n\nclass PositionwiseFeedForward(nn.Module):\n    \"\"\" Two-layer position-wise feed-forward neural network. \"\"\"\n\n    def __init__(self, d_in, d_hid, dropout=0.1):\n        super().__init__()\n        self.w_1 = nn.Linear(d_in, d_hid)\n        self.w_2 = nn.Linear(d_hid, d_in)\n        self.dropout = nn.Dropout(dropout)\n\n    def forward(self, x):\n        x = F.gelu(self.w_1(x))\n        x = self.dropout(x)\n        x = self.w_2(x)\n        x = self.dropout(x)\n\n        return x\n\n\n"
  },
  {
    "path": "probts/model/nn/arch/ElasTSTModule/TRoPE.py",
    "content": "import torch\nfrom typing import Tuple\nimport torch\nimport torch.nn as nn\nimport numpy as np\nimport sys\n\nclass RotaryEmbedding(nn.Module):  \n    def __init__(self, dim: int, seq_len: int, base: float = 10000.0, learnable=False, init=\"exp\",min_period=0.01, max_period=1000):  \n        super(RotaryEmbedding, self).__init__()  \n        if init == 'linear':\n            theta = get_linear_period(min_period, max_period, dim)\n        elif init == 'uniform':\n            theta = torch.ones([dim//2])\n            periods = torch.nn.init.uniform_(theta, a=min_period, b=max_period)\n            theta = 2 * np.pi / periods\n        elif init == 'exp':\n            theta = get_exp_period(min_period, max_period, dim)\n        elif init == 'rope':\n            theta = 1.0 / (base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))\n        else:\n            print(\"invalid theta init\")\n            sys.exit(0)\n\n        if learnable:  \n            self.freqs = nn.Parameter(theta)\n        else:  \n            self.register_buffer('freqs', torch.tensor(theta))\n        \n        self.dim = dim  \n        self.seq_len = seq_len  \n        self.learnable = learnable  \n\n    def forward(self, xq: torch.Tensor, xk: torch.Tensor, xv: torch.Tensor):\n        L = xq.shape[-2]\n        t = torch.arange(L, device=xq.device)\n            \n        freqs = torch.outer(t, self.freqs).float()  # m * \\theta\n        freqs_cis = torch.polar(torch.ones_like(freqs), freqs)\n        \n        xq_ = xq.float().reshape(*xq.shape[:-1], -1, 2)\n        xk_ = xk.float().reshape(*xk.shape[:-1], -1, 2)\n        xv_ = xv.float().reshape(*xv.shape[:-1], -1, 2)\n    \n        xq_ = torch.view_as_complex(xq_).to(xq.device)\n        xk_ = torch.view_as_complex(xk_).to(xq.device)\n        xv_ = torch.view_as_complex(xv_).to(xq.device)\n        \n        # rotate and then map to real number field\n        xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(2).to(xq.device)\n        xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(2).to(xq.device)\n        xv_out = torch.view_as_real(xv_ * freqs_cis).flatten(2).to(xq.device)\n        return xq_out.type_as(xq), xk_out.type_as(xk), xv_out.type_as(xv)\n\n\ndef get_linear_period(min_period, max_period, dim):\n    i = torch.arange(0, dim, 2)[: (dim // 2)]\n\n    periods = min_period + ((max_period - min_period) / dim )  * i\n    theta = 2 * np.pi / periods  \n    return theta\n\ndef get_exp_period(min_period, max_period, dim):\n    i = torch.arange(0, dim, 2)[: (dim // 2)]\n    max_theta = 2 * np.pi / min_period\n    min_theta = 2 * np.pi / max_period\n    alpha = np.log(max_theta/min_theta) * (1/(dim-2))\n    thetas = max_theta * np.exp(-alpha * i)\n    return thetas\n\n# generate rotation matrix\ndef precompute_freqs_cis(dim: int, seq_len: int, theta: float = 10000.0):\n    \n    # rotate \\theta_i\n    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))\n    # generate token indexes t = [0, 1,..., seq_len-1]\n    t = torch.arange(seq_len, device=freqs.device)\n    # freqs.shape = [seq_len, dim // 2] \n    freqs = torch.outer(t, freqs).float()  # m * \\theta\n\n    freqs_cis = torch.polar(torch.ones_like(freqs), freqs) \n    return freqs_cis\n\ndef apply_rotary_emb(\n    xq: torch.Tensor,\n    xk: torch.Tensor,\n    xv: torch.Tensor,\n    freqs_cis: torch.Tensor,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n    xq_ = xq.float().reshape(*xq.shape[:-1], -1, 2)\n    xk_ = xk.float().reshape(*xk.shape[:-1], -1, 2)\n    xv_ = xv.float().reshape(*xv.shape[:-1], -1, 2)\n    \n    freqs_cis = freqs_cis.to(xq.device)\n\n    xq_ = torch.view_as_complex(xq_).to(xq.device)\n    xk_ = torch.view_as_complex(xk_).to(xq.device)\n    xv_ = torch.view_as_complex(xv_).to(xq.device)\n    \n    # rotate and then map to real number field\n    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(2).to(xq.device)\n    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(2).to(xq.device)\n    xv_out = torch.view_as_real(xv_ * freqs_cis).flatten(2).to(xq.device)\n    return xq_out.type_as(xq), xk_out.type_as(xk), xv_out.type_as(xv)\n\n"
  },
  {
    "path": "probts/model/nn/arch/ElasTSTModule/__init__.py",
    "content": ""
  },
  {
    "path": "probts/model/nn/arch/ModernTCN_backbone.py",
    "content": "import torch\r\nfrom torch import nn\r\nimport torch.nn.functional as F\r\nfrom probts.model.nn.arch.RevIN import RevIN\r\nfrom probts.model.nn.arch.decomp import series_decomp\r\n\r\n# forecast task head\r\nclass Flatten_Head(nn.Module):\r\n    def __init__(self, individual, n_vars, nf, target_window, head_dropout=0):\r\n        super(Flatten_Head, self).__init__()\r\n\r\n        self.individual = individual\r\n        self.n_vars = n_vars\r\n\r\n        if self.individual:\r\n            self.linears = nn.ModuleList()\r\n            self.dropouts = nn.ModuleList()\r\n            self.flattens = nn.ModuleList()\r\n            for i in range(self.n_vars):\r\n                self.flattens.append(nn.Flatten(start_dim=-2))\r\n                self.linears.append(nn.Linear(nf, target_window))\r\n                self.dropouts.append(nn.Dropout(head_dropout))\r\n        else:\r\n            self.flatten = nn.Flatten(start_dim=-2)\r\n            self.linear = nn.Linear(nf, target_window)\r\n            self.dropout = nn.Dropout(head_dropout)\r\n\r\n    def forward(self, x):  # x: [bs x nvars x d_model x patch_num]\r\n        if self.individual:\r\n            x_out = []\r\n            for i in range(self.n_vars):\r\n                z = self.flattens[i](x[:, i, :, :])  # z: [bs x d_model * patch_num]\r\n                z = self.linears[i](z)  # z: [bs x target_window]\r\n                z = self.dropouts[i](z)\r\n                x_out.append(z)\r\n            x = torch.stack(x_out, dim=1)  # x: [bs x nvars x target_window]\r\n        else:\r\n            x = self.flatten(x)\r\n            x = self.linear(x)\r\n            x = self.dropout(x)\r\n        return x\r\n\r\nclass LayerNorm(nn.Module):\r\n    def __init__(self, channels, eps=1e-6, data_format=\"channels_last\"):\r\n        super(LayerNorm, self).__init__()\r\n        self.norm = nn.Layernorm(channels)\r\n\r\n    def forward(self, x):\r\n        B, M, D, N = x.shape\r\n        x = x.permute(0, 1, 3, 2)\r\n        x = x.reshape(B * M, N, D)\r\n        x = self.norm(x)\r\n        x = x.reshape(B, M, N, D)\r\n        x = x.permute(0, 1, 3, 2)\r\n        return x\r\n\r\ndef get_conv1d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias):\r\n    return nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,\r\n                     padding=padding, dilation=dilation, groups=groups, bias=bias)\r\n\r\n\r\ndef get_bn(channels):\r\n    return nn.BatchNorm1d(channels)\r\n\r\ndef conv_bn(in_channels, out_channels, kernel_size, stride, padding, groups, dilation=1,bias=False):\r\n    if padding is None:\r\n        padding = kernel_size // 2\r\n    result = nn.Sequential()\r\n    result.add_module('conv', get_conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,\r\n                                         stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias))\r\n    result.add_module('bn', get_bn(out_channels))\r\n    return result\r\n\r\ndef fuse_bn(conv, bn):\r\n\r\n    kernel = conv.weight\r\n    running_mean = bn.running_mean\r\n    running_var = bn.running_var\r\n    gamma = bn.weight\r\n    beta = bn.bias\r\n    eps = bn.eps\r\n    std = (running_var + eps).sqrt()\r\n    t = (gamma / std).reshape(-1, 1, 1)\r\n    return kernel * t, beta - running_mean * gamma / std\r\n\r\nclass ReparamLargeKernelConv(nn.Module):\r\n\r\n    def __init__(self, in_channels, out_channels, kernel_size,\r\n                 stride, groups,\r\n                 small_kernel,\r\n                 small_kernel_merged=False, nvars=7):\r\n        super(ReparamLargeKernelConv, self).__init__()\r\n        self.kernel_size = kernel_size\r\n        self.small_kernel = small_kernel\r\n        # We assume the conv does not change the feature map size, so padding = k//2. Otherwise, you may configure padding as you wish, and change the padding of small_conv accordingly.\r\n        padding = kernel_size // 2\r\n        if small_kernel_merged:\r\n            self.lkb_reparam = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,\r\n                                         stride=stride, padding=padding, dilation=1, groups=groups, bias=True)\r\n        else:\r\n            self.lkb_origin = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,\r\n                                        stride=stride, padding=padding, dilation=1, groups=groups,bias=False)\r\n            if small_kernel is not None:\r\n                assert small_kernel <= kernel_size, 'The kernel size for re-param cannot be larger than the large kernel!'\r\n                self.small_conv = conv_bn(in_channels=in_channels, out_channels=out_channels,\r\n                                            kernel_size=small_kernel,\r\n                                            stride=stride, padding=small_kernel // 2, groups=groups, dilation=1,bias=False)\r\n\r\n\r\n    def forward(self, inputs):\r\n\r\n        if hasattr(self, 'lkb_reparam'):\r\n            out = self.lkb_reparam(inputs)\r\n        else:\r\n            out = self.lkb_origin(inputs)\r\n            if hasattr(self, 'small_conv'):\r\n                out += self.small_conv(inputs)\r\n\r\n        return out\r\n\r\n    def PaddingTwoEdge1d(self,x,pad_length_left,pad_length_right,pad_values=0):\r\n\r\n        D_out,D_in,ks=x.shape\r\n        if pad_values ==0:\r\n            pad_left = torch.zeros(D_out,D_in,pad_length_left)\r\n            pad_right = torch.zeros(D_out,D_in,pad_length_right)\r\n        else:\r\n            pad_left = torch.ones(D_out, D_in, pad_length_left) * pad_values\r\n            pad_right = torch.ones(D_out, D_in, pad_length_right) * pad_values\r\n        x = torch.cat([pad_left,x],dims=-1)\r\n        x = torch.cat([x,pad_right],dims=-1)\r\n        return x\r\n\r\n    def get_equivalent_kernel_bias(self):\r\n\r\n        eq_k, eq_b = fuse_bn(self.lkb_origin.conv, self.lkb_origin.bn)\r\n\r\n        if hasattr(self, 'small_conv'):\r\n            small_k, small_b = fuse_bn(self.small_conv.conv, self.small_conv.bn)\r\n\r\n            eq_b += small_b\r\n\r\n            eq_k += self.PaddingTwoEdge1d(small_k, (self.kernel_size - self.small_kernel) // 2,\r\n                                          (self.kernel_size - self.small_kernel) // 2, 0)\r\n        return eq_k, eq_b\r\n\r\n    def merge_kernel(self):\r\n        eq_k, eq_b = self.get_equivalent_kernel_bias()\r\n        self.lkb_reparam = nn.Conv1d(in_channels=self.lkb_origin.conv.in_channels,\r\n                                     out_channels=self.lkb_origin.conv.out_channels,\r\n                                     kernel_size=self.lkb_origin.conv.kernel_size, stride=self.lkb_origin.conv.stride,\r\n                                     padding=self.lkb_origin.conv.padding, dilation=self.lkb_origin.conv.dilation,\r\n                                     groups=self.lkb_origin.conv.groups, bias=True)\r\n        self.lkb_reparam.weight.data = eq_k\r\n        self.lkb_reparam.bias.data = eq_b\r\n        self.__delattr__('lkb_origin')\r\n        if hasattr(self, 'small_conv'):\r\n            self.__delattr__('small_conv')\r\n\r\nclass Block(nn.Module):\r\n    def __init__(self, large_size, small_size, dmodel, dff, nvars, small_kernel_merged=False, drop=0.1):\r\n\r\n        super(Block, self).__init__()\r\n        self.dw = ReparamLargeKernelConv(in_channels=nvars * dmodel, out_channels=nvars * dmodel,\r\n                                         kernel_size=large_size, stride=1, groups=nvars * dmodel,\r\n                                         small_kernel=small_size, small_kernel_merged=small_kernel_merged, nvars=nvars)\r\n        self.norm = nn.BatchNorm1d(dmodel)\r\n\r\n        #convffn1\r\n        self.ffn1pw1 = nn.Conv1d(in_channels=nvars * dmodel, out_channels=nvars * dff, kernel_size=1, stride=1,\r\n                                 padding=0, dilation=1, groups=nvars)\r\n        self.ffn1act = nn.GELU()\r\n        self.ffn1pw2 = nn.Conv1d(in_channels=nvars * dff, out_channels=nvars * dmodel, kernel_size=1, stride=1,\r\n                                 padding=0, dilation=1, groups=nvars)\r\n        self.ffn1drop1 = nn.Dropout(drop)\r\n        self.ffn1drop2 = nn.Dropout(drop)\r\n\r\n        #convffn2\r\n        self.ffn2pw1 = nn.Conv1d(in_channels=nvars * dmodel, out_channels=nvars * dff, kernel_size=1, stride=1,\r\n                                 padding=0, dilation=1, groups=dmodel)\r\n        self.ffn2act = nn.GELU()\r\n        self.ffn2pw2 = nn.Conv1d(in_channels=nvars * dff, out_channels=nvars * dmodel, kernel_size=1, stride=1,\r\n                                 padding=0, dilation=1, groups=dmodel)\r\n        self.ffn2drop1 = nn.Dropout(drop)\r\n        self.ffn2drop2 = nn.Dropout(drop)\r\n\r\n        self.ffn_ratio = dff//dmodel\r\n    def forward(self,x):\r\n\r\n        input = x\r\n        B, M, D, N = x.shape\r\n        x = x.reshape(B,M*D,N)\r\n        x = self.dw(x)\r\n        x = x.reshape(B,M,D,N)\r\n        x = x.reshape(B*M,D,N)\r\n        x = self.norm(x)\r\n        x = x.reshape(B, M, D, N)\r\n        x = x.reshape(B, M * D, N)\r\n\r\n        x = self.ffn1drop1(self.ffn1pw1(x))\r\n        x = self.ffn1act(x)\r\n        x = self.ffn1drop2(self.ffn1pw2(x))\r\n        x = x.reshape(B, M, D, N)\r\n\r\n        x = x.permute(0, 2, 1, 3)\r\n        x = x.reshape(B, D * M, N)\r\n        x = self.ffn2drop1(self.ffn2pw1(x))\r\n        x = self.ffn2act(x)\r\n        x = self.ffn2drop2(self.ffn2pw2(x))\r\n        x = x.reshape(B, D, M, N)\r\n        x = x.permute(0, 2, 1, 3)\r\n\r\n        x = input + x\r\n        return x\r\n\r\n\r\nclass Stage(nn.Module):\r\n    def __init__(self, ffn_ratio, num_blocks, large_size, small_size, dmodel, dw_model, nvars,\r\n                 small_kernel_merged=False, drop=0.1):\r\n\r\n        super(Stage, self).__init__()\r\n        d_ffn = dmodel * ffn_ratio\r\n        blks = []\r\n        for i in range(num_blocks):\r\n            blk = Block(large_size=large_size, small_size=small_size, dmodel=dmodel, dff=d_ffn, nvars=nvars, small_kernel_merged=small_kernel_merged, drop=drop)\r\n            blks.append(blk)\r\n\r\n        self.blocks = nn.ModuleList(blks)\r\n\r\n    def forward(self, x):\r\n\r\n        for blk in self.blocks:\r\n            x = blk(x)\r\n\r\n        return x\r\n\r\n\r\nclass ModernTCNModel(nn.Module):\r\n    def __init__(self,patch_size,patch_stride, stem_ratio, downsample_ratio, ffn_ratio, num_blocks, large_size, small_size, dims, dw_dims,\r\n                 nvars, small_kernel_merged=False, backbone_dropout=0.1, head_dropout=0.1, use_multi_scale=True, revin=True, affine=True,\r\n                 subtract_last=False, freq=None, seq_len=512, c_in=7, individual=False, target_window=96):\r\n\r\n        super(ModernTCNModel, self).__init__()\r\n\r\n\r\n\r\n        # RevIN\r\n        self.revin = revin\r\n        if self.revin:\r\n            self.revin_layer = RevIN(c_in, affine=affine, subtract_last=subtract_last)\r\n\r\n        # stem layer & down sampling layers(if needed)\r\n        self.downsample_layers = nn.ModuleList()\r\n        stem = nn.Sequential(\r\n\r\n            nn.Conv1d(1, dims[0], kernel_size=patch_size, stride=patch_stride),\r\n            nn.BatchNorm1d(dims[0])\r\n        )\r\n        self.downsample_layers.append(stem)\r\n        for i in range(3):\r\n            downsample_layer = nn.Sequential(\r\n                nn.BatchNorm1d(dims[i]),\r\n                nn.Conv1d(dims[i], dims[i + 1], kernel_size=downsample_ratio, stride=downsample_ratio),\r\n            )\r\n            self.downsample_layers.append(downsample_layer)\r\n        self.patch_size = patch_size\r\n        self.patch_stride = patch_stride\r\n        self.downsample_ratio = downsample_ratio\r\n\r\n        # if freq == 'h':\r\n        #     time_feature_num = 4\r\n        # elif freq == 't':\r\n        #     time_feature_num = 5\r\n        # else:\r\n        #     raise NotImplementedError(\"time_feature_num should be 4 or 5\")\r\n        if freq.lower() == 'h':\r\n            time_feature_num = 4\r\n        else:\r\n            time_feature_num = 5\r\n        \r\n        self.te_patch = nn.Sequential(\r\n\r\n            nn.Conv1d(time_feature_num, time_feature_num, kernel_size=patch_size, stride=patch_stride,groups=time_feature_num),\r\n            nn.Conv1d(time_feature_num, dims[0], kernel_size=1, stride=1, groups=1),\r\n            nn.BatchNorm1d(dims[0]))\r\n\r\n        # backbone\r\n\r\n        self.num_stage = len(num_blocks)\r\n        self.stages = nn.ModuleList()\r\n        for stage_idx in range(self.num_stage):\r\n            layer = Stage(ffn_ratio, num_blocks[stage_idx], large_size[stage_idx], small_size[stage_idx], dmodel=dims[stage_idx],\r\n                          dw_model=dw_dims[stage_idx], nvars=nvars, small_kernel_merged=small_kernel_merged, drop=backbone_dropout)\r\n            self.stages.append(layer)\r\n\r\n        # Multi scale fusing (if needed)\r\n        self.use_multi_scale = use_multi_scale\r\n        self.up_sample_ratio = downsample_ratio\r\n\r\n        self.lat_layer = nn.ModuleList()\r\n        self.smooth_layer = nn.ModuleList()\r\n        self.up_sample_conv = nn.ModuleList()\r\n        for i in range(self.num_stage):\r\n            align_dim = dims[-1]\r\n            lat = nn.Conv1d(dims[i], align_dim, kernel_size=1,\r\n                            stride=1)\r\n            self.lat_layer.append(lat)\r\n            smooth = nn.Conv1d(align_dim, align_dim, kernel_size=3, stride=1, padding=1)\r\n            self.smooth_layer.append(smooth)\r\n\r\n            up_conv = nn.Sequential(\r\n                nn.ConvTranspose1d(align_dim, align_dim, kernel_size=self.up_sample_ratio, stride=self.up_sample_ratio),\r\n                nn.BatchNorm1d(align_dim))\r\n            self.up_sample_conv.append(up_conv)\r\n\r\n        # head\r\n        patch_num = seq_len // patch_stride\r\n\r\n        self.n_vars = c_in\r\n        self.individual = individual\r\n        d_model = dims[-1]\r\n        if use_multi_scale:\r\n            self.head_nf = d_model * patch_num\r\n            self.head = Flatten_Head(self.individual, self.n_vars, self.head_nf, target_window,\r\n                                     head_dropout=head_dropout)\r\n        else:\r\n\r\n            if patch_num % pow(downsample_ratio,(self.num_stage - 1)) == 0:\r\n                self.head_nf = d_model * patch_num // pow(downsample_ratio,(self.num_stage - 1))\r\n            else:\r\n                self.head_nf = d_model * (patch_num // pow(downsample_ratio, (self.num_stage - 1))+1)\r\n            self.head = Flatten_Head(self.individual, self.n_vars, self.head_nf, target_window,\r\n                                     head_dropout=head_dropout)\r\n\r\n    def up_sample(self, x, upsample_ratio):\r\n        _, _, _, N = x.shape\r\n        return F.upsample(x, size=N, scale_factor=upsample_ratio, mode='bilinear')\r\n\r\n    def forward_feature(self, x, te=None):\r\n\r\n        B,M,L=x.shape\r\n\r\n        x = x.unsqueeze(-2)\r\n        for i in range(self.num_stage):\r\n            B, M, D, N = x.shape\r\n            x = x.reshape(B * M, D, N)\r\n            if i==0:\r\n                if self.patch_size != self.patch_stride:\r\n                    # stem layer padding\r\n                    pad_len = self.patch_size - self.patch_stride\r\n                    pad = x[:,:,-1:].repeat(1,1,pad_len)\r\n                    x = torch.cat([x,pad],dim=-1)\r\n            else:\r\n                if N % self.downsample_ratio != 0:\r\n                    pad_len = self.downsample_ratio - (N % self.downsample_ratio)\r\n                    x = torch.cat([x, x[:, :, -pad_len:]],dim=-1)\r\n            x = self.downsample_layers[i](x)\r\n            _, D_, N_ = x.shape\r\n            x = x.reshape(B, M, D_, N_)\r\n            x = self.stages[i](x)\r\n        return x\r\n\r\n    def forward(self, x, te=None):\r\n\r\n        # instance norm\r\n        if self.revin:\r\n            x = x.permute(0, 2, 1)\r\n            x = self.revin_layer(x, 'norm')\r\n            x = x.permute(0, 2, 1)\r\n        x = self.forward_feature(x,te)\r\n        x = self.head(x)\r\n        # de-instance norm\r\n        if self.revin:\r\n            x = x.permute(0, 2, 1)\r\n            x = self.revin_layer(x, 'denorm')\r\n            x = x.permute(0, 2, 1)\r\n        return x\r\n\r\n    def structural_reparam(self):\r\n        for m in self.modules():\r\n            if hasattr(m, 'merge_kernel'):\r\n                m.merge_kernel()\r\n\r\n"
  },
  {
    "path": "probts/model/nn/arch/Moirai_backbone.py",
    "content": "# ---------------------------------------------------------------------------------\n# Portions of this file are derived from uni2ts\n# - Source: https://github.com/SalesforceAIResearch/uni2ts\n# - Paper: Unified Training of Universal Time Series Forecasting Transformers\n# - License: Apache License 2.0\n\n# We thank the authors for their contributions.\n# ---------------------------------------------------------------------------------\n\n\nimport math\nfrom contextlib import contextmanager\nfrom copy import deepcopy\nfrom typing import Any, Generator, Optional\nimport sys\n\nimport lightning as L\nimport torch\nfrom einops import rearrange, reduce, repeat\nfrom jaxtyping import Bool, Float, Int\nfrom torch.distributions import Distribution\n\nfrom uni2ts.common.torch_util import safe_div\nfrom uni2ts.loss.packed import PackedNLLLoss as _PackedNLLLoss\nfrom uni2ts.model.moirai.module import MoiraiModule\nfrom uni2ts.module.packed_scaler import PackedNOPScaler, PackedStdScaler\n\n\nclass SampleNLLLoss(_PackedNLLLoss):\n    def reduce_loss(\n        self,\n        loss: Float[torch.Tensor, \"batch seq_len #dim\"],\n        prediction_mask: Optional[Bool[torch.Tensor, \"batch seq_len\"]],\n        observed_mask: Optional[Bool[torch.Tensor, \"batch seq_len #dim\"]],\n        sample_id: Optional[Int[torch.Tensor, \"batch seq_len\"]],\n        variate_id: Optional[Int[torch.Tensor, \"batch seq_len\"]],\n    ) -> Float[torch.Tensor, \"batch\"]:\n        id_mask = torch.logical_and(\n            torch.eq(sample_id.unsqueeze(-1), sample_id.unsqueeze(-2)),\n            torch.eq(variate_id.unsqueeze(-1), variate_id.unsqueeze(-2)),\n        )\n        mask = prediction_mask.unsqueeze(-1) * observed_mask\n        tobs = reduce(\n            id_mask\n            * reduce(\n                mask,\n                \"... seq dim -> ... 1 seq\",\n                \"sum\",\n            ),\n            \"... seq1 seq2 -> ... seq1 1\",\n            \"sum\",\n        )\n        loss = safe_div(loss, tobs)\n        return (loss * mask).sum(dim=(-1, -2))\n\n\nclass MoiraiBackbone(L.LightningModule):\n    def __init__(\n        self,\n        prediction_length: int,\n        target_dim: int,\n        context_length: int,\n        module_kwargs: Optional[dict[str, Any]] = None,\n        module: Optional[MoiraiModule] = None,\n        patch_size: int | str = \"auto\",\n        num_samples: int = 100,\n        scaling: bool = True,\n    ):\n        assert (module is not None) or (\n            module_kwargs is not None\n        ), \"if module is not provided, module_kwargs is required\"\n        super().__init__()\n        self.save_hyperparameters(ignore=[\"module\"])\n        self.module = MoiraiModule(**module_kwargs) if module is None else module\n        self.module.scaling = scaling\n        self.module.scaler = PackedStdScaler() if scaling else PackedNOPScaler()\n        self.per_sample_loss_func = SampleNLLLoss()\n\n    @contextmanager\n    def hparams_context(\n        self,\n        prediction_length: Optional[int] = None,\n        target_dim: Optional[int] = None,\n        context_length: Optional[int] = None,\n        patch_size: Optional[int | str] = None,\n        num_samples: Optional[int] = None,\n    ) -> Generator[\"MoiraiForecast\", None, None]:\n        kwargs = {\n            \"prediction_length\": prediction_length,\n            \"target_dim\": target_dim,\n            \"context_length\": context_length,\n            \"patch_size\": patch_size,\n            \"num_samples\": num_samples,\n        }\n        old_hparams = deepcopy(self.hparams)\n        for kw, arg in kwargs.items():\n            if arg is not None:\n                self.hparams[kw] = arg\n\n        yield self\n\n        for kw in kwargs:\n            self.hparams[kw] = old_hparams[kw]\n\n    @property\n    def past_length(self) -> int:\n        return (\n            self.hparams.context_length + self.hparams.prediction_length\n            if self.hparams.patch_size == \"auto\"\n            else self.hparams.context_length\n        )\n\n    def context_token_length(self, patch_size: int) -> int:\n        return math.ceil(self.hparams.context_length / patch_size)\n\n    def prediction_token_length(self, patch_size) -> int:\n        return math.ceil(self.hparams.prediction_length / patch_size)\n\n    @property\n    def max_patch_size(self) -> int:\n        return max(self.module.patch_sizes)\n\n    def forward(\n        self,\n        past_target: Float[torch.Tensor, \"batch past_time tgt\"],\n        past_observed_target: Bool[torch.Tensor, \"batch past_time tgt\"],\n        past_is_pad: Bool[torch.Tensor, \"batch past_time\"],\n        num_samples: Optional[int] = None,\n    ) -> Float[torch.Tensor, \"batch sample future_time *tgt\"]:\n        \n        if self.hparams.patch_size == \"auto\":\n            val_loss = []\n            preds = []\n            for patch_size in self.module.patch_sizes:\n                val_loss.append(\n                    self._val_loss(\n                        patch_size=patch_size,\n                        target=past_target[..., : self.past_length, :],\n                        observed_target=past_observed_target[\n                            ..., : self.past_length, :\n                        ],\n                        is_pad=past_is_pad[..., : self.past_length]\n                    )\n                )\n                distr = self._get_distr(\n                    patch_size,\n                    past_target[..., -self.hparams.context_length :, :],\n                    past_observed_target[..., -self.hparams.context_length :, :],\n                    past_is_pad[..., -self.hparams.context_length :]\n                )\n                preds.append(\n                    self._format_preds(\n                        patch_size,\n                        distr.sample(\n                            torch.Size((num_samples or self.hparams.num_samples,))\n                        ),\n                        past_target.shape[-1],\n                    )\n                )\n            val_loss = torch.stack(val_loss)\n            preds = torch.stack(preds)\n            idx = val_loss.argmin(dim=0)\n            return preds[idx, torch.arange(len(idx), device=idx.device)]\n        else:\n            distr = self._get_distr(\n                self.hparams.patch_size,\n                past_target[..., -self.hparams.context_length :, :],\n                past_observed_target[..., -self.hparams.context_length :, :],\n                past_is_pad[..., -self.hparams.context_length :],\n            )\n            preds = distr.sample(torch.Size((num_samples or self.hparams.num_samples,)))\n            return self._format_preds(\n                self.hparams.patch_size, preds, past_target.shape[-1]\n            )\n\n    def _val_loss(\n        self,\n        patch_size: int,\n        target: Float[torch.Tensor, \"batch time tgt\"],\n        observed_target: Bool[torch.Tensor, \"batch time tgt\"],\n        is_pad: Bool[torch.Tensor, \"batch time\"]\n    ) -> Float[torch.Tensor, \"batch\"]:\n        # convert format\n        (\n            target,\n            observed_mask,\n            sample_id,\n            time_id,\n            variate_id,\n            prediction_mask,\n        ) = self._convert(\n            patch_size,\n            past_target=target[..., : self.hparams.context_length, :],\n            past_observed_target=observed_target[..., : self.hparams.context_length, :],\n            past_is_pad=is_pad[..., : self.hparams.context_length],\n            future_target=target[..., self.hparams.context_length :, :],\n            future_observed_target=observed_target[\n                ..., self.hparams.context_length :, :\n            ],\n            future_is_pad=is_pad[..., self.hparams.context_length :]\n        )\n        # get predictions\n        distr = self.module(\n            target,\n            observed_mask,\n            sample_id,\n            time_id,\n            variate_id,\n            prediction_mask,\n            torch.ones_like(time_id, dtype=torch.long) * patch_size,\n        )\n        val_loss = self.per_sample_loss_func(\n            pred=distr,\n            target=target,\n            prediction_mask=prediction_mask,\n            observed_mask=observed_mask,\n            sample_id=sample_id,\n            variate_id=variate_id,\n        )\n        return val_loss\n\n    def _get_distr(\n        self,\n        patch_size: int,\n        past_target: Float[torch.Tensor, \"batch past_time tgt\"],\n        past_observed_target: Bool[torch.Tensor, \"batch past_time tgt\"],\n        past_is_pad: Bool[torch.Tensor, \"batch past_time\"]\n    ) -> Distribution:\n        # convert format\n        (\n            target,\n            observed_mask,\n            sample_id,\n            time_id,\n            variate_id,\n            prediction_mask,\n        ) = self._convert(\n            patch_size,\n            past_target,\n            past_observed_target,\n            past_is_pad\n        )\n        # get predictions\n        distr = self.module(\n            target,\n            observed_mask,\n            sample_id,\n            time_id,\n            variate_id,\n            prediction_mask,\n            torch.ones_like(time_id, dtype=torch.long) * patch_size,\n        )\n        return distr\n\n    @staticmethod\n    def _patched_seq_pad(\n        patch_size: int,\n        x: torch.Tensor,\n        dim: int,\n        left: bool = True,\n        value: Optional[float] = None,\n    ) -> torch.Tensor:\n        if dim >= 0:\n            dim = -x.ndim + dim\n        pad_length = -x.size(dim) % patch_size\n        if left:\n            pad = (pad_length, 0)\n        else:\n            pad = (0, pad_length)\n        pad = (0, 0) * (abs(dim) - 1) + pad\n        return torch.nn.functional.pad(x, pad, value=value)\n\n    def _generate_time_id(\n        self,\n        patch_size: int,\n        past_observed_target: Bool[torch.Tensor, \"batch past_seq tgt\"],\n        future_target: Float[torch.Tensor, \"batch future_seq tgt\"],\n    ) -> tuple[\n        Int[torch.Tensor, \"batch past_token\"], Int[torch.Tensor, \"batch future_token\"]\n    ]:\n        past_seq_id = reduce(\n            self._patched_seq_pad(patch_size, past_observed_target, -2, left=True),\n            \"... (seq patch) dim -> ... seq\",\n            \"max\",\n            patch=patch_size,\n        )\n        past_seq_id = torch.clamp(past_seq_id.cumsum(dim=-1) - 1, min=0)\n        batch_shape = \" \".join(map(str, past_observed_target.shape[:-2]))\n        future_seq_id = (\n            repeat(\n                torch.arange(\n                    math.ceil(future_target.shape[-2] / patch_size),\n                    device=past_observed_target.device,\n                ),\n                f\"prediction -> {batch_shape} prediction\",\n            )\n            + past_seq_id.max(dim=-1, keepdim=True).values\n            + 1\n        )\n        past_seq_id = past_seq_id.to(dtype=torch.int32)\n        future_seq_id = future_seq_id.to(dtype=torch.int32)\n        return past_seq_id, future_seq_id\n\n    def _convert(\n        self,\n        patch_size: int,\n        past_target: Float[torch.Tensor, \"batch past_time tgt\"],\n        past_observed_target: Bool[torch.Tensor, \"batch past_time tgt\"],\n        past_is_pad: Bool[torch.Tensor, \"batch past_time\"],\n        future_target: Optional[Float[torch.Tensor, \"batch future_time tgt\"]] = None,\n        future_observed_target: Optional[\n            Bool[torch.Tensor, \"batch future_time tgt\"]\n        ] = None,\n        future_is_pad: Optional[Bool[torch.Tensor, \"batch future_time\"]] = None\n    ) -> tuple[\n        Float[torch.Tensor, \"batch combine_seq patch\"],  # target\n        Bool[torch.Tensor, \"batch combine_seq patch\"],  # observed_mask\n        Int[torch.Tensor, \"batch combine_seq\"],  # sample_id\n        Int[torch.Tensor, \"batch combine_seq\"],  # time_id\n        Int[torch.Tensor, \"batch combine_seq\"],  # variate_id\n        Bool[torch.Tensor, \"batch combine_seq\"],  # prediction_mask\n    ]:\n        batch_shape = past_target.shape[:-2]\n        device = past_target.device\n\n        target = []\n        observed_mask = []\n        sample_id = []\n        time_id = []\n        variate_id = []\n        prediction_mask = []\n        dim_count = 0\n\n        if future_target is None:\n            future_target = torch.zeros(\n                batch_shape\n                + (\n                    self.hparams.prediction_length,\n                    past_target.shape[-1],\n                ),\n                dtype=past_target.dtype,\n                device=device,\n            )\n        \n        past_seq_id, future_seq_id = self._generate_time_id(\n            patch_size, past_observed_target, future_target\n        )\n\n        target.extend(\n            [\n                torch.nn.functional.pad(\n                    rearrange(\n                        self._patched_seq_pad(patch_size, past_target, -2, left=True),\n                        \"... (seq patch) dim -> ... (dim seq) patch\",\n                        patch=patch_size,\n                    ),\n                    (0, self.max_patch_size - patch_size),\n                ),\n                torch.nn.functional.pad(\n                    rearrange(\n                        self._patched_seq_pad(\n                            patch_size, future_target, -2, left=False\n                        ),\n                        \"... (seq patch) dim -> ... (dim seq) patch\",\n                        patch=patch_size,\n                    ),\n                    (0, self.max_patch_size - patch_size),\n                ),\n            ]\n        )\n        if future_observed_target is None:\n            future_observed_target = torch.ones(\n                batch_shape\n                + (\n                    self.hparams.prediction_length,\n                    past_observed_target.shape[-1],\n                ),\n                dtype=torch.bool,\n                device=device,\n            )\n        observed_mask.extend(\n            [\n                torch.nn.functional.pad(\n                    rearrange(\n                        self._patched_seq_pad(\n                            patch_size, past_observed_target, -2, left=True\n                        ),\n                        \"... (seq patch) dim -> ... (dim seq) patch\",\n                        patch=patch_size,\n                    ),\n                    (0, self.max_patch_size - patch_size),\n                ),\n                torch.nn.functional.pad(\n                    rearrange(\n                        self._patched_seq_pad(\n                            patch_size, future_observed_target, -2, left=False\n                        ),\n                        \"... (seq patch) dim -> ... (dim seq) patch\",\n                        patch=patch_size,\n                    ),\n                    (0, self.max_patch_size - patch_size),\n                ),\n            ]\n        )\n        if future_is_pad is None:\n            future_is_pad = torch.zeros(\n                batch_shape + (self.hparams.prediction_length,),\n                dtype=torch.long,\n                device=device,\n            )\n        sample_id.extend(\n            [\n                repeat(\n                    reduce(\n                        (\n                            self._patched_seq_pad(\n                                patch_size, past_is_pad, -1, left=True, value=1\n                            )\n                            == 0\n                        ).int(),\n                        \"... (seq patch) -> ... seq\",\n                        \"max\",\n                        patch=patch_size,\n                    ),\n                    \"... seq -> ... (dim seq)\",\n                    dim=past_target.shape[-1],\n                ),\n                repeat(\n                    reduce(\n                        (\n                            self._patched_seq_pad(\n                                patch_size, future_is_pad, -1, left=False, value=1\n                            )\n                            == 0\n                        ).int(),\n                        \"... (seq patch) -> ... seq\",\n                        \"max\",\n                        patch=patch_size,\n                    ),\n                    \"... seq -> ... (dim seq)\",\n                    dim=past_target.shape[-1],\n                ),\n            ]\n        )\n        time_id.extend(\n            [past_seq_id] * past_target.shape[-1]\n            + [future_seq_id] * past_target.shape[-1]\n        )\n        variate_id.extend(\n            [\n                repeat(\n                    torch.arange(past_target.shape[-1], device=device) + dim_count,\n                    f\"dim -> {' '.join(map(str, batch_shape))} (dim past)\",\n                    past=self.context_token_length(patch_size),\n                ),\n                repeat(\n                    torch.arange(past_target.shape[-1], device=device) + dim_count,\n                    f\"dim -> {' '.join(map(str, batch_shape))} (dim future)\",\n                    # future=self.prediction_token_length(patch_size),\n                    future = math.ceil(future_target.shape[-2] / patch_size)\n                ),\n            ]\n        )\n        dim_count += past_target.shape[-1]\n        prediction_mask.extend(\n            [\n                torch.zeros(\n                    batch_shape\n                    + (self.context_token_length(patch_size) * past_target.shape[-1],),\n                    dtype=torch.bool,\n                    device=device,\n                ),\n                torch.ones(\n                    batch_shape\n                    + (\n                        # self.prediction_token_length(patch_size)\n                        math.ceil(future_target.shape[-2] / patch_size)\n                        * past_target.shape[-1],\n                    ),\n                    dtype=torch.bool,\n                    device=device,\n                ),\n            ]\n        )\n\n        target = torch.cat(target, dim=-2)\n        observed_mask = torch.cat(observed_mask, dim=-2)\n        sample_id = torch.cat(sample_id, dim=-1)\n        time_id = torch.cat(time_id, dim=-1)\n        variate_id = torch.cat(variate_id, dim=-1)\n        prediction_mask = torch.cat(prediction_mask, dim=-1)\n        return (\n            target,\n            observed_mask,\n            sample_id,\n            time_id,\n            variate_id,\n            prediction_mask,\n        )\n\n    def _format_preds(\n        self,\n        patch_size: int,\n        preds: Float[torch.Tensor, \"sample batch combine_seq patch\"],\n        target_dim: int,\n    ) -> Float[torch.Tensor, \"batch sample future_time *tgt\"]:\n        start = target_dim * self.context_token_length(patch_size)\n        end = start + target_dim * self.prediction_token_length(patch_size)\n        preds = preds[..., start:end, :patch_size]\n        preds = rearrange(\n            preds,\n            \"sample ... (dim seq) patch -> ... sample (seq patch) dim\",\n            dim=target_dim,\n        )[..., : self.hparams.prediction_length, :]\n        return preds.squeeze(-1)"
  },
  {
    "path": "probts/model/nn/arch/PatchTSTModule/PatchTST_backbone.py",
    "content": "# ---------------------------------------------------------------------------------\n# Portions of this file are derived from PatchTST\n# - Source: https://github.com/yuqinie98/PatchTST/tree/main\n#\n# We thank the authors for their contributions.\n# ---------------------------------------------------------------------------------\n\n\n__all__ = ['PatchTST_backbone']\n\n# Cell\nfrom typing import Callable, Optional\nimport torch\nfrom torch import nn\nfrom torch import Tensor\nimport torch.nn.functional as F\nimport numpy as np\n\n#from collections import OrderedDict\nfrom probts.model.nn.arch.PatchTSTModule.PatchTST_layers import *\nfrom probts.model.nn.arch.RevIN import RevIN\n\n# Cell\nclass PatchTST_backbone(nn.Module):\n    def __init__(self, c_in:int, context_window:int, target_window:int, patch_len:int, stride:int, max_seq_len:Optional[int]=1024, \n                 n_layers:int=3, d_model=128, n_heads=16, d_k:Optional[int]=None, d_v:Optional[int]=None,\n                 d_ff:int=256, norm:str='BatchNorm', attn_dropout:float=0., dropout:float=0., act:str=\"gelu\", key_padding_mask:bool='auto',\n                 padding_var:Optional[int]=None, attn_mask:Optional[Tensor]=None, res_attention:bool=True, pre_norm:bool=False, store_attn:bool=False,\n                 pe:str='zeros', learn_pe:bool=True, fc_dropout:float=0., head_dropout = 0, padding_patch = None,\n                 pretrain_head:bool=False, head_type = 'flatten', individual = False, revin = True, affine = True, subtract_last = False,\n                 verbose:bool=False):\n        \n        super().__init__()\n        \n        # RevIn\n        self.revin = revin\n        if self.revin: self.revin_layer = RevIN(c_in, affine=affine, subtract_last=subtract_last)\n        \n        # Patching\n        self.patch_len = patch_len\n        self.stride = stride\n        self.padding_patch = padding_patch\n        patch_num = int((context_window - patch_len)/stride + 1)\n        if padding_patch == 'end': # can be modified to general case\n            self.padding_patch_layer = nn.ReplicationPad1d((0, stride)) \n            patch_num += 1\n        \n        # Backbone \n        self.backbone = TSTiEncoder(c_in, patch_num=patch_num, patch_len=patch_len, max_seq_len=max_seq_len,\n                                n_layers=n_layers, d_model=d_model, n_heads=n_heads, d_k=d_k, d_v=d_v, d_ff=d_ff,\n                                attn_dropout=attn_dropout, dropout=dropout, act=act, key_padding_mask=key_padding_mask, padding_var=padding_var,\n                                attn_mask=attn_mask, res_attention=res_attention, pre_norm=pre_norm, store_attn=store_attn,\n                                pe=pe, learn_pe=learn_pe, verbose=verbose)\n\n        # Head\n        self.head_nf = d_model * patch_num\n        self.n_vars = c_in\n        self.pretrain_head = pretrain_head\n        self.head_type = head_type\n        self.individual = individual\n\n        if self.pretrain_head: \n            self.head = self.create_pretrain_head(self.head_nf, c_in, fc_dropout) # custom head passed as a partial func with all its kwargs\n        elif head_type == 'flatten': \n            self.head = Flatten_Head(self.individual, self.n_vars, self.head_nf, target_window, head_dropout=head_dropout)\n        \n    \n    def forward(self, z):                                                                   # z: [bs x nvars x seq_len]\n        # norm\n        if self.revin: \n            z = z.permute(0,2,1)\n            z = self.revin_layer(z, 'norm')\n            z = z.permute(0,2,1)\n            \n        # do patching\n        if self.padding_patch == 'end':\n            z = self.padding_patch_layer(z)\n        z = z.unfold(dimension=-1, size=self.patch_len, step=self.stride)                   # z: [bs x nvars x patch_num x patch_len]\n        z = z.permute(0,1,3,2)                                                              # z: [bs x nvars x patch_len x patch_num]\n        \n        # model\n        z = self.backbone(z)                                                                # z: [bs x nvars x d_model x patch_num]\n        z = self.head(z)                                                                    # z: [bs x nvars x target_window] \n        \n        # denorm\n        if self.revin: \n            z = z.permute(0,2,1)\n            z = self.revin_layer(z, 'denorm')\n            z = z.permute(0,2,1)\n        return z\n    \n    def create_pretrain_head(self, head_nf, vars, dropout):\n        return nn.Sequential(nn.Dropout(dropout),\n                    nn.Conv1d(head_nf, vars, 1)\n                    )\n\n\nclass Flatten_Head(nn.Module):\n    def __init__(self, individual, n_vars, nf, target_window, head_dropout=0):\n        super().__init__()\n        \n        self.individual = individual\n        self.n_vars = n_vars\n        \n        if self.individual:\n            self.linears = nn.ModuleList()\n            self.dropouts = nn.ModuleList()\n            self.flattens = nn.ModuleList()\n            for i in range(self.n_vars):\n                self.flattens.append(nn.Flatten(start_dim=-2))\n                self.linears.append(nn.Linear(nf, target_window))\n                self.dropouts.append(nn.Dropout(head_dropout))\n        else:\n            self.flatten = nn.Flatten(start_dim=-2)\n            self.linear = nn.Linear(nf, target_window)\n            self.dropout = nn.Dropout(head_dropout)\n            \n    def forward(self, x):                                 # x: [bs x nvars x d_model x patch_num]\n        if self.individual:\n            x_out = []\n            for i in range(self.n_vars):\n                z = self.flattens[i](x[:,i,:,:])          # z: [bs x d_model * patch_num]\n                z = self.linears[i](z)                    # z: [bs x target_window]\n                z = self.dropouts[i](z)\n                x_out.append(z)\n            x = torch.stack(x_out, dim=1)                 # x: [bs x nvars x target_window]\n        else:\n            x = self.flatten(x)\n            x = self.linear(x)\n            x = self.dropout(x)\n        return x\n        \n        \n    \n    \nclass TSTiEncoder(nn.Module):  #i means channel-independent\n    def __init__(self, c_in, patch_num, patch_len, max_seq_len=1024,\n                 n_layers=3, d_model=128, n_heads=16, d_k=None, d_v=None,\n                 d_ff=256, norm='BatchNorm', attn_dropout=0., dropout=0., act=\"gelu\", store_attn=False,\n                 key_padding_mask='auto', padding_var=None, attn_mask=None, res_attention=True, pre_norm=False,\n                 pe='zeros', learn_pe=True, verbose=False):\n        \n        \n        super().__init__()\n        \n        self.patch_num = patch_num\n        self.patch_len = patch_len\n        \n        # Input encoding\n        q_len = patch_num\n        self.W_P = nn.Linear(patch_len, d_model)        # Eq 1: projection of feature vectors onto a d-dim vector space\n        self.seq_len = q_len\n\n        # Positional encoding\n        self.W_pos = positional_encoding(pe, learn_pe, q_len, d_model)\n\n        # Residual dropout\n        self.dropout = nn.Dropout(dropout)\n\n        # Encoder\n        self.encoder = TSTEncoder(q_len, d_model, n_heads, d_k=d_k, d_v=d_v, d_ff=d_ff, norm=norm, attn_dropout=attn_dropout, dropout=dropout,\n                                   pre_norm=pre_norm, activation=act, res_attention=res_attention, n_layers=n_layers, store_attn=store_attn)\n\n        \n    def forward(self, x) -> Tensor:                                              # x: [bs x nvars x patch_len x patch_num]\n        \n        n_vars = x.shape[1]\n        # Input encoding\n        x = x.permute(0,1,3,2)                                                   # x: [bs x nvars x patch_num x patch_len]\n        x = self.W_P(x)                                                          # x: [bs x nvars x patch_num x d_model]\n\n        u = torch.reshape(x, (x.shape[0]*x.shape[1],x.shape[2],x.shape[3]))      # u: [bs * nvars x patch_num x d_model]\n        u = self.dropout(u + self.W_pos)                                         # u: [bs * nvars x patch_num x d_model]\n\n        # Encoder\n        z = self.encoder(u)                                                      # z: [bs * nvars x patch_num x d_model]\n        z = torch.reshape(z, (-1,n_vars,z.shape[-2],z.shape[-1]))                # z: [bs x nvars x patch_num x d_model]\n        z = z.permute(0,1,3,2)                                                   # z: [bs x nvars x d_model x patch_num]\n        \n        return z    \n            \n            \n    \n# Cell\nclass TSTEncoder(nn.Module):\n    def __init__(self, q_len, d_model, n_heads, d_k=None, d_v=None, d_ff=None, \n                        norm='BatchNorm', attn_dropout=0., dropout=0., activation='gelu',\n                        res_attention=False, n_layers=1, pre_norm=False, store_attn=False):\n        super().__init__()\n\n        self.layers = nn.ModuleList([TSTEncoderLayer(q_len, d_model, n_heads=n_heads, d_k=d_k, d_v=d_v, d_ff=d_ff, norm=norm,\n                                                      attn_dropout=attn_dropout, dropout=dropout,\n                                                      activation=activation, res_attention=res_attention,\n                                                      pre_norm=pre_norm, store_attn=store_attn) for i in range(n_layers)])\n        self.res_attention = res_attention\n\n    def forward(self, src:Tensor, key_padding_mask:Optional[Tensor]=None, attn_mask:Optional[Tensor]=None):\n        output = src\n        scores = None\n        if self.res_attention:\n            for mod in self.layers: output, scores = mod(output, prev=scores, key_padding_mask=key_padding_mask, attn_mask=attn_mask)\n            return output\n        else:\n            for mod in self.layers: output = mod(output, key_padding_mask=key_padding_mask, attn_mask=attn_mask)\n            return output\n\n\n\nclass TSTEncoderLayer(nn.Module):\n    def __init__(self, q_len, d_model, n_heads, d_k=None, d_v=None, d_ff=256, store_attn=False,\n                 norm='BatchNorm', attn_dropout=0, dropout=0., bias=True, activation=\"gelu\", res_attention=False, pre_norm=False):\n        super().__init__()\n        assert not d_model%n_heads, f\"d_model ({d_model}) must be divisible by n_heads ({n_heads})\"\n        d_k = d_model // n_heads if d_k is None else d_k\n        d_v = d_model // n_heads if d_v is None else d_v\n\n        # Multi-Head attention\n        self.res_attention = res_attention\n        self.self_attn = _MultiheadAttention(d_model, n_heads, d_k, d_v, attn_dropout=attn_dropout, proj_dropout=dropout, res_attention=res_attention)\n\n        # Add & Norm\n        self.dropout_attn = nn.Dropout(dropout)\n        if \"batch\" in norm.lower():\n            self.norm_attn = nn.Sequential(Transpose(1,2), nn.BatchNorm1d(d_model), Transpose(1,2))\n        else:\n            self.norm_attn = nn.LayerNorm(d_model)\n\n        # Position-wise Feed-Forward\n        self.ff = nn.Sequential(nn.Linear(d_model, d_ff, bias=bias),\n                                get_activation_fn(activation),\n                                nn.Dropout(dropout),\n                                nn.Linear(d_ff, d_model, bias=bias))\n\n        # Add & Norm\n        self.dropout_ffn = nn.Dropout(dropout)\n        if \"batch\" in norm.lower():\n            self.norm_ffn = nn.Sequential(Transpose(1,2), nn.BatchNorm1d(d_model), Transpose(1,2))\n        else:\n            self.norm_ffn = nn.LayerNorm(d_model)\n\n        self.pre_norm = pre_norm\n        self.store_attn = store_attn\n\n\n    def forward(self, src:Tensor, prev:Optional[Tensor]=None, key_padding_mask:Optional[Tensor]=None, attn_mask:Optional[Tensor]=None) -> Tensor:\n\n        # Multi-Head attention sublayer\n        if self.pre_norm:\n            src = self.norm_attn(src)\n        ## Multi-Head attention\n        if self.res_attention:\n            src2, attn, scores = self.self_attn(src, src, src, prev, key_padding_mask=key_padding_mask, attn_mask=attn_mask)\n        else:\n            src2, attn = self.self_attn(src, src, src, key_padding_mask=key_padding_mask, attn_mask=attn_mask)\n        if self.store_attn:\n            self.attn = attn\n        ## Add & Norm\n        src = src + self.dropout_attn(src2) # Add: residual connection with residual dropout\n        if not self.pre_norm:\n            src = self.norm_attn(src)\n\n        # Feed-forward sublayer\n        if self.pre_norm:\n            src = self.norm_ffn(src)\n        ## Position-wise Feed-Forward\n        src2 = self.ff(src)\n        ## Add & Norm\n        src = src + self.dropout_ffn(src2) # Add: residual connection with residual dropout\n        if not self.pre_norm:\n            src = self.norm_ffn(src)\n\n        if self.res_attention:\n            return src, scores\n        else:\n            return src\n\n\n\n\nclass _MultiheadAttention(nn.Module):\n    def __init__(self, d_model, n_heads, d_k=None, d_v=None, res_attention=False, attn_dropout=0., proj_dropout=0., qkv_bias=True):\n        \"\"\"Multi Head Attention Layer\n        Input shape:\n            Q:       [batch_size (bs) x max_q_len x d_model]\n            K, V:    [batch_size (bs) x q_len x d_model]\n            mask:    [q_len x q_len]\n        \"\"\"\n        super().__init__()\n        d_k = d_model // n_heads if d_k is None else d_k\n        d_v = d_model // n_heads if d_v is None else d_v\n\n        self.n_heads, self.d_k, self.d_v = n_heads, d_k, d_v\n\n        self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=qkv_bias)\n        self.W_K = nn.Linear(d_model, d_k * n_heads, bias=qkv_bias)\n        self.W_V = nn.Linear(d_model, d_v * n_heads, bias=qkv_bias)\n\n        # Scaled Dot-Product Attention (multiple heads)\n        self.res_attention = res_attention\n        self.sdp_attn = _ScaledDotProductAttention(d_model, n_heads, attn_dropout=attn_dropout, res_attention=self.res_attention)\n\n        # Poject output\n        self.to_out = nn.Sequential(nn.Linear(n_heads * d_v, d_model), nn.Dropout(proj_dropout))\n\n\n    def forward(self, Q:Tensor, K:Optional[Tensor]=None, V:Optional[Tensor]=None, prev:Optional[Tensor]=None,\n                key_padding_mask:Optional[Tensor]=None, attn_mask:Optional[Tensor]=None):\n\n        bs = Q.size(0)\n        if K is None: K = Q\n        if V is None: V = Q\n\n        # Linear (+ split in multiple heads)\n        q_s = self.W_Q(Q).view(bs, -1, self.n_heads, self.d_k).transpose(1,2)       # q_s    : [bs x n_heads x max_q_len x d_k]\n        k_s = self.W_K(K).view(bs, -1, self.n_heads, self.d_k).permute(0,2,3,1)     # k_s    : [bs x n_heads x d_k x q_len] - transpose(1,2) + transpose(2,3)\n        v_s = self.W_V(V).view(bs, -1, self.n_heads, self.d_v).transpose(1,2)       # v_s    : [bs x n_heads x q_len x d_v]\n\n        # Apply Scaled Dot-Product Attention (multiple heads)\n        if self.res_attention:\n            output, attn_weights, attn_scores = self.sdp_attn(q_s, k_s, v_s, prev=prev, key_padding_mask=key_padding_mask, attn_mask=attn_mask)\n        else:\n            output, attn_weights = self.sdp_attn(q_s, k_s, v_s, key_padding_mask=key_padding_mask, attn_mask=attn_mask)\n        # output: [bs x n_heads x q_len x d_v], attn: [bs x n_heads x q_len x q_len], scores: [bs x n_heads x max_q_len x q_len]\n\n        # back to the original inputs dimensions\n        output = output.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * self.d_v) # output: [bs x q_len x n_heads * d_v]\n        output = self.to_out(output)\n\n        if self.res_attention: return output, attn_weights, attn_scores\n        else: return output, attn_weights\n\n\nclass _ScaledDotProductAttention(nn.Module):\n    r\"\"\"Scaled Dot-Product Attention module (Attention is all you need by Vaswani et al., 2017) with optional residual attention from previous layer\n    (Realformer: Transformer likes residual attention by He et al, 2020) and locality self sttention (Vision Transformer for Small-Size Datasets\n    by Lee et al, 2021)\"\"\"\n\n    def __init__(self, d_model, n_heads, attn_dropout=0., res_attention=False):\n        super().__init__()\n        self.attn_dropout = nn.Dropout(attn_dropout)\n        self.res_attention = res_attention\n        head_dim = d_model // n_heads\n        self.scale = torch.tensor(head_dim ** -0.5)\n\n    def forward(self, q:Tensor, k:Tensor, v:Tensor, prev:Optional[Tensor]=None, key_padding_mask:Optional[Tensor]=None, attn_mask:Optional[Tensor]=None):\n        '''\n        Input shape:\n            q               : [bs x n_heads x max_q_len x d_k]\n            k               : [bs x n_heads x d_k x seq_len]\n            v               : [bs x n_heads x seq_len x d_v]\n            prev            : [bs x n_heads x q_len x seq_len]\n            key_padding_mask: [bs x seq_len]\n            attn_mask       : [1 x seq_len x seq_len]\n        Output shape:\n            output:  [bs x n_heads x q_len x d_v]\n            attn   : [bs x n_heads x q_len x seq_len]\n            scores : [bs x n_heads x q_len x seq_len]\n        '''\n\n        # Scaled MatMul (q, k) - similarity scores for all pairs of positions in an input sequence\n        attn_scores = torch.matmul(q, k) * self.scale      # attn_scores : [bs x n_heads x max_q_len x q_len]\n\n        # Add pre-softmax attention scores from the previous layer (optional)\n        if prev is not None: attn_scores = attn_scores + prev\n\n        # Attention mask (optional)\n        if attn_mask is not None:                                     # attn_mask with shape [q_len x seq_len] - only used when q_len == seq_len\n            if attn_mask.dtype == torch.bool:\n                attn_scores.masked_fill_(attn_mask, -np.inf)\n            else:\n                attn_scores += attn_mask\n\n        # Key padding mask (optional)\n        if key_padding_mask is not None:                              # mask with shape [bs x q_len] (only when max_w_len == q_len)\n            attn_scores.masked_fill_(key_padding_mask.unsqueeze(1).unsqueeze(2), -np.inf)\n\n        # normalize the attention weights\n        attn_weights = F.softmax(attn_scores, dim=-1)                 # attn_weights   : [bs x n_heads x max_q_len x q_len]\n        attn_weights = self.attn_dropout(attn_weights)\n\n        # compute the new values given the attention weights\n        output = torch.matmul(attn_weights, v)                        # output: [bs x n_heads x max_q_len x d_v]\n\n        if self.res_attention: return output, attn_weights, attn_scores\n        else: return output, attn_weights\n\n"
  },
  {
    "path": "probts/model/nn/arch/PatchTSTModule/PatchTST_layers.py",
    "content": "# ---------------------------------------------------------------------------------\n# Portions of this file are derived from PatchTST\n# - Source: https://github.com/yuqinie98/PatchTST/tree/main\n#\n# We thank the authors for their contributions.\n# ---------------------------------------------------------------------------------\n\n\n__all__ = ['Transpose', 'get_activation_fn', 'moving_avg', 'series_decomp', 'PositionalEncoding', 'SinCosPosEncoding', 'Coord2dPosEncoding', 'Coord1dPosEncoding', 'positional_encoding']           \n\nimport torch\nfrom torch import nn\nimport math\n\nclass Transpose(nn.Module):\n    def __init__(self, *dims, contiguous=False): \n        super().__init__()\n        self.dims, self.contiguous = dims, contiguous\n    def forward(self, x):\n        if self.contiguous: return x.transpose(*self.dims).contiguous()\n        else: return x.transpose(*self.dims)\n\n    \ndef get_activation_fn(activation):\n    if callable(activation): return activation()\n    elif activation.lower() == \"relu\": return nn.ReLU()\n    elif activation.lower() == \"gelu\": return nn.GELU()\n    raise ValueError(f'{activation} is not available. You can use \"relu\", \"gelu\", or a callable') \n    \n    \n# decomposition\n\nclass moving_avg(nn.Module):\n    \"\"\"\n    Moving average block to highlight the trend of time series\n    \"\"\"\n    def __init__(self, kernel_size, stride):\n        super(moving_avg, self).__init__()\n        self.kernel_size = kernel_size\n        self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)\n\n    def forward(self, x):\n        # padding on the both ends of time series\n        front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1)\n        end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1)\n        x = torch.cat([front, x, end], dim=1)\n        x = self.avg(x.permute(0, 2, 1))\n        x = x.permute(0, 2, 1)\n        return x\n\n\nclass series_decomp(nn.Module):\n    \"\"\"\n    Series decomposition block\n    \"\"\"\n    def __init__(self, kernel_size):\n        super(series_decomp, self).__init__()\n        self.moving_avg = moving_avg(kernel_size, stride=1)\n\n    def forward(self, x):\n        moving_mean = self.moving_avg(x)\n        res = x - moving_mean\n        return res, moving_mean\n    \n    \n    \n# pos_encoding\n\ndef PositionalEncoding(q_len, d_model, normalize=True):\n    pe = torch.zeros(q_len, d_model)\n    position = torch.arange(0, q_len).unsqueeze(1)\n    div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))\n    pe[:, 0::2] = torch.sin(position * div_term)\n    pe[:, 1::2] = torch.cos(position * div_term)\n    if normalize:\n        pe = pe - pe.mean()\n        pe = pe / (pe.std() * 10)\n    return pe\n\nSinCosPosEncoding = PositionalEncoding\n\ndef Coord2dPosEncoding(q_len, d_model, exponential=False, normalize=True, eps=1e-3, verbose=False):\n    x = .5 if exponential else 1\n    i = 0\n    for i in range(100):\n        cpe = 2 * (torch.linspace(0, 1, q_len).reshape(-1, 1) ** x) * (torch.linspace(0, 1, d_model).reshape(1, -1) ** x) - 1\n        pv(f'{i:4.0f}  {x:5.3f}  {cpe.mean():+6.3f}', verbose)\n        if abs(cpe.mean()) <= eps: break\n        elif cpe.mean() > eps: x += .001\n        else: x -= .001\n        i += 1\n    if normalize:\n        cpe = cpe - cpe.mean()\n        cpe = cpe / (cpe.std() * 10)\n    return cpe\n\ndef Coord1dPosEncoding(q_len, exponential=False, normalize=True):\n    cpe = (2 * (torch.linspace(0, 1, q_len).reshape(-1, 1)**(.5 if exponential else 1)) - 1)\n    if normalize:\n        cpe = cpe - cpe.mean()\n        cpe = cpe / (cpe.std() * 10)\n    return cpe\n\ndef positional_encoding(pe, learn_pe, q_len, d_model):\n    # Positional encoding\n    if pe == None:\n        W_pos = torch.empty((q_len, d_model)) # pe = None and learn_pe = False can be used to measure impact of pe\n        nn.init.uniform_(W_pos, -0.02, 0.02)\n        learn_pe = False\n    elif pe == 'zero':\n        W_pos = torch.empty((q_len, 1))\n        nn.init.uniform_(W_pos, -0.02, 0.02)\n    elif pe == 'zeros':\n        W_pos = torch.empty((q_len, d_model))\n        nn.init.uniform_(W_pos, -0.02, 0.02)\n    elif pe == 'normal' or pe == 'gauss':\n        W_pos = torch.zeros((q_len, 1))\n        torch.nn.init.normal_(W_pos, mean=0.0, std=0.1)\n    elif pe == 'uniform':\n        W_pos = torch.zeros((q_len, 1))\n        nn.init.uniform_(W_pos, a=0.0, b=0.1)\n    elif pe == 'lin1d': W_pos = Coord1dPosEncoding(q_len, exponential=False, normalize=True)\n    elif pe == 'exp1d': W_pos = Coord1dPosEncoding(q_len, exponential=True, normalize=True)\n    elif pe == 'lin2d': W_pos = Coord2dPosEncoding(q_len, d_model, exponential=False, normalize=True)\n    elif pe == 'exp2d': W_pos = Coord2dPosEncoding(q_len, d_model, exponential=True, normalize=True)\n    elif pe == 'sincos': W_pos = PositionalEncoding(q_len, d_model, normalize=True)\n    else: raise ValueError(f\"{pe} is not a valid pe (positional encoder. Available types: 'gauss'=='normal', \\\n        'zeros', 'zero', uniform', 'lin1d', 'exp1d', 'lin2d', 'exp2d', 'sincos', None.)\")\n    return nn.Parameter(W_pos, requires_grad=learn_pe)"
  },
  {
    "path": "probts/model/nn/arch/RevIN.py",
    "content": "# ---------------------------------------------------------------------------------\n# Portions of this file are derived from RevIN\n# - Source: https://github.com/ts-kim/RevIN\n#\n# We thank the authors for their contributions.\n# ---------------------------------------------------------------------------------\n\n\nimport torch\nimport torch.nn as nn\n\nclass RevIN(nn.Module):\n    def __init__(self, num_features: int, eps=1e-5, affine=True, subtract_last=False):\n        \"\"\"\n        :param num_features: the number of features or channels\n        :param eps: a value added for numerical stability\n        :param affine: if True, RevIN has learnable affine parameters\n        \"\"\"\n        super(RevIN, self).__init__()\n        self.num_features = num_features\n        self.eps = eps\n        self.affine = affine\n        self.subtract_last = subtract_last\n        if self.affine:\n            self._init_params()\n\n    def forward(self, x, mode:str):\n        if mode == 'norm':\n            self._get_statistics(x)\n            x = self._normalize(x)\n        elif mode == 'denorm':\n            x = self._denormalize(x)\n        else: raise NotImplementedError\n        return x\n\n    def _init_params(self):\n        # initialize RevIN params: (C,)\n        self.affine_weight = nn.Parameter(torch.ones(self.num_features))\n        self.affine_bias = nn.Parameter(torch.zeros(self.num_features))\n\n    def _get_statistics(self, x):\n        dim2reduce = tuple(range(1, x.ndim-1))\n        if self.subtract_last:\n            self.last = x[:,-1,:].unsqueeze(1)\n        else:\n            self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach()\n        self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach()\n\n    def _normalize(self, x):\n        if self.subtract_last:\n            x = x - self.last\n        else:\n            x = x - self.mean\n        x = x / self.stdev\n        if self.affine:\n            x = x * self.affine_weight\n            x = x + self.affine_bias\n        return x\n\n    def _denormalize(self, x):\n        if self.affine:\n            x = x - self.affine_bias\n            x = x / (self.affine_weight + self.eps*self.eps)\n        x = x * self.stdev\n        if self.subtract_last:\n            x = x + self.last\n        else:\n            x = x + self.mean\n        return x\n"
  },
  {
    "path": "probts/model/nn/arch/S4/s4.py",
    "content": "\"\"\"Standalone version of Structured (Sequence) State Space (S4) model.\"\"\"\n\nimport logging\nfrom functools import partial\nimport math\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom pytorch_lightning.utilities import rank_zero_only\nfrom einops import rearrange, repeat\nimport opt_einsum as oe\n\ncontract = oe.contract\ncontract_expression = oe.contract_expression\n\n\ndef get_logger(name=__name__, level=logging.INFO) -> logging.Logger:\n    \"\"\"Initializes multi-GPU-friendly python logger.\"\"\"\n\n    logger = logging.getLogger(name)\n    logger.setLevel(level)\n\n    # this ensures all logging levels get marked with the rank zero decorator\n    # otherwise logs would get multiplied for each GPU process in multi-GPU setup\n    for level in (\n        \"debug\",\n        \"info\",\n        \"warning\",\n        \"error\",\n        \"exception\",\n        \"fatal\",\n        \"critical\",\n    ):\n        setattr(logger, level, rank_zero_only(getattr(logger, level)))\n\n    return logger\n\n\nlog = get_logger(__name__)\n\n\"\"\" Cauchy and Vandermonde kernels \"\"\"\n\ntry:  # Try CUDA extension\n    from extensions.cauchy.cauchy import cauchy_mult\n\n    has_cauchy_extension = True\nexcept ImportError:\n    # log.warning(\n    #     \"CUDA extension for cauchy multiplication not found. Install by going to extensions/cauchy/ and running `python setup.py install`. This should speed up end-to-end training by 10-50%\"\n    # )\n    has_cauchy_extension = False\n\ntry:  # Try pykeops\n    from pykeops.torch import Genred\n\n    has_pykeops = True\n    # log.info(\"Pykeops installation found.\")\n\n    def _broadcast_dims(*tensors):\n        max_dim = max([len(tensor.shape) for tensor in tensors])\n        tensors = [\n            tensor.view((1,) * (max_dim - len(tensor.shape)) + tensor.shape)\n            for tensor in tensors\n        ]\n        return tensors\n\n    def cauchy_conj(v, z, w):\n        \"\"\"Pykeops version\"\"\"\n        expr_num = \"z * ComplexReal(v) - Real2Complex(Sum(v * w))\"\n        expr_denom = \"ComplexMult(z-w, z-Conj(w))\"\n\n        cauchy_mult = Genred(\n            f\"ComplexDivide({expr_num}, {expr_denom})\",\n            [\n                \"v = Vj(2)\",\n                \"z = Vi(2)\",\n                \"w = Vj(2)\",\n            ],\n            reduction_op=\"Sum\",\n            axis=1,\n        )\n\n        v, z, w = _broadcast_dims(v, z, w)\n        v = _c2r(v)\n        z = _c2r(z)\n        w = _c2r(w)\n\n        r = 2 * cauchy_mult(v, z, w, backend=\"GPU\")\n        return _r2c(r)\n\n    def log_vandermonde(v, x, L):\n        expr = \"ComplexMult(v, ComplexExp(ComplexMult(x, l)))\"\n        vandermonde_mult = Genred(\n            expr,\n            [\n                \"v = Vj(2)\",\n                \"x = Vj(2)\",\n                \"l = Vi(2)\",\n            ],\n            reduction_op=\"Sum\",\n            axis=1,\n        )\n\n        l = torch.arange(L).to(x)\n        v, x, l = _broadcast_dims(v, x, l)\n        v = _c2r(v)\n        x = _c2r(x)\n        l = _c2r(l)\n\n        r = vandermonde_mult(v, x, l, backend=\"GPU\")\n        return 2 * _r2c(r).real\n\n    def log_vandermonde_transpose(u, v, x, L):\n        \"\"\"\n        u: ... H L\n        v: ... H N\n        x: ... H N\n        Returns: ... H N\n\n        V = Vandermonde(a, L) : (H N L)\n        contract_L(V * u * v)\n        \"\"\"\n        expr = \"ComplexMult(ComplexMult(v, u), ComplexExp(ComplexMult(x, l)))\"\n        vandermonde_mult = Genred(\n            expr,\n            [\n                \"u = Vj(2)\",\n                \"v = Vi(2)\",\n                \"x = Vi(2)\",\n                \"l = Vj(2)\",\n            ],\n            reduction_op=\"Sum\",\n            axis=1,\n        )\n\n        l = torch.arange(L).to(x)\n        u, v, x, l = _broadcast_dims(u, v, x, l)\n        u = _c2r(u)\n        v = _c2r(v)\n        x = _c2r(x)\n        l = _c2r(l)\n\n        r = vandermonde_mult(u, v, x, l, backend=\"GPU\")\n        return _r2c(r)\n\nexcept ImportError:\n    has_pykeops = False\n    if not has_cauchy_extension:\n        # log.warning(\n        #     \"Falling back on slow Cauchy kernel. Install at least one of pykeops or the CUDA extension for efficiency.\"\n        # )\n\n        def cauchy_naive(v, z, w):\n            \"\"\"\n            v, w: (..., N)\n            z: (..., L)\n            returns: (..., L)\n            \"\"\"\n            cauchy_matrix = v.unsqueeze(-1) / (\n                z.unsqueeze(-2) - w.unsqueeze(-1)\n            )  # (... N L)\n            return torch.sum(cauchy_matrix, dim=-2)\n\n    # Vandermonde functions\n    # log.warning(\n    #     \"Falling back on slow Vandermonde kernel. Install pykeops for improved memory efficiency.\"\n    # )\n\n    def log_vandermonde(v, x, L):\n        \"\"\"\n        v: (..., N)\n        x: (..., N)\n        returns: (..., L) \\sum v x^l\n        \"\"\"\n        vandermonde_matrix = torch.exp(\n            x.unsqueeze(-1) * torch.arange(L).to(x)\n        )  # (... N L)\n        vandermonde_prod = contract(\n            \"... n, ... n l -> ... l\", v, vandermonde_matrix\n        )  # (... L)\n        return 2 * vandermonde_prod.real\n\n    def log_vandermonde_transpose(u, v, x, L):\n        vandermonde_matrix = torch.exp(\n            x.unsqueeze(-1) * torch.arange(L).to(x)\n        )  # (... N L)\n        vandermonde_prod = contract(\n            \"... l, ... n, ... n l -> ... n\",\n            u.to(x),\n            v.to(x),\n            vandermonde_matrix,\n        )  # (... L)\n        return vandermonde_prod\n\n\ndef _conj(x):\n    return torch.cat([x, x.conj()], dim=-1)\n\n\n_c2r = torch.view_as_real\n_r2c = torch.view_as_complex\nif tuple(map(int, torch.__version__.split(\".\")[:2])) >= (1, 10):\n\n    def _resolve_conj(x):\n        return x.conj().resolve_conj()\n\nelse:\n\n    def _resolve_conj(x):\n        return x.conj()\n\n\n\"\"\" Simple nn.Module components \"\"\"\n\n\ndef Activation(activation=None, dim=-1):\n    if activation in [None, \"id\", \"identity\", \"linear\"]:\n        return nn.Identity()\n    elif activation == \"tanh\":\n        return nn.Tanh()\n    elif activation == \"relu\":\n        return nn.ReLU()\n    elif activation == \"gelu\":\n        return nn.GELU()\n    elif activation in [\"swish\", \"silu\"]:\n        return nn.SiLU()\n    elif activation == \"glu\":\n        return nn.GLU(dim=dim)\n    elif activation == \"sigmoid\":\n        return nn.Sigmoid()\n    else:\n        raise NotImplementedError(\n            \"hidden activation '{}' is not implemented\".format(activation)\n        )\n\n\ndef LinearActivation(\n    d_input,\n    d_output,\n    bias=True,\n    transposed=False,\n    activation=None,\n    activate=False,  # Apply activation as part of this module\n    **kwargs,\n):\n    \"\"\"Returns a linear nn.Module with control over axes order, initialization, and activation\"\"\"\n\n    # Construct core module\n    linear_cls = partial(nn.Conv1d, kernel_size=1) if transposed else nn.Linear\n    if activation == \"glu\":\n        d_output *= 2\n    linear = linear_cls(d_input, d_output, bias=bias, **kwargs)\n\n    if activate and activation is not None:\n        activation = Activation(activation, dim=-2 if transposed else -1)\n        linear = nn.Sequential(linear, activation)\n    return linear\n\n\nclass DropoutNd(nn.Module):\n    def __init__(self, p: float = 0.5, tie=True, transposed=True):\n        \"\"\"\n        tie: tie dropout mask across sequence lengths (Dropout1d/2d/3d)\n        \"\"\"\n        super().__init__()\n        if p < 0 or p >= 1:\n            raise ValueError(\n                \"dropout probability has to be in [0, 1), \"\n                \"but got {}\".format(p)\n            )\n        self.p = p\n        self.tie = tie\n        self.transposed = transposed\n        self.binomial = torch.distributions.binomial.Binomial(probs=1 - self.p)\n\n    def forward(self, X):\n        \"\"\"X: (batch, dim, lengths...)\"\"\"\n        if self.training:\n            if not self.transposed:\n                X = rearrange(X, \"b d ... -> b ... d\")\n            mask_shape = (\n                X.shape[:2] + (1,) * (X.ndim - 2) if self.tie else X.shape\n            )\n            mask = torch.rand(*mask_shape, device=X.device) < 1.0 - self.p\n            X = X * mask * (1.0 / (1 - self.p))\n            if not self.transposed:\n                X = rearrange(X, \"b ... d -> b d ...\")\n            return X\n        return X\n\n\n\"\"\" Misc functional utilities \"\"\"\n\n\ndef power(L, A, v=None):\n    \"\"\"Compute A^L and the scan sum_i A^i v_i\n\n    A: (..., N, N)\n    v: (..., N, L)\n    \"\"\"\n\n    I = torch.eye(A.shape[-1]).to(A)  # , dtype=A.dtype, device=A.device)\n\n    powers = [A]\n    l = 1\n    while True:\n        if L % 2 == 1:\n            I = powers[-1] @ I\n        L //= 2\n        if L == 0:\n            break\n        l *= 2\n        powers.append(powers[-1] @ powers[-1])\n\n    if v is None:\n        return I\n\n    # Invariants:\n    # powers[-1] := A^l\n    # l := largest po2 at most L\n\n    # Note that an alternative divide and conquer to compute the reduction is possible and can be embedded into the above loop without caching intermediate powers of A\n    # We do this reverse divide-and-conquer for efficiency reasons:\n    # 1) it involves fewer padding steps for non-po2 L\n    # 2) it involves more contiguous arrays\n\n    # Take care of edge case for non-po2 arrays\n    # Note that this initial step is a no-op for the case of power of 2 (l == L)\n    k = v.size(-1) - l\n    v_ = powers.pop() @ v[..., l:]\n    v = v[..., :l]\n    v[..., :k] = v[..., :k] + v_\n\n    # Handle reduction for power of 2\n    while v.size(-1) > 1:\n        v = rearrange(v, \"... (z l) -> ... z l\", z=2)\n        v = v[..., 0, :] + powers.pop() @ v[..., 1, :]\n    return I, v.squeeze(-1)\n\n\n\"\"\" HiPPO utilities \"\"\"\n\n\ndef transition(measure, N):\n    \"\"\"A, B transition matrices for different measures\"\"\"\n    # Legendre (translated)\n    if measure == \"legt\":\n        Q = np.arange(N, dtype=np.float64)\n        R = (2 * Q + 1) ** 0.5\n        j, i = np.meshgrid(Q, Q)\n        A = R[:, None] * np.where(i < j, (-1.0) ** (i - j), 1) * R[None, :]\n        B = R[:, None]\n        A = -A\n\n        # Halve again for timescale correctness\n        A *= 0.5\n        B *= 0.5\n    # Legendre (scaled)\n    elif measure == \"legs\":\n        q = np.arange(N, dtype=np.float64)\n        col, row = np.meshgrid(q, q)\n        r = 2 * q + 1\n        M = -(np.where(row >= col, r, 0) - np.diag(q))\n        T = np.sqrt(np.diag(2 * q + 1))\n        A = T @ M @ np.linalg.inv(T)\n        B = np.diag(T)[:, None]\n        B = (\n            B.copy()\n        )  # Otherwise \"UserWarning: given NumPY array is not writeable...\" after torch.as_tensor(B)\n    elif measure == \"legsd\":\n        # Essentially equivalent to S4D-LegS\n        q = np.arange(N, dtype=np.float64)\n        col, row = np.meshgrid(q, q)\n        r = 2 * q + 1\n        M = -(np.where(row >= col, r, 0) - np.diag(q))\n        T = np.sqrt(np.diag(2 * q + 1))\n        A = T @ M @ np.linalg.inv(T)\n        B = np.diag(T)[:, None]\n        B = (\n            B.copy()\n        )  # Otherwise \"UserWarning: given NumPY array is not writeable...\" after torch.as_tensor(B)\n        A += 0.5 * B * B[None, :, 0]\n        B = B / 2.0\n    elif measure in [\"fourier_diag\", \"foud\"]:\n        # Essentially equivalent to S4D-Lin\n        freqs = np.arange(N // 2)\n        d = np.stack([freqs, np.zeros(N // 2)], axis=-1).reshape(-1)[:-1]\n        A = 2 * np.pi * (-np.diag(d, 1) + np.diag(d, -1))\n        A = A - 0.5 * np.eye(N)\n        B = np.zeros(N)\n        B[0::2] = 2**0.5\n        B[0] = 1\n        B = B[:, None]\n    elif measure in [\"fourier\", \"fout\"]:\n        freqs = np.arange(N // 2)\n        d = np.stack([np.zeros(N // 2), freqs], axis=-1).reshape(-1)[1:]\n        A = np.pi * (-np.diag(d, 1) + np.diag(d, -1))\n        B = np.zeros(N)\n        B[0::2] = 2**0.5\n        B[0] = 1\n\n        # Subtract off rank correction - this corresponds to the other endpoint u(t-1) in this case\n        A = A - B[:, None] * B[None, :]\n        B = B[:, None]\n    else:\n        raise NotImplementedError\n\n    return A, B\n\n\ndef rank_correction(measure, N, rank=1, dtype=torch.float):\n    \"\"\"Return low-rank matrix L such that A + L is normal\"\"\"\n\n    if measure == \"legs\":\n        assert rank >= 1\n        P = torch.sqrt(0.5 + torch.arange(N, dtype=dtype)).unsqueeze(\n            0\n        )  # (1 N)\n    elif measure == \"legt\":\n        assert rank >= 2\n        P = torch.sqrt(1 + 2 * torch.arange(N, dtype=dtype))  # (N)\n        P0 = P.clone()\n        P0[0::2] = 0.0\n        P1 = P.clone()\n        P1[1::2] = 0.0\n        P = torch.stack([P0, P1], dim=0)  # (2 N)\n        P *= 2 ** (\n            -0.5\n        )  # Halve the rank correct just like the original matrix was halved\n    elif measure in [\"fourier\", \"fout\"]:\n        P = torch.zeros(N)\n        P[0::2] = 2**0.5\n        P[0] = 1\n        P = P.unsqueeze(0)\n    elif measure in [\"fourier_diag\", \"foud\", \"legsd\"]:\n        P = torch.zeros(1, N, dtype=dtype)\n    else:\n        raise NotImplementedError\n\n    d = P.size(0)\n    if rank > d:\n        P = torch.cat(\n            [P, torch.zeros(rank - d, N, dtype=dtype)], dim=0\n        )  # (rank N)\n    return P\n\n\ndef nplr(measure, N, rank=1, dtype=torch.float, diagonalize_precision=True):\n    \"\"\"Return w, p, q, V, B such that\n    (w - p q^*, B) is unitarily equivalent to the original HiPPO A, B by the matrix V\n    i.e. A = V[w - p q^*]V^*, B = V B\n    \"\"\"\n    assert dtype == torch.float or dtype == torch.double\n    cdtype = torch.cfloat if dtype == torch.float else torch.cdouble\n\n    A, B = transition(measure, N)\n    A = torch.as_tensor(A, dtype=dtype)  # (N, N)\n    B = torch.as_tensor(B, dtype=dtype)[:, 0]  # (N,)\n\n    P = rank_correction(measure, N, rank=rank, dtype=dtype)  # (r N)\n    AP = A + torch.sum(P.unsqueeze(-2) * P.unsqueeze(-1), dim=-3)\n\n    # We require AP to be nearly skew-symmetric\n    _A = AP + AP.transpose(-1, -2)\n    if (\n        err := torch.sum((_A - _A[0, 0] * torch.eye(N)) ** 2) / N\n    ) > 1e-5:  # if not torch.allclose(_A - _A[0,0]*torch.eye(N), torch.zeros(N, N), atol=1e-5):\n        print(\"WARNING: HiPPO matrix not skew symmetric\", err)\n\n    # Take advantage of identity + skew-symmetric form to calculate real and imaginary parts separately\n    # Imaginary part can use eigh instead of eig\n    w_re = torch.mean(torch.diagonal(AP), -1, keepdim=True)\n\n    # Diagonalize in double precision\n    if diagonalize_precision:\n        AP = AP.to(torch.double)\n    w_im, V = torch.linalg.eigh(AP * -1j)  # (..., N) (..., N, N)\n    if diagonalize_precision:\n        w_im, V = w_im.to(cdtype), V.to(cdtype)\n    w = w_re + 1j * w_im\n    # Check: V w V^{-1} = A\n    # print(\"check\", V @ torch.diag_embed(w) @ V.conj().transpose(-1, -2))\n\n    # Only keep half of each conjugate pair\n    _, idx = torch.sort(w.imag)\n    w_sorted = w[idx]\n    V_sorted = V[:, idx]\n\n    # There is an edge case when eigenvalues can be 0, which requires some machinery to handle\n    # We use a huge hack here: Assume only one pair is 0, and that it is the first row/column of A (only happens in Fourier case)\n    V = V_sorted[:, : N // 2]\n    w = w_sorted[: N // 2]\n    assert (\n        w[-2].abs() > 1e-4\n    ), \"Only 1 zero eigenvalue allowed in diagonal part of A\"\n    if w[-1].abs() < 1e-4:\n        V[:, -1] = 0.0\n        V[0, -1] = 2**-0.5\n        V[1, -1] = 2**-0.5 * 1j\n\n    _AP = V @ torch.diag_embed(w) @ V.conj().transpose(-1, -2)\n    if (err := torch.sum((2 * _AP.real - AP) ** 2) / N) > 1e-5:\n        print(\n            \"Warning: Diagonalization of A matrix not numerically precise - error\",\n            err,\n        )\n    # print(\"check\", V @ torch.diag_embed(w) @ V.conj().transpose(-1, -2))\n\n    V_inv = V.conj().transpose(-1, -2)\n\n    B = contract(\"ij, j -> i\", V_inv, B.to(V))  # V^* B\n    P = contract(\"ij, ...j -> ...i\", V_inv, P.to(V))  # V^* P\n\n    return w, P, B, V\n\n\ndef dplr(\n    scaling,\n    N,\n    rank=1,\n    H=1,\n    dtype=torch.float,\n    real_scale=1.0,\n    imag_scale=1.0,\n    random_real=False,\n    random_imag=False,\n    normalize=False,\n    diagonal=True,\n    random_B=False,\n):\n    assert dtype == torch.float or dtype == torch.double\n    dtype = torch.cfloat if dtype == torch.float else torch.cdouble\n\n    pi = torch.tensor(math.pi)\n    if random_real:\n        real_part = torch.rand(H, N // 2)\n    else:\n        real_part = 0.5 * torch.ones(H, N // 2)\n    if random_imag:\n        imag_part = N // 2 * torch.rand(H, N // 2)\n    else:\n        imag_part = repeat(torch.arange(N // 2), \"n -> h n\", h=H)\n\n    real_part = real_scale * real_part\n    if scaling == \"random\":\n        imag_part = torch.randn(H, N // 2)\n    elif scaling == \"real\":\n        imag_part = 0 * imag_part\n        real_part = 1 + repeat(torch.arange(N // 2), \"n -> h n\", h=H)\n    elif scaling in [\"linear\", \"lin\"]:\n        imag_part = pi * imag_part\n    elif scaling in [\n        \"inverse\",\n        \"inv\",\n    ]:  # Based on asymptotics of the default HiPPO matrix\n        imag_part = 1 / pi * N * (N / (1 + 2 * imag_part) - 1)\n    elif scaling in [\"inverse2\", \"inv2\"]:\n        imag_part = 1 / pi * N * (N / (1 + imag_part) - 1)\n    elif scaling in [\"quadratic\", \"quad\"]:\n        imag_part = 1 / pi * (1 + 2 * imag_part) ** 2\n    elif scaling in [\"legs\", \"hippo\"]:\n        w, _, _, _ = nplr(\"legsd\", N)\n        imag_part = w.imag\n\n    else:\n        raise NotImplementedError\n    imag_part = imag_scale * imag_part\n    w = -real_part + 1j * imag_part\n\n    # Initialize B\n    if random_B:\n        B = torch.randn(H, N // 2, dtype=dtype)\n    else:\n        B = torch.ones(H, N // 2, dtype=dtype)\n\n    if normalize:\n        norm = (\n            -B / w\n        )  # (H, N) # Result if you integrate the kernel with constant 1 function\n        zeta = 2 * torch.sum(\n            torch.abs(norm) ** 2, dim=-1, keepdim=True\n        )  # Variance with a random C vector\n        B = B / zeta**0.5\n\n    P = torch.randn(rank, H, N // 2, dtype=dtype)\n    if diagonal:\n        P = P * 0.0\n    V = torch.eye(N, dtype=dtype)[:: N // 2]  # Only used in testing\n    V = repeat(V, \"n m -> h n m\", h=H)\n\n    return w, P, B, V\n\n\ndef ssm(measure, N, R, H, **ssm_args):\n    \"\"\"Dispatcher to create single SSM initialization\n\n    N: state size\n    R: rank (for DPLR parameterization)\n    H: number of independent SSM copies\n    \"\"\"\n\n    if measure == \"dplr\":\n        w, P, B, V = dplr(N=N, rank=R, H=H, **ssm_args)\n    elif measure.startswith(\"diag\"):\n        args = measure.split(\"-\")\n        assert args[0] == \"diag\" and len(args) > 1\n        scaling = args[1]\n        w, P, B, V = dplr(\n            scaling=scaling, N=N, rank=R, H=H, diagonal=True, **ssm_args\n        )\n    else:\n        w, P, B, V = nplr(measure, N, R, **ssm_args)\n        w = repeat(w, \"n -> s n\", s=H)\n        P = repeat(P, \"r n -> r s n\", s=H)\n        B = repeat(B, \"n -> s n\", s=H)\n        V = repeat(V, \"n m -> s n m\", s=H)\n    return w, P, B, V\n\n\ncombinations = {\n    \"hippo\": [\"legs\", \"fourier\"],\n    \"diag\": [\"diag-inv\", \"diag-lin\"],\n    \"all\": [\"legs\", \"fourier\", \"diag-inv\", \"diag-lin\"],\n}\n\n\ndef combination(measures, N, R, S, **ssm_args):\n    if isinstance(measures, str):\n        measures = (\n            combinations[measures] if measures in combinations else [measures]\n        )\n\n    assert (\n        S % len(measures) == 0\n    ), f\"{S} independent trainable SSM copies must be multiple of {len(measures)} different measures\"\n    w, P, B, V = zip(\n        *[\n            ssm(measure, N, R, S // len(measures), **ssm_args)\n            for measure in measures\n        ]\n    )\n    w = torch.cat(w, dim=0)  # (S N)\n    P = torch.cat(P, dim=1)  # (R S N)\n    B = torch.cat(B, dim=0)  # (S N)\n    V = torch.cat(V, dim=0)  # (S N N)\n    return w, P, B, V\n\n\nclass OptimModule(nn.Module):\n    \"\"\"Interface for Module that allows registering buffers/parameters with configurable optimizer hyperparameters\"\"\"\n\n    def register(self, name, tensor, lr=None):\n        \"\"\"Register a tensor with a configurable learning rate and 0 weight decay\"\"\"\n\n        if lr == 0.0:\n            self.register_buffer(name, tensor)\n        else:\n            self.register_parameter(name, nn.Parameter(tensor))\n\n            optim = {\"weight_decay\": 0.0}\n            if lr is not None:\n                optim[\"lr\"] = lr\n            setattr(getattr(self, name), \"_optim\", optim)\n\n\nclass SSKernelNPLR(OptimModule):\n    \"\"\"Stores a representation of and computes the SSKernel function K_L(A^dt, B^dt, C) corresponding to a discretized state space, where A is Normal + Low Rank (NPLR)\"\"\"\n\n    @torch.no_grad()\n    def _setup_C(self, L):\n        \"\"\"Construct C~ from C\n\n        Two modes are supported: go directly to length L if self.L is 1, or length is doubled\n        \"\"\"\n\n        if self.L.item() == 0:\n            if self.verbose:\n                log.info(f\"S4: Initializing kernel to length {L}\")\n            double_length = False\n        elif L > self.L.item():  # 2*int(self.L) == L:\n            if self.verbose:\n                log.info(\n                    f\"S4: Doubling length from L = {self.L.item()} to {2*self.L.item()}\"\n                )\n            double_length = True\n            L = self.L.item()  # Convenience for the math below\n        else:\n            return\n\n        C = _r2c(self.C)\n        dA, _ = self._setup_state()\n        dA_L = power(L, dA)\n        # Multiply C by I - dA_L\n        C_ = _conj(C)\n        prod = contract(\"h m n, c h n -> c h m\", dA_L.transpose(-1, -2), C_)\n        if double_length:\n            prod = -prod  # Multiply by I + dA_L instead\n        C_ = C_ - prod\n        C_ = C_[..., : self.N]  # Take conjugate pairs again\n        self.C.copy_(_c2r(C_))\n\n        self.L = (\n            2 * self.L if double_length else self.L + L\n        )  # Preserve type/device\n\n    def _omega(self, L, dtype, device, cache=True):\n        \"\"\"Calculate (and cache) FFT nodes and their \"unprocessed\" version with the bilinear transform\n        This should be called everytime the internal length self.L changes\"\"\"\n\n        # Use cached if available\n        if (\n            cache\n            and hasattr(self, \"omega\")\n            and self.omega.size(-1) == L // 2 + 1\n        ):\n            return self.omega, self.z\n\n        omega = torch.tensor(\n            np.exp(-2j * np.pi / (L)), dtype=dtype, device=device\n        )  # \\omega_{2L}\n        omega = omega ** torch.arange(0, L // 2 + 1, device=device)\n        z = 2 * (1 - omega) / (1 + omega)\n\n        # Cache if necessary\n        if cache:\n            self.omega = omega\n            self.z = z\n        return omega, z\n\n    def __init__(\n        self,\n        w,\n        P,\n        B,\n        C,\n        log_dt,\n        L=None,  # starting/maximum length of kernel\n        lr=None,\n        verbose=False,\n        keops=False,\n        real_type=\"exp\",  # ['none' | 'exp' | 'relu' | sigmoid']\n        real_tolerance=1e-3,\n        bandlimit=None,\n    ):\n        \"\"\"\n        L: Maximum length; this module computes an SSM kernel of length L\n        A is represented by diag(w) - PP^*\n        w: (S, N) diagonal part\n        P: (R, S, N) low-rank part\n\n        B: (S, N)\n        C: (C, H, N)\n        dt: (H) timescale per feature\n        lr: [dict | float | None] hook to set lr of special parameters (A, B, dt)\n\n        Dimensions:\n        N (or d_state): state size\n        H (or d_model): total SSM copies\n        S (or n_ssm): number of trainable copies of (A, B, dt); must divide H\n        R (or rank): rank of low-rank part\n        C (or channels): system is 1-dim to C-dim\n\n        The forward pass of this Module returns a tensor of shape (C, H, L)\n\n        Note: tensor shape N here denotes half the true state size, because of conjugate symmetry\n        \"\"\"\n\n        super().__init__()\n        self.verbose = verbose\n        self.keops = keops\n        self.bandlimit = bandlimit\n        self.real_type = real_type\n        self.real_tolerance = real_tolerance\n\n        # Rank of low-rank correction\n        self.rank = P.shape[-3]\n        assert w.size(-1) == P.size(-1) == B.size(-1) == C.size(-1)\n        self.H = log_dt.size(-1)\n        self.N = w.size(-1)\n\n        # Check different SSM inits\n        assert w.size(-2) == P.size(-2) == B.size(-2)  # n_ssm\n        assert self.H % w.size(0) == 0\n        self.n_ssm = w.size(0)\n        self.repeat = self.H // w.size(\n            0\n        )  # Each trainable SSM needs to be duplicated this many times\n\n        # Broadcast everything to correct shapes\n        C = C.expand(\n            torch.broadcast_shapes(C.shape, (1, self.H, self.N))\n        )  # (C, H, N)\n        B = B.unsqueeze(0)  # (1, 1, N)\n\n        # Register parameters\n        self.C = nn.Parameter(_c2r(_resolve_conj(C)))\n        if lr is None or isinstance(lr, float):\n            lr_dict = {}\n        else:\n            lr_dict, lr = lr, None\n        self.register(\"log_dt\", log_dt, lr_dict.get(\"dt\", lr))\n        self.register(\"B\", _c2r(B), lr_dict.get(\"B\", lr))\n        self.register(\"P\", _c2r(P), lr_dict.get(\"A\", lr))\n        self.register(\"inv_w_real\", self._w_init(w.real), lr_dict.get(\"A\", lr))\n        self.register(\"w_imag\", w.imag, lr_dict.get(\"A\", lr))\n\n        self.l_max = L\n        self.register_buffer(\"L\", torch.tensor(0))  # Internal length\n\n    def _w_init(self, w_real):\n        w_real = torch.clamp(w_real, max=-self.real_tolerance)\n        if self.real_type == \"none\":\n            return -w_real\n        elif self.real_type == \"exp\":\n            return torch.log(\n                -w_real\n            )  # Some of the HiPPO methods have real part 0\n        elif self.real_type == \"relu\":\n            return -w_real\n        elif self.real_type == \"sigmoid\":\n            return torch.logit(-w_real)\n        elif self.real_type == \"softplus\":\n            return torch.log(torch.exp(-w_real) - 1)\n        else:\n            raise NotImplementedError\n\n    def _w(self):\n        # Get the internal w (diagonal) parameter\n        if self.real_type == \"none\":\n            w_real = -self.inv_w_real\n        elif self.real_type == \"exp\":\n            w_real = -torch.exp(self.inv_w_real)\n        elif self.real_type == \"relu\":\n            w_real = -F.relu(self.inv_w_real)\n        elif self.real_type == \"sigmoid\":\n            w_real = -F.sigmoid(self.inv_w_real)\n        elif self.real_type == \"softplus\":\n            w_real = -F.softplus(self.inv_w_real)\n        else:\n            raise NotImplementedError\n        w = w_real + 1j * self.w_imag\n        return w\n\n    def forward(self, state=None, rate=1.0, L=None):\n        \"\"\"\n        state: (B, H, N) initial state\n        rate: sampling rate factor\n        L: target length\n\n        returns:\n        (C, H, L) convolution kernel (generally C=1)\n        (B, H, L) output from initial state\n        \"\"\"\n\n        # Initialize C~ if necessary (done in forward pass so it's on the correct device)\n        if self.L.item() == 0 and self.l_max is not None and self.l_max > 0:\n            self._setup_C(self.l_max)\n\n        # Handle sampling rate logic\n        # The idea is that this kernel's length (in continuous units) is self.L, while we are asked to provide a kernel of length L at (relative) frequency rate\n        if L is None:\n            L = round(self.L.item() / rate)\n\n        # Increase the internal length if needed\n        continuous_L = round(rate * L)\n        while continuous_L > self.L.item():\n            self._setup_C(continuous_L)\n        discrete_L = round(self.L.item() / rate)\n\n        dt = torch.exp(self.log_dt) * rate\n        B = _r2c(self.B)\n        C = _r2c(self.C)\n        P = _r2c(self.P)\n        Q = P.conj()\n        w = self._w()  # (n_ssm, N)\n\n        # Address bandlimiting\n        if self.bandlimit is not None:\n            freqs = w.imag.abs() / (2 * math.pi)  # (H, N)\n            freqs = dt[:, None] / rate * freqs  # (H, N)\n            mask = torch.where(freqs < self.bandlimit * 0.5, 1, 0)\n            C = C * mask\n\n        # Get FFT nodes of right length\n        omega, z = self._omega(\n            discrete_L, dtype=w.dtype, device=w.device, cache=(rate == 1.0)\n        )\n\n        # Broadcast parameters to same hidden features H\n        B = repeat(B, \"1 t n -> 1 (v t) n\", v=self.repeat)\n        P = repeat(P, \"r t n -> r (v t) n\", v=self.repeat)\n        Q = repeat(Q, \"r t n -> r (v t) n\", v=self.repeat)\n        w = repeat(w, \"t n -> (v t) n\", v=self.repeat)\n\n        # Augment B\n        if state is not None:\n            # Have to \"unbilinear\" the state to put it into the same \"type\" as B\n            # Compute 1/dt * (I + dt/2 A) @ state\n\n            # Can do this without expanding (maybe minor speedup using conj symmetry in theory), but it's easier to read this way\n            s = _conj(state) if state.size(-1) == self.N else state  # (B H N)\n            sA = s * _conj(w) - contract(  # (B H N)\n                \"bhm, rhm, rhn -> bhn\", s, _conj(Q), _conj(P)\n            )\n            s = s / dt.unsqueeze(-1) + sA / 2\n            s = s[..., : self.N]\n\n            B = torch.cat([s, B], dim=-3)  # (B+1, H, N)\n\n        # Incorporate dt into A\n        w = w * dt.unsqueeze(-1)  # (H N)\n\n        # Stack B and p, C and q for convenient batching\n        B = torch.cat([B, P], dim=-3)  # (B+1+R, H, N)\n        C = torch.cat([C, Q], dim=-3)  # (C+R, H, N)\n\n        # Incorporate B and C batch dimensions\n        v = B.unsqueeze(-3) * C.unsqueeze(-4)  # (B+1+R, C+R, H, N)\n\n        # Calculate resolvent at omega\n        if has_cauchy_extension and z.dtype == torch.cfloat and not self.keops:\n            r = cauchy_mult(v, z, w, symmetric=True)\n        elif has_pykeops:\n            r = cauchy_conj(v, z, w)\n        else:\n            r = cauchy_naive(v, z, w)\n        r = r * dt[None, None, :, None]  # (B+1+R, C+R, H, L)\n\n        # Low-rank Woodbury correction\n        if self.rank == 1:\n            k_f = r[:-1, :-1, :, :] - r[:-1, -1:, :, :] * r[-1:, :-1, :, :] / (\n                1 + r[-1:, -1:, :, :]\n            )\n        elif self.rank == 2:\n            r00 = r[: -self.rank, : -self.rank, :, :]\n            r01 = r[: -self.rank, -self.rank :, :, :]\n            r10 = r[-self.rank :, : -self.rank, :, :]\n            r11 = r[-self.rank :, -self.rank :, :, :]\n            det = (1 + r11[:1, :1, :, :]) * (1 + r11[1:, 1:, :, :]) - r11[\n                :1, 1:, :, :\n            ] * r11[1:, :1, :, :]\n            s = (\n                r01[:, :1, :, :] * (1 + r11[1:, 1:, :, :]) * r10[:1, :, :, :]\n                + r01[:, 1:, :, :] * (1 + r11[:1, :1, :, :]) * r10[1:, :, :, :]\n                - r01[:, :1, :, :] * (r11[:1, 1:, :, :]) * r10[1:, :, :, :]\n                - r01[:, 1:, :, :] * (r11[1:, :1, :, :]) * r10[:1, :, :, :]\n            )\n            s = s / det\n            k_f = r00 - s\n        else:\n            r00 = r[: -self.rank, : -self.rank, :, :]\n            r01 = r[: -self.rank, -self.rank :, :, :]\n            r10 = r[-self.rank :, : -self.rank, :, :]\n            r11 = r[-self.rank :, -self.rank :, :, :]\n            r11 = rearrange(r11, \"a b h n -> h n a b\")\n            r11 = torch.linalg.inv(torch.eye(self.rank, device=r.device) + r11)\n            r11 = rearrange(r11, \"h n a b -> a b h n\")\n            k_f = r00 - torch.einsum(\n                \"i j h n, j k h n, k l h n -> i l h n\", r01, r11, r10\n            )\n\n        # Final correction for the bilinear transform\n        k_f = k_f * 2 / (1 + omega)\n\n        # Move from frequency to coefficients\n        k = torch.fft.irfft(k_f, n=discrete_L)  # (B+1, C, H, L)\n\n        # # Truncate to target length\n        k = k[..., :L]\n\n        if state is not None:\n            k_state = k[:-1, :, :, :]  # (B, C, H, L)\n        else:\n            k_state = None\n        k_B = k[-1, :, :, :]  # (C H L)\n\n        return k_B, k_state\n\n    @torch.no_grad()\n    def _setup_linear(self):\n        \"\"\"Create parameters that allow fast linear stepping of state\"\"\"\n        w = self._w()\n        B = _r2c(self.B)  # (H N)\n        P = _r2c(self.P)\n        Q = P.conj()\n\n        # Repeat w shape properly\n        B = repeat(B, \"1 t n -> 1 (v t) n\", v=self.repeat)\n        P = repeat(P, \"r t n -> r (v t) n\", v=self.repeat)\n        Q = repeat(Q, \"r t n -> r (v t) n\", v=self.repeat)\n        w = repeat(w, \"t n -> (v t) n\", v=self.repeat)\n\n        # Prepare Linear stepping\n        dt = torch.exp(self.log_dt)\n        D = (2.0 / dt.unsqueeze(-1) - w).reciprocal()  # (H, N)\n        R = (\n            torch.eye(self.rank, dtype=w.dtype, device=w.device)\n            + 2 * contract(\"r h n, h n, s h n -> h r s\", Q, D, P).real\n        )  # (H R R)\n        Q_D = rearrange(Q * D, \"r h n -> h r n\")\n        try:\n            R = torch.linalg.solve(R, Q_D)  # (H R N)\n        except Exception:\n            R = torch.tensor(\n                np.linalg.solve(\n                    R.to(Q_D).contiguous().detach().cpu(),\n                    Q_D.contiguous().detach().cpu(),\n                )\n            ).to(Q_D)\n        R = rearrange(R, \"h r n -> r h n\")\n\n        self.step_params = {\n            \"D\": D,  # (H N)\n            \"R\": R,  # (R H N)\n            \"P\": P,  # (R H N)\n            \"Q\": Q,  # (R H N)\n            \"B\": B,  # (1 H N)\n            \"E\": 2.0 / dt.unsqueeze(-1) + w,  # (H N)\n        }\n\n    def _step_state_linear(self, u=None, state=None):\n        \"\"\"\n        Version of the step function that has time O(N) instead of O(N^2) per step, which takes advantage of the DPLR form and bilinear discretization.\n\n        Unfortunately, as currently implemented it's about 2x slower because it calls several sequential operations. Perhaps a fused CUDA kernel implementation would be much faster\n\n        u: (H) input\n        state: (H, N/2) state with conjugate pairs\n          Optionally, the state can have last dimension N\n        Returns: same shape as state\n        \"\"\"\n        C = _r2c(self.C)  # View used for dtype/device\n\n        if u is None:  # Special case used to find dA\n            u = torch.zeros(self.H, dtype=C.dtype, device=C.device)\n        if state is None:  # Special case used to find dB\n            state = torch.zeros(self.H, self.N, dtype=C.dtype, device=C.device)\n\n        step_params = self.step_params.copy()\n        if (\n            state.size(-1) == self.N\n        ):  # Only store half of the conjugate pairs; should be true by default\n            # There should be a slightly faster way using conjugate symmetry\n            def contract_fn(p, x, y):\n                return contract(\n                    \"r h n, r h m, ... h m -> ... h n\",\n                    _conj(p),\n                    _conj(x),\n                    _conj(y),\n                )[\n                    ..., : self.N\n                ]  # inner outer product\n\n        else:\n            assert state.size(-1) == 2 * self.N\n            step_params = {k: _conj(v) for k, v in step_params.items()}\n\n            # TODO worth setting up a contract_expression in default_state if we want to use this at inference time for stepping\n            def contract_fn(p, x, y):\n                return contract(\n                    \"r h n, r h m, ... h m -> ... h n\", p, x, y\n                )  # inner outer product\n\n        D = step_params[\"D\"]  # (H N)\n        E = step_params[\"E\"]  # (H N)\n        R = step_params[\"R\"]  # (R H N)\n        P = step_params[\"P\"]  # (R H N)\n        Q = step_params[\"Q\"]  # (R H N)\n        B = step_params[\"B\"]  # (1 H N)\n\n        new_state = E * state - contract_fn(P, Q, state)  # (B H N)\n        new_state = new_state + 2.0 * B * u.unsqueeze(-1)  # (B H N)\n        new_state = D * (new_state - contract_fn(P, R, new_state))\n\n        return new_state\n\n    def _setup_state(self):\n        \"\"\"Construct dA and dB for discretized state equation\"\"\"\n\n        # Construct dA and dB by using the stepping\n        self._setup_linear()\n        C = _r2c(\n            self.C\n        )  # Just returns a view that we use for finding dtype/device\n\n        state = torch.eye(\n            2 * self.N, dtype=C.dtype, device=C.device\n        ).unsqueeze(\n            -2\n        )  # (N 1 N)\n        dA = self._step_state_linear(state=state)\n        dA = rearrange(dA, \"n h m -> h m n\")\n\n        u = C.new_ones(self.H)\n        dB = self._step_state_linear(u=u)\n        dB = _conj(dB)\n        dB = rearrange(dB, \"1 h n -> h n\")  # (H N)\n        return dA, dB\n\n    def _step_state(self, u, state):\n        \"\"\"Must be called after self.default_state() is used to construct an initial state!\"\"\"\n        next_state = self.state_contraction(\n            self.dA, state\n        ) + self.input_contraction(self.dB, u)\n        return next_state\n\n    def _setup_step(self, mode=\"dense\"):\n        \"\"\"Set up dA, dB, dC discretized parameters for stepping\"\"\"\n        self.dA, self.dB = self._setup_state()\n\n        # Calculate original C\n        C = _conj(_r2c(self.C))  # (H C N)\n        if self.L.item() == 0:\n            dC = C\n        else:\n            # self.C represents C_tilde\n            dA_L = power(self.L.item(), self.dA)\n            I = torch.eye(self.dA.size(-1)).to(dA_L)\n\n            dC = torch.linalg.solve(\n                I - dA_L.transpose(-1, -2),\n                C.unsqueeze(-1),\n            ).squeeze(-1)\n        self.dC = dC\n\n        # Do special preprocessing for different step modes\n\n        self._step_mode = mode\n        if mode == \"linear\":\n            # Linear case: special step function for the state, we need to handle output\n            # use conjugate symmetry by default, which affects the output projection\n            self.dC = 2 * self.dC[:, :, : self.N]\n        elif mode == \"diagonal\":\n            # Eigendecomposition of the A matrix\n            L, V = torch.linalg.eig(self.dA)\n            V_inv = torch.linalg.inv(V)\n            # Check that the eigendedecomposition is correct\n            if self.verbose:\n                print(\n                    \"Diagonalization error:\",\n                    torch.dist(V @ torch.diag_embed(L) @ V_inv, self.dA),\n                )\n\n            # Change the parameterization to diagonalize\n            self.dA = L\n            self.dB = contract(\"h n m, h m -> h n\", V_inv, self.dB)\n            self.dC = contract(\"h n m, c h n -> c h m\", V, self.dC)\n\n        elif mode == \"dense\":\n            pass\n        else:\n            raise NotImplementedError(\n                \"NPLR Kernel step mode must be {'dense' | 'linear' | 'diagonal'}\"\n            )\n\n    def default_state(self, *batch_shape):\n        C = _r2c(self.C)\n        N = C.size(-1)\n        H = C.size(-2)\n\n        # Cache the tensor contractions we will later do, for efficiency\n        # These are put in this function because they depend on the batch size\n        step_mode = getattr(\n            self, \"_step_mode\", \"dense\"\n        )  # Used in default_state, which is called without _setup_step() in forward_state()\n        if step_mode != \"linear\":\n            N *= 2\n\n            if step_mode == \"diagonal\":\n                self.state_contraction = contract_expression(\n                    \"h n, ... h n -> ... h n\",\n                    (H, N),\n                    batch_shape + (H, N),\n                )\n            else:\n                # Dense (quadratic) case: expand all terms\n                self.state_contraction = contract_expression(\n                    \"h m n, ... h n -> ... h m\",\n                    (H, N, N),\n                    batch_shape + (H, N),\n                )\n\n            self.input_contraction = contract_expression(\n                \"h n, ... h -> ... h n\",\n                (H, N),  # self.dB.shape\n                batch_shape + (H,),\n            )\n\n        self.output_contraction = contract_expression(\n            \"c h n, ... h n -> ... c h\",\n            (C.shape[0], H, N),  # self.dC.shape\n            batch_shape + (H, N),\n        )\n\n        state = torch.zeros(*batch_shape, H, N, dtype=C.dtype, device=C.device)\n        return state\n\n    def step(self, u, state):\n        \"\"\"Must have called self._setup_step() and created state with self.default_state() before calling this\"\"\"\n\n        if self._step_mode == \"linear\":\n            new_state = self._step_state_linear(u, state)\n        else:\n            new_state = self._step_state(u, state)\n        y = self.output_contraction(self.dC, new_state)\n        return y.real, new_state\n\n\nclass SSKernelDiag(OptimModule):\n    \"\"\"Version using (complex) diagonal state matrix (S4D)\"\"\"\n\n    def __init__(\n        self,\n        A,\n        B,\n        C,\n        log_dt,\n        L=None,\n        disc=\"bilinear\",\n        real_type=\"exp\",\n        lr=None,\n        bandlimit=None,\n    ):\n        super().__init__()\n        self.L = L\n        self.disc = disc\n        self.bandlimit = bandlimit\n        self.real_type = real_type\n\n        # Rank of low-rank correction\n        assert A.size(-1) == C.size(-1)\n        self.H = log_dt.size(-1)\n        self.N = A.size(-1)\n        assert A.size(-2) == B.size(-2)  # Number of independent SSMs trained\n        assert self.H % A.size(-2) == 0\n        self.n_ssm = A.size(-2)\n        self.repeat = self.H // A.size(0)\n\n        self.channels = C.shape[0]\n        self.C = nn.Parameter(_c2r(_resolve_conj(C)))\n\n        # Register parameters\n        if lr is None or isinstance(lr, float):\n            lr_dict = {}\n        else:\n            lr_dict, lr = lr, None\n\n        self.register(\"log_dt\", log_dt, lr_dict.get(\"dt\", lr))\n        self.register(\"B\", _c2r(B), lr_dict.get(\"B\", lr))\n        self.register(\"inv_A_real\", self._A_init(A.real), lr_dict.get(\"A\", lr))\n        self.register(\"A_imag\", A.imag, lr_dict.get(\"A\", lr))\n\n    def _A_init(self, A_real):\n        A_real = torch.clamp(A_real, max=-1e-4)\n        if self.real_type == \"none\":\n            return -A_real\n        elif self.real_type == \"exp\":\n            return torch.log(\n                -A_real\n            )  # Some of the HiPPO methods have real part 0\n        elif self.real_type == \"relu\":\n            return -A_real\n        elif self.real_type == \"sigmoid\":\n            return torch.logit(-A_real)\n        elif self.real_type == \"softplus\":\n            return torch.log(torch.exp(-A_real) - 1)\n        else:\n            raise NotImplementedError\n\n    def _A(self):\n        # Get the internal A (diagonal) parameter\n        if self.real_type == \"none\":\n            A_real = -self.inv_A_real\n        elif self.real_type == \"exp\":\n            A_real = -torch.exp(self.inv_A_real)\n        elif self.real_type == \"relu\":\n            # JAX version seems to NaN if you alloA 0's, although this code Aas fine Aithout it\n            A_real = -F.relu(self.inv_A_real) - 1e-4\n        elif self.real_type == \"sigmoid\":\n            A_real = -F.sigmoid(self.inv_A_real)\n        elif self.real_type == \"softplus\":\n            A_real = -F.softplus(self.inv_A_real)\n        else:\n            raise NotImplementedError\n        A = A_real + 1j * self.A_imag\n        return A\n\n    def forward(self, L, state=None, rate=1.0, u=None):\n        \"\"\"\n        state: (B, H, N) initial state\n        rate: sampling rate factor\n        L: target length\n\n        returns:\n        (C, H, L) convolution kernel (generally C=1)\n        (B, H, L) output from initial state\n        \"\"\"\n\n        dt = torch.exp(self.log_dt) * rate  # (H)\n        C = _r2c(self.C)  # (C H N)\n        A = self._A()  # (H N)\n\n        B = _r2c(self.B)\n        B = repeat(B, \"t n -> 1 (v t) n\", v=self.repeat)\n\n        if self.bandlimit is not None:\n            freqs = dt[:, None] / rate * A.imag.abs() / (2 * math.pi)  # (H, N)\n            mask = torch.where(freqs < self.bandlimit * 0.5, 1, 0)\n            C = C * mask\n\n        # Incorporate dt into A\n        A = repeat(A, \"t n -> (v t) n\", v=self.repeat)\n        dtA = A * dt.unsqueeze(-1)  # (H N)\n\n        # Augment B with state\n        if state is not None:\n            s = state / dt.unsqueeze(-1)\n            if self.disc == \"bilinear\":\n                s = s * (1.0 + dtA / 2)\n            elif self.disc == \"zoh\":\n                s = s * dtA * dtA.exp() / (dtA.exp() - 1.0)\n            B = torch.cat([s, B], dim=-3)  # (1+B H N)\n\n        C = (B[:, None, :, :] * C).view(-1, self.H, self.N)\n        if self.disc == \"zoh\":\n            # Power up\n            C = C * (torch.exp(dtA) - 1.0) / A\n            K = log_vandermonde(C, dtA, L)  # (H L)\n        elif self.disc == \"bilinear\":\n            C = (\n                C * (1.0 - dtA / 2).reciprocal() * dt.unsqueeze(-1)\n            )  # or * dtA / A\n            dA = (1.0 + dtA / 2) / (1.0 - dtA / 2)\n            K = log_vandermonde(C, dA.log(), L)\n        elif self.disc == \"dss\":\n            # Implementation from DSS meant for case when real eigenvalues can be positive\n            P = dtA.unsqueeze(-1) * torch.arange(L, device=C.device)  # [H N L]\n            A_gt_0 = A.real > 0  # [N]\n            if A_gt_0.any():\n                with torch.no_grad():\n                    P_max = dtA * (A_gt_0 * (L - 1))  # [H N]\n                P = P - P_max.unsqueeze(-1)  # [H N L]\n            S = P.exp()  # [H N L]\n\n            dtA_neg = dtA * (1 - 2 * A_gt_0)  # [H N]\n            num = dtA_neg.exp() - 1  # [H N]\n            den = (dtA_neg * L).exp() - 1  # [H N]\n\n            # Inline reciprocal function for DSS logic\n            x = den * A\n            x_conj = _resolve_conj(x)\n            r = x_conj / (x * x_conj + 1e-7)\n\n            C = C * num * r  # [C H N]\n            K = contract(\"chn,hnl->chl\", C, S).float()\n        else:\n            assert False, f\"{self.disc} not supported\"\n\n        K = K.view(-1, self.channels, self.H, L)  # (1+B C H L)\n        if state is not None:\n            K_state = K[:-1, :, :, :]  # (B C H L)\n        else:\n            K_state = None\n        K = K[-1, :, :, :]  # (C H L)\n        return K, K_state\n\n    def _setup_step(self):\n        # These methods are organized like this to be compatible with the NPLR kernel interface\n        dt = torch.exp(self.log_dt)  # (H)\n        B = _r2c(self.B)  # (H N)\n        C = _r2c(self.C)  # (C H N)\n        self.dC = C\n        A = self._A()  # (H N)\n\n        A = repeat(A, \"t n -> (v t) n\", v=self.repeat)\n        B = repeat(B, \"t n -> (v t) n\", v=self.repeat)\n\n        # Incorporate dt into A\n        dtA = A * dt.unsqueeze(-1)  # (H N)\n        if self.disc == \"zoh\":\n            self.dA = torch.exp(dtA)  # (H N)\n            self.dB = B * (torch.exp(dtA) - 1.0) / A  # (C H N)\n        elif self.disc == \"bilinear\":\n            self.dA = (1.0 + dtA / 2) / (1.0 - dtA / 2)\n            self.dB = (\n                B * (1.0 - dtA / 2).reciprocal() * dt.unsqueeze(-1)\n            )  # or * dtA / A\n\n    def default_state(self, *batch_shape):\n        C = _r2c(self.C)\n        state = torch.zeros(\n            *batch_shape, self.H, self.N, dtype=C.dtype, device=C.device\n        )\n        return state\n\n    def step(self, u, state):\n        next_state = contract(\n            \"h n, b h n -> b h n\", self.dA, state\n        ) + contract(\"h n, b h -> b h n\", self.dB, u)\n        y = contract(\"c h n, b h n -> b c h\", self.dC, next_state)\n        return 2 * y.real, next_state\n\n    def forward_state(self, u, state):\n        self._setup_step()\n        AL = self.dA ** u.size(-1)\n        u = u.flip(-1).to(self.dA).contiguous()  # (B H L)\n        v = log_vandermonde_transpose(u, self.dB, self.dA.log(), u.size(-1))\n        next_state = AL * state + v\n        return next_state\n\n\nclass SSKernel(nn.Module):\n    \"\"\"Wrapper around SSKernel parameterizations.\n\n    The SSKernel is expected to support the interface\n    forward()\n    default_state()\n    _setup_step()\n    step()\n    \"\"\"\n\n    def __init__(\n        self,\n        H,\n        N=64,\n        L=None,\n        measure=\"legs\",\n        rank=1,\n        channels=1,\n        dt_min=0.001,\n        dt_max=0.1,\n        deterministic=False,\n        lr=None,\n        mode=\"nplr\",\n        n_ssm=None,\n        verbose=False,\n        measure_args={},\n        **kernel_args,\n    ):\n        \"\"\"State Space Kernel which computes the convolution kernel $\\\\bar{K}$\n\n        H: Number of independent SSM copies; controls the size of the model. Also called d_model in the config.\n        N: State size (dimensionality of parameters A, B, C). Also called d_state in the config. Generally shouldn't need to be adjusted and doens't affect speed much.\n        L: Maximum length of convolution kernel, if known. Should work in the majority of cases even if not known.\n        measure: Options for initialization of (A, B). For NPLR mode, recommendations are \"legs\", \"fout\", \"hippo\" (combination of both). For Diag mode, recommendations are \"diag-inv\", \"diag-lin\", \"diag-legs\", and \"diag\" (combination of diag-inv and diag-lin)\n        rank: Rank of low-rank correction for NPLR mode. Needs to be increased for measure \"legt\"\n        channels: C channels turns the SSM from a 1-dim to C-dim map; can think of it having C separate \"heads\" per SSM. This was partly a feature to make it easier to implement bidirectionality; it is recommended to set channels=1 and adjust H to control parameters instead\n        dt_min, dt_max: min and max values for the step size dt (\\Delta)\n        mode: Which kernel algorithm to use. 'nplr' is the full S4 model; 'diag' is the simpler S4D; 'slow' is a dense version for testing\n        n_ssm: Number of independent trainable (A, B) SSMs, e.g. n_ssm=1 means all A/B parameters are tied across the H different instantiations of C. n_ssm=None means all H SSMs are completely independent. Generally, changing this option can save parameters but doesn't affect performance or speed much. This parameter must divide H\n        lr: Passing in a number (e.g. 0.001) sets attributes of SSM parameers (A, B, dt). A custom optimizer hook is needed to configure the optimizer to set the learning rates appropriately for these parameters.\n        \"\"\"\n        super().__init__()\n        self.N = N\n        self.H = H\n        dtype, cdtype = torch.float, torch.cfloat\n        self.channels = channels\n        self.n_ssm = n_ssm if n_ssm is not None else H\n        self.mode = mode\n        self.verbose = verbose\n        self.kernel_args = kernel_args\n\n        # Generate dt\n        if deterministic:\n            log_dt = torch.exp(\n                torch.linspace(math.log(dt_min), math.log(dt_max), H)\n            )\n        else:\n            log_dt = torch.rand(self.H, dtype=dtype) * (\n                math.log(dt_max) - math.log(dt_min)\n            ) + math.log(dt_min)\n\n        # Compute the preprocessed representation\n        w, P, B, V = combination(\n            measure, self.N, rank, self.n_ssm, **measure_args\n        )\n\n        # Broadcast C to have H channels\n        if deterministic:\n            C = torch.zeros(channels, self.n_ssm, self.N, dtype=cdtype)\n            C[:, :, :1] = 1.0\n            C = contract(\n                \"hmn, chn -> chm\", V.conj().transpose(-1, -2), C\n            )  # V^* C\n            C = (\n                repeat(C, \"c t n -> c (v t) n\", v=self.n_ssm // C.size(-2))\n                .clone()\n                .contiguous()\n            )\n        else:\n            C = torch.randn(channels, self.H, self.N // 2, dtype=cdtype)\n\n        # Broadcast other parameters to have n_ssm copies\n        assert (\n            self.n_ssm % B.size(-2) == 0\n            and self.n_ssm % P.size(-2) == 0\n            and self.n_ssm % w.size(-2) == 0\n        )\n        # Broadcast tensors to n_ssm copies\n        # These will be the parameters, so make sure tensors are materialized and contiguous\n        B = (\n            repeat(B, \"t n -> (v t) n\", v=self.n_ssm // B.size(-2))\n            .clone()\n            .contiguous()\n        )\n        P = (\n            repeat(P, \"r t n -> r (v t) n\", v=self.n_ssm // P.size(-2))\n            .clone()\n            .contiguous()\n        )\n        w = (\n            repeat(w, \"t n -> (v t) n\", v=self.n_ssm // w.size(-2))\n            .clone()\n            .contiguous()\n        )\n\n        if mode == \"nplr\":\n            self.kernel = SSKernelNPLR(\n                w,\n                P,\n                B,\n                C,\n                log_dt,\n                L=L,\n                lr=lr,\n                verbose=verbose,\n                **kernel_args,\n            )\n        elif mode == \"diag\":\n            if not measure.startswith(\"diag\"):\n                log.warning(\n                    \"Diagonal kernel (S4D) activated but initialization is not intended for S4D. Set `measure` to 'diag-lin', 'diag-inv', or 'diag-legs' for the main variants, or 'diag' for a combination of S4D-Lin and S4D-Inv.\"\n                )\n            C = C * repeat(B, \"t n -> (v t) n\", v=H // self.n_ssm)\n            self.kernel = SSKernelDiag(\n                w,\n                B,\n                C,\n                log_dt,\n                L=L,\n                lr=lr,\n                **kernel_args,\n            )\n        else:\n            raise NotImplementedError(f\"{mode=} is not valid\")\n\n    def forward(self, state=None, L=None, rate=1.0):\n        return self.kernel(state=state, L=L, rate=rate)\n\n    @torch.no_grad()\n    def forward_state(self, u, state):\n        \"\"\"Forward the state through a sequence, i.e. computes the state after passing chunk through SSM\n\n        state: (B, H, N)\n        u: (B, H, L)\n\n        Returns: (B, H, N)\n        \"\"\"\n\n        if hasattr(self.kernel, \"forward_state\"):\n            return self.kernel.forward_state(u, state)\n\n        dA, dB = self.kernel._setup_state()  # Construct dA, dB matrices\n        # dA, dB = self.kernel.dA, self.kernel.dB # (H N N) (H N)\n\n        conj = state.size(-1) != dA.size(-1)\n        if conj:\n            state = _conj(state)\n\n        v = contract(\n            \"h n, b h l -> b h n l\", dB, u.flip(-1)\n        )  # dB.unsqueeze(-1) * u.flip(-1).unsqueeze(-2)\n        AL, v = power(u.size(-1), dA, v)\n        next_state = contract(\"h m n, b h n -> b h m\", AL, state)\n        next_state = next_state + v\n\n        if conj:\n            next_state = next_state[..., : next_state.size(-1) // 2]\n        return next_state\n\n    def _setup_step(self, **kwargs):\n        # This method is intended to be private so that setting up an S4 module with\n        # ```\n        # if hasattr(module, 'setup_step'): module.setup_step()\n        # ```\n        # will not trigger this method multiple times\n        self.kernel._setup_step(**kwargs)\n\n    def step(self, u, state, **kwargs):\n        y, state = self.kernel.step(u, state, **kwargs)\n        return y, state\n\n    def default_state(self, *args, **kwargs):\n        return self.kernel.default_state(*args, **kwargs)\n\n\nclass S4(nn.Module):\n    def __init__(\n        self,\n        d_model,\n        d_state=64,\n        l_max=None,\n        channels=1,\n        mode=\"nplr\",\n        measure=\"legs\",\n        bidirectional=False,\n        # Arguments for position-wise feedforward components\n        activation=\"gelu\",\n        postact=\"glu\",\n        hyper_act=None,\n        dropout=0.0,\n        tie_dropout=False,\n        bottleneck=None,\n        gate=None,\n        transposed=True,\n        verbose=False,\n        # SSM Kernel arguments\n        **kernel_args,\n    ):\n        \"\"\"\n        d_state: the dimension of the state, also denoted by N\n        l_max: the maximum kernel length, also denoted by L. Set l_max=None to always use a global kernel\n        channels: can be interpreted as a number of \"heads\"; the SSM is a map from a 1-dim to C-dim sequence. It's not recommended to change this unless desperate for things to tune; instead, increase d_model for larger models\n        bidirectional: if True, convolution kernel will be two-sided\n\n        Position-wise feedforward components:\n        --------------------\n        activation: activation in between SS and FF\n        postact: activation after FF\n        hyper_act: use a \"hypernetwork\" multiplication (experimental)\n        dropout: standard dropout argument. tie_dropout=True ties the dropout mask across the sequence length, emulating nn.Dropout1d\n\n        Other arguments:\n        --------------------\n        transposed: choose backbone axis ordering of (B, L, H) (if False) or (B, H, L) (if True) [B=batch size, L=sequence length, H=hidden dimension]\n        gate: add gated activation (GSS)\n        bottleneck: reduce SSM dimension (GSS)\n\n        See the class SSKernel for the kernel constructor which accepts kernel_args. Relevant options that are worth considering and tuning include \"mode\" + \"measure\", \"dt_min\", \"dt_max\", \"lr\"\n\n        Other options are all experimental and should not need to be configured\n        \"\"\"\n\n        super().__init__()\n        if verbose:\n            log.info(\n                f\"Constructing S4 (H, N, L) = ({d_model}, {d_state}, {l_max})\"\n            )\n\n        self.d_model = d_model\n        self.H = d_model\n        self.N = d_state\n        self.L = l_max\n        self.bidirectional = bidirectional\n        self.channels = channels\n        self.transposed = transposed\n\n        self.gate = gate\n        self.bottleneck = bottleneck\n\n        if bottleneck is not None:\n            self.H = self.H // bottleneck\n            self.input_linear = LinearActivation(\n                self.d_model,\n                self.H,\n                transposed=self.transposed,\n                activation=activation,\n                activate=True,\n            )\n\n        if gate is not None:\n            self.input_gate = LinearActivation(\n                self.d_model,\n                self.d_model * gate,\n                transposed=self.transposed,\n                activation=activation,\n                activate=True,\n            )\n            self.output_gate = LinearActivation(\n                self.d_model * gate,\n                self.d_model,\n                transposed=self.transposed,\n                activation=None,\n                activate=False,\n            )\n\n        # optional multiplicative modulation GLU-style\n        # https://arxiv.org/abs/2002.05202\n        self.hyper = hyper_act is not None\n        if self.hyper:\n            channels *= 2\n            self.hyper_activation = Activation(hyper_act)\n\n        self.D = nn.Parameter(torch.randn(channels, self.H))\n\n        if self.bidirectional:\n            channels *= 2\n\n        # SSM Kernel\n        self.kernel = SSKernel(\n            self.H,\n            N=self.N,\n            L=self.L,\n            channels=channels,\n            verbose=verbose,\n            mode=mode,\n            measure=measure,\n            **kernel_args,\n        )\n\n        # Pointwise\n        self.activation = Activation(activation)\n        dropout_fn = DropoutNd if tie_dropout else nn.Dropout\n        self.dropout = dropout_fn(dropout) if dropout > 0.0 else nn.Identity()\n        # position-wise output transform to mix features\n        self.output_linear = LinearActivation(\n            self.H * self.channels,\n            self.d_model * (1 if self.gate is None else self.gate),\n            transposed=self.transposed,\n            activation=postact,\n            activate=True,\n        )\n\n    def forward(self, u, state=None, rate=1.0, lengths=None, **kwargs):\n        \"\"\"\n        u: (B H L) if self.transposed else (B L H)\n        state: (H N) never needed unless you know what you're doing\n\n        Returns: same shape as u\n        \"\"\"\n        if not self.transposed:\n            u = u.transpose(-1, -2)\n        L = u.size(-1)\n\n        # Mask out padding tokens\n        if isinstance(lengths, int):\n            if lengths != L:\n                lengths = torch.tensor(\n                    lengths, dtype=torch.long, device=u.device\n                )\n            else:\n                lengths = None\n        if lengths is not None:\n            assert (\n                isinstance(lengths, torch.Tensor)\n                and lengths.ndim == 1\n                and lengths.size(0) in [1, u.size(0)]\n            )\n            mask = torch.where(\n                torch.arange(L, device=lengths.device)\n                < lengths[:, None, None],\n                1.0,\n                0.0,\n            )\n            u = u * mask\n\n        if self.gate is not None:\n            v = self.input_gate(u)\n        if self.bottleneck is not None:\n            u = self.input_linear(u)\n\n        # Compute SS Kernel\n        L_kernel = L if self.L is None else min(L, round(self.L / rate))\n        k, k_state = self.kernel(\n            L=L_kernel, rate=rate, state=state\n        )  # (C H L) (B C H L)\n\n        # Convolution\n        if self.bidirectional:\n            k0, k1 = rearrange(k, \"(s c) h l -> s c h l\", s=2)\n            k = F.pad(k0, (0, L)) + F.pad(k1.flip(-1), (L, 0))\n        k_f = torch.fft.rfft(k, n=L_kernel + L)  # (C H L)\n        u_f = torch.fft.rfft(u, n=L_kernel + L)  # (B H L)\n        y_f = contract(\"bhl,chl->bchl\", u_f, k_f)\n        y = torch.fft.irfft(y_f, n=L_kernel + L)[..., :L]  # (B C H L)\n\n        # Compute D term in state space equation - essentially a skip connection\n        y = y + contract(\"bhl,ch->bchl\", u, self.D)\n\n        # Compute state update\n        if state is not None:\n            assert (\n                not self.bidirectional\n            ), \"Bidirectional not supported with state forwarding\"\n            y = y + k_state  #\n            next_state = self.kernel.forward_state(u, state)\n        else:\n            next_state = None\n\n        # Optional hyper-network multiplication\n        if self.hyper:\n            y, yh = rearrange(y, \"b (s c) h l -> s b c h l\", s=2)\n            y = self.hyper_activation(yh) * y\n\n        # Reshape to flatten channels\n        y = rearrange(y, \"... c h l -> ... (c h) l\")\n\n        y = self.dropout(self.activation(y))\n\n        if not self.transposed:\n            y = y.transpose(-1, -2)\n\n        y = self.output_linear(y)\n\n        if self.gate is not None:\n            y = self.output_gate(y * v)\n\n        return y, next_state\n\n    def setup_step(self, **kwargs):\n        self.kernel._setup_step(**kwargs)\n\n    def step(self, u, state):\n        \"\"\"Step one time step as a recurrent model. Intended to be used during validation.\n\n        u: (B H)\n        state: (B H N)\n        Returns: output (B H), state (B H N)\n        \"\"\"\n        assert not self.training\n\n        y, next_state = self.kernel.step(u, state)  # (B C H)\n        y = y + u.unsqueeze(-2) * self.D\n        y = rearrange(y, \"b c h -> b (c h)\")\n        y = self.activation(y)\n        if self.transposed:\n            y = self.output_linear(y.unsqueeze(-1)).squeeze(-1)\n        else:\n            y = self.output_linear(y)\n        return y, next_state\n\n    def default_state(self, *batch_shape, device=None):\n        # kernel is not a SequenceModule so it doesn't need to adhere to same interface\n        # the kernel will know the device of its own parameters\n        return self.kernel.default_state(*batch_shape)\n\n    @property\n    def d_output(self):\n        return self.d_model\n"
  },
  {
    "path": "probts/model/nn/arch/S4/s4_backbones.py",
    "content": "# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n# SPDX-License-Identifier: Apache-2.0\nimport math\n\nimport torch\nfrom torch import nn\n\nfrom probts.model.nn.arch.S4.s4 import S4\n\n\nclass SinusoidalPositionEmbeddings(nn.Module):\n    def __init__(self, dim):\n        super().__init__()\n        self.dim = dim\n\n    def forward(self, time):\n        device = time.device\n        half_dim = self.dim // 2\n        embeddings = math.log(10000) / (half_dim - 1)\n        embeddings = torch.exp(\n            torch.arange(half_dim, device=device) * -embeddings\n        )\n        embeddings = time[:, None] * embeddings[None, :]\n        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)\n        return embeddings\n\n\nclass S4Layer(nn.Module):\n    def __init__(\n        self,\n        d_model,\n        dropout=0.0,\n        mode=\"nplr\",\n        l_max=None,\n        measure=\"legs\"\n    ):\n        super().__init__()\n        self.layer = S4(\n            d_model=d_model,\n            d_state=128,\n            bidirectional=True,\n            dropout=dropout,\n            transposed=True,\n            postact=None,\n            mode=mode,\n            l_max=l_max,\n            measure=measure,\n        )\n        self.norm = nn.LayerNorm(d_model)\n        self.dropout = (\n            nn.Dropout1d(dropout) if dropout > 0.0 else nn.Identity()\n        )\n\n    def forward(self, x):\n        \"\"\"\n        Input x is shape (B, d_input, L)\n        \"\"\"\n        z = x\n        # Prenorm\n        z = self.norm(z.transpose(-1, -2)).transpose(-1, -2)\n        # Apply layer: we ignore the state input and output for training\n        z, _ = self.layer(z)\n        # Dropout on the output of the layer\n        z = self.dropout(z)\n        # Residual connection\n        x = z + x\n        return x, None\n\n    def default_state(self, *args, **kwargs):\n        return self.layer.default_state(*args, **kwargs)\n\n    def step(self, x, state, **kwargs):\n        z = x\n        # Prenorm\n        z = self.norm(z.transpose(-1, -2)).transpose(-1, -2)\n        # Apply layer\n        z, state = self.layer.step(z, state, **kwargs)\n        # Residual connection\n        x = z + x\n        return x, state\n\n\nclass S4Block(nn.Module):\n    def __init__(self, d_model, dropout=0.0, expand=2, num_features=0,mode=\"nplr\",l_max=None,measure=\"legs\"):\n        super().__init__()\n        self.s4block = S4Layer(d_model, dropout=dropout,mode=mode,l_max=l_max,measure=measure)\n\n        self.time_linear = nn.Linear(d_model, d_model)\n        self.tanh = nn.Tanh()\n        self.sigm = nn.Sigmoid()\n        self.out_linear1 = nn.Conv1d(\n            in_channels=d_model, out_channels=d_model, kernel_size=1\n        )\n        self.out_linear2 = nn.Conv1d(\n            in_channels=d_model, out_channels=d_model, kernel_size=1\n        )\n        self.feature_encoder = nn.Conv1d(num_features, d_model, kernel_size=1)\n\n    def forward(self, x, t, features=None):\n        t = self.time_linear(t)[:, None, :].repeat(1, x.shape[2], 1)\n        t = t.transpose(-1, -2)\n        out, _ = self.s4block(x + t)\n        if features is not None:\n            out = out + self.feature_encoder(features)\n        out = self.tanh(out) * self.sigm(out)\n        out1 = self.out_linear1(out)\n        out2 = self.out_linear2(out)\n        return out1 + x, out2\n\n\ndef Conv1dKaiming(in_channels, out_channels, kernel_size):\n    layer = nn.Conv1d(in_channels, out_channels, kernel_size)\n    nn.init.kaiming_normal_(layer.weight)\n    return layer\n\n\nclass BackboneModel(nn.Module):\n    def __init__(\n        self,\n        input_dim,\n        hidden_dim,\n        output_dim,\n        step_emb,\n        num_residual_blocks,\n        num_features,\n        residual_block=\"s4\",\n        mode=\"nplr\",\n        measure=\"legs\",\n        l_max=None,\n        dropout=0.0,\n        init_skip=True,\n    ):\n        super().__init__()\n        if residual_block == \"s4\":\n            residual_block = S4Block\n        else:\n            raise ValueError(f\"Unknown residual block {residual_block}\")\n        self.input_init = nn.Sequential(\n            nn.Linear(input_dim, hidden_dim),\n            nn.ReLU(),\n        )\n        self.time_init = nn.Sequential(\n            nn.Linear(step_emb, hidden_dim),\n            nn.SiLU(),\n            nn.Linear(hidden_dim, hidden_dim),\n            nn.SiLU(),\n        )\n        self.out_linear = nn.Sequential(\n            nn.Linear(hidden_dim, hidden_dim),\n            nn.ReLU(),\n            nn.Linear(hidden_dim, output_dim),\n        )\n        residual_blocks = []\n        for i in range(num_residual_blocks):\n            residual_blocks.append(\n                residual_block(\n                    hidden_dim, \n                    num_features=num_features, \n                    dropout=dropout, \n                    mode=mode,l_max=l_max,\n                    measure=measure,\n                )\n            )\n        self.residual_blocks = nn.ModuleList(residual_blocks)\n        self.step_embedding = SinusoidalPositionEmbeddings(step_emb)\n        self.init_skip = init_skip\n\n    def forward(self, input, t, features=None):\n        x = self.input_init(input)  # B, L ,C\n        t = self.time_init(self.step_embedding(t))\n        x = x.transpose(-1, -2)\n        if features is not None:\n            features = features.transpose(-1, -2)\n        skips = []\n        for layer in self.residual_blocks:\n            x, skip = layer(x, t, features)\n            skips.append(skip)\n\n        skip = torch.stack(skips).sum(0)\n        skip = skip.transpose(-1, -2)\n        out = self.out_linear(skip)\n        if self.init_skip:\n            out = out + input\n        return out\n"
  },
  {
    "path": "probts/model/nn/arch/TSMixer_layers.py",
    "content": "from __future__ import annotations\nfrom collections.abc import Callable\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import Tensor, nn\nimport sys\n\nclass TimeBatchNorm2d(nn.BatchNorm1d):\n    \"\"\"A batch normalization layer that normalizes over the last two dimensions of a\n    sequence in PyTorch, mimicking Keras behavior.\n\n    This class extends nn.BatchNorm1d to apply batch normalization across time and\n    feature dimensions.\n\n    Attributes:\n        num_time_steps (int): Number of time steps in the input.\n        num_channels (int): Number of channels in the input.\n    \"\"\"\n\n    def __init__(self, normalized_shape: tuple[int, int]):\n        \"\"\"Initializes the TimeBatchNorm2d module.\n\n        Args:\n            normalized_shape (tuple[int, int]): A tuple (num_time_steps, num_channels)\n                representing the shape of the time and feature dimensions to normalize.\n        \"\"\"\n        num_time_steps, num_channels = normalized_shape\n        \n        super().__init__(num_channels * num_time_steps)\n        self.num_time_steps = num_time_steps\n        self.num_channels = num_channels\n\n    def forward(self, x: Tensor) -> Tensor:\n        \"\"\"Applies the batch normalization over the last two dimensions of the input tensor.\n\n        Args:\n            x (Tensor): A 3D tensor with shape (N, S, C), where N is the batch size,\n                S is the number of time steps, and C is the number of channels.\n\n        Returns:\n            Tensor: A 3D tensor with batch normalization applied over the last two dims.\n\n        Raises:\n            ValueError: If the input tensor is not 3D.\n        \"\"\"\n        if x.ndim != 3:\n            raise ValueError(f\"Expected 3D input tensor, but got {x.ndim}D tensor instead.\")\n\n        # Reshaping input to combine time and feature dimensions for normalization\n        x = x.reshape(x.shape[0], -1, 1)\n\n        # Applying batch normalization\n        x = super().forward(x)\n\n        # Reshaping back to original dimensions (N, S, C)\n        x = x.reshape(x.shape[0], self.num_time_steps, self.num_channels)\n\n        return x\n\n\nclass FeatureMixing(nn.Module):\n    \"\"\"A module for feature mixing with flexibility in normalization and activation.\n\n    This module provides options for batch normalization before or after mixing features,\n    uses dropout for regularization, and allows for different activation functions.\n\n    Args:\n        sequence_length: The length of the sequences to be transformed.\n        input_channels: The number of input channels to the module.\n        output_channels: The number of output channels from the module.\n        ff_dim: The dimension of the feed-forward network internal to the module.\n        activation_fn: The activation function used within the feed-forward network.\n        dropout_rate: The dropout probability used for regularization.\n        normalize_before: A boolean indicating whether to apply normalization before\n            the rest of the operations.\n    \"\"\"\n\n    def __init__(\n        self,\n        sequence_length: int,\n        input_channels: int,\n        output_channels: int,\n        ff_dim: int,\n        activation_fn: Callable[[torch.Tensor], torch.Tensor] = F.relu,\n        dropout_rate: float = 0.1,\n        normalize_before: bool = True,\n        norm_type: type[nn.Module] = TimeBatchNorm2d,\n    ):\n        \"\"\"Initializes the FeatureMixing module with the provided parameters.\"\"\"\n        super().__init__()\n\n        self.norm_before = (\n            norm_type((sequence_length, input_channels))\n            if normalize_before\n            else nn.Identity()\n        )\n        self.norm_after = (\n            norm_type((sequence_length, output_channels))\n            if not normalize_before\n            else nn.Identity()\n        )\n\n        self.activation_fn = activation_fn\n        self.dropout = nn.Dropout(dropout_rate)\n        self.fc1 = nn.Linear(input_channels, ff_dim)\n        self.fc2 = nn.Linear(ff_dim, output_channels)\n\n        self.projection = (\n            nn.Linear(input_channels, output_channels)\n            if input_channels != output_channels\n            else nn.Identity()\n        )\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        \"\"\"Forward pass for the FeatureMixing module.\n\n        Args:\n            x: A 3D tensor with shape (N, C, L) where C is the channel dimension.\n\n        Returns:\n            The output tensor after feature mixing.\n        \"\"\"\n        x_proj = self.projection(x)\n\n        x = self.norm_before(x)\n\n        x = self.fc1(x)  # Apply the first linear transformation.\n        x = self.activation_fn(x)  # Apply the activation function.\n        x = self.dropout(x)  # Apply dropout for regularization.\n        x = self.fc2(x)  # Apply the second linear transformation.\n        x = self.dropout(x)  # Apply dropout again if needed.\n\n        x = x_proj + x  # Add the projection shortcut to the transformed features.\n\n        return self.norm_after(x)\n\n\nclass ConditionalFeatureMixing(nn.Module):\n    \"\"\"Conditional feature mixing module that incorporates static features.\n\n    This module extends the feature mixing process by including static features. It uses\n    a linear transformation to integrate static features into the dynamic feature space,\n    then applies the feature mixing on the concatenated features.\n\n    Args:\n        input_channels: The number of input channels of the dynamic features.\n        output_channels: The number of output channels after feature mixing.\n        static_channels: The number of channels in the static feature input.\n        ff_dim: The inner dimension of the feedforward network used in feature mixing.\n        activation_fn: The activation function used in feature mixing.\n        dropout_rate: The dropout probability used in the feature mixing operation.\n    \"\"\"\n\n    def __init__(\n        self,\n        sequence_length: int,\n        input_channels: int,\n        output_channels: int,\n        static_channels: int,\n        ff_dim: int,\n        activation_fn: Callable = F.relu,\n        dropout_rate: float = 0.1,\n        normalize_before: bool = False,\n        norm_type: type[nn.Module] = nn.LayerNorm,\n    ):\n        super().__init__()\n\n        self.fr_static = nn.Linear(static_channels, output_channels)\n        self.fm = FeatureMixing(\n            sequence_length,\n            input_channels + output_channels,\n            output_channels,\n            ff_dim,\n            activation_fn,\n            dropout_rate,\n            normalize_before=normalize_before,\n            norm_type=norm_type,\n        )\n\n    def forward(\n        self, x: torch.Tensor, x_static: torch.Tensor\n    ) -> tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"Applies conditional feature mixing using both dynamic and static inputs.\n\n        Args:\n            x: A tensor representing dynamic features, typically with shape\n               [batch_size, time_steps, input_channels].\n            x_static: A tensor representing static features, typically with shape\n               [batch_size, static_channels].\n\n        Returns:\n            A tuple containing:\n            - The output tensor after applying conditional feature mixing.\n            - The transformed static features tensor for monitoring or further processing.\n        \"\"\"\n        v = self.fr_static(x_static)  # Transform static features to match output channels.\n        v = v.unsqueeze(1).repeat(\n            1, x.shape[1], 1\n        )  # Repeat static features across time steps.\n\n        return (\n            self.fm(\n                torch.cat([x, v], dim=-1)\n            ),  # Apply feature mixing on concatenated features.\n            v.detach(),  # Return detached static feature for monitoring or further use.\n        )\n\n\nclass TimeMixing(nn.Module):\n    \"\"\"Applies a transformation over the time dimension of a sequence.\n\n    This module applies a linear transformation followed by an activation function\n    and dropout over the sequence length of the input feature tensor after converting\n    feature maps to the time dimension and then back.\n\n    Args:\n        input_channels: The number of input channels to the module.\n        sequence_length: The length of the sequences to be transformed.\n        activation_fn: The activation function to be used after the linear transformation.\n        dropout_rate: The dropout probability to be used after the activation function.\n    \"\"\"\n\n    def __init__(\n        self,\n        sequence_length: int,\n        input_channels: int,\n        activation_fn: Callable = F.relu,\n        dropout_rate: float = 0.1,\n        norm_type: type[nn.Module] = TimeBatchNorm2d,\n    ):\n        \"\"\"Initializes the TimeMixing module with the specified parameters.\"\"\"\n        super().__init__()\n        self.norm = norm_type((sequence_length, input_channels))\n        self.activation_fn = activation_fn\n        self.dropout = nn.Dropout(dropout_rate)\n        self.fc1 = nn.Linear(sequence_length, sequence_length)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        \"\"\"Applies the time mixing operations on the input tensor.\n\n        Args:\n            x: A 3D tensor with shape (N, C, L), where C = channel dimension and\n                L = sequence length.\n\n        Returns:\n            The normalized output tensor after time mixing transformations.\n        \"\"\"\n        x_temp = feature_to_time(\n            x\n        )  # Convert feature maps to time dimension. Assumes definition elsewhere.\n        x_temp = self.activation_fn(self.fc1(x_temp))\n        x_temp = self.dropout(x_temp)\n        x_res = time_to_feature(x_temp)  # Convert back from time to feature maps.\n\n        return self.norm(x + x_res)  # Apply normalization and combine with original input.\n\n\nclass MixerLayer(nn.Module):\n    \"\"\"A residual block that combines time and feature mixing for sequence data.\n\n    This module sequentially applies time mixing and feature mixing, which are forms\n    of data augmentation and feature transformation that can help in learning temporal\n    dependencies and feature interactions respectively.\n\n    Args:\n        sequence_length: The length of the input sequences.\n        input_channels: The number of input channels to the module.\n        output_channels: The number of output channels from the module.\n        ff_dim: The inner dimension of the feedforward network used in feature mixing.\n        activation_fn: The activation function used in both time and feature mixing.\n        dropout_rate: The dropout probability used in both mixing operations.\n    \"\"\"\n\n    def __init__(\n        self,\n        sequence_length: int,\n        input_channels: int,\n        output_channels: int,\n        ff_dim: int,\n        activation_fn: Callable = F.relu,\n        dropout_rate: float = 0.1,\n        normalize_before: bool = False,\n        norm_type: type[nn.Module] = nn.LayerNorm,\n    ):\n        \"\"\"Initializes the MixLayer with time and feature mixing modules.\"\"\"\n        super().__init__()\n\n        self.time_mixing = TimeMixing(\n            sequence_length,\n            input_channels,\n            activation_fn,\n            dropout_rate,\n            norm_type=norm_type,\n        )\n        self.feature_mixing = FeatureMixing(\n            sequence_length,\n            input_channels,\n            output_channels,\n            ff_dim,\n            activation_fn,\n            dropout_rate,\n            norm_type=norm_type,\n            normalize_before=normalize_before,\n        )\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        \"\"\"Forward pass for the MixLayer module.\n\n        Args:\n            x: A 3D tensor with shape (N, C, L) to be processed by the mixing layers.\n\n        Returns:\n            The output tensor after applying time and feature mixing operations.\n        \"\"\"\n        x = self.time_mixing(x)  # Apply time mixing first.\n        x = self.feature_mixing(x)  # Then apply feature mixing.\n\n        return x\n\n\nclass ConditionalMixerLayer(nn.Module):\n    \"\"\"Conditional mix layer combining time and feature mixing with static context.\n\n    This module combines time mixing and conditional feature mixing, where the latter\n    is influenced by static features. This allows the module to learn representations\n    that are influenced by both dynamic and static features.\n\n    Args:\n        sequence_length: The length of the input sequences.\n        input_channels: The number of input channels of the dynamic features.\n        output_channels: The number of output channels after feature mixing.\n        static_channels: The number of channels in the static feature input.\n        ff_dim: The inner dimension of the feedforward network used in feature mixing.\n        activation_fn: The activation function used in both mixing operations.\n        dropout_rate: The dropout probability used in both mixing operations.\n    \"\"\"\n\n    def __init__(\n        self,\n        sequence_length: int,\n        input_channels: int,\n        output_channels: int,\n        static_channels: int,\n        ff_dim: int,\n        activation_fn: Callable = F.relu,\n        dropout_rate: float = 0.1,\n        normalize_before: bool = False,\n        norm_type: type[nn.Module] = nn.LayerNorm,\n    ):\n        super().__init__()\n\n        self.time_mixing = TimeMixing(\n            sequence_length,\n            input_channels,\n            activation_fn,\n            dropout_rate,\n            norm_type=norm_type,\n        )\n        self.feature_mixing = ConditionalFeatureMixing(\n            sequence_length,\n            input_channels,\n            output_channels=output_channels,\n            static_channels=static_channels,\n            ff_dim=ff_dim,\n            activation_fn=activation_fn,\n            dropout_rate=dropout_rate,\n            normalize_before=normalize_before,\n            norm_type=norm_type,\n        )\n\n    def forward(self, x: torch.Tensor, x_static: torch.Tensor) -> torch.Tensor:\n        \"\"\"Forward pass for the conditional mix layer.\n\n        Args:\n            x: A tensor representing dynamic features, typically with shape\n               [batch_size, time_steps, input_channels].\n            x_static: A tensor representing static features, typically with shape\n               [batch_size, static_channels].\n\n        Returns:\n            The output tensor after applying time and conditional feature mixing.\n        \"\"\"\n        x = self.time_mixing(x)  # Apply time mixing first.\n        x, _ = self.feature_mixing(x, x_static)  # Then apply conditional feature mixing.\n\n        return x\n\n\ndef time_to_feature(x: torch.Tensor) -> torch.Tensor:\n    \"\"\"Converts a time series tensor to a feature tensor.\"\"\"\n    return x.permute(0, 2, 1)\n\n\nfeature_to_time = time_to_feature"
  },
  {
    "path": "probts/model/nn/arch/TimesFMModule/__init__.py",
    "content": "# Copyright 2024 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# # limitations under the License.\n\"\"\"TimesFM init file.\"\"\"\n# print(\n#     \"TimesFM v1.2.0. See https://github.com/google-research/timesfm/blob/master/README.md for updated APIs.\"\n# )\nfrom probts.model.nn.arch.TimesFMModule.timesfm_base import freq_map, TimesFmCheckpoint, TimesFmHparams, TimesFmBase\n\n# print(\"Loaded PyTorch TimesFM.\")\nfrom probts.model.nn.arch.TimesFMModule.timesfm_torch import TimesFmTorch as TimesFm\n"
  },
  {
    "path": "probts/model/nn/arch/TimesFMModule/patched_decoder.py",
    "content": "# Copyright 2024 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Pax ML model for patched time-series decoder.\n\nThe file implements Residual MLPs, Patched Decoder layers and PAX ML models.\n\"\"\"\n\nimport dataclasses\nfrom typing import Optional, Tuple\n\nimport einshape as es\nfrom jax import lax\nimport jax.numpy as jnp\nfrom praxis import base_layer\nfrom praxis import base_model\nfrom praxis import layers\nfrom praxis import pax_fiddle\nfrom praxis import py_utils\nfrom praxis import pytypes\nfrom praxis.layers import activations\nfrom praxis.layers import embedding_softmax\nfrom praxis.layers import linears\nfrom praxis.layers import normalizations\nfrom praxis.layers import stochastics\nfrom praxis.layers import transformers\n\n# PAX shortcuts\nNestedMap = py_utils.NestedMap\nJTensor = pytypes.JTensor\n\nLayerTpl = pax_fiddle.Config[base_layer.BaseLayer]\ntemplate_field = base_layer.template_field\n\nPAD_VAL = 1123581321.0\nDEFAULT_QUANTILES = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]\n\n# NestedMap keys\n_INPUT_TS = \"input_ts\"\n_TARGET_FUTURE = \"actual_ts\"\n_INPUT_PADDING = \"input_padding\"\n_OUTPUT_TS = \"output_ts\"\n_FREQ = \"freq\"\n_OUTPUT_TOKENS = \"output_tokens\"\n_STATS = \"stats\"\n\n# Small numerical value.\n_TOLERANCE = 1e-7\n\n\ndef _shift_padded_seq(mask: JTensor, seq: JTensor) -> JTensor:\n  \"\"\"Shifts rows of seq based on the first 0 in each row of the mask.\"\"\"\n  num = seq.shape[1]\n\n  # Find the index of the first 0 in each row of the mask\n  first_zero_idx = jnp.argmin(mask, axis=1)\n\n  # Create a range array for indexing\n  idx_range = jnp.arange(num)\n\n  def shift_row(carry, x):\n    seq_row, shift = x\n    shifted_idx = (idx_range - shift) % num\n    shifted_row = seq_row[shifted_idx]\n    return carry, shifted_row\n\n  # Use lax.scan to shift each row of seq based on the corresponding\n  # first_zero_idx.\n  _, shifted_seq = lax.scan(shift_row, None, (seq, first_zero_idx))\n\n  return shifted_seq\n\n\nclass ResidualBlock(base_layer.BaseLayer):\n  \"\"\"Simple feedforward block with residual connection.\n\n  Attributes:\n    input_dims: input dimension.\n    hidden_dims: hidden dimension.\n    output_dims: output dimension.\n    dropout_prob: dropout probability.\n    layer_norm: whether to use layer norm or not.\n    dropout_tpl: config for dropout.\n    ln_tpl: config for layer norm.\n    act_tpl: config for activation in hidden layer.\n  \"\"\"\n\n  input_dims: int = 0\n  hidden_dims: int = 0\n  output_dims: int = 0\n  dropout_prob: float = 0.0\n  layer_norm: bool = False\n  dropout_tpl: LayerTpl = template_field(stochastics.Dropout)\n  ln_tpl: LayerTpl = template_field(normalizations.LayerNorm)\n  act_tpl: LayerTpl = template_field(activations.Swish)\n\n  def setup(self):\n    lnorm_tpl = self.ln_tpl.clone()\n    lnorm_tpl.dim = self.output_dims\n    self.create_child(\"ln_layer\", lnorm_tpl)\n\n    dropout_tpl = self.dropout_tpl.clone()\n    dropout_tpl.keep_prob = 1.0 - self.dropout_prob\n    self.create_child(\"dropout\", dropout_tpl)\n\n    self.create_child(\n        \"hidden_layer\",\n        pax_fiddle.Config(\n            linears.FeedForward,\n            input_dims=self.input_dims,\n            output_dims=self.hidden_dims,\n            activation_tpl=self.act_tpl.clone(),\n        ),\n    )\n\n    self.create_child(\n        \"output_layer\",\n        pax_fiddle.Config(\n            linears.FeedForward,\n            input_dims=self.hidden_dims,\n            output_dims=self.output_dims,\n            activation_tpl=pax_fiddle.Config(activations.Identity),\n        ),\n    )\n\n    self.create_child(\n        \"residual_layer\",\n        pax_fiddle.Config(\n            linears.FeedForward,\n            input_dims=self.input_dims,\n            output_dims=self.output_dims,\n            activation_tpl=pax_fiddle.Config(activations.Identity),\n        ),\n    )\n\n  def __call__(self, inputs: JTensor) -> JTensor:\n    hidden = self.hidden_layer(inputs)\n    output = self.output_layer(hidden)\n    output = self.dropout(output)\n    residual = self.residual_layer(inputs)\n    if self.layer_norm:\n      return self.ln_layer(output + residual)\n    else:\n      return output + residual\n\n\ndef _masked_mean_std(inputs: JTensor,\n                     padding: JTensor) -> Tuple[JTensor, JTensor]:\n  \"\"\"Calculates mean and standard deviation of arr across axis 1.\n\n  It should exclude values where pad is 1.\n\n  Args:\n    inputs: A JAX array of shape [b, n, p].\n    padding: A JAX array of shape [b, n, p] with values 0 or 1.\n\n  Returns:\n    A tuple containing the mean and standard deviation of arr. We return the\n    statistics of the first patch with more than three non-padded values.\n  \"\"\"\n  # Selecting the first pad with more than 3 unpadded values.\n  pad_sum = jnp.sum(1 - padding, axis=2)\n\n  def _get_patch_index(arr: JTensor):\n    indices = jnp.argmax(arr >= 3, axis=1)\n    row_sum = (arr >= 3).sum(axis=1)\n    return jnp.where(row_sum == 0, arr.shape[1] - 1, indices)\n\n  patch_indices = _get_patch_index(pad_sum)\n  bidxs = jnp.arange(inputs.shape[0])\n\n  arr = inputs[bidxs, patch_indices, :]\n  pad = padding[bidxs, patch_indices, :]\n\n  # Create a mask where P is 0\n  mask = 1 - pad\n\n  # Calculate the number of valid elements\n  num_valid_elements = jnp.sum(mask, axis=1)\n\n  num_valid_elements = jnp.where(num_valid_elements == 0, 1, num_valid_elements)\n\n  # Calculate the masked sum and squared sum of M\n  masked_sum = jnp.sum(arr * mask, axis=1)\n  masked_squared_sum = jnp.sum((arr * mask)**2, axis=1)\n\n  # Calculate the masked mean and standard deviation\n  masked_mean = masked_sum / num_valid_elements\n  masked_var = masked_squared_sum / num_valid_elements - masked_mean**2\n  masked_var = jnp.where(masked_var < 0.0, 0.0, masked_var)\n  masked_std = jnp.sqrt(masked_var)\n\n  return masked_mean, masked_std\n\n\ndef _create_quantiles() -> list[float]:\n  \"\"\"Returns the quantiles for forecasting.\"\"\"\n  return DEFAULT_QUANTILES\n\n\nclass PatchedTimeSeriesDecoder(base_layer.BaseLayer):\n  \"\"\"Patch decoder layer for time-series foundation model.\n\n  Attributes:\n    patch_len: length of input patches.\n    horizon_len: length of output patches. Referred to as `output_patch_len`\n      during inference.\n    model_dims: model dimension of stacked transformer layer.\n    hidden_dims: hidden dimensions in fully connected layers.\n    quantiles: list of quantiles for non prob model.\n    residual_block_tpl: config for residual block.\n    stacked_transformer_params_tpl: config for stacked transformer.\n    use_freq: whether to use frequency encoding.\n\n  In all of what followed, except specified otherwise, B is batch size, T is\n  sequence length of time-series. N is the number of input patches that can be\n  obtained from T. P is the input patch length and H is the horizon length. Q is\n  number of output logits. D is model dimension.\n  \"\"\"\n\n  patch_len: int = 0\n  horizon_len: int = 0\n  model_dims: int = 0\n  hidden_dims: int = 0\n  quantiles: list[float] = dataclasses.field(default_factory=_create_quantiles)\n  residual_block_tpl: LayerTpl = template_field(ResidualBlock)\n  stacked_transformer_params_tpl: LayerTpl = template_field(\n      transformers.StackedTransformer)\n  use_freq: bool = True\n  use_pos_emb: bool = True\n\n  def setup(self) -> None:\n    \"\"\"Construct the model.\"\"\"\n    num_outputs = len(self.quantiles) + 1\n\n    stl = self.stacked_transformer_params_tpl.clone()\n    stl.model_dims = self.model_dims\n    stl.hidden_dims = self.hidden_dims\n    stl.mask_self_attention = True\n\n    self.create_child(\"stacked_transformer_layer\", stl)\n\n    input_resl = self.residual_block_tpl.clone()\n    ff_in_dims = 2 * self.patch_len\n    input_resl.input_dims = ff_in_dims\n    input_resl.hidden_dims = self.hidden_dims\n    input_resl.output_dims = self.model_dims\n    self.create_child(\n        \"input_ff_layer\",\n        input_resl,\n    )\n\n    horizon_resl = self.residual_block_tpl.clone()\n    horizon_resl.input_dims = self.model_dims\n    horizon_resl.hidden_dims = self.hidden_dims\n    horizon_resl.output_dims = self.horizon_len * num_outputs\n    self.create_child(\n        \"horizon_ff_layer\",\n        horizon_resl,\n    )\n\n    self.create_child(\n        \"position_emb\",\n        pax_fiddle.Config(layers.PositionalEmbedding,\n                          embedding_dims=self.model_dims),\n    )\n\n    if self.use_freq:\n      self.create_child(\n          \"freq_emb\",\n          pax_fiddle.Config(\n              embedding_softmax.Embedding,\n              num_classes=3,\n              input_dims=self.model_dims,\n          ),\n      )\n\n  def transform_decode_state(\n      self, transform_fn: base_layer.DecodeStateTransformFn) -> None:\n    \"\"\"Transforms all decode state variables based on transform_fn.\"\"\"\n    self.stacked_transformer_layer.transform_decode_state(transform_fn)\n\n  def _forward_transform(\n      self, inputs: JTensor,\n      patched_pads: JTensor) -> Tuple[JTensor, Tuple[JTensor, JTensor]]:\n    \"\"\"Input is of shape [B, N, P].\"\"\"\n    mu, sigma = _masked_mean_std(inputs, patched_pads)\n    sigma = jnp.where(sigma < _TOLERANCE, 1.0, sigma)\n    # Normalize each patch.\n    outputs = (inputs - mu[:, None, None]) / sigma[:, None, None]\n    outputs = jnp.where(\n        jnp.abs(inputs - PAD_VAL) < _TOLERANCE, PAD_VAL, outputs)\n    return outputs, (mu, sigma)\n\n  def _reverse_transform(self, outputs: JTensor,\n                         stats: Tuple[JTensor, JTensor]) -> JTensor:\n    \"\"\"Output is of shape [B, N, P, Q].\"\"\"\n    mu, sigma = stats\n    return outputs * sigma[:, None, None, None] + mu[:, None, None, None]\n\n  def _preprocess_input(\n      self,\n      input_ts: JTensor,\n      input_padding: JTensor,\n      pos_emb: Optional[JTensor] = None,\n  ) -> Tuple[JTensor, JTensor, Optional[Tuple[JTensor, JTensor]], JTensor]:\n    \"\"\"Preprocess input for stacked transformer.\"\"\"\n    # Reshape into patches.\n    patched_inputs = es.jax_einshape(\"b(np)->bnp\", input_ts, p=self.patch_len)\n    patched_pads = es.jax_einshape(\"b(np)->bnp\",\n                                   input_padding,\n                                   p=self.patch_len)\n    patched_inputs = jnp.where(\n        jnp.abs(patched_pads - 1.0) < _TOLERANCE, 0.0, patched_inputs)\n    patched_pads = jnp.where(\n        jnp.abs(patched_inputs - PAD_VAL) < _TOLERANCE, 1, patched_pads)\n    patched_inputs, stats = self._forward_transform(patched_inputs,\n                                                    patched_pads)\n\n    # B x N x D\n    patched_inputs = patched_inputs * (1.0 - patched_pads)\n    concat_inputs = jnp.concatenate([patched_inputs, patched_pads], axis=-1)\n    model_input = self.input_ff_layer(concat_inputs)\n    # A patch should not be padded even if there is at least one zero.\n    patched_padding = jnp.min(patched_pads, axis=-1)\n    \n    if self.use_pos_emb:\n      if pos_emb is None:\n        position_emb = self.position_emb(seq_length=model_input.shape[1])\n      else:\n        position_emb = pos_emb\n      if self.do_eval:\n        if position_emb.shape[0] != model_input.shape[0]:\n          position_emb = jnp.repeat(position_emb, model_input.shape[0], axis=0)\n        position_emb = _shift_padded_seq(patched_padding, position_emb)\n      model_input += position_emb\n\n    return model_input, patched_padding, stats, patched_inputs\n\n  def _postprocess_output(\n      self,\n      model_output: JTensor,\n      num_outputs: int,\n      stats: Tuple[JTensor, JTensor],\n  ) -> JTensor:\n    \"\"\"Postprocess output of stacked transformer.\"\"\"\n    # B x N x (H.Q)\n    output_ts = self.horizon_ff_layer(model_output)\n    output_ts = es.jax_einshape(\"bn(hq)->bnhq\",\n                                output_ts,\n                                q=num_outputs,\n                                h=self.horizon_len)\n    return self._reverse_transform(output_ts, stats)\n\n  def __call__(self, inputs: NestedMap) -> NestedMap:\n    \"\"\"PatchTST call.\n\n    Args:\n      inputs: A NestedMap containing (1) input_ts: input sequence of shape [B,\n        T] where T must be multiple of patch_length; (2) input_padding: that\n        contains padding map.\n\n    Returns:\n      A nested map with two keys:\n      (1) 'output_tokens' of shape [B, N, D].\n      (2) 'output_ts' of shape [B, N, H, Q]\n      (3) 'stats' a Tuple of statistics for renormalization.\n    \"\"\"\n    input_ts, input_padding = inputs[_INPUT_TS], inputs[_INPUT_PADDING]\n    num_outputs = len(self.quantiles) + 1\n    model_input, patched_padding, stats, _ = self._preprocess_input(\n        input_ts=input_ts,\n        input_padding=input_padding,\n    )\n    if self.use_freq:\n      freq = inputs[_FREQ].astype(jnp.int32)\n      f_emb = self.freq_emb(freq)  # B x 1 x D\n      f_emb = jnp.repeat(f_emb, model_input.shape[1], axis=1)\n      model_input += f_emb\n    model_output = self.stacked_transformer_layer(model_input, patched_padding)\n\n    output_ts = self._postprocess_output(model_output, num_outputs, stats)\n    return NestedMap({\n        _OUTPUT_TOKENS: model_output,\n        _OUTPUT_TS: output_ts,\n        _STATS: stats\n    })\n\n  def decode(\n      self,\n      inputs: NestedMap,\n      horizon_len: int,\n      output_patch_len: Optional[int] = None,\n      max_len: int = 512,\n      return_forecast_on_context: bool = False,\n  ) -> tuple[JTensor, JTensor]:\n    \"\"\"Auto-regressive decoding without caching.\n\n    Args:\n      inputs: input time-series and paddings. Time-series shape B x C, padding\n        shape shape B x (C + H) where H is the prediction length.\n      horizon_len: prediction length.\n      output_patch_len: output length to be fetched from one step of\n        auto-regressive decoding.\n      max_len: maximum training context length.\n      return_forecast_on_context: whether to return the model forecast on the\n        context except the first input patch.\n\n    Returns:\n      Tuple of two forecasting results:\n      - Point (mean) output predictions as a tensor with shape B x H'.\n      - Full predictions (mean and quantiles) as a tensor with shape\n        B x H' x (1 + # quantiles).\n      In particular, if return_forecast_on_context is True, H' is H plus\n      the forecastable context length, i.e. context_len - (first) patch_len.\n    \"\"\"\n    final_out = inputs[_INPUT_TS]\n    context_len = final_out.shape[1]\n    paddings = inputs[_INPUT_PADDING]\n    if self.use_freq:\n      freq = inputs[_FREQ].astype(jnp.int32)\n    else:\n      freq = jnp.zeros([final_out.shape[0], 1], dtype=jnp.int32)\n    full_outputs = []\n    if paddings.shape[1] != final_out.shape[1] + horizon_len:\n      raise ValueError(\n          \"Length of paddings must match length of input + horizon_len:\"\n          f\" {paddings.shape[1]} != {final_out.shape[1]} + {horizon_len}\")\n    if output_patch_len is None:\n      output_patch_len = self.horizon_len\n    num_decode_patches = (horizon_len + output_patch_len -\n                          1) // output_patch_len\n    for step_index in range(num_decode_patches):\n      current_padding = paddings[:, 0:final_out.shape[1]]\n      input_ts = final_out[:, -max_len:]\n      input_padding = current_padding[:, -max_len:]\n      model_input = NestedMap(\n          input_ts=input_ts,\n          input_padding=input_padding,\n          freq=freq,\n      )\n      fprop_outputs = self(model_input)[_OUTPUT_TS]\n      if return_forecast_on_context and step_index == 0:\n        # For the first decodings step, collect the model forecast on the\n        # context except the unavailable first input batch forecast.\n        new_full_ts = fprop_outputs[:, :-1, :self.patch_len, :]\n        new_full_ts = es.jax_einshape(\"bnph->b(np)h\", new_full_ts)\n\n        full_outputs.append(new_full_ts)\n\n      # (full batch, last patch, output_patch_len, index of mean forecast = 0)\n      new_ts = fprop_outputs[:, -1, :output_patch_len, 0]\n      new_full_ts = fprop_outputs[:, -1, :output_patch_len, :]\n      # (full batch, last patch, output_patch_len, all output indices)\n      full_outputs.append(new_full_ts)\n      final_out = jnp.concatenate([final_out, new_ts], axis=-1)\n\n    if return_forecast_on_context:\n      # `full_outputs` indexing starts at after the first input patch.\n      full_outputs = jnp.concatenate(full_outputs,\n                                     axis=1)[:, :(context_len - self.patch_len +\n                                                  horizon_len), :]\n    else:\n      # `full_outputs` indexing starts at the forecast horizon.\n      full_outputs = jnp.concatenate(full_outputs, axis=1)[:, 0:horizon_len, :]\n\n    return (full_outputs[:, :, 0], full_outputs)\n\n\nclass PatchedDecoderFinetuneModel(base_model.BaseModel):\n  \"\"\"Model class for finetuning patched time-series decoder.\n\n  Attributes:\n    core_layer_tpl: config for core layer.\n    freq: freq to finetune on.\n  \"\"\"\n\n  core_layer_tpl: LayerTpl = template_field(PatchedTimeSeriesDecoder)\n  freq: int = 0\n\n  def setup(self) -> None:\n    self.create_child(\"core_layer\", self.core_layer_tpl)\n\n  def compute_predictions(self, input_batch: NestedMap) -> NestedMap:\n    input_ts = input_batch[_INPUT_TS]\n    input_padding = jnp.zeros_like(input_ts)\n    context_len = input_ts.shape[1]\n    input_patch_len = self.core_layer_tpl.patch_len\n    context_pad = ((context_len + input_patch_len - 1) //\n                   input_patch_len) * input_patch_len - context_len\n\n    input_ts = jnp.pad(input_ts, [(0, 0), (context_pad, 0)])\n    input_padding = jnp.pad(input_padding, [(0, 0), (context_pad, 0)],\n                            constant_values=1)\n    freq = jnp.ones([input_ts.shape[0], 1], dtype=jnp.int32) * self.freq\n    new_input_batch = NestedMap(\n        input_ts=input_ts,\n        input_padding=input_padding,\n        freq=freq,\n    )\n    return self.core_layer(new_input_batch)\n\n  def _quantile_loss(self, pred: JTensor, actual: JTensor,\n                     quantile: float) -> JTensor:\n    \"\"\"Calculates quantile loss.\n\n    Args:\n      pred: B x T\n      actual: B x T\n      quantile: quantile at which loss is computed.\n\n    Returns:\n      per coordinate loss.\n    \"\"\"\n    dev = actual - pred\n    loss_first = dev * quantile\n    loss_second = -dev * (1.0 - quantile)\n    return 2 * jnp.where(loss_first >= 0, loss_first, loss_second)\n\n  def compute_loss(self, prediction_output: NestedMap,\n                   input_batch: NestedMap) -> Tuple[NestedMap, NestedMap]:\n    output_ts = prediction_output[_OUTPUT_TS]\n    actual_ts = input_batch[_TARGET_FUTURE]\n    pred_ts = output_ts[:, -1, 0:actual_ts.shape[1], :]\n    loss = jnp.square(pred_ts[:, :, 0] - actual_ts)\n    for i, quantile in enumerate(self.core_layer.quantiles):\n      loss += self._quantile_loss(pred_ts[:, :, i + 1], actual_ts, quantile)\n    loss = loss.mean()\n    loss_weight = jnp.array(1.0, dtype=jnp.float32)\n    per_example_out = NestedMap()\n    return {\"avg_qloss\": (loss, loss_weight)}, per_example_out\n"
  },
  {
    "path": "probts/model/nn/arch/TimesFMModule/pytorch_patched_decoder.py",
    "content": "# Copyright 2024 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Pytorch version of patched decoder.\"\"\"\n\nimport dataclasses\nimport math\nfrom typing import List, Tuple\nimport torch\nfrom torch import nn\nimport torch.nn.functional as F\n\n\ndef _create_quantiles() -> list[float]:\n  return [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]\n\n\n@dataclasses.dataclass\nclass TimesFMConfig:\n  \"\"\"Config for initializing timesfm patched_decoder class.\"\"\"\n\n  # The number of blocks in the model.\n  num_layers: int = 20\n  # The number of attention heads used in the attention layers of the model.\n  num_heads: int = 16\n  # The number of key-value heads for implementing attention.\n  num_kv_heads: int = 16\n  # The hidden size of the model.\n  hidden_size: int = 1280\n  # The dimension of the MLP representations.\n  intermediate_size: int = 1280\n  # The number of head dimensions.\n  head_dim: int = 80\n  # The epsilon used by the rms normalization layers.\n  rms_norm_eps: float = 1e-6\n  # Patch length\n  patch_len: int = 32\n  # Horizon length\n  horizon_len: int = 128\n  # quantiles\n  quantiles: List[float] = dataclasses.field(default_factory=_create_quantiles)\n  # Padding value\n  pad_val: float = 1123581321.0\n  # Tolerance\n  tolerance: float = 1e-6\n  # The dtype of the weights.\n  dtype: str = \"bfloat32\"\n  # use positional embedding\n  use_positional_embedding: bool = True\n\n\ndef _masked_mean_std(\n    inputs: torch.Tensor,\n    padding: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:\n  \"\"\"Calculates mean and standard deviation of `inputs` across axis 1.\n\n  It excludes values where `padding` is 1.\n\n  Args:\n    inputs: A PyTorch tensor of shape [b, n, p].\n    padding: A PyTorch tensor of shape [b, n, p] with values 0 or 1.\n\n  Returns:\n    A tuple containing the mean and standard deviation.\n    We return the statistics of the first patch with more than three non-padded\n    values.\n  \"\"\"\n  # Selecting the first patch with more than 3 unpadded values.\n  pad_sum = torch.sum(1 - padding, dim=2)\n\n  def _get_patch_index(arr: torch.Tensor):\n    indices = torch.argmax((arr >= 3).to(torch.int32), dim=1)\n    row_sum = (arr >= 3).to(torch.int32).sum(dim=1)\n    return torch.where(row_sum == 0, arr.shape[1] - 1, indices)\n\n  patch_indices = _get_patch_index(pad_sum)\n  bidxs = torch.arange(inputs.shape[0])\n\n  arr = inputs[bidxs, patch_indices, :]\n  pad = padding[bidxs, patch_indices, :]\n\n  # Create a mask where padding is 0\n  mask = 1 - pad\n\n  # Calculate the number of valid elements\n  num_valid_elements = torch.sum(mask, dim=1)\n  num_valid_elements = torch.where(\n      num_valid_elements == 0,\n      torch.tensor(1,\n                   dtype=num_valid_elements.dtype,\n                   device=num_valid_elements.device),\n      num_valid_elements,\n  )\n\n  # Calculate the masked sum and squared sum\n  masked_sum = torch.sum(arr * mask, dim=1)\n  masked_squared_sum = torch.sum((arr * mask)**2, dim=1)\n\n  # Calculate the masked mean and standard deviation\n  masked_mean = masked_sum / num_valid_elements\n  masked_var = masked_squared_sum / num_valid_elements - masked_mean**2\n  masked_var = torch.where(\n      masked_var < 0.0,\n      torch.tensor(0.0, dtype=masked_var.dtype, device=masked_var.device),\n      masked_var,\n  )\n  masked_std = torch.sqrt(masked_var)\n\n  return masked_mean, masked_std\n\n\ndef _shift_padded_seq(mask: torch.Tensor, seq: torch.Tensor) -> torch.Tensor:\n  \"\"\"Shifts rows of seq based on the first 0 in each row of the mask.\n\n  Args:\n    mask: mask tensor of shape [B, N]\n    seq: seq tensor of shape [B, N, P]\n\n  Returns:\n    Returns the shifted sequence.\n  \"\"\"\n  batch_size, num_seq, feature_dim = seq.shape\n\n  new_mask: torch.BoolTensor = mask == 0\n\n  # Use argmax to find the first True value in each row\n  indices = new_mask.to(torch.int32).argmax(dim=1)\n\n  # Handle rows with all zeros\n  indices[~new_mask.any(dim=1)] = -1\n\n  # Create index ranges for each sequence in the batch\n  idx_range = (torch.arange(num_seq).to(\n      seq.device).unsqueeze(0).unsqueeze(-1).expand(batch_size, -1,\n                                                    feature_dim))\n\n  # Calculate shifted indices for each element in each sequence\n  shifted_idx = (idx_range - indices[:, None, None]) % num_seq\n\n  # Gather values from seq using shifted indices\n  shifted_seq = seq.gather(1, shifted_idx)\n\n  return shifted_seq\n\n\ndef get_large_negative_number(dtype: torch.dtype) -> torch.Tensor:\n  \"\"\"Returns a large negative value for the given dtype.\"\"\"\n  if dtype.is_floating_point:\n    dtype_max = torch.finfo(dtype).max\n  else:\n    dtype_max = torch.iinfo(dtype).max\n  return torch.tensor(-0.7 * dtype_max, dtype=dtype)\n\n\ndef apply_mask_to_logits(logits: torch.Tensor,\n                         mask: torch.Tensor) -> torch.Tensor:\n  \"\"\"Applies a floating-point mask to a set of logits.\n\n  Args:\n      logits: A torch.Tensor of logit values.\n      mask: A torch.Tensor (float32) of mask values with the encoding described\n        in the function documentation.\n\n  Returns:\n      Masked logits.\n  \"\"\"\n\n  min_value = get_large_negative_number(logits.dtype)\n\n  return torch.where((mask >= min_value * 0.5), logits, min_value)\n\n\ndef convert_paddings_to_mask(\n    paddings: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor:\n  \"\"\"Converts binary paddings to a logit mask ready to add to attention matrix.\n\n  Args:\n      paddings: binary torch.Tensor of shape [B, T], with 1 denoting padding\n        token.\n      dtype: data type of the input.\n\n  Returns:\n      A torch.Tensor of shape [B, 1, 1, T] ready to add to attention logits.\n  \"\"\"\n  attention_mask = paddings.detach().clone()\n  attention_mask = attention_mask[:, None, None, :]  # Equivalent to jnp.newaxis\n  attention_mask *= get_large_negative_number(dtype)\n  return attention_mask\n\n\ndef causal_mask(input_t: torch.Tensor) -> torch.Tensor:\n  \"\"\"Computes and returns causal mask.\n\n  Args:\n      input_t: A torch.Tensor of shape [B, T, D].\n\n  Returns:\n      An attention_mask torch.Tensor of shape [1, 1, T, T]. Attention mask has\n      already been converted to large negative values.\n  \"\"\"\n  assert input_t.dtype.is_floating_point, input_t.dtype\n  large_negative_number = get_large_negative_number(input_t.dtype)\n  t = input_t.shape[1]\n  col_idx = torch.arange(t).unsqueeze(0).repeat(t, 1)\n  row_idx = torch.arange(t).unsqueeze(1).repeat(1, t)\n  mask = (row_idx < col_idx).to(input_t.dtype) * large_negative_number\n  return (mask.unsqueeze(0).unsqueeze(0).to(input_t.device)\n         )  # Equivalent to jnp.newaxis\n\n\ndef merge_masks(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:\n  \"\"\"Merges 2 masks.\n\n  logscale mask is expected but 0/1 mask is also fine.\n\n  Args:\n      a: torch.Tensor of shape [1|B, 1, 1|T, S].\n      b: torch.Tensor of shape [1|B, 1, 1|T, S].\n\n  Returns:\n      torch.Tensor of shape [1|B, 1, 1|T, S].\n  \"\"\"\n\n  def expand_t(key_mask):\n    query_mask = key_mask.transpose(-1, -2)  # Equivalent of jnp.transpose\n    return torch.minimum(query_mask, key_mask)\n\n  if a.shape[2] != b.shape[2]:\n    if a.shape[2] == 1:\n      a = expand_t(a)\n    else:\n      assert b.shape[2] == 1\n      b = expand_t(b)\n\n  assert a.shape[1:] == b.shape[1:], f\"a.shape={a.shape}, b.shape={b.shape}.\"\n  return torch.minimum(a, b)  # Element-wise minimum, similar to jnp.minimum\n\n\nclass ResidualBlock(nn.Module):\n  \"\"\"TimesFM residual block.\"\"\"\n\n  def __init__(\n      self,\n      input_dims,\n      hidden_dims,\n      output_dims,\n  ):\n    super(ResidualBlock, self).__init__()\n    self.input_dims = input_dims\n    self.hidden_dims = hidden_dims\n    self.output_dims = output_dims\n\n    # Hidden Layer\n    self.hidden_layer = nn.Sequential(\n        nn.Linear(input_dims, hidden_dims),\n        nn.SiLU(),\n    )\n\n    # Output Layer\n    self.output_layer = nn.Linear(hidden_dims, output_dims)\n    # Residual Layer\n    self.residual_layer = nn.Linear(input_dims, output_dims)\n\n  def forward(self, x):\n    hidden = self.hidden_layer(x)\n    output = self.output_layer(hidden)\n    residual = self.residual_layer(x)\n    return output + residual\n\n\nclass RMSNorm(torch.nn.Module):\n  \"\"\"Pax rms norm in pytorch.\"\"\"\n\n  def __init__(\n      self,\n      dim: int,\n      eps: float = 1e-6,\n      add_unit_offset: bool = False,\n  ):\n    super().__init__()\n    self.eps = eps\n    self.add_unit_offset = add_unit_offset\n    self.weight = nn.Parameter(torch.zeros(dim))\n\n  def _norm(self, x):\n    return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)\n\n  def forward(self, x):\n    output = self._norm(x.float())\n    if self.add_unit_offset:\n      output = output * (1 + self.weight.float())\n    else:\n      output = output * self.weight.float()\n    return output.type_as(x)\n\n\nclass TransformerMLP(nn.Module):\n  \"\"\"Pax transformer MLP in pytorch.\"\"\"\n\n  def __init__(\n      self,\n      hidden_size: int,\n      intermediate_size: int,\n  ):\n    super().__init__()\n    self.gate_proj = nn.Linear(hidden_size, intermediate_size)\n    self.down_proj = nn.Linear(intermediate_size, hidden_size)\n    self.layer_norm = nn.LayerNorm(normalized_shape=hidden_size, eps=1e-6)\n\n  def forward(self, x, paddings=None):\n    gate_inp = self.layer_norm(x)\n    gate = self.gate_proj(gate_inp)\n    gate = F.relu(gate)\n    outputs = self.down_proj(gate)\n    if paddings is not None:\n      outputs = outputs * (1.0 - paddings[:, :, None])\n    return outputs + x\n\n\nclass TimesFMAttention(nn.Module):\n  \"\"\"Implements the attention used in TimesFM.\"\"\"\n\n  def __init__(\n      self,\n      hidden_size: int,\n      num_heads: int,\n      num_kv_heads: int,\n      head_dim: int,\n  ):\n    super().__init__()\n\n    self.num_heads = num_heads\n    self.num_kv_heads = num_kv_heads\n\n    assert self.num_heads % self.num_kv_heads == 0\n    self.num_queries_per_kv = self.num_heads // self.num_kv_heads\n\n    self.hidden_size = hidden_size\n    self.head_dim = head_dim\n\n    self.q_size = self.num_heads * self.head_dim\n    self.kv_size = self.num_kv_heads * self.head_dim\n    self.scaling = nn.Parameter(\n        torch.empty((self.head_dim,), dtype=torch.float32),)\n\n    self.qkv_proj = nn.Linear(\n        self.hidden_size,\n        (self.num_heads + 2 * self.num_kv_heads) * self.head_dim,\n    )\n    self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size)\n\n  def _per_dim_scaling(self, query: torch.Tensor) -> torch.Tensor:\n    # [batch_size, n_local_heads, input_len, head_dim]\n    r_softplus_0 = 1.442695041\n    softplus_func = torch.nn.Softplus()\n    scale = r_softplus_0 / math.sqrt(self.head_dim)\n    scale = scale * softplus_func(self.scaling)\n    return query * scale[None, None, None, :]\n\n  def forward(\n      self,\n      hidden_states: torch.Tensor,\n      mask: torch.Tensor,\n      kv_write_indices: torch.Tensor | None = None,\n      kv_cache: Tuple[torch.Tensor, torch.Tensor] | None = None,\n  ) -> torch.Tensor:\n    hidden_states_shape = hidden_states.shape\n    assert len(hidden_states_shape) == 3\n\n    batch_size, input_len, _ = hidden_states_shape\n\n    qkv = self.qkv_proj(hidden_states)\n    xq, xk, xv = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)\n\n    xq = xq.view(batch_size, -1, self.num_heads, self.head_dim)\n    xk = xk.view(batch_size, -1, self.num_kv_heads, self.head_dim)\n    xv = xv.view(batch_size, -1, self.num_kv_heads, self.head_dim)\n    xq = self._per_dim_scaling(xq)\n\n    # Write new kv cache.\n    # [batch_size, input_len, n_local_kv_heads, head_dim]\n    if kv_cache is not None and kv_write_indices is not None:\n      k_cache, v_cache = kv_cache\n      k_cache.index_copy_(1, kv_write_indices, xk)\n      v_cache.index_copy_(1, kv_write_indices, xv)\n\n      key = k_cache\n      value = v_cache\n    else:\n      key = xk\n      value = xv\n    if self.num_kv_heads != self.num_heads:\n      # [batch_size, max_seq_len, n_local_heads, head_dim]\n      key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=2)\n      value = torch.repeat_interleave(value, self.num_queries_per_kv, dim=2)\n\n    # [batch_size, n_local_heads, input_len, head_dim]\n    q = xq.transpose(1, 2)\n    # [batch_size, n_local_heads, max_seq_len, head_dim]\n    k = key.transpose(1, 2)\n    v = value.transpose(1, 2)\n\n    # [batch_size, n_local_heads, input_len, max_seq_len]\n    scores = torch.matmul(q, k.transpose(2, 3))\n    scores = scores + mask\n    scores = F.softmax(scores.float(), dim=-1).type_as(q)\n\n    # [batch_size, n_local_heads, input_len, head_dim]\n    output = torch.matmul(scores, v)\n    # return scores, output.transpose(1, 2).contiguous()\n\n    # [batch_size, input_len, hidden_dim]\n    output = output.transpose(1, 2).contiguous().view(batch_size, input_len, -1)\n    output = self.o_proj(output)\n    return scores, output\n\n\nclass TimesFMDecoderLayer(nn.Module):\n  \"\"\"Transformer layer.\"\"\"\n\n  def __init__(\n      self,\n      hidden_size: int,\n      intermediate_size: int,\n      num_heads: int,\n      num_kv_heads: int,\n      head_dim: int,\n      rms_norm_eps: float = 1e-6,\n  ):\n    super().__init__()\n    self.self_attn = TimesFMAttention(\n        hidden_size=hidden_size,\n        num_heads=num_heads,\n        num_kv_heads=num_kv_heads,\n        head_dim=head_dim,\n    )\n    self.mlp = TransformerMLP(\n        hidden_size=hidden_size,\n        intermediate_size=intermediate_size,\n    )\n    self.input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps)\n\n  def forward(\n      self,\n      hidden_states: torch.Tensor,\n      mask: torch.Tensor,\n      paddings: torch.Tensor,\n      kv_write_indices: torch.Tensor | None = None,\n      kv_cache: Tuple[torch.Tensor, torch.Tensor] | None = None,\n  ) -> torch.Tensor:\n    # Self Attention\n    residual = hidden_states\n    hidden_states = self.input_layernorm(hidden_states)\n    scores, hidden_states = self.self_attn(\n        hidden_states=hidden_states,\n        mask=mask,\n        kv_write_indices=kv_write_indices,\n        kv_cache=kv_cache,\n    )\n    hidden_states = residual + hidden_states\n\n    # MLP\n    hidden_states = self.mlp(hidden_states, paddings=paddings)\n\n    return scores, hidden_states\n\n\nclass StackedDecoder(nn.Module):\n  \"\"\"Stacked transformer layer.\"\"\"\n\n  def __init__(\n      self,\n      hidden_size: int,\n      intermediate_size: int,\n      num_heads: int,\n      num_kv_heads: int,\n      head_dim: int,\n      num_layers: int,\n      rms_norm_eps: float = 1e-6,\n  ):\n    super().__init__()\n\n    self.layers = nn.ModuleList()\n    for _ in range(num_layers):\n      self.layers.append(\n          TimesFMDecoderLayer(\n              hidden_size=hidden_size,\n              intermediate_size=intermediate_size,\n              num_heads=num_heads,\n              num_kv_heads=num_kv_heads,\n              head_dim=head_dim,\n              rms_norm_eps=rms_norm_eps,\n          ))\n\n  def forward(\n      self,\n      hidden_states: torch.Tensor,\n      paddings: torch.Tensor,\n      kv_write_indices: torch.Tensor | None = None,\n      kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] | None = None,\n  ) -> torch.Tensor:\n    padding_mask = convert_paddings_to_mask(paddings, hidden_states.dtype)\n    atten_mask = causal_mask(hidden_states)\n    mask = merge_masks(padding_mask, atten_mask)\n    for i in range(len(self.layers)):\n      layer = self.layers[i]\n      kv_cache = kv_caches[i] if kv_caches is not None else None\n      _, hidden_states = layer(\n          hidden_states=hidden_states,\n          mask=mask,\n          paddings=paddings,\n          kv_write_indices=kv_write_indices,\n          kv_cache=kv_cache,\n      )\n    return hidden_states\n\n\nclass PositionalEmbedding(torch.nn.Module):\n  \"\"\"Generates position embedding for a given 1-d sequence.\n\n  Attributes:\n      min_timescale: Start of the geometric index. Determines the periodicity of\n        the added signal.\n      max_timescale: End of the geometric index. Determines the frequency of the\n        added signal.\n      embedding_dims: Dimension of the embedding to be generated.\n  \"\"\"\n\n  def __init__(\n      self,\n      embedding_dims: int,\n      min_timescale: int = 1,\n      max_timescale: int = 10_000,\n  ) -> None:\n    super().__init__()\n    self.min_timescale = min_timescale\n    self.max_timescale = max_timescale\n    self.embedding_dims = embedding_dims\n\n  def forward(self, seq_length=None, position=None):\n    \"\"\"Generates a Tensor of sinusoids with different frequencies.\n\n    Args:\n        seq_length: an optional Python int defining the output sequence length.\n          if the `position` argument is specified.\n        position:   [B, seq_length], optional position for each token in the\n          sequence, only required when the sequence is packed.\n\n    Returns:\n        [B, seqlen, D] if `position` is specified, else [1, seqlen, D]\n    \"\"\"\n    if position is None:\n      assert seq_length is not None\n      # [1, seqlen]\n      position = torch.arange(seq_length, dtype=torch.float32).unsqueeze(0)\n    else:\n      assert position.ndim == 2, position.shape\n\n    num_timescales = self.embedding_dims // 2\n    log_timescale_increment = math.log(\n        float(self.max_timescale) / float(self.min_timescale)) / max(\n            num_timescales - 1, 1)\n    inv_timescales = self.min_timescale * torch.exp(\n        torch.arange(num_timescales, dtype=torch.float32) *\n        -log_timescale_increment)\n    scaled_time = position.unsqueeze(2) * inv_timescales.unsqueeze(0).unsqueeze(\n        0)\n    signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2)\n    # Padding to ensure correct embedding dimension\n    signal = F.pad(signal, (0, 0, 0, self.embedding_dims % 2))\n    return signal\n\n\nclass PatchedTimeSeriesDecoder(nn.Module):\n  \"\"\"Patched time-series decoder.\"\"\"\n\n  def __init__(self, config: TimesFMConfig):\n    super().__init__()\n    self.config = config\n    self.input_ff_layer = ResidualBlock(\n        input_dims=2 * config.patch_len,\n        output_dims=config.hidden_size,\n        hidden_dims=config.intermediate_size,\n    )\n    self.freq_emb = nn.Embedding(num_embeddings=3,\n                                 embedding_dim=config.hidden_size)\n    self.horizon_ff_layer = ResidualBlock(\n        input_dims=config.hidden_size,\n        output_dims=config.horizon_len * (1 + len(config.quantiles)),\n        hidden_dims=config.intermediate_size,\n    )\n    self.stacked_transformer = StackedDecoder(\n        hidden_size=self.config.hidden_size,\n        intermediate_size=self.config.intermediate_size,\n        num_heads=self.config.num_heads,\n        num_kv_heads=self.config.num_kv_heads,\n        head_dim=self.config.head_dim,\n        num_layers=self.config.num_layers,\n        rms_norm_eps=self.config.rms_norm_eps,\n    )\n    if self.config.use_positional_embedding:\n      self.position_emb = PositionalEmbedding(self.config.hidden_size)\n\n  def _forward_transform(\n      self, inputs: torch.Tensor, patched_pads: torch.Tensor\n  ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:\n    \"\"\"Input is of shape [B, N, P].\"\"\"\n    mu, sigma = _masked_mean_std(inputs, patched_pads)\n    sigma = torch.where(\n        sigma < self.config.tolerance,\n        torch.tensor(1.0, dtype=sigma.dtype, device=sigma.device),\n        sigma,\n    )\n\n    # Normalize each patch\n    outputs = (inputs - mu[:, None, None]) / sigma[:, None, None]\n    outputs = torch.where(\n        torch.abs(inputs - self.config.pad_val) < self.config.tolerance,\n        torch.tensor(self.config.pad_val,\n                     dtype=outputs.dtype,\n                     device=outputs.device),\n        outputs,\n    )\n    return outputs, (mu, sigma)\n\n  def _reverse_transform(\n      self, outputs: torch.Tensor, stats: tuple[torch.Tensor,\n                                                torch.Tensor]) -> torch.Tensor:\n    \"\"\"Output is of shape [B, N, P, Q].\"\"\"\n    mu, sigma = stats\n    return outputs * sigma[:, None, None, None] + mu[:, None, None, None]\n\n  def _preprocess_input(\n      self,\n      input_ts: torch.Tensor,\n      input_padding: torch.Tensor,\n  ) -> tuple[\n      torch.Tensor,\n      torch.Tensor,\n      tuple[torch.Tensor, torch.Tensor] | None,\n      torch.Tensor,\n  ]:\n    \"\"\"Preprocess input for stacked transformer.\"\"\"\n\n    # Reshape into patches (using view for efficiency)\n    bsize = input_ts.shape[0]\n    patched_inputs = input_ts.view(bsize, -1, self.config.patch_len)\n    patched_pads = input_padding.view(bsize, -1, self.config.patch_len)\n\n    patched_inputs = torch.where(\n        torch.abs(patched_pads - 1.0) < self.config.tolerance,\n        torch.tensor(0.0,\n                     dtype=patched_inputs.dtype,\n                     device=patched_inputs.device),\n        patched_inputs,\n    )\n    patched_pads = torch.where(\n        torch.abs(patched_inputs - self.config.pad_val) < self.config.tolerance,\n        torch.tensor(1.0, dtype=patched_pads.dtype, device=patched_pads.device),\n        patched_pads,\n    )\n    patched_inputs, stats = self._forward_transform(patched_inputs,\n                                                    patched_pads)\n\n    # B x N x D\n    patched_inputs = patched_inputs * (1.0 - patched_pads)\n    concat_inputs = torch.cat([patched_inputs, patched_pads], dim=-1)\n    model_input = self.input_ff_layer(concat_inputs)\n\n    # A patch should not be padded even if there is at least one zero.\n    patched_padding = torch.min(patched_pads,\n                                dim=-1)[0]  # Get the values from the min result\n    if self.config.use_positional_embedding:\n      pos_emb = self.position_emb(model_input.shape[1]).to(model_input.device)\n      pos_emb = torch.concat([pos_emb] * model_input.shape[0], dim=0)\n      pos_emb = _shift_padded_seq(patched_padding, pos_emb)\n      model_input += pos_emb\n\n    return model_input, patched_padding, stats, patched_inputs\n\n  def _postprocess_output(\n      self,\n      model_output: torch.Tensor,\n      num_outputs: int,\n      stats: tuple[torch.Tensor, torch.Tensor],\n  ) -> torch.Tensor:\n    \"\"\"Postprocess output of stacked transformer.\"\"\"\n\n    # B x N x (H.Q)\n    output_ts = self.horizon_ff_layer(model_output)\n\n    # Reshape using view\n    b, n, _ = output_ts.shape\n    output_ts = output_ts.view(b, n, self.config.horizon_len, num_outputs)\n\n    return self._reverse_transform(output_ts, stats)\n\n  def forward(\n      self,\n      input_ts: torch.Tensor,\n      input_padding: torch.LongTensor,\n      freq: torch.Tensor,\n  ) -> torch.Tensor:\n    num_outputs = len(self.config.quantiles) + 1\n    model_input, patched_padding, stats, _ = self._preprocess_input(\n        input_ts=input_ts,\n        input_padding=input_padding,\n    )\n    f_emb = self.freq_emb(freq)  # B x 1 x D\n    model_input += f_emb\n    model_output = self.stacked_transformer(model_input, patched_padding)\n\n    output_ts = self._postprocess_output(model_output, num_outputs, stats)\n    return output_ts\n\n  def decode(\n      self,\n      input_ts: torch.Tensor,\n      paddings: torch.Tensor,\n      freq: torch.LongTensor,\n      horizon_len: int,\n      output_patch_len: int | None = None,\n      max_len: int = 512,\n      return_forecast_on_context: bool = False,\n  ) -> tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"Auto-regressive decoding without caching.\n\n    Args:\n      input_ts: input time-series and paddings. Time-series shape B x C.\n      paddings: padding shape B x (C + H) where H is the prediction length.\n      freq: frequency shape B x 1\n      horizon_len: prediction length.\n      output_patch_len: output length to be fetched from one step of\n        auto-regressive decoding.\n      max_len: maximum training context length.\n      return_forecast_on_context: whether to return the model forecast on the\n        context except the first input patch.\n\n    Returns:\n      Tuple of two forecasting results:\n      - Point (mean) output predictions as a tensor with shape B x H'.\n      - Full predictions (mean and quantiles) as a tensor with shape\n        B x H' x (1 + # quantiles).\n      In particular, if return_forecast_on_context is True, H' is H plus\n      the forecastable context length, i.e. context_len - (first) patch_len.\n    \"\"\"\n    final_out = input_ts\n    context_len = final_out.shape[1]\n    full_outputs = []\n    if paddings.shape[1] != final_out.shape[1] + horizon_len:\n      raise ValueError(\n          \"Length of paddings must match length of input + horizon_len:\"\n          f\" {paddings.shape[1]} != {final_out.shape[1]} + {horizon_len}\")\n    if output_patch_len is None:\n      output_patch_len = self.config.horizon_len\n    num_decode_patches = (horizon_len + output_patch_len -\n                          1) // output_patch_len\n    for step_index in range(num_decode_patches):\n      current_padding = paddings[:, 0:final_out.shape[1]]\n      input_ts = final_out[:, -max_len:]\n      input_padding = current_padding[:, -max_len:]\n      fprop_outputs = self(input_ts, input_padding, freq)\n      if return_forecast_on_context and step_index == 0:\n        # For the first decodings step, collect the model forecast on the\n        # context except the unavailable first input batch forecast.\n        new_full_ts = fprop_outputs[:, :-1, :self.config.patch_len, :]\n        new_full_ts = fprop_outputs.view(new_full_ts.size(0), -1,\n                                         new_full_ts.size(3))\n\n        full_outputs.append(new_full_ts)\n\n      # (full batch, last patch, output_patch_len, index of mean forecast = 0)\n      new_ts = fprop_outputs[:, -1, :output_patch_len, 0]\n      new_full_ts = fprop_outputs[:, -1, :output_patch_len, :]\n      # (full batch, last patch, output_patch_len, all output indices)\n      full_outputs.append(new_full_ts)\n      final_out = torch.concatenate([final_out, new_ts], axis=-1)\n\n    if return_forecast_on_context:\n      # `full_outputs` indexing starts at after the first input patch.\n      full_outputs = torch.concatenate(\n          full_outputs,\n          axis=1)[:, :(context_len - self.config.patch_len + horizon_len), :]\n    else:\n      # `full_outputs` indexing starts at the forecast horizon.\n      full_outputs = torch.concatenate(full_outputs, axis=1)[:,\n                                                             0:horizon_len, :]\n\n    return (full_outputs[:, :, 0], full_outputs)\n"
  },
  {
    "path": "probts/model/nn/arch/TimesFMModule/timesfm_base.py",
    "content": "# Copyright 2024 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Base class for TimesFM inference. This will be common to PAX and Pytorch.\"\"\"\n\nimport collections\nimport dataclasses\nimport logging\nimport multiprocessing\nfrom typing import Any, Literal, Sequence\n\nimport numpy as np\nimport pandas as pd\n\nfrom utilsforecast.processing import make_future_dataframe\n\nfrom probts.model.nn.arch.TimesFMModule import xreg_lib\n\nCategory = xreg_lib.Category\nXRegMode = xreg_lib.XRegMode\n\n_TOL = 1e-6\nDEFAULT_QUANTILES = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9)\n\n\ndef process_group(key, group, value_name, forecast_context_len):\n  group = group.tail(forecast_context_len)\n  return np.array(group[value_name], dtype=np.float32), key\n\n\ndef moving_average(arr, window_size):\n  \"\"\"Calculates the moving average using NumPy's convolution function.\"\"\"\n  # Pad with zeros to handle initial window positions\n  arr_padded = np.pad(arr, (window_size - 1, 0), \"constant\")\n  smoothed_arr = (np.convolve(arr_padded, np.ones(window_size), \"valid\") /\n                  window_size)\n  return [smoothed_arr, arr - smoothed_arr]\n\n\ndef freq_map(freq: str):\n  \"\"\"Returns the frequency map for the given frequency string.\"\"\"\n  freq = str.upper(freq)\n  if (freq.endswith(\"H\") or freq.endswith(\"T\") or freq.endswith(\"MIN\") or\n      freq.endswith(\"D\") or freq.endswith(\"B\") or freq.endswith(\"U\") or\n      freq.endswith(\"S\")):\n    return 0\n  elif freq.endswith((\"W\", \"M\", \"MS\")):\n    return 1\n  elif freq.endswith(\"Y\") or freq.endswith(\"Q\") or freq.endswith(\"A\"):\n    return 2\n  else:\n    raise ValueError(f\"Invalid frequency: {freq}\")\n\ndef strip_leading_nans(arr):\n  \"\"\"\n  Removes contiguous NaN values from the beginning of a NumPy array.\n\n  Args:\n    arr: The input NumPy array.\n\n  Returns:\n    A new NumPy array with leading NaN values removed.\n    If the array is all NaNs or empty, returns an empty array.\n  \"\"\"\n\n  isnan = np.isnan(arr)\n  first_valid_index = np.argmax(~isnan)\n  return arr[first_valid_index:]\n\ndef linear_interpolation(arr):\n  \"\"\"\n    Performs linear interpolation to fill NaN values in a 1D numpy array.\n\n    Args:\n        arr: The 1D numpy array containing NaN values.\n\n    Returns:\n        A new numpy array with NaN values filled using linear interpolation, \n        or the original array if no NaNs are present. \n        Returns None if the input is not a 1D array.\n        Returns the original array if there are no NaN values.\n    \"\"\"\n\n  nans = np.isnan(arr)\n  if not np.any(nans):  # Check if there are any NaNs\n    return arr\n\n  x = lambda z: z.nonzero()[0]\n  nans_indices = x(nans)\n  non_nans_indices = x(~nans)\n  non_nans_values = arr[~nans]\n\n  try:\n    arr[nans] = np.interp(nans_indices, non_nans_indices, non_nans_values)\n  except ValueError:\n    if len(non_nans_values) > 0:\n      mu = np.nanmean(arr)\n    else:\n      mu = 0.0\n    arr = np.where(np.isfinite(arr), arr, mu)\n  return arr\n\n\n# Per time series normalization: forward.\ndef _normalize(batch):\n  stats = [\n      (np.mean(x), np.where((w := np.std(x)) > _TOL, w, 1.0)) for x in batch\n  ]\n  new_batch = [(x - stat[0]) / stat[1] for x, stat in zip(batch, stats)]\n  return new_batch, stats\n\n\n# Per time series normalization: inverse.\ndef _renormalize(batch, stats):\n  return [x * stat[1] + stat[0] for x, stat in zip(batch, stats)]\n\n\n@dataclasses.dataclass(kw_only=True)\nclass TimesFmHparams:\n  \"\"\"Hparams used to initialize a TimesFM model for inference.\n\n  These are the sufficient subset of hparams to configure TimesFM inference\n  agnostic to the checkpoint version, and are not necessarily the same as the\n  hparams used to train the checkpoint.\n\n  Attributes:\n    context_len: Largest context length the model allows for each decode call.\n      This technically can be any large, but practically should set to the\n      context length the checkpoint was trained with.\n    horizon_len: Forecast horizon.\n    input_patch_len: Input patch len.\n    output_patch_len: Output patch len. How many timepoints is taken from a\n      single step of autoregressive decoding. Can be set as the training horizon\n      of the checkpoint.\n    num_layers: Number of transformer layers in the model.\n    model_dims: Model dimension.\n    per_core_batch_size: Batch size on each core for data parallelism.\n    backend: One of \"cpu\", \"gpu\" or \"tpu\".\n    quantiles: Which quantiles are output by the model.\n  \"\"\"\n\n  context_len: int = 512\n  horizon_len: int = 128\n  input_patch_len: int = 32\n  output_patch_len: int = 128\n  num_layers: int = 20\n  num_heads: int = 16\n  model_dims: int = 1280\n  per_core_batch_size: int = 32\n  backend: Literal[\"cpu\", \"gpu\", \"tpu\"] = \"cpu\"\n  quantiles: Sequence[float] | None = DEFAULT_QUANTILES\n  use_positional_embedding: bool = True\n  # Hparams beyond the model.\n  point_forecast_mode: Literal[\"mean\", \"median\"] = \"median\"\n\n\n@dataclasses.dataclass(kw_only=True)\nclass TimesFmCheckpoint:\n  \"\"\"Checkpoint used to initialize a TimesFM model for inference.\n\n  Attributes:\n    version: Version of the checkpoint, e.g. \"jax\", \"torch\", \"tensorflow\", etc.\n      The factory will create the corresponding TimesFm inference class based on\n      this version.\n    path: Path to the checkpoint.\n    type: If provided, type of the checkpoint used by the specific checkpoint\n      loader per version.\n    step: If provided, step of the checkpoint.\n  \"\"\"\n\n  version: str = \"jax\"\n  path: str | None = None\n  huggingface_repo_id: str | None = None\n  type: Any = None\n  step: int | None = None\n  local_dir: str | None = None\n\n\nclass TimesFmBase:\n  \"\"\"Base TimesFM forecast API for inference.\n\n  This class is the scaffolding for calling TimesFM forecast. To properly use:\n    1. Create an instance with the correct hyperparameters of a TimesFM model.\n    2. Call `load_from_checkpoint` to load a compatible checkpoint.\n    3. Call `forecast` for inference.\n  \"\"\"\n\n  def _logging(self, s):\n    print(s)\n\n  def __post_init__(self) -> None:\n    \"\"\"Additional initialization for subclasses before checkpoint loading.\"\"\"\n    pass\n\n  def __init__(self, hparams: TimesFmHparams,\n               checkpoint: TimesFmCheckpoint) -> None:\n    \"\"\"Initializes the TimesFM forecast API.\n\n    Args:\n      hparams: Hyperparameters of the model.\n      checkpoint: Checkpoint to load. Notice `checkpoint.version` will decide\n        which TimesFM version to use.\n    \"\"\"\n    self.hparams = hparams\n\n    # Expand hparams for conciseness within the model code.\n    self.context_len = hparams.context_len\n    self.horizon_len = hparams.horizon_len\n    self.input_patch_len = hparams.input_patch_len\n    self.output_patch_len = hparams.output_patch_len\n    self.num_layers = hparams.num_layers\n    self.model_dims = hparams.model_dims\n    self.backend = hparams.backend\n    self.quantiles = hparams.quantiles\n    self.num_heads = hparams.num_heads\n    self.use_pos_emb = hparams.use_positional_embedding\n\n    # Rewrite these values in __post_init__ for SPMD.\n    self.num_cores = 1\n    self.per_core_batch_size = hparams.per_core_batch_size\n    self.global_batch_size = hparams.per_core_batch_size\n\n    self._horizon_start = self.context_len - self.input_patch_len\n    self.__post_init__()\n    self.load_from_checkpoint(checkpoint)\n\n  def load_from_checkpoint(self, checkpoint: TimesFmCheckpoint) -> None:\n    \"\"\"Loads a checkpoint and compiles the decoder.\"\"\"\n    raise NotImplementedError(\"`load_from_checkpoint` is not implemented.\")\n\n  def _preprocess(\n      self, inputs: Sequence[np.ndarray],\n      freq: Sequence[int]) -> tuple[np.ndarray, np.ndarray, np.ndarray, int]:\n    \"\"\"Formats and pads raw inputs to feed into the model.\n\n    This function both pads each time series to match the context length, and\n    pads the inputs to meet the SPMD shape requirement.\n\n    Args:\n      inputs: A list of 1d JTensors. Each JTensor is the context time series of\n        a single forecast task.\n      freq: list of frequencies\n\n    Returns:\n    A tuple of:\n    - the padded input time series to meet the model required context.\n    - the padding indicator.\n    - the frequency of each input time series.\n    - the number of padded examples for SPMD so that each core has the same\n        number (a multiple of `batch_size`) of examples.\n    \"\"\"\n\n    input_ts, input_padding, inp_freq = [], [], []\n\n    pmap_pad = ((len(inputs) - 1) // self.global_batch_size +\n                1) * self.global_batch_size - len(inputs)\n\n    for i, ts in enumerate(inputs):\n      input_len = ts.shape[0]\n      padding = np.zeros(shape=(input_len + self.horizon_len,), dtype=float)\n      if input_len < self.context_len:\n        num_front_pad = self.context_len - input_len\n        ts = np.concatenate([np.zeros(shape=(num_front_pad,), dtype=float), ts],\n                            axis=0)\n        padding = np.concatenate(\n            [np.ones(shape=(num_front_pad,), dtype=float), padding], axis=0)\n      elif input_len > self.context_len:\n        ts = ts[-self.context_len:]\n        padding = padding[-(self.context_len + self.horizon_len):]\n\n      input_ts.append(ts)\n      input_padding.append(padding)\n      inp_freq.append(freq[i])\n\n    # Padding the remainder batch.\n    for _ in range(pmap_pad):\n      input_ts.append(input_ts[-1])\n      input_padding.append(input_padding[-1])\n      inp_freq.append(inp_freq[-1])\n\n    return (\n        np.stack(input_ts, axis=0),\n        np.stack(input_padding, axis=0),\n        np.array(inp_freq).astype(np.int32).reshape(-1, 1),\n        pmap_pad,\n    )\n\n  def _forecast(\n      self,\n      inputs: Sequence[Any],\n      freq: Sequence[int] | None = None,\n      window_size: int | None = None,\n      forecast_context_len: int | None = None,\n      return_forecast_on_context: bool = False,\n  ) -> tuple[np.ndarray, np.ndarray]:\n    \"\"\"Forecasts on a list of time series.\n\n    Args:\n      inputs: list of time series forecast contexts. Each context time series\n        should be in a format convertible to JTensor by `jnp.array`.\n      freq: frequency of each context time series. 0 for high frequency\n        (default), 1 for medium, and 2 for low. Notice this is different from\n        the `freq` required by `forecast_on_df`.\n      window_size: window size of trend + residual decomposition. If None then\n        we do not do decomposition.\n      forecast_context_len: optional max context length.\n      return_forecast_on_context: True to return the forecast on the context\n        when available, i.e. after the first input patch.\n\n    Returns:\n    A tuple for np.array:\n    - the mean forecast of size (# inputs, # forecast horizon),\n    - the full forecast (mean + quantiles) of size\n        (# inputs,  # forecast horizon, 1 + # quantiles).\n\n    Raises:\n    ValueError: If the checkpoint is not properly loaded.\n    \"\"\"\n    raise NotImplementedError(\"`_forecast` is not implemented.\")\n\n  def forecast(\n      self,\n      inputs: Sequence[Any],\n      freq: Sequence[int] | None = None,\n      window_size: int | None = None,\n      forecast_context_len: int | None = None,\n      return_forecast_on_context: bool = False,\n      normalize: bool = False,\n  ) -> tuple[np.ndarray, np.ndarray]:\n    \"\"\"Forecasts on a list of time series.\n\n    Args:\n      inputs: list of time series forecast contexts. Each context time series\n        should be in a format convertible to JTensor by `jnp.array`.\n      freq: frequency of each context time series. 0 for high frequency\n        (default), 1 for medium, and 2 for low. Notice this is different from\n        the `freq` required by `forecast_on_df`.\n      window_size: window size of trend + residual decomposition. If None then\n        we do not do decomposition.\n      forecast_context_len: optional max context length.\n      return_forecast_on_context: True to return the forecast on the context\n        when available, i.e. after the first input patch.\n      normalize: If True, then we normalize the inputs before forecasting and\n        the outputs are then renormalized to the original scale.\n\n    Returns:\n    A tuple for np.array:\n    - the mean forecast of size (# inputs, # forecast horizon),\n    - the full forecast (mean + quantiles) of size\n        (# inputs,  # forecast horizon, 1 + # quantiles).\n\n    Raises:\n    ValueError: If the checkpoint is not properly loaded.\n    \"\"\"\n    stats = None\n    \n    tmp_inputs = []\n    for each_input in inputs:\n      arr = np.array(each_input)\n      if not np.isfinite(arr).all():\n        arr = np.where(np.isfinite(arr), arr, np.nan)\n        arr = strip_leading_nans(arr)\n        arr = linear_interpolation(arr)\n      tmp_inputs.append(arr)\n  \n    inputs = tmp_inputs\n    if normalize:\n      inputs, stats = _normalize(inputs)\n    mean_forecast, quantile_forecast = self._forecast(\n        inputs,\n        freq,\n        window_size,\n        forecast_context_len,\n        return_forecast_on_context,\n    )\n    if stats is not None:\n      stats = np.array(stats)\n      mu = stats[:, 0]\n      sigma = stats[:, 1]\n      mean_forecast = mean_forecast * sigma[:, None] + mu[:, None]\n      quantile_forecast = (quantile_forecast * sigma[:, None, None] +\n                           mu[:, None, None])\n    if self.hparams.point_forecast_mode == \"mean\":\n      return mean_forecast, quantile_forecast\n    elif self.hparams.point_forecast_mode == \"median\":\n      if self._median_index == -1:\n        for i, quantile in enumerate(self.quantiles):\n          if quantile == 0.5:\n            self._median_index = i\n            break\n        if self._median_index == -1:\n          raise ValueError(\"Median (0.5) is not found in the model quantiles:\"\n                           f\" {self.quantiles}. Please check the hparams.\")\n      return (\n          quantile_forecast[:, :, 1 + self._median_index],\n          quantile_forecast,\n      )\n    else:\n      raise ValueError(\n          \"Unsupported point forecast mode:\"\n          f\" {self.hparams.point_forecast_mode}. Use 'mean' or 'median'.\")\n\n  def forecast_with_covariates(\n      self,\n      inputs: list[Sequence[float]],\n      dynamic_numerical_covariates: (dict[str, Sequence[Sequence[float]]] |\n                                     None) = None,\n      dynamic_categorical_covariates: (dict[str, Sequence[Sequence[Category]]] |\n                                       None) = None,\n      static_numerical_covariates: dict[str, Sequence[float]] | None = None,\n      static_categorical_covariates: (dict[str, Sequence[Category]] |\n                                      None) = None,\n      freq: Sequence[int] | None = None,\n      window_size: int | None = None,\n      forecast_context_len: int | None = None,\n      xreg_mode: XRegMode = \"xreg + timesfm\",\n      normalize_xreg_target_per_input: bool = True,\n      ridge: float = 0.0,\n      max_rows_per_col: int = 0,\n      force_on_cpu: bool = False,\n  ):\n    \"\"\"Forecasts on a list of time series with covariates.\n\n    To optimize inference speed, avoid string valued categorical covariates.\n\n    Args:\n      inputs: A list of time series forecast contexts. Each context time series\n        should be in a format convertible to JTensor by `jnp.array`.\n      dynamic_numerical_covariates: A dict of dynamic numerical covariates.\n      dynamic_categorical_covariates: A dict of dynamic categorical covariates.\n      static_numerical_covariates: A dict of static numerical covariates.\n      static_categorical_covariates: A dict of static categorical covariates.\n      freq: frequency of each context time series. 0 for high frequency\n        (default), 1 for medium, and 2 for low. Notice this is different from\n        the `freq` required by `forecast_on_df`.\n      window_size: window size of trend + residual decomposition. If None then\n        we do not do decomposition.\n      forecast_context_len: optional max context length.\n      xreg_mode: one of \"xreg + timesfm\" or \"timesfm + xreg\". \"xreg + timesfm\"\n        fits a model on the residuals of the TimesFM forecast. \"timesfm + xreg\"\n        fits a model on the targets then forecasts on the residuals via TimesFM.\n      normalize_xreg_target_per_input: whether to normalize the xreg target per\n        input in the given batch.\n      ridge: ridge penalty for the linear model.\n      max_rows_per_col: max number of rows per column for the linear model.\n      force_on_cpu: whether to force running on cpu for the linear model.\n\n    Returns:\n      A tuple of two lists. The first is the outputs of the model. The second is\n      the outputs of the xreg.\n    \"\"\"\n\n    # Verify and bookkeep covariates.\n    if not (dynamic_numerical_covariates or dynamic_categorical_covariates or\n            static_numerical_covariates or static_categorical_covariates):\n      raise ValueError(\n          \"At least one of dynamic_numerical_covariates,\"\n          \" dynamic_categorical_covariates, static_numerical_covariates,\"\n          \" static_categorical_covariates must be set.\")\n\n    # Track the lengths of (1) each input, (2) the part that can be used in the\n    # linear model, and (3) the horizon.\n    input_lens, train_lens, test_lens = [], [], []\n\n    for i, input_ts in enumerate(inputs):\n      input_len = len(input_ts)\n      input_lens.append(input_len)\n\n      if xreg_mode == \"timesfm + xreg\":\n        # For fitting residuals, no TimesFM forecast on the first patch.\n        train_lens.append(max(0, input_len - self.input_patch_len))\n      elif xreg_mode == \"xreg + timesfm\":\n        train_lens.append(input_len)\n      else:\n        raise ValueError(f\"Unsupported mode: {xreg_mode}\")\n\n      if dynamic_numerical_covariates:\n        test_lens.append(\n            len(list(dynamic_numerical_covariates.values())[0][i]) - input_len)\n      elif dynamic_categorical_covariates:\n        test_lens.append(\n            len(list(dynamic_categorical_covariates.values())[0][i]) -\n            input_len)\n      else:\n        test_lens.append(self.horizon_len)\n\n      if test_lens[-1] > self.horizon_len:\n        raise ValueError(\n            \"Forecast requested longer horizon than the model definition \"\n            f\"supports: {test_lens[-1]} vs {self.horizon_len}.\")\n\n    # Prepare the covariates into train and test.\n    train_dynamic_numerical_covariates = collections.defaultdict(list)\n    test_dynamic_numerical_covariates = collections.defaultdict(list)\n    train_dynamic_categorical_covariates = collections.defaultdict(list)\n    test_dynamic_categorical_covariates = collections.defaultdict(list)\n    for covariates, train_covariates, test_covariates in (\n        (\n            dynamic_numerical_covariates,\n            train_dynamic_numerical_covariates,\n            test_dynamic_numerical_covariates,\n        ),\n        (\n            dynamic_categorical_covariates,\n            train_dynamic_categorical_covariates,\n            test_dynamic_categorical_covariates,\n        ),\n    ):\n      if not covariates:\n        continue\n      for covariate_name, covariate_values in covariates.items():\n        for input_len, train_len, covariate_value in zip(\n            input_lens, train_lens, covariate_values):\n          train_covariates[covariate_name].append(\n              covariate_value[(input_len - train_len):input_len])\n          test_covariates[covariate_name].append(covariate_value[input_len:])\n\n    # Fit models.\n    if xreg_mode == \"timesfm + xreg\":\n      # Forecast via TimesFM then fit a model on the residuals.\n      mean_outputs, _ = self.forecast(\n          inputs,\n          freq,\n          window_size,\n          forecast_context_len,\n          return_forecast_on_context=True,\n      )\n      targets = [\n          (np.array(input_ts)[-train_len:] -\n           mean_output[(self._horizon_start - train_len):self._horizon_start])\n          for input_ts, mean_output, train_len in zip(inputs, mean_outputs,\n                                                      train_lens)\n      ]\n      per_instance_stats = None\n      if normalize_xreg_target_per_input:\n        targets, per_instance_stats = _normalize(targets)\n      xregs = xreg_lib.BatchedInContextXRegLinear(\n          targets=targets,\n          train_lens=train_lens,\n          test_lens=test_lens,\n          train_dynamic_numerical_covariates=train_dynamic_numerical_covariates,\n          test_dynamic_numerical_covariates=test_dynamic_numerical_covariates,\n          train_dynamic_categorical_covariates=\n          train_dynamic_categorical_covariates,\n          test_dynamic_categorical_covariates=\n          test_dynamic_categorical_covariates,\n          static_numerical_covariates=static_numerical_covariates,\n          static_categorical_covariates=static_categorical_covariates,\n      ).fit(\n          ridge=ridge,\n          one_hot_encoder_drop=None if ridge > 0 else \"first\",\n          max_rows_per_col=max_rows_per_col,\n          force_on_cpu=force_on_cpu,\n          debug_info=False,\n          assert_covariates=True,\n          assert_covariate_shapes=True,\n      )\n      if normalize_xreg_target_per_input:\n        xregs = _renormalize(xregs, per_instance_stats)\n      outputs = [\n          (mean_output[self._horizon_start:(self._horizon_start + test_len)] +\n           xreg)\n          for mean_output, test_len, xreg in zip(mean_outputs, test_lens, xregs)\n      ]\n\n    else:\n      # Fit a model on the targets then forecast on the residuals via TimesFM.\n      targets = [\n          np.array(input_ts)[-train_len:]\n          for input_ts, train_len in zip(inputs, train_lens)\n      ]\n      per_instance_stats = None\n      if normalize_xreg_target_per_input:\n        targets, per_instance_stats = _normalize(targets)\n      xregs, xregs_on_context, _, _, _ = xreg_lib.BatchedInContextXRegLinear(\n          targets=targets,\n          train_lens=train_lens,\n          test_lens=test_lens,\n          train_dynamic_numerical_covariates=train_dynamic_numerical_covariates,\n          test_dynamic_numerical_covariates=test_dynamic_numerical_covariates,\n          train_dynamic_categorical_covariates=\n          train_dynamic_categorical_covariates,\n          test_dynamic_categorical_covariates=\n          test_dynamic_categorical_covariates,\n          static_numerical_covariates=static_numerical_covariates,\n          static_categorical_covariates=static_categorical_covariates,\n      ).fit(\n          ridge=ridge,\n          one_hot_encoder_drop=None if ridge > 0 else \"first\",\n          max_rows_per_col=max_rows_per_col,\n          force_on_cpu=force_on_cpu,\n          debug_info=True,\n          assert_covariates=True,\n          assert_covariate_shapes=True,\n      )\n      mean_outputs, _ = self.forecast(\n          [\n              target - xreg_on_context\n              for target, xreg_on_context in zip(targets, xregs_on_context)\n          ],\n          freq,\n          window_size,\n          forecast_context_len,\n          return_forecast_on_context=True,\n      )\n      outputs = [\n          (mean_output[self._horizon_start:(self._horizon_start + test_len)] +\n           xreg)\n          for mean_output, test_len, xreg in zip(mean_outputs, test_lens, xregs)\n      ]\n      if normalize_xreg_target_per_input:\n        outputs = _renormalize(outputs, per_instance_stats)\n\n    return outputs, xregs\n\n  def forecast_on_df(\n      self,\n      inputs: pd.DataFrame,\n      freq: str,\n      forecast_context_len: int = 0,\n      value_name: str = \"values\",\n      model_name: str = \"timesfm\",\n      window_size: int | None = None,\n      num_jobs: int = 1,\n      verbose: bool = True,\n  ) -> pd.DataFrame:\n    \"\"\"Forecasts on a list of time series.\n\n    Args:\n      inputs: A pd.DataFrame of all time series. The dataframe should have a\n        `unique_id` column for identifying the time series, a `ds` column for\n        timestamps and a value column for the time series values.\n      freq: string valued `freq` of data. Notice this is different from the\n        `freq` required by `forecast`. See `freq_map` for allowed values.\n      forecast_context_len: If provided none zero, we take the last\n        `forecast_context_len` time-points from each series as the forecast\n        context instead of the `context_len` set by the model.\n      value_name: The name of the value column.\n      model_name: name of the model to be written into future df.\n      window_size: window size of trend + residual decomposition. If None then\n        we do not do decomposition.\n      num_jobs: number of parallel processes to use for dataframe processing.\n      verbose: output model states in terminal.\n\n    Returns:\n      Future forecasts dataframe.\n    \"\"\"\n    if not (\"unique_id\" in inputs.columns and \"ds\" in inputs.columns and\n            value_name in inputs.columns):\n      raise ValueError(\n          f\"DataFrame must have unique_id, ds and {value_name} columns.\")\n    if not forecast_context_len:\n      forecast_context_len = self.context_len\n    logging.info(\"Preprocessing dataframe.\")\n    df_sorted = inputs.sort_values(by=[\"unique_id\", \"ds\"])\n    new_inputs = []\n    uids = []\n    if num_jobs == 1:\n      if verbose:\n        print(\"Processing dataframe with single process.\")\n      for key, group in df_sorted.groupby(\"unique_id\"):\n        inp, uid = process_group(\n            key,\n            group,\n            value_name,\n            forecast_context_len,\n        )\n        new_inputs.append(inp)\n        uids.append(uid)\n    else:\n      if num_jobs == -1:\n        num_jobs = multiprocessing.cpu_count()\n      if verbose:\n        print(\"Processing dataframe with multiple processes.\")\n      with multiprocessing.Pool(processes=num_jobs) as pool:\n        results = pool.starmap(\n            process_group,\n            [(key, group, value_name, forecast_context_len)\n             for key, group in df_sorted.groupby(\"unique_id\")],\n        )\n      new_inputs, uids = zip(*results)\n    if verbose:\n      print(\"Finished preprocessing dataframe.\")\n    freq_inps = [freq_map(freq)] * len(new_inputs)\n    _, full_forecast = self.forecast(new_inputs,\n                                     freq=freq_inps,\n                                     window_size=window_size)\n    if verbose:\n      print(\"Finished forecasting.\")\n    fcst_df = make_future_dataframe(\n        uids=uids,\n        last_times=df_sorted.groupby(\"unique_id\")[\"ds\"].tail(1),\n        h=self.horizon_len,\n        freq=freq,\n    )\n    fcst_df[model_name] = full_forecast[:, 0:self.horizon_len, 0].reshape(-1, 1)\n\n    for i, q in enumerate(self.quantiles):\n      q_col = f\"{model_name}-q-{q}\"\n      fcst_df[q_col] = full_forecast[:, 0:self.horizon_len,\n                                     1 + i].reshape(-1, 1)\n      if q == 0.5:\n        fcst_df[model_name] = fcst_df[q_col]\n    logging.info(\"Finished creating output dataframe.\")\n    return fcst_df\n"
  },
  {
    "path": "probts/model/nn/arch/TimesFMModule/timesfm_jax.py",
    "content": "# Copyright 2024 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"TimesFM JAX forecast API for inference.\"\"\"\n\nimport logging\nimport multiprocessing\nimport time\nfrom os import path\nfrom typing import Any, Sequence\n\nimport einshape as es\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\nfrom huggingface_hub import snapshot_download\n\nfrom paxml import checkpoints, tasks_lib\nfrom praxis import base_hyperparams, base_layer, pax_fiddle, py_utils, pytypes\nfrom praxis.layers import normalizations, transformers\nfrom probts.model.nn.arch.TimesFMModule import timesfm_base\nfrom probts.model.nn.arch.TimesFMModule import patched_decoder\n\ninstantiate = base_hyperparams.instantiate\nNestedMap = py_utils.NestedMap\nJTensor = pytypes.JTensor\n\n_TOL = 1e-6\n\n\nclass TimesFmJax(timesfm_base.TimesFmBase):\n  \"\"\"TimesFM forecast API for inference.\n\n  This class is the scaffolding for calling TimesFM forecast. To properly use:\n    1. Create an instance with the correct hyperparameters of a TimesFM model.\n    2. Call `load_from_checkpoint` to load a compatible checkpoint.\n    3. Call `forecast` for inference.\n\n  Given the model size, this API does not shard the model weights for SPMD. All\n  parallelism happens on the data dimension.\n\n  Compilation happens during the first time `forecast` is called and uses the\n  `per_core_batch_size` to set and freeze the input signature. Subsequent calls\n  to `forecast` reflect the actual inference latency.\n  \"\"\"\n\n  def _get_sample_inputs(self):\n    return {\n        \"input_ts\":\n            jnp.zeros(\n                (\n                    self.per_core_batch_size,\n                    self.context_len + self.output_patch_len,\n                ),\n                dtype=jnp.float32,\n            ),\n        \"input_padding\":\n            jnp.zeros(\n                (\n                    self.per_core_batch_size,\n                    self.context_len + self.output_patch_len,\n                ),\n                dtype=jnp.float32,\n            ),\n        \"freq\":\n            jnp.zeros(\n                (\n                    self.per_core_batch_size,\n                    1,\n                ),\n                dtype=jnp.int32,\n            ),\n    }\n\n  def __post_init__(self):\n    self.num_cores = jax.local_device_count(self.backend)\n    self.global_batch_size = self.per_core_batch_size * self.num_cores\n    self._eval_context = base_layer.JaxContext.HParams(do_eval=True)\n    self._pmapped_decode = None\n    self._model = None\n    self._train_state = None\n    self._median_index = -1\n\n  def load_from_checkpoint(\n      self,\n      checkpoint: timesfm_base.TimesFmCheckpoint,\n  ) -> None:\n    \"\"\"Loads a checkpoint and compiles the decoder.\"\"\"\n    checkpoint_type = (checkpoints.CheckpointType.FLAX\n                       if checkpoint.type is None else checkpoint.type)\n    checkpoint_path = checkpoint.path\n    step = checkpoint.step\n    repo_id = checkpoint.huggingface_repo_id\n    if checkpoint_path is None:\n      checkpoint_path = path.join(snapshot_download(repo_id), \"checkpoints\")\n    # Rewrite the devices for Jax.\n    self.mesh_shape = [1, self.num_cores, 1]\n    self.mesh_name = [\"replica\", \"data\", \"mdl\"]\n\n    self.model_p = pax_fiddle.Config(\n        patched_decoder.PatchedTimeSeriesDecoder,\n        name=\"patched_decoder\",\n        horizon_len=self.output_patch_len,\n        patch_len=self.input_patch_len,\n        model_dims=self.model_dims,\n        hidden_dims=self.model_dims,\n        residual_block_tpl=pax_fiddle.Config(patched_decoder.ResidualBlock),\n        quantiles=self.quantiles,\n        use_freq=True,\n        use_pos_emb=self.use_pos_emb,\n        stacked_transformer_params_tpl=pax_fiddle.Config(\n            transformers.StackedTransformer,\n            num_heads=self.num_heads,\n            num_layers=self.num_layers,\n            transformer_layer_params_tpl=pax_fiddle.Config(\n                transformers.Transformer,\n                ln_tpl=pax_fiddle.Config(normalizations.RmsNorm,),\n            ),\n        ),\n    )\n\n    self._key1, self._key2 = jax.random.split(jax.random.PRNGKey(42))\n    self._model = None\n    self._train_state = None\n    self._pmapped_decode = None\n    self._eval_context = base_layer.JaxContext.HParams(do_eval=True)\n    try:\n      multiprocessing.set_start_method(\"spawn\")\n    except RuntimeError:\n      print(\"Multiprocessing context has already been set.\")\n    # Download the checkpoint from Hugging Face Hub if not given\n\n    #  Initialize the model weights.\n    self._logging(\"Constructing model weights.\")\n    start_time = time.time()\n    self._model = instantiate(self.model_p)\n    var_weight_hparams = self._model.abstract_init_with_metadata(\n        self._get_sample_inputs(), do_eval=True)\n    train_state_partition_specs = tasks_lib.create_state_partition_specs(\n        var_weight_hparams,\n        mesh_shape=self.mesh_shape,\n        mesh_axis_names=self.mesh_name,\n        discard_opt_states=True,\n        learners=None,\n    )\n    train_state_local_shapes = tasks_lib.create_state_unpadded_shapes(\n        var_weight_hparams,\n        discard_opt_states=True,\n        learners=None,\n    )\n    self._logging(\n        f\"Constructed model weights in {time.time() - start_time:.2f} seconds.\")\n\n    # Load the model weights.\n    self._logging(f\"Restoring checkpoint from {checkpoint_path}.\")\n    start_time = time.time()\n    self._train_state = checkpoints.restore_checkpoint(\n        train_state_local_shapes,\n        checkpoint_dir=checkpoint_path,\n        checkpoint_type=checkpoint_type,\n        state_specs=train_state_partition_specs,\n        step=step,\n    )\n    self._logging(\n        f\"Restored checkpoint in {time.time() - start_time:.2f} seconds.\")\n    self.jit_decode()\n\n  def jit_decode(self):\n    \"\"\"Jitting decoding function.\"\"\"\n\n    # Initialize and jit the decode fn.\n    def _decode(inputs):\n      assert self._model is not None\n      assert self._train_state is not None\n      return self._model.apply(\n          self._train_state.mdl_vars,\n          inputs,\n          horizon_len=self.horizon_len,\n          output_patch_len=self.output_patch_len,\n          max_len=self.context_len,\n          return_forecast_on_context=True,\n          rngs={\n              base_layer.PARAMS: self._key1,\n              base_layer.RANDOM: self._key2,\n          },\n          method=self._model.decode,\n      )\n\n    self._logging(\"Jitting decoding.\")\n    start_time = time.time()\n    self._pmapped_decode = jax.pmap(\n        _decode,\n        axis_name=\"batch\",\n        devices=jax.devices(self.backend),\n        backend=self.backend,\n        axis_size=self.num_cores,\n    )\n    with base_layer.JaxContext.new_context(hparams=self._eval_context):\n      _ = self._pmapped_decode(\n          NestedMap({\n              \"input_ts\":\n                  jnp.zeros(\n                      (\n                          self.num_cores,\n                          self.per_core_batch_size,\n                          self.context_len,\n                      ),\n                      dtype=jnp.float32,\n                  ),\n              \"input_padding\":\n                  jnp.zeros(\n                      (\n                          self.num_cores,\n                          self.per_core_batch_size,\n                          self.context_len + self.horizon_len,\n                      ),\n                      dtype=jnp.float32,\n                  ),\n              \"date_features\":\n                  None,\n              \"freq\":\n                  jnp.zeros(\n                      (self.num_cores, self.per_core_batch_size, 1),\n                      dtype=jnp.int32,\n                  ),\n          }))\n    self._logging(f\"Jitted decoding in {time.time() - start_time:.2f} seconds.\")\n\n  def _forecast(\n      self,\n      inputs: Sequence[Any],\n      freq: Sequence[int] | None = None,\n      window_size: int | None = None,\n      forecast_context_len: int | None = None,\n      return_forecast_on_context: bool = False,\n  ) -> tuple[np.ndarray, np.ndarray]:\n    \"\"\"Forecasts on a list of time series.\n\n    Args:\n      inputs: list of time series forecast contexts. Each context time series\n        should be in a format convertible to JTensor by `jnp.array`.\n      freq: frequency of each context time series. 0 for high frequency\n        (default), 1 for medium, and 2 for low. Notice this is different from\n        the `freq` required by `forecast_on_df`.\n      window_size: window size of trend + residual decomposition. If None then\n        we do not do decomposition.\n      forecast_context_len: optional max context length.\n      return_forecast_on_context: True to return the forecast on the context\n        when available, i.e. after the first input patch.\n\n    Returns:\n    A tuple for JTensors:\n    - the mean forecast of size (# inputs, # forecast horizon),\n    - the full forecast (mean + quantiles) of size\n        (# inputs,  # forecast horizon, 1 + # quantiles).\n\n    Raises:\n    ValueError: If the checkpoint is not properly loaded.\n    \"\"\"\n    if not self._train_state or not self._model:\n      raise ValueError(\n          \"Checkpoint not loaded. Call `load_from_checkpoint` before\"\n          \" `forecast`.\")\n    if forecast_context_len is None:\n      fcontext_len = self.context_len\n    else:\n      fcontext_len = forecast_context_len\n    inputs = [np.array(ts)[-fcontext_len:] for ts in inputs]\n\n    if window_size is not None:\n      new_inputs = []\n      for ts in inputs:\n        new_inputs.extend(timesfm_base.moving_average(ts, window_size))\n      inputs = new_inputs\n\n    if freq is None:\n      logging.info(\"No frequency provided via `freq`. Default to high (0).\")\n      freq = [0] * len(inputs)\n\n    input_ts, input_padding, inp_freq, pmap_pad = self._preprocess(inputs, freq)\n    with base_layer.JaxContext.new_context(hparams=self._eval_context):\n      mean_outputs = []\n      full_outputs = []\n      assert input_ts.shape[0] % self.global_batch_size == 0\n      for i in range(input_ts.shape[0] // self.global_batch_size):\n        input_ts_in = jnp.array(input_ts[i * self.global_batch_size:(i + 1) *\n                                         self.global_batch_size])\n        input_padding_in = jnp.array(\n            input_padding[i * self.global_batch_size:(i + 1) *\n                          self.global_batch_size],)\n        inp_freq_in = jnp.array(\n            inp_freq[i * self.global_batch_size:(i + 1) *\n                     self.global_batch_size, :],\n            dtype=jnp.int32,\n        )\n        pmapped_inputs = NestedMap({\n            \"input_ts\":\n                es.jax_einshape(\n                    \"(db)...->db...\",\n                    input_ts_in,\n                    d=self.num_cores,\n                ),\n            \"input_padding\":\n                es.jax_einshape(\n                    \"(db)...->db...\",\n                    input_padding_in,\n                    d=self.num_cores,\n                ),\n            \"date_features\":\n                None,\n            \"freq\":\n                es.jax_einshape(\n                    \"(db)...->db...\",\n                    inp_freq_in,\n                    d=self.num_cores,\n                ),\n        })\n        mean_output, full_output = self._pmapped_decode(pmapped_inputs)\n        if not return_forecast_on_context:\n          mean_output = mean_output[:, :, self._horizon_start:, ...]\n          full_output = full_output[:, :, self._horizon_start:, ...]\n        mean_output = es.jax_einshape(\"db...->(db)...\",\n                                      mean_output,\n                                      d=self.num_cores)\n        full_output = es.jax_einshape(\"db...->(db)...\",\n                                      full_output,\n                                      d=self.num_cores)\n        mean_output = np.array(mean_output)\n        full_output = np.array(full_output)\n        mean_outputs.append(mean_output)\n        full_outputs.append(full_output)\n\n    mean_outputs = np.concatenate(mean_outputs, axis=0)\n    full_outputs = np.concatenate(full_outputs, axis=0)\n\n    if pmap_pad > 0:\n      mean_outputs = mean_outputs[:-pmap_pad, ...]\n      full_outputs = full_outputs[:-pmap_pad, ...]\n\n    if window_size is not None:\n      mean_outputs = mean_outputs[0::2, ...] + mean_outputs[1::2, ...]\n      full_outputs = full_outputs[0::2, ...] + full_outputs[1::2, ...]\n    return mean_outputs, full_outputs\n"
  },
  {
    "path": "probts/model/nn/arch/TimesFMModule/timesfm_torch.py",
    "content": "# Copyright 2024 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"TimesFM pytorch forecast API for inference.\"\"\"\n\nimport logging\nfrom os import path\nfrom typing import Any, Sequence\n\nimport numpy as np\nimport torch\nfrom huggingface_hub import snapshot_download\nfrom probts.model.nn.arch.TimesFMModule import timesfm_base\n\nfrom probts.model.nn.arch.TimesFMModule import pytorch_patched_decoder as ppd\n\n_TOL = 1e-6\n\n\nclass TimesFmTorch(timesfm_base.TimesFmBase):\n  \"\"\"TimesFM forecast API for inference.\"\"\"\n\n  def __post_init__(self):\n    self._model_config = ppd.TimesFMConfig(\n        num_layers=self.num_layers,\n        num_heads=self.num_heads,\n        hidden_size=self.model_dims,\n        intermediate_size=self.model_dims,\n        patch_len=self.input_patch_len,\n        horizon_len=self.output_patch_len,\n        head_dim=self.model_dims // self.num_heads,\n        quantiles=self.quantiles,\n        use_positional_embedding=self.use_pos_emb,\n    )\n    self._model = None\n    self.num_cores = 1\n    self.global_batch_size = self.per_core_batch_size\n    self._device = torch.device(\"cuda:0\" if (\n        torch.cuda.is_available() and self.backend == \"gpu\") else \"cpu\")\n    self._median_index = -1\n\n  def load_from_checkpoint(\n      self,\n      checkpoint: timesfm_base.TimesFmCheckpoint,\n  ) -> None:\n    \"\"\"Loads a checkpoint and compiles the decoder.\"\"\"\n    checkpoint_path = checkpoint.path\n    repo_id = checkpoint.huggingface_repo_id\n    if checkpoint_path is None:\n      checkpoint_path = path.join(\n                snapshot_download(repo_id, local_dir=checkpoint.local_dir),\n                \"torch_model.ckpt\")\n    self._model = ppd.PatchedTimeSeriesDecoder(self._model_config)\n    loaded_checkpoint = torch.load(checkpoint_path, weights_only=True)\n    logging.info(\"Loading checkpoint from %s\", checkpoint_path)\n    self._model.load_state_dict(loaded_checkpoint)\n    logging.info(\"Sending checkpoint to device %s\", f\"{self._device}\")\n    self._model.to(self._device)\n    self._model.eval()\n    # TODO: add compilation.\n\n  def _forecast(\n      self,\n      inputs: Sequence[Any],\n      freq: Sequence[int] | None = None,\n      window_size: int | None = None,\n      forecast_context_len: int | None = None,\n      return_forecast_on_context: bool = False,\n  ) -> tuple[np.ndarray, np.ndarray]:\n    \"\"\"Forecasts on a list of time series.\n\n        Args:\n          inputs: list of time series forecast contexts. Each context time series\n            should be in a format convertible to JTensor by `jnp.array`.\n          freq: frequency of each context time series. 0 for high frequency\n            (default), 1 for medium, and 2 for low. Notice this is different from\n            the `freq` required by `forecast_on_df`.\n          window_size: window size of trend + residual decomposition. If None then\n            we do not do decomposition.\n          forecast_context_len: optional max context length.\n          return_forecast_on_context: True to return the forecast on the context\n            when available, i.e. after the first input patch.\n\n        Returns:\n        A tuple for JTensors:\n        - the mean forecast of size (# inputs, # forecast horizon),\n        - the full forecast (mean + quantiles) of size\n            (# inputs,  # forecast horizon, 1 + # quantiles).\n\n        Raises:\n        ValueError: If the checkpoint is not properly loaded.\n        \"\"\"\n    if not self._model:\n      raise ValueError(\n          \"Checkpoint not loaded. Call `load_from_checkpoint` before\"\n          \" `forecast`.\")\n    if forecast_context_len is None:\n      fcontext_len = self.context_len\n    else:\n      fcontext_len = forecast_context_len\n    inputs = [np.array(ts)[-fcontext_len:] for ts in inputs]\n\n    if window_size is not None:\n      new_inputs = []\n      for ts in inputs:\n        new_inputs.extend(timesfm_base.moving_average(ts, window_size))\n      inputs = new_inputs\n\n    if freq is None:\n      logging.info(\"No frequency provided via `freq`. Default to high (0).\")\n      freq = [0] * len(inputs)\n\n    input_ts, input_padding, inp_freq, pmap_pad = self._preprocess(inputs, freq)\n    with torch.no_grad():\n      mean_outputs = []\n      full_outputs = []\n      assert input_ts.shape[0] % self.global_batch_size == 0\n      for i in range(input_ts.shape[0] // self.global_batch_size):\n        input_ts_in = torch.from_numpy(\n            np.array(input_ts[i * self.global_batch_size:(i + 1) *\n                              self.global_batch_size],\n                     dtype=np.float32)).to(self._device)\n        input_padding_in = torch.from_numpy(\n            np.array(input_padding[i * self.global_batch_size:(i + 1) *\n                                   self.global_batch_size],\n                     dtype=np.float32)).to(self._device)\n        inp_freq_in = torch.from_numpy(\n            np.array(inp_freq[\n                i * self.global_batch_size:(i + 1) * self.global_batch_size,\n                :,\n            ],\n                     dtype=np.int32)).long().to(self._device)\n        mean_output, full_output = self._model.decode(\n            input_ts=input_ts_in,\n            paddings=input_padding_in,\n            freq=inp_freq_in,\n            horizon_len=self.horizon_len,\n            return_forecast_on_context=return_forecast_on_context,\n        )\n        mean_output = mean_output.detach().cpu().numpy()\n        full_output = full_output.detach().cpu().numpy()\n        mean_output = np.array(mean_output)\n        full_output = np.array(full_output)\n        mean_outputs.append(mean_output)\n        full_outputs.append(full_output)\n\n    mean_outputs = np.concatenate(mean_outputs, axis=0)\n    full_outputs = np.concatenate(full_outputs, axis=0)\n\n    if pmap_pad > 0:\n      mean_outputs = mean_outputs[:-pmap_pad, ...]\n      full_outputs = full_outputs[:-pmap_pad, ...]\n\n    if window_size is not None:\n      mean_outputs = mean_outputs[0::2, ...] + mean_outputs[1::2, ...]\n      full_outputs = full_outputs[0::2, ...] + full_outputs[1::2, ...]\n    return mean_outputs, full_outputs\n"
  },
  {
    "path": "probts/model/nn/arch/TimesFMModule/xreg_lib.py",
    "content": "# Copyright 2024 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Helper functions for in-context covariates and regression.\"\"\"\n\nimport itertools\nimport math\nfrom typing import Any, Iterable, Literal, Mapping, Sequence\n\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\nfrom sklearn import preprocessing\n\nCategory = int | str\n\n_TOL = 1e-6\nXRegMode = Literal[\"timesfm + xreg\", \"xreg + timesfm\"]\n\n\ndef _unnest(nested: Sequence[Sequence[Any]]) -> np.ndarray:\n  return np.array(list(itertools.chain.from_iterable(nested)))\n\n\ndef _repeat(elements: Iterable[Any], counts: Iterable[int]) -> np.ndarray:\n  return np.array(\n      list(\n          itertools.chain.from_iterable(map(itertools.repeat, elements,\n                                            counts))))\n\n\ndef _to_padded_jax_array(x: np.ndarray) -> jax.Array:\n  if x.ndim == 1:\n    (i,) = x.shape\n    di = 2**math.ceil(math.log2(i)) - i\n    return jnp.pad(x, ((0, di),), mode=\"constant\", constant_values=0.0)\n  elif x.ndim == 2:\n    i, j = x.shape\n    di = 2**math.ceil(math.log2(i)) - i\n    dj = 2**math.ceil(math.log2(j)) - j\n    return jnp.pad(x, ((0, di), (0, dj)), mode=\"constant\", constant_values=0.0)\n  else:\n    raise ValueError(f\"Unsupported array shape: {x.shape}\")\n\n\nclass BatchedInContextXRegBase:\n  \"\"\"Helper class for in-context regression covariate formatting.\n\n  Attributes:\n    targets: List of targets (responses) of the in-context regression.\n    train_lens: List of lengths of each target vector from the context.\n    test_lens: List of lengths of each forecast horizon.\n    train_dynamic_numerical_covariates: Dict of covariate names mapping to the\n      dynamic numerical covariates of each forecast task on the context. Their\n      lengths should match the corresponding lengths in `train_lens`.\n    train_dynamic_categorical_covariates: Dict of covariate names mapping to the\n      dynamic categorical covariates of each forecast task on the context. Their\n      lengths should match the corresponding lengths in `train_lens`.\n    test_dynamic_numerical_covariates: Dict of covariate names mapping to the\n      dynamic numerical covariates of each forecast task on the horizon. Their\n      lengths should match the corresponding lengths in `test_lens`.\n    test_dynamic_categorical_covariates: Dict of covariate names mapping to the\n      dynamic categorical covariates of each forecast task on the horizon. Their\n      lengths should match the corresponding lengths in `test_lens`.\n    static_numerical_covariates: Dict of covariate names mapping to the static\n      numerical covariates of each forecast task.\n    static_categorical_covariates: Dict of covariate names mapping to the static\n      categorical covariates of each forecast task.\n  \"\"\"\n\n  def __init__(\n      self,\n      targets: Sequence[Sequence[float]],\n      train_lens: Sequence[int],\n      test_lens: Sequence[int],\n      train_dynamic_numerical_covariates: (\n          Mapping[str, Sequence[Sequence[float]]] | None) = None,\n      train_dynamic_categorical_covariates: (\n          Mapping[str, Sequence[Sequence[Category]]] | None) = None,\n      test_dynamic_numerical_covariates: (\n          Mapping[str, Sequence[Sequence[float]]] | None) = None,\n      test_dynamic_categorical_covariates: (\n          Mapping[str, Sequence[Sequence[Category]]] | None) = None,\n      static_numerical_covariates: Mapping[str, Sequence[float]] | None = None,\n      static_categorical_covariates: (Mapping[str, Sequence[Category]] |\n                                      None) = None,\n  ) -> None:\n    \"\"\"Initializes with the exogenous covariate inputs.\n\n    Here we use model fitting language to refer to the context as 'train' and\n    the horizon as 'test'. We assume batched inputs. To properly format the\n    request:\n\n     - `train_lens` represents the contexts in the batch. Targets and all train\n     dynamic covariates should have the same lengths as the corresponding\n     elements\n     in `train_lens`. Notice each `train_len` can be different from the exact\n     length of the corresponding context depending on how much of the context is\n     used for fitting the in-context model.\n     - `test_lens` represents the horizon lengths in the batch. All tesdt\n     dynamic\n     covariates should have the same lengths as the corresponding elements in\n     `test_lens`.\n     - Static covariates should be one for each input.\n     - For train and test dynamic covariates, they should have the same\n     covariate\n     names.\n\n     Pass an empty dict {} for a covariate type if it is not present.\n\n     Example:\n       Here is a set of valid inputs whose schema can be used for reference.\n       ```\n       targets = [\n           [0.0, 0.1, 0.2],\n           [0.0, 0.1, 0.2, 0.3],\n       ]  # Two inputs in this batch.\n       train_lens = [3, 4]\n       test_lens = [2, 5]  # Forecast horizons 2 and 5 respectively.\n       train_dynamic_numerical_covariates = {\n           \"cov_1_dn\": [[0.0, 0.5, 1.0], [0.0, 0.5, 1.0, 1.5]],\n           \"cov_2_dn\": [[0.0, 1.5, 1.0], [0.0, 1.5, 1.0, 2.5]],\n       }  # Each train dynamic covariate has 3 and 4 elements respectively.\n       test_dynamic_numerical_covariates = {\n           \"cov_1_dn\": [[0.1, 0.6], [0.1, 0.6, 1.1, 1.6, 2.4]],\n           \"cov_2_dn\": [[0.1, 1.1], [0.1, 1.6, 1.1, 2.6, 10.0]],\n       }  # Each test dynamic covariate has 2 and 5 elements respectively.\n       train_dynamic_categorical_covariates = {\n           \"cov_1_dc\": [[0, 1, 0], [0, 1, 2, 3]],\n           \"cov_2_dc\": [[\"good\", \"bad\", \"good\"], [\"good\", \"good\", \"bad\",\n           \"bad\"]],\n       }\n       test_dynamic_categorical_covariates = {\n           \"cov_1_dc\": [[1, 0], [1, 0, 2, 3, 1]],\n           \"cov_2_dc\": [[\"bad\", \"good\"], [\"bad\", \"bad\", \"bad\", \"bad\", \"bad\"]],\n       }\n       static_numerical_covariates = {\n           \"cov_1_sn\": [0.0, 3.0],\n           \"cov_2_sn\": [2.0, 1.0],\n           \"cov_3_sn\": [1.0, 2.0],\n       }  # Each static covariate has 1 element for each input.\n       static_categorical_covariates = {\n           \"cov_1_sc\": [\"apple\", \"orange\"],\n           \"cov_2_sc\": [2, 3],\n       }\n       ```\n\n    Args:\n      targets: List of targets (responses) of the in-context regression.\n      train_lens: List of lengths of each target vector from the context.\n      test_lens: List of lengths of each forecast horizon.\n      train_dynamic_numerical_covariates: Dict of covariate names mapping to the\n        dynamic numerical covariates of each forecast task on the context. Their\n        lengths should match the corresponding lengths in `train_lens`.\n      train_dynamic_categorical_covariates: Dict of covariate names mapping to\n        the dynamic categorical covariates of each forecast task on the context.\n        Their lengths should match the corresponding lengths in `train_lens`.\n      test_dynamic_numerical_covariates: Dict of covariate names mapping to the\n        dynamic numerical covariates of each forecast task on the horizon. Their\n        lengths should match the corresponding lengths in `test_lens`.\n      test_dynamic_categorical_covariates: Dict of covariate names mapping to\n        the dynamic categorical covariates of each forecast task on the horizon.\n        Their lengths should match the corresponding lengths in `test_lens`.\n      static_numerical_covariates: Dict of covariate names mapping to the static\n        numerical covariates of each forecast task.\n      static_categorical_covariates: Dict of covariate names mapping to the\n        static categorical covariates of each forecast task.\n    \"\"\"\n    self.targets = targets\n    self.train_lens = train_lens\n    self.test_lens = test_lens\n    self.train_dynamic_numerical_covariates = (\n        train_dynamic_numerical_covariates or {})\n    self.train_dynamic_categorical_covariates = (\n        train_dynamic_categorical_covariates or {})\n    self.test_dynamic_numerical_covariates = (test_dynamic_numerical_covariates\n                                              or {})\n    self.test_dynamic_categorical_covariates = (\n        test_dynamic_categorical_covariates or {})\n    self.static_numerical_covariates = static_numerical_covariates or {}\n    self.static_categorical_covariates = static_categorical_covariates or {}\n\n  def _assert_covariates(self, assert_covariate_shapes: bool = False) -> None:\n    \"\"\"Verifies the validity of the covariate inputs.\"\"\"\n\n    # Check presence.\n    if (self.train_dynamic_numerical_covariates and\n        not self.test_dynamic_numerical_covariates) or (\n            not self.train_dynamic_numerical_covariates and\n            self.test_dynamic_numerical_covariates):\n      raise ValueError(\n          \"train_dynamic_numerical_covariates and\"\n          \" test_dynamic_numerical_covariates must be both present or both\"\n          \" absent.\")\n\n    if (self.train_dynamic_categorical_covariates and\n        not self.test_dynamic_categorical_covariates) or (\n            not self.train_dynamic_categorical_covariates and\n            self.test_dynamic_categorical_covariates):\n      raise ValueError(\n          \"train_dynamic_categorical_covariates and\"\n          \" test_dynamic_categorical_covariates must be both present or both\"\n          \" absent.\")\n\n    # Check keys.\n    for dict_a, dict_b, dict_a_name, dict_b_name in (\n        (\n            self.train_dynamic_numerical_covariates,\n            self.test_dynamic_numerical_covariates,\n            \"train_dynamic_numerical_covariates\",\n            \"test_dynamic_numerical_covariates\",\n        ),\n        (\n            self.train_dynamic_categorical_covariates,\n            self.test_dynamic_categorical_covariates,\n            \"train_dynamic_categorical_covariates\",\n            \"test_dynamic_categorical_covariates\",\n        ),\n    ):\n      if w := set(dict_a.keys()) - set(dict_b.keys()):\n        raise ValueError(\n            f\"{dict_a_name} has keys not present in {dict_b_name}: {w}\")\n      if w := set(dict_b.keys()) - set(dict_a.keys()):\n        raise ValueError(\n            f\"{dict_b_name} has keys not present in {dict_a_name}: {w}\")\n\n    # Check shapes.\n    if assert_covariate_shapes:\n      if len(self.targets) != len(self.train_lens):\n        raise ValueError(\n            \"targets and train_lens must have the same number of elements.\")\n\n      if len(self.train_lens) != len(self.test_lens):\n        raise ValueError(\n            \"train_lens and test_lens must have the same number of elements.\")\n\n      for i, (target, train_len) in enumerate(zip(self.targets,\n                                                  self.train_lens)):\n        if len(target) != train_len:\n          raise ValueError(\n              f\"targets[{i}] has length {len(target)} != expected {train_len}.\")\n\n      for key, values in self.static_numerical_covariates.items():\n        if len(values) != len(self.train_lens):\n          raise ValueError(\n              f\"static_numerical_covariates has key {key} with number of\"\n              f\" examples {len(values)} != expected {len(self.train_lens)}.\")\n\n      for key, values in self.static_categorical_covariates.items():\n        if len(values) != len(self.train_lens):\n          raise ValueError(\n              f\"static_categorical_covariates has key {key} with number of\"\n              f\" examples {len(values)} != expected {len(self.train_lens)}.\")\n\n      for lens, dict_cov, dict_cov_name in (\n          (\n              self.train_lens,\n              self.train_dynamic_numerical_covariates,\n              \"train_dynamic_numerical_covariates\",\n          ),\n          (\n              self.train_lens,\n              self.train_dynamic_categorical_covariates,\n              \"train_dynamic_categorical_covariates\",\n          ),\n          (\n              self.test_lens,\n              self.test_dynamic_numerical_covariates,\n              \"test_dynamic_numerical_covariates\",\n          ),\n          (\n              self.test_lens,\n              self.test_dynamic_categorical_covariates,\n              \"test_dynamic_categorical_covariates\",\n          ),\n      ):\n        for key, cov_values in dict_cov.items():\n          if len(cov_values) != len(lens):\n            raise ValueError(\n                f\"{dict_cov_name} has key {key} with number of examples\"\n                f\" {len(cov_values)} != expected {len(lens)}.\")\n          for i, cov_value in enumerate(cov_values):\n            if len(cov_value) != lens[i]:\n              raise ValueError(\n                  f\"{dict_cov_name} has key {key} with its {i}-th example\"\n                  f\" length {len(cov_value)} != expected {lens[i]}.\")\n\n  def create_covariate_matrix(\n      self,\n      one_hot_encoder_drop: str | None = \"first\",\n      use_intercept: bool = True,\n      assert_covariates: bool = False,\n      assert_covariate_shapes: bool = False,\n  ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:\n    \"\"\"Creates target vector and covariate matrices for in context regression.\n\n    Here we use model fitting language to refer to the context as 'train' and\n    the horizon as 'test'.\n\n    Args:\n      one_hot_encoder_drop: Which drop strategy to use for the one hot encoder.\n      use_intercept: Whether to prepare an intercept (all 1) column in the\n        matrices.\n      assert_covariates: Whether to assert the validity of the covariate inputs.\n      assert_covariate_shapes: Whether to assert the shapes of the covariate\n        inputs when `assert_covariates` is True.\n\n    Returns:\n      A tuple of the target vector, the covariate matrix for the context, and\n      the covariate matrix for the horizon.\n    \"\"\"\n    if assert_covariates:\n      self._assert_covariates(assert_covariate_shapes)\n\n    x_train, x_test = [], []\n\n    # Numerical features.\n    for name in sorted(self.train_dynamic_numerical_covariates):\n      x_train.append(\n          _unnest(self.train_dynamic_numerical_covariates[name])[:, np.newaxis])\n      x_test.append(\n          _unnest(self.test_dynamic_numerical_covariates[name])[:, np.newaxis])\n\n    for covs in self.static_numerical_covariates.values():\n      x_train.append(_repeat(covs, self.train_lens)[:, np.newaxis])\n      x_test.append(_repeat(covs, self.test_lens)[:, np.newaxis])\n\n    if x_train:\n      x_train = np.concatenate(x_train, axis=1)\n      x_test = np.concatenate(x_test, axis=1)\n\n      # Normalize for robustness.\n      x_mean = np.mean(x_train, axis=0, keepdims=True)\n      x_std = np.where((w := np.std(x_train, axis=0, keepdims=True)) > _TOL, w,\n                       1.0)\n      x_train = [(x_train - x_mean) / x_std]\n      x_test = [(x_test - x_mean) / x_std]\n\n    # Categorical features. Encode one by one.\n    one_hot_encoder = preprocessing.OneHotEncoder(\n        drop=one_hot_encoder_drop,\n        sparse_output=False,\n        handle_unknown=\"ignore\",\n    )\n    for name in sorted(self.train_dynamic_categorical_covariates.keys()):\n      ohe_train = _unnest(\n          self.train_dynamic_categorical_covariates[name])[:, np.newaxis]\n      ohe_test = _unnest(\n          self.test_dynamic_categorical_covariates[name])[:, np.newaxis]\n      x_train.append(np.array(one_hot_encoder.fit_transform(ohe_train)))\n      x_test.append(np.array(one_hot_encoder.transform(ohe_test)))\n\n    for covs in self.static_categorical_covariates.values():\n      ohe = one_hot_encoder.fit_transform(np.array(covs)[:, np.newaxis])\n      x_train.append(_repeat(ohe, self.train_lens))\n      x_test.append(_repeat(ohe, self.test_lens))\n\n    x_train = np.concatenate(x_train, axis=1)\n    x_test = np.concatenate(x_test, axis=1)\n\n    if use_intercept:\n      x_train = np.pad(x_train, ((0, 0), (1, 0)), constant_values=1.0)\n      x_test = np.pad(x_test, ((0, 0), (1, 0)), constant_values=1.0)\n\n    return _unnest(self.targets), x_train, x_test\n\n  def fit(self) -> Any:\n    raise NotImplementedError(\"Fit is not implemented.\")\n\n\nclass BatchedInContextXRegLinear(BatchedInContextXRegBase):\n  \"\"\"Linear in-context regression model.\"\"\"\n\n  def fit(\n      self,\n      ridge: float = 0.0,\n      one_hot_encoder_drop: str | None = \"first\",\n      use_intercept: bool = True,\n      force_on_cpu: bool = False,\n      max_rows_per_col: int = 0,\n      max_rows_per_col_sample_seed: int = 42,\n      debug_info: bool = False,\n      assert_covariates: bool = False,\n      assert_covariate_shapes: bool = False,\n  ) -> (list[np.ndarray] | tuple[list[np.ndarray], list[np.ndarray], jax.Array,\n                                 jax.Array, jax.Array]):\n    \"\"\"Fits a linear model for in-context regression.\n\n    Args:\n      ridge: A non-negative value for specifying the ridge regression penalty.\n        If 0 is provided, fallback to ordinary least squares. Note this penalty\n        is added to the normalized covariate matrix.\n      one_hot_encoder_drop: Which drop strategy to use for the one hot encoder.\n      use_intercept: Whether to prepare an intercept (all 1) column in the\n        matrices.\n      force_on_cpu: Whether to force execution on cpu for accelerator machines.\n      max_rows_per_col: How many rows to subsample per column. 0 for no\n        subsampling. This is for speeding up model fitting.\n      max_rows_per_col_sample_seed: The seed for the subsampling if needed by\n        `max_rows_per_col`.\n      debug_info: Whether to return debug info.\n      assert_covariates: Whether to assert the validity of the covariate inputs.\n      assert_covariate_shapes: Whether to assert the shapes of the covariate\n        inputs when `assert_covariates` is True.\n\n    Returns:\n      If `debug_info` is False:\n        The linear fits on the horizon.\n      If `debug_info` is True:\n        A tuple of:\n        - the linear fits on the horizon,\n        - the linear fits on the context,\n        - the flattened target vector,\n        - the covariate matrix for the context, and\n        - the covariate matrix for the horizon.\n    \"\"\"\n    flat_targets, x_train_raw, x_test = self.create_covariate_matrix(\n        one_hot_encoder_drop=one_hot_encoder_drop,\n        use_intercept=use_intercept,\n        assert_covariates=assert_covariates,\n        assert_covariate_shapes=assert_covariate_shapes,\n    )\n\n    x_train = x_train_raw.copy()\n    if max_rows_per_col:\n      nrows, ncols = x_train.shape\n      if nrows > (w := ncols * max_rows_per_col):\n        subsample = jax.random.choice(\n            jax.random.PRNGKey(max_rows_per_col_sample_seed),\n            nrows,\n            (w,),\n            replace=False,\n        )\n        x_train = x_train[subsample]\n        flat_targets = flat_targets[subsample]\n\n    device = jax.devices(\"cpu\")[0] if force_on_cpu else None\n    # Runs jitted version of the solvers which are quicker at the cost of\n    # running jitting during the first time calling. Re-jitting happens whenever\n    # new (padded) shapes are encountered.\n    # Ocassionally it helps with the speed and the accuracy if we force single\n    # thread execution on cpu for accelerator machines:\n    # 1. Avoid moving data to accelarator memory.\n    # 2. Avoid precision loss if any.\n    with jax.default_device(device):\n      x_train_raw = _to_padded_jax_array(x_train_raw)\n      x_train = _to_padded_jax_array(x_train)\n      flat_targets = _to_padded_jax_array(flat_targets)\n      x_test = _to_padded_jax_array(x_test)\n      beta_hat = (jnp.linalg.pinv(\n          x_train.T @ x_train + ridge * jnp.eye(x_train.shape[1]),\n          hermitian=True,\n      ) @ x_train.T @ flat_targets)\n      y_hat = x_test @ beta_hat\n      y_hat_context = x_train_raw @ beta_hat if debug_info else None\n\n    outputs = []\n    outputs_context = []\n\n    # Reconstruct the ragged 2-dim batched forecasts from flattened linear fits.\n    train_index, test_index = 0, 0\n    for train_index_delta, test_index_delta in zip(self.train_lens,\n                                                   self.test_lens):\n      outputs.append(np.array(y_hat[test_index:(test_index +\n                                                test_index_delta)]))\n      if debug_info:\n        outputs_context.append(\n            np.array(y_hat_context[train_index:(train_index +\n                                                train_index_delta)]))\n      train_index += train_index_delta\n      test_index += test_index_delta\n\n    if debug_info:\n      return outputs, outputs_context, flat_targets, x_train, x_test\n    else:\n      return outputs\n"
  },
  {
    "path": "probts/model/nn/arch/TransformerModule/Embed.py",
    "content": "import torch\nimport torch.nn as nn\nimport math\n\nclass PositionalEmbedding(nn.Module):\n    def __init__(self, d_model, max_len=5000):\n        super(PositionalEmbedding, self).__init__()\n        # Compute the positional encodings once in log space.\n        pe = torch.zeros(max_len, d_model).float()\n        pe.require_grad = False\n\n        position = torch.arange(0, max_len).float().unsqueeze(1)\n        div_term = (torch.arange(0, d_model, 2).float()\n                    * -(math.log(10000.0) / d_model)).exp()\n\n        pe[:, 0::2] = torch.sin(position * div_term)\n        pe[:, 1::2] = torch.cos(position * div_term)\n\n        pe = pe.unsqueeze(0)\n        self.register_buffer('pe', pe)\n\n    def forward(self, x):\n        return self.pe[:, :x.size(1)]\n\n\nclass TokenEmbedding(nn.Module):\n    def __init__(self, c_in, d_model):\n        super(TokenEmbedding, self).__init__()\n        padding = 1 if torch.__version__ >= '1.5.0' else 2\n        self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model,\n                                   kernel_size=3, padding=padding, padding_mode='circular', bias=False)\n        for m in self.modules():\n            if isinstance(m, nn.Conv1d):\n                nn.init.kaiming_normal_(\n                    m.weight, mode='fan_in', nonlinearity='leaky_relu')\n\n    def forward(self, x):\n        x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2)\n        return x\n\n\nclass FixedEmbedding(nn.Module):\n    def __init__(self, c_in, d_model):\n        super(FixedEmbedding, self).__init__()\n\n        w = torch.zeros(c_in, d_model).float()\n        w.require_grad = False\n\n        position = torch.arange(0, c_in).float().unsqueeze(1)\n        div_term = (torch.arange(0, d_model, 2).float()\n                    * -(math.log(10000.0) / d_model)).exp()\n\n        w[:, 0::2] = torch.sin(position * div_term)\n        w[:, 1::2] = torch.cos(position * div_term)\n\n        self.emb = nn.Embedding(c_in, d_model)\n        self.emb.weight = nn.Parameter(w, requires_grad=False)\n\n    def forward(self, x):\n        return self.emb(x).detach()\n\n\nclass TemporalEmbedding(nn.Module):\n    def __init__(self, d_model, embed_type='fixed', freq='h'):\n        super(TemporalEmbedding, self).__init__()\n\n        minute_size = 4\n        hour_size = 24\n        weekday_size = 7\n        day_size = 32\n        month_size = 13\n\n        Embed = FixedEmbedding if embed_type == 'fixed' else nn.Embedding\n        if freq == 't':\n            self.minute_embed = Embed(minute_size, d_model)\n        self.hour_embed = Embed(hour_size, d_model)\n        self.weekday_embed = Embed(weekday_size, d_model)\n        self.day_embed = Embed(day_size, d_model)\n        self.month_embed = Embed(month_size, d_model)\n\n    def forward(self, x):\n        x = x.long()\n        minute_x = self.minute_embed(x[:, :, 4]) if hasattr(\n            self, 'minute_embed') else 0.\n        hour_x = self.hour_embed(x[:, :, 3])\n        weekday_x = self.weekday_embed(x[:, :, 2])\n        day_x = self.day_embed(x[:, :, 1])\n        month_x = self.month_embed(x[:, :, 0])\n\n        return hour_x + weekday_x + day_x + month_x + minute_x\n\n\nclass TimeFeatureEmbedding(nn.Module):\n    def __init__(self, d_model, embed_type='timeF', freq='h'):\n        super(TimeFeatureEmbedding, self).__init__()\n\n        if freq == 'min':\n            freq = 't'\n        freq_map = {'h': 4, 't': 5, 's': 6,\n                    'm': 1, 'a': 1, 'w': 2, 'd': 3, 'b': 3}\n        d_inp = freq_map[freq]\n        self.embed = nn.Linear(d_inp, d_model, bias=False)\n\n    def forward(self, x):\n        return self.embed(x)\n\n\nclass DataEmbedding(nn.Module):\n    def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1):\n        super(DataEmbedding, self).__init__()\n\n        self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)\n        self.position_embedding = PositionalEmbedding(d_model=d_model)\n        self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type,\n                                                    freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding(\n            d_model=d_model, embed_type=embed_type, freq=freq)\n        self.dropout = nn.Dropout(p=dropout)\n\n    def forward(self, x, x_mark):\n        if x_mark is None:\n            x = self.value_embedding(x) + self.position_embedding(x)\n        else:\n            x = self.value_embedding(\n                x) + self.temporal_embedding(x_mark) + self.position_embedding(x)\n        return self.dropout(x)\n\n\nclass DataEmbedding_wo_pos(nn.Module):\n    def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1):\n        super(DataEmbedding_wo_pos, self).__init__()\n\n        self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)\n        self.position_embedding = PositionalEmbedding(d_model=d_model)\n        self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type,\n                                                    freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding(\n            d_model=d_model, embed_type=embed_type, freq=freq)\n        self.dropout = nn.Dropout(p=dropout)\n\n    def forward(self, x, x_mark):\n        if x_mark is None:\n            x = self.value_embedding(x)\n        else:\n            x = self.value_embedding(x) + self.temporal_embedding(x_mark)\n        return self.dropout(x)\n\n\nclass PatchEmbedding(nn.Module):\n    def __init__(self, d_model, patch_len, stride, padding, dropout):\n        super(PatchEmbedding, self).__init__()\n        # Patching\n        self.patch_len = patch_len\n        self.stride = stride\n        self.padding_patch_layer = nn.ReplicationPad1d((0, padding))\n\n        # Backbone, Input encoding: projection of feature vectors onto a d-dim vector space\n        self.value_embedding = nn.Linear(patch_len, d_model, bias=False)\n\n        # Positional embedding\n        self.position_embedding = PositionalEmbedding(d_model)\n\n        # Residual dropout\n        self.dropout = nn.Dropout(dropout)\n\n    def forward(self, x):\n        # do patching\n        n_vars = x.shape[1]\n        x = self.padding_patch_layer(x)\n        x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride)\n        x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3]))\n        # Input encoding\n        x = self.value_embedding(x) + self.position_embedding(x)\n        return self.dropout(x), n_vars\n\n\n# Code implementation from https://github.com/thuml/iTransformer\nclass DataEmbedding_inverted(nn.Module):\n    def __init__(self, c_in, d_model, dropout=0.1):\n        super(DataEmbedding_inverted, self).__init__()\n        self.value_embedding = nn.Linear(c_in, d_model)\n        self.dropout = nn.Dropout(p=dropout)\n\n    def forward(self, x, x_mark):\n        x = x.permute(0, 2, 1)\n        # x: [Batch Variate Time]\n        if x_mark is None:\n            x = self.value_embedding(x)\n        else:\n            # the potential to take covariates (e.g. timestamps) as tokens\n            x = self.value_embedding(torch.cat([x, x_mark.permute(0, 2, 1)], 1)) \n        # x: [Batch Variate d_model]\n        return self.dropout(x)\n"
  },
  {
    "path": "probts/model/nn/arch/TransformerModule/SelfAttention_Family.py",
    "content": "import torch\nimport torch.nn as nn\nimport numpy as np\nfrom math import sqrt\nfrom probts.utils.masking import TriangularCausalMask, ProbMask\nfrom reformer_pytorch import LSHSelfAttention\nfrom einops import rearrange\n\n\n# Code implementation from https://github.com/thuml/Flowformer\nclass FlowAttention(nn.Module):\n    def __init__(self, attention_dropout=0.1):\n        super(FlowAttention, self).__init__()\n        self.dropout = nn.Dropout(attention_dropout)\n\n    def kernel_method(self, x):\n        return torch.sigmoid(x)\n\n    def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):\n        queries = queries.transpose(1, 2)\n        keys = keys.transpose(1, 2)\n        values = values.transpose(1, 2)\n        # kernel\n        queries = self.kernel_method(queries)\n        keys = self.kernel_method(keys)\n        # incoming and outgoing\n        normalizer_row = 1.0 / (torch.einsum(\"nhld,nhd->nhl\", queries + 1e-6, keys.sum(dim=2) + 1e-6))\n        normalizer_col = 1.0 / (torch.einsum(\"nhsd,nhd->nhs\", keys + 1e-6, queries.sum(dim=2) + 1e-6))\n        # reweighting\n        normalizer_row_refine = (\n            torch.einsum(\"nhld,nhd->nhl\", queries + 1e-6, (keys * normalizer_col[:, :, :, None]).sum(dim=2) + 1e-6))\n        normalizer_col_refine = (\n            torch.einsum(\"nhsd,nhd->nhs\", keys + 1e-6, (queries * normalizer_row[:, :, :, None]).sum(dim=2) + 1e-6))\n        # competition and allocation\n        normalizer_row_refine = torch.sigmoid(\n            normalizer_row_refine * (float(queries.shape[2]) / float(keys.shape[2])))\n        normalizer_col_refine = torch.softmax(normalizer_col_refine, dim=-1) * keys.shape[2]  # B h L vis\n        # multiply\n        kv = keys.transpose(-2, -1) @ (values * normalizer_col_refine[:, :, :, None])\n        x = (((queries @ kv) * normalizer_row[:, :, :, None]) * normalizer_row_refine[:, :, :, None]).transpose(1,\n                                                                                                                2).contiguous()\n        return x, None\n\n\n# Code implementation from https://github.com/shreyansh26/FlashAttention-PyTorch\nclass FlashAttention(nn.Module):\n    def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):\n        super(FlashAttention, self).__init__()\n        self.scale = scale\n        self.mask_flag = mask_flag\n        self.output_attention = output_attention\n        self.dropout = nn.Dropout(attention_dropout)\n\n    def flash_attention_forward(self, Q, K, V, mask=None):\n        BLOCK_SIZE = 32\n        NEG_INF = -1e10  # -infinity\n        EPSILON = 1e-10\n        # mask = torch.randint(0, 2, (128, 8)).to(device='cuda')\n        O = torch.zeros_like(Q, requires_grad=True)\n        l = torch.zeros(Q.shape[:-1])[..., None]\n        m = torch.ones(Q.shape[:-1])[..., None] * NEG_INF\n\n        O = O.to(device='cuda')\n        l = l.to(device='cuda')\n        m = m.to(device='cuda')\n\n        Q_BLOCK_SIZE = min(BLOCK_SIZE, Q.shape[-1])\n        KV_BLOCK_SIZE = BLOCK_SIZE\n\n        Q_BLOCKS = torch.split(Q, Q_BLOCK_SIZE, dim=2)\n        K_BLOCKS = torch.split(K, KV_BLOCK_SIZE, dim=2)\n        V_BLOCKS = torch.split(V, KV_BLOCK_SIZE, dim=2)\n        if mask is not None:\n            mask_BLOCKS = list(torch.split(mask, KV_BLOCK_SIZE, dim=1))\n\n        Tr = len(Q_BLOCKS)\n        Tc = len(K_BLOCKS)\n\n        O_BLOCKS = list(torch.split(O, Q_BLOCK_SIZE, dim=2))\n        l_BLOCKS = list(torch.split(l, Q_BLOCK_SIZE, dim=2))\n        m_BLOCKS = list(torch.split(m, Q_BLOCK_SIZE, dim=2))\n\n        for j in range(Tc):\n            Kj = K_BLOCKS[j]\n            Vj = V_BLOCKS[j]\n            if mask is not None:\n                maskj = mask_BLOCKS[j]\n\n            for i in range(Tr):\n                Qi = Q_BLOCKS[i]\n                Oi = O_BLOCKS[i]\n                li = l_BLOCKS[i]\n                mi = m_BLOCKS[i]\n\n                scale = 1 / np.sqrt(Q.shape[-1])\n                Qi_scaled = Qi * scale\n\n                S_ij = torch.einsum('... i d, ... j d -> ... i j', Qi_scaled, Kj)\n                if mask is not None:\n                    # Masking\n                    maskj_temp = rearrange(maskj, 'b j -> b 1 1 j')\n                    S_ij = torch.where(maskj_temp > 0, S_ij, NEG_INF)\n\n                m_block_ij, _ = torch.max(S_ij, dim=-1, keepdims=True)\n                P_ij = torch.exp(S_ij - m_block_ij)\n                if mask is not None:\n                    # Masking\n                    P_ij = torch.where(maskj_temp > 0, P_ij, 0.)\n\n                l_block_ij = torch.sum(P_ij, dim=-1, keepdims=True) + EPSILON\n\n                P_ij_Vj = torch.einsum('... i j, ... j d -> ... i d', P_ij, Vj)\n\n                mi_new = torch.maximum(m_block_ij, mi)\n                li_new = torch.exp(mi - mi_new) * li + torch.exp(m_block_ij - mi_new) * l_block_ij\n\n                O_BLOCKS[i] = (li / li_new) * torch.exp(mi - mi_new) * Oi + (\n                        torch.exp(m_block_ij - mi_new) / li_new) * P_ij_Vj\n                l_BLOCKS[i] = li_new\n                m_BLOCKS[i] = mi_new\n\n        O = torch.cat(O_BLOCKS, dim=2)\n        l = torch.cat(l_BLOCKS, dim=2)\n        m = torch.cat(m_BLOCKS, dim=2)\n        return O, l, m\n\n    def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):\n        res = \\\n        self.flash_attention_forward(queries.permute(0, 2, 1, 3), keys.permute(0, 2, 1, 3), values.permute(0, 2, 1, 3),\n                                     attn_mask)[0]\n        return res.permute(0, 2, 1, 3).contiguous(), None\n\n\nclass FullAttention(nn.Module):\n    def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):\n        super(FullAttention, self).__init__()\n        self.scale = scale\n        self.mask_flag = mask_flag\n        self.output_attention = output_attention\n        self.dropout = nn.Dropout(attention_dropout)\n\n    def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):\n        B, L, H, E = queries.shape\n        _, S, _, D = values.shape\n        scale = self.scale or 1. / sqrt(E)\n\n        scores = torch.einsum(\"blhe,bshe->bhls\", queries, keys)\n\n        if self.mask_flag:\n            if attn_mask is None:\n                attn_mask = TriangularCausalMask(B, L, device=queries.device)\n\n            scores.masked_fill_(attn_mask.mask, -np.inf)\n\n        A = self.dropout(torch.softmax(scale * scores, dim=-1))\n        V = torch.einsum(\"bhls,bshd->blhd\", A, values)\n\n        if self.output_attention:\n            return (V.contiguous(), A)\n        else:\n            return (V.contiguous(), None)\n\n\n# Code implementation from https://github.com/zhouhaoyi/Informer2020\nclass ProbAttention(nn.Module):\n    def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):\n        super(ProbAttention, self).__init__()\n        self.factor = factor\n        self.scale = scale\n        self.mask_flag = mask_flag\n        self.output_attention = output_attention\n        self.dropout = nn.Dropout(attention_dropout)\n\n    def _prob_QK(self, Q, K, sample_k, n_top):  # n_top: c*ln(L_q)\n        # Q [B, H, L, D]\n        B, H, L_K, E = K.shape\n        _, _, L_Q, _ = Q.shape\n\n        # calculate the sampled Q_K\n        K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E)\n        # real U = U_part(factor*ln(L_k))*L_q\n        index_sample = torch.randint(L_K, (L_Q, sample_k))\n        K_sample = K_expand[:, :, torch.arange(\n            L_Q).unsqueeze(1), index_sample, :]\n        Q_K_sample = torch.matmul(\n            Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze()\n\n        # find the Top_k query with sparisty measurement\n        M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K)\n        M_top = M.topk(n_top, sorted=False)[1]\n\n        # use the reduced Q to calculate Q_K\n        Q_reduce = Q[torch.arange(B)[:, None, None],\n                   torch.arange(H)[None, :, None],\n                   M_top, :]  # factor*ln(L_q)\n        Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1))  # factor*ln(L_q)*L_k\n\n        return Q_K, M_top\n\n    def _get_initial_context(self, V, L_Q):\n        B, H, L_V, D = V.shape\n        if not self.mask_flag:\n            # V_sum = V.sum(dim=-2)\n            V_sum = V.mean(dim=-2)\n            contex = V_sum.unsqueeze(-2).expand(B, H,\n                                                L_Q, V_sum.shape[-1]).clone()\n        else:  # use mask\n            # requires that L_Q == L_V, i.e. for self-attention only\n            assert (L_Q == L_V)\n            contex = V.cumsum(dim=-2)\n        return contex\n\n    def _update_context(self, context_in, V, scores, index, L_Q, attn_mask):\n        B, H, L_V, D = V.shape\n\n        if self.mask_flag:\n            attn_mask = ProbMask(B, H, L_Q, index, scores, device=V.device)\n            scores.masked_fill_(attn_mask.mask, -np.inf)\n\n        attn = torch.softmax(scores, dim=-1)  # nn.Softmax(dim=-1)(scores)\n\n        context_in[torch.arange(B)[:, None, None],\n        torch.arange(H)[None, :, None],\n        index, :] = torch.matmul(attn, V).type_as(context_in)\n        if self.output_attention:\n            attns = (torch.ones([B, H, L_V, L_V]) /\n                     L_V).type_as(attn).to(attn.device)\n            attns[torch.arange(B)[:, None, None], torch.arange(H)[\n                                                  None, :, None], index, :] = attn\n            return (context_in, attns)\n        else:\n            return (context_in, None)\n\n    def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):\n        B, L_Q, H, D = queries.shape\n        _, L_K, _, _ = keys.shape\n\n        queries = queries.transpose(2, 1)\n        keys = keys.transpose(2, 1)\n        values = values.transpose(2, 1)\n\n        U_part = self.factor * \\\n                 np.ceil(np.log(L_K)).astype('int').item()  # c*ln(L_k)\n        u = self.factor * \\\n            np.ceil(np.log(L_Q)).astype('int').item()  # c*ln(L_q)\n\n        U_part = U_part if U_part < L_K else L_K\n        u = u if u < L_Q else L_Q\n\n        scores_top, index = self._prob_QK(\n            queries, keys, sample_k=U_part, n_top=u)\n\n        # add scale factor\n        scale = self.scale or 1. / sqrt(D)\n        if scale is not None:\n            scores_top = scores_top * scale\n        # get the context\n        context = self._get_initial_context(values, L_Q)\n        # update the context with selected top_k queries\n        context, attn = self._update_context(\n            context, values, scores_top, index, L_Q, attn_mask)\n\n        return context.contiguous(), attn\n\n\nclass AttentionLayer(nn.Module):\n    def __init__(self, attention, d_model, n_heads, d_keys=None,\n                 d_values=None):\n        super(AttentionLayer, self).__init__()\n\n        d_keys = d_keys or (d_model // n_heads)\n        d_values = d_values or (d_model // n_heads)\n\n        self.inner_attention = attention\n        self.query_projection = nn.Linear(d_model, d_keys * n_heads)\n        self.key_projection = nn.Linear(d_model, d_keys * n_heads)\n        self.value_projection = nn.Linear(d_model, d_values * n_heads)\n        self.out_projection = nn.Linear(d_values * n_heads, d_model)\n        self.n_heads = n_heads\n\n    def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):\n        B, L, _ = queries.shape\n        _, S, _ = keys.shape\n        H = self.n_heads\n\n        queries = self.query_projection(queries).view(B, L, H, -1)\n        keys = self.key_projection(keys).view(B, S, H, -1)\n        values = self.value_projection(values).view(B, S, H, -1)\n\n        out, attn = self.inner_attention(\n            queries,\n            keys,\n            values,\n            attn_mask,\n            tau=tau,\n            delta=delta\n        )\n        out = out.view(B, L, -1)\n\n        return self.out_projection(out), attn\n\n\nclass ReformerLayer(nn.Module):\n    def __init__(self, attention, d_model, n_heads, d_keys=None,\n                 d_values=None, causal=False, bucket_size=4, n_hashes=4):\n        super().__init__()\n        self.bucket_size = bucket_size\n        self.attn = LSHSelfAttention(\n            dim=d_model,\n            heads=n_heads,\n            bucket_size=bucket_size,\n            n_hashes=n_hashes,\n            causal=causal\n        )\n\n    def fit_length(self, queries):\n        # inside reformer: assert N % (bucket_size * 2) == 0\n        B, N, C = queries.shape\n        if N % (self.bucket_size * 2) == 0:\n            return queries\n        else:\n            # fill the time series\n            fill_len = (self.bucket_size * 2) - (N % (self.bucket_size * 2))\n            return torch.cat([queries, torch.zeros([B, fill_len, C]).to(queries.device)], dim=1)\n\n    def forward(self, queries, keys, values, attn_mask, tau, delta):\n        # in Reformer: defalut queries=keys\n        B, N, C = queries.shape\n        queries = self.attn(self.fit_length(queries))[:, :N, :]\n        return queries, None\n\n"
  },
  {
    "path": "probts/model/nn/arch/TransformerModule/Transformer_EncDec.py",
    "content": "import torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass ConvLayer(nn.Module):\n    def __init__(self, c_in):\n        super(ConvLayer, self).__init__()\n        self.downConv = nn.Conv1d(in_channels=c_in,\n                                  out_channels=c_in,\n                                  kernel_size=3,\n                                  padding=2,\n                                  padding_mode='circular')\n        self.norm = nn.BatchNorm1d(c_in)\n        self.activation = nn.ELU()\n        self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)\n\n    def forward(self, x):\n        x = self.downConv(x.permute(0, 2, 1))\n        x = self.norm(x)\n        x = self.activation(x)\n        x = self.maxPool(x)\n        x = x.transpose(1, 2)\n        return x\n\n\nclass EncoderLayer(nn.Module):\n    def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation=\"relu\"):\n        super(EncoderLayer, self).__init__()\n        d_ff = d_ff or 4 * d_model\n        self.attention = attention\n        self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)\n        self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)\n        self.norm1 = nn.LayerNorm(d_model)\n        self.norm2 = nn.LayerNorm(d_model)\n        self.dropout = nn.Dropout(dropout)\n        self.activation = F.relu if activation == \"relu\" else F.gelu\n\n    def forward(self, x, attn_mask=None, tau=None, delta=None):\n        new_x, attn = self.attention(\n            x, x, x,\n            attn_mask=attn_mask,\n            tau=tau, delta=delta\n        )\n        x = x + self.dropout(new_x)\n\n        y = x = self.norm1(x)\n        y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))\n        y = self.dropout(self.conv2(y).transpose(-1, 1))\n\n        return self.norm2(x + y), attn\n\n\nclass Encoder(nn.Module):\n    def __init__(self, attn_layers, conv_layers=None, norm_layer=None):\n        super(Encoder, self).__init__()\n        self.attn_layers = nn.ModuleList(attn_layers)\n        self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None\n        self.norm = norm_layer\n\n    def forward(self, x, attn_mask=None, tau=None, delta=None):\n        # x [B, L, D]\n        attns = []\n        if self.conv_layers is not None:\n            for i, (attn_layer, conv_layer) in enumerate(zip(self.attn_layers, self.conv_layers)):\n                delta = delta if i == 0 else None\n                x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta)\n                x = conv_layer(x)\n                attns.append(attn)\n            x, attn = self.attn_layers[-1](x, tau=tau, delta=None)\n            attns.append(attn)\n        else:\n            for attn_layer in self.attn_layers:\n                x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta)\n                attns.append(attn)\n\n        if self.norm is not None:\n            x = self.norm(x)\n\n        return x, attns\n\n\nclass DecoderLayer(nn.Module):\n    def __init__(self, self_attention, cross_attention, d_model, d_ff=None,\n                 dropout=0.1, activation=\"relu\"):\n        super(DecoderLayer, self).__init__()\n        d_ff = d_ff or 4 * d_model\n        self.self_attention = self_attention\n        self.cross_attention = cross_attention\n        self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)\n        self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)\n        self.norm1 = nn.LayerNorm(d_model)\n        self.norm2 = nn.LayerNorm(d_model)\n        self.norm3 = nn.LayerNorm(d_model)\n        self.dropout = nn.Dropout(dropout)\n        self.activation = F.relu if activation == \"relu\" else F.gelu\n\n    def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None):\n        x = x + self.dropout(self.self_attention(\n            x, x, x,\n            attn_mask=x_mask,\n            tau=tau, delta=None\n        )[0])\n        x = self.norm1(x)\n\n        x = x + self.dropout(self.cross_attention(\n            x, cross, cross,\n            attn_mask=cross_mask,\n            tau=tau, delta=delta\n        )[0])\n\n        y = x = self.norm2(x)\n        y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))\n        y = self.dropout(self.conv2(y).transpose(-1, 1))\n\n        return self.norm3(x + y)\n\n\nclass Decoder(nn.Module):\n    def __init__(self, layers, norm_layer=None, projection=None):\n        super(Decoder, self).__init__()\n        self.layers = nn.ModuleList(layers)\n        self.norm = norm_layer\n        self.projection = projection\n\n    def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None):\n        for layer in self.layers:\n            x = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask, tau=tau, delta=delta)\n\n        if self.norm is not None:\n            x = self.norm(x)\n\n        if self.projection is not None:\n            x = self.projection(x)\n        return x\n"
  },
  {
    "path": "probts/model/nn/arch/__init__.py",
    "content": ""
  },
  {
    "path": "probts/model/nn/arch/decomp.py",
    "content": "import torch\nfrom torch import nn\n\nclass moving_avg(nn.Module):\n    \"\"\"\n    Moving average block to highlight the trend of time series\n    \"\"\"\n    def __init__(self, kernel_size, stride):\n        super(moving_avg, self).__init__()\n        self.kernel_size = kernel_size\n        self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)\n\n    def forward(self, x):\n        # padding on the both ends of time series\n        front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1)\n        end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1)\n        x = torch.cat([front, x, end], dim=1)\n        x = self.avg(x.permute(0, 2, 1))\n        x = x.permute(0, 2, 1)\n        return x\n\n\nclass series_decomp(nn.Module):\n    \"\"\"\n    Series decomposition block\n    \"\"\"\n    def __init__(self, kernel_size):\n        super(series_decomp, self).__init__()\n        self.moving_avg = moving_avg(kernel_size, stride=1)\n\n    def forward(self, x):\n        moving_mean = self.moving_avg(x)\n        res = x - moving_mean\n        return res, moving_mean"
  },
  {
    "path": "probts/model/nn/prob/MAF.py",
    "content": "# ---------------------------------------------------------------------------------\n# Portions of this file are derived from PyTorch-TS\n# - Source: https://github.com/zalandoresearch/pytorch-ts\n# - Paper: Multi-variate Probabilistic Time Series Forecasting via Conditioned Normalizing Flows\n# - License: MIT, Apache-2.0 license\n\n# We thank the authors for their contributions.\n# ---------------------------------------------------------------------------------\n\n\nimport math\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.distributions import Normal\nfrom probts.model.nn.prob.flow_model import FlowModel, BatchNorm, FlowSequential\n\n\ndef create_masks(\n    input_size, hidden_size, n_hidden, input_order=\"sequential\", input_degrees=None\n):\n    # MADE paper sec 4:\n    # degrees of connections between layers -- ensure at most in_degree - 1 connections\n    degrees = []\n\n    # set input degrees to what is provided in args (the flipped order of the previous layer in a stack of mades);\n    # else init input degrees based on strategy in input_order (sequential or random)\n    if input_order == \"sequential\":\n        degrees += (\n            [torch.arange(input_size)] if input_degrees is None else [input_degrees]\n        )\n        for _ in range(n_hidden + 1):\n            degrees += [torch.arange(hidden_size) % (input_size - 1)]\n        degrees += (\n            [torch.arange(input_size) % input_size - 1]\n            if input_degrees is None\n            else [input_degrees % input_size - 1]\n        )\n\n    elif input_order == \"random\":\n        degrees += (\n            [torch.randperm(input_size)] if input_degrees is None else [input_degrees]\n        )\n        for _ in range(n_hidden + 1):\n            min_prev_degree = min(degrees[-1].min().item(), input_size - 1)\n            degrees += [torch.randint(min_prev_degree, input_size, (hidden_size,))]\n        min_prev_degree = min(degrees[-1].min().item(), input_size - 1)\n        degrees += (\n            [torch.randint(min_prev_degree, input_size, (input_size,)) - 1]\n            if input_degrees is None\n            else [input_degrees - 1]\n        )\n\n    # construct masks\n    masks = []\n    for (d0, d1) in zip(degrees[:-1], degrees[1:]):\n        masks += [(d1.unsqueeze(-1) >= d0.unsqueeze(0)).float()]\n\n    return masks, degrees[0]\n\n\nclass MaskedLinear(nn.Linear):\n    \"\"\" MADE building block layer \"\"\"\n\n    def __init__(self, input_size, n_outputs, mask, cond_label_size=None):\n        super().__init__(input_size, n_outputs)\n\n        self.register_buffer(\"mask\", mask)\n\n        self.cond_label_size = cond_label_size\n        if cond_label_size is not None:\n            self.cond_weight = nn.Parameter(\n                torch.rand(n_outputs, cond_label_size) / math.sqrt(cond_label_size)\n            )\n\n    def forward(self, x, y=None):\n        out = F.linear(x, self.weight * self.mask, self.bias)\n        if y is not None:\n            out = out + F.linear(y, self.cond_weight)\n        return out\n\n\nclass MADE(nn.Module):\n    def __init__(\n        self,\n        input_size,\n        hidden_size,\n        n_hidden,\n        cond_label_size=None,\n        activation=\"ReLU\",\n        input_order=\"sequential\",\n        input_degrees=None,\n    ):\n        \"\"\"\n        Args:\n            input_size -- scalar; dim of inputs\n            hidden_size -- scalar; dim of hidden layers\n            n_hidden -- scalar; number of hidden layers\n            activation -- str; activation function to use\n            input_order -- str or tensor; variable order for creating the autoregressive masks (sequential|random)\n                            or the order flipped from the previous layer in a stack of MADEs\n            conditional -- bool; whether model is conditional\n        \"\"\"\n        super().__init__()\n        # base distribution for calculation of log prob under the model\n        self.register_buffer(\"base_dist_mean\", torch.zeros(input_size))\n        self.register_buffer(\"base_dist_var\", torch.ones(input_size))\n\n        # create masks\n        masks, self.input_degrees = create_masks(\n            input_size, hidden_size, n_hidden, input_order, input_degrees\n        )\n\n        # setup activation\n        if activation == \"ReLU\":\n            activation_fn = nn.ReLU()\n        elif activation == \"Tanh\":\n            activation_fn = nn.Tanh()\n        else:\n            raise ValueError(\"Check activation function.\")\n\n        # construct model\n        self.net_input = MaskedLinear(\n            input_size, hidden_size, masks[0], cond_label_size\n        )\n        self.net = []\n        for m in masks[1:-1]:\n            self.net += [activation_fn, MaskedLinear(hidden_size, hidden_size, m)]\n        self.net += [\n            activation_fn,\n            MaskedLinear(hidden_size, 2 * input_size, masks[-1].repeat(2, 1)),\n        ]\n        self.net = nn.Sequential(*self.net)\n\n    @property\n    def base_dist(self):\n        return Normal(self.base_dist_mean, self.base_dist_var)\n\n    def forward(self, x, y=None):\n        # MAF eq 4 -- return mean and log std\n        m, loga = self.net(self.net_input(x, y)).chunk(chunks=2, dim=-1)\n        u = (x - m) * torch.exp(-loga)\n        # MAF eq 5\n        log_abs_det_jacobian = -loga\n        return u, log_abs_det_jacobian\n\n    def inverse(self, u, y=None, sum_log_abs_det_jacobians=None):\n        # MAF eq 3\n        # D = u.shape[-1]\n        x = torch.zeros_like(u)\n        # run through reverse model\n        for i in self.input_degrees:\n            m, loga = self.net(self.net_input(x, y)).chunk(chunks=2, dim=-1)\n            x[..., i] = u[..., i] * torch.exp(loga[..., i]) + m[..., i]\n        log_abs_det_jacobian = loga\n        return x, log_abs_det_jacobian\n\n    def log_prob(self, x, y=None):\n        u, log_abs_det_jacobian = self.forward(x, y)\n        return torch.sum(self.base_dist.log_prob(u) + log_abs_det_jacobian, dim=-1)\n\n\nclass MAF(FlowModel):\n    def __init__(\n        self,\n        n_blocks,\n        target_dim,\n        hidden_size,\n        n_hidden,\n        f_hidden_size,\n        conditional_length,\n        dequantize,\n        activation=\"ReLU\",\n        input_order=\"sequential\",\n        batch_norm=True,\n    ):\n        super().__init__(target_dim, f_hidden_size, conditional_length, dequantize)\n\n        # construct model\n        modules = []\n        self.input_degrees = None\n        for i in range(n_blocks):\n            modules += [\n                MADE(\n                    target_dim,\n                    hidden_size,\n                    n_hidden,\n                    conditional_length,\n                    activation,\n                    input_order,\n                    self.input_degrees,\n                )\n            ]\n            self.input_degrees = modules[-1].input_degrees.flip(0)\n            modules += batch_norm * [BatchNorm(target_dim)]\n\n        self.net = FlowSequential(*modules)"
  },
  {
    "path": "probts/model/nn/prob/RealNVP.py",
    "content": "# ---------------------------------------------------------------------------------\n# Portions of this file are derived from PyTorch-TS\n# - Source: https://github.com/zalandoresearch/pytorch-ts\n# - Paper: Multi-variate Probabilistic Time Series Forecasting via Conditioned Normalizing Flows\n# - License: MIT, Apache-2.0 license\n\n# We thank the authors for their contributions.\n# ---------------------------------------------------------------------------------\n\n\nimport copy\nimport torch\nimport torch.nn as nn\nfrom probts.model.nn.prob.flow_model import FlowModel, BatchNorm, FlowSequential\n\n\nclass LinearMaskedCoupling(nn.Module):\n    \"\"\" Modified RealNVP Coupling Layers per the MAF paper \"\"\"\n\n    def __init__(self, input_size, hidden_size, n_hidden, mask, cond_label_size=None):\n        super().__init__()\n\n        self.register_buffer(\"mask\", mask)\n\n        # scale function\n        s_net = [\n            nn.Linear(\n                input_size + (cond_label_size if cond_label_size is not None else 0),\n                hidden_size,\n            )\n        ]\n        for _ in range(n_hidden):\n            s_net += [nn.Tanh(), nn.Linear(hidden_size, hidden_size)]\n        s_net += [nn.Tanh(), nn.Linear(hidden_size, input_size)]\n        self.s_net = nn.Sequential(*s_net)\n\n        # translation function\n        self.t_net = copy.deepcopy(self.s_net)\n        # replace Tanh with ReLU's per MAF paper\n        for i in range(len(self.t_net)):\n            if not isinstance(self.t_net[i], nn.Linear):\n                self.t_net[i] = nn.ReLU()\n\n    def forward(self, x, y=None):\n        # apply mask\n        mx = x * self.mask\n\n        # run through model\n        s = self.s_net(mx if y is None else torch.cat([y, mx], dim=-1))\n        t = self.t_net(mx if y is None else torch.cat([y, mx], dim=-1)) * (\n            1 - self.mask\n        )\n\n        # cf RealNVP eq 8 where u corresponds to x (here we're modeling u)\n        log_s = torch.tanh(s) * (1 - self.mask)\n        u = x * torch.exp(log_s) + t\n        # u = (x - t) * torch.exp(log_s)\n        # u = mx + (1 - self.mask) * (x - t) * torch.exp(-s)\n\n        # log det du/dx; cf RealNVP 8 and 6; note, sum over input_size done at model log_prob\n        # log_abs_det_jacobian = -(1 - self.mask) * s\n        # log_abs_det_jacobian = -log_s #.sum(-1, keepdim=True)\n        log_abs_det_jacobian = log_s\n\n        return u, log_abs_det_jacobian\n\n    def inverse(self, u, y=None):\n        # apply mask\n        mu = u * self.mask\n\n        # run through model\n        s = self.s_net(mu if y is None else torch.cat([y, mu], dim=-1))\n        t = self.t_net(mu if y is None else torch.cat([y, mu], dim=-1)) * (\n            1 - self.mask\n        )\n\n        log_s = torch.tanh(s) * (1 - self.mask)\n        x = (u - t) * torch.exp(-log_s)\n        # x = u * torch.exp(log_s) + t\n        # x = mu + (1 - self.mask) * (u * s.exp() + t)  # cf RealNVP eq 7\n\n        # log_abs_det_jacobian = (1 - self.mask) * s  # log det dx/du\n        # log_abs_det_jacobian = log_s #.sum(-1, keepdim=True)\n        log_abs_det_jacobian = -log_s\n\n        return x, log_abs_det_jacobian\n\n\nclass RealNVP(FlowModel):\n    def __init__(\n        self,\n        n_blocks,\n        target_dim,\n        hidden_size,\n        n_hidden,\n        f_hidden_size,\n        conditional_length,\n        dequantize,\n        batch_norm=True\n    ):\n        super().__init__(target_dim, f_hidden_size, conditional_length, dequantize)\n\n        # construct model\n        modules = []\n        mask = torch.arange(target_dim).float() % 2\n        for i in range(n_blocks):\n            modules += [\n                LinearMaskedCoupling(\n                    target_dim, hidden_size, n_hidden, mask, conditional_length\n                )\n            ]\n            mask = 1 - mask\n            modules += batch_norm * [BatchNorm(target_dim)]\n\n        self.net = FlowSequential(*modules)"
  },
  {
    "path": "probts/model/nn/prob/__init__.py",
    "content": ""
  },
  {
    "path": "probts/model/nn/prob/diffusion_layers.py",
    "content": "# ---------------------------------------------------------------------------------\n# Portions of this file are derived from PyTorch-TS\n# - Source: https://github.com/zalandoresearch/pytorch-ts\n# - Paper: Autoregressive Denoising Diffusion Models for Multivariate Probabilistic Time Series Forecasting\n# - License: MIT, Apache-2.0 license\n\n# We thank the authors for their contributions.\n# ---------------------------------------------------------------------------------\n\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport math\nfrom linear_attention_transformer import LinearAttentionTransformer\n\ndef get_torch_trans(heads=8, layers=1, channels=64,linear=False):\n    if linear:\n        encoder_layer = LinearAttentionTransformer(\n            dim = channels,\n            heads = heads,\n            depth = layers,\n            max_seq_len = 4096,\n            n_local_attn_heads = 0\n        )\n        return encoder_layer\n    else:\n        encoder_layer = nn.TransformerEncoderLayer(\n            d_model=channels, nhead=heads, dim_feedforward=64, activation=\"gelu\"\n        )\n        return nn.TransformerEncoder(encoder_layer, num_layers=layers)\n\n\ndef Conv1d_with_init(in_channels, out_channels, kernel_size):\n    layer = nn.Conv1d(in_channels, out_channels, kernel_size)\n    nn.init.kaiming_normal_(layer.weight)\n    return layer\n\n\nclass DiffusionEmbedding(nn.Module):\n    def __init__(self, dim=128, proj_dim=None, max_steps=500):\n        super().__init__()\n        if proj_dim is None:\n            proj_dim = dim\n        self.register_buffer(\n            \"embedding\", self._build_embedding(dim, max_steps), persistent=False\n        )\n        self.projection1 = nn.Linear(dim * 2, proj_dim)\n        self.projection2 = nn.Linear(proj_dim, proj_dim)\n\n    def forward(self, diffusion_step):\n        x = self.embedding[diffusion_step]\n        x = self.projection1(x)\n        x = F.silu(x)\n        x = self.projection2(x)\n        x = F.silu(x)\n        return x\n\n    def _build_embedding(self, dim, max_steps):\n        steps = torch.arange(max_steps).unsqueeze(1)  # [T,1]\n        dims = torch.arange(dim).unsqueeze(0)  # [1,dim]\n        table = steps * 10.0 ** (dims * 4.0 / dim)  # [T,dim]\n        table = torch.cat([torch.sin(table), torch.cos(table)], dim=1)\n        return table\n\n\nclass diff_CSDI(nn.Module):\n    def __init__(self, channels, diffusion_embedding_dim, side_dim, num_steps, nheads, n_layers, inputdim=2, linear=False):\n        super().__init__()\n        self.channels = channels\n\n        self.diffusion_embedding = DiffusionEmbedding(\n            dim=diffusion_embedding_dim, max_steps=num_steps\n        )\n        self.input_projection = Conv1d_with_init(inputdim, self.channels, 1)\n        self.output_projection1 = Conv1d_with_init(self.channels, self.channels, 1)\n        self.output_projection2 = Conv1d_with_init(self.channels, 1, 1)\n        nn.init.zeros_(self.output_projection2.weight)\n\n        self.residual_layers = nn.ModuleList(\n            [\n                ResidualBlock(\n                    side_dim=side_dim,\n                    channels=self.channels,\n                    diffusion_embedding_dim=diffusion_embedding_dim,\n                    nheads=nheads,\n                    linear=linear,\n                )\n                for _ in range(n_layers)\n            ]\n        )\n\n    def forward(self, x, cond_info, diffusion_step):\n        B, inputdim, K, L = x.shape\n\n        x = x.reshape(B, inputdim, K * L)\n\n        x = self.input_projection(x)\n        x = F.relu(x)\n        x = x.reshape(B, self.channels, K, L)\n\n        diffusion_emb = self.diffusion_embedding(diffusion_step)\n\n        skip = []\n        for layer in self.residual_layers:\n            x, skip_connection = layer(x, cond_info, diffusion_emb)\n            skip.append(skip_connection)\n\n        x = torch.sum(torch.stack(skip), dim=0) / math.sqrt(len(self.residual_layers))\n        x = x.reshape(B, self.channels, K * L)\n        x = self.output_projection1(x)  # (B,channel,K*L)\n        x = F.relu(x)\n        x = self.output_projection2(x)  # (B,1,K*L)\n        x = x.reshape(B, K, L)\n        return x\n\n\nclass ResidualBlock(nn.Module):\n    def __init__(self, side_dim, channels, diffusion_embedding_dim, nheads, linear=False):\n        super().__init__()\n        self.side_dim = side_dim\n        self.diffusion_projection = nn.Linear(diffusion_embedding_dim, channels)\n        self.cond_projection = Conv1d_with_init(side_dim, 2 * channels, 1)\n        self.mid_projection = Conv1d_with_init(channels, 2 * channels, 1)\n        self.output_projection = Conv1d_with_init(channels, 2 * channels, 1)\n\n        self.time_layer = get_torch_trans(heads=nheads, layers=1, channels=channels,linear=linear)\n        self.feature_layer = get_torch_trans(heads=nheads, layers=1, channels=channels,linear=linear)\n\n    def forward_time(self, y, base_shape):\n        B, channel, K, L = base_shape\n        if L == 1:\n            return y\n        y = y.reshape(B, channel, K, L).permute(0, 2, 1, 3).reshape(B * K, channel, L)\n        y = self.time_layer(y.permute(2, 0, 1)).permute(1, 2, 0)\n        y = y.reshape(B, K, channel, L).permute(0, 2, 1, 3).reshape(B, channel, K * L)\n        return y\n\n    def forward_feature(self, y, base_shape):\n        B, channel, K, L = base_shape\n        if K == 1:\n            return y\n        y = y.reshape(B, channel, K, L).permute(0, 3, 1, 2).reshape(B * L, channel, K)\n        y = self.feature_layer(y.permute(2, 0, 1)).permute(1, 2, 0)\n        y = y.reshape(B, L, channel, K).permute(0, 2, 3, 1).reshape(B, channel, K * L)\n        return y\n\n    def forward(self, x, cond_info, diffusion_emb):\n\n        B, channel, K, L = x.shape\n        base_shape = x.shape\n        x = x.reshape(B, channel, K * L)\n\n        diffusion_emb = self.diffusion_projection(diffusion_emb).unsqueeze(-1)  # (B,channel,1)\n        y = x + diffusion_emb\n\n        y = self.forward_time(y, base_shape)\n        y = self.forward_feature(y, base_shape)  # (B,channel,K*L)\n        y = self.mid_projection(y)  # (B,2*channel,K*L)\n        _, cond_dim, _, _ = cond_info.shape\n        cond_info = cond_info.reshape(B, cond_dim, K * L)\n        cond_info = self.cond_projection(cond_info)  # (B,2*channel,K*L)\n        y = y + cond_info\n\n        gate, filter = torch.chunk(y, 2, dim=1)\n        y = torch.sigmoid(gate) * torch.tanh(filter)  # (B,channel,K*L)\n        y = self.output_projection(y)\n\n        residual, skip = torch.chunk(y, 2, dim=1)\n        x = x.reshape(base_shape)\n        residual = residual.reshape(base_shape)\n        skip = skip.reshape(base_shape)\n        return (x + residual) / math.sqrt(2.0), skip\n"
  },
  {
    "path": "probts/model/nn/prob/flow_model.py",
    "content": "# ---------------------------------------------------------------------------------\n# Portions of this file are derived from PyTorch-TS\n# - Source: https://github.com/zalandoresearch/pytorch-ts\n# - Paper: Multi-variate Probabilistic Time Series Forecasting via Conditioned Normalizing Flows\n# - License: MIT, Apache-2.0 license\n\n# We thank the authors for their contributions.\n# ---------------------------------------------------------------------------------\n\n\nimport torch\nimport torch.nn as nn\nfrom torch.distributions import Normal\n\nclass FlowModel(nn.Module):\n    def __init__(self, target_dim, f_hidden_size, conditional_length, dequantize):\n        super().__init__()\n        self.__scale = None\n        self.net = None\n        self.dequantize = dequantize\n\n        self.dist_args = nn.Linear(\n            in_features=f_hidden_size, out_features=conditional_length\n        )\n\n        # base distribution for calculation of log prob under the model\n        self.register_buffer(\"base_dist_mean\", torch.zeros(target_dim))\n        self.register_buffer(\"base_dist_var\", torch.ones(target_dim))\n\n    @property\n    def base_dist(self):\n        return Normal(self.base_dist_mean, self.base_dist_var)\n\n    @property\n    def scale(self):\n        return self.__scale\n    \n    @scale.setter\n    def scale(self, scale):\n        self.__scale = scale\n\n    def forward(self, x, cond):\n        if self.scale is not None:\n            x /= self.scale\n        u, log_abs_det_jacobian = self.net(x, cond)\n        return u, log_abs_det_jacobian\n\n    def inverse(self, u, cond):\n        x, log_abs_det_jacobian = self.net.inverse(u, cond)\n        if self.scale is not None:\n            x *= self.scale\n            log_abs_det_jacobian += torch.log(torch.abs(self.scale))\n        return x, log_abs_det_jacobian\n\n    def log_prob(self, x, cond):\n        if self.dequantize:\n            x += torch.rand_like(x)\n        u, sum_log_abs_det_jacobians = self.forward(x, cond)\n        return torch.sum(self.base_dist.log_prob(u) + sum_log_abs_det_jacobians, dim=-1)\n\n    def loss(self, x, cond):\n        return -self.log_prob(x, cond)\n\n    def sample(self, sample_shape=torch.Size(), cond=None):\n        if cond is not None:\n            shape = cond.shape[:-1]\n        else:\n            shape = sample_shape\n\n        u = self.base_dist.sample(shape)\n        sample, _ = self.inverse(u, cond)\n        return sample\n\n\nclass BatchNorm(nn.Module):\n    \"\"\" Flow Model BatchNorm layer \"\"\"\n\n    def __init__(self, input_size, momentum=0.9, eps=1e-5):\n        super().__init__()\n        self.momentum = momentum\n        self.eps = eps\n\n        self.log_gamma = nn.Parameter(torch.zeros(input_size))\n        self.beta = nn.Parameter(torch.zeros(input_size))\n\n        self.register_buffer(\"running_mean\", torch.zeros(input_size))\n        self.register_buffer(\"running_var\", torch.ones(input_size))\n\n    def forward(self, x, cond_y=None):\n        if self.training:\n            self.batch_mean = x.view(-1, x.shape[-1]).mean(0)\n            # note MAF paper uses biased variance estimate; ie x.var(0, unbiased=False)\n            self.batch_var = x.view(-1, x.shape[-1]).var(0)\n\n            # update running mean\n            self.running_mean.mul_(self.momentum).add_(\n                self.batch_mean.data * (1 - self.momentum)\n            )\n            self.running_var.mul_(self.momentum).add_(\n                self.batch_var.data * (1 - self.momentum)\n            )\n\n            mean = self.batch_mean\n            var = self.batch_var\n        else:\n            mean = self.running_mean\n            var = self.running_var\n\n        # compute normalized input (cf original batch norm paper algo 1)\n        x_hat = (x - mean) / torch.sqrt(var + self.eps)\n        y = self.log_gamma.exp() * x_hat + self.beta\n\n        # compute log_abs_det_jacobian (cf RealNVP paper)\n        log_abs_det_jacobian = self.log_gamma - 0.5 * torch.log(var + self.eps)\n        \n        return y, log_abs_det_jacobian.expand_as(x)\n\n    def inverse(self, y, cond_y=None):\n        if self.training:\n            mean = self.batch_mean\n            var = self.batch_var\n        else:\n            mean = self.running_mean\n            var = self.running_var\n\n        x_hat = (y - self.beta) * torch.exp(-self.log_gamma)\n        x = x_hat * torch.sqrt(var + self.eps) + mean\n\n        log_abs_det_jacobian = 0.5 * torch.log(var + self.eps) - self.log_gamma\n\n        return x, log_abs_det_jacobian.expand_as(x)\n\n\nclass FlowSequential(nn.Sequential):\n    \"\"\" Container for layers of a normalizing flow \"\"\"\n\n    def forward(self, x, y):\n        sum_log_abs_det_jacobians = 0\n        for module in self:\n            x, log_abs_det_jacobian = module(x, y)\n            sum_log_abs_det_jacobians += log_abs_det_jacobian\n        return x, sum_log_abs_det_jacobians\n\n    def inverse(self, u, y):\n        sum_log_abs_det_jacobians = 0\n        for module in reversed(self):\n            u, log_abs_det_jacobian = module.inverse(u, y)\n            sum_log_abs_det_jacobians += log_abs_det_jacobian\n        return u, sum_log_abs_det_jacobians"
  },
  {
    "path": "probts/model/nn/prob/gaussian_diffusion.py",
    "content": "# ---------------------------------------------------------------------------------\n# Portions of this file are derived from PyTorch-TS\n# - Source: https://github.com/zalandoresearch/pytorch-ts\n# - Paper: Autoregressive Denoising Diffusion Models for Multivariate Probabilistic Time Series Forecasting\n# - License: MIT, Apache-2.0 license\n\n# We thank the authors for their contributions.\n# ---------------------------------------------------------------------------------\n\n\nimport math\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn, einsum\nfrom probts.model.nn.prob.diffusion_layers import DiffusionEmbedding\nfrom functools import partial\nfrom inspect import isfunction\n\n\ndef default(val, d):\n    if val is not None:\n        return val\n    return d() if isfunction(d) else d\n\n\ndef extract(a, t, x_shape):\n    b, *_ = t.shape\n    out = a.gather(-1, t)\n    return out.reshape(b, *((1,) * (len(x_shape) - 1)))\n\n\ndef noise_like(shape, device, repeat=False):\n    repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(\n        shape[0], *((1,) * (len(shape) - 1))\n    )\n    noise = lambda: torch.randn(shape, device=device)\n    return repeat_noise() if repeat else noise()\n\n\ndef cosine_beta_schedule(timesteps, s=0.008):\n    \"\"\"\n    cosine schedule\n    as proposed in https://openreview.net/forum?id=-NEXDKk8gZ\n    \"\"\"\n    steps = timesteps + 1\n    x = np.linspace(0, timesteps, steps)\n    alphas_cumprod = np.cos(((x / timesteps) + s) / (1 + s) * np.pi * 0.5) ** 2\n    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]\n    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])\n    return np.clip(betas, 0, 0.999)\n\n\nclass ResidualBlock(nn.Module):\n    def __init__(self, hidden_size, residual_channels, dilation, target_dim):\n        super().__init__()\n        self.target_dim = target_dim\n        \n        self.diffusion_projection = nn.Linear(hidden_size, residual_channels)\n\n        if self.target_dim > 1:\n            self.dilated_conv = nn.Conv1d(\n                residual_channels,\n                2 * residual_channels,\n                3,\n                padding=dilation,\n                dilation=dilation,\n                padding_mode=\"circular\",\n            )\n            self.conditioner_projection = nn.Conv1d(\n                1, 2 * residual_channels, 1, padding=2, padding_mode=\"circular\"\n            )\n        else:\n            self.dilated_conv = nn.Conv1d(residual_channels,2 * residual_channels,1)\n            self.conditioner_projection = nn.Conv1d(1, 2 * residual_channels, 1)\n\n        self.output_projection = nn.Conv1d(residual_channels, 2 * residual_channels, 1)\n\n        nn.init.kaiming_normal_(self.conditioner_projection.weight)\n        nn.init.kaiming_normal_(self.output_projection.weight)\n\n    def forward(self, x, conditioner, diffusion_step):\n        diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1)\n        conditioner = self.conditioner_projection(conditioner)\n\n        y = x + diffusion_step\n        y = self.dilated_conv(y) + conditioner\n\n        gate, filter = torch.chunk(y, 2, dim=1)\n        y = torch.sigmoid(gate) * torch.tanh(filter)\n\n        y = self.output_projection(y)\n        y = F.leaky_relu(y, 0.4)\n        residual, skip = torch.chunk(y, 2, dim=1)\n        return (x + residual) / math.sqrt(2.0), skip\n\n\nclass CondUpsampler(nn.Module):\n    def __init__(self, cond_length, target_dim):\n        super().__init__()\n        self.target_dim = target_dim\n\n        if self.target_dim > 1:\n            self.linear1 = nn.Linear(cond_length, target_dim // 2)\n            self.linear2 = nn.Linear(target_dim // 2, target_dim)\n        else:\n            self.linear = nn.Linear(cond_length, target_dim)\n\n    def forward(self, x):\n        if self.target_dim > 1:\n            x = self.linear1(x)\n            x = F.leaky_relu(x, 0.4)\n            x = self.linear2(x)\n            x = F.leaky_relu(x, 0.4)\n        else:\n            x = self.linear(x)\n            x = F.leaky_relu(x, 0.4)\n        return x\n\n\nclass EpsilonTheta(nn.Module):\n    def __init__(\n        self,\n        target_dim,\n        cond_length,\n        time_emb_dim=16,\n        residual_layers=8,\n        residual_channels=8,\n        dilation_cycle_length=2,\n        residual_hidden=64,\n        padding=2\n    ):\n        super().__init__()\n        if target_dim > 1:\n            self.input_projection = nn.Conv1d(\n                1, residual_channels, 1, padding=padding, padding_mode=\"circular\"\n            )\n            self.skip_projection = nn.Conv1d(residual_channels, residual_channels, 3)\n            self.output_projection = nn.Conv1d(residual_channels, 1, 3)\n        else:\n            # self.input_projection = nn.Identity()\n            self.input_projection = nn.Conv1d(1, residual_channels, 1)\n            self.skip_projection = nn.Conv1d(residual_channels, residual_channels, 1)\n            self.output_projection = nn.Conv1d(residual_channels, 1, 1)\n\n        self.diffusion_embedding = DiffusionEmbedding(\n            time_emb_dim, proj_dim=residual_hidden\n        )\n        self.cond_upsampler = CondUpsampler(\n            target_dim=target_dim, cond_length=cond_length\n        )\n        self.residual_layers = nn.ModuleList(\n            [\n                ResidualBlock(\n                    residual_channels=residual_channels,\n                    dilation=2 ** (i % dilation_cycle_length),\n                    hidden_size=residual_hidden,\n                    target_dim=target_dim,\n                )\n                for i in range(residual_layers)\n            ]\n        )\n\n        nn.init.kaiming_normal_(self.input_projection.weight)\n        nn.init.kaiming_normal_(self.skip_projection.weight)\n        nn.init.zeros_(self.output_projection.weight)\n\n    def forward(self, inputs, time, cond):\n        x = self.input_projection(inputs)\n        x = F.leaky_relu(x, 0.4)\n\n        diffusion_step = self.diffusion_embedding(time)\n        cond_up = self.cond_upsampler(cond)\n        skip = []\n        for layer in self.residual_layers:\n            x, skip_connection = layer(x, cond_up, diffusion_step)\n            skip.append(skip_connection)\n\n        x = torch.sum(torch.stack(skip), dim=0) / math.sqrt(len(self.residual_layers))\n        x = self.skip_projection(x)\n        x = F.leaky_relu(x, 0.4)\n        x = self.output_projection(x)\n        return x\n\n\nclass GaussianDiffusion(nn.Module):\n    def __init__(\n        self,\n        target_dim,\n        f_hidden_size,\n        conditional_length,\n        beta_end=0.1,\n        diff_steps=100,\n        loss_type=\"l2\",\n        betas=None,\n        beta_schedule=\"linear\",\n        padding=2,\n        residual_channels=8,\n    ):\n        super().__init__()\n        self.dist_args = nn.Linear(\n            in_features=f_hidden_size, out_features=conditional_length\n        )\n        self.denoise_fn = EpsilonTheta(\n            target_dim=target_dim,\n            cond_length=conditional_length,\n            residual_channels=residual_channels,\n            padding=padding,\n        )\n        self.target_dim = target_dim\n        self.__scale = None\n\n        if betas is not None:\n            betas = (\n                betas.detach().cpu().numpy()\n                if isinstance(betas, torch.Tensor)\n                else betas\n            )\n        else:\n            if beta_schedule == \"linear\":\n                betas = np.linspace(1e-4, beta_end, diff_steps)\n            elif beta_schedule == \"quad\":\n                betas = np.linspace(1e-4 ** 0.5, beta_end ** 0.5, diff_steps) ** 2\n            elif beta_schedule == \"const\":\n                betas = beta_end * np.ones(diff_steps)\n            elif beta_schedule == \"jsd\":  # 1/T, 1/(T-1), 1/(T-2), ..., 1\n                betas = 1.0 / np.linspace(diff_steps, 1, diff_steps)\n            elif beta_schedule == \"sigmoid\":\n                betas = np.linspace(-6, 6, diff_steps)\n                betas = (beta_end - 1e-4) / (np.exp(-betas) + 1) + 1e-4\n            elif beta_schedule == \"cosine\":\n                betas = cosine_beta_schedule(diff_steps)\n            else:\n                raise NotImplementedError(beta_schedule)\n\n        alphas = 1.0 - betas\n        alphas_cumprod = np.cumprod(alphas, axis=0)\n        alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])\n\n        (timesteps,) = betas.shape\n        self.num_timesteps = int(timesteps)\n        self.loss_type = loss_type\n\n        to_torch = partial(torch.tensor, dtype=torch.float32)\n\n        self.register_buffer(\"betas\", to_torch(betas))\n        self.register_buffer(\"alphas_cumprod\", to_torch(alphas_cumprod))\n        self.register_buffer(\"alphas_cumprod_prev\", to_torch(alphas_cumprod_prev))\n\n        # calculations for diffusion q(x_t | x_{t-1}) and others\n        self.register_buffer(\"sqrt_alphas_cumprod\", to_torch(np.sqrt(alphas_cumprod)))\n        self.register_buffer(\n            \"sqrt_one_minus_alphas_cumprod\", to_torch(np.sqrt(1.0 - alphas_cumprod))\n        )\n        self.register_buffer(\n            \"log_one_minus_alphas_cumprod\", to_torch(np.log(1.0 - alphas_cumprod))\n        )\n        self.register_buffer(\n            \"sqrt_recip_alphas_cumprod\", to_torch(np.sqrt(1.0 / alphas_cumprod))\n        )\n        self.register_buffer(\n            \"sqrt_recipm1_alphas_cumprod\", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))\n        )\n\n        # calculations for posterior q(x_{t-1} | x_t, x_0)\n        posterior_variance = (\n            betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)\n        )\n        # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)\n        self.register_buffer(\"posterior_variance\", to_torch(posterior_variance))\n        # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain\n        self.register_buffer(\n            \"posterior_log_variance_clipped\",\n            to_torch(np.log(np.maximum(posterior_variance, 1e-20))),\n        )\n        self.register_buffer(\n            \"posterior_mean_coef1\",\n            to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)),\n        )\n        self.register_buffer(\n            \"posterior_mean_coef2\",\n            to_torch(\n                (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)\n            ),\n        )\n\n    @property\n    def scale(self):\n        return self.__scale\n\n    @scale.setter\n    def scale(self, scale):\n        self.__scale = scale\n\n    def q_mean_variance(self, x_start, t):\n        mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start\n        variance = extract(1.0 - self.alphas_cumprod, t, x_start.shape)\n        log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)\n        return mean, variance, log_variance\n\n    def predict_start_from_noise(self, x_t, t, noise):\n        return (\n            extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t\n            - extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise\n        )\n\n    def q_posterior(self, x_start, x_t, t):\n        posterior_mean = (\n            extract(self.posterior_mean_coef1, t, x_t.shape) * x_start\n            + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t\n        )\n        posterior_variance = extract(self.posterior_variance, t, x_t.shape)\n        posterior_log_variance_clipped = extract(\n            self.posterior_log_variance_clipped, t, x_t.shape\n        )\n        return posterior_mean, posterior_variance, posterior_log_variance_clipped\n\n    def p_mean_variance(self, x, cond, t, clip_denoised: bool):\n        x_recon = self.predict_start_from_noise(\n            x, t=t, noise=self.denoise_fn(x, t, cond=cond)\n        )\n\n        if clip_denoised:\n            x_recon.clamp_(-1.0, 1.0)\n\n        model_mean, posterior_variance, posterior_log_variance = self.q_posterior(\n            x_start=x_recon, x_t=x, t=t\n        )\n        return model_mean, posterior_variance, posterior_log_variance\n\n    @torch.no_grad()\n    def p_sample(self, x, cond, t, clip_denoised=False, repeat_noise=False):\n        b, *_, device = *x.shape, x.device\n        model_mean, _, model_log_variance = self.p_mean_variance(\n            x=x, cond=cond, t=t, clip_denoised=clip_denoised\n        )\n        noise = noise_like(x.shape, device, repeat_noise)\n        # no noise when t == 0\n        nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))\n        return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise\n\n    @torch.no_grad()\n    def p_sample_loop(self, shape, cond):\n        device = self.betas.device\n\n        b = shape[0]\n        img = torch.randn(shape, device=device)\n\n        for i in reversed(range(0, self.num_timesteps)):\n            img = self.p_sample(\n                img, cond, torch.full((b,), i, device=device, dtype=torch.long)\n            )\n        return img\n\n    @torch.no_grad()\n    def sample(self, sample_shape=torch.Size(), cond=None):\n        if cond is not None:\n            shape = cond.shape[:-1] + (self.target_dim,)\n            # TODO reshape cond to (B*T, 1, -1)\n        else:\n            shape = sample_shape\n        x_hat = self.p_sample_loop(shape, cond)  # TODO reshape x_hat to (B,T,-1)\n\n        if self.scale is not None:\n            x_hat *= self.scale\n        return x_hat\n\n    @torch.no_grad()\n    def interpolate(self, x1, x2, t=None, lam=0.5):\n        b, *_, device = *x1.shape, x1.device\n        t = default(t, self.num_timesteps - 1)\n\n        assert x1.shape == x2.shape\n\n        t_batched = torch.stack([torch.tensor(t, device=device)] * b)\n        xt1, xt2 = map(lambda x: self.q_sample(x, t=t_batched), (x1, x2))\n\n        img = (1 - lam) * xt1 + lam * xt2\n        for i in reversed(range(0, t)):\n            img = self.p_sample(\n                img, torch.full((b,), i, device=device, dtype=torch.long)\n            )\n\n        return img\n\n    def q_sample(self, x_start, t, noise=None):\n        noise = default(noise, lambda: torch.randn_like(x_start))\n\n        return (\n            extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start\n            + extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise\n        )\n\n    def p_losses(self, x_start, cond, t, noise=None):\n        noise = default(noise, lambda: torch.randn_like(x_start))\n\n        x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)\n        x_recon = self.denoise_fn(x_noisy, t, cond=cond)\n\n        if self.loss_type == \"l1\":\n            loss = F.l1_loss(x_recon, noise)\n        elif self.loss_type == \"l2\":\n            loss = F.mse_loss(x_recon, noise)\n        elif self.loss_type == \"huber\":\n            loss = F.smooth_l1_loss(x_recon, noise)\n        else:\n            raise NotImplementedError()\n\n        return loss\n\n    def loss(self, x, cond, *args, **kwargs):\n        if self.scale is not None:\n            x /= self.scale\n\n        B, T, _ = x.shape\n\n        time = torch.randint(0, self.num_timesteps, (B * T,), device=x.device).long()\n        loss = self.p_losses(\n            x.reshape(B * T, 1, -1), cond.reshape(B * T, 1, -1), time, *args, **kwargs\n        )\n\n        return loss\n"
  },
  {
    "path": "probts/utils/__init__.py",
    "content": "from .utils import * \nfrom .evaluator import Evaluator"
  },
  {
    "path": "probts/utils/download_datasets.py",
    "content": "import gdown\nimport shutil\nimport os\nimport argparse\n\ndef download_and_extract_zip(output_path, zip_name='all_datasets'):\n    output_path = os.path.normpath(output_path)\n    if not output_path.endswith(os.path.sep):\n        output_path += os.path.sep\n    gdown.download(id='1tSc1WA30CL2aMt5hAW7M-d5_0IBz-lJP', output=output_path, quiet=False)\n    print(f\"Data files are saved to {os.path.dirname(output_path)}\")\n\n    file_path = os.path.join(output_path, zip_name + '.zip')\n    \n    try:\n        shutil.unpack_archive(file_path, os.path.dirname(file_path))\n        print(f\"files are unzipped\")\n    except shutil.ReadError:\n        print(\"is not zip file\")\n        \n    move_files_up_one_level(os.path.join(output_path, zip_name))\n    cleanup_directory(output_path)\n    print(\"datasets prepared done.\")\n    \ndef move_files_up_one_level(directory):\n    for item in os.listdir(directory):\n        if item in ['__MACOSX', '.DS_Store', 'all_datasets.zip']:\n            continue\n        s = os.path.join(directory, item)\n        d = os.path.join(os.path.dirname(directory), item)\n        if not os.path.exists(d):\n            shutil.move(s, d)\n        else:\n            print(f\"skip {item} due to file exist\")\n            delete_path(s)\n    \n    try:\n        delete_path(directory)\n    except:\n        print(f'cannot delete {directory}, skip...')\n    \ndef cleanup_directory(directory):\n    for root, dirs, files in os.walk(directory):\n        for name in dirs:\n            if name in ['__MACOSX']:\n                shutil.rmtree(os.path.join(root, name))\n                \n        for name in files:\n            if name in ['.DS_Store', 'all_datasets.zip']:\n                os.remove(os.path.join(root, name))\n\ndef delete_path(path):\n    if os.path.exists(path):\n        if os.path.isfile(path):\n            os.remove(path)\n        elif os.path.isdir(path):\n            shutil.rmtree(path)\n                \n                \ndef download_datasets_from_kaggle(output_path):\n    import kagglehub\n    output_path = os.path.join(output_path, 'kaggle/')\n    \n    if not os.path.exists(output_path):\n        os.makedirs(output_path)\n        \n    path = kagglehub.dataset_download(\"dharanikra/electrical-power-demand-in-turkey\")\n    s = os.path.join(path, 'power Generation and consumption.csv')\n    d = os.path.join(os.path.dirname(output_path), 'power Generation and consumption.csv')\n    shutil.move(s, d)\n    print(\"Path to electrical-power-demand-in-turkey files:\", d)\n    delete_path(path)\n    \n    path = kagglehub.dataset_download(\"leonardo00/istanbul-traffic-index\")\n    s = os.path.join(path, 'istanbul_traffic.csv')\n    d = os.path.join(os.path.dirname(output_path), 'istanbul_traffic.csv')\n    shutil.move(s, d)\n    print(\"Path to istanbul-traffic-index files:\", d)\n    delete_path(path)\n    \nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='Download and extract zip file from Google Drive')\n    parser.add_argument('--data_path', type=str, required=True, help='Path to store the extracted files')\n    args = parser.parse_args()\n\n    download_and_extract_zip(args.data_path, zip_name='all_datasets')\n    try:\n        download_datasets_from_kaggle(args.data_path)\n    except:\n        print(\"Cannot download datasets from kaggle, skip it.\")"
  },
  {
    "path": "probts/utils/evaluator.py",
    "content": "import numpy as np\nfrom .metrics import *\nimport torch\n\nclass Evaluator:\n    \n    def __init__(self, quantiles_num=10, smooth=False):\n        self.quantiles = (1.0 * np.arange(quantiles_num) / quantiles_num)[1:]\n        self.ignore_invalid_values = True\n        self.smooth = smooth\n\n    def loss_name(self, q):\n        return f\"QuantileLoss[{q}]\"\n\n    def weighted_loss_name(self, q):\n        return f\"wQuantileLoss[{q}]\"\n\n    def coverage_name(self, q):\n        return f\"Coverage[{q}]\"\n\n    def get_sequence_metrics(self, targets, forecasts, seasonal_error=None, samples_dim=1,loss_weights=None):\n        mean_forecasts = forecasts.mean(axis=samples_dim)\n        median_forecasts = np.quantile(forecasts, 0.5, axis=samples_dim)\n        metrics = {\n            \"MSE\": mse(targets, mean_forecasts),\n            \"abs_error\": abs_error(targets, median_forecasts),\n            \"abs_target_sum\": abs_target_sum(targets),\n            \"abs_target_mean\": abs_target_mean(targets),\n            \"MAPE\": mape(targets, median_forecasts),\n            \"sMAPE\": smape(targets, median_forecasts),\n        }\n        \n        if seasonal_error is not None:\n            metrics[\"MASE\"] = mase(targets, median_forecasts, seasonal_error)\n        \n        metrics[\"RMSE\"] = np.sqrt(metrics[\"MSE\"])\n        metrics[\"NRMSE\"] = metrics[\"RMSE\"] / metrics[\"abs_target_mean\"]\n        metrics[\"ND\"] = metrics[\"abs_error\"] / metrics[\"abs_target_sum\"]\n        \n        # calculate weighted loss\n        if loss_weights is not None:\n            nd = np.abs(targets - mean_forecasts) / np.sum(np.abs(targets), axis=(1, 2))\n            loss_weights = loss_weights.detach().unsqueeze(0).unsqueeze(-1).numpy()\n            weighted_ND = loss_weights * nd\n            metrics['weighted_ND'] = np.sum(weighted_ND)\n        else:\n            metrics['weighted_ND'] = metrics[\"ND\"]\n\n        for q in self.quantiles:\n            q_forecasts = np.quantile(forecasts, q, axis=samples_dim)\n            metrics[self.loss_name(q)] = np.sum(quantile_loss(targets, q_forecasts, q))\n            metrics[self.weighted_loss_name(q)] = \\\n                metrics[self.loss_name(q)] / metrics[\"abs_target_sum\"]\n            metrics[self.coverage_name(q)] = coverage(targets, q_forecasts)\n        \n        metrics[\"mean_absolute_QuantileLoss\"] = np.mean(\n            [metrics[self.loss_name(q)] for q in self.quantiles]\n        )\n        metrics[\"CRPS\"] = np.mean(\n            [metrics[self.weighted_loss_name(q)] for q in self.quantiles]\n        )\n        \n        metrics[\"MAE_Coverage\"] = np.mean(\n            [\n                np.abs(metrics[self.coverage_name(q)] - np.array([q]))\n                for q in self.quantiles\n            ]\n        )\n        return metrics\n\n    def get_metrics(self, targets, forecasts, seasonal_error=None, samples_dim=1, loss_weights=None):\n        metrics = {}\n        seq_metrics = {}\n        \n        # Calculate metrics for each sequence\n        for i in range(targets.shape[0]):\n            single_seq_metrics = self.get_sequence_metrics(\n                np.expand_dims(targets[i], axis=0),\n                np.expand_dims(forecasts[i], axis=0),\n                np.expand_dims(seasonal_error[i], axis=0) if seasonal_error is not None else None,\n                samples_dim,\n                loss_weights\n            )\n            for metric_name, metric_value in single_seq_metrics.items():\n                if metric_name not in seq_metrics:\n                    seq_metrics[metric_name] = []\n                seq_metrics[metric_name].append(metric_value)\n        \n        for metric_name, metric_values in seq_metrics.items():\n            metrics[metric_name] = np.mean(metric_values)\n        return metrics\n\n    @property\n    def selected_metrics(self):\n        return [ \"ND\",'weighted_ND', 'CRPS', \"NRMSE\", \"MSE\", \"MASE\"]\n\n    def __call__(self, targets, forecasts, past_data, freq, loss_weights=None):\n        \"\"\"\n\n        Parameters\n        ----------\n        targets\n            groundtruth in (batch_size, prediction_length, target_dim)\n        forecasts\n            forecasts in (batch_size, num_samples, prediction_length, target_dim)\n        Returns\n        -------\n        Dict[String, float]\n            metrics\n        \"\"\"\n        \n        targets = process_tensor(targets)\n        forecasts = process_tensor(forecasts)\n        past_data = process_tensor(past_data)\n        \n        if self.ignore_invalid_values:\n            targets = np.ma.masked_invalid(targets)\n            forecasts = np.ma.masked_invalid(forecasts)\n        \n        seasonal_error = calculate_seasonal_error(past_data, freq)\n\n        metrics = self.get_metrics(targets, forecasts, seasonal_error=seasonal_error, samples_dim=1, loss_weights=loss_weights)\n        metrics_sum = self.get_metrics(targets.sum(axis=-1), forecasts.sum(axis=-1), samples_dim=1)\n        \n        # select output metrics\n        output_metrics = dict()\n        for k in self.selected_metrics:\n            output_metrics[k] = metrics[k]\n            if k in metrics_sum:\n                output_metrics[f\"{k}-Sum\"] = metrics_sum[k]\n        return output_metrics\n    \ndef process_tensor(targets):\n    if isinstance(targets, torch.Tensor):\n        targets = targets.cpu().detach().numpy()\n    elif isinstance(targets, np.ndarray):\n        pass \n    else:\n        raise TypeError(\"targets must be a torch.Tensor or a numpy.ndarray\")\n    return targets"
  },
  {
    "path": "probts/utils/masking.py",
    "content": "# Code implementation from https://github.com/thuml/iTransformer\nimport torch\n\nclass TriangularCausalMask():\n    def __init__(self, B, L, device=\"cpu\"):\n        mask_shape = [B, 1, L, L]\n        with torch.no_grad():\n            self._mask = torch.triu(torch.ones(mask_shape, dtype=torch.bool), diagonal=1).to(device)\n\n    @property\n    def mask(self):\n        return self._mask\n\n\nclass ProbMask():\n    def __init__(self, B, H, L, index, scores, device=\"cpu\"):\n        _mask = torch.ones(L, scores.shape[-1], dtype=torch.bool).to(device).triu(1)\n        _mask_ex = _mask[None, None, :].expand(B, H, L, scores.shape[-1])\n        indicator = _mask_ex[torch.arange(B)[:, None, None],\n                    torch.arange(H)[None, :, None],\n                    index, :].to(device)\n        self._mask = indicator.view(scores.shape).to(device)\n\n    @property\n    def mask(self):\n        return self._mask\n"
  },
  {
    "path": "probts/utils/metrics.py",
    "content": "# ---------------------------------------------------------------------------------\n# Portions of this file are derived from gluonts\n# - Source: https://github.com/awslabs/gluonts\n# - Paper: GluonTS: Probabilistic and Neural Time Series Modeling in Python\n# - License: Apache-2.0\n#\n# We thank the authors for their contributions.\n# ---------------------------------------------------------------------------------\n\n\nfrom typing import Optional\nimport numpy as np\nfrom gluonts.time_feature import get_seasonality\n\n\ndef mse(target: np.ndarray, forecast: np.ndarray) -> float:\n    r\"\"\"\n    .. math::\n\n        mse = mean((Y - \\hat{Y})^2)\n    \"\"\"\n    return np.mean(np.square(target - forecast))\n\n\ndef abs_error(target: np.ndarray, forecast: np.ndarray) -> float:\n    r\"\"\"\n    .. math::\n\n        abs\\_error = sum(|Y - \\hat{Y}|)\n    \"\"\"\n    return np.sum(np.abs(target - forecast))\n\n\ndef abs_target_sum(target) -> float:\n    r\"\"\"\n    .. math::\n\n        abs\\_target\\_sum = sum(|Y|)\n    \"\"\"\n    return np.sum(np.abs(target))\n\n\ndef abs_target_mean(target) -> float:\n    r\"\"\"\n    .. math::\n\n        abs\\_target\\_mean = mean(|Y|)\n    \"\"\"\n    return np.mean(np.abs(target))\n\n\ndef mase(\n    target: np.ndarray,\n    forecast: np.ndarray,\n    seasonal_error: np.ndarray,\n) -> float:\n    r\"\"\"\n    .. math::\n\n        mase = mean(|Y - \\hat{Y}|) / seasonal\\_error\n\n    See [HA21]_ for more details.\n    \"\"\"\n    diff = np.mean(np.abs(target - forecast), axis=1)\n    mase = diff / seasonal_error\n    # if seasonal_error is 0, set mase to 0\n    mase = mase.filled(0)  \n    return np.mean(mase)\n\ndef calculate_seasonal_error(\n    past_data: np.ndarray,\n    freq: Optional[str] = None,\n):\n    r\"\"\"\n    .. math::\n\n        seasonal\\_error = mean(|Y[t] - Y[t-m]|)\n\n    where m is the seasonal frequency. See [HA21]_ for more details.\n    \"\"\"\n    seasonality = get_seasonality(freq)\n\n    if seasonality < len(past_data):\n        forecast_freq = seasonality\n    else:\n        # edge case: the seasonal freq is larger than the length of ts\n        # revert to freq=1\n\n        # logging.info('The seasonal frequency is larger than the length of the\n        # time series. Reverting to freq=1.')\n        forecast_freq = 1\n        \n    y_t = past_data[:, :-forecast_freq]\n    y_tm = past_data[:, forecast_freq:]\n\n    mean_diff = np.mean(np.abs(y_t - y_tm), axis=1)\n    mean_diff = np.expand_dims(mean_diff, axis=1)\n\n    return mean_diff\n\n\n\ndef mape(target: np.ndarray, forecast: np.ndarray) -> float:\n    r\"\"\"\n    .. math::\n\n        mape = mean(|Y - \\hat{Y}| / |Y|))\n\n    See [HA21]_ for more details.\n    \"\"\"\n    return np.mean(np.abs(target - forecast) / np.abs(target))\n\n\ndef smape(target: np.ndarray, forecast: np.ndarray) -> float:\n    r\"\"\"\n    .. math::\n\n        smape = 2 * mean(|Y - \\hat{Y}| / (|Y| + |\\hat{Y}|))\n\n    See [HA21]_ for more details.\n    \"\"\"\n    return 2 * np.mean(\n        np.abs(target - forecast) / (np.abs(target) + np.abs(forecast))\n    )\n\ndef quantile_loss(target: np.ndarray, forecast: np.ndarray, q: float) -> float:\n    r\"\"\"\n    .. math::\n\n        quantile\\_loss = 2 * sum(|(Y - \\hat{Y}) * ((Y <= \\hat{Y}) - q)|)\n    \"\"\"\n    return 2 * np.abs((forecast - target) * ((target <= forecast) - q))\n\ndef scaled_quantile_loss(target: np.ndarray, forecast: np.ndarray, q: float, seasonal_error) -> np.ndarray:\n    return quantile_loss(target, forecast, q) / seasonal_error\n\ndef coverage(target: np.ndarray, forecast: np.ndarray) -> float:\n    r\"\"\"\n    .. math::\n\n        coverage = mean(Y < \\hat{Y})\n    \"\"\"\n    return np.mean(target < forecast)"
  },
  {
    "path": "probts/utils/position_emb.py",
    "content": "import torch\nfrom torch import nn\nimport numpy as np\nfrom einops import rearrange, repeat\n\nclass Time_Encoder(nn.Module):\n    def __init__(self, embed_time):\n        super(Time_Encoder, self).__init__()\n        self.periodic = nn.Linear(1, embed_time - 1)\n        self.linear = nn.Linear(1, 1)\n\n    def forward(self, tt):\n        if tt.dim() == 3:  # [B,L,K]\n            tt = rearrange(tt, 'b l k -> b l k 1')\n        else: # [B,L]\n            tt = rearrange(tt, 'b l -> b l 1 1')\n        \n        out2 = torch.sin(self.periodic(tt))\n        out1 = self.linear(tt)\n        out = torch.cat([out1, out2], -1) # [B,L,1,D]\n        return out\n    \ndef sin_cos_encoding(B, K, L, embed_dim):\n    assert embed_dim % 2 == 0\n    \n    omega = np.arange(embed_dim // 2, dtype=np.float64)\n    omega /= embed_dim / 2.\n    omega = 1. / 10000**omega  # (D/2,)\n    pos= [i for i in range(L)]\n    out = np.einsum('m,d->md', pos, omega)  # (M, D/2), outer product\n\n    emb_sin = np.sin(out) # (M, D/2)\n    emb_cos = np.cos(out) # (M, D/2)\n\n    emb = np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D)\n    \n    emb = repeat(emb, 'l d -> b k l d', b=B, k=K)\n    return torch.tensor(emb, dtype=torch.float64)\n\n\ndef get_1d_sincos_pos_embed_from_grid(embed_dim, pos):\n    \"\"\"\n    embed_dim: output dimension for each position\n    pos: a list of positions to be encoded: size (M,)\n    out: (M, D)\n    \"\"\"\n    assert embed_dim % 2 == 0\n    omega = np.arange(embed_dim // 2, dtype=np.float64)\n    omega /= embed_dim / 2.\n    omega = 1. / 10000**omega  # (D/2,)\n\n    pos = pos.reshape(-1)  # (M,)\n    out = np.einsum('m,d->md', pos, omega)  # (M, D/2), outer product\n\n    emb_sin = np.sin(out) # (M, D/2)\n    emb_cos = np.cos(out) # (M, D/2)\n\n    emb = np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D)\n    return emb"
  },
  {
    "path": "probts/utils/save_utils.py",
    "content": "from typing import Dict\nimport numpy as np\nimport torch\nfrom probts.model.forecaster import Forecaster\nimport importlib\nimport json\nimport pandas as pd\nimport pickle\nimport os\n\ndef update_metrics(new_metrics: Dict, stage: str, key: str = '', target_dict = {}):\n    prefix = stage if key == '' else f'{stage}_{key}'\n    for metric_name, metric_value in new_metrics.items():\n        metric_key = f'{prefix}_{metric_name}'\n        if metric_key not in target_dict:\n            target_dict[metric_key] = []\n            \n        if isinstance(metric_value, list):\n            target_dict[metric_key] = target_dict[metric_key] + metric_value\n        else:\n            target_dict[metric_key].append(metric_value)\n        \n    return target_dict\n\ndef calculate_average(metrics_dict: Dict, hor=''):\n    metrics = {}\n    if hor != '':\n        hor = hor + '/'\n\n    for key, value in metrics_dict.items():\n        metrics[hor+key] = np.mean(value)\n    return metrics\n\n\ndef calculate_weighted_average(metrics_dict: Dict, batch_size: list, hor=''):\n    metrics = {}\n    for key, value in metrics_dict.items():\n        metrics[hor+key] = np.sum(value * np.array(batch_size)) / np.sum(batch_size)\n    return metrics\n\ndef save_point_error(target, predict, input_dict, hor_str):\n    if hor_str not in input_dict:\n        input_dict[hor_str] = {'MAE': [], 'target': [], 'forecast': []}\n    \n    abs_error = np.abs(target - predict)\n\n    input_dict[hor_str]['MAE'].append(abs_error)\n    input_dict[hor_str]['target'].append(target)\n    input_dict[hor_str]['forecast'].append(predict)\n    return input_dict\n\n\ndef load_checkpoint(Model, checkpoint_path, scaler=None, learning_rate=None, no_training=False, **kwargs):\n    # Load the checkpoint\n    checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)\n    # Extract the arguments for the forecaster\n    forecaster_args = checkpoint['hyper_parameters']['forecaster']\n\n    if isinstance(forecaster_args, Forecaster):\n        forecaster = forecaster_args\n    else:\n        module_path, class_name = forecaster_args['class_path'].rsplit('.', 1)\n        forecaster_class = getattr(importlib.import_module(module_path), class_name)\n        \n        # Add any missing required arguments\n        forecaster_args = forecaster_args['init_args']\n        forecaster_args.update(kwargs)\n        \n        # Create the forecaster\n        forecaster = forecaster_class(**forecaster_args)\n    \n    forecaster.no_training = no_training\n    \n    if learning_rate is None:\n        learning_rate = checkpoint['hyper_parameters'].get('learning_rate', 1e-3)\n    \n    # Create the model instance\n    model = Model(\n        forecaster=forecaster,\n        scaler=scaler,\n        num_samples=checkpoint['hyper_parameters'].get('num_samples', 100),\n        learning_rate=learning_rate,\n        quantiles_num=checkpoint['hyper_parameters'].get('quantiles_num', 10),\n        load_from_ckpt=checkpoint['hyper_parameters'].get('load_from_ckpt', None),\n        **kwargs  # Pass additional arguments here\n    )\n    model.load_state_dict(checkpoint['state_dict'])\n    return model\n\ndef get_hor_str(prediction_length, dataloader_idx):\n    if dataloader_idx is not None:\n        hor_str = str(prediction_length[dataloader_idx])\n    elif type(prediction_length) == list:\n        hor_str = str(prediction_length[0])\n    else:\n        hor_str = str(prediction_length)\n    return hor_str\n\n\ndef save_exp_summary(pl_module, inference=False):\n    exp_summary = {}\n    \n    model_summary = pl_module.model_summary_callback._summary(pl_module.trainer, pl_module.model)\n    exp_summary['total_parameters'] = model_summary.total_parameters\n    exp_summary['trainable_parameters'] = model_summary.trainable_parameters\n    exp_summary['model_size'] = model_summary.model_size\n    \n    memory_summary = pl_module.memory_callback.memory_summary\n    exp_summary['memory_summary'] = memory_summary\n    \n    time_summary = pl_module.time_callback.time_summary\n    exp_summary['time_summary'] = time_summary\n    for batch_key, batch_time in time_summary.items():\n        if len(batch_time) > 0:\n            exp_summary[f'mean_{batch_key}'] = sum(batch_time) / len(batch_time)\n    \n    exp_summary['sampling_weight_scheme'] = pl_module.model.sampling_weight_scheme\n    \n    if inference:\n        summary_save_path = f\"{pl_module.save_dict}/inference_summary.json\"\n    else:\n        summary_save_path = f\"{pl_module.save_dict}/summary.json\"\n\n    with open(summary_save_path, 'w') as f:\n        json.dump(exp_summary, f, indent=4)\n    print(f\"Summary saved to {summary_save_path}\")\n    \n    \ndef save_csv(save_dict, model, context_length):\n    if len(model.avg_hor_metrics) > 0:\n        horizon_list = []\n        for horizon in model.avg_hor_metrics:\n            horizon_dict = model.avg_hor_metrics[str(horizon)]\n            horizon_dict['horizon'] = horizon\n            horizon_list.append(horizon_dict)\n            \n        df = pd.DataFrame(horizon_list)\n        \n    else:\n        df = pd.DataFrame([model.avg_metrics])\n    \n    if not model.forecaster.no_training:\n        test_result_file = 'horizons_results'\n    else:\n        test_result_file = f'testctx_{context_length}_horizons_results'\n        \n    df.to_csv(f'{save_dict}/{test_result_file}.csv', index='idx')\n    print('horizons result saved to ', f'{save_dict}/{test_result_file}.csv')"
  },
  {
    "path": "probts/utils/utils.py",
    "content": "# ---------------------------------------------------------------------------------\n# Portions of this file are derived from PyTorch-TS\n# - Source: https://github.com/zalandoresearch/pytorch-ts\n# - License: MIT, Apache-2.0 license\n\n# We thank the authors for their contributions.\n# ---------------------------------------------------------------------------------\n\nimport re\nimport os\nimport torch\nimport numpy as np\nfrom typing import Optional, Dict\nimport torch.nn as nn\nimport importlib\n\ndef repeat(tensor: torch.Tensor, n: int, dim: int = 0):\n    return tensor.repeat_interleave(repeats=n, dim=dim)\n\n\ndef extract(a, t, x_shape):\n    batch_size = t.shape[0]\n    out = a.gather(-1, t.cpu())\n    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)\n\n\n    \ndef weighted_average(\n    x: torch.Tensor,\n    weights: Optional[torch.Tensor] = None,\n    dim: int = None,\n    reduce: str = 'mean',\n):\n    \"\"\"\n    Computes the weighted average of a given tensor across a given dim, masking\n    values associated with weight zero,\n    meaning instead of `nan * 0 = nan` you will get `0 * 0 = 0`.\n\n    Args:\n        x: Input tensor, of which the average must be computed.\n        weights: Weights tensor, of the same shape as `x`.\n        dim: The dim along which to average `x`\n\n    Returns:\n        Tensor: The tensor with values averaged along the specified `dim`.\n    \"\"\"\n    if weights is not None:\n        weighted_tensor = torch.where(weights != 0, x * weights, torch.zeros_like(x))\n        if reduce != 'mean':\n            return weighted_tensor\n        sum_weights = torch.clamp(\n            weights.sum(dim=dim) if dim else weights.sum(), min=1.0\n        )\n        return (\n            weighted_tensor.sum(dim=dim) if dim else weighted_tensor.sum()\n        ) / sum_weights\n    else:\n        return x.mean(dim=dim) if dim else x\n    \n    \ndef convert_to_list(s):\n    '''\n    Convert prediction length strings into list\n    e.g., '96-192-336-720' will be convert into [96,192,336,720]\n    Input: str, list, int\n    Returns: list\n    '''\n    if (type(s).__name__=='int'):\n        return [s]\n    elif (type(s).__name__=='list'):\n        return s\n    elif (type(s).__name__=='str'):\n        elements = re.split(r'\\D+', s)\n        return list(map(int, elements))\n    else:\n        return None\n    \n\ndef find_best_epoch(ckpt_folder):\n    \"\"\"\n    Find the highest epoch in the Test Tube file structure.\n    Thanks to GitHub@Kai-Ref for identifying and fixing the issue with CRPS value comparisons.\n    \"\"\"\n    pattern = r\"epoch=(\\d+)-val_CRPS=([0-9]*\\.[0-9]+)\"\n    ckpt_files = os.listdir(ckpt_folder)  # List of checkpoint files\n    \n    best_ckpt = None\n    best_epoch = None\n    best_crps = float(\"inf\")  # Start with an infinitely large CRPS\n    \n    for filename in ckpt_files:\n        match = re.search(pattern, filename)\n        if match:\n            epoch = int(match.group(1))  # Extract epoch number\n            crps = float(match.group(2))  # Extract CRPS value\n            \n            if crps < best_crps:  # If this is the lowest CRPS found so far\n                best_crps = crps\n                best_ckpt = filename\n                best_epoch = epoch  # Store the best epoch number\n    return best_epoch, best_ckpt\n\ndef ensure_list(input_value, default_value=None):\n    \"\"\"\n    Ensures that the input is converted to a list. If the input is None,\n    it converts the default value to a list instead.\n    \"\"\"\n    result = convert_to_list(input_value)\n    if result is None:\n        result = convert_to_list(default_value)\n    return result\n\n\ndef init_class_helper(class_name):\n    \"\"\"\n    Dynamically imports a module and retrieves a class.\n\n    Args:\n        class_name (str): The fully qualified name of the class in the format \"module_name.ClassName\".\n\n    Returns:\n        type: The class object retrieved from the specified module.\n    \"\"\"\n    module_name, class_name = class_name.rsplit(\".\", 1)\n    module = importlib.import_module(module_name)\n    Class = getattr(module, class_name)\n    return Class"
  },
  {
    "path": "pyproject.toml",
    "content": "[build-system]\nrequires = [\"setuptools>=66\"]\n\n[project]\nname = \"ProbTS\"\nversion = \"0.1.0\"\ndescription = \"Benchmarking Point and Distributional Forecasting across Diverse Prediction Horizons\"\nauthors = [\n    {name = \"Jiawen Zhang\"},\n    {name = \"Xumeng Wen\"},\n    {name = \"Zhenwei Zhang\"},\n    {name = \"Shun Zhen\"},\n]\nreadme = \"README.md\"\nrequires-python = \">=3.10\"\nlicense = {text = \"MIT\"}\n\ndependencies = [\n    \"numpy\",\n    \"pandas==2.0.3\",\n    \"einops\",\n    \"matplotlib\",\n    \"tqdm\",\n    \"PyYAML>=6.0\",\n    \"lightning @ https://github.com/Lightning-AI/lightning/archive/refs/heads/master.zip\",\n    \"gluonts~=0.15.1\",\n    \"typeshed-client==2.3.0\",\n    \"docstring-parser==0.15\",\n    \"orjson==3.9.0\",\n    \"einops>=0.6.1\",\n    \"pydantic==1.10.8\",\n    \"transformers==4.50.0\",\n    \"linear-attention-transformer==0.19.1\",\n    \"tensorboardx==2.6.2\",\n    \"pyarrow==11.0.0\",\n    \"protobuf>=3.19\",\n    \"jsonargparse[signatures]\",\n    \"opt_einsum\",\n    \"psutil\",\n    \"reformer-pytorch\",\n    \"gdown\",\n    \"kagglehub\",\n    \"python-dotenv>=1.0.0\",\n    \"utilsforecast\",\n    \"jax\",\n    \"scikit-learn\",\n]\n\n[project.optional-dependencies]\ntsfm = [\n    \"timm\",\n    \"accelerate\",\n    \"tokenizers\",\n    \"datasets\",\n    \"jaxtyping\",\n    \"hydra-core==1.3\",\n    \"orjson\",\n    \"tensorboard\",\n    \"multiprocess\",\n    \"huggingface_hub>=0.23.0\",\n    \"safetensors\",\n    \"jax[cpu]\",\n    \"paxml>=1.4.0\", # for timesfm\n    \"praxis>=1.4.0\",\n    \"einshape>=1.0.0\",\n    \"numpy>=1.26.4\",\n    \"pandas==2.0.3\",\n    \"pykeops\",\n]\n\n[tool.setuptools]\npy-modules = []"
  },
  {
    "path": "run.py",
    "content": "import os\nimport torch\nimport logging\nfrom probts.data import ProbTSDataModule\nfrom probts.model.forecast_module import ProbTSForecastModule\nfrom probts.callbacks import MemoryCallback, TimeCallback\nfrom probts.utils import find_best_epoch\nfrom lightning.pytorch.cli import LightningCLI\nfrom lightning.pytorch.loggers import CSVLogger, TensorBoardLogger\nfrom lightning.pytorch.callbacks import ModelCheckpoint\nfrom probts.utils.save_utils import save_exp_summary, save_csv\n\nMULTI_HOR_MODEL = ['ElasTST', 'Autoformer']\n\nimport warnings\nwarnings.filterwarnings('ignore')\n\ntorch.set_float32_matmul_precision('high')\n\nlog = logging.getLogger(__name__)\nlogging.basicConfig(level=logging.INFO)\n\nclass ProbTSCli(LightningCLI):\n    \n    def add_arguments_to_parser(self, parser):\n        data_to_model_link_args = [\n            \"scaler\",\n            \"train_pred_len_list\", \n        ]\n        data_to_forecaster_link_args = [\n            \"target_dim\",\n            \"history_length\",\n            \"context_length\",\n            \"prediction_length\",\n            \"train_pred_len_list\", \n            \"lags_list\",\n            \"freq\",\n            \"time_feat_dim\",\n            \"global_mean\",\n            \"dataset\"\n        ]\n        for arg in data_to_model_link_args:\n            parser.link_arguments(f\"data.data_manager.{arg}\", f\"model.{arg}\", apply_on=\"instantiate\")\n        for arg in data_to_forecaster_link_args:\n            parser.link_arguments(f\"data.data_manager.{arg}\", f\"model.forecaster.init_args.{arg}\", apply_on=\"instantiate\")\n\n    def init_exp(self):\n        config_args = self.parser.parse_args()\n        \n        if self.datamodule.data_manager.multi_hor:\n            assert self.model.forecaster.name in MULTI_HOR_MODEL, f\"Only support multi-horizon setting for {MULTI_HOR_MODEL}\"\n            \n            self.tag = \"_\".join([\n                self.datamodule.data_manager.dataset,\n                self.model.forecaster.name,\n                'TrainCTX','-'.join([str(i) for i in self.datamodule.data_manager.train_ctx_len_list]),\n                'TrainPRED','-'.join([str(i) for i in self.datamodule.data_manager.train_pred_len_list]),\n                'ValCTX','-'.join([str(i) for i in self.datamodule.data_manager.val_ctx_len_list]),\n                'ValPRED','-'.join([str(i) for i in self.datamodule.data_manager.val_pred_len_list]),\n                'seed' + str(config_args.seed_everything)\n            ])\n        else:\n            self.tag = \"_\".join([\n                self.datamodule.data_manager.dataset,\n                self.model.forecaster.name,\n                'CTX' + str(self.datamodule.data_manager.context_length),\n                'PRED' + str(self.datamodule.data_manager.prediction_length),\n                'seed' + str(config_args.seed_everything)\n            ])\n        \n        log.info(f\"Root dir is {self.trainer.default_root_dir}, exp tag is {self.tag}\")\n        \n        if not os.path.exists(self.trainer.default_root_dir):\n            os.makedirs(self.trainer.default_root_dir)\n            \n        self.save_dict = f'{self.trainer.default_root_dir}/{self.tag}'\n        if not os.path.exists(self.save_dict):\n            os.makedirs(self.save_dict)\n\n        if self.model.load_from_ckpt is not None:\n            # if the checkpoint file is not assigned, find the best epoch in the current folder\n            if '.ckpt' not in self.model.load_from_ckpt:\n                _, best_ckpt = find_best_epoch(self.model.load_from_ckpt)\n                print(\"find best ckpt \", best_ckpt)\n                self.model.load_from_ckpt = os.path.join(self.model.load_from_ckpt, best_ckpt)\n            \n            log.info(f\"Loading pre-trained checkpoint from {self.model.load_from_ckpt}\")\n            self.model = ProbTSForecastModule.load_from_checkpoint(\n                self.model.load_from_ckpt,\n                learning_rate=config_args.model.learning_rate,\n                scaler=self.datamodule.data_manager.scaler,\n                context_length=self.datamodule.data_manager.context_length,\n                target_dim=self.datamodule.data_manager.target_dim,\n                freq=self.datamodule.data_manager.freq,\n                prediction_length=self.datamodule.data_manager.prediction_length,\n                train_pred_len_list=self.datamodule.data_manager.train_pred_len_list,\n                lags_list=self.datamodule.data_manager.lags_list,\n                time_feat_dim=self.datamodule.data_manager.time_feat_dim,\n                no_training=self.model.forecaster.no_training,\n                sampling_weight_scheme=self.model.sampling_weight_scheme,\n            )\n        \n        # Set callbacks\n        self.memory_callback = MemoryCallback()\n        self.time_callback = TimeCallback()\n        \n        callbacks = [\n            self.memory_callback,\n            self.time_callback\n        ]\n        \n        if not self.model.forecaster.no_training:\n            if self.datamodule.dataset_val is None:  # if the validation set is empty\n                monitor = \"train_loss\"\n            else:\n                # not using reweighting scheme for loss\n                if self.model.sampling_weight_scheme in ['none', 'fix']:\n                    monitor = 'val_CRPS'\n                else:\n                    monitor = 'val_weighted_ND'\n            \n            # Set callbacks\n            self.checkpoint_callback = ModelCheckpoint(\n                dirpath=f'{self.save_dict}/ckpt',\n                filename='{epoch}-{val_CRPS:.6f}',\n                every_n_epochs=1,\n                monitor=monitor,\n                save_top_k=-1,\n                save_last=True,\n                enable_version_counter=False\n            )\n\n            callbacks.append(self.checkpoint_callback)\n\n        self.set_callbacks(callbacks)\n\n    def set_callbacks(self, callbacks):\n        # Replace built-in callbacks with custom callbacks\n        custom_callbacks_name = [c.__class__.__name__ for c in callbacks]\n        for c in self.trainer.callbacks:\n            if c.__class__.__name__ in custom_callbacks_name:\n                self.trainer.callbacks.remove(c)\n        for c in callbacks:\n            self.trainer.callbacks.append(c)\n        for c in self.trainer.callbacks:\n            if c.__class__.__name__ == \"ModelSummary\":\n                self.model_summary_callback = c\n\n    def set_fit_mode(self):\n        self.trainer.logger = TensorBoardLogger(\n            save_dir=f'{self.save_dict}/logs',\n            name=self.tag,\n            version='fit'\n        )\n    \n    def set_test_mode(self):\n        self.trainer.logger = CSVLogger(\n            save_dir=f'{self.save_dict}/logs',\n            name=self.tag,\n            version='test'\n        )\n\n        if not self.model.forecaster.no_training:\n            self.ckpt = self.checkpoint_callback.best_model_path\n            log.info(f\"Loading best checkpoint from {self.ckpt}\")\n            self.model = ProbTSForecastModule.load_from_checkpoint(\n                self.ckpt, \n                scaler=self.datamodule.data_manager.scaler,\n                context_length=self.datamodule.data_manager.context_length,\n                target_dim=self.datamodule.data_manager.target_dim,\n                freq=self.datamodule.data_manager.freq,\n                prediction_length=self.datamodule.data_manager.prediction_length,\n                lags_list=self.datamodule.data_manager.lags_list,\n                time_feat_dim=self.datamodule.data_manager.time_feat_dim,\n                sampling_weight_scheme=self.model.sampling_weight_scheme,\n            )\n\n    def run(self):\n        self.init_exp()\n        \n        if not self.model.forecaster.no_training:\n            self.set_fit_mode()\n            if self.datamodule.dataset_val is None:  # if the validation set is empty\n                self.trainer.fit(model=self.model, train_dataloaders=self.datamodule.train_dataloader())\n            else:\n                self.trainer.fit(model=self.model, datamodule=self.datamodule)\n            \n            inference=False\n        else:\n            inference=True\n\n        self.set_test_mode()\n        self.trainer.test(model=self.model, datamodule=self.datamodule)\n        \n        save_exp_summary(self, inference=inference)\n        \n        ctx_len = self.datamodule.data_manager.context_length\n        if self.datamodule.data_manager.multi_hor:\n            ctx_len = ctx_len[0]\n\n        save_csv(self.save_dict, self.model, ctx_len)\n\n\nif __name__ == '__main__':\n    cli = ProbTSCli(\n        datamodule_class=ProbTSDataModule,\n        model_class=ProbTSForecastModule,\n        save_config_kwargs={\"overwrite\": True},\n        run=False\n    )\n    cli.run()"
  },
  {
    "path": "run.sh",
    "content": "MODEL=patchtst\nDATASET=etth1\nCTX_LEN=96\nPRED_LEN=96\n\n# DATA_DIR=/path/to/datasets\n# LOG_DIR=/path/to/log_dir\nDATA_DIR=./datasets\nLOG_DIR=./log_dir\n\n# multivariate datasets:\n# ['exchange_rate_nips', 'solar_nips','electricity_nips', 'traffic_nips','wiki2000_nips']\n\n# Univariate datasets:\n# ['m4_weekly', 'm4_hourly', 'm4_daily', 'm4_monthly', 'm4_quarterly', 'm4_yearly', 'm5', 'tourism_monthly', 'tourism_quarterly', 'tourism_yearly']\n\n# Long-term forecasting:\n# ['etth1', 'etth2','ettm1','ettm2','traffic_ltsf', 'electricity_ltsf', 'exchange_ltsf', 'illness_ltsf', 'weather_ltsf']\n# NOTE: when using long-term forecasting datasets, please explicit assign context_length and prediction_length, e.g., :\n# --data.data_manager.init_args.context_length 96 \\\n# --data.data_manager.init_args.prediction_length 192 \\\n\n# run pipeline with train and test\n# replace ${MODEL} with tarfet model name, e.g, patchtst\n# replace ${DATASET} with dataset name\n\n# if not specify dataset_path, the default path is ./datasets\n\n# to run on cpu, uncomment the last line\npython run.py --config config/ltsf/${DATASET}/${MODEL}.yaml --seed_everything 0  \\\n    --data.data_manager.init_args.path ${DATA_DIR} \\\n    --trainer.default_root_dir ${LOG_DIR} \\\n    --data.data_manager.init_args.dataset ${DATASET} \\\n    --data.data_manager.init_args.split_val true \\\n    --trainer.max_epochs 50 \\\n    --data.data_manager.init_args.context_length ${CTX_LEN} \\\n    --data.data_manager.init_args.prediction_length ${PRED_LEN} \n    # --trainer.accelerator=cpu --trainer.devices=1"
  },
  {
    "path": "scripts/prepare_datasets.sh",
    "content": "# Check if gdown is installed\nif pip show gdown > /dev/null 2>&1; then\n    echo \"gdown is already installed, skipping installation.\"\nelse\n    echo \"gdown is not installed, installing...\"\n    pip install gdown\nfi\n\npython probts/utils/download_datasets.py --data_path $1"
  },
  {
    "path": "scripts/prepare_tsfm_checkpoints.sh",
    "content": "#!/bin/sh\n\necho \"NOTE! By downloading these checkpoints, you agree to the licenses of the original models and checkpoints.\"\necho \"\"\necho \"- [Timer](https://github.com/thuml/Large-Time-Series-Model) created by thuml. The original model and its checkpoints are licensed under the MIT License. The checkpoints are distributed under the MIT License. You may not use these files except in compliance with the License. You may obtain a copy of the License at: https://github.com/thuml/Large-Time-Series-Model/blob/main/LICENSE.\"\necho \"- [ForecastPFN](https://github.com/abacusai/ForecastPFN) created by abacusai. The original model and its checkpoints are licensed under the MIT License. The checkpoints are distributed under the Apache-2.0 License. You may not use these files except in compliance with the License. You may obtain a copy of the License at: https://github.com/abacusai/ForecastPFN/blob/main/LICENSE.\"\necho \"- [UniTS](https://github.com/mims-harvard/UniTS) created by mims-harvard. The original model and its checkpoints are licensed under the MIT License. The checkpoints are distributed under the MIT License. You may not use these files except in compliance with the License. You may obtain a copy of the License at: https://github.com/mims-harvard/UniTS/blob/main/LICENSE.\"\necho \"- [Lag-Llama](https://github.com/time-series-foundation-models/lag-llama) created by time-series-foundation-models. The original model and its checkpoints are licensed under the MIT License. The checkpoints are distributed under the Apache-2.0 License. You may not use these files except in compliance with the License. You may obtain a copy of the License at: https://github.com/time-series-foundation-models/lag-llama/blob/main/LICENSE.\"\necho \"\"\necho \"NOTE! By downloading these checkpoints, you agree to the licenses of the original models and checkpoints.\"\nread -p \"Do you want to continue? (yes/y to continue): \" confirm\n\n# Convert input to lowercase for comparison\nconfirm=$(echo \"$confirm\" | tr '[:upper:]' '[:lower:]')\n\nif [ \"$confirm\" = \"yes\" ] || [ \"$confirm\" = \"y\" ]; then\n    # Check if gdown is installed\n    if pip show gdown > /dev/null 2>&1; then\n        echo \"gdown is already installed, skipping installation.\"\n    else\n        echo \"gdown is not installed, installing...\"\n        pip install gdown\n    fi\n    # Download the folder\n    gdown --folder 1FaCk9Lj9KZGEO09gehNqC4fbTj4wnN8j -O checkpoints\nelse\n    echo \"Download canceled.\"\nfi"
  },
  {
    "path": "scripts/reproduce_ltsf_results.sh",
    "content": "export CUDA_VISIBLE_DEVICES=0\n\nDATA_DIR=./datasets\nLOG_DIR=./exps\n\n\nCTX_LEN=96\n\nfor DATASET in 'etth1' 'etth2' 'ettm1' 'ettm2' 'weather_ltsf' 'electricity_ltsf' 'exchange_ltsf' 'traffic_ltsf'\ndo\n    for MODEL in 'dlinear' 'patchtst' 'gru_nvp' 'timegrad' 'csdi'\n    do\n        for PRED_LEN in 96 192 336 720\n        do\n            python run.py --config config/ltsf/${DATASET}/${MODEL}.yaml --seed_everything 0  \\\n                --data.data_manager.init_args.path ${DATA_DIR} \\\n                --trainer.default_root_dir ${LOG_DIR} \\\n                --data.data_manager.init_args.split_val true \\\n                --data.data_manager.init_args.dataset ${DATASET} \\\n                --data.data_manager.init_args.context_length ${CTX_LEN} \\\n                --data.data_manager.init_args.prediction_length ${PRED_LEN} \n        done\n    done\ndone\n\nCTX_LEN=36\n\nfor DATASET in 'illness_ltsf'\ndo\n    for MODEL in 'dlinear' 'patchtst' 'gru_nvp' 'timegrad' 'csdi'\n    do\n        for PRED_LEN in 24 36 48 60\n        do\n            python run.py --config config/ltsf/${DATASET}/${MODEL}.yaml --seed_everything 0  \\\n                --data.data_manager.init_args.path ${DATA_DIR} \\\n                --trainer.default_root_dir ${LOG_DIR} \\\n                --data.data_manager.init_args.split_val true \\\n                --data.data_manager.init_args.dataset ${DATASET} \\\n                --data.data_manager.init_args.context_length ${CTX_LEN} \\\n                --data.data_manager.init_args.prediction_length ${PRED_LEN} \n        done\n    done\ndone"
  },
  {
    "path": "scripts/reproduce_stsf_results.sh",
    "content": "export CUDA_VISIBLE_DEVICES=0\n\nDATA_DIR=./datasets\nLOG_DIR=./exps\n\nfor DATASET in 'solar' 'electricity' 'exchange' 'traffic' 'wiki'\ndo\n    for MODEL in 'dlinear' 'patchtst' 'gru_nvp' 'gru_maf' 'trans_maf' 'timegrad' 'csdi' 'timesnet'\n    do\n        python run.py --config config/stsf/${DATASET}/${MODEL}.yaml --seed_everything 0  \\\n            --data.data_manager.init_args.path ${DATA_DIR} \\\n            --trainer.default_root_dir ${LOG_DIR} \\\n            --data.data_manager.init_args.split_val true \n    done\ndone\n"
  },
  {
    "path": "scripts/reproduce_tsfm_results.sh",
    "content": "export CUDA_VISIBLE_DEVICES=0\n\nDATA_DIR=./datasets\nLOG_DIR=./exps\n\n# MOIRAI\nMODEL='moirai'\nfor DATASET in 'etth1' 'etth2' 'ettm1' 'ettm2' 'weather_ltsf' 'electricity_ltsf'; do\n    for CTX_LEN in 5000 96; do\n        for PRED_LEN in 24 48 96 192 336 720; do\n            python run.py --config config/tsfm/${MODEL}/context_${CTX_LEN}/${DATASET}.yaml --seed_everything 0  \\\n                --data.data_manager.init_args.path ${DATA_DIR} \\\n                --trainer.default_root_dir ${LOG_DIR} \\\n                --data.data_manager.init_args.dataset ${DATASET} \\\n                --data.data_manager.init_args.prediction_length ${PRED_LEN}\n        done\n    done\ndone\n\nfor DATASET in 'exchange_rate_nips' 'solar_nips' 'electricity_nips'; do\n    for CTX_LEN in 5000 96; do\n        python run.py --config config/tsfm/${MODEL}/context_${CTX_LEN}/${DATASET}.yaml --seed_everything 0  \\\n            --data.data_manager.init_args.path ${DATA_DIR} \\\n            --trainer.default_root_dir ${LOG_DIR} \\\n            --data.data_manager.init_args.dataset ${DATASET} \n    done\ndone\n\n# Chronos\nMODEL='chronos'\nfor DATASET in 'etth1' 'etth2' 'ettm1' 'ettm2' 'weather_ltsf'; do\n    for CTX_LEN in 5000 96; do\n        for PRED_LEN in 24 48 96 192 336 720; do\n            python run.py --config config/tsfm/${MODEL}.yaml --seed_everything 0  \\\n                --data.data_manager.init_args.path ${DATA_DIR} \\\n                --trainer.default_root_dir ${LOG_DIR} \\\n                --data.data_manager.init_args.split_val true \\\n                --data.data_manager.init_args.dataset ${DATASET} \\\n                --data.data_manager.init_args.context_length ${CTX_LEN} \\\n                --data.data_manager.init_args.prediction_length ${PRED_LEN} \\\n                --data.test_batch_size 1\n        done\n    done\ndone\n\nfor DATASET in 'exchange_rate_nips' 'traffic_nips'; do\n    for CTX_LEN in 512 96; do\n        for PRED_LEN in 24; do\n            python run.py --config config/tsfm/${MODEL}.yaml --seed_everything 0  \\\n                --data.data_manager.init_args.path ${DATA_DIR} \\\n                --trainer.default_root_dir ${LOG_DIR} \\\n                --data.data_manager.init_args.split_val true \\\n                --data.data_manager.init_args.dataset ${DATASET} \\\n                --data.data_manager.init_args.context_length ${CTX_LEN} \\\n                --data.data_manager.init_args.prediction_length ${PRED_LEN} \\\n                --data.test_batch_size 1\n        done\n    done\ndone\n\n# Lag-Llama\nMODEL='lag_llama'\nfor DATASET in 'etth1' 'etth2' 'ettm1' 'ettm2' 'weather_ltsf'; do\n    for CTX_LEN in 512; do\n        for PRED_LEN in 24 48 96 192 336 720; do\n            python run.py --config config/tsfm/${MODEL}.yaml --seed_everything 0  \\\n                --data.data_manager.init_args.path ${DATA_DIR} \\\n                --trainer.default_root_dir ${LOG_DIR} \\\n                --data.data_manager.init_args.split_val true \\\n                --data.data_manager.init_args.dataset ${DATASET} \\\n                --data.data_manager.init_args.context_length ${CTX_LEN} \\\n                --data.data_manager.init_args.prediction_length ${PRED_LEN} \\\n                --model.forecaster.init_args.ckpt_path './checkpoints/lag-llama/lag-llama.ckpt' \\\n                --data.test_batch_size 1\n        done\n    done\ndone\n\n# TimesFM\nMODEL='timesfm'\nfor DATASET in 'etth1' 'etth2' 'ettm1' 'ettm2'; do\n    for CTX_LEN in 96; do\n        for PRED_LEN in 24 48 96 192 336 720; do\n            python run.py --config config/tsfm/${MODEL}.yaml --seed_everything 0  \\\n                --data.data_manager.init_args.path ${DATA_DIR} \\\n                --trainer.default_root_dir ${LOG_DIR} \\\n                --data.data_manager.init_args.split_val true \\\n                --data.data_manager.init_args.dataset ${DATASET} \\\n                --data.data_manager.init_args.context_length ${CTX_LEN} \\\n                --data.data_manager.init_args.prediction_length ${PRED_LEN} \\\n                --data.test_batch_size 64\n        done\n    done\ndone\n\n# Timer\nMODEL='timer'\nfor DATASET in 'etth1' 'etth2' 'ettm1' 'ettm2' 'weather_ltsf' 'electricity_ltsf'; do\n    for CTX_LEN in 96; do\n        for PRED_LEN in 24 48 96 192 336 720; do\n            python run.py --config config/tsfm/${MODEL}.yaml --seed_everything 0  \\\n                --data.data_manager.init_args.path ${DATA_DIR} \\\n                --trainer.default_root_dir ${LOG_DIR} \\\n                --data.data_manager.init_args.split_val true \\\n                --data.data_manager.init_args.dataset ${DATASET} \\\n                --data.data_manager.init_args.context_length ${CTX_LEN} \\\n                --data.data_manager.init_args.prediction_length ${PRED_LEN} \\\n                --model.forecaster.init_args.ckpt_path './checkpoints/timer/Timer_67M_UTSD_4G.pt' \\\n                --data.test_batch_size 64\n        done\n    done\ndone\n\n# UniTS\nMODEL='units'\nfor DATASET in 'etth1' 'etth2' 'ettm1' 'ettm2'; do\n    for CTX_LEN in 96; do\n        for PRED_LEN in 24 48 96 192 336 720; do\n            python run.py --config config/tsfm/${MODEL}.yaml --seed_everything 0  \\\n                --data.data_manager.init_args.path ${DATA_DIR} \\\n                --trainer.default_root_dir ${LOG_DIR} \\\n                --data.data_manager.init_args.split_val true \\\n                --data.data_manager.init_args.dataset ${DATASET} \\\n                --data.data_manager.init_args.context_length ${CTX_LEN} \\\n                --data.data_manager.init_args.prediction_length ${PRED_LEN} \\\n                --model.forecaster.init_args.ckpt_path './checkpoints/units/units_x128_pretrain_checkpoint.pth' \\\n                --data.test_batch_size 64\n        done\n    done\ndone\n\n# ForecastPFN\nMODEL='forecastpfn'\nfor DATASET in 'etth1' 'etth2' 'ettm1' 'ettm2' 'weather_ltsf'; do\n    for CTX_LEN in 96; do\n        for PRED_LEN in 24 48 96 192 336 720; do\n            python run.py --config config/tsfm/${MODEL}.yaml --seed_everything 0  \\\n                --data.data_manager.init_args.path ${DATA_DIR} \\\n                --trainer.default_root_dir ${LOG_DIR} \\\n                --data.data_manager.init_args.split_val true \\\n                --data.data_manager.init_args.dataset ${DATASET} \\\n                --data.data_manager.init_args.context_length ${CTX_LEN} \\\n                --data.data_manager.init_args.prediction_length ${PRED_LEN} \\\n                --model.forecaster.init_args.ckpt_path './checkpoints/ForecastPFN/saved_weights' \\\n                --data.test_batch_size 64\n        done\n    done\ndone"
  },
  {
    "path": "scripts/run_elastst.sh",
    "content": "DATA_DIR=/path/to/datasets\nLOG_DIR=/path/to/log_dir\n\n# for varied-horizon forecasting\n\nTRAIN_CTX_LEN=96\nVAL_CTX_LEN=96\nTEST_CTX_LEN=96\n\nTRAIN_PRED_LEN=720\nVAL_PRED_LEN=720\nTEST_PRED_LEN=24-48-96-192-336-720\n\n\nDATASET='exchange_ltsf' # select from ['etth1', 'etth2', 'ettm1', 'ettm2', 'traffic_ltsf', 'electricity_ltsf', 'exchange_ltsf', 'weather_ltsf']\n\nMODEL=elastst\n\npython run.py --config config/multi_hor/${MODEL}.yaml --seed_everything 0  \\\n    --data.data_manager.init_args.path ${DATA_DIR} \\\n    --trainer.default_root_dir ${LOG_DIR} \\\n    --data.data_manager.init_args.split_val true \\\n    --data.data_manager.init_args.dataset ${DATASET} \\\n    --data.data_manager.init_args.context_length ${TEST_CTX_LEN} \\\n    --data.data_manager.init_args.prediction_length ${TEST_PRED_LEN} \\\n    --data.data_manager.init_args.train_pred_len_list ${TRAIN_PRED_LEN} \\\n    --data.data_manager.init_args.train_ctx_len ${TRAIN_CTX_LEN} \\\n    --data.data_manager.init_args.val_ctx_len ${VAL_CTX_LEN} \\\n    --data.data_manager.init_args.val_pred_len_list ${VAL_PRED_LEN} \\\n    --trainer.max_epochs 50"
  },
  {
    "path": "scripts/run_varied_hor_training.sh",
    "content": "DATA_DIR=/path/to/datasets\nLOG_DIR=/path/to/log_dir\n\n# for varied-horizon forecasting\n\nTRAIN_CTX_LEN=96\nVAL_CTX_LEN=96\nTEST_CTX_LEN=96\n\nTRAIN_PRED_LEN=1-720 \nVAL_PRED_LEN=720\nTEST_PRED_LEN=24-48-96-192-336-720\n\n\nDATASET='exchange_ltsf' # select from ['etth1', 'etth2', 'ettm1', 'ettm2', 'traffic_ltsf', 'electricity_ltsf', 'exchange_ltsf', 'weather_ltsf']\n\nMODEL=elastst\n\npython run.py --config config/multi_hor/${MODEL}.yaml --seed_everything 0  \\\n    --data.data_manager.init_args.path ${DATA_DIR} \\\n    --trainer.default_root_dir ${LOG_DIR} \\\n    --data.data_manager.init_args.split_val true \\\n    --data.data_manager.init_args.dataset ${DATASET} \\\n    --data.data_manager.init_args.context_length ${TEST_CTX_LEN} \\\n    --data.data_manager.init_args.prediction_length ${TEST_PRED_LEN} \\\n    --data.data_manager.init_args.train_pred_len_list ${TRAIN_PRED_LEN} \\\n    --data.data_manager.init_args.train_ctx_len ${TRAIN_CTX_LEN} \\\n    --data.data_manager.init_args.val_ctx_len ${VAL_CTX_LEN} \\\n    --data.data_manager.init_args.val_pred_len_list ${VAL_PRED_LEN} \\\n    --data.data_manager.init_args.continuous_sample true \\\n    --trainer.max_epochs 50"
  }
]