Repository: amazon-science/unconditional-time-series-diffusion
Branch: main
Commit: 3eafeffdffef
Files: 126
Total size: 289.4 KB
Directory structure:
gitextract_qvnzuyyo/
├── .gitignore
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── NOTICE
├── README.md
├── THIRD-PARTY-LICENSES.txt
├── bin/
│ ├── guidance_experiment.py
│ ├── refinement_experiment.py
│ ├── train_cond_model.py
│ ├── train_model.py
│ └── tstr_experiment.py
├── configs/
│ ├── guidance/
│ │ ├── guidance_electricity.yaml
│ │ ├── guidance_exchange.yaml
│ │ ├── guidance_kdd_cup.yaml
│ │ ├── guidance_m4.yaml
│ │ ├── guidance_solar.yaml
│ │ ├── guidance_traffic.yaml
│ │ ├── guidance_uber_tlc.yaml
│ │ └── guidance_wiki.yaml
│ ├── guidance.yaml
│ ├── refinement/
│ │ ├── electricity_nips-deepar.yaml
│ │ ├── electricity_nips-linear.yaml
│ │ ├── electricity_nips-seasonal_naive.yaml
│ │ ├── electricity_nips-transformer.yaml
│ │ ├── exchange_rate_nips-deepar.yaml
│ │ ├── exchange_rate_nips-linear.yaml
│ │ ├── exchange_rate_nips-seasonal_naive.yaml
│ │ ├── exchange_rate_nips-transformer.yaml
│ │ ├── kdd_cup_2018_without_missing-deepar.yaml
│ │ ├── kdd_cup_2018_without_missing-linear.yaml
│ │ ├── kdd_cup_2018_without_missing-seasonal_naive.yaml
│ │ ├── kdd_cup_2018_without_missing-transformer.yaml
│ │ ├── m4_hourly-deepar.yaml
│ │ ├── m4_hourly-linear.yaml
│ │ ├── m4_hourly-seasonal_naive.yaml
│ │ ├── m4_hourly-transformer.yaml
│ │ ├── solar_nips-deepar.yaml
│ │ ├── solar_nips-linear.yaml
│ │ ├── solar_nips-seasonal_naive.yaml
│ │ ├── solar_nips-transformer.yaml
│ │ ├── traffic_nips-deepar.yaml
│ │ ├── traffic_nips-linear.yaml
│ │ ├── traffic_nips-seasonal_naive.yaml
│ │ ├── traffic_nips-transformer.yaml
│ │ ├── uber_tlc_hourly-deepar.yaml
│ │ ├── uber_tlc_hourly-linear.yaml
│ │ ├── uber_tlc_hourly-seasonal_naive.yaml
│ │ ├── uber_tlc_hourly-transformer.yaml
│ │ ├── wiki2000_nips-deepar.yaml
│ │ ├── wiki2000_nips-linear.yaml
│ │ ├── wiki2000_nips-seasonal_naive.yaml
│ │ └── wiki2000_nips-transformer.yaml
│ ├── refinement.yaml
│ ├── train_tsdiff/
│ │ ├── train_electricity.yaml
│ │ ├── train_exchange.yaml
│ │ ├── train_kdd_cup.yaml
│ │ ├── train_m4.yaml
│ │ ├── train_missing_electricity.yaml
│ │ ├── train_missing_exchange.yaml
│ │ ├── train_missing_kdd_cup.yaml
│ │ ├── train_missing_solar.yaml
│ │ ├── train_missing_traffic.yaml
│ │ ├── train_missing_uber_tlc.yaml
│ │ ├── train_solar.yaml
│ │ ├── train_traffic.yaml
│ │ ├── train_uber_tlc.yaml
│ │ └── train_wiki.yaml
│ ├── train_tsdiff-cond/
│ │ ├── electricity_nips.yaml
│ │ ├── exchange_rate_nips.yaml
│ │ ├── kdd_cup_2018_without_missing.yaml
│ │ ├── m4_hourly.yaml
│ │ ├── missing_BM-B_electricity_nips.yaml
│ │ ├── missing_BM-B_exchange_rate_nips.yaml
│ │ ├── missing_BM-B_kdd_cup_2018_without_missing.yaml
│ │ ├── missing_BM-B_solar_nips.yaml
│ │ ├── missing_BM-B_traffic_nips.yaml
│ │ ├── missing_BM-B_uber_tlc_hourly.yaml
│ │ ├── missing_BM-E_electricity_nips.yaml
│ │ ├── missing_BM-E_exchange_rate_nips.yaml
│ │ ├── missing_BM-E_kdd_cup_2018_without_missing.yaml
│ │ ├── missing_BM-E_solar_nips.yaml
│ │ ├── missing_BM-E_traffic_nips.yaml
│ │ ├── missing_BM-E_uber_tlc_hourly.yaml
│ │ ├── missing_RM_electricity_nips.yaml
│ │ ├── missing_RM_exchange_rate_nips.yaml
│ │ ├── missing_RM_kdd_cup_2018_without_missing.yaml
│ │ ├── missing_RM_solar_nips.yaml
│ │ ├── missing_RM_traffic_nips.yaml
│ │ ├── missing_RM_uber_tlc_hourly.yaml
│ │ ├── solar_nips.yaml
│ │ ├── traffic_nips.yaml
│ │ ├── uber_tlc_hourly.yaml
│ │ └── wiki2000_nips.yaml
│ ├── train_tsdiff-cond.yaml
│ ├── train_tsdiff.yaml
│ ├── tstr/
│ │ ├── electricity_nips.yaml
│ │ ├── exchange_rate_nips.yaml
│ │ ├── kdd_cup_2018_without_missing.yaml
│ │ ├── m4_hourly.yaml
│ │ ├── solar_nips.yaml
│ │ ├── traffic_nips.yaml
│ │ ├── uber_tlc_hourly.yaml
│ │ └── wiki2000_nips.yaml
│ └── tstr.yaml
├── pyproject.toml
└── src/
└── uncond_ts_diff/
├── arch/
│ ├── __init__.py
│ ├── backbones.py
│ └── s4.py
├── configs.py
├── dataset.py
├── metrics/
│ ├── __init__.py
│ └── linear_pred_score.py
├── model/
│ ├── __init__.py
│ ├── callback.py
│ ├── diffusion/
│ │ ├── _base.py
│ │ ├── tsdiff.py
│ │ └── tsdiff_cond.py
│ └── linear/
│ ├── _estimator.py
│ └── _scaler.py
├── predictor.py
├── sampler/
│ ├── __init__.py
│ ├── _base.py
│ ├── observation_guidance.py
│ └── refiner.py
└── utils.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
__pycache__
lightning_logs/
.DS_Store
*.egg-info
/results/
/ckpts/
/saved_samples/
.vscode/
/sm_runs/
/data/
/checkpoints/
================================================
FILE: CODE_OF_CONDUCT.md
================================================
## Code of Conduct
This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).
For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact
opensource-codeofconduct@amazon.com with any additional questions or comments.
================================================
FILE: CONTRIBUTING.md
================================================
# Contributing Guidelines
Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional
documentation, we greatly value feedback and contributions from our community.
Please read through this document before submitting any issues or pull requests to ensure we have all the necessary
information to effectively respond to your bug report or contribution.
## Reporting Bugs/Feature Requests
We welcome you to use the GitHub issue tracker to report bugs or suggest features.
When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already
reported the issue. Please try to include as much information as you can. Details like these are incredibly useful:
* A reproducible test case or series of steps
* The version of our code being used
* Any modifications you've made relevant to the bug
* Anything unusual about your environment or deployment
## Contributing via Pull Requests
Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that:
1. You are working against the latest source on the *main* branch.
2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already.
3. You open an issue to discuss any significant work - we would hate for your time to be wasted.
To send us a pull request, please:
1. Fork the repository.
2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change.
3. Ensure local tests pass.
4. Commit to your fork using clear commit messages.
5. Send us a pull request, answering any default questions in the pull request interface.
6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation.
GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and
[creating a pull request](https://help.github.com/articles/creating-a-pull-request/).
## Finding contributions to work on
Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start.
## Code of Conduct
This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).
For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact
opensource-codeofconduct@amazon.com with any additional questions or comments.
## Security issue notifications
If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue.
## Licensing
See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution.
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
================================================
FILE: NOTICE
================================================
Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
================================================
FILE: README.md
================================================
# TSDiff: An Unconditional Diffusion Model for Time Series
[](https://arxiv.org/abs/2307.11494)
[](https://opensource.org/licenses/Apache-2.0)
[](https://neurips.cc/)
Fig. 1: An overview of TSDiff’s use cases. Predict: By utilizing observation self-guidance, TSDiff can be
conditioned during inference to perform predictive tasks such as forecasting. Refine: Predictions
of base forecasters can be improved by leveraging the implicit probability density of TSDiff.
Synthesize: Realistic samples generated by TSDiff can be used to train downstream forecasters achieving good
performance on real test data.
---
This repository contains the official implementation of the NeurIPS 2023 paper [*Predict, Refine, Synthesize: Self-Guiding Diffusion Models for Probabilistic Time Series Forecasting*](https://arxiv.org/abs/2307.11494). In this paper, we propose *TSDiff*, an unconditional diffusion model for time series. Our proposed self-guidance mechanism enables conditioning TSDiff for downstream tasks during inference, without requiring auxiliary networks or altering the training procedure. Furthermore, our refinement scheme leverages the implicit density learned by the diffusion model to iteratively refine the predictions of base forecasters. Finally, we demonstrate the high quality of the synthetic time series by training downstrain models solely on generated data and introducing the *Linear Predictive Score (LPS)*.
Fig. 2: Example forecasts generated by TSDiff-Q for
time series in Electricity, KDDCup, and Exchange — three datasets with different frequencies and/or prediction lengths.
## Installation
TSDiff requires Python 3.8 or higher.
* Create a conda environment (optional, but recommended).
```sh
conda create --name tsdiff --yes python=3.8 && conda activate tsdiff
```
* Install this package.
```sh
pip install --editable "."
```
> [!TIP]
> We have some updates in the `update` branch. If you're interested in testing out TSDiff or [training it on a custom dataset](https://github.com/amazon-science/unconditional-time-series-diffusion/issues/7), using the `update` branch maybe faster for training.
## Usage
### Training Models
Train models using the `train_model.py` and `train_cond_model.py` scripts for `TSDiff` and `TSDiff-Cond`, respectively. Sample configurations can be found in `configs/train_tsdiff.yaml` and `configs/train_tsdiff-cond.yaml`. Specific configurations used in the paper can be found in `configs/train_tsdiff` and `configs/train_tsdiff-cond`.
Example commands for regular (i.e., no missing values) forecasting:
```sh
# Train TSDiff on the Uber dataset for regular forecasting
python bin/train_model.py -c configs/train_tsdiff/train_uber_tlc.yaml
# Train TSDiff on the M4 dataset for regular forecasting
python bin/train_model.py -c configs/train_tsdiff/train_m4.yaml
# Train TSDiff-Cond on the Uber dataset for regular forecasting
python bin/train_cond_model.py -c configs/train_tsdiff-cond/uber_tlc_hourly.yaml
# Train TSDiff-Cond on the M4 dataset for regular forecasting
python bin/train_cond_model.py -c configs/train_tsdiff-cond/m4_hourly.yaml
```
Example commands for forecasting with missing values:
```sh
# Train TSDiff on the Uber dataset for the missing values experiment
python bin/train_model.py -c configs/train_tsdiff/train_missing_uber_tlc.yaml
# Train TSDiff on the KDDCup dataset for the missing values experiment
python bin/train_model.py -c configs/train_tsdiff/train_missing_kdd_cup.yaml
# Train TSDiff-Cond on the Uber dataset for the RM missing values experiment
python bin/train_cond_model.py -c configs/train_tsdiff-cond/missing_RM_uber_tlc_hourly.yaml
# Train TSDiff-Cond on the KDDCup dataset for the BM-B missing values experiment
python bin/train_cond_model.py -c configs/train_tsdiff-cond/missing_BM-B_kdd_cup_2018_without_missing.yaml
# Train TSDiff-Cond on the KDDCup dataset for the BM-E missing values experiment
python bin/train_cond_model.py -c configs/train_tsdiff-cond/missing_BM-E_kdd_cup_2018_without_missing.yaml
```
Note that for TSDiff we train only one model and all the missing value scenarios are evaluated using the same unconditional model. However, for TSDiff-Cond, one model is trained per missingness scenario.
### Evaluating Models
The unconditional models trained above can be used for the following tasks.
#### Predict using Observation Self-Guidance
Use the `guidance_experiment.py` script and `configs/guidance.yaml` config to run the forecasting experiments. Specific configurations used in the paper can be found in `configs/guidance/`.
Example commands:
```sh
# Run observation self-guidance on the Solar dataset
python bin/guidance_experiment.py -c configs/guidance/guidance_solar.yaml --ckpt /path/to/ckpt
# Run observation self-guidance on the KDDCup dataset
python bin/guidance_experiment.py -c configs/guidance/guidance_kdd_cup.yaml --ckpt /path/to/ckpt
```
#### Refine Predictions of Base Forecasters
Use `refinement_experiment.py` script and `configs/refinement.yaml` config to run the refinement experiments. Specific configurations used in the paper can be found in `configs/refinement/`.
Example commands:
```sh
# Refine predictions from the Linear model on the Solar dataset
python bin/refinement_experiment.py -c configs/refinement/solar_nips-linear.yaml --ckpt /path/to/ckpt
# Refine predictions from the DeepAR model on the M4 dataset
python bin/refinement_experiment.py -c configs/refinement/m4_hourly-deepar.yaml --ckpt /path/to/ckpt
```
#### Train Downstream Models using Synthetic Data
Use `tstr_experiment.py` script and `configs/tstr.yaml` config to run the _train on synthetic-test on real_ experiments. Specific configurations used in the paper can be found in `configs/tstr/`.
Example commands:
```sh
# TSTR on the Solar Dataset
python bin/tstr_experiment.py -c configs/tstr/solar_nips.yaml --ckpt /path/to/ckpt
# TSTR on the KDDCup Dataset
python bin/tstr_experiment.py -c configs/tstr/kdd_cup_2018_without_missing.yaml --ckpt /path/to/ckpt
```
## BibTeX
If you find this repository or the ideas presented in our paper useful, please consider citing.
```
@inproceedings{kollovieh2023predict,
author = {Kollovieh, Marcel and Ansari, Abdul Fatir and Bohlke-Schneider, Michael and Zschiegner, Jasper and Wang, Hao and Wang, Yuyang},
title = {Predict, Refine, Synthesize: Self-Guiding Diffusion Models for Probabilistic Time Series Forecasting},
booktitle = {Advances in Neural Information Processing Systems},
year = {2023}
}
```
## Security
See [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more information.
## License
This project is licensed under the Apache-2.0 License.
================================================
FILE: THIRD-PARTY-LICENSES.txt
================================================
** state-spaces; version 1.0 -- https://github.com/HazyResearch/state-spaces
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction, and
distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by the copyright
owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all other entities
that control, are controlled by, or are under common control with that entity.
For the purposes of this definition, "control" means (i) the power, direct or
indirect, to cause the direction or management of such entity, whether by
contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity exercising
permissions granted by this License.
"Source" form shall mean the preferred form for making modifications, including
but not limited to software source code, documentation source, and configuration
files.
"Object" form shall mean any form resulting from mechanical transformation or
translation of a Source form, including but not limited to compiled object code,
generated documentation, and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or Object form, made
available under the License, as indicated by a copyright notice that is included
in or attached to the work (an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object form, that
is based on (or derived from) the Work and for which the editorial revisions,
annotations, elaborations, or other modifications represent, as a whole, an
original work of authorship. For the purposes of this License, Derivative Works
shall not include works that remain separable from, or merely link (or bind by
name) to the interfaces of, the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including the original version
of the Work and any modifications or additions to that Work or Derivative Works
thereof, that is intentionally submitted to Licensor for inclusion in the Work
by the copyright owner or by an individual or Legal Entity authorized to submit
on behalf of the copyright owner. For the purposes of this definition,
"submitted" means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems, and
issue tracking systems that are managed by, or on behalf of, the Licensor for
the purpose of discussing and improving the Work, but excluding communication
that is conspicuously marked or otherwise designated in writing by the copyright
owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity on behalf
of whom a Contribution has been received by Licensor and subsequently
incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of this
License, each Contributor hereby grants to You a perpetual, worldwide, non-
exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce,
prepare Derivative Works of, publicly display, publicly perform, sublicense, and
distribute the Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of this License,
each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-
charge, royalty-free, irrevocable (except as stated in this section) patent
license to make, have made, use, offer to sell, sell, import, and otherwise
transfer the Work, where such license applies only to those patent claims
licensable by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s) with the Work
to which such Contribution(s) was submitted. If You institute patent litigation
against any entity (including a cross-claim or counterclaim in a lawsuit)
alleging that the Work or a Contribution incorporated within the Work
constitutes direct or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate as of the date
such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the Work or
Derivative Works thereof in any medium, with or without modifications, and in
Source or Object form, provided that You meet the following conditions:
(a) You must give any other recipients of the Work or Derivative Works a
copy of this License; and
(b) You must cause any modified files to carry prominent notices stating
that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works that You
distribute, all copyright, patent, trademark, and attribution notices from the
Source form of the Work, excluding those notices that do not pertain to any part
of the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its distribution,
then any Derivative Works that You distribute must include a readable copy of
the attribution notices contained within such NOTICE file, excluding those
notices that do not pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed as part of the
Derivative Works; within the Source form or documentation, if provided along
with the Derivative Works; or, within a display generated by the Derivative
Works, if and wherever such third-party notices normally appear. The contents of
the NOTICE file are for informational purposes only and do not modify the
License. You may add Your own attribution notices within Derivative Works that
You distribute, alongside or as an addendum to the NOTICE text from the Work,
provided that such additional attribution notices cannot be construed as
modifying the License.
You may add Your own copyright statement to Your modifications and may
provide additional or different license terms and conditions for use,
reproduction, or distribution of Your modifications, or for any such Derivative
Works as a whole, provided Your use, reproduction, and distribution of the Work
otherwise complies with the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise, any
Contribution intentionally submitted for inclusion in the Work by You to the
Licensor shall be under the terms and conditions of this License, without any
additional terms or conditions. Notwithstanding the above, nothing herein shall
supersede or modify the terms of any separate license agreement you may have
executed with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade names,
trademarks, service marks, or product names of the Licensor, except as required
for reasonable and customary use in describing the origin of the Work and
reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or agreed to in
writing, Licensor provides the Work (and each Contributor provides its
Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied, including, without limitation, any warranties
or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any risks
associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory, whether in
tort (including negligence), contract, or otherwise, unless required by
applicable law (such as deliberate and grossly negligent acts) or agreed to in
writing, shall any Contributor be liable to You for damages, including any
direct, indirect, special, incidental, or consequential damages of any character
arising as a result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill, work stoppage,
computer failure or malfunction, or any and all other commercial damages or
losses), even if such Contributor has been advised of the possibility of such
damages.
9. Accepting Warranty or Additional Liability. While redistributing the Work or
Derivative Works thereof, You may choose to offer, and charge a fee for,
acceptance of support, warranty, indemnity, or other liability obligations
and/or rights consistent with this License. However, in accepting such
obligations, You may act only on Your own behalf and on Your sole
responsibility, not on behalf of any other Contributor, and only if You agree to
indemnify, defend, and hold each Contributor harmless for any liability incurred
by, or claims asserted against, such Contributor by reason of your accepting any
such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following boilerplate
notice, with the fields enclosed by brackets "[]" replaced with your own
identifying information. (Don't include the brackets!) The text should be
enclosed in the appropriate comment syntax for the file format. We also
recommend that a file or class name and description of purpose be included on
the same "printed page" as the copyright notice for easier identification within
third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
* For state-spaces see also this required NOTICE:
Copyright 2022 Albert Gu and Karan Goel and Christopher Re
================================================
FILE: bin/guidance_experiment.py
================================================
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import logging
import argparse
from pathlib import Path
import yaml
import torch
from tqdm.auto import tqdm
from gluonts.dataset.field_names import FieldName
from gluonts.evaluation import make_evaluation_predictions, Evaluator
from uncond_ts_diff.utils import (
create_transforms,
create_splitter,
get_next_file_num,
add_config_to_argparser,
filter_metrics,
MaskInput,
)
from uncond_ts_diff.model import TSDiff
from uncond_ts_diff.dataset import get_gts_dataset
from uncond_ts_diff.sampler import (
DDPMGuidance,
DDIMGuidance,
)
import uncond_ts_diff.configs as diffusion_configs
guidance_map = {"ddpm": DDPMGuidance, "ddim": DDIMGuidance}
def load_model(config):
model = TSDiff(
**getattr(
diffusion_configs,
config.get("diffusion_config", "diffusion_small_config"),
),
freq=config["freq"],
use_features=config["use_features"],
use_lags=config["use_lags"],
normalization="mean",
context_length=config["context_length"],
prediction_length=config["prediction_length"],
init_skip=config["init_skip"],
)
model.load_state_dict(
torch.load(config["ckpt"], map_location="cpu"),
strict=True,
)
model = model.to(config["device"])
return model
def evaluate_guidance(
config, model, test_dataset, transformation, num_samples=100
):
logger.info(f"Evaluating with {num_samples} samples.")
results = []
if config["setup"] == "forecasting":
missing_data_kwargs_list = [
{
"missing_scenario": "none",
"missing_values": 0,
}
]
config["missing_data_configs"] = missing_data_kwargs_list
elif config["setup"] == "missing_values":
missing_data_kwargs_list = config["missing_data_configs"]
else:
raise ValueError(f"Unknown setup {config['setup']}")
Guidance = guidance_map[config["sampler"]]
sampler_params = config["sampler_params"]
for missing_data_kwargs in missing_data_kwargs_list:
logger.info(
f"Evaluating scenario '{missing_data_kwargs['missing_scenario']}' "
f"with {missing_data_kwargs['missing_values']:.1f} missing_values."
)
sampler = Guidance(
model=model,
prediction_length=config["prediction_length"],
num_samples=num_samples,
**missing_data_kwargs,
**sampler_params,
)
transformed_testdata = transformation.apply(
test_dataset, is_train=False
)
test_splitter = create_splitter(
past_length=config["context_length"] + max(model.lags_seq),
future_length=config["prediction_length"],
mode="test",
)
masking_transform = MaskInput(
FieldName.TARGET,
FieldName.OBSERVED_VALUES,
config["context_length"],
missing_data_kwargs["missing_scenario"],
missing_data_kwargs["missing_values"],
)
test_transform = test_splitter + masking_transform
predictor = sampler.get_predictor(
test_transform,
batch_size=1280 // num_samples,
device=config["device"],
)
forecast_it, ts_it = make_evaluation_predictions(
dataset=transformed_testdata,
predictor=predictor,
num_samples=num_samples,
)
forecasts = list(tqdm(forecast_it, total=len(transformed_testdata)))
tss = list(ts_it)
evaluator = Evaluator()
metrics, _ = evaluator(tss, forecasts)
metrics = filter_metrics(metrics)
results.append(dict(**missing_data_kwargs, **metrics))
return results
def main(config: dict, log_dir: str):
# Read global parameters
dataset_name = config["dataset"]
freq = config["freq"]
prediction_length = config["prediction_length"]
num_samples = config["num_samples"]
# Load dataset and model
logger.info("Loading model")
model = load_model(config)
dataset = get_gts_dataset(dataset_name)
assert dataset.metadata.freq == freq
assert dataset.metadata.prediction_length == prediction_length
# Setup data transformation and loading
transformation = create_transforms(
num_feat_dynamic_real=0,
num_feat_static_cat=0,
num_feat_static_real=0,
time_features=model.time_features,
prediction_length=prediction_length,
)
# Run guidance
results = evaluate_guidance(
config, model, dataset.test, transformation, num_samples=num_samples
)
# Save results
log_dir = Path(log_dir) / "guidance_logs"
log_dir.mkdir(exist_ok=True, parents=True)
base_filename = "results"
run_num = get_next_file_num(
base_filename, log_dir, file_type="yaml", separator="-"
)
save_path = log_dir / f"{base_filename}-{run_num}.yaml"
with open(save_path, "w") as fp:
yaml.safe_dump(
{"config": config, "metrics": results},
fp,
default_flow_style=False,
sort_keys=False,
)
if __name__ == "__main__":
# Setup Logger
logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__file__)
logger.setLevel(logging.INFO)
# Setup argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"-c", "--config", type=str, required=True, help="Path to yaml config"
)
parser.add_argument(
"--out_dir", type=str, default="./results", help="Path to results dir"
)
args, _ = parser.parse_known_args()
with open(args.config, "r") as fp:
config = yaml.safe_load(fp)
# Update config from command line
parser = add_config_to_argparser(config=config, parser=parser)
args = parser.parse_args()
config_updates = vars(args)
for k in config.keys() & config_updates.keys():
orig_val = config[k]
updated_val = config_updates[k]
if updated_val != orig_val:
logger.info(f"Updated key '{k}': {orig_val} -> {updated_val}")
config.update(config_updates)
main(config=config, log_dir=args.out_dir)
================================================
FILE: bin/refinement_experiment.py
================================================
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import json
import copy
import logging
import argparse
from pathlib import Path
import yaml
import torch
import numpy as np
from tqdm.auto import tqdm
from gluonts.mx import DeepAREstimator, TransformerEstimator
from gluonts.model.seasonal_naive import SeasonalNaivePredictor
from gluonts.evaluation import make_evaluation_predictions, Evaluator
from gluonts.dataset.loader import TrainDataLoader
from gluonts.itertools import Cached
from gluonts.torch.batchify import batchify
from uncond_ts_diff.utils import (
create_transforms,
create_splitter,
get_next_file_num,
add_config_to_argparser,
filter_metrics,
)
from uncond_ts_diff.model import TSDiff, LinearEstimator
from uncond_ts_diff.dataset import get_gts_dataset
from uncond_ts_diff.sampler import (
MostLikelyRefiner,
MCMCRefiner,
DDPMGuidance,
DDIMGuidance,
)
import uncond_ts_diff.configs as diffusion_configs
guidance_map = {"ddpm": DDPMGuidance, "ddim": DDIMGuidance}
refiner_map = {"most_likely": MostLikelyRefiner, "mcmc": MCMCRefiner}
def load_model(config):
model = TSDiff(
**getattr(
diffusion_configs,
config.get("diffusion_config", "diffusion_small_config"),
),
freq=config["freq"],
use_features=config["use_features"],
use_lags=config["use_lags"],
normalization="mean",
context_length=config["context_length"],
prediction_length=config["prediction_length"],
init_skip=config["init_skip"],
)
model.load_state_dict(
torch.load(config["ckpt"], map_location="cpu"),
strict=True,
)
model = model.to(config["device"])
return model
def get_best_diffusion_step(model: TSDiff, data_loader, device):
losses = np.zeros(model.timesteps)
batch = {
k: v.to(device)
for k, v in next(iter(data_loader)).items()
if isinstance(v, torch.Tensor)
}
x, features, scale = model._extract_features(batch)
for t in range(model.timesteps):
loss, _, _ = model.p_losses(
x.to(device), torch.tensor([t], device=device)
)
losses[t] = loss
best_t = ((losses - losses.mean()) ** 2).argmin()
return best_t
def train_and_forecast_base_model(dataset, base_model_name, config):
base_model_kwargs = config.get("base_model_params", {})
if base_model_name == "deepar":
predictor = DeepAREstimator(
prediction_length=dataset.metadata.prediction_length,
freq=dataset.metadata.freq,
**base_model_kwargs,
).train(list(dataset.train), cache_data=True)
elif base_model_name == "transformer":
predictor = TransformerEstimator(
prediction_length=dataset.metadata.prediction_length,
freq=dataset.metadata.freq,
**base_model_kwargs,
).train(list(dataset.train), cache_data=True)
elif base_model_name == "seasonal_naive":
predictor = SeasonalNaivePredictor(
freq=dataset.metadata.freq,
prediction_length=dataset.metadata.prediction_length,
**base_model_kwargs,
)
elif base_model_name == "linear":
num_train_samples = 10000
predictor = LinearEstimator(
freq=dataset.metadata.freq,
prediction_length=dataset.metadata.prediction_length,
context_length=config["context_length"],
num_train_samples=num_train_samples,
**base_model_kwargs,
).train(list(dataset.train), cache_data=True)
else:
raise ValueError(f"Unsupported base model {base_model_name}!")
fcst_iter, ts_iter = make_evaluation_predictions(
dataset=dataset.test,
predictor=predictor,
num_samples=config["num_samples"],
)
fcsts = list(tqdm(fcst_iter, total=len(dataset.test)))
tss = list(ts_iter)
return fcsts, tss
def forecast_guidance(
dataset,
base_model_name,
config,
diffusion_model,
transformed_testdata,
test_splitter,
):
assert len(dataset.test) == len(transformed_testdata)
base_model_kwargs = config.get("base_model_params", {})
Guidance = guidance_map[base_model_name]
predictor = Guidance(
model=diffusion_model,
prediction_length=dataset.metadata.prediction_length,
num_samples=config["num_samples"],
**base_model_kwargs,
).get_predictor(
input_transform=test_splitter,
batch_size=1280 // config["num_samples"],
device=config["device"],
)
fcst_iter, ts_iter = make_evaluation_predictions(
dataset=transformed_testdata,
predictor=predictor,
num_samples=config["num_samples"],
)
fcsts = list(tqdm(fcst_iter, total=len(dataset.test)))
tss = list(ts_iter)
return fcsts, tss
def main(config: dict, log_dir: str):
# Read global parameters
dataset_name = config["dataset"]
device = config["device"]
context_length = config["context_length"]
prediction_length = config["prediction_length"]
base_model_name = config["base_model"]
num_samples = config["num_samples"]
# Load dataset and model
logger.info("Loading model")
dataset = get_gts_dataset(dataset_name)
config["freq"] = dataset.metadata.freq
assert prediction_length == dataset.metadata.prediction_length
model = load_model(config)
# Setup data transformation and loading
transformation = create_transforms(
num_feat_dynamic_real=0,
num_feat_static_cat=0,
num_feat_static_real=0,
time_features=model.time_features,
prediction_length=prediction_length,
)
transformed_data = transformation.apply(list(dataset.train), is_train=True)
transformed_testdata = transformation.apply(
list(dataset.test), is_train=False
)
training_splitter = create_splitter(
past_length=context_length + max(model.lags_seq),
future_length=prediction_length,
mode="train",
)
test_splitter = create_splitter(
past_length=context_length + max(model.lags_seq),
future_length=prediction_length,
mode="test",
)
train_dataloader = TrainDataLoader(
Cached(transformed_data),
batch_size=1024,
stack_fn=batchify,
transform=training_splitter,
num_batches_per_epoch=2048,
)
best_t = get_best_diffusion_step(model, train_dataloader, device)
# Train base model & get initial forecasts
logger.info("Training base model")
if base_model_name in {"ddpm", "ddim"}:
base_fcsts, tss = forecast_guidance(
dataset,
base_model_name,
config,
diffusion_model=model,
transformed_testdata=transformed_testdata,
test_splitter=test_splitter,
)
else:
base_fcsts, tss = train_and_forecast_base_model(
dataset, base_model_name, config
)
# Evaluate base forecasts
evaluator = Evaluator()
baseline_metrics, _ = evaluator(tss, base_fcsts)
baseline_metrics = filter_metrics(baseline_metrics)
# Run refinement
log_dir = Path(log_dir) / "refinement_logs"
log_dir.mkdir(exist_ok=True, parents=True)
base_filename = "results"
run_num = get_next_file_num(
base_filename, log_dir, file_type="yaml", separator="-"
)
save_path = log_dir / f"{base_filename}-{run_num}.yaml"
results = [
{
"model": "baseline",
"model_params": {
"name": base_model_name,
**config.get("base_model_params", {}),
},
**baseline_metrics,
}
]
n_refiner_configs = len(config["refiner_configs"])
for i, ref_config in enumerate(config["refiner_configs"]):
logger.info(
f"Running refiner ({i+1}/{n_refiner_configs}): {json.dumps(ref_config)}"
)
refiner_config = copy.deepcopy(ref_config)
refiner_name = refiner_config.pop("refiner_name")
Refiner = refiner_map[refiner_name]
refiner = Refiner(
model,
prediction_length,
init=iter(base_fcsts),
num_samples=num_samples,
fixed_t=best_t,
iterations=config["iterations"],
**refiner_config,
)
refiner_predictor = refiner.get_predictor(
test_splitter, batch_size=1024 // num_samples, device=device
)
forecast_it, ts_it = make_evaluation_predictions(
dataset=transformed_testdata,
predictor=refiner_predictor,
num_samples=num_samples,
)
evaluator = Evaluator()
refined_metrics, _ = evaluator(
list(ts_it),
list(tqdm(forecast_it, total=len(transformed_testdata))),
)
refined_metrics = filter_metrics(refined_metrics)
results.append(
{
"model": refiner_name,
"model_params": json.dumps(ref_config),
**refined_metrics,
}
)
with open(save_path, "w") as fp:
yaml.safe_dump(
{"config": config, "metrics": results},
fp,
default_flow_style=False,
sort_keys=False,
)
if __name__ == "__main__":
# Setup Logger
logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__file__)
logger.setLevel(logging.INFO)
# Setup argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"-c", "--config", type=str, required=True, help="Path to yaml config"
)
parser.add_argument(
"--out_dir", type=str, default="./results", help="Path to results dir"
)
args, _ = parser.parse_known_args()
with open(args.config, "r") as fp:
config = yaml.safe_load(fp)
# Update config from command line
parser = add_config_to_argparser(config=config, parser=parser)
args = parser.parse_args()
config_updates = vars(args)
for k in config.keys() & config_updates.keys():
orig_val = config[k]
updated_val = config_updates[k]
if updated_val != orig_val:
logger.info(f"Updated key '{k}': {orig_val} -> {updated_val}")
config.update(config_updates)
main(config=config, log_dir=args.out_dir)
================================================
FILE: bin/train_cond_model.py
================================================
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import logging
import argparse
from pathlib import Path
import yaml
import torch
from tqdm.auto import tqdm
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, RichProgressBar
from gluonts.dataset.loader import TrainDataLoader, ValidationDataLoader
from gluonts.dataset.split import OffsetSplitter
from gluonts.itertools import Cached
from gluonts.torch.batchify import batchify
from gluonts.evaluation import make_evaluation_predictions, Evaluator
from gluonts.dataset.field_names import FieldName
import uncond_ts_diff.configs as diffusion_configs
from uncond_ts_diff.dataset import get_gts_dataset
from uncond_ts_diff.model import TSDiffCond
from uncond_ts_diff.utils import (
create_transforms,
create_splitter,
add_config_to_argparser,
filter_metrics,
MaskInput,
ConcatDataset,
)
def create_model(config):
model = TSDiffCond(
**getattr(diffusion_configs, config["diffusion_config"]),
freq=config["freq"],
use_features=config["use_features"],
use_lags=config["use_lags"],
normalization=config["normalization"],
context_length=config["context_length"],
prediction_length=config["prediction_length"],
lr=config["lr"],
init_skip=config["init_skip"],
noise_observed=config["noise_observed"],
)
model.to(config["device"])
return model
def evaluate_conditional(
config,
model: TSDiffCond,
test_dataset,
transformation,
num_samples=100,
):
logger.info(f"Evaluating with {num_samples} samples.")
logger.info(
f"Evaluating scenario '{config['missing_scenario']}' "
f"with {config['missing_values']:.1f} missing_values."
)
results = []
transformed_testdata = transformation.apply(test_dataset, is_train=False)
test_splitter = create_splitter(
past_length=config["context_length"] + max(model.lags_seq),
future_length=config["prediction_length"],
mode="test",
)
masking_transform = MaskInput(
FieldName.TARGET,
FieldName.OBSERVED_VALUES,
config["context_length"],
config["missing_scenario"],
config["missing_values"],
)
test_transform = test_splitter + masking_transform
predictor = model.get_predictor(
test_transform,
batch_size=1280,
device=config["device"],
)
forecast_it, ts_it = make_evaluation_predictions(
dataset=transformed_testdata,
predictor=predictor,
num_samples=num_samples,
)
forecasts = list(tqdm(forecast_it, total=len(transformed_testdata)))
tss = list(ts_it)
evaluator = Evaluator()
metrics, _ = evaluator(tss, forecasts)
metrics = filter_metrics(metrics)
results.append(dict(**metrics))
return results
def main(config, log_dir):
# Load parameters
dataset_name = config["dataset"]
freq = config["freq"]
context_length = config["context_length"]
prediction_length = config["prediction_length"]
total_length = context_length + prediction_length
# Create model
model = create_model(config)
# Setup dataset and data loading
dataset = get_gts_dataset(dataset_name)
assert dataset.metadata.freq == freq
assert dataset.metadata.prediction_length == prediction_length
if config["setup"] == "forecasting":
training_data = dataset.train
elif config["setup"] == "missing_values":
missing_values_splitter = OffsetSplitter(offset=-total_length)
training_data, _ = missing_values_splitter.split(dataset.train)
num_rolling_evals = int(len(dataset.test) / len(dataset.train))
transformation = create_transforms(
num_feat_dynamic_real=0,
num_feat_static_cat=0,
num_feat_static_real=0,
time_features=model.time_features,
prediction_length=config["prediction_length"],
)
training_splitter = create_splitter(
past_length=config["context_length"] + max(model.lags_seq),
future_length=config["prediction_length"],
mode="train",
)
if config["setup"] == "forecasting":
config["missing_scenario"] = "none"
config["missing_values"] = 0
masking_transform = MaskInput(
FieldName.TARGET,
FieldName.OBSERVED_VALUES,
config["context_length"],
config.get("train_missing_scenario", config["missing_scenario"]),
config["missing_values"],
)
train_transform = training_splitter + masking_transform
callbacks = []
val_loader = None
if config["use_validation_set"]:
transformed_data = transformation.apply(training_data, is_train=True)
train_val_splitter = OffsetSplitter(
offset=-config["prediction_length"] * num_rolling_evals
)
_, val_gen = train_val_splitter.split(training_data)
val_dataset = ConcatDataset(
val_gen.generate_instances(
config["prediction_length"], num_rolling_evals
)
)
val_splitter = create_splitter(
past_length=config["context_length"] + max(model.lags_seq),
future_length=config["prediction_length"],
mode="val",
)
transformed_valdata = transformation.apply(val_dataset, is_train=True)
val_loader = ValidationDataLoader(
transformed_valdata,
batch_size=1280,
stack_fn=batchify,
transform=val_splitter + masking_transform,
)
callbacks = []
log_monitor = "valid_loss"
else:
transformed_data = transformation.apply(training_data, is_train=True)
log_monitor = "train_loss"
filename = dataset_name + "-{epoch:03d}-{train_loss:.3f}"
data_loader = TrainDataLoader(
Cached(transformed_data),
batch_size=config["batch_size"],
stack_fn=batchify,
transform=train_transform,
num_batches_per_epoch=config["num_batches_per_epoch"],
)
checkpoint_callback = ModelCheckpoint(
save_top_k=3,
monitor=f"{log_monitor}",
mode="min",
filename=filename,
save_last=True,
save_weights_only=True,
)
callbacks.append(checkpoint_callback)
callbacks.append(RichProgressBar())
trainer = pl.Trainer(
accelerator="gpu" if torch.cuda.is_available() else None,
devices=[int(config["device"].split(":")[-1])],
max_epochs=config["max_epochs"],
enable_progress_bar=True,
num_sanity_val_steps=0,
callbacks=callbacks,
default_root_dir=log_dir,
gradient_clip_val=config.get("gradient_clip_val", None),
check_val_every_n_epoch=config["eval_every"],
)
logger.info(f"Logging to {trainer.logger.log_dir}")
trainer.fit(
model, train_dataloaders=data_loader, val_dataloaders=val_loader
)
logger.info("Training completed.")
best_ckpt_path = Path(trainer.logger.log_dir) / "best_checkpoint.ckpt"
if not best_ckpt_path.exists():
torch.save(
torch.load(checkpoint_callback.best_model_path)["state_dict"],
best_ckpt_path,
)
logger.info(f"Loading {best_ckpt_path}.")
best_state_dict = torch.load(best_ckpt_path)
model.load_state_dict(best_state_dict, strict=True)
metrics = (
evaluate_conditional(config, model, dataset.test, transformation)
if config.get("do_final_eval", True)
else "Final eval not performed"
)
with open(Path(trainer.logger.log_dir) / "results.yaml", "w") as fp:
yaml.dump(
{
"config": config,
"version": trainer.logger.version,
"metrics": metrics,
},
fp,
)
if __name__ == "__main__":
# Setup Logger
logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__file__)
logger.setLevel(logging.INFO)
# Setup argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"-c", "--config", type=str, required=True, help="Path to yaml config"
)
parser.add_argument(
"--out_dir", type=str, default="./", help="Path to results dir"
)
args, _ = parser.parse_known_args()
with open(args.config, "r") as fp:
config = yaml.safe_load(fp)
# Update config from command line
parser = add_config_to_argparser(config=config, parser=parser)
args = parser.parse_args()
config_updates = vars(args)
for k in config.keys() & config_updates.keys():
orig_val = config[k]
updated_val = config_updates[k]
if updated_val != orig_val:
logger.info(f"Updated key '{k}': {orig_val} -> {updated_val}")
config.update(config_updates)
main(config=config, log_dir=args.out_dir)
================================================
FILE: bin/train_model.py
================================================
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import logging
import argparse
from pathlib import Path
import yaml
import torch
from tqdm.auto import tqdm
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, RichProgressBar
from gluonts.dataset.loader import TrainDataLoader
from gluonts.dataset.split import OffsetSplitter
from gluonts.itertools import Cached
from gluonts.torch.batchify import batchify
from gluonts.evaluation import make_evaluation_predictions, Evaluator
from gluonts.dataset.field_names import FieldName
import uncond_ts_diff.configs as diffusion_configs
from uncond_ts_diff.dataset import get_gts_dataset
from uncond_ts_diff.model.callback import EvaluateCallback
from uncond_ts_diff.model import TSDiff
from uncond_ts_diff.sampler import DDPMGuidance, DDIMGuidance
from uncond_ts_diff.utils import (
create_transforms,
create_splitter,
add_config_to_argparser,
filter_metrics,
MaskInput,
)
guidance_map = {"ddpm": DDPMGuidance, "ddim": DDIMGuidance}
def create_model(config):
model = TSDiff(
**getattr(diffusion_configs, config["diffusion_config"]),
freq=config["freq"],
use_features=config["use_features"],
use_lags=config["use_lags"],
normalization=config["normalization"],
context_length=config["context_length"],
prediction_length=config["prediction_length"],
lr=config["lr"],
init_skip=config["init_skip"],
)
model.to(config["device"])
return model
def evaluate_guidance(
config, model, test_dataset, transformation, num_samples=100
):
logger.info(f"Evaluating with {num_samples} samples.")
results = []
if config["setup"] == "forecasting":
missing_data_kwargs_list = [
{
"missing_scenario": "none",
"missing_values": 0,
}
]
config["missing_data_configs"] = missing_data_kwargs_list
elif config["setup"] == "missing_values":
missing_data_kwargs_list = config["missing_data_configs"]
else:
raise ValueError(f"Unknown setup {config['setup']}")
Guidance = guidance_map[config["sampler"]]
sampler_kwargs = config["sampler_params"]
for missing_data_kwargs in missing_data_kwargs_list:
logger.info(
f"Evaluating scenario '{missing_data_kwargs['missing_scenario']}' "
f"with {missing_data_kwargs['missing_values']:.1f} missing_values."
)
sampler = Guidance(
model=model,
prediction_length=config["prediction_length"],
num_samples=num_samples,
**missing_data_kwargs,
**sampler_kwargs,
)
transformed_testdata = transformation.apply(
test_dataset, is_train=False
)
test_splitter = create_splitter(
past_length=config["context_length"] + max(model.lags_seq),
future_length=config["prediction_length"],
mode="test",
)
masking_transform = MaskInput(
FieldName.TARGET,
FieldName.OBSERVED_VALUES,
config["context_length"],
missing_data_kwargs["missing_scenario"],
missing_data_kwargs["missing_values"],
)
test_transform = test_splitter + masking_transform
predictor = sampler.get_predictor(
test_transform,
batch_size=1280 // num_samples,
device=config["device"],
)
forecast_it, ts_it = make_evaluation_predictions(
dataset=transformed_testdata,
predictor=predictor,
num_samples=num_samples,
)
forecasts = list(tqdm(forecast_it, total=len(transformed_testdata)))
tss = list(ts_it)
evaluator = Evaluator()
metrics, _ = evaluator(tss, forecasts)
metrics = filter_metrics(metrics)
results.append(dict(**missing_data_kwargs, **metrics))
return results
def main(config, log_dir):
# Load parameters
dataset_name = config["dataset"]
freq = config["freq"]
context_length = config["context_length"]
prediction_length = config["prediction_length"]
total_length = context_length + prediction_length
# Create model
model = create_model(config)
# Setup dataset and data loading
dataset = get_gts_dataset(dataset_name)
assert dataset.metadata.freq == freq
assert dataset.metadata.prediction_length == prediction_length
if config["setup"] == "forecasting":
training_data = dataset.train
elif config["setup"] == "missing_values":
missing_values_splitter = OffsetSplitter(offset=-total_length)
training_data, _ = missing_values_splitter.split(dataset.train)
num_rolling_evals = int(len(dataset.test) / len(dataset.train))
transformation = create_transforms(
num_feat_dynamic_real=0,
num_feat_static_cat=0,
num_feat_static_real=0,
time_features=model.time_features,
prediction_length=config["prediction_length"],
)
training_splitter = create_splitter(
past_length=config["context_length"] + max(model.lags_seq),
future_length=config["prediction_length"],
mode="train",
)
callbacks = []
if config["use_validation_set"]:
transformed_data = transformation.apply(training_data, is_train=True)
train_val_splitter = OffsetSplitter(
offset=-config["prediction_length"] * num_rolling_evals
)
_, val_gen = train_val_splitter.split(training_data)
val_data = val_gen.generate_instances(
config["prediction_length"], num_rolling_evals
)
callbacks = [
EvaluateCallback(
context_length=config["context_length"],
prediction_length=config["prediction_length"],
sampler=config["sampler"],
sampler_kwargs=config["sampler_params"],
num_samples=config["num_samples"],
model=model,
transformation=transformation,
test_dataset=dataset.test,
val_dataset=val_data,
eval_every=config["eval_every"],
)
]
else:
transformed_data = transformation.apply(training_data, is_train=True)
log_monitor = "train_loss"
filename = dataset_name + "-{epoch:03d}-{train_loss:.3f}"
data_loader = TrainDataLoader(
Cached(transformed_data),
batch_size=config["batch_size"],
stack_fn=batchify,
transform=training_splitter,
num_batches_per_epoch=config["num_batches_per_epoch"],
)
checkpoint_callback = ModelCheckpoint(
save_top_k=3,
monitor=f"{log_monitor}",
mode="min",
filename=filename,
save_last=True,
save_weights_only=True,
)
callbacks.append(checkpoint_callback)
callbacks.append(RichProgressBar())
trainer = pl.Trainer(
accelerator="gpu" if torch.cuda.is_available() else None,
devices=[int(config["device"].split(":")[-1])],
max_epochs=config["max_epochs"],
enable_progress_bar=True,
num_sanity_val_steps=0,
callbacks=callbacks,
default_root_dir=log_dir,
gradient_clip_val=config.get("gradient_clip_val", None),
)
logger.info(f"Logging to {trainer.logger.log_dir}")
trainer.fit(model, train_dataloaders=data_loader)
logger.info("Training completed.")
best_ckpt_path = Path(trainer.logger.log_dir) / "best_checkpoint.ckpt"
if not best_ckpt_path.exists():
torch.save(
torch.load(checkpoint_callback.best_model_path)["state_dict"],
best_ckpt_path,
)
logger.info(f"Loading {best_ckpt_path}.")
best_state_dict = torch.load(best_ckpt_path)
model.load_state_dict(best_state_dict, strict=True)
metrics = (
evaluate_guidance(config, model, dataset.test, transformation)
if config.get("do_final_eval", True)
else "Final eval not performed"
)
with open(Path(trainer.logger.log_dir) / "results.yaml", "w") as fp:
yaml.dump(
{
"config": config,
"version": trainer.logger.version,
"metrics": metrics,
},
fp,
)
if __name__ == "__main__":
# Setup Logger
logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__file__)
logger.setLevel(logging.INFO)
# Setup argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"-c", "--config", type=str, required=True, help="Path to yaml config"
)
parser.add_argument(
"--out_dir", type=str, default="./", help="Path to results dir"
)
args, _ = parser.parse_known_args()
with open(args.config, "r") as fp:
config = yaml.safe_load(fp)
# Update config from command line
parser = add_config_to_argparser(config=config, parser=parser)
args = parser.parse_args()
config_updates = vars(args)
for k in config.keys() & config_updates.keys():
orig_val = config[k]
updated_val = config_updates[k]
if updated_val != orig_val:
logger.info(f"Updated key '{k}': {orig_val} -> {updated_val}")
config.update(config_updates)
main(config=config, log_dir=args.out_dir)
================================================
FILE: bin/tstr_experiment.py
================================================
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from functools import partial
import math
import logging
import argparse
from pathlib import Path
import yaml
import torch
import numpy as np
from tqdm.auto import tqdm
from gluonts.mx import DeepAREstimator, TransformerEstimator
from gluonts.evaluation import Evaluator
from gluonts.dataset.loader import TrainDataLoader
from gluonts.itertools import Cached
from gluonts.torch.batchify import batchify
from gluonts.time_feature import (
get_lags_for_frequency,
time_features_from_frequency_str,
)
from gluonts.dataset.split import slice_data_entry
from gluonts.transform import AdhocTransform, Chain
from uncond_ts_diff.utils import (
ScaleAndAddMeanFeature,
ScaleAndAddMinMaxFeature,
GluonTSNumpyDataset,
create_transforms,
create_splitter,
get_next_file_num,
add_config_to_argparser,
make_evaluation_predictions_with_scaling,
filter_metrics,
)
from uncond_ts_diff.model import TSDiff, LinearEstimator
from uncond_ts_diff.dataset import get_gts_dataset
import uncond_ts_diff.configs as diffusion_configs
DOWNSTREAM_MODELS = ["linear", "deepar", "transformer"]
def load_model(config):
model = TSDiff(
**getattr(
diffusion_configs,
config.get("diffusion_config", "diffusion_small_config"),
),
freq=config["freq"],
use_features=config["use_features"],
use_lags=config["use_lags"],
normalization="mean",
context_length=config["context_length"],
prediction_length=config["prediction_length"],
init_skip=config["init_skip"],
)
model.load_state_dict(
torch.load(config["ckpt"], map_location="cpu"),
strict=True,
)
model = model.to(config["device"])
return model
def sample_synthetic(
model: TSDiff,
num_samples: int = 10_000,
batch_size: int = 1000,
):
synth_samples = []
n_iters = math.ceil(num_samples / batch_size)
for _ in tqdm(range(n_iters)):
samples = model.sample_n(num_samples=batch_size)
synth_samples.append(samples)
synth_samples = np.concatenate(synth_samples, axis=0)[:num_samples]
return synth_samples
def sample_real(
data_loader,
n_timesteps: int,
num_samples: int = 10_000,
batch_size: int = 1000,
):
real_samples = []
data_iter = iter(data_loader)
n_iters = math.ceil(num_samples / batch_size)
for _ in tqdm(range(n_iters)):
try:
batch = next(data_iter)
except StopIteration:
data_iter = iter(data_loader)
batch = next(data_iter)
ts = np.concatenate(
[batch["past_target"], batch["future_target"]], axis=-1
)[:, -n_timesteps:]
real_samples.append(ts)
real_samples = np.concatenate(real_samples, axis=0)[:num_samples]
return real_samples
def evaluate_tstr(
tstr_predictor,
test_dataset,
context_length,
prediction_length,
num_samples=100,
scaling_type="mean",
):
total_length = context_length + prediction_length
# Slice test set to be of the same length as context_length + prediction_length
slice_func = partial(slice_data_entry, slice_=slice(-total_length, None))
if scaling_type == "mean":
ScaleAndAddScaleFeature = ScaleAndAddMeanFeature
elif scaling_type == "min-max":
ScaleAndAddScaleFeature = ScaleAndAddMinMaxFeature
transformation = Chain(
[
AdhocTransform(slice_func),
# Add scale to data entry for use later during evaluation
ScaleAndAddScaleFeature("target", "scale", prediction_length),
]
)
sliced_test_set = transformation.apply(test_dataset)
fcst_iter, ts_iter = make_evaluation_predictions_with_scaling(
dataset=sliced_test_set,
predictor=tstr_predictor,
num_samples=num_samples,
scaling_type=scaling_type,
)
evaluator = Evaluator()
metrics, _ = evaluator(list(ts_iter), list(fcst_iter))
return filter_metrics(metrics)
def train_and_evaluate(
dataset,
model_name,
synth_samples,
real_samples,
config,
scaling_type="mean",
):
# NOTE: There's no notion of time for synthetic time series,
# they are just "sequences".
# A dummy timestamp is used for start time in synthetic time series.
# Hence, time_features are set to [] in the models below.
model_name = model_name.lower()
freq = dataset.metadata.freq
context_length = config["context_length"]
prediction_length = config["prediction_length"]
total_length = context_length + prediction_length
assert len(synth_samples) == len(real_samples)
assert (
synth_samples.shape[-1] == total_length
and real_samples.shape[-1] == total_length
)
num_samples = len(real_samples)
synthetic_dataset = GluonTSNumpyDataset(synth_samples)
if model_name == "linear":
logger.info(f"Running TSTR for {model_name}")
tstr_predictor = LinearEstimator(
freq=freq, # Not actually used in the estimator
prediction_length=prediction_length,
context_length=context_length,
num_train_samples=num_samples,
# Synthetic dataset is in the "scaled space"
scaling=False,
).train(synthetic_dataset)
elif model_name == "deepar":
logger.info(f"Running TSTR for {model_name}")
tstr_predictor = DeepAREstimator(
freq=freq,
prediction_length=prediction_length,
# Synthetic dataset is in the "scaled space"
scaling=False,
time_features=[],
lags_seq=get_lags_for_frequency(freq, lag_ub=context_length),
).train(synthetic_dataset)
elif model_name == "transformer":
logger.info(f"Running TSTR for {model_name}")
tstr_predictor = TransformerEstimator(
freq=freq,
prediction_length=prediction_length,
# Synthetic dataset is in the "scaled space"
scaling=False,
time_features=[],
lags_seq=get_lags_for_frequency(freq, lag_ub=context_length),
).train(synthetic_dataset)
tstr_metrics = evaluate_tstr(
tstr_predictor=tstr_predictor,
test_dataset=dataset.test,
context_length=context_length,
prediction_length=prediction_length,
scaling_type=scaling_type,
)
return dict(
tstr_metrics=tstr_metrics,
)
def main(config: dict, log_dir: str, samples_path: str):
# Read global parameters
dataset_name = config["dataset"]
context_length = config["context_length"]
prediction_length = config["prediction_length"]
# Create log_dir
log_dir: Path = Path(log_dir)
base_dirname = "tstr_log"
run_num = get_next_file_num(
base_dirname, log_dir, file_type="", separator="-"
)
log_dir = log_dir / f"{base_dirname}-{run_num}"
log_dir.mkdir(exist_ok=True, parents=True)
logger.info(f"Logging to {log_dir}")
# Load dataset and model
logger.info("Loading model")
dataset = get_gts_dataset(dataset_name)
config["freq"] = dataset.metadata.freq
assert prediction_length == dataset.metadata.prediction_length
model = load_model(config)
# Setup data transformation and loading
transformation = create_transforms(
num_feat_dynamic_real=0,
num_feat_static_cat=0,
num_feat_static_real=0,
time_features=time_features_from_frequency_str(config["freq"]),
prediction_length=prediction_length,
)
transformed_data = transformation.apply(list(dataset.train), is_train=True)
training_splitter = create_splitter(
past_length=context_length + max(model.lags_seq),
future_length=prediction_length,
mode="train",
)
train_dataloader = TrainDataLoader(
Cached(transformed_data),
batch_size=1000,
stack_fn=batchify,
transform=training_splitter,
)
# Generate real samples
logger.info("Generating real samples")
real_samples = sample_real(
train_dataloader,
n_timesteps=context_length + prediction_length,
num_samples=10000,
)
np.save(log_dir / "real_samples.npy", real_samples)
if samples_path is None:
# Generate synthetic samples
logger.info("Generating synthetic samples")
synth_samples = sample_synthetic(model, num_samples=10000)
np.save(log_dir / "synth_samples.npy", synth_samples)
else:
logger.info(f"Using synthetic samples from {samples_path}")
synth_samples = np.load(samples_path)[:10000]
synth_samples = synth_samples.reshape(
(10000, context_length + prediction_length)
)
# Run TSTR experiment for each downstream model
results = []
for model_name in DOWNSTREAM_MODELS:
logger.info(f"Training and evaluating {model_name}")
metrics = train_and_evaluate(
dataset=dataset,
model_name=model_name,
synth_samples=synth_samples,
real_samples=real_samples,
config=config,
scaling_type=config["scaling_type"],
)
results.append({"model": model_name, **metrics})
logger.info("Saving results")
with open(log_dir / "results.yaml", "w") as fp:
yaml.safe_dump(
{"config": config, "metrics": results},
fp,
default_flow_style=False,
sort_keys=False,
)
if __name__ == "__main__":
# Setup Logger
logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__file__)
logger.setLevel(logging.INFO)
# Setup argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"-c", "--config", type=str, required=True, help="Path to yaml config"
)
parser.add_argument(
"--out_dir", type=str, default="./results", help="Path to results dir"
)
parser.add_argument(
"--samples_path", type=str, help="Path to generated samples"
)
args, _ = parser.parse_known_args()
with open(args.config, "r") as fp:
config = yaml.safe_load(fp)
# Update config from command line
parser = add_config_to_argparser(config=config, parser=parser)
args = parser.parse_args()
config_updates = vars(args)
for k in config.keys() & config_updates.keys():
orig_val = config[k]
updated_val = config_updates[k]
if updated_val != orig_val:
logger.info(f"Updated key '{k}': {orig_val} -> {updated_val}")
config.update(config_updates)
main(config=config, log_dir=args.out_dir, samples_path=args.samples_path)
================================================
FILE: configs/guidance/guidance_electricity.yaml
================================================
ckpt: dummy/path.ckpt
context_length: 336
dataset: electricity_nips
device: cuda:0
diffusion_config: diffusion_small_config
freq: H
init_skip: false
num_samples: 100
prediction_length: 24
sampler: ddpm
sampler_params:
guidance: quantile
scale: 4
setup: forecasting
use_features: false
use_lags: true
================================================
FILE: configs/guidance/guidance_exchange.yaml
================================================
ckpt: dummy/path.ckpt
context_length: 360
dataset: exchange_rate_nips
device: cuda:0
diffusion_config: diffusion_small_config
freq: B
init_skip: true
num_samples: 100
prediction_length: 30
sampler: ddpm
sampler_params:
guidance: quantile
scale: 8
setup: forecasting
use_features: false
use_lags: true
================================================
FILE: configs/guidance/guidance_kdd_cup.yaml
================================================
ckpt: dummy/path.ckpt
context_length: 312
dataset: kdd_cup_2018_without_missing
device: cuda:0
diffusion_config: diffusion_small_config
freq: H
init_skip: true
num_samples: 100
prediction_length: 48
sampler: ddpm
sampler_params:
guidance: quantile
scale: 1
setup: forecasting
use_features: false
use_lags: true
================================================
FILE: configs/guidance/guidance_m4.yaml
================================================
ckpt: dummy/path.ckpt
context_length: 312
dataset: m4_hourly
device: cuda:0
diffusion_config: diffusion_small_config
freq: H
init_skip: false
num_samples: 100
prediction_length: 48
sampler: ddpm
sampler_params:
guidance: quantile
scale: 2
setup: forecasting
use_features: false
use_lags: false
================================================
FILE: configs/guidance/guidance_solar.yaml
================================================
ckpt: dummy/path.ckpt
context_length: 336
dataset: solar_nips
device: cuda:0
diffusion_config: diffusion_small_config
freq: H
init_skip: false
num_samples: 100
prediction_length: 24
sampler: ddpm
sampler_params:
guidance: quantile
scale: 8
setup: forecasting
use_features: false
use_lags: true
================================================
FILE: configs/guidance/guidance_traffic.yaml
================================================
ckpt: dummy/path.ckpt
context_length: 336
dataset: traffic_nips
device: cuda:0
diffusion_config: diffusion_small_config
freq: H
init_skip: true
num_samples: 100
prediction_length: 24
sampler: ddpm
sampler_params:
guidance: quantile
scale: 4
setup: forecasting
use_features: false
use_lags: true
================================================
FILE: configs/guidance/guidance_uber_tlc.yaml
================================================
ckpt: dummy/path.ckpt
context_length: 336
dataset: uber_tlc_hourly
device: cuda:0
diffusion_config: diffusion_small_config
freq: H
init_skip: false
num_samples: 100
prediction_length: 24
sampler: ddpm
sampler_params:
guidance: quantile
scale: 2
setup: forecasting
use_features: false
use_lags: true
================================================
FILE: configs/guidance/guidance_wiki.yaml
================================================
ckpt: dummy/path.ckpt
context_length: 360
dataset: wiki2000_nips
device: cuda:0
diffusion_config: diffusion_small_config
freq: 1D
init_skip: false
num_samples: 100
prediction_length: 30
sampler: ddpm
sampler_params:
guidance: quantile
scale: 2
setup: forecasting
use_features: false
use_lags: false
================================================
FILE: configs/guidance.yaml
================================================
# Model & checkpoint parameters
dataset: solar_nips
freq: H
device: cuda:0
ckpt: ckpts/forecasting/solar_nips/649_.ckpt
diffusion_config: diffusion_small_config
context_length: 336
prediction_length: 24
use_lags: true
use_features: false
init_skip: false
sampler: ddpm
sampler_params:
guidance: quantile
scale: 4
num_samples: 100
setup: forecasting
# The following key will be ignored,
# if the setup is forecasting
missing_data_configs:
- missing_scenario: BM-B
missing_values: 168
- missing_scenario: BM-E
missing_values: 168
================================================
FILE: configs/refinement/electricity_nips-deepar.yaml
================================================
base_model: deepar
ckpt: dummy/electricity_nips.ckpt
context_length: 336
dataset: electricity_nips
device: cuda:0
init_skip: false
iterations: 20
num_samples: 100
prediction_length: 24
refiner_configs:
- guidance: MSE
lr: 0.1
refiner_name: most_likely
- guidance: quantile
lr: 0.1
refiner_name: most_likely
- guidance: MSE
method: lmc
method_kwargs:
noise_scale: 0.1
refiner_name: mcmc
step_size: 0.1
- guidance: quantile
method: lmc
method_kwargs:
noise_scale: 0.1
refiner_name: mcmc
step_size: 0.1
use_features: false
use_lags: true
================================================
FILE: configs/refinement/electricity_nips-linear.yaml
================================================
base_model: linear
ckpt: dummy/electricity_nips.ckpt
context_length: 336
dataset: electricity_nips
device: cuda:0
diffusion_config: diffusion_small_config
init_skip: false
iterations: 20
num_samples: 100
prediction_length: 24
refiner_configs:
- guidance: MSE
lr: 0.1
refiner_name: most_likely
- guidance: quantile
lr: 0.1
refiner_name: most_likely
- guidance: MSE
method: lmc
method_kwargs:
noise_scale: 0.1
refiner_name: mcmc
step_size: 0.1
- guidance: quantile
method: lmc
method_kwargs:
noise_scale: 0.1
refiner_name: mcmc
step_size: 0.1
use_features: false
use_lags: true
================================================
FILE: configs/refinement/electricity_nips-seasonal_naive.yaml
================================================
base_model: seasonal_naive
ckpt: dummy/electricity_nips.ckpt
context_length: 336
dataset: electricity_nips
device: cuda:0
init_skip: false
iterations: 20
num_samples: 100
prediction_length: 24
refiner_configs:
- guidance: MSE
lr: 0.1
refiner_name: most_likely
- guidance: quantile
lr: 0.1
refiner_name: most_likely
- guidance: MSE
method: lmc
method_kwargs:
noise_scale: 0.1
refiner_name: mcmc
step_size: 0.1
- guidance: quantile
method: lmc
method_kwargs:
noise_scale: 0.1
refiner_name: mcmc
step_size: 0.1
use_features: false
use_lags: true
================================================
FILE: configs/refinement/electricity_nips-transformer.yaml
================================================
base_model: transformer
ckpt: dummy/electricity_nips.ckpt
context_length: 336
dataset: electricity_nips
device: cuda:0
init_skip: false
iterations: 20
num_samples: 100
prediction_length: 24
refiner_configs:
- guidance: MSE
lr: 0.1
refiner_name: most_likely
- guidance: quantile
lr: 0.1
refiner_name: most_likely
- guidance: MSE
method: lmc
method_kwargs:
noise_scale: 0.1
refiner_name: mcmc
step_size: 0.1
- guidance: quantile
method: lmc
method_kwargs:
noise_scale: 0.1
refiner_name: mcmc
step_size: 0.1
use_features: false
use_lags: true
================================================
FILE: configs/refinement/exchange_rate_nips-deepar.yaml
================================================
base_model: deepar
ckpt: dummy/exchange_rate_nips.ckpt
context_length: 360
dataset: exchange_rate_nips
device: cuda:0
init_skip: true
iterations: 20
num_samples: 100
prediction_length: 30
refiner_configs:
- guidance: MSE
lr: 0.01
refiner_name: most_likely
- guidance: quantile
lr: 0.01
refiner_name: most_likely
- guidance: MSE
method: lmc
method_kwargs:
noise_scale: 0.1
refiner_name: mcmc
step_size: 0.01
- guidance: quantile
method: lmc
method_kwargs:
noise_scale: 0.1
refiner_name: mcmc
step_size: 0.01
use_features: false
use_lags: true
================================================
FILE: configs/refinement/exchange_rate_nips-linear.yaml
================================================
base_model: linear
ckpt: dummy/exchange_rate_nips.ckpt
context_length: 360
dataset: exchange_rate_nips
device: cuda:0
diffusion_config: diffusion_small_config
init_skip: true
iterations: 20
num_samples: 100
prediction_length: 30
refiner_configs:
- guidance: MSE
lr: 0.01
refiner_name: most_likely
- guidance: quantile
lr: 0.01
refiner_name: most_likely
- guidance: MSE
method: lmc
method_kwargs:
noise_scale: 0.1
refiner_name: mcmc
step_size: 0.01
- guidance: quantile
method: lmc
method_kwargs:
noise_scale: 0.1
refiner_name: mcmc
step_size: 0.01
use_features: false
use_lags: true
================================================
FILE: configs/refinement/exchange_rate_nips-seasonal_naive.yaml
================================================
base_model: seasonal_naive
ckpt: dummy/exchange_rate_nips.ckpt
context_length: 360
dataset: exchange_rate_nips
device: cuda:0
init_skip: true
iterations: 20
num_samples: 100
prediction_length: 30
refiner_configs:
- guidance: MSE
lr: 0.01
refiner_name: most_likely
- guidance: quantile
lr: 0.01
refiner_name: most_likely
- guidance: MSE
method: lmc
method_kwargs:
noise_scale: 0.1
refiner_name: mcmc
step_size: 0.01
- guidance: quantile
method: lmc
method_kwargs:
noise_scale: 0.1
refiner_name: mcmc
step_size: 0.01
use_features: false
use_lags: true
================================================
FILE: configs/refinement/exchange_rate_nips-transformer.yaml
================================================
base_model: transformer
ckpt: dummy/exchange_rate_nips.ckpt
context_length: 360
dataset: exchange_rate_nips
device: cuda:0
init_skip: true
iterations: 20
num_samples: 100
prediction_length: 30
refiner_configs:
- guidance: MSE
lr: 0.01
refiner_name: most_likely
- guidance: quantile
lr: 0.01
refiner_name: most_likely
- guidance: MSE
method: lmc
method_kwargs:
noise_scale: 0.1
refiner_name: mcmc
step_size: 0.01
- guidance: quantile
method: lmc
method_kwargs:
noise_scale: 0.1
refiner_name: mcmc
step_size: 0.01
use_features: false
use_lags: true
================================================
FILE: configs/refinement/kdd_cup_2018_without_missing-deepar.yaml
================================================
base_model: deepar
base_model_params: {}
ckpt: dummy/kdd_cup_2018_without_missing.ckpt
context_length: 312
dataset: kdd_cup_2018_without_missing
device: cuda:0
diffusion_config: diffusion_small_config
init_skip: true
iterations: 20
num_samples: 100
prediction_length: 48
refiner_configs:
- guidance: MSE
lr: 0.1
refiner_name: most_likely
- guidance: quantile
lr: 0.1
refiner_name: most_likely
- guidance: MSE
method: lmc
method_kwargs:
noise_scale: 0.1
refiner_name: mcmc
step_size: 0.1
- guidance: quantile
method: lmc
method_kwargs:
noise_scale: 0.1
refiner_name: mcmc
step_size: 0.1
use_features: false
use_lags: true
================================================
FILE: configs/refinement/kdd_cup_2018_without_missing-linear.yaml
================================================
base_model: linear
base_model_params: {}
ckpt: dummy/kdd_cup_2018_without_missing.ckpt
context_length: 312
dataset: kdd_cup_2018_without_missing
device: cuda:0
diffusion_config: diffusion_small_config
init_skip: true
iterations: 20
num_samples: 100
prediction_length: 48
refiner_configs:
- guidance: MSE
lr: 0.1
refiner_name: most_likely
- guidance: quantile
lr: 0.1
refiner_name: most_likely
- guidance: MSE
method: lmc
method_kwargs:
noise_scale: 0.1
refiner_name: mcmc
step_size: 0.1
- guidance: quantile
method: lmc
method_kwargs:
noise_scale: 0.1
refiner_name: mcmc
step_size: 0.1
use_features: false
use_lags: true
================================================
FILE: configs/refinement/kdd_cup_2018_without_missing-seasonal_naive.yaml
================================================
base_model: seasonal_naive
base_model_params: {}
ckpt: dummy/kdd_cup_2018_without_missing.ckpt
context_length: 312
dataset: kdd_cup_2018_without_missing
device: cuda:0
diffusion_config: diffusion_small_config
init_skip: true
iterations: 20
num_samples: 100
prediction_length: 48
refiner_configs:
- guidance: MSE
lr: 0.1
refiner_name: most_likely
- guidance: quantile
lr: 0.1
refiner_name: most_likely
- guidance: MSE
method: lmc
method_kwargs:
noise_scale: 0.1
refiner_name: mcmc
step_size: 0.1
- guidance: quantile
method: lmc
method_kwargs:
noise_scale: 0.1
refiner_name: mcmc
step_size: 0.1
use_features: false
use_lags: true
================================================
FILE: configs/refinement/kdd_cup_2018_without_missing-transformer.yaml
================================================
base_model: transformer
base_model_params: {}
ckpt: dummy/kdd_cup_2018_without_missing.ckpt
context_length: 312
dataset: kdd_cup_2018_without_missing
device: cuda:0
diffusion_config: diffusion_small_config
init_skip: true
iterations: 20
num_samples: 100
prediction_length: 48
refiner_configs:
- guidance: MSE
lr: 0.1
refiner_name: most_likely
- guidance: quantile
lr: 0.1
refiner_name: most_likely
- guidance: MSE
method: lmc
method_kwargs:
noise_scale: 0.1
refiner_name: mcmc
step_size: 0.1
- guidance: quantile
method: lmc
method_kwargs:
noise_scale: 0.1
refiner_name: mcmc
step_size: 0.1
use_features: false
use_lags: true
================================================
FILE: configs/refinement/m4_hourly-deepar.yaml
================================================
base_model: deepar
ckpt: dummy/m4_hourly.ckpt
context_length: 312
dataset: m4_hourly
device: cuda:0
diffusion_config: diffusion_small_config
init_skip: false
iterations: 20
num_samples: 100
prediction_length: 48
refiner_configs:
- guidance: MSE
lr: 0.1
refiner_name: most_likely
- guidance: quantile
lr: 0.1
refiner_name: most_likely
- guidance: MSE
method: lmc
method_kwargs:
noise_scale: 0.1
refiner_name: mcmc
step_size: 0.1
- guidance: quantile
method: lmc
method_kwargs:
noise_scale: 0.1
refiner_name: mcmc
step_size: 0.1
use_features: false
use_lags: false
================================================
FILE: configs/refinement/m4_hourly-linear.yaml
================================================
base_model: linear
ckpt: dummy/m4_hourly.ckpt
context_length: 312
dataset: m4_hourly
device: cuda:0
diffusion_config: diffusion_small_config
init_skip: false
iterations: 20
num_samples: 100
prediction_length: 48
refiner_configs:
- guidance: MSE
lr: 0.1
refiner_name: most_likely
- guidance: quantile
lr: 0.1
refiner_name: most_likely
- guidance: MSE
method: lmc
method_kwargs:
noise_scale: 0.1
refiner_name: mcmc
step_size: 0.1
- guidance: quantile
method: lmc
method_kwargs:
noise_scale: 0.1
refiner_name: mcmc
step_size: 0.1
use_features: false
use_lags: false
================================================
FILE: configs/refinement/m4_hourly-seasonal_naive.yaml
================================================
base_model: seasonal_naive
ckpt: dummy/m4_hourly.ckpt
context_length: 312
dataset: m4_hourly
device: cuda:0
diffusion_config: diffusion_small_config
init_skip: false
iterations: 20
num_samples: 100
prediction_length: 48
refiner_configs:
- guidance: MSE
lr: 0.1
refiner_name: most_likely
- guidance: quantile
lr: 0.1
refiner_name: most_likely
- guidance: MSE
method: lmc
method_kwargs:
noise_scale: 0.1
refiner_name: mcmc
step_size: 0.1
- guidance: quantile
method: lmc
method_kwargs:
noise_scale: 0.1
refiner_name: mcmc
step_size: 0.1
use_features: false
use_lags: false
================================================
FILE: configs/refinement/m4_hourly-transformer.yaml
================================================
base_model: transformer
ckpt: dummy/m4_hourly.ckpt
context_length: 312
dataset: m4_hourly
device: cuda:0
diffusion_config: diffusion_small_config
init_skip: false
iterations: 20
num_samples: 100
prediction_length: 48
refiner_configs:
- guidance: MSE
lr: 0.1
refiner_name: most_likely
- guidance: quantile
lr: 0.1
refiner_name: most_likely
- guidance: MSE
method: lmc
method_kwargs:
noise_scale: 0.1
refiner_name: mcmc
step_size: 0.1
- guidance: quantile
method: lmc
method_kwargs:
noise_scale: 0.1
refiner_name: mcmc
step_size: 0.1
use_features: false
use_lags: false
================================================
FILE: configs/refinement/solar_nips-deepar.yaml
================================================
base_model: deepar
ckpt: dummy/solar_nips.ckpt
context_length: 336
dataset: solar_nips
device: cuda:0
init_skip: false
iterations: 20
num_samples: 100
prediction_length: 24
refiner_configs:
- guidance: MSE
lr: 0.1
refiner_name: most_likely
- guidance: quantile
lr: 0.1
refiner_name: most_likely
- guidance: MSE
method: lmc
method_kwargs:
noise_scale: 0.1
refiner_name: mcmc
step_size: 0.1
- guidance: quantile
method: lmc
method_kwargs:
noise_scale: 0.1
refiner_name: mcmc
step_size: 0.1
use_features: false
use_lags: true
================================================
FILE: configs/refinement/solar_nips-linear.yaml
================================================
base_model: linear
ckpt: dummy/solar_nips.ckpt
context_length: 336
dataset: solar_nips
device: cuda:0
diffusion_config: diffusion_small_config
init_skip: false
iterations: 20
num_samples: 100
prediction_length: 24
refiner_configs:
- guidance: MSE
lr: 0.1
refiner_name: most_likely
- guidance: quantile
lr: 0.1
refiner_name: most_likely
- guidance: MSE
method: lmc
method_kwargs:
noise_scale: 0.1
refiner_name: mcmc
step_size: 0.1
- guidance: quantile
method: lmc
method_kwargs:
noise_scale: 0.1
refiner_name: mcmc
step_size: 0.1
use_features: false
use_lags: true
================================================
FILE: configs/refinement/solar_nips-seasonal_naive.yaml
================================================
base_model: seasonal_naive
ckpt: dummy/solar_nips.ckpt
context_length: 336
dataset: solar_nips
device: cuda:0
init_skip: false
iterations: 20
num_samples: 100
prediction_length: 24
refiner_configs:
- guidance: MSE
lr: 0.1
refiner_name: most_likely
- guidance: quantile
lr: 0.1
refiner_name: most_likely
- guidance: MSE
method: lmc
method_kwargs:
noise_scale: 0.1
refiner_name: mcmc
step_size: 0.1
- guidance: quantile
method: lmc
method_kwargs:
noise_scale: 0.1
refiner_name: mcmc
step_size: 0.1
use_features: false
use_lags: true
================================================
FILE: configs/refinement/solar_nips-transformer.yaml
================================================
base_model: transformer
ckpt: dummy/solar_nips.ckpt
context_length: 336
dataset: solar_nips
device: cuda:0
init_skip: false
iterations: 20
num_samples: 100
prediction_length: 24
refiner_configs:
- guidance: MSE
lr: 0.1
refiner_name: most_likely
- guidance: quantile
lr: 0.1
refiner_name: most_likely
- guidance: MSE
method: lmc
method_kwargs:
noise_scale: 0.1
refiner_name: mcmc
step_size: 0.1
- guidance: quantile
method: lmc
method_kwargs:
noise_scale: 0.1
refiner_name: mcmc
step_size: 0.1
use_features: false
use_lags: true
================================================
FILE: configs/refinement/traffic_nips-deepar.yaml
================================================
base_model: deepar
ckpt: dummy/traffic_nips.ckpt
context_length: 336
dataset: traffic_nips
device: cuda:0
init_skip: true
iterations: 20
num_samples: 100
prediction_length: 24
refiner_configs:
- guidance: MSE
lr: 0.1
refiner_name: most_likely
- guidance: quantile
lr: 0.1
refiner_name: most_likely
- guidance: MSE
method: lmc
method_kwargs:
noise_scale: 0.1
refiner_name: mcmc
step_size: 0.1
- guidance: quantile
method: lmc
method_kwargs:
noise_scale: 0.1
refiner_name: mcmc
step_size: 0.1
use_features: false
use_lags: true
================================================
FILE: configs/refinement/traffic_nips-linear.yaml
================================================
base_model: linear
ckpt: dummy/traffic_nips.ckpt
context_length: 336
dataset: traffic_nips
device: cuda:0
diffusion_config: diffusion_small_config
init_skip: true
iterations: 20
num_samples: 100
prediction_length: 24
refiner_configs:
- guidance: MSE
lr: 0.1
refiner_name: most_likely
- guidance: quantile
lr: 0.1
refiner_name: most_likely
- guidance: MSE
method: lmc
method_kwargs:
noise_scale: 0.1
refiner_name: mcmc
step_size: 0.1
- guidance: quantile
method: lmc
method_kwargs:
noise_scale: 0.1
refiner_name: mcmc
step_size: 0.1
use_features: false
use_lags: true
================================================
FILE: configs/refinement/traffic_nips-seasonal_naive.yaml
================================================
base_model: seasonal_naive
ckpt: dummy/traffic_nips.ckpt
context_length: 336
dataset: traffic_nips
device: cuda:0
init_skip: true
iterations: 20
num_samples: 100
prediction_length: 24
refiner_configs:
- guidance: MSE
lr: 0.1
refiner_name: most_likely
- guidance: quantile
lr: 0.1
refiner_name: most_likely
- guidance: MSE
method: lmc
method_kwargs:
noise_scale: 0.1
refiner_name: mcmc
step_size: 0.1
- guidance: quantile
method: lmc
method_kwargs:
noise_scale: 0.1
refiner_name: mcmc
step_size: 0.1
use_features: false
use_lags: true
================================================
FILE: configs/refinement/traffic_nips-transformer.yaml
================================================
base_model: transformer
ckpt: dummy/traffic_nips.ckpt
context_length: 336
dataset: traffic_nips
device: cuda:0
init_skip: true
iterations: 20
num_samples: 100
prediction_length: 24
refiner_configs:
- guidance: MSE
lr: 0.1
refiner_name: most_likely
- guidance: quantile
lr: 0.1
refiner_name: most_likely
- guidance: MSE
method: lmc
method_kwargs:
noise_scale: 0.1
refiner_name: mcmc
step_size: 0.1
- guidance: quantile
method: lmc
method_kwargs:
noise_scale: 0.1
refiner_name: mcmc
step_size: 0.1
use_features: false
use_lags: true
================================================
FILE: configs/refinement/uber_tlc_hourly-deepar.yaml
================================================
base_model: deepar
ckpt: dummy/uber_tlc_hourly.ckpt
context_length: 336
dataset: uber_tlc_hourly
device: cuda:0
init_skip: false
iterations: 20
num_samples: 100
prediction_length: 24
refiner_configs:
- guidance: MSE
lr: 0.1
refiner_name: most_likely
- guidance: quantile
lr: 0.1
refiner_name: most_likely
- guidance: MSE
method: lmc
method_kwargs:
noise_scale: 0.1
refiner_name: mcmc
step_size: 0.1
- guidance: quantile
method: lmc
method_kwargs:
noise_scale: 0.1
refiner_name: mcmc
step_size: 0.1
use_features: false
use_lags: true
================================================
FILE: configs/refinement/uber_tlc_hourly-linear.yaml
================================================
base_model: linear
ckpt: dummy/uber_tlc_hourly.ckpt
context_length: 336
dataset: uber_tlc_hourly
device: cuda:0
diffusion_config: diffusion_small_config
init_skip: false
iterations: 20
num_samples: 100
prediction_length: 24
refiner_configs:
- guidance: MSE
lr: 0.1
refiner_name: most_likely
- guidance: quantile
lr: 0.1
refiner_name: most_likely
- guidance: MSE
method: lmc
method_kwargs:
noise_scale: 0.1
refiner_name: mcmc
step_size: 0.1
- guidance: quantile
method: lmc
method_kwargs:
noise_scale: 0.1
refiner_name: mcmc
step_size: 0.1
use_features: false
use_lags: true
================================================
FILE: configs/refinement/uber_tlc_hourly-seasonal_naive.yaml
================================================
base_model: seasonal_naive
ckpt: dummy/uber_tlc_hourly.ckpt
context_length: 336
dataset: uber_tlc_hourly
device: cuda:0
init_skip: false
iterations: 20
num_samples: 100
prediction_length: 24
refiner_configs:
- guidance: MSE
lr: 0.1
refiner_name: most_likely
- guidance: quantile
lr: 0.1
refiner_name: most_likely
- guidance: MSE
method: lmc
method_kwargs:
noise_scale: 0.1
refiner_name: mcmc
step_size: 0.1
- guidance: quantile
method: lmc
method_kwargs:
noise_scale: 0.1
refiner_name: mcmc
step_size: 0.1
use_features: false
use_lags: true
================================================
FILE: configs/refinement/uber_tlc_hourly-transformer.yaml
================================================
base_model: transformer
ckpt: dummy/uber_tlc_hourly.ckpt
context_length: 336
dataset: uber_tlc_hourly
device: cuda:0
init_skip: false
iterations: 20
num_samples: 100
prediction_length: 24
refiner_configs:
- guidance: MSE
lr: 0.1
refiner_name: most_likely
- guidance: quantile
lr: 0.1
refiner_name: most_likely
- guidance: MSE
method: lmc
method_kwargs:
noise_scale: 0.1
refiner_name: mcmc
step_size: 0.1
- guidance: quantile
method: lmc
method_kwargs:
noise_scale: 0.1
refiner_name: mcmc
step_size: 0.1
use_features: false
use_lags: true
================================================
FILE: configs/refinement/wiki2000_nips-deepar.yaml
================================================
base_model: deepar
ckpt: dummy/wiki2000_nips.ckpt
context_length: 360
dataset: wiki2000_nips
device: cuda:0
diffusion_config: diffusion_small_config
init_skip: false
iterations: 20
num_samples: 100
prediction_length: 30
refiner_configs:
- guidance: MSE
lr: 0.1
refiner_name: most_likely
- guidance: quantile
lr: 0.1
refiner_name: most_likely
- guidance: MSE
method: lmc
method_kwargs:
noise_scale: 0.1
refiner_name: mcmc
step_size: 0.1
- guidance: quantile
method: lmc
method_kwargs:
noise_scale: 0.1
refiner_name: mcmc
step_size: 0.1
use_features: false
use_lags: false
================================================
FILE: configs/refinement/wiki2000_nips-linear.yaml
================================================
base_model: linear
ckpt: dummy/wiki2000_nips.ckpt
context_length: 360
dataset: wiki2000_nips
device: cuda:0
diffusion_config: diffusion_small_config
init_skip: false
iterations: 20
num_samples: 100
prediction_length: 30
refiner_configs:
- guidance: MSE
lr: 0.1
refiner_name: most_likely
- guidance: quantile
lr: 0.1
refiner_name: most_likely
- guidance: MSE
method: lmc
method_kwargs:
noise_scale: 0.1
refiner_name: mcmc
step_size: 0.1
- guidance: quantile
method: lmc
method_kwargs:
noise_scale: 0.1
refiner_name: mcmc
step_size: 0.1
use_features: false
use_lags: false
================================================
FILE: configs/refinement/wiki2000_nips-seasonal_naive.yaml
================================================
base_model: seasonal_naive
ckpt: dummy/wiki2000_nips.ckpt
context_length: 360
dataset: wiki2000_nips
device: cuda:0
diffusion_config: diffusion_small_config
init_skip: false
iterations: 20
num_samples: 100
prediction_length: 30
refiner_configs:
- guidance: MSE
lr: 0.1
refiner_name: most_likely
- guidance: quantile
lr: 0.1
refiner_name: most_likely
- guidance: MSE
method: lmc
method_kwargs:
noise_scale: 0.1
refiner_name: mcmc
step_size: 0.1
- guidance: quantile
method: lmc
method_kwargs:
noise_scale: 0.1
refiner_name: mcmc
step_size: 0.1
use_features: false
use_lags: false
================================================
FILE: configs/refinement/wiki2000_nips-transformer.yaml
================================================
base_model: transformer
ckpt: dummy/wiki2000_nips.ckpt
context_length: 360
dataset: wiki2000_nips
device: cuda:0
diffusion_config: diffusion_small_config
init_skip: false
iterations: 20
num_samples: 100
prediction_length: 30
refiner_configs:
- guidance: MSE
lr: 0.1
refiner_name: most_likely
- guidance: quantile
lr: 0.1
refiner_name: most_likely
- guidance: MSE
method: lmc
method_kwargs:
noise_scale: 0.1
refiner_name: mcmc
step_size: 0.1
- guidance: quantile
method: lmc
method_kwargs:
noise_scale: 0.1
refiner_name: mcmc
step_size: 0.1
use_features: false
use_lags: false
================================================
FILE: configs/refinement.yaml
================================================
# Model & checkpoint parameters
dataset: solar_nips
device: cuda:0
ckpt: ckpts/forecasting/solar_nips/649_.ckpt
diffusion_config: diffusion_small_config
context_length: 336
prediction_length: 24
use_lags: true
use_features: false
init_skip: false
# Refinement parameters
base_model: linear
base_model_params: {}
num_samples: 16
iterations: 10
refiner_configs:
- refiner_name: most_likely
lr: 1.e-1
guidance: MSE
- refiner_name: most_likely
lr: 1.e-1
guidance: quantile
- refiner_name: mcmc
step_size: 1.e-1
guidance: MSE
method: lmc
method_kwargs:
noise_scale: 0.1
- refiner_name: mcmc
step_size: 1.e-1
guidance: quantile
method: lmc
method_kwargs:
noise_scale: 0.1
================================================
FILE: configs/train_tsdiff/train_electricity.yaml
================================================
model: unconditional
diffusion_config: diffusion_small_config
normalization: mean
use_features: False
use_lags: True
dataset: electricity_nips
freq: H
context_length: 336 # 360 for `D`
prediction_length: 24 # 30 for `D`
lr: 1.e-3
init_skip: False
gradient_clip_val: 0.5
max_epochs: 1000
num_batches_per_epoch: 128
batch_size: 64
# Used only in callback,
# the final evaluation uses 100 samples
num_samples: 16
sampler: ddpm
sampler_params:
guidance: quantile
scale: 4
use_validation_set: True
eval_every: 50
device: cuda:0
setup: forecasting
================================================
FILE: configs/train_tsdiff/train_exchange.yaml
================================================
model: unconditional
diffusion_config: diffusion_small_config
normalization: mean
use_features: False
use_lags: True
dataset: exchange_rate_nips
freq: B
context_length: 360 # 360 for `D`
prediction_length: 30 # 30 for `D`
lr: 1.e-3
init_skip: True
gradient_clip_val: 0.5
max_epochs: 1000
num_batches_per_epoch: 128
batch_size: 64
# Used only in callback,
# the final evaluation uses 100 samples
num_samples: 16
sampler: ddpm
sampler_params:
guidance: quantile
scale: 8
use_validation_set: True
eval_every: 50
device: cuda:0
setup: forecasting
================================================
FILE: configs/train_tsdiff/train_kdd_cup.yaml
================================================
model: unconditional
diffusion_config: diffusion_small_config
normalization: mean
use_features: False
use_lags: True
dataset: kdd_cup_2018_without_missing
freq: H
context_length: 312 # 360 for `D`
prediction_length: 48 # 30 for `D`
lr: 1.e-3
init_skip: True
gradient_clip_val: 0.5
max_epochs: 1000
num_batches_per_epoch: 128
batch_size: 64
# Used only in callback,
# the final evaluation uses 100 samples
num_samples: 16
sampler: ddpm
sampler_params:
guidance: quantile
scale: 1
use_validation_set: True
eval_every: 50
device: cuda:0
setup: forecasting
================================================
FILE: configs/train_tsdiff/train_m4.yaml
================================================
model: unconditional
diffusion_config: diffusion_small_config
normalization: mean
use_features: False
use_lags: False
dataset: m4_hourly
freq: H
context_length: 312 # 360 for `D`
prediction_length: 48 # 30 for `D`
lr: 1.e-3
init_skip: False
gradient_clip_val: 0.5
max_epochs: 1000
num_batches_per_epoch: 128
batch_size: 64
# Used only in callback,
# the final evaluation uses 100 samples
num_samples: 16
sampler: ddpm
sampler_params:
guidance: quantile
scale: 2
use_validation_set: True
eval_every: 50
device: cuda:0
setup: forecasting
================================================
FILE: configs/train_tsdiff/train_missing_electricity.yaml
================================================
model: unconditional
diffusion_config: diffusion_small_config
normalization: mean
use_features: False
use_lags: True
dataset: electricity_nips
freq: H
context_length: 336 # 360 for `D`
prediction_length: 24 # 30 for `D`
lr: 1.e-3
init_skip: False
gradient_clip_val: 0.5
max_epochs: 1000
num_batches_per_epoch: 128
batch_size: 64
# Used only in callback,
# the final evaluation uses 100 samples
num_samples: 16
sampler: ddpm
sampler_params:
guidance: quantile
scale: 4
use_validation_set: True
eval_every: 50
device: cuda:0
setup: missing_values
# The following key will be ignored,
# if the setup is forecasting
missing_data_configs:
- missing_scenario: none
missing_values: 0
- missing_scenario: BM-E
missing_values: 168
- missing_scenario: BM-B
missing_values: 168
- missing_scenario: RM
missing_values: 168
================================================
FILE: configs/train_tsdiff/train_missing_exchange.yaml
================================================
model: unconditional
diffusion_config: diffusion_small_config
normalization: mean
use_features: False
use_lags: True
dataset: exchange_rate_nips
freq: B
context_length: 360 # 360 for `D`
prediction_length: 30 # 30 for `D`
lr: 1.e-3
init_skip: True
gradient_clip_val: 0.5
max_epochs: 1000
num_batches_per_epoch: 128
batch_size: 64
# Used only in callback,
# the final evaluation uses 100 samples
num_samples: 16
sampler: ddpm
sampler_params:
guidance: quantile
scale: 8
use_validation_set: True
eval_every: 50
device: cuda:0
setup: missing_values
# The following key will be ignored,
# if the setup is forecasting
missing_data_configs:
- missing_scenario: none
missing_values: 0
- missing_scenario: BM-E
missing_values: 180
- missing_scenario: BM-B
missing_values: 180
- missing_scenario: RM
missing_values: 180
================================================
FILE: configs/train_tsdiff/train_missing_kdd_cup.yaml
================================================
model: unconditional
diffusion_config: diffusion_small_config
normalization: mean
use_features: False
use_lags: True
dataset: kdd_cup_2018_without_missing
freq: H
context_length: 312 # 360 for `D`
prediction_length: 48 # 30 for `D`
lr: 1.e-3
init_skip: True
gradient_clip_val: 0.5
max_epochs: 1000
num_batches_per_epoch: 128
batch_size: 64
# Used only in callback,
# the final evaluation uses 100 samples
num_samples: 16
sampler: ddpm
sampler_params:
guidance: quantile
scale: 1
guidance: quantile
use_validation_set: True
do_final_eval: True
eval_every: 50
device: cuda:0
setup: missing_values
# The following key will be ignored,
# if the setup is forecasting
missing_data_configs:
- missing_scenario: none
missing_values: 0
- missing_scenario: BM-E
missing_values: 156
- missing_scenario: BM-B
missing_values: 156
- missing_scenario: RM
missing_values: 156
================================================
FILE: configs/train_tsdiff/train_missing_solar.yaml
================================================
model: unconditional
diffusion_config: diffusion_small_config
normalization: mean
use_features: False
use_lags: True
dataset: solar_nips
freq: H
context_length: 336 # 360 for `D`
prediction_length: 24 # 30 for `D`
lr: 1.e-3
init_skip: False
gradient_clip_val: 0.5
max_epochs: 1000
num_batches_per_epoch: 128
batch_size: 64
# Used only in callback,
# the final evaluation uses 100 samples
num_samples: 16
sampler: ddpm
sampler_params:
guidance: quantile
scale: 8
use_validation_set: True
eval_every: 50
device: cuda:0
setup: missing_values
# The following key will be ignored,
# if the setup is forecasting
missing_data_configs:
- missing_scenario: none
missing_values: 0
- missing_scenario: BM-E
missing_values: 168
- missing_scenario: BM-B
missing_values: 168
- missing_scenario: RM
missing_values: 168
================================================
FILE: configs/train_tsdiff/train_missing_traffic.yaml
================================================
model: unconditional
diffusion_config: diffusion_small_config
normalization: mean
use_features: False
use_lags: True
dataset: traffic_nips
freq: H
context_length: 336 # 360 for `D`
prediction_length: 24 # 30 for `D`
lr: 1.e-3
init_skip: True
gradient_clip_val: 0.5
max_epochs: 1000
num_batches_per_epoch: 128
batch_size: 64
# Used only in callback,
# the final evaluation uses 100 samples
num_samples: 4
sampler: ddpm
sampler_params:
guidance: quantile
scale: 4
use_validation_set: True
eval_every: 50
device: cuda:0
setup: missing_values
# The following key will be ignored,
# if the setup is forecasting
missing_data_configs:
- missing_scenario: none
missing_values: 0
- missing_scenario: BM-E
missing_values: 168
- missing_scenario: BM-B
missing_values: 168
- missing_scenario: RM
missing_values: 168
================================================
FILE: configs/train_tsdiff/train_missing_uber_tlc.yaml
================================================
model: unconditional
diffusion_config: diffusion_small_config
normalization: mean
use_features: False
use_lags: True
dataset: uber_tlc_hourly
freq: H
context_length: 336 # 360 for `D`
prediction_length: 24 # 30 for `D`
lr: 1.e-3
init_skip: False
gradient_clip_val: 0.5
max_epochs: 1000
num_batches_per_epoch: 128
batch_size: 64
# Used only in callback,
# the final evaluation uses 100 samples
num_samples: 16
sampler: ddpm
sampler_params:
guidance: quantile
scale: 2
use_validation_set: True
eval_every: 50
device: cuda:0
setup: missing_values
# The following key will be ignored,
# if the setup is forecasting
missing_data_configs:
- missing_scenario: none
missing_values: 0
- missing_scenario: BM-E
missing_values: 168
- missing_scenario: BM-B
missing_values: 168
- missing_scenario: RM
missing_values: 168
================================================
FILE: configs/train_tsdiff/train_solar.yaml
================================================
model: unconditional
diffusion_config: diffusion_small_config
normalization: mean
use_features: False
use_lags: True
dataset: solar_nips
freq: H
context_length: 336 # 360 for `D`
prediction_length: 24 # 30 for `D`
lr: 1.e-3
init_skip: False
gradient_clip_val: 0.5
max_epochs: 1000
num_batches_per_epoch: 128
batch_size: 64
# Used only in callback,
# the final evaluation uses 100 samples
num_samples: 16
sampler: ddpm
sampler_params:
guidance: quantile
scale: 8
use_validation_set: True
eval_every: 50
device: cuda:0
setup: forecasting
================================================
FILE: configs/train_tsdiff/train_traffic.yaml
================================================
model: unconditional
diffusion_config: diffusion_small_config
normalization: mean
use_features: False
use_lags: True
dataset: traffic_nips
freq: H
context_length: 336 # 360 for `D`
prediction_length: 24 # 30 for `D`
lr: 1.e-3
init_skip: True
gradient_clip_val: 0.5
max_epochs: 1000
num_batches_per_epoch: 128
batch_size: 64
# Used only in callback,
# the final evaluation uses 100 samples
num_samples: 4
sampler: ddpm
sampler_params:
guidance: quantile
scale: 4
use_validation_set: True
eval_every: 50
device: cuda:0
setup: forecasting
================================================
FILE: configs/train_tsdiff/train_uber_tlc.yaml
================================================
model: unconditional
diffusion_config: diffusion_small_config
normalization: mean
use_features: False
use_lags: True
dataset: uber_tlc_hourly
freq: H
context_length: 336 # 360 for `D`
prediction_length: 24 # 30 for `D`
lr: 1.e-3
init_skip: False
gradient_clip_val: 0.5
max_epochs: 1000
num_batches_per_epoch: 128
batch_size: 64
# Used only in callback,
# the final evaluation uses 100 samples
num_samples: 16
sampler: ddpm
sampler_params:
guidance: quantile
scale: 2
use_validation_set: True
eval_every: 50
device: cuda:0
setup: forecasting
================================================
FILE: configs/train_tsdiff/train_wiki.yaml
================================================
model: unconditional
diffusion_config: diffusion_small_config
normalization: mean
use_features: False
use_lags: False
dataset: wiki2000_nips
freq: 1D
context_length: 360 # 360 for `D`
prediction_length: 30 # 30 for `D`
lr: 1.e-3
init_skip: False
gradient_clip_val: 0.5
max_epochs: 1000
num_batches_per_epoch: 128
batch_size: 64
# Used only in callback,
# the final evaluation uses 100 samples
num_samples: 4
sampler: ddpm
sampler_params:
guidance: quantile
scale: 2
use_validation_set: True
eval_every: 50
device: cuda:0
setup: forecasting
================================================
FILE: configs/train_tsdiff-cond/electricity_nips.yaml
================================================
batch_size: 64
context_length: 336
dataset: electricity_nips
device: cuda:0
diffusion_config: diffusion_small_config
do_final_eval: true
eval_every: 10
freq: H
gradient_clip_val: 0.5
init_skip: false
lr: 0.001
max_epochs: 100
model: conditional
noise_observed: true
normalization: mean
num_batches_per_epoch: 128
prediction_length: 24
setup: forecasting
use_features: false
use_lags: true
use_validation_set: true
================================================
FILE: configs/train_tsdiff-cond/exchange_rate_nips.yaml
================================================
batch_size: 64
context_length: 360
dataset: exchange_rate_nips
device: cuda:0
diffusion_config: diffusion_small_config
do_final_eval: true
eval_every: 10
freq: B
gradient_clip_val: 0.5
init_skip: true
lr: 0.001
max_epochs: 100
model: conditional
noise_observed: true
normalization: mean
num_batches_per_epoch: 128
prediction_length: 30
setup: forecasting
use_features: false
use_lags: true
use_validation_set: true
================================================
FILE: configs/train_tsdiff-cond/kdd_cup_2018_without_missing.yaml
================================================
batch_size: 64
context_length: 312
dataset: kdd_cup_2018_without_missing
device: cuda:0
diffusion_config: diffusion_small_config
do_final_eval: true
eval_every: 10
freq: H
gradient_clip_val: 0.5
init_skip: true
lr: 0.001
max_epochs: 100
model: conditional
noise_observed: true
normalization: mean
num_batches_per_epoch: 128
prediction_length: 48
setup: forecasting
use_features: false
use_lags: true
use_validation_set: true
================================================
FILE: configs/train_tsdiff-cond/m4_hourly.yaml
================================================
batch_size: 64
context_length: 312
dataset: m4_hourly
device: cuda:0
diffusion_config: diffusion_small_config
do_final_eval: true
eval_every: 10
freq: H
gradient_clip_val: 0.5
init_skip: false
lr: 0.001
max_epochs: 100
model: conditional
noise_observed: true
normalization: mean
num_batches_per_epoch: 128
prediction_length: 48
setup: forecasting
use_features: false
use_lags: false
use_validation_set: true
================================================
FILE: configs/train_tsdiff-cond/missing_BM-B_electricity_nips.yaml
================================================
batch_size: 64
context_length: 336
dataset: electricity_nips
device: cuda:0
diffusion_config: diffusion_small_config
do_final_eval: true
eval_every: 10
freq: H
gradient_clip_val: 0.5
init_skip: false
lr: 0.001
max_epochs: 100
missing_scenario: BM-B
missing_values: 168
model: conditional
noise_observed: true
normalization: mean
num_batches_per_epoch: 128
prediction_length: 24
setup: missing_values
train_missing_scenario: BM-B
use_features: false
use_lags: true
use_validation_set: true
================================================
FILE: configs/train_tsdiff-cond/missing_BM-B_exchange_rate_nips.yaml
================================================
batch_size: 64
context_length: 360
dataset: exchange_rate_nips
device: cuda:0
diffusion_config: diffusion_small_config
do_final_eval: true
eval_every: 10
freq: B
gradient_clip_val: 0.5
init_skip: true
lr: 0.001
max_epochs: 100
missing_scenario: BM-B
missing_values: 180
model: conditional
noise_observed: true
normalization: mean
num_batches_per_epoch: 128
prediction_length: 30
setup: missing_values
train_missing_scenario: BM-B
use_features: false
use_lags: true
use_validation_set: true
================================================
FILE: configs/train_tsdiff-cond/missing_BM-B_kdd_cup_2018_without_missing.yaml
================================================
batch_size: 64
context_length: 312
dataset: kdd_cup_2018_without_missing
device: cuda:0
diffusion_config: diffusion_small_config
do_final_eval: true
eval_every: 10
freq: H
gradient_clip_val: 0.5
init_skip: true
lr: 0.001
max_epochs: 100
missing_scenario: BM-B
missing_values: 156
model: conditional
noise_observed: true
normalization: mean
num_batches_per_epoch: 128
prediction_length: 48
setup: missing_values
train_missing_scenario: BM-B
use_features: false
use_lags: true
use_validation_set: true
================================================
FILE: configs/train_tsdiff-cond/missing_BM-B_solar_nips.yaml
================================================
batch_size: 64
context_length: 336
dataset: solar_nips
device: cuda:0
diffusion_config: diffusion_small_config
do_final_eval: true
eval_every: 10
freq: H
gradient_clip_val: 0.5
init_skip: false
lr: 0.001
max_epochs: 100
missing_scenario: BM-B
missing_values: 168
model: conditional
noise_observed: true
normalization: mean
num_batches_per_epoch: 128
prediction_length: 24
setup: missing_values
train_missing_scenario: BM-B
use_features: false
use_lags: true
use_validation_set: true
================================================
FILE: configs/train_tsdiff-cond/missing_BM-B_traffic_nips.yaml
================================================
batch_size: 64
context_length: 336
dataset: traffic_nips
device: cuda:0
diffusion_config: diffusion_small_config
do_final_eval: true
eval_every: 10
freq: H
gradient_clip_val: 0.5
init_skip: true
lr: 0.001
max_epochs: 100
missing_scenario: BM-B
missing_values: 168
model: conditional
noise_observed: true
normalization: mean
num_batches_per_epoch: 128
prediction_length: 24
setup: missing_values
train_missing_scenario: BM-B
use_features: false
use_lags: true
use_validation_set: true
================================================
FILE: configs/train_tsdiff-cond/missing_BM-B_uber_tlc_hourly.yaml
================================================
batch_size: 64
context_length: 336
dataset: uber_tlc_hourly
device: cuda:0
diffusion_config: diffusion_small_config
do_final_eval: true
eval_every: 10
freq: H
gradient_clip_val: 0.5
init_skip: false
lr: 0.001
max_epochs: 100
missing_scenario: BM-B
missing_values: 168
model: conditional
noise_observed: true
normalization: mean
num_batches_per_epoch: 128
prediction_length: 24
setup: missing_values
train_missing_scenario: BM-B
use_features: false
use_lags: true
use_validation_set: true
================================================
FILE: configs/train_tsdiff-cond/missing_BM-E_electricity_nips.yaml
================================================
batch_size: 64
context_length: 336
dataset: electricity_nips
device: cuda:0
diffusion_config: diffusion_small_config
do_final_eval: true
eval_every: 10
freq: H
gradient_clip_val: 0.5
init_skip: false
lr: 0.001
max_epochs: 100
missing_scenario: BM-E
missing_values: 168
model: conditional
noise_observed: true
normalization: mean
num_batches_per_epoch: 128
prediction_length: 24
setup: missing_values
train_missing_scenario: BM-E
use_features: false
use_lags: true
use_validation_set: true
================================================
FILE: configs/train_tsdiff-cond/missing_BM-E_exchange_rate_nips.yaml
================================================
batch_size: 64
context_length: 360
dataset: exchange_rate_nips
device: cuda:0
diffusion_config: diffusion_small_config
do_final_eval: true
eval_every: 10
freq: B
gradient_clip_val: 0.5
init_skip: true
lr: 0.001
max_epochs: 100
missing_scenario: BM-E
missing_values: 180
model: conditional
noise_observed: true
normalization: mean
num_batches_per_epoch: 128
prediction_length: 30
setup: missing_values
train_missing_scenario: BM-E
use_features: false
use_lags: true
use_validation_set: true
================================================
FILE: configs/train_tsdiff-cond/missing_BM-E_kdd_cup_2018_without_missing.yaml
================================================
batch_size: 64
context_length: 312
dataset: kdd_cup_2018_without_missing
device: cuda:0
diffusion_config: diffusion_small_config
do_final_eval: true
eval_every: 10
freq: H
gradient_clip_val: 0.5
init_skip: true
lr: 0.001
max_epochs: 100
missing_scenario: BM-E
missing_values: 156
model: conditional
noise_observed: true
normalization: mean
num_batches_per_epoch: 128
prediction_length: 48
setup: missing_values
train_missing_scenario: BM-E
use_features: false
use_lags: true
use_validation_set: true
================================================
FILE: configs/train_tsdiff-cond/missing_BM-E_solar_nips.yaml
================================================
batch_size: 64
context_length: 336
dataset: solar_nips
device: cuda:0
diffusion_config: diffusion_small_config
do_final_eval: true
eval_every: 10
freq: H
gradient_clip_val: 0.5
init_skip: false
lr: 0.001
max_epochs: 100
missing_scenario: BM-E
missing_values: 168
model: conditional
noise_observed: true
normalization: mean
num_batches_per_epoch: 128
prediction_length: 24
setup: missing_values
train_missing_scenario: BM-E
use_features: false
use_lags: true
use_validation_set: true
================================================
FILE: configs/train_tsdiff-cond/missing_BM-E_traffic_nips.yaml
================================================
batch_size: 64
context_length: 336
dataset: traffic_nips
device: cuda:0
diffusion_config: diffusion_small_config
do_final_eval: true
eval_every: 10
freq: H
gradient_clip_val: 0.5
init_skip: true
lr: 0.001
max_epochs: 100
missing_scenario: BM-E
missing_values: 168
model: conditional
noise_observed: true
normalization: mean
num_batches_per_epoch: 128
prediction_length: 24
setup: missing_values
train_missing_scenario: BM-E
use_features: false
use_lags: true
use_validation_set: true
================================================
FILE: configs/train_tsdiff-cond/missing_BM-E_uber_tlc_hourly.yaml
================================================
batch_size: 64
context_length: 336
dataset: uber_tlc_hourly
device: cuda:0
diffusion_config: diffusion_small_config
do_final_eval: true
eval_every: 10
freq: H
gradient_clip_val: 0.5
init_skip: false
lr: 0.001
max_epochs: 100
missing_scenario: BM-E
missing_values: 168
model: conditional
noise_observed: true
normalization: mean
num_batches_per_epoch: 128
prediction_length: 24
setup: missing_values
train_missing_scenario: BM-E
use_features: false
use_lags: true
use_validation_set: true
================================================
FILE: configs/train_tsdiff-cond/missing_RM_electricity_nips.yaml
================================================
batch_size: 64
context_length: 336
dataset: electricity_nips
device: cuda:0
diffusion_config: diffusion_small_config
do_final_eval: true
eval_every: 10
freq: H
gradient_clip_val: 0.5
init_skip: false
lr: 0.001
max_epochs: 100
missing_scenario: RM
missing_values: 168
model: conditional
noise_observed: true
normalization: mean
num_batches_per_epoch: 128
prediction_length: 24
setup: missing_values
train_missing_scenario: RM
use_features: false
use_lags: true
use_validation_set: true
================================================
FILE: configs/train_tsdiff-cond/missing_RM_exchange_rate_nips.yaml
================================================
batch_size: 64
context_length: 360
dataset: exchange_rate_nips
device: cuda:0
diffusion_config: diffusion_small_config
do_final_eval: true
eval_every: 10
freq: B
gradient_clip_val: 0.5
init_skip: true
lr: 0.001
max_epochs: 100
missing_scenario: RM
missing_values: 180
model: conditional
noise_observed: true
normalization: mean
num_batches_per_epoch: 128
prediction_length: 30
setup: missing_values
train_missing_scenario: RM
use_features: false
use_lags: true
use_validation_set: true
================================================
FILE: configs/train_tsdiff-cond/missing_RM_kdd_cup_2018_without_missing.yaml
================================================
batch_size: 64
context_length: 312
dataset: kdd_cup_2018_without_missing
device: cuda:0
diffusion_config: diffusion_small_config
do_final_eval: true
eval_every: 10
freq: H
gradient_clip_val: 0.5
init_skip: true
lr: 0.001
max_epochs: 100
missing_scenario: RM
missing_values: 156
model: conditional
noise_observed: true
normalization: mean
num_batches_per_epoch: 128
prediction_length: 48
setup: missing_values
train_missing_scenario: RM
use_features: false
use_lags: true
use_validation_set: true
================================================
FILE: configs/train_tsdiff-cond/missing_RM_solar_nips.yaml
================================================
batch_size: 64
context_length: 336
dataset: solar_nips
device: cuda:0
diffusion_config: diffusion_small_config
do_final_eval: true
eval_every: 10
freq: H
gradient_clip_val: 0.5
init_skip: false
lr: 0.001
max_epochs: 100
missing_scenario: RM
missing_values: 168
model: conditional
noise_observed: true
normalization: mean
num_batches_per_epoch: 128
prediction_length: 24
setup: missing_values
train_missing_scenario: RM
use_features: false
use_lags: true
use_validation_set: true
================================================
FILE: configs/train_tsdiff-cond/missing_RM_traffic_nips.yaml
================================================
batch_size: 64
context_length: 336
dataset: traffic_nips
device: cuda:0
diffusion_config: diffusion_small_config
do_final_eval: true
eval_every: 10
freq: H
gradient_clip_val: 0.5
init_skip: true
lr: 0.001
max_epochs: 100
missing_scenario: RM
missing_values: 168
model: conditional
noise_observed: true
normalization: mean
num_batches_per_epoch: 128
prediction_length: 24
setup: missing_values
train_missing_scenario: RM
use_features: false
use_lags: true
use_validation_set: true
================================================
FILE: configs/train_tsdiff-cond/missing_RM_uber_tlc_hourly.yaml
================================================
batch_size: 64
context_length: 336
dataset: uber_tlc_hourly
device: cuda:0
diffusion_config: diffusion_small_config
do_final_eval: true
eval_every: 10
freq: H
gradient_clip_val: 0.5
init_skip: false
lr: 0.001
max_epochs: 100
missing_scenario: RM
missing_values: 168
model: conditional
noise_observed: true
normalization: mean
num_batches_per_epoch: 128
prediction_length: 24
setup: missing_values
train_missing_scenario: RM
use_features: false
use_lags: true
use_validation_set: true
================================================
FILE: configs/train_tsdiff-cond/solar_nips.yaml
================================================
batch_size: 64
context_length: 336
dataset: solar_nips
device: cuda:0
diffusion_config: diffusion_small_config
do_final_eval: true
eval_every: 10
freq: H
gradient_clip_val: 0.5
init_skip: false
lr: 0.001
max_epochs: 100
model: conditional
noise_observed: true
normalization: mean
num_batches_per_epoch: 128
prediction_length: 24
setup: forecasting
use_features: false
use_lags: true
use_validation_set: true
================================================
FILE: configs/train_tsdiff-cond/traffic_nips.yaml
================================================
batch_size: 64
context_length: 336
dataset: traffic_nips
device: cuda:0
diffusion_config: diffusion_small_config
do_final_eval: true
eval_every: 10
freq: H
gradient_clip_val: 0.5
init_skip: true
lr: 0.001
max_epochs: 100
model: conditional
noise_observed: true
normalization: mean
num_batches_per_epoch: 128
prediction_length: 24
setup: forecasting
use_features: false
use_lags: true
use_validation_set: true
================================================
FILE: configs/train_tsdiff-cond/uber_tlc_hourly.yaml
================================================
batch_size: 64
context_length: 336
dataset: uber_tlc_hourly
device: cuda:0
diffusion_config: diffusion_small_config
do_final_eval: true
eval_every: 10
freq: H
gradient_clip_val: 0.5
init_skip: false
lr: 0.001
max_epochs: 100
model: conditional
noise_observed: true
normalization: mean
num_batches_per_epoch: 128
prediction_length: 24
setup: forecasting
use_features: false
use_lags: true
use_validation_set: true
================================================
FILE: configs/train_tsdiff-cond/wiki2000_nips.yaml
================================================
batch_size: 64
context_length: 360
dataset: wiki2000_nips
device: cuda:0
diffusion_config: diffusion_small_config
do_final_eval: true
eval_every: 10
freq: 1D
gradient_clip_val: 0.5
init_skip: false
lr: 0.001
max_epochs: 100
model: conditional
noise_observed: true
normalization: mean
num_batches_per_epoch: 128
prediction_length: 30
setup: forecasting
use_features: false
use_lags: false
use_validation_set: true
================================================
FILE: configs/train_tsdiff-cond.yaml
================================================
model: conditional
diffusion_config: diffusion_small_config
normalization: mean
use_features: False
use_lags: True
dataset: solar_nips
freq: H
context_length: 336 # 360 for `D`
prediction_length: 24 # 30 for `D`
lr: 1.e-3
init_skip: False
gradient_clip_val: 0.5
max_epochs: 100
num_batches_per_epoch: 128
batch_size: 64
use_validation_set: True
eval_every: 10
device: cuda:0
noise_observed: True
do_final_eval: True
setup: missing_values
# The following keys will be ignored, if the setup is forecasting
train_missing_scenario: BM-E
missing_scenario: BM-E
missing_values: 168
================================================
FILE: configs/train_tsdiff.yaml
================================================
model: unconditional
diffusion_config: diffusion_small_config
normalization: mean
use_features: False
use_lags: False
dataset: solar_nips
freq: H
context_length: 336 # 360 for `D`
prediction_length: 24 # 30 for `D`
lr: 1.e-3
init_skip: True
gradient_clip_val: 0.5
max_epochs: 100
num_batches_per_epoch: 128
batch_size: 64
scale: 4
# Used only in callback,
# the final evaluation uses 100 samples
num_samples: 16
sampler: ddpm
sampler_params:
guidance: quantile
scale: 4
use_validation_set: True
eval_every: 50
device: cuda:0
setup: forecasting
do_final_eval: True
# The following key will be ignored,
# if the setup is forecasting
missing_data_configs:
- missing_scenario: BM-B
missing_values: 168
- missing_scenario: BM-E
missing_values: 168
================================================
FILE: configs/tstr/electricity_nips.yaml
================================================
ckpt: dummy/electricity_nips.ckpt
context_length: 336
dataset: electricity_nips
device: cuda:0
diffusion_config: diffusion_small_config
init_skip: false
prediction_length: 24
scaling_type: mean
use_features: false
use_lags: true
================================================
FILE: configs/tstr/exchange_rate_nips.yaml
================================================
ckpt: dummy/exchange_rate_nips.ckpt
context_length: 360
dataset: exchange_rate_nips
device: cuda:0
diffusion_config: diffusion_small_config
init_skip: true
prediction_length: 30
scaling_type: mean
use_features: false
use_lags: true
================================================
FILE: configs/tstr/kdd_cup_2018_without_missing.yaml
================================================
ckpt: dummy/kdd_cup_2018_without_missing.ckpt
context_length: 312
dataset: kdd_cup_2018_without_missing
device: cuda:0
diffusion_config: diffusion_small_config
init_skip: true
prediction_length: 48
scaling_type: mean
use_features: false
use_lags: true
================================================
FILE: configs/tstr/m4_hourly.yaml
================================================
ckpt: dummy/m4_hourly.ckpt
context_length: 312
dataset: m4_hourly
device: cuda:0
diffusion_config: diffusion_small_config
init_skip: false
prediction_length: 48
scaling_type: mean
use_features: false
use_lags: false
================================================
FILE: configs/tstr/solar_nips.yaml
================================================
ckpt: dummy/solar_nips.ckpt
context_length: 336
dataset: solar_nips
device: cuda:0
diffusion_config: diffusion_small_config
init_skip: false
prediction_length: 24
scaling_type: mean
use_features: false
use_lags: true
================================================
FILE: configs/tstr/traffic_nips.yaml
================================================
ckpt: dummy/traffic_nips.ckpt
context_length: 336
dataset: traffic_nips
device: cuda:0
diffusion_config: diffusion_small_config
init_skip: true
prediction_length: 24
scaling_type: mean
use_features: false
use_lags: true
================================================
FILE: configs/tstr/uber_tlc_hourly.yaml
================================================
ckpt: dummy/uber_tlc_hourly.ckpt
context_length: 336
dataset: uber_tlc_hourly
device: cuda:0
diffusion_config: diffusion_small_config
init_skip: false
prediction_length: 24
scaling_type: mean
use_features: false
use_lags: true
================================================
FILE: configs/tstr/wiki2000_nips.yaml
================================================
ckpt: dummy/wiki2000_nips.ckpt
context_length: 360
dataset: wiki2000_nips
device: cuda:0
diffusion_config: diffusion_small_config
init_skip: false
prediction_length: 30
scaling_type: mean
use_features: false
use_lags: false
================================================
FILE: configs/tstr.yaml
================================================
# Model & checkpoint parameters
dataset: solar_nips
device: cuda:0
ckpt: ckpts/solar_nips/version_236/1299_.ckpt
diffusion_config: diffusion_small_config
context_length: 336
prediction_length: 24
use_lags: true
use_features: false
init_skip: true
scaling_type: mean
================================================
FILE: pyproject.toml
================================================
[project]
name = "uncond-ts-diff"
version = "0.1.0"
description = "TSDiff: An Unconditional Diffusion Model for Time Series"
authors = []
dependencies = [
"torch~=1.13.1",
"pytorch-lightning~=1.9.4",
"gluonts[mxnet,pro]~=0.12.3",
"matplotlib",
"seaborn",
"opt_einsum~=3.3.0",
"einops",
"black",
"tqdm",
"scipy",
"scikit-learn",
"numba",
"jupyter",
"rich",
"pykeops==2.1.1",
]
readme = "README.md"
requires-python = ">= 3.8"
[tool.black]
line-length = 79
================================================
FILE: src/uncond_ts_diff/arch/__init__.py
================================================
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from .backbones import BackboneModel
__all__ = ["BackboneModel"]
================================================
FILE: src/uncond_ts_diff/arch/backbones.py
================================================
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import math
import torch
from torch import nn
from .s4 import S4
class SinusoidalPositionEmbeddings(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, time):
device = time.device
half_dim = self.dim // 2
embeddings = math.log(10000) / (half_dim - 1)
embeddings = torch.exp(
torch.arange(half_dim, device=device) * -embeddings
)
embeddings = time[:, None] * embeddings[None, :]
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
return embeddings
class S4Layer(nn.Module):
def __init__(
self,
d_model,
dropout=0.0,
):
super().__init__()
self.layer = S4(
d_model=d_model,
d_state=128,
bidirectional=True,
dropout=dropout,
transposed=True,
postact=None,
)
self.norm = nn.LayerNorm(d_model)
self.dropout = (
nn.Dropout1d(dropout) if dropout > 0.0 else nn.Identity()
)
def forward(self, x):
"""
Input x is shape (B, d_input, L)
"""
z = x
# Prenorm
z = self.norm(z.transpose(-1, -2)).transpose(-1, -2)
# Apply layer: we ignore the state input and output for training
z, _ = self.layer(z)
# Dropout on the output of the layer
z = self.dropout(z)
# Residual connection
x = z + x
return x, None
def default_state(self, *args, **kwargs):
return self.layer.default_state(*args, **kwargs)
def step(self, x, state, **kwargs):
z = x
# Prenorm
z = self.norm(z.transpose(-1, -2)).transpose(-1, -2)
# Apply layer
z, state = self.layer.step(z, state, **kwargs)
# Residual connection
x = z + x
return x, state
class S4Block(nn.Module):
def __init__(self, d_model, dropout=0.0, expand=2, num_features=0):
super().__init__()
self.s4block = S4Layer(d_model, dropout=dropout)
self.time_linear = nn.Linear(d_model, d_model)
self.tanh = nn.Tanh()
self.sigm = nn.Sigmoid()
self.out_linear1 = nn.Conv1d(
in_channels=d_model, out_channels=d_model, kernel_size=1
)
self.out_linear2 = nn.Conv1d(
in_channels=d_model, out_channels=d_model, kernel_size=1
)
self.feature_encoder = nn.Conv1d(num_features, d_model, kernel_size=1)
def forward(self, x, t, features=None):
t = self.time_linear(t)[:, None, :].repeat(1, x.shape[2], 1)
t = t.transpose(-1, -2)
out, _ = self.s4block(x + t)
if features is not None:
out = out + self.feature_encoder(features)
out = self.tanh(out) * self.sigm(out)
out1 = self.out_linear1(out)
out2 = self.out_linear2(out)
return out1 + x, out2
def Conv1dKaiming(in_channels, out_channels, kernel_size):
layer = nn.Conv1d(in_channels, out_channels, kernel_size)
nn.init.kaiming_normal_(layer.weight)
return layer
class BackboneModel(nn.Module):
def __init__(
self,
input_dim,
hidden_dim,
output_dim,
step_emb,
num_residual_blocks,
num_features,
residual_block="s4",
dropout=0.0,
init_skip=True,
):
super().__init__()
if residual_block == "s4":
residual_block = S4Block
else:
raise ValueError(f"Unknown residual block {residual_block}")
self.input_init = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
)
self.time_init = nn.Sequential(
nn.Linear(step_emb, hidden_dim),
nn.SiLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.SiLU(),
)
self.out_linear = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, output_dim),
)
residual_blocks = []
for i in range(num_residual_blocks):
residual_blocks.append(
residual_block(
hidden_dim, num_features=num_features, dropout=dropout
)
)
self.residual_blocks = nn.ModuleList(residual_blocks)
self.step_embedding = SinusoidalPositionEmbeddings(step_emb)
self.init_skip = init_skip
def forward(self, input, t, features=None):
x = self.input_init(input) # B, L ,C
t = self.time_init(self.step_embedding(t))
x = x.transpose(-1, -2)
if features is not None:
features = features.transpose(-1, -2)
skips = []
for layer in self.residual_blocks:
x, skip = layer(x, t, features)
skips.append(skip)
skip = torch.stack(skips).sum(0)
skip = skip.transpose(-1, -2)
out = self.out_linear(skip)
if self.init_skip:
out = out + input
return out
================================================
FILE: src/uncond_ts_diff/arch/s4.py
================================================
"""Standalone version of Structured (Sequence) State Space (S4) model."""
import logging
from functools import partial
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorch_lightning.utilities import rank_zero_only
from einops import rearrange, repeat
import opt_einsum as oe
contract = oe.contract
contract_expression = oe.contract_expression
def get_logger(name=__name__, level=logging.INFO) -> logging.Logger:
"""Initializes multi-GPU-friendly python logger."""
logger = logging.getLogger(name)
logger.setLevel(level)
# this ensures all logging levels get marked with the rank zero decorator
# otherwise logs would get multiplied for each GPU process in multi-GPU setup
for level in (
"debug",
"info",
"warning",
"error",
"exception",
"fatal",
"critical",
):
setattr(logger, level, rank_zero_only(getattr(logger, level)))
return logger
log = get_logger(__name__)
""" Cauchy and Vandermonde kernels """
try: # Try CUDA extension
from extensions.cauchy.cauchy import cauchy_mult
has_cauchy_extension = True
except ImportError:
# log.warning(
# "CUDA extension for cauchy multiplication not found. Install by going to extensions/cauchy/ and running `python setup.py install`. This should speed up end-to-end training by 10-50%"
# )
has_cauchy_extension = False
try: # Try pykeops
from pykeops.torch import Genred
has_pykeops = True
log.info("Pykeops installation found.")
def _broadcast_dims(*tensors):
max_dim = max([len(tensor.shape) for tensor in tensors])
tensors = [
tensor.view((1,) * (max_dim - len(tensor.shape)) + tensor.shape)
for tensor in tensors
]
return tensors
def cauchy_conj(v, z, w):
"""Pykeops version"""
expr_num = "z * ComplexReal(v) - Real2Complex(Sum(v * w))"
expr_denom = "ComplexMult(z-w, z-Conj(w))"
cauchy_mult = Genred(
f"ComplexDivide({expr_num}, {expr_denom})",
[
"v = Vj(2)",
"z = Vi(2)",
"w = Vj(2)",
],
reduction_op="Sum",
axis=1,
)
v, z, w = _broadcast_dims(v, z, w)
v = _c2r(v)
z = _c2r(z)
w = _c2r(w)
r = 2 * cauchy_mult(v, z, w, backend="GPU")
return _r2c(r)
def log_vandermonde(v, x, L):
expr = "ComplexMult(v, ComplexExp(ComplexMult(x, l)))"
vandermonde_mult = Genred(
expr,
[
"v = Vj(2)",
"x = Vj(2)",
"l = Vi(2)",
],
reduction_op="Sum",
axis=1,
)
l = torch.arange(L).to(x)
v, x, l = _broadcast_dims(v, x, l)
v = _c2r(v)
x = _c2r(x)
l = _c2r(l)
r = vandermonde_mult(v, x, l, backend="GPU")
return 2 * _r2c(r).real
def log_vandermonde_transpose(u, v, x, L):
"""
u: ... H L
v: ... H N
x: ... H N
Returns: ... H N
V = Vandermonde(a, L) : (H N L)
contract_L(V * u * v)
"""
expr = "ComplexMult(ComplexMult(v, u), ComplexExp(ComplexMult(x, l)))"
vandermonde_mult = Genred(
expr,
[
"u = Vj(2)",
"v = Vi(2)",
"x = Vi(2)",
"l = Vj(2)",
],
reduction_op="Sum",
axis=1,
)
l = torch.arange(L).to(x)
u, v, x, l = _broadcast_dims(u, v, x, l)
u = _c2r(u)
v = _c2r(v)
x = _c2r(x)
l = _c2r(l)
r = vandermonde_mult(u, v, x, l, backend="GPU")
return _r2c(r)
except ImportError:
has_pykeops = False
if not has_cauchy_extension:
log.warning(
"Falling back on slow Cauchy kernel. Install at least one of pykeops or the CUDA extension for efficiency."
)
def cauchy_naive(v, z, w):
"""
v, w: (..., N)
z: (..., L)
returns: (..., L)
"""
cauchy_matrix = v.unsqueeze(-1) / (
z.unsqueeze(-2) - w.unsqueeze(-1)
) # (... N L)
return torch.sum(cauchy_matrix, dim=-2)
# Vandermonde functions
log.warning(
"Falling back on slow Vandermonde kernel. Install pykeops for improved memory efficiency."
)
def log_vandermonde(v, x, L):
"""
v: (..., N)
x: (..., N)
returns: (..., L) \sum v x^l
"""
vandermonde_matrix = torch.exp(
x.unsqueeze(-1) * torch.arange(L).to(x)
) # (... N L)
vandermonde_prod = contract(
"... n, ... n l -> ... l", v, vandermonde_matrix
) # (... L)
return 2 * vandermonde_prod.real
def log_vandermonde_transpose(u, v, x, L):
vandermonde_matrix = torch.exp(
x.unsqueeze(-1) * torch.arange(L).to(x)
) # (... N L)
vandermonde_prod = contract(
"... l, ... n, ... n l -> ... n",
u.to(x),
v.to(x),
vandermonde_matrix,
) # (... L)
return vandermonde_prod
def _conj(x):
return torch.cat([x, x.conj()], dim=-1)
_c2r = torch.view_as_real
_r2c = torch.view_as_complex
if tuple(map(int, torch.__version__.split(".")[:2])) >= (1, 10):
def _resolve_conj(x):
return x.conj().resolve_conj()
else:
def _resolve_conj(x):
return x.conj()
""" Simple nn.Module components """
def Activation(activation=None, dim=-1):
if activation in [None, "id", "identity", "linear"]:
return nn.Identity()
elif activation == "tanh":
return nn.Tanh()
elif activation == "relu":
return nn.ReLU()
elif activation == "gelu":
return nn.GELU()
elif activation in ["swish", "silu"]:
return nn.SiLU()
elif activation == "glu":
return nn.GLU(dim=dim)
elif activation == "sigmoid":
return nn.Sigmoid()
else:
raise NotImplementedError(
"hidden activation '{}' is not implemented".format(activation)
)
def LinearActivation(
d_input,
d_output,
bias=True,
transposed=False,
activation=None,
activate=False, # Apply activation as part of this module
**kwargs,
):
"""Returns a linear nn.Module with control over axes order, initialization, and activation"""
# Construct core module
linear_cls = partial(nn.Conv1d, kernel_size=1) if transposed else nn.Linear
if activation == "glu":
d_output *= 2
linear = linear_cls(d_input, d_output, bias=bias, **kwargs)
if activate and activation is not None:
activation = Activation(activation, dim=-2 if transposed else -1)
linear = nn.Sequential(linear, activation)
return linear
class DropoutNd(nn.Module):
def __init__(self, p: float = 0.5, tie=True, transposed=True):
"""
tie: tie dropout mask across sequence lengths (Dropout1d/2d/3d)
"""
super().__init__()
if p < 0 or p >= 1:
raise ValueError(
"dropout probability has to be in [0, 1), "
"but got {}".format(p)
)
self.p = p
self.tie = tie
self.transposed = transposed
self.binomial = torch.distributions.binomial.Binomial(probs=1 - self.p)
def forward(self, X):
"""X: (batch, dim, lengths...)"""
if self.training:
if not self.transposed:
X = rearrange(X, "b d ... -> b ... d")
mask_shape = (
X.shape[:2] + (1,) * (X.ndim - 2) if self.tie else X.shape
)
mask = torch.rand(*mask_shape, device=X.device) < 1.0 - self.p
X = X * mask * (1.0 / (1 - self.p))
if not self.transposed:
X = rearrange(X, "b ... d -> b d ...")
return X
return X
""" Misc functional utilities """
def power(L, A, v=None):
"""Compute A^L and the scan sum_i A^i v_i
A: (..., N, N)
v: (..., N, L)
"""
I = torch.eye(A.shape[-1]).to(A) # , dtype=A.dtype, device=A.device)
powers = [A]
l = 1
while True:
if L % 2 == 1:
I = powers[-1] @ I
L //= 2
if L == 0:
break
l *= 2
powers.append(powers[-1] @ powers[-1])
if v is None:
return I
# Invariants:
# powers[-1] := A^l
# l := largest po2 at most L
# Note that an alternative divide and conquer to compute the reduction is possible and can be embedded into the above loop without caching intermediate powers of A
# We do this reverse divide-and-conquer for efficiency reasons:
# 1) it involves fewer padding steps for non-po2 L
# 2) it involves more contiguous arrays
# Take care of edge case for non-po2 arrays
# Note that this initial step is a no-op for the case of power of 2 (l == L)
k = v.size(-1) - l
v_ = powers.pop() @ v[..., l:]
v = v[..., :l]
v[..., :k] = v[..., :k] + v_
# Handle reduction for power of 2
while v.size(-1) > 1:
v = rearrange(v, "... (z l) -> ... z l", z=2)
v = v[..., 0, :] + powers.pop() @ v[..., 1, :]
return I, v.squeeze(-1)
""" HiPPO utilities """
def transition(measure, N):
"""A, B transition matrices for different measures"""
# Legendre (translated)
if measure == "legt":
Q = np.arange(N, dtype=np.float64)
R = (2 * Q + 1) ** 0.5
j, i = np.meshgrid(Q, Q)
A = R[:, None] * np.where(i < j, (-1.0) ** (i - j), 1) * R[None, :]
B = R[:, None]
A = -A
# Halve again for timescale correctness
A *= 0.5
B *= 0.5
# Legendre (scaled)
elif measure == "legs":
q = np.arange(N, dtype=np.float64)
col, row = np.meshgrid(q, q)
r = 2 * q + 1
M = -(np.where(row >= col, r, 0) - np.diag(q))
T = np.sqrt(np.diag(2 * q + 1))
A = T @ M @ np.linalg.inv(T)
B = np.diag(T)[:, None]
B = (
B.copy()
) # Otherwise "UserWarning: given NumPY array is not writeable..." after torch.as_tensor(B)
elif measure == "legsd":
# Essentially equivalent to S4D-LegS
q = np.arange(N, dtype=np.float64)
col, row = np.meshgrid(q, q)
r = 2 * q + 1
M = -(np.where(row >= col, r, 0) - np.diag(q))
T = np.sqrt(np.diag(2 * q + 1))
A = T @ M @ np.linalg.inv(T)
B = np.diag(T)[:, None]
B = (
B.copy()
) # Otherwise "UserWarning: given NumPY array is not writeable..." after torch.as_tensor(B)
A += 0.5 * B * B[None, :, 0]
B = B / 2.0
elif measure in ["fourier_diag", "foud"]:
# Essentially equivalent to S4D-Lin
freqs = np.arange(N // 2)
d = np.stack([freqs, np.zeros(N // 2)], axis=-1).reshape(-1)[:-1]
A = 2 * np.pi * (-np.diag(d, 1) + np.diag(d, -1))
A = A - 0.5 * np.eye(N)
B = np.zeros(N)
B[0::2] = 2**0.5
B[0] = 1
B = B[:, None]
elif measure in ["fourier", "fout"]:
freqs = np.arange(N // 2)
d = np.stack([np.zeros(N // 2), freqs], axis=-1).reshape(-1)[1:]
A = np.pi * (-np.diag(d, 1) + np.diag(d, -1))
B = np.zeros(N)
B[0::2] = 2**0.5
B[0] = 1
# Subtract off rank correction - this corresponds to the other endpoint u(t-1) in this case
A = A - B[:, None] * B[None, :]
B = B[:, None]
else:
raise NotImplementedError
return A, B
def rank_correction(measure, N, rank=1, dtype=torch.float):
"""Return low-rank matrix L such that A + L is normal"""
if measure == "legs":
assert rank >= 1
P = torch.sqrt(0.5 + torch.arange(N, dtype=dtype)).unsqueeze(
0
) # (1 N)
elif measure == "legt":
assert rank >= 2
P = torch.sqrt(1 + 2 * torch.arange(N, dtype=dtype)) # (N)
P0 = P.clone()
P0[0::2] = 0.0
P1 = P.clone()
P1[1::2] = 0.0
P = torch.stack([P0, P1], dim=0) # (2 N)
P *= 2 ** (
-0.5
) # Halve the rank correct just like the original matrix was halved
elif measure in ["fourier", "fout"]:
P = torch.zeros(N)
P[0::2] = 2**0.5
P[0] = 1
P = P.unsqueeze(0)
elif measure in ["fourier_diag", "foud", "legsd"]:
P = torch.zeros(1, N, dtype=dtype)
else:
raise NotImplementedError
d = P.size(0)
if rank > d:
P = torch.cat(
[P, torch.zeros(rank - d, N, dtype=dtype)], dim=0
) # (rank N)
return P
def nplr(measure, N, rank=1, dtype=torch.float, diagonalize_precision=True):
"""Return w, p, q, V, B such that
(w - p q^*, B) is unitarily equivalent to the original HiPPO A, B by the matrix V
i.e. A = V[w - p q^*]V^*, B = V B
"""
assert dtype == torch.float or dtype == torch.double
cdtype = torch.cfloat if dtype == torch.float else torch.cdouble
A, B = transition(measure, N)
A = torch.as_tensor(A, dtype=dtype) # (N, N)
B = torch.as_tensor(B, dtype=dtype)[:, 0] # (N,)
P = rank_correction(measure, N, rank=rank, dtype=dtype) # (r N)
AP = A + torch.sum(P.unsqueeze(-2) * P.unsqueeze(-1), dim=-3)
# We require AP to be nearly skew-symmetric
_A = AP + AP.transpose(-1, -2)
if (
err := torch.sum((_A - _A[0, 0] * torch.eye(N)) ** 2) / N
) > 1e-5: # if not torch.allclose(_A - _A[0,0]*torch.eye(N), torch.zeros(N, N), atol=1e-5):
print("WARNING: HiPPO matrix not skew symmetric", err)
# Take advantage of identity + skew-symmetric form to calculate real and imaginary parts separately
# Imaginary part can use eigh instead of eig
w_re = torch.mean(torch.diagonal(AP), -1, keepdim=True)
# Diagonalize in double precision
if diagonalize_precision:
AP = AP.to(torch.double)
w_im, V = torch.linalg.eigh(AP * -1j) # (..., N) (..., N, N)
if diagonalize_precision:
w_im, V = w_im.to(cdtype), V.to(cdtype)
w = w_re + 1j * w_im
# Check: V w V^{-1} = A
# print("check", V @ torch.diag_embed(w) @ V.conj().transpose(-1, -2))
# Only keep half of each conjugate pair
_, idx = torch.sort(w.imag)
w_sorted = w[idx]
V_sorted = V[:, idx]
# There is an edge case when eigenvalues can be 0, which requires some machinery to handle
# We use a huge hack here: Assume only one pair is 0, and that it is the first row/column of A (only happens in Fourier case)
V = V_sorted[:, : N // 2]
w = w_sorted[: N // 2]
assert (
w[-2].abs() > 1e-4
), "Only 1 zero eigenvalue allowed in diagonal part of A"
if w[-1].abs() < 1e-4:
V[:, -1] = 0.0
V[0, -1] = 2**-0.5
V[1, -1] = 2**-0.5 * 1j
_AP = V @ torch.diag_embed(w) @ V.conj().transpose(-1, -2)
if (err := torch.sum((2 * _AP.real - AP) ** 2) / N) > 1e-5:
print(
"Warning: Diagonalization of A matrix not numerically precise - error",
err,
)
# print("check", V @ torch.diag_embed(w) @ V.conj().transpose(-1, -2))
V_inv = V.conj().transpose(-1, -2)
B = contract("ij, j -> i", V_inv, B.to(V)) # V^* B
P = contract("ij, ...j -> ...i", V_inv, P.to(V)) # V^* P
return w, P, B, V
def dplr(
scaling,
N,
rank=1,
H=1,
dtype=torch.float,
real_scale=1.0,
imag_scale=1.0,
random_real=False,
random_imag=False,
normalize=False,
diagonal=True,
random_B=False,
):
assert dtype == torch.float or dtype == torch.double
dtype = torch.cfloat if dtype == torch.float else torch.cdouble
pi = torch.tensor(math.pi)
if random_real:
real_part = torch.rand(H, N // 2)
else:
real_part = 0.5 * torch.ones(H, N // 2)
if random_imag:
imag_part = N // 2 * torch.rand(H, N // 2)
else:
imag_part = repeat(torch.arange(N // 2), "n -> h n", h=H)
real_part = real_scale * real_part
if scaling == "random":
imag_part = torch.randn(H, N // 2)
elif scaling == "real":
imag_part = 0 * imag_part
real_part = 1 + repeat(torch.arange(N // 2), "n -> h n", h=H)
elif scaling in ["linear", "lin"]:
imag_part = pi * imag_part
elif scaling in [
"inverse",
"inv",
]: # Based on asymptotics of the default HiPPO matrix
imag_part = 1 / pi * N * (N / (1 + 2 * imag_part) - 1)
elif scaling in ["inverse2", "inv2"]:
imag_part = 1 / pi * N * (N / (1 + imag_part) - 1)
elif scaling in ["quadratic", "quad"]:
imag_part = 1 / pi * (1 + 2 * imag_part) ** 2
elif scaling in ["legs", "hippo"]:
w, _, _, _ = nplr("legsd", N)
imag_part = w.imag
else:
raise NotImplementedError
imag_part = imag_scale * imag_part
w = -real_part + 1j * imag_part
# Initialize B
if random_B:
B = torch.randn(H, N // 2, dtype=dtype)
else:
B = torch.ones(H, N // 2, dtype=dtype)
if normalize:
norm = (
-B / w
) # (H, N) # Result if you integrate the kernel with constant 1 function
zeta = 2 * torch.sum(
torch.abs(norm) ** 2, dim=-1, keepdim=True
) # Variance with a random C vector
B = B / zeta**0.5
P = torch.randn(rank, H, N // 2, dtype=dtype)
if diagonal:
P = P * 0.0
V = torch.eye(N, dtype=dtype)[:: N // 2] # Only used in testing
V = repeat(V, "n m -> h n m", h=H)
return w, P, B, V
def ssm(measure, N, R, H, **ssm_args):
"""Dispatcher to create single SSM initialization
N: state size
R: rank (for DPLR parameterization)
H: number of independent SSM copies
"""
if measure == "dplr":
w, P, B, V = dplr(N=N, rank=R, H=H, **ssm_args)
elif measure.startswith("diag"):
args = measure.split("-")
assert args[0] == "diag" and len(args) > 1
scaling = args[1]
w, P, B, V = dplr(
scaling=scaling, N=N, rank=R, H=H, diagonal=True, **ssm_args
)
else:
w, P, B, V = nplr(measure, N, R, **ssm_args)
w = repeat(w, "n -> s n", s=H)
P = repeat(P, "r n -> r s n", s=H)
B = repeat(B, "n -> s n", s=H)
V = repeat(V, "n m -> s n m", s=H)
return w, P, B, V
combinations = {
"hippo": ["legs", "fourier"],
"diag": ["diag-inv", "diag-lin"],
"all": ["legs", "fourier", "diag-inv", "diag-lin"],
}
def combination(measures, N, R, S, **ssm_args):
if isinstance(measures, str):
measures = (
combinations[measures] if measures in combinations else [measures]
)
assert (
S % len(measures) == 0
), f"{S} independent trainable SSM copies must be multiple of {len(measures)} different measures"
w, P, B, V = zip(
*[
ssm(measure, N, R, S // len(measures), **ssm_args)
for measure in measures
]
)
w = torch.cat(w, dim=0) # (S N)
P = torch.cat(P, dim=1) # (R S N)
B = torch.cat(B, dim=0) # (S N)
V = torch.cat(V, dim=0) # (S N N)
return w, P, B, V
class OptimModule(nn.Module):
"""Interface for Module that allows registering buffers/parameters with configurable optimizer hyperparameters"""
def register(self, name, tensor, lr=None):
"""Register a tensor with a configurable learning rate and 0 weight decay"""
if lr == 0.0:
self.register_buffer(name, tensor)
else:
self.register_parameter(name, nn.Parameter(tensor))
optim = {"weight_decay": 0.0}
if lr is not None:
optim["lr"] = lr
setattr(getattr(self, name), "_optim", optim)
class SSKernelNPLR(OptimModule):
"""Stores a representation of and computes the SSKernel function K_L(A^dt, B^dt, C) corresponding to a discretized state space, where A is Normal + Low Rank (NPLR)"""
@torch.no_grad()
def _setup_C(self, L):
"""Construct C~ from C
Two modes are supported: go directly to length L if self.L is 1, or length is doubled
"""
if self.L.item() == 0:
if self.verbose:
log.info(f"S4: Initializing kernel to length {L}")
double_length = False
elif L > self.L.item(): # 2*int(self.L) == L:
if self.verbose:
log.info(
f"S4: Doubling length from L = {self.L.item()} to {2*self.L.item()}"
)
double_length = True
L = self.L.item() # Convenience for the math below
else:
return
C = _r2c(self.C)
dA, _ = self._setup_state()
dA_L = power(L, dA)
# Multiply C by I - dA_L
C_ = _conj(C)
prod = contract("h m n, c h n -> c h m", dA_L.transpose(-1, -2), C_)
if double_length:
prod = -prod # Multiply by I + dA_L instead
C_ = C_ - prod
C_ = C_[..., : self.N] # Take conjugate pairs again
self.C.copy_(_c2r(C_))
self.L = (
2 * self.L if double_length else self.L + L
) # Preserve type/device
def _omega(self, L, dtype, device, cache=True):
"""Calculate (and cache) FFT nodes and their "unprocessed" version with the bilinear transform
This should be called everytime the internal length self.L changes"""
# Use cached if available
if (
cache
and hasattr(self, "omega")
and self.omega.size(-1) == L // 2 + 1
):
return self.omega, self.z
omega = torch.tensor(
np.exp(-2j * np.pi / (L)), dtype=dtype, device=device
) # \omega_{2L}
omega = omega ** torch.arange(0, L // 2 + 1, device=device)
z = 2 * (1 - omega) / (1 + omega)
# Cache if necessary
if cache:
self.omega = omega
self.z = z
return omega, z
def __init__(
self,
w,
P,
B,
C,
log_dt,
L=None, # starting/maximum length of kernel
lr=None,
verbose=False,
keops=False,
real_type="exp", # ['none' | 'exp' | 'relu' | sigmoid']
real_tolerance=1e-3,
bandlimit=None,
):
"""
L: Maximum length; this module computes an SSM kernel of length L
A is represented by diag(w) - PP^*
w: (S, N) diagonal part
P: (R, S, N) low-rank part
B: (S, N)
C: (C, H, N)
dt: (H) timescale per feature
lr: [dict | float | None] hook to set lr of special parameters (A, B, dt)
Dimensions:
N (or d_state): state size
H (or d_model): total SSM copies
S (or n_ssm): number of trainable copies of (A, B, dt); must divide H
R (or rank): rank of low-rank part
C (or channels): system is 1-dim to C-dim
The forward pass of this Module returns a tensor of shape (C, H, L)
Note: tensor shape N here denotes half the true state size, because of conjugate symmetry
"""
super().__init__()
self.verbose = verbose
self.keops = keops
self.bandlimit = bandlimit
self.real_type = real_type
self.real_tolerance = real_tolerance
# Rank of low-rank correction
self.rank = P.shape[-3]
assert w.size(-1) == P.size(-1) == B.size(-1) == C.size(-1)
self.H = log_dt.size(-1)
self.N = w.size(-1)
# Check different SSM inits
assert w.size(-2) == P.size(-2) == B.size(-2) # n_ssm
assert self.H % w.size(0) == 0
self.n_ssm = w.size(0)
self.repeat = self.H // w.size(
0
) # Each trainable SSM needs to be duplicated this many times
# Broadcast everything to correct shapes
C = C.expand(
torch.broadcast_shapes(C.shape, (1, self.H, self.N))
) # (C, H, N)
B = B.unsqueeze(0) # (1, 1, N)
# Register parameters
self.C = nn.Parameter(_c2r(_resolve_conj(C)))
if lr is None or isinstance(lr, float):
lr_dict = {}
else:
lr_dict, lr = lr, None
self.register("log_dt", log_dt, lr_dict.get("dt", lr))
self.register("B", _c2r(B), lr_dict.get("B", lr))
self.register("P", _c2r(P), lr_dict.get("A", lr))
self.register("inv_w_real", self._w_init(w.real), lr_dict.get("A", lr))
self.register("w_imag", w.imag, lr_dict.get("A", lr))
self.l_max = L
self.register_buffer("L", torch.tensor(0)) # Internal length
def _w_init(self, w_real):
w_real = torch.clamp(w_real, max=-self.real_tolerance)
if self.real_type == "none":
return -w_real
elif self.real_type == "exp":
return torch.log(
-w_real
) # Some of the HiPPO methods have real part 0
elif self.real_type == "relu":
return -w_real
elif self.real_type == "sigmoid":
return torch.logit(-w_real)
elif self.real_type == "softplus":
return torch.log(torch.exp(-w_real) - 1)
else:
raise NotImplementedError
def _w(self):
# Get the internal w (diagonal) parameter
if self.real_type == "none":
w_real = -self.inv_w_real
elif self.real_type == "exp":
w_real = -torch.exp(self.inv_w_real)
elif self.real_type == "relu":
w_real = -F.relu(self.inv_w_real)
elif self.real_type == "sigmoid":
w_real = -F.sigmoid(self.inv_w_real)
elif self.real_type == "softplus":
w_real = -F.softplus(self.inv_w_real)
else:
raise NotImplementedError
w = w_real + 1j * self.w_imag
return w
def forward(self, state=None, rate=1.0, L=None):
"""
state: (B, H, N) initial state
rate: sampling rate factor
L: target length
returns:
(C, H, L) convolution kernel (generally C=1)
(B, H, L) output from initial state
"""
# Initialize C~ if necessary (done in forward pass so it's on the correct device)
if self.L.item() == 0 and self.l_max is not None and self.l_max > 0:
self._setup_C(self.l_max)
# Handle sampling rate logic
# The idea is that this kernel's length (in continuous units) is self.L, while we are asked to provide a kernel of length L at (relative) frequency rate
if L is None:
L = round(self.L.item() / rate)
# Increase the internal length if needed
continuous_L = round(rate * L)
while continuous_L > self.L.item():
self._setup_C(continuous_L)
discrete_L = round(self.L.item() / rate)
dt = torch.exp(self.log_dt) * rate
B = _r2c(self.B)
C = _r2c(self.C)
P = _r2c(self.P)
Q = P.conj()
w = self._w() # (n_ssm, N)
# Address bandlimiting
if self.bandlimit is not None:
freqs = w.imag.abs() / (2 * math.pi) # (H, N)
freqs = dt[:, None] / rate * freqs # (H, N)
mask = torch.where(freqs < self.bandlimit * 0.5, 1, 0)
C = C * mask
# Get FFT nodes of right length
omega, z = self._omega(
discrete_L, dtype=w.dtype, device=w.device, cache=(rate == 1.0)
)
# Broadcast parameters to same hidden features H
B = repeat(B, "1 t n -> 1 (v t) n", v=self.repeat)
P = repeat(P, "r t n -> r (v t) n", v=self.repeat)
Q = repeat(Q, "r t n -> r (v t) n", v=self.repeat)
w = repeat(w, "t n -> (v t) n", v=self.repeat)
# Augment B
if state is not None:
# Have to "unbilinear" the state to put it into the same "type" as B
# Compute 1/dt * (I + dt/2 A) @ state
# Can do this without expanding (maybe minor speedup using conj symmetry in theory), but it's easier to read this way
s = _conj(state) if state.size(-1) == self.N else state # (B H N)
sA = s * _conj(w) - contract( # (B H N)
"bhm, rhm, rhn -> bhn", s, _conj(Q), _conj(P)
)
s = s / dt.unsqueeze(-1) + sA / 2
s = s[..., : self.N]
B = torch.cat([s, B], dim=-3) # (B+1, H, N)
# Incorporate dt into A
w = w * dt.unsqueeze(-1) # (H N)
# Stack B and p, C and q for convenient batching
B = torch.cat([B, P], dim=-3) # (B+1+R, H, N)
C = torch.cat([C, Q], dim=-3) # (C+R, H, N)
# Incorporate B and C batch dimensions
v = B.unsqueeze(-3) * C.unsqueeze(-4) # (B+1+R, C+R, H, N)
# Calculate resolvent at omega
if has_cauchy_extension and z.dtype == torch.cfloat and not self.keops:
r = cauchy_mult(v, z, w, symmetric=True)
elif has_pykeops:
r = cauchy_conj(v, z, w)
else:
r = cauchy_naive(v, z, w)
r = r * dt[None, None, :, None] # (B+1+R, C+R, H, L)
# Low-rank Woodbury correction
if self.rank == 1:
k_f = r[:-1, :-1, :, :] - r[:-1, -1:, :, :] * r[-1:, :-1, :, :] / (
1 + r[-1:, -1:, :, :]
)
elif self.rank == 2:
r00 = r[: -self.rank, : -self.rank, :, :]
r01 = r[: -self.rank, -self.rank :, :, :]
r10 = r[-self.rank :, : -self.rank, :, :]
r11 = r[-self.rank :, -self.rank :, :, :]
det = (1 + r11[:1, :1, :, :]) * (1 + r11[1:, 1:, :, :]) - r11[
:1, 1:, :, :
] * r11[1:, :1, :, :]
s = (
r01[:, :1, :, :] * (1 + r11[1:, 1:, :, :]) * r10[:1, :, :, :]
+ r01[:, 1:, :, :] * (1 + r11[:1, :1, :, :]) * r10[1:, :, :, :]
- r01[:, :1, :, :] * (r11[:1, 1:, :, :]) * r10[1:, :, :, :]
- r01[:, 1:, :, :] * (r11[1:, :1, :, :]) * r10[:1, :, :, :]
)
s = s / det
k_f = r00 - s
else:
r00 = r[: -self.rank, : -self.rank, :, :]
r01 = r[: -self.rank, -self.rank :, :, :]
r10 = r[-self.rank :, : -self.rank, :, :]
r11 = r[-self.rank :, -self.rank :, :, :]
r11 = rearrange(r11, "a b h n -> h n a b")
r11 = torch.linalg.inv(torch.eye(self.rank, device=r.device) + r11)
r11 = rearrange(r11, "h n a b -> a b h n")
k_f = r00 - torch.einsum(
"i j h n, j k h n, k l h n -> i l h n", r01, r11, r10
)
# Final correction for the bilinear transform
k_f = k_f * 2 / (1 + omega)
# Move from frequency to coefficients
k = torch.fft.irfft(k_f, n=discrete_L) # (B+1, C, H, L)
# # Truncate to target length
k = k[..., :L]
if state is not None:
k_state = k[:-1, :, :, :] # (B, C, H, L)
else:
k_state = None
k_B = k[-1, :, :, :] # (C H L)
return k_B, k_state
@torch.no_grad()
def _setup_linear(self):
"""Create parameters that allow fast linear stepping of state"""
w = self._w()
B = _r2c(self.B) # (H N)
P = _r2c(self.P)
Q = P.conj()
# Repeat w shape properly
B = repeat(B, "1 t n -> 1 (v t) n", v=self.repeat)
P = repeat(P, "r t n -> r (v t) n", v=self.repeat)
Q = repeat(Q, "r t n -> r (v t) n", v=self.repeat)
w = repeat(w, "t n -> (v t) n", v=self.repeat)
# Prepare Linear stepping
dt = torch.exp(self.log_dt)
D = (2.0 / dt.unsqueeze(-1) - w).reciprocal() # (H, N)
R = (
torch.eye(self.rank, dtype=w.dtype, device=w.device)
+ 2 * contract("r h n, h n, s h n -> h r s", Q, D, P).real
) # (H R R)
Q_D = rearrange(Q * D, "r h n -> h r n")
try:
R = torch.linalg.solve(R, Q_D) # (H R N)
except Exception:
R = torch.tensor(
np.linalg.solve(
R.to(Q_D).contiguous().detach().cpu(),
Q_D.contiguous().detach().cpu(),
)
).to(Q_D)
R = rearrange(R, "h r n -> r h n")
self.step_params = {
"D": D, # (H N)
"R": R, # (R H N)
"P": P, # (R H N)
"Q": Q, # (R H N)
"B": B, # (1 H N)
"E": 2.0 / dt.unsqueeze(-1) + w, # (H N)
}
def _step_state_linear(self, u=None, state=None):
"""
Version of the step function that has time O(N) instead of O(N^2) per step, which takes advantage of the DPLR form and bilinear discretization.
Unfortunately, as currently implemented it's about 2x slower because it calls several sequential operations. Perhaps a fused CUDA kernel implementation would be much faster
u: (H) input
state: (H, N/2) state with conjugate pairs
Optionally, the state can have last dimension N
Returns: same shape as state
"""
C = _r2c(self.C) # View used for dtype/device
if u is None: # Special case used to find dA
u = torch.zeros(self.H, dtype=C.dtype, device=C.device)
if state is None: # Special case used to find dB
state = torch.zeros(self.H, self.N, dtype=C.dtype, device=C.device)
step_params = self.step_params.copy()
if (
state.size(-1) == self.N
): # Only store half of the conjugate pairs; should be true by default
# There should be a slightly faster way using conjugate symmetry
def contract_fn(p, x, y):
return contract(
"r h n, r h m, ... h m -> ... h n",
_conj(p),
_conj(x),
_conj(y),
)[
..., : self.N
] # inner outer product
else:
assert state.size(-1) == 2 * self.N
step_params = {k: _conj(v) for k, v in step_params.items()}
# TODO worth setting up a contract_expression in default_state if we want to use this at inference time for stepping
def contract_fn(p, x, y):
return contract(
"r h n, r h m, ... h m -> ... h n", p, x, y
) # inner outer product
D = step_params["D"] # (H N)
E = step_params["E"] # (H N)
R = step_params["R"] # (R H N)
P = step_params["P"] # (R H N)
Q = step_params["Q"] # (R H N)
B = step_params["B"] # (1 H N)
new_state = E * state - contract_fn(P, Q, state) # (B H N)
new_state = new_state + 2.0 * B * u.unsqueeze(-1) # (B H N)
new_state = D * (new_state - contract_fn(P, R, new_state))
return new_state
def _setup_state(self):
"""Construct dA and dB for discretized state equation"""
# Construct dA and dB by using the stepping
self._setup_linear()
C = _r2c(
self.C
) # Just returns a view that we use for finding dtype/device
state = torch.eye(
2 * self.N, dtype=C.dtype, device=C.device
).unsqueeze(
-2
) # (N 1 N)
dA = self._step_state_linear(state=state)
dA = rearrange(dA, "n h m -> h m n")
u = C.new_ones(self.H)
dB = self._step_state_linear(u=u)
dB = _conj(dB)
dB = rearrange(dB, "1 h n -> h n") # (H N)
return dA, dB
def _step_state(self, u, state):
"""Must be called after self.default_state() is used to construct an initial state!"""
next_state = self.state_contraction(
self.dA, state
) + self.input_contraction(self.dB, u)
return next_state
def _setup_step(self, mode="dense"):
"""Set up dA, dB, dC discretized parameters for stepping"""
self.dA, self.dB = self._setup_state()
# Calculate original C
C = _conj(_r2c(self.C)) # (H C N)
if self.L.item() == 0:
dC = C
else:
# self.C represents C_tilde
dA_L = power(self.L.item(), self.dA)
I = torch.eye(self.dA.size(-1)).to(dA_L)
dC = torch.linalg.solve(
I - dA_L.transpose(-1, -2),
C.unsqueeze(-1),
).squeeze(-1)
self.dC = dC
# Do special preprocessing for different step modes
self._step_mode = mode
if mode == "linear":
# Linear case: special step function for the state, we need to handle output
# use conjugate symmetry by default, which affects the output projection
self.dC = 2 * self.dC[:, :, : self.N]
elif mode == "diagonal":
# Eigendecomposition of the A matrix
L, V = torch.linalg.eig(self.dA)
V_inv = torch.linalg.inv(V)
# Check that the eigendedecomposition is correct
if self.verbose:
print(
"Diagonalization error:",
torch.dist(V @ torch.diag_embed(L) @ V_inv, self.dA),
)
# Change the parameterization to diagonalize
self.dA = L
self.dB = contract("h n m, h m -> h n", V_inv, self.dB)
self.dC = contract("h n m, c h n -> c h m", V, self.dC)
elif mode == "dense":
pass
else:
raise NotImplementedError(
"NPLR Kernel step mode must be {'dense' | 'linear' | 'diagonal'}"
)
def default_state(self, *batch_shape):
C = _r2c(self.C)
N = C.size(-1)
H = C.size(-2)
# Cache the tensor contractions we will later do, for efficiency
# These are put in this function because they depend on the batch size
step_mode = getattr(
self, "_step_mode", "dense"
) # Used in default_state, which is called without _setup_step() in forward_state()
if step_mode != "linear":
N *= 2
if step_mode == "diagonal":
self.state_contraction = contract_expression(
"h n, ... h n -> ... h n",
(H, N),
batch_shape + (H, N),
)
else:
# Dense (quadratic) case: expand all terms
self.state_contraction = contract_expression(
"h m n, ... h n -> ... h m",
(H, N, N),
batch_shape + (H, N),
)
self.input_contraction = contract_expression(
"h n, ... h -> ... h n",
(H, N), # self.dB.shape
batch_shape + (H,),
)
self.output_contraction = contract_expression(
"c h n, ... h n -> ... c h",
(C.shape[0], H, N), # self.dC.shape
batch_shape + (H, N),
)
state = torch.zeros(*batch_shape, H, N, dtype=C.dtype, device=C.device)
return state
def step(self, u, state):
"""Must have called self._setup_step() and created state with self.default_state() before calling this"""
if self._step_mode == "linear":
new_state = self._step_state_linear(u, state)
else:
new_state = self._step_state(u, state)
y = self.output_contraction(self.dC, new_state)
return y.real, new_state
class SSKernelDiag(OptimModule):
"""Version using (complex) diagonal state matrix (S4D)"""
def __init__(
self,
A,
B,
C,
log_dt,
L=None,
disc="bilinear",
real_type="exp",
lr=None,
bandlimit=None,
):
super().__init__()
self.L = L
self.disc = disc
self.bandlimit = bandlimit
self.real_type = real_type
# Rank of low-rank correction
assert A.size(-1) == C.size(-1)
self.H = log_dt.size(-1)
self.N = A.size(-1)
assert A.size(-2) == B.size(-2) # Number of independent SSMs trained
assert self.H % A.size(-2) == 0
self.n_ssm = A.size(-2)
self.repeat = self.H // A.size(0)
self.channels = C.shape[0]
self.C = nn.Parameter(_c2r(_resolve_conj(C)))
# Register parameters
if lr is None or isinstance(lr, float):
lr_dict = {}
else:
lr_dict, lr = lr, None
self.register("log_dt", log_dt, lr_dict.get("dt", lr))
self.register("B", _c2r(B), lr_dict.get("B", lr))
self.register("inv_A_real", self._A_init(A.real), lr_dict.get("A", lr))
self.register("A_imag", A.imag, lr_dict.get("A", lr))
def _A_init(self, A_real):
A_real = torch.clamp(A_real, max=-1e-4)
if self.real_type == "none":
return -A_real
elif self.real_type == "exp":
return torch.log(
-A_real
) # Some of the HiPPO methods have real part 0
elif self.real_type == "relu":
return -A_real
elif self.real_type == "sigmoid":
return torch.logit(-A_real)
elif self.real_type == "softplus":
return torch.log(torch.exp(-A_real) - 1)
else:
raise NotImplementedError
def _A(self):
# Get the internal A (diagonal) parameter
if self.real_type == "none":
A_real = -self.inv_A_real
elif self.real_type == "exp":
A_real = -torch.exp(self.inv_A_real)
elif self.real_type == "relu":
# JAX version seems to NaN if you alloA 0's, although this code Aas fine Aithout it
A_real = -F.relu(self.inv_A_real) - 1e-4
elif self.real_type == "sigmoid":
A_real = -F.sigmoid(self.inv_A_real)
elif self.real_type == "softplus":
A_real = -F.softplus(self.inv_A_real)
else:
raise NotImplementedError
A = A_real + 1j * self.A_imag
return A
def forward(self, L, state=None, rate=1.0, u=None):
"""
state: (B, H, N) initial state
rate: sampling rate factor
L: target length
returns:
(C, H, L) convolution kernel (generally C=1)
(B, H, L) output from initial state
"""
dt = torch.exp(self.log_dt) * rate # (H)
C = _r2c(self.C) # (C H N)
A = self._A() # (H N)
B = _r2c(self.B)
B = repeat(B, "t n -> 1 (v t) n", v=self.repeat)
if self.bandlimit is not None:
freqs = dt[:, None] / rate * A.imag.abs() / (2 * math.pi) # (H, N)
mask = torch.where(freqs < self.bandlimit * 0.5, 1, 0)
C = C * mask
# Incorporate dt into A
A = repeat(A, "t n -> (v t) n", v=self.repeat)
dtA = A * dt.unsqueeze(-1) # (H N)
# Augment B with state
if state is not None:
s = state / dt.unsqueeze(-1)
if self.disc == "bilinear":
s = s * (1.0 + dtA / 2)
elif self.disc == "zoh":
s = s * dtA * dtA.exp() / (dtA.exp() - 1.0)
B = torch.cat([s, B], dim=-3) # (1+B H N)
C = (B[:, None, :, :] * C).view(-1, self.H, self.N)
if self.disc == "zoh":
# Power up
C = C * (torch.exp(dtA) - 1.0) / A
K = log_vandermonde(C, dtA, L) # (H L)
elif self.disc == "bilinear":
C = (
C * (1.0 - dtA / 2).reciprocal() * dt.unsqueeze(-1)
) # or * dtA / A
dA = (1.0 + dtA / 2) / (1.0 - dtA / 2)
K = log_vandermonde(C, dA.log(), L)
elif self.disc == "dss":
# Implementation from DSS meant for case when real eigenvalues can be positive
P = dtA.unsqueeze(-1) * torch.arange(L, device=C.device) # [H N L]
A_gt_0 = A.real > 0 # [N]
if A_gt_0.any():
with torch.no_grad():
P_max = dtA * (A_gt_0 * (L - 1)) # [H N]
P = P - P_max.unsqueeze(-1) # [H N L]
S = P.exp() # [H N L]
dtA_neg = dtA * (1 - 2 * A_gt_0) # [H N]
num = dtA_neg.exp() - 1 # [H N]
den = (dtA_neg * L).exp() - 1 # [H N]
# Inline reciprocal function for DSS logic
x = den * A
x_conj = _resolve_conj(x)
r = x_conj / (x * x_conj + 1e-7)
C = C * num * r # [C H N]
K = contract("chn,hnl->chl", C, S).float()
else:
assert False, f"{self.disc} not supported"
K = K.view(-1, self.channels, self.H, L) # (1+B C H L)
if state is not None:
K_state = K[:-1, :, :, :] # (B C H L)
else:
K_state = None
K = K[-1, :, :, :] # (C H L)
return K, K_state
def _setup_step(self):
# These methods are organized like this to be compatible with the NPLR kernel interface
dt = torch.exp(self.log_dt) # (H)
B = _r2c(self.B) # (H N)
C = _r2c(self.C) # (C H N)
self.dC = C
A = self._A() # (H N)
A = repeat(A, "t n -> (v t) n", v=self.repeat)
B = repeat(B, "t n -> (v t) n", v=self.repeat)
# Incorporate dt into A
dtA = A * dt.unsqueeze(-1) # (H N)
if self.disc == "zoh":
self.dA = torch.exp(dtA) # (H N)
self.dB = B * (torch.exp(dtA) - 1.0) / A # (C H N)
elif self.disc == "bilinear":
self.dA = (1.0 + dtA / 2) / (1.0 - dtA / 2)
self.dB = (
B * (1.0 - dtA / 2).reciprocal() * dt.unsqueeze(-1)
) # or * dtA / A
def default_state(self, *batch_shape):
C = _r2c(self.C)
state = torch.zeros(
*batch_shape, self.H, self.N, dtype=C.dtype, device=C.device
)
return state
def step(self, u, state):
next_state = contract(
"h n, b h n -> b h n", self.dA, state
) + contract("h n, b h -> b h n", self.dB, u)
y = contract("c h n, b h n -> b c h", self.dC, next_state)
return 2 * y.real, next_state
def forward_state(self, u, state):
self._setup_step()
AL = self.dA ** u.size(-1)
u = u.flip(-1).to(self.dA).contiguous() # (B H L)
v = log_vandermonde_transpose(u, self.dB, self.dA.log(), u.size(-1))
next_state = AL * state + v
return next_state
class SSKernel(nn.Module):
"""Wrapper around SSKernel parameterizations.
The SSKernel is expected to support the interface
forward()
default_state()
_setup_step()
step()
"""
def __init__(
self,
H,
N=64,
L=None,
measure="legs",
rank=1,
channels=1,
dt_min=0.001,
dt_max=0.1,
deterministic=False,
lr=None,
mode="nplr",
n_ssm=None,
verbose=False,
measure_args={},
**kernel_args,
):
"""State Space Kernel which computes the convolution kernel $\\bar{K}$
H: Number of independent SSM copies; controls the size of the model. Also called d_model in the config.
N: State size (dimensionality of parameters A, B, C). Also called d_state in the config. Generally shouldn't need to be adjusted and doens't affect speed much.
L: Maximum length of convolution kernel, if known. Should work in the majority of cases even if not known.
measure: Options for initialization of (A, B). For NPLR mode, recommendations are "legs", "fout", "hippo" (combination of both). For Diag mode, recommendations are "diag-inv", "diag-lin", "diag-legs", and "diag" (combination of diag-inv and diag-lin)
rank: Rank of low-rank correction for NPLR mode. Needs to be increased for measure "legt"
channels: C channels turns the SSM from a 1-dim to C-dim map; can think of it having C separate "heads" per SSM. This was partly a feature to make it easier to implement bidirectionality; it is recommended to set channels=1 and adjust H to control parameters instead
dt_min, dt_max: min and max values for the step size dt (\Delta)
mode: Which kernel algorithm to use. 'nplr' is the full S4 model; 'diag' is the simpler S4D; 'slow' is a dense version for testing
n_ssm: Number of independent trainable (A, B) SSMs, e.g. n_ssm=1 means all A/B parameters are tied across the H different instantiations of C. n_ssm=None means all H SSMs are completely independent. Generally, changing this option can save parameters but doesn't affect performance or speed much. This parameter must divide H
lr: Passing in a number (e.g. 0.001) sets attributes of SSM parameers (A, B, dt). A custom optimizer hook is needed to configure the optimizer to set the learning rates appropriately for these parameters.
"""
super().__init__()
self.N = N
self.H = H
dtype, cdtype = torch.float, torch.cfloat
self.channels = channels
self.n_ssm = n_ssm if n_ssm is not None else H
self.mode = mode
self.verbose = verbose
self.kernel_args = kernel_args
# Generate dt
if deterministic:
log_dt = torch.exp(
torch.linspace(math.log(dt_min), math.log(dt_max), H)
)
else:
log_dt = torch.rand(self.H, dtype=dtype) * (
math.log(dt_max) - math.log(dt_min)
) + math.log(dt_min)
# Compute the preprocessed representation
w, P, B, V = combination(
measure, self.N, rank, self.n_ssm, **measure_args
)
# Broadcast C to have H channels
if deterministic:
C = torch.zeros(channels, self.n_ssm, self.N, dtype=cdtype)
C[:, :, :1] = 1.0
C = contract(
"hmn, chn -> chm", V.conj().transpose(-1, -2), C
) # V^* C
C = (
repeat(C, "c t n -> c (v t) n", v=self.n_ssm // C.size(-2))
.clone()
.contiguous()
)
else:
C = torch.randn(channels, self.H, self.N // 2, dtype=cdtype)
# Broadcast other parameters to have n_ssm copies
assert (
self.n_ssm % B.size(-2) == 0
and self.n_ssm % P.size(-2) == 0
and self.n_ssm % w.size(-2) == 0
)
# Broadcast tensors to n_ssm copies
# These will be the parameters, so make sure tensors are materialized and contiguous
B = (
repeat(B, "t n -> (v t) n", v=self.n_ssm // B.size(-2))
.clone()
.contiguous()
)
P = (
repeat(P, "r t n -> r (v t) n", v=self.n_ssm // P.size(-2))
.clone()
.contiguous()
)
w = (
repeat(w, "t n -> (v t) n", v=self.n_ssm // w.size(-2))
.clone()
.contiguous()
)
if mode == "nplr":
self.kernel = SSKernelNPLR(
w,
P,
B,
C,
log_dt,
L=L,
lr=lr,
verbose=verbose,
**kernel_args,
)
elif mode == "diag":
if not measure.startswith("diag"):
log.warning(
"Diagonal kernel (S4D) activated but initialization is not intended for S4D. Set `measure` to 'diag-lin', 'diag-inv', or 'diag-legs' for the main variants, or 'diag' for a combination of S4D-Lin and S4D-Inv."
)
C = C * repeat(B, "t n -> (v t) n", v=H // self.n_ssm)
self.kernel = SSKernelDiag(
w,
B,
C,
log_dt,
L=L,
lr=lr,
**kernel_args,
)
else:
raise NotImplementedError(f"{mode=} is not valid")
def forward(self, state=None, L=None, rate=1.0):
return self.kernel(state=state, L=L, rate=rate)
@torch.no_grad()
def forward_state(self, u, state):
"""Forward the state through a sequence, i.e. computes the state after passing chunk through SSM
state: (B, H, N)
u: (B, H, L)
Returns: (B, H, N)
"""
if hasattr(self.kernel, "forward_state"):
return self.kernel.forward_state(u, state)
dA, dB = self.kernel._setup_state() # Construct dA, dB matrices
# dA, dB = self.kernel.dA, self.kernel.dB # (H N N) (H N)
conj = state.size(-1) != dA.size(-1)
if conj:
state = _conj(state)
v = contract(
"h n, b h l -> b h n l", dB, u.flip(-1)
) # dB.unsqueeze(-1) * u.flip(-1).unsqueeze(-2)
AL, v = power(u.size(-1), dA, v)
next_state = contract("h m n, b h n -> b h m", AL, state)
next_state = next_state + v
if conj:
next_state = next_state[..., : next_state.size(-1) // 2]
return next_state
def _setup_step(self, **kwargs):
# This method is intended to be private so that setting up an S4 module with
# ```
# if hasattr(module, 'setup_step'): module.setup_step()
# ```
# will not trigger this method multiple times
self.kernel._setup_step(**kwargs)
def step(self, u, state, **kwargs):
y, state = self.kernel.step(u, state, **kwargs)
return y, state
def default_state(self, *args, **kwargs):
return self.kernel.default_state(*args, **kwargs)
class S4(nn.Module):
def __init__(
self,
d_model,
d_state=64,
l_max=None,
channels=1,
bidirectional=False,
# Arguments for position-wise feedforward components
activation="gelu",
postact="glu",
hyper_act=None,
dropout=0.0,
tie_dropout=False,
bottleneck=None,
gate=None,
transposed=True,
verbose=False,
# SSM Kernel arguments
**kernel_args,
):
"""
d_state: the dimension of the state, also denoted by N
l_max: the maximum kernel length, also denoted by L. Set l_max=None to always use a global kernel
channels: can be interpreted as a number of "heads"; the SSM is a map from a 1-dim to C-dim sequence. It's not recommended to change this unless desperate for things to tune; instead, increase d_model for larger models
bidirectional: if True, convolution kernel will be two-sided
Position-wise feedforward components:
--------------------
activation: activation in between SS and FF
postact: activation after FF
hyper_act: use a "hypernetwork" multiplication (experimental)
dropout: standard dropout argument. tie_dropout=True ties the dropout mask across the sequence length, emulating nn.Dropout1d
Other arguments:
--------------------
transposed: choose backbone axis ordering of (B, L, H) (if False) or (B, H, L) (if True) [B=batch size, L=sequence length, H=hidden dimension]
gate: add gated activation (GSS)
bottleneck: reduce SSM dimension (GSS)
See the class SSKernel for the kernel constructor which accepts kernel_args. Relevant options that are worth considering and tuning include "mode" + "measure", "dt_min", "dt_max", "lr"
Other options are all experimental and should not need to be configured
"""
super().__init__()
if verbose:
log.info(
f"Constructing S4 (H, N, L) = ({d_model}, {d_state}, {l_max})"
)
self.d_model = d_model
self.H = d_model
self.N = d_state
self.L = l_max
self.bidirectional = bidirectional
self.channels = channels
self.transposed = transposed
self.gate = gate
self.bottleneck = bottleneck
if bottleneck is not None:
self.H = self.H // bottleneck
self.input_linear = LinearActivation(
self.d_model,
self.H,
transposed=self.transposed,
activation=activation,
activate=True,
)
if gate is not None:
self.input_gate = LinearActivation(
self.d_model,
self.d_model * gate,
transposed=self.transposed,
activation=activation,
activate=True,
)
self.output_gate = LinearActivation(
self.d_model * gate,
self.d_model,
transposed=self.transposed,
activation=None,
activate=False,
)
# optional multiplicative modulation GLU-style
# https://arxiv.org/abs/2002.05202
self.hyper = hyper_act is not None
if self.hyper:
channels *= 2
self.hyper_activation = Activation(hyper_act)
self.D = nn.Parameter(torch.randn(channels, self.H))
if self.bidirectional:
channels *= 2
# SSM Kernel
self.kernel = SSKernel(
self.H,
N=self.N,
L=self.L,
channels=channels,
verbose=verbose,
**kernel_args,
)
# Pointwise
self.activation = Activation(activation)
dropout_fn = DropoutNd if tie_dropout else nn.Dropout
self.dropout = dropout_fn(dropout) if dropout > 0.0 else nn.Identity()
# position-wise output transform to mix features
self.output_linear = LinearActivation(
self.H * self.channels,
self.d_model * (1 if self.gate is None else self.gate),
transposed=self.transposed,
activation=postact,
activate=True,
)
def forward(self, u, state=None, rate=1.0, lengths=None, **kwargs):
"""
u: (B H L) if self.transposed else (B L H)
state: (H N) never needed unless you know what you're doing
Returns: same shape as u
"""
if not self.transposed:
u = u.transpose(-1, -2)
L = u.size(-1)
# Mask out padding tokens
if isinstance(lengths, int):
if lengths != L:
lengths = torch.tensor(
lengths, dtype=torch.long, device=u.device
)
else:
lengths = None
if lengths is not None:
assert (
isinstance(lengths, torch.Tensor)
and lengths.ndim == 1
and lengths.size(0) in [1, u.size(0)]
)
mask = torch.where(
torch.arange(L, device=lengths.device)
< lengths[:, None, None],
1.0,
0.0,
)
u = u * mask
if self.gate is not None:
v = self.input_gate(u)
if self.bottleneck is not None:
u = self.input_linear(u)
# Compute SS Kernel
L_kernel = L if self.L is None else min(L, round(self.L / rate))
k, k_state = self.kernel(
L=L_kernel, rate=rate, state=state
) # (C H L) (B C H L)
# Convolution
if self.bidirectional:
k0, k1 = rearrange(k, "(s c) h l -> s c h l", s=2)
k = F.pad(k0, (0, L)) + F.pad(k1.flip(-1), (L, 0))
k_f = torch.fft.rfft(k, n=L_kernel + L) # (C H L)
u_f = torch.fft.rfft(u, n=L_kernel + L) # (B H L)
y_f = contract("bhl,chl->bchl", u_f, k_f)
y = torch.fft.irfft(y_f, n=L_kernel + L)[..., :L] # (B C H L)
# Compute D term in state space equation - essentially a skip connection
y = y + contract("bhl,ch->bchl", u, self.D)
# Compute state update
if state is not None:
assert (
not self.bidirectional
), "Bidirectional not supported with state forwarding"
y = y + k_state #
next_state = self.kernel.forward_state(u, state)
else:
next_state = None
# Optional hyper-network multiplication
if self.hyper:
y, yh = rearrange(y, "b (s c) h l -> s b c h l", s=2)
y = self.hyper_activation(yh) * y
# Reshape to flatten channels
y = rearrange(y, "... c h l -> ... (c h) l")
y = self.dropout(self.activation(y))
if not self.transposed:
y = y.transpose(-1, -2)
y = self.output_linear(y)
if self.gate is not None:
y = self.output_gate(y * v)
return y, next_state
def setup_step(self, **kwargs):
self.kernel._setup_step(**kwargs)
def step(self, u, state):
"""Step one time step as a recurrent model. Intended to be used during validation.
u: (B H)
state: (B H N)
Returns: output (B H), state (B H N)
"""
assert not self.training
y, next_state = self.kernel.step(u, state) # (B C H)
y = y + u.unsqueeze(-2) * self.D
y = rearrange(y, "b c h -> b (c h)")
y = self.activation(y)
if self.transposed:
y = self.output_linear(y.unsqueeze(-1)).squeeze(-1)
else:
y = self.output_linear(y)
return y, next_state
def default_state(self, *batch_shape, device=None):
# kernel is not a SequenceModule so it doesn't need to adhere to same interface
# the kernel will know the device of its own parameters
return self.kernel.default_state(*batch_shape)
@property
def d_output(self):
return self.d_model
================================================
FILE: src/uncond_ts_diff/configs.py
================================================
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from uncond_ts_diff.utils import linear_beta_schedule
residual_block_s4_backbone = {
"input_dim": 1,
"hidden_dim": 128,
"output_dim": 1,
"step_emb": 128,
"num_residual_blocks": 6,
"residual_block": "s4",
}
residual_block_s4_backbone_smallv2 = {
"input_dim": 1,
"hidden_dim": 512,
"output_dim": 1,
"step_emb": 128,
"num_residual_blocks": 3,
"residual_block": "s4",
}
residual_block_s4_backbone_small = {
"input_dim": 1,
"hidden_dim": 64,
"output_dim": 1,
"step_emb": 128,
"num_residual_blocks": 3,
"residual_block": "s4",
}
residual_block_s4_backbone_small_dropout01 = {
"input_dim": 1,
"hidden_dim": 64,
"output_dim": 1,
"step_emb": 128,
"num_residual_blocks": 3,
"dropout": 0.1,
"residual_block": "s4",
}
residual_block_s4_backbone_small_dropout02 = {
"input_dim": 1,
"hidden_dim": 64,
"output_dim": 1,
"step_emb": 128,
"num_residual_blocks": 3,
"dropout": 0.2,
"residual_block": "s4",
}
residual_block_s4_backbone_small_dropout03 = {
"input_dim": 1,
"hidden_dim": 64,
"output_dim": 1,
"step_emb": 128,
"num_residual_blocks": 3,
"dropout": 0.3,
"residual_block": "s4",
}
residual_block_s4_backbone_large = {
"input_dim": 1,
"hidden_dim": 128,
"output_dim": 1,
"step_emb": 128,
"num_residual_blocks": 18,
"residual_block": "s4",
}
diffusion_config = {
"backbone_parameters": residual_block_s4_backbone,
"timesteps": 100,
"diffusion_scheduler": linear_beta_schedule,
}
diffusion_small_config = {
"backbone_parameters": residual_block_s4_backbone_small,
"timesteps": 100,
"diffusion_scheduler": linear_beta_schedule,
}
diffusion_small_configv2 = {
"backbone_parameters": residual_block_s4_backbone_smallv2,
"timesteps": 100,
"diffusion_scheduler": linear_beta_schedule,
}
diffusion_small_config_dropout = {
"backbone_parameters": residual_block_s4_backbone_small_dropout01,
"timesteps": 100,
"diffusion_scheduler": linear_beta_schedule,
}
diffusion_small_config_dropout02 = {
"backbone_parameters": residual_block_s4_backbone_small_dropout02,
"timesteps": 100,
"diffusion_scheduler": linear_beta_schedule,
}
diffusion_small_config_dropout03 = {
"backbone_parameters": residual_block_s4_backbone_small_dropout03,
"timesteps": 100,
"diffusion_scheduler": linear_beta_schedule,
}
diffusion_large_config = {
"backbone_parameters": residual_block_s4_backbone_large,
"timesteps": 100,
"diffusion_scheduler": linear_beta_schedule,
}
================================================
FILE: src/uncond_ts_diff/dataset.py
================================================
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import os
import tarfile
from pathlib import Path
from urllib import request
from gluonts.dataset.common import load_datasets
from gluonts.dataset.repository.datasets import get_dataset, get_download_path
default_dataset_path: Path = get_download_path() / "datasets"
wiki2k_download_link: str = "https://github.com/awslabs/gluonts/raw/b89f203595183340651411a41eeb0ee60570a4d9/datasets/wiki2000_nips.tar.gz" # noqa: E501
def get_gts_dataset(dataset_name):
if dataset_name == "wiki2000_nips":
wiki_dataset_path = default_dataset_path / dataset_name
Path(default_dataset_path).mkdir(parents=True, exist_ok=True)
if not wiki_dataset_path.exists():
tar_file_path = wiki_dataset_path.parent / f"{dataset_name}.tar.gz"
request.urlretrieve(
wiki2k_download_link,
tar_file_path,
)
with tarfile.open(tar_file_path) as tar:
tar.extractall(path=wiki_dataset_path.parent)
os.remove(tar_file_path)
return load_datasets(
metadata=wiki_dataset_path / "metadata",
train=wiki_dataset_path / "train",
test=wiki_dataset_path / "test",
)
else:
return get_dataset(dataset_name)
================================================
FILE: src/uncond_ts_diff/metrics/__init__.py
================================================
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from .linear_pred_score import linear_pred_score
__all__ = ["linear_pred_score"]
================================================
FILE: src/uncond_ts_diff/metrics/linear_pred_score.py
================================================
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from typing import Tuple
from functools import partial
import numpy as np
from gluonts.evaluation import Evaluator
from gluonts.dataset.split import slice_data_entry
from gluonts.transform import AdhocTransform, Chain
from uncond_ts_diff.model import LinearEstimator
from uncond_ts_diff.utils import (
GluonTSNumpyDataset,
ScaleAndAddMeanFeature,
ScaleAndAddMinMaxFeature,
make_evaluation_predictions_with_scaling,
)
def linear_pred_score(
samples: np.ndarray,
context_length: int,
prediction_length: int,
test_dataset,
num_samples: int = 1,
scaling_type: str = "mean",
) -> Tuple[dict, list, list]:
"""Compute the linear predictive score.
Uses the `samples` to to fit a LinearRegression model
and evaluate the forecast performance on the provided
`test_dataset`.
Parameters
----------
samples
The samples used to fit the linear regression model.
A numpy array of shape [N, T].
Assumed to be already scaled.
context_length
The context length for the linear model.
prediction_length
The prediction length for the linear model.
Must be the same as the prediction length of the
target `test_dataset`.
test_datastet
The test dataset on which the linear model will
be evaluated.
num_samples, optional
Number of samples to draw from the linear model.
Since the linear model is a point forecaster,
`num_samples` > 1 would just result in the forecast
being repeated `num_samples` times, by default 1
scaling_type, optional
Scaling type should be one of {"mean", "min-max"}
Min-max scaling is used in TimeGAN, defaults to "mean"
Returns
-------
Evaluation metrics, target test time series and forecasts
"""
min_past = context_length + prediction_length
assert samples.shape[1] >= min_past
dataset = GluonTSNumpyDataset(samples)
linear_predictor = LinearEstimator(
freq="H", # Not actually used in the estimator
prediction_length=prediction_length,
context_length=context_length,
num_train_samples=len(dataset),
# Since `samples` are synthetic samples, they are assumed to be already scaled
scaling=False,
).train(dataset)
# The linear predictor has been trained on scaled samples,
# however, the test dataset is still in the original space.
# Therefore, the test time series need to be sliced and
# scaled before being fed into the predictor.
# After prediction, the time series must be scaled back to
# the original space for metric computation.
# The following lines of code perform this custom evaluation.
# Slice test set to be of the same length as context_length + prediction_length
slice_func = partial(slice_data_entry, slice_=slice(-min_past, None))
if scaling_type == "mean":
ScaleAndAddScaleFeature = ScaleAndAddMeanFeature
elif scaling_type == "min-max":
ScaleAndAddScaleFeature = ScaleAndAddMinMaxFeature
transformation = Chain(
[
AdhocTransform(slice_func),
# Add scale to data entry for use later during evaluation
ScaleAndAddScaleFeature("target", "scale", prediction_length),
]
)
sliced_test_set = transformation.apply(test_dataset)
evaluator = Evaluator()
forecast_it, ts_it = make_evaluation_predictions_with_scaling(
dataset=sliced_test_set,
predictor=linear_predictor,
num_samples=num_samples,
scaling_type=scaling_type,
)
forecasts = list(forecast_it)
tss = list(ts_it)
metrics, _ = evaluator(tss, forecasts)
return metrics, tss, forecasts
================================================
FILE: src/uncond_ts_diff/model/__init__.py
================================================
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from .diffusion.tsdiff import TSDiff
from .diffusion.tsdiff_cond import TSDiffCond
from .linear._estimator import LinearEstimator
__all__ = [
"TSDiff",
"TSDiffCond",
"LinearEstimator",
]
================================================
FILE: src/uncond_ts_diff/model/callback.py
================================================
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from copy import deepcopy
import math
from pathlib import Path
import numpy as np
import torch
from gluonts.dataset.field_names import FieldName
from gluonts.evaluation import make_evaluation_predictions, Evaluator
from gluonts.transform import TestSplitSampler, InstanceSplitter
from pytorch_lightning import Callback
from uncond_ts_diff.sampler import DDPMGuidance, DDIMGuidance
from uncond_ts_diff.metrics import linear_pred_score
from uncond_ts_diff.utils import ConcatDataset
class GradNormCallback(Callback):
def __init__(self) -> None:
super().__init__()
def on_before_optimizer_step(
self,
trainer,
pl_module,
optimizer,
opt_idx: int,
) -> None:
return pl_module.log(
"grad_norm", self.grad_norm(pl_module.parameters()), prog_bar=True
)
def grad_norm(self, parameters):
parameters = [p for p in parameters if p.grad is not None]
device = parameters[0].grad.device
total_norm = torch.norm(
torch.stack(
[torch.norm(p.grad.detach(), 2).to(device) for p in parameters]
),
2,
)
return total_norm
class PredictiveScoreCallback(Callback):
def __init__(
self,
context_length,
prediction_length,
model,
transformation,
train_dataloader,
train_batch_size,
test_dataset,
eval_every=10,
):
super().__init__()
self.context_length = context_length
self.prediction_length = prediction_length
self.model = model
self.transformation = transformation
self.train_dataloader = train_dataloader
self.train_batch_size = train_batch_size
self.test_dataset = test_dataset
self.eval_every = eval_every
# Number of samples used to train the downstream predictor
self.n_pred_samples = 10000
def _generate_real_samples(
self,
data_loader,
num_samples: int,
n_timesteps: int,
batch_size: int,
cache_path: Path,
):
if cache_path.exists():
real_samples = np.load(cache_path)
if len(real_samples) == num_samples:
return real_samples
real_samples = []
data_iter = iter(data_loader)
n_iters = math.ceil(num_samples / batch_size)
for i in range(n_iters):
try:
batch = next(data_iter)
except StopIteration:
data_iter = iter(data_loader)
batch = next(data_iter)
ts = np.concatenate(
[batch["past_target"], batch["future_target"]], axis=-1
)[:, -n_timesteps:]
real_samples.append(ts)
real_samples = np.concatenate(real_samples, axis=0)[:num_samples]
np.save(cache_path, real_samples)
return real_samples
def _generate_synth_samples(
self, model, num_samples: int, batch_size: int = 1000
):
synth_samples = []
n_iters = math.ceil(num_samples / batch_size)
for _ in range(n_iters):
samples = model.sample_n(num_samples=batch_size)
synth_samples.append(samples)
synth_samples = np.concatenate(synth_samples, axis=0)[:num_samples]
return synth_samples
def on_train_epoch_end(self, trainer, pl_module):
if (pl_module.current_epoch + 1) % self.eval_every == 0:
device = next(pl_module.backbone.parameters()).device
pl_module.eval()
assert pl_module.training is False
real_samples = self._generate_real_samples(
self.train_dataloader,
self.n_pred_samples,
self.context_length + self.prediction_length,
self.train_batch_size,
cache_path=Path(trainer.logger.log_dir) / "real_samples.npy",
)
synth_samples = self._generate_synth_samples(
self.model,
self.n_pred_samples,
)
# Train using synthetic samples, test on test set
synth_metrics, _, _ = linear_pred_score(
synth_samples,
self.context_length,
self.prediction_length,
self.test_dataset,
scaling_type="mean",
)
# Train using real samples, test on test set
scaled_real_samples, _ = self.model.scaler(
torch.from_numpy(real_samples).to(device),
torch.from_numpy(np.ones_like(real_samples)).to(device),
)
real_metrics, _, _ = linear_pred_score(
scaled_real_samples.cpu().numpy(),
self.context_length,
self.prediction_length,
self.test_dataset,
scaling_type="mean",
)
pl_module.log_dict(
{
"synth_linear_ND": synth_metrics["ND"],
"synth_linear_NRMSE": synth_metrics["NRMSE"],
"real_linear_ND": real_metrics["ND"],
"real_linear_NRMSE": real_metrics["NRMSE"],
}
)
pl_module.train()
class EvaluateCallback(Callback):
def __init__(
self,
context_length,
prediction_length,
sampler,
sampler_kwargs,
num_samples,
model,
transformation,
test_dataset,
val_dataset,
eval_every=50,
):
super().__init__()
self.context_length = context_length
self.prediction_length = prediction_length
self.sampler = sampler
self.num_samples = num_samples
self.sampler_kwargs = sampler_kwargs
self.model = model
self.transformation = transformation
self.test_dataset = test_dataset
self.val_data = val_dataset
self.original_state_dict = {}
self.eval_every = eval_every
self.log_metrics = {
"CRPS",
"ND",
"NRMSE",
}
if sampler == "ddpm":
self.Guidance = DDPMGuidance
elif sampler == "ddim":
self.Guidance = DDIMGuidance
else:
raise ValueError(f"Unknown sampler type: {sampler}")
def on_train_epoch_end(self, trainer, pl_module):
if (pl_module.current_epoch + 1) % self.eval_every == 0:
device = next(pl_module.backbone.parameters()).device
self.original_state_dict = deepcopy(
pl_module.backbone.state_dict()
)
pl_module.eval()
assert pl_module.training is False
for label, state_dict in zip(
[""] + [str(rate) for rate in pl_module.ema_rate],
[pl_module.backbone.state_dict()] + pl_module.ema_state_dicts,
):
pl_module.backbone.load_state_dict(state_dict, strict=True)
pl_module.to(device)
prediction_splitter = InstanceSplitter(
target_field=FieldName.TARGET,
is_pad_field=FieldName.IS_PAD,
start_field=FieldName.START,
forecast_start_field=FieldName.FORECAST_START,
instance_sampler=TestSplitSampler(),
past_length=self.context_length + max(self.model.lags_seq),
future_length=self.prediction_length,
time_series_fields=[
FieldName.FEAT_TIME,
FieldName.OBSERVED_VALUES,
],
)
og = self.Guidance(
self.model,
self.prediction_length,
num_samples=self.num_samples,
**self.sampler_kwargs,
)
predictor_pytorch = og.get_predictor(
prediction_splitter,
batch_size=1024 // self.num_samples,
device=device,
)
evaluator = Evaluator()
transformed_valdata = self.transformation.apply(
ConcatDataset(self.val_data), is_train=False
)
forecast_it, ts_it = make_evaluation_predictions(
dataset=transformed_valdata,
predictor=predictor_pytorch,
num_samples=self.num_samples,
)
forecasts_pytorch = list(forecast_it)
tss_pytorch = list(ts_it)
metrics_pytorch, per_ts = evaluator(
tss_pytorch, forecasts_pytorch
)
metrics_pytorch["CRPS"] = metrics_pytorch["mean_wQuantileLoss"]
if metrics_pytorch["CRPS"] < pl_module.best_crps:
pl_module.best_crps = metrics_pytorch["CRPS"]
ckpt_path = (
Path(trainer.logger.log_dir) / "best_checkpoint.ckpt"
)
torch.save(
pl_module.state_dict(),
ckpt_path,
)
pl_module.log_dict(
{
f"val_{metric}{label}": metrics_pytorch[metric]
for metric in self.log_metrics
}
)
pl_module.backbone.load_state_dict(
self.original_state_dict, strict=True
)
pl_module.train()
================================================
FILE: src/uncond_ts_diff/model/diffusion/_base.py
================================================
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import pytorch_lightning as pl
from torch.optim.lr_scheduler import ReduceLROnPlateau
from gluonts.time_feature import time_features_from_frequency_str
from gluonts.torch.modules.feature import FeatureEmbedder
from gluonts.torch.modules.scaler import MeanScaler, NOPScaler
from uncond_ts_diff.utils import extract
PREDICTION_INPUT_NAMES = [
"past_target",
"past_observed_values",
"feat_static_cat",
"feat_static_real",
"past_time_feat",
"future_time_feat",
]
class TSDiffBase(pl.LightningModule):
def __init__(
self,
backbone_parameters,
timesteps,
diffusion_scheduler,
context_length,
prediction_length,
num_feat_dynamic_real: int = 0,
num_feat_static_cat: int = 0,
num_feat_static_real: int = 0,
cardinalities=None,
freq=None,
normalization="none",
use_features=False,
use_lags=True,
lr: float = 1e-3,
):
super().__init__()
self.save_hyperparameters()
self.timesteps = timesteps
self.betas = diffusion_scheduler(timesteps)
self.sqrt_one_minus_beta = torch.sqrt(1.0 - self.betas)
self.alphas = 1 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, axis=0)
self.alphas_cumprod_prev = F.pad(
self.alphas_cumprod[:-1], (1, 0), value=1.0
)
self.sqrt_recip_alphas = torch.sqrt(1.0 / self.alphas)
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(
1.0 - self.alphas_cumprod
)
self.posterior_variance = (
self.betas
* (1.0 - self.alphas_cumprod_prev)
/ (1.0 - self.alphas_cumprod)
)
self.logs = {}
self.normalization = normalization
if normalization == "mean":
self.scaler = MeanScaler(dim=1, keepdim=True)
else:
self.scaler = NOPScaler(dim=1, keepdim=True)
if cardinalities is None:
cardinalities = [1]
self.embedder = FeatureEmbedder(
cardinalities=cardinalities,
embedding_dims=[min(50, (cat + 1) // 2) for cat in cardinalities],
)
self.time_features = (
time_features_from_frequency_str(freq) if freq is not None else []
)
self.num_feat_dynamic_real = (
1 + num_feat_dynamic_real + len(self.time_features)
)
self.num_feat_static_cat = max(num_feat_static_cat, 1)
self.num_feat_static_real = max(num_feat_static_real, 1)
self.use_features = use_features
self.use_lags = use_lags
self.context_length = context_length
self.prediction_length = prediction_length
self.losses_running_mean = torch.ones(timesteps, requires_grad=False)
self.lr = lr
self.best_crps = np.inf
def _extract_features(self, data):
raise NotImplementedError()
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
scheduler = ReduceLROnPlateau(
optimizer, mode="min", factor=0.5, patience=int(1e12)
)
return [optimizer], {"scheduler": scheduler, "monitor": "train_loss"}
def log(self, name, value, **kwargs):
super().log(name, value, **kwargs)
if isinstance(value, torch.Tensor):
value = value.detach().cpu().item()
if name not in self.logs:
self.logs[name] = [value]
else:
self.logs[name].append(value)
def get_logs(self):
logs = self.logs
logs["epochs"] = list(range(self.current_epoch))
return pd.DataFrame.from_dict(logs)
def q_sample(self, x_start, t, noise=None):
device = next(self.backbone.parameters()).device
if noise is None:
noise = torch.randn_like(x_start, device=device)
sqrt_alphas_cumprod_t = extract(
self.sqrt_alphas_cumprod, t, x_start.shape
)
sqrt_one_minus_alphas_cumprod_t = extract(
self.sqrt_one_minus_alphas_cumprod, t, x_start.shape
)
return (
sqrt_alphas_cumprod_t * x_start
+ sqrt_one_minus_alphas_cumprod_t * noise
)
def p_losses(
self,
x_start,
t,
features=None,
noise=None,
loss_type="l2",
reduction="mean",
):
device = next(self.backbone.parameters()).device
if noise is None:
noise = torch.randn_like(x_start, device=device)
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
predicted_noise = self.backbone(x_noisy, t, features)
if loss_type == "l1":
loss = F.l1_loss(noise, predicted_noise, reduction=reduction)
elif loss_type == "l2":
loss = F.mse_loss(noise, predicted_noise, reduction=reduction)
elif loss_type == "huber":
loss = F.smooth_l1_loss(
noise, predicted_noise, reduction=reduction
)
else:
raise NotImplementedError()
return loss, x_noisy, predicted_noise
@torch.no_grad()
def p_sample(self, x, t, t_index, features=None):
betas_t = extract(self.betas, t, x.shape)
sqrt_one_minus_alphas_cumprod_t = extract(
self.sqrt_one_minus_alphas_cumprod, t, x.shape
)
sqrt_recip_alphas_t = extract(self.sqrt_recip_alphas, t, x.shape)
predicted_noise = self.backbone(x, t, features)
model_mean = sqrt_recip_alphas_t * (
x - betas_t * predicted_noise / sqrt_one_minus_alphas_cumprod_t
)
if t_index == 0:
return model_mean
else:
posterior_variance_t = extract(self.posterior_variance, t, x.shape)
noise = torch.randn_like(x)
return model_mean + torch.sqrt(posterior_variance_t) * noise
@torch.no_grad()
def p_sample_ddim(self, x, t, features=None, noise=None):
if noise is None:
noise = self.backbone(x, t, features)
sqrt_alphas_cumprod_prev_t = extract(
self.alphas_cumprod_prev, t, x.shape
).sqrt()
sqrt_one_minus_alphas_cumprod_prev_t = extract(
1 - self.alphas_cumprod_prev, t, x.shape
).sqrt()
sqrt_one_minus_alphas_cumprod_t = extract(
self.sqrt_one_minus_alphas_cumprod, t, x.shape
)
sqrt_alphas_cumprod_t = extract(self.sqrt_alphas_cumprod, t, x.shape)
x0pointer = (
sqrt_alphas_cumprod_prev_t
* (x - sqrt_one_minus_alphas_cumprod_t * noise)
/ sqrt_alphas_cumprod_t
)
xtpointer = sqrt_one_minus_alphas_cumprod_prev_t * noise
return x0pointer + xtpointer
@torch.no_grad()
def p_sample_genddim(
self,
x: torch.Tensor,
t: torch.Tensor,
t_index: int,
t_prev: Optional[torch.Tensor] = None,
eta: float = 0.0,
features=None,
noise: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Generalized DDIM step that interpolates between
DDPM (eta=1) and DDIM (eta=0).
Args:
x (torch.Tensor): _description_
t (torch.Tensor): _description_
features (_type_, optional): _description_. Defaults to None.
noise (Optional[torch.Tensor], optional): _description_. Defaults to None.
Returns:
torch.Tensor: _description_
"""
if noise is None:
noise = self.backbone(x, t, features)
if t_prev is None:
t_prev = t - 1
alphas_cumprod_t = extract(self.alphas_cumprod, t, x.shape)
alphas_cumprod_prev_t = (
extract(self.alphas_cumprod, t_prev, x.shape)
if t_index > 0
else torch.ones_like(alphas_cumprod_t)
)
sqrt_alphas_cumprod_prev_t = alphas_cumprod_prev_t.sqrt()
sqrt_one_minus_alphas_cumprod_t = extract(
self.sqrt_one_minus_alphas_cumprod, t, x.shape
)
sqrt_alphas_cumprod_t = extract(self.sqrt_alphas_cumprod, t, x.shape)
x0pointer = (
sqrt_alphas_cumprod_prev_t
* (x - sqrt_one_minus_alphas_cumprod_t * noise)
/ sqrt_alphas_cumprod_t
)
c1 = (
eta
* (
(1 - alphas_cumprod_t / alphas_cumprod_prev_t)
* (1 - alphas_cumprod_prev_t)
/ (1 - alphas_cumprod_t)
).sqrt()
)
c2 = ((1 - alphas_cumprod_prev_t) - c1**2).sqrt()
return x0pointer + c1 * torch.randn_like(x) + c2 * noise
@torch.no_grad()
def sample(self, noise, features=None):
device = next(self.backbone.parameters()).device
batch_size, length, ch = noise.shape
seq = noise
seqs = [seq.cpu()]
for i in reversed(range(0, self.timesteps)):
seq = self.p_sample(
seq,
torch.full((batch_size,), i, device=device, dtype=torch.long),
i,
features,
)
seqs.append(seq.cpu().numpy())
return np.stack(seqs, axis=0)
def fast_denoise(self, xt, t, features=None, noise=None):
if noise is None:
noise = self.backbone(xt, t, features)
sqrt_one_minus_alphas_cumprod_t = extract(
self.sqrt_one_minus_alphas_cumprod, t, xt.shape
)
sqrt_alphas_cumprod_t = extract(self.sqrt_alphas_cumprod, t, xt.shape)
return (
xt - sqrt_one_minus_alphas_cumprod_t * noise
) / sqrt_alphas_cumprod_t
def forward(self, x, mask):
raise NotImplementedError()
def training_step(self, data, idx):
assert self.training is True
device = next(self.backbone.parameters()).device
if isinstance(data, dict):
x, _, features = self._extract_features(data)
else:
x, _ = self.scaler(data, torch.ones_like(data))
t = torch.randint(
0, self.timesteps, (x.shape[0],), device=device
).long()
elbo_loss, xt, noise = self.p_losses(x, t, features, loss_type="l2")
return {
"loss": elbo_loss,
"elbo_loss": elbo_loss,
}
def training_epoch_end(self, outputs):
epoch_loss = sum(x["loss"] for x in outputs) / len(outputs)
elbo_loss = sum(x["elbo_loss"] for x in outputs) / len(outputs)
self.log("train_loss", epoch_loss)
self.log("train_elbo_loss", elbo_loss)
def validation_step(self, data, idx):
device = next(self.backbone.parameters()).device
if isinstance(data, dict):
x, _, features = self._extract_features(data)
else:
x, features = data, None
t = torch.randint(
0, self.timesteps, (x.shape[0],), device=device
).long()
elbo_loss, xt, noise = self.p_losses(x, t, features, loss_type="l2")
return {
"loss": elbo_loss,
"elbo_loss": elbo_loss,
}
def validation_epoch_end(self, outputs):
epoch_loss = sum(x["loss"] for x in outputs) / len(outputs)
elbo_loss = sum(x["elbo_loss"] for x in outputs) / len(outputs)
self.log("valid_loss", epoch_loss)
self.log("valid_elbo_loss", elbo_loss)
================================================
FILE: src/uncond_ts_diff/model/diffusion/tsdiff.py
================================================
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import copy
import torch
from gluonts.torch.util import lagged_sequence_values
from uncond_ts_diff.arch import BackboneModel
from uncond_ts_diff.model.diffusion._base import TSDiffBase
from uncond_ts_diff.utils import get_lags_for_freq
class TSDiff(TSDiffBase):
def __init__(
self,
backbone_parameters,
timesteps,
diffusion_scheduler,
context_length,
prediction_length,
num_feat_dynamic_real: int = 0,
num_feat_static_cat: int = 0,
num_feat_static_real: int = 0,
cardinalities=None,
freq=None,
normalization="none",
use_features=False,
use_lags=True,
init_skip=True,
lr=1e-3,
):
super().__init__(
backbone_parameters,
timesteps=timesteps,
diffusion_scheduler=diffusion_scheduler,
context_length=context_length,
prediction_length=prediction_length,
num_feat_dynamic_real=num_feat_dynamic_real,
num_feat_static_cat=num_feat_static_cat,
num_feat_static_real=num_feat_static_real,
cardinalities=cardinalities,
freq=freq,
normalization=normalization,
use_features=use_features,
use_lags=use_lags,
lr=lr,
)
self.freq = freq
if use_lags:
self.lags_seq = get_lags_for_freq(freq)
backbone_parameters = backbone_parameters.copy()
backbone_parameters["input_dim"] += len(self.lags_seq)
backbone_parameters["output_dim"] += len(self.lags_seq)
else:
self.lags_seq = [0]
self.input_dim = backbone_parameters["input_dim"]
self.backbone = BackboneModel(
**backbone_parameters,
num_features=(
self.num_feat_static_real
+ self.num_feat_static_cat
+ self.num_feat_dynamic_real
+ 1 # log_scale
),
init_skip=init_skip,
)
self.ema_rate = [] # [0.9999]
self.ema_state_dicts = [
copy.deepcopy(self.backbone.state_dict())
for _ in range(len(self.ema_rate))
]
def _extract_features(self, data):
prior = data["past_target"][:, : -self.context_length]
context = data["past_target"][:, -self.context_length :]
context_observed = data["past_observed_values"][
:, -self.context_length :
]
if self.normalization == "zscore":
scaled_context, scale = self.scaler(
context, context_observed, data["stats"]
)
else:
scaled_context, scale = self.scaler(context, context_observed)
features = []
scaled_prior = prior / scale
scaled_future = data["future_target"] / scale
features.append(scale.log())
x = torch.cat([scaled_context, scaled_future], dim=1)
if data["feat_static_cat"] is not None:
features.append(self.embedder(data["feat_static_cat"]))
if data["feat_static_real"] is not None:
features.append(data["feat_static_real"])
static_feat = torch.cat(
features,
dim=1,
)
expanded_static_feat = static_feat.unsqueeze(1).expand(
-1, x.shape[1], -1
)
features = [expanded_static_feat]
time_features = []
if data["past_time_feat"] is not None:
time_features.append(
data["past_time_feat"][:, -self.context_length :]
)
if data["future_time_feat"] is not None:
time_features.append(data["future_time_feat"])
features.append(torch.cat(time_features, dim=1))
features = torch.cat(features, dim=-1)
if self.use_lags:
lags = lagged_sequence_values(
self.lags_seq,
scaled_prior,
torch.cat([scaled_context, scaled_future], dim=1),
dim=1,
)
x = torch.cat([x[:, :, None], lags], dim=-1)
else:
x = x[:, :, None]
if not self.use_features:
features = None
return x, scale[:, :, None], features
@torch.no_grad()
def sample_n(
self,
num_samples: int = 1,
return_lags: bool = False,
):
device = next(self.backbone.parameters()).device
seq_len = self.context_length + self.prediction_length
samples = torch.randn(
(num_samples, seq_len, self.input_dim), device=device
)
for i in reversed(range(0, self.timesteps)):
t = torch.full((num_samples,), i, device=device, dtype=torch.long)
samples = self.p_sample(samples, t, i, features=None)
samples = samples.cpu().numpy()
if return_lags:
return samples
return samples[..., 0]
def on_train_batch_end(self, outputs, batch, batch_idx):
for rate, state_dict in zip(self.ema_rate, self.ema_state_dicts):
update_ema(state_dict, self.backbone.state_dict(), rate=rate)
def update_ema(target_state_dict, source_state_dict, rate=0.99):
with torch.no_grad():
for key, value in source_state_dict.items():
ema_value = target_state_dict[key]
ema_value.copy_(
rate * ema_value + (1.0 - rate) * value.cpu(),
non_blocking=True,
)
================================================
FILE: src/uncond_ts_diff/model/diffusion/tsdiff_cond.py
================================================
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import torch
from gluonts.torch.model.predictor import PyTorchPredictor
from gluonts.torch.util import lagged_sequence_values
from uncond_ts_diff.arch import BackboneModel
from uncond_ts_diff.model.diffusion._base import TSDiffBase
from uncond_ts_diff.model.diffusion._base import PREDICTION_INPUT_NAMES
from uncond_ts_diff.utils import get_lags_for_freq
PREDICTION_INPUT_NAMES = PREDICTION_INPUT_NAMES + ["orig_past_target"]
class TSDiffCond(TSDiffBase):
def __init__(
self,
backbone_parameters,
timesteps,
diffusion_scheduler,
context_length,
prediction_length,
num_feat_dynamic_real: int = 0,
num_feat_static_cat: int = 0,
num_feat_static_real: int = 0,
cardinalities=None,
freq=None,
normalization="none",
use_features=False,
use_lags=True,
lr=1e-3,
init_skip=True,
noise_observed=True,
):
super().__init__(
backbone_parameters,
timesteps=timesteps,
diffusion_scheduler=diffusion_scheduler,
context_length=context_length,
prediction_length=prediction_length,
num_feat_dynamic_real=num_feat_dynamic_real,
num_feat_static_cat=num_feat_static_cat,
num_feat_static_real=num_feat_static_real,
cardinalities=cardinalities,
freq=freq,
normalization=normalization,
use_features=use_features,
use_lags=use_lags,
lr=lr,
)
num_features = (
(
self.num_feat_dynamic_real
+ self.num_feat_static_cat
+ self.num_feat_static_real
+ 1
)
if use_features
else 0
)
self.freq = freq
self.lags_seq = get_lags_for_freq(freq) if use_lags else [0]
self.backbone = BackboneModel(
**backbone_parameters,
num_features=(
num_features + 2 + (len(self.lags_seq) if use_lags else 0)
),
init_skip=init_skip,
)
self.noise_observed = noise_observed
def _extract_features(self, data):
device = next(self.parameters()).device
prior = data["past_target"][:, : -self.context_length]
context = data["past_target"][:, -self.context_length :]
context_observed = data["past_observed_values"][
:, -self.context_length :
]
scaled_context, scale = self.scaler(context, context_observed)
features = []
scaled_prior = prior / scale
scaled_future = data["future_target"] / scale
scaled_orig_context = (
data["orig_past_target"][:, -self.context_length :]
) / scale
x = torch.cat([scaled_orig_context, scaled_future], dim=1)
observation_mask = torch.zeros_like(x, device=device)
observation_mask[:, : -self.prediction_length] = data[
"past_observed_values"
][:, -self.context_length :].data
x_past = torch.cat(
[scaled_context, torch.zeros_like(scaled_future)], dim=1
).clone()
assert x.size() == x_past.size()
if data["feat_static_cat"] is not None:
features.append(self.embedder(data["feat_static_cat"]))
if data["feat_static_real"] is not None:
features.append(data["feat_static_real"])
static_feat = torch.cat(
features,
dim=1,
)
expanded_static_feat = static_feat.unsqueeze(1).expand(
-1, x.shape[1], -1
)
features = []
if self.use_features:
features.append(expanded_static_feat)
time_features = []
if data["past_time_feat"] is not None:
time_features.append(
data["past_time_feat"][:, -self.context_length :]
)
if data["future_time_feat"] is not None:
time_features.append(data["future_time_feat"])
features.append(torch.cat(time_features, dim=1))
lags = lagged_sequence_values(
self.lags_seq,
scaled_prior,
torch.cat([scaled_context, scaled_future], dim=1),
dim=1,
)
if self.use_lags:
features.append(lags)
features.append(x_past[..., None])
features.append(observation_mask[..., None])
features = torch.cat(features, dim=-1)
return x[..., None], scale[..., None], features
def step(self, x, t, features, loss_mask):
noise = torch.randn_like(x)
if not self.noise_observed:
noise = (1 - loss_mask) * x + noise * loss_mask
num_eval = loss_mask.sum()
sq_err, _, _ = self.p_losses(
x,
t,
features,
loss_type="l2",
reduction="none",
noise=noise,
)
if self.noise_observed:
elbo_loss = sq_err.mean()
else:
sq_err = sq_err * loss_mask
elbo_loss = sq_err.sum() / (num_eval if num_eval else 1)
return elbo_loss
def training_step(self, data, idx):
assert self.training is True
device = next(self.parameters()).device
x, _, features = self._extract_features(data)
# Last dim of features has the observation mask
observation_mask = features[..., -1:]
loss_mask = 1 - observation_mask
t = torch.randint(
0, self.timesteps, (x.shape[0],), device=device
).long()
elbo_loss = self.step(x, t, features, loss_mask)
return {
"loss": elbo_loss,
"elbo_loss": elbo_loss,
}
def validation_step(self, data, idx):
device = next(self.parameters()).device
x, _, features = self._extract_features(data)
# Last dim of features has the observation mask
observation_mask = features[..., -1:]
loss_mask = 1 - observation_mask
val_loss = 0.0
for i in range(self.timesteps):
t = torch.full((x.shape[0],), i, device=device).long()
val_loss += self.step(x, t, features, loss_mask)
val_loss /= self.timesteps
return {
"loss": val_loss,
"elbo_loss": val_loss,
}
@torch.no_grad()
def forecast(self, observation, observation_mask, features=None):
device = next(self.backbone.parameters()).device
batch_size, length, ch = observation.shape
seq = torch.randn_like(observation)
for i in reversed(range(0, self.timesteps)):
if not self.noise_observed:
seq = observation_mask * observation + seq * (
1 - observation_mask
)
seq = self.p_sample(
seq,
torch.full((batch_size,), i, device=device, dtype=torch.long),
i,
features,
)
return seq
def forward(
self,
past_target: torch.Tensor,
past_observed_values: torch.Tensor,
feat_static_cat: torch.Tensor = None,
feat_static_real: torch.Tensor = None,
past_time_feat: torch.Tensor = None,
future_time_feat: torch.Tensor = None,
orig_past_target: torch.Tensor = None,
):
# This is only used during prediction
device = next(self.backbone.parameters()).device
data = dict(
feat_static_cat=feat_static_cat.to(device)
if feat_static_cat is not None
else None,
feat_static_real=feat_static_real.to(device)
if feat_static_real is not None
else None,
past_time_feat=past_time_feat.to(device)
if past_time_feat is not None
else None,
past_target=past_target.to(device),
orig_past_target=orig_past_target.to(device),
future_target=torch.zeros(
past_target.shape[0], self.prediction_length, device=device
),
past_observed_values=past_observed_values.to(device)
if past_observed_values is not None
else None,
future_time_feat=future_time_feat.to(device)
if future_time_feat is not None
else None,
)
observation, scale, features = self._extract_features(data)
observation = observation.to(device)
batch_size, length, ch = observation.shape
observation_mask = features[..., -1:]
pred = self.forecast(
observation=observation,
observation_mask=observation_mask,
features=features,
)
pred = pred * scale
return pred[:, None, length - self.prediction_length :, 0]
def get_predictor(self, input_transform, batch_size=40, device=None):
return PyTorchPredictor(
prediction_length=self.prediction_length,
input_names=PREDICTION_INPUT_NAMES,
prediction_net=self,
batch_size=batch_size,
input_transform=input_transform,
device=device,
)
================================================
FILE: src/uncond_ts_diff/model/linear/_estimator.py
================================================
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from typing import Optional, List
import math
import numpy as np
from sklearn.linear_model import LinearRegression, Ridge
from gluonts.model import Estimator, Predictor
from gluonts.dataset.common import Dataset
from gluonts.dataset.field_names import FieldName
from gluonts.transform import (
Transformation,
AddObservedValuesIndicator,
InstanceSplitter,
TestSplitSampler,
ExpectedNumInstanceSampler,
SelectFields,
)
from gluonts.dataset.loader import TrainDataLoader, InferenceDataLoader
from gluonts.itertools import Cached
from gluonts.model.forecast_generator import (
ForecastGenerator,
SampleForecastGenerator,
predict_to_numpy,
)
from ._scaler import MeanScaler, NOPScaler
PREDICTION_INPUT_NAMES = [
"past_target",
"past_observed_values",
]
TRAINING_INPUT_NAMES = PREDICTION_INPUT_NAMES + [
"future_target",
"future_observed_values",
]
def stack(data):
if isinstance(data[0], np.ndarray):
data = np.array(data)
elif isinstance(data[0], (list, tuple)):
return list(stack(t) for t in zip(*data))
return data
def batchify(data: List[dict]):
return {
key: stack(data=[item[key] for item in data]) for key in data[0].keys()
}
class LinearModel:
def __init__(self, weight, bias, scaler, num_parallel_samples=100) -> None:
super().__init__()
self.scaler = scaler
self.weight = weight
self.bias = bias
self.num_parallel_samples = num_parallel_samples
def _linear(self, x, A, b):
return x @ A.T + b
def __call__(self, x, mask):
assert x.ndim == 2
x, scale = self.scaler(x, np.ones_like(x))
out = self._linear(x, self.weight, self.bias) * scale
return np.tile(out[:, None], (1, self.num_parallel_samples, 1))
@predict_to_numpy.register(LinearModel)
def _(prediction_net, args) -> np.ndarray:
return prediction_net(*args)
class LinearPredictor(Predictor):
def __init__(
self,
input_names: List[str],
prediction_net: LinearModel,
batch_size: int,
prediction_length: int,
input_transform: Transformation,
forecast_generator: ForecastGenerator = SampleForecastGenerator(),
lead_time: int = 0,
) -> None:
super().__init__(prediction_length, lead_time=lead_time)
self.input_names = input_names
self.prediction_net = prediction_net
self.batch_size = batch_size
self.input_transform = input_transform
self.forecast_generator = forecast_generator
def predict(self, dataset: Dataset, num_samples: Optional[int] = None):
inference_data_loader = InferenceDataLoader(
dataset,
transform=self.input_transform,
batch_size=self.batch_size,
stack_fn=batchify,
)
yield from self.forecast_generator(
inference_data_loader=inference_data_loader,
prediction_net=self.prediction_net,
input_names=self.input_names,
output_transform=None,
num_samples=num_samples,
)
class LinearEstimator(Estimator):
"""A Linear regressor that takes inputs of size equal to `context_length`
and outputs forecasts of size equal to `prediction_length`. This model uses
LinearRegression from scikit-learn under the hood.
Example usage:
```python
estimator = LinearEstimator(
dataset.metadata.freq,
prediction_length=dataset.metadata.prediction_length,
context_length=24 * 7 * 2,
)
predictor = estimator.train(dataset.train)
```
Parameters
----------
freq
Frequency of the dataset (not actually used)
prediction_length
Prediction length
context_length, optional
Context length for the linear model,
by default equal to 4 * prediction_length
num_train_samples, optional
Number of samples used to fit the LinearRegression model,
by default 10000
model, optional
Which sklearn linear model to use, one of {"linear", "ridge"},
by default "ridge".
scaling, optional
Whether to use scaling, by default True
batch_size, optional
Batch size (only relevant during prediction), by default 64
"""
def __init__(
self,
freq: str,
prediction_length: int,
context_length: Optional[int] = None,
num_train_samples: int = 10000,
model: str = "ridge",
scaling: bool = True,
batch_size: int = 64,
**kwargs,
) -> None:
super().__init__(**kwargs)
assert model in {"linear", "ridge"}
self.freq = freq
self.prediction_length = prediction_length
self.context_length = context_length or 4 * prediction_length
self.num_train_samples = num_train_samples
self.model = model
if scaling:
self.scaler = MeanScaler(axis=-1, keepdims=True)
else:
self.scaler = NOPScaler(axis=-1, keepdims=True)
self.batch_size = batch_size
def create_transformation(self) -> Transformation:
return SelectFields(
[
FieldName.ITEM_ID,
FieldName.INFO,
FieldName.START,
FieldName.TARGET,
],
allow_missing=True,
) + AddObservedValuesIndicator(
target_field=FieldName.TARGET,
output_field=FieldName.OBSERVED_VALUES,
)
def _create_instance_splitter(self, mode: str):
assert mode in ["training", "test"]
instance_sampler = {
"training": ExpectedNumInstanceSampler(
num_instances=1,
min_past=self.context_length,
min_future=self.prediction_length,
),
"test": TestSplitSampler(),
}[mode]
return InstanceSplitter(
target_field=FieldName.TARGET,
is_pad_field=FieldName.IS_PAD,
start_field=FieldName.START,
forecast_start_field=FieldName.FORECAST_START,
instance_sampler=instance_sampler,
past_length=self.context_length,
future_length=self.prediction_length,
time_series_fields=[
FieldName.OBSERVED_VALUES,
],
)
def _create_training_samples(self, training_data) -> np.ndarray:
transformation = self._create_instance_splitter(
"training"
) + SelectFields(TRAINING_INPUT_NAMES)
num_batches_per_epoch = math.ceil(self.num_train_samples / 100)
data_loader = TrainDataLoader(
training_data,
batch_size=100,
stack_fn=batchify,
transform=transformation,
num_batches_per_epoch=num_batches_per_epoch,
)
train_X, train_y = [], []
for batch in data_loader:
train_X.append(batch["past_target"])
train_y.append(batch["future_target"])
assert np.all(batch["past_observed_values"] == 1.0) and np.all(
batch["future_observed_values"] == 1.0
), "Missing values not supported!"
train_X = np.concatenate(train_X, 0)
train_y = np.concatenate(train_y, 0)
train_X = train_X[: self.num_train_samples]
train_y = train_y[: self.num_train_samples]
assert len(train_X) == self.num_train_samples
return train_X, train_y
def create_predictor(self, transformation, model):
prediction_splitter = self._create_instance_splitter("test")
return LinearPredictor(
input_names=PREDICTION_INPUT_NAMES,
prediction_net=model,
batch_size=self.batch_size,
prediction_length=self.prediction_length,
input_transform=transformation + prediction_splitter,
)
def train(
self,
training_data: Dataset,
validation_data: Optional[Dataset] = None,
cache_data: bool = False,
) -> Predictor:
transformation = self.create_transformation()
transformed_data = transformation.apply(training_data, is_train=True)
if cache_data:
transformed_data = Cached(transformed_data)
train_X, train_y = self._create_training_samples(transformed_data)
scaled_train_X, scale = self.scaler(train_X, np.ones_like(train_X))
scaled_train_y = train_y / scale
if self.model == "linear":
SKLearnLinear = LinearRegression
elif self.model == "ridge":
SKLearnLinear = Ridge
regressor = SKLearnLinear().fit(scaled_train_X, scaled_train_y)
model = LinearModel(regressor.coef_, regressor.intercept_, self.scaler)
return self.create_predictor(
transformation=transformation, model=model
)
================================================
FILE: src/uncond_ts_diff/model/linear/_scaler.py
================================================
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from typing import Optional, Tuple
import numpy as np
class MeanScaler:
"""Just like torch MeanScaler, but for numpy."""
def __init__(
self,
axis: int,
keepdims: bool = False,
default_scale: Optional[float] = None,
minimum_scale: float = 1e-10,
):
super().__init__()
self.axis = axis
self.keepdims = keepdims
self.minimum_scale = minimum_scale
self.default_scale = default_scale or 0.0
def __call__(
self, data: np.ndarray, weights: np.ndarray
) -> Tuple[np.ndarray, np.ndarray]:
# these will have shape (N, C)
total_weight = weights.sum(axis=self.axis)
weighted_sum = (np.abs(data) * weights).sum(axis=self.axis)
# first compute a global scale per-dimension
total_observed = total_weight.sum(axis=0)
denominator = np.maximum(total_observed, np.ones_like(total_observed))
if self.default_scale != 0.0:
default_scale = self.default_scale
else:
default_scale = weighted_sum.sum(axis=0) / denominator
# then compute a per-item, per-dimension scale
denominator = np.maximum(total_weight, np.ones_like(total_weight))
scale = weighted_sum / denominator
# use per-batch scale when no element is observed
# or when the sequence contains only zeros
scale = np.expand_dims(
np.maximum(
self.minimum_scale,
np.where(
weighted_sum > np.zeros_like(weighted_sum),
scale,
default_scale * np.ones_like(total_weight),
),
),
axis=self.axis,
)
return data / scale, scale if self.keepdims else scale.squeeze(
axis=self.axis
)
class NOPScaler:
"""
Just like torch NOPScaler, but for numpy.
"""
def __init__(self, axis: int, keepdims: bool = False):
super().__init__()
self.axis = axis
self.keepdims = keepdims
def __call__(
self, data: np.ndarray, weights: np.ndarray
) -> Tuple[np.ndarray, np.ndarray]:
scale = np.ones_like(data).mean(
axis=self.axis,
keepdims=self.keepdims,
)
return data, scale
================================================
FILE: src/uncond_ts_diff/predictor.py
================================================
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from typing import Iterator, Optional
from gluonts.dataset import Dataset
from gluonts.dataset.loader import InferenceDataLoader
from gluonts.model import Forecast
from gluonts.torch.batchify import batchify
from gluonts.torch.model.predictor import PyTorchPredictor
class PyTorchPredictorWGrads(PyTorchPredictor):
def predict(
self, dataset: Dataset, num_samples: Optional[int] = None
) -> Iterator[Forecast]:
inference_data_loader = InferenceDataLoader(
dataset,
transform=self.input_transform,
batch_size=self.batch_size,
stack_fn=lambda data: batchify(data, self.device),
)
self.prediction_net.eval()
yield from self.forecast_generator(
inference_data_loader=inference_data_loader,
prediction_net=self.prediction_net,
input_names=self.input_names,
output_transform=self.output_transform,
num_samples=num_samples,
)
================================================
FILE: src/uncond_ts_diff/sampler/__init__.py
================================================
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from .observation_guidance import DDIMGuidance, DDPMGuidance
from .refiner import MostLikelyRefiner, MCMCRefiner
__all__ = [
"DDIMGuidance",
"DDPMGuidance",
"MostLikelyRefiner",
"MCMCRefiner",
]
================================================
FILE: src/uncond_ts_diff/sampler/_base.py
================================================
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from typing import Callable, Tuple
from functools import partial
import numpy as np
import torch
def grad_fn(fn, x):
x.requires_grad_(True)
return torch.autograd.grad(fn(x), x)[0]
@torch.no_grad()
def langevin_dynamics(
z0: torch.Tensor,
energy_func: Callable = None,
score_func: Callable = None,
step_size: float = 0.1,
noise_scale: float = 0.1,
n_steps: int = 1,
):
"""Overdamped Langevin dynamics.
Parameters
----------
z0
Initial guess.
energy_func, optional
Energy function, only one of energy function or score function
must be specified, by default None
score_func, optional
Score function, only one of energy function or score function
must be specified, by default None
step_size, optional
Step size, by default 0.1
noise_scale, optional
Scale for Brownian noise, by default 0.1
n_steps, optional
Number of Langevin steps, by default 1
Returns
-------
Updated point.
"""
assert energy_func is not None or score_func is not None
z = z0
sqrt_2eta = torch.sqrt(2 * torch.tensor(step_size))
for _ in range(n_steps):
if energy_func is not None:
with torch.enable_grad():
z.requires_grad_(True)
Ez = energy_func(z)
v = -torch.autograd.grad(Ez, z)[0]
else:
v = score_func(z)
z = (
z.detach()
+ step_size * v
+ sqrt_2eta * noise_scale * torch.randn_like(z)
)
return z
@torch.enable_grad()
def leapfrog(
xt: torch.Tensor,
pt: torch.Tensor,
dynamics_p: Callable[[torch.Tensor], torch.Tensor],
mass: float,
h: float,
n_steps: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Leapfrong integrator.
Parameters
----------
xt
Position.
pt
Momentum.
dynamics_p
Dynamics function for momentum
mass
Mass of particle
h
Step size
n_steps
Number of leapfrog integration steps
Returns
-------
Updated position and momentum.
"""
for _ in range(n_steps):
pt = pt - (h / 2) * dynamics_p(xt)
xt = xt + h * pt / mass
pt = pt - (h / 2) * dynamics_p(xt)
xt, pt = xt.detach(), pt.detach()
return xt, pt
@torch.no_grad()
def hmc(
x0: torch.Tensor,
energy_func: Callable[[torch.Tensor], torch.Tensor],
step_size: float,
mass: float,
n_leapfrog_steps: int = 10,
n_steps: int = 100,
) -> torch.Tensor:
"""Hamiltonian Monte Carlo.
Parameters
----------
x0
Initial guess of shape [B, T, C].
energy_func
Energy function E: [B, T, C] -> []
step_size
Step size.
mass
Mass of particle.
n_leapfrog_steps, optional
Number of leapfrog integration steps, by default 10
n_steps, optional
Number of HMC steps, by default 100
Returns
-------
Updated tensor of shape [B, T, C].
"""
potential_energy_func = energy_func
batch_size, length, ch = x0.shape
drift_func = partial(grad_fn, potential_energy_func)
xt = x0
for _ in range(n_steps):
pt = np.sqrt(mass) * torch.randn_like(xt)
xt_prop, pt_prop = leapfrog(
xt, pt, drift_func, mass, step_size, n_leapfrog_steps
)
xt = xt_prop
return xt
def linear_midpoint_em_step(
zt: torch.Tensor, coeff: float, h: float, sigma: float
):
"""Midpoint Euler-Maruyama step."""
eta = torch.randn_like(zt)
ztp1 = zt - h * coeff * zt / 2 + np.sqrt(h) * sigma * eta
ztp1 = ztp1 / (1 + h * coeff / 2)
return ztp1.detach()
@torch.no_grad()
def udld(
x0: torch.Tensor,
potential_energy_func: Callable[[torch.Tensor], torch.Tensor],
step_size: float,
friction: float,
mass: float,
n_leapfrog_steps: int = 1,
n_steps: int = 100,
) -> torch.Tensor:
"""Underdamped Langevin dynamics.
Parameters
----------
x0
Initial guess of shape [B, T, C]
potential_energy_func
Energy function E: [B, T, C] -> []
step_size
Step size
friction
Friction coefficient
mass
Mass of the particle
n_leapfrog_steps, optional
Number of leapfrog integration steps, by default 1
n_steps, optional
Number of UDLD steps, by default 100
Returns
-------
Updated tensor of shape [B, T, C].
"""
batch_size, length, ch = x0.shape
xt = x0
drift_func = partial(grad_fn, potential_energy_func)
pt = np.sqrt(mass) * torch.randn_like(xt)
coeff = friction / mass
sigma = np.sqrt(2 * friction)
for _ in range(n_steps):
pt = linear_midpoint_em_step(pt, coeff, step_size / 2, sigma)
xt_prop, pt_prop = leapfrog(
xt, pt, drift_func, mass, step_size, n_leapfrog_steps
)
xt, pt = xt_prop, pt_prop
pt = linear_midpoint_em_step(pt, coeff, step_size / 2, sigma)
return xt
================================================
FILE: src/uncond_ts_diff/sampler/observation_guidance.py
================================================
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import torch
import torch.nn.functional as F
from gluonts.torch.util import lagged_sequence_values
from uncond_ts_diff.predictor import PyTorchPredictorWGrads
from uncond_ts_diff.utils import extract
from uncond_ts_diff.model import TSDiff
PREDICTION_INPUT_NAMES = [
"past_target",
"past_observed_values",
"feat_static_cat",
"feat_static_real",
"past_time_feat",
"future_time_feat",
"stats",
]
class Guidance(torch.nn.Module):
_missing_scenarios = ["none", "RM", "BM-B", "BM-E"]
def __init__(
self,
model: TSDiff,
prediction_length: int,
scale: float = 1.0,
num_samples: int = 1,
guidance: str = "quantile",
missing_scenario: str = "none",
missing_values: int = 0,
):
super().__init__()
assert missing_scenario in self._missing_scenarios
self.model = model
self.prediction_length = prediction_length
self.scale = scale
self.num_samples = num_samples
self.guidance = guidance
self.missing_scenario = missing_scenario
self.missing_values = missing_values
def quantile_loss(self, y_prediction, y_target):
assert y_target.shape == y_prediction.shape
device = y_prediction.device
batch_size_x_num_samples, length, ch = y_target.shape
batch_size = batch_size_x_num_samples // self.num_samples
# num_samples uniformly distributed quantiles between 0 and 1
# repeat for each element in the batch
q = (torch.arange(self.num_samples).repeat(batch_size) + 1).to(
device
) / (self.num_samples + 1)
# (batch_size x num_samples,)
q = q[:, None, None] # (batch_size x num_samples, 1, 1)
e = y_target - y_prediction
loss = torch.max(q * e, (q - 1) * e)
return loss
def energy_func(self, y, t, observation, observation_mask, features):
if self.guidance == "MSE":
return F.mse_loss(
self.model.fast_denoise(y, t, features),
observation,
reduction="none",
)[observation_mask == 1].sum()
elif self.guidance == "quantile":
return self.quantile_loss(
self.model.fast_denoise(y, t, features),
observation,
)[observation_mask == 1].sum()
else:
raise ValueError(f"Unknown guidance {self.guidance}!")
def score_func(self, y, t, observation, observation_mask, features):
with torch.enable_grad():
y.requires_grad_(True)
Ey = self.energy_func(
y, t, observation, observation_mask, features
)
return -torch.autograd.grad(Ey, y)[0]
def scale_func(self, y, t, base_scale):
raise NotImplementedError("Must be implemented by a subclass!")
def guide(self, observation, observation_mask, features, scale):
raise NotImplementedError("Must be implemented by a subclass!")
def forward(
self,
past_target: torch.Tensor,
past_observed_values: torch.Tensor,
feat_static_cat: torch.Tensor = None,
feat_static_real: torch.Tensor = None,
past_time_feat: torch.Tensor = None,
future_time_feat: torch.Tensor = None,
stats: torch.Tensor = None,
):
device = next(self.model.parameters()).device
future_target = torch.zeros(
past_target.shape[0], self.prediction_length, device=device
)
data = dict(
feat_static_cat=feat_static_cat.to(device)
if feat_static_cat is not None
else None,
feat_static_real=feat_static_real.to(device)
if feat_static_real is not None
else None,
past_time_feat=past_time_feat.to(device)
if past_time_feat is not None
else None,
past_target=past_target.to(device),
future_target=future_target,
past_observed_values=past_observed_values.to(device)
if past_observed_values is not None
else None,
future_time_feat=future_time_feat.to(device)
if future_time_feat is not None
else None,
stats=stats.to(device) if stats is not None else None,
)
observation, scale_params, features = self.model._extract_features(
data
)
observation = observation.to(device)
batch_size, length, ch = observation.shape
prior_mask = past_observed_values[:, : -self.model.context_length]
context_mask = past_observed_values[:, -self.model.context_length :]
future_mask = torch.zeros_like(future_target)
observation_mask = torch.cat([context_mask, future_mask], dim=1)
if self.model.use_lags:
lagged_mask = lagged_sequence_values(
self.model.lags_seq,
prior_mask,
observation_mask,
dim=1,
)
observation_mask = torch.cat(
[observation_mask[:, :, None], lagged_mask], dim=-1
)
else:
observation_mask = observation_mask[:, :, None]
observation = observation.repeat_interleave(self.num_samples, dim=0)
observation_mask = observation_mask.repeat_interleave(
self.num_samples, dim=0
)
if features is not None:
features = features.repeat_interleave(self.num_samples, dim=0)
# base_scale = self.scale / (
# context_mask.sum() / torch.ones_like(context_mask).sum()
# )
base_scale = self.scale
pred = self.guide(observation, observation_mask, features, base_scale)
pred = pred[:, :, 0].reshape(batch_size, self.num_samples, -1)
pred = pred * scale_params
return pred[..., length - self.prediction_length :]
def get_predictor(self, input_transform, batch_size=40, device=None):
return PyTorchPredictorWGrads(
prediction_length=self.prediction_length,
input_names=PREDICTION_INPUT_NAMES,
prediction_net=self,
batch_size=batch_size,
input_transform=input_transform,
device=device,
)
class DDPMGuidance(Guidance):
def __init__(
self,
model: TSDiff,
prediction_length: int,
scale: float = 1,
num_samples: int = 1,
guidance: str = "quantile",
missing_scenario: str = "none",
missing_values: int = 0,
):
super().__init__(
model,
prediction_length,
scale,
num_samples,
guidance,
missing_scenario,
missing_values,
)
def scale_func(self, y, t, base_scale):
return extract(self.model.posterior_variance, t, y.shape) * base_scale
@torch.no_grad()
def _reverse_diffusion(
self, observation, observation_mask, features, base_scale
):
device = observation.device
batch_size = observation.shape[0]
seq = torch.randn_like(observation)
for i in reversed(range(0, self.model.timesteps)):
t = torch.full((batch_size,), i, device=device, dtype=torch.long)
seq = self.model.p_sample(seq, t, i, features)
scale = self.scale_func(seq, t, base_scale=base_scale)
seq = seq + scale * self.score_func(
seq,
t,
observation=observation,
observation_mask=observation_mask,
features=features,
)
return seq
def guide(self, observation, observation_mask, features, base_scale):
return self._reverse_diffusion(
observation, observation_mask, features, base_scale
)
class DDIMGuidance(Guidance):
_skip_types = ["uniform", "quadratic"]
def __init__(
self,
model: TSDiff,
prediction_length: int,
eta: float = 0.0,
skip_factor: int = 1,
skip_type: str = "uniform",
scale: float = 1,
num_samples: int = 1,
guidance: str = "quantile",
missing_scenario: str = "none",
missing_values: int = 0,
):
super().__init__(
model,
prediction_length,
scale,
num_samples,
guidance,
missing_scenario,
missing_values,
)
assert skip_type in self._skip_types
self.eta = eta
self.skip_factor = skip_factor
self.skip_type = skip_type
def scale_func(self, y, t, base_scale):
return (
extract(self.model.sqrt_one_minus_alphas_cumprod, t, y.shape)
* base_scale
)
def _get_timesteps(self):
if self.skip_type == "uniform":
timesteps = range(0, self.model.timesteps, self.skip_factor)
elif self.skip_type == "quadratic":
n_test_timesteps = int(self.model.timesteps / self.skip_factor)
c = 1 - self.skip_factor / self.model.timesteps
timesteps = np.square(
np.linspace(
0, np.sqrt(self.model.timesteps * c), n_test_timesteps
)
)
timesteps = timesteps.astype(np.int64).tolist()
timesteps = sorted(set(timesteps))
return timesteps
@torch.no_grad()
def _reverse_ddim(
self, observation, observation_mask, features, base_scale
):
device = observation.device
batch_size = observation.shape[0]
timesteps = self._get_timesteps()
timesteps_prev = [-1] + timesteps[:-1]
seq = torch.randn_like(observation)
for i, j in zip(reversed(timesteps), reversed(timesteps_prev)):
t = torch.full((batch_size,), i, device=device, dtype=torch.long)
t_prev = torch.full(
(batch_size,), j, device=device, dtype=torch.long
)
noise = self.model.backbone(seq, t, features)
scale = self.scale_func(seq, t, base_scale=base_scale)
noise = noise - scale * self.score_func(
seq,
t,
observation=observation,
observation_mask=observation_mask,
features=features,
)
seq = self.model.p_sample_genddim(
seq,
t,
t_index=i,
t_prev=t_prev,
eta=self.eta,
features=features,
noise=noise,
)
return seq
def guide(self, observation, observation_mask, features, base_scale):
return self._reverse_ddim(
observation, observation_mask, features, base_scale
)
================================================
FILE: src/uncond_ts_diff/sampler/refiner.py
================================================
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from gluonts.time_feature import get_seasonality
from uncond_ts_diff.predictor import PyTorchPredictorWGrads
from uncond_ts_diff.sampler._base import (
langevin_dynamics,
hmc,
udld,
)
PREDICTION_INPUT_NAMES = [
"past_target",
"past_observed_values",
"feat_static_cat",
"feat_static_real",
"past_time_feat",
"future_time_feat",
"stats",
]
class Refiner(torch.nn.Module):
def __init__(
self,
model,
prediction_length,
fixed_t=20,
iterations=1,
init=None,
num_samples=1,
guidance="quantile",
scale=1,
):
super().__init__()
self.model = model
self.prediction_length = prediction_length
self.fixed_t = fixed_t
self.iterations = iterations
self.init = init
self.num_samples = num_samples
self.guidance = guidance
self.scale = scale
def quantile_loss(self, y_prediction, y_target):
assert y_target.shape == y_prediction.shape
device = y_prediction.device
batch_size_x_num_samples, length, ch = y_target.shape
batch_size = batch_size_x_num_samples // self.num_samples
# num_samples uniformly distributed quantiles between 0 and 1
# repeat for each element in the batch
q = (torch.arange(self.num_samples).repeat(batch_size) + 1).to(
device
) / (self.num_samples + 1)
# (batch_size x num_samples,)
q = q[:, None, None]
# (batch_size x num_samples, 1, 1)
e = y_target - y_prediction
loss = torch.max(q * e, (q - 1) * e)
return loss
def prior(self, y_prediction, obs, obs_mask):
if self.guidance == "MSE":
return (
self.scale
* F.mse_loss(y_prediction, obs, reduction="none")[
obs_mask == 1
].sum()
)
elif self.guidance == "quantile":
return self.scale * self.quantile_loss(y_prediction, obs).sum()
else:
raise ValueError(f"Unknown guidance {self.guidance}!")
def refine(self, observation, observation_mask):
raise NotImplementedError("Must be implemented by a subclass!")
def forward(
self,
past_target: torch.Tensor,
past_observed_values: torch.Tensor,
feat_static_cat: torch.Tensor = None,
feat_static_real: torch.Tensor = None,
past_time_feat: torch.Tensor = None,
future_time_feat: torch.Tensor = None,
stats: torch.Tensor = None,
):
device = next(self.model.backbone.parameters()).device
data = dict(
feat_static_cat=feat_static_cat.to(device)
if feat_static_cat is not None
else None,
feat_static_real=feat_static_real.to(device)
if feat_static_real is not None
else None,
past_time_feat=past_time_feat.to(device)
if past_time_feat is not None
else None,
past_target=past_target.to(device),
future_target=torch.zeros(
past_target.shape[0], self.prediction_length, device=device
),
past_observed_values=past_observed_values.to(device)
if past_observed_values is not None
else None,
future_time_feat=future_time_feat.to(device)
if future_time_feat is not None
else None,
)
observation, scale, features = self.model._extract_features(data)
observation = observation.to(device)
batch_size, length, ch = observation.shape
observation_mask = torch.ones_like(observation, device=device)
observation_mask[:, length - self.prediction_length :, 0] = 0
observation = observation.repeat_interleave(self.num_samples, dim=0)
observation_mask = observation_mask.repeat_interleave(
self.num_samples, dim=0
)
if features is not None:
features = features.repeat_interleave(self.num_samples, dim=0)
if self.init is not None:
init_forecasts = np.stack(
[next(self.init).samples for _ in range(batch_size)]
)
if init_forecasts.shape[1] == 1:
# Single sample, e.g., for SeasonalNaive
init_forecasts = np.tile(
init_forecasts, (1, self.num_samples, 1)
)
# create numpy array out of list and sort them to
# match to their corresponding quantile
init = np.sort(init_forecasts, axis=1)
init = torch.from_numpy(init).to(device)
# scale input
init = init / scale
# reshape from B x num_samples x prediction_length to
# B * self.num_samples x prediction_length
init = init.reshape(
batch_size * self.num_samples, self.prediction_length
)
# use it as initial guess
observation[:, length - self.prediction_length :, 0] = init
else:
season_length = get_seasonality(self.model.freq)
# Initialize using Seasonal Naive predictions
if (length - self.prediction_length) >= season_length:
indices = [
length
- self.prediction_length
- season_length
+ k % season_length
for k in range(self.prediction_length)
]
observation[
:, length - self.prediction_length :, 0
] = observation[:, indices, 0]
# Initialize using the meant of the context length
else:
observation[
:, length - self.prediction_length :, 0
] = torch.mean(
observation[:, : length - self.prediction_length, 0],
dim=1,
keepdim=True,
)
pred = self.refine(observation, observation_mask)
pred = pred[:, :, 0].reshape(batch_size, self.num_samples, -1)
pred = pred * scale
return pred[:, :, length - self.prediction_length :]
def get_predictor(self, input_transform, batch_size=40, device=None):
return PyTorchPredictorWGrads(
prediction_length=self.prediction_length,
input_names=PREDICTION_INPUT_NAMES,
prediction_net=self,
batch_size=batch_size,
input_transform=input_transform,
device=device,
)
class MostLikelyRefiner(Refiner):
def __init__(
self,
model,
prediction_length,
lr=1e-1,
patience=100,
fixed_t=20,
iterations=1,
init=None,
num_samples=1,
guidance="quantile",
scale=1,
):
super().__init__(
model,
prediction_length,
fixed_t,
iterations,
init,
num_samples,
guidance,
scale,
)
self.lr = lr
self.patience = patience
def _most_likely(self, observation, observation_mask):
device = next(self.model.backbone.parameters()).device
observation = observation.to(device)
seq = nn.Parameter(torch.clone(observation), requires_grad=True)
optim = torch.optim.SGD([seq], lr=self.lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optim, "min", patience=self.patience, factor=0.5
)
with torch.enable_grad():
for i in range(self.iterations):
optim.zero_grad()
t = torch.randint(
0, self.model.timesteps, (seq.shape[0],), device=device
).long()
if self.fixed_t != -1:
t = t * 0 + self.fixed_t
loss = self.model.p_losses(
seq, t, loss_type="l2", reduction="sum"
)[0] + self.prior(seq, observation, observation_mask)
loss.backward()
optim.step()
scheduler.step(loss.item())
return seq.detach()
def refine(self, observation, observation_mask):
return self._most_likely(observation, observation_mask)
class MCMCRefiner(Refiner):
_available_methods = {"lmc", "hmc", "udld", "cdld"}
def __init__(
self,
model,
prediction_length,
step_size=1e-1,
method="lmc",
method_kwargs={},
fixed_t=20,
iterations=1,
init=None,
num_samples=1,
guidance="quantile",
scale=1,
):
super().__init__(
model,
prediction_length,
fixed_t,
iterations,
init,
num_samples,
guidance,
scale,
)
assert method in self._available_methods
self.step_size: float = step_size
self.method: str = method
self.method_kwargs: dict = method_kwargs
def _mcmc(self, observation, observation_mask):
device = next(self.model.backbone.parameters()).device
observation = observation.to(device)
seq = torch.clone(observation)
for i in range(self.iterations):
t = torch.randint(
0, self.model.timesteps, (seq.shape[0],), device=device
).long()
if self.fixed_t != -1:
t = t * 0 + self.fixed_t
energy_func = lambda x: self.model.p_losses( # noqa: E731
x, t, loss_type="l2", reduction="sum"
)[0] + self.prior(x, observation, observation_mask)
if self.method == "lmc":
method_kwargs = {
"noise_scale": 0.1,
"n_steps": 1,
}
method_kwargs.update(self.method_kwargs)
seq = langevin_dynamics(
seq,
energy_func,
score_func=None,
step_size=self.step_size,
**self.method_kwargs,
)
elif self.method == "hmc":
method_kwargs = {
"mass": 1.0,
"n_steps": 1,
"n_leapfrog_steps": 5,
}
method_kwargs.update(self.method_kwargs)
seq = hmc(
seq, energy_func, step_size=self.step_size, **method_kwargs
)
elif self.method == "udld":
method_kwargs = {
"mass": 1.0,
"friction": 1.0,
"n_steps": 1,
"n_leapfrog_steps": 5,
}
method_kwargs.update(self.method_kwargs)
seq = udld(
seq, energy_func, step_size=self.step_size, **method_kwargs
)
elif self.method == "cdld":
method_kwargs = {
"mass": 1.0,
"n_steps": 1,
"n_leapfrog_steps": 5,
}
method_kwargs.update(self.method_kwargs)
# friction^2 = 4 x mass
method_kwargs["friction"] = np.sqrt(4 * method_kwargs["mass"])
seq = udld(
seq, energy_func, step_size=self.step_size, **method_kwargs
)
return seq.detach()
def refine(self, observation, observation_mask):
return self._mcmc(observation, observation_mask)
================================================
FILE: src/uncond_ts_diff/utils.py
================================================
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from copy import deepcopy
from typing import Type, Dict
from pathlib import Path
from argparse import ArgumentParser, ArgumentTypeError
from functools import partial
import re
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import pandas as pd
import torch
from torch.utils.data import Dataset
from pandas.tseries.frequencies import to_offset
from gluonts.core.component import validated
from gluonts.dataset import DataEntry
from gluonts.dataset.field_names import FieldName
from gluonts.dataset.split import split
from gluonts.dataset.util import period_index
from gluonts.transform import (
Chain,
RemoveFields,
SetField,
AsNumpyArray,
AddObservedValuesIndicator,
AddTimeFeatures,
AddAgeFeature,
VstackFeatures,
MapTransformation,
ExpectedNumInstanceSampler,
InstanceSplitter,
TestSplitSampler,
ValidationSplitSampler,
)
from gluonts.model.forecast import SampleForecast
sns.set(
style="white",
font_scale=1.1,
rc={"figure.dpi": 125, "lines.linewidth": 2.5, "axes.linewidth": 1.5},
)
def filter_metrics(metrics, select={"ND", "NRMSE", "mean_wQuantileLoss"}):
return {m: metrics[m].item() for m in select}
def extract(a, t, x_shape):
batch_size = t.shape[0]
out = a.gather(-1, t.cpu())
return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)
def cosine_beta_schedule(timesteps, s=0.008):
"""
cosine schedule as proposed in https://arxiv.org/abs/2102.09672
"""
steps = timesteps + 1
x = torch.linspace(0, timesteps, steps)
alphas_cumprod = (
torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
)
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return torch.clip(betas, 0.0001, 0.9999)
def linear_beta_schedule(timesteps):
beta_start = 0.0001
beta_end = 0.1
return torch.linspace(beta_start, beta_end, timesteps)
def plot_train_stats(df: pd.DataFrame, y_keys=None, skip_first_epoch=True):
if skip_first_epoch:
df = df.iloc[1:, :]
if y_keys is None:
y_keys = ["train_loss", "valid_loss"]
fix, ax = plt.subplots(1, 1, figsize=(6.5, 4))
for y_key in y_keys:
sns.lineplot(
ax=ax,
data=df,
x="epochs",
y=y_key,
label=y_key.replace("_", " ").capitalize(),
)
ax.legend()
ax.set_ylabel("Loss")
ax.set_xlabel("Epoch")
plt.show()
def get_lags_for_freq(freq_str: str):
offset = to_offset(freq_str)
if offset.n > 1:
raise NotImplementedError(
"Lags for freq multiple > 1 are not implemented yet."
)
if offset.name == "H":
lags_seq = [24 * i for i in [1, 2, 3, 4, 5, 6, 7, 14, 21, 28]]
elif offset.name == "D" or offset.name == "B":
# TODO: Fix lags for B
lags_seq = [30 * i for i in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]
else:
raise NotImplementedError(
f"Lags for {freq_str} are not implemented yet."
)
return lags_seq
def create_transforms(
num_feat_dynamic_real,
num_feat_static_cat,
num_feat_static_real,
time_features,
prediction_length,
):
remove_field_names = []
if num_feat_static_real == 0:
remove_field_names.append(FieldName.FEAT_STATIC_REAL)
if num_feat_dynamic_real == 0:
remove_field_names.append(FieldName.FEAT_DYNAMIC_REAL)
return Chain(
[RemoveFields(field_names=remove_field_names)]
+ (
[SetField(output_field=FieldName.FEAT_STATIC_CAT, value=[0])]
if not num_feat_static_cat > 0
else []
)
+ (
[SetField(output_field=FieldName.FEAT_STATIC_REAL, value=[0.0])]
if not num_feat_static_real > 0
else []
)
+ [
AsNumpyArray(
field=FieldName.FEAT_STATIC_CAT,
expected_ndim=1,
dtype=int,
),
AsNumpyArray(
field=FieldName.FEAT_STATIC_REAL,
expected_ndim=1,
),
AsNumpyArray(
field=FieldName.TARGET,
expected_ndim=1,
),
AddObservedValuesIndicator(
target_field=FieldName.TARGET,
output_field=FieldName.OBSERVED_VALUES,
),
AddTimeFeatures(
start_field=FieldName.START,
target_field=FieldName.TARGET,
output_field=FieldName.FEAT_TIME,
time_features=time_features,
pred_length=prediction_length,
),
AddAgeFeature(
target_field=FieldName.TARGET,
output_field=FieldName.FEAT_AGE,
pred_length=prediction_length,
log_scale=True,
),
AddMeanAndStdFeature(
target_field=FieldName.TARGET,
output_field="stats",
),
VstackFeatures(
output_field=FieldName.FEAT_TIME,
input_fields=[FieldName.FEAT_TIME, FieldName.FEAT_AGE]
+ (
[FieldName.FEAT_DYNAMIC_REAL]
if num_feat_dynamic_real > 0
else []
),
),
]
)
def create_splitter(past_length: int, future_length: int, mode: str = "train"):
if mode == "train":
instance_sampler = ExpectedNumInstanceSampler(
num_instances=1,
min_past=past_length,
min_future=future_length,
)
elif mode == "val":
instance_sampler = ValidationSplitSampler(min_future=future_length)
elif mode == "test":
instance_sampler = TestSplitSampler()
splitter = InstanceSplitter(
target_field=FieldName.TARGET,
is_pad_field=FieldName.IS_PAD,
start_field=FieldName.START,
forecast_start_field=FieldName.FORECAST_START,
instance_sampler=instance_sampler,
past_length=past_length,
future_length=future_length,
time_series_fields=[FieldName.FEAT_TIME, FieldName.OBSERVED_VALUES],
)
return splitter
def get_next_file_num(
base_fname: str,
base_dir: Path,
file_type: str = "yaml",
separator: str = "-",
):
"""Gets the next available file number in a directory.
e.g., if `base_fname="results"` and `base_dir` has
files ["results-0.yaml", "results-1.yaml"],
this function returns 2.
Parameters
----------
base_fname
Base name of the file.
base_dir
Base directory where files are located.
Returns
-------
Next available file number
"""
if file_type == "":
# Directory
items = filter(
lambda x: x.is_dir() and x.name.startswith(base_fname),
base_dir.glob("*"),
)
else:
# File
items = filter(
lambda x: x.name.startswith(base_fname),
base_dir.glob(f"*.{file_type}"),
)
run_nums = list(
map(lambda x: int(x.stem.replace(base_fname + separator, "")), items)
) + [-1]
return max(run_nums) + 1
def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ("yes", "true", "t", "y", "1"):
return True
elif v.lower() in ("no", "false", "f", "n", "0"):
return False
else:
raise ArgumentTypeError("Boolean value expected.")
def add_config_to_argparser(config: Dict, parser: ArgumentParser):
for k, v in config.items():
sanitized_key = re.sub(r"[^\w\-]", "", k).replace("-", "_")
val_type = type(v)
if val_type not in {int, float, str, bool}:
print(f"WARNING: Skipping key {k}!")
continue
if val_type == bool:
parser.add_argument(f"--{sanitized_key}", type=str2bool, default=v)
else:
parser.add_argument(f"--{sanitized_key}", type=val_type, default=v)
return parser
class AddMeanAndStdFeature(MapTransformation):
@validated()
def __init__(
self,
target_field: str,
output_field: str,
dtype: Type = np.float32,
) -> None:
self.target_field = target_field
self.feature_name = output_field
self.dtype = dtype
def map_transform(self, data: DataEntry, is_train: bool) -> DataEntry:
data[self.feature_name] = np.array(
[data[self.target_field].mean(), data[self.target_field].std()]
)
return data
class ScaleAndAddMeanFeature(MapTransformation):
def __init__(
self, target_field: str, output_field: str, prediction_length: int
) -> None:
"""Scale the time series using mean scaler and
add the scale to `output_field`.
Parameters
----------
target_field
Key for target time series
output_field
Key for the mean feature
prediction_length
prediction length, only the time series before the
last `prediction_length` timesteps is used for
scale computation
"""
self.target_field = target_field
self.feature_name = output_field
self.prediction_length = prediction_length
def map_transform(self, data, is_train: bool):
scale = np.mean(
np.abs(data[self.target_field][..., : -self.prediction_length]),
axis=-1,
keepdims=True,
)
scale = np.maximum(scale, 1e-7)
scaled_target = data[self.target_field] / scale
data[self.target_field] = scaled_target
data[self.feature_name] = scale
return data
class ScaleAndAddMinMaxFeature(MapTransformation):
def __init__(
self, target_field: str, output_field: str, prediction_length: int
) -> None:
"""Scale the time series using min-max scaler and
add the scale to `output_field`.
Parameters
----------
target_field
Key for target time series
output_field
Key for the min-max feature
prediction_length
prediction length, only the time series before the
last `prediction_length` timesteps is used for
scale computation
"""
self.target_field = target_field
self.feature_name = output_field
self.prediction_length = prediction_length
def map_transform(self, data, is_train: bool):
full_seq = data[self.target_field][..., : -self.prediction_length]
min_val = np.min(full_seq, axis=-1, keepdims=True)
max_val = np.max(full_seq, axis=-1, keepdims=True)
loc = min_val
scale = np.maximum(max_val - min_val, 1e-7)
scaled_target = (full_seq - loc) / scale
data[self.target_field] = scaled_target
data[self.feature_name] = (loc, scale)
return data
def descale(data, scale, scaling_type):
if scaling_type == "mean":
return data * scale
elif scaling_type == "min-max":
loc, scale = scale
return data * scale + loc
else:
raise ValueError(f"Unknown scaling type: {scaling_type}")
def predict_and_descale(predictor, dataset, num_samples, scaling_type):
"""Generates forecasts using the predictor on the test
dataset and then scales them back to the original space
using the scale feature from `ScaleAndAddMeanFeature`
or `ScaleAndAddMinMaxFeature` transformation.
Parameters
----------
predictor
GluonTS predictor
dataset
Test dataset
num_samples
Number of forecast samples
scaling_type
Scaling type should be one of {"mean", "min-max"}
Min-max scaling is used in TimeGAN, defaults to "mean"
Yields
------
SampleForecast objects
Raises
------
ValueError
If the predictor generates Forecast objects other than SampleForecast
"""
forecasts = predictor.predict(dataset, num_samples=num_samples)
for input_ts, fcst in zip(dataset, forecasts):
scale = input_ts["scale"]
if isinstance(fcst, SampleForecast):
fcst.samples = descale(
fcst.samples, scale, scaling_type=scaling_type
)
else:
raise ValueError("Only SampleForecast objects supported!")
yield fcst
def to_dataframe_and_descale(input_label, scaling_type) -> pd.DataFrame:
"""Glues together "input" and "label" time series and scales
the back using the scale feature from transformation.
Parameters
----------
input_label
Input-Label pair generated from the test template
scaling_type
Scaling type should be one of {"mean", "min-max"}
Min-max scaling is used in TimeGAN, defaults to "mean"
Returns
-------
A DataFrame containing the time series
"""
start = input_label[0][FieldName.START]
scale = input_label[0]["scale"]
targets = [entry[FieldName.TARGET] for entry in input_label]
full_target = np.concatenate(targets, axis=-1)
full_target = descale(full_target, scale, scaling_type=scaling_type)
index = period_index(
{FieldName.START: start, FieldName.TARGET: full_target}
)
return pd.DataFrame(full_target.transpose(), index=index)
def make_evaluation_predictions_with_scaling(
dataset, predictor, num_samples: int = 100, scaling_type="mean"
):
"""A customized version of `make_evaluation_predictions` utility
that first scales the test time series, generates the forecast and
the scales it back to the original space.
Parameters
----------
dataset
Test dataset
predictor
GluonTS predictor
num_samples, optional
Number of test samples, by default 100
scaling_type, optional
Scaling type should be one of {"mean", "min-max"}
Min-max scaling is used in TimeGAN, defaults to "mean"
Returns
-------
A tuple of forecast and time series iterators
"""
window_length = predictor.prediction_length + predictor.lead_time
_, test_template = split(dataset, offset=-window_length)
test_data = test_template.generate_instances(window_length)
input_test_data = list(test_data.input)
return (
predict_and_descale(
predictor,
input_test_data,
num_samples=num_samples,
scaling_type=scaling_type,
),
map(
partial(to_dataframe_and_descale, scaling_type=scaling_type),
test_data,
),
)
class PairDataset(Dataset):
def __init__(self, x, y) -> None:
super().__init__()
assert x.shape[0] == y.shape[0]
self.x = x
self.y = y
def __getitem__(self, index):
return self.x[index], self.y[index]
def __len__(self):
return self.x.shape[0]
class GluonTSNumpyDataset:
"""GluonTS dataset from a numpy array.
Parameters
----------
data
Numpy array of samples with shape [N, T].
start_date, optional
Dummy start date field, by default pd.Period("2023", "H")
"""
def __init__(
self, data: np.ndarray, start_date: pd.Period = pd.Period("2023", "H")
):
self.data = data
self.start_date = start_date
def __iter__(self):
for ts in self.data:
item = {"target": ts, "start": self.start_date}
yield item
def __len__(self):
return len(self.data)
class MaskInput(MapTransformation):
@validated()
def __init__(
self,
target_field: str,
observed_field: str,
context_length: int,
missing_scenario: str,
missing_values: int,
dtype: Type = np.float32,
) -> None:
# FIXME: Remove hardcoding of fields
self.target_field = target_field
self.observed_field = observed_field
self.context_length = context_length
self.missing_scenario = missing_scenario
self.missing_values = missing_values
self.dtype = dtype
def map_transform(self, data: DataEntry, is_train: bool) -> DataEntry:
data = deepcopy(data)
data["orig_past_target"] = data["past_target"].copy()
if self.missing_scenario == "BM-E" and self.missing_values > 0:
data["past_target"][-self.missing_values :] = 0
data["past_observed_values"][-self.missing_values :] = 0
elif self.missing_scenario == "BM-B" and self.missing_values > 0:
data["past_target"][
-self.context_length : -self.context_length
+ self.missing_values
] = 0
data["past_observed_values"][
-self.context_length : -self.context_length
+ self.missing_values
] = 0
elif self.missing_scenario == "RM" and self.missing_values > 0:
weights = torch.ones(self.context_length)
missing_idxs = -self.context_length + torch.multinomial(
weights, self.missing_values, replacement=False
)
data["past_target"][missing_idxs] = 0
data["past_observed_values"][missing_idxs] = 0
return data
class ConcatDataset:
def __init__(self, test_pairs, axis=-1) -> None:
self.test_pairs = test_pairs
self.axis = axis
def _concat(self, test_pairs):
for t1, t2 in test_pairs:
yield {
"target": np.concatenate(
[t1["target"], t2["target"]], axis=self.axis
),
"start": t1["start"],
}
def __iter__(self):
yield from self._concat(self.test_pairs)