[
  {
    "path": ".gitignore",
    "content": "__pycache__\nlightning_logs/\n.DS_Store\n*.egg-info\n/results/\n/ckpts/\n/saved_samples/\n.vscode/\n/sm_runs/\n/data/\n/checkpoints/"
  },
  {
    "path": "CODE_OF_CONDUCT.md",
    "content": "## Code of Conduct\nThis project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).\nFor more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact\nopensource-codeofconduct@amazon.com with any additional questions or comments.\n"
  },
  {
    "path": "CONTRIBUTING.md",
    "content": "# Contributing Guidelines\n\nThank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional\ndocumentation, we greatly value feedback and contributions from our community.\n\nPlease read through this document before submitting any issues or pull requests to ensure we have all the necessary\ninformation to effectively respond to your bug report or contribution.\n\n\n## Reporting Bugs/Feature Requests\n\nWe welcome you to use the GitHub issue tracker to report bugs or suggest features.\n\nWhen filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already\nreported the issue. Please try to include as much information as you can. Details like these are incredibly useful:\n\n* A reproducible test case or series of steps\n* The version of our code being used\n* Any modifications you've made relevant to the bug\n* Anything unusual about your environment or deployment\n\n\n## Contributing via Pull Requests\nContributions via pull requests are much appreciated. Before sending us a pull request, please ensure that:\n\n1. You are working against the latest source on the *main* branch.\n2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already.\n3. You open an issue to discuss any significant work - we would hate for your time to be wasted.\n\nTo send us a pull request, please:\n\n1. Fork the repository.\n2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change.\n3. Ensure local tests pass.\n4. Commit to your fork using clear commit messages.\n5. Send us a pull request, answering any default questions in the pull request interface.\n6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation.\n\nGitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and\n[creating a pull request](https://help.github.com/articles/creating-a-pull-request/).\n\n\n## Finding contributions to work on\nLooking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start.\n\n\n## Code of Conduct\nThis project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).\nFor more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact\nopensource-codeofconduct@amazon.com with any additional questions or comments.\n\n\n## Security issue notifications\nIf you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue.\n\n\n## Licensing\n\nSee the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution.\n"
  },
  {
    "path": "LICENSE",
    "content": "\n                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n"
  },
  {
    "path": "NOTICE",
    "content": "Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n"
  },
  {
    "path": "README.md",
    "content": "# TSDiff: An Unconditional Diffusion Model for Time Series\n\n[![preprint](https://img.shields.io/static/v1?label=arXiv&message=2307.11494&color=B31B1B)](https://arxiv.org/abs/2307.11494)\n[![License: MIT](https://img.shields.io/badge/License-Apache--2.0-yellow.svg)](https://opensource.org/licenses/Apache-2.0)\n[![Venue:ICML 2023](https://img.shields.io/badge/Venue-NeurIPS%202023-007CFF)](https://neurips.cc/)\n<p align=\"center\">\n  <img src=\"./assets/overview.png\" width=\"100%\">\n  <br />\n  <span>Fig. 1: An overview of TSDiff’s use cases. <b>Predict:</b> By utilizing observation self-guidance, TSDiff can be\nconditioned during inference to perform predictive tasks such as forecasting. <b>Refine:</b> Predictions\nof base forecasters can be improved by leveraging the implicit probability density of TSDiff.\n<b>Synthesize:</b> Realistic samples generated by TSDiff can be used to train downstream forecasters achieving good\nperformance on real test data.</span>\n</p>\n\n---\n\nThis repository contains the official implementation of the NeurIPS 2023 paper [*Predict, Refine, Synthesize: Self-Guiding Diffusion Models for Probabilistic Time Series Forecasting*](https://arxiv.org/abs/2307.11494). In this paper, we propose *TSDiff*, an unconditional diffusion model for time series. Our proposed self-guidance mechanism enables conditioning TSDiff for downstream tasks during inference, without requiring auxiliary networks or altering the training procedure. Furthermore, our refinement scheme leverages the implicit density learned by the diffusion model to iteratively refine the predictions of base forecasters. Finally, we demonstrate the high quality of the synthetic time series by training downstrain models solely on generated data and introducing the *Linear Predictive Score (LPS)*.\n\n<p align=\"center\">\n  <img src=\"./assets/forecasts.png\" width=\"60%\">\n  <br />\n  <span>Fig. 2: Example forecasts generated by TSDiff-Q for\ntime series in Electricity, KDDCup, and Exchange — three datasets with different frequencies and/or prediction lengths.</span>\n</p>\n\n## Installation\n\nTSDiff requires Python 3.8 or higher.\n\n* Create a conda environment (optional, but recommended).\n```sh\nconda create --name tsdiff --yes python=3.8 && conda activate tsdiff\n```\n* Install this package.\n```sh\npip install --editable \".\"\n```\n\n> [!TIP]  \n> We have some updates in the `update` branch. If you're interested in testing out TSDiff or [training it on a custom dataset](https://github.com/amazon-science/unconditional-time-series-diffusion/issues/7), using the `update` branch maybe faster for training.\n\n## Usage\n\n### Training Models\n\nTrain models using the `train_model.py` and `train_cond_model.py` scripts for `TSDiff` and `TSDiff-Cond`, respectively. Sample configurations can be found in `configs/train_tsdiff.yaml` and `configs/train_tsdiff-cond.yaml`. Specific configurations used in the paper can be found in `configs/train_tsdiff` and `configs/train_tsdiff-cond`.\n\nExample commands for regular (i.e., no missing values) forecasting:\n```sh\n# Train TSDiff on the Uber dataset for regular forecasting\npython bin/train_model.py -c configs/train_tsdiff/train_uber_tlc.yaml\n\n# Train TSDiff on the M4 dataset for regular forecasting\npython bin/train_model.py -c configs/train_tsdiff/train_m4.yaml\n\n# Train TSDiff-Cond on the Uber dataset for regular forecasting\npython bin/train_cond_model.py -c configs/train_tsdiff-cond/uber_tlc_hourly.yaml\n\n# Train TSDiff-Cond on the M4 dataset for regular forecasting\npython bin/train_cond_model.py -c configs/train_tsdiff-cond/m4_hourly.yaml\n```\n\nExample commands for forecasting with missing values:\n```sh\n# Train TSDiff on the Uber dataset for the missing values experiment\npython bin/train_model.py -c configs/train_tsdiff/train_missing_uber_tlc.yaml\n\n# Train TSDiff on the KDDCup dataset for the missing values experiment\npython bin/train_model.py -c configs/train_tsdiff/train_missing_kdd_cup.yaml\n\n# Train TSDiff-Cond on the Uber dataset for the RM missing values experiment\npython bin/train_cond_model.py -c configs/train_tsdiff-cond/missing_RM_uber_tlc_hourly.yaml\n\n# Train TSDiff-Cond on the KDDCup dataset for the BM-B missing values experiment\npython bin/train_cond_model.py -c configs/train_tsdiff-cond/missing_BM-B_kdd_cup_2018_without_missing.yaml\n\n# Train TSDiff-Cond on the KDDCup dataset for the BM-E missing values experiment\npython bin/train_cond_model.py -c configs/train_tsdiff-cond/missing_BM-E_kdd_cup_2018_without_missing.yaml\n```\nNote that for TSDiff we train only one model and all the missing value scenarios are evaluated using the same unconditional model. However, for TSDiff-Cond, one model is trained per missingness scenario.\n\n### Evaluating Models\nThe unconditional models trained above can be used for the following tasks.\n\n#### Predict using Observation Self-Guidance\nUse the `guidance_experiment.py` script and `configs/guidance.yaml` config to run the forecasting experiments. Specific configurations used in the paper can be found in `configs/guidance/`.\n\nExample commands:\n```sh\n# Run observation self-guidance on the Solar dataset\npython bin/guidance_experiment.py -c configs/guidance/guidance_solar.yaml --ckpt /path/to/ckpt\n\n# Run observation self-guidance on the KDDCup dataset\npython bin/guidance_experiment.py -c configs/guidance/guidance_kdd_cup.yaml --ckpt /path/to/ckpt\n```\n\n#### Refine Predictions of Base Forecasters\nUse `refinement_experiment.py` script and `configs/refinement.yaml` config to run the refinement experiments. Specific configurations used in the paper can be found in `configs/refinement/`.\n\nExample commands:\n```sh\n# Refine predictions from the Linear model on the Solar dataset\npython bin/refinement_experiment.py -c configs/refinement/solar_nips-linear.yaml --ckpt /path/to/ckpt\n\n# Refine predictions from the DeepAR model on the M4 dataset\npython bin/refinement_experiment.py -c configs/refinement/m4_hourly-deepar.yaml --ckpt /path/to/ckpt\n```\n#### Train Downstream Models using Synthetic Data\nUse `tstr_experiment.py` script and `configs/tstr.yaml` config to run the _train on synthetic-test on real_ experiments. Specific configurations used in the paper can be found in `configs/tstr/`.\n\nExample commands:\n```sh\n# TSTR on the Solar Dataset\npython bin/tstr_experiment.py -c configs/tstr/solar_nips.yaml --ckpt /path/to/ckpt\n\n# TSTR on the KDDCup Dataset\npython bin/tstr_experiment.py -c configs/tstr/kdd_cup_2018_without_missing.yaml --ckpt /path/to/ckpt\n```\n\n## BibTeX\n\nIf you find this repository or the ideas presented in our paper useful, please consider citing.\n\n```\n@inproceedings{kollovieh2023predict,\n author    = {Kollovieh, Marcel and Ansari, Abdul Fatir and Bohlke-Schneider, Michael and Zschiegner, Jasper and Wang, Hao and Wang, Yuyang},\n title     = {Predict, Refine, Synthesize: Self-Guiding Diffusion Models for Probabilistic Time Series Forecasting},\n booktitle = {Advances in Neural Information Processing Systems},\n year      = {2023}\n}\n```\n\n## Security\n\nSee [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more information.\n\n## License\n\nThis project is licensed under the Apache-2.0 License.\n\n"
  },
  {
    "path": "THIRD-PARTY-LICENSES.txt",
    "content": "** state-spaces; version 1.0 -- https://github.com/HazyResearch/state-spaces\n \nApache License\nVersion 2.0, January 2004\nhttp://www.apache.org/licenses/\n\nTERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n1. Definitions.\n\n\"License\" shall mean the terms and conditions for use, reproduction, and\ndistribution as defined by Sections 1 through 9 of this document.\n\n\"Licensor\" shall mean the copyright owner or entity authorized by the copyright\nowner that is granting the License.\n\n\"Legal Entity\" shall mean the union of the acting entity and all other entities\nthat control, are controlled by, or are under common control with that entity.\nFor the purposes of this definition, \"control\" means (i) the power, direct or\nindirect, to cause the direction or management of such entity, whether by\ncontract or otherwise, or (ii) ownership of fifty percent (50%) or more of the\noutstanding shares, or (iii) beneficial ownership of such entity.\n\n\"You\" (or \"Your\") shall mean an individual or Legal Entity exercising\npermissions granted by this License.\n\n\"Source\" form shall mean the preferred form for making modifications, including\nbut not limited to software source code, documentation source, and configuration\nfiles.\n\n\"Object\" form shall mean any form resulting from mechanical transformation or\ntranslation of a Source form, including but not limited to compiled object code,\ngenerated documentation, and conversions to other media types.\n\n\"Work\" shall mean the work of authorship, whether in Source or Object form, made\navailable under the License, as indicated by a copyright notice that is included\nin or attached to the work (an example is provided in the Appendix below).\n\n\"Derivative Works\" shall mean any work, whether in Source or Object form, that\nis based on (or derived from) the Work and for which the editorial revisions,\nannotations, elaborations, or other modifications represent, as a whole, an\noriginal work of authorship. For the purposes of this License, Derivative Works\nshall not include works that remain separable from, or merely link (or bind by\nname) to the interfaces of, the Work and Derivative Works thereof.\n\n\"Contribution\" shall mean any work of authorship, including the original version\nof the Work and any modifications or additions to that Work or Derivative Works\nthereof, that is intentionally submitted to Licensor for inclusion in the Work\nby the copyright owner or by an individual or Legal Entity authorized to submit\non behalf of the copyright owner. For the purposes of this definition,\n\"submitted\" means any form of electronic, verbal, or written communication sent\nto the Licensor or its representatives, including but not limited to\ncommunication on electronic mailing lists, source code control systems, and\nissue tracking systems that are managed by, or on behalf of, the Licensor for\nthe purpose of discussing and improving the Work, but excluding communication\nthat is conspicuously marked or otherwise designated in writing by the copyright\nowner as \"Not a Contribution.\"\n\n\"Contributor\" shall mean Licensor and any individual or Legal Entity on behalf\nof whom a Contribution has been received by Licensor and subsequently\nincorporated within the Work.\n\n2. Grant of Copyright License. Subject to the terms and conditions of this\nLicense, each Contributor hereby grants to You a perpetual, worldwide, non-\nexclusive, no-charge, royalty-free, irrevocable copyright license to reproduce,\nprepare Derivative Works of, publicly display, publicly perform, sublicense, and\ndistribute the Work and such Derivative Works in Source or Object form.\n\n3. Grant of Patent License. Subject to the terms and conditions of this License,\neach Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-\ncharge, royalty-free, irrevocable (except as stated in this section) patent\nlicense to make, have made, use, offer to sell, sell, import, and otherwise\ntransfer the Work, where such license applies only to those patent claims\nlicensable by such Contributor that are necessarily infringed by their\nContribution(s) alone or by combination of their Contribution(s) with the Work\nto which such Contribution(s) was submitted. If You institute patent litigation\nagainst any entity (including a cross-claim or counterclaim in a lawsuit)\nalleging that the Work or a Contribution incorporated within the Work\nconstitutes direct or contributory patent infringement, then any patent licenses\ngranted to You under this License for that Work shall terminate as of the date\nsuch litigation is filed.\n\n4. Redistribution. You may reproduce and distribute copies of the Work or\nDerivative Works thereof in any medium, with or without modifications, and in\nSource or Object form, provided that You meet the following conditions:\n\n     (a) You must give any other recipients of the Work or Derivative Works a\ncopy of this License; and\n\n     (b) You must cause any modified files to carry prominent notices stating\nthat You changed the files; and\n\n     (c) You must retain, in the Source form of any Derivative Works that You\ndistribute, all copyright, patent, trademark, and attribution notices from the\nSource form of the Work, excluding those notices that do not pertain to any part\nof the Derivative Works; and\n\n     (d) If the Work includes a \"NOTICE\" text file as part of its distribution,\nthen any Derivative Works that You distribute must include a readable copy of\nthe attribution notices contained within such NOTICE file, excluding those\nnotices that do not pertain to any part of the Derivative Works, in at least one\nof the following places: within a NOTICE text file distributed as part of the\nDerivative Works; within the Source form or documentation, if provided along\nwith the Derivative Works; or, within a display generated by the Derivative\nWorks, if and wherever such third-party notices normally appear. The contents of\nthe NOTICE file are for informational purposes only and do not modify the\nLicense. You may add Your own attribution notices within Derivative Works that\nYou distribute, alongside or as an addendum to the NOTICE text from the Work,\nprovided that such additional attribution notices cannot be construed as\nmodifying the License.\n\n     You may add Your own copyright statement to Your modifications and may\nprovide additional or different license terms and conditions for use,\nreproduction, or distribution of Your modifications, or for any such Derivative\nWorks as a whole, provided Your use, reproduction, and distribution of the Work\notherwise complies with the conditions stated in this License.\n\n5. Submission of Contributions. Unless You explicitly state otherwise, any\nContribution intentionally submitted for inclusion in the Work by You to the\nLicensor shall be under the terms and conditions of this License, without any\nadditional terms or conditions. Notwithstanding the above, nothing herein shall\nsupersede or modify the terms of any separate license agreement you may have\nexecuted with Licensor regarding such Contributions.\n\n6. Trademarks. This License does not grant permission to use the trade names,\ntrademarks, service marks, or product names of the Licensor, except as required\nfor reasonable and customary use in describing the origin of the Work and\nreproducing the content of the NOTICE file.\n\n7. Disclaimer of Warranty. Unless required by applicable law or agreed to in\nwriting, Licensor provides the Work (and each Contributor provides its\nContributions) on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\nKIND, either express or implied, including, without limitation, any warranties\nor conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\nPARTICULAR PURPOSE. You are solely responsible for determining the\nappropriateness of using or redistributing the Work and assume any risks\nassociated with Your exercise of permissions under this License.\n\n8. Limitation of Liability. In no event and under no legal theory, whether in\ntort (including negligence), contract, or otherwise, unless required by\napplicable law (such as deliberate and grossly negligent acts) or agreed to in\nwriting, shall any Contributor be liable to You for damages, including any\ndirect, indirect, special, incidental, or consequential damages of any character\narising as a result of this License or out of the use or inability to use the\nWork (including but not limited to damages for loss of goodwill, work stoppage,\ncomputer failure or malfunction, or any and all other commercial damages or\nlosses), even if such Contributor has been advised of the possibility of such\ndamages.\n\n9. Accepting Warranty or Additional Liability. While redistributing the Work or\nDerivative Works thereof, You may choose to offer, and charge a fee for,\nacceptance of support, warranty, indemnity, or other liability obligations\nand/or rights consistent with this License. However, in accepting such\nobligations, You may act only on Your own behalf and on Your sole\nresponsibility, not on behalf of any other Contributor, and only if You agree to\nindemnify, defend, and hold each Contributor harmless for any liability incurred\nby, or claims asserted against, such Contributor by reason of your accepting any\nsuch warranty or additional liability.\n\nEND OF TERMS AND CONDITIONS\n\nAPPENDIX: How to apply the Apache License to your work.\n\nTo apply the Apache License to your work, attach the following boilerplate\nnotice, with the fields enclosed by brackets \"[]\" replaced with your own\nidentifying information. (Don't include the brackets!)  The text should be\nenclosed in the appropriate comment syntax for the file format. We also\nrecommend that a file or class name and description of purpose be included on\nthe same \"printed page\" as the copyright notice for easier identification within\nthird-party archives.\n\nCopyright [yyyy] [name of copyright owner]\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\n* For state-spaces see also this required NOTICE:\n    Copyright 2022 Albert Gu and Karan Goel and Christopher Re\n"
  },
  {
    "path": "bin/guidance_experiment.py",
    "content": "# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n# SPDX-License-Identifier: Apache-2.0\nimport logging\nimport argparse\nfrom pathlib import Path\n\nimport yaml\nimport torch\nfrom tqdm.auto import tqdm\nfrom gluonts.dataset.field_names import FieldName\nfrom gluonts.evaluation import make_evaluation_predictions, Evaluator\n\nfrom uncond_ts_diff.utils import (\n    create_transforms,\n    create_splitter,\n    get_next_file_num,\n    add_config_to_argparser,\n    filter_metrics,\n    MaskInput,\n)\nfrom uncond_ts_diff.model import TSDiff\nfrom uncond_ts_diff.dataset import get_gts_dataset\nfrom uncond_ts_diff.sampler import (\n    DDPMGuidance,\n    DDIMGuidance,\n)\nimport uncond_ts_diff.configs as diffusion_configs\n\nguidance_map = {\"ddpm\": DDPMGuidance, \"ddim\": DDIMGuidance}\n\n\ndef load_model(config):\n    model = TSDiff(\n        **getattr(\n            diffusion_configs,\n            config.get(\"diffusion_config\", \"diffusion_small_config\"),\n        ),\n        freq=config[\"freq\"],\n        use_features=config[\"use_features\"],\n        use_lags=config[\"use_lags\"],\n        normalization=\"mean\",\n        context_length=config[\"context_length\"],\n        prediction_length=config[\"prediction_length\"],\n        init_skip=config[\"init_skip\"],\n    )\n    model.load_state_dict(\n        torch.load(config[\"ckpt\"], map_location=\"cpu\"),\n        strict=True,\n    )\n    model = model.to(config[\"device\"])\n    return model\n\n\ndef evaluate_guidance(\n    config, model, test_dataset, transformation, num_samples=100\n):\n    logger.info(f\"Evaluating with {num_samples} samples.\")\n    results = []\n    if config[\"setup\"] == \"forecasting\":\n        missing_data_kwargs_list = [\n            {\n                \"missing_scenario\": \"none\",\n                \"missing_values\": 0,\n            }\n        ]\n        config[\"missing_data_configs\"] = missing_data_kwargs_list\n    elif config[\"setup\"] == \"missing_values\":\n        missing_data_kwargs_list = config[\"missing_data_configs\"]\n    else:\n        raise ValueError(f\"Unknown setup {config['setup']}\")\n\n    Guidance = guidance_map[config[\"sampler\"]]\n    sampler_params = config[\"sampler_params\"]\n    for missing_data_kwargs in missing_data_kwargs_list:\n        logger.info(\n            f\"Evaluating scenario '{missing_data_kwargs['missing_scenario']}' \"\n            f\"with {missing_data_kwargs['missing_values']:.1f} missing_values.\"\n        )\n\n        sampler = Guidance(\n            model=model,\n            prediction_length=config[\"prediction_length\"],\n            num_samples=num_samples,\n            **missing_data_kwargs,\n            **sampler_params,\n        )\n\n        transformed_testdata = transformation.apply(\n            test_dataset, is_train=False\n        )\n        test_splitter = create_splitter(\n            past_length=config[\"context_length\"] + max(model.lags_seq),\n            future_length=config[\"prediction_length\"],\n            mode=\"test\",\n        )\n\n        masking_transform = MaskInput(\n            FieldName.TARGET,\n            FieldName.OBSERVED_VALUES,\n            config[\"context_length\"],\n            missing_data_kwargs[\"missing_scenario\"],\n            missing_data_kwargs[\"missing_values\"],\n        )\n        test_transform = test_splitter + masking_transform\n\n        predictor = sampler.get_predictor(\n            test_transform,\n            batch_size=1280 // num_samples,\n            device=config[\"device\"],\n        )\n        forecast_it, ts_it = make_evaluation_predictions(\n            dataset=transformed_testdata,\n            predictor=predictor,\n            num_samples=num_samples,\n        )\n        forecasts = list(tqdm(forecast_it, total=len(transformed_testdata)))\n        tss = list(ts_it)\n        evaluator = Evaluator()\n        metrics, _ = evaluator(tss, forecasts)\n        metrics = filter_metrics(metrics)\n        results.append(dict(**missing_data_kwargs, **metrics))\n\n    return results\n\n\ndef main(config: dict, log_dir: str):\n    # Read global parameters\n    dataset_name = config[\"dataset\"]\n    freq = config[\"freq\"]\n    prediction_length = config[\"prediction_length\"]\n    num_samples = config[\"num_samples\"]\n\n    # Load dataset and model\n    logger.info(\"Loading model\")\n    model = load_model(config)\n    dataset = get_gts_dataset(dataset_name)\n    assert dataset.metadata.freq == freq\n    assert dataset.metadata.prediction_length == prediction_length\n\n    # Setup data transformation and loading\n    transformation = create_transforms(\n        num_feat_dynamic_real=0,\n        num_feat_static_cat=0,\n        num_feat_static_real=0,\n        time_features=model.time_features,\n        prediction_length=prediction_length,\n    )\n\n    # Run guidance\n    results = evaluate_guidance(\n        config, model, dataset.test, transformation, num_samples=num_samples\n    )\n\n    # Save results\n    log_dir = Path(log_dir) / \"guidance_logs\"\n    log_dir.mkdir(exist_ok=True, parents=True)\n    base_filename = \"results\"\n    run_num = get_next_file_num(\n        base_filename, log_dir, file_type=\"yaml\", separator=\"-\"\n    )\n    save_path = log_dir / f\"{base_filename}-{run_num}.yaml\"\n\n    with open(save_path, \"w\") as fp:\n        yaml.safe_dump(\n            {\"config\": config, \"metrics\": results},\n            fp,\n            default_flow_style=False,\n            sort_keys=False,\n        )\n\n\nif __name__ == \"__main__\":\n    # Setup Logger\n    logging.basicConfig(\n        format=\"%(asctime)s - %(name)s - %(levelname)s - %(message)s\"\n    )\n    logger = logging.getLogger(__file__)\n    logger.setLevel(logging.INFO)\n\n    # Setup argparse\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"-c\", \"--config\", type=str, required=True, help=\"Path to yaml config\"\n    )\n    parser.add_argument(\n        \"--out_dir\", type=str, default=\"./results\", help=\"Path to results dir\"\n    )\n    args, _ = parser.parse_known_args()\n\n    with open(args.config, \"r\") as fp:\n        config = yaml.safe_load(fp)\n\n    # Update config from command line\n    parser = add_config_to_argparser(config=config, parser=parser)\n    args = parser.parse_args()\n    config_updates = vars(args)\n    for k in config.keys() & config_updates.keys():\n        orig_val = config[k]\n        updated_val = config_updates[k]\n        if updated_val != orig_val:\n            logger.info(f\"Updated key '{k}': {orig_val} -> {updated_val}\")\n    config.update(config_updates)\n\n    main(config=config, log_dir=args.out_dir)\n"
  },
  {
    "path": "bin/refinement_experiment.py",
    "content": "# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n# SPDX-License-Identifier: Apache-2.0\nimport json\nimport copy\nimport logging\nimport argparse\nfrom pathlib import Path\n\nimport yaml\nimport torch\nimport numpy as np\nfrom tqdm.auto import tqdm\nfrom gluonts.mx import DeepAREstimator, TransformerEstimator\nfrom gluonts.model.seasonal_naive import SeasonalNaivePredictor\nfrom gluonts.evaluation import make_evaluation_predictions, Evaluator\nfrom gluonts.dataset.loader import TrainDataLoader\nfrom gluonts.itertools import Cached\nfrom gluonts.torch.batchify import batchify\n\nfrom uncond_ts_diff.utils import (\n    create_transforms,\n    create_splitter,\n    get_next_file_num,\n    add_config_to_argparser,\n    filter_metrics,\n)\nfrom uncond_ts_diff.model import TSDiff, LinearEstimator\nfrom uncond_ts_diff.dataset import get_gts_dataset\nfrom uncond_ts_diff.sampler import (\n    MostLikelyRefiner,\n    MCMCRefiner,\n    DDPMGuidance,\n    DDIMGuidance,\n)\nimport uncond_ts_diff.configs as diffusion_configs\n\nguidance_map = {\"ddpm\": DDPMGuidance, \"ddim\": DDIMGuidance}\nrefiner_map = {\"most_likely\": MostLikelyRefiner, \"mcmc\": MCMCRefiner}\n\n\ndef load_model(config):\n    model = TSDiff(\n        **getattr(\n            diffusion_configs,\n            config.get(\"diffusion_config\", \"diffusion_small_config\"),\n        ),\n        freq=config[\"freq\"],\n        use_features=config[\"use_features\"],\n        use_lags=config[\"use_lags\"],\n        normalization=\"mean\",\n        context_length=config[\"context_length\"],\n        prediction_length=config[\"prediction_length\"],\n        init_skip=config[\"init_skip\"],\n    )\n    model.load_state_dict(\n        torch.load(config[\"ckpt\"], map_location=\"cpu\"),\n        strict=True,\n    )\n    model = model.to(config[\"device\"])\n    return model\n\n\ndef get_best_diffusion_step(model: TSDiff, data_loader, device):\n    losses = np.zeros(model.timesteps)\n    batch = {\n        k: v.to(device)\n        for k, v in next(iter(data_loader)).items()\n        if isinstance(v, torch.Tensor)\n    }\n    x, features, scale = model._extract_features(batch)\n    for t in range(model.timesteps):\n        loss, _, _ = model.p_losses(\n            x.to(device), torch.tensor([t], device=device)\n        )\n        losses[t] = loss\n\n    best_t = ((losses - losses.mean()) ** 2).argmin()\n    return best_t\n\n\ndef train_and_forecast_base_model(dataset, base_model_name, config):\n    base_model_kwargs = config.get(\"base_model_params\", {})\n    if base_model_name == \"deepar\":\n        predictor = DeepAREstimator(\n            prediction_length=dataset.metadata.prediction_length,\n            freq=dataset.metadata.freq,\n            **base_model_kwargs,\n        ).train(list(dataset.train), cache_data=True)\n    elif base_model_name == \"transformer\":\n        predictor = TransformerEstimator(\n            prediction_length=dataset.metadata.prediction_length,\n            freq=dataset.metadata.freq,\n            **base_model_kwargs,\n        ).train(list(dataset.train), cache_data=True)\n    elif base_model_name == \"seasonal_naive\":\n        predictor = SeasonalNaivePredictor(\n            freq=dataset.metadata.freq,\n            prediction_length=dataset.metadata.prediction_length,\n            **base_model_kwargs,\n        )\n    elif base_model_name == \"linear\":\n        num_train_samples = 10000\n        predictor = LinearEstimator(\n            freq=dataset.metadata.freq,\n            prediction_length=dataset.metadata.prediction_length,\n            context_length=config[\"context_length\"],\n            num_train_samples=num_train_samples,\n            **base_model_kwargs,\n        ).train(list(dataset.train), cache_data=True)\n    else:\n        raise ValueError(f\"Unsupported base model {base_model_name}!\")\n\n    fcst_iter, ts_iter = make_evaluation_predictions(\n        dataset=dataset.test,\n        predictor=predictor,\n        num_samples=config[\"num_samples\"],\n    )\n    fcsts = list(tqdm(fcst_iter, total=len(dataset.test)))\n    tss = list(ts_iter)\n\n    return fcsts, tss\n\n\ndef forecast_guidance(\n    dataset,\n    base_model_name,\n    config,\n    diffusion_model,\n    transformed_testdata,\n    test_splitter,\n):\n    assert len(dataset.test) == len(transformed_testdata)\n    base_model_kwargs = config.get(\"base_model_params\", {})\n\n    Guidance = guidance_map[base_model_name]\n    predictor = Guidance(\n        model=diffusion_model,\n        prediction_length=dataset.metadata.prediction_length,\n        num_samples=config[\"num_samples\"],\n        **base_model_kwargs,\n    ).get_predictor(\n        input_transform=test_splitter,\n        batch_size=1280 // config[\"num_samples\"],\n        device=config[\"device\"],\n    )\n\n    fcst_iter, ts_iter = make_evaluation_predictions(\n        dataset=transformed_testdata,\n        predictor=predictor,\n        num_samples=config[\"num_samples\"],\n    )\n    fcsts = list(tqdm(fcst_iter, total=len(dataset.test)))\n    tss = list(ts_iter)\n\n    return fcsts, tss\n\n\ndef main(config: dict, log_dir: str):\n    # Read global parameters\n    dataset_name = config[\"dataset\"]\n    device = config[\"device\"]\n    context_length = config[\"context_length\"]\n    prediction_length = config[\"prediction_length\"]\n    base_model_name = config[\"base_model\"]\n    num_samples = config[\"num_samples\"]\n\n    # Load dataset and model\n    logger.info(\"Loading model\")\n    dataset = get_gts_dataset(dataset_name)\n    config[\"freq\"] = dataset.metadata.freq\n\n    assert prediction_length == dataset.metadata.prediction_length\n\n    model = load_model(config)\n\n    # Setup data transformation and loading\n    transformation = create_transforms(\n        num_feat_dynamic_real=0,\n        num_feat_static_cat=0,\n        num_feat_static_real=0,\n        time_features=model.time_features,\n        prediction_length=prediction_length,\n    )\n    transformed_data = transformation.apply(list(dataset.train), is_train=True)\n\n    transformed_testdata = transformation.apply(\n        list(dataset.test), is_train=False\n    )\n\n    training_splitter = create_splitter(\n        past_length=context_length + max(model.lags_seq),\n        future_length=prediction_length,\n        mode=\"train\",\n    )\n    test_splitter = create_splitter(\n        past_length=context_length + max(model.lags_seq),\n        future_length=prediction_length,\n        mode=\"test\",\n    )\n\n    train_dataloader = TrainDataLoader(\n        Cached(transformed_data),\n        batch_size=1024,\n        stack_fn=batchify,\n        transform=training_splitter,\n        num_batches_per_epoch=2048,\n    )\n\n    best_t = get_best_diffusion_step(model, train_dataloader, device)\n\n    # Train base model & get initial forecasts\n    logger.info(\"Training base model\")\n    if base_model_name in {\"ddpm\", \"ddim\"}:\n        base_fcsts, tss = forecast_guidance(\n            dataset,\n            base_model_name,\n            config,\n            diffusion_model=model,\n            transformed_testdata=transformed_testdata,\n            test_splitter=test_splitter,\n        )\n    else:\n        base_fcsts, tss = train_and_forecast_base_model(\n            dataset, base_model_name, config\n        )\n\n    # Evaluate base forecasts\n    evaluator = Evaluator()\n    baseline_metrics, _ = evaluator(tss, base_fcsts)\n    baseline_metrics = filter_metrics(baseline_metrics)\n\n    # Run refinement\n    log_dir = Path(log_dir) / \"refinement_logs\"\n    log_dir.mkdir(exist_ok=True, parents=True)\n    base_filename = \"results\"\n    run_num = get_next_file_num(\n        base_filename, log_dir, file_type=\"yaml\", separator=\"-\"\n    )\n    save_path = log_dir / f\"{base_filename}-{run_num}.yaml\"\n\n    results = [\n        {\n            \"model\": \"baseline\",\n            \"model_params\": {\n                \"name\": base_model_name,\n                **config.get(\"base_model_params\", {}),\n            },\n            **baseline_metrics,\n        }\n    ]\n\n    n_refiner_configs = len(config[\"refiner_configs\"])\n    for i, ref_config in enumerate(config[\"refiner_configs\"]):\n        logger.info(\n            f\"Running refiner ({i+1}/{n_refiner_configs}): {json.dumps(ref_config)}\"\n        )\n\n        refiner_config = copy.deepcopy(ref_config)\n        refiner_name = refiner_config.pop(\"refiner_name\")\n        Refiner = refiner_map[refiner_name]\n        refiner = Refiner(\n            model,\n            prediction_length,\n            init=iter(base_fcsts),\n            num_samples=num_samples,\n            fixed_t=best_t,\n            iterations=config[\"iterations\"],\n            **refiner_config,\n        )\n        refiner_predictor = refiner.get_predictor(\n            test_splitter, batch_size=1024 // num_samples, device=device\n        )\n        forecast_it, ts_it = make_evaluation_predictions(\n            dataset=transformed_testdata,\n            predictor=refiner_predictor,\n            num_samples=num_samples,\n        )\n        evaluator = Evaluator()\n        refined_metrics, _ = evaluator(\n            list(ts_it),\n            list(tqdm(forecast_it, total=len(transformed_testdata))),\n        )\n        refined_metrics = filter_metrics(refined_metrics)\n\n        results.append(\n            {\n                \"model\": refiner_name,\n                \"model_params\": json.dumps(ref_config),\n                **refined_metrics,\n            }\n        )\n\n    with open(save_path, \"w\") as fp:\n        yaml.safe_dump(\n            {\"config\": config, \"metrics\": results},\n            fp,\n            default_flow_style=False,\n            sort_keys=False,\n        )\n\n\nif __name__ == \"__main__\":\n    # Setup Logger\n    logging.basicConfig(\n        format=\"%(asctime)s - %(name)s - %(levelname)s - %(message)s\"\n    )\n    logger = logging.getLogger(__file__)\n    logger.setLevel(logging.INFO)\n\n    # Setup argparse\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"-c\", \"--config\", type=str, required=True, help=\"Path to yaml config\"\n    )\n    parser.add_argument(\n        \"--out_dir\", type=str, default=\"./results\", help=\"Path to results dir\"\n    )\n    args, _ = parser.parse_known_args()\n\n    with open(args.config, \"r\") as fp:\n        config = yaml.safe_load(fp)\n\n    # Update config from command line\n    parser = add_config_to_argparser(config=config, parser=parser)\n    args = parser.parse_args()\n    config_updates = vars(args)\n    for k in config.keys() & config_updates.keys():\n        orig_val = config[k]\n        updated_val = config_updates[k]\n        if updated_val != orig_val:\n            logger.info(f\"Updated key '{k}': {orig_val} -> {updated_val}\")\n    config.update(config_updates)\n\n    main(config=config, log_dir=args.out_dir)\n"
  },
  {
    "path": "bin/train_cond_model.py",
    "content": "# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n# SPDX-License-Identifier: Apache-2.0\nimport logging\nimport argparse\nfrom pathlib import Path\n\nimport yaml\nimport torch\nfrom tqdm.auto import tqdm\nimport pytorch_lightning as pl\nfrom pytorch_lightning.callbacks import ModelCheckpoint, RichProgressBar\n\nfrom gluonts.dataset.loader import TrainDataLoader, ValidationDataLoader\nfrom gluonts.dataset.split import OffsetSplitter\nfrom gluonts.itertools import Cached\nfrom gluonts.torch.batchify import batchify\nfrom gluonts.evaluation import make_evaluation_predictions, Evaluator\nfrom gluonts.dataset.field_names import FieldName\n\nimport uncond_ts_diff.configs as diffusion_configs\nfrom uncond_ts_diff.dataset import get_gts_dataset\nfrom uncond_ts_diff.model import TSDiffCond\nfrom uncond_ts_diff.utils import (\n    create_transforms,\n    create_splitter,\n    add_config_to_argparser,\n    filter_metrics,\n    MaskInput,\n    ConcatDataset,\n)\n\n\ndef create_model(config):\n    model = TSDiffCond(\n        **getattr(diffusion_configs, config[\"diffusion_config\"]),\n        freq=config[\"freq\"],\n        use_features=config[\"use_features\"],\n        use_lags=config[\"use_lags\"],\n        normalization=config[\"normalization\"],\n        context_length=config[\"context_length\"],\n        prediction_length=config[\"prediction_length\"],\n        lr=config[\"lr\"],\n        init_skip=config[\"init_skip\"],\n        noise_observed=config[\"noise_observed\"],\n    )\n    model.to(config[\"device\"])\n    return model\n\n\ndef evaluate_conditional(\n    config,\n    model: TSDiffCond,\n    test_dataset,\n    transformation,\n    num_samples=100,\n):\n    logger.info(f\"Evaluating with {num_samples} samples.\")\n    logger.info(\n        f\"Evaluating scenario '{config['missing_scenario']}' \"\n        f\"with {config['missing_values']:.1f} missing_values.\"\n    )\n\n    results = []\n\n    transformed_testdata = transformation.apply(test_dataset, is_train=False)\n    test_splitter = create_splitter(\n        past_length=config[\"context_length\"] + max(model.lags_seq),\n        future_length=config[\"prediction_length\"],\n        mode=\"test\",\n    )\n\n    masking_transform = MaskInput(\n        FieldName.TARGET,\n        FieldName.OBSERVED_VALUES,\n        config[\"context_length\"],\n        config[\"missing_scenario\"],\n        config[\"missing_values\"],\n    )\n    test_transform = test_splitter + masking_transform\n\n    predictor = model.get_predictor(\n        test_transform,\n        batch_size=1280,\n        device=config[\"device\"],\n    )\n    forecast_it, ts_it = make_evaluation_predictions(\n        dataset=transformed_testdata,\n        predictor=predictor,\n        num_samples=num_samples,\n    )\n    forecasts = list(tqdm(forecast_it, total=len(transformed_testdata)))\n    tss = list(ts_it)\n    evaluator = Evaluator()\n    metrics, _ = evaluator(tss, forecasts)\n    metrics = filter_metrics(metrics)\n    results.append(dict(**metrics))\n\n    return results\n\n\ndef main(config, log_dir):\n    # Load parameters\n    dataset_name = config[\"dataset\"]\n    freq = config[\"freq\"]\n    context_length = config[\"context_length\"]\n    prediction_length = config[\"prediction_length\"]\n    total_length = context_length + prediction_length\n\n    # Create model\n    model = create_model(config)\n\n    # Setup dataset and data loading\n    dataset = get_gts_dataset(dataset_name)\n    assert dataset.metadata.freq == freq\n    assert dataset.metadata.prediction_length == prediction_length\n\n    if config[\"setup\"] == \"forecasting\":\n        training_data = dataset.train\n    elif config[\"setup\"] == \"missing_values\":\n        missing_values_splitter = OffsetSplitter(offset=-total_length)\n        training_data, _ = missing_values_splitter.split(dataset.train)\n\n    num_rolling_evals = int(len(dataset.test) / len(dataset.train))\n\n    transformation = create_transforms(\n        num_feat_dynamic_real=0,\n        num_feat_static_cat=0,\n        num_feat_static_real=0,\n        time_features=model.time_features,\n        prediction_length=config[\"prediction_length\"],\n    )\n\n    training_splitter = create_splitter(\n        past_length=config[\"context_length\"] + max(model.lags_seq),\n        future_length=config[\"prediction_length\"],\n        mode=\"train\",\n    )\n\n    if config[\"setup\"] == \"forecasting\":\n        config[\"missing_scenario\"] = \"none\"\n        config[\"missing_values\"] = 0\n\n    masking_transform = MaskInput(\n        FieldName.TARGET,\n        FieldName.OBSERVED_VALUES,\n        config[\"context_length\"],\n        config.get(\"train_missing_scenario\", config[\"missing_scenario\"]),\n        config[\"missing_values\"],\n    )\n    train_transform = training_splitter + masking_transform\n\n    callbacks = []\n    val_loader = None\n    if config[\"use_validation_set\"]:\n        transformed_data = transformation.apply(training_data, is_train=True)\n        train_val_splitter = OffsetSplitter(\n            offset=-config[\"prediction_length\"] * num_rolling_evals\n        )\n        _, val_gen = train_val_splitter.split(training_data)\n\n        val_dataset = ConcatDataset(\n            val_gen.generate_instances(\n                config[\"prediction_length\"], num_rolling_evals\n            )\n        )\n        val_splitter = create_splitter(\n            past_length=config[\"context_length\"] + max(model.lags_seq),\n            future_length=config[\"prediction_length\"],\n            mode=\"val\",\n        )\n        transformed_valdata = transformation.apply(val_dataset, is_train=True)\n        val_loader = ValidationDataLoader(\n            transformed_valdata,\n            batch_size=1280,\n            stack_fn=batchify,\n            transform=val_splitter + masking_transform,\n        )\n\n        callbacks = []\n        log_monitor = \"valid_loss\"\n    else:\n        transformed_data = transformation.apply(training_data, is_train=True)\n        log_monitor = \"train_loss\"\n\n    filename = dataset_name + \"-{epoch:03d}-{train_loss:.3f}\"\n\n    data_loader = TrainDataLoader(\n        Cached(transformed_data),\n        batch_size=config[\"batch_size\"],\n        stack_fn=batchify,\n        transform=train_transform,\n        num_batches_per_epoch=config[\"num_batches_per_epoch\"],\n    )\n\n    checkpoint_callback = ModelCheckpoint(\n        save_top_k=3,\n        monitor=f\"{log_monitor}\",\n        mode=\"min\",\n        filename=filename,\n        save_last=True,\n        save_weights_only=True,\n    )\n\n    callbacks.append(checkpoint_callback)\n    callbacks.append(RichProgressBar())\n\n    trainer = pl.Trainer(\n        accelerator=\"gpu\" if torch.cuda.is_available() else None,\n        devices=[int(config[\"device\"].split(\":\")[-1])],\n        max_epochs=config[\"max_epochs\"],\n        enable_progress_bar=True,\n        num_sanity_val_steps=0,\n        callbacks=callbacks,\n        default_root_dir=log_dir,\n        gradient_clip_val=config.get(\"gradient_clip_val\", None),\n        check_val_every_n_epoch=config[\"eval_every\"],\n    )\n    logger.info(f\"Logging to {trainer.logger.log_dir}\")\n    trainer.fit(\n        model, train_dataloaders=data_loader, val_dataloaders=val_loader\n    )\n    logger.info(\"Training completed.\")\n\n    best_ckpt_path = Path(trainer.logger.log_dir) / \"best_checkpoint.ckpt\"\n\n    if not best_ckpt_path.exists():\n        torch.save(\n            torch.load(checkpoint_callback.best_model_path)[\"state_dict\"],\n            best_ckpt_path,\n        )\n    logger.info(f\"Loading {best_ckpt_path}.\")\n    best_state_dict = torch.load(best_ckpt_path)\n    model.load_state_dict(best_state_dict, strict=True)\n\n    metrics = (\n        evaluate_conditional(config, model, dataset.test, transformation)\n        if config.get(\"do_final_eval\", True)\n        else \"Final eval not performed\"\n    )\n    with open(Path(trainer.logger.log_dir) / \"results.yaml\", \"w\") as fp:\n        yaml.dump(\n            {\n                \"config\": config,\n                \"version\": trainer.logger.version,\n                \"metrics\": metrics,\n            },\n            fp,\n        )\n\n\nif __name__ == \"__main__\":\n    # Setup Logger\n    logging.basicConfig(\n        format=\"%(asctime)s - %(name)s - %(levelname)s - %(message)s\"\n    )\n    logger = logging.getLogger(__file__)\n    logger.setLevel(logging.INFO)\n\n    # Setup argparse\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"-c\", \"--config\", type=str, required=True, help=\"Path to yaml config\"\n    )\n    parser.add_argument(\n        \"--out_dir\", type=str, default=\"./\", help=\"Path to results dir\"\n    )\n    args, _ = parser.parse_known_args()\n\n    with open(args.config, \"r\") as fp:\n        config = yaml.safe_load(fp)\n\n    # Update config from command line\n    parser = add_config_to_argparser(config=config, parser=parser)\n    args = parser.parse_args()\n    config_updates = vars(args)\n    for k in config.keys() & config_updates.keys():\n        orig_val = config[k]\n        updated_val = config_updates[k]\n        if updated_val != orig_val:\n            logger.info(f\"Updated key '{k}': {orig_val} -> {updated_val}\")\n    config.update(config_updates)\n\n    main(config=config, log_dir=args.out_dir)\n"
  },
  {
    "path": "bin/train_model.py",
    "content": "# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n# SPDX-License-Identifier: Apache-2.0\nimport logging\nimport argparse\nfrom pathlib import Path\n\nimport yaml\nimport torch\nfrom tqdm.auto import tqdm\nimport pytorch_lightning as pl\nfrom pytorch_lightning.callbacks import ModelCheckpoint, RichProgressBar\n\nfrom gluonts.dataset.loader import TrainDataLoader\nfrom gluonts.dataset.split import OffsetSplitter\nfrom gluonts.itertools import Cached\nfrom gluonts.torch.batchify import batchify\nfrom gluonts.evaluation import make_evaluation_predictions, Evaluator\nfrom gluonts.dataset.field_names import FieldName\n\nimport uncond_ts_diff.configs as diffusion_configs\nfrom uncond_ts_diff.dataset import get_gts_dataset\nfrom uncond_ts_diff.model.callback import EvaluateCallback\nfrom uncond_ts_diff.model import TSDiff\nfrom uncond_ts_diff.sampler import DDPMGuidance, DDIMGuidance\nfrom uncond_ts_diff.utils import (\n    create_transforms,\n    create_splitter,\n    add_config_to_argparser,\n    filter_metrics,\n    MaskInput,\n)\n\nguidance_map = {\"ddpm\": DDPMGuidance, \"ddim\": DDIMGuidance}\n\n\ndef create_model(config):\n    model = TSDiff(\n        **getattr(diffusion_configs, config[\"diffusion_config\"]),\n        freq=config[\"freq\"],\n        use_features=config[\"use_features\"],\n        use_lags=config[\"use_lags\"],\n        normalization=config[\"normalization\"],\n        context_length=config[\"context_length\"],\n        prediction_length=config[\"prediction_length\"],\n        lr=config[\"lr\"],\n        init_skip=config[\"init_skip\"],\n    )\n    model.to(config[\"device\"])\n    return model\n\n\ndef evaluate_guidance(\n    config, model, test_dataset, transformation, num_samples=100\n):\n    logger.info(f\"Evaluating with {num_samples} samples.\")\n    results = []\n    if config[\"setup\"] == \"forecasting\":\n        missing_data_kwargs_list = [\n            {\n                \"missing_scenario\": \"none\",\n                \"missing_values\": 0,\n            }\n        ]\n        config[\"missing_data_configs\"] = missing_data_kwargs_list\n    elif config[\"setup\"] == \"missing_values\":\n        missing_data_kwargs_list = config[\"missing_data_configs\"]\n    else:\n        raise ValueError(f\"Unknown setup {config['setup']}\")\n\n    Guidance = guidance_map[config[\"sampler\"]]\n    sampler_kwargs = config[\"sampler_params\"]\n    for missing_data_kwargs in missing_data_kwargs_list:\n        logger.info(\n            f\"Evaluating scenario '{missing_data_kwargs['missing_scenario']}' \"\n            f\"with {missing_data_kwargs['missing_values']:.1f} missing_values.\"\n        )\n        sampler = Guidance(\n            model=model,\n            prediction_length=config[\"prediction_length\"],\n            num_samples=num_samples,\n            **missing_data_kwargs,\n            **sampler_kwargs,\n        )\n\n        transformed_testdata = transformation.apply(\n            test_dataset, is_train=False\n        )\n        test_splitter = create_splitter(\n            past_length=config[\"context_length\"] + max(model.lags_seq),\n            future_length=config[\"prediction_length\"],\n            mode=\"test\",\n        )\n\n        masking_transform = MaskInput(\n            FieldName.TARGET,\n            FieldName.OBSERVED_VALUES,\n            config[\"context_length\"],\n            missing_data_kwargs[\"missing_scenario\"],\n            missing_data_kwargs[\"missing_values\"],\n        )\n        test_transform = test_splitter + masking_transform\n\n        predictor = sampler.get_predictor(\n            test_transform,\n            batch_size=1280 // num_samples,\n            device=config[\"device\"],\n        )\n        forecast_it, ts_it = make_evaluation_predictions(\n            dataset=transformed_testdata,\n            predictor=predictor,\n            num_samples=num_samples,\n        )\n        forecasts = list(tqdm(forecast_it, total=len(transformed_testdata)))\n        tss = list(ts_it)\n        evaluator = Evaluator()\n        metrics, _ = evaluator(tss, forecasts)\n        metrics = filter_metrics(metrics)\n        results.append(dict(**missing_data_kwargs, **metrics))\n\n    return results\n\n\ndef main(config, log_dir):\n    # Load parameters\n    dataset_name = config[\"dataset\"]\n    freq = config[\"freq\"]\n    context_length = config[\"context_length\"]\n    prediction_length = config[\"prediction_length\"]\n    total_length = context_length + prediction_length\n\n    # Create model\n    model = create_model(config)\n\n    # Setup dataset and data loading\n    dataset = get_gts_dataset(dataset_name)\n    assert dataset.metadata.freq == freq\n    assert dataset.metadata.prediction_length == prediction_length\n\n    if config[\"setup\"] == \"forecasting\":\n        training_data = dataset.train\n    elif config[\"setup\"] == \"missing_values\":\n        missing_values_splitter = OffsetSplitter(offset=-total_length)\n        training_data, _ = missing_values_splitter.split(dataset.train)\n\n    num_rolling_evals = int(len(dataset.test) / len(dataset.train))\n\n    transformation = create_transforms(\n        num_feat_dynamic_real=0,\n        num_feat_static_cat=0,\n        num_feat_static_real=0,\n        time_features=model.time_features,\n        prediction_length=config[\"prediction_length\"],\n    )\n\n    training_splitter = create_splitter(\n        past_length=config[\"context_length\"] + max(model.lags_seq),\n        future_length=config[\"prediction_length\"],\n        mode=\"train\",\n    )\n\n    callbacks = []\n    if config[\"use_validation_set\"]:\n        transformed_data = transformation.apply(training_data, is_train=True)\n        train_val_splitter = OffsetSplitter(\n            offset=-config[\"prediction_length\"] * num_rolling_evals\n        )\n        _, val_gen = train_val_splitter.split(training_data)\n        val_data = val_gen.generate_instances(\n            config[\"prediction_length\"], num_rolling_evals\n        )\n\n        callbacks = [\n            EvaluateCallback(\n                context_length=config[\"context_length\"],\n                prediction_length=config[\"prediction_length\"],\n                sampler=config[\"sampler\"],\n                sampler_kwargs=config[\"sampler_params\"],\n                num_samples=config[\"num_samples\"],\n                model=model,\n                transformation=transformation,\n                test_dataset=dataset.test,\n                val_dataset=val_data,\n                eval_every=config[\"eval_every\"],\n            )\n        ]\n    else:\n        transformed_data = transformation.apply(training_data, is_train=True)\n\n    log_monitor = \"train_loss\"\n    filename = dataset_name + \"-{epoch:03d}-{train_loss:.3f}\"\n\n    data_loader = TrainDataLoader(\n        Cached(transformed_data),\n        batch_size=config[\"batch_size\"],\n        stack_fn=batchify,\n        transform=training_splitter,\n        num_batches_per_epoch=config[\"num_batches_per_epoch\"],\n    )\n\n    checkpoint_callback = ModelCheckpoint(\n        save_top_k=3,\n        monitor=f\"{log_monitor}\",\n        mode=\"min\",\n        filename=filename,\n        save_last=True,\n        save_weights_only=True,\n    )\n\n    callbacks.append(checkpoint_callback)\n    callbacks.append(RichProgressBar())\n\n    trainer = pl.Trainer(\n        accelerator=\"gpu\" if torch.cuda.is_available() else None,\n        devices=[int(config[\"device\"].split(\":\")[-1])],\n        max_epochs=config[\"max_epochs\"],\n        enable_progress_bar=True,\n        num_sanity_val_steps=0,\n        callbacks=callbacks,\n        default_root_dir=log_dir,\n        gradient_clip_val=config.get(\"gradient_clip_val\", None),\n    )\n    logger.info(f\"Logging to {trainer.logger.log_dir}\")\n    trainer.fit(model, train_dataloaders=data_loader)\n    logger.info(\"Training completed.\")\n\n    best_ckpt_path = Path(trainer.logger.log_dir) / \"best_checkpoint.ckpt\"\n\n    if not best_ckpt_path.exists():\n        torch.save(\n            torch.load(checkpoint_callback.best_model_path)[\"state_dict\"],\n            best_ckpt_path,\n        )\n    logger.info(f\"Loading {best_ckpt_path}.\")\n    best_state_dict = torch.load(best_ckpt_path)\n    model.load_state_dict(best_state_dict, strict=True)\n\n    metrics = (\n        evaluate_guidance(config, model, dataset.test, transformation)\n        if config.get(\"do_final_eval\", True)\n        else \"Final eval not performed\"\n    )\n    with open(Path(trainer.logger.log_dir) / \"results.yaml\", \"w\") as fp:\n        yaml.dump(\n            {\n                \"config\": config,\n                \"version\": trainer.logger.version,\n                \"metrics\": metrics,\n            },\n            fp,\n        )\n\n\nif __name__ == \"__main__\":\n    # Setup Logger\n    logging.basicConfig(\n        format=\"%(asctime)s - %(name)s - %(levelname)s - %(message)s\"\n    )\n    logger = logging.getLogger(__file__)\n    logger.setLevel(logging.INFO)\n\n    # Setup argparse\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"-c\", \"--config\", type=str, required=True, help=\"Path to yaml config\"\n    )\n    parser.add_argument(\n        \"--out_dir\", type=str, default=\"./\", help=\"Path to results dir\"\n    )\n    args, _ = parser.parse_known_args()\n\n    with open(args.config, \"r\") as fp:\n        config = yaml.safe_load(fp)\n\n    # Update config from command line\n    parser = add_config_to_argparser(config=config, parser=parser)\n    args = parser.parse_args()\n    config_updates = vars(args)\n    for k in config.keys() & config_updates.keys():\n        orig_val = config[k]\n        updated_val = config_updates[k]\n        if updated_val != orig_val:\n            logger.info(f\"Updated key '{k}': {orig_val} -> {updated_val}\")\n    config.update(config_updates)\n\n    main(config=config, log_dir=args.out_dir)\n"
  },
  {
    "path": "bin/tstr_experiment.py",
    "content": "# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n# SPDX-License-Identifier: Apache-2.0\nfrom functools import partial\nimport math\nimport logging\nimport argparse\nfrom pathlib import Path\n\nimport yaml\nimport torch\nimport numpy as np\nfrom tqdm.auto import tqdm\n\nfrom gluonts.mx import DeepAREstimator, TransformerEstimator\nfrom gluonts.evaluation import Evaluator\nfrom gluonts.dataset.loader import TrainDataLoader\nfrom gluonts.itertools import Cached\nfrom gluonts.torch.batchify import batchify\nfrom gluonts.time_feature import (\n    get_lags_for_frequency,\n    time_features_from_frequency_str,\n)\nfrom gluonts.dataset.split import slice_data_entry\nfrom gluonts.transform import AdhocTransform, Chain\n\nfrom uncond_ts_diff.utils import (\n    ScaleAndAddMeanFeature,\n    ScaleAndAddMinMaxFeature,\n    GluonTSNumpyDataset,\n    create_transforms,\n    create_splitter,\n    get_next_file_num,\n    add_config_to_argparser,\n    make_evaluation_predictions_with_scaling,\n    filter_metrics,\n)\nfrom uncond_ts_diff.model import TSDiff, LinearEstimator\nfrom uncond_ts_diff.dataset import get_gts_dataset\nimport uncond_ts_diff.configs as diffusion_configs\n\nDOWNSTREAM_MODELS = [\"linear\", \"deepar\", \"transformer\"]\n\n\ndef load_model(config):\n    model = TSDiff(\n        **getattr(\n            diffusion_configs,\n            config.get(\"diffusion_config\", \"diffusion_small_config\"),\n        ),\n        freq=config[\"freq\"],\n        use_features=config[\"use_features\"],\n        use_lags=config[\"use_lags\"],\n        normalization=\"mean\",\n        context_length=config[\"context_length\"],\n        prediction_length=config[\"prediction_length\"],\n        init_skip=config[\"init_skip\"],\n    )\n    model.load_state_dict(\n        torch.load(config[\"ckpt\"], map_location=\"cpu\"),\n        strict=True,\n    )\n    model = model.to(config[\"device\"])\n    return model\n\n\ndef sample_synthetic(\n    model: TSDiff,\n    num_samples: int = 10_000,\n    batch_size: int = 1000,\n):\n    synth_samples = []\n\n    n_iters = math.ceil(num_samples / batch_size)\n    for _ in tqdm(range(n_iters)):\n        samples = model.sample_n(num_samples=batch_size)\n        synth_samples.append(samples)\n\n    synth_samples = np.concatenate(synth_samples, axis=0)[:num_samples]\n\n    return synth_samples\n\n\ndef sample_real(\n    data_loader,\n    n_timesteps: int,\n    num_samples: int = 10_000,\n    batch_size: int = 1000,\n):\n    real_samples = []\n    data_iter = iter(data_loader)\n    n_iters = math.ceil(num_samples / batch_size)\n    for _ in tqdm(range(n_iters)):\n        try:\n            batch = next(data_iter)\n        except StopIteration:\n            data_iter = iter(data_loader)\n            batch = next(data_iter)\n        ts = np.concatenate(\n            [batch[\"past_target\"], batch[\"future_target\"]], axis=-1\n        )[:, -n_timesteps:]\n        real_samples.append(ts)\n\n    real_samples = np.concatenate(real_samples, axis=0)[:num_samples]\n\n    return real_samples\n\n\ndef evaluate_tstr(\n    tstr_predictor,\n    test_dataset,\n    context_length,\n    prediction_length,\n    num_samples=100,\n    scaling_type=\"mean\",\n):\n    total_length = context_length + prediction_length\n    # Slice test set to be of the same length as context_length + prediction_length\n    slice_func = partial(slice_data_entry, slice_=slice(-total_length, None))\n    if scaling_type == \"mean\":\n        ScaleAndAddScaleFeature = ScaleAndAddMeanFeature\n    elif scaling_type == \"min-max\":\n        ScaleAndAddScaleFeature = ScaleAndAddMinMaxFeature\n    transformation = Chain(\n        [\n            AdhocTransform(slice_func),\n            # Add scale to data entry for use later during evaluation\n            ScaleAndAddScaleFeature(\"target\", \"scale\", prediction_length),\n        ]\n    )\n    sliced_test_set = transformation.apply(test_dataset)\n\n    fcst_iter, ts_iter = make_evaluation_predictions_with_scaling(\n        dataset=sliced_test_set,\n        predictor=tstr_predictor,\n        num_samples=num_samples,\n        scaling_type=scaling_type,\n    )\n    evaluator = Evaluator()\n    metrics, _ = evaluator(list(ts_iter), list(fcst_iter))\n    return filter_metrics(metrics)\n\n\ndef train_and_evaluate(\n    dataset,\n    model_name,\n    synth_samples,\n    real_samples,\n    config,\n    scaling_type=\"mean\",\n):\n    # NOTE: There's no notion of time for synthetic time series,\n    # they are just \"sequences\".\n    # A dummy timestamp is used for start time in synthetic time series.\n    # Hence, time_features are set to [] in the models below.\n    model_name = model_name.lower()\n    freq = dataset.metadata.freq\n    context_length = config[\"context_length\"]\n    prediction_length = config[\"prediction_length\"]\n    total_length = context_length + prediction_length\n\n    assert len(synth_samples) == len(real_samples)\n    assert (\n        synth_samples.shape[-1] == total_length\n        and real_samples.shape[-1] == total_length\n    )\n    num_samples = len(real_samples)\n\n    synthetic_dataset = GluonTSNumpyDataset(synth_samples)\n\n    if model_name == \"linear\":\n        logger.info(f\"Running TSTR for {model_name}\")\n        tstr_predictor = LinearEstimator(\n            freq=freq,  # Not actually used in the estimator\n            prediction_length=prediction_length,\n            context_length=context_length,\n            num_train_samples=num_samples,\n            # Synthetic dataset is in the \"scaled space\"\n            scaling=False,\n        ).train(synthetic_dataset)\n    elif model_name == \"deepar\":\n        logger.info(f\"Running TSTR for {model_name}\")\n        tstr_predictor = DeepAREstimator(\n            freq=freq,\n            prediction_length=prediction_length,\n            # Synthetic dataset is in the \"scaled space\"\n            scaling=False,\n            time_features=[],\n            lags_seq=get_lags_for_frequency(freq, lag_ub=context_length),\n        ).train(synthetic_dataset)\n    elif model_name == \"transformer\":\n        logger.info(f\"Running TSTR for {model_name}\")\n        tstr_predictor = TransformerEstimator(\n            freq=freq,\n            prediction_length=prediction_length,\n            # Synthetic dataset is in the \"scaled space\"\n            scaling=False,\n            time_features=[],\n            lags_seq=get_lags_for_frequency(freq, lag_ub=context_length),\n        ).train(synthetic_dataset)\n\n    tstr_metrics = evaluate_tstr(\n        tstr_predictor=tstr_predictor,\n        test_dataset=dataset.test,\n        context_length=context_length,\n        prediction_length=prediction_length,\n        scaling_type=scaling_type,\n    )\n\n    return dict(\n        tstr_metrics=tstr_metrics,\n    )\n\n\ndef main(config: dict, log_dir: str, samples_path: str):\n    # Read global parameters\n    dataset_name = config[\"dataset\"]\n    context_length = config[\"context_length\"]\n    prediction_length = config[\"prediction_length\"]\n\n    # Create log_dir\n    log_dir: Path = Path(log_dir)\n    base_dirname = \"tstr_log\"\n    run_num = get_next_file_num(\n        base_dirname, log_dir, file_type=\"\", separator=\"-\"\n    )\n    log_dir = log_dir / f\"{base_dirname}-{run_num}\"\n    log_dir.mkdir(exist_ok=True, parents=True)\n    logger.info(f\"Logging to {log_dir}\")\n\n    # Load dataset and model\n    logger.info(\"Loading model\")\n    dataset = get_gts_dataset(dataset_name)\n    config[\"freq\"] = dataset.metadata.freq\n    assert prediction_length == dataset.metadata.prediction_length\n\n    model = load_model(config)\n\n    # Setup data transformation and loading\n    transformation = create_transforms(\n        num_feat_dynamic_real=0,\n        num_feat_static_cat=0,\n        num_feat_static_real=0,\n        time_features=time_features_from_frequency_str(config[\"freq\"]),\n        prediction_length=prediction_length,\n    )\n    transformed_data = transformation.apply(list(dataset.train), is_train=True)\n    training_splitter = create_splitter(\n        past_length=context_length + max(model.lags_seq),\n        future_length=prediction_length,\n        mode=\"train\",\n    )\n    train_dataloader = TrainDataLoader(\n        Cached(transformed_data),\n        batch_size=1000,\n        stack_fn=batchify,\n        transform=training_splitter,\n    )\n\n    # Generate real samples\n    logger.info(\"Generating real samples\")\n    real_samples = sample_real(\n        train_dataloader,\n        n_timesteps=context_length + prediction_length,\n        num_samples=10000,\n    )\n    np.save(log_dir / \"real_samples.npy\", real_samples)\n\n    if samples_path is None:\n        # Generate synthetic samples\n        logger.info(\"Generating synthetic samples\")\n        synth_samples = sample_synthetic(model, num_samples=10000)\n        np.save(log_dir / \"synth_samples.npy\", synth_samples)\n    else:\n        logger.info(f\"Using synthetic samples from {samples_path}\")\n        synth_samples = np.load(samples_path)[:10000]\n        synth_samples = synth_samples.reshape(\n            (10000, context_length + prediction_length)\n        )\n\n    # Run TSTR experiment for each downstream model\n    results = []\n\n    for model_name in DOWNSTREAM_MODELS:\n        logger.info(f\"Training and evaluating {model_name}\")\n        metrics = train_and_evaluate(\n            dataset=dataset,\n            model_name=model_name,\n            synth_samples=synth_samples,\n            real_samples=real_samples,\n            config=config,\n            scaling_type=config[\"scaling_type\"],\n        )\n        results.append({\"model\": model_name, **metrics})\n\n    logger.info(\"Saving results\")\n    with open(log_dir / \"results.yaml\", \"w\") as fp:\n        yaml.safe_dump(\n            {\"config\": config, \"metrics\": results},\n            fp,\n            default_flow_style=False,\n            sort_keys=False,\n        )\n\n\nif __name__ == \"__main__\":\n    # Setup Logger\n    logging.basicConfig(\n        format=\"%(asctime)s - %(name)s - %(levelname)s - %(message)s\"\n    )\n    logger = logging.getLogger(__file__)\n    logger.setLevel(logging.INFO)\n\n    # Setup argparse\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"-c\", \"--config\", type=str, required=True, help=\"Path to yaml config\"\n    )\n    parser.add_argument(\n        \"--out_dir\", type=str, default=\"./results\", help=\"Path to results dir\"\n    )\n    parser.add_argument(\n        \"--samples_path\", type=str, help=\"Path to generated samples\"\n    )\n    args, _ = parser.parse_known_args()\n\n    with open(args.config, \"r\") as fp:\n        config = yaml.safe_load(fp)\n\n    # Update config from command line\n    parser = add_config_to_argparser(config=config, parser=parser)\n    args = parser.parse_args()\n    config_updates = vars(args)\n    for k in config.keys() & config_updates.keys():\n        orig_val = config[k]\n        updated_val = config_updates[k]\n        if updated_val != orig_val:\n            logger.info(f\"Updated key '{k}': {orig_val} -> {updated_val}\")\n    config.update(config_updates)\n\n    main(config=config, log_dir=args.out_dir, samples_path=args.samples_path)\n"
  },
  {
    "path": "configs/guidance/guidance_electricity.yaml",
    "content": "ckpt: dummy/path.ckpt\ncontext_length: 336\ndataset: electricity_nips\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\nfreq: H\ninit_skip: false\nnum_samples: 100\nprediction_length: 24\nsampler: ddpm\nsampler_params:\n  guidance: quantile\n  scale: 4\nsetup: forecasting\nuse_features: false\nuse_lags: true\n"
  },
  {
    "path": "configs/guidance/guidance_exchange.yaml",
    "content": "ckpt: dummy/path.ckpt\ncontext_length: 360\ndataset: exchange_rate_nips\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\nfreq: B\ninit_skip: true\nnum_samples: 100\nprediction_length: 30\nsampler: ddpm\nsampler_params:\n  guidance: quantile\n  scale: 8\nsetup: forecasting\nuse_features: false\nuse_lags: true\n"
  },
  {
    "path": "configs/guidance/guidance_kdd_cup.yaml",
    "content": "ckpt: dummy/path.ckpt\ncontext_length: 312\ndataset: kdd_cup_2018_without_missing\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\nfreq: H\ninit_skip: true\nnum_samples: 100\nprediction_length: 48\nsampler: ddpm\nsampler_params:\n  guidance: quantile\n  scale: 1\nsetup: forecasting\nuse_features: false\nuse_lags: true\n"
  },
  {
    "path": "configs/guidance/guidance_m4.yaml",
    "content": "ckpt: dummy/path.ckpt\ncontext_length: 312\ndataset: m4_hourly\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\nfreq: H\ninit_skip: false\nnum_samples: 100\nprediction_length: 48\nsampler: ddpm\nsampler_params:\n  guidance: quantile\n  scale: 2\nsetup: forecasting\nuse_features: false\nuse_lags: false\n"
  },
  {
    "path": "configs/guidance/guidance_solar.yaml",
    "content": "ckpt: dummy/path.ckpt\ncontext_length: 336\ndataset: solar_nips\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\nfreq: H\ninit_skip: false\nnum_samples: 100\nprediction_length: 24\nsampler: ddpm\nsampler_params:\n  guidance: quantile\n  scale: 8\nsetup: forecasting\nuse_features: false\nuse_lags: true\n"
  },
  {
    "path": "configs/guidance/guidance_traffic.yaml",
    "content": "ckpt: dummy/path.ckpt\ncontext_length: 336\ndataset: traffic_nips\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\nfreq: H\ninit_skip: true\nnum_samples: 100\nprediction_length: 24\nsampler: ddpm\nsampler_params:\n  guidance: quantile\n  scale: 4\nsetup: forecasting\nuse_features: false\nuse_lags: true\n"
  },
  {
    "path": "configs/guidance/guidance_uber_tlc.yaml",
    "content": "ckpt: dummy/path.ckpt\ncontext_length: 336\ndataset: uber_tlc_hourly\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\nfreq: H\ninit_skip: false\nnum_samples: 100\nprediction_length: 24\nsampler: ddpm\nsampler_params:\n  guidance: quantile\n  scale: 2\nsetup: forecasting\nuse_features: false\nuse_lags: true\n"
  },
  {
    "path": "configs/guidance/guidance_wiki.yaml",
    "content": "ckpt: dummy/path.ckpt\ncontext_length: 360\ndataset: wiki2000_nips\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\nfreq: 1D\ninit_skip: false\nnum_samples: 100\nprediction_length: 30\nsampler: ddpm\nsampler_params:\n  guidance: quantile\n  scale: 2\nsetup: forecasting\nuse_features: false\nuse_lags: false\n"
  },
  {
    "path": "configs/guidance.yaml",
    "content": "# Model & checkpoint parameters\ndataset: solar_nips\nfreq: H\ndevice: cuda:0\nckpt: ckpts/forecasting/solar_nips/649_.ckpt\ndiffusion_config: diffusion_small_config\ncontext_length: 336\nprediction_length: 24\nuse_lags: true\nuse_features: false\ninit_skip: false\nsampler: ddpm\nsampler_params:\n  guidance: quantile\n  scale: 4\nnum_samples: 100\nsetup: forecasting\n# The following key will be ignored,\n# if the setup is forecasting\nmissing_data_configs:\n- missing_scenario: BM-B\n  missing_values: 168\n- missing_scenario: BM-E\n  missing_values: 168"
  },
  {
    "path": "configs/refinement/electricity_nips-deepar.yaml",
    "content": "base_model: deepar\nckpt: dummy/electricity_nips.ckpt\ncontext_length: 336\ndataset: electricity_nips\ndevice: cuda:0\ninit_skip: false\niterations: 20\nnum_samples: 100\nprediction_length: 24\nrefiner_configs:\n- guidance: MSE\n  lr: 0.1\n  refiner_name: most_likely\n- guidance: quantile\n  lr: 0.1\n  refiner_name: most_likely\n- guidance: MSE\n  method: lmc\n  method_kwargs:\n    noise_scale: 0.1\n  refiner_name: mcmc\n  step_size: 0.1\n- guidance: quantile\n  method: lmc\n  method_kwargs:\n    noise_scale: 0.1\n  refiner_name: mcmc\n  step_size: 0.1\nuse_features: false\nuse_lags: true\n"
  },
  {
    "path": "configs/refinement/electricity_nips-linear.yaml",
    "content": "base_model: linear\nckpt: dummy/electricity_nips.ckpt\ncontext_length: 336\ndataset: electricity_nips\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\ninit_skip: false\niterations: 20\nnum_samples: 100\nprediction_length: 24\nrefiner_configs:\n- guidance: MSE\n  lr: 0.1\n  refiner_name: most_likely\n- guidance: quantile\n  lr: 0.1\n  refiner_name: most_likely\n- guidance: MSE\n  method: lmc\n  method_kwargs:\n    noise_scale: 0.1\n  refiner_name: mcmc\n  step_size: 0.1\n- guidance: quantile\n  method: lmc\n  method_kwargs:\n    noise_scale: 0.1\n  refiner_name: mcmc\n  step_size: 0.1\nuse_features: false\nuse_lags: true\n"
  },
  {
    "path": "configs/refinement/electricity_nips-seasonal_naive.yaml",
    "content": "base_model: seasonal_naive\nckpt: dummy/electricity_nips.ckpt\ncontext_length: 336\ndataset: electricity_nips\ndevice: cuda:0\ninit_skip: false\niterations: 20\nnum_samples: 100\nprediction_length: 24\nrefiner_configs:\n- guidance: MSE\n  lr: 0.1\n  refiner_name: most_likely\n- guidance: quantile\n  lr: 0.1\n  refiner_name: most_likely\n- guidance: MSE\n  method: lmc\n  method_kwargs:\n    noise_scale: 0.1\n  refiner_name: mcmc\n  step_size: 0.1\n- guidance: quantile\n  method: lmc\n  method_kwargs:\n    noise_scale: 0.1\n  refiner_name: mcmc\n  step_size: 0.1\nuse_features: false\nuse_lags: true\n"
  },
  {
    "path": "configs/refinement/electricity_nips-transformer.yaml",
    "content": "base_model: transformer\nckpt: dummy/electricity_nips.ckpt\ncontext_length: 336\ndataset: electricity_nips\ndevice: cuda:0\ninit_skip: false\niterations: 20\nnum_samples: 100\nprediction_length: 24\nrefiner_configs:\n- guidance: MSE\n  lr: 0.1\n  refiner_name: most_likely\n- guidance: quantile\n  lr: 0.1\n  refiner_name: most_likely\n- guidance: MSE\n  method: lmc\n  method_kwargs:\n    noise_scale: 0.1\n  refiner_name: mcmc\n  step_size: 0.1\n- guidance: quantile\n  method: lmc\n  method_kwargs:\n    noise_scale: 0.1\n  refiner_name: mcmc\n  step_size: 0.1\nuse_features: false\nuse_lags: true\n"
  },
  {
    "path": "configs/refinement/exchange_rate_nips-deepar.yaml",
    "content": "base_model: deepar\nckpt: dummy/exchange_rate_nips.ckpt\ncontext_length: 360\ndataset: exchange_rate_nips\ndevice: cuda:0\ninit_skip: true\niterations: 20\nnum_samples: 100\nprediction_length: 30\nrefiner_configs:\n- guidance: MSE\n  lr: 0.01\n  refiner_name: most_likely\n- guidance: quantile\n  lr: 0.01\n  refiner_name: most_likely\n- guidance: MSE\n  method: lmc\n  method_kwargs:\n    noise_scale: 0.1\n  refiner_name: mcmc\n  step_size: 0.01\n- guidance: quantile\n  method: lmc\n  method_kwargs:\n    noise_scale: 0.1\n  refiner_name: mcmc\n  step_size: 0.01\nuse_features: false\nuse_lags: true\n"
  },
  {
    "path": "configs/refinement/exchange_rate_nips-linear.yaml",
    "content": "base_model: linear\nckpt: dummy/exchange_rate_nips.ckpt\ncontext_length: 360\ndataset: exchange_rate_nips\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\ninit_skip: true\niterations: 20\nnum_samples: 100\nprediction_length: 30\nrefiner_configs:\n- guidance: MSE\n  lr: 0.01\n  refiner_name: most_likely\n- guidance: quantile\n  lr: 0.01\n  refiner_name: most_likely\n- guidance: MSE\n  method: lmc\n  method_kwargs:\n    noise_scale: 0.1\n  refiner_name: mcmc\n  step_size: 0.01\n- guidance: quantile\n  method: lmc\n  method_kwargs:\n    noise_scale: 0.1\n  refiner_name: mcmc\n  step_size: 0.01\nuse_features: false\nuse_lags: true\n"
  },
  {
    "path": "configs/refinement/exchange_rate_nips-seasonal_naive.yaml",
    "content": "base_model: seasonal_naive\nckpt: dummy/exchange_rate_nips.ckpt\ncontext_length: 360\ndataset: exchange_rate_nips\ndevice: cuda:0\ninit_skip: true\niterations: 20\nnum_samples: 100\nprediction_length: 30\nrefiner_configs:\n- guidance: MSE\n  lr: 0.01\n  refiner_name: most_likely\n- guidance: quantile\n  lr: 0.01\n  refiner_name: most_likely\n- guidance: MSE\n  method: lmc\n  method_kwargs:\n    noise_scale: 0.1\n  refiner_name: mcmc\n  step_size: 0.01\n- guidance: quantile\n  method: lmc\n  method_kwargs:\n    noise_scale: 0.1\n  refiner_name: mcmc\n  step_size: 0.01\nuse_features: false\nuse_lags: true\n"
  },
  {
    "path": "configs/refinement/exchange_rate_nips-transformer.yaml",
    "content": "base_model: transformer\nckpt: dummy/exchange_rate_nips.ckpt\ncontext_length: 360\ndataset: exchange_rate_nips\ndevice: cuda:0\ninit_skip: true\niterations: 20\nnum_samples: 100\nprediction_length: 30\nrefiner_configs:\n- guidance: MSE\n  lr: 0.01\n  refiner_name: most_likely\n- guidance: quantile\n  lr: 0.01\n  refiner_name: most_likely\n- guidance: MSE\n  method: lmc\n  method_kwargs:\n    noise_scale: 0.1\n  refiner_name: mcmc\n  step_size: 0.01\n- guidance: quantile\n  method: lmc\n  method_kwargs:\n    noise_scale: 0.1\n  refiner_name: mcmc\n  step_size: 0.01\nuse_features: false\nuse_lags: true\n"
  },
  {
    "path": "configs/refinement/kdd_cup_2018_without_missing-deepar.yaml",
    "content": "base_model: deepar\nbase_model_params: {}\nckpt: dummy/kdd_cup_2018_without_missing.ckpt\ncontext_length: 312\ndataset: kdd_cup_2018_without_missing\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\ninit_skip: true\niterations: 20\nnum_samples: 100\nprediction_length: 48\nrefiner_configs:\n- guidance: MSE\n  lr: 0.1\n  refiner_name: most_likely\n- guidance: quantile\n  lr: 0.1\n  refiner_name: most_likely\n- guidance: MSE\n  method: lmc\n  method_kwargs:\n    noise_scale: 0.1\n  refiner_name: mcmc\n  step_size: 0.1\n- guidance: quantile\n  method: lmc\n  method_kwargs:\n    noise_scale: 0.1\n  refiner_name: mcmc\n  step_size: 0.1\nuse_features: false\nuse_lags: true\n"
  },
  {
    "path": "configs/refinement/kdd_cup_2018_without_missing-linear.yaml",
    "content": "base_model: linear\nbase_model_params: {}\nckpt: dummy/kdd_cup_2018_without_missing.ckpt\ncontext_length: 312\ndataset: kdd_cup_2018_without_missing\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\ninit_skip: true\niterations: 20\nnum_samples: 100\nprediction_length: 48\nrefiner_configs:\n- guidance: MSE\n  lr: 0.1\n  refiner_name: most_likely\n- guidance: quantile\n  lr: 0.1\n  refiner_name: most_likely\n- guidance: MSE\n  method: lmc\n  method_kwargs:\n    noise_scale: 0.1\n  refiner_name: mcmc\n  step_size: 0.1\n- guidance: quantile\n  method: lmc\n  method_kwargs:\n    noise_scale: 0.1\n  refiner_name: mcmc\n  step_size: 0.1\nuse_features: false\nuse_lags: true\n"
  },
  {
    "path": "configs/refinement/kdd_cup_2018_without_missing-seasonal_naive.yaml",
    "content": "base_model: seasonal_naive\nbase_model_params: {}\nckpt: dummy/kdd_cup_2018_without_missing.ckpt\ncontext_length: 312\ndataset: kdd_cup_2018_without_missing\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\ninit_skip: true\niterations: 20\nnum_samples: 100\nprediction_length: 48\nrefiner_configs:\n- guidance: MSE\n  lr: 0.1\n  refiner_name: most_likely\n- guidance: quantile\n  lr: 0.1\n  refiner_name: most_likely\n- guidance: MSE\n  method: lmc\n  method_kwargs:\n    noise_scale: 0.1\n  refiner_name: mcmc\n  step_size: 0.1\n- guidance: quantile\n  method: lmc\n  method_kwargs:\n    noise_scale: 0.1\n  refiner_name: mcmc\n  step_size: 0.1\nuse_features: false\nuse_lags: true\n"
  },
  {
    "path": "configs/refinement/kdd_cup_2018_without_missing-transformer.yaml",
    "content": "base_model: transformer\nbase_model_params: {}\nckpt: dummy/kdd_cup_2018_without_missing.ckpt\ncontext_length: 312\ndataset: kdd_cup_2018_without_missing\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\ninit_skip: true\niterations: 20\nnum_samples: 100\nprediction_length: 48\nrefiner_configs:\n- guidance: MSE\n  lr: 0.1\n  refiner_name: most_likely\n- guidance: quantile\n  lr: 0.1\n  refiner_name: most_likely\n- guidance: MSE\n  method: lmc\n  method_kwargs:\n    noise_scale: 0.1\n  refiner_name: mcmc\n  step_size: 0.1\n- guidance: quantile\n  method: lmc\n  method_kwargs:\n    noise_scale: 0.1\n  refiner_name: mcmc\n  step_size: 0.1\nuse_features: false\nuse_lags: true\n"
  },
  {
    "path": "configs/refinement/m4_hourly-deepar.yaml",
    "content": "base_model: deepar\nckpt: dummy/m4_hourly.ckpt\ncontext_length: 312\ndataset: m4_hourly\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\ninit_skip: false\niterations: 20\nnum_samples: 100\nprediction_length: 48\nrefiner_configs:\n- guidance: MSE\n  lr: 0.1\n  refiner_name: most_likely\n- guidance: quantile\n  lr: 0.1\n  refiner_name: most_likely\n- guidance: MSE\n  method: lmc\n  method_kwargs:\n    noise_scale: 0.1\n  refiner_name: mcmc\n  step_size: 0.1\n- guidance: quantile\n  method: lmc\n  method_kwargs:\n    noise_scale: 0.1\n  refiner_name: mcmc\n  step_size: 0.1\nuse_features: false\nuse_lags: false\n"
  },
  {
    "path": "configs/refinement/m4_hourly-linear.yaml",
    "content": "base_model: linear\nckpt: dummy/m4_hourly.ckpt\ncontext_length: 312\ndataset: m4_hourly\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\ninit_skip: false\niterations: 20\nnum_samples: 100\nprediction_length: 48\nrefiner_configs:\n- guidance: MSE\n  lr: 0.1\n  refiner_name: most_likely\n- guidance: quantile\n  lr: 0.1\n  refiner_name: most_likely\n- guidance: MSE\n  method: lmc\n  method_kwargs:\n    noise_scale: 0.1\n  refiner_name: mcmc\n  step_size: 0.1\n- guidance: quantile\n  method: lmc\n  method_kwargs:\n    noise_scale: 0.1\n  refiner_name: mcmc\n  step_size: 0.1\nuse_features: false\nuse_lags: false\n"
  },
  {
    "path": "configs/refinement/m4_hourly-seasonal_naive.yaml",
    "content": "base_model: seasonal_naive\nckpt: dummy/m4_hourly.ckpt\ncontext_length: 312\ndataset: m4_hourly\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\ninit_skip: false\niterations: 20\nnum_samples: 100\nprediction_length: 48\nrefiner_configs:\n- guidance: MSE\n  lr: 0.1\n  refiner_name: most_likely\n- guidance: quantile\n  lr: 0.1\n  refiner_name: most_likely\n- guidance: MSE\n  method: lmc\n  method_kwargs:\n    noise_scale: 0.1\n  refiner_name: mcmc\n  step_size: 0.1\n- guidance: quantile\n  method: lmc\n  method_kwargs:\n    noise_scale: 0.1\n  refiner_name: mcmc\n  step_size: 0.1\nuse_features: false\nuse_lags: false\n"
  },
  {
    "path": "configs/refinement/m4_hourly-transformer.yaml",
    "content": "base_model: transformer\nckpt: dummy/m4_hourly.ckpt\ncontext_length: 312\ndataset: m4_hourly\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\ninit_skip: false\niterations: 20\nnum_samples: 100\nprediction_length: 48\nrefiner_configs:\n- guidance: MSE\n  lr: 0.1\n  refiner_name: most_likely\n- guidance: quantile\n  lr: 0.1\n  refiner_name: most_likely\n- guidance: MSE\n  method: lmc\n  method_kwargs:\n    noise_scale: 0.1\n  refiner_name: mcmc\n  step_size: 0.1\n- guidance: quantile\n  method: lmc\n  method_kwargs:\n    noise_scale: 0.1\n  refiner_name: mcmc\n  step_size: 0.1\nuse_features: false\nuse_lags: false\n"
  },
  {
    "path": "configs/refinement/solar_nips-deepar.yaml",
    "content": "base_model: deepar\nckpt: dummy/solar_nips.ckpt\ncontext_length: 336\ndataset: solar_nips\ndevice: cuda:0\ninit_skip: false\niterations: 20\nnum_samples: 100\nprediction_length: 24\nrefiner_configs:\n- guidance: MSE\n  lr: 0.1\n  refiner_name: most_likely\n- guidance: quantile\n  lr: 0.1\n  refiner_name: most_likely\n- guidance: MSE\n  method: lmc\n  method_kwargs:\n    noise_scale: 0.1\n  refiner_name: mcmc\n  step_size: 0.1\n- guidance: quantile\n  method: lmc\n  method_kwargs:\n    noise_scale: 0.1\n  refiner_name: mcmc\n  step_size: 0.1\nuse_features: false\nuse_lags: true\n"
  },
  {
    "path": "configs/refinement/solar_nips-linear.yaml",
    "content": "base_model: linear\nckpt: dummy/solar_nips.ckpt\ncontext_length: 336\ndataset: solar_nips\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\ninit_skip: false\niterations: 20\nnum_samples: 100\nprediction_length: 24\nrefiner_configs:\n- guidance: MSE\n  lr: 0.1\n  refiner_name: most_likely\n- guidance: quantile\n  lr: 0.1\n  refiner_name: most_likely\n- guidance: MSE\n  method: lmc\n  method_kwargs:\n    noise_scale: 0.1\n  refiner_name: mcmc\n  step_size: 0.1\n- guidance: quantile\n  method: lmc\n  method_kwargs:\n    noise_scale: 0.1\n  refiner_name: mcmc\n  step_size: 0.1\nuse_features: false\nuse_lags: true\n"
  },
  {
    "path": "configs/refinement/solar_nips-seasonal_naive.yaml",
    "content": "base_model: seasonal_naive\nckpt: dummy/solar_nips.ckpt\ncontext_length: 336\ndataset: solar_nips\ndevice: cuda:0\ninit_skip: false\niterations: 20\nnum_samples: 100\nprediction_length: 24\nrefiner_configs:\n- guidance: MSE\n  lr: 0.1\n  refiner_name: most_likely\n- guidance: quantile\n  lr: 0.1\n  refiner_name: most_likely\n- guidance: MSE\n  method: lmc\n  method_kwargs:\n    noise_scale: 0.1\n  refiner_name: mcmc\n  step_size: 0.1\n- guidance: quantile\n  method: lmc\n  method_kwargs:\n    noise_scale: 0.1\n  refiner_name: mcmc\n  step_size: 0.1\nuse_features: false\nuse_lags: true\n"
  },
  {
    "path": "configs/refinement/solar_nips-transformer.yaml",
    "content": "base_model: transformer\nckpt: dummy/solar_nips.ckpt\ncontext_length: 336\ndataset: solar_nips\ndevice: cuda:0\ninit_skip: false\niterations: 20\nnum_samples: 100\nprediction_length: 24\nrefiner_configs:\n- guidance: MSE\n  lr: 0.1\n  refiner_name: most_likely\n- guidance: quantile\n  lr: 0.1\n  refiner_name: most_likely\n- guidance: MSE\n  method: lmc\n  method_kwargs:\n    noise_scale: 0.1\n  refiner_name: mcmc\n  step_size: 0.1\n- guidance: quantile\n  method: lmc\n  method_kwargs:\n    noise_scale: 0.1\n  refiner_name: mcmc\n  step_size: 0.1\nuse_features: false\nuse_lags: true\n"
  },
  {
    "path": "configs/refinement/traffic_nips-deepar.yaml",
    "content": "base_model: deepar\nckpt: dummy/traffic_nips.ckpt\ncontext_length: 336\ndataset: traffic_nips\ndevice: cuda:0\ninit_skip: true\niterations: 20\nnum_samples: 100\nprediction_length: 24\nrefiner_configs:\n- guidance: MSE\n  lr: 0.1\n  refiner_name: most_likely\n- guidance: quantile\n  lr: 0.1\n  refiner_name: most_likely\n- guidance: MSE\n  method: lmc\n  method_kwargs:\n    noise_scale: 0.1\n  refiner_name: mcmc\n  step_size: 0.1\n- guidance: quantile\n  method: lmc\n  method_kwargs:\n    noise_scale: 0.1\n  refiner_name: mcmc\n  step_size: 0.1\nuse_features: false\nuse_lags: true\n"
  },
  {
    "path": "configs/refinement/traffic_nips-linear.yaml",
    "content": "base_model: linear\nckpt: dummy/traffic_nips.ckpt\ncontext_length: 336\ndataset: traffic_nips\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\ninit_skip: true\niterations: 20\nnum_samples: 100\nprediction_length: 24\nrefiner_configs:\n- guidance: MSE\n  lr: 0.1\n  refiner_name: most_likely\n- guidance: quantile\n  lr: 0.1\n  refiner_name: most_likely\n- guidance: MSE\n  method: lmc\n  method_kwargs:\n    noise_scale: 0.1\n  refiner_name: mcmc\n  step_size: 0.1\n- guidance: quantile\n  method: lmc\n  method_kwargs:\n    noise_scale: 0.1\n  refiner_name: mcmc\n  step_size: 0.1\nuse_features: false\nuse_lags: true\n"
  },
  {
    "path": "configs/refinement/traffic_nips-seasonal_naive.yaml",
    "content": "base_model: seasonal_naive\nckpt: dummy/traffic_nips.ckpt\ncontext_length: 336\ndataset: traffic_nips\ndevice: cuda:0\ninit_skip: true\niterations: 20\nnum_samples: 100\nprediction_length: 24\nrefiner_configs:\n- guidance: MSE\n  lr: 0.1\n  refiner_name: most_likely\n- guidance: quantile\n  lr: 0.1\n  refiner_name: most_likely\n- guidance: MSE\n  method: lmc\n  method_kwargs:\n    noise_scale: 0.1\n  refiner_name: mcmc\n  step_size: 0.1\n- guidance: quantile\n  method: lmc\n  method_kwargs:\n    noise_scale: 0.1\n  refiner_name: mcmc\n  step_size: 0.1\nuse_features: false\nuse_lags: true\n"
  },
  {
    "path": "configs/refinement/traffic_nips-transformer.yaml",
    "content": "base_model: transformer\nckpt: dummy/traffic_nips.ckpt\ncontext_length: 336\ndataset: traffic_nips\ndevice: cuda:0\ninit_skip: true\niterations: 20\nnum_samples: 100\nprediction_length: 24\nrefiner_configs:\n- guidance: MSE\n  lr: 0.1\n  refiner_name: most_likely\n- guidance: quantile\n  lr: 0.1\n  refiner_name: most_likely\n- guidance: MSE\n  method: lmc\n  method_kwargs:\n    noise_scale: 0.1\n  refiner_name: mcmc\n  step_size: 0.1\n- guidance: quantile\n  method: lmc\n  method_kwargs:\n    noise_scale: 0.1\n  refiner_name: mcmc\n  step_size: 0.1\nuse_features: false\nuse_lags: true\n"
  },
  {
    "path": "configs/refinement/uber_tlc_hourly-deepar.yaml",
    "content": "base_model: deepar\nckpt: dummy/uber_tlc_hourly.ckpt\ncontext_length: 336\ndataset: uber_tlc_hourly\ndevice: cuda:0\ninit_skip: false\niterations: 20\nnum_samples: 100\nprediction_length: 24\nrefiner_configs:\n- guidance: MSE\n  lr: 0.1\n  refiner_name: most_likely\n- guidance: quantile\n  lr: 0.1\n  refiner_name: most_likely\n- guidance: MSE\n  method: lmc\n  method_kwargs:\n    noise_scale: 0.1\n  refiner_name: mcmc\n  step_size: 0.1\n- guidance: quantile\n  method: lmc\n  method_kwargs:\n    noise_scale: 0.1\n  refiner_name: mcmc\n  step_size: 0.1\nuse_features: false\nuse_lags: true\n"
  },
  {
    "path": "configs/refinement/uber_tlc_hourly-linear.yaml",
    "content": "base_model: linear\nckpt: dummy/uber_tlc_hourly.ckpt\ncontext_length: 336\ndataset: uber_tlc_hourly\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\ninit_skip: false\niterations: 20\nnum_samples: 100\nprediction_length: 24\nrefiner_configs:\n- guidance: MSE\n  lr: 0.1\n  refiner_name: most_likely\n- guidance: quantile\n  lr: 0.1\n  refiner_name: most_likely\n- guidance: MSE\n  method: lmc\n  method_kwargs:\n    noise_scale: 0.1\n  refiner_name: mcmc\n  step_size: 0.1\n- guidance: quantile\n  method: lmc\n  method_kwargs:\n    noise_scale: 0.1\n  refiner_name: mcmc\n  step_size: 0.1\nuse_features: false\nuse_lags: true\n"
  },
  {
    "path": "configs/refinement/uber_tlc_hourly-seasonal_naive.yaml",
    "content": "base_model: seasonal_naive\nckpt: dummy/uber_tlc_hourly.ckpt\ncontext_length: 336\ndataset: uber_tlc_hourly\ndevice: cuda:0\ninit_skip: false\niterations: 20\nnum_samples: 100\nprediction_length: 24\nrefiner_configs:\n- guidance: MSE\n  lr: 0.1\n  refiner_name: most_likely\n- guidance: quantile\n  lr: 0.1\n  refiner_name: most_likely\n- guidance: MSE\n  method: lmc\n  method_kwargs:\n    noise_scale: 0.1\n  refiner_name: mcmc\n  step_size: 0.1\n- guidance: quantile\n  method: lmc\n  method_kwargs:\n    noise_scale: 0.1\n  refiner_name: mcmc\n  step_size: 0.1\nuse_features: false\nuse_lags: true\n"
  },
  {
    "path": "configs/refinement/uber_tlc_hourly-transformer.yaml",
    "content": "base_model: transformer\nckpt: dummy/uber_tlc_hourly.ckpt\ncontext_length: 336\ndataset: uber_tlc_hourly\ndevice: cuda:0\ninit_skip: false\niterations: 20\nnum_samples: 100\nprediction_length: 24\nrefiner_configs:\n- guidance: MSE\n  lr: 0.1\n  refiner_name: most_likely\n- guidance: quantile\n  lr: 0.1\n  refiner_name: most_likely\n- guidance: MSE\n  method: lmc\n  method_kwargs:\n    noise_scale: 0.1\n  refiner_name: mcmc\n  step_size: 0.1\n- guidance: quantile\n  method: lmc\n  method_kwargs:\n    noise_scale: 0.1\n  refiner_name: mcmc\n  step_size: 0.1\nuse_features: false\nuse_lags: true\n"
  },
  {
    "path": "configs/refinement/wiki2000_nips-deepar.yaml",
    "content": "base_model: deepar\nckpt: dummy/wiki2000_nips.ckpt\ncontext_length: 360\ndataset: wiki2000_nips\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\ninit_skip: false\niterations: 20\nnum_samples: 100\nprediction_length: 30\nrefiner_configs:\n- guidance: MSE\n  lr: 0.1\n  refiner_name: most_likely\n- guidance: quantile\n  lr: 0.1\n  refiner_name: most_likely\n- guidance: MSE\n  method: lmc\n  method_kwargs:\n    noise_scale: 0.1\n  refiner_name: mcmc\n  step_size: 0.1\n- guidance: quantile\n  method: lmc\n  method_kwargs:\n    noise_scale: 0.1\n  refiner_name: mcmc\n  step_size: 0.1\nuse_features: false\nuse_lags: false\n"
  },
  {
    "path": "configs/refinement/wiki2000_nips-linear.yaml",
    "content": "base_model: linear\nckpt: dummy/wiki2000_nips.ckpt\ncontext_length: 360\ndataset: wiki2000_nips\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\ninit_skip: false\niterations: 20\nnum_samples: 100\nprediction_length: 30\nrefiner_configs:\n- guidance: MSE\n  lr: 0.1\n  refiner_name: most_likely\n- guidance: quantile\n  lr: 0.1\n  refiner_name: most_likely\n- guidance: MSE\n  method: lmc\n  method_kwargs:\n    noise_scale: 0.1\n  refiner_name: mcmc\n  step_size: 0.1\n- guidance: quantile\n  method: lmc\n  method_kwargs:\n    noise_scale: 0.1\n  refiner_name: mcmc\n  step_size: 0.1\nuse_features: false\nuse_lags: false\n"
  },
  {
    "path": "configs/refinement/wiki2000_nips-seasonal_naive.yaml",
    "content": "base_model: seasonal_naive\nckpt: dummy/wiki2000_nips.ckpt\ncontext_length: 360\ndataset: wiki2000_nips\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\ninit_skip: false\niterations: 20\nnum_samples: 100\nprediction_length: 30\nrefiner_configs:\n- guidance: MSE\n  lr: 0.1\n  refiner_name: most_likely\n- guidance: quantile\n  lr: 0.1\n  refiner_name: most_likely\n- guidance: MSE\n  method: lmc\n  method_kwargs:\n    noise_scale: 0.1\n  refiner_name: mcmc\n  step_size: 0.1\n- guidance: quantile\n  method: lmc\n  method_kwargs:\n    noise_scale: 0.1\n  refiner_name: mcmc\n  step_size: 0.1\nuse_features: false\nuse_lags: false\n"
  },
  {
    "path": "configs/refinement/wiki2000_nips-transformer.yaml",
    "content": "base_model: transformer\nckpt: dummy/wiki2000_nips.ckpt\ncontext_length: 360\ndataset: wiki2000_nips\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\ninit_skip: false\niterations: 20\nnum_samples: 100\nprediction_length: 30\nrefiner_configs:\n- guidance: MSE\n  lr: 0.1\n  refiner_name: most_likely\n- guidance: quantile\n  lr: 0.1\n  refiner_name: most_likely\n- guidance: MSE\n  method: lmc\n  method_kwargs:\n    noise_scale: 0.1\n  refiner_name: mcmc\n  step_size: 0.1\n- guidance: quantile\n  method: lmc\n  method_kwargs:\n    noise_scale: 0.1\n  refiner_name: mcmc\n  step_size: 0.1\nuse_features: false\nuse_lags: false\n"
  },
  {
    "path": "configs/refinement.yaml",
    "content": "# Model & checkpoint parameters\ndataset: solar_nips\ndevice: cuda:0\nckpt: ckpts/forecasting/solar_nips/649_.ckpt\ndiffusion_config: diffusion_small_config\ncontext_length: 336\nprediction_length: 24\nuse_lags: true\nuse_features: false\ninit_skip: false\n# Refinement parameters\nbase_model: linear\nbase_model_params: {}\nnum_samples: 16\niterations: 10\nrefiner_configs:\n- refiner_name: most_likely\n  lr: 1.e-1\n  guidance: MSE\n- refiner_name: most_likely\n  lr: 1.e-1\n  guidance: quantile\n- refiner_name: mcmc\n  step_size: 1.e-1\n  guidance: MSE\n  method: lmc\n  method_kwargs:\n    noise_scale: 0.1\n- refiner_name: mcmc\n  step_size: 1.e-1\n  guidance: quantile\n  method: lmc\n  method_kwargs:\n    noise_scale: 0.1"
  },
  {
    "path": "configs/train_tsdiff/train_electricity.yaml",
    "content": "model: unconditional\ndiffusion_config: diffusion_small_config\nnormalization: mean\nuse_features: False\nuse_lags: True\ndataset: electricity_nips\nfreq: H\ncontext_length: 336 # 360 for `D`\nprediction_length: 24 # 30 for `D`\nlr: 1.e-3\ninit_skip: False\ngradient_clip_val: 0.5\nmax_epochs: 1000\nnum_batches_per_epoch: 128\nbatch_size: 64\n# Used only in callback,\n# the final evaluation uses 100 samples\nnum_samples: 16\nsampler: ddpm\nsampler_params:\n  guidance: quantile\n  scale: 4\nuse_validation_set: True\neval_every: 50\ndevice: cuda:0\nsetup: forecasting"
  },
  {
    "path": "configs/train_tsdiff/train_exchange.yaml",
    "content": "model: unconditional\ndiffusion_config: diffusion_small_config\nnormalization: mean\nuse_features: False\nuse_lags: True\ndataset: exchange_rate_nips\nfreq: B\ncontext_length: 360 # 360 for `D`\nprediction_length: 30 # 30 for `D`\nlr: 1.e-3\ninit_skip: True\ngradient_clip_val: 0.5\nmax_epochs: 1000\nnum_batches_per_epoch: 128\nbatch_size: 64\n# Used only in callback,\n# the final evaluation uses 100 samples\nnum_samples: 16\nsampler: ddpm\nsampler_params:\n  guidance: quantile\n  scale: 8\nuse_validation_set: True\neval_every: 50\ndevice: cuda:0\nsetup: forecasting"
  },
  {
    "path": "configs/train_tsdiff/train_kdd_cup.yaml",
    "content": "model: unconditional\ndiffusion_config: diffusion_small_config\nnormalization: mean\nuse_features: False\nuse_lags: True\ndataset: kdd_cup_2018_without_missing\nfreq: H\ncontext_length: 312 # 360 for `D`\nprediction_length: 48 # 30 for `D`\nlr: 1.e-3\ninit_skip: True\ngradient_clip_val: 0.5\nmax_epochs: 1000\nnum_batches_per_epoch: 128\nbatch_size: 64\n# Used only in callback,\n# the final evaluation uses 100 samples\nnum_samples: 16\nsampler: ddpm\nsampler_params:\n  guidance: quantile\n  scale: 1\nuse_validation_set: True\neval_every: 50\ndevice: cuda:0\nsetup: forecasting"
  },
  {
    "path": "configs/train_tsdiff/train_m4.yaml",
    "content": "model: unconditional\ndiffusion_config: diffusion_small_config\nnormalization: mean\nuse_features: False\nuse_lags: False\ndataset: m4_hourly\nfreq: H\ncontext_length: 312 # 360 for `D`\nprediction_length: 48 # 30 for `D`\nlr: 1.e-3\ninit_skip: False\ngradient_clip_val: 0.5\nmax_epochs: 1000\nnum_batches_per_epoch: 128\nbatch_size: 64\n# Used only in callback,\n# the final evaluation uses 100 samples\nnum_samples: 16\nsampler: ddpm\nsampler_params:\n  guidance: quantile\n  scale: 2\nuse_validation_set: True\neval_every: 50\ndevice: cuda:0\nsetup: forecasting"
  },
  {
    "path": "configs/train_tsdiff/train_missing_electricity.yaml",
    "content": "model: unconditional\ndiffusion_config: diffusion_small_config\nnormalization: mean\nuse_features: False\nuse_lags: True\ndataset: electricity_nips\nfreq: H\ncontext_length: 336 # 360 for `D`\nprediction_length: 24 # 30 for `D`\nlr: 1.e-3\ninit_skip: False\ngradient_clip_val: 0.5\nmax_epochs: 1000\nnum_batches_per_epoch: 128\nbatch_size: 64\n# Used only in callback,\n# the final evaluation uses 100 samples\nnum_samples: 16\nsampler: ddpm\nsampler_params:\n  guidance: quantile\n  scale: 4\nuse_validation_set: True\neval_every: 50\ndevice: cuda:0\nsetup: missing_values\n# The following key will be ignored,\n# if the setup is forecasting\nmissing_data_configs:\n- missing_scenario: none\n  missing_values: 0\n- missing_scenario: BM-E\n  missing_values: 168\n- missing_scenario: BM-B\n  missing_values: 168\n- missing_scenario: RM\n  missing_values: 168"
  },
  {
    "path": "configs/train_tsdiff/train_missing_exchange.yaml",
    "content": "model: unconditional\ndiffusion_config: diffusion_small_config\nnormalization: mean\nuse_features: False\nuse_lags: True\ndataset: exchange_rate_nips\nfreq: B\ncontext_length: 360 # 360 for `D`\nprediction_length: 30 # 30 for `D`\nlr: 1.e-3\ninit_skip: True\ngradient_clip_val: 0.5\nmax_epochs: 1000\nnum_batches_per_epoch: 128\nbatch_size: 64\n# Used only in callback,\n# the final evaluation uses 100 samples\nnum_samples: 16\nsampler: ddpm\nsampler_params:\n  guidance: quantile\n  scale: 8\nuse_validation_set: True\neval_every: 50\ndevice: cuda:0\nsetup: missing_values\n# The following key will be ignored,\n# if the setup is forecasting\nmissing_data_configs:\n- missing_scenario: none\n  missing_values: 0\n- missing_scenario: BM-E\n  missing_values: 180\n- missing_scenario: BM-B\n  missing_values: 180\n- missing_scenario: RM\n  missing_values: 180"
  },
  {
    "path": "configs/train_tsdiff/train_missing_kdd_cup.yaml",
    "content": "model: unconditional\ndiffusion_config: diffusion_small_config\nnormalization: mean\nuse_features: False\nuse_lags: True\ndataset: kdd_cup_2018_without_missing\nfreq: H\ncontext_length: 312 # 360 for `D`\nprediction_length: 48 # 30 for `D`\nlr: 1.e-3\ninit_skip: True\ngradient_clip_val: 0.5\nmax_epochs: 1000\nnum_batches_per_epoch: 128\nbatch_size: 64\n# Used only in callback,\n# the final evaluation uses 100 samples\nnum_samples: 16\nsampler: ddpm\nsampler_params:\n  guidance: quantile\n  scale: 1\nguidance: quantile\nuse_validation_set: True\ndo_final_eval: True\neval_every: 50\ndevice: cuda:0\nsetup: missing_values\n# The following key will be ignored,\n# if the setup is forecasting\nmissing_data_configs:\n- missing_scenario: none\n  missing_values: 0\n- missing_scenario: BM-E\n  missing_values: 156\n- missing_scenario: BM-B\n  missing_values: 156\n- missing_scenario: RM\n  missing_values: 156"
  },
  {
    "path": "configs/train_tsdiff/train_missing_solar.yaml",
    "content": "model: unconditional\ndiffusion_config: diffusion_small_config\nnormalization: mean\nuse_features: False\nuse_lags: True\ndataset: solar_nips\nfreq: H\ncontext_length: 336 # 360 for `D`\nprediction_length: 24 # 30 for `D`\nlr: 1.e-3\ninit_skip: False\ngradient_clip_val: 0.5\nmax_epochs: 1000\nnum_batches_per_epoch: 128\nbatch_size: 64\n# Used only in callback,\n# the final evaluation uses 100 samples\nnum_samples: 16\nsampler: ddpm\nsampler_params:\n  guidance: quantile\n  scale: 8\nuse_validation_set: True\neval_every: 50\ndevice: cuda:0\nsetup: missing_values\n# The following key will be ignored,\n# if the setup is forecasting\nmissing_data_configs:\n- missing_scenario: none\n  missing_values: 0\n- missing_scenario: BM-E\n  missing_values: 168\n- missing_scenario: BM-B\n  missing_values: 168\n- missing_scenario: RM\n  missing_values: 168"
  },
  {
    "path": "configs/train_tsdiff/train_missing_traffic.yaml",
    "content": "model: unconditional\ndiffusion_config: diffusion_small_config\nnormalization: mean\nuse_features: False\nuse_lags: True\ndataset: traffic_nips\nfreq: H\ncontext_length: 336 # 360 for `D`\nprediction_length: 24 # 30 for `D`\nlr: 1.e-3\ninit_skip: True\ngradient_clip_val: 0.5\nmax_epochs: 1000\nnum_batches_per_epoch: 128\nbatch_size: 64\n# Used only in callback,\n# the final evaluation uses 100 samples\nnum_samples: 4\nsampler: ddpm\nsampler_params:\n  guidance: quantile\n  scale: 4\nuse_validation_set: True\neval_every: 50\ndevice: cuda:0\nsetup: missing_values\n# The following key will be ignored,\n# if the setup is forecasting\nmissing_data_configs:\n- missing_scenario: none\n  missing_values: 0\n- missing_scenario: BM-E\n  missing_values: 168\n- missing_scenario: BM-B\n  missing_values: 168\n- missing_scenario: RM\n  missing_values: 168"
  },
  {
    "path": "configs/train_tsdiff/train_missing_uber_tlc.yaml",
    "content": "model: unconditional\ndiffusion_config: diffusion_small_config\nnormalization: mean\nuse_features: False\nuse_lags: True\ndataset: uber_tlc_hourly\nfreq: H\ncontext_length: 336 # 360 for `D`\nprediction_length: 24 # 30 for `D`\nlr: 1.e-3\ninit_skip: False\ngradient_clip_val: 0.5\nmax_epochs: 1000\nnum_batches_per_epoch: 128\nbatch_size: 64\n# Used only in callback,\n# the final evaluation uses 100 samples\nnum_samples: 16\nsampler: ddpm\nsampler_params:\n  guidance: quantile\n  scale: 2\nuse_validation_set: True\neval_every: 50\ndevice: cuda:0\nsetup: missing_values\n# The following key will be ignored,\n# if the setup is forecasting\nmissing_data_configs:\n- missing_scenario: none\n  missing_values: 0\n- missing_scenario: BM-E\n  missing_values: 168\n- missing_scenario: BM-B\n  missing_values: 168\n- missing_scenario: RM\n  missing_values: 168"
  },
  {
    "path": "configs/train_tsdiff/train_solar.yaml",
    "content": "model: unconditional\ndiffusion_config: diffusion_small_config\nnormalization: mean\nuse_features: False\nuse_lags: True\ndataset: solar_nips\nfreq: H\ncontext_length: 336 # 360 for `D`\nprediction_length: 24 # 30 for `D`\nlr: 1.e-3\ninit_skip: False\ngradient_clip_val: 0.5\nmax_epochs: 1000\nnum_batches_per_epoch: 128\nbatch_size: 64\n# Used only in callback,\n# the final evaluation uses 100 samples\nnum_samples: 16\nsampler: ddpm\nsampler_params:\n  guidance: quantile\n  scale: 8\nuse_validation_set: True\neval_every: 50\ndevice: cuda:0\nsetup: forecasting"
  },
  {
    "path": "configs/train_tsdiff/train_traffic.yaml",
    "content": "model: unconditional\ndiffusion_config: diffusion_small_config\nnormalization: mean\nuse_features: False\nuse_lags: True\ndataset: traffic_nips\nfreq: H\ncontext_length: 336 # 360 for `D`\nprediction_length: 24 # 30 for `D`\nlr: 1.e-3\ninit_skip: True\ngradient_clip_val: 0.5\nmax_epochs: 1000\nnum_batches_per_epoch: 128\nbatch_size: 64\n# Used only in callback,\n# the final evaluation uses 100 samples\nnum_samples: 4\nsampler: ddpm\nsampler_params:\n  guidance: quantile\n  scale: 4\nuse_validation_set: True\neval_every: 50\ndevice: cuda:0\nsetup: forecasting"
  },
  {
    "path": "configs/train_tsdiff/train_uber_tlc.yaml",
    "content": "model: unconditional\ndiffusion_config: diffusion_small_config\nnormalization: mean\nuse_features: False\nuse_lags: True\ndataset: uber_tlc_hourly\nfreq: H\ncontext_length: 336 # 360 for `D`\nprediction_length: 24 # 30 for `D`\nlr: 1.e-3\ninit_skip: False\ngradient_clip_val: 0.5\nmax_epochs: 1000\nnum_batches_per_epoch: 128\nbatch_size: 64\n# Used only in callback,\n# the final evaluation uses 100 samples\nnum_samples: 16\nsampler: ddpm\nsampler_params:\n  guidance: quantile\n  scale: 2\nuse_validation_set: True\neval_every: 50\ndevice: cuda:0\nsetup: forecasting"
  },
  {
    "path": "configs/train_tsdiff/train_wiki.yaml",
    "content": "model: unconditional\ndiffusion_config: diffusion_small_config\nnormalization: mean\nuse_features: False\nuse_lags: False\ndataset: wiki2000_nips\nfreq: 1D\ncontext_length: 360 # 360 for `D`\nprediction_length: 30 # 30 for `D`\nlr: 1.e-3\ninit_skip: False\ngradient_clip_val: 0.5\nmax_epochs: 1000\nnum_batches_per_epoch: 128\nbatch_size: 64\n# Used only in callback,\n# the final evaluation uses 100 samples\nnum_samples: 4\nsampler: ddpm\nsampler_params:\n  guidance: quantile\n  scale: 2\nuse_validation_set: True\neval_every: 50\ndevice: cuda:0\nsetup: forecasting"
  },
  {
    "path": "configs/train_tsdiff-cond/electricity_nips.yaml",
    "content": "batch_size: 64\ncontext_length: 336\ndataset: electricity_nips\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\ndo_final_eval: true\neval_every: 10\nfreq: H\ngradient_clip_val: 0.5\ninit_skip: false\nlr: 0.001\nmax_epochs: 100\nmodel: conditional\nnoise_observed: true\nnormalization: mean\nnum_batches_per_epoch: 128\nprediction_length: 24\nsetup: forecasting\nuse_features: false\nuse_lags: true\nuse_validation_set: true\n"
  },
  {
    "path": "configs/train_tsdiff-cond/exchange_rate_nips.yaml",
    "content": "batch_size: 64\ncontext_length: 360\ndataset: exchange_rate_nips\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\ndo_final_eval: true\neval_every: 10\nfreq: B\ngradient_clip_val: 0.5\ninit_skip: true\nlr: 0.001\nmax_epochs: 100\nmodel: conditional\nnoise_observed: true\nnormalization: mean\nnum_batches_per_epoch: 128\nprediction_length: 30\nsetup: forecasting\nuse_features: false\nuse_lags: true\nuse_validation_set: true\n"
  },
  {
    "path": "configs/train_tsdiff-cond/kdd_cup_2018_without_missing.yaml",
    "content": "batch_size: 64\ncontext_length: 312\ndataset: kdd_cup_2018_without_missing\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\ndo_final_eval: true\neval_every: 10\nfreq: H\ngradient_clip_val: 0.5\ninit_skip: true\nlr: 0.001\nmax_epochs: 100\nmodel: conditional\nnoise_observed: true\nnormalization: mean\nnum_batches_per_epoch: 128\nprediction_length: 48\nsetup: forecasting\nuse_features: false\nuse_lags: true\nuse_validation_set: true\n"
  },
  {
    "path": "configs/train_tsdiff-cond/m4_hourly.yaml",
    "content": "batch_size: 64\ncontext_length: 312\ndataset: m4_hourly\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\ndo_final_eval: true\neval_every: 10\nfreq: H\ngradient_clip_val: 0.5\ninit_skip: false\nlr: 0.001\nmax_epochs: 100\nmodel: conditional\nnoise_observed: true\nnormalization: mean\nnum_batches_per_epoch: 128\nprediction_length: 48\nsetup: forecasting\nuse_features: false\nuse_lags: false\nuse_validation_set: true\n"
  },
  {
    "path": "configs/train_tsdiff-cond/missing_BM-B_electricity_nips.yaml",
    "content": "batch_size: 64\ncontext_length: 336\ndataset: electricity_nips\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\ndo_final_eval: true\neval_every: 10\nfreq: H\ngradient_clip_val: 0.5\ninit_skip: false\nlr: 0.001\nmax_epochs: 100\nmissing_scenario: BM-B\nmissing_values: 168\nmodel: conditional\nnoise_observed: true\nnormalization: mean\nnum_batches_per_epoch: 128\nprediction_length: 24\nsetup: missing_values\ntrain_missing_scenario: BM-B\nuse_features: false\nuse_lags: true\nuse_validation_set: true\n"
  },
  {
    "path": "configs/train_tsdiff-cond/missing_BM-B_exchange_rate_nips.yaml",
    "content": "batch_size: 64\ncontext_length: 360\ndataset: exchange_rate_nips\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\ndo_final_eval: true\neval_every: 10\nfreq: B\ngradient_clip_val: 0.5\ninit_skip: true\nlr: 0.001\nmax_epochs: 100\nmissing_scenario: BM-B\nmissing_values: 180\nmodel: conditional\nnoise_observed: true\nnormalization: mean\nnum_batches_per_epoch: 128\nprediction_length: 30\nsetup: missing_values\ntrain_missing_scenario: BM-B\nuse_features: false\nuse_lags: true\nuse_validation_set: true\n"
  },
  {
    "path": "configs/train_tsdiff-cond/missing_BM-B_kdd_cup_2018_without_missing.yaml",
    "content": "batch_size: 64\ncontext_length: 312\ndataset: kdd_cup_2018_without_missing\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\ndo_final_eval: true\neval_every: 10\nfreq: H\ngradient_clip_val: 0.5\ninit_skip: true\nlr: 0.001\nmax_epochs: 100\nmissing_scenario: BM-B\nmissing_values: 156\nmodel: conditional\nnoise_observed: true\nnormalization: mean\nnum_batches_per_epoch: 128\nprediction_length: 48\nsetup: missing_values\ntrain_missing_scenario: BM-B\nuse_features: false\nuse_lags: true\nuse_validation_set: true\n"
  },
  {
    "path": "configs/train_tsdiff-cond/missing_BM-B_solar_nips.yaml",
    "content": "batch_size: 64\ncontext_length: 336\ndataset: solar_nips\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\ndo_final_eval: true\neval_every: 10\nfreq: H\ngradient_clip_val: 0.5\ninit_skip: false\nlr: 0.001\nmax_epochs: 100\nmissing_scenario: BM-B\nmissing_values: 168\nmodel: conditional\nnoise_observed: true\nnormalization: mean\nnum_batches_per_epoch: 128\nprediction_length: 24\nsetup: missing_values\ntrain_missing_scenario: BM-B\nuse_features: false\nuse_lags: true\nuse_validation_set: true\n"
  },
  {
    "path": "configs/train_tsdiff-cond/missing_BM-B_traffic_nips.yaml",
    "content": "batch_size: 64\ncontext_length: 336\ndataset: traffic_nips\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\ndo_final_eval: true\neval_every: 10\nfreq: H\ngradient_clip_val: 0.5\ninit_skip: true\nlr: 0.001\nmax_epochs: 100\nmissing_scenario: BM-B\nmissing_values: 168\nmodel: conditional\nnoise_observed: true\nnormalization: mean\nnum_batches_per_epoch: 128\nprediction_length: 24\nsetup: missing_values\ntrain_missing_scenario: BM-B\nuse_features: false\nuse_lags: true\nuse_validation_set: true\n"
  },
  {
    "path": "configs/train_tsdiff-cond/missing_BM-B_uber_tlc_hourly.yaml",
    "content": "batch_size: 64\ncontext_length: 336\ndataset: uber_tlc_hourly\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\ndo_final_eval: true\neval_every: 10\nfreq: H\ngradient_clip_val: 0.5\ninit_skip: false\nlr: 0.001\nmax_epochs: 100\nmissing_scenario: BM-B\nmissing_values: 168\nmodel: conditional\nnoise_observed: true\nnormalization: mean\nnum_batches_per_epoch: 128\nprediction_length: 24\nsetup: missing_values\ntrain_missing_scenario: BM-B\nuse_features: false\nuse_lags: true\nuse_validation_set: true\n"
  },
  {
    "path": "configs/train_tsdiff-cond/missing_BM-E_electricity_nips.yaml",
    "content": "batch_size: 64\ncontext_length: 336\ndataset: electricity_nips\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\ndo_final_eval: true\neval_every: 10\nfreq: H\ngradient_clip_val: 0.5\ninit_skip: false\nlr: 0.001\nmax_epochs: 100\nmissing_scenario: BM-E\nmissing_values: 168\nmodel: conditional\nnoise_observed: true\nnormalization: mean\nnum_batches_per_epoch: 128\nprediction_length: 24\nsetup: missing_values\ntrain_missing_scenario: BM-E\nuse_features: false\nuse_lags: true\nuse_validation_set: true\n"
  },
  {
    "path": "configs/train_tsdiff-cond/missing_BM-E_exchange_rate_nips.yaml",
    "content": "batch_size: 64\ncontext_length: 360\ndataset: exchange_rate_nips\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\ndo_final_eval: true\neval_every: 10\nfreq: B\ngradient_clip_val: 0.5\ninit_skip: true\nlr: 0.001\nmax_epochs: 100\nmissing_scenario: BM-E\nmissing_values: 180\nmodel: conditional\nnoise_observed: true\nnormalization: mean\nnum_batches_per_epoch: 128\nprediction_length: 30\nsetup: missing_values\ntrain_missing_scenario: BM-E\nuse_features: false\nuse_lags: true\nuse_validation_set: true\n"
  },
  {
    "path": "configs/train_tsdiff-cond/missing_BM-E_kdd_cup_2018_without_missing.yaml",
    "content": "batch_size: 64\ncontext_length: 312\ndataset: kdd_cup_2018_without_missing\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\ndo_final_eval: true\neval_every: 10\nfreq: H\ngradient_clip_val: 0.5\ninit_skip: true\nlr: 0.001\nmax_epochs: 100\nmissing_scenario: BM-E\nmissing_values: 156\nmodel: conditional\nnoise_observed: true\nnormalization: mean\nnum_batches_per_epoch: 128\nprediction_length: 48\nsetup: missing_values\ntrain_missing_scenario: BM-E\nuse_features: false\nuse_lags: true\nuse_validation_set: true\n"
  },
  {
    "path": "configs/train_tsdiff-cond/missing_BM-E_solar_nips.yaml",
    "content": "batch_size: 64\ncontext_length: 336\ndataset: solar_nips\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\ndo_final_eval: true\neval_every: 10\nfreq: H\ngradient_clip_val: 0.5\ninit_skip: false\nlr: 0.001\nmax_epochs: 100\nmissing_scenario: BM-E\nmissing_values: 168\nmodel: conditional\nnoise_observed: true\nnormalization: mean\nnum_batches_per_epoch: 128\nprediction_length: 24\nsetup: missing_values\ntrain_missing_scenario: BM-E\nuse_features: false\nuse_lags: true\nuse_validation_set: true\n"
  },
  {
    "path": "configs/train_tsdiff-cond/missing_BM-E_traffic_nips.yaml",
    "content": "batch_size: 64\ncontext_length: 336\ndataset: traffic_nips\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\ndo_final_eval: true\neval_every: 10\nfreq: H\ngradient_clip_val: 0.5\ninit_skip: true\nlr: 0.001\nmax_epochs: 100\nmissing_scenario: BM-E\nmissing_values: 168\nmodel: conditional\nnoise_observed: true\nnormalization: mean\nnum_batches_per_epoch: 128\nprediction_length: 24\nsetup: missing_values\ntrain_missing_scenario: BM-E\nuse_features: false\nuse_lags: true\nuse_validation_set: true\n"
  },
  {
    "path": "configs/train_tsdiff-cond/missing_BM-E_uber_tlc_hourly.yaml",
    "content": "batch_size: 64\ncontext_length: 336\ndataset: uber_tlc_hourly\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\ndo_final_eval: true\neval_every: 10\nfreq: H\ngradient_clip_val: 0.5\ninit_skip: false\nlr: 0.001\nmax_epochs: 100\nmissing_scenario: BM-E\nmissing_values: 168\nmodel: conditional\nnoise_observed: true\nnormalization: mean\nnum_batches_per_epoch: 128\nprediction_length: 24\nsetup: missing_values\ntrain_missing_scenario: BM-E\nuse_features: false\nuse_lags: true\nuse_validation_set: true\n"
  },
  {
    "path": "configs/train_tsdiff-cond/missing_RM_electricity_nips.yaml",
    "content": "batch_size: 64\ncontext_length: 336\ndataset: electricity_nips\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\ndo_final_eval: true\neval_every: 10\nfreq: H\ngradient_clip_val: 0.5\ninit_skip: false\nlr: 0.001\nmax_epochs: 100\nmissing_scenario: RM\nmissing_values: 168\nmodel: conditional\nnoise_observed: true\nnormalization: mean\nnum_batches_per_epoch: 128\nprediction_length: 24\nsetup: missing_values\ntrain_missing_scenario: RM\nuse_features: false\nuse_lags: true\nuse_validation_set: true\n"
  },
  {
    "path": "configs/train_tsdiff-cond/missing_RM_exchange_rate_nips.yaml",
    "content": "batch_size: 64\ncontext_length: 360\ndataset: exchange_rate_nips\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\ndo_final_eval: true\neval_every: 10\nfreq: B\ngradient_clip_val: 0.5\ninit_skip: true\nlr: 0.001\nmax_epochs: 100\nmissing_scenario: RM\nmissing_values: 180\nmodel: conditional\nnoise_observed: true\nnormalization: mean\nnum_batches_per_epoch: 128\nprediction_length: 30\nsetup: missing_values\ntrain_missing_scenario: RM\nuse_features: false\nuse_lags: true\nuse_validation_set: true\n"
  },
  {
    "path": "configs/train_tsdiff-cond/missing_RM_kdd_cup_2018_without_missing.yaml",
    "content": "batch_size: 64\ncontext_length: 312\ndataset: kdd_cup_2018_without_missing\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\ndo_final_eval: true\neval_every: 10\nfreq: H\ngradient_clip_val: 0.5\ninit_skip: true\nlr: 0.001\nmax_epochs: 100\nmissing_scenario: RM\nmissing_values: 156\nmodel: conditional\nnoise_observed: true\nnormalization: mean\nnum_batches_per_epoch: 128\nprediction_length: 48\nsetup: missing_values\ntrain_missing_scenario: RM\nuse_features: false\nuse_lags: true\nuse_validation_set: true\n"
  },
  {
    "path": "configs/train_tsdiff-cond/missing_RM_solar_nips.yaml",
    "content": "batch_size: 64\ncontext_length: 336\ndataset: solar_nips\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\ndo_final_eval: true\neval_every: 10\nfreq: H\ngradient_clip_val: 0.5\ninit_skip: false\nlr: 0.001\nmax_epochs: 100\nmissing_scenario: RM\nmissing_values: 168\nmodel: conditional\nnoise_observed: true\nnormalization: mean\nnum_batches_per_epoch: 128\nprediction_length: 24\nsetup: missing_values\ntrain_missing_scenario: RM\nuse_features: false\nuse_lags: true\nuse_validation_set: true\n"
  },
  {
    "path": "configs/train_tsdiff-cond/missing_RM_traffic_nips.yaml",
    "content": "batch_size: 64\ncontext_length: 336\ndataset: traffic_nips\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\ndo_final_eval: true\neval_every: 10\nfreq: H\ngradient_clip_val: 0.5\ninit_skip: true\nlr: 0.001\nmax_epochs: 100\nmissing_scenario: RM\nmissing_values: 168\nmodel: conditional\nnoise_observed: true\nnormalization: mean\nnum_batches_per_epoch: 128\nprediction_length: 24\nsetup: missing_values\ntrain_missing_scenario: RM\nuse_features: false\nuse_lags: true\nuse_validation_set: true\n"
  },
  {
    "path": "configs/train_tsdiff-cond/missing_RM_uber_tlc_hourly.yaml",
    "content": "batch_size: 64\ncontext_length: 336\ndataset: uber_tlc_hourly\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\ndo_final_eval: true\neval_every: 10\nfreq: H\ngradient_clip_val: 0.5\ninit_skip: false\nlr: 0.001\nmax_epochs: 100\nmissing_scenario: RM\nmissing_values: 168\nmodel: conditional\nnoise_observed: true\nnormalization: mean\nnum_batches_per_epoch: 128\nprediction_length: 24\nsetup: missing_values\ntrain_missing_scenario: RM\nuse_features: false\nuse_lags: true\nuse_validation_set: true\n"
  },
  {
    "path": "configs/train_tsdiff-cond/solar_nips.yaml",
    "content": "batch_size: 64\ncontext_length: 336\ndataset: solar_nips\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\ndo_final_eval: true\neval_every: 10\nfreq: H\ngradient_clip_val: 0.5\ninit_skip: false\nlr: 0.001\nmax_epochs: 100\nmodel: conditional\nnoise_observed: true\nnormalization: mean\nnum_batches_per_epoch: 128\nprediction_length: 24\nsetup: forecasting\nuse_features: false\nuse_lags: true\nuse_validation_set: true\n"
  },
  {
    "path": "configs/train_tsdiff-cond/traffic_nips.yaml",
    "content": "batch_size: 64\ncontext_length: 336\ndataset: traffic_nips\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\ndo_final_eval: true\neval_every: 10\nfreq: H\ngradient_clip_val: 0.5\ninit_skip: true\nlr: 0.001\nmax_epochs: 100\nmodel: conditional\nnoise_observed: true\nnormalization: mean\nnum_batches_per_epoch: 128\nprediction_length: 24\nsetup: forecasting\nuse_features: false\nuse_lags: true\nuse_validation_set: true\n"
  },
  {
    "path": "configs/train_tsdiff-cond/uber_tlc_hourly.yaml",
    "content": "batch_size: 64\ncontext_length: 336\ndataset: uber_tlc_hourly\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\ndo_final_eval: true\neval_every: 10\nfreq: H\ngradient_clip_val: 0.5\ninit_skip: false\nlr: 0.001\nmax_epochs: 100\nmodel: conditional\nnoise_observed: true\nnormalization: mean\nnum_batches_per_epoch: 128\nprediction_length: 24\nsetup: forecasting\nuse_features: false\nuse_lags: true\nuse_validation_set: true\n"
  },
  {
    "path": "configs/train_tsdiff-cond/wiki2000_nips.yaml",
    "content": "batch_size: 64\ncontext_length: 360\ndataset: wiki2000_nips\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\ndo_final_eval: true\neval_every: 10\nfreq: 1D\ngradient_clip_val: 0.5\ninit_skip: false\nlr: 0.001\nmax_epochs: 100\nmodel: conditional\nnoise_observed: true\nnormalization: mean\nnum_batches_per_epoch: 128\nprediction_length: 30\nsetup: forecasting\nuse_features: false\nuse_lags: false\nuse_validation_set: true\n"
  },
  {
    "path": "configs/train_tsdiff-cond.yaml",
    "content": "model: conditional\ndiffusion_config: diffusion_small_config\nnormalization: mean\nuse_features: False\nuse_lags: True\ndataset: solar_nips\nfreq: H\ncontext_length: 336 # 360 for `D`\nprediction_length: 24 # 30 for `D`\nlr: 1.e-3\ninit_skip: False\ngradient_clip_val: 0.5\nmax_epochs: 100\nnum_batches_per_epoch: 128\nbatch_size: 64\nuse_validation_set: True\neval_every: 10\ndevice: cuda:0\nnoise_observed: True\ndo_final_eval: True\nsetup: missing_values\n# The following keys will be ignored, if the setup is forecasting\ntrain_missing_scenario: BM-E\nmissing_scenario: BM-E\nmissing_values: 168\n"
  },
  {
    "path": "configs/train_tsdiff.yaml",
    "content": "model: unconditional\ndiffusion_config: diffusion_small_config\nnormalization: mean\nuse_features: False\nuse_lags: False\ndataset: solar_nips\nfreq: H\ncontext_length: 336 # 360 for `D`\nprediction_length: 24 # 30 for `D`\nlr: 1.e-3\ninit_skip: True\ngradient_clip_val: 0.5\nmax_epochs: 100\nnum_batches_per_epoch: 128\nbatch_size: 64\nscale: 4\n# Used only in callback,\n# the final evaluation uses 100 samples\nnum_samples: 16\nsampler: ddpm\nsampler_params:\n  guidance: quantile\n  scale: 4\nuse_validation_set: True\neval_every: 50\ndevice: cuda:0\nsetup: forecasting\ndo_final_eval: True\n# The following key will be ignored,\n# if the setup is forecasting\nmissing_data_configs:\n- missing_scenario: BM-B\n  missing_values: 168\n- missing_scenario: BM-E\n  missing_values: 168\n"
  },
  {
    "path": "configs/tstr/electricity_nips.yaml",
    "content": "ckpt: dummy/electricity_nips.ckpt\ncontext_length: 336\ndataset: electricity_nips\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\ninit_skip: false\nprediction_length: 24\nscaling_type: mean\nuse_features: false\nuse_lags: true\n"
  },
  {
    "path": "configs/tstr/exchange_rate_nips.yaml",
    "content": "ckpt: dummy/exchange_rate_nips.ckpt\ncontext_length: 360\ndataset: exchange_rate_nips\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\ninit_skip: true\nprediction_length: 30\nscaling_type: mean\nuse_features: false\nuse_lags: true\n"
  },
  {
    "path": "configs/tstr/kdd_cup_2018_without_missing.yaml",
    "content": "ckpt: dummy/kdd_cup_2018_without_missing.ckpt\ncontext_length: 312\ndataset: kdd_cup_2018_without_missing\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\ninit_skip: true\nprediction_length: 48\nscaling_type: mean\nuse_features: false\nuse_lags: true\n"
  },
  {
    "path": "configs/tstr/m4_hourly.yaml",
    "content": "ckpt: dummy/m4_hourly.ckpt\ncontext_length: 312\ndataset: m4_hourly\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\ninit_skip: false\nprediction_length: 48\nscaling_type: mean\nuse_features: false\nuse_lags: false\n"
  },
  {
    "path": "configs/tstr/solar_nips.yaml",
    "content": "ckpt: dummy/solar_nips.ckpt\ncontext_length: 336\ndataset: solar_nips\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\ninit_skip: false\nprediction_length: 24\nscaling_type: mean\nuse_features: false\nuse_lags: true\n"
  },
  {
    "path": "configs/tstr/traffic_nips.yaml",
    "content": "ckpt: dummy/traffic_nips.ckpt\ncontext_length: 336\ndataset: traffic_nips\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\ninit_skip: true\nprediction_length: 24\nscaling_type: mean\nuse_features: false\nuse_lags: true\n"
  },
  {
    "path": "configs/tstr/uber_tlc_hourly.yaml",
    "content": "ckpt: dummy/uber_tlc_hourly.ckpt\ncontext_length: 336\ndataset: uber_tlc_hourly\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\ninit_skip: false\nprediction_length: 24\nscaling_type: mean\nuse_features: false\nuse_lags: true\n"
  },
  {
    "path": "configs/tstr/wiki2000_nips.yaml",
    "content": "ckpt: dummy/wiki2000_nips.ckpt\ncontext_length: 360\ndataset: wiki2000_nips\ndevice: cuda:0\ndiffusion_config: diffusion_small_config\ninit_skip: false\nprediction_length: 30\nscaling_type: mean\nuse_features: false\nuse_lags: false\n"
  },
  {
    "path": "configs/tstr.yaml",
    "content": "# Model & checkpoint parameters\ndataset: solar_nips\ndevice: cuda:0\nckpt: ckpts/solar_nips/version_236/1299_.ckpt\ndiffusion_config: diffusion_small_config\ncontext_length: 336\nprediction_length: 24\nuse_lags: true\nuse_features: false\ninit_skip: true\nscaling_type: mean"
  },
  {
    "path": "pyproject.toml",
    "content": "[project]\nname = \"uncond-ts-diff\"\nversion = \"0.1.0\"\ndescription = \"TSDiff: An Unconditional Diffusion Model for Time Series\"\nauthors = []\ndependencies = [\n    \"torch~=1.13.1\",\n    \"pytorch-lightning~=1.9.4\",\n    \"gluonts[mxnet,pro]~=0.12.3\",\n    \"matplotlib\",\n    \"seaborn\",\n    \"opt_einsum~=3.3.0\",\n    \"einops\",\n    \"black\",\n    \"tqdm\",\n    \"scipy\",\n    \"scikit-learn\",\n    \"numba\",\n    \"jupyter\",\n    \"rich\",\n    \"pykeops==2.1.1\",\n]\nreadme = \"README.md\"\nrequires-python = \">= 3.8\"\n\n[tool.black]\nline-length = 79\n"
  },
  {
    "path": "src/uncond_ts_diff/arch/__init__.py",
    "content": "# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n# SPDX-License-Identifier: Apache-2.0\nfrom .backbones import BackboneModel\n\n__all__ = [\"BackboneModel\"]\n"
  },
  {
    "path": "src/uncond_ts_diff/arch/backbones.py",
    "content": "# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n# SPDX-License-Identifier: Apache-2.0\nimport math\n\nimport torch\nfrom torch import nn\n\nfrom .s4 import S4\n\n\nclass SinusoidalPositionEmbeddings(nn.Module):\n    def __init__(self, dim):\n        super().__init__()\n        self.dim = dim\n\n    def forward(self, time):\n        device = time.device\n        half_dim = self.dim // 2\n        embeddings = math.log(10000) / (half_dim - 1)\n        embeddings = torch.exp(\n            torch.arange(half_dim, device=device) * -embeddings\n        )\n        embeddings = time[:, None] * embeddings[None, :]\n        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)\n        return embeddings\n\n\nclass S4Layer(nn.Module):\n    def __init__(\n        self,\n        d_model,\n        dropout=0.0,\n    ):\n        super().__init__()\n        self.layer = S4(\n            d_model=d_model,\n            d_state=128,\n            bidirectional=True,\n            dropout=dropout,\n            transposed=True,\n            postact=None,\n        )\n        self.norm = nn.LayerNorm(d_model)\n        self.dropout = (\n            nn.Dropout1d(dropout) if dropout > 0.0 else nn.Identity()\n        )\n\n    def forward(self, x):\n        \"\"\"\n        Input x is shape (B, d_input, L)\n        \"\"\"\n        z = x\n        # Prenorm\n        z = self.norm(z.transpose(-1, -2)).transpose(-1, -2)\n        # Apply layer: we ignore the state input and output for training\n        z, _ = self.layer(z)\n        # Dropout on the output of the layer\n        z = self.dropout(z)\n        # Residual connection\n        x = z + x\n        return x, None\n\n    def default_state(self, *args, **kwargs):\n        return self.layer.default_state(*args, **kwargs)\n\n    def step(self, x, state, **kwargs):\n        z = x\n        # Prenorm\n        z = self.norm(z.transpose(-1, -2)).transpose(-1, -2)\n        # Apply layer\n        z, state = self.layer.step(z, state, **kwargs)\n        # Residual connection\n        x = z + x\n        return x, state\n\n\nclass S4Block(nn.Module):\n    def __init__(self, d_model, dropout=0.0, expand=2, num_features=0):\n        super().__init__()\n        self.s4block = S4Layer(d_model, dropout=dropout)\n\n        self.time_linear = nn.Linear(d_model, d_model)\n        self.tanh = nn.Tanh()\n        self.sigm = nn.Sigmoid()\n        self.out_linear1 = nn.Conv1d(\n            in_channels=d_model, out_channels=d_model, kernel_size=1\n        )\n        self.out_linear2 = nn.Conv1d(\n            in_channels=d_model, out_channels=d_model, kernel_size=1\n        )\n        self.feature_encoder = nn.Conv1d(num_features, d_model, kernel_size=1)\n\n    def forward(self, x, t, features=None):\n        t = self.time_linear(t)[:, None, :].repeat(1, x.shape[2], 1)\n        t = t.transpose(-1, -2)\n        out, _ = self.s4block(x + t)\n        if features is not None:\n            out = out + self.feature_encoder(features)\n        out = self.tanh(out) * self.sigm(out)\n        out1 = self.out_linear1(out)\n        out2 = self.out_linear2(out)\n        return out1 + x, out2\n\n\ndef Conv1dKaiming(in_channels, out_channels, kernel_size):\n    layer = nn.Conv1d(in_channels, out_channels, kernel_size)\n    nn.init.kaiming_normal_(layer.weight)\n    return layer\n\n\nclass BackboneModel(nn.Module):\n    def __init__(\n        self,\n        input_dim,\n        hidden_dim,\n        output_dim,\n        step_emb,\n        num_residual_blocks,\n        num_features,\n        residual_block=\"s4\",\n        dropout=0.0,\n        init_skip=True,\n    ):\n        super().__init__()\n        if residual_block == \"s4\":\n            residual_block = S4Block\n        else:\n            raise ValueError(f\"Unknown residual block {residual_block}\")\n        self.input_init = nn.Sequential(\n            nn.Linear(input_dim, hidden_dim),\n            nn.ReLU(),\n        )\n        self.time_init = nn.Sequential(\n            nn.Linear(step_emb, hidden_dim),\n            nn.SiLU(),\n            nn.Linear(hidden_dim, hidden_dim),\n            nn.SiLU(),\n        )\n        self.out_linear = nn.Sequential(\n            nn.Linear(hidden_dim, hidden_dim),\n            nn.ReLU(),\n            nn.Linear(hidden_dim, output_dim),\n        )\n        residual_blocks = []\n        for i in range(num_residual_blocks):\n            residual_blocks.append(\n                residual_block(\n                    hidden_dim, num_features=num_features, dropout=dropout\n                )\n            )\n        self.residual_blocks = nn.ModuleList(residual_blocks)\n        self.step_embedding = SinusoidalPositionEmbeddings(step_emb)\n        self.init_skip = init_skip\n\n    def forward(self, input, t, features=None):\n        x = self.input_init(input)  # B, L ,C\n        t = self.time_init(self.step_embedding(t))\n        x = x.transpose(-1, -2)\n        if features is not None:\n            features = features.transpose(-1, -2)\n        skips = []\n        for layer in self.residual_blocks:\n            x, skip = layer(x, t, features)\n            skips.append(skip)\n\n        skip = torch.stack(skips).sum(0)\n        skip = skip.transpose(-1, -2)\n        out = self.out_linear(skip)\n        if self.init_skip:\n            out = out + input\n        return out\n"
  },
  {
    "path": "src/uncond_ts_diff/arch/s4.py",
    "content": "\"\"\"Standalone version of Structured (Sequence) State Space (S4) model.\"\"\"\n\nimport logging\nfrom functools import partial\nimport math\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom pytorch_lightning.utilities import rank_zero_only\nfrom einops import rearrange, repeat\nimport opt_einsum as oe\n\ncontract = oe.contract\ncontract_expression = oe.contract_expression\n\n\ndef get_logger(name=__name__, level=logging.INFO) -> logging.Logger:\n    \"\"\"Initializes multi-GPU-friendly python logger.\"\"\"\n\n    logger = logging.getLogger(name)\n    logger.setLevel(level)\n\n    # this ensures all logging levels get marked with the rank zero decorator\n    # otherwise logs would get multiplied for each GPU process in multi-GPU setup\n    for level in (\n        \"debug\",\n        \"info\",\n        \"warning\",\n        \"error\",\n        \"exception\",\n        \"fatal\",\n        \"critical\",\n    ):\n        setattr(logger, level, rank_zero_only(getattr(logger, level)))\n\n    return logger\n\n\nlog = get_logger(__name__)\n\n\"\"\" Cauchy and Vandermonde kernels \"\"\"\n\ntry:  # Try CUDA extension\n    from extensions.cauchy.cauchy import cauchy_mult\n\n    has_cauchy_extension = True\nexcept ImportError:\n    # log.warning(\n    #     \"CUDA extension for cauchy multiplication not found. Install by going to extensions/cauchy/ and running `python setup.py install`. This should speed up end-to-end training by 10-50%\"\n    # )\n    has_cauchy_extension = False\n\ntry:  # Try pykeops\n    from pykeops.torch import Genred\n\n    has_pykeops = True\n    log.info(\"Pykeops installation found.\")\n\n    def _broadcast_dims(*tensors):\n        max_dim = max([len(tensor.shape) for tensor in tensors])\n        tensors = [\n            tensor.view((1,) * (max_dim - len(tensor.shape)) + tensor.shape)\n            for tensor in tensors\n        ]\n        return tensors\n\n    def cauchy_conj(v, z, w):\n        \"\"\"Pykeops version\"\"\"\n        expr_num = \"z * ComplexReal(v) - Real2Complex(Sum(v * w))\"\n        expr_denom = \"ComplexMult(z-w, z-Conj(w))\"\n\n        cauchy_mult = Genred(\n            f\"ComplexDivide({expr_num}, {expr_denom})\",\n            [\n                \"v = Vj(2)\",\n                \"z = Vi(2)\",\n                \"w = Vj(2)\",\n            ],\n            reduction_op=\"Sum\",\n            axis=1,\n        )\n\n        v, z, w = _broadcast_dims(v, z, w)\n        v = _c2r(v)\n        z = _c2r(z)\n        w = _c2r(w)\n\n        r = 2 * cauchy_mult(v, z, w, backend=\"GPU\")\n        return _r2c(r)\n\n    def log_vandermonde(v, x, L):\n        expr = \"ComplexMult(v, ComplexExp(ComplexMult(x, l)))\"\n        vandermonde_mult = Genred(\n            expr,\n            [\n                \"v = Vj(2)\",\n                \"x = Vj(2)\",\n                \"l = Vi(2)\",\n            ],\n            reduction_op=\"Sum\",\n            axis=1,\n        )\n\n        l = torch.arange(L).to(x)\n        v, x, l = _broadcast_dims(v, x, l)\n        v = _c2r(v)\n        x = _c2r(x)\n        l = _c2r(l)\n\n        r = vandermonde_mult(v, x, l, backend=\"GPU\")\n        return 2 * _r2c(r).real\n\n    def log_vandermonde_transpose(u, v, x, L):\n        \"\"\"\n        u: ... H L\n        v: ... H N\n        x: ... H N\n        Returns: ... H N\n\n        V = Vandermonde(a, L) : (H N L)\n        contract_L(V * u * v)\n        \"\"\"\n        expr = \"ComplexMult(ComplexMult(v, u), ComplexExp(ComplexMult(x, l)))\"\n        vandermonde_mult = Genred(\n            expr,\n            [\n                \"u = Vj(2)\",\n                \"v = Vi(2)\",\n                \"x = Vi(2)\",\n                \"l = Vj(2)\",\n            ],\n            reduction_op=\"Sum\",\n            axis=1,\n        )\n\n        l = torch.arange(L).to(x)\n        u, v, x, l = _broadcast_dims(u, v, x, l)\n        u = _c2r(u)\n        v = _c2r(v)\n        x = _c2r(x)\n        l = _c2r(l)\n\n        r = vandermonde_mult(u, v, x, l, backend=\"GPU\")\n        return _r2c(r)\n\nexcept ImportError:\n    has_pykeops = False\n    if not has_cauchy_extension:\n        log.warning(\n            \"Falling back on slow Cauchy kernel. Install at least one of pykeops or the CUDA extension for efficiency.\"\n        )\n\n        def cauchy_naive(v, z, w):\n            \"\"\"\n            v, w: (..., N)\n            z: (..., L)\n            returns: (..., L)\n            \"\"\"\n            cauchy_matrix = v.unsqueeze(-1) / (\n                z.unsqueeze(-2) - w.unsqueeze(-1)\n            )  # (... N L)\n            return torch.sum(cauchy_matrix, dim=-2)\n\n    # Vandermonde functions\n    log.warning(\n        \"Falling back on slow Vandermonde kernel. Install pykeops for improved memory efficiency.\"\n    )\n\n    def log_vandermonde(v, x, L):\n        \"\"\"\n        v: (..., N)\n        x: (..., N)\n        returns: (..., L) \\sum v x^l\n        \"\"\"\n        vandermonde_matrix = torch.exp(\n            x.unsqueeze(-1) * torch.arange(L).to(x)\n        )  # (... N L)\n        vandermonde_prod = contract(\n            \"... n, ... n l -> ... l\", v, vandermonde_matrix\n        )  # (... L)\n        return 2 * vandermonde_prod.real\n\n    def log_vandermonde_transpose(u, v, x, L):\n        vandermonde_matrix = torch.exp(\n            x.unsqueeze(-1) * torch.arange(L).to(x)\n        )  # (... N L)\n        vandermonde_prod = contract(\n            \"... l, ... n, ... n l -> ... n\",\n            u.to(x),\n            v.to(x),\n            vandermonde_matrix,\n        )  # (... L)\n        return vandermonde_prod\n\n\ndef _conj(x):\n    return torch.cat([x, x.conj()], dim=-1)\n\n\n_c2r = torch.view_as_real\n_r2c = torch.view_as_complex\nif tuple(map(int, torch.__version__.split(\".\")[:2])) >= (1, 10):\n\n    def _resolve_conj(x):\n        return x.conj().resolve_conj()\n\nelse:\n\n    def _resolve_conj(x):\n        return x.conj()\n\n\n\"\"\" Simple nn.Module components \"\"\"\n\n\ndef Activation(activation=None, dim=-1):\n    if activation in [None, \"id\", \"identity\", \"linear\"]:\n        return nn.Identity()\n    elif activation == \"tanh\":\n        return nn.Tanh()\n    elif activation == \"relu\":\n        return nn.ReLU()\n    elif activation == \"gelu\":\n        return nn.GELU()\n    elif activation in [\"swish\", \"silu\"]:\n        return nn.SiLU()\n    elif activation == \"glu\":\n        return nn.GLU(dim=dim)\n    elif activation == \"sigmoid\":\n        return nn.Sigmoid()\n    else:\n        raise NotImplementedError(\n            \"hidden activation '{}' is not implemented\".format(activation)\n        )\n\n\ndef LinearActivation(\n    d_input,\n    d_output,\n    bias=True,\n    transposed=False,\n    activation=None,\n    activate=False,  # Apply activation as part of this module\n    **kwargs,\n):\n    \"\"\"Returns a linear nn.Module with control over axes order, initialization, and activation\"\"\"\n\n    # Construct core module\n    linear_cls = partial(nn.Conv1d, kernel_size=1) if transposed else nn.Linear\n    if activation == \"glu\":\n        d_output *= 2\n    linear = linear_cls(d_input, d_output, bias=bias, **kwargs)\n\n    if activate and activation is not None:\n        activation = Activation(activation, dim=-2 if transposed else -1)\n        linear = nn.Sequential(linear, activation)\n    return linear\n\n\nclass DropoutNd(nn.Module):\n    def __init__(self, p: float = 0.5, tie=True, transposed=True):\n        \"\"\"\n        tie: tie dropout mask across sequence lengths (Dropout1d/2d/3d)\n        \"\"\"\n        super().__init__()\n        if p < 0 or p >= 1:\n            raise ValueError(\n                \"dropout probability has to be in [0, 1), \"\n                \"but got {}\".format(p)\n            )\n        self.p = p\n        self.tie = tie\n        self.transposed = transposed\n        self.binomial = torch.distributions.binomial.Binomial(probs=1 - self.p)\n\n    def forward(self, X):\n        \"\"\"X: (batch, dim, lengths...)\"\"\"\n        if self.training:\n            if not self.transposed:\n                X = rearrange(X, \"b d ... -> b ... d\")\n            mask_shape = (\n                X.shape[:2] + (1,) * (X.ndim - 2) if self.tie else X.shape\n            )\n            mask = torch.rand(*mask_shape, device=X.device) < 1.0 - self.p\n            X = X * mask * (1.0 / (1 - self.p))\n            if not self.transposed:\n                X = rearrange(X, \"b ... d -> b d ...\")\n            return X\n        return X\n\n\n\"\"\" Misc functional utilities \"\"\"\n\n\ndef power(L, A, v=None):\n    \"\"\"Compute A^L and the scan sum_i A^i v_i\n\n    A: (..., N, N)\n    v: (..., N, L)\n    \"\"\"\n\n    I = torch.eye(A.shape[-1]).to(A)  # , dtype=A.dtype, device=A.device)\n\n    powers = [A]\n    l = 1\n    while True:\n        if L % 2 == 1:\n            I = powers[-1] @ I\n        L //= 2\n        if L == 0:\n            break\n        l *= 2\n        powers.append(powers[-1] @ powers[-1])\n\n    if v is None:\n        return I\n\n    # Invariants:\n    # powers[-1] := A^l\n    # l := largest po2 at most L\n\n    # Note that an alternative divide and conquer to compute the reduction is possible and can be embedded into the above loop without caching intermediate powers of A\n    # We do this reverse divide-and-conquer for efficiency reasons:\n    # 1) it involves fewer padding steps for non-po2 L\n    # 2) it involves more contiguous arrays\n\n    # Take care of edge case for non-po2 arrays\n    # Note that this initial step is a no-op for the case of power of 2 (l == L)\n    k = v.size(-1) - l\n    v_ = powers.pop() @ v[..., l:]\n    v = v[..., :l]\n    v[..., :k] = v[..., :k] + v_\n\n    # Handle reduction for power of 2\n    while v.size(-1) > 1:\n        v = rearrange(v, \"... (z l) -> ... z l\", z=2)\n        v = v[..., 0, :] + powers.pop() @ v[..., 1, :]\n    return I, v.squeeze(-1)\n\n\n\"\"\" HiPPO utilities \"\"\"\n\n\ndef transition(measure, N):\n    \"\"\"A, B transition matrices for different measures\"\"\"\n    # Legendre (translated)\n    if measure == \"legt\":\n        Q = np.arange(N, dtype=np.float64)\n        R = (2 * Q + 1) ** 0.5\n        j, i = np.meshgrid(Q, Q)\n        A = R[:, None] * np.where(i < j, (-1.0) ** (i - j), 1) * R[None, :]\n        B = R[:, None]\n        A = -A\n\n        # Halve again for timescale correctness\n        A *= 0.5\n        B *= 0.5\n    # Legendre (scaled)\n    elif measure == \"legs\":\n        q = np.arange(N, dtype=np.float64)\n        col, row = np.meshgrid(q, q)\n        r = 2 * q + 1\n        M = -(np.where(row >= col, r, 0) - np.diag(q))\n        T = np.sqrt(np.diag(2 * q + 1))\n        A = T @ M @ np.linalg.inv(T)\n        B = np.diag(T)[:, None]\n        B = (\n            B.copy()\n        )  # Otherwise \"UserWarning: given NumPY array is not writeable...\" after torch.as_tensor(B)\n    elif measure == \"legsd\":\n        # Essentially equivalent to S4D-LegS\n        q = np.arange(N, dtype=np.float64)\n        col, row = np.meshgrid(q, q)\n        r = 2 * q + 1\n        M = -(np.where(row >= col, r, 0) - np.diag(q))\n        T = np.sqrt(np.diag(2 * q + 1))\n        A = T @ M @ np.linalg.inv(T)\n        B = np.diag(T)[:, None]\n        B = (\n            B.copy()\n        )  # Otherwise \"UserWarning: given NumPY array is not writeable...\" after torch.as_tensor(B)\n        A += 0.5 * B * B[None, :, 0]\n        B = B / 2.0\n    elif measure in [\"fourier_diag\", \"foud\"]:\n        # Essentially equivalent to S4D-Lin\n        freqs = np.arange(N // 2)\n        d = np.stack([freqs, np.zeros(N // 2)], axis=-1).reshape(-1)[:-1]\n        A = 2 * np.pi * (-np.diag(d, 1) + np.diag(d, -1))\n        A = A - 0.5 * np.eye(N)\n        B = np.zeros(N)\n        B[0::2] = 2**0.5\n        B[0] = 1\n        B = B[:, None]\n    elif measure in [\"fourier\", \"fout\"]:\n        freqs = np.arange(N // 2)\n        d = np.stack([np.zeros(N // 2), freqs], axis=-1).reshape(-1)[1:]\n        A = np.pi * (-np.diag(d, 1) + np.diag(d, -1))\n        B = np.zeros(N)\n        B[0::2] = 2**0.5\n        B[0] = 1\n\n        # Subtract off rank correction - this corresponds to the other endpoint u(t-1) in this case\n        A = A - B[:, None] * B[None, :]\n        B = B[:, None]\n    else:\n        raise NotImplementedError\n\n    return A, B\n\n\ndef rank_correction(measure, N, rank=1, dtype=torch.float):\n    \"\"\"Return low-rank matrix L such that A + L is normal\"\"\"\n\n    if measure == \"legs\":\n        assert rank >= 1\n        P = torch.sqrt(0.5 + torch.arange(N, dtype=dtype)).unsqueeze(\n            0\n        )  # (1 N)\n    elif measure == \"legt\":\n        assert rank >= 2\n        P = torch.sqrt(1 + 2 * torch.arange(N, dtype=dtype))  # (N)\n        P0 = P.clone()\n        P0[0::2] = 0.0\n        P1 = P.clone()\n        P1[1::2] = 0.0\n        P = torch.stack([P0, P1], dim=0)  # (2 N)\n        P *= 2 ** (\n            -0.5\n        )  # Halve the rank correct just like the original matrix was halved\n    elif measure in [\"fourier\", \"fout\"]:\n        P = torch.zeros(N)\n        P[0::2] = 2**0.5\n        P[0] = 1\n        P = P.unsqueeze(0)\n    elif measure in [\"fourier_diag\", \"foud\", \"legsd\"]:\n        P = torch.zeros(1, N, dtype=dtype)\n    else:\n        raise NotImplementedError\n\n    d = P.size(0)\n    if rank > d:\n        P = torch.cat(\n            [P, torch.zeros(rank - d, N, dtype=dtype)], dim=0\n        )  # (rank N)\n    return P\n\n\ndef nplr(measure, N, rank=1, dtype=torch.float, diagonalize_precision=True):\n    \"\"\"Return w, p, q, V, B such that\n    (w - p q^*, B) is unitarily equivalent to the original HiPPO A, B by the matrix V\n    i.e. A = V[w - p q^*]V^*, B = V B\n    \"\"\"\n    assert dtype == torch.float or dtype == torch.double\n    cdtype = torch.cfloat if dtype == torch.float else torch.cdouble\n\n    A, B = transition(measure, N)\n    A = torch.as_tensor(A, dtype=dtype)  # (N, N)\n    B = torch.as_tensor(B, dtype=dtype)[:, 0]  # (N,)\n\n    P = rank_correction(measure, N, rank=rank, dtype=dtype)  # (r N)\n    AP = A + torch.sum(P.unsqueeze(-2) * P.unsqueeze(-1), dim=-3)\n\n    # We require AP to be nearly skew-symmetric\n    _A = AP + AP.transpose(-1, -2)\n    if (\n        err := torch.sum((_A - _A[0, 0] * torch.eye(N)) ** 2) / N\n    ) > 1e-5:  # if not torch.allclose(_A - _A[0,0]*torch.eye(N), torch.zeros(N, N), atol=1e-5):\n        print(\"WARNING: HiPPO matrix not skew symmetric\", err)\n\n    # Take advantage of identity + skew-symmetric form to calculate real and imaginary parts separately\n    # Imaginary part can use eigh instead of eig\n    w_re = torch.mean(torch.diagonal(AP), -1, keepdim=True)\n\n    # Diagonalize in double precision\n    if diagonalize_precision:\n        AP = AP.to(torch.double)\n    w_im, V = torch.linalg.eigh(AP * -1j)  # (..., N) (..., N, N)\n    if diagonalize_precision:\n        w_im, V = w_im.to(cdtype), V.to(cdtype)\n    w = w_re + 1j * w_im\n    # Check: V w V^{-1} = A\n    # print(\"check\", V @ torch.diag_embed(w) @ V.conj().transpose(-1, -2))\n\n    # Only keep half of each conjugate pair\n    _, idx = torch.sort(w.imag)\n    w_sorted = w[idx]\n    V_sorted = V[:, idx]\n\n    # There is an edge case when eigenvalues can be 0, which requires some machinery to handle\n    # We use a huge hack here: Assume only one pair is 0, and that it is the first row/column of A (only happens in Fourier case)\n    V = V_sorted[:, : N // 2]\n    w = w_sorted[: N // 2]\n    assert (\n        w[-2].abs() > 1e-4\n    ), \"Only 1 zero eigenvalue allowed in diagonal part of A\"\n    if w[-1].abs() < 1e-4:\n        V[:, -1] = 0.0\n        V[0, -1] = 2**-0.5\n        V[1, -1] = 2**-0.5 * 1j\n\n    _AP = V @ torch.diag_embed(w) @ V.conj().transpose(-1, -2)\n    if (err := torch.sum((2 * _AP.real - AP) ** 2) / N) > 1e-5:\n        print(\n            \"Warning: Diagonalization of A matrix not numerically precise - error\",\n            err,\n        )\n    # print(\"check\", V @ torch.diag_embed(w) @ V.conj().transpose(-1, -2))\n\n    V_inv = V.conj().transpose(-1, -2)\n\n    B = contract(\"ij, j -> i\", V_inv, B.to(V))  # V^* B\n    P = contract(\"ij, ...j -> ...i\", V_inv, P.to(V))  # V^* P\n\n    return w, P, B, V\n\n\ndef dplr(\n    scaling,\n    N,\n    rank=1,\n    H=1,\n    dtype=torch.float,\n    real_scale=1.0,\n    imag_scale=1.0,\n    random_real=False,\n    random_imag=False,\n    normalize=False,\n    diagonal=True,\n    random_B=False,\n):\n    assert dtype == torch.float or dtype == torch.double\n    dtype = torch.cfloat if dtype == torch.float else torch.cdouble\n\n    pi = torch.tensor(math.pi)\n    if random_real:\n        real_part = torch.rand(H, N // 2)\n    else:\n        real_part = 0.5 * torch.ones(H, N // 2)\n    if random_imag:\n        imag_part = N // 2 * torch.rand(H, N // 2)\n    else:\n        imag_part = repeat(torch.arange(N // 2), \"n -> h n\", h=H)\n\n    real_part = real_scale * real_part\n    if scaling == \"random\":\n        imag_part = torch.randn(H, N // 2)\n    elif scaling == \"real\":\n        imag_part = 0 * imag_part\n        real_part = 1 + repeat(torch.arange(N // 2), \"n -> h n\", h=H)\n    elif scaling in [\"linear\", \"lin\"]:\n        imag_part = pi * imag_part\n    elif scaling in [\n        \"inverse\",\n        \"inv\",\n    ]:  # Based on asymptotics of the default HiPPO matrix\n        imag_part = 1 / pi * N * (N / (1 + 2 * imag_part) - 1)\n    elif scaling in [\"inverse2\", \"inv2\"]:\n        imag_part = 1 / pi * N * (N / (1 + imag_part) - 1)\n    elif scaling in [\"quadratic\", \"quad\"]:\n        imag_part = 1 / pi * (1 + 2 * imag_part) ** 2\n    elif scaling in [\"legs\", \"hippo\"]:\n        w, _, _, _ = nplr(\"legsd\", N)\n        imag_part = w.imag\n\n    else:\n        raise NotImplementedError\n    imag_part = imag_scale * imag_part\n    w = -real_part + 1j * imag_part\n\n    # Initialize B\n    if random_B:\n        B = torch.randn(H, N // 2, dtype=dtype)\n    else:\n        B = torch.ones(H, N // 2, dtype=dtype)\n\n    if normalize:\n        norm = (\n            -B / w\n        )  # (H, N) # Result if you integrate the kernel with constant 1 function\n        zeta = 2 * torch.sum(\n            torch.abs(norm) ** 2, dim=-1, keepdim=True\n        )  # Variance with a random C vector\n        B = B / zeta**0.5\n\n    P = torch.randn(rank, H, N // 2, dtype=dtype)\n    if diagonal:\n        P = P * 0.0\n    V = torch.eye(N, dtype=dtype)[:: N // 2]  # Only used in testing\n    V = repeat(V, \"n m -> h n m\", h=H)\n\n    return w, P, B, V\n\n\ndef ssm(measure, N, R, H, **ssm_args):\n    \"\"\"Dispatcher to create single SSM initialization\n\n    N: state size\n    R: rank (for DPLR parameterization)\n    H: number of independent SSM copies\n    \"\"\"\n\n    if measure == \"dplr\":\n        w, P, B, V = dplr(N=N, rank=R, H=H, **ssm_args)\n    elif measure.startswith(\"diag\"):\n        args = measure.split(\"-\")\n        assert args[0] == \"diag\" and len(args) > 1\n        scaling = args[1]\n        w, P, B, V = dplr(\n            scaling=scaling, N=N, rank=R, H=H, diagonal=True, **ssm_args\n        )\n    else:\n        w, P, B, V = nplr(measure, N, R, **ssm_args)\n        w = repeat(w, \"n -> s n\", s=H)\n        P = repeat(P, \"r n -> r s n\", s=H)\n        B = repeat(B, \"n -> s n\", s=H)\n        V = repeat(V, \"n m -> s n m\", s=H)\n    return w, P, B, V\n\n\ncombinations = {\n    \"hippo\": [\"legs\", \"fourier\"],\n    \"diag\": [\"diag-inv\", \"diag-lin\"],\n    \"all\": [\"legs\", \"fourier\", \"diag-inv\", \"diag-lin\"],\n}\n\n\ndef combination(measures, N, R, S, **ssm_args):\n    if isinstance(measures, str):\n        measures = (\n            combinations[measures] if measures in combinations else [measures]\n        )\n\n    assert (\n        S % len(measures) == 0\n    ), f\"{S} independent trainable SSM copies must be multiple of {len(measures)} different measures\"\n    w, P, B, V = zip(\n        *[\n            ssm(measure, N, R, S // len(measures), **ssm_args)\n            for measure in measures\n        ]\n    )\n    w = torch.cat(w, dim=0)  # (S N)\n    P = torch.cat(P, dim=1)  # (R S N)\n    B = torch.cat(B, dim=0)  # (S N)\n    V = torch.cat(V, dim=0)  # (S N N)\n    return w, P, B, V\n\n\nclass OptimModule(nn.Module):\n    \"\"\"Interface for Module that allows registering buffers/parameters with configurable optimizer hyperparameters\"\"\"\n\n    def register(self, name, tensor, lr=None):\n        \"\"\"Register a tensor with a configurable learning rate and 0 weight decay\"\"\"\n\n        if lr == 0.0:\n            self.register_buffer(name, tensor)\n        else:\n            self.register_parameter(name, nn.Parameter(tensor))\n\n            optim = {\"weight_decay\": 0.0}\n            if lr is not None:\n                optim[\"lr\"] = lr\n            setattr(getattr(self, name), \"_optim\", optim)\n\n\nclass SSKernelNPLR(OptimModule):\n    \"\"\"Stores a representation of and computes the SSKernel function K_L(A^dt, B^dt, C) corresponding to a discretized state space, where A is Normal + Low Rank (NPLR)\"\"\"\n\n    @torch.no_grad()\n    def _setup_C(self, L):\n        \"\"\"Construct C~ from C\n\n        Two modes are supported: go directly to length L if self.L is 1, or length is doubled\n        \"\"\"\n\n        if self.L.item() == 0:\n            if self.verbose:\n                log.info(f\"S4: Initializing kernel to length {L}\")\n            double_length = False\n        elif L > self.L.item():  # 2*int(self.L) == L:\n            if self.verbose:\n                log.info(\n                    f\"S4: Doubling length from L = {self.L.item()} to {2*self.L.item()}\"\n                )\n            double_length = True\n            L = self.L.item()  # Convenience for the math below\n        else:\n            return\n\n        C = _r2c(self.C)\n        dA, _ = self._setup_state()\n        dA_L = power(L, dA)\n        # Multiply C by I - dA_L\n        C_ = _conj(C)\n        prod = contract(\"h m n, c h n -> c h m\", dA_L.transpose(-1, -2), C_)\n        if double_length:\n            prod = -prod  # Multiply by I + dA_L instead\n        C_ = C_ - prod\n        C_ = C_[..., : self.N]  # Take conjugate pairs again\n        self.C.copy_(_c2r(C_))\n\n        self.L = (\n            2 * self.L if double_length else self.L + L\n        )  # Preserve type/device\n\n    def _omega(self, L, dtype, device, cache=True):\n        \"\"\"Calculate (and cache) FFT nodes and their \"unprocessed\" version with the bilinear transform\n        This should be called everytime the internal length self.L changes\"\"\"\n\n        # Use cached if available\n        if (\n            cache\n            and hasattr(self, \"omega\")\n            and self.omega.size(-1) == L // 2 + 1\n        ):\n            return self.omega, self.z\n\n        omega = torch.tensor(\n            np.exp(-2j * np.pi / (L)), dtype=dtype, device=device\n        )  # \\omega_{2L}\n        omega = omega ** torch.arange(0, L // 2 + 1, device=device)\n        z = 2 * (1 - omega) / (1 + omega)\n\n        # Cache if necessary\n        if cache:\n            self.omega = omega\n            self.z = z\n        return omega, z\n\n    def __init__(\n        self,\n        w,\n        P,\n        B,\n        C,\n        log_dt,\n        L=None,  # starting/maximum length of kernel\n        lr=None,\n        verbose=False,\n        keops=False,\n        real_type=\"exp\",  # ['none' | 'exp' | 'relu' | sigmoid']\n        real_tolerance=1e-3,\n        bandlimit=None,\n    ):\n        \"\"\"\n        L: Maximum length; this module computes an SSM kernel of length L\n        A is represented by diag(w) - PP^*\n        w: (S, N) diagonal part\n        P: (R, S, N) low-rank part\n\n        B: (S, N)\n        C: (C, H, N)\n        dt: (H) timescale per feature\n        lr: [dict | float | None] hook to set lr of special parameters (A, B, dt)\n\n        Dimensions:\n        N (or d_state): state size\n        H (or d_model): total SSM copies\n        S (or n_ssm): number of trainable copies of (A, B, dt); must divide H\n        R (or rank): rank of low-rank part\n        C (or channels): system is 1-dim to C-dim\n\n        The forward pass of this Module returns a tensor of shape (C, H, L)\n\n        Note: tensor shape N here denotes half the true state size, because of conjugate symmetry\n        \"\"\"\n\n        super().__init__()\n        self.verbose = verbose\n        self.keops = keops\n        self.bandlimit = bandlimit\n        self.real_type = real_type\n        self.real_tolerance = real_tolerance\n\n        # Rank of low-rank correction\n        self.rank = P.shape[-3]\n        assert w.size(-1) == P.size(-1) == B.size(-1) == C.size(-1)\n        self.H = log_dt.size(-1)\n        self.N = w.size(-1)\n\n        # Check different SSM inits\n        assert w.size(-2) == P.size(-2) == B.size(-2)  # n_ssm\n        assert self.H % w.size(0) == 0\n        self.n_ssm = w.size(0)\n        self.repeat = self.H // w.size(\n            0\n        )  # Each trainable SSM needs to be duplicated this many times\n\n        # Broadcast everything to correct shapes\n        C = C.expand(\n            torch.broadcast_shapes(C.shape, (1, self.H, self.N))\n        )  # (C, H, N)\n        B = B.unsqueeze(0)  # (1, 1, N)\n\n        # Register parameters\n        self.C = nn.Parameter(_c2r(_resolve_conj(C)))\n        if lr is None or isinstance(lr, float):\n            lr_dict = {}\n        else:\n            lr_dict, lr = lr, None\n        self.register(\"log_dt\", log_dt, lr_dict.get(\"dt\", lr))\n        self.register(\"B\", _c2r(B), lr_dict.get(\"B\", lr))\n        self.register(\"P\", _c2r(P), lr_dict.get(\"A\", lr))\n        self.register(\"inv_w_real\", self._w_init(w.real), lr_dict.get(\"A\", lr))\n        self.register(\"w_imag\", w.imag, lr_dict.get(\"A\", lr))\n\n        self.l_max = L\n        self.register_buffer(\"L\", torch.tensor(0))  # Internal length\n\n    def _w_init(self, w_real):\n        w_real = torch.clamp(w_real, max=-self.real_tolerance)\n        if self.real_type == \"none\":\n            return -w_real\n        elif self.real_type == \"exp\":\n            return torch.log(\n                -w_real\n            )  # Some of the HiPPO methods have real part 0\n        elif self.real_type == \"relu\":\n            return -w_real\n        elif self.real_type == \"sigmoid\":\n            return torch.logit(-w_real)\n        elif self.real_type == \"softplus\":\n            return torch.log(torch.exp(-w_real) - 1)\n        else:\n            raise NotImplementedError\n\n    def _w(self):\n        # Get the internal w (diagonal) parameter\n        if self.real_type == \"none\":\n            w_real = -self.inv_w_real\n        elif self.real_type == \"exp\":\n            w_real = -torch.exp(self.inv_w_real)\n        elif self.real_type == \"relu\":\n            w_real = -F.relu(self.inv_w_real)\n        elif self.real_type == \"sigmoid\":\n            w_real = -F.sigmoid(self.inv_w_real)\n        elif self.real_type == \"softplus\":\n            w_real = -F.softplus(self.inv_w_real)\n        else:\n            raise NotImplementedError\n        w = w_real + 1j * self.w_imag\n        return w\n\n    def forward(self, state=None, rate=1.0, L=None):\n        \"\"\"\n        state: (B, H, N) initial state\n        rate: sampling rate factor\n        L: target length\n\n        returns:\n        (C, H, L) convolution kernel (generally C=1)\n        (B, H, L) output from initial state\n        \"\"\"\n\n        # Initialize C~ if necessary (done in forward pass so it's on the correct device)\n        if self.L.item() == 0 and self.l_max is not None and self.l_max > 0:\n            self._setup_C(self.l_max)\n\n        # Handle sampling rate logic\n        # The idea is that this kernel's length (in continuous units) is self.L, while we are asked to provide a kernel of length L at (relative) frequency rate\n        if L is None:\n            L = round(self.L.item() / rate)\n\n        # Increase the internal length if needed\n        continuous_L = round(rate * L)\n        while continuous_L > self.L.item():\n            self._setup_C(continuous_L)\n        discrete_L = round(self.L.item() / rate)\n\n        dt = torch.exp(self.log_dt) * rate\n        B = _r2c(self.B)\n        C = _r2c(self.C)\n        P = _r2c(self.P)\n        Q = P.conj()\n        w = self._w()  # (n_ssm, N)\n\n        # Address bandlimiting\n        if self.bandlimit is not None:\n            freqs = w.imag.abs() / (2 * math.pi)  # (H, N)\n            freqs = dt[:, None] / rate * freqs  # (H, N)\n            mask = torch.where(freqs < self.bandlimit * 0.5, 1, 0)\n            C = C * mask\n\n        # Get FFT nodes of right length\n        omega, z = self._omega(\n            discrete_L, dtype=w.dtype, device=w.device, cache=(rate == 1.0)\n        )\n\n        # Broadcast parameters to same hidden features H\n        B = repeat(B, \"1 t n -> 1 (v t) n\", v=self.repeat)\n        P = repeat(P, \"r t n -> r (v t) n\", v=self.repeat)\n        Q = repeat(Q, \"r t n -> r (v t) n\", v=self.repeat)\n        w = repeat(w, \"t n -> (v t) n\", v=self.repeat)\n\n        # Augment B\n        if state is not None:\n            # Have to \"unbilinear\" the state to put it into the same \"type\" as B\n            # Compute 1/dt * (I + dt/2 A) @ state\n\n            # Can do this without expanding (maybe minor speedup using conj symmetry in theory), but it's easier to read this way\n            s = _conj(state) if state.size(-1) == self.N else state  # (B H N)\n            sA = s * _conj(w) - contract(  # (B H N)\n                \"bhm, rhm, rhn -> bhn\", s, _conj(Q), _conj(P)\n            )\n            s = s / dt.unsqueeze(-1) + sA / 2\n            s = s[..., : self.N]\n\n            B = torch.cat([s, B], dim=-3)  # (B+1, H, N)\n\n        # Incorporate dt into A\n        w = w * dt.unsqueeze(-1)  # (H N)\n\n        # Stack B and p, C and q for convenient batching\n        B = torch.cat([B, P], dim=-3)  # (B+1+R, H, N)\n        C = torch.cat([C, Q], dim=-3)  # (C+R, H, N)\n\n        # Incorporate B and C batch dimensions\n        v = B.unsqueeze(-3) * C.unsqueeze(-4)  # (B+1+R, C+R, H, N)\n\n        # Calculate resolvent at omega\n        if has_cauchy_extension and z.dtype == torch.cfloat and not self.keops:\n            r = cauchy_mult(v, z, w, symmetric=True)\n        elif has_pykeops:\n            r = cauchy_conj(v, z, w)\n        else:\n            r = cauchy_naive(v, z, w)\n        r = r * dt[None, None, :, None]  # (B+1+R, C+R, H, L)\n\n        # Low-rank Woodbury correction\n        if self.rank == 1:\n            k_f = r[:-1, :-1, :, :] - r[:-1, -1:, :, :] * r[-1:, :-1, :, :] / (\n                1 + r[-1:, -1:, :, :]\n            )\n        elif self.rank == 2:\n            r00 = r[: -self.rank, : -self.rank, :, :]\n            r01 = r[: -self.rank, -self.rank :, :, :]\n            r10 = r[-self.rank :, : -self.rank, :, :]\n            r11 = r[-self.rank :, -self.rank :, :, :]\n            det = (1 + r11[:1, :1, :, :]) * (1 + r11[1:, 1:, :, :]) - r11[\n                :1, 1:, :, :\n            ] * r11[1:, :1, :, :]\n            s = (\n                r01[:, :1, :, :] * (1 + r11[1:, 1:, :, :]) * r10[:1, :, :, :]\n                + r01[:, 1:, :, :] * (1 + r11[:1, :1, :, :]) * r10[1:, :, :, :]\n                - r01[:, :1, :, :] * (r11[:1, 1:, :, :]) * r10[1:, :, :, :]\n                - r01[:, 1:, :, :] * (r11[1:, :1, :, :]) * r10[:1, :, :, :]\n            )\n            s = s / det\n            k_f = r00 - s\n        else:\n            r00 = r[: -self.rank, : -self.rank, :, :]\n            r01 = r[: -self.rank, -self.rank :, :, :]\n            r10 = r[-self.rank :, : -self.rank, :, :]\n            r11 = r[-self.rank :, -self.rank :, :, :]\n            r11 = rearrange(r11, \"a b h n -> h n a b\")\n            r11 = torch.linalg.inv(torch.eye(self.rank, device=r.device) + r11)\n            r11 = rearrange(r11, \"h n a b -> a b h n\")\n            k_f = r00 - torch.einsum(\n                \"i j h n, j k h n, k l h n -> i l h n\", r01, r11, r10\n            )\n\n        # Final correction for the bilinear transform\n        k_f = k_f * 2 / (1 + omega)\n\n        # Move from frequency to coefficients\n        k = torch.fft.irfft(k_f, n=discrete_L)  # (B+1, C, H, L)\n\n        # # Truncate to target length\n        k = k[..., :L]\n\n        if state is not None:\n            k_state = k[:-1, :, :, :]  # (B, C, H, L)\n        else:\n            k_state = None\n        k_B = k[-1, :, :, :]  # (C H L)\n\n        return k_B, k_state\n\n    @torch.no_grad()\n    def _setup_linear(self):\n        \"\"\"Create parameters that allow fast linear stepping of state\"\"\"\n        w = self._w()\n        B = _r2c(self.B)  # (H N)\n        P = _r2c(self.P)\n        Q = P.conj()\n\n        # Repeat w shape properly\n        B = repeat(B, \"1 t n -> 1 (v t) n\", v=self.repeat)\n        P = repeat(P, \"r t n -> r (v t) n\", v=self.repeat)\n        Q = repeat(Q, \"r t n -> r (v t) n\", v=self.repeat)\n        w = repeat(w, \"t n -> (v t) n\", v=self.repeat)\n\n        # Prepare Linear stepping\n        dt = torch.exp(self.log_dt)\n        D = (2.0 / dt.unsqueeze(-1) - w).reciprocal()  # (H, N)\n        R = (\n            torch.eye(self.rank, dtype=w.dtype, device=w.device)\n            + 2 * contract(\"r h n, h n, s h n -> h r s\", Q, D, P).real\n        )  # (H R R)\n        Q_D = rearrange(Q * D, \"r h n -> h r n\")\n        try:\n            R = torch.linalg.solve(R, Q_D)  # (H R N)\n        except Exception:\n            R = torch.tensor(\n                np.linalg.solve(\n                    R.to(Q_D).contiguous().detach().cpu(),\n                    Q_D.contiguous().detach().cpu(),\n                )\n            ).to(Q_D)\n        R = rearrange(R, \"h r n -> r h n\")\n\n        self.step_params = {\n            \"D\": D,  # (H N)\n            \"R\": R,  # (R H N)\n            \"P\": P,  # (R H N)\n            \"Q\": Q,  # (R H N)\n            \"B\": B,  # (1 H N)\n            \"E\": 2.0 / dt.unsqueeze(-1) + w,  # (H N)\n        }\n\n    def _step_state_linear(self, u=None, state=None):\n        \"\"\"\n        Version of the step function that has time O(N) instead of O(N^2) per step, which takes advantage of the DPLR form and bilinear discretization.\n\n        Unfortunately, as currently implemented it's about 2x slower because it calls several sequential operations. Perhaps a fused CUDA kernel implementation would be much faster\n\n        u: (H) input\n        state: (H, N/2) state with conjugate pairs\n          Optionally, the state can have last dimension N\n        Returns: same shape as state\n        \"\"\"\n        C = _r2c(self.C)  # View used for dtype/device\n\n        if u is None:  # Special case used to find dA\n            u = torch.zeros(self.H, dtype=C.dtype, device=C.device)\n        if state is None:  # Special case used to find dB\n            state = torch.zeros(self.H, self.N, dtype=C.dtype, device=C.device)\n\n        step_params = self.step_params.copy()\n        if (\n            state.size(-1) == self.N\n        ):  # Only store half of the conjugate pairs; should be true by default\n            # There should be a slightly faster way using conjugate symmetry\n            def contract_fn(p, x, y):\n                return contract(\n                    \"r h n, r h m, ... h m -> ... h n\",\n                    _conj(p),\n                    _conj(x),\n                    _conj(y),\n                )[\n                    ..., : self.N\n                ]  # inner outer product\n\n        else:\n            assert state.size(-1) == 2 * self.N\n            step_params = {k: _conj(v) for k, v in step_params.items()}\n\n            # TODO worth setting up a contract_expression in default_state if we want to use this at inference time for stepping\n            def contract_fn(p, x, y):\n                return contract(\n                    \"r h n, r h m, ... h m -> ... h n\", p, x, y\n                )  # inner outer product\n\n        D = step_params[\"D\"]  # (H N)\n        E = step_params[\"E\"]  # (H N)\n        R = step_params[\"R\"]  # (R H N)\n        P = step_params[\"P\"]  # (R H N)\n        Q = step_params[\"Q\"]  # (R H N)\n        B = step_params[\"B\"]  # (1 H N)\n\n        new_state = E * state - contract_fn(P, Q, state)  # (B H N)\n        new_state = new_state + 2.0 * B * u.unsqueeze(-1)  # (B H N)\n        new_state = D * (new_state - contract_fn(P, R, new_state))\n\n        return new_state\n\n    def _setup_state(self):\n        \"\"\"Construct dA and dB for discretized state equation\"\"\"\n\n        # Construct dA and dB by using the stepping\n        self._setup_linear()\n        C = _r2c(\n            self.C\n        )  # Just returns a view that we use for finding dtype/device\n\n        state = torch.eye(\n            2 * self.N, dtype=C.dtype, device=C.device\n        ).unsqueeze(\n            -2\n        )  # (N 1 N)\n        dA = self._step_state_linear(state=state)\n        dA = rearrange(dA, \"n h m -> h m n\")\n\n        u = C.new_ones(self.H)\n        dB = self._step_state_linear(u=u)\n        dB = _conj(dB)\n        dB = rearrange(dB, \"1 h n -> h n\")  # (H N)\n        return dA, dB\n\n    def _step_state(self, u, state):\n        \"\"\"Must be called after self.default_state() is used to construct an initial state!\"\"\"\n        next_state = self.state_contraction(\n            self.dA, state\n        ) + self.input_contraction(self.dB, u)\n        return next_state\n\n    def _setup_step(self, mode=\"dense\"):\n        \"\"\"Set up dA, dB, dC discretized parameters for stepping\"\"\"\n        self.dA, self.dB = self._setup_state()\n\n        # Calculate original C\n        C = _conj(_r2c(self.C))  # (H C N)\n        if self.L.item() == 0:\n            dC = C\n        else:\n            # self.C represents C_tilde\n            dA_L = power(self.L.item(), self.dA)\n            I = torch.eye(self.dA.size(-1)).to(dA_L)\n\n            dC = torch.linalg.solve(\n                I - dA_L.transpose(-1, -2),\n                C.unsqueeze(-1),\n            ).squeeze(-1)\n        self.dC = dC\n\n        # Do special preprocessing for different step modes\n\n        self._step_mode = mode\n        if mode == \"linear\":\n            # Linear case: special step function for the state, we need to handle output\n            # use conjugate symmetry by default, which affects the output projection\n            self.dC = 2 * self.dC[:, :, : self.N]\n        elif mode == \"diagonal\":\n            # Eigendecomposition of the A matrix\n            L, V = torch.linalg.eig(self.dA)\n            V_inv = torch.linalg.inv(V)\n            # Check that the eigendedecomposition is correct\n            if self.verbose:\n                print(\n                    \"Diagonalization error:\",\n                    torch.dist(V @ torch.diag_embed(L) @ V_inv, self.dA),\n                )\n\n            # Change the parameterization to diagonalize\n            self.dA = L\n            self.dB = contract(\"h n m, h m -> h n\", V_inv, self.dB)\n            self.dC = contract(\"h n m, c h n -> c h m\", V, self.dC)\n\n        elif mode == \"dense\":\n            pass\n        else:\n            raise NotImplementedError(\n                \"NPLR Kernel step mode must be {'dense' | 'linear' | 'diagonal'}\"\n            )\n\n    def default_state(self, *batch_shape):\n        C = _r2c(self.C)\n        N = C.size(-1)\n        H = C.size(-2)\n\n        # Cache the tensor contractions we will later do, for efficiency\n        # These are put in this function because they depend on the batch size\n        step_mode = getattr(\n            self, \"_step_mode\", \"dense\"\n        )  # Used in default_state, which is called without _setup_step() in forward_state()\n        if step_mode != \"linear\":\n            N *= 2\n\n            if step_mode == \"diagonal\":\n                self.state_contraction = contract_expression(\n                    \"h n, ... h n -> ... h n\",\n                    (H, N),\n                    batch_shape + (H, N),\n                )\n            else:\n                # Dense (quadratic) case: expand all terms\n                self.state_contraction = contract_expression(\n                    \"h m n, ... h n -> ... h m\",\n                    (H, N, N),\n                    batch_shape + (H, N),\n                )\n\n            self.input_contraction = contract_expression(\n                \"h n, ... h -> ... h n\",\n                (H, N),  # self.dB.shape\n                batch_shape + (H,),\n            )\n\n        self.output_contraction = contract_expression(\n            \"c h n, ... h n -> ... c h\",\n            (C.shape[0], H, N),  # self.dC.shape\n            batch_shape + (H, N),\n        )\n\n        state = torch.zeros(*batch_shape, H, N, dtype=C.dtype, device=C.device)\n        return state\n\n    def step(self, u, state):\n        \"\"\"Must have called self._setup_step() and created state with self.default_state() before calling this\"\"\"\n\n        if self._step_mode == \"linear\":\n            new_state = self._step_state_linear(u, state)\n        else:\n            new_state = self._step_state(u, state)\n        y = self.output_contraction(self.dC, new_state)\n        return y.real, new_state\n\n\nclass SSKernelDiag(OptimModule):\n    \"\"\"Version using (complex) diagonal state matrix (S4D)\"\"\"\n\n    def __init__(\n        self,\n        A,\n        B,\n        C,\n        log_dt,\n        L=None,\n        disc=\"bilinear\",\n        real_type=\"exp\",\n        lr=None,\n        bandlimit=None,\n    ):\n        super().__init__()\n        self.L = L\n        self.disc = disc\n        self.bandlimit = bandlimit\n        self.real_type = real_type\n\n        # Rank of low-rank correction\n        assert A.size(-1) == C.size(-1)\n        self.H = log_dt.size(-1)\n        self.N = A.size(-1)\n        assert A.size(-2) == B.size(-2)  # Number of independent SSMs trained\n        assert self.H % A.size(-2) == 0\n        self.n_ssm = A.size(-2)\n        self.repeat = self.H // A.size(0)\n\n        self.channels = C.shape[0]\n        self.C = nn.Parameter(_c2r(_resolve_conj(C)))\n\n        # Register parameters\n        if lr is None or isinstance(lr, float):\n            lr_dict = {}\n        else:\n            lr_dict, lr = lr, None\n\n        self.register(\"log_dt\", log_dt, lr_dict.get(\"dt\", lr))\n        self.register(\"B\", _c2r(B), lr_dict.get(\"B\", lr))\n        self.register(\"inv_A_real\", self._A_init(A.real), lr_dict.get(\"A\", lr))\n        self.register(\"A_imag\", A.imag, lr_dict.get(\"A\", lr))\n\n    def _A_init(self, A_real):\n        A_real = torch.clamp(A_real, max=-1e-4)\n        if self.real_type == \"none\":\n            return -A_real\n        elif self.real_type == \"exp\":\n            return torch.log(\n                -A_real\n            )  # Some of the HiPPO methods have real part 0\n        elif self.real_type == \"relu\":\n            return -A_real\n        elif self.real_type == \"sigmoid\":\n            return torch.logit(-A_real)\n        elif self.real_type == \"softplus\":\n            return torch.log(torch.exp(-A_real) - 1)\n        else:\n            raise NotImplementedError\n\n    def _A(self):\n        # Get the internal A (diagonal) parameter\n        if self.real_type == \"none\":\n            A_real = -self.inv_A_real\n        elif self.real_type == \"exp\":\n            A_real = -torch.exp(self.inv_A_real)\n        elif self.real_type == \"relu\":\n            # JAX version seems to NaN if you alloA 0's, although this code Aas fine Aithout it\n            A_real = -F.relu(self.inv_A_real) - 1e-4\n        elif self.real_type == \"sigmoid\":\n            A_real = -F.sigmoid(self.inv_A_real)\n        elif self.real_type == \"softplus\":\n            A_real = -F.softplus(self.inv_A_real)\n        else:\n            raise NotImplementedError\n        A = A_real + 1j * self.A_imag\n        return A\n\n    def forward(self, L, state=None, rate=1.0, u=None):\n        \"\"\"\n        state: (B, H, N) initial state\n        rate: sampling rate factor\n        L: target length\n\n        returns:\n        (C, H, L) convolution kernel (generally C=1)\n        (B, H, L) output from initial state\n        \"\"\"\n\n        dt = torch.exp(self.log_dt) * rate  # (H)\n        C = _r2c(self.C)  # (C H N)\n        A = self._A()  # (H N)\n\n        B = _r2c(self.B)\n        B = repeat(B, \"t n -> 1 (v t) n\", v=self.repeat)\n\n        if self.bandlimit is not None:\n            freqs = dt[:, None] / rate * A.imag.abs() / (2 * math.pi)  # (H, N)\n            mask = torch.where(freqs < self.bandlimit * 0.5, 1, 0)\n            C = C * mask\n\n        # Incorporate dt into A\n        A = repeat(A, \"t n -> (v t) n\", v=self.repeat)\n        dtA = A * dt.unsqueeze(-1)  # (H N)\n\n        # Augment B with state\n        if state is not None:\n            s = state / dt.unsqueeze(-1)\n            if self.disc == \"bilinear\":\n                s = s * (1.0 + dtA / 2)\n            elif self.disc == \"zoh\":\n                s = s * dtA * dtA.exp() / (dtA.exp() - 1.0)\n            B = torch.cat([s, B], dim=-3)  # (1+B H N)\n\n        C = (B[:, None, :, :] * C).view(-1, self.H, self.N)\n        if self.disc == \"zoh\":\n            # Power up\n            C = C * (torch.exp(dtA) - 1.0) / A\n            K = log_vandermonde(C, dtA, L)  # (H L)\n        elif self.disc == \"bilinear\":\n            C = (\n                C * (1.0 - dtA / 2).reciprocal() * dt.unsqueeze(-1)\n            )  # or * dtA / A\n            dA = (1.0 + dtA / 2) / (1.0 - dtA / 2)\n            K = log_vandermonde(C, dA.log(), L)\n        elif self.disc == \"dss\":\n            # Implementation from DSS meant for case when real eigenvalues can be positive\n            P = dtA.unsqueeze(-1) * torch.arange(L, device=C.device)  # [H N L]\n            A_gt_0 = A.real > 0  # [N]\n            if A_gt_0.any():\n                with torch.no_grad():\n                    P_max = dtA * (A_gt_0 * (L - 1))  # [H N]\n                P = P - P_max.unsqueeze(-1)  # [H N L]\n            S = P.exp()  # [H N L]\n\n            dtA_neg = dtA * (1 - 2 * A_gt_0)  # [H N]\n            num = dtA_neg.exp() - 1  # [H N]\n            den = (dtA_neg * L).exp() - 1  # [H N]\n\n            # Inline reciprocal function for DSS logic\n            x = den * A\n            x_conj = _resolve_conj(x)\n            r = x_conj / (x * x_conj + 1e-7)\n\n            C = C * num * r  # [C H N]\n            K = contract(\"chn,hnl->chl\", C, S).float()\n        else:\n            assert False, f\"{self.disc} not supported\"\n\n        K = K.view(-1, self.channels, self.H, L)  # (1+B C H L)\n        if state is not None:\n            K_state = K[:-1, :, :, :]  # (B C H L)\n        else:\n            K_state = None\n        K = K[-1, :, :, :]  # (C H L)\n        return K, K_state\n\n    def _setup_step(self):\n        # These methods are organized like this to be compatible with the NPLR kernel interface\n        dt = torch.exp(self.log_dt)  # (H)\n        B = _r2c(self.B)  # (H N)\n        C = _r2c(self.C)  # (C H N)\n        self.dC = C\n        A = self._A()  # (H N)\n\n        A = repeat(A, \"t n -> (v t) n\", v=self.repeat)\n        B = repeat(B, \"t n -> (v t) n\", v=self.repeat)\n\n        # Incorporate dt into A\n        dtA = A * dt.unsqueeze(-1)  # (H N)\n        if self.disc == \"zoh\":\n            self.dA = torch.exp(dtA)  # (H N)\n            self.dB = B * (torch.exp(dtA) - 1.0) / A  # (C H N)\n        elif self.disc == \"bilinear\":\n            self.dA = (1.0 + dtA / 2) / (1.0 - dtA / 2)\n            self.dB = (\n                B * (1.0 - dtA / 2).reciprocal() * dt.unsqueeze(-1)\n            )  # or * dtA / A\n\n    def default_state(self, *batch_shape):\n        C = _r2c(self.C)\n        state = torch.zeros(\n            *batch_shape, self.H, self.N, dtype=C.dtype, device=C.device\n        )\n        return state\n\n    def step(self, u, state):\n        next_state = contract(\n            \"h n, b h n -> b h n\", self.dA, state\n        ) + contract(\"h n, b h -> b h n\", self.dB, u)\n        y = contract(\"c h n, b h n -> b c h\", self.dC, next_state)\n        return 2 * y.real, next_state\n\n    def forward_state(self, u, state):\n        self._setup_step()\n        AL = self.dA ** u.size(-1)\n        u = u.flip(-1).to(self.dA).contiguous()  # (B H L)\n        v = log_vandermonde_transpose(u, self.dB, self.dA.log(), u.size(-1))\n        next_state = AL * state + v\n        return next_state\n\n\nclass SSKernel(nn.Module):\n    \"\"\"Wrapper around SSKernel parameterizations.\n\n    The SSKernel is expected to support the interface\n    forward()\n    default_state()\n    _setup_step()\n    step()\n    \"\"\"\n\n    def __init__(\n        self,\n        H,\n        N=64,\n        L=None,\n        measure=\"legs\",\n        rank=1,\n        channels=1,\n        dt_min=0.001,\n        dt_max=0.1,\n        deterministic=False,\n        lr=None,\n        mode=\"nplr\",\n        n_ssm=None,\n        verbose=False,\n        measure_args={},\n        **kernel_args,\n    ):\n        \"\"\"State Space Kernel which computes the convolution kernel $\\\\bar{K}$\n\n        H: Number of independent SSM copies; controls the size of the model. Also called d_model in the config.\n        N: State size (dimensionality of parameters A, B, C). Also called d_state in the config. Generally shouldn't need to be adjusted and doens't affect speed much.\n        L: Maximum length of convolution kernel, if known. Should work in the majority of cases even if not known.\n        measure: Options for initialization of (A, B). For NPLR mode, recommendations are \"legs\", \"fout\", \"hippo\" (combination of both). For Diag mode, recommendations are \"diag-inv\", \"diag-lin\", \"diag-legs\", and \"diag\" (combination of diag-inv and diag-lin)\n        rank: Rank of low-rank correction for NPLR mode. Needs to be increased for measure \"legt\"\n        channels: C channels turns the SSM from a 1-dim to C-dim map; can think of it having C separate \"heads\" per SSM. This was partly a feature to make it easier to implement bidirectionality; it is recommended to set channels=1 and adjust H to control parameters instead\n        dt_min, dt_max: min and max values for the step size dt (\\Delta)\n        mode: Which kernel algorithm to use. 'nplr' is the full S4 model; 'diag' is the simpler S4D; 'slow' is a dense version for testing\n        n_ssm: Number of independent trainable (A, B) SSMs, e.g. n_ssm=1 means all A/B parameters are tied across the H different instantiations of C. n_ssm=None means all H SSMs are completely independent. Generally, changing this option can save parameters but doesn't affect performance or speed much. This parameter must divide H\n        lr: Passing in a number (e.g. 0.001) sets attributes of SSM parameers (A, B, dt). A custom optimizer hook is needed to configure the optimizer to set the learning rates appropriately for these parameters.\n        \"\"\"\n        super().__init__()\n        self.N = N\n        self.H = H\n        dtype, cdtype = torch.float, torch.cfloat\n        self.channels = channels\n        self.n_ssm = n_ssm if n_ssm is not None else H\n        self.mode = mode\n        self.verbose = verbose\n        self.kernel_args = kernel_args\n\n        # Generate dt\n        if deterministic:\n            log_dt = torch.exp(\n                torch.linspace(math.log(dt_min), math.log(dt_max), H)\n            )\n        else:\n            log_dt = torch.rand(self.H, dtype=dtype) * (\n                math.log(dt_max) - math.log(dt_min)\n            ) + math.log(dt_min)\n\n        # Compute the preprocessed representation\n        w, P, B, V = combination(\n            measure, self.N, rank, self.n_ssm, **measure_args\n        )\n\n        # Broadcast C to have H channels\n        if deterministic:\n            C = torch.zeros(channels, self.n_ssm, self.N, dtype=cdtype)\n            C[:, :, :1] = 1.0\n            C = contract(\n                \"hmn, chn -> chm\", V.conj().transpose(-1, -2), C\n            )  # V^* C\n            C = (\n                repeat(C, \"c t n -> c (v t) n\", v=self.n_ssm // C.size(-2))\n                .clone()\n                .contiguous()\n            )\n        else:\n            C = torch.randn(channels, self.H, self.N // 2, dtype=cdtype)\n\n        # Broadcast other parameters to have n_ssm copies\n        assert (\n            self.n_ssm % B.size(-2) == 0\n            and self.n_ssm % P.size(-2) == 0\n            and self.n_ssm % w.size(-2) == 0\n        )\n        # Broadcast tensors to n_ssm copies\n        # These will be the parameters, so make sure tensors are materialized and contiguous\n        B = (\n            repeat(B, \"t n -> (v t) n\", v=self.n_ssm // B.size(-2))\n            .clone()\n            .contiguous()\n        )\n        P = (\n            repeat(P, \"r t n -> r (v t) n\", v=self.n_ssm // P.size(-2))\n            .clone()\n            .contiguous()\n        )\n        w = (\n            repeat(w, \"t n -> (v t) n\", v=self.n_ssm // w.size(-2))\n            .clone()\n            .contiguous()\n        )\n\n        if mode == \"nplr\":\n            self.kernel = SSKernelNPLR(\n                w,\n                P,\n                B,\n                C,\n                log_dt,\n                L=L,\n                lr=lr,\n                verbose=verbose,\n                **kernel_args,\n            )\n        elif mode == \"diag\":\n            if not measure.startswith(\"diag\"):\n                log.warning(\n                    \"Diagonal kernel (S4D) activated but initialization is not intended for S4D. Set `measure` to 'diag-lin', 'diag-inv', or 'diag-legs' for the main variants, or 'diag' for a combination of S4D-Lin and S4D-Inv.\"\n                )\n            C = C * repeat(B, \"t n -> (v t) n\", v=H // self.n_ssm)\n            self.kernel = SSKernelDiag(\n                w,\n                B,\n                C,\n                log_dt,\n                L=L,\n                lr=lr,\n                **kernel_args,\n            )\n        else:\n            raise NotImplementedError(f\"{mode=} is not valid\")\n\n    def forward(self, state=None, L=None, rate=1.0):\n        return self.kernel(state=state, L=L, rate=rate)\n\n    @torch.no_grad()\n    def forward_state(self, u, state):\n        \"\"\"Forward the state through a sequence, i.e. computes the state after passing chunk through SSM\n\n        state: (B, H, N)\n        u: (B, H, L)\n\n        Returns: (B, H, N)\n        \"\"\"\n\n        if hasattr(self.kernel, \"forward_state\"):\n            return self.kernel.forward_state(u, state)\n\n        dA, dB = self.kernel._setup_state()  # Construct dA, dB matrices\n        # dA, dB = self.kernel.dA, self.kernel.dB # (H N N) (H N)\n\n        conj = state.size(-1) != dA.size(-1)\n        if conj:\n            state = _conj(state)\n\n        v = contract(\n            \"h n, b h l -> b h n l\", dB, u.flip(-1)\n        )  # dB.unsqueeze(-1) * u.flip(-1).unsqueeze(-2)\n        AL, v = power(u.size(-1), dA, v)\n        next_state = contract(\"h m n, b h n -> b h m\", AL, state)\n        next_state = next_state + v\n\n        if conj:\n            next_state = next_state[..., : next_state.size(-1) // 2]\n        return next_state\n\n    def _setup_step(self, **kwargs):\n        # This method is intended to be private so that setting up an S4 module with\n        # ```\n        # if hasattr(module, 'setup_step'): module.setup_step()\n        # ```\n        # will not trigger this method multiple times\n        self.kernel._setup_step(**kwargs)\n\n    def step(self, u, state, **kwargs):\n        y, state = self.kernel.step(u, state, **kwargs)\n        return y, state\n\n    def default_state(self, *args, **kwargs):\n        return self.kernel.default_state(*args, **kwargs)\n\n\nclass S4(nn.Module):\n    def __init__(\n        self,\n        d_model,\n        d_state=64,\n        l_max=None,\n        channels=1,\n        bidirectional=False,\n        # Arguments for position-wise feedforward components\n        activation=\"gelu\",\n        postact=\"glu\",\n        hyper_act=None,\n        dropout=0.0,\n        tie_dropout=False,\n        bottleneck=None,\n        gate=None,\n        transposed=True,\n        verbose=False,\n        # SSM Kernel arguments\n        **kernel_args,\n    ):\n        \"\"\"\n        d_state: the dimension of the state, also denoted by N\n        l_max: the maximum kernel length, also denoted by L. Set l_max=None to always use a global kernel\n        channels: can be interpreted as a number of \"heads\"; the SSM is a map from a 1-dim to C-dim sequence. It's not recommended to change this unless desperate for things to tune; instead, increase d_model for larger models\n        bidirectional: if True, convolution kernel will be two-sided\n\n        Position-wise feedforward components:\n        --------------------\n        activation: activation in between SS and FF\n        postact: activation after FF\n        hyper_act: use a \"hypernetwork\" multiplication (experimental)\n        dropout: standard dropout argument. tie_dropout=True ties the dropout mask across the sequence length, emulating nn.Dropout1d\n\n        Other arguments:\n        --------------------\n        transposed: choose backbone axis ordering of (B, L, H) (if False) or (B, H, L) (if True) [B=batch size, L=sequence length, H=hidden dimension]\n        gate: add gated activation (GSS)\n        bottleneck: reduce SSM dimension (GSS)\n\n        See the class SSKernel for the kernel constructor which accepts kernel_args. Relevant options that are worth considering and tuning include \"mode\" + \"measure\", \"dt_min\", \"dt_max\", \"lr\"\n\n        Other options are all experimental and should not need to be configured\n        \"\"\"\n\n        super().__init__()\n        if verbose:\n            log.info(\n                f\"Constructing S4 (H, N, L) = ({d_model}, {d_state}, {l_max})\"\n            )\n\n        self.d_model = d_model\n        self.H = d_model\n        self.N = d_state\n        self.L = l_max\n        self.bidirectional = bidirectional\n        self.channels = channels\n        self.transposed = transposed\n\n        self.gate = gate\n        self.bottleneck = bottleneck\n\n        if bottleneck is not None:\n            self.H = self.H // bottleneck\n            self.input_linear = LinearActivation(\n                self.d_model,\n                self.H,\n                transposed=self.transposed,\n                activation=activation,\n                activate=True,\n            )\n\n        if gate is not None:\n            self.input_gate = LinearActivation(\n                self.d_model,\n                self.d_model * gate,\n                transposed=self.transposed,\n                activation=activation,\n                activate=True,\n            )\n            self.output_gate = LinearActivation(\n                self.d_model * gate,\n                self.d_model,\n                transposed=self.transposed,\n                activation=None,\n                activate=False,\n            )\n\n        # optional multiplicative modulation GLU-style\n        # https://arxiv.org/abs/2002.05202\n        self.hyper = hyper_act is not None\n        if self.hyper:\n            channels *= 2\n            self.hyper_activation = Activation(hyper_act)\n\n        self.D = nn.Parameter(torch.randn(channels, self.H))\n\n        if self.bidirectional:\n            channels *= 2\n\n        # SSM Kernel\n        self.kernel = SSKernel(\n            self.H,\n            N=self.N,\n            L=self.L,\n            channels=channels,\n            verbose=verbose,\n            **kernel_args,\n        )\n\n        # Pointwise\n        self.activation = Activation(activation)\n        dropout_fn = DropoutNd if tie_dropout else nn.Dropout\n        self.dropout = dropout_fn(dropout) if dropout > 0.0 else nn.Identity()\n        # position-wise output transform to mix features\n        self.output_linear = LinearActivation(\n            self.H * self.channels,\n            self.d_model * (1 if self.gate is None else self.gate),\n            transposed=self.transposed,\n            activation=postact,\n            activate=True,\n        )\n\n    def forward(self, u, state=None, rate=1.0, lengths=None, **kwargs):\n        \"\"\"\n        u: (B H L) if self.transposed else (B L H)\n        state: (H N) never needed unless you know what you're doing\n\n        Returns: same shape as u\n        \"\"\"\n        if not self.transposed:\n            u = u.transpose(-1, -2)\n        L = u.size(-1)\n\n        # Mask out padding tokens\n        if isinstance(lengths, int):\n            if lengths != L:\n                lengths = torch.tensor(\n                    lengths, dtype=torch.long, device=u.device\n                )\n            else:\n                lengths = None\n        if lengths is not None:\n            assert (\n                isinstance(lengths, torch.Tensor)\n                and lengths.ndim == 1\n                and lengths.size(0) in [1, u.size(0)]\n            )\n            mask = torch.where(\n                torch.arange(L, device=lengths.device)\n                < lengths[:, None, None],\n                1.0,\n                0.0,\n            )\n            u = u * mask\n\n        if self.gate is not None:\n            v = self.input_gate(u)\n        if self.bottleneck is not None:\n            u = self.input_linear(u)\n\n        # Compute SS Kernel\n        L_kernel = L if self.L is None else min(L, round(self.L / rate))\n        k, k_state = self.kernel(\n            L=L_kernel, rate=rate, state=state\n        )  # (C H L) (B C H L)\n\n        # Convolution\n        if self.bidirectional:\n            k0, k1 = rearrange(k, \"(s c) h l -> s c h l\", s=2)\n            k = F.pad(k0, (0, L)) + F.pad(k1.flip(-1), (L, 0))\n        k_f = torch.fft.rfft(k, n=L_kernel + L)  # (C H L)\n        u_f = torch.fft.rfft(u, n=L_kernel + L)  # (B H L)\n        y_f = contract(\"bhl,chl->bchl\", u_f, k_f)\n        y = torch.fft.irfft(y_f, n=L_kernel + L)[..., :L]  # (B C H L)\n\n        # Compute D term in state space equation - essentially a skip connection\n        y = y + contract(\"bhl,ch->bchl\", u, self.D)\n\n        # Compute state update\n        if state is not None:\n            assert (\n                not self.bidirectional\n            ), \"Bidirectional not supported with state forwarding\"\n            y = y + k_state  #\n            next_state = self.kernel.forward_state(u, state)\n        else:\n            next_state = None\n\n        # Optional hyper-network multiplication\n        if self.hyper:\n            y, yh = rearrange(y, \"b (s c) h l -> s b c h l\", s=2)\n            y = self.hyper_activation(yh) * y\n\n        # Reshape to flatten channels\n        y = rearrange(y, \"... c h l -> ... (c h) l\")\n\n        y = self.dropout(self.activation(y))\n\n        if not self.transposed:\n            y = y.transpose(-1, -2)\n\n        y = self.output_linear(y)\n\n        if self.gate is not None:\n            y = self.output_gate(y * v)\n\n        return y, next_state\n\n    def setup_step(self, **kwargs):\n        self.kernel._setup_step(**kwargs)\n\n    def step(self, u, state):\n        \"\"\"Step one time step as a recurrent model. Intended to be used during validation.\n\n        u: (B H)\n        state: (B H N)\n        Returns: output (B H), state (B H N)\n        \"\"\"\n        assert not self.training\n\n        y, next_state = self.kernel.step(u, state)  # (B C H)\n        y = y + u.unsqueeze(-2) * self.D\n        y = rearrange(y, \"b c h -> b (c h)\")\n        y = self.activation(y)\n        if self.transposed:\n            y = self.output_linear(y.unsqueeze(-1)).squeeze(-1)\n        else:\n            y = self.output_linear(y)\n        return y, next_state\n\n    def default_state(self, *batch_shape, device=None):\n        # kernel is not a SequenceModule so it doesn't need to adhere to same interface\n        # the kernel will know the device of its own parameters\n        return self.kernel.default_state(*batch_shape)\n\n    @property\n    def d_output(self):\n        return self.d_model\n"
  },
  {
    "path": "src/uncond_ts_diff/configs.py",
    "content": "# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n# SPDX-License-Identifier: Apache-2.0\nfrom uncond_ts_diff.utils import linear_beta_schedule\n\nresidual_block_s4_backbone = {\n    \"input_dim\": 1,\n    \"hidden_dim\": 128,\n    \"output_dim\": 1,\n    \"step_emb\": 128,\n    \"num_residual_blocks\": 6,\n    \"residual_block\": \"s4\",\n}\n\nresidual_block_s4_backbone_smallv2 = {\n    \"input_dim\": 1,\n    \"hidden_dim\": 512,\n    \"output_dim\": 1,\n    \"step_emb\": 128,\n    \"num_residual_blocks\": 3,\n    \"residual_block\": \"s4\",\n}\n\nresidual_block_s4_backbone_small = {\n    \"input_dim\": 1,\n    \"hidden_dim\": 64,\n    \"output_dim\": 1,\n    \"step_emb\": 128,\n    \"num_residual_blocks\": 3,\n    \"residual_block\": \"s4\",\n}\n\nresidual_block_s4_backbone_small_dropout01 = {\n    \"input_dim\": 1,\n    \"hidden_dim\": 64,\n    \"output_dim\": 1,\n    \"step_emb\": 128,\n    \"num_residual_blocks\": 3,\n    \"dropout\": 0.1,\n    \"residual_block\": \"s4\",\n}\n\nresidual_block_s4_backbone_small_dropout02 = {\n    \"input_dim\": 1,\n    \"hidden_dim\": 64,\n    \"output_dim\": 1,\n    \"step_emb\": 128,\n    \"num_residual_blocks\": 3,\n    \"dropout\": 0.2,\n    \"residual_block\": \"s4\",\n}\n\nresidual_block_s4_backbone_small_dropout03 = {\n    \"input_dim\": 1,\n    \"hidden_dim\": 64,\n    \"output_dim\": 1,\n    \"step_emb\": 128,\n    \"num_residual_blocks\": 3,\n    \"dropout\": 0.3,\n    \"residual_block\": \"s4\",\n}\n\nresidual_block_s4_backbone_large = {\n    \"input_dim\": 1,\n    \"hidden_dim\": 128,\n    \"output_dim\": 1,\n    \"step_emb\": 128,\n    \"num_residual_blocks\": 18,\n    \"residual_block\": \"s4\",\n}\n\n\ndiffusion_config = {\n    \"backbone_parameters\": residual_block_s4_backbone,\n    \"timesteps\": 100,\n    \"diffusion_scheduler\": linear_beta_schedule,\n}\ndiffusion_small_config = {\n    \"backbone_parameters\": residual_block_s4_backbone_small,\n    \"timesteps\": 100,\n    \"diffusion_scheduler\": linear_beta_schedule,\n}\n\ndiffusion_small_configv2 = {\n    \"backbone_parameters\": residual_block_s4_backbone_smallv2,\n    \"timesteps\": 100,\n    \"diffusion_scheduler\": linear_beta_schedule,\n}\n\n\ndiffusion_small_config_dropout = {\n    \"backbone_parameters\": residual_block_s4_backbone_small_dropout01,\n    \"timesteps\": 100,\n    \"diffusion_scheduler\": linear_beta_schedule,\n}\n\ndiffusion_small_config_dropout02 = {\n    \"backbone_parameters\": residual_block_s4_backbone_small_dropout02,\n    \"timesteps\": 100,\n    \"diffusion_scheduler\": linear_beta_schedule,\n}\n\ndiffusion_small_config_dropout03 = {\n    \"backbone_parameters\": residual_block_s4_backbone_small_dropout03,\n    \"timesteps\": 100,\n    \"diffusion_scheduler\": linear_beta_schedule,\n}\n\ndiffusion_large_config = {\n    \"backbone_parameters\": residual_block_s4_backbone_large,\n    \"timesteps\": 100,\n    \"diffusion_scheduler\": linear_beta_schedule,\n}\n"
  },
  {
    "path": "src/uncond_ts_diff/dataset.py",
    "content": "# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n# SPDX-License-Identifier: Apache-2.0\nimport os\nimport tarfile\nfrom pathlib import Path\nfrom urllib import request\n\nfrom gluonts.dataset.common import load_datasets\nfrom gluonts.dataset.repository.datasets import get_dataset, get_download_path\n\ndefault_dataset_path: Path = get_download_path() / \"datasets\"\nwiki2k_download_link: str = \"https://github.com/awslabs/gluonts/raw/b89f203595183340651411a41eeb0ee60570a4d9/datasets/wiki2000_nips.tar.gz\"  # noqa: E501\n\n\ndef get_gts_dataset(dataset_name):\n    if dataset_name == \"wiki2000_nips\":\n        wiki_dataset_path = default_dataset_path / dataset_name\n        Path(default_dataset_path).mkdir(parents=True, exist_ok=True)\n        if not wiki_dataset_path.exists():\n            tar_file_path = wiki_dataset_path.parent / f\"{dataset_name}.tar.gz\"\n            request.urlretrieve(\n                wiki2k_download_link,\n                tar_file_path,\n            )\n\n            with tarfile.open(tar_file_path) as tar:\n                tar.extractall(path=wiki_dataset_path.parent)\n\n            os.remove(tar_file_path)\n        return load_datasets(\n            metadata=wiki_dataset_path / \"metadata\",\n            train=wiki_dataset_path / \"train\",\n            test=wiki_dataset_path / \"test\",\n        )\n    else:\n        return get_dataset(dataset_name)\n"
  },
  {
    "path": "src/uncond_ts_diff/metrics/__init__.py",
    "content": "# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n# SPDX-License-Identifier: Apache-2.0\nfrom .linear_pred_score import linear_pred_score\n\n__all__ = [\"linear_pred_score\"]\n"
  },
  {
    "path": "src/uncond_ts_diff/metrics/linear_pred_score.py",
    "content": "# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n# SPDX-License-Identifier: Apache-2.0\nfrom typing import Tuple\nfrom functools import partial\n\nimport numpy as np\nfrom gluonts.evaluation import Evaluator\nfrom gluonts.dataset.split import slice_data_entry\nfrom gluonts.transform import AdhocTransform, Chain\n\nfrom uncond_ts_diff.model import LinearEstimator\nfrom uncond_ts_diff.utils import (\n    GluonTSNumpyDataset,\n    ScaleAndAddMeanFeature,\n    ScaleAndAddMinMaxFeature,\n    make_evaluation_predictions_with_scaling,\n)\n\n\ndef linear_pred_score(\n    samples: np.ndarray,\n    context_length: int,\n    prediction_length: int,\n    test_dataset,\n    num_samples: int = 1,\n    scaling_type: str = \"mean\",\n) -> Tuple[dict, list, list]:\n    \"\"\"Compute the linear predictive score.\n    Uses the `samples` to to fit a LinearRegression model\n    and evaluate the forecast performance on the provided\n    `test_dataset`.\n\n    Parameters\n    ----------\n    samples\n        The samples used to fit the linear regression model.\n        A numpy array of shape [N, T].\n        Assumed to be already scaled.\n    context_length\n        The context length for the linear model.\n    prediction_length\n        The prediction length for the linear model.\n        Must be the same as the prediction length of the\n        target `test_dataset`.\n    test_datastet\n        The test dataset on which the linear model will\n        be evaluated.\n    num_samples, optional\n        Number of samples to draw from the linear model.\n        Since the linear model is a point forecaster,\n        `num_samples` > 1 would just result in the forecast\n        being repeated `num_samples` times, by default 1\n    scaling_type, optional\n        Scaling type should be one of {\"mean\", \"min-max\"}\n        Min-max scaling is used in TimeGAN, defaults to \"mean\"\n\n    Returns\n    -------\n        Evaluation metrics, target test time series and forecasts\n    \"\"\"\n    min_past = context_length + prediction_length\n    assert samples.shape[1] >= min_past\n    dataset = GluonTSNumpyDataset(samples)\n\n    linear_predictor = LinearEstimator(\n        freq=\"H\",  # Not actually used in the estimator\n        prediction_length=prediction_length,\n        context_length=context_length,\n        num_train_samples=len(dataset),\n        # Since `samples` are synthetic samples, they are assumed to be already scaled\n        scaling=False,\n    ).train(dataset)\n\n    # The linear predictor has been trained on scaled samples,\n    # however, the test dataset is still in the original space.\n    # Therefore, the test time series need to be sliced and\n    # scaled before being fed into the predictor.\n    # After prediction, the time series must be scaled back to\n    # the original space for metric computation.\n    # The following lines of code perform this custom evaluation.\n\n    # Slice test set to be of the same length as context_length + prediction_length\n    slice_func = partial(slice_data_entry, slice_=slice(-min_past, None))\n    if scaling_type == \"mean\":\n        ScaleAndAddScaleFeature = ScaleAndAddMeanFeature\n    elif scaling_type == \"min-max\":\n        ScaleAndAddScaleFeature = ScaleAndAddMinMaxFeature\n    transformation = Chain(\n        [\n            AdhocTransform(slice_func),\n            # Add scale to data entry for use later during evaluation\n            ScaleAndAddScaleFeature(\"target\", \"scale\", prediction_length),\n        ]\n    )\n    sliced_test_set = transformation.apply(test_dataset)\n\n    evaluator = Evaluator()\n    forecast_it, ts_it = make_evaluation_predictions_with_scaling(\n        dataset=sliced_test_set,\n        predictor=linear_predictor,\n        num_samples=num_samples,\n        scaling_type=scaling_type,\n    )\n    forecasts = list(forecast_it)\n    tss = list(ts_it)\n\n    metrics, _ = evaluator(tss, forecasts)\n\n    return metrics, tss, forecasts\n"
  },
  {
    "path": "src/uncond_ts_diff/model/__init__.py",
    "content": "# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n# SPDX-License-Identifier: Apache-2.0\nfrom .diffusion.tsdiff import TSDiff\nfrom .diffusion.tsdiff_cond import TSDiffCond\nfrom .linear._estimator import LinearEstimator\n\n__all__ = [\n    \"TSDiff\",\n    \"TSDiffCond\",\n    \"LinearEstimator\",\n]\n"
  },
  {
    "path": "src/uncond_ts_diff/model/callback.py",
    "content": "# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n# SPDX-License-Identifier: Apache-2.0\nfrom copy import deepcopy\nimport math\nfrom pathlib import Path\n\nimport numpy as np\nimport torch\n\nfrom gluonts.dataset.field_names import FieldName\nfrom gluonts.evaluation import make_evaluation_predictions, Evaluator\n\nfrom gluonts.transform import TestSplitSampler, InstanceSplitter\nfrom pytorch_lightning import Callback\n\nfrom uncond_ts_diff.sampler import DDPMGuidance, DDIMGuidance\nfrom uncond_ts_diff.metrics import linear_pred_score\nfrom uncond_ts_diff.utils import ConcatDataset\n\n\nclass GradNormCallback(Callback):\n    def __init__(self) -> None:\n        super().__init__()\n\n    def on_before_optimizer_step(\n        self,\n        trainer,\n        pl_module,\n        optimizer,\n        opt_idx: int,\n    ) -> None:\n        return pl_module.log(\n            \"grad_norm\", self.grad_norm(pl_module.parameters()), prog_bar=True\n        )\n\n    def grad_norm(self, parameters):\n        parameters = [p for p in parameters if p.grad is not None]\n        device = parameters[0].grad.device\n        total_norm = torch.norm(\n            torch.stack(\n                [torch.norm(p.grad.detach(), 2).to(device) for p in parameters]\n            ),\n            2,\n        )\n        return total_norm\n\n\nclass PredictiveScoreCallback(Callback):\n    def __init__(\n        self,\n        context_length,\n        prediction_length,\n        model,\n        transformation,\n        train_dataloader,\n        train_batch_size,\n        test_dataset,\n        eval_every=10,\n    ):\n        super().__init__()\n        self.context_length = context_length\n        self.prediction_length = prediction_length\n        self.model = model\n        self.transformation = transformation\n        self.train_dataloader = train_dataloader\n        self.train_batch_size = train_batch_size\n        self.test_dataset = test_dataset\n        self.eval_every = eval_every\n        # Number of samples used to train the downstream predictor\n        self.n_pred_samples = 10000\n\n    def _generate_real_samples(\n        self,\n        data_loader,\n        num_samples: int,\n        n_timesteps: int,\n        batch_size: int,\n        cache_path: Path,\n    ):\n        if cache_path.exists():\n            real_samples = np.load(cache_path)\n            if len(real_samples) == num_samples:\n                return real_samples\n\n        real_samples = []\n        data_iter = iter(data_loader)\n        n_iters = math.ceil(num_samples / batch_size)\n        for i in range(n_iters):\n            try:\n                batch = next(data_iter)\n            except StopIteration:\n                data_iter = iter(data_loader)\n                batch = next(data_iter)\n            ts = np.concatenate(\n                [batch[\"past_target\"], batch[\"future_target\"]], axis=-1\n            )[:, -n_timesteps:]\n            real_samples.append(ts)\n\n        real_samples = np.concatenate(real_samples, axis=0)[:num_samples]\n        np.save(cache_path, real_samples)\n\n        return real_samples\n\n    def _generate_synth_samples(\n        self, model, num_samples: int, batch_size: int = 1000\n    ):\n        synth_samples = []\n\n        n_iters = math.ceil(num_samples / batch_size)\n        for _ in range(n_iters):\n            samples = model.sample_n(num_samples=batch_size)\n            synth_samples.append(samples)\n\n        synth_samples = np.concatenate(synth_samples, axis=0)[:num_samples]\n        return synth_samples\n\n    def on_train_epoch_end(self, trainer, pl_module):\n        if (pl_module.current_epoch + 1) % self.eval_every == 0:\n            device = next(pl_module.backbone.parameters()).device\n            pl_module.eval()\n            assert pl_module.training is False\n\n            real_samples = self._generate_real_samples(\n                self.train_dataloader,\n                self.n_pred_samples,\n                self.context_length + self.prediction_length,\n                self.train_batch_size,\n                cache_path=Path(trainer.logger.log_dir) / \"real_samples.npy\",\n            )\n            synth_samples = self._generate_synth_samples(\n                self.model,\n                self.n_pred_samples,\n            )\n\n            # Train using synthetic samples, test on test set\n            synth_metrics, _, _ = linear_pred_score(\n                synth_samples,\n                self.context_length,\n                self.prediction_length,\n                self.test_dataset,\n                scaling_type=\"mean\",\n            )\n\n            # Train using real samples, test on test set\n            scaled_real_samples, _ = self.model.scaler(\n                torch.from_numpy(real_samples).to(device),\n                torch.from_numpy(np.ones_like(real_samples)).to(device),\n            )\n            real_metrics, _, _ = linear_pred_score(\n                scaled_real_samples.cpu().numpy(),\n                self.context_length,\n                self.prediction_length,\n                self.test_dataset,\n                scaling_type=\"mean\",\n            )\n\n            pl_module.log_dict(\n                {\n                    \"synth_linear_ND\": synth_metrics[\"ND\"],\n                    \"synth_linear_NRMSE\": synth_metrics[\"NRMSE\"],\n                    \"real_linear_ND\": real_metrics[\"ND\"],\n                    \"real_linear_NRMSE\": real_metrics[\"NRMSE\"],\n                }\n            )\n\n            pl_module.train()\n\n\nclass EvaluateCallback(Callback):\n    def __init__(\n        self,\n        context_length,\n        prediction_length,\n        sampler,\n        sampler_kwargs,\n        num_samples,\n        model,\n        transformation,\n        test_dataset,\n        val_dataset,\n        eval_every=50,\n    ):\n        super().__init__()\n        self.context_length = context_length\n        self.prediction_length = prediction_length\n        self.sampler = sampler\n        self.num_samples = num_samples\n        self.sampler_kwargs = sampler_kwargs\n        self.model = model\n        self.transformation = transformation\n        self.test_dataset = test_dataset\n        self.val_data = val_dataset\n        self.original_state_dict = {}\n        self.eval_every = eval_every\n        self.log_metrics = {\n            \"CRPS\",\n            \"ND\",\n            \"NRMSE\",\n        }\n\n        if sampler == \"ddpm\":\n            self.Guidance = DDPMGuidance\n        elif sampler == \"ddim\":\n            self.Guidance = DDIMGuidance\n        else:\n            raise ValueError(f\"Unknown sampler type: {sampler}\")\n\n    def on_train_epoch_end(self, trainer, pl_module):\n        if (pl_module.current_epoch + 1) % self.eval_every == 0:\n            device = next(pl_module.backbone.parameters()).device\n            self.original_state_dict = deepcopy(\n                pl_module.backbone.state_dict()\n            )\n            pl_module.eval()\n            assert pl_module.training is False\n            for label, state_dict in zip(\n                [\"\"] + [str(rate) for rate in pl_module.ema_rate],\n                [pl_module.backbone.state_dict()] + pl_module.ema_state_dicts,\n            ):\n                pl_module.backbone.load_state_dict(state_dict, strict=True)\n                pl_module.to(device)\n                prediction_splitter = InstanceSplitter(\n                    target_field=FieldName.TARGET,\n                    is_pad_field=FieldName.IS_PAD,\n                    start_field=FieldName.START,\n                    forecast_start_field=FieldName.FORECAST_START,\n                    instance_sampler=TestSplitSampler(),\n                    past_length=self.context_length + max(self.model.lags_seq),\n                    future_length=self.prediction_length,\n                    time_series_fields=[\n                        FieldName.FEAT_TIME,\n                        FieldName.OBSERVED_VALUES,\n                    ],\n                )\n                og = self.Guidance(\n                    self.model,\n                    self.prediction_length,\n                    num_samples=self.num_samples,\n                    **self.sampler_kwargs,\n                )\n                predictor_pytorch = og.get_predictor(\n                    prediction_splitter,\n                    batch_size=1024 // self.num_samples,\n                    device=device,\n                )\n                evaluator = Evaluator()\n\n                transformed_valdata = self.transformation.apply(\n                    ConcatDataset(self.val_data), is_train=False\n                )\n\n                forecast_it, ts_it = make_evaluation_predictions(\n                    dataset=transformed_valdata,\n                    predictor=predictor_pytorch,\n                    num_samples=self.num_samples,\n                )\n\n                forecasts_pytorch = list(forecast_it)\n                tss_pytorch = list(ts_it)\n\n                metrics_pytorch, per_ts = evaluator(\n                    tss_pytorch, forecasts_pytorch\n                )\n                metrics_pytorch[\"CRPS\"] = metrics_pytorch[\"mean_wQuantileLoss\"]\n                if metrics_pytorch[\"CRPS\"] < pl_module.best_crps:\n                    pl_module.best_crps = metrics_pytorch[\"CRPS\"]\n                    ckpt_path = (\n                        Path(trainer.logger.log_dir) / \"best_checkpoint.ckpt\"\n                    )\n                    torch.save(\n                        pl_module.state_dict(),\n                        ckpt_path,\n                    )\n                pl_module.log_dict(\n                    {\n                        f\"val_{metric}{label}\": metrics_pytorch[metric]\n                        for metric in self.log_metrics\n                    }\n                )\n            pl_module.backbone.load_state_dict(\n                self.original_state_dict, strict=True\n            )\n            pl_module.train()\n"
  },
  {
    "path": "src/uncond_ts_diff/model/diffusion/_base.py",
    "content": "# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n# SPDX-License-Identifier: Apache-2.0\nfrom typing import Optional\n\nimport numpy as np\nimport pandas as pd\n\nimport torch\nimport torch.nn.functional as F\nimport pytorch_lightning as pl\nfrom torch.optim.lr_scheduler import ReduceLROnPlateau\nfrom gluonts.time_feature import time_features_from_frequency_str\nfrom gluonts.torch.modules.feature import FeatureEmbedder\nfrom gluonts.torch.modules.scaler import MeanScaler, NOPScaler\n\nfrom uncond_ts_diff.utils import extract\n\nPREDICTION_INPUT_NAMES = [\n    \"past_target\",\n    \"past_observed_values\",\n    \"feat_static_cat\",\n    \"feat_static_real\",\n    \"past_time_feat\",\n    \"future_time_feat\",\n]\n\n\nclass TSDiffBase(pl.LightningModule):\n    def __init__(\n        self,\n        backbone_parameters,\n        timesteps,\n        diffusion_scheduler,\n        context_length,\n        prediction_length,\n        num_feat_dynamic_real: int = 0,\n        num_feat_static_cat: int = 0,\n        num_feat_static_real: int = 0,\n        cardinalities=None,\n        freq=None,\n        normalization=\"none\",\n        use_features=False,\n        use_lags=True,\n        lr: float = 1e-3,\n    ):\n        super().__init__()\n        self.save_hyperparameters()\n        self.timesteps = timesteps\n        self.betas = diffusion_scheduler(timesteps)\n        self.sqrt_one_minus_beta = torch.sqrt(1.0 - self.betas)\n        self.alphas = 1 - self.betas\n        self.alphas_cumprod = torch.cumprod(self.alphas, axis=0)\n        self.alphas_cumprod_prev = F.pad(\n            self.alphas_cumprod[:-1], (1, 0), value=1.0\n        )\n        self.sqrt_recip_alphas = torch.sqrt(1.0 / self.alphas)\n        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)\n        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(\n            1.0 - self.alphas_cumprod\n        )\n        self.posterior_variance = (\n            self.betas\n            * (1.0 - self.alphas_cumprod_prev)\n            / (1.0 - self.alphas_cumprod)\n        )\n        self.logs = {}\n        self.normalization = normalization\n        if normalization == \"mean\":\n            self.scaler = MeanScaler(dim=1, keepdim=True)\n        else:\n            self.scaler = NOPScaler(dim=1, keepdim=True)\n        if cardinalities is None:\n            cardinalities = [1]\n        self.embedder = FeatureEmbedder(\n            cardinalities=cardinalities,\n            embedding_dims=[min(50, (cat + 1) // 2) for cat in cardinalities],\n        )\n        self.time_features = (\n            time_features_from_frequency_str(freq) if freq is not None else []\n        )\n\n        self.num_feat_dynamic_real = (\n            1 + num_feat_dynamic_real + len(self.time_features)\n        )\n        self.num_feat_static_cat = max(num_feat_static_cat, 1)\n        self.num_feat_static_real = max(num_feat_static_real, 1)\n\n        self.use_features = use_features\n        self.use_lags = use_lags\n\n        self.context_length = context_length\n        self.prediction_length = prediction_length\n        self.losses_running_mean = torch.ones(timesteps, requires_grad=False)\n        self.lr = lr\n        self.best_crps = np.inf\n\n    def _extract_features(self, data):\n        raise NotImplementedError()\n\n    def configure_optimizers(self):\n        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)\n        scheduler = ReduceLROnPlateau(\n            optimizer, mode=\"min\", factor=0.5, patience=int(1e12)\n        )\n        return [optimizer], {\"scheduler\": scheduler, \"monitor\": \"train_loss\"}\n\n    def log(self, name, value, **kwargs):\n        super().log(name, value, **kwargs)\n        if isinstance(value, torch.Tensor):\n            value = value.detach().cpu().item()\n        if name not in self.logs:\n            self.logs[name] = [value]\n        else:\n            self.logs[name].append(value)\n\n    def get_logs(self):\n        logs = self.logs\n        logs[\"epochs\"] = list(range(self.current_epoch))\n        return pd.DataFrame.from_dict(logs)\n\n    def q_sample(self, x_start, t, noise=None):\n        device = next(self.backbone.parameters()).device\n        if noise is None:\n            noise = torch.randn_like(x_start, device=device)\n        sqrt_alphas_cumprod_t = extract(\n            self.sqrt_alphas_cumprod, t, x_start.shape\n        )\n        sqrt_one_minus_alphas_cumprod_t = extract(\n            self.sqrt_one_minus_alphas_cumprod, t, x_start.shape\n        )\n\n        return (\n            sqrt_alphas_cumprod_t * x_start\n            + sqrt_one_minus_alphas_cumprod_t * noise\n        )\n\n    def p_losses(\n        self,\n        x_start,\n        t,\n        features=None,\n        noise=None,\n        loss_type=\"l2\",\n        reduction=\"mean\",\n    ):\n        device = next(self.backbone.parameters()).device\n        if noise is None:\n            noise = torch.randn_like(x_start, device=device)\n\n        x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)\n        predicted_noise = self.backbone(x_noisy, t, features)\n\n        if loss_type == \"l1\":\n            loss = F.l1_loss(noise, predicted_noise, reduction=reduction)\n        elif loss_type == \"l2\":\n            loss = F.mse_loss(noise, predicted_noise, reduction=reduction)\n        elif loss_type == \"huber\":\n            loss = F.smooth_l1_loss(\n                noise, predicted_noise, reduction=reduction\n            )\n        else:\n            raise NotImplementedError()\n\n        return loss, x_noisy, predicted_noise\n\n    @torch.no_grad()\n    def p_sample(self, x, t, t_index, features=None):\n        betas_t = extract(self.betas, t, x.shape)\n        sqrt_one_minus_alphas_cumprod_t = extract(\n            self.sqrt_one_minus_alphas_cumprod, t, x.shape\n        )\n        sqrt_recip_alphas_t = extract(self.sqrt_recip_alphas, t, x.shape)\n        predicted_noise = self.backbone(x, t, features)\n\n        model_mean = sqrt_recip_alphas_t * (\n            x - betas_t * predicted_noise / sqrt_one_minus_alphas_cumprod_t\n        )\n\n        if t_index == 0:\n            return model_mean\n        else:\n            posterior_variance_t = extract(self.posterior_variance, t, x.shape)\n            noise = torch.randn_like(x)\n            return model_mean + torch.sqrt(posterior_variance_t) * noise\n\n    @torch.no_grad()\n    def p_sample_ddim(self, x, t, features=None, noise=None):\n        if noise is None:\n            noise = self.backbone(x, t, features)\n        sqrt_alphas_cumprod_prev_t = extract(\n            self.alphas_cumprod_prev, t, x.shape\n        ).sqrt()\n        sqrt_one_minus_alphas_cumprod_prev_t = extract(\n            1 - self.alphas_cumprod_prev, t, x.shape\n        ).sqrt()\n        sqrt_one_minus_alphas_cumprod_t = extract(\n            self.sqrt_one_minus_alphas_cumprod, t, x.shape\n        )\n        sqrt_alphas_cumprod_t = extract(self.sqrt_alphas_cumprod, t, x.shape)\n        x0pointer = (\n            sqrt_alphas_cumprod_prev_t\n            * (x - sqrt_one_minus_alphas_cumprod_t * noise)\n            / sqrt_alphas_cumprod_t\n        )\n        xtpointer = sqrt_one_minus_alphas_cumprod_prev_t * noise\n        return x0pointer + xtpointer\n\n    @torch.no_grad()\n    def p_sample_genddim(\n        self,\n        x: torch.Tensor,\n        t: torch.Tensor,\n        t_index: int,\n        t_prev: Optional[torch.Tensor] = None,\n        eta: float = 0.0,\n        features=None,\n        noise: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n        \"\"\"Generalized DDIM step that interpolates between\n        DDPM (eta=1) and DDIM (eta=0).\n\n        Args:\n            x (torch.Tensor): _description_\n            t (torch.Tensor): _description_\n            features (_type_, optional): _description_. Defaults to None.\n            noise (Optional[torch.Tensor], optional): _description_. Defaults to None.\n\n        Returns:\n            torch.Tensor: _description_\n        \"\"\"\n        if noise is None:\n            noise = self.backbone(x, t, features)\n        if t_prev is None:\n            t_prev = t - 1\n\n        alphas_cumprod_t = extract(self.alphas_cumprod, t, x.shape)\n        alphas_cumprod_prev_t = (\n            extract(self.alphas_cumprod, t_prev, x.shape)\n            if t_index > 0\n            else torch.ones_like(alphas_cumprod_t)\n        )\n        sqrt_alphas_cumprod_prev_t = alphas_cumprod_prev_t.sqrt()\n\n        sqrt_one_minus_alphas_cumprod_t = extract(\n            self.sqrt_one_minus_alphas_cumprod, t, x.shape\n        )\n        sqrt_alphas_cumprod_t = extract(self.sqrt_alphas_cumprod, t, x.shape)\n\n        x0pointer = (\n            sqrt_alphas_cumprod_prev_t\n            * (x - sqrt_one_minus_alphas_cumprod_t * noise)\n            / sqrt_alphas_cumprod_t\n        )\n        c1 = (\n            eta\n            * (\n                (1 - alphas_cumprod_t / alphas_cumprod_prev_t)\n                * (1 - alphas_cumprod_prev_t)\n                / (1 - alphas_cumprod_t)\n            ).sqrt()\n        )\n        c2 = ((1 - alphas_cumprod_prev_t) - c1**2).sqrt()\n        return x0pointer + c1 * torch.randn_like(x) + c2 * noise\n\n    @torch.no_grad()\n    def sample(self, noise, features=None):\n        device = next(self.backbone.parameters()).device\n        batch_size, length, ch = noise.shape\n        seq = noise\n        seqs = [seq.cpu()]\n\n        for i in reversed(range(0, self.timesteps)):\n            seq = self.p_sample(\n                seq,\n                torch.full((batch_size,), i, device=device, dtype=torch.long),\n                i,\n                features,\n            )\n            seqs.append(seq.cpu().numpy())\n\n        return np.stack(seqs, axis=0)\n\n    def fast_denoise(self, xt, t, features=None, noise=None):\n        if noise is None:\n            noise = self.backbone(xt, t, features)\n        sqrt_one_minus_alphas_cumprod_t = extract(\n            self.sqrt_one_minus_alphas_cumprod, t, xt.shape\n        )\n        sqrt_alphas_cumprod_t = extract(self.sqrt_alphas_cumprod, t, xt.shape)\n        return (\n            xt - sqrt_one_minus_alphas_cumprod_t * noise\n        ) / sqrt_alphas_cumprod_t\n\n    def forward(self, x, mask):\n        raise NotImplementedError()\n\n    def training_step(self, data, idx):\n        assert self.training is True\n        device = next(self.backbone.parameters()).device\n        if isinstance(data, dict):\n            x, _, features = self._extract_features(data)\n        else:\n            x, _ = self.scaler(data, torch.ones_like(data))\n\n        t = torch.randint(\n            0, self.timesteps, (x.shape[0],), device=device\n        ).long()\n        elbo_loss, xt, noise = self.p_losses(x, t, features, loss_type=\"l2\")\n        return {\n            \"loss\": elbo_loss,\n            \"elbo_loss\": elbo_loss,\n        }\n\n    def training_epoch_end(self, outputs):\n        epoch_loss = sum(x[\"loss\"] for x in outputs) / len(outputs)\n        elbo_loss = sum(x[\"elbo_loss\"] for x in outputs) / len(outputs)\n        self.log(\"train_loss\", epoch_loss)\n        self.log(\"train_elbo_loss\", elbo_loss)\n\n    def validation_step(self, data, idx):\n        device = next(self.backbone.parameters()).device\n        if isinstance(data, dict):\n            x, _, features = self._extract_features(data)\n        else:\n            x, features = data, None\n        t = torch.randint(\n            0, self.timesteps, (x.shape[0],), device=device\n        ).long()\n        elbo_loss, xt, noise = self.p_losses(x, t, features, loss_type=\"l2\")\n        return {\n            \"loss\": elbo_loss,\n            \"elbo_loss\": elbo_loss,\n        }\n\n    def validation_epoch_end(self, outputs):\n        epoch_loss = sum(x[\"loss\"] for x in outputs) / len(outputs)\n        elbo_loss = sum(x[\"elbo_loss\"] for x in outputs) / len(outputs)\n        self.log(\"valid_loss\", epoch_loss)\n        self.log(\"valid_elbo_loss\", elbo_loss)\n"
  },
  {
    "path": "src/uncond_ts_diff/model/diffusion/tsdiff.py",
    "content": "# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n# SPDX-License-Identifier: Apache-2.0\nimport copy\n\nimport torch\nfrom gluonts.torch.util import lagged_sequence_values\n\nfrom uncond_ts_diff.arch import BackboneModel\nfrom uncond_ts_diff.model.diffusion._base import TSDiffBase\nfrom uncond_ts_diff.utils import get_lags_for_freq\n\n\nclass TSDiff(TSDiffBase):\n    def __init__(\n        self,\n        backbone_parameters,\n        timesteps,\n        diffusion_scheduler,\n        context_length,\n        prediction_length,\n        num_feat_dynamic_real: int = 0,\n        num_feat_static_cat: int = 0,\n        num_feat_static_real: int = 0,\n        cardinalities=None,\n        freq=None,\n        normalization=\"none\",\n        use_features=False,\n        use_lags=True,\n        init_skip=True,\n        lr=1e-3,\n    ):\n        super().__init__(\n            backbone_parameters,\n            timesteps=timesteps,\n            diffusion_scheduler=diffusion_scheduler,\n            context_length=context_length,\n            prediction_length=prediction_length,\n            num_feat_dynamic_real=num_feat_dynamic_real,\n            num_feat_static_cat=num_feat_static_cat,\n            num_feat_static_real=num_feat_static_real,\n            cardinalities=cardinalities,\n            freq=freq,\n            normalization=normalization,\n            use_features=use_features,\n            use_lags=use_lags,\n            lr=lr,\n        )\n\n        self.freq = freq\n        if use_lags:\n            self.lags_seq = get_lags_for_freq(freq)\n            backbone_parameters = backbone_parameters.copy()\n            backbone_parameters[\"input_dim\"] += len(self.lags_seq)\n            backbone_parameters[\"output_dim\"] += len(self.lags_seq)\n        else:\n            self.lags_seq = [0]\n        self.input_dim = backbone_parameters[\"input_dim\"]\n        self.backbone = BackboneModel(\n            **backbone_parameters,\n            num_features=(\n                self.num_feat_static_real\n                + self.num_feat_static_cat\n                + self.num_feat_dynamic_real\n                + 1  # log_scale\n            ),\n            init_skip=init_skip,\n        )\n        self.ema_rate = []  # [0.9999]\n        self.ema_state_dicts = [\n            copy.deepcopy(self.backbone.state_dict())\n            for _ in range(len(self.ema_rate))\n        ]\n\n    def _extract_features(self, data):\n        prior = data[\"past_target\"][:, : -self.context_length]\n        context = data[\"past_target\"][:, -self.context_length :]\n        context_observed = data[\"past_observed_values\"][\n            :, -self.context_length :\n        ]\n        if self.normalization == \"zscore\":\n            scaled_context, scale = self.scaler(\n                context, context_observed, data[\"stats\"]\n            )\n        else:\n            scaled_context, scale = self.scaler(context, context_observed)\n        features = []\n\n        scaled_prior = prior / scale\n        scaled_future = data[\"future_target\"] / scale\n        features.append(scale.log())\n\n        x = torch.cat([scaled_context, scaled_future], dim=1)\n        if data[\"feat_static_cat\"] is not None:\n            features.append(self.embedder(data[\"feat_static_cat\"]))\n        if data[\"feat_static_real\"] is not None:\n            features.append(data[\"feat_static_real\"])\n        static_feat = torch.cat(\n            features,\n            dim=1,\n        )\n        expanded_static_feat = static_feat.unsqueeze(1).expand(\n            -1, x.shape[1], -1\n        )\n\n        features = [expanded_static_feat]\n\n        time_features = []\n        if data[\"past_time_feat\"] is not None:\n            time_features.append(\n                data[\"past_time_feat\"][:, -self.context_length :]\n            )\n        if data[\"future_time_feat\"] is not None:\n            time_features.append(data[\"future_time_feat\"])\n        features.append(torch.cat(time_features, dim=1))\n        features = torch.cat(features, dim=-1)\n\n        if self.use_lags:\n            lags = lagged_sequence_values(\n                self.lags_seq,\n                scaled_prior,\n                torch.cat([scaled_context, scaled_future], dim=1),\n                dim=1,\n            )\n            x = torch.cat([x[:, :, None], lags], dim=-1)\n        else:\n            x = x[:, :, None]\n        if not self.use_features:\n            features = None\n\n        return x, scale[:, :, None], features\n\n    @torch.no_grad()\n    def sample_n(\n        self,\n        num_samples: int = 1,\n        return_lags: bool = False,\n    ):\n        device = next(self.backbone.parameters()).device\n        seq_len = self.context_length + self.prediction_length\n\n        samples = torch.randn(\n            (num_samples, seq_len, self.input_dim), device=device\n        )\n\n        for i in reversed(range(0, self.timesteps)):\n            t = torch.full((num_samples,), i, device=device, dtype=torch.long)\n            samples = self.p_sample(samples, t, i, features=None)\n\n        samples = samples.cpu().numpy()\n\n        if return_lags:\n            return samples\n\n        return samples[..., 0]\n\n    def on_train_batch_end(self, outputs, batch, batch_idx):\n        for rate, state_dict in zip(self.ema_rate, self.ema_state_dicts):\n            update_ema(state_dict, self.backbone.state_dict(), rate=rate)\n\n\ndef update_ema(target_state_dict, source_state_dict, rate=0.99):\n    with torch.no_grad():\n        for key, value in source_state_dict.items():\n            ema_value = target_state_dict[key]\n            ema_value.copy_(\n                rate * ema_value + (1.0 - rate) * value.cpu(),\n                non_blocking=True,\n            )\n"
  },
  {
    "path": "src/uncond_ts_diff/model/diffusion/tsdiff_cond.py",
    "content": "# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n# SPDX-License-Identifier: Apache-2.0\nimport torch\nfrom gluonts.torch.model.predictor import PyTorchPredictor\nfrom gluonts.torch.util import lagged_sequence_values\n\nfrom uncond_ts_diff.arch import BackboneModel\nfrom uncond_ts_diff.model.diffusion._base import TSDiffBase\nfrom uncond_ts_diff.model.diffusion._base import PREDICTION_INPUT_NAMES\nfrom uncond_ts_diff.utils import get_lags_for_freq\n\nPREDICTION_INPUT_NAMES = PREDICTION_INPUT_NAMES + [\"orig_past_target\"]\n\n\nclass TSDiffCond(TSDiffBase):\n    def __init__(\n        self,\n        backbone_parameters,\n        timesteps,\n        diffusion_scheduler,\n        context_length,\n        prediction_length,\n        num_feat_dynamic_real: int = 0,\n        num_feat_static_cat: int = 0,\n        num_feat_static_real: int = 0,\n        cardinalities=None,\n        freq=None,\n        normalization=\"none\",\n        use_features=False,\n        use_lags=True,\n        lr=1e-3,\n        init_skip=True,\n        noise_observed=True,\n    ):\n        super().__init__(\n            backbone_parameters,\n            timesteps=timesteps,\n            diffusion_scheduler=diffusion_scheduler,\n            context_length=context_length,\n            prediction_length=prediction_length,\n            num_feat_dynamic_real=num_feat_dynamic_real,\n            num_feat_static_cat=num_feat_static_cat,\n            num_feat_static_real=num_feat_static_real,\n            cardinalities=cardinalities,\n            freq=freq,\n            normalization=normalization,\n            use_features=use_features,\n            use_lags=use_lags,\n            lr=lr,\n        )\n\n        num_features = (\n            (\n                self.num_feat_dynamic_real\n                + self.num_feat_static_cat\n                + self.num_feat_static_real\n                + 1\n            )\n            if use_features\n            else 0\n        )\n        self.freq = freq\n        self.lags_seq = get_lags_for_freq(freq) if use_lags else [0]\n        self.backbone = BackboneModel(\n            **backbone_parameters,\n            num_features=(\n                num_features + 2 + (len(self.lags_seq) if use_lags else 0)\n            ),\n            init_skip=init_skip,\n        )\n        self.noise_observed = noise_observed\n\n    def _extract_features(self, data):\n        device = next(self.parameters()).device\n        prior = data[\"past_target\"][:, : -self.context_length]\n        context = data[\"past_target\"][:, -self.context_length :]\n        context_observed = data[\"past_observed_values\"][\n            :, -self.context_length :\n        ]\n        scaled_context, scale = self.scaler(context, context_observed)\n        features = []\n\n        scaled_prior = prior / scale\n        scaled_future = data[\"future_target\"] / scale\n        scaled_orig_context = (\n            data[\"orig_past_target\"][:, -self.context_length :]\n        ) / scale\n\n        x = torch.cat([scaled_orig_context, scaled_future], dim=1)\n        observation_mask = torch.zeros_like(x, device=device)\n        observation_mask[:, : -self.prediction_length] = data[\n            \"past_observed_values\"\n        ][:, -self.context_length :].data\n        x_past = torch.cat(\n            [scaled_context, torch.zeros_like(scaled_future)], dim=1\n        ).clone()\n\n        assert x.size() == x_past.size()\n\n        if data[\"feat_static_cat\"] is not None:\n            features.append(self.embedder(data[\"feat_static_cat\"]))\n        if data[\"feat_static_real\"] is not None:\n            features.append(data[\"feat_static_real\"])\n        static_feat = torch.cat(\n            features,\n            dim=1,\n        )\n        expanded_static_feat = static_feat.unsqueeze(1).expand(\n            -1, x.shape[1], -1\n        )\n        features = []\n        if self.use_features:\n            features.append(expanded_static_feat)\n\n            time_features = []\n            if data[\"past_time_feat\"] is not None:\n                time_features.append(\n                    data[\"past_time_feat\"][:, -self.context_length :]\n                )\n            if data[\"future_time_feat\"] is not None:\n                time_features.append(data[\"future_time_feat\"])\n            features.append(torch.cat(time_features, dim=1))\n        lags = lagged_sequence_values(\n            self.lags_seq,\n            scaled_prior,\n            torch.cat([scaled_context, scaled_future], dim=1),\n            dim=1,\n        )\n        if self.use_lags:\n            features.append(lags)\n        features.append(x_past[..., None])\n        features.append(observation_mask[..., None])\n        features = torch.cat(features, dim=-1)\n        return x[..., None], scale[..., None], features\n\n    def step(self, x, t, features, loss_mask):\n        noise = torch.randn_like(x)\n        if not self.noise_observed:\n            noise = (1 - loss_mask) * x + noise * loss_mask\n\n        num_eval = loss_mask.sum()\n        sq_err, _, _ = self.p_losses(\n            x,\n            t,\n            features,\n            loss_type=\"l2\",\n            reduction=\"none\",\n            noise=noise,\n        )\n\n        if self.noise_observed:\n            elbo_loss = sq_err.mean()\n        else:\n            sq_err = sq_err * loss_mask\n            elbo_loss = sq_err.sum() / (num_eval if num_eval else 1)\n        return elbo_loss\n\n    def training_step(self, data, idx):\n        assert self.training is True\n        device = next(self.parameters()).device\n\n        x, _, features = self._extract_features(data)\n\n        # Last dim of features has the observation mask\n        observation_mask = features[..., -1:]\n        loss_mask = 1 - observation_mask\n\n        t = torch.randint(\n            0, self.timesteps, (x.shape[0],), device=device\n        ).long()\n        elbo_loss = self.step(x, t, features, loss_mask)\n        return {\n            \"loss\": elbo_loss,\n            \"elbo_loss\": elbo_loss,\n        }\n\n    def validation_step(self, data, idx):\n        device = next(self.parameters()).device\n\n        x, _, features = self._extract_features(data)\n\n        # Last dim of features has the observation mask\n        observation_mask = features[..., -1:]\n        loss_mask = 1 - observation_mask\n\n        val_loss = 0.0\n        for i in range(self.timesteps):\n            t = torch.full((x.shape[0],), i, device=device).long()\n            val_loss += self.step(x, t, features, loss_mask)\n\n        val_loss /= self.timesteps\n\n        return {\n            \"loss\": val_loss,\n            \"elbo_loss\": val_loss,\n        }\n\n    @torch.no_grad()\n    def forecast(self, observation, observation_mask, features=None):\n        device = next(self.backbone.parameters()).device\n        batch_size, length, ch = observation.shape\n\n        seq = torch.randn_like(observation)\n\n        for i in reversed(range(0, self.timesteps)):\n            if not self.noise_observed:\n                seq = observation_mask * observation + seq * (\n                    1 - observation_mask\n                )\n\n            seq = self.p_sample(\n                seq,\n                torch.full((batch_size,), i, device=device, dtype=torch.long),\n                i,\n                features,\n            )\n\n        return seq\n\n    def forward(\n        self,\n        past_target: torch.Tensor,\n        past_observed_values: torch.Tensor,\n        feat_static_cat: torch.Tensor = None,\n        feat_static_real: torch.Tensor = None,\n        past_time_feat: torch.Tensor = None,\n        future_time_feat: torch.Tensor = None,\n        orig_past_target: torch.Tensor = None,\n    ):\n        # This is only used during prediction\n        device = next(self.backbone.parameters()).device\n        data = dict(\n            feat_static_cat=feat_static_cat.to(device)\n            if feat_static_cat is not None\n            else None,\n            feat_static_real=feat_static_real.to(device)\n            if feat_static_real is not None\n            else None,\n            past_time_feat=past_time_feat.to(device)\n            if past_time_feat is not None\n            else None,\n            past_target=past_target.to(device),\n            orig_past_target=orig_past_target.to(device),\n            future_target=torch.zeros(\n                past_target.shape[0], self.prediction_length, device=device\n            ),\n            past_observed_values=past_observed_values.to(device)\n            if past_observed_values is not None\n            else None,\n            future_time_feat=future_time_feat.to(device)\n            if future_time_feat is not None\n            else None,\n        )\n\n        observation, scale, features = self._extract_features(data)\n        observation = observation.to(device)\n        batch_size, length, ch = observation.shape\n        observation_mask = features[..., -1:]\n\n        pred = self.forecast(\n            observation=observation,\n            observation_mask=observation_mask,\n            features=features,\n        )\n\n        pred = pred * scale\n\n        return pred[:, None, length - self.prediction_length :, 0]\n\n    def get_predictor(self, input_transform, batch_size=40, device=None):\n        return PyTorchPredictor(\n            prediction_length=self.prediction_length,\n            input_names=PREDICTION_INPUT_NAMES,\n            prediction_net=self,\n            batch_size=batch_size,\n            input_transform=input_transform,\n            device=device,\n        )\n"
  },
  {
    "path": "src/uncond_ts_diff/model/linear/_estimator.py",
    "content": "# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n# SPDX-License-Identifier: Apache-2.0\nfrom typing import Optional, List\nimport math\n\nimport numpy as np\nfrom sklearn.linear_model import LinearRegression, Ridge\nfrom gluonts.model import Estimator, Predictor\nfrom gluonts.dataset.common import Dataset\nfrom gluonts.dataset.field_names import FieldName\nfrom gluonts.transform import (\n    Transformation,\n    AddObservedValuesIndicator,\n    InstanceSplitter,\n    TestSplitSampler,\n    ExpectedNumInstanceSampler,\n    SelectFields,\n)\nfrom gluonts.dataset.loader import TrainDataLoader, InferenceDataLoader\nfrom gluonts.itertools import Cached\nfrom gluonts.model.forecast_generator import (\n    ForecastGenerator,\n    SampleForecastGenerator,\n    predict_to_numpy,\n)\n\nfrom ._scaler import MeanScaler, NOPScaler\n\nPREDICTION_INPUT_NAMES = [\n    \"past_target\",\n    \"past_observed_values\",\n]\n\nTRAINING_INPUT_NAMES = PREDICTION_INPUT_NAMES + [\n    \"future_target\",\n    \"future_observed_values\",\n]\n\n\ndef stack(data):\n    if isinstance(data[0], np.ndarray):\n        data = np.array(data)\n    elif isinstance(data[0], (list, tuple)):\n        return list(stack(t) for t in zip(*data))\n    return data\n\n\ndef batchify(data: List[dict]):\n    return {\n        key: stack(data=[item[key] for item in data]) for key in data[0].keys()\n    }\n\n\nclass LinearModel:\n    def __init__(self, weight, bias, scaler, num_parallel_samples=100) -> None:\n        super().__init__()\n        self.scaler = scaler\n        self.weight = weight\n        self.bias = bias\n        self.num_parallel_samples = num_parallel_samples\n\n    def _linear(self, x, A, b):\n        return x @ A.T + b\n\n    def __call__(self, x, mask):\n        assert x.ndim == 2\n        x, scale = self.scaler(x, np.ones_like(x))\n        out = self._linear(x, self.weight, self.bias) * scale\n        return np.tile(out[:, None], (1, self.num_parallel_samples, 1))\n\n\n@predict_to_numpy.register(LinearModel)\ndef _(prediction_net, args) -> np.ndarray:\n    return prediction_net(*args)\n\n\nclass LinearPredictor(Predictor):\n    def __init__(\n        self,\n        input_names: List[str],\n        prediction_net: LinearModel,\n        batch_size: int,\n        prediction_length: int,\n        input_transform: Transformation,\n        forecast_generator: ForecastGenerator = SampleForecastGenerator(),\n        lead_time: int = 0,\n    ) -> None:\n        super().__init__(prediction_length, lead_time=lead_time)\n        self.input_names = input_names\n        self.prediction_net = prediction_net\n        self.batch_size = batch_size\n        self.input_transform = input_transform\n        self.forecast_generator = forecast_generator\n\n    def predict(self, dataset: Dataset, num_samples: Optional[int] = None):\n        inference_data_loader = InferenceDataLoader(\n            dataset,\n            transform=self.input_transform,\n            batch_size=self.batch_size,\n            stack_fn=batchify,\n        )\n\n        yield from self.forecast_generator(\n            inference_data_loader=inference_data_loader,\n            prediction_net=self.prediction_net,\n            input_names=self.input_names,\n            output_transform=None,\n            num_samples=num_samples,\n        )\n\n\nclass LinearEstimator(Estimator):\n    \"\"\"A Linear regressor that takes inputs of size equal to `context_length`\n    and outputs forecasts of size equal to `prediction_length`. This model uses\n    LinearRegression from scikit-learn under the hood.\n\n    Example usage:\n    ```python\n    estimator = LinearEstimator(\n        dataset.metadata.freq,\n        prediction_length=dataset.metadata.prediction_length,\n        context_length=24 * 7 * 2,\n    )\n\n    predictor = estimator.train(dataset.train)\n    ```\n\n    Parameters\n    ----------\n    freq\n        Frequency of the dataset (not actually used)\n    prediction_length\n        Prediction length\n    context_length, optional\n        Context length for the linear model,\n        by default equal to 4 * prediction_length\n    num_train_samples, optional\n        Number of samples used to fit the LinearRegression model,\n        by default 10000\n    model, optional\n        Which sklearn linear model to use, one of {\"linear\", \"ridge\"},\n        by default \"ridge\".\n    scaling, optional\n        Whether to use scaling, by default True\n    batch_size, optional\n        Batch size (only relevant during prediction), by default 64\n    \"\"\"\n\n    def __init__(\n        self,\n        freq: str,\n        prediction_length: int,\n        context_length: Optional[int] = None,\n        num_train_samples: int = 10000,\n        model: str = \"ridge\",\n        scaling: bool = True,\n        batch_size: int = 64,\n        **kwargs,\n    ) -> None:\n        super().__init__(**kwargs)\n        assert model in {\"linear\", \"ridge\"}\n        self.freq = freq\n        self.prediction_length = prediction_length\n        self.context_length = context_length or 4 * prediction_length\n        self.num_train_samples = num_train_samples\n        self.model = model\n\n        if scaling:\n            self.scaler = MeanScaler(axis=-1, keepdims=True)\n        else:\n            self.scaler = NOPScaler(axis=-1, keepdims=True)\n        self.batch_size = batch_size\n\n    def create_transformation(self) -> Transformation:\n        return SelectFields(\n            [\n                FieldName.ITEM_ID,\n                FieldName.INFO,\n                FieldName.START,\n                FieldName.TARGET,\n            ],\n            allow_missing=True,\n        ) + AddObservedValuesIndicator(\n            target_field=FieldName.TARGET,\n            output_field=FieldName.OBSERVED_VALUES,\n        )\n\n    def _create_instance_splitter(self, mode: str):\n        assert mode in [\"training\", \"test\"]\n\n        instance_sampler = {\n            \"training\": ExpectedNumInstanceSampler(\n                num_instances=1,\n                min_past=self.context_length,\n                min_future=self.prediction_length,\n            ),\n            \"test\": TestSplitSampler(),\n        }[mode]\n\n        return InstanceSplitter(\n            target_field=FieldName.TARGET,\n            is_pad_field=FieldName.IS_PAD,\n            start_field=FieldName.START,\n            forecast_start_field=FieldName.FORECAST_START,\n            instance_sampler=instance_sampler,\n            past_length=self.context_length,\n            future_length=self.prediction_length,\n            time_series_fields=[\n                FieldName.OBSERVED_VALUES,\n            ],\n        )\n\n    def _create_training_samples(self, training_data) -> np.ndarray:\n        transformation = self._create_instance_splitter(\n            \"training\"\n        ) + SelectFields(TRAINING_INPUT_NAMES)\n        num_batches_per_epoch = math.ceil(self.num_train_samples / 100)\n        data_loader = TrainDataLoader(\n            training_data,\n            batch_size=100,\n            stack_fn=batchify,\n            transform=transformation,\n            num_batches_per_epoch=num_batches_per_epoch,\n        )\n\n        train_X, train_y = [], []\n        for batch in data_loader:\n            train_X.append(batch[\"past_target\"])\n            train_y.append(batch[\"future_target\"])\n            assert np.all(batch[\"past_observed_values\"] == 1.0) and np.all(\n                batch[\"future_observed_values\"] == 1.0\n            ), \"Missing values not supported!\"\n        train_X = np.concatenate(train_X, 0)\n        train_y = np.concatenate(train_y, 0)\n        train_X = train_X[: self.num_train_samples]\n        train_y = train_y[: self.num_train_samples]\n\n        assert len(train_X) == self.num_train_samples\n\n        return train_X, train_y\n\n    def create_predictor(self, transformation, model):\n        prediction_splitter = self._create_instance_splitter(\"test\")\n        return LinearPredictor(\n            input_names=PREDICTION_INPUT_NAMES,\n            prediction_net=model,\n            batch_size=self.batch_size,\n            prediction_length=self.prediction_length,\n            input_transform=transformation + prediction_splitter,\n        )\n\n    def train(\n        self,\n        training_data: Dataset,\n        validation_data: Optional[Dataset] = None,\n        cache_data: bool = False,\n    ) -> Predictor:\n        transformation = self.create_transformation()\n        transformed_data = transformation.apply(training_data, is_train=True)\n\n        if cache_data:\n            transformed_data = Cached(transformed_data)\n\n        train_X, train_y = self._create_training_samples(transformed_data)\n        scaled_train_X, scale = self.scaler(train_X, np.ones_like(train_X))\n        scaled_train_y = train_y / scale\n\n        if self.model == \"linear\":\n            SKLearnLinear = LinearRegression\n        elif self.model == \"ridge\":\n            SKLearnLinear = Ridge\n        regressor = SKLearnLinear().fit(scaled_train_X, scaled_train_y)\n        model = LinearModel(regressor.coef_, regressor.intercept_, self.scaler)\n        return self.create_predictor(\n            transformation=transformation, model=model\n        )\n"
  },
  {
    "path": "src/uncond_ts_diff/model/linear/_scaler.py",
    "content": "# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n# SPDX-License-Identifier: Apache-2.0\nfrom typing import Optional, Tuple\n\nimport numpy as np\n\n\nclass MeanScaler:\n    \"\"\"Just like torch MeanScaler, but for numpy.\"\"\"\n\n    def __init__(\n        self,\n        axis: int,\n        keepdims: bool = False,\n        default_scale: Optional[float] = None,\n        minimum_scale: float = 1e-10,\n    ):\n        super().__init__()\n        self.axis = axis\n        self.keepdims = keepdims\n        self.minimum_scale = minimum_scale\n        self.default_scale = default_scale or 0.0\n\n    def __call__(\n        self, data: np.ndarray, weights: np.ndarray\n    ) -> Tuple[np.ndarray, np.ndarray]:\n        # these will have shape (N, C)\n        total_weight = weights.sum(axis=self.axis)\n        weighted_sum = (np.abs(data) * weights).sum(axis=self.axis)\n\n        # first compute a global scale per-dimension\n        total_observed = total_weight.sum(axis=0)\n        denominator = np.maximum(total_observed, np.ones_like(total_observed))\n\n        if self.default_scale != 0.0:\n            default_scale = self.default_scale\n        else:\n            default_scale = weighted_sum.sum(axis=0) / denominator\n\n        # then compute a per-item, per-dimension scale\n        denominator = np.maximum(total_weight, np.ones_like(total_weight))\n        scale = weighted_sum / denominator\n\n        # use per-batch scale when no element is observed\n        # or when the sequence contains only zeros\n        scale = np.expand_dims(\n            np.maximum(\n                self.minimum_scale,\n                np.where(\n                    weighted_sum > np.zeros_like(weighted_sum),\n                    scale,\n                    default_scale * np.ones_like(total_weight),\n                ),\n            ),\n            axis=self.axis,\n        )\n\n        return data / scale, scale if self.keepdims else scale.squeeze(\n            axis=self.axis\n        )\n\n\nclass NOPScaler:\n    \"\"\"\n    Just like torch NOPScaler, but for numpy.\n    \"\"\"\n\n    def __init__(self, axis: int, keepdims: bool = False):\n        super().__init__()\n        self.axis = axis\n        self.keepdims = keepdims\n\n    def __call__(\n        self, data: np.ndarray, weights: np.ndarray\n    ) -> Tuple[np.ndarray, np.ndarray]:\n        scale = np.ones_like(data).mean(\n            axis=self.axis,\n            keepdims=self.keepdims,\n        )\n        return data, scale\n"
  },
  {
    "path": "src/uncond_ts_diff/predictor.py",
    "content": "# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n# SPDX-License-Identifier: Apache-2.0\nfrom typing import Iterator, Optional\n\nfrom gluonts.dataset import Dataset\nfrom gluonts.dataset.loader import InferenceDataLoader\nfrom gluonts.model import Forecast\nfrom gluonts.torch.batchify import batchify\nfrom gluonts.torch.model.predictor import PyTorchPredictor\n\n\nclass PyTorchPredictorWGrads(PyTorchPredictor):\n    def predict(\n        self, dataset: Dataset, num_samples: Optional[int] = None\n    ) -> Iterator[Forecast]:\n        inference_data_loader = InferenceDataLoader(\n            dataset,\n            transform=self.input_transform,\n            batch_size=self.batch_size,\n            stack_fn=lambda data: batchify(data, self.device),\n        )\n\n        self.prediction_net.eval()\n\n        yield from self.forecast_generator(\n            inference_data_loader=inference_data_loader,\n            prediction_net=self.prediction_net,\n            input_names=self.input_names,\n            output_transform=self.output_transform,\n            num_samples=num_samples,\n        )\n"
  },
  {
    "path": "src/uncond_ts_diff/sampler/__init__.py",
    "content": "# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n# SPDX-License-Identifier: Apache-2.0\nfrom .observation_guidance import DDIMGuidance, DDPMGuidance\nfrom .refiner import MostLikelyRefiner, MCMCRefiner\n\n__all__ = [\n    \"DDIMGuidance\",\n    \"DDPMGuidance\",\n    \"MostLikelyRefiner\",\n    \"MCMCRefiner\",\n]\n"
  },
  {
    "path": "src/uncond_ts_diff/sampler/_base.py",
    "content": "# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n# SPDX-License-Identifier: Apache-2.0\nfrom typing import Callable, Tuple\nfrom functools import partial\n\nimport numpy as np\nimport torch\n\n\ndef grad_fn(fn, x):\n    x.requires_grad_(True)\n    return torch.autograd.grad(fn(x), x)[0]\n\n\n@torch.no_grad()\ndef langevin_dynamics(\n    z0: torch.Tensor,\n    energy_func: Callable = None,\n    score_func: Callable = None,\n    step_size: float = 0.1,\n    noise_scale: float = 0.1,\n    n_steps: int = 1,\n):\n    \"\"\"Overdamped Langevin dynamics.\n\n    Parameters\n    ----------\n    z0\n        Initial guess.\n    energy_func, optional\n        Energy function, only one of energy function or score function\n        must be specified, by default None\n    score_func, optional\n        Score function, only one of energy function or score function\n        must be specified, by default None\n    step_size, optional\n        Step size, by default 0.1\n    noise_scale, optional\n        Scale for Brownian noise, by default 0.1\n    n_steps, optional\n        Number of Langevin steps, by default 1\n\n    Returns\n    -------\n        Updated point.\n    \"\"\"\n    assert energy_func is not None or score_func is not None\n    z = z0\n    sqrt_2eta = torch.sqrt(2 * torch.tensor(step_size))\n    for _ in range(n_steps):\n        if energy_func is not None:\n            with torch.enable_grad():\n                z.requires_grad_(True)\n                Ez = energy_func(z)\n                v = -torch.autograd.grad(Ez, z)[0]\n        else:\n            v = score_func(z)\n        z = (\n            z.detach()\n            + step_size * v\n            + sqrt_2eta * noise_scale * torch.randn_like(z)\n        )\n    return z\n\n\n@torch.enable_grad()\ndef leapfrog(\n    xt: torch.Tensor,\n    pt: torch.Tensor,\n    dynamics_p: Callable[[torch.Tensor], torch.Tensor],\n    mass: float,\n    h: float,\n    n_steps: int,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"Leapfrong integrator.\n\n    Parameters\n    ----------\n    xt\n        Position.\n    pt\n        Momentum.\n    dynamics_p\n        Dynamics function for momentum\n    mass\n        Mass of particle\n    h\n        Step size\n    n_steps\n        Number of leapfrog integration steps\n\n    Returns\n    -------\n        Updated position and momentum.\n    \"\"\"\n    for _ in range(n_steps):\n        pt = pt - (h / 2) * dynamics_p(xt)\n        xt = xt + h * pt / mass\n        pt = pt - (h / 2) * dynamics_p(xt)\n        xt, pt = xt.detach(), pt.detach()\n\n    return xt, pt\n\n\n@torch.no_grad()\ndef hmc(\n    x0: torch.Tensor,\n    energy_func: Callable[[torch.Tensor], torch.Tensor],\n    step_size: float,\n    mass: float,\n    n_leapfrog_steps: int = 10,\n    n_steps: int = 100,\n) -> torch.Tensor:\n    \"\"\"Hamiltonian Monte Carlo.\n\n    Parameters\n    ----------\n    x0\n        Initial guess of shape [B, T, C].\n    energy_func\n        Energy function E: [B, T, C] -> []\n    step_size\n        Step size.\n    mass\n        Mass of particle.\n    n_leapfrog_steps, optional\n        Number of leapfrog integration steps, by default 10\n    n_steps, optional\n        Number of HMC steps, by default 100\n\n    Returns\n    -------\n        Updated tensor of shape [B, T, C].\n    \"\"\"\n    potential_energy_func = energy_func\n    batch_size, length, ch = x0.shape\n\n    drift_func = partial(grad_fn, potential_energy_func)\n    xt = x0\n    for _ in range(n_steps):\n        pt = np.sqrt(mass) * torch.randn_like(xt)\n        xt_prop, pt_prop = leapfrog(\n            xt, pt, drift_func, mass, step_size, n_leapfrog_steps\n        )\n        xt = xt_prop\n\n    return xt\n\n\ndef linear_midpoint_em_step(\n    zt: torch.Tensor, coeff: float, h: float, sigma: float\n):\n    \"\"\"Midpoint Euler-Maruyama step.\"\"\"\n    eta = torch.randn_like(zt)\n    ztp1 = zt - h * coeff * zt / 2 + np.sqrt(h) * sigma * eta\n    ztp1 = ztp1 / (1 + h * coeff / 2)\n    return ztp1.detach()\n\n\n@torch.no_grad()\ndef udld(\n    x0: torch.Tensor,\n    potential_energy_func: Callable[[torch.Tensor], torch.Tensor],\n    step_size: float,\n    friction: float,\n    mass: float,\n    n_leapfrog_steps: int = 1,\n    n_steps: int = 100,\n) -> torch.Tensor:\n    \"\"\"Underdamped Langevin dynamics.\n\n    Parameters\n    ----------\n    x0\n        Initial guess of shape [B, T, C]\n    potential_energy_func\n        Energy function E: [B, T, C] -> []\n    step_size\n        Step size\n    friction\n        Friction coefficient\n    mass\n        Mass of the particle\n    n_leapfrog_steps, optional\n        Number of leapfrog integration steps, by default 1\n    n_steps, optional\n        Number of UDLD steps, by default 100\n\n    Returns\n    -------\n         Updated tensor of shape [B, T, C].\n    \"\"\"\n    batch_size, length, ch = x0.shape\n    xt = x0\n    drift_func = partial(grad_fn, potential_energy_func)\n\n    pt = np.sqrt(mass) * torch.randn_like(xt)\n\n    coeff = friction / mass\n    sigma = np.sqrt(2 * friction)\n    for _ in range(n_steps):\n        pt = linear_midpoint_em_step(pt, coeff, step_size / 2, sigma)\n        xt_prop, pt_prop = leapfrog(\n            xt, pt, drift_func, mass, step_size, n_leapfrog_steps\n        )\n        xt, pt = xt_prop, pt_prop\n        pt = linear_midpoint_em_step(pt, coeff, step_size / 2, sigma)\n\n    return xt\n"
  },
  {
    "path": "src/uncond_ts_diff/sampler/observation_guidance.py",
    "content": "# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n# SPDX-License-Identifier: Apache-2.0\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom gluonts.torch.util import lagged_sequence_values\n\nfrom uncond_ts_diff.predictor import PyTorchPredictorWGrads\nfrom uncond_ts_diff.utils import extract\nfrom uncond_ts_diff.model import TSDiff\n\nPREDICTION_INPUT_NAMES = [\n    \"past_target\",\n    \"past_observed_values\",\n    \"feat_static_cat\",\n    \"feat_static_real\",\n    \"past_time_feat\",\n    \"future_time_feat\",\n    \"stats\",\n]\n\n\nclass Guidance(torch.nn.Module):\n    _missing_scenarios = [\"none\", \"RM\", \"BM-B\", \"BM-E\"]\n\n    def __init__(\n        self,\n        model: TSDiff,\n        prediction_length: int,\n        scale: float = 1.0,\n        num_samples: int = 1,\n        guidance: str = \"quantile\",\n        missing_scenario: str = \"none\",\n        missing_values: int = 0,\n    ):\n        super().__init__()\n        assert missing_scenario in self._missing_scenarios\n\n        self.model = model\n        self.prediction_length = prediction_length\n        self.scale = scale\n        self.num_samples = num_samples\n        self.guidance = guidance\n        self.missing_scenario = missing_scenario\n        self.missing_values = missing_values\n\n    def quantile_loss(self, y_prediction, y_target):\n        assert y_target.shape == y_prediction.shape\n        device = y_prediction.device\n        batch_size_x_num_samples, length, ch = y_target.shape\n        batch_size = batch_size_x_num_samples // self.num_samples\n        # num_samples uniformly distributed quantiles between 0 and 1\n        # repeat for each element in the batch\n        q = (torch.arange(self.num_samples).repeat(batch_size) + 1).to(\n            device\n        ) / (self.num_samples + 1)\n        # (batch_size x num_samples,)\n\n        q = q[:, None, None]  # (batch_size x num_samples, 1, 1)\n        e = y_target - y_prediction\n        loss = torch.max(q * e, (q - 1) * e)\n        return loss\n\n    def energy_func(self, y, t, observation, observation_mask, features):\n        if self.guidance == \"MSE\":\n            return F.mse_loss(\n                self.model.fast_denoise(y, t, features),\n                observation,\n                reduction=\"none\",\n            )[observation_mask == 1].sum()\n        elif self.guidance == \"quantile\":\n            return self.quantile_loss(\n                self.model.fast_denoise(y, t, features),\n                observation,\n            )[observation_mask == 1].sum()\n        else:\n            raise ValueError(f\"Unknown guidance {self.guidance}!\")\n\n    def score_func(self, y, t, observation, observation_mask, features):\n        with torch.enable_grad():\n            y.requires_grad_(True)\n            Ey = self.energy_func(\n                y, t, observation, observation_mask, features\n            )\n            return -torch.autograd.grad(Ey, y)[0]\n\n    def scale_func(self, y, t, base_scale):\n        raise NotImplementedError(\"Must be implemented by a subclass!\")\n\n    def guide(self, observation, observation_mask, features, scale):\n        raise NotImplementedError(\"Must be implemented by a subclass!\")\n\n    def forward(\n        self,\n        past_target: torch.Tensor,\n        past_observed_values: torch.Tensor,\n        feat_static_cat: torch.Tensor = None,\n        feat_static_real: torch.Tensor = None,\n        past_time_feat: torch.Tensor = None,\n        future_time_feat: torch.Tensor = None,\n        stats: torch.Tensor = None,\n    ):\n        device = next(self.model.parameters()).device\n\n        future_target = torch.zeros(\n            past_target.shape[0], self.prediction_length, device=device\n        )\n        data = dict(\n            feat_static_cat=feat_static_cat.to(device)\n            if feat_static_cat is not None\n            else None,\n            feat_static_real=feat_static_real.to(device)\n            if feat_static_real is not None\n            else None,\n            past_time_feat=past_time_feat.to(device)\n            if past_time_feat is not None\n            else None,\n            past_target=past_target.to(device),\n            future_target=future_target,\n            past_observed_values=past_observed_values.to(device)\n            if past_observed_values is not None\n            else None,\n            future_time_feat=future_time_feat.to(device)\n            if future_time_feat is not None\n            else None,\n            stats=stats.to(device) if stats is not None else None,\n        )\n\n        observation, scale_params, features = self.model._extract_features(\n            data\n        )\n\n        observation = observation.to(device)\n\n        batch_size, length, ch = observation.shape\n        prior_mask = past_observed_values[:, : -self.model.context_length]\n        context_mask = past_observed_values[:, -self.model.context_length :]\n        future_mask = torch.zeros_like(future_target)\n        observation_mask = torch.cat([context_mask, future_mask], dim=1)\n        if self.model.use_lags:\n            lagged_mask = lagged_sequence_values(\n                self.model.lags_seq,\n                prior_mask,\n                observation_mask,\n                dim=1,\n            )\n            observation_mask = torch.cat(\n                [observation_mask[:, :, None], lagged_mask], dim=-1\n            )\n        else:\n            observation_mask = observation_mask[:, :, None]\n\n        observation = observation.repeat_interleave(self.num_samples, dim=0)\n        observation_mask = observation_mask.repeat_interleave(\n            self.num_samples, dim=0\n        )\n        if features is not None:\n            features = features.repeat_interleave(self.num_samples, dim=0)\n\n        # base_scale = self.scale / (\n        #     context_mask.sum() / torch.ones_like(context_mask).sum()\n        # )\n        base_scale = self.scale\n\n        pred = self.guide(observation, observation_mask, features, base_scale)\n        pred = pred[:, :, 0].reshape(batch_size, self.num_samples, -1)\n        pred = pred * scale_params\n\n        return pred[..., length - self.prediction_length :]\n\n    def get_predictor(self, input_transform, batch_size=40, device=None):\n        return PyTorchPredictorWGrads(\n            prediction_length=self.prediction_length,\n            input_names=PREDICTION_INPUT_NAMES,\n            prediction_net=self,\n            batch_size=batch_size,\n            input_transform=input_transform,\n            device=device,\n        )\n\n\nclass DDPMGuidance(Guidance):\n    def __init__(\n        self,\n        model: TSDiff,\n        prediction_length: int,\n        scale: float = 1,\n        num_samples: int = 1,\n        guidance: str = \"quantile\",\n        missing_scenario: str = \"none\",\n        missing_values: int = 0,\n    ):\n        super().__init__(\n            model,\n            prediction_length,\n            scale,\n            num_samples,\n            guidance,\n            missing_scenario,\n            missing_values,\n        )\n\n    def scale_func(self, y, t, base_scale):\n        return extract(self.model.posterior_variance, t, y.shape) * base_scale\n\n    @torch.no_grad()\n    def _reverse_diffusion(\n        self, observation, observation_mask, features, base_scale\n    ):\n        device = observation.device\n        batch_size = observation.shape[0]\n\n        seq = torch.randn_like(observation)\n        for i in reversed(range(0, self.model.timesteps)):\n            t = torch.full((batch_size,), i, device=device, dtype=torch.long)\n            seq = self.model.p_sample(seq, t, i, features)\n            scale = self.scale_func(seq, t, base_scale=base_scale)\n            seq = seq + scale * self.score_func(\n                seq,\n                t,\n                observation=observation,\n                observation_mask=observation_mask,\n                features=features,\n            )\n\n        return seq\n\n    def guide(self, observation, observation_mask, features, base_scale):\n        return self._reverse_diffusion(\n            observation, observation_mask, features, base_scale\n        )\n\n\nclass DDIMGuidance(Guidance):\n    _skip_types = [\"uniform\", \"quadratic\"]\n\n    def __init__(\n        self,\n        model: TSDiff,\n        prediction_length: int,\n        eta: float = 0.0,\n        skip_factor: int = 1,\n        skip_type: str = \"uniform\",\n        scale: float = 1,\n        num_samples: int = 1,\n        guidance: str = \"quantile\",\n        missing_scenario: str = \"none\",\n        missing_values: int = 0,\n    ):\n        super().__init__(\n            model,\n            prediction_length,\n            scale,\n            num_samples,\n            guidance,\n            missing_scenario,\n            missing_values,\n        )\n        assert skip_type in self._skip_types\n        self.eta = eta\n        self.skip_factor = skip_factor\n        self.skip_type = skip_type\n\n    def scale_func(self, y, t, base_scale):\n        return (\n            extract(self.model.sqrt_one_minus_alphas_cumprod, t, y.shape)\n            * base_scale\n        )\n\n    def _get_timesteps(self):\n        if self.skip_type == \"uniform\":\n            timesteps = range(0, self.model.timesteps, self.skip_factor)\n        elif self.skip_type == \"quadratic\":\n            n_test_timesteps = int(self.model.timesteps / self.skip_factor)\n            c = 1 - self.skip_factor / self.model.timesteps\n            timesteps = np.square(\n                np.linspace(\n                    0, np.sqrt(self.model.timesteps * c), n_test_timesteps\n                )\n            )\n            timesteps = timesteps.astype(np.int64).tolist()\n        timesteps = sorted(set(timesteps))\n        return timesteps\n\n    @torch.no_grad()\n    def _reverse_ddim(\n        self, observation, observation_mask, features, base_scale\n    ):\n        device = observation.device\n        batch_size = observation.shape[0]\n        timesteps = self._get_timesteps()\n        timesteps_prev = [-1] + timesteps[:-1]\n\n        seq = torch.randn_like(observation)\n\n        for i, j in zip(reversed(timesteps), reversed(timesteps_prev)):\n            t = torch.full((batch_size,), i, device=device, dtype=torch.long)\n            t_prev = torch.full(\n                (batch_size,), j, device=device, dtype=torch.long\n            )\n            noise = self.model.backbone(seq, t, features)\n            scale = self.scale_func(seq, t, base_scale=base_scale)\n            noise = noise - scale * self.score_func(\n                seq,\n                t,\n                observation=observation,\n                observation_mask=observation_mask,\n                features=features,\n            )\n            seq = self.model.p_sample_genddim(\n                seq,\n                t,\n                t_index=i,\n                t_prev=t_prev,\n                eta=self.eta,\n                features=features,\n                noise=noise,\n            )\n\n        return seq\n\n    def guide(self, observation, observation_mask, features, base_scale):\n        return self._reverse_ddim(\n            observation, observation_mask, features, base_scale\n        )\n"
  },
  {
    "path": "src/uncond_ts_diff/sampler/refiner.py",
    "content": "# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n# SPDX-License-Identifier: Apache-2.0\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom gluonts.time_feature import get_seasonality\n\nfrom uncond_ts_diff.predictor import PyTorchPredictorWGrads\nfrom uncond_ts_diff.sampler._base import (\n    langevin_dynamics,\n    hmc,\n    udld,\n)\n\nPREDICTION_INPUT_NAMES = [\n    \"past_target\",\n    \"past_observed_values\",\n    \"feat_static_cat\",\n    \"feat_static_real\",\n    \"past_time_feat\",\n    \"future_time_feat\",\n    \"stats\",\n]\n\n\nclass Refiner(torch.nn.Module):\n    def __init__(\n        self,\n        model,\n        prediction_length,\n        fixed_t=20,\n        iterations=1,\n        init=None,\n        num_samples=1,\n        guidance=\"quantile\",\n        scale=1,\n    ):\n        super().__init__()\n        self.model = model\n        self.prediction_length = prediction_length\n        self.fixed_t = fixed_t\n        self.iterations = iterations\n        self.init = init\n        self.num_samples = num_samples\n        self.guidance = guidance\n        self.scale = scale\n\n    def quantile_loss(self, y_prediction, y_target):\n        assert y_target.shape == y_prediction.shape\n        device = y_prediction.device\n        batch_size_x_num_samples, length, ch = y_target.shape\n        batch_size = batch_size_x_num_samples // self.num_samples\n        # num_samples uniformly distributed quantiles between 0 and 1\n        # repeat for each element in the batch\n        q = (torch.arange(self.num_samples).repeat(batch_size) + 1).to(\n            device\n        ) / (self.num_samples + 1)\n        # (batch_size x num_samples,)\n        q = q[:, None, None]\n        # (batch_size x num_samples, 1, 1)\n        e = y_target - y_prediction\n        loss = torch.max(q * e, (q - 1) * e)\n        return loss\n\n    def prior(self, y_prediction, obs, obs_mask):\n        if self.guidance == \"MSE\":\n            return (\n                self.scale\n                * F.mse_loss(y_prediction, obs, reduction=\"none\")[\n                    obs_mask == 1\n                ].sum()\n            )\n        elif self.guidance == \"quantile\":\n            return self.scale * self.quantile_loss(y_prediction, obs).sum()\n        else:\n            raise ValueError(f\"Unknown guidance {self.guidance}!\")\n\n    def refine(self, observation, observation_mask):\n        raise NotImplementedError(\"Must be implemented by a subclass!\")\n\n    def forward(\n        self,\n        past_target: torch.Tensor,\n        past_observed_values: torch.Tensor,\n        feat_static_cat: torch.Tensor = None,\n        feat_static_real: torch.Tensor = None,\n        past_time_feat: torch.Tensor = None,\n        future_time_feat: torch.Tensor = None,\n        stats: torch.Tensor = None,\n    ):\n        device = next(self.model.backbone.parameters()).device\n        data = dict(\n            feat_static_cat=feat_static_cat.to(device)\n            if feat_static_cat is not None\n            else None,\n            feat_static_real=feat_static_real.to(device)\n            if feat_static_real is not None\n            else None,\n            past_time_feat=past_time_feat.to(device)\n            if past_time_feat is not None\n            else None,\n            past_target=past_target.to(device),\n            future_target=torch.zeros(\n                past_target.shape[0], self.prediction_length, device=device\n            ),\n            past_observed_values=past_observed_values.to(device)\n            if past_observed_values is not None\n            else None,\n            future_time_feat=future_time_feat.to(device)\n            if future_time_feat is not None\n            else None,\n        )\n\n        observation, scale, features = self.model._extract_features(data)\n\n        observation = observation.to(device)\n        batch_size, length, ch = observation.shape\n        observation_mask = torch.ones_like(observation, device=device)\n        observation_mask[:, length - self.prediction_length :, 0] = 0\n\n        observation = observation.repeat_interleave(self.num_samples, dim=0)\n        observation_mask = observation_mask.repeat_interleave(\n            self.num_samples, dim=0\n        )\n        if features is not None:\n            features = features.repeat_interleave(self.num_samples, dim=0)\n\n        if self.init is not None:\n            init_forecasts = np.stack(\n                [next(self.init).samples for _ in range(batch_size)]\n            )\n\n            if init_forecasts.shape[1] == 1:\n                # Single sample, e.g., for SeasonalNaive\n                init_forecasts = np.tile(\n                    init_forecasts, (1, self.num_samples, 1)\n                )\n\n            # create numpy array out of list and sort them to\n            # match to their corresponding quantile\n            init = np.sort(init_forecasts, axis=1)\n            init = torch.from_numpy(init).to(device)\n\n            # scale input\n            init = init / scale\n\n            # reshape from B x num_samples x prediction_length to\n            # B * self.num_samples x prediction_length\n            init = init.reshape(\n                batch_size * self.num_samples, self.prediction_length\n            )\n\n            # use it as initial guess\n            observation[:, length - self.prediction_length :, 0] = init\n\n        else:\n            season_length = get_seasonality(self.model.freq)\n\n            # Initialize using Seasonal Naive predictions\n            if (length - self.prediction_length) >= season_length:\n                indices = [\n                    length\n                    - self.prediction_length\n                    - season_length\n                    + k % season_length\n                    for k in range(self.prediction_length)\n                ]\n                observation[\n                    :, length - self.prediction_length :, 0\n                ] = observation[:, indices, 0]\n\n            # Initialize using the meant of the context length\n            else:\n                observation[\n                    :, length - self.prediction_length :, 0\n                ] = torch.mean(\n                    observation[:, : length - self.prediction_length, 0],\n                    dim=1,\n                    keepdim=True,\n                )\n\n        pred = self.refine(observation, observation_mask)\n\n        pred = pred[:, :, 0].reshape(batch_size, self.num_samples, -1)\n        pred = pred * scale\n\n        return pred[:, :, length - self.prediction_length :]\n\n    def get_predictor(self, input_transform, batch_size=40, device=None):\n        return PyTorchPredictorWGrads(\n            prediction_length=self.prediction_length,\n            input_names=PREDICTION_INPUT_NAMES,\n            prediction_net=self,\n            batch_size=batch_size,\n            input_transform=input_transform,\n            device=device,\n        )\n\n\nclass MostLikelyRefiner(Refiner):\n    def __init__(\n        self,\n        model,\n        prediction_length,\n        lr=1e-1,\n        patience=100,\n        fixed_t=20,\n        iterations=1,\n        init=None,\n        num_samples=1,\n        guidance=\"quantile\",\n        scale=1,\n    ):\n        super().__init__(\n            model,\n            prediction_length,\n            fixed_t,\n            iterations,\n            init,\n            num_samples,\n            guidance,\n            scale,\n        )\n        self.lr = lr\n        self.patience = patience\n\n    def _most_likely(self, observation, observation_mask):\n        device = next(self.model.backbone.parameters()).device\n        observation = observation.to(device)\n        seq = nn.Parameter(torch.clone(observation), requires_grad=True)\n        optim = torch.optim.SGD([seq], lr=self.lr)\n        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(\n            optim, \"min\", patience=self.patience, factor=0.5\n        )\n        with torch.enable_grad():\n            for i in range(self.iterations):\n                optim.zero_grad()\n                t = torch.randint(\n                    0, self.model.timesteps, (seq.shape[0],), device=device\n                ).long()\n                if self.fixed_t != -1:\n                    t = t * 0 + self.fixed_t\n                loss = self.model.p_losses(\n                    seq, t, loss_type=\"l2\", reduction=\"sum\"\n                )[0] + self.prior(seq, observation, observation_mask)\n                loss.backward()\n\n                optim.step()\n                scheduler.step(loss.item())\n\n        return seq.detach()\n\n    def refine(self, observation, observation_mask):\n        return self._most_likely(observation, observation_mask)\n\n\nclass MCMCRefiner(Refiner):\n    _available_methods = {\"lmc\", \"hmc\", \"udld\", \"cdld\"}\n\n    def __init__(\n        self,\n        model,\n        prediction_length,\n        step_size=1e-1,\n        method=\"lmc\",\n        method_kwargs={},\n        fixed_t=20,\n        iterations=1,\n        init=None,\n        num_samples=1,\n        guidance=\"quantile\",\n        scale=1,\n    ):\n        super().__init__(\n            model,\n            prediction_length,\n            fixed_t,\n            iterations,\n            init,\n            num_samples,\n            guidance,\n            scale,\n        )\n        assert method in self._available_methods\n        self.step_size: float = step_size\n        self.method: str = method\n        self.method_kwargs: dict = method_kwargs\n\n    def _mcmc(self, observation, observation_mask):\n        device = next(self.model.backbone.parameters()).device\n        observation = observation.to(device)\n        seq = torch.clone(observation)\n\n        for i in range(self.iterations):\n            t = torch.randint(\n                0, self.model.timesteps, (seq.shape[0],), device=device\n            ).long()\n            if self.fixed_t != -1:\n                t = t * 0 + self.fixed_t\n\n            energy_func = lambda x: self.model.p_losses(  # noqa: E731\n                x, t, loss_type=\"l2\", reduction=\"sum\"\n            )[0] + self.prior(x, observation, observation_mask)\n\n            if self.method == \"lmc\":\n                method_kwargs = {\n                    \"noise_scale\": 0.1,\n                    \"n_steps\": 1,\n                }\n                method_kwargs.update(self.method_kwargs)\n                seq = langevin_dynamics(\n                    seq,\n                    energy_func,\n                    score_func=None,\n                    step_size=self.step_size,\n                    **self.method_kwargs,\n                )\n            elif self.method == \"hmc\":\n                method_kwargs = {\n                    \"mass\": 1.0,\n                    \"n_steps\": 1,\n                    \"n_leapfrog_steps\": 5,\n                }\n                method_kwargs.update(self.method_kwargs)\n                seq = hmc(\n                    seq, energy_func, step_size=self.step_size, **method_kwargs\n                )\n            elif self.method == \"udld\":\n                method_kwargs = {\n                    \"mass\": 1.0,\n                    \"friction\": 1.0,\n                    \"n_steps\": 1,\n                    \"n_leapfrog_steps\": 5,\n                }\n                method_kwargs.update(self.method_kwargs)\n                seq = udld(\n                    seq, energy_func, step_size=self.step_size, **method_kwargs\n                )\n            elif self.method == \"cdld\":\n                method_kwargs = {\n                    \"mass\": 1.0,\n                    \"n_steps\": 1,\n                    \"n_leapfrog_steps\": 5,\n                }\n                method_kwargs.update(self.method_kwargs)\n                # friction^2 = 4 x mass\n                method_kwargs[\"friction\"] = np.sqrt(4 * method_kwargs[\"mass\"])\n                seq = udld(\n                    seq, energy_func, step_size=self.step_size, **method_kwargs\n                )\n\n        return seq.detach()\n\n    def refine(self, observation, observation_mask):\n        return self._mcmc(observation, observation_mask)\n"
  },
  {
    "path": "src/uncond_ts_diff/utils.py",
    "content": "# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n# SPDX-License-Identifier: Apache-2.0\nfrom copy import deepcopy\nfrom typing import Type, Dict\nfrom pathlib import Path\nfrom argparse import ArgumentParser, ArgumentTypeError\nfrom functools import partial\nimport re\n\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport seaborn as sns\nimport pandas as pd\nimport torch\nfrom torch.utils.data import Dataset\nfrom pandas.tseries.frequencies import to_offset\n\nfrom gluonts.core.component import validated\nfrom gluonts.dataset import DataEntry\nfrom gluonts.dataset.field_names import FieldName\nfrom gluonts.dataset.split import split\nfrom gluonts.dataset.util import period_index\nfrom gluonts.transform import (\n    Chain,\n    RemoveFields,\n    SetField,\n    AsNumpyArray,\n    AddObservedValuesIndicator,\n    AddTimeFeatures,\n    AddAgeFeature,\n    VstackFeatures,\n    MapTransformation,\n    ExpectedNumInstanceSampler,\n    InstanceSplitter,\n    TestSplitSampler,\n    ValidationSplitSampler,\n)\nfrom gluonts.model.forecast import SampleForecast\n\nsns.set(\n    style=\"white\",\n    font_scale=1.1,\n    rc={\"figure.dpi\": 125, \"lines.linewidth\": 2.5, \"axes.linewidth\": 1.5},\n)\n\n\ndef filter_metrics(metrics, select={\"ND\", \"NRMSE\", \"mean_wQuantileLoss\"}):\n    return {m: metrics[m].item() for m in select}\n\n\ndef extract(a, t, x_shape):\n    batch_size = t.shape[0]\n    out = a.gather(-1, t.cpu())\n    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)\n\n\ndef cosine_beta_schedule(timesteps, s=0.008):\n    \"\"\"\n    cosine schedule as proposed in https://arxiv.org/abs/2102.09672\n    \"\"\"\n    steps = timesteps + 1\n    x = torch.linspace(0, timesteps, steps)\n    alphas_cumprod = (\n        torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2\n    )\n    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]\n    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])\n    return torch.clip(betas, 0.0001, 0.9999)\n\n\ndef linear_beta_schedule(timesteps):\n    beta_start = 0.0001\n    beta_end = 0.1\n    return torch.linspace(beta_start, beta_end, timesteps)\n\n\ndef plot_train_stats(df: pd.DataFrame, y_keys=None, skip_first_epoch=True):\n    if skip_first_epoch:\n        df = df.iloc[1:, :]\n    if y_keys is None:\n        y_keys = [\"train_loss\", \"valid_loss\"]\n\n    fix, ax = plt.subplots(1, 1, figsize=(6.5, 4))\n    for y_key in y_keys:\n        sns.lineplot(\n            ax=ax,\n            data=df,\n            x=\"epochs\",\n            y=y_key,\n            label=y_key.replace(\"_\", \" \").capitalize(),\n        )\n    ax.legend()\n    ax.set_ylabel(\"Loss\")\n    ax.set_xlabel(\"Epoch\")\n    plt.show()\n\n\ndef get_lags_for_freq(freq_str: str):\n    offset = to_offset(freq_str)\n    if offset.n > 1:\n        raise NotImplementedError(\n            \"Lags for freq multiple > 1 are not implemented yet.\"\n        )\n    if offset.name == \"H\":\n        lags_seq = [24 * i for i in [1, 2, 3, 4, 5, 6, 7, 14, 21, 28]]\n    elif offset.name == \"D\" or offset.name == \"B\":\n        # TODO: Fix lags for B\n        lags_seq = [30 * i for i in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]\n    else:\n        raise NotImplementedError(\n            f\"Lags for {freq_str} are not implemented yet.\"\n        )\n    return lags_seq\n\n\ndef create_transforms(\n    num_feat_dynamic_real,\n    num_feat_static_cat,\n    num_feat_static_real,\n    time_features,\n    prediction_length,\n):\n    remove_field_names = []\n    if num_feat_static_real == 0:\n        remove_field_names.append(FieldName.FEAT_STATIC_REAL)\n    if num_feat_dynamic_real == 0:\n        remove_field_names.append(FieldName.FEAT_DYNAMIC_REAL)\n\n    return Chain(\n        [RemoveFields(field_names=remove_field_names)]\n        + (\n            [SetField(output_field=FieldName.FEAT_STATIC_CAT, value=[0])]\n            if not num_feat_static_cat > 0\n            else []\n        )\n        + (\n            [SetField(output_field=FieldName.FEAT_STATIC_REAL, value=[0.0])]\n            if not num_feat_static_real > 0\n            else []\n        )\n        + [\n            AsNumpyArray(\n                field=FieldName.FEAT_STATIC_CAT,\n                expected_ndim=1,\n                dtype=int,\n            ),\n            AsNumpyArray(\n                field=FieldName.FEAT_STATIC_REAL,\n                expected_ndim=1,\n            ),\n            AsNumpyArray(\n                field=FieldName.TARGET,\n                expected_ndim=1,\n            ),\n            AddObservedValuesIndicator(\n                target_field=FieldName.TARGET,\n                output_field=FieldName.OBSERVED_VALUES,\n            ),\n            AddTimeFeatures(\n                start_field=FieldName.START,\n                target_field=FieldName.TARGET,\n                output_field=FieldName.FEAT_TIME,\n                time_features=time_features,\n                pred_length=prediction_length,\n            ),\n            AddAgeFeature(\n                target_field=FieldName.TARGET,\n                output_field=FieldName.FEAT_AGE,\n                pred_length=prediction_length,\n                log_scale=True,\n            ),\n            AddMeanAndStdFeature(\n                target_field=FieldName.TARGET,\n                output_field=\"stats\",\n            ),\n            VstackFeatures(\n                output_field=FieldName.FEAT_TIME,\n                input_fields=[FieldName.FEAT_TIME, FieldName.FEAT_AGE]\n                + (\n                    [FieldName.FEAT_DYNAMIC_REAL]\n                    if num_feat_dynamic_real > 0\n                    else []\n                ),\n            ),\n        ]\n    )\n\n\ndef create_splitter(past_length: int, future_length: int, mode: str = \"train\"):\n    if mode == \"train\":\n        instance_sampler = ExpectedNumInstanceSampler(\n            num_instances=1,\n            min_past=past_length,\n            min_future=future_length,\n        )\n    elif mode == \"val\":\n        instance_sampler = ValidationSplitSampler(min_future=future_length)\n    elif mode == \"test\":\n        instance_sampler = TestSplitSampler()\n\n    splitter = InstanceSplitter(\n        target_field=FieldName.TARGET,\n        is_pad_field=FieldName.IS_PAD,\n        start_field=FieldName.START,\n        forecast_start_field=FieldName.FORECAST_START,\n        instance_sampler=instance_sampler,\n        past_length=past_length,\n        future_length=future_length,\n        time_series_fields=[FieldName.FEAT_TIME, FieldName.OBSERVED_VALUES],\n    )\n    return splitter\n\n\ndef get_next_file_num(\n    base_fname: str,\n    base_dir: Path,\n    file_type: str = \"yaml\",\n    separator: str = \"-\",\n):\n    \"\"\"Gets the next available file number in a directory.\n    e.g., if `base_fname=\"results\"` and `base_dir` has\n    files [\"results-0.yaml\", \"results-1.yaml\"],\n    this function returns 2.\n\n    Parameters\n    ----------\n    base_fname\n        Base name of the file.\n    base_dir\n        Base directory where files are located.\n\n    Returns\n    -------\n        Next available file number\n    \"\"\"\n    if file_type == \"\":\n        # Directory\n        items = filter(\n            lambda x: x.is_dir() and x.name.startswith(base_fname),\n            base_dir.glob(\"*\"),\n        )\n    else:\n        # File\n        items = filter(\n            lambda x: x.name.startswith(base_fname),\n            base_dir.glob(f\"*.{file_type}\"),\n        )\n    run_nums = list(\n        map(lambda x: int(x.stem.replace(base_fname + separator, \"\")), items)\n    ) + [-1]\n\n    return max(run_nums) + 1\n\n\ndef str2bool(v):\n    if isinstance(v, bool):\n        return v\n    if v.lower() in (\"yes\", \"true\", \"t\", \"y\", \"1\"):\n        return True\n    elif v.lower() in (\"no\", \"false\", \"f\", \"n\", \"0\"):\n        return False\n    else:\n        raise ArgumentTypeError(\"Boolean value expected.\")\n\n\ndef add_config_to_argparser(config: Dict, parser: ArgumentParser):\n    for k, v in config.items():\n        sanitized_key = re.sub(r\"[^\\w\\-]\", \"\", k).replace(\"-\", \"_\")\n        val_type = type(v)\n        if val_type not in {int, float, str, bool}:\n            print(f\"WARNING: Skipping key {k}!\")\n            continue\n        if val_type == bool:\n            parser.add_argument(f\"--{sanitized_key}\", type=str2bool, default=v)\n        else:\n            parser.add_argument(f\"--{sanitized_key}\", type=val_type, default=v)\n    return parser\n\n\nclass AddMeanAndStdFeature(MapTransformation):\n    @validated()\n    def __init__(\n        self,\n        target_field: str,\n        output_field: str,\n        dtype: Type = np.float32,\n    ) -> None:\n        self.target_field = target_field\n        self.feature_name = output_field\n        self.dtype = dtype\n\n    def map_transform(self, data: DataEntry, is_train: bool) -> DataEntry:\n        data[self.feature_name] = np.array(\n            [data[self.target_field].mean(), data[self.target_field].std()]\n        )\n\n        return data\n\n\nclass ScaleAndAddMeanFeature(MapTransformation):\n    def __init__(\n        self, target_field: str, output_field: str, prediction_length: int\n    ) -> None:\n        \"\"\"Scale the time series using mean scaler and\n        add the scale to `output_field`.\n\n        Parameters\n        ----------\n        target_field\n            Key for target time series\n        output_field\n            Key for the mean feature\n        prediction_length\n            prediction length, only the time series before the\n            last `prediction_length` timesteps is used for\n            scale computation\n        \"\"\"\n        self.target_field = target_field\n        self.feature_name = output_field\n        self.prediction_length = prediction_length\n\n    def map_transform(self, data, is_train: bool):\n        scale = np.mean(\n            np.abs(data[self.target_field][..., : -self.prediction_length]),\n            axis=-1,\n            keepdims=True,\n        )\n        scale = np.maximum(scale, 1e-7)\n        scaled_target = data[self.target_field] / scale\n        data[self.target_field] = scaled_target\n        data[self.feature_name] = scale\n\n        return data\n\n\nclass ScaleAndAddMinMaxFeature(MapTransformation):\n    def __init__(\n        self, target_field: str, output_field: str, prediction_length: int\n    ) -> None:\n        \"\"\"Scale the time series using min-max scaler and\n        add the scale to `output_field`.\n\n        Parameters\n        ----------\n        target_field\n            Key for target time series\n        output_field\n            Key for the min-max feature\n        prediction_length\n            prediction length, only the time series before the\n            last `prediction_length` timesteps is used for\n            scale computation\n        \"\"\"\n        self.target_field = target_field\n        self.feature_name = output_field\n        self.prediction_length = prediction_length\n\n    def map_transform(self, data, is_train: bool):\n        full_seq = data[self.target_field][..., : -self.prediction_length]\n        min_val = np.min(full_seq, axis=-1, keepdims=True)\n        max_val = np.max(full_seq, axis=-1, keepdims=True)\n        loc = min_val\n        scale = np.maximum(max_val - min_val, 1e-7)\n        scaled_target = (full_seq - loc) / scale\n        data[self.target_field] = scaled_target\n        data[self.feature_name] = (loc, scale)\n\n        return data\n\n\ndef descale(data, scale, scaling_type):\n    if scaling_type == \"mean\":\n        return data * scale\n    elif scaling_type == \"min-max\":\n        loc, scale = scale\n        return data * scale + loc\n    else:\n        raise ValueError(f\"Unknown scaling type: {scaling_type}\")\n\n\ndef predict_and_descale(predictor, dataset, num_samples, scaling_type):\n    \"\"\"Generates forecasts using the predictor on the test\n    dataset and then scales them back to the original space\n    using the scale feature from `ScaleAndAddMeanFeature`\n    or `ScaleAndAddMinMaxFeature` transformation.\n\n    Parameters\n    ----------\n    predictor\n        GluonTS predictor\n    dataset\n        Test dataset\n    num_samples\n        Number of forecast samples\n    scaling_type\n        Scaling type should be one of {\"mean\", \"min-max\"}\n        Min-max scaling is used in TimeGAN, defaults to \"mean\"\n\n    Yields\n    ------\n        SampleForecast objects\n\n    Raises\n    ------\n    ValueError\n        If the predictor generates Forecast objects other than SampleForecast\n    \"\"\"\n    forecasts = predictor.predict(dataset, num_samples=num_samples)\n    for input_ts, fcst in zip(dataset, forecasts):\n        scale = input_ts[\"scale\"]\n        if isinstance(fcst, SampleForecast):\n            fcst.samples = descale(\n                fcst.samples, scale, scaling_type=scaling_type\n            )\n        else:\n            raise ValueError(\"Only SampleForecast objects supported!\")\n        yield fcst\n\n\ndef to_dataframe_and_descale(input_label, scaling_type) -> pd.DataFrame:\n    \"\"\"Glues together \"input\" and \"label\" time series and scales\n    the back using the scale feature from transformation.\n\n    Parameters\n    ----------\n    input_label\n        Input-Label pair generated from the test template\n    scaling_type\n        Scaling type should be one of {\"mean\", \"min-max\"}\n        Min-max scaling is used in TimeGAN, defaults to \"mean\"\n\n    Returns\n    -------\n        A DataFrame containing the time series\n    \"\"\"\n    start = input_label[0][FieldName.START]\n    scale = input_label[0][\"scale\"]\n    targets = [entry[FieldName.TARGET] for entry in input_label]\n    full_target = np.concatenate(targets, axis=-1)\n    full_target = descale(full_target, scale, scaling_type=scaling_type)\n    index = period_index(\n        {FieldName.START: start, FieldName.TARGET: full_target}\n    )\n    return pd.DataFrame(full_target.transpose(), index=index)\n\n\ndef make_evaluation_predictions_with_scaling(\n    dataset, predictor, num_samples: int = 100, scaling_type=\"mean\"\n):\n    \"\"\"A customized version of `make_evaluation_predictions` utility\n    that first scales the test time series, generates the forecast and\n    the scales it back to the original space.\n\n    Parameters\n    ----------\n    dataset\n        Test dataset\n    predictor\n        GluonTS predictor\n    num_samples, optional\n        Number of test samples, by default 100\n    scaling_type, optional\n        Scaling type should be one of {\"mean\", \"min-max\"}\n        Min-max scaling is used in TimeGAN, defaults to \"mean\"\n\n    Returns\n    -------\n        A tuple of forecast and time series iterators\n    \"\"\"\n    window_length = predictor.prediction_length + predictor.lead_time\n    _, test_template = split(dataset, offset=-window_length)\n    test_data = test_template.generate_instances(window_length)\n    input_test_data = list(test_data.input)\n\n    return (\n        predict_and_descale(\n            predictor,\n            input_test_data,\n            num_samples=num_samples,\n            scaling_type=scaling_type,\n        ),\n        map(\n            partial(to_dataframe_and_descale, scaling_type=scaling_type),\n            test_data,\n        ),\n    )\n\n\nclass PairDataset(Dataset):\n    def __init__(self, x, y) -> None:\n        super().__init__()\n        assert x.shape[0] == y.shape[0]\n        self.x = x\n        self.y = y\n\n    def __getitem__(self, index):\n        return self.x[index], self.y[index]\n\n    def __len__(self):\n        return self.x.shape[0]\n\n\nclass GluonTSNumpyDataset:\n    \"\"\"GluonTS dataset from a numpy array.\n\n    Parameters\n    ----------\n    data\n        Numpy array of samples with shape [N, T].\n    start_date, optional\n        Dummy start date field, by default pd.Period(\"2023\", \"H\")\n    \"\"\"\n\n    def __init__(\n        self, data: np.ndarray, start_date: pd.Period = pd.Period(\"2023\", \"H\")\n    ):\n        self.data = data\n        self.start_date = start_date\n\n    def __iter__(self):\n        for ts in self.data:\n            item = {\"target\": ts, \"start\": self.start_date}\n            yield item\n\n    def __len__(self):\n        return len(self.data)\n\n\nclass MaskInput(MapTransformation):\n    @validated()\n    def __init__(\n        self,\n        target_field: str,\n        observed_field: str,\n        context_length: int,\n        missing_scenario: str,\n        missing_values: int,\n        dtype: Type = np.float32,\n    ) -> None:\n        # FIXME: Remove hardcoding of fields\n        self.target_field = target_field\n        self.observed_field = observed_field\n        self.context_length = context_length\n        self.missing_scenario = missing_scenario\n        self.missing_values = missing_values\n        self.dtype = dtype\n\n    def map_transform(self, data: DataEntry, is_train: bool) -> DataEntry:\n        data = deepcopy(data)\n        data[\"orig_past_target\"] = data[\"past_target\"].copy()\n        if self.missing_scenario == \"BM-E\" and self.missing_values > 0:\n            data[\"past_target\"][-self.missing_values :] = 0\n            data[\"past_observed_values\"][-self.missing_values :] = 0\n        elif self.missing_scenario == \"BM-B\" and self.missing_values > 0:\n            data[\"past_target\"][\n                -self.context_length : -self.context_length\n                + self.missing_values\n            ] = 0\n            data[\"past_observed_values\"][\n                -self.context_length : -self.context_length\n                + self.missing_values\n            ] = 0\n        elif self.missing_scenario == \"RM\" and self.missing_values > 0:\n            weights = torch.ones(self.context_length)\n            missing_idxs = -self.context_length + torch.multinomial(\n                weights, self.missing_values, replacement=False\n            )\n            data[\"past_target\"][missing_idxs] = 0\n            data[\"past_observed_values\"][missing_idxs] = 0\n        return data\n\n\nclass ConcatDataset:\n    def __init__(self, test_pairs, axis=-1) -> None:\n        self.test_pairs = test_pairs\n        self.axis = axis\n\n    def _concat(self, test_pairs):\n        for t1, t2 in test_pairs:\n            yield {\n                \"target\": np.concatenate(\n                    [t1[\"target\"], t2[\"target\"]], axis=self.axis\n                ),\n                \"start\": t1[\"start\"],\n            }\n\n    def __iter__(self):\n        yield from self._concat(self.test_pairs)\n"
  }
]