Repository: microsoft/ProbTS
Branch: main
Commit: 6975a9766995
Files: 299
Total size: 998.7 KB
Directory structure:
gitextract_33gmype6/
├── .gitignore
├── .gitmodules
├── CODE_OF_CONDUCT.md
├── LICENSE
├── README.md
├── SECURITY.md
├── checkpoints/
│ └── README.md
├── config/
│ ├── default/
│ │ ├── autoformer.yaml
│ │ ├── csdi.yaml
│ │ ├── dlinear.yaml
│ │ ├── gru.yaml
│ │ ├── gru_maf.yaml
│ │ ├── gru_nvp.yaml
│ │ ├── itransformer.yaml
│ │ ├── linear.yaml
│ │ ├── mean.yaml
│ │ ├── moderntcn.yaml
│ │ ├── naive.yaml
│ │ ├── nhits.yaml
│ │ ├── nlinear.yaml
│ │ ├── patchtst.yaml
│ │ ├── timegrad.yaml
│ │ ├── timesnet.yaml
│ │ ├── trans_maf.yaml
│ │ ├── transformer.yaml
│ │ ├── tsdiff.yaml
│ │ └── tsmixer.yaml
│ ├── ltsf/
│ │ ├── electricity_ltsf/
│ │ │ ├── csdi.yaml
│ │ │ ├── dlinear.yaml
│ │ │ ├── gru_nvp.yaml
│ │ │ ├── patchtst.yaml
│ │ │ └── timegrad.yaml
│ │ ├── etth1/
│ │ │ ├── csdi.yaml
│ │ │ ├── dlinear.yaml
│ │ │ ├── gru_nvp.yaml
│ │ │ ├── patchtst.yaml
│ │ │ └── timegrad.yaml
│ │ ├── etth2/
│ │ │ ├── csdi.yaml
│ │ │ ├── dlinear.yaml
│ │ │ ├── gru_nvp.yaml
│ │ │ ├── patchtst.yaml
│ │ │ └── timegrad.yaml
│ │ ├── ettm1/
│ │ │ ├── csdi.yaml
│ │ │ ├── dlinear.yaml
│ │ │ ├── gru_nvp.yaml
│ │ │ ├── patchtst.yaml
│ │ │ └── timegrad.yaml
│ │ ├── ettm2/
│ │ │ ├── csdi.yaml
│ │ │ ├── dlinear.yaml
│ │ │ ├── gru_nvp.yaml
│ │ │ ├── patchtst.yaml
│ │ │ └── timegrad.yaml
│ │ ├── exchange_ltsf/
│ │ │ ├── csdi.yaml
│ │ │ ├── dlinear.yaml
│ │ │ ├── gru_nvp.yaml
│ │ │ ├── patchtst.yaml
│ │ │ └── timegrad.yaml
│ │ ├── illness_ltsf/
│ │ │ ├── csdi.yaml
│ │ │ ├── dlinear.yaml
│ │ │ ├── gru_nvp.yaml
│ │ │ ├── patchtst.yaml
│ │ │ └── timegrad.yaml
│ │ ├── traffic_ltsf/
│ │ │ ├── csdi.yaml
│ │ │ ├── dlinear.yaml
│ │ │ ├── gru_nvp.yaml
│ │ │ ├── patchtst.yaml
│ │ │ └── timegrad.yaml
│ │ └── weather_ltsf/
│ │ ├── csdi.yaml
│ │ ├── dlinear.yaml
│ │ ├── gru_nvp.yaml
│ │ ├── patchtst.yaml
│ │ └── timegrad.yaml
│ ├── m4/
│ │ ├── m4_daily/
│ │ │ ├── csdi.yaml
│ │ │ ├── dlinear.yaml
│ │ │ ├── gru_nvp.yaml
│ │ │ ├── patchtst.yaml
│ │ │ └── timegrad.yaml
│ │ ├── m4_weekly/
│ │ │ ├── csdi.yaml
│ │ │ ├── dlinear.yaml
│ │ │ ├── gru_nvp.yaml
│ │ │ ├── patchtst.yaml
│ │ │ └── timegrad.yaml
│ │ ├── m5/
│ │ │ ├── csdi.yaml
│ │ │ ├── dlinear.yaml
│ │ │ ├── gru_nvp.yaml
│ │ │ ├── patchtst.yaml
│ │ │ └── timegrad.yaml
│ │ └── tourism_monthly/
│ │ ├── csdi.yaml
│ │ ├── dlinear.yaml
│ │ ├── gru_nvp.yaml
│ │ ├── patchtst.yaml
│ │ └── timegrad.yaml
│ ├── multi_hor/
│ │ ├── autoformer.yaml
│ │ └── elastst.yaml
│ ├── pipeline_config.yaml
│ ├── stsf/
│ │ ├── electricity/
│ │ │ ├── csdi.yaml
│ │ │ ├── dlinear.yaml
│ │ │ ├── gru.yaml
│ │ │ ├── gru_maf.yaml
│ │ │ ├── gru_nvp.yaml
│ │ │ ├── patchtst.yaml
│ │ │ ├── timegrad.yaml
│ │ │ ├── timesnet.yaml
│ │ │ ├── trans_maf.yaml
│ │ │ └── transformer.yaml
│ │ ├── exchange/
│ │ │ ├── csdi.yaml
│ │ │ ├── dlinear.yaml
│ │ │ ├── gru.yaml
│ │ │ ├── gru_maf.yaml
│ │ │ ├── gru_nvp.yaml
│ │ │ ├── patchtst.yaml
│ │ │ ├── timegrad.yaml
│ │ │ ├── timesnet.yaml
│ │ │ ├── trans_maf.yaml
│ │ │ └── transformer.yaml
│ │ ├── solar/
│ │ │ ├── csdi.yaml
│ │ │ ├── dlinear.yaml
│ │ │ ├── gru.yaml
│ │ │ ├── gru_maf.yaml
│ │ │ ├── gru_nvp.yaml
│ │ │ ├── patchtst.yaml
│ │ │ ├── timegrad.yaml
│ │ │ ├── timesnet.yaml
│ │ │ ├── trans_maf.yaml
│ │ │ └── transformer.yaml
│ │ ├── traffic/
│ │ │ ├── csdi.yaml
│ │ │ ├── dlinear.yaml
│ │ │ ├── gru.yaml
│ │ │ ├── gru_maf.yaml
│ │ │ ├── gru_nvp.yaml
│ │ │ ├── patchtst.yaml
│ │ │ ├── timegrad.yaml
│ │ │ ├── timesnet.yaml
│ │ │ ├── trans_maf.yaml
│ │ │ └── transformer.yaml
│ │ └── wiki/
│ │ ├── csdi.yaml
│ │ ├── dlinear.yaml
│ │ ├── gru.yaml
│ │ ├── gru_maf.yaml
│ │ ├── gru_nvp.yaml
│ │ ├── patchtst.yaml
│ │ ├── timegrad.yaml
│ │ ├── timesnet.yaml
│ │ ├── trans_maf.yaml
│ │ └── transformer.yaml
│ └── tsfm/
│ ├── chronos.yaml
│ ├── forecastpfn.yaml
│ ├── lag_llama.yaml
│ ├── moirai/
│ │ ├── context_5000/
│ │ │ ├── electricity_ltsf.yaml
│ │ │ ├── electricity_nips.yaml
│ │ │ ├── etth1.yaml
│ │ │ ├── etth2.yaml
│ │ │ ├── ettm1.yaml
│ │ │ ├── ettm2.yaml
│ │ │ ├── exchange_rate_nips.yaml
│ │ │ ├── solar_nips.yaml
│ │ │ └── weather_ltsf.yaml
│ │ └── context_96/
│ │ ├── electricity_ltsf.yaml
│ │ ├── electricity_nips.yaml
│ │ ├── etth1.yaml
│ │ ├── etth2.yaml
│ │ ├── ettm1.yaml
│ │ ├── ettm2.yaml
│ │ ├── exchange_rate_nips.yaml
│ │ ├── solar_nips.yaml
│ │ └── weather_ltsf.yaml
│ ├── moirai.yaml
│ ├── time_moe.yaml
│ ├── timer.yaml
│ ├── timesfm.yaml
│ ├── tinytimemixer.yaml
│ └── units.yaml
├── datasets/
│ └── .gitignore
├── docs/
│ ├── benchmark/
│ │ ├── README.md
│ │ ├── foundation_model/
│ │ │ ├── README.md
│ │ │ ├── chronos.md
│ │ │ ├── forecastpfn.md
│ │ │ ├── lag-llama.md
│ │ │ ├── moirai.md
│ │ │ ├── timer.md
│ │ │ ├── timesfm.md
│ │ │ ├── ttm.md
│ │ │ └── units.md
│ │ └── supervised_model/
│ │ └── README.md
│ └── documentation/
│ ├── Gift_eval.md
│ └── README.md
├── exps/
│ └── .gitignore
├── notebook/
│ └── data_characteristics.ipynb
├── probts/
│ ├── __init__.py
│ ├── callbacks/
│ │ ├── __init__.py
│ │ ├── memory_callback.py
│ │ └── time_callback.py
│ ├── data/
│ │ ├── __init__.py
│ │ ├── data_manager.py
│ │ ├── data_module.py
│ │ ├── data_utils/
│ │ │ ├── data_scaler.py
│ │ │ ├── data_utils.py
│ │ │ ├── get_datasets.py
│ │ │ └── time_features.py
│ │ ├── data_wrapper.py
│ │ └── datasets/
│ │ ├── gift_eval_datasets.py
│ │ ├── multi_horizon_datasets.py
│ │ └── single_horizon_datasets.py
│ ├── model/
│ │ ├── __init__.py
│ │ ├── forecast_module.py
│ │ ├── forecaster/
│ │ │ ├── __init__.py
│ │ │ ├── forecaster.py
│ │ │ ├── point_forecaster/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── autoformer.py
│ │ │ │ ├── dlinear.py
│ │ │ │ ├── elastst.py
│ │ │ │ ├── forecastpfn.py
│ │ │ │ ├── gru.py
│ │ │ │ ├── itransformer.py
│ │ │ │ ├── linear.py
│ │ │ │ ├── mean.py
│ │ │ │ ├── moderntcn.py
│ │ │ │ ├── naive.py
│ │ │ │ ├── nhits.py
│ │ │ │ ├── nlinear.py
│ │ │ │ ├── patchtst.py
│ │ │ │ ├── time_moe.py
│ │ │ │ ├── timer.py
│ │ │ │ ├── timesfm.py
│ │ │ │ ├── timesnet.py
│ │ │ │ ├── tinytimemixer.py
│ │ │ │ ├── transformer.py
│ │ │ │ ├── tsmixer.py
│ │ │ │ └── units.py
│ │ │ └── prob_forecaster/
│ │ │ ├── __init__.py
│ │ │ ├── chronos.py
│ │ │ ├── csdi.py
│ │ │ ├── gru_maf.py
│ │ │ ├── gru_nvp.py
│ │ │ ├── lag_llama.py
│ │ │ ├── moirai.py
│ │ │ ├── timegrad.py
│ │ │ ├── trans_maf.py
│ │ │ └── tsdiff.py
│ │ └── nn/
│ │ ├── __init__.py
│ │ ├── arch/
│ │ │ ├── AutoformerModule/
│ │ │ │ ├── AutoCorrelation.py
│ │ │ │ └── Autoformer_EncDec.py
│ │ │ ├── ChronosModule/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── base.py
│ │ │ │ ├── chronos.py
│ │ │ │ ├── chronos_bolt.py
│ │ │ │ ├── loss.py
│ │ │ │ └── utils.py
│ │ │ ├── Conv_Blocks.py
│ │ │ ├── ElasTSTModule/
│ │ │ │ ├── ElasTST_backbone.py
│ │ │ │ ├── Layers.py
│ │ │ │ ├── Modules.py
│ │ │ │ ├── SubLayers.py
│ │ │ │ ├── TRoPE.py
│ │ │ │ └── __init__.py
│ │ │ ├── ModernTCN_backbone.py
│ │ │ ├── Moirai_backbone.py
│ │ │ ├── PatchTSTModule/
│ │ │ │ ├── PatchTST_backbone.py
│ │ │ │ └── PatchTST_layers.py
│ │ │ ├── RevIN.py
│ │ │ ├── S4/
│ │ │ │ ├── s4.py
│ │ │ │ └── s4_backbones.py
│ │ │ ├── TSMixer_layers.py
│ │ │ ├── TimesFMModule/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── patched_decoder.py
│ │ │ │ ├── pytorch_patched_decoder.py
│ │ │ │ ├── timesfm_base.py
│ │ │ │ ├── timesfm_jax.py
│ │ │ │ ├── timesfm_torch.py
│ │ │ │ └── xreg_lib.py
│ │ │ ├── TransformerModule/
│ │ │ │ ├── Embed.py
│ │ │ │ ├── SelfAttention_Family.py
│ │ │ │ └── Transformer_EncDec.py
│ │ │ ├── __init__.py
│ │ │ └── decomp.py
│ │ └── prob/
│ │ ├── MAF.py
│ │ ├── RealNVP.py
│ │ ├── __init__.py
│ │ ├── diffusion_layers.py
│ │ ├── flow_model.py
│ │ └── gaussian_diffusion.py
│ └── utils/
│ ├── __init__.py
│ ├── download_datasets.py
│ ├── evaluator.py
│ ├── masking.py
│ ├── metrics.py
│ ├── position_emb.py
│ ├── save_utils.py
│ └── utils.py
├── pyproject.toml
├── run.py
├── run.sh
└── scripts/
├── prepare_datasets.sh
├── prepare_tsfm_checkpoints.sh
├── reproduce_ltsf_results.sh
├── reproduce_stsf_results.sh
├── reproduce_tsfm_results.sh
├── run_elastst.sh
└── run_varied_hor_training.sh
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
# vscode IDE
.vscode
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
.pdm.toml
.pdm-python
.pdm-build/
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
1.sh
log/
.vscode/
*.DS_Store
*.AppleDouble
*.LSOverride
*__MACOSX
# Icon must end with two \r characters
Icon
# Thumbnails / metadata
._*
.Spotlight-V100
.Trashes
.fseventsd
# Volumes / network
.AppleDB
.AppleDesktop
Network Trash Folder
Temporary Items
.VolumeIcon.icns
# iCloud placeholders
*.icloud
================================================
FILE: .gitmodules
================================================
[submodule "submodules/uni2ts"]
path = submodules/uni2ts
url = https://github.com/SalesforceAIResearch/uni2ts.git
[submodule "submodules/lag_llama"]
path = submodules/lag_llama
url = https://github.com/time-series-foundation-models/lag-llama.git
[submodule "submodules/timesfm"]
path = submodules/timesfm
url = https://github.com/google-research/timesfm.git
[submodule "submodules/tsfm"]
path = submodules/tsfm
url = https://github.com/ibm-granite/granite-tsfm.git
================================================
FILE: CODE_OF_CONDUCT.md
================================================
# Microsoft Open Source Code of Conduct
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
Resources:
- [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
- [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
- Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) Microsoft Corporation.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE
================================================
FILE: README.md
================================================
# ProbTS: Benchmarking Point and Distributional Forecasting across Diverse Prediction Horizons
[](https://arxiv.org/abs/2310.07446) [](./docs/benchmark/README.md) [](./docs/documentation/README.md)
## News :tada:
:triangular_flag_on_post: **May 2025**: We have integrated [ModernTCN](https://github.com/luodhhh/ModernTCN/tree/main) into ProbTS. You can find the corresponding configuration file [here](./config/default/moderntcn.yaml).
:triangular_flag_on_post: **Apr 2025**: ProbTS now includes [Time-MoE](https://github.com/Time-MoE/Time-MoE) and offers improved support for foundation models of varying sizes. See [Foundation Models](#foundation-models) for details.
:triangular_flag_on_post: **Dec 2024**: ProbTS now supports [GIFT-EVAL](https://github.com/SalesforceAIResearch/gift-eval?tab=readme-ov-file#installation) benchmark datasets! Visit [this page](./docs/documentation/Gift_eval.md) for detailed instructions. *Please note that this feature is still in beta version and may contain bugs or inconsistencies. We will continue to update and improve it.*
:triangular_flag_on_post: **Dec 2024**: Added quick guides for benchmarking foundation models. Visit [this page](./docs/benchmark/foundation_model/README.md) for detailed instructions.
:triangular_flag_on_post: **Oct 2024**: ProbTS now includes the ElasTST model! Check out the [ElasTST branch](https://github.com/microsoft/ProbTS/tree/elastst) to reproduce all results reported in paper or run `bash scripts/run_elastst.sh` for a quick start.
:triangular_flag_on_post: **Oct 2024**: The [camera-ready version](https://arxiv.org/abs/2310.07446) of ProbTS is now available, with more in-depth analyses on the impact of normalization.
## About ProbTS :bulb:
A wide range of industrial applications desire precise point and distributional forecasting for diverse prediction horizons. ProbTS serves as a benchmarking tool to aid in understanding how advanced time-series models fulfill these essential forecasting needs. It also sheds light on their advantages and disadvantages in addressing different challenges and unveil the possibilities for future research.
To achieve these objectives, ProbTS provides a unified pipeline that implements [cutting-edge models](#-available-models) from different research threads, including:
- Supervised long-term point forecasting models, such as [PatchTST](https://arxiv.org/abs/2211.14730), [iTransformer](https://arxiv.org/abs/2310.06625), etc.
- Supervised short-term probabilistic forecasting models, such as [TimeGrad](https://arxiv.org/abs/2101.12072), [CSDI](https://arxiv.org/abs/2107.03502), etc.
- Pre-trained time-series foundation models for zero-shot forecasting, such as [TimesFM](https://arxiv.org/abs/2310.10688), [MOIRAI](https://arxiv.org/abs/2402.02592), etc.
Specifically, ProbTS emphasizes the differences in their primary methodological designs, including:
- Supporting point or distributional forecasts
- Using autoregressive or non-autoregressive decoding schemes for multi-step outputs
## Available Models 🧩
ProbTS includes both classical time-series models, specializing in long-term point forecasting or short-term distributional forecasting, and recent time-series foundation models that offer zero-shot and arbitrary-horizon forecasting capabilities for new time series.
### Classical Time-series Models
| **Model** | **Original Eval. Horizon** | **Estimation** | **Decoding Scheme** | **Class Path** |
| --- | --- | --- | --- | --- |
| Linear | - | Point | Auto / Non-auto | `probts.model.forecaster.point_forecaster.LinearForecaster` |
| [GRU](https://arxiv.org/abs/1412.3555) | - | Point | AR / NAR | `probts.model.forecaster.point_forecaster.GRUForecaster` |
| [Transformer](https://arxiv.org/abs/1706.03762) | - | Point | AR / NAR | `probts.model.forecaster.point_forecaster.TransformerForecaster` |
| [Autoformer](https://arxiv.org/abs/2106.13008) | Long | Point | NAR | `probts.model.forecaster.point_forecaster.Autoformer` |
| [N-HiTS](https://arxiv.org/abs/2201.12886) | Long | Point | NAR | `probts.model.forecaster.point_forecaster.NHiTS` |
| [NLinear](https://arxiv.org/abs/2205.13504) | Long | Point | NAR | `probts.model.forecaster.point_forecaster.NLinear` |
| [DLinear](https://arxiv.org/abs/2205.13504) | Long | Point | NAR | `probts.model.forecaster.point_forecaster.DLinear` |
| [TSMixer](https://arxiv.org/abs/2303.06053) | Long | Point | NAR | `probts.model.forecaster.point_forecaster.TSMixer` |
| [TimesNet](https://arxiv.org/abs/2210.02186) | Short / Long | Point | NAR | `probts.model.forecaster.point_forecaster.TimesNet` |
| [PatchTST](https://arxiv.org/abs/2211.14730) | Long | Point | NAR | `probts.model.forecaster.point_forecaster.PatchTST` |
| [iTransformer](https://arxiv.org/abs/2310.06625) | Long | Point | NAR | `probts.model.forecaster.point_forecaster.iTransformer` |
| [ElasTST](https://arxiv.org/abs/2411.01842) | Long | Point | NAR | `probts.model.forecaster.point_forecaster.ElasTST` |
| [GRU NVP](https://arxiv.org/abs/2002.06103) | Short | Probabilistic | AR | `probts.model.forecaster.prob_forecaster.GRU_NVP` |
| [GRU MAF](https://arxiv.org/abs/2002.06103) | Short | Probabilistic | AR | `probts.model.forecaster.prob_forecaster.GRU_MAF` |
| [Trans MAF](https://arxiv.org/abs/2002.06103) | Short | Probabilistic | AR | `probts.model.forecaster.prob_forecaster.Trans_MAF` |
| [TimeGrad](https://arxiv.org/abs/2101.12072) | Short | Probabilistic | AR | `probts.model.forecaster.prob_forecaster.TimeGrad` |
| [CSDI](https://arxiv.org/abs/2107.03502) | Short | Probabilistic | NAR | `probts.model.forecaster.prob_forecaster.CSDI` |
| [TSDiff](https://arxiv.org/abs/2307.11494) | Short | Probabilistic | NAR | `probts.model.forecaster.prob_forecaster.TSDiffCond` |
### Foundation Models
| **Model** | **Any Horizon** | **Estimation** | **Decoding Scheme** | **Class Path** | **Model Size** |
| --- | --- | --- | --- | --- | --- |
| [Lag-Llama](https://arxiv.org/abs/2310.08278) | ✔ | Probabilistic | AR | `probts.model.forecaster.prob_forecaster.LagLlama` | - |
| [ForecastPFN](https://arxiv.org/abs/2311.01933) | ✔ | Point | NAR | `probts.model.forecaster.point_forecaster.ForecastPFN` | - |
| [TimesFM](https://arxiv.org/abs/2310.10688) | ✔ | Point | AR | `probts.model.forecaster.point_forecaster.TimesFM` | `200m`, `500m` |
| [TTM](https://arxiv.org/abs/2401.03955) | ✘ | Point | NAR | `probts.model.forecaster.point_forecaster.TinyTimeMixer` | - |
| [Timer](https://arxiv.org/abs/2402.02368) | ✔ | Point | AR | `probts.model.forecaster.point_forecaster.Timer` | - |
| [MOIRAI](https://arxiv.org/abs/2402.02592) | ✔ | Probabilistic | NAR | `probts.model.forecaster.prob_forecaster.Moirai` | `small`, `base`, `large` |
| [UniTS](https://arxiv.org/abs/2403.00131) | ✔ | Point | NAR | `probts.model.forecaster.point_forecaster.UniTS` | - |
| [Chronos](https://arxiv.org/abs/2403.07815) | ✔ | Probabilistic | AR | `probts.model.forecaster.prob_forecaster.Chronos` | `tiny`, `mini`, `small`, `base`, `large` |
| [Time-MoE](https://arxiv.org/abs/2409.16040) | ✔ | Point | AR | `probts.model.forecaster.point_forecaster.TimeMoE` | `50M`, `200M` |
See the [tsfm configuration directory](./config/tsfm/) for more details. More models will be added soon—stay tuned!
## Setup :wrench:
### Environment
ProbTS is developed with Python 3.10 and relies on [PyTorch Lightning](https://github.com/Lightning-AI/lightning). To set up the environment:
```bash
# Create a new conda environment
conda create -n probts python=3.10
conda activate probts
# Install required packages
pip install .
pip uninstall -y probts # recommended to uninstall the root package (optional)
```
Optional for TSFMs reproducibility
For time-series foundation models, you need to install basic packages and additional dependencies:
**1. Set Up Environment**
```bash
# Create a new conda environment
conda create -n probts_fm python=3.10
conda activate probts_fm
# Git submodule
git submodule update --init --recursive
# Install additional packages for foundation models
pip install ".[tsfm]"
pip uninstall -y probts # recommended to uninstall the root package (optional)
```
**2. Initialize Submodules**
```bash
# For MOIRAI, we fix the version of the package for better performance
cd submodules/uni2ts
git reset --hard fce6a6f57bc3bc1a57c7feb3abc6c7eb2f264301
# For Lag-Llama, fix the version for reproducibility (optional)
cd submodules/lag_llama
git reset --hard 4ad82d9
# For TinyTimeMixer, fix the version for reproducibility (optional)
cd submodules/tsfm
git reset --hard bb125c14a05e4231636d6b64f8951d5fe96da1dc
```
### Datasets
For a complete dataset list, refer to the [Datasets Overview](./docs/documentation/README.md#datasets-overview).
- **Short-Term Forecasting**: We use datasets from [GluonTS](https://github.com/awslabs/gluonts).
Configure the datasets using `--data.data_manager.init_args.dataset {DATASET_NAME}`. You can choose from multivariate or univariate datasets as per your requirement.
```bash
['exchange_rate_nips', 'electricity_nips', 'traffic_nips', 'solar_nips', 'wiki2000_nips']
```
- **Long-Term Forecasting**: To download the [long-term forecasting datasets](https://drive.google.com/drive/folders/1ZOYpTUa82_jCcxIdTmyr0LXQfvaM9vIy), please follow these steps:
```bash
bash scripts/prepare_datasets.sh "./datasets"
```
Configure the datasets using `--data.data_manager.init_args.dataset {DATASET_NAME}` with the following list of available datasets:
```bash
['etth1', 'etth2','ettm1','ettm2','traffic_ltsf', 'electricity_ltsf', 'exchange_ltsf', 'illness_ltsf', 'weather_ltsf', 'caiso', 'nordpool']
```
*Note: When utilizing long-term forecasting datasets, you must explicitly specify the `context_length` and `prediction_length` parameters. For example, to set a context length of 96 and a prediction length of 192, use the following command-line arguments:*
```bash
--data.data_manager.init_args.context_length 96 \
--data.data_manager.init_args.prediction_length 192 \
```
- **Using Datasets from Monash Time Series Forecasting Repository**: To use datasets from the [Monash Time Series Forecasting Repository](https://forecastingdata.org/), follow these steps:
1. **Download the Dataset**:
- Navigate to the target dataset, such as the [Electricity Hourly Dataset](https://zenodo.org/records/4656140).
- Download the `.tsf` file and place it in your local `datasets` directory (e.g., `./datasets`).
1. **Configure the Dataset**:
- Use the following configuration to specify the dataset, file path, and frequency:
```bash
--data.data_manager.init_args.dataset {DATASET_NAME} \
--data.data_manager.init_args.data_path /path/to/data_file.tsf \
--data.data_manager.init_args.freq {FREQ}
```
- **Example Configuration**:
```bash
--data.data_manager.init_args.dataset monash_electricity_hourly \
--data.data_manager.init_args.data_path ./datasets/electricity_hourly_dataset.tsf \
--data.data_manager.init_args.freq H \
--data.data_manager.init_args.context_length 96 \
--data.data_manager.init_args.prediction_length 96 \
--data.data_manager.init_args.multivariate true
```
*Note 1: Refer to the [Pandas Time Series Offset Aliases](https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#timeseries-offset-aliases) for the correct frequency values (`{FREQ}`) to use in your configuration.*
*Note 2: You can adjust the test instance sampling using the `--data.data_manager.init_args.test_rolling_length` parameter.*
### Checkpoints for Foundation Models
Download the checkpoints with the following command (details can be found [here](./checkpoints/README.md)):
```bash
bash scripts/prepare_tsfm_checkpoints.sh # By downloading, you agree to the original licenses
```
## Quick Start :rocket:
Specify `--config` with a specific configuration file to reproduce results of point or probabilistic models on commonly used long- and short-term forecasting datasets. Configuration files are included in the [config](./config/) folder.
To run models:
```bash
bash run.sh
```
Experimental results reproduction:
- **Long-term Forecasting:**
```bash
bash scripts/reproduce_ltsf_results.sh
```
- **Short-term Forecasting:**
```bash
bash scripts/reproduce_stsf_results.sh
```
- **Time Series Foundation Models:**
```bash
bash scripts/reproduce_tsfm_results.sh
```
### Short-term Forecasting Configuration
For short-term forecasting scenarios, datasets and corresponding `context_length` and `prediction_length` are automatically obtained from [GluonTS](https://github.com/awslabs/gluonts). Use the following command:
```bash
python run.py --config config/path/to/model.yaml \
--data.data_manager.init_args.path /path/to/datasets/ \
--trainer.default_root_dir /path/to/log_dir/ \
--data.data_manager.init_args.dataset {DATASET_NAME}
```
See full `DATASET_NAME` list:
```python
from gluonts.dataset.repository import dataset_names
print(dataset_names)
```
### Long-term Forecasting Configuration
For long-term forecasting scenarios, `context_length` and `prediction_length` must be explicitly assigned:
```bash
python run.py --config config/path/to/model.yaml \
--data.data_manager.init_args.path /path/to/datasets/ \
--trainer.default_root_dir /path/to/log_dir/ \
--data.data_manager.init_args.dataset {DATASET_NAME} \
--data.data_manager.init_args.context_length {CTX_LEN} \
--data.data_manager.init_args.prediction_length {PRED_LEN}
```
`DATASET_NAME` options:
```bash
['etth1', 'etth2','ettm1','ettm2','traffic_ltsf', 'electricity_ltsf', 'exchange_ltsf', 'illness_ltsf', 'weather_ltsf', 'caiso', 'nordpool']
```
### Forecasting with Varied Prediction Lengths
Conventional forecasting models typically require specific training and deployment for each prediction horizon. However, with the growing importance of varied-horizon forecasting, there is a need for models that can deliver robust predictions across multiple inference horizons after a single training phase.
ProbTS has been updated to support varied-horizon forecasting by enabling the specification of distinct context and prediction lengths for the training, validation, and testing phases.
**Quick Start**
To quickly train and evaluate ElasTST:
```bash
bash scripts/run_elastst.sh
```
To quickly set up varied-horizon training:
```bash
bash scripts/run_varied_hor_training.sh
```
For detailed information on the configuration, refer to the [documentation](./docs/documentation/README.md#forecasting-with-varied-prediction-lengths).
*Note: Currently, this feature is only supported by ElasTST, Autoformer, and foundation models.*
## Benchmarking :balance_scale:
By utilizing ProbTS, we conduct a systematic comparison between studies that focus on point forecasting and those aimed at distributional estimation, employing various forecasting horizons and evaluation metrics. For more details
- [Short-term & Long-term Forecasting Benchmarking](./docs/benchmark/README.md)
- [Evaluating Time Series Foundation Models](./docs/benchmark/FOUNDATION_MODEL.md)
## Documentation :open_book:
For detailed information on configuration parameters and model customization, please refer to the [documentation](./docs/documentation/README.md).
- To print the full pipeline configuration to a file:
```bash
python run.py --print_config > config/pipeline_config.yaml
```
## Acknowledgement 🌟
Special thanks to the following repositories for their open-sourced code bases and datasets.
### Tools/Packages
- [GluonTS](https://github.com/awslabs/gluonts)
- [PyTorch-TS](https://github.com/zalandoresearch/pytorch-ts)
- [TSLib](https://github.com/libts/tslib)
- [NeuralForecast](https://github.com/Nixtla/neuralforecast)
### Official Implementations
**Classical Time-series Models**
- [Autoformer](https://github.com/thuml/Autoformer)
- [N-HiTS](https://github.com/cchallu/n-hits)
- [NLinear, DLinear](https://github.com/cure-lab/LTSF-Linear)
- [TimesNet](https://github.com/thuml/Time-Series-Library)
- [RevIN](https://github.com/ts-kim/RevIN)
- [PatchTST](https://github.com/yuqinie98/PatchTST)
- [iTransformer](https://github.com/thuml/iTransformer)
- [GRU NVP, GRU MAF, Trans MAF, TimeGrad](https://github.com/zalandoresearch/pytorch-ts/tree/master)
- [CSDI](https://github.com/ermongroup/CSDI)
- [TSDiff](https://github.com/amazon-science/unconditional-time-series-diffusion)
**Time-series Foundation Models**
- [MOIRAI](https://github.com/SalesforceAIResearch/uni2ts)
- [Chronos](https://github.com/amazon-science/chronos-forecasting)
- [Lag-Llama](https://github.com/time-series-foundation-models/lag-llama)
- [TimesFM](https://github.com/google-research/timesfm)
- [Timer](https://github.com/thuml/Large-Time-Series-Model)
- [UniTS](https://github.com/mims-harvard/UniTS)
- [ForecastPFN](https://github.com/abacusai/ForecastPFN)
- [TTM](https://github.com/ibm-granite/granite-tsfm)
## Citing ProbTS :beers:
If you have used ProbTS for research or production, please cite it as follows.
```tex
@inproceedings{zhang2024probts,
title={{ProbTS}: Benchmarking Point and Distributional Forecasting across Diverse Prediction Horizons},
author={Zhang, Jiawen and Wen, Xumeng and Zhang, Zhenwei and Zheng, Shun and Li, Jia and Bian, Jiang},
booktitle={NeurIPS Datasets and Benchmarks Track},
year={2024}
}
```
================================================
FILE: SECURITY.md
================================================
## Security
Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet) and [Xamarin](https://github.com/xamarin).
If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below.
## Reporting Security Issues
**Please do not report security vulnerabilities through public GitHub issues.**
Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report).
If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/security.md/msrc/pgp).
You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc).
Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
* Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
* Full paths of source file(s) related to the manifestation of the issue
* The location of the affected source code (tag/branch/commit or direct URL)
* Any special configuration required to reproduce the issue
* Step-by-step instructions to reproduce the issue
* Proof-of-concept or exploit code (if possible)
* Impact of the issue, including how an attacker might exploit the issue
This information will help us triage your report more quickly.
If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs.
## Preferred Languages
We prefer all communications to be in English.
## Policy
Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd).
================================================
FILE: checkpoints/README.md
================================================
# Checkpoints for Foundation Models
For full reproducibility, we provide the checkpoints for some foundation models as of the paper completion date.
Download the checkpoints from [Google Drive](https://drive.google.com/drive/folders/1FaCk9Lj9KZGEO09gehNqC4fbTj4wnN8j?usp=sharing) with:
```bash
# By downloading, you agree to the terms of the original license agreements.
sh scripts/prepare_checkpoints.sh # in root directory
```
You can also download the newest checkpoints from the following repositories:
- For `Timer`, download the checkpoints from its [official repository](https://github.com/thuml/Large-Time-Series-Model?tab=readme-ov-file#code-for-fine-tuning) ([Google Drive](https://drive.google.com/drive/folders/15oaiAl4OO5gFqZMJD2lOtX2fxHbpgcU8) or [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/d/235e6bfcf5fa440bb119/)) under the folder `./checkpoints/timer/Timer_67M_UTSD_4G.pt`.
- For `ForecastPFN`, download the checkpoints from its [official repository](https://github.com/abacusai/ForecastPFN#installation-) ([Google Drive](https://drive.google.com/file/d/1acp5thS7I4g_6Gw40wNFGnU1Sx14z0cU/view)) under the folder `./checkpoints/ForecastPFN/saved_weights`.
- For `UniTS`, download the checkpoints `units_x128_pretrain_checkpoint.pth` from its [official repository](https://github.com/mims-harvard/UniTS/releases/tag/ckpt) under the folder `./checkpoints/units/units_x128_pretrain_checkpoint.pth`.
- For `Lag-Llama`, download the checkpoints `lag-llama.ckpt` from its [huggingface repository](https://huggingface.co/time-series-foundation-models/Lag-Llama/tree/main) under the folder `./checkpoints/lag-llama/lag-llama.ckpt`.
- For other models, they can be automatically downloaded from huggingface during the first run.
| **Model** | **HuggingFace** |
| --- | --- |
| `MOIRAI` | [Link](https://huggingface.co/Salesforce/moirai-1.0-R-small) |
| `Chronos` | [Link](https://huggingface.co/amazon/chronos-t5-large) |
| `TinyTimeMixer` | [Link](https://huggingface.co/ibm-granite/granite-timeseries-ttm-v1) |
| `TimesFM` | [Link](https://huggingface.co/google/timesfm-1.0-200m) |
================================================
FILE: config/default/autoformer.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 0
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
accumulate_grad_batches: 1
# num_sanity_val_steps: 0
# gradient_clip_algorithm: 'norm'
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.Autoformer
init_args:
moving_avg: 25
factor: 1
n_heads: 8
activation: 'gelu'
e_layers: 2
d_layers: 1
output_attention: false
d_ff: 512
f_hidden_size: 512
embed: 'timeF'
use_lags: false
use_feat_idx_emb: false
use_time_feat: true
feat_idx_emb_dim: 1
num_samples: 1
learning_rate: 1e-3
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: solar_nips
split_val: true
scaler: standard # none, standard, scaling
batch_size: 32
test_batch_size: 32
num_workers: 8
================================================
FILE: config/default/csdi.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 800
log_every_n_steps: 1
check_val_every_n_epoch: 2
default_root_dir: ./results
accumulate_grad_batches: 8
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.CSDI
init_args:
emb_time_dim: 128
emb_feature_dim: 16
channels: 64
n_layers: 4
num_heads: 8
num_steps: 50
diffusion_embedding_dim: 128
beta_start: 0.001
beta_end: 0.5
sample_size: 64
linear_trans: false
use_lags: false
use_feat_idx_emb: false
use_time_feat: false
feat_idx_emb_dim: 1
num_samples: 100
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: solar_nips
split_val: true
scaler: standard # identity, standard, temporal
batch_size: 4
test_batch_size: 4
num_workers: 8
================================================
FILE: config/default/dlinear.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.DLinear
init_args:
individual: false
kernel_size: 3
use_lags: true
use_feat_idx_emb: true
use_time_feat: true
learning_rate: 0.01
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: solar_nips
split_val: true
scaler: standard # identity, standard, temporal
batch_size: 32
test_batch_size: 32
num_workers: 8
================================================
FILE: config/default/gru.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.GRUForecaster
init_args:
f_hidden_size: 40
num_layers: 2
dropout: 0.1
use_lags: true
use_feat_idx_emb: true
use_time_feat: true
feat_idx_emb_dim: 1
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: solar_nips
split_val: true
scaler: standard # identity, standard, temporal
batch_size: 64
test_batch_size: 64
num_workers: 8
================================================
FILE: config/default/gru_maf.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 1
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.GRU_MAF
init_args:
enc_num_layers: 2
enc_hidden_size: 40
enc_dropout: 0.1
n_blocks: 4
hidden_size: 100
n_hidden: 2
batch_norm: false
conditional_length: 200
dequantize: true
use_lags: true
use_feat_idx_emb: true
use_time_feat: true
feat_idx_emb_dim: 1
use_scaling: true
num_samples: 100
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: solar_nips
scaler: identity # identity, standard, temporal
split_val: true
batch_size: 64
test_batch_size: 64
num_workers: 8
================================================
FILE: config/default/gru_nvp.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 7
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.GRU_NVP
init_args:
enc_hidden_size: 40
enc_num_layers: 2
enc_dropout: 0.1
n_blocks: 4
hidden_size: 100
n_hidden: 2
batch_norm: true
conditional_length: 200
dequantize: true
use_lags: true
use_feat_idx_emb: true
use_time_feat: true
feat_idx_emb_dim: 1
use_scaling: true
num_samples: 100
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: solar_nips
split_val: true
scaler: identity # identity, standard, temporal
batch_size: 64
test_batch_size: 64
num_workers: 8
================================================
FILE: config/default/itransformer.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 0
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
accumulate_grad_batches: 1
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.iTransformer
init_args:
factor: 1
n_heads: 8
activation: 'gelu'
e_layers: 2
output_attention: false
f_hidden_size: 256
d_ff: 256
label_len: 48
use_lags: false
use_feat_idx_emb: false
use_time_feat: false
feat_idx_emb_dim: 1
num_samples: 1
learning_rate: 1e-4
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: solar_nips
split_val: true
scaler: standard # none, standard, scaling
batch_size: 32
test_batch_size: 32
num_workers: 8
================================================
FILE: config/default/linear.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 30
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.LinearForecaster
init_args:
individual: false
use_lags: true
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: solar_nips
split_val: true
scaler: standard # identity, standard, temporal
batch_size: 64
test_batch_size: 64
num_workers: 8
================================================
FILE: config/default/mean.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 40
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.MeanForecaster
init_args:
mode: global
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: solar_nips
split_val: true
scaler: identity # identity, standard, temporal
batch_size: 64
test_batch_size: 64
num_workers: 8
================================================
FILE: config/default/moderntcn.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 0
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.ModernTCN
init_args:
ffn_ratio: 1
patch_size: 8
patch_stride: 4
num_blocks: [1]
large_size: [51]
dims: [64, 64, 64, 64]
dropout: 0.3
kernel_size: 3
small_size: [5]
use_multi_scale: false
small_kernel_merged: false
learning_rate: 0.0001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: etth1
split_val: true
scaler: standard # identity, standard, temporal
context_length: 96
prediction_length: 96
batch_size: 32
test_batch_size: 32
num_workers: 8
================================================
FILE: config/default/naive.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 40
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.NaiveForecaster
learning_rate: 0.001
quantiles_num: 10
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: solar_nips
split_val: true
scaler: identity # identity, standard, temporal
batch_size: 64
test_batch_size: 64
num_workers: 8
================================================
FILE: config/default/nhits.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
accumulate_grad_batches: 4
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.NHiTS
init_args:
n_blocks: [1,1,1]
hidden_size: 512
pooling_mode: 'max'
interpolation_mode: 'linear'
activation: 'ReLU'
initialization: 'lecun_normal'
batch_normalization: false
shared_weights: false
naive_level:
dropout: 0
n_layers: 2
use_lags: false
use_feat_idx_emb: true
use_time_feat: true
feat_idx_emb_dim: 1
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: solar_nips
split_val: true
scaler: standard # identity, standard, temporal
batch_size: 64
test_batch_size: 64
num_workers: 8
================================================
FILE: config/default/nlinear.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.NLinear
init_args:
individual: false
use_lags: false
use_feat_idx_emb: false
use_time_feat: false
learning_rate: 0.01
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: solar_nips
split_val: true
scaler: standard # identity, standard, temporal
batch_size: 64
test_batch_size: 64
num_workers: 8
================================================
FILE: config/default/patchtst.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
accumulate_grad_batches: 1
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.PatchTST
init_args:
stride: 3
patch_len: 6
dropout: 0.1
f_hidden_size: 32
n_layers: 3
n_heads: 8
fc_dropout: 0.2
head_dropout: 0
individual: false
optimizer_config:
class_name: torch.optim.Adam
init_args:
weight_decay: 0
lr_scheduler_config:
class_name: torch.optim.lr_scheduler.OneCycleLR
init_args:
max_lr: 0.0001
steps_per_epoch: 100
pct_start: 0.3
epochs: 50
learning_rate: 0.0001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: exchange_rate_nips
split_val: true
scaler: standard # identity, standard, temporal
batch_size: 64
test_batch_size: 64
num_workers: 8
================================================
FILE: config/default/timegrad.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.TimeGrad
init_args:
loss_type: l2
diff_steps: 100
beta_end: 0.1
beta_schedule: linear
conditional_length: 100
enc_hidden_size: 128
enc_num_layers: 4
enc_dropout: 0.1
use_lags: true
use_feat_idx_emb: true
use_time_feat: true
feat_idx_emb_dim: 1
use_scaling: true
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: solar_nips
scaler: identity # identity, standard, temporal
split_val: true
batch_size: 64
test_batch_size: 64
num_workers: 8
================================================
FILE: config/default/timesnet.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.TimesNet
init_args:
n_layers: 2
num_kernels: 6
top_k: 5
d_ff: 32
dropout: 0.1
f_hidden_size: 40
use_lags: false
use_feat_idx_emb: false
use_time_feat: false
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: solar_nips
split_val: true
scaler: standard # identity, standard, temporal
batch_size: 64
test_batch_size: 64
num_workers: 8
================================================
FILE: config/default/trans_maf.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.Trans_MAF
init_args:
enc_hidden_size: 32
enc_num_heads: 8
enc_num_encoder_layers: 2
enc_num_decoder_layers: 2
enc_dim_feedforward_scale: 4
enc_dropout: 0.1
enc_activation: gelu
n_blocks: 4
hidden_size: 100
n_hidden: 2
batch_norm: false
conditional_length: 200
dequantize: true
use_lags: true
use_feat_idx_emb: true
use_time_feat: true
feat_idx_emb_dim: 1
use_scaling: true
num_samples: 100
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: solar_nips
scaler: identity # identity, standard, temporal
split_val: true
batch_size: 64
test_batch_size: 64
num_workers: 8
================================================
FILE: config/default/transformer.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.TransformerForecaster
init_args:
f_hidden_size: 16
num_heads: 4
num_encoder_layers: 3
num_decoder_layers: 3
dim_feedforward_scale: 4
dropout: 0.1
activation: gelu
use_lags: true
use_feat_idx_emb: true
use_time_feat: true
feat_idx_emb_dim: 1
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: solar_nips
split_val: true
scaler: standard # identity, standard, temporal
batch_size: 64
test_batch_size: 64
num_workers: 8
================================================
FILE: config/default/tsdiff.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
check_val_every_n_epoch: 1
default_root_dir: ./results
accumulate_grad_batches: 1
gradient_clip_val: 0.5
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.TSDiffCond
init_args:
timesteps: 100
hidden_dim: 64
step_emb: 128
num_residual_blocks: 3
dropout: 0.0
mode: diag # diag, nplr
measure: diag # 'diag', 'diag-lin', 'diag-inv', or 'diag-legs' for diag
use_lags: false
use_feat_idx_emb: false
use_time_feat: false
feat_idx_emb_dim: 1
use_scaling: false
num_samples: 100
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: solar_nips
split_val: true
scaler: temporal # identity, standard, temporal
context_length: 336
batch_size: 32
test_batch_size: 32
num_workers: 8
================================================
FILE: config/default/tsmixer.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 40
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
accumulate_grad_batches: 1
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.TSMixer
init_args:
num_blocks: 6
dropout_rate: 0.7
ff_dim: 64
use_lags: false
use_feat_idx_emb: false
use_time_feat: false
feat_idx_emb_dim: 1
learning_rate: 0.0001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: etth1
split_val: true
scaler: standard # identity, standard, temporal
context_length: 96
prediction_length: 96
batch_size: 64
test_batch_size: 64
num_workers: 8
================================================
FILE: config/ltsf/electricity_ltsf/csdi.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 800
log_every_n_steps: 1
check_val_every_n_epoch: 3
default_root_dir: ./results
accumulate_grad_batches: 8
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.CSDI
init_args:
emb_time_dim: 64
emb_feature_dim: 8
channels: 64
n_layers: 4
num_heads: 8
num_steps: 50
diffusion_embedding_dim: 64
beta_start: 0.001
beta_end: 0.5
sample_size: 16
linear_trans: false
use_lags: false
use_feat_idx_emb: false
use_time_feat: false
feat_idx_emb_dim: 1
num_samples: 100
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: electricity_ltsf
scaler: standard # identity, standard, temporal
split_val: true
batch_size: 4
test_batch_size: 8
num_workers: 8
================================================
FILE: config/ltsf/electricity_ltsf/dlinear.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 200
log_every_n_steps: 1
accumulate_grad_batches: 2
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.DLinearEncoder
init_args:
individual: true
kernel_size: 25
use_lags: false
use_feat_idx_emb: false
use_time_feat: false
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: electricity_ltsf
split_val: true
scaler: standard # identity, standard, temporal
context_length: 96
prediction_length: 96
batch_size: 16
test_batch_size: 16
num_workers: 8
================================================
FILE: config/ltsf/electricity_ltsf/gru_nvp.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 400
log_every_n_steps: 1
default_root_dir: ./results
accumulate_grad_batches: 4
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.GRU_NVP
init_args:
enc_hidden_size: 128
enc_num_layers: 2
enc_dropout: 0.1
n_blocks: 2
hidden_size: 64
n_hidden: 2
batch_norm: false
conditional_length: 200
dequantize: false
use_lags: true
use_feat_idx_emb: true
use_time_feat: true
feat_idx_emb_dim: 1
use_scaling: true
num_samples: 100
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: electricity_ltsf
split_val: true
scaler: identity # identity, standard, temporal
context_length: 96
prediction_length: 96
batch_size: 16
test_batch_size: 16
num_workers: 8
================================================
FILE: config/ltsf/electricity_ltsf/patchtst.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 400
log_every_n_steps: 1
default_root_dir: ./results
accumulate_grad_batches: 4
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.PatchTST
init_args:
stride: 8
patch_len: 16
dropout: 0.2
f_hidden_size: 128
n_layers: 3
n_heads: 16
fc_dropout: 0.2
head_dropout: 0
individual: false
num_samples: 100
learning_rate: 0.0001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: electricity_ltsf
split_val: true
scaler: standard # identity, standard, temporal
context_length: 96
prediction_length: 96
batch_size: 8
test_batch_size: 8
num_workers: 8
================================================
FILE: config/ltsf/electricity_ltsf/timegrad.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
accumulate_grad_batches: 4
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.TimeGrad
init_args:
loss_type: l2
diff_steps: 100
beta_end: 0.1
beta_schedule: linear
conditional_length: 200
enc_hidden_size: 128
enc_num_layers: 3
enc_dropout: 0.1
use_lags: true
use_feat_idx_emb: true
use_time_feat: true
feat_idx_emb_dim: 1
use_scaling: true
num_samples: 100
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: electricity_ltsf
split_val: true
scaler: identity # identity, standard, temporal
context_length: 96
prediction_length: 96
batch_size: 16
test_batch_size: 16
num_workers: 8
================================================
FILE: config/ltsf/etth1/csdi.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 800
log_every_n_steps: 1
check_val_every_n_epoch: 2
default_root_dir: ./results
accumulate_grad_batches: 8
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.CSDI
init_args:
emb_time_dim: 128
emb_feature_dim: 16
channels: 64
n_layers: 4
num_heads: 8
num_steps: 50
diffusion_embedding_dim: 128
beta_start: 0.001
beta_end: 0.5
sample_size: 64
linear_trans: false
use_lags: false
use_feat_idx_emb: false
use_time_feat: false
feat_idx_emb_dim: 1
num_samples: 100
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: etth1
split_val: true
scaler: standard # identity, standard, temporal
context_length: 96
prediction_length: 96
batch_size: 8
test_batch_size: 8
num_workers: 8
================================================
FILE: config/ltsf/etth1/dlinear.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
accumulate_grad_batches: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.DLinear
init_args:
individual: true
kernel_size: 25
use_lags: false
use_feat_idx_emb: false
use_time_feat: false
learning_rate: 0.005
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: etth1
split_val: true
scaler: standard # identity, standard, temporal
context_length: 96
prediction_length: 96
batch_size: 32
test_batch_size: 32
num_workers: 8
================================================
FILE: config/ltsf/etth1/gru_nvp.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
accumulate_grad_batches: 1
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.GRU_NVP
init_args:
enc_hidden_size: 64
enc_num_layers: 2
enc_dropout: 0.1
n_blocks: 4
hidden_size: 64
n_hidden: 3
batch_norm: false
conditional_length: 100
dequantize: false
use_lags: true
use_feat_idx_emb: true
use_time_feat: true
feat_idx_emb_dim: 1
use_scaling: true
num_samples: 100
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: etth1
split_val: true
scaler: identity # identity, standard, temporal
context_length: 96
prediction_length: 96
batch_size: 64
test_batch_size: 64
num_workers: 8
================================================
FILE: config/ltsf/etth1/patchtst.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
accumulate_grad_batches: 4
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.PatchTST
init_args:
stride: 8
patch_len: 16
dropout: 0.3
f_hidden_size: 16
n_layers: 3
n_heads: 4
fc_dropout: 0.2
head_dropout: 0
individual: true
learning_rate: 0.0001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: etth1
split_val: true
scaler: standard # identity, standard, temporal
context_length: 96
prediction_length: 96
batch_size: 32
test_batch_size: 32
num_workers: 8
================================================
FILE: config/ltsf/etth1/timegrad.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
accumulate_grad_batches: 4
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.TimeGrad
init_args:
loss_type: l2
diff_steps: 100
beta_end: 0.1
beta_schedule: linear
conditional_length: 200
enc_hidden_size: 128
enc_num_layers: 3
enc_dropout: 0.1
use_lags: true
use_feat_idx_emb: true
use_time_feat: true
feat_idx_emb_dim: 1
use_scaling: true
num_samples: 100
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: etth1
split_val: true
scaler: identity # identity, standard, temporal
context_length: 96
prediction_length: 96
batch_size: 16
test_batch_size: 16
num_workers: 8
================================================
FILE: config/ltsf/etth2/csdi.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 800
log_every_n_steps: 1
check_val_every_n_epoch: 2
default_root_dir: ./results
accumulate_grad_batches: 8
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.CSDI
init_args:
emb_time_dim: 128
emb_feature_dim: 16
channels: 64
n_layers: 4
num_heads: 8
num_steps: 50
diffusion_embedding_dim: 128
beta_start: 0.001
beta_end: 0.5
sample_size: 64
linear_trans: false
use_lags: false
use_feat_idx_emb: false
use_time_feat: false
feat_idx_emb_dim: 1
num_samples: 100
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: etth2
split_val: true
scaler: standard # identity, standard, temporal
context_length: 96
prediction_length: 96
batch_size: 8
test_batch_size: 8
num_workers: 8
================================================
FILE: config/ltsf/etth2/dlinear.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
accumulate_grad_batches: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.DLinear
init_args:
individual: false
kernel_size: 25
use_lags: false
use_feat_idx_emb: false
use_time_feat: false
learning_rate: 0.05
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: etth2
split_val: true
scaler: standard # identity, standard, temporal
context_length: 96
prediction_length: 96
batch_size: 32
test_batch_size: 32
num_workers: 8
================================================
FILE: config/ltsf/etth2/gru_nvp.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 400
log_every_n_steps: 1
default_root_dir: ./results
accumulate_grad_batches: 4
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.GRU_NVP
init_args:
enc_hidden_size: 64
enc_num_layers: 4
enc_dropout: 0.1
n_blocks: 2
hidden_size: 128
n_hidden: 3
batch_norm: true
conditional_length: 200
dequantize: false
use_lags: true
use_feat_idx_emb: true
use_time_feat: true
feat_idx_emb_dim: 1
use_scaling: true
num_samples: 100
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: etth2
path: /home/covpreduser/Blob/v-jiawezhang/data/all_datasets/
split_val: true
scaler: identity # identity, standard, temporal
context_length: 96
prediction_length: 96
batch_size: 16
test_batch_size: 16
num_workers: 8
================================================
FILE: config/ltsf/etth2/patchtst.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
accumulate_grad_batches: 1
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.PatchTST
init_args:
stride: 8
patch_len: 16
dropout: 0.3
f_hidden_size: 16
d_ff: 128
n_layers: 3
n_heads: 4
fc_dropout: 0.2
head_dropout: 0
individual: false
learning_rate: 0.0001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: etth2
split_val: true
scaler: standard # identity, standard, temporal
context_length: 96
prediction_length: 96
batch_size: 32
test_batch_size: 32
num_workers: 8
================================================
FILE: config/ltsf/etth2/timegrad.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
accumulate_grad_batches: 4
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.TimeGrad
init_args:
loss_type: l2
diff_steps: 100
beta_end: 0.1
beta_schedule: linear
conditional_length: 100
enc_hidden_size: 64
enc_num_layers: 4
enc_dropout: 0.1
use_lags: true
use_feat_idx_emb: true
use_time_feat: true
feat_idx_emb_dim: 1
use_scaling: true
num_samples: 100
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: etth2
split_val: true
scaler: identity # identity, standard, temporal
context_length: 96
prediction_length: 96
batch_size: 16
test_batch_size: 16
num_workers: 8
================================================
FILE: config/ltsf/ettm1/csdi.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 800
log_every_n_steps: 1
check_val_every_n_epoch: 2
default_root_dir: ./results
accumulate_grad_batches: 8
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.CSDI
init_args:
emb_time_dim: 128
emb_feature_dim: 16
channels: 64
n_layers: 4
num_heads: 8
num_steps: 50
diffusion_embedding_dim: 128
beta_start: 0.001
beta_end: 0.5
sample_size: 64
linear_trans: false
use_lags: false
use_feat_idx_emb: false
use_time_feat: false
feat_idx_emb_dim: 1
num_samples: 100
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: ettm1
split_val: true
scaler: standard # identity, standard, temporal
context_length: 96
prediction_length: 96
batch_size: 8
test_batch_size: 8
num_workers: 8
================================================
FILE: config/ltsf/ettm1/dlinear.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
accumulate_grad_batches: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.DLinear
init_args:
individual: true
kernel_size: 25
use_lags: false
use_feat_idx_emb: false
use_time_feat: false
learning_rate: 0.0001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: ettm1
split_val: true
scaler: standard # identity, standard, temporal
context_length: 96
prediction_length: 96
batch_size: 32
test_batch_size: 32
num_workers: 8
================================================
FILE: config/ltsf/ettm1/gru_nvp.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 400
log_every_n_steps: 1
default_root_dir: ./results
accumulate_grad_batches: 4
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.GRU_NVP
init_args:
enc_hidden_size: 64
enc_num_layers: 2
enc_dropout: 0.1
n_blocks: 4
hidden_size: 64
n_hidden: 3
batch_norm: false
conditional_length: 200
dequantize: false
use_lags: true
use_feat_idx_emb: true
use_time_feat: true
feat_idx_emb_dim: 1
use_scaling: true
num_samples: 100
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: ettm1
split_val: true
scaler: identity # identity, standard, temporal
context_length: 96
prediction_length: 96
batch_size: 16
test_batch_size: 16
num_workers: 8
================================================
FILE: config/ltsf/ettm1/patchtst.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
accumulate_grad_batches: 1
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.PatchTST
init_args:
stride: 8
patch_len: 16
dropout: 0.2
f_hidden_size: 128
n_layers: 3
n_heads: 16
fc_dropout: 0.2
head_dropout: 0
individual: true
learning_rate: 0.0001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: ettm1
split_val: true
scaler: standard # identity, standard, temporal
context_length: 96
prediction_length: 96
batch_size: 32
test_batch_size: 32
num_workers: 8
================================================
FILE: config/ltsf/ettm1/timegrad.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
accumulate_grad_batches: 4
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.TimeGrad
init_args:
loss_type: l2
diff_steps: 100
beta_end: 0.1
beta_schedule: linear
conditional_length: 200
enc_hidden_size: 128
enc_num_layers: 3
enc_dropout: 0.1
use_lags: true
use_feat_idx_emb: true
use_time_feat: true
feat_idx_emb_dim: 1
use_scaling: true
num_samples: 100
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: ettm1
split_val: true
scaler: identity # identity, standard, temporal
context_length: 96
prediction_length: 96
batch_size: 16
test_batch_size: 16
num_workers: 8
================================================
FILE: config/ltsf/ettm2/csdi.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 800
log_every_n_steps: 1
check_val_every_n_epoch: 2
default_root_dir: ./results
accumulate_grad_batches: 8
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.CSDI
init_args:
emb_time_dim: 128
emb_feature_dim: 16
channels: 64
n_layers: 4
num_heads: 8
num_steps: 50
diffusion_embedding_dim: 128
beta_start: 0.001
beta_end: 0.5
sample_size: 64
linear_trans: false
use_lags: false
use_feat_idx_emb: false
use_time_feat: false
feat_idx_emb_dim: 1
num_samples: 100
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: ettm2
split_val: true
scaler: standard # identity, standard, temporal
context_length: 96
prediction_length: 96
batch_size: 8
test_batch_size: 8
num_workers: 8
================================================
FILE: config/ltsf/ettm2/dlinear.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
accumulate_grad_batches: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.DLinear
init_args:
individual: false
kernel_size: 25
use_lags: false
use_feat_idx_emb: false
use_time_feat: false
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: ettm2
split_val: true
scaler: standard # identity, standard, temporal
context_length: 96
prediction_length: 96
batch_size: 32
test_batch_size: 32
num_workers: 8
================================================
FILE: config/ltsf/ettm2/gru_nvp.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 400
log_every_n_steps: 1
default_root_dir: ./results
accumulate_grad_batches: 4
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.GRU_NVP
init_args:
enc_hidden_size: 64
enc_num_layers: 4
enc_dropout: 0.1
n_blocks: 2
hidden_size: 128
n_hidden: 3
batch_norm: false
conditional_length: 200
dequantize: false
use_lags: true
use_feat_idx_emb: true
use_time_feat: true
feat_idx_emb_dim: 1
use_scaling: true
num_samples: 100
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: ettm2
split_val: true
scaler: identity # identity, standard, temporal
context_length: 96
prediction_length: 96
batch_size: 16
test_batch_size: 16
num_workers: 8
================================================
FILE: config/ltsf/ettm2/patchtst.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
accumulate_grad_batches: 1
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.PatchTST
init_args:
stride: 8
patch_len: 16
dropout: 0.2
f_hidden_size: 128
n_layers: 3
n_heads: 16
fc_dropout: 0.2
head_dropout: 0
individual: true
learning_rate: 0.0001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: ettm2
split_val: true
scaler: standard # identity, standard, temporal
context_length: 96
prediction_length: 96
batch_size: 32
test_batch_size: 32
num_workers: 8
================================================
FILE: config/ltsf/ettm2/timegrad.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
accumulate_grad_batches: 4
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.TimeGrad
init_args:
loss_type: l2
diff_steps: 100
beta_end: 0.1
beta_schedule: linear
conditional_length: 200
enc_hidden_size: 64
enc_num_layers: 2
enc_dropout: 0.1
use_lags: true
use_feat_idx_emb: true
use_time_feat: true
feat_idx_emb_dim: 1
use_scaling: true
num_samples: 100
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: ettm2
split_val: true
scaler: identity # identity, standard, temporal
context_length: 96
prediction_length: 96
batch_size: 16
test_batch_size: 16
num_workers: 8
================================================
FILE: config/ltsf/exchange_ltsf/csdi.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 800
log_every_n_steps: 1
check_val_every_n_epoch: 2
default_root_dir: ./results
accumulate_grad_batches: 8
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.CSDI
init_args:
emb_time_dim: 128
emb_feature_dim: 16
channels: 64
n_layers: 4
num_heads: 8
num_steps: 50
diffusion_embedding_dim: 128
beta_start: 0.001
beta_end: 0.5
sample_size: 64
linear_trans: false
use_lags: false
use_feat_idx_emb: false
use_time_feat: false
feat_idx_emb_dim: 1
num_samples: 100
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: exchange_ltsf
split_val: true
scaler: standard # identity, standard, temporal
context_length: 96
prediction_length: 96
batch_size: 8
test_batch_size: 8
num_workers: 8
================================================
FILE: config/ltsf/exchange_ltsf/dlinear.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
accumulate_grad_batches: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.DLinear
init_args:
individual: true
kernel_size: 25
use_lags: false
use_feat_idx_emb: false
use_time_feat: false
learning_rate: 0.0005
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: exchange_ltsf
split_val: true
scaler: standard # identity, standard, temporal
context_length: 96
prediction_length: 96
batch_size: 32
test_batch_size: 32
num_workers: 8
================================================
FILE: config/ltsf/exchange_ltsf/gru_nvp.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 400
log_every_n_steps: 1
default_root_dir: ./results
accumulate_grad_batches: 4
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.GRU_NVP
init_args:
enc_hidden_size: 128
enc_num_layers: 2
enc_dropout: 0.1
n_blocks: 2
hidden_size: 128
n_hidden: 3
batch_norm: false
conditional_length: 200
dequantize: false
use_lags: true
use_feat_idx_emb: true
use_time_feat: true
feat_idx_emb_dim: 1
use_scaling: true
num_samples: 100
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: exchange_ltsf
split_val: true
scaler: identity # identity, standard, temporal
context_length: 96
prediction_length: 96
batch_size: 16
test_batch_size: 16
num_workers: 8
================================================
FILE: config/ltsf/exchange_ltsf/patchtst.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
accumulate_grad_batches: 1
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.PatchTST
init_args:
stride: 8
patch_len: 16
dropout: 0.2
f_hidden_size: 16
n_layers: 3
n_heads: 4
fc_dropout: 0.2
head_dropout: 0
individual: true
learning_rate: 0.0001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: exchange_ltsf
split_val: true
scaler: standard # identity, standard, temporal
context_length: 96
prediction_length: 96
batch_size: 32
test_batch_size: 32
num_workers: 8
================================================
FILE: config/ltsf/exchange_ltsf/timegrad.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 400
log_every_n_steps: 1
default_root_dir: ./results
accumulate_grad_batches: 4
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.TimeGrad
init_args:
loss_type: l2
diff_steps: 100
beta_end: 0.1
beta_schedule: linear
conditional_length: 200
enc_hidden_size: 64
enc_num_layers: 4
enc_dropout: 0.1
use_lags: true
use_feat_idx_emb: true
use_time_feat: true
feat_idx_emb_dim: 1
use_scaling: true
num_samples: 100
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: exchange_ltsf
split_val: true
scaler: identity # identity, standard, temporal
context_length: 96
prediction_length: 96
batch_size: 16
test_batch_size: 16
num_workers: 8
================================================
FILE: config/ltsf/illness_ltsf/csdi.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 800
log_every_n_steps: 1
check_val_every_n_epoch: 2
default_root_dir: ./results
accumulate_grad_batches: 8
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.CSDI
init_args:
emb_time_dim: 128
emb_feature_dim: 16
channels: 64
n_layers: 4
num_heads: 8
num_steps: 50
diffusion_embedding_dim: 128
beta_start: 0.001
beta_end: 0.5
sample_size: 64
linear_trans: false
use_lags: false
use_feat_idx_emb: false
use_time_feat: false
feat_idx_emb_dim: 1
num_samples: 100
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: illness_ltsf
split_val: true
scaler: standard # identity, standard, temporal
context_length: 96
prediction_length: 96
batch_size: 8
test_batch_size: 8
num_workers: 8
================================================
FILE: config/ltsf/illness_ltsf/dlinear.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
accumulate_grad_batches: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.DLinear
init_args:
individual: false
kernel_size: 25
use_lags: false
use_feat_idx_emb: false
use_time_feat: false
learning_rate: 0.01
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: illness_ltsf
split_val: true
scaler: standard # identity, standard, temporal
context_length: 36
prediction_length: 36
batch_size: 32
test_batch_size: 32
num_workers: 8
================================================
FILE: config/ltsf/illness_ltsf/gru_nvp.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 400
log_every_n_steps: 1
default_root_dir: ./results
accumulate_grad_batches: 4
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.GRU_NVP
init_args:
enc_hidden_size: 64
enc_num_layers: 4
enc_dropout: 0.1
n_blocks: 4
hidden_size: 128
n_hidden: 2
batch_norm: false
conditional_length: 200
dequantize: false
use_lags: true
use_feat_idx_emb: true
use_time_feat: true
feat_idx_emb_dim: 1
use_scaling: true
num_samples: 100
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: illness_ltsf
split_val: true
scaler: identity # identity, standard, temporal
context_length: 36
prediction_length: 36
batch_size: 16
test_batch_size: 16
num_workers: 8
================================================
FILE: config/ltsf/illness_ltsf/patchtst.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
accumulate_grad_batches: 1
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.PatchTST
init_args:
stride: 2
patch_len: 24
dropout: 0.3
f_hidden_size: 16
n_layers: 3
n_heads: 4
fc_dropout: 0.3
head_dropout: 0
individual: true
learning_rate: 0.0025
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: illness_ltsf
path: /home/covpreduser/Blob/v-jiawezhang/data/all_datasets/
split_val: true
scaler: standard # identity, standard, temporal
context_length: 36
prediction_length: 36
batch_size: 32
test_batch_size: 32
num_workers: 8
================================================
FILE: config/ltsf/illness_ltsf/timegrad.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
accumulate_grad_batches: 4
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.TimeGrad
init_args:
loss_type: l2
diff_steps: 100
beta_end: 0.1
beta_schedule: linear
conditional_length: 200
enc_hidden_size: 64
enc_num_layers: 2
enc_dropout: 0.1
use_lags: true
use_feat_idx_emb: true
use_time_feat: true
feat_idx_emb_dim: 1
use_scaling: true
num_samples: 100
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: illness_ltsf
split_val: true
scaler: identity # identity, standard, temporal
context_length: 36
prediction_length: 36
batch_size: 16
test_batch_size: 16
num_workers: 8
================================================
FILE: config/ltsf/traffic_ltsf/csdi.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 800
log_every_n_steps: 1
check_val_every_n_epoch: 3
default_root_dir: ./results
accumulate_grad_batches: 8
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.CSDI
init_args:
emb_time_dim: 64
emb_feature_dim: 8
channels: 64
n_layers: 4
num_heads: 8
num_steps: 50
diffusion_embedding_dim: 64
beta_start: 0.001
beta_end: 0.5
sample_size: 16
linear_trans: false
use_lags: false
use_feat_idx_emb: false
use_time_feat: false
feat_idx_emb_dim: 1
num_samples: 100
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: traffic_ltsf
split_val: true
scaler: standard # identity, standard, temporal
context_length: 96
prediction_length: 96
batch_size: 4
test_batch_size: 4
num_workers: 8
================================================
FILE: config/ltsf/traffic_ltsf/dlinear.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 400
log_every_n_steps: 1
accumulate_grad_batches: 4
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.DLinear
init_args:
individual: false
kernel_size: 25
use_lags: false
use_feat_idx_emb: false
use_time_feat: false
learning_rate: 0.05
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: traffic_ltsf
split_val: true
scaler: standard # identity, standard, temporal
context_length: 96
prediction_length: 96
batch_size: 8
test_batch_size: 8
num_workers: 8
================================================
FILE: config/ltsf/traffic_ltsf/gru_nvp.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 400
log_every_n_steps: 1
default_root_dir: ./results
accumulate_grad_batches: 4
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.GRU_NVP
init_args:
enc_hidden_size: 128
enc_num_layers: 3
enc_dropout: 0.1
n_blocks: 4
hidden_size: 128
n_hidden: 3
batch_norm: true
conditional_length: 200
dequantize: false
use_lags: true
use_feat_idx_emb: true
use_time_feat: true
feat_idx_emb_dim: 1
use_scaling: true
num_samples: 100
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: traffic_ltsf
split_val: true
scaler: identity # identity, standard, temporal
context_length: 96
prediction_length: 96
batch_size: 16
test_batch_size: 16
num_workers: 8
================================================
FILE: config/ltsf/traffic_ltsf/patchtst.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 300
log_every_n_steps: 1
default_root_dir: ./results
accumulate_grad_batches: 3
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.PatchTST
init_args:
stride: 8
patch_len: 16
dropout: 0.2
f_hidden_size: 128
n_layers: 3
n_heads: 16
fc_dropout: 0.2
head_dropout: 0
individual: false
learning_rate: 0.0001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: traffic_ltsf
split_val: true
scaler: standard # identity, standard, temporal
context_length: 96
prediction_length: 96
batch_size: 8
test_batch_size: 8
num_workers: 8
================================================
FILE: config/ltsf/traffic_ltsf/timegrad.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
accumulate_grad_batches: 4
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.TimeGrad
init_args:
loss_type: l2
diff_steps: 100
beta_end: 0.1
beta_schedule: linear
conditional_length: 200
enc_hidden_size: 128
enc_num_layers: 3
enc_dropout: 0.1
use_lags: true
use_feat_idx_emb: true
use_time_feat: true
feat_idx_emb_dim: 1
use_scaling: true
num_samples: 100
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: traffic_ltsf
split_val: true
scaler: identity # identity, standard, temporal
context_length: 96
prediction_length: 96
batch_size: 16
test_batch_size: 16
num_workers: 8
================================================
FILE: config/ltsf/weather_ltsf/csdi.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 800
log_every_n_steps: 1
check_val_every_n_epoch: 2
default_root_dir: ./results
accumulate_grad_batches: 8
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.CSDI
init_args:
emb_time_dim: 128
emb_feature_dim: 16
channels: 64
n_layers: 4
num_heads: 8
num_steps: 50
diffusion_embedding_dim: 128
beta_start: 0.001
beta_end: 0.5
sample_size: 64
linear_trans: false
use_lags: false
use_feat_idx_emb: false
use_time_feat: false
feat_idx_emb_dim: 1
num_samples: 100
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: weather_ltsf
split_val: true
scaler: standard # identity, standard, temporal
context_length: 96
prediction_length: 96
batch_size: 8
test_batch_size: 8
num_workers: 8
================================================
FILE: config/ltsf/weather_ltsf/dlinear.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
accumulate_grad_batches: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.DLinear
init_args:
individual: false
kernel_size: 25
use_lags: false
use_feat_idx_emb: false
use_time_feat: false
learning_rate: 0.0001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: weather_ltsf
split_val: true
scaler: standard # identity, standard, temporal
context_length: 96
prediction_length: 96
batch_size: 32
test_batch_size: 32
num_workers: 8
================================================
FILE: config/ltsf/weather_ltsf/gru_nvp.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 400
log_every_n_steps: 1
default_root_dir: ./results
accumulate_grad_batches: 4
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.GRU_NVP
init_args:
enc_hidden_size: 64
enc_num_layers: 4
enc_dropout: 0.1
n_blocks: 4
hidden_size: 128
n_hidden: 3
batch_norm: false
conditional_length: 200
dequantize: false
use_lags: true
use_feat_idx_emb: true
use_time_feat: true
feat_idx_emb_dim: 1
use_scaling: true
num_samples: 100
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: weather_ltsf
split_val: true
scaler: identity # identity, standard, temporal
context_length: 96
prediction_length: 96
batch_size: 16
test_batch_size: 16
num_workers: 8
================================================
FILE: config/ltsf/weather_ltsf/patchtst.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
accumulate_grad_batches: 1
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.PatchTST
init_args:
stride: 8
patch_len: 16
dropout: 0.2
f_hidden_size: 128
n_layers: 3
n_heads: 16
fc_dropout: 0.2
head_dropout: 0
individual: false
learning_rate: 0.0001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: weather_ltsf
split_val: true
scaler: standard # identity, standard, temporal
context_length: 96
prediction_length: 96
batch_size: 32
test_batch_size: 32
num_workers: 8
================================================
FILE: config/ltsf/weather_ltsf/timegrad.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
accumulate_grad_batches: 4
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.TimeGrad
init_args:
loss_type: l2
diff_steps: 100
beta_end: 0.1
beta_schedule: linear
conditional_length: 200
enc_hidden_size: 64
enc_num_layers: 4
enc_dropout: 0.1
use_lags: true
use_feat_idx_emb: true
use_time_feat: true
feat_idx_emb_dim: 1
use_scaling: true
num_samples: 100
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: weather_ltsf
split_val: true
scaler: identity # identity, standard, temporal
context_length: 96
prediction_length: 96
batch_size: 16
test_batch_size: 16
num_workers: 8
================================================
FILE: config/m4/m4_daily/csdi.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 800
log_every_n_steps: 2
check_val_every_n_epoch: 2
default_root_dir: ./results
accumulate_grad_batches: 8
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.CSDI
init_args:
emb_time_dim: 32
emb_feature_dim: 4
channels: 16
n_layers: 4
num_heads: 4
num_steps: 50
diffusion_embedding_dim: 32
beta_start: 0.001
beta_end: 0.5
sample_size: 64
linear_trans: false
use_lags: false
use_feat_idx_emb: false
use_time_feat: false
feat_idx_emb_dim: 1
num_samples: 100
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: m4_daily
context_length_factor: 3
split_val: true
scaler: standard # identity, standard, temporal
batch_size: 1
test_batch_size: 1
num_workers: 8
================================================
FILE: config/m4/m4_daily/dlinear.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 800
log_every_n_steps: 2
default_root_dir: ./results
accumulate_grad_batches: 8
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.DLinear
init_args:
individual: false
kernel_size: 3
use_lags: false
use_feat_idx_emb: false
use_time_feat: false
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: m4_daily
context_length_factor: 3
split_val: true
scaler: standard # identity, standard, temporal
batch_size: 1
test_batch_size: 1
num_workers: 8
================================================
FILE: config/m4/m4_daily/gru_nvp.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 800
log_every_n_steps: 2
default_root_dir: ./results
accumulate_grad_batches: 8
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.GRU_NVP
init_args:
enc_hidden_size: 40
enc_num_layers: 2
enc_dropout: 0.1
n_blocks: 2
hidden_size: 100
n_hidden: 2
batch_norm: true
conditional_length: 100
dequantize: false
use_lags: false
use_feat_idx_emb: false
use_time_feat: false
feat_idx_emb_dim: 1
use_scaling: true
num_samples: 100
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: m4_daily
context_length_factor: 3
split_val: true
scaler: identity # identity, standard, temporal
batch_size: 1
test_batch_size: 1
num_workers: 8
================================================
FILE: config/m4/m4_daily/patchtst.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 800
log_every_n_steps: 2
default_root_dir: ./results
accumulate_grad_batches: 8
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.PatchTST
init_args:
stride: 2
patch_len: 6
dropout: 0.3
f_hidden_size: 32
d_ff: 128
n_layers: 3
n_heads: 8
fc_dropout: 0.2
head_dropout: 0
individual: true
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: m4_daily
context_length_factor: 3
split_val: true
scaler: standard # identity, standard, temporal
batch_size: 1
test_batch_size: 128
num_workers: 8
================================================
FILE: config/m4/m4_daily/timegrad.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 800
log_every_n_steps: 2
check_val_every_n_epoch: 2
default_root_dir: ./results
accumulate_grad_batches: 8
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.TimeGrad
init_args:
loss_type: l2
diff_steps: 50
beta_end: 0.1
beta_schedule: linear
conditional_length: 100
enc_hidden_size: 64
enc_num_layers: 4
enc_dropout: 0.1
use_lags: false
use_feat_idx_emb: false
use_time_feat: false
feat_idx_emb_dim: 1
use_scaling: true
num_samples: 100
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: m4_daily
context_length_factor: 3
split_val: true
scaler: identity # identity, standard, temporal
batch_size: 1
test_batch_size: 1
num_workers: 8
================================================
FILE: config/m4/m4_weekly/csdi.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 800
log_every_n_steps: 2
check_val_every_n_epoch: 2
default_root_dir: ./results
accumulate_grad_batches: 8
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.CSDI
init_args:
emb_time_dim: 32
emb_feature_dim: 4
channels: 16
n_layers: 4
num_heads: 4
num_steps: 50
diffusion_embedding_dim: 32
beta_start: 0.001
beta_end: 0.5
sample_size: 64
linear_trans: false
use_lags: false
use_feat_idx_emb: false
use_time_feat: false
feat_idx_emb_dim: 1
num_samples: 100
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: m4_weekly
context_length_factor: 3
split_val: true
scaler: standard # identity, standard, temporal
batch_size: 1
test_batch_size: 1
num_workers: 8
================================================
FILE: config/m4/m4_weekly/dlinear.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 800
log_every_n_steps: 2
default_root_dir: ./results
accumulate_grad_batches: 8
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.DLinear
init_args:
individual: false
kernel_size: 3
use_lags: false
use_feat_idx_emb: false
use_time_feat: false
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: m4_weekly
context_length_factor: 3
split_val: true
scaler: standard # identity, standard, temporal
batch_size: 1
test_batch_size: 1
num_workers: 8
================================================
FILE: config/m4/m4_weekly/gru_nvp.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 800
log_every_n_steps: 2
default_root_dir: ./results
accumulate_grad_batches: 8
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.GRU_NVP
init_args:
enc_hidden_size: 40
enc_num_layers: 2
enc_dropout: 0.1
n_blocks: 2
hidden_size: 100
n_hidden: 2
batch_norm: true
conditional_length: 100
dequantize: false
use_lags: false
use_feat_idx_emb: false
use_time_feat: false
feat_idx_emb_dim: 1
use_scaling: true
num_samples: 100
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: m4_weekly
context_length_factor: 3
split_val: true
scaler: identity # identity, standard, temporal
batch_size: 1
test_batch_size: 1
num_workers: 8
================================================
FILE: config/m4/m4_weekly/patchtst.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 800
log_every_n_steps: 2
default_root_dir: ./results
accumulate_grad_batches: 8
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.PatchTST
init_args:
stride: 3
patch_len: 6
dropout: 0.3
f_hidden_size: 32
d_ff: 128
n_layers: 3
n_heads: 8
fc_dropout: 0.2
head_dropout: 0
individual: true
learning_rate: 0.0001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: m4_weekly
context_length_factor: 3
split_val: true
scaler: standard # identity, standard, temporal
batch_size: 1
test_batch_size: 128
num_workers: 8
================================================
FILE: config/m4/m4_weekly/timegrad.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 800
log_every_n_steps: 2
check_val_every_n_epoch: 2
default_root_dir: ./results
accumulate_grad_batches: 8
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.TimeGrad
init_args:
loss_type: l2
diff_steps: 50
beta_end: 0.1
beta_schedule: linear
conditional_length: 100
enc_hidden_size: 64
enc_num_layers: 4
enc_dropout: 0.1
use_lags: false
use_feat_idx_emb: false
use_time_feat: false
feat_idx_emb_dim: 1
use_scaling: true
num_samples: 100
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: m4_weekly
context_length_factor: 3
split_val: true
scaler: identity # identity, standard, temporal
batch_size: 1
test_batch_size: 1
num_workers: 8
================================================
FILE: config/m4/m5/csdi.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 800
log_every_n_steps: 2
check_val_every_n_epoch: 2
default_root_dir: ./results
accumulate_grad_batches: 8
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.CSDI
init_args:
emb_time_dim: 32
emb_feature_dim: 4
channels: 16
n_layers: 4
num_heads: 4
num_steps: 50
diffusion_embedding_dim: 32
beta_start: 0.001
beta_end: 0.5
sample_size: 64
linear_trans: false
use_lags: false
use_feat_idx_emb: false
use_time_feat: false
feat_idx_emb_dim: 1
num_samples: 100
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: m5
context_length_factor: 3
split_val: true
scaler: standard # identity, standard, temporal
batch_size: 1
test_batch_size: 1
num_workers: 8
================================================
FILE: config/m4/m5/dlinear.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 800
log_every_n_steps: 2
check_val_every_n_epoch: 2
default_root_dir: ./results
accumulate_grad_batches: 8
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.DLinear
init_args:
individual: false
kernel_size: 3
use_lags: false
use_feat_idx_emb: false
use_time_feat: false
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: m5
context_length_factor: 3
split_val: true
scaler: standard # identity, standard, temporal
batch_size: 1
test_batch_size: 256
num_workers: 8
================================================
FILE: config/m4/m5/gru_nvp.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 800
log_every_n_steps: 2
check_val_every_n_epoch: 2
default_root_dir: ./results
accumulate_grad_batches: 8
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.GRU_NVP
init_args:
enc_hidden_size: 40
enc_num_layers: 2
enc_dropout: 0.1
n_blocks: 2
hidden_size: 100
n_hidden: 2
batch_norm: true
conditional_length: 100
dequantize: false
use_lags: false
use_feat_idx_emb: false
use_time_feat: false
feat_idx_emb_dim: 1
use_scaling: true
num_samples: 100
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: m5
context_length_factor: 3
split_val: true
scaler: identity # identity, standard, temporal
batch_size: 1
test_batch_size: 1
num_workers: 8
================================================
FILE: config/m4/m5/patchtst.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 800
log_every_n_steps: 2
check_val_every_n_epoch: 2
default_root_dir: ./results
accumulate_grad_batches: 8
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.PatchTST
init_args:
stride: 2
patch_len: 4
dropout: 0.3
f_hidden_size: 64
d_ff: 128
n_layers: 3
n_heads: 8
fc_dropout: 0.2
head_dropout: 0
individual: true
learning_rate: 0.0001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: m5
context_length_factor: 3
split_val: true
scaler: standard # identity, standard, temporal
batch_size: 1
test_batch_size: 128
num_workers: 8
================================================
FILE: config/m4/m5/timegrad.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 30
use_distributed_sampler: false
limit_train_batches: 800
log_every_n_steps: 2
check_val_every_n_epoch: 2
default_root_dir: ./results
accumulate_grad_batches: 8
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.TimeGrad
init_args:
loss_type: l2
diff_steps: 50
beta_end: 0.1
beta_schedule: linear
conditional_length: 100
enc_hidden_size: 64
enc_num_layers: 4
enc_dropout: 0.1
use_lags: false
use_feat_idx_emb: false
use_time_feat: false
feat_idx_emb_dim: 1
use_scaling: true
num_samples: 100
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: m5
context_length_factor: 3
split_val: true
scaler: identity # identity, standard, temporal
batch_size: 1
test_batch_size: 512
num_workers: 8
================================================
FILE: config/m4/tourism_monthly/csdi.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 800
log_every_n_steps: 2
check_val_every_n_epoch: 2
default_root_dir: ./results
accumulate_grad_batches: 8
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.CSDI
init_args:
emb_time_dim: 32
emb_feature_dim: 4
channels: 16
n_layers: 4
num_heads: 4
num_steps: 50
diffusion_embedding_dim: 32
beta_start: 0.001
beta_end: 0.5
sample_size: 64
linear_trans: false
use_lags: false
use_feat_idx_emb: false
use_time_feat: false
feat_idx_emb_dim: 1
num_samples: 100
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: tourism_monthly
context_length_factor: 3
split_val: true
scaler: standard # identity, standard, temporal
batch_size: 1
test_batch_size: 1
num_workers: 8
================================================
FILE: config/m4/tourism_monthly/dlinear.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 800
log_every_n_steps: 2
default_root_dir: ./results
accumulate_grad_batches: 8
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.DLinear
init_args:
individual: false
kernel_size: 3
use_lags: false
use_feat_idx_emb: false
use_time_feat: false
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: tourism_monthly
context_length_factor: 3
split_val: true
scaler: standard # identity, standard, temporal
batch_size: 1
test_batch_size: 1
num_workers: 8
================================================
FILE: config/m4/tourism_monthly/gru_nvp.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 800
log_every_n_steps: 2
default_root_dir: ./results
accumulate_grad_batches: 8
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.GRU_NVP
init_args:
enc_hidden_size: 40
enc_num_layers: 2
enc_dropout: 0.1
n_blocks: 2
hidden_size: 100
n_hidden: 2
batch_norm: true
conditional_length: 100
dequantize: false
use_lags: false
use_feat_idx_emb: false
use_time_feat: false
feat_idx_emb_dim: 1
use_scaling: true
num_samples: 100
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: tourism_monthly
context_length_factor: 3
split_val: true
scaler: identity # identity, standard, temporal
batch_size: 1
test_batch_size: 1
num_workers: 8
================================================
FILE: config/m4/tourism_monthly/patchtst.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 800
log_every_n_steps: 2
default_root_dir: ./results
accumulate_grad_batches: 8
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.PatchTST
init_args:
stride: 2
patch_len: 6
dropout: 0.3
f_hidden_size: 64
d_ff: 128
n_layers: 3
n_heads: 8
fc_dropout: 0.2
head_dropout: 0
individual: true
learning_rate: 0.0001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: tourism_monthly
context_length_factor: 3
split_val: true
scaler: standard # identity, standard, temporal
batch_size: 1
test_batch_size: 128
num_workers: 8
================================================
FILE: config/m4/tourism_monthly/timegrad.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 800
log_every_n_steps: 2
check_val_every_n_epoch: 2
default_root_dir: ./results
accumulate_grad_batches: 8
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.TimeGrad
init_args:
loss_type: l2
diff_steps: 50
beta_end: 0.1
beta_schedule: linear
conditional_length: 100
enc_hidden_size: 64
enc_num_layers: 4
enc_dropout: 0.1
use_lags: false
use_feat_idx_emb: false
use_time_feat: false
feat_idx_emb_dim: 1
use_scaling: true
num_samples: 100
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: tourism_monthly
context_length_factor: 3
split_val: true
scaler: identity # identity, standard, temporal
batch_size: 1
test_batch_size: 1
num_workers: 8
================================================
FILE: config/multi_hor/autoformer.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 0
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
accumulate_grad_batches: 1
# num_sanity_val_steps: 0
# gradient_clip_algorithm: 'norm'
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.Autoformer
init_args:
moving_avg: 25
factor: 1
n_heads: 8
activation: 'gelu'
e_layers: 2
d_layers: 1
output_attention: false
d_ff: 512
f_hidden_size: 512
embed: 'timeF'
use_lags: false
use_feat_idx_emb: false
use_time_feat: true
feat_idx_emb_dim: 1
num_samples: 1
learning_rate: 1e-3
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: etth1
split_val: true
scaler: standard # identity, standard, temporal
context_length: 96
prediction_length: 24-96-192-336-720-1024
train_ctx_len: 96
train_pred_len_list: 720
val_ctx_len: 96
val_pred_len_list: 720
batch_size: 32
test_batch_size: 32
num_workers: 8
================================================
FILE: config/multi_hor/elastst.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
accumulate_grad_batches: 1
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.ElasTST
init_args:
l_patch_size: '8_16_32'
dropout: 0.0
f_hidden_size: 256
d_inner: 256
t_layers: 2
v_layers: 0
n_heads: 8
d_v: 64
d_k: 64
structured_mask: true
rotate: true
rope_theta_init: 'exp'
learnable_rope: true
min_period: 1
max_period: 1000
addv: false
bin_att: false
learn_tem_emb: false
learning_rate: 0.001
quantiles_num: 20
sampling_weight_scheme: random
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: etth1
split_val: true
scaler: standard # identity, standard, temporal
context_length: 96
prediction_length: 24-96-192-336-720-1024
train_ctx_len: 96
train_pred_len_list: 720
val_ctx_len: 96
val_pred_len_list: 720
continuous_sample: false
batch_size: 32
test_batch_size: 32
num_workers: 8
================================================
FILE: config/pipeline_config.yaml
================================================
# lightning.pytorch==2.3.0dev
seed_everything: true
trainer:
accelerator: auto
strategy: auto
devices: auto
num_nodes: 1
precision: null
logger: null
callbacks: null
fast_dev_run: false
max_epochs: null
min_epochs: null
max_steps: -1
min_steps: null
max_time: null
limit_train_batches: null
limit_val_batches: null
limit_test_batches: null
limit_predict_batches: null
overfit_batches: 0.0
val_check_interval: null
check_val_every_n_epoch: 1
num_sanity_val_steps: null
log_every_n_steps: null
enable_checkpointing: null
enable_progress_bar: null
enable_model_summary: null
accumulate_grad_batches: 1
gradient_clip_val: null
gradient_clip_algorithm: null
deterministic: null
benchmark: null
inference_mode: true
use_distributed_sampler: true
profiler: null
detect_anomaly: false
barebones: false
plugins: null
sync_batchnorm: false
reload_dataloaders_every_n_epochs: 0
default_root_dir: null
model:
forecaster: null
num_samples: 100
learning_rate: 0.001
quantiles_num: 10
load_from_ckpt: null
data:
data_manager: null
batch_size: 64
test_batch_size: 8
num_workers: 8
================================================
FILE: config/stsf/electricity/csdi.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 800
log_every_n_steps: 1
check_val_every_n_epoch: 2
default_root_dir: ./results
accumulate_grad_batches: 8
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.CSDI
init_args:
emb_time_dim: 128
emb_feature_dim: 16
channels: 64
n_layers: 4
num_heads: 8
num_steps: 50
diffusion_embedding_dim: 128
beta_start: 0.001
beta_end: 0.5
sample_size: 64
linear_trans: false
use_lags: false
use_feat_idx_emb: false
use_time_feat: false
feat_idx_emb_dim: 1
num_samples: 100
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: electricity_nips
split_val: true
scaler: standard # identity, standard, temporal
batch_size: 4
test_batch_size: 4
num_workers: 8
================================================
FILE: config/stsf/electricity/dlinear.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.DLinear
init_args:
individual: true
kernel_size: 3
use_lags: false
use_feat_idx_emb: false
use_time_feat: false
learning_rate: 0.01
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: electricity_nips
split_val: true
scaler: standard # identity, standard, temporal
batch_size: 32
test_batch_size: 32
num_workers: 8
================================================
FILE: config/stsf/electricity/gru.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.GRUForecaster
init_args:
f_hidden_size: 40
num_layers: 2
dropout: 0.1
use_lags: true
use_feat_idx_emb: true
use_time_feat: true
feat_idx_emb_dim: 1
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: electricity_nips
split_val: true
scaler: standard # identity, standard, temporal
batch_size: 64
test_batch_size: 64
num_workers: 8
================================================
FILE: config/stsf/electricity/gru_maf.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.GRU_MAF
init_args:
enc_num_layers: 2
enc_hidden_size: 40
enc_dropout: 0.1
n_blocks: 4
hidden_size: 100
n_hidden: 2
batch_norm: true
conditional_length: 200
dequantize: false
use_lags: true
use_feat_idx_emb: true
use_time_feat: true
feat_idx_emb_dim: 1
use_scaling: true
num_samples: 100
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: electricity_nips
split_val: true
scaler: identity # identity, standard, temporal
batch_size: 64
test_batch_size: 64
num_workers: 8
================================================
FILE: config/stsf/electricity/gru_nvp.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.GRU_NVP
init_args:
enc_hidden_size: 40
enc_num_layers: 2
enc_dropout: 0.1
n_blocks: 3
hidden_size: 100
n_hidden: 2
batch_norm: true
conditional_length: 200
dequantize: false
use_lags: true
use_feat_idx_emb: true
use_time_feat: true
feat_idx_emb_dim: 1
use_scaling: true
num_samples: 100
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: electricity_nips
split_val: true
scaler: identity # identity, standard, temporal
batch_size: 64
test_batch_size: 64
num_workers: 8
================================================
FILE: config/stsf/electricity/patchtst.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
accumulate_grad_batches: 1
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.PatchTST
init_args:
stride: 2
patch_len: 4
dropout: 0.1
f_hidden_size: 64
n_layers: 4
n_heads: 8
fc_dropout: 0.1
head_dropout: 0
individual: true
learning_rate: 0.0001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: electricity_nips
split_val: true
scaler: standard # identity, standard, temporal
batch_size: 64
test_batch_size: 64
num_workers: 8
================================================
FILE: config/stsf/electricity/timegrad.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.TimeGrad
init_args:
loss_type: l2
diff_steps: 100
beta_end: 0.1
beta_schedule: linear
conditional_length: 100
enc_hidden_size: 128
enc_num_layers: 4
enc_dropout: 0.1
use_lags: true
use_feat_idx_emb: true
use_time_feat: true
feat_idx_emb_dim: 1
use_scaling: true
num_samples: 100
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: electricity_nips
split_val: true
scaler: identity # identity, standard, temporal
batch_size: 64
test_batch_size: 64
num_workers: 8
================================================
FILE: config/stsf/electricity/timesnet.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.TimesNet
init_args:
n_layers: 2
num_kernels: 6
top_k: 5
d_ff: 64
dropout: 0.1
f_hidden_size: 64
use_lags: false
use_feat_idx_emb: false
use_time_feat: false
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: electricity_nips
split_val: true
scaler: standard # identity, standard, temporal
batch_size: 64
test_batch_size: 64
num_workers: 8
================================================
FILE: config/stsf/electricity/trans_maf.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.Trans_MAF
init_args:
enc_hidden_size: 32
enc_num_heads: 8
enc_num_encoder_layers: 2
enc_num_decoder_layers: 2
enc_dim_feedforward_scale: 4
enc_dropout: 0.1
enc_activation: gelu
n_blocks: 4
hidden_size: 100
n_hidden: 2
batch_norm: true
conditional_length: 200
dequantize: false
use_lags: true
use_feat_idx_emb: true
use_time_feat: true
feat_idx_emb_dim: 1
use_scaling: true
num_samples: 100
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: electricity_nips
split_val: true
scaler: identity # identity, standard, temporal
batch_size: 64
test_batch_size: 64
num_workers: 8
================================================
FILE: config/stsf/electricity/transformer.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.TransformerForecaster
init_args:
f_hidden_size: 32
num_heads: 8
num_encoder_layers: 3
num_decoder_layers: 3
dim_feedforward_scale: 4
dropout: 0.1
activation: gelu
use_lags: true
use_feat_idx_emb: true
use_time_feat: true
feat_idx_emb_dim: 1
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: electricity_nips
split_val: true
scaler: standard # identity, standard, temporal
batch_size: 64
test_batch_size: 64
num_workers: 8
================================================
FILE: config/stsf/exchange/csdi.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 800
log_every_n_steps: 1
check_val_every_n_epoch: 2
default_root_dir: ./results
accumulate_grad_batches: 8
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.CSDI
init_args:
emb_time_dim: 128
emb_feature_dim: 16
channels: 64
n_layers: 4
num_heads: 8
num_steps: 50
diffusion_embedding_dim: 128
beta_start: 0.001
beta_end: 0.5
sample_size: 64
linear_trans: false
use_lags: false
use_feat_idx_emb: false
use_time_feat: false
feat_idx_emb_dim: 1
num_samples: 100
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: exchange_rate_nips
split_val: true
scaler: standard # identity, standard, temporal
batch_size: 4
test_batch_size: 4
num_workers: 8
================================================
FILE: config/stsf/exchange/dlinear.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.DLinear
init_args:
individual: false
kernel_size: 3
use_lags: false
use_feat_idx_emb: false
use_time_feat: false
learning_rate: 0.01
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: exchange_rate_nips
split_val: true
scaler: standard # identity, standard, temporal
batch_size: 32
test_batch_size: 32
num_workers: 8
================================================
FILE: config/stsf/exchange/gru.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.GRUForecaster
init_args:
f_hidden_size: 40
num_layers: 2
dropout: 0.1
use_lags: true
use_feat_idx_emb: true
use_time_feat: true
feat_idx_emb_dim: 1
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: exchange_rate_nips
split_val: true
scaler: standard # identity, standard, temporal
batch_size: 64
test_batch_size: 64
num_workers: 8
================================================
FILE: config/stsf/exchange/gru_maf.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.GRU_MAF
init_args:
enc_num_layers: 2
enc_hidden_size: 40
enc_dropout: 0.1
n_blocks: 4
hidden_size: 100
n_hidden: 2
batch_norm: false
conditional_length: 200
dequantize: false
use_lags: true
use_feat_idx_emb: true
use_time_feat: true
feat_idx_emb_dim: 1
use_scaling: true
num_samples: 100
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: exchange_rate_nips
split_val: true
scaler: identity # identity, standard, temporal
batch_size: 64
test_batch_size: 64
num_workers: 8
================================================
FILE: config/stsf/exchange/gru_nvp.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.GRU_NVP
init_args:
enc_hidden_size: 40
enc_num_layers: 2
enc_dropout: 0.1
n_blocks: 4
hidden_size: 100
n_hidden: 2
batch_norm: true
conditional_length: 200
dequantize: false
use_lags: true
use_feat_idx_emb: true
use_time_feat: true
feat_idx_emb_dim: 1
use_scaling: true
num_samples: 100
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: exchange_rate_nips
split_val: true
scaler: identity # identity, standard, temporal
batch_size: 64
test_batch_size: 64
num_workers: 8
================================================
FILE: config/stsf/exchange/patchtst.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
accumulate_grad_batches: 1
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.PatchTST
init_args:
stride: 3
patch_len: 6
dropout: 0.1
f_hidden_size: 32
n_layers: 3
n_heads: 8
fc_dropout: 0.2
head_dropout: 0
individual: true
learning_rate: 0.0001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: exchange_rate_nips
split_val: true
scaler: standard # identity, standard, temporal
batch_size: 64
test_batch_size: 64
num_workers: 8
================================================
FILE: config/stsf/exchange/timegrad.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.TimeGrad
init_args:
loss_type: l2
diff_steps: 100
beta_end: 0.1
beta_schedule: linear
conditional_length: 100
enc_hidden_size: 128
enc_num_layers: 4
enc_dropout: 0.1
use_lags: true
use_feat_idx_emb: true
use_time_feat: true
feat_idx_emb_dim: 1
use_scaling: true
num_samples: 100
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: exchange_rate_nips
split_val: true
scaler: identity # identity, standard, temporal
batch_size: 64
test_batch_size: 64
num_workers: 8
================================================
FILE: config/stsf/exchange/timesnet.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.TimesNet
init_args:
n_layers: 2
num_kernels: 6
top_k: 5
d_ff: 64
dropout: 0.1
f_hidden_size: 64
use_lags: false
use_feat_idx_emb: false
use_time_feat: false
learning_rate: 0.0001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: exchange_rate_nips
split_val: true
scaler: standard # identity, standard, temporal
batch_size: 64
test_batch_size: 64
num_workers: 8
================================================
FILE: config/stsf/exchange/trans_maf.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.Trans_MAF
init_args:
enc_hidden_size: 16
enc_num_heads: 8
enc_num_encoder_layers: 2
enc_num_decoder_layers: 2
enc_dim_feedforward_scale: 4
enc_dropout: 0.1
enc_activation: gelu
n_blocks: 4
hidden_size: 100
n_hidden: 2
batch_norm: false
conditional_length: 200
dequantize: false
use_lags: true
use_feat_idx_emb: true
use_time_feat: true
feat_idx_emb_dim: 1
use_scaling: true
num_samples: 100
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: exchange_rate_nips
split_val: true
scaler: identity # identity, standard, temporal
batch_size: 64
test_batch_size: 64
num_workers: 8
================================================
FILE: config/stsf/exchange/transformer.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.TransformerForecaster
init_args:
f_hidden_size: 32
num_heads: 8
num_encoder_layers: 3
num_decoder_layers: 3
dim_feedforward_scale: 4
dropout: 0.1
activation: gelu
use_lags: true
use_feat_idx_emb: true
use_time_feat: true
feat_idx_emb_dim: 1
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: exchange_rate_nips
split_val: true
scaler: standard # identity, standard, temporal
batch_size: 64
test_batch_size: 64
num_workers: 8
================================================
FILE: config/stsf/solar/csdi.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 800
log_every_n_steps: 1
check_val_every_n_epoch: 2
default_root_dir: ./results
accumulate_grad_batches: 8
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.CSDI
init_args:
emb_time_dim: 128
emb_feature_dim: 16
channels: 64
n_layers: 4
num_heads: 8
num_steps: 50
diffusion_embedding_dim: 128
beta_start: 0.001
beta_end: 0.5
sample_size: 64
linear_trans: false
use_lags: false
use_feat_idx_emb: false
use_time_feat: false
feat_idx_emb_dim: 1
num_samples: 100
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: solar_nips
split_val: true
scaler: standard # identity, standard, temporal
batch_size: 4
test_batch_size: 4
num_workers: 8
================================================
FILE: config/stsf/solar/dlinear.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.DLinear
init_args:
individual: false
kernel_size: 3
use_lags: true
use_feat_idx_emb: true
use_time_feat: true
learning_rate: 0.01
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: solar_nips
split_val: true
scaler: standard # identity, standard, temporal
batch_size: 32
test_batch_size: 32
num_workers: 8
================================================
FILE: config/stsf/solar/gru.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.GRUForecaster
init_args:
f_hidden_size: 40
num_layers: 2
dropout: 0.1
use_lags: true
use_feat_idx_emb: true
use_time_feat: true
feat_idx_emb_dim: 1
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: solar_nips
split_val: true
scaler: standard # identity, standard, temporal
batch_size: 64
test_batch_size: 64
num_workers: 8
================================================
FILE: config/stsf/solar/gru_maf.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.GRU_MAF
init_args:
enc_num_layers: 2
enc_hidden_size: 40
enc_dropout: 0.1
n_blocks: 4
hidden_size: 100
n_hidden: 2
batch_norm: false
conditional_length: 200
dequantize: true
use_lags: true
use_feat_idx_emb: true
use_time_feat: true
feat_idx_emb_dim: 1
use_scaling: true
num_samples: 100
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: solar_nips
split_val: true
scaler: identity # identity, standard, temporal
batch_size: 64
test_batch_size: 64
num_workers: 8
================================================
FILE: config/stsf/solar/gru_nvp.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.GRU_NVP
init_args:
enc_hidden_size: 40
enc_num_layers: 2
enc_dropout: 0.1
n_blocks: 4
hidden_size: 100
n_hidden: 2
batch_norm: true
conditional_length: 200
dequantize: true
use_lags: true
use_feat_idx_emb: true
use_time_feat: true
feat_idx_emb_dim: 1
use_scaling: true
num_samples: 100
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: solar_nips
split_val: true
scaler: identity # identity, standard, temporal
batch_size: 64
test_batch_size: 64
num_workers: 8
================================================
FILE: config/stsf/solar/patchtst.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
accumulate_grad_batches: 1
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.PatchTST
init_args:
stride: 3
patch_len: 6
dropout: 0.1
f_hidden_size: 32
n_layers: 3
n_heads: 8
fc_dropout: 0.2
head_dropout: 0
individual: true
learning_rate: 0.0001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: solar_nips
split_val: true
scaler: standard # identity, standard, temporal
batch_size: 64
test_batch_size: 64
num_workers: 8
================================================
FILE: config/stsf/solar/timegrad.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.TimeGrad
init_args:
loss_type: l2
diff_steps: 100
beta_end: 0.1
beta_schedule: linear
conditional_length: 100
enc_hidden_size: 128
enc_num_layers: 4
enc_dropout: 0.1
use_lags: true
use_feat_idx_emb: true
use_time_feat: true
feat_idx_emb_dim: 1
use_scaling: true
num_samples: 100
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: solar_nips
split_val: true
scaler: identity # identity, standard, temporal
batch_size: 64
test_batch_size: 64
num_workers: 8
================================================
FILE: config/stsf/solar/timesnet.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.TimesNet
init_args:
n_layers: 2
num_kernels: 6
top_k: 5
d_ff: 16
dropout: 0.1
f_hidden_size: 16
use_lags: false
use_feat_idx_emb: false
use_time_feat: false
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: solar_nips
split_val: true
scaler: standard # identity, standard, temporal
batch_size: 64
test_batch_size: 64
num_workers: 8
================================================
FILE: config/stsf/solar/trans_maf.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.Trans_MAF
init_args:
enc_hidden_size: 32
enc_num_heads: 8
enc_num_encoder_layers: 2
enc_num_decoder_layers: 2
enc_dim_feedforward_scale: 4
enc_dropout: 0.1
enc_activation: gelu
n_blocks: 4
hidden_size: 100
n_hidden: 2
batch_norm: false
conditional_length: 200
dequantize: true
use_lags: true
use_feat_idx_emb: true
use_time_feat: true
feat_idx_emb_dim: 1
use_scaling: true
num_samples: 100
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: solar_nips
split_val: true
scaler: identity # identity, standard, temporal
batch_size: 64
test_batch_size: 64
num_workers: 8
================================================
FILE: config/stsf/solar/transformer.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.TransformerForecaster
init_args:
f_hidden_size: 16
num_heads: 4
num_encoder_layers: 3
num_decoder_layers: 3
dim_feedforward_scale: 4
dropout: 0.1
activation: gelu
use_lags: true
use_feat_idx_emb: true
use_time_feat: true
feat_idx_emb_dim: 1
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: solar_nips
split_val: true
scaler: standard # identity, standard, temporal
batch_size: 64
test_batch_size: 64
num_workers: 8
================================================
FILE: config/stsf/traffic/csdi.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 400
log_every_n_steps: 1
check_val_every_n_epoch: 3
default_root_dir: ./results
accumulate_grad_batches: 4
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.CSDI
init_args:
emb_time_dim: 64
emb_feature_dim: 8
channels: 64
n_layers: 4
num_heads: 8
num_steps: 50
diffusion_embedding_dim: 64
beta_start: 0.001
beta_end: 0.5
sample_size: 16
linear_trans: false
use_lags: false
use_feat_idx_emb: false
use_time_feat: false
feat_idx_emb_dim: 1
num_samples: 100
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: traffic_nips
split_val: true
scaler: standard # identity, standard, temporal
batch_size: 8
test_batch_size: 8
num_workers: 8
================================================
FILE: config/stsf/traffic/dlinear.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.DLinear
init_args:
individual: false
kernel_size: 3
use_lags: false
use_feat_idx_emb: false
use_time_feat: false
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: traffic_nips
split_val: true
scaler: standard # identity, standard, temporal
batch_size: 32
test_batch_size: 32
num_workers: 8
================================================
FILE: config/stsf/traffic/gru.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.GRUForecaster
init_args:
f_hidden_size: 128
num_layers: 2
dropout: 0.1
use_lags: true
use_feat_idx_emb: true
use_time_feat: true
feat_idx_emb_dim: 1
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: traffic_nips
split_val: true
scaler: standard # identity, standard, temporal
batch_size: 32
test_batch_size: 32
num_workers: 8
================================================
FILE: config/stsf/traffic/gru_maf.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.GRU_MAF
init_args:
enc_num_layers: 2
enc_hidden_size: 128
enc_dropout: 0.3
n_blocks: 3
hidden_size: 100
n_hidden: 2
batch_norm: true
conditional_length: 200
dequantize: false
use_lags: true
use_feat_idx_emb: true
use_time_feat: true
feat_idx_emb_dim: 1
use_scaling: true
num_samples: 100
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: traffic_nips
split_val: true
scaler: identity # identity, standard, temporal
batch_size: 32
test_batch_size: 32
num_workers: 8
================================================
FILE: config/stsf/traffic/gru_nvp.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.GRU_NVP
init_args:
enc_hidden_size: 128
enc_num_layers: 2
enc_dropout: 0.3
n_blocks: 4
hidden_size: 100
n_hidden: 2
batch_norm: true
conditional_length: 200
dequantize: false
use_lags: true
use_feat_idx_emb: true
use_time_feat: true
feat_idx_emb_dim: 1
use_scaling: true
num_samples: 100
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: traffic_nips
split_val: true
scaler: identity # identity, standard, temporal
batch_size: 32
test_batch_size: 32
num_workers: 8
================================================
FILE: config/stsf/traffic/patchtst.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
accumulate_grad_batches: 1
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.PatchTST
init_args:
stride: 3
patch_len: 6
dropout: 0.1
f_hidden_size: 32
n_layers: 3
n_heads: 8
fc_dropout: 0.2
head_dropout: 0
individual: false
num_samples: 100
learning_rate: 0.0001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: traffic_nips
split_val: true
scaler: standard # identity, standard, temporal
batch_size: 64
test_batch_size: 64
num_workers: 8
================================================
FILE: config/stsf/traffic/timegrad.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.TimeGrad
init_args:
loss_type: l2
diff_steps: 100
beta_end: 0.1
beta_schedule: linear
conditional_length: 100
enc_hidden_size: 128
enc_num_layers: 4
enc_dropout: 0.1
use_lags: true
use_feat_idx_emb: true
use_time_feat: true
feat_idx_emb_dim: 1
use_scaling: true
num_samples: 100
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: traffic_nips
split_val: true
scaler: identity # identity, standard, temporal
batch_size: 32
test_batch_size: 32
num_workers: 8
================================================
FILE: config/stsf/traffic/timesnet.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.TimesNet
init_args:
n_layers: 2
num_kernels: 6
top_k: 5
d_ff: 16
dropout: 0.1
f_hidden_size: 16
use_lags: false
use_feat_idx_emb: false
use_time_feat: false
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: traffic_nips
split_val: true
scaler: standard # identity, standard, temporal
batch_size: 64
test_batch_size: 64
num_workers: 8
================================================
FILE: config/stsf/traffic/trans_maf.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.Trans_MAF
init_args:
enc_hidden_size: 128
enc_num_heads: 4
enc_num_encoder_layers: 2
enc_num_decoder_layers: 2
enc_dim_feedforward_scale: 4
enc_dropout: 0.1
enc_activation: gelu
n_blocks: 3
hidden_size: 100
n_hidden: 2
batch_norm: true
conditional_length: 200
dequantize: false
use_lags: true
use_feat_idx_emb: true
use_time_feat: true
feat_idx_emb_dim: 1
use_scaling: true
num_samples: 100
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: traffic_nips
split_val: true
scaler: identity # identity, standard, temporal
batch_size: 32
test_batch_size: 32
num_workers: 8
================================================
FILE: config/stsf/traffic/transformer.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.TransformerForecaster
init_args:
f_hidden_size: 32
num_heads: 8
num_encoder_layers: 3
num_decoder_layers: 3
dim_feedforward_scale: 4
dropout: 0.1
activation: gelu
use_lags: true
use_feat_idx_emb: true
use_time_feat: true
feat_idx_emb_dim: 1
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: traffic_nips
split_val: true
scaler: standard # identity, standard, temporal
batch_size: 32
test_batch_size: 32
num_workers: 8
================================================
FILE: config/stsf/wiki/csdi.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 400
log_every_n_steps: 1
check_val_every_n_epoch: 3
default_root_dir: ./results
accumulate_grad_batches: 4
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.CSDI
init_args:
emb_time_dim: 64
emb_feature_dim: 8
channels: 64
n_layers: 4
num_heads: 8
num_steps: 50
diffusion_embedding_dim: 64
beta_start: 0.001
beta_end: 0.5
sample_size: 16
linear_trans: false
use_lags: false
use_feat_idx_emb: false
use_time_feat: false
feat_idx_emb_dim: 1
num_samples: 100
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: wiki2000_nips
split_val: true
scaler: standard # identity, standard, temporal
batch_size: 8
test_batch_size: 8
num_workers: 8
================================================
FILE: config/stsf/wiki/dlinear.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.DLinear
init_args:
individual: false
kernel_size: 3
use_lags: false
use_feat_idx_emb: false
use_time_feat: false
learning_rate: 0.0001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: wiki2000_nips
split_val: true
scaler: standard # identity, standard, temporal
batch_size: 32
test_batch_size: 32
num_workers: 8
================================================
FILE: config/stsf/wiki/gru.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.GRUForecaster
init_args:
f_hidden_size: 40
num_layers: 2
dropout: 0.1
use_lags: true
use_feat_idx_emb: true
use_time_feat: true
feat_idx_emb_dim: 1
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: wiki2000_nips
split_val: true
scaler: standard # identity, standard, temporal
batch_size: 64
test_batch_size: 64
num_workers: 8
================================================
FILE: config/stsf/wiki/gru_maf.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.GRU_MAF
init_args:
enc_num_layers: 2
enc_hidden_size: 40
enc_dropout: 0.1
n_blocks: 3
hidden_size: 100
n_hidden: 2
batch_norm: true
conditional_length: 200
dequantize: true
use_lags: true
use_feat_idx_emb: true
use_time_feat: true
feat_idx_emb_dim: 1
use_scaling: true
num_samples: 100
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: wiki2000_nips
split_val: true
scaler: identity # identity, standard, temporal
batch_size: 64
test_batch_size: 64
num_workers: 8
================================================
FILE: config/stsf/wiki/gru_nvp.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.GRU_NVP
init_args:
enc_hidden_size: 40
enc_num_layers: 2
enc_dropout: 0.1
n_blocks: 3
hidden_size: 100
n_hidden: 2
batch_norm: true
conditional_length: 200
dequantize: true
use_lags: true
use_feat_idx_emb: true
use_time_feat: true
feat_idx_emb_dim: 1
use_scaling: true
num_samples: 100
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: wiki2000_nips
split_val: true
scaler: identity # identity, standard, temporal
batch_size: 64
test_batch_size: 64
num_workers: 8
================================================
FILE: config/stsf/wiki/patchtst.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 400
log_every_n_steps: 1
default_root_dir: ./results
accumulate_grad_batches: 4
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.PatchTST
init_args:
stride: 4
patch_len: 8
dropout: 0.1
f_hidden_size: 32
n_layers: 2
n_heads: 8
fc_dropout: 0.2
head_dropout: 0
individual: false
num_samples: 100
learning_rate: 0.0001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: wiki2000_nips
split_val: true
scaler: standard # identity, standard, temporal
batch_size: 16
test_batch_size: 16
num_workers: 8
================================================
FILE: config/stsf/wiki/timegrad.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.TimeGrad
init_args:
loss_type: l2
diff_steps: 100
beta_end: 0.1
beta_schedule: linear
conditional_length: 100
enc_hidden_size: 128
enc_num_layers: 4
enc_dropout: 0.1
use_lags: true
use_feat_idx_emb: true
use_time_feat: true
feat_idx_emb_dim: 1
use_scaling: true
num_samples: 100
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: wiki2000_nips
split_val: true
scaler: identity # identity, standard, temporal
batch_size: 64
test_batch_size: 64
num_workers: 8
================================================
FILE: config/stsf/wiki/timesnet.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.TimesNet
init_args:
n_layers: 2
num_kernels: 6
top_k: 5
d_ff: 32
dropout: 0.1
f_hidden_size: 32
use_lags: false
use_feat_idx_emb: false
use_time_feat: false
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: wiki2000_nips
split_val: true
scaler: standard # identity, standard, temporal
batch_size: 64
test_batch_size: 64
num_workers: 8
================================================
FILE: config/stsf/wiki/trans_maf.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.Trans_MAF
init_args:
enc_hidden_size: 128
enc_num_heads: 4
enc_num_encoder_layers: 2
enc_num_decoder_layers: 2
enc_dim_feedforward_scale: 4
enc_dropout: 0.1
enc_activation: gelu
n_blocks: 3
hidden_size: 100
n_hidden: 2
batch_norm: true
conditional_length: 200
dequantize: true
use_lags: true
use_feat_idx_emb: true
use_time_feat: true
feat_idx_emb_dim: 1
use_scaling: true
num_samples: 100
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: wiki2000_nips
split_val: true
scaler: identity # identity, standard, temporal
batch_size: 64
test_batch_size: 64
num_workers: 8
================================================
FILE: config/stsf/wiki/transformer.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.TransformerForecaster
init_args:
f_hidden_size: 32
num_heads: 8
num_encoder_layers: 3
num_decoder_layers: 3
dim_feedforward_scale: 4
dropout: 0.1
activation: gelu
use_lags: true
use_feat_idx_emb: true
use_time_feat: true
feat_idx_emb_dim: 1
learning_rate: 0.001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: wiki2000_nips
split_val: true
scaler: standard # identity, standard, temporal
batch_size: 64
test_batch_size: 64
num_workers: 8
================================================
FILE: config/tsfm/chronos.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 0
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 40
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.Chronos
init_args:
model_size: base # tiny, mini, small, base, large
num_samples: 100
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: solar_nips
split_val: true
scaler: standard # identity, standard, temporal
batch_size: 16
test_batch_size: 16
num_workers: 8
================================================
FILE: config/tsfm/forecastpfn.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 0
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 40
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.ForecastPFN
init_args:
label_len: 48
ckpt_path: ./checkpoints/ForecastPFN/saved_weights
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: solar_nips
split_val: true
scaler: standard # identity, standard, temporal
timeenc: 2
batch_size: 64
test_batch_size: 64
num_workers: 8
================================================
FILE: config/tsfm/lag_llama.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 0
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 40
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.LagLlama
init_args:
use_rope_scaling: true
ckpt_path: ./checkpoints/lag-llama/lag-llama.ckpt
num_samples: 100
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: solar_nips
split_val: true
scaler: identity # identity, standard, temporal
timeenc: 2
batch_size: 1
test_batch_size: 1
num_workers: 8
================================================
FILE: config/tsfm/moirai/context_5000/electricity_ltsf.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 0
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 1
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.Moirai
init_args:
variate_mode: S
patch_size: 128
model_size: base
scaling: true
num_samples: 100
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: electricity_ltsf
split_val: true
scaler: standard # identity, standard, temporal
var_specific_norm: false
context_length: 5000
auto_search: true
batch_size: 1
test_batch_size: 1
num_workers: 8
================================================
FILE: config/tsfm/moirai/context_5000/electricity_nips.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 0
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 1
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.Moirai
init_args:
variate_mode: S
patch_size: 64
model_size: base
scaling: true
num_samples: 100
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: electricity_nips
split_val: true
scaler: standard # identity, standard, temporal
var_specific_norm: true
context_length: 3800 # maximum history length
auto_search: true
batch_size: 1
test_batch_size: 1
num_workers: 8
================================================
FILE: config/tsfm/moirai/context_5000/etth1.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 0
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 1
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.Moirai
init_args:
variate_mode: M
patch_size: 64
model_size: base
scaling: true
num_samples: 100
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: etth1
split_val: true
scaler: standard # identity, standard, temporal
context_length: 5000
auto_search: true
batch_size: 64
test_batch_size: 64
num_workers: 8
================================================
FILE: config/tsfm/moirai/context_5000/etth2.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 0
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 1
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.Moirai
init_args:
variate_mode: M
patch_size: 64
model_size: base
scaling: true
num_samples: 100
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: etth2
split_val: true
scaler: standard # identity, standard, temporal
context_length: 5000
auto_search: true
batch_size: 64
test_batch_size: 64
num_workers: 8
================================================
FILE: config/tsfm/moirai/context_5000/ettm1.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 0
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 1
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.Moirai
init_args:
variate_mode: S
patch_size: 64
model_size: base
scaling: true
num_samples: 100
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: ettm1
split_val: true
scaler: standard # identity, standard, temporal
context_length: 5000
auto_search: true
batch_size: 64
test_batch_size: 64
num_workers: 8
================================================
FILE: config/tsfm/moirai/context_5000/ettm2.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 0
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 1
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.Moirai
init_args:
variate_mode: M
patch_size: 128
model_size: base
scaling: true
num_samples: 100
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: ettm2
split_val: true
scaler: standard # identity, standard, temporal
context_length: 5000
auto_search: true
batch_size: 64
test_batch_size: 64
num_workers: 8
================================================
FILE: config/tsfm/moirai/context_5000/exchange_rate_nips.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 0
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 1
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.Moirai
init_args:
variate_mode: M
patch_size: 128
model_size: base
scaling: true
num_samples: 100
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: exchange_rate_nips
split_val: true
scaler: standard # identity, standard, temporal
var_specific_norm: false
context_length: 5000
auto_search: true
batch_size: 1
test_batch_size: 1
num_workers: 8
================================================
FILE: config/tsfm/moirai/context_5000/solar_nips.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 0
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 1
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.Moirai
init_args:
variate_mode: S
patch_size: auto
model_size: base
scaling: true
num_samples: 100
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: solar_nips
split_val: true
scaler: identity # identity, standard, temporal
var_specific_norm: false
context_length: 5000
auto_search: true
batch_size: 1
test_batch_size: 1
num_workers: 8
================================================
FILE: config/tsfm/moirai/context_5000/weather_ltsf.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 0
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 1
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.Moirai
init_args:
variate_mode: M
patch_size: 128
model_size: base
scaling: true
num_samples: 100
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: weather_ltsf
split_val: true
scaler: standard # identity, standard, temporal
var_specific_norm: true
context_length: 5000
auto_search: true
batch_size: 64
test_batch_size: 64
num_workers: 8
================================================
FILE: config/tsfm/moirai/context_96/electricity_ltsf.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 0
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 1
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.Moirai
init_args:
variate_mode: S
patch_size: auto
model_size: base
scaling: true
num_samples: 100
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: electricity_ltsf
split_val: true
scaler: standard # identity, standard, temporal
var_specific_norm: false
context_length: 96
auto_search: true
batch_size: 4
test_batch_size: 4
num_workers: 8
================================================
FILE: config/tsfm/moirai/context_96/electricity_nips.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 0
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 1
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.Moirai
init_args:
variate_mode: S
patch_size: 64
model_size: base
scaling: true
num_samples: 100
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: electricity_nips
split_val: true
scaler: standard # identity, standard, temporal
var_specific_norm: true
context_length: 96
auto_search: true
batch_size: 1
test_batch_size: 1
num_workers: 8
================================================
FILE: config/tsfm/moirai/context_96/etth1.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 0
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 1
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.Moirai
init_args:
variate_mode: M
patch_size: auto
model_size: base
scaling: true
num_samples: 100
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: etth1
split_val: true
scaler: standard # identity, standard, temporal
var_specific_norm: false
context_length: 96
auto_search: true
batch_size: 64
test_batch_size: 64
num_workers: 8
================================================
FILE: config/tsfm/moirai/context_96/etth2.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 0
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 1
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.Moirai
init_args:
variate_mode: M
patch_size: auto
model_size: base
scaling: true
num_samples: 100
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: etth2
split_val: true
scaler: standard # identity, standard, temporal
var_specific_norm: false
context_length: 96
auto_search: true
batch_size: 64
test_batch_size: 64
num_workers: 8
================================================
FILE: config/tsfm/moirai/context_96/ettm1.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 0
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 1
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.Moirai
init_args:
variate_mode: M
patch_size: auto
model_size: base
scaling: true
num_samples: 100
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: ettm1
split_val: true
scaler: standard # identity, standard, temporal
var_specific_norm: false
context_length: 96
auto_search: true
batch_size: 64
test_batch_size: 64
num_workers: 8
================================================
FILE: config/tsfm/moirai/context_96/ettm2.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 0
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 1
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.Moirai
init_args:
variate_mode: M
patch_size: auto
model_size: base
scaling: true
num_samples: 100
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: ettm2
split_val: true
scaler: standard # identity, standard, temporal
var_specific_norm: false
context_length: 96
auto_search: true
batch_size: 64
test_batch_size: 64
num_workers: 8
================================================
FILE: config/tsfm/moirai/context_96/exchange_rate_nips.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 0
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 1
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.Moirai
init_args:
variate_mode: M
patch_size: auto
model_size: base
scaling: true
num_samples: 100
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: exchange_rate_nips
split_val: true
scaler: standard # identity, standard, temporal
var_specific_norm: true
context_length: 96
auto_search: true
batch_size: 1
test_batch_size: 1
num_workers: 8
================================================
FILE: config/tsfm/moirai/context_96/solar_nips.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 0
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 1
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.Moirai
init_args:
variate_mode: S
patch_size: auto
model_size: base
scaling: true
num_samples: 100
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: solar_nips
split_val: true
scaler: identity # identity, standard, temporal
var_specific_norm: false
context_length: 96
auto_search: true
batch_size: 1
test_batch_size: 1
num_workers: 8
================================================
FILE: config/tsfm/moirai/context_96/weather_ltsf.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 0
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 1
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.Moirai
init_args:
variate_mode: M
patch_size: auto
model_size: base
scaling: true
num_samples: 100
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: weather_ltsf
split_val: true
scaler: standard # identity, standard, temporal
var_specific_norm: true
context_length: 96
auto_search: true
batch_size: 64
test_batch_size: 64
num_workers: 8
================================================
FILE: config/tsfm/moirai.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 0
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 40
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.prob_forecaster.Moirai
init_args:
variate_mode: S
patch_size: auto
model_size: base
scaling: true
num_samples: 100
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: solar_nips
split_val: true
scaler: identity # identity, standard, temporal
auto_search: true
batch_size: 64
test_batch_size: 64
num_workers: 8
================================================
FILE: config/tsfm/time_moe.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 0
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 40
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.TimeMoE
init_args:
model_size: 200M # select from ['50M', '200M']
instance_norm: true
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: solar_nips
split_val: true
scaler: identity # identity, standard, temporal
var_specific_norm: true
batch_size: 64
test_batch_size: 64
num_workers: 8
================================================
FILE: config/tsfm/timer.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 0
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 40
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.Timer
init_args:
label_len: 96
ckpt_path: ./checkpoints/timer/Timer_67M_UTSD_4G.pt
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: solar_nips
split_val: true
scaler: standard # identity, standard, temporal
batch_size: 64
test_batch_size: 64
num_workers: 8
================================================
FILE: config/tsfm/timesfm.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 0
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 40
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.TimesFM
init_args:
model_size: 200m # select from ['200m', '500m']
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: solar_nips
split_val: true
scaler: identity # identity, standard, temporal
var_specific_norm: true
batch_size: 64
test_batch_size: 64
num_workers: 8
================================================
FILE: config/tsfm/tinytimemixer.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 0
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 40
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.TinyTimeMixer
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: solar_nips
split_val: true
scaler: standard # identity, standard, temporal
batch_size: 64
test_batch_size: 64
num_workers: 8
================================================
FILE: config/tsfm/units.yaml
================================================
# lightning==2.3.0.dev0
seed_everything: 0
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 40
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.UniTS
init_args:
ckpt_path: ./checkpoints/units/units_x128_pretrain_checkpoint.pth
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: solar_nips
split_val: true
scaler: standard # identity, standard, temporal
# var_norm: true
batch_size: 64
test_batch_size: 64
num_workers: 8
================================================
FILE: datasets/.gitignore
================================================
*
!.gitignore
================================================
FILE: docs/benchmark/README.md
================================================
# Benchmarking :balance_scale:
Accurate point and distributional forecasts across diverse horizons are crucial for time-series forecasting. However, existing research often focuses on isolated aspects, such as long-term point forecasting or short-term probabilistic estimation. This raises a fundamental question: **How do different methodological designs address these diverse forecasting needs?**
In this repository, we:
1. **Provide Detailed Reproduction Guides:** Offer comprehensive instructions for replicating supervised models and pre-trained foundation models.
2. **Evaluate Methods Under a Unified Framework:** Align and assess existing methods across various data scenarios using a consistent benchmarking framework.
3. **Deliver In-Depth Insights:** Present detailed analyses and insights into the experimental results.
## Benchmarking Scripts
- [Supervised Forecasting Models](./supervised_model/README.md)
- [Pre-trained Time-Series Foundation Models](./foundation_model/README.md)
## Methodology Overview

================================================
FILE: docs/benchmark/foundation_model/README.md
================================================
# Time Series Foundation Models Benchmarking
- [Time Series Foundation Models Benchmarking](#time-series-foundation-models-benchmarking)
- [Foundation Models](#foundation-models)
- [Overview](#overview)
- [Results Reproduction](#results-reproduction)
- [Key Insights \& Takeaways](#key-insights--takeaways)
- [Experimental Results](#experimental-results)
- [Comparison Across Horizons](#comparison-across-horizons)
- [Short-term Probabilistic Forecasting](#short-term-probabilistic-forecasting)
## Foundation Models
### Overview
| Model | Backbone | Dec. | Varied Hor. | Dist. Head | Var. | Hyper-param in Inference | Running Guides |
| --- | --- | --- | --- | --- | --- | --- | --- |
| [Lag-Llama](https://github.com/time-series-foundation-models/lag-llama) | Dec-only Trans. | AR | √ | Student' t | Uni | `context len`, `pred len`, `use_rope_scaling` | [Details](./lag-llama.md) |
| [Chronos](https://github.com/amazon-science/chronos-forecasting) | Enc-Dec Trans. | AR | √ | Arbitrary | Uni | `context len`, `pred len`, `num_samples`, `temperature`, `top_k`, `top_p` | [Details](./chronos.md) |
| [TimesFM](https://github.com/google-research/timesfm) | Dec-only Trans. | AR | √ | - | Uni | `context len`, `frequency`, `window size` | [Details](./timesfm.md) |
| [Timer](https://github.com/thuml/Large-Time-Series-Model) | Dec-only Trans. | AR | √ | - | Uni | `context len`, `pred len`, `use_ims` | [Details](./timer.md) |
| [MOIRAI](https://github.com/SalesforceAIResearch/uni2ts) | Enc-only Trans. | NAR | √ | Mixture dist. | Multi | `context len`, `pred len`, `patch size`, `variate_mode` | [Details](./moirai.md) |
| [ForecastPFN](https://github.com/abacusai/ForecastPFN) | Enc-only Trans. | NAR | √ | - | Uni | `context len`, `pred len` | [Details](./forecastpfn.md) |
| [UniTS](https://github.com/mims-harvard/UniTS) | Enc-only Trans. | NAR | √ | - | Multi | `context len`, `pred len` | [Details](./units.md) |
| [Tiny Time Mixers](https://github.com/ibm-granite/granite-tsfm/tree/main/tsfm_public/models/tinytimemixer) | TSMixer | NAR | x | - | Multi | `context len`, `pred len` | [Details](./ttm.md) |
### Results Reproduction
For time-series foundation models, you need to install basic packages and additional dependencies:
**1. Set Up Environment**
```bash
# Create a new conda environment
conda create -n probts_fm python=3.10
conda activate probts_fm
# Git submodule
git submodule update --init --recursive
# Install additional packages for foundation models
pip install ".[tsfm]"
pip uninstall -y probts # recommended to uninstall the root package (optional)
```
**2. Initialize Submodules**
To running model MOIRAI, TimesFM, Lag-Llama and TinyTimeMixer, please run the following commands for submodules initialization.
```bash
# For MOIRAI, we fix the version of the package for better performance
cd submodules/uni2ts
git reset --hard fce6a6f57bc3bc1a57c7feb3abc6c7eb2f264301
# For TimesFM, fix the version for reproducibility (optional)
cd submodules/timesfm
git reset --hard 5c7b905
# For Lag-Llama, fix the version for reproducibility (optional)
cd submodules/lag_llama
git reset --hard 4ad82d9
# For TinyTimeMixer, fix the version for reproducibility (optional)
cd submodules/tsfm
git reset --hard bb125c14a05e4231636d6b64f8951d5fe96da1dc
```
**3. Download Model Checkpoints**
Download the necessary checkpoints (More details are available [here](./checkpoints/README.md)):
```bash
bash scripts/prepare_tsfm_checkpoints.sh
```
Note: By downloading, you agree to the original license terms.
**4. Run Benchmarking:**
Reproduce the results reported in the ProbTS paper:
```bash
bash scripts/reproduce_tsfm_results.sh
```
Configuration files are in [config/tsfm/](../../config/tsfm/).
**5. Experimental Results Analysis (Coming Soon)** :construction:
Analysis notebooks will be added in a future update.
## Key Insights & Takeaways
**1. Similar Insights in Evaluating Supervised Models Reconfirmed**
- Handling **Varied Forecasting Horizons:** Current AR-based time-series foundation models also encounter error accumulation problems.
- Addressing **Complex Data Distributions:** Predefined distribution heads lack the capability to fully capture complex data distributions.
**2. Supervised Time-Series Models vs. Pre-trained Foundation Models**
- There is no definitive winner yet!

**Takeaways:**
- In practice, you may need to choose the right paradigm based on specific cases:
- Unique data patterns → supervised models
- Scarce training data → pre-trained models, etc.
## Experimental Results
### Comparison Across Horizons

Figure. We use a dashed line to denote the datasets on which the model was pre-trained, e.g., both TimesFM and MOIRAI have leveraged Traffic datasets for their pre-training. The ETT encompasses averaged results from datasets ETTh1, ETTh2, ETTm1, and ETTm2.
Table 3. NMAE of time-series foundation models on diverse prediction horizons. The input sequence length is set to 96 if not specified. For every model, we exclude the evaluation results on its pre-trained datasets

### Short-term Probabilistic Forecasting
Table 4. Results of probabilistic foundation models on short-term distributional forecasting. For every model, we exclude the evaluation results on its pre-trained datasets.

================================================
FILE: docs/benchmark/foundation_model/chronos.md
================================================
# Running Inference with Chronos
[Original Repository](https://github.com/amazon-science/chronos-forecasting) | [Paper](https://arxiv.org/abs/2403.07815)
Follow these steps to set up and run inference using Chronos:
1. Set up the [environment](../README.md#results-reproduction).
2. Run the inference script with the following commands:
```bash
MODEL='chronos'
for DATASET in 'etth1' 'etth2' 'ettm1' 'ettm2' 'weather_ltsf'; do
for CTX_LEN in 5000 96; do
for PRED_LEN in 24 48 96 192 336 720; do
python run.py --config config/tsfm/${MODEL}.yaml --seed_everything 0 \
--data.data_manager.init_args.path ${DATA_DIR} \
--trainer.default_root_dir ${LOG_DIR} \
--data.data_manager.init_args.split_val true \
--data.data_manager.init_args.dataset ${DATASET} \
--data.data_manager.init_args.context_length ${CTX_LEN} \
--data.data_manager.init_args.prediction_length ${PRED_LEN} \
--data.test_batch_size 1
done
done
done
```
## Hyper-param in Inference
`Temperature` (default: 1): If Temperature=0, the output is consistent. The bigger the more diverse
`top_k`(default: 50): Only conduct softmax for top-k logits.
`top-p` (default: 1): Nucleus sampling. The model sums the probabilities of the most likely next value in descending order and stops when the sum reaches p.
================================================
FILE: docs/benchmark/foundation_model/forecastpfn.md
================================================
# Running Inference with ForecastPFN
[Original Repository](https://github.com/abacusai/ForecastPFN) | [Paper](https://arxiv.org/abs/2311.01933)
Follow these steps to set up and run inference using ForecastPFN:
1. Set up the [environment](../README.md#results-reproduction).
2. Run the inference script with the following commands:
```bash
# ForecastPFN
MODEL='forecastpfn'
for DATASET in 'etth1' 'etth2' 'ettm1' 'ettm2' 'weather_ltsf'; do
for CTX_LEN in 96; do
for PRED_LEN in 24 48 96 192 336 720; do
python run.py --config config/tsfm/${MODEL}.yaml --seed_everything 0 \
--data.data_manager.init_args.path ${DATA_DIR} \
--trainer.default_root_dir ${LOG_DIR} \
--data.data_manager.init_args.split_val true \
--data.data_manager.init_args.dataset ${DATASET} \
--data.data_manager.init_args.context_length ${CTX_LEN} \
--data.data_manager.init_args.prediction_length ${PRED_LEN} \
--model.forecaster.init_args.ckpt_path './checkpoints/ForecastPFN/saved_weights' \
--data.test_batch_size 64
done
done
done
```
================================================
FILE: docs/benchmark/foundation_model/lag-llama.md
================================================
# Running Inference with Lag-Llama
[Original Repository](https://github.com/time-series-foundation-models/lag-llama) | [Paper](https://arxiv.org/abs/2310.08278)
Follow these steps to set up and run inference using Lag-Llama:
1. Set up the [environment and initialize submodules](../README.md#results-reproduction).
2. Run the inference script with the following commands:
```bash
# Lag-Llama
MODEL='lag_llama'
for DATASET in 'etth1' 'etth2' 'ettm1' 'ettm2' 'weather_ltsf'; do
for CTX_LEN in 512; do
for PRED_LEN in 24 48 96 192 336 720; do
python run.py --config config/tsfm/${MODEL}.yaml --seed_everything 0 \
--data.data_manager.init_args.path ${DATA_DIR} \
--trainer.default_root_dir ${LOG_DIR} \
--data.data_manager.init_args.split_val true \
--data.data_manager.init_args.dataset ${DATASET} \
--data.data_manager.init_args.context_length ${CTX_LEN} \
--data.data_manager.init_args.prediction_length ${PRED_LEN} \
--model.forecaster.init_args.ckpt_path './checkpoints/lag-llama/lag-llama.ckpt' \
--data.test_batch_size 1
done
done
done
```
================================================
FILE: docs/benchmark/foundation_model/moirai.md
================================================
# Running Inference with MOIRAI
[Original Repository](https://github.com/SalesforceAIResearch/uni2ts) | [Paper](https://arxiv.org/abs/2402.02592)
Follow these steps to set up and run inference using MOIRAI:
1. Set up the [environment and initialize submodules](../README.md#results-reproduction).
2. Run the inference script with the following commands:
```bash
MODEL='moirai'
for DATASET in 'etth1' 'etth2' 'ettm1' 'ettm2' 'weather_ltsf' 'electricity_ltsf'; do
for CTX_LEN in 5000 96; do
for PRED_LEN in 24 48 96 192 336 720; do
python run.py --config config/tsfm/${MODEL}/context_${CTX_LEN}/${DATASET}.yaml --seed_everything 0 \
--data.data_manager.init_args.path ${DATA_DIR} \
--trainer.default_root_dir ${LOG_DIR} \
--data.data_manager.init_args.dataset ${DATASET} \
--data.data_manager.init_args.prediction_length ${PRED_LEN}
done
done
done
```
## Hyper-param in Inference
`patch size` (default: `auto`): Specifies the patch size used during inference. When set to `auto`, the model selects the patch size that minimizes validation loss based on historical data.
`variate_mode` (default: `S`): Determines whether the model operates in univariate (`S`) or multivariate mode (`M`) during inference.
================================================
FILE: docs/benchmark/foundation_model/timer.md
================================================
# Running Inference with Timer
[Original Repository](https://github.com/thuml/Large-Time-Series-Model) | [Paper](https://arxiv.org/abs/2402.02368)
Follow these steps to set up and run inference using Timer:
1. Set up the [environment](../README.md#results-reproduction).
2. Run the inference script with the following commands:
```bash
MODEL='timer'
for DATASET in 'etth1' 'etth2' 'ettm1' 'ettm2' 'weather_ltsf' 'electricity_ltsf'; do
for CTX_LEN in 96; do
for PRED_LEN in 24 48 96 192 336 720; do
python run.py --config config/tsfm/${MODEL}.yaml --seed_everything 0 \
--data.data_manager.init_args.path ${DATA_DIR} \
--trainer.default_root_dir ${LOG_DIR} \
--data.data_manager.init_args.split_val true \
--data.data_manager.init_args.dataset ${DATASET} \
--data.data_manager.init_args.context_length ${CTX_LEN} \
--data.data_manager.init_args.prediction_length ${PRED_LEN} \
--model.forecaster.init_args.ckpt_path './checkpoints/timer/Timer_67M_UTSD_4G.pt' \
--data.test_batch_size 64
done
done
done
```
## Hyper-param in Inference
`use_ims` (default: false): Evaluate decoder-only models in the Iterative Multi-step (IMS) way or encoder-only forecasters in Direct Multi-step (DMS) approach
`sub_rand_ratio`: The ratio of training samples in few-shot scenarios.
================================================
FILE: docs/benchmark/foundation_model/timesfm.md
================================================
# Running Inference with TimesFM
[Original Repository](https://github.com/google-research/timesfm) | [Paper](https://arxiv.org/abs/2310.10688)
Follow these steps to set up and run inference using TimesFM:
1. Set up the [environment](../README.md#results-reproduction).
2. Run the inference script with the following commands:
```bash
MODEL='timesfm'
for DATASET in 'etth1' 'etth2' 'ettm1' 'ettm2'; do
for CTX_LEN in 96; do
for PRED_LEN in 24 48 96 192 336 720; do
python run.py --config config/tsfm/${MODEL}.yaml --seed_everything 0 \
--data.data_manager.init_args.path ${DATA_DIR} \
--trainer.default_root_dir ${LOG_DIR} \
--data.data_manager.init_args.split_val true \
--data.data_manager.init_args.dataset ${DATASET} \
--data.data_manager.init_args.context_length ${CTX_LEN} \
--data.data_manager.init_args.prediction_length ${PRED_LEN} \
--data.test_batch_size 64
done
done
done
```
## Hyper-param in Inference
`frequency` (default: 0): Chose from {0, 1, 2}.
- **0 (default):** High frequency, long horizon time series. We recommend using this for time series up to daily granularity.
- **1:** Medium frequency time series. We recommend using this for weekly and monthly data.
- **2:** Low frequency, short horizon time series. We recommend using this for anything beyond monthly, e.g., quarterly or yearly.
`window size` (default: None): Window size of trend + residual decomposition
================================================
FILE: docs/benchmark/foundation_model/ttm.md
================================================
# Running Inference with Tiny Time Mixers
[Original Repository](https://github.com/ibm-granite/granite-tsfm/tree/main/tsfm_public/models/tinytimemixer) | [Paper](https://arxiv.org/abs/2401.03955)
Follow these steps to set up and run inference using Tiny Time Mixers:
1. Set up the [environment and initialize submodules](../README.md#results-reproduction).
2. Run the inference script with the following commands:
```bash
MODEL='tinytimemixer'
for DATASET in 'etth1' 'etth2' 'ettm1' 'ettm2' 'weather_ltsf'; do
for CTX_LEN in 5000 96; do
for PRED_LEN in 24 48 96 192 336 720; do
python run.py --config config/tsfm/${MODEL}.yaml --seed_everything 0 \
--data.data_manager.init_args.path ${DATA_DIR} \
--trainer.default_root_dir ${LOG_DIR} \
--data.data_manager.init_args.split_val true \
--data.data_manager.init_args.dataset ${DATASET} \
--data.data_manager.init_args.context_length ${CTX_LEN} \
--data.data_manager.init_args.prediction_length ${PRED_LEN} \
--data.test_batch_size 1
done
done
done
```
================================================
FILE: docs/benchmark/foundation_model/units.md
================================================
# Running Inference with UniTS
[Original Repository](https://github.com/mims-harvard/UniTS) | [Paper](https://arxiv.org/pdf/2403.00131)
Follow these steps to set up and run inference using UniTS:
1. Set up the [environment](../README.md#results-reproduction).
2. Run the inference script with the following commands:
```bash
MODEL='units'
for DATASET in 'etth1' 'etth2' 'ettm1' 'ettm2'; do
for CTX_LEN in 96; do
for PRED_LEN in 24 48 96 192 336 720; do
python run.py --config config/tsfm/${MODEL}.yaml --seed_everything 0 \
--data.data_manager.init_args.path ${DATA_DIR} \
--trainer.default_root_dir ${LOG_DIR} \
--data.data_manager.init_args.split_val true \
--data.data_manager.init_args.dataset ${DATASET} \
--data.data_manager.init_args.context_length ${CTX_LEN} \
--data.data_manager.init_args.prediction_length ${PRED_LEN} \
--model.forecaster.init_args.ckpt_path './checkpoints/units/units_x128_pretrain_checkpoint.pth' \
--data.test_batch_size 64
done
done
done
```
================================================
FILE: docs/benchmark/supervised_model/README.md
================================================
# Supervised Forecasting Models Benchmarking
- [Supervised Forecasting Models Benchmarking](#supervised-forecasting-models-benchmarking)
- [Experimental Results Reproduction](#experimental-results-reproduction)
- [Key Insights \& Takeaways](#key-insights--takeaways)
- [Point vs. Probabilistic Estimation](#point-vs-probabilistic-estimation)
- [Autoregressive vs. Non-autoregressive Decoding Scheme](#autoregressive-vs-non-autoregressive-decoding-scheme)
- [Instance-level Normalization Choice](#instance-level-normalization-choice)
- [Experimental Result Details](#experimental-result-details)
## Experimental Results Reproduction
Reproduce the experimental results using the provided scripts:
- **Long-Term Forecasting:**
```bash
bash scripts/reproduce_ltsf_results.sh
```
Configuration files: [config/ltsf/](../../../config/ltsf/).
- **Short-Term Forecasting:**
```bash
bash scripts/reproduce_stsf_results.sh
```
Configuration files: [config/stsf/](../../../config/stsf/).
## Key Insights & Takeaways
### Point vs. Probabilistic Estimation
**Insights**
- Current supervised long-term point forecasting models (e.g., DLinear, PatchTST, iTransformer) **struggle with intricate data distributions**.
- Current supervised short-term probabilistic forecasting models (e.g., GRU NVP, TimeGrad, CSDI) **face challenges in extended forecasting horizons**.

**Takeaways**
- It is important to consider both long-term and short-term evaluation scenarios.
- Leverage both point and distributional metrics for more comprehensive insights.
### Autoregressive vs. Non-autoregressive Decoding Scheme
**Insights**
- Current Supervised Non-Autoregressive (NAR) Models (e.g., PatchTST, iTransformer, CSDI)
- Primarily developed for long-term forecasting scenarios.
- **Suboptimal for short-term forecasting, and some models are memory-intensive.**
- Current Supervised Autoregressive (AR) Models (e.g., GRU, GRU NVP, TimeGrad)
- Primarily developed for short-term forecasting scenarios
- **Perform well with strong seasonality but struggle with long-term, strong trends**

**Takeaways**
- It is crucial to select the right **methodological design** based on the specific **data characteristics**.
- There are tremendous **re-design opportunities**, given the **comprehensive forecasting needs**.
### Instance-level Normalization Choice
**Insights**
- Reversible Instance Normalization (RevIN): Essential for Long-term Forecasting Scenarios
- Our observation: **AR models in the literature are scarce for long-term forecasting**
- Our finding: RevIN + AR => **A simple yet highly effective baseline that has been overlooked**
- Normalization Choices under Short-term Forecasting Scenarios
- **No dominating normalization strategies**

**Takeaways**
- The **co-design** of **normalization** techniques and **model** architectures warrants further research attention.
- The **challenges and opportunities** in time-series normalization persist in balancing short-term and long-term forecasting needs.
## Experimental Result Details
**Long-Term Forecasting Benchmarking**
Table 1. Results ($\textrm{mean}_{\textrm{std}}$) on long-term forecasting scenarios with the best in $\textbf{bold}$ and the second $\underline{\textrm{underlined}}$, each containing five independent runs with different seeds. The input sequence length is set to 36 for the ILI-L dataset and 96 for the others. Due to the excessive time and memory consumption of CSDI in producing long-term forecasts, its results are unavailable in some datasets.

**Short-Term Forecasting Benchmarking**
Table 2.Results ($\textrm{mean}_{\textrm{std}}$) on short-term forecasting scenarios with the best in $\textbf{bold}$ and the second $\underline{\textrm{underlined}}$, each containing five independent runs with different seeds.

================================================
FILE: docs/documentation/Gift_eval.md
================================================
## How to evaluate the models in ProbTS using the GIFT-EVAL benchmark
Link to the GIFT-EVAL benchmark: [Github Repo](https://github.com/SalesforceAIResearch/gift-eval) [Paper](https://openreview.net/forum?id=9EBSEkFSje)
1. Follow installation instructions in the GIFT-EVAL repository to **download the dataset** from its huggingface dataset repository.
2. Also, set the environment variable `GIFT_EVAL` to the path where the dataset is downloaded.
``` bash
echo "GIFT_EVAL=/path/to/gift-eval" >> .env
```
3. Quick start example:
``` bash
python run.py --config config/default/mean.yaml \
--seed_everything 0 \
--model.forecaster.init_args.mode batch \
--data.data_manager.init_args.dataset gift/ett1/H/long \
--data.data_manager.init_args.path ./datasets \
--trainer.default_root_dir ./exps
```
> [!NOTE]
> The dataset name for the GIFT-EVAL format should be specified as follows: `"gift/" + "dataset_name (main_name/freq)" + "short/medium/long"`. For example, `gift/ett1/H/long`. More dataset names can be found in the GIFT-EVAL repository (for example [naive.ipynb](https://github.com/SalesforceAIResearch/gift-eval/blob/main/notebooks/naive.ipynb)).
================================================
FILE: docs/documentation/README.md
================================================
# Documentation :open_book:
- [Documentation :open\_book:](#documentation-open_book)
- [Setup](#setup)
- [Configuration Parameters](#configuration-parameters)
- [Trainer](#trainer)
- [Model](#model)
- [Data](#data)
- [Datasets](#datasets)
- [Datasets Overview](#datasets-overview)
- [Short-Term Setting](#short-term-setting)
- [Long-Term Setting](#long-term-setting)
- [Data Processing Pipeline](#data-processing-pipeline)
- [Using Build-in Datasets](#using-build-in-datasets)
- [Using Customized Dataset](#using-customized-dataset)
- [Model](#model-1)
- [Available Models](#available-models)
- [Using Customized Model](#using-customized-model)
- [Training](#training)
- [Configuring Optimizers and Learning Rate Schedulers](#configuring-optimizers-and-learning-rate-schedulers)
- [Forecasting with Varied Prediction Lengths](#forecasting-with-varied-prediction-lengths)
- [Example 1: Varied-Horizon Training](#example-1-varied-horizon-training)
- [Example 2: Validation and Testing with Multiple Horizons](#example-2-validation-and-testing-with-multiple-horizons)
## Setup
ProbTS is developed with Python 3.10 and relies on [PyTorch Lightning](https://github.com/Lightning-AI/lightning). To set up the environment:
```bash
# Create a new conda environment
conda create -n probts python=3.10
conda activate probts
# Install required packages
pip install .
pip uninstall -y probts # recommended to uninstall the root package (optional)
```
[Optional] For time-series foundation models, you need to install basic packages and additional dependencies:
```bash
# Create a new conda environment
conda create -n probts_fm python=3.10
conda activate probts_fm
# Git submodule
git submodule update --init --recursive
# Install additional packages for foundation models
pip install ".[tsfm]"
pip uninstall -y probts # recommended to uninstall the root package (optional)
# For MOIRAI, we fix the version of the package for better performance
cd submodules/uni2ts
git reset --hard fce6a6f57bc3bc1a57c7feb3abc6c7eb2f264301
```
Optional for TSFMs reproducibility
```bash
# For TimesFM, fix the version for reproducibility (optional)
cd submodules/timesfm
git reset --hard 5c7b905
# For Lag-Llama, fix the version for reproducibility (optional)
cd submodules/lag_llama
git reset --hard 4ad82d9
# For TinyTimeMixer, fix the version for reproducibility (optional)
cd submodules/tsfm
git reset --hard bb125c14a05e4231636d6b64f8951d5fe96da1dc
```
## Configuration Parameters
- To print the full pipeline configuration to a file:
```bash
python run.py --print_config > config/pipeline_config.yaml
```
### Trainer
| Config Name | Type | Description |
| --- | --- | --- |
| `trainer.max_epochs` | `int` | Maximum number of training epochs. |
| `trainer.limit_train_batches` | `int` | Limits the number of training batches per epoch. |
| `trainer.check_val_every_n_epoch` | `int` | Perform validation every n training epochs. |
| `trainer.default_root_dir` | `int` | Default path for logs and weights. |
| `trainer.accumulate_grad_batches` | `int` | Number of batches to accumulate gradients before updating. |
### Model
| Config Name | Type | Description |
| --- | --- | --- |
| `model.forecaster.class_path` | `str` | Forecaster module path (e.g., `probts.model.forecaster.point_forecaster.PatchTST`). |
| `model.forecaster.init_args.{ARG}` | - | Model-specific hyperparameters. |
| `model.num_samples` | `int` | Number of samples per distribution during evaluation. |
| `model.learning_rate` | `float` | Learning rate. |
| `model.quantiles_num` | `int` | Number of quantiles for evaluation. |
| `model.sampling_weight_scheme` | `str` | The scheme of training horizon reweighting. Options: ['random', 'none', 'const'].|
| `model.optimizer_config.class_name` | `str` | optimizer module (e.g., `torch.optim.Adam`). |
| `model.optimizer_config.init_args.{ARG}` | - | optimizer hyperparameters. |
| `model.scheduler_config.class_name` | `str` | lr_scheduler module (e.g., `torch.optim.lr_scheduler.OneCycleLR`). |
| `model.scheduler_config.init_args.{ARG}` | - | lr_scheduler hyperparameters. |
### Data
| Config Name | Type | Description |
| --- | --- | --- |
| `data.data_manager.init_args.dataset` | `str` | Dataset for training and evaluation. |
| `data.data_manager.init_args.path` | `str` | Path to the dataset folder. |
| `data.data_manager.init_args.split_val` | `bool` | Whether to split a validation set during training. |
| `data.data_manager.init_args.scaler` | `str` | Scaler type: `identity`, `standard` (z-score normalization), or `temporal` (scale based on average temporal absolute value). |
| `data.data_manager.init_args.target_dim` | `int` | The number of variates. |
| `data.data_manager.init_args.var_specific_norm` | `bool` | If conduct per-variate normalization or not. |
| `data.data_manager.init_args.timeenc` | `int` | Time feature type. Select from `[0,1,2]`. See the explaination below for details. |
| `data.data_manager.init_args.context_length` | `Union[str, int, list]` | Length of observation window in inference phase. |
| `data.data_manager.init_args.prediction_length` | `Union[str, int, list]` | Forecasting horizon length in inference phase. |
| `data.data_manager.init_args.val_pred_len_list` | `Union[str, int, list]` | Forecasting horizon length for performance validation. |
| `data.data_manager.init_args.val_ctx_len` | `Union[str, int, list]` | Forecasting horizons for performance validation. |
| `data.data_manager.init_args.train_pred_len_list`| `Union[str, int, list]` | Length of observation window in training phase. |
| `data.data_manager.init_args.train_ctx_len` | `Union[str, int, list]` | Forecasting horizons in training phase. |
| `data.data_manager.init_args.continuous_sample` | `bool` | If True, sampling horizons from `[min(train_pred_len_list), max(train_pred_len_list)]`, else sampling within the set `train_pred_len_list`.|
| `data.data_manager.init_args.test_rolling_length` | `int` | `int` or `str` | Defines the gap window for rolling evaluations during testing. Defaults to `96` if not explicitly specified. If set to `auto`, the value is determined based on the dataset frequency: `{'h': 24, 'd': 7, 'b': 5, 'w': 4, 'min': 60}`. |
| `data.data_manager.init_args.train_ratio` | `float` | Specifies proportion of the dataset used for training. Default value is 0.7.|
| `data.data_manager.init_args.test_ratio` | `float` | Specifies proportion of the dataset used for training. Default value is 0.2.|
| `data.batch_size` | `int` | Batch size. |
**Temporal Features**
For the datasets used for long-term forecasting scenario, we support three types of time feature encoding
```bash
--data.data_manager.init_args.timeenc {the encoding type} # select from [0,1,2]
```
- **[timeenc 0] temporal information**
The dimension of time feature is 5, containing `month, day, weekday, hour, minute`.
- **[timeenc 1] time feature based on frequency**
Extract time feature using `time_features_from_frequency_str()` function. The dimensionality follows:
```bash
freq_map = {'h': 4, 't': 5, 's': 6, 'm': 1, 'a': 1, 'w': 2, 'd': 3, 'b': 3}
```
*Note: timeenc = 0 if model.embed != 'timeF' else 1.*
- **[timeenc 2] Raw date information**
The dimension of time feature is 5, using the following code to recover it to date data type:
```bash
data_stamp = batch_data.past_time_feat.cpu().numpy().astype('datetime64[s]')
data_stamp = batch_data.future_time_feat.cpu().numpy().astype('datetime64[s]')
```
## Datasets
### Datasets Overview
#### Short-Term Setting
| Dataset | DATASET_NAME | Domain | Frequency | #Var | time steps | Description |
| --- | --- | --- | --- | --- | --- | --- |
| Exchange | `exchange_rate_nips` | Finance | Busi. Day | 8 | 6,071 | Daily exchange rates of 8 countries |
| Solar | `solar_nips` | Energy | H | 137 | 7,009 | Solar power production records |
| Electricity | `electricity_nips` | Energy | H | 370 | 5,833 | Electricity consumption |
| Traffic | `traffic_nips` | Transport | H | 963 | 4,001 | Road occupancy rates |
| Wikipedia | `wiki2000_nips` | Web | D | 2,000 | 792 | Page views of 2000 Wikipedia pages |
#### Long-Term Setting
| Dataset | DATASET_NAME | Domain | Frequency | #Var | time steps | Description |
| --- | --- | --- | --- | --- | --- | --- |
| ETTh | `etth1` / `etth2` | Energy | H | 7 | 17,420 | Electricity transformer temperature per hour |
| ETTm | `ettm1` / `ettm2` | Energy | 15min | 7 | 69,680 | Electricity transformer temperature every 15 min |
| Electricity | `electricity_lstf` | Energy | H | 321 | 26,304 | Electricity consumption (Kwh) |
| Weather | `weather_lstf` | Climate | 10min | 21 | 52,696 | Local climatological data |
| Traffic | `traffic_ltsf` | Transport | H | 862 | 17,544 | Road occupancy rates |
| Exchange | `exchange_ltsf` | Finance | Busi. Day | 8 | 7,588 | Daily exchange rates of 8 countries |
| ILI | `illness_ltsf` | Epidemiology | W | 7 | 966 | Ratio of patients seen with influenza-like illness |
| Caiso | `caiso` | Energy | H | 10 | 74,472 | Electricity load series in different zones of California |
| Nordpool | `nordpool` | Energy | H | 18 | 70,128 | Energy production volume in European countries |
| Turkey Power | `turkey_power` | Energy | H | 18 | 26,304 | Electrical power demand in Turkey |
| Istanbul Traffic | `istanbul_traffic` | Transport | H | 3 | 14,244 | Traffic Index data for Istanbul traffic |
### Data Processing Pipeline
### Using Build-in Datasets
- **Short-Term Forecasting**: We use datasets from [GluonTS](https://github.com/awslabs/gluonts).
Configure the datasets using `--data.data_manager.init_args.dataset {DATASET_NAME}` with available `DATASET_NAME` in [short-term setting](#short-term-setting).
- **Long-Term Forecasting**: To download the [long-term forecasting datasets](https://drive.google.com/drive/folders/1ZOYpTUa82_jCcxIdTmyr0LXQfvaM9vIy), please follow these steps:
```bash
bash scripts/prepare_datasets.sh "./datasets"
```
Configure the datasets using `--data.data_manager.init_args.dataset {DATASET_NAME}` with available `DATASET_NAME` in [long-term setting](#long-term-setting).
*Note: When utilizing long-term forecasting datasets, you must explicitly specify the `context_length` and `prediction_length` parameters. For example, to set a context length of 96 and a prediction length of 192, use the following command-line arguments:*
```bash
--data.data_manager.init_args.context_length 96 \
--data.data_manager.init_args.prediction_length 192 \
```
- **Using Datasets from Monash Time Series Forecasting Repository**: To use datasets from the [Monash Time Series Forecasting Repository](https://forecastingdata.org/), follow these steps:
1. **Download the Dataset**:
- Navigate to the target dataset, such as the [Electricity Hourly Dataset](https://zenodo.org/records/4656140).
- Download the `.tsf` file and place it in your local `datasets` directory (e.g., `./datasets`).
1. **Configure the Dataset**:
- Use the following configuration to specify the dataset, file path, and frequency:
```bash
--data.data_manager.init_args.dataset {DATASET_NAME} \
--data.data_manager.init_args.data_path /path/to/data.csv \
--data.data_manager.init_args.freq {FREQ}
```
- **Example Configuration**:
```bash
--data.data_manager.init_args.dataset monash_electricity_hourly \
--data.data_manager.init_args.data_path ./datasets/electricity_hourly_dataset.tsf \
--data.data_manager.init_args.freq H \
--data.data_manager.init_args.context_length 96 \
--data.data_manager.init_args.prediction_length 96 \
--data.data_manager.init_args.multivariate true
```
*Note: Refer to the [Pandas Time Series Offset Aliases](https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#timeseries-offset-aliases) for the correct frequency values (`{FREQ}`) to use in your configuration.*
- **Using Datasets from GIFT-EVAL Benchmarking**: see [this page](./docs/documentation/Gift_eval.md) for detailed instructions.
### Using Customized Dataset
1. **Prepare the Data**:
- Format your dataset as a `.csv` file with the following structure:
| date | VAR1 | VAR2 | ... |
|---------------------|--------|--------|-----|
| 2013-01-01 00:00:00 | 2611.0 | 1539.0 | ... |
| 2013-01-01 01:00:00 | 2132.0 | 1535.0 | ... |
Note1: The date column represents timestamps.
Note2: VAR1, VAR2, etc., represent different variables (features) for each timestamp.
- Place the csv file in your local `datasets` directory (e.g., `./datasets`).
1. **Configure the Dataset**:
- Use the following configuration to specify the dataset, file path, and frequency:
```bash
--data.data_manager.init_args.dataset {DATASET_NAME} \
--data.data_manager.init_args.data_path /path/to/data_file.tsf \
--data.data_manager.init_args.freq {FREQ}
```
- **Example Configuration**:
```bash
--data.data_manager.init_args.dataset my_data \
--data.data_manager.init_args.data_path ./datasets/my_data.csv \
--data.data_manager.init_args.freq H \
--data.data_manager.init_args.context_length 96 \
--data.data_manager.init_args.prediction_length 96 \
--data.data_manager.init_args.multivariate true
```
*Note: You can adjust the test instance sampling using the `--data.data_manager.init_args.test_rolling_length` parameter.*
## Model
### Available Models
ProbTS includes both classical time-series models, specializing in long-term point forecasting or short-term distributional forecasting, and recent time-series foundation models that offer zero-shot and arbitrary-horizon forecasting capabilities for new time series.
**Classical Time-series Models**
| **Model** | **Original Eval. Horizon** | **Estimation** | **Decoding Scheme** | **Class Path** |
| --- | --- | --- | --- | --- |
| Linear | - | Point | Auto / Non-auto | `probts.model.forecaster.point_forecaster.LinearForecaster` |
| [GRU](https://arxiv.org/abs/1412.3555) | - | Point | Auto / Non-auto | `probts.model.forecaster.point_forecaster.GRUForecaster` |
| [Transformer](https://arxiv.org/abs/1706.03762) | - | Point | Auto / Non-auto | `probts.model.forecaster.point_forecaster.TransformerForecaster` |
| [Autoformer](https://arxiv.org/abs/2106.13008) | Long-trem | Point | Non-auto | `probts.model.forecaster.point_forecaster.Autoformer` |
| [N-HiTS](https://arxiv.org/abs/2201.12886) | Long-trem | Point | Non-auto | `probts.model.forecaster.point_forecaster.NHiTS` |
| [NLinear](https://arxiv.org/abs/2205.13504) | Long-trem | Point | Non-auto | `probts.model.forecaster.point_forecaster.NLinear` |
| [DLinear](https://arxiv.org/abs/2205.13504) | Long-trem | Point | Non-auto | `probts.model.forecaster.point_forecaster.DLinear` |
| [TSMixer](https://arxiv.org/abs/2303.06053) | Long-trem | Point | Non-auto | `probts.model.forecaster.point_forecaster.TSMixer` |
| [TimesNet](https://arxiv.org/abs/2210.02186) | Short- / Long-term | Point | Non-auto | `probts.model.forecaster.point_forecaster.TimesNet` |
| [PatchTST](https://arxiv.org/abs/2211.14730) | Long-trem | Point | Non-auto | `probts.model.forecaster.point_forecaster.PatchTST` |
| [iTransformer](https://arxiv.org/abs/2310.06625) | Long-trem | Point | Non-auto | `probts.model.forecaster.point_forecaster.iTransformer` |
| [ElasTST](https://arxiv.org/abs/2411.01842) | Long-trem | Point | Non-auto | `probts.model.forecaster.point_forecaster.ElasTST` |
| [GRU NVP](https://arxiv.org/abs/2002.06103) | Short-term | Probabilistic | Auto | `probts.model.forecaster.prob_forecaster.GRU_NVP` |
| [GRU MAF](https://arxiv.org/abs/2002.06103) | Short-term | Probabilistic | Auto | `probts.model.forecaster.prob_forecaster.GRU_MAF` |
| [Trans MAF](https://arxiv.org/abs/2002.06103) | Short-term | Probabilistic | Auto | `probts.model.forecaster.prob_forecaster.Trans_MAF` |
| [TimeGrad](https://arxiv.org/abs/2101.12072) | Short-term | Probabilistic | Auto | `probts.model.forecaster.prob_forecaster.TimeGrad` |
| [CSDI](https://arxiv.org/abs/2107.03502) | Short-term | Probabilistic | Non-auto | `probts.model.forecaster.prob_forecaster.CSDI` |
| [TSDiff](https://arxiv.org/abs/2307.11494) | Short-term | Probabilistic | Non-auto | `probts.model.forecaster.prob_forecaster.TSDiffCond` |
**Fundation Models**
| **Model** | **Any Horizon** | **Estimation** | **Decoding Scheme** | **Class Path** | **Model Size** |
| --- | --- | --- | --- | --- | --- |
| [Lag-Llama](https://arxiv.org/abs/2310.08278) | ✔ | Probabilistic | AR | `probts.model.forecaster.prob_forecaster.LagLlama` | - |
| [ForecastPFN](https://arxiv.org/abs/2311.01933) | ✔ | Point | NAR | `probts.model.forecaster.point_forecaster.ForecastPFN` | - |
| [TimesFM](https://arxiv.org/abs/2310.10688) | ✔ | Point | AR | `probts.model.forecaster.point_forecaster.TimesFM` | `200m`, `500m` |
| [TTM](https://arxiv.org/abs/2401.03955) | ✘ | Point | NAR | `probts.model.forecaster.point_forecaster.TinyTimeMixer` | - |
| [Timer](https://arxiv.org/abs/2402.02368) | ✔ | Point | AR | `probts.model.forecaster.point_forecaster.Timer` | - |
| [MOIRAI](https://arxiv.org/abs/2402.02592) | ✔ | Probabilistic | NAR | `probts.model.forecaster.prob_forecaster.Moirai` | `small`, `base`, `large` |
| [UniTS](https://arxiv.org/abs/2403.00131) | ✔ | Point | NAR | `probts.model.forecaster.point_forecaster.UniTS` | - |
| [Chronos](https://arxiv.org/abs/2403.07815) | ✔ | Probabilistic | AR | `probts.model.forecaster.prob_forecaster.Chronos` | `tiny`, `mini`, `small`, `base`, `large` |
| [Time-MoE](https://arxiv.org/abs/2409.16040) | ✔ | Point | AR | `probts.model.forecaster.point_forecaster.TimeMoE` | `50M`, `200M` |
See the [tsfm configuration directory](./config/tsfm/) for more details. More models will be added soon—stay tuned!
### Using Customized Model
With our platform, you can easily evaluate customized models across various datasets. Follow the steps below to create and evaluate your model.
**Step 1: Create a New Python File**
Create a new Python file and follow the structure below to define your custom model:
```python
from probts.model.forecaster import Forecaster
class ModelName(Forecaster):
def __init__(
self,
**kwargs
):
"""
Initialize the model with parameters.
"""
super().__init__(**kwargs)
# Initialize model parameters here
def forward(self, inputs):
"""
Forward pass for the model.
Parameters:
inputs [Tensor]: Input tensor for the model.
Returns:
Tensor: Output tensor.
"""
# Perform the forward pass of the model
return outputs
def loss(self, batch_data):
"""
Compute the loss for the given batch data.
Parameters:
batch_data [dict]: Dictionary containing input data and possibly target data.
Returns:
Tensor: Computed loss.
"""
# Extract inputs and targets from batch_data
inputs = batch_data.past_target_cdf[:, -self.context_length:, :] # [batch_size, context_length, var_num]
target = batch_data.future_target_cdf # [batch_size, prediction_length, var_num]
# Forward pass
outputs = self.forward(inputs)
# Calculate loss using a loss function, e.g., Mean Squared Error
loss = self.loss_function(outputs, future_target_cdf)
return loss
def forecast(self, batch_data, num_samples=None):
"""
Generate forecasts for the given batch data.
Parameters:
batch_data [dict]: Dictionary containing input data.
num_samples [int, optional]: Number of samples per distribution during evaluation. Defaults to None.
Returns:
Tensor: Forecasted outputs.
"""
# Perform the forward pass to get the outputs
outputs = self(batch_data.past_target_cdf[:, -self.context_length:, :])
if num_samples is not None:
# If num_samples is specified, use it to sample from the distribution
outputs = self.sample_from_distribution(outputs, num_samples)
else:
# If perform point estimation, the num_samples is equal to 1
outputs = outputs.unsqueeze(1)
return outputs # [batch_size, num_samples, prediction_length, var_num]
```
**Input Data Format**
The `batch_data` dictionary contains several fields that provide necessary information for the model's operation. Each field is described below:
- **`target_dimension_indicator`**:
- **Shape**: [var_num]
- **Description**: Indicator that specifies which dimension or feature of the target is being referenced.
- **`{past|future}_time_feat`**:
- **Shape**: [batch_size,length,time_feature_dim]
- **Description**: Time features associated with each time step in the past or future. This can include various time-related information such as timestamps, seasonal indicators (e.g., month, day of the week), or other temporal features that provide context to the observations.
- **`{past|future}_target_cdf`**:
- **Shape**: [batch_size,length,var_num]
- **Description**: The observation values of the target variable(s) for past or future time steps.
- **`{past|future}_observed_values`**:
- **Shape**: [batch_size,length,var_num]
- **Description**: Binary masks indicating which values in the past or future target data are observed (1) and which are missing or unobserved (0).
**Step 2: Create YAML Configuration File**
Create a YAML configuration file (`model.yaml`) for the customized model:
```yaml
seed_everything: 1 # random seed
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results # path to the log folder
model:
forecaster:
class_path: class.path.to.ModelName
init_args:
# init your hyperparameter here
learning_rate: 0.001 # learning rate
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: solar_nips # dataset name
split_val: true
scaler: standard # identity, standard, temporal
batch_size: 32
test_batch_size: 32
num_workers: 8
```
**Step 3: Run the Customized Model**
Run the customized model using the configuration file:
```bash
python run.py --config config/path/to/model.yaml
```
## Training
### Configuring Optimizers and Learning Rate Schedulers
ProbTS supports customizable optimizers and learning rate schedulers. You can specify them directly in the YAML configuration file.
**Example Configuration**
```yaml
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.PatchTST
init_args:
# Add forecaster-specific parameters here
optimizer_config:
class_name: torch.optim.Adam
init_args:
weight_decay: 0 # Add optimizer-specific parameters here
lr_scheduler_config:
class_name: torch.optim.lr_scheduler.OneCycleLR
init_args:
max_lr: 0.0001
steps_per_epoch: 100
pct_start: 0.3
epochs: 50 # Add scheduler-specific parameters here
```
Example configurations can be found in [config/default/patchtst.yaml](../../config/default/patchtst.yaml).
**Notes**
- If no configuration is provided, ProbTS defaults to the Adam optimizer with a constant learning rate.
- Adjust init_args for both the optimizer and scheduler to suit your specific use case.
## Forecasting with Varied Prediction Lengths
**Example:**
```bash
python run.py --config config/multi_hor/elastst.yaml \
--data.data_manager.init_args.path ./datasets \
--trainer.default_root_dir /path/to/log_dir/ \
--data.data_manager.init_args.dataset {DATASET_NAME} \
--data.data_manager.init_args.context_length ${TEST_CTX_LEN} \
--data.data_manager.init_args.prediction_length ${TEST_PRED_LEN} \
--data.data_manager.init_args.train_ctx_len ${TRAIN_CTX_LEN} \
--data.data_manager.init_args.train_pred_len_list ${TRAIN_PRED_LEN} \
--data.data_manager.init_args.val_ctx_len ${VAL_CTX_LEN} \
--data.data_manager.init_args.val_pred_len_list ${VAL_PRED_LEN}
```
- `DATASET_NAME`: Select from datasets used in long-term forecasting scenerios.
- `TEST_CTX_LEN`: Context length in the testing phase.
- `VAL_CTX_LEN` (Default: `TEST_CTX_LEN`): Context length in the validation phase.
- `TRAIN_CTX_LEN` (Default: `TEST_CTX_LEN`): Context length in the training phase.
- `TEST_PRED_LEN`: Forecasting horizons in the testing phase.
- `VAL_PRED_LEN` (Default: `TEST_PRED_LEN`): Forecasting horizons for performance validation.
- `TRAIN_PRED_LEN` (Default: `TEST_PRED_LEN`): Forecasting horizons in the training phase.
The results across multiple horizons will be saved to:
```bash
/path/to/log_dir/{DATASET_NAME}_{MODEL}_{seed}_TrainCTX_{TRAIN_CTX_LEN}_TrainPRED_{TRAIN_PRED_LEN}_ValCTX_{CTX_LEN}_ValPRED_{VAL_PRED_LEN}/horizons_results.csv
```
### Example 1: Varied-Horizon Training
**Mode 1: Random sampling from a set of horizons**
```bash
python run.py --config config/multi_hor/elastst.yaml \
--data.data_manager.init_args.path ./datasets \
--trainer.default_root_dir /path/to/log_dir/ \
--data.data_manager.init_args.dataset ${DATASET} \
--data.data_manager.init_args.context_length 96 \
--data.data_manager.init_args.prediction_length 720 \
--data.data_manager.init_args.train_ctx_len 96 \
--data.data_manager.init_args.val_pred_len_list 720 \
# random selection from {96, 192, 336, 720}
--data.data_manager.init_args.train_pred_len_list 96-192-336-720 \
--data.data_manager.init_args.continuous_sample false
```
**Mode 2: Random sampling from a horizon range**
```bash
python run.py --config config/multi_hor/elastst.yaml \
--data.data_manager.init_args.path ./datasets \
--trainer.default_root_dir /path/to/log_dir/ \
--data.data_manager.init_args.dataset ${DATASET} \
--data.data_manager.init_args.context_length 96 \
--data.data_manager.init_args.prediction_length 720 \
--data.data_manager.init_args.train_ctx_len 96 \
--data.data_manager.init_args.val_pred_len_list 720 \
# random sampling from [1, 720]
--data.data_manager.init_args.train_pred_len_list 1-720 \
--data.data_manager.init_args.continuous_sample true
```
### Example 2: Validation and Testing with Multiple Horizons
```bash
python run.py --config config/multi_hor/elastst.yaml \
--data.data_manager.init_args.path ./datasets \
--trainer.default_root_dir /path/to/log_dir/ \
--data.data_manager.init_args.dataset ${DATASET} \
--data.data_manager.init_args.context_length 96 \
--data.data_manager.init_args.train_pred_len_list 720 \
--data.data_manager.init_args.train_ctx_len 96 \
# validation on {96, 192, 336, 720}
--data.data_manager.init_args.val_pred_len_list 96-192-336-720 \
# testing on {24, 96, 192, 336, 720, 1024}
--data.data_manager.init_args.prediction_length 24-96-192-336-720-1024
```
================================================
FILE: exps/.gitignore
================================================
*
!.gitignore
================================================
FILE: notebook/data_characteristics.ipynb
================================================
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"from gluonts.dataset.repository.datasets import get_dataset\n",
"from gluonts.dataset.multivariate_grouper import MultivariateGrouper\n",
"import pandas as pd\n",
"import numpy as np\n",
"\n",
"data_path = 'path/to/datasets/'\n",
"save_path = Path(data_path)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Decomposition"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"from statsmodels.tsa.seasonal import STL\n",
"from tqdm import trange\n",
"\n",
"def measure_strength(df, dataset, win=0):\n",
" \"\"\"\n",
" Measures the strength of trend (F_t) and seasonality (F_s) in time series data.\n",
"\n",
" Parameters:\n",
" - df (pd.DataFrame): The input data containing time series columns.\n",
" - dataset (str): The name of the dataset to identify frequency or specific configurations.\n",
" - win (int): Window size for decomposition; if 0, applies decomposition on the full time series.\n",
"\n",
" Outputs:\n",
" Prints the average strength of trend and seasonality for the dataset.\n",
" \"\"\"\n",
" # Decompose the time series for each dimension\n",
" dim_list = ts_decompose(df, dataset, win=win)\n",
" \n",
" F_t_list = [] # List to store trend strength values\n",
" F_s_list = [] # List to store seasonality strength values\n",
" \n",
" for res in dim_list:\n",
" # Skip calculations if variance of the decomposed components is zero\n",
" if (res.trend + res.resid).var() == 0 or (res.seasonal + res.resid).var() == 0:\n",
" continue\n",
" \n",
" # Calculate trend strength (F_t)\n",
" F_t = max(0, 1 - (res.resid.var() / (res.trend + res.resid).var()))\n",
" F_t_list.append(F_t)\n",
" \n",
" # Calculate seasonality strength (F_s)\n",
" F_s = max(0, 1 - (res.resid.var() / (res.seasonal + res.resid).var()))\n",
" F_s_list.append(F_s)\n",
" \n",
" # Print summary of results\n",
" print('dataset: {dataset}, \\t win. size: {win},\\t Avg. F_t: {avg_ft:2.4f},\\t Avg. F_s: {avg_fs:2.4f}'.format(\n",
" dataset=dataset, win=win, avg_ft=np.mean(F_t_list), avg_fs=np.mean(F_s_list)\n",
" ))\n",
"\n",
"def ts_decompose(df, dataset, win=0):\n",
" \"\"\"\n",
" Decomposes time series data into trend, seasonal, and residual components.\n",
"\n",
" Parameters:\n",
" - df (pd.DataFrame): The input data containing time series columns.\n",
" - dataset (str): The name of the dataset to identify frequency or specific configurations.\n",
" - win (int): Window size for decomposition; if 0, applies decomposition on the full time series.\n",
"\n",
" Returns:\n",
" - dim_list (list): A list of decomposition results for each dimension of the time series.\n",
" \"\"\"\n",
" # Define frequency mapping for datasets\n",
" freq_dict = {\n",
" 'ETT-small/ETTh1': 'H', 'ETT-small/ETTh2': 'H', 'ETT-small/ETTm1': 'T', 'ETT-small/ETTm2': 'T',\n",
" 'electricity/electricity': 'H', 'exchange_rate/exchange_rate': 'B',\n",
" 'illness/national_illness': 'W', 'traffic/traffic': 'H', 'weather/weather': 'T',\n",
" 'exchange_rate_nips': 'B', 'solar_nips': 'H', 'electricity_nips': 'H',\n",
" 'traffic_nips': 'H', 'wiki2000_nips': 'D'\n",
" }\n",
" \n",
" # Define minimum period mapping for datasets\n",
" min_dict = {\n",
" 'ETT-small/ETTm1': (24 * 60) // 15, 'ETT-small/ETTm2': (24 * 60) // 15,\n",
" 'weather/weather': (24 * 60) // 10\n",
" }\n",
"\n",
" dim = len(df.iloc[0]) # Number of dimensions (columns) in the data\n",
" dim_list = [] # List to store decomposition results for each dimension\n",
" \n",
" for i in trange(dim): # Iterate over each column in the dataset\n",
" if win == 0:\n",
" # Standardize the time series column\n",
" tmp_df = (df.iloc[:, i] - df.iloc[:, i].mean()) / (df.iloc[:, i].std())\n",
" \n",
" # Perform STL decomposition with appropriate frequency settings\n",
" if dataset in freq_dict and freq_dict[dataset] == 'T':\n",
" stl = STL(tmp_df.fillna(0), period=7, robust=True)\n",
" else:\n",
" stl = STL(tmp_df.fillna(0), robust=True)\n",
" \n",
" res = stl.fit() # Fit the decomposition model\n",
" dim_list.append(res) # Store the result\n",
" else:\n",
" # Perform windowed decomposition\n",
" right = win # Initialize the right boundary of the window\n",
" while right < len(df.iloc[1:, i]):\n",
" tmp_df = df.iloc[right - win:right, i] # Extract the windowed data\n",
" tmp_df = (tmp_df - tmp_df.mean()) / (tmp_df.std()) # Standardize the windowed data\n",
" \n",
" # Perform STL decomposition with appropriate frequency settings\n",
" if dataset in freq_dict and freq_dict[dataset] == 'T':\n",
" stl = STL(tmp_df.fillna(0), period=7, robust=True)\n",
" else:\n",
" stl = STL(tmp_df.fillna(0), robust=True)\n",
" \n",
" res = stl.fit() # Fit the decomposition model\n",
" right += win # Move the window forward\n",
" dim_list.append(res) # Store the result\n",
" \n",
" return dim_list # Return the list of decomposition results"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Normality"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"from scipy.stats import normaltest\n",
"import numpy as np\n",
"import scipy.stats\n",
"from scipy.stats import norm\n",
"\n",
"def test_normal(df, dataset, win=0):\n",
" dim = len(df.iloc[0])\n",
" score_list = []\n",
" gaussian_count = 0\n",
" count = 0\n",
" for i in range(dim):\n",
" # z-score\n",
" # df.iloc[:,i]=(df.iloc[:,i]-df.iloc[:,i].mean())/(df.iloc[:,i].std())\n",
" value = df.iloc[:,i].dropna().values\n",
" if len(value) < 10:\n",
" continue\n",
" \n",
" right = win\n",
" pvalue = []\n",
" if win > 0:\n",
" while right < len(value):\n",
" res = normaltest(value[right-win:right])[1]\n",
" pvalue.append(res)\n",
" right += win\n",
" res = np.mean(pvalue)\n",
" else:\n",
" res = normaltest(value)[1]\n",
" # res = kstest(value, 'norm')[1]\n",
" if sum(value) == 0:\n",
" continue\n",
" \n",
" if res >= 0.05:\n",
" gaussian_count += 1\n",
" count += 1\n",
" \n",
" score_list.append(res)\n",
"\n",
" \n",
" print(dataset, \" gaussian pvalue: \", str(np.mean(score_list)), '\\t gaussian ratio: ', str(gaussian_count/count))\n",
"\n",
"\n",
"def JS_divergence(p,q):\n",
" M=(p+q)/2\n",
" return 0.5*scipy.stats.entropy(p, M, base=2)+0.5*scipy.stats.entropy(q, M, base=2)\n",
"\n",
"def JS_div(arr1,arr2,num_bins):\n",
" max0 = max(np.max(arr1),np.max(arr2))\n",
" min0 = min(np.min(arr1),np.min(arr2))\n",
" bins = np.linspace(min0-1e-4, max0-1e-4, num=num_bins)\n",
" \n",
" PDF1 = pd.cut(arr1,bins,duplicates='drop').value_counts()\n",
" PDF2 = pd.cut(arr2,bins, duplicates='drop').value_counts()\n",
" \n",
" if sum(PDF1) > 0 and sum(PDF2) > 0:\n",
" PDF1 = PDF1 / len(arr1)\n",
" PDF2 = PDF2 / len(arr2)\n",
" return JS_divergence(PDF1.values,PDF2.values)\n",
" else:\n",
" return None\n",
"\n",
"\n",
"def cal_JS_divergence(df, dataset, win=0):\n",
" \n",
" dim = len(df.iloc[0])\n",
" js_list = []\n",
" for i in range(1, dim):\n",
" \n",
" # z-score\n",
" global_mu = df.iloc[:,i].mean()\n",
" global_std = df.iloc[:,i].std()\n",
" df.iloc[:,i]=(df.iloc[:,i]-global_mu) / global_std\n",
" value = df.iloc[:,i].dropna().values\n",
" \n",
" if sum(value) == 0:\n",
" continue\n",
" \n",
" right = win\n",
" dim_list = []\n",
" if win > 0:\n",
" while right < len(value):\n",
" tmp_value = value[right-win:right]\n",
" mu = tmp_value.mean()\n",
" std = tmp_value.std()\n",
"\n",
" norm_dist = norm.rvs(loc=mu, scale=std, size=len(tmp_value))\n",
" res = JS_div(tmp_value,norm_dist,num_bins=20)\n",
" if res is not None:\n",
" dim_list.append(res)\n",
" right += win\n",
" \n",
" js_div = np.mean(dim_list)\n",
"\n",
" else:\n",
" norm_dist = norm.rvs(loc=global_mu, scale=global_std, size=len(value))\n",
" js_div = JS_div(value,norm_dist,num_bins=20)\n",
" \n",
" if js_div is not None:\n",
" js_list.append(js_div)\n",
" \n",
" print(\"window size: \", win, \"\\t dataset: \", dataset, \"\\t JS DIV avg: \", str(np.mean(js_list)))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Long-term Datasets"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"def load_csv_data(filename, dataset):\n",
" \"\"\"\n",
" Loads time series data from a CSV file and processes it based on dataset-specific requirements.\n",
"\n",
" Parameters:\n",
" - filename (str): Path to the directory containing the CSV file.\n",
" - dataset (str): Name of the dataset to be loaded, used for specific handling.\n",
"\n",
" Returns:\n",
" - df (pd.DataFrame): Processed DataFrame with time series data, indexed by date.\n",
" \"\"\"\n",
" # Dictionary to map dataset names to their respective data frequency\n",
" freq_dict = {\n",
" 'ETT-small/ETTh1': 'H', 'ETT-small/ETTh2': 'H', 'ETT-small/ETTm1': 'T', 'ETT-small/ETTm2': 'T',\n",
" 'electricity/electricity': 'H', 'exchange_rate/exchange_rate': 'D',\n",
" 'illness/national_illness': 'D', 'traffic/traffic': 'H', 'weather/weather': 'T'\n",
" }\n",
"\n",
" # Special handling for 'caiso' dataset\n",
" if 'caiso' in dataset:\n",
" # Load the dataset and convert the 'Date' column to datetime\n",
" data = pd.read_csv(filename + dataset + '.csv')\n",
" data['Date'] = data['Date'].astype('datetime64[ns]')\n",
" \n",
" # Names of zones in the dataset\n",
" names = ['PGE', 'SCE', 'SDGE', 'VEA', 'CA ISO', 'PACE', 'PACW', 'NEVP', 'AZPS', 'PSEI']\n",
" \n",
" # Create a DataFrame with a complete hourly date range\n",
" df = pd.DataFrame(pd.date_range('20130101', '20210630', freq='H')[:-1], columns=['Date'])\n",
" \n",
" # Process each zone's data and merge into a single DataFrame\n",
" for name in names:\n",
" current_df = (\n",
" data[data['zone'] == name]\n",
" .drop_duplicates(subset='Date', keep='last') # Remove duplicate entries, keeping the last\n",
" .rename(columns={'load': name}) # Rename 'load' column to the zone name\n",
" .drop(columns=['zone']) # Drop the 'zone' column\n",
" )\n",
" df = df.merge(current_df, on='Date', how='outer') # Merge with the main DataFrame\n",
" \n",
" # Rename the 'Date' column to 'date'\n",
" df = df.rename(columns={'Date': 'date'})\n",
" elif 'nordpool' in dataset:\n",
" # Special handling for 'nordpool' dataset: Parse the 'Time' column as datetime\n",
" df = pd.read_csv(filename + dataset + '.csv', parse_dates=['Time'])\n",
" df = df.rename(columns={'Time': 'date'}) # Rename the 'Time' column to 'date'\n",
" else:\n",
" # General case: Load the dataset as-is\n",
" df = pd.read_csv(filename + dataset + '.csv')\n",
" \n",
" # Convert the 'date' column to datetime format and set it as the index\n",
" df['date'] = pd.to_datetime(df['date'])\n",
" df = df.set_index('date')\n",
"\n",
" # Drop the first column (usually an index column or non-relevant column)\n",
" df = df.iloc[:, 1:]\n",
" \n",
" return df # Return the processed DataFrame"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 6/6 [00:10<00:00, 1.67s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"dataset: ETT-small/ETTh1, \t win. size: 0,\t Avg. F_t: 0.7728,\t Avg. F_s: 0.4772\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"dataset = 'ETT-small/ETTh1' # 'exchange_rate/exchange_rate'\n",
"win_len = 0\n",
"df = load_csv_data(data_path, dataset)\n",
"measure_strength(df, dataset, win=win_len)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"window size: 336 \t dataset: ETT-small/ETTh1 \t JS DIV avg: 0.0719988819816385\n"
]
}
],
"source": [
"dataset = 'ETT-small/ETTh1' # 'exchange_rate/exchange_rate'\n",
"win_len = 336\n",
"df = load_csv_data(data_path, dataset)\n",
"cal_JS_divergence(df, dataset, win=win_len)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Short-term Datasets"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"def load_prob_data(dataset, win=0):\n",
" freq_dict = {'exchange_rate_nips':'B','solar_nips':'H','electricity_nips':'H','traffic_nips':'H', 'wiki2000_nips':'D'}\n",
" \n",
" idx = 0\n",
" dataname = dataset\n",
" dataset = get_dataset(dataset, path=save_path, regenerate=False)\n",
" dim = int(dataset.metadata.feat_static_cat[0].cardinality)\n",
" train_grouper = MultivariateGrouper(max_target_dim=dim)\n",
" dataset_train = train_grouper(dataset.train)\n",
" data = list(dataset_train)[0]['target']\n",
" start_date = dataset_train[0]['start'].to_timestamp()\n",
" \n",
" # multi\n",
" idx = [i for i in range(dim)]\n",
"\n",
" data = data.transpose(1,0)\n",
" df = pd.DataFrame(data,columns=idx,dtype=float)\n",
"\n",
" df['date'] = pd.date_range(start_date,periods=len(data),freq=freq_dict[dataname]) \n",
" df = df.set_index('date')\n",
" \n",
" return df\n"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/v-zhangjiaw/miniconda3/envs/probts/lib/python3.10/site-packages/gluonts/dataset/common.py:263: FutureWarning: Period with BDay freq is deprecated and will be removed in a future version. Use a DatetimeIndex with BDay freq instead.\n",
" return pd.Period(val, freq)\n",
"/home/v-zhangjiaw/miniconda3/envs/probts/lib/python3.10/site-packages/gluonts/dataset/multivariate_grouper.py:114: FutureWarning: Period with BDay freq is deprecated and will be removed in a future version. Use a DatetimeIndex with BDay freq instead.\n",
" timestamp + len(data[FieldName.TARGET]) - 1,\n",
"/home/v-zhangjiaw/miniconda3/envs/probts/lib/python3.10/site-packages/gluonts/dataset/multivariate_grouper.py:243: FutureWarning: Period with BDay freq is deprecated and will be removed in a future version. Use a DatetimeIndex with BDay freq instead.\n",
" index=pd.period_range(\n",
"/home/v-zhangjiaw/miniconda3/envs/probts/lib/python3.10/site-packages/gluonts/dataset/multivariate_grouper.py:243: FutureWarning: PeriodDtype[B] is deprecated and will be removed in a future version. Use a DatetimeIndex with freq='B' instead\n",
" index=pd.period_range(\n",
"/home/v-zhangjiaw/miniconda3/envs/probts/lib/python3.10/site-packages/gluonts/dataset/multivariate_grouper.py:188: FutureWarning: Period with BDay freq is deprecated and will be removed in a future version. Use a DatetimeIndex with BDay freq instead.\n",
" pd.period_range(\n",
"/home/v-zhangjiaw/miniconda3/envs/probts/lib/python3.10/site-packages/gluonts/dataset/multivariate_grouper.py:188: FutureWarning: PeriodDtype[B] is deprecated and will be removed in a future version. Use a DatetimeIndex with freq='B' instead\n",
" pd.period_range(\n",
"/tmp/ipykernel_1399510/2105741496.py:11: FutureWarning: Period with BDay freq is deprecated and will be removed in a future version. Use a DatetimeIndex with BDay freq instead.\n",
" start_date = dataset_train[0]['start'].to_timestamp()\n",
"100%|██████████| 8/8 [00:01<00:00, 5.98it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"dataset: exchange_rate_nips, \t win. size: 0,\t Avg. F_t: 0.9982,\t Avg. F_s: 0.1256\n",
"window size: 30 \t dataset: exchange_rate_nips \t JS DIV avg: 0.2964380648448922\n"
]
}
],
"source": [
"# \"exchange_rate_nips\", \"solar_nips\", \"electricity_nips\", \"traffic_nips\", \"taxi_30min\", \"wiki2000_nips\"\n",
"dataset = \"exchange_rate_nips\"\n",
"df = load_prob_data(dataset, win=0)\n",
"\n",
"measure_strength(df, dataset, win=0)\n",
"cal_JS_divergence(df, dataset, win=30)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "probts",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.15"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
================================================
FILE: probts/__init__.py
================================================
from .data import *
from .model import *
from .utils import *
================================================
FILE: probts/callbacks/__init__.py
================================================
from .memory_callback import MemoryCallback
from .time_callback import TimeCallback
================================================
FILE: probts/callbacks/memory_callback.py
================================================
import gc
import threading
import psutil
import torch
import lightning.pytorch as pl
from lightning.pytorch.callbacks.callback import Callback
def byte2gb(x):
return float(x / 2**30)
class MemoryTrace:
def __init__(self):
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated() # reset the peak gauge to zero
self.begin = byte2gb(torch.cuda.memory_allocated())
self.process = psutil.Process()
self.cpu_begin = byte2gb(self.cpu_mem_used())
self.peak_monitoring = True
peak_monitor_thread = threading.Thread(target=self.peak_monitor_func)
peak_monitor_thread.daemon = True
peak_monitor_thread.start()
def cpu_mem_used(self):
"""get resident set size memory for the current process"""
return self.process.memory_info().rss
def peak_monitor_func(self):
self.cpu_peak = -1
while True:
self.cpu_peak = max(self.cpu_mem_used(), self.cpu_peak)
if not self.peak_monitoring:
break
def __exit__(self, *exc):
self.peak_monitoring = False
gc.collect()
torch.cuda.empty_cache()
self.end = byte2gb(torch.cuda.memory_allocated())
self.peak = byte2gb(torch.cuda.max_memory_allocated())
cuda_info = torch.cuda.memory_stats()
self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"])
self.cuda_malloc_retires = cuda_info.get("num_alloc_retries", 0)
self.m_cuda_ooms = cuda_info.get("num_ooms", 0)
self.used = byte2gb(self.end - self.begin)
self.peaked = byte2gb(self.peak - self.begin)
self.max_reserved = byte2gb(torch.cuda.max_memory_reserved())
self.cpu_end = self.cpu_mem_used()
self.cpu_used = byte2gb(self.cpu_end - self.cpu_begin)
self.cpu_peaked = byte2gb(self.cpu_peak - self.cpu_begin)
class MemoryCallback(Callback):
"""
Trace the memory usage.
"""
def __init__(self):
self.memory_summary = {
'train': {},
'val': {},
'test': {}
}
def update_memory_summary(self, key, memtrace):
self.memory_summary[key] = {
"mem_peak": max(memtrace.peak, self.memory_summary[key].get("mem_peak", 0)),
"max_reserved": max(memtrace.max_reserved, self.memory_summary[key].get("max_reserved", 0)),
"peak_active_gb": max(memtrace.peak_active_gb, self.memory_summary[key].get("peak_active_gb", 0)),
"cuda_malloc_retires": max(memtrace.cuda_malloc_retires, self.memory_summary[key].get("cuda_malloc_retires", 0)),
"cpu_total_peaked": max(memtrace.cpu_peaked + memtrace.cpu_begin, self.memory_summary[key].get("cpu_total_peaked", 0)),
}
def on_train_epoch_start(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule"
) -> None:
"""Called when the train epoch begins"""
if torch.cuda.is_available():
self.train_memtrace = MemoryTrace()
def on_train_epoch_end(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule"
) -> None:
"""Called when the train epoch ends"""
if torch.cuda.is_available():
self.train_memtrace.__exit__()
self.update_memory_summary('train', self.train_memtrace)
def on_validation_epoch_start(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule"
) -> None:
"""Called when the validation epoch begins"""
if torch.cuda.is_available():
self.val_memtrace = MemoryTrace()
def on_validation_epoch_end(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule"
) -> None:
"""Called when the validation epoch ends"""
if torch.cuda.is_available():
self.val_memtrace.__exit__()
self.update_memory_summary('val', self.val_memtrace)
def on_test_epoch_start(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule"
) -> None:
"""Called when the test epoch begins"""
if torch.cuda.is_available():
self.test_memtrace = MemoryTrace()
def on_test_epoch_end(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule"
) -> None:
"""Called when the test epoch ends"""
if torch.cuda.is_available():
self.test_memtrace.__exit__()
self.update_memory_summary('test', self.test_memtrace)
================================================
FILE: probts/callbacks/time_callback.py
================================================
import time
from typing import Any
import lightning.pytorch as pl
from lightning.pytorch.utilities.types import STEP_OUTPUT
from lightning.pytorch.callbacks.callback import Callback
class TimeCallback(Callback):
"""
Trace the computation time.
"""
def __init__(self):
self.time_summary = {
'train_batch_time': [],
'val_batch_time': [],
'test_batch_time': []
}
def on_train_batch_start(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int
) -> None:
"""Called when the train batch begins."""
self.train_start_time = time.time()
def on_train_batch_end(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int
) -> None:
"""Called when the train batch ends"""
self.time_summary['train_batch_time'].append(time.time() - self.train_start_time)
def on_validation_batch_start(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
batch: Any,
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
"""Called when the validation batch begins"""
self.val_start_time = time.time()
def on_validation_batch_end(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
outputs: STEP_OUTPUT,
batch: Any,
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
"""Called when the validation batch ends"""
self.time_summary['val_batch_time'].append(time.time() - self.val_start_time)
def on_test_batch_start(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
batch: Any,
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
"""Called when the test batch begins"""
self.test_start_time = time.time()
def on_test_batch_end(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
outputs: STEP_OUTPUT,
batch: Any,
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
"""Called when the test batch ends"""
self.time_summary['test_batch_time'].append(time.time() - self.test_start_time)
================================================
FILE: probts/data/__init__.py
================================================
from .data_module import *
from .data_manager import *
from .data_utils.time_features import *
================================================
FILE: probts/data/data_manager.py
================================================
import torch
from pathlib import Path
from functools import cached_property
from gluonts.dataset.repository import dataset_names, datasets
from gluonts.dataset.multivariate_grouper import MultivariateGrouper
from probts.data.data_utils.get_datasets import get_dataset_info, get_dataset_borders, load_dataset
from probts.data.datasets.single_horizon_datasets import SingleHorizonDataset
from probts.data.datasets.multi_horizon_datasets import MultiHorizonDataset
from probts.data.datasets.gift_eval_datasets import GiftEvalDataset
from probts.data.data_utils.time_features import get_lags
from probts.data.data_utils.data_utils import split_train_val, truncate_test, get_rolling_test, df_to_mvds
from probts.data.data_wrapper import ProbTSBatchData
from probts.utils.utils import ensure_list
from probts.data.data_utils.data_scaler import StandardScaler, TemporalScaler, IdentityScaler
from typing import Union
MULTI_VARIATE_DATASETS = [
'exchange_rate_nips',
'solar_nips',
'electricity_nips',
'traffic_nips',
'taxi_30min',
'wiki-rolling_nips',
'wiki2000_nips'
]
class DataManager:
def __init__(
self,
dataset: str,
path: str = './datasets',
history_length: int = None,
context_length: int = None,
prediction_length: Union[list,int,str] = None,
train_ctx_len: int = None,
train_pred_len_list: Union[list,int,str] = None,
val_ctx_len: int = None,
val_pred_len_list: Union[list,int,str] = None,
test_rolling_length: int = 96,
split_val: bool = True,
scaler: str = 'none',
context_length_factor: int = 1,
timeenc: int = 1,
var_specific_norm: bool = True,
data_path: str = None,
freq: str = None,
multivariate: bool = True,
continuous_sample: bool = False,
train_ratio: float = 0.7,
test_ratio: float = 0.2,
auto_search: bool = False,
):
"""
DataManager class for handling datasets and preparing data for time-series models.
Parameters
----------
dataset : str
Name of the dataset to load. Examples include "etth1", "electricity_ltsf", etc.
path : str, optional, default='./datasets'
Root directory path where datasets are stored.
history_length : int, optional, default=None
Length of the historical input window for the model.
If not specified, it is automatically calculated based on `context_length` and lag features.
context_length : int, optional, default=None
Length of the input context for the model.
prediction_length : Union[list, int, str], optional, default=None
Length of the prediction horizon for the model. Can be:
- int: Fixed prediction length.
- list: Variable prediction lengths for multi-horizon training.
- str: The string format of multiple prediction length. E.g., '96-192-336-720' represents [96, 192, 336, 720]
train_ctx_len : int, optional, default=None
Context length for the training dataset.
If not specified, defaults to the value of `context_length`.
train_pred_len_list : Union[list, int, str], optional, default=None
List of prediction lengths for the training dataset.
If not specified, defaults to the value of `prediction_length`.
val_ctx_len : int, optional, default=None
Context length for the validation dataset.
If not specified, defaults to the value of `context_length`.
val_pred_len_list : Union[list, int, str], optional, default=None
List of prediction lengths for the validation dataset.
If not specified, defaults to the value of `prediction_length`.
test_rolling_length : int, optional, default=96
Gap window size used for rolling predictions in the testing phase.
- If set to `auto`, it is dynamically determined based on the dataset frequency
(e.g., 'H' -> 24, 'D' -> 7, 'W' -> 4).
split_val : bool, optional, default=True
Whether to split the training dataset into training and validation sets.
scaler : str, optional, default='none'
Type of normalization or scaling applied to the dataset. Options include:
- 'none': No scaling.
- 'standard': Standard normalization (z-score).
- 'temporal': Mean-scaling normalization.
context_length_factor : int, optional, default=1
Scaling factor for context length, allowing dynamic adjustment of `context_length`.
timeenc : int, optional, default=1
Time encoding strategy. Options include:
- 0: The dimension of time feature is 5, containing `month, day, weekday, hour, minute`
- 1: Cyclic time features (e.g., sine/cosine of timestamps).
- 2: Raw Timestamp information.
var_specific_norm : bool, optional, default=True
Whether to normalize variables independently. Only applies when `scaler='standard'`.
data_path : str, optional, default=None
Specific path to the dataset file.
freq : str, optional, default=None
Data frequency (e.g., 'H' for hourly, 'D' for daily).
multivariate : bool, optional, default=True
Whether the dataset is multivariables.
continuous_sample : bool, optional, default=False
Whether to enable continuous sampling for forecasting horizons during training phase.
train_ratio : float, optional, default=0.7
Proportion of the dataset used for training. Default is 70% of the data.
test_ratio : float, optional, default=0.2
Proportion of the dataset used for testing. Default is 20% of the data.
auto_search : bool, optional, default=False
Make past_len=ctx_len+pred_len, enabling post training search.
"""
self.dataset = dataset
self.path = path
self.history_length = history_length
self.context_length = context_length
self.prediction_length = prediction_length
self.train_ctx_len = train_ctx_len if train_ctx_len is not None else context_length
self.val_ctx_len = val_ctx_len if val_ctx_len is not None else context_length
self.train_pred_len_list = train_pred_len_list if train_pred_len_list is not None else prediction_length
self.val_pred_len_list = val_pred_len_list if val_pred_len_list is not None else prediction_length
self.test_rolling_length = test_rolling_length
self.split_val = split_val
self.scaler_type = scaler
self.context_length_factor = context_length_factor
self.timeenc = timeenc
self.var_specific_norm = var_specific_norm
self.data_path = data_path
self.freq = freq
self.multivariate = multivariate
self.continuous_sample = continuous_sample
self.train_ratio = train_ratio
self.test_ratio = test_ratio
self.auto_search = auto_search
self.test_rolling_dict = {'h': 24, 'd': 7, 'b':5, 'w':4, 'min': 60}
self.global_mean = None
# Configure scaler
self.scaler = self._configure_scaler(self.scaler_type)
# Load dataset and prepare for processing
if dataset in dataset_names:
self.multi_hor = False
self._load_short_term_dataset()
elif self.is_gift_eval:
self.multi_hor = False
# Load GIFT eval datasets from salesforce
self._load_gift_eval_dataset()
else:
# Process context and prediction lengths
self._process_context_and_prediction_lengths()
self._load_long_term_dataset()
# Print configuration details
self._print_configurations()
def _configure_scaler(self, scaler_type: str):
"""Configure the scaler."""
if scaler_type == "standard":
return StandardScaler(var_specific=self.var_specific_norm)
elif scaler_type == "temporal":
return TemporalScaler()
return IdentityScaler()
def _load_gift_eval_dataset(self):
parts = self.dataset[5:].split('/') # Remove first 'gift/'
self.dataset = '/'.join(parts[:-1]) # Join all parts except last one with '/'
gift_term = parts[-1] # corresponding to "term" parameter in GiftEvalDataset
TO_UNIVARIATE = False
self.dataset_raw = GiftEvalDataset(self.dataset, term=gift_term, to_univariate=TO_UNIVARIATE)
self._set_meta_parameters(self.dataset_raw.target_dim, self.dataset_raw.freq, self.dataset_raw.prediction_length)
dataset_loader = SingleHorizonDataset(
ProbTSBatchData.input_names_,
self.history_length,
self.context_length,
self.prediction_length,
self.freq,
self.multivariate
)
self.train_iter_dataset = dataset_loader.get_iter_dataset(self.dataset_raw.training_dataset, mode='train')
self.val_iter_dataset = dataset_loader.get_iter_dataset(self.dataset_raw.validation_dataset, mode='val')
self.test_iter_dataset = dataset_loader.get_iter_dataset(self.dataset_raw.test_dataset, mode='test')
self.time_feat_dim = dataset_loader.time_feat_dim
# TODO: Implement global mean for GIFT eval datasets
# self.global_mean = torch.mean(torch.tensor(self.dataset_raw.training_dataset[0]['target']), dim=-1)
def _load_short_term_dataset(self):
"""Load short-term dataset using GluonTS."""
print(f"Loading Short-term Dataset: {self.dataset}")
self.dataset_raw = datasets.get_dataset(self.dataset, path=Path(self.path), regenerate=True)
metadata = self.dataset_raw.metadata
if self.is_univar_dataset:
target_dim = 1
else:
target_dim = metadata.feat_static_cat[0].cardinality
self._set_meta_parameters(target_dim, metadata.freq.upper(), metadata.prediction_length)
self.prepare_STSF_dataset(self.dataset)
def _set_meta_parameters(self, target_dim, freq, prediction_length):
"""Set meta parameters from base dataset."""
self.target_dim = int(target_dim)
self.multivariate = self.target_dim > 1
self.freq = freq
self.lags_list = get_lags(self.freq)
self.prediction_length = prediction_length
self.context_length = self.context_length or self.prediction_length * self.context_length_factor
self.history_length = self.history_length or (self.context_length + max(self.lags_list))
def _process_context_and_prediction_lengths(self):
"""Convert context and prediction lengths to lists for multi-horizon processing."""
self.train_ctx_len_list = ensure_list(self.train_ctx_len, default_value=self.context_length)
self.val_ctx_len_list = ensure_list(self.val_ctx_len, default_value=self.context_length)
self.test_ctx_len_list = ensure_list(self.context_length)
self.train_pred_len_list = ensure_list(self.train_pred_len_list, default_value=self.prediction_length)
self.val_pred_len_list = ensure_list(self.val_pred_len_list, default_value=self.prediction_length)
self.test_pred_len_list = ensure_list(self.prediction_length)
# Validate context length support
assert len(self.train_ctx_len_list) == 1, "Assign a single context length for training."
assert len(self.val_ctx_len_list) == 1, "Assign a single context length for validation."
assert len(self.test_ctx_len_list) == 1, "Assign a single context length for testing."
self.multi_hor = len(self.train_pred_len_list) > 1 or \
len(self.val_pred_len_list) > 1 or \
len(self.test_pred_len_list) > 1
def _load_long_term_dataset(self):
"""Load long-term dataset or customized dataset."""
print(f"Loading Long-term Dataset: {self.dataset}")
if not self.context_length or not self.prediction_length:
raise ValueError("context_length or prediction_length must be specified.")
data_path, self.freq = get_dataset_info(self.dataset, data_path=self.data_path, freq=self.freq)
self.dataset_raw, self.data_stamp, self.target_dim, data_size = load_dataset(
self.path, data_path, freq=self.freq, timeenc=self.timeenc, multivariate=self.multivariate
)
self.border_begin, self.border_end = get_dataset_borders(
self.dataset, data_size, train_ratio=self.train_ratio, test_ratio=self.test_ratio
)
self._set_meta_parameters_from_raw(data_size)
self.prepare_dataset()
def _set_meta_parameters_from_raw(self, data_size):
"""Set meta parameters directly from raw dataset."""
self.lags_list = get_lags(self.freq)
self.prediction_length = ensure_list(self.prediction_length) if self.multi_hor else self.prediction_length
self.context_length = ensure_list(self.context_length) if self.multi_hor else self.context_length
self.history_length = self.history_length or (
max(self.context_length) + max(self.lags_list) if self.multi_hor else self.context_length + max(self.lags_list)
)
if not self.multivariate:
self.target_dim = 1
raise NotImplementedError("Customized univariate datasets are not yet supported.")
assert data_size >= self.border_end[2], "border_end index exceeds dataset size!"
# define the test_rolling_length
if self.test_rolling_length == 'auto':
if self.freq.lower() in self.test_rolling_dict:
self.test_rolling_length = self.test_rolling_dict[self.freq.lower()]
else:
self.test_rolling_length = 24
def prepare_dataset(self):
"""Prepare datasets for training, validation, and testing."""
# Split raw data into train, validation, and test sets
train_data = self.dataset_raw[: self.border_end[0]]
val_data = self.dataset_raw[: self.border_end[1]]
test_data = self.dataset_raw[: self.border_end[2]]
# Calculate statictics using training data
self.scaler.fit(torch.tensor(train_data.values))
# Convert dataframes to multivariate datasets
train_set = df_to_mvds(train_data, freq=self.freq)
val_set = df_to_mvds(val_data,freq=self.freq)
test_set = df_to_mvds(test_data,freq=self.freq)
train_grouper = MultivariateGrouper(max_target_dim=self.target_dim)
test_grouper = MultivariateGrouper(max_target_dim=self.target_dim)
group_train_set = train_grouper(train_set)
group_val_set = test_grouper(val_set)
group_test_set = test_grouper(test_set)
if self.multi_hor:
# Handle multi-horizon datasets
dataset_loader = self._prepare_multi_horizon_datasets(group_val_set, group_test_set)
else:
# Handle single-horizon datasets
dataset_loader = self._prepare_single_horizon_datasets(group_val_set, group_test_set)
self.train_iter_dataset = dataset_loader.get_iter_dataset(group_train_set, mode='train', data_stamp=self.data_stamp[: self.border_end[0]])
self.time_feat_dim = dataset_loader.time_feat_dim
self.global_mean = torch.mean(torch.tensor(group_train_set[0]['target']), dim=-1)
def _prepare_multi_horizon_datasets(self, group_val_set, group_test_set):
"""Prepare multi-horizon datasets for validation and testing."""
self.val_iter_dataset = {}
self.test_iter_dataset = {}
dataset_loader = MultiHorizonDataset(
input_names = ProbTSBatchData.input_names_,
freq = self.freq,
train_ctx_range = self.train_ctx_len_list,
train_pred_range = self.train_pred_len_list,
val_ctx_range = self.val_ctx_len_list,
val_pred_range = self.val_pred_len_list,
test_ctx_range = self.test_ctx_len_list,
test_pred_range = self.test_pred_len_list,
multivariate = self.multivariate,
continuous_sample = self.continuous_sample
)
# Prepare validation datasets
for pred_len in self.val_pred_len_list:
local_group_val_set = get_rolling_test(
'val', group_val_set, self.border_begin[1], self.border_end[1],
rolling_length=self.test_rolling_length, pred_len=pred_len, freq=self.freq
)
self.val_iter_dataset[str(pred_len)] = dataset_loader.get_iter_dataset(
local_group_val_set, mode='val', data_stamp=self.data_stamp[:self.border_end[1]], pred_len=[pred_len]
)
# Prepare testing datasets
for pred_len in self.test_pred_len_list:
local_group_test_set = get_rolling_test(
'test', group_test_set, self.border_begin[2], self.border_end[2],
rolling_length=self.test_rolling_length, pred_len=pred_len, freq=self.freq
)
self.test_iter_dataset[str(pred_len)] = dataset_loader.get_iter_dataset(
local_group_test_set, mode='test', data_stamp=self.data_stamp[:self.border_end[2]], pred_len=[pred_len], auto_search=self.auto_search,
)
return dataset_loader
def _prepare_single_horizon_datasets(self, group_val_set, group_test_set):
"""Prepare single-horizon datasets for training, validation, and testing."""
dataset_loader = SingleHorizonDataset(
ProbTSBatchData.input_names_,
self.history_length,
self.context_length,
self.prediction_length,
self.freq,
self.multivariate,
)
# Validation dataset
local_group_val_set = get_rolling_test(
'val', group_val_set, self.border_begin[1], self.border_end[1],
rolling_length=self.test_rolling_length, pred_len=self.val_pred_len_list[0], freq=self.freq
)
self.val_iter_dataset = dataset_loader.get_iter_dataset(local_group_val_set, mode='val', data_stamp=self.data_stamp[:self.border_end[1]])
# Testing dataset
local_group_test_set = get_rolling_test(
'test', group_test_set, self.border_begin[2], self.border_end[2],
rolling_length=self.test_rolling_length, pred_len=self.prediction_length, freq=self.freq
)
self.test_iter_dataset = dataset_loader.get_iter_dataset(local_group_test_set, mode='test', data_stamp=self.data_stamp[:self.border_end[2]], auto_search=self.auto_search)
return dataset_loader
def prepare_STSF_dataset(self, dataset: str):
"""Prepare datasets for short-term series forecasting."""
if dataset in MULTI_VARIATE_DATASETS:
self.num_test_dates = int(len(self.dataset_raw.test)/len(self.dataset_raw.train))
train_grouper = MultivariateGrouper(max_target_dim=int(self.target_dim))
test_grouper = MultivariateGrouper(
num_test_dates=self.num_test_dates,
max_target_dim=int(self.target_dim)
)
train_set = train_grouper(self.dataset_raw.train)
test_set = test_grouper(self.dataset_raw.test)
self.scaler.fit(torch.tensor(train_set[0]['target'].transpose(1, 0)))
self.global_mean = torch.mean(torch.tensor(train_set[0]['target']), dim=-1)
# split_val
if self.split_val:
train_set, val_set = split_train_val(train_set, self.num_test_dates, self.context_length, self.prediction_length, self.freq)
else:
val_set = None
else:
self.target_dim = 1
self.multivariate = False
self.num_test_dates = 1
train_set = self.dataset_raw.train
test_set = self.dataset_raw.test
test_set = truncate_test(test_set, self.context_length, self.prediction_length, self.freq)
# for univariate dataset, e.g., M4 and M5, no validation set is used
val_set = None
if val_set is None:
print('No validation set is used.')
dataset_loader = SingleHorizonDataset(
ProbTSBatchData.input_names_,
self.history_length,
self.context_length,
self.prediction_length,
self.freq,
self.multivariate
)
self.train_iter_dataset = dataset_loader.get_iter_dataset(train_set, mode='train')
if val_set is not None:
self.val_iter_dataset = dataset_loader.get_iter_dataset(val_set, mode='val')
else:
self.val_iter_dataset = None
self.test_iter_dataset = dataset_loader.get_iter_dataset(test_set, mode='test')
self.time_feat_dim = dataset_loader.time_feat_dim
def _print_configurations(self):
"""Print dataset and configuration details."""
print(f"Test context length: {self.test_ctx_len_list}, prediction length: {self.test_pred_len_list}")
print(f"Validation context length: {self.val_ctx_len_list}, prediction length: {self.val_pred_len_list}")
print(f"Training context length: {self.train_ctx_len_list}, prediction lengths: {self.train_pred_len_list}")
print(f"Test rolling length: {self.test_rolling_length}")
if self.scaler_type == "standard":
print(f"Variable-specific normalization: {self.var_specific_norm}")
@cached_property
def is_gift_eval(self) -> bool:
return self.dataset[:5] == "gift/"
@cached_property
def is_univar_dataset(self) -> bool:
if 'm4' in self.dataset or 'm5' in self.dataset:
return True
return False
================================================
FILE: probts/data/data_module.py
================================================
import torch
import lightning.pytorch as pl
from torch.utils.data import DataLoader, Dataset
from lightning.pytorch.utilities.combined_loader import CombinedLoader
from probts.data.data_manager import DataManager
from probts.data.data_wrapper import ProbTSBatchData
class EmptyDataset(Dataset):
def __len__(self):
return 0
def __getitem__(self, idx):
raise IndexError("This dataset is empty.")
class ProbTSDataModule(pl.LightningDataModule):
r"""
DataModule for probablistic time series datasets.
"""
def __init__(
self,
data_manager: DataManager,
batch_size: int = 64,
test_batch_size: int = 8,
num_workers: int = 8
):
super().__init__()
self.data_manager = data_manager
self.batch_size = batch_size
self.test_batch_size = test_batch_size
self.num_workers = num_workers
self.save_hyperparameters()
self.dataset_train = self.data_manager.train_iter_dataset
self.dataset_val = self.data_manager.val_iter_dataset
self.dataset_test = self.data_manager.test_iter_dataset
def train_dataloader(self):
if self.data_manager.multi_hor:
return DataLoader(
self.dataset_train,
batch_size=self.batch_size,
num_workers=0,
pin_memory=True,
collate_fn=self.train_collate_fn
)
else:
return DataLoader(
self.dataset_train,
batch_size=self.batch_size,
num_workers=self.num_workers,
persistent_workers=True,
pin_memory=True
)
def val_dataloader(self):
# if no validation set available
if self.dataset_val is None:
return DataLoader(EmptyDataset(), batch_size=1)
if self.data_manager.multi_hor:
val_dataloader = self.combine_dataloader(self.dataset_val)
else:
val_dataloader = DataLoader(self.dataset_val, batch_size=self.test_batch_size, num_workers=1)
return val_dataloader
def test_dataloader(self):
if self.data_manager.multi_hor:
return self.combine_dataloader(self.dataset_test)
else:
return DataLoader(self.dataset_test, batch_size=self.test_batch_size, num_workers=1)
def predict_dataloader(self):
return DataLoader(self.dataset_test, batch_size=self.test_batch_size, num_workers=0)
def combine_dataloader(self, dataset_dict):
dataloader_dict = {}
for hor in dataset_dict:
dataloader_dict[hor] = DataLoader(dataset_dict[hor], batch_size=self.test_batch_size, num_workers=0, persistent_workers=False,)
combined_loader = CombinedLoader(dataloader_dict, mode="sequential")
return combined_loader
def train_collate_fn(self, batch):
'''
Training with varied horizons is achieved by padding horizons in training phase.
The look-back window for each sample can different within a batch.
'''
past_len_list = [len(x['past_target_cdf']) for x in batch]
future_len_list = [len(x['future_target_cdf']) for x in batch]
max_past_length = max(past_len_list)
max_future_length = max(future_len_list)
B = len(batch)
batch_dict = {}
batch_dict['context_length'] = []
batch_dict['prediction_length'] = []
batch_dict['target_dimension_indicator'] = []
for idx in range(len(batch)):
local_past_len = len(batch[idx]['past_target_cdf'])
local_future_len = len(batch[idx]['future_target_cdf'])
for input in ProbTSBatchData.input_names_:
K = batch[0][input].shape[-1]
if input in ['past_target_cdf','past_observed_values','past_time_feat','past_is_pad']:
if input not in batch_dict and input in ['past_target_cdf','past_observed_values','past_time_feat']:
batch_dict[input] = torch.zeros([B, max_past_length, K])
if input not in batch_dict and input in ['past_is_pad']:
batch_dict[input] = torch.zeros([B, max_past_length])
batch_dict[input][idx][-local_past_len:] = torch.tensor(batch[idx][input])[:local_past_len]
elif input in ['future_target_cdf','future_observed_values','future_time_feat']:
if input not in batch_dict:
batch_dict[input] = torch.zeros([B, max_future_length, K])
batch_dict[input][idx][:local_future_len] = torch.tensor(batch[idx][input])[:local_future_len]
batch_dict['target_dimension_indicator'].append(batch[idx]['target_dimension_indicator'])
batch_dict['context_length'].append(local_past_len)
batch_dict['prediction_length'].append(local_future_len)
batch_dict['target_dimension_indicator'] = torch.tensor(batch_dict['target_dimension_indicator'])
batch_dict['max_context_length'] = max_past_length
batch_dict['max_prediction_length'] = max_future_length
return batch_dict
================================================
FILE: probts/data/data_utils/data_scaler.py
================================================
# ---------------------------------------------------------------------------------
# Portions of this file are derived from PyTorch-TS
# - Source: https://github.com/zalandoresearch/pytorch-ts
# - License: MIT, Apache-2.0 license
# We thank the authors for their contributions.
# ---------------------------------------------------------------------------------
import torch
import torch.nn as nn
class Scaler:
def __init__(self):
super().__init__()
def fit(self, values):
raise NotImplementedError
def transform(self, values):
raise NotImplementedError
def fit_transform(self, values):
raise NotImplementedError
def inverse_transform(self, values):
raise NotImplementedError
class StandardScaler(Scaler):
def __init__(
self,
mean: float = None,
std: float = None,
epsilon: float = 1e-9,
var_specific: bool = True
):
"""
The class can be used to normalize PyTorch Tensors using native functions. The module does not expect the
tensors to be of any specific shape; as long as the features are the last dimension in the tensor, the module
will work fine.
Args:
mean: The mean of the features. The property will be set after a call to fit.
std: The standard deviation of the features. The property will be set after a call to fit.
epsilon: Used to avoid a Division-By-Zero exception.
var_specific: If True, the mean and standard deviation will be computed per variate.
"""
self.mean = mean
self.scale = std
self.epsilon = epsilon
self.var_specific = var_specific
def fit(self, values):
"""
Args:
values: Input values should be a PyTorch tensor of shape (T, C) or (N, T, C),
where N is the batch size, T is the timesteps and C is the number of variates.
"""
dims = list(range(values.dim() - 1))
if not self.var_specific:
self.mean = torch.mean(values)
self.scale = torch.std(values)
else:
self.mean = torch.mean(values, dim=dims)
self.scale = torch.std(values, dim=dims)
def transform(self, values):
if self.mean is None:
return values
values = (values - self.mean.to(values.device)) / (self.scale.to(values.device) + self.epsilon)
return values.to(torch.float32)
def fit_transform(self, values):
self.fit(values)
return self.transform(values)
def inverse_transform(self, values):
if self.mean is None:
return values
values = values * (self.scale.to(values.device) + self.epsilon)
values = values + self.mean.to(values.device)
return values
class TemporalScaler(Scaler):
def __init__(
self,
minimum_scale:float = 1e-10,
time_first: bool = True
):
"""
The ``TemporalScaler`` computes a per-item scale according to the average
absolute value over time of each item. The average is computed only among
the observed values in the data tensor, as indicated by the second
argument. Items with no observed data are assigned a scale based on the
global average.
Args:
minimum_scale: default scale that is used if the time series has only zeros.
time_first: if True, the input tensor has shape (N, T, C), otherwise (N, C, T).
"""
super().__init__()
self.scale = None
self.minimum_scale = torch.tensor(minimum_scale)
self.time_first = time_first
def fit(
self,
data: torch.Tensor,
observed_indicator: torch.Tensor = None
):
"""
Fit the scaler to the data.
Args:
data: tensor of shape (N, T, C) if ``time_first == True`` or (N, C, T)
if ``time_first == False`` containing the data to be scaled
observed_indicator: observed_indicator: binary tensor with the same shape as
``data``, that has 1 in correspondence of observed data points,
and 0 in correspondence of missing data points.
Note:
Tensor containing the scale, of shape (N, 1, C) or (N, C, 1).
"""
if self.time_first:
dim = -2
else:
dim = -1
if observed_indicator is None:
observed_indicator = torch.ones_like(data)
# These will have shape (N, C)
num_observed = observed_indicator.sum(dim=dim)
sum_observed = (data.abs() * observed_indicator).sum(dim=dim)
# First compute a global scale per-dimension
total_observed = num_observed.sum(dim=0)
denominator = torch.max(total_observed, torch.ones_like(total_observed))
default_scale = sum_observed.sum(dim=0) / denominator
# Then compute a per-item, per-dimension scale
denominator = torch.max(num_observed, torch.ones_like(num_observed))
scale = sum_observed / denominator
# Use per-batch scale when no element is observed
# or when the sequence contains only zeros
scale = torch.where(
sum_observed > torch.zeros_like(sum_observed),
scale,
default_scale * torch.ones_like(num_observed),
)
self.scale = torch.max(scale, self.minimum_scale).unsqueeze(dim=dim).detach()
def transform(self, data):
return data / self.scale.to(data.device)
def fit_transform(self, data, observed_indicator=None):
self.fit(data, observed_indicator)
return self.transform(data)
def inverse_transform(self, data):
return data * self.scale.to(data.device)
class IdentityScaler(Scaler):
"""
No scaling is applied upon calling the ``IdentityScaler``.
"""
def __init__(self, time_first: bool = True):
super().__init__()
self.scale = None
def fit(self, data):
pass
def transform(self, data):
return data
def inverse_transform(self, data):
return data
class InstanceNorm(nn.Module):
def __init__(self, eps=1e-5):
"""
:param eps: a value added for numerical stability
"""
super(InstanceNorm, self).__init__()
self.eps = eps
def forward(self, x, mode:str):
if mode == 'norm':
self._get_statistics(x)
x = self._normalize(x)
elif mode == 'denorm':
x = self._denormalize(x)
else: raise NotImplementedError
return x
def _get_statistics(self, x):
dim2reduce = tuple(range(1, x.ndim-1))
self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach()
self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach()
def _normalize(self, x):
x = x - self.mean
x = x / self.stdev
return x
def _denormalize(self, x):
x = x * self.stdev
x = x + self.mean
return x
================================================
FILE: probts/data/data_utils/data_utils.py
================================================
from copy import deepcopy
import math
import pandas as pd
import numpy as np
from datetime import datetime
from distutils.util import strtobool
from gluonts.dataset.common import ListDataset
from gluonts.dataset.field_names import FieldName
def split_train_val(train_set, num_test_windows, context_length, prediction_length, freq):
"""
Splits a training dataset into a truncated training set and a validation set.
Parameters:
- train_set: The input training dataset.
- num_test_windows: Number of rolling windows for validation.
- context_length: Context length for the model.
- prediction_length: Prediction horizon for the model.
- freq: Data frequency (e.g., 'H' for hourly).
Returns:
- trunc_train_set: Truncated training dataset (ListDataset).
- val_set: Validation dataset (ListDataset).
"""
trunc_train_list = []
val_set_list = []
univariate = False
for train_seq in iter(train_set):
# truncate train set
offset = num_test_windows * prediction_length
trunc_train_seq = deepcopy(train_seq)
if len(train_seq[FieldName.TARGET].shape) == 1:
trunc_train_len = train_seq[FieldName.TARGET].shape[0] - offset
trunc_train_seq[FieldName.TARGET] = train_seq[FieldName.TARGET][:trunc_train_len]
univariate = True
elif len(train_seq[FieldName.TARGET].shape) == 2:
trunc_train_len = train_seq[FieldName.TARGET].shape[1] - offset
trunc_train_seq[FieldName.TARGET] = train_seq[FieldName.TARGET][:, :trunc_train_len]
else:
raise ValueError(f"Invalid Data Shape: {str(len(train_seq[FieldName.TARGET].shape))}")
trunc_train_list.append(trunc_train_seq)
# construct val set by rolling
for i in range(num_test_windows):
val_seq = deepcopy(train_seq)
rolling_len = trunc_train_len + prediction_length * (i+1)
if univariate:
val_seq[FieldName.TARGET] = val_seq[FieldName.TARGET][trunc_train_len + prediction_length * (i-1) - context_length : rolling_len]
else:
val_seq[FieldName.TARGET] = val_seq[FieldName.TARGET][:, :rolling_len]
val_set_list.append(val_seq)
trunc_train_set = ListDataset(
trunc_train_list, freq=freq, one_dim_target=univariate
)
val_set = ListDataset(
val_set_list, freq=freq, one_dim_target=univariate
)
return trunc_train_set, val_set
def truncate_test(test_set, context_length, prediction_length, freq):
"""
Truncates the test dataset to ensure only the last context and prediction lengths are retained.
Parameters:
- test_set: The input test dataset.
- context_length: Context length for the model.
- prediction_length: Prediction horizon for the model.
- freq: Data frequency.
Returns:
- trunc_test_set: Truncated test dataset (ListDataset).
"""
trunc_test_list = []
for test_seq in iter(test_set):
# truncate train set
trunc_test_seq = deepcopy(test_seq)
trunc_test_seq[FieldName.TARGET] = trunc_test_seq[FieldName.TARGET][- (prediction_length * 2 + context_length):]
trunc_test_list.append(trunc_test_seq)
trunc_test_set = ListDataset(
trunc_test_list, freq=freq, one_dim_target=True
)
return trunc_test_set
def get_rolling_test(stage, test_set, border_begin_idx, border_end_idx, rolling_length, pred_len, freq=None):
"""
Using rolling windows to build the test dataset.
Parameters:
- stage: Stage name (e.g., 'test', 'val').
- test_set: The test dataset.
- border_begin_idx: Start index for rolling windows.
- border_end_idx: End index for rolling windows.
- rolling_length: Gap length of each rolling window.
- pred_len: Prediction length.
- freq: Data frequency.
Returns:
- rolling_test_set: Rolling test dataset (ListDataset).
"""
num_test_windows = math.ceil(((border_end_idx - border_begin_idx - pred_len) / rolling_length))
print(f"{stage} pred_len: {pred_len} : num_test_windows: {num_test_windows}")
test_set = next(iter(test_set))
rolling_test_seq_list = list()
for i in range(num_test_windows):
rolling_test_seq = deepcopy(test_set)
rolling_end = border_begin_idx + pred_len + i * rolling_length
rolling_test_seq[FieldName.TARGET] = rolling_test_seq[FieldName.TARGET][:, :rolling_end]
rolling_test_seq_list.append(rolling_test_seq)
rolling_test_set = ListDataset(
rolling_test_seq_list, freq=freq, one_dim_target=False
)
return rolling_test_set
def get_rolling_test_of_gift_eval(dataset, prediction_length, windows):
"""
Using rolling windows to build the test dataset for GiftEval.
https://github.com/SalesforceAIResearch/gift-eval/blob/61ec5e563188bc4b2d7e86f6a7fcc78270607ae7/src/gift_eval/data.py#L213
Get the windows from the back of the dataset, for example if the dataset has N time points:
- The first window will be from the first time point to the N - prediction_length * windows time point.
- The second window will be from the first time point to the N - prediction_length * (windows - 1) time point.
- The last window will be from the first time point to the N time point.
Parameters:
- dataset: The input dataset.
- prediction_length: Prediction length.
- windows: Number of rolling windows.
Returns:
- rolling_test_set: Rolling test dataset (ListDataset).
"""
rolling_test_seq_list = list()
dataset = next(iter(dataset))
if "freq" not in dataset.keys():
raise ValueError("The dataset must contain the 'freq' key.")
freq = dataset["freq"]
is_univariate = len(dataset[FieldName.TARGET].shape) == 1
for i in range(windows):
rolling_test_seq = deepcopy(dataset)
rolling_end = dataset[FieldName.TARGET].shape[-1] - prediction_length * (windows - i)
if is_univariate:
rolling_test_seq[FieldName.TARGET] = dataset[FieldName.TARGET][:rolling_end]
elif len(dataset[FieldName.TARGET].shape) == 2:
rolling_test_seq[FieldName.TARGET] = dataset[FieldName.TARGET][:, :rolling_end]
else:
raise ValueError(f"Invalid Data Shape: expected 1 or 2 dimensions, got {len(dataset[FieldName.TARGET].shape)}")
rolling_test_seq_list.append(rolling_test_seq)
rolling_test_set = ListDataset(
rolling_test_seq_list, freq=freq, one_dim_target=is_univariate
)
return rolling_test_set
def df_to_mvds(df, freq='H'):
"""
Converts a pandas DataFrame to a multivariate ListDataset for GluonTS.
Parameters:
- df: Input DataFrame where columns represent time series variables.
- freq: Data frequency (e.g., 'H' for hourly).
Returns:
- dataset: Multivariate ListDataset.
"""
datasets = []
for variable in df.keys():
ds = {"item_id" : variable, "target" : df[variable], "start": str(df.index[0])}
datasets.append(ds)
dataset = ListDataset(datasets,freq=freq)
return dataset
def convert_monash_data_to_dataframe(
full_file_path_and_name,
replace_missing_vals_with="NaN",
value_column_name="series_value",
):
col_names = []
col_types = []
all_data = {}
line_count = 0
frequency = None
forecast_horizon = None
contain_missing_values = None
contain_equal_length = None
found_data_tag = False
found_data_section = False
started_reading_data_section = False
with open(full_file_path_and_name, "r", encoding="cp1252") as file:
for line in file:
# Strip white space from start/end of line
line = line.strip()
if line:
if line.startswith("@"): # Read meta-data
if not line.startswith("@data"):
line_content = line.split(" ")
if line.startswith("@attribute"):
if (
len(line_content) != 3
): # Attributes have both name and type
raise Exception("Invalid meta-data specification.")
col_names.append(line_content[1])
col_types.append(line_content[2])
else:
if (
len(line_content) != 2
): # Other meta-data have only values
raise Exception("Invalid meta-data specification.")
if line.startswith("@frequency"):
frequency = line_content[1]
elif line.startswith("@horizon"):
forecast_horizon = int(line_content[1])
elif line.startswith("@missing"):
contain_missing_values = bool(
strtobool(line_content[1])
)
elif line.startswith("@equallength"):
contain_equal_length = bool(strtobool(line_content[1]))
else:
if len(col_names) == 0:
raise Exception(
"Missing attribute section. Attribute section must come before data."
)
found_data_tag = True
elif not line.startswith("#"):
if len(col_names) == 0:
raise Exception(
"Missing attribute section. Attribute section must come before data."
)
elif not found_data_tag:
raise Exception("Missing @data tag.")
else:
if not started_reading_data_section:
started_reading_data_section = True
found_data_section = True
all_series = []
for col in col_names:
all_data[col] = []
full_info = line.split(":")
if len(full_info) != (len(col_names) + 1):
raise Exception("Missing attributes/values in series.")
series = full_info[len(full_info) - 1]
series = series.split(",")
if len(series) == 0:
raise Exception(
"A given series should contains a set of comma separated numeric values. At least one numeric value should be there in a series. Missing values should be indicated with ? symbol"
)
numeric_series = []
for val in series:
if val == "?":
numeric_series.append(replace_missing_vals_with)
else:
numeric_series.append(float(val))
if numeric_series.count(replace_missing_vals_with) == len(
numeric_series
):
raise Exception(
"All series values are missing. A given series should contains a set of comma separated numeric values. At least one numeric value should be there in a series."
)
all_series.append(pd.Series(numeric_series).array)
for i in range(len(col_names)):
att_val = None
if col_types[i] == "numeric":
att_val = int(full_info[i])
elif col_types[i] == "string":
att_val = str(full_info[i])
elif col_types[i] == "date":
att_val = datetime.strptime(
full_info[i], "%Y-%m-%d %H-%M-%S"
)
else:
raise Exception(
"Invalid attribute type."
) # Currently, the code supports only numeric, string and date types. Extend this as required.
if att_val is None:
raise Exception("Invalid attribute value.")
else:
all_data[col_names[i]].append(att_val)
line_count = line_count + 1
if line_count == 0:
raise Exception("Empty file.")
if len(col_names) == 0:
raise Exception("Missing attribute section.")
if not found_data_section:
raise Exception("Missing series information under data section.")
all_data[value_column_name] = all_series
loaded_data = pd.DataFrame(all_data)
return (
loaded_data,
frequency,
forecast_horizon,
contain_missing_values,
contain_equal_length,
)
def monash_format_convert(loaded_data, frequency, multivariate):
series_names = loaded_data['series_name'].values
if str(frequency) == '10_minutes':
freq = '10min'
elif str(frequency) == 'daily':
freq = 'D'
else:
freq = frequency
if multivariate:
timestamps = pd.date_range(start=loaded_data['start_timestamp'][0], periods=len(loaded_data['series_value'][0]), freq=freq)
new_df = pd.DataFrame({ 'date': timestamps })
series_df = pd.DataFrame({ series: loaded_data['series_value'][i] for i, series in enumerate(series_names) })
result_df = pd.concat([new_df, series_df], axis=1)
else:
result = []
for idx, row in loaded_data.iterrows():
result.append({
'target': np.array(row['series_value'], dtype=np.float32),
'start': pd.Period(row['start_timestamp'], freq=freq),
'feat_static_cat': np.array([idx], dtype=np.int32),
'item_id': idx,
})
result_df = pd.DataFrame(result)
return result_df
================================================
FILE: probts/data/data_utils/get_datasets.py
================================================
# ---------------------------------------------------------------------------------
# Portions of this file are derived from Autoformer
# - Source: https://github.com/thuml/Autoformer/tree/main
#
# We thank the authors for their contributions.
# ---------------------------------------------------------------------------------
import os
import pandas as pd
from probts.data.data_utils.time_features import time_features
from probts.data.data_utils.data_utils import convert_monash_data_to_dataframe, monash_format_convert
import numpy as np
def get_dataset_info(dataset, data_path=None, freq=None):
"""
Get the file path and frequency associated with the specified dataset.
Parameters:
dataset (str): The name of the dataset.
data_path (str): Optional custom data path for the dataset.
freq (str): Optional custom frequency for the dataset.
Returns:
tuple: A tuple containing the data path and frequency.
"""
paths = {
'etth1': ('ETT-small/ETTh1.csv', 'H'),
'etth2': ('ETT-small/ETTh2.csv', 'H'),
'ettm1': ('ETT-small/ETTm1.csv', 'min'),
'ettm2': ('ETT-small/ETTm2.csv', 'min'),
'traffic_ltsf': ('traffic/traffic.csv', 'H'),
'electricity_ltsf': ('electricity/electricity.csv', 'H'),
'exchange_ltsf': ('exchange_rate/exchange_rate.csv', 'B'),
'illness_ltsf': ('illness/national_illness.csv', 'W'),
'weather_ltsf': ('weather/weather.csv', 'min'),
'caiso': ('caiso/caiso_20130101_20210630.csv', 'H'),
'nordpool': ('nordpool/production.csv', 'H'),
'turkey_power': ('kaggle/power Generation and consumption.csv', 'H'),
'istanbul_traffic': ('kaggle/istanbul_traffic.csv', 'H')
}
if dataset in paths:
data_path, freq = paths[dataset]
else:
assert data_path is not None, f'Invalid dataset name: {dataset}! Provide --data.data_manager.init_args.data_path for custom datasets.'
assert freq is not None, 'Provide --data.data_manager.init_args.freq for custom datasets.'
return data_path, freq
def get_dataset_borders(dataset, data_size, train_ratio=0.7, test_ratio=0.2):
"""
Compute the start and end indices for train, validation, and test splits.
Parameters:
dataset (str): The name of the dataset.
data_size (int): Total number of time points in the dataset.
train_ratio (float): Proportion of the dataset used for training.
test_ratio (float): Proportion of the dataset used for testing.
Returns:
tuple: Two lists representing the start and end indices of each split.
"""
# Validate ratios
assert 0 < train_ratio <= 1, "train_ratio must be between 0 and 1 (exclusive of 0)."
assert 0 < test_ratio <= 1, "test_ratio must be between 0 and 1 (exclusive of 0)."
assert train_ratio + test_ratio <= 1, "The sum of train_ratio and test_ratio must not exceed 1."
# Predefined borders for ETT datasets
if dataset == 'etth1' or dataset == 'etth2':
border_begin = [0, 12 * 30 * 24, 12 * 30 * 24 + 4 * 30 * 24]
border_end = [12 * 30 * 24, 12 * 30 * 24 + 4 * 30 * 24, 12 * 30 * 24 + 8 * 30 * 24]
elif dataset == 'ettm1' or dataset == 'ettm2':
border_begin = [0, 12 * 30 * 24 * 4, 12 * 30 * 24 * 4 + 4 * 30 * 24 * 4]
border_end = [12 * 30 * 24 * 4, 12 * 30 * 24 * 4 + 4 * 30 * 24 * 4, 12 * 30 * 24 * 4 + 8 * 30 * 24 * 4]
else:
# Calculate borders for custom datasets
num_train = int(data_size * train_ratio)
num_test = int(data_size * test_ratio)
num_vali = data_size - num_train - num_test
border_begin = [0, num_train, data_size - num_test]
border_end = [num_train, num_train + num_vali, data_size]
return border_begin, border_end
def load_dataset(root_path, data_path,freq='h', timeenc=1, multivariate=True):
"""
Load and process datasets.
Parameters:
root_path (str): Root directory for datasets.
data_path (str): Path to the specific dataset.
freq (str): Frequency of the dataset (e.g., 'H', 'min').
timeenc (int): Time encoding method (0 for temporal information, 1 for time feature based on frequency, 2 for raw date information).
multivariate (bool): Whether the dataset is multivariate.
Returns:
df_raw: the processed DataFrame
data_stamp: time features
target_dim: target dimensions
data_size: total length of timestamps.
"""
data_format = None
if '.tsf' in data_path:
# Load Monash time series dataset
df_raw, _, _, _, _ = convert_monash_data_to_dataframe(data_path)
df_raw = monash_format_convert(df_raw, freq, multivariate)
if multivariate:
if freq.lower() == 'h':
df_raw.set_index('date', inplace=True)
df_raw = df_raw.resample(freq).mean().reset_index()
elif 'caiso' in data_path:
# Load and process CAISO dataset
data = pd.read_csv(os.path.join(root_path, data_path))
data['Date'] = data['Date'].astype('datetime64[ns]')
names = ['PGE','SCE','SDGE','VEA','CA ISO','PACE','PACW','NEVP','AZPS','PSEI']
df_raw = pd.DataFrame(pd.date_range('20130101','20210630',freq='H')[:-1], columns=['Date'])
for name in names:
current_df = data[data['zone'] == name].drop_duplicates(subset='Date', keep='last').rename(columns={'load':name}).drop(columns=['zone'])
df_raw = df_raw.merge(current_df, on='Date', how='outer')
df_raw = df_raw.rename(columns={'Date': 'date'})
elif 'nordpool' in data_path:
# Load and process Nordpool dataset
df_raw = pd.read_csv(os.path.join(root_path, data_path), parse_dates=['Time'])
df_raw = df_raw.rename(columns={'Time': 'date'})
elif 'power Generation and consumption' in data_path:
# Load and process Turkey Power dataset
df_raw = pd.read_csv(os.path.join(root_path, data_path), parse_dates=['Date_Time'])
df_raw = df_raw.rename(columns={'Date_Time': 'date'})
data_format = "%d.%m.%Y %H:%M"
elif 'istanbul_traffic' in data_path:
# Load and process Istanbul Traffic dataset
df_raw = pd.read_csv(os.path.join(root_path, data_path), parse_dates=['datetime'])
df_raw = df_raw.rename(columns={'datetime': 'date'})
df_raw.set_index('date', inplace=True)
df_raw = df_raw.resample(freq).mean().reset_index()
else:
# Load customized dataset
df_raw = pd.read_csv(os.path.join(root_path, data_path), parse_dates=['date'])
# Process time encoding
if multivariate:
df_stamp = df_raw[['date']]
df_stamp['date'] = pd.to_datetime(df_stamp.date, format=data_format)
if timeenc == 0:
df_stamp['month'] = df_stamp.date.apply(lambda row: row.month, 1)
df_stamp['day'] = df_stamp.date.apply(lambda row: row.day, 1)
df_stamp['weekday'] = df_stamp.date.apply(lambda row: row.weekday(), 1)
df_stamp['hour'] = df_stamp.date.apply(lambda row: row.hour, 1)
df_stamp['minute'] = df_stamp.date.apply(lambda row: row.minute, 1)
df_stamp['minute'] = df_stamp.minute.map(lambda x: x // 15)
data_stamp = df_stamp.drop(labels='date', axis=1).values
elif timeenc == 1:
data_stamp = time_features(pd.to_datetime(df_stamp['date'].values), freq=freq)
data_stamp = data_stamp.transpose(1, 0)
elif timeenc == 2:
data_stamp = pd.to_datetime(df_stamp['date'].values)
data_stamp = np.array(data_stamp, dtype='datetime64[s]')
else:
raise ValueError('Invalid timeenc value. timeenc should be sellected within [0, 1, 2].')
df_raw = df_raw.set_index(keys='date')
else:
data_stamp = None
# Replace missing values with 0
df_raw = df_raw.fillna(0)
# Determine target dimension and dataset size
target_dim = len(df_raw.columns) if multivariate else 1
data_size = len(df_raw)
return df_raw, data_stamp, target_dim, data_size
================================================
FILE: probts/data/data_utils/time_features.py
================================================
# ---------------------------------------------------------------------------------
# Portions of this file are derived from GluonTS
# - Source: https://github.com/awslabs/gluonts
#
# We thank the authors for their contributions.
# ---------------------------------------------------------------------------------
from typing import List
import numpy as np
import pandas as pd
from pandas.tseries import offsets
from pandas.tseries.frequencies import to_offset
from gluonts.core.component import validated
from gluonts.dataset.common import DataEntry
from gluonts.transform import MapTransformation
from typing import List, Type
class TimeFeature:
def __init__(self):
pass
def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
pass
def __repr__(self):
return self.__class__.__name__ + "()"
class SecondOfMinute(TimeFeature):
"""Minute of hour encoded as value between [-0.5, 0.5]"""
def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
return index.second / 59.0 - 0.5
class MinuteOfHour(TimeFeature):
"""Minute of hour encoded as value between [-0.5, 0.5]"""
def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
return index.minute / 59.0 - 0.5
class HourOfDay(TimeFeature):
"""Hour of day encoded as value between [-0.5, 0.5]"""
def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
return index.hour / 23.0 - 0.5
class DayOfWeek(TimeFeature):
"""Hour of day encoded as value between [-0.5, 0.5]"""
def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
return index.dayofweek / 6.0 - 0.5
class DayOfMonth(TimeFeature):
"""Day of month encoded as value between [-0.5, 0.5]"""
def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
return (index.day - 1) / 30.0 - 0.5
class DayOfYear(TimeFeature):
"""Day of year encoded as value between [-0.5, 0.5]"""
def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
return (index.dayofyear - 1) / 365.0 - 0.5
class MonthOfYear(TimeFeature):
"""Month of year encoded as value between [-0.5, 0.5]"""
def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
return (index.month - 1) / 11.0 - 0.5
class WeekOfYear(TimeFeature):
"""Week of year encoded as value between [-0.5, 0.5]"""
def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
return (index.isocalendar().week - 1) / 52.0 - 0.5
def time_features_from_frequency_str(freq_str: str) -> List[TimeFeature]:
"""
Returns a list of time features that will be appropriate for the given frequency string.
Parameters
----------
freq_str
Frequency string of the form [multiple][granularity] such as "12H", "5min", "1D" etc.
"""
features_by_offsets = {
offsets.YearEnd: [],
offsets.QuarterEnd: [MonthOfYear],
offsets.MonthEnd: [MonthOfYear],
offsets.Week: [DayOfMonth, WeekOfYear],
offsets.Day: [DayOfWeek, DayOfMonth, DayOfYear],
offsets.BusinessDay: [DayOfWeek, DayOfMonth, DayOfYear],
offsets.Hour: [HourOfDay, DayOfWeek, DayOfMonth, DayOfYear],
offsets.Minute: [
MinuteOfHour,
HourOfDay,
DayOfWeek,
DayOfMonth,
DayOfYear,
],
offsets.Second: [
SecondOfMinute,
MinuteOfHour,
HourOfDay,
DayOfWeek,
DayOfMonth,
DayOfYear,
],
}
offset = to_offset(freq_str)
for offset_type, feature_classes in features_by_offsets.items():
if isinstance(offset, offset_type):
return [cls() for cls in feature_classes]
supported_freq_msg = f"""
Unsupported frequency {freq_str}
The following frequencies are supported:
Y - yearly
alias: A
M - monthly
W - weekly
D - daily
B - business days
H - hourly
T - minutely
alias: min
S - secondly
"""
raise RuntimeError(supported_freq_msg)
def time_features(dates, freq='h'):
return np.vstack([feat(dates) for feat in time_features_from_frequency_str(freq)])
class FourierDateFeatures(TimeFeature):
def __init__(self, freq: str) -> None:
super().__init__()
# reocurring freq
freqs = [
"month",
"day",
"hour",
"minute",
"weekofyear",
"weekday",
"dayofweek",
"dayofyear",
"daysinmonth",
]
assert freq in freqs
self.freq = freq
def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
values = getattr(index, self.freq)
num_values = max(values) + 1
steps = [x * 2.0 * np.pi / num_values for x in values]
return np.vstack([np.cos(steps), np.sin(steps)])
def norm_freq_str(freq_str: str) -> str:
base_freq = freq_str.split("-")[0]
# Pandas has start and end frequencies, e.g `AS` and `A` for yearly start
# and yearly end frequencies. We don't make that difference and instead
# rely only on the end frequencies which don't have the `S` prefix.
# Note: Secondly ("S") frequency exists, where we don't want to remove the
# "S"!
if len(base_freq) >= 2 and base_freq.endswith("S"):
return base_freq[:-1]
return base_freq
def fourier_time_features_from_frequency(freq_str: str) -> List[TimeFeature]:
offset = to_offset(freq_str)
granularity = norm_freq_str(offset.name)
granularity = granularity.upper()
features = {
"M": ["weekofyear"],
"W": ["daysinmonth", "weekofyear"],
"D": ["dayofweek"],
"B": ["dayofweek", "dayofyear"],
"H": ["hour", "dayofweek"],
"min": ["minute", "hour", "dayofweek"],
"T": ["minute", "hour", "dayofweek"],
}
assert granularity in features, f"freq {granularity} not supported"
feature_classes: List[TimeFeature] = [
FourierDateFeatures(freq=freq) for freq in features[granularity]
]
return feature_classes
def get_lags(freq_str:str):
"""
Calculate appropriate lag values for time series forecasting based on data frequency.
Parameters
----------
freq_str : str
The frequency of the time series data. Supported values include:
Returns
-------
lags : list[int]
A list of lag values, representing the offsets of past observations to include in the model.
The lags are tailored to capture autocorrelation and seasonality patterns for the specified frequency.
Examples
--------
>>> get_lags("H")
[1, 24, 168] # Captures hourly, daily, and weekly seasonality
>>> get_lags("D")
[1, 7, 14] # Captures daily, weekly, and bi-weekly seasonality
"""
freq_str = freq_str.upper()
if freq_str == "M":
lags = [1, 12]
elif freq_str == "D":
lags = [1, 7, 14]
elif freq_str == "B":
lags = [1, 2]
elif freq_str == "H":
lags = [1, 24, 168]
elif freq_str in ("T", "min"):
lags = [1, 4, 12, 24, 48]
else:
lags = [1]
return lags
def target_transformation_length(
target: np.ndarray, pred_length: int, is_train: bool
) -> int:
return target.shape[-1] + (0 if is_train else pred_length)
class AddCustomizedTimeFeatures(MapTransformation):
"""
Adds a set of time features.
If `is_train=True` the feature matrix has the same length as the `target`
field. If `is_train=False` the feature matrix has length
`len(target) + pred_length`
Parameters
----------
start_field
Field with the start time stamp of the time series
target_field
Field with the array containing the time series values
output_field
Field name for result.
time_features
list of time features to use.
pred_length
Prediction length
"""
@validated()
def __init__(
self,
start_field: str,
target_field: str,
output_field: str,
time_features,
pred_length: int,
dtype: Type = np.float32,
) -> None:
self.date_features = time_features
self.pred_length = pred_length
self.start_field = start_field
self.target_field = target_field
self.output_field = output_field
self.dtype = dtype
def map_transform(self, data: DataEntry, is_train: bool) -> DataEntry:
length = target_transformation_length(
data[self.target_field], self.pred_length, is_train=is_train
)
if len(self.date_features.shape) == 2:
data[self.output_field] = self.date_features[:length].astype(self.dtype)
else:
data[self.output_field] = self.date_features[:length].astype(np.float64)
data[self.output_field] = self.date_features[:length].astype(np.float64)
data[self.output_field] = np.transpose(data[self.output_field])
return data
================================================
FILE: probts/data/data_wrapper.py
================================================
import torch
class ProbTSBatchData:
input_names_ = [
'target_dimension_indicator',
'past_time_feat',
'past_target_cdf',
'past_observed_values',
'past_is_pad',
'future_time_feat',
'future_target_cdf',
'future_observed_values',
]
def __init__(self, data_dict, device):
# Initialize attributes from the provided data dictionary
self.__dict__.update(data_dict)
self.__dict__['context_length'] = data_dict.get('context_length', None)
self.__dict__['prediction_length'] = data_dict.get('prediction_length', None)
self.__dict__['max_context_length'] = data_dict.get('max_context_length', None)
self.__dict__['max_prediction_length'] = data_dict.get('max_prediction_length', None)
# Expand dimensions for univariate data
if len(self.__dict__['past_target_cdf'].shape) == 2:
self._expand_dimensions()
# Set tensors to the specified device
self._set_device(device)
# Fill missing inputs with None
self._ensure_all_inputs_present()
# Process padding for observed values
self._process_padding()
def _ensure_all_inputs_present(self):
"""Ensure all expected inputs are present in the data."""
for input in self.input_names_:
if input not in self.__dict__:
self.__dict__[input] = None
def _set_device(self, device):
"""Move all tensors to the specified device."""
for k, v in self.__dict__.items():
if v is not None and torch.is_tensor(v):
v.to(device)
self.device = device
def _expand_dimensions(self):
"""Expand dimensions for target-related tensors if necessary."""
self.__dict__["target_dimension_indicator"] = self.__dict__["target_dimension_indicator"][:, :1]
for input in ['past_target_cdf','past_observed_values','future_target_cdf','future_observed_values']:
self.__dict__[input] = self.__dict__[input].unsqueeze(-1)
def _process_padding(self):
"""Adjust observed values based on the padding indicator."""
if self.__dict__['past_is_pad'] is not None:
self.__dict__["past_observed_values"] = torch.min(
self.__dict__["past_observed_values"],
1 - self.__dict__["past_is_pad"].unsqueeze(-1)
)
================================================
FILE: probts/data/datasets/gift_eval_datasets.py
================================================
# Copyright (c) 2023, Salesforce, Inc.
# SPDX-License-Identifier: Apache-2
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import math
from functools import cached_property
from enum import Enum
from pathlib import Path
from typing import Iterable, Iterator
import datasets
from dotenv import load_dotenv
from gluonts.dataset import DataEntry
from gluonts.dataset.common import ProcessDataEntry
from gluonts.dataset.split import TestData, TrainingDataset, split
from gluonts.itertools import Map
from gluonts.time_feature import norm_freq_str
from gluonts.transform import Transformation
from pandas.tseries.frequencies import to_offset
import pyarrow.compute as pc
from toolz import compose
# add for probts transform
from probts.data.data_utils.data_utils import get_rolling_test_of_gift_eval
TEST_SPLIT = 0.1
MAX_WINDOW = 20
M4_PRED_LENGTH_MAP = {
"A": 6,
"Q": 8,
"M": 18,
"W": 13,
"D": 14,
"H": 48,
}
PRED_LENGTH_MAP = {
"M": 12,
"W": 8,
"D": 30,
"H": 48,
"T": 48,
"S": 60,
}
TFB_PRED_LENGTH_MAP = {
"A": 6,
"H": 48,
"Q": 8,
"D": 14,
"M": 18,
"W": 13,
"U": 8,
"T": 8,
}
class Term(Enum):
SHORT = "short"
MEDIUM = "medium"
LONG = "long"
@property
def multiplier(self) -> int:
if self == Term.SHORT:
return 1
elif self == Term.MEDIUM:
return 10
elif self == Term.LONG:
return 15
def itemize_start(data_entry: DataEntry) -> DataEntry:
data_entry["start"] = data_entry["start"].item()
return data_entry
class MultivariateToUnivariate(Transformation):
def __init__(self, field):
self.field = field
def __call__(
self, data_it: Iterable[DataEntry], is_train: bool = False
) -> Iterator:
for data_entry in data_it:
item_id = data_entry["item_id"]
val_ls = list(data_entry[self.field])
for id, val in enumerate(val_ls):
data_entry[self.field] = val
data_entry["item_id"] = item_id + "_dim" + str(id)
yield data_entry
class GiftEvalDataset:
def __init__(
self,
name: str,
term: Term | str = Term.SHORT,
to_univariate: bool = False,
storage_env_var: str = "GIFT_EVAL",
):
self.term = Term(term)
self.name = name
self.to_univariate = to_univariate
load_dotenv()
storage_path = Path(os.getenv(storage_env_var))
self.hf_dataset = datasets.load_from_disk(str(storage_path / name)).with_format(
"numpy"
)
@cached_property
def gluonts_dataset(self):
process = ProcessDataEntry(
self.freq,
one_dim_target=self.target_dim == 1,
)
gluonts_dataset = Map(compose(process, itemize_start), self.hf_dataset)
if self.to_univariate:
gluonts_dataset = MultivariateToUnivariate("target").apply(
gluonts_dataset
)
return gluonts_dataset
@cached_property
def prediction_length(self) -> int:
freq = norm_freq_str(to_offset(self.freq).name)
pred_len = (
M4_PRED_LENGTH_MAP[freq] if "m4" in self.name else PRED_LENGTH_MAP[freq]
)
return self.term.multiplier * pred_len
@cached_property
def freq(self) -> str:
return self.hf_dataset[0]["freq"]
@cached_property
def target_dim(self) -> int:
return (
target.shape[0]
if len((target := self.hf_dataset[0]["target"]).shape) > 1
else 1
)
@cached_property
def target_ndim(self) -> int:
return 1 if self.target_dim == 1 else 2
@cached_property
def past_feat_dynamic_real_dim(self) -> int:
if "past_feat_dynamic_real" not in self.hf_dataset[0]:
return 0
elif (
len(
(
past_feat_dynamic_real := self.hf_dataset[0][
"past_feat_dynamic_real"
]
).shape
)
> 1
):
return past_feat_dynamic_real.shape[0]
else:
return 1
@cached_property
def windows(self) -> int:
if "m4" in self.name:
return 1
w = math.ceil(TEST_SPLIT * self._min_series_length / self.prediction_length)
return min(max(1, w), MAX_WINDOW)
@cached_property
def _min_series_length(self) -> int:
if self.hf_dataset[0]["target"].ndim > 1:
lengths = pc.list_value_length(
pc.list_flatten(
pc.list_slice(self.hf_dataset.data.column("target"), 0, 1)
)
)
else:
lengths = pc.list_value_length(self.hf_dataset.data.column("target"))
return min(lengths.to_numpy())
@cached_property
def sum_series_length(self) -> int:
if self.hf_dataset[0]["target"].ndim > 1:
lengths = pc.list_value_length(
pc.list_flatten(self.hf_dataset.data.column("target"))
)
else:
lengths = pc.list_value_length(self.hf_dataset.data.column("target"))
return sum(lengths.to_numpy())
@property
def training_dataset(self) -> TrainingDataset:
training_dataset, _ = split(
self.gluonts_dataset, offset=-self.prediction_length * (self.windows + 1)
)
return training_dataset
@property
def validation_dataset(self) -> TrainingDataset:
validation_dataset, _ = split(
self.gluonts_dataset, offset=-self.prediction_length * self.windows
)
return validation_dataset
@property
def test_dataset(self) -> TrainingDataset:
print(f"BETA version: generating test datasets for gift eval, should contain {self.windows} windows.")
test_dataset = get_rolling_test_of_gift_eval(
dataset=self.gluonts_dataset,
prediction_length=self.prediction_length,
windows=self.windows,
)
return test_dataset
@property
def test_data(self) -> TestData:
_, test_template = split(
self.gluonts_dataset, offset=-self.prediction_length * self.windows
)
test_data = test_template.generate_instances(
prediction_length=self.prediction_length,
windows=self.windows,
distance=self.prediction_length,
)
return test_data
================================================
FILE: probts/data/datasets/multi_horizon_datasets.py
================================================
# ---------------------------------------------------------------------------------
# Portions of this file are derived from GluonTS
# - Source: https://github.com/awslabs/gluonts
#
# We thank the authors for their contributions.
# ---------------------------------------------------------------------------------
from torch.utils.data import IterableDataset
from gluonts.env import env
from gluonts.dataset.common import Dataset
from gluonts.dataset.field_names import FieldName
from gluonts.transform import (
SelectFields,
Transformation,
Chain,
ValidationSplitSampler,
ExpectedNumInstanceSampler,
RenameFields,
AsNumpyArray,
ExpandDimArray,
AddObservedValuesIndicator,
AddTimeFeatures,
VstackFeatures,
SetFieldIfNotPresent,
TargetDimIndicator,
InstanceSplitter
)
from gluonts.dataset.common import DataEntry
from gluonts.transform import InstanceSampler
from gluonts.zebras._util import pad_axis
from gluonts.dataset.common import DataEntry
from gluonts.transform._base import FlatMapTransformation
from probts.data.data_utils.time_features import fourier_time_features_from_frequency, AddCustomizedTimeFeatures
from probts.data.datasets.single_horizon_datasets import TransformedIterableDataset
from typing import Union
from typing import Iterator, List, Optional, Tuple, Union
import numpy as np
import random
class MultiHorizonDataset():
"""
MultiHorizonDataset: Supports multi-horizon forecasting by enabling flexible context and prediction lengths.
Parameters:
----------
input_names : list
Names of input fields required by the model.
freq : str
Frequency of the data (e.g., 'H' for hourly, 'D' for daily).
train_ctx_range : Union[int, list]
Range of context lengths for the training dataset.
train_pred_range : Union[int, list]
Range of prediction lengths for the training dataset.
val_ctx_range : Union[int, list]
Range of context lengths for the validation dataset.
val_pred_range : Union[int, list]
Range of prediction lengths for the validation dataset.
test_ctx_range : Union[int, list]
Range of context lengths for the testing dataset.
test_pred_range : Union[int, list]
Range of prediction lengths for the testing dataset.
multivariate : bool, optional, default=True
Whether the dataset contains multiple target variables.
continuous_sample : bool, optional, default=False
Whether to enable continuous sampling horizons from the train_pred_range.
"""
def __init__(
self,
input_names: list,
freq: str,
train_ctx_range: Union[int, list],
train_pred_range: Union[int, list],
val_ctx_range: Union[int, list],
val_pred_range: Union[int, list],
test_ctx_range: Union[int, list],
test_pred_range: Union[int, list],
multivariate: bool = True,
continuous_sample: bool = False,
):
super().__init__()
self.input_names_ = input_names
self.train_ctx_range = train_ctx_range
self.train_pred_range = train_pred_range
self.val_ctx_range = val_ctx_range
self.val_pred_range = val_pred_range
self.test_ctx_range = test_ctx_range
self.test_pred_range=test_pred_range
self.continuous_sample = continuous_sample
self.freq = freq
if multivariate:
self.expected_ndim = 2
else:
self.expected_ndim = 1
def get_sampler(self):
"""
Creates samplers for training, validation, and testing datasets.
Samplers control how data instances are selected for each mode.
"""
# for training
train_min_past = min(self.train_ctx_range)
train_min_future = min(self.train_pred_range)
# for validation
val_min_past = max(self.val_ctx_range)
val_min_future = max(self.val_pred_range)
# for testing
if (type(self.test_ctx_range).__name__=='list'):
test_min_past = max(self.test_ctx_range)
else:
test_min_past=self.test_ctx_range
if (type(self.test_pred_range).__name__=='list'):
test_min_future = max(self.test_pred_range)
else:
test_min_future=self.test_pred_range
self.train_sampler = ExpectedNumInstanceSampler(
num_instances=1.0,
min_past=train_min_past,
min_future=train_min_future,
)
self.val_sampler = ValidationSplitSampler(
min_past=val_min_past,
min_future=val_min_future,
)
self.test_sampler = ValidationSplitSampler(
min_past=test_min_past,
min_future=test_min_future,
)
def create_transformation(self, data_stamp=None, pred_len=None) -> Transformation:
"""
Creates a transformation pipeline for data preprocessing.
Parameters:
----------
data_stamp : np.array, optional
Precomputed time features. If None, features are generated based on the frequency.
pred_len : int, optional
Prediction length for the transformation. If None, uses the maximum training prediction range.
Returns:
----------
Chain : Transformation
A chain of transformations applied to the dataset.
"""
if data_stamp is None:
if self.freq in ["M", "W", "D", "B", "H", "min", "T"]:
time_features = fourier_time_features_from_frequency(self.freq)
else:
time_features = fourier_time_features_from_frequency('D')
self.time_feat_dim = len(time_features) * 2
time_feature_func = AddTimeFeatures
else:
self.time_feat_dim = data_stamp.shape[-1]
time_features = data_stamp
time_feature_func = AddCustomizedTimeFeatures
if pred_len is None:
pred_len = max(self.train_pred_range)
else:
pred_len = max(pred_len)
return Chain(
[
AsNumpyArray(
field=FieldName.TARGET,
expected_ndim=self.expected_ndim,
),
ExpandDimArray(
field=FieldName.TARGET,
axis=None,
),
AddObservedValuesIndicator(
target_field=FieldName.TARGET,
output_field=FieldName.OBSERVED_VALUES,
),
time_feature_func(
start_field=FieldName.START,
target_field=FieldName.TARGET,
output_field=FieldName.FEAT_TIME,
time_features=time_features,
pred_length=pred_len,
),
VstackFeatures(
output_field=FieldName.FEAT_TIME,
input_fields=[FieldName.FEAT_TIME],
),
SetFieldIfNotPresent(field=FieldName.FEAT_STATIC_CAT, value=[0]),
TargetDimIndicator(
field_name="target_dimension_indicator",
target_field=FieldName.TARGET,
),
AsNumpyArray(field=FieldName.FEAT_STATIC_CAT, expected_ndim=1),
]
)
def create_instance_splitter(self, mode: str, pred_len=None, auto_search=False):
"""
Creates an instance splitter for slicing data sequences.
Parameters:
----------
mode : str
Dataset mode. Must be one of ['train', 'val', 'test'].
pred_len : list, optional
Prediction length for validation or testing. If None, defaults to the predefined ranges.
Returns:
----------
MultiHorizonSplitter : Transformation
Transformation that slices time series sequences.
"""
assert mode in ["train", "val", "test"]
self.get_sampler()
instance_sampler = {
"train": self.train_sampler,
"val": self.val_sampler,
"test": self.test_sampler,
}[mode]
if mode == "train":
past_length = self.train_ctx_range
future_length = self.train_pred_range
elif mode == 'val':
past_length = self.val_ctx_range
if pred_len is None:
future_length = self.val_pred_range
else:
future_length = pred_len
else:
if pred_len is None:
future_length = self.test_pred_range
else:
future_length = pred_len
if auto_search:
past_length = [max(self.test_ctx_range) + max(future_length)]
else:
past_length = self.test_ctx_range
return MultiHorizonSplitter(
target_field=FieldName.TARGET,
is_pad_field=FieldName.IS_PAD,
start_field=FieldName.START,
forecast_start_field=FieldName.FORECAST_START,
instance_sampler=instance_sampler,
past_length=past_length,
future_length=future_length,
mode=mode,
continuous_sample=self.continuous_sample,
time_series_fields=[
FieldName.FEAT_TIME,
FieldName.OBSERVED_VALUES,
],
) + (
RenameFields(
{
f"past_{FieldName.TARGET}": f"past_{FieldName.TARGET}_cdf",
f"future_{FieldName.TARGET}": f"future_{FieldName.TARGET}_cdf",
}
)
)
def get_iter_dataset(self, dataset: Dataset, mode: str, data_stamp=None, pred_len=None, auto_search=False) -> IterableDataset:
"""
Creates an iterable dataset with applied transformations and splitters.
Parameters:
----------
dataset : Dataset
Input dataset to transform.
mode : str
Mode of operation. Must be one of ['train', 'val', 'test'].
data_stamp : np.array, optional
Precomputed time features.
pred_len : list, optional
Prediction length for validation or testing.
Returns:
----------
IterableDataset : TransformedIterableDataset
Transformed dataset ready for model training or evaluation.
"""
assert mode in ["train", "val", "test"]
transform = self.create_transformation(data_stamp, pred_len=pred_len)
if mode == 'train':
with env._let(max_idle_transforms=100):
instance_splitter = self.create_instance_splitter(mode)
else:
instance_splitter = self.create_instance_splitter(mode, pred_len=pred_len, auto_search=auto_search)
input_names = self.input_names_
iter_dataset = TransformedIterableDataset(
dataset,
transform=transform
+ instance_splitter
+ SelectFields(input_names),
is_train=True if mode == 'train' else False
)
return iter_dataset
class MultiHorizonSplitter(FlatMapTransformation):
"""
Split instances from a dataset, by slicing the target and other time series
fields at points in time selected by the specified sampler. The assumption
is that all time series fields start at the same time point.
It is assumed that time axis is always the last axis.
The ``target_field`` and each field in ``time_series_fields`` are removed and
replaced by two new fields, with prefix `past_` and `future_` respectively.
A ``past_is_pad`` is also added, that indicates whether values at a given
time point are padding or not.
Parameters
----------
target_field
field containing the target
is_pad_field
output field indicating whether padding happened
start_field
field containing the start date of the time series
forecast_start_field
output field that will contain the time point where the forecast starts
instance_sampler
instance sampler that provides sampling indices given a time series
past_length
length of the target seen before making prediction
future_length
length of the target that must be predicted
lead_time
gap between the past and future windows (default: 0)
output_NTC
whether to have time series output in (time, dimension) or in
(dimension, time) layout (default: True)
time_series_fields
fields that contains time series, they are split in the same interval
as the target (default: None)
dummy_value
Value to use for padding. (default: 0.0)
"""
# @validated()
def __init__(
self,
target_field: str,
is_pad_field: str,
start_field: str,
forecast_start_field: str,
instance_sampler: InstanceSampler,
past_length: Union[int, list],
future_length: Union[int, list],
mode: str,
lead_time: int = 0,
output_NTC: bool = True,
time_series_fields: List[str] = [],
dummy_value: float = 0.0,
continuous_sample: bool = False,
) -> None:
super().__init__()
# assert future_length > 0, "The value of `future_length` should be > 0"
self.instance_sampler = instance_sampler
self.past_length = past_length
self.future_length = future_length
self.continuous_sample = continuous_sample
self.lead_time = lead_time
self.output_NTC = output_NTC
self.ts_fields = time_series_fields
self.target_field = target_field
self.is_pad_field = is_pad_field
self.start_field = start_field
self.forecast_start_field = forecast_start_field
self.dummy_value = dummy_value
self.mode = mode
def _past(self, col_name):
return f"past_{col_name}"
def _future(self, col_name):
return f"future_{col_name}"
def _split_array(
self, array: np.ndarray, idx: int, past_length: int, future_length: int
) -> Tuple[np.ndarray, np.ndarray]:
if idx >= past_length:
past_piece = array[..., idx - past_length : idx]
else:
past_piece = pad_axis(
array[..., :idx],
axis=-1,
left=past_length - idx,
value=self.dummy_value,
)
future_start = idx + self.lead_time
future_slice = slice(future_start, future_start + future_length)
future_piece = array[..., future_slice]
return past_piece, future_piece
def _split_instance(self, entry: DataEntry, idx: int, is_train) -> DataEntry:
slice_cols = self.ts_fields + [self.target_field]
dtype = entry[self.target_field].dtype
entry = entry.copy()
if is_train:
if self.continuous_sample:
past_len = random.randint(min(self.past_length), max(self.past_length))
pred_len = random.randint(min(self.future_length), max(self.future_length))
else:
past_len = random.choice(self.past_length)
pred_len = random.choice(self.future_length)
else:
past_len = max(self.past_length)
pred_len = max(self.future_length)
for ts_field in slice_cols:
past_piece, future_piece = self._split_array(entry[ts_field], idx, past_length=past_len, future_length=pred_len)
if self.output_NTC:
past_piece = past_piece.transpose()
future_piece = future_piece.transpose()
entry[self._past(ts_field)] = past_piece
entry[self._future(ts_field)] = future_piece
del entry[ts_field]
pad_indicator = np.zeros(past_len, dtype=dtype)
pad_length = max(past_len - idx, 0)
pad_indicator[:pad_length] = 1
entry[self._past(self.is_pad_field)] = pad_indicator
entry[self.forecast_start_field] = (
entry[self.start_field] + idx + self.lead_time
)
entry['context_length'] = past_len
entry['prediction_length'] = pred_len
return entry
def flatmap_transform(
self, entry: DataEntry, is_train: bool
) -> Iterator[DataEntry]:
sampled_indices = self.instance_sampler(entry[self.target_field])
for idx in sampled_indices:
yield self._split_instance(entry, idx, is_train)
================================================
FILE: probts/data/datasets/single_horizon_datasets.py
================================================
# ---------------------------------------------------------------------------------
# Portions of this file are derived from PyTorch-TS
# - Source: https://github.com/zalandoresearch/pytorch-ts
#
# We thank the authors for their contributions.
# ---------------------------------------------------------------------------------
from torch.utils.data import IterableDataset
from gluonts.env import env
from gluonts.dataset.common import Dataset
from gluonts.itertools import Cyclic
from gluonts.dataset.field_names import FieldName
from gluonts.transform import (
SelectFields,
Transformation,
Chain,
InstanceSplitter,
ValidationSplitSampler,
ExpectedNumInstanceSampler,
RenameFields,
AsNumpyArray,
ExpandDimArray,
AddObservedValuesIndicator,
AddTimeFeatures,
VstackFeatures,
SetFieldIfNotPresent,
TargetDimIndicator,
TransformedDataset,
)
from probts.data.data_utils.time_features import fourier_time_features_from_frequency, AddCustomizedTimeFeatures
class SingleHorizonDataset():
"""
SingleHorizonDataset: Handles dataset transformation and instance splitting for single-horizon forecasting tasks.
Parameters:
----------
input_names : list
List of input field names required by the model.
history_length : int
Length of the historical time series window for input data.
prediction_length : int
Length of the forecasting horizon.
freq : str
Data frequency (e.g., 'H' for hourly, 'D' for daily).
multivariate : bool, optional, default=True
Indicates if the dataset contains multiple target variables.
"""
def __init__(
self,
input_names: list,
history_length: int,
context_length: int,
prediction_length: int,
freq: str,
multivariate: bool = True
):
super().__init__()
self.input_names_ = input_names
self.history_length = history_length
self.context_length = context_length
self.prediction_length = prediction_length
self.freq = freq
if multivariate:
self.expected_ndim = 2
else:
self.expected_ndim = 1
def get_sampler(self):
"""
Creates samplers for training, validation, and testing.
- Training: Generates instances randomly.
- Validation and Testing: Always selects the last time point.
"""
# returns a set of indices at which training instances will be generated
self.train_sampler = ExpectedNumInstanceSampler(
num_instances=1.0,
min_past=self.history_length,
min_future=self.prediction_length,
)
self.val_sampler = ValidationSplitSampler(
min_past=self.history_length,
min_future=self.prediction_length,
)
self.test_sampler = ValidationSplitSampler(
min_past=self.history_length,
min_future=self.prediction_length,
)
def create_transformation(self, data_stamp=None) -> Transformation:
"""
Creates a data transformation pipeline to prepare inputs for the model.
Adds features such as time attributes and observed value indicators.
Parameters:
----------
data_stamp : np.array, optional
Precomputed time features. If None, features are generated based on the data frequency.
Returns:
----------
Chain : Transformation
A chain of transformations applied to the dataset.
"""
if data_stamp is None:
if self.freq in ["M", "W", "D", "B", "H", "min", "T"]:
time_features = fourier_time_features_from_frequency(self.freq)
else:
time_features = fourier_time_features_from_frequency('D')
self.time_feat_dim = len(time_features) * 2
time_feature_func = AddTimeFeatures
else:
self.time_feat_dim = data_stamp.shape[-1]
time_features = data_stamp
time_feature_func = AddCustomizedTimeFeatures
return Chain(
[
AsNumpyArray(
field=FieldName.TARGET,
expected_ndim=self.expected_ndim,
),
AddObservedValuesIndicator(
target_field=FieldName.TARGET,
output_field=FieldName.OBSERVED_VALUES,
),
time_feature_func(
start_field=FieldName.START,
target_field=FieldName.TARGET,
output_field=FieldName.FEAT_TIME,
time_features=time_features,
pred_length=self.prediction_length,
),
VstackFeatures(
output_field=FieldName.FEAT_TIME,
input_fields=[FieldName.FEAT_TIME],
),
SetFieldIfNotPresent(field=FieldName.FEAT_STATIC_CAT, value=[0]),
TargetDimIndicator(
field_name="target_dimension_indicator",
target_field=FieldName.TARGET,
),
AsNumpyArray(field=FieldName.FEAT_STATIC_CAT, expected_ndim=1),
]
)
def create_instance_splitter(self, mode: str, auto_search=False):
"""
Creates an instance splitter for training, validation, or testing.
Parameters:
----------
mode : str
Mode of operation. Must be one of ['train', 'val', 'test'].
Returns:
----------
InstanceSplitter : Transformation
A splitter transformation that slices input data for model training or evaluation.
"""
assert mode in ["train", "val", "test"]
self.get_sampler()
instance_sampler = {
"train": self.train_sampler,
"val": self.val_sampler,
"test": self.test_sampler,
}[mode]
if auto_search:
past_length = self.context_length + self.prediction_length
else:
past_length=self.history_length
return InstanceSplitter(
target_field=FieldName.TARGET,
is_pad_field=FieldName.IS_PAD,
start_field=FieldName.START,
forecast_start_field=FieldName.FORECAST_START,
instance_sampler=instance_sampler,
past_length=past_length,
future_length=self.prediction_length,
time_series_fields=[
FieldName.FEAT_TIME,
FieldName.OBSERVED_VALUES,
],
) + (
RenameFields(
{
f"past_{FieldName.TARGET}": f"past_{FieldName.TARGET}_cdf",
f"future_{FieldName.TARGET}": f"future_{FieldName.TARGET}_cdf",
}
)
)
def get_iter_dataset(self, dataset: Dataset, mode: str, data_stamp=None, auto_search=False) -> IterableDataset:
"""
Creates an iterable dataset for training, validation, or testing.
Parameters:
----------
dataset : Dataset
Input dataset to transform.
mode : str
Mode of operation. Must be one of ['train', 'val', 'test'].
data_stamp : np.array, optional
Precomputed time features.
Returns:
----------
IterableDataset : TransformedIterableDataset
Transformed dataset with applied transformations and instance splitting.
"""
assert mode in ["train", "val", "test"]
transform = self.create_transformation(data_stamp)
if mode == 'train':
with env._let(max_idle_transforms=100):
instance_splitter = self.create_instance_splitter(mode)
else:
instance_splitter = self.create_instance_splitter(mode, auto_search=auto_search)
input_names = self.input_names_
iter_dataset = TransformedIterableDataset(
dataset,
transform=transform
+ instance_splitter
+ SelectFields(input_names),
is_train=True if mode == 'train' else False
)
return iter_dataset
class TransformedIterableDataset(IterableDataset):
"""
A transformed iterable dataset that applies a transformation pipeline on-the-fly.
Parameters:
----------
dataset : Dataset
The original dataset to transform.
transform : Transformation
The transformation pipeline to apply.
is_train : bool, optional, default=True
Whether the dataset is used for training.
"""
def __init__(
self,
dataset: Dataset,
transform: Transformation,
is_train: bool = True
):
super().__init__()
self.transformed_dataset = TransformedDataset(
Cyclic(dataset) if is_train else dataset,
transform,
is_train=is_train,
)
def __iter__(self):
return iter(self.transformed_dataset)
================================================
FILE: probts/model/__init__.py
================================================
from .forecast_module import *
================================================
FILE: probts/model/forecast_module.py
================================================
import numpy as np
import torch
from torch import optim
from typing import Dict
import lightning.pytorch as pl
import sys
from probts.data import ProbTSBatchData
from probts.data.data_utils.data_scaler import Scaler
from probts.model.forecaster import Forecaster
from probts.utils.evaluator import Evaluator
from probts.utils.metrics import *
from probts.utils.save_utils import update_metrics, calculate_weighted_average, load_checkpoint, get_hor_str
from probts.utils.utils import init_class_helper
def get_weights(sampling_weight_scheme, max_hor):
'''
return: w [max_hor]
'''
if sampling_weight_scheme == 'random':
i_array = np.linspace(1 + 1e-5, max_hor - 1e-3, max_hor)
w = (1 / max_hor) * (np.log(max_hor) - np.log(i_array))
elif sampling_weight_scheme == 'const':
w = np.array([1 / max_hor] * max_hor)
elif sampling_weight_scheme == 'none':
return None
else:
raise ValueError(f"Invalid sampling scheme {sampling_weight_scheme}.")
return torch.tensor(w)
class ProbTSForecastModule(pl.LightningModule):
def __init__(
self,
forecaster: Forecaster,
scaler: Scaler = None,
train_pred_len_list: list = None,
num_samples: int = 100,
learning_rate: float = 1e-3,
quantiles_num: int = 10,
load_from_ckpt: str = None,
sampling_weight_scheme: str = 'none',
optimizer_config = None,
lr_scheduler_config = None,
**kwargs
):
super().__init__()
self.num_samples = num_samples
self.learning_rate = learning_rate
self.load_from_ckpt = load_from_ckpt
self.train_pred_len_list = train_pred_len_list
self.forecaster = forecaster
self.optimizer_config = optimizer_config
self.scheduler_config = lr_scheduler_config
if self.optimizer_config is not None:
print("optimizer config: ", self.optimizer_config)
if self.scheduler_config is not None:
print("lr_scheduler config: ", self.scheduler_config)
self.scaler = scaler
self.evaluator = Evaluator(quantiles_num=quantiles_num)
# init the parapemetr for sampling
self.sampling_weight_scheme = sampling_weight_scheme
print(f'sampling_weight_scheme: {sampling_weight_scheme}')
self.save_hyperparameters()
@classmethod
def load_from_checkpoint(self, checkpoint_path, scaler=None, learning_rate=None, no_training=False, **kwargs):
model = load_checkpoint(self, checkpoint_path, scaler=scaler, learning_rate=learning_rate, no_training=no_training, **kwargs)
return model
def training_forward(self, batch_data):
batch_data.past_target_cdf = self.scaler.transform(batch_data.past_target_cdf)
batch_data.future_target_cdf = self.scaler.transform(batch_data.future_target_cdf)
loss = self.forecaster.loss(batch_data)
if len(loss.shape) > 1:
loss_weights = get_weights(self.sampling_weight_scheme, loss.shape[1])
loss = (loss_weights.detach().to(loss.device).unsqueeze(0).unsqueeze(-1) * loss).sum(dim=1)
loss = loss.mean()
return loss
def training_step(self, batch, batch_idx):
batch_data = ProbTSBatchData(batch, self.device)
loss = self.training_forward(batch_data)
self.log("train_loss", loss, on_step=True, prog_bar=True, logger=True)
return loss
def evaluate(self, batch, stage='',dataloader_idx=None):
batch_data = ProbTSBatchData(batch, self.device)
pred_len = batch_data.future_target_cdf.shape[1]
orin_past_data = batch_data.past_target_cdf[:]
orin_future_data = batch_data.future_target_cdf[:]
norm_past_data = self.scaler.transform(batch_data.past_target_cdf)
norm_future_data = self.scaler.transform(batch_data.future_target_cdf)
self.batch_size.append(orin_past_data.shape[0])
batch_data.past_target_cdf = self.scaler.transform(batch_data.past_target_cdf)
forecasts = self.forecaster.forecast(batch_data, self.num_samples)[:,:, :pred_len]
# Calculate denorm metrics
denorm_forecasts = self.scaler.inverse_transform(forecasts)
metrics = self.evaluator(orin_future_data, denorm_forecasts, past_data=orin_past_data, freq=self.forecaster.freq)
self.metrics_dict = update_metrics(metrics, stage, target_dict=self.metrics_dict)
# Calculate norm metrics
norm_metrics = self.evaluator(norm_future_data, forecasts, past_data=norm_past_data, freq=self.forecaster.freq)
self.metrics_dict = update_metrics(norm_metrics, stage, 'norm', target_dict=self.metrics_dict)
l = orin_future_data.shape[1]
if stage != 'test' and self.sampling_weight_scheme not in ['fix', 'none']:
loss_weights = get_weights('random', l)
else:
loss_weights = None
hor_metrics = self.evaluator(orin_future_data, denorm_forecasts, past_data=orin_past_data, freq=self.forecaster.freq, loss_weights=loss_weights)
if stage == 'test':
hor_str = get_hor_str(self.forecaster.prediction_length, dataloader_idx)
if hor_str not in self.hor_metrics:
self.hor_metrics[hor_str] = {}
self.hor_metrics[hor_str] = update_metrics(hor_metrics, stage, target_dict=self.hor_metrics[hor_str])
return hor_metrics
def validation_step(self, batch, batch_idx, dataloader_idx=None):
metrics = self.evaluate(batch, stage='val',dataloader_idx=dataloader_idx)
return metrics
def on_validation_epoch_start(self):
self.metrics_dict = {}
self.hor_metrics = {}
self.batch_size = []
def on_validation_epoch_end(self):
avg_metrics = calculate_weighted_average(self.metrics_dict, self.batch_size)
self.log_dict(avg_metrics, prog_bar=True)
def test_step(self, batch, batch_idx, dataloader_idx=None):
metrics = self.evaluate(batch, stage='test',dataloader_idx=dataloader_idx)
return metrics
def on_test_epoch_start(self):
self.metrics_dict = {}
self.hor_metrics = {}
self.avg_metrics = {}
self.avg_hor_metrics = {}
self.batch_size = []
def on_test_epoch_end(self):
if len(self.hor_metrics) > 0:
for hor_str, metric in self.hor_metrics.items():
self.avg_hor_metrics[hor_str] = calculate_weighted_average(metric, batch_size=self.batch_size)
self.avg_metrics.update(calculate_weighted_average(metric, batch_size=self.batch_size, hor=hor_str+'_'))
else:
self.avg_metrics = calculate_weighted_average(self.metrics_dict, self.batch_size)
if isinstance(self.forecaster.prediction_length, int) or len(self.forecaster.prediction_length) < 2:
self.log_dict(self.avg_metrics, logger=True)
def predict_step(self, batch, batch_idx):
batch_data = ProbTSBatchData(batch, self.device)
forecasts = self.forecaster.forecast(batch_data, self.num_samples)
return forecasts
def configure_optimizers(self):
if self.optimizer_config is None:
optimizer = optim.Adam(self.parameters(), lr=self.learning_rate)
else:
optimizer = init_class_helper(self.optimizer_config['class_name'])
params = self.optimizer_config['init_args']
optimizer = optimizer(self.parameters(), **params)
if self.scheduler_config is not None:
scheduler = init_class_helper(self.scheduler_config['class_name'])
params = self.scheduler_config['init_args']
scheduler = scheduler(optimizer=optimizer, **params)
lr_scheduler = {
"scheduler": scheduler,
"interval": "epoch",
"frequency": 1,
"monitor": "val_loss",
"strict": True,
"name": None,
}
return {"optimizer": optimizer, "lr_scheduler": lr_scheduler}
return optimizer
================================================
FILE: probts/model/forecaster/__init__.py
================================================
from .forecaster import Forecaster
from .point_forecaster import *
from .prob_forecaster import *
================================================
FILE: probts/model/forecaster/forecaster.py
================================================
import torch
from torch import nn
from typing import List
from probts.utils import weighted_average
from probts.data.data_utils.data_scaler import TemporalScaler
from typing import Union
class Forecaster(nn.Module):
def __init__(
self,
target_dim: int,
context_length: Union[list,int],
prediction_length: Union[list,int],
freq: str ,
use_lags: bool = False,
use_feat_idx_emb: bool = False,
use_time_feat: bool = False,
lags_list: List[int] = [],
feat_idx_emb_dim: int = 1,
time_feat_dim: int = 1,
use_scaling: bool = False,
autoregressive: bool = False,
no_training: bool = False,
dataset: str = None,
**kwargs
):
super().__init__()
self.context_length = context_length
self.prediction_length = prediction_length
if isinstance(self.context_length, list):
self.max_context_length = max(self.context_length)
else:
self.max_context_length = self.context_length
if isinstance(self.prediction_length, list):
self.max_prediction_length = max(self.prediction_length)
else:
self.max_prediction_length = self.prediction_length
self.target_dim = target_dim
self.freq = freq
self.use_lags = use_lags
self.use_feat_idx_emb = use_feat_idx_emb
self.use_time_feat = use_time_feat
self.feat_idx_emb_dim = feat_idx_emb_dim
self.time_feat_dim = time_feat_dim
self.autoregressive = autoregressive
self.no_training = no_training
self.use_scaling = use_scaling
self.dataset = dataset
# Lag parameters
self.lags_list = lags_list
if self.use_scaling:
self.scaler = TemporalScaler()
else:
self.scaler = None
self.lags_dim = len(self.lags_list) * target_dim
if use_feat_idx_emb:
self.feat_idx_emb = nn.Embedding(
num_embeddings=self.target_dim, embedding_dim=self.feat_idx_emb_dim
)
else:
self.feat_idx_emb = None
self.input_size = self.get_input_size()
@property
def name(self):
return self.__class__.__name__
def get_input_size(self):
input_size = self.target_dim if not self.use_lags else self.lags_dim
if self.use_feat_idx_emb:
input_size += self.use_feat_idx_emb * self.target_dim
if self.use_time_feat:
input_size += self.time_feat_dim
return input_size
def get_lags(self, sequence, lags_list, lags_length=1):
"""
Get several lags from the sequence of shape (B, L, C) to (B, L', C*N),
where L' = lag_length and N = len(lag_list).
"""
assert max(lags_list) + lags_length <= sequence.shape[1]
lagged_values = []
for lag_index in lags_list:
begin_index = -lag_index - lags_length
end_index = -lag_index if lag_index > 0 else None
lagged_value = sequence[:, begin_index:end_index, ...]
if self.use_scaling:
lagged_value = lagged_value / self.scaler.scale
lagged_values.append(lagged_value)
return torch.cat(lagged_values, dim=-1)
def get_input_sequence(
self,
past_target_cdf,
future_target_cdf,
mode
):
if mode == 'all':
sequence = torch.cat((past_target_cdf, future_target_cdf), dim=1)
seq_length = self.max_context_length + self.max_prediction_length
elif mode == 'encode':
sequence = past_target_cdf
seq_length = self.max_context_length
elif mode == 'decode':
sequence = past_target_cdf
seq_length = 1
else:
raise ValueError(f"Unsupported input mode: {mode}")
if self.use_lags:
input_seq = self.get_lags(sequence, self.lags_list, seq_length)
else:
input_seq = sequence[:, -seq_length:, ...]
if self.use_scaling:
input_seq = input_seq / self.scaler.scale
return input_seq
def get_input_feat_idx_emb(self, target_dimension_indicator, input_length):
input_feat_idx_emb = self.feat_idx_emb(target_dimension_indicator) # [B K D]
input_feat_idx_emb = (
input_feat_idx_emb.unsqueeze(1)
.expand(-1, input_length, -1, -1)
.reshape(-1, input_length, self.target_dim * self.feat_idx_emb_dim)
)
return input_feat_idx_emb # [B L K*D]
def get_input_time_feat(
self,
past_time_feat,
future_time_feat,
mode
):
if mode == 'all':
time_feat = torch.cat(
(past_time_feat[:, -self.max_context_length:, ...], future_time_feat), dim=1)
elif mode == 'encode':
time_feat = past_time_feat[:, -self.max_context_length:, ...]
elif mode == 'decode':
time_feat = future_time_feat
return time_feat
def get_inputs(self, batch_data, mode):
inputs_list = []
input_seq = self.get_input_sequence(
batch_data.past_target_cdf, batch_data.future_target_cdf, mode=mode)
input_length = input_seq.shape[1] # [B L n_lags*K]
inputs_list.append(input_seq)
if self.use_feat_idx_emb:
input_feat_idx_emb = self.get_input_feat_idx_emb(
batch_data.target_dimension_indicator, input_length) # [B L K*D]
inputs_list.append(input_feat_idx_emb)
if self.use_time_feat:
input_time_feat = self.get_input_time_feat(
batch_data.past_time_feat, batch_data.future_time_feat, mode=mode) # [B L Dt]
inputs_list.append(input_time_feat)
return torch.cat(inputs_list, dim=-1).to(dtype=torch.float32)
def get_scale(self, batch_data):
self.scaler.fit(
batch_data.past_target_cdf[:, -self.max_context_length:, ...],
batch_data.past_observed_values[:, -self.max_context_length:, ...]
)
def get_weighted_loss(self, batch_data, loss):
observed_values = batch_data.future_observed_values
loss_weights, _ = observed_values.min(dim=-1, keepdim=True)
loss = weighted_average(loss, weights=loss_weights, dim=1)
return loss
def loss(self, batch_data):
raise NotImplementedError
def forecast(self, batch_data=None, num_samples=None):
raise NotImplementedError
================================================
FILE: probts/model/forecaster/point_forecaster/__init__.py
================================================
from .mean import MeanForecaster
from .naive import NaiveForecaster
from .linear import LinearForecaster
from .patchtst import PatchTST
from .transformer import TransformerForecaster
from .gru import GRUForecaster
from .dlinear import DLinear
from .nlinear import NLinear
from .nhits import NHiTS
from .timesnet import TimesNet
from .itransformer import iTransformer
from .autoformer import Autoformer
from .tsmixer import TSMixer
from .elastst import ElasTST
from .time_moe import TimeMoE
from .timesfm import TimesFM
from .moderntcn import ModernTCN
# ------- add timesfm to sys.path ----------
try:
import os, sys
current_dir = os.path.dirname(os.path.realpath(__file__))
project_root = os.path.abspath(os.path.join(current_dir, '..', '..', '..', '..'))
timesfm_path = os.path.join(project_root, 'submodules', 'timesfm', 'src')
if timesfm_path not in sys.path:
sys.path.append(timesfm_path)
except Exception as e:
print(f"Warning: Unable to add timesfm to sys.path. {e}")
# ------------------------------------------
import importlib
modules = [
('timer', 'Timer'),
('units', 'UniTS'),
('forecastpfn', 'ForecastPFN'),
('tinytimemixer', 'TinyTimeMixer'),
]
for module, class_name in modules:
try:
mod = importlib.import_module(f".{module}", package=__package__)
globals()[class_name] = getattr(mod, class_name)
except ImportError:
# print(f"Warning: {class_name} is not available due to missing dependencies.")
pass
================================================
FILE: probts/model/forecaster/point_forecaster/autoformer.py
================================================
# ---------------------------------------------------------------------------------
# Portions of this file are derived from Autoformer
# - Source: https://github.com/thuml/Autoformer
# - Paper: Autoformer: Decomposition Transformers with Auto-Correlation for Long-Term Series Forecasting
# - License: MIT License
# We thank the authors for their contributions.
# ---------------------------------------------------------------------------------
import torch
import torch.nn as nn
from probts.model.forecaster import Forecaster
from probts.model.nn.arch.TransformerModule.Embed import DataEmbedding_wo_pos
from probts.model.nn.arch.AutoformerModule.AutoCorrelation import AutoCorrelation, AutoCorrelationLayer
from probts.model.nn.arch.AutoformerModule.Autoformer_EncDec import Encoder, Decoder, EncoderLayer, DecoderLayer, my_Layernorm, series_decomp
class Autoformer(Forecaster):
def __init__(
self,
moving_avg: int = 25,
factor: int = 1,
n_heads: int = 8,
activation: str = 'gelu',
e_layers: int = 2,
d_layers: int = 1,
output_attention: bool = False,
d_ff: int = 256,
label_len: int = 48,
embed: str = 'timeF',
dropout: float = 0.1,
f_hidden_size: int = 256,
**kwargs
):
super().__init__(**kwargs)
if isinstance(self.context_length, list):
self.context_length = max(self.context_length)
self.label_len = self.context_length
# Decomp
kernel_size = moving_avg
self.decomp = series_decomp(kernel_size)
# Embedding
# The series-wise connection inherently contains the sequential information.
# Thus, we can discard the position embedding of transformers.
self.enc_embedding = DataEmbedding_wo_pos(self.target_dim, f_hidden_size, embed, self.freq.lower(),
dropout)
self.dec_embedding = DataEmbedding_wo_pos(self.target_dim, f_hidden_size, embed, self.freq.lower(),
dropout)
# Encoder
self.model_encoder = Encoder(
[
EncoderLayer(
AutoCorrelationLayer(
AutoCorrelation(False, factor, attention_dropout=dropout,
output_attention=output_attention),
f_hidden_size, n_heads),
f_hidden_size,
d_ff,
moving_avg=moving_avg,
dropout=dropout,
activation=activation
) for l in range(e_layers)
],
norm_layer=my_Layernorm(f_hidden_size)
)
# Decoder
self.model_decoder = Decoder(
[
DecoderLayer(
AutoCorrelationLayer(
AutoCorrelation(True, factor, attention_dropout=dropout,
output_attention=False),
f_hidden_size, n_heads),
AutoCorrelationLayer(
AutoCorrelation(False, factor, attention_dropout=dropout,
output_attention=False),
f_hidden_size, n_heads),
f_hidden_size,
self.target_dim,
d_ff,
moving_avg=moving_avg,
dropout=dropout,
activation=activation,
)
for l in range(d_layers)
],
norm_layer=my_Layernorm(f_hidden_size),
projection=nn.Linear(f_hidden_size, self.target_dim, bias=True)
)
self.loss_fn = nn.MSELoss(reduction='none')
def forward(self, inputs, pred_len, enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None, *args, **kwargs):
B, _, _ = inputs.shape
if self.use_time_feat:
past_target = inputs[:,:self.context_length, :self.target_dim]
x_mark_enc = inputs[:,:self.context_length, self.target_dim:]
time_feat = inputs[:,:,self.target_dim:]
else:
past_target = inputs[:,:self.context_length,:self.target_dim]
x_mark_enc = None
time_feat = None
# decomp init
mean = torch.mean(past_target, dim=1).unsqueeze(1).repeat(1, pred_len, 1)
zeros = torch.zeros([B, pred_len, self.target_dim], device=past_target.device)
seasonal_init, trend_init = self.decomp(past_target)
# decoder input
trend_init = torch.cat([trend_init[:, -self.label_len:, :], mean], dim=1)
seasonal_init = torch.cat([seasonal_init[:, -self.label_len:, :], zeros], dim=1)
enc_out = self.enc_embedding(past_target, x_mark_enc)
enc_out, attns = self.model_encoder(enc_out, attn_mask=enc_self_mask)
# dec
dec_out = self.dec_embedding(seasonal_init, time_feat)
seasonal_part, trend_part = self.model_decoder(dec_out, enc_out, x_mask=dec_self_mask, cross_mask=dec_enc_mask,
trend=trend_init)
# final
dec_out = trend_part + seasonal_part
return dec_out[:, -pred_len:, :]
def loss(self, batch_data):
max_pred_len = batch_data.max_prediction_length if batch_data.max_prediction_length is not None else max(self.train_prediction_length)
inputs = self.get_inputs(batch_data, 'all')
outputs = self(inputs, max_pred_len)
loss = self.loss_fn(batch_data.future_target_cdf, outputs)
loss = self.get_weighted_loss(batch_data, loss)
return loss.mean()
def forecast(self, batch_data, num_samples=None):
max_pred_len = batch_data.future_target_cdf.shape[1]
inputs = self.get_inputs(batch_data, 'all')
outputs = self(inputs, max_pred_len)
return outputs.unsqueeze(1)
================================================
FILE: probts/model/forecaster/point_forecaster/dlinear.py
================================================
# ---------------------------------------------------------------------------------
# Portions of this file are derived from LTSF-Linear
# - Source: https://github.com/cure-lab/LTSF-Linear
# - Paper: Are Transformers Effective for Time Series Forecasting?
# - License: Apache-2.0
#
# We thank the authors for their contributions.
# ---------------------------------------------------------------------------------
import torch
import torch.nn as nn
from probts.model.forecaster import Forecaster
from probts.model.nn.arch.decomp import series_decomp
class DLinear(Forecaster):
def __init__(
self,
kernel_size: int,
individual: bool,
**kwargs
):
super().__init__(**kwargs)
if self.input_size != self.target_dim:
self.enc_linear = nn.Linear(
in_features=self.input_size, out_features=self.target_dim
)
else:
self.enc_linear = nn.Identity()
# Decompsition Kernel Size
self.kernel_size = kernel_size
self.decompsition = series_decomp(kernel_size)
self.individual = individual
if self.individual:
self.Linear_Seasonal = nn.ModuleList()
self.Linear_Trend = nn.ModuleList()
for i in range(self.target_dim):
self.Linear_Seasonal.append(nn.Linear(self.context_length, self.prediction_length))
self.Linear_Trend.append(nn.Linear(self.context_length, self.prediction_length))
else:
self.Linear_Seasonal = nn.Linear(self.context_length, self.prediction_length)
self.Linear_Trend = nn.Linear(self.context_length, self.prediction_length)
self.loss_fn = nn.MSELoss(reduction='none')
def encoder(self, inputs):
seasonal_init, trend_init = self.decompsition(inputs)
# [B,C,L]
seasonal_init, trend_init = seasonal_init.permute(0,2,1), trend_init.permute(0,2,1)
if self.individual:
seasonal_output = torch.zeros([seasonal_init.size(0),seasonal_init.size(1),self.prediction_length],dtype=seasonal_init.dtype).to(seasonal_init.device)
trend_output = torch.zeros([trend_init.size(0),trend_init.size(1),self.prediction_length],dtype=trend_init.dtype).to(trend_init.device)
for i in range(self.target_dim):
seasonal_output[:,i,:] = self.Linear_Seasonal[i](seasonal_init[:,i,:])
trend_output[:,i,:] = self.Linear_Trend[i](trend_init[:,i,:])
else:
seasonal_output = self.Linear_Seasonal(seasonal_init)
trend_output = self.Linear_Trend(trend_init)
outputs = seasonal_output + trend_output # [B,C,L]
return outputs.permute(0,2,1)
def loss(self, batch_data):
inputs = self.get_inputs(batch_data, 'encode')
inputs = self.enc_linear(inputs)
outputs = self.encoder(inputs)
loss = self.loss_fn(batch_data.future_target_cdf, outputs)
loss = self.get_weighted_loss(batch_data, loss)
return loss.mean()
def forecast(self, batch_data, num_samples=None):
inputs = self.get_inputs(batch_data, 'encode')
inputs = self.enc_linear(inputs)
outputs = self.encoder(inputs)
return outputs.unsqueeze(1)
================================================
FILE: probts/model/forecaster/point_forecaster/elastst.py
================================================
import torch
import torch.nn as nn
from typing import Union
from probts.model.forecaster import Forecaster
from probts.model.nn.arch.ElasTSTModule.ElasTST_backbone import ElasTST_backbone
from probts.utils import convert_to_list, weighted_average
from probts.data.data_utils.data_scaler import InstanceNorm
class ElasTST(Forecaster):
def __init__(
self,
l_patch_size: Union[str, int, list] = '8_16_32',
k_patch_size: int = 1,
stride: int = None,
rotate: bool = True,
addv: bool = False,
bin_att: bool = False,
rope_theta_init: str = 'exp',
min_period: float = 1,
max_period: float = 1000,
learn_tem_emb: bool = False,
learnable_rope: bool = True,
abs_tem_emb: bool = False,
structured_mask: bool = True,
max_seq_len: int = 1024,
theta_base: float = 10000,
t_layers: int = 1,
v_layers: int = 0,
patch_share_backbone: bool = True,
n_heads: int = 16,
d_k: int = 8,
d_v: int = 8,
d_inner: int = 256,
dropout: float = 0.,
in_channels: int = 1,
f_hidden_size: int = 40,
use_norm: bool = True,
**kwargs
):
"""
ElasTST model.
Parameters
----------
l_patch_size : Union[str, int, list]
Patch sizes configuration.
k_patch_size : int
Patch size for variables.
stride : int
Stride for patch splitting. If None, uses patch size as default.
rotate : bool
Apply rotational positional embeddings.
addv : bool
Whether to add RoPE information to value in attention. If False, only rotate the key and query embeddings.
bin_att : bool
Use binary attention biases to encode variate indices (any-variate attention).
rope_theta_init : str
Initialization for TRoPE, default is 'exp', as used in the paper. Options: ['exp', 'linear', 'uniform', 'rope'].
min_period : float
Minimum initialized period coefficient for rotary embeddings.
max_period : float
Maximum initialized period coefficient for rotary embeddings.
learn_tem_emb : bool
Whether to use learnable temporal embeddings.
learnable_rope : bool
Make period coefficient in TRoPE learnable.
abs_tem_emb : bool
Use absolute temporal embeddings if True.
structured_mask : bool
Apply structured mask or not.
max_seq_len : int
Maximum sequence length for the input time series.
theta_base : int
Base frequency of vanilla RoPE.
t_layers : int
Number of temporal attention layers.
v_layers : int
Number of variable attention layers.
patch_share_backbone : bool
Share Transformer backbone across patches.
n_heads : int
Number of attention heads in the multi-head attention mechanism.
d_k : int
Dimensionality of key embeddings in attention.
d_v : int
Dimensionality of value embeddings in attention.
d_inner : int
Size of inner layers in the feed-forward network.
dropout : float
Dropout rate for regularization during training.
in_channels : int
Number of input channels in the time series data. We only consider univariable.
f_hidden_size : int
Hidden size for the feed-forward layers.
use_norm : bool
Whether to apply instance normalization.
**kwargs : dict
Additional keyword arguments for extended functionality.
"""
super().__init__(**kwargs)
self.l_patch_size = convert_to_list(l_patch_size)
self.use_norm = use_norm
# Model
self.model = ElasTST_backbone(l_patch_size=self.l_patch_size,
stride=stride,
k_patch_size=k_patch_size,
in_channels=in_channels,
t_layers=t_layers,
v_layers=v_layers,
hidden_size=f_hidden_size,
d_inner=d_inner,
n_heads=n_heads,
d_k=d_k,
d_v=d_v,
dropout=dropout,
rotate=rotate,
max_seq_len=max_seq_len,
theta=theta_base,
addv=addv,
bin_att=bin_att,
learn_tem_emb=learn_tem_emb,
abs_tem_emb=abs_tem_emb,
learnable_theta=learnable_rope,
structured_mask=structured_mask,
rope_theta_init=rope_theta_init,
min_period=min_period,
max_period=max_period,
patch_share_backbone=patch_share_backbone
)
self.loss_fn = nn.MSELoss(reduction='none')
self.instance_norm = InstanceNorm()
def forward(self, batch_data, pred_len, dataset_name=None):
new_pred_len = pred_len
for p in self.l_patch_size:
new_pred_len = self.check_divisibility(new_pred_len, p)
B, _, K = batch_data.past_target_cdf.shape
past_target = batch_data.past_target_cdf
past_observed_values = batch_data.past_observed_values
if self.use_norm:
past_target = self.instance_norm(past_target, 'norm')
# future_observed_values is the mask indicate whether there is a value in a position
future_observed_values = torch.zeros([B, new_pred_len, K]).to(batch_data.future_observed_values.device)
pred_len = batch_data.future_observed_values.shape[1]
future_observed_values[:,:pred_len] = batch_data.future_observed_values
# target placeholder
future_placeholder = torch.zeros([B, new_pred_len, K]).to(batch_data.past_target_cdf.device)
x, pred_list = self.model(past_target, future_placeholder, past_observed_values, future_observed_values, dataset_name=dataset_name)
dec_out = x[:, :pred_len]
if self.use_norm:
dec_out = self.instance_norm(dec_out, 'denorm')
return dec_out # [b l k], [b l k #patch_size]
def loss(self, batch_data, reduce='none'):
max_pred_len = batch_data.max_prediction_length if batch_data.max_prediction_length is not None else self.max_prediction_length
predict = self(batch_data, max_pred_len, dataset_name=None, )
target = batch_data.future_target_cdf
observed_values = batch_data.future_observed_values
loss = self.loss_fn(target, predict)
loss = self.get_weighted_loss(observed_values, loss, reduce=reduce)
if reduce=='mean':
loss = loss.mean()
return loss
def forecast(self, batch_data, num_samples=None):
# max_pred_len = batch_data.max_prediction_length if batch_data.max_prediction_length is not None else max(self.prediction_length)
max_pred_len = batch_data.future_target_cdf.shape[1]
outputs = self(batch_data, max_pred_len, dataset_name=None, )
return outputs.unsqueeze(1)
def check_divisibility(self, pred_len, patch_size):
if pred_len % patch_size == 0:
return pred_len
else:
return (pred_len // patch_size + 1) * patch_size
def get_weighted_loss(self, observed_values, loss, reduce='mean'):
loss_weights, _ = observed_values.min(dim=-1, keepdim=True)
loss = weighted_average(loss, weights=loss_weights, dim=1, reduce=reduce)
return loss
================================================
FILE: probts/model/forecaster/point_forecaster/forecastpfn.py
================================================
# ---------------------------------------------------------------------------------
# Portions of this file are derived from ForecastPFN
# - Source: https://github.com/abacusai/ForecastPFN
# - Paper: ForecastPFN: Synthetically-Trained Zero-Shot Forecasting
# - License: Apache License 2.0
# We thank the authors for their contributions.
# ---------------------------------------------------------------------------------
import datetime
import numpy as np
import pandas as pd
import tensorflow as tf
import torch
from keras import backend
from sklearn.preprocessing import StandardScaler
from probts.model.forecaster import Forecaster
def smape(y_true, y_pred):
""" Calculate Armstrong's original definition of sMAPE between `y_true` & `y_pred`.
`loss = 200 * mean(abs((y_true - y_pred) / (y_true + y_pred), axis=-1)`
Args:
y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`.
y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`.
Returns:
Symmetric mean absolute percentage error values. shape = `[batch_size, d0, ..
dN-1]`.
"""
y_pred = tf.convert_to_tensor(y_pred)
y_true = tf.cast(y_true, y_pred.dtype)
diff = tf.abs(
(y_true - y_pred) /
backend.maximum(y_true + y_pred, backend.epsilon())
)
return 200.0 * backend.mean(diff, axis=-1)
class ForecastPFN(Forecaster):
def __init__(
self,
label_len: int = 48,
ckpt_path: str = None,
**kwargs
):
super().__init__(**kwargs)
self.no_training = True
self.label_len = label_len
self.model = tf.keras.models.load_model(ckpt_path, custom_objects={'smape': smape})
def _ForecastPFN_time_features(self, x_mark_enc: np.ndarray, x_mark_dec: np.ndarray):
def extract_time_features(ts):
original_shape = ts.shape
ts = ts.reshape(-1) # Flatten the array
if type(ts[0]) == datetime.datetime:
year = np.array([x.year for x in ts])
month = np.array([x.month for x in ts])
day = np.array([x.day for x in ts])
day_of_week = np.array([x.weekday() + 1 for x in ts])
day_of_year = np.array([x.timetuple().tm_yday for x in ts])
else:
ts = pd.to_datetime(ts)
year = ts.year.values
month = ts.month.values
day = ts.day.values
day_of_week = ts.day_of_week.values + 1
day_of_year = ts.day_of_year.values
features = np.stack([year, month, day, day_of_week, day_of_year], axis=-1)
return features.reshape(*original_shape, -1).squeeze()
# Process the encoder and decoder inputs
x_mark_enc_features = extract_time_features(x_mark_enc)
x_mark_dec_features = extract_time_features(x_mark_dec)
return x_mark_enc_features, x_mark_dec_features
def _process_tuple(self, x, x_mark, y_mark, horizon):
"""
x: tensor of shape (n, 1)
x_mark: tensor of shape (n, d)
y_mark: tensor of shape (horizon, d)
where
n is the input sequence length
horizon is the output sequence length
d is the dimensionality of the time_stamp (5 for ForecastPFN)
"""
if tf.reduce_all(x == x[0]):
x = tf.concat([x[:-1], x[-1:] + 1], axis=0)
history = x.numpy()
scaler = StandardScaler()
scaler.fit(history)
history = scaler.transform(history)
history_mean = np.nanmean(history[-6:])
history_std = np.nanstd(history[-6:])
local_scale = history_mean + history_std + 1e-4
history = np.clip(history / local_scale, a_min=0, a_max=1)
if x.shape[0] != 100:
if x.shape[0] > 100:
target = x_mark[-100:, :]
history = history[-100:, :]
else:
target = tf.pad(x_mark, [[100 - x.shape[0], 0], [0, 0]])
history = tf.pad(history, [[100 - x.shape[0], 0], [0, 0]])
history = tf.repeat(tf.expand_dims(history, axis=0), horizon, axis=0)[:, :, 0]
ts = tf.repeat(tf.expand_dims(target, axis=0), horizon, axis=0)
else:
ts = tf.repeat(tf.expand_dims(x_mark, axis=0), horizon, axis=0)
history = tf.convert_to_tensor(history, dtype=tf.float32)
task = tf.fill([horizon], 1)
y_mark_tensor = tf.convert_to_tensor(y_mark[-horizon:, :], dtype=tf.int64)
target_ts = tf.expand_dims(y_mark_tensor, axis=1)
model_input = {'ts': ts, 'history': history, 'target_ts': target_ts, 'task': task}
pred_vals = self.model(model_input)
scaled_vals = pred_vals['result'].numpy().T.reshape(-1) * pred_vals['scale'].numpy().reshape(-1)
scaled_vals = scaler.inverse_transform([scaled_vals])
return scaled_vals
def _process_batch(self, batch_x, batch_y, batch_x_mark, batch_y_mark):
preds = []
for idx, (x, y, x_mark, y_mark) in enumerate(zip(batch_x, batch_y, batch_x_mark, batch_y_mark)):
pred = self._process_tuple(x, x_mark, y_mark, self.prediction_length)
preds.append(pred)
return preds
def forecast(self, batch_data, num_samples=None):
# For now, we only support batch_size=1
B, _, K = batch_data.past_target_cdf.shape
inputs = batch_data.past_target_cdf[:, -self.context_length:, ...].cpu()
x_mark_enc = batch_data.past_time_feat[:, -self.context_length:, ...].cpu().numpy().astype('datetime64[s]')
x_mark_dec = batch_data.future_time_feat.cpu().numpy().astype('datetime64[s]')
x_mark_enc, x_mark_dec = self._ForecastPFN_time_features(x_mark_enc, x_mark_dec)
x_mark_dec = tf.concat([x_mark_enc[:, -self.label_len:, :], x_mark_dec], axis=1)
inputs = tf.reshape(inputs, [-1, self.context_length, 1])
x_mark_enc = tf.repeat(x_mark_enc, repeats=K, axis=0)
x_mark_dec = tf.repeat(x_mark_dec, repeats=K, axis=0)
dec_inp = tf.zeros_like(inputs[:, -self.prediction_length:, :])
dec_inp = tf.concat([inputs[:, -self.label_len:, :], dec_inp], axis=1)
x_mark_enc = tf.cast(x_mark_enc, tf.int64)
x_mark_dec = tf.cast(x_mark_dec, tf.int64)
outputs = self._process_batch(inputs, dec_inp, x_mark_enc, x_mark_dec)
outputs = tf.concat(outputs, axis=0)
outputs = tf.reshape(outputs, [B, -1, K])
outputs = outputs[:, :self.prediction_length, :].numpy()
outputs = torch.tensor(outputs)
return outputs.unsqueeze(1)
================================================
FILE: probts/model/forecaster/point_forecaster/gru.py
================================================
import torch
import torch.nn as nn
from probts.data import ProbTSBatchData
from probts.utils import repeat
from probts.model.forecaster import Forecaster
class GRUForecaster(Forecaster):
def __init__(
self,
num_layers: int = 2,
f_hidden_size: int = 40,
dropout: float = 0.1,
**kwargs
):
super().__init__(**kwargs)
self.autoregressive = True
self.model = nn.GRU(
input_size=self.input_size,
hidden_size=f_hidden_size,
num_layers=num_layers,
dropout=dropout,
batch_first=True
)
self.linear = nn.Linear(f_hidden_size, self.target_dim)
self.loss_fn = nn.MSELoss(reduction='none')
def loss(self, batch_data):
inputs = self.get_inputs(batch_data, 'all')
outputs, _ = self.model(inputs)
outputs = outputs[:, -self.prediction_length-1:-1, ...]
outputs = self.linear(outputs)
loss = self.loss_fn(batch_data.future_target_cdf, outputs)
loss = self.get_weighted_loss(batch_data, loss)
return loss.mean()
def forecast(self, batch_data, num_samples=None):
forecasts = []
states = self.encode(batch_data)
past_target_cdf = batch_data.past_target_cdf
for k in range(self.prediction_length):
current_batch_data = ProbTSBatchData({
'target_dimension_indicator': batch_data.target_dimension_indicator,
'past_target_cdf': past_target_cdf,
'future_time_feat': batch_data.future_time_feat[:, k : k + 1:, ...]
}, device=batch_data.device)
outputs, states = self.decode(current_batch_data, states)
outputs = self.linear(outputs)
forecasts.append(outputs)
past_target_cdf = torch.cat(
(past_target_cdf, outputs), dim=1
)
forecasts = torch.cat(forecasts, dim=1).reshape(
-1, self.prediction_length, self.target_dim)
return forecasts.unsqueeze(1)
def encode(self, batch_data):
inputs = self.get_inputs(batch_data, 'encode')
outputs, states = self.model(inputs)
return states
def decode(self, batch_data, states=None):
inputs = self.get_inputs(batch_data, 'decode')
outputs, states = self.model(inputs, states)
return outputs, states
================================================
FILE: probts/model/forecaster/point_forecaster/itransformer.py
================================================
# ---------------------------------------------------------------------------------
# Portions of this file are derived from iTransformer
# - Source: https://github.com/thuml/iTransformer
# - Paper: iTransformer: Inverted Transformers Are Effective for Time Series Forecasting
# - License: MIT License
# We thank the authors for their contributions.
# ---------------------------------------------------------------------------------
import torch
import torch.nn as nn
from probts.model.forecaster import Forecaster
from probts.model.nn.arch.TransformerModule.Transformer_EncDec import Encoder, EncoderLayer
from probts.model.nn.arch.TransformerModule.SelfAttention_Family import FullAttention, AttentionLayer
from probts.model.nn.arch.TransformerModule.Embed import DataEmbedding_inverted
class iTransformer(Forecaster):
def __init__(
self,
factor: int = 1,
n_heads: int = 8,
activation: str = 'gelu',
e_layers: int = 2,
output_attention: bool = False,
d_ff: int = 512,
label_len: int = 48,
use_norm: bool = True,
class_strategy:str = 'projection',
dropout: float = 0.1,
f_hidden_size: int = 512,
**kwargs
):
super().__init__(**kwargs)
self.label_len = label_len
self.use_norm = use_norm
# Embedding
self.enc_embedding = DataEmbedding_inverted(self.context_length, f_hidden_size,
dropout)
self.class_strategy = class_strategy
# Encoder-only architecture
self.model_encoder = Encoder(
[
EncoderLayer(
AttentionLayer(
FullAttention(False, factor, attention_dropout=dropout,
output_attention=output_attention), f_hidden_size, n_heads),
f_hidden_size,
d_ff,
dropout=dropout,
activation=activation
) for l in range(e_layers)
],
norm_layer=torch.nn.LayerNorm(f_hidden_size)
)
self.projector = nn.Linear(f_hidden_size, self.prediction_length, bias=True)
self.loss_fn = nn.MSELoss(reduction='none')
def forward(self, inputs):
if self.use_time_feat:
past_target = inputs[:,:,:self.target_dim]
x_mark_enc = inputs[:,:,-self.target_dim:]
else:
past_target = inputs
x_mark_enc = None
if self.use_norm:
# Normalization from Non-stationary Transformer
means = past_target.mean(1, keepdim=True).detach()
past_target = past_target - means
stdev = torch.sqrt(torch.var(past_target, dim=1, keepdim=True, unbiased=False) + 1e-5)
past_target /= stdev
_, _, N = past_target.shape # B L N
# B: batch_size; E: d_model;
# L: seq_len; S: pred_len;
# N: number of variate (tokens), can also includes covariates
# Embedding
# B L N -> B N E (B L N -> B L E in the vanilla Transformer)
enc_out = self.enc_embedding(past_target, x_mark_enc) # covariates (e.g timestamp) can be also embedded as tokens
# B N E -> B N E (B L E -> B L E in the vanilla Transformer)
# the dimensions of embedded time series has been inverted, and then processed by native attn, layernorm and ffn modules
enc_out, attns = self.model_encoder(enc_out, attn_mask=None)
# B N E -> B N S -> B S N
dec_out = self.projector(enc_out).permute(0, 2, 1)[:, :, :N] # filter the covariates
if self.use_norm:
# De-Normalization from Non-stationary Transformer
dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.prediction_length, 1))
dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, self.prediction_length, 1))
return dec_out[:, -self.prediction_length:, :]
def forecast(self, batch_data, num_samples=None):
inputs = self.get_inputs(batch_data, 'encode')
output = self(inputs)
return output.unsqueeze(1)
def loss(self, batch_data):
inputs = self.get_inputs(batch_data, 'encode')
outputs = self(inputs)
loss = self.loss_fn(batch_data.future_target_cdf, outputs)
loss = self.get_weighted_loss(batch_data, loss)
return loss.mean()
================================================
FILE: probts/model/forecaster/point_forecaster/linear.py
================================================
# ---------------------------------------------------------------------------------
# Portions of this file are derived from LTSF-Linear
# - Source: https://github.com/cure-lab/LTSF-Linear
# - Paper: Are Transformers Effective for Time Series Forecasting?
# - License: Apache-2.0
#
# We thank the authors for their contributions.
# ---------------------------------------------------------------------------------
import torch
import torch.nn as nn
from probts.model.forecaster import Forecaster
class LinearForecaster(Forecaster):
def __init__(
self,
individual: bool = True,
**kwargs
):
super().__init__(**kwargs)
self.individual = individual
if self.individual:
self.linear = nn.ModuleList()
for i in range(self.input_size):
self.linear.append(nn.Linear(self.context_length, self.prediction_length))
else:
self.linear = nn.Linear(self.context_length, self.prediction_length)
self.out_linear = nn.Linear(self.input_size, self.target_dim)
self.loss_fn = nn.MSELoss(reduction='none')
def forward(self, x):
if self.individual:
outputs = torch.zeros([x.size(0), self.prediction_length, x.size(2)], dtype=x.dtype).to(x.device)
for i in range(self.input_size):
outputs[:, :, i] = self.linear[i](x[:, :, i])
else:
outputs = self.linear(x.permute(0,2,1)).permute(0,2,1)
outputs = self.out_linear(outputs)
return outputs
def forecast(self, batch_data, num_samples=None):
inputs = self.get_inputs(batch_data, 'encode')
forecasts = self(inputs).unsqueeze(1)
return forecasts
def loss(self, batch_data):
inputs = self.get_inputs(batch_data, 'encode')
outputs = self(inputs)
loss = self.loss_fn(batch_data.future_target_cdf, outputs)
loss = self.get_weighted_loss(batch_data, loss)
return loss.mean()
================================================
FILE: probts/model/forecaster/point_forecaster/mean.py
================================================
import torch
from einops import repeat
from probts.model.forecaster import Forecaster
class MeanForecaster(Forecaster):
def __init__(
self,
global_mean: torch.Tensor,
mode: str = 'batch',
**kwargs
):
super().__init__(**kwargs)
self.global_mean = global_mean
self.mode = mode
self.no_training = True
@property
def name(self):
return self.mode + self.__class__.__name__
def forecast(self, batch_data, num_samples=None):
B = batch_data.past_target_cdf.shape[0]
if self.mode == 'global':
outputs = self.global_mean.clone()
elif self.mode == 'batch':
outputs = torch.mean(batch_data.past_target_cdf, dim=1)
outputs = torch.mean(outputs, dim=0)
else:
raise ValueError(f"Unsupported mode: {self.mode}")
outputs = repeat(outputs,'d -> b n l d', b=B, n=1, l=self.prediction_length)
return outputs
================================================
FILE: probts/model/forecaster/point_forecaster/moderntcn.py
================================================
# ---------------------------------------------------------------------------------
# Portions of this file are derived from ModernTCN
# - Source: https://github.com/luodhhh/ModernTCN/tree/main
# - Paper: ModernTCN: A Modern Pure Convolution Structure for General Time Series Analysis
# - License: MIT License
#
# We thank the authors for their contributions.
# ---------------------------------------------------------------------------------
import sys
import torch
import torch.nn as nn
from typing import List
from probts.model.forecaster import Forecaster
from probts.model.nn.arch.decomp import series_decomp
from probts.model.nn.arch.ModernTCN_backbone import ModernTCNModel
# torch.backends.cudnn.enabled = False
class ModernTCN(Forecaster):
def __init__(
self,
kernel_size: int = 25,
decomposition: int = 0,
stem_ratio: int = 6,
downsample_ratio: int = 2,
ffn_ratio: int = 2,
num_blocks: List[int] = [1, 1, 1, 1],
large_size: List[int] = [31, 29, 27, 13],
small_size: List[int] = [5, 5, 5, 5],
dims: List[int] = [256, 256, 256, 256],
dw_dims: List[int] = [256, 256, 256, 256],
small_kernel_merged: bool = False,
use_multi_scale: bool = True,
revin: int = 1,
affine: int = 0,
subtract_last: int = 0,
individual: int = 0,
patch_size: int = 16,
patch_stride: int = 8,
dropout: float = 0.05,
head_dropout: float = 0.0,
**kwargs
):
super().__init__(**kwargs)
self.stem_ratio = stem_ratio
self.downsample_ratio = downsample_ratio
self.ffn_ratio = ffn_ratio
self.num_blocks = num_blocks
self.large_size = large_size
self.small_size = small_size
self.dims = dims
self.dw_dims = dw_dims
self.nvars = self.target_dim
self.small_kernel_merged = small_kernel_merged
self.drop_backbone = dropout
self.drop_head = head_dropout
self.use_multi_scale = use_multi_scale
self.revin = revin
self.affine = affine
self.subtract_last = subtract_last
self.seq_len = self.context_length
self.c_in = self.nvars,
self.individual = individual
self.target_window = self.prediction_length
self.kernel_size = kernel_size
self.patch_size = patch_size
self.patch_stride = patch_stride
self.decomposition = decomposition
if self.decomposition:
self.decomp_module = series_decomp(self.kernel_size)
self.model_res = ModernTCNModel(patch_size=self.patch_size,patch_stride=self.patch_stride,stem_ratio=self.stem_ratio, downsample_ratio=self.downsample_ratio, ffn_ratio=self.ffn_ratio, num_blocks=self.num_blocks, large_size=self.large_size, small_size=self.small_size, dims=self.dims, dw_dims=self.dw_dims,
nvars=self.nvars, small_kernel_merged=self.small_kernel_merged, backbone_dropout=self.drop_backbone, head_dropout=self.drop_head, use_multi_scale=self.use_multi_scale, revin=self.revin, affine=self.affine,
subtract_last=self.subtract_last, freq=self.freq, seq_len=self.seq_len, c_in=self.c_in, individual=self.individual, target_window=self.target_window)
self.model_trend = ModernTCNModel(patch_size=self.patch_size,patch_stride=self.patch_stride,stem_ratio=self.stem_ratio, downsample_ratio=self.downsample_ratio, ffn_ratio=self.ffn_ratio, num_blocks=self.num_blocks, large_size=self.large_size, small_size=self.small_size, dims=self.dims, dw_dims=self.dw_dims,
nvars=self.nvars, small_kernel_merged=self.small_kernel_merged, backbone_dropout=self.drop_backbone, head_dropout=self.drop_head, use_multi_scale=self.use_multi_scale, revin=self.revin, affine=self.affine,
subtract_last=self.subtract_last, freq=self.freq, seq_len=self.seq_len, c_in=self.c_in, individual=self.individual, target_window=self.target_window)
else:
self.model = ModernTCNModel(patch_size=self.patch_size,patch_stride=self.patch_stride,stem_ratio=self.stem_ratio, downsample_ratio=self.downsample_ratio, ffn_ratio=self.ffn_ratio, num_blocks=self.num_blocks, large_size=self.large_size, small_size=self.small_size, dims=self.dims, dw_dims=self.dw_dims,
nvars=self.nvars, small_kernel_merged=self.small_kernel_merged, backbone_dropout=self.drop_backbone, head_dropout=self.drop_head, use_multi_scale=self.use_multi_scale, revin=self.revin, affine=self.affine,
subtract_last=self.subtract_last, freq=self.freq, seq_len=self.seq_len, c_in=self.c_in, individual=self.individual, target_window=self.target_window)
self.loss_fn = nn.MSELoss(reduction='none')
if self.input_size != self.target_dim:
self.enc_linear = nn.Linear(
in_features=self.input_size, out_features=self.target_dim
)
else:
self.enc_linear = nn.Identity()
def encoder(self, x, te=None):
if self.decomposition:
res_init, trend_init = self.decomp_module(x)
res_init, trend_init = res_init.permute(0, 2, 1), trend_init.permute(0, 2, 1)
if te is not None:
te = te.permute(0, 2, 1)
res = self.model_res(res_init, te)
trend = self.model_trend(trend_init, te)
x = res + trend
x = x.permute(0, 2, 1)
else:
x = x.permute(0, 2, 1)
if te is not None:
te = te.permute(0, 2, 1)
x = self.model(x, te)
x = x.permute(0, 2, 1)
return x
def loss(self, batch_data):
inputs = self.get_inputs(batch_data, 'encode')
# inputs = inputs[:,:,:self.target_dim]
inputs = self.enc_linear(inputs)
outputs = self.encoder(inputs)
loss = self.loss_fn(batch_data.future_target_cdf, outputs)
loss = self.get_weighted_loss(batch_data, loss)
return loss.mean()
def forecast(self, batch_data, num_samples=None):
# b l k
inputs = self.get_inputs(batch_data, 'encode')
# inputs = inputs[:,:,:self.target_dim]
inputs = self.enc_linear(inputs)
outputs = self.encoder(inputs)
return outputs.unsqueeze(1)
================================================
FILE: probts/model/forecaster/point_forecaster/naive.py
================================================
import torch
from einops import repeat
from probts.model.forecaster import Forecaster
import sys
class NaiveForecaster(Forecaster):
def __init__(
self,
**kwargs
):
super().__init__(**kwargs)
self.no_training = True
def forecast(self, batch_data, num_samples=None):
last_value = batch_data.past_target_cdf[:,-1,:]
outputs = repeat(last_value,'b k -> b n l k', n=1, l=self.prediction_length)
return outputs
================================================
FILE: probts/model/forecaster/point_forecaster/nhits.py
================================================
# ---------------------------------------------------------------------------------
# Portions of this file are derived from NeuralForecast
# - Source: https://github.com/Nixtla/neuralforecast
# - Paper: N-HiTS: Neural Hierarchical Interpolation for Time Series Forecasting
# - License: Apache-2.0
#
# We thank the authors for their contributions.
# ---------------------------------------------------------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from einops import rearrange, repeat
from functools import partial
from typing import List, Tuple
from probts.model.forecaster import Forecaster
class StaticFeaturesEncoder(nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
layers = [nn.Dropout(p=0.5), nn.Linear(in_features=in_features, out_features=out_features), nn.ReLU()]
self.encoder = nn.Sequential(*layers)
def forward(self, x):
x = self.encoder(x)
return x
class IdentityBasis(nn.Module):
def __init__(self, backcast_size: int, forecast_size: int, interpolation_mode: str):
super().__init__()
assert (interpolation_mode in ["linear", "nearest"]) or ("cubic" in interpolation_mode)
self.forecast_size = forecast_size
self.backcast_size = backcast_size
self.interpolation_mode = interpolation_mode
def forward(
self,
backcast_theta: torch.Tensor,
forecast_theta: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
backcast = backcast_theta
knots = forecast_theta
if self.interpolation_mode == "nearest":
knots = knots[:, None, :]
forecast = F.interpolate(knots, size=self.forecast_size, mode=self.interpolation_mode)
forecast = forecast[:, 0, :]
elif self.interpolation_mode == "linear":
knots = knots[:, None, :]
forecast = F.interpolate(
knots, size=self.forecast_size, mode=self.interpolation_mode
) # , align_corners=True)
forecast = forecast[:, 0, :]
elif "cubic" in self.interpolation_mode:
batch_size = int(self.interpolation_mode.split("-")[-1])
knots = knots[:, None, None, :]
forecast = torch.zeros((len(knots), self.forecast_size)).to(knots.device)
n_batches = int(np.ceil(len(knots) / batch_size))
for i in range(n_batches):
forecast_i = F.interpolate(
knots[i * batch_size : (i + 1) * batch_size], size=self.forecast_size, mode="bicubic"
) # , align_corners=True)
forecast[i * batch_size : (i + 1) * batch_size] += forecast_i[:, 0, 0, :]
return backcast, forecast
def init_weights(module, initialization):
if type(module) == torch.nn.Linear:
if initialization == "orthogonal":
torch.nn.init.orthogonal_(module.weight)
elif initialization == "he_uniform":
torch.nn.init.kaiming_uniform_(module.weight)
elif initialization == "he_normal":
torch.nn.init.kaiming_normal_(module.weight)
elif initialization == "glorot_uniform":
torch.nn.init.xavier_uniform_(module.weight)
elif initialization == "glorot_normal":
torch.nn.init.xavier_normal_(module.weight)
elif initialization == "lecun_normal":
pass # torch.nn.init.normal_(module.weight, 0.0, std=1/np.sqrt(module.weight.numel()))
else:
assert 1 < 0, f"Initialization {initialization} not found"
ACTIVATIONS = ["ReLU", "Softplus", "Tanh", "SELU", "LeakyReLU", "PReLU", "Sigmoid"]
class NHiTSBlock(nn.Module):
"""
N-HiTS block which takes a basis function as an argument.
"""
def __init__(
self,
context_length: int,
prediction_length: int,
output_size: int,
covariate_size: int,
static_size: int,
static_hidden_size: int,
n_theta: int,
hidden_size: List[int],
pooling_sizes: int,
pooling_mode: str,
basis: nn.Module,
n_layers: int,
batch_normalization: bool,
dropout: float,
activation: str,
):
super().__init__()
assert pooling_mode in ["max", "average"]
self.context_length_pooled = int(np.ceil(context_length / pooling_sizes))
if static_size == 0:
static_hidden_size = 0
self.context_length = context_length
self.output_size = [output_size]
self.n_theta = n_theta
self.prediction_length = prediction_length
self.static_size = static_size
self.static_hidden_size = static_hidden_size
self.covariate_size = covariate_size
self.pooling_sizes = pooling_sizes
self.batch_normalization = batch_normalization
self.dropout = dropout
hidden1 = [self.context_length_pooled * len(self.output_size) + (self.context_length + self.prediction_length) * self.covariate_size + self.static_hidden_size]
self.hidden_size = hidden1 + hidden_size
assert activation in ACTIVATIONS, f"{activation} is not in {ACTIVATIONS}"
activ = getattr(nn, activation)()
if pooling_mode == "max":
self.pooling_layer = nn.MaxPool1d(kernel_size=self.pooling_sizes, stride=self.pooling_sizes, ceil_mode=True)
elif pooling_mode == "average":
self.pooling_layer = nn.AvgPool1d(kernel_size=self.pooling_sizes, stride=self.pooling_sizes, ceil_mode=True)
hidden_layers = []
for i in range(n_layers):
hidden_layers.append(nn.Linear(in_features=self.hidden_size[i], out_features=self.hidden_size[i + 1]))
hidden_layers.append(activ)
if self.batch_normalization:
hidden_layers.append(nn.BatchNorm1d(num_features=self.hidden_size[i + 1]))
if self.dropout > 0:
hidden_layers.append(nn.Dropout(p=self.dropout))
output_layer = [
nn.Linear(
in_features=self.hidden_size[-1],
out_features=context_length * len(self.output_size) + n_theta * sum(self.output_size),
)
]
layers = hidden_layers + output_layer
# static_size is computed with data, static_hidden_size is provided by user, if 0 no statics are used
if (self.static_size > 0) and (self.static_hidden_size > 0):
self.static_encoder = StaticFeaturesEncoder(in_features=static_size, out_features=static_hidden_size)
self.layers = nn.Sequential(*layers)
self.basis = basis
def forward(
self, encoder_y: torch.Tensor, encoder_x_t: torch.Tensor, decoder_x_t: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size = len(encoder_y)
encoder_y = encoder_y.transpose(1, 2)
# Pooling layer to downsample input
encoder_y = self.pooling_layer(encoder_y)
encoder_y = encoder_y.transpose(1, 2).reshape(batch_size, -1)
if self.covariate_size > 0:
encoder_y = torch.cat(
(
encoder_y,
encoder_x_t.reshape(batch_size, -1),
decoder_x_t.reshape(batch_size, -1),
),
1,
)
# Compute local projection weights and projection
theta = self.layers(encoder_y)
backcast_theta = theta[:, : self.context_length * len(self.output_size)].reshape(-1, self.context_length)
forecast_theta = theta[:, self.context_length * len(self.output_size) :].reshape(-1, self.n_theta)
backcast, forecast = self.basis(backcast_theta, forecast_theta)
backcast = backcast.reshape(-1, len(self.output_size), self.context_length).transpose(1, 2)
forecast = forecast.reshape(-1, sum(self.output_size), self.prediction_length).transpose(1, 2)
return backcast, forecast
class NHiTS(Forecaster):
def __init__(
self,
n_blocks: list,
pooling_mode,
interpolation_mode,
dropout,
activation,
initialization,
batch_normalization,
shared_weights,
output_size: int = 1,
hidden_size: int = 512,
naive_level: bool = True,
static_size: int = 0,
static_hidden_size: int = 0,
n_layers: int = 2,
pooling_sizes: list = None,
downsample_frequencies: list = None,
**kwargs
):
super().__init__(**kwargs)
"""
N-HiTS model.
Parameters
----------
n_time_in: int
Multiplier to get insample size.
Insample size = n_time_in * output_size
n_time_out: int
Forecast horizon.
shared_weights: bool
If True, repeats first block.
activation: str
Activation function.
An item from ['relu', 'softplus', 'tanh', 'selu', 'lrelu', 'prelu', 'sigmoid'].
initialization: str
Initialization function.
An item from ['orthogonal', 'he_uniform', 'glorot_uniform', 'glorot_normal', 'lecun_normal'].
stack_types: List[str]
List of stack types.
Subset from ['identity'].
n_blocks: List[int]
Number of blocks for each stack type.
Note that len(n_blocks) = len(stack_types).
n_layers: List[int]
Number of layers for each stack type.
Note that len(n_layers) = len(stack_types).
n_theta_hidden: List[List[int]]
Structure of hidden layers for each stack type.
Each internal list should contain the number of units of each hidden layer.
Note that len(n_theta_hidden) = len(stack_types).
n_pool_kernel_size List[int]:
Pooling size for input for each stack.
Note that len(n_pool_kernel_size) = len(stack_types).
n_freq_downsample List[int]:
Downsample multiplier of output for each stack.
Note that len(n_freq_downsample) = len(stack_types).
batch_normalization: bool
Whether perform batch normalization.
dropout_prob_theta: float
Float between (0, 1).
Dropout for Nbeats basis.
"""
n_stacks = len(n_blocks)
covariate_size = 0
if self.use_feat_idx_emb:
covariate_size = covariate_size + self.feat_idx_emb_dim
if self.use_time_feat:
covariate_size = covariate_size + self.time_feat_dim
self.covariate_size = covariate_size
self.output_size = output_size
self.naive_level = naive_level
n_layers = [n_layers] * n_stacks
hidden_size = n_stacks * [2 * [hidden_size]]
if pooling_sizes is None:
pooling_sizes = np.exp2(np.round(np.linspace(0.49, np.log2(self.prediction_length / 2), n_stacks)))
pooling_sizes = [int(x) for x in pooling_sizes[::-1]]
if downsample_frequencies is None:
downsample_frequencies = [min(self.prediction_length, int(np.power(x, 1.5))) for x in pooling_sizes]
blocks = self.create_stack(
n_blocks=n_blocks,
context_length=self.context_length,
prediction_length=self.prediction_length,
output_size=output_size,
covariate_size=covariate_size,
static_size=static_size,
static_hidden_size=static_hidden_size,
n_layers=n_layers,
hidden_size=hidden_size,
pooling_sizes=pooling_sizes,
downsample_frequencies=downsample_frequencies,
pooling_mode=pooling_mode,
interpolation_mode=interpolation_mode,
batch_normalization=batch_normalization,
dropout=dropout,
activation=activation,
shared_weights=shared_weights,
initialization=initialization,
)
self.blocks = torch.nn.ModuleList(blocks)
self.loss_fn = nn.MSELoss(reduction='none')
def create_stack(
self,
n_blocks,
context_length,
prediction_length,
output_size,
covariate_size,
static_size,
static_hidden_size,
n_layers,
hidden_size,
pooling_sizes,
downsample_frequencies,
pooling_mode,
interpolation_mode,
batch_normalization,
dropout,
activation,
shared_weights,
initialization,
):
block_list = []
for i in range(len(n_blocks)):
for block_id in range(n_blocks[i]):
# Batch norm only on first block
if (len(block_list) == 0) and (batch_normalization):
batch_normalization_block = True
else:
batch_normalization_block = False
# Shared weights
if shared_weights and block_id > 0:
nbeats_block = block_list[-1]
else:
n_theta = max(prediction_length // downsample_frequencies[i], 1)
basis = IdentityBasis(
backcast_size=context_length,
forecast_size=prediction_length,
interpolation_mode=interpolation_mode,
)
nbeats_block = NHiTSBlock(
context_length=context_length,
prediction_length=prediction_length,
output_size=output_size,
covariate_size=covariate_size,
static_size=static_size,
static_hidden_size=static_hidden_size,
n_theta=n_theta,
hidden_size=hidden_size[i],
pooling_sizes=pooling_sizes[i],
pooling_mode=pooling_mode,
basis=basis,
n_layers=n_layers[i],
batch_normalization=batch_normalization_block,
dropout=dropout,
activation=activation,
)
# Select type of evaluation and apply it to all layers of block
init_function = partial(init_weights, initialization=initialization)
nbeats_block.layers.apply(init_function)
block_list.append(nbeats_block)
return block_list
def encoder(self, encoder_y, encoder_x_t, decoder_x_t):
# encoder_y: [B L D]
residuals = (encoder_y)
level = encoder_y[:, -1:].repeat(1, self.prediction_length, 1) # Level with Naive1
forecast_level = level.repeat_interleave(torch.tensor(self.output_size, device=level.device), dim=2)
# level with last available observation
if self.naive_level:
block_forecasts = [forecast_level]
forecast = block_forecasts[0]
else:
block_forecasts = []
forecast = torch.zeros_like(forecast_level, device=forecast_level.device)
# forecast by block
for block in self.blocks:
block_backcast, block_forecast = block(
encoder_y=residuals, encoder_x_t=encoder_x_t, decoder_x_t=decoder_x_t
)
residuals = (residuals - block_backcast) # * encoder_mask
forecast = forecast + block_forecast
return forecast
def get_cov(self, inputs):
if self.use_feat_idx_emb:
if self.use_time_feat:
encoder_dim_fea = inputs[:, : self.context_length, self.target_dim:-self.time_feat_dim] # [B L K*D]
decoder_dim_fea = inputs[:, -self.prediction_length:, self.target_dim:-self.time_feat_dim] # [B L K*D]
else:
encoder_dim_fea = inputs[:, : self.context_length, self.target_dim:] # [B L K*D]
decoder_dim_fea = inputs[:, -self.prediction_length:, self.target_dim:] # [B L K*D]
encoder_dim_fea = rearrange(encoder_dim_fea, "b l (k d) -> (b k) l d", k=self.target_dim, d=self.feat_idx_emb_dim)
decoder_dim_fea = rearrange(decoder_dim_fea, "b l (k d) -> (b k) l d", k=self.target_dim, d=self.feat_idx_emb_dim)
else:
encoder_dim_fea = []
if self.time_feat_dim:
encoder_time_fea = inputs[:, : self.context_length, -self.time_feat_dim: ] # [B L Dt]
encoder_time_fea = repeat(encoder_time_fea, 'b l d -> (b k) l d', k=self.target_dim)
decoder_time_fea = inputs[:, -self.prediction_length:, -self.time_feat_dim: ] # [B L Dt]
decoder_time_fea = repeat(decoder_time_fea, 'b l d -> (b k) l d', k=self.target_dim)
else:
encoder_time_fea = []
if self.use_feat_idx_emb and self.use_time_feat:
encoder_x_t = torch.cat([encoder_dim_fea, encoder_time_fea], dim=-1)
decoder_x_t = torch.cat([decoder_dim_fea, decoder_time_fea], dim=-1)
elif self.use_feat_idx_emb:
encoder_x_t, decoder_x_t = encoder_dim_fea, decoder_dim_fea
elif self.use_time_feat:
encoder_x_t, decoder_x_t = encoder_time_fea, decoder_time_fea
else:
encoder_x_t, decoder_x_t = None, None
return encoder_x_t, decoder_x_t
def loss(self, batch_data):
inputs = self.get_inputs(batch_data, 'all') # [B L D]
# Encode
encoder_y = inputs[:, : self.context_length, :self.target_dim] # [B L K]
encoder_y = rearrange(encoder_y, "b l k -> (b k) l 1")
encoder_x_t, decoder_x_t = self.get_cov(inputs)
outputs = self.encoder(encoder_y, encoder_x_t, decoder_x_t)
outputs = rearrange(outputs, "(b k) l 1 -> b l k", k=self.target_dim)
loss = self.loss_fn(batch_data.future_target_cdf, outputs)
loss = self.get_weighted_loss(batch_data, loss)
return loss.mean()
def forecast(self, batch_data, num_samples=None):
inputs = self.get_inputs(batch_data, 'all') # [B L D]
encoder_y = inputs[:, : self.context_length, :self.target_dim] # [B L K]
encoder_y = rearrange(encoder_y, "b l k -> (b k) l 1")
encoder_x_t, decoder_x_t = self.get_cov(inputs)
output = self.encoder(encoder_y,encoder_x_t, decoder_x_t)
outputs = rearrange(output, "(b k) l 1 -> b l k", k=self.target_dim)
return outputs.unsqueeze(1)
================================================
FILE: probts/model/forecaster/point_forecaster/nlinear.py
================================================
# ---------------------------------------------------------------------------------
# Portions of this file are derived from LTSF-Linear
# - Source: https://github.com/cure-lab/LTSF-Linear
# - Paper: Are Transformers Effective for Time Series Forecasting?
# - License: Apache-2.0
#
# We thank the authors for their contributions.
# ---------------------------------------------------------------------------------
import torch
import torch.nn as nn
from probts.model.forecaster import Forecaster
class NLinear(Forecaster):
def __init__(
self,
individual: bool,
**kwargs
):
super().__init__(**kwargs)
if self.input_size != self.target_dim:
self.enc_linear = nn.Linear(
in_features=self.input_size, out_features=self.target_dim
)
else:
self.enc_linear = nn.Identity()
self.target_dim = self.target_dim
self.individual = individual
if individual:
self.Linear = nn.ModuleList()
for i in range(self.target_dim):
self.Linear.append(nn.Linear(self.context_length,self.prediction_length))
else:
self.Linear = nn.Linear(self.context_length, self.prediction_length)
self.loss_fn = nn.MSELoss(reduction='none')
def forward(self, inputs):
seq_last = inputs[:,-1:,:].detach()
inputs = inputs - seq_last
if self.individual:
output = torch.zeros([inputs.size(0),self.prediction_length,inputs.size(2)],dtype=inputs.dtype).to(inputs.device)
for i in range(self.target_dim):
output[:,:,i] = self.Linear[i](inputs[:,:,i])
else:
output = self.Linear(inputs.permute(0,2,1)).permute(0,2,1)
output = output + seq_last
return output
def loss(self, batch_data):
inputs = self.get_inputs(batch_data, 'all')
inputs = inputs[:, : self.context_length, ...]
inputs = self.enc_linear(inputs)
outputs = self(inputs)
loss = self.loss_fn(batch_data.future_target_cdf, outputs)
loss = self.get_weighted_loss(batch_data, loss)
return loss.mean()
def forecast(self, batch_data, num_samples=None):
inputs = self.get_inputs(batch_data, 'encode')
inputs = self.enc_linear(inputs)
outputs = self(inputs)
return outputs.unsqueeze(1)
================================================
FILE: probts/model/forecaster/point_forecaster/patchtst.py
================================================
# ---------------------------------------------------------------------------------
# Portions of this file are derived from PatchTST
# - Source: https://github.com/yuqinie98/PatchTST/tree/main
# - Paper: PatchTST: A Time Series is Worth 64 Words: Long-term Forecasting with Transformers
# - License: Apache-2.0
# We thank the authors for their contributions.
# -----
# ----------------------------------------------------------------------------
import torch.nn as nn
from torch import Tensor
from typing import Optional
from probts.model.forecaster import Forecaster
from probts.model.nn.arch.PatchTSTModule.PatchTST_backbone import PatchTST_backbone
from probts.model.nn.arch.PatchTSTModule.PatchTST_layers import series_decomp
class PatchTST(Forecaster):
def __init__(
self,
stride: int,
patch_len: int,
padding_patch: str = None,
max_seq_len: int = 1024,
n_layers:int = 3,
n_heads = 16,
d_k: int = None,
d_v: int = None,
d_ff: int = 256,
attn_dropout: float = 0.,
dropout: float = 0.,
act: str = "gelu",
res_attention: bool = True,
pre_norm: bool = False,
store_attn: bool = False,
pe: str = 'zeros',
learn_pe: bool = True,
attn_mask: Optional[Tensor] = None,
individual: bool = False,
head_type: str = 'flatten',
padding_var: Optional[int] = None,
revin: bool = True,
key_padding_mask: str = 'auto',
affine: bool = False,
subtract_last: bool = False,
decomposition: bool = False,
kernel_size: int = 3,
fc_dropout: float = 0.,
head_dropout: float = 0.,
f_hidden_size: int = 40,
**kwargs
):
super().__init__(**kwargs)
if self.input_size != self.target_dim:
self.enc_linear = nn.Linear(
in_features=self.input_size, out_features=self.target_dim
)
else:
self.enc_linear = nn.Identity()
# Load parameters
c_in = self.input_size
context_window = self.context_length
target_window = self.prediction_length
# Model
self.decomposition = decomposition
if self.decomposition:
self.decomp_module = series_decomp(kernel_size)
self.model_trend = PatchTST_backbone(c_in=c_in, context_window=context_window, target_window=target_window, patch_len=patch_len, stride=stride,
max_seq_len=max_seq_len, n_layers=n_layers, d_model=f_hidden_size,
n_heads=n_heads, d_k=d_k, d_v=d_v, d_ff=d_ff, attn_dropout=attn_dropout,
dropout=dropout, act=act, key_padding_mask=key_padding_mask, padding_var=padding_var,
attn_mask=attn_mask, res_attention=res_attention, pre_norm=pre_norm, store_attn=store_attn,
pe=pe, learn_pe=learn_pe, fc_dropout=fc_dropout, head_dropout=head_dropout, padding_patch = padding_patch,
pretrain_head=False, head_type=head_type, individual=individual, revin=revin, affine=affine,
subtract_last=subtract_last)
self.model_res = PatchTST_backbone(c_in=c_in, context_window=context_window, target_window=target_window, patch_len=patch_len, stride=stride,
max_seq_len=max_seq_len, n_layers=n_layers, d_model=f_hidden_size,
n_heads=n_heads, d_k=d_k, d_v=d_v, d_ff=d_ff, attn_dropout=attn_dropout,
dropout=dropout, act=act, key_padding_mask=key_padding_mask, padding_var=padding_var,
attn_mask=attn_mask, res_attention=res_attention, pre_norm=pre_norm, store_attn=store_attn,
pe=pe, learn_pe=learn_pe, fc_dropout=fc_dropout, head_dropout=head_dropout, padding_patch = padding_patch,
pretrain_head=False, head_type=head_type, individual=individual, revin=revin, affine=affine,
subtract_last=subtract_last)
else:
self.model = PatchTST_backbone(c_in=c_in, context_window=context_window, target_window=target_window, patch_len=patch_len, stride=stride,
max_seq_len=max_seq_len, n_layers=n_layers, d_model=f_hidden_size,
n_heads=n_heads, d_k=d_k, d_v=d_v, d_ff=d_ff, attn_dropout=attn_dropout,
dropout=dropout, act=act, key_padding_mask=key_padding_mask, padding_var=padding_var,
attn_mask=attn_mask, res_attention=res_attention, pre_norm=pre_norm, store_attn=store_attn,
pe=pe, learn_pe=learn_pe, fc_dropout=fc_dropout, head_dropout=head_dropout, padding_patch = padding_patch,
pretrain_head=False, head_type=head_type, individual=individual, revin=revin, affine=affine,
subtract_last=subtract_last)
self.loss_fn = nn.MSELoss(reduction='none')
def forward(self, x):
if self.decomposition:
res_init, trend_init = self.decomp_module(x)
res_init, trend_init = res_init.permute(0,2,1), trend_init.permute(0,2,1) # x: [Batch, Channel, Input length]
res = self.model_res(res_init)
trend = self.model_trend(trend_init)
x = res + trend
x = x.permute(0,2,1) # x: [Batch, Input length, Channel]
else:
x = x.permute(0,2,1) # x: [Batch, Channel, Input length]
x = self.model(x)
x = x.permute(0,2,1) # x: [Batch, Input length, Channel]
return x
def loss(self, batch_data):
inputs = self.get_inputs(batch_data, 'encode')
inputs = self.enc_linear(inputs)
outputs = self(inputs)
loss = self.loss_fn(batch_data.future_target_cdf, outputs)
loss = self.get_weighted_loss(batch_data, loss)
return loss.mean()
def forecast(self, batch_data, num_samples=None):
inputs = self.get_inputs(batch_data, 'encode')
inputs = self.enc_linear(inputs)
outputs = self(inputs)
return outputs.unsqueeze(1)
================================================
FILE: probts/model/forecaster/point_forecaster/time_moe.py
================================================
# ---------------------------------------------------------------------------------
# Portions of this file are derived from Time-MoE
# - Source: https://github.com/Time-MoE/Time-MoE
# - Paper: Time-MoE: Billion-Scale Time Series Foundation Models with Mixture of Experts
# - License: Apache License 2.0
# We thank the authors for their contributions.
# ---------------------------------------------------------------------------------
import numpy as np
import torch
import torch.nn as nn
from einops import rearrange
from transformers import AutoModelForCausalLM
from probts.model.forecaster import Forecaster
import sys
from probts.data.data_utils.data_scaler import InstanceNorm
class TimeMoE(Forecaster):
def __init__(
self,
model_size: str = '50M',
instance_norm=True,
**kwargs
):
super().__init__(**kwargs)
self.no_training = True
if (type(self.target_dim).__name__=='dict'):
for dataset_name in self.target_dim:
target_dim = target_dim[dataset_name]
freq = freq[dataset_name]
else:
freq = self.freq
if (type(self.context_length).__name__=='list'):
context_length = max(context_length)
if (type(self.prediction_length).__name__=='list'):
prediction_length = max(prediction_length)
if model_size not in ['50M', '200M']:
print('Invalid model size. Please choose from 50M or 200M')
sys.exit()
if instance_norm:
self.normalization = InstanceNorm()
else:
self.normalization = None
self.model = AutoModelForCausalLM.from_pretrained(
f'Maple728/TimeMoE-{model_size}',
trust_remote_code=True,
torch_dtype=torch.bfloat16,
)
print(f"loaded TimeMoE-{model_size} model")
def forecast(self, batch_data, num_samples=None):
inputs = batch_data.past_target_cdf[:, -self.context_length:]
# inputs = inputs[:, -self.context_length:].cpu()
B, _, K = inputs.shape
inputs = inputs.to(dtype=torch.bfloat16)
inputs = rearrange(inputs, 'b l k -> (b k) l')
if self.normalization:
inputs = self.normalization(inputs, mode='norm')
forecasts = self.model.generate(inputs, max_new_tokens=self.prediction_length) # shape is [batch_size, 12 + 6]
point_forecast = forecasts[:, -self.prediction_length:]
if self.normalization:
point_forecast = self.normalization(point_forecast, mode='denorm')
point_forecast = point_forecast.to(dtype=torch.float32)
point_forecast = rearrange(point_forecast, '(b k) l -> b l k', b=B,k=K)
point_forecast = point_forecast[:, :self.prediction_length]
return point_forecast.unsqueeze(1)
================================================
FILE: probts/model/forecaster/point_forecaster/timer.py
================================================
# ---------------------------------------------------------------------------------
# Portions of this file are derived from Large-Time-Series-Model
# - Source: https://github.com/thuml/Large-Time-Series-Model
# - Paper: Timer: Generative Pre-trained Transformers Are Large Time Series Models
# - License: MIT License
# We thank the authors for their contributions.
# ---------------------------------------------------------------------------------
import torch
from einops import rearrange, repeat
from torch import nn
from probts.model.forecaster import Forecaster
class Model(nn.Module):
"""
Paper link: https://arxiv.org/pdf/2402.02368.pdf
"""
def __init__(self, ckpt_path):
super().__init__()
if ckpt_path and ckpt_path != "":
if ckpt_path.endswith('.pt'):
# print(f"Loading Timer model from {ckpt_path}")
self.timer = torch.jit.load(ckpt_path)
else:
raise NotImplementedError
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
return self.timer(x_enc, x_mark_enc, x_dec, x_mark_dec)
class Timer(Forecaster):
def __init__(
self,
label_len: int = 576,
ckpt_path: str = None,
ckpt_path_finetune: str = None,
**kwargs
):
super().__init__(**kwargs)
self.no_training = True
self.output_patch_len = 96 # fixed by the pre-trained model
self.label_len = label_len
# Load Timer
self.model = Model(ckpt_path)
if ckpt_path_finetune:
print(f"Loading Timer finetune model from {ckpt_path_finetune}")
self.model.load_state_dict(torch.load(ckpt_path_finetune))
def forecast(self, batch_data, num_samples=None):
# for now, we only support batch_size=1
B, _, K = batch_data.past_target_cdf.shape
inputs = batch_data.past_target_cdf[:, -self.context_length:, ...]
x_mark_enc = batch_data.past_time_feat[:, -self.context_length:, ...]
x_mark_dec = batch_data.future_time_feat
x_mark_dec = torch.cat([x_mark_enc[:, -self.label_len:, :], x_mark_dec], dim=1)
inputs = rearrange(inputs, 'b l k -> (b k) l 1')
x_mark_enc = repeat(x_mark_enc, 'b l f -> (b k) l f', k=K)
x_mark_dec = repeat(x_mark_dec, 'b l f -> (b k) l f', k=K)
dec_inp = torch.zeros_like(inputs[:, -self.prediction_length:, :]).float()
dec_inp = torch.cat((inputs[:, -self.label_len:, ...], dec_inp), dim=1).float()
inference_steps = self.prediction_length // self.output_patch_len
dis = self.prediction_length - inference_steps * self.output_patch_len
if dis != 0:
inference_steps += 1
pred_y = []
for j in range(inference_steps):
if len(pred_y) != 0:
inputs = torch.cat([inputs[:, self.output_patch_len:, :], pred_y[-1]], dim=1)
tmp = x_mark_dec[:, j - 1:j, :]
x_mark_enc = torch.cat([x_mark_enc[:, 1:, :], tmp], dim=1)
outputs = self.model(inputs, x_mark_enc, dec_inp, x_mark_dec)
pred_y.append(outputs[:, -self.output_patch_len:, :])
pred_y = torch.cat(pred_y, dim=1)
if dis != 0:
pred_y = pred_y[:, :-dis, :]
pred_y = rearrange(pred_y, '(b k) l 1 -> b l k', b=B, k=K)
pred_y = pred_y[:, :self.prediction_length, :]
return pred_y.unsqueeze(1)
================================================
FILE: probts/model/forecaster/point_forecaster/timesfm.py
================================================
# ---------------------------------------------------------------------------------
# Portions of this file are derived from timesfm
# - Source: https://github.com/google-research/timesfm
# - Paper: A decoder-only foundation model for time-series forecasting
# - License: Apache License 2.0
# We thank the authors for their contributions.
# ---------------------------------------------------------------------------------
import numpy as np
import torch
from einops import rearrange
import sys
from probts.model.forecaster import Forecaster
from probts.model.nn.arch.TimesFMModule import TimesFm, TimesFmCheckpoint, TimesFmHparams
# from submodules.timesfm.src.timesfm import TimesFm
class TimesFM(Forecaster):
def __init__(
self,
model_size: str = '200m',
# input_patch_len: int = 32,
# output_patch_len: int = 128,
# num_layers: int = 20,
# model_dims: int = 1280,
**kwargs
):
super().__init__(**kwargs)
self.no_training = True
if (type(self.target_dim).__name__=='dict'):
for dataset_name in self.target_dim:
target_dim = target_dim[dataset_name]
freq = freq[dataset_name]
else:
freq = self.freq
if (type(self.context_length).__name__=='list'):
context_length = max(context_length)
if (type(self.prediction_length).__name__=='list'):
prediction_length = max(prediction_length)
if model_size not in ['200m', '500m']:
print('Invalid model size. Please choose from 200m or 500m')
sys.exit()
if model_size == '200m':
self.tfm = TimesFm(
hparams=TimesFmHparams(
backend="gpu",
per_core_batch_size=32,
horizon_len=128,
),
checkpoint=TimesFmCheckpoint(
huggingface_repo_id="google/timesfm-1.0-200m-pytorch"),
)
elif model_size == '500m':
self.tfm = TimesFm(
hparams=TimesFmHparams(
backend="gpu",
per_core_batch_size=32,
horizon_len=128,
num_layers=50,
use_positional_embedding=False,
context_len=2048,
),
checkpoint=TimesFmCheckpoint(
huggingface_repo_id="google/timesfm-2.0-500m-pytorch"),
)
freq_dict = {'h': 0, 'min': 0, 'd': 0, 'b': 0, 'u': 0, 'w': 1, 'm': 1, 'q': 2, 'y': 2}
freq = freq.lower()
if freq in freq_dict:
self.freq_int = freq_dict[freq]
else:
self.freq_int = 0
print(f"TimesFM-{model_size} - frequency: {freq}, freq_num: {self.freq_int}")
def forecast(self, batch_data, num_samples=None):
inputs = self.get_inputs(batch_data, 'encode')
inputs = inputs[:, -self.context_length:].cpu()
B, _, K = inputs.shape
# past_target = batch_data.past_target_cdf[:, -self.context_length:]
inputs = np.array(rearrange(inputs, 'b l k -> (b k) l'))
frequency_input = [self.freq_int] * inputs.shape[0]
_, out = self.tfm.forecast(
inputs,
freq=frequency_input,
)
point_forecast = out[:, :, 5]
point_forecast = rearrange(point_forecast, '(b k) l -> b l k', b=B,k=K)
point_forecast = torch.tensor(point_forecast[:, :self.prediction_length])
return point_forecast.unsqueeze(1)
================================================
FILE: probts/model/forecaster/point_forecaster/timesnet.py
================================================
# ---------------------------------------------------------------------------------
# Portions of this file are derived from TSLib
# - Source: https://github.com/libts/tslib
# - Paper: TimesNet: Temporal 2D-Variation Modeling for General Time Series Analysis
# - License: LGPL-2.1
# We thank the authors for their contributions.
# ---------------------------------------------------------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.fft
from probts.model.forecaster import Forecaster
from probts.model.nn.arch.TransformerModule.Embed import DataEmbedding
from probts.model.nn.arch.Conv_Blocks import Inception_Block_V1
def FFT_for_Period(x, k=2):
# [B, T, C]
xf = torch.fft.rfft(x, dim=1)
# find period by amplitudes
frequency_list = abs(xf).mean(0).mean(-1)
frequency_list[0] = 0
_, top_list = torch.topk(frequency_list, k)
top_list = top_list.detach().cpu().numpy()
period = x.shape[1] // top_list
return period, abs(xf).mean(-1)[:, top_list]
class TimesBlock(nn.Module):
def __init__(self, context_length, prediction_length, top_k, d_model, d_ff, num_kernels):
super(TimesBlock, self).__init__()
self.seq_len = context_length
self.pred_len = prediction_length
self.k = top_k
# parameter-efficient design
self.conv = nn.Sequential(
Inception_Block_V1(d_model, d_ff,
num_kernels=num_kernels),
nn.GELU(),
Inception_Block_V1(d_ff, d_model,
num_kernels=num_kernels)
)
def forward(self, x):
B, T, N = x.size()
period_list, period_weight = FFT_for_Period(x, self.k)
res = []
for i in range(self.k):
period = period_list[i]
# padding
if (self.seq_len + self.pred_len) % period != 0:
length = (
((self.seq_len + self.pred_len) // period) + 1) * period
padding = torch.zeros([x.shape[0], (length - (self.seq_len + self.pred_len)), x.shape[2]]).to(x.device)
out = torch.cat([x, padding], dim=1)
else:
length = (self.seq_len + self.pred_len)
out = x
# reshape
out = out.reshape(B, length // period, period,
N).permute(0, 3, 1, 2).contiguous()
# 2D conv: from 1d Variation to 2d Variation
out = self.conv(out)
# reshape back
out = out.permute(0, 2, 3, 1).reshape(B, -1, N)
res.append(out[:, :(self.seq_len + self.pred_len), :])
res = torch.stack(res, dim=-1)
# adaptive aggregation
period_weight = F.softmax(period_weight, dim=1)
period_weight = period_weight.unsqueeze(
1).unsqueeze(1).repeat(1, T, N, 1)
res = torch.sum(res * period_weight, -1)
# residual connection
res = res + x
return res
class TimesNet(Forecaster):
def __init__(
self,
n_layers: int = 2,
num_kernels: int = 6,
top_k: int = 5,
d_ff: int = 32,
embed: str = 'timeF',
dropout: float = 0.1,
f_hidden_size: int = 40,
**kwargs
):
super().__init__(**kwargs)
self.seq_len = self.context_length
self.pred_len = self.prediction_length
self.model = nn.ModuleList(
[TimesBlock(self.context_length, self.prediction_length, top_k, f_hidden_size, d_ff, num_kernels)
for _ in range(n_layers)]
)
self.enc_embedding = DataEmbedding(self.target_dim, f_hidden_size, embed, self.freq.lower(), dropout)
self.layer = n_layers
self.layer_norm = nn.LayerNorm(f_hidden_size)
self.predict_linear = nn.Linear(
self.seq_len, self.pred_len + self.seq_len)
self.projection = nn.Linear(
f_hidden_size, self.target_dim, bias=True)
if self.input_size != self.target_dim:
self.enc_linear = nn.Linear(
in_features=self.input_size, out_features=self.target_dim
)
else:
self.enc_linear = nn.Identity()
self.loss_fn = nn.MSELoss(reduction='none')
def forward(self, x_enc, x_mark_enc=None):
# Normalization from Non-stationary Transformer
means = x_enc.mean(1, keepdim=True).detach()
x_enc = x_enc - means
stdev = torch.sqrt(
torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
x_enc = x_enc / stdev
# embedding
enc_out = self.enc_embedding(x_enc, x_mark_enc) # [B,T,C]
enc_out = self.predict_linear(enc_out.permute(0, 2, 1)).permute(
0, 2, 1) # align temporal dimension
# TimesNet
for i in range(self.layer):
enc_out = self.layer_norm(self.model[i](enc_out))
# porject back
dec_out = self.projection(enc_out)
# De-Normalization from Non-stationary Transformer
dec_out = dec_out * \
(stdev[:, 0, :].unsqueeze(1).repeat(
1, self.pred_len + self.seq_len, 1))
dec_out = dec_out + \
(means[:, 0, :].unsqueeze(1).repeat(
1, self.pred_len + self.seq_len, 1))
return dec_out[:, -self.pred_len:, :] # [B, L, D]
def loss(self, batch_data):
inputs = self.get_inputs(batch_data, 'all')
inputs = inputs[:, : self.context_length, ...]
inputs = self.enc_linear(inputs)
# x: [Batch, Input length, Channel]
outputs = self(inputs)
loss = self.loss_fn(batch_data.future_target_cdf, outputs)
loss = self.get_weighted_loss(batch_data, loss)
return loss.mean()
def forecast(self, batch_data, num_samples=None):
inputs = self.get_inputs(batch_data, 'encode')
inputs = self.enc_linear(inputs)
outputs = self(inputs)
return outputs.unsqueeze(1)
================================================
FILE: probts/model/forecaster/point_forecaster/tinytimemixer.py
================================================
# ---------------------------------------------------------------------------------
# Portions of this file are derived from granite-tsfm
# - Source: https://github.com/ibm-granite/granite-tsfm
# - Paper: Tiny Time Mixers (TTMs): Fast Pre-trained Models for Enhanced Zero/Few-Shot Forecasting of Multivariate Time Series
# - License: Apache License 2.0
# We thank the authors for their contributions.
# ---------------------------------------------------------------------------------
from probts.model.forecaster import Forecaster
from submodules.tsfm.tsfm_public.models.tinytimemixer import TinyTimeMixerForPrediction
class TinyTimeMixer(Forecaster):
"""
TinyTimeMixer from https://github.com/ibm-granite/granite-tsfm/blob/main/notebooks/hfdemo/ttm_getting_started.ipynb
prediction length originally 96
context length originally 512
changes might cause degradation in performance
"""
def __init__(
self,
**kwargs,
):
super().__init__(**kwargs)
self.no_training = True
# TTM model branch
# Use main for 512-96 model
# Use "1024_96_v1" for 1024-96 model
TTM_MODEL_REVISION = "main"
if (type(self.context_length).__name__=='list'):
context_length = max(context_length)
if (type(self.prediction_length).__name__=='list'):
prediction_length = max(prediction_length)
self.zeroshot_model = TinyTimeMixerForPrediction.from_pretrained(
"ibm/TTM", revision=TTM_MODEL_REVISION
)
def forecast(self, batch_data, num_samples=None):
inputs = self.get_inputs(batch_data, 'encode')
inputs = inputs[:, -self.context_length:]
B, _, K = inputs.shape
# past_target = batch_data.past_target_cdf[:, -self.context_length:]
self.zeroshot_model.eval()
point_forecast = self.zeroshot_model.forward(inputs).prediction_outputs
return point_forecast.unsqueeze(1)
================================================
FILE: probts/model/forecaster/point_forecaster/transformer.py
================================================
import torch
import torch.nn as nn
from probts.data import ProbTSBatchData
from probts.model.forecaster import Forecaster
class TransformerForecaster(Forecaster):
def __init__(
self,
f_hidden_size: int = 32,
num_heads: int = 8,
num_encoder_layers: int = 3,
num_decoder_layers: int = 3,
dim_feedforward_scale: int = 4,
dropout: float = 0.1,
activation: str = 'gelu',
**kwargs
):
super().__init__(**kwargs)
self.autoregressive = True
self.f_hidden_size = f_hidden_size
self.enc_linear = nn.Linear(self.input_size, self.f_hidden_size)
self.dec_linear = nn.Linear(self.input_size, self.f_hidden_size)
self.model = nn.Transformer(
d_model=self.f_hidden_size,
nhead=num_heads,
num_encoder_layers=num_encoder_layers,
num_decoder_layers=num_decoder_layers,
dim_feedforward=dim_feedforward_scale * self.f_hidden_size,
dropout=dropout,
activation=activation
)
self.register_buffer(
"tgt_mask",
self.model.generate_square_subsequent_mask(self.prediction_length),
)
self.linear = nn.Linear(self.f_hidden_size, self.target_dim)
self.loss_fn = nn.MSELoss(reduction='none')
def loss(self, batch_data):
inputs = self.get_inputs(batch_data, 'all') # [B L D]
# Encode
enc_inputs = inputs[:, :self.context_length, ...]
enc_inputs = self.enc_linear(enc_inputs).permute(1, 0, 2)
enc_outputs = self.model.encoder(enc_inputs) # [L_in B H]
# Decode
dec_inputs = inputs[:, -self.prediction_length-1:-1, ...]
dec_inputs = self.dec_linear(dec_inputs).permute(1, 0, 2)
dec_outputs = self.model.decoder(
dec_inputs, enc_outputs, tgt_mask=self.tgt_mask)
dec_outputs = dec_outputs.permute(1, 0, 2) # [L_out B D]
outputs = self.linear(dec_outputs)
loss = self.loss_fn(batch_data.future_target_cdf, outputs)
loss = self.get_weighted_loss(batch_data, loss)
return loss.mean()
def forecast(self, batch_data, num_samples=None):
forecasts = []
states = self.encode(batch_data)
past_target_cdf = batch_data.past_target_cdf
for k in range(self.prediction_length):
current_batch_data = ProbTSBatchData({
'target_dimension_indicator': batch_data.target_dimension_indicator,
'past_target_cdf': past_target_cdf,
'future_time_feat': batch_data.future_time_feat[:, k : k + 1:, ...]
}, device=batch_data.device)
outputs, states = self.decode(current_batch_data, states)
outputs = self.linear(outputs)
forecasts.append(outputs)
past_target_cdf = torch.cat(
(past_target_cdf, outputs), dim=1
)
forecasts = torch.cat(forecasts, dim=1).reshape(
-1, self.prediction_length, self.target_dim)
return forecasts.unsqueeze(1)
def encode(self, batch_data):
inputs = self.get_inputs(batch_data, 'encode')
inputs = self.enc_linear(inputs).permute(1, 0, 2)
states = self.model.encoder(inputs)
return states
def decode(self, batch_data, states=None):
inputs = self.get_inputs(batch_data, 'decode')
inputs = self.dec_linear(inputs).permute(1, 0, 2)
outputs = self.model.decoder(inputs, states, tgt_mask=None)
return outputs.permute(1, 0, 2), states
================================================
FILE: probts/model/forecaster/point_forecaster/tsmixer.py
================================================
# ---------------------------------------------------------------------------------
# Portions of this file are derived from TSMixer
# - Source: https://github.com/google-research/google-research/tree/master/tsmixer
# - Paper: TSMixer: An All-MLP Architecture for Time Series Forecasting
# - License: Apache-2.0 license
# We thank the authors for their contributions.
# ---------------------------------------------------------------------------------
from __future__ import annotations
import torch
import torch.nn as nn
import torch.nn.functional as F
from probts.model.nn.arch.TSMixer_layers import MixerLayer, TimeBatchNorm2d, feature_to_time, time_to_feature
from probts.model.forecaster import Forecaster
import sys
class TSMixer(Forecaster):
"""TSMixer model for time series forecasting.
This model uses a series of mixer layers to process time series data,
followed by a linear transformation to project the output to the desired
prediction length.
Attributes:
mixer_layers: Sequential container of mixer layers.
temporal_projection: Linear layer for temporal projection.
Args:
sequence_length: Length of the input time series sequence.
prediction_length: Desired length of the output prediction sequence.
input_channels: Number of input channels.
output_channels: Number of output channels. Defaults to None.
activation_fn: Activation function to use. Defaults to "relu".
num_blocks: Number of mixer blocks. Defaults to 2.
dropout_rate: Dropout rate for regularization. Defaults to 0.1.
ff_dim: Dimension of feedforward network inside mixer layer. Defaults to 64.
normalize_before: Whether to apply layer normalization before or after mixer layer.
norm_type: Type of normalization to use. "batch" or "layer". Defaults to "batch".
"""
def __init__(
self,
activation_fn: str = "relu",
num_blocks: int = 2,
dropout_rate: float = 0.1,
ff_dim: int = 64,
normalize_before: bool = True,
norm_type: str = "batch",
**kwargs
):
super().__init__(**kwargs)
# Transform activation_fn to callable
activation_fn = getattr(F, activation_fn)
input_channels = self.target_dim
output_channels = self.target_dim
if type(self.prediction_length) == list:
self.prediction_length = max(self.prediction_length)
if type(self.context_length) == list:
self.context_length = max(self.context_length)
sequence_length = self.context_length
prediction_length = self.prediction_length
# Transform norm_type to callable
assert norm_type in {
"batch",
"layer",
}, f"Invalid norm_type: {norm_type}, must be one of batch, layer."
norm_type = TimeBatchNorm2d if norm_type == "batch" else nn.LayerNorm
# Build mixer layers
self.mixer_layers = self._build_mixer(
num_blocks,
input_channels,
output_channels,
ff_dim=ff_dim,
activation_fn=activation_fn,
dropout_rate=dropout_rate,
sequence_length=sequence_length,
normalize_before=normalize_before,
norm_type=norm_type,
)
# Temporal projection layer
self.temporal_projection = nn.Linear(sequence_length, prediction_length)
self.loss_fn = nn.MSELoss(reduction='none')
def _build_mixer(
self, num_blocks: int, input_channels: int, output_channels: int, **kwargs
):
"""Build the mixer blocks for the model.
Args:
num_blocks (int): Number of mixer blocks to be built.
input_channels (int): Number of input channels for the first block.
output_channels (int): Number of output channels for the last block.
**kwargs: Additional keyword arguments for mixer layer configuration.
Returns:
nn.Sequential: Sequential container of mixer layers.
"""
output_channels = output_channels if output_channels is not None else input_channels
channels = [input_channels] * (num_blocks - 1) + [output_channels]
return nn.Sequential(
*[
MixerLayer(input_channels=in_ch, output_channels=out_ch, **kwargs)
for in_ch, out_ch in zip(channels[:-1], channels[1:])
]
)
def forward(self, x_hist: torch.Tensor) -> torch.Tensor:
"""Forward pass of the TSMixer model.
Args:
x_hist (torch.Tensor): Input time series tensor.
Returns:
torch.Tensor: The output tensor after processing by the model.
"""
x = self.mixer_layers(x_hist)
x_temp = feature_to_time(x)
x_temp = self.temporal_projection(x_temp)
x = time_to_feature(x_temp)
return x
def loss(self, batch_data):
inputs = self.get_inputs(batch_data, 'encode')
outputs = self(inputs)
loss = self.loss_fn(batch_data.future_target_cdf, outputs)
loss = self.get_weighted_loss(batch_data, loss)
return loss.mean()
def forecast(self, batch_data, num_samples=None):
inputs = self.get_inputs(batch_data, 'encode')
outputs = self(inputs)
return outputs.unsqueeze(1)
================================================
FILE: probts/model/forecaster/point_forecaster/units.py
================================================
# ---------------------------------------------------------------------------------
# Portions of this file are derived from UniTS
# - Source: https://github.com/mims-harvard/UniTS
# - Paper: UNITS: A Unified Multi-Task Time Series Model
# - License: MIT License
#
# We thank the authors for their contributions.
# ---------------------------------------------------------------------------------
import math
import torch
import torch.nn.functional as F
from timm.layers import DropPath, Mlp
from timm.layers.helpers import to_2tuple
from torch import nn
from probts.model.forecaster import Forecaster
def calculate_unfold_output_length(input_length, size, step):
# Calculate the number of windows
num_windows = (input_length - size) // step + 1
return num_windows
class CrossAttention(nn.Module):
def __init__(
self,
dim,
num_heads=8,
qkv_bias=False,
qk_norm=False,
attn_drop=0.,
proj_drop=0.,
norm_layer=nn.LayerNorm,
var_num=None,
):
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.q = nn.Linear(dim, dim, bias=qkv_bias)
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
if var_num is not None:
self.template = nn.Parameter(
torch.zeros(var_num, dim), requires_grad=True)
torch.nn.init.normal_(self.template, std=.02)
self.var_num = var_num
def forward(self, x, query=None):
B, N, C = x.shape
if query is not None:
q = self.q(query).reshape(
B, query.shape[1], self.num_heads, self.head_dim).permute(0, 2, 1, 3)
q = self.q_norm(q)
var_num = query.shape[1]
else:
q = self.q(self.template).reshape(1, self.var_num,
self.num_heads, self.head_dim).permute(0, 2, 1, 3)
q = self.q_norm(q)
q = q.repeat(B, 1, 1, 1)
var_num = self.var_num
kv = self.kv(x).reshape(B, N, 2, self.num_heads,
self.head_dim).permute(2, 0, 3, 1, 4)
k, v = kv.unbind(0)
k = self.k_norm(k)
x = F.scaled_dot_product_attention(
q, k, v,
dropout_p=self.attn_drop.p if self.training else 0.,
)
x = x.transpose(1, 2).reshape(B, var_num, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class DynamicLinear(nn.Module):
"""
A dynamic linear layer that can interpolate the weight size to support any given input and output feature dimension.
"""
def __init__(self, in_features=None, out_features=None, fixed_in=0, bias=True):
super(DynamicLinear, self).__init__()
assert fixed_in < in_features, "fixed_in < in_features is required !!!"
self.in_features = in_features
self.out_features = out_features
self.weights = nn.Parameter(torch.Tensor(out_features, in_features))
self.bias = nn.Parameter(torch.Tensor(out_features))
self.fixed_in = fixed_in
self.reset_parameters()
def reset_parameters(self):
nn.init.kaiming_uniform_(self.weights, a=math.sqrt(5))
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weights)
bound = 1 / math.sqrt(fan_in)
nn.init.uniform_(self.bias, -bound, bound)
def forward(self, x, out_features):
"""
Forward pass for the dynamic linear layer.
"""
fixed_weights = self.weights[:, :self.fixed_in]
dynamic_weights = self.weights[:, self.fixed_in:]
this_bias = self.bias
in_features = x.shape[-1]
if in_features != self.weights.size(1) or out_features != self.weights.size(0):
dynamic_weights = F.interpolate(dynamic_weights.unsqueeze(0).unsqueeze(0), size=(
out_features, in_features-self.fixed_in), mode='bilinear', align_corners=False).squeeze(0).squeeze(0)
if self.fixed_in != 0:
fixed_weights = F.interpolate(fixed_weights.unsqueeze(0).unsqueeze(0), size=(
out_features, self.fixed_in), mode='bilinear', align_corners=False).squeeze(0).squeeze(0)
if out_features != self.weights.size(0):
this_bias = F.interpolate(this_bias.unsqueeze(0).unsqueeze(0).unsqueeze(0), size=(
1, out_features), mode='bilinear', align_corners=False).squeeze(0).squeeze(0).squeeze(0)
return F.linear(x, torch.cat((fixed_weights, dynamic_weights), dim=1), this_bias)
class DynamicLinearMlp(nn.Module):
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
norm_layer=None,
bias=True,
drop=0.,
prefix_token_length=None,
group=1,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
bias = to_2tuple(bias)
drop_probs = to_2tuple(drop)
self.fc1 = nn.Conv1d(in_features, hidden_features,
3, groups=group, bias=bias[0], padding=1)
self.act = act_layer()
self.drop1 = nn.Dropout(drop_probs[0])
self.norm = norm_layer(
hidden_features) if norm_layer is not None else nn.Identity()
self.seq_fc = DynamicLinear(
hidden_features//4, hidden_features//4, bias=bias[1], fixed_in=prefix_token_length)
self.prompt_fc = DynamicLinear(
hidden_features//4, prefix_token_length, bias=bias[1], fixed_in=prefix_token_length)
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
self.drop2 = nn.Dropout(drop_probs[1])
self.hidden_features = hidden_features
self.prefix_token_length = prefix_token_length
def dynamic_linear(self, x, prefix_seq_len):
x_func = x[:, :, prefix_seq_len:]
x_seq = x[:, :, :prefix_seq_len]
x_seq_out = self.seq_fc(
x_seq, x_seq.shape[-1]-self.prefix_token_length)
x_prompt = self.prompt_fc(x_seq, self.prefix_token_length)
x = torch.cat((x_prompt, x_seq_out, x_func), dim=-1)
return x
def split_dynamic_linear(self, x, prefix_seq_len):
x1, x2 = x.chunk(2, dim=-2)
x1 = self.dynamic_linear(x1, prefix_seq_len)
return torch.cat((x1, x2), dim=-2)
def forward(self, x, prefix_seq_len, dim=2):
n, var, l, c = x.shape
x = x.view(-1, l, c)
x = x.transpose(-1, -2)
x = self.fc1(x)
x = self.split_dynamic_linear(x, prefix_seq_len)
x = self.act(x)
x = self.drop1(x)
x = x.transpose(1, 2)
x = self.norm(x)
x = self.fc2(x).view(n, var, l, c)
x = self.drop2(x)
return x
class LearnablePositionalEmbedding(nn.Module):
def __init__(self, d_model, max_len=5000):
super(LearnablePositionalEmbedding, self).__init__()
# Compute the positional encodings once in log space.
self.pe = nn.Parameter(torch.zeros(
1, 1, max_len, d_model), requires_grad=True)
pe = torch.zeros(max_len, d_model).float()
position = torch.arange(0, max_len).float().unsqueeze(1)
div_term = (torch.arange(0, d_model, 2).float()
* -(math.log(10000.0) / d_model)).exp()
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).unsqueeze(0)
self.pe.data.copy_(pe.float())
del pe
def forward(self, x, offset=0):
return self.pe[:, :, offset:offset+x.size(2)]
class SeqAttention(nn.Module):
def __init__(
self,
dim,
num_heads=8,
qkv_bias=False,
qk_norm=False,
attn_drop=0.,
proj_drop=0.,
norm_layer=nn.LayerNorm,
):
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, attn_mask=None):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
q, k = self.q_norm(q), self.k_norm(k)
x = F.scaled_dot_product_attention(
q, k, v, # attn_mask=attn_mask,
dropout_p=self.attn_drop.p if self.training else 0.,
)
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class VarAttention(nn.Module):
def __init__(
self,
dim,
num_heads=8,
qkv_bias=False,
qk_norm=False,
attn_drop=0.,
proj_drop=0.,
norm_layer=nn.LayerNorm,
):
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, P, C = x.shape
qkv = self.qkv(x).reshape(B, N, P, 3, self.num_heads,
self.head_dim).permute(3, 0, 2, 4, 1, 5)
q, k, v = qkv.unbind(0)
q, k = self.q_norm(q), self.k_norm(k)
q = q.mean(dim=1, keepdim=False)
k = k.mean(dim=1, keepdim=False)
v = v.permute(0, 2, 3, 4, 1).reshape(B, self.num_heads, N, -1)
x = F.scaled_dot_product_attention(
q, k, v,
dropout_p=self.attn_drop.p if self.training else 0.,
)
x = x.view(B, self.num_heads, N, -1, P).permute(0,
2, 4, 1, 3).reshape(B, N, P, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
class GateLayer(nn.Module):
def __init__(self, dim, init_values=1e-5, inplace=False):
super().__init__()
self.inplace = inplace
self.gate = nn.Linear(dim, 1)
def forward(self, x):
gate_value = self.gate(x)
return gate_value.sigmoid() * x
class SeqAttBlock(nn.Module):
def __init__(
self,
dim,
num_heads,
qkv_bias=False,
qk_norm=False,
proj_drop=0.,
attn_drop=0.,
init_values=None,
drop_path=0.,
norm_layer=nn.LayerNorm,
):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn_seq = SeqAttention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_norm=qk_norm,
attn_drop=attn_drop,
proj_drop=proj_drop,
norm_layer=norm_layer,
)
self.ls1 = GateLayer(dim, init_values=init_values)
self.drop_path1 = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
self.proj = nn.Linear(dim, dim)
def forward(self, x, attn_mask):
x_input = x
x = self.norm1(x)
n_vars, n_seqs = x.shape[1], x.shape[2]
x = torch.reshape(
x, (-1, x.shape[-2], x.shape[-1]))
x = self.attn_seq(x, attn_mask)
x = torch.reshape(
x, (-1, n_vars, n_seqs, x.shape[-1]))
x = x_input + self.drop_path1(self.ls1(x))
return x
class VarAttBlock(nn.Module):
def __init__(
self,
dim,
num_heads,
qkv_bias=False,
qk_norm=False,
proj_drop=0.,
attn_drop=0.,
init_values=None,
drop_path=0.,
norm_layer=nn.LayerNorm,
):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn_var = VarAttention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_norm=qk_norm,
attn_drop=attn_drop,
proj_drop=proj_drop,
norm_layer=norm_layer,
)
self.ls1 = GateLayer(dim, init_values=init_values)
self.drop_path1 = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
self.proj = nn.Linear(dim, dim)
def forward(self, x):
x = x + self.drop_path1(self.ls1(self.attn_var(self.norm1(x))))
return x
class MLPBlock(nn.Module):
def __init__(
self,
dim,
mlp_ratio=4.,
proj_drop=0.,
init_values=None,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
mlp_layer=None,
prefix_token_length=0,
):
super().__init__()
self.norm2 = norm_layer(dim)
if mlp_layer is DynamicLinearMlp:
self.mlp = mlp_layer(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
drop=proj_drop,
prefix_token_length=prefix_token_length,
)
else:
self.mlp = mlp_layer(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
drop=proj_drop,
)
self.ls2 = GateLayer(dim, init_values=init_values)
self.drop_path2 = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x, prefix_seq_len=None):
if prefix_seq_len is not None:
x = x + \
self.drop_path2(
self.ls2(self.mlp(self.norm2(x), prefix_seq_len=prefix_seq_len)))
else:
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
return x
class BasicBlock(nn.Module):
def __init__(
self,
dim,
num_heads,
mlp_ratio=8.,
qkv_bias=False,
qk_norm=False,
proj_drop=0.,
attn_drop=0.,
init_values=None,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
prefix_token_length=0,
):
super().__init__()
self.seq_att_block = SeqAttBlock(dim=dim, num_heads=num_heads,
qkv_bias=qkv_bias, qk_norm=qk_norm,
attn_drop=attn_drop, init_values=init_values, proj_drop=proj_drop,
drop_path=drop_path, norm_layer=norm_layer)
self.var_att_block = VarAttBlock(dim=dim, num_heads=num_heads,
qkv_bias=qkv_bias, qk_norm=qk_norm,
attn_drop=attn_drop, init_values=init_values, proj_drop=proj_drop,
drop_path=drop_path, norm_layer=norm_layer)
self.dynamic_mlp = MLPBlock(dim=dim, mlp_ratio=mlp_ratio, mlp_layer=DynamicLinearMlp,
proj_drop=proj_drop, init_values=init_values, drop_path=drop_path,
act_layer=act_layer, norm_layer=norm_layer,
prefix_token_length=prefix_token_length)
def forward(self, x, prefix_seq_len, attn_mask):
x = self.seq_att_block(x, attn_mask)
x = self.var_att_block(x)
x = self.dynamic_mlp(x, prefix_seq_len=prefix_seq_len)
return x
class PatchEmbedding(nn.Module):
def __init__(self, d_model, patch_len, stride, padding, dropout):
super(PatchEmbedding, self).__init__()
# Patching
self.patch_len = patch_len
self.stride = stride
assert self.patch_len == self.stride, "non-overlap"
self.value_embedding = nn.Linear(patch_len, d_model, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
n_vars = x.shape[1]
x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride)
x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3]))
x = self.value_embedding(x)
return self.dropout(x), n_vars
class CLSHead(nn.Module):
def __init__(self, d_model, head_dropout=0):
super().__init__()
d_mid = d_model
self.proj_in = nn.Linear(d_model, d_mid)
self.cross_att = CrossAttention(d_mid)
self.mlp = MLPBlock(dim=d_mid, mlp_ratio=8, mlp_layer=Mlp,
proj_drop=head_dropout, init_values=None, drop_path=0.0,
act_layer=nn.GELU, norm_layer=nn.LayerNorm,
prefix_token_length=None)
def forward(self, x, category_token=None, return_feature=False):
x = self.proj_in(x)
B, V, L, C = x.shape
x = x.view(-1, L, C)
cls_token = x[:, -1:]
cls_token = self.cross_att(x, query=cls_token)
cls_token = cls_token.reshape(B, V, -1, C)
cls_token = self.mlp(cls_token)
if return_feature:
return cls_token
m = category_token.shape[2]
cls_token = cls_token.expand(B, V, m, C)
distance = torch.einsum('nvkc,nvmc->nvm', cls_token, category_token)
distance = distance.mean(dim=1)
return distance
class ForecastHead(nn.Module):
def __init__(self, d_model, patch_len, stride, pad, head_dropout=0, prefix_token_length=None):
super().__init__()
d_mid = d_model
self.proj_in = nn.Linear(d_model, d_mid)
self.mlp = Mlp(
in_features=d_model,
hidden_features=int(d_model * 4),
act_layer=nn.GELU,
drop=head_dropout,
)
self.proj_out = nn.Linear(d_model, patch_len)
self.pad = pad
self.patch_len = patch_len
self.stride = stride
self.pos_proj = DynamicLinear(
in_features=128, out_features=128, fixed_in=prefix_token_length)
def forward(self, x_full, pred_len, token_len):
x_full = self.proj_in(x_full)
x_pred = x_full[:, :, -token_len:]
x = x_full.transpose(-1, -2)
x = self.pos_proj(x, token_len)
x = x.transpose(-1, -2)
x = x + x_pred
x = self.mlp(x)
x = self.proj_out(x)
bs, n_vars = x.shape[0], x.shape[1]
x = x.reshape(-1, x.shape[-2], x.shape[-1])
x = x.permute(0, 2, 1)
x = torch.nn.functional.fold(x, output_size=(
pred_len, 1), kernel_size=(self.patch_len, 1), stride=(self.stride, 1))
x = x.squeeze(dim=-1)
x = x.reshape(bs, n_vars, -1)
x = x.permute(0, 2, 1)
return x
class Model(nn.Module):
"""
UniTS: Building a Unified Time Series Model
"""
def __init__(self, args, configs_list, pretrain=False):
super().__init__()
# (zhenwei) we do not pretrain the model in this stage
# if pretrain:
# self.right_prob = args.right_prob
# self.min_mask_ratio = args.min_mask_ratio
# self.max_mask_ratio = args.max_mask_ratio
# Tokens settings
self.num_task = len(configs_list)
self.prompt_tokens = nn.ParameterDict({})
self.mask_tokens = nn.ParameterDict({})
self.cls_tokens = nn.ParameterDict({})
self.category_tokens = nn.ParameterDict({})
for i in range(self.num_task):
dataset_name = configs_list[i][1]['dataset']
task_data_name = configs_list[i][0]
if dataset_name not in self.prompt_tokens:
self.prompt_tokens[dataset_name] = torch.zeros(
1, configs_list[i][1]['enc_in'], args.prompt_num, args.d_model)
torch.nn.init.normal_(
self.prompt_tokens[dataset_name], std=.02)
self.mask_tokens[dataset_name] = torch.zeros(
1, configs_list[i][1]['enc_in'], 1, args.d_model)
if configs_list[i][1]['task_name'] == 'classification':
self.category_tokens[task_data_name] = torch.zeros(
1, configs_list[i][1]['enc_in'], configs_list[i][1]['num_class'], args.d_model)
torch.nn.init.normal_(
self.category_tokens[task_data_name], std=.02)
self.cls_tokens[task_data_name] = torch.zeros(
1, configs_list[i][1]['enc_in'], 1, args.d_model)
torch.nn.init.normal_(self.cls_tokens[task_data_name], std=.02)
if pretrain:
self.cls_tokens[task_data_name] = torch.zeros(
1, configs_list[i][1]['enc_in'], 1, args.d_model)
torch.nn.init.normal_(self.cls_tokens[task_data_name], std=.02)
self.cls_nums = {}
for i in range(self.num_task):
task_data_name = configs_list[i][0]
if configs_list[i][1]['task_name'] == 'classification':
self.cls_nums[task_data_name] = configs_list[i][1]['num_class']
elif configs_list[i][1]['task_name'] == 'long_term_forecast':
remainder = configs_list[i][1]['seq_len'] % args.patch_len
if remainder == 0:
padding = 0
else:
padding = args.patch_len - remainder
input_token_len = calculate_unfold_output_length(
configs_list[i][1]['seq_len']+padding, args.stride, args.patch_len)
input_pad = args.stride * \
(input_token_len - 1) + args.patch_len - \
configs_list[i][1]['seq_len']
pred_token_len = calculate_unfold_output_length(
configs_list[i][1]['pred_len']-input_pad, args.stride, args.patch_len)
real_len = configs_list[i][1]['seq_len'] + \
configs_list[i][1]['pred_len']
self.cls_nums[task_data_name] = [pred_token_len,
configs_list[i][1]['pred_len'], real_len]
self.configs_list = configs_list
### model settings ###
self.prompt_num = args.prompt_num
self.stride = args.stride
self.pad = args.stride
self.patch_len = args.patch_len
# input processing
self.patch_embeddings = PatchEmbedding(
args.d_model, args.patch_len, args.stride, args.stride, args.dropout)
self.position_embedding = LearnablePositionalEmbedding(args.d_model)
self.prompt2forecat = DynamicLinear(128, 128, fixed_in=args.prompt_num)
# basic blocks
self.block_num = args.e_layers
self.blocks = nn.ModuleList(
[BasicBlock(dim=args.d_model, num_heads=args.n_heads, qkv_bias=False, qk_norm=False,
mlp_ratio=8., proj_drop=args.dropout, attn_drop=0., drop_path=0.,
init_values=None, prefix_token_length=args.prompt_num) for l in range(args.e_layers)]
)
# output processing
self.cls_head = CLSHead(args.d_model, head_dropout=args.dropout)
self.forecast_head = ForecastHead(
args.d_model, args.patch_len, args.stride, args.stride, prefix_token_length=args.prompt_num, head_dropout=args.dropout)
if pretrain:
self.pretrain_head = ForecastHead(
args.d_model, args.patch_len, args.stride, args.stride, prefix_token_length=1, head_dropout=args.dropout)
def tokenize(self, x, mask=None):
# Normalization from Non-stationary Transformer
means = x.mean(1, keepdim=True).detach()
x = x - means
if mask is not None:
x = x.masked_fill(mask == 0, 0)
stdev = torch.sqrt(torch.sum(x * x, dim=1) /
torch.sum(mask == 1, dim=1) + 1e-5)
stdev = stdev.unsqueeze(dim=1)
else:
stdev = torch.sqrt(
torch.var(x, dim=1, keepdim=True, unbiased=False) + 1e-5)
x /= stdev
x = x.permute(0, 2, 1)
remainder = x.shape[2] % self.patch_len
if remainder != 0:
padding = self.patch_len - remainder
x = F.pad(x, (0, padding))
else:
padding = 0
x, n_vars = self.patch_embeddings(x)
return x, means, stdev, n_vars, padding
def prepare_prompt(self, x, n_vars, prefix_prompt, task_prompt, task_prompt_num, task_name=None, mask=None):
x = torch.reshape(
x, (-1, n_vars, x.shape[-2], x.shape[-1]))
# append prompt tokens
this_prompt = prefix_prompt.repeat(x.shape[0], 1, 1, 1)
if task_name == 'forecast':
this_mask_prompt = task_prompt.repeat(
x.shape[0], 1, task_prompt_num, 1)
init_full_input = torch.cat(
(this_prompt, x, this_mask_prompt), dim=-2)
init_mask_prompt = self.prompt2forecat(init_full_input.transpose(
-1, -2), init_full_input.shape[2]-prefix_prompt.shape[2]).transpose(-1, -2)
this_function_prompt = init_mask_prompt[:, :, -task_prompt_num:]
x = torch.cat((this_prompt, x, this_function_prompt), dim=2)
x[:, :, self.prompt_num:] = x[:, :, self.prompt_num:] + \
self.position_embedding(x[:, :, self.prompt_num:])
elif task_name == 'classification':
this_function_prompt = task_prompt.repeat(x.shape[0], 1, 1, 1)
x = x + self.position_embedding(x)
x = torch.cat((this_prompt, x, this_function_prompt), dim=2)
elif task_name == 'imputation':
# fill the masked parts with mask tokens
# for imputation, masked is 0, unmasked is 1, so here to reverse mask
mask = 1-mask
mask = mask.permute(0, 2, 1)
mask = self.mark2token(mask)
mask_repeat = mask.unsqueeze(dim=-1)
mask_token = task_prompt
mask_repeat = mask_repeat.repeat(1, 1, 1, x.shape[-1])
x = x * (1-mask_repeat) + mask_token * mask_repeat
init_full_input = torch.cat((this_prompt, x), dim=-2)
init_mask_prompt = self.prompt2forecat(
init_full_input.transpose(-1, -2), x.shape[2]).transpose(-1, -2)
# keep the unmasked tokens and fill the masked ones with init_mask_prompt.
x = x * (1-mask_repeat) + init_mask_prompt * mask_repeat
x = x + self.position_embedding(x)
x = torch.cat((this_prompt, x), dim=2)
elif task_name == 'anomaly_detection':
x = x + self.position_embedding(x)
x = torch.cat((this_prompt, x), dim=2)
return x
def mark2token(self, x_mark):
x_mark = x_mark.unfold(
dimension=-1, size=self.patch_len, step=self.stride)
x_mark = x_mark.mean(dim=-1)
x_mark = (x_mark > 0).float()
return x_mark
def backbone(self, x, prefix_len, seq_len):
attn_mask = None
for block in self.blocks:
x = block(x, prefix_seq_len=prefix_len +
seq_len, attn_mask=attn_mask)
return x
def forecast(self, x, x_mark, task_id):
dataset_name = self.configs_list[task_id][1]['dataset']
task_data_name = self.configs_list[task_id][0]
prefix_prompt = self.prompt_tokens[dataset_name]
task_prompt = self.mask_tokens[dataset_name]
task_prompt_num = self.cls_nums[task_data_name][0]
task_seq_num = self.cls_nums[task_data_name][1]
real_seq_len = self.cls_nums[task_data_name][2]
x, means, stdev, n_vars, _ = self.tokenize(x)
x = self.prepare_prompt(
x, n_vars, prefix_prompt, task_prompt, task_prompt_num, task_name='forecast')
seq_token_len = x.shape[-2]-prefix_prompt.shape[2]
x = self.backbone(x, prefix_prompt.shape[2], seq_token_len)
x = self.forecast_head(
x, real_seq_len, seq_token_len)
x = x[:, -task_seq_num:]
# De-Normalization from Non-stationary Transformer
x = x * (stdev[:, 0, :].unsqueeze(1).repeat(1, x.shape[1], 1))
x = x + (means[:, 0, :].unsqueeze(1).repeat(1, x.shape[1], 1))
return x
def classification(self, x, x_mark, task_id):
dataset_name = self.configs_list[task_id][1]['dataset']
task_data_name = self.configs_list[task_id][0]
prefix_prompt = self.prompt_tokens[dataset_name]
task_prompt = self.cls_tokens[task_data_name]
task_prompt_num = 1
category_token = self.category_tokens[task_data_name]
x, means, stdev, n_vars, _ = self.tokenize(x)
seq_len = x.shape[-2]
x = self.prepare_prompt(
x, n_vars, prefix_prompt, task_prompt, task_prompt_num, task_name='classification')
x = self.backbone(x, prefix_prompt.shape[2], seq_len)
x = self.cls_head(x, category_token)
return x
def imputation(self, x, x_mark, mask, task_id):
dataset_name = self.configs_list[task_id][1]['dataset']
prefix_prompt = self.prompt_tokens[dataset_name]
task_prompt = self.mask_tokens[dataset_name]
seq_len = x.shape[1]
x, means, stdev, n_vars, padding = self.tokenize(x, mask)
x = self.prepare_prompt(
x, n_vars, prefix_prompt, task_prompt, None, mask=mask, task_name='imputation')
seq_token_len = x.shape[-2]-prefix_prompt.shape[2]
x = self.backbone(x, prefix_prompt.shape[2], seq_token_len)
x = self.forecast_head(
x, seq_len+padding, seq_token_len)
x = x[:, :seq_len]
# De-Normalization from Non-stationary Transformer
x = x * (stdev[:, 0, :].unsqueeze(1).repeat(1, x.shape[1], 1))
x = x + (means[:, 0, :].unsqueeze(1).repeat(1, x.shape[1], 1))
return x
def anomaly_detection(self, x, x_mark, task_id):
dataset_name = self.configs_list[task_id][1]['dataset']
prefix_prompt = self.prompt_tokens[dataset_name]
seq_len = x.shape[1]
x, means, stdev, n_vars, padding = self.tokenize(x)
x = self.prepare_prompt(x, n_vars, prefix_prompt,
None, None, task_name='anomaly_detection')
seq_token_len = x.shape[-2]-prefix_prompt.shape[2]
x = self.backbone(x, prefix_prompt.shape[2], seq_token_len)
x = self.forecast_head(
x, seq_len+padding, seq_token_len)
x = x[:, :seq_len]
# De-Normalization from Non-stationary Transformer
x = x * (stdev[:, 0, :].unsqueeze(1).repeat(1, x.shape[1], 1))
x = x + (means[:, 0, :].unsqueeze(1).repeat(1, x.shape[1], 1))
return x
def random_masking(self, x, min_mask_ratio, max_mask_ratio):
"""
Perform per-sample random masking.
"""
N, V, L, D = x.shape # batch, var, length, dim
# Calculate mask ratios and lengths to keep for each sample in the batch
mask_ratios = torch.rand(N, device=x.device) * \
(max_mask_ratio - min_mask_ratio) + min_mask_ratio
len_keeps = (L * (1 - mask_ratios)).long()
noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
# sort noise for each sample
# ascend: small is keep, large is remove
ids_shuffle = torch.argsort(noise, dim=1)
ids_restore = torch.argsort(ids_shuffle, dim=1)
# generate the binary mask: 0 is keep, 1 is remove
mask = torch.ones([N, L], device=x.device)
# Create a range tensor and compare with len_keeps for mask generation
range_tensor = torch.arange(L, device=x.device).expand(N, L)
mask = (range_tensor >= len_keeps.unsqueeze(1))
# unshuffle to get the binary mask
mask = torch.gather(mask, dim=1, index=ids_restore)
mask = mask.float()
return mask
def right_masking(self, x, min_mask_ratio, max_mask_ratio):
N, V, L, D = x.shape # batch, var, length, dim
# Randomly choose a mask ratio for each sample within the specified range
mask_ratios = torch.rand(N, device=x.device) * \
(max_mask_ratio - min_mask_ratio) + min_mask_ratio
len_keeps = (L * (1 - mask_ratios)).long()
# Binary mask creation without a for loop
len_keeps_matrix = len_keeps.unsqueeze(1).expand(N, L)
indices = torch.arange(L, device=x.device).expand_as(len_keeps_matrix)
mask = indices >= len_keeps_matrix
mask = mask.float()
return mask
def choose_masking(self, x, right_prob, min_mask_ratio, max_mask_ratio):
# Generate a random number to decide which masking function to use
if torch.rand(1).item() > right_prob:
return self.random_masking(x, min_mask_ratio, max_mask_ratio)
else:
return self.right_masking(x, min_mask_ratio, max_mask_ratio)
def get_mask_seq(self, mask, seq_len):
mask_seq = mask.unsqueeze(dim=-1).repeat(1, 1, self.patch_len)
mask_seq = mask_seq.permute(0, 2, 1)
mask_seq = mask_seq.masked_fill(mask_seq == 0, -1e9)
# Fold operation
mask_seq = torch.nn.functional.fold(mask_seq, output_size=(
seq_len, 1), kernel_size=(self.patch_len, 1), stride=(self.stride, 1))
# Apply threshold to bring back to 0/1 values
mask_seq = (mask_seq > 0).float()
mask_seq = mask_seq.squeeze(dim=-1).squeeze(dim=1)
return mask_seq
def pretraining(self, x, x_mark, task_id, enable_mask=False):
dataset_name = self.configs_list[task_id][1]['dataset']
task_data_name = self.configs_list[task_id][0]
prefix_prompt = self.prompt_tokens[dataset_name]
mask_token = self.mask_tokens[dataset_name]
cls_token = self.cls_tokens[task_data_name]
seq_len = x.shape[1]
x, means, stdev, n_vars, padding = self.tokenize(x)
seq_token_len = x.shape[-2]
# append prompt tokens
x = torch.reshape(
x, (-1, n_vars, x.shape[-2], x.shape[-1]))
# prepare prompts
this_prompt = prefix_prompt.repeat(x.shape[0], 1, 1, 1)
if enable_mask:
mask = self.choose_masking(x, self.right_prob,
self.min_mask_ratio, self.max_mask_ratio)
mask_repeat = mask.unsqueeze(dim=1).unsqueeze(dim=-1)
mask_repeat = mask_repeat.repeat(1, x.shape[1], 1, x.shape[-1])
x = x * (1-mask_repeat) + mask_token * mask_repeat # todo
init_full_input = torch.cat((this_prompt, x), dim=-2)
init_mask_prompt = self.prompt2forecat(
init_full_input.transpose(-1, -2), x.shape[2]).transpose(-1, -2)
# keep the unmasked tokens and fill the masked ones with init_mask_prompt.
x = x * (1-mask_repeat) + init_mask_prompt * mask_repeat
x = x + self.position_embedding(x)
mask_seq = self.get_mask_seq(mask, seq_len+padding)
mask_seq = mask_seq[:, :seq_len]
this_function_prompt = cls_token.repeat(x.shape[0], 1, 1, 1)
x = torch.cat((this_prompt, x, this_function_prompt), dim=2)
x = self.backbone(x, prefix_prompt.shape[2], seq_token_len)
if enable_mask:
mask_dec_out = self.forecast_head(
x[:, :, :-1], seq_len+padding, seq_token_len)
mask_dec_out = mask_dec_out[:, :seq_len]
# De-Normalization from Non-stationary Transformer
mask_dec_out = mask_dec_out * \
(stdev[:, 0, :].unsqueeze(1).repeat(
1, mask_dec_out.shape[1], 1))
mask_dec_out = mask_dec_out + \
(means[:, 0, :].unsqueeze(1).repeat(
1, mask_dec_out.shape[1], 1))
cls_dec_out = self.cls_head(x, return_feature=True)
# detach grad of the forecasting on tokens
fused_dec_out = torch.cat(
(cls_dec_out, x[:, :, self.prompt_num:-1].detach()), dim=2)
cls_dec_out = self.pretrain_head(
fused_dec_out, seq_len+padding, seq_token_len)
cls_dec_out = cls_dec_out[:, :seq_len]
cls_dec_out = cls_dec_out * \
(stdev[:, 0, :].unsqueeze(1).repeat(
1, cls_dec_out.shape[1], 1))
cls_dec_out = cls_dec_out + \
(means[:, 0, :].unsqueeze(1).repeat(
1, cls_dec_out.shape[1], 1))
return cls_dec_out, mask_dec_out, mask_seq
else:
return cls_dec_out
def forward(self, x_enc, x_mark_enc, x_dec=None, x_mark_dec=None,
mask=None, task_id=None, task_name=None, enable_mask=None):
task_id = 0
# if task_name == 'long_term_forecast' or task_name == 'short_term_forecast':
dec_out = self.forecast(x_enc, x_mark_enc, task_id)
return dec_out # [B, L, D]
# if task_name == 'imputation':
# dec_out = self.imputation(
# x_enc, x_mark_enc, mask, task_id)
# return dec_out # [B, L, D]
# if task_name == 'anomaly_detection':
# dec_out = self.anomaly_detection(x_enc, x_mark_enc, task_id)
# return dec_out # [B, L, D]
# if task_name == 'classification':
# dec_out = self.classification(x_enc, x_mark_enc, task_id)
# return dec_out # [B, N]
# if 'pretrain' in task_name:
# dec_out = self.pretraining(x_enc, x_mark_enc, task_id,
# enable_mask=enable_mask)
# return dec_out
# return None
class UniTS(Forecaster):
def __init__(
self,
ckpt_path: str = None,
**kwargs
):
super().__init__(**kwargs)
self.no_training = True
if (type(self.context_length).__name__=='list'):
context_length = max(context_length)
if (type(self.prediction_length).__name__=='list'):
prediction_length = max(prediction_length)
args, configs_list = self.generate_units_default_args(self.dataset)
self.model = Model(args, configs_list, pretrain=False)
pretrain_weight_path = ckpt_path
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
state_dict = torch.load(pretrain_weight_path, map_location=device)['student']
ckpt = {}
for k, v in state_dict.items():
if not ('cls_prompts' in k):
k = k.replace('module.', '') if 'module.' in k else k
ckpt[k] = v
msg = self.model.load_state_dict(ckpt, strict=False)
if len(msg.missing_keys) > 0:
print(f"""Warning: There are missing keys in the pretrained model: {msg.missing_keys},
which may cause prediction results less accurate.""")
def generate_units_default_args(self, dataset_name='ETTh1'):
class Args:
def __init__(self):
self.d_model = 128
self.n_heads = 8
self.e_layers = 3
self.prompt_num = 10
self.dropout = 0.1
self.patch_len = 16
self.stride = 16
self.batch_size = 32
args = Args()
# parse dataset names - ECL, ETTh1, Exchange, ILI, Traffic, Weather
units_valid_dataset_map = {
'ECL': ['ECL', 'electricity'],
'ETTh1': ['ETT'],
'Exchange': ['Exchange'],
'ILI': ['ILI'],
'Traffic': ['Traffic'],
'Weather': ['Weather']
}
units_dataset_name = 'DEFAULT'
for key, value_list in units_valid_dataset_map.items():
if any(substring.lower() in dataset_name for substring in value_list):
units_dataset_name = key
break
task_name = f"LTF_{units_dataset_name}_p{self.prediction_length}"
task_data_config = {
task_name: {
"task_name": "long_term_forecast",
"dataset": units_dataset_name,
"data": units_dataset_name,
"embed": "timeF",
"features": "M",
"seq_len": self.context_length,
"label_len": 48,
"pred_len": self.prediction_length,
"enc_in": self.target_dim,
"dec_in": self.target_dim,
"c_out": self.target_dim
}
}
task_data_config_list = []
for task_name, task_config in task_data_config.items():
task_config['max_batch'] = args.batch_size
task_data_config_list.append([task_name, task_config])
return args, task_data_config_list
def forecast(self, batch_data, pred_len=None, dataset_name=None, *args, **kwargs):
inputs = self.get_inputs(batch_data, 'encode')
inputs = inputs[:, -self.context_length:]
B, _, K = inputs.shape
point_forecast = self.model.forward(inputs, None)
return point_forecast.unsqueeze(1)
================================================
FILE: probts/model/forecaster/prob_forecaster/__init__.py
================================================
from .gru_nvp import GRU_NVP
from .gru_maf import GRU_MAF
from .timegrad import TimeGrad
from .trans_maf import Trans_MAF
from .csdi import CSDI
from .tsdiff import TSDiffCond
# ------- add lag_llama to sys.path ---------
try:
import os, sys
current_dir = os.path.dirname(os.path.realpath(__file__))
project_root = os.path.abspath(os.path.join(current_dir, '..', '..', '..', '..'))
lag_llama_path = os.path.join(project_root, 'submodules', 'lag_llama')
moirai_path = os.path.join(project_root, 'submodules', 'uni2ts', 'src')
if lag_llama_path not in sys.path:
sys.path.append(lag_llama_path)
if moirai_path not in sys.path:
sys.path.append(moirai_path)
except Exception as e:
print(f"Warning: Unable to add lag_llama to sys.path. {e}")
# -------------------------------------------
import importlib
modules = [
('moirai', 'Moirai'),
('chronos', 'Chronos'),
('lag_llama', 'LagLlama'),
]
for module, class_name in modules:
try:
mod = importlib.import_module(f".{module}", package=__package__)
globals()[class_name] = getattr(mod, class_name)
except ImportError:
# print(f"Warning: {class_name} is not available due to missing dependencies.")
pass
================================================
FILE: probts/model/forecaster/prob_forecaster/chronos.py
================================================
# ---------------------------------------------------------------------------------
# Portions of this file are derived from Chronos
# - Source: https://github.com/amazon-science/chronos-forecasting
# - Paper: Chronos: Learning the Language of Time Series
# - License: Apache License 2.0
#
# We thank the authors for their contributions.
# ---------------------------------------------------------------------------------
import torch
# from chronos import ChronosPipeline
from einops import rearrange
from probts.model.nn.arch.ChronosModule.base import BaseChronosPipeline
from probts.model.forecaster import Forecaster
class Chronos(Forecaster):
def __init__(
self,
model_size: str = 'base',
**kwargs
):
super().__init__(**kwargs)
if type(self.prediction_length) == list:
self.prediction_length = max(self.prediction_length)
if type(self.context_length) == list:
self.context_length = max(self.context_length)
self.pred_len = self.prediction_length
# Load pretrained model
self.no_training = True
self.pipeline = BaseChronosPipeline.from_pretrained(
f"amazon/chronos-t5-{model_size}", # use "amazon/chronos-bolt-small" for the corresponding Chronos-Bolt model
device_map="cuda",
torch_dtype=torch.bfloat16,)
self.q = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] # Quantile levels
def forecast(self, batch_data, num_samples=None):
inputs = self.get_inputs(batch_data, 'encode')
inputs = inputs[:, -self.context_length:]
B, _, K = inputs.shape
inputs = rearrange(inputs, 'b l k -> (b k) l')#.cpu()
context = [inputs[i] for i in range(B*K)]
inner_batch_size = 12 # for 80G gpu
forecast_samples = []
# Process in batches of size `inner_batch_size`
for i in range(0, len(context), inner_batch_size):
batch_context = context[i:i + inner_batch_size]
batch_forecast_samples = self.pipeline.predict(
batch_context,
prediction_length=self.pred_len,
num_samples=num_samples,
limit_prediction_length=False
)
forecast_samples.append(batch_forecast_samples)
forecast_samples = torch.cat(forecast_samples, dim=0)
prob_forecast = rearrange(forecast_samples, '(b k) s l -> b s l k', b=B, k=K)
return prob_forecast
================================================
FILE: probts/model/forecaster/prob_forecaster/csdi.py
================================================
# ---------------------------------------------------------------------------------
# Portions of this file are derived from CSDI
# - Source: https://github.com/ermongroup/CSDI
# - Paper: CSDI: Conditional Score-based Diffusion Models for Probabilistic Time Series Imputation
# - License: MIT license
# We thank the authors for their contributions.
# ---------------------------------------------------------------------------------
import torch
import torch.nn as nn
import numpy as np
from einops import repeat
from probts.model.forecaster import Forecaster
from probts.model.nn.prob.diffusion_layers import diff_CSDI
class CSDI(Forecaster):
def __init__(
self,
channels: int = 64,
emb_time_dim: int = 128,
emb_feature_dim: int = 16,
num_steps: int = 50,
schedule: str = "quad",
beta_start: float = 0.0001,
beta_end: float = 0.5,
diffusion_embedding_dim: int = 128,
num_heads: int = 8,
n_layers: int = 4,
sample_size: int = 64,
linear_trans: bool = False,
**kwargs
):
super().__init__(**kwargs)
self.autoregressive = False
self.dist_args = nn.Identity()
self.emb_time_dim = emb_time_dim
self.emb_feature_dim = emb_feature_dim
self.emb_total_dim = self.emb_time_dim + self.emb_feature_dim
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.emb_total_dim += 1 # for conditional mask
self.embed_layer = nn.Embedding(
num_embeddings=self.target_dim, embedding_dim=self.emb_feature_dim
)
side_dim = self.emb_total_dim
self.sample_size = sample_size
input_dim = 2
self.diffmodel = diff_CSDI(channels, diffusion_embedding_dim, side_dim, num_steps, num_heads, n_layers, inputdim=input_dim,linear=linear_trans)
# parameters for diffusion models
self.num_steps = num_steps
if schedule == "quad":
self.beta = np.linspace(
beta_start ** 0.5, beta_end ** 0.5, self.num_steps
) ** 2
elif schedule == "linear":
self.beta = np.linspace(
beta_start, beta_end, self.num_steps
)
self.alpha_hat = 1 - self.beta
self.alpha = np.cumprod(self.alpha_hat)
self.alpha_torch = torch.tensor(self.alpha).float().unsqueeze(1).unsqueeze(1).to(self.device)
def time_embedding(self, pos, device, d_model=128):
pe = torch.zeros(pos.shape[0], pos.shape[1], d_model).to(device)
position = pos.unsqueeze(2)
div_term = 1 / torch.pow(
10000.0, torch.arange(0, d_model, 2).to(device) / d_model
)
pe[:, :, 0::2] = torch.sin(position * div_term)
pe[:, :, 1::2] = torch.cos(position * div_term)
return pe
def set_input_to_diffmodel(self, noisy_data, observed_data, cond_mask):
cond_obs = (cond_mask * observed_data).unsqueeze(1)
noisy_target = ((1 - cond_mask) * noisy_data).unsqueeze(1)
total_input = torch.cat([cond_obs, noisy_target], dim=1) # (B,2,K,L)
return total_input
def get_masks(self, batch_data):
hist_observed_mask = batch_data.past_observed_values[:, -self.context_length:, ...]
target_observed_mask = batch_data.future_observed_values
observed_mask = torch.cat((hist_observed_mask, target_observed_mask), dim=1)
cond_mask = torch.cat((hist_observed_mask, torch.zeros_like(target_observed_mask)), dim=1)
return observed_mask, cond_mask # [B L K]
def get_side_info(self, observed_data, cond_mask, target_dimension_indicator, observed_tp=None):
B, K, L = observed_data.shape
if observed_tp is None:
observed_tp = torch.arange(L) * 1.0
observed_tp = repeat(observed_tp, 'l -> b l', b=B).to(observed_data.device)
time_embed = self.time_embedding(observed_tp, observed_data.device, self.emb_time_dim) # (B,L,emb)
time_embed = time_embed.unsqueeze(2).expand(-1, -1, K, -1) # (B,L,K, emb)
feature_embed = self.embed_layer(target_dimension_indicator) # (B, K,emb)
feature_embed = feature_embed.unsqueeze(1).expand(-1, L, -1, -1) # (B,L,K, emb)
side_info = torch.cat([time_embed, feature_embed], dim=-1) # (B,L,K,*)
side_info = side_info.permute(0, 3, 2, 1) # (B,*,K,L)
side_mask = cond_mask.unsqueeze(1) # (B,1,K,L)
side_info = torch.cat([side_info, side_mask], dim=1)
return side_info # (B,D,K,L)
def loss(self, batch_data, observed_tp=None):
past_target_cdf = batch_data.past_target_cdf[:, -self.context_length:, ...]
future_target_cdf = batch_data.future_target_cdf
observed_data = torch.cat([past_target_cdf, future_target_cdf], dim=1)
B, L, K = observed_data.shape
t = torch.randint(0, self.num_steps, [B]).to(past_target_cdf.device)
observed_mask, gt_mask = self.get_masks(batch_data)
feature_id = batch_data.target_dimension_indicator
if K > self.sample_size:
# sample subset
sampled_data = []
sampled_mask = []
sampled_feature_id = []
sampled_gt_mask = []
for i in range(len(observed_data)):
ind = np.arange(K)
np.random.shuffle(ind)
sampled_data.append(observed_data[i,...,ind[:self.sample_size]])
sampled_mask.append(observed_mask[i,...,ind[:self.sample_size]])
sampled_feature_id.append(feature_id[i,ind[:self.sample_size]])
sampled_gt_mask.append(gt_mask[i,...,ind[:self.sample_size]])
observed_data = torch.stack(sampled_data,0)
observed_mask = torch.stack(sampled_mask,0)
feature_id = torch.stack(sampled_feature_id,0)
gt_mask = torch.stack(sampled_gt_mask,0)
observed_data = observed_data.permute(0,2,1) # [B K L]
observed_mask = observed_mask.permute(0,2,1) # [B K L]
cond_mask = gt_mask.permute(0,2,1) # [B K L]
side_info = self.get_side_info(observed_data, cond_mask, feature_id, observed_tp)
target_mask = observed_mask - cond_mask
current_alpha = self.alpha_torch[t] # (B,1,1)
noise = torch.randn_like(observed_data).to(observed_data.device)
noisy_data = (current_alpha ** 0.5) * observed_data + (1.0 - current_alpha) ** 0.5 * noise
total_input = self.set_input_to_diffmodel(noisy_data, observed_data, cond_mask)
predicted = self.diffmodel(total_input, side_info, t) # (B,K,L)
residual = (noise - predicted) * target_mask
num_eval = target_mask.sum()
loss = (residual ** 2).sum() / (num_eval if num_eval > 0 else 1)
loss = self.get_weighted_loss(batch_data, loss)
return loss.mean()
def forecast(self, batch_data, num_samples):
observed_data = torch.cat([batch_data.past_target_cdf[:, -self.context_length:, ...], torch.zeros_like(batch_data.future_target_cdf)], dim=1).permute(0,2,1)
_, cond_mask = self.get_masks(batch_data)
cond_mask = cond_mask.permute(0,2,1)
side_info = self.get_side_info(observed_data, cond_mask, batch_data.target_dimension_indicator)
sample = self.sample(observed_data, cond_mask, side_info, num_samples)
sample = sample.permute(0,1,3,2)
return sample[:, : , -self.prediction_length:, :] # [B N L K]
def sample(self, observed_data, cond_mask, side_info, n_samples):
B, K, L = observed_data.shape
imputed_samples = torch.zeros(B, n_samples, K, L).to(observed_data.device)
for i in range(n_samples):
current_sample = torch.randn_like(observed_data).to(observed_data.device)
for t in range(self.num_steps - 1, -1, -1):
cond_obs = (cond_mask * observed_data).unsqueeze(1)
noisy_target = ((1 - cond_mask) * current_sample).unsqueeze(1) # [B 1 K L]
diff_input = torch.cat([cond_obs, noisy_target], dim=1) # (B,2,K,L)
predicted = self.diffmodel(diff_input, side_info, torch.tensor([t]).to(observed_data.device))
coeff1 = 1 / self.alpha_hat[t] ** 0.5
coeff2 = (1 - self.alpha_hat[t]) / (1 - self.alpha[t]) ** 0.5
current_sample = coeff1 * (current_sample - coeff2 * predicted)
if t > 0:
noise = torch.randn_like(current_sample).to(observed_data.device)
sigma = (
(1.0 - self.alpha[t - 1]) / (1.0 - self.alpha[t]) * self.beta[t]
) ** 0.5
current_sample += sigma * noise
imputed_samples[:, i] = current_sample.detach()
return imputed_samples
================================================
FILE: probts/model/forecaster/prob_forecaster/gru_maf.py
================================================
# ---------------------------------------------------------------------------------
# Portions of this file are derived from PyTorch-TS
# - Source: https://github.com/zalandoresearch/pytorch-ts
# - Paper: Multi-variate Probabilistic Time Series Forecasting via Conditioned Normalizing Flows
# - License: MIT, Apache-2.0 license
# We thank the authors for their contributions.
# ---------------------------------------------------------------------------------
import torch
import torch.nn as nn
from probts.data import ProbTSBatchData
from probts.utils import repeat
from probts.model.forecaster import Forecaster
from probts.model.nn.prob.MAF import MAF
class GRU_MAF(Forecaster):
def __init__(
self,
enc_num_layers: int = 2,
enc_hidden_size: int = 40,
enc_dropout: float = 0.1,
n_blocks: int = 4,
hidden_size: int = 100,
n_hidden: int = 2,
conditional_length: int = 200,
dequantize: bool = False,
batch_norm: bool = True,
**kwargs
):
super().__init__(**kwargs)
self.autoregressive = True
self.encoder = nn.GRU(
input_size=self.input_size,
hidden_size=enc_hidden_size,
num_layers=enc_num_layers,
dropout=enc_dropout,
batch_first=True
)
self.prob_model = MAF(
n_blocks=n_blocks,
target_dim=self.target_dim,
hidden_size=hidden_size,
n_hidden=n_hidden,
f_hidden_size=enc_hidden_size,
conditional_length=conditional_length,
dequantize=dequantize,
batch_norm=batch_norm
)
def loss(self, batch_data):
if self.use_scaling:
self.get_scale(batch_data)
self.prob_model.scale = self.scaler.scale
inputs = self.get_inputs(batch_data, 'all')
enc_outs, states = self.encoder(inputs)
enc_outs = enc_outs[:, -self.prediction_length-1:-1, ...]
dist_args = self.prob_model.dist_args(enc_outs)
loss = self.prob_model.loss(batch_data.future_target_cdf, dist_args).unsqueeze(-1)
loss = self.get_weighted_loss(batch_data, loss)
return loss.mean()
def forecast(self, batch_data, num_samples=None):
if self.use_scaling:
self.get_scale(batch_data)
states = self.encode(batch_data)
repeated_target_dimension_indicator = repeat(batch_data.target_dimension_indicator, num_samples)
repeated_past_target_cdf = repeat(batch_data.past_target_cdf, num_samples)
repeated_future_time_feat = repeat(batch_data.future_time_feat, num_samples)
repeated_states = repeat(states, num_samples, dim=1)
if self.use_scaling:
repeated_scale = repeat(self.scaler.scale, num_samples)
self.scaler.scale = repeated_scale
self.prob_model.scale = repeated_scale
future_samples = []
for k in range(self.prediction_length):
repeated_batch_data = ProbTSBatchData({
'target_dimension_indicator': repeated_target_dimension_indicator,
'past_target_cdf': repeated_past_target_cdf,
'future_time_feat': repeated_future_time_feat[:, k:k+1, ...]
}, device=batch_data.device)
enc_outs, repeated_states = self.decode(repeated_batch_data, repeated_states)
# Sample
dist_args = self.prob_model.dist_args(enc_outs)
new_samples = self.prob_model.sample(cond=dist_args)
future_samples.append(new_samples)
repeated_past_target_cdf = torch.cat(
(repeated_past_target_cdf, new_samples), dim=1
)
forecasts = torch.cat(future_samples, dim=1).reshape(
-1, num_samples, self.prediction_length, self.target_dim)
return forecasts
def encode(self, batch_data):
inputs = self.get_inputs(batch_data, 'encode')
outputs, states = self.encoder(inputs)
return states
def decode(self, batch_data, states=None):
inputs = self.get_inputs(batch_data, 'decode')
outputs, states = self.encoder(inputs, states)
return outputs, states
================================================
FILE: probts/model/forecaster/prob_forecaster/gru_nvp.py
================================================
# ---------------------------------------------------------------------------------
# Portions of this file are derived from PyTorch-TS
# - Source: https://github.com/zalandoresearch/pytorch-ts
# - Paper: Multi-variate Probabilistic Time Series Forecasting via Conditioned Normalizing Flows
# - License: MIT, Apache-2.0 license
# We thank the authors for their contributions.
# ---------------------------------------------------------------------------------
import torch
import torch.nn as nn
from probts.data import ProbTSBatchData
from probts.utils import repeat
from probts.model.forecaster import Forecaster
from probts.model.nn.prob.RealNVP import RealNVP
class GRU_NVP(Forecaster):
def __init__(
self,
enc_num_layers: int = 2,
enc_hidden_size: int = 40,
enc_dropout: float = 0.1,
n_blocks: int = 4,
hidden_size: int = 100,
n_hidden: int = 2,
conditional_length: int = 200,
dequantize: bool = False,
batch_norm: bool = True,
**kwargs
):
super().__init__(**kwargs)
self.autoregressive = True
self.encoder = nn.GRU(
input_size=self.input_size,
hidden_size=enc_hidden_size,
num_layers=enc_num_layers,
dropout=enc_dropout,
batch_first=True
)
self.prob_model = RealNVP(
n_blocks=n_blocks,
target_dim=self.target_dim,
hidden_size=hidden_size,
n_hidden=n_hidden,
f_hidden_size=enc_hidden_size,
conditional_length=conditional_length,
dequantize=dequantize,
batch_norm=batch_norm
)
def loss(self, batch_data):
if self.use_scaling:
self.get_scale(batch_data)
self.prob_model.scale = self.scaler.scale
inputs = self.get_inputs(batch_data, 'all')
enc_outs, states = self.encoder(inputs)
enc_outs = enc_outs[:, -self.prediction_length-1:-1, ...]
dist_args = self.prob_model.dist_args(enc_outs)
loss = self.prob_model.loss(batch_data.future_target_cdf, dist_args).unsqueeze(-1)
loss = self.get_weighted_loss(batch_data, loss)
return loss.mean()
def forecast(self, batch_data, num_samples=None):
if self.use_scaling:
self.get_scale(batch_data)
states = self.encode(batch_data)
repeated_target_dimension_indicator = repeat(batch_data.target_dimension_indicator, num_samples)
repeated_past_target_cdf = repeat(batch_data.past_target_cdf, num_samples)
repeated_future_time_feat = repeat(batch_data.future_time_feat, num_samples)
repeated_states = repeat(states, num_samples, dim=1)
if self.use_scaling:
repeated_scale = repeat(self.scaler.scale, num_samples)
self.scaler.scale = repeated_scale
self.prob_model.scale = repeated_scale
future_samples = []
for k in range(self.prediction_length):
repeated_batch_data = ProbTSBatchData({
'target_dimension_indicator': repeated_target_dimension_indicator,
'past_target_cdf': repeated_past_target_cdf,
'future_time_feat': repeated_future_time_feat[:, k:k+1, ...]
}, device=batch_data.device)
enc_outs, repeated_states = self.decode(repeated_batch_data, repeated_states)
# Sample
dist_args = self.prob_model.dist_args(enc_outs)
new_samples = self.prob_model.sample(cond=dist_args)
future_samples.append(new_samples)
repeated_past_target_cdf = torch.cat(
(repeated_past_target_cdf, new_samples), dim=1
)
forecasts = torch.cat(future_samples, dim=1).reshape(
-1, num_samples, self.prediction_length, self.target_dim)
return forecasts
def encode(self, batch_data):
inputs = self.get_inputs(batch_data, 'encode')
outputs, states = self.encoder(inputs)
return states
def decode(self, batch_data, states=None):
inputs = self.get_inputs(batch_data, 'decode')
outputs, states = self.encoder(inputs, states)
return outputs, states
================================================
FILE: probts/model/forecaster/prob_forecaster/lag_llama.py
================================================
# ---------------------------------------------------------------------------------
# Portions of this file are derived from lag-llama
# - Source: https://github.com/time-series-foundation-models/lag-llama
# - Paper: Lag-Llama: Towards Foundation Models for Probabilistic Time Series Forecasting
# - License: Apache License 2.0
# We thank the authors for their contributions.
# ---------------------------------------------------------------------------------
import numpy as np
import torch
from gluonts.dataset.common import ListDataset
from probts.model.forecaster import Forecaster
from submodules.lag_llama.lag_llama.gluon.estimator import LagLlamaEstimator
class LagLlama(Forecaster):
def __init__(
self,
use_rope_scaling: bool = True,
ckpt_path: str = None,
**kwargs
):
super().__init__(**kwargs)
# self.ctx_len = kwargs.get('context_length')
# self.pred_len = kwargs.get('prediction_length')
if type(self.prediction_length) == list:
self.prediction_length = max(self.prediction_length)
if type(self.context_length) == list:
self.context_length = max(self.context_length)
self.ctx_len = self.context_length
self.pred_len = self.prediction_length
# Load pretrained model
self.no_training = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ckpt = torch.load(ckpt_path, map_location=device)
estimator_args = ckpt["hyper_parameters"]["model_kwargs"]
rope_scaling_arguments = {
"type": "linear",
"factor": max(1.0, (self.ctx_len + self.pred_len) / estimator_args["context_length"]), # 32
}
# Load model checkpoint
estimator = LagLlamaEstimator(
ckpt_path=ckpt_path,
prediction_length=self.pred_len,
context_length=self.ctx_len, # Lag-Llama was trained with a context length of 32, but can work with any context length
# estimator args
input_size=estimator_args["input_size"], # 1
n_layer=estimator_args["n_layer"], # 8
n_embd_per_head=estimator_args["n_embd_per_head"], # 16
n_head=estimator_args["n_head"], # 9
scaling=estimator_args["scaling"], # robust
time_feat=estimator_args["time_feat"], # True
rope_scaling=rope_scaling_arguments if use_rope_scaling else None, # long-term set to True
batch_size=4,
num_parallel_samples=100,
device=device,
)
lightning_module = estimator.create_lightning_module()
transformation = estimator.create_transformation()
self.predictor = estimator.create_predictor(transformation, lightning_module)
def forecast(self, batch_data, num_samples=None):
inputs = self.get_inputs(batch_data, 'encode')
inputs = inputs[:, -self.context_length:]
datastamps = batch_data.past_time_feat.cpu().numpy().astype('datetime64[s]')
# for now, we only support batch_size=1
B, _, K = inputs.shape
# past_target = batch_data.past_target_cdf[:, -self.context_length:]
start_time = datastamps.reshape(-1)[0]
data = [{"start": start_time, "target": inputs[:,:,i].cpu().squeeze()} for i in range(K)]
dataset = ListDataset(data, freq='1h')
forecasts = self.predictor.predict(dataset, num_samples=num_samples)
samples = [fs.samples for fs in forecasts]
forecasts = np.array(samples).transpose(1, 2, 0)
prob_forecast = forecasts[np.newaxis, :, :]
prob_forecast = torch.tensor(prob_forecast) # shape: b s l k
return prob_forecast
================================================
FILE: probts/model/forecaster/prob_forecaster/moirai.py
================================================
# ---------------------------------------------------------------------------------
# Portions of this file are derived from uni2ts
# - Source: https://github.com/SalesforceAIResearch/uni2ts
# - Paper: Unified Training of Universal Time Series Forecasting Transformers
# - License: Apache License 2.0
# We thank the authors for their contributions.
# ---------------------------------------------------------------------------------
from typing import Union
from probts.model.forecaster import Forecaster
from einops import rearrange, repeat
from probts.model.nn.arch.Moirai_backbone import MoiraiBackbone
from uni2ts.model.moirai.module import MoiraiModule
import sys
class Moirai(Forecaster):
def __init__(
self,
variate_mode: str = 'M',
patch_size: Union[str, int] = 'auto',
model_size: str = 'base',
scaling: bool = True,
**kwargs
):
super().__init__(**kwargs)
self.variate_mode = variate_mode
self.patch_size = patch_size if patch_size == 'auto' else int(patch_size)
if type(self.prediction_length) == list:
self.prediction_length = max(self.prediction_length)
if type(self.context_length) == list:
self.context_length = max(self.context_length)
# Load pretrained model
self.no_training = True
self.moirai = MoiraiBackbone(
module=MoiraiModule.from_pretrained(f"Salesforce/moirai-1.0-R-{model_size}"),
prediction_length=self.prediction_length,
context_length=self.context_length,
patch_size=self.patch_size,
target_dim=self.target_dim if self.variate_mode == 'M' else 1,
scaling=scaling
)
def forecast(self, batch_data, num_samples=None):
if self.variate_mode == 'M':
forecasts = self.moirai(
past_target=batch_data.past_target_cdf,
past_observed_target=batch_data.past_observed_values,
past_is_pad=batch_data.past_is_pad,
num_samples=num_samples
)
elif self.variate_mode == 'S':
B, L, K = batch_data.past_target_cdf.shape
forecasts = self.moirai(
past_target=rearrange(batch_data.past_target_cdf, 'b l k -> (b k) l').unsqueeze(-1),
past_observed_target=rearrange(batch_data.past_observed_values, 'b l k -> (b k) l').unsqueeze(-1),
past_is_pad=repeat(batch_data.past_is_pad, 'b l -> (b k) l', k=K),
num_samples=num_samples
)
forecasts = forecasts.squeeze(-1)
forecasts = rearrange(forecasts, '(b k) n l -> b n l k', b=B, k=K)
else:
raise ValueError(f"Unknown variate mode: {self.variate_mode}")
return forecasts
================================================
FILE: probts/model/forecaster/prob_forecaster/timegrad.py
================================================
# ---------------------------------------------------------------------------------
# Portions of this file are derived from PyTorch-TS
# - Source: https://github.com/zalandoresearch/pytorch-ts
# - Paper: Multi-variate Probabilistic Time Series Forecasting via Conditioned Normalizing Flows
# - License: MIT, Apache-2.0 license
# We thank the authors for their contributions.
# ---------------------------------------------------------------------------------
import torch
import torch.nn as nn
from probts.data import ProbTSBatchData
from probts.utils import repeat
from probts.model.forecaster import Forecaster
from probts.model.nn.prob.gaussian_diffusion import GaussianDiffusion
class TimeGrad(Forecaster):
def __init__(
self,
enc_num_layers: int = 2,
enc_hidden_size: int = 40,
enc_dropout: float = 0.1,
conditional_length: int = 100,
beta_end: float = 0.1,
diff_steps: int = 100,
loss_type: str = "l2",
beta_schedule: str = "linear",
**kwargs
):
super().__init__(**kwargs)
self.autoregressive = True
self.encoder = nn.GRU(
input_size=self.input_size,
hidden_size=enc_hidden_size,
num_layers=enc_num_layers,
dropout=enc_dropout,
batch_first=True
)
self.prob_model = GaussianDiffusion(
target_dim=self.target_dim,
f_hidden_size=enc_hidden_size,
conditional_length=conditional_length,
beta_end=beta_end,
diff_steps=diff_steps,
loss_type=loss_type,
beta_schedule=beta_schedule
)
def loss(self, batch_data):
if self.use_scaling:
self.get_scale(batch_data)
self.prob_model.scale = self.scaler.scale
inputs = self.get_inputs(batch_data, 'all')
enc_outs, states = self.encoder(inputs)
enc_outs = enc_outs[:, -self.prediction_length-1:-1, ...]
dist_args = self.prob_model.dist_args(enc_outs)
loss = self.prob_model.loss(batch_data.future_target_cdf, dist_args).unsqueeze(-1)
loss = self.get_weighted_loss(batch_data, loss)
return loss.mean()
def forecast(self, batch_data, num_samples=None):
if self.use_scaling:
self.get_scale(batch_data)
states = self.encode(batch_data)
repeated_target_dimension_indicator = repeat(batch_data.target_dimension_indicator, num_samples)
repeated_past_target_cdf = repeat(batch_data.past_target_cdf, num_samples)
repeated_future_time_feat = repeat(batch_data.future_time_feat, num_samples)
repeated_states = repeat(states, num_samples, dim=1)
if self.use_scaling:
repeated_scale = repeat(self.scaler.scale, num_samples)
self.scaler.scale = repeated_scale
self.prob_model.scale = repeated_scale
future_samples = []
for k in range(self.prediction_length):
repeated_batch_data = ProbTSBatchData({
'target_dimension_indicator': repeated_target_dimension_indicator,
'past_target_cdf': repeated_past_target_cdf,
'future_time_feat': repeated_future_time_feat[:, k:k+1, ...]
}, device=batch_data.device)
enc_outs, repeated_states = self.decode(repeated_batch_data, repeated_states)
# Sample
dist_args = self.prob_model.dist_args(enc_outs)
new_samples = self.prob_model.sample(cond=dist_args)
future_samples.append(new_samples)
repeated_past_target_cdf = torch.cat(
(repeated_past_target_cdf, new_samples), dim=1
)
forecasts = torch.cat(future_samples, dim=1).reshape(
-1, num_samples, self.prediction_length, self.target_dim)
return forecasts
def encode(self, batch_data):
inputs = self.get_inputs(batch_data, 'encode')
outputs, states = self.encoder(inputs)
return states
def decode(self, batch_data, states=None):
inputs = self.get_inputs(batch_data, 'decode')
outputs, states = self.encoder(inputs, states)
return outputs, states
================================================
FILE: probts/model/forecaster/prob_forecaster/trans_maf.py
================================================
# ---------------------------------------------------------------------------------
# Portions of this file are derived from PyTorch-TS
# - Source: https://github.com/zalandoresearch/pytorch-ts
# - Paper: Multi-variate Probabilistic Time Series Forecasting via Conditioned Normalizing Flows
# - License: MIT, Apache-2.0 license
# We thank the authors for their contributions.
# ---------------------------------------------------------------------------------
import torch
import torch.nn as nn
from probts.data import ProbTSBatchData
from probts.utils import repeat
from probts.model.forecaster import Forecaster
from probts.model.nn.prob.MAF import MAF
class Trans_MAF(Forecaster):
def __init__(
self,
enc_hidden_size: int = 32,
enc_num_heads: int = 8,
enc_num_encoder_layers: int = 3,
enc_num_decoder_layers: int = 3,
enc_dim_feedforward_scale: int = 4,
enc_dropout: float = 0.1,
enc_activation: str = 'gelu',
n_blocks: int = 4,
hidden_size: int = 100,
n_hidden: int = 2,
conditional_length: int = 200,
dequantize: bool = False,
batch_norm: bool = True,
**kwargs
):
super().__init__(**kwargs)
self.autoregressive = True
self.enc_linear = nn.Linear(self.input_size, enc_hidden_size)
self.dec_linear = nn.Linear(self.input_size, enc_hidden_size)
self.model = nn.Transformer(
d_model=enc_hidden_size,
nhead=enc_num_heads,
num_encoder_layers=enc_num_encoder_layers,
num_decoder_layers=enc_num_decoder_layers,
dim_feedforward=enc_dim_feedforward_scale * enc_hidden_size,
dropout=enc_dropout,
activation=enc_activation
)
self.register_buffer(
"tgt_mask",
self.model.generate_square_subsequent_mask(self.prediction_length),
)
self.prob_model = MAF(
n_blocks=n_blocks,
target_dim=self.target_dim,
hidden_size=hidden_size,
n_hidden=n_hidden,
f_hidden_size=enc_hidden_size,
conditional_length=conditional_length,
dequantize=dequantize,
batch_norm=batch_norm
)
def loss(self, batch_data):
if self.use_scaling:
self.get_scale(batch_data)
self.prob_model.scale = self.scaler.scale
inputs = self.get_inputs(batch_data, 'all') # [B L D]
enc_inputs = inputs[:, :self.context_length, ...]
enc_inputs = self.enc_linear(enc_inputs).permute(1, 0, 2)
enc_outputs = self.model.encoder(enc_inputs) # [L_in B H]
dec_inputs = inputs[:, -self.prediction_length-1:-1, ...]
dec_inputs = self.dec_linear(dec_inputs).permute(1, 0, 2)
dec_outputs = self.model.decoder(
dec_inputs, enc_outputs, tgt_mask=self.tgt_mask)
dec_outputs = dec_outputs.permute(1, 0, 2) # [L_out B D]
dist_args = self.prob_model.dist_args(dec_outputs)
loss = self.prob_model.loss(batch_data.future_target_cdf, dist_args).unsqueeze(-1)
loss = self.get_weighted_loss(batch_data, loss)
return loss.mean()
def forecast(self, batch_data, num_samples=None):
if self.use_scaling:
self.get_scale(batch_data)
states = self.encode(batch_data)
repeated_target_dimension_indicator = repeat(batch_data.target_dimension_indicator, num_samples)
repeated_past_target_cdf = repeat(batch_data.past_target_cdf, num_samples)
repeated_future_time_feat = repeat(batch_data.future_time_feat, num_samples)
repeated_states = repeat(states, num_samples, dim=1)
if self.use_scaling:
repeated_scale = repeat(self.scaler.scale, num_samples)
self.scaler.scale = repeated_scale
self.prob_model.scale = repeated_scale
future_samples = []
for k in range(self.prediction_length):
repeated_batch_data = ProbTSBatchData({
'target_dimension_indicator': repeated_target_dimension_indicator,
'past_target_cdf': repeated_past_target_cdf,
'future_time_feat': repeated_future_time_feat[:, k:k+1, ...]
}, device=batch_data.device)
enc_outs, repeated_states = self.decode(repeated_batch_data, repeated_states)
# Sample
dist_args = self.prob_model.dist_args(enc_outs)
new_samples = self.prob_model.sample(cond=dist_args)
future_samples.append(new_samples)
repeated_past_target_cdf = torch.cat(
(repeated_past_target_cdf, new_samples), dim=1
)
forecasts = torch.cat(future_samples, dim=1).reshape(
-1, num_samples, self.prediction_length, self.target_dim)
return forecasts
def encode(self, batch_data):
inputs = self.get_inputs(batch_data, 'encode')
inputs = self.enc_linear(inputs).permute(1, 0, 2)
states = self.model.encoder(inputs)
return states
def decode(self, batch_data, states=None):
inputs = self.get_inputs(batch_data, 'decode')
inputs = self.dec_linear(inputs).permute(1, 0, 2)
outputs = self.model.decoder(inputs, states, tgt_mask=None)
return outputs.permute(1, 0, 2), states
================================================
FILE: probts/model/forecaster/prob_forecaster/tsdiff.py
================================================
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
# ---------------------------------------------------------------------------------
# Portions of this file are derived from TSDiff
# - Source: https://github.com/amazon-science/unconditional-time-series-diffusion
# - Paper: Predict, Refine, Synthesize: Self-Guiding Diffusion Models for Probabilistic Time Series Forecasting
# - License: Apache-2.0
#
# We thank the authors for their contributions.
# ---------------------------------------------------------------------------------
import torch
import torch.nn.functional as F
from probts.utils import extract
from probts.model.forecaster import Forecaster
from probts.model.nn.arch.S4.s4_backbones import BackboneModel
from probts.utils import repeat
import sys
def linear_beta_schedule(timesteps):
beta_start = 0.0001
beta_end = 0.1
return torch.linspace(beta_start, beta_end, timesteps)
class TSDiffCond(Forecaster):
def __init__(
self,
hidden_dim: int,
step_emb: int,
timesteps: int,
num_residual_blocks: int,
dropout: float = 0,
# use_features: bool = False,
init_skip=True,
noise_observed=False, # reconstruct past
mode="diag",
measure="diag",
**kwargs
):
super().__init__(**kwargs)
backbone_parameters = {
"input_dim": self.target_dim,
"hidden_dim": hidden_dim,
"output_dim": self.target_dim,
"step_emb": step_emb,
"num_residual_blocks": num_residual_blocks,
"residual_block": "s4",
"mode": mode,
'measure': measure,
}
# self.use_features=use_features
self.timesteps = timesteps
self.betas = linear_beta_schedule(timesteps)
self.sqrt_one_minus_beta = torch.sqrt(1.0 - self.betas)
self.alphas = 1 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, axis=0)
self.alphas_cumprod_prev = F.pad(
self.alphas_cumprod[:-1], (1, 0), value=1.0
)
self.sqrt_recip_alphas = torch.sqrt(1.0 / self.alphas)
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(
1.0 - self.alphas_cumprod
)
self.posterior_variance = (
self.betas
* (1.0 - self.alphas_cumprod_prev)
/ (1.0 - self.alphas_cumprod)
)
self.backbone = BackboneModel(
**backbone_parameters,
num_features=self.target_dim,
init_skip=init_skip,
dropout=dropout,
)
self.noise_observed = noise_observed
def _extract_features(self, batch_data):
inputs = self.get_inputs(batch_data, 'all')
x = inputs[:,:, :self.target_dim]
features = inputs.clone()
if self.use_time_feat:
features[:,self.context_length:, :self.target_dim] = 0
else:
features = features[:,:, :self.target_dim]
features[:,self.context_length:] = 0
observation_mask = torch.zeros_like(x, device=x.device)
observation_mask[:,:self.context_length] = 1
return x, features, observation_mask
def q_sample(self, x_start, t, noise=None):
device = next(self.backbone.parameters()).device
if noise is None:
noise = torch.randn_like(x_start, device=device)
sqrt_alphas_cumprod_t = extract(
self.sqrt_alphas_cumprod, t, x_start.shape
)
sqrt_one_minus_alphas_cumprod_t = extract(
self.sqrt_one_minus_alphas_cumprod, t, x_start.shape
)
return (
sqrt_alphas_cumprod_t * x_start
+ sqrt_one_minus_alphas_cumprod_t * noise
)
def p_losses(
self,
x_start,
t,
features=None,
noise=None,
loss_type="l2",
reduction="none",
):
device = next(self.backbone.parameters()).device
if noise is None:
noise = torch.randn_like(x_start, device=device)
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
predicted_noise = self.backbone(x_noisy, t, features)
if loss_type == "l1":
loss = F.l1_loss(noise, predicted_noise, reduction=reduction)
elif loss_type == "l2":
loss = F.mse_loss(noise, predicted_noise, reduction=reduction)
elif loss_type == "huber":
loss = F.smooth_l1_loss(
noise, predicted_noise, reduction=reduction
)
else:
raise NotImplementedError()
return loss, x_noisy, predicted_noise
@torch.no_grad()
def p_sample(self, x, t, t_index, features=None):
betas_t = extract(self.betas, t, x.shape)
sqrt_one_minus_alphas_cumprod_t = extract(
self.sqrt_one_minus_alphas_cumprod, t, x.shape
)
sqrt_recip_alphas_t = extract(self.sqrt_recip_alphas, t, x.shape)
predicted_noise = self.backbone(x, t, features)
model_mean = sqrt_recip_alphas_t * (
x - betas_t * predicted_noise / sqrt_one_minus_alphas_cumprod_t
)
if t_index == 0:
return model_mean
else:
posterior_variance_t = extract(self.posterior_variance, t, x.shape)
noise = torch.randn_like(x)
return model_mean + torch.sqrt(posterior_variance_t) * noise
def step(self, x, t, features, loss_mask):
noise = torch.randn_like(x)
if not self.noise_observed:
noise = (1 - loss_mask) * x + noise * loss_mask
num_eval = loss_mask.sum()
sq_err, _, _ = self.p_losses(
x,
t,
features,
loss_type="l2",
reduction="none",
noise=noise,
)
if self.noise_observed:
elbo_loss = sq_err.mean()
else:
sq_err = sq_err * loss_mask
elbo_loss = sq_err.sum() / (num_eval if num_eval else 1)
return elbo_loss
def loss(self, batch_data):
# [b l k 1], [b l k 2]
x, features, observation_mask = self._extract_features(batch_data)
loss_mask = 1 - observation_mask
t = torch.randint(
0, self.timesteps, [x.shape[0]], device=x.device
).long()
loss = self.step(x, t, features, loss_mask)
if torch.isnan(loss):
print("Loss is NaN, exiting.")
sys.exit(1)
return loss
def forecast(self, batch_data, num_samples):
observation, features, observation_mask = self._extract_features(batch_data)
observation = observation.to(observation.device)
pred = self.sample(
observation=observation,
observation_mask=observation_mask,
n_samples=num_samples,
features=features,
)
return pred[:,:,-self.prediction_length:,:]
@torch.no_grad()
def sample(self, observation, observation_mask, n_samples, features=None):
repeated_observation = repeat(observation, n_samples)
repeated_observation_mask = repeat(observation_mask, n_samples)
repeated_features = repeat(features, n_samples)
batch_size, length, ch = repeated_observation.shape
seq = torch.randn_like(repeated_observation)
for i in reversed(range(0, self.timesteps)):
if not self.noise_observed:
seq = repeated_observation_mask * repeated_observation + seq * (1 - repeated_observation_mask)
seq = self.p_sample(
seq,
torch.full((batch_size,), i, device=repeated_observation.device, dtype=torch.long),
i,
repeated_features,
)
seq = seq.reshape(-1, n_samples, length, ch)
return seq
================================================
FILE: probts/model/nn/__init__.py
================================================
================================================
FILE: probts/model/nn/arch/AutoformerModule/AutoCorrelation.py
================================================
import torch
import torch.nn as nn
import math
class AutoCorrelation(nn.Module):
"""
AutoCorrelation Mechanism with the following two phases:
(1) period-based dependencies discovery
(2) time delay aggregation
This block can replace the self-attention family mechanism seamlessly.
"""
def __init__(self, mask_flag=True, factor=1, scale=None, attention_dropout=0.1, output_attention=False):
super(AutoCorrelation, self).__init__()
self.factor = factor
self.scale = scale
self.mask_flag = mask_flag
self.output_attention = output_attention
self.dropout = nn.Dropout(attention_dropout)
def time_delay_agg_training(self, values, corr):
"""
SpeedUp version of Autocorrelation (a batch-normalization style design)
This is for the training phase.
"""
head = values.shape[1]
channel = values.shape[2]
length = values.shape[3]
# find top k
top_k = int(self.factor * math.log(length))
mean_value = torch.mean(torch.mean(corr, dim=1), dim=1)
index = torch.topk(torch.mean(mean_value, dim=0), top_k, dim=-1)[1]
weights = torch.stack([mean_value[:, index[i]] for i in range(top_k)], dim=-1)
# update corr
tmp_corr = torch.softmax(weights, dim=-1)
# aggregation
tmp_values = values
delays_agg = torch.zeros_like(values).float()
for i in range(top_k):
pattern = torch.roll(tmp_values, -int(index[i]), -1)
delays_agg = delays_agg + pattern * \
(tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length))
return delays_agg
def time_delay_agg_inference(self, values, corr):
"""
SpeedUp version of Autocorrelation (a batch-normalization style design)
This is for the inference phase.
"""
batch = values.shape[0]
head = values.shape[1]
channel = values.shape[2]
length = values.shape[3]
# index init
init_index = torch.arange(length).unsqueeze(0).unsqueeze(0).unsqueeze(0)\
.repeat(batch, head, channel, 1).to(values.device)
# find top k
top_k = int(self.factor * math.log(length))
mean_value = torch.mean(torch.mean(corr, dim=1), dim=1)
weights, delay = torch.topk(mean_value, top_k, dim=-1)
# update corr
tmp_corr = torch.softmax(weights, dim=-1)
# aggregation
tmp_values = values.repeat(1, 1, 1, 2)
delays_agg = torch.zeros_like(values).float()
for i in range(top_k):
tmp_delay = init_index + delay[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length)
pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay)
delays_agg = delays_agg + pattern * \
(tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length))
return delays_agg
def time_delay_agg_full(self, values, corr):
"""
Standard version of Autocorrelation
"""
batch = values.shape[0]
head = values.shape[1]
channel = values.shape[2]
length = values.shape[3]
# index init
init_index = torch.arange(length).unsqueeze(0).unsqueeze(0).unsqueeze(0)\
.repeat(batch, head, channel, 1).to(values.device)
# find top k
top_k = int(self.factor * math.log(length))
weights, delay = torch.topk(corr, top_k, dim=-1)
# update corr
tmp_corr = torch.softmax(weights, dim=-1)
# aggregation
tmp_values = values.repeat(1, 1, 1, 2)
delays_agg = torch.zeros_like(values).float()
for i in range(top_k):
tmp_delay = init_index + delay[..., i].unsqueeze(-1)
pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay)
delays_agg = delays_agg + pattern * (tmp_corr[..., i].unsqueeze(-1))
return delays_agg
def forward(self, queries, keys, values, attn_mask):
B, L, H, E = queries.shape
_, S, _, D = values.shape
if L > S:
zeros = torch.zeros_like(queries[:, :(L - S), :]).float()
values = torch.cat([values, zeros], dim=1)
keys = torch.cat([keys, zeros], dim=1)
else:
values = values[:, :L, :, :]
keys = keys[:, :L, :, :]
# period-based dependencies
q_fft = torch.fft.rfft(queries.permute(0, 2, 3, 1).contiguous(), dim=-1)
k_fft = torch.fft.rfft(keys.permute(0, 2, 3, 1).contiguous(), dim=-1)
res = q_fft * torch.conj(k_fft)
corr = torch.fft.irfft(res, n=L, dim=-1)
# time delay agg
if self.training:
V = self.time_delay_agg_training(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2)
else:
V = self.time_delay_agg_inference(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2)
if self.output_attention:
return (V.contiguous(), corr.permute(0, 3, 1, 2))
else:
return (V.contiguous(), None)
class AutoCorrelationLayer(nn.Module):
def __init__(self, correlation, d_model, n_heads, d_keys=None,
d_values=None):
super(AutoCorrelationLayer, self).__init__()
d_keys = d_keys or (d_model // n_heads)
d_values = d_values or (d_model // n_heads)
self.inner_correlation = correlation
self.query_projection = nn.Linear(d_model, d_keys * n_heads)
self.key_projection = nn.Linear(d_model, d_keys * n_heads)
self.value_projection = nn.Linear(d_model, d_values * n_heads)
self.out_projection = nn.Linear(d_values * n_heads, d_model)
self.n_heads = n_heads
def forward(self, queries, keys, values, attn_mask):
B, L, _ = queries.shape
_, S, _ = keys.shape
H = self.n_heads
queries = self.query_projection(queries).view(B, L, H, -1)
keys = self.key_projection(keys).view(B, S, H, -1)
values = self.value_projection(values).view(B, S, H, -1)
out, attn = self.inner_correlation(
queries,
keys,
values,
attn_mask
)
out = out.view(B, L, -1)
return self.out_projection(out), attn
================================================
FILE: probts/model/nn/arch/AutoformerModule/Autoformer_EncDec.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
class my_Layernorm(nn.Module):
"""
Special designed layernorm for the seasonal part
"""
def __init__(self, channels):
super(my_Layernorm, self).__init__()
self.layernorm = nn.LayerNorm(channels)
def forward(self, x):
x_hat = self.layernorm(x)
bias = torch.mean(x_hat, dim=1).unsqueeze(1).repeat(1, x.shape[1], 1)
return x_hat - bias
class moving_avg(nn.Module):
"""
Moving average block to highlight the trend of time series
"""
def __init__(self, kernel_size, stride):
super(moving_avg, self).__init__()
self.kernel_size = kernel_size
self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)
def forward(self, x):
# padding on the both ends of time series
front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1)
end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1)
x = torch.cat([front, x, end], dim=1)
x = self.avg(x.permute(0, 2, 1))
x = x.permute(0, 2, 1)
return x
class series_decomp(nn.Module):
"""
Series decomposition block
"""
def __init__(self, kernel_size):
super(series_decomp, self).__init__()
self.moving_avg = moving_avg(kernel_size, stride=1)
def forward(self, x):
moving_mean = self.moving_avg(x)
res = x - moving_mean
return res, moving_mean
class EncoderLayer(nn.Module):
"""
Autoformer encoder layer with the progressive decomposition architecture
"""
def __init__(self, attention, d_model, d_ff=None, moving_avg=25, dropout=0.1, activation="relu"):
super(EncoderLayer, self).__init__()
d_ff = d_ff or 4 * d_model
self.attention = attention
self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False)
self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False)
self.decomp1 = series_decomp(moving_avg)
self.decomp2 = series_decomp(moving_avg)
self.dropout = nn.Dropout(dropout)
self.activation = F.relu if activation == "relu" else F.gelu
def forward(self, x, attn_mask=None):
new_x, attn = self.attention(
x, x, x,
attn_mask=attn_mask
)
x = x + self.dropout(new_x)
x, _ = self.decomp1(x)
y = x
y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
y = self.dropout(self.conv2(y).transpose(-1, 1))
res, _ = self.decomp2(x + y)
return res, attn
class Encoder(nn.Module):
"""
Autoformer encoder
"""
def __init__(self, attn_layers, conv_layers=None, norm_layer=None):
super(Encoder, self).__init__()
self.attn_layers = nn.ModuleList(attn_layers)
self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None
self.norm = norm_layer
def forward(self, x, attn_mask=None):
attns = []
if self.conv_layers is not None:
for attn_layer, conv_layer in zip(self.attn_layers, self.conv_layers):
x, attn = attn_layer(x, attn_mask=attn_mask)
x = conv_layer(x)
attns.append(attn)
x, attn = self.attn_layers[-1](x)
attns.append(attn)
else:
for attn_layer in self.attn_layers:
x, attn = attn_layer(x, attn_mask=attn_mask)
attns.append(attn)
if self.norm is not None:
x = self.norm(x)
return x, attns
class DecoderLayer(nn.Module):
"""
Autoformer decoder layer with the progressive decomposition architecture
"""
def __init__(self, self_attention, cross_attention, d_model, c_out, d_ff=None,
moving_avg=25, dropout=0.1, activation="relu"):
super(DecoderLayer, self).__init__()
d_ff = d_ff or 4 * d_model
self.self_attention = self_attention
self.cross_attention = cross_attention
self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False)
self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False)
self.decomp1 = series_decomp(moving_avg)
self.decomp2 = series_decomp(moving_avg)
self.decomp3 = series_decomp(moving_avg)
self.dropout = nn.Dropout(dropout)
self.projection = nn.Conv1d(in_channels=d_model, out_channels=c_out, kernel_size=3, stride=1, padding=1,
padding_mode='circular', bias=False)
self.activation = F.relu if activation == "relu" else F.gelu
def forward(self, x, cross, x_mask=None, cross_mask=None):
x = x + self.dropout(self.self_attention(
x, x, x,
attn_mask=x_mask
)[0])
x, trend1 = self.decomp1(x)
x = x + self.dropout(self.cross_attention(
x, cross, cross,
attn_mask=cross_mask
)[0])
x, trend2 = self.decomp2(x)
y = x
y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
y = self.dropout(self.conv2(y).transpose(-1, 1))
x, trend3 = self.decomp3(x + y)
residual_trend = trend1 + trend2 + trend3
residual_trend = self.projection(residual_trend.permute(0, 2, 1)).transpose(1, 2)
return x, residual_trend
class Decoder(nn.Module):
"""
Autoformer encoder
"""
def __init__(self, layers, norm_layer=None, projection=None):
super(Decoder, self).__init__()
self.layers = nn.ModuleList(layers)
self.norm = norm_layer
self.projection = projection
def forward(self, x, cross, x_mask=None, cross_mask=None, trend=None):
for layer in self.layers:
x, residual_trend = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask)
trend = trend + residual_trend
if self.norm is not None:
x = self.norm(x)
if self.projection is not None:
x = self.projection(x)
return x, trend
================================================
FILE: probts/model/nn/arch/ChronosModule/__init__.py
================================================
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from .base import BaseChronosPipeline, ForecastType
from .chronos import (
ChronosConfig,
ChronosModel,
ChronosPipeline,
ChronosTokenizer,
MeanScaleUniformBins,
)
from .chronos_bolt import ChronosBoltConfig, ChronosBoltPipeline
__all__ = [
"BaseChronosPipeline",
"ForecastType",
"ChronosConfig",
"ChronosModel",
"ChronosPipeline",
"ChronosTokenizer",
"MeanScaleUniformBins",
"ChronosBoltConfig",
"ChronosBoltPipeline",
]
================================================
FILE: probts/model/nn/arch/ChronosModule/base.py
================================================
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
# Authors: Caner Turkmen , Abdul Fatir Ansari , Lorenzo Stella
# Original source:
# https://github.com/autogluon/autogluon/blob/f57beb26cb769c6e0d484a6af2b89eab8aee73a8/timeseries/src/autogluon/timeseries/models/chronos/pipeline/base.py
from enum import Enum
from pathlib import Path
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
if TYPE_CHECKING:
from transformers import PreTrainedModel
from .utils import left_pad_and_stack_1D
class ForecastType(Enum):
SAMPLES = "samples"
QUANTILES = "quantiles"
class PipelineRegistry(type):
REGISTRY: Dict[str, "PipelineRegistry"] = {}
def __new__(cls, name, bases, attrs):
"""See, https://github.com/faif/python-patterns."""
new_cls = type.__new__(cls, name, bases, attrs)
if name is not None:
cls.REGISTRY[name] = new_cls
return new_cls
class BaseChronosPipeline(metaclass=PipelineRegistry):
forecast_type: ForecastType
dtypes = {"bfloat16": torch.bfloat16, "float32": torch.float32}
def __init__(self, inner_model: "PreTrainedModel"):
"""
Parameters
----------
inner_model : PreTrainedModel
A hugging-face transformers PreTrainedModel, e.g., T5ForConditionalGeneration
"""
# for easy access to the inner HF-style model
self.inner_model = inner_model
def _prepare_and_validate_context(
self, context: Union[torch.Tensor, List[torch.Tensor]]
):
if isinstance(context, list):
context = left_pad_and_stack_1D(context)
assert isinstance(context, torch.Tensor)
if context.ndim == 1:
context = context.unsqueeze(0)
assert context.ndim == 2
return context
def predict(
self,
context: Union[torch.Tensor, List[torch.Tensor]],
prediction_length: Optional[int] = None,
**kwargs,
):
"""
Get forecasts for the given time series. Predictions will be
returned in fp32 on the cpu.
Parameters
----------
context
Input series. This is either a 1D tensor, or a list
of 1D tensors, or a 2D tensor whose first dimension
is batch. In the latter case, use left-padding with
``torch.nan`` to align series of different lengths.
prediction_length
Time steps to predict. Defaults to a model-dependent
value if not given.
Returns
-------
forecasts
Tensor containing forecasts. The layout and meaning
of the forecasts values depends on ``self.forecast_type``.
"""
raise NotImplementedError()
def predict_quantiles(
self,
context: Union[torch.Tensor, List[torch.Tensor]],
prediction_length: Optional[int] = None,
quantile_levels: List[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Get quantile and mean forecasts for given time series.
Predictions will be returned in fp32 on the cpu.
Parameters
----------
context : Union[torch.Tensor, List[torch.Tensor]]
Input series. This is either a 1D tensor, or a list
of 1D tensors, or a 2D tensor whose first dimension
is batch. In the latter case, use left-padding with
``torch.nan`` to align series of different lengths.
prediction_length : Optional[int], optional
Time steps to predict. Defaults to a model-dependent
value if not given.
quantile_levels : List[float], optional
Quantile levels to compute, by default [0.1, 0.2, ..., 0.9]
Returns
-------
quantiles
Tensor containing quantile forecasts. Shape
(batch_size, prediction_length, num_quantiles)
mean
Tensor containing mean (point) forecasts. Shape
(batch_size, prediction_length)
"""
raise NotImplementedError()
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Union[str, Path],
*model_args,
**kwargs,
):
"""
Load the model, either from a local path or from the HuggingFace Hub.
Supports the same arguments as ``AutoConfig`` and ``AutoModel``
from ``transformers``.
"""
from transformers import AutoConfig
torch_dtype = kwargs.get("torch_dtype", "auto")
if torch_dtype != "auto" and isinstance(torch_dtype, str):
kwargs["torch_dtype"] = cls.dtypes[torch_dtype]
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
is_valid_config = hasattr(config, "chronos_pipeline_class") or hasattr(
config, "chronos_config"
)
if not is_valid_config:
raise ValueError("Not a Chronos config file")
pipeline_class_name = getattr(
config, "chronos_pipeline_class", "ChronosPipeline"
)
class_ = PipelineRegistry.REGISTRY.get(pipeline_class_name)
if class_ is None:
raise ValueError(
f"Trying to load unknown pipeline class: {pipeline_class_name}"
)
return class_.from_pretrained( # type: ignore[attr-defined]
pretrained_model_name_or_path, *model_args, **kwargs
)
================================================
FILE: probts/model/nn/arch/ChronosModule/chronos.py
================================================
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
# Authors: Abdul Fatir Ansari , Lorenzo Stella , Caner Turkmen
import logging
from dataclasses import dataclass
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
from einops import rearrange
import sys
from .loss import LabelSmoother
import torch
import torch.nn as nn
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
GenerationConfig,
PreTrainedModel,
)
# import chronos
from probts.model.nn.arch import ChronosModule
from .base import BaseChronosPipeline, ForecastType
from .utils import left_pad_and_stack_1D
logger = logging.getLogger(__file__)
@dataclass
class ChronosConfig:
"""
This class holds all the configuration parameters to be used
by ``ChronosTokenizer`` and ``ChronosModel``.
"""
tokenizer_class: str
tokenizer_kwargs: Dict[str, Any]
context_length: int
prediction_length: int
n_tokens: int
n_special_tokens: int
pad_token_id: int
eos_token_id: int
use_eos_token: bool
model_type: Literal["causal", "seq2seq"]
num_samples: int
temperature: float
top_k: int
top_p: float
def __post_init__(self):
assert (
self.pad_token_id < self.n_special_tokens
and self.eos_token_id < self.n_special_tokens
), f"Special token id's must be smaller than {self.n_special_tokens=}"
def create_tokenizer(self) -> "ChronosTokenizer":
class_ = getattr(ChronosModule, self.tokenizer_class)
return class_(**self.tokenizer_kwargs, config=self)
class ChronosTokenizer:
"""
A ``ChronosTokenizer`` definines how time series are mapped into token IDs
and back.
For details, see the ``input_transform`` and ``output_transform`` methods,
which concrete classes must implement.
"""
def context_input_transform(
self,
context: torch.Tensor,
) -> Tuple:
"""
Turn a batch of time series into token IDs, attention map, and tokenizer_state.
Parameters
----------
context
A tensor shaped (batch_size, time_length), containing the
timeseries to forecast. Use left-padding with ``torch.nan``
to align time series of different lengths.
Returns
-------
token_ids
A tensor of integers, shaped (batch_size, time_length + 1)
if ``config.use_eos_token`` and (batch_size, time_length)
otherwise, containing token IDs for the input series.
attention_mask
A boolean tensor, same shape as ``token_ids``, indicating
which input observations are not ``torch.nan`` (i.e. not
missing nor padding).
tokenizer_state
An object that can be passed to ``label_input_transform``
and ``output_transform``. Contains the relevant information
to decode output samples into real values,
such as location and scale parameters.
"""
raise NotImplementedError()
def label_input_transform(self, label: torch.Tensor, tokenizer_state: Any) -> Tuple:
"""
Turn a batch of label slices of time series into token IDs and attention map
using the ``tokenizer_state`` provided by ``context_input_transform``.
Parameters
----------
context
A tensor shaped (batch_size, time_length), containing the
timeseries to forecast. Use left-padding with ``torch.nan``
to align time series of different lengths.
tokenizer_state
An object returned by ``context_input_transform`` containing
relevant information to preprocess data, such as location and
scale. The nature of this depends on the specific tokenizer.
This is used for tokenizing the label, in order to use the same
scaling used to tokenize the context.
Returns
-------
token_ids
A tensor of integers, shaped (batch_size, time_length + 1)
if ``config.use_eos_token`` and (batch_size, time_length)
otherwise, containing token IDs for the input series.
attention_mask
A boolean tensor, same shape as ``token_ids``, indicating
which input observations are not ``torch.nan`` (i.e. not
missing nor padding).
"""
raise NotImplementedError()
def output_transform(
self, samples: torch.Tensor, tokenizer_state: Any
) -> torch.Tensor:
"""
Turn a batch of sample token IDs into real values.
Parameters
----------
samples
A tensor of integers, shaped (batch_size, num_samples, time_length),
containing token IDs of sample trajectories.
tokenizer_state
An object returned by ``input_transform`` containing
relevant context to decode samples, such as location and scale.
The nature of this depends on the specific tokenizer.
Returns
-------
forecasts
A real tensor, shaped (batch_size, num_samples, time_length),
containing forecasted sample paths.
"""
raise NotImplementedError()
class MeanScaleUniformBins(ChronosTokenizer):
def __init__(
self, low_limit: float, high_limit: float, config: ChronosConfig,
) -> None:
self.config = config
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
self.centers = torch.linspace(
low_limit,
high_limit,
config.n_tokens - config.n_special_tokens - 1,
).to(device)
self.boundaries = torch.concat(
(
torch.tensor([-1e20], device=self.centers.device),
(self.centers[1:] + self.centers[:-1]) / 2,
torch.tensor([1e20], device=self.centers.device),
)
)
def _input_transform(
self, context: torch.Tensor, scale: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
context = context.to(dtype=torch.float32)
attention_mask = ~torch.isnan(context) #.to(context.device)
if scale is None:
scale = torch.nansum(
torch.abs(context) * attention_mask, dim=-1
) / torch.nansum(attention_mask, dim=-1)
scale[~(scale > 0)] = 1.0
scaled_context = context / scale.unsqueeze(dim=-1)
token_ids = (
torch.bucketize(
input=scaled_context,
boundaries=self.boundaries,
# buckets are open to the right, see:
# https://pytorch.org/docs/2.1/generated/torch.bucketize.html#torch-bucketize
right=True,
)
+ self.config.n_special_tokens
)
token_ids.clamp_(0, self.config.n_tokens - 1)
token_ids[~attention_mask] = self.config.pad_token_id
return token_ids, attention_mask, scale
def _append_eos_token(
self, token_ids: torch.Tensor, attention_mask: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size = token_ids.shape[0]
eos_tokens = torch.full((batch_size, 1), fill_value=self.config.eos_token_id).to(token_ids.device)
token_ids = torch.concat((token_ids, eos_tokens), dim=1)
eos_mask = torch.full((batch_size, 1), fill_value=True).to(attention_mask.device)
attention_mask = torch.concat((attention_mask, eos_mask), dim=1)
return token_ids, attention_mask
def context_input_transform(
self, context: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
length = context.shape[-1]
if length > self.config.context_length:
context = context[..., -self.config.context_length :]
token_ids, attention_mask, scale = self._input_transform(context=context)
if self.config.use_eos_token and self.config.model_type == "seq2seq":
token_ids, attention_mask = self._append_eos_token(
token_ids=token_ids, attention_mask=attention_mask
)
return token_ids, attention_mask, scale
def label_input_transform(
self, label: torch.Tensor, scale: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
length = label.shape[-1]
assert length == self.config.prediction_length
token_ids, attention_mask, _ = self._input_transform(context=label, scale=scale)
if self.config.use_eos_token:
token_ids, attention_mask = self._append_eos_token(
token_ids=token_ids, attention_mask=attention_mask
)
return token_ids, attention_mask
def output_transform(
self, samples: torch.Tensor, scale: torch.Tensor
) -> torch.Tensor:
scale_unsqueezed = scale.unsqueeze(-1).unsqueeze(-1)
indices = torch.clamp(
samples - self.config.n_special_tokens - 1,
min=0,
max=len(self.centers) - 1,
)
return self.centers[indices] * scale_unsqueezed
class ChronosModel(nn.Module):
"""
A ``ChronosModel`` wraps a ``PreTrainedModel`` object from ``transformers``
and uses it to predict sample paths for time series tokens.
Parameters
----------
config
The configuration to use.
model
The pretrained model to use.
"""
def __init__(self, config: ChronosConfig, model: PreTrainedModel) -> None:
super().__init__()
self.config = config
self.model = model
@property
def device(self):
return self.model.device
def encode(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
):
"""
Extract the encoder embedding for the given token sequences.
Parameters
----------
input_ids
Tensor of indices of input sequence tokens in the vocabulary
with shape (batch_size, sequence_length).
attention_mask
A mask tensor of the same shape as input_ids to avoid attending
on padding or missing tokens.
Returns
-------
embedding
A tensor of encoder embeddings with shape
(batch_size, sequence_length, d_model).
"""
assert (
self.config.model_type == "seq2seq"
), "Encoder embeddings are only supported for encoder-decoder models"
return self.model.encoder(
input_ids=input_ids, attention_mask=attention_mask
).last_hidden_state
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
prediction_length: Optional[int] = None,
num_samples: Optional[int] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
) -> torch.Tensor:
"""
Predict future sample tokens for the given token sequences.
Arguments ``prediction_length``, ``num_samples``, ``temperature``,
``top_k``, ``top_p`` can be used to customize the model inference,
and default to the corresponding attributes in ``self.config`` if
not provided.
Returns
-------
samples
A tensor of integers, shaped (batch_size, num_samples, time_length),
containing forecasted sample paths.
"""
if prediction_length is None:
prediction_length = self.config.prediction_length
if num_samples is None:
num_samples = self.config.num_samples
if temperature is None:
temperature = self.config.temperature
if top_k is None:
top_k = self.config.top_k
if top_p is None:
top_p = self.config.top_p
preds = self.model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
generation_config=GenerationConfig(
min_new_tokens=prediction_length,
max_new_tokens=prediction_length,
do_sample=True,
num_return_sequences=num_samples,
eos_token_id=self.config.eos_token_id,
pad_token_id=self.config.pad_token_id,
temperature=temperature,
top_k=top_k,
top_p=top_p,
),
)
if self.config.model_type == "seq2seq":
preds = preds[..., 1:] # remove the decoder start token
else:
assert self.config.model_type == "causal"
assert preds.size(-1) == input_ids.size(-1) + prediction_length
preds = preds[..., -prediction_length:]
return preds.reshape(input_ids.size(0), num_samples, -1)
class ChronosPipeline(BaseChronosPipeline):
"""
A ``ChronosPipeline`` uses the given tokenizer and model to forecast
input time series.
Use the ``from_pretrained`` class method to load serialized models.
Use the ``predict`` method to get forecasts.
Parameters
----------
tokenizer
The tokenizer object to use.
model
The model to use.
"""
tokenizer: ChronosTokenizer
model: ChronosModel
forecast_type: ForecastType = ForecastType.SAMPLES
def __init__(self, tokenizer, model):
super().__init__(inner_model=model.model)
self.tokenizer = tokenizer
self.model = model
self.loss_func = LabelSmoother()
def _prepare_and_validate_context(
self, context: Union[torch.Tensor, List[torch.Tensor]]
):
if isinstance(context, list):
context = left_pad_and_stack_1D(context)
assert isinstance(context, torch.Tensor)
if context.ndim == 1:
context = context.unsqueeze(0)
assert context.ndim == 2
return context
@torch.no_grad()
def embed(
self, context: Union[torch.Tensor, List[torch.Tensor]]
) -> Tuple[torch.Tensor, Any]:
"""
Get encoder embeddings for the given time series.
Parameters
----------
context
Input series. This is either a 1D tensor, or a list
of 1D tensors, or a 2D tensor whose first dimension
is batch. In the latter case, use left-padding with
``torch.nan`` to align series of different lengths.
Returns
-------
embeddings, tokenizer_state
A tuple of two tensors: the encoder embeddings and the tokenizer_state,
e.g., the scale of the time series in the case of mean scaling.
The encoder embeddings are shaped (batch_size, context_length, d_model)
or (batch_size, context_length + 1, d_model), where context_length
is the size of the context along the time axis if a 2D tensor was provided
or the length of the longest time series, if a list of 1D tensors was
provided, and the extra 1 is for EOS.
"""
context_tensor = self._prepare_and_validate_context(context=context)
token_ids, attention_mask, tokenizer_state = (
self.tokenizer.context_input_transform(context_tensor)
)
embeddings = self.model.encode(
input_ids=token_ids.to(self.model.device),
attention_mask=attention_mask.to(self.model.device),
).cpu()
return embeddings, tokenizer_state
def predict( # type: ignore[override]
self,
context: Union[torch.Tensor, List[torch.Tensor]],
prediction_length: Optional[int] = None,
num_samples: Optional[int] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
limit_prediction_length: bool = False,
) -> torch.Tensor:
"""
Get forecasts for the given time series.
Refer to the base method (``BaseChronosPipeline.predict``)
for details on shared parameters.
Additional parameters
---------------------
num_samples
Number of sample paths to predict. Defaults to what
specified in ``self.model.config``.
temperature
Temperature to use for generating sample tokens.
Defaults to what specified in ``self.model.config``.
top_k
Top-k parameter to use for generating sample tokens.
Defaults to what specified in ``self.model.config``.
top_p
Top-p parameter to use for generating sample tokens.
Defaults to what specified in ``self.model.config``.
limit_prediction_length
Force prediction length smaller or equal than the
built-in prediction length from the model. False by
default. When true, fail loudly if longer predictions
are requested, otherwise longer predictions are allowed.
Returns
-------
samples
Tensor of sample forecasts, of shape
(batch_size, num_samples, prediction_length).
"""
context_tensor = self._prepare_and_validate_context(context=context)
if prediction_length is None:
prediction_length = self.model.config.prediction_length
# if prediction_length > self.model.config.prediction_length:
# msg = (
# f"We recommend keeping prediction length <= {self.model.config.prediction_length}. "
# "The quality of longer predictions may degrade since the model is not optimized for it. "
# )
# if limit_prediction_length:
# msg += "You can turn off this check by setting `limit_prediction_length=False`."
# raise ValueError(msg)
# logger.warning(msg)
predictions = []
remaining = prediction_length
while remaining > 0:
token_ids, attention_mask, scale = self.tokenizer.context_input_transform(
context_tensor
)
samples = self.model(
token_ids.to(self.model.device),
attention_mask.to(self.model.device),
min(remaining, self.model.config.prediction_length),
num_samples,
temperature,
top_k,
top_p,
)
prediction = self.tokenizer.output_transform(
samples.to(scale.device), scale
)
predictions.append(prediction)
remaining -= prediction.shape[-1]
if remaining <= 0:
break
context_tensor = torch.cat(
[context_tensor, prediction.median(dim=1).values], dim=-1
)
return torch.cat(predictions, dim=-1).to(dtype=torch.float32, device="cpu")
def predict_quantiles(
self,
context: Union[torch.Tensor, List[torch.Tensor]],
prediction_length: Optional[int] = None,
quantile_levels: List[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
**predict_kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Refer to the base method (``BaseChronosPipeline.predict_quantiles``).
"""
shape_dim = context.shape
if len(shape_dim) == 3:
context = rearrange(context, 'b k l -> (b k) l')
prediction_samples = (
self.predict(context, prediction_length=prediction_length, **predict_kwargs)
.detach()
.swapaxes(1, 2)
)
mean = prediction_samples.mean(dim=-1)
quantiles = torch.quantile(
prediction_samples,
q=torch.tensor(quantile_levels, dtype=prediction_samples.dtype),
dim=-1,
).permute(1, 2, 0)
if len(shape_dim) == 3:
quantiles = rearrange(quantiles, '(b k) l q -> b k l q', b=shape_dim[0])
mean = rearrange(mean, '(b k) l -> b k l',b=shape_dim[0])
return mean, quantiles
@classmethod
def from_pretrained(cls, *args, **kwargs):
"""
Load the model, either from a local path or from the HuggingFace Hub.
Supports the same arguments as ``AutoConfig`` and ``AutoModel``
from ``transformers``.
"""
config = AutoConfig.from_pretrained(*args, **kwargs)
assert hasattr(config, "chronos_config"), "Not a Chronos config file"
chronos_config = ChronosConfig(**config.chronos_config)
if chronos_config.model_type == "seq2seq":
inner_model = AutoModelForSeq2SeqLM.from_pretrained(*args, **kwargs)
else:
assert chronos_config.model_type == "causal"
inner_model = AutoModelForCausalLM.from_pretrained(*args, **kwargs)
return cls(
tokenizer=chronos_config.create_tokenizer(),
model=ChronosModel(config=chronos_config, model=inner_model),
)
================================================
FILE: probts/model/nn/arch/ChronosModule/chronos_bolt.py
================================================
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
# Authors: Abdul Fatir Ansari , Caner Turkmen , Lorenzo Stella
# Original source:
# https://github.com/autogluon/autogluon/blob/f57beb26cb769c6e0d484a6af2b89eab8aee73a8/timeseries/src/autogluon/timeseries/models/chronos/pipeline/chronos_bolt.py
import copy
import logging
import warnings
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from transformers import AutoConfig
from transformers.models.t5.modeling_t5 import (
ACT2FN,
T5Config,
T5LayerNorm,
T5PreTrainedModel,
T5Stack,
)
from transformers.utils import ModelOutput
from .base import BaseChronosPipeline, ForecastType
logger = logging.getLogger(__file__)
@dataclass
class ChronosBoltConfig:
context_length: int
prediction_length: int
input_patch_size: int
input_patch_stride: int
quantiles: List[float]
use_reg_token: bool = False
@dataclass
class ChronosBoltOutput(ModelOutput):
loss: Optional[torch.Tensor] = None
quantile_preds: Optional[torch.Tensor] = None
attentions: Optional[torch.Tensor] = None
cross_attentions: Optional[torch.Tensor] = None
class Patch(nn.Module):
def __init__(self, patch_size: int, patch_stride: int) -> None:
super().__init__()
self.patch_size = patch_size
self.patch_stride = patch_stride
def forward(self, x: torch.Tensor) -> torch.Tensor:
length = x.shape[-1]
if length % self.patch_size != 0:
padding_size = (
*x.shape[:-1],
self.patch_size - (length % self.patch_size),
)
padding = torch.full(
size=padding_size, fill_value=torch.nan, dtype=x.dtype, device=x.device
)
x = torch.concat((padding, x), dim=-1)
x = x.unfold(dimension=-1, size=self.patch_size, step=self.patch_stride)
return x
class InstanceNorm(nn.Module):
"""
See, also, RevIN. Apply standardization along the last dimension.
"""
def __init__(self, eps: float = 1e-5) -> None:
super().__init__()
self.eps = eps
def forward(
self,
x: torch.Tensor,
loc_scale: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if loc_scale is None:
loc = torch.nan_to_num(torch.nanmean(x, dim=-1, keepdim=True), nan=0.0)
scale = torch.nan_to_num(
torch.nanmean((x - loc).square(), dim=-1, keepdim=True).sqrt(), nan=1.0
)
scale = torch.where(scale == 0, torch.abs(loc) + self.eps, scale)
else:
loc, scale = loc_scale
return (x - loc) / scale, (loc, scale)
def inverse(
self, x: torch.Tensor, loc_scale: Tuple[torch.Tensor, torch.Tensor]
) -> torch.Tensor:
loc, scale = loc_scale
return x * scale + loc
class ResidualBlock(nn.Module):
def __init__(
self,
in_dim: int,
h_dim: int,
out_dim: int,
act_fn_name: str,
dropout_p: float = 0.0,
use_layer_norm: bool = False,
) -> None:
super().__init__()
self.dropout = nn.Dropout(dropout_p)
self.hidden_layer = nn.Linear(in_dim, h_dim)
self.act = ACT2FN[act_fn_name]
self.output_layer = nn.Linear(h_dim, out_dim)
self.residual_layer = nn.Linear(in_dim, out_dim)
self.use_layer_norm = use_layer_norm
if use_layer_norm:
self.layer_norm = T5LayerNorm(out_dim)
def forward(self, x: torch.Tensor):
hid = self.act(self.hidden_layer(x))
out = self.dropout(self.output_layer(hid))
res = self.residual_layer(x)
out = out + res
if self.use_layer_norm:
return self.layer_norm(out)
return out
class ChronosBoltModelForForecasting(T5PreTrainedModel):
_keys_to_ignore_on_load_missing = [
r"input_patch_embedding\.",
r"output_patch_embedding\.",
]
_keys_to_ignore_on_load_unexpected = [r"lm_head.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: T5Config):
assert hasattr(config, "chronos_config"), "Not a Chronos config file"
super().__init__(config)
self.model_dim = config.d_model
self.chronos_config = ChronosBoltConfig(**config.chronos_config)
# Only decoder_start_id (and optionally REG token)
if self.chronos_config.use_reg_token:
config.reg_token_id = 1
config.vocab_size = 2 if self.chronos_config.use_reg_token else 1
self.shared = nn.Embedding(config.vocab_size, config.d_model)
# Input patch embedding layer
self.input_patch_embedding = ResidualBlock(
in_dim=self.chronos_config.input_patch_size * 2,
h_dim=config.d_ff,
out_dim=config.d_model,
act_fn_name=config.dense_act_fn,
dropout_p=config.dropout_rate,
)
# patching layer
self.patch = Patch(
patch_size=self.chronos_config.input_patch_size,
patch_stride=self.chronos_config.input_patch_stride,
)
# instance normalization, also referred to as "scaling" in Chronos and GluonTS
self.instance_norm = InstanceNorm()
encoder_config = copy.deepcopy(config)
encoder_config.is_decoder = False
encoder_config.use_cache = False
encoder_config.is_encoder_decoder = False
self.encoder = T5Stack(encoder_config, self.shared)
self._init_decoder(config)
self.num_quantiles = len(self.chronos_config.quantiles)
quantiles = torch.tensor(self.chronos_config.quantiles, dtype=self.dtype)
self.register_buffer("quantiles", quantiles, persistent=False)
self.output_patch_embedding = ResidualBlock(
in_dim=config.d_model,
h_dim=config.d_ff,
out_dim=self.num_quantiles * self.chronos_config.prediction_length,
act_fn_name=config.dense_act_fn,
dropout_p=config.dropout_rate,
)
# Initialize weights and apply final processing
self.post_init()
# Model parallel
self.model_parallel = False
self.device_map = None
def _init_weights(self, module):
super()._init_weights(module)
"""Initialize the weights"""
factor = self.config.initializer_factor
if isinstance(module, (self.__class__)):
module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
elif isinstance(module, ResidualBlock):
module.hidden_layer.weight.data.normal_(
mean=0.0,
std=factor * ((self.chronos_config.input_patch_size * 2) ** -0.5),
)
if (
hasattr(module.hidden_layer, "bias")
and module.hidden_layer.bias is not None
):
module.hidden_layer.bias.data.zero_()
module.residual_layer.weight.data.normal_(
mean=0.0,
std=factor * ((self.chronos_config.input_patch_size * 2) ** -0.5),
)
if (
hasattr(module.residual_layer, "bias")
and module.residual_layer.bias is not None
):
module.residual_layer.bias.data.zero_()
module.output_layer.weight.data.normal_(
mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)
)
if (
hasattr(module.output_layer, "bias")
and module.output_layer.bias is not None
):
module.output_layer.bias.data.zero_()
def encode(
self, context: torch.Tensor, mask: Optional[torch.Tensor] = None
) -> Tuple[
torch.Tensor, Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor
]:
mask = (
mask.to(context.dtype)
if mask is not None
else torch.isnan(context).logical_not().to(context.dtype)
)
batch_size, _ = context.shape
if context.shape[-1] > self.chronos_config.context_length:
context = context[..., -self.chronos_config.context_length :]
mask = mask[..., -self.chronos_config.context_length :]
# scaling
context, loc_scale = self.instance_norm(context)
# the scaling op above is done in 32-bit precision,
# then the context is moved to model's dtype
context = context.to(self.dtype)
mask = mask.to(self.dtype)
# patching
patched_context = self.patch(context)
patched_mask = torch.nan_to_num(self.patch(mask), nan=0.0)
patched_context = torch.where(patched_mask > 0.0, patched_context, 0.0)
# concat context and mask along patch dim
patched_context = torch.cat([patched_context, patched_mask], dim=-1)
# attention_mask = 1 if at least one item in the patch is observed
attention_mask = (
patched_mask.sum(dim=-1) > 0
) # (batch_size, patched_seq_length)
input_embeds = self.input_patch_embedding(patched_context)
if self.chronos_config.use_reg_token:
# Append [REG]
reg_input_ids = torch.full(
(batch_size, 1),
self.config.reg_token_id,
device=input_embeds.device,
)
reg_embeds = self.shared(reg_input_ids)
input_embeds = torch.cat([input_embeds, reg_embeds], dim=-2)
attention_mask = torch.cat(
[
attention_mask.to(self.dtype),
torch.ones_like(reg_input_ids).to(self.dtype),
],
dim=-1,
)
encoder_outputs = self.encoder(
attention_mask=attention_mask,
inputs_embeds=input_embeds,
)
return encoder_outputs[0], loc_scale, input_embeds, attention_mask
def forward(
self,
context: torch.Tensor,
mask: Optional[torch.Tensor] = None,
target: Optional[torch.Tensor] = None,
target_mask: Optional[torch.Tensor] = None,
) -> ChronosBoltOutput:
batch_size = context.size(0)
hidden_states, loc_scale, input_embeds, attention_mask = self.encode(
context=context, mask=mask
)
sequence_output = self.decode(input_embeds, attention_mask, hidden_states)
quantile_preds_shape = (
batch_size,
self.num_quantiles,
self.chronos_config.prediction_length,
)
quantile_preds = self.output_patch_embedding(sequence_output).view(
*quantile_preds_shape
)
loss = None
if target is not None:
# normalize target
target, _ = self.instance_norm(target, loc_scale)
target = target.unsqueeze(1) # type: ignore
assert self.chronos_config.prediction_length >= target.shape[-1]
target = target.to(quantile_preds.device)
target_mask = (
target_mask.unsqueeze(1).to(quantile_preds.device)
if target_mask is not None
else ~torch.isnan(target)
)
target[~target_mask] = 0.0
# pad target and target_mask if they are shorter than model's prediction_length
if self.chronos_config.prediction_length > target.shape[-1]:
padding_shape = (
*target.shape[:-1],
self.chronos_config.prediction_length - target.shape[-1],
)
target = torch.cat(
[target, torch.zeros(padding_shape).to(target)], dim=-1
)
target_mask = torch.cat(
[target_mask, torch.zeros(padding_shape).to(target_mask)], dim=-1
)
loss = (
2
* torch.abs(
(target - quantile_preds)
* (
(target <= quantile_preds).float()
- self.quantiles.view(1, self.num_quantiles, 1)
)
)
* target_mask.float()
)
loss = loss.mean(dim=-2) # Mean over prediction horizon
loss = loss.sum(dim=-1) # Sum over quantile levels
loss = loss.mean() # Mean over batch
# Unscale predictions
quantile_preds = self.instance_norm.inverse(
quantile_preds.view(batch_size, -1),
loc_scale,
).view(*quantile_preds_shape)
return ChronosBoltOutput(
loss=loss,
quantile_preds=quantile_preds,
)
def _init_decoder(self, config):
decoder_config = copy.deepcopy(config)
decoder_config.is_decoder = True
decoder_config.is_encoder_decoder = False
decoder_config.num_layers = config.num_decoder_layers
self.decoder = T5Stack(decoder_config, self.shared)
def decode(
self,
input_embeds,
attention_mask,
hidden_states,
output_attentions=False,
):
"""
Parameters
----------
input_embeds: torch.Tensor
Patched and embedded inputs. Shape (batch_size, patched_context_length, d_model)
attention_mask: torch.Tensor
Attention mask for the patched context. Shape (batch_size, patched_context_length), type: torch.int64
hidden_states: torch.Tensor
Hidden states returned by the encoder. Shape (batch_size, patched_context_length, d_model)
Returns
-------
last_hidden_state
Last hidden state returned by the decoder, of shape (batch_size, 1, d_model)
"""
batch_size = input_embeds.shape[0]
decoder_input_ids = torch.full(
(batch_size, 1),
self.config.decoder_start_token_id,
device=input_embeds.device,
)
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
encoder_hidden_states=hidden_states,
encoder_attention_mask=attention_mask,
output_attentions=output_attentions,
return_dict=True,
)
return decoder_outputs.last_hidden_state # sequence_outputs, b x 1 x d_model
class ChronosBoltPipeline(BaseChronosPipeline):
forecast_type: ForecastType = ForecastType.QUANTILES
default_context_length: int = 2048
def __init__(self, model: ChronosBoltModelForForecasting):
super().__init__(inner_model=model)
self.model = model
@property
def quantiles(self) -> List[float]:
return self.model.config.chronos_config["quantiles"]
@torch.no_grad()
def embed(
self, context: Union[torch.Tensor, List[torch.Tensor]]
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Get encoder embeddings for the given time series.
Parameters
----------
context
Input series. This is either a 1D tensor, or a list
of 1D tensors, or a 2D tensor whose first dimension
is batch. In the latter case, use left-padding with
``torch.nan`` to align series of different lengths.
Returns
-------
embeddings, loc_scale
A tuple of two items: the encoder embeddings and the loc_scale,
i.e., the mean and std of the original time series.
The encoder embeddings are shaped (batch_size, num_patches + 1, d_model),
where num_patches is the number of patches in the time series
and the extra 1 is for the [REG] token (if used by the model).
"""
context_tensor = self._prepare_and_validate_context(context=context)
model_context_length = self.model.config.chronos_config["context_length"]
if context_tensor.shape[-1] > model_context_length:
context_tensor = context_tensor[..., -model_context_length:]
context_tensor = context_tensor.to(
device=self.model.device,
dtype=torch.float32,
)
embeddings, loc_scale, *_ = self.model.encode(context=context_tensor)
return embeddings.cpu(), (
loc_scale[0].squeeze(-1).cpu(),
loc_scale[1].squeeze(-1).cpu(),
)
def predict( # type: ignore[override]
self,
context: Union[torch.Tensor, List[torch.Tensor]],
prediction_length: Optional[int] = None,
limit_prediction_length: bool = False,
) -> torch.Tensor:
"""
Get forecasts for the given time series.
Refer to the base method (``BaseChronosPipeline.predict``)
for details on shared parameters.
Additional parameters
---------------------
limit_prediction_length
Force prediction length smaller or equal than the
built-in prediction length from the model. False by
default. When true, fail loudly if longer predictions
are requested, otherwise longer predictions are allowed.
Returns
-------
torch.Tensor
Forecasts of shape (batch_size, num_quantiles, prediction_length)
where num_quantiles is the number of quantiles the model has been
trained to output. For official Chronos-Bolt models, the value of
num_quantiles is 9 for [0.1, 0.2, ..., 0.9]-quantiles.
Raises
------
ValueError
When limit_prediction_length is True and the prediction_length is
greater than model's trainig prediction_length.
"""
context_tensor = self._prepare_and_validate_context(context=context)
model_context_length = self.model.config.chronos_config["context_length"]
model_prediction_length = self.model.config.chronos_config["prediction_length"]
if prediction_length is None:
prediction_length = model_prediction_length
if prediction_length > model_prediction_length:
msg = (
f"We recommend keeping prediction length <= {model_prediction_length}. "
"The quality of longer predictions may degrade since the model is not optimized for it. "
)
if limit_prediction_length:
msg += "You can turn off this check by setting `limit_prediction_length=False`."
raise ValueError(msg)
warnings.warn(msg)
predictions = []
remaining = prediction_length
# We truncate the context here because otherwise batches with very long
# context could take up large amounts of GPU memory unnecessarily.
if context_tensor.shape[-1] > model_context_length:
context_tensor = context_tensor[..., -model_context_length:]
# TODO: We unroll the forecast of Chronos Bolt greedily with the full forecast
# horizon that the model was trained with (i.e., 64). This results in variance collapsing
# every 64 steps.
context_tensor = context_tensor.to(
device=self.model.device,
dtype=torch.float32,
)
while remaining > 0:
with torch.no_grad():
prediction = self.model(
context=context_tensor,
).quantile_preds.to(context_tensor)
predictions.append(prediction)
remaining -= prediction.shape[-1]
if remaining <= 0:
break
central_idx = torch.abs(torch.tensor(self.quantiles) - 0.5).argmin()
central_prediction = prediction[:, central_idx]
context_tensor = torch.cat([context_tensor, central_prediction], dim=-1)
return torch.cat(predictions, dim=-1)[..., :prediction_length].to(
dtype=torch.float32, device="cpu"
)
def predict_quantiles(
self,
context: Union[torch.Tensor, List[torch.Tensor]],
prediction_length: Optional[int] = None,
quantile_levels: List[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
**predict_kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Refer to the base method (``BaseChronosPipeline.predict_quantiles``).
"""
# shape (batch_size, prediction_length, len(training_quantile_levels))
predictions = (
self.predict(context, prediction_length=prediction_length, **predict_kwargs)
.detach()
.swapaxes(1, 2)
)
training_quantile_levels = self.quantiles
if set(quantile_levels).issubset(set(training_quantile_levels)):
# no need to perform intra/extrapolation
quantiles = predictions[
..., [training_quantile_levels.index(q) for q in quantile_levels]
]
else:
# we rely on torch for interpolating quantiles if quantiles that
# Chronos Bolt was trained on were not provided
if min(quantile_levels) < min(training_quantile_levels) or max(
quantile_levels
) > max(training_quantile_levels):
logger.warning(
f"\tQuantiles to be predicted ({quantile_levels}) are not within the range of "
f"quantiles that Chronos-Bolt was trained on ({training_quantile_levels}). "
"Quantile predictions will be set to the minimum/maximum levels at which Chronos-Bolt "
"was trained on. This may significantly affect the quality of the predictions."
)
# TODO: this is a hack that assumes the model's quantiles during training (training_quantile_levels)
# made up an equidistant grid along the quantile dimension. i.e., they were (0.1, 0.2, ..., 0.9).
# While this holds for official Chronos-Bolt models, this may not be true in the future, and this
# function may have to be revised.
augmented_predictions = torch.cat(
[predictions[..., [0]], predictions, predictions[..., [-1]]],
dim=-1,
)
quantiles = torch.quantile(
augmented_predictions,
q=torch.tensor(quantile_levels, dtype=augmented_predictions.dtype),
dim=-1,
).permute(1, 2, 0)
# NOTE: the median is returned as the mean here
mean = predictions[:, :, training_quantile_levels.index(0.5)]
return quantiles, mean
@classmethod
def from_pretrained(cls, *args, **kwargs):
"""
Load the model, either from a local path or from the HuggingFace Hub.
Supports the same arguments as ``AutoConfig`` and ``AutoModel``
from ``transformers``.
"""
config = AutoConfig.from_pretrained(*args, **kwargs)
assert hasattr(config, "chronos_config"), "Not a Chronos config file"
architecture = config.architectures[0]
class_ = globals().get(architecture)
if class_ is None:
logger.warning(
f"Unknown architecture: {architecture}, defaulting to ChronosBoltModelForForecasting"
)
class_ = ChronosBoltModelForForecasting
model = class_.from_pretrained(*args, **kwargs)
return cls(model=model)
================================================
FILE: probts/model/nn/arch/ChronosModule/loss.py
================================================
import torch
import torch.nn as nn
# from huggingface transformers/trainer_pt_utils.py
class LabelSmoother:
"""
Adds label-smoothing on a pre-computed output from a Transformers model.
Args:
epsilon (`float`, *optional*, defaults to 0.1):
The label smoothing factor.
ignore_index (`int`, *optional*, defaults to -100):
The index in the labels to ignore when computing the loss.
"""
epsilon: float = 0.1
ignore_index: int = -100
def __call__(self, model_output, labels):
# logits = model_output["logits"] if isinstance(model_output, dict) else model_output[0]
logits = model_output["logits"] if isinstance(model_output, dict) else model_output
logits = logits.to(torch.float32)
log_probs = -nn.functional.log_softmax(logits, dim=-1)
if labels.dim() == log_probs.dim() - 1:
labels = labels.unsqueeze(-1)
padding_mask = labels.eq(self.ignore_index)
# In case the ignore_index is -100, the gather will fail, so we replace labels by 0. The padding_mask
# will ignore them in any case.
labels = torch.clamp(labels, min=0)
nll_loss = log_probs.gather(dim=-1, index=labels)
# works for fp16 input tensor too, by internally upcasting it to fp32
smoothed_loss = log_probs.sum(dim=-1, keepdim=True, dtype=torch.float32)
nll_loss.masked_fill_(padding_mask, 0.0)
smoothed_loss.masked_fill_(padding_mask, 0.0)
# Take the mean over the label dimensions, then divide by the number of active elements (i.e. not-padded):
num_active_elements = padding_mask.numel() - padding_mask.long().sum()
nll_loss = nll_loss.sum() / num_active_elements
smoothed_loss = smoothed_loss.sum() / (num_active_elements * log_probs.shape[-1])
return (1 - self.epsilon) * nll_loss + self.epsilon * smoothed_loss
================================================
FILE: probts/model/nn/arch/ChronosModule/utils.py
================================================
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from typing import List
import torch
def left_pad_and_stack_1D(tensors: List[torch.Tensor]) -> torch.Tensor:
max_len = max(len(c) for c in tensors)
padded = []
for c in tensors:
assert isinstance(c, torch.Tensor)
assert c.ndim == 1
padding = torch.full(
size=(max_len - len(c),), fill_value=torch.nan, device=c.device
)
padded.append(torch.concat((padding, c), dim=-1))
return torch.stack(padded)
================================================
FILE: probts/model/nn/arch/Conv_Blocks.py
================================================
import torch
import torch.nn as nn
class Inception_Block_V1(nn.Module):
def __init__(self, in_channels, out_channels, num_kernels=6, init_weight=True):
super(Inception_Block_V1, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.num_kernels = num_kernels
kernels = []
for i in range(self.num_kernels):
kernels.append(nn.Conv2d(in_channels, out_channels, kernel_size=2 * i + 1, padding=i))
self.kernels = nn.ModuleList(kernels)
if init_weight:
self._initialize_weights()
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
res_list = []
for i in range(self.num_kernels):
res_list.append(self.kernels[i](x))
res = torch.stack(res_list, dim=-1).mean(-1)
return res
class Inception_Block_V2(nn.Module):
def __init__(self, in_channels, out_channels, num_kernels=6, init_weight=True):
super(Inception_Block_V2, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.num_kernels = num_kernels
kernels = []
for i in range(self.num_kernels // 2):
kernels.append(nn.Conv2d(in_channels, out_channels, kernel_size=[1, 2 * i + 3], padding=[0, i + 1]))
kernels.append(nn.Conv2d(in_channels, out_channels, kernel_size=[2 * i + 3, 1], padding=[i + 1, 0]))
kernels.append(nn.Conv2d(in_channels, out_channels, kernel_size=1))
self.kernels = nn.ModuleList(kernels)
if init_weight:
self._initialize_weights()
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
res_list = []
for i in range(self.num_kernels + 1):
res_list.append(self.kernels[i](x))
res = torch.stack(res_list, dim=-1).mean(-1)
return res
================================================
FILE: probts/model/nn/arch/ElasTSTModule/ElasTST_backbone.py
================================================
__all__ = ['PatchTST_backbone']
# Cell
from typing import Callable, Optional
import torch
from torch import nn
from torch import Tensor
import numpy as np
from einops import rearrange, repeat
from probts.utils.position_emb import Time_Encoder, sin_cos_encoding
from probts.model.nn.arch.ElasTSTModule.Layers import EncoderLayer
# Cell
class ElasTST_backbone(nn.Module):
def __init__(self,
l_patch_size: list,
stride: int = None,
k_patch_size: int = 1,
in_channels: int = 1,
n_layers: int = 0,
t_layers: int = 1,
v_layers: int = 1,
hidden_size: int = 256,
n_heads: int = 16,
d_k: Optional[int] = None,
d_v: Optional[int] = None,
d_inner: int = 256,
dropout: float = 0.,
rotate: bool = False,
max_seq_len = 1000,
theta = 10000,
learnable_theta = False,
addv: bool = False,
bin_att: bool = False,
abs_tem_emb: bool = False,
learn_tem_emb: bool = False,
structured_mask: bool = True,
rope_theta_init: str = 'exp',
min_period: float = 1,
max_period: float = 1000,
patch_share_backbone: bool = True,):
super().__init__()
if rotate:
print(f'Using Rotary Embedding... [theta init]: {rope_theta_init}, [period range]: [{min_period},{max_period}], [learnable]: {learnable_theta}')
print("[Binary Att.]: ", bin_att, " [Learned time emb]: ", learn_tem_emb, " [Abs time emb]: ", abs_tem_emb)
print("[Multi Patch Share Backbone]: ", patch_share_backbone)
print("[Structured Mask]: ", not structured_mask)
# Patching
self.l_patch_size = l_patch_size
self.k_patch_size = k_patch_size
self.in_channels = in_channels
self.out_channels = in_channels
self.patch_share_backbone = patch_share_backbone
self.abs_tem_emb= abs_tem_emb
self.hidden_size = hidden_size
if stride is not None:
self.stride = stride
else:
self.stride = self.l_patch_size
x_embedder = []
final_layer = []
backbone = []
for p in self.l_patch_size:
print(f"=== Patch {p} Branch ===")
x_embedder.append(TimePatchEmbed(p, self.k_patch_size, self.in_channels, self.hidden_size, bias=True,stride=p))
final_layer.append(MLP_FinalLayer(self.hidden_size, p, self.k_patch_size, self.out_channels))
if not patch_share_backbone:
backbone.append(DoublyAtt(d_model=self.hidden_size,n_layers=n_layers, t_layers=t_layers, v_layers=v_layers, d_inner=d_inner, n_heads=n_heads, d_k=d_k, d_v=d_v, dropout=dropout,
rotate=rotate, max_seq_len=max_seq_len, theta=theta, addv=addv, bin_att=bin_att,
learnable_theta=learnable_theta, structured_mask=structured_mask,rope_theta_init=rope_theta_init, min_period=min_period, max_period=max_period))
self.x_embedder = nn.ModuleList(x_embedder)
self.final_layer = nn.ModuleList(final_layer)
if not patch_share_backbone:
self.backbone = nn.ModuleList(backbone)
else:
self.backbone = DoublyAtt(d_model=self.hidden_size,n_layers=n_layers, t_layers=t_layers, v_layers=v_layers, d_inner=d_inner, n_heads=n_heads, d_k=d_k, d_v=d_v, dropout=dropout,
rotate=rotate, max_seq_len=max_seq_len, theta=theta, addv=addv, bin_att=bin_att,
learnable_theta=learnable_theta, structured_mask=structured_mask,rope_theta_init=rope_theta_init, min_period=min_period, max_period=max_period)
self.learn_tem_emb = learn_tem_emb
if self.learn_tem_emb:
self.learn_time_embedding = Time_Encoder(self.hidden_size)
def get_patch_num(self, dim_size, len_size, l_patch_size):
num_k_patches = int((dim_size - self.k_patch_size)/self.k_patch_size + 1)
num_l_patches = int((len_size - l_patch_size)/l_patch_size + 1)
return num_k_patches, num_l_patches
def forward(self, past_target, future_placeholder, past_observed_values, future_observed_values, dataset_name=None): # z: [bs x nvars x seq_len]
pred_shape = future_placeholder.shape
future_observed_indicator = torch.zeros(future_observed_values.shape).to(future_observed_values.device)
x = torch.cat((past_target, future_placeholder), dim=1) # B L+T K
past_value_indicator = torch.cat((past_observed_values, future_observed_indicator), dim=1) # B L+T K
observed_value_indicator = torch.cat((past_observed_values, future_observed_values), dim=1) # B L+T K
pred_list = []
for idx in range(len(self.l_patch_size)):
x_p = x.clone()
num_k_patches, num_l_patches = self.get_patch_num(x_p.shape[-1], x_p.shape[-2],self.l_patch_size[idx])
# do patching
x_p, past_value_indicator_p, observed_value_indicator_p = self.x_embedder[idx](x_p, past_value_indicator, observed_value_indicator) # b k l d
if self.learn_tem_emb:
grid_len = np.arange(num_l_patches, dtype=np.float32)
grid_len = torch.tensor(grid_len, requires_grad=False).float().unsqueeze(0).to(x.device)
pos_embed = repeat(grid_len, '1 l -> b l', b=pred_shape[0])
pos_embed = self.learn_time_embedding(pos_embed) # b l 1 d
pos_embed = rearrange(pos_embed, 'b l 1 d -> b 1 l d')
x_p = x_p + pos_embed
# use a absolute position embedding
if self.abs_tem_emb:
B, K, L, embed_dim = x_p.shape
pos_embed = sin_cos_encoding(B, K, L, embed_dim).float() # b k l d
x_p = x_p + pos_embed.to(x_p.device)
# model
if self.patch_share_backbone:
x_p = self.backbone(x_p, past_value_indicator_p, observed_value_indicator_p) # b k l d
else:
x_p = self.backbone[idx](x_p, past_value_indicator_p, observed_value_indicator_p) # b k l d
x_p = self.final_layer[idx](x_p) # b k l p
x_p = rearrange(x_p, 'b k t p -> b (t p) k')
x_p = x_p[:,-pred_shape[1]:,:]
pred_list.append(x_p.unsqueeze(-1))
pred_list = torch.cat(pred_list, dim=-1)
multi_patch_mean_res = torch.mean(pred_list, dim=-1)
return multi_patch_mean_res, pred_list
class DoublyAtt(nn.Module):
def __init__(self, d_model,n_layers, d_inner, n_heads, d_k, d_v, dropout,
rotate=False, max_seq_len=1024, theta=10000, t_layers=2, v_layers=1,
bin_att=False, addv=False, learnable_theta=False, structured_mask=True,
rope_theta_init='exp',min_period=0.1, max_period=10):
super().__init__()
# assert n_layers <= (t_layers + v_layers) <= 2*n_layers , "Sum of t_layers and n_layers must be between 1 and 2"
# Configuration based on temporal and variate ratios
self.layer_stack = nn.ModuleList()
num_t = t_layers
num_v = v_layers
num_both = min(t_layers, v_layers)
num_t = num_t - num_both
num_v = num_v - num_both
t_count = 0
v_count= 0
for _ in range(num_t + num_v):
if t_count < num_t :
self.layer_stack.append(EncoderLayer(d_model, d_inner, n_heads, d_k, d_v, dropout=dropout, tem_att=True, type_att=False,
structured_mask=structured_mask, rotate=rotate, max_seq_len=max_seq_len,theta=theta, addv=addv,
learnable_theta=learnable_theta, bin_att=bin_att,rope_theta_init=rope_theta_init, min_period=min_period, max_period=max_period))
t_count = t_count + 1
print(f"[Encoder Layer {t_count+v_count}] Use tem att")
if v_count < num_v:
self.layer_stack.append(EncoderLayer(d_model, d_inner, n_heads, d_k, d_v, dropout=dropout, tem_att=False, type_att=True,
structured_mask=structured_mask, rotate=rotate, max_seq_len=max_seq_len,theta=theta, addv=addv,
learnable_theta=learnable_theta, bin_att=bin_att,rope_theta_init=rope_theta_init, min_period=min_period, max_period=max_period))
v_count = v_count + 1
print(f"[Encoder Layer {t_count+v_count}] Use var att")
for idx in range(num_both):
self.layer_stack.append(EncoderLayer(d_model, d_inner, n_heads, d_k, d_v, dropout=dropout, tem_att=True, type_att=True,
structured_mask=structured_mask, rotate=rotate, max_seq_len=max_seq_len,theta=theta, addv=addv,
learnable_theta=learnable_theta, bin_att=bin_att,rope_theta_init=rope_theta_init, min_period=min_period, max_period=max_period))
print(f"[Encoder Layer {idx+t_count+v_count}] Use tem and var att")
def forward(self, x, past_value_indicator, observed_indicator) -> Tensor:
for enc_layer in self.layer_stack:
x = enc_layer(x, past_value_indicator=past_value_indicator, observed_indicator=observed_indicator)
return x
class MLP_FinalLayer(nn.Module):
"""
The final layer of DiT.
"""
def __init__(self, hidden_size, l_patch_size, k_patch_size, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, l_patch_size * k_patch_size * out_channels, bias=True)
def forward(self, x):
x = self.norm_final(x)
x = self.linear(x)
return x
class TimePatchEmbed(nn.Module):
""" Time Patch Embedding
"""
def __init__(
self,
l_patch_size: int = 16,
k_patch_size = 1,
in_chans: int = 1,
embed_dim: int = 768,
norm_layer: Optional[Callable] = None,
flatten: bool = False,
bias: bool = True,
# padding_patch = None,
stride = None,
# strict_img_size: bool = True,
):
super().__init__()
self.l_patch_size = l_patch_size
self.k_patch_size = k_patch_size
if stride is None:
stride = l_patch_size
self.flatten = flatten
padding = 0
kernel_size = (l_patch_size,k_patch_size)
stride_size = (stride,k_patch_size)
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride_size, bias=bias, padding=padding)
self.mask_proj = nn.Conv2d(1, 1, kernel_size=kernel_size, stride=stride_size, bias=False, padding=padding)
self.mask_proj.weight.data.fill_(1.0)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x, future_mask, obv_mask):
'''
future_mask: only past values are set to 1
obv_mask: past values and values to be predicted are set to 1
'''
# B, C, K, L = x.shape
if len(x.shape) == 3:
x = rearrange(x, 'b l k -> b 1 l k')
future_mask = rearrange(future_mask, 'b l k -> b 1 l k')
obv_mask = rearrange(obv_mask, 'b l k -> b 1 l k')
x = self.proj(x) # B C L K -> B C L' K
with torch.no_grad():
future_mask = self.mask_proj(future_mask)
obv_mask = self.mask_proj(obv_mask)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
future_mask = future_mask.flatten(2).transpose(1, 2) # NCHW -> NLC
obv_mask = obv_mask.flatten(2).transpose(1, 2) # NCHW -> NLC
x = self.norm(x)
x = rearrange(x, 'b d l k -> b k l d')
future_mask = rearrange(future_mask, 'b 1 l k -> b k l')
obv_mask = rearrange(obv_mask, 'b 1 l k -> b k l')
return x, future_mask, obv_mask
================================================
FILE: probts/model/nn/arch/ElasTSTModule/Layers.py
================================================
import torch.nn as nn
import sys
import torch
from probts.model.nn.arch.ElasTSTModule.SubLayers import PositionwiseFeedForward, MultiHeadAttention_tem_bias, MultiHeadAttention_type_bias
from einops import rearrange, repeat
PAD = 0
def get_attn_key_pad_mask_K(past_value_indicator, observed_indicator , transpose=False, structured_mask=False):
""" For masking out the padding part of key sequence.
input: mask: transpose=False: [b k l]
"""
if structured_mask:
mask = past_value_indicator
else:
mask = observed_indicator
if transpose:
mask = rearrange(mask, 'b l k -> b k l')
padding_mask = repeat(mask, 'b k l1 -> b k l2 l1', l2=mask.shape[-1]).eq(PAD)
return padding_mask
class EncoderLayer(nn.Module):
""" Compose with two layers """
def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1,
tem_att=True, type_att=False, structured_mask=True,
rotate=False, max_seq_len=100, theta=10000,
addv=False, learnable_theta=False, bin_att=False,
rope_theta_init='exp',min_period=0.1, max_period=10):
super(EncoderLayer, self).__init__()
self.structured_mask = structured_mask
self.tem_att = tem_att
self.type_att = type_att
if tem_att:
self.slf_tem_attn = MultiHeadAttention_tem_bias(
n_head, d_model, d_k, d_v, dropout=dropout, rotate=rotate, max_seq_len=max_seq_len, theta=theta, addv=addv,
learnable_theta=learnable_theta, bin_att=bin_att,rope_theta_init=rope_theta_init, min_period=min_period, max_period=max_period)
if type_att:
self.slf_type_attn = MultiHeadAttention_type_bias(
n_head, d_model, d_k, d_v, dropout=dropout, rotate=False, max_seq_len=max_seq_len, bin_att=bin_att)
self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
def forward(self, input, past_value_indicator=None, observed_indicator=None):
# time attention
# [B, K, L, D]
if self.tem_att:
tem_mask = get_attn_key_pad_mask_K(past_value_indicator=past_value_indicator, observed_indicator=observed_indicator, transpose=False, structured_mask=self.structured_mask)
tem_output = self.layer_norm(input)
tem_output, enc_tem_attn = self.slf_tem_attn(
tem_output, tem_output, tem_output, mask=tem_mask)
tem_output = tem_output + input
else:
tem_output = input
tem_output = rearrange(tem_output, 'b k l d -> b l k d')
# type attention
# [B, L, K, D]
if self.type_att:
type_mask = get_attn_key_pad_mask_K(past_value_indicator=past_value_indicator, observed_indicator=observed_indicator, transpose=True, structured_mask=self.structured_mask)
type_output = self.layer_norm(tem_output)
type_output, enc_type_attn = self.slf_type_attn(
type_output, type_output, type_output, mask=type_mask)
enc_output = type_output + tem_output
else:
enc_output = tem_output
# FFNN
output = self.layer_norm(enc_output)
output = self.pos_ffn(output)
output = output + enc_output
output = rearrange(output, 'b l k d -> b k l d')
# optional
output = self.layer_norm(output)
return output #, enc_tem_attn, enc_type_attn
================================================
FILE: probts/model/nn/arch/ElasTSTModule/Modules.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from probts.model.nn.arch.ElasTSTModule.TRoPE import RotaryEmbedding
class ScaledDotProductAttention(nn.Module):
""" Scaled Dot-Product Attention """
def __init__(self, temperature, attn_dropout=0.2):
super().__init__()
self.temperature = temperature
self.dropout = nn.Dropout(attn_dropout)
def forward(self, q, k, v, mask=None):
attn = torch.bmm(q / self.temperature, k.transpose(-2, -1))
if mask is not None and mask.dim() == 5:
mask = mask.transpose(2, 4)
if mask is not None:
attn = attn.masked_fill(mask, -1e9)
attn = self.dropout(F.softmax(attn, dim=-1))
output = torch.bmm(attn, v)
return output, attn
class ScaledDotProductAttention_bias(nn.Module):
def __init__(self, d_model, n_head, d_k, d_v, temperature,
attn_dropout=0.2, rotate=False, max_seq_len=100,
theta=10000, addv=False, learnable_theta=False,
bin_att=False,rope_theta_init='exp',
min_period=0.1, max_period=10):
super().__init__()
self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False)
self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False)
self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False)
self.temperature = temperature
self.dropout = nn.Dropout(attn_dropout)
self.n_head = n_head
self.bin_att = bin_att
self.rotate = rotate
self.addv = addv
self.trope = RotaryEmbedding(d_v, max_seq_len,base=theta, learnable=learnable_theta,init=rope_theta_init,min_period=min_period, max_period=max_period)
if self.bin_att:
self.alpha = nn.Parameter(torch.zeros([1,1,n_head,1,1]))
self.beta = nn.Parameter(torch.zeros([1,1,n_head,1,1]))
def forward(self, q, k, v, mask):
# input: [B,K,H,LQ,LK] for temporal, [B,L,H,Kq,Kk] for category
# [B,K,L,H,D]
q = rearrange(self.w_qs(q), 'b k l (n d) -> b k n l d', n=self.n_head)
k = rearrange(self.w_ks(k), 'b k l (n d) -> b k n d l', n=self.n_head)
v = rearrange(self.w_vs(v), 'b k l (n d) -> b k n l d', n=self.n_head)
B, K, N, L, D = q.shape
if self.rotate:
xq = rearrange(q, 'b k n l d -> (b k n) l d')
xk = rearrange(k, 'b k n d l -> (b k n) l d')
xv = rearrange(v, 'b k n l d -> (b k n) l d')
xq, xk, xv = self.trope(xq, xk, xv)
attn = torch.matmul(xq, xk.transpose(1, 2)) / self.temperature
attn = rearrange(attn, '(b k n) l t -> b k n l t', b=B, k=K,n=N)
if self.addv:
v = rearrange(xv, '(b k n) l d -> b k n l d', b=B, k=K,n=N)
else:
attn = torch.matmul(q , k) / self.temperature
if self.bin_att:
self_mask = torch.eye(L).to(mask.device)
self_mask = repeat(self_mask, 'l t -> b k n l t', b=B, k=K,n=N)
attn = attn + self_mask * self.alpha + (1-self_mask) * self.beta
if mask is not None:
if attn.dim() > mask.dim():
mask = mask.unsqueeze(2).expand(attn.shape)
attn = attn.masked_fill(mask, -1e9)
attn = self.dropout(F.softmax(attn, dim=-1))
v = torch.matmul(attn, v)
v = rearrange(v, 'b k n l d -> b k l (n d)')
# sys.exit(0)
return v, attn
class Attention(nn.Module):
def __init__(self, hin_d, d_model):
super().__init__()
self.linear = nn.Linear(d_model, hin_d)
self.W = nn.Linear(hin_d,1, bias=False)
def forward(self, x, mask=None, mask_value=-1e30):
# [B,K,L,D]
# map directly
attn = self.W(torch.tanh(self.linear(x))) # [B,K,L,1]
if mask is not None:
attn = mask * attn + (1-mask)*mask_value
attn = F.softmax(attn, dim=-2)
x = torch.matmul(x.transpose(-1, -2), attn).squeeze(-1) # [B,K,D,1]
return x, attn
================================================
FILE: probts/model/nn/arch/ElasTSTModule/SubLayers.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import sys
from probts.model.nn.arch.ElasTSTModule.Modules import ScaledDotProductAttention_bias
class MultiHeadAttention_tem_bias(nn.Module):
""" Multi-Head Attention module """
def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1, rotate=False, max_seq_len=100, theta=10000, addv=False,
learnable_theta=False, bin_att=False,rope_theta_init='exp',min_period=0.1, max_period=10):
super().__init__()
self.n_head = n_head
self.d_k = d_k
self.d_v = d_v
self.fc = nn.Linear(d_v * n_head, d_model)
self.attention = ScaledDotProductAttention_bias(d_model, n_head, d_k, d_v, temperature=d_k ** 0.5,
attn_dropout=dropout, rotate=rotate, max_seq_len=max_seq_len,
theta=theta, addv=addv, learnable_theta=learnable_theta, bin_att=bin_att,
rope_theta_init=rope_theta_init,min_period=min_period, max_period=max_period)
self.dropout = nn.Dropout(dropout)
def forward(self, q, k, v, mask=None):
# event_matrix [B,L,K]
# [B,K,H,Lq,Lk]
output, attn = self.attention(q, k, v, mask=mask) # [B,K,H,L,D]
output = self.dropout(self.fc(output))
return output, attn
class MultiHeadAttention_type_bias(nn.Module):
""" Multi-Head Attention module """
def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1, rotate=False, max_seq_len=1024, bin_att=False):
super().__init__()
self.n_head = n_head
self.d_k = d_k
self.d_v = d_v
self.fc = nn.Linear(d_v * n_head, d_model)
self.attention = ScaledDotProductAttention_bias(d_model, n_head, d_k, d_v, temperature=d_k ** 0.5, attn_dropout=dropout, rotate=False, max_seq_len=max_seq_len, bin_att=bin_att)
self.dropout = nn.Dropout(dropout)
def forward(self, q, k, v, mask=None):
# [B,L,K,D]
output, attn = self.attention(q, k, v, mask=mask)
output = self.dropout(self.fc(output))
return output, attn
class PositionwiseFeedForward(nn.Module):
""" Two-layer position-wise feed-forward neural network. """
def __init__(self, d_in, d_hid, dropout=0.1):
super().__init__()
self.w_1 = nn.Linear(d_in, d_hid)
self.w_2 = nn.Linear(d_hid, d_in)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
x = F.gelu(self.w_1(x))
x = self.dropout(x)
x = self.w_2(x)
x = self.dropout(x)
return x
================================================
FILE: probts/model/nn/arch/ElasTSTModule/TRoPE.py
================================================
import torch
from typing import Tuple
import torch
import torch.nn as nn
import numpy as np
import sys
class RotaryEmbedding(nn.Module):
def __init__(self, dim: int, seq_len: int, base: float = 10000.0, learnable=False, init="exp",min_period=0.01, max_period=1000):
super(RotaryEmbedding, self).__init__()
if init == 'linear':
theta = get_linear_period(min_period, max_period, dim)
elif init == 'uniform':
theta = torch.ones([dim//2])
periods = torch.nn.init.uniform_(theta, a=min_period, b=max_period)
theta = 2 * np.pi / periods
elif init == 'exp':
theta = get_exp_period(min_period, max_period, dim)
elif init == 'rope':
theta = 1.0 / (base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
else:
print("invalid theta init")
sys.exit(0)
if learnable:
self.freqs = nn.Parameter(theta)
else:
self.register_buffer('freqs', torch.tensor(theta))
self.dim = dim
self.seq_len = seq_len
self.learnable = learnable
def forward(self, xq: torch.Tensor, xk: torch.Tensor, xv: torch.Tensor):
L = xq.shape[-2]
t = torch.arange(L, device=xq.device)
freqs = torch.outer(t, self.freqs).float() # m * \theta
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 2)
xv_ = xv.float().reshape(*xv.shape[:-1], -1, 2)
xq_ = torch.view_as_complex(xq_).to(xq.device)
xk_ = torch.view_as_complex(xk_).to(xq.device)
xv_ = torch.view_as_complex(xv_).to(xq.device)
# rotate and then map to real number field
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(2).to(xq.device)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(2).to(xq.device)
xv_out = torch.view_as_real(xv_ * freqs_cis).flatten(2).to(xq.device)
return xq_out.type_as(xq), xk_out.type_as(xk), xv_out.type_as(xv)
def get_linear_period(min_period, max_period, dim):
i = torch.arange(0, dim, 2)[: (dim // 2)]
periods = min_period + ((max_period - min_period) / dim ) * i
theta = 2 * np.pi / periods
return theta
def get_exp_period(min_period, max_period, dim):
i = torch.arange(0, dim, 2)[: (dim // 2)]
max_theta = 2 * np.pi / min_period
min_theta = 2 * np.pi / max_period
alpha = np.log(max_theta/min_theta) * (1/(dim-2))
thetas = max_theta * np.exp(-alpha * i)
return thetas
# generate rotation matrix
def precompute_freqs_cis(dim: int, seq_len: int, theta: float = 10000.0):
# rotate \theta_i
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
# generate token indexes t = [0, 1,..., seq_len-1]
t = torch.arange(seq_len, device=freqs.device)
# freqs.shape = [seq_len, dim // 2]
freqs = torch.outer(t, freqs).float() # m * \theta
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
return freqs_cis
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
xv: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 2)
xv_ = xv.float().reshape(*xv.shape[:-1], -1, 2)
freqs_cis = freqs_cis.to(xq.device)
xq_ = torch.view_as_complex(xq_).to(xq.device)
xk_ = torch.view_as_complex(xk_).to(xq.device)
xv_ = torch.view_as_complex(xv_).to(xq.device)
# rotate and then map to real number field
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(2).to(xq.device)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(2).to(xq.device)
xv_out = torch.view_as_real(xv_ * freqs_cis).flatten(2).to(xq.device)
return xq_out.type_as(xq), xk_out.type_as(xk), xv_out.type_as(xv)
================================================
FILE: probts/model/nn/arch/ElasTSTModule/__init__.py
================================================
================================================
FILE: probts/model/nn/arch/ModernTCN_backbone.py
================================================
import torch
from torch import nn
import torch.nn.functional as F
from probts.model.nn.arch.RevIN import RevIN
from probts.model.nn.arch.decomp import series_decomp
# forecast task head
class Flatten_Head(nn.Module):
def __init__(self, individual, n_vars, nf, target_window, head_dropout=0):
super(Flatten_Head, self).__init__()
self.individual = individual
self.n_vars = n_vars
if self.individual:
self.linears = nn.ModuleList()
self.dropouts = nn.ModuleList()
self.flattens = nn.ModuleList()
for i in range(self.n_vars):
self.flattens.append(nn.Flatten(start_dim=-2))
self.linears.append(nn.Linear(nf, target_window))
self.dropouts.append(nn.Dropout(head_dropout))
else:
self.flatten = nn.Flatten(start_dim=-2)
self.linear = nn.Linear(nf, target_window)
self.dropout = nn.Dropout(head_dropout)
def forward(self, x): # x: [bs x nvars x d_model x patch_num]
if self.individual:
x_out = []
for i in range(self.n_vars):
z = self.flattens[i](x[:, i, :, :]) # z: [bs x d_model * patch_num]
z = self.linears[i](z) # z: [bs x target_window]
z = self.dropouts[i](z)
x_out.append(z)
x = torch.stack(x_out, dim=1) # x: [bs x nvars x target_window]
else:
x = self.flatten(x)
x = self.linear(x)
x = self.dropout(x)
return x
class LayerNorm(nn.Module):
def __init__(self, channels, eps=1e-6, data_format="channels_last"):
super(LayerNorm, self).__init__()
self.norm = nn.Layernorm(channels)
def forward(self, x):
B, M, D, N = x.shape
x = x.permute(0, 1, 3, 2)
x = x.reshape(B * M, N, D)
x = self.norm(x)
x = x.reshape(B, M, N, D)
x = x.permute(0, 1, 3, 2)
return x
def get_conv1d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias):
return nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, groups=groups, bias=bias)
def get_bn(channels):
return nn.BatchNorm1d(channels)
def conv_bn(in_channels, out_channels, kernel_size, stride, padding, groups, dilation=1,bias=False):
if padding is None:
padding = kernel_size // 2
result = nn.Sequential()
result.add_module('conv', get_conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias))
result.add_module('bn', get_bn(out_channels))
return result
def fuse_bn(conv, bn):
kernel = conv.weight
running_mean = bn.running_mean
running_var = bn.running_var
gamma = bn.weight
beta = bn.bias
eps = bn.eps
std = (running_var + eps).sqrt()
t = (gamma / std).reshape(-1, 1, 1)
return kernel * t, beta - running_mean * gamma / std
class ReparamLargeKernelConv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size,
stride, groups,
small_kernel,
small_kernel_merged=False, nvars=7):
super(ReparamLargeKernelConv, self).__init__()
self.kernel_size = kernel_size
self.small_kernel = small_kernel
# We assume the conv does not change the feature map size, so padding = k//2. Otherwise, you may configure padding as you wish, and change the padding of small_conv accordingly.
padding = kernel_size // 2
if small_kernel_merged:
self.lkb_reparam = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
stride=stride, padding=padding, dilation=1, groups=groups, bias=True)
else:
self.lkb_origin = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
stride=stride, padding=padding, dilation=1, groups=groups,bias=False)
if small_kernel is not None:
assert small_kernel <= kernel_size, 'The kernel size for re-param cannot be larger than the large kernel!'
self.small_conv = conv_bn(in_channels=in_channels, out_channels=out_channels,
kernel_size=small_kernel,
stride=stride, padding=small_kernel // 2, groups=groups, dilation=1,bias=False)
def forward(self, inputs):
if hasattr(self, 'lkb_reparam'):
out = self.lkb_reparam(inputs)
else:
out = self.lkb_origin(inputs)
if hasattr(self, 'small_conv'):
out += self.small_conv(inputs)
return out
def PaddingTwoEdge1d(self,x,pad_length_left,pad_length_right,pad_values=0):
D_out,D_in,ks=x.shape
if pad_values ==0:
pad_left = torch.zeros(D_out,D_in,pad_length_left)
pad_right = torch.zeros(D_out,D_in,pad_length_right)
else:
pad_left = torch.ones(D_out, D_in, pad_length_left) * pad_values
pad_right = torch.ones(D_out, D_in, pad_length_right) * pad_values
x = torch.cat([pad_left,x],dims=-1)
x = torch.cat([x,pad_right],dims=-1)
return x
def get_equivalent_kernel_bias(self):
eq_k, eq_b = fuse_bn(self.lkb_origin.conv, self.lkb_origin.bn)
if hasattr(self, 'small_conv'):
small_k, small_b = fuse_bn(self.small_conv.conv, self.small_conv.bn)
eq_b += small_b
eq_k += self.PaddingTwoEdge1d(small_k, (self.kernel_size - self.small_kernel) // 2,
(self.kernel_size - self.small_kernel) // 2, 0)
return eq_k, eq_b
def merge_kernel(self):
eq_k, eq_b = self.get_equivalent_kernel_bias()
self.lkb_reparam = nn.Conv1d(in_channels=self.lkb_origin.conv.in_channels,
out_channels=self.lkb_origin.conv.out_channels,
kernel_size=self.lkb_origin.conv.kernel_size, stride=self.lkb_origin.conv.stride,
padding=self.lkb_origin.conv.padding, dilation=self.lkb_origin.conv.dilation,
groups=self.lkb_origin.conv.groups, bias=True)
self.lkb_reparam.weight.data = eq_k
self.lkb_reparam.bias.data = eq_b
self.__delattr__('lkb_origin')
if hasattr(self, 'small_conv'):
self.__delattr__('small_conv')
class Block(nn.Module):
def __init__(self, large_size, small_size, dmodel, dff, nvars, small_kernel_merged=False, drop=0.1):
super(Block, self).__init__()
self.dw = ReparamLargeKernelConv(in_channels=nvars * dmodel, out_channels=nvars * dmodel,
kernel_size=large_size, stride=1, groups=nvars * dmodel,
small_kernel=small_size, small_kernel_merged=small_kernel_merged, nvars=nvars)
self.norm = nn.BatchNorm1d(dmodel)
#convffn1
self.ffn1pw1 = nn.Conv1d(in_channels=nvars * dmodel, out_channels=nvars * dff, kernel_size=1, stride=1,
padding=0, dilation=1, groups=nvars)
self.ffn1act = nn.GELU()
self.ffn1pw2 = nn.Conv1d(in_channels=nvars * dff, out_channels=nvars * dmodel, kernel_size=1, stride=1,
padding=0, dilation=1, groups=nvars)
self.ffn1drop1 = nn.Dropout(drop)
self.ffn1drop2 = nn.Dropout(drop)
#convffn2
self.ffn2pw1 = nn.Conv1d(in_channels=nvars * dmodel, out_channels=nvars * dff, kernel_size=1, stride=1,
padding=0, dilation=1, groups=dmodel)
self.ffn2act = nn.GELU()
self.ffn2pw2 = nn.Conv1d(in_channels=nvars * dff, out_channels=nvars * dmodel, kernel_size=1, stride=1,
padding=0, dilation=1, groups=dmodel)
self.ffn2drop1 = nn.Dropout(drop)
self.ffn2drop2 = nn.Dropout(drop)
self.ffn_ratio = dff//dmodel
def forward(self,x):
input = x
B, M, D, N = x.shape
x = x.reshape(B,M*D,N)
x = self.dw(x)
x = x.reshape(B,M,D,N)
x = x.reshape(B*M,D,N)
x = self.norm(x)
x = x.reshape(B, M, D, N)
x = x.reshape(B, M * D, N)
x = self.ffn1drop1(self.ffn1pw1(x))
x = self.ffn1act(x)
x = self.ffn1drop2(self.ffn1pw2(x))
x = x.reshape(B, M, D, N)
x = x.permute(0, 2, 1, 3)
x = x.reshape(B, D * M, N)
x = self.ffn2drop1(self.ffn2pw1(x))
x = self.ffn2act(x)
x = self.ffn2drop2(self.ffn2pw2(x))
x = x.reshape(B, D, M, N)
x = x.permute(0, 2, 1, 3)
x = input + x
return x
class Stage(nn.Module):
def __init__(self, ffn_ratio, num_blocks, large_size, small_size, dmodel, dw_model, nvars,
small_kernel_merged=False, drop=0.1):
super(Stage, self).__init__()
d_ffn = dmodel * ffn_ratio
blks = []
for i in range(num_blocks):
blk = Block(large_size=large_size, small_size=small_size, dmodel=dmodel, dff=d_ffn, nvars=nvars, small_kernel_merged=small_kernel_merged, drop=drop)
blks.append(blk)
self.blocks = nn.ModuleList(blks)
def forward(self, x):
for blk in self.blocks:
x = blk(x)
return x
class ModernTCNModel(nn.Module):
def __init__(self,patch_size,patch_stride, stem_ratio, downsample_ratio, ffn_ratio, num_blocks, large_size, small_size, dims, dw_dims,
nvars, small_kernel_merged=False, backbone_dropout=0.1, head_dropout=0.1, use_multi_scale=True, revin=True, affine=True,
subtract_last=False, freq=None, seq_len=512, c_in=7, individual=False, target_window=96):
super(ModernTCNModel, self).__init__()
# RevIN
self.revin = revin
if self.revin:
self.revin_layer = RevIN(c_in, affine=affine, subtract_last=subtract_last)
# stem layer & down sampling layers(if needed)
self.downsample_layers = nn.ModuleList()
stem = nn.Sequential(
nn.Conv1d(1, dims[0], kernel_size=patch_size, stride=patch_stride),
nn.BatchNorm1d(dims[0])
)
self.downsample_layers.append(stem)
for i in range(3):
downsample_layer = nn.Sequential(
nn.BatchNorm1d(dims[i]),
nn.Conv1d(dims[i], dims[i + 1], kernel_size=downsample_ratio, stride=downsample_ratio),
)
self.downsample_layers.append(downsample_layer)
self.patch_size = patch_size
self.patch_stride = patch_stride
self.downsample_ratio = downsample_ratio
# if freq == 'h':
# time_feature_num = 4
# elif freq == 't':
# time_feature_num = 5
# else:
# raise NotImplementedError("time_feature_num should be 4 or 5")
if freq.lower() == 'h':
time_feature_num = 4
else:
time_feature_num = 5
self.te_patch = nn.Sequential(
nn.Conv1d(time_feature_num, time_feature_num, kernel_size=patch_size, stride=patch_stride,groups=time_feature_num),
nn.Conv1d(time_feature_num, dims[0], kernel_size=1, stride=1, groups=1),
nn.BatchNorm1d(dims[0]))
# backbone
self.num_stage = len(num_blocks)
self.stages = nn.ModuleList()
for stage_idx in range(self.num_stage):
layer = Stage(ffn_ratio, num_blocks[stage_idx], large_size[stage_idx], small_size[stage_idx], dmodel=dims[stage_idx],
dw_model=dw_dims[stage_idx], nvars=nvars, small_kernel_merged=small_kernel_merged, drop=backbone_dropout)
self.stages.append(layer)
# Multi scale fusing (if needed)
self.use_multi_scale = use_multi_scale
self.up_sample_ratio = downsample_ratio
self.lat_layer = nn.ModuleList()
self.smooth_layer = nn.ModuleList()
self.up_sample_conv = nn.ModuleList()
for i in range(self.num_stage):
align_dim = dims[-1]
lat = nn.Conv1d(dims[i], align_dim, kernel_size=1,
stride=1)
self.lat_layer.append(lat)
smooth = nn.Conv1d(align_dim, align_dim, kernel_size=3, stride=1, padding=1)
self.smooth_layer.append(smooth)
up_conv = nn.Sequential(
nn.ConvTranspose1d(align_dim, align_dim, kernel_size=self.up_sample_ratio, stride=self.up_sample_ratio),
nn.BatchNorm1d(align_dim))
self.up_sample_conv.append(up_conv)
# head
patch_num = seq_len // patch_stride
self.n_vars = c_in
self.individual = individual
d_model = dims[-1]
if use_multi_scale:
self.head_nf = d_model * patch_num
self.head = Flatten_Head(self.individual, self.n_vars, self.head_nf, target_window,
head_dropout=head_dropout)
else:
if patch_num % pow(downsample_ratio,(self.num_stage - 1)) == 0:
self.head_nf = d_model * patch_num // pow(downsample_ratio,(self.num_stage - 1))
else:
self.head_nf = d_model * (patch_num // pow(downsample_ratio, (self.num_stage - 1))+1)
self.head = Flatten_Head(self.individual, self.n_vars, self.head_nf, target_window,
head_dropout=head_dropout)
def up_sample(self, x, upsample_ratio):
_, _, _, N = x.shape
return F.upsample(x, size=N, scale_factor=upsample_ratio, mode='bilinear')
def forward_feature(self, x, te=None):
B,M,L=x.shape
x = x.unsqueeze(-2)
for i in range(self.num_stage):
B, M, D, N = x.shape
x = x.reshape(B * M, D, N)
if i==0:
if self.patch_size != self.patch_stride:
# stem layer padding
pad_len = self.patch_size - self.patch_stride
pad = x[:,:,-1:].repeat(1,1,pad_len)
x = torch.cat([x,pad],dim=-1)
else:
if N % self.downsample_ratio != 0:
pad_len = self.downsample_ratio - (N % self.downsample_ratio)
x = torch.cat([x, x[:, :, -pad_len:]],dim=-1)
x = self.downsample_layers[i](x)
_, D_, N_ = x.shape
x = x.reshape(B, M, D_, N_)
x = self.stages[i](x)
return x
def forward(self, x, te=None):
# instance norm
if self.revin:
x = x.permute(0, 2, 1)
x = self.revin_layer(x, 'norm')
x = x.permute(0, 2, 1)
x = self.forward_feature(x,te)
x = self.head(x)
# de-instance norm
if self.revin:
x = x.permute(0, 2, 1)
x = self.revin_layer(x, 'denorm')
x = x.permute(0, 2, 1)
return x
def structural_reparam(self):
for m in self.modules():
if hasattr(m, 'merge_kernel'):
m.merge_kernel()
================================================
FILE: probts/model/nn/arch/Moirai_backbone.py
================================================
# ---------------------------------------------------------------------------------
# Portions of this file are derived from uni2ts
# - Source: https://github.com/SalesforceAIResearch/uni2ts
# - Paper: Unified Training of Universal Time Series Forecasting Transformers
# - License: Apache License 2.0
# We thank the authors for their contributions.
# ---------------------------------------------------------------------------------
import math
from contextlib import contextmanager
from copy import deepcopy
from typing import Any, Generator, Optional
import sys
import lightning as L
import torch
from einops import rearrange, reduce, repeat
from jaxtyping import Bool, Float, Int
from torch.distributions import Distribution
from uni2ts.common.torch_util import safe_div
from uni2ts.loss.packed import PackedNLLLoss as _PackedNLLLoss
from uni2ts.model.moirai.module import MoiraiModule
from uni2ts.module.packed_scaler import PackedNOPScaler, PackedStdScaler
class SampleNLLLoss(_PackedNLLLoss):
def reduce_loss(
self,
loss: Float[torch.Tensor, "batch seq_len #dim"],
prediction_mask: Optional[Bool[torch.Tensor, "batch seq_len"]],
observed_mask: Optional[Bool[torch.Tensor, "batch seq_len #dim"]],
sample_id: Optional[Int[torch.Tensor, "batch seq_len"]],
variate_id: Optional[Int[torch.Tensor, "batch seq_len"]],
) -> Float[torch.Tensor, "batch"]:
id_mask = torch.logical_and(
torch.eq(sample_id.unsqueeze(-1), sample_id.unsqueeze(-2)),
torch.eq(variate_id.unsqueeze(-1), variate_id.unsqueeze(-2)),
)
mask = prediction_mask.unsqueeze(-1) * observed_mask
tobs = reduce(
id_mask
* reduce(
mask,
"... seq dim -> ... 1 seq",
"sum",
),
"... seq1 seq2 -> ... seq1 1",
"sum",
)
loss = safe_div(loss, tobs)
return (loss * mask).sum(dim=(-1, -2))
class MoiraiBackbone(L.LightningModule):
def __init__(
self,
prediction_length: int,
target_dim: int,
context_length: int,
module_kwargs: Optional[dict[str, Any]] = None,
module: Optional[MoiraiModule] = None,
patch_size: int | str = "auto",
num_samples: int = 100,
scaling: bool = True,
):
assert (module is not None) or (
module_kwargs is not None
), "if module is not provided, module_kwargs is required"
super().__init__()
self.save_hyperparameters(ignore=["module"])
self.module = MoiraiModule(**module_kwargs) if module is None else module
self.module.scaling = scaling
self.module.scaler = PackedStdScaler() if scaling else PackedNOPScaler()
self.per_sample_loss_func = SampleNLLLoss()
@contextmanager
def hparams_context(
self,
prediction_length: Optional[int] = None,
target_dim: Optional[int] = None,
context_length: Optional[int] = None,
patch_size: Optional[int | str] = None,
num_samples: Optional[int] = None,
) -> Generator["MoiraiForecast", None, None]:
kwargs = {
"prediction_length": prediction_length,
"target_dim": target_dim,
"context_length": context_length,
"patch_size": patch_size,
"num_samples": num_samples,
}
old_hparams = deepcopy(self.hparams)
for kw, arg in kwargs.items():
if arg is not None:
self.hparams[kw] = arg
yield self
for kw in kwargs:
self.hparams[kw] = old_hparams[kw]
@property
def past_length(self) -> int:
return (
self.hparams.context_length + self.hparams.prediction_length
if self.hparams.patch_size == "auto"
else self.hparams.context_length
)
def context_token_length(self, patch_size: int) -> int:
return math.ceil(self.hparams.context_length / patch_size)
def prediction_token_length(self, patch_size) -> int:
return math.ceil(self.hparams.prediction_length / patch_size)
@property
def max_patch_size(self) -> int:
return max(self.module.patch_sizes)
def forward(
self,
past_target: Float[torch.Tensor, "batch past_time tgt"],
past_observed_target: Bool[torch.Tensor, "batch past_time tgt"],
past_is_pad: Bool[torch.Tensor, "batch past_time"],
num_samples: Optional[int] = None,
) -> Float[torch.Tensor, "batch sample future_time *tgt"]:
if self.hparams.patch_size == "auto":
val_loss = []
preds = []
for patch_size in self.module.patch_sizes:
val_loss.append(
self._val_loss(
patch_size=patch_size,
target=past_target[..., : self.past_length, :],
observed_target=past_observed_target[
..., : self.past_length, :
],
is_pad=past_is_pad[..., : self.past_length]
)
)
distr = self._get_distr(
patch_size,
past_target[..., -self.hparams.context_length :, :],
past_observed_target[..., -self.hparams.context_length :, :],
past_is_pad[..., -self.hparams.context_length :]
)
preds.append(
self._format_preds(
patch_size,
distr.sample(
torch.Size((num_samples or self.hparams.num_samples,))
),
past_target.shape[-1],
)
)
val_loss = torch.stack(val_loss)
preds = torch.stack(preds)
idx = val_loss.argmin(dim=0)
return preds[idx, torch.arange(len(idx), device=idx.device)]
else:
distr = self._get_distr(
self.hparams.patch_size,
past_target[..., -self.hparams.context_length :, :],
past_observed_target[..., -self.hparams.context_length :, :],
past_is_pad[..., -self.hparams.context_length :],
)
preds = distr.sample(torch.Size((num_samples or self.hparams.num_samples,)))
return self._format_preds(
self.hparams.patch_size, preds, past_target.shape[-1]
)
def _val_loss(
self,
patch_size: int,
target: Float[torch.Tensor, "batch time tgt"],
observed_target: Bool[torch.Tensor, "batch time tgt"],
is_pad: Bool[torch.Tensor, "batch time"]
) -> Float[torch.Tensor, "batch"]:
# convert format
(
target,
observed_mask,
sample_id,
time_id,
variate_id,
prediction_mask,
) = self._convert(
patch_size,
past_target=target[..., : self.hparams.context_length, :],
past_observed_target=observed_target[..., : self.hparams.context_length, :],
past_is_pad=is_pad[..., : self.hparams.context_length],
future_target=target[..., self.hparams.context_length :, :],
future_observed_target=observed_target[
..., self.hparams.context_length :, :
],
future_is_pad=is_pad[..., self.hparams.context_length :]
)
# get predictions
distr = self.module(
target,
observed_mask,
sample_id,
time_id,
variate_id,
prediction_mask,
torch.ones_like(time_id, dtype=torch.long) * patch_size,
)
val_loss = self.per_sample_loss_func(
pred=distr,
target=target,
prediction_mask=prediction_mask,
observed_mask=observed_mask,
sample_id=sample_id,
variate_id=variate_id,
)
return val_loss
def _get_distr(
self,
patch_size: int,
past_target: Float[torch.Tensor, "batch past_time tgt"],
past_observed_target: Bool[torch.Tensor, "batch past_time tgt"],
past_is_pad: Bool[torch.Tensor, "batch past_time"]
) -> Distribution:
# convert format
(
target,
observed_mask,
sample_id,
time_id,
variate_id,
prediction_mask,
) = self._convert(
patch_size,
past_target,
past_observed_target,
past_is_pad
)
# get predictions
distr = self.module(
target,
observed_mask,
sample_id,
time_id,
variate_id,
prediction_mask,
torch.ones_like(time_id, dtype=torch.long) * patch_size,
)
return distr
@staticmethod
def _patched_seq_pad(
patch_size: int,
x: torch.Tensor,
dim: int,
left: bool = True,
value: Optional[float] = None,
) -> torch.Tensor:
if dim >= 0:
dim = -x.ndim + dim
pad_length = -x.size(dim) % patch_size
if left:
pad = (pad_length, 0)
else:
pad = (0, pad_length)
pad = (0, 0) * (abs(dim) - 1) + pad
return torch.nn.functional.pad(x, pad, value=value)
def _generate_time_id(
self,
patch_size: int,
past_observed_target: Bool[torch.Tensor, "batch past_seq tgt"],
future_target: Float[torch.Tensor, "batch future_seq tgt"],
) -> tuple[
Int[torch.Tensor, "batch past_token"], Int[torch.Tensor, "batch future_token"]
]:
past_seq_id = reduce(
self._patched_seq_pad(patch_size, past_observed_target, -2, left=True),
"... (seq patch) dim -> ... seq",
"max",
patch=patch_size,
)
past_seq_id = torch.clamp(past_seq_id.cumsum(dim=-1) - 1, min=0)
batch_shape = " ".join(map(str, past_observed_target.shape[:-2]))
future_seq_id = (
repeat(
torch.arange(
math.ceil(future_target.shape[-2] / patch_size),
device=past_observed_target.device,
),
f"prediction -> {batch_shape} prediction",
)
+ past_seq_id.max(dim=-1, keepdim=True).values
+ 1
)
past_seq_id = past_seq_id.to(dtype=torch.int32)
future_seq_id = future_seq_id.to(dtype=torch.int32)
return past_seq_id, future_seq_id
def _convert(
self,
patch_size: int,
past_target: Float[torch.Tensor, "batch past_time tgt"],
past_observed_target: Bool[torch.Tensor, "batch past_time tgt"],
past_is_pad: Bool[torch.Tensor, "batch past_time"],
future_target: Optional[Float[torch.Tensor, "batch future_time tgt"]] = None,
future_observed_target: Optional[
Bool[torch.Tensor, "batch future_time tgt"]
] = None,
future_is_pad: Optional[Bool[torch.Tensor, "batch future_time"]] = None
) -> tuple[
Float[torch.Tensor, "batch combine_seq patch"], # target
Bool[torch.Tensor, "batch combine_seq patch"], # observed_mask
Int[torch.Tensor, "batch combine_seq"], # sample_id
Int[torch.Tensor, "batch combine_seq"], # time_id
Int[torch.Tensor, "batch combine_seq"], # variate_id
Bool[torch.Tensor, "batch combine_seq"], # prediction_mask
]:
batch_shape = past_target.shape[:-2]
device = past_target.device
target = []
observed_mask = []
sample_id = []
time_id = []
variate_id = []
prediction_mask = []
dim_count = 0
if future_target is None:
future_target = torch.zeros(
batch_shape
+ (
self.hparams.prediction_length,
past_target.shape[-1],
),
dtype=past_target.dtype,
device=device,
)
past_seq_id, future_seq_id = self._generate_time_id(
patch_size, past_observed_target, future_target
)
target.extend(
[
torch.nn.functional.pad(
rearrange(
self._patched_seq_pad(patch_size, past_target, -2, left=True),
"... (seq patch) dim -> ... (dim seq) patch",
patch=patch_size,
),
(0, self.max_patch_size - patch_size),
),
torch.nn.functional.pad(
rearrange(
self._patched_seq_pad(
patch_size, future_target, -2, left=False
),
"... (seq patch) dim -> ... (dim seq) patch",
patch=patch_size,
),
(0, self.max_patch_size - patch_size),
),
]
)
if future_observed_target is None:
future_observed_target = torch.ones(
batch_shape
+ (
self.hparams.prediction_length,
past_observed_target.shape[-1],
),
dtype=torch.bool,
device=device,
)
observed_mask.extend(
[
torch.nn.functional.pad(
rearrange(
self._patched_seq_pad(
patch_size, past_observed_target, -2, left=True
),
"... (seq patch) dim -> ... (dim seq) patch",
patch=patch_size,
),
(0, self.max_patch_size - patch_size),
),
torch.nn.functional.pad(
rearrange(
self._patched_seq_pad(
patch_size, future_observed_target, -2, left=False
),
"... (seq patch) dim -> ... (dim seq) patch",
patch=patch_size,
),
(0, self.max_patch_size - patch_size),
),
]
)
if future_is_pad is None:
future_is_pad = torch.zeros(
batch_shape + (self.hparams.prediction_length,),
dtype=torch.long,
device=device,
)
sample_id.extend(
[
repeat(
reduce(
(
self._patched_seq_pad(
patch_size, past_is_pad, -1, left=True, value=1
)
== 0
).int(),
"... (seq patch) -> ... seq",
"max",
patch=patch_size,
),
"... seq -> ... (dim seq)",
dim=past_target.shape[-1],
),
repeat(
reduce(
(
self._patched_seq_pad(
patch_size, future_is_pad, -1, left=False, value=1
)
== 0
).int(),
"... (seq patch) -> ... seq",
"max",
patch=patch_size,
),
"... seq -> ... (dim seq)",
dim=past_target.shape[-1],
),
]
)
time_id.extend(
[past_seq_id] * past_target.shape[-1]
+ [future_seq_id] * past_target.shape[-1]
)
variate_id.extend(
[
repeat(
torch.arange(past_target.shape[-1], device=device) + dim_count,
f"dim -> {' '.join(map(str, batch_shape))} (dim past)",
past=self.context_token_length(patch_size),
),
repeat(
torch.arange(past_target.shape[-1], device=device) + dim_count,
f"dim -> {' '.join(map(str, batch_shape))} (dim future)",
# future=self.prediction_token_length(patch_size),
future = math.ceil(future_target.shape[-2] / patch_size)
),
]
)
dim_count += past_target.shape[-1]
prediction_mask.extend(
[
torch.zeros(
batch_shape
+ (self.context_token_length(patch_size) * past_target.shape[-1],),
dtype=torch.bool,
device=device,
),
torch.ones(
batch_shape
+ (
# self.prediction_token_length(patch_size)
math.ceil(future_target.shape[-2] / patch_size)
* past_target.shape[-1],
),
dtype=torch.bool,
device=device,
),
]
)
target = torch.cat(target, dim=-2)
observed_mask = torch.cat(observed_mask, dim=-2)
sample_id = torch.cat(sample_id, dim=-1)
time_id = torch.cat(time_id, dim=-1)
variate_id = torch.cat(variate_id, dim=-1)
prediction_mask = torch.cat(prediction_mask, dim=-1)
return (
target,
observed_mask,
sample_id,
time_id,
variate_id,
prediction_mask,
)
def _format_preds(
self,
patch_size: int,
preds: Float[torch.Tensor, "sample batch combine_seq patch"],
target_dim: int,
) -> Float[torch.Tensor, "batch sample future_time *tgt"]:
start = target_dim * self.context_token_length(patch_size)
end = start + target_dim * self.prediction_token_length(patch_size)
preds = preds[..., start:end, :patch_size]
preds = rearrange(
preds,
"sample ... (dim seq) patch -> ... sample (seq patch) dim",
dim=target_dim,
)[..., : self.hparams.prediction_length, :]
return preds.squeeze(-1)
================================================
FILE: probts/model/nn/arch/PatchTSTModule/PatchTST_backbone.py
================================================
# ---------------------------------------------------------------------------------
# Portions of this file are derived from PatchTST
# - Source: https://github.com/yuqinie98/PatchTST/tree/main
#
# We thank the authors for their contributions.
# ---------------------------------------------------------------------------------
__all__ = ['PatchTST_backbone']
# Cell
from typing import Callable, Optional
import torch
from torch import nn
from torch import Tensor
import torch.nn.functional as F
import numpy as np
#from collections import OrderedDict
from probts.model.nn.arch.PatchTSTModule.PatchTST_layers import *
from probts.model.nn.arch.RevIN import RevIN
# Cell
class PatchTST_backbone(nn.Module):
def __init__(self, c_in:int, context_window:int, target_window:int, patch_len:int, stride:int, max_seq_len:Optional[int]=1024,
n_layers:int=3, d_model=128, n_heads=16, d_k:Optional[int]=None, d_v:Optional[int]=None,
d_ff:int=256, norm:str='BatchNorm', attn_dropout:float=0., dropout:float=0., act:str="gelu", key_padding_mask:bool='auto',
padding_var:Optional[int]=None, attn_mask:Optional[Tensor]=None, res_attention:bool=True, pre_norm:bool=False, store_attn:bool=False,
pe:str='zeros', learn_pe:bool=True, fc_dropout:float=0., head_dropout = 0, padding_patch = None,
pretrain_head:bool=False, head_type = 'flatten', individual = False, revin = True, affine = True, subtract_last = False,
verbose:bool=False):
super().__init__()
# RevIn
self.revin = revin
if self.revin: self.revin_layer = RevIN(c_in, affine=affine, subtract_last=subtract_last)
# Patching
self.patch_len = patch_len
self.stride = stride
self.padding_patch = padding_patch
patch_num = int((context_window - patch_len)/stride + 1)
if padding_patch == 'end': # can be modified to general case
self.padding_patch_layer = nn.ReplicationPad1d((0, stride))
patch_num += 1
# Backbone
self.backbone = TSTiEncoder(c_in, patch_num=patch_num, patch_len=patch_len, max_seq_len=max_seq_len,
n_layers=n_layers, d_model=d_model, n_heads=n_heads, d_k=d_k, d_v=d_v, d_ff=d_ff,
attn_dropout=attn_dropout, dropout=dropout, act=act, key_padding_mask=key_padding_mask, padding_var=padding_var,
attn_mask=attn_mask, res_attention=res_attention, pre_norm=pre_norm, store_attn=store_attn,
pe=pe, learn_pe=learn_pe, verbose=verbose)
# Head
self.head_nf = d_model * patch_num
self.n_vars = c_in
self.pretrain_head = pretrain_head
self.head_type = head_type
self.individual = individual
if self.pretrain_head:
self.head = self.create_pretrain_head(self.head_nf, c_in, fc_dropout) # custom head passed as a partial func with all its kwargs
elif head_type == 'flatten':
self.head = Flatten_Head(self.individual, self.n_vars, self.head_nf, target_window, head_dropout=head_dropout)
def forward(self, z): # z: [bs x nvars x seq_len]
# norm
if self.revin:
z = z.permute(0,2,1)
z = self.revin_layer(z, 'norm')
z = z.permute(0,2,1)
# do patching
if self.padding_patch == 'end':
z = self.padding_patch_layer(z)
z = z.unfold(dimension=-1, size=self.patch_len, step=self.stride) # z: [bs x nvars x patch_num x patch_len]
z = z.permute(0,1,3,2) # z: [bs x nvars x patch_len x patch_num]
# model
z = self.backbone(z) # z: [bs x nvars x d_model x patch_num]
z = self.head(z) # z: [bs x nvars x target_window]
# denorm
if self.revin:
z = z.permute(0,2,1)
z = self.revin_layer(z, 'denorm')
z = z.permute(0,2,1)
return z
def create_pretrain_head(self, head_nf, vars, dropout):
return nn.Sequential(nn.Dropout(dropout),
nn.Conv1d(head_nf, vars, 1)
)
class Flatten_Head(nn.Module):
def __init__(self, individual, n_vars, nf, target_window, head_dropout=0):
super().__init__()
self.individual = individual
self.n_vars = n_vars
if self.individual:
self.linears = nn.ModuleList()
self.dropouts = nn.ModuleList()
self.flattens = nn.ModuleList()
for i in range(self.n_vars):
self.flattens.append(nn.Flatten(start_dim=-2))
self.linears.append(nn.Linear(nf, target_window))
self.dropouts.append(nn.Dropout(head_dropout))
else:
self.flatten = nn.Flatten(start_dim=-2)
self.linear = nn.Linear(nf, target_window)
self.dropout = nn.Dropout(head_dropout)
def forward(self, x): # x: [bs x nvars x d_model x patch_num]
if self.individual:
x_out = []
for i in range(self.n_vars):
z = self.flattens[i](x[:,i,:,:]) # z: [bs x d_model * patch_num]
z = self.linears[i](z) # z: [bs x target_window]
z = self.dropouts[i](z)
x_out.append(z)
x = torch.stack(x_out, dim=1) # x: [bs x nvars x target_window]
else:
x = self.flatten(x)
x = self.linear(x)
x = self.dropout(x)
return x
class TSTiEncoder(nn.Module): #i means channel-independent
def __init__(self, c_in, patch_num, patch_len, max_seq_len=1024,
n_layers=3, d_model=128, n_heads=16, d_k=None, d_v=None,
d_ff=256, norm='BatchNorm', attn_dropout=0., dropout=0., act="gelu", store_attn=False,
key_padding_mask='auto', padding_var=None, attn_mask=None, res_attention=True, pre_norm=False,
pe='zeros', learn_pe=True, verbose=False):
super().__init__()
self.patch_num = patch_num
self.patch_len = patch_len
# Input encoding
q_len = patch_num
self.W_P = nn.Linear(patch_len, d_model) # Eq 1: projection of feature vectors onto a d-dim vector space
self.seq_len = q_len
# Positional encoding
self.W_pos = positional_encoding(pe, learn_pe, q_len, d_model)
# Residual dropout
self.dropout = nn.Dropout(dropout)
# Encoder
self.encoder = TSTEncoder(q_len, d_model, n_heads, d_k=d_k, d_v=d_v, d_ff=d_ff, norm=norm, attn_dropout=attn_dropout, dropout=dropout,
pre_norm=pre_norm, activation=act, res_attention=res_attention, n_layers=n_layers, store_attn=store_attn)
def forward(self, x) -> Tensor: # x: [bs x nvars x patch_len x patch_num]
n_vars = x.shape[1]
# Input encoding
x = x.permute(0,1,3,2) # x: [bs x nvars x patch_num x patch_len]
x = self.W_P(x) # x: [bs x nvars x patch_num x d_model]
u = torch.reshape(x, (x.shape[0]*x.shape[1],x.shape[2],x.shape[3])) # u: [bs * nvars x patch_num x d_model]
u = self.dropout(u + self.W_pos) # u: [bs * nvars x patch_num x d_model]
# Encoder
z = self.encoder(u) # z: [bs * nvars x patch_num x d_model]
z = torch.reshape(z, (-1,n_vars,z.shape[-2],z.shape[-1])) # z: [bs x nvars x patch_num x d_model]
z = z.permute(0,1,3,2) # z: [bs x nvars x d_model x patch_num]
return z
# Cell
class TSTEncoder(nn.Module):
def __init__(self, q_len, d_model, n_heads, d_k=None, d_v=None, d_ff=None,
norm='BatchNorm', attn_dropout=0., dropout=0., activation='gelu',
res_attention=False, n_layers=1, pre_norm=False, store_attn=False):
super().__init__()
self.layers = nn.ModuleList([TSTEncoderLayer(q_len, d_model, n_heads=n_heads, d_k=d_k, d_v=d_v, d_ff=d_ff, norm=norm,
attn_dropout=attn_dropout, dropout=dropout,
activation=activation, res_attention=res_attention,
pre_norm=pre_norm, store_attn=store_attn) for i in range(n_layers)])
self.res_attention = res_attention
def forward(self, src:Tensor, key_padding_mask:Optional[Tensor]=None, attn_mask:Optional[Tensor]=None):
output = src
scores = None
if self.res_attention:
for mod in self.layers: output, scores = mod(output, prev=scores, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
return output
else:
for mod in self.layers: output = mod(output, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
return output
class TSTEncoderLayer(nn.Module):
def __init__(self, q_len, d_model, n_heads, d_k=None, d_v=None, d_ff=256, store_attn=False,
norm='BatchNorm', attn_dropout=0, dropout=0., bias=True, activation="gelu", res_attention=False, pre_norm=False):
super().__init__()
assert not d_model%n_heads, f"d_model ({d_model}) must be divisible by n_heads ({n_heads})"
d_k = d_model // n_heads if d_k is None else d_k
d_v = d_model // n_heads if d_v is None else d_v
# Multi-Head attention
self.res_attention = res_attention
self.self_attn = _MultiheadAttention(d_model, n_heads, d_k, d_v, attn_dropout=attn_dropout, proj_dropout=dropout, res_attention=res_attention)
# Add & Norm
self.dropout_attn = nn.Dropout(dropout)
if "batch" in norm.lower():
self.norm_attn = nn.Sequential(Transpose(1,2), nn.BatchNorm1d(d_model), Transpose(1,2))
else:
self.norm_attn = nn.LayerNorm(d_model)
# Position-wise Feed-Forward
self.ff = nn.Sequential(nn.Linear(d_model, d_ff, bias=bias),
get_activation_fn(activation),
nn.Dropout(dropout),
nn.Linear(d_ff, d_model, bias=bias))
# Add & Norm
self.dropout_ffn = nn.Dropout(dropout)
if "batch" in norm.lower():
self.norm_ffn = nn.Sequential(Transpose(1,2), nn.BatchNorm1d(d_model), Transpose(1,2))
else:
self.norm_ffn = nn.LayerNorm(d_model)
self.pre_norm = pre_norm
self.store_attn = store_attn
def forward(self, src:Tensor, prev:Optional[Tensor]=None, key_padding_mask:Optional[Tensor]=None, attn_mask:Optional[Tensor]=None) -> Tensor:
# Multi-Head attention sublayer
if self.pre_norm:
src = self.norm_attn(src)
## Multi-Head attention
if self.res_attention:
src2, attn, scores = self.self_attn(src, src, src, prev, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
else:
src2, attn = self.self_attn(src, src, src, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
if self.store_attn:
self.attn = attn
## Add & Norm
src = src + self.dropout_attn(src2) # Add: residual connection with residual dropout
if not self.pre_norm:
src = self.norm_attn(src)
# Feed-forward sublayer
if self.pre_norm:
src = self.norm_ffn(src)
## Position-wise Feed-Forward
src2 = self.ff(src)
## Add & Norm
src = src + self.dropout_ffn(src2) # Add: residual connection with residual dropout
if not self.pre_norm:
src = self.norm_ffn(src)
if self.res_attention:
return src, scores
else:
return src
class _MultiheadAttention(nn.Module):
def __init__(self, d_model, n_heads, d_k=None, d_v=None, res_attention=False, attn_dropout=0., proj_dropout=0., qkv_bias=True):
"""Multi Head Attention Layer
Input shape:
Q: [batch_size (bs) x max_q_len x d_model]
K, V: [batch_size (bs) x q_len x d_model]
mask: [q_len x q_len]
"""
super().__init__()
d_k = d_model // n_heads if d_k is None else d_k
d_v = d_model // n_heads if d_v is None else d_v
self.n_heads, self.d_k, self.d_v = n_heads, d_k, d_v
self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=qkv_bias)
self.W_K = nn.Linear(d_model, d_k * n_heads, bias=qkv_bias)
self.W_V = nn.Linear(d_model, d_v * n_heads, bias=qkv_bias)
# Scaled Dot-Product Attention (multiple heads)
self.res_attention = res_attention
self.sdp_attn = _ScaledDotProductAttention(d_model, n_heads, attn_dropout=attn_dropout, res_attention=self.res_attention)
# Poject output
self.to_out = nn.Sequential(nn.Linear(n_heads * d_v, d_model), nn.Dropout(proj_dropout))
def forward(self, Q:Tensor, K:Optional[Tensor]=None, V:Optional[Tensor]=None, prev:Optional[Tensor]=None,
key_padding_mask:Optional[Tensor]=None, attn_mask:Optional[Tensor]=None):
bs = Q.size(0)
if K is None: K = Q
if V is None: V = Q
# Linear (+ split in multiple heads)
q_s = self.W_Q(Q).view(bs, -1, self.n_heads, self.d_k).transpose(1,2) # q_s : [bs x n_heads x max_q_len x d_k]
k_s = self.W_K(K).view(bs, -1, self.n_heads, self.d_k).permute(0,2,3,1) # k_s : [bs x n_heads x d_k x q_len] - transpose(1,2) + transpose(2,3)
v_s = self.W_V(V).view(bs, -1, self.n_heads, self.d_v).transpose(1,2) # v_s : [bs x n_heads x q_len x d_v]
# Apply Scaled Dot-Product Attention (multiple heads)
if self.res_attention:
output, attn_weights, attn_scores = self.sdp_attn(q_s, k_s, v_s, prev=prev, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
else:
output, attn_weights = self.sdp_attn(q_s, k_s, v_s, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
# output: [bs x n_heads x q_len x d_v], attn: [bs x n_heads x q_len x q_len], scores: [bs x n_heads x max_q_len x q_len]
# back to the original inputs dimensions
output = output.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * self.d_v) # output: [bs x q_len x n_heads * d_v]
output = self.to_out(output)
if self.res_attention: return output, attn_weights, attn_scores
else: return output, attn_weights
class _ScaledDotProductAttention(nn.Module):
r"""Scaled Dot-Product Attention module (Attention is all you need by Vaswani et al., 2017) with optional residual attention from previous layer
(Realformer: Transformer likes residual attention by He et al, 2020) and locality self sttention (Vision Transformer for Small-Size Datasets
by Lee et al, 2021)"""
def __init__(self, d_model, n_heads, attn_dropout=0., res_attention=False):
super().__init__()
self.attn_dropout = nn.Dropout(attn_dropout)
self.res_attention = res_attention
head_dim = d_model // n_heads
self.scale = torch.tensor(head_dim ** -0.5)
def forward(self, q:Tensor, k:Tensor, v:Tensor, prev:Optional[Tensor]=None, key_padding_mask:Optional[Tensor]=None, attn_mask:Optional[Tensor]=None):
'''
Input shape:
q : [bs x n_heads x max_q_len x d_k]
k : [bs x n_heads x d_k x seq_len]
v : [bs x n_heads x seq_len x d_v]
prev : [bs x n_heads x q_len x seq_len]
key_padding_mask: [bs x seq_len]
attn_mask : [1 x seq_len x seq_len]
Output shape:
output: [bs x n_heads x q_len x d_v]
attn : [bs x n_heads x q_len x seq_len]
scores : [bs x n_heads x q_len x seq_len]
'''
# Scaled MatMul (q, k) - similarity scores for all pairs of positions in an input sequence
attn_scores = torch.matmul(q, k) * self.scale # attn_scores : [bs x n_heads x max_q_len x q_len]
# Add pre-softmax attention scores from the previous layer (optional)
if prev is not None: attn_scores = attn_scores + prev
# Attention mask (optional)
if attn_mask is not None: # attn_mask with shape [q_len x seq_len] - only used when q_len == seq_len
if attn_mask.dtype == torch.bool:
attn_scores.masked_fill_(attn_mask, -np.inf)
else:
attn_scores += attn_mask
# Key padding mask (optional)
if key_padding_mask is not None: # mask with shape [bs x q_len] (only when max_w_len == q_len)
attn_scores.masked_fill_(key_padding_mask.unsqueeze(1).unsqueeze(2), -np.inf)
# normalize the attention weights
attn_weights = F.softmax(attn_scores, dim=-1) # attn_weights : [bs x n_heads x max_q_len x q_len]
attn_weights = self.attn_dropout(attn_weights)
# compute the new values given the attention weights
output = torch.matmul(attn_weights, v) # output: [bs x n_heads x max_q_len x d_v]
if self.res_attention: return output, attn_weights, attn_scores
else: return output, attn_weights
================================================
FILE: probts/model/nn/arch/PatchTSTModule/PatchTST_layers.py
================================================
# ---------------------------------------------------------------------------------
# Portions of this file are derived from PatchTST
# - Source: https://github.com/yuqinie98/PatchTST/tree/main
#
# We thank the authors for their contributions.
# ---------------------------------------------------------------------------------
__all__ = ['Transpose', 'get_activation_fn', 'moving_avg', 'series_decomp', 'PositionalEncoding', 'SinCosPosEncoding', 'Coord2dPosEncoding', 'Coord1dPosEncoding', 'positional_encoding']
import torch
from torch import nn
import math
class Transpose(nn.Module):
def __init__(self, *dims, contiguous=False):
super().__init__()
self.dims, self.contiguous = dims, contiguous
def forward(self, x):
if self.contiguous: return x.transpose(*self.dims).contiguous()
else: return x.transpose(*self.dims)
def get_activation_fn(activation):
if callable(activation): return activation()
elif activation.lower() == "relu": return nn.ReLU()
elif activation.lower() == "gelu": return nn.GELU()
raise ValueError(f'{activation} is not available. You can use "relu", "gelu", or a callable')
# decomposition
class moving_avg(nn.Module):
"""
Moving average block to highlight the trend of time series
"""
def __init__(self, kernel_size, stride):
super(moving_avg, self).__init__()
self.kernel_size = kernel_size
self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)
def forward(self, x):
# padding on the both ends of time series
front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1)
end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1)
x = torch.cat([front, x, end], dim=1)
x = self.avg(x.permute(0, 2, 1))
x = x.permute(0, 2, 1)
return x
class series_decomp(nn.Module):
"""
Series decomposition block
"""
def __init__(self, kernel_size):
super(series_decomp, self).__init__()
self.moving_avg = moving_avg(kernel_size, stride=1)
def forward(self, x):
moving_mean = self.moving_avg(x)
res = x - moving_mean
return res, moving_mean
# pos_encoding
def PositionalEncoding(q_len, d_model, normalize=True):
pe = torch.zeros(q_len, d_model)
position = torch.arange(0, q_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
if normalize:
pe = pe - pe.mean()
pe = pe / (pe.std() * 10)
return pe
SinCosPosEncoding = PositionalEncoding
def Coord2dPosEncoding(q_len, d_model, exponential=False, normalize=True, eps=1e-3, verbose=False):
x = .5 if exponential else 1
i = 0
for i in range(100):
cpe = 2 * (torch.linspace(0, 1, q_len).reshape(-1, 1) ** x) * (torch.linspace(0, 1, d_model).reshape(1, -1) ** x) - 1
pv(f'{i:4.0f} {x:5.3f} {cpe.mean():+6.3f}', verbose)
if abs(cpe.mean()) <= eps: break
elif cpe.mean() > eps: x += .001
else: x -= .001
i += 1
if normalize:
cpe = cpe - cpe.mean()
cpe = cpe / (cpe.std() * 10)
return cpe
def Coord1dPosEncoding(q_len, exponential=False, normalize=True):
cpe = (2 * (torch.linspace(0, 1, q_len).reshape(-1, 1)**(.5 if exponential else 1)) - 1)
if normalize:
cpe = cpe - cpe.mean()
cpe = cpe / (cpe.std() * 10)
return cpe
def positional_encoding(pe, learn_pe, q_len, d_model):
# Positional encoding
if pe == None:
W_pos = torch.empty((q_len, d_model)) # pe = None and learn_pe = False can be used to measure impact of pe
nn.init.uniform_(W_pos, -0.02, 0.02)
learn_pe = False
elif pe == 'zero':
W_pos = torch.empty((q_len, 1))
nn.init.uniform_(W_pos, -0.02, 0.02)
elif pe == 'zeros':
W_pos = torch.empty((q_len, d_model))
nn.init.uniform_(W_pos, -0.02, 0.02)
elif pe == 'normal' or pe == 'gauss':
W_pos = torch.zeros((q_len, 1))
torch.nn.init.normal_(W_pos, mean=0.0, std=0.1)
elif pe == 'uniform':
W_pos = torch.zeros((q_len, 1))
nn.init.uniform_(W_pos, a=0.0, b=0.1)
elif pe == 'lin1d': W_pos = Coord1dPosEncoding(q_len, exponential=False, normalize=True)
elif pe == 'exp1d': W_pos = Coord1dPosEncoding(q_len, exponential=True, normalize=True)
elif pe == 'lin2d': W_pos = Coord2dPosEncoding(q_len, d_model, exponential=False, normalize=True)
elif pe == 'exp2d': W_pos = Coord2dPosEncoding(q_len, d_model, exponential=True, normalize=True)
elif pe == 'sincos': W_pos = PositionalEncoding(q_len, d_model, normalize=True)
else: raise ValueError(f"{pe} is not a valid pe (positional encoder. Available types: 'gauss'=='normal', \
'zeros', 'zero', uniform', 'lin1d', 'exp1d', 'lin2d', 'exp2d', 'sincos', None.)")
return nn.Parameter(W_pos, requires_grad=learn_pe)
================================================
FILE: probts/model/nn/arch/RevIN.py
================================================
# ---------------------------------------------------------------------------------
# Portions of this file are derived from RevIN
# - Source: https://github.com/ts-kim/RevIN
#
# We thank the authors for their contributions.
# ---------------------------------------------------------------------------------
import torch
import torch.nn as nn
class RevIN(nn.Module):
def __init__(self, num_features: int, eps=1e-5, affine=True, subtract_last=False):
"""
:param num_features: the number of features or channels
:param eps: a value added for numerical stability
:param affine: if True, RevIN has learnable affine parameters
"""
super(RevIN, self).__init__()
self.num_features = num_features
self.eps = eps
self.affine = affine
self.subtract_last = subtract_last
if self.affine:
self._init_params()
def forward(self, x, mode:str):
if mode == 'norm':
self._get_statistics(x)
x = self._normalize(x)
elif mode == 'denorm':
x = self._denormalize(x)
else: raise NotImplementedError
return x
def _init_params(self):
# initialize RevIN params: (C,)
self.affine_weight = nn.Parameter(torch.ones(self.num_features))
self.affine_bias = nn.Parameter(torch.zeros(self.num_features))
def _get_statistics(self, x):
dim2reduce = tuple(range(1, x.ndim-1))
if self.subtract_last:
self.last = x[:,-1,:].unsqueeze(1)
else:
self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach()
self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach()
def _normalize(self, x):
if self.subtract_last:
x = x - self.last
else:
x = x - self.mean
x = x / self.stdev
if self.affine:
x = x * self.affine_weight
x = x + self.affine_bias
return x
def _denormalize(self, x):
if self.affine:
x = x - self.affine_bias
x = x / (self.affine_weight + self.eps*self.eps)
x = x * self.stdev
if self.subtract_last:
x = x + self.last
else:
x = x + self.mean
return x
================================================
FILE: probts/model/nn/arch/S4/s4.py
================================================
"""Standalone version of Structured (Sequence) State Space (S4) model."""
import logging
from functools import partial
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorch_lightning.utilities import rank_zero_only
from einops import rearrange, repeat
import opt_einsum as oe
contract = oe.contract
contract_expression = oe.contract_expression
def get_logger(name=__name__, level=logging.INFO) -> logging.Logger:
"""Initializes multi-GPU-friendly python logger."""
logger = logging.getLogger(name)
logger.setLevel(level)
# this ensures all logging levels get marked with the rank zero decorator
# otherwise logs would get multiplied for each GPU process in multi-GPU setup
for level in (
"debug",
"info",
"warning",
"error",
"exception",
"fatal",
"critical",
):
setattr(logger, level, rank_zero_only(getattr(logger, level)))
return logger
log = get_logger(__name__)
""" Cauchy and Vandermonde kernels """
try: # Try CUDA extension
from extensions.cauchy.cauchy import cauchy_mult
has_cauchy_extension = True
except ImportError:
# log.warning(
# "CUDA extension for cauchy multiplication not found. Install by going to extensions/cauchy/ and running `python setup.py install`. This should speed up end-to-end training by 10-50%"
# )
has_cauchy_extension = False
try: # Try pykeops
from pykeops.torch import Genred
has_pykeops = True
# log.info("Pykeops installation found.")
def _broadcast_dims(*tensors):
max_dim = max([len(tensor.shape) for tensor in tensors])
tensors = [
tensor.view((1,) * (max_dim - len(tensor.shape)) + tensor.shape)
for tensor in tensors
]
return tensors
def cauchy_conj(v, z, w):
"""Pykeops version"""
expr_num = "z * ComplexReal(v) - Real2Complex(Sum(v * w))"
expr_denom = "ComplexMult(z-w, z-Conj(w))"
cauchy_mult = Genred(
f"ComplexDivide({expr_num}, {expr_denom})",
[
"v = Vj(2)",
"z = Vi(2)",
"w = Vj(2)",
],
reduction_op="Sum",
axis=1,
)
v, z, w = _broadcast_dims(v, z, w)
v = _c2r(v)
z = _c2r(z)
w = _c2r(w)
r = 2 * cauchy_mult(v, z, w, backend="GPU")
return _r2c(r)
def log_vandermonde(v, x, L):
expr = "ComplexMult(v, ComplexExp(ComplexMult(x, l)))"
vandermonde_mult = Genred(
expr,
[
"v = Vj(2)",
"x = Vj(2)",
"l = Vi(2)",
],
reduction_op="Sum",
axis=1,
)
l = torch.arange(L).to(x)
v, x, l = _broadcast_dims(v, x, l)
v = _c2r(v)
x = _c2r(x)
l = _c2r(l)
r = vandermonde_mult(v, x, l, backend="GPU")
return 2 * _r2c(r).real
def log_vandermonde_transpose(u, v, x, L):
"""
u: ... H L
v: ... H N
x: ... H N
Returns: ... H N
V = Vandermonde(a, L) : (H N L)
contract_L(V * u * v)
"""
expr = "ComplexMult(ComplexMult(v, u), ComplexExp(ComplexMult(x, l)))"
vandermonde_mult = Genred(
expr,
[
"u = Vj(2)",
"v = Vi(2)",
"x = Vi(2)",
"l = Vj(2)",
],
reduction_op="Sum",
axis=1,
)
l = torch.arange(L).to(x)
u, v, x, l = _broadcast_dims(u, v, x, l)
u = _c2r(u)
v = _c2r(v)
x = _c2r(x)
l = _c2r(l)
r = vandermonde_mult(u, v, x, l, backend="GPU")
return _r2c(r)
except ImportError:
has_pykeops = False
if not has_cauchy_extension:
# log.warning(
# "Falling back on slow Cauchy kernel. Install at least one of pykeops or the CUDA extension for efficiency."
# )
def cauchy_naive(v, z, w):
"""
v, w: (..., N)
z: (..., L)
returns: (..., L)
"""
cauchy_matrix = v.unsqueeze(-1) / (
z.unsqueeze(-2) - w.unsqueeze(-1)
) # (... N L)
return torch.sum(cauchy_matrix, dim=-2)
# Vandermonde functions
# log.warning(
# "Falling back on slow Vandermonde kernel. Install pykeops for improved memory efficiency."
# )
def log_vandermonde(v, x, L):
"""
v: (..., N)
x: (..., N)
returns: (..., L) \sum v x^l
"""
vandermonde_matrix = torch.exp(
x.unsqueeze(-1) * torch.arange(L).to(x)
) # (... N L)
vandermonde_prod = contract(
"... n, ... n l -> ... l", v, vandermonde_matrix
) # (... L)
return 2 * vandermonde_prod.real
def log_vandermonde_transpose(u, v, x, L):
vandermonde_matrix = torch.exp(
x.unsqueeze(-1) * torch.arange(L).to(x)
) # (... N L)
vandermonde_prod = contract(
"... l, ... n, ... n l -> ... n",
u.to(x),
v.to(x),
vandermonde_matrix,
) # (... L)
return vandermonde_prod
def _conj(x):
return torch.cat([x, x.conj()], dim=-1)
_c2r = torch.view_as_real
_r2c = torch.view_as_complex
if tuple(map(int, torch.__version__.split(".")[:2])) >= (1, 10):
def _resolve_conj(x):
return x.conj().resolve_conj()
else:
def _resolve_conj(x):
return x.conj()
""" Simple nn.Module components """
def Activation(activation=None, dim=-1):
if activation in [None, "id", "identity", "linear"]:
return nn.Identity()
elif activation == "tanh":
return nn.Tanh()
elif activation == "relu":
return nn.ReLU()
elif activation == "gelu":
return nn.GELU()
elif activation in ["swish", "silu"]:
return nn.SiLU()
elif activation == "glu":
return nn.GLU(dim=dim)
elif activation == "sigmoid":
return nn.Sigmoid()
else:
raise NotImplementedError(
"hidden activation '{}' is not implemented".format(activation)
)
def LinearActivation(
d_input,
d_output,
bias=True,
transposed=False,
activation=None,
activate=False, # Apply activation as part of this module
**kwargs,
):
"""Returns a linear nn.Module with control over axes order, initialization, and activation"""
# Construct core module
linear_cls = partial(nn.Conv1d, kernel_size=1) if transposed else nn.Linear
if activation == "glu":
d_output *= 2
linear = linear_cls(d_input, d_output, bias=bias, **kwargs)
if activate and activation is not None:
activation = Activation(activation, dim=-2 if transposed else -1)
linear = nn.Sequential(linear, activation)
return linear
class DropoutNd(nn.Module):
def __init__(self, p: float = 0.5, tie=True, transposed=True):
"""
tie: tie dropout mask across sequence lengths (Dropout1d/2d/3d)
"""
super().__init__()
if p < 0 or p >= 1:
raise ValueError(
"dropout probability has to be in [0, 1), "
"but got {}".format(p)
)
self.p = p
self.tie = tie
self.transposed = transposed
self.binomial = torch.distributions.binomial.Binomial(probs=1 - self.p)
def forward(self, X):
"""X: (batch, dim, lengths...)"""
if self.training:
if not self.transposed:
X = rearrange(X, "b d ... -> b ... d")
mask_shape = (
X.shape[:2] + (1,) * (X.ndim - 2) if self.tie else X.shape
)
mask = torch.rand(*mask_shape, device=X.device) < 1.0 - self.p
X = X * mask * (1.0 / (1 - self.p))
if not self.transposed:
X = rearrange(X, "b ... d -> b d ...")
return X
return X
""" Misc functional utilities """
def power(L, A, v=None):
"""Compute A^L and the scan sum_i A^i v_i
A: (..., N, N)
v: (..., N, L)
"""
I = torch.eye(A.shape[-1]).to(A) # , dtype=A.dtype, device=A.device)
powers = [A]
l = 1
while True:
if L % 2 == 1:
I = powers[-1] @ I
L //= 2
if L == 0:
break
l *= 2
powers.append(powers[-1] @ powers[-1])
if v is None:
return I
# Invariants:
# powers[-1] := A^l
# l := largest po2 at most L
# Note that an alternative divide and conquer to compute the reduction is possible and can be embedded into the above loop without caching intermediate powers of A
# We do this reverse divide-and-conquer for efficiency reasons:
# 1) it involves fewer padding steps for non-po2 L
# 2) it involves more contiguous arrays
# Take care of edge case for non-po2 arrays
# Note that this initial step is a no-op for the case of power of 2 (l == L)
k = v.size(-1) - l
v_ = powers.pop() @ v[..., l:]
v = v[..., :l]
v[..., :k] = v[..., :k] + v_
# Handle reduction for power of 2
while v.size(-1) > 1:
v = rearrange(v, "... (z l) -> ... z l", z=2)
v = v[..., 0, :] + powers.pop() @ v[..., 1, :]
return I, v.squeeze(-1)
""" HiPPO utilities """
def transition(measure, N):
"""A, B transition matrices for different measures"""
# Legendre (translated)
if measure == "legt":
Q = np.arange(N, dtype=np.float64)
R = (2 * Q + 1) ** 0.5
j, i = np.meshgrid(Q, Q)
A = R[:, None] * np.where(i < j, (-1.0) ** (i - j), 1) * R[None, :]
B = R[:, None]
A = -A
# Halve again for timescale correctness
A *= 0.5
B *= 0.5
# Legendre (scaled)
elif measure == "legs":
q = np.arange(N, dtype=np.float64)
col, row = np.meshgrid(q, q)
r = 2 * q + 1
M = -(np.where(row >= col, r, 0) - np.diag(q))
T = np.sqrt(np.diag(2 * q + 1))
A = T @ M @ np.linalg.inv(T)
B = np.diag(T)[:, None]
B = (
B.copy()
) # Otherwise "UserWarning: given NumPY array is not writeable..." after torch.as_tensor(B)
elif measure == "legsd":
# Essentially equivalent to S4D-LegS
q = np.arange(N, dtype=np.float64)
col, row = np.meshgrid(q, q)
r = 2 * q + 1
M = -(np.where(row >= col, r, 0) - np.diag(q))
T = np.sqrt(np.diag(2 * q + 1))
A = T @ M @ np.linalg.inv(T)
B = np.diag(T)[:, None]
B = (
B.copy()
) # Otherwise "UserWarning: given NumPY array is not writeable..." after torch.as_tensor(B)
A += 0.5 * B * B[None, :, 0]
B = B / 2.0
elif measure in ["fourier_diag", "foud"]:
# Essentially equivalent to S4D-Lin
freqs = np.arange(N // 2)
d = np.stack([freqs, np.zeros(N // 2)], axis=-1).reshape(-1)[:-1]
A = 2 * np.pi * (-np.diag(d, 1) + np.diag(d, -1))
A = A - 0.5 * np.eye(N)
B = np.zeros(N)
B[0::2] = 2**0.5
B[0] = 1
B = B[:, None]
elif measure in ["fourier", "fout"]:
freqs = np.arange(N // 2)
d = np.stack([np.zeros(N // 2), freqs], axis=-1).reshape(-1)[1:]
A = np.pi * (-np.diag(d, 1) + np.diag(d, -1))
B = np.zeros(N)
B[0::2] = 2**0.5
B[0] = 1
# Subtract off rank correction - this corresponds to the other endpoint u(t-1) in this case
A = A - B[:, None] * B[None, :]
B = B[:, None]
else:
raise NotImplementedError
return A, B
def rank_correction(measure, N, rank=1, dtype=torch.float):
"""Return low-rank matrix L such that A + L is normal"""
if measure == "legs":
assert rank >= 1
P = torch.sqrt(0.5 + torch.arange(N, dtype=dtype)).unsqueeze(
0
) # (1 N)
elif measure == "legt":
assert rank >= 2
P = torch.sqrt(1 + 2 * torch.arange(N, dtype=dtype)) # (N)
P0 = P.clone()
P0[0::2] = 0.0
P1 = P.clone()
P1[1::2] = 0.0
P = torch.stack([P0, P1], dim=0) # (2 N)
P *= 2 ** (
-0.5
) # Halve the rank correct just like the original matrix was halved
elif measure in ["fourier", "fout"]:
P = torch.zeros(N)
P[0::2] = 2**0.5
P[0] = 1
P = P.unsqueeze(0)
elif measure in ["fourier_diag", "foud", "legsd"]:
P = torch.zeros(1, N, dtype=dtype)
else:
raise NotImplementedError
d = P.size(0)
if rank > d:
P = torch.cat(
[P, torch.zeros(rank - d, N, dtype=dtype)], dim=0
) # (rank N)
return P
def nplr(measure, N, rank=1, dtype=torch.float, diagonalize_precision=True):
"""Return w, p, q, V, B such that
(w - p q^*, B) is unitarily equivalent to the original HiPPO A, B by the matrix V
i.e. A = V[w - p q^*]V^*, B = V B
"""
assert dtype == torch.float or dtype == torch.double
cdtype = torch.cfloat if dtype == torch.float else torch.cdouble
A, B = transition(measure, N)
A = torch.as_tensor(A, dtype=dtype) # (N, N)
B = torch.as_tensor(B, dtype=dtype)[:, 0] # (N,)
P = rank_correction(measure, N, rank=rank, dtype=dtype) # (r N)
AP = A + torch.sum(P.unsqueeze(-2) * P.unsqueeze(-1), dim=-3)
# We require AP to be nearly skew-symmetric
_A = AP + AP.transpose(-1, -2)
if (
err := torch.sum((_A - _A[0, 0] * torch.eye(N)) ** 2) / N
) > 1e-5: # if not torch.allclose(_A - _A[0,0]*torch.eye(N), torch.zeros(N, N), atol=1e-5):
print("WARNING: HiPPO matrix not skew symmetric", err)
# Take advantage of identity + skew-symmetric form to calculate real and imaginary parts separately
# Imaginary part can use eigh instead of eig
w_re = torch.mean(torch.diagonal(AP), -1, keepdim=True)
# Diagonalize in double precision
if diagonalize_precision:
AP = AP.to(torch.double)
w_im, V = torch.linalg.eigh(AP * -1j) # (..., N) (..., N, N)
if diagonalize_precision:
w_im, V = w_im.to(cdtype), V.to(cdtype)
w = w_re + 1j * w_im
# Check: V w V^{-1} = A
# print("check", V @ torch.diag_embed(w) @ V.conj().transpose(-1, -2))
# Only keep half of each conjugate pair
_, idx = torch.sort(w.imag)
w_sorted = w[idx]
V_sorted = V[:, idx]
# There is an edge case when eigenvalues can be 0, which requires some machinery to handle
# We use a huge hack here: Assume only one pair is 0, and that it is the first row/column of A (only happens in Fourier case)
V = V_sorted[:, : N // 2]
w = w_sorted[: N // 2]
assert (
w[-2].abs() > 1e-4
), "Only 1 zero eigenvalue allowed in diagonal part of A"
if w[-1].abs() < 1e-4:
V[:, -1] = 0.0
V[0, -1] = 2**-0.5
V[1, -1] = 2**-0.5 * 1j
_AP = V @ torch.diag_embed(w) @ V.conj().transpose(-1, -2)
if (err := torch.sum((2 * _AP.real - AP) ** 2) / N) > 1e-5:
print(
"Warning: Diagonalization of A matrix not numerically precise - error",
err,
)
# print("check", V @ torch.diag_embed(w) @ V.conj().transpose(-1, -2))
V_inv = V.conj().transpose(-1, -2)
B = contract("ij, j -> i", V_inv, B.to(V)) # V^* B
P = contract("ij, ...j -> ...i", V_inv, P.to(V)) # V^* P
return w, P, B, V
def dplr(
scaling,
N,
rank=1,
H=1,
dtype=torch.float,
real_scale=1.0,
imag_scale=1.0,
random_real=False,
random_imag=False,
normalize=False,
diagonal=True,
random_B=False,
):
assert dtype == torch.float or dtype == torch.double
dtype = torch.cfloat if dtype == torch.float else torch.cdouble
pi = torch.tensor(math.pi)
if random_real:
real_part = torch.rand(H, N // 2)
else:
real_part = 0.5 * torch.ones(H, N // 2)
if random_imag:
imag_part = N // 2 * torch.rand(H, N // 2)
else:
imag_part = repeat(torch.arange(N // 2), "n -> h n", h=H)
real_part = real_scale * real_part
if scaling == "random":
imag_part = torch.randn(H, N // 2)
elif scaling == "real":
imag_part = 0 * imag_part
real_part = 1 + repeat(torch.arange(N // 2), "n -> h n", h=H)
elif scaling in ["linear", "lin"]:
imag_part = pi * imag_part
elif scaling in [
"inverse",
"inv",
]: # Based on asymptotics of the default HiPPO matrix
imag_part = 1 / pi * N * (N / (1 + 2 * imag_part) - 1)
elif scaling in ["inverse2", "inv2"]:
imag_part = 1 / pi * N * (N / (1 + imag_part) - 1)
elif scaling in ["quadratic", "quad"]:
imag_part = 1 / pi * (1 + 2 * imag_part) ** 2
elif scaling in ["legs", "hippo"]:
w, _, _, _ = nplr("legsd", N)
imag_part = w.imag
else:
raise NotImplementedError
imag_part = imag_scale * imag_part
w = -real_part + 1j * imag_part
# Initialize B
if random_B:
B = torch.randn(H, N // 2, dtype=dtype)
else:
B = torch.ones(H, N // 2, dtype=dtype)
if normalize:
norm = (
-B / w
) # (H, N) # Result if you integrate the kernel with constant 1 function
zeta = 2 * torch.sum(
torch.abs(norm) ** 2, dim=-1, keepdim=True
) # Variance with a random C vector
B = B / zeta**0.5
P = torch.randn(rank, H, N // 2, dtype=dtype)
if diagonal:
P = P * 0.0
V = torch.eye(N, dtype=dtype)[:: N // 2] # Only used in testing
V = repeat(V, "n m -> h n m", h=H)
return w, P, B, V
def ssm(measure, N, R, H, **ssm_args):
"""Dispatcher to create single SSM initialization
N: state size
R: rank (for DPLR parameterization)
H: number of independent SSM copies
"""
if measure == "dplr":
w, P, B, V = dplr(N=N, rank=R, H=H, **ssm_args)
elif measure.startswith("diag"):
args = measure.split("-")
assert args[0] == "diag" and len(args) > 1
scaling = args[1]
w, P, B, V = dplr(
scaling=scaling, N=N, rank=R, H=H, diagonal=True, **ssm_args
)
else:
w, P, B, V = nplr(measure, N, R, **ssm_args)
w = repeat(w, "n -> s n", s=H)
P = repeat(P, "r n -> r s n", s=H)
B = repeat(B, "n -> s n", s=H)
V = repeat(V, "n m -> s n m", s=H)
return w, P, B, V
combinations = {
"hippo": ["legs", "fourier"],
"diag": ["diag-inv", "diag-lin"],
"all": ["legs", "fourier", "diag-inv", "diag-lin"],
}
def combination(measures, N, R, S, **ssm_args):
if isinstance(measures, str):
measures = (
combinations[measures] if measures in combinations else [measures]
)
assert (
S % len(measures) == 0
), f"{S} independent trainable SSM copies must be multiple of {len(measures)} different measures"
w, P, B, V = zip(
*[
ssm(measure, N, R, S // len(measures), **ssm_args)
for measure in measures
]
)
w = torch.cat(w, dim=0) # (S N)
P = torch.cat(P, dim=1) # (R S N)
B = torch.cat(B, dim=0) # (S N)
V = torch.cat(V, dim=0) # (S N N)
return w, P, B, V
class OptimModule(nn.Module):
"""Interface for Module that allows registering buffers/parameters with configurable optimizer hyperparameters"""
def register(self, name, tensor, lr=None):
"""Register a tensor with a configurable learning rate and 0 weight decay"""
if lr == 0.0:
self.register_buffer(name, tensor)
else:
self.register_parameter(name, nn.Parameter(tensor))
optim = {"weight_decay": 0.0}
if lr is not None:
optim["lr"] = lr
setattr(getattr(self, name), "_optim", optim)
class SSKernelNPLR(OptimModule):
"""Stores a representation of and computes the SSKernel function K_L(A^dt, B^dt, C) corresponding to a discretized state space, where A is Normal + Low Rank (NPLR)"""
@torch.no_grad()
def _setup_C(self, L):
"""Construct C~ from C
Two modes are supported: go directly to length L if self.L is 1, or length is doubled
"""
if self.L.item() == 0:
if self.verbose:
log.info(f"S4: Initializing kernel to length {L}")
double_length = False
elif L > self.L.item(): # 2*int(self.L) == L:
if self.verbose:
log.info(
f"S4: Doubling length from L = {self.L.item()} to {2*self.L.item()}"
)
double_length = True
L = self.L.item() # Convenience for the math below
else:
return
C = _r2c(self.C)
dA, _ = self._setup_state()
dA_L = power(L, dA)
# Multiply C by I - dA_L
C_ = _conj(C)
prod = contract("h m n, c h n -> c h m", dA_L.transpose(-1, -2), C_)
if double_length:
prod = -prod # Multiply by I + dA_L instead
C_ = C_ - prod
C_ = C_[..., : self.N] # Take conjugate pairs again
self.C.copy_(_c2r(C_))
self.L = (
2 * self.L if double_length else self.L + L
) # Preserve type/device
def _omega(self, L, dtype, device, cache=True):
"""Calculate (and cache) FFT nodes and their "unprocessed" version with the bilinear transform
This should be called everytime the internal length self.L changes"""
# Use cached if available
if (
cache
and hasattr(self, "omega")
and self.omega.size(-1) == L // 2 + 1
):
return self.omega, self.z
omega = torch.tensor(
np.exp(-2j * np.pi / (L)), dtype=dtype, device=device
) # \omega_{2L}
omega = omega ** torch.arange(0, L // 2 + 1, device=device)
z = 2 * (1 - omega) / (1 + omega)
# Cache if necessary
if cache:
self.omega = omega
self.z = z
return omega, z
def __init__(
self,
w,
P,
B,
C,
log_dt,
L=None, # starting/maximum length of kernel
lr=None,
verbose=False,
keops=False,
real_type="exp", # ['none' | 'exp' | 'relu' | sigmoid']
real_tolerance=1e-3,
bandlimit=None,
):
"""
L: Maximum length; this module computes an SSM kernel of length L
A is represented by diag(w) - PP^*
w: (S, N) diagonal part
P: (R, S, N) low-rank part
B: (S, N)
C: (C, H, N)
dt: (H) timescale per feature
lr: [dict | float | None] hook to set lr of special parameters (A, B, dt)
Dimensions:
N (or d_state): state size
H (or d_model): total SSM copies
S (or n_ssm): number of trainable copies of (A, B, dt); must divide H
R (or rank): rank of low-rank part
C (or channels): system is 1-dim to C-dim
The forward pass of this Module returns a tensor of shape (C, H, L)
Note: tensor shape N here denotes half the true state size, because of conjugate symmetry
"""
super().__init__()
self.verbose = verbose
self.keops = keops
self.bandlimit = bandlimit
self.real_type = real_type
self.real_tolerance = real_tolerance
# Rank of low-rank correction
self.rank = P.shape[-3]
assert w.size(-1) == P.size(-1) == B.size(-1) == C.size(-1)
self.H = log_dt.size(-1)
self.N = w.size(-1)
# Check different SSM inits
assert w.size(-2) == P.size(-2) == B.size(-2) # n_ssm
assert self.H % w.size(0) == 0
self.n_ssm = w.size(0)
self.repeat = self.H // w.size(
0
) # Each trainable SSM needs to be duplicated this many times
# Broadcast everything to correct shapes
C = C.expand(
torch.broadcast_shapes(C.shape, (1, self.H, self.N))
) # (C, H, N)
B = B.unsqueeze(0) # (1, 1, N)
# Register parameters
self.C = nn.Parameter(_c2r(_resolve_conj(C)))
if lr is None or isinstance(lr, float):
lr_dict = {}
else:
lr_dict, lr = lr, None
self.register("log_dt", log_dt, lr_dict.get("dt", lr))
self.register("B", _c2r(B), lr_dict.get("B", lr))
self.register("P", _c2r(P), lr_dict.get("A", lr))
self.register("inv_w_real", self._w_init(w.real), lr_dict.get("A", lr))
self.register("w_imag", w.imag, lr_dict.get("A", lr))
self.l_max = L
self.register_buffer("L", torch.tensor(0)) # Internal length
def _w_init(self, w_real):
w_real = torch.clamp(w_real, max=-self.real_tolerance)
if self.real_type == "none":
return -w_real
elif self.real_type == "exp":
return torch.log(
-w_real
) # Some of the HiPPO methods have real part 0
elif self.real_type == "relu":
return -w_real
elif self.real_type == "sigmoid":
return torch.logit(-w_real)
elif self.real_type == "softplus":
return torch.log(torch.exp(-w_real) - 1)
else:
raise NotImplementedError
def _w(self):
# Get the internal w (diagonal) parameter
if self.real_type == "none":
w_real = -self.inv_w_real
elif self.real_type == "exp":
w_real = -torch.exp(self.inv_w_real)
elif self.real_type == "relu":
w_real = -F.relu(self.inv_w_real)
elif self.real_type == "sigmoid":
w_real = -F.sigmoid(self.inv_w_real)
elif self.real_type == "softplus":
w_real = -F.softplus(self.inv_w_real)
else:
raise NotImplementedError
w = w_real + 1j * self.w_imag
return w
def forward(self, state=None, rate=1.0, L=None):
"""
state: (B, H, N) initial state
rate: sampling rate factor
L: target length
returns:
(C, H, L) convolution kernel (generally C=1)
(B, H, L) output from initial state
"""
# Initialize C~ if necessary (done in forward pass so it's on the correct device)
if self.L.item() == 0 and self.l_max is not None and self.l_max > 0:
self._setup_C(self.l_max)
# Handle sampling rate logic
# The idea is that this kernel's length (in continuous units) is self.L, while we are asked to provide a kernel of length L at (relative) frequency rate
if L is None:
L = round(self.L.item() / rate)
# Increase the internal length if needed
continuous_L = round(rate * L)
while continuous_L > self.L.item():
self._setup_C(continuous_L)
discrete_L = round(self.L.item() / rate)
dt = torch.exp(self.log_dt) * rate
B = _r2c(self.B)
C = _r2c(self.C)
P = _r2c(self.P)
Q = P.conj()
w = self._w() # (n_ssm, N)
# Address bandlimiting
if self.bandlimit is not None:
freqs = w.imag.abs() / (2 * math.pi) # (H, N)
freqs = dt[:, None] / rate * freqs # (H, N)
mask = torch.where(freqs < self.bandlimit * 0.5, 1, 0)
C = C * mask
# Get FFT nodes of right length
omega, z = self._omega(
discrete_L, dtype=w.dtype, device=w.device, cache=(rate == 1.0)
)
# Broadcast parameters to same hidden features H
B = repeat(B, "1 t n -> 1 (v t) n", v=self.repeat)
P = repeat(P, "r t n -> r (v t) n", v=self.repeat)
Q = repeat(Q, "r t n -> r (v t) n", v=self.repeat)
w = repeat(w, "t n -> (v t) n", v=self.repeat)
# Augment B
if state is not None:
# Have to "unbilinear" the state to put it into the same "type" as B
# Compute 1/dt * (I + dt/2 A) @ state
# Can do this without expanding (maybe minor speedup using conj symmetry in theory), but it's easier to read this way
s = _conj(state) if state.size(-1) == self.N else state # (B H N)
sA = s * _conj(w) - contract( # (B H N)
"bhm, rhm, rhn -> bhn", s, _conj(Q), _conj(P)
)
s = s / dt.unsqueeze(-1) + sA / 2
s = s[..., : self.N]
B = torch.cat([s, B], dim=-3) # (B+1, H, N)
# Incorporate dt into A
w = w * dt.unsqueeze(-1) # (H N)
# Stack B and p, C and q for convenient batching
B = torch.cat([B, P], dim=-3) # (B+1+R, H, N)
C = torch.cat([C, Q], dim=-3) # (C+R, H, N)
# Incorporate B and C batch dimensions
v = B.unsqueeze(-3) * C.unsqueeze(-4) # (B+1+R, C+R, H, N)
# Calculate resolvent at omega
if has_cauchy_extension and z.dtype == torch.cfloat and not self.keops:
r = cauchy_mult(v, z, w, symmetric=True)
elif has_pykeops:
r = cauchy_conj(v, z, w)
else:
r = cauchy_naive(v, z, w)
r = r * dt[None, None, :, None] # (B+1+R, C+R, H, L)
# Low-rank Woodbury correction
if self.rank == 1:
k_f = r[:-1, :-1, :, :] - r[:-1, -1:, :, :] * r[-1:, :-1, :, :] / (
1 + r[-1:, -1:, :, :]
)
elif self.rank == 2:
r00 = r[: -self.rank, : -self.rank, :, :]
r01 = r[: -self.rank, -self.rank :, :, :]
r10 = r[-self.rank :, : -self.rank, :, :]
r11 = r[-self.rank :, -self.rank :, :, :]
det = (1 + r11[:1, :1, :, :]) * (1 + r11[1:, 1:, :, :]) - r11[
:1, 1:, :, :
] * r11[1:, :1, :, :]
s = (
r01[:, :1, :, :] * (1 + r11[1:, 1:, :, :]) * r10[:1, :, :, :]
+ r01[:, 1:, :, :] * (1 + r11[:1, :1, :, :]) * r10[1:, :, :, :]
- r01[:, :1, :, :] * (r11[:1, 1:, :, :]) * r10[1:, :, :, :]
- r01[:, 1:, :, :] * (r11[1:, :1, :, :]) * r10[:1, :, :, :]
)
s = s / det
k_f = r00 - s
else:
r00 = r[: -self.rank, : -self.rank, :, :]
r01 = r[: -self.rank, -self.rank :, :, :]
r10 = r[-self.rank :, : -self.rank, :, :]
r11 = r[-self.rank :, -self.rank :, :, :]
r11 = rearrange(r11, "a b h n -> h n a b")
r11 = torch.linalg.inv(torch.eye(self.rank, device=r.device) + r11)
r11 = rearrange(r11, "h n a b -> a b h n")
k_f = r00 - torch.einsum(
"i j h n, j k h n, k l h n -> i l h n", r01, r11, r10
)
# Final correction for the bilinear transform
k_f = k_f * 2 / (1 + omega)
# Move from frequency to coefficients
k = torch.fft.irfft(k_f, n=discrete_L) # (B+1, C, H, L)
# # Truncate to target length
k = k[..., :L]
if state is not None:
k_state = k[:-1, :, :, :] # (B, C, H, L)
else:
k_state = None
k_B = k[-1, :, :, :] # (C H L)
return k_B, k_state
@torch.no_grad()
def _setup_linear(self):
"""Create parameters that allow fast linear stepping of state"""
w = self._w()
B = _r2c(self.B) # (H N)
P = _r2c(self.P)
Q = P.conj()
# Repeat w shape properly
B = repeat(B, "1 t n -> 1 (v t) n", v=self.repeat)
P = repeat(P, "r t n -> r (v t) n", v=self.repeat)
Q = repeat(Q, "r t n -> r (v t) n", v=self.repeat)
w = repeat(w, "t n -> (v t) n", v=self.repeat)
# Prepare Linear stepping
dt = torch.exp(self.log_dt)
D = (2.0 / dt.unsqueeze(-1) - w).reciprocal() # (H, N)
R = (
torch.eye(self.rank, dtype=w.dtype, device=w.device)
+ 2 * contract("r h n, h n, s h n -> h r s", Q, D, P).real
) # (H R R)
Q_D = rearrange(Q * D, "r h n -> h r n")
try:
R = torch.linalg.solve(R, Q_D) # (H R N)
except Exception:
R = torch.tensor(
np.linalg.solve(
R.to(Q_D).contiguous().detach().cpu(),
Q_D.contiguous().detach().cpu(),
)
).to(Q_D)
R = rearrange(R, "h r n -> r h n")
self.step_params = {
"D": D, # (H N)
"R": R, # (R H N)
"P": P, # (R H N)
"Q": Q, # (R H N)
"B": B, # (1 H N)
"E": 2.0 / dt.unsqueeze(-1) + w, # (H N)
}
def _step_state_linear(self, u=None, state=None):
"""
Version of the step function that has time O(N) instead of O(N^2) per step, which takes advantage of the DPLR form and bilinear discretization.
Unfortunately, as currently implemented it's about 2x slower because it calls several sequential operations. Perhaps a fused CUDA kernel implementation would be much faster
u: (H) input
state: (H, N/2) state with conjugate pairs
Optionally, the state can have last dimension N
Returns: same shape as state
"""
C = _r2c(self.C) # View used for dtype/device
if u is None: # Special case used to find dA
u = torch.zeros(self.H, dtype=C.dtype, device=C.device)
if state is None: # Special case used to find dB
state = torch.zeros(self.H, self.N, dtype=C.dtype, device=C.device)
step_params = self.step_params.copy()
if (
state.size(-1) == self.N
): # Only store half of the conjugate pairs; should be true by default
# There should be a slightly faster way using conjugate symmetry
def contract_fn(p, x, y):
return contract(
"r h n, r h m, ... h m -> ... h n",
_conj(p),
_conj(x),
_conj(y),
)[
..., : self.N
] # inner outer product
else:
assert state.size(-1) == 2 * self.N
step_params = {k: _conj(v) for k, v in step_params.items()}
# TODO worth setting up a contract_expression in default_state if we want to use this at inference time for stepping
def contract_fn(p, x, y):
return contract(
"r h n, r h m, ... h m -> ... h n", p, x, y
) # inner outer product
D = step_params["D"] # (H N)
E = step_params["E"] # (H N)
R = step_params["R"] # (R H N)
P = step_params["P"] # (R H N)
Q = step_params["Q"] # (R H N)
B = step_params["B"] # (1 H N)
new_state = E * state - contract_fn(P, Q, state) # (B H N)
new_state = new_state + 2.0 * B * u.unsqueeze(-1) # (B H N)
new_state = D * (new_state - contract_fn(P, R, new_state))
return new_state
def _setup_state(self):
"""Construct dA and dB for discretized state equation"""
# Construct dA and dB by using the stepping
self._setup_linear()
C = _r2c(
self.C
) # Just returns a view that we use for finding dtype/device
state = torch.eye(
2 * self.N, dtype=C.dtype, device=C.device
).unsqueeze(
-2
) # (N 1 N)
dA = self._step_state_linear(state=state)
dA = rearrange(dA, "n h m -> h m n")
u = C.new_ones(self.H)
dB = self._step_state_linear(u=u)
dB = _conj(dB)
dB = rearrange(dB, "1 h n -> h n") # (H N)
return dA, dB
def _step_state(self, u, state):
"""Must be called after self.default_state() is used to construct an initial state!"""
next_state = self.state_contraction(
self.dA, state
) + self.input_contraction(self.dB, u)
return next_state
def _setup_step(self, mode="dense"):
"""Set up dA, dB, dC discretized parameters for stepping"""
self.dA, self.dB = self._setup_state()
# Calculate original C
C = _conj(_r2c(self.C)) # (H C N)
if self.L.item() == 0:
dC = C
else:
# self.C represents C_tilde
dA_L = power(self.L.item(), self.dA)
I = torch.eye(self.dA.size(-1)).to(dA_L)
dC = torch.linalg.solve(
I - dA_L.transpose(-1, -2),
C.unsqueeze(-1),
).squeeze(-1)
self.dC = dC
# Do special preprocessing for different step modes
self._step_mode = mode
if mode == "linear":
# Linear case: special step function for the state, we need to handle output
# use conjugate symmetry by default, which affects the output projection
self.dC = 2 * self.dC[:, :, : self.N]
elif mode == "diagonal":
# Eigendecomposition of the A matrix
L, V = torch.linalg.eig(self.dA)
V_inv = torch.linalg.inv(V)
# Check that the eigendedecomposition is correct
if self.verbose:
print(
"Diagonalization error:",
torch.dist(V @ torch.diag_embed(L) @ V_inv, self.dA),
)
# Change the parameterization to diagonalize
self.dA = L
self.dB = contract("h n m, h m -> h n", V_inv, self.dB)
self.dC = contract("h n m, c h n -> c h m", V, self.dC)
elif mode == "dense":
pass
else:
raise NotImplementedError(
"NPLR Kernel step mode must be {'dense' | 'linear' | 'diagonal'}"
)
def default_state(self, *batch_shape):
C = _r2c(self.C)
N = C.size(-1)
H = C.size(-2)
# Cache the tensor contractions we will later do, for efficiency
# These are put in this function because they depend on the batch size
step_mode = getattr(
self, "_step_mode", "dense"
) # Used in default_state, which is called without _setup_step() in forward_state()
if step_mode != "linear":
N *= 2
if step_mode == "diagonal":
self.state_contraction = contract_expression(
"h n, ... h n -> ... h n",
(H, N),
batch_shape + (H, N),
)
else:
# Dense (quadratic) case: expand all terms
self.state_contraction = contract_expression(
"h m n, ... h n -> ... h m",
(H, N, N),
batch_shape + (H, N),
)
self.input_contraction = contract_expression(
"h n, ... h -> ... h n",
(H, N), # self.dB.shape
batch_shape + (H,),
)
self.output_contraction = contract_expression(
"c h n, ... h n -> ... c h",
(C.shape[0], H, N), # self.dC.shape
batch_shape + (H, N),
)
state = torch.zeros(*batch_shape, H, N, dtype=C.dtype, device=C.device)
return state
def step(self, u, state):
"""Must have called self._setup_step() and created state with self.default_state() before calling this"""
if self._step_mode == "linear":
new_state = self._step_state_linear(u, state)
else:
new_state = self._step_state(u, state)
y = self.output_contraction(self.dC, new_state)
return y.real, new_state
class SSKernelDiag(OptimModule):
"""Version using (complex) diagonal state matrix (S4D)"""
def __init__(
self,
A,
B,
C,
log_dt,
L=None,
disc="bilinear",
real_type="exp",
lr=None,
bandlimit=None,
):
super().__init__()
self.L = L
self.disc = disc
self.bandlimit = bandlimit
self.real_type = real_type
# Rank of low-rank correction
assert A.size(-1) == C.size(-1)
self.H = log_dt.size(-1)
self.N = A.size(-1)
assert A.size(-2) == B.size(-2) # Number of independent SSMs trained
assert self.H % A.size(-2) == 0
self.n_ssm = A.size(-2)
self.repeat = self.H // A.size(0)
self.channels = C.shape[0]
self.C = nn.Parameter(_c2r(_resolve_conj(C)))
# Register parameters
if lr is None or isinstance(lr, float):
lr_dict = {}
else:
lr_dict, lr = lr, None
self.register("log_dt", log_dt, lr_dict.get("dt", lr))
self.register("B", _c2r(B), lr_dict.get("B", lr))
self.register("inv_A_real", self._A_init(A.real), lr_dict.get("A", lr))
self.register("A_imag", A.imag, lr_dict.get("A", lr))
def _A_init(self, A_real):
A_real = torch.clamp(A_real, max=-1e-4)
if self.real_type == "none":
return -A_real
elif self.real_type == "exp":
return torch.log(
-A_real
) # Some of the HiPPO methods have real part 0
elif self.real_type == "relu":
return -A_real
elif self.real_type == "sigmoid":
return torch.logit(-A_real)
elif self.real_type == "softplus":
return torch.log(torch.exp(-A_real) - 1)
else:
raise NotImplementedError
def _A(self):
# Get the internal A (diagonal) parameter
if self.real_type == "none":
A_real = -self.inv_A_real
elif self.real_type == "exp":
A_real = -torch.exp(self.inv_A_real)
elif self.real_type == "relu":
# JAX version seems to NaN if you alloA 0's, although this code Aas fine Aithout it
A_real = -F.relu(self.inv_A_real) - 1e-4
elif self.real_type == "sigmoid":
A_real = -F.sigmoid(self.inv_A_real)
elif self.real_type == "softplus":
A_real = -F.softplus(self.inv_A_real)
else:
raise NotImplementedError
A = A_real + 1j * self.A_imag
return A
def forward(self, L, state=None, rate=1.0, u=None):
"""
state: (B, H, N) initial state
rate: sampling rate factor
L: target length
returns:
(C, H, L) convolution kernel (generally C=1)
(B, H, L) output from initial state
"""
dt = torch.exp(self.log_dt) * rate # (H)
C = _r2c(self.C) # (C H N)
A = self._A() # (H N)
B = _r2c(self.B)
B = repeat(B, "t n -> 1 (v t) n", v=self.repeat)
if self.bandlimit is not None:
freqs = dt[:, None] / rate * A.imag.abs() / (2 * math.pi) # (H, N)
mask = torch.where(freqs < self.bandlimit * 0.5, 1, 0)
C = C * mask
# Incorporate dt into A
A = repeat(A, "t n -> (v t) n", v=self.repeat)
dtA = A * dt.unsqueeze(-1) # (H N)
# Augment B with state
if state is not None:
s = state / dt.unsqueeze(-1)
if self.disc == "bilinear":
s = s * (1.0 + dtA / 2)
elif self.disc == "zoh":
s = s * dtA * dtA.exp() / (dtA.exp() - 1.0)
B = torch.cat([s, B], dim=-3) # (1+B H N)
C = (B[:, None, :, :] * C).view(-1, self.H, self.N)
if self.disc == "zoh":
# Power up
C = C * (torch.exp(dtA) - 1.0) / A
K = log_vandermonde(C, dtA, L) # (H L)
elif self.disc == "bilinear":
C = (
C * (1.0 - dtA / 2).reciprocal() * dt.unsqueeze(-1)
) # or * dtA / A
dA = (1.0 + dtA / 2) / (1.0 - dtA / 2)
K = log_vandermonde(C, dA.log(), L)
elif self.disc == "dss":
# Implementation from DSS meant for case when real eigenvalues can be positive
P = dtA.unsqueeze(-1) * torch.arange(L, device=C.device) # [H N L]
A_gt_0 = A.real > 0 # [N]
if A_gt_0.any():
with torch.no_grad():
P_max = dtA * (A_gt_0 * (L - 1)) # [H N]
P = P - P_max.unsqueeze(-1) # [H N L]
S = P.exp() # [H N L]
dtA_neg = dtA * (1 - 2 * A_gt_0) # [H N]
num = dtA_neg.exp() - 1 # [H N]
den = (dtA_neg * L).exp() - 1 # [H N]
# Inline reciprocal function for DSS logic
x = den * A
x_conj = _resolve_conj(x)
r = x_conj / (x * x_conj + 1e-7)
C = C * num * r # [C H N]
K = contract("chn,hnl->chl", C, S).float()
else:
assert False, f"{self.disc} not supported"
K = K.view(-1, self.channels, self.H, L) # (1+B C H L)
if state is not None:
K_state = K[:-1, :, :, :] # (B C H L)
else:
K_state = None
K = K[-1, :, :, :] # (C H L)
return K, K_state
def _setup_step(self):
# These methods are organized like this to be compatible with the NPLR kernel interface
dt = torch.exp(self.log_dt) # (H)
B = _r2c(self.B) # (H N)
C = _r2c(self.C) # (C H N)
self.dC = C
A = self._A() # (H N)
A = repeat(A, "t n -> (v t) n", v=self.repeat)
B = repeat(B, "t n -> (v t) n", v=self.repeat)
# Incorporate dt into A
dtA = A * dt.unsqueeze(-1) # (H N)
if self.disc == "zoh":
self.dA = torch.exp(dtA) # (H N)
self.dB = B * (torch.exp(dtA) - 1.0) / A # (C H N)
elif self.disc == "bilinear":
self.dA = (1.0 + dtA / 2) / (1.0 - dtA / 2)
self.dB = (
B * (1.0 - dtA / 2).reciprocal() * dt.unsqueeze(-1)
) # or * dtA / A
def default_state(self, *batch_shape):
C = _r2c(self.C)
state = torch.zeros(
*batch_shape, self.H, self.N, dtype=C.dtype, device=C.device
)
return state
def step(self, u, state):
next_state = contract(
"h n, b h n -> b h n", self.dA, state
) + contract("h n, b h -> b h n", self.dB, u)
y = contract("c h n, b h n -> b c h", self.dC, next_state)
return 2 * y.real, next_state
def forward_state(self, u, state):
self._setup_step()
AL = self.dA ** u.size(-1)
u = u.flip(-1).to(self.dA).contiguous() # (B H L)
v = log_vandermonde_transpose(u, self.dB, self.dA.log(), u.size(-1))
next_state = AL * state + v
return next_state
class SSKernel(nn.Module):
"""Wrapper around SSKernel parameterizations.
The SSKernel is expected to support the interface
forward()
default_state()
_setup_step()
step()
"""
def __init__(
self,
H,
N=64,
L=None,
measure="legs",
rank=1,
channels=1,
dt_min=0.001,
dt_max=0.1,
deterministic=False,
lr=None,
mode="nplr",
n_ssm=None,
verbose=False,
measure_args={},
**kernel_args,
):
"""State Space Kernel which computes the convolution kernel $\\bar{K}$
H: Number of independent SSM copies; controls the size of the model. Also called d_model in the config.
N: State size (dimensionality of parameters A, B, C). Also called d_state in the config. Generally shouldn't need to be adjusted and doens't affect speed much.
L: Maximum length of convolution kernel, if known. Should work in the majority of cases even if not known.
measure: Options for initialization of (A, B). For NPLR mode, recommendations are "legs", "fout", "hippo" (combination of both). For Diag mode, recommendations are "diag-inv", "diag-lin", "diag-legs", and "diag" (combination of diag-inv and diag-lin)
rank: Rank of low-rank correction for NPLR mode. Needs to be increased for measure "legt"
channels: C channels turns the SSM from a 1-dim to C-dim map; can think of it having C separate "heads" per SSM. This was partly a feature to make it easier to implement bidirectionality; it is recommended to set channels=1 and adjust H to control parameters instead
dt_min, dt_max: min and max values for the step size dt (\Delta)
mode: Which kernel algorithm to use. 'nplr' is the full S4 model; 'diag' is the simpler S4D; 'slow' is a dense version for testing
n_ssm: Number of independent trainable (A, B) SSMs, e.g. n_ssm=1 means all A/B parameters are tied across the H different instantiations of C. n_ssm=None means all H SSMs are completely independent. Generally, changing this option can save parameters but doesn't affect performance or speed much. This parameter must divide H
lr: Passing in a number (e.g. 0.001) sets attributes of SSM parameers (A, B, dt). A custom optimizer hook is needed to configure the optimizer to set the learning rates appropriately for these parameters.
"""
super().__init__()
self.N = N
self.H = H
dtype, cdtype = torch.float, torch.cfloat
self.channels = channels
self.n_ssm = n_ssm if n_ssm is not None else H
self.mode = mode
self.verbose = verbose
self.kernel_args = kernel_args
# Generate dt
if deterministic:
log_dt = torch.exp(
torch.linspace(math.log(dt_min), math.log(dt_max), H)
)
else:
log_dt = torch.rand(self.H, dtype=dtype) * (
math.log(dt_max) - math.log(dt_min)
) + math.log(dt_min)
# Compute the preprocessed representation
w, P, B, V = combination(
measure, self.N, rank, self.n_ssm, **measure_args
)
# Broadcast C to have H channels
if deterministic:
C = torch.zeros(channels, self.n_ssm, self.N, dtype=cdtype)
C[:, :, :1] = 1.0
C = contract(
"hmn, chn -> chm", V.conj().transpose(-1, -2), C
) # V^* C
C = (
repeat(C, "c t n -> c (v t) n", v=self.n_ssm // C.size(-2))
.clone()
.contiguous()
)
else:
C = torch.randn(channels, self.H, self.N // 2, dtype=cdtype)
# Broadcast other parameters to have n_ssm copies
assert (
self.n_ssm % B.size(-2) == 0
and self.n_ssm % P.size(-2) == 0
and self.n_ssm % w.size(-2) == 0
)
# Broadcast tensors to n_ssm copies
# These will be the parameters, so make sure tensors are materialized and contiguous
B = (
repeat(B, "t n -> (v t) n", v=self.n_ssm // B.size(-2))
.clone()
.contiguous()
)
P = (
repeat(P, "r t n -> r (v t) n", v=self.n_ssm // P.size(-2))
.clone()
.contiguous()
)
w = (
repeat(w, "t n -> (v t) n", v=self.n_ssm // w.size(-2))
.clone()
.contiguous()
)
if mode == "nplr":
self.kernel = SSKernelNPLR(
w,
P,
B,
C,
log_dt,
L=L,
lr=lr,
verbose=verbose,
**kernel_args,
)
elif mode == "diag":
if not measure.startswith("diag"):
log.warning(
"Diagonal kernel (S4D) activated but initialization is not intended for S4D. Set `measure` to 'diag-lin', 'diag-inv', or 'diag-legs' for the main variants, or 'diag' for a combination of S4D-Lin and S4D-Inv."
)
C = C * repeat(B, "t n -> (v t) n", v=H // self.n_ssm)
self.kernel = SSKernelDiag(
w,
B,
C,
log_dt,
L=L,
lr=lr,
**kernel_args,
)
else:
raise NotImplementedError(f"{mode=} is not valid")
def forward(self, state=None, L=None, rate=1.0):
return self.kernel(state=state, L=L, rate=rate)
@torch.no_grad()
def forward_state(self, u, state):
"""Forward the state through a sequence, i.e. computes the state after passing chunk through SSM
state: (B, H, N)
u: (B, H, L)
Returns: (B, H, N)
"""
if hasattr(self.kernel, "forward_state"):
return self.kernel.forward_state(u, state)
dA, dB = self.kernel._setup_state() # Construct dA, dB matrices
# dA, dB = self.kernel.dA, self.kernel.dB # (H N N) (H N)
conj = state.size(-1) != dA.size(-1)
if conj:
state = _conj(state)
v = contract(
"h n, b h l -> b h n l", dB, u.flip(-1)
) # dB.unsqueeze(-1) * u.flip(-1).unsqueeze(-2)
AL, v = power(u.size(-1), dA, v)
next_state = contract("h m n, b h n -> b h m", AL, state)
next_state = next_state + v
if conj:
next_state = next_state[..., : next_state.size(-1) // 2]
return next_state
def _setup_step(self, **kwargs):
# This method is intended to be private so that setting up an S4 module with
# ```
# if hasattr(module, 'setup_step'): module.setup_step()
# ```
# will not trigger this method multiple times
self.kernel._setup_step(**kwargs)
def step(self, u, state, **kwargs):
y, state = self.kernel.step(u, state, **kwargs)
return y, state
def default_state(self, *args, **kwargs):
return self.kernel.default_state(*args, **kwargs)
class S4(nn.Module):
def __init__(
self,
d_model,
d_state=64,
l_max=None,
channels=1,
mode="nplr",
measure="legs",
bidirectional=False,
# Arguments for position-wise feedforward components
activation="gelu",
postact="glu",
hyper_act=None,
dropout=0.0,
tie_dropout=False,
bottleneck=None,
gate=None,
transposed=True,
verbose=False,
# SSM Kernel arguments
**kernel_args,
):
"""
d_state: the dimension of the state, also denoted by N
l_max: the maximum kernel length, also denoted by L. Set l_max=None to always use a global kernel
channels: can be interpreted as a number of "heads"; the SSM is a map from a 1-dim to C-dim sequence. It's not recommended to change this unless desperate for things to tune; instead, increase d_model for larger models
bidirectional: if True, convolution kernel will be two-sided
Position-wise feedforward components:
--------------------
activation: activation in between SS and FF
postact: activation after FF
hyper_act: use a "hypernetwork" multiplication (experimental)
dropout: standard dropout argument. tie_dropout=True ties the dropout mask across the sequence length, emulating nn.Dropout1d
Other arguments:
--------------------
transposed: choose backbone axis ordering of (B, L, H) (if False) or (B, H, L) (if True) [B=batch size, L=sequence length, H=hidden dimension]
gate: add gated activation (GSS)
bottleneck: reduce SSM dimension (GSS)
See the class SSKernel for the kernel constructor which accepts kernel_args. Relevant options that are worth considering and tuning include "mode" + "measure", "dt_min", "dt_max", "lr"
Other options are all experimental and should not need to be configured
"""
super().__init__()
if verbose:
log.info(
f"Constructing S4 (H, N, L) = ({d_model}, {d_state}, {l_max})"
)
self.d_model = d_model
self.H = d_model
self.N = d_state
self.L = l_max
self.bidirectional = bidirectional
self.channels = channels
self.transposed = transposed
self.gate = gate
self.bottleneck = bottleneck
if bottleneck is not None:
self.H = self.H // bottleneck
self.input_linear = LinearActivation(
self.d_model,
self.H,
transposed=self.transposed,
activation=activation,
activate=True,
)
if gate is not None:
self.input_gate = LinearActivation(
self.d_model,
self.d_model * gate,
transposed=self.transposed,
activation=activation,
activate=True,
)
self.output_gate = LinearActivation(
self.d_model * gate,
self.d_model,
transposed=self.transposed,
activation=None,
activate=False,
)
# optional multiplicative modulation GLU-style
# https://arxiv.org/abs/2002.05202
self.hyper = hyper_act is not None
if self.hyper:
channels *= 2
self.hyper_activation = Activation(hyper_act)
self.D = nn.Parameter(torch.randn(channels, self.H))
if self.bidirectional:
channels *= 2
# SSM Kernel
self.kernel = SSKernel(
self.H,
N=self.N,
L=self.L,
channels=channels,
verbose=verbose,
mode=mode,
measure=measure,
**kernel_args,
)
# Pointwise
self.activation = Activation(activation)
dropout_fn = DropoutNd if tie_dropout else nn.Dropout
self.dropout = dropout_fn(dropout) if dropout > 0.0 else nn.Identity()
# position-wise output transform to mix features
self.output_linear = LinearActivation(
self.H * self.channels,
self.d_model * (1 if self.gate is None else self.gate),
transposed=self.transposed,
activation=postact,
activate=True,
)
def forward(self, u, state=None, rate=1.0, lengths=None, **kwargs):
"""
u: (B H L) if self.transposed else (B L H)
state: (H N) never needed unless you know what you're doing
Returns: same shape as u
"""
if not self.transposed:
u = u.transpose(-1, -2)
L = u.size(-1)
# Mask out padding tokens
if isinstance(lengths, int):
if lengths != L:
lengths = torch.tensor(
lengths, dtype=torch.long, device=u.device
)
else:
lengths = None
if lengths is not None:
assert (
isinstance(lengths, torch.Tensor)
and lengths.ndim == 1
and lengths.size(0) in [1, u.size(0)]
)
mask = torch.where(
torch.arange(L, device=lengths.device)
< lengths[:, None, None],
1.0,
0.0,
)
u = u * mask
if self.gate is not None:
v = self.input_gate(u)
if self.bottleneck is not None:
u = self.input_linear(u)
# Compute SS Kernel
L_kernel = L if self.L is None else min(L, round(self.L / rate))
k, k_state = self.kernel(
L=L_kernel, rate=rate, state=state
) # (C H L) (B C H L)
# Convolution
if self.bidirectional:
k0, k1 = rearrange(k, "(s c) h l -> s c h l", s=2)
k = F.pad(k0, (0, L)) + F.pad(k1.flip(-1), (L, 0))
k_f = torch.fft.rfft(k, n=L_kernel + L) # (C H L)
u_f = torch.fft.rfft(u, n=L_kernel + L) # (B H L)
y_f = contract("bhl,chl->bchl", u_f, k_f)
y = torch.fft.irfft(y_f, n=L_kernel + L)[..., :L] # (B C H L)
# Compute D term in state space equation - essentially a skip connection
y = y + contract("bhl,ch->bchl", u, self.D)
# Compute state update
if state is not None:
assert (
not self.bidirectional
), "Bidirectional not supported with state forwarding"
y = y + k_state #
next_state = self.kernel.forward_state(u, state)
else:
next_state = None
# Optional hyper-network multiplication
if self.hyper:
y, yh = rearrange(y, "b (s c) h l -> s b c h l", s=2)
y = self.hyper_activation(yh) * y
# Reshape to flatten channels
y = rearrange(y, "... c h l -> ... (c h) l")
y = self.dropout(self.activation(y))
if not self.transposed:
y = y.transpose(-1, -2)
y = self.output_linear(y)
if self.gate is not None:
y = self.output_gate(y * v)
return y, next_state
def setup_step(self, **kwargs):
self.kernel._setup_step(**kwargs)
def step(self, u, state):
"""Step one time step as a recurrent model. Intended to be used during validation.
u: (B H)
state: (B H N)
Returns: output (B H), state (B H N)
"""
assert not self.training
y, next_state = self.kernel.step(u, state) # (B C H)
y = y + u.unsqueeze(-2) * self.D
y = rearrange(y, "b c h -> b (c h)")
y = self.activation(y)
if self.transposed:
y = self.output_linear(y.unsqueeze(-1)).squeeze(-1)
else:
y = self.output_linear(y)
return y, next_state
def default_state(self, *batch_shape, device=None):
# kernel is not a SequenceModule so it doesn't need to adhere to same interface
# the kernel will know the device of its own parameters
return self.kernel.default_state(*batch_shape)
@property
def d_output(self):
return self.d_model
================================================
FILE: probts/model/nn/arch/S4/s4_backbones.py
================================================
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import math
import torch
from torch import nn
from probts.model.nn.arch.S4.s4 import S4
class SinusoidalPositionEmbeddings(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, time):
device = time.device
half_dim = self.dim // 2
embeddings = math.log(10000) / (half_dim - 1)
embeddings = torch.exp(
torch.arange(half_dim, device=device) * -embeddings
)
embeddings = time[:, None] * embeddings[None, :]
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
return embeddings
class S4Layer(nn.Module):
def __init__(
self,
d_model,
dropout=0.0,
mode="nplr",
l_max=None,
measure="legs"
):
super().__init__()
self.layer = S4(
d_model=d_model,
d_state=128,
bidirectional=True,
dropout=dropout,
transposed=True,
postact=None,
mode=mode,
l_max=l_max,
measure=measure,
)
self.norm = nn.LayerNorm(d_model)
self.dropout = (
nn.Dropout1d(dropout) if dropout > 0.0 else nn.Identity()
)
def forward(self, x):
"""
Input x is shape (B, d_input, L)
"""
z = x
# Prenorm
z = self.norm(z.transpose(-1, -2)).transpose(-1, -2)
# Apply layer: we ignore the state input and output for training
z, _ = self.layer(z)
# Dropout on the output of the layer
z = self.dropout(z)
# Residual connection
x = z + x
return x, None
def default_state(self, *args, **kwargs):
return self.layer.default_state(*args, **kwargs)
def step(self, x, state, **kwargs):
z = x
# Prenorm
z = self.norm(z.transpose(-1, -2)).transpose(-1, -2)
# Apply layer
z, state = self.layer.step(z, state, **kwargs)
# Residual connection
x = z + x
return x, state
class S4Block(nn.Module):
def __init__(self, d_model, dropout=0.0, expand=2, num_features=0,mode="nplr",l_max=None,measure="legs"):
super().__init__()
self.s4block = S4Layer(d_model, dropout=dropout,mode=mode,l_max=l_max,measure=measure)
self.time_linear = nn.Linear(d_model, d_model)
self.tanh = nn.Tanh()
self.sigm = nn.Sigmoid()
self.out_linear1 = nn.Conv1d(
in_channels=d_model, out_channels=d_model, kernel_size=1
)
self.out_linear2 = nn.Conv1d(
in_channels=d_model, out_channels=d_model, kernel_size=1
)
self.feature_encoder = nn.Conv1d(num_features, d_model, kernel_size=1)
def forward(self, x, t, features=None):
t = self.time_linear(t)[:, None, :].repeat(1, x.shape[2], 1)
t = t.transpose(-1, -2)
out, _ = self.s4block(x + t)
if features is not None:
out = out + self.feature_encoder(features)
out = self.tanh(out) * self.sigm(out)
out1 = self.out_linear1(out)
out2 = self.out_linear2(out)
return out1 + x, out2
def Conv1dKaiming(in_channels, out_channels, kernel_size):
layer = nn.Conv1d(in_channels, out_channels, kernel_size)
nn.init.kaiming_normal_(layer.weight)
return layer
class BackboneModel(nn.Module):
def __init__(
self,
input_dim,
hidden_dim,
output_dim,
step_emb,
num_residual_blocks,
num_features,
residual_block="s4",
mode="nplr",
measure="legs",
l_max=None,
dropout=0.0,
init_skip=True,
):
super().__init__()
if residual_block == "s4":
residual_block = S4Block
else:
raise ValueError(f"Unknown residual block {residual_block}")
self.input_init = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
)
self.time_init = nn.Sequential(
nn.Linear(step_emb, hidden_dim),
nn.SiLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.SiLU(),
)
self.out_linear = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, output_dim),
)
residual_blocks = []
for i in range(num_residual_blocks):
residual_blocks.append(
residual_block(
hidden_dim,
num_features=num_features,
dropout=dropout,
mode=mode,l_max=l_max,
measure=measure,
)
)
self.residual_blocks = nn.ModuleList(residual_blocks)
self.step_embedding = SinusoidalPositionEmbeddings(step_emb)
self.init_skip = init_skip
def forward(self, input, t, features=None):
x = self.input_init(input) # B, L ,C
t = self.time_init(self.step_embedding(t))
x = x.transpose(-1, -2)
if features is not None:
features = features.transpose(-1, -2)
skips = []
for layer in self.residual_blocks:
x, skip = layer(x, t, features)
skips.append(skip)
skip = torch.stack(skips).sum(0)
skip = skip.transpose(-1, -2)
out = self.out_linear(skip)
if self.init_skip:
out = out + input
return out
================================================
FILE: probts/model/nn/arch/TSMixer_layers.py
================================================
from __future__ import annotations
from collections.abc import Callable
import torch
import torch.nn.functional as F
from torch import Tensor, nn
import sys
class TimeBatchNorm2d(nn.BatchNorm1d):
"""A batch normalization layer that normalizes over the last two dimensions of a
sequence in PyTorch, mimicking Keras behavior.
This class extends nn.BatchNorm1d to apply batch normalization across time and
feature dimensions.
Attributes:
num_time_steps (int): Number of time steps in the input.
num_channels (int): Number of channels in the input.
"""
def __init__(self, normalized_shape: tuple[int, int]):
"""Initializes the TimeBatchNorm2d module.
Args:
normalized_shape (tuple[int, int]): A tuple (num_time_steps, num_channels)
representing the shape of the time and feature dimensions to normalize.
"""
num_time_steps, num_channels = normalized_shape
super().__init__(num_channels * num_time_steps)
self.num_time_steps = num_time_steps
self.num_channels = num_channels
def forward(self, x: Tensor) -> Tensor:
"""Applies the batch normalization over the last two dimensions of the input tensor.
Args:
x (Tensor): A 3D tensor with shape (N, S, C), where N is the batch size,
S is the number of time steps, and C is the number of channels.
Returns:
Tensor: A 3D tensor with batch normalization applied over the last two dims.
Raises:
ValueError: If the input tensor is not 3D.
"""
if x.ndim != 3:
raise ValueError(f"Expected 3D input tensor, but got {x.ndim}D tensor instead.")
# Reshaping input to combine time and feature dimensions for normalization
x = x.reshape(x.shape[0], -1, 1)
# Applying batch normalization
x = super().forward(x)
# Reshaping back to original dimensions (N, S, C)
x = x.reshape(x.shape[0], self.num_time_steps, self.num_channels)
return x
class FeatureMixing(nn.Module):
"""A module for feature mixing with flexibility in normalization and activation.
This module provides options for batch normalization before or after mixing features,
uses dropout for regularization, and allows for different activation functions.
Args:
sequence_length: The length of the sequences to be transformed.
input_channels: The number of input channels to the module.
output_channels: The number of output channels from the module.
ff_dim: The dimension of the feed-forward network internal to the module.
activation_fn: The activation function used within the feed-forward network.
dropout_rate: The dropout probability used for regularization.
normalize_before: A boolean indicating whether to apply normalization before
the rest of the operations.
"""
def __init__(
self,
sequence_length: int,
input_channels: int,
output_channels: int,
ff_dim: int,
activation_fn: Callable[[torch.Tensor], torch.Tensor] = F.relu,
dropout_rate: float = 0.1,
normalize_before: bool = True,
norm_type: type[nn.Module] = TimeBatchNorm2d,
):
"""Initializes the FeatureMixing module with the provided parameters."""
super().__init__()
self.norm_before = (
norm_type((sequence_length, input_channels))
if normalize_before
else nn.Identity()
)
self.norm_after = (
norm_type((sequence_length, output_channels))
if not normalize_before
else nn.Identity()
)
self.activation_fn = activation_fn
self.dropout = nn.Dropout(dropout_rate)
self.fc1 = nn.Linear(input_channels, ff_dim)
self.fc2 = nn.Linear(ff_dim, output_channels)
self.projection = (
nn.Linear(input_channels, output_channels)
if input_channels != output_channels
else nn.Identity()
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass for the FeatureMixing module.
Args:
x: A 3D tensor with shape (N, C, L) where C is the channel dimension.
Returns:
The output tensor after feature mixing.
"""
x_proj = self.projection(x)
x = self.norm_before(x)
x = self.fc1(x) # Apply the first linear transformation.
x = self.activation_fn(x) # Apply the activation function.
x = self.dropout(x) # Apply dropout for regularization.
x = self.fc2(x) # Apply the second linear transformation.
x = self.dropout(x) # Apply dropout again if needed.
x = x_proj + x # Add the projection shortcut to the transformed features.
return self.norm_after(x)
class ConditionalFeatureMixing(nn.Module):
"""Conditional feature mixing module that incorporates static features.
This module extends the feature mixing process by including static features. It uses
a linear transformation to integrate static features into the dynamic feature space,
then applies the feature mixing on the concatenated features.
Args:
input_channels: The number of input channels of the dynamic features.
output_channels: The number of output channels after feature mixing.
static_channels: The number of channels in the static feature input.
ff_dim: The inner dimension of the feedforward network used in feature mixing.
activation_fn: The activation function used in feature mixing.
dropout_rate: The dropout probability used in the feature mixing operation.
"""
def __init__(
self,
sequence_length: int,
input_channels: int,
output_channels: int,
static_channels: int,
ff_dim: int,
activation_fn: Callable = F.relu,
dropout_rate: float = 0.1,
normalize_before: bool = False,
norm_type: type[nn.Module] = nn.LayerNorm,
):
super().__init__()
self.fr_static = nn.Linear(static_channels, output_channels)
self.fm = FeatureMixing(
sequence_length,
input_channels + output_channels,
output_channels,
ff_dim,
activation_fn,
dropout_rate,
normalize_before=normalize_before,
norm_type=norm_type,
)
def forward(
self, x: torch.Tensor, x_static: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""Applies conditional feature mixing using both dynamic and static inputs.
Args:
x: A tensor representing dynamic features, typically with shape
[batch_size, time_steps, input_channels].
x_static: A tensor representing static features, typically with shape
[batch_size, static_channels].
Returns:
A tuple containing:
- The output tensor after applying conditional feature mixing.
- The transformed static features tensor for monitoring or further processing.
"""
v = self.fr_static(x_static) # Transform static features to match output channels.
v = v.unsqueeze(1).repeat(
1, x.shape[1], 1
) # Repeat static features across time steps.
return (
self.fm(
torch.cat([x, v], dim=-1)
), # Apply feature mixing on concatenated features.
v.detach(), # Return detached static feature for monitoring or further use.
)
class TimeMixing(nn.Module):
"""Applies a transformation over the time dimension of a sequence.
This module applies a linear transformation followed by an activation function
and dropout over the sequence length of the input feature tensor after converting
feature maps to the time dimension and then back.
Args:
input_channels: The number of input channels to the module.
sequence_length: The length of the sequences to be transformed.
activation_fn: The activation function to be used after the linear transformation.
dropout_rate: The dropout probability to be used after the activation function.
"""
def __init__(
self,
sequence_length: int,
input_channels: int,
activation_fn: Callable = F.relu,
dropout_rate: float = 0.1,
norm_type: type[nn.Module] = TimeBatchNorm2d,
):
"""Initializes the TimeMixing module with the specified parameters."""
super().__init__()
self.norm = norm_type((sequence_length, input_channels))
self.activation_fn = activation_fn
self.dropout = nn.Dropout(dropout_rate)
self.fc1 = nn.Linear(sequence_length, sequence_length)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Applies the time mixing operations on the input tensor.
Args:
x: A 3D tensor with shape (N, C, L), where C = channel dimension and
L = sequence length.
Returns:
The normalized output tensor after time mixing transformations.
"""
x_temp = feature_to_time(
x
) # Convert feature maps to time dimension. Assumes definition elsewhere.
x_temp = self.activation_fn(self.fc1(x_temp))
x_temp = self.dropout(x_temp)
x_res = time_to_feature(x_temp) # Convert back from time to feature maps.
return self.norm(x + x_res) # Apply normalization and combine with original input.
class MixerLayer(nn.Module):
"""A residual block that combines time and feature mixing for sequence data.
This module sequentially applies time mixing and feature mixing, which are forms
of data augmentation and feature transformation that can help in learning temporal
dependencies and feature interactions respectively.
Args:
sequence_length: The length of the input sequences.
input_channels: The number of input channels to the module.
output_channels: The number of output channels from the module.
ff_dim: The inner dimension of the feedforward network used in feature mixing.
activation_fn: The activation function used in both time and feature mixing.
dropout_rate: The dropout probability used in both mixing operations.
"""
def __init__(
self,
sequence_length: int,
input_channels: int,
output_channels: int,
ff_dim: int,
activation_fn: Callable = F.relu,
dropout_rate: float = 0.1,
normalize_before: bool = False,
norm_type: type[nn.Module] = nn.LayerNorm,
):
"""Initializes the MixLayer with time and feature mixing modules."""
super().__init__()
self.time_mixing = TimeMixing(
sequence_length,
input_channels,
activation_fn,
dropout_rate,
norm_type=norm_type,
)
self.feature_mixing = FeatureMixing(
sequence_length,
input_channels,
output_channels,
ff_dim,
activation_fn,
dropout_rate,
norm_type=norm_type,
normalize_before=normalize_before,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass for the MixLayer module.
Args:
x: A 3D tensor with shape (N, C, L) to be processed by the mixing layers.
Returns:
The output tensor after applying time and feature mixing operations.
"""
x = self.time_mixing(x) # Apply time mixing first.
x = self.feature_mixing(x) # Then apply feature mixing.
return x
class ConditionalMixerLayer(nn.Module):
"""Conditional mix layer combining time and feature mixing with static context.
This module combines time mixing and conditional feature mixing, where the latter
is influenced by static features. This allows the module to learn representations
that are influenced by both dynamic and static features.
Args:
sequence_length: The length of the input sequences.
input_channels: The number of input channels of the dynamic features.
output_channels: The number of output channels after feature mixing.
static_channels: The number of channels in the static feature input.
ff_dim: The inner dimension of the feedforward network used in feature mixing.
activation_fn: The activation function used in both mixing operations.
dropout_rate: The dropout probability used in both mixing operations.
"""
def __init__(
self,
sequence_length: int,
input_channels: int,
output_channels: int,
static_channels: int,
ff_dim: int,
activation_fn: Callable = F.relu,
dropout_rate: float = 0.1,
normalize_before: bool = False,
norm_type: type[nn.Module] = nn.LayerNorm,
):
super().__init__()
self.time_mixing = TimeMixing(
sequence_length,
input_channels,
activation_fn,
dropout_rate,
norm_type=norm_type,
)
self.feature_mixing = ConditionalFeatureMixing(
sequence_length,
input_channels,
output_channels=output_channels,
static_channels=static_channels,
ff_dim=ff_dim,
activation_fn=activation_fn,
dropout_rate=dropout_rate,
normalize_before=normalize_before,
norm_type=norm_type,
)
def forward(self, x: torch.Tensor, x_static: torch.Tensor) -> torch.Tensor:
"""Forward pass for the conditional mix layer.
Args:
x: A tensor representing dynamic features, typically with shape
[batch_size, time_steps, input_channels].
x_static: A tensor representing static features, typically with shape
[batch_size, static_channels].
Returns:
The output tensor after applying time and conditional feature mixing.
"""
x = self.time_mixing(x) # Apply time mixing first.
x, _ = self.feature_mixing(x, x_static) # Then apply conditional feature mixing.
return x
def time_to_feature(x: torch.Tensor) -> torch.Tensor:
"""Converts a time series tensor to a feature tensor."""
return x.permute(0, 2, 1)
feature_to_time = time_to_feature
================================================
FILE: probts/model/nn/arch/TimesFMModule/__init__.py
================================================
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# # limitations under the License.
"""TimesFM init file."""
# print(
# "TimesFM v1.2.0. See https://github.com/google-research/timesfm/blob/master/README.md for updated APIs."
# )
from probts.model.nn.arch.TimesFMModule.timesfm_base import freq_map, TimesFmCheckpoint, TimesFmHparams, TimesFmBase
# print("Loaded PyTorch TimesFM.")
from probts.model.nn.arch.TimesFMModule.timesfm_torch import TimesFmTorch as TimesFm
================================================
FILE: probts/model/nn/arch/TimesFMModule/patched_decoder.py
================================================
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pax ML model for patched time-series decoder.
The file implements Residual MLPs, Patched Decoder layers and PAX ML models.
"""
import dataclasses
from typing import Optional, Tuple
import einshape as es
from jax import lax
import jax.numpy as jnp
from praxis import base_layer
from praxis import base_model
from praxis import layers
from praxis import pax_fiddle
from praxis import py_utils
from praxis import pytypes
from praxis.layers import activations
from praxis.layers import embedding_softmax
from praxis.layers import linears
from praxis.layers import normalizations
from praxis.layers import stochastics
from praxis.layers import transformers
# PAX shortcuts
NestedMap = py_utils.NestedMap
JTensor = pytypes.JTensor
LayerTpl = pax_fiddle.Config[base_layer.BaseLayer]
template_field = base_layer.template_field
PAD_VAL = 1123581321.0
DEFAULT_QUANTILES = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
# NestedMap keys
_INPUT_TS = "input_ts"
_TARGET_FUTURE = "actual_ts"
_INPUT_PADDING = "input_padding"
_OUTPUT_TS = "output_ts"
_FREQ = "freq"
_OUTPUT_TOKENS = "output_tokens"
_STATS = "stats"
# Small numerical value.
_TOLERANCE = 1e-7
def _shift_padded_seq(mask: JTensor, seq: JTensor) -> JTensor:
"""Shifts rows of seq based on the first 0 in each row of the mask."""
num = seq.shape[1]
# Find the index of the first 0 in each row of the mask
first_zero_idx = jnp.argmin(mask, axis=1)
# Create a range array for indexing
idx_range = jnp.arange(num)
def shift_row(carry, x):
seq_row, shift = x
shifted_idx = (idx_range - shift) % num
shifted_row = seq_row[shifted_idx]
return carry, shifted_row
# Use lax.scan to shift each row of seq based on the corresponding
# first_zero_idx.
_, shifted_seq = lax.scan(shift_row, None, (seq, first_zero_idx))
return shifted_seq
class ResidualBlock(base_layer.BaseLayer):
"""Simple feedforward block with residual connection.
Attributes:
input_dims: input dimension.
hidden_dims: hidden dimension.
output_dims: output dimension.
dropout_prob: dropout probability.
layer_norm: whether to use layer norm or not.
dropout_tpl: config for dropout.
ln_tpl: config for layer norm.
act_tpl: config for activation in hidden layer.
"""
input_dims: int = 0
hidden_dims: int = 0
output_dims: int = 0
dropout_prob: float = 0.0
layer_norm: bool = False
dropout_tpl: LayerTpl = template_field(stochastics.Dropout)
ln_tpl: LayerTpl = template_field(normalizations.LayerNorm)
act_tpl: LayerTpl = template_field(activations.Swish)
def setup(self):
lnorm_tpl = self.ln_tpl.clone()
lnorm_tpl.dim = self.output_dims
self.create_child("ln_layer", lnorm_tpl)
dropout_tpl = self.dropout_tpl.clone()
dropout_tpl.keep_prob = 1.0 - self.dropout_prob
self.create_child("dropout", dropout_tpl)
self.create_child(
"hidden_layer",
pax_fiddle.Config(
linears.FeedForward,
input_dims=self.input_dims,
output_dims=self.hidden_dims,
activation_tpl=self.act_tpl.clone(),
),
)
self.create_child(
"output_layer",
pax_fiddle.Config(
linears.FeedForward,
input_dims=self.hidden_dims,
output_dims=self.output_dims,
activation_tpl=pax_fiddle.Config(activations.Identity),
),
)
self.create_child(
"residual_layer",
pax_fiddle.Config(
linears.FeedForward,
input_dims=self.input_dims,
output_dims=self.output_dims,
activation_tpl=pax_fiddle.Config(activations.Identity),
),
)
def __call__(self, inputs: JTensor) -> JTensor:
hidden = self.hidden_layer(inputs)
output = self.output_layer(hidden)
output = self.dropout(output)
residual = self.residual_layer(inputs)
if self.layer_norm:
return self.ln_layer(output + residual)
else:
return output + residual
def _masked_mean_std(inputs: JTensor,
padding: JTensor) -> Tuple[JTensor, JTensor]:
"""Calculates mean and standard deviation of arr across axis 1.
It should exclude values where pad is 1.
Args:
inputs: A JAX array of shape [b, n, p].
padding: A JAX array of shape [b, n, p] with values 0 or 1.
Returns:
A tuple containing the mean and standard deviation of arr. We return the
statistics of the first patch with more than three non-padded values.
"""
# Selecting the first pad with more than 3 unpadded values.
pad_sum = jnp.sum(1 - padding, axis=2)
def _get_patch_index(arr: JTensor):
indices = jnp.argmax(arr >= 3, axis=1)
row_sum = (arr >= 3).sum(axis=1)
return jnp.where(row_sum == 0, arr.shape[1] - 1, indices)
patch_indices = _get_patch_index(pad_sum)
bidxs = jnp.arange(inputs.shape[0])
arr = inputs[bidxs, patch_indices, :]
pad = padding[bidxs, patch_indices, :]
# Create a mask where P is 0
mask = 1 - pad
# Calculate the number of valid elements
num_valid_elements = jnp.sum(mask, axis=1)
num_valid_elements = jnp.where(num_valid_elements == 0, 1, num_valid_elements)
# Calculate the masked sum and squared sum of M
masked_sum = jnp.sum(arr * mask, axis=1)
masked_squared_sum = jnp.sum((arr * mask)**2, axis=1)
# Calculate the masked mean and standard deviation
masked_mean = masked_sum / num_valid_elements
masked_var = masked_squared_sum / num_valid_elements - masked_mean**2
masked_var = jnp.where(masked_var < 0.0, 0.0, masked_var)
masked_std = jnp.sqrt(masked_var)
return masked_mean, masked_std
def _create_quantiles() -> list[float]:
"""Returns the quantiles for forecasting."""
return DEFAULT_QUANTILES
class PatchedTimeSeriesDecoder(base_layer.BaseLayer):
"""Patch decoder layer for time-series foundation model.
Attributes:
patch_len: length of input patches.
horizon_len: length of output patches. Referred to as `output_patch_len`
during inference.
model_dims: model dimension of stacked transformer layer.
hidden_dims: hidden dimensions in fully connected layers.
quantiles: list of quantiles for non prob model.
residual_block_tpl: config for residual block.
stacked_transformer_params_tpl: config for stacked transformer.
use_freq: whether to use frequency encoding.
In all of what followed, except specified otherwise, B is batch size, T is
sequence length of time-series. N is the number of input patches that can be
obtained from T. P is the input patch length and H is the horizon length. Q is
number of output logits. D is model dimension.
"""
patch_len: int = 0
horizon_len: int = 0
model_dims: int = 0
hidden_dims: int = 0
quantiles: list[float] = dataclasses.field(default_factory=_create_quantiles)
residual_block_tpl: LayerTpl = template_field(ResidualBlock)
stacked_transformer_params_tpl: LayerTpl = template_field(
transformers.StackedTransformer)
use_freq: bool = True
use_pos_emb: bool = True
def setup(self) -> None:
"""Construct the model."""
num_outputs = len(self.quantiles) + 1
stl = self.stacked_transformer_params_tpl.clone()
stl.model_dims = self.model_dims
stl.hidden_dims = self.hidden_dims
stl.mask_self_attention = True
self.create_child("stacked_transformer_layer", stl)
input_resl = self.residual_block_tpl.clone()
ff_in_dims = 2 * self.patch_len
input_resl.input_dims = ff_in_dims
input_resl.hidden_dims = self.hidden_dims
input_resl.output_dims = self.model_dims
self.create_child(
"input_ff_layer",
input_resl,
)
horizon_resl = self.residual_block_tpl.clone()
horizon_resl.input_dims = self.model_dims
horizon_resl.hidden_dims = self.hidden_dims
horizon_resl.output_dims = self.horizon_len * num_outputs
self.create_child(
"horizon_ff_layer",
horizon_resl,
)
self.create_child(
"position_emb",
pax_fiddle.Config(layers.PositionalEmbedding,
embedding_dims=self.model_dims),
)
if self.use_freq:
self.create_child(
"freq_emb",
pax_fiddle.Config(
embedding_softmax.Embedding,
num_classes=3,
input_dims=self.model_dims,
),
)
def transform_decode_state(
self, transform_fn: base_layer.DecodeStateTransformFn) -> None:
"""Transforms all decode state variables based on transform_fn."""
self.stacked_transformer_layer.transform_decode_state(transform_fn)
def _forward_transform(
self, inputs: JTensor,
patched_pads: JTensor) -> Tuple[JTensor, Tuple[JTensor, JTensor]]:
"""Input is of shape [B, N, P]."""
mu, sigma = _masked_mean_std(inputs, patched_pads)
sigma = jnp.where(sigma < _TOLERANCE, 1.0, sigma)
# Normalize each patch.
outputs = (inputs - mu[:, None, None]) / sigma[:, None, None]
outputs = jnp.where(
jnp.abs(inputs - PAD_VAL) < _TOLERANCE, PAD_VAL, outputs)
return outputs, (mu, sigma)
def _reverse_transform(self, outputs: JTensor,
stats: Tuple[JTensor, JTensor]) -> JTensor:
"""Output is of shape [B, N, P, Q]."""
mu, sigma = stats
return outputs * sigma[:, None, None, None] + mu[:, None, None, None]
def _preprocess_input(
self,
input_ts: JTensor,
input_padding: JTensor,
pos_emb: Optional[JTensor] = None,
) -> Tuple[JTensor, JTensor, Optional[Tuple[JTensor, JTensor]], JTensor]:
"""Preprocess input for stacked transformer."""
# Reshape into patches.
patched_inputs = es.jax_einshape("b(np)->bnp", input_ts, p=self.patch_len)
patched_pads = es.jax_einshape("b(np)->bnp",
input_padding,
p=self.patch_len)
patched_inputs = jnp.where(
jnp.abs(patched_pads - 1.0) < _TOLERANCE, 0.0, patched_inputs)
patched_pads = jnp.where(
jnp.abs(patched_inputs - PAD_VAL) < _TOLERANCE, 1, patched_pads)
patched_inputs, stats = self._forward_transform(patched_inputs,
patched_pads)
# B x N x D
patched_inputs = patched_inputs * (1.0 - patched_pads)
concat_inputs = jnp.concatenate([patched_inputs, patched_pads], axis=-1)
model_input = self.input_ff_layer(concat_inputs)
# A patch should not be padded even if there is at least one zero.
patched_padding = jnp.min(patched_pads, axis=-1)
if self.use_pos_emb:
if pos_emb is None:
position_emb = self.position_emb(seq_length=model_input.shape[1])
else:
position_emb = pos_emb
if self.do_eval:
if position_emb.shape[0] != model_input.shape[0]:
position_emb = jnp.repeat(position_emb, model_input.shape[0], axis=0)
position_emb = _shift_padded_seq(patched_padding, position_emb)
model_input += position_emb
return model_input, patched_padding, stats, patched_inputs
def _postprocess_output(
self,
model_output: JTensor,
num_outputs: int,
stats: Tuple[JTensor, JTensor],
) -> JTensor:
"""Postprocess output of stacked transformer."""
# B x N x (H.Q)
output_ts = self.horizon_ff_layer(model_output)
output_ts = es.jax_einshape("bn(hq)->bnhq",
output_ts,
q=num_outputs,
h=self.horizon_len)
return self._reverse_transform(output_ts, stats)
def __call__(self, inputs: NestedMap) -> NestedMap:
"""PatchTST call.
Args:
inputs: A NestedMap containing (1) input_ts: input sequence of shape [B,
T] where T must be multiple of patch_length; (2) input_padding: that
contains padding map.
Returns:
A nested map with two keys:
(1) 'output_tokens' of shape [B, N, D].
(2) 'output_ts' of shape [B, N, H, Q]
(3) 'stats' a Tuple of statistics for renormalization.
"""
input_ts, input_padding = inputs[_INPUT_TS], inputs[_INPUT_PADDING]
num_outputs = len(self.quantiles) + 1
model_input, patched_padding, stats, _ = self._preprocess_input(
input_ts=input_ts,
input_padding=input_padding,
)
if self.use_freq:
freq = inputs[_FREQ].astype(jnp.int32)
f_emb = self.freq_emb(freq) # B x 1 x D
f_emb = jnp.repeat(f_emb, model_input.shape[1], axis=1)
model_input += f_emb
model_output = self.stacked_transformer_layer(model_input, patched_padding)
output_ts = self._postprocess_output(model_output, num_outputs, stats)
return NestedMap({
_OUTPUT_TOKENS: model_output,
_OUTPUT_TS: output_ts,
_STATS: stats
})
def decode(
self,
inputs: NestedMap,
horizon_len: int,
output_patch_len: Optional[int] = None,
max_len: int = 512,
return_forecast_on_context: bool = False,
) -> tuple[JTensor, JTensor]:
"""Auto-regressive decoding without caching.
Args:
inputs: input time-series and paddings. Time-series shape B x C, padding
shape shape B x (C + H) where H is the prediction length.
horizon_len: prediction length.
output_patch_len: output length to be fetched from one step of
auto-regressive decoding.
max_len: maximum training context length.
return_forecast_on_context: whether to return the model forecast on the
context except the first input patch.
Returns:
Tuple of two forecasting results:
- Point (mean) output predictions as a tensor with shape B x H'.
- Full predictions (mean and quantiles) as a tensor with shape
B x H' x (1 + # quantiles).
In particular, if return_forecast_on_context is True, H' is H plus
the forecastable context length, i.e. context_len - (first) patch_len.
"""
final_out = inputs[_INPUT_TS]
context_len = final_out.shape[1]
paddings = inputs[_INPUT_PADDING]
if self.use_freq:
freq = inputs[_FREQ].astype(jnp.int32)
else:
freq = jnp.zeros([final_out.shape[0], 1], dtype=jnp.int32)
full_outputs = []
if paddings.shape[1] != final_out.shape[1] + horizon_len:
raise ValueError(
"Length of paddings must match length of input + horizon_len:"
f" {paddings.shape[1]} != {final_out.shape[1]} + {horizon_len}")
if output_patch_len is None:
output_patch_len = self.horizon_len
num_decode_patches = (horizon_len + output_patch_len -
1) // output_patch_len
for step_index in range(num_decode_patches):
current_padding = paddings[:, 0:final_out.shape[1]]
input_ts = final_out[:, -max_len:]
input_padding = current_padding[:, -max_len:]
model_input = NestedMap(
input_ts=input_ts,
input_padding=input_padding,
freq=freq,
)
fprop_outputs = self(model_input)[_OUTPUT_TS]
if return_forecast_on_context and step_index == 0:
# For the first decodings step, collect the model forecast on the
# context except the unavailable first input batch forecast.
new_full_ts = fprop_outputs[:, :-1, :self.patch_len, :]
new_full_ts = es.jax_einshape("bnph->b(np)h", new_full_ts)
full_outputs.append(new_full_ts)
# (full batch, last patch, output_patch_len, index of mean forecast = 0)
new_ts = fprop_outputs[:, -1, :output_patch_len, 0]
new_full_ts = fprop_outputs[:, -1, :output_patch_len, :]
# (full batch, last patch, output_patch_len, all output indices)
full_outputs.append(new_full_ts)
final_out = jnp.concatenate([final_out, new_ts], axis=-1)
if return_forecast_on_context:
# `full_outputs` indexing starts at after the first input patch.
full_outputs = jnp.concatenate(full_outputs,
axis=1)[:, :(context_len - self.patch_len +
horizon_len), :]
else:
# `full_outputs` indexing starts at the forecast horizon.
full_outputs = jnp.concatenate(full_outputs, axis=1)[:, 0:horizon_len, :]
return (full_outputs[:, :, 0], full_outputs)
class PatchedDecoderFinetuneModel(base_model.BaseModel):
"""Model class for finetuning patched time-series decoder.
Attributes:
core_layer_tpl: config for core layer.
freq: freq to finetune on.
"""
core_layer_tpl: LayerTpl = template_field(PatchedTimeSeriesDecoder)
freq: int = 0
def setup(self) -> None:
self.create_child("core_layer", self.core_layer_tpl)
def compute_predictions(self, input_batch: NestedMap) -> NestedMap:
input_ts = input_batch[_INPUT_TS]
input_padding = jnp.zeros_like(input_ts)
context_len = input_ts.shape[1]
input_patch_len = self.core_layer_tpl.patch_len
context_pad = ((context_len + input_patch_len - 1) //
input_patch_len) * input_patch_len - context_len
input_ts = jnp.pad(input_ts, [(0, 0), (context_pad, 0)])
input_padding = jnp.pad(input_padding, [(0, 0), (context_pad, 0)],
constant_values=1)
freq = jnp.ones([input_ts.shape[0], 1], dtype=jnp.int32) * self.freq
new_input_batch = NestedMap(
input_ts=input_ts,
input_padding=input_padding,
freq=freq,
)
return self.core_layer(new_input_batch)
def _quantile_loss(self, pred: JTensor, actual: JTensor,
quantile: float) -> JTensor:
"""Calculates quantile loss.
Args:
pred: B x T
actual: B x T
quantile: quantile at which loss is computed.
Returns:
per coordinate loss.
"""
dev = actual - pred
loss_first = dev * quantile
loss_second = -dev * (1.0 - quantile)
return 2 * jnp.where(loss_first >= 0, loss_first, loss_second)
def compute_loss(self, prediction_output: NestedMap,
input_batch: NestedMap) -> Tuple[NestedMap, NestedMap]:
output_ts = prediction_output[_OUTPUT_TS]
actual_ts = input_batch[_TARGET_FUTURE]
pred_ts = output_ts[:, -1, 0:actual_ts.shape[1], :]
loss = jnp.square(pred_ts[:, :, 0] - actual_ts)
for i, quantile in enumerate(self.core_layer.quantiles):
loss += self._quantile_loss(pred_ts[:, :, i + 1], actual_ts, quantile)
loss = loss.mean()
loss_weight = jnp.array(1.0, dtype=jnp.float32)
per_example_out = NestedMap()
return {"avg_qloss": (loss, loss_weight)}, per_example_out
================================================
FILE: probts/model/nn/arch/TimesFMModule/pytorch_patched_decoder.py
================================================
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pytorch version of patched decoder."""
import dataclasses
import math
from typing import List, Tuple
import torch
from torch import nn
import torch.nn.functional as F
def _create_quantiles() -> list[float]:
return [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
@dataclasses.dataclass
class TimesFMConfig:
"""Config for initializing timesfm patched_decoder class."""
# The number of blocks in the model.
num_layers: int = 20
# The number of attention heads used in the attention layers of the model.
num_heads: int = 16
# The number of key-value heads for implementing attention.
num_kv_heads: int = 16
# The hidden size of the model.
hidden_size: int = 1280
# The dimension of the MLP representations.
intermediate_size: int = 1280
# The number of head dimensions.
head_dim: int = 80
# The epsilon used by the rms normalization layers.
rms_norm_eps: float = 1e-6
# Patch length
patch_len: int = 32
# Horizon length
horizon_len: int = 128
# quantiles
quantiles: List[float] = dataclasses.field(default_factory=_create_quantiles)
# Padding value
pad_val: float = 1123581321.0
# Tolerance
tolerance: float = 1e-6
# The dtype of the weights.
dtype: str = "bfloat32"
# use positional embedding
use_positional_embedding: bool = True
def _masked_mean_std(
inputs: torch.Tensor,
padding: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Calculates mean and standard deviation of `inputs` across axis 1.
It excludes values where `padding` is 1.
Args:
inputs: A PyTorch tensor of shape [b, n, p].
padding: A PyTorch tensor of shape [b, n, p] with values 0 or 1.
Returns:
A tuple containing the mean and standard deviation.
We return the statistics of the first patch with more than three non-padded
values.
"""
# Selecting the first patch with more than 3 unpadded values.
pad_sum = torch.sum(1 - padding, dim=2)
def _get_patch_index(arr: torch.Tensor):
indices = torch.argmax((arr >= 3).to(torch.int32), dim=1)
row_sum = (arr >= 3).to(torch.int32).sum(dim=1)
return torch.where(row_sum == 0, arr.shape[1] - 1, indices)
patch_indices = _get_patch_index(pad_sum)
bidxs = torch.arange(inputs.shape[0])
arr = inputs[bidxs, patch_indices, :]
pad = padding[bidxs, patch_indices, :]
# Create a mask where padding is 0
mask = 1 - pad
# Calculate the number of valid elements
num_valid_elements = torch.sum(mask, dim=1)
num_valid_elements = torch.where(
num_valid_elements == 0,
torch.tensor(1,
dtype=num_valid_elements.dtype,
device=num_valid_elements.device),
num_valid_elements,
)
# Calculate the masked sum and squared sum
masked_sum = torch.sum(arr * mask, dim=1)
masked_squared_sum = torch.sum((arr * mask)**2, dim=1)
# Calculate the masked mean and standard deviation
masked_mean = masked_sum / num_valid_elements
masked_var = masked_squared_sum / num_valid_elements - masked_mean**2
masked_var = torch.where(
masked_var < 0.0,
torch.tensor(0.0, dtype=masked_var.dtype, device=masked_var.device),
masked_var,
)
masked_std = torch.sqrt(masked_var)
return masked_mean, masked_std
def _shift_padded_seq(mask: torch.Tensor, seq: torch.Tensor) -> torch.Tensor:
"""Shifts rows of seq based on the first 0 in each row of the mask.
Args:
mask: mask tensor of shape [B, N]
seq: seq tensor of shape [B, N, P]
Returns:
Returns the shifted sequence.
"""
batch_size, num_seq, feature_dim = seq.shape
new_mask: torch.BoolTensor = mask == 0
# Use argmax to find the first True value in each row
indices = new_mask.to(torch.int32).argmax(dim=1)
# Handle rows with all zeros
indices[~new_mask.any(dim=1)] = -1
# Create index ranges for each sequence in the batch
idx_range = (torch.arange(num_seq).to(
seq.device).unsqueeze(0).unsqueeze(-1).expand(batch_size, -1,
feature_dim))
# Calculate shifted indices for each element in each sequence
shifted_idx = (idx_range - indices[:, None, None]) % num_seq
# Gather values from seq using shifted indices
shifted_seq = seq.gather(1, shifted_idx)
return shifted_seq
def get_large_negative_number(dtype: torch.dtype) -> torch.Tensor:
"""Returns a large negative value for the given dtype."""
if dtype.is_floating_point:
dtype_max = torch.finfo(dtype).max
else:
dtype_max = torch.iinfo(dtype).max
return torch.tensor(-0.7 * dtype_max, dtype=dtype)
def apply_mask_to_logits(logits: torch.Tensor,
mask: torch.Tensor) -> torch.Tensor:
"""Applies a floating-point mask to a set of logits.
Args:
logits: A torch.Tensor of logit values.
mask: A torch.Tensor (float32) of mask values with the encoding described
in the function documentation.
Returns:
Masked logits.
"""
min_value = get_large_negative_number(logits.dtype)
return torch.where((mask >= min_value * 0.5), logits, min_value)
def convert_paddings_to_mask(
paddings: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor:
"""Converts binary paddings to a logit mask ready to add to attention matrix.
Args:
paddings: binary torch.Tensor of shape [B, T], with 1 denoting padding
token.
dtype: data type of the input.
Returns:
A torch.Tensor of shape [B, 1, 1, T] ready to add to attention logits.
"""
attention_mask = paddings.detach().clone()
attention_mask = attention_mask[:, None, None, :] # Equivalent to jnp.newaxis
attention_mask *= get_large_negative_number(dtype)
return attention_mask
def causal_mask(input_t: torch.Tensor) -> torch.Tensor:
"""Computes and returns causal mask.
Args:
input_t: A torch.Tensor of shape [B, T, D].
Returns:
An attention_mask torch.Tensor of shape [1, 1, T, T]. Attention mask has
already been converted to large negative values.
"""
assert input_t.dtype.is_floating_point, input_t.dtype
large_negative_number = get_large_negative_number(input_t.dtype)
t = input_t.shape[1]
col_idx = torch.arange(t).unsqueeze(0).repeat(t, 1)
row_idx = torch.arange(t).unsqueeze(1).repeat(1, t)
mask = (row_idx < col_idx).to(input_t.dtype) * large_negative_number
return (mask.unsqueeze(0).unsqueeze(0).to(input_t.device)
) # Equivalent to jnp.newaxis
def merge_masks(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
"""Merges 2 masks.
logscale mask is expected but 0/1 mask is also fine.
Args:
a: torch.Tensor of shape [1|B, 1, 1|T, S].
b: torch.Tensor of shape [1|B, 1, 1|T, S].
Returns:
torch.Tensor of shape [1|B, 1, 1|T, S].
"""
def expand_t(key_mask):
query_mask = key_mask.transpose(-1, -2) # Equivalent of jnp.transpose
return torch.minimum(query_mask, key_mask)
if a.shape[2] != b.shape[2]:
if a.shape[2] == 1:
a = expand_t(a)
else:
assert b.shape[2] == 1
b = expand_t(b)
assert a.shape[1:] == b.shape[1:], f"a.shape={a.shape}, b.shape={b.shape}."
return torch.minimum(a, b) # Element-wise minimum, similar to jnp.minimum
class ResidualBlock(nn.Module):
"""TimesFM residual block."""
def __init__(
self,
input_dims,
hidden_dims,
output_dims,
):
super(ResidualBlock, self).__init__()
self.input_dims = input_dims
self.hidden_dims = hidden_dims
self.output_dims = output_dims
# Hidden Layer
self.hidden_layer = nn.Sequential(
nn.Linear(input_dims, hidden_dims),
nn.SiLU(),
)
# Output Layer
self.output_layer = nn.Linear(hidden_dims, output_dims)
# Residual Layer
self.residual_layer = nn.Linear(input_dims, output_dims)
def forward(self, x):
hidden = self.hidden_layer(x)
output = self.output_layer(hidden)
residual = self.residual_layer(x)
return output + residual
class RMSNorm(torch.nn.Module):
"""Pax rms norm in pytorch."""
def __init__(
self,
dim: int,
eps: float = 1e-6,
add_unit_offset: bool = False,
):
super().__init__()
self.eps = eps
self.add_unit_offset = add_unit_offset
self.weight = nn.Parameter(torch.zeros(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float())
if self.add_unit_offset:
output = output * (1 + self.weight.float())
else:
output = output * self.weight.float()
return output.type_as(x)
class TransformerMLP(nn.Module):
"""Pax transformer MLP in pytorch."""
def __init__(
self,
hidden_size: int,
intermediate_size: int,
):
super().__init__()
self.gate_proj = nn.Linear(hidden_size, intermediate_size)
self.down_proj = nn.Linear(intermediate_size, hidden_size)
self.layer_norm = nn.LayerNorm(normalized_shape=hidden_size, eps=1e-6)
def forward(self, x, paddings=None):
gate_inp = self.layer_norm(x)
gate = self.gate_proj(gate_inp)
gate = F.relu(gate)
outputs = self.down_proj(gate)
if paddings is not None:
outputs = outputs * (1.0 - paddings[:, :, None])
return outputs + x
class TimesFMAttention(nn.Module):
"""Implements the attention used in TimesFM."""
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
head_dim: int,
):
super().__init__()
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.hidden_size = hidden_size
self.head_dim = head_dim
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = nn.Parameter(
torch.empty((self.head_dim,), dtype=torch.float32),)
self.qkv_proj = nn.Linear(
self.hidden_size,
(self.num_heads + 2 * self.num_kv_heads) * self.head_dim,
)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size)
def _per_dim_scaling(self, query: torch.Tensor) -> torch.Tensor:
# [batch_size, n_local_heads, input_len, head_dim]
r_softplus_0 = 1.442695041
softplus_func = torch.nn.Softplus()
scale = r_softplus_0 / math.sqrt(self.head_dim)
scale = scale * softplus_func(self.scaling)
return query * scale[None, None, None, :]
def forward(
self,
hidden_states: torch.Tensor,
mask: torch.Tensor,
kv_write_indices: torch.Tensor | None = None,
kv_cache: Tuple[torch.Tensor, torch.Tensor] | None = None,
) -> torch.Tensor:
hidden_states_shape = hidden_states.shape
assert len(hidden_states_shape) == 3
batch_size, input_len, _ = hidden_states_shape
qkv = self.qkv_proj(hidden_states)
xq, xk, xv = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
xq = xq.view(batch_size, -1, self.num_heads, self.head_dim)
xk = xk.view(batch_size, -1, self.num_kv_heads, self.head_dim)
xv = xv.view(batch_size, -1, self.num_kv_heads, self.head_dim)
xq = self._per_dim_scaling(xq)
# Write new kv cache.
# [batch_size, input_len, n_local_kv_heads, head_dim]
if kv_cache is not None and kv_write_indices is not None:
k_cache, v_cache = kv_cache
k_cache.index_copy_(1, kv_write_indices, xk)
v_cache.index_copy_(1, kv_write_indices, xv)
key = k_cache
value = v_cache
else:
key = xk
value = xv
if self.num_kv_heads != self.num_heads:
# [batch_size, max_seq_len, n_local_heads, head_dim]
key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=2)
value = torch.repeat_interleave(value, self.num_queries_per_kv, dim=2)
# [batch_size, n_local_heads, input_len, head_dim]
q = xq.transpose(1, 2)
# [batch_size, n_local_heads, max_seq_len, head_dim]
k = key.transpose(1, 2)
v = value.transpose(1, 2)
# [batch_size, n_local_heads, input_len, max_seq_len]
scores = torch.matmul(q, k.transpose(2, 3))
scores = scores + mask
scores = F.softmax(scores.float(), dim=-1).type_as(q)
# [batch_size, n_local_heads, input_len, head_dim]
output = torch.matmul(scores, v)
# return scores, output.transpose(1, 2).contiguous()
# [batch_size, input_len, hidden_dim]
output = output.transpose(1, 2).contiguous().view(batch_size, input_len, -1)
output = self.o_proj(output)
return scores, output
class TimesFMDecoderLayer(nn.Module):
"""Transformer layer."""
def __init__(
self,
hidden_size: int,
intermediate_size: int,
num_heads: int,
num_kv_heads: int,
head_dim: int,
rms_norm_eps: float = 1e-6,
):
super().__init__()
self.self_attn = TimesFMAttention(
hidden_size=hidden_size,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
)
self.mlp = TransformerMLP(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
)
self.input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
mask: torch.Tensor,
paddings: torch.Tensor,
kv_write_indices: torch.Tensor | None = None,
kv_cache: Tuple[torch.Tensor, torch.Tensor] | None = None,
) -> torch.Tensor:
# Self Attention
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
scores, hidden_states = self.self_attn(
hidden_states=hidden_states,
mask=mask,
kv_write_indices=kv_write_indices,
kv_cache=kv_cache,
)
hidden_states = residual + hidden_states
# MLP
hidden_states = self.mlp(hidden_states, paddings=paddings)
return scores, hidden_states
class StackedDecoder(nn.Module):
"""Stacked transformer layer."""
def __init__(
self,
hidden_size: int,
intermediate_size: int,
num_heads: int,
num_kv_heads: int,
head_dim: int,
num_layers: int,
rms_norm_eps: float = 1e-6,
):
super().__init__()
self.layers = nn.ModuleList()
for _ in range(num_layers):
self.layers.append(
TimesFMDecoderLayer(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
rms_norm_eps=rms_norm_eps,
))
def forward(
self,
hidden_states: torch.Tensor,
paddings: torch.Tensor,
kv_write_indices: torch.Tensor | None = None,
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] | None = None,
) -> torch.Tensor:
padding_mask = convert_paddings_to_mask(paddings, hidden_states.dtype)
atten_mask = causal_mask(hidden_states)
mask = merge_masks(padding_mask, atten_mask)
for i in range(len(self.layers)):
layer = self.layers[i]
kv_cache = kv_caches[i] if kv_caches is not None else None
_, hidden_states = layer(
hidden_states=hidden_states,
mask=mask,
paddings=paddings,
kv_write_indices=kv_write_indices,
kv_cache=kv_cache,
)
return hidden_states
class PositionalEmbedding(torch.nn.Module):
"""Generates position embedding for a given 1-d sequence.
Attributes:
min_timescale: Start of the geometric index. Determines the periodicity of
the added signal.
max_timescale: End of the geometric index. Determines the frequency of the
added signal.
embedding_dims: Dimension of the embedding to be generated.
"""
def __init__(
self,
embedding_dims: int,
min_timescale: int = 1,
max_timescale: int = 10_000,
) -> None:
super().__init__()
self.min_timescale = min_timescale
self.max_timescale = max_timescale
self.embedding_dims = embedding_dims
def forward(self, seq_length=None, position=None):
"""Generates a Tensor of sinusoids with different frequencies.
Args:
seq_length: an optional Python int defining the output sequence length.
if the `position` argument is specified.
position: [B, seq_length], optional position for each token in the
sequence, only required when the sequence is packed.
Returns:
[B, seqlen, D] if `position` is specified, else [1, seqlen, D]
"""
if position is None:
assert seq_length is not None
# [1, seqlen]
position = torch.arange(seq_length, dtype=torch.float32).unsqueeze(0)
else:
assert position.ndim == 2, position.shape
num_timescales = self.embedding_dims // 2
log_timescale_increment = math.log(
float(self.max_timescale) / float(self.min_timescale)) / max(
num_timescales - 1, 1)
inv_timescales = self.min_timescale * torch.exp(
torch.arange(num_timescales, dtype=torch.float32) *
-log_timescale_increment)
scaled_time = position.unsqueeze(2) * inv_timescales.unsqueeze(0).unsqueeze(
0)
signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2)
# Padding to ensure correct embedding dimension
signal = F.pad(signal, (0, 0, 0, self.embedding_dims % 2))
return signal
class PatchedTimeSeriesDecoder(nn.Module):
"""Patched time-series decoder."""
def __init__(self, config: TimesFMConfig):
super().__init__()
self.config = config
self.input_ff_layer = ResidualBlock(
input_dims=2 * config.patch_len,
output_dims=config.hidden_size,
hidden_dims=config.intermediate_size,
)
self.freq_emb = nn.Embedding(num_embeddings=3,
embedding_dim=config.hidden_size)
self.horizon_ff_layer = ResidualBlock(
input_dims=config.hidden_size,
output_dims=config.horizon_len * (1 + len(config.quantiles)),
hidden_dims=config.intermediate_size,
)
self.stacked_transformer = StackedDecoder(
hidden_size=self.config.hidden_size,
intermediate_size=self.config.intermediate_size,
num_heads=self.config.num_heads,
num_kv_heads=self.config.num_kv_heads,
head_dim=self.config.head_dim,
num_layers=self.config.num_layers,
rms_norm_eps=self.config.rms_norm_eps,
)
if self.config.use_positional_embedding:
self.position_emb = PositionalEmbedding(self.config.hidden_size)
def _forward_transform(
self, inputs: torch.Tensor, patched_pads: torch.Tensor
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
"""Input is of shape [B, N, P]."""
mu, sigma = _masked_mean_std(inputs, patched_pads)
sigma = torch.where(
sigma < self.config.tolerance,
torch.tensor(1.0, dtype=sigma.dtype, device=sigma.device),
sigma,
)
# Normalize each patch
outputs = (inputs - mu[:, None, None]) / sigma[:, None, None]
outputs = torch.where(
torch.abs(inputs - self.config.pad_val) < self.config.tolerance,
torch.tensor(self.config.pad_val,
dtype=outputs.dtype,
device=outputs.device),
outputs,
)
return outputs, (mu, sigma)
def _reverse_transform(
self, outputs: torch.Tensor, stats: tuple[torch.Tensor,
torch.Tensor]) -> torch.Tensor:
"""Output is of shape [B, N, P, Q]."""
mu, sigma = stats
return outputs * sigma[:, None, None, None] + mu[:, None, None, None]
def _preprocess_input(
self,
input_ts: torch.Tensor,
input_padding: torch.Tensor,
) -> tuple[
torch.Tensor,
torch.Tensor,
tuple[torch.Tensor, torch.Tensor] | None,
torch.Tensor,
]:
"""Preprocess input for stacked transformer."""
# Reshape into patches (using view for efficiency)
bsize = input_ts.shape[0]
patched_inputs = input_ts.view(bsize, -1, self.config.patch_len)
patched_pads = input_padding.view(bsize, -1, self.config.patch_len)
patched_inputs = torch.where(
torch.abs(patched_pads - 1.0) < self.config.tolerance,
torch.tensor(0.0,
dtype=patched_inputs.dtype,
device=patched_inputs.device),
patched_inputs,
)
patched_pads = torch.where(
torch.abs(patched_inputs - self.config.pad_val) < self.config.tolerance,
torch.tensor(1.0, dtype=patched_pads.dtype, device=patched_pads.device),
patched_pads,
)
patched_inputs, stats = self._forward_transform(patched_inputs,
patched_pads)
# B x N x D
patched_inputs = patched_inputs * (1.0 - patched_pads)
concat_inputs = torch.cat([patched_inputs, patched_pads], dim=-1)
model_input = self.input_ff_layer(concat_inputs)
# A patch should not be padded even if there is at least one zero.
patched_padding = torch.min(patched_pads,
dim=-1)[0] # Get the values from the min result
if self.config.use_positional_embedding:
pos_emb = self.position_emb(model_input.shape[1]).to(model_input.device)
pos_emb = torch.concat([pos_emb] * model_input.shape[0], dim=0)
pos_emb = _shift_padded_seq(patched_padding, pos_emb)
model_input += pos_emb
return model_input, patched_padding, stats, patched_inputs
def _postprocess_output(
self,
model_output: torch.Tensor,
num_outputs: int,
stats: tuple[torch.Tensor, torch.Tensor],
) -> torch.Tensor:
"""Postprocess output of stacked transformer."""
# B x N x (H.Q)
output_ts = self.horizon_ff_layer(model_output)
# Reshape using view
b, n, _ = output_ts.shape
output_ts = output_ts.view(b, n, self.config.horizon_len, num_outputs)
return self._reverse_transform(output_ts, stats)
def forward(
self,
input_ts: torch.Tensor,
input_padding: torch.LongTensor,
freq: torch.Tensor,
) -> torch.Tensor:
num_outputs = len(self.config.quantiles) + 1
model_input, patched_padding, stats, _ = self._preprocess_input(
input_ts=input_ts,
input_padding=input_padding,
)
f_emb = self.freq_emb(freq) # B x 1 x D
model_input += f_emb
model_output = self.stacked_transformer(model_input, patched_padding)
output_ts = self._postprocess_output(model_output, num_outputs, stats)
return output_ts
def decode(
self,
input_ts: torch.Tensor,
paddings: torch.Tensor,
freq: torch.LongTensor,
horizon_len: int,
output_patch_len: int | None = None,
max_len: int = 512,
return_forecast_on_context: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Auto-regressive decoding without caching.
Args:
input_ts: input time-series and paddings. Time-series shape B x C.
paddings: padding shape B x (C + H) where H is the prediction length.
freq: frequency shape B x 1
horizon_len: prediction length.
output_patch_len: output length to be fetched from one step of
auto-regressive decoding.
max_len: maximum training context length.
return_forecast_on_context: whether to return the model forecast on the
context except the first input patch.
Returns:
Tuple of two forecasting results:
- Point (mean) output predictions as a tensor with shape B x H'.
- Full predictions (mean and quantiles) as a tensor with shape
B x H' x (1 + # quantiles).
In particular, if return_forecast_on_context is True, H' is H plus
the forecastable context length, i.e. context_len - (first) patch_len.
"""
final_out = input_ts
context_len = final_out.shape[1]
full_outputs = []
if paddings.shape[1] != final_out.shape[1] + horizon_len:
raise ValueError(
"Length of paddings must match length of input + horizon_len:"
f" {paddings.shape[1]} != {final_out.shape[1]} + {horizon_len}")
if output_patch_len is None:
output_patch_len = self.config.horizon_len
num_decode_patches = (horizon_len + output_patch_len -
1) // output_patch_len
for step_index in range(num_decode_patches):
current_padding = paddings[:, 0:final_out.shape[1]]
input_ts = final_out[:, -max_len:]
input_padding = current_padding[:, -max_len:]
fprop_outputs = self(input_ts, input_padding, freq)
if return_forecast_on_context and step_index == 0:
# For the first decodings step, collect the model forecast on the
# context except the unavailable first input batch forecast.
new_full_ts = fprop_outputs[:, :-1, :self.config.patch_len, :]
new_full_ts = fprop_outputs.view(new_full_ts.size(0), -1,
new_full_ts.size(3))
full_outputs.append(new_full_ts)
# (full batch, last patch, output_patch_len, index of mean forecast = 0)
new_ts = fprop_outputs[:, -1, :output_patch_len, 0]
new_full_ts = fprop_outputs[:, -1, :output_patch_len, :]
# (full batch, last patch, output_patch_len, all output indices)
full_outputs.append(new_full_ts)
final_out = torch.concatenate([final_out, new_ts], axis=-1)
if return_forecast_on_context:
# `full_outputs` indexing starts at after the first input patch.
full_outputs = torch.concatenate(
full_outputs,
axis=1)[:, :(context_len - self.config.patch_len + horizon_len), :]
else:
# `full_outputs` indexing starts at the forecast horizon.
full_outputs = torch.concatenate(full_outputs, axis=1)[:,
0:horizon_len, :]
return (full_outputs[:, :, 0], full_outputs)
================================================
FILE: probts/model/nn/arch/TimesFMModule/timesfm_base.py
================================================
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base class for TimesFM inference. This will be common to PAX and Pytorch."""
import collections
import dataclasses
import logging
import multiprocessing
from typing import Any, Literal, Sequence
import numpy as np
import pandas as pd
from utilsforecast.processing import make_future_dataframe
from probts.model.nn.arch.TimesFMModule import xreg_lib
Category = xreg_lib.Category
XRegMode = xreg_lib.XRegMode
_TOL = 1e-6
DEFAULT_QUANTILES = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9)
def process_group(key, group, value_name, forecast_context_len):
group = group.tail(forecast_context_len)
return np.array(group[value_name], dtype=np.float32), key
def moving_average(arr, window_size):
"""Calculates the moving average using NumPy's convolution function."""
# Pad with zeros to handle initial window positions
arr_padded = np.pad(arr, (window_size - 1, 0), "constant")
smoothed_arr = (np.convolve(arr_padded, np.ones(window_size), "valid") /
window_size)
return [smoothed_arr, arr - smoothed_arr]
def freq_map(freq: str):
"""Returns the frequency map for the given frequency string."""
freq = str.upper(freq)
if (freq.endswith("H") or freq.endswith("T") or freq.endswith("MIN") or
freq.endswith("D") or freq.endswith("B") or freq.endswith("U") or
freq.endswith("S")):
return 0
elif freq.endswith(("W", "M", "MS")):
return 1
elif freq.endswith("Y") or freq.endswith("Q") or freq.endswith("A"):
return 2
else:
raise ValueError(f"Invalid frequency: {freq}")
def strip_leading_nans(arr):
"""
Removes contiguous NaN values from the beginning of a NumPy array.
Args:
arr: The input NumPy array.
Returns:
A new NumPy array with leading NaN values removed.
If the array is all NaNs or empty, returns an empty array.
"""
isnan = np.isnan(arr)
first_valid_index = np.argmax(~isnan)
return arr[first_valid_index:]
def linear_interpolation(arr):
"""
Performs linear interpolation to fill NaN values in a 1D numpy array.
Args:
arr: The 1D numpy array containing NaN values.
Returns:
A new numpy array with NaN values filled using linear interpolation,
or the original array if no NaNs are present.
Returns None if the input is not a 1D array.
Returns the original array if there are no NaN values.
"""
nans = np.isnan(arr)
if not np.any(nans): # Check if there are any NaNs
return arr
x = lambda z: z.nonzero()[0]
nans_indices = x(nans)
non_nans_indices = x(~nans)
non_nans_values = arr[~nans]
try:
arr[nans] = np.interp(nans_indices, non_nans_indices, non_nans_values)
except ValueError:
if len(non_nans_values) > 0:
mu = np.nanmean(arr)
else:
mu = 0.0
arr = np.where(np.isfinite(arr), arr, mu)
return arr
# Per time series normalization: forward.
def _normalize(batch):
stats = [
(np.mean(x), np.where((w := np.std(x)) > _TOL, w, 1.0)) for x in batch
]
new_batch = [(x - stat[0]) / stat[1] for x, stat in zip(batch, stats)]
return new_batch, stats
# Per time series normalization: inverse.
def _renormalize(batch, stats):
return [x * stat[1] + stat[0] for x, stat in zip(batch, stats)]
@dataclasses.dataclass(kw_only=True)
class TimesFmHparams:
"""Hparams used to initialize a TimesFM model for inference.
These are the sufficient subset of hparams to configure TimesFM inference
agnostic to the checkpoint version, and are not necessarily the same as the
hparams used to train the checkpoint.
Attributes:
context_len: Largest context length the model allows for each decode call.
This technically can be any large, but practically should set to the
context length the checkpoint was trained with.
horizon_len: Forecast horizon.
input_patch_len: Input patch len.
output_patch_len: Output patch len. How many timepoints is taken from a
single step of autoregressive decoding. Can be set as the training horizon
of the checkpoint.
num_layers: Number of transformer layers in the model.
model_dims: Model dimension.
per_core_batch_size: Batch size on each core for data parallelism.
backend: One of "cpu", "gpu" or "tpu".
quantiles: Which quantiles are output by the model.
"""
context_len: int = 512
horizon_len: int = 128
input_patch_len: int = 32
output_patch_len: int = 128
num_layers: int = 20
num_heads: int = 16
model_dims: int = 1280
per_core_batch_size: int = 32
backend: Literal["cpu", "gpu", "tpu"] = "cpu"
quantiles: Sequence[float] | None = DEFAULT_QUANTILES
use_positional_embedding: bool = True
# Hparams beyond the model.
point_forecast_mode: Literal["mean", "median"] = "median"
@dataclasses.dataclass(kw_only=True)
class TimesFmCheckpoint:
"""Checkpoint used to initialize a TimesFM model for inference.
Attributes:
version: Version of the checkpoint, e.g. "jax", "torch", "tensorflow", etc.
The factory will create the corresponding TimesFm inference class based on
this version.
path: Path to the checkpoint.
type: If provided, type of the checkpoint used by the specific checkpoint
loader per version.
step: If provided, step of the checkpoint.
"""
version: str = "jax"
path: str | None = None
huggingface_repo_id: str | None = None
type: Any = None
step: int | None = None
local_dir: str | None = None
class TimesFmBase:
"""Base TimesFM forecast API for inference.
This class is the scaffolding for calling TimesFM forecast. To properly use:
1. Create an instance with the correct hyperparameters of a TimesFM model.
2. Call `load_from_checkpoint` to load a compatible checkpoint.
3. Call `forecast` for inference.
"""
def _logging(self, s):
print(s)
def __post_init__(self) -> None:
"""Additional initialization for subclasses before checkpoint loading."""
pass
def __init__(self, hparams: TimesFmHparams,
checkpoint: TimesFmCheckpoint) -> None:
"""Initializes the TimesFM forecast API.
Args:
hparams: Hyperparameters of the model.
checkpoint: Checkpoint to load. Notice `checkpoint.version` will decide
which TimesFM version to use.
"""
self.hparams = hparams
# Expand hparams for conciseness within the model code.
self.context_len = hparams.context_len
self.horizon_len = hparams.horizon_len
self.input_patch_len = hparams.input_patch_len
self.output_patch_len = hparams.output_patch_len
self.num_layers = hparams.num_layers
self.model_dims = hparams.model_dims
self.backend = hparams.backend
self.quantiles = hparams.quantiles
self.num_heads = hparams.num_heads
self.use_pos_emb = hparams.use_positional_embedding
# Rewrite these values in __post_init__ for SPMD.
self.num_cores = 1
self.per_core_batch_size = hparams.per_core_batch_size
self.global_batch_size = hparams.per_core_batch_size
self._horizon_start = self.context_len - self.input_patch_len
self.__post_init__()
self.load_from_checkpoint(checkpoint)
def load_from_checkpoint(self, checkpoint: TimesFmCheckpoint) -> None:
"""Loads a checkpoint and compiles the decoder."""
raise NotImplementedError("`load_from_checkpoint` is not implemented.")
def _preprocess(
self, inputs: Sequence[np.ndarray],
freq: Sequence[int]) -> tuple[np.ndarray, np.ndarray, np.ndarray, int]:
"""Formats and pads raw inputs to feed into the model.
This function both pads each time series to match the context length, and
pads the inputs to meet the SPMD shape requirement.
Args:
inputs: A list of 1d JTensors. Each JTensor is the context time series of
a single forecast task.
freq: list of frequencies
Returns:
A tuple of:
- the padded input time series to meet the model required context.
- the padding indicator.
- the frequency of each input time series.
- the number of padded examples for SPMD so that each core has the same
number (a multiple of `batch_size`) of examples.
"""
input_ts, input_padding, inp_freq = [], [], []
pmap_pad = ((len(inputs) - 1) // self.global_batch_size +
1) * self.global_batch_size - len(inputs)
for i, ts in enumerate(inputs):
input_len = ts.shape[0]
padding = np.zeros(shape=(input_len + self.horizon_len,), dtype=float)
if input_len < self.context_len:
num_front_pad = self.context_len - input_len
ts = np.concatenate([np.zeros(shape=(num_front_pad,), dtype=float), ts],
axis=0)
padding = np.concatenate(
[np.ones(shape=(num_front_pad,), dtype=float), padding], axis=0)
elif input_len > self.context_len:
ts = ts[-self.context_len:]
padding = padding[-(self.context_len + self.horizon_len):]
input_ts.append(ts)
input_padding.append(padding)
inp_freq.append(freq[i])
# Padding the remainder batch.
for _ in range(pmap_pad):
input_ts.append(input_ts[-1])
input_padding.append(input_padding[-1])
inp_freq.append(inp_freq[-1])
return (
np.stack(input_ts, axis=0),
np.stack(input_padding, axis=0),
np.array(inp_freq).astype(np.int32).reshape(-1, 1),
pmap_pad,
)
def _forecast(
self,
inputs: Sequence[Any],
freq: Sequence[int] | None = None,
window_size: int | None = None,
forecast_context_len: int | None = None,
return_forecast_on_context: bool = False,
) -> tuple[np.ndarray, np.ndarray]:
"""Forecasts on a list of time series.
Args:
inputs: list of time series forecast contexts. Each context time series
should be in a format convertible to JTensor by `jnp.array`.
freq: frequency of each context time series. 0 for high frequency
(default), 1 for medium, and 2 for low. Notice this is different from
the `freq` required by `forecast_on_df`.
window_size: window size of trend + residual decomposition. If None then
we do not do decomposition.
forecast_context_len: optional max context length.
return_forecast_on_context: True to return the forecast on the context
when available, i.e. after the first input patch.
Returns:
A tuple for np.array:
- the mean forecast of size (# inputs, # forecast horizon),
- the full forecast (mean + quantiles) of size
(# inputs, # forecast horizon, 1 + # quantiles).
Raises:
ValueError: If the checkpoint is not properly loaded.
"""
raise NotImplementedError("`_forecast` is not implemented.")
def forecast(
self,
inputs: Sequence[Any],
freq: Sequence[int] | None = None,
window_size: int | None = None,
forecast_context_len: int | None = None,
return_forecast_on_context: bool = False,
normalize: bool = False,
) -> tuple[np.ndarray, np.ndarray]:
"""Forecasts on a list of time series.
Args:
inputs: list of time series forecast contexts. Each context time series
should be in a format convertible to JTensor by `jnp.array`.
freq: frequency of each context time series. 0 for high frequency
(default), 1 for medium, and 2 for low. Notice this is different from
the `freq` required by `forecast_on_df`.
window_size: window size of trend + residual decomposition. If None then
we do not do decomposition.
forecast_context_len: optional max context length.
return_forecast_on_context: True to return the forecast on the context
when available, i.e. after the first input patch.
normalize: If True, then we normalize the inputs before forecasting and
the outputs are then renormalized to the original scale.
Returns:
A tuple for np.array:
- the mean forecast of size (# inputs, # forecast horizon),
- the full forecast (mean + quantiles) of size
(# inputs, # forecast horizon, 1 + # quantiles).
Raises:
ValueError: If the checkpoint is not properly loaded.
"""
stats = None
tmp_inputs = []
for each_input in inputs:
arr = np.array(each_input)
if not np.isfinite(arr).all():
arr = np.where(np.isfinite(arr), arr, np.nan)
arr = strip_leading_nans(arr)
arr = linear_interpolation(arr)
tmp_inputs.append(arr)
inputs = tmp_inputs
if normalize:
inputs, stats = _normalize(inputs)
mean_forecast, quantile_forecast = self._forecast(
inputs,
freq,
window_size,
forecast_context_len,
return_forecast_on_context,
)
if stats is not None:
stats = np.array(stats)
mu = stats[:, 0]
sigma = stats[:, 1]
mean_forecast = mean_forecast * sigma[:, None] + mu[:, None]
quantile_forecast = (quantile_forecast * sigma[:, None, None] +
mu[:, None, None])
if self.hparams.point_forecast_mode == "mean":
return mean_forecast, quantile_forecast
elif self.hparams.point_forecast_mode == "median":
if self._median_index == -1:
for i, quantile in enumerate(self.quantiles):
if quantile == 0.5:
self._median_index = i
break
if self._median_index == -1:
raise ValueError("Median (0.5) is not found in the model quantiles:"
f" {self.quantiles}. Please check the hparams.")
return (
quantile_forecast[:, :, 1 + self._median_index],
quantile_forecast,
)
else:
raise ValueError(
"Unsupported point forecast mode:"
f" {self.hparams.point_forecast_mode}. Use 'mean' or 'median'.")
def forecast_with_covariates(
self,
inputs: list[Sequence[float]],
dynamic_numerical_covariates: (dict[str, Sequence[Sequence[float]]] |
None) = None,
dynamic_categorical_covariates: (dict[str, Sequence[Sequence[Category]]] |
None) = None,
static_numerical_covariates: dict[str, Sequence[float]] | None = None,
static_categorical_covariates: (dict[str, Sequence[Category]] |
None) = None,
freq: Sequence[int] | None = None,
window_size: int | None = None,
forecast_context_len: int | None = None,
xreg_mode: XRegMode = "xreg + timesfm",
normalize_xreg_target_per_input: bool = True,
ridge: float = 0.0,
max_rows_per_col: int = 0,
force_on_cpu: bool = False,
):
"""Forecasts on a list of time series with covariates.
To optimize inference speed, avoid string valued categorical covariates.
Args:
inputs: A list of time series forecast contexts. Each context time series
should be in a format convertible to JTensor by `jnp.array`.
dynamic_numerical_covariates: A dict of dynamic numerical covariates.
dynamic_categorical_covariates: A dict of dynamic categorical covariates.
static_numerical_covariates: A dict of static numerical covariates.
static_categorical_covariates: A dict of static categorical covariates.
freq: frequency of each context time series. 0 for high frequency
(default), 1 for medium, and 2 for low. Notice this is different from
the `freq` required by `forecast_on_df`.
window_size: window size of trend + residual decomposition. If None then
we do not do decomposition.
forecast_context_len: optional max context length.
xreg_mode: one of "xreg + timesfm" or "timesfm + xreg". "xreg + timesfm"
fits a model on the residuals of the TimesFM forecast. "timesfm + xreg"
fits a model on the targets then forecasts on the residuals via TimesFM.
normalize_xreg_target_per_input: whether to normalize the xreg target per
input in the given batch.
ridge: ridge penalty for the linear model.
max_rows_per_col: max number of rows per column for the linear model.
force_on_cpu: whether to force running on cpu for the linear model.
Returns:
A tuple of two lists. The first is the outputs of the model. The second is
the outputs of the xreg.
"""
# Verify and bookkeep covariates.
if not (dynamic_numerical_covariates or dynamic_categorical_covariates or
static_numerical_covariates or static_categorical_covariates):
raise ValueError(
"At least one of dynamic_numerical_covariates,"
" dynamic_categorical_covariates, static_numerical_covariates,"
" static_categorical_covariates must be set.")
# Track the lengths of (1) each input, (2) the part that can be used in the
# linear model, and (3) the horizon.
input_lens, train_lens, test_lens = [], [], []
for i, input_ts in enumerate(inputs):
input_len = len(input_ts)
input_lens.append(input_len)
if xreg_mode == "timesfm + xreg":
# For fitting residuals, no TimesFM forecast on the first patch.
train_lens.append(max(0, input_len - self.input_patch_len))
elif xreg_mode == "xreg + timesfm":
train_lens.append(input_len)
else:
raise ValueError(f"Unsupported mode: {xreg_mode}")
if dynamic_numerical_covariates:
test_lens.append(
len(list(dynamic_numerical_covariates.values())[0][i]) - input_len)
elif dynamic_categorical_covariates:
test_lens.append(
len(list(dynamic_categorical_covariates.values())[0][i]) -
input_len)
else:
test_lens.append(self.horizon_len)
if test_lens[-1] > self.horizon_len:
raise ValueError(
"Forecast requested longer horizon than the model definition "
f"supports: {test_lens[-1]} vs {self.horizon_len}.")
# Prepare the covariates into train and test.
train_dynamic_numerical_covariates = collections.defaultdict(list)
test_dynamic_numerical_covariates = collections.defaultdict(list)
train_dynamic_categorical_covariates = collections.defaultdict(list)
test_dynamic_categorical_covariates = collections.defaultdict(list)
for covariates, train_covariates, test_covariates in (
(
dynamic_numerical_covariates,
train_dynamic_numerical_covariates,
test_dynamic_numerical_covariates,
),
(
dynamic_categorical_covariates,
train_dynamic_categorical_covariates,
test_dynamic_categorical_covariates,
),
):
if not covariates:
continue
for covariate_name, covariate_values in covariates.items():
for input_len, train_len, covariate_value in zip(
input_lens, train_lens, covariate_values):
train_covariates[covariate_name].append(
covariate_value[(input_len - train_len):input_len])
test_covariates[covariate_name].append(covariate_value[input_len:])
# Fit models.
if xreg_mode == "timesfm + xreg":
# Forecast via TimesFM then fit a model on the residuals.
mean_outputs, _ = self.forecast(
inputs,
freq,
window_size,
forecast_context_len,
return_forecast_on_context=True,
)
targets = [
(np.array(input_ts)[-train_len:] -
mean_output[(self._horizon_start - train_len):self._horizon_start])
for input_ts, mean_output, train_len in zip(inputs, mean_outputs,
train_lens)
]
per_instance_stats = None
if normalize_xreg_target_per_input:
targets, per_instance_stats = _normalize(targets)
xregs = xreg_lib.BatchedInContextXRegLinear(
targets=targets,
train_lens=train_lens,
test_lens=test_lens,
train_dynamic_numerical_covariates=train_dynamic_numerical_covariates,
test_dynamic_numerical_covariates=test_dynamic_numerical_covariates,
train_dynamic_categorical_covariates=
train_dynamic_categorical_covariates,
test_dynamic_categorical_covariates=
test_dynamic_categorical_covariates,
static_numerical_covariates=static_numerical_covariates,
static_categorical_covariates=static_categorical_covariates,
).fit(
ridge=ridge,
one_hot_encoder_drop=None if ridge > 0 else "first",
max_rows_per_col=max_rows_per_col,
force_on_cpu=force_on_cpu,
debug_info=False,
assert_covariates=True,
assert_covariate_shapes=True,
)
if normalize_xreg_target_per_input:
xregs = _renormalize(xregs, per_instance_stats)
outputs = [
(mean_output[self._horizon_start:(self._horizon_start + test_len)] +
xreg)
for mean_output, test_len, xreg in zip(mean_outputs, test_lens, xregs)
]
else:
# Fit a model on the targets then forecast on the residuals via TimesFM.
targets = [
np.array(input_ts)[-train_len:]
for input_ts, train_len in zip(inputs, train_lens)
]
per_instance_stats = None
if normalize_xreg_target_per_input:
targets, per_instance_stats = _normalize(targets)
xregs, xregs_on_context, _, _, _ = xreg_lib.BatchedInContextXRegLinear(
targets=targets,
train_lens=train_lens,
test_lens=test_lens,
train_dynamic_numerical_covariates=train_dynamic_numerical_covariates,
test_dynamic_numerical_covariates=test_dynamic_numerical_covariates,
train_dynamic_categorical_covariates=
train_dynamic_categorical_covariates,
test_dynamic_categorical_covariates=
test_dynamic_categorical_covariates,
static_numerical_covariates=static_numerical_covariates,
static_categorical_covariates=static_categorical_covariates,
).fit(
ridge=ridge,
one_hot_encoder_drop=None if ridge > 0 else "first",
max_rows_per_col=max_rows_per_col,
force_on_cpu=force_on_cpu,
debug_info=True,
assert_covariates=True,
assert_covariate_shapes=True,
)
mean_outputs, _ = self.forecast(
[
target - xreg_on_context
for target, xreg_on_context in zip(targets, xregs_on_context)
],
freq,
window_size,
forecast_context_len,
return_forecast_on_context=True,
)
outputs = [
(mean_output[self._horizon_start:(self._horizon_start + test_len)] +
xreg)
for mean_output, test_len, xreg in zip(mean_outputs, test_lens, xregs)
]
if normalize_xreg_target_per_input:
outputs = _renormalize(outputs, per_instance_stats)
return outputs, xregs
def forecast_on_df(
self,
inputs: pd.DataFrame,
freq: str,
forecast_context_len: int = 0,
value_name: str = "values",
model_name: str = "timesfm",
window_size: int | None = None,
num_jobs: int = 1,
verbose: bool = True,
) -> pd.DataFrame:
"""Forecasts on a list of time series.
Args:
inputs: A pd.DataFrame of all time series. The dataframe should have a
`unique_id` column for identifying the time series, a `ds` column for
timestamps and a value column for the time series values.
freq: string valued `freq` of data. Notice this is different from the
`freq` required by `forecast`. See `freq_map` for allowed values.
forecast_context_len: If provided none zero, we take the last
`forecast_context_len` time-points from each series as the forecast
context instead of the `context_len` set by the model.
value_name: The name of the value column.
model_name: name of the model to be written into future df.
window_size: window size of trend + residual decomposition. If None then
we do not do decomposition.
num_jobs: number of parallel processes to use for dataframe processing.
verbose: output model states in terminal.
Returns:
Future forecasts dataframe.
"""
if not ("unique_id" in inputs.columns and "ds" in inputs.columns and
value_name in inputs.columns):
raise ValueError(
f"DataFrame must have unique_id, ds and {value_name} columns.")
if not forecast_context_len:
forecast_context_len = self.context_len
logging.info("Preprocessing dataframe.")
df_sorted = inputs.sort_values(by=["unique_id", "ds"])
new_inputs = []
uids = []
if num_jobs == 1:
if verbose:
print("Processing dataframe with single process.")
for key, group in df_sorted.groupby("unique_id"):
inp, uid = process_group(
key,
group,
value_name,
forecast_context_len,
)
new_inputs.append(inp)
uids.append(uid)
else:
if num_jobs == -1:
num_jobs = multiprocessing.cpu_count()
if verbose:
print("Processing dataframe with multiple processes.")
with multiprocessing.Pool(processes=num_jobs) as pool:
results = pool.starmap(
process_group,
[(key, group, value_name, forecast_context_len)
for key, group in df_sorted.groupby("unique_id")],
)
new_inputs, uids = zip(*results)
if verbose:
print("Finished preprocessing dataframe.")
freq_inps = [freq_map(freq)] * len(new_inputs)
_, full_forecast = self.forecast(new_inputs,
freq=freq_inps,
window_size=window_size)
if verbose:
print("Finished forecasting.")
fcst_df = make_future_dataframe(
uids=uids,
last_times=df_sorted.groupby("unique_id")["ds"].tail(1),
h=self.horizon_len,
freq=freq,
)
fcst_df[model_name] = full_forecast[:, 0:self.horizon_len, 0].reshape(-1, 1)
for i, q in enumerate(self.quantiles):
q_col = f"{model_name}-q-{q}"
fcst_df[q_col] = full_forecast[:, 0:self.horizon_len,
1 + i].reshape(-1, 1)
if q == 0.5:
fcst_df[model_name] = fcst_df[q_col]
logging.info("Finished creating output dataframe.")
return fcst_df
================================================
FILE: probts/model/nn/arch/TimesFMModule/timesfm_jax.py
================================================
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TimesFM JAX forecast API for inference."""
import logging
import multiprocessing
import time
from os import path
from typing import Any, Sequence
import einshape as es
import jax
import jax.numpy as jnp
import numpy as np
from huggingface_hub import snapshot_download
from paxml import checkpoints, tasks_lib
from praxis import base_hyperparams, base_layer, pax_fiddle, py_utils, pytypes
from praxis.layers import normalizations, transformers
from probts.model.nn.arch.TimesFMModule import timesfm_base
from probts.model.nn.arch.TimesFMModule import patched_decoder
instantiate = base_hyperparams.instantiate
NestedMap = py_utils.NestedMap
JTensor = pytypes.JTensor
_TOL = 1e-6
class TimesFmJax(timesfm_base.TimesFmBase):
"""TimesFM forecast API for inference.
This class is the scaffolding for calling TimesFM forecast. To properly use:
1. Create an instance with the correct hyperparameters of a TimesFM model.
2. Call `load_from_checkpoint` to load a compatible checkpoint.
3. Call `forecast` for inference.
Given the model size, this API does not shard the model weights for SPMD. All
parallelism happens on the data dimension.
Compilation happens during the first time `forecast` is called and uses the
`per_core_batch_size` to set and freeze the input signature. Subsequent calls
to `forecast` reflect the actual inference latency.
"""
def _get_sample_inputs(self):
return {
"input_ts":
jnp.zeros(
(
self.per_core_batch_size,
self.context_len + self.output_patch_len,
),
dtype=jnp.float32,
),
"input_padding":
jnp.zeros(
(
self.per_core_batch_size,
self.context_len + self.output_patch_len,
),
dtype=jnp.float32,
),
"freq":
jnp.zeros(
(
self.per_core_batch_size,
1,
),
dtype=jnp.int32,
),
}
def __post_init__(self):
self.num_cores = jax.local_device_count(self.backend)
self.global_batch_size = self.per_core_batch_size * self.num_cores
self._eval_context = base_layer.JaxContext.HParams(do_eval=True)
self._pmapped_decode = None
self._model = None
self._train_state = None
self._median_index = -1
def load_from_checkpoint(
self,
checkpoint: timesfm_base.TimesFmCheckpoint,
) -> None:
"""Loads a checkpoint and compiles the decoder."""
checkpoint_type = (checkpoints.CheckpointType.FLAX
if checkpoint.type is None else checkpoint.type)
checkpoint_path = checkpoint.path
step = checkpoint.step
repo_id = checkpoint.huggingface_repo_id
if checkpoint_path is None:
checkpoint_path = path.join(snapshot_download(repo_id), "checkpoints")
# Rewrite the devices for Jax.
self.mesh_shape = [1, self.num_cores, 1]
self.mesh_name = ["replica", "data", "mdl"]
self.model_p = pax_fiddle.Config(
patched_decoder.PatchedTimeSeriesDecoder,
name="patched_decoder",
horizon_len=self.output_patch_len,
patch_len=self.input_patch_len,
model_dims=self.model_dims,
hidden_dims=self.model_dims,
residual_block_tpl=pax_fiddle.Config(patched_decoder.ResidualBlock),
quantiles=self.quantiles,
use_freq=True,
use_pos_emb=self.use_pos_emb,
stacked_transformer_params_tpl=pax_fiddle.Config(
transformers.StackedTransformer,
num_heads=self.num_heads,
num_layers=self.num_layers,
transformer_layer_params_tpl=pax_fiddle.Config(
transformers.Transformer,
ln_tpl=pax_fiddle.Config(normalizations.RmsNorm,),
),
),
)
self._key1, self._key2 = jax.random.split(jax.random.PRNGKey(42))
self._model = None
self._train_state = None
self._pmapped_decode = None
self._eval_context = base_layer.JaxContext.HParams(do_eval=True)
try:
multiprocessing.set_start_method("spawn")
except RuntimeError:
print("Multiprocessing context has already been set.")
# Download the checkpoint from Hugging Face Hub if not given
# Initialize the model weights.
self._logging("Constructing model weights.")
start_time = time.time()
self._model = instantiate(self.model_p)
var_weight_hparams = self._model.abstract_init_with_metadata(
self._get_sample_inputs(), do_eval=True)
train_state_partition_specs = tasks_lib.create_state_partition_specs(
var_weight_hparams,
mesh_shape=self.mesh_shape,
mesh_axis_names=self.mesh_name,
discard_opt_states=True,
learners=None,
)
train_state_local_shapes = tasks_lib.create_state_unpadded_shapes(
var_weight_hparams,
discard_opt_states=True,
learners=None,
)
self._logging(
f"Constructed model weights in {time.time() - start_time:.2f} seconds.")
# Load the model weights.
self._logging(f"Restoring checkpoint from {checkpoint_path}.")
start_time = time.time()
self._train_state = checkpoints.restore_checkpoint(
train_state_local_shapes,
checkpoint_dir=checkpoint_path,
checkpoint_type=checkpoint_type,
state_specs=train_state_partition_specs,
step=step,
)
self._logging(
f"Restored checkpoint in {time.time() - start_time:.2f} seconds.")
self.jit_decode()
def jit_decode(self):
"""Jitting decoding function."""
# Initialize and jit the decode fn.
def _decode(inputs):
assert self._model is not None
assert self._train_state is not None
return self._model.apply(
self._train_state.mdl_vars,
inputs,
horizon_len=self.horizon_len,
output_patch_len=self.output_patch_len,
max_len=self.context_len,
return_forecast_on_context=True,
rngs={
base_layer.PARAMS: self._key1,
base_layer.RANDOM: self._key2,
},
method=self._model.decode,
)
self._logging("Jitting decoding.")
start_time = time.time()
self._pmapped_decode = jax.pmap(
_decode,
axis_name="batch",
devices=jax.devices(self.backend),
backend=self.backend,
axis_size=self.num_cores,
)
with base_layer.JaxContext.new_context(hparams=self._eval_context):
_ = self._pmapped_decode(
NestedMap({
"input_ts":
jnp.zeros(
(
self.num_cores,
self.per_core_batch_size,
self.context_len,
),
dtype=jnp.float32,
),
"input_padding":
jnp.zeros(
(
self.num_cores,
self.per_core_batch_size,
self.context_len + self.horizon_len,
),
dtype=jnp.float32,
),
"date_features":
None,
"freq":
jnp.zeros(
(self.num_cores, self.per_core_batch_size, 1),
dtype=jnp.int32,
),
}))
self._logging(f"Jitted decoding in {time.time() - start_time:.2f} seconds.")
def _forecast(
self,
inputs: Sequence[Any],
freq: Sequence[int] | None = None,
window_size: int | None = None,
forecast_context_len: int | None = None,
return_forecast_on_context: bool = False,
) -> tuple[np.ndarray, np.ndarray]:
"""Forecasts on a list of time series.
Args:
inputs: list of time series forecast contexts. Each context time series
should be in a format convertible to JTensor by `jnp.array`.
freq: frequency of each context time series. 0 for high frequency
(default), 1 for medium, and 2 for low. Notice this is different from
the `freq` required by `forecast_on_df`.
window_size: window size of trend + residual decomposition. If None then
we do not do decomposition.
forecast_context_len: optional max context length.
return_forecast_on_context: True to return the forecast on the context
when available, i.e. after the first input patch.
Returns:
A tuple for JTensors:
- the mean forecast of size (# inputs, # forecast horizon),
- the full forecast (mean + quantiles) of size
(# inputs, # forecast horizon, 1 + # quantiles).
Raises:
ValueError: If the checkpoint is not properly loaded.
"""
if not self._train_state or not self._model:
raise ValueError(
"Checkpoint not loaded. Call `load_from_checkpoint` before"
" `forecast`.")
if forecast_context_len is None:
fcontext_len = self.context_len
else:
fcontext_len = forecast_context_len
inputs = [np.array(ts)[-fcontext_len:] for ts in inputs]
if window_size is not None:
new_inputs = []
for ts in inputs:
new_inputs.extend(timesfm_base.moving_average(ts, window_size))
inputs = new_inputs
if freq is None:
logging.info("No frequency provided via `freq`. Default to high (0).")
freq = [0] * len(inputs)
input_ts, input_padding, inp_freq, pmap_pad = self._preprocess(inputs, freq)
with base_layer.JaxContext.new_context(hparams=self._eval_context):
mean_outputs = []
full_outputs = []
assert input_ts.shape[0] % self.global_batch_size == 0
for i in range(input_ts.shape[0] // self.global_batch_size):
input_ts_in = jnp.array(input_ts[i * self.global_batch_size:(i + 1) *
self.global_batch_size])
input_padding_in = jnp.array(
input_padding[i * self.global_batch_size:(i + 1) *
self.global_batch_size],)
inp_freq_in = jnp.array(
inp_freq[i * self.global_batch_size:(i + 1) *
self.global_batch_size, :],
dtype=jnp.int32,
)
pmapped_inputs = NestedMap({
"input_ts":
es.jax_einshape(
"(db)...->db...",
input_ts_in,
d=self.num_cores,
),
"input_padding":
es.jax_einshape(
"(db)...->db...",
input_padding_in,
d=self.num_cores,
),
"date_features":
None,
"freq":
es.jax_einshape(
"(db)...->db...",
inp_freq_in,
d=self.num_cores,
),
})
mean_output, full_output = self._pmapped_decode(pmapped_inputs)
if not return_forecast_on_context:
mean_output = mean_output[:, :, self._horizon_start:, ...]
full_output = full_output[:, :, self._horizon_start:, ...]
mean_output = es.jax_einshape("db...->(db)...",
mean_output,
d=self.num_cores)
full_output = es.jax_einshape("db...->(db)...",
full_output,
d=self.num_cores)
mean_output = np.array(mean_output)
full_output = np.array(full_output)
mean_outputs.append(mean_output)
full_outputs.append(full_output)
mean_outputs = np.concatenate(mean_outputs, axis=0)
full_outputs = np.concatenate(full_outputs, axis=0)
if pmap_pad > 0:
mean_outputs = mean_outputs[:-pmap_pad, ...]
full_outputs = full_outputs[:-pmap_pad, ...]
if window_size is not None:
mean_outputs = mean_outputs[0::2, ...] + mean_outputs[1::2, ...]
full_outputs = full_outputs[0::2, ...] + full_outputs[1::2, ...]
return mean_outputs, full_outputs
================================================
FILE: probts/model/nn/arch/TimesFMModule/timesfm_torch.py
================================================
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TimesFM pytorch forecast API for inference."""
import logging
from os import path
from typing import Any, Sequence
import numpy as np
import torch
from huggingface_hub import snapshot_download
from probts.model.nn.arch.TimesFMModule import timesfm_base
from probts.model.nn.arch.TimesFMModule import pytorch_patched_decoder as ppd
_TOL = 1e-6
class TimesFmTorch(timesfm_base.TimesFmBase):
"""TimesFM forecast API for inference."""
def __post_init__(self):
self._model_config = ppd.TimesFMConfig(
num_layers=self.num_layers,
num_heads=self.num_heads,
hidden_size=self.model_dims,
intermediate_size=self.model_dims,
patch_len=self.input_patch_len,
horizon_len=self.output_patch_len,
head_dim=self.model_dims // self.num_heads,
quantiles=self.quantiles,
use_positional_embedding=self.use_pos_emb,
)
self._model = None
self.num_cores = 1
self.global_batch_size = self.per_core_batch_size
self._device = torch.device("cuda:0" if (
torch.cuda.is_available() and self.backend == "gpu") else "cpu")
self._median_index = -1
def load_from_checkpoint(
self,
checkpoint: timesfm_base.TimesFmCheckpoint,
) -> None:
"""Loads a checkpoint and compiles the decoder."""
checkpoint_path = checkpoint.path
repo_id = checkpoint.huggingface_repo_id
if checkpoint_path is None:
checkpoint_path = path.join(
snapshot_download(repo_id, local_dir=checkpoint.local_dir),
"torch_model.ckpt")
self._model = ppd.PatchedTimeSeriesDecoder(self._model_config)
loaded_checkpoint = torch.load(checkpoint_path, weights_only=True)
logging.info("Loading checkpoint from %s", checkpoint_path)
self._model.load_state_dict(loaded_checkpoint)
logging.info("Sending checkpoint to device %s", f"{self._device}")
self._model.to(self._device)
self._model.eval()
# TODO: add compilation.
def _forecast(
self,
inputs: Sequence[Any],
freq: Sequence[int] | None = None,
window_size: int | None = None,
forecast_context_len: int | None = None,
return_forecast_on_context: bool = False,
) -> tuple[np.ndarray, np.ndarray]:
"""Forecasts on a list of time series.
Args:
inputs: list of time series forecast contexts. Each context time series
should be in a format convertible to JTensor by `jnp.array`.
freq: frequency of each context time series. 0 for high frequency
(default), 1 for medium, and 2 for low. Notice this is different from
the `freq` required by `forecast_on_df`.
window_size: window size of trend + residual decomposition. If None then
we do not do decomposition.
forecast_context_len: optional max context length.
return_forecast_on_context: True to return the forecast on the context
when available, i.e. after the first input patch.
Returns:
A tuple for JTensors:
- the mean forecast of size (# inputs, # forecast horizon),
- the full forecast (mean + quantiles) of size
(# inputs, # forecast horizon, 1 + # quantiles).
Raises:
ValueError: If the checkpoint is not properly loaded.
"""
if not self._model:
raise ValueError(
"Checkpoint not loaded. Call `load_from_checkpoint` before"
" `forecast`.")
if forecast_context_len is None:
fcontext_len = self.context_len
else:
fcontext_len = forecast_context_len
inputs = [np.array(ts)[-fcontext_len:] for ts in inputs]
if window_size is not None:
new_inputs = []
for ts in inputs:
new_inputs.extend(timesfm_base.moving_average(ts, window_size))
inputs = new_inputs
if freq is None:
logging.info("No frequency provided via `freq`. Default to high (0).")
freq = [0] * len(inputs)
input_ts, input_padding, inp_freq, pmap_pad = self._preprocess(inputs, freq)
with torch.no_grad():
mean_outputs = []
full_outputs = []
assert input_ts.shape[0] % self.global_batch_size == 0
for i in range(input_ts.shape[0] // self.global_batch_size):
input_ts_in = torch.from_numpy(
np.array(input_ts[i * self.global_batch_size:(i + 1) *
self.global_batch_size],
dtype=np.float32)).to(self._device)
input_padding_in = torch.from_numpy(
np.array(input_padding[i * self.global_batch_size:(i + 1) *
self.global_batch_size],
dtype=np.float32)).to(self._device)
inp_freq_in = torch.from_numpy(
np.array(inp_freq[
i * self.global_batch_size:(i + 1) * self.global_batch_size,
:,
],
dtype=np.int32)).long().to(self._device)
mean_output, full_output = self._model.decode(
input_ts=input_ts_in,
paddings=input_padding_in,
freq=inp_freq_in,
horizon_len=self.horizon_len,
return_forecast_on_context=return_forecast_on_context,
)
mean_output = mean_output.detach().cpu().numpy()
full_output = full_output.detach().cpu().numpy()
mean_output = np.array(mean_output)
full_output = np.array(full_output)
mean_outputs.append(mean_output)
full_outputs.append(full_output)
mean_outputs = np.concatenate(mean_outputs, axis=0)
full_outputs = np.concatenate(full_outputs, axis=0)
if pmap_pad > 0:
mean_outputs = mean_outputs[:-pmap_pad, ...]
full_outputs = full_outputs[:-pmap_pad, ...]
if window_size is not None:
mean_outputs = mean_outputs[0::2, ...] + mean_outputs[1::2, ...]
full_outputs = full_outputs[0::2, ...] + full_outputs[1::2, ...]
return mean_outputs, full_outputs
================================================
FILE: probts/model/nn/arch/TimesFMModule/xreg_lib.py
================================================
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helper functions for in-context covariates and regression."""
import itertools
import math
from typing import Any, Iterable, Literal, Mapping, Sequence
import jax
import jax.numpy as jnp
import numpy as np
from sklearn import preprocessing
Category = int | str
_TOL = 1e-6
XRegMode = Literal["timesfm + xreg", "xreg + timesfm"]
def _unnest(nested: Sequence[Sequence[Any]]) -> np.ndarray:
return np.array(list(itertools.chain.from_iterable(nested)))
def _repeat(elements: Iterable[Any], counts: Iterable[int]) -> np.ndarray:
return np.array(
list(
itertools.chain.from_iterable(map(itertools.repeat, elements,
counts))))
def _to_padded_jax_array(x: np.ndarray) -> jax.Array:
if x.ndim == 1:
(i,) = x.shape
di = 2**math.ceil(math.log2(i)) - i
return jnp.pad(x, ((0, di),), mode="constant", constant_values=0.0)
elif x.ndim == 2:
i, j = x.shape
di = 2**math.ceil(math.log2(i)) - i
dj = 2**math.ceil(math.log2(j)) - j
return jnp.pad(x, ((0, di), (0, dj)), mode="constant", constant_values=0.0)
else:
raise ValueError(f"Unsupported array shape: {x.shape}")
class BatchedInContextXRegBase:
"""Helper class for in-context regression covariate formatting.
Attributes:
targets: List of targets (responses) of the in-context regression.
train_lens: List of lengths of each target vector from the context.
test_lens: List of lengths of each forecast horizon.
train_dynamic_numerical_covariates: Dict of covariate names mapping to the
dynamic numerical covariates of each forecast task on the context. Their
lengths should match the corresponding lengths in `train_lens`.
train_dynamic_categorical_covariates: Dict of covariate names mapping to the
dynamic categorical covariates of each forecast task on the context. Their
lengths should match the corresponding lengths in `train_lens`.
test_dynamic_numerical_covariates: Dict of covariate names mapping to the
dynamic numerical covariates of each forecast task on the horizon. Their
lengths should match the corresponding lengths in `test_lens`.
test_dynamic_categorical_covariates: Dict of covariate names mapping to the
dynamic categorical covariates of each forecast task on the horizon. Their
lengths should match the corresponding lengths in `test_lens`.
static_numerical_covariates: Dict of covariate names mapping to the static
numerical covariates of each forecast task.
static_categorical_covariates: Dict of covariate names mapping to the static
categorical covariates of each forecast task.
"""
def __init__(
self,
targets: Sequence[Sequence[float]],
train_lens: Sequence[int],
test_lens: Sequence[int],
train_dynamic_numerical_covariates: (
Mapping[str, Sequence[Sequence[float]]] | None) = None,
train_dynamic_categorical_covariates: (
Mapping[str, Sequence[Sequence[Category]]] | None) = None,
test_dynamic_numerical_covariates: (
Mapping[str, Sequence[Sequence[float]]] | None) = None,
test_dynamic_categorical_covariates: (
Mapping[str, Sequence[Sequence[Category]]] | None) = None,
static_numerical_covariates: Mapping[str, Sequence[float]] | None = None,
static_categorical_covariates: (Mapping[str, Sequence[Category]] |
None) = None,
) -> None:
"""Initializes with the exogenous covariate inputs.
Here we use model fitting language to refer to the context as 'train' and
the horizon as 'test'. We assume batched inputs. To properly format the
request:
- `train_lens` represents the contexts in the batch. Targets and all train
dynamic covariates should have the same lengths as the corresponding
elements
in `train_lens`. Notice each `train_len` can be different from the exact
length of the corresponding context depending on how much of the context is
used for fitting the in-context model.
- `test_lens` represents the horizon lengths in the batch. All tesdt
dynamic
covariates should have the same lengths as the corresponding elements in
`test_lens`.
- Static covariates should be one for each input.
- For train and test dynamic covariates, they should have the same
covariate
names.
Pass an empty dict {} for a covariate type if it is not present.
Example:
Here is a set of valid inputs whose schema can be used for reference.
```
targets = [
[0.0, 0.1, 0.2],
[0.0, 0.1, 0.2, 0.3],
] # Two inputs in this batch.
train_lens = [3, 4]
test_lens = [2, 5] # Forecast horizons 2 and 5 respectively.
train_dynamic_numerical_covariates = {
"cov_1_dn": [[0.0, 0.5, 1.0], [0.0, 0.5, 1.0, 1.5]],
"cov_2_dn": [[0.0, 1.5, 1.0], [0.0, 1.5, 1.0, 2.5]],
} # Each train dynamic covariate has 3 and 4 elements respectively.
test_dynamic_numerical_covariates = {
"cov_1_dn": [[0.1, 0.6], [0.1, 0.6, 1.1, 1.6, 2.4]],
"cov_2_dn": [[0.1, 1.1], [0.1, 1.6, 1.1, 2.6, 10.0]],
} # Each test dynamic covariate has 2 and 5 elements respectively.
train_dynamic_categorical_covariates = {
"cov_1_dc": [[0, 1, 0], [0, 1, 2, 3]],
"cov_2_dc": [["good", "bad", "good"], ["good", "good", "bad",
"bad"]],
}
test_dynamic_categorical_covariates = {
"cov_1_dc": [[1, 0], [1, 0, 2, 3, 1]],
"cov_2_dc": [["bad", "good"], ["bad", "bad", "bad", "bad", "bad"]],
}
static_numerical_covariates = {
"cov_1_sn": [0.0, 3.0],
"cov_2_sn": [2.0, 1.0],
"cov_3_sn": [1.0, 2.0],
} # Each static covariate has 1 element for each input.
static_categorical_covariates = {
"cov_1_sc": ["apple", "orange"],
"cov_2_sc": [2, 3],
}
```
Args:
targets: List of targets (responses) of the in-context regression.
train_lens: List of lengths of each target vector from the context.
test_lens: List of lengths of each forecast horizon.
train_dynamic_numerical_covariates: Dict of covariate names mapping to the
dynamic numerical covariates of each forecast task on the context. Their
lengths should match the corresponding lengths in `train_lens`.
train_dynamic_categorical_covariates: Dict of covariate names mapping to
the dynamic categorical covariates of each forecast task on the context.
Their lengths should match the corresponding lengths in `train_lens`.
test_dynamic_numerical_covariates: Dict of covariate names mapping to the
dynamic numerical covariates of each forecast task on the horizon. Their
lengths should match the corresponding lengths in `test_lens`.
test_dynamic_categorical_covariates: Dict of covariate names mapping to
the dynamic categorical covariates of each forecast task on the horizon.
Their lengths should match the corresponding lengths in `test_lens`.
static_numerical_covariates: Dict of covariate names mapping to the static
numerical covariates of each forecast task.
static_categorical_covariates: Dict of covariate names mapping to the
static categorical covariates of each forecast task.
"""
self.targets = targets
self.train_lens = train_lens
self.test_lens = test_lens
self.train_dynamic_numerical_covariates = (
train_dynamic_numerical_covariates or {})
self.train_dynamic_categorical_covariates = (
train_dynamic_categorical_covariates or {})
self.test_dynamic_numerical_covariates = (test_dynamic_numerical_covariates
or {})
self.test_dynamic_categorical_covariates = (
test_dynamic_categorical_covariates or {})
self.static_numerical_covariates = static_numerical_covariates or {}
self.static_categorical_covariates = static_categorical_covariates or {}
def _assert_covariates(self, assert_covariate_shapes: bool = False) -> None:
"""Verifies the validity of the covariate inputs."""
# Check presence.
if (self.train_dynamic_numerical_covariates and
not self.test_dynamic_numerical_covariates) or (
not self.train_dynamic_numerical_covariates and
self.test_dynamic_numerical_covariates):
raise ValueError(
"train_dynamic_numerical_covariates and"
" test_dynamic_numerical_covariates must be both present or both"
" absent.")
if (self.train_dynamic_categorical_covariates and
not self.test_dynamic_categorical_covariates) or (
not self.train_dynamic_categorical_covariates and
self.test_dynamic_categorical_covariates):
raise ValueError(
"train_dynamic_categorical_covariates and"
" test_dynamic_categorical_covariates must be both present or both"
" absent.")
# Check keys.
for dict_a, dict_b, dict_a_name, dict_b_name in (
(
self.train_dynamic_numerical_covariates,
self.test_dynamic_numerical_covariates,
"train_dynamic_numerical_covariates",
"test_dynamic_numerical_covariates",
),
(
self.train_dynamic_categorical_covariates,
self.test_dynamic_categorical_covariates,
"train_dynamic_categorical_covariates",
"test_dynamic_categorical_covariates",
),
):
if w := set(dict_a.keys()) - set(dict_b.keys()):
raise ValueError(
f"{dict_a_name} has keys not present in {dict_b_name}: {w}")
if w := set(dict_b.keys()) - set(dict_a.keys()):
raise ValueError(
f"{dict_b_name} has keys not present in {dict_a_name}: {w}")
# Check shapes.
if assert_covariate_shapes:
if len(self.targets) != len(self.train_lens):
raise ValueError(
"targets and train_lens must have the same number of elements.")
if len(self.train_lens) != len(self.test_lens):
raise ValueError(
"train_lens and test_lens must have the same number of elements.")
for i, (target, train_len) in enumerate(zip(self.targets,
self.train_lens)):
if len(target) != train_len:
raise ValueError(
f"targets[{i}] has length {len(target)} != expected {train_len}.")
for key, values in self.static_numerical_covariates.items():
if len(values) != len(self.train_lens):
raise ValueError(
f"static_numerical_covariates has key {key} with number of"
f" examples {len(values)} != expected {len(self.train_lens)}.")
for key, values in self.static_categorical_covariates.items():
if len(values) != len(self.train_lens):
raise ValueError(
f"static_categorical_covariates has key {key} with number of"
f" examples {len(values)} != expected {len(self.train_lens)}.")
for lens, dict_cov, dict_cov_name in (
(
self.train_lens,
self.train_dynamic_numerical_covariates,
"train_dynamic_numerical_covariates",
),
(
self.train_lens,
self.train_dynamic_categorical_covariates,
"train_dynamic_categorical_covariates",
),
(
self.test_lens,
self.test_dynamic_numerical_covariates,
"test_dynamic_numerical_covariates",
),
(
self.test_lens,
self.test_dynamic_categorical_covariates,
"test_dynamic_categorical_covariates",
),
):
for key, cov_values in dict_cov.items():
if len(cov_values) != len(lens):
raise ValueError(
f"{dict_cov_name} has key {key} with number of examples"
f" {len(cov_values)} != expected {len(lens)}.")
for i, cov_value in enumerate(cov_values):
if len(cov_value) != lens[i]:
raise ValueError(
f"{dict_cov_name} has key {key} with its {i}-th example"
f" length {len(cov_value)} != expected {lens[i]}.")
def create_covariate_matrix(
self,
one_hot_encoder_drop: str | None = "first",
use_intercept: bool = True,
assert_covariates: bool = False,
assert_covariate_shapes: bool = False,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Creates target vector and covariate matrices for in context regression.
Here we use model fitting language to refer to the context as 'train' and
the horizon as 'test'.
Args:
one_hot_encoder_drop: Which drop strategy to use for the one hot encoder.
use_intercept: Whether to prepare an intercept (all 1) column in the
matrices.
assert_covariates: Whether to assert the validity of the covariate inputs.
assert_covariate_shapes: Whether to assert the shapes of the covariate
inputs when `assert_covariates` is True.
Returns:
A tuple of the target vector, the covariate matrix for the context, and
the covariate matrix for the horizon.
"""
if assert_covariates:
self._assert_covariates(assert_covariate_shapes)
x_train, x_test = [], []
# Numerical features.
for name in sorted(self.train_dynamic_numerical_covariates):
x_train.append(
_unnest(self.train_dynamic_numerical_covariates[name])[:, np.newaxis])
x_test.append(
_unnest(self.test_dynamic_numerical_covariates[name])[:, np.newaxis])
for covs in self.static_numerical_covariates.values():
x_train.append(_repeat(covs, self.train_lens)[:, np.newaxis])
x_test.append(_repeat(covs, self.test_lens)[:, np.newaxis])
if x_train:
x_train = np.concatenate(x_train, axis=1)
x_test = np.concatenate(x_test, axis=1)
# Normalize for robustness.
x_mean = np.mean(x_train, axis=0, keepdims=True)
x_std = np.where((w := np.std(x_train, axis=0, keepdims=True)) > _TOL, w,
1.0)
x_train = [(x_train - x_mean) / x_std]
x_test = [(x_test - x_mean) / x_std]
# Categorical features. Encode one by one.
one_hot_encoder = preprocessing.OneHotEncoder(
drop=one_hot_encoder_drop,
sparse_output=False,
handle_unknown="ignore",
)
for name in sorted(self.train_dynamic_categorical_covariates.keys()):
ohe_train = _unnest(
self.train_dynamic_categorical_covariates[name])[:, np.newaxis]
ohe_test = _unnest(
self.test_dynamic_categorical_covariates[name])[:, np.newaxis]
x_train.append(np.array(one_hot_encoder.fit_transform(ohe_train)))
x_test.append(np.array(one_hot_encoder.transform(ohe_test)))
for covs in self.static_categorical_covariates.values():
ohe = one_hot_encoder.fit_transform(np.array(covs)[:, np.newaxis])
x_train.append(_repeat(ohe, self.train_lens))
x_test.append(_repeat(ohe, self.test_lens))
x_train = np.concatenate(x_train, axis=1)
x_test = np.concatenate(x_test, axis=1)
if use_intercept:
x_train = np.pad(x_train, ((0, 0), (1, 0)), constant_values=1.0)
x_test = np.pad(x_test, ((0, 0), (1, 0)), constant_values=1.0)
return _unnest(self.targets), x_train, x_test
def fit(self) -> Any:
raise NotImplementedError("Fit is not implemented.")
class BatchedInContextXRegLinear(BatchedInContextXRegBase):
"""Linear in-context regression model."""
def fit(
self,
ridge: float = 0.0,
one_hot_encoder_drop: str | None = "first",
use_intercept: bool = True,
force_on_cpu: bool = False,
max_rows_per_col: int = 0,
max_rows_per_col_sample_seed: int = 42,
debug_info: bool = False,
assert_covariates: bool = False,
assert_covariate_shapes: bool = False,
) -> (list[np.ndarray] | tuple[list[np.ndarray], list[np.ndarray], jax.Array,
jax.Array, jax.Array]):
"""Fits a linear model for in-context regression.
Args:
ridge: A non-negative value for specifying the ridge regression penalty.
If 0 is provided, fallback to ordinary least squares. Note this penalty
is added to the normalized covariate matrix.
one_hot_encoder_drop: Which drop strategy to use for the one hot encoder.
use_intercept: Whether to prepare an intercept (all 1) column in the
matrices.
force_on_cpu: Whether to force execution on cpu for accelerator machines.
max_rows_per_col: How many rows to subsample per column. 0 for no
subsampling. This is for speeding up model fitting.
max_rows_per_col_sample_seed: The seed for the subsampling if needed by
`max_rows_per_col`.
debug_info: Whether to return debug info.
assert_covariates: Whether to assert the validity of the covariate inputs.
assert_covariate_shapes: Whether to assert the shapes of the covariate
inputs when `assert_covariates` is True.
Returns:
If `debug_info` is False:
The linear fits on the horizon.
If `debug_info` is True:
A tuple of:
- the linear fits on the horizon,
- the linear fits on the context,
- the flattened target vector,
- the covariate matrix for the context, and
- the covariate matrix for the horizon.
"""
flat_targets, x_train_raw, x_test = self.create_covariate_matrix(
one_hot_encoder_drop=one_hot_encoder_drop,
use_intercept=use_intercept,
assert_covariates=assert_covariates,
assert_covariate_shapes=assert_covariate_shapes,
)
x_train = x_train_raw.copy()
if max_rows_per_col:
nrows, ncols = x_train.shape
if nrows > (w := ncols * max_rows_per_col):
subsample = jax.random.choice(
jax.random.PRNGKey(max_rows_per_col_sample_seed),
nrows,
(w,),
replace=False,
)
x_train = x_train[subsample]
flat_targets = flat_targets[subsample]
device = jax.devices("cpu")[0] if force_on_cpu else None
# Runs jitted version of the solvers which are quicker at the cost of
# running jitting during the first time calling. Re-jitting happens whenever
# new (padded) shapes are encountered.
# Ocassionally it helps with the speed and the accuracy if we force single
# thread execution on cpu for accelerator machines:
# 1. Avoid moving data to accelarator memory.
# 2. Avoid precision loss if any.
with jax.default_device(device):
x_train_raw = _to_padded_jax_array(x_train_raw)
x_train = _to_padded_jax_array(x_train)
flat_targets = _to_padded_jax_array(flat_targets)
x_test = _to_padded_jax_array(x_test)
beta_hat = (jnp.linalg.pinv(
x_train.T @ x_train + ridge * jnp.eye(x_train.shape[1]),
hermitian=True,
) @ x_train.T @ flat_targets)
y_hat = x_test @ beta_hat
y_hat_context = x_train_raw @ beta_hat if debug_info else None
outputs = []
outputs_context = []
# Reconstruct the ragged 2-dim batched forecasts from flattened linear fits.
train_index, test_index = 0, 0
for train_index_delta, test_index_delta in zip(self.train_lens,
self.test_lens):
outputs.append(np.array(y_hat[test_index:(test_index +
test_index_delta)]))
if debug_info:
outputs_context.append(
np.array(y_hat_context[train_index:(train_index +
train_index_delta)]))
train_index += train_index_delta
test_index += test_index_delta
if debug_info:
return outputs, outputs_context, flat_targets, x_train, x_test
else:
return outputs
================================================
FILE: probts/model/nn/arch/TransformerModule/Embed.py
================================================
import torch
import torch.nn as nn
import math
class PositionalEmbedding(nn.Module):
def __init__(self, d_model, max_len=5000):
super(PositionalEmbedding, self).__init__()
# Compute the positional encodings once in log space.
pe = torch.zeros(max_len, d_model).float()
pe.require_grad = False
position = torch.arange(0, max_len).float().unsqueeze(1)
div_term = (torch.arange(0, d_model, 2).float()
* -(math.log(10000.0) / d_model)).exp()
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
return self.pe[:, :x.size(1)]
class TokenEmbedding(nn.Module):
def __init__(self, c_in, d_model):
super(TokenEmbedding, self).__init__()
padding = 1 if torch.__version__ >= '1.5.0' else 2
self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model,
kernel_size=3, padding=padding, padding_mode='circular', bias=False)
for m in self.modules():
if isinstance(m, nn.Conv1d):
nn.init.kaiming_normal_(
m.weight, mode='fan_in', nonlinearity='leaky_relu')
def forward(self, x):
x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2)
return x
class FixedEmbedding(nn.Module):
def __init__(self, c_in, d_model):
super(FixedEmbedding, self).__init__()
w = torch.zeros(c_in, d_model).float()
w.require_grad = False
position = torch.arange(0, c_in).float().unsqueeze(1)
div_term = (torch.arange(0, d_model, 2).float()
* -(math.log(10000.0) / d_model)).exp()
w[:, 0::2] = torch.sin(position * div_term)
w[:, 1::2] = torch.cos(position * div_term)
self.emb = nn.Embedding(c_in, d_model)
self.emb.weight = nn.Parameter(w, requires_grad=False)
def forward(self, x):
return self.emb(x).detach()
class TemporalEmbedding(nn.Module):
def __init__(self, d_model, embed_type='fixed', freq='h'):
super(TemporalEmbedding, self).__init__()
minute_size = 4
hour_size = 24
weekday_size = 7
day_size = 32
month_size = 13
Embed = FixedEmbedding if embed_type == 'fixed' else nn.Embedding
if freq == 't':
self.minute_embed = Embed(minute_size, d_model)
self.hour_embed = Embed(hour_size, d_model)
self.weekday_embed = Embed(weekday_size, d_model)
self.day_embed = Embed(day_size, d_model)
self.month_embed = Embed(month_size, d_model)
def forward(self, x):
x = x.long()
minute_x = self.minute_embed(x[:, :, 4]) if hasattr(
self, 'minute_embed') else 0.
hour_x = self.hour_embed(x[:, :, 3])
weekday_x = self.weekday_embed(x[:, :, 2])
day_x = self.day_embed(x[:, :, 1])
month_x = self.month_embed(x[:, :, 0])
return hour_x + weekday_x + day_x + month_x + minute_x
class TimeFeatureEmbedding(nn.Module):
def __init__(self, d_model, embed_type='timeF', freq='h'):
super(TimeFeatureEmbedding, self).__init__()
if freq == 'min':
freq = 't'
freq_map = {'h': 4, 't': 5, 's': 6,
'm': 1, 'a': 1, 'w': 2, 'd': 3, 'b': 3}
d_inp = freq_map[freq]
self.embed = nn.Linear(d_inp, d_model, bias=False)
def forward(self, x):
return self.embed(x)
class DataEmbedding(nn.Module):
def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1):
super(DataEmbedding, self).__init__()
self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)
self.position_embedding = PositionalEmbedding(d_model=d_model)
self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type,
freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding(
d_model=d_model, embed_type=embed_type, freq=freq)
self.dropout = nn.Dropout(p=dropout)
def forward(self, x, x_mark):
if x_mark is None:
x = self.value_embedding(x) + self.position_embedding(x)
else:
x = self.value_embedding(
x) + self.temporal_embedding(x_mark) + self.position_embedding(x)
return self.dropout(x)
class DataEmbedding_wo_pos(nn.Module):
def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1):
super(DataEmbedding_wo_pos, self).__init__()
self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)
self.position_embedding = PositionalEmbedding(d_model=d_model)
self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type,
freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding(
d_model=d_model, embed_type=embed_type, freq=freq)
self.dropout = nn.Dropout(p=dropout)
def forward(self, x, x_mark):
if x_mark is None:
x = self.value_embedding(x)
else:
x = self.value_embedding(x) + self.temporal_embedding(x_mark)
return self.dropout(x)
class PatchEmbedding(nn.Module):
def __init__(self, d_model, patch_len, stride, padding, dropout):
super(PatchEmbedding, self).__init__()
# Patching
self.patch_len = patch_len
self.stride = stride
self.padding_patch_layer = nn.ReplicationPad1d((0, padding))
# Backbone, Input encoding: projection of feature vectors onto a d-dim vector space
self.value_embedding = nn.Linear(patch_len, d_model, bias=False)
# Positional embedding
self.position_embedding = PositionalEmbedding(d_model)
# Residual dropout
self.dropout = nn.Dropout(dropout)
def forward(self, x):
# do patching
n_vars = x.shape[1]
x = self.padding_patch_layer(x)
x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride)
x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3]))
# Input encoding
x = self.value_embedding(x) + self.position_embedding(x)
return self.dropout(x), n_vars
# Code implementation from https://github.com/thuml/iTransformer
class DataEmbedding_inverted(nn.Module):
def __init__(self, c_in, d_model, dropout=0.1):
super(DataEmbedding_inverted, self).__init__()
self.value_embedding = nn.Linear(c_in, d_model)
self.dropout = nn.Dropout(p=dropout)
def forward(self, x, x_mark):
x = x.permute(0, 2, 1)
# x: [Batch Variate Time]
if x_mark is None:
x = self.value_embedding(x)
else:
# the potential to take covariates (e.g. timestamps) as tokens
x = self.value_embedding(torch.cat([x, x_mark.permute(0, 2, 1)], 1))
# x: [Batch Variate d_model]
return self.dropout(x)
================================================
FILE: probts/model/nn/arch/TransformerModule/SelfAttention_Family.py
================================================
import torch
import torch.nn as nn
import numpy as np
from math import sqrt
from probts.utils.masking import TriangularCausalMask, ProbMask
from reformer_pytorch import LSHSelfAttention
from einops import rearrange
# Code implementation from https://github.com/thuml/Flowformer
class FlowAttention(nn.Module):
def __init__(self, attention_dropout=0.1):
super(FlowAttention, self).__init__()
self.dropout = nn.Dropout(attention_dropout)
def kernel_method(self, x):
return torch.sigmoid(x)
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
queries = queries.transpose(1, 2)
keys = keys.transpose(1, 2)
values = values.transpose(1, 2)
# kernel
queries = self.kernel_method(queries)
keys = self.kernel_method(keys)
# incoming and outgoing
normalizer_row = 1.0 / (torch.einsum("nhld,nhd->nhl", queries + 1e-6, keys.sum(dim=2) + 1e-6))
normalizer_col = 1.0 / (torch.einsum("nhsd,nhd->nhs", keys + 1e-6, queries.sum(dim=2) + 1e-6))
# reweighting
normalizer_row_refine = (
torch.einsum("nhld,nhd->nhl", queries + 1e-6, (keys * normalizer_col[:, :, :, None]).sum(dim=2) + 1e-6))
normalizer_col_refine = (
torch.einsum("nhsd,nhd->nhs", keys + 1e-6, (queries * normalizer_row[:, :, :, None]).sum(dim=2) + 1e-6))
# competition and allocation
normalizer_row_refine = torch.sigmoid(
normalizer_row_refine * (float(queries.shape[2]) / float(keys.shape[2])))
normalizer_col_refine = torch.softmax(normalizer_col_refine, dim=-1) * keys.shape[2] # B h L vis
# multiply
kv = keys.transpose(-2, -1) @ (values * normalizer_col_refine[:, :, :, None])
x = (((queries @ kv) * normalizer_row[:, :, :, None]) * normalizer_row_refine[:, :, :, None]).transpose(1,
2).contiguous()
return x, None
# Code implementation from https://github.com/shreyansh26/FlashAttention-PyTorch
class FlashAttention(nn.Module):
def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):
super(FlashAttention, self).__init__()
self.scale = scale
self.mask_flag = mask_flag
self.output_attention = output_attention
self.dropout = nn.Dropout(attention_dropout)
def flash_attention_forward(self, Q, K, V, mask=None):
BLOCK_SIZE = 32
NEG_INF = -1e10 # -infinity
EPSILON = 1e-10
# mask = torch.randint(0, 2, (128, 8)).to(device='cuda')
O = torch.zeros_like(Q, requires_grad=True)
l = torch.zeros(Q.shape[:-1])[..., None]
m = torch.ones(Q.shape[:-1])[..., None] * NEG_INF
O = O.to(device='cuda')
l = l.to(device='cuda')
m = m.to(device='cuda')
Q_BLOCK_SIZE = min(BLOCK_SIZE, Q.shape[-1])
KV_BLOCK_SIZE = BLOCK_SIZE
Q_BLOCKS = torch.split(Q, Q_BLOCK_SIZE, dim=2)
K_BLOCKS = torch.split(K, KV_BLOCK_SIZE, dim=2)
V_BLOCKS = torch.split(V, KV_BLOCK_SIZE, dim=2)
if mask is not None:
mask_BLOCKS = list(torch.split(mask, KV_BLOCK_SIZE, dim=1))
Tr = len(Q_BLOCKS)
Tc = len(K_BLOCKS)
O_BLOCKS = list(torch.split(O, Q_BLOCK_SIZE, dim=2))
l_BLOCKS = list(torch.split(l, Q_BLOCK_SIZE, dim=2))
m_BLOCKS = list(torch.split(m, Q_BLOCK_SIZE, dim=2))
for j in range(Tc):
Kj = K_BLOCKS[j]
Vj = V_BLOCKS[j]
if mask is not None:
maskj = mask_BLOCKS[j]
for i in range(Tr):
Qi = Q_BLOCKS[i]
Oi = O_BLOCKS[i]
li = l_BLOCKS[i]
mi = m_BLOCKS[i]
scale = 1 / np.sqrt(Q.shape[-1])
Qi_scaled = Qi * scale
S_ij = torch.einsum('... i d, ... j d -> ... i j', Qi_scaled, Kj)
if mask is not None:
# Masking
maskj_temp = rearrange(maskj, 'b j -> b 1 1 j')
S_ij = torch.where(maskj_temp > 0, S_ij, NEG_INF)
m_block_ij, _ = torch.max(S_ij, dim=-1, keepdims=True)
P_ij = torch.exp(S_ij - m_block_ij)
if mask is not None:
# Masking
P_ij = torch.where(maskj_temp > 0, P_ij, 0.)
l_block_ij = torch.sum(P_ij, dim=-1, keepdims=True) + EPSILON
P_ij_Vj = torch.einsum('... i j, ... j d -> ... i d', P_ij, Vj)
mi_new = torch.maximum(m_block_ij, mi)
li_new = torch.exp(mi - mi_new) * li + torch.exp(m_block_ij - mi_new) * l_block_ij
O_BLOCKS[i] = (li / li_new) * torch.exp(mi - mi_new) * Oi + (
torch.exp(m_block_ij - mi_new) / li_new) * P_ij_Vj
l_BLOCKS[i] = li_new
m_BLOCKS[i] = mi_new
O = torch.cat(O_BLOCKS, dim=2)
l = torch.cat(l_BLOCKS, dim=2)
m = torch.cat(m_BLOCKS, dim=2)
return O, l, m
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
res = \
self.flash_attention_forward(queries.permute(0, 2, 1, 3), keys.permute(0, 2, 1, 3), values.permute(0, 2, 1, 3),
attn_mask)[0]
return res.permute(0, 2, 1, 3).contiguous(), None
class FullAttention(nn.Module):
def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):
super(FullAttention, self).__init__()
self.scale = scale
self.mask_flag = mask_flag
self.output_attention = output_attention
self.dropout = nn.Dropout(attention_dropout)
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
B, L, H, E = queries.shape
_, S, _, D = values.shape
scale = self.scale or 1. / sqrt(E)
scores = torch.einsum("blhe,bshe->bhls", queries, keys)
if self.mask_flag:
if attn_mask is None:
attn_mask = TriangularCausalMask(B, L, device=queries.device)
scores.masked_fill_(attn_mask.mask, -np.inf)
A = self.dropout(torch.softmax(scale * scores, dim=-1))
V = torch.einsum("bhls,bshd->blhd", A, values)
if self.output_attention:
return (V.contiguous(), A)
else:
return (V.contiguous(), None)
# Code implementation from https://github.com/zhouhaoyi/Informer2020
class ProbAttention(nn.Module):
def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):
super(ProbAttention, self).__init__()
self.factor = factor
self.scale = scale
self.mask_flag = mask_flag
self.output_attention = output_attention
self.dropout = nn.Dropout(attention_dropout)
def _prob_QK(self, Q, K, sample_k, n_top): # n_top: c*ln(L_q)
# Q [B, H, L, D]
B, H, L_K, E = K.shape
_, _, L_Q, _ = Q.shape
# calculate the sampled Q_K
K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E)
# real U = U_part(factor*ln(L_k))*L_q
index_sample = torch.randint(L_K, (L_Q, sample_k))
K_sample = K_expand[:, :, torch.arange(
L_Q).unsqueeze(1), index_sample, :]
Q_K_sample = torch.matmul(
Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze()
# find the Top_k query with sparisty measurement
M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K)
M_top = M.topk(n_top, sorted=False)[1]
# use the reduced Q to calculate Q_K
Q_reduce = Q[torch.arange(B)[:, None, None],
torch.arange(H)[None, :, None],
M_top, :] # factor*ln(L_q)
Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1)) # factor*ln(L_q)*L_k
return Q_K, M_top
def _get_initial_context(self, V, L_Q):
B, H, L_V, D = V.shape
if not self.mask_flag:
# V_sum = V.sum(dim=-2)
V_sum = V.mean(dim=-2)
contex = V_sum.unsqueeze(-2).expand(B, H,
L_Q, V_sum.shape[-1]).clone()
else: # use mask
# requires that L_Q == L_V, i.e. for self-attention only
assert (L_Q == L_V)
contex = V.cumsum(dim=-2)
return contex
def _update_context(self, context_in, V, scores, index, L_Q, attn_mask):
B, H, L_V, D = V.shape
if self.mask_flag:
attn_mask = ProbMask(B, H, L_Q, index, scores, device=V.device)
scores.masked_fill_(attn_mask.mask, -np.inf)
attn = torch.softmax(scores, dim=-1) # nn.Softmax(dim=-1)(scores)
context_in[torch.arange(B)[:, None, None],
torch.arange(H)[None, :, None],
index, :] = torch.matmul(attn, V).type_as(context_in)
if self.output_attention:
attns = (torch.ones([B, H, L_V, L_V]) /
L_V).type_as(attn).to(attn.device)
attns[torch.arange(B)[:, None, None], torch.arange(H)[
None, :, None], index, :] = attn
return (context_in, attns)
else:
return (context_in, None)
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
B, L_Q, H, D = queries.shape
_, L_K, _, _ = keys.shape
queries = queries.transpose(2, 1)
keys = keys.transpose(2, 1)
values = values.transpose(2, 1)
U_part = self.factor * \
np.ceil(np.log(L_K)).astype('int').item() # c*ln(L_k)
u = self.factor * \
np.ceil(np.log(L_Q)).astype('int').item() # c*ln(L_q)
U_part = U_part if U_part < L_K else L_K
u = u if u < L_Q else L_Q
scores_top, index = self._prob_QK(
queries, keys, sample_k=U_part, n_top=u)
# add scale factor
scale = self.scale or 1. / sqrt(D)
if scale is not None:
scores_top = scores_top * scale
# get the context
context = self._get_initial_context(values, L_Q)
# update the context with selected top_k queries
context, attn = self._update_context(
context, values, scores_top, index, L_Q, attn_mask)
return context.contiguous(), attn
class AttentionLayer(nn.Module):
def __init__(self, attention, d_model, n_heads, d_keys=None,
d_values=None):
super(AttentionLayer, self).__init__()
d_keys = d_keys or (d_model // n_heads)
d_values = d_values or (d_model // n_heads)
self.inner_attention = attention
self.query_projection = nn.Linear(d_model, d_keys * n_heads)
self.key_projection = nn.Linear(d_model, d_keys * n_heads)
self.value_projection = nn.Linear(d_model, d_values * n_heads)
self.out_projection = nn.Linear(d_values * n_heads, d_model)
self.n_heads = n_heads
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
B, L, _ = queries.shape
_, S, _ = keys.shape
H = self.n_heads
queries = self.query_projection(queries).view(B, L, H, -1)
keys = self.key_projection(keys).view(B, S, H, -1)
values = self.value_projection(values).view(B, S, H, -1)
out, attn = self.inner_attention(
queries,
keys,
values,
attn_mask,
tau=tau,
delta=delta
)
out = out.view(B, L, -1)
return self.out_projection(out), attn
class ReformerLayer(nn.Module):
def __init__(self, attention, d_model, n_heads, d_keys=None,
d_values=None, causal=False, bucket_size=4, n_hashes=4):
super().__init__()
self.bucket_size = bucket_size
self.attn = LSHSelfAttention(
dim=d_model,
heads=n_heads,
bucket_size=bucket_size,
n_hashes=n_hashes,
causal=causal
)
def fit_length(self, queries):
# inside reformer: assert N % (bucket_size * 2) == 0
B, N, C = queries.shape
if N % (self.bucket_size * 2) == 0:
return queries
else:
# fill the time series
fill_len = (self.bucket_size * 2) - (N % (self.bucket_size * 2))
return torch.cat([queries, torch.zeros([B, fill_len, C]).to(queries.device)], dim=1)
def forward(self, queries, keys, values, attn_mask, tau, delta):
# in Reformer: defalut queries=keys
B, N, C = queries.shape
queries = self.attn(self.fit_length(queries))[:, :N, :]
return queries, None
================================================
FILE: probts/model/nn/arch/TransformerModule/Transformer_EncDec.py
================================================
import torch.nn as nn
import torch.nn.functional as F
class ConvLayer(nn.Module):
def __init__(self, c_in):
super(ConvLayer, self).__init__()
self.downConv = nn.Conv1d(in_channels=c_in,
out_channels=c_in,
kernel_size=3,
padding=2,
padding_mode='circular')
self.norm = nn.BatchNorm1d(c_in)
self.activation = nn.ELU()
self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
def forward(self, x):
x = self.downConv(x.permute(0, 2, 1))
x = self.norm(x)
x = self.activation(x)
x = self.maxPool(x)
x = x.transpose(1, 2)
return x
class EncoderLayer(nn.Module):
def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"):
super(EncoderLayer, self).__init__()
d_ff = d_ff or 4 * d_model
self.attention = attention
self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
self.activation = F.relu if activation == "relu" else F.gelu
def forward(self, x, attn_mask=None, tau=None, delta=None):
new_x, attn = self.attention(
x, x, x,
attn_mask=attn_mask,
tau=tau, delta=delta
)
x = x + self.dropout(new_x)
y = x = self.norm1(x)
y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
y = self.dropout(self.conv2(y).transpose(-1, 1))
return self.norm2(x + y), attn
class Encoder(nn.Module):
def __init__(self, attn_layers, conv_layers=None, norm_layer=None):
super(Encoder, self).__init__()
self.attn_layers = nn.ModuleList(attn_layers)
self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None
self.norm = norm_layer
def forward(self, x, attn_mask=None, tau=None, delta=None):
# x [B, L, D]
attns = []
if self.conv_layers is not None:
for i, (attn_layer, conv_layer) in enumerate(zip(self.attn_layers, self.conv_layers)):
delta = delta if i == 0 else None
x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta)
x = conv_layer(x)
attns.append(attn)
x, attn = self.attn_layers[-1](x, tau=tau, delta=None)
attns.append(attn)
else:
for attn_layer in self.attn_layers:
x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta)
attns.append(attn)
if self.norm is not None:
x = self.norm(x)
return x, attns
class DecoderLayer(nn.Module):
def __init__(self, self_attention, cross_attention, d_model, d_ff=None,
dropout=0.1, activation="relu"):
super(DecoderLayer, self).__init__()
d_ff = d_ff or 4 * d_model
self.self_attention = self_attention
self.cross_attention = cross_attention
self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
self.activation = F.relu if activation == "relu" else F.gelu
def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None):
x = x + self.dropout(self.self_attention(
x, x, x,
attn_mask=x_mask,
tau=tau, delta=None
)[0])
x = self.norm1(x)
x = x + self.dropout(self.cross_attention(
x, cross, cross,
attn_mask=cross_mask,
tau=tau, delta=delta
)[0])
y = x = self.norm2(x)
y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
y = self.dropout(self.conv2(y).transpose(-1, 1))
return self.norm3(x + y)
class Decoder(nn.Module):
def __init__(self, layers, norm_layer=None, projection=None):
super(Decoder, self).__init__()
self.layers = nn.ModuleList(layers)
self.norm = norm_layer
self.projection = projection
def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None):
for layer in self.layers:
x = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask, tau=tau, delta=delta)
if self.norm is not None:
x = self.norm(x)
if self.projection is not None:
x = self.projection(x)
return x
================================================
FILE: probts/model/nn/arch/__init__.py
================================================
================================================
FILE: probts/model/nn/arch/decomp.py
================================================
import torch
from torch import nn
class moving_avg(nn.Module):
"""
Moving average block to highlight the trend of time series
"""
def __init__(self, kernel_size, stride):
super(moving_avg, self).__init__()
self.kernel_size = kernel_size
self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)
def forward(self, x):
# padding on the both ends of time series
front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1)
end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1)
x = torch.cat([front, x, end], dim=1)
x = self.avg(x.permute(0, 2, 1))
x = x.permute(0, 2, 1)
return x
class series_decomp(nn.Module):
"""
Series decomposition block
"""
def __init__(self, kernel_size):
super(series_decomp, self).__init__()
self.moving_avg = moving_avg(kernel_size, stride=1)
def forward(self, x):
moving_mean = self.moving_avg(x)
res = x - moving_mean
return res, moving_mean
================================================
FILE: probts/model/nn/prob/MAF.py
================================================
# ---------------------------------------------------------------------------------
# Portions of this file are derived from PyTorch-TS
# - Source: https://github.com/zalandoresearch/pytorch-ts
# - Paper: Multi-variate Probabilistic Time Series Forecasting via Conditioned Normalizing Flows
# - License: MIT, Apache-2.0 license
# We thank the authors for their contributions.
# ---------------------------------------------------------------------------------
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
from probts.model.nn.prob.flow_model import FlowModel, BatchNorm, FlowSequential
def create_masks(
input_size, hidden_size, n_hidden, input_order="sequential", input_degrees=None
):
# MADE paper sec 4:
# degrees of connections between layers -- ensure at most in_degree - 1 connections
degrees = []
# set input degrees to what is provided in args (the flipped order of the previous layer in a stack of mades);
# else init input degrees based on strategy in input_order (sequential or random)
if input_order == "sequential":
degrees += (
[torch.arange(input_size)] if input_degrees is None else [input_degrees]
)
for _ in range(n_hidden + 1):
degrees += [torch.arange(hidden_size) % (input_size - 1)]
degrees += (
[torch.arange(input_size) % input_size - 1]
if input_degrees is None
else [input_degrees % input_size - 1]
)
elif input_order == "random":
degrees += (
[torch.randperm(input_size)] if input_degrees is None else [input_degrees]
)
for _ in range(n_hidden + 1):
min_prev_degree = min(degrees[-1].min().item(), input_size - 1)
degrees += [torch.randint(min_prev_degree, input_size, (hidden_size,))]
min_prev_degree = min(degrees[-1].min().item(), input_size - 1)
degrees += (
[torch.randint(min_prev_degree, input_size, (input_size,)) - 1]
if input_degrees is None
else [input_degrees - 1]
)
# construct masks
masks = []
for (d0, d1) in zip(degrees[:-1], degrees[1:]):
masks += [(d1.unsqueeze(-1) >= d0.unsqueeze(0)).float()]
return masks, degrees[0]
class MaskedLinear(nn.Linear):
""" MADE building block layer """
def __init__(self, input_size, n_outputs, mask, cond_label_size=None):
super().__init__(input_size, n_outputs)
self.register_buffer("mask", mask)
self.cond_label_size = cond_label_size
if cond_label_size is not None:
self.cond_weight = nn.Parameter(
torch.rand(n_outputs, cond_label_size) / math.sqrt(cond_label_size)
)
def forward(self, x, y=None):
out = F.linear(x, self.weight * self.mask, self.bias)
if y is not None:
out = out + F.linear(y, self.cond_weight)
return out
class MADE(nn.Module):
def __init__(
self,
input_size,
hidden_size,
n_hidden,
cond_label_size=None,
activation="ReLU",
input_order="sequential",
input_degrees=None,
):
"""
Args:
input_size -- scalar; dim of inputs
hidden_size -- scalar; dim of hidden layers
n_hidden -- scalar; number of hidden layers
activation -- str; activation function to use
input_order -- str or tensor; variable order for creating the autoregressive masks (sequential|random)
or the order flipped from the previous layer in a stack of MADEs
conditional -- bool; whether model is conditional
"""
super().__init__()
# base distribution for calculation of log prob under the model
self.register_buffer("base_dist_mean", torch.zeros(input_size))
self.register_buffer("base_dist_var", torch.ones(input_size))
# create masks
masks, self.input_degrees = create_masks(
input_size, hidden_size, n_hidden, input_order, input_degrees
)
# setup activation
if activation == "ReLU":
activation_fn = nn.ReLU()
elif activation == "Tanh":
activation_fn = nn.Tanh()
else:
raise ValueError("Check activation function.")
# construct model
self.net_input = MaskedLinear(
input_size, hidden_size, masks[0], cond_label_size
)
self.net = []
for m in masks[1:-1]:
self.net += [activation_fn, MaskedLinear(hidden_size, hidden_size, m)]
self.net += [
activation_fn,
MaskedLinear(hidden_size, 2 * input_size, masks[-1].repeat(2, 1)),
]
self.net = nn.Sequential(*self.net)
@property
def base_dist(self):
return Normal(self.base_dist_mean, self.base_dist_var)
def forward(self, x, y=None):
# MAF eq 4 -- return mean and log std
m, loga = self.net(self.net_input(x, y)).chunk(chunks=2, dim=-1)
u = (x - m) * torch.exp(-loga)
# MAF eq 5
log_abs_det_jacobian = -loga
return u, log_abs_det_jacobian
def inverse(self, u, y=None, sum_log_abs_det_jacobians=None):
# MAF eq 3
# D = u.shape[-1]
x = torch.zeros_like(u)
# run through reverse model
for i in self.input_degrees:
m, loga = self.net(self.net_input(x, y)).chunk(chunks=2, dim=-1)
x[..., i] = u[..., i] * torch.exp(loga[..., i]) + m[..., i]
log_abs_det_jacobian = loga
return x, log_abs_det_jacobian
def log_prob(self, x, y=None):
u, log_abs_det_jacobian = self.forward(x, y)
return torch.sum(self.base_dist.log_prob(u) + log_abs_det_jacobian, dim=-1)
class MAF(FlowModel):
def __init__(
self,
n_blocks,
target_dim,
hidden_size,
n_hidden,
f_hidden_size,
conditional_length,
dequantize,
activation="ReLU",
input_order="sequential",
batch_norm=True,
):
super().__init__(target_dim, f_hidden_size, conditional_length, dequantize)
# construct model
modules = []
self.input_degrees = None
for i in range(n_blocks):
modules += [
MADE(
target_dim,
hidden_size,
n_hidden,
conditional_length,
activation,
input_order,
self.input_degrees,
)
]
self.input_degrees = modules[-1].input_degrees.flip(0)
modules += batch_norm * [BatchNorm(target_dim)]
self.net = FlowSequential(*modules)
================================================
FILE: probts/model/nn/prob/RealNVP.py
================================================
# ---------------------------------------------------------------------------------
# Portions of this file are derived from PyTorch-TS
# - Source: https://github.com/zalandoresearch/pytorch-ts
# - Paper: Multi-variate Probabilistic Time Series Forecasting via Conditioned Normalizing Flows
# - License: MIT, Apache-2.0 license
# We thank the authors for their contributions.
# ---------------------------------------------------------------------------------
import copy
import torch
import torch.nn as nn
from probts.model.nn.prob.flow_model import FlowModel, BatchNorm, FlowSequential
class LinearMaskedCoupling(nn.Module):
""" Modified RealNVP Coupling Layers per the MAF paper """
def __init__(self, input_size, hidden_size, n_hidden, mask, cond_label_size=None):
super().__init__()
self.register_buffer("mask", mask)
# scale function
s_net = [
nn.Linear(
input_size + (cond_label_size if cond_label_size is not None else 0),
hidden_size,
)
]
for _ in range(n_hidden):
s_net += [nn.Tanh(), nn.Linear(hidden_size, hidden_size)]
s_net += [nn.Tanh(), nn.Linear(hidden_size, input_size)]
self.s_net = nn.Sequential(*s_net)
# translation function
self.t_net = copy.deepcopy(self.s_net)
# replace Tanh with ReLU's per MAF paper
for i in range(len(self.t_net)):
if not isinstance(self.t_net[i], nn.Linear):
self.t_net[i] = nn.ReLU()
def forward(self, x, y=None):
# apply mask
mx = x * self.mask
# run through model
s = self.s_net(mx if y is None else torch.cat([y, mx], dim=-1))
t = self.t_net(mx if y is None else torch.cat([y, mx], dim=-1)) * (
1 - self.mask
)
# cf RealNVP eq 8 where u corresponds to x (here we're modeling u)
log_s = torch.tanh(s) * (1 - self.mask)
u = x * torch.exp(log_s) + t
# u = (x - t) * torch.exp(log_s)
# u = mx + (1 - self.mask) * (x - t) * torch.exp(-s)
# log det du/dx; cf RealNVP 8 and 6; note, sum over input_size done at model log_prob
# log_abs_det_jacobian = -(1 - self.mask) * s
# log_abs_det_jacobian = -log_s #.sum(-1, keepdim=True)
log_abs_det_jacobian = log_s
return u, log_abs_det_jacobian
def inverse(self, u, y=None):
# apply mask
mu = u * self.mask
# run through model
s = self.s_net(mu if y is None else torch.cat([y, mu], dim=-1))
t = self.t_net(mu if y is None else torch.cat([y, mu], dim=-1)) * (
1 - self.mask
)
log_s = torch.tanh(s) * (1 - self.mask)
x = (u - t) * torch.exp(-log_s)
# x = u * torch.exp(log_s) + t
# x = mu + (1 - self.mask) * (u * s.exp() + t) # cf RealNVP eq 7
# log_abs_det_jacobian = (1 - self.mask) * s # log det dx/du
# log_abs_det_jacobian = log_s #.sum(-1, keepdim=True)
log_abs_det_jacobian = -log_s
return x, log_abs_det_jacobian
class RealNVP(FlowModel):
def __init__(
self,
n_blocks,
target_dim,
hidden_size,
n_hidden,
f_hidden_size,
conditional_length,
dequantize,
batch_norm=True
):
super().__init__(target_dim, f_hidden_size, conditional_length, dequantize)
# construct model
modules = []
mask = torch.arange(target_dim).float() % 2
for i in range(n_blocks):
modules += [
LinearMaskedCoupling(
target_dim, hidden_size, n_hidden, mask, conditional_length
)
]
mask = 1 - mask
modules += batch_norm * [BatchNorm(target_dim)]
self.net = FlowSequential(*modules)
================================================
FILE: probts/model/nn/prob/__init__.py
================================================
================================================
FILE: probts/model/nn/prob/diffusion_layers.py
================================================
# ---------------------------------------------------------------------------------
# Portions of this file are derived from PyTorch-TS
# - Source: https://github.com/zalandoresearch/pytorch-ts
# - Paper: Autoregressive Denoising Diffusion Models for Multivariate Probabilistic Time Series Forecasting
# - License: MIT, Apache-2.0 license
# We thank the authors for their contributions.
# ---------------------------------------------------------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from linear_attention_transformer import LinearAttentionTransformer
def get_torch_trans(heads=8, layers=1, channels=64,linear=False):
if linear:
encoder_layer = LinearAttentionTransformer(
dim = channels,
heads = heads,
depth = layers,
max_seq_len = 4096,
n_local_attn_heads = 0
)
return encoder_layer
else:
encoder_layer = nn.TransformerEncoderLayer(
d_model=channels, nhead=heads, dim_feedforward=64, activation="gelu"
)
return nn.TransformerEncoder(encoder_layer, num_layers=layers)
def Conv1d_with_init(in_channels, out_channels, kernel_size):
layer = nn.Conv1d(in_channels, out_channels, kernel_size)
nn.init.kaiming_normal_(layer.weight)
return layer
class DiffusionEmbedding(nn.Module):
def __init__(self, dim=128, proj_dim=None, max_steps=500):
super().__init__()
if proj_dim is None:
proj_dim = dim
self.register_buffer(
"embedding", self._build_embedding(dim, max_steps), persistent=False
)
self.projection1 = nn.Linear(dim * 2, proj_dim)
self.projection2 = nn.Linear(proj_dim, proj_dim)
def forward(self, diffusion_step):
x = self.embedding[diffusion_step]
x = self.projection1(x)
x = F.silu(x)
x = self.projection2(x)
x = F.silu(x)
return x
def _build_embedding(self, dim, max_steps):
steps = torch.arange(max_steps).unsqueeze(1) # [T,1]
dims = torch.arange(dim).unsqueeze(0) # [1,dim]
table = steps * 10.0 ** (dims * 4.0 / dim) # [T,dim]
table = torch.cat([torch.sin(table), torch.cos(table)], dim=1)
return table
class diff_CSDI(nn.Module):
def __init__(self, channels, diffusion_embedding_dim, side_dim, num_steps, nheads, n_layers, inputdim=2, linear=False):
super().__init__()
self.channels = channels
self.diffusion_embedding = DiffusionEmbedding(
dim=diffusion_embedding_dim, max_steps=num_steps
)
self.input_projection = Conv1d_with_init(inputdim, self.channels, 1)
self.output_projection1 = Conv1d_with_init(self.channels, self.channels, 1)
self.output_projection2 = Conv1d_with_init(self.channels, 1, 1)
nn.init.zeros_(self.output_projection2.weight)
self.residual_layers = nn.ModuleList(
[
ResidualBlock(
side_dim=side_dim,
channels=self.channels,
diffusion_embedding_dim=diffusion_embedding_dim,
nheads=nheads,
linear=linear,
)
for _ in range(n_layers)
]
)
def forward(self, x, cond_info, diffusion_step):
B, inputdim, K, L = x.shape
x = x.reshape(B, inputdim, K * L)
x = self.input_projection(x)
x = F.relu(x)
x = x.reshape(B, self.channels, K, L)
diffusion_emb = self.diffusion_embedding(diffusion_step)
skip = []
for layer in self.residual_layers:
x, skip_connection = layer(x, cond_info, diffusion_emb)
skip.append(skip_connection)
x = torch.sum(torch.stack(skip), dim=0) / math.sqrt(len(self.residual_layers))
x = x.reshape(B, self.channels, K * L)
x = self.output_projection1(x) # (B,channel,K*L)
x = F.relu(x)
x = self.output_projection2(x) # (B,1,K*L)
x = x.reshape(B, K, L)
return x
class ResidualBlock(nn.Module):
def __init__(self, side_dim, channels, diffusion_embedding_dim, nheads, linear=False):
super().__init__()
self.side_dim = side_dim
self.diffusion_projection = nn.Linear(diffusion_embedding_dim, channels)
self.cond_projection = Conv1d_with_init(side_dim, 2 * channels, 1)
self.mid_projection = Conv1d_with_init(channels, 2 * channels, 1)
self.output_projection = Conv1d_with_init(channels, 2 * channels, 1)
self.time_layer = get_torch_trans(heads=nheads, layers=1, channels=channels,linear=linear)
self.feature_layer = get_torch_trans(heads=nheads, layers=1, channels=channels,linear=linear)
def forward_time(self, y, base_shape):
B, channel, K, L = base_shape
if L == 1:
return y
y = y.reshape(B, channel, K, L).permute(0, 2, 1, 3).reshape(B * K, channel, L)
y = self.time_layer(y.permute(2, 0, 1)).permute(1, 2, 0)
y = y.reshape(B, K, channel, L).permute(0, 2, 1, 3).reshape(B, channel, K * L)
return y
def forward_feature(self, y, base_shape):
B, channel, K, L = base_shape
if K == 1:
return y
y = y.reshape(B, channel, K, L).permute(0, 3, 1, 2).reshape(B * L, channel, K)
y = self.feature_layer(y.permute(2, 0, 1)).permute(1, 2, 0)
y = y.reshape(B, L, channel, K).permute(0, 2, 3, 1).reshape(B, channel, K * L)
return y
def forward(self, x, cond_info, diffusion_emb):
B, channel, K, L = x.shape
base_shape = x.shape
x = x.reshape(B, channel, K * L)
diffusion_emb = self.diffusion_projection(diffusion_emb).unsqueeze(-1) # (B,channel,1)
y = x + diffusion_emb
y = self.forward_time(y, base_shape)
y = self.forward_feature(y, base_shape) # (B,channel,K*L)
y = self.mid_projection(y) # (B,2*channel,K*L)
_, cond_dim, _, _ = cond_info.shape
cond_info = cond_info.reshape(B, cond_dim, K * L)
cond_info = self.cond_projection(cond_info) # (B,2*channel,K*L)
y = y + cond_info
gate, filter = torch.chunk(y, 2, dim=1)
y = torch.sigmoid(gate) * torch.tanh(filter) # (B,channel,K*L)
y = self.output_projection(y)
residual, skip = torch.chunk(y, 2, dim=1)
x = x.reshape(base_shape)
residual = residual.reshape(base_shape)
skip = skip.reshape(base_shape)
return (x + residual) / math.sqrt(2.0), skip
================================================
FILE: probts/model/nn/prob/flow_model.py
================================================
# ---------------------------------------------------------------------------------
# Portions of this file are derived from PyTorch-TS
# - Source: https://github.com/zalandoresearch/pytorch-ts
# - Paper: Multi-variate Probabilistic Time Series Forecasting via Conditioned Normalizing Flows
# - License: MIT, Apache-2.0 license
# We thank the authors for their contributions.
# ---------------------------------------------------------------------------------
import torch
import torch.nn as nn
from torch.distributions import Normal
class FlowModel(nn.Module):
def __init__(self, target_dim, f_hidden_size, conditional_length, dequantize):
super().__init__()
self.__scale = None
self.net = None
self.dequantize = dequantize
self.dist_args = nn.Linear(
in_features=f_hidden_size, out_features=conditional_length
)
# base distribution for calculation of log prob under the model
self.register_buffer("base_dist_mean", torch.zeros(target_dim))
self.register_buffer("base_dist_var", torch.ones(target_dim))
@property
def base_dist(self):
return Normal(self.base_dist_mean, self.base_dist_var)
@property
def scale(self):
return self.__scale
@scale.setter
def scale(self, scale):
self.__scale = scale
def forward(self, x, cond):
if self.scale is not None:
x /= self.scale
u, log_abs_det_jacobian = self.net(x, cond)
return u, log_abs_det_jacobian
def inverse(self, u, cond):
x, log_abs_det_jacobian = self.net.inverse(u, cond)
if self.scale is not None:
x *= self.scale
log_abs_det_jacobian += torch.log(torch.abs(self.scale))
return x, log_abs_det_jacobian
def log_prob(self, x, cond):
if self.dequantize:
x += torch.rand_like(x)
u, sum_log_abs_det_jacobians = self.forward(x, cond)
return torch.sum(self.base_dist.log_prob(u) + sum_log_abs_det_jacobians, dim=-1)
def loss(self, x, cond):
return -self.log_prob(x, cond)
def sample(self, sample_shape=torch.Size(), cond=None):
if cond is not None:
shape = cond.shape[:-1]
else:
shape = sample_shape
u = self.base_dist.sample(shape)
sample, _ = self.inverse(u, cond)
return sample
class BatchNorm(nn.Module):
""" Flow Model BatchNorm layer """
def __init__(self, input_size, momentum=0.9, eps=1e-5):
super().__init__()
self.momentum = momentum
self.eps = eps
self.log_gamma = nn.Parameter(torch.zeros(input_size))
self.beta = nn.Parameter(torch.zeros(input_size))
self.register_buffer("running_mean", torch.zeros(input_size))
self.register_buffer("running_var", torch.ones(input_size))
def forward(self, x, cond_y=None):
if self.training:
self.batch_mean = x.view(-1, x.shape[-1]).mean(0)
# note MAF paper uses biased variance estimate; ie x.var(0, unbiased=False)
self.batch_var = x.view(-1, x.shape[-1]).var(0)
# update running mean
self.running_mean.mul_(self.momentum).add_(
self.batch_mean.data * (1 - self.momentum)
)
self.running_var.mul_(self.momentum).add_(
self.batch_var.data * (1 - self.momentum)
)
mean = self.batch_mean
var = self.batch_var
else:
mean = self.running_mean
var = self.running_var
# compute normalized input (cf original batch norm paper algo 1)
x_hat = (x - mean) / torch.sqrt(var + self.eps)
y = self.log_gamma.exp() * x_hat + self.beta
# compute log_abs_det_jacobian (cf RealNVP paper)
log_abs_det_jacobian = self.log_gamma - 0.5 * torch.log(var + self.eps)
return y, log_abs_det_jacobian.expand_as(x)
def inverse(self, y, cond_y=None):
if self.training:
mean = self.batch_mean
var = self.batch_var
else:
mean = self.running_mean
var = self.running_var
x_hat = (y - self.beta) * torch.exp(-self.log_gamma)
x = x_hat * torch.sqrt(var + self.eps) + mean
log_abs_det_jacobian = 0.5 * torch.log(var + self.eps) - self.log_gamma
return x, log_abs_det_jacobian.expand_as(x)
class FlowSequential(nn.Sequential):
""" Container for layers of a normalizing flow """
def forward(self, x, y):
sum_log_abs_det_jacobians = 0
for module in self:
x, log_abs_det_jacobian = module(x, y)
sum_log_abs_det_jacobians += log_abs_det_jacobian
return x, sum_log_abs_det_jacobians
def inverse(self, u, y):
sum_log_abs_det_jacobians = 0
for module in reversed(self):
u, log_abs_det_jacobian = module.inverse(u, y)
sum_log_abs_det_jacobians += log_abs_det_jacobian
return u, sum_log_abs_det_jacobians
================================================
FILE: probts/model/nn/prob/gaussian_diffusion.py
================================================
# ---------------------------------------------------------------------------------
# Portions of this file are derived from PyTorch-TS
# - Source: https://github.com/zalandoresearch/pytorch-ts
# - Paper: Autoregressive Denoising Diffusion Models for Multivariate Probabilistic Time Series Forecasting
# - License: MIT, Apache-2.0 license
# We thank the authors for their contributions.
# ---------------------------------------------------------------------------------
import math
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn, einsum
from probts.model.nn.prob.diffusion_layers import DiffusionEmbedding
from functools import partial
from inspect import isfunction
def default(val, d):
if val is not None:
return val
return d() if isfunction(d) else d
def extract(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
def noise_like(shape, device, repeat=False):
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(
shape[0], *((1,) * (len(shape) - 1))
)
noise = lambda: torch.randn(shape, device=device)
return repeat_noise() if repeat else noise()
def cosine_beta_schedule(timesteps, s=0.008):
"""
cosine schedule
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
"""
steps = timesteps + 1
x = np.linspace(0, timesteps, steps)
alphas_cumprod = np.cos(((x / timesteps) + s) / (1 + s) * np.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return np.clip(betas, 0, 0.999)
class ResidualBlock(nn.Module):
def __init__(self, hidden_size, residual_channels, dilation, target_dim):
super().__init__()
self.target_dim = target_dim
self.diffusion_projection = nn.Linear(hidden_size, residual_channels)
if self.target_dim > 1:
self.dilated_conv = nn.Conv1d(
residual_channels,
2 * residual_channels,
3,
padding=dilation,
dilation=dilation,
padding_mode="circular",
)
self.conditioner_projection = nn.Conv1d(
1, 2 * residual_channels, 1, padding=2, padding_mode="circular"
)
else:
self.dilated_conv = nn.Conv1d(residual_channels,2 * residual_channels,1)
self.conditioner_projection = nn.Conv1d(1, 2 * residual_channels, 1)
self.output_projection = nn.Conv1d(residual_channels, 2 * residual_channels, 1)
nn.init.kaiming_normal_(self.conditioner_projection.weight)
nn.init.kaiming_normal_(self.output_projection.weight)
def forward(self, x, conditioner, diffusion_step):
diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1)
conditioner = self.conditioner_projection(conditioner)
y = x + diffusion_step
y = self.dilated_conv(y) + conditioner
gate, filter = torch.chunk(y, 2, dim=1)
y = torch.sigmoid(gate) * torch.tanh(filter)
y = self.output_projection(y)
y = F.leaky_relu(y, 0.4)
residual, skip = torch.chunk(y, 2, dim=1)
return (x + residual) / math.sqrt(2.0), skip
class CondUpsampler(nn.Module):
def __init__(self, cond_length, target_dim):
super().__init__()
self.target_dim = target_dim
if self.target_dim > 1:
self.linear1 = nn.Linear(cond_length, target_dim // 2)
self.linear2 = nn.Linear(target_dim // 2, target_dim)
else:
self.linear = nn.Linear(cond_length, target_dim)
def forward(self, x):
if self.target_dim > 1:
x = self.linear1(x)
x = F.leaky_relu(x, 0.4)
x = self.linear2(x)
x = F.leaky_relu(x, 0.4)
else:
x = self.linear(x)
x = F.leaky_relu(x, 0.4)
return x
class EpsilonTheta(nn.Module):
def __init__(
self,
target_dim,
cond_length,
time_emb_dim=16,
residual_layers=8,
residual_channels=8,
dilation_cycle_length=2,
residual_hidden=64,
padding=2
):
super().__init__()
if target_dim > 1:
self.input_projection = nn.Conv1d(
1, residual_channels, 1, padding=padding, padding_mode="circular"
)
self.skip_projection = nn.Conv1d(residual_channels, residual_channels, 3)
self.output_projection = nn.Conv1d(residual_channels, 1, 3)
else:
# self.input_projection = nn.Identity()
self.input_projection = nn.Conv1d(1, residual_channels, 1)
self.skip_projection = nn.Conv1d(residual_channels, residual_channels, 1)
self.output_projection = nn.Conv1d(residual_channels, 1, 1)
self.diffusion_embedding = DiffusionEmbedding(
time_emb_dim, proj_dim=residual_hidden
)
self.cond_upsampler = CondUpsampler(
target_dim=target_dim, cond_length=cond_length
)
self.residual_layers = nn.ModuleList(
[
ResidualBlock(
residual_channels=residual_channels,
dilation=2 ** (i % dilation_cycle_length),
hidden_size=residual_hidden,
target_dim=target_dim,
)
for i in range(residual_layers)
]
)
nn.init.kaiming_normal_(self.input_projection.weight)
nn.init.kaiming_normal_(self.skip_projection.weight)
nn.init.zeros_(self.output_projection.weight)
def forward(self, inputs, time, cond):
x = self.input_projection(inputs)
x = F.leaky_relu(x, 0.4)
diffusion_step = self.diffusion_embedding(time)
cond_up = self.cond_upsampler(cond)
skip = []
for layer in self.residual_layers:
x, skip_connection = layer(x, cond_up, diffusion_step)
skip.append(skip_connection)
x = torch.sum(torch.stack(skip), dim=0) / math.sqrt(len(self.residual_layers))
x = self.skip_projection(x)
x = F.leaky_relu(x, 0.4)
x = self.output_projection(x)
return x
class GaussianDiffusion(nn.Module):
def __init__(
self,
target_dim,
f_hidden_size,
conditional_length,
beta_end=0.1,
diff_steps=100,
loss_type="l2",
betas=None,
beta_schedule="linear",
padding=2,
residual_channels=8,
):
super().__init__()
self.dist_args = nn.Linear(
in_features=f_hidden_size, out_features=conditional_length
)
self.denoise_fn = EpsilonTheta(
target_dim=target_dim,
cond_length=conditional_length,
residual_channels=residual_channels,
padding=padding,
)
self.target_dim = target_dim
self.__scale = None
if betas is not None:
betas = (
betas.detach().cpu().numpy()
if isinstance(betas, torch.Tensor)
else betas
)
else:
if beta_schedule == "linear":
betas = np.linspace(1e-4, beta_end, diff_steps)
elif beta_schedule == "quad":
betas = np.linspace(1e-4 ** 0.5, beta_end ** 0.5, diff_steps) ** 2
elif beta_schedule == "const":
betas = beta_end * np.ones(diff_steps)
elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
betas = 1.0 / np.linspace(diff_steps, 1, diff_steps)
elif beta_schedule == "sigmoid":
betas = np.linspace(-6, 6, diff_steps)
betas = (beta_end - 1e-4) / (np.exp(-betas) + 1) + 1e-4
elif beta_schedule == "cosine":
betas = cosine_beta_schedule(diff_steps)
else:
raise NotImplementedError(beta_schedule)
alphas = 1.0 - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
(timesteps,) = betas.shape
self.num_timesteps = int(timesteps)
self.loss_type = loss_type
to_torch = partial(torch.tensor, dtype=torch.float32)
self.register_buffer("betas", to_torch(betas))
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
self.register_buffer(
"sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))
)
self.register_buffer(
"log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))
)
self.register_buffer(
"sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod))
)
self.register_buffer(
"sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))
)
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = (
betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
)
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.register_buffer("posterior_variance", to_torch(posterior_variance))
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.register_buffer(
"posterior_log_variance_clipped",
to_torch(np.log(np.maximum(posterior_variance, 1e-20))),
)
self.register_buffer(
"posterior_mean_coef1",
to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)),
)
self.register_buffer(
"posterior_mean_coef2",
to_torch(
(1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)
),
)
@property
def scale(self):
return self.__scale
@scale.setter
def scale(self, scale):
self.__scale = scale
def q_mean_variance(self, x_start, t):
mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
variance = extract(1.0 - self.alphas_cumprod, t, x_start.shape)
log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
return mean, variance, log_variance
def predict_start_from_noise(self, x_t, t, noise):
return (
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
- extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
)
def q_posterior(self, x_start, x_t, t):
posterior_mean = (
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start
+ extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = extract(
self.posterior_log_variance_clipped, t, x_t.shape
)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, x, cond, t, clip_denoised: bool):
x_recon = self.predict_start_from_noise(
x, t=t, noise=self.denoise_fn(x, t, cond=cond)
)
if clip_denoised:
x_recon.clamp_(-1.0, 1.0)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
x_start=x_recon, x_t=x, t=t
)
return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad()
def p_sample(self, x, cond, t, clip_denoised=False, repeat_noise=False):
b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(
x=x, cond=cond, t=t, clip_denoised=clip_denoised
)
noise = noise_like(x.shape, device, repeat_noise)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad()
def p_sample_loop(self, shape, cond):
device = self.betas.device
b = shape[0]
img = torch.randn(shape, device=device)
for i in reversed(range(0, self.num_timesteps)):
img = self.p_sample(
img, cond, torch.full((b,), i, device=device, dtype=torch.long)
)
return img
@torch.no_grad()
def sample(self, sample_shape=torch.Size(), cond=None):
if cond is not None:
shape = cond.shape[:-1] + (self.target_dim,)
# TODO reshape cond to (B*T, 1, -1)
else:
shape = sample_shape
x_hat = self.p_sample_loop(shape, cond) # TODO reshape x_hat to (B,T,-1)
if self.scale is not None:
x_hat *= self.scale
return x_hat
@torch.no_grad()
def interpolate(self, x1, x2, t=None, lam=0.5):
b, *_, device = *x1.shape, x1.device
t = default(t, self.num_timesteps - 1)
assert x1.shape == x2.shape
t_batched = torch.stack([torch.tensor(t, device=device)] * b)
xt1, xt2 = map(lambda x: self.q_sample(x, t=t_batched), (x1, x2))
img = (1 - lam) * xt1 + lam * xt2
for i in reversed(range(0, t)):
img = self.p_sample(
img, torch.full((b,), i, device=device, dtype=torch.long)
)
return img
def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
return (
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
)
def p_losses(self, x_start, cond, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
x_recon = self.denoise_fn(x_noisy, t, cond=cond)
if self.loss_type == "l1":
loss = F.l1_loss(x_recon, noise)
elif self.loss_type == "l2":
loss = F.mse_loss(x_recon, noise)
elif self.loss_type == "huber":
loss = F.smooth_l1_loss(x_recon, noise)
else:
raise NotImplementedError()
return loss
def loss(self, x, cond, *args, **kwargs):
if self.scale is not None:
x /= self.scale
B, T, _ = x.shape
time = torch.randint(0, self.num_timesteps, (B * T,), device=x.device).long()
loss = self.p_losses(
x.reshape(B * T, 1, -1), cond.reshape(B * T, 1, -1), time, *args, **kwargs
)
return loss
================================================
FILE: probts/utils/__init__.py
================================================
from .utils import *
from .evaluator import Evaluator
================================================
FILE: probts/utils/download_datasets.py
================================================
import gdown
import shutil
import os
import argparse
def download_and_extract_zip(output_path, zip_name='all_datasets'):
output_path = os.path.normpath(output_path)
if not output_path.endswith(os.path.sep):
output_path += os.path.sep
gdown.download(id='1tSc1WA30CL2aMt5hAW7M-d5_0IBz-lJP', output=output_path, quiet=False)
print(f"Data files are saved to {os.path.dirname(output_path)}")
file_path = os.path.join(output_path, zip_name + '.zip')
try:
shutil.unpack_archive(file_path, os.path.dirname(file_path))
print(f"files are unzipped")
except shutil.ReadError:
print("is not zip file")
move_files_up_one_level(os.path.join(output_path, zip_name))
cleanup_directory(output_path)
print("datasets prepared done.")
def move_files_up_one_level(directory):
for item in os.listdir(directory):
if item in ['__MACOSX', '.DS_Store', 'all_datasets.zip']:
continue
s = os.path.join(directory, item)
d = os.path.join(os.path.dirname(directory), item)
if not os.path.exists(d):
shutil.move(s, d)
else:
print(f"skip {item} due to file exist")
delete_path(s)
try:
delete_path(directory)
except:
print(f'cannot delete {directory}, skip...')
def cleanup_directory(directory):
for root, dirs, files in os.walk(directory):
for name in dirs:
if name in ['__MACOSX']:
shutil.rmtree(os.path.join(root, name))
for name in files:
if name in ['.DS_Store', 'all_datasets.zip']:
os.remove(os.path.join(root, name))
def delete_path(path):
if os.path.exists(path):
if os.path.isfile(path):
os.remove(path)
elif os.path.isdir(path):
shutil.rmtree(path)
def download_datasets_from_kaggle(output_path):
import kagglehub
output_path = os.path.join(output_path, 'kaggle/')
if not os.path.exists(output_path):
os.makedirs(output_path)
path = kagglehub.dataset_download("dharanikra/electrical-power-demand-in-turkey")
s = os.path.join(path, 'power Generation and consumption.csv')
d = os.path.join(os.path.dirname(output_path), 'power Generation and consumption.csv')
shutil.move(s, d)
print("Path to electrical-power-demand-in-turkey files:", d)
delete_path(path)
path = kagglehub.dataset_download("leonardo00/istanbul-traffic-index")
s = os.path.join(path, 'istanbul_traffic.csv')
d = os.path.join(os.path.dirname(output_path), 'istanbul_traffic.csv')
shutil.move(s, d)
print("Path to istanbul-traffic-index files:", d)
delete_path(path)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Download and extract zip file from Google Drive')
parser.add_argument('--data_path', type=str, required=True, help='Path to store the extracted files')
args = parser.parse_args()
download_and_extract_zip(args.data_path, zip_name='all_datasets')
try:
download_datasets_from_kaggle(args.data_path)
except:
print("Cannot download datasets from kaggle, skip it.")
================================================
FILE: probts/utils/evaluator.py
================================================
import numpy as np
from .metrics import *
import torch
class Evaluator:
def __init__(self, quantiles_num=10, smooth=False):
self.quantiles = (1.0 * np.arange(quantiles_num) / quantiles_num)[1:]
self.ignore_invalid_values = True
self.smooth = smooth
def loss_name(self, q):
return f"QuantileLoss[{q}]"
def weighted_loss_name(self, q):
return f"wQuantileLoss[{q}]"
def coverage_name(self, q):
return f"Coverage[{q}]"
def get_sequence_metrics(self, targets, forecasts, seasonal_error=None, samples_dim=1,loss_weights=None):
mean_forecasts = forecasts.mean(axis=samples_dim)
median_forecasts = np.quantile(forecasts, 0.5, axis=samples_dim)
metrics = {
"MSE": mse(targets, mean_forecasts),
"abs_error": abs_error(targets, median_forecasts),
"abs_target_sum": abs_target_sum(targets),
"abs_target_mean": abs_target_mean(targets),
"MAPE": mape(targets, median_forecasts),
"sMAPE": smape(targets, median_forecasts),
}
if seasonal_error is not None:
metrics["MASE"] = mase(targets, median_forecasts, seasonal_error)
metrics["RMSE"] = np.sqrt(metrics["MSE"])
metrics["NRMSE"] = metrics["RMSE"] / metrics["abs_target_mean"]
metrics["ND"] = metrics["abs_error"] / metrics["abs_target_sum"]
# calculate weighted loss
if loss_weights is not None:
nd = np.abs(targets - mean_forecasts) / np.sum(np.abs(targets), axis=(1, 2))
loss_weights = loss_weights.detach().unsqueeze(0).unsqueeze(-1).numpy()
weighted_ND = loss_weights * nd
metrics['weighted_ND'] = np.sum(weighted_ND)
else:
metrics['weighted_ND'] = metrics["ND"]
for q in self.quantiles:
q_forecasts = np.quantile(forecasts, q, axis=samples_dim)
metrics[self.loss_name(q)] = np.sum(quantile_loss(targets, q_forecasts, q))
metrics[self.weighted_loss_name(q)] = \
metrics[self.loss_name(q)] / metrics["abs_target_sum"]
metrics[self.coverage_name(q)] = coverage(targets, q_forecasts)
metrics["mean_absolute_QuantileLoss"] = np.mean(
[metrics[self.loss_name(q)] for q in self.quantiles]
)
metrics["CRPS"] = np.mean(
[metrics[self.weighted_loss_name(q)] for q in self.quantiles]
)
metrics["MAE_Coverage"] = np.mean(
[
np.abs(metrics[self.coverage_name(q)] - np.array([q]))
for q in self.quantiles
]
)
return metrics
def get_metrics(self, targets, forecasts, seasonal_error=None, samples_dim=1, loss_weights=None):
metrics = {}
seq_metrics = {}
# Calculate metrics for each sequence
for i in range(targets.shape[0]):
single_seq_metrics = self.get_sequence_metrics(
np.expand_dims(targets[i], axis=0),
np.expand_dims(forecasts[i], axis=0),
np.expand_dims(seasonal_error[i], axis=0) if seasonal_error is not None else None,
samples_dim,
loss_weights
)
for metric_name, metric_value in single_seq_metrics.items():
if metric_name not in seq_metrics:
seq_metrics[metric_name] = []
seq_metrics[metric_name].append(metric_value)
for metric_name, metric_values in seq_metrics.items():
metrics[metric_name] = np.mean(metric_values)
return metrics
@property
def selected_metrics(self):
return [ "ND",'weighted_ND', 'CRPS', "NRMSE", "MSE", "MASE"]
def __call__(self, targets, forecasts, past_data, freq, loss_weights=None):
"""
Parameters
----------
targets
groundtruth in (batch_size, prediction_length, target_dim)
forecasts
forecasts in (batch_size, num_samples, prediction_length, target_dim)
Returns
-------
Dict[String, float]
metrics
"""
targets = process_tensor(targets)
forecasts = process_tensor(forecasts)
past_data = process_tensor(past_data)
if self.ignore_invalid_values:
targets = np.ma.masked_invalid(targets)
forecasts = np.ma.masked_invalid(forecasts)
seasonal_error = calculate_seasonal_error(past_data, freq)
metrics = self.get_metrics(targets, forecasts, seasonal_error=seasonal_error, samples_dim=1, loss_weights=loss_weights)
metrics_sum = self.get_metrics(targets.sum(axis=-1), forecasts.sum(axis=-1), samples_dim=1)
# select output metrics
output_metrics = dict()
for k in self.selected_metrics:
output_metrics[k] = metrics[k]
if k in metrics_sum:
output_metrics[f"{k}-Sum"] = metrics_sum[k]
return output_metrics
def process_tensor(targets):
if isinstance(targets, torch.Tensor):
targets = targets.cpu().detach().numpy()
elif isinstance(targets, np.ndarray):
pass
else:
raise TypeError("targets must be a torch.Tensor or a numpy.ndarray")
return targets
================================================
FILE: probts/utils/masking.py
================================================
# Code implementation from https://github.com/thuml/iTransformer
import torch
class TriangularCausalMask():
def __init__(self, B, L, device="cpu"):
mask_shape = [B, 1, L, L]
with torch.no_grad():
self._mask = torch.triu(torch.ones(mask_shape, dtype=torch.bool), diagonal=1).to(device)
@property
def mask(self):
return self._mask
class ProbMask():
def __init__(self, B, H, L, index, scores, device="cpu"):
_mask = torch.ones(L, scores.shape[-1], dtype=torch.bool).to(device).triu(1)
_mask_ex = _mask[None, None, :].expand(B, H, L, scores.shape[-1])
indicator = _mask_ex[torch.arange(B)[:, None, None],
torch.arange(H)[None, :, None],
index, :].to(device)
self._mask = indicator.view(scores.shape).to(device)
@property
def mask(self):
return self._mask
================================================
FILE: probts/utils/metrics.py
================================================
# ---------------------------------------------------------------------------------
# Portions of this file are derived from gluonts
# - Source: https://github.com/awslabs/gluonts
# - Paper: GluonTS: Probabilistic and Neural Time Series Modeling in Python
# - License: Apache-2.0
#
# We thank the authors for their contributions.
# ---------------------------------------------------------------------------------
from typing import Optional
import numpy as np
from gluonts.time_feature import get_seasonality
def mse(target: np.ndarray, forecast: np.ndarray) -> float:
r"""
.. math::
mse = mean((Y - \hat{Y})^2)
"""
return np.mean(np.square(target - forecast))
def abs_error(target: np.ndarray, forecast: np.ndarray) -> float:
r"""
.. math::
abs\_error = sum(|Y - \hat{Y}|)
"""
return np.sum(np.abs(target - forecast))
def abs_target_sum(target) -> float:
r"""
.. math::
abs\_target\_sum = sum(|Y|)
"""
return np.sum(np.abs(target))
def abs_target_mean(target) -> float:
r"""
.. math::
abs\_target\_mean = mean(|Y|)
"""
return np.mean(np.abs(target))
def mase(
target: np.ndarray,
forecast: np.ndarray,
seasonal_error: np.ndarray,
) -> float:
r"""
.. math::
mase = mean(|Y - \hat{Y}|) / seasonal\_error
See [HA21]_ for more details.
"""
diff = np.mean(np.abs(target - forecast), axis=1)
mase = diff / seasonal_error
# if seasonal_error is 0, set mase to 0
mase = mase.filled(0)
return np.mean(mase)
def calculate_seasonal_error(
past_data: np.ndarray,
freq: Optional[str] = None,
):
r"""
.. math::
seasonal\_error = mean(|Y[t] - Y[t-m]|)
where m is the seasonal frequency. See [HA21]_ for more details.
"""
seasonality = get_seasonality(freq)
if seasonality < len(past_data):
forecast_freq = seasonality
else:
# edge case: the seasonal freq is larger than the length of ts
# revert to freq=1
# logging.info('The seasonal frequency is larger than the length of the
# time series. Reverting to freq=1.')
forecast_freq = 1
y_t = past_data[:, :-forecast_freq]
y_tm = past_data[:, forecast_freq:]
mean_diff = np.mean(np.abs(y_t - y_tm), axis=1)
mean_diff = np.expand_dims(mean_diff, axis=1)
return mean_diff
def mape(target: np.ndarray, forecast: np.ndarray) -> float:
r"""
.. math::
mape = mean(|Y - \hat{Y}| / |Y|))
See [HA21]_ for more details.
"""
return np.mean(np.abs(target - forecast) / np.abs(target))
def smape(target: np.ndarray, forecast: np.ndarray) -> float:
r"""
.. math::
smape = 2 * mean(|Y - \hat{Y}| / (|Y| + |\hat{Y}|))
See [HA21]_ for more details.
"""
return 2 * np.mean(
np.abs(target - forecast) / (np.abs(target) + np.abs(forecast))
)
def quantile_loss(target: np.ndarray, forecast: np.ndarray, q: float) -> float:
r"""
.. math::
quantile\_loss = 2 * sum(|(Y - \hat{Y}) * ((Y <= \hat{Y}) - q)|)
"""
return 2 * np.abs((forecast - target) * ((target <= forecast) - q))
def scaled_quantile_loss(target: np.ndarray, forecast: np.ndarray, q: float, seasonal_error) -> np.ndarray:
return quantile_loss(target, forecast, q) / seasonal_error
def coverage(target: np.ndarray, forecast: np.ndarray) -> float:
r"""
.. math::
coverage = mean(Y < \hat{Y})
"""
return np.mean(target < forecast)
================================================
FILE: probts/utils/position_emb.py
================================================
import torch
from torch import nn
import numpy as np
from einops import rearrange, repeat
class Time_Encoder(nn.Module):
def __init__(self, embed_time):
super(Time_Encoder, self).__init__()
self.periodic = nn.Linear(1, embed_time - 1)
self.linear = nn.Linear(1, 1)
def forward(self, tt):
if tt.dim() == 3: # [B,L,K]
tt = rearrange(tt, 'b l k -> b l k 1')
else: # [B,L]
tt = rearrange(tt, 'b l -> b l 1 1')
out2 = torch.sin(self.periodic(tt))
out1 = self.linear(tt)
out = torch.cat([out1, out2], -1) # [B,L,1,D]
return out
def sin_cos_encoding(B, K, L, embed_dim):
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float64)
omega /= embed_dim / 2.
omega = 1. / 10000**omega # (D/2,)
pos= [i for i in range(L)]
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
emb = repeat(emb, 'l d -> b k l d', b=B, k=K)
return torch.tensor(emb, dtype=torch.float64)
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float64)
omega /= embed_dim / 2.
omega = 1. / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
================================================
FILE: probts/utils/save_utils.py
================================================
from typing import Dict
import numpy as np
import torch
from probts.model.forecaster import Forecaster
import importlib
import json
import pandas as pd
import pickle
import os
def update_metrics(new_metrics: Dict, stage: str, key: str = '', target_dict = {}):
prefix = stage if key == '' else f'{stage}_{key}'
for metric_name, metric_value in new_metrics.items():
metric_key = f'{prefix}_{metric_name}'
if metric_key not in target_dict:
target_dict[metric_key] = []
if isinstance(metric_value, list):
target_dict[metric_key] = target_dict[metric_key] + metric_value
else:
target_dict[metric_key].append(metric_value)
return target_dict
def calculate_average(metrics_dict: Dict, hor=''):
metrics = {}
if hor != '':
hor = hor + '/'
for key, value in metrics_dict.items():
metrics[hor+key] = np.mean(value)
return metrics
def calculate_weighted_average(metrics_dict: Dict, batch_size: list, hor=''):
metrics = {}
for key, value in metrics_dict.items():
metrics[hor+key] = np.sum(value * np.array(batch_size)) / np.sum(batch_size)
return metrics
def save_point_error(target, predict, input_dict, hor_str):
if hor_str not in input_dict:
input_dict[hor_str] = {'MAE': [], 'target': [], 'forecast': []}
abs_error = np.abs(target - predict)
input_dict[hor_str]['MAE'].append(abs_error)
input_dict[hor_str]['target'].append(target)
input_dict[hor_str]['forecast'].append(predict)
return input_dict
def load_checkpoint(Model, checkpoint_path, scaler=None, learning_rate=None, no_training=False, **kwargs):
# Load the checkpoint
checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
# Extract the arguments for the forecaster
forecaster_args = checkpoint['hyper_parameters']['forecaster']
if isinstance(forecaster_args, Forecaster):
forecaster = forecaster_args
else:
module_path, class_name = forecaster_args['class_path'].rsplit('.', 1)
forecaster_class = getattr(importlib.import_module(module_path), class_name)
# Add any missing required arguments
forecaster_args = forecaster_args['init_args']
forecaster_args.update(kwargs)
# Create the forecaster
forecaster = forecaster_class(**forecaster_args)
forecaster.no_training = no_training
if learning_rate is None:
learning_rate = checkpoint['hyper_parameters'].get('learning_rate', 1e-3)
# Create the model instance
model = Model(
forecaster=forecaster,
scaler=scaler,
num_samples=checkpoint['hyper_parameters'].get('num_samples', 100),
learning_rate=learning_rate,
quantiles_num=checkpoint['hyper_parameters'].get('quantiles_num', 10),
load_from_ckpt=checkpoint['hyper_parameters'].get('load_from_ckpt', None),
**kwargs # Pass additional arguments here
)
model.load_state_dict(checkpoint['state_dict'])
return model
def get_hor_str(prediction_length, dataloader_idx):
if dataloader_idx is not None:
hor_str = str(prediction_length[dataloader_idx])
elif type(prediction_length) == list:
hor_str = str(prediction_length[0])
else:
hor_str = str(prediction_length)
return hor_str
def save_exp_summary(pl_module, inference=False):
exp_summary = {}
model_summary = pl_module.model_summary_callback._summary(pl_module.trainer, pl_module.model)
exp_summary['total_parameters'] = model_summary.total_parameters
exp_summary['trainable_parameters'] = model_summary.trainable_parameters
exp_summary['model_size'] = model_summary.model_size
memory_summary = pl_module.memory_callback.memory_summary
exp_summary['memory_summary'] = memory_summary
time_summary = pl_module.time_callback.time_summary
exp_summary['time_summary'] = time_summary
for batch_key, batch_time in time_summary.items():
if len(batch_time) > 0:
exp_summary[f'mean_{batch_key}'] = sum(batch_time) / len(batch_time)
exp_summary['sampling_weight_scheme'] = pl_module.model.sampling_weight_scheme
if inference:
summary_save_path = f"{pl_module.save_dict}/inference_summary.json"
else:
summary_save_path = f"{pl_module.save_dict}/summary.json"
with open(summary_save_path, 'w') as f:
json.dump(exp_summary, f, indent=4)
print(f"Summary saved to {summary_save_path}")
def save_csv(save_dict, model, context_length):
if len(model.avg_hor_metrics) > 0:
horizon_list = []
for horizon in model.avg_hor_metrics:
horizon_dict = model.avg_hor_metrics[str(horizon)]
horizon_dict['horizon'] = horizon
horizon_list.append(horizon_dict)
df = pd.DataFrame(horizon_list)
else:
df = pd.DataFrame([model.avg_metrics])
if not model.forecaster.no_training:
test_result_file = 'horizons_results'
else:
test_result_file = f'testctx_{context_length}_horizons_results'
df.to_csv(f'{save_dict}/{test_result_file}.csv', index='idx')
print('horizons result saved to ', f'{save_dict}/{test_result_file}.csv')
================================================
FILE: probts/utils/utils.py
================================================
# ---------------------------------------------------------------------------------
# Portions of this file are derived from PyTorch-TS
# - Source: https://github.com/zalandoresearch/pytorch-ts
# - License: MIT, Apache-2.0 license
# We thank the authors for their contributions.
# ---------------------------------------------------------------------------------
import re
import os
import torch
import numpy as np
from typing import Optional, Dict
import torch.nn as nn
import importlib
def repeat(tensor: torch.Tensor, n: int, dim: int = 0):
return tensor.repeat_interleave(repeats=n, dim=dim)
def extract(a, t, x_shape):
batch_size = t.shape[0]
out = a.gather(-1, t.cpu())
return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)
def weighted_average(
x: torch.Tensor,
weights: Optional[torch.Tensor] = None,
dim: int = None,
reduce: str = 'mean',
):
"""
Computes the weighted average of a given tensor across a given dim, masking
values associated with weight zero,
meaning instead of `nan * 0 = nan` you will get `0 * 0 = 0`.
Args:
x: Input tensor, of which the average must be computed.
weights: Weights tensor, of the same shape as `x`.
dim: The dim along which to average `x`
Returns:
Tensor: The tensor with values averaged along the specified `dim`.
"""
if weights is not None:
weighted_tensor = torch.where(weights != 0, x * weights, torch.zeros_like(x))
if reduce != 'mean':
return weighted_tensor
sum_weights = torch.clamp(
weights.sum(dim=dim) if dim else weights.sum(), min=1.0
)
return (
weighted_tensor.sum(dim=dim) if dim else weighted_tensor.sum()
) / sum_weights
else:
return x.mean(dim=dim) if dim else x
def convert_to_list(s):
'''
Convert prediction length strings into list
e.g., '96-192-336-720' will be convert into [96,192,336,720]
Input: str, list, int
Returns: list
'''
if (type(s).__name__=='int'):
return [s]
elif (type(s).__name__=='list'):
return s
elif (type(s).__name__=='str'):
elements = re.split(r'\D+', s)
return list(map(int, elements))
else:
return None
def find_best_epoch(ckpt_folder):
"""
Find the highest epoch in the Test Tube file structure.
Thanks to GitHub@Kai-Ref for identifying and fixing the issue with CRPS value comparisons.
"""
pattern = r"epoch=(\d+)-val_CRPS=([0-9]*\.[0-9]+)"
ckpt_files = os.listdir(ckpt_folder) # List of checkpoint files
best_ckpt = None
best_epoch = None
best_crps = float("inf") # Start with an infinitely large CRPS
for filename in ckpt_files:
match = re.search(pattern, filename)
if match:
epoch = int(match.group(1)) # Extract epoch number
crps = float(match.group(2)) # Extract CRPS value
if crps < best_crps: # If this is the lowest CRPS found so far
best_crps = crps
best_ckpt = filename
best_epoch = epoch # Store the best epoch number
return best_epoch, best_ckpt
def ensure_list(input_value, default_value=None):
"""
Ensures that the input is converted to a list. If the input is None,
it converts the default value to a list instead.
"""
result = convert_to_list(input_value)
if result is None:
result = convert_to_list(default_value)
return result
def init_class_helper(class_name):
"""
Dynamically imports a module and retrieves a class.
Args:
class_name (str): The fully qualified name of the class in the format "module_name.ClassName".
Returns:
type: The class object retrieved from the specified module.
"""
module_name, class_name = class_name.rsplit(".", 1)
module = importlib.import_module(module_name)
Class = getattr(module, class_name)
return Class
================================================
FILE: pyproject.toml
================================================
[build-system]
requires = ["setuptools>=66"]
[project]
name = "ProbTS"
version = "0.1.0"
description = "Benchmarking Point and Distributional Forecasting across Diverse Prediction Horizons"
authors = [
{name = "Jiawen Zhang"},
{name = "Xumeng Wen"},
{name = "Zhenwei Zhang"},
{name = "Shun Zhen"},
]
readme = "README.md"
requires-python = ">=3.10"
license = {text = "MIT"}
dependencies = [
"numpy",
"pandas==2.0.3",
"einops",
"matplotlib",
"tqdm",
"PyYAML>=6.0",
"lightning @ https://github.com/Lightning-AI/lightning/archive/refs/heads/master.zip",
"gluonts~=0.15.1",
"typeshed-client==2.3.0",
"docstring-parser==0.15",
"orjson==3.9.0",
"einops>=0.6.1",
"pydantic==1.10.8",
"transformers==4.50.0",
"linear-attention-transformer==0.19.1",
"tensorboardx==2.6.2",
"pyarrow==11.0.0",
"protobuf>=3.19",
"jsonargparse[signatures]",
"opt_einsum",
"psutil",
"reformer-pytorch",
"gdown",
"kagglehub",
"python-dotenv>=1.0.0",
"utilsforecast",
"jax",
"scikit-learn",
]
[project.optional-dependencies]
tsfm = [
"timm",
"accelerate",
"tokenizers",
"datasets",
"jaxtyping",
"hydra-core==1.3",
"orjson",
"tensorboard",
"multiprocess",
"huggingface_hub>=0.23.0",
"safetensors",
"jax[cpu]",
"paxml>=1.4.0", # for timesfm
"praxis>=1.4.0",
"einshape>=1.0.0",
"numpy>=1.26.4",
"pandas==2.0.3",
"pykeops",
]
[tool.setuptools]
py-modules = []
================================================
FILE: run.py
================================================
import os
import torch
import logging
from probts.data import ProbTSDataModule
from probts.model.forecast_module import ProbTSForecastModule
from probts.callbacks import MemoryCallback, TimeCallback
from probts.utils import find_best_epoch
from lightning.pytorch.cli import LightningCLI
from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger
from lightning.pytorch.callbacks import ModelCheckpoint
from probts.utils.save_utils import save_exp_summary, save_csv
MULTI_HOR_MODEL = ['ElasTST', 'Autoformer']
import warnings
warnings.filterwarnings('ignore')
torch.set_float32_matmul_precision('high')
log = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
class ProbTSCli(LightningCLI):
def add_arguments_to_parser(self, parser):
data_to_model_link_args = [
"scaler",
"train_pred_len_list",
]
data_to_forecaster_link_args = [
"target_dim",
"history_length",
"context_length",
"prediction_length",
"train_pred_len_list",
"lags_list",
"freq",
"time_feat_dim",
"global_mean",
"dataset"
]
for arg in data_to_model_link_args:
parser.link_arguments(f"data.data_manager.{arg}", f"model.{arg}", apply_on="instantiate")
for arg in data_to_forecaster_link_args:
parser.link_arguments(f"data.data_manager.{arg}", f"model.forecaster.init_args.{arg}", apply_on="instantiate")
def init_exp(self):
config_args = self.parser.parse_args()
if self.datamodule.data_manager.multi_hor:
assert self.model.forecaster.name in MULTI_HOR_MODEL, f"Only support multi-horizon setting for {MULTI_HOR_MODEL}"
self.tag = "_".join([
self.datamodule.data_manager.dataset,
self.model.forecaster.name,
'TrainCTX','-'.join([str(i) for i in self.datamodule.data_manager.train_ctx_len_list]),
'TrainPRED','-'.join([str(i) for i in self.datamodule.data_manager.train_pred_len_list]),
'ValCTX','-'.join([str(i) for i in self.datamodule.data_manager.val_ctx_len_list]),
'ValPRED','-'.join([str(i) for i in self.datamodule.data_manager.val_pred_len_list]),
'seed' + str(config_args.seed_everything)
])
else:
self.tag = "_".join([
self.datamodule.data_manager.dataset,
self.model.forecaster.name,
'CTX' + str(self.datamodule.data_manager.context_length),
'PRED' + str(self.datamodule.data_manager.prediction_length),
'seed' + str(config_args.seed_everything)
])
log.info(f"Root dir is {self.trainer.default_root_dir}, exp tag is {self.tag}")
if not os.path.exists(self.trainer.default_root_dir):
os.makedirs(self.trainer.default_root_dir)
self.save_dict = f'{self.trainer.default_root_dir}/{self.tag}'
if not os.path.exists(self.save_dict):
os.makedirs(self.save_dict)
if self.model.load_from_ckpt is not None:
# if the checkpoint file is not assigned, find the best epoch in the current folder
if '.ckpt' not in self.model.load_from_ckpt:
_, best_ckpt = find_best_epoch(self.model.load_from_ckpt)
print("find best ckpt ", best_ckpt)
self.model.load_from_ckpt = os.path.join(self.model.load_from_ckpt, best_ckpt)
log.info(f"Loading pre-trained checkpoint from {self.model.load_from_ckpt}")
self.model = ProbTSForecastModule.load_from_checkpoint(
self.model.load_from_ckpt,
learning_rate=config_args.model.learning_rate,
scaler=self.datamodule.data_manager.scaler,
context_length=self.datamodule.data_manager.context_length,
target_dim=self.datamodule.data_manager.target_dim,
freq=self.datamodule.data_manager.freq,
prediction_length=self.datamodule.data_manager.prediction_length,
train_pred_len_list=self.datamodule.data_manager.train_pred_len_list,
lags_list=self.datamodule.data_manager.lags_list,
time_feat_dim=self.datamodule.data_manager.time_feat_dim,
no_training=self.model.forecaster.no_training,
sampling_weight_scheme=self.model.sampling_weight_scheme,
)
# Set callbacks
self.memory_callback = MemoryCallback()
self.time_callback = TimeCallback()
callbacks = [
self.memory_callback,
self.time_callback
]
if not self.model.forecaster.no_training:
if self.datamodule.dataset_val is None: # if the validation set is empty
monitor = "train_loss"
else:
# not using reweighting scheme for loss
if self.model.sampling_weight_scheme in ['none', 'fix']:
monitor = 'val_CRPS'
else:
monitor = 'val_weighted_ND'
# Set callbacks
self.checkpoint_callback = ModelCheckpoint(
dirpath=f'{self.save_dict}/ckpt',
filename='{epoch}-{val_CRPS:.6f}',
every_n_epochs=1,
monitor=monitor,
save_top_k=-1,
save_last=True,
enable_version_counter=False
)
callbacks.append(self.checkpoint_callback)
self.set_callbacks(callbacks)
def set_callbacks(self, callbacks):
# Replace built-in callbacks with custom callbacks
custom_callbacks_name = [c.__class__.__name__ for c in callbacks]
for c in self.trainer.callbacks:
if c.__class__.__name__ in custom_callbacks_name:
self.trainer.callbacks.remove(c)
for c in callbacks:
self.trainer.callbacks.append(c)
for c in self.trainer.callbacks:
if c.__class__.__name__ == "ModelSummary":
self.model_summary_callback = c
def set_fit_mode(self):
self.trainer.logger = TensorBoardLogger(
save_dir=f'{self.save_dict}/logs',
name=self.tag,
version='fit'
)
def set_test_mode(self):
self.trainer.logger = CSVLogger(
save_dir=f'{self.save_dict}/logs',
name=self.tag,
version='test'
)
if not self.model.forecaster.no_training:
self.ckpt = self.checkpoint_callback.best_model_path
log.info(f"Loading best checkpoint from {self.ckpt}")
self.model = ProbTSForecastModule.load_from_checkpoint(
self.ckpt,
scaler=self.datamodule.data_manager.scaler,
context_length=self.datamodule.data_manager.context_length,
target_dim=self.datamodule.data_manager.target_dim,
freq=self.datamodule.data_manager.freq,
prediction_length=self.datamodule.data_manager.prediction_length,
lags_list=self.datamodule.data_manager.lags_list,
time_feat_dim=self.datamodule.data_manager.time_feat_dim,
sampling_weight_scheme=self.model.sampling_weight_scheme,
)
def run(self):
self.init_exp()
if not self.model.forecaster.no_training:
self.set_fit_mode()
if self.datamodule.dataset_val is None: # if the validation set is empty
self.trainer.fit(model=self.model, train_dataloaders=self.datamodule.train_dataloader())
else:
self.trainer.fit(model=self.model, datamodule=self.datamodule)
inference=False
else:
inference=True
self.set_test_mode()
self.trainer.test(model=self.model, datamodule=self.datamodule)
save_exp_summary(self, inference=inference)
ctx_len = self.datamodule.data_manager.context_length
if self.datamodule.data_manager.multi_hor:
ctx_len = ctx_len[0]
save_csv(self.save_dict, self.model, ctx_len)
if __name__ == '__main__':
cli = ProbTSCli(
datamodule_class=ProbTSDataModule,
model_class=ProbTSForecastModule,
save_config_kwargs={"overwrite": True},
run=False
)
cli.run()
================================================
FILE: run.sh
================================================
MODEL=patchtst
DATASET=etth1
CTX_LEN=96
PRED_LEN=96
# DATA_DIR=/path/to/datasets
# LOG_DIR=/path/to/log_dir
DATA_DIR=./datasets
LOG_DIR=./log_dir
# multivariate datasets:
# ['exchange_rate_nips', 'solar_nips','electricity_nips', 'traffic_nips','wiki2000_nips']
# Univariate datasets:
# ['m4_weekly', 'm4_hourly', 'm4_daily', 'm4_monthly', 'm4_quarterly', 'm4_yearly', 'm5', 'tourism_monthly', 'tourism_quarterly', 'tourism_yearly']
# Long-term forecasting:
# ['etth1', 'etth2','ettm1','ettm2','traffic_ltsf', 'electricity_ltsf', 'exchange_ltsf', 'illness_ltsf', 'weather_ltsf']
# NOTE: when using long-term forecasting datasets, please explicit assign context_length and prediction_length, e.g., :
# --data.data_manager.init_args.context_length 96 \
# --data.data_manager.init_args.prediction_length 192 \
# run pipeline with train and test
# replace ${MODEL} with tarfet model name, e.g, patchtst
# replace ${DATASET} with dataset name
# if not specify dataset_path, the default path is ./datasets
# to run on cpu, uncomment the last line
python run.py --config config/ltsf/${DATASET}/${MODEL}.yaml --seed_everything 0 \
--data.data_manager.init_args.path ${DATA_DIR} \
--trainer.default_root_dir ${LOG_DIR} \
--data.data_manager.init_args.dataset ${DATASET} \
--data.data_manager.init_args.split_val true \
--trainer.max_epochs 50 \
--data.data_manager.init_args.context_length ${CTX_LEN} \
--data.data_manager.init_args.prediction_length ${PRED_LEN}
# --trainer.accelerator=cpu --trainer.devices=1
================================================
FILE: scripts/prepare_datasets.sh
================================================
# Check if gdown is installed
if pip show gdown > /dev/null 2>&1; then
echo "gdown is already installed, skipping installation."
else
echo "gdown is not installed, installing..."
pip install gdown
fi
python probts/utils/download_datasets.py --data_path $1
================================================
FILE: scripts/prepare_tsfm_checkpoints.sh
================================================
#!/bin/sh
echo "NOTE! By downloading these checkpoints, you agree to the licenses of the original models and checkpoints."
echo ""
echo "- [Timer](https://github.com/thuml/Large-Time-Series-Model) created by thuml. The original model and its checkpoints are licensed under the MIT License. The checkpoints are distributed under the MIT License. You may not use these files except in compliance with the License. You may obtain a copy of the License at: https://github.com/thuml/Large-Time-Series-Model/blob/main/LICENSE."
echo "- [ForecastPFN](https://github.com/abacusai/ForecastPFN) created by abacusai. The original model and its checkpoints are licensed under the MIT License. The checkpoints are distributed under the Apache-2.0 License. You may not use these files except in compliance with the License. You may obtain a copy of the License at: https://github.com/abacusai/ForecastPFN/blob/main/LICENSE."
echo "- [UniTS](https://github.com/mims-harvard/UniTS) created by mims-harvard. The original model and its checkpoints are licensed under the MIT License. The checkpoints are distributed under the MIT License. You may not use these files except in compliance with the License. You may obtain a copy of the License at: https://github.com/mims-harvard/UniTS/blob/main/LICENSE."
echo "- [Lag-Llama](https://github.com/time-series-foundation-models/lag-llama) created by time-series-foundation-models. The original model and its checkpoints are licensed under the MIT License. The checkpoints are distributed under the Apache-2.0 License. You may not use these files except in compliance with the License. You may obtain a copy of the License at: https://github.com/time-series-foundation-models/lag-llama/blob/main/LICENSE."
echo ""
echo "NOTE! By downloading these checkpoints, you agree to the licenses of the original models and checkpoints."
read -p "Do you want to continue? (yes/y to continue): " confirm
# Convert input to lowercase for comparison
confirm=$(echo "$confirm" | tr '[:upper:]' '[:lower:]')
if [ "$confirm" = "yes" ] || [ "$confirm" = "y" ]; then
# Check if gdown is installed
if pip show gdown > /dev/null 2>&1; then
echo "gdown is already installed, skipping installation."
else
echo "gdown is not installed, installing..."
pip install gdown
fi
# Download the folder
gdown --folder 1FaCk9Lj9KZGEO09gehNqC4fbTj4wnN8j -O checkpoints
else
echo "Download canceled."
fi
================================================
FILE: scripts/reproduce_ltsf_results.sh
================================================
export CUDA_VISIBLE_DEVICES=0
DATA_DIR=./datasets
LOG_DIR=./exps
CTX_LEN=96
for DATASET in 'etth1' 'etth2' 'ettm1' 'ettm2' 'weather_ltsf' 'electricity_ltsf' 'exchange_ltsf' 'traffic_ltsf'
do
for MODEL in 'dlinear' 'patchtst' 'gru_nvp' 'timegrad' 'csdi'
do
for PRED_LEN in 96 192 336 720
do
python run.py --config config/ltsf/${DATASET}/${MODEL}.yaml --seed_everything 0 \
--data.data_manager.init_args.path ${DATA_DIR} \
--trainer.default_root_dir ${LOG_DIR} \
--data.data_manager.init_args.split_val true \
--data.data_manager.init_args.dataset ${DATASET} \
--data.data_manager.init_args.context_length ${CTX_LEN} \
--data.data_manager.init_args.prediction_length ${PRED_LEN}
done
done
done
CTX_LEN=36
for DATASET in 'illness_ltsf'
do
for MODEL in 'dlinear' 'patchtst' 'gru_nvp' 'timegrad' 'csdi'
do
for PRED_LEN in 24 36 48 60
do
python run.py --config config/ltsf/${DATASET}/${MODEL}.yaml --seed_everything 0 \
--data.data_manager.init_args.path ${DATA_DIR} \
--trainer.default_root_dir ${LOG_DIR} \
--data.data_manager.init_args.split_val true \
--data.data_manager.init_args.dataset ${DATASET} \
--data.data_manager.init_args.context_length ${CTX_LEN} \
--data.data_manager.init_args.prediction_length ${PRED_LEN}
done
done
done
================================================
FILE: scripts/reproduce_stsf_results.sh
================================================
export CUDA_VISIBLE_DEVICES=0
DATA_DIR=./datasets
LOG_DIR=./exps
for DATASET in 'solar' 'electricity' 'exchange' 'traffic' 'wiki'
do
for MODEL in 'dlinear' 'patchtst' 'gru_nvp' 'gru_maf' 'trans_maf' 'timegrad' 'csdi' 'timesnet'
do
python run.py --config config/stsf/${DATASET}/${MODEL}.yaml --seed_everything 0 \
--data.data_manager.init_args.path ${DATA_DIR} \
--trainer.default_root_dir ${LOG_DIR} \
--data.data_manager.init_args.split_val true
done
done
================================================
FILE: scripts/reproduce_tsfm_results.sh
================================================
export CUDA_VISIBLE_DEVICES=0
DATA_DIR=./datasets
LOG_DIR=./exps
# MOIRAI
MODEL='moirai'
for DATASET in 'etth1' 'etth2' 'ettm1' 'ettm2' 'weather_ltsf' 'electricity_ltsf'; do
for CTX_LEN in 5000 96; do
for PRED_LEN in 24 48 96 192 336 720; do
python run.py --config config/tsfm/${MODEL}/context_${CTX_LEN}/${DATASET}.yaml --seed_everything 0 \
--data.data_manager.init_args.path ${DATA_DIR} \
--trainer.default_root_dir ${LOG_DIR} \
--data.data_manager.init_args.dataset ${DATASET} \
--data.data_manager.init_args.prediction_length ${PRED_LEN}
done
done
done
for DATASET in 'exchange_rate_nips' 'solar_nips' 'electricity_nips'; do
for CTX_LEN in 5000 96; do
python run.py --config config/tsfm/${MODEL}/context_${CTX_LEN}/${DATASET}.yaml --seed_everything 0 \
--data.data_manager.init_args.path ${DATA_DIR} \
--trainer.default_root_dir ${LOG_DIR} \
--data.data_manager.init_args.dataset ${DATASET}
done
done
# Chronos
MODEL='chronos'
for DATASET in 'etth1' 'etth2' 'ettm1' 'ettm2' 'weather_ltsf'; do
for CTX_LEN in 5000 96; do
for PRED_LEN in 24 48 96 192 336 720; do
python run.py --config config/tsfm/${MODEL}.yaml --seed_everything 0 \
--data.data_manager.init_args.path ${DATA_DIR} \
--trainer.default_root_dir ${LOG_DIR} \
--data.data_manager.init_args.split_val true \
--data.data_manager.init_args.dataset ${DATASET} \
--data.data_manager.init_args.context_length ${CTX_LEN} \
--data.data_manager.init_args.prediction_length ${PRED_LEN} \
--data.test_batch_size 1
done
done
done
for DATASET in 'exchange_rate_nips' 'traffic_nips'; do
for CTX_LEN in 512 96; do
for PRED_LEN in 24; do
python run.py --config config/tsfm/${MODEL}.yaml --seed_everything 0 \
--data.data_manager.init_args.path ${DATA_DIR} \
--trainer.default_root_dir ${LOG_DIR} \
--data.data_manager.init_args.split_val true \
--data.data_manager.init_args.dataset ${DATASET} \
--data.data_manager.init_args.context_length ${CTX_LEN} \
--data.data_manager.init_args.prediction_length ${PRED_LEN} \
--data.test_batch_size 1
done
done
done
# Lag-Llama
MODEL='lag_llama'
for DATASET in 'etth1' 'etth2' 'ettm1' 'ettm2' 'weather_ltsf'; do
for CTX_LEN in 512; do
for PRED_LEN in 24 48 96 192 336 720; do
python run.py --config config/tsfm/${MODEL}.yaml --seed_everything 0 \
--data.data_manager.init_args.path ${DATA_DIR} \
--trainer.default_root_dir ${LOG_DIR} \
--data.data_manager.init_args.split_val true \
--data.data_manager.init_args.dataset ${DATASET} \
--data.data_manager.init_args.context_length ${CTX_LEN} \
--data.data_manager.init_args.prediction_length ${PRED_LEN} \
--model.forecaster.init_args.ckpt_path './checkpoints/lag-llama/lag-llama.ckpt' \
--data.test_batch_size 1
done
done
done
# TimesFM
MODEL='timesfm'
for DATASET in 'etth1' 'etth2' 'ettm1' 'ettm2'; do
for CTX_LEN in 96; do
for PRED_LEN in 24 48 96 192 336 720; do
python run.py --config config/tsfm/${MODEL}.yaml --seed_everything 0 \
--data.data_manager.init_args.path ${DATA_DIR} \
--trainer.default_root_dir ${LOG_DIR} \
--data.data_manager.init_args.split_val true \
--data.data_manager.init_args.dataset ${DATASET} \
--data.data_manager.init_args.context_length ${CTX_LEN} \
--data.data_manager.init_args.prediction_length ${PRED_LEN} \
--data.test_batch_size 64
done
done
done
# Timer
MODEL='timer'
for DATASET in 'etth1' 'etth2' 'ettm1' 'ettm2' 'weather_ltsf' 'electricity_ltsf'; do
for CTX_LEN in 96; do
for PRED_LEN in 24 48 96 192 336 720; do
python run.py --config config/tsfm/${MODEL}.yaml --seed_everything 0 \
--data.data_manager.init_args.path ${DATA_DIR} \
--trainer.default_root_dir ${LOG_DIR} \
--data.data_manager.init_args.split_val true \
--data.data_manager.init_args.dataset ${DATASET} \
--data.data_manager.init_args.context_length ${CTX_LEN} \
--data.data_manager.init_args.prediction_length ${PRED_LEN} \
--model.forecaster.init_args.ckpt_path './checkpoints/timer/Timer_67M_UTSD_4G.pt' \
--data.test_batch_size 64
done
done
done
# UniTS
MODEL='units'
for DATASET in 'etth1' 'etth2' 'ettm1' 'ettm2'; do
for CTX_LEN in 96; do
for PRED_LEN in 24 48 96 192 336 720; do
python run.py --config config/tsfm/${MODEL}.yaml --seed_everything 0 \
--data.data_manager.init_args.path ${DATA_DIR} \
--trainer.default_root_dir ${LOG_DIR} \
--data.data_manager.init_args.split_val true \
--data.data_manager.init_args.dataset ${DATASET} \
--data.data_manager.init_args.context_length ${CTX_LEN} \
--data.data_manager.init_args.prediction_length ${PRED_LEN} \
--model.forecaster.init_args.ckpt_path './checkpoints/units/units_x128_pretrain_checkpoint.pth' \
--data.test_batch_size 64
done
done
done
# ForecastPFN
MODEL='forecastpfn'
for DATASET in 'etth1' 'etth2' 'ettm1' 'ettm2' 'weather_ltsf'; do
for CTX_LEN in 96; do
for PRED_LEN in 24 48 96 192 336 720; do
python run.py --config config/tsfm/${MODEL}.yaml --seed_everything 0 \
--data.data_manager.init_args.path ${DATA_DIR} \
--trainer.default_root_dir ${LOG_DIR} \
--data.data_manager.init_args.split_val true \
--data.data_manager.init_args.dataset ${DATASET} \
--data.data_manager.init_args.context_length ${CTX_LEN} \
--data.data_manager.init_args.prediction_length ${PRED_LEN} \
--model.forecaster.init_args.ckpt_path './checkpoints/ForecastPFN/saved_weights' \
--data.test_batch_size 64
done
done
done
================================================
FILE: scripts/run_elastst.sh
================================================
DATA_DIR=/path/to/datasets
LOG_DIR=/path/to/log_dir
# for varied-horizon forecasting
TRAIN_CTX_LEN=96
VAL_CTX_LEN=96
TEST_CTX_LEN=96
TRAIN_PRED_LEN=720
VAL_PRED_LEN=720
TEST_PRED_LEN=24-48-96-192-336-720
DATASET='exchange_ltsf' # select from ['etth1', 'etth2', 'ettm1', 'ettm2', 'traffic_ltsf', 'electricity_ltsf', 'exchange_ltsf', 'weather_ltsf']
MODEL=elastst
python run.py --config config/multi_hor/${MODEL}.yaml --seed_everything 0 \
--data.data_manager.init_args.path ${DATA_DIR} \
--trainer.default_root_dir ${LOG_DIR} \
--data.data_manager.init_args.split_val true \
--data.data_manager.init_args.dataset ${DATASET} \
--data.data_manager.init_args.context_length ${TEST_CTX_LEN} \
--data.data_manager.init_args.prediction_length ${TEST_PRED_LEN} \
--data.data_manager.init_args.train_pred_len_list ${TRAIN_PRED_LEN} \
--data.data_manager.init_args.train_ctx_len ${TRAIN_CTX_LEN} \
--data.data_manager.init_args.val_ctx_len ${VAL_CTX_LEN} \
--data.data_manager.init_args.val_pred_len_list ${VAL_PRED_LEN} \
--trainer.max_epochs 50
================================================
FILE: scripts/run_varied_hor_training.sh
================================================
DATA_DIR=/path/to/datasets
LOG_DIR=/path/to/log_dir
# for varied-horizon forecasting
TRAIN_CTX_LEN=96
VAL_CTX_LEN=96
TEST_CTX_LEN=96
TRAIN_PRED_LEN=1-720
VAL_PRED_LEN=720
TEST_PRED_LEN=24-48-96-192-336-720
DATASET='exchange_ltsf' # select from ['etth1', 'etth2', 'ettm1', 'ettm2', 'traffic_ltsf', 'electricity_ltsf', 'exchange_ltsf', 'weather_ltsf']
MODEL=elastst
python run.py --config config/multi_hor/${MODEL}.yaml --seed_everything 0 \
--data.data_manager.init_args.path ${DATA_DIR} \
--trainer.default_root_dir ${LOG_DIR} \
--data.data_manager.init_args.split_val true \
--data.data_manager.init_args.dataset ${DATASET} \
--data.data_manager.init_args.context_length ${TEST_CTX_LEN} \
--data.data_manager.init_args.prediction_length ${TEST_PRED_LEN} \
--data.data_manager.init_args.train_pred_len_list ${TRAIN_PRED_LEN} \
--data.data_manager.init_args.train_ctx_len ${TRAIN_CTX_LEN} \
--data.data_manager.init_args.val_ctx_len ${VAL_CTX_LEN} \
--data.data_manager.init_args.val_pred_len_list ${VAL_PRED_LEN} \
--data.data_manager.init_args.continuous_sample true \
--trainer.max_epochs 50