Repository: amazon-science/unconditional-time-series-diffusion Branch: main Commit: 3eafeffdffef Files: 126 Total size: 289.4 KB Directory structure: gitextract_qvnzuyyo/ ├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── NOTICE ├── README.md ├── THIRD-PARTY-LICENSES.txt ├── bin/ │ ├── guidance_experiment.py │ ├── refinement_experiment.py │ ├── train_cond_model.py │ ├── train_model.py │ └── tstr_experiment.py ├── configs/ │ ├── guidance/ │ │ ├── guidance_electricity.yaml │ │ ├── guidance_exchange.yaml │ │ ├── guidance_kdd_cup.yaml │ │ ├── guidance_m4.yaml │ │ ├── guidance_solar.yaml │ │ ├── guidance_traffic.yaml │ │ ├── guidance_uber_tlc.yaml │ │ └── guidance_wiki.yaml │ ├── guidance.yaml │ ├── refinement/ │ │ ├── electricity_nips-deepar.yaml │ │ ├── electricity_nips-linear.yaml │ │ ├── electricity_nips-seasonal_naive.yaml │ │ ├── electricity_nips-transformer.yaml │ │ ├── exchange_rate_nips-deepar.yaml │ │ ├── exchange_rate_nips-linear.yaml │ │ ├── exchange_rate_nips-seasonal_naive.yaml │ │ ├── exchange_rate_nips-transformer.yaml │ │ ├── kdd_cup_2018_without_missing-deepar.yaml │ │ ├── kdd_cup_2018_without_missing-linear.yaml │ │ ├── kdd_cup_2018_without_missing-seasonal_naive.yaml │ │ ├── kdd_cup_2018_without_missing-transformer.yaml │ │ ├── m4_hourly-deepar.yaml │ │ ├── m4_hourly-linear.yaml │ │ ├── m4_hourly-seasonal_naive.yaml │ │ ├── m4_hourly-transformer.yaml │ │ ├── solar_nips-deepar.yaml │ │ ├── solar_nips-linear.yaml │ │ ├── solar_nips-seasonal_naive.yaml │ │ ├── solar_nips-transformer.yaml │ │ ├── traffic_nips-deepar.yaml │ │ ├── traffic_nips-linear.yaml │ │ ├── traffic_nips-seasonal_naive.yaml │ │ ├── traffic_nips-transformer.yaml │ │ ├── uber_tlc_hourly-deepar.yaml │ │ ├── uber_tlc_hourly-linear.yaml │ │ ├── uber_tlc_hourly-seasonal_naive.yaml │ │ ├── uber_tlc_hourly-transformer.yaml │ │ ├── wiki2000_nips-deepar.yaml │ │ ├── wiki2000_nips-linear.yaml │ │ ├── wiki2000_nips-seasonal_naive.yaml │ │ └── wiki2000_nips-transformer.yaml │ ├── refinement.yaml │ ├── train_tsdiff/ │ │ ├── train_electricity.yaml │ │ ├── train_exchange.yaml │ │ ├── train_kdd_cup.yaml │ │ ├── train_m4.yaml │ │ ├── train_missing_electricity.yaml │ │ ├── train_missing_exchange.yaml │ │ ├── train_missing_kdd_cup.yaml │ │ ├── train_missing_solar.yaml │ │ ├── train_missing_traffic.yaml │ │ ├── train_missing_uber_tlc.yaml │ │ ├── train_solar.yaml │ │ ├── train_traffic.yaml │ │ ├── train_uber_tlc.yaml │ │ └── train_wiki.yaml │ ├── train_tsdiff-cond/ │ │ ├── electricity_nips.yaml │ │ ├── exchange_rate_nips.yaml │ │ ├── kdd_cup_2018_without_missing.yaml │ │ ├── m4_hourly.yaml │ │ ├── missing_BM-B_electricity_nips.yaml │ │ ├── missing_BM-B_exchange_rate_nips.yaml │ │ ├── missing_BM-B_kdd_cup_2018_without_missing.yaml │ │ ├── missing_BM-B_solar_nips.yaml │ │ ├── missing_BM-B_traffic_nips.yaml │ │ ├── missing_BM-B_uber_tlc_hourly.yaml │ │ ├── missing_BM-E_electricity_nips.yaml │ │ ├── missing_BM-E_exchange_rate_nips.yaml │ │ ├── missing_BM-E_kdd_cup_2018_without_missing.yaml │ │ ├── missing_BM-E_solar_nips.yaml │ │ ├── missing_BM-E_traffic_nips.yaml │ │ ├── missing_BM-E_uber_tlc_hourly.yaml │ │ ├── missing_RM_electricity_nips.yaml │ │ ├── missing_RM_exchange_rate_nips.yaml │ │ ├── missing_RM_kdd_cup_2018_without_missing.yaml │ │ ├── missing_RM_solar_nips.yaml │ │ ├── missing_RM_traffic_nips.yaml │ │ ├── missing_RM_uber_tlc_hourly.yaml │ │ ├── solar_nips.yaml │ │ ├── traffic_nips.yaml │ │ ├── uber_tlc_hourly.yaml │ │ └── wiki2000_nips.yaml │ ├── train_tsdiff-cond.yaml │ ├── train_tsdiff.yaml │ ├── tstr/ │ │ ├── electricity_nips.yaml │ │ ├── exchange_rate_nips.yaml │ │ ├── kdd_cup_2018_without_missing.yaml │ │ ├── m4_hourly.yaml │ │ ├── solar_nips.yaml │ │ ├── traffic_nips.yaml │ │ ├── uber_tlc_hourly.yaml │ │ └── wiki2000_nips.yaml │ └── tstr.yaml ├── pyproject.toml └── src/ └── uncond_ts_diff/ ├── arch/ │ ├── __init__.py │ ├── backbones.py │ └── s4.py ├── configs.py ├── dataset.py ├── metrics/ │ ├── __init__.py │ └── linear_pred_score.py ├── model/ │ ├── __init__.py │ ├── callback.py │ ├── diffusion/ │ │ ├── _base.py │ │ ├── tsdiff.py │ │ └── tsdiff_cond.py │ └── linear/ │ ├── _estimator.py │ └── _scaler.py ├── predictor.py ├── sampler/ │ ├── __init__.py │ ├── _base.py │ ├── observation_guidance.py │ └── refiner.py └── utils.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ __pycache__ lightning_logs/ .DS_Store *.egg-info /results/ /ckpts/ /saved_samples/ .vscode/ /sm_runs/ /data/ /checkpoints/ ================================================ FILE: CODE_OF_CONDUCT.md ================================================ ## Code of Conduct This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact opensource-codeofconduct@amazon.com with any additional questions or comments. ================================================ FILE: CONTRIBUTING.md ================================================ # Contributing Guidelines Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional documentation, we greatly value feedback and contributions from our community. Please read through this document before submitting any issues or pull requests to ensure we have all the necessary information to effectively respond to your bug report or contribution. ## Reporting Bugs/Feature Requests We welcome you to use the GitHub issue tracker to report bugs or suggest features. When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: * A reproducible test case or series of steps * The version of our code being used * Any modifications you've made relevant to the bug * Anything unusual about your environment or deployment ## Contributing via Pull Requests Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 1. You are working against the latest source on the *main* branch. 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. To send us a pull request, please: 1. Fork the repository. 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 3. Ensure local tests pass. 4. Commit to your fork using clear commit messages. 5. Send us a pull request, answering any default questions in the pull request interface. 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). ## Finding contributions to work on Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. ## Code of Conduct This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact opensource-codeofconduct@amazon.com with any additional questions or comments. ## Security issue notifications If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. ## Licensing See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. ================================================ FILE: NOTICE ================================================ Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. ================================================ FILE: README.md ================================================ # TSDiff: An Unconditional Diffusion Model for Time Series [![preprint](https://img.shields.io/static/v1?label=arXiv&message=2307.11494&color=B31B1B)](https://arxiv.org/abs/2307.11494) [![License: MIT](https://img.shields.io/badge/License-Apache--2.0-yellow.svg)](https://opensource.org/licenses/Apache-2.0) [![Venue:ICML 2023](https://img.shields.io/badge/Venue-NeurIPS%202023-007CFF)](https://neurips.cc/)


Fig. 1: An overview of TSDiff’s use cases. Predict: By utilizing observation self-guidance, TSDiff can be conditioned during inference to perform predictive tasks such as forecasting. Refine: Predictions of base forecasters can be improved by leveraging the implicit probability density of TSDiff. Synthesize: Realistic samples generated by TSDiff can be used to train downstream forecasters achieving good performance on real test data.

--- This repository contains the official implementation of the NeurIPS 2023 paper [*Predict, Refine, Synthesize: Self-Guiding Diffusion Models for Probabilistic Time Series Forecasting*](https://arxiv.org/abs/2307.11494). In this paper, we propose *TSDiff*, an unconditional diffusion model for time series. Our proposed self-guidance mechanism enables conditioning TSDiff for downstream tasks during inference, without requiring auxiliary networks or altering the training procedure. Furthermore, our refinement scheme leverages the implicit density learned by the diffusion model to iteratively refine the predictions of base forecasters. Finally, we demonstrate the high quality of the synthetic time series by training downstrain models solely on generated data and introducing the *Linear Predictive Score (LPS)*.


Fig. 2: Example forecasts generated by TSDiff-Q for time series in Electricity, KDDCup, and Exchange — three datasets with different frequencies and/or prediction lengths.

## Installation TSDiff requires Python 3.8 or higher. * Create a conda environment (optional, but recommended). ```sh conda create --name tsdiff --yes python=3.8 && conda activate tsdiff ``` * Install this package. ```sh pip install --editable "." ``` > [!TIP] > We have some updates in the `update` branch. If you're interested in testing out TSDiff or [training it on a custom dataset](https://github.com/amazon-science/unconditional-time-series-diffusion/issues/7), using the `update` branch maybe faster for training. ## Usage ### Training Models Train models using the `train_model.py` and `train_cond_model.py` scripts for `TSDiff` and `TSDiff-Cond`, respectively. Sample configurations can be found in `configs/train_tsdiff.yaml` and `configs/train_tsdiff-cond.yaml`. Specific configurations used in the paper can be found in `configs/train_tsdiff` and `configs/train_tsdiff-cond`. Example commands for regular (i.e., no missing values) forecasting: ```sh # Train TSDiff on the Uber dataset for regular forecasting python bin/train_model.py -c configs/train_tsdiff/train_uber_tlc.yaml # Train TSDiff on the M4 dataset for regular forecasting python bin/train_model.py -c configs/train_tsdiff/train_m4.yaml # Train TSDiff-Cond on the Uber dataset for regular forecasting python bin/train_cond_model.py -c configs/train_tsdiff-cond/uber_tlc_hourly.yaml # Train TSDiff-Cond on the M4 dataset for regular forecasting python bin/train_cond_model.py -c configs/train_tsdiff-cond/m4_hourly.yaml ``` Example commands for forecasting with missing values: ```sh # Train TSDiff on the Uber dataset for the missing values experiment python bin/train_model.py -c configs/train_tsdiff/train_missing_uber_tlc.yaml # Train TSDiff on the KDDCup dataset for the missing values experiment python bin/train_model.py -c configs/train_tsdiff/train_missing_kdd_cup.yaml # Train TSDiff-Cond on the Uber dataset for the RM missing values experiment python bin/train_cond_model.py -c configs/train_tsdiff-cond/missing_RM_uber_tlc_hourly.yaml # Train TSDiff-Cond on the KDDCup dataset for the BM-B missing values experiment python bin/train_cond_model.py -c configs/train_tsdiff-cond/missing_BM-B_kdd_cup_2018_without_missing.yaml # Train TSDiff-Cond on the KDDCup dataset for the BM-E missing values experiment python bin/train_cond_model.py -c configs/train_tsdiff-cond/missing_BM-E_kdd_cup_2018_without_missing.yaml ``` Note that for TSDiff we train only one model and all the missing value scenarios are evaluated using the same unconditional model. However, for TSDiff-Cond, one model is trained per missingness scenario. ### Evaluating Models The unconditional models trained above can be used for the following tasks. #### Predict using Observation Self-Guidance Use the `guidance_experiment.py` script and `configs/guidance.yaml` config to run the forecasting experiments. Specific configurations used in the paper can be found in `configs/guidance/`. Example commands: ```sh # Run observation self-guidance on the Solar dataset python bin/guidance_experiment.py -c configs/guidance/guidance_solar.yaml --ckpt /path/to/ckpt # Run observation self-guidance on the KDDCup dataset python bin/guidance_experiment.py -c configs/guidance/guidance_kdd_cup.yaml --ckpt /path/to/ckpt ``` #### Refine Predictions of Base Forecasters Use `refinement_experiment.py` script and `configs/refinement.yaml` config to run the refinement experiments. Specific configurations used in the paper can be found in `configs/refinement/`. Example commands: ```sh # Refine predictions from the Linear model on the Solar dataset python bin/refinement_experiment.py -c configs/refinement/solar_nips-linear.yaml --ckpt /path/to/ckpt # Refine predictions from the DeepAR model on the M4 dataset python bin/refinement_experiment.py -c configs/refinement/m4_hourly-deepar.yaml --ckpt /path/to/ckpt ``` #### Train Downstream Models using Synthetic Data Use `tstr_experiment.py` script and `configs/tstr.yaml` config to run the _train on synthetic-test on real_ experiments. Specific configurations used in the paper can be found in `configs/tstr/`. Example commands: ```sh # TSTR on the Solar Dataset python bin/tstr_experiment.py -c configs/tstr/solar_nips.yaml --ckpt /path/to/ckpt # TSTR on the KDDCup Dataset python bin/tstr_experiment.py -c configs/tstr/kdd_cup_2018_without_missing.yaml --ckpt /path/to/ckpt ``` ## BibTeX If you find this repository or the ideas presented in our paper useful, please consider citing. ``` @inproceedings{kollovieh2023predict, author = {Kollovieh, Marcel and Ansari, Abdul Fatir and Bohlke-Schneider, Michael and Zschiegner, Jasper and Wang, Hao and Wang, Yuyang}, title = {Predict, Refine, Synthesize: Self-Guiding Diffusion Models for Probabilistic Time Series Forecasting}, booktitle = {Advances in Neural Information Processing Systems}, year = {2023} } ``` ## Security See [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more information. ## License This project is licensed under the Apache-2.0 License. ================================================ FILE: THIRD-PARTY-LICENSES.txt ================================================ ** state-spaces; version 1.0 -- https://github.com/HazyResearch/state-spaces Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non- exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no- charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. * For state-spaces see also this required NOTICE: Copyright 2022 Albert Gu and Karan Goel and Christopher Re ================================================ FILE: bin/guidance_experiment.py ================================================ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 import logging import argparse from pathlib import Path import yaml import torch from tqdm.auto import tqdm from gluonts.dataset.field_names import FieldName from gluonts.evaluation import make_evaluation_predictions, Evaluator from uncond_ts_diff.utils import ( create_transforms, create_splitter, get_next_file_num, add_config_to_argparser, filter_metrics, MaskInput, ) from uncond_ts_diff.model import TSDiff from uncond_ts_diff.dataset import get_gts_dataset from uncond_ts_diff.sampler import ( DDPMGuidance, DDIMGuidance, ) import uncond_ts_diff.configs as diffusion_configs guidance_map = {"ddpm": DDPMGuidance, "ddim": DDIMGuidance} def load_model(config): model = TSDiff( **getattr( diffusion_configs, config.get("diffusion_config", "diffusion_small_config"), ), freq=config["freq"], use_features=config["use_features"], use_lags=config["use_lags"], normalization="mean", context_length=config["context_length"], prediction_length=config["prediction_length"], init_skip=config["init_skip"], ) model.load_state_dict( torch.load(config["ckpt"], map_location="cpu"), strict=True, ) model = model.to(config["device"]) return model def evaluate_guidance( config, model, test_dataset, transformation, num_samples=100 ): logger.info(f"Evaluating with {num_samples} samples.") results = [] if config["setup"] == "forecasting": missing_data_kwargs_list = [ { "missing_scenario": "none", "missing_values": 0, } ] config["missing_data_configs"] = missing_data_kwargs_list elif config["setup"] == "missing_values": missing_data_kwargs_list = config["missing_data_configs"] else: raise ValueError(f"Unknown setup {config['setup']}") Guidance = guidance_map[config["sampler"]] sampler_params = config["sampler_params"] for missing_data_kwargs in missing_data_kwargs_list: logger.info( f"Evaluating scenario '{missing_data_kwargs['missing_scenario']}' " f"with {missing_data_kwargs['missing_values']:.1f} missing_values." ) sampler = Guidance( model=model, prediction_length=config["prediction_length"], num_samples=num_samples, **missing_data_kwargs, **sampler_params, ) transformed_testdata = transformation.apply( test_dataset, is_train=False ) test_splitter = create_splitter( past_length=config["context_length"] + max(model.lags_seq), future_length=config["prediction_length"], mode="test", ) masking_transform = MaskInput( FieldName.TARGET, FieldName.OBSERVED_VALUES, config["context_length"], missing_data_kwargs["missing_scenario"], missing_data_kwargs["missing_values"], ) test_transform = test_splitter + masking_transform predictor = sampler.get_predictor( test_transform, batch_size=1280 // num_samples, device=config["device"], ) forecast_it, ts_it = make_evaluation_predictions( dataset=transformed_testdata, predictor=predictor, num_samples=num_samples, ) forecasts = list(tqdm(forecast_it, total=len(transformed_testdata))) tss = list(ts_it) evaluator = Evaluator() metrics, _ = evaluator(tss, forecasts) metrics = filter_metrics(metrics) results.append(dict(**missing_data_kwargs, **metrics)) return results def main(config: dict, log_dir: str): # Read global parameters dataset_name = config["dataset"] freq = config["freq"] prediction_length = config["prediction_length"] num_samples = config["num_samples"] # Load dataset and model logger.info("Loading model") model = load_model(config) dataset = get_gts_dataset(dataset_name) assert dataset.metadata.freq == freq assert dataset.metadata.prediction_length == prediction_length # Setup data transformation and loading transformation = create_transforms( num_feat_dynamic_real=0, num_feat_static_cat=0, num_feat_static_real=0, time_features=model.time_features, prediction_length=prediction_length, ) # Run guidance results = evaluate_guidance( config, model, dataset.test, transformation, num_samples=num_samples ) # Save results log_dir = Path(log_dir) / "guidance_logs" log_dir.mkdir(exist_ok=True, parents=True) base_filename = "results" run_num = get_next_file_num( base_filename, log_dir, file_type="yaml", separator="-" ) save_path = log_dir / f"{base_filename}-{run_num}.yaml" with open(save_path, "w") as fp: yaml.safe_dump( {"config": config, "metrics": results}, fp, default_flow_style=False, sort_keys=False, ) if __name__ == "__main__": # Setup Logger logging.basicConfig( format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) logger = logging.getLogger(__file__) logger.setLevel(logging.INFO) # Setup argparse parser = argparse.ArgumentParser() parser.add_argument( "-c", "--config", type=str, required=True, help="Path to yaml config" ) parser.add_argument( "--out_dir", type=str, default="./results", help="Path to results dir" ) args, _ = parser.parse_known_args() with open(args.config, "r") as fp: config = yaml.safe_load(fp) # Update config from command line parser = add_config_to_argparser(config=config, parser=parser) args = parser.parse_args() config_updates = vars(args) for k in config.keys() & config_updates.keys(): orig_val = config[k] updated_val = config_updates[k] if updated_val != orig_val: logger.info(f"Updated key '{k}': {orig_val} -> {updated_val}") config.update(config_updates) main(config=config, log_dir=args.out_dir) ================================================ FILE: bin/refinement_experiment.py ================================================ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 import json import copy import logging import argparse from pathlib import Path import yaml import torch import numpy as np from tqdm.auto import tqdm from gluonts.mx import DeepAREstimator, TransformerEstimator from gluonts.model.seasonal_naive import SeasonalNaivePredictor from gluonts.evaluation import make_evaluation_predictions, Evaluator from gluonts.dataset.loader import TrainDataLoader from gluonts.itertools import Cached from gluonts.torch.batchify import batchify from uncond_ts_diff.utils import ( create_transforms, create_splitter, get_next_file_num, add_config_to_argparser, filter_metrics, ) from uncond_ts_diff.model import TSDiff, LinearEstimator from uncond_ts_diff.dataset import get_gts_dataset from uncond_ts_diff.sampler import ( MostLikelyRefiner, MCMCRefiner, DDPMGuidance, DDIMGuidance, ) import uncond_ts_diff.configs as diffusion_configs guidance_map = {"ddpm": DDPMGuidance, "ddim": DDIMGuidance} refiner_map = {"most_likely": MostLikelyRefiner, "mcmc": MCMCRefiner} def load_model(config): model = TSDiff( **getattr( diffusion_configs, config.get("diffusion_config", "diffusion_small_config"), ), freq=config["freq"], use_features=config["use_features"], use_lags=config["use_lags"], normalization="mean", context_length=config["context_length"], prediction_length=config["prediction_length"], init_skip=config["init_skip"], ) model.load_state_dict( torch.load(config["ckpt"], map_location="cpu"), strict=True, ) model = model.to(config["device"]) return model def get_best_diffusion_step(model: TSDiff, data_loader, device): losses = np.zeros(model.timesteps) batch = { k: v.to(device) for k, v in next(iter(data_loader)).items() if isinstance(v, torch.Tensor) } x, features, scale = model._extract_features(batch) for t in range(model.timesteps): loss, _, _ = model.p_losses( x.to(device), torch.tensor([t], device=device) ) losses[t] = loss best_t = ((losses - losses.mean()) ** 2).argmin() return best_t def train_and_forecast_base_model(dataset, base_model_name, config): base_model_kwargs = config.get("base_model_params", {}) if base_model_name == "deepar": predictor = DeepAREstimator( prediction_length=dataset.metadata.prediction_length, freq=dataset.metadata.freq, **base_model_kwargs, ).train(list(dataset.train), cache_data=True) elif base_model_name == "transformer": predictor = TransformerEstimator( prediction_length=dataset.metadata.prediction_length, freq=dataset.metadata.freq, **base_model_kwargs, ).train(list(dataset.train), cache_data=True) elif base_model_name == "seasonal_naive": predictor = SeasonalNaivePredictor( freq=dataset.metadata.freq, prediction_length=dataset.metadata.prediction_length, **base_model_kwargs, ) elif base_model_name == "linear": num_train_samples = 10000 predictor = LinearEstimator( freq=dataset.metadata.freq, prediction_length=dataset.metadata.prediction_length, context_length=config["context_length"], num_train_samples=num_train_samples, **base_model_kwargs, ).train(list(dataset.train), cache_data=True) else: raise ValueError(f"Unsupported base model {base_model_name}!") fcst_iter, ts_iter = make_evaluation_predictions( dataset=dataset.test, predictor=predictor, num_samples=config["num_samples"], ) fcsts = list(tqdm(fcst_iter, total=len(dataset.test))) tss = list(ts_iter) return fcsts, tss def forecast_guidance( dataset, base_model_name, config, diffusion_model, transformed_testdata, test_splitter, ): assert len(dataset.test) == len(transformed_testdata) base_model_kwargs = config.get("base_model_params", {}) Guidance = guidance_map[base_model_name] predictor = Guidance( model=diffusion_model, prediction_length=dataset.metadata.prediction_length, num_samples=config["num_samples"], **base_model_kwargs, ).get_predictor( input_transform=test_splitter, batch_size=1280 // config["num_samples"], device=config["device"], ) fcst_iter, ts_iter = make_evaluation_predictions( dataset=transformed_testdata, predictor=predictor, num_samples=config["num_samples"], ) fcsts = list(tqdm(fcst_iter, total=len(dataset.test))) tss = list(ts_iter) return fcsts, tss def main(config: dict, log_dir: str): # Read global parameters dataset_name = config["dataset"] device = config["device"] context_length = config["context_length"] prediction_length = config["prediction_length"] base_model_name = config["base_model"] num_samples = config["num_samples"] # Load dataset and model logger.info("Loading model") dataset = get_gts_dataset(dataset_name) config["freq"] = dataset.metadata.freq assert prediction_length == dataset.metadata.prediction_length model = load_model(config) # Setup data transformation and loading transformation = create_transforms( num_feat_dynamic_real=0, num_feat_static_cat=0, num_feat_static_real=0, time_features=model.time_features, prediction_length=prediction_length, ) transformed_data = transformation.apply(list(dataset.train), is_train=True) transformed_testdata = transformation.apply( list(dataset.test), is_train=False ) training_splitter = create_splitter( past_length=context_length + max(model.lags_seq), future_length=prediction_length, mode="train", ) test_splitter = create_splitter( past_length=context_length + max(model.lags_seq), future_length=prediction_length, mode="test", ) train_dataloader = TrainDataLoader( Cached(transformed_data), batch_size=1024, stack_fn=batchify, transform=training_splitter, num_batches_per_epoch=2048, ) best_t = get_best_diffusion_step(model, train_dataloader, device) # Train base model & get initial forecasts logger.info("Training base model") if base_model_name in {"ddpm", "ddim"}: base_fcsts, tss = forecast_guidance( dataset, base_model_name, config, diffusion_model=model, transformed_testdata=transformed_testdata, test_splitter=test_splitter, ) else: base_fcsts, tss = train_and_forecast_base_model( dataset, base_model_name, config ) # Evaluate base forecasts evaluator = Evaluator() baseline_metrics, _ = evaluator(tss, base_fcsts) baseline_metrics = filter_metrics(baseline_metrics) # Run refinement log_dir = Path(log_dir) / "refinement_logs" log_dir.mkdir(exist_ok=True, parents=True) base_filename = "results" run_num = get_next_file_num( base_filename, log_dir, file_type="yaml", separator="-" ) save_path = log_dir / f"{base_filename}-{run_num}.yaml" results = [ { "model": "baseline", "model_params": { "name": base_model_name, **config.get("base_model_params", {}), }, **baseline_metrics, } ] n_refiner_configs = len(config["refiner_configs"]) for i, ref_config in enumerate(config["refiner_configs"]): logger.info( f"Running refiner ({i+1}/{n_refiner_configs}): {json.dumps(ref_config)}" ) refiner_config = copy.deepcopy(ref_config) refiner_name = refiner_config.pop("refiner_name") Refiner = refiner_map[refiner_name] refiner = Refiner( model, prediction_length, init=iter(base_fcsts), num_samples=num_samples, fixed_t=best_t, iterations=config["iterations"], **refiner_config, ) refiner_predictor = refiner.get_predictor( test_splitter, batch_size=1024 // num_samples, device=device ) forecast_it, ts_it = make_evaluation_predictions( dataset=transformed_testdata, predictor=refiner_predictor, num_samples=num_samples, ) evaluator = Evaluator() refined_metrics, _ = evaluator( list(ts_it), list(tqdm(forecast_it, total=len(transformed_testdata))), ) refined_metrics = filter_metrics(refined_metrics) results.append( { "model": refiner_name, "model_params": json.dumps(ref_config), **refined_metrics, } ) with open(save_path, "w") as fp: yaml.safe_dump( {"config": config, "metrics": results}, fp, default_flow_style=False, sort_keys=False, ) if __name__ == "__main__": # Setup Logger logging.basicConfig( format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) logger = logging.getLogger(__file__) logger.setLevel(logging.INFO) # Setup argparse parser = argparse.ArgumentParser() parser.add_argument( "-c", "--config", type=str, required=True, help="Path to yaml config" ) parser.add_argument( "--out_dir", type=str, default="./results", help="Path to results dir" ) args, _ = parser.parse_known_args() with open(args.config, "r") as fp: config = yaml.safe_load(fp) # Update config from command line parser = add_config_to_argparser(config=config, parser=parser) args = parser.parse_args() config_updates = vars(args) for k in config.keys() & config_updates.keys(): orig_val = config[k] updated_val = config_updates[k] if updated_val != orig_val: logger.info(f"Updated key '{k}': {orig_val} -> {updated_val}") config.update(config_updates) main(config=config, log_dir=args.out_dir) ================================================ FILE: bin/train_cond_model.py ================================================ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 import logging import argparse from pathlib import Path import yaml import torch from tqdm.auto import tqdm import pytorch_lightning as pl from pytorch_lightning.callbacks import ModelCheckpoint, RichProgressBar from gluonts.dataset.loader import TrainDataLoader, ValidationDataLoader from gluonts.dataset.split import OffsetSplitter from gluonts.itertools import Cached from gluonts.torch.batchify import batchify from gluonts.evaluation import make_evaluation_predictions, Evaluator from gluonts.dataset.field_names import FieldName import uncond_ts_diff.configs as diffusion_configs from uncond_ts_diff.dataset import get_gts_dataset from uncond_ts_diff.model import TSDiffCond from uncond_ts_diff.utils import ( create_transforms, create_splitter, add_config_to_argparser, filter_metrics, MaskInput, ConcatDataset, ) def create_model(config): model = TSDiffCond( **getattr(diffusion_configs, config["diffusion_config"]), freq=config["freq"], use_features=config["use_features"], use_lags=config["use_lags"], normalization=config["normalization"], context_length=config["context_length"], prediction_length=config["prediction_length"], lr=config["lr"], init_skip=config["init_skip"], noise_observed=config["noise_observed"], ) model.to(config["device"]) return model def evaluate_conditional( config, model: TSDiffCond, test_dataset, transformation, num_samples=100, ): logger.info(f"Evaluating with {num_samples} samples.") logger.info( f"Evaluating scenario '{config['missing_scenario']}' " f"with {config['missing_values']:.1f} missing_values." ) results = [] transformed_testdata = transformation.apply(test_dataset, is_train=False) test_splitter = create_splitter( past_length=config["context_length"] + max(model.lags_seq), future_length=config["prediction_length"], mode="test", ) masking_transform = MaskInput( FieldName.TARGET, FieldName.OBSERVED_VALUES, config["context_length"], config["missing_scenario"], config["missing_values"], ) test_transform = test_splitter + masking_transform predictor = model.get_predictor( test_transform, batch_size=1280, device=config["device"], ) forecast_it, ts_it = make_evaluation_predictions( dataset=transformed_testdata, predictor=predictor, num_samples=num_samples, ) forecasts = list(tqdm(forecast_it, total=len(transformed_testdata))) tss = list(ts_it) evaluator = Evaluator() metrics, _ = evaluator(tss, forecasts) metrics = filter_metrics(metrics) results.append(dict(**metrics)) return results def main(config, log_dir): # Load parameters dataset_name = config["dataset"] freq = config["freq"] context_length = config["context_length"] prediction_length = config["prediction_length"] total_length = context_length + prediction_length # Create model model = create_model(config) # Setup dataset and data loading dataset = get_gts_dataset(dataset_name) assert dataset.metadata.freq == freq assert dataset.metadata.prediction_length == prediction_length if config["setup"] == "forecasting": training_data = dataset.train elif config["setup"] == "missing_values": missing_values_splitter = OffsetSplitter(offset=-total_length) training_data, _ = missing_values_splitter.split(dataset.train) num_rolling_evals = int(len(dataset.test) / len(dataset.train)) transformation = create_transforms( num_feat_dynamic_real=0, num_feat_static_cat=0, num_feat_static_real=0, time_features=model.time_features, prediction_length=config["prediction_length"], ) training_splitter = create_splitter( past_length=config["context_length"] + max(model.lags_seq), future_length=config["prediction_length"], mode="train", ) if config["setup"] == "forecasting": config["missing_scenario"] = "none" config["missing_values"] = 0 masking_transform = MaskInput( FieldName.TARGET, FieldName.OBSERVED_VALUES, config["context_length"], config.get("train_missing_scenario", config["missing_scenario"]), config["missing_values"], ) train_transform = training_splitter + masking_transform callbacks = [] val_loader = None if config["use_validation_set"]: transformed_data = transformation.apply(training_data, is_train=True) train_val_splitter = OffsetSplitter( offset=-config["prediction_length"] * num_rolling_evals ) _, val_gen = train_val_splitter.split(training_data) val_dataset = ConcatDataset( val_gen.generate_instances( config["prediction_length"], num_rolling_evals ) ) val_splitter = create_splitter( past_length=config["context_length"] + max(model.lags_seq), future_length=config["prediction_length"], mode="val", ) transformed_valdata = transformation.apply(val_dataset, is_train=True) val_loader = ValidationDataLoader( transformed_valdata, batch_size=1280, stack_fn=batchify, transform=val_splitter + masking_transform, ) callbacks = [] log_monitor = "valid_loss" else: transformed_data = transformation.apply(training_data, is_train=True) log_monitor = "train_loss" filename = dataset_name + "-{epoch:03d}-{train_loss:.3f}" data_loader = TrainDataLoader( Cached(transformed_data), batch_size=config["batch_size"], stack_fn=batchify, transform=train_transform, num_batches_per_epoch=config["num_batches_per_epoch"], ) checkpoint_callback = ModelCheckpoint( save_top_k=3, monitor=f"{log_monitor}", mode="min", filename=filename, save_last=True, save_weights_only=True, ) callbacks.append(checkpoint_callback) callbacks.append(RichProgressBar()) trainer = pl.Trainer( accelerator="gpu" if torch.cuda.is_available() else None, devices=[int(config["device"].split(":")[-1])], max_epochs=config["max_epochs"], enable_progress_bar=True, num_sanity_val_steps=0, callbacks=callbacks, default_root_dir=log_dir, gradient_clip_val=config.get("gradient_clip_val", None), check_val_every_n_epoch=config["eval_every"], ) logger.info(f"Logging to {trainer.logger.log_dir}") trainer.fit( model, train_dataloaders=data_loader, val_dataloaders=val_loader ) logger.info("Training completed.") best_ckpt_path = Path(trainer.logger.log_dir) / "best_checkpoint.ckpt" if not best_ckpt_path.exists(): torch.save( torch.load(checkpoint_callback.best_model_path)["state_dict"], best_ckpt_path, ) logger.info(f"Loading {best_ckpt_path}.") best_state_dict = torch.load(best_ckpt_path) model.load_state_dict(best_state_dict, strict=True) metrics = ( evaluate_conditional(config, model, dataset.test, transformation) if config.get("do_final_eval", True) else "Final eval not performed" ) with open(Path(trainer.logger.log_dir) / "results.yaml", "w") as fp: yaml.dump( { "config": config, "version": trainer.logger.version, "metrics": metrics, }, fp, ) if __name__ == "__main__": # Setup Logger logging.basicConfig( format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) logger = logging.getLogger(__file__) logger.setLevel(logging.INFO) # Setup argparse parser = argparse.ArgumentParser() parser.add_argument( "-c", "--config", type=str, required=True, help="Path to yaml config" ) parser.add_argument( "--out_dir", type=str, default="./", help="Path to results dir" ) args, _ = parser.parse_known_args() with open(args.config, "r") as fp: config = yaml.safe_load(fp) # Update config from command line parser = add_config_to_argparser(config=config, parser=parser) args = parser.parse_args() config_updates = vars(args) for k in config.keys() & config_updates.keys(): orig_val = config[k] updated_val = config_updates[k] if updated_val != orig_val: logger.info(f"Updated key '{k}': {orig_val} -> {updated_val}") config.update(config_updates) main(config=config, log_dir=args.out_dir) ================================================ FILE: bin/train_model.py ================================================ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 import logging import argparse from pathlib import Path import yaml import torch from tqdm.auto import tqdm import pytorch_lightning as pl from pytorch_lightning.callbacks import ModelCheckpoint, RichProgressBar from gluonts.dataset.loader import TrainDataLoader from gluonts.dataset.split import OffsetSplitter from gluonts.itertools import Cached from gluonts.torch.batchify import batchify from gluonts.evaluation import make_evaluation_predictions, Evaluator from gluonts.dataset.field_names import FieldName import uncond_ts_diff.configs as diffusion_configs from uncond_ts_diff.dataset import get_gts_dataset from uncond_ts_diff.model.callback import EvaluateCallback from uncond_ts_diff.model import TSDiff from uncond_ts_diff.sampler import DDPMGuidance, DDIMGuidance from uncond_ts_diff.utils import ( create_transforms, create_splitter, add_config_to_argparser, filter_metrics, MaskInput, ) guidance_map = {"ddpm": DDPMGuidance, "ddim": DDIMGuidance} def create_model(config): model = TSDiff( **getattr(diffusion_configs, config["diffusion_config"]), freq=config["freq"], use_features=config["use_features"], use_lags=config["use_lags"], normalization=config["normalization"], context_length=config["context_length"], prediction_length=config["prediction_length"], lr=config["lr"], init_skip=config["init_skip"], ) model.to(config["device"]) return model def evaluate_guidance( config, model, test_dataset, transformation, num_samples=100 ): logger.info(f"Evaluating with {num_samples} samples.") results = [] if config["setup"] == "forecasting": missing_data_kwargs_list = [ { "missing_scenario": "none", "missing_values": 0, } ] config["missing_data_configs"] = missing_data_kwargs_list elif config["setup"] == "missing_values": missing_data_kwargs_list = config["missing_data_configs"] else: raise ValueError(f"Unknown setup {config['setup']}") Guidance = guidance_map[config["sampler"]] sampler_kwargs = config["sampler_params"] for missing_data_kwargs in missing_data_kwargs_list: logger.info( f"Evaluating scenario '{missing_data_kwargs['missing_scenario']}' " f"with {missing_data_kwargs['missing_values']:.1f} missing_values." ) sampler = Guidance( model=model, prediction_length=config["prediction_length"], num_samples=num_samples, **missing_data_kwargs, **sampler_kwargs, ) transformed_testdata = transformation.apply( test_dataset, is_train=False ) test_splitter = create_splitter( past_length=config["context_length"] + max(model.lags_seq), future_length=config["prediction_length"], mode="test", ) masking_transform = MaskInput( FieldName.TARGET, FieldName.OBSERVED_VALUES, config["context_length"], missing_data_kwargs["missing_scenario"], missing_data_kwargs["missing_values"], ) test_transform = test_splitter + masking_transform predictor = sampler.get_predictor( test_transform, batch_size=1280 // num_samples, device=config["device"], ) forecast_it, ts_it = make_evaluation_predictions( dataset=transformed_testdata, predictor=predictor, num_samples=num_samples, ) forecasts = list(tqdm(forecast_it, total=len(transformed_testdata))) tss = list(ts_it) evaluator = Evaluator() metrics, _ = evaluator(tss, forecasts) metrics = filter_metrics(metrics) results.append(dict(**missing_data_kwargs, **metrics)) return results def main(config, log_dir): # Load parameters dataset_name = config["dataset"] freq = config["freq"] context_length = config["context_length"] prediction_length = config["prediction_length"] total_length = context_length + prediction_length # Create model model = create_model(config) # Setup dataset and data loading dataset = get_gts_dataset(dataset_name) assert dataset.metadata.freq == freq assert dataset.metadata.prediction_length == prediction_length if config["setup"] == "forecasting": training_data = dataset.train elif config["setup"] == "missing_values": missing_values_splitter = OffsetSplitter(offset=-total_length) training_data, _ = missing_values_splitter.split(dataset.train) num_rolling_evals = int(len(dataset.test) / len(dataset.train)) transformation = create_transforms( num_feat_dynamic_real=0, num_feat_static_cat=0, num_feat_static_real=0, time_features=model.time_features, prediction_length=config["prediction_length"], ) training_splitter = create_splitter( past_length=config["context_length"] + max(model.lags_seq), future_length=config["prediction_length"], mode="train", ) callbacks = [] if config["use_validation_set"]: transformed_data = transformation.apply(training_data, is_train=True) train_val_splitter = OffsetSplitter( offset=-config["prediction_length"] * num_rolling_evals ) _, val_gen = train_val_splitter.split(training_data) val_data = val_gen.generate_instances( config["prediction_length"], num_rolling_evals ) callbacks = [ EvaluateCallback( context_length=config["context_length"], prediction_length=config["prediction_length"], sampler=config["sampler"], sampler_kwargs=config["sampler_params"], num_samples=config["num_samples"], model=model, transformation=transformation, test_dataset=dataset.test, val_dataset=val_data, eval_every=config["eval_every"], ) ] else: transformed_data = transformation.apply(training_data, is_train=True) log_monitor = "train_loss" filename = dataset_name + "-{epoch:03d}-{train_loss:.3f}" data_loader = TrainDataLoader( Cached(transformed_data), batch_size=config["batch_size"], stack_fn=batchify, transform=training_splitter, num_batches_per_epoch=config["num_batches_per_epoch"], ) checkpoint_callback = ModelCheckpoint( save_top_k=3, monitor=f"{log_monitor}", mode="min", filename=filename, save_last=True, save_weights_only=True, ) callbacks.append(checkpoint_callback) callbacks.append(RichProgressBar()) trainer = pl.Trainer( accelerator="gpu" if torch.cuda.is_available() else None, devices=[int(config["device"].split(":")[-1])], max_epochs=config["max_epochs"], enable_progress_bar=True, num_sanity_val_steps=0, callbacks=callbacks, default_root_dir=log_dir, gradient_clip_val=config.get("gradient_clip_val", None), ) logger.info(f"Logging to {trainer.logger.log_dir}") trainer.fit(model, train_dataloaders=data_loader) logger.info("Training completed.") best_ckpt_path = Path(trainer.logger.log_dir) / "best_checkpoint.ckpt" if not best_ckpt_path.exists(): torch.save( torch.load(checkpoint_callback.best_model_path)["state_dict"], best_ckpt_path, ) logger.info(f"Loading {best_ckpt_path}.") best_state_dict = torch.load(best_ckpt_path) model.load_state_dict(best_state_dict, strict=True) metrics = ( evaluate_guidance(config, model, dataset.test, transformation) if config.get("do_final_eval", True) else "Final eval not performed" ) with open(Path(trainer.logger.log_dir) / "results.yaml", "w") as fp: yaml.dump( { "config": config, "version": trainer.logger.version, "metrics": metrics, }, fp, ) if __name__ == "__main__": # Setup Logger logging.basicConfig( format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) logger = logging.getLogger(__file__) logger.setLevel(logging.INFO) # Setup argparse parser = argparse.ArgumentParser() parser.add_argument( "-c", "--config", type=str, required=True, help="Path to yaml config" ) parser.add_argument( "--out_dir", type=str, default="./", help="Path to results dir" ) args, _ = parser.parse_known_args() with open(args.config, "r") as fp: config = yaml.safe_load(fp) # Update config from command line parser = add_config_to_argparser(config=config, parser=parser) args = parser.parse_args() config_updates = vars(args) for k in config.keys() & config_updates.keys(): orig_val = config[k] updated_val = config_updates[k] if updated_val != orig_val: logger.info(f"Updated key '{k}': {orig_val} -> {updated_val}") config.update(config_updates) main(config=config, log_dir=args.out_dir) ================================================ FILE: bin/tstr_experiment.py ================================================ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 from functools import partial import math import logging import argparse from pathlib import Path import yaml import torch import numpy as np from tqdm.auto import tqdm from gluonts.mx import DeepAREstimator, TransformerEstimator from gluonts.evaluation import Evaluator from gluonts.dataset.loader import TrainDataLoader from gluonts.itertools import Cached from gluonts.torch.batchify import batchify from gluonts.time_feature import ( get_lags_for_frequency, time_features_from_frequency_str, ) from gluonts.dataset.split import slice_data_entry from gluonts.transform import AdhocTransform, Chain from uncond_ts_diff.utils import ( ScaleAndAddMeanFeature, ScaleAndAddMinMaxFeature, GluonTSNumpyDataset, create_transforms, create_splitter, get_next_file_num, add_config_to_argparser, make_evaluation_predictions_with_scaling, filter_metrics, ) from uncond_ts_diff.model import TSDiff, LinearEstimator from uncond_ts_diff.dataset import get_gts_dataset import uncond_ts_diff.configs as diffusion_configs DOWNSTREAM_MODELS = ["linear", "deepar", "transformer"] def load_model(config): model = TSDiff( **getattr( diffusion_configs, config.get("diffusion_config", "diffusion_small_config"), ), freq=config["freq"], use_features=config["use_features"], use_lags=config["use_lags"], normalization="mean", context_length=config["context_length"], prediction_length=config["prediction_length"], init_skip=config["init_skip"], ) model.load_state_dict( torch.load(config["ckpt"], map_location="cpu"), strict=True, ) model = model.to(config["device"]) return model def sample_synthetic( model: TSDiff, num_samples: int = 10_000, batch_size: int = 1000, ): synth_samples = [] n_iters = math.ceil(num_samples / batch_size) for _ in tqdm(range(n_iters)): samples = model.sample_n(num_samples=batch_size) synth_samples.append(samples) synth_samples = np.concatenate(synth_samples, axis=0)[:num_samples] return synth_samples def sample_real( data_loader, n_timesteps: int, num_samples: int = 10_000, batch_size: int = 1000, ): real_samples = [] data_iter = iter(data_loader) n_iters = math.ceil(num_samples / batch_size) for _ in tqdm(range(n_iters)): try: batch = next(data_iter) except StopIteration: data_iter = iter(data_loader) batch = next(data_iter) ts = np.concatenate( [batch["past_target"], batch["future_target"]], axis=-1 )[:, -n_timesteps:] real_samples.append(ts) real_samples = np.concatenate(real_samples, axis=0)[:num_samples] return real_samples def evaluate_tstr( tstr_predictor, test_dataset, context_length, prediction_length, num_samples=100, scaling_type="mean", ): total_length = context_length + prediction_length # Slice test set to be of the same length as context_length + prediction_length slice_func = partial(slice_data_entry, slice_=slice(-total_length, None)) if scaling_type == "mean": ScaleAndAddScaleFeature = ScaleAndAddMeanFeature elif scaling_type == "min-max": ScaleAndAddScaleFeature = ScaleAndAddMinMaxFeature transformation = Chain( [ AdhocTransform(slice_func), # Add scale to data entry for use later during evaluation ScaleAndAddScaleFeature("target", "scale", prediction_length), ] ) sliced_test_set = transformation.apply(test_dataset) fcst_iter, ts_iter = make_evaluation_predictions_with_scaling( dataset=sliced_test_set, predictor=tstr_predictor, num_samples=num_samples, scaling_type=scaling_type, ) evaluator = Evaluator() metrics, _ = evaluator(list(ts_iter), list(fcst_iter)) return filter_metrics(metrics) def train_and_evaluate( dataset, model_name, synth_samples, real_samples, config, scaling_type="mean", ): # NOTE: There's no notion of time for synthetic time series, # they are just "sequences". # A dummy timestamp is used for start time in synthetic time series. # Hence, time_features are set to [] in the models below. model_name = model_name.lower() freq = dataset.metadata.freq context_length = config["context_length"] prediction_length = config["prediction_length"] total_length = context_length + prediction_length assert len(synth_samples) == len(real_samples) assert ( synth_samples.shape[-1] == total_length and real_samples.shape[-1] == total_length ) num_samples = len(real_samples) synthetic_dataset = GluonTSNumpyDataset(synth_samples) if model_name == "linear": logger.info(f"Running TSTR for {model_name}") tstr_predictor = LinearEstimator( freq=freq, # Not actually used in the estimator prediction_length=prediction_length, context_length=context_length, num_train_samples=num_samples, # Synthetic dataset is in the "scaled space" scaling=False, ).train(synthetic_dataset) elif model_name == "deepar": logger.info(f"Running TSTR for {model_name}") tstr_predictor = DeepAREstimator( freq=freq, prediction_length=prediction_length, # Synthetic dataset is in the "scaled space" scaling=False, time_features=[], lags_seq=get_lags_for_frequency(freq, lag_ub=context_length), ).train(synthetic_dataset) elif model_name == "transformer": logger.info(f"Running TSTR for {model_name}") tstr_predictor = TransformerEstimator( freq=freq, prediction_length=prediction_length, # Synthetic dataset is in the "scaled space" scaling=False, time_features=[], lags_seq=get_lags_for_frequency(freq, lag_ub=context_length), ).train(synthetic_dataset) tstr_metrics = evaluate_tstr( tstr_predictor=tstr_predictor, test_dataset=dataset.test, context_length=context_length, prediction_length=prediction_length, scaling_type=scaling_type, ) return dict( tstr_metrics=tstr_metrics, ) def main(config: dict, log_dir: str, samples_path: str): # Read global parameters dataset_name = config["dataset"] context_length = config["context_length"] prediction_length = config["prediction_length"] # Create log_dir log_dir: Path = Path(log_dir) base_dirname = "tstr_log" run_num = get_next_file_num( base_dirname, log_dir, file_type="", separator="-" ) log_dir = log_dir / f"{base_dirname}-{run_num}" log_dir.mkdir(exist_ok=True, parents=True) logger.info(f"Logging to {log_dir}") # Load dataset and model logger.info("Loading model") dataset = get_gts_dataset(dataset_name) config["freq"] = dataset.metadata.freq assert prediction_length == dataset.metadata.prediction_length model = load_model(config) # Setup data transformation and loading transformation = create_transforms( num_feat_dynamic_real=0, num_feat_static_cat=0, num_feat_static_real=0, time_features=time_features_from_frequency_str(config["freq"]), prediction_length=prediction_length, ) transformed_data = transformation.apply(list(dataset.train), is_train=True) training_splitter = create_splitter( past_length=context_length + max(model.lags_seq), future_length=prediction_length, mode="train", ) train_dataloader = TrainDataLoader( Cached(transformed_data), batch_size=1000, stack_fn=batchify, transform=training_splitter, ) # Generate real samples logger.info("Generating real samples") real_samples = sample_real( train_dataloader, n_timesteps=context_length + prediction_length, num_samples=10000, ) np.save(log_dir / "real_samples.npy", real_samples) if samples_path is None: # Generate synthetic samples logger.info("Generating synthetic samples") synth_samples = sample_synthetic(model, num_samples=10000) np.save(log_dir / "synth_samples.npy", synth_samples) else: logger.info(f"Using synthetic samples from {samples_path}") synth_samples = np.load(samples_path)[:10000] synth_samples = synth_samples.reshape( (10000, context_length + prediction_length) ) # Run TSTR experiment for each downstream model results = [] for model_name in DOWNSTREAM_MODELS: logger.info(f"Training and evaluating {model_name}") metrics = train_and_evaluate( dataset=dataset, model_name=model_name, synth_samples=synth_samples, real_samples=real_samples, config=config, scaling_type=config["scaling_type"], ) results.append({"model": model_name, **metrics}) logger.info("Saving results") with open(log_dir / "results.yaml", "w") as fp: yaml.safe_dump( {"config": config, "metrics": results}, fp, default_flow_style=False, sort_keys=False, ) if __name__ == "__main__": # Setup Logger logging.basicConfig( format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) logger = logging.getLogger(__file__) logger.setLevel(logging.INFO) # Setup argparse parser = argparse.ArgumentParser() parser.add_argument( "-c", "--config", type=str, required=True, help="Path to yaml config" ) parser.add_argument( "--out_dir", type=str, default="./results", help="Path to results dir" ) parser.add_argument( "--samples_path", type=str, help="Path to generated samples" ) args, _ = parser.parse_known_args() with open(args.config, "r") as fp: config = yaml.safe_load(fp) # Update config from command line parser = add_config_to_argparser(config=config, parser=parser) args = parser.parse_args() config_updates = vars(args) for k in config.keys() & config_updates.keys(): orig_val = config[k] updated_val = config_updates[k] if updated_val != orig_val: logger.info(f"Updated key '{k}': {orig_val} -> {updated_val}") config.update(config_updates) main(config=config, log_dir=args.out_dir, samples_path=args.samples_path) ================================================ FILE: configs/guidance/guidance_electricity.yaml ================================================ ckpt: dummy/path.ckpt context_length: 336 dataset: electricity_nips device: cuda:0 diffusion_config: diffusion_small_config freq: H init_skip: false num_samples: 100 prediction_length: 24 sampler: ddpm sampler_params: guidance: quantile scale: 4 setup: forecasting use_features: false use_lags: true ================================================ FILE: configs/guidance/guidance_exchange.yaml ================================================ ckpt: dummy/path.ckpt context_length: 360 dataset: exchange_rate_nips device: cuda:0 diffusion_config: diffusion_small_config freq: B init_skip: true num_samples: 100 prediction_length: 30 sampler: ddpm sampler_params: guidance: quantile scale: 8 setup: forecasting use_features: false use_lags: true ================================================ FILE: configs/guidance/guidance_kdd_cup.yaml ================================================ ckpt: dummy/path.ckpt context_length: 312 dataset: kdd_cup_2018_without_missing device: cuda:0 diffusion_config: diffusion_small_config freq: H init_skip: true num_samples: 100 prediction_length: 48 sampler: ddpm sampler_params: guidance: quantile scale: 1 setup: forecasting use_features: false use_lags: true ================================================ FILE: configs/guidance/guidance_m4.yaml ================================================ ckpt: dummy/path.ckpt context_length: 312 dataset: m4_hourly device: cuda:0 diffusion_config: diffusion_small_config freq: H init_skip: false num_samples: 100 prediction_length: 48 sampler: ddpm sampler_params: guidance: quantile scale: 2 setup: forecasting use_features: false use_lags: false ================================================ FILE: configs/guidance/guidance_solar.yaml ================================================ ckpt: dummy/path.ckpt context_length: 336 dataset: solar_nips device: cuda:0 diffusion_config: diffusion_small_config freq: H init_skip: false num_samples: 100 prediction_length: 24 sampler: ddpm sampler_params: guidance: quantile scale: 8 setup: forecasting use_features: false use_lags: true ================================================ FILE: configs/guidance/guidance_traffic.yaml ================================================ ckpt: dummy/path.ckpt context_length: 336 dataset: traffic_nips device: cuda:0 diffusion_config: diffusion_small_config freq: H init_skip: true num_samples: 100 prediction_length: 24 sampler: ddpm sampler_params: guidance: quantile scale: 4 setup: forecasting use_features: false use_lags: true ================================================ FILE: configs/guidance/guidance_uber_tlc.yaml ================================================ ckpt: dummy/path.ckpt context_length: 336 dataset: uber_tlc_hourly device: cuda:0 diffusion_config: diffusion_small_config freq: H init_skip: false num_samples: 100 prediction_length: 24 sampler: ddpm sampler_params: guidance: quantile scale: 2 setup: forecasting use_features: false use_lags: true ================================================ FILE: configs/guidance/guidance_wiki.yaml ================================================ ckpt: dummy/path.ckpt context_length: 360 dataset: wiki2000_nips device: cuda:0 diffusion_config: diffusion_small_config freq: 1D init_skip: false num_samples: 100 prediction_length: 30 sampler: ddpm sampler_params: guidance: quantile scale: 2 setup: forecasting use_features: false use_lags: false ================================================ FILE: configs/guidance.yaml ================================================ # Model & checkpoint parameters dataset: solar_nips freq: H device: cuda:0 ckpt: ckpts/forecasting/solar_nips/649_.ckpt diffusion_config: diffusion_small_config context_length: 336 prediction_length: 24 use_lags: true use_features: false init_skip: false sampler: ddpm sampler_params: guidance: quantile scale: 4 num_samples: 100 setup: forecasting # The following key will be ignored, # if the setup is forecasting missing_data_configs: - missing_scenario: BM-B missing_values: 168 - missing_scenario: BM-E missing_values: 168 ================================================ FILE: configs/refinement/electricity_nips-deepar.yaml ================================================ base_model: deepar ckpt: dummy/electricity_nips.ckpt context_length: 336 dataset: electricity_nips device: cuda:0 init_skip: false iterations: 20 num_samples: 100 prediction_length: 24 refiner_configs: - guidance: MSE lr: 0.1 refiner_name: most_likely - guidance: quantile lr: 0.1 refiner_name: most_likely - guidance: MSE method: lmc method_kwargs: noise_scale: 0.1 refiner_name: mcmc step_size: 0.1 - guidance: quantile method: lmc method_kwargs: noise_scale: 0.1 refiner_name: mcmc step_size: 0.1 use_features: false use_lags: true ================================================ FILE: configs/refinement/electricity_nips-linear.yaml ================================================ base_model: linear ckpt: dummy/electricity_nips.ckpt context_length: 336 dataset: electricity_nips device: cuda:0 diffusion_config: diffusion_small_config init_skip: false iterations: 20 num_samples: 100 prediction_length: 24 refiner_configs: - guidance: MSE lr: 0.1 refiner_name: most_likely - guidance: quantile lr: 0.1 refiner_name: most_likely - guidance: MSE method: lmc method_kwargs: noise_scale: 0.1 refiner_name: mcmc step_size: 0.1 - guidance: quantile method: lmc method_kwargs: noise_scale: 0.1 refiner_name: mcmc step_size: 0.1 use_features: false use_lags: true ================================================ FILE: configs/refinement/electricity_nips-seasonal_naive.yaml ================================================ base_model: seasonal_naive ckpt: dummy/electricity_nips.ckpt context_length: 336 dataset: electricity_nips device: cuda:0 init_skip: false iterations: 20 num_samples: 100 prediction_length: 24 refiner_configs: - guidance: MSE lr: 0.1 refiner_name: most_likely - guidance: quantile lr: 0.1 refiner_name: most_likely - guidance: MSE method: lmc method_kwargs: noise_scale: 0.1 refiner_name: mcmc step_size: 0.1 - guidance: quantile method: lmc method_kwargs: noise_scale: 0.1 refiner_name: mcmc step_size: 0.1 use_features: false use_lags: true ================================================ FILE: configs/refinement/electricity_nips-transformer.yaml ================================================ base_model: transformer ckpt: dummy/electricity_nips.ckpt context_length: 336 dataset: electricity_nips device: cuda:0 init_skip: false iterations: 20 num_samples: 100 prediction_length: 24 refiner_configs: - guidance: MSE lr: 0.1 refiner_name: most_likely - guidance: quantile lr: 0.1 refiner_name: most_likely - guidance: MSE method: lmc method_kwargs: noise_scale: 0.1 refiner_name: mcmc step_size: 0.1 - guidance: quantile method: lmc method_kwargs: noise_scale: 0.1 refiner_name: mcmc step_size: 0.1 use_features: false use_lags: true ================================================ FILE: configs/refinement/exchange_rate_nips-deepar.yaml ================================================ base_model: deepar ckpt: dummy/exchange_rate_nips.ckpt context_length: 360 dataset: exchange_rate_nips device: cuda:0 init_skip: true iterations: 20 num_samples: 100 prediction_length: 30 refiner_configs: - guidance: MSE lr: 0.01 refiner_name: most_likely - guidance: quantile lr: 0.01 refiner_name: most_likely - guidance: MSE method: lmc method_kwargs: noise_scale: 0.1 refiner_name: mcmc step_size: 0.01 - guidance: quantile method: lmc method_kwargs: noise_scale: 0.1 refiner_name: mcmc step_size: 0.01 use_features: false use_lags: true ================================================ FILE: configs/refinement/exchange_rate_nips-linear.yaml ================================================ base_model: linear ckpt: dummy/exchange_rate_nips.ckpt context_length: 360 dataset: exchange_rate_nips device: cuda:0 diffusion_config: diffusion_small_config init_skip: true iterations: 20 num_samples: 100 prediction_length: 30 refiner_configs: - guidance: MSE lr: 0.01 refiner_name: most_likely - guidance: quantile lr: 0.01 refiner_name: most_likely - guidance: MSE method: lmc method_kwargs: noise_scale: 0.1 refiner_name: mcmc step_size: 0.01 - guidance: quantile method: lmc method_kwargs: noise_scale: 0.1 refiner_name: mcmc step_size: 0.01 use_features: false use_lags: true ================================================ FILE: configs/refinement/exchange_rate_nips-seasonal_naive.yaml ================================================ base_model: seasonal_naive ckpt: dummy/exchange_rate_nips.ckpt context_length: 360 dataset: exchange_rate_nips device: cuda:0 init_skip: true iterations: 20 num_samples: 100 prediction_length: 30 refiner_configs: - guidance: MSE lr: 0.01 refiner_name: most_likely - guidance: quantile lr: 0.01 refiner_name: most_likely - guidance: MSE method: lmc method_kwargs: noise_scale: 0.1 refiner_name: mcmc step_size: 0.01 - guidance: quantile method: lmc method_kwargs: noise_scale: 0.1 refiner_name: mcmc step_size: 0.01 use_features: false use_lags: true ================================================ FILE: configs/refinement/exchange_rate_nips-transformer.yaml ================================================ base_model: transformer ckpt: dummy/exchange_rate_nips.ckpt context_length: 360 dataset: exchange_rate_nips device: cuda:0 init_skip: true iterations: 20 num_samples: 100 prediction_length: 30 refiner_configs: - guidance: MSE lr: 0.01 refiner_name: most_likely - guidance: quantile lr: 0.01 refiner_name: most_likely - guidance: MSE method: lmc method_kwargs: noise_scale: 0.1 refiner_name: mcmc step_size: 0.01 - guidance: quantile method: lmc method_kwargs: noise_scale: 0.1 refiner_name: mcmc step_size: 0.01 use_features: false use_lags: true ================================================ FILE: configs/refinement/kdd_cup_2018_without_missing-deepar.yaml ================================================ base_model: deepar base_model_params: {} ckpt: dummy/kdd_cup_2018_without_missing.ckpt context_length: 312 dataset: kdd_cup_2018_without_missing device: cuda:0 diffusion_config: diffusion_small_config init_skip: true iterations: 20 num_samples: 100 prediction_length: 48 refiner_configs: - guidance: MSE lr: 0.1 refiner_name: most_likely - guidance: quantile lr: 0.1 refiner_name: most_likely - guidance: MSE method: lmc method_kwargs: noise_scale: 0.1 refiner_name: mcmc step_size: 0.1 - guidance: quantile method: lmc method_kwargs: noise_scale: 0.1 refiner_name: mcmc step_size: 0.1 use_features: false use_lags: true ================================================ FILE: configs/refinement/kdd_cup_2018_without_missing-linear.yaml ================================================ base_model: linear base_model_params: {} ckpt: dummy/kdd_cup_2018_without_missing.ckpt context_length: 312 dataset: kdd_cup_2018_without_missing device: cuda:0 diffusion_config: diffusion_small_config init_skip: true iterations: 20 num_samples: 100 prediction_length: 48 refiner_configs: - guidance: MSE lr: 0.1 refiner_name: most_likely - guidance: quantile lr: 0.1 refiner_name: most_likely - guidance: MSE method: lmc method_kwargs: noise_scale: 0.1 refiner_name: mcmc step_size: 0.1 - guidance: quantile method: lmc method_kwargs: noise_scale: 0.1 refiner_name: mcmc step_size: 0.1 use_features: false use_lags: true ================================================ FILE: configs/refinement/kdd_cup_2018_without_missing-seasonal_naive.yaml ================================================ base_model: seasonal_naive base_model_params: {} ckpt: dummy/kdd_cup_2018_without_missing.ckpt context_length: 312 dataset: kdd_cup_2018_without_missing device: cuda:0 diffusion_config: diffusion_small_config init_skip: true iterations: 20 num_samples: 100 prediction_length: 48 refiner_configs: - guidance: MSE lr: 0.1 refiner_name: most_likely - guidance: quantile lr: 0.1 refiner_name: most_likely - guidance: MSE method: lmc method_kwargs: noise_scale: 0.1 refiner_name: mcmc step_size: 0.1 - guidance: quantile method: lmc method_kwargs: noise_scale: 0.1 refiner_name: mcmc step_size: 0.1 use_features: false use_lags: true ================================================ FILE: configs/refinement/kdd_cup_2018_without_missing-transformer.yaml ================================================ base_model: transformer base_model_params: {} ckpt: dummy/kdd_cup_2018_without_missing.ckpt context_length: 312 dataset: kdd_cup_2018_without_missing device: cuda:0 diffusion_config: diffusion_small_config init_skip: true iterations: 20 num_samples: 100 prediction_length: 48 refiner_configs: - guidance: MSE lr: 0.1 refiner_name: most_likely - guidance: quantile lr: 0.1 refiner_name: most_likely - guidance: MSE method: lmc method_kwargs: noise_scale: 0.1 refiner_name: mcmc step_size: 0.1 - guidance: quantile method: lmc method_kwargs: noise_scale: 0.1 refiner_name: mcmc step_size: 0.1 use_features: false use_lags: true ================================================ FILE: configs/refinement/m4_hourly-deepar.yaml ================================================ base_model: deepar ckpt: dummy/m4_hourly.ckpt context_length: 312 dataset: m4_hourly device: cuda:0 diffusion_config: diffusion_small_config init_skip: false iterations: 20 num_samples: 100 prediction_length: 48 refiner_configs: - guidance: MSE lr: 0.1 refiner_name: most_likely - guidance: quantile lr: 0.1 refiner_name: most_likely - guidance: MSE method: lmc method_kwargs: noise_scale: 0.1 refiner_name: mcmc step_size: 0.1 - guidance: quantile method: lmc method_kwargs: noise_scale: 0.1 refiner_name: mcmc step_size: 0.1 use_features: false use_lags: false ================================================ FILE: configs/refinement/m4_hourly-linear.yaml ================================================ base_model: linear ckpt: dummy/m4_hourly.ckpt context_length: 312 dataset: m4_hourly device: cuda:0 diffusion_config: diffusion_small_config init_skip: false iterations: 20 num_samples: 100 prediction_length: 48 refiner_configs: - guidance: MSE lr: 0.1 refiner_name: most_likely - guidance: quantile lr: 0.1 refiner_name: most_likely - guidance: MSE method: lmc method_kwargs: noise_scale: 0.1 refiner_name: mcmc step_size: 0.1 - guidance: quantile method: lmc method_kwargs: noise_scale: 0.1 refiner_name: mcmc step_size: 0.1 use_features: false use_lags: false ================================================ FILE: configs/refinement/m4_hourly-seasonal_naive.yaml ================================================ base_model: seasonal_naive ckpt: dummy/m4_hourly.ckpt context_length: 312 dataset: m4_hourly device: cuda:0 diffusion_config: diffusion_small_config init_skip: false iterations: 20 num_samples: 100 prediction_length: 48 refiner_configs: - guidance: MSE lr: 0.1 refiner_name: most_likely - guidance: quantile lr: 0.1 refiner_name: most_likely - guidance: MSE method: lmc method_kwargs: noise_scale: 0.1 refiner_name: mcmc step_size: 0.1 - guidance: quantile method: lmc method_kwargs: noise_scale: 0.1 refiner_name: mcmc step_size: 0.1 use_features: false use_lags: false ================================================ FILE: configs/refinement/m4_hourly-transformer.yaml ================================================ base_model: transformer ckpt: dummy/m4_hourly.ckpt context_length: 312 dataset: m4_hourly device: cuda:0 diffusion_config: diffusion_small_config init_skip: false iterations: 20 num_samples: 100 prediction_length: 48 refiner_configs: - guidance: MSE lr: 0.1 refiner_name: most_likely - guidance: quantile lr: 0.1 refiner_name: most_likely - guidance: MSE method: lmc method_kwargs: noise_scale: 0.1 refiner_name: mcmc step_size: 0.1 - guidance: quantile method: lmc method_kwargs: noise_scale: 0.1 refiner_name: mcmc step_size: 0.1 use_features: false use_lags: false ================================================ FILE: configs/refinement/solar_nips-deepar.yaml ================================================ base_model: deepar ckpt: dummy/solar_nips.ckpt context_length: 336 dataset: solar_nips device: cuda:0 init_skip: false iterations: 20 num_samples: 100 prediction_length: 24 refiner_configs: - guidance: MSE lr: 0.1 refiner_name: most_likely - guidance: quantile lr: 0.1 refiner_name: most_likely - guidance: MSE method: lmc method_kwargs: noise_scale: 0.1 refiner_name: mcmc step_size: 0.1 - guidance: quantile method: lmc method_kwargs: noise_scale: 0.1 refiner_name: mcmc step_size: 0.1 use_features: false use_lags: true ================================================ FILE: configs/refinement/solar_nips-linear.yaml ================================================ base_model: linear ckpt: dummy/solar_nips.ckpt context_length: 336 dataset: solar_nips device: cuda:0 diffusion_config: diffusion_small_config init_skip: false iterations: 20 num_samples: 100 prediction_length: 24 refiner_configs: - guidance: MSE lr: 0.1 refiner_name: most_likely - guidance: quantile lr: 0.1 refiner_name: most_likely - guidance: MSE method: lmc method_kwargs: noise_scale: 0.1 refiner_name: mcmc step_size: 0.1 - guidance: quantile method: lmc method_kwargs: noise_scale: 0.1 refiner_name: mcmc step_size: 0.1 use_features: false use_lags: true ================================================ FILE: configs/refinement/solar_nips-seasonal_naive.yaml ================================================ base_model: seasonal_naive ckpt: dummy/solar_nips.ckpt context_length: 336 dataset: solar_nips device: cuda:0 init_skip: false iterations: 20 num_samples: 100 prediction_length: 24 refiner_configs: - guidance: MSE lr: 0.1 refiner_name: most_likely - guidance: quantile lr: 0.1 refiner_name: most_likely - guidance: MSE method: lmc method_kwargs: noise_scale: 0.1 refiner_name: mcmc step_size: 0.1 - guidance: quantile method: lmc method_kwargs: noise_scale: 0.1 refiner_name: mcmc step_size: 0.1 use_features: false use_lags: true ================================================ FILE: configs/refinement/solar_nips-transformer.yaml ================================================ base_model: transformer ckpt: dummy/solar_nips.ckpt context_length: 336 dataset: solar_nips device: cuda:0 init_skip: false iterations: 20 num_samples: 100 prediction_length: 24 refiner_configs: - guidance: MSE lr: 0.1 refiner_name: most_likely - guidance: quantile lr: 0.1 refiner_name: most_likely - guidance: MSE method: lmc method_kwargs: noise_scale: 0.1 refiner_name: mcmc step_size: 0.1 - guidance: quantile method: lmc method_kwargs: noise_scale: 0.1 refiner_name: mcmc step_size: 0.1 use_features: false use_lags: true ================================================ FILE: configs/refinement/traffic_nips-deepar.yaml ================================================ base_model: deepar ckpt: dummy/traffic_nips.ckpt context_length: 336 dataset: traffic_nips device: cuda:0 init_skip: true iterations: 20 num_samples: 100 prediction_length: 24 refiner_configs: - guidance: MSE lr: 0.1 refiner_name: most_likely - guidance: quantile lr: 0.1 refiner_name: most_likely - guidance: MSE method: lmc method_kwargs: noise_scale: 0.1 refiner_name: mcmc step_size: 0.1 - guidance: quantile method: lmc method_kwargs: noise_scale: 0.1 refiner_name: mcmc step_size: 0.1 use_features: false use_lags: true ================================================ FILE: configs/refinement/traffic_nips-linear.yaml ================================================ base_model: linear ckpt: dummy/traffic_nips.ckpt context_length: 336 dataset: traffic_nips device: cuda:0 diffusion_config: diffusion_small_config init_skip: true iterations: 20 num_samples: 100 prediction_length: 24 refiner_configs: - guidance: MSE lr: 0.1 refiner_name: most_likely - guidance: quantile lr: 0.1 refiner_name: most_likely - guidance: MSE method: lmc method_kwargs: noise_scale: 0.1 refiner_name: mcmc step_size: 0.1 - guidance: quantile method: lmc method_kwargs: noise_scale: 0.1 refiner_name: mcmc step_size: 0.1 use_features: false use_lags: true ================================================ FILE: configs/refinement/traffic_nips-seasonal_naive.yaml ================================================ base_model: seasonal_naive ckpt: dummy/traffic_nips.ckpt context_length: 336 dataset: traffic_nips device: cuda:0 init_skip: true iterations: 20 num_samples: 100 prediction_length: 24 refiner_configs: - guidance: MSE lr: 0.1 refiner_name: most_likely - guidance: quantile lr: 0.1 refiner_name: most_likely - guidance: MSE method: lmc method_kwargs: noise_scale: 0.1 refiner_name: mcmc step_size: 0.1 - guidance: quantile method: lmc method_kwargs: noise_scale: 0.1 refiner_name: mcmc step_size: 0.1 use_features: false use_lags: true ================================================ FILE: configs/refinement/traffic_nips-transformer.yaml ================================================ base_model: transformer ckpt: dummy/traffic_nips.ckpt context_length: 336 dataset: traffic_nips device: cuda:0 init_skip: true iterations: 20 num_samples: 100 prediction_length: 24 refiner_configs: - guidance: MSE lr: 0.1 refiner_name: most_likely - guidance: quantile lr: 0.1 refiner_name: most_likely - guidance: MSE method: lmc method_kwargs: noise_scale: 0.1 refiner_name: mcmc step_size: 0.1 - guidance: quantile method: lmc method_kwargs: noise_scale: 0.1 refiner_name: mcmc step_size: 0.1 use_features: false use_lags: true ================================================ FILE: configs/refinement/uber_tlc_hourly-deepar.yaml ================================================ base_model: deepar ckpt: dummy/uber_tlc_hourly.ckpt context_length: 336 dataset: uber_tlc_hourly device: cuda:0 init_skip: false iterations: 20 num_samples: 100 prediction_length: 24 refiner_configs: - guidance: MSE lr: 0.1 refiner_name: most_likely - guidance: quantile lr: 0.1 refiner_name: most_likely - guidance: MSE method: lmc method_kwargs: noise_scale: 0.1 refiner_name: mcmc step_size: 0.1 - guidance: quantile method: lmc method_kwargs: noise_scale: 0.1 refiner_name: mcmc step_size: 0.1 use_features: false use_lags: true ================================================ FILE: configs/refinement/uber_tlc_hourly-linear.yaml ================================================ base_model: linear ckpt: dummy/uber_tlc_hourly.ckpt context_length: 336 dataset: uber_tlc_hourly device: cuda:0 diffusion_config: diffusion_small_config init_skip: false iterations: 20 num_samples: 100 prediction_length: 24 refiner_configs: - guidance: MSE lr: 0.1 refiner_name: most_likely - guidance: quantile lr: 0.1 refiner_name: most_likely - guidance: MSE method: lmc method_kwargs: noise_scale: 0.1 refiner_name: mcmc step_size: 0.1 - guidance: quantile method: lmc method_kwargs: noise_scale: 0.1 refiner_name: mcmc step_size: 0.1 use_features: false use_lags: true ================================================ FILE: configs/refinement/uber_tlc_hourly-seasonal_naive.yaml ================================================ base_model: seasonal_naive ckpt: dummy/uber_tlc_hourly.ckpt context_length: 336 dataset: uber_tlc_hourly device: cuda:0 init_skip: false iterations: 20 num_samples: 100 prediction_length: 24 refiner_configs: - guidance: MSE lr: 0.1 refiner_name: most_likely - guidance: quantile lr: 0.1 refiner_name: most_likely - guidance: MSE method: lmc method_kwargs: noise_scale: 0.1 refiner_name: mcmc step_size: 0.1 - guidance: quantile method: lmc method_kwargs: noise_scale: 0.1 refiner_name: mcmc step_size: 0.1 use_features: false use_lags: true ================================================ FILE: configs/refinement/uber_tlc_hourly-transformer.yaml ================================================ base_model: transformer ckpt: dummy/uber_tlc_hourly.ckpt context_length: 336 dataset: uber_tlc_hourly device: cuda:0 init_skip: false iterations: 20 num_samples: 100 prediction_length: 24 refiner_configs: - guidance: MSE lr: 0.1 refiner_name: most_likely - guidance: quantile lr: 0.1 refiner_name: most_likely - guidance: MSE method: lmc method_kwargs: noise_scale: 0.1 refiner_name: mcmc step_size: 0.1 - guidance: quantile method: lmc method_kwargs: noise_scale: 0.1 refiner_name: mcmc step_size: 0.1 use_features: false use_lags: true ================================================ FILE: configs/refinement/wiki2000_nips-deepar.yaml ================================================ base_model: deepar ckpt: dummy/wiki2000_nips.ckpt context_length: 360 dataset: wiki2000_nips device: cuda:0 diffusion_config: diffusion_small_config init_skip: false iterations: 20 num_samples: 100 prediction_length: 30 refiner_configs: - guidance: MSE lr: 0.1 refiner_name: most_likely - guidance: quantile lr: 0.1 refiner_name: most_likely - guidance: MSE method: lmc method_kwargs: noise_scale: 0.1 refiner_name: mcmc step_size: 0.1 - guidance: quantile method: lmc method_kwargs: noise_scale: 0.1 refiner_name: mcmc step_size: 0.1 use_features: false use_lags: false ================================================ FILE: configs/refinement/wiki2000_nips-linear.yaml ================================================ base_model: linear ckpt: dummy/wiki2000_nips.ckpt context_length: 360 dataset: wiki2000_nips device: cuda:0 diffusion_config: diffusion_small_config init_skip: false iterations: 20 num_samples: 100 prediction_length: 30 refiner_configs: - guidance: MSE lr: 0.1 refiner_name: most_likely - guidance: quantile lr: 0.1 refiner_name: most_likely - guidance: MSE method: lmc method_kwargs: noise_scale: 0.1 refiner_name: mcmc step_size: 0.1 - guidance: quantile method: lmc method_kwargs: noise_scale: 0.1 refiner_name: mcmc step_size: 0.1 use_features: false use_lags: false ================================================ FILE: configs/refinement/wiki2000_nips-seasonal_naive.yaml ================================================ base_model: seasonal_naive ckpt: dummy/wiki2000_nips.ckpt context_length: 360 dataset: wiki2000_nips device: cuda:0 diffusion_config: diffusion_small_config init_skip: false iterations: 20 num_samples: 100 prediction_length: 30 refiner_configs: - guidance: MSE lr: 0.1 refiner_name: most_likely - guidance: quantile lr: 0.1 refiner_name: most_likely - guidance: MSE method: lmc method_kwargs: noise_scale: 0.1 refiner_name: mcmc step_size: 0.1 - guidance: quantile method: lmc method_kwargs: noise_scale: 0.1 refiner_name: mcmc step_size: 0.1 use_features: false use_lags: false ================================================ FILE: configs/refinement/wiki2000_nips-transformer.yaml ================================================ base_model: transformer ckpt: dummy/wiki2000_nips.ckpt context_length: 360 dataset: wiki2000_nips device: cuda:0 diffusion_config: diffusion_small_config init_skip: false iterations: 20 num_samples: 100 prediction_length: 30 refiner_configs: - guidance: MSE lr: 0.1 refiner_name: most_likely - guidance: quantile lr: 0.1 refiner_name: most_likely - guidance: MSE method: lmc method_kwargs: noise_scale: 0.1 refiner_name: mcmc step_size: 0.1 - guidance: quantile method: lmc method_kwargs: noise_scale: 0.1 refiner_name: mcmc step_size: 0.1 use_features: false use_lags: false ================================================ FILE: configs/refinement.yaml ================================================ # Model & checkpoint parameters dataset: solar_nips device: cuda:0 ckpt: ckpts/forecasting/solar_nips/649_.ckpt diffusion_config: diffusion_small_config context_length: 336 prediction_length: 24 use_lags: true use_features: false init_skip: false # Refinement parameters base_model: linear base_model_params: {} num_samples: 16 iterations: 10 refiner_configs: - refiner_name: most_likely lr: 1.e-1 guidance: MSE - refiner_name: most_likely lr: 1.e-1 guidance: quantile - refiner_name: mcmc step_size: 1.e-1 guidance: MSE method: lmc method_kwargs: noise_scale: 0.1 - refiner_name: mcmc step_size: 1.e-1 guidance: quantile method: lmc method_kwargs: noise_scale: 0.1 ================================================ FILE: configs/train_tsdiff/train_electricity.yaml ================================================ model: unconditional diffusion_config: diffusion_small_config normalization: mean use_features: False use_lags: True dataset: electricity_nips freq: H context_length: 336 # 360 for `D` prediction_length: 24 # 30 for `D` lr: 1.e-3 init_skip: False gradient_clip_val: 0.5 max_epochs: 1000 num_batches_per_epoch: 128 batch_size: 64 # Used only in callback, # the final evaluation uses 100 samples num_samples: 16 sampler: ddpm sampler_params: guidance: quantile scale: 4 use_validation_set: True eval_every: 50 device: cuda:0 setup: forecasting ================================================ FILE: configs/train_tsdiff/train_exchange.yaml ================================================ model: unconditional diffusion_config: diffusion_small_config normalization: mean use_features: False use_lags: True dataset: exchange_rate_nips freq: B context_length: 360 # 360 for `D` prediction_length: 30 # 30 for `D` lr: 1.e-3 init_skip: True gradient_clip_val: 0.5 max_epochs: 1000 num_batches_per_epoch: 128 batch_size: 64 # Used only in callback, # the final evaluation uses 100 samples num_samples: 16 sampler: ddpm sampler_params: guidance: quantile scale: 8 use_validation_set: True eval_every: 50 device: cuda:0 setup: forecasting ================================================ FILE: configs/train_tsdiff/train_kdd_cup.yaml ================================================ model: unconditional diffusion_config: diffusion_small_config normalization: mean use_features: False use_lags: True dataset: kdd_cup_2018_without_missing freq: H context_length: 312 # 360 for `D` prediction_length: 48 # 30 for `D` lr: 1.e-3 init_skip: True gradient_clip_val: 0.5 max_epochs: 1000 num_batches_per_epoch: 128 batch_size: 64 # Used only in callback, # the final evaluation uses 100 samples num_samples: 16 sampler: ddpm sampler_params: guidance: quantile scale: 1 use_validation_set: True eval_every: 50 device: cuda:0 setup: forecasting ================================================ FILE: configs/train_tsdiff/train_m4.yaml ================================================ model: unconditional diffusion_config: diffusion_small_config normalization: mean use_features: False use_lags: False dataset: m4_hourly freq: H context_length: 312 # 360 for `D` prediction_length: 48 # 30 for `D` lr: 1.e-3 init_skip: False gradient_clip_val: 0.5 max_epochs: 1000 num_batches_per_epoch: 128 batch_size: 64 # Used only in callback, # the final evaluation uses 100 samples num_samples: 16 sampler: ddpm sampler_params: guidance: quantile scale: 2 use_validation_set: True eval_every: 50 device: cuda:0 setup: forecasting ================================================ FILE: configs/train_tsdiff/train_missing_electricity.yaml ================================================ model: unconditional diffusion_config: diffusion_small_config normalization: mean use_features: False use_lags: True dataset: electricity_nips freq: H context_length: 336 # 360 for `D` prediction_length: 24 # 30 for `D` lr: 1.e-3 init_skip: False gradient_clip_val: 0.5 max_epochs: 1000 num_batches_per_epoch: 128 batch_size: 64 # Used only in callback, # the final evaluation uses 100 samples num_samples: 16 sampler: ddpm sampler_params: guidance: quantile scale: 4 use_validation_set: True eval_every: 50 device: cuda:0 setup: missing_values # The following key will be ignored, # if the setup is forecasting missing_data_configs: - missing_scenario: none missing_values: 0 - missing_scenario: BM-E missing_values: 168 - missing_scenario: BM-B missing_values: 168 - missing_scenario: RM missing_values: 168 ================================================ FILE: configs/train_tsdiff/train_missing_exchange.yaml ================================================ model: unconditional diffusion_config: diffusion_small_config normalization: mean use_features: False use_lags: True dataset: exchange_rate_nips freq: B context_length: 360 # 360 for `D` prediction_length: 30 # 30 for `D` lr: 1.e-3 init_skip: True gradient_clip_val: 0.5 max_epochs: 1000 num_batches_per_epoch: 128 batch_size: 64 # Used only in callback, # the final evaluation uses 100 samples num_samples: 16 sampler: ddpm sampler_params: guidance: quantile scale: 8 use_validation_set: True eval_every: 50 device: cuda:0 setup: missing_values # The following key will be ignored, # if the setup is forecasting missing_data_configs: - missing_scenario: none missing_values: 0 - missing_scenario: BM-E missing_values: 180 - missing_scenario: BM-B missing_values: 180 - missing_scenario: RM missing_values: 180 ================================================ FILE: configs/train_tsdiff/train_missing_kdd_cup.yaml ================================================ model: unconditional diffusion_config: diffusion_small_config normalization: mean use_features: False use_lags: True dataset: kdd_cup_2018_without_missing freq: H context_length: 312 # 360 for `D` prediction_length: 48 # 30 for `D` lr: 1.e-3 init_skip: True gradient_clip_val: 0.5 max_epochs: 1000 num_batches_per_epoch: 128 batch_size: 64 # Used only in callback, # the final evaluation uses 100 samples num_samples: 16 sampler: ddpm sampler_params: guidance: quantile scale: 1 guidance: quantile use_validation_set: True do_final_eval: True eval_every: 50 device: cuda:0 setup: missing_values # The following key will be ignored, # if the setup is forecasting missing_data_configs: - missing_scenario: none missing_values: 0 - missing_scenario: BM-E missing_values: 156 - missing_scenario: BM-B missing_values: 156 - missing_scenario: RM missing_values: 156 ================================================ FILE: configs/train_tsdiff/train_missing_solar.yaml ================================================ model: unconditional diffusion_config: diffusion_small_config normalization: mean use_features: False use_lags: True dataset: solar_nips freq: H context_length: 336 # 360 for `D` prediction_length: 24 # 30 for `D` lr: 1.e-3 init_skip: False gradient_clip_val: 0.5 max_epochs: 1000 num_batches_per_epoch: 128 batch_size: 64 # Used only in callback, # the final evaluation uses 100 samples num_samples: 16 sampler: ddpm sampler_params: guidance: quantile scale: 8 use_validation_set: True eval_every: 50 device: cuda:0 setup: missing_values # The following key will be ignored, # if the setup is forecasting missing_data_configs: - missing_scenario: none missing_values: 0 - missing_scenario: BM-E missing_values: 168 - missing_scenario: BM-B missing_values: 168 - missing_scenario: RM missing_values: 168 ================================================ FILE: configs/train_tsdiff/train_missing_traffic.yaml ================================================ model: unconditional diffusion_config: diffusion_small_config normalization: mean use_features: False use_lags: True dataset: traffic_nips freq: H context_length: 336 # 360 for `D` prediction_length: 24 # 30 for `D` lr: 1.e-3 init_skip: True gradient_clip_val: 0.5 max_epochs: 1000 num_batches_per_epoch: 128 batch_size: 64 # Used only in callback, # the final evaluation uses 100 samples num_samples: 4 sampler: ddpm sampler_params: guidance: quantile scale: 4 use_validation_set: True eval_every: 50 device: cuda:0 setup: missing_values # The following key will be ignored, # if the setup is forecasting missing_data_configs: - missing_scenario: none missing_values: 0 - missing_scenario: BM-E missing_values: 168 - missing_scenario: BM-B missing_values: 168 - missing_scenario: RM missing_values: 168 ================================================ FILE: configs/train_tsdiff/train_missing_uber_tlc.yaml ================================================ model: unconditional diffusion_config: diffusion_small_config normalization: mean use_features: False use_lags: True dataset: uber_tlc_hourly freq: H context_length: 336 # 360 for `D` prediction_length: 24 # 30 for `D` lr: 1.e-3 init_skip: False gradient_clip_val: 0.5 max_epochs: 1000 num_batches_per_epoch: 128 batch_size: 64 # Used only in callback, # the final evaluation uses 100 samples num_samples: 16 sampler: ddpm sampler_params: guidance: quantile scale: 2 use_validation_set: True eval_every: 50 device: cuda:0 setup: missing_values # The following key will be ignored, # if the setup is forecasting missing_data_configs: - missing_scenario: none missing_values: 0 - missing_scenario: BM-E missing_values: 168 - missing_scenario: BM-B missing_values: 168 - missing_scenario: RM missing_values: 168 ================================================ FILE: configs/train_tsdiff/train_solar.yaml ================================================ model: unconditional diffusion_config: diffusion_small_config normalization: mean use_features: False use_lags: True dataset: solar_nips freq: H context_length: 336 # 360 for `D` prediction_length: 24 # 30 for `D` lr: 1.e-3 init_skip: False gradient_clip_val: 0.5 max_epochs: 1000 num_batches_per_epoch: 128 batch_size: 64 # Used only in callback, # the final evaluation uses 100 samples num_samples: 16 sampler: ddpm sampler_params: guidance: quantile scale: 8 use_validation_set: True eval_every: 50 device: cuda:0 setup: forecasting ================================================ FILE: configs/train_tsdiff/train_traffic.yaml ================================================ model: unconditional diffusion_config: diffusion_small_config normalization: mean use_features: False use_lags: True dataset: traffic_nips freq: H context_length: 336 # 360 for `D` prediction_length: 24 # 30 for `D` lr: 1.e-3 init_skip: True gradient_clip_val: 0.5 max_epochs: 1000 num_batches_per_epoch: 128 batch_size: 64 # Used only in callback, # the final evaluation uses 100 samples num_samples: 4 sampler: ddpm sampler_params: guidance: quantile scale: 4 use_validation_set: True eval_every: 50 device: cuda:0 setup: forecasting ================================================ FILE: configs/train_tsdiff/train_uber_tlc.yaml ================================================ model: unconditional diffusion_config: diffusion_small_config normalization: mean use_features: False use_lags: True dataset: uber_tlc_hourly freq: H context_length: 336 # 360 for `D` prediction_length: 24 # 30 for `D` lr: 1.e-3 init_skip: False gradient_clip_val: 0.5 max_epochs: 1000 num_batches_per_epoch: 128 batch_size: 64 # Used only in callback, # the final evaluation uses 100 samples num_samples: 16 sampler: ddpm sampler_params: guidance: quantile scale: 2 use_validation_set: True eval_every: 50 device: cuda:0 setup: forecasting ================================================ FILE: configs/train_tsdiff/train_wiki.yaml ================================================ model: unconditional diffusion_config: diffusion_small_config normalization: mean use_features: False use_lags: False dataset: wiki2000_nips freq: 1D context_length: 360 # 360 for `D` prediction_length: 30 # 30 for `D` lr: 1.e-3 init_skip: False gradient_clip_val: 0.5 max_epochs: 1000 num_batches_per_epoch: 128 batch_size: 64 # Used only in callback, # the final evaluation uses 100 samples num_samples: 4 sampler: ddpm sampler_params: guidance: quantile scale: 2 use_validation_set: True eval_every: 50 device: cuda:0 setup: forecasting ================================================ FILE: configs/train_tsdiff-cond/electricity_nips.yaml ================================================ batch_size: 64 context_length: 336 dataset: electricity_nips device: cuda:0 diffusion_config: diffusion_small_config do_final_eval: true eval_every: 10 freq: H gradient_clip_val: 0.5 init_skip: false lr: 0.001 max_epochs: 100 model: conditional noise_observed: true normalization: mean num_batches_per_epoch: 128 prediction_length: 24 setup: forecasting use_features: false use_lags: true use_validation_set: true ================================================ FILE: configs/train_tsdiff-cond/exchange_rate_nips.yaml ================================================ batch_size: 64 context_length: 360 dataset: exchange_rate_nips device: cuda:0 diffusion_config: diffusion_small_config do_final_eval: true eval_every: 10 freq: B gradient_clip_val: 0.5 init_skip: true lr: 0.001 max_epochs: 100 model: conditional noise_observed: true normalization: mean num_batches_per_epoch: 128 prediction_length: 30 setup: forecasting use_features: false use_lags: true use_validation_set: true ================================================ FILE: configs/train_tsdiff-cond/kdd_cup_2018_without_missing.yaml ================================================ batch_size: 64 context_length: 312 dataset: kdd_cup_2018_without_missing device: cuda:0 diffusion_config: diffusion_small_config do_final_eval: true eval_every: 10 freq: H gradient_clip_val: 0.5 init_skip: true lr: 0.001 max_epochs: 100 model: conditional noise_observed: true normalization: mean num_batches_per_epoch: 128 prediction_length: 48 setup: forecasting use_features: false use_lags: true use_validation_set: true ================================================ FILE: configs/train_tsdiff-cond/m4_hourly.yaml ================================================ batch_size: 64 context_length: 312 dataset: m4_hourly device: cuda:0 diffusion_config: diffusion_small_config do_final_eval: true eval_every: 10 freq: H gradient_clip_val: 0.5 init_skip: false lr: 0.001 max_epochs: 100 model: conditional noise_observed: true normalization: mean num_batches_per_epoch: 128 prediction_length: 48 setup: forecasting use_features: false use_lags: false use_validation_set: true ================================================ FILE: configs/train_tsdiff-cond/missing_BM-B_electricity_nips.yaml ================================================ batch_size: 64 context_length: 336 dataset: electricity_nips device: cuda:0 diffusion_config: diffusion_small_config do_final_eval: true eval_every: 10 freq: H gradient_clip_val: 0.5 init_skip: false lr: 0.001 max_epochs: 100 missing_scenario: BM-B missing_values: 168 model: conditional noise_observed: true normalization: mean num_batches_per_epoch: 128 prediction_length: 24 setup: missing_values train_missing_scenario: BM-B use_features: false use_lags: true use_validation_set: true ================================================ FILE: configs/train_tsdiff-cond/missing_BM-B_exchange_rate_nips.yaml ================================================ batch_size: 64 context_length: 360 dataset: exchange_rate_nips device: cuda:0 diffusion_config: diffusion_small_config do_final_eval: true eval_every: 10 freq: B gradient_clip_val: 0.5 init_skip: true lr: 0.001 max_epochs: 100 missing_scenario: BM-B missing_values: 180 model: conditional noise_observed: true normalization: mean num_batches_per_epoch: 128 prediction_length: 30 setup: missing_values train_missing_scenario: BM-B use_features: false use_lags: true use_validation_set: true ================================================ FILE: configs/train_tsdiff-cond/missing_BM-B_kdd_cup_2018_without_missing.yaml ================================================ batch_size: 64 context_length: 312 dataset: kdd_cup_2018_without_missing device: cuda:0 diffusion_config: diffusion_small_config do_final_eval: true eval_every: 10 freq: H gradient_clip_val: 0.5 init_skip: true lr: 0.001 max_epochs: 100 missing_scenario: BM-B missing_values: 156 model: conditional noise_observed: true normalization: mean num_batches_per_epoch: 128 prediction_length: 48 setup: missing_values train_missing_scenario: BM-B use_features: false use_lags: true use_validation_set: true ================================================ FILE: configs/train_tsdiff-cond/missing_BM-B_solar_nips.yaml ================================================ batch_size: 64 context_length: 336 dataset: solar_nips device: cuda:0 diffusion_config: diffusion_small_config do_final_eval: true eval_every: 10 freq: H gradient_clip_val: 0.5 init_skip: false lr: 0.001 max_epochs: 100 missing_scenario: BM-B missing_values: 168 model: conditional noise_observed: true normalization: mean num_batches_per_epoch: 128 prediction_length: 24 setup: missing_values train_missing_scenario: BM-B use_features: false use_lags: true use_validation_set: true ================================================ FILE: configs/train_tsdiff-cond/missing_BM-B_traffic_nips.yaml ================================================ batch_size: 64 context_length: 336 dataset: traffic_nips device: cuda:0 diffusion_config: diffusion_small_config do_final_eval: true eval_every: 10 freq: H gradient_clip_val: 0.5 init_skip: true lr: 0.001 max_epochs: 100 missing_scenario: BM-B missing_values: 168 model: conditional noise_observed: true normalization: mean num_batches_per_epoch: 128 prediction_length: 24 setup: missing_values train_missing_scenario: BM-B use_features: false use_lags: true use_validation_set: true ================================================ FILE: configs/train_tsdiff-cond/missing_BM-B_uber_tlc_hourly.yaml ================================================ batch_size: 64 context_length: 336 dataset: uber_tlc_hourly device: cuda:0 diffusion_config: diffusion_small_config do_final_eval: true eval_every: 10 freq: H gradient_clip_val: 0.5 init_skip: false lr: 0.001 max_epochs: 100 missing_scenario: BM-B missing_values: 168 model: conditional noise_observed: true normalization: mean num_batches_per_epoch: 128 prediction_length: 24 setup: missing_values train_missing_scenario: BM-B use_features: false use_lags: true use_validation_set: true ================================================ FILE: configs/train_tsdiff-cond/missing_BM-E_electricity_nips.yaml ================================================ batch_size: 64 context_length: 336 dataset: electricity_nips device: cuda:0 diffusion_config: diffusion_small_config do_final_eval: true eval_every: 10 freq: H gradient_clip_val: 0.5 init_skip: false lr: 0.001 max_epochs: 100 missing_scenario: BM-E missing_values: 168 model: conditional noise_observed: true normalization: mean num_batches_per_epoch: 128 prediction_length: 24 setup: missing_values train_missing_scenario: BM-E use_features: false use_lags: true use_validation_set: true ================================================ FILE: configs/train_tsdiff-cond/missing_BM-E_exchange_rate_nips.yaml ================================================ batch_size: 64 context_length: 360 dataset: exchange_rate_nips device: cuda:0 diffusion_config: diffusion_small_config do_final_eval: true eval_every: 10 freq: B gradient_clip_val: 0.5 init_skip: true lr: 0.001 max_epochs: 100 missing_scenario: BM-E missing_values: 180 model: conditional noise_observed: true normalization: mean num_batches_per_epoch: 128 prediction_length: 30 setup: missing_values train_missing_scenario: BM-E use_features: false use_lags: true use_validation_set: true ================================================ FILE: configs/train_tsdiff-cond/missing_BM-E_kdd_cup_2018_without_missing.yaml ================================================ batch_size: 64 context_length: 312 dataset: kdd_cup_2018_without_missing device: cuda:0 diffusion_config: diffusion_small_config do_final_eval: true eval_every: 10 freq: H gradient_clip_val: 0.5 init_skip: true lr: 0.001 max_epochs: 100 missing_scenario: BM-E missing_values: 156 model: conditional noise_observed: true normalization: mean num_batches_per_epoch: 128 prediction_length: 48 setup: missing_values train_missing_scenario: BM-E use_features: false use_lags: true use_validation_set: true ================================================ FILE: configs/train_tsdiff-cond/missing_BM-E_solar_nips.yaml ================================================ batch_size: 64 context_length: 336 dataset: solar_nips device: cuda:0 diffusion_config: diffusion_small_config do_final_eval: true eval_every: 10 freq: H gradient_clip_val: 0.5 init_skip: false lr: 0.001 max_epochs: 100 missing_scenario: BM-E missing_values: 168 model: conditional noise_observed: true normalization: mean num_batches_per_epoch: 128 prediction_length: 24 setup: missing_values train_missing_scenario: BM-E use_features: false use_lags: true use_validation_set: true ================================================ FILE: configs/train_tsdiff-cond/missing_BM-E_traffic_nips.yaml ================================================ batch_size: 64 context_length: 336 dataset: traffic_nips device: cuda:0 diffusion_config: diffusion_small_config do_final_eval: true eval_every: 10 freq: H gradient_clip_val: 0.5 init_skip: true lr: 0.001 max_epochs: 100 missing_scenario: BM-E missing_values: 168 model: conditional noise_observed: true normalization: mean num_batches_per_epoch: 128 prediction_length: 24 setup: missing_values train_missing_scenario: BM-E use_features: false use_lags: true use_validation_set: true ================================================ FILE: configs/train_tsdiff-cond/missing_BM-E_uber_tlc_hourly.yaml ================================================ batch_size: 64 context_length: 336 dataset: uber_tlc_hourly device: cuda:0 diffusion_config: diffusion_small_config do_final_eval: true eval_every: 10 freq: H gradient_clip_val: 0.5 init_skip: false lr: 0.001 max_epochs: 100 missing_scenario: BM-E missing_values: 168 model: conditional noise_observed: true normalization: mean num_batches_per_epoch: 128 prediction_length: 24 setup: missing_values train_missing_scenario: BM-E use_features: false use_lags: true use_validation_set: true ================================================ FILE: configs/train_tsdiff-cond/missing_RM_electricity_nips.yaml ================================================ batch_size: 64 context_length: 336 dataset: electricity_nips device: cuda:0 diffusion_config: diffusion_small_config do_final_eval: true eval_every: 10 freq: H gradient_clip_val: 0.5 init_skip: false lr: 0.001 max_epochs: 100 missing_scenario: RM missing_values: 168 model: conditional noise_observed: true normalization: mean num_batches_per_epoch: 128 prediction_length: 24 setup: missing_values train_missing_scenario: RM use_features: false use_lags: true use_validation_set: true ================================================ FILE: configs/train_tsdiff-cond/missing_RM_exchange_rate_nips.yaml ================================================ batch_size: 64 context_length: 360 dataset: exchange_rate_nips device: cuda:0 diffusion_config: diffusion_small_config do_final_eval: true eval_every: 10 freq: B gradient_clip_val: 0.5 init_skip: true lr: 0.001 max_epochs: 100 missing_scenario: RM missing_values: 180 model: conditional noise_observed: true normalization: mean num_batches_per_epoch: 128 prediction_length: 30 setup: missing_values train_missing_scenario: RM use_features: false use_lags: true use_validation_set: true ================================================ FILE: configs/train_tsdiff-cond/missing_RM_kdd_cup_2018_without_missing.yaml ================================================ batch_size: 64 context_length: 312 dataset: kdd_cup_2018_without_missing device: cuda:0 diffusion_config: diffusion_small_config do_final_eval: true eval_every: 10 freq: H gradient_clip_val: 0.5 init_skip: true lr: 0.001 max_epochs: 100 missing_scenario: RM missing_values: 156 model: conditional noise_observed: true normalization: mean num_batches_per_epoch: 128 prediction_length: 48 setup: missing_values train_missing_scenario: RM use_features: false use_lags: true use_validation_set: true ================================================ FILE: configs/train_tsdiff-cond/missing_RM_solar_nips.yaml ================================================ batch_size: 64 context_length: 336 dataset: solar_nips device: cuda:0 diffusion_config: diffusion_small_config do_final_eval: true eval_every: 10 freq: H gradient_clip_val: 0.5 init_skip: false lr: 0.001 max_epochs: 100 missing_scenario: RM missing_values: 168 model: conditional noise_observed: true normalization: mean num_batches_per_epoch: 128 prediction_length: 24 setup: missing_values train_missing_scenario: RM use_features: false use_lags: true use_validation_set: true ================================================ FILE: configs/train_tsdiff-cond/missing_RM_traffic_nips.yaml ================================================ batch_size: 64 context_length: 336 dataset: traffic_nips device: cuda:0 diffusion_config: diffusion_small_config do_final_eval: true eval_every: 10 freq: H gradient_clip_val: 0.5 init_skip: true lr: 0.001 max_epochs: 100 missing_scenario: RM missing_values: 168 model: conditional noise_observed: true normalization: mean num_batches_per_epoch: 128 prediction_length: 24 setup: missing_values train_missing_scenario: RM use_features: false use_lags: true use_validation_set: true ================================================ FILE: configs/train_tsdiff-cond/missing_RM_uber_tlc_hourly.yaml ================================================ batch_size: 64 context_length: 336 dataset: uber_tlc_hourly device: cuda:0 diffusion_config: diffusion_small_config do_final_eval: true eval_every: 10 freq: H gradient_clip_val: 0.5 init_skip: false lr: 0.001 max_epochs: 100 missing_scenario: RM missing_values: 168 model: conditional noise_observed: true normalization: mean num_batches_per_epoch: 128 prediction_length: 24 setup: missing_values train_missing_scenario: RM use_features: false use_lags: true use_validation_set: true ================================================ FILE: configs/train_tsdiff-cond/solar_nips.yaml ================================================ batch_size: 64 context_length: 336 dataset: solar_nips device: cuda:0 diffusion_config: diffusion_small_config do_final_eval: true eval_every: 10 freq: H gradient_clip_val: 0.5 init_skip: false lr: 0.001 max_epochs: 100 model: conditional noise_observed: true normalization: mean num_batches_per_epoch: 128 prediction_length: 24 setup: forecasting use_features: false use_lags: true use_validation_set: true ================================================ FILE: configs/train_tsdiff-cond/traffic_nips.yaml ================================================ batch_size: 64 context_length: 336 dataset: traffic_nips device: cuda:0 diffusion_config: diffusion_small_config do_final_eval: true eval_every: 10 freq: H gradient_clip_val: 0.5 init_skip: true lr: 0.001 max_epochs: 100 model: conditional noise_observed: true normalization: mean num_batches_per_epoch: 128 prediction_length: 24 setup: forecasting use_features: false use_lags: true use_validation_set: true ================================================ FILE: configs/train_tsdiff-cond/uber_tlc_hourly.yaml ================================================ batch_size: 64 context_length: 336 dataset: uber_tlc_hourly device: cuda:0 diffusion_config: diffusion_small_config do_final_eval: true eval_every: 10 freq: H gradient_clip_val: 0.5 init_skip: false lr: 0.001 max_epochs: 100 model: conditional noise_observed: true normalization: mean num_batches_per_epoch: 128 prediction_length: 24 setup: forecasting use_features: false use_lags: true use_validation_set: true ================================================ FILE: configs/train_tsdiff-cond/wiki2000_nips.yaml ================================================ batch_size: 64 context_length: 360 dataset: wiki2000_nips device: cuda:0 diffusion_config: diffusion_small_config do_final_eval: true eval_every: 10 freq: 1D gradient_clip_val: 0.5 init_skip: false lr: 0.001 max_epochs: 100 model: conditional noise_observed: true normalization: mean num_batches_per_epoch: 128 prediction_length: 30 setup: forecasting use_features: false use_lags: false use_validation_set: true ================================================ FILE: configs/train_tsdiff-cond.yaml ================================================ model: conditional diffusion_config: diffusion_small_config normalization: mean use_features: False use_lags: True dataset: solar_nips freq: H context_length: 336 # 360 for `D` prediction_length: 24 # 30 for `D` lr: 1.e-3 init_skip: False gradient_clip_val: 0.5 max_epochs: 100 num_batches_per_epoch: 128 batch_size: 64 use_validation_set: True eval_every: 10 device: cuda:0 noise_observed: True do_final_eval: True setup: missing_values # The following keys will be ignored, if the setup is forecasting train_missing_scenario: BM-E missing_scenario: BM-E missing_values: 168 ================================================ FILE: configs/train_tsdiff.yaml ================================================ model: unconditional diffusion_config: diffusion_small_config normalization: mean use_features: False use_lags: False dataset: solar_nips freq: H context_length: 336 # 360 for `D` prediction_length: 24 # 30 for `D` lr: 1.e-3 init_skip: True gradient_clip_val: 0.5 max_epochs: 100 num_batches_per_epoch: 128 batch_size: 64 scale: 4 # Used only in callback, # the final evaluation uses 100 samples num_samples: 16 sampler: ddpm sampler_params: guidance: quantile scale: 4 use_validation_set: True eval_every: 50 device: cuda:0 setup: forecasting do_final_eval: True # The following key will be ignored, # if the setup is forecasting missing_data_configs: - missing_scenario: BM-B missing_values: 168 - missing_scenario: BM-E missing_values: 168 ================================================ FILE: configs/tstr/electricity_nips.yaml ================================================ ckpt: dummy/electricity_nips.ckpt context_length: 336 dataset: electricity_nips device: cuda:0 diffusion_config: diffusion_small_config init_skip: false prediction_length: 24 scaling_type: mean use_features: false use_lags: true ================================================ FILE: configs/tstr/exchange_rate_nips.yaml ================================================ ckpt: dummy/exchange_rate_nips.ckpt context_length: 360 dataset: exchange_rate_nips device: cuda:0 diffusion_config: diffusion_small_config init_skip: true prediction_length: 30 scaling_type: mean use_features: false use_lags: true ================================================ FILE: configs/tstr/kdd_cup_2018_without_missing.yaml ================================================ ckpt: dummy/kdd_cup_2018_without_missing.ckpt context_length: 312 dataset: kdd_cup_2018_without_missing device: cuda:0 diffusion_config: diffusion_small_config init_skip: true prediction_length: 48 scaling_type: mean use_features: false use_lags: true ================================================ FILE: configs/tstr/m4_hourly.yaml ================================================ ckpt: dummy/m4_hourly.ckpt context_length: 312 dataset: m4_hourly device: cuda:0 diffusion_config: diffusion_small_config init_skip: false prediction_length: 48 scaling_type: mean use_features: false use_lags: false ================================================ FILE: configs/tstr/solar_nips.yaml ================================================ ckpt: dummy/solar_nips.ckpt context_length: 336 dataset: solar_nips device: cuda:0 diffusion_config: diffusion_small_config init_skip: false prediction_length: 24 scaling_type: mean use_features: false use_lags: true ================================================ FILE: configs/tstr/traffic_nips.yaml ================================================ ckpt: dummy/traffic_nips.ckpt context_length: 336 dataset: traffic_nips device: cuda:0 diffusion_config: diffusion_small_config init_skip: true prediction_length: 24 scaling_type: mean use_features: false use_lags: true ================================================ FILE: configs/tstr/uber_tlc_hourly.yaml ================================================ ckpt: dummy/uber_tlc_hourly.ckpt context_length: 336 dataset: uber_tlc_hourly device: cuda:0 diffusion_config: diffusion_small_config init_skip: false prediction_length: 24 scaling_type: mean use_features: false use_lags: true ================================================ FILE: configs/tstr/wiki2000_nips.yaml ================================================ ckpt: dummy/wiki2000_nips.ckpt context_length: 360 dataset: wiki2000_nips device: cuda:0 diffusion_config: diffusion_small_config init_skip: false prediction_length: 30 scaling_type: mean use_features: false use_lags: false ================================================ FILE: configs/tstr.yaml ================================================ # Model & checkpoint parameters dataset: solar_nips device: cuda:0 ckpt: ckpts/solar_nips/version_236/1299_.ckpt diffusion_config: diffusion_small_config context_length: 336 prediction_length: 24 use_lags: true use_features: false init_skip: true scaling_type: mean ================================================ FILE: pyproject.toml ================================================ [project] name = "uncond-ts-diff" version = "0.1.0" description = "TSDiff: An Unconditional Diffusion Model for Time Series" authors = [] dependencies = [ "torch~=1.13.1", "pytorch-lightning~=1.9.4", "gluonts[mxnet,pro]~=0.12.3", "matplotlib", "seaborn", "opt_einsum~=3.3.0", "einops", "black", "tqdm", "scipy", "scikit-learn", "numba", "jupyter", "rich", "pykeops==2.1.1", ] readme = "README.md" requires-python = ">= 3.8" [tool.black] line-length = 79 ================================================ FILE: src/uncond_ts_diff/arch/__init__.py ================================================ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 from .backbones import BackboneModel __all__ = ["BackboneModel"] ================================================ FILE: src/uncond_ts_diff/arch/backbones.py ================================================ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 import math import torch from torch import nn from .s4 import S4 class SinusoidalPositionEmbeddings(nn.Module): def __init__(self, dim): super().__init__() self.dim = dim def forward(self, time): device = time.device half_dim = self.dim // 2 embeddings = math.log(10000) / (half_dim - 1) embeddings = torch.exp( torch.arange(half_dim, device=device) * -embeddings ) embeddings = time[:, None] * embeddings[None, :] embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) return embeddings class S4Layer(nn.Module): def __init__( self, d_model, dropout=0.0, ): super().__init__() self.layer = S4( d_model=d_model, d_state=128, bidirectional=True, dropout=dropout, transposed=True, postact=None, ) self.norm = nn.LayerNorm(d_model) self.dropout = ( nn.Dropout1d(dropout) if dropout > 0.0 else nn.Identity() ) def forward(self, x): """ Input x is shape (B, d_input, L) """ z = x # Prenorm z = self.norm(z.transpose(-1, -2)).transpose(-1, -2) # Apply layer: we ignore the state input and output for training z, _ = self.layer(z) # Dropout on the output of the layer z = self.dropout(z) # Residual connection x = z + x return x, None def default_state(self, *args, **kwargs): return self.layer.default_state(*args, **kwargs) def step(self, x, state, **kwargs): z = x # Prenorm z = self.norm(z.transpose(-1, -2)).transpose(-1, -2) # Apply layer z, state = self.layer.step(z, state, **kwargs) # Residual connection x = z + x return x, state class S4Block(nn.Module): def __init__(self, d_model, dropout=0.0, expand=2, num_features=0): super().__init__() self.s4block = S4Layer(d_model, dropout=dropout) self.time_linear = nn.Linear(d_model, d_model) self.tanh = nn.Tanh() self.sigm = nn.Sigmoid() self.out_linear1 = nn.Conv1d( in_channels=d_model, out_channels=d_model, kernel_size=1 ) self.out_linear2 = nn.Conv1d( in_channels=d_model, out_channels=d_model, kernel_size=1 ) self.feature_encoder = nn.Conv1d(num_features, d_model, kernel_size=1) def forward(self, x, t, features=None): t = self.time_linear(t)[:, None, :].repeat(1, x.shape[2], 1) t = t.transpose(-1, -2) out, _ = self.s4block(x + t) if features is not None: out = out + self.feature_encoder(features) out = self.tanh(out) * self.sigm(out) out1 = self.out_linear1(out) out2 = self.out_linear2(out) return out1 + x, out2 def Conv1dKaiming(in_channels, out_channels, kernel_size): layer = nn.Conv1d(in_channels, out_channels, kernel_size) nn.init.kaiming_normal_(layer.weight) return layer class BackboneModel(nn.Module): def __init__( self, input_dim, hidden_dim, output_dim, step_emb, num_residual_blocks, num_features, residual_block="s4", dropout=0.0, init_skip=True, ): super().__init__() if residual_block == "s4": residual_block = S4Block else: raise ValueError(f"Unknown residual block {residual_block}") self.input_init = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.ReLU(), ) self.time_init = nn.Sequential( nn.Linear(step_emb, hidden_dim), nn.SiLU(), nn.Linear(hidden_dim, hidden_dim), nn.SiLU(), ) self.out_linear = nn.Sequential( nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, output_dim), ) residual_blocks = [] for i in range(num_residual_blocks): residual_blocks.append( residual_block( hidden_dim, num_features=num_features, dropout=dropout ) ) self.residual_blocks = nn.ModuleList(residual_blocks) self.step_embedding = SinusoidalPositionEmbeddings(step_emb) self.init_skip = init_skip def forward(self, input, t, features=None): x = self.input_init(input) # B, L ,C t = self.time_init(self.step_embedding(t)) x = x.transpose(-1, -2) if features is not None: features = features.transpose(-1, -2) skips = [] for layer in self.residual_blocks: x, skip = layer(x, t, features) skips.append(skip) skip = torch.stack(skips).sum(0) skip = skip.transpose(-1, -2) out = self.out_linear(skip) if self.init_skip: out = out + input return out ================================================ FILE: src/uncond_ts_diff/arch/s4.py ================================================ """Standalone version of Structured (Sequence) State Space (S4) model.""" import logging from functools import partial import math import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from pytorch_lightning.utilities import rank_zero_only from einops import rearrange, repeat import opt_einsum as oe contract = oe.contract contract_expression = oe.contract_expression def get_logger(name=__name__, level=logging.INFO) -> logging.Logger: """Initializes multi-GPU-friendly python logger.""" logger = logging.getLogger(name) logger.setLevel(level) # this ensures all logging levels get marked with the rank zero decorator # otherwise logs would get multiplied for each GPU process in multi-GPU setup for level in ( "debug", "info", "warning", "error", "exception", "fatal", "critical", ): setattr(logger, level, rank_zero_only(getattr(logger, level))) return logger log = get_logger(__name__) """ Cauchy and Vandermonde kernels """ try: # Try CUDA extension from extensions.cauchy.cauchy import cauchy_mult has_cauchy_extension = True except ImportError: # log.warning( # "CUDA extension for cauchy multiplication not found. Install by going to extensions/cauchy/ and running `python setup.py install`. This should speed up end-to-end training by 10-50%" # ) has_cauchy_extension = False try: # Try pykeops from pykeops.torch import Genred has_pykeops = True log.info("Pykeops installation found.") def _broadcast_dims(*tensors): max_dim = max([len(tensor.shape) for tensor in tensors]) tensors = [ tensor.view((1,) * (max_dim - len(tensor.shape)) + tensor.shape) for tensor in tensors ] return tensors def cauchy_conj(v, z, w): """Pykeops version""" expr_num = "z * ComplexReal(v) - Real2Complex(Sum(v * w))" expr_denom = "ComplexMult(z-w, z-Conj(w))" cauchy_mult = Genred( f"ComplexDivide({expr_num}, {expr_denom})", [ "v = Vj(2)", "z = Vi(2)", "w = Vj(2)", ], reduction_op="Sum", axis=1, ) v, z, w = _broadcast_dims(v, z, w) v = _c2r(v) z = _c2r(z) w = _c2r(w) r = 2 * cauchy_mult(v, z, w, backend="GPU") return _r2c(r) def log_vandermonde(v, x, L): expr = "ComplexMult(v, ComplexExp(ComplexMult(x, l)))" vandermonde_mult = Genred( expr, [ "v = Vj(2)", "x = Vj(2)", "l = Vi(2)", ], reduction_op="Sum", axis=1, ) l = torch.arange(L).to(x) v, x, l = _broadcast_dims(v, x, l) v = _c2r(v) x = _c2r(x) l = _c2r(l) r = vandermonde_mult(v, x, l, backend="GPU") return 2 * _r2c(r).real def log_vandermonde_transpose(u, v, x, L): """ u: ... H L v: ... H N x: ... H N Returns: ... H N V = Vandermonde(a, L) : (H N L) contract_L(V * u * v) """ expr = "ComplexMult(ComplexMult(v, u), ComplexExp(ComplexMult(x, l)))" vandermonde_mult = Genred( expr, [ "u = Vj(2)", "v = Vi(2)", "x = Vi(2)", "l = Vj(2)", ], reduction_op="Sum", axis=1, ) l = torch.arange(L).to(x) u, v, x, l = _broadcast_dims(u, v, x, l) u = _c2r(u) v = _c2r(v) x = _c2r(x) l = _c2r(l) r = vandermonde_mult(u, v, x, l, backend="GPU") return _r2c(r) except ImportError: has_pykeops = False if not has_cauchy_extension: log.warning( "Falling back on slow Cauchy kernel. Install at least one of pykeops or the CUDA extension for efficiency." ) def cauchy_naive(v, z, w): """ v, w: (..., N) z: (..., L) returns: (..., L) """ cauchy_matrix = v.unsqueeze(-1) / ( z.unsqueeze(-2) - w.unsqueeze(-1) ) # (... N L) return torch.sum(cauchy_matrix, dim=-2) # Vandermonde functions log.warning( "Falling back on slow Vandermonde kernel. Install pykeops for improved memory efficiency." ) def log_vandermonde(v, x, L): """ v: (..., N) x: (..., N) returns: (..., L) \sum v x^l """ vandermonde_matrix = torch.exp( x.unsqueeze(-1) * torch.arange(L).to(x) ) # (... N L) vandermonde_prod = contract( "... n, ... n l -> ... l", v, vandermonde_matrix ) # (... L) return 2 * vandermonde_prod.real def log_vandermonde_transpose(u, v, x, L): vandermonde_matrix = torch.exp( x.unsqueeze(-1) * torch.arange(L).to(x) ) # (... N L) vandermonde_prod = contract( "... l, ... n, ... n l -> ... n", u.to(x), v.to(x), vandermonde_matrix, ) # (... L) return vandermonde_prod def _conj(x): return torch.cat([x, x.conj()], dim=-1) _c2r = torch.view_as_real _r2c = torch.view_as_complex if tuple(map(int, torch.__version__.split(".")[:2])) >= (1, 10): def _resolve_conj(x): return x.conj().resolve_conj() else: def _resolve_conj(x): return x.conj() """ Simple nn.Module components """ def Activation(activation=None, dim=-1): if activation in [None, "id", "identity", "linear"]: return nn.Identity() elif activation == "tanh": return nn.Tanh() elif activation == "relu": return nn.ReLU() elif activation == "gelu": return nn.GELU() elif activation in ["swish", "silu"]: return nn.SiLU() elif activation == "glu": return nn.GLU(dim=dim) elif activation == "sigmoid": return nn.Sigmoid() else: raise NotImplementedError( "hidden activation '{}' is not implemented".format(activation) ) def LinearActivation( d_input, d_output, bias=True, transposed=False, activation=None, activate=False, # Apply activation as part of this module **kwargs, ): """Returns a linear nn.Module with control over axes order, initialization, and activation""" # Construct core module linear_cls = partial(nn.Conv1d, kernel_size=1) if transposed else nn.Linear if activation == "glu": d_output *= 2 linear = linear_cls(d_input, d_output, bias=bias, **kwargs) if activate and activation is not None: activation = Activation(activation, dim=-2 if transposed else -1) linear = nn.Sequential(linear, activation) return linear class DropoutNd(nn.Module): def __init__(self, p: float = 0.5, tie=True, transposed=True): """ tie: tie dropout mask across sequence lengths (Dropout1d/2d/3d) """ super().__init__() if p < 0 or p >= 1: raise ValueError( "dropout probability has to be in [0, 1), " "but got {}".format(p) ) self.p = p self.tie = tie self.transposed = transposed self.binomial = torch.distributions.binomial.Binomial(probs=1 - self.p) def forward(self, X): """X: (batch, dim, lengths...)""" if self.training: if not self.transposed: X = rearrange(X, "b d ... -> b ... d") mask_shape = ( X.shape[:2] + (1,) * (X.ndim - 2) if self.tie else X.shape ) mask = torch.rand(*mask_shape, device=X.device) < 1.0 - self.p X = X * mask * (1.0 / (1 - self.p)) if not self.transposed: X = rearrange(X, "b ... d -> b d ...") return X return X """ Misc functional utilities """ def power(L, A, v=None): """Compute A^L and the scan sum_i A^i v_i A: (..., N, N) v: (..., N, L) """ I = torch.eye(A.shape[-1]).to(A) # , dtype=A.dtype, device=A.device) powers = [A] l = 1 while True: if L % 2 == 1: I = powers[-1] @ I L //= 2 if L == 0: break l *= 2 powers.append(powers[-1] @ powers[-1]) if v is None: return I # Invariants: # powers[-1] := A^l # l := largest po2 at most L # Note that an alternative divide and conquer to compute the reduction is possible and can be embedded into the above loop without caching intermediate powers of A # We do this reverse divide-and-conquer for efficiency reasons: # 1) it involves fewer padding steps for non-po2 L # 2) it involves more contiguous arrays # Take care of edge case for non-po2 arrays # Note that this initial step is a no-op for the case of power of 2 (l == L) k = v.size(-1) - l v_ = powers.pop() @ v[..., l:] v = v[..., :l] v[..., :k] = v[..., :k] + v_ # Handle reduction for power of 2 while v.size(-1) > 1: v = rearrange(v, "... (z l) -> ... z l", z=2) v = v[..., 0, :] + powers.pop() @ v[..., 1, :] return I, v.squeeze(-1) """ HiPPO utilities """ def transition(measure, N): """A, B transition matrices for different measures""" # Legendre (translated) if measure == "legt": Q = np.arange(N, dtype=np.float64) R = (2 * Q + 1) ** 0.5 j, i = np.meshgrid(Q, Q) A = R[:, None] * np.where(i < j, (-1.0) ** (i - j), 1) * R[None, :] B = R[:, None] A = -A # Halve again for timescale correctness A *= 0.5 B *= 0.5 # Legendre (scaled) elif measure == "legs": q = np.arange(N, dtype=np.float64) col, row = np.meshgrid(q, q) r = 2 * q + 1 M = -(np.where(row >= col, r, 0) - np.diag(q)) T = np.sqrt(np.diag(2 * q + 1)) A = T @ M @ np.linalg.inv(T) B = np.diag(T)[:, None] B = ( B.copy() ) # Otherwise "UserWarning: given NumPY array is not writeable..." after torch.as_tensor(B) elif measure == "legsd": # Essentially equivalent to S4D-LegS q = np.arange(N, dtype=np.float64) col, row = np.meshgrid(q, q) r = 2 * q + 1 M = -(np.where(row >= col, r, 0) - np.diag(q)) T = np.sqrt(np.diag(2 * q + 1)) A = T @ M @ np.linalg.inv(T) B = np.diag(T)[:, None] B = ( B.copy() ) # Otherwise "UserWarning: given NumPY array is not writeable..." after torch.as_tensor(B) A += 0.5 * B * B[None, :, 0] B = B / 2.0 elif measure in ["fourier_diag", "foud"]: # Essentially equivalent to S4D-Lin freqs = np.arange(N // 2) d = np.stack([freqs, np.zeros(N // 2)], axis=-1).reshape(-1)[:-1] A = 2 * np.pi * (-np.diag(d, 1) + np.diag(d, -1)) A = A - 0.5 * np.eye(N) B = np.zeros(N) B[0::2] = 2**0.5 B[0] = 1 B = B[:, None] elif measure in ["fourier", "fout"]: freqs = np.arange(N // 2) d = np.stack([np.zeros(N // 2), freqs], axis=-1).reshape(-1)[1:] A = np.pi * (-np.diag(d, 1) + np.diag(d, -1)) B = np.zeros(N) B[0::2] = 2**0.5 B[0] = 1 # Subtract off rank correction - this corresponds to the other endpoint u(t-1) in this case A = A - B[:, None] * B[None, :] B = B[:, None] else: raise NotImplementedError return A, B def rank_correction(measure, N, rank=1, dtype=torch.float): """Return low-rank matrix L such that A + L is normal""" if measure == "legs": assert rank >= 1 P = torch.sqrt(0.5 + torch.arange(N, dtype=dtype)).unsqueeze( 0 ) # (1 N) elif measure == "legt": assert rank >= 2 P = torch.sqrt(1 + 2 * torch.arange(N, dtype=dtype)) # (N) P0 = P.clone() P0[0::2] = 0.0 P1 = P.clone() P1[1::2] = 0.0 P = torch.stack([P0, P1], dim=0) # (2 N) P *= 2 ** ( -0.5 ) # Halve the rank correct just like the original matrix was halved elif measure in ["fourier", "fout"]: P = torch.zeros(N) P[0::2] = 2**0.5 P[0] = 1 P = P.unsqueeze(0) elif measure in ["fourier_diag", "foud", "legsd"]: P = torch.zeros(1, N, dtype=dtype) else: raise NotImplementedError d = P.size(0) if rank > d: P = torch.cat( [P, torch.zeros(rank - d, N, dtype=dtype)], dim=0 ) # (rank N) return P def nplr(measure, N, rank=1, dtype=torch.float, diagonalize_precision=True): """Return w, p, q, V, B such that (w - p q^*, B) is unitarily equivalent to the original HiPPO A, B by the matrix V i.e. A = V[w - p q^*]V^*, B = V B """ assert dtype == torch.float or dtype == torch.double cdtype = torch.cfloat if dtype == torch.float else torch.cdouble A, B = transition(measure, N) A = torch.as_tensor(A, dtype=dtype) # (N, N) B = torch.as_tensor(B, dtype=dtype)[:, 0] # (N,) P = rank_correction(measure, N, rank=rank, dtype=dtype) # (r N) AP = A + torch.sum(P.unsqueeze(-2) * P.unsqueeze(-1), dim=-3) # We require AP to be nearly skew-symmetric _A = AP + AP.transpose(-1, -2) if ( err := torch.sum((_A - _A[0, 0] * torch.eye(N)) ** 2) / N ) > 1e-5: # if not torch.allclose(_A - _A[0,0]*torch.eye(N), torch.zeros(N, N), atol=1e-5): print("WARNING: HiPPO matrix not skew symmetric", err) # Take advantage of identity + skew-symmetric form to calculate real and imaginary parts separately # Imaginary part can use eigh instead of eig w_re = torch.mean(torch.diagonal(AP), -1, keepdim=True) # Diagonalize in double precision if diagonalize_precision: AP = AP.to(torch.double) w_im, V = torch.linalg.eigh(AP * -1j) # (..., N) (..., N, N) if diagonalize_precision: w_im, V = w_im.to(cdtype), V.to(cdtype) w = w_re + 1j * w_im # Check: V w V^{-1} = A # print("check", V @ torch.diag_embed(w) @ V.conj().transpose(-1, -2)) # Only keep half of each conjugate pair _, idx = torch.sort(w.imag) w_sorted = w[idx] V_sorted = V[:, idx] # There is an edge case when eigenvalues can be 0, which requires some machinery to handle # We use a huge hack here: Assume only one pair is 0, and that it is the first row/column of A (only happens in Fourier case) V = V_sorted[:, : N // 2] w = w_sorted[: N // 2] assert ( w[-2].abs() > 1e-4 ), "Only 1 zero eigenvalue allowed in diagonal part of A" if w[-1].abs() < 1e-4: V[:, -1] = 0.0 V[0, -1] = 2**-0.5 V[1, -1] = 2**-0.5 * 1j _AP = V @ torch.diag_embed(w) @ V.conj().transpose(-1, -2) if (err := torch.sum((2 * _AP.real - AP) ** 2) / N) > 1e-5: print( "Warning: Diagonalization of A matrix not numerically precise - error", err, ) # print("check", V @ torch.diag_embed(w) @ V.conj().transpose(-1, -2)) V_inv = V.conj().transpose(-1, -2) B = contract("ij, j -> i", V_inv, B.to(V)) # V^* B P = contract("ij, ...j -> ...i", V_inv, P.to(V)) # V^* P return w, P, B, V def dplr( scaling, N, rank=1, H=1, dtype=torch.float, real_scale=1.0, imag_scale=1.0, random_real=False, random_imag=False, normalize=False, diagonal=True, random_B=False, ): assert dtype == torch.float or dtype == torch.double dtype = torch.cfloat if dtype == torch.float else torch.cdouble pi = torch.tensor(math.pi) if random_real: real_part = torch.rand(H, N // 2) else: real_part = 0.5 * torch.ones(H, N // 2) if random_imag: imag_part = N // 2 * torch.rand(H, N // 2) else: imag_part = repeat(torch.arange(N // 2), "n -> h n", h=H) real_part = real_scale * real_part if scaling == "random": imag_part = torch.randn(H, N // 2) elif scaling == "real": imag_part = 0 * imag_part real_part = 1 + repeat(torch.arange(N // 2), "n -> h n", h=H) elif scaling in ["linear", "lin"]: imag_part = pi * imag_part elif scaling in [ "inverse", "inv", ]: # Based on asymptotics of the default HiPPO matrix imag_part = 1 / pi * N * (N / (1 + 2 * imag_part) - 1) elif scaling in ["inverse2", "inv2"]: imag_part = 1 / pi * N * (N / (1 + imag_part) - 1) elif scaling in ["quadratic", "quad"]: imag_part = 1 / pi * (1 + 2 * imag_part) ** 2 elif scaling in ["legs", "hippo"]: w, _, _, _ = nplr("legsd", N) imag_part = w.imag else: raise NotImplementedError imag_part = imag_scale * imag_part w = -real_part + 1j * imag_part # Initialize B if random_B: B = torch.randn(H, N // 2, dtype=dtype) else: B = torch.ones(H, N // 2, dtype=dtype) if normalize: norm = ( -B / w ) # (H, N) # Result if you integrate the kernel with constant 1 function zeta = 2 * torch.sum( torch.abs(norm) ** 2, dim=-1, keepdim=True ) # Variance with a random C vector B = B / zeta**0.5 P = torch.randn(rank, H, N // 2, dtype=dtype) if diagonal: P = P * 0.0 V = torch.eye(N, dtype=dtype)[:: N // 2] # Only used in testing V = repeat(V, "n m -> h n m", h=H) return w, P, B, V def ssm(measure, N, R, H, **ssm_args): """Dispatcher to create single SSM initialization N: state size R: rank (for DPLR parameterization) H: number of independent SSM copies """ if measure == "dplr": w, P, B, V = dplr(N=N, rank=R, H=H, **ssm_args) elif measure.startswith("diag"): args = measure.split("-") assert args[0] == "diag" and len(args) > 1 scaling = args[1] w, P, B, V = dplr( scaling=scaling, N=N, rank=R, H=H, diagonal=True, **ssm_args ) else: w, P, B, V = nplr(measure, N, R, **ssm_args) w = repeat(w, "n -> s n", s=H) P = repeat(P, "r n -> r s n", s=H) B = repeat(B, "n -> s n", s=H) V = repeat(V, "n m -> s n m", s=H) return w, P, B, V combinations = { "hippo": ["legs", "fourier"], "diag": ["diag-inv", "diag-lin"], "all": ["legs", "fourier", "diag-inv", "diag-lin"], } def combination(measures, N, R, S, **ssm_args): if isinstance(measures, str): measures = ( combinations[measures] if measures in combinations else [measures] ) assert ( S % len(measures) == 0 ), f"{S} independent trainable SSM copies must be multiple of {len(measures)} different measures" w, P, B, V = zip( *[ ssm(measure, N, R, S // len(measures), **ssm_args) for measure in measures ] ) w = torch.cat(w, dim=0) # (S N) P = torch.cat(P, dim=1) # (R S N) B = torch.cat(B, dim=0) # (S N) V = torch.cat(V, dim=0) # (S N N) return w, P, B, V class OptimModule(nn.Module): """Interface for Module that allows registering buffers/parameters with configurable optimizer hyperparameters""" def register(self, name, tensor, lr=None): """Register a tensor with a configurable learning rate and 0 weight decay""" if lr == 0.0: self.register_buffer(name, tensor) else: self.register_parameter(name, nn.Parameter(tensor)) optim = {"weight_decay": 0.0} if lr is not None: optim["lr"] = lr setattr(getattr(self, name), "_optim", optim) class SSKernelNPLR(OptimModule): """Stores a representation of and computes the SSKernel function K_L(A^dt, B^dt, C) corresponding to a discretized state space, where A is Normal + Low Rank (NPLR)""" @torch.no_grad() def _setup_C(self, L): """Construct C~ from C Two modes are supported: go directly to length L if self.L is 1, or length is doubled """ if self.L.item() == 0: if self.verbose: log.info(f"S4: Initializing kernel to length {L}") double_length = False elif L > self.L.item(): # 2*int(self.L) == L: if self.verbose: log.info( f"S4: Doubling length from L = {self.L.item()} to {2*self.L.item()}" ) double_length = True L = self.L.item() # Convenience for the math below else: return C = _r2c(self.C) dA, _ = self._setup_state() dA_L = power(L, dA) # Multiply C by I - dA_L C_ = _conj(C) prod = contract("h m n, c h n -> c h m", dA_L.transpose(-1, -2), C_) if double_length: prod = -prod # Multiply by I + dA_L instead C_ = C_ - prod C_ = C_[..., : self.N] # Take conjugate pairs again self.C.copy_(_c2r(C_)) self.L = ( 2 * self.L if double_length else self.L + L ) # Preserve type/device def _omega(self, L, dtype, device, cache=True): """Calculate (and cache) FFT nodes and their "unprocessed" version with the bilinear transform This should be called everytime the internal length self.L changes""" # Use cached if available if ( cache and hasattr(self, "omega") and self.omega.size(-1) == L // 2 + 1 ): return self.omega, self.z omega = torch.tensor( np.exp(-2j * np.pi / (L)), dtype=dtype, device=device ) # \omega_{2L} omega = omega ** torch.arange(0, L // 2 + 1, device=device) z = 2 * (1 - omega) / (1 + omega) # Cache if necessary if cache: self.omega = omega self.z = z return omega, z def __init__( self, w, P, B, C, log_dt, L=None, # starting/maximum length of kernel lr=None, verbose=False, keops=False, real_type="exp", # ['none' | 'exp' | 'relu' | sigmoid'] real_tolerance=1e-3, bandlimit=None, ): """ L: Maximum length; this module computes an SSM kernel of length L A is represented by diag(w) - PP^* w: (S, N) diagonal part P: (R, S, N) low-rank part B: (S, N) C: (C, H, N) dt: (H) timescale per feature lr: [dict | float | None] hook to set lr of special parameters (A, B, dt) Dimensions: N (or d_state): state size H (or d_model): total SSM copies S (or n_ssm): number of trainable copies of (A, B, dt); must divide H R (or rank): rank of low-rank part C (or channels): system is 1-dim to C-dim The forward pass of this Module returns a tensor of shape (C, H, L) Note: tensor shape N here denotes half the true state size, because of conjugate symmetry """ super().__init__() self.verbose = verbose self.keops = keops self.bandlimit = bandlimit self.real_type = real_type self.real_tolerance = real_tolerance # Rank of low-rank correction self.rank = P.shape[-3] assert w.size(-1) == P.size(-1) == B.size(-1) == C.size(-1) self.H = log_dt.size(-1) self.N = w.size(-1) # Check different SSM inits assert w.size(-2) == P.size(-2) == B.size(-2) # n_ssm assert self.H % w.size(0) == 0 self.n_ssm = w.size(0) self.repeat = self.H // w.size( 0 ) # Each trainable SSM needs to be duplicated this many times # Broadcast everything to correct shapes C = C.expand( torch.broadcast_shapes(C.shape, (1, self.H, self.N)) ) # (C, H, N) B = B.unsqueeze(0) # (1, 1, N) # Register parameters self.C = nn.Parameter(_c2r(_resolve_conj(C))) if lr is None or isinstance(lr, float): lr_dict = {} else: lr_dict, lr = lr, None self.register("log_dt", log_dt, lr_dict.get("dt", lr)) self.register("B", _c2r(B), lr_dict.get("B", lr)) self.register("P", _c2r(P), lr_dict.get("A", lr)) self.register("inv_w_real", self._w_init(w.real), lr_dict.get("A", lr)) self.register("w_imag", w.imag, lr_dict.get("A", lr)) self.l_max = L self.register_buffer("L", torch.tensor(0)) # Internal length def _w_init(self, w_real): w_real = torch.clamp(w_real, max=-self.real_tolerance) if self.real_type == "none": return -w_real elif self.real_type == "exp": return torch.log( -w_real ) # Some of the HiPPO methods have real part 0 elif self.real_type == "relu": return -w_real elif self.real_type == "sigmoid": return torch.logit(-w_real) elif self.real_type == "softplus": return torch.log(torch.exp(-w_real) - 1) else: raise NotImplementedError def _w(self): # Get the internal w (diagonal) parameter if self.real_type == "none": w_real = -self.inv_w_real elif self.real_type == "exp": w_real = -torch.exp(self.inv_w_real) elif self.real_type == "relu": w_real = -F.relu(self.inv_w_real) elif self.real_type == "sigmoid": w_real = -F.sigmoid(self.inv_w_real) elif self.real_type == "softplus": w_real = -F.softplus(self.inv_w_real) else: raise NotImplementedError w = w_real + 1j * self.w_imag return w def forward(self, state=None, rate=1.0, L=None): """ state: (B, H, N) initial state rate: sampling rate factor L: target length returns: (C, H, L) convolution kernel (generally C=1) (B, H, L) output from initial state """ # Initialize C~ if necessary (done in forward pass so it's on the correct device) if self.L.item() == 0 and self.l_max is not None and self.l_max > 0: self._setup_C(self.l_max) # Handle sampling rate logic # The idea is that this kernel's length (in continuous units) is self.L, while we are asked to provide a kernel of length L at (relative) frequency rate if L is None: L = round(self.L.item() / rate) # Increase the internal length if needed continuous_L = round(rate * L) while continuous_L > self.L.item(): self._setup_C(continuous_L) discrete_L = round(self.L.item() / rate) dt = torch.exp(self.log_dt) * rate B = _r2c(self.B) C = _r2c(self.C) P = _r2c(self.P) Q = P.conj() w = self._w() # (n_ssm, N) # Address bandlimiting if self.bandlimit is not None: freqs = w.imag.abs() / (2 * math.pi) # (H, N) freqs = dt[:, None] / rate * freqs # (H, N) mask = torch.where(freqs < self.bandlimit * 0.5, 1, 0) C = C * mask # Get FFT nodes of right length omega, z = self._omega( discrete_L, dtype=w.dtype, device=w.device, cache=(rate == 1.0) ) # Broadcast parameters to same hidden features H B = repeat(B, "1 t n -> 1 (v t) n", v=self.repeat) P = repeat(P, "r t n -> r (v t) n", v=self.repeat) Q = repeat(Q, "r t n -> r (v t) n", v=self.repeat) w = repeat(w, "t n -> (v t) n", v=self.repeat) # Augment B if state is not None: # Have to "unbilinear" the state to put it into the same "type" as B # Compute 1/dt * (I + dt/2 A) @ state # Can do this without expanding (maybe minor speedup using conj symmetry in theory), but it's easier to read this way s = _conj(state) if state.size(-1) == self.N else state # (B H N) sA = s * _conj(w) - contract( # (B H N) "bhm, rhm, rhn -> bhn", s, _conj(Q), _conj(P) ) s = s / dt.unsqueeze(-1) + sA / 2 s = s[..., : self.N] B = torch.cat([s, B], dim=-3) # (B+1, H, N) # Incorporate dt into A w = w * dt.unsqueeze(-1) # (H N) # Stack B and p, C and q for convenient batching B = torch.cat([B, P], dim=-3) # (B+1+R, H, N) C = torch.cat([C, Q], dim=-3) # (C+R, H, N) # Incorporate B and C batch dimensions v = B.unsqueeze(-3) * C.unsqueeze(-4) # (B+1+R, C+R, H, N) # Calculate resolvent at omega if has_cauchy_extension and z.dtype == torch.cfloat and not self.keops: r = cauchy_mult(v, z, w, symmetric=True) elif has_pykeops: r = cauchy_conj(v, z, w) else: r = cauchy_naive(v, z, w) r = r * dt[None, None, :, None] # (B+1+R, C+R, H, L) # Low-rank Woodbury correction if self.rank == 1: k_f = r[:-1, :-1, :, :] - r[:-1, -1:, :, :] * r[-1:, :-1, :, :] / ( 1 + r[-1:, -1:, :, :] ) elif self.rank == 2: r00 = r[: -self.rank, : -self.rank, :, :] r01 = r[: -self.rank, -self.rank :, :, :] r10 = r[-self.rank :, : -self.rank, :, :] r11 = r[-self.rank :, -self.rank :, :, :] det = (1 + r11[:1, :1, :, :]) * (1 + r11[1:, 1:, :, :]) - r11[ :1, 1:, :, : ] * r11[1:, :1, :, :] s = ( r01[:, :1, :, :] * (1 + r11[1:, 1:, :, :]) * r10[:1, :, :, :] + r01[:, 1:, :, :] * (1 + r11[:1, :1, :, :]) * r10[1:, :, :, :] - r01[:, :1, :, :] * (r11[:1, 1:, :, :]) * r10[1:, :, :, :] - r01[:, 1:, :, :] * (r11[1:, :1, :, :]) * r10[:1, :, :, :] ) s = s / det k_f = r00 - s else: r00 = r[: -self.rank, : -self.rank, :, :] r01 = r[: -self.rank, -self.rank :, :, :] r10 = r[-self.rank :, : -self.rank, :, :] r11 = r[-self.rank :, -self.rank :, :, :] r11 = rearrange(r11, "a b h n -> h n a b") r11 = torch.linalg.inv(torch.eye(self.rank, device=r.device) + r11) r11 = rearrange(r11, "h n a b -> a b h n") k_f = r00 - torch.einsum( "i j h n, j k h n, k l h n -> i l h n", r01, r11, r10 ) # Final correction for the bilinear transform k_f = k_f * 2 / (1 + omega) # Move from frequency to coefficients k = torch.fft.irfft(k_f, n=discrete_L) # (B+1, C, H, L) # # Truncate to target length k = k[..., :L] if state is not None: k_state = k[:-1, :, :, :] # (B, C, H, L) else: k_state = None k_B = k[-1, :, :, :] # (C H L) return k_B, k_state @torch.no_grad() def _setup_linear(self): """Create parameters that allow fast linear stepping of state""" w = self._w() B = _r2c(self.B) # (H N) P = _r2c(self.P) Q = P.conj() # Repeat w shape properly B = repeat(B, "1 t n -> 1 (v t) n", v=self.repeat) P = repeat(P, "r t n -> r (v t) n", v=self.repeat) Q = repeat(Q, "r t n -> r (v t) n", v=self.repeat) w = repeat(w, "t n -> (v t) n", v=self.repeat) # Prepare Linear stepping dt = torch.exp(self.log_dt) D = (2.0 / dt.unsqueeze(-1) - w).reciprocal() # (H, N) R = ( torch.eye(self.rank, dtype=w.dtype, device=w.device) + 2 * contract("r h n, h n, s h n -> h r s", Q, D, P).real ) # (H R R) Q_D = rearrange(Q * D, "r h n -> h r n") try: R = torch.linalg.solve(R, Q_D) # (H R N) except Exception: R = torch.tensor( np.linalg.solve( R.to(Q_D).contiguous().detach().cpu(), Q_D.contiguous().detach().cpu(), ) ).to(Q_D) R = rearrange(R, "h r n -> r h n") self.step_params = { "D": D, # (H N) "R": R, # (R H N) "P": P, # (R H N) "Q": Q, # (R H N) "B": B, # (1 H N) "E": 2.0 / dt.unsqueeze(-1) + w, # (H N) } def _step_state_linear(self, u=None, state=None): """ Version of the step function that has time O(N) instead of O(N^2) per step, which takes advantage of the DPLR form and bilinear discretization. Unfortunately, as currently implemented it's about 2x slower because it calls several sequential operations. Perhaps a fused CUDA kernel implementation would be much faster u: (H) input state: (H, N/2) state with conjugate pairs Optionally, the state can have last dimension N Returns: same shape as state """ C = _r2c(self.C) # View used for dtype/device if u is None: # Special case used to find dA u = torch.zeros(self.H, dtype=C.dtype, device=C.device) if state is None: # Special case used to find dB state = torch.zeros(self.H, self.N, dtype=C.dtype, device=C.device) step_params = self.step_params.copy() if ( state.size(-1) == self.N ): # Only store half of the conjugate pairs; should be true by default # There should be a slightly faster way using conjugate symmetry def contract_fn(p, x, y): return contract( "r h n, r h m, ... h m -> ... h n", _conj(p), _conj(x), _conj(y), )[ ..., : self.N ] # inner outer product else: assert state.size(-1) == 2 * self.N step_params = {k: _conj(v) for k, v in step_params.items()} # TODO worth setting up a contract_expression in default_state if we want to use this at inference time for stepping def contract_fn(p, x, y): return contract( "r h n, r h m, ... h m -> ... h n", p, x, y ) # inner outer product D = step_params["D"] # (H N) E = step_params["E"] # (H N) R = step_params["R"] # (R H N) P = step_params["P"] # (R H N) Q = step_params["Q"] # (R H N) B = step_params["B"] # (1 H N) new_state = E * state - contract_fn(P, Q, state) # (B H N) new_state = new_state + 2.0 * B * u.unsqueeze(-1) # (B H N) new_state = D * (new_state - contract_fn(P, R, new_state)) return new_state def _setup_state(self): """Construct dA and dB for discretized state equation""" # Construct dA and dB by using the stepping self._setup_linear() C = _r2c( self.C ) # Just returns a view that we use for finding dtype/device state = torch.eye( 2 * self.N, dtype=C.dtype, device=C.device ).unsqueeze( -2 ) # (N 1 N) dA = self._step_state_linear(state=state) dA = rearrange(dA, "n h m -> h m n") u = C.new_ones(self.H) dB = self._step_state_linear(u=u) dB = _conj(dB) dB = rearrange(dB, "1 h n -> h n") # (H N) return dA, dB def _step_state(self, u, state): """Must be called after self.default_state() is used to construct an initial state!""" next_state = self.state_contraction( self.dA, state ) + self.input_contraction(self.dB, u) return next_state def _setup_step(self, mode="dense"): """Set up dA, dB, dC discretized parameters for stepping""" self.dA, self.dB = self._setup_state() # Calculate original C C = _conj(_r2c(self.C)) # (H C N) if self.L.item() == 0: dC = C else: # self.C represents C_tilde dA_L = power(self.L.item(), self.dA) I = torch.eye(self.dA.size(-1)).to(dA_L) dC = torch.linalg.solve( I - dA_L.transpose(-1, -2), C.unsqueeze(-1), ).squeeze(-1) self.dC = dC # Do special preprocessing for different step modes self._step_mode = mode if mode == "linear": # Linear case: special step function for the state, we need to handle output # use conjugate symmetry by default, which affects the output projection self.dC = 2 * self.dC[:, :, : self.N] elif mode == "diagonal": # Eigendecomposition of the A matrix L, V = torch.linalg.eig(self.dA) V_inv = torch.linalg.inv(V) # Check that the eigendedecomposition is correct if self.verbose: print( "Diagonalization error:", torch.dist(V @ torch.diag_embed(L) @ V_inv, self.dA), ) # Change the parameterization to diagonalize self.dA = L self.dB = contract("h n m, h m -> h n", V_inv, self.dB) self.dC = contract("h n m, c h n -> c h m", V, self.dC) elif mode == "dense": pass else: raise NotImplementedError( "NPLR Kernel step mode must be {'dense' | 'linear' | 'diagonal'}" ) def default_state(self, *batch_shape): C = _r2c(self.C) N = C.size(-1) H = C.size(-2) # Cache the tensor contractions we will later do, for efficiency # These are put in this function because they depend on the batch size step_mode = getattr( self, "_step_mode", "dense" ) # Used in default_state, which is called without _setup_step() in forward_state() if step_mode != "linear": N *= 2 if step_mode == "diagonal": self.state_contraction = contract_expression( "h n, ... h n -> ... h n", (H, N), batch_shape + (H, N), ) else: # Dense (quadratic) case: expand all terms self.state_contraction = contract_expression( "h m n, ... h n -> ... h m", (H, N, N), batch_shape + (H, N), ) self.input_contraction = contract_expression( "h n, ... h -> ... h n", (H, N), # self.dB.shape batch_shape + (H,), ) self.output_contraction = contract_expression( "c h n, ... h n -> ... c h", (C.shape[0], H, N), # self.dC.shape batch_shape + (H, N), ) state = torch.zeros(*batch_shape, H, N, dtype=C.dtype, device=C.device) return state def step(self, u, state): """Must have called self._setup_step() and created state with self.default_state() before calling this""" if self._step_mode == "linear": new_state = self._step_state_linear(u, state) else: new_state = self._step_state(u, state) y = self.output_contraction(self.dC, new_state) return y.real, new_state class SSKernelDiag(OptimModule): """Version using (complex) diagonal state matrix (S4D)""" def __init__( self, A, B, C, log_dt, L=None, disc="bilinear", real_type="exp", lr=None, bandlimit=None, ): super().__init__() self.L = L self.disc = disc self.bandlimit = bandlimit self.real_type = real_type # Rank of low-rank correction assert A.size(-1) == C.size(-1) self.H = log_dt.size(-1) self.N = A.size(-1) assert A.size(-2) == B.size(-2) # Number of independent SSMs trained assert self.H % A.size(-2) == 0 self.n_ssm = A.size(-2) self.repeat = self.H // A.size(0) self.channels = C.shape[0] self.C = nn.Parameter(_c2r(_resolve_conj(C))) # Register parameters if lr is None or isinstance(lr, float): lr_dict = {} else: lr_dict, lr = lr, None self.register("log_dt", log_dt, lr_dict.get("dt", lr)) self.register("B", _c2r(B), lr_dict.get("B", lr)) self.register("inv_A_real", self._A_init(A.real), lr_dict.get("A", lr)) self.register("A_imag", A.imag, lr_dict.get("A", lr)) def _A_init(self, A_real): A_real = torch.clamp(A_real, max=-1e-4) if self.real_type == "none": return -A_real elif self.real_type == "exp": return torch.log( -A_real ) # Some of the HiPPO methods have real part 0 elif self.real_type == "relu": return -A_real elif self.real_type == "sigmoid": return torch.logit(-A_real) elif self.real_type == "softplus": return torch.log(torch.exp(-A_real) - 1) else: raise NotImplementedError def _A(self): # Get the internal A (diagonal) parameter if self.real_type == "none": A_real = -self.inv_A_real elif self.real_type == "exp": A_real = -torch.exp(self.inv_A_real) elif self.real_type == "relu": # JAX version seems to NaN if you alloA 0's, although this code Aas fine Aithout it A_real = -F.relu(self.inv_A_real) - 1e-4 elif self.real_type == "sigmoid": A_real = -F.sigmoid(self.inv_A_real) elif self.real_type == "softplus": A_real = -F.softplus(self.inv_A_real) else: raise NotImplementedError A = A_real + 1j * self.A_imag return A def forward(self, L, state=None, rate=1.0, u=None): """ state: (B, H, N) initial state rate: sampling rate factor L: target length returns: (C, H, L) convolution kernel (generally C=1) (B, H, L) output from initial state """ dt = torch.exp(self.log_dt) * rate # (H) C = _r2c(self.C) # (C H N) A = self._A() # (H N) B = _r2c(self.B) B = repeat(B, "t n -> 1 (v t) n", v=self.repeat) if self.bandlimit is not None: freqs = dt[:, None] / rate * A.imag.abs() / (2 * math.pi) # (H, N) mask = torch.where(freqs < self.bandlimit * 0.5, 1, 0) C = C * mask # Incorporate dt into A A = repeat(A, "t n -> (v t) n", v=self.repeat) dtA = A * dt.unsqueeze(-1) # (H N) # Augment B with state if state is not None: s = state / dt.unsqueeze(-1) if self.disc == "bilinear": s = s * (1.0 + dtA / 2) elif self.disc == "zoh": s = s * dtA * dtA.exp() / (dtA.exp() - 1.0) B = torch.cat([s, B], dim=-3) # (1+B H N) C = (B[:, None, :, :] * C).view(-1, self.H, self.N) if self.disc == "zoh": # Power up C = C * (torch.exp(dtA) - 1.0) / A K = log_vandermonde(C, dtA, L) # (H L) elif self.disc == "bilinear": C = ( C * (1.0 - dtA / 2).reciprocal() * dt.unsqueeze(-1) ) # or * dtA / A dA = (1.0 + dtA / 2) / (1.0 - dtA / 2) K = log_vandermonde(C, dA.log(), L) elif self.disc == "dss": # Implementation from DSS meant for case when real eigenvalues can be positive P = dtA.unsqueeze(-1) * torch.arange(L, device=C.device) # [H N L] A_gt_0 = A.real > 0 # [N] if A_gt_0.any(): with torch.no_grad(): P_max = dtA * (A_gt_0 * (L - 1)) # [H N] P = P - P_max.unsqueeze(-1) # [H N L] S = P.exp() # [H N L] dtA_neg = dtA * (1 - 2 * A_gt_0) # [H N] num = dtA_neg.exp() - 1 # [H N] den = (dtA_neg * L).exp() - 1 # [H N] # Inline reciprocal function for DSS logic x = den * A x_conj = _resolve_conj(x) r = x_conj / (x * x_conj + 1e-7) C = C * num * r # [C H N] K = contract("chn,hnl->chl", C, S).float() else: assert False, f"{self.disc} not supported" K = K.view(-1, self.channels, self.H, L) # (1+B C H L) if state is not None: K_state = K[:-1, :, :, :] # (B C H L) else: K_state = None K = K[-1, :, :, :] # (C H L) return K, K_state def _setup_step(self): # These methods are organized like this to be compatible with the NPLR kernel interface dt = torch.exp(self.log_dt) # (H) B = _r2c(self.B) # (H N) C = _r2c(self.C) # (C H N) self.dC = C A = self._A() # (H N) A = repeat(A, "t n -> (v t) n", v=self.repeat) B = repeat(B, "t n -> (v t) n", v=self.repeat) # Incorporate dt into A dtA = A * dt.unsqueeze(-1) # (H N) if self.disc == "zoh": self.dA = torch.exp(dtA) # (H N) self.dB = B * (torch.exp(dtA) - 1.0) / A # (C H N) elif self.disc == "bilinear": self.dA = (1.0 + dtA / 2) / (1.0 - dtA / 2) self.dB = ( B * (1.0 - dtA / 2).reciprocal() * dt.unsqueeze(-1) ) # or * dtA / A def default_state(self, *batch_shape): C = _r2c(self.C) state = torch.zeros( *batch_shape, self.H, self.N, dtype=C.dtype, device=C.device ) return state def step(self, u, state): next_state = contract( "h n, b h n -> b h n", self.dA, state ) + contract("h n, b h -> b h n", self.dB, u) y = contract("c h n, b h n -> b c h", self.dC, next_state) return 2 * y.real, next_state def forward_state(self, u, state): self._setup_step() AL = self.dA ** u.size(-1) u = u.flip(-1).to(self.dA).contiguous() # (B H L) v = log_vandermonde_transpose(u, self.dB, self.dA.log(), u.size(-1)) next_state = AL * state + v return next_state class SSKernel(nn.Module): """Wrapper around SSKernel parameterizations. The SSKernel is expected to support the interface forward() default_state() _setup_step() step() """ def __init__( self, H, N=64, L=None, measure="legs", rank=1, channels=1, dt_min=0.001, dt_max=0.1, deterministic=False, lr=None, mode="nplr", n_ssm=None, verbose=False, measure_args={}, **kernel_args, ): """State Space Kernel which computes the convolution kernel $\\bar{K}$ H: Number of independent SSM copies; controls the size of the model. Also called d_model in the config. N: State size (dimensionality of parameters A, B, C). Also called d_state in the config. Generally shouldn't need to be adjusted and doens't affect speed much. L: Maximum length of convolution kernel, if known. Should work in the majority of cases even if not known. measure: Options for initialization of (A, B). For NPLR mode, recommendations are "legs", "fout", "hippo" (combination of both). For Diag mode, recommendations are "diag-inv", "diag-lin", "diag-legs", and "diag" (combination of diag-inv and diag-lin) rank: Rank of low-rank correction for NPLR mode. Needs to be increased for measure "legt" channels: C channels turns the SSM from a 1-dim to C-dim map; can think of it having C separate "heads" per SSM. This was partly a feature to make it easier to implement bidirectionality; it is recommended to set channels=1 and adjust H to control parameters instead dt_min, dt_max: min and max values for the step size dt (\Delta) mode: Which kernel algorithm to use. 'nplr' is the full S4 model; 'diag' is the simpler S4D; 'slow' is a dense version for testing n_ssm: Number of independent trainable (A, B) SSMs, e.g. n_ssm=1 means all A/B parameters are tied across the H different instantiations of C. n_ssm=None means all H SSMs are completely independent. Generally, changing this option can save parameters but doesn't affect performance or speed much. This parameter must divide H lr: Passing in a number (e.g. 0.001) sets attributes of SSM parameers (A, B, dt). A custom optimizer hook is needed to configure the optimizer to set the learning rates appropriately for these parameters. """ super().__init__() self.N = N self.H = H dtype, cdtype = torch.float, torch.cfloat self.channels = channels self.n_ssm = n_ssm if n_ssm is not None else H self.mode = mode self.verbose = verbose self.kernel_args = kernel_args # Generate dt if deterministic: log_dt = torch.exp( torch.linspace(math.log(dt_min), math.log(dt_max), H) ) else: log_dt = torch.rand(self.H, dtype=dtype) * ( math.log(dt_max) - math.log(dt_min) ) + math.log(dt_min) # Compute the preprocessed representation w, P, B, V = combination( measure, self.N, rank, self.n_ssm, **measure_args ) # Broadcast C to have H channels if deterministic: C = torch.zeros(channels, self.n_ssm, self.N, dtype=cdtype) C[:, :, :1] = 1.0 C = contract( "hmn, chn -> chm", V.conj().transpose(-1, -2), C ) # V^* C C = ( repeat(C, "c t n -> c (v t) n", v=self.n_ssm // C.size(-2)) .clone() .contiguous() ) else: C = torch.randn(channels, self.H, self.N // 2, dtype=cdtype) # Broadcast other parameters to have n_ssm copies assert ( self.n_ssm % B.size(-2) == 0 and self.n_ssm % P.size(-2) == 0 and self.n_ssm % w.size(-2) == 0 ) # Broadcast tensors to n_ssm copies # These will be the parameters, so make sure tensors are materialized and contiguous B = ( repeat(B, "t n -> (v t) n", v=self.n_ssm // B.size(-2)) .clone() .contiguous() ) P = ( repeat(P, "r t n -> r (v t) n", v=self.n_ssm // P.size(-2)) .clone() .contiguous() ) w = ( repeat(w, "t n -> (v t) n", v=self.n_ssm // w.size(-2)) .clone() .contiguous() ) if mode == "nplr": self.kernel = SSKernelNPLR( w, P, B, C, log_dt, L=L, lr=lr, verbose=verbose, **kernel_args, ) elif mode == "diag": if not measure.startswith("diag"): log.warning( "Diagonal kernel (S4D) activated but initialization is not intended for S4D. Set `measure` to 'diag-lin', 'diag-inv', or 'diag-legs' for the main variants, or 'diag' for a combination of S4D-Lin and S4D-Inv." ) C = C * repeat(B, "t n -> (v t) n", v=H // self.n_ssm) self.kernel = SSKernelDiag( w, B, C, log_dt, L=L, lr=lr, **kernel_args, ) else: raise NotImplementedError(f"{mode=} is not valid") def forward(self, state=None, L=None, rate=1.0): return self.kernel(state=state, L=L, rate=rate) @torch.no_grad() def forward_state(self, u, state): """Forward the state through a sequence, i.e. computes the state after passing chunk through SSM state: (B, H, N) u: (B, H, L) Returns: (B, H, N) """ if hasattr(self.kernel, "forward_state"): return self.kernel.forward_state(u, state) dA, dB = self.kernel._setup_state() # Construct dA, dB matrices # dA, dB = self.kernel.dA, self.kernel.dB # (H N N) (H N) conj = state.size(-1) != dA.size(-1) if conj: state = _conj(state) v = contract( "h n, b h l -> b h n l", dB, u.flip(-1) ) # dB.unsqueeze(-1) * u.flip(-1).unsqueeze(-2) AL, v = power(u.size(-1), dA, v) next_state = contract("h m n, b h n -> b h m", AL, state) next_state = next_state + v if conj: next_state = next_state[..., : next_state.size(-1) // 2] return next_state def _setup_step(self, **kwargs): # This method is intended to be private so that setting up an S4 module with # ``` # if hasattr(module, 'setup_step'): module.setup_step() # ``` # will not trigger this method multiple times self.kernel._setup_step(**kwargs) def step(self, u, state, **kwargs): y, state = self.kernel.step(u, state, **kwargs) return y, state def default_state(self, *args, **kwargs): return self.kernel.default_state(*args, **kwargs) class S4(nn.Module): def __init__( self, d_model, d_state=64, l_max=None, channels=1, bidirectional=False, # Arguments for position-wise feedforward components activation="gelu", postact="glu", hyper_act=None, dropout=0.0, tie_dropout=False, bottleneck=None, gate=None, transposed=True, verbose=False, # SSM Kernel arguments **kernel_args, ): """ d_state: the dimension of the state, also denoted by N l_max: the maximum kernel length, also denoted by L. Set l_max=None to always use a global kernel channels: can be interpreted as a number of "heads"; the SSM is a map from a 1-dim to C-dim sequence. It's not recommended to change this unless desperate for things to tune; instead, increase d_model for larger models bidirectional: if True, convolution kernel will be two-sided Position-wise feedforward components: -------------------- activation: activation in between SS and FF postact: activation after FF hyper_act: use a "hypernetwork" multiplication (experimental) dropout: standard dropout argument. tie_dropout=True ties the dropout mask across the sequence length, emulating nn.Dropout1d Other arguments: -------------------- transposed: choose backbone axis ordering of (B, L, H) (if False) or (B, H, L) (if True) [B=batch size, L=sequence length, H=hidden dimension] gate: add gated activation (GSS) bottleneck: reduce SSM dimension (GSS) See the class SSKernel for the kernel constructor which accepts kernel_args. Relevant options that are worth considering and tuning include "mode" + "measure", "dt_min", "dt_max", "lr" Other options are all experimental and should not need to be configured """ super().__init__() if verbose: log.info( f"Constructing S4 (H, N, L) = ({d_model}, {d_state}, {l_max})" ) self.d_model = d_model self.H = d_model self.N = d_state self.L = l_max self.bidirectional = bidirectional self.channels = channels self.transposed = transposed self.gate = gate self.bottleneck = bottleneck if bottleneck is not None: self.H = self.H // bottleneck self.input_linear = LinearActivation( self.d_model, self.H, transposed=self.transposed, activation=activation, activate=True, ) if gate is not None: self.input_gate = LinearActivation( self.d_model, self.d_model * gate, transposed=self.transposed, activation=activation, activate=True, ) self.output_gate = LinearActivation( self.d_model * gate, self.d_model, transposed=self.transposed, activation=None, activate=False, ) # optional multiplicative modulation GLU-style # https://arxiv.org/abs/2002.05202 self.hyper = hyper_act is not None if self.hyper: channels *= 2 self.hyper_activation = Activation(hyper_act) self.D = nn.Parameter(torch.randn(channels, self.H)) if self.bidirectional: channels *= 2 # SSM Kernel self.kernel = SSKernel( self.H, N=self.N, L=self.L, channels=channels, verbose=verbose, **kernel_args, ) # Pointwise self.activation = Activation(activation) dropout_fn = DropoutNd if tie_dropout else nn.Dropout self.dropout = dropout_fn(dropout) if dropout > 0.0 else nn.Identity() # position-wise output transform to mix features self.output_linear = LinearActivation( self.H * self.channels, self.d_model * (1 if self.gate is None else self.gate), transposed=self.transposed, activation=postact, activate=True, ) def forward(self, u, state=None, rate=1.0, lengths=None, **kwargs): """ u: (B H L) if self.transposed else (B L H) state: (H N) never needed unless you know what you're doing Returns: same shape as u """ if not self.transposed: u = u.transpose(-1, -2) L = u.size(-1) # Mask out padding tokens if isinstance(lengths, int): if lengths != L: lengths = torch.tensor( lengths, dtype=torch.long, device=u.device ) else: lengths = None if lengths is not None: assert ( isinstance(lengths, torch.Tensor) and lengths.ndim == 1 and lengths.size(0) in [1, u.size(0)] ) mask = torch.where( torch.arange(L, device=lengths.device) < lengths[:, None, None], 1.0, 0.0, ) u = u * mask if self.gate is not None: v = self.input_gate(u) if self.bottleneck is not None: u = self.input_linear(u) # Compute SS Kernel L_kernel = L if self.L is None else min(L, round(self.L / rate)) k, k_state = self.kernel( L=L_kernel, rate=rate, state=state ) # (C H L) (B C H L) # Convolution if self.bidirectional: k0, k1 = rearrange(k, "(s c) h l -> s c h l", s=2) k = F.pad(k0, (0, L)) + F.pad(k1.flip(-1), (L, 0)) k_f = torch.fft.rfft(k, n=L_kernel + L) # (C H L) u_f = torch.fft.rfft(u, n=L_kernel + L) # (B H L) y_f = contract("bhl,chl->bchl", u_f, k_f) y = torch.fft.irfft(y_f, n=L_kernel + L)[..., :L] # (B C H L) # Compute D term in state space equation - essentially a skip connection y = y + contract("bhl,ch->bchl", u, self.D) # Compute state update if state is not None: assert ( not self.bidirectional ), "Bidirectional not supported with state forwarding" y = y + k_state # next_state = self.kernel.forward_state(u, state) else: next_state = None # Optional hyper-network multiplication if self.hyper: y, yh = rearrange(y, "b (s c) h l -> s b c h l", s=2) y = self.hyper_activation(yh) * y # Reshape to flatten channels y = rearrange(y, "... c h l -> ... (c h) l") y = self.dropout(self.activation(y)) if not self.transposed: y = y.transpose(-1, -2) y = self.output_linear(y) if self.gate is not None: y = self.output_gate(y * v) return y, next_state def setup_step(self, **kwargs): self.kernel._setup_step(**kwargs) def step(self, u, state): """Step one time step as a recurrent model. Intended to be used during validation. u: (B H) state: (B H N) Returns: output (B H), state (B H N) """ assert not self.training y, next_state = self.kernel.step(u, state) # (B C H) y = y + u.unsqueeze(-2) * self.D y = rearrange(y, "b c h -> b (c h)") y = self.activation(y) if self.transposed: y = self.output_linear(y.unsqueeze(-1)).squeeze(-1) else: y = self.output_linear(y) return y, next_state def default_state(self, *batch_shape, device=None): # kernel is not a SequenceModule so it doesn't need to adhere to same interface # the kernel will know the device of its own parameters return self.kernel.default_state(*batch_shape) @property def d_output(self): return self.d_model ================================================ FILE: src/uncond_ts_diff/configs.py ================================================ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 from uncond_ts_diff.utils import linear_beta_schedule residual_block_s4_backbone = { "input_dim": 1, "hidden_dim": 128, "output_dim": 1, "step_emb": 128, "num_residual_blocks": 6, "residual_block": "s4", } residual_block_s4_backbone_smallv2 = { "input_dim": 1, "hidden_dim": 512, "output_dim": 1, "step_emb": 128, "num_residual_blocks": 3, "residual_block": "s4", } residual_block_s4_backbone_small = { "input_dim": 1, "hidden_dim": 64, "output_dim": 1, "step_emb": 128, "num_residual_blocks": 3, "residual_block": "s4", } residual_block_s4_backbone_small_dropout01 = { "input_dim": 1, "hidden_dim": 64, "output_dim": 1, "step_emb": 128, "num_residual_blocks": 3, "dropout": 0.1, "residual_block": "s4", } residual_block_s4_backbone_small_dropout02 = { "input_dim": 1, "hidden_dim": 64, "output_dim": 1, "step_emb": 128, "num_residual_blocks": 3, "dropout": 0.2, "residual_block": "s4", } residual_block_s4_backbone_small_dropout03 = { "input_dim": 1, "hidden_dim": 64, "output_dim": 1, "step_emb": 128, "num_residual_blocks": 3, "dropout": 0.3, "residual_block": "s4", } residual_block_s4_backbone_large = { "input_dim": 1, "hidden_dim": 128, "output_dim": 1, "step_emb": 128, "num_residual_blocks": 18, "residual_block": "s4", } diffusion_config = { "backbone_parameters": residual_block_s4_backbone, "timesteps": 100, "diffusion_scheduler": linear_beta_schedule, } diffusion_small_config = { "backbone_parameters": residual_block_s4_backbone_small, "timesteps": 100, "diffusion_scheduler": linear_beta_schedule, } diffusion_small_configv2 = { "backbone_parameters": residual_block_s4_backbone_smallv2, "timesteps": 100, "diffusion_scheduler": linear_beta_schedule, } diffusion_small_config_dropout = { "backbone_parameters": residual_block_s4_backbone_small_dropout01, "timesteps": 100, "diffusion_scheduler": linear_beta_schedule, } diffusion_small_config_dropout02 = { "backbone_parameters": residual_block_s4_backbone_small_dropout02, "timesteps": 100, "diffusion_scheduler": linear_beta_schedule, } diffusion_small_config_dropout03 = { "backbone_parameters": residual_block_s4_backbone_small_dropout03, "timesteps": 100, "diffusion_scheduler": linear_beta_schedule, } diffusion_large_config = { "backbone_parameters": residual_block_s4_backbone_large, "timesteps": 100, "diffusion_scheduler": linear_beta_schedule, } ================================================ FILE: src/uncond_ts_diff/dataset.py ================================================ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 import os import tarfile from pathlib import Path from urllib import request from gluonts.dataset.common import load_datasets from gluonts.dataset.repository.datasets import get_dataset, get_download_path default_dataset_path: Path = get_download_path() / "datasets" wiki2k_download_link: str = "https://github.com/awslabs/gluonts/raw/b89f203595183340651411a41eeb0ee60570a4d9/datasets/wiki2000_nips.tar.gz" # noqa: E501 def get_gts_dataset(dataset_name): if dataset_name == "wiki2000_nips": wiki_dataset_path = default_dataset_path / dataset_name Path(default_dataset_path).mkdir(parents=True, exist_ok=True) if not wiki_dataset_path.exists(): tar_file_path = wiki_dataset_path.parent / f"{dataset_name}.tar.gz" request.urlretrieve( wiki2k_download_link, tar_file_path, ) with tarfile.open(tar_file_path) as tar: tar.extractall(path=wiki_dataset_path.parent) os.remove(tar_file_path) return load_datasets( metadata=wiki_dataset_path / "metadata", train=wiki_dataset_path / "train", test=wiki_dataset_path / "test", ) else: return get_dataset(dataset_name) ================================================ FILE: src/uncond_ts_diff/metrics/__init__.py ================================================ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 from .linear_pred_score import linear_pred_score __all__ = ["linear_pred_score"] ================================================ FILE: src/uncond_ts_diff/metrics/linear_pred_score.py ================================================ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 from typing import Tuple from functools import partial import numpy as np from gluonts.evaluation import Evaluator from gluonts.dataset.split import slice_data_entry from gluonts.transform import AdhocTransform, Chain from uncond_ts_diff.model import LinearEstimator from uncond_ts_diff.utils import ( GluonTSNumpyDataset, ScaleAndAddMeanFeature, ScaleAndAddMinMaxFeature, make_evaluation_predictions_with_scaling, ) def linear_pred_score( samples: np.ndarray, context_length: int, prediction_length: int, test_dataset, num_samples: int = 1, scaling_type: str = "mean", ) -> Tuple[dict, list, list]: """Compute the linear predictive score. Uses the `samples` to to fit a LinearRegression model and evaluate the forecast performance on the provided `test_dataset`. Parameters ---------- samples The samples used to fit the linear regression model. A numpy array of shape [N, T]. Assumed to be already scaled. context_length The context length for the linear model. prediction_length The prediction length for the linear model. Must be the same as the prediction length of the target `test_dataset`. test_datastet The test dataset on which the linear model will be evaluated. num_samples, optional Number of samples to draw from the linear model. Since the linear model is a point forecaster, `num_samples` > 1 would just result in the forecast being repeated `num_samples` times, by default 1 scaling_type, optional Scaling type should be one of {"mean", "min-max"} Min-max scaling is used in TimeGAN, defaults to "mean" Returns ------- Evaluation metrics, target test time series and forecasts """ min_past = context_length + prediction_length assert samples.shape[1] >= min_past dataset = GluonTSNumpyDataset(samples) linear_predictor = LinearEstimator( freq="H", # Not actually used in the estimator prediction_length=prediction_length, context_length=context_length, num_train_samples=len(dataset), # Since `samples` are synthetic samples, they are assumed to be already scaled scaling=False, ).train(dataset) # The linear predictor has been trained on scaled samples, # however, the test dataset is still in the original space. # Therefore, the test time series need to be sliced and # scaled before being fed into the predictor. # After prediction, the time series must be scaled back to # the original space for metric computation. # The following lines of code perform this custom evaluation. # Slice test set to be of the same length as context_length + prediction_length slice_func = partial(slice_data_entry, slice_=slice(-min_past, None)) if scaling_type == "mean": ScaleAndAddScaleFeature = ScaleAndAddMeanFeature elif scaling_type == "min-max": ScaleAndAddScaleFeature = ScaleAndAddMinMaxFeature transformation = Chain( [ AdhocTransform(slice_func), # Add scale to data entry for use later during evaluation ScaleAndAddScaleFeature("target", "scale", prediction_length), ] ) sliced_test_set = transformation.apply(test_dataset) evaluator = Evaluator() forecast_it, ts_it = make_evaluation_predictions_with_scaling( dataset=sliced_test_set, predictor=linear_predictor, num_samples=num_samples, scaling_type=scaling_type, ) forecasts = list(forecast_it) tss = list(ts_it) metrics, _ = evaluator(tss, forecasts) return metrics, tss, forecasts ================================================ FILE: src/uncond_ts_diff/model/__init__.py ================================================ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 from .diffusion.tsdiff import TSDiff from .diffusion.tsdiff_cond import TSDiffCond from .linear._estimator import LinearEstimator __all__ = [ "TSDiff", "TSDiffCond", "LinearEstimator", ] ================================================ FILE: src/uncond_ts_diff/model/callback.py ================================================ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 from copy import deepcopy import math from pathlib import Path import numpy as np import torch from gluonts.dataset.field_names import FieldName from gluonts.evaluation import make_evaluation_predictions, Evaluator from gluonts.transform import TestSplitSampler, InstanceSplitter from pytorch_lightning import Callback from uncond_ts_diff.sampler import DDPMGuidance, DDIMGuidance from uncond_ts_diff.metrics import linear_pred_score from uncond_ts_diff.utils import ConcatDataset class GradNormCallback(Callback): def __init__(self) -> None: super().__init__() def on_before_optimizer_step( self, trainer, pl_module, optimizer, opt_idx: int, ) -> None: return pl_module.log( "grad_norm", self.grad_norm(pl_module.parameters()), prog_bar=True ) def grad_norm(self, parameters): parameters = [p for p in parameters if p.grad is not None] device = parameters[0].grad.device total_norm = torch.norm( torch.stack( [torch.norm(p.grad.detach(), 2).to(device) for p in parameters] ), 2, ) return total_norm class PredictiveScoreCallback(Callback): def __init__( self, context_length, prediction_length, model, transformation, train_dataloader, train_batch_size, test_dataset, eval_every=10, ): super().__init__() self.context_length = context_length self.prediction_length = prediction_length self.model = model self.transformation = transformation self.train_dataloader = train_dataloader self.train_batch_size = train_batch_size self.test_dataset = test_dataset self.eval_every = eval_every # Number of samples used to train the downstream predictor self.n_pred_samples = 10000 def _generate_real_samples( self, data_loader, num_samples: int, n_timesteps: int, batch_size: int, cache_path: Path, ): if cache_path.exists(): real_samples = np.load(cache_path) if len(real_samples) == num_samples: return real_samples real_samples = [] data_iter = iter(data_loader) n_iters = math.ceil(num_samples / batch_size) for i in range(n_iters): try: batch = next(data_iter) except StopIteration: data_iter = iter(data_loader) batch = next(data_iter) ts = np.concatenate( [batch["past_target"], batch["future_target"]], axis=-1 )[:, -n_timesteps:] real_samples.append(ts) real_samples = np.concatenate(real_samples, axis=0)[:num_samples] np.save(cache_path, real_samples) return real_samples def _generate_synth_samples( self, model, num_samples: int, batch_size: int = 1000 ): synth_samples = [] n_iters = math.ceil(num_samples / batch_size) for _ in range(n_iters): samples = model.sample_n(num_samples=batch_size) synth_samples.append(samples) synth_samples = np.concatenate(synth_samples, axis=0)[:num_samples] return synth_samples def on_train_epoch_end(self, trainer, pl_module): if (pl_module.current_epoch + 1) % self.eval_every == 0: device = next(pl_module.backbone.parameters()).device pl_module.eval() assert pl_module.training is False real_samples = self._generate_real_samples( self.train_dataloader, self.n_pred_samples, self.context_length + self.prediction_length, self.train_batch_size, cache_path=Path(trainer.logger.log_dir) / "real_samples.npy", ) synth_samples = self._generate_synth_samples( self.model, self.n_pred_samples, ) # Train using synthetic samples, test on test set synth_metrics, _, _ = linear_pred_score( synth_samples, self.context_length, self.prediction_length, self.test_dataset, scaling_type="mean", ) # Train using real samples, test on test set scaled_real_samples, _ = self.model.scaler( torch.from_numpy(real_samples).to(device), torch.from_numpy(np.ones_like(real_samples)).to(device), ) real_metrics, _, _ = linear_pred_score( scaled_real_samples.cpu().numpy(), self.context_length, self.prediction_length, self.test_dataset, scaling_type="mean", ) pl_module.log_dict( { "synth_linear_ND": synth_metrics["ND"], "synth_linear_NRMSE": synth_metrics["NRMSE"], "real_linear_ND": real_metrics["ND"], "real_linear_NRMSE": real_metrics["NRMSE"], } ) pl_module.train() class EvaluateCallback(Callback): def __init__( self, context_length, prediction_length, sampler, sampler_kwargs, num_samples, model, transformation, test_dataset, val_dataset, eval_every=50, ): super().__init__() self.context_length = context_length self.prediction_length = prediction_length self.sampler = sampler self.num_samples = num_samples self.sampler_kwargs = sampler_kwargs self.model = model self.transformation = transformation self.test_dataset = test_dataset self.val_data = val_dataset self.original_state_dict = {} self.eval_every = eval_every self.log_metrics = { "CRPS", "ND", "NRMSE", } if sampler == "ddpm": self.Guidance = DDPMGuidance elif sampler == "ddim": self.Guidance = DDIMGuidance else: raise ValueError(f"Unknown sampler type: {sampler}") def on_train_epoch_end(self, trainer, pl_module): if (pl_module.current_epoch + 1) % self.eval_every == 0: device = next(pl_module.backbone.parameters()).device self.original_state_dict = deepcopy( pl_module.backbone.state_dict() ) pl_module.eval() assert pl_module.training is False for label, state_dict in zip( [""] + [str(rate) for rate in pl_module.ema_rate], [pl_module.backbone.state_dict()] + pl_module.ema_state_dicts, ): pl_module.backbone.load_state_dict(state_dict, strict=True) pl_module.to(device) prediction_splitter = InstanceSplitter( target_field=FieldName.TARGET, is_pad_field=FieldName.IS_PAD, start_field=FieldName.START, forecast_start_field=FieldName.FORECAST_START, instance_sampler=TestSplitSampler(), past_length=self.context_length + max(self.model.lags_seq), future_length=self.prediction_length, time_series_fields=[ FieldName.FEAT_TIME, FieldName.OBSERVED_VALUES, ], ) og = self.Guidance( self.model, self.prediction_length, num_samples=self.num_samples, **self.sampler_kwargs, ) predictor_pytorch = og.get_predictor( prediction_splitter, batch_size=1024 // self.num_samples, device=device, ) evaluator = Evaluator() transformed_valdata = self.transformation.apply( ConcatDataset(self.val_data), is_train=False ) forecast_it, ts_it = make_evaluation_predictions( dataset=transformed_valdata, predictor=predictor_pytorch, num_samples=self.num_samples, ) forecasts_pytorch = list(forecast_it) tss_pytorch = list(ts_it) metrics_pytorch, per_ts = evaluator( tss_pytorch, forecasts_pytorch ) metrics_pytorch["CRPS"] = metrics_pytorch["mean_wQuantileLoss"] if metrics_pytorch["CRPS"] < pl_module.best_crps: pl_module.best_crps = metrics_pytorch["CRPS"] ckpt_path = ( Path(trainer.logger.log_dir) / "best_checkpoint.ckpt" ) torch.save( pl_module.state_dict(), ckpt_path, ) pl_module.log_dict( { f"val_{metric}{label}": metrics_pytorch[metric] for metric in self.log_metrics } ) pl_module.backbone.load_state_dict( self.original_state_dict, strict=True ) pl_module.train() ================================================ FILE: src/uncond_ts_diff/model/diffusion/_base.py ================================================ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 from typing import Optional import numpy as np import pandas as pd import torch import torch.nn.functional as F import pytorch_lightning as pl from torch.optim.lr_scheduler import ReduceLROnPlateau from gluonts.time_feature import time_features_from_frequency_str from gluonts.torch.modules.feature import FeatureEmbedder from gluonts.torch.modules.scaler import MeanScaler, NOPScaler from uncond_ts_diff.utils import extract PREDICTION_INPUT_NAMES = [ "past_target", "past_observed_values", "feat_static_cat", "feat_static_real", "past_time_feat", "future_time_feat", ] class TSDiffBase(pl.LightningModule): def __init__( self, backbone_parameters, timesteps, diffusion_scheduler, context_length, prediction_length, num_feat_dynamic_real: int = 0, num_feat_static_cat: int = 0, num_feat_static_real: int = 0, cardinalities=None, freq=None, normalization="none", use_features=False, use_lags=True, lr: float = 1e-3, ): super().__init__() self.save_hyperparameters() self.timesteps = timesteps self.betas = diffusion_scheduler(timesteps) self.sqrt_one_minus_beta = torch.sqrt(1.0 - self.betas) self.alphas = 1 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, axis=0) self.alphas_cumprod_prev = F.pad( self.alphas_cumprod[:-1], (1, 0), value=1.0 ) self.sqrt_recip_alphas = torch.sqrt(1.0 / self.alphas) self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod) self.sqrt_one_minus_alphas_cumprod = torch.sqrt( 1.0 - self.alphas_cumprod ) self.posterior_variance = ( self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) ) self.logs = {} self.normalization = normalization if normalization == "mean": self.scaler = MeanScaler(dim=1, keepdim=True) else: self.scaler = NOPScaler(dim=1, keepdim=True) if cardinalities is None: cardinalities = [1] self.embedder = FeatureEmbedder( cardinalities=cardinalities, embedding_dims=[min(50, (cat + 1) // 2) for cat in cardinalities], ) self.time_features = ( time_features_from_frequency_str(freq) if freq is not None else [] ) self.num_feat_dynamic_real = ( 1 + num_feat_dynamic_real + len(self.time_features) ) self.num_feat_static_cat = max(num_feat_static_cat, 1) self.num_feat_static_real = max(num_feat_static_real, 1) self.use_features = use_features self.use_lags = use_lags self.context_length = context_length self.prediction_length = prediction_length self.losses_running_mean = torch.ones(timesteps, requires_grad=False) self.lr = lr self.best_crps = np.inf def _extract_features(self, data): raise NotImplementedError() def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) scheduler = ReduceLROnPlateau( optimizer, mode="min", factor=0.5, patience=int(1e12) ) return [optimizer], {"scheduler": scheduler, "monitor": "train_loss"} def log(self, name, value, **kwargs): super().log(name, value, **kwargs) if isinstance(value, torch.Tensor): value = value.detach().cpu().item() if name not in self.logs: self.logs[name] = [value] else: self.logs[name].append(value) def get_logs(self): logs = self.logs logs["epochs"] = list(range(self.current_epoch)) return pd.DataFrame.from_dict(logs) def q_sample(self, x_start, t, noise=None): device = next(self.backbone.parameters()).device if noise is None: noise = torch.randn_like(x_start, device=device) sqrt_alphas_cumprod_t = extract( self.sqrt_alphas_cumprod, t, x_start.shape ) sqrt_one_minus_alphas_cumprod_t = extract( self.sqrt_one_minus_alphas_cumprod, t, x_start.shape ) return ( sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise ) def p_losses( self, x_start, t, features=None, noise=None, loss_type="l2", reduction="mean", ): device = next(self.backbone.parameters()).device if noise is None: noise = torch.randn_like(x_start, device=device) x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) predicted_noise = self.backbone(x_noisy, t, features) if loss_type == "l1": loss = F.l1_loss(noise, predicted_noise, reduction=reduction) elif loss_type == "l2": loss = F.mse_loss(noise, predicted_noise, reduction=reduction) elif loss_type == "huber": loss = F.smooth_l1_loss( noise, predicted_noise, reduction=reduction ) else: raise NotImplementedError() return loss, x_noisy, predicted_noise @torch.no_grad() def p_sample(self, x, t, t_index, features=None): betas_t = extract(self.betas, t, x.shape) sqrt_one_minus_alphas_cumprod_t = extract( self.sqrt_one_minus_alphas_cumprod, t, x.shape ) sqrt_recip_alphas_t = extract(self.sqrt_recip_alphas, t, x.shape) predicted_noise = self.backbone(x, t, features) model_mean = sqrt_recip_alphas_t * ( x - betas_t * predicted_noise / sqrt_one_minus_alphas_cumprod_t ) if t_index == 0: return model_mean else: posterior_variance_t = extract(self.posterior_variance, t, x.shape) noise = torch.randn_like(x) return model_mean + torch.sqrt(posterior_variance_t) * noise @torch.no_grad() def p_sample_ddim(self, x, t, features=None, noise=None): if noise is None: noise = self.backbone(x, t, features) sqrt_alphas_cumprod_prev_t = extract( self.alphas_cumprod_prev, t, x.shape ).sqrt() sqrt_one_minus_alphas_cumprod_prev_t = extract( 1 - self.alphas_cumprod_prev, t, x.shape ).sqrt() sqrt_one_minus_alphas_cumprod_t = extract( self.sqrt_one_minus_alphas_cumprod, t, x.shape ) sqrt_alphas_cumprod_t = extract(self.sqrt_alphas_cumprod, t, x.shape) x0pointer = ( sqrt_alphas_cumprod_prev_t * (x - sqrt_one_minus_alphas_cumprod_t * noise) / sqrt_alphas_cumprod_t ) xtpointer = sqrt_one_minus_alphas_cumprod_prev_t * noise return x0pointer + xtpointer @torch.no_grad() def p_sample_genddim( self, x: torch.Tensor, t: torch.Tensor, t_index: int, t_prev: Optional[torch.Tensor] = None, eta: float = 0.0, features=None, noise: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Generalized DDIM step that interpolates between DDPM (eta=1) and DDIM (eta=0). Args: x (torch.Tensor): _description_ t (torch.Tensor): _description_ features (_type_, optional): _description_. Defaults to None. noise (Optional[torch.Tensor], optional): _description_. Defaults to None. Returns: torch.Tensor: _description_ """ if noise is None: noise = self.backbone(x, t, features) if t_prev is None: t_prev = t - 1 alphas_cumprod_t = extract(self.alphas_cumprod, t, x.shape) alphas_cumprod_prev_t = ( extract(self.alphas_cumprod, t_prev, x.shape) if t_index > 0 else torch.ones_like(alphas_cumprod_t) ) sqrt_alphas_cumprod_prev_t = alphas_cumprod_prev_t.sqrt() sqrt_one_minus_alphas_cumprod_t = extract( self.sqrt_one_minus_alphas_cumprod, t, x.shape ) sqrt_alphas_cumprod_t = extract(self.sqrt_alphas_cumprod, t, x.shape) x0pointer = ( sqrt_alphas_cumprod_prev_t * (x - sqrt_one_minus_alphas_cumprod_t * noise) / sqrt_alphas_cumprod_t ) c1 = ( eta * ( (1 - alphas_cumprod_t / alphas_cumprod_prev_t) * (1 - alphas_cumprod_prev_t) / (1 - alphas_cumprod_t) ).sqrt() ) c2 = ((1 - alphas_cumprod_prev_t) - c1**2).sqrt() return x0pointer + c1 * torch.randn_like(x) + c2 * noise @torch.no_grad() def sample(self, noise, features=None): device = next(self.backbone.parameters()).device batch_size, length, ch = noise.shape seq = noise seqs = [seq.cpu()] for i in reversed(range(0, self.timesteps)): seq = self.p_sample( seq, torch.full((batch_size,), i, device=device, dtype=torch.long), i, features, ) seqs.append(seq.cpu().numpy()) return np.stack(seqs, axis=0) def fast_denoise(self, xt, t, features=None, noise=None): if noise is None: noise = self.backbone(xt, t, features) sqrt_one_minus_alphas_cumprod_t = extract( self.sqrt_one_minus_alphas_cumprod, t, xt.shape ) sqrt_alphas_cumprod_t = extract(self.sqrt_alphas_cumprod, t, xt.shape) return ( xt - sqrt_one_minus_alphas_cumprod_t * noise ) / sqrt_alphas_cumprod_t def forward(self, x, mask): raise NotImplementedError() def training_step(self, data, idx): assert self.training is True device = next(self.backbone.parameters()).device if isinstance(data, dict): x, _, features = self._extract_features(data) else: x, _ = self.scaler(data, torch.ones_like(data)) t = torch.randint( 0, self.timesteps, (x.shape[0],), device=device ).long() elbo_loss, xt, noise = self.p_losses(x, t, features, loss_type="l2") return { "loss": elbo_loss, "elbo_loss": elbo_loss, } def training_epoch_end(self, outputs): epoch_loss = sum(x["loss"] for x in outputs) / len(outputs) elbo_loss = sum(x["elbo_loss"] for x in outputs) / len(outputs) self.log("train_loss", epoch_loss) self.log("train_elbo_loss", elbo_loss) def validation_step(self, data, idx): device = next(self.backbone.parameters()).device if isinstance(data, dict): x, _, features = self._extract_features(data) else: x, features = data, None t = torch.randint( 0, self.timesteps, (x.shape[0],), device=device ).long() elbo_loss, xt, noise = self.p_losses(x, t, features, loss_type="l2") return { "loss": elbo_loss, "elbo_loss": elbo_loss, } def validation_epoch_end(self, outputs): epoch_loss = sum(x["loss"] for x in outputs) / len(outputs) elbo_loss = sum(x["elbo_loss"] for x in outputs) / len(outputs) self.log("valid_loss", epoch_loss) self.log("valid_elbo_loss", elbo_loss) ================================================ FILE: src/uncond_ts_diff/model/diffusion/tsdiff.py ================================================ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 import copy import torch from gluonts.torch.util import lagged_sequence_values from uncond_ts_diff.arch import BackboneModel from uncond_ts_diff.model.diffusion._base import TSDiffBase from uncond_ts_diff.utils import get_lags_for_freq class TSDiff(TSDiffBase): def __init__( self, backbone_parameters, timesteps, diffusion_scheduler, context_length, prediction_length, num_feat_dynamic_real: int = 0, num_feat_static_cat: int = 0, num_feat_static_real: int = 0, cardinalities=None, freq=None, normalization="none", use_features=False, use_lags=True, init_skip=True, lr=1e-3, ): super().__init__( backbone_parameters, timesteps=timesteps, diffusion_scheduler=diffusion_scheduler, context_length=context_length, prediction_length=prediction_length, num_feat_dynamic_real=num_feat_dynamic_real, num_feat_static_cat=num_feat_static_cat, num_feat_static_real=num_feat_static_real, cardinalities=cardinalities, freq=freq, normalization=normalization, use_features=use_features, use_lags=use_lags, lr=lr, ) self.freq = freq if use_lags: self.lags_seq = get_lags_for_freq(freq) backbone_parameters = backbone_parameters.copy() backbone_parameters["input_dim"] += len(self.lags_seq) backbone_parameters["output_dim"] += len(self.lags_seq) else: self.lags_seq = [0] self.input_dim = backbone_parameters["input_dim"] self.backbone = BackboneModel( **backbone_parameters, num_features=( self.num_feat_static_real + self.num_feat_static_cat + self.num_feat_dynamic_real + 1 # log_scale ), init_skip=init_skip, ) self.ema_rate = [] # [0.9999] self.ema_state_dicts = [ copy.deepcopy(self.backbone.state_dict()) for _ in range(len(self.ema_rate)) ] def _extract_features(self, data): prior = data["past_target"][:, : -self.context_length] context = data["past_target"][:, -self.context_length :] context_observed = data["past_observed_values"][ :, -self.context_length : ] if self.normalization == "zscore": scaled_context, scale = self.scaler( context, context_observed, data["stats"] ) else: scaled_context, scale = self.scaler(context, context_observed) features = [] scaled_prior = prior / scale scaled_future = data["future_target"] / scale features.append(scale.log()) x = torch.cat([scaled_context, scaled_future], dim=1) if data["feat_static_cat"] is not None: features.append(self.embedder(data["feat_static_cat"])) if data["feat_static_real"] is not None: features.append(data["feat_static_real"]) static_feat = torch.cat( features, dim=1, ) expanded_static_feat = static_feat.unsqueeze(1).expand( -1, x.shape[1], -1 ) features = [expanded_static_feat] time_features = [] if data["past_time_feat"] is not None: time_features.append( data["past_time_feat"][:, -self.context_length :] ) if data["future_time_feat"] is not None: time_features.append(data["future_time_feat"]) features.append(torch.cat(time_features, dim=1)) features = torch.cat(features, dim=-1) if self.use_lags: lags = lagged_sequence_values( self.lags_seq, scaled_prior, torch.cat([scaled_context, scaled_future], dim=1), dim=1, ) x = torch.cat([x[:, :, None], lags], dim=-1) else: x = x[:, :, None] if not self.use_features: features = None return x, scale[:, :, None], features @torch.no_grad() def sample_n( self, num_samples: int = 1, return_lags: bool = False, ): device = next(self.backbone.parameters()).device seq_len = self.context_length + self.prediction_length samples = torch.randn( (num_samples, seq_len, self.input_dim), device=device ) for i in reversed(range(0, self.timesteps)): t = torch.full((num_samples,), i, device=device, dtype=torch.long) samples = self.p_sample(samples, t, i, features=None) samples = samples.cpu().numpy() if return_lags: return samples return samples[..., 0] def on_train_batch_end(self, outputs, batch, batch_idx): for rate, state_dict in zip(self.ema_rate, self.ema_state_dicts): update_ema(state_dict, self.backbone.state_dict(), rate=rate) def update_ema(target_state_dict, source_state_dict, rate=0.99): with torch.no_grad(): for key, value in source_state_dict.items(): ema_value = target_state_dict[key] ema_value.copy_( rate * ema_value + (1.0 - rate) * value.cpu(), non_blocking=True, ) ================================================ FILE: src/uncond_ts_diff/model/diffusion/tsdiff_cond.py ================================================ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 import torch from gluonts.torch.model.predictor import PyTorchPredictor from gluonts.torch.util import lagged_sequence_values from uncond_ts_diff.arch import BackboneModel from uncond_ts_diff.model.diffusion._base import TSDiffBase from uncond_ts_diff.model.diffusion._base import PREDICTION_INPUT_NAMES from uncond_ts_diff.utils import get_lags_for_freq PREDICTION_INPUT_NAMES = PREDICTION_INPUT_NAMES + ["orig_past_target"] class TSDiffCond(TSDiffBase): def __init__( self, backbone_parameters, timesteps, diffusion_scheduler, context_length, prediction_length, num_feat_dynamic_real: int = 0, num_feat_static_cat: int = 0, num_feat_static_real: int = 0, cardinalities=None, freq=None, normalization="none", use_features=False, use_lags=True, lr=1e-3, init_skip=True, noise_observed=True, ): super().__init__( backbone_parameters, timesteps=timesteps, diffusion_scheduler=diffusion_scheduler, context_length=context_length, prediction_length=prediction_length, num_feat_dynamic_real=num_feat_dynamic_real, num_feat_static_cat=num_feat_static_cat, num_feat_static_real=num_feat_static_real, cardinalities=cardinalities, freq=freq, normalization=normalization, use_features=use_features, use_lags=use_lags, lr=lr, ) num_features = ( ( self.num_feat_dynamic_real + self.num_feat_static_cat + self.num_feat_static_real + 1 ) if use_features else 0 ) self.freq = freq self.lags_seq = get_lags_for_freq(freq) if use_lags else [0] self.backbone = BackboneModel( **backbone_parameters, num_features=( num_features + 2 + (len(self.lags_seq) if use_lags else 0) ), init_skip=init_skip, ) self.noise_observed = noise_observed def _extract_features(self, data): device = next(self.parameters()).device prior = data["past_target"][:, : -self.context_length] context = data["past_target"][:, -self.context_length :] context_observed = data["past_observed_values"][ :, -self.context_length : ] scaled_context, scale = self.scaler(context, context_observed) features = [] scaled_prior = prior / scale scaled_future = data["future_target"] / scale scaled_orig_context = ( data["orig_past_target"][:, -self.context_length :] ) / scale x = torch.cat([scaled_orig_context, scaled_future], dim=1) observation_mask = torch.zeros_like(x, device=device) observation_mask[:, : -self.prediction_length] = data[ "past_observed_values" ][:, -self.context_length :].data x_past = torch.cat( [scaled_context, torch.zeros_like(scaled_future)], dim=1 ).clone() assert x.size() == x_past.size() if data["feat_static_cat"] is not None: features.append(self.embedder(data["feat_static_cat"])) if data["feat_static_real"] is not None: features.append(data["feat_static_real"]) static_feat = torch.cat( features, dim=1, ) expanded_static_feat = static_feat.unsqueeze(1).expand( -1, x.shape[1], -1 ) features = [] if self.use_features: features.append(expanded_static_feat) time_features = [] if data["past_time_feat"] is not None: time_features.append( data["past_time_feat"][:, -self.context_length :] ) if data["future_time_feat"] is not None: time_features.append(data["future_time_feat"]) features.append(torch.cat(time_features, dim=1)) lags = lagged_sequence_values( self.lags_seq, scaled_prior, torch.cat([scaled_context, scaled_future], dim=1), dim=1, ) if self.use_lags: features.append(lags) features.append(x_past[..., None]) features.append(observation_mask[..., None]) features = torch.cat(features, dim=-1) return x[..., None], scale[..., None], features def step(self, x, t, features, loss_mask): noise = torch.randn_like(x) if not self.noise_observed: noise = (1 - loss_mask) * x + noise * loss_mask num_eval = loss_mask.sum() sq_err, _, _ = self.p_losses( x, t, features, loss_type="l2", reduction="none", noise=noise, ) if self.noise_observed: elbo_loss = sq_err.mean() else: sq_err = sq_err * loss_mask elbo_loss = sq_err.sum() / (num_eval if num_eval else 1) return elbo_loss def training_step(self, data, idx): assert self.training is True device = next(self.parameters()).device x, _, features = self._extract_features(data) # Last dim of features has the observation mask observation_mask = features[..., -1:] loss_mask = 1 - observation_mask t = torch.randint( 0, self.timesteps, (x.shape[0],), device=device ).long() elbo_loss = self.step(x, t, features, loss_mask) return { "loss": elbo_loss, "elbo_loss": elbo_loss, } def validation_step(self, data, idx): device = next(self.parameters()).device x, _, features = self._extract_features(data) # Last dim of features has the observation mask observation_mask = features[..., -1:] loss_mask = 1 - observation_mask val_loss = 0.0 for i in range(self.timesteps): t = torch.full((x.shape[0],), i, device=device).long() val_loss += self.step(x, t, features, loss_mask) val_loss /= self.timesteps return { "loss": val_loss, "elbo_loss": val_loss, } @torch.no_grad() def forecast(self, observation, observation_mask, features=None): device = next(self.backbone.parameters()).device batch_size, length, ch = observation.shape seq = torch.randn_like(observation) for i in reversed(range(0, self.timesteps)): if not self.noise_observed: seq = observation_mask * observation + seq * ( 1 - observation_mask ) seq = self.p_sample( seq, torch.full((batch_size,), i, device=device, dtype=torch.long), i, features, ) return seq def forward( self, past_target: torch.Tensor, past_observed_values: torch.Tensor, feat_static_cat: torch.Tensor = None, feat_static_real: torch.Tensor = None, past_time_feat: torch.Tensor = None, future_time_feat: torch.Tensor = None, orig_past_target: torch.Tensor = None, ): # This is only used during prediction device = next(self.backbone.parameters()).device data = dict( feat_static_cat=feat_static_cat.to(device) if feat_static_cat is not None else None, feat_static_real=feat_static_real.to(device) if feat_static_real is not None else None, past_time_feat=past_time_feat.to(device) if past_time_feat is not None else None, past_target=past_target.to(device), orig_past_target=orig_past_target.to(device), future_target=torch.zeros( past_target.shape[0], self.prediction_length, device=device ), past_observed_values=past_observed_values.to(device) if past_observed_values is not None else None, future_time_feat=future_time_feat.to(device) if future_time_feat is not None else None, ) observation, scale, features = self._extract_features(data) observation = observation.to(device) batch_size, length, ch = observation.shape observation_mask = features[..., -1:] pred = self.forecast( observation=observation, observation_mask=observation_mask, features=features, ) pred = pred * scale return pred[:, None, length - self.prediction_length :, 0] def get_predictor(self, input_transform, batch_size=40, device=None): return PyTorchPredictor( prediction_length=self.prediction_length, input_names=PREDICTION_INPUT_NAMES, prediction_net=self, batch_size=batch_size, input_transform=input_transform, device=device, ) ================================================ FILE: src/uncond_ts_diff/model/linear/_estimator.py ================================================ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 from typing import Optional, List import math import numpy as np from sklearn.linear_model import LinearRegression, Ridge from gluonts.model import Estimator, Predictor from gluonts.dataset.common import Dataset from gluonts.dataset.field_names import FieldName from gluonts.transform import ( Transformation, AddObservedValuesIndicator, InstanceSplitter, TestSplitSampler, ExpectedNumInstanceSampler, SelectFields, ) from gluonts.dataset.loader import TrainDataLoader, InferenceDataLoader from gluonts.itertools import Cached from gluonts.model.forecast_generator import ( ForecastGenerator, SampleForecastGenerator, predict_to_numpy, ) from ._scaler import MeanScaler, NOPScaler PREDICTION_INPUT_NAMES = [ "past_target", "past_observed_values", ] TRAINING_INPUT_NAMES = PREDICTION_INPUT_NAMES + [ "future_target", "future_observed_values", ] def stack(data): if isinstance(data[0], np.ndarray): data = np.array(data) elif isinstance(data[0], (list, tuple)): return list(stack(t) for t in zip(*data)) return data def batchify(data: List[dict]): return { key: stack(data=[item[key] for item in data]) for key in data[0].keys() } class LinearModel: def __init__(self, weight, bias, scaler, num_parallel_samples=100) -> None: super().__init__() self.scaler = scaler self.weight = weight self.bias = bias self.num_parallel_samples = num_parallel_samples def _linear(self, x, A, b): return x @ A.T + b def __call__(self, x, mask): assert x.ndim == 2 x, scale = self.scaler(x, np.ones_like(x)) out = self._linear(x, self.weight, self.bias) * scale return np.tile(out[:, None], (1, self.num_parallel_samples, 1)) @predict_to_numpy.register(LinearModel) def _(prediction_net, args) -> np.ndarray: return prediction_net(*args) class LinearPredictor(Predictor): def __init__( self, input_names: List[str], prediction_net: LinearModel, batch_size: int, prediction_length: int, input_transform: Transformation, forecast_generator: ForecastGenerator = SampleForecastGenerator(), lead_time: int = 0, ) -> None: super().__init__(prediction_length, lead_time=lead_time) self.input_names = input_names self.prediction_net = prediction_net self.batch_size = batch_size self.input_transform = input_transform self.forecast_generator = forecast_generator def predict(self, dataset: Dataset, num_samples: Optional[int] = None): inference_data_loader = InferenceDataLoader( dataset, transform=self.input_transform, batch_size=self.batch_size, stack_fn=batchify, ) yield from self.forecast_generator( inference_data_loader=inference_data_loader, prediction_net=self.prediction_net, input_names=self.input_names, output_transform=None, num_samples=num_samples, ) class LinearEstimator(Estimator): """A Linear regressor that takes inputs of size equal to `context_length` and outputs forecasts of size equal to `prediction_length`. This model uses LinearRegression from scikit-learn under the hood. Example usage: ```python estimator = LinearEstimator( dataset.metadata.freq, prediction_length=dataset.metadata.prediction_length, context_length=24 * 7 * 2, ) predictor = estimator.train(dataset.train) ``` Parameters ---------- freq Frequency of the dataset (not actually used) prediction_length Prediction length context_length, optional Context length for the linear model, by default equal to 4 * prediction_length num_train_samples, optional Number of samples used to fit the LinearRegression model, by default 10000 model, optional Which sklearn linear model to use, one of {"linear", "ridge"}, by default "ridge". scaling, optional Whether to use scaling, by default True batch_size, optional Batch size (only relevant during prediction), by default 64 """ def __init__( self, freq: str, prediction_length: int, context_length: Optional[int] = None, num_train_samples: int = 10000, model: str = "ridge", scaling: bool = True, batch_size: int = 64, **kwargs, ) -> None: super().__init__(**kwargs) assert model in {"linear", "ridge"} self.freq = freq self.prediction_length = prediction_length self.context_length = context_length or 4 * prediction_length self.num_train_samples = num_train_samples self.model = model if scaling: self.scaler = MeanScaler(axis=-1, keepdims=True) else: self.scaler = NOPScaler(axis=-1, keepdims=True) self.batch_size = batch_size def create_transformation(self) -> Transformation: return SelectFields( [ FieldName.ITEM_ID, FieldName.INFO, FieldName.START, FieldName.TARGET, ], allow_missing=True, ) + AddObservedValuesIndicator( target_field=FieldName.TARGET, output_field=FieldName.OBSERVED_VALUES, ) def _create_instance_splitter(self, mode: str): assert mode in ["training", "test"] instance_sampler = { "training": ExpectedNumInstanceSampler( num_instances=1, min_past=self.context_length, min_future=self.prediction_length, ), "test": TestSplitSampler(), }[mode] return InstanceSplitter( target_field=FieldName.TARGET, is_pad_field=FieldName.IS_PAD, start_field=FieldName.START, forecast_start_field=FieldName.FORECAST_START, instance_sampler=instance_sampler, past_length=self.context_length, future_length=self.prediction_length, time_series_fields=[ FieldName.OBSERVED_VALUES, ], ) def _create_training_samples(self, training_data) -> np.ndarray: transformation = self._create_instance_splitter( "training" ) + SelectFields(TRAINING_INPUT_NAMES) num_batches_per_epoch = math.ceil(self.num_train_samples / 100) data_loader = TrainDataLoader( training_data, batch_size=100, stack_fn=batchify, transform=transformation, num_batches_per_epoch=num_batches_per_epoch, ) train_X, train_y = [], [] for batch in data_loader: train_X.append(batch["past_target"]) train_y.append(batch["future_target"]) assert np.all(batch["past_observed_values"] == 1.0) and np.all( batch["future_observed_values"] == 1.0 ), "Missing values not supported!" train_X = np.concatenate(train_X, 0) train_y = np.concatenate(train_y, 0) train_X = train_X[: self.num_train_samples] train_y = train_y[: self.num_train_samples] assert len(train_X) == self.num_train_samples return train_X, train_y def create_predictor(self, transformation, model): prediction_splitter = self._create_instance_splitter("test") return LinearPredictor( input_names=PREDICTION_INPUT_NAMES, prediction_net=model, batch_size=self.batch_size, prediction_length=self.prediction_length, input_transform=transformation + prediction_splitter, ) def train( self, training_data: Dataset, validation_data: Optional[Dataset] = None, cache_data: bool = False, ) -> Predictor: transformation = self.create_transformation() transformed_data = transformation.apply(training_data, is_train=True) if cache_data: transformed_data = Cached(transformed_data) train_X, train_y = self._create_training_samples(transformed_data) scaled_train_X, scale = self.scaler(train_X, np.ones_like(train_X)) scaled_train_y = train_y / scale if self.model == "linear": SKLearnLinear = LinearRegression elif self.model == "ridge": SKLearnLinear = Ridge regressor = SKLearnLinear().fit(scaled_train_X, scaled_train_y) model = LinearModel(regressor.coef_, regressor.intercept_, self.scaler) return self.create_predictor( transformation=transformation, model=model ) ================================================ FILE: src/uncond_ts_diff/model/linear/_scaler.py ================================================ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 from typing import Optional, Tuple import numpy as np class MeanScaler: """Just like torch MeanScaler, but for numpy.""" def __init__( self, axis: int, keepdims: bool = False, default_scale: Optional[float] = None, minimum_scale: float = 1e-10, ): super().__init__() self.axis = axis self.keepdims = keepdims self.minimum_scale = minimum_scale self.default_scale = default_scale or 0.0 def __call__( self, data: np.ndarray, weights: np.ndarray ) -> Tuple[np.ndarray, np.ndarray]: # these will have shape (N, C) total_weight = weights.sum(axis=self.axis) weighted_sum = (np.abs(data) * weights).sum(axis=self.axis) # first compute a global scale per-dimension total_observed = total_weight.sum(axis=0) denominator = np.maximum(total_observed, np.ones_like(total_observed)) if self.default_scale != 0.0: default_scale = self.default_scale else: default_scale = weighted_sum.sum(axis=0) / denominator # then compute a per-item, per-dimension scale denominator = np.maximum(total_weight, np.ones_like(total_weight)) scale = weighted_sum / denominator # use per-batch scale when no element is observed # or when the sequence contains only zeros scale = np.expand_dims( np.maximum( self.minimum_scale, np.where( weighted_sum > np.zeros_like(weighted_sum), scale, default_scale * np.ones_like(total_weight), ), ), axis=self.axis, ) return data / scale, scale if self.keepdims else scale.squeeze( axis=self.axis ) class NOPScaler: """ Just like torch NOPScaler, but for numpy. """ def __init__(self, axis: int, keepdims: bool = False): super().__init__() self.axis = axis self.keepdims = keepdims def __call__( self, data: np.ndarray, weights: np.ndarray ) -> Tuple[np.ndarray, np.ndarray]: scale = np.ones_like(data).mean( axis=self.axis, keepdims=self.keepdims, ) return data, scale ================================================ FILE: src/uncond_ts_diff/predictor.py ================================================ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 from typing import Iterator, Optional from gluonts.dataset import Dataset from gluonts.dataset.loader import InferenceDataLoader from gluonts.model import Forecast from gluonts.torch.batchify import batchify from gluonts.torch.model.predictor import PyTorchPredictor class PyTorchPredictorWGrads(PyTorchPredictor): def predict( self, dataset: Dataset, num_samples: Optional[int] = None ) -> Iterator[Forecast]: inference_data_loader = InferenceDataLoader( dataset, transform=self.input_transform, batch_size=self.batch_size, stack_fn=lambda data: batchify(data, self.device), ) self.prediction_net.eval() yield from self.forecast_generator( inference_data_loader=inference_data_loader, prediction_net=self.prediction_net, input_names=self.input_names, output_transform=self.output_transform, num_samples=num_samples, ) ================================================ FILE: src/uncond_ts_diff/sampler/__init__.py ================================================ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 from .observation_guidance import DDIMGuidance, DDPMGuidance from .refiner import MostLikelyRefiner, MCMCRefiner __all__ = [ "DDIMGuidance", "DDPMGuidance", "MostLikelyRefiner", "MCMCRefiner", ] ================================================ FILE: src/uncond_ts_diff/sampler/_base.py ================================================ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 from typing import Callable, Tuple from functools import partial import numpy as np import torch def grad_fn(fn, x): x.requires_grad_(True) return torch.autograd.grad(fn(x), x)[0] @torch.no_grad() def langevin_dynamics( z0: torch.Tensor, energy_func: Callable = None, score_func: Callable = None, step_size: float = 0.1, noise_scale: float = 0.1, n_steps: int = 1, ): """Overdamped Langevin dynamics. Parameters ---------- z0 Initial guess. energy_func, optional Energy function, only one of energy function or score function must be specified, by default None score_func, optional Score function, only one of energy function or score function must be specified, by default None step_size, optional Step size, by default 0.1 noise_scale, optional Scale for Brownian noise, by default 0.1 n_steps, optional Number of Langevin steps, by default 1 Returns ------- Updated point. """ assert energy_func is not None or score_func is not None z = z0 sqrt_2eta = torch.sqrt(2 * torch.tensor(step_size)) for _ in range(n_steps): if energy_func is not None: with torch.enable_grad(): z.requires_grad_(True) Ez = energy_func(z) v = -torch.autograd.grad(Ez, z)[0] else: v = score_func(z) z = ( z.detach() + step_size * v + sqrt_2eta * noise_scale * torch.randn_like(z) ) return z @torch.enable_grad() def leapfrog( xt: torch.Tensor, pt: torch.Tensor, dynamics_p: Callable[[torch.Tensor], torch.Tensor], mass: float, h: float, n_steps: int, ) -> Tuple[torch.Tensor, torch.Tensor]: """Leapfrong integrator. Parameters ---------- xt Position. pt Momentum. dynamics_p Dynamics function for momentum mass Mass of particle h Step size n_steps Number of leapfrog integration steps Returns ------- Updated position and momentum. """ for _ in range(n_steps): pt = pt - (h / 2) * dynamics_p(xt) xt = xt + h * pt / mass pt = pt - (h / 2) * dynamics_p(xt) xt, pt = xt.detach(), pt.detach() return xt, pt @torch.no_grad() def hmc( x0: torch.Tensor, energy_func: Callable[[torch.Tensor], torch.Tensor], step_size: float, mass: float, n_leapfrog_steps: int = 10, n_steps: int = 100, ) -> torch.Tensor: """Hamiltonian Monte Carlo. Parameters ---------- x0 Initial guess of shape [B, T, C]. energy_func Energy function E: [B, T, C] -> [] step_size Step size. mass Mass of particle. n_leapfrog_steps, optional Number of leapfrog integration steps, by default 10 n_steps, optional Number of HMC steps, by default 100 Returns ------- Updated tensor of shape [B, T, C]. """ potential_energy_func = energy_func batch_size, length, ch = x0.shape drift_func = partial(grad_fn, potential_energy_func) xt = x0 for _ in range(n_steps): pt = np.sqrt(mass) * torch.randn_like(xt) xt_prop, pt_prop = leapfrog( xt, pt, drift_func, mass, step_size, n_leapfrog_steps ) xt = xt_prop return xt def linear_midpoint_em_step( zt: torch.Tensor, coeff: float, h: float, sigma: float ): """Midpoint Euler-Maruyama step.""" eta = torch.randn_like(zt) ztp1 = zt - h * coeff * zt / 2 + np.sqrt(h) * sigma * eta ztp1 = ztp1 / (1 + h * coeff / 2) return ztp1.detach() @torch.no_grad() def udld( x0: torch.Tensor, potential_energy_func: Callable[[torch.Tensor], torch.Tensor], step_size: float, friction: float, mass: float, n_leapfrog_steps: int = 1, n_steps: int = 100, ) -> torch.Tensor: """Underdamped Langevin dynamics. Parameters ---------- x0 Initial guess of shape [B, T, C] potential_energy_func Energy function E: [B, T, C] -> [] step_size Step size friction Friction coefficient mass Mass of the particle n_leapfrog_steps, optional Number of leapfrog integration steps, by default 1 n_steps, optional Number of UDLD steps, by default 100 Returns ------- Updated tensor of shape [B, T, C]. """ batch_size, length, ch = x0.shape xt = x0 drift_func = partial(grad_fn, potential_energy_func) pt = np.sqrt(mass) * torch.randn_like(xt) coeff = friction / mass sigma = np.sqrt(2 * friction) for _ in range(n_steps): pt = linear_midpoint_em_step(pt, coeff, step_size / 2, sigma) xt_prop, pt_prop = leapfrog( xt, pt, drift_func, mass, step_size, n_leapfrog_steps ) xt, pt = xt_prop, pt_prop pt = linear_midpoint_em_step(pt, coeff, step_size / 2, sigma) return xt ================================================ FILE: src/uncond_ts_diff/sampler/observation_guidance.py ================================================ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 import numpy as np import torch import torch.nn.functional as F from gluonts.torch.util import lagged_sequence_values from uncond_ts_diff.predictor import PyTorchPredictorWGrads from uncond_ts_diff.utils import extract from uncond_ts_diff.model import TSDiff PREDICTION_INPUT_NAMES = [ "past_target", "past_observed_values", "feat_static_cat", "feat_static_real", "past_time_feat", "future_time_feat", "stats", ] class Guidance(torch.nn.Module): _missing_scenarios = ["none", "RM", "BM-B", "BM-E"] def __init__( self, model: TSDiff, prediction_length: int, scale: float = 1.0, num_samples: int = 1, guidance: str = "quantile", missing_scenario: str = "none", missing_values: int = 0, ): super().__init__() assert missing_scenario in self._missing_scenarios self.model = model self.prediction_length = prediction_length self.scale = scale self.num_samples = num_samples self.guidance = guidance self.missing_scenario = missing_scenario self.missing_values = missing_values def quantile_loss(self, y_prediction, y_target): assert y_target.shape == y_prediction.shape device = y_prediction.device batch_size_x_num_samples, length, ch = y_target.shape batch_size = batch_size_x_num_samples // self.num_samples # num_samples uniformly distributed quantiles between 0 and 1 # repeat for each element in the batch q = (torch.arange(self.num_samples).repeat(batch_size) + 1).to( device ) / (self.num_samples + 1) # (batch_size x num_samples,) q = q[:, None, None] # (batch_size x num_samples, 1, 1) e = y_target - y_prediction loss = torch.max(q * e, (q - 1) * e) return loss def energy_func(self, y, t, observation, observation_mask, features): if self.guidance == "MSE": return F.mse_loss( self.model.fast_denoise(y, t, features), observation, reduction="none", )[observation_mask == 1].sum() elif self.guidance == "quantile": return self.quantile_loss( self.model.fast_denoise(y, t, features), observation, )[observation_mask == 1].sum() else: raise ValueError(f"Unknown guidance {self.guidance}!") def score_func(self, y, t, observation, observation_mask, features): with torch.enable_grad(): y.requires_grad_(True) Ey = self.energy_func( y, t, observation, observation_mask, features ) return -torch.autograd.grad(Ey, y)[0] def scale_func(self, y, t, base_scale): raise NotImplementedError("Must be implemented by a subclass!") def guide(self, observation, observation_mask, features, scale): raise NotImplementedError("Must be implemented by a subclass!") def forward( self, past_target: torch.Tensor, past_observed_values: torch.Tensor, feat_static_cat: torch.Tensor = None, feat_static_real: torch.Tensor = None, past_time_feat: torch.Tensor = None, future_time_feat: torch.Tensor = None, stats: torch.Tensor = None, ): device = next(self.model.parameters()).device future_target = torch.zeros( past_target.shape[0], self.prediction_length, device=device ) data = dict( feat_static_cat=feat_static_cat.to(device) if feat_static_cat is not None else None, feat_static_real=feat_static_real.to(device) if feat_static_real is not None else None, past_time_feat=past_time_feat.to(device) if past_time_feat is not None else None, past_target=past_target.to(device), future_target=future_target, past_observed_values=past_observed_values.to(device) if past_observed_values is not None else None, future_time_feat=future_time_feat.to(device) if future_time_feat is not None else None, stats=stats.to(device) if stats is not None else None, ) observation, scale_params, features = self.model._extract_features( data ) observation = observation.to(device) batch_size, length, ch = observation.shape prior_mask = past_observed_values[:, : -self.model.context_length] context_mask = past_observed_values[:, -self.model.context_length :] future_mask = torch.zeros_like(future_target) observation_mask = torch.cat([context_mask, future_mask], dim=1) if self.model.use_lags: lagged_mask = lagged_sequence_values( self.model.lags_seq, prior_mask, observation_mask, dim=1, ) observation_mask = torch.cat( [observation_mask[:, :, None], lagged_mask], dim=-1 ) else: observation_mask = observation_mask[:, :, None] observation = observation.repeat_interleave(self.num_samples, dim=0) observation_mask = observation_mask.repeat_interleave( self.num_samples, dim=0 ) if features is not None: features = features.repeat_interleave(self.num_samples, dim=0) # base_scale = self.scale / ( # context_mask.sum() / torch.ones_like(context_mask).sum() # ) base_scale = self.scale pred = self.guide(observation, observation_mask, features, base_scale) pred = pred[:, :, 0].reshape(batch_size, self.num_samples, -1) pred = pred * scale_params return pred[..., length - self.prediction_length :] def get_predictor(self, input_transform, batch_size=40, device=None): return PyTorchPredictorWGrads( prediction_length=self.prediction_length, input_names=PREDICTION_INPUT_NAMES, prediction_net=self, batch_size=batch_size, input_transform=input_transform, device=device, ) class DDPMGuidance(Guidance): def __init__( self, model: TSDiff, prediction_length: int, scale: float = 1, num_samples: int = 1, guidance: str = "quantile", missing_scenario: str = "none", missing_values: int = 0, ): super().__init__( model, prediction_length, scale, num_samples, guidance, missing_scenario, missing_values, ) def scale_func(self, y, t, base_scale): return extract(self.model.posterior_variance, t, y.shape) * base_scale @torch.no_grad() def _reverse_diffusion( self, observation, observation_mask, features, base_scale ): device = observation.device batch_size = observation.shape[0] seq = torch.randn_like(observation) for i in reversed(range(0, self.model.timesteps)): t = torch.full((batch_size,), i, device=device, dtype=torch.long) seq = self.model.p_sample(seq, t, i, features) scale = self.scale_func(seq, t, base_scale=base_scale) seq = seq + scale * self.score_func( seq, t, observation=observation, observation_mask=observation_mask, features=features, ) return seq def guide(self, observation, observation_mask, features, base_scale): return self._reverse_diffusion( observation, observation_mask, features, base_scale ) class DDIMGuidance(Guidance): _skip_types = ["uniform", "quadratic"] def __init__( self, model: TSDiff, prediction_length: int, eta: float = 0.0, skip_factor: int = 1, skip_type: str = "uniform", scale: float = 1, num_samples: int = 1, guidance: str = "quantile", missing_scenario: str = "none", missing_values: int = 0, ): super().__init__( model, prediction_length, scale, num_samples, guidance, missing_scenario, missing_values, ) assert skip_type in self._skip_types self.eta = eta self.skip_factor = skip_factor self.skip_type = skip_type def scale_func(self, y, t, base_scale): return ( extract(self.model.sqrt_one_minus_alphas_cumprod, t, y.shape) * base_scale ) def _get_timesteps(self): if self.skip_type == "uniform": timesteps = range(0, self.model.timesteps, self.skip_factor) elif self.skip_type == "quadratic": n_test_timesteps = int(self.model.timesteps / self.skip_factor) c = 1 - self.skip_factor / self.model.timesteps timesteps = np.square( np.linspace( 0, np.sqrt(self.model.timesteps * c), n_test_timesteps ) ) timesteps = timesteps.astype(np.int64).tolist() timesteps = sorted(set(timesteps)) return timesteps @torch.no_grad() def _reverse_ddim( self, observation, observation_mask, features, base_scale ): device = observation.device batch_size = observation.shape[0] timesteps = self._get_timesteps() timesteps_prev = [-1] + timesteps[:-1] seq = torch.randn_like(observation) for i, j in zip(reversed(timesteps), reversed(timesteps_prev)): t = torch.full((batch_size,), i, device=device, dtype=torch.long) t_prev = torch.full( (batch_size,), j, device=device, dtype=torch.long ) noise = self.model.backbone(seq, t, features) scale = self.scale_func(seq, t, base_scale=base_scale) noise = noise - scale * self.score_func( seq, t, observation=observation, observation_mask=observation_mask, features=features, ) seq = self.model.p_sample_genddim( seq, t, t_index=i, t_prev=t_prev, eta=self.eta, features=features, noise=noise, ) return seq def guide(self, observation, observation_mask, features, base_scale): return self._reverse_ddim( observation, observation_mask, features, base_scale ) ================================================ FILE: src/uncond_ts_diff/sampler/refiner.py ================================================ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from gluonts.time_feature import get_seasonality from uncond_ts_diff.predictor import PyTorchPredictorWGrads from uncond_ts_diff.sampler._base import ( langevin_dynamics, hmc, udld, ) PREDICTION_INPUT_NAMES = [ "past_target", "past_observed_values", "feat_static_cat", "feat_static_real", "past_time_feat", "future_time_feat", "stats", ] class Refiner(torch.nn.Module): def __init__( self, model, prediction_length, fixed_t=20, iterations=1, init=None, num_samples=1, guidance="quantile", scale=1, ): super().__init__() self.model = model self.prediction_length = prediction_length self.fixed_t = fixed_t self.iterations = iterations self.init = init self.num_samples = num_samples self.guidance = guidance self.scale = scale def quantile_loss(self, y_prediction, y_target): assert y_target.shape == y_prediction.shape device = y_prediction.device batch_size_x_num_samples, length, ch = y_target.shape batch_size = batch_size_x_num_samples // self.num_samples # num_samples uniformly distributed quantiles between 0 and 1 # repeat for each element in the batch q = (torch.arange(self.num_samples).repeat(batch_size) + 1).to( device ) / (self.num_samples + 1) # (batch_size x num_samples,) q = q[:, None, None] # (batch_size x num_samples, 1, 1) e = y_target - y_prediction loss = torch.max(q * e, (q - 1) * e) return loss def prior(self, y_prediction, obs, obs_mask): if self.guidance == "MSE": return ( self.scale * F.mse_loss(y_prediction, obs, reduction="none")[ obs_mask == 1 ].sum() ) elif self.guidance == "quantile": return self.scale * self.quantile_loss(y_prediction, obs).sum() else: raise ValueError(f"Unknown guidance {self.guidance}!") def refine(self, observation, observation_mask): raise NotImplementedError("Must be implemented by a subclass!") def forward( self, past_target: torch.Tensor, past_observed_values: torch.Tensor, feat_static_cat: torch.Tensor = None, feat_static_real: torch.Tensor = None, past_time_feat: torch.Tensor = None, future_time_feat: torch.Tensor = None, stats: torch.Tensor = None, ): device = next(self.model.backbone.parameters()).device data = dict( feat_static_cat=feat_static_cat.to(device) if feat_static_cat is not None else None, feat_static_real=feat_static_real.to(device) if feat_static_real is not None else None, past_time_feat=past_time_feat.to(device) if past_time_feat is not None else None, past_target=past_target.to(device), future_target=torch.zeros( past_target.shape[0], self.prediction_length, device=device ), past_observed_values=past_observed_values.to(device) if past_observed_values is not None else None, future_time_feat=future_time_feat.to(device) if future_time_feat is not None else None, ) observation, scale, features = self.model._extract_features(data) observation = observation.to(device) batch_size, length, ch = observation.shape observation_mask = torch.ones_like(observation, device=device) observation_mask[:, length - self.prediction_length :, 0] = 0 observation = observation.repeat_interleave(self.num_samples, dim=0) observation_mask = observation_mask.repeat_interleave( self.num_samples, dim=0 ) if features is not None: features = features.repeat_interleave(self.num_samples, dim=0) if self.init is not None: init_forecasts = np.stack( [next(self.init).samples for _ in range(batch_size)] ) if init_forecasts.shape[1] == 1: # Single sample, e.g., for SeasonalNaive init_forecasts = np.tile( init_forecasts, (1, self.num_samples, 1) ) # create numpy array out of list and sort them to # match to their corresponding quantile init = np.sort(init_forecasts, axis=1) init = torch.from_numpy(init).to(device) # scale input init = init / scale # reshape from B x num_samples x prediction_length to # B * self.num_samples x prediction_length init = init.reshape( batch_size * self.num_samples, self.prediction_length ) # use it as initial guess observation[:, length - self.prediction_length :, 0] = init else: season_length = get_seasonality(self.model.freq) # Initialize using Seasonal Naive predictions if (length - self.prediction_length) >= season_length: indices = [ length - self.prediction_length - season_length + k % season_length for k in range(self.prediction_length) ] observation[ :, length - self.prediction_length :, 0 ] = observation[:, indices, 0] # Initialize using the meant of the context length else: observation[ :, length - self.prediction_length :, 0 ] = torch.mean( observation[:, : length - self.prediction_length, 0], dim=1, keepdim=True, ) pred = self.refine(observation, observation_mask) pred = pred[:, :, 0].reshape(batch_size, self.num_samples, -1) pred = pred * scale return pred[:, :, length - self.prediction_length :] def get_predictor(self, input_transform, batch_size=40, device=None): return PyTorchPredictorWGrads( prediction_length=self.prediction_length, input_names=PREDICTION_INPUT_NAMES, prediction_net=self, batch_size=batch_size, input_transform=input_transform, device=device, ) class MostLikelyRefiner(Refiner): def __init__( self, model, prediction_length, lr=1e-1, patience=100, fixed_t=20, iterations=1, init=None, num_samples=1, guidance="quantile", scale=1, ): super().__init__( model, prediction_length, fixed_t, iterations, init, num_samples, guidance, scale, ) self.lr = lr self.patience = patience def _most_likely(self, observation, observation_mask): device = next(self.model.backbone.parameters()).device observation = observation.to(device) seq = nn.Parameter(torch.clone(observation), requires_grad=True) optim = torch.optim.SGD([seq], lr=self.lr) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optim, "min", patience=self.patience, factor=0.5 ) with torch.enable_grad(): for i in range(self.iterations): optim.zero_grad() t = torch.randint( 0, self.model.timesteps, (seq.shape[0],), device=device ).long() if self.fixed_t != -1: t = t * 0 + self.fixed_t loss = self.model.p_losses( seq, t, loss_type="l2", reduction="sum" )[0] + self.prior(seq, observation, observation_mask) loss.backward() optim.step() scheduler.step(loss.item()) return seq.detach() def refine(self, observation, observation_mask): return self._most_likely(observation, observation_mask) class MCMCRefiner(Refiner): _available_methods = {"lmc", "hmc", "udld", "cdld"} def __init__( self, model, prediction_length, step_size=1e-1, method="lmc", method_kwargs={}, fixed_t=20, iterations=1, init=None, num_samples=1, guidance="quantile", scale=1, ): super().__init__( model, prediction_length, fixed_t, iterations, init, num_samples, guidance, scale, ) assert method in self._available_methods self.step_size: float = step_size self.method: str = method self.method_kwargs: dict = method_kwargs def _mcmc(self, observation, observation_mask): device = next(self.model.backbone.parameters()).device observation = observation.to(device) seq = torch.clone(observation) for i in range(self.iterations): t = torch.randint( 0, self.model.timesteps, (seq.shape[0],), device=device ).long() if self.fixed_t != -1: t = t * 0 + self.fixed_t energy_func = lambda x: self.model.p_losses( # noqa: E731 x, t, loss_type="l2", reduction="sum" )[0] + self.prior(x, observation, observation_mask) if self.method == "lmc": method_kwargs = { "noise_scale": 0.1, "n_steps": 1, } method_kwargs.update(self.method_kwargs) seq = langevin_dynamics( seq, energy_func, score_func=None, step_size=self.step_size, **self.method_kwargs, ) elif self.method == "hmc": method_kwargs = { "mass": 1.0, "n_steps": 1, "n_leapfrog_steps": 5, } method_kwargs.update(self.method_kwargs) seq = hmc( seq, energy_func, step_size=self.step_size, **method_kwargs ) elif self.method == "udld": method_kwargs = { "mass": 1.0, "friction": 1.0, "n_steps": 1, "n_leapfrog_steps": 5, } method_kwargs.update(self.method_kwargs) seq = udld( seq, energy_func, step_size=self.step_size, **method_kwargs ) elif self.method == "cdld": method_kwargs = { "mass": 1.0, "n_steps": 1, "n_leapfrog_steps": 5, } method_kwargs.update(self.method_kwargs) # friction^2 = 4 x mass method_kwargs["friction"] = np.sqrt(4 * method_kwargs["mass"]) seq = udld( seq, energy_func, step_size=self.step_size, **method_kwargs ) return seq.detach() def refine(self, observation, observation_mask): return self._mcmc(observation, observation_mask) ================================================ FILE: src/uncond_ts_diff/utils.py ================================================ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 from copy import deepcopy from typing import Type, Dict from pathlib import Path from argparse import ArgumentParser, ArgumentTypeError from functools import partial import re import matplotlib.pyplot as plt import numpy as np import seaborn as sns import pandas as pd import torch from torch.utils.data import Dataset from pandas.tseries.frequencies import to_offset from gluonts.core.component import validated from gluonts.dataset import DataEntry from gluonts.dataset.field_names import FieldName from gluonts.dataset.split import split from gluonts.dataset.util import period_index from gluonts.transform import ( Chain, RemoveFields, SetField, AsNumpyArray, AddObservedValuesIndicator, AddTimeFeatures, AddAgeFeature, VstackFeatures, MapTransformation, ExpectedNumInstanceSampler, InstanceSplitter, TestSplitSampler, ValidationSplitSampler, ) from gluonts.model.forecast import SampleForecast sns.set( style="white", font_scale=1.1, rc={"figure.dpi": 125, "lines.linewidth": 2.5, "axes.linewidth": 1.5}, ) def filter_metrics(metrics, select={"ND", "NRMSE", "mean_wQuantileLoss"}): return {m: metrics[m].item() for m in select} def extract(a, t, x_shape): batch_size = t.shape[0] out = a.gather(-1, t.cpu()) return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device) def cosine_beta_schedule(timesteps, s=0.008): """ cosine schedule as proposed in https://arxiv.org/abs/2102.09672 """ steps = timesteps + 1 x = torch.linspace(0, timesteps, steps) alphas_cumprod = ( torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2 ) alphas_cumprod = alphas_cumprod / alphas_cumprod[0] betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) return torch.clip(betas, 0.0001, 0.9999) def linear_beta_schedule(timesteps): beta_start = 0.0001 beta_end = 0.1 return torch.linspace(beta_start, beta_end, timesteps) def plot_train_stats(df: pd.DataFrame, y_keys=None, skip_first_epoch=True): if skip_first_epoch: df = df.iloc[1:, :] if y_keys is None: y_keys = ["train_loss", "valid_loss"] fix, ax = plt.subplots(1, 1, figsize=(6.5, 4)) for y_key in y_keys: sns.lineplot( ax=ax, data=df, x="epochs", y=y_key, label=y_key.replace("_", " ").capitalize(), ) ax.legend() ax.set_ylabel("Loss") ax.set_xlabel("Epoch") plt.show() def get_lags_for_freq(freq_str: str): offset = to_offset(freq_str) if offset.n > 1: raise NotImplementedError( "Lags for freq multiple > 1 are not implemented yet." ) if offset.name == "H": lags_seq = [24 * i for i in [1, 2, 3, 4, 5, 6, 7, 14, 21, 28]] elif offset.name == "D" or offset.name == "B": # TODO: Fix lags for B lags_seq = [30 * i for i in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]] else: raise NotImplementedError( f"Lags for {freq_str} are not implemented yet." ) return lags_seq def create_transforms( num_feat_dynamic_real, num_feat_static_cat, num_feat_static_real, time_features, prediction_length, ): remove_field_names = [] if num_feat_static_real == 0: remove_field_names.append(FieldName.FEAT_STATIC_REAL) if num_feat_dynamic_real == 0: remove_field_names.append(FieldName.FEAT_DYNAMIC_REAL) return Chain( [RemoveFields(field_names=remove_field_names)] + ( [SetField(output_field=FieldName.FEAT_STATIC_CAT, value=[0])] if not num_feat_static_cat > 0 else [] ) + ( [SetField(output_field=FieldName.FEAT_STATIC_REAL, value=[0.0])] if not num_feat_static_real > 0 else [] ) + [ AsNumpyArray( field=FieldName.FEAT_STATIC_CAT, expected_ndim=1, dtype=int, ), AsNumpyArray( field=FieldName.FEAT_STATIC_REAL, expected_ndim=1, ), AsNumpyArray( field=FieldName.TARGET, expected_ndim=1, ), AddObservedValuesIndicator( target_field=FieldName.TARGET, output_field=FieldName.OBSERVED_VALUES, ), AddTimeFeatures( start_field=FieldName.START, target_field=FieldName.TARGET, output_field=FieldName.FEAT_TIME, time_features=time_features, pred_length=prediction_length, ), AddAgeFeature( target_field=FieldName.TARGET, output_field=FieldName.FEAT_AGE, pred_length=prediction_length, log_scale=True, ), AddMeanAndStdFeature( target_field=FieldName.TARGET, output_field="stats", ), VstackFeatures( output_field=FieldName.FEAT_TIME, input_fields=[FieldName.FEAT_TIME, FieldName.FEAT_AGE] + ( [FieldName.FEAT_DYNAMIC_REAL] if num_feat_dynamic_real > 0 else [] ), ), ] ) def create_splitter(past_length: int, future_length: int, mode: str = "train"): if mode == "train": instance_sampler = ExpectedNumInstanceSampler( num_instances=1, min_past=past_length, min_future=future_length, ) elif mode == "val": instance_sampler = ValidationSplitSampler(min_future=future_length) elif mode == "test": instance_sampler = TestSplitSampler() splitter = InstanceSplitter( target_field=FieldName.TARGET, is_pad_field=FieldName.IS_PAD, start_field=FieldName.START, forecast_start_field=FieldName.FORECAST_START, instance_sampler=instance_sampler, past_length=past_length, future_length=future_length, time_series_fields=[FieldName.FEAT_TIME, FieldName.OBSERVED_VALUES], ) return splitter def get_next_file_num( base_fname: str, base_dir: Path, file_type: str = "yaml", separator: str = "-", ): """Gets the next available file number in a directory. e.g., if `base_fname="results"` and `base_dir` has files ["results-0.yaml", "results-1.yaml"], this function returns 2. Parameters ---------- base_fname Base name of the file. base_dir Base directory where files are located. Returns ------- Next available file number """ if file_type == "": # Directory items = filter( lambda x: x.is_dir() and x.name.startswith(base_fname), base_dir.glob("*"), ) else: # File items = filter( lambda x: x.name.startswith(base_fname), base_dir.glob(f"*.{file_type}"), ) run_nums = list( map(lambda x: int(x.stem.replace(base_fname + separator, "")), items) ) + [-1] return max(run_nums) + 1 def str2bool(v): if isinstance(v, bool): return v if v.lower() in ("yes", "true", "t", "y", "1"): return True elif v.lower() in ("no", "false", "f", "n", "0"): return False else: raise ArgumentTypeError("Boolean value expected.") def add_config_to_argparser(config: Dict, parser: ArgumentParser): for k, v in config.items(): sanitized_key = re.sub(r"[^\w\-]", "", k).replace("-", "_") val_type = type(v) if val_type not in {int, float, str, bool}: print(f"WARNING: Skipping key {k}!") continue if val_type == bool: parser.add_argument(f"--{sanitized_key}", type=str2bool, default=v) else: parser.add_argument(f"--{sanitized_key}", type=val_type, default=v) return parser class AddMeanAndStdFeature(MapTransformation): @validated() def __init__( self, target_field: str, output_field: str, dtype: Type = np.float32, ) -> None: self.target_field = target_field self.feature_name = output_field self.dtype = dtype def map_transform(self, data: DataEntry, is_train: bool) -> DataEntry: data[self.feature_name] = np.array( [data[self.target_field].mean(), data[self.target_field].std()] ) return data class ScaleAndAddMeanFeature(MapTransformation): def __init__( self, target_field: str, output_field: str, prediction_length: int ) -> None: """Scale the time series using mean scaler and add the scale to `output_field`. Parameters ---------- target_field Key for target time series output_field Key for the mean feature prediction_length prediction length, only the time series before the last `prediction_length` timesteps is used for scale computation """ self.target_field = target_field self.feature_name = output_field self.prediction_length = prediction_length def map_transform(self, data, is_train: bool): scale = np.mean( np.abs(data[self.target_field][..., : -self.prediction_length]), axis=-1, keepdims=True, ) scale = np.maximum(scale, 1e-7) scaled_target = data[self.target_field] / scale data[self.target_field] = scaled_target data[self.feature_name] = scale return data class ScaleAndAddMinMaxFeature(MapTransformation): def __init__( self, target_field: str, output_field: str, prediction_length: int ) -> None: """Scale the time series using min-max scaler and add the scale to `output_field`. Parameters ---------- target_field Key for target time series output_field Key for the min-max feature prediction_length prediction length, only the time series before the last `prediction_length` timesteps is used for scale computation """ self.target_field = target_field self.feature_name = output_field self.prediction_length = prediction_length def map_transform(self, data, is_train: bool): full_seq = data[self.target_field][..., : -self.prediction_length] min_val = np.min(full_seq, axis=-1, keepdims=True) max_val = np.max(full_seq, axis=-1, keepdims=True) loc = min_val scale = np.maximum(max_val - min_val, 1e-7) scaled_target = (full_seq - loc) / scale data[self.target_field] = scaled_target data[self.feature_name] = (loc, scale) return data def descale(data, scale, scaling_type): if scaling_type == "mean": return data * scale elif scaling_type == "min-max": loc, scale = scale return data * scale + loc else: raise ValueError(f"Unknown scaling type: {scaling_type}") def predict_and_descale(predictor, dataset, num_samples, scaling_type): """Generates forecasts using the predictor on the test dataset and then scales them back to the original space using the scale feature from `ScaleAndAddMeanFeature` or `ScaleAndAddMinMaxFeature` transformation. Parameters ---------- predictor GluonTS predictor dataset Test dataset num_samples Number of forecast samples scaling_type Scaling type should be one of {"mean", "min-max"} Min-max scaling is used in TimeGAN, defaults to "mean" Yields ------ SampleForecast objects Raises ------ ValueError If the predictor generates Forecast objects other than SampleForecast """ forecasts = predictor.predict(dataset, num_samples=num_samples) for input_ts, fcst in zip(dataset, forecasts): scale = input_ts["scale"] if isinstance(fcst, SampleForecast): fcst.samples = descale( fcst.samples, scale, scaling_type=scaling_type ) else: raise ValueError("Only SampleForecast objects supported!") yield fcst def to_dataframe_and_descale(input_label, scaling_type) -> pd.DataFrame: """Glues together "input" and "label" time series and scales the back using the scale feature from transformation. Parameters ---------- input_label Input-Label pair generated from the test template scaling_type Scaling type should be one of {"mean", "min-max"} Min-max scaling is used in TimeGAN, defaults to "mean" Returns ------- A DataFrame containing the time series """ start = input_label[0][FieldName.START] scale = input_label[0]["scale"] targets = [entry[FieldName.TARGET] for entry in input_label] full_target = np.concatenate(targets, axis=-1) full_target = descale(full_target, scale, scaling_type=scaling_type) index = period_index( {FieldName.START: start, FieldName.TARGET: full_target} ) return pd.DataFrame(full_target.transpose(), index=index) def make_evaluation_predictions_with_scaling( dataset, predictor, num_samples: int = 100, scaling_type="mean" ): """A customized version of `make_evaluation_predictions` utility that first scales the test time series, generates the forecast and the scales it back to the original space. Parameters ---------- dataset Test dataset predictor GluonTS predictor num_samples, optional Number of test samples, by default 100 scaling_type, optional Scaling type should be one of {"mean", "min-max"} Min-max scaling is used in TimeGAN, defaults to "mean" Returns ------- A tuple of forecast and time series iterators """ window_length = predictor.prediction_length + predictor.lead_time _, test_template = split(dataset, offset=-window_length) test_data = test_template.generate_instances(window_length) input_test_data = list(test_data.input) return ( predict_and_descale( predictor, input_test_data, num_samples=num_samples, scaling_type=scaling_type, ), map( partial(to_dataframe_and_descale, scaling_type=scaling_type), test_data, ), ) class PairDataset(Dataset): def __init__(self, x, y) -> None: super().__init__() assert x.shape[0] == y.shape[0] self.x = x self.y = y def __getitem__(self, index): return self.x[index], self.y[index] def __len__(self): return self.x.shape[0] class GluonTSNumpyDataset: """GluonTS dataset from a numpy array. Parameters ---------- data Numpy array of samples with shape [N, T]. start_date, optional Dummy start date field, by default pd.Period("2023", "H") """ def __init__( self, data: np.ndarray, start_date: pd.Period = pd.Period("2023", "H") ): self.data = data self.start_date = start_date def __iter__(self): for ts in self.data: item = {"target": ts, "start": self.start_date} yield item def __len__(self): return len(self.data) class MaskInput(MapTransformation): @validated() def __init__( self, target_field: str, observed_field: str, context_length: int, missing_scenario: str, missing_values: int, dtype: Type = np.float32, ) -> None: # FIXME: Remove hardcoding of fields self.target_field = target_field self.observed_field = observed_field self.context_length = context_length self.missing_scenario = missing_scenario self.missing_values = missing_values self.dtype = dtype def map_transform(self, data: DataEntry, is_train: bool) -> DataEntry: data = deepcopy(data) data["orig_past_target"] = data["past_target"].copy() if self.missing_scenario == "BM-E" and self.missing_values > 0: data["past_target"][-self.missing_values :] = 0 data["past_observed_values"][-self.missing_values :] = 0 elif self.missing_scenario == "BM-B" and self.missing_values > 0: data["past_target"][ -self.context_length : -self.context_length + self.missing_values ] = 0 data["past_observed_values"][ -self.context_length : -self.context_length + self.missing_values ] = 0 elif self.missing_scenario == "RM" and self.missing_values > 0: weights = torch.ones(self.context_length) missing_idxs = -self.context_length + torch.multinomial( weights, self.missing_values, replacement=False ) data["past_target"][missing_idxs] = 0 data["past_observed_values"][missing_idxs] = 0 return data class ConcatDataset: def __init__(self, test_pairs, axis=-1) -> None: self.test_pairs = test_pairs self.axis = axis def _concat(self, test_pairs): for t1, t2 in test_pairs: yield { "target": np.concatenate( [t1["target"], t2["target"]], axis=self.axis ), "start": t1["start"], } def __iter__(self): yield from self._concat(self.test_pairs)