main 3eafeffdffef cached
126 files
289.4 KB
78.3k tokens
249 symbols
1 requests
Download .txt
Showing preview only (321K chars total). Download the full file or copy to clipboard to get everything.
Repository: amazon-science/unconditional-time-series-diffusion
Branch: main
Commit: 3eafeffdffef
Files: 126
Total size: 289.4 KB

Directory structure:
gitextract_qvnzuyyo/

├── .gitignore
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── NOTICE
├── README.md
├── THIRD-PARTY-LICENSES.txt
├── bin/
│   ├── guidance_experiment.py
│   ├── refinement_experiment.py
│   ├── train_cond_model.py
│   ├── train_model.py
│   └── tstr_experiment.py
├── configs/
│   ├── guidance/
│   │   ├── guidance_electricity.yaml
│   │   ├── guidance_exchange.yaml
│   │   ├── guidance_kdd_cup.yaml
│   │   ├── guidance_m4.yaml
│   │   ├── guidance_solar.yaml
│   │   ├── guidance_traffic.yaml
│   │   ├── guidance_uber_tlc.yaml
│   │   └── guidance_wiki.yaml
│   ├── guidance.yaml
│   ├── refinement/
│   │   ├── electricity_nips-deepar.yaml
│   │   ├── electricity_nips-linear.yaml
│   │   ├── electricity_nips-seasonal_naive.yaml
│   │   ├── electricity_nips-transformer.yaml
│   │   ├── exchange_rate_nips-deepar.yaml
│   │   ├── exchange_rate_nips-linear.yaml
│   │   ├── exchange_rate_nips-seasonal_naive.yaml
│   │   ├── exchange_rate_nips-transformer.yaml
│   │   ├── kdd_cup_2018_without_missing-deepar.yaml
│   │   ├── kdd_cup_2018_without_missing-linear.yaml
│   │   ├── kdd_cup_2018_without_missing-seasonal_naive.yaml
│   │   ├── kdd_cup_2018_without_missing-transformer.yaml
│   │   ├── m4_hourly-deepar.yaml
│   │   ├── m4_hourly-linear.yaml
│   │   ├── m4_hourly-seasonal_naive.yaml
│   │   ├── m4_hourly-transformer.yaml
│   │   ├── solar_nips-deepar.yaml
│   │   ├── solar_nips-linear.yaml
│   │   ├── solar_nips-seasonal_naive.yaml
│   │   ├── solar_nips-transformer.yaml
│   │   ├── traffic_nips-deepar.yaml
│   │   ├── traffic_nips-linear.yaml
│   │   ├── traffic_nips-seasonal_naive.yaml
│   │   ├── traffic_nips-transformer.yaml
│   │   ├── uber_tlc_hourly-deepar.yaml
│   │   ├── uber_tlc_hourly-linear.yaml
│   │   ├── uber_tlc_hourly-seasonal_naive.yaml
│   │   ├── uber_tlc_hourly-transformer.yaml
│   │   ├── wiki2000_nips-deepar.yaml
│   │   ├── wiki2000_nips-linear.yaml
│   │   ├── wiki2000_nips-seasonal_naive.yaml
│   │   └── wiki2000_nips-transformer.yaml
│   ├── refinement.yaml
│   ├── train_tsdiff/
│   │   ├── train_electricity.yaml
│   │   ├── train_exchange.yaml
│   │   ├── train_kdd_cup.yaml
│   │   ├── train_m4.yaml
│   │   ├── train_missing_electricity.yaml
│   │   ├── train_missing_exchange.yaml
│   │   ├── train_missing_kdd_cup.yaml
│   │   ├── train_missing_solar.yaml
│   │   ├── train_missing_traffic.yaml
│   │   ├── train_missing_uber_tlc.yaml
│   │   ├── train_solar.yaml
│   │   ├── train_traffic.yaml
│   │   ├── train_uber_tlc.yaml
│   │   └── train_wiki.yaml
│   ├── train_tsdiff-cond/
│   │   ├── electricity_nips.yaml
│   │   ├── exchange_rate_nips.yaml
│   │   ├── kdd_cup_2018_without_missing.yaml
│   │   ├── m4_hourly.yaml
│   │   ├── missing_BM-B_electricity_nips.yaml
│   │   ├── missing_BM-B_exchange_rate_nips.yaml
│   │   ├── missing_BM-B_kdd_cup_2018_without_missing.yaml
│   │   ├── missing_BM-B_solar_nips.yaml
│   │   ├── missing_BM-B_traffic_nips.yaml
│   │   ├── missing_BM-B_uber_tlc_hourly.yaml
│   │   ├── missing_BM-E_electricity_nips.yaml
│   │   ├── missing_BM-E_exchange_rate_nips.yaml
│   │   ├── missing_BM-E_kdd_cup_2018_without_missing.yaml
│   │   ├── missing_BM-E_solar_nips.yaml
│   │   ├── missing_BM-E_traffic_nips.yaml
│   │   ├── missing_BM-E_uber_tlc_hourly.yaml
│   │   ├── missing_RM_electricity_nips.yaml
│   │   ├── missing_RM_exchange_rate_nips.yaml
│   │   ├── missing_RM_kdd_cup_2018_without_missing.yaml
│   │   ├── missing_RM_solar_nips.yaml
│   │   ├── missing_RM_traffic_nips.yaml
│   │   ├── missing_RM_uber_tlc_hourly.yaml
│   │   ├── solar_nips.yaml
│   │   ├── traffic_nips.yaml
│   │   ├── uber_tlc_hourly.yaml
│   │   └── wiki2000_nips.yaml
│   ├── train_tsdiff-cond.yaml
│   ├── train_tsdiff.yaml
│   ├── tstr/
│   │   ├── electricity_nips.yaml
│   │   ├── exchange_rate_nips.yaml
│   │   ├── kdd_cup_2018_without_missing.yaml
│   │   ├── m4_hourly.yaml
│   │   ├── solar_nips.yaml
│   │   ├── traffic_nips.yaml
│   │   ├── uber_tlc_hourly.yaml
│   │   └── wiki2000_nips.yaml
│   └── tstr.yaml
├── pyproject.toml
└── src/
    └── uncond_ts_diff/
        ├── arch/
        │   ├── __init__.py
        │   ├── backbones.py
        │   └── s4.py
        ├── configs.py
        ├── dataset.py
        ├── metrics/
        │   ├── __init__.py
        │   └── linear_pred_score.py
        ├── model/
        │   ├── __init__.py
        │   ├── callback.py
        │   ├── diffusion/
        │   │   ├── _base.py
        │   │   ├── tsdiff.py
        │   │   └── tsdiff_cond.py
        │   └── linear/
        │       ├── _estimator.py
        │       └── _scaler.py
        ├── predictor.py
        ├── sampler/
        │   ├── __init__.py
        │   ├── _base.py
        │   ├── observation_guidance.py
        │   └── refiner.py
        └── utils.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
__pycache__
lightning_logs/
.DS_Store
*.egg-info
/results/
/ckpts/
/saved_samples/
.vscode/
/sm_runs/
/data/
/checkpoints/

================================================
FILE: CODE_OF_CONDUCT.md
================================================
## Code of Conduct
This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).
For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact
opensource-codeofconduct@amazon.com with any additional questions or comments.


================================================
FILE: CONTRIBUTING.md
================================================
# Contributing Guidelines

Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional
documentation, we greatly value feedback and contributions from our community.

Please read through this document before submitting any issues or pull requests to ensure we have all the necessary
information to effectively respond to your bug report or contribution.


## Reporting Bugs/Feature Requests

We welcome you to use the GitHub issue tracker to report bugs or suggest features.

When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already
reported the issue. Please try to include as much information as you can. Details like these are incredibly useful:

* A reproducible test case or series of steps
* The version of our code being used
* Any modifications you've made relevant to the bug
* Anything unusual about your environment or deployment


## Contributing via Pull Requests
Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that:

1. You are working against the latest source on the *main* branch.
2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already.
3. You open an issue to discuss any significant work - we would hate for your time to be wasted.

To send us a pull request, please:

1. Fork the repository.
2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change.
3. Ensure local tests pass.
4. Commit to your fork using clear commit messages.
5. Send us a pull request, answering any default questions in the pull request interface.
6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation.

GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and
[creating a pull request](https://help.github.com/articles/creating-a-pull-request/).


## Finding contributions to work on
Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start.


## Code of Conduct
This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).
For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact
opensource-codeofconduct@amazon.com with any additional questions or comments.


## Security issue notifications
If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue.


## Licensing

See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution.


================================================
FILE: LICENSE
================================================

                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.


================================================
FILE: NOTICE
================================================
Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.


================================================
FILE: README.md
================================================
# TSDiff: An Unconditional Diffusion Model for Time Series

[![preprint](https://img.shields.io/static/v1?label=arXiv&message=2307.11494&color=B31B1B)](https://arxiv.org/abs/2307.11494)
[![License: MIT](https://img.shields.io/badge/License-Apache--2.0-yellow.svg)](https://opensource.org/licenses/Apache-2.0)
[![Venue:ICML 2023](https://img.shields.io/badge/Venue-NeurIPS%202023-007CFF)](https://neurips.cc/)
<p align="center">
  <img src="./assets/overview.png" width="100%">
  <br />
  <span>Fig. 1: An overview of TSDiff’s use cases. <b>Predict:</b> By utilizing observation self-guidance, TSDiff can be
conditioned during inference to perform predictive tasks such as forecasting. <b>Refine:</b> Predictions
of base forecasters can be improved by leveraging the implicit probability density of TSDiff.
<b>Synthesize:</b> Realistic samples generated by TSDiff can be used to train downstream forecasters achieving good
performance on real test data.</span>
</p>

---

This repository contains the official implementation of the NeurIPS 2023 paper [*Predict, Refine, Synthesize: Self-Guiding Diffusion Models for Probabilistic Time Series Forecasting*](https://arxiv.org/abs/2307.11494). In this paper, we propose *TSDiff*, an unconditional diffusion model for time series. Our proposed self-guidance mechanism enables conditioning TSDiff for downstream tasks during inference, without requiring auxiliary networks or altering the training procedure. Furthermore, our refinement scheme leverages the implicit density learned by the diffusion model to iteratively refine the predictions of base forecasters. Finally, we demonstrate the high quality of the synthetic time series by training downstrain models solely on generated data and introducing the *Linear Predictive Score (LPS)*.

<p align="center">
  <img src="./assets/forecasts.png" width="60%">
  <br />
  <span>Fig. 2: Example forecasts generated by TSDiff-Q for
time series in Electricity, KDDCup, and Exchange — three datasets with different frequencies and/or prediction lengths.</span>
</p>

## Installation

TSDiff requires Python 3.8 or higher.

* Create a conda environment (optional, but recommended).
```sh
conda create --name tsdiff --yes python=3.8 && conda activate tsdiff
```
* Install this package.
```sh
pip install --editable "."
```

> [!TIP]  
> We have some updates in the `update` branch. If you're interested in testing out TSDiff or [training it on a custom dataset](https://github.com/amazon-science/unconditional-time-series-diffusion/issues/7), using the `update` branch maybe faster for training.

## Usage

### Training Models

Train models using the `train_model.py` and `train_cond_model.py` scripts for `TSDiff` and `TSDiff-Cond`, respectively. Sample configurations can be found in `configs/train_tsdiff.yaml` and `configs/train_tsdiff-cond.yaml`. Specific configurations used in the paper can be found in `configs/train_tsdiff` and `configs/train_tsdiff-cond`.

Example commands for regular (i.e., no missing values) forecasting:
```sh
# Train TSDiff on the Uber dataset for regular forecasting
python bin/train_model.py -c configs/train_tsdiff/train_uber_tlc.yaml

# Train TSDiff on the M4 dataset for regular forecasting
python bin/train_model.py -c configs/train_tsdiff/train_m4.yaml

# Train TSDiff-Cond on the Uber dataset for regular forecasting
python bin/train_cond_model.py -c configs/train_tsdiff-cond/uber_tlc_hourly.yaml

# Train TSDiff-Cond on the M4 dataset for regular forecasting
python bin/train_cond_model.py -c configs/train_tsdiff-cond/m4_hourly.yaml
```

Example commands for forecasting with missing values:
```sh
# Train TSDiff on the Uber dataset for the missing values experiment
python bin/train_model.py -c configs/train_tsdiff/train_missing_uber_tlc.yaml

# Train TSDiff on the KDDCup dataset for the missing values experiment
python bin/train_model.py -c configs/train_tsdiff/train_missing_kdd_cup.yaml

# Train TSDiff-Cond on the Uber dataset for the RM missing values experiment
python bin/train_cond_model.py -c configs/train_tsdiff-cond/missing_RM_uber_tlc_hourly.yaml

# Train TSDiff-Cond on the KDDCup dataset for the BM-B missing values experiment
python bin/train_cond_model.py -c configs/train_tsdiff-cond/missing_BM-B_kdd_cup_2018_without_missing.yaml

# Train TSDiff-Cond on the KDDCup dataset for the BM-E missing values experiment
python bin/train_cond_model.py -c configs/train_tsdiff-cond/missing_BM-E_kdd_cup_2018_without_missing.yaml
```
Note that for TSDiff we train only one model and all the missing value scenarios are evaluated using the same unconditional model. However, for TSDiff-Cond, one model is trained per missingness scenario.

### Evaluating Models
The unconditional models trained above can be used for the following tasks.

#### Predict using Observation Self-Guidance
Use the `guidance_experiment.py` script and `configs/guidance.yaml` config to run the forecasting experiments. Specific configurations used in the paper can be found in `configs/guidance/`.

Example commands:
```sh
# Run observation self-guidance on the Solar dataset
python bin/guidance_experiment.py -c configs/guidance/guidance_solar.yaml --ckpt /path/to/ckpt

# Run observation self-guidance on the KDDCup dataset
python bin/guidance_experiment.py -c configs/guidance/guidance_kdd_cup.yaml --ckpt /path/to/ckpt
```

#### Refine Predictions of Base Forecasters
Use `refinement_experiment.py` script and `configs/refinement.yaml` config to run the refinement experiments. Specific configurations used in the paper can be found in `configs/refinement/`.

Example commands:
```sh
# Refine predictions from the Linear model on the Solar dataset
python bin/refinement_experiment.py -c configs/refinement/solar_nips-linear.yaml --ckpt /path/to/ckpt

# Refine predictions from the DeepAR model on the M4 dataset
python bin/refinement_experiment.py -c configs/refinement/m4_hourly-deepar.yaml --ckpt /path/to/ckpt
```
#### Train Downstream Models using Synthetic Data
Use `tstr_experiment.py` script and `configs/tstr.yaml` config to run the _train on synthetic-test on real_ experiments. Specific configurations used in the paper can be found in `configs/tstr/`.

Example commands:
```sh
# TSTR on the Solar Dataset
python bin/tstr_experiment.py -c configs/tstr/solar_nips.yaml --ckpt /path/to/ckpt

# TSTR on the KDDCup Dataset
python bin/tstr_experiment.py -c configs/tstr/kdd_cup_2018_without_missing.yaml --ckpt /path/to/ckpt
```

## BibTeX

If you find this repository or the ideas presented in our paper useful, please consider citing.

```
@inproceedings{kollovieh2023predict,
 author    = {Kollovieh, Marcel and Ansari, Abdul Fatir and Bohlke-Schneider, Michael and Zschiegner, Jasper and Wang, Hao and Wang, Yuyang},
 title     = {Predict, Refine, Synthesize: Self-Guiding Diffusion Models for Probabilistic Time Series Forecasting},
 booktitle = {Advances in Neural Information Processing Systems},
 year      = {2023}
}
```

## Security

See [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more information.

## License

This project is licensed under the Apache-2.0 License.



================================================
FILE: THIRD-PARTY-LICENSES.txt
================================================
** state-spaces; version 1.0 -- https://github.com/HazyResearch/state-spaces
 
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/

TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

1. Definitions.

"License" shall mean the terms and conditions for use, reproduction, and
distribution as defined by Sections 1 through 9 of this document.

"Licensor" shall mean the copyright owner or entity authorized by the copyright
owner that is granting the License.

"Legal Entity" shall mean the union of the acting entity and all other entities
that control, are controlled by, or are under common control with that entity.
For the purposes of this definition, "control" means (i) the power, direct or
indirect, to cause the direction or management of such entity, whether by
contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.

"You" (or "Your") shall mean an individual or Legal Entity exercising
permissions granted by this License.

"Source" form shall mean the preferred form for making modifications, including
but not limited to software source code, documentation source, and configuration
files.

"Object" form shall mean any form resulting from mechanical transformation or
translation of a Source form, including but not limited to compiled object code,
generated documentation, and conversions to other media types.

"Work" shall mean the work of authorship, whether in Source or Object form, made
available under the License, as indicated by a copyright notice that is included
in or attached to the work (an example is provided in the Appendix below).

"Derivative Works" shall mean any work, whether in Source or Object form, that
is based on (or derived from) the Work and for which the editorial revisions,
annotations, elaborations, or other modifications represent, as a whole, an
original work of authorship. For the purposes of this License, Derivative Works
shall not include works that remain separable from, or merely link (or bind by
name) to the interfaces of, the Work and Derivative Works thereof.

"Contribution" shall mean any work of authorship, including the original version
of the Work and any modifications or additions to that Work or Derivative Works
thereof, that is intentionally submitted to Licensor for inclusion in the Work
by the copyright owner or by an individual or Legal Entity authorized to submit
on behalf of the copyright owner. For the purposes of this definition,
"submitted" means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems, and
issue tracking systems that are managed by, or on behalf of, the Licensor for
the purpose of discussing and improving the Work, but excluding communication
that is conspicuously marked or otherwise designated in writing by the copyright
owner as "Not a Contribution."

"Contributor" shall mean Licensor and any individual or Legal Entity on behalf
of whom a Contribution has been received by Licensor and subsequently
incorporated within the Work.

2. Grant of Copyright License. Subject to the terms and conditions of this
License, each Contributor hereby grants to You a perpetual, worldwide, non-
exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce,
prepare Derivative Works of, publicly display, publicly perform, sublicense, and
distribute the Work and such Derivative Works in Source or Object form.

3. Grant of Patent License. Subject to the terms and conditions of this License,
each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-
charge, royalty-free, irrevocable (except as stated in this section) patent
license to make, have made, use, offer to sell, sell, import, and otherwise
transfer the Work, where such license applies only to those patent claims
licensable by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s) with the Work
to which such Contribution(s) was submitted. If You institute patent litigation
against any entity (including a cross-claim or counterclaim in a lawsuit)
alleging that the Work or a Contribution incorporated within the Work
constitutes direct or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate as of the date
such litigation is filed.

4. Redistribution. You may reproduce and distribute copies of the Work or
Derivative Works thereof in any medium, with or without modifications, and in
Source or Object form, provided that You meet the following conditions:

     (a) You must give any other recipients of the Work or Derivative Works a
copy of this License; and

     (b) You must cause any modified files to carry prominent notices stating
that You changed the files; and

     (c) You must retain, in the Source form of any Derivative Works that You
distribute, all copyright, patent, trademark, and attribution notices from the
Source form of the Work, excluding those notices that do not pertain to any part
of the Derivative Works; and

     (d) If the Work includes a "NOTICE" text file as part of its distribution,
then any Derivative Works that You distribute must include a readable copy of
the attribution notices contained within such NOTICE file, excluding those
notices that do not pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed as part of the
Derivative Works; within the Source form or documentation, if provided along
with the Derivative Works; or, within a display generated by the Derivative
Works, if and wherever such third-party notices normally appear. The contents of
the NOTICE file are for informational purposes only and do not modify the
License. You may add Your own attribution notices within Derivative Works that
You distribute, alongside or as an addendum to the NOTICE text from the Work,
provided that such additional attribution notices cannot be construed as
modifying the License.

     You may add Your own copyright statement to Your modifications and may
provide additional or different license terms and conditions for use,
reproduction, or distribution of Your modifications, or for any such Derivative
Works as a whole, provided Your use, reproduction, and distribution of the Work
otherwise complies with the conditions stated in this License.

5. Submission of Contributions. Unless You explicitly state otherwise, any
Contribution intentionally submitted for inclusion in the Work by You to the
Licensor shall be under the terms and conditions of this License, without any
additional terms or conditions. Notwithstanding the above, nothing herein shall
supersede or modify the terms of any separate license agreement you may have
executed with Licensor regarding such Contributions.

6. Trademarks. This License does not grant permission to use the trade names,
trademarks, service marks, or product names of the Licensor, except as required
for reasonable and customary use in describing the origin of the Work and
reproducing the content of the NOTICE file.

7. Disclaimer of Warranty. Unless required by applicable law or agreed to in
writing, Licensor provides the Work (and each Contributor provides its
Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied, including, without limitation, any warranties
or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any risks
associated with Your exercise of permissions under this License.

8. Limitation of Liability. In no event and under no legal theory, whether in
tort (including negligence), contract, or otherwise, unless required by
applicable law (such as deliberate and grossly negligent acts) or agreed to in
writing, shall any Contributor be liable to You for damages, including any
direct, indirect, special, incidental, or consequential damages of any character
arising as a result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill, work stoppage,
computer failure or malfunction, or any and all other commercial damages or
losses), even if such Contributor has been advised of the possibility of such
damages.

9. Accepting Warranty or Additional Liability. While redistributing the Work or
Derivative Works thereof, You may choose to offer, and charge a fee for,
acceptance of support, warranty, indemnity, or other liability obligations
and/or rights consistent with this License. However, in accepting such
obligations, You may act only on Your own behalf and on Your sole
responsibility, not on behalf of any other Contributor, and only if You agree to
indemnify, defend, and hold each Contributor harmless for any liability incurred
by, or claims asserted against, such Contributor by reason of your accepting any
such warranty or additional liability.

END OF TERMS AND CONDITIONS

APPENDIX: How to apply the Apache License to your work.

To apply the Apache License to your work, attach the following boilerplate
notice, with the fields enclosed by brackets "[]" replaced with your own
identifying information. (Don't include the brackets!)  The text should be
enclosed in the appropriate comment syntax for the file format. We also
recommend that a file or class name and description of purpose be included on
the same "printed page" as the copyright notice for easier identification within
third-party archives.

Copyright [yyyy] [name of copyright owner]

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

* For state-spaces see also this required NOTICE:
    Copyright 2022 Albert Gu and Karan Goel and Christopher Re


================================================
FILE: bin/guidance_experiment.py
================================================
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import logging
import argparse
from pathlib import Path

import yaml
import torch
from tqdm.auto import tqdm
from gluonts.dataset.field_names import FieldName
from gluonts.evaluation import make_evaluation_predictions, Evaluator

from uncond_ts_diff.utils import (
    create_transforms,
    create_splitter,
    get_next_file_num,
    add_config_to_argparser,
    filter_metrics,
    MaskInput,
)
from uncond_ts_diff.model import TSDiff
from uncond_ts_diff.dataset import get_gts_dataset
from uncond_ts_diff.sampler import (
    DDPMGuidance,
    DDIMGuidance,
)
import uncond_ts_diff.configs as diffusion_configs

guidance_map = {"ddpm": DDPMGuidance, "ddim": DDIMGuidance}


def load_model(config):
    model = TSDiff(
        **getattr(
            diffusion_configs,
            config.get("diffusion_config", "diffusion_small_config"),
        ),
        freq=config["freq"],
        use_features=config["use_features"],
        use_lags=config["use_lags"],
        normalization="mean",
        context_length=config["context_length"],
        prediction_length=config["prediction_length"],
        init_skip=config["init_skip"],
    )
    model.load_state_dict(
        torch.load(config["ckpt"], map_location="cpu"),
        strict=True,
    )
    model = model.to(config["device"])
    return model


def evaluate_guidance(
    config, model, test_dataset, transformation, num_samples=100
):
    logger.info(f"Evaluating with {num_samples} samples.")
    results = []
    if config["setup"] == "forecasting":
        missing_data_kwargs_list = [
            {
                "missing_scenario": "none",
                "missing_values": 0,
            }
        ]
        config["missing_data_configs"] = missing_data_kwargs_list
    elif config["setup"] == "missing_values":
        missing_data_kwargs_list = config["missing_data_configs"]
    else:
        raise ValueError(f"Unknown setup {config['setup']}")

    Guidance = guidance_map[config["sampler"]]
    sampler_params = config["sampler_params"]
    for missing_data_kwargs in missing_data_kwargs_list:
        logger.info(
            f"Evaluating scenario '{missing_data_kwargs['missing_scenario']}' "
            f"with {missing_data_kwargs['missing_values']:.1f} missing_values."
        )

        sampler = Guidance(
            model=model,
            prediction_length=config["prediction_length"],
            num_samples=num_samples,
            **missing_data_kwargs,
            **sampler_params,
        )

        transformed_testdata = transformation.apply(
            test_dataset, is_train=False
        )
        test_splitter = create_splitter(
            past_length=config["context_length"] + max(model.lags_seq),
            future_length=config["prediction_length"],
            mode="test",
        )

        masking_transform = MaskInput(
            FieldName.TARGET,
            FieldName.OBSERVED_VALUES,
            config["context_length"],
            missing_data_kwargs["missing_scenario"],
            missing_data_kwargs["missing_values"],
        )
        test_transform = test_splitter + masking_transform

        predictor = sampler.get_predictor(
            test_transform,
            batch_size=1280 // num_samples,
            device=config["device"],
        )
        forecast_it, ts_it = make_evaluation_predictions(
            dataset=transformed_testdata,
            predictor=predictor,
            num_samples=num_samples,
        )
        forecasts = list(tqdm(forecast_it, total=len(transformed_testdata)))
        tss = list(ts_it)
        evaluator = Evaluator()
        metrics, _ = evaluator(tss, forecasts)
        metrics = filter_metrics(metrics)
        results.append(dict(**missing_data_kwargs, **metrics))

    return results


def main(config: dict, log_dir: str):
    # Read global parameters
    dataset_name = config["dataset"]
    freq = config["freq"]
    prediction_length = config["prediction_length"]
    num_samples = config["num_samples"]

    # Load dataset and model
    logger.info("Loading model")
    model = load_model(config)
    dataset = get_gts_dataset(dataset_name)
    assert dataset.metadata.freq == freq
    assert dataset.metadata.prediction_length == prediction_length

    # Setup data transformation and loading
    transformation = create_transforms(
        num_feat_dynamic_real=0,
        num_feat_static_cat=0,
        num_feat_static_real=0,
        time_features=model.time_features,
        prediction_length=prediction_length,
    )

    # Run guidance
    results = evaluate_guidance(
        config, model, dataset.test, transformation, num_samples=num_samples
    )

    # Save results
    log_dir = Path(log_dir) / "guidance_logs"
    log_dir.mkdir(exist_ok=True, parents=True)
    base_filename = "results"
    run_num = get_next_file_num(
        base_filename, log_dir, file_type="yaml", separator="-"
    )
    save_path = log_dir / f"{base_filename}-{run_num}.yaml"

    with open(save_path, "w") as fp:
        yaml.safe_dump(
            {"config": config, "metrics": results},
            fp,
            default_flow_style=False,
            sort_keys=False,
        )


if __name__ == "__main__":
    # Setup Logger
    logging.basicConfig(
        format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
    )
    logger = logging.getLogger(__file__)
    logger.setLevel(logging.INFO)

    # Setup argparse
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-c", "--config", type=str, required=True, help="Path to yaml config"
    )
    parser.add_argument(
        "--out_dir", type=str, default="./results", help="Path to results dir"
    )
    args, _ = parser.parse_known_args()

    with open(args.config, "r") as fp:
        config = yaml.safe_load(fp)

    # Update config from command line
    parser = add_config_to_argparser(config=config, parser=parser)
    args = parser.parse_args()
    config_updates = vars(args)
    for k in config.keys() & config_updates.keys():
        orig_val = config[k]
        updated_val = config_updates[k]
        if updated_val != orig_val:
            logger.info(f"Updated key '{k}': {orig_val} -> {updated_val}")
    config.update(config_updates)

    main(config=config, log_dir=args.out_dir)


================================================
FILE: bin/refinement_experiment.py
================================================
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import json
import copy
import logging
import argparse
from pathlib import Path

import yaml
import torch
import numpy as np
from tqdm.auto import tqdm
from gluonts.mx import DeepAREstimator, TransformerEstimator
from gluonts.model.seasonal_naive import SeasonalNaivePredictor
from gluonts.evaluation import make_evaluation_predictions, Evaluator
from gluonts.dataset.loader import TrainDataLoader
from gluonts.itertools import Cached
from gluonts.torch.batchify import batchify

from uncond_ts_diff.utils import (
    create_transforms,
    create_splitter,
    get_next_file_num,
    add_config_to_argparser,
    filter_metrics,
)
from uncond_ts_diff.model import TSDiff, LinearEstimator
from uncond_ts_diff.dataset import get_gts_dataset
from uncond_ts_diff.sampler import (
    MostLikelyRefiner,
    MCMCRefiner,
    DDPMGuidance,
    DDIMGuidance,
)
import uncond_ts_diff.configs as diffusion_configs

guidance_map = {"ddpm": DDPMGuidance, "ddim": DDIMGuidance}
refiner_map = {"most_likely": MostLikelyRefiner, "mcmc": MCMCRefiner}


def load_model(config):
    model = TSDiff(
        **getattr(
            diffusion_configs,
            config.get("diffusion_config", "diffusion_small_config"),
        ),
        freq=config["freq"],
        use_features=config["use_features"],
        use_lags=config["use_lags"],
        normalization="mean",
        context_length=config["context_length"],
        prediction_length=config["prediction_length"],
        init_skip=config["init_skip"],
    )
    model.load_state_dict(
        torch.load(config["ckpt"], map_location="cpu"),
        strict=True,
    )
    model = model.to(config["device"])
    return model


def get_best_diffusion_step(model: TSDiff, data_loader, device):
    losses = np.zeros(model.timesteps)
    batch = {
        k: v.to(device)
        for k, v in next(iter(data_loader)).items()
        if isinstance(v, torch.Tensor)
    }
    x, features, scale = model._extract_features(batch)
    for t in range(model.timesteps):
        loss, _, _ = model.p_losses(
            x.to(device), torch.tensor([t], device=device)
        )
        losses[t] = loss

    best_t = ((losses - losses.mean()) ** 2).argmin()
    return best_t


def train_and_forecast_base_model(dataset, base_model_name, config):
    base_model_kwargs = config.get("base_model_params", {})
    if base_model_name == "deepar":
        predictor = DeepAREstimator(
            prediction_length=dataset.metadata.prediction_length,
            freq=dataset.metadata.freq,
            **base_model_kwargs,
        ).train(list(dataset.train), cache_data=True)
    elif base_model_name == "transformer":
        predictor = TransformerEstimator(
            prediction_length=dataset.metadata.prediction_length,
            freq=dataset.metadata.freq,
            **base_model_kwargs,
        ).train(list(dataset.train), cache_data=True)
    elif base_model_name == "seasonal_naive":
        predictor = SeasonalNaivePredictor(
            freq=dataset.metadata.freq,
            prediction_length=dataset.metadata.prediction_length,
            **base_model_kwargs,
        )
    elif base_model_name == "linear":
        num_train_samples = 10000
        predictor = LinearEstimator(
            freq=dataset.metadata.freq,
            prediction_length=dataset.metadata.prediction_length,
            context_length=config["context_length"],
            num_train_samples=num_train_samples,
            **base_model_kwargs,
        ).train(list(dataset.train), cache_data=True)
    else:
        raise ValueError(f"Unsupported base model {base_model_name}!")

    fcst_iter, ts_iter = make_evaluation_predictions(
        dataset=dataset.test,
        predictor=predictor,
        num_samples=config["num_samples"],
    )
    fcsts = list(tqdm(fcst_iter, total=len(dataset.test)))
    tss = list(ts_iter)

    return fcsts, tss


def forecast_guidance(
    dataset,
    base_model_name,
    config,
    diffusion_model,
    transformed_testdata,
    test_splitter,
):
    assert len(dataset.test) == len(transformed_testdata)
    base_model_kwargs = config.get("base_model_params", {})

    Guidance = guidance_map[base_model_name]
    predictor = Guidance(
        model=diffusion_model,
        prediction_length=dataset.metadata.prediction_length,
        num_samples=config["num_samples"],
        **base_model_kwargs,
    ).get_predictor(
        input_transform=test_splitter,
        batch_size=1280 // config["num_samples"],
        device=config["device"],
    )

    fcst_iter, ts_iter = make_evaluation_predictions(
        dataset=transformed_testdata,
        predictor=predictor,
        num_samples=config["num_samples"],
    )
    fcsts = list(tqdm(fcst_iter, total=len(dataset.test)))
    tss = list(ts_iter)

    return fcsts, tss


def main(config: dict, log_dir: str):
    # Read global parameters
    dataset_name = config["dataset"]
    device = config["device"]
    context_length = config["context_length"]
    prediction_length = config["prediction_length"]
    base_model_name = config["base_model"]
    num_samples = config["num_samples"]

    # Load dataset and model
    logger.info("Loading model")
    dataset = get_gts_dataset(dataset_name)
    config["freq"] = dataset.metadata.freq

    assert prediction_length == dataset.metadata.prediction_length

    model = load_model(config)

    # Setup data transformation and loading
    transformation = create_transforms(
        num_feat_dynamic_real=0,
        num_feat_static_cat=0,
        num_feat_static_real=0,
        time_features=model.time_features,
        prediction_length=prediction_length,
    )
    transformed_data = transformation.apply(list(dataset.train), is_train=True)

    transformed_testdata = transformation.apply(
        list(dataset.test), is_train=False
    )

    training_splitter = create_splitter(
        past_length=context_length + max(model.lags_seq),
        future_length=prediction_length,
        mode="train",
    )
    test_splitter = create_splitter(
        past_length=context_length + max(model.lags_seq),
        future_length=prediction_length,
        mode="test",
    )

    train_dataloader = TrainDataLoader(
        Cached(transformed_data),
        batch_size=1024,
        stack_fn=batchify,
        transform=training_splitter,
        num_batches_per_epoch=2048,
    )

    best_t = get_best_diffusion_step(model, train_dataloader, device)

    # Train base model & get initial forecasts
    logger.info("Training base model")
    if base_model_name in {"ddpm", "ddim"}:
        base_fcsts, tss = forecast_guidance(
            dataset,
            base_model_name,
            config,
            diffusion_model=model,
            transformed_testdata=transformed_testdata,
            test_splitter=test_splitter,
        )
    else:
        base_fcsts, tss = train_and_forecast_base_model(
            dataset, base_model_name, config
        )

    # Evaluate base forecasts
    evaluator = Evaluator()
    baseline_metrics, _ = evaluator(tss, base_fcsts)
    baseline_metrics = filter_metrics(baseline_metrics)

    # Run refinement
    log_dir = Path(log_dir) / "refinement_logs"
    log_dir.mkdir(exist_ok=True, parents=True)
    base_filename = "results"
    run_num = get_next_file_num(
        base_filename, log_dir, file_type="yaml", separator="-"
    )
    save_path = log_dir / f"{base_filename}-{run_num}.yaml"

    results = [
        {
            "model": "baseline",
            "model_params": {
                "name": base_model_name,
                **config.get("base_model_params", {}),
            },
            **baseline_metrics,
        }
    ]

    n_refiner_configs = len(config["refiner_configs"])
    for i, ref_config in enumerate(config["refiner_configs"]):
        logger.info(
            f"Running refiner ({i+1}/{n_refiner_configs}): {json.dumps(ref_config)}"
        )

        refiner_config = copy.deepcopy(ref_config)
        refiner_name = refiner_config.pop("refiner_name")
        Refiner = refiner_map[refiner_name]
        refiner = Refiner(
            model,
            prediction_length,
            init=iter(base_fcsts),
            num_samples=num_samples,
            fixed_t=best_t,
            iterations=config["iterations"],
            **refiner_config,
        )
        refiner_predictor = refiner.get_predictor(
            test_splitter, batch_size=1024 // num_samples, device=device
        )
        forecast_it, ts_it = make_evaluation_predictions(
            dataset=transformed_testdata,
            predictor=refiner_predictor,
            num_samples=num_samples,
        )
        evaluator = Evaluator()
        refined_metrics, _ = evaluator(
            list(ts_it),
            list(tqdm(forecast_it, total=len(transformed_testdata))),
        )
        refined_metrics = filter_metrics(refined_metrics)

        results.append(
            {
                "model": refiner_name,
                "model_params": json.dumps(ref_config),
                **refined_metrics,
            }
        )

    with open(save_path, "w") as fp:
        yaml.safe_dump(
            {"config": config, "metrics": results},
            fp,
            default_flow_style=False,
            sort_keys=False,
        )


if __name__ == "__main__":
    # Setup Logger
    logging.basicConfig(
        format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
    )
    logger = logging.getLogger(__file__)
    logger.setLevel(logging.INFO)

    # Setup argparse
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-c", "--config", type=str, required=True, help="Path to yaml config"
    )
    parser.add_argument(
        "--out_dir", type=str, default="./results", help="Path to results dir"
    )
    args, _ = parser.parse_known_args()

    with open(args.config, "r") as fp:
        config = yaml.safe_load(fp)

    # Update config from command line
    parser = add_config_to_argparser(config=config, parser=parser)
    args = parser.parse_args()
    config_updates = vars(args)
    for k in config.keys() & config_updates.keys():
        orig_val = config[k]
        updated_val = config_updates[k]
        if updated_val != orig_val:
            logger.info(f"Updated key '{k}': {orig_val} -> {updated_val}")
    config.update(config_updates)

    main(config=config, log_dir=args.out_dir)


================================================
FILE: bin/train_cond_model.py
================================================
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import logging
import argparse
from pathlib import Path

import yaml
import torch
from tqdm.auto import tqdm
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, RichProgressBar

from gluonts.dataset.loader import TrainDataLoader, ValidationDataLoader
from gluonts.dataset.split import OffsetSplitter
from gluonts.itertools import Cached
from gluonts.torch.batchify import batchify
from gluonts.evaluation import make_evaluation_predictions, Evaluator
from gluonts.dataset.field_names import FieldName

import uncond_ts_diff.configs as diffusion_configs
from uncond_ts_diff.dataset import get_gts_dataset
from uncond_ts_diff.model import TSDiffCond
from uncond_ts_diff.utils import (
    create_transforms,
    create_splitter,
    add_config_to_argparser,
    filter_metrics,
    MaskInput,
    ConcatDataset,
)


def create_model(config):
    model = TSDiffCond(
        **getattr(diffusion_configs, config["diffusion_config"]),
        freq=config["freq"],
        use_features=config["use_features"],
        use_lags=config["use_lags"],
        normalization=config["normalization"],
        context_length=config["context_length"],
        prediction_length=config["prediction_length"],
        lr=config["lr"],
        init_skip=config["init_skip"],
        noise_observed=config["noise_observed"],
    )
    model.to(config["device"])
    return model


def evaluate_conditional(
    config,
    model: TSDiffCond,
    test_dataset,
    transformation,
    num_samples=100,
):
    logger.info(f"Evaluating with {num_samples} samples.")
    logger.info(
        f"Evaluating scenario '{config['missing_scenario']}' "
        f"with {config['missing_values']:.1f} missing_values."
    )

    results = []

    transformed_testdata = transformation.apply(test_dataset, is_train=False)
    test_splitter = create_splitter(
        past_length=config["context_length"] + max(model.lags_seq),
        future_length=config["prediction_length"],
        mode="test",
    )

    masking_transform = MaskInput(
        FieldName.TARGET,
        FieldName.OBSERVED_VALUES,
        config["context_length"],
        config["missing_scenario"],
        config["missing_values"],
    )
    test_transform = test_splitter + masking_transform

    predictor = model.get_predictor(
        test_transform,
        batch_size=1280,
        device=config["device"],
    )
    forecast_it, ts_it = make_evaluation_predictions(
        dataset=transformed_testdata,
        predictor=predictor,
        num_samples=num_samples,
    )
    forecasts = list(tqdm(forecast_it, total=len(transformed_testdata)))
    tss = list(ts_it)
    evaluator = Evaluator()
    metrics, _ = evaluator(tss, forecasts)
    metrics = filter_metrics(metrics)
    results.append(dict(**metrics))

    return results


def main(config, log_dir):
    # Load parameters
    dataset_name = config["dataset"]
    freq = config["freq"]
    context_length = config["context_length"]
    prediction_length = config["prediction_length"]
    total_length = context_length + prediction_length

    # Create model
    model = create_model(config)

    # Setup dataset and data loading
    dataset = get_gts_dataset(dataset_name)
    assert dataset.metadata.freq == freq
    assert dataset.metadata.prediction_length == prediction_length

    if config["setup"] == "forecasting":
        training_data = dataset.train
    elif config["setup"] == "missing_values":
        missing_values_splitter = OffsetSplitter(offset=-total_length)
        training_data, _ = missing_values_splitter.split(dataset.train)

    num_rolling_evals = int(len(dataset.test) / len(dataset.train))

    transformation = create_transforms(
        num_feat_dynamic_real=0,
        num_feat_static_cat=0,
        num_feat_static_real=0,
        time_features=model.time_features,
        prediction_length=config["prediction_length"],
    )

    training_splitter = create_splitter(
        past_length=config["context_length"] + max(model.lags_seq),
        future_length=config["prediction_length"],
        mode="train",
    )

    if config["setup"] == "forecasting":
        config["missing_scenario"] = "none"
        config["missing_values"] = 0

    masking_transform = MaskInput(
        FieldName.TARGET,
        FieldName.OBSERVED_VALUES,
        config["context_length"],
        config.get("train_missing_scenario", config["missing_scenario"]),
        config["missing_values"],
    )
    train_transform = training_splitter + masking_transform

    callbacks = []
    val_loader = None
    if config["use_validation_set"]:
        transformed_data = transformation.apply(training_data, is_train=True)
        train_val_splitter = OffsetSplitter(
            offset=-config["prediction_length"] * num_rolling_evals
        )
        _, val_gen = train_val_splitter.split(training_data)

        val_dataset = ConcatDataset(
            val_gen.generate_instances(
                config["prediction_length"], num_rolling_evals
            )
        )
        val_splitter = create_splitter(
            past_length=config["context_length"] + max(model.lags_seq),
            future_length=config["prediction_length"],
            mode="val",
        )
        transformed_valdata = transformation.apply(val_dataset, is_train=True)
        val_loader = ValidationDataLoader(
            transformed_valdata,
            batch_size=1280,
            stack_fn=batchify,
            transform=val_splitter + masking_transform,
        )

        callbacks = []
        log_monitor = "valid_loss"
    else:
        transformed_data = transformation.apply(training_data, is_train=True)
        log_monitor = "train_loss"

    filename = dataset_name + "-{epoch:03d}-{train_loss:.3f}"

    data_loader = TrainDataLoader(
        Cached(transformed_data),
        batch_size=config["batch_size"],
        stack_fn=batchify,
        transform=train_transform,
        num_batches_per_epoch=config["num_batches_per_epoch"],
    )

    checkpoint_callback = ModelCheckpoint(
        save_top_k=3,
        monitor=f"{log_monitor}",
        mode="min",
        filename=filename,
        save_last=True,
        save_weights_only=True,
    )

    callbacks.append(checkpoint_callback)
    callbacks.append(RichProgressBar())

    trainer = pl.Trainer(
        accelerator="gpu" if torch.cuda.is_available() else None,
        devices=[int(config["device"].split(":")[-1])],
        max_epochs=config["max_epochs"],
        enable_progress_bar=True,
        num_sanity_val_steps=0,
        callbacks=callbacks,
        default_root_dir=log_dir,
        gradient_clip_val=config.get("gradient_clip_val", None),
        check_val_every_n_epoch=config["eval_every"],
    )
    logger.info(f"Logging to {trainer.logger.log_dir}")
    trainer.fit(
        model, train_dataloaders=data_loader, val_dataloaders=val_loader
    )
    logger.info("Training completed.")

    best_ckpt_path = Path(trainer.logger.log_dir) / "best_checkpoint.ckpt"

    if not best_ckpt_path.exists():
        torch.save(
            torch.load(checkpoint_callback.best_model_path)["state_dict"],
            best_ckpt_path,
        )
    logger.info(f"Loading {best_ckpt_path}.")
    best_state_dict = torch.load(best_ckpt_path)
    model.load_state_dict(best_state_dict, strict=True)

    metrics = (
        evaluate_conditional(config, model, dataset.test, transformation)
        if config.get("do_final_eval", True)
        else "Final eval not performed"
    )
    with open(Path(trainer.logger.log_dir) / "results.yaml", "w") as fp:
        yaml.dump(
            {
                "config": config,
                "version": trainer.logger.version,
                "metrics": metrics,
            },
            fp,
        )


if __name__ == "__main__":
    # Setup Logger
    logging.basicConfig(
        format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
    )
    logger = logging.getLogger(__file__)
    logger.setLevel(logging.INFO)

    # Setup argparse
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-c", "--config", type=str, required=True, help="Path to yaml config"
    )
    parser.add_argument(
        "--out_dir", type=str, default="./", help="Path to results dir"
    )
    args, _ = parser.parse_known_args()

    with open(args.config, "r") as fp:
        config = yaml.safe_load(fp)

    # Update config from command line
    parser = add_config_to_argparser(config=config, parser=parser)
    args = parser.parse_args()
    config_updates = vars(args)
    for k in config.keys() & config_updates.keys():
        orig_val = config[k]
        updated_val = config_updates[k]
        if updated_val != orig_val:
            logger.info(f"Updated key '{k}': {orig_val} -> {updated_val}")
    config.update(config_updates)

    main(config=config, log_dir=args.out_dir)


================================================
FILE: bin/train_model.py
================================================
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import logging
import argparse
from pathlib import Path

import yaml
import torch
from tqdm.auto import tqdm
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, RichProgressBar

from gluonts.dataset.loader import TrainDataLoader
from gluonts.dataset.split import OffsetSplitter
from gluonts.itertools import Cached
from gluonts.torch.batchify import batchify
from gluonts.evaluation import make_evaluation_predictions, Evaluator
from gluonts.dataset.field_names import FieldName

import uncond_ts_diff.configs as diffusion_configs
from uncond_ts_diff.dataset import get_gts_dataset
from uncond_ts_diff.model.callback import EvaluateCallback
from uncond_ts_diff.model import TSDiff
from uncond_ts_diff.sampler import DDPMGuidance, DDIMGuidance
from uncond_ts_diff.utils import (
    create_transforms,
    create_splitter,
    add_config_to_argparser,
    filter_metrics,
    MaskInput,
)

guidance_map = {"ddpm": DDPMGuidance, "ddim": DDIMGuidance}


def create_model(config):
    model = TSDiff(
        **getattr(diffusion_configs, config["diffusion_config"]),
        freq=config["freq"],
        use_features=config["use_features"],
        use_lags=config["use_lags"],
        normalization=config["normalization"],
        context_length=config["context_length"],
        prediction_length=config["prediction_length"],
        lr=config["lr"],
        init_skip=config["init_skip"],
    )
    model.to(config["device"])
    return model


def evaluate_guidance(
    config, model, test_dataset, transformation, num_samples=100
):
    logger.info(f"Evaluating with {num_samples} samples.")
    results = []
    if config["setup"] == "forecasting":
        missing_data_kwargs_list = [
            {
                "missing_scenario": "none",
                "missing_values": 0,
            }
        ]
        config["missing_data_configs"] = missing_data_kwargs_list
    elif config["setup"] == "missing_values":
        missing_data_kwargs_list = config["missing_data_configs"]
    else:
        raise ValueError(f"Unknown setup {config['setup']}")

    Guidance = guidance_map[config["sampler"]]
    sampler_kwargs = config["sampler_params"]
    for missing_data_kwargs in missing_data_kwargs_list:
        logger.info(
            f"Evaluating scenario '{missing_data_kwargs['missing_scenario']}' "
            f"with {missing_data_kwargs['missing_values']:.1f} missing_values."
        )
        sampler = Guidance(
            model=model,
            prediction_length=config["prediction_length"],
            num_samples=num_samples,
            **missing_data_kwargs,
            **sampler_kwargs,
        )

        transformed_testdata = transformation.apply(
            test_dataset, is_train=False
        )
        test_splitter = create_splitter(
            past_length=config["context_length"] + max(model.lags_seq),
            future_length=config["prediction_length"],
            mode="test",
        )

        masking_transform = MaskInput(
            FieldName.TARGET,
            FieldName.OBSERVED_VALUES,
            config["context_length"],
            missing_data_kwargs["missing_scenario"],
            missing_data_kwargs["missing_values"],
        )
        test_transform = test_splitter + masking_transform

        predictor = sampler.get_predictor(
            test_transform,
            batch_size=1280 // num_samples,
            device=config["device"],
        )
        forecast_it, ts_it = make_evaluation_predictions(
            dataset=transformed_testdata,
            predictor=predictor,
            num_samples=num_samples,
        )
        forecasts = list(tqdm(forecast_it, total=len(transformed_testdata)))
        tss = list(ts_it)
        evaluator = Evaluator()
        metrics, _ = evaluator(tss, forecasts)
        metrics = filter_metrics(metrics)
        results.append(dict(**missing_data_kwargs, **metrics))

    return results


def main(config, log_dir):
    # Load parameters
    dataset_name = config["dataset"]
    freq = config["freq"]
    context_length = config["context_length"]
    prediction_length = config["prediction_length"]
    total_length = context_length + prediction_length

    # Create model
    model = create_model(config)

    # Setup dataset and data loading
    dataset = get_gts_dataset(dataset_name)
    assert dataset.metadata.freq == freq
    assert dataset.metadata.prediction_length == prediction_length

    if config["setup"] == "forecasting":
        training_data = dataset.train
    elif config["setup"] == "missing_values":
        missing_values_splitter = OffsetSplitter(offset=-total_length)
        training_data, _ = missing_values_splitter.split(dataset.train)

    num_rolling_evals = int(len(dataset.test) / len(dataset.train))

    transformation = create_transforms(
        num_feat_dynamic_real=0,
        num_feat_static_cat=0,
        num_feat_static_real=0,
        time_features=model.time_features,
        prediction_length=config["prediction_length"],
    )

    training_splitter = create_splitter(
        past_length=config["context_length"] + max(model.lags_seq),
        future_length=config["prediction_length"],
        mode="train",
    )

    callbacks = []
    if config["use_validation_set"]:
        transformed_data = transformation.apply(training_data, is_train=True)
        train_val_splitter = OffsetSplitter(
            offset=-config["prediction_length"] * num_rolling_evals
        )
        _, val_gen = train_val_splitter.split(training_data)
        val_data = val_gen.generate_instances(
            config["prediction_length"], num_rolling_evals
        )

        callbacks = [
            EvaluateCallback(
                context_length=config["context_length"],
                prediction_length=config["prediction_length"],
                sampler=config["sampler"],
                sampler_kwargs=config["sampler_params"],
                num_samples=config["num_samples"],
                model=model,
                transformation=transformation,
                test_dataset=dataset.test,
                val_dataset=val_data,
                eval_every=config["eval_every"],
            )
        ]
    else:
        transformed_data = transformation.apply(training_data, is_train=True)

    log_monitor = "train_loss"
    filename = dataset_name + "-{epoch:03d}-{train_loss:.3f}"

    data_loader = TrainDataLoader(
        Cached(transformed_data),
        batch_size=config["batch_size"],
        stack_fn=batchify,
        transform=training_splitter,
        num_batches_per_epoch=config["num_batches_per_epoch"],
    )

    checkpoint_callback = ModelCheckpoint(
        save_top_k=3,
        monitor=f"{log_monitor}",
        mode="min",
        filename=filename,
        save_last=True,
        save_weights_only=True,
    )

    callbacks.append(checkpoint_callback)
    callbacks.append(RichProgressBar())

    trainer = pl.Trainer(
        accelerator="gpu" if torch.cuda.is_available() else None,
        devices=[int(config["device"].split(":")[-1])],
        max_epochs=config["max_epochs"],
        enable_progress_bar=True,
        num_sanity_val_steps=0,
        callbacks=callbacks,
        default_root_dir=log_dir,
        gradient_clip_val=config.get("gradient_clip_val", None),
    )
    logger.info(f"Logging to {trainer.logger.log_dir}")
    trainer.fit(model, train_dataloaders=data_loader)
    logger.info("Training completed.")

    best_ckpt_path = Path(trainer.logger.log_dir) / "best_checkpoint.ckpt"

    if not best_ckpt_path.exists():
        torch.save(
            torch.load(checkpoint_callback.best_model_path)["state_dict"],
            best_ckpt_path,
        )
    logger.info(f"Loading {best_ckpt_path}.")
    best_state_dict = torch.load(best_ckpt_path)
    model.load_state_dict(best_state_dict, strict=True)

    metrics = (
        evaluate_guidance(config, model, dataset.test, transformation)
        if config.get("do_final_eval", True)
        else "Final eval not performed"
    )
    with open(Path(trainer.logger.log_dir) / "results.yaml", "w") as fp:
        yaml.dump(
            {
                "config": config,
                "version": trainer.logger.version,
                "metrics": metrics,
            },
            fp,
        )


if __name__ == "__main__":
    # Setup Logger
    logging.basicConfig(
        format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
    )
    logger = logging.getLogger(__file__)
    logger.setLevel(logging.INFO)

    # Setup argparse
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-c", "--config", type=str, required=True, help="Path to yaml config"
    )
    parser.add_argument(
        "--out_dir", type=str, default="./", help="Path to results dir"
    )
    args, _ = parser.parse_known_args()

    with open(args.config, "r") as fp:
        config = yaml.safe_load(fp)

    # Update config from command line
    parser = add_config_to_argparser(config=config, parser=parser)
    args = parser.parse_args()
    config_updates = vars(args)
    for k in config.keys() & config_updates.keys():
        orig_val = config[k]
        updated_val = config_updates[k]
        if updated_val != orig_val:
            logger.info(f"Updated key '{k}': {orig_val} -> {updated_val}")
    config.update(config_updates)

    main(config=config, log_dir=args.out_dir)


================================================
FILE: bin/tstr_experiment.py
================================================
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from functools import partial
import math
import logging
import argparse
from pathlib import Path

import yaml
import torch
import numpy as np
from tqdm.auto import tqdm

from gluonts.mx import DeepAREstimator, TransformerEstimator
from gluonts.evaluation import Evaluator
from gluonts.dataset.loader import TrainDataLoader
from gluonts.itertools import Cached
from gluonts.torch.batchify import batchify
from gluonts.time_feature import (
    get_lags_for_frequency,
    time_features_from_frequency_str,
)
from gluonts.dataset.split import slice_data_entry
from gluonts.transform import AdhocTransform, Chain

from uncond_ts_diff.utils import (
    ScaleAndAddMeanFeature,
    ScaleAndAddMinMaxFeature,
    GluonTSNumpyDataset,
    create_transforms,
    create_splitter,
    get_next_file_num,
    add_config_to_argparser,
    make_evaluation_predictions_with_scaling,
    filter_metrics,
)
from uncond_ts_diff.model import TSDiff, LinearEstimator
from uncond_ts_diff.dataset import get_gts_dataset
import uncond_ts_diff.configs as diffusion_configs

DOWNSTREAM_MODELS = ["linear", "deepar", "transformer"]


def load_model(config):
    model = TSDiff(
        **getattr(
            diffusion_configs,
            config.get("diffusion_config", "diffusion_small_config"),
        ),
        freq=config["freq"],
        use_features=config["use_features"],
        use_lags=config["use_lags"],
        normalization="mean",
        context_length=config["context_length"],
        prediction_length=config["prediction_length"],
        init_skip=config["init_skip"],
    )
    model.load_state_dict(
        torch.load(config["ckpt"], map_location="cpu"),
        strict=True,
    )
    model = model.to(config["device"])
    return model


def sample_synthetic(
    model: TSDiff,
    num_samples: int = 10_000,
    batch_size: int = 1000,
):
    synth_samples = []

    n_iters = math.ceil(num_samples / batch_size)
    for _ in tqdm(range(n_iters)):
        samples = model.sample_n(num_samples=batch_size)
        synth_samples.append(samples)

    synth_samples = np.concatenate(synth_samples, axis=0)[:num_samples]

    return synth_samples


def sample_real(
    data_loader,
    n_timesteps: int,
    num_samples: int = 10_000,
    batch_size: int = 1000,
):
    real_samples = []
    data_iter = iter(data_loader)
    n_iters = math.ceil(num_samples / batch_size)
    for _ in tqdm(range(n_iters)):
        try:
            batch = next(data_iter)
        except StopIteration:
            data_iter = iter(data_loader)
            batch = next(data_iter)
        ts = np.concatenate(
            [batch["past_target"], batch["future_target"]], axis=-1
        )[:, -n_timesteps:]
        real_samples.append(ts)

    real_samples = np.concatenate(real_samples, axis=0)[:num_samples]

    return real_samples


def evaluate_tstr(
    tstr_predictor,
    test_dataset,
    context_length,
    prediction_length,
    num_samples=100,
    scaling_type="mean",
):
    total_length = context_length + prediction_length
    # Slice test set to be of the same length as context_length + prediction_length
    slice_func = partial(slice_data_entry, slice_=slice(-total_length, None))
    if scaling_type == "mean":
        ScaleAndAddScaleFeature = ScaleAndAddMeanFeature
    elif scaling_type == "min-max":
        ScaleAndAddScaleFeature = ScaleAndAddMinMaxFeature
    transformation = Chain(
        [
            AdhocTransform(slice_func),
            # Add scale to data entry for use later during evaluation
            ScaleAndAddScaleFeature("target", "scale", prediction_length),
        ]
    )
    sliced_test_set = transformation.apply(test_dataset)

    fcst_iter, ts_iter = make_evaluation_predictions_with_scaling(
        dataset=sliced_test_set,
        predictor=tstr_predictor,
        num_samples=num_samples,
        scaling_type=scaling_type,
    )
    evaluator = Evaluator()
    metrics, _ = evaluator(list(ts_iter), list(fcst_iter))
    return filter_metrics(metrics)


def train_and_evaluate(
    dataset,
    model_name,
    synth_samples,
    real_samples,
    config,
    scaling_type="mean",
):
    # NOTE: There's no notion of time for synthetic time series,
    # they are just "sequences".
    # A dummy timestamp is used for start time in synthetic time series.
    # Hence, time_features are set to [] in the models below.
    model_name = model_name.lower()
    freq = dataset.metadata.freq
    context_length = config["context_length"]
    prediction_length = config["prediction_length"]
    total_length = context_length + prediction_length

    assert len(synth_samples) == len(real_samples)
    assert (
        synth_samples.shape[-1] == total_length
        and real_samples.shape[-1] == total_length
    )
    num_samples = len(real_samples)

    synthetic_dataset = GluonTSNumpyDataset(synth_samples)

    if model_name == "linear":
        logger.info(f"Running TSTR for {model_name}")
        tstr_predictor = LinearEstimator(
            freq=freq,  # Not actually used in the estimator
            prediction_length=prediction_length,
            context_length=context_length,
            num_train_samples=num_samples,
            # Synthetic dataset is in the "scaled space"
            scaling=False,
        ).train(synthetic_dataset)
    elif model_name == "deepar":
        logger.info(f"Running TSTR for {model_name}")
        tstr_predictor = DeepAREstimator(
            freq=freq,
            prediction_length=prediction_length,
            # Synthetic dataset is in the "scaled space"
            scaling=False,
            time_features=[],
            lags_seq=get_lags_for_frequency(freq, lag_ub=context_length),
        ).train(synthetic_dataset)
    elif model_name == "transformer":
        logger.info(f"Running TSTR for {model_name}")
        tstr_predictor = TransformerEstimator(
            freq=freq,
            prediction_length=prediction_length,
            # Synthetic dataset is in the "scaled space"
            scaling=False,
            time_features=[],
            lags_seq=get_lags_for_frequency(freq, lag_ub=context_length),
        ).train(synthetic_dataset)

    tstr_metrics = evaluate_tstr(
        tstr_predictor=tstr_predictor,
        test_dataset=dataset.test,
        context_length=context_length,
        prediction_length=prediction_length,
        scaling_type=scaling_type,
    )

    return dict(
        tstr_metrics=tstr_metrics,
    )


def main(config: dict, log_dir: str, samples_path: str):
    # Read global parameters
    dataset_name = config["dataset"]
    context_length = config["context_length"]
    prediction_length = config["prediction_length"]

    # Create log_dir
    log_dir: Path = Path(log_dir)
    base_dirname = "tstr_log"
    run_num = get_next_file_num(
        base_dirname, log_dir, file_type="", separator="-"
    )
    log_dir = log_dir / f"{base_dirname}-{run_num}"
    log_dir.mkdir(exist_ok=True, parents=True)
    logger.info(f"Logging to {log_dir}")

    # Load dataset and model
    logger.info("Loading model")
    dataset = get_gts_dataset(dataset_name)
    config["freq"] = dataset.metadata.freq
    assert prediction_length == dataset.metadata.prediction_length

    model = load_model(config)

    # Setup data transformation and loading
    transformation = create_transforms(
        num_feat_dynamic_real=0,
        num_feat_static_cat=0,
        num_feat_static_real=0,
        time_features=time_features_from_frequency_str(config["freq"]),
        prediction_length=prediction_length,
    )
    transformed_data = transformation.apply(list(dataset.train), is_train=True)
    training_splitter = create_splitter(
        past_length=context_length + max(model.lags_seq),
        future_length=prediction_length,
        mode="train",
    )
    train_dataloader = TrainDataLoader(
        Cached(transformed_data),
        batch_size=1000,
        stack_fn=batchify,
        transform=training_splitter,
    )

    # Generate real samples
    logger.info("Generating real samples")
    real_samples = sample_real(
        train_dataloader,
        n_timesteps=context_length + prediction_length,
        num_samples=10000,
    )
    np.save(log_dir / "real_samples.npy", real_samples)

    if samples_path is None:
        # Generate synthetic samples
        logger.info("Generating synthetic samples")
        synth_samples = sample_synthetic(model, num_samples=10000)
        np.save(log_dir / "synth_samples.npy", synth_samples)
    else:
        logger.info(f"Using synthetic samples from {samples_path}")
        synth_samples = np.load(samples_path)[:10000]
        synth_samples = synth_samples.reshape(
            (10000, context_length + prediction_length)
        )

    # Run TSTR experiment for each downstream model
    results = []

    for model_name in DOWNSTREAM_MODELS:
        logger.info(f"Training and evaluating {model_name}")
        metrics = train_and_evaluate(
            dataset=dataset,
            model_name=model_name,
            synth_samples=synth_samples,
            real_samples=real_samples,
            config=config,
            scaling_type=config["scaling_type"],
        )
        results.append({"model": model_name, **metrics})

    logger.info("Saving results")
    with open(log_dir / "results.yaml", "w") as fp:
        yaml.safe_dump(
            {"config": config, "metrics": results},
            fp,
            default_flow_style=False,
            sort_keys=False,
        )


if __name__ == "__main__":
    # Setup Logger
    logging.basicConfig(
        format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
    )
    logger = logging.getLogger(__file__)
    logger.setLevel(logging.INFO)

    # Setup argparse
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-c", "--config", type=str, required=True, help="Path to yaml config"
    )
    parser.add_argument(
        "--out_dir", type=str, default="./results", help="Path to results dir"
    )
    parser.add_argument(
        "--samples_path", type=str, help="Path to generated samples"
    )
    args, _ = parser.parse_known_args()

    with open(args.config, "r") as fp:
        config = yaml.safe_load(fp)

    # Update config from command line
    parser = add_config_to_argparser(config=config, parser=parser)
    args = parser.parse_args()
    config_updates = vars(args)
    for k in config.keys() & config_updates.keys():
        orig_val = config[k]
        updated_val = config_updates[k]
        if updated_val != orig_val:
            logger.info(f"Updated key '{k}': {orig_val} -> {updated_val}")
    config.update(config_updates)

    main(config=config, log_dir=args.out_dir, samples_path=args.samples_path)


================================================
FILE: configs/guidance/guidance_electricity.yaml
================================================
ckpt: dummy/path.ckpt
context_length: 336
dataset: electricity_nips
device: cuda:0
diffusion_config: diffusion_small_config
freq: H
init_skip: false
num_samples: 100
prediction_length: 24
sampler: ddpm
sampler_params:
  guidance: quantile
  scale: 4
setup: forecasting
use_features: false
use_lags: true


================================================
FILE: configs/guidance/guidance_exchange.yaml
================================================
ckpt: dummy/path.ckpt
context_length: 360
dataset: exchange_rate_nips
device: cuda:0
diffusion_config: diffusion_small_config
freq: B
init_skip: true
num_samples: 100
prediction_length: 30
sampler: ddpm
sampler_params:
  guidance: quantile
  scale: 8
setup: forecasting
use_features: false
use_lags: true


================================================
FILE: configs/guidance/guidance_kdd_cup.yaml
================================================
ckpt: dummy/path.ckpt
context_length: 312
dataset: kdd_cup_2018_without_missing
device: cuda:0
diffusion_config: diffusion_small_config
freq: H
init_skip: true
num_samples: 100
prediction_length: 48
sampler: ddpm
sampler_params:
  guidance: quantile
  scale: 1
setup: forecasting
use_features: false
use_lags: true


================================================
FILE: configs/guidance/guidance_m4.yaml
================================================
ckpt: dummy/path.ckpt
context_length: 312
dataset: m4_hourly
device: cuda:0
diffusion_config: diffusion_small_config
freq: H
init_skip: false
num_samples: 100
prediction_length: 48
sampler: ddpm
sampler_params:
  guidance: quantile
  scale: 2
setup: forecasting
use_features: false
use_lags: false


================================================
FILE: configs/guidance/guidance_solar.yaml
================================================
ckpt: dummy/path.ckpt
context_length: 336
dataset: solar_nips
device: cuda:0
diffusion_config: diffusion_small_config
freq: H
init_skip: false
num_samples: 100
prediction_length: 24
sampler: ddpm
sampler_params:
  guidance: quantile
  scale: 8
setup: forecasting
use_features: false
use_lags: true


================================================
FILE: configs/guidance/guidance_traffic.yaml
================================================
ckpt: dummy/path.ckpt
context_length: 336
dataset: traffic_nips
device: cuda:0
diffusion_config: diffusion_small_config
freq: H
init_skip: true
num_samples: 100
prediction_length: 24
sampler: ddpm
sampler_params:
  guidance: quantile
  scale: 4
setup: forecasting
use_features: false
use_lags: true


================================================
FILE: configs/guidance/guidance_uber_tlc.yaml
================================================
ckpt: dummy/path.ckpt
context_length: 336
dataset: uber_tlc_hourly
device: cuda:0
diffusion_config: diffusion_small_config
freq: H
init_skip: false
num_samples: 100
prediction_length: 24
sampler: ddpm
sampler_params:
  guidance: quantile
  scale: 2
setup: forecasting
use_features: false
use_lags: true


================================================
FILE: configs/guidance/guidance_wiki.yaml
================================================
ckpt: dummy/path.ckpt
context_length: 360
dataset: wiki2000_nips
device: cuda:0
diffusion_config: diffusion_small_config
freq: 1D
init_skip: false
num_samples: 100
prediction_length: 30
sampler: ddpm
sampler_params:
  guidance: quantile
  scale: 2
setup: forecasting
use_features: false
use_lags: false


================================================
FILE: configs/guidance.yaml
================================================
# Model & checkpoint parameters
dataset: solar_nips
freq: H
device: cuda:0
ckpt: ckpts/forecasting/solar_nips/649_.ckpt
diffusion_config: diffusion_small_config
context_length: 336
prediction_length: 24
use_lags: true
use_features: false
init_skip: false
sampler: ddpm
sampler_params:
  guidance: quantile
  scale: 4
num_samples: 100
setup: forecasting
# The following key will be ignored,
# if the setup is forecasting
missing_data_configs:
- missing_scenario: BM-B
  missing_values: 168
- missing_scenario: BM-E
  missing_values: 168

================================================
FILE: configs/refinement/electricity_nips-deepar.yaml
================================================
base_model: deepar
ckpt: dummy/electricity_nips.ckpt
context_length: 336
dataset: electricity_nips
device: cuda:0
init_skip: false
iterations: 20
num_samples: 100
prediction_length: 24
refiner_configs:
- guidance: MSE
  lr: 0.1
  refiner_name: most_likely
- guidance: quantile
  lr: 0.1
  refiner_name: most_likely
- guidance: MSE
  method: lmc
  method_kwargs:
    noise_scale: 0.1
  refiner_name: mcmc
  step_size: 0.1
- guidance: quantile
  method: lmc
  method_kwargs:
    noise_scale: 0.1
  refiner_name: mcmc
  step_size: 0.1
use_features: false
use_lags: true


================================================
FILE: configs/refinement/electricity_nips-linear.yaml
================================================
base_model: linear
ckpt: dummy/electricity_nips.ckpt
context_length: 336
dataset: electricity_nips
device: cuda:0
diffusion_config: diffusion_small_config
init_skip: false
iterations: 20
num_samples: 100
prediction_length: 24
refiner_configs:
- guidance: MSE
  lr: 0.1
  refiner_name: most_likely
- guidance: quantile
  lr: 0.1
  refiner_name: most_likely
- guidance: MSE
  method: lmc
  method_kwargs:
    noise_scale: 0.1
  refiner_name: mcmc
  step_size: 0.1
- guidance: quantile
  method: lmc
  method_kwargs:
    noise_scale: 0.1
  refiner_name: mcmc
  step_size: 0.1
use_features: false
use_lags: true


================================================
FILE: configs/refinement/electricity_nips-seasonal_naive.yaml
================================================
base_model: seasonal_naive
ckpt: dummy/electricity_nips.ckpt
context_length: 336
dataset: electricity_nips
device: cuda:0
init_skip: false
iterations: 20
num_samples: 100
prediction_length: 24
refiner_configs:
- guidance: MSE
  lr: 0.1
  refiner_name: most_likely
- guidance: quantile
  lr: 0.1
  refiner_name: most_likely
- guidance: MSE
  method: lmc
  method_kwargs:
    noise_scale: 0.1
  refiner_name: mcmc
  step_size: 0.1
- guidance: quantile
  method: lmc
  method_kwargs:
    noise_scale: 0.1
  refiner_name: mcmc
  step_size: 0.1
use_features: false
use_lags: true


================================================
FILE: configs/refinement/electricity_nips-transformer.yaml
================================================
base_model: transformer
ckpt: dummy/electricity_nips.ckpt
context_length: 336
dataset: electricity_nips
device: cuda:0
init_skip: false
iterations: 20
num_samples: 100
prediction_length: 24
refiner_configs:
- guidance: MSE
  lr: 0.1
  refiner_name: most_likely
- guidance: quantile
  lr: 0.1
  refiner_name: most_likely
- guidance: MSE
  method: lmc
  method_kwargs:
    noise_scale: 0.1
  refiner_name: mcmc
  step_size: 0.1
- guidance: quantile
  method: lmc
  method_kwargs:
    noise_scale: 0.1
  refiner_name: mcmc
  step_size: 0.1
use_features: false
use_lags: true


================================================
FILE: configs/refinement/exchange_rate_nips-deepar.yaml
================================================
base_model: deepar
ckpt: dummy/exchange_rate_nips.ckpt
context_length: 360
dataset: exchange_rate_nips
device: cuda:0
init_skip: true
iterations: 20
num_samples: 100
prediction_length: 30
refiner_configs:
- guidance: MSE
  lr: 0.01
  refiner_name: most_likely
- guidance: quantile
  lr: 0.01
  refiner_name: most_likely
- guidance: MSE
  method: lmc
  method_kwargs:
    noise_scale: 0.1
  refiner_name: mcmc
  step_size: 0.01
- guidance: quantile
  method: lmc
  method_kwargs:
    noise_scale: 0.1
  refiner_name: mcmc
  step_size: 0.01
use_features: false
use_lags: true


================================================
FILE: configs/refinement/exchange_rate_nips-linear.yaml
================================================
base_model: linear
ckpt: dummy/exchange_rate_nips.ckpt
context_length: 360
dataset: exchange_rate_nips
device: cuda:0
diffusion_config: diffusion_small_config
init_skip: true
iterations: 20
num_samples: 100
prediction_length: 30
refiner_configs:
- guidance: MSE
  lr: 0.01
  refiner_name: most_likely
- guidance: quantile
  lr: 0.01
  refiner_name: most_likely
- guidance: MSE
  method: lmc
  method_kwargs:
    noise_scale: 0.1
  refiner_name: mcmc
  step_size: 0.01
- guidance: quantile
  method: lmc
  method_kwargs:
    noise_scale: 0.1
  refiner_name: mcmc
  step_size: 0.01
use_features: false
use_lags: true


================================================
FILE: configs/refinement/exchange_rate_nips-seasonal_naive.yaml
================================================
base_model: seasonal_naive
ckpt: dummy/exchange_rate_nips.ckpt
context_length: 360
dataset: exchange_rate_nips
device: cuda:0
init_skip: true
iterations: 20
num_samples: 100
prediction_length: 30
refiner_configs:
- guidance: MSE
  lr: 0.01
  refiner_name: most_likely
- guidance: quantile
  lr: 0.01
  refiner_name: most_likely
- guidance: MSE
  method: lmc
  method_kwargs:
    noise_scale: 0.1
  refiner_name: mcmc
  step_size: 0.01
- guidance: quantile
  method: lmc
  method_kwargs:
    noise_scale: 0.1
  refiner_name: mcmc
  step_size: 0.01
use_features: false
use_lags: true


================================================
FILE: configs/refinement/exchange_rate_nips-transformer.yaml
================================================
base_model: transformer
ckpt: dummy/exchange_rate_nips.ckpt
context_length: 360
dataset: exchange_rate_nips
device: cuda:0
init_skip: true
iterations: 20
num_samples: 100
prediction_length: 30
refiner_configs:
- guidance: MSE
  lr: 0.01
  refiner_name: most_likely
- guidance: quantile
  lr: 0.01
  refiner_name: most_likely
- guidance: MSE
  method: lmc
  method_kwargs:
    noise_scale: 0.1
  refiner_name: mcmc
  step_size: 0.01
- guidance: quantile
  method: lmc
  method_kwargs:
    noise_scale: 0.1
  refiner_name: mcmc
  step_size: 0.01
use_features: false
use_lags: true


================================================
FILE: configs/refinement/kdd_cup_2018_without_missing-deepar.yaml
================================================
base_model: deepar
base_model_params: {}
ckpt: dummy/kdd_cup_2018_without_missing.ckpt
context_length: 312
dataset: kdd_cup_2018_without_missing
device: cuda:0
diffusion_config: diffusion_small_config
init_skip: true
iterations: 20
num_samples: 100
prediction_length: 48
refiner_configs:
- guidance: MSE
  lr: 0.1
  refiner_name: most_likely
- guidance: quantile
  lr: 0.1
  refiner_name: most_likely
- guidance: MSE
  method: lmc
  method_kwargs:
    noise_scale: 0.1
  refiner_name: mcmc
  step_size: 0.1
- guidance: quantile
  method: lmc
  method_kwargs:
    noise_scale: 0.1
  refiner_name: mcmc
  step_size: 0.1
use_features: false
use_lags: true


================================================
FILE: configs/refinement/kdd_cup_2018_without_missing-linear.yaml
================================================
base_model: linear
base_model_params: {}
ckpt: dummy/kdd_cup_2018_without_missing.ckpt
context_length: 312
dataset: kdd_cup_2018_without_missing
device: cuda:0
diffusion_config: diffusion_small_config
init_skip: true
iterations: 20
num_samples: 100
prediction_length: 48
refiner_configs:
- guidance: MSE
  lr: 0.1
  refiner_name: most_likely
- guidance: quantile
  lr: 0.1
  refiner_name: most_likely
- guidance: MSE
  method: lmc
  method_kwargs:
    noise_scale: 0.1
  refiner_name: mcmc
  step_size: 0.1
- guidance: quantile
  method: lmc
  method_kwargs:
    noise_scale: 0.1
  refiner_name: mcmc
  step_size: 0.1
use_features: false
use_lags: true


================================================
FILE: configs/refinement/kdd_cup_2018_without_missing-seasonal_naive.yaml
================================================
base_model: seasonal_naive
base_model_params: {}
ckpt: dummy/kdd_cup_2018_without_missing.ckpt
context_length: 312
dataset: kdd_cup_2018_without_missing
device: cuda:0
diffusion_config: diffusion_small_config
init_skip: true
iterations: 20
num_samples: 100
prediction_length: 48
refiner_configs:
- guidance: MSE
  lr: 0.1
  refiner_name: most_likely
- guidance: quantile
  lr: 0.1
  refiner_name: most_likely
- guidance: MSE
  method: lmc
  method_kwargs:
    noise_scale: 0.1
  refiner_name: mcmc
  step_size: 0.1
- guidance: quantile
  method: lmc
  method_kwargs:
    noise_scale: 0.1
  refiner_name: mcmc
  step_size: 0.1
use_features: false
use_lags: true


================================================
FILE: configs/refinement/kdd_cup_2018_without_missing-transformer.yaml
================================================
base_model: transformer
base_model_params: {}
ckpt: dummy/kdd_cup_2018_without_missing.ckpt
context_length: 312
dataset: kdd_cup_2018_without_missing
device: cuda:0
diffusion_config: diffusion_small_config
init_skip: true
iterations: 20
num_samples: 100
prediction_length: 48
refiner_configs:
- guidance: MSE
  lr: 0.1
  refiner_name: most_likely
- guidance: quantile
  lr: 0.1
  refiner_name: most_likely
- guidance: MSE
  method: lmc
  method_kwargs:
    noise_scale: 0.1
  refiner_name: mcmc
  step_size: 0.1
- guidance: quantile
  method: lmc
  method_kwargs:
    noise_scale: 0.1
  refiner_name: mcmc
  step_size: 0.1
use_features: false
use_lags: true


================================================
FILE: configs/refinement/m4_hourly-deepar.yaml
================================================
base_model: deepar
ckpt: dummy/m4_hourly.ckpt
context_length: 312
dataset: m4_hourly
device: cuda:0
diffusion_config: diffusion_small_config
init_skip: false
iterations: 20
num_samples: 100
prediction_length: 48
refiner_configs:
- guidance: MSE
  lr: 0.1
  refiner_name: most_likely
- guidance: quantile
  lr: 0.1
  refiner_name: most_likely
- guidance: MSE
  method: lmc
  method_kwargs:
    noise_scale: 0.1
  refiner_name: mcmc
  step_size: 0.1
- guidance: quantile
  method: lmc
  method_kwargs:
    noise_scale: 0.1
  refiner_name: mcmc
  step_size: 0.1
use_features: false
use_lags: false


================================================
FILE: configs/refinement/m4_hourly-linear.yaml
================================================
base_model: linear
ckpt: dummy/m4_hourly.ckpt
context_length: 312
dataset: m4_hourly
device: cuda:0
diffusion_config: diffusion_small_config
init_skip: false
iterations: 20
num_samples: 100
prediction_length: 48
refiner_configs:
- guidance: MSE
  lr: 0.1
  refiner_name: most_likely
- guidance: quantile
  lr: 0.1
  refiner_name: most_likely
- guidance: MSE
  method: lmc
  method_kwargs:
    noise_scale: 0.1
  refiner_name: mcmc
  step_size: 0.1
- guidance: quantile
  method: lmc
  method_kwargs:
    noise_scale: 0.1
  refiner_name: mcmc
  step_size: 0.1
use_features: false
use_lags: false


================================================
FILE: configs/refinement/m4_hourly-seasonal_naive.yaml
================================================
base_model: seasonal_naive
ckpt: dummy/m4_hourly.ckpt
context_length: 312
dataset: m4_hourly
device: cuda:0
diffusion_config: diffusion_small_config
init_skip: false
iterations: 20
num_samples: 100
prediction_length: 48
refiner_configs:
- guidance: MSE
  lr: 0.1
  refiner_name: most_likely
- guidance: quantile
  lr: 0.1
  refiner_name: most_likely
- guidance: MSE
  method: lmc
  method_kwargs:
    noise_scale: 0.1
  refiner_name: mcmc
  step_size: 0.1
- guidance: quantile
  method: lmc
  method_kwargs:
    noise_scale: 0.1
  refiner_name: mcmc
  step_size: 0.1
use_features: false
use_lags: false


================================================
FILE: configs/refinement/m4_hourly-transformer.yaml
================================================
base_model: transformer
ckpt: dummy/m4_hourly.ckpt
context_length: 312
dataset: m4_hourly
device: cuda:0
diffusion_config: diffusion_small_config
init_skip: false
iterations: 20
num_samples: 100
prediction_length: 48
refiner_configs:
- guidance: MSE
  lr: 0.1
  refiner_name: most_likely
- guidance: quantile
  lr: 0.1
  refiner_name: most_likely
- guidance: MSE
  method: lmc
  method_kwargs:
    noise_scale: 0.1
  refiner_name: mcmc
  step_size: 0.1
- guidance: quantile
  method: lmc
  method_kwargs:
    noise_scale: 0.1
  refiner_name: mcmc
  step_size: 0.1
use_features: false
use_lags: false


================================================
FILE: configs/refinement/solar_nips-deepar.yaml
================================================
base_model: deepar
ckpt: dummy/solar_nips.ckpt
context_length: 336
dataset: solar_nips
device: cuda:0
init_skip: false
iterations: 20
num_samples: 100
prediction_length: 24
refiner_configs:
- guidance: MSE
  lr: 0.1
  refiner_name: most_likely
- guidance: quantile
  lr: 0.1
  refiner_name: most_likely
- guidance: MSE
  method: lmc
  method_kwargs:
    noise_scale: 0.1
  refiner_name: mcmc
  step_size: 0.1
- guidance: quantile
  method: lmc
  method_kwargs:
    noise_scale: 0.1
  refiner_name: mcmc
  step_size: 0.1
use_features: false
use_lags: true


================================================
FILE: configs/refinement/solar_nips-linear.yaml
================================================
base_model: linear
ckpt: dummy/solar_nips.ckpt
context_length: 336
dataset: solar_nips
device: cuda:0
diffusion_config: diffusion_small_config
init_skip: false
iterations: 20
num_samples: 100
prediction_length: 24
refiner_configs:
- guidance: MSE
  lr: 0.1
  refiner_name: most_likely
- guidance: quantile
  lr: 0.1
  refiner_name: most_likely
- guidance: MSE
  method: lmc
  method_kwargs:
    noise_scale: 0.1
  refiner_name: mcmc
  step_size: 0.1
- guidance: quantile
  method: lmc
  method_kwargs:
    noise_scale: 0.1
  refiner_name: mcmc
  step_size: 0.1
use_features: false
use_lags: true


================================================
FILE: configs/refinement/solar_nips-seasonal_naive.yaml
================================================
base_model: seasonal_naive
ckpt: dummy/solar_nips.ckpt
context_length: 336
dataset: solar_nips
device: cuda:0
init_skip: false
iterations: 20
num_samples: 100
prediction_length: 24
refiner_configs:
- guidance: MSE
  lr: 0.1
  refiner_name: most_likely
- guidance: quantile
  lr: 0.1
  refiner_name: most_likely
- guidance: MSE
  method: lmc
  method_kwargs:
    noise_scale: 0.1
  refiner_name: mcmc
  step_size: 0.1
- guidance: quantile
  method: lmc
  method_kwargs:
    noise_scale: 0.1
  refiner_name: mcmc
  step_size: 0.1
use_features: false
use_lags: true


================================================
FILE: configs/refinement/solar_nips-transformer.yaml
================================================
base_model: transformer
ckpt: dummy/solar_nips.ckpt
context_length: 336
dataset: solar_nips
device: cuda:0
init_skip: false
iterations: 20
num_samples: 100
prediction_length: 24
refiner_configs:
- guidance: MSE
  lr: 0.1
  refiner_name: most_likely
- guidance: quantile
  lr: 0.1
  refiner_name: most_likely
- guidance: MSE
  method: lmc
  method_kwargs:
    noise_scale: 0.1
  refiner_name: mcmc
  step_size: 0.1
- guidance: quantile
  method: lmc
  method_kwargs:
    noise_scale: 0.1
  refiner_name: mcmc
  step_size: 0.1
use_features: false
use_lags: true


================================================
FILE: configs/refinement/traffic_nips-deepar.yaml
================================================
base_model: deepar
ckpt: dummy/traffic_nips.ckpt
context_length: 336
dataset: traffic_nips
device: cuda:0
init_skip: true
iterations: 20
num_samples: 100
prediction_length: 24
refiner_configs:
- guidance: MSE
  lr: 0.1
  refiner_name: most_likely
- guidance: quantile
  lr: 0.1
  refiner_name: most_likely
- guidance: MSE
  method: lmc
  method_kwargs:
    noise_scale: 0.1
  refiner_name: mcmc
  step_size: 0.1
- guidance: quantile
  method: lmc
  method_kwargs:
    noise_scale: 0.1
  refiner_name: mcmc
  step_size: 0.1
use_features: false
use_lags: true


================================================
FILE: configs/refinement/traffic_nips-linear.yaml
================================================
base_model: linear
ckpt: dummy/traffic_nips.ckpt
context_length: 336
dataset: traffic_nips
device: cuda:0
diffusion_config: diffusion_small_config
init_skip: true
iterations: 20
num_samples: 100
prediction_length: 24
refiner_configs:
- guidance: MSE
  lr: 0.1
  refiner_name: most_likely
- guidance: quantile
  lr: 0.1
  refiner_name: most_likely
- guidance: MSE
  method: lmc
  method_kwargs:
    noise_scale: 0.1
  refiner_name: mcmc
  step_size: 0.1
- guidance: quantile
  method: lmc
  method_kwargs:
    noise_scale: 0.1
  refiner_name: mcmc
  step_size: 0.1
use_features: false
use_lags: true


================================================
FILE: configs/refinement/traffic_nips-seasonal_naive.yaml
================================================
base_model: seasonal_naive
ckpt: dummy/traffic_nips.ckpt
context_length: 336
dataset: traffic_nips
device: cuda:0
init_skip: true
iterations: 20
num_samples: 100
prediction_length: 24
refiner_configs:
- guidance: MSE
  lr: 0.1
  refiner_name: most_likely
- guidance: quantile
  lr: 0.1
  refiner_name: most_likely
- guidance: MSE
  method: lmc
  method_kwargs:
    noise_scale: 0.1
  refiner_name: mcmc
  step_size: 0.1
- guidance: quantile
  method: lmc
  method_kwargs:
    noise_scale: 0.1
  refiner_name: mcmc
  step_size: 0.1
use_features: false
use_lags: true


================================================
FILE: configs/refinement/traffic_nips-transformer.yaml
================================================
base_model: transformer
ckpt: dummy/traffic_nips.ckpt
context_length: 336
dataset: traffic_nips
device: cuda:0
init_skip: true
iterations: 20
num_samples: 100
prediction_length: 24
refiner_configs:
- guidance: MSE
  lr: 0.1
  refiner_name: most_likely
- guidance: quantile
  lr: 0.1
  refiner_name: most_likely
- guidance: MSE
  method: lmc
  method_kwargs:
    noise_scale: 0.1
  refiner_name: mcmc
  step_size: 0.1
- guidance: quantile
  method: lmc
  method_kwargs:
    noise_scale: 0.1
  refiner_name: mcmc
  step_size: 0.1
use_features: false
use_lags: true


================================================
FILE: configs/refinement/uber_tlc_hourly-deepar.yaml
================================================
base_model: deepar
ckpt: dummy/uber_tlc_hourly.ckpt
context_length: 336
dataset: uber_tlc_hourly
device: cuda:0
init_skip: false
iterations: 20
num_samples: 100
prediction_length: 24
refiner_configs:
- guidance: MSE
  lr: 0.1
  refiner_name: most_likely
- guidance: quantile
  lr: 0.1
  refiner_name: most_likely
- guidance: MSE
  method: lmc
  method_kwargs:
    noise_scale: 0.1
  refiner_name: mcmc
  step_size: 0.1
- guidance: quantile
  method: lmc
  method_kwargs:
    noise_scale: 0.1
  refiner_name: mcmc
  step_size: 0.1
use_features: false
use_lags: true


================================================
FILE: configs/refinement/uber_tlc_hourly-linear.yaml
================================================
base_model: linear
ckpt: dummy/uber_tlc_hourly.ckpt
context_length: 336
dataset: uber_tlc_hourly
device: cuda:0
diffusion_config: diffusion_small_config
init_skip: false
iterations: 20
num_samples: 100
prediction_length: 24
refiner_configs:
- guidance: MSE
  lr: 0.1
  refiner_name: most_likely
- guidance: quantile
  lr: 0.1
  refiner_name: most_likely
- guidance: MSE
  method: lmc
  method_kwargs:
    noise_scale: 0.1
  refiner_name: mcmc
  step_size: 0.1
- guidance: quantile
  method: lmc
  method_kwargs:
    noise_scale: 0.1
  refiner_name: mcmc
  step_size: 0.1
use_features: false
use_lags: true


================================================
FILE: configs/refinement/uber_tlc_hourly-seasonal_naive.yaml
================================================
base_model: seasonal_naive
ckpt: dummy/uber_tlc_hourly.ckpt
context_length: 336
dataset: uber_tlc_hourly
device: cuda:0
init_skip: false
iterations: 20
num_samples: 100
prediction_length: 24
refiner_configs:
- guidance: MSE
  lr: 0.1
  refiner_name: most_likely
- guidance: quantile
  lr: 0.1
  refiner_name: most_likely
- guidance: MSE
  method: lmc
  method_kwargs:
    noise_scale: 0.1
  refiner_name: mcmc
  step_size: 0.1
- guidance: quantile
  method: lmc
  method_kwargs:
    noise_scale: 0.1
  refiner_name: mcmc
  step_size: 0.1
use_features: false
use_lags: true


================================================
FILE: configs/refinement/uber_tlc_hourly-transformer.yaml
================================================
base_model: transformer
ckpt: dummy/uber_tlc_hourly.ckpt
context_length: 336
dataset: uber_tlc_hourly
device: cuda:0
init_skip: false
iterations: 20
num_samples: 100
prediction_length: 24
refiner_configs:
- guidance: MSE
  lr: 0.1
  refiner_name: most_likely
- guidance: quantile
  lr: 0.1
  refiner_name: most_likely
- guidance: MSE
  method: lmc
  method_kwargs:
    noise_scale: 0.1
  refiner_name: mcmc
  step_size: 0.1
- guidance: quantile
  method: lmc
  method_kwargs:
    noise_scale: 0.1
  refiner_name: mcmc
  step_size: 0.1
use_features: false
use_lags: true


================================================
FILE: configs/refinement/wiki2000_nips-deepar.yaml
================================================
base_model: deepar
ckpt: dummy/wiki2000_nips.ckpt
context_length: 360
dataset: wiki2000_nips
device: cuda:0
diffusion_config: diffusion_small_config
init_skip: false
iterations: 20
num_samples: 100
prediction_length: 30
refiner_configs:
- guidance: MSE
  lr: 0.1
  refiner_name: most_likely
- guidance: quantile
  lr: 0.1
  refiner_name: most_likely
- guidance: MSE
  method: lmc
  method_kwargs:
    noise_scale: 0.1
  refiner_name: mcmc
  step_size: 0.1
- guidance: quantile
  method: lmc
  method_kwargs:
    noise_scale: 0.1
  refiner_name: mcmc
  step_size: 0.1
use_features: false
use_lags: false


================================================
FILE: configs/refinement/wiki2000_nips-linear.yaml
================================================
base_model: linear
ckpt: dummy/wiki2000_nips.ckpt
context_length: 360
dataset: wiki2000_nips
device: cuda:0
diffusion_config: diffusion_small_config
init_skip: false
iterations: 20
num_samples: 100
prediction_length: 30
refiner_configs:
- guidance: MSE
  lr: 0.1
  refiner_name: most_likely
- guidance: quantile
  lr: 0.1
  refiner_name: most_likely
- guidance: MSE
  method: lmc
  method_kwargs:
    noise_scale: 0.1
  refiner_name: mcmc
  step_size: 0.1
- guidance: quantile
  method: lmc
  method_kwargs:
    noise_scale: 0.1
  refiner_name: mcmc
  step_size: 0.1
use_features: false
use_lags: false


================================================
FILE: configs/refinement/wiki2000_nips-seasonal_naive.yaml
================================================
base_model: seasonal_naive
ckpt: dummy/wiki2000_nips.ckpt
context_length: 360
dataset: wiki2000_nips
device: cuda:0
diffusion_config: diffusion_small_config
init_skip: false
iterations: 20
num_samples: 100
prediction_length: 30
refiner_configs:
- guidance: MSE
  lr: 0.1
  refiner_name: most_likely
- guidance: quantile
  lr: 0.1
  refiner_name: most_likely
- guidance: MSE
  method: lmc
  method_kwargs:
    noise_scale: 0.1
  refiner_name: mcmc
  step_size: 0.1
- guidance: quantile
  method: lmc
  method_kwargs:
    noise_scale: 0.1
  refiner_name: mcmc
  step_size: 0.1
use_features: false
use_lags: false


================================================
FILE: configs/refinement/wiki2000_nips-transformer.yaml
================================================
base_model: transformer
ckpt: dummy/wiki2000_nips.ckpt
context_length: 360
dataset: wiki2000_nips
device: cuda:0
diffusion_config: diffusion_small_config
init_skip: false
iterations: 20
num_samples: 100
prediction_length: 30
refiner_configs:
- guidance: MSE
  lr: 0.1
  refiner_name: most_likely
- guidance: quantile
  lr: 0.1
  refiner_name: most_likely
- guidance: MSE
  method: lmc
  method_kwargs:
    noise_scale: 0.1
  refiner_name: mcmc
  step_size: 0.1
- guidance: quantile
  method: lmc
  method_kwargs:
    noise_scale: 0.1
  refiner_name: mcmc
  step_size: 0.1
use_features: false
use_lags: false


================================================
FILE: configs/refinement.yaml
================================================
# Model & checkpoint parameters
dataset: solar_nips
device: cuda:0
ckpt: ckpts/forecasting/solar_nips/649_.ckpt
diffusion_config: diffusion_small_config
context_length: 336
prediction_length: 24
use_lags: true
use_features: false
init_skip: false
# Refinement parameters
base_model: linear
base_model_params: {}
num_samples: 16
iterations: 10
refiner_configs:
- refiner_name: most_likely
  lr: 1.e-1
  guidance: MSE
- refiner_name: most_likely
  lr: 1.e-1
  guidance: quantile
- refiner_name: mcmc
  step_size: 1.e-1
  guidance: MSE
  method: lmc
  method_kwargs:
    noise_scale: 0.1
- refiner_name: mcmc
  step_size: 1.e-1
  guidance: quantile
  method: lmc
  method_kwargs:
    noise_scale: 0.1

================================================
FILE: configs/train_tsdiff/train_electricity.yaml
================================================
model: unconditional
diffusion_config: diffusion_small_config
normalization: mean
use_features: False
use_lags: True
dataset: electricity_nips
freq: H
context_length: 336 # 360 for `D`
prediction_length: 24 # 30 for `D`
lr: 1.e-3
init_skip: False
gradient_clip_val: 0.5
max_epochs: 1000
num_batches_per_epoch: 128
batch_size: 64
# Used only in callback,
# the final evaluation uses 100 samples
num_samples: 16
sampler: ddpm
sampler_params:
  guidance: quantile
  scale: 4
use_validation_set: True
eval_every: 50
device: cuda:0
setup: forecasting

================================================
FILE: configs/train_tsdiff/train_exchange.yaml
================================================
model: unconditional
diffusion_config: diffusion_small_config
normalization: mean
use_features: False
use_lags: True
dataset: exchange_rate_nips
freq: B
context_length: 360 # 360 for `D`
prediction_length: 30 # 30 for `D`
lr: 1.e-3
init_skip: True
gradient_clip_val: 0.5
max_epochs: 1000
num_batches_per_epoch: 128
batch_size: 64
# Used only in callback,
# the final evaluation uses 100 samples
num_samples: 16
sampler: ddpm
sampler_params:
  guidance: quantile
  scale: 8
use_validation_set: True
eval_every: 50
device: cuda:0
setup: forecasting

================================================
FILE: configs/train_tsdiff/train_kdd_cup.yaml
================================================
model: unconditional
diffusion_config: diffusion_small_config
normalization: mean
use_features: False
use_lags: True
dataset: kdd_cup_2018_without_missing
freq: H
context_length: 312 # 360 for `D`
prediction_length: 48 # 30 for `D`
lr: 1.e-3
init_skip: True
gradient_clip_val: 0.5
max_epochs: 1000
num_batches_per_epoch: 128
batch_size: 64
# Used only in callback,
# the final evaluation uses 100 samples
num_samples: 16
sampler: ddpm
sampler_params:
  guidance: quantile
  scale: 1
use_validation_set: True
eval_every: 50
device: cuda:0
setup: forecasting

================================================
FILE: configs/train_tsdiff/train_m4.yaml
================================================
model: unconditional
diffusion_config: diffusion_small_config
normalization: mean
use_features: False
use_lags: False
dataset: m4_hourly
freq: H
context_length: 312 # 360 for `D`
prediction_length: 48 # 30 for `D`
lr: 1.e-3
init_skip: False
gradient_clip_val: 0.5
max_epochs: 1000
num_batches_per_epoch: 128
batch_size: 64
# Used only in callback,
# the final evaluation uses 100 samples
num_samples: 16
sampler: ddpm
sampler_params:
  guidance: quantile
  scale: 2
use_validation_set: True
eval_every: 50
device: cuda:0
setup: forecasting

================================================
FILE: configs/train_tsdiff/train_missing_electricity.yaml
================================================
model: unconditional
diffusion_config: diffusion_small_config
normalization: mean
use_features: False
use_lags: True
dataset: electricity_nips
freq: H
context_length: 336 # 360 for `D`
prediction_length: 24 # 30 for `D`
lr: 1.e-3
init_skip: False
gradient_clip_val: 0.5
max_epochs: 1000
num_batches_per_epoch: 128
batch_size: 64
# Used only in callback,
# the final evaluation uses 100 samples
num_samples: 16
sampler: ddpm
sampler_params:
  guidance: quantile
  scale: 4
use_validation_set: True
eval_every: 50
device: cuda:0
setup: missing_values
# The following key will be ignored,
# if the setup is forecasting
missing_data_configs:
- missing_scenario: none
  missing_values: 0
- missing_scenario: BM-E
  missing_values: 168
- missing_scenario: BM-B
  missing_values: 168
- missing_scenario: RM
  missing_values: 168

================================================
FILE: configs/train_tsdiff/train_missing_exchange.yaml
================================================
model: unconditional
diffusion_config: diffusion_small_config
normalization: mean
use_features: False
use_lags: True
dataset: exchange_rate_nips
freq: B
context_length: 360 # 360 for `D`
prediction_length: 30 # 30 for `D`
lr: 1.e-3
init_skip: True
gradient_clip_val: 0.5
max_epochs: 1000
num_batches_per_epoch: 128
batch_size: 64
# Used only in callback,
# the final evaluation uses 100 samples
num_samples: 16
sampler: ddpm
sampler_params:
  guidance: quantile
  scale: 8
use_validation_set: True
eval_every: 50
device: cuda:0
setup: missing_values
# The following key will be ignored,
# if the setup is forecasting
missing_data_configs:
- missing_scenario: none
  missing_values: 0
- missing_scenario: BM-E
  missing_values: 180
- missing_scenario: BM-B
  missing_values: 180
- missing_scenario: RM
  missing_values: 180

================================================
FILE: configs/train_tsdiff/train_missing_kdd_cup.yaml
================================================
model: unconditional
diffusion_config: diffusion_small_config
normalization: mean
use_features: False
use_lags: True
dataset: kdd_cup_2018_without_missing
freq: H
context_length: 312 # 360 for `D`
prediction_length: 48 # 30 for `D`
lr: 1.e-3
init_skip: True
gradient_clip_val: 0.5
max_epochs: 1000
num_batches_per_epoch: 128
batch_size: 64
# Used only in callback,
# the final evaluation uses 100 samples
num_samples: 16
sampler: ddpm
sampler_params:
  guidance: quantile
  scale: 1
guidance: quantile
use_validation_set: True
do_final_eval: True
eval_every: 50
device: cuda:0
setup: missing_values
# The following key will be ignored,
# if the setup is forecasting
missing_data_configs:
- missing_scenario: none
  missing_values: 0
- missing_scenario: BM-E
  missing_values: 156
- missing_scenario: BM-B
  missing_values: 156
- missing_scenario: RM
  missing_values: 156

================================================
FILE: configs/train_tsdiff/train_missing_solar.yaml
================================================
model: unconditional
diffusion_config: diffusion_small_config
normalization: mean
use_features: False
use_lags: True
dataset: solar_nips
freq: H
context_length: 336 # 360 for `D`
prediction_length: 24 # 30 for `D`
lr: 1.e-3
init_skip: False
gradient_clip_val: 0.5
max_epochs: 1000
num_batches_per_epoch: 128
batch_size: 64
# Used only in callback,
# the final evaluation uses 100 samples
num_samples: 16
sampler: ddpm
sampler_params:
  guidance: quantile
  scale: 8
use_validation_set: True
eval_every: 50
device: cuda:0
setup: missing_values
# The following key will be ignored,
# if the setup is forecasting
missing_data_configs:
- missing_scenario: none
  missing_values: 0
- missing_scenario: BM-E
  missing_values: 168
- missing_scenario: BM-B
  missing_values: 168
- missing_scenario: RM
  missing_values: 168

================================================
FILE: configs/train_tsdiff/train_missing_traffic.yaml
================================================
model: unconditional
diffusion_config: diffusion_small_config
normalization: mean
use_features: False
use_lags: True
dataset: traffic_nips
freq: H
context_length: 336 # 360 for `D`
prediction_length: 24 # 30 for `D`
lr: 1.e-3
init_skip: True
gradient_clip_val: 0.5
max_epochs: 1000
num_batches_per_epoch: 128
batch_size: 64
# Used only in callback,
# the final evaluation uses 100 samples
num_samples: 4
sampler: ddpm
sampler_params:
  guidance: quantile
  scale: 4
use_validation_set: True
eval_every: 50
device: cuda:0
setup: missing_values
# The following key will be ignored,
# if the setup is forecasting
missing_data_configs:
- missing_scenario: none
  missing_values: 0
- missing_scenario: BM-E
  missing_values: 168
- missing_scenario: BM-B
  missing_values: 168
- missing_scenario: RM
  missing_values: 168

================================================
FILE: configs/train_tsdiff/train_missing_uber_tlc.yaml
================================================
model: unconditional
diffusion_config: diffusion_small_config
normalization: mean
use_features: False
use_lags: True
dataset: uber_tlc_hourly
freq: H
context_length: 336 # 360 for `D`
prediction_length: 24 # 30 for `D`
lr: 1.e-3
init_skip: False
gradient_clip_val: 0.5
max_epochs: 1000
num_batches_per_epoch: 128
batch_size: 64
# Used only in callback,
# the final evaluation uses 100 samples
num_samples: 16
sampler: ddpm
sampler_params:
  guidance: quantile
  scale: 2
use_validation_set: True
eval_every: 50
device: cuda:0
setup: missing_values
# The following key will be ignored,
# if the setup is forecasting
missing_data_configs:
- missing_scenario: none
  missing_values: 0
- missing_scenario: BM-E
  missing_values: 168
- missing_scenario: BM-B
  missing_values: 168
- missing_scenario: RM
  missing_values: 168

================================================
FILE: configs/train_tsdiff/train_solar.yaml
================================================
model: unconditional
diffusion_config: diffusion_small_config
normalization: mean
use_features: False
use_lags: True
dataset: solar_nips
freq: H
context_length: 336 # 360 for `D`
prediction_length: 24 # 30 for `D`
lr: 1.e-3
init_skip: False
gradient_clip_val: 0.5
max_epochs: 1000
num_batches_per_epoch: 128
batch_size: 64
# Used only in callback,
# the final evaluation uses 100 samples
num_samples: 16
sampler: ddpm
sampler_params:
  guidance: quantile
  scale: 8
use_validation_set: True
eval_every: 50
device: cuda:0
setup: forecasting

================================================
FILE: configs/train_tsdiff/train_traffic.yaml
================================================
model: unconditional
diffusion_config: diffusion_small_config
normalization: mean
use_features: False
use_lags: True
dataset: traffic_nips
freq: H
context_length: 336 # 360 for `D`
prediction_length: 24 # 30 for `D`
lr: 1.e-3
init_skip: True
gradient_clip_val: 0.5
max_epochs: 1000
num_batches_per_epoch: 128
batch_size: 64
# Used only in callback,
# the final evaluation uses 100 samples
num_samples: 4
sampler: ddpm
sampler_params:
  guidance: quantile
  scale: 4
use_validation_set: True
eval_every: 50
device: cuda:0
setup: forecasting

================================================
FILE: configs/train_tsdiff/train_uber_tlc.yaml
================================================
model: unconditional
diffusion_config: diffusion_small_config
normalization: mean
use_features: False
use_lags: True
dataset: uber_tlc_hourly
freq: H
context_length: 336 # 360 for `D`
prediction_length: 24 # 30 for `D`
lr: 1.e-3
init_skip: False
gradient_clip_val: 0.5
max_epochs: 1000
num_batches_per_epoch: 128
batch_size: 64
# Used only in callback,
# the final evaluation uses 100 samples
num_samples: 16
sampler: ddpm
sampler_params:
  guidance: quantile
  scale: 2
use_validation_set: True
eval_every: 50
device: cuda:0
setup: forecasting

================================================
FILE: configs/train_tsdiff/train_wiki.yaml
================================================
model: unconditional
diffusion_config: diffusion_small_config
normalization: mean
use_features: False
use_lags: False
dataset: wiki2000_nips
freq: 1D
context_length: 360 # 360 for `D`
prediction_length: 30 # 30 for `D`
lr: 1.e-3
init_skip: False
gradient_clip_val: 0.5
max_epochs: 1000
num_batches_per_epoch: 128
batch_size: 64
# Used only in callback,
# the final evaluation uses 100 samples
num_samples: 4
sampler: ddpm
sampler_params:
  guidance: quantile
  scale: 2
use_validation_set: True
eval_every: 50
device: cuda:0
setup: forecasting

================================================
FILE: configs/train_tsdiff-cond/electricity_nips.yaml
================================================
batch_size: 64
context_length: 336
dataset: electricity_nips
device: cuda:0
diffusion_config: diffusion_small_config
do_final_eval: true
eval_every: 10
freq: H
gradient_clip_val: 0.5
init_skip: false
lr: 0.001
max_epochs: 100
model: conditional
noise_observed: true
normalization: mean
num_batches_per_epoch: 128
prediction_length: 24
setup: forecasting
use_features: false
use_lags: true
use_validation_set: true


================================================
FILE: configs/train_tsdiff-cond/exchange_rate_nips.yaml
================================================
batch_size: 64
context_length: 360
dataset: exchange_rate_nips
device: cuda:0
diffusion_config: diffusion_small_config
do_final_eval: true
eval_every: 10
freq: B
gradient_clip_val: 0.5
init_skip: true
lr: 0.001
max_epochs: 100
model: conditional
noise_observed: true
normalization: mean
num_batches_per_epoch: 128
prediction_length: 30
setup: forecasting
use_features: false
use_lags: true
use_validation_set: true


================================================
FILE: configs/train_tsdiff-cond/kdd_cup_2018_without_missing.yaml
================================================
batch_size: 64
context_length: 312
dataset: kdd_cup_2018_without_missing
device: cuda:0
diffusion_config: diffusion_small_config
do_final_eval: true
eval_every: 10
freq: H
gradient_clip_val: 0.5
init_skip: true
lr: 0.001
max_epochs: 100
model: conditional
noise_observed: true
normalization: mean
num_batches_per_epoch: 128
prediction_length: 48
setup: forecasting
use_features: false
use_lags: true
use_validation_set: true


================================================
FILE: configs/train_tsdiff-cond/m4_hourly.yaml
================================================
batch_size: 64
context_length: 312
dataset: m4_hourly
device: cuda:0
diffusion_config: diffusion_small_config
do_final_eval: true
eval_every: 10
freq: H
gradient_clip_val: 0.5
init_skip: false
lr: 0.001
max_epochs: 100
model: conditional
noise_observed: true
normalization: mean
num_batches_per_epoch: 128
prediction_length: 48
setup: forecasting
use_features: false
use_lags: false
use_validation_set: true


================================================
FILE: configs/train_tsdiff-cond/missing_BM-B_electricity_nips.yaml
================================================
batch_size: 64
context_length: 336
dataset: electricity_nips
device: cuda:0
diffusion_config: diffusion_small_config
do_final_eval: true
eval_every: 10
freq: H
gradient_clip_val: 0.5
init_skip: false
lr: 0.001
max_epochs: 100
missing_scenario: BM-B
missing_values: 168
model: conditional
noise_observed: true
normalization: mean
num_batches_per_epoch: 128
prediction_length: 24
setup: missing_values
train_missing_scenario: BM-B
use_features: false
use_lags: true
use_validation_set: true


================================================
FILE: configs/train_tsdiff-cond/missing_BM-B_exchange_rate_nips.yaml
================================================
batch_size: 64
context_length: 360
dataset: exchange_rate_nips
device: cuda:0
diffusion_config: diffusion_small_config
do_final_eval: true
eval_every: 10
freq: B
gradient_clip_val: 0.5
init_skip: true
lr: 0.001
max_epochs: 100
missing_scenario: BM-B
missing_values: 180
model: conditional
noise_observed: true
normalization: mean
num_batches_per_epoch: 128
prediction_length: 30
setup: missing_values
train_missing_scenario: BM-B
use_features: false
use_lags: true
use_validation_set: true


================================================
FILE: configs/train_tsdiff-cond/missing_BM-B_kdd_cup_2018_without_missing.yaml
================================================
batch_size: 64
context_length: 312
dataset: kdd_cup_2018_without_missing
device: cuda:0
diffusion_config: diffusion_small_config
do_final_eval: true
eval_every: 10
freq: H
gradient_clip_val: 0.5
init_skip: true
lr: 0.001
max_epochs: 100
missing_scenario: BM-B
missing_values: 156
model: conditional
noise_observed: true
normalization: mean
num_batches_per_epoch: 128
prediction_length: 48
setup: missing_values
train_missing_scenario: BM-B
use_features: false
use_lags: true
use_validation_set: true


================================================
FILE: configs/train_tsdiff-cond/missing_BM-B_solar_nips.yaml
================================================
batch_size: 64
context_length: 336
dataset: solar_nips
device: cuda:0
diffusion_config: diffusion_small_config
do_final_eval: true
eval_every: 10
freq: H
gradient_clip_val: 0.5
init_skip: false
lr: 0.001
max_epochs: 100
missing_scenario: BM-B
missing_values: 168
model: conditional
noise_observed: true
normalization: mean
num_batches_per_epoch: 128
prediction_length: 24
setup: missing_values
train_missing_scenario: BM-B
use_features: false
use_lags: true
use_validation_set: true


================================================
FILE: configs/train_tsdiff-cond/missing_BM-B_traffic_nips.yaml
================================================
batch_size: 64
context_length: 336
dataset: traffic_nips
device: cuda:0
diffusion_config: diffusion_small_config
do_final_eval: true
eval_every: 10
freq: H
gradient_clip_val: 0.5
init_skip: true
lr: 0.001
max_epochs: 100
missing_scenario: BM-B
missing_values: 168
model: conditional
noise_observed: true
normalization: mean
num_batches_per_epoch: 128
prediction_length: 24
setup: missing_values
train_missing_scenario: BM-B
use_features: false
use_lags: true
use_validation_set: true


================================================
FILE: configs/train_tsdiff-cond/missing_BM-B_uber_tlc_hourly.yaml
================================================
batch_size: 64
context_length: 336
dataset: uber_tlc_hourly
device: cuda:0
diffusion_config: diffusion_small_config
do_final_eval: true
eval_every: 10
freq: H
gradient_clip_val: 0.5
init_skip: false
lr: 0.001
max_epochs: 100
missing_scenario: BM-B
missing_values: 168
model: conditional
noise_observed: true
normalization: mean
num_batches_per_epoch: 128
prediction_length: 24
setup: missing_values
train_missing_scenario: BM-B
use_features: false
use_lags: true
use_validation_set: true


================================================
FILE: configs/train_tsdiff-cond/missing_BM-E_electricity_nips.yaml
================================================
batch_size: 64
context_length: 336
dataset: electricity_nips
device: cuda:0
diffusion_config: diffusion_small_config
do_final_eval: true
eval_every: 10
freq: H
gradient_clip_val: 0.5
init_skip: false
lr: 0.001
max_epochs: 100
missing_scenario: BM-E
missing_values: 168
model: conditional
noise_observed: true
normalization: mean
num_batches_per_epoch: 128
prediction_length: 24
setup: missing_values
train_missing_scenario: BM-E
use_features: false
use_lags: true
use_validation_set: true


================================================
FILE: configs/train_tsdiff-cond/missing_BM-E_exchange_rate_nips.yaml
================================================
batch_size: 64
context_length: 360
dataset: exchange_rate_nips
device: cuda:0
diffusion_config: diffusion_small_config
do_final_eval: true
eval_every: 10
freq: B
gradient_clip_val: 0.5
init_skip: true
lr: 0.001
max_epochs: 100
missing_scenario: BM-E
missing_values: 180
model: conditional
noise_observed: true
normalization: mean
num_batches_per_epoch: 128
prediction_length: 30
setup: missing_values
train_missing_scenario: BM-E
use_features: false
use_lags: true
use_validation_set: true


================================================
FILE: configs/train_tsdiff-cond/missing_BM-E_kdd_cup_2018_without_missing.yaml
================================================
batch_size: 64
context_length: 312
dataset: kdd_cup_2018_without_missing
device: cuda:0
diffusion_config: diffusion_small_config
do_final_eval: true
eval_every: 10
freq: H
gradient_clip_val: 0.5
init_skip: true
lr: 0.001
max_epochs: 100
missing_scenario: BM-E
missing_values: 156
model: conditional
noise_observed: true
normalization: mean
num_batches_per_epoch: 128
prediction_length: 48
setup: missing_values
train_missing_scenario: BM-E
use_features: false
use_lags: true
use_validation_set: true


================================================
FILE: configs/train_tsdiff-cond/missing_BM-E_solar_nips.yaml
================================================
batch_size: 64
context_length: 336
dataset: solar_nips
device: cuda:0
diffusion_config: diffusion_small_config
do_final_eval: true
eval_every: 10
freq: H
gradient_clip_val: 0.5
init_skip: false
lr: 0.001
max_epochs: 100
missing_scenario: BM-E
missing_values: 168
model: conditional
noise_observed: true
normalization: mean
num_batches_per_epoch: 128
prediction_length: 24
setup: missing_values
train_missing_scenario: BM-E
use_features: false
use_lags: true
use_validation_set: true


================================================
FILE: configs/train_tsdiff-cond/missing_BM-E_traffic_nips.yaml
================================================
batch_size: 64
context_length: 336
dataset: traffic_nips
device: cuda:0
diffusion_config: diffusion_small_config
do_final_eval: true
eval_every: 10
freq: H
gradient_clip_val: 0.5
init_skip: true
lr: 0.001
max_epochs: 100
missing_scenario: BM-E
missing_values: 168
model: conditional
noise_observed: true
normalization: mean
num_batches_per_epoch: 128
prediction_length: 24
setup: missing_values
train_missing_scenario: BM-E
use_features: false
use_lags: true
use_validation_set: true


================================================
FILE: configs/train_tsdiff-cond/missing_BM-E_uber_tlc_hourly.yaml
================================================
batch_size: 64
context_length: 336
dataset: uber_tlc_hourly
device: cuda:0
diffusion_config: diffusion_small_config
do_final_eval: true
eval_every: 10
freq: H
gradient_clip_val: 0.5
init_skip: false
lr: 0.001
max_epochs: 100
missing_scenario: BM-E
missing_values: 168
model: conditional
noise_observed: true
normalization: mean
num_batches_per_epoch: 128
prediction_length: 24
setup: missing_values
train_missing_scenario: BM-E
use_features: false
use_lags: true
use_validation_set: true


================================================
FILE: configs/train_tsdiff-cond/missing_RM_electricity_nips.yaml
================================================
batch_size: 64
context_length: 336
dataset: electricity_nips
device: cuda:0
diffusion_config: diffusion_small_config
do_final_eval: true
eval_every: 10
freq: H
gradient_clip_val: 0.5
init_skip: false
lr: 0.001
max_epochs: 100
missing_scenario: RM
missing_values: 168
model: conditional
noise_observed: true
normalization: mean
num_batches_per_epoch: 128
prediction_length: 24
setup: missing_values
train_missing_scenario: RM
use_features: false
use_lags: true
use_validation_set: true


================================================
FILE: configs/train_tsdiff-cond/missing_RM_exchange_rate_nips.yaml
================================================
batch_size: 64
context_length: 360
dataset: exchange_rate_nips
device: cuda:0
diffusion_config: diffusion_small_config
do_final_eval: true
eval_every: 10
freq: B
gradient_clip_val: 0.5
init_skip: true
lr: 0.001
max_epochs: 100
missing_scenario: RM
missing_values: 180
model: conditional
noise_observed: true
normalization: mean
num_batches_per_epoch: 128
prediction_length: 30
setup: missing_values
train_missing_scenario: RM
use_features: false
use_lags: true
use_validation_set: true


================================================
FILE: configs/train_tsdiff-cond/missing_RM_kdd_cup_2018_without_missing.yaml
================================================
batch_size: 64
context_length: 312
dataset: kdd_cup_2018_without_missing
device: cuda:0
diffusion_config: diffusion_small_config
do_final_eval: true
eval_every: 10
freq: H
gradient_clip_val: 0.5
init_skip: true
lr: 0.001
max_epochs: 100
missing_scenario: RM
missing_values: 156
model: conditional
noise_observed: true
normalization: mean
num_batches_per_epoch: 128
prediction_length: 48
setup: missing_values
train_missing_scenario: RM
use_features: false
use_lags: true
use_validation_set: true


================================================
FILE: configs/train_tsdiff-cond/missing_RM_solar_nips.yaml
================================================
batch_size: 64
context_length: 336
dataset: solar_nips
device: cuda:0
diffusion_config: diffusion_small_config
do_final_eval: true
eval_every: 10
freq: H
gradient_clip_val: 0.5
init_skip: false
lr: 0.001
max_epochs: 100
missing_scenario: RM
missing_values: 168
model: conditional
noise_observed: true
normalization: mean
num_batches_per_epoch: 128
prediction_length: 24
setup: missing_values
train_missing_scenario: RM
use_features: false
use_lags: true
use_validation_set: true


================================================
FILE: configs/train_tsdiff-cond/missing_RM_traffic_nips.yaml
================================================
batch_size: 64
context_length: 336
dataset: traffic_nips
device: cuda:0
diffusion_config: diffusion_small_config
do_final_eval: true
eval_every: 10
freq: H
gradient_clip_val: 0.5
init_skip: true
lr: 0.001
max_epochs: 100
missing_scenario: RM
missing_values: 168
model: conditional
noise_observed: true
normalization: mean
num_batches_per_epoch: 128
prediction_length: 24
setup: missing_values
train_missing_scenario: RM
use_features: false
use_lags: true
use_validation_set: true


================================================
FILE: configs/train_tsdiff-cond/missing_RM_uber_tlc_hourly.yaml
================================================
batch_size: 64
context_length: 336
dataset: uber_tlc_hourly
device: cuda:0
diffusion_config: diffusion_small_config
do_final_eval: true
eval_every: 10
freq: H
gradient_clip_val: 0.5
init_skip: false
lr: 0.001
max_epochs: 100
missing_scenario: RM
missing_values: 168
model: conditional
noise_observed: true
normalization: mean
num_batches_per_epoch: 128
prediction_length: 24
setup: missing_values
train_missing_scenario: RM
use_features: false
use_lags: true
use_validation_set: true


================================================
FILE: configs/train_tsdiff-cond/solar_nips.yaml
================================================
batch_size: 64
context_length: 336
dataset: solar_nips
device: cuda:0
diffusion_config: diffusion_small_config
do_final_eval: true
eval_every: 10
freq: H
gradient_clip_val: 0.5
init_skip: false
lr: 0.001
max_epochs: 100
model: conditional
noise_observed: true
normalization: mean
num_batches_per_epoch: 128
prediction_length: 24
setup: forecasting
use_features: false
use_lags: true
use_validation_set: true


================================================
FILE: configs/train_tsdiff-cond/traffic_nips.yaml
================================================
batch_size: 64
context_length: 336
dataset: traffic_nips
device: cuda:0
diffusion_config: diffusion_small_config
do_final_eval: true
eval_every: 10
freq: H
gradient_clip_val: 0.5
init_skip: true
lr: 0.001
max_epochs: 100
model: conditional
noise_observed: true
normalization: mean
num_batches_per_epoch: 128
prediction_length: 24
setup: forecasting
use_features: false
use_lags: true
use_validation_set: true


================================================
FILE: configs/train_tsdiff-cond/uber_tlc_hourly.yaml
================================================
batch_size: 64
context_length: 336
dataset: uber_tlc_hourly
device: cuda:0
diffusion_config: diffusion_small_config
do_final_eval: true
eval_every: 10
freq: H
gradient_clip_val: 0.5
init_skip: false
lr: 0.001
max_epochs: 100
model: conditional
noise_observed: true
normalization: mean
num_batches_per_epoch: 128
prediction_length: 24
setup: forecasting
use_features: false
use_lags: true
use_validation_set: true


================================================
FILE: configs/train_tsdiff-cond/wiki2000_nips.yaml
================================================
batch_size: 64
context_length: 360
dataset: wiki2000_nips
device: cuda:0
diffusion_config: diffusion_small_config
do_final_eval: true
eval_every: 10
freq: 1D
gradient_clip_val: 0.5
init_skip: false
lr: 0.001
max_epochs: 100
model: conditional
noise_observed: true
normalization: mean
num_batches_per_epoch: 128
prediction_length: 30
setup: forecasting
use_features: false
use_lags: false
use_validation_set: true


================================================
FILE: configs/train_tsdiff-cond.yaml
================================================
model: conditional
diffusion_config: diffusion_small_config
normalization: mean
use_features: False
use_lags: True
dataset: solar_nips
freq: H
context_length: 336 # 360 for `D`
prediction_length: 24 # 30 for `D`
lr: 1.e-3
init_skip: False
gradient_clip_val: 0.5
max_epochs: 100
num_batches_per_epoch: 128
batch_size: 64
use_validation_set: True
eval_every: 10
device: cuda:0
noise_observed: True
do_final_eval: True
setup: missing_values
# The following keys will be ignored, if the setup is forecasting
train_missing_scenario: BM-E
missing_scenario: BM-E
missing_values: 168


================================================
FILE: configs/train_tsdiff.yaml
================================================
model: unconditional
diffusion_config: diffusion_small_config
normalization: mean
use_features: False
use_lags: False
dataset: solar_nips
freq: H
context_length: 336 # 360 for `D`
prediction_length: 24 # 30 for `D`
lr: 1.e-3
init_skip: True
gradient_clip_val: 0.5
max_epochs: 100
num_batches_per_epoch: 128
batch_size: 64
scale: 4
# Used only in callback,
# the final evaluation uses 100 samples
num_samples: 16
sampler: ddpm
sampler_params:
  guidance: quantile
  scale: 4
use_validation_set: True
eval_every: 50
device: cuda:0
setup: forecasting
do_final_eval: True
# The following key will be ignored,
# if the setup is forecasting
missing_data_configs:
- missing_scenario: BM-B
  missing_values: 168
- missing_scenario: BM-E
  missing_values: 168


================================================
FILE: configs/tstr/electricity_nips.yaml
================================================
ckpt: dummy/electricity_nips.ckpt
context_length: 336
dataset: electricity_nips
device: cuda:0
diffusion_config: diffusion_small_config
init_skip: false
prediction_length: 24
scaling_type: mean
use_features: false
use_lags: true


================================================
FILE: configs/tstr/exchange_rate_nips.yaml
================================================
ckpt: dummy/exchange_rate_nips.ckpt
context_length: 360
dataset: exchange_rate_nips
device: cuda:0
diffusion_config: diffusion_small_config
init_skip: true
prediction_length: 30
scaling_type: mean
use_features: false
use_lags: true


================================================
FILE: configs/tstr/kdd_cup_2018_without_missing.yaml
================================================
ckpt: dummy/kdd_cup_2018_without_missing.ckpt
context_length: 312
dataset: kdd_cup_2018_without_missing
device: cuda:0
diffusion_config: diffusion_small_config
init_skip: true
prediction_length: 48
scaling_type: mean
use_features: false
use_lags: true


================================================
FILE: configs/tstr/m4_hourly.yaml
================================================
ckpt: dummy/m4_hourly.ckpt
context_length: 312
dataset: m4_hourly
device: cuda:0
diffusion_config: diffusion_small_config
init_skip: false
prediction_length: 48
scaling_type: mean
use_features: false
use_lags: false


================================================
FILE: configs/tstr/solar_nips.yaml
================================================
ckpt: dummy/solar_nips.ckpt
context_length: 336
dataset: solar_nips
device: cuda:0
diffusion_config: diffusion_small_config
init_skip: false
prediction_length: 24
scaling_type: mean
use_features: false
use_lags: true


================================================
FILE: configs/tstr/traffic_nips.yaml
================================================
ckpt: dummy/traffic_nips.ckpt
context_length: 336
dataset: traffic_nips
device: cuda:0
diffusion_config: diffusion_small_config
init_skip: true
prediction_length: 24
scaling_type: mean
use_features: false
use_lags: true


================================================
FILE: configs/tstr/uber_tlc_hourly.yaml
================================================
ckpt: dummy/uber_tlc_hourly.ckpt
context_length: 336
dataset: uber_tlc_hourly
device: cuda:0
diffusion_config: diffusion_small_config
init_skip: false
prediction_length: 24
scaling_type: mean
use_features: false
use_lags: true


================================================
FILE: configs/tstr/wiki2000_nips.yaml
================================================
ckpt: dummy/wiki2000_nips.ckpt
context_length: 360
dataset: wiki2000_nips
device: cuda:0
diffusion_config: diffusion_small_config
init_skip: false
prediction_length: 30
scaling_type: mean
use_features: false
use_lags: false


================================================
FILE: configs/tstr.yaml
================================================
# Model & checkpoint parameters
dataset: solar_nips
device: cuda:0
ckpt: ckpts/solar_nips/version_236/1299_.ckpt
diffusion_config: diffusion_small_config
context_length: 336
prediction_length: 24
use_lags: true
use_features: false
init_skip: true
scaling_type: mean

================================================
FILE: pyproject.toml
================================================
[project]
name = "uncond-ts-diff"
version = "0.1.0"
description = "TSDiff: An Unconditional Diffusion Model for Time Series"
authors = []
dependencies = [
    "torch~=1.13.1",
    "pytorch-lightning~=1.9.4",
    "gluonts[mxnet,pro]~=0.12.3",
    "matplotlib",
    "seaborn",
    "opt_einsum~=3.3.0",
    "einops",
    "black",
    "tqdm",
    "scipy",
    "scikit-learn",
    "numba",
    "jupyter",
    "rich",
    "pykeops==2.1.1",
]
readme = "README.md"
requires-python = ">= 3.8"

[tool.black]
line-length = 79


================================================
FILE: src/uncond_ts_diff/arch/__init__.py
================================================
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from .backbones import BackboneModel

__all__ = ["BackboneModel"]


================================================
FILE: src/uncond_ts_diff/arch/backbones.py
================================================
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import math

import torch
from torch import nn

from .s4 import S4


class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(
            torch.arange(half_dim, device=device) * -embeddings
        )
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings


class S4Layer(nn.Module):
    def __init__(
        self,
        d_model,
        dropout=0.0,
    ):
        super().__init__()
        self.layer = S4(
            d_model=d_model,
            d_state=128,
            bidirectional=True,
            dropout=dropout,
            transposed=True,
            postact=None,
        )
        self.norm = nn.LayerNorm(d_model)
        self.dropout = (
            nn.Dropout1d(dropout) if dropout > 0.0 else nn.Identity()
        )

    def forward(self, x):
        """
        Input x is shape (B, d_input, L)
        """
        z = x
        # Prenorm
        z = self.norm(z.transpose(-1, -2)).transpose(-1, -2)
        # Apply layer: we ignore the state input and output for training
        z, _ = self.layer(z)
        # Dropout on the output of the layer
        z = self.dropout(z)
        # Residual connection
        x = z + x
        return x, None

    def default_state(self, *args, **kwargs):
        return self.layer.default_state(*args, **kwargs)

    def step(self, x, state, **kwargs):
        z = x
        # Prenorm
        z = self.norm(z.transpose(-1, -2)).transpose(-1, -2)
        # Apply layer
        z, state = self.layer.step(z, state, **kwargs)
        # Residual connection
        x = z + x
        return x, state


class S4Block(nn.Module):
    def __init__(self, d_model, dropout=0.0, expand=2, num_features=0):
        super().__init__()
        self.s4block = S4Layer(d_model, dropout=dropout)

        self.time_linear = nn.Linear(d_model, d_model)
        self.tanh = nn.Tanh()
        self.sigm = nn.Sigmoid()
        self.out_linear1 = nn.Conv1d(
            in_channels=d_model, out_channels=d_model, kernel_size=1
        )
        self.out_linear2 = nn.Conv1d(
            in_channels=d_model, out_channels=d_model, kernel_size=1
        )
        self.feature_encoder = nn.Conv1d(num_features, d_model, kernel_size=1)

    def forward(self, x, t, features=None):
        t = self.time_linear(t)[:, None, :].repeat(1, x.shape[2], 1)
        t = t.transpose(-1, -2)
        out, _ = self.s4block(x + t)
        if features is not None:
            out = out + self.feature_encoder(features)
        out = self.tanh(out) * self.sigm(out)
        out1 = self.out_linear1(out)
        out2 = self.out_linear2(out)
        return out1 + x, out2


def Conv1dKaiming(in_channels, out_channels, kernel_size):
    layer = nn.Conv1d(in_channels, out_channels, kernel_size)
    nn.init.kaiming_normal_(layer.weight)
    return layer


class BackboneModel(nn.Module):
    def __init__(
        self,
        input_dim,
        hidden_dim,
        output_dim,
        step_emb,
        num_residual_blocks,
        num_features,
        residual_block="s4",
        dropout=0.0,
        init_skip=True,
    ):
        super().__init__()
        if residual_block == "s4":
            residual_block = S4Block
        else:
            raise ValueError(f"Unknown residual block {residual_block}")
        self.input_init = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
        )
        self.time_init = nn.Sequential(
            nn.Linear(step_emb, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.SiLU(),
        )
        self.out_linear = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim),
        )
        residual_blocks = []
        for i in range(num_residual_blocks):
            residual_blocks.append(
                residual_block(
                    hidden_dim, num_features=num_features, dropout=dropout
                )
            )
        self.residual_blocks = nn.ModuleList(residual_blocks)
        self.step_embedding = SinusoidalPositionEmbeddings(step_emb)
        self.init_skip = init_skip

    def forward(self, input, t, features=None):
        x = self.input_init(input)  # B, L ,C
        t = self.time_init(self.step_embedding(t))
        x = x.transpose(-1, -2)
        if features is not None:
            features = features.transpose(-1, -2)
        skips = []
        for layer in self.residual_blocks:
            x, skip = layer(x, t, features)
            skips.append(skip)

        skip = torch.stack(skips).sum(0)
        skip = skip.transpose(-1, -2)
        out = self.out_linear(skip)
        if self.init_skip:
            out = out + input
        return out


================================================
FILE: src/uncond_ts_diff/arch/s4.py
================================================
"""Standalone version of Structured (Sequence) State Space (S4) model."""

import logging
from functools import partial
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorch_lightning.utilities import rank_zero_only
from einops import rearrange, repeat
import opt_einsum as oe

contract = oe.contract
contract_expression = oe.contract_expression


def get_logger(name=__name__, level=logging.INFO) -> logging.Logger:
    """Initializes multi-GPU-friendly python logger."""

    logger = logging.getLogger(name)
    logger.setLevel(level)

    # this ensures all logging levels get marked with the rank zero decorator
    # otherwise logs would get multiplied for each GPU process in multi-GPU setup
    for level in (
        "debug",
        "info",
        "warning",
        "error",
        "exception",
        "fatal",
        "critical",
    ):
        setattr(logger, level, rank_zero_only(getattr(logger, level)))

    return logger


log = get_logger(__name__)

""" Cauchy and Vandermonde kernels """

try:  # Try CUDA extension
    from extensions.cauchy.cauchy import cauchy_mult

    has_cauchy_extension = True
except ImportError:
    # log.warning(
    #     "CUDA extension for cauchy multiplication not found. Install by going to extensions/cauchy/ and running `python setup.py install`. This should speed up end-to-end training by 10-50%"
    # )
    has_cauchy_extension = False

try:  # Try pykeops
    from pykeops.torch import Genred

    has_pykeops = True
    log.info("Pykeops installation found.")

    def _broadcast_dims(*tensors):
        max_dim = max([len(tensor.shape) for tensor in tensors])
        tensors = [
            tensor.view((1,) * (max_dim - len(tensor.shape)) + tensor.shape)
            for tensor in tensors
        ]
        return tensors

    def cauchy_conj(v, z, w):
        """Pykeops version"""
        expr_num = "z * ComplexReal(v) - Real2Complex(Sum(v * w))"
        expr_denom = "ComplexMult(z-w, z-Conj(w))"

        cauchy_mult = Genred(
            f"ComplexDivide({expr_num}, {expr_denom})",
            [
                "v = Vj(2)",
                "z = Vi(2)",
                "w = Vj(2)",
            ],
            reduction_op="Sum",
            axis=1,
        )

        v, z, w = _broadcast_dims(v, z, w)
        v = _c2r(v)
        z = _c2r(z)
        w = _c2r(w)

        r = 2 * cauchy_mult(v, z, w, backend="GPU")
        return _r2c(r)

    def log_vandermonde(v, x, L):
        expr = "ComplexMult(v, ComplexExp(ComplexMult(x, l)))"
        vandermonde_mult = Genred(
            expr,
            [
                "v = Vj(2)",
                "x = Vj(2)",
                "l = Vi(2)",
            ],
            reduction_op="Sum",
            axis=1,
        )

        l = torch.arange(L).to(x)
        v, x, l = _broadcast_dims(v, x, l)
        v = _c2r(v)
        x = _c2r(x)
        l = _c2r(l)

        r = vandermonde_mult(v, x, l, backend="GPU")
        return 2 * _r2c(r).real

    def log_vandermonde_transpose(u, v, x, L):
        """
        u: ... H L
        v: ... H N
        x: ... H N
        Returns: ... H N

        V = Vandermonde(a, L) : (H N L)
        contract_L(V * u * v)
        """
        expr = "ComplexMult(ComplexMult(v, u), ComplexExp(ComplexMult(x, l)))"
        vandermonde_mult = Genred(
            expr,
            [
                "u = Vj(2)",
                "v = Vi(2)",
                "x = Vi(2)",
                "l = Vj(2)",
            ],
            reduction_op="Sum",
            axis=1,
        )

        l = torch.arange(L).to(x)
        u, v, x, l = _broadcast_dims(u, v, x, l)
        u = _c2r(u)
        v = _c2r(v)
        x = _c2r(x)
        l = _c2r(l)

        r = vandermonde_mult(u, v, x, l, backend="GPU")
        return _r2c(r)

except ImportError:
    has_pykeops = False
    if not has_cauchy_extension:
        log.warning(
            "Falling back on slow Cauchy kernel. Install at least one of pykeops or the CUDA extension for efficiency."
        )

        def cauchy_naive(v, z, w):
            """
            v, w: (..., N)
            z: (..., L)
            returns: (..., L)
            """
            cauchy_matrix = v.unsqueeze(-1) / (
                z.unsqueeze(-2) - w.unsqueeze(-1)
            )  # (... N L)
            return torch.sum(cauchy_matrix, dim=-2)

    # Vandermonde functions
    log.warning(
        "Falling back on slow Vandermonde kernel. Install pykeops for improved memory efficiency."
    )

    def log_vandermonde(v, x, L):
        """
        v: (..., N)
        x: (..., N)
        returns: (..., L) \sum v x^l
        """
        vandermonde_matrix = torch.exp(
            x.unsqueeze(-1) * torch.arange(L).to(x)
        )  # (... N L)
        vandermonde_prod = contract(
            "... n, ... n l -> ... l", v, vandermonde_matrix
        )  # (... L)
        return 2 * vandermonde_prod.real

    def log_vandermonde_transpose(u, v, x, L):
        vandermonde_matrix = torch.exp(
            x.unsqueeze(-1) * torch.arange(L).to(x)
        )  # (... N L)
        vandermonde_prod = contract(
            "... l, ... n, ... n l -> ... n",
            u.to(x),
            v.to(x),
            vandermonde_matrix,
        )  # (... L)
        return vandermonde_prod


def _conj(x):
    return torch.cat([x, x.conj()], dim=-1)


_c2r = torch.view_as_real
_r2c = torch.view_as_complex
if tuple(map(int, torch.__version__.split(".")[:2])) >= (1, 10):

    def _resolve_conj(x):
        return x.conj().resolve_conj()

else:

    def _resolve_conj(x):
        return x.conj()


""" Simple nn.Module components """


def Activation(activation=None, dim=-1):
    if activation in [None, "id", "identity", "linear"]:
        return nn.Identity()
    elif activation == "tanh":
        return nn.Tanh()
    elif activation == "relu":
        return nn.ReLU()
    elif activation == "gelu":
        return nn.GELU()
    elif activation in ["swish", "silu"]:
        return nn.SiLU()
    elif activation == "glu":
        return nn.GLU(dim=dim)
    elif activation == "sigmoid":
        return nn.Sigmoid()
    else:
        raise NotImplementedError(
            "hidden activation '{}' is not implemented".format(activation)
        )


def LinearActivation(
    d_input,
    d_output,
    bias=True,
    transposed=False,
    activation=None,
    activate=False,  # Apply activation as part of this module
    **kwargs,
):
    """Returns a linear nn.Module with control over axes order, initialization, and activation"""

    # Construct core module
    linear_cls = partial(nn.Conv1d, kernel_size=1) if transposed else nn.Linear
    if activation == "glu":
        d_output *= 2
    linear = linear_cls(d_input, d_output, bias=bias, **kwargs)

    if activate and activation is not None:
        activation = Activation(activation, dim=-2 if transposed else -1)
        linear = nn.Sequential(linear, activation)
    return linear


class DropoutNd(nn.Module):
    def __init__(self, p: float = 0.5, tie=True, transposed=True):
        """
        tie: tie dropout mask across sequence lengths (Dropout1d/2d/3d)
        """
        super().__init__()
        if p < 0 or p >= 1:
            raise ValueError(
                "dropout probability has to be in [0, 1), "
                "but got {}".format(p)
            )
        self.p = p
        self.tie = tie
        self.transposed = transposed
        self.binomial = torch.distributions.binomial.Binomial(probs=1 - self.p)

    def forward(self, X):
        """X: (batch, dim, lengths...)"""
        if self.training:
            if not self.transposed:
                X = rearrange(X, "b d ... -> b ... d")
            mask_shape = (
                X.shape[:2] + (1,) * (X.ndim - 2) if self.tie else X.shape
            )
            mask = torch.rand(*mask_shape, device=X.device) < 1.0 - self.p
            X = X * mask * (1.0 / (1 - self.p))
            if not self.transposed:
                X = rearrange(X, "b ... d -> b d ...")
            return X
        return X


""" Misc functional utilities """


def power(L, A, v=None):
    """Compute A^L and the scan sum_i A^i v_i

    A: (..., N, N)
    v: (..., N, L)
    """

    I = torch.eye(A.shape[-1]).to(A)  # , dtype=A.dtype, device=A.device)

    powers = [A]
    l = 1
    while True:
        if L % 2 == 1:
            I = powers[-1] @ I
        L //= 2
        if L == 0:
            break
        l *= 2
        powers.append(powers[-1] @ powers[-1])

    if v is None:
        return I

    # Invariants:
    # powers[-1] := A^l
    # l := largest po2 at most L

    # Note that an alternative divide and conquer to compute the reduction is possible and can be embedded into the above loop without caching intermediate powers of A
    # We do this reverse divide-and-conquer for efficiency reasons:
    # 1) it involves fewer padding steps for non-po2 L
    # 2) it involves more contiguous arrays

    # Take care of edge case for non-po2 arrays
    # Note that this initial step is a no-op for the case of power of 2 (l == L)
    k = v.size(-1) - l
    v_ = powers.pop() @ v[..., l:]
    v = v[..., :l]
    v[..., :k] = v[..., :k] + v_

    # Handle reduction for power of 2
    while v.size(-1) > 1:
        v = rearrange(v, "... (z l) -> ... z l", z=2)
        v = v[..., 0, :] + powers.pop() @ v[..., 1, :]
    return I, v.squeeze(-1)


""" HiPPO utilities """


def transition(measure, N):
    """A, B transition matrices for different measures"""
    # Legendre (translated)
    if measure == "legt":
        Q = np.arange(N, dtype=np.float64)
        R = (2 * Q + 1) ** 0.5
        j, i = np.meshgrid(Q, Q)
        A = R[:, None] * np.where(i < j, (-1.0) ** (i - j), 1) * R[None, :]
        B = R[:, None]
        A = -A

        # Halve again for timescale correctness
        A *= 0.5
        B *= 0.5
    # Legendre (scaled)
    elif measure == "legs":
        q = np.arange(N, dtype=np.float64)
        col, row = np.meshgrid(q, q)
        r = 2 * q + 1
        M = -(np.where(row >= col, r, 0) - np.diag(q))
        T = np.sqrt(np.diag(2 * q + 1))
        A = T @ M @ np.linalg.inv(T)
        B = np.diag(T)[:, None]
        B = (
            B.copy()
        )  # Otherwise "UserWarning: given NumPY array is not writeable..." after torch.as_tensor(B)
    elif measure == "legsd":
        # Essentially equivalent to S4D-LegS
        q = np.arange(N, dtype=np.float64)
        col, row = np.meshgrid(q, q)
        r = 2 * q + 1
        M = -(np.where(row >= col, r, 0) - np.diag(q))
        T = np.sqrt(np.diag(2 * q + 1))
        A = T @ M @ np.linalg.inv(T)
        B = np.diag(T)[:, None]
        B = (
            B.copy()
        )  # Otherwise "UserWarning: given NumPY array is not writeable..." after torch.as_tensor(B)
        A += 0.5 * B * B[None, :, 0]
        B = B / 2.0
    elif measure in ["fourier_diag", "foud"]:
        # Essentially equivalent to S4D-Lin
        freqs = np.arange(N // 2)
        d = np.stack([freqs, np.zeros(N // 2)], axis=-1).reshape(-1)[:-1]
        A = 2 * np.pi * (-np.diag(d, 1) + np.diag(d, -1))
        A = A - 0.5 * np.eye(N)
        B = np.zeros(N)
        B[0::2] = 2**0.5
        B[0] = 1
        B = B[:, None]
    elif measure in ["fourier", "fout"]:
        freqs = np.arange(N // 2)
        d = np.stack([np.zeros(N // 2), freqs], axis=-1).reshape(-1)[1:]
        A = np.pi * (-np.diag(d, 1) + np.diag(d, -1))
        B = np.zeros(N)
        B[0::2] = 2**0.5
        B[0] = 1

        # Subtract off rank correction - this corresponds to the other endpoint u(t-1) in this case
        A = A - B[:, None] * B[None, :]
        B = B[:, None]
    else:
        raise NotImplementedError

    return A, B


def rank_correction(measure, N, rank=1, dtype=torch.float):
    """Return low-rank matrix L such that A + L is normal"""

    if measure == "legs":
        assert rank >= 1
        P = torch.sqrt(0.5 + torch.arange(N, dtype=dtype)).unsqueeze(
            0
        )  # (1 N)
    elif measure == "legt":
        assert rank >= 2
        P = torch.sqrt(1 + 2 * torch.arange(N, dtype=dtype))  # (N)
        P0 = P.clone()
        P0[0::2] = 0.0
        P1 = P.clone()
        P1[1::2] = 0.0
        P = torch.stack([P0, P1], dim=0)  # (2 N)
        P *= 2 ** (
            -0.5
        )  # Halve the rank correct just like the original matrix was halved
    elif measure in ["fourier", "fout"]:
        P = torch.zeros(N)
        P[0::2] = 2**0.5
        P[0] = 1
        P = P.unsqueeze(0)
    elif measure in ["fourier_diag", "foud", "legsd"]:
        P = torch.zeros(1, N, dtype=dtype)
    else:
        raise NotImplementedError

    d = P.size(0)
    if rank > d:
        P = torch.cat(
            [P, torch.zeros(rank - d, N, dtype=dtype)], dim=0
        )  # (rank N)
    return P


def nplr(measure, N, rank=1, dtype=torch.float, diagonalize_precision=True):
    """Return w, p, q, V, B such that
    (w - p q^*, B) is unitarily equivalent to the original HiPPO A, B by the matrix V
    i.e. A = V[w - p q^*]V^*, B = V B
    """
    assert dtype == torch.float or dtype == torch.double
    cdtype = torch.cfloat if dtype == torch.float else torch.cdouble

    A, B = transition(measure, N)
    A = torch.as_tensor(A, dtype=dtype)  # (N, N)
    B = torch.as_tensor(B, dtype=dtype)[:, 0]  # (N,)

    P = rank_correction(measure, N, rank=rank, dtype=dtype)  # (r N)
    AP = A + torch.sum(P.unsqueeze(-2) * P.unsqueeze(-1), dim=-3)

    # We require AP to be nearly skew-symmetric
    _A = AP + AP.transpose(-1, -2)
    if (
        err := torch.sum((_A - _A[0, 0] * torch.eye(N)) ** 2) / N
    ) > 1e-5:  # if not torch.allclose(_A - _A[0,0]*torch.eye(N), torch.zeros(N, N), atol=1e-5):
        print("WARNING: HiPPO matrix not skew symmetric", err)

    # Take advantage of identity + skew-symmetric form to calculate real and imaginary parts separately
    # Imaginary part can use eigh instead of eig
    w_re = torch.mean(torch.diagonal(AP), -1, keepdim=True)

    # Diagonalize in double precision
    if diagonalize_precision:
        AP = AP.to(torch.double)
    w_im, V = torch.linalg.eigh(AP * -1j)  # (..., N) (..., N, N)
    if diagonalize_precision:
        w_im, V = w_im.to(cdtype), V.to(cdtype)
    w = w_re + 1j * w_im
    # Check: V w V^{-1} = A
    # print("check", V @ torch.diag_embed(w) @ V.conj().transpose(-1, -2))

    # Only keep half of each conjugate pair
    _, idx = torch.sort(w.imag)
    w_sorted = w[idx]
    V_sorted = V[:, idx]

    # There is an edge case when eigenvalues can be 0, which requires some machinery to handle
    # We use a huge hack here: Assume only one pair is 0, and that it is the first row/column of A (only happens in Fourier case)
    V = V_sorted[:, : N // 2]
    w = w_sorted[: N // 2]
    assert (
        w[-2].abs() > 1e-4
    ), "Only 1 zero eigenvalue allowed in diagonal part of A"
    if w[-1].abs() < 1e-4:
        V[:, -1] = 0.0
        V[0, -1] = 2**-0.5
        V[1, -1] = 2**-0.5 * 1j

    _AP = V @ torch.diag_embed(w) @ V.conj().transpose(-1, -2)
    if (err := torch.sum((2 * _AP.real - AP) ** 2) / N) > 1e-5:
        print(
            "Warning: Diagonalization of A matrix not numerically precise - error",
            err,
        )
    # print("check", V @ torch.diag_embed(w) @ V.conj().transpose(-1, -2))

    V_inv = V.conj().transpose(-1, -2)

    B = contract("ij, j -> i", V_inv, B.to(V))  # V^* B
    P = contract("ij, ...j -> ...i", V_inv, P.to(V))  # V^* P

    return w, P, B, V


def dplr(
    scaling,
    N,
    rank=1,
    H=1,
    dtype=torch.float,
    real_scale=1.0,
    imag_scale=1.0,
    random_real=False,
    random_imag=False,
    normalize=False,
    diagonal=True,
    random_B=False,
):
    assert dtype == torch.float or dtype == torch.double
    dtype = torch.cfloat if dtype == torch.float else torch.cdouble

    pi = torch.tensor(math.pi)
    if random_real:
        real_part = torch.rand(H, N // 2)
    else:
        real_part = 0.5 * torch.ones(H, N // 2)
    if random_imag:
        imag_part = N // 2 * torch.rand(H, N // 2)
    else:
        imag_part = repeat(torch.arange(N // 2), "n -> h n", h=H)

    real_part = real_scale * real_part
    if scaling == "random":
        imag_part = torch.randn(H, N // 2)
    elif scaling == "real":
        imag_part = 0 * imag_part
        real_part = 1 + repeat(torch.arange(N // 2), "n -> h n", h=H)
    elif scaling in ["linear", "lin"]:
        imag_part = pi * imag_part
    elif scaling in [
        "inverse",
        "inv",
    ]:  # Based on asymptotics of the default HiPPO matrix
        imag_part = 1 / pi * N * (N / (1 + 2 * imag_part) - 1)
    elif scaling in ["inverse2", "inv2"]:
        imag_part = 1 / pi * N * (N / (1 + imag_part) - 1)
    elif scaling in ["quadratic", "quad"]:
        imag_part = 1 / pi * (1 + 2 * imag_part) ** 2
    elif scaling in ["legs", "hippo"]:
        w, _, _, _ = nplr("legsd", N)
        imag_part = w.imag

    else:
        raise NotImplementedError
    imag_part = imag_scale * imag_part
    w = -real_part + 1j * imag_part

    # Initialize B
    if random_B:
        B = torch.randn(H, N // 2, dtype=dtype)
    else:
        B = torch.ones(H, N // 2, dtype=dtype)

    if normalize:
        norm = (
            -B / w
        )  # (H, N) # Result if you integrate the kernel with constant 1 function
        zeta = 2 * torch.sum(
            torch.abs(norm) ** 2, dim=-1, keepdim=True
        )  # Variance with a random C vector
        B = B / zeta**0.5

    P = torch.randn(rank, H, N // 2, dtype=dtype)
    if diagonal:
        P = P * 0.0
    V = torch.eye(N, dtype=dtype)[:: N // 2]  # Only used in testing
    V = repeat(V, "n m -> h n m", h=H)

    return w, P, B, V


def ssm(measure, N, R, H, **ssm_args):
    """Dispatcher to create single SSM initialization

    N: state size
    R: rank (for DPLR parameterization)
    H: number of independent SSM copies
    """

    if measure == "dplr":
        w, P, B, V = dplr(N=N, rank=R, H=H, **ssm_args)
    elif measure.startswith("diag"):
        args = measure.split("-")
        assert args[0] == "diag" and len(args) > 1
        scaling = args[1]
        w, P, B, V = dplr(
            scaling=scaling, N=N, rank=R, H=H, diagonal=True, **ssm_args
        )
    else:
        w, P, B, V = nplr(measure, N, R, **ssm_args)
        w = repeat(w, "n -> s n", s=H)
        P = repeat(P, "r n -> r s n", s=H)
        B = repeat(B, "n -> s n", s=H)
        V = repeat(V, "n m -> s n m", s=H)
    return w, P, B, V


combinations = {
    "hippo": ["legs", "fourier"],
    "diag": ["diag-inv", "diag-lin"],
    "all": ["legs", "fourier", "diag-inv", "diag-lin"],
}


def combination(measures, N, R, S, **ssm_args):
    if isinstance(measures, str):
        measures = (
            combinations[measures] if measures in combinations else [measures]
        )

    assert (
        S % len(measures) == 0
    ), f"{S} independent trainable SSM copies must be multiple of {len(measures)} different measures"
    w, P, B, V = zip(
        *[
            ssm(measure, N, R, S // len(measures), **ssm_args)
            for measure in measures
        ]
    )
    w = torch.cat(w, dim=0)  # (S N)
    P = torch.cat(P, dim=1)  # (R S N)
    B = torch.cat(B, dim=0)  # (S N)
    V = torch.cat(V, dim=0)  # (S N N)
    return w, P, B, V


class OptimModule(nn.Module):
    """Interface for Module that allows registering buffers/parameters with configurable optimizer hyperparameters"""

    def register(self, name, tensor, lr=None):
        """Register a tensor with a configurable learning rate and 0 weight decay"""

        if lr == 0.0:
            self.register_buffer(name, tensor)
        else:
            self.register_parameter(name, nn.Parameter(tensor))

            optim = {"weight_decay": 0.0}
            if lr is not None:
                optim["lr"] = lr
            setattr(getattr(self, name), "_optim", optim)


class SSKernelNPLR(OptimModule):
    """Stores a representation of and computes the SSKernel function K_L(A^dt, B^dt, C) corresponding to a discretized state space, where A is Normal + Low Rank (NPLR)"""

    @torch.no_grad()
    def _setup_C(self, L):
        """Construct C~ from C

        Two modes are supported: go directly to length L if self.L is 1, or length is doubled
        """

        if self.L.item() == 0:
            if self.verbose:
                log.info(f"S4: Initializing kernel to length {L}")
            double_length = False
        elif L > self.L.item():  # 2*int(self.L) == L:
            if self.verbose:
                log.info(
                    f"S4: Doubling length from L = {self.L.item()} to {2*self.L.item()}"
                )
            double_length = True
            L = self.L.item()  # Convenience for the math below
        else:
            return

        C = _r2c(self.C)
        dA, _ = self._setup_state()
        dA_L = power(L, dA)
        # Multiply C by I - dA_L
        C_ = _conj(C)
        prod = contract("h m n, c h n -> c h m", dA_L.transpose(-1, -2), C_)
        if double_length:
            prod = -prod  # Multiply by I + dA_L instead
        C_ = C_ - prod
        C_ = C_[..., : self.N]  # Take conjugate pairs again
        self.C.copy_(_c2r(C_))

        self.L = (
            2 * self.L if double_length else self.L + L
        )  # Preserve type/device

    def _omega(self, L, dtype, device, cache=True):
        """Calculate (and cache) FFT nodes and their "unprocessed" version with the bilinear transform
        This should be called everytime the internal length self.L changes"""

        # Use cached if available
        if (
            cache
            and hasattr(self, "omega")
            and self.omega.size(-1) == L // 2 + 1
        ):
            return self.omega, self.z

        omega = torch.tensor(
            np.exp(-2j * np.pi / (L)), dtype=dtype, device=device
        )  # \omega_{2L}
        omega = omega ** torch.arange(0, L // 2 + 1, device=device)
        z = 2 * (1 - omega) / (1 + omega)

        # Cache if necessary
        if cache:
            self.omega = omega
            self.z = z
        return omega, z

    def __init__(
        self,
        w,
        P,
        B,
        C,
        log_dt,
        L=None,  # starting/maximum length of kernel
        lr=None,
        verbose=False,
        keops=False,
        real_type="exp",  # ['none' | 'exp' | 'relu' | sigmoid']
        real_tolerance=1e-3,
        bandlimit=None,
    ):
        """
        L: Maximum length; this module computes an SSM kernel of length L
        A is represented by diag(w) - PP^*
        w: (S, N) diagonal part
        P: (R, S, N) low-rank part

        B: (S, N)
        C: (C, H, N)
        dt: (H) timescale per feature
        lr: [dict | float | None] hook to set lr of special parameters (A, B, dt)

        Dimensions:
        N (or d_state): state size
        H (or d_model): total SSM copies
        S (or n_ssm): number of trainable copies of (A, B, dt); must divide H
        R (or rank): rank of low-rank part
        C (or channels): system is 1-dim to C-dim

        The forward pass of this Module returns a tensor of shape (C, H, L)

        Note: tensor shape N here denotes half the true state size, because of conjugate symmetry
        """

        super().__init__()
        self.verbose = verbose
        self.keops = keops
        self.bandlimit = bandlimit
        self.real_type = real_type
        self.real_tolerance = real_tolerance

        # Rank of low-rank correction
        self.rank = P.shape[-3]
        assert w.size(-1) == P.size(-1) == B.size(-1) == C.size(-1)
        self.H = log_dt.size(-1)
        self.N = w.size(-1)

        # Check different SSM inits
        assert w.size(-2) == P.size(-2) == B.size(-2)  # n_ssm
        assert self.H % w.size(0) == 0
        self.n_ssm = w.size(0)
        self.repeat = self.H // w.size(
            0
        )  # Each trainable SSM needs to be duplicated this many times

        # Broadcast everything to correct shapes
        C = C.expand(
            torch.broadcast_shapes(C.shape, (1, self.H, self.N))
        )  # (C, H, N)
        B = B.unsqueeze(0)  # (1, 1, N)

        # Register parameters
        self.C = nn.Parameter(_c2r(_resolve_conj(C)))
        if lr is None or isinstance(lr, float):
            lr_dict = {}
        else:
            lr_dict, lr = lr, None
        self.register("log_dt", log_dt, lr_dict.get("dt", lr))
        self.register("B", _c2r(B), lr_dict.get("B", lr))
        self.register("P", _c2r(P), lr_dict.get("A", lr))
        self.register("inv_w_real", self._w_init(w.real), lr_dict.get("A", lr))
        self.register("w_imag", w.imag, lr_dict.get("A", lr))

        self.l_max = L
        self.register_buffer("L", torch.tensor(0))  # Internal length

    def _w_init(self, w_real):
        w_real = torch.clamp(w_real, max=-self.real_tolerance)
        if self.real_type == "none":
            return -w_real
        elif self.real_type == "exp":
            return torch.log(
                -w_real
            )  # Some of the HiPPO methods have real part 0
        elif self.real_type == "relu":
            return -w_real
        elif self.real_type == "sigmoid":
            return torch.logit(-w_real)
        elif self.real_type == "softplus":
            return torch.log(torch.exp(-w_real) - 1)
        else:
            raise NotImplementedError

    def _w(self):
        # Get the internal w (diagonal) parameter
        if self.real_type == "none":
            w_real = -self.inv_w_real
        elif self.real_type == "exp":
            w_real = -torch.exp(self.inv_w_real)
        elif self.real_type == "relu":
            w_real = -F.relu(self.inv_w_real)
        elif self.real_type == "sigmoid":
            w_real = -F.sigmoid(self.inv_w_real)
        elif self.real_type == "softplus":
            w_real = -F.softplus(self.inv_w_real)
        else:
            raise NotImplementedError
        w = w_real + 1j * self.w_imag
        return w

    def forward(self, state=None, rate=1.0, L=None):
        """
        state: (B, H, N) initial state
        rate: sampling rate factor
        L: target length

        returns:
        (C, H, L) convolution kernel (generally C=1)
        (B, H, L) output from initial state
        """

        # Initialize C~ if necessary (done in forward pass so it's on the correct device)
        if self.L.item() == 0 and self.l_max is not None and self.l_max > 0:
            self._setup_C(self.l_max)

        # Handle sampling rate logic
        # The idea is that this kernel's length (in continuous units) is self.L, while we are asked to provide a kernel of length L at (relative) frequency rate
        if L is None:
            L = round(self.L.item() / rate)

        # Increase the internal length if needed
        continuous_L = round(rate * L)
        while continuous_L > self.L.item():
            self._setup_C(continuous_L)
        discrete_L = round(self.L.item() / rate)

        dt = torch.exp(self.log_dt) * rate
        B = _r2c(self.B)
        C = _r2c(self.C)
        P = _r2c(self.P)
        Q = P.conj()
        w = self._w()  # (n_ssm, N)

        # Address bandlimiting
        if self.bandlimit is not None:
            freqs = w.imag.abs() / (2 * math.pi)  # (H, N)
            freqs = dt[:, None] / rate * freqs  # (H, N)
            mask = torch.where(freqs < self.bandlimit * 0.5, 1, 0)
            C = C * mask

        # Get FFT nodes of right length
        omega, z = self._omega(
            discrete_L, dtype=w.dtype, device=w.device, cache=(rate == 1.0)
        )

        # Broadcast parameters to same hidden features H
        B = repeat(B, "1 t n -> 1 (v t) n", v=self.repeat)
        P = repeat(P, "r t n -> r (v t) n", v=self.repeat)
        Q = repeat(Q, "r t n -> r (v t) n", v=self.repeat)
        w = repeat(w, "t n -> (v t) n", v=self.repeat)

        # Augment B
        if state is not None:
            # Have to "unbilinear" the state to put it into the same "type" as B
            # Compute 1/dt * (I + dt/2 A) @ state

            # Can do this without expanding (maybe minor speedup using conj symmetry in theory), but it's easier to read this way
            s = _conj(state) if state.size(-1) == self.N else state  # (B H N)
            sA = s * _conj(w) - contract(  # (B H N)
                "bhm, rhm, rhn -> bhn", s, _conj(Q), _conj(P)
            )
            s = s / dt.unsqueeze(-1) + sA / 2
            s = s[..., : self.N]

            B = torch.cat([s, B], dim=-3)  # (B+1, H, N)

        # Incorporate dt into A
        w = w * dt.unsqueeze(-1)  # (H N)

        # Stack B and p, C and q for convenient batching
        B = torch.cat([B, P], dim=-3)  # (B+1+R, H, N)
        C = torch.cat([C, Q], dim=-3)  # (C+R, H, N)

        # Incorporate B and C batch dimensions
        v = B.unsqueeze(-3) * C.unsqueeze(-4)  # (B+1+R, C+R, H, N)

        # Calculate resolvent at omega
        if has_cauchy_extension and z.dtype == torch.cfloat and not self.keops:
            r = cauchy_mult(v, z, w, symmetric=True)
        elif has_pykeops:
            r = cauchy_conj(v, z, w)
        else:
            r = cauchy_naive(v, z, w)
        r = r * dt[None, None, :, None]  # (B+1+R, C+R, H, L)

        # Low-rank Woodbury correction
        if self.rank == 1:
            k_f = r[:-1, :-1, :, :] - r[:-1, -1:, :, :] * r[-1:, :-1, :, :] / (
                1 + r[-1:, -1:, :, :]
            )
        elif self.rank == 2:
            r00 = r[: -self.rank, : -self.rank, :, :]
            r01 = r[: -self.rank, -self.rank :, :, :]
            r10 = r[-self.rank :, : -self.rank, :, :]
            r11 = r[-self.rank :, -self.rank :, :, :]
            det = (1 + r11[:1, :1, :, :]) * (1 + r11[1:, 1:, :, :]) - r11[
                :1, 1:, :, :
            ] * r11[1:, :1, :, :]
            s = (
                r01[:, :1, :, :] * (1 + r11[1:, 1:, :, :]) * r10[:1, :, :, :]
                + r01[:, 1:, :, :] * (1 + r11[:1, :1, :, :]) * r10[1:, :, :, :]
                - r01[:, :1, :, :] * (r11[:1, 1:, :, :]) * r10[1:, :, :, :]
                - r01[:, 1:, :, :] * (r11[1:, :1, :, :]) * r10[:1, :, :, :]
            )
            s = s / det
            k_f = r00 - s
        else:
            r00 = r[: -self.rank, : -self.rank, :, :]
            r01 = r[: -self.rank, -self.rank :, :, :]
            r10 = r[-self.rank :, : -self.rank, :, :]
            r11 = r[-self.rank :, -self.rank :, :, :]
            r11 = rearrange(r11, "a b h n -> h n a b")
            r11 = torch.linalg.inv(torch.eye(self.rank, device=r.device) + r11)
            r11 = rearrange(r11, "h n a b -> a b h n")
            k_f = r00 - torch.einsum(
                "i j h n, j k h n, k l h n -> i l h n", r01, r11, r10
            )

        # Final correction for the bilinear transform
        k_f = k_f * 2 / (1 + omega)

        # Move from frequency to coefficients
        k = torch.fft.irfft(k_f, n=discrete_L)  # (B+1, C, H, L)

        # # Truncate to target length
        k = k[..., :L]

        if state is not None:
            k_state = k[:-1, :, :, :]  # (B, C, H, L)
        else:
            k_state = None
        k_B = k[-1, :, :, :]  # (C H L)

        return k_B, k_state

    @torch.no_grad()
    def _setup_linear(self):
        """Create parameters that allow fast linear stepping of state"""
        w = self._w()
        B = _r2c(self.B)  # (H N)
        P = _r2c(self.P)
        Q = P.conj()

        # Repeat w shape properly
        B = repeat(B, "1 t n -> 1 (v t) n", v=self.repeat)
        P = repeat(P, "r t n -> r (v t) n", v=self.repeat)
        Q = repeat(Q, "r t n -> r (v t) n", v=self.repeat)
        w = repeat(w, "t n -> (v t) n", v=self.repeat)

        # Prepare Linear stepping
        dt = torch.exp(self.log_dt)
        D = (2.0 / dt.unsqueeze(-1) - w).reciprocal()  # (H, N)
        R = (
            torch.eye(self.rank, dtype=w.dtype, device=w.device)
            + 2 * contract("r h n, h n, s h n -> h r s", Q, D, P).real
        )  # (H R R)
        Q_D = rearrange(Q * D, "r h n -> h r n")
        try:
            R = torch.linalg.solve(R, Q_D)  # (H R N)
        except Exception:
            R = torch.tensor(
                np.linalg.solve(
                    R.to(Q_D).contiguous().detach().cpu(),
                    Q_D.contiguous().detach().cpu(),
                )
            ).to(Q_D)
        R = rearrange(R, "h r n -> r h n")

        self.step_params = {
            "D": D,  # (H N)
            "R": R,  # (R H N)
            "P": P,  # (R H N)
            "Q": Q,  # (R H N)
            "B": B,  # (1 H N)
            "E": 2.0 / dt.unsqueeze(-1) + w,  # (H N)
        }

    def _step_state_linear(self, u=None, state=None):
        """
        Version of the step function that has time O(N) instead of O(N^2) per step, which takes advantage of the DPLR form and bilinear discretization.

        Unfortunately, as currently implemented it's about 2x slower because it calls several sequential operations. Perhaps a fused CUDA kernel implementation would be much faster

        u: (H) input
        state: (H, N/2) state with conjugate pairs
          Optionally, the state can have last dimension N
        Returns: same shape as state
        """
        C = _r2c(self.C)  # View used for dtype/device

        if u is None:  # Special case used to find dA
            u = torch.zeros(self.H, dtype=C.dtype, device=C.device)
        if state is None:  # Special case used to find dB
            state = torch.zeros(self.H, self.N, dtype=C.dtype, device=C.device)

        step_params = self.step_params.copy()
        if (
            state.size(-1) == self.N
        ):  # Only store half of the conjugate pairs; should be true by default
            # There should be a slightly faster way using conjugate symmetry
            def contract_fn(p, x, y):
                return contract(
                    "r h n, r h m, ... h m -> ... h n",
                    _conj(p),
                    _conj(x),
                    _conj(y),
                )[
                    ..., : self.N
                ]  # inner outer product

        else:
            assert state.size(-1) == 2 * self.N
            step_params = {k: _conj(v) for k, v in step_params.items()}

            # TODO worth setting up a contract_expression in default_state if we want to use this at inference time for stepping
            def contract_fn(p, x, y):
                return contract(
                    "r h n, r h m, ... h m -> ... h n", p, x, y
                )  # inner outer product

        D = step_params["D"]  # (H N)
        E = step_params["E"]  # (H N)
        R = step_params["R"]  # (R H N)
        P = step_params["P"]  # (R H N)
        Q = step_params["Q"]  # (R H N)
        B = step_params["B"]  # (1 H N)

        new_state = E * state - contract_fn(P, Q, state)  # (B H N)
        new_state = new_state + 2.0 * B * u.unsqueeze(-1)  # (B H N)
        new_state = D * (new_state - contract_fn(P, R, new_state))

        return new_state

    def _setup_state(self):
        """Construct dA and dB for discretized state equation"""

        # Construct dA and dB by using the stepping
        self._setup_linear()
        C = _r2c(
            self.C
        )  # Just returns a view that we use for finding dtype/device

        state = torch.eye(
            2 * self.N, dtype=C.dtype, device=C.device
        ).unsqueeze(
            -2
        )  # (N 1 N)
        dA = self._step_state_linear(state=state)
        dA = rearrange(dA, "n h m -> h m n")

        u = C.new_ones(self.H)
        dB = self._step_state_linear(u=u)
        dB = _conj(dB)
        dB = rearrange(dB, "1 h n -> h n")  # (H N)
        return dA, dB

    def _step_state(self, u, state):
        """Must be called after self.default_state() is used to construct an initial state!"""
        next_state = self.state_contraction(
            self.dA, state
        ) + self.input_contraction(self.dB, u)
        return next_state

    def _setup_step(self, mode="dense"):
        """Set up dA, dB, dC discretized parameters for stepping"""
        self.dA, self.dB = self._setup_state()

        # Calculate original C
        C = _conj(_r2c(self.C))  # (H C N)
        if self.L.item() == 0:
            dC = C
        else:
            # self.C represents C_tilde
            dA_L = power(self.L.item(), self.dA)
            I = torch.eye(self.dA.size(-1)).to(dA_L)

            dC = torch.linalg.solve(
                I - dA_L.transpose(-1, -2),
                C.unsqueeze(-1),
            ).squeeze(-1)
        self.dC = dC

        # Do special preprocessing for different step modes

        self._step_mode = mode
        if mode == "linear":
            # Linear case: special step function for the state, we need to handle output
            # use conjugate symmetry by default, which affects the output projection
            self.dC = 2 * self.dC[:, :, : self.N]
        elif mode == "diagonal":
            # Eigendecomposition of the A matrix
            L, V = torch.linalg.eig(self.dA)
            V_inv = torch.linalg.inv(V)
            # Check that the eigendedecomposition is correct
            if self.verbose:
                print(
                    "Diagonalization error:",
                    torch.dist(V @ torch.diag_embed(L) @ V_inv, self.dA),
                )

            # Change the parameterization to diagonalize
            self.dA = L
            self.dB = contract("h n m, h m -> h n", V_inv, self.dB)
            self.dC = contract("h n m, c h n -> c h m", V, self.dC)

        elif mode == "dense":
            pass
        else:
            raise NotImplementedError(
                "NPLR Kernel step mode must be {'dense' | 'linear' | 'diagonal'}"
            )

    def default_state(self, *batch_shape):
        C = _r2c(self.C)
        N = C.size(-1)
        H = C.size(-2)

        # Cache the tensor contractions we will later do, for efficiency
        # These are put in this function because they depend on the batch size
        step_mode = getattr(
            self, "_step_mode", "dense"
        )  # Used in default_state, which is called without _setup_step() in forward_state()
        if step_mode != "linear":
            N *= 2

            if step_mode == "diagonal":
                self.state_contraction = contract_expression(
                    "h n, ... h n -> ... h n",
                    (H, N),
                    batch_shape + (H, N),
                )
            else:
                # Dense (quadratic) case: expand all terms
                self.state_contraction = contract_expression(
                    "h m n, ... h n -> ... h m",
                    (H, N, N),
                    batch_shape + (H, N),
                )

            self.input_contraction = contract_expression(
                "h n, ... h -> ... h n",
                (H, N),  # self.dB.shape
                batch_shape + (H,),
            )

        self.output_contraction = contract_expression(
            "c h n, ... h n -> ... c h",
            (C.shape[0], H, N),  # self.dC.shape
            batch_shape + (H, N),
        )

        state = torch.zeros(*batch_shape, H, N, dtype=C.dtype, device=C.device)
        return state

    def step(self, u, state):
        """Must have called self._setup_step() and created state with self.default_state() before calling this"""

        if self._step_mode == "linear":
            new_state = self._step_state_linear(u, state)
        else:
            new_state = self._step_state(u, state)
        y = self.output_contraction(self.dC, new_state)
        return y.real, new_state


class SSKernelDiag(OptimModule):
    """Version using (complex) diagonal state matrix (S4D)"""

    def __init__(
        self,
        A,
        B,
        C,
        log_dt,
        L=None,
        disc="bilinear",
        real_type="exp",
        lr=None,
        bandlimit=None,
    ):
        super().__init__()
        self.L = L
        self.disc = disc
        self.bandlimit = bandlimit
        self.real_type = real_type

        # Rank of low-rank correction
        assert A.size(-1) == C.size(-1)
        self.H = log_dt.size(-1)
        self.N = A.size(-1)
        assert A.size(-2) == B.size(-2)  # Number of independent SSMs trained
        assert self.H % A.size(-2) == 0
        self.n_ssm = A.size(-2)
        self.repeat = self.H // A.size(0)

        self.channels = C.shape[0]
        self.C = nn.Parameter(_c2r(_resolve_conj(C)))

        # Register parameters
        if lr is None or isinstance(lr, float):
            lr_dict = {}
        else:
            lr_dict, lr = lr, None

        self.register("log_dt", log_dt, lr_dict.get("dt", lr))
        self.register("B", _c2r(B), lr_dict.get("B", lr))
        self.register("inv_A_real", self._A_init(A.real), lr_dict.get("A", lr))
        self.register("A_imag", A.imag, lr_dict.get("A", lr))

    def _A_init(self, A_real):
        A_real = torch.clamp(A_real, max=-1e-4)
        if self.real_type == "none":
            return -A_real
        elif self.real_type == "exp":
            return torch.log(
                -A_real
            )  # Some of the HiPPO methods have real part 0
        elif self.real_type == "relu":
            return -A_real
        elif self.real_type == "sigmoid":
            return torch.logit(-A_real)
        elif self.real_type == "softplus":
            return torch.log(torch.exp(-A_real) - 1)
        else:
            raise NotImplementedError

    def _A(self):
        # Get the internal A (diagonal) parameter
        if self.real_type == "none":
            A_real = -self.inv_A_real
        elif self.real_type == "exp":
            A_real = -torch.exp(self.inv_A_real)
        elif self.real_type == "relu":
            # JAX version seems to NaN if you alloA 0's, although this code Aas fine Aithout it
            A_real = -F.relu(self.inv_A_real) - 1e-4
        elif self.real_type == "sigmoid":
            A_real = -F.sigmoid(self.inv_A_real)
        elif self.real_type == "softplus":
            A_real = -F.softplus(self.inv_A_real)
        else:
            raise NotImplementedError
        A = A_real + 1j * self.A_imag
        return A

    def forward(self, L, state=None, rate=1.0, u=None):
        """
        state: (B, H, N) initial state
        rate: sampling rate factor
        L: target length

        returns:
        (C, H, L) convolution kernel (generally C=1)
        (B, H, L) output from initial state
        """

        dt = torch.exp(self.log_dt) * rate  # (H)
        C = _r2c(self.C)  # (C H N)
        A = self._A()  # (H N)

        B = _r2c(self.B)
        B = repeat(B, "t n -> 1 (v t) n", v=self.repeat)

        if self.bandlimit is not None:
            freqs = dt[:, None] / rate * A.imag.abs() / (2 * math.pi)  # (H, N)
            mask = torch.where(freqs < self.bandlimit * 0.5, 1, 0)
            C = C * mask

        # Incorporate dt into A
        A = repeat(A, "t n -> (v t) n", v=self.repeat)
        dtA = A * dt.unsqueeze(-1)  # (H N)

        # Augment B with state
        if state is not None:
            s = state / dt.unsqueeze(-1)
            if self.disc == "bilinear":
                s = s * (1.0 + dtA / 2)
            elif self.disc == "zoh":
                s = s * dtA * dtA.exp() / (dtA.exp() - 1.0)
            B = torch.cat([s, B], dim=-3)  # (1+B H N)

        C = (B[:, None, :, :] * C).view(-1, self.H, self.N)
        if self.disc == "zoh":
            # Power up
            C = C * (torch.exp(dtA) - 1.0) / A
            K = log_vandermonde(C, dtA, L)  # (H L)
        elif self.disc == "bilinear":
            C = (
                C * (1.0 - dtA / 2).reciprocal() * dt.unsqueeze(-1)
            )  # or * dtA / A
            dA = (1.0 + dtA / 2) / (1.0 - dtA / 2)
            K = log_vandermonde(C, dA.log(), L)
        elif self.disc == "dss":
            # Implementation from DSS meant for case when real eigenvalues can be positive
            P = dtA.unsqueeze(-1) * torch.arange(L, device=C.device)  # [H N L]
            A_gt_0 = A.real > 0  # [N]
            if A_gt_0.any():
                with torch.no_grad():
                    P_max = dtA * (A_gt_0 * (L - 1))  # [H N]
                P = P - P_max.unsqueeze(-1)  # [H N L]
            S = P.exp()  # [H N L]

            dtA_neg = dtA * (1 - 2 * A_gt_0)  # [H N]
            num = dtA_neg.exp() - 1  # [H N]
            den = (dtA_neg * L).exp() - 1  # [H N]

            # Inline reciprocal function for DSS logic
            x = den * A
            x_conj = _resolve_conj(x)
            r = x_conj / (x * x_conj + 1e-7)

            C = C * num * r  # [C H N]
            K = contract("chn,hnl->chl", C, S).float()
        else:
            assert False, f"{self.disc} not supported"

        K = K.view(-1, self.channels, self.H, L)  # (1+B C H L)
        if state is not None:
            K_state = K[:-1, :, :, :]  # (B C H L)
        else:
            K_state = None
        K = K[-1, :, :, :]  # (C H L)
        return K, K_state

    def _setup_step(self):
        # These methods are organized like this to be compatible with the NPLR kernel interface
        dt = torch.exp(self.log_dt)  # (H)
        B = _r2c(self.B)  # (H N)
        C = _r2c(self.C)  # (C H N)
        self.dC = C
        A = self._A()  # (H N)

        A = repeat(A, "t n -> (v t) n", v=self.repeat)
        B = repeat(B, "t n -> (v t) n", v=self.repeat)

        # Incorporate dt into A
        dtA = A * dt.unsqueeze(-1)  # (H N)
        if self.disc == "zoh":
            self.dA = torch.exp(dtA)  # (H N)
            self.dB = B * (torch.exp(dtA) - 1.0) / A  # (C H N)
        elif self.disc == "bilinear":
            self.dA = (1.0 + dtA / 2) / (1.0 - dtA / 2)
            self.dB = (
                B * (1.0 - dtA / 2).reciprocal() * dt.unsqueeze(-1)
            )  # or * dtA / A

    def default_state(self, *batch_shape):
        C = _r2c(self.C)
        state = torch.zeros(
            *batch_shape, self.H, self.N, dtype=C.dtype, device=C.device
        )
        return state

    def step(self, u, state):
        next_state = contract(
            "h n, b h n -> b h n", self.dA, state
        ) + contract("h n, b h -> b h n", self.dB, u)
        y = contract("c h n, b h n -> b c h", self.dC, next_state)
        return 2 * y.real, next_state

    def forward_state(self, u, state):
        self._setup_step()
        AL = self.dA ** u.size(-1)
        u = u.flip(-1).to(self.dA).contiguous()  # (B H L)
        v = log_vandermonde_transpose(u, self.dB, self.dA.log(), u.size(-1))
        next_state = AL * state + v
        return next_state


class SSKernel(nn.Module):
    """Wrapper around SSKernel parameterizations.

    The SSKernel is expected to support the interface
    forward()
    default_state()
    _setup_step()
    step()
    """

    def __init__(
        self,
        H,
        N=64,
        L=None,
        measure="legs",
        rank=1,
        channels=1,
        dt_min=0.001,
        dt_max=0.1,
        deterministic=False,
        lr=None,
        mode="nplr",
        n_ssm=None,
        verbose=False,
        measure_args={},
        **kernel_args,
    ):
        """State Space Kernel which computes the convolution kernel $\\bar{K}$

        H: Number of independent SSM copies; controls the size of the model. Also called d_model in the config.
        N: State size (dimensionality of parameters A, B, C). Also called d_state in the config. Generally shouldn't need to be adjusted and doens't affect speed much.
        L: Maximum length of convolution kernel, if known. Should work in the majority of cases even if not known.
        measure: Options for initialization of (A, B). For NPLR mode, recommendations are "legs", "fout", "hippo" (combination of both). For Diag mode, recommendations are "diag-inv", "diag-lin", "diag-legs", and "diag" (combination of diag-inv and diag-lin)
        rank: Rank of low-rank correction for NPLR mode. Needs to be increased for measure "legt"
        channels: C channels turns the SSM from a 1-dim to C-dim map; can think of it having C separate "heads" per SSM. This was partly a feature to make it easier to implement bidirectionality; it is recommended to set channels=1 and adjust H to control parameters instead
        dt_min, dt_max: min and max values for the step size dt (\Delta)
        mode: Which kernel algorithm to use. 'nplr' is the full S4 model; 'diag' is the simpler S4D; 'slow' is a dense version for testing
        n_ssm: Number of independent trainable (A, B) SSMs, e.g. n_ssm=1 means all A/B parameters are tied across the H different instantiations of C. n_ssm=None means all H SSMs are completely independent. Generally, changing this option can save parameters but doesn't affect performance or speed much. This parameter must divide H
        lr: Passing in a number (e.g. 0.001) sets attributes of SSM parameers (A, B, dt). A custom optimizer hook is needed to configure the optimizer to set the learning rates appropriately for these parameters.
        """
        super().__init__()
        self.N = N
        self.H = H
        dtype, cdtype = torch.float, torch.cfloat
        self.channels = channels
        self.n_ssm = n_ssm if n_ssm is not None else H
        self.mode = mode
        self.verbose = verbose
        self.kernel_args = kernel_args

        # Generate dt
        if deterministic:
            log_dt = torch.exp(
                torch.linspace(math.log(dt_min), math.log(dt_max), H)
            )
        else:
            log_dt = torch.rand(self.H, dtype=dtype) * (
                math.log(dt_max) - math.log(dt_min)
            ) + math.log(dt_min)

        # Compute the preprocessed representation
        w, P, B, V = combination(
            measure, self.N, rank, self.n_ssm, **measure_args
        )

        # Broadcast C to have H channels
        if deterministic:
            C = torch.zeros(channels, self.n_ssm, self.N, dtype=cdtype)
            C[:, :, :1] = 1.0
            C = contract(
                "hmn, chn -> chm", V.conj().transpose(-1, -2), C
            )  # V^* C
            C = (
                repeat(C, "c t n -> c (v t) n", v=self.n_ssm // C.size(-2))
                .clone()
                .contiguous()
            )
        else:
            C = torch.randn(channels, self.H, self.N // 2, dtype=cdtype)

        # Broadcast other parameters to have n_ssm copies
        assert (
            self.n_ssm % B.size(-2) == 0
            and self.n_ssm % P.size(-2) == 0
            and self.n_ssm % w.size(-2) == 0
        )
        # Broadcast tensors to n_ssm copies
        # These will be the parameters, so make sure tensors are materialized and contiguous
        B = (
            repeat(B, "t n -> (v t) n", v=self.n_ssm // B.size(-2))
            .clone()
            .contiguous()
        )
        P = (
            repeat(P, "r t n -> r (v t) n", v=self.n_ssm // P.size(-2))
            .clone()
            .contiguous()
        )
        w = (
            repeat(w, "t n -> (v t) n", v=self.n_ssm // w.size(-2))
            .clone()
            .contiguous()
        )

        if mode == "nplr":
            self.kernel = SSKernelNPLR(
                w,
                P,
                B,
                C,
                log_dt,
                L=L,
                lr=lr,
                verbose=verbose,
                **kernel_args,
            )
        elif mode == "diag":
            if not measure.startswith("diag"):
                log.warning(
                    "Diagonal kernel (S4D) activated but initialization is not intended for S4D. Set `measure` to 'diag-lin', 'diag-inv', or 'diag-legs' for the main variants, or 'diag' for a combination of S4D-Lin and S4D-Inv."
                )
            C = C * repeat(B, "t n -> (v t) n", v=H // self.n_ssm)
            self.kernel = SSKernelDiag(
                w,
                B,
                C,
                log_dt,
                L=L,
                lr=lr,
                **kernel_args,
            )
        else:
            raise NotImplementedError(f"{mode=} is not valid")

    def forward(self, state=None, L=None, rate=1.0):
        return self.kernel(state=state, L=L, rate=rate)

    @torch.no_grad()
    def forward_state(self, u, state):
        """Forward the state through a sequence, i.e. computes the state after passing chunk through SSM

        state: (B, H, N)
        u: (B, H, L)

        Returns: (B, H, N)
        """

        if hasattr(self.kernel, "forward_state"):
            return self.kernel.forward_state(u, state)

        dA, dB = self.kernel._setup_state()  # Construct dA, dB matrices
        # dA, dB = self.kernel.dA, self.kernel.dB # (H N 
Download .txt
gitextract_qvnzuyyo/

├── .gitignore
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── NOTICE
├── README.md
├── THIRD-PARTY-LICENSES.txt
├── bin/
│   ├── guidance_experiment.py
│   ├── refinement_experiment.py
│   ├── train_cond_model.py
│   ├── train_model.py
│   └── tstr_experiment.py
├── configs/
│   ├── guidance/
│   │   ├── guidance_electricity.yaml
│   │   ├── guidance_exchange.yaml
│   │   ├── guidance_kdd_cup.yaml
│   │   ├── guidance_m4.yaml
│   │   ├── guidance_solar.yaml
│   │   ├── guidance_traffic.yaml
│   │   ├── guidance_uber_tlc.yaml
│   │   └── guidance_wiki.yaml
│   ├── guidance.yaml
│   ├── refinement/
│   │   ├── electricity_nips-deepar.yaml
│   │   ├── electricity_nips-linear.yaml
│   │   ├── electricity_nips-seasonal_naive.yaml
│   │   ├── electricity_nips-transformer.yaml
│   │   ├── exchange_rate_nips-deepar.yaml
│   │   ├── exchange_rate_nips-linear.yaml
│   │   ├── exchange_rate_nips-seasonal_naive.yaml
│   │   ├── exchange_rate_nips-transformer.yaml
│   │   ├── kdd_cup_2018_without_missing-deepar.yaml
│   │   ├── kdd_cup_2018_without_missing-linear.yaml
│   │   ├── kdd_cup_2018_without_missing-seasonal_naive.yaml
│   │   ├── kdd_cup_2018_without_missing-transformer.yaml
│   │   ├── m4_hourly-deepar.yaml
│   │   ├── m4_hourly-linear.yaml
│   │   ├── m4_hourly-seasonal_naive.yaml
│   │   ├── m4_hourly-transformer.yaml
│   │   ├── solar_nips-deepar.yaml
│   │   ├── solar_nips-linear.yaml
│   │   ├── solar_nips-seasonal_naive.yaml
│   │   ├── solar_nips-transformer.yaml
│   │   ├── traffic_nips-deepar.yaml
│   │   ├── traffic_nips-linear.yaml
│   │   ├── traffic_nips-seasonal_naive.yaml
│   │   ├── traffic_nips-transformer.yaml
│   │   ├── uber_tlc_hourly-deepar.yaml
│   │   ├── uber_tlc_hourly-linear.yaml
│   │   ├── uber_tlc_hourly-seasonal_naive.yaml
│   │   ├── uber_tlc_hourly-transformer.yaml
│   │   ├── wiki2000_nips-deepar.yaml
│   │   ├── wiki2000_nips-linear.yaml
│   │   ├── wiki2000_nips-seasonal_naive.yaml
│   │   └── wiki2000_nips-transformer.yaml
│   ├── refinement.yaml
│   ├── train_tsdiff/
│   │   ├── train_electricity.yaml
│   │   ├── train_exchange.yaml
│   │   ├── train_kdd_cup.yaml
│   │   ├── train_m4.yaml
│   │   ├── train_missing_electricity.yaml
│   │   ├── train_missing_exchange.yaml
│   │   ├── train_missing_kdd_cup.yaml
│   │   ├── train_missing_solar.yaml
│   │   ├── train_missing_traffic.yaml
│   │   ├── train_missing_uber_tlc.yaml
│   │   ├── train_solar.yaml
│   │   ├── train_traffic.yaml
│   │   ├── train_uber_tlc.yaml
│   │   └── train_wiki.yaml
│   ├── train_tsdiff-cond/
│   │   ├── electricity_nips.yaml
│   │   ├── exchange_rate_nips.yaml
│   │   ├── kdd_cup_2018_without_missing.yaml
│   │   ├── m4_hourly.yaml
│   │   ├── missing_BM-B_electricity_nips.yaml
│   │   ├── missing_BM-B_exchange_rate_nips.yaml
│   │   ├── missing_BM-B_kdd_cup_2018_without_missing.yaml
│   │   ├── missing_BM-B_solar_nips.yaml
│   │   ├── missing_BM-B_traffic_nips.yaml
│   │   ├── missing_BM-B_uber_tlc_hourly.yaml
│   │   ├── missing_BM-E_electricity_nips.yaml
│   │   ├── missing_BM-E_exchange_rate_nips.yaml
│   │   ├── missing_BM-E_kdd_cup_2018_without_missing.yaml
│   │   ├── missing_BM-E_solar_nips.yaml
│   │   ├── missing_BM-E_traffic_nips.yaml
│   │   ├── missing_BM-E_uber_tlc_hourly.yaml
│   │   ├── missing_RM_electricity_nips.yaml
│   │   ├── missing_RM_exchange_rate_nips.yaml
│   │   ├── missing_RM_kdd_cup_2018_without_missing.yaml
│   │   ├── missing_RM_solar_nips.yaml
│   │   ├── missing_RM_traffic_nips.yaml
│   │   ├── missing_RM_uber_tlc_hourly.yaml
│   │   ├── solar_nips.yaml
│   │   ├── traffic_nips.yaml
│   │   ├── uber_tlc_hourly.yaml
│   │   └── wiki2000_nips.yaml
│   ├── train_tsdiff-cond.yaml
│   ├── train_tsdiff.yaml
│   ├── tstr/
│   │   ├── electricity_nips.yaml
│   │   ├── exchange_rate_nips.yaml
│   │   ├── kdd_cup_2018_without_missing.yaml
│   │   ├── m4_hourly.yaml
│   │   ├── solar_nips.yaml
│   │   ├── traffic_nips.yaml
│   │   ├── uber_tlc_hourly.yaml
│   │   └── wiki2000_nips.yaml
│   └── tstr.yaml
├── pyproject.toml
└── src/
    └── uncond_ts_diff/
        ├── arch/
        │   ├── __init__.py
        │   ├── backbones.py
        │   └── s4.py
        ├── configs.py
        ├── dataset.py
        ├── metrics/
        │   ├── __init__.py
        │   └── linear_pred_score.py
        ├── model/
        │   ├── __init__.py
        │   ├── callback.py
        │   ├── diffusion/
        │   │   ├── _base.py
        │   │   ├── tsdiff.py
        │   │   └── tsdiff_cond.py
        │   └── linear/
        │       ├── _estimator.py
        │       └── _scaler.py
        ├── predictor.py
        ├── sampler/
        │   ├── __init__.py
        │   ├── _base.py
        │   ├── observation_guidance.py
        │   └── refiner.py
        └── utils.py
Download .txt
SYMBOL INDEX (249 symbols across 20 files)

FILE: bin/guidance_experiment.py
  function load_model (line 32) | def load_model(config):
  function evaluate_guidance (line 54) | def evaluate_guidance(
  function main (line 126) | def main(config: dict, log_dir: str):

FILE: bin/refinement_experiment.py
  function load_model (line 41) | def load_model(config):
  function get_best_diffusion_step (line 63) | def get_best_diffusion_step(model: TSDiff, data_loader, device):
  function train_and_forecast_base_model (line 81) | def train_and_forecast_base_model(dataset, base_model_name, config):
  function forecast_guidance (line 124) | def forecast_guidance(
  function main (line 158) | def main(config: dict, log_dir: str):

FILE: bin/train_cond_model.py
  function create_model (line 33) | def create_model(config):
  function evaluate_conditional (line 50) | def evaluate_conditional(
  function main (line 101) | def main(config, log_dir):

FILE: bin/train_model.py
  function create_model (line 36) | def create_model(config):
  function evaluate_guidance (line 52) | def evaluate_guidance(
  function main (line 123) | def main(config, log_dir):

FILE: bin/tstr_experiment.py
  function load_model (line 44) | def load_model(config):
  function sample_synthetic (line 66) | def sample_synthetic(
  function sample_real (line 83) | def sample_real(
  function evaluate_tstr (line 108) | def evaluate_tstr(
  function train_and_evaluate (line 143) | def train_and_evaluate(
  function main (line 214) | def main(config: dict, log_dir: str, samples_path: str):

FILE: src/uncond_ts_diff/arch/backbones.py
  class SinusoidalPositionEmbeddings (line 11) | class SinusoidalPositionEmbeddings(nn.Module):
    method __init__ (line 12) | def __init__(self, dim):
    method forward (line 16) | def forward(self, time):
  class S4Layer (line 28) | class S4Layer(nn.Module):
    method __init__ (line 29) | def __init__(
    method forward (line 48) | def forward(self, x):
    method default_state (line 63) | def default_state(self, *args, **kwargs):
    method step (line 66) | def step(self, x, state, **kwargs):
  class S4Block (line 77) | class S4Block(nn.Module):
    method __init__ (line 78) | def __init__(self, d_model, dropout=0.0, expand=2, num_features=0):
    method forward (line 93) | def forward(self, x, t, features=None):
  function Conv1dKaiming (line 105) | def Conv1dKaiming(in_channels, out_channels, kernel_size):
  class BackboneModel (line 111) | class BackboneModel(nn.Module):
    method __init__ (line 112) | def __init__(
    method forward (line 155) | def forward(self, input, t, features=None):

FILE: src/uncond_ts_diff/arch/s4.py
  function get_logger (line 18) | def get_logger(name=__name__, level=logging.INFO) -> logging.Logger:
  function _broadcast_dims (line 60) | def _broadcast_dims(*tensors):
  function cauchy_conj (line 68) | def cauchy_conj(v, z, w):
  function log_vandermonde (line 92) | def log_vandermonde(v, x, L):
  function log_vandermonde_transpose (line 114) | def log_vandermonde_transpose(u, v, x, L):
  function cauchy_naive (line 154) | def cauchy_naive(v, z, w):
  function log_vandermonde (line 170) | def log_vandermonde(v, x, L):
  function log_vandermonde_transpose (line 184) | def log_vandermonde_transpose(u, v, x, L):
  function _conj (line 197) | def _conj(x):
  function _resolve_conj (line 205) | def _resolve_conj(x):
  function _resolve_conj (line 210) | def _resolve_conj(x):
  function Activation (line 217) | def Activation(activation=None, dim=-1):
  function LinearActivation (line 238) | def LinearActivation(
  class DropoutNd (line 261) | class DropoutNd(nn.Module):
    method __init__ (line 262) | def __init__(self, p: float = 0.5, tie=True, transposed=True):
    method forward (line 277) | def forward(self, X):
  function power (line 296) | def power(L, A, v=None):
  function transition (line 345) | def transition(measure, N):
  function rank_correction (line 412) | def rank_correction(measure, N, rank=1, dtype=torch.float):
  function nplr (line 449) | def nplr(measure, N, rank=1, dtype=torch.float, diagonalize_precision=Tr...
  function dplr (line 518) | def dplr(
  function ssm (line 595) | def ssm(measure, N, R, H, **ssm_args):
  function combination (line 628) | def combination(measures, N, R, S, **ssm_args):
  class OptimModule (line 650) | class OptimModule(nn.Module):
    method register (line 653) | def register(self, name, tensor, lr=None):
  class SSKernelNPLR (line 667) | class SSKernelNPLR(OptimModule):
    method _setup_C (line 671) | def _setup_C(self, L):
    method _omega (line 707) | def _omega(self, L, dtype, device, cache=True):
    method __init__ (line 731) | def __init__(
    method _w_init (line 811) | def _w_init(self, w_real):
    method _w (line 828) | def _w(self):
    method forward (line 845) | def forward(self, state=None, rate=1.0, L=None):
    method _setup_linear (line 981) | def _setup_linear(self):
    method _step_state_linear (line 1022) | def _step_state_linear(self, u=None, state=None):
    method _setup_state (line 1078) | def _setup_state(self):
    method _step_state (line 1101) | def _step_state(self, u, state):
    method _setup_step (line 1108) | def _setup_step(self, mode="dense"):
    method default_state (line 1157) | def default_state(self, *batch_shape):
    method step (line 1199) | def step(self, u, state):
  class SSKernelDiag (line 1210) | class SSKernelDiag(OptimModule):
    method __init__ (line 1213) | def __init__(
    method _A_init (line 1254) | def _A_init(self, A_real):
    method _A (line 1271) | def _A(self):
    method forward (line 1289) | def forward(self, L, state=None, rate=1.0, u=None):
    method _setup_step (line 1368) | def _setup_step(self):
    method default_state (line 1390) | def default_state(self, *batch_shape):
    method step (line 1397) | def step(self, u, state):
    method forward_state (line 1404) | def forward_state(self, u, state):
  class SSKernel (line 1413) | class SSKernel(nn.Module):
    method __init__ (line 1423) | def __init__(
    method forward (line 1548) | def forward(self, state=None, L=None, rate=1.0):
    method forward_state (line 1552) | def forward_state(self, u, state):
    method _setup_step (line 1582) | def _setup_step(self, **kwargs):
    method step (line 1590) | def step(self, u, state, **kwargs):
    method default_state (line 1594) | def default_state(self, *args, **kwargs):
  class S4 (line 1598) | class S4(nn.Module):
    method __init__ (line 1599) | def __init__(
    method forward (line 1721) | def forward(self, u, state=None, rate=1.0, lengths=None, **kwargs):
    method setup_step (line 1807) | def setup_step(self, **kwargs):
    method step (line 1810) | def step(self, u, state):
    method default_state (line 1829) | def default_state(self, *batch_shape, device=None):
    method d_output (line 1835) | def d_output(self):

FILE: src/uncond_ts_diff/dataset.py
  function get_gts_dataset (line 15) | def get_gts_dataset(dataset_name):

FILE: src/uncond_ts_diff/metrics/linear_pred_score.py
  function linear_pred_score (line 20) | def linear_pred_score(

FILE: src/uncond_ts_diff/model/callback.py
  class GradNormCallback (line 21) | class GradNormCallback(Callback):
    method __init__ (line 22) | def __init__(self) -> None:
    method on_before_optimizer_step (line 25) | def on_before_optimizer_step(
    method grad_norm (line 36) | def grad_norm(self, parameters):
  class PredictiveScoreCallback (line 48) | class PredictiveScoreCallback(Callback):
    method __init__ (line 49) | def __init__(
    method _generate_real_samples (line 72) | def _generate_real_samples(
    method _generate_synth_samples (line 104) | def _generate_synth_samples(
    method on_train_epoch_end (line 117) | def on_train_epoch_end(self, trainer, pl_module):
  class EvaluateCallback (line 169) | class EvaluateCallback(Callback):
    method __init__ (line 170) | def __init__(
    method on_train_epoch_end (line 208) | def on_train_epoch_end(self, trainer, pl_module):

FILE: src/uncond_ts_diff/model/diffusion/_base.py
  class TSDiffBase (line 28) | class TSDiffBase(pl.LightningModule):
    method __init__ (line 29) | def __init__(
    method _extract_features (line 97) | def _extract_features(self, data):
    method configure_optimizers (line 100) | def configure_optimizers(self):
    method log (line 107) | def log(self, name, value, **kwargs):
    method get_logs (line 116) | def get_logs(self):
    method q_sample (line 121) | def q_sample(self, x_start, t, noise=None):
    method p_losses (line 137) | def p_losses(
    method p_sample (line 167) | def p_sample(self, x, t, t_index, features=None):
    method p_sample_ddim (line 187) | def p_sample_ddim(self, x, t, features=None, noise=None):
    method p_sample_genddim (line 209) | def p_sample_genddim(
    method sample (line 266) | def sample(self, noise, features=None):
    method fast_denoise (line 283) | def fast_denoise(self, xt, t, features=None, noise=None):
    method forward (line 294) | def forward(self, x, mask):
    method training_step (line 297) | def training_step(self, data, idx):
    method training_epoch_end (line 314) | def training_epoch_end(self, outputs):
    method validation_step (line 320) | def validation_step(self, data, idx):
    method validation_epoch_end (line 335) | def validation_epoch_end(self, outputs):

FILE: src/uncond_ts_diff/model/diffusion/tsdiff.py
  class TSDiff (line 13) | class TSDiff(TSDiffBase):
    method __init__ (line 14) | def __init__(
    method _extract_features (line 74) | def _extract_features(self, data):
    method sample_n (line 133) | def sample_n(
    method on_train_batch_end (line 156) | def on_train_batch_end(self, outputs, batch, batch_idx):
  function update_ema (line 161) | def update_ema(target_state_dict, source_state_dict, rate=0.99):

FILE: src/uncond_ts_diff/model/diffusion/tsdiff_cond.py
  class TSDiffCond (line 15) | class TSDiffCond(TSDiffBase):
    method __init__ (line 16) | def __init__(
    method _extract_features (line 73) | def _extract_features(self, data):
    method step (line 136) | def step(self, x, t, features, loss_mask):
    method training_step (line 158) | def training_step(self, data, idx):
    method validation_step (line 177) | def validation_step(self, data, idx):
    method forecast (line 199) | def forecast(self, observation, observation_mask, features=None):
    method forward (line 220) | def forward(
    method get_predictor (line 270) | def get_predictor(self, input_transform, batch_size=40, device=None):

FILE: src/uncond_ts_diff/model/linear/_estimator.py
  function stack (line 40) | def stack(data):
  function batchify (line 48) | def batchify(data: List[dict]):
  class LinearModel (line 54) | class LinearModel:
    method __init__ (line 55) | def __init__(self, weight, bias, scaler, num_parallel_samples=100) -> ...
    method _linear (line 62) | def _linear(self, x, A, b):
    method __call__ (line 65) | def __call__(self, x, mask):
  function _ (line 73) | def _(prediction_net, args) -> np.ndarray:
  class LinearPredictor (line 77) | class LinearPredictor(Predictor):
    method __init__ (line 78) | def __init__(
    method predict (line 95) | def predict(self, dataset: Dataset, num_samples: Optional[int] = None):
  class LinearEstimator (line 112) | class LinearEstimator(Estimator):
    method __init__ (line 149) | def __init__(
    method create_transformation (line 174) | def create_transformation(self) -> Transformation:
    method _create_instance_splitter (line 188) | def _create_instance_splitter(self, mode: str):
    method _create_training_samples (line 213) | def _create_training_samples(self, training_data) -> np.ndarray:
    method create_predictor (line 242) | def create_predictor(self, transformation, model):
    method train (line 252) | def train(

FILE: src/uncond_ts_diff/model/linear/_scaler.py
  class MeanScaler (line 8) | class MeanScaler:
    method __init__ (line 11) | def __init__(
    method __call__ (line 24) | def __call__(
  class NOPScaler (line 63) | class NOPScaler:
    method __init__ (line 68) | def __init__(self, axis: int, keepdims: bool = False):
    method __call__ (line 73) | def __call__(

FILE: src/uncond_ts_diff/predictor.py
  class PyTorchPredictorWGrads (line 12) | class PyTorchPredictorWGrads(PyTorchPredictor):
    method predict (line 13) | def predict(

FILE: src/uncond_ts_diff/sampler/_base.py
  function grad_fn (line 10) | def grad_fn(fn, x):
  function langevin_dynamics (line 16) | def langevin_dynamics(
  function leapfrog (line 67) | def leapfrog(
  function hmc (line 106) | def hmc(
  function linear_midpoint_em_step (line 150) | def linear_midpoint_em_step(
  function udld (line 161) | def udld(

FILE: src/uncond_ts_diff/sampler/observation_guidance.py
  class Guidance (line 23) | class Guidance(torch.nn.Module):
    method __init__ (line 26) | def __init__(
    method quantile_loss (line 47) | def quantile_loss(self, y_prediction, y_target):
    method energy_func (line 64) | def energy_func(self, y, t, observation, observation_mask, features):
    method score_func (line 79) | def score_func(self, y, t, observation, observation_mask, features):
    method scale_func (line 87) | def scale_func(self, y, t, base_scale):
    method guide (line 90) | def guide(self, observation, observation_mask, features, scale):
    method forward (line 93) | def forward(
    method get_predictor (line 171) | def get_predictor(self, input_transform, batch_size=40, device=None):
  class DDPMGuidance (line 182) | class DDPMGuidance(Guidance):
    method __init__ (line 183) | def __init__(
    method scale_func (line 203) | def scale_func(self, y, t, base_scale):
    method _reverse_diffusion (line 207) | def _reverse_diffusion(
    method guide (line 228) | def guide(self, observation, observation_mask, features, base_scale):
  class DDIMGuidance (line 234) | class DDIMGuidance(Guidance):
    method __init__ (line 237) | def __init__(
    method scale_func (line 264) | def scale_func(self, y, t, base_scale):
    method _get_timesteps (line 270) | def _get_timesteps(self):
    method _reverse_ddim (line 286) | def _reverse_ddim(
    method guide (line 322) | def guide(self, observation, observation_mask, features, base_scale):

FILE: src/uncond_ts_diff/sampler/refiner.py
  class Refiner (line 27) | class Refiner(torch.nn.Module):
    method __init__ (line 28) | def __init__(
    method quantile_loss (line 49) | def quantile_loss(self, y_prediction, y_target):
    method prior (line 66) | def prior(self, y_prediction, obs, obs_mask):
    method refine (line 79) | def refine(self, observation, observation_mask):
    method forward (line 82) | def forward(
    method get_predictor (line 190) | def get_predictor(self, input_transform, batch_size=40, device=None):
  class MostLikelyRefiner (line 201) | class MostLikelyRefiner(Refiner):
    method __init__ (line 202) | def __init__(
    method _most_likely (line 228) | def _most_likely(self, observation, observation_mask):
    method refine (line 254) | def refine(self, observation, observation_mask):
  class MCMCRefiner (line 258) | class MCMCRefiner(Refiner):
    method __init__ (line 261) | def __init__(
    method _mcmc (line 290) | def _mcmc(self, observation, observation_mask):
    method refine (line 355) | def refine(self, observation, observation_mask):

FILE: src/uncond_ts_diff/utils.py
  function filter_metrics (line 47) | def filter_metrics(metrics, select={"ND", "NRMSE", "mean_wQuantileLoss"}):
  function extract (line 51) | def extract(a, t, x_shape):
  function cosine_beta_schedule (line 57) | def cosine_beta_schedule(timesteps, s=0.008):
  function linear_beta_schedule (line 71) | def linear_beta_schedule(timesteps):
  function plot_train_stats (line 77) | def plot_train_stats(df: pd.DataFrame, y_keys=None, skip_first_epoch=True):
  function get_lags_for_freq (line 98) | def get_lags_for_freq(freq_str: str):
  function create_transforms (line 116) | def create_transforms(
  function create_splitter (line 189) | def create_splitter(past_length: int, future_length: int, mode: str = "t...
  function get_next_file_num (line 214) | def get_next_file_num(
  function str2bool (line 255) | def str2bool(v):
  function add_config_to_argparser (line 266) | def add_config_to_argparser(config: Dict, parser: ArgumentParser):
  class AddMeanAndStdFeature (line 280) | class AddMeanAndStdFeature(MapTransformation):
    method __init__ (line 282) | def __init__(
    method map_transform (line 292) | def map_transform(self, data: DataEntry, is_train: bool) -> DataEntry:
  class ScaleAndAddMeanFeature (line 300) | class ScaleAndAddMeanFeature(MapTransformation):
    method __init__ (line 301) | def __init__(
    method map_transform (line 322) | def map_transform(self, data, is_train: bool):
  class ScaleAndAddMinMaxFeature (line 336) | class ScaleAndAddMinMaxFeature(MapTransformation):
    method __init__ (line 337) | def __init__(
    method map_transform (line 358) | def map_transform(self, data, is_train: bool):
  function descale (line 371) | def descale(data, scale, scaling_type):
  function predict_and_descale (line 381) | def predict_and_descale(predictor, dataset, num_samples, scaling_type):
  function to_dataframe_and_descale (line 420) | def to_dataframe_and_descale(input_label, scaling_type) -> pd.DataFrame:
  function make_evaluation_predictions_with_scaling (line 447) | def make_evaluation_predictions_with_scaling(
  class PairDataset (line 489) | class PairDataset(Dataset):
    method __init__ (line 490) | def __init__(self, x, y) -> None:
    method __getitem__ (line 496) | def __getitem__(self, index):
    method __len__ (line 499) | def __len__(self):
  class GluonTSNumpyDataset (line 503) | class GluonTSNumpyDataset:
    method __init__ (line 514) | def __init__(
    method __iter__ (line 520) | def __iter__(self):
    method __len__ (line 525) | def __len__(self):
  class MaskInput (line 529) | class MaskInput(MapTransformation):
    method __init__ (line 531) | def __init__(
    method map_transform (line 548) | def map_transform(self, data: DataEntry, is_train: bool) -> DataEntry:
  class ConcatDataset (line 573) | class ConcatDataset:
    method __init__ (line 574) | def __init__(self, test_pairs, axis=-1) -> None:
    method _concat (line 578) | def _concat(self, test_pairs):
    method __iter__ (line 587) | def __iter__(self):
Condensed preview — 126 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (317K chars).
[
  {
    "path": ".gitignore",
    "chars": 122,
    "preview": "__pycache__\nlightning_logs/\n.DS_Store\n*.egg-info\n/results/\n/ckpts/\n/saved_samples/\n.vscode/\n/sm_runs/\n/data/\n/checkpoint"
  },
  {
    "path": "CODE_OF_CONDUCT.md",
    "chars": 309,
    "preview": "## Code of Conduct\nThis project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-condu"
  },
  {
    "path": "CONTRIBUTING.md",
    "chars": 3160,
    "preview": "# Contributing Guidelines\n\nThank you for your interest in contributing to our project. Whether it's a bug report, new fe"
  },
  {
    "path": "LICENSE",
    "chars": 10142,
    "preview": "\n                                 Apache License\n                           Version 2.0, January 2004\n                  "
  },
  {
    "path": "NOTICE",
    "chars": 67,
    "preview": "Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n"
  },
  {
    "path": "README.md",
    "chars": 7128,
    "preview": "# TSDiff: An Unconditional Diffusion Model for Time Series\n\n[![preprint](https://img.shields.io/static/v1?label=arXiv&me"
  },
  {
    "path": "THIRD-PARTY-LICENSES.txt",
    "chars": 10475,
    "preview": "** state-spaces; version 1.0 -- https://github.com/HazyResearch/state-spaces\n \nApache License\nVersion 2.0, January 2004\n"
  },
  {
    "path": "bin/guidance_experiment.py",
    "chars": 6372,
    "preview": "# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n# SPDX-License-Identifier: Apache-2.0\nimport loggin"
  },
  {
    "path": "bin/refinement_experiment.py",
    "chars": 10481,
    "preview": "# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n# SPDX-License-Identifier: Apache-2.0\nimport json\ni"
  },
  {
    "path": "bin/train_cond_model.py",
    "chars": 8972,
    "preview": "# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n# SPDX-License-Identifier: Apache-2.0\nimport loggin"
  },
  {
    "path": "bin/train_model.py",
    "chars": 9507,
    "preview": "# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n# SPDX-License-Identifier: Apache-2.0\nimport loggin"
  },
  {
    "path": "bin/tstr_experiment.py",
    "chars": 10825,
    "preview": "# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n# SPDX-License-Identifier: Apache-2.0\nfrom functool"
  },
  {
    "path": "configs/guidance/guidance_electricity.yaml",
    "chars": 304,
    "preview": "ckpt: dummy/path.ckpt\ncontext_length: 336\ndataset: electricity_nips\ndevice: cuda:0\ndiffusion_config: diffusion_small_con"
  },
  {
    "path": "configs/guidance/guidance_exchange.yaml",
    "chars": 305,
    "preview": "ckpt: dummy/path.ckpt\ncontext_length: 360\ndataset: exchange_rate_nips\ndevice: cuda:0\ndiffusion_config: diffusion_small_c"
  },
  {
    "path": "configs/guidance/guidance_kdd_cup.yaml",
    "chars": 315,
    "preview": "ckpt: dummy/path.ckpt\ncontext_length: 312\ndataset: kdd_cup_2018_without_missing\ndevice: cuda:0\ndiffusion_config: diffusi"
  },
  {
    "path": "configs/guidance/guidance_m4.yaml",
    "chars": 298,
    "preview": "ckpt: dummy/path.ckpt\ncontext_length: 312\ndataset: m4_hourly\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\nfre"
  },
  {
    "path": "configs/guidance/guidance_solar.yaml",
    "chars": 298,
    "preview": "ckpt: dummy/path.ckpt\ncontext_length: 336\ndataset: solar_nips\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\nfr"
  },
  {
    "path": "configs/guidance/guidance_traffic.yaml",
    "chars": 299,
    "preview": "ckpt: dummy/path.ckpt\ncontext_length: 336\ndataset: traffic_nips\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\n"
  },
  {
    "path": "configs/guidance/guidance_uber_tlc.yaml",
    "chars": 303,
    "preview": "ckpt: dummy/path.ckpt\ncontext_length: 336\ndataset: uber_tlc_hourly\ndevice: cuda:0\ndiffusion_config: diffusion_small_conf"
  },
  {
    "path": "configs/guidance/guidance_wiki.yaml",
    "chars": 303,
    "preview": "ckpt: dummy/path.ckpt\ncontext_length: 360\ndataset: wiki2000_nips\ndevice: cuda:0\ndiffusion_config: diffusion_small_config"
  },
  {
    "path": "configs/guidance.yaml",
    "chars": 535,
    "preview": "# Model & checkpoint parameters\ndataset: solar_nips\nfreq: H\ndevice: cuda:0\nckpt: ckpts/forecasting/solar_nips/649_.ckpt\n"
  },
  {
    "path": "configs/refinement/electricity_nips-deepar.yaml",
    "chars": 567,
    "preview": "base_model: deepar\nckpt: dummy/electricity_nips.ckpt\ncontext_length: 336\ndataset: electricity_nips\ndevice: cuda:0\ninit_s"
  },
  {
    "path": "configs/refinement/electricity_nips-linear.yaml",
    "chars": 608,
    "preview": "base_model: linear\nckpt: dummy/electricity_nips.ckpt\ncontext_length: 336\ndataset: electricity_nips\ndevice: cuda:0\ndiffus"
  },
  {
    "path": "configs/refinement/electricity_nips-seasonal_naive.yaml",
    "chars": 575,
    "preview": "base_model: seasonal_naive\nckpt: dummy/electricity_nips.ckpt\ncontext_length: 336\ndataset: electricity_nips\ndevice: cuda:"
  },
  {
    "path": "configs/refinement/electricity_nips-transformer.yaml",
    "chars": 572,
    "preview": "base_model: transformer\nckpt: dummy/electricity_nips.ckpt\ncontext_length: 336\ndataset: electricity_nips\ndevice: cuda:0\ni"
  },
  {
    "path": "configs/refinement/exchange_rate_nips-deepar.yaml",
    "chars": 574,
    "preview": "base_model: deepar\nckpt: dummy/exchange_rate_nips.ckpt\ncontext_length: 360\ndataset: exchange_rate_nips\ndevice: cuda:0\nin"
  },
  {
    "path": "configs/refinement/exchange_rate_nips-linear.yaml",
    "chars": 615,
    "preview": "base_model: linear\nckpt: dummy/exchange_rate_nips.ckpt\ncontext_length: 360\ndataset: exchange_rate_nips\ndevice: cuda:0\ndi"
  },
  {
    "path": "configs/refinement/exchange_rate_nips-seasonal_naive.yaml",
    "chars": 582,
    "preview": "base_model: seasonal_naive\nckpt: dummy/exchange_rate_nips.ckpt\ncontext_length: 360\ndataset: exchange_rate_nips\ndevice: c"
  },
  {
    "path": "configs/refinement/exchange_rate_nips-transformer.yaml",
    "chars": 579,
    "preview": "base_model: transformer\nckpt: dummy/exchange_rate_nips.ckpt\ncontext_length: 360\ndataset: exchange_rate_nips\ndevice: cuda"
  },
  {
    "path": "configs/refinement/kdd_cup_2018_without_missing-deepar.yaml",
    "chars": 653,
    "preview": "base_model: deepar\nbase_model_params: {}\nckpt: dummy/kdd_cup_2018_without_missing.ckpt\ncontext_length: 312\ndataset: kdd_"
  },
  {
    "path": "configs/refinement/kdd_cup_2018_without_missing-linear.yaml",
    "chars": 653,
    "preview": "base_model: linear\nbase_model_params: {}\nckpt: dummy/kdd_cup_2018_without_missing.ckpt\ncontext_length: 312\ndataset: kdd_"
  },
  {
    "path": "configs/refinement/kdd_cup_2018_without_missing-seasonal_naive.yaml",
    "chars": 661,
    "preview": "base_model: seasonal_naive\nbase_model_params: {}\nckpt: dummy/kdd_cup_2018_without_missing.ckpt\ncontext_length: 312\ndatas"
  },
  {
    "path": "configs/refinement/kdd_cup_2018_without_missing-transformer.yaml",
    "chars": 658,
    "preview": "base_model: transformer\nbase_model_params: {}\nckpt: dummy/kdd_cup_2018_without_missing.ckpt\ncontext_length: 312\ndataset:"
  },
  {
    "path": "configs/refinement/m4_hourly-deepar.yaml",
    "chars": 595,
    "preview": "base_model: deepar\nckpt: dummy/m4_hourly.ckpt\ncontext_length: 312\ndataset: m4_hourly\ndevice: cuda:0\ndiffusion_config: di"
  },
  {
    "path": "configs/refinement/m4_hourly-linear.yaml",
    "chars": 595,
    "preview": "base_model: linear\nckpt: dummy/m4_hourly.ckpt\ncontext_length: 312\ndataset: m4_hourly\ndevice: cuda:0\ndiffusion_config: di"
  },
  {
    "path": "configs/refinement/m4_hourly-seasonal_naive.yaml",
    "chars": 603,
    "preview": "base_model: seasonal_naive\nckpt: dummy/m4_hourly.ckpt\ncontext_length: 312\ndataset: m4_hourly\ndevice: cuda:0\ndiffusion_co"
  },
  {
    "path": "configs/refinement/m4_hourly-transformer.yaml",
    "chars": 600,
    "preview": "base_model: transformer\nckpt: dummy/m4_hourly.ckpt\ncontext_length: 312\ndataset: m4_hourly\ndevice: cuda:0\ndiffusion_confi"
  },
  {
    "path": "configs/refinement/solar_nips-deepar.yaml",
    "chars": 555,
    "preview": "base_model: deepar\nckpt: dummy/solar_nips.ckpt\ncontext_length: 336\ndataset: solar_nips\ndevice: cuda:0\ninit_skip: false\ni"
  },
  {
    "path": "configs/refinement/solar_nips-linear.yaml",
    "chars": 596,
    "preview": "base_model: linear\nckpt: dummy/solar_nips.ckpt\ncontext_length: 336\ndataset: solar_nips\ndevice: cuda:0\ndiffusion_config: "
  },
  {
    "path": "configs/refinement/solar_nips-seasonal_naive.yaml",
    "chars": 563,
    "preview": "base_model: seasonal_naive\nckpt: dummy/solar_nips.ckpt\ncontext_length: 336\ndataset: solar_nips\ndevice: cuda:0\ninit_skip:"
  },
  {
    "path": "configs/refinement/solar_nips-transformer.yaml",
    "chars": 560,
    "preview": "base_model: transformer\nckpt: dummy/solar_nips.ckpt\ncontext_length: 336\ndataset: solar_nips\ndevice: cuda:0\ninit_skip: fa"
  },
  {
    "path": "configs/refinement/traffic_nips-deepar.yaml",
    "chars": 558,
    "preview": "base_model: deepar\nckpt: dummy/traffic_nips.ckpt\ncontext_length: 336\ndataset: traffic_nips\ndevice: cuda:0\ninit_skip: tru"
  },
  {
    "path": "configs/refinement/traffic_nips-linear.yaml",
    "chars": 599,
    "preview": "base_model: linear\nckpt: dummy/traffic_nips.ckpt\ncontext_length: 336\ndataset: traffic_nips\ndevice: cuda:0\ndiffusion_conf"
  },
  {
    "path": "configs/refinement/traffic_nips-seasonal_naive.yaml",
    "chars": 566,
    "preview": "base_model: seasonal_naive\nckpt: dummy/traffic_nips.ckpt\ncontext_length: 336\ndataset: traffic_nips\ndevice: cuda:0\ninit_s"
  },
  {
    "path": "configs/refinement/traffic_nips-transformer.yaml",
    "chars": 563,
    "preview": "base_model: transformer\nckpt: dummy/traffic_nips.ckpt\ncontext_length: 336\ndataset: traffic_nips\ndevice: cuda:0\ninit_skip"
  },
  {
    "path": "configs/refinement/uber_tlc_hourly-deepar.yaml",
    "chars": 565,
    "preview": "base_model: deepar\nckpt: dummy/uber_tlc_hourly.ckpt\ncontext_length: 336\ndataset: uber_tlc_hourly\ndevice: cuda:0\ninit_ski"
  },
  {
    "path": "configs/refinement/uber_tlc_hourly-linear.yaml",
    "chars": 606,
    "preview": "base_model: linear\nckpt: dummy/uber_tlc_hourly.ckpt\ncontext_length: 336\ndataset: uber_tlc_hourly\ndevice: cuda:0\ndiffusio"
  },
  {
    "path": "configs/refinement/uber_tlc_hourly-seasonal_naive.yaml",
    "chars": 573,
    "preview": "base_model: seasonal_naive\nckpt: dummy/uber_tlc_hourly.ckpt\ncontext_length: 336\ndataset: uber_tlc_hourly\ndevice: cuda:0\n"
  },
  {
    "path": "configs/refinement/uber_tlc_hourly-transformer.yaml",
    "chars": 570,
    "preview": "base_model: transformer\nckpt: dummy/uber_tlc_hourly.ckpt\ncontext_length: 336\ndataset: uber_tlc_hourly\ndevice: cuda:0\nini"
  },
  {
    "path": "configs/refinement/wiki2000_nips-deepar.yaml",
    "chars": 603,
    "preview": "base_model: deepar\nckpt: dummy/wiki2000_nips.ckpt\ncontext_length: 360\ndataset: wiki2000_nips\ndevice: cuda:0\ndiffusion_co"
  },
  {
    "path": "configs/refinement/wiki2000_nips-linear.yaml",
    "chars": 603,
    "preview": "base_model: linear\nckpt: dummy/wiki2000_nips.ckpt\ncontext_length: 360\ndataset: wiki2000_nips\ndevice: cuda:0\ndiffusion_co"
  },
  {
    "path": "configs/refinement/wiki2000_nips-seasonal_naive.yaml",
    "chars": 611,
    "preview": "base_model: seasonal_naive\nckpt: dummy/wiki2000_nips.ckpt\ncontext_length: 360\ndataset: wiki2000_nips\ndevice: cuda:0\ndiff"
  },
  {
    "path": "configs/refinement/wiki2000_nips-transformer.yaml",
    "chars": 608,
    "preview": "base_model: transformer\nckpt: dummy/wiki2000_nips.ckpt\ncontext_length: 360\ndataset: wiki2000_nips\ndevice: cuda:0\ndiffusi"
  },
  {
    "path": "configs/refinement.yaml",
    "chars": 697,
    "preview": "# Model & checkpoint parameters\ndataset: solar_nips\ndevice: cuda:0\nckpt: ckpts/forecasting/solar_nips/649_.ckpt\ndiffusio"
  },
  {
    "path": "configs/train_tsdiff/train_electricity.yaml",
    "chars": 545,
    "preview": "model: unconditional\ndiffusion_config: diffusion_small_config\nnormalization: mean\nuse_features: False\nuse_lags: True\ndat"
  },
  {
    "path": "configs/train_tsdiff/train_exchange.yaml",
    "chars": 546,
    "preview": "model: unconditional\ndiffusion_config: diffusion_small_config\nnormalization: mean\nuse_features: False\nuse_lags: True\ndat"
  },
  {
    "path": "configs/train_tsdiff/train_kdd_cup.yaml",
    "chars": 556,
    "preview": "model: unconditional\ndiffusion_config: diffusion_small_config\nnormalization: mean\nuse_features: False\nuse_lags: True\ndat"
  },
  {
    "path": "configs/train_tsdiff/train_m4.yaml",
    "chars": 539,
    "preview": "model: unconditional\ndiffusion_config: diffusion_small_config\nnormalization: mean\nuse_features: False\nuse_lags: False\nda"
  },
  {
    "path": "configs/train_tsdiff/train_missing_electricity.yaml",
    "chars": 821,
    "preview": "model: unconditional\ndiffusion_config: diffusion_small_config\nnormalization: mean\nuse_features: False\nuse_lags: True\ndat"
  },
  {
    "path": "configs/train_tsdiff/train_missing_exchange.yaml",
    "chars": 822,
    "preview": "model: unconditional\ndiffusion_config: diffusion_small_config\nnormalization: mean\nuse_features: False\nuse_lags: True\ndat"
  },
  {
    "path": "configs/train_tsdiff/train_missing_kdd_cup.yaml",
    "chars": 871,
    "preview": "model: unconditional\ndiffusion_config: diffusion_small_config\nnormalization: mean\nuse_features: False\nuse_lags: True\ndat"
  },
  {
    "path": "configs/train_tsdiff/train_missing_solar.yaml",
    "chars": 815,
    "preview": "model: unconditional\ndiffusion_config: diffusion_small_config\nnormalization: mean\nuse_features: False\nuse_lags: True\ndat"
  },
  {
    "path": "configs/train_tsdiff/train_missing_traffic.yaml",
    "chars": 815,
    "preview": "model: unconditional\ndiffusion_config: diffusion_small_config\nnormalization: mean\nuse_features: False\nuse_lags: True\ndat"
  },
  {
    "path": "configs/train_tsdiff/train_missing_uber_tlc.yaml",
    "chars": 820,
    "preview": "model: unconditional\ndiffusion_config: diffusion_small_config\nnormalization: mean\nuse_features: False\nuse_lags: True\ndat"
  },
  {
    "path": "configs/train_tsdiff/train_solar.yaml",
    "chars": 539,
    "preview": "model: unconditional\ndiffusion_config: diffusion_small_config\nnormalization: mean\nuse_features: False\nuse_lags: True\ndat"
  },
  {
    "path": "configs/train_tsdiff/train_traffic.yaml",
    "chars": 539,
    "preview": "model: unconditional\ndiffusion_config: diffusion_small_config\nnormalization: mean\nuse_features: False\nuse_lags: True\ndat"
  },
  {
    "path": "configs/train_tsdiff/train_uber_tlc.yaml",
    "chars": 544,
    "preview": "model: unconditional\ndiffusion_config: diffusion_small_config\nnormalization: mean\nuse_features: False\nuse_lags: True\ndat"
  },
  {
    "path": "configs/train_tsdiff/train_wiki.yaml",
    "chars": 543,
    "preview": "model: unconditional\ndiffusion_config: diffusion_small_config\nnormalization: mean\nuse_features: False\nuse_lags: False\nda"
  },
  {
    "path": "configs/train_tsdiff-cond/electricity_nips.yaml",
    "chars": 414,
    "preview": "batch_size: 64\ncontext_length: 336\ndataset: electricity_nips\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\ndo_"
  },
  {
    "path": "configs/train_tsdiff-cond/exchange_rate_nips.yaml",
    "chars": 415,
    "preview": "batch_size: 64\ncontext_length: 360\ndataset: exchange_rate_nips\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\nd"
  },
  {
    "path": "configs/train_tsdiff-cond/kdd_cup_2018_without_missing.yaml",
    "chars": 425,
    "preview": "batch_size: 64\ncontext_length: 312\ndataset: kdd_cup_2018_without_missing\ndevice: cuda:0\ndiffusion_config: diffusion_smal"
  },
  {
    "path": "configs/train_tsdiff-cond/m4_hourly.yaml",
    "chars": 408,
    "preview": "batch_size: 64\ncontext_length: 312\ndataset: m4_hourly\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\ndo_final_e"
  },
  {
    "path": "configs/train_tsdiff-cond/missing_BM-B_electricity_nips.yaml",
    "chars": 489,
    "preview": "batch_size: 64\ncontext_length: 336\ndataset: electricity_nips\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\ndo_"
  },
  {
    "path": "configs/train_tsdiff-cond/missing_BM-B_exchange_rate_nips.yaml",
    "chars": 490,
    "preview": "batch_size: 64\ncontext_length: 360\ndataset: exchange_rate_nips\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\nd"
  },
  {
    "path": "configs/train_tsdiff-cond/missing_BM-B_kdd_cup_2018_without_missing.yaml",
    "chars": 500,
    "preview": "batch_size: 64\ncontext_length: 312\ndataset: kdd_cup_2018_without_missing\ndevice: cuda:0\ndiffusion_config: diffusion_smal"
  },
  {
    "path": "configs/train_tsdiff-cond/missing_BM-B_solar_nips.yaml",
    "chars": 483,
    "preview": "batch_size: 64\ncontext_length: 336\ndataset: solar_nips\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\ndo_final_"
  },
  {
    "path": "configs/train_tsdiff-cond/missing_BM-B_traffic_nips.yaml",
    "chars": 484,
    "preview": "batch_size: 64\ncontext_length: 336\ndataset: traffic_nips\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\ndo_fina"
  },
  {
    "path": "configs/train_tsdiff-cond/missing_BM-B_uber_tlc_hourly.yaml",
    "chars": 488,
    "preview": "batch_size: 64\ncontext_length: 336\ndataset: uber_tlc_hourly\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\ndo_f"
  },
  {
    "path": "configs/train_tsdiff-cond/missing_BM-E_electricity_nips.yaml",
    "chars": 489,
    "preview": "batch_size: 64\ncontext_length: 336\ndataset: electricity_nips\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\ndo_"
  },
  {
    "path": "configs/train_tsdiff-cond/missing_BM-E_exchange_rate_nips.yaml",
    "chars": 490,
    "preview": "batch_size: 64\ncontext_length: 360\ndataset: exchange_rate_nips\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\nd"
  },
  {
    "path": "configs/train_tsdiff-cond/missing_BM-E_kdd_cup_2018_without_missing.yaml",
    "chars": 500,
    "preview": "batch_size: 64\ncontext_length: 312\ndataset: kdd_cup_2018_without_missing\ndevice: cuda:0\ndiffusion_config: diffusion_smal"
  },
  {
    "path": "configs/train_tsdiff-cond/missing_BM-E_solar_nips.yaml",
    "chars": 483,
    "preview": "batch_size: 64\ncontext_length: 336\ndataset: solar_nips\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\ndo_final_"
  },
  {
    "path": "configs/train_tsdiff-cond/missing_BM-E_traffic_nips.yaml",
    "chars": 484,
    "preview": "batch_size: 64\ncontext_length: 336\ndataset: traffic_nips\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\ndo_fina"
  },
  {
    "path": "configs/train_tsdiff-cond/missing_BM-E_uber_tlc_hourly.yaml",
    "chars": 488,
    "preview": "batch_size: 64\ncontext_length: 336\ndataset: uber_tlc_hourly\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\ndo_f"
  },
  {
    "path": "configs/train_tsdiff-cond/missing_RM_electricity_nips.yaml",
    "chars": 485,
    "preview": "batch_size: 64\ncontext_length: 336\ndataset: electricity_nips\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\ndo_"
  },
  {
    "path": "configs/train_tsdiff-cond/missing_RM_exchange_rate_nips.yaml",
    "chars": 486,
    "preview": "batch_size: 64\ncontext_length: 360\ndataset: exchange_rate_nips\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\nd"
  },
  {
    "path": "configs/train_tsdiff-cond/missing_RM_kdd_cup_2018_without_missing.yaml",
    "chars": 496,
    "preview": "batch_size: 64\ncontext_length: 312\ndataset: kdd_cup_2018_without_missing\ndevice: cuda:0\ndiffusion_config: diffusion_smal"
  },
  {
    "path": "configs/train_tsdiff-cond/missing_RM_solar_nips.yaml",
    "chars": 479,
    "preview": "batch_size: 64\ncontext_length: 336\ndataset: solar_nips\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\ndo_final_"
  },
  {
    "path": "configs/train_tsdiff-cond/missing_RM_traffic_nips.yaml",
    "chars": 480,
    "preview": "batch_size: 64\ncontext_length: 336\ndataset: traffic_nips\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\ndo_fina"
  },
  {
    "path": "configs/train_tsdiff-cond/missing_RM_uber_tlc_hourly.yaml",
    "chars": 484,
    "preview": "batch_size: 64\ncontext_length: 336\ndataset: uber_tlc_hourly\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\ndo_f"
  },
  {
    "path": "configs/train_tsdiff-cond/solar_nips.yaml",
    "chars": 408,
    "preview": "batch_size: 64\ncontext_length: 336\ndataset: solar_nips\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\ndo_final_"
  },
  {
    "path": "configs/train_tsdiff-cond/traffic_nips.yaml",
    "chars": 409,
    "preview": "batch_size: 64\ncontext_length: 336\ndataset: traffic_nips\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\ndo_fina"
  },
  {
    "path": "configs/train_tsdiff-cond/uber_tlc_hourly.yaml",
    "chars": 413,
    "preview": "batch_size: 64\ncontext_length: 336\ndataset: uber_tlc_hourly\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\ndo_f"
  },
  {
    "path": "configs/train_tsdiff-cond/wiki2000_nips.yaml",
    "chars": 413,
    "preview": "batch_size: 64\ncontext_length: 360\ndataset: wiki2000_nips\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\ndo_fin"
  },
  {
    "path": "configs/train_tsdiff-cond.yaml",
    "chars": 576,
    "preview": "model: conditional\ndiffusion_config: diffusion_small_config\nnormalization: mean\nuse_features: False\nuse_lags: True\ndatas"
  },
  {
    "path": "configs/train_tsdiff.yaml",
    "chars": 751,
    "preview": "model: unconditional\ndiffusion_config: diffusion_small_config\nnormalization: mean\nuse_features: False\nuse_lags: False\nda"
  },
  {
    "path": "configs/tstr/electricity_nips.yaml",
    "chars": 229,
    "preview": "ckpt: dummy/electricity_nips.ckpt\ncontext_length: 336\ndataset: electricity_nips\ndevice: cuda:0\ndiffusion_config: diffusi"
  },
  {
    "path": "configs/tstr/exchange_rate_nips.yaml",
    "chars": 232,
    "preview": "ckpt: dummy/exchange_rate_nips.ckpt\ncontext_length: 360\ndataset: exchange_rate_nips\ndevice: cuda:0\ndiffusion_config: dif"
  },
  {
    "path": "configs/tstr/kdd_cup_2018_without_missing.yaml",
    "chars": 252,
    "preview": "ckpt: dummy/kdd_cup_2018_without_missing.ckpt\ncontext_length: 312\ndataset: kdd_cup_2018_without_missing\ndevice: cuda:0\nd"
  },
  {
    "path": "configs/tstr/m4_hourly.yaml",
    "chars": 216,
    "preview": "ckpt: dummy/m4_hourly.ckpt\ncontext_length: 312\ndataset: m4_hourly\ndevice: cuda:0\ndiffusion_config: diffusion_small_confi"
  },
  {
    "path": "configs/tstr/solar_nips.yaml",
    "chars": 217,
    "preview": "ckpt: dummy/solar_nips.ckpt\ncontext_length: 336\ndataset: solar_nips\ndevice: cuda:0\ndiffusion_config: diffusion_small_con"
  },
  {
    "path": "configs/tstr/traffic_nips.yaml",
    "chars": 220,
    "preview": "ckpt: dummy/traffic_nips.ckpt\ncontext_length: 336\ndataset: traffic_nips\ndevice: cuda:0\ndiffusion_config: diffusion_small"
  },
  {
    "path": "configs/tstr/uber_tlc_hourly.yaml",
    "chars": 227,
    "preview": "ckpt: dummy/uber_tlc_hourly.ckpt\ncontext_length: 336\ndataset: uber_tlc_hourly\ndevice: cuda:0\ndiffusion_config: diffusion"
  },
  {
    "path": "configs/tstr/wiki2000_nips.yaml",
    "chars": 224,
    "preview": "ckpt: dummy/wiki2000_nips.ckpt\ncontext_length: 360\ndataset: wiki2000_nips\ndevice: cuda:0\ndiffusion_config: diffusion_sma"
  },
  {
    "path": "configs/tstr.yaml",
    "chars": 265,
    "preview": "# Model & checkpoint parameters\ndataset: solar_nips\ndevice: cuda:0\nckpt: ckpts/solar_nips/version_236/1299_.ckpt\ndiffusi"
  },
  {
    "path": "pyproject.toml",
    "chars": 515,
    "preview": "[project]\nname = \"uncond-ts-diff\"\nversion = \"0.1.0\"\ndescription = \"TSDiff: An Unconditional Diffusion Model for Time Ser"
  },
  {
    "path": "src/uncond_ts_diff/arch/__init__.py",
    "chars": 173,
    "preview": "# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n# SPDX-License-Identifier: Apache-2.0\nfrom .backbon"
  },
  {
    "path": "src/uncond_ts_diff/arch/backbones.py",
    "chars": 5177,
    "preview": "# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n# SPDX-License-Identifier: Apache-2.0\nimport math\n\n"
  },
  {
    "path": "src/uncond_ts_diff/arch/s4.py",
    "chars": 61989,
    "preview": "\"\"\"Standalone version of Structured (Sequence) State Space (S4) model.\"\"\"\n\nimport logging\nfrom functools import partial\n"
  },
  {
    "path": "src/uncond_ts_diff/configs.py",
    "chars": 2710,
    "preview": "# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n# SPDX-License-Identifier: Apache-2.0\nfrom uncond_t"
  },
  {
    "path": "src/uncond_ts_diff/dataset.py",
    "chars": 1370,
    "preview": "# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n# SPDX-License-Identifier: Apache-2.0\nimport os\nimp"
  },
  {
    "path": "src/uncond_ts_diff/metrics/__init__.py",
    "chars": 189,
    "preview": "# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n# SPDX-License-Identifier: Apache-2.0\nfrom .linear_"
  },
  {
    "path": "src/uncond_ts_diff/metrics/linear_pred_score.py",
    "chars": 3853,
    "preview": "# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n# SPDX-License-Identifier: Apache-2.0\nfrom typing i"
  },
  {
    "path": "src/uncond_ts_diff/model/__init__.py",
    "chars": 307,
    "preview": "# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n# SPDX-License-Identifier: Apache-2.0\nfrom .diffusi"
  },
  {
    "path": "src/uncond_ts_diff/model/callback.py",
    "chars": 9719,
    "preview": "# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n# SPDX-License-Identifier: Apache-2.0\nfrom copy imp"
  },
  {
    "path": "src/uncond_ts_diff/model/diffusion/_base.py",
    "chars": 11647,
    "preview": "# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n# SPDX-License-Identifier: Apache-2.0\nfrom typing i"
  },
  {
    "path": "src/uncond_ts_diff/model/diffusion/tsdiff.py",
    "chars": 5594,
    "preview": "# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n# SPDX-License-Identifier: Apache-2.0\nimport copy\n\n"
  },
  {
    "path": "src/uncond_ts_diff/model/diffusion/tsdiff_cond.py",
    "chars": 9327,
    "preview": "# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n# SPDX-License-Identifier: Apache-2.0\nimport torch\n"
  },
  {
    "path": "src/uncond_ts_diff/model/linear/_estimator.py",
    "chars": 8977,
    "preview": "# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n# SPDX-License-Identifier: Apache-2.0\nfrom typing i"
  },
  {
    "path": "src/uncond_ts_diff/model/linear/_scaler.py",
    "chars": 2426,
    "preview": "# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n# SPDX-License-Identifier: Apache-2.0\nfrom typing i"
  },
  {
    "path": "src/uncond_ts_diff/predictor.py",
    "chars": 1095,
    "preview": "# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n# SPDX-License-Identifier: Apache-2.0\nfrom typing i"
  },
  {
    "path": "src/uncond_ts_diff/sampler/__init__.py",
    "chars": 319,
    "preview": "# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n# SPDX-License-Identifier: Apache-2.0\nfrom .observa"
  },
  {
    "path": "src/uncond_ts_diff/sampler/_base.py",
    "chars": 5170,
    "preview": "# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n# SPDX-License-Identifier: Apache-2.0\nfrom typing i"
  },
  {
    "path": "src/uncond_ts_diff/sampler/observation_guidance.py",
    "chars": 10976,
    "preview": "# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n# SPDX-License-Identifier: Apache-2.0\nimport numpy "
  },
  {
    "path": "src/uncond_ts_diff/sampler/refiner.py",
    "chars": 11868,
    "preview": "# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n# SPDX-License-Identifier: Apache-2.0\nimport numpy "
  },
  {
    "path": "src/uncond_ts_diff/utils.py",
    "chars": 17922,
    "preview": "# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n# SPDX-License-Identifier: Apache-2.0\nfrom copy imp"
  }
]

About this extraction

This page contains the full source code of the amazon-science/unconditional-time-series-diffusion GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 126 files (289.4 KB), approximately 78.3k tokens, and a symbol index with 249 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!