[
  {
    "path": ".gitattributes",
    "content": "# Git LFS tracking for binary outputs in timesfm-forecasting skill\ntimesfm-forecasting/**/*.png filter=lfs diff=lfs merge=lfs -text\ntimesfm-forecasting/**/*.gif filter=lfs diff=lfs merge=lfs -text\n"
  },
  {
    "path": ".github/workflows/main.yml",
    "content": "name: Python package build\n\non:\n  push:\n    branches: [ \"master\" ]\n  pull_request:\n    branches: [ \"master\" ]\n\njobs:\n  build:\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/checkout@v2\n      - name: Set up Python\n        uses: actions/setup-python@v2\n        with:\n          python-version: '3.11'\n      - name: Install uv\n        run: |\n          curl -LsSf https://astral.sh/uv/install.sh | sh\n          echo \"$HOME/.cargo/bin\" >> $GITHUB_PATH\n      - name: Create virtual environment\n        run: uv venv\n      - name: Install build dependencies\n        run: |\n          uv pip install build \".[torch,flax]\"\n      - name: Build package\n        run: uv run python -m build"
  },
  {
    "path": ".github/workflows/manual_publish.yml",
    "content": "name: Manual PyPI Publish\n\non:\n  workflow_dispatch:\n\njobs:\n  build-and-publish:\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/checkout@v2\n      - name: Set up Python\n        uses: actions/setup-python@v2\n        with:\n          python-version: '3.11'\n      - name: Install uv\n        run: |\n          curl -LsSf https://astral.sh/uv/install.sh | sh\n          echo \"$HOME/.cargo/bin\" >> $GITHUB_PATH\n      - name: Create virtual environment\n        run: uv venv\n      - name: Install build dependencies\n        run: uv pip install build twine\n      - name: Build package\n        run: uv run python -m build\n      - name: Publish to PyPI\n        env:\n          TWINE_USERNAME: __token__\n          TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}\n        run: uv run twine upload dist/*"
  },
  {
    "path": ".gitignore",
    "content": ".venv/\ndist/\n__pycache__/\ncheckpoints/\nwandb/\ndatasets/\nresults/\ntimesfm_jax.egg-info/\ndevelopment_setup.md\n"
  },
  {
    "path": "AGENTS.md",
    "content": "# TimesFM — Agent Entry Point\n\nThis repository ships a first-party **Agent Skill** for TimesFM at:\n\n```\ntimesfm-forecasting/\n└── SKILL.md    ← read this for the full skill\n```\n\n## Install the skill\n\nCopy the skill directory into your agent's skills folder:\n\n```bash\n# Cursor / Claude Code / OpenCode / Codex (global install)\ncp -r timesfm-forecasting/ ~/.cursor/skills/\ncp -r timesfm-forecasting/ ~/.claude/skills/\n\n# Or project-level\ncp -r timesfm-forecasting/ .cursor/skills/\n```\n\nAny agent that supports the open [Agent Skills standard](https://agentskills.io) will discover it automatically.\n\n## Working in this repo\n\nIf you are developing TimesFM itself (not using it), the source lives in `src/timesfm/`.\nArchived v1/v2 code and notebooks are in `v1/`.\n\nRun tests:\n\n```bash\npytest v1/tests/\n```\n\nSee `README.md` for full developer setup.\n"
  },
  {
    "path": "LICENSE",
    "content": "\n                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "README.md",
    "content": "# TimesFM\n\nTimesFM (Time Series Foundation Model) is a pretrained time-series foundation\nmodel developed by Google Research for time-series forecasting.\n\n*   Paper:\n    [A decoder-only foundation model for time-series forecasting](https://arxiv.org/abs/2310.10688),\n    ICML 2024.\n*   All checkpoints:\n    [TimesFM Hugging Face Collection](https://huggingface.co/collections/google/timesfm-release-66e4be5fdb56e960c1e482a6).\n*   [Google Research blog](https://research.google/blog/a-decoder-only-foundation-model-for-time-series-forecasting/).\n*   [TimesFM in BigQuery](https://cloud.google.com/bigquery/docs/timesfm-model):\n    an official Google product.\n\nThis open version is not an officially supported Google product.\n\n**Latest Model Version:** TimesFM 2.5\n\n**Archived Model Versions:**\n\n-   1.0 and 2.0: relevant code archived in the sub directory `v1`. You can `pip\n    install timesfm==1.3.0` to install an older version of this package to load\n    them.\n\n## Update - Oct. 29, 2025\n\nAdded back the covariate support through XReg for TimesFM 2.5.\n\n\n## Update - Sept. 15, 2025\n\nTimesFM 2.5 is out!\n\nComparing to TimesFM 2.0, this new 2.5 model:\n\n-   uses 200M parameters, down from 500M.\n-   supports up to 16k context length, up from 2048.\n-   supports continuous quantile forecast up to 1k horizon via an optional 30M\n    quantile head.\n-   gets rid of the `frequency` indicator.\n-   has a couple of new forecasting flags.\n\nAlong with the model upgrade we have also upgraded the inference API. This repo\nwill be under construction over the next few weeks to\n\n1.  add support for an upcoming Flax version of the model (faster inference).\n2.  add back covariate support.\n3.  populate more docstrings, docs and notebook.\n\n### Install\n\n1.  Clone the repository:\n    ```shell\n    git clone https://github.com/google-research/timesfm.git\n    cd timesfm\n    ```\n\n2.  Create a virtual environment and install dependencies using `uv`:\n    ```shell\n    # Create a virtual environment\n    uv venv\n    \n    # Activate the environment\n    source .venv/bin/activate\n    \n    # Install the package in editable mode with torch\n    uv pip install -e .[torch]\n    # Or with flax\n    uv pip install -e .[flax]\n    # Or XReg is needed\n    uv pip install -e .[xreg]\n    ```\n\n3. [Optional] Install your preferred `torch` / `jax` backend based on your OS and accelerators\n(CPU, GPU, TPU or Apple Silicon).:\n\n-   [Install PyTorch](https://pytorch.org/get-started/locally/).\n-   [Install Jax](https://docs.jax.dev/en/latest/installation.html#installation)\n    for Flax.\n\n### Code Example\n\n```python\nimport torch\nimport numpy as np\nimport timesfm\n\ntorch.set_float32_matmul_precision(\"high\")\n\nmodel = timesfm.TimesFM_2p5_200M_torch.from_pretrained(\"google/timesfm-2.5-200m-pytorch\")\n\nmodel.compile(\n    timesfm.ForecastConfig(\n        max_context=1024,\n        max_horizon=256,\n        normalize_inputs=True,\n        use_continuous_quantile_head=True,\n        force_flip_invariance=True,\n        infer_is_positive=True,\n        fix_quantile_crossing=True,\n    )\n)\npoint_forecast, quantile_forecast = model.forecast(\n    horizon=12,\n    inputs=[\n        np.linspace(0, 1, 100),\n        np.sin(np.linspace(0, 20, 67)),\n    ],  # Two dummy inputs\n)\npoint_forecast.shape  # (2, 12)\nquantile_forecast.shape  # (2, 12, 10): mean, then 10th to 90th quantiles.\n```\n"
  },
  {
    "path": "pyproject.toml",
    "content": "[project]\nname = \"timesfm\"\nversion = \"2.0.0\"\ndescription = \"A time series foundation model.\"\nauthors = [\n    {name = \"Rajat Sen\", email = \"senrajat@google.com\"},\n    {name = \"Yichen Zhou\", email = \"yichenzhou@google.com\"},\n    {name = \"Abhimanyu Das\", email = \"abhidas@google.com\"},\n    {name = \"Petros Mol\", email = \"pmol@google.com\"},\n    {name = \"Michael Chertushkin\", email = \"chertushkinmichael@gmail.com\"},\n]\nlicense = {text = \"Apache-2.0\"}\nreadme = \"README.md\"\nrequires-python = \">=3.10\"\ndependencies = [\n    \"numpy>=1.26.4\",\n    \"huggingface_hub[cli]>=0.23.0\",\n    \"safetensors>=0.5.3\",\n]\n\n[project.optional-dependencies]\ntorch = [\n    \"torch>=2.0.0\",\n]\nflax = [\n    \"flax\",\n    \"optax\",\n    \"einshape\",\n    \"orbax-checkpoint\",\n    \"jaxtyping\",\n    \"jax[cuda]\"\n]\nxreg = [\n    \"jax[cuda]\",\n    \"scikit-learn\",\n]\n\n[tool.ruff]\nline-length = 88\nindent-width = 2\n\n[build-system]\nrequires = [\"setuptools>=61.0\"]\nbuild-backend = \"setuptools.build_meta\"\n\n"
  },
  {
    "path": "requirements.txt",
    "content": "# This file was autogenerated by uv via the following command:\n#    uv pip compile pyproject.toml -o requirements.txt\nanyio==4.11.0\n    # via httpx\ncertifi==2025.10.5\n    # via\n    #   httpcore\n    #   httpx\nclick==8.3.0\n    # via typer-slim\nfilelock==3.19.1\n    # via huggingface-hub\nfsspec==2025.9.0\n    # via huggingface-hub\nh11==0.16.0\n    # via httpcore\nhf-xet==1.2.0\n    # via huggingface-hub\nhttpcore==1.0.9\n    # via httpx\nhttpx==0.28.1\n    # via huggingface-hub\nhuggingface-hub==1.0.1\n    # via timesfm (pyproject.toml)\nidna==3.10\n    # via\n    #   anyio\n    #   httpx\nnumpy==2.2.6\n    # via timesfm (pyproject.toml)\npackaging==25.0\n    # via huggingface-hub\npyyaml==6.0.3\n    # via huggingface-hub\nsafetensors==0.6.2\n    # via timesfm (pyproject.toml)\nshellingham==1.5.4\n    # via huggingface-hub\nsniffio==1.3.1\n    # via anyio\ntqdm==4.67.1\n    # via huggingface-hub\ntyper-slim==0.20.0\n    # via huggingface-hub\ntyping-extensions==4.15.0\n    # via\n    #   anyio\n    #   huggingface-hub\n    #   typer-slim\n"
  },
  {
    "path": "src/timesfm/__init__.py",
    "content": "# Copyright 2025 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"TimesFM API.\"\"\"\n\nfrom .configs import ForecastConfig\n\ntry:\n  from .timesfm_2p5 import timesfm_2p5_torch\n  TimesFM_2p5_200M_torch = timesfm_2p5_torch.TimesFM_2p5_200M_torch\nexcept ImportError:\n  pass\n\ntry:\n  from .timesfm_2p5 import timesfm_2p5_flax\n  TimesFM_2p5_200M_flax = timesfm_2p5_flax.TimesFM_2p5_200M_flax\nexcept ImportError:\n  pass\n"
  },
  {
    "path": "src/timesfm/configs.py",
    "content": "# Copyright 2025 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Abstract configs for TimesFM layers.\"\"\"\n\nimport dataclasses\nfrom typing import Literal\n\n\n@dataclasses.dataclass(frozen=True)\nclass ForecastConfig:\n  \"\"\"Options for forecasting.\n\n  Attributes:\n    max_context: The maximum context length. This is used by the complied decode\n      function at inference time during batched inference. Any input time series\n      with length less than max_context will be padded with zeros, and with\n      length greater than max_context will be truncated.\n    max_horizon: The maximum horizon length. This is used by the complied decode\n      function at inference time during batched inference. The compiled cached\n      decoding function will by default forecast till max_horizon.\n    normalize_inputs: Whether to normalize the inputs. This is useful when the\n      raw inputs are of extremely large or small magnitudes which may result in\n      numerical issues.\n    window_size: The window size for decomposed forecasting.\n      TODO(siriuz42):implement it.\n    per_core_batch_size: The batch size per core. Used at inference time during\n      batched inference when multiple GPU / TPU devices are used.\n    use_continuous_quantile_head: Whether to use a separate continuous quantile\n      head to avoid quantile collapsing.\n    force_flip_invariance: Whether to force flip invariance. TimesFM guarantees\n      that TimesFM(aX + b) = a * TimesFM(x) + b for a >= 0 by default. This flag\n      extends it to a < 0 as well.\n    infer_is_positive: Whether to guarantee nonnegativity of the output if the\n      input is nonnegative.\n    fix_quantile_crossing: Whether to fix quantile crossing.\n    return_backcast: Whether to return backcast.\n  \"\"\"\n\n  max_context: int = 0\n  max_horizon: int = 0\n  normalize_inputs: bool = False\n  window_size: int = 0\n  per_core_batch_size: int = 1\n  use_continuous_quantile_head: bool = False\n  force_flip_invariance: bool = True\n  infer_is_positive: bool = True\n  fix_quantile_crossing: bool = False\n  return_backcast: bool = False\n\n\n@dataclasses.dataclass(frozen=True)\nclass ResidualBlockConfig:\n  \"\"\"Framework-agnostic config for a residual block.\"\"\"\n\n  input_dims: int\n  hidden_dims: int\n  output_dims: int\n  use_bias: bool\n  activation: Literal[\"relu\", \"swish\", \"none\"]\n\n\n@dataclasses.dataclass(frozen=True)\nclass RandomFourierFeaturesConfig:\n  \"\"\"Framework-agnostic config for random fourier features.\"\"\"\n\n  input_dims: int\n  output_dims: int\n  projection_stddev: float\n  use_bias: bool\n\n\n@dataclasses.dataclass(frozen=True)\nclass TransformerConfig:\n  \"\"\"Framework-agnostic config for a transformer.\"\"\"\n\n  model_dims: int\n  hidden_dims: int\n  num_heads: int\n  attention_norm: Literal[\"rms\"]\n  feedforward_norm: Literal[\"rms\"]\n  qk_norm: Literal[\"rms\", \"none\"]\n  use_bias: bool\n  use_rotary_position_embeddings: bool\n  ff_activation: Literal[\"relu\", \"swish\", \"none\"]\n  fuse_qkv: bool\n\n\n@dataclasses.dataclass(frozen=True)\nclass StackedTransformersConfig:\n  \"\"\"Framework-agnostic config for a stacked transformers.\"\"\"\n\n  num_layers: int\n  transformer: TransformerConfig\n"
  },
  {
    "path": "src/timesfm/flax/__init__.py",
    "content": "# Copyright 2025 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "src/timesfm/flax/dense.py",
    "content": "# Copyright 2025 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Dense layers for TimesFM.\"\"\"\n\nfrom flax import nnx\nimport jax\nimport jax.numpy as jnp\nimport jaxtyping\n\nfrom .. import configs\n\nArray = jaxtyping.Array\nBool = jaxtyping.Bool\nFloat = jaxtyping.Float\nInteger = jaxtyping.Integer\nNum = jaxtyping.Num\n\nResidualBlockConfig = configs.ResidualBlockConfig\nRandomFourierFeaturesConfig = configs.RandomFourierFeaturesConfig\n\n\nclass ResidualBlock(nnx.Module):\n  \"\"\"Residual block with two linear layers and a linear residual connection.\"\"\"\n\n  def __init__(self, config: ResidualBlockConfig, *, rngs=nnx.Rngs(42)):\n    self.config = config\n    self.hidden_layer = nnx.Linear(\n      in_features=config.input_dims,\n      out_features=config.hidden_dims,\n      use_bias=config.use_bias,\n      rngs=rngs,\n    )\n    self.output_layer = nnx.Linear(\n      in_features=config.hidden_dims,\n      out_features=config.output_dims,\n      use_bias=config.use_bias,\n      rngs=rngs,\n    )\n    self.residual_layer = nnx.Linear(\n      in_features=config.input_dims,\n      out_features=config.output_dims,\n      use_bias=config.use_bias,\n      rngs=rngs,\n    )\n    if config.activation == \"relu\":\n      self.activation = jax.nn.relu\n    elif config.activation == \"swish\":\n      self.activation = jax.nn.swish\n    elif config.activation == \"none\":\n      self.activation = lambda x: x\n    else:\n      raise ValueError(f\"Activation: {config.activation} not supported.\")\n\n  def __call__(self, x: Float[Array, \"b ... i\"]) -> Float[Array, \"b ... o\"]:\n    return self.output_layer(\n      self.activation(self.hidden_layer(x))\n    ) + self.residual_layer(x)\n\n\nclass RandomFourierFeatures(nnx.Module):\n  \"\"\"Random Fourier features layer.\"\"\"\n\n  __data__ = (\"phrase_shifts\",)\n\n  def __init__(self, config: RandomFourierFeaturesConfig, *, rngs=nnx.Rngs(42)):\n    self.config = config\n\n    if config.output_dims % 4 != 0:\n      raise ValueError(\n        f\"Output dims must be a multiple of 4: {config.output_dims} % 4 != 0.\"\n      )\n    num_projected_features = config.output_dims // 4\n\n    self.phase_shifts = nnx.Param(jnp.zeros(shape=(2, num_projected_features)))\n    self.projection_layer = nnx.Linear(\n      in_features=config.input_dims,\n      out_features=num_projected_features,\n      use_bias=config.use_bias,\n      rngs=rngs,\n    )\n    self.residual_layer = nnx.Linear(\n      in_features=config.input_dims,\n      out_features=config.output_dims,\n      use_bias=config.use_bias,\n      rngs=rngs,\n    )\n\n  def __call__(self, x: Float[Array, \"b ... i\"]) -> Float[Array, \"b ... o\"]:\n    projected = self.projection_layer(x)\n    cos_features = jnp.cos(projected)\n    sin_features = jnp.sin(projected)\n    sq_wave_1 = jnp.sign(jnp.sin(projected + self.phase_shifts[0, :]))\n    sq_wave_2 = jnp.sign(jnp.sin(projected + self.phase_shifts[1, :]))\n    fourier_features = jnp.concatenate(\n      [cos_features, sin_features, sq_wave_1, sq_wave_2], axis=-1\n    )\n    residual = self.residual_layer(x)\n    return fourier_features + residual\n"
  },
  {
    "path": "src/timesfm/flax/normalization.py",
    "content": "# Copyright 2025 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Normalization layers for TimesFM.\"\"\"\n\nfrom flax import nnx\nimport jax\nimport jax.numpy as jnp\nimport jaxtyping\n\nArray = jaxtyping.Array\nBool = jaxtyping.Bool\nFloat = jaxtyping.Float\nInteger = jaxtyping.Integer\nNum = jaxtyping.Num\n\n\nclass RMSNorm(nnx.Module):\n  \"\"\"RMS normalization.\"\"\"\n\n  __data__ = (\"scale\",)\n\n  def __init__(\n    self,\n    num_features: int,\n    *,\n    epsilon: float = 1e-6,\n    rngs=nnx.Rngs(42),\n  ):\n    del rngs\n    self.scale = nnx.Param(jnp.zeros(shape=(num_features,)))\n    self.num_features = num_features\n    self.epsilon = epsilon\n\n  def __call__(self, inputs: Float[Array, \"b ... d\"]) -> Float[Array, \"b ... d\"]:\n    var = jnp.mean(jnp.square(inputs), axis=-1, keepdims=True)\n    normed_inputs = inputs * jax.lax.rsqrt(var + self.epsilon)\n    normed_inputs *= self.scale\n    return normed_inputs\n\n\nclass LayerNorm(nnx.Module):\n  \"\"\"Layer normalization replica of  LayerNorm.\"\"\"\n\n  __data__ = (\"scale\", \"bias\")\n\n  def __init__(self, num_features: int, *, epsilon: float = 1e-6, rngs=nnx.Rngs(42)):\n    del rngs\n    self.scale = nnx.Param(jnp.ones(shape=(num_features,)))\n    self.bias = nnx.Param(jnp.zeros(shape=(num_features,)))\n    self.num_features = num_features\n    self.epsilon = epsilon\n\n  def __call__(self, inputs: Float[Array, \"b ... d\"]) -> Float[Array, \"b ... d\"]:\n    mean = jnp.mean(inputs, axis=-1, keepdims=True)\n    var = jnp.mean(jnp.square(inputs - mean), axis=-1, keepdims=True)\n    normed_inputs = (inputs - mean) * jax.lax.rsqrt(var + self.epsilon)\n    normed_inputs *= self.scale\n    normed_inputs += self.bias\n    return normed_inputs\n"
  },
  {
    "path": "src/timesfm/flax/transformer.py",
    "content": "# Copyright 2025 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Transformer layers for TimesFM.\"\"\"\n\nimport functools\nfrom typing import Callable\n\nfrom flax import nnx\nfrom flax.nnx.nn import linear\nimport jax\nfrom jax import lax\nimport jax.numpy as jnp\nimport jaxtyping\n\nfrom .. import configs\nfrom . import normalization, util\n\nArray = jaxtyping.Array\nBool = jaxtyping.Bool\nFloat = jaxtyping.Float\nInteger = jaxtyping.Integer\nNum = jaxtyping.Num\nLayerNorm = normalization.LayerNorm\nRMSNorm = normalization.RMSNorm\nLinearGeneral = linear.LinearGeneral\nTransformerConfig = configs.TransformerConfig\nDecodeCache = util.DecodeCache\n\n\n@functools.partial(\n  jax.jit,\n  static_argnames=(\"query_length\", \"kv_length\"),\n)\ndef make_attn_mask(\n  query_length: int,\n  num_all_masked_kv: Integer[Array, \"b\"],\n  query_index_offset: Integer[Array, \"b\"] | None = None,\n  kv_length: int = 0,\n) -> Bool[Array, \"b 1 q n\"]:\n  \"\"\"Makes attention mask.\"\"\"\n\n  if kv_length == 0:\n    kv_length = query_length\n\n  q_index = jnp.arange(query_length)[None, None, :, None]\n  if query_index_offset is not None:\n    q_index += query_index_offset[:, None, None, None]\n  kv_index = jnp.arange(kv_length)[None, None, None, :]\n  return jnp.logical_and(\n    q_index >= kv_index,\n    kv_index >= num_all_masked_kv[:, None, None, None],\n  )\n\n\nclass RotaryPositionalEmbedding(nnx.Module):\n  \"\"\"Rotary positional embedding.\"\"\"\n\n  def __init__(\n    self,\n    embedding_dims: int,\n    min_timescale: int = 1,\n    max_timescale: int = 10000,\n  ):\n    self.embedding_dims = embedding_dims\n    self.min_timescale = min_timescale\n    self.max_timescale = max_timescale\n\n  def __call__(\n    self,\n    inputs: Float[Array, \"b ... d\"],\n    position: Array | None = None,\n  ):\n    \"\"\"Generates a JTensor of sinusoids with different frequencies.\"\"\"\n    if self.embedding_dims != inputs.shape[-1]:\n      raise ValueError(\n        \"The embedding dims of the rotary position embedding\"\n        \"must match the hidden dimension of the inputs.\"\n      )\n    half_embedding_dim = self.embedding_dims // 2\n    fraction = 2 * jnp.arange(0, half_embedding_dim) / self.embedding_dims\n    timescale = (\n      self.min_timescale * (self.max_timescale / self.min_timescale) ** fraction\n    )\n    if position is None:\n      seq_length = inputs.shape[1]\n      position = jnp.arange(seq_length, dtype=jnp.float32)[None, :]\n    if len(inputs.shape) == 4:\n      position = position[..., None, None]\n      timescale = timescale[None, None, None, :]\n    elif len(inputs.shape) == 3:\n      position = position[..., None]\n      timescale = timescale[None, None, :]\n    else:\n      raise ValueError(\"Inputs must be of rank 3 or 4.\")\n    sinusoid_inp = position / timescale\n    sin = jnp.sin(sinusoid_inp)\n    cos = jnp.cos(sinusoid_inp)\n    first_half, second_half = jnp.split(inputs, 2, axis=-1)\n    first_part = first_half * cos - second_half * sin\n    second_part = second_half * cos + first_half * sin\n    first_part = first_part.astype(None)\n    second_part = second_part.astype(None)\n    return jnp.concatenate([first_part, second_part], axis=-1)\n\n\nclass PerDimScale(nnx.Module):\n  \"\"\"Per-dimension scaling.\"\"\"\n\n  __data__ = (\"per_dim_scale\",)\n\n  def __init__(self, num_dims: int, *, rngs=nnx.Rngs(42)):\n    del rngs\n    self.num_dims = num_dims\n    self.per_dim_scale = nnx.Param(jnp.zeros(shape=(num_dims,)))\n\n  def __call__(self, x: Float[Array, \"b ... d\"]) -> Float[Array, \"b ... d\"]:\n    return x * (\n      1.442695041 / jnp.sqrt(self.num_dims) * jax.nn.softplus(self.per_dim_scale)\n    )\n\n\nclass MultiHeadAttention(nnx.Module):\n  \"\"\"Multi-head attention.\"\"\"\n\n  def __init__(\n    self,\n    num_heads: int,\n    in_features: int,\n    *,\n    use_per_dim_scale: bool = True,\n    use_rotary_position_embeddings: bool = True,\n    use_bias: bool = False,\n    deterministic: bool | None = None,\n    attention_fn: Callable[..., Array] = nnx.dot_product_attention,\n    qk_norm: str = \"rms\",\n    rngs=nnx.Rngs(42),\n  ):\n    self.num_heads = num_heads\n    self.in_features = in_features\n    self.qkv_features = in_features\n    self.out_features = in_features\n    self.in_kv_features = in_features\n    self.deterministic = deterministic\n    self.use_bias = use_bias\n    self.attention_fn = attention_fn\n    self.qk_norm = qk_norm\n\n    if self.qkv_features % self.num_heads != 0:\n      raise ValueError(\n        f\"Memory dimension ({self.qkv_features}) must be divisible by \"\n        f\"'num_heads' heads ({self.num_heads}).\"\n      )\n    self.head_dim = self.qkv_features // self.num_heads\n\n    linear_general = functools.partial(\n      LinearGeneral,\n      out_features=(self.num_heads, self.head_dim),\n      use_bias=self.use_bias,\n    )\n    # project inputs_q to multi-headed q/k/v\n    # dimensions are then [batch..., length, n_heads, n_features_per_head]\n    self.query = linear_general(self.in_features, rngs=rngs)\n    self.key = linear_general(self.in_kv_features, rngs=rngs)\n    self.value = linear_general(self.in_kv_features, rngs=rngs)\n\n    if self.qk_norm == \"rms\":\n      self.query_ln = RMSNorm(self.head_dim)\n      self.key_ln = RMSNorm(self.head_dim)\n    else:\n      self.query_ln = None\n      self.key_ln = None\n\n    self.out = LinearGeneral(\n      in_features=(self.num_heads, self.head_dim),\n      out_features=self.out_features,\n      axis=(-2, -1),\n      use_bias=self.use_bias,\n      rngs=rngs,\n    )\n\n    self.use_per_dim_scale = use_per_dim_scale\n    self.use_rotary_position_embeddings = use_rotary_position_embeddings\n    if self.use_rotary_position_embeddings:\n      self.rotary_position_embedding = RotaryPositionalEmbedding(\n        embedding_dims=self.head_dim,\n      )\n    else:\n      self.rotary_position_embedding = None\n\n    if use_per_dim_scale:\n      self.per_dim_scale = PerDimScale(num_dims=self.head_dim, rngs=rngs)\n    else:\n      self.per_dim_scale = None\n\n  def __call__(\n    self,\n    inputs_q: Array,\n    *,\n    decode_cache: DecodeCache | None = None,\n    patch_mask: Array | None = None,\n    deterministic: bool | None = None,\n    sow_weights: bool = False,\n  ) -> tuple[Float[Array, \"b ... o\"], DecodeCache | None]:\n    \"\"\"Applies multi-head dot product attention on the input data.\"\"\"\n    _, n_patches, input_in_features = inputs_q.shape\n    if input_in_features != self.in_features:\n      raise ValueError(\n        f\"Incompatible input dimension, got {input_in_features} \"\n        f\"but module expects {self.in_features}.\"\n      )\n    if patch_mask is None:\n      patch_mask = jnp.zeros_like(inputs_q.shape[:-1], dtype=jnp.bool)\n\n    # For query: rope -> ln -> per_dim_scale\n    query = self.query(inputs_q)\n    key = self.key(inputs_q)\n    value = self.value(inputs_q)\n\n    if decode_cache is None:\n      num_masked = jnp.sum(patch_mask.astype(jnp.int32), axis=-1, keepdims=False)\n      next_index = jnp.zeros_like(num_masked, dtype=jnp.int32)\n    else:\n      num_masked = (\n        jnp.sum(patch_mask.astype(jnp.int32), axis=-1, keepdims=False)\n        + decode_cache.num_masked\n      )\n      next_index = decode_cache.next_index\n\n    if self.use_rotary_position_embeddings:\n      position = (\n        jnp.arange(n_patches, dtype=jnp.int32)[None, :]\n        + next_index[:, None]\n        - num_masked[:, None]\n      )\n      query = self.rotary_position_embedding(query, position)\n      key = self.rotary_position_embedding(key, position)\n    if self.query_ln is not None:\n      query = self.query_ln(query)\n    if self.key_ln is not None:\n      key = self.key_ln(key)\n    if self.use_per_dim_scale:\n      query = self.per_dim_scale(query)\n\n    if decode_cache is not None:\n      # Cached decoding.\n      _, decode_cache_size, _, _ = decode_cache.value.shape\n      zero = jnp.array(0, dtype=lax.dtype(next_index.dtype))\n      start_indices = (zero, next_index[0], zero, zero)\n      key = lax.dynamic_update_slice(decode_cache.key, key, start_indices)\n      value = lax.dynamic_update_slice(decode_cache.value, value, start_indices)\n      decode_cache.key = key\n      decode_cache.value = value\n      decode_cache.next_index = next_index + n_patches\n      decode_cache.num_masked = num_masked\n      attn_mask = make_attn_mask(\n        query_length=n_patches,\n        num_all_masked_kv=num_masked,\n        query_index_offset=next_index,\n        kv_length=decode_cache_size,\n      )\n    else:\n      # Training\n      attn_mask = make_attn_mask(query_length=n_patches, num_all_masked_kv=num_masked)\n\n    # apply attention\n    x = self.attention_fn(\n      query * jnp.sqrt(self.head_dim),\n      key,\n      value,\n      mask=attn_mask,\n      deterministic=deterministic,\n      module=self if sow_weights else None,\n    )\n    # back to the original inputs dimensions\n    out = self.out(x)\n    return out, decode_cache\n\n\nclass Transformer(nnx.Module):\n  \"\"\"Classic Transformer used in TimesFM.\"\"\"\n\n  def __init__(self, config: TransformerConfig, *, rngs=nnx.Rngs(42)):\n    self.config = config\n\n    if config.attention_norm == \"rms\":\n      self.pre_attn_ln = RMSNorm(num_features=config.model_dims, rngs=rngs)\n      self.post_attn_ln = RMSNorm(num_features=config.model_dims, rngs=rngs)\n    else:\n      raise ValueError(f\"Layer norm: {config.attention_norm} not supported.\")\n\n    self.attn = MultiHeadAttention(\n      num_heads=config.num_heads,\n      in_features=config.model_dims,\n      use_per_dim_scale=True,\n      use_rotary_position_embeddings=config.use_rotary_position_embeddings,\n      qk_norm=config.qk_norm,\n      rngs=rngs,\n    )\n\n    if config.feedforward_norm == \"rms\":\n      self.pre_ff_ln = RMSNorm(num_features=config.model_dims, rngs=rngs)\n      self.post_ff_ln = RMSNorm(num_features=config.model_dims, rngs=rngs)\n    else:\n      raise ValueError(f\"Layer norm: {config.feedforward_norm} not supported.\")\n    self.ff0 = nnx.Linear(\n      in_features=config.model_dims,\n      out_features=config.hidden_dims,\n      use_bias=config.use_bias,\n      rngs=rngs,\n    )\n    self.ff1 = nnx.Linear(\n      in_features=config.hidden_dims,\n      out_features=config.model_dims,\n      use_bias=config.use_bias,\n      rngs=rngs,\n    )\n    if config.ff_activation == \"relu\":\n      self.activation = jax.nn.relu\n    elif config.ff_activation == \"swish\":\n      self.activation = jax.nn.swish\n    elif config.ff_activation == \"none\":\n      self.activation = lambda x: x\n    else:\n      raise ValueError(f\"Activation: {config.ff_activation} not supported.\")\n\n  def __call__(\n    self,\n    input_embeddings: Float[Array, \"b n d\"],\n    patch_mask: Bool[Array, \"b n\"],\n    decode_cache: DecodeCache | None = None,\n  ) -> tuple[Float[Array, \"b n d\"], DecodeCache | None]:\n    attn_output, decode_cache = self.attn(\n      inputs_q=self.pre_attn_ln(input_embeddings),\n      decode_cache=decode_cache,\n      patch_mask=patch_mask,\n      sow_weights=False,\n      deterministic=True,\n    )\n    attn_output = self.post_attn_ln(attn_output) + input_embeddings\n    output_embeddings = (\n      self.post_ff_ln(self.ff1(self.activation(self.ff0(self.pre_ff_ln(attn_output)))))\n      + attn_output\n    )\n    return output_embeddings, decode_cache\n"
  },
  {
    "path": "src/timesfm/flax/util.py",
    "content": "# Copyright 2025 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Flax utility functions for TimesFM layers.\"\"\"\n\nimport dataclasses\nimport functools\nimport jax\nimport jax.numpy as jnp\nimport jaxtyping\n\nFloat = jaxtyping.Float\nArray = jaxtyping.Array\nBool = jaxtyping.Bool\nInteger = jaxtyping.Integer\n\n_TOLERANCE = 1e-6\n\n\n@jax.tree_util.register_dataclass\n@dataclasses.dataclass(frozen=False)\nclass DecodeCache:\n  \"\"\"Cache for decoding.\"\"\"\n\n  next_index: Integer[Array, \"b\"]\n  num_masked: Integer[Array, \"b\"]\n  key: Float[Array, \"b n h d\"]\n  value: Float[Array, \"b n h d\"]\n\n\n@jax.jit\ndef update_running_stats(\n  n: Float[Array, \"b\"],\n  mu: Float[Array, \"b\"],\n  sigma: Float[Array, \"b\"],\n  x: Float[Array, \"b p\"],\n  mask: Bool[Array, \"b p\"],\n) -> tuple[\n  tuple[Float[Array, \"b\"], Float[Array, \"b\"], Float[Array, \"b\"]],\n  tuple[Float[Array, \"b\"], Float[Array, \"b\"], Float[Array, \"b\"]],\n]:\n  \"\"\"Updates the running stats.\"\"\"\n  is_legit = jnp.logical_not(mask)\n  inc_n = jnp.sum(is_legit.astype(jnp.float32), axis=-1, keepdims=False)\n  inc_mu = jnp.where(\n    inc_n == 0, 0.0, jnp.mean(x, axis=-1, keepdims=False, where=is_legit)\n  )\n  inc_sigma = jnp.where(\n    inc_n == 0, 0.0, jnp.std(x, axis=-1, keepdims=False, where=is_legit)\n  )\n  new_n = n + inc_n\n  new_mu = jnp.where(new_n == 0, 0.0, (n * mu + inc_mu * inc_n) / new_n)\n  new_sigma = jnp.sqrt(\n    jnp.where(\n      new_n == 0,\n      0.0,\n      (\n        n * sigma * sigma\n        + inc_n * inc_sigma * inc_sigma\n        + n * (mu - new_mu) * (mu - new_mu)\n        + inc_n * (inc_mu - new_mu) * (inc_mu - new_mu)\n      )\n      / new_n,\n    )\n  )\n  return (w := (new_n, new_mu, new_sigma), w)\n\n\ndef scan_along_axis(f, init, xs, axis: int, **kwargs):\n  \"\"\"Scans along an axis.\"\"\"\n  moved_xs = jax.tree_util.tree_map(lambda x: jnp.moveaxis(x, axis, 0), xs)\n  carry, moved_ys = jax.lax.scan(f, init, moved_xs, **kwargs)\n  return (\n    carry,\n    jax.tree_util.tree_map(lambda x: jnp.moveaxis(x, 0, axis), moved_ys),\n  )\n\n\n@functools.partial(jax.jit, static_argnames=(\"reverse\",))\ndef revin(\n  x: Float[Array, \"b ...\"],\n  mu: Float[Array, \"b ...\"],\n  sigma: Float[Array, \"b ...\"],\n  reverse: bool = False,\n):\n  \"\"\"Reversible per-instance normalization.\"\"\"\n  if len(mu.shape) == len(x.shape) - 1:\n    mu = mu[..., None]\n    sigma = sigma[..., None]\n  elif len(mu.shape) == len(x.shape) - 2:\n    mu = mu[..., None, None]\n    sigma = sigma[..., None, None]\n  if reverse:\n    return x * sigma + mu\n  else:\n    return (x - mu) / jnp.where(sigma < _TOLERANCE, 1.0, sigma)\n"
  },
  {
    "path": "src/timesfm/timesfm_2p5/timesfm_2p5_base.py",
    "content": "# Copyright 2025 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"TimesFM 2p5 base implementation.\"\"\"\n\nimport dataclasses\nfrom typing import Any, Callable, Sequence\n\nimport collections\nimport numpy as np\n\nfrom .. import configs\n\nResidualBlockConfig = configs.ResidualBlockConfig\nStackedTransformersConfig = configs.StackedTransformersConfig\nTransformerConfig = configs.TransformerConfig\nForecastConfig = configs.ForecastConfig\nCategory = int | str\nXRegMode = str\n\n\ndef strip_leading_nans(arr):\n  \"\"\"Removes contiguous NaN values from the beginning of a NumPy array.\n\n  Args:\n    arr: The input NumPy array.\n\n  Returns:\n    A new NumPy array with leading NaN values removed.\n    If the array is all NaNs or empty, returns an empty array.\n  \"\"\"\n\n  isnan = np.isnan(arr)\n  first_valid_index = np.argmax(~isnan)\n  return arr[first_valid_index:]\n\n\ndef linear_interpolation(arr):\n  \"\"\"Performs linear interpolation to fill NaN values in a 1D numpy array.\n\n  Args:\n      arr: The 1D numpy array containing NaN values.\n\n  Returns:\n      A new numpy array with NaN values filled using linear interpolation,\n      or the original array if no NaNs are present.\n      Returns None if the input is not a 1D array.\n      Returns the original array if there are no NaN values.\n  \"\"\"\n\n  nans = np.isnan(arr)\n  if not np.any(nans):  # Check if there are any NaNs\n    return arr\n\n  def x(z):\n    return z.nonzero()[0]\n\n  nans_indices = x(nans)\n  non_nans_indices = x(~nans)\n  non_nans_values = arr[~nans]\n\n  try:\n    arr[nans] = np.interp(nans_indices, non_nans_indices, non_nans_values)\n  except ValueError:\n    if non_nans_values:\n      mu = np.nanmean(arr)\n    else:\n      mu = 0.0\n    arr = np.where(np.isfinite(arr), arr, mu)\n  return arr\n\n\n@dataclasses.dataclass(frozen=True)\nclass TimesFM_2p5_200M_Definition:\n  \"\"\"Framework-agnostic config of TimesFM 2.5.\"\"\"\n\n  context_limit = 16384\n  input_patch_len: int = 32\n  output_patch_len: int = 128\n  output_quantile_len: int = 1024\n  quantiles: list[float] = dataclasses.field(\n    default_factory=lambda: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]\n  )\n  decode_index: int = 5\n  tokenizer: ResidualBlockConfig = ResidualBlockConfig(\n    input_dims=64,\n    hidden_dims=1280,\n    output_dims=1280,\n    use_bias=True,\n    activation=\"swish\",\n  )\n  stacked_transformers: StackedTransformersConfig = StackedTransformersConfig(\n    num_layers=20,\n    transformer=TransformerConfig(\n      model_dims=1280,\n      hidden_dims=1280,\n      num_heads=16,\n      attention_norm=\"rms\",\n      feedforward_norm=\"rms\",\n      qk_norm=\"rms\",\n      use_bias=False,\n      use_rotary_position_embeddings=True,\n      ff_activation=\"swish\",\n      fuse_qkv=True,\n    ),\n  )\n  output_projection_point: ResidualBlockConfig = ResidualBlockConfig(\n    input_dims=1280,\n    hidden_dims=1280,\n    output_dims=1280,\n    use_bias=False,\n    activation=\"swish\",\n  )\n  output_projection_quantiles: ResidualBlockConfig = ResidualBlockConfig(\n    input_dims=1280,\n    hidden_dims=1280,\n    output_dims=10240,\n    use_bias=False,\n    activation=\"swish\",\n  )\n\n\nclass TimesFM_2p5:\n  \"\"\"Abstract base class for TimesFM models.\n\n  Attributes:\n    forecast_config: Configuration for forecasting flags.\n    compiled_decode: Compiled decode function.\n    global_batch_size: Global batch size.\n  \"\"\"\n\n  forecast_config: ForecastConfig | None = None\n  compiled_decode: Callable[..., Any] | None = None\n  global_batch_size: int = 0\n\n  def load_checkpoint(self, path: str):\n    \"\"\"Loads a TimesFM model from a checkpoint.\"\"\"\n    raise NotImplementedError()\n\n  def compile(self, forecast_config: ForecastConfig | None = None):\n    \"\"\"Compiles the TimesFM model for fast decoding.\"\"\"\n    raise NotImplementedError()\n\n  def forecast(\n    self, horizon: int, inputs: list[np.ndarray]\n  ) -> tuple[np.ndarray, np.ndarray]:\n    \"\"\"Forecasts the time series.\"\"\"\n    if self.compiled_decode is None:\n      raise RuntimeError(\"Model is not compiled. Please call compile() first.\")\n\n    assert self.global_batch_size > 0\n    assert self.forecast_config is not None\n\n    context = self.forecast_config.max_context\n    num_inputs = len(inputs)\n    if (w := num_inputs % self.global_batch_size) != 0:\n      inputs += [np.array([0.0] * 3)] * (self.global_batch_size - w)\n\n    output_points = []\n    output_quantiles = []\n    values = []\n    masks = []\n    idx = 0\n    for each_input in inputs:\n      value = linear_interpolation(strip_leading_nans(np.array(each_input)))\n      if (w := len(value)) >= context:\n        value = value[-context:]\n        mask = np.zeros_like(value, dtype=bool)\n      else:\n        mask = np.array([True] * (context - w) + [False] * w)\n        value = np.pad(value, (context - w, 0), \"constant\", constant_values=0.0)\n      values.append(value)\n      masks.append(mask)\n      idx += 1\n      if idx == self.global_batch_size:\n        idx = 0\n        point_forecast, quantile_forecast = self.compiled_decode(horizon, values, masks)\n        output_points.append(point_forecast)\n        output_quantiles.append(quantile_forecast)\n        values = []\n        masks = []\n\n    output_points = np.concatenate(output_points, axis=0)\n    output_quantiles = np.concatenate(output_quantiles, axis=0)\n    return output_points[:num_inputs], output_quantiles[:num_inputs]\n\n  def forecast_with_covariates(\n    self,\n    inputs: list[Sequence[float]],\n    dynamic_numerical_covariates: dict[str, Sequence[Sequence[float]]] | None = None,\n    dynamic_categorical_covariates: (\n      dict[str, Sequence[Sequence[Category]]] | None\n    ) = None,\n    static_numerical_covariates: dict[str, Sequence[float]] | None = None,\n    static_categorical_covariates: dict[str, Sequence[Category]] | None = None,\n    xreg_mode: XRegMode = \"xreg + timesfm\",\n    normalize_xreg_target_per_input: bool = True,\n    ridge: float = 0.0,\n    max_rows_per_col: int = 0,\n    force_on_cpu: bool = False,\n  ):\n    \"\"\"Forecasts on a list of time series with covariates.\n\n    To optimize inference speed, avoid string valued categorical covariates.\n\n    Args:\n      inputs: A list of time series forecast contexts. Each context time series\n        should be in a format convertible to JTensor by `jnp.array`.\n      dynamic_numerical_covariates: A dict of dynamic numerical covariates.\n      dynamic_categorical_covariates: A dict of dynamic categorical covariates.\n      static_numerical_covariates: A dict of static numerical covariates.\n      static_categorical_covariates: A dict of static categorical covariates.\n      xreg_mode: one of \"xreg + timesfm\" or \"timesfm + xreg\". \"timesfm + xreg\"\n        fits a model on the residuals of the TimesFM forecast. \"xreg + timesfm\"\n        fits a model on the targets then forecasts on the residuals via TimesFM.\n      normalize_xreg_target_per_input: whether to normalize the xreg target per\n        input in the given batch.\n      ridge: ridge penalty for the linear model.\n      max_rows_per_col: max number of rows per column for the linear model.\n      force_on_cpu: whether to force running on cpu for the linear model.\n\n    Returns:\n      A tuple of two lists. The first is the outputs of the model. The second is\n      the outputs of the xreg.\n    \"\"\"\n    if self.forecast_config is None:\n      raise ValueError(\"Model is not compiled. Please call compile() first.\")\n    elif not self.forecast_config.return_backcast:\n      raise ValueError(\n        \"For XReg, `return_backcast` must be set to True in the forecast config. Please recompile the model.\"\n      )\n\n    from ..utils import xreg_lib\n\n    # Verify and bookkeep covariates.\n    if not (\n      dynamic_numerical_covariates\n      or dynamic_categorical_covariates\n      or static_numerical_covariates\n      or static_categorical_covariates\n    ):\n      raise ValueError(\n        \"At least one of dynamic_numerical_covariates,\"\n        \" dynamic_categorical_covariates, static_numerical_covariates,\"\n        \" static_categorical_covariates must be set.\"\n      )\n\n    # Track the lengths of (1) each input, (2) the part that can be used in the\n    # linear model, and (3) the horizon.\n    input_lens, train_lens, test_lens = [], [], []\n\n    for i, input_ts in enumerate(inputs):\n      input_len = len(input_ts)\n      input_lens.append(input_len)\n\n      if xreg_mode == \"timesfm + xreg\":\n        # For fitting residuals, no TimesFM forecast on the first patch.\n        train_lens.append(max(0, input_len - self.model.p))\n      elif xreg_mode == \"xreg + timesfm\":\n        train_lens.append(input_len)\n      else:\n        raise ValueError(f\"Unsupported mode: {xreg_mode}\")\n\n      if dynamic_numerical_covariates:\n        test_lens.append(\n          len(list(dynamic_numerical_covariates.values())[0][i]) - input_len\n        )\n      elif dynamic_categorical_covariates:\n        test_lens.append(\n          len(list(dynamic_categorical_covariates.values())[0][i]) - input_len\n        )\n      else:\n        test_lens.append(self.forecast_config.max_horizon)\n\n      if test_lens[-1] > self.forecast_config.max_horizon:\n        raise ValueError(\n          \"Forecast horizon length inferred from the dynamic covaraites is longer than the\"\n          f\"max_horizon defined in the forecast config: {test_lens[-1]} > {self.forecast_config.max_horizon=}.\"\n        )\n\n    # Prepare the covariates into train and test.\n    train_dynamic_numerical_covariates = collections.defaultdict(list)\n    test_dynamic_numerical_covariates = collections.defaultdict(list)\n    train_dynamic_categorical_covariates = collections.defaultdict(list)\n    test_dynamic_categorical_covariates = collections.defaultdict(list)\n    for covariates, train_covariates, test_covariates in (\n      (\n        dynamic_numerical_covariates,\n        train_dynamic_numerical_covariates,\n        test_dynamic_numerical_covariates,\n      ),\n      (\n        dynamic_categorical_covariates,\n        train_dynamic_categorical_covariates,\n        test_dynamic_categorical_covariates,\n      ),\n    ):\n      if not covariates:\n        continue\n      for covariate_name, covariate_values in covariates.items():\n        for input_len, train_len, covariate_value in zip(\n          input_lens, train_lens, covariate_values\n        ):\n          train_covariates[covariate_name].append(\n            covariate_value[(input_len - train_len) : input_len]\n          )\n          test_covariates[covariate_name].append(covariate_value[input_len:])\n\n    # Fit models.\n    if xreg_mode == \"timesfm + xreg\":\n      # Forecast via TimesFM then fit a model on the residuals.\n      point_outputs, quantile_outputs = self.forecast(\n        horizon=self.forecast_config.max_horizon, inputs=inputs\n      )\n      targets = [\n        (\n          np.array(input_ts)[-train_len:]\n          - point_output[: -self.forecast_config.max_horizon][-train_len:]\n        )\n        for input_ts, point_output, train_len in zip(inputs, point_outputs, train_lens)\n      ]\n      per_instance_stats = None\n      if normalize_xreg_target_per_input:\n        targets, per_instance_stats = xreg_lib.normalize(targets)\n      xregs = xreg_lib.BatchedInContextXRegLinear(\n        targets=targets,\n        train_lens=train_lens,\n        test_lens=test_lens,\n        train_dynamic_numerical_covariates=train_dynamic_numerical_covariates,\n        test_dynamic_numerical_covariates=test_dynamic_numerical_covariates,\n        train_dynamic_categorical_covariates=train_dynamic_categorical_covariates,\n        test_dynamic_categorical_covariates=test_dynamic_categorical_covariates,\n        static_numerical_covariates=static_numerical_covariates,\n        static_categorical_covariates=static_categorical_covariates,\n      ).fit(\n        ridge=ridge,\n        one_hot_encoder_drop=None if ridge > 0 else \"first\",\n        max_rows_per_col=max_rows_per_col,\n        force_on_cpu=force_on_cpu,\n        debug_info=False,\n        assert_covariates=True,\n        assert_covariate_shapes=True,\n      )\n      if normalize_xreg_target_per_input:\n        xregs = xreg_lib.renormalize(xregs, per_instance_stats)\n      xregs = np.array(xregs)\n      new_point_outputs = [\n        (point_output[-self.forecast_config.max_horizon :][:test_len] + xreg)\n        for point_output, test_len, xreg in zip(point_outputs, test_lens, xregs)\n      ]\n      new_quantile_outputs = [\n        (\n          quantile_output[-self.forecast_config.max_horizon :][:test_len]\n          + xreg[..., None]\n        )\n        for quantile_output, test_len, xreg in zip(quantile_outputs, test_lens, xregs)\n      ]\n\n    else:\n      # Fit a model on the targets then forecast on the residuals via TimesFM.\n      targets = [\n        np.array(input_ts)[-train_len:]\n        for input_ts, train_len in zip(inputs, train_lens)\n      ]\n      per_instance_stats = None\n      if normalize_xreg_target_per_input:\n        targets, per_instance_stats = xreg_lib.normalize(targets)\n      xregs, xregs_on_context, _, _, _ = xreg_lib.BatchedInContextXRegLinear(\n        targets=targets,\n        train_lens=train_lens,\n        test_lens=test_lens,\n        train_dynamic_numerical_covariates=train_dynamic_numerical_covariates,\n        test_dynamic_numerical_covariates=test_dynamic_numerical_covariates,\n        train_dynamic_categorical_covariates=train_dynamic_categorical_covariates,\n        test_dynamic_categorical_covariates=test_dynamic_categorical_covariates,\n        static_numerical_covariates=static_numerical_covariates,\n        static_categorical_covariates=static_categorical_covariates,\n      ).fit(\n        ridge=ridge,\n        one_hot_encoder_drop=None if ridge > 0 else \"first\",\n        max_rows_per_col=max_rows_per_col,\n        force_on_cpu=force_on_cpu,\n        debug_info=True,\n        assert_covariates=True,\n        assert_covariate_shapes=True,\n      )\n      point_outputs, quantile_outputs = self.forecast(\n        horizon=self.forecast_config.max_horizon,\n        inputs=[\n          target - xreg_on_context\n          for target, xreg_on_context in zip(targets, xregs_on_context)\n        ],\n      )\n      new_point_outputs = [\n        (point_output[-self.forecast_config.max_horizon :][:test_len] + xreg)\n        for point_output, test_len, xreg in zip(point_outputs, test_lens, xregs)\n      ]\n      new_quantile_outputs = [\n        (\n          quantile_output[-self.forecast_config.max_horizon :][:test_len]\n          + xreg[..., None]\n        )\n        for quantile_output, test_len, xreg in zip(quantile_outputs, test_lens, xregs)\n      ]\n      if normalize_xreg_target_per_input:\n        new_point_outputs = xreg_lib.renormalize(new_point_outputs, per_instance_stats)\n        new_quantile_outputs = xreg_lib.renormalize(\n          new_quantile_outputs, per_instance_stats\n        )\n\n    return new_point_outputs, new_quantile_outputs\n"
  },
  {
    "path": "src/timesfm/timesfm_2p5/timesfm_2p5_flax.py",
    "content": "# Copyright 2025 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"TimesFM models in Flax.\"\"\"\n\nimport dataclasses\nimport functools\nimport gc\nimport logging\nimport math\nimport os\nfrom pathlib import Path\nfrom typing import Any, Callable, Dict\n\nimport einshape\nfrom flax import nnx\nimport huggingface_hub\nimport jax\nimport jax.numpy as jnp\nimport jaxtyping\nimport numpy as np\nimport orbax.checkpoint as ocp\n\nfrom .. import configs\nfrom ..flax import dense, transformer, util\nfrom . import timesfm_2p5_base\n\njax_einshape = einshape.jax_einshape\nscan = util.scan_along_axis\nrevin = util.revin\n\nFloat = jaxtyping.Float\nBool = jaxtyping.Bool\nArray = jaxtyping.Array\n\n\ndef try_gc():\n  for d in jax.local_devices():\n    stats = d.memory_stats()\n    if stats is None:\n      return\n    if stats[\"bytes_in_use\"] / stats[\"bytes_limit\"] > 0.75:\n      gc.collect()\n      break\n\n\n@nnx.vmap(in_axes=(None, 0), out_axes=0)\ndef _create_stacked_transformers(\n  config: configs.StackedTransformersConfig, key: jax.Array\n):\n  return transformer.Transformer(config.transformer, rngs=nnx.Rngs(key))\n\n\ndef _scan_along_axis(f, init, xs, axis: int, **kwargs):\n  \"\"\"Scans along an axis.\"\"\"\n  moved_xs = jax.tree_util.tree_map(lambda x: jnp.moveaxis(x, axis, 0), xs)\n  carry, moved_ys = jax.lax.scan(f, init, moved_xs, **kwargs)\n  return (\n    carry,\n    jax.tree_util.tree_map(lambda x: jnp.moveaxis(x, 0, axis), moved_ys),\n  )\n\n\n@nnx.scan(in_axes=(0, nnx.Carry, None, 0), out_axes=(nnx.Carry, 0))\ndef _apply_stacked_transformers(\n  model: transformer.Transformer,\n  x: Float[Array, \"b n d\"],\n  m: Float[Array, \"b n\"],\n  decode_cache: util.DecodeCache | None = None,\n) -> Float[Array, \"b n d\"]:\n  return model(x, m, decode_cache=decode_cache)\n\n\nclass TimesFM_2p5_200M_flax_module(nnx.Module):  # pylint: disable=invalid-name\n  \"\"\"TimesFM 2.5 with 200M parameters.\"\"\"\n\n  config = timesfm_2p5_base.TimesFM_2p5_200M_Definition()\n  decode_index: int = 5\n  compiled_decode: Callable[..., Any] | None = None\n  backend: str = \"\"\n  context: int = 0\n  horizon: int = 0\n  per_core_batch_size: int = 0\n\n  def __init__(self):\n    super().__init__()\n    self.backend = jax.devices()[0].platform\n    self.num_devices = len(jax.devices(self.backend))\n\n    # Names constants.\n    self.p = self.config.input_patch_len  # 32\n    self.o = self.config.output_patch_len  # 128\n    self.os = self.config.output_quantile_len  # 1024\n    self.m = self.o // self.p  # 4\n    self.x = self.config.stacked_transformers.num_layers  # 20\n    self.h = self.config.stacked_transformers.transformer.num_heads  # 16\n    self.md = self.config.stacked_transformers.transformer.model_dims  # 1280\n    self.hd = self.md // self.h  # 80\n    self.q = len(self.config.quantiles) + 1  # 10\n    self.aridx = self.config.decode_index  # 5\n\n    # Layers.\n    self.tokenizer = dense.ResidualBlock(self.config.tokenizer)\n    self.stacked_xf = _create_stacked_transformers(\n      self.config.stacked_transformers,\n      jax.random.split(jax.random.key(42), self.x),\n    )\n    self.output_projection_point = dense.ResidualBlock(\n      self.config.output_projection_point\n    )\n    self.output_projection_quantiles = dense.ResidualBlock(\n      self.config.output_projection_quantiles\n    )\n\n  def __call__(\n    self,\n    inputs: Float[Array, \"b n p\"],\n    masks: Bool[Array, \"b n p\"],\n    decode_cache: util.DecodeCache | None = None,\n  ):\n    tokenizer_inputs = jnp.concatenate([inputs, masks.astype(inputs.dtype)], axis=-1)\n    input_embeddings = self.tokenizer(tokenizer_inputs)\n    if decode_cache is None:\n      decode_cache = [None] * self.x\n    output_embeddings, decode_cache = _apply_stacked_transformers(\n      self.stacked_xf, input_embeddings, masks[..., -1], decode_cache\n    )\n    output_ts = self.output_projection_point(output_embeddings)\n    output_quantile_spread = self.output_projection_quantiles(output_embeddings)\n    return (\n      input_embeddings,\n      output_embeddings,\n      output_ts,\n      output_quantile_spread,\n    ), decode_cache\n\n  @nnx.jit(static_argnames=(\"horizon\",))\n  def decode(self, horizon: int, inputs, masks):\n    batch_size, context = inputs.shape[0], inputs.shape[1]\n    num_decode_steps = (horizon - 1) // self.o\n    num_input_patches = context // self.p\n    decode_cache_size = num_input_patches + num_decode_steps * self.m\n\n    # Prefill\n    patched_inputs = jax_einshape(\"b(np)->bnp\", inputs, b=batch_size, p=self.p)\n    patched_masks = jax_einshape(\"b(np)->bnp\", masks, b=batch_size, p=self.p)\n    (last_n, last_mu, last_sigma), (_, context_mu, context_sigma) = scan(\n      lambda carry, xs: util.update_running_stats(*carry, *xs),\n      init=(zero := jnp.zeros(shape=(batch_size)), zero, zero),\n      xs=(patched_inputs, patched_masks),\n      axis=1,\n    )\n    decode_cache = util.DecodeCache(\n      next_index=jnp.zeros(shape=(self.x, batch_size), dtype=jnp.int32),\n      num_masked=jnp.zeros(shape=(self.x, batch_size), dtype=jnp.int32),\n      key=jnp.zeros(shape=(self.x, batch_size, decode_cache_size, self.h, self.hd)),\n      value=jnp.zeros(shape=(self.x, batch_size, decode_cache_size, self.h, self.hd)),\n    )\n    normed_inputs = revin(patched_inputs, context_mu, context_sigma, reverse=False)\n    normed_inputs = jnp.where(patched_masks, 0.0, normed_inputs)\n    (_, _, normed_outputs, normed_quantile_spread), decode_cache = self(\n      normed_inputs, patched_masks, decode_cache\n    )\n    renormed_outputs = jax_einshape(\n      \"bn(oq)->bnoq\",\n      revin(normed_outputs, context_mu, context_sigma, reverse=True),\n      o=self.o,\n      q=self.q,\n    )\n    renormed_quantile_spread = jax_einshape(\n      \"bn(oq)->bnoq\",\n      revin(normed_quantile_spread, context_mu, context_sigma, reverse=True),\n      o=self.os,\n      q=self.q,\n    )[:, -1, ...]\n\n    # Autogressive decode\n    @nnx.scan(in_axes=(None, nnx.Carry, 0), out_axes=(nnx.Carry, 1))\n    def _ar_decode(module, carry, unused_iter):\n      last_renormed_output, (last_n, last_mu, last_sigma), decode_cache = carry\n      new_patched_input = jax_einshape(\n        \"b(mp)->bmp\", last_renormed_output, m=module.m, p=module.p\n      )\n      new_mask = jnp.zeros_like(new_patched_input, dtype=jnp.bool)\n      carry_stats, (_, new_mu, new_sigma) = scan(\n        lambda carry, xs: util.update_running_stats(*carry, *xs),\n        init=(last_n, last_mu, last_sigma),\n        xs=(new_patched_input, new_mask),\n        axis=1,\n      )\n      new_normed_input = revin(new_patched_input, new_mu, new_sigma, reverse=False)\n      (_, _, new_normed_output, _), decode_cache = module(\n        new_normed_input, new_mask, decode_cache\n      )\n      new_renormed_output = jax_einshape(\n        \"bm(oq)->bmoq\",\n        revin(new_normed_output, new_mu, new_sigma, reverse=True),\n        o=module.o,\n        q=module.q,\n      )[..., -1, :, :]\n\n      return (\n        (\n          new_renormed_output[..., module.decode_index],\n          carry_stats,\n          decode_cache,\n        ),\n        new_renormed_output,\n      )\n\n    if num_decode_steps > 0:\n      _, ar_renormed_outputs = _ar_decode(\n        self,\n        (\n          renormed_outputs[..., -1, :, self.decode_index],\n          (last_n, last_mu, last_sigma),\n          decode_cache,\n        ),\n        jnp.arange(num_decode_steps),\n      )\n    else:\n      ar_renormed_outputs = None\n\n    return renormed_outputs, renormed_quantile_spread, ar_renormed_outputs\n\n  def compile(\n    self,\n    context: int,\n    horizon: int,\n    per_core_batch_size: int = 1,\n  ):\n    if context % self.p != 0:\n      logging.info(\n        \"When compiling, context needs to be multiple of the patch size %d.\"\n        \" Modifying context to %d.\",\n        self.p,\n        context := math.ceil(context / self.p) * self.p,\n      )\n    if horizon % self.o != 0:\n      logging.info(\n        \"When compiling, horizon needs to be multiple of the output patch\"\n        \" size %d. Modifying horizon to %d.\",\n        self.o,\n        horizon := math.ceil(horizon / self.o) * self.o,\n      )\n\n    self.context = context\n    self.horizon = horizon\n    self.per_core_batch_size = per_core_batch_size\n\n    @nnx.pmap(\n      in_axes=(None, None, 0, 0),\n      out_axes=(0, 0, 0),\n      devices=jax.devices(self.backend),\n      axis_size=self.num_devices,\n      static_broadcasted_argnums=(1,),\n      axis_name=\"global_batch\",\n    )\n    def compiled_decode_kernel(model, horizon, inputs, masks):\n      return model.decode(horizon, inputs, masks)\n\n    self.compiled_decode = functools.partial(compiled_decode_kernel, self)\n\n\ndef _flip_quantile_fn(x):\n  return jnp.concatenate([x[..., :1], jnp.flip(x[..., 1:], axis=-1)], axis=-1)\n\n\n@functools.partial(\n  jax.jit,\n  donate_argnums=(0, 1, 2),\n)\ndef _force_flip_invariance_fn(\n  flipped_pf_outputs,\n  flipped_quantile_spreads,\n  flipped_ar_outputs,\n):\n  \"\"\"Forces flip invariance.\"\"\"\n  flipped_pf_outputs = _flip_quantile_fn(flipped_pf_outputs)\n  flipped_pf_outputs = jax_einshape(\"tb...->(tb)...\", flipped_pf_outputs)\n  flipped_quantile_spreads = _flip_quantile_fn(flipped_quantile_spreads)\n  flipped_quantile_spreads = jax_einshape(\"tb...->(tb)...\", flipped_quantile_spreads)\n  to_concat = [flipped_pf_outputs[:, -1, ...]]\n  if flipped_ar_outputs is not None:\n    flipped_ar_outputs = _flip_quantile_fn(flipped_ar_outputs)\n    flipped_ar_outputs = jax_einshape(\"tbno...->(tb)(no)...\", flipped_ar_outputs)\n    to_concat.append(flipped_ar_outputs)\n  flipped_full_forecast = jnp.concatenate(to_concat, axis=1)\n\n  return flipped_quantile_spreads, flipped_pf_outputs, flipped_full_forecast\n\n\n@functools.partial(\n  jax.jit,\n  static_argnames=(\"max_horizon\",),\n  donate_argnums=(0,),\n)\ndef _use_continuous_quantile_head_fn(full_forecast, quantile_spreads, max_horizon):\n  \"\"\"Uses continuous quantile head.\"\"\"\n  to_stack = [full_forecast[..., :max_horizon, 0]]\n  for quantile_index in [1, 2, 3, 4]:\n    to_stack.append(\n      quantile_spreads[:, :max_horizon, quantile_index]\n      - quantile_spreads[:, :max_horizon, 5]\n      + full_forecast[:, :max_horizon, 5]\n    )\n  to_stack.append(full_forecast[..., :max_horizon, 5])\n  for quantile_index in [6, 7, 8, 9]:\n    to_stack.append(\n      quantile_spreads[:, :max_horizon, quantile_index]\n      - quantile_spreads[:, :max_horizon, 5]\n      + full_forecast[:, :max_horizon, 5]\n    )\n  return jnp.stack(to_stack, axis=-1)\n\n\n@functools.partial(jax.jit, donate_argnums=(0,))\ndef _fix_quantile_crossing_fn(full_forecast):\n  \"\"\"Fixes quantile crossing.\"\"\"\n  lower_quantiles = _scan_along_axis(\n    lambda carry, x: (w := jnp.minimum(carry, x), w),\n    init=full_forecast[..., 5],\n    xs=full_forecast[..., 1:5],\n    axis=-1,\n    reverse=True,\n  )[1]\n  upper_quantiles = _scan_along_axis(\n    lambda carry, x: (w := jnp.maximum(carry, x), w),\n    init=full_forecast[..., 5],\n    xs=full_forecast[..., 6:10],\n    axis=-1,\n    reverse=False,\n  )[1]\n  return jnp.concatenate(\n    [\n      full_forecast[..., :1],\n      lower_quantiles,\n      full_forecast[..., 5:6],\n      upper_quantiles,\n    ],\n    axis=-1,\n  )\n\n\n@functools.partial(jax.jit, static_argnames=(\"fc\",), donate_argnums=(1, 2))\ndef _before_model_decode(fc, inputs, masks):\n  \"\"\"All Jax steps before model decode call.\"\"\"\n  if fc.infer_is_positive:\n    is_positive = jnp.all(inputs >= 0, axis=-1, keepdims=True)\n  else:\n    is_positive = None\n\n  if fc.normalize_inputs:\n    mu = jnp.mean(inputs, axis=-1, keepdims=True)\n    sigma = jnp.std(inputs, axis=-1, keepdims=True)\n    inputs = revin(inputs, mu, sigma, reverse=False)\n  else:\n    mu, sigma = None, None\n\n  inputs = jax_einshape(\"(tb)...->tb...\", inputs, b=fc.per_core_batch_size)\n  masks = jax_einshape(\"(tb)...->tb...\", masks, b=fc.per_core_batch_size)\n\n  return inputs, masks, is_positive, mu, sigma\n\n\n@functools.partial(\n  jax.jit,\n  static_argnames=(\n    \"fc\",\n    \"p\",\n  ),\n  donate_argnums=(1, 2, 3, 4, 5, 6, 7, 8, 9),\n)\ndef _after_model_decode(\n  fc,\n  pf_outputs,\n  quantile_spreads,\n  ar_outputs,\n  flipped_pf_outputs,\n  flipped_quantile_spreads,\n  flipped_ar_outputs,\n  is_positive,\n  mu,\n  sigma,\n  p,\n):\n  \"\"\"All Jax steps after model decode call.\"\"\"\n  # t: num_devices, b: per_core_batch_size\n  pf_outputs = jax_einshape(\"tb...->(tb)...\", pf_outputs)\n  quantile_spreads = jax_einshape(\"tb...->(tb)...\", quantile_spreads)\n  to_concat = [pf_outputs[:, -1, ...]]\n  if ar_outputs is not None:\n    ar_outputs = jax_einshape(\"tbno...->(tb)(no)...\", ar_outputs)\n    to_concat.append(ar_outputs)\n  full_forecast = jnp.concatenate(to_concat, axis=1)\n\n  if fc.force_flip_invariance:\n    (\n      flipped_quantile_spreads,\n      flipped_pf_outputs,\n      flipped_full_forecast,\n    ) = _force_flip_invariance_fn(\n      flipped_pf_outputs, flipped_quantile_spreads, flipped_ar_outputs\n    )\n    quantile_spreads = (quantile_spreads - flipped_quantile_spreads) / 2\n    pf_outputs = (pf_outputs - flipped_pf_outputs) / 2\n    full_forecast = (full_forecast - flipped_full_forecast) / 2\n\n  if fc.use_continuous_quantile_head:\n    full_forecast = _use_continuous_quantile_head_fn(\n      full_forecast, quantile_spreads, fc.max_horizon\n    )\n\n  if fc.return_backcast:\n    full_backcast = jax_einshape(\"...npq->...(np)q\", pf_outputs[:, :-1, :p, :])\n    full_forecast = jnp.concatenate([full_backcast, full_forecast], axis=1)\n\n  if fc.fix_quantile_crossing:\n    full_forecast = _fix_quantile_crossing_fn(full_forecast)\n\n  if fc.normalize_inputs:\n    full_forecast = revin(full_forecast, mu, sigma, reverse=True)\n\n  if is_positive is not None:\n    full_forecast = jnp.where(\n      is_positive[..., None],\n      jnp.maximum(full_forecast, jnp.zeros_like(full_forecast)),\n      full_forecast,\n    )\n\n  return full_forecast\n\n\nclass TimesFM_2p5_200M_flax(timesfm_2p5_base.TimesFM_2p5):\n  \"\"\"Flax implementation of TimesFM 2.5 with 200M parameters.\"\"\"\n\n  model: nnx.Module = TimesFM_2p5_200M_flax_module()\n\n  @classmethod\n  def from_pretrained(\n    cls,\n    model_id: str = \"google/timesfm-2.5-200m-flax\",\n    *,\n    revision: str | None = None,\n    cache_dir: str | Path | None = None,\n    force_download: bool = False,\n    proxies: Dict | None = None,\n    resume_download: bool | None = None,\n    local_files_only: bool | None = None,\n    token: str | None = None,\n    **model_kwargs,\n  ):\n    \"\"\"Loads a Flax TimesFM model.\"\"\"\n\n    # Create an instance of the model wrapper class.\n    instance = cls(**model_kwargs)\n\n    # Determine the path to the model weights.\n    model_file_path = \"\"\n    if os.path.isdir(model_id):\n      logging.info(\"Loading checkpoint from local directory: %s\", model_id)\n      model_file_path = model_id\n    else:\n      logging.info(\"Downloading checkpoint from Hugging Face repo %s\", model_id)\n      model_file_path = huggingface_hub.snapshot_download(\n        repo_id=model_id,\n        revision=revision,\n        cache_dir=cache_dir,\n        force_download=force_download,\n        proxies=proxies,\n        resume_download=resume_download,\n        token=token,\n        local_files_only=local_files_only,\n      )\n      logging.info(\"Loading checkpoint from: %s\", model_file_path)\n\n    checkpointer = ocp.StandardCheckpointer()\n    graph, state = nnx.split(instance.model)\n    state = checkpointer.restore(model_file_path, state)\n    instance.model = nnx.merge(graph, state)\n    return instance\n\n  def compile(\n    self,\n    forecast_config: configs.ForecastConfig,\n    dryrun: bool = True,\n    **kwargs\n  ):\n    # Acrobym used during validation.\n    print(\"Compiling model...\")\n\n    fc = forecast_config\n    if fc.max_context % self.model.p != 0:\n      logging.info(\n        \"When compiling, max context needs to be multiple of the patch size\"\n        \" %d. Using max context = %d instead.\",\n        self.model.p,\n        new_context := math.ceil(fc.max_context / self.model.p) * self.model.p,\n      )\n      fc = dataclasses.replace(fc, max_context=new_context)\n    if fc.max_horizon % self.model.o != 0:\n      logging.info(\n        \"When compiling, max horizon needs to be multiple of the output patch\"\n        \" size %d. Using max horizon = %d instead.\",\n        self.model.o,\n        new_horizon := math.ceil(fc.max_horizon / self.model.o) * self.model.o,\n      )\n      fc = dataclasses.replace(fc, max_horizon=new_horizon)\n    if fc.max_context + fc.max_horizon > self.model.config.context_limit:\n      raise ValueError(\n        \"Context + horizon must be less than the context limit.\"\n        f\" {fc.max_context} + {fc.max_horizon} >\"\n        f\" {self.model.config.context_limit}.\"\n      )\n    if fc.use_continuous_quantile_head and (fc.max_horizon > self.model.os):\n      raise ValueError(\n        f\"Continuous quantile head is not supported for horizons > {self.model.os}.\"\n      )\n\n    self.forecast_config = fc\n    self.model.compile(\n      context=self.forecast_config.max_context,\n      horizon=self.forecast_config.max_horizon,\n      per_core_batch_size=fc.per_core_batch_size,\n    )\n    self.per_core_batch_size = self.forecast_config.per_core_batch_size\n    self.num_devices = self.model.num_devices\n    self.global_batch_size = (\n      self.forecast_config.per_core_batch_size * self.model.num_devices\n    )\n\n    def compiled_decode_kernel(fc, horizon, inputs, masks):\n      inputs = jnp.array(inputs, dtype=jnp.float32)\n      masks = jnp.array(masks, dtype=jnp.bool)\n      if horizon > fc.max_horizon:\n        raise ValueError(\n          f\"Horizon must be less than the max horizon. {horizon} > {fc.max_horizon}.\"\n        )\n      to_trim = fc.max_horizon - horizon\n\n      inputs, masks, is_positive, mu, sigma = _before_model_decode(fc, inputs, masks)\n\n      pf_outputs, quantile_spreads, ar_outputs = self.model.compiled_decode(\n        fc.max_horizon, inputs, masks\n      )\n      if fc.force_flip_invariance:\n        flipped_pf_outputs, flipped_quantile_spreads, flipped_ar_outputs = (\n          self.model.compiled_decode(fc.max_horizon, -inputs, masks)\n        )\n      else:\n        flipped_pf_outputs, flipped_quantile_spreads, flipped_ar_outputs = (\n          None,\n          None,\n          None,\n        )\n\n      full_forecast = _after_model_decode(\n        fc,\n        pf_outputs,\n        quantile_spreads,\n        ar_outputs,\n        flipped_pf_outputs,\n        flipped_quantile_spreads,\n        flipped_ar_outputs,\n        is_positive,\n        mu,\n        sigma,\n        self.model.p,\n      )\n      full_forecast_np = np.array(full_forecast)\n      del full_forecast\n      try_gc()\n      if to_trim > 0:\n        full_forecast_np = full_forecast_np[..., :-to_trim, :]\n      return full_forecast_np[..., 5], full_forecast_np\n\n    self.compiled_decode = functools.partial(\n      compiled_decode_kernel, self.forecast_config\n    )\n\n    if dryrun:\n      _ = self.compiled_decode(\n        self.forecast_config.max_horizon,\n        jnp.zeros(\n          (self.global_batch_size, self.forecast_config.max_context), dtype=jnp.float32\n        ),\n        jnp.zeros(\n          (self.global_batch_size, self.forecast_config.max_context), dtype=jnp.bool\n        ),\n      )\n    print(\"Compiling done.\")\n"
  },
  {
    "path": "src/timesfm/timesfm_2p5/timesfm_2p5_torch.py",
    "content": "# Copyright 2025 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"TimesFM models.\"\"\"\n\nimport dataclasses\nimport logging\nimport math\nimport os\nfrom pathlib import Path\nfrom typing import Optional, Sequence, Union\n\nimport numpy as np\nimport torch\nfrom huggingface_hub import PyTorchModelHubMixin, hf_hub_download\nfrom safetensors.torch import load_file, save_file\nfrom torch import nn\n\nfrom .. import configs\nfrom ..torch import dense, transformer, util\nfrom . import timesfm_2p5_base\n\nrevin = util.revin\n\n\nclass TimesFM_2p5_200M_torch_module(nn.Module):\n  \"\"\"TimesFM 2.5 with 200M parameters.\"\"\"\n\n  config = timesfm_2p5_base.TimesFM_2p5_200M_Definition()\n\n  def __init__(self):\n    super().__init__()\n\n    # Names constants.\n    self.p = self.config.input_patch_len  # 32\n    self.o = self.config.output_patch_len  # 128\n    self.os = self.config.output_quantile_len  # 1024\n    self.m = self.o // self.p  # 4\n    self.x = self.config.stacked_transformers.num_layers  # 20\n    self.h = self.config.stacked_transformers.transformer.num_heads  # 16\n    self.md = self.config.stacked_transformers.transformer.model_dims  # 1280\n    self.hd = self.md // self.h  # 80\n    self.q = len(self.config.quantiles) + 1  # 10\n    self.aridx = self.config.decode_index  # 5\n\n    # Layers.\n    self.tokenizer = dense.ResidualBlock(self.config.tokenizer)\n    self.stacked_xf = nn.ModuleList(\n      [\n        transformer.Transformer(self.config.stacked_transformers.transformer)\n        for _ in range(self.x)\n      ]\n    )\n    self.output_projection_point = dense.ResidualBlock(\n      self.config.output_projection_point\n    )\n    self.output_projection_quantiles = dense.ResidualBlock(\n      self.config.output_projection_quantiles\n    )\n\n    # Device.\n    if torch.cuda.is_available():\n      self.device = torch.device(\"cuda:0\")\n      self.device_count = torch.cuda.device_count()\n    else:\n      self.device = torch.device(\"cpu\")\n      self.device_count = 1\n\n  def load_checkpoint(self, path: str, **kwargs):\n    \"\"\"Loads a PyTorch TimesFM model from a checkpoint.\"\"\"\n    tensors = load_file(path)\n    self.load_state_dict(tensors, strict=True)\n    self.to(self.device)\n    torch_compile = True\n    if \"torch_compile\" in kwargs:\n      torch_compile = kwargs[\"torch_compile\"]\n    if torch_compile:\n      print(\"Compiling model...\")\n      self = torch.compile(self)\n\n    self.eval()\n\n  def forward(\n    self,\n    inputs: torch.Tensor,\n    masks: torch.Tensor,\n    decode_caches: list[util.DecodeCache] | None = None,\n  ):\n    tokenizer_inputs = torch.cat([inputs, masks.to(inputs.dtype)], dim=-1)\n    input_embeddings = self.tokenizer(tokenizer_inputs)\n\n    if decode_caches is None:\n      decode_caches = [None] * self.x\n\n    output_embeddings = input_embeddings\n    new_decode_caches = []\n    for i, layer in enumerate(self.stacked_xf):\n      output_embeddings, new_cache = layer(\n        output_embeddings, masks[..., -1], decode_caches[i]\n      )\n      new_decode_caches.append(new_cache)\n    output_ts = self.output_projection_point(output_embeddings)\n    output_quantile_spread = self.output_projection_quantiles(output_embeddings)\n\n    return (\n      input_embeddings,\n      output_embeddings,\n      output_ts,\n      output_quantile_spread,\n    ), new_decode_caches\n\n  def decode(self, horizon: int, inputs, masks):\n    \"\"\"Decodes the time series.\"\"\"\n\n    with torch.no_grad():\n      batch_size, context = inputs.shape[0], inputs.shape[1]\n      num_decode_steps = (horizon - 1) // self.o\n      num_input_patches = context // self.p\n      decode_cache_size = num_input_patches + num_decode_steps * self.m\n\n      # Prefill\n      patched_inputs = torch.reshape(inputs, (batch_size, -1, self.p))\n      patched_masks = torch.reshape(masks, (batch_size, -1, self.p))\n\n      # running stats\n      n = torch.zeros(batch_size, device=inputs.device)\n      mu = torch.zeros(batch_size, device=inputs.device)\n      sigma = torch.zeros(batch_size, device=inputs.device)\n      patch_mu = []\n      patch_sigma = []\n      for i in range(num_input_patches):\n        (n, mu, sigma), _ = util.update_running_stats(\n          n, mu, sigma, patched_inputs[:, i], patched_masks[:, i]\n        )\n        patch_mu.append(mu)\n        patch_sigma.append(sigma)\n      last_n, last_mu, last_sigma = n, mu, sigma\n      context_mu = torch.stack(patch_mu, dim=1)\n      context_sigma = torch.stack(patch_sigma, dim=1)\n\n      decode_caches = [\n        util.DecodeCache(\n          next_index=torch.zeros(batch_size, dtype=torch.int32, device=inputs.device),\n          num_masked=torch.zeros(batch_size, dtype=torch.int32, device=inputs.device),\n          key=torch.zeros(\n            batch_size,\n            decode_cache_size,\n            self.h,\n            self.hd,\n            device=inputs.device,\n          ),\n          value=torch.zeros(\n            batch_size,\n            decode_cache_size,\n            self.h,\n            self.hd,\n            device=inputs.device,\n          ),\n        )\n        for _ in range(self.x)\n      ]\n\n      normed_inputs = revin(patched_inputs, context_mu, context_sigma, reverse=False)\n      normed_inputs = torch.where(patched_masks, 0.0, normed_inputs)\n      (_, _, normed_outputs, normed_quantile_spread), decode_caches = self(\n        normed_inputs, patched_masks, decode_caches\n      )\n      renormed_outputs = torch.reshape(\n        revin(normed_outputs, context_mu, context_sigma, reverse=True),\n        (batch_size, -1, self.o, self.q),\n      )\n      renormed_quantile_spread = torch.reshape(\n        revin(normed_quantile_spread, context_mu, context_sigma, reverse=True),\n        (batch_size, -1, self.os, self.q),\n      )[:, -1, ...]\n\n      # Autogressive decode\n      ar_outputs = []\n      last_renormed_output = renormed_outputs[:, -1, :, self.aridx]\n\n      for _ in range(num_decode_steps):\n        new_patched_input = torch.reshape(\n          last_renormed_output, (batch_size, self.m, self.p)\n        )\n        new_mask = torch.zeros_like(new_patched_input, dtype=torch.bool)\n\n        n, mu, sigma = last_n, last_mu, last_sigma\n        new_mus, new_sigmas = [], []\n        for i in range(self.m):\n          (n, mu, sigma), _ = util.update_running_stats(\n            n, mu, sigma, new_patched_input[:, i], new_mask[:, i]\n          )\n          new_mus.append(mu)\n          new_sigmas.append(sigma)\n        last_n, last_mu, last_sigma = n, mu, sigma\n        new_mu = torch.stack(new_mus, dim=1)\n        new_sigma = torch.stack(new_sigmas, dim=1)\n\n        new_normed_input = revin(new_patched_input, new_mu, new_sigma, reverse=False)\n        (_, _, new_normed_output, _), decode_caches = self(\n          new_normed_input, new_mask, decode_caches\n        )\n\n        new_renormed_output = torch.reshape(\n          revin(new_normed_output, new_mu, new_sigma, reverse=True),\n          (batch_size, self.m, self.o, self.q),\n        )\n        ar_outputs.append(new_renormed_output[:, -1, ...])\n        last_renormed_output = new_renormed_output[:, -1, :, self.aridx]\n\n      if num_decode_steps > 0:\n        ar_renormed_outputs = torch.stack(ar_outputs, dim=1)\n      else:\n        ar_renormed_outputs = None\n\n    return renormed_outputs, renormed_quantile_spread, ar_renormed_outputs\n\n  def forecast_naive(\n    self, horizon: int, inputs: Sequence[np.ndarray]\n  ) -> list[np.ndarray]:\n    \"\"\"Forecasts the time series.\n\n    This is a naive implementation for debugging purposes. No forecasting\n    flags are used here. Forecasting quality can be subpar.\n\n    Args:\n      horizon: The number of time points to forecast.\n      inputs: A sequence of numpy arrays, each representing a time series to\n        query forecast for.\n\n    Returns:\n      A list of numpy arrays of forecasts.\n    \"\"\"\n    outputs = []\n    for each_input in inputs:\n      input_t = torch.tensor(each_input, dtype=torch.float32)\n      mask = torch.zeros_like(input_t, dtype=torch.bool)\n      len_front_mask = self.p - (len(each_input) % self.p)\n      if len_front_mask < self.p:\n        input_t = torch.cat(\n          [torch.zeros(len_front_mask, dtype=torch.float32), input_t], dim=0\n        )\n        mask = torch.cat([torch.ones(len_front_mask, dtype=torch.bool), mask], dim=0)\n      input_t = input_t[None, ...]\n      mask = mask[None, ...]\n      t_pf, _, t_ar = self.decode(horizon, input_t, mask)\n      to_concat = [t_pf[:, -1, ...]]\n      if t_ar is not None:\n        to_concat.append(t_ar.reshape(1, -1, self.q))\n      torch_forecast = torch.cat(to_concat, dim=1)[..., :horizon]\n      torch_forecast = torch_forecast.squeeze(0)\n      outputs.append(torch_forecast.detach().cpu().numpy())\n    return outputs\n\n\nclass TimesFM_2p5_200M_torch(\n  timesfm_2p5_base.TimesFM_2p5,\n  PyTorchModelHubMixin,\n  library_name=\"timesfm\",\n  repo_url=\"https://github.com/google-research/timesfm\",\n  paper_url=\"https://arxiv.org/abs/2310.10688\",\n  docs_url=\"https://github.com/google-research/timesfm\",\n  license=\"apache-2.0\",\n  pipeline_tag=\"time-series-forecasting\",\n  tags=[\"pytorch\", \"timeseries\", \"forecasting\", \"timesfm-2.5\"],\n):\n  \"\"\"PyTorch implementation of TimesFM 2.5 with 200M parameters.\"\"\"\n\n  DEFAULT_REPO_ID = \"google/timesfm-2.5-200m-pytorch\"\n  WEIGHTS_FILENAME = \"model.safetensors\"\n\n  def __init__(\n    self,\n    torch_compile: bool = True,\n    config: Optional[dict] = None,\n  ):\n    self.model = TimesFM_2p5_200M_torch_module()\n    self.torch_compile = torch_compile\n    if config is not None:\n      self._hub_mixin_config = config\n\n  @classmethod\n  def _from_pretrained(\n    cls,\n    *,\n    model_id: str = DEFAULT_REPO_ID,\n    revision: Optional[str],\n    cache_dir: Optional[Union[str, Path]],\n    force_download: bool = False,\n    local_files_only: bool,\n    token: Optional[Union[str, bool]],\n    config: Optional[dict] = None,\n    **model_kwargs,\n  ):\n    \"\"\"\n    Loads a PyTorch safetensors TimesFM model from a local path or the Hugging\n    Face Hub. This method is the backend for the `from_pretrained` class\n    method provided by `PyTorchModelHubMixin`.\n    \"\"\"\n    # Determine the path to the model weights.\n    model_file_path = \"\"\n    if os.path.isdir(model_id):\n      logging.info(\"Loading checkpoint from local directory: %s\", model_id)\n      model_file_path = os.path.join(model_id, cls.WEIGHTS_FILENAME)\n      if not os.path.exists(model_file_path):\n        raise FileNotFoundError(\n          f\"{cls.WEIGHTS_FILENAME} not found in directory {model_id}\"\n        )\n    else:\n      logging.info(\"Downloading checkpoint from Hugging Face repo %s\", model_id)\n      model_file_path = hf_hub_download(\n        repo_id=model_id,\n        filename=cls.WEIGHTS_FILENAME,\n        revision=revision,\n        cache_dir=cache_dir,\n        force_download=force_download,\n        token=token,\n        local_files_only=local_files_only,\n      )\n\n    # Create an instance of the model wrapper class.\n    instance = cls(config=config, **model_kwargs)\n\n    logging.info(\"Loading checkpoint from: %s\", model_file_path)\n    # Load the weights into the model.\n    instance.model.load_checkpoint(\n      model_file_path, torch_compile=instance.torch_compile\n    )\n    return instance\n\n  def _save_pretrained(self, save_directory: Union[str, Path]):\n    \"\"\"\n    Saves the model's state dictionary to a safetensors file. This method\n    is called by the `save_pretrained` method from `PyTorchModelHubMixin`.\n    \"\"\"\n    if not os.path.exists(save_directory):\n      os.makedirs(save_directory)\n\n    weights_path = os.path.join(save_directory, self.WEIGHTS_FILENAME)\n    save_file(self.model.state_dict(), weights_path)\n\n  def compile(self, forecast_config: configs.ForecastConfig, **kwargs) -> None:\n    \"\"\"Attempts to compile the model for fast decoding.\n\n    See configs.ForecastConfig for more details on the supported flags.\n\n    Args:\n      forecast_config: Configuration for forecasting flags.\n      **kwargs: Additional keyword arguments to pass to model.compile().\n    \"\"\"\n    self.global_batch_size = (\n      forecast_config.per_core_batch_size * self.model.device_count\n    )\n\n    # Shortcut.\n    fc = forecast_config\n\n    if fc.max_context % self.model.p != 0:\n      logging.info(\n        \"When compiling, max context needs to be multiple of the patch size\"\n        \" %d. Using max context = %d instead.\",\n        self.model.p,\n        new_context := math.ceil(fc.max_context / self.model.p) * self.model.p,\n      )\n      fc = dataclasses.replace(fc, max_context=new_context)\n    if fc.max_horizon % self.model.o != 0:\n      logging.info(\n        \"When compiling, max horizon needs to be multiple of the output patch\"\n        \" size %d. Using max horizon = %d instead.\",\n        self.model.o,\n        new_horizon := math.ceil(fc.max_horizon / self.model.o) * self.model.o,\n      )\n      fc = dataclasses.replace(fc, max_horizon=new_horizon)\n    if fc.max_context + fc.max_horizon > self.model.config.context_limit:\n      raise ValueError(\n        \"Context + horizon must be less than the context limit.\"\n        f\" {fc.max_context} + {fc.max_horizon} >\"\n        f\" {self.model.config.context_limit}.\"\n      )\n    if fc.use_continuous_quantile_head and (fc.max_horizon > self.model.os):\n      raise ValueError(\n        f\"Continuous quantile head is not supported for horizons > {self.model.os}.\"\n      )\n    self.forecast_config = fc\n\n    def _compiled_decode(horizon, inputs, masks):\n      if horizon > fc.max_horizon:\n        raise ValueError(\n          f\"Horizon must be less than the max horizon. {horizon} > {fc.max_horizon}.\"\n        )\n\n      inputs = (\n        torch.from_numpy(np.array(inputs)).to(self.model.device).to(torch.float32)\n      )\n      masks = torch.from_numpy(np.array(masks)).to(self.model.device).to(torch.bool)\n      batch_size = inputs.shape[0]\n\n      if fc.infer_is_positive:\n        is_positive = torch.all(inputs >= 0, dim=-1, keepdim=True)\n      else:\n        is_positive = None\n\n      if fc.normalize_inputs:\n        mu = torch.mean(inputs, dim=-1, keepdim=True)\n        sigma = torch.std(inputs, dim=-1, keepdim=True)\n        inputs = revin(inputs, mu, sigma, reverse=False)\n      else:\n        mu, sigma = None, None\n\n      pf_outputs, quantile_spreads, ar_outputs = self.model.decode(\n        forecast_config.max_horizon, inputs, masks\n      )\n      to_cat = [pf_outputs[:, -1, ...]]\n      if ar_outputs is not None:\n        to_cat.append(ar_outputs.reshape(batch_size, -1, self.model.q))\n      full_forecast = torch.cat(to_cat, dim=1)\n\n      def flip_quantile_fn(x):\n        return torch.cat([x[..., :1], torch.flip(x[..., 1:], dims=(-1,))], dim=-1)\n\n      if fc.force_flip_invariance:\n        flipped_pf_outputs, flipped_quantile_spreads, flipped_ar_outputs = (\n          self.model.decode(forecast_config.max_horizon, -inputs, masks)\n        )\n        flipped_quantile_spreads = flip_quantile_fn(flipped_quantile_spreads)\n        flipped_pf_outputs = flip_quantile_fn(flipped_pf_outputs)\n        to_cat = [flipped_pf_outputs[:, -1, ...]]\n        if flipped_ar_outputs is not None:\n          to_cat.append(flipped_ar_outputs.reshape(batch_size, -1, self.model.q))\n        flipped_full_forecast = torch.cat(to_cat, dim=1)\n        quantile_spreads = (quantile_spreads - flipped_quantile_spreads) / 2\n        pf_outputs = (pf_outputs - flipped_pf_outputs) / 2\n        full_forecast = (full_forecast - flipped_full_forecast) / 2\n\n      if fc.use_continuous_quantile_head:\n        for quantile_index in [1, 2, 3, 4, 6, 7, 8, 9]:\n          full_forecast[:, :, quantile_index] = (\n            quantile_spreads[:, : fc.max_horizon, quantile_index]\n            - quantile_spreads[:, : fc.max_horizon, 5]\n            + full_forecast[:, : fc.max_horizon, 5]\n          )\n      full_forecast = full_forecast[:, :horizon, :]\n\n      if fc.return_backcast:\n        full_backcast = pf_outputs[:, :-1, : self.model.p, :].reshape(\n          batch_size, -1, self.model.q\n        )\n        full_forecast = torch.cat([full_backcast, full_forecast], dim=1)\n\n      if fc.fix_quantile_crossing:\n        for i in [4, 3, 2, 1]:\n          full_forecast[:, :, i] = torch.where(\n            full_forecast[:, :, i] < full_forecast[:, :, i + 1],\n            full_forecast[:, :, i],\n            full_forecast[:, :, i + 1],\n          )\n        for i in [6, 7, 8, 9]:\n          full_forecast[:, :, i] = torch.where(\n            full_forecast[:, :, i] > full_forecast[:, :, i - 1],\n            full_forecast[:, :, i],\n            full_forecast[:, :, i - 1],\n          )\n\n      if fc.normalize_inputs:\n        full_forecast = revin(full_forecast, mu, sigma, reverse=True)\n\n      if is_positive is not None:\n        full_forecast = torch.where(\n          is_positive[..., None],\n          torch.maximum(full_forecast, torch.zeros_like(full_forecast)),\n          full_forecast,\n        )\n\n      full_forecast = full_forecast.detach().cpu().numpy()\n      return full_forecast[..., 5], full_forecast\n\n    self.compiled_decode = _compiled_decode\n"
  },
  {
    "path": "src/timesfm/torch/__init__.py",
    "content": "# Copyright 2025 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "src/timesfm/torch/dense.py",
    "content": "# Copyright 2025 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Dense layers for TimesFM.\"\"\"\n\nimport torch\nfrom torch import nn\n\nfrom .. import configs\n\n\nclass ResidualBlock(nn.Module):\n  \"\"\"Residual block with two linear layers and a linear residual connection.\"\"\"\n\n  def __init__(self, config: configs.ResidualBlockConfig):\n    super().__init__()\n    self.config = config\n    self.hidden_layer = nn.Linear(\n        in_features=config.input_dims,\n        out_features=config.hidden_dims,\n        bias=config.use_bias,\n    )\n    self.output_layer = nn.Linear(\n        in_features=config.hidden_dims,\n        out_features=config.output_dims,\n        bias=config.use_bias,\n    )\n    self.residual_layer = nn.Linear(\n        in_features=config.input_dims,\n        out_features=config.output_dims,\n        bias=config.use_bias,\n    )\n    if config.activation == \"relu\":\n      self.activation = nn.ReLU()\n    elif config.activation == \"swish\":\n      self.activation = nn.SiLU()\n    elif config.activation == \"none\":\n      self.activation = nn.Identity()\n    else:\n      raise ValueError(f\"Activation: {config.activation} not supported.\")\n\n  def forward(self, x: torch.Tensor) -> torch.Tensor:\n    return self.output_layer(\n        self.activation(self.hidden_layer(x))\n    ) + self.residual_layer(x)\n\n\nclass RandomFourierFeatures(nn.Module):\n  \"\"\"Random Fourier features layer.\"\"\"\n\n  def __init__(self, config: configs.RandomFourierFeaturesConfig):\n    super().__init__()\n    self.config = config\n\n    if config.output_dims % 4 != 0:\n      raise ValueError(\n          f\"Output dims must be a multiple of 4: {config.output_dims} % 4 != 0.\"\n      )\n    num_projected_features = config.output_dims // 4\n\n    self.phase_shifts = nn.Parameter(torch.zeros(2, num_projected_features))\n    self.projection_layer = nn.Linear(\n        in_features=config.input_dims,\n        out_features=num_projected_features,\n        bias=config.use_bias,\n    )\n    self.residual_layer = nn.Linear(\n        in_features=config.input_dims,\n        out_features=config.output_dims,\n        bias=config.use_bias,\n    )\n\n  def forward(self, x: torch.Tensor) -> torch.Tensor:\n    projected = self.projection_layer(x)\n    cos_features = torch.cos(projected)\n    sin_features = torch.sin(projected)\n    sq_wave_1 = torch.sign(torch.sin(projected + self.phase_shifts[0, :]))\n    sq_wave_2 = torch.sign(torch.sin(projected + self.phase_shifts[1, :]))\n    fourier_features = torch.cat(\n        [cos_features, sin_features, sq_wave_1, sq_wave_2], dim=-1\n    )\n    residual = self.residual_layer(x)\n    return fourier_features + residual\n"
  },
  {
    "path": "src/timesfm/torch/normalization.py",
    "content": "# Copyright 2025 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Normalization layers for TimesFM.\"\"\"\n\nimport torch\nfrom torch import nn\n\n\nclass RMSNorm(nn.Module):\n  \"\"\"RMS normalization.\"\"\"\n\n  def __init__(\n      self,\n      num_features: int,\n      *,\n      epsilon: float = 1e-6,\n  ):\n    super().__init__()\n    self.scale = nn.Parameter(torch.zeros(num_features))\n    self.num_features = num_features\n    self.epsilon = epsilon\n\n  def forward(self, inputs: torch.Tensor) -> torch.Tensor:\n    var = torch.mean(torch.square(inputs), dim=-1, keepdim=True)\n    normed_inputs = inputs * torch.rsqrt(var + self.epsilon)\n    normed_inputs = normed_inputs * self.scale\n    return normed_inputs\n"
  },
  {
    "path": "src/timesfm/torch/transformer.py",
    "content": "# Copyright 2025 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Transformer layers for TimesFM.\"\"\"\n\nimport math\nfrom typing import Callable\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\n\nfrom .. import configs\nfrom . import normalization, util\n\nLayerNorm = nn.LayerNorm\nRMSNorm = normalization.RMSNorm\nDecodeCache = util.DecodeCache\n\n\ndef make_attn_mask(\n  query_length: int,\n  num_all_masked_kv: torch.Tensor,\n  query_index_offset: torch.Tensor | None = None,\n  kv_length: int = 0,\n) -> torch.Tensor:\n  \"\"\"Makes attention mask.\"\"\"\n  if kv_length == 0:\n    kv_length = query_length\n\n  q_index = torch.arange(query_length, device=num_all_masked_kv.device)[\n    None, None, :, None\n  ]\n  if query_index_offset is not None:\n    q_index = q_index + query_index_offset[:, None, None, None]\n  kv_index = torch.arange(kv_length, device=num_all_masked_kv.device)[\n    None, None, None, :\n  ]\n  return torch.logical_and(\n    q_index >= kv_index,\n    kv_index >= num_all_masked_kv[:, None, None, None],\n  )\n\n\nclass RotaryPositionalEmbedding(nn.Module):\n  \"\"\"Rotary positional embedding.\"\"\"\n\n  def __init__(\n    self,\n    embedding_dims: int,\n    min_timescale: float = 1.0,\n    max_timescale: float = 10000.0,\n  ):\n    super().__init__()\n    self.embedding_dims = embedding_dims\n    self.min_timescale = min_timescale\n    self.max_timescale = max_timescale\n\n  def forward(\n    self,\n    inputs: torch.Tensor,\n    position: torch.Tensor | None = None,\n  ):\n    \"\"\"Generates a JTensor of sinusoids with different frequencies.\"\"\"\n    if self.embedding_dims != inputs.shape[-1]:\n      raise ValueError(\n        \"The embedding dims of the rotary position embedding\"\n        \"must match the hidden dimension of the inputs.\"\n      )\n    half_embedding_dim = self.embedding_dims // 2\n    fraction = (\n      2\n      * torch.arange(0, half_embedding_dim, device=inputs.device)\n      / self.embedding_dims\n    )\n    timescale = (\n      self.min_timescale * (self.max_timescale / self.min_timescale) ** fraction\n    ).to(inputs.device)\n    if position is None:\n      seq_length = inputs.shape[1]\n      position = torch.arange(seq_length, dtype=torch.float32, device=inputs.device)[\n        None, :\n      ]\n\n    if len(inputs.shape) == 4:\n      position = position[..., None, None]\n      timescale = timescale[None, None, None, :]\n    elif len(inputs.shape) == 3:\n      position = position[..., None]\n      timescale = timescale[None, None, :]\n    else:\n      raise ValueError(\"Inputs must be of rank 3 or 4.\")\n\n    sinusoid_inp = position / timescale\n    sin = torch.sin(sinusoid_inp)\n    cos = torch.cos(sinusoid_inp)\n    first_half, second_half = torch.chunk(inputs, 2, dim=-1)\n    first_part = first_half * cos - second_half * sin\n    second_part = second_half * cos + first_half * sin\n    return torch.cat([first_part, second_part], dim=-1)\n\n\ndef _dot_product_attention(\n  query,\n  key,\n  value,\n  mask=None,\n):\n  \"\"\"Computes dot-product attention given query, key, and value.\"\"\"\n  attn_weights = torch.einsum(\"...qhd,...khd->...hqk\", query, key)\n  if mask is not None:\n    attn_weights = torch.where(\n      mask, attn_weights, -torch.finfo(attn_weights.dtype).max / 2\n    )\n\n  attn_weights = F.softmax(attn_weights, dim=-1)\n\n  return torch.einsum(\"...hqk,...khd->...qhd\", attn_weights, value)\n\n\ndef _torch_dot_product_attention(query, key, value, mask=None):\n  \"\"\"\n  Performs the exact same (unscaled) attention as the above function,\n  but using the fast and fused F.scaled_dot_product_attention kernel.\n  \"\"\"\n\n  # 1. Permute inputs from (B, L, H, D) to the expected (B, H, L, D)\n  query = query.permute(0, 2, 1, 3)\n  key = key.permute(0, 2, 1, 3)\n  value = value.permute(0, 2, 1, 3)\n\n  # 2. Call the fused attention kernel\n  #    - Pass the mask to `attn_mask`.\n  #    - Set `scale=1.0` to disable the default 1/sqrt(d_k) scaling.\n  output = F.scaled_dot_product_attention(query, key, value, attn_mask=mask, scale=1.0)\n\n  # 3. Permute the output back to the original (B, L, H, D) layout\n  output = output.permute(0, 2, 1, 3)\n\n  return output\n\n\nclass PerDimScale(nn.Module):\n  \"\"\"Per-dimension scaling.\"\"\"\n\n  def __init__(self, num_dims: int):\n    super().__init__()\n    self.num_dims = num_dims\n    self.per_dim_scale = nn.Parameter(torch.zeros(num_dims))\n\n  def forward(self, x: torch.Tensor) -> torch.Tensor:\n    scale_factor = (\n      1.442695041 / math.sqrt(self.num_dims) * F.softplus(self.per_dim_scale)\n    )\n    return x * scale_factor\n\n\nclass MultiHeadAttention(nn.Module):\n  \"\"\"Multi-head attention.\"\"\"\n\n  def __init__(\n    self,\n    num_heads: int,\n    in_features: int,\n    *,\n    use_per_dim_scale: bool = True,\n    use_rotary_position_embeddings: bool = True,\n    use_bias: bool = False,\n    attention_fn: Callable[..., torch.Tensor] = _torch_dot_product_attention,\n    qk_norm: str = \"rms\",\n    fuse_qkv: bool = False,\n  ):\n    super().__init__()\n    self.num_heads = num_heads\n    self.in_features = in_features\n    self.head_dim = in_features // num_heads\n    self.use_bias = use_bias\n    self.attention_fn = attention_fn\n    self.qk_norm = qk_norm\n    self.fuse_qkv = fuse_qkv\n\n    if self.in_features % self.num_heads != 0:\n      raise ValueError(\n        f\"Memory dimension ({self.in_features}) must be divisible by \"\n        f\"'num_heads' heads ({self.num_heads}).\"\n      )\n\n    if self.fuse_qkv:\n      self.qkv_proj = nn.Linear(self.in_features, 3 * self.in_features, bias=use_bias)\n    else:\n      self.query = nn.Linear(self.in_features, self.in_features, bias=use_bias)\n      self.key = nn.Linear(self.in_features, self.in_features, bias=use_bias)\n      self.value = nn.Linear(self.in_features, self.in_features, bias=use_bias)\n    self.out = nn.Linear(self.in_features, self.in_features, bias=use_bias)\n\n    if self.qk_norm == \"rms\":\n      self.query_ln = RMSNorm(self.head_dim)\n      self.key_ln = RMSNorm(self.head_dim)\n    else:\n      self.query_ln = nn.Identity()\n      self.key_ln = nn.Identity()\n\n    self.use_rotary_position_embeddings = use_rotary_position_embeddings\n    if self.use_rotary_position_embeddings:\n      self.rotary_position_embedding = RotaryPositionalEmbedding(\n        embedding_dims=self.head_dim,\n      )\n\n    self.use_per_dim_scale = use_per_dim_scale\n    if use_per_dim_scale:\n      self.per_dim_scale = PerDimScale(num_dims=self.head_dim)\n\n  def forward(\n    self,\n    inputs_q: torch.Tensor,\n    *,\n    decode_cache: DecodeCache | None = None,\n    patch_mask: torch.Tensor | None = None,\n  ) -> tuple[torch.Tensor, DecodeCache | None]:\n    b, n_patches, _ = inputs_q.shape\n    if patch_mask is None:\n      patch_mask = torch.zeros(b, n_patches, dtype=torch.bool, device=inputs_q.device)\n\n    if self.fuse_qkv:\n      qkv = self.qkv_proj(inputs_q)\n      query, key, value = torch.chunk(qkv, 3, dim=-1)\n      query = query.view(b, n_patches, self.num_heads, self.head_dim)\n      key = key.view(b, n_patches, self.num_heads, self.head_dim)\n      value = value.view(b, n_patches, self.num_heads, self.head_dim)\n    else:\n      query = self.query(inputs_q).view(b, n_patches, self.num_heads, self.head_dim)\n      key = self.key(inputs_q).view(b, n_patches, self.num_heads, self.head_dim)\n      value = self.value(inputs_q).view(b, n_patches, self.num_heads, self.head_dim)\n\n    if decode_cache is None:\n      num_masked = torch.sum(patch_mask.to(torch.int32), dim=-1)\n      next_index = torch.zeros_like(num_masked, dtype=torch.int32)\n    else:\n      num_masked = (\n        torch.sum(patch_mask.to(torch.int32), dim=-1) + decode_cache.num_masked\n      )\n      next_index = decode_cache.next_index.clone()\n\n    if self.use_rotary_position_embeddings:\n      position = (\n        torch.arange(n_patches, device=inputs_q.device)[None, :]\n        + next_index[:, None]\n        - num_masked[:, None]\n      )\n      query = self.rotary_position_embedding(query, position)\n      key = self.rotary_position_embedding(key, position)\n\n    query = self.query_ln(query)\n    key = self.key_ln(key)\n\n    if self.use_per_dim_scale:\n      query = self.per_dim_scale(query)\n\n    if decode_cache is not None:\n      _, decode_cache_size, _, _ = decode_cache.value.shape\n\n      start = decode_cache.next_index[0]\n      end = start + n_patches\n\n      # Perform a single, vectorized slice assignment for the entire batch.\n      # This is vastly more efficient than a Python for-loop.\n\n      decode_cache.key[:, start:end] = key\n      decode_cache.value[:, start:end] = value\n\n      key = decode_cache.key\n      value = decode_cache.value\n      decode_cache.next_index += n_patches\n      decode_cache.num_masked = num_masked\n      attn_mask = make_attn_mask(\n        query_length=n_patches,\n        num_all_masked_kv=num_masked,\n        query_index_offset=next_index,\n        kv_length=decode_cache_size,\n      )\n    else:\n      attn_mask = make_attn_mask(query_length=n_patches, num_all_masked_kv=num_masked)\n\n    x = self.attention_fn(\n      query,\n      key,\n      value,\n      mask=attn_mask,\n    )\n\n    x = x.reshape(b, n_patches, self.in_features)\n    out = self.out(x)\n    return out, decode_cache\n\n\nclass Transformer(nn.Module):\n  \"\"\"Classic Transformer used in TimesFM.\"\"\"\n\n  def __init__(self, config: configs.TransformerConfig):\n    super().__init__()\n    self.config = config\n\n    if config.attention_norm == \"rms\":\n      self.pre_attn_ln = RMSNorm(num_features=config.model_dims)\n      self.post_attn_ln = RMSNorm(num_features=config.model_dims)\n    else:\n      raise ValueError(f\"Layer norm: {config.attention_norm} not supported.\")\n\n    self.attn = MultiHeadAttention(\n      num_heads=config.num_heads,\n      in_features=config.model_dims,\n      use_per_dim_scale=True,\n      use_rotary_position_embeddings=config.use_rotary_position_embeddings,\n      qk_norm=config.qk_norm,\n      fuse_qkv=config.fuse_qkv,\n    )\n\n    if config.feedforward_norm == \"rms\":\n      self.pre_ff_ln = RMSNorm(num_features=config.model_dims)\n      self.post_ff_ln = RMSNorm(num_features=config.model_dims)\n    else:\n      raise ValueError(f\"Layer norm: {config.feedforward_norm} not supported.\")\n\n    self.ff0 = nn.Linear(\n      in_features=config.model_dims,\n      out_features=config.hidden_dims,\n      bias=config.use_bias,\n    )\n    self.ff1 = nn.Linear(\n      in_features=config.hidden_dims,\n      out_features=config.model_dims,\n      bias=config.use_bias,\n    )\n    if config.ff_activation == \"relu\":\n      self.activation = nn.ReLU()\n    elif config.ff_activation == \"swish\":\n      self.activation = nn.SiLU()\n    elif config.ff_activation == \"none\":\n      self.activation = nn.Identity()\n    else:\n      raise ValueError(f\"Activation: {config.ff_activation} not supported.\")\n\n  def forward(\n    self,\n    input_embeddings: torch.Tensor,\n    patch_mask: torch.Tensor,\n    decode_cache: DecodeCache | None = None,\n  ) -> tuple[torch.Tensor, DecodeCache | None]:\n    attn_output, decode_cache = self.attn(\n      inputs_q=self.pre_attn_ln(input_embeddings),\n      decode_cache=decode_cache,\n      patch_mask=patch_mask,\n    )\n    attn_output = self.post_attn_ln(attn_output) + input_embeddings\n    output_embeddings = (\n      self.post_ff_ln(self.ff1(self.activation(self.ff0(self.pre_ff_ln(attn_output)))))\n      + attn_output\n    )\n    return output_embeddings, decode_cache\n"
  },
  {
    "path": "src/timesfm/torch/util.py",
    "content": "# Copyright 2025 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"PyTorch utility functions for TimesFM layers.\"\"\"\n\nimport dataclasses\nimport torch\n\n_TOLERANCE = 1e-6\n\n\n@dataclasses.dataclass(frozen=False)\nclass DecodeCache:\n  \"\"\"Cache for decoding.\"\"\"\n\n  next_index: torch.Tensor\n  num_masked: torch.Tensor\n  key: torch.Tensor\n  value: torch.Tensor\n\n\ndef update_running_stats(\n    n: torch.Tensor,\n    mu: torch.Tensor,\n    sigma: torch.Tensor,\n    x: torch.Tensor,\n    mask: torch.Tensor,\n) -> tuple[\n    tuple[torch.Tensor, torch.Tensor, torch.Tensor],\n    tuple[torch.Tensor, torch.Tensor, torch.Tensor],\n]:\n  \"\"\"Updates the running stats.\"\"\"\n  is_legit = torch.logical_not(mask)\n  inc_n = torch.sum(is_legit.to(x.dtype), dim=-1)\n\n  inc_mu_numerator = torch.sum(x * is_legit, dim=-1)\n  inc_n_safe = torch.where(inc_n == 0, 1.0, inc_n)\n  inc_mu = inc_mu_numerator / inc_n_safe\n  inc_mu = torch.where(inc_n == 0, 0.0, inc_mu)\n\n  inc_var_numerator = torch.sum(\n      ((x - inc_mu.unsqueeze(-1)) ** 2) * is_legit, dim=-1\n  )\n  inc_var = inc_var_numerator / inc_n_safe\n  inc_var = torch.where(inc_n == 0, 0.0, inc_var)\n  inc_sigma = torch.sqrt(inc_var)\n\n  new_n = n + inc_n\n  new_n_safe = torch.where(new_n == 0, 1.0, new_n)\n\n  new_mu = (n * mu + inc_mu * inc_n) / new_n_safe\n  new_mu = torch.where(new_n == 0, 0.0, new_mu)\n\n  term1 = n * sigma.pow(2)\n  term2 = inc_n * inc_sigma.pow(2)\n  term3 = n * (mu - new_mu).pow(2)\n  term4 = inc_n * (inc_mu - new_mu).pow(2)\n\n  new_var = (term1 + term2 + term3 + term4) / new_n_safe\n  new_var = torch.where(new_n == 0, 0.0, new_var)\n  new_sigma = torch.sqrt(torch.clamp(new_var, min=0.0))\n\n  return (w := (new_n, new_mu, new_sigma), w)\n\n\ndef revin(\n    x: torch.Tensor,\n    mu: torch.Tensor,\n    sigma: torch.Tensor,\n    reverse: bool = False,\n):\n  \"\"\"Reversible instance normalization.\"\"\"\n  if len(mu.shape) == len(x.shape) - 1:\n    mu = mu[..., None]\n    sigma = sigma[..., None]\n  elif len(mu.shape) == len(x.shape) - 2:\n    mu = mu[..., None, None]\n    sigma = sigma[..., None, None]\n\n  if reverse:\n    return x * sigma + mu\n  else:\n    return (x - mu) / torch.where(sigma < _TOLERANCE, 1.0, sigma)\n"
  },
  {
    "path": "src/timesfm/utils/xreg_lib.py",
    "content": "# Copyright 2025 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Helper functions for in-context covariates and regression.\"\"\"\n\nimport itertools\nimport math\nfrom typing import Any, Iterable, Literal, Mapping, Sequence\n\ntry:\n  import jax\n  import jax.numpy as jnp\n  import numpy as np\n  from sklearn import preprocessing\nexcept ImportError:\n  raise ImportError(\n    \"Failed to load the XReg module. Did you forget to install `timesfm[xreg]`?\"\n  )\n\nCategory = int | str\n\n_TOL = 1e-6\nXRegMode = Literal[\"timesfm + xreg\", \"xreg + timesfm\"]\n\n\ndef _unnest(nested: Sequence[Sequence[Any]]) -> np.ndarray:\n  return np.array(list(itertools.chain.from_iterable(nested)))\n\n\ndef _repeat(elements: Iterable[Any], counts: Iterable[int]) -> np.ndarray:\n  return np.array(\n    list(itertools.chain.from_iterable(map(itertools.repeat, elements, counts)))\n  )\n\n\ndef _to_padded_jax_array(x: np.ndarray) -> jax.Array:\n  if x.ndim == 1:\n    (i,) = x.shape\n    di = 2 ** math.ceil(math.log2(i)) - i\n    return jnp.pad(x, ((0, di),), mode=\"constant\", constant_values=0.0)\n  elif x.ndim == 2:\n    i, j = x.shape\n    di = 2 ** math.ceil(math.log2(i)) - i\n    dj = 2 ** math.ceil(math.log2(j)) - j\n    return jnp.pad(x, ((0, di), (0, dj)), mode=\"constant\", constant_values=0.0)\n  else:\n    raise ValueError(f\"Unsupported array shape: {x.shape}\")\n\n\n# Per time series normalization: forward.\ndef normalize(batch):\n  stats = [(np.mean(x), np.where((w := np.std(x)) > _TOL, w, 1.0)) for x in batch]\n  new_batch = [(x - stat[0]) / stat[1] for x, stat in zip(batch, stats)]\n  return new_batch, stats\n\n\n# Per time series normalization: inverse.\ndef renormalize(batch, stats):\n  return [x * stat[1] + stat[0] for x, stat in zip(batch, stats)]\n\n\nclass BatchedInContextXRegBase:\n  \"\"\"Helper class for in-context regression covariate formatting.\n\n  Attributes:\n    targets: List of targets (responses) of the in-context regression.\n    train_lens: List of lengths of each target vector from the context.\n    test_lens: List of lengths of each forecast horizon.\n    train_dynamic_numerical_covariates: Dict of covariate names mapping to the\n      dynamic numerical covariates of each forecast task on the context. Their\n      lengths should match the corresponding lengths in `train_lens`.\n    train_dynamic_categorical_covariates: Dict of covariate names mapping to the\n      dynamic categorical covariates of each forecast task on the context. Their\n      lengths should match the corresponding lengths in `train_lens`.\n    test_dynamic_numerical_covariates: Dict of covariate names mapping to the\n      dynamic numerical covariates of each forecast task on the horizon. Their\n      lengths should match the corresponding lengths in `test_lens`.\n    test_dynamic_categorical_covariates: Dict of covariate names mapping to the\n      dynamic categorical covariates of each forecast task on the horizon. Their\n      lengths should match the corresponding lengths in `test_lens`.\n    static_numerical_covariates: Dict of covariate names mapping to the static\n      numerical covariates of each forecast task.\n    static_categorical_covariates: Dict of covariate names mapping to the static\n      categorical covariates of each forecast task.\n  \"\"\"\n\n  def __init__(\n    self,\n    targets: Sequence[Sequence[float]],\n    train_lens: Sequence[int],\n    test_lens: Sequence[int],\n    train_dynamic_numerical_covariates: (\n      Mapping[str, Sequence[Sequence[float]]] | None\n    ) = None,\n    train_dynamic_categorical_covariates: (\n      Mapping[str, Sequence[Sequence[Category]]] | None\n    ) = None,\n    test_dynamic_numerical_covariates: (\n      Mapping[str, Sequence[Sequence[float]]] | None\n    ) = None,\n    test_dynamic_categorical_covariates: (\n      Mapping[str, Sequence[Sequence[Category]]] | None\n    ) = None,\n    static_numerical_covariates: Mapping[str, Sequence[float]] | None = None,\n    static_categorical_covariates: (Mapping[str, Sequence[Category]] | None) = None,\n  ) -> None:\n    \"\"\"Initializes with the exogenous covariate inputs.\n\n    Here we use model fitting language to refer to the context as 'train' and\n    the horizon as 'test'. We assume batched inputs. To properly format the\n    request:\n\n     - `train_lens` represents the contexts in the batch. Targets and all train\n     dynamic covariates should have the same lengths as the corresponding\n     elements\n     in `train_lens`. Notice each `train_len` can be different from the exact\n     length of the corresponding context depending on how much of the context is\n     used for fitting the in-context model.\n     - `test_lens` represents the horizon lengths in the batch. All tesdt\n     dynamic\n     covariates should have the same lengths as the corresponding elements in\n     `test_lens`.\n     - Static covariates should be one for each input.\n     - For train and test dynamic covariates, they should have the same\n     covariate\n     names.\n\n     Pass an empty dict {} for a covariate type if it is not present.\n\n     Example:\n       Here is a set of valid inputs whose schema can be used for reference.\n       ```\n       targets = [\n           [0.0, 0.1, 0.2],\n           [0.0, 0.1, 0.2, 0.3],\n       ]  # Two inputs in this batch.\n       train_lens = [3, 4]\n       test_lens = [2, 5]  # Forecast horizons 2 and 5 respectively.\n       train_dynamic_numerical_covariates = {\n           \"cov_1_dn\": [[0.0, 0.5, 1.0], [0.0, 0.5, 1.0, 1.5]],\n           \"cov_2_dn\": [[0.0, 1.5, 1.0], [0.0, 1.5, 1.0, 2.5]],\n       }  # Each train dynamic covariate has 3 and 4 elements respectively.\n       test_dynamic_numerical_covariates = {\n           \"cov_1_dn\": [[0.1, 0.6], [0.1, 0.6, 1.1, 1.6, 2.4]],\n           \"cov_2_dn\": [[0.1, 1.1], [0.1, 1.6, 1.1, 2.6, 10.0]],\n       }  # Each test dynamic covariate has 2 and 5 elements respectively.\n       train_dynamic_categorical_covariates = {\n           \"cov_1_dc\": [[0, 1, 0], [0, 1, 2, 3]],\n           \"cov_2_dc\": [[\"good\", \"bad\", \"good\"], [\"good\", \"good\", \"bad\",\n           \"bad\"]],\n       }\n       test_dynamic_categorical_covariates = {\n           \"cov_1_dc\": [[1, 0], [1, 0, 2, 3, 1]],\n           \"cov_2_dc\": [[\"bad\", \"good\"], [\"bad\", \"bad\", \"bad\", \"bad\", \"bad\"]],\n       }\n       static_numerical_covariates = {\n           \"cov_1_sn\": [0.0, 3.0],\n           \"cov_2_sn\": [2.0, 1.0],\n           \"cov_3_sn\": [1.0, 2.0],\n       }  # Each static covariate has 1 element for each input.\n       static_categorical_covariates = {\n           \"cov_1_sc\": [\"apple\", \"orange\"],\n           \"cov_2_sc\": [2, 3],\n       }\n       ```\n\n    Args:\n      targets: List of targets (responses) of the in-context regression.\n      train_lens: List of lengths of each target vector from the context.\n      test_lens: List of lengths of each forecast horizon.\n      train_dynamic_numerical_covariates: Dict of covariate names mapping to the\n        dynamic numerical covariates of each forecast task on the context. Their\n        lengths should match the corresponding lengths in `train_lens`.\n      train_dynamic_categorical_covariates: Dict of covariate names mapping to\n        the dynamic categorical covariates of each forecast task on the context.\n        Their lengths should match the corresponding lengths in `train_lens`.\n      test_dynamic_numerical_covariates: Dict of covariate names mapping to the\n        dynamic numerical covariates of each forecast task on the horizon. Their\n        lengths should match the corresponding lengths in `test_lens`.\n      test_dynamic_categorical_covariates: Dict of covariate names mapping to\n        the dynamic categorical covariates of each forecast task on the horizon.\n        Their lengths should match the corresponding lengths in `test_lens`.\n      static_numerical_covariates: Dict of covariate names mapping to the static\n        numerical covariates of each forecast task.\n      static_categorical_covariates: Dict of covariate names mapping to the\n        static categorical covariates of each forecast task.\n    \"\"\"\n    self.targets = targets\n    self.train_lens = train_lens\n    self.test_lens = test_lens\n    self.train_dynamic_numerical_covariates = train_dynamic_numerical_covariates or {}\n    self.train_dynamic_categorical_covariates = (\n      train_dynamic_categorical_covariates or {}\n    )\n    self.test_dynamic_numerical_covariates = test_dynamic_numerical_covariates or {}\n    self.test_dynamic_categorical_covariates = test_dynamic_categorical_covariates or {}\n    self.static_numerical_covariates = static_numerical_covariates or {}\n    self.static_categorical_covariates = static_categorical_covariates or {}\n\n  def _assert_covariates(self, assert_covariate_shapes: bool = False) -> None:\n    \"\"\"Verifies the validity of the covariate inputs.\"\"\"\n\n    # Check presence.\n    if (\n      self.train_dynamic_numerical_covariates\n      and not self.test_dynamic_numerical_covariates\n    ) or (\n      not self.train_dynamic_numerical_covariates\n      and self.test_dynamic_numerical_covariates\n    ):\n      raise ValueError(\n        \"train_dynamic_numerical_covariates and\"\n        \" test_dynamic_numerical_covariates must be both present or both\"\n        \" absent.\"\n      )\n\n    if (\n      self.train_dynamic_categorical_covariates\n      and not self.test_dynamic_categorical_covariates\n    ) or (\n      not self.train_dynamic_categorical_covariates\n      and self.test_dynamic_categorical_covariates\n    ):\n      raise ValueError(\n        \"train_dynamic_categorical_covariates and\"\n        \" test_dynamic_categorical_covariates must be both present or both\"\n        \" absent.\"\n      )\n\n    # Check keys.\n    for dict_a, dict_b, dict_a_name, dict_b_name in (\n      (\n        self.train_dynamic_numerical_covariates,\n        self.test_dynamic_numerical_covariates,\n        \"train_dynamic_numerical_covariates\",\n        \"test_dynamic_numerical_covariates\",\n      ),\n      (\n        self.train_dynamic_categorical_covariates,\n        self.test_dynamic_categorical_covariates,\n        \"train_dynamic_categorical_covariates\",\n        \"test_dynamic_categorical_covariates\",\n      ),\n    ):\n      if w := set(dict_a.keys()) - set(dict_b.keys()):\n        raise ValueError(f\"{dict_a_name} has keys not present in {dict_b_name}: {w}\")\n      if w := set(dict_b.keys()) - set(dict_a.keys()):\n        raise ValueError(f\"{dict_b_name} has keys not present in {dict_a_name}: {w}\")\n\n    # Check shapes.\n    if assert_covariate_shapes:\n      if len(self.targets) != len(self.train_lens):\n        raise ValueError(\n          \"targets and train_lens must have the same number of elements.\"\n        )\n\n      if len(self.train_lens) != len(self.test_lens):\n        raise ValueError(\n          \"train_lens and test_lens must have the same number of elements.\"\n        )\n\n      for i, (target, train_len) in enumerate(zip(self.targets, self.train_lens)):\n        if len(target) != train_len:\n          raise ValueError(\n            f\"targets[{i}] has length {len(target)} != expected {train_len}.\"\n          )\n\n      for key, values in self.static_numerical_covariates.items():\n        if len(values) != len(self.train_lens):\n          raise ValueError(\n            f\"static_numerical_covariates has key {key} with number of\"\n            f\" examples {len(values)} != expected {len(self.train_lens)}.\"\n          )\n\n      for key, values in self.static_categorical_covariates.items():\n        if len(values) != len(self.train_lens):\n          raise ValueError(\n            f\"static_categorical_covariates has key {key} with number of\"\n            f\" examples {len(values)} != expected {len(self.train_lens)}.\"\n          )\n\n      for lens, dict_cov, dict_cov_name in (\n        (\n          self.train_lens,\n          self.train_dynamic_numerical_covariates,\n          \"train_dynamic_numerical_covariates\",\n        ),\n        (\n          self.train_lens,\n          self.train_dynamic_categorical_covariates,\n          \"train_dynamic_categorical_covariates\",\n        ),\n        (\n          self.test_lens,\n          self.test_dynamic_numerical_covariates,\n          \"test_dynamic_numerical_covariates\",\n        ),\n        (\n          self.test_lens,\n          self.test_dynamic_categorical_covariates,\n          \"test_dynamic_categorical_covariates\",\n        ),\n      ):\n        for key, cov_values in dict_cov.items():\n          if len(cov_values) != len(lens):\n            raise ValueError(\n              f\"{dict_cov_name} has key {key} with number of examples\"\n              f\" {len(cov_values)} != expected {len(lens)}.\"\n            )\n          for i, cov_value in enumerate(cov_values):\n            if len(cov_value) != lens[i]:\n              raise ValueError(\n                f\"{dict_cov_name} has key {key} with its {i}-th example\"\n                f\" length {len(cov_value)} != expected {lens[i]}.\"\n              )\n\n  def create_covariate_matrix(\n    self,\n    one_hot_encoder_drop: str | None = \"first\",\n    use_intercept: bool = True,\n    assert_covariates: bool = False,\n    assert_covariate_shapes: bool = False,\n  ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:\n    \"\"\"Creates target vector and covariate matrices for in context regression.\n\n    Here we use model fitting language to refer to the context as 'train' and\n    the horizon as 'test'.\n\n    Args:\n      one_hot_encoder_drop: Which drop strategy to use for the one hot encoder.\n      use_intercept: Whether to prepare an intercept (all 1) column in the\n        matrices.\n      assert_covariates: Whether to assert the validity of the covariate inputs.\n      assert_covariate_shapes: Whether to assert the shapes of the covariate\n        inputs when `assert_covariates` is True.\n\n    Returns:\n      A tuple of the target vector, the covariate matrix for the context, and\n      the covariate matrix for the horizon.\n    \"\"\"\n    if assert_covariates:\n      self._assert_covariates(assert_covariate_shapes)\n\n    x_train, x_test = [], []\n\n    # Numerical features.\n    for name in sorted(self.train_dynamic_numerical_covariates):\n      x_train.append(\n        _unnest(self.train_dynamic_numerical_covariates[name])[:, np.newaxis]\n      )\n      x_test.append(\n        _unnest(self.test_dynamic_numerical_covariates[name])[:, np.newaxis]\n      )\n\n    for covs in self.static_numerical_covariates.values():\n      x_train.append(_repeat(covs, self.train_lens)[:, np.newaxis])\n      x_test.append(_repeat(covs, self.test_lens)[:, np.newaxis])\n\n    if x_train:\n      x_train = np.concatenate(x_train, axis=1)\n      x_test = np.concatenate(x_test, axis=1)\n\n      # Normalize for robustness.\n      x_mean = np.mean(x_train, axis=0, keepdims=True)\n      x_std = np.where((w := np.std(x_train, axis=0, keepdims=True)) > _TOL, w, 1.0)\n      x_train = [(x_train - x_mean) / x_std]\n      x_test = [(x_test - x_mean) / x_std]\n\n    # Categorical features. Encode one by one.\n    one_hot_encoder = preprocessing.OneHotEncoder(\n      drop=one_hot_encoder_drop,\n      sparse_output=False,\n      handle_unknown=\"ignore\",\n    )\n    for name in sorted(self.train_dynamic_categorical_covariates.keys()):\n      ohe_train = _unnest(self.train_dynamic_categorical_covariates[name])[\n        :, np.newaxis\n      ]\n      ohe_test = _unnest(self.test_dynamic_categorical_covariates[name])[:, np.newaxis]\n      x_train.append(np.array(one_hot_encoder.fit_transform(ohe_train)))\n      x_test.append(np.array(one_hot_encoder.transform(ohe_test)))\n\n    for covs in self.static_categorical_covariates.values():\n      ohe = one_hot_encoder.fit_transform(np.array(covs)[:, np.newaxis])\n      x_train.append(_repeat(ohe, self.train_lens))\n      x_test.append(_repeat(ohe, self.test_lens))\n\n    x_train = np.concatenate(x_train, axis=1)\n    x_test = np.concatenate(x_test, axis=1)\n\n    if use_intercept:\n      x_train = np.pad(x_train, ((0, 0), (1, 0)), constant_values=1.0)\n      x_test = np.pad(x_test, ((0, 0), (1, 0)), constant_values=1.0)\n\n    return _unnest(self.targets), x_train, x_test\n\n  def fit(self) -> Any:\n    raise NotImplementedError(\"Fit is not implemented.\")\n\n\nclass BatchedInContextXRegLinear(BatchedInContextXRegBase):\n  \"\"\"Linear in-context regression model.\"\"\"\n\n  def fit(\n    self,\n    ridge: float = 0.0,\n    one_hot_encoder_drop: str | None = \"first\",\n    use_intercept: bool = True,\n    force_on_cpu: bool = False,\n    max_rows_per_col: int = 0,\n    max_rows_per_col_sample_seed: int = 42,\n    debug_info: bool = False,\n    assert_covariates: bool = False,\n    assert_covariate_shapes: bool = False,\n  ) -> (\n    list[np.ndarray]\n    | tuple[list[np.ndarray], list[np.ndarray], jax.Array, jax.Array, jax.Array]\n  ):\n    \"\"\"Fits a linear model for in-context regression.\n\n    Args:\n      ridge: A non-negative value for specifying the ridge regression penalty.\n        If 0 is provided, fallback to ordinary least squares. Note this penalty\n        is added to the normalized covariate matrix.\n      one_hot_encoder_drop: Which drop strategy to use for the one hot encoder.\n      use_intercept: Whether to prepare an intercept (all 1) column in the\n        matrices.\n      force_on_cpu: Whether to force execution on cpu for accelerator machines.\n      max_rows_per_col: How many rows to subsample per column. 0 for no\n        subsampling. This is for speeding up model fitting.\n      max_rows_per_col_sample_seed: The seed for the subsampling if needed by\n        `max_rows_per_col`.\n      debug_info: Whether to return debug info.\n      assert_covariates: Whether to assert the validity of the covariate inputs.\n      assert_covariate_shapes: Whether to assert the shapes of the covariate\n        inputs when `assert_covariates` is True.\n\n    Returns:\n      If `debug_info` is False:\n        The linear fits on the horizon.\n      If `debug_info` is True:\n        A tuple of:\n        - the linear fits on the horizon,\n        - the linear fits on the context,\n        - the flattened target vector,\n        - the covariate matrix for the context, and\n        - the covariate matrix for the horizon.\n    \"\"\"\n    flat_targets, x_train_raw, x_test = self.create_covariate_matrix(\n      one_hot_encoder_drop=one_hot_encoder_drop,\n      use_intercept=use_intercept,\n      assert_covariates=assert_covariates,\n      assert_covariate_shapes=assert_covariate_shapes,\n    )\n\n    x_train = x_train_raw.copy()\n    if max_rows_per_col:\n      nrows, ncols = x_train.shape\n      if nrows > (w := ncols * max_rows_per_col):\n        subsample = jax.random.choice(\n          jax.random.PRNGKey(max_rows_per_col_sample_seed),\n          nrows,\n          (w,),\n          replace=False,\n        )\n        x_train = x_train[subsample]\n        flat_targets = flat_targets[subsample]\n\n    device = jax.devices(\"cpu\")[0] if force_on_cpu else None\n    # Runs jitted version of the solvers which are quicker at the cost of\n    # running jitting during the first time calling. Re-jitting happens whenever\n    # new (padded) shapes are encountered.\n    # Ocassionally it helps with the speed and the accuracy if we force single\n    # thread execution on cpu for accelerator machines:\n    # 1. Avoid moving data to accelarator memory.\n    # 2. Avoid precision loss if any.\n    with jax.default_device(device):\n      x_train_raw = _to_padded_jax_array(x_train_raw)\n      x_train = _to_padded_jax_array(x_train)\n      flat_targets = _to_padded_jax_array(flat_targets)\n      x_test = _to_padded_jax_array(x_test)\n      beta_hat = (\n        jnp.linalg.pinv(\n          x_train.T @ x_train + ridge * jnp.eye(x_train.shape[1]),\n          hermitian=True,\n        )\n        @ x_train.T\n        @ flat_targets\n      )\n      y_hat = x_test @ beta_hat\n      y_hat_context = x_train_raw @ beta_hat if debug_info else None\n\n    outputs = []\n    outputs_context = []\n\n    # Reconstruct the ragged 2-dim batched forecasts from flattened linear fits.\n    train_index, test_index = 0, 0\n    for train_index_delta, test_index_delta in zip(self.train_lens, self.test_lens):\n      outputs.append(np.array(y_hat[test_index : (test_index + test_index_delta)]))\n      if debug_info:\n        outputs_context.append(\n          np.array(y_hat_context[train_index : (train_index + train_index_delta)])\n        )\n      train_index += train_index_delta\n      test_index += test_index_delta\n\n    if debug_info:\n      return outputs, outputs_context, flat_targets, x_train, x_test\n    else:\n      return outputs\n"
  },
  {
    "path": "timesfm-forecasting/SKILL.md",
    "content": "---\nname: timesfm-forecasting\ndescription: >\n  Zero-shot time series forecasting with Google's TimesFM foundation model. Use this\n  skill when forecasting ANY univariate time series — sales, sensor readings, stock prices,\n  energy demand, patient vitals, weather, or scientific measurements — without training a\n  custom model. Supports both basic forecasting and advanced covariate forecasting (XReg)\n  with dynamic and static exogenous variables. Automatically checks system RAM/GPU before\n  loading the model, validates dataset fit before processing, supports CSV/DataFrame/array\n  inputs, and returns point forecasts with calibrated prediction intervals. Includes a\n  preflight system checker script that MUST be run before first use to verify the machine\n  can load the model and handle your specific dataset.\nlicense: Apache-2.0\nmetadata:\n  author: Clayton Young (@borealBytes)\n  version: \"1.0.0\"\n---\n\n# TimesFM Forecasting\n\n## Overview\n\nTimesFM (Time Series Foundation Model) is a pretrained decoder-only foundation model\ndeveloped by Google Research for time-series forecasting. It works **zero-shot** — feed it\nany univariate time series and it returns point forecasts with calibrated quantile\nprediction intervals, no training required.\n\nThis skill includes a **mandatory preflight system checker** that verifies RAM, GPU memory,\nand disk space before the model is ever loaded so the agent never crashes the user's machine.\n\n> **Key numbers**: TimesFM 2.5 uses 200M parameters (~800 MB on disk, ~1.5 GB in RAM on\n> CPU, ~1 GB VRAM on GPU). The archived v1/v2 500M-parameter model needs ~32 GB RAM.\n> Always run the system checker first.\n\n## When to Use This Skill\n\nUse this skill when:\n\n- Forecasting **any univariate time series** (sales, demand, sensor, vitals, price, weather)\n- You need **zero-shot forecasting** without training a custom model\n- You want **probabilistic forecasts** with calibrated prediction intervals (quantiles)\n- You have time series of **any length** (the model handles 1–16,384 context points)\n- You need to **batch-forecast** hundreds or thousands of series efficiently\n- You want a **foundation model** approach instead of hand-tuning ARIMA/ETS parameters\n- You need **covariate forecasting** with exogenous variables (price, promotions, holidays, day-of-week effects) → use `forecast_with_covariates()` (TimesFM 2.5 + `pip install timesfm[xreg]`)\n\n\nDo **not** use this skill when:\n\n- You need classical statistical models with coefficient interpretation → use `statsmodels`\n- You need time series classification or clustering → use `aeon`\n- You need multivariate vector autoregression or Granger causality → use `statsmodels`\n- Your data is tabular (not temporal) → use `scikit-learn`\n- You cannot install optional dependencies → XReg requires scikit-learn and JAX\n\n\n> **Note on Anomaly Detection**: TimesFM does not have built-in anomaly detection, but you\n> can use the **quantile forecasts as prediction intervals** — values outside the 90% CI\n> (q10–q90) are statistically unusual. See `examples/anomaly-detection/` for a full example.\n\n## ⚠️ Mandatory Preflight: System Requirements Check\n\n**CRITICAL — ALWAYS run the system checker before loading the model for the first time.**\n\n```bash\npython scripts/check_system.py\n```\n\nThis script checks:\n\n1. **Available RAM** — warns if below 4 GB, blocks if below 2 GB\n2. **GPU availability** — detects CUDA/MPS devices and VRAM\n3. **Disk space** — verifies room for the ~800 MB model download\n4. **Python version** — requires 3.10+\n5. **Existing installation** — checks if `timesfm` and `torch` are installed\n\n> **Note:** Model weights are **NOT stored in this repository**. TimesFM weights (~800 MB)\n> download on-demand from HuggingFace on first use and cache in `~/.cache/huggingface/`.\n\n```mermaid\nflowchart TD\n    start[\"🚀 Run check_system.py\"] --> ram{\"RAM ≥ 4 GB?\"}\n    ram -->|\"Yes\"| gpu{\"GPU available?\"}\n    ram -->|\"No (2-4 GB)\"| warn_ram[\"⚠️ Warning: tight RAM<br/>CPU-only, small batches\"]\n    ram -->|\"No (< 2 GB)\"| block[\"🛑 BLOCKED<br/>Insufficient memory\"]\n    warn_ram --> disk\n    gpu -->|\"CUDA / MPS\"| vram{\"VRAM ≥ 2 GB?\"}\n    gpu -->|\"CPU only\"| cpu_ok[\"✅ CPU mode<br/>Slower but works\"]\n    vram -->|\"Yes\"| gpu_ok[\"✅ GPU mode<br/>Fast inference\"]\n    vram -->|\"No\"| cpu_ok\n    gpu_ok --> disk{\"Disk ≥ 2 GB free?\"}\n    cpu_ok --> disk\n    disk -->|\"Yes\"| ready[\"✅ READY<br/>Safe to load model\"]\n    disk -->|\"No\"| block_disk[\"🛑 BLOCKED<br/>Need space for weights\"]\n```\n\n### Dataset Preflight (NEW)\n\nBefore loading your actual data, verify it will fit in memory:\n\n```bash\n# Quick estimate for your dataset\npython scripts/check_system.py \\\n  --num-series 1000 \\\n  --context-length 1024 \\\n  --horizon 24 \\\n  --batch-size 32 \\\n  --estimate-only\n```\n\nThis will show you the estimated memory requirements and warn if your dataset is too large.\n\n**Memory Estimation Formula**:\n`RAM ≈ 0.8 GB (model) + 0.5 GB (overhead) + (0.2 MB × num_series × context_length / 1000)`\n\n**Example Outputs**:\n\n✅ **Dataset Fits**:\n```\nTotal CPU memory: 2.34 GB\nTotal GPU memory: 2.15 GB\n```\n\n⚠️ **Dataset Too Large**:\n```\nDataset requires ~12.5 GB RAM but system has 8.0 GB.\nTry: context_length=512 or process in chunks of 50 series.\n```\n\n### Hardware Requirements by Model Version\n\n| Model | Parameters | RAM (CPU) | VRAM (GPU) | Disk | Context |\n| ----- | ---------- | --------- | ---------- | ---- | ------- |\n| **TimesFM 2.5** (recommended) | 200M | ≥ 4 GB | ≥ 2 GB | ~800 MB | up to 16,384 |\n| TimesFM 2.0 (archived) | 500M | ≥ 16 GB | ≥ 8 GB | ~2 GB | up to 2,048 |\n| TimesFM 1.0 (archived) | 200M | ≥ 8 GB | ≥ 4 GB | ~800 MB | up to 2,048 |\n\n> **Recommendation**: Always use TimesFM 2.5 unless you have a specific reason to use an\n> older checkpoint. It is smaller, faster, and supports 8× longer context.\n\n## 🔧 Installation\n\n### Step 1: Verify System (always first)\n\n```bash\npython scripts/check_system.py\n```\n\n### Step 2: Install TimesFM\n\n```bash\n# Using uv (fast)\nuv pip install timesfm[torch]\n\n# Or using pip\npip install timesfm[torch]\n\n# For JAX/Flax backend (faster on TPU/GPU)\nuv pip install timesfm[flax]\n```\n\n### Step 3: Install PyTorch for Your Hardware\n\n```bash\n# CUDA 12.1 (NVIDIA GPU)\npip install torch>=2.0.0 --index-url https://download.pytorch.org/whl/cu121\n\n# CPU only\npip install torch>=2.0.0 --index-url https://download.pytorch.org/whl/cpu\n\n# Apple Silicon (MPS)\npip install torch>=2.0.0  # MPS support is built-in\n```\n\n## 🎯 Quick Start\n\n### Minimal Example\n\n```python\nimport torch, numpy as np, timesfm\n\ntorch.set_float32_matmul_precision(\"high\")\n\nmodel = timesfm.TimesFM_2p5_200M_torch.from_pretrained(\n    \"google/timesfm-2.5-200m-pytorch\"\n)\nmodel.compile(timesfm.ForecastConfig(\n    max_context=1024, max_horizon=256, normalize_inputs=True,\n    use_continuous_quantile_head=True, force_flip_invariance=True,\n    infer_is_positive=True, fix_quantile_crossing=True,\n))\n\npoint, quantiles = model.forecast(horizon=24, inputs=[\n    np.sin(np.linspace(0, 20, 200)),  # any 1-D array\n])\n# point.shape == (1, 24)         — median forecast\n# quantiles.shape == (1, 24, 10) — 10th–90th percentile bands\n```\n\n### Forecast with Covariates (XReg)\n\nTimesFM 2.5+ supports exogenous variables through `forecast_with_covariates()`.\nRequires `pip install timesfm[xreg]`.\n\n```python\npoint, quantiles = model.forecast_with_covariates(\n    inputs=inputs,\n    dynamic_numerical_covariates={\"price\": price_arrays},\n    dynamic_categorical_covariates={\"holiday\": holiday_arrays},\n    static_categorical_covariates={\"region\": region_labels},\n    xreg_mode=\"xreg + timesfm\",  # or \"timesfm + xreg\"\n)\n```\n\n### Anomaly Detection (via Quantile Intervals)\n\n```python\npoint, q = model.forecast(horizon=H, inputs=[values])\n\nlower_90 = q[0, :, 1]  # 10th percentile\nupper_90 = q[0, :, 9]  # 90th percentile\n\nactual = test_values\nanomalies = (actual < lower_90) | (actual > upper_90)\n```\n\n| Severity | Condition | Interpretation |\n| -------- | --------- | -------------- |\n| **Normal** | Inside 80% CI | Expected behavior |\n| **Warning** | Outside 80% CI | Unusual but possible |\n| **Critical** | Outside 90% CI | Statistically rare (< 10% probability) |\n\n> See `examples/anomaly-detection/` for a complete worked example with visualization.\n\n## 📊 Understanding the Output\n\nTimesFM returns `(point_forecast, quantile_forecast)`:\n\n- **`point_forecast`**: shape `(batch, horizon)` — the median (0.5 quantile)\n- **`quantile_forecast`**: shape `(batch, horizon, 10)` — ten quantile slices:\n\n| Index | Quantile | Use |\n| ----- | -------- | --- |\n| 0 | Mean | Average prediction |\n| 1 | 0.1 | Lower bound of 80% PI |\n| 2 | 0.2 | Lower bound of 60% PI |\n| **5** | **0.5** | **Median (= `point_forecast`)** |\n| 8 | 0.8 | Upper bound of 60% PI |\n| 9 | 0.9 | Upper bound of 80% PI |\n\n```python\npoint, q = model.forecast(horizon=H, inputs=data)\n\nlower_80 = q[:, :, 1]  # 10th percentile\nupper_80 = q[:, :, 9]  # 90th percentile\nmedian   = q[:, :, 5]\n```\n\n## 🔧 ForecastConfig Reference\n\nAll forecasting behavior is controlled by `timesfm.ForecastConfig`:\n\n```python\ntimesfm.ForecastConfig(\n    max_context=1024,                    # Max context window\n    max_horizon=256,                     # Max forecast horizon\n    normalize_inputs=True,               # RECOMMENDED — prevents scale instability\n    per_core_batch_size=32,              # Tune for memory\n    use_continuous_quantile_head=True,   # Better quantile accuracy for long horizons\n    force_flip_invariance=True,          # Ensures f(-x) = -f(x)\n    infer_is_positive=True,              # Clamp forecasts ≥ 0 when all inputs > 0\n    fix_quantile_crossing=True,          # Ensure q10 ≤ q20 ≤ ... ≤ q90\n    return_backcast=False,               # Return backcast (for covariate workflows)\n)\n```\n\n| Parameter | Default | When to Change |\n| --------- | ------- | -------------- |\n| `max_context` | 0 | Set to match your longest historical window |\n| `normalize_inputs` | False | **Always set True** |\n| `use_continuous_quantile_head` | False | **Set True** for calibrated PIs |\n| `infer_is_positive` | True | Set False for series that can be negative |\n| `fix_quantile_crossing` | False | **Set True** for monotonic quantiles |\n\nSee `references/api_reference.md` for the complete parameter reference.\n\n## 📋 Common Workflows\n\n### Single Series Forecast\n\n```python\nimport torch, numpy as np, pandas as pd, timesfm, matplotlib\nmatplotlib.use(\"Agg\")\nimport matplotlib.pyplot as plt\n\ntorch.set_float32_matmul_precision(\"high\")\nmodel = timesfm.TimesFM_2p5_200M_torch.from_pretrained(\n    \"google/timesfm-2.5-200m-pytorch\"\n)\nmodel.compile(timesfm.ForecastConfig(\n    max_context=512, max_horizon=52, normalize_inputs=True,\n    use_continuous_quantile_head=True, fix_quantile_crossing=True,\n))\n\ndf = pd.read_csv(\"weekly_demand.csv\", parse_dates=[\"week\"])\nvalues = df[\"demand\"].values.astype(np.float32)\n\npoint, quantiles = model.forecast(horizon=52, inputs=[values])\n\nfig, ax = plt.subplots(figsize=(12, 5))\nax.plot(values[-104:], label=\"Historical\")\nx_fc = range(len(values[-104:]), len(values[-104:]) + 52)\nax.plot(x_fc, point[0], label=\"Forecast\", color=\"tab:orange\")\nax.fill_between(x_fc, quantiles[0, :, 1], quantiles[0, :, 9],\n                alpha=0.2, color=\"tab:orange\", label=\"80% PI\")\nax.legend(); ax.set_title(\"52-Week Demand Forecast\")\nplt.tight_layout(); plt.savefig(\"forecast.png\", dpi=150)\n```\n\n### Batch Forecasting (Many Series)\n\n```python\ndf = pd.read_csv(\"all_stores.csv\", parse_dates=[\"date\"], index_col=\"date\")\ninputs = [df[col].dropna().values.astype(np.float32) for col in df.columns]\n\npoint, quantiles = model.forecast(horizon=30, inputs=inputs)\n\nimport json\nresults = {col: {\"forecast\": point[i].tolist(),\n                 \"lower_80\": quantiles[i, :, 1].tolist(),\n                 \"upper_80\": quantiles[i, :, 9].tolist()}\n           for i, col in enumerate(df.columns)}\nwith open(\"batch_forecasts.json\", \"w\") as f:\n    json.dump(results, f, indent=2)\n```\n\n### Evaluate Forecast Accuracy\n\n```python\nH = 24\ntrain, actual = values[:-H], values[-H:]\npoint, quantiles = model.forecast(horizon=H, inputs=[train])\npred = point[0]\n\nmae  = np.mean(np.abs(actual - pred))\nrmse = np.sqrt(np.mean((actual - pred) ** 2))\nmape = np.mean(np.abs((actual - pred) / actual)) * 100\ncoverage = np.mean((actual >= quantiles[0, :, 1]) & (actual <= quantiles[0, :, 9])) * 100\n\nprint(f\"MAE: {mae:.2f} | RMSE: {rmse:.2f} | MAPE: {mape:.1f}% | 80% PI Coverage: {coverage:.1f}%\")\n```\n\n## ⚙️ Performance Tuning\n\n```python\n# Always set on Ampere+ GPUs (A100, RTX 3090+)\ntorch.set_float32_matmul_precision(\"high\")\n\n# Batch size guidelines:\n# GPU 8 GB VRAM:  per_core_batch_size=64\n# GPU 16 GB VRAM: per_core_batch_size=128\n# CPU 8 GB RAM:   per_core_batch_size=8\n# CPU 16 GB RAM:  per_core_batch_size=32\n\n# Memory-constrained: process in chunks\nCHUNK = 50\nresults = []\nfor i in range(0, len(inputs), CHUNK):\n    p, q = model.forecast(horizon=H, inputs=inputs[i:i+CHUNK])\n    results.append((p, q))\n```\n\n## 📚 Available Scripts\n\n### `scripts/check_system.py`\n\nMandatory preflight checker — run before first model load.\nNow includes **dataset-aware memory estimation** to prevent OOM errors before loading your data.\n\n```bash\n# Basic system check\npython scripts/check_system.py\n\n# Check if your specific dataset will fit\npython scripts/check_system.py \\\n  --num-series 1000 \\\n  --context-length 1024 \\\n  --horizon 24 \\\n  --batch-size 32\n\n# Quick memory estimate without system checks\npython scripts/check_system.py \\\n  --num-series 5000 \\\n  --context-length 2048 \\\n  --estimate-only\n```\n\n**What it checks**:\n\n1. **Available RAM** — warns if below 4 GB, blocks if below 2 GB\n2. **GPU availability** — detects CUDA/MPS devices and VRAM\n3. **Disk space** — verifies room for the ~800 MB model download\n4. **Python version** — requires 3.10+\n5. **Existing installation** — checks if `timesfm` and `torch` are installed\n6. **Dataset fit** (NEW) — estimates memory for your specific dataset and warns if it won't fit\n\n### `scripts/forecast_csv.py`\n\nEnd-to-end CSV forecasting CLI.\n\n```bash\npython scripts/forecast_csv.py input.csv \\\n    --horizon 24 \\\n    --date-col date \\\n    --value-cols sales,revenue \\\n    --output forecasts.csv\n```\n\n## 📖 Reference Documentation\n\n| File | Contents |\n| ---- | -------- |\n| `references/system_requirements.md` | Hardware tiers, GPU/CPU selection, memory estimation |\n| `references/api_reference.md` | Full `ForecastConfig` docs, output shapes, model options |\n| `references/data_preparation.md` | Input formats, NaN handling, CSV loading, covariate setup |\n\n## 🧪 Examples\n\n| Example | Directory | What It Demonstrates |\n| ------- | --------- | -------------------- |\n| **Global Temperature Forecast** | `examples/global-temperature/` | Basic `model.forecast()`, CSV → PNG → GIF pipeline |\n| **Anomaly Detection** | `examples/anomaly-detection/` | Two-phase detrend + Z-score + quantile PI, 2-panel viz |\n| **Covariates (XReg)** | `examples/covariates-forecasting/` | `forecast_with_covariates()`, 2×2 shared-axis viz |\n\n```bash\n# Run all three examples:\ncd examples/global-temperature && python run_forecast.py && python visualize_forecast.py\ncd examples/anomaly-detection  && python detect_anomalies.py\ncd examples/covariates-forecasting && python demo_covariates.py\n```\n\n### Expected Outputs\n\n| Example | Key output files | Acceptance criteria |\n| ------- | ---------------- | ------------------- |\n| global-temperature | `output/forecast_output.json`, `output/forecast_visualization.png` | `point_forecast` has 12 values; PNG shows context + forecast + PI bands |\n| anomaly-detection | `output/anomaly_detection.json`, `output/anomaly_detection.png` | Sep 2023 flagged CRITICAL (z ≥ 3.0) |\n| covariates-forecasting | `output/sales_with_covariates.csv`, `output/covariates_data.png` | 108 rows (3 stores × 36 weeks); distinct price arrays per store |\n\n## Model Versions\n\n| Version | Params | Context | Status | HuggingFace checkpoint |\n| ------- | ------ | ------- | ------ | ---------------------- |\n| **2.5** | 200M | 16,384 | **Latest** | `google/timesfm-2.5-200m-pytorch` |\n| 2.0 | 500M | 2,048 | Archived | `google/timesfm-2.0-500m-pytorch` |\n| 1.0 | 200M | 2,048 | Archived | `google/timesfm-1.0-200m-pytorch` |\n\n- TimesFM 1.0/2.0: must pass `freq=[0]` for monthly data\n- TimesFM 2.5: no frequency flag — it was removed\n\n## Resources\n\n- **Paper**: [A Decoder-Only Foundation Model for Time-Series Forecasting](https://arxiv.org/abs/2310.10688) (ICML 2024)\n- **HuggingFace**: https://huggingface.co/collections/google/timesfm-release-66e4be5fdb56e960c1e482a6\n- **Google Blog**: https://research.google/blog/a-decoder-only-foundation-model-for-time-series-forecasting/\n- **BigQuery Integration**: https://cloud.google.com/bigquery/docs/timesfm-model\n\n## Quality Checklist\n\nRun after every TimesFM task before declaring success:\n\n- [ ] **Output shape** — `point_fc` is `(n_series, horizon)`, `quant_fc` is `(n_series, horizon, 10)`\n- [ ] **Quantile indices** — index 0 = mean, 1 = q10 ... 9 = q90. NOT 0 = q0.\n- [ ] **Frequency flag** — TimesFM 1.0/2.0: pass `freq=[0]` for monthly. TimesFM 2.5: omit.\n- [ ] **Series length** — context must be ≥ 32 data points.\n- [ ] **No NaN** — `np.isnan(point_fc).any()` must be False.\n- [ ] **Axes** — multiple panels sharing data must use `sharex=True`.\n- [ ] **`matplotlib.use('Agg')`** — before any pyplot import when running headless.\n- [ ] **`infer_is_positive`** — set False for temperature, financial returns, negatives.\n\n## Common Mistakes\n\n1. **Quantile index off-by-one** — `quant_fc[..., 0]` is the **mean**, not q0. q10 = index 1, q90 = index 9. Define: `IDX_Q10, IDX_Q90 = 1, 9`.\n\n2. **Variable shadowing in covariate loops** — don't use the outer loop variable as a comprehension variable when building per-series covariate dicts.\n\n3. **Wrong CSV column name** — global-temperature CSV uses `anomaly_c`, not `anomaly`. Print `df.columns` first.\n\n4. **TimesFM 2.5 required for `forecast_with_covariates()`** — TimesFM 1.0 does NOT have this method.\n\n5. **Future covariates must span the full horizon** — dynamic covariates need values for BOTH context AND forecast windows.\n\n6. **Context anomaly detection uses residuals** — detrend first, then Z-score. Raw Z-scores mislead on trending data.\n\n## Validation & Verification\n\n```bash\n# Anomaly detection regression:\npython -c \"\nimport json\nd = json.load(open('examples/anomaly-detection/output/anomaly_detection.json'))\nassert d['context_summary']['critical'] >= 1, 'Sep 2023 must be CRITICAL'\nprint('Anomaly detection: PASS')\"\n\n# Covariates regression:\npython -c \"\nimport pandas as pd\ndf = pd.read_csv('examples/covariates-forecasting/output/sales_with_covariates.csv')\nassert len(df) == 108, f'Expected 108 rows, got {len(df)}'\nprint('Covariates: PASS')\"\n```\n"
  },
  {
    "path": "timesfm-forecasting/examples/anomaly-detection/detect_anomalies.py",
    "content": "#!/usr/bin/env python3\n\"\"\"\nTimesFM Anomaly Detection Example — Two-Phase Method\n\nPhase 1 (context): Linear detrend + Z-score on 36 months of real NOAA\n  temperature anomaly data (2022-01 through 2024-12).\n  Sep 2023 (1.47 C) is a known critical outlier.\n\nPhase 2 (forecast): TimesFM quantile prediction intervals on a 12-month\n  synthetic future with 3 injected anomalies.\n\nOutputs:\n  output/anomaly_detection.png  -- 2-panel visualization\n  output/anomaly_detection.json -- structured detection records\n\"\"\"\n\nfrom __future__ import annotations\n\nimport json\nfrom pathlib import Path\n\nimport matplotlib\n\nmatplotlib.use(\"Agg\")\nimport matplotlib.patches as mpatches\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport pandas as pd\n\nHORIZON = 12\nDATA_FILE = (\n    Path(__file__).parent.parent / \"global-temperature\" / \"temperature_anomaly.csv\"\n)\nOUTPUT_DIR = Path(__file__).parent / \"output\"\n\nCRITICAL_Z = 3.0\nWARNING_Z = 2.0\n\n# quant_fc index mapping: 0=mean, 1=q10, 2=q20, ..., 9=q90\nIDX_Q10, IDX_Q20, IDX_Q80, IDX_Q90 = 1, 2, 8, 9\n\nCLR = {\"CRITICAL\": \"#e02020\", \"WARNING\": \"#f08030\", \"NORMAL\": \"#4a90d9\"}\n\n\n# ---------------------------------------------------------------------------\n# Phase 1: context anomaly detection\n# ---------------------------------------------------------------------------\n\n\ndef detect_context_anomalies(\n    values: np.ndarray,\n    dates: list,\n) -> tuple[list[dict], np.ndarray, np.ndarray, float]:\n    \"\"\"Linear detrend + Z-score anomaly detection on context period.\n\n    Returns\n    -------\n    records    : list of dicts, one per month\n    trend_line : fitted linear trend values (same length as values)\n    residuals  : actual - trend_line\n    res_std    : std of residuals (used as sigma for threshold bands)\n    \"\"\"\n    n = len(values)\n    idx = np.arange(n, dtype=float)\n\n    coeffs = np.polyfit(idx, values, 1)\n    trend_line = np.polyval(coeffs, idx)\n    residuals = values - trend_line\n    res_std = residuals.std()\n\n    records = []\n    for i, (d, v, r) in enumerate(zip(dates, values, residuals)):\n        z = r / res_std if res_std > 0 else 0.0\n        if abs(z) >= CRITICAL_Z:\n            severity = \"CRITICAL\"\n        elif abs(z) >= WARNING_Z:\n            severity = \"WARNING\"\n        else:\n            severity = \"NORMAL\"\n        records.append(\n            {\n                \"date\": str(d)[:7],\n                \"value\": round(float(v), 4),\n                \"trend\": round(float(trend_line[i]), 4),\n                \"residual\": round(float(r), 4),\n                \"z_score\": round(float(z), 3),\n                \"severity\": severity,\n            }\n        )\n    return records, trend_line, residuals, res_std\n\n\n# ---------------------------------------------------------------------------\n# Phase 2: synthetic future + forecast anomaly detection\n# ---------------------------------------------------------------------------\n\n\ndef build_synthetic_future(\n    context: np.ndarray,\n    n: int,\n    seed: int = 42,\n) -> tuple[np.ndarray, list[int]]:\n    \"\"\"Build a plausible future with 3 injected anomalies.\n\n    Injected months: 3, 8, 11 (0-indexed within the 12-month horizon).\n    Returns (future_values, injected_indices).\n    \"\"\"\n    rng = np.random.default_rng(seed)\n    trend = np.linspace(context[-6:].mean(), context[-6:].mean() + 0.05, n)\n    noise = rng.normal(0, 0.1, n)\n    future = trend + noise\n\n    injected = [3, 8, 11]\n    future[3] += 0.7  # CRITICAL spike\n    future[8] -= 0.65  # CRITICAL dip\n    future[11] += 0.45  # WARNING spike\n\n    return future.astype(np.float32), injected\n\n\ndef detect_forecast_anomalies(\n    future_values: np.ndarray,\n    point: np.ndarray,\n    quant_fc: np.ndarray,\n    future_dates: list,\n    injected_at: list[int],\n) -> list[dict]:\n    \"\"\"Classify each forecast month by which PI band it falls outside.\n\n    CRITICAL = outside 80% PI (q10-q90)\n    WARNING  = outside 60% PI (q20-q80) but inside 80% PI\n    NORMAL   = inside 60% PI\n    \"\"\"\n    q10 = quant_fc[IDX_Q10]\n    q20 = quant_fc[IDX_Q20]\n    q80 = quant_fc[IDX_Q80]\n    q90 = quant_fc[IDX_Q90]\n\n    records = []\n    for i, (d, fv, pt) in enumerate(zip(future_dates, future_values, point)):\n        outside_80 = fv < q10[i] or fv > q90[i]\n        outside_60 = fv < q20[i] or fv > q80[i]\n\n        if outside_80:\n            severity = \"CRITICAL\"\n        elif outside_60:\n            severity = \"WARNING\"\n        else:\n            severity = \"NORMAL\"\n\n        records.append(\n            {\n                \"date\": str(d)[:7],\n                \"actual\": round(float(fv), 4),\n                \"forecast\": round(float(pt), 4),\n                \"q10\": round(float(q10[i]), 4),\n                \"q20\": round(float(q20[i]), 4),\n                \"q80\": round(float(q80[i]), 4),\n                \"q90\": round(float(q90[i]), 4),\n                \"severity\": severity,\n                \"was_injected\": i in injected_at,\n            }\n        )\n    return records\n\n\n# ---------------------------------------------------------------------------\n# Visualization\n# ---------------------------------------------------------------------------\n\n\ndef plot_results(\n    context_dates: list,\n    context_values: np.ndarray,\n    ctx_records: list[dict],\n    trend_line: np.ndarray,\n    residuals: np.ndarray,\n    res_std: float,\n    future_dates: list,\n    future_values: np.ndarray,\n    point_fc: np.ndarray,\n    quant_fc: np.ndarray,\n    fc_records: list[dict],\n) -> None:\n    OUTPUT_DIR.mkdir(exist_ok=True)\n\n    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 10), gridspec_kw={\"hspace\": 0.42})\n    fig.suptitle(\n        \"TimesFM Anomaly Detection — Two-Phase Method\", fontsize=14, fontweight=\"bold\"\n    )\n\n    # -----------------------------------------------------------------------\n    # Panel 1 — full timeline\n    # -----------------------------------------------------------------------\n    ctx_x = [pd.Timestamp(d) for d in context_dates]\n    fut_x = [pd.Timestamp(d) for d in future_dates]\n    divider = ctx_x[-1]\n\n    # context: blue line + trend + 2sigma band\n    ax1.plot(\n        ctx_x,\n        context_values,\n        color=CLR[\"NORMAL\"],\n        lw=2,\n        marker=\"o\",\n        ms=4,\n        label=\"Observed (context)\",\n    )\n    ax1.plot(ctx_x, trend_line, color=\"#aaaaaa\", lw=1.5, ls=\"--\", label=\"Linear trend\")\n    ax1.fill_between(\n        ctx_x,\n        trend_line - 2 * res_std,\n        trend_line + 2 * res_std,\n        alpha=0.15,\n        color=CLR[\"NORMAL\"],\n        label=\"+/-2sigma band\",\n    )\n\n    # context anomaly markers\n    seen_ctx: set[str] = set()\n    for rec in ctx_records:\n        if rec[\"severity\"] == \"NORMAL\":\n            continue\n        d = pd.Timestamp(rec[\"date\"])\n        v = rec[\"value\"]\n        sev = rec[\"severity\"]\n        lbl = f\"Context {sev}\" if sev not in seen_ctx else None\n        seen_ctx.add(sev)\n        ax1.scatter(d, v, marker=\"D\", s=90, color=CLR[sev], zorder=6, label=lbl)\n        ax1.annotate(\n            f\"z={rec['z_score']:+.1f}\",\n            (d, v),\n            textcoords=\"offset points\",\n            xytext=(0, 9),\n            fontsize=7.5,\n            ha=\"center\",\n            color=CLR[sev],\n        )\n\n    # forecast section\n    q10 = quant_fc[IDX_Q10]\n    q20 = quant_fc[IDX_Q20]\n    q80 = quant_fc[IDX_Q80]\n    q90 = quant_fc[IDX_Q90]\n\n    ax1.plot(fut_x, future_values, \"k--\", lw=1.5, label=\"Synthetic future (truth)\")\n    ax1.plot(\n        fut_x,\n        point_fc,\n        color=CLR[\"CRITICAL\"],\n        lw=2,\n        marker=\"s\",\n        ms=4,\n        label=\"TimesFM point forecast\",\n    )\n    ax1.fill_between(fut_x, q10, q90, alpha=0.15, color=CLR[\"CRITICAL\"], label=\"80% PI\")\n    ax1.fill_between(fut_x, q20, q80, alpha=0.25, color=CLR[\"CRITICAL\"], label=\"60% PI\")\n\n    seen_fc: set[str] = set()\n    for i, rec in enumerate(fc_records):\n        if rec[\"severity\"] == \"NORMAL\":\n            continue\n        d = pd.Timestamp(rec[\"date\"])\n        v = rec[\"actual\"]\n        sev = rec[\"severity\"]\n        mk = \"X\" if sev == \"CRITICAL\" else \"^\"\n        lbl = f\"Forecast {sev}\" if sev not in seen_fc else None\n        seen_fc.add(sev)\n        ax1.scatter(d, v, marker=mk, s=100, color=CLR[sev], zorder=6, label=lbl)\n\n    ax1.axvline(divider, color=\"#555555\", lw=1.5, ls=\":\")\n    ax1.text(\n        divider,\n        ax1.get_ylim()[1] if ax1.get_ylim()[1] != 0 else 1.5,\n        \"  <- Context | Forecast ->\",\n        fontsize=8.5,\n        color=\"#555555\",\n        style=\"italic\",\n        va=\"top\",\n    )\n\n    ax1.annotate(\n        \"Context: D = Z-score anomaly | Forecast: X = CRITICAL, ^ = WARNING\",\n        xy=(0.01, 0.04),\n        xycoords=\"axes fraction\",\n        fontsize=8,\n        bbox=dict(boxstyle=\"round\", fc=\"white\", ec=\"#cccccc\", alpha=0.9),\n    )\n\n    ax1.set_ylabel(\"Temperature Anomaly (C)\", fontsize=10)\n    ax1.legend(ncol=2, fontsize=7.5, loc=\"upper left\")\n    ax1.grid(True, alpha=0.22)\n\n    # -----------------------------------------------------------------------\n    # Panel 2 — deviation bars across all 48 months\n    # -----------------------------------------------------------------------\n    all_labels: list[str] = []\n    bar_colors: list[str] = []\n    bar_heights: list[float] = []\n\n    for rec in ctx_records:\n        all_labels.append(rec[\"date\"])\n        bar_heights.append(rec[\"residual\"])\n        bar_colors.append(CLR[rec[\"severity\"]])\n\n    fc_deviations: list[float] = []\n    for rec in fc_records:\n        all_labels.append(rec[\"date\"])\n        dev = rec[\"actual\"] - rec[\"forecast\"]\n        fc_deviations.append(dev)\n        bar_heights.append(dev)\n        bar_colors.append(CLR[rec[\"severity\"]])\n\n    xs = np.arange(len(all_labels))\n    ax2.bar(xs[:36], bar_heights[:36], color=bar_colors[:36], alpha=0.8)\n    ax2.bar(xs[36:], bar_heights[36:], color=bar_colors[36:], alpha=0.8)\n\n    # threshold lines for context section only\n    ax2.hlines(\n        [2 * res_std, -2 * res_std], -0.5, 35.5, colors=CLR[\"NORMAL\"], lw=1.2, ls=\"--\"\n    )\n    ax2.hlines(\n        [3 * res_std, -3 * res_std], -0.5, 35.5, colors=CLR[\"NORMAL\"], lw=1.0, ls=\":\"\n    )\n\n    # PI bands for forecast section\n    fc_xs = xs[36:]\n    ax2.fill_between(\n        fc_xs,\n        q10 - point_fc,\n        q90 - point_fc,\n        alpha=0.12,\n        color=CLR[\"CRITICAL\"],\n        step=\"mid\",\n    )\n    ax2.fill_between(\n        fc_xs,\n        q20 - point_fc,\n        q80 - point_fc,\n        alpha=0.20,\n        color=CLR[\"CRITICAL\"],\n        step=\"mid\",\n    )\n\n    ax2.axvline(35.5, color=\"#555555\", lw=1.5, ls=\"--\")\n    ax2.axhline(0, color=\"black\", lw=0.8, alpha=0.6)\n\n    ax2.text(\n        10,\n        ax2.get_ylim()[0] * 0.85 if ax2.get_ylim()[0] < 0 else -0.05,\n        \"<- Context: delta from linear trend\",\n        fontsize=8,\n        style=\"italic\",\n        color=\"#555555\",\n        ha=\"center\",\n    )\n    ax2.text(\n        41,\n        ax2.get_ylim()[0] * 0.85 if ax2.get_ylim()[0] < 0 else -0.05,\n        \"Forecast: delta from TimesFM ->\",\n        fontsize=8,\n        style=\"italic\",\n        color=\"#555555\",\n        ha=\"center\",\n    )\n\n    tick_every = 3\n    ax2.set_xticks(xs[::tick_every])\n    ax2.set_xticklabels(all_labels[::tick_every], rotation=45, ha=\"right\", fontsize=7)\n    ax2.set_ylabel(\"Delta from expected (C)\", fontsize=10)\n    ax2.grid(True, alpha=0.22, axis=\"y\")\n\n    legend_patches = [\n        mpatches.Patch(color=CLR[\"CRITICAL\"], label=\"CRITICAL\"),\n        mpatches.Patch(color=CLR[\"WARNING\"], label=\"WARNING\"),\n        mpatches.Patch(color=CLR[\"NORMAL\"], label=\"Normal\"),\n    ]\n    ax2.legend(handles=legend_patches, fontsize=8, loc=\"upper right\")\n\n    output_path = OUTPUT_DIR / \"anomaly_detection.png\"\n    plt.savefig(output_path, dpi=150, bbox_inches=\"tight\")\n    plt.close()\n    print(f\"\\n  Saved: {output_path}\")\n\n\n# ---------------------------------------------------------------------------\n# Main\n# ---------------------------------------------------------------------------\n\n\ndef main() -> None:\n    print(\"=\" * 68)\n    print(\"  TIMESFM ANOMALY DETECTION — TWO-PHASE METHOD\")\n    print(\"=\" * 68)\n\n    # --- Load context data ---------------------------------------------------\n    df = pd.read_csv(DATA_FILE)\n    df[\"date\"] = pd.to_datetime(df[\"date\"])\n    df = df.sort_values(\"date\").reset_index(drop=True)\n\n    context_values = df[\"anomaly_c\"].values.astype(np.float32)\n    context_dates = [pd.Timestamp(d) for d in df[\"date\"].tolist()]\n    start_str = context_dates[0].strftime('%Y-%m') if not pd.isnull(context_dates[0]) else '?'\n    end_str   = context_dates[-1].strftime('%Y-%m') if not pd.isnull(context_dates[-1]) else '?'\n    print(f\"\\n  Context: {len(context_values)} months  ({start_str} - {end_str})\")\n\n    # --- Phase 1: context anomaly detection ----------------------------------\n    ctx_records, trend_line, residuals, res_std = detect_context_anomalies(\n        context_values, context_dates\n    )\n    ctx_critical = [r for r in ctx_records if r[\"severity\"] == \"CRITICAL\"]\n    ctx_warning = [r for r in ctx_records if r[\"severity\"] == \"WARNING\"]\n    print(f\"\\n  [Phase 1] Context anomalies (Z-score, sigma={res_std:.3f} C):\")\n    print(f\"    CRITICAL (|Z|>={CRITICAL_Z}): {len(ctx_critical)}\")\n    for r in ctx_critical:\n        print(f\"      {r['date']}  {r['value']:+.3f} C  z={r['z_score']:+.2f}\")\n    print(f\"    WARNING  (|Z|>={WARNING_Z}): {len(ctx_warning)}\")\n    for r in ctx_warning:\n        print(f\"      {r['date']}  {r['value']:+.3f} C  z={r['z_score']:+.2f}\")\n\n    # --- Load TimesFM --------------------------------------------------------\n    print(\"\\n  Loading TimesFM 1.0 ...\")\n    import timesfm\n\n    hparams = timesfm.TimesFmHparams(horizon_len=HORIZON)\n    checkpoint = timesfm.TimesFmCheckpoint(\n        huggingface_repo_id=\"google/timesfm-1.0-200m-pytorch\"\n    )\n    model = timesfm.TimesFm(hparams=hparams, checkpoint=checkpoint)\n\n    point_out, quant_out = model.forecast([context_values], freq=[0])\n    point_fc = point_out[0]  # shape (HORIZON,)\n    quant_fc = quant_out[0].T  # shape (10, HORIZON)\n\n    # --- Build synthetic future + Phase 2 detection --------------------------\n    future_values, injected = build_synthetic_future(context_values, HORIZON)\n    last_date = context_dates[-1]\n    future_dates = [last_date + pd.DateOffset(months=i + 1) for i in range(HORIZON)]\n\n    fc_records = detect_forecast_anomalies(\n        future_values, point_fc, quant_fc, future_dates, injected\n    )\n    fc_critical = [r for r in fc_records if r[\"severity\"] == \"CRITICAL\"]\n    fc_warning = [r for r in fc_records if r[\"severity\"] == \"WARNING\"]\n\n    print(f\"\\n  [Phase 2] Forecast anomalies (quantile PI, horizon={HORIZON} months):\")\n    print(f\"    CRITICAL (outside 80% PI): {len(fc_critical)}\")\n    for r in fc_critical:\n        print(\n            f\"      {r['date']}  actual={r['actual']:+.3f}  \"\n            f\"fc={r['forecast']:+.3f}  injected={r['was_injected']}\"\n        )\n    print(f\"    WARNING  (outside 60% PI): {len(fc_warning)}\")\n    for r in fc_warning:\n        print(\n            f\"      {r['date']}  actual={r['actual']:+.3f}  \"\n            f\"fc={r['forecast']:+.3f}  injected={r['was_injected']}\"\n        )\n\n    # --- Plot ----------------------------------------------------------------\n    print(\"\\n  Generating 2-panel visualization...\")\n    plot_results(\n        context_dates,\n        context_values,\n        ctx_records,\n        trend_line,\n        residuals,\n        res_std,\n        future_dates,\n        future_values,\n        point_fc,\n        quant_fc,\n        fc_records,\n    )\n\n    # --- Save JSON -----------------------------------------------------------\n    OUTPUT_DIR.mkdir(exist_ok=True)\n    out = {\n        \"method\": \"two_phase\",\n        \"context_method\": \"linear_detrend_zscore\",\n        \"forecast_method\": \"quantile_prediction_intervals\",\n        \"thresholds\": {\n            \"critical_z\": CRITICAL_Z,\n            \"warning_z\": WARNING_Z,\n            \"pi_critical_pct\": 80,\n            \"pi_warning_pct\": 60,\n        },\n        \"context_summary\": {\n            \"total\": len(ctx_records),\n            \"critical\": len(ctx_critical),\n            \"warning\": len(ctx_warning),\n            \"normal\": len([r for r in ctx_records if r[\"severity\"] == \"NORMAL\"]),\n            \"res_std\": round(float(res_std), 5),\n        },\n        \"forecast_summary\": {\n            \"total\": len(fc_records),\n            \"critical\": len(fc_critical),\n            \"warning\": len(fc_warning),\n            \"normal\": len([r for r in fc_records if r[\"severity\"] == \"NORMAL\"]),\n        },\n        \"context_detections\": ctx_records,\n        \"forecast_detections\": fc_records,\n    }\n    json_path = OUTPUT_DIR / \"anomaly_detection.json\"\n    with open(json_path, \"w\") as f:\n        json.dump(out, f, indent=2)\n    print(f\"  Saved: {json_path}\")\n\n    print(\"\\n\" + \"=\" * 68)\n    print(\"  SUMMARY\")\n    print(\"=\" * 68)\n    print(\n        f\"  Context  ({len(ctx_records)} months): \"\n        f\"{len(ctx_critical)} CRITICAL, {len(ctx_warning)} WARNING\"\n    )\n    print(\n        f\"  Forecast ({len(fc_records)} months): \"\n        f\"{len(fc_critical)} CRITICAL, {len(fc_warning)} WARNING\"\n    )\n    print(\"=\" * 68)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "timesfm-forecasting/examples/anomaly-detection/output/anomaly_detection.json",
    "content": "{\n  \"method\": \"two_phase\",\n  \"context_method\": \"linear_detrend_zscore\",\n  \"forecast_method\": \"quantile_prediction_intervals\",\n  \"thresholds\": {\n    \"critical_z\": 3.0,\n    \"warning_z\": 2.0,\n    \"pi_critical_pct\": 80,\n    \"pi_warning_pct\": 60\n  },\n  \"context_summary\": {\n    \"total\": 36,\n    \"critical\": 1,\n    \"warning\": 0,\n    \"normal\": 35,\n    \"res_std\": 0.11362\n  },\n  \"forecast_summary\": {\n    \"total\": 12,\n    \"critical\": 4,\n    \"warning\": 1,\n    \"normal\": 7\n  },\n  \"context_detections\": [\n    {\n      \"date\": \"2022-01\",\n      \"value\": 0.89,\n      \"trend\": 0.837,\n      \"residual\": 0.053,\n      \"z_score\": 0.467,\n      \"severity\": \"NORMAL\"\n    },\n    {\n      \"date\": \"2022-02\",\n      \"value\": 0.89,\n      \"trend\": 0.8514,\n      \"residual\": 0.0386,\n      \"z_score\": 0.34,\n      \"severity\": \"NORMAL\"\n    },\n    {\n      \"date\": \"2022-03\",\n      \"value\": 1.02,\n      \"trend\": 0.8658,\n      \"residual\": 0.1542,\n      \"z_score\": 1.357,\n      \"severity\": \"NORMAL\"\n    },\n    {\n      \"date\": \"2022-04\",\n      \"value\": 0.88,\n      \"trend\": 0.8803,\n      \"residual\": -0.0003,\n      \"z_score\": -0.002,\n      \"severity\": \"NORMAL\"\n    },\n    {\n      \"date\": \"2022-05\",\n      \"value\": 0.85,\n      \"trend\": 0.8947,\n      \"residual\": -0.0447,\n      \"z_score\": -0.394,\n      \"severity\": \"NORMAL\"\n    },\n    {\n      \"date\": \"2022-06\",\n      \"value\": 0.88,\n      \"trend\": 0.9092,\n      \"residual\": -0.0292,\n      \"z_score\": -0.257,\n      \"severity\": \"NORMAL\"\n    },\n    {\n      \"date\": \"2022-07\",\n      \"value\": 0.88,\n      \"trend\": 0.9236,\n      \"residual\": -0.0436,\n      \"z_score\": -0.384,\n      \"severity\": \"NORMAL\"\n    },\n    {\n      \"date\": \"2022-08\",\n      \"value\": 0.9,\n      \"trend\": 0.9381,\n      \"residual\": -0.0381,\n      \"z_score\": -0.335,\n      \"severity\": \"NORMAL\"\n    },\n    {\n      \"date\": \"2022-09\",\n      \"value\": 0.88,\n      \"trend\": 0.9525,\n      \"residual\": -0.0725,\n      \"z_score\": -0.638,\n      \"severity\": \"NORMAL\"\n    },\n    {\n      \"date\": \"2022-10\",\n      \"value\": 0.95,\n      \"trend\": 0.9669,\n      \"residual\": -0.0169,\n      \"z_score\": -0.149,\n      \"severity\": \"NORMAL\"\n    },\n    {\n      \"date\": \"2022-11\",\n      \"value\": 0.77,\n      \"trend\": 0.9814,\n      \"residual\": -0.2114,\n      \"z_score\": -1.86,\n      \"severity\": \"NORMAL\"\n    },\n    {\n      \"date\": \"2022-12\",\n      \"value\": 0.78,\n      \"trend\": 0.9958,\n      \"residual\": -0.2158,\n      \"z_score\": -1.9,\n      \"severity\": \"NORMAL\"\n    },\n    {\n      \"date\": \"2023-01\",\n      \"value\": 0.87,\n      \"trend\": 1.0103,\n      \"residual\": -0.1403,\n      \"z_score\": -1.235,\n      \"severity\": \"NORMAL\"\n    },\n    {\n      \"date\": \"2023-02\",\n      \"value\": 0.98,\n      \"trend\": 1.0247,\n      \"residual\": -0.0447,\n      \"z_score\": -0.394,\n      \"severity\": \"NORMAL\"\n    },\n    {\n      \"date\": \"2023-03\",\n      \"value\": 1.21,\n      \"trend\": 1.0392,\n      \"residual\": 0.1708,\n      \"z_score\": 1.503,\n      \"severity\": \"NORMAL\"\n    },\n    {\n      \"date\": \"2023-04\",\n      \"value\": 1.0,\n      \"trend\": 1.0536,\n      \"residual\": -0.0536,\n      \"z_score\": -0.472,\n      \"severity\": \"NORMAL\"\n    },\n    {\n      \"date\": \"2023-05\",\n      \"value\": 0.94,\n      \"trend\": 1.0681,\n      \"residual\": -0.1281,\n      \"z_score\": -1.127,\n      \"severity\": \"NORMAL\"\n    },\n    {\n      \"date\": \"2023-06\",\n      \"value\": 1.08,\n      \"trend\": 1.0825,\n      \"residual\": -0.0025,\n      \"z_score\": -0.022,\n      \"severity\": \"NORMAL\"\n    },\n    {\n      \"date\": \"2023-07\",\n      \"value\": 1.18,\n      \"trend\": 1.0969,\n      \"residual\": 0.0831,\n      \"z_score\": 0.731,\n      \"severity\": \"NORMAL\"\n    },\n    {\n      \"date\": \"2023-08\",\n      \"value\": 1.24,\n      \"trend\": 1.1114,\n      \"residual\": 0.1286,\n      \"z_score\": 1.132,\n      \"severity\": \"NORMAL\"\n    },\n    {\n      \"date\": \"2023-09\",\n      \"value\": 1.47,\n      \"trend\": 1.1258,\n      \"residual\": 0.3442,\n      \"z_score\": 3.029,\n      \"severity\": \"CRITICAL\"\n    },\n    {\n      \"date\": \"2023-10\",\n      \"value\": 1.32,\n      \"trend\": 1.1403,\n      \"residual\": 0.1797,\n      \"z_score\": 1.582,\n      \"severity\": \"NORMAL\"\n    },\n    {\n      \"date\": \"2023-11\",\n      \"value\": 1.18,\n      \"trend\": 1.1547,\n      \"residual\": 0.0253,\n      \"z_score\": 0.222,\n      \"severity\": \"NORMAL\"\n    },\n    {\n      \"date\": \"2023-12\",\n      \"value\": 1.16,\n      \"trend\": 1.1692,\n      \"residual\": -0.0092,\n      \"z_score\": -0.081,\n      \"severity\": \"NORMAL\"\n    },\n    {\n      \"date\": \"2024-01\",\n      \"value\": 1.22,\n      \"trend\": 1.1836,\n      \"residual\": 0.0364,\n      \"z_score\": 0.32,\n      \"severity\": \"NORMAL\"\n    },\n    {\n      \"date\": \"2024-02\",\n      \"value\": 1.35,\n      \"trend\": 1.1981,\n      \"residual\": 0.1519,\n      \"z_score\": 1.337,\n      \"severity\": \"NORMAL\"\n    },\n    {\n      \"date\": \"2024-03\",\n      \"value\": 1.34,\n      \"trend\": 1.2125,\n      \"residual\": 0.1275,\n      \"z_score\": 1.122,\n      \"severity\": \"NORMAL\"\n    },\n    {\n      \"date\": \"2024-04\",\n      \"value\": 1.26,\n      \"trend\": 1.2269,\n      \"residual\": 0.0331,\n      \"z_score\": 0.291,\n      \"severity\": \"NORMAL\"\n    },\n    {\n      \"date\": \"2024-05\",\n      \"value\": 1.15,\n      \"trend\": 1.2414,\n      \"residual\": -0.0914,\n      \"z_score\": -0.804,\n      \"severity\": \"NORMAL\"\n    },\n    {\n      \"date\": \"2024-06\",\n      \"value\": 1.2,\n      \"trend\": 1.2558,\n      \"residual\": -0.0558,\n      \"z_score\": -0.491,\n      \"severity\": \"NORMAL\"\n    },\n    {\n      \"date\": \"2024-07\",\n      \"value\": 1.24,\n      \"trend\": 1.2703,\n      \"residual\": -0.0303,\n      \"z_score\": -0.266,\n      \"severity\": \"NORMAL\"\n    },\n    {\n      \"date\": \"2024-08\",\n      \"value\": 1.3,\n      \"trend\": 1.2847,\n      \"residual\": 0.0153,\n      \"z_score\": 0.135,\n      \"severity\": \"NORMAL\"\n    },\n    {\n      \"date\": \"2024-09\",\n      \"value\": 1.28,\n      \"trend\": 1.2992,\n      \"residual\": -0.0192,\n      \"z_score\": -0.169,\n      \"severity\": \"NORMAL\"\n    },\n    {\n      \"date\": \"2024-10\",\n      \"value\": 1.27,\n      \"trend\": 1.3136,\n      \"residual\": -0.0436,\n      \"z_score\": -0.384,\n      \"severity\": \"NORMAL\"\n    },\n    {\n      \"date\": \"2024-11\",\n      \"value\": 1.22,\n      \"trend\": 1.328,\n      \"residual\": -0.108,\n      \"z_score\": -0.951,\n      \"severity\": \"NORMAL\"\n    },\n    {\n      \"date\": \"2024-12\",\n      \"value\": 1.2,\n      \"trend\": 1.3425,\n      \"residual\": -0.1425,\n      \"z_score\": -1.254,\n      \"severity\": \"NORMAL\"\n    }\n  ],\n  \"forecast_detections\": [\n    {\n      \"date\": \"2025-01\",\n      \"actual\": 1.2821,\n      \"forecast\": 1.2593,\n      \"q10\": 1.1407,\n      \"q20\": 1.1881,\n      \"q80\": 1.324,\n      \"q90\": 1.3679,\n      \"severity\": \"NORMAL\",\n      \"was_injected\": false\n    },\n    {\n      \"date\": \"2025-02\",\n      \"actual\": 1.1522,\n      \"forecast\": 1.2857,\n      \"q10\": 1.1406,\n      \"q20\": 1.1961,\n      \"q80\": 1.3751,\n      \"q90\": 1.4254,\n      \"severity\": \"WARNING\",\n      \"was_injected\": false\n    },\n    {\n      \"date\": \"2025-03\",\n      \"actual\": 1.3358,\n      \"forecast\": 1.295,\n      \"q10\": 1.1269,\n      \"q20\": 1.1876,\n      \"q80\": 1.4035,\n      \"q90\": 1.4643,\n      \"severity\": \"NORMAL\",\n      \"was_injected\": false\n    },\n    {\n      \"date\": \"2025-04\",\n      \"actual\": 2.0594,\n      \"forecast\": 1.2208,\n      \"q10\": 1.0353,\n      \"q20\": 1.1042,\n      \"q80\": 1.331,\n      \"q90\": 1.4017,\n      \"severity\": \"CRITICAL\",\n      \"was_injected\": true\n    },\n    {\n      \"date\": \"2025-05\",\n      \"actual\": 1.0747,\n      \"forecast\": 1.1703,\n      \"q10\": 0.9691,\n      \"q20\": 1.0431,\n      \"q80\": 1.2892,\n      \"q90\": 1.3632,\n      \"severity\": \"NORMAL\",\n      \"was_injected\": false\n    },\n    {\n      \"date\": \"2025-06\",\n      \"actual\": 1.1442,\n      \"forecast\": 1.1456,\n      \"q10\": 0.942,\n      \"q20\": 1.0111,\n      \"q80\": 1.2703,\n      \"q90\": 1.3454,\n      \"severity\": \"NORMAL\",\n      \"was_injected\": false\n    },\n    {\n      \"date\": \"2025-07\",\n      \"actual\": 1.2917,\n      \"forecast\": 1.1702,\n      \"q10\": 0.9504,\n      \"q20\": 1.0348,\n      \"q80\": 1.2998,\n      \"q90\": 1.3807,\n      \"severity\": \"NORMAL\",\n      \"was_injected\": false\n    },\n    {\n      \"date\": \"2025-08\",\n      \"actual\": 1.2519,\n      \"forecast\": 1.2027,\n      \"q10\": 0.9709,\n      \"q20\": 1.0594,\n      \"q80\": 1.3408,\n      \"q90\": 1.4195,\n      \"severity\": \"NORMAL\",\n      \"was_injected\": false\n    },\n    {\n      \"date\": \"2025-09\",\n      \"actual\": 0.6364,\n      \"forecast\": 1.191,\n      \"q10\": 0.9594,\n      \"q20\": 1.0404,\n      \"q80\": 1.3355,\n      \"q90\": 1.417,\n      \"severity\": \"CRITICAL\",\n      \"was_injected\": true\n    },\n    {\n      \"date\": \"2025-10\",\n      \"actual\": 1.2073,\n      \"forecast\": 1.1491,\n      \"q10\": 0.9079,\n      \"q20\": 0.9953,\n      \"q80\": 1.2869,\n      \"q90\": 1.3775,\n      \"severity\": \"NORMAL\",\n      \"was_injected\": false\n    },\n    {\n      \"date\": \"2025-11\",\n      \"actual\": 1.3851,\n      \"forecast\": 1.0805,\n      \"q10\": 0.8361,\n      \"q20\": 0.926,\n      \"q80\": 1.2284,\n      \"q90\": 1.3122,\n      \"severity\": \"CRITICAL\",\n      \"was_injected\": false\n    },\n    {\n      \"date\": \"2025-12\",\n      \"actual\": 1.8294,\n      \"forecast\": 1.0613,\n      \"q10\": 0.8022,\n      \"q20\": 0.8952,\n      \"q80\": 1.2169,\n      \"q90\": 1.296,\n      \"severity\": \"CRITICAL\",\n      \"was_injected\": true\n    }\n  ]\n}"
  },
  {
    "path": "timesfm-forecasting/examples/covariates-forecasting/demo_covariates.py",
    "content": "#!/usr/bin/env python3\n\"\"\"\nTimesFM Covariates (XReg) Example\n\nDemonstrates the TimesFM covariate API using synthetic retail sales data.\nTimesFM 1.0 does NOT support forecast_with_covariates(); that requires\nTimesFM 2.5 + `pip install timesfm[xreg]`.\n\nThis script:\n  1. Generates synthetic 3-store weekly retail data (24-week context, 12-week horizon)\n  2. Produces a 2x2 visualization showing WHAT each covariate contributes\n     and WHY knowing them improves forecasts -- all panels share the same\n     week x-axis (0 = first context week, 35 = last horizon week)\n  3. Exports a compact CSV (108 rows) and metadata JSON\n\nNOTE ON REAL DATA:\n  If you want to use a real retail dataset (e.g., Kaggle Rossmann Store Sales),\n  download it to a TEMP location -- do NOT commit large CSVs to this repo.\n\n      import tempfile, urllib.request\n      tmp = tempfile.mkdtemp(prefix=\"timesfm_retail_\")\n      # urllib.request.urlretrieve(\"https://...store_sales.csv\", f\"{tmp}/store_sales.csv\")\n      # df = pd.read_csv(f\"{tmp}/store_sales.csv\")\n\n  This skills directory intentionally keeps only tiny reference datasets.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport json\nfrom pathlib import Path\n\nimport matplotlib\n\nmatplotlib.use(\"Agg\")\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport pandas as pd\n\nEXAMPLE_DIR = Path(__file__).parent\nOUTPUT_DIR = EXAMPLE_DIR / \"output\"\n\nN_STORES = 3\nCONTEXT_LEN = 24\nHORIZON_LEN = 12\nTOTAL_LEN = CONTEXT_LEN + HORIZON_LEN  # 36\n\n\ndef generate_sales_data() -> dict:\n    \"\"\"Generate synthetic retail sales data with covariate components stored separately.\n\n    Returns a dict with:\n      stores:     {store_id: {sales, config}}\n      covariates: {price, promotion, holiday, day_of_week, store_type, region}\n      components: {store_id: {base, price_effect, promo_effect, holiday_effect}}\n\n    Components let us show 'what would sales look like without covariates?' --\n    the gap between 'base' and 'sales' IS the covariate signal.\n\n    BUG FIX v3: Previous versions had variable-shadowing where inner dict\n    comprehension `{store_id: ... for store_id in stores}` overwrote the outer\n    loop variable causing all stores to get identical covariate arrays.\n    Fixed by accumulating per-store arrays separately before building covariate dict.\n    \"\"\"\n    rng = np.random.default_rng(42)\n\n    stores = {\n        \"store_A\": {\"type\": \"premium\", \"region\": \"urban\", \"base_sales\": 1000},\n        \"store_B\": {\"type\": \"standard\", \"region\": \"suburban\", \"base_sales\": 750},\n        \"store_C\": {\"type\": \"discount\", \"region\": \"rural\", \"base_sales\": 500},\n    }\n    base_prices = {\"store_A\": 12.0, \"store_B\": 10.0, \"store_C\": 7.5}\n\n    data: dict = {\"stores\": {}, \"covariates\": {}, \"components\": {}}\n\n    prices_by_store: dict[str, np.ndarray] = {}\n    promos_by_store: dict[str, np.ndarray] = {}\n    holidays_by_store: dict[str, np.ndarray] = {}\n    dow_by_store: dict[str, np.ndarray] = {}\n\n    for store_id, config in stores.items():\n        bp = base_prices[store_id]\n        weeks = np.arange(TOTAL_LEN)\n\n        trend = config[\"base_sales\"] * (1 + 0.005 * weeks)\n        seasonality = 80 * np.sin(2 * np.pi * weeks / 52)\n        noise = rng.normal(0, 40, TOTAL_LEN)\n        base = (trend + seasonality + noise).astype(np.float32)\n\n        price = (bp + rng.uniform(-0.5, 0.5, TOTAL_LEN)).astype(np.float32)\n        price_effect = (-20 * (price - bp)).astype(np.float32)\n\n        holidays = np.zeros(TOTAL_LEN, dtype=np.float32)\n        for hw in [0, 11, 23, 35]:\n            if hw < TOTAL_LEN:\n                holidays[hw] = 1.0\n        holiday_effect = (200 * holidays).astype(np.float32)\n\n        promotion = rng.choice([0.0, 1.0], TOTAL_LEN, p=[0.8, 0.2]).astype(np.float32)\n        promo_effect = (150 * promotion).astype(np.float32)\n\n        day_of_week = np.tile(np.arange(7), TOTAL_LEN // 7 + 1)[:TOTAL_LEN].astype(\n            np.int32\n        )\n\n        sales = np.maximum(base + price_effect + holiday_effect + promo_effect, 50.0)\n\n        data[\"stores\"][store_id] = {\"sales\": sales, \"config\": config}\n        data[\"components\"][store_id] = {\n            \"base\": base,\n            \"price_effect\": price_effect,\n            \"promo_effect\": promo_effect,\n            \"holiday_effect\": holiday_effect,\n        }\n\n        prices_by_store[store_id] = price\n        promos_by_store[store_id] = promotion\n        holidays_by_store[store_id] = holidays\n        dow_by_store[store_id] = day_of_week\n\n    data[\"covariates\"] = {\n        \"price\": prices_by_store,\n        \"promotion\": promos_by_store,\n        \"holiday\": holidays_by_store,\n        \"day_of_week\": dow_by_store,\n        \"store_type\": {sid: stores[sid][\"type\"] for sid in stores},\n        \"region\": {sid: stores[sid][\"region\"] for sid in stores},\n    }\n    return data\n\n\ndef create_visualization(data: dict) -> None:\n    \"\"\"\n    2x2 figure -- ALL panels share x-axis = weeks 0-35.\n\n    (0,0) Sales by store -- context solid, horizon dashed\n    (0,1) Store A: actual vs baseline (no covariates), with event overlays showing uplift\n    (1,0) Price covariate for all stores -- full 36 weeks including horizon\n    (1,1) Covariate effect decomposition for Store A (stacked fill_between)\n\n    Each panel has a conclusion annotation box explaining what the data shows.\n    \"\"\"\n    OUTPUT_DIR.mkdir(exist_ok=True)\n\n    store_colors = {\"store_A\": \"#1a56db\", \"store_B\": \"#057a55\", \"store_C\": \"#c03221\"}\n    weeks = np.arange(TOTAL_LEN)\n\n    fig, axes = plt.subplots(\n        2,\n        2,\n        figsize=(16, 11),\n        sharex=True,\n        gridspec_kw={\"hspace\": 0.42, \"wspace\": 0.32},\n    )\n    fig.suptitle(\n        \"TimesFM Covariates (XReg) -- Retail Sales with Exogenous Variables\\n\"\n        \"Shared x-axis: Week 0-23 = context (observed) | Week 24-35 = forecast horizon\",\n        fontsize=13,\n        fontweight=\"bold\",\n        y=1.01,\n    )\n\n    def add_divider(ax, label_top=True):\n        ax.axvline(CONTEXT_LEN - 0.5, color=\"#9ca3af\", lw=1.3, ls=\"--\", alpha=0.8)\n        ax.axvspan(\n            CONTEXT_LEN - 0.5, TOTAL_LEN - 0.5, alpha=0.06, color=\"grey\", zorder=0\n        )\n        if label_top:\n            ax.text(\n                CONTEXT_LEN + 0.3,\n                1.01,\n                \"<- horizon ->\",\n                transform=ax.get_xaxis_transform(),\n                fontsize=7.5,\n                color=\"#6b7280\",\n                style=\"italic\",\n            )\n\n    # -- (0,0): Sales by Store ---------------------------------------------------\n    ax = axes[0, 0]\n    base_price_labels = {\"store_A\": \"$12\", \"store_B\": \"$10\", \"store_C\": \"$7.50\"}\n    for sid, store_data in data[\"stores\"].items():\n        sales = store_data[\"sales\"]\n        c = store_colors[sid]\n        lbl = f\"{sid} ({store_data['config']['type']}, {base_price_labels[sid]} base)\"\n        ax.plot(\n            weeks[:CONTEXT_LEN],\n            sales[:CONTEXT_LEN],\n            color=c,\n            lw=2,\n            marker=\"o\",\n            ms=3,\n            label=lbl,\n        )\n        ax.plot(\n            weeks[CONTEXT_LEN:],\n            sales[CONTEXT_LEN:],\n            color=c,\n            lw=1.5,\n            ls=\"--\",\n            marker=\"o\",\n            ms=3,\n            alpha=0.6,\n        )\n    add_divider(ax)\n    ax.set_ylabel(\"Weekly Sales (units)\", fontsize=10)\n    ax.set_title(\"Sales by Store\", fontsize=11, fontweight=\"bold\")\n    ax.legend(fontsize=7.5, loc=\"upper left\")\n    ax.grid(True, alpha=0.22)\n    ratio = (\n        data[\"stores\"][\"store_A\"][\"sales\"][:CONTEXT_LEN].mean()\n        / data[\"stores\"][\"store_C\"][\"sales\"][:CONTEXT_LEN].mean()\n    )\n    ax.annotate(\n        f\"Store A earns {ratio:.1f}x Store C\\n(premium vs discount pricing)\\n\"\n        f\"-> store_type is a useful static covariate\",\n        xy=(0.97, 0.05),\n        xycoords=\"axes fraction\",\n        ha=\"right\",\n        fontsize=8,\n        bbox=dict(boxstyle=\"round\", fc=\"#fffbe6\", ec=\"#d4a017\", alpha=0.95),\n    )\n\n    # -- (0,1): Store A actual vs baseline ---------------------------------------\n    ax = axes[0, 1]\n    comp_A = data[\"components\"][\"store_A\"]\n    sales_A = data[\"stores\"][\"store_A\"][\"sales\"]\n    base_A = comp_A[\"base\"]\n    promo_A = data[\"covariates\"][\"promotion\"][\"store_A\"]\n    holiday_A = data[\"covariates\"][\"holiday\"][\"store_A\"]\n\n    ax.plot(\n        weeks[:CONTEXT_LEN],\n        base_A[:CONTEXT_LEN],\n        color=\"#9ca3af\",\n        lw=1.8,\n        ls=\"--\",\n        label=\"Baseline (no covariates)\",\n    )\n    ax.fill_between(\n        weeks[:CONTEXT_LEN],\n        base_A[:CONTEXT_LEN],\n        sales_A[:CONTEXT_LEN],\n        where=(sales_A[:CONTEXT_LEN] > base_A[:CONTEXT_LEN]),\n        alpha=0.35,\n        color=\"#22c55e\",\n        label=\"Covariate uplift\",\n    )\n    ax.fill_between(\n        weeks[:CONTEXT_LEN],\n        sales_A[:CONTEXT_LEN],\n        base_A[:CONTEXT_LEN],\n        where=(sales_A[:CONTEXT_LEN] < base_A[:CONTEXT_LEN]),\n        alpha=0.30,\n        color=\"#ef4444\",\n        label=\"Price suppression\",\n    )\n    ax.plot(\n        weeks[:CONTEXT_LEN],\n        sales_A[:CONTEXT_LEN],\n        color=store_colors[\"store_A\"],\n        lw=2,\n        label=\"Actual sales (Store A)\",\n    )\n\n    for w in range(CONTEXT_LEN):\n        if holiday_A[w] > 0:\n            ax.axvspan(w - 0.45, w + 0.45, alpha=0.22, color=\"darkorange\", zorder=0)\n    promo_weeks = [w for w in range(CONTEXT_LEN) if promo_A[w] > 0]\n    if promo_weeks:\n        ax.scatter(\n            promo_weeks,\n            sales_A[promo_weeks],\n            marker=\"^\",\n            color=\"#16a34a\",\n            s=70,\n            zorder=6,\n            label=\"Promotion week\",\n        )\n\n    add_divider(ax)\n    ax.set_ylabel(\"Weekly Sales (units)\", fontsize=10)\n    ax.set_title(\n        \"Store A -- Actual vs Baseline (No Covariates)\", fontsize=11, fontweight=\"bold\"\n    )\n    ax.legend(fontsize=7.5, loc=\"upper left\", ncol=2)\n    ax.grid(True, alpha=0.22)\n\n    hm = holiday_A[:CONTEXT_LEN] > 0\n    pm = promo_A[:CONTEXT_LEN] > 0\n    h_lift = (\n        (sales_A[:CONTEXT_LEN][hm] - base_A[:CONTEXT_LEN][hm]).mean() if hm.any() else 0\n    )\n    p_lift = (\n        (sales_A[:CONTEXT_LEN][pm] - base_A[:CONTEXT_LEN][pm]).mean() if pm.any() else 0\n    )\n    ax.annotate(\n        f\"Holiday weeks: +{h_lift:.0f} units avg\\n\"\n        f\"Promotion weeks: +{p_lift:.0f} units avg\\n\"\n        f\"Future event schedules must be known for XReg\",\n        xy=(0.97, 0.05),\n        xycoords=\"axes fraction\",\n        ha=\"right\",\n        fontsize=8,\n        bbox=dict(boxstyle=\"round\", fc=\"#fffbe6\", ec=\"#d4a017\", alpha=0.95),\n    )\n\n    # -- (1,0): Price covariate -- full 36 weeks ---------------------------------\n    ax = axes[1, 0]\n    for sid in data[\"stores\"]:\n        ax.plot(\n            weeks,\n            data[\"covariates\"][\"price\"][sid],\n            color=store_colors[sid],\n            lw=2,\n            label=sid,\n            alpha=0.85,\n        )\n    add_divider(ax, label_top=False)\n    ax.set_xlabel(\"Week\", fontsize=10)\n    ax.set_ylabel(\"Price ($)\", fontsize=10)\n    ax.set_title(\n        \"Price Covariate -- Context + Forecast Horizon\", fontsize=11, fontweight=\"bold\"\n    )\n    ax.legend(fontsize=8, loc=\"upper right\")\n    ax.grid(True, alpha=0.22)\n    ax.annotate(\n        \"Prices are planned -- known for forecast horizon\\n\"\n        \"Price elasticity: -$1 increase -> -20 units sold\\n\"\n        \"Store A ($12) consistently more expensive than C ($7.50)\",\n        xy=(0.97, 0.05),\n        xycoords=\"axes fraction\",\n        ha=\"right\",\n        fontsize=8,\n        bbox=dict(boxstyle=\"round\", fc=\"#fffbe6\", ec=\"#d4a017\", alpha=0.95),\n    )\n\n    # -- (1,1): Covariate effect decomposition -----------------------------------\n    ax = axes[1, 1]\n    pe = comp_A[\"price_effect\"]\n    pre = comp_A[\"promo_effect\"]\n    he = comp_A[\"holiday_effect\"]\n\n    ax.fill_between(\n        weeks,\n        0,\n        pe,\n        alpha=0.65,\n        color=\"steelblue\",\n        step=\"mid\",\n        label=f\"Price effect (max +/-{np.abs(pe).max():.0f} units)\",\n    )\n    ax.fill_between(\n        weeks,\n        pe,\n        pe + pre,\n        alpha=0.70,\n        color=\"#22c55e\",\n        step=\"mid\",\n        label=\"Promotion effect (+150 units)\",\n    )\n    ax.fill_between(\n        weeks,\n        pe + pre,\n        pe + pre + he,\n        alpha=0.70,\n        color=\"darkorange\",\n        step=\"mid\",\n        label=\"Holiday effect (+200 units)\",\n    )\n    total = pe + pre + he\n    ax.plot(weeks, total, \"k-\", lw=1.5, alpha=0.75, label=\"Total covariate effect\")\n    ax.axhline(0, color=\"black\", lw=0.9, alpha=0.6)\n    add_divider(ax, label_top=False)\n    ax.set_xlabel(\"Week\", fontsize=10)\n    ax.set_ylabel(\"Effect on sales (units)\", fontsize=10)\n    ax.set_title(\n        \"Store A -- Covariate Effect Decomposition\", fontsize=11, fontweight=\"bold\"\n    )\n    ax.legend(fontsize=7.5, loc=\"upper right\")\n    ax.grid(True, alpha=0.22, axis=\"y\")\n    ax.annotate(\n        f\"Holidays (+200) and promotions (+150) dominate\\n\"\n        f\"Price effect (+/-{np.abs(pe).max():.0f} units) is minor by comparison\\n\"\n        f\"-> Time-varying covariates explain most sales spikes\",\n        xy=(0.97, 0.55),\n        xycoords=\"axes fraction\",\n        ha=\"right\",\n        fontsize=8,\n        bbox=dict(boxstyle=\"round\", fc=\"#fffbe6\", ec=\"#d4a017\", alpha=0.95),\n    )\n\n    tick_pos = list(range(0, TOTAL_LEN, 4))\n    for row in [0, 1]:\n        for col in [0, 1]:\n            axes[row, col].set_xticks(tick_pos)\n\n    plt.tight_layout()\n    output_path = OUTPUT_DIR / \"covariates_data.png\"\n    plt.savefig(output_path, dpi=150, bbox_inches=\"tight\")\n    plt.close()\n    print(f\"\\n Saved visualization: {output_path}\")\n\n\ndef demonstrate_api() -> None:\n    print(\"\\n\" + \"=\" * 70)\n    print(\"  TIMESFM COVARIATES API (TimesFM 2.5)\")\n    print(\"=\" * 70)\n    print(\"\"\"\n# Installation\npip install timesfm[xreg]\n\nimport timesfm\nhparams   = timesfm.TimesFmHparams(backend=\"cpu\", per_core_batch_size=32, horizon_len=12)\nckpt      = timesfm.TimesFmCheckpoint(huggingface_repo_id=\"google/timesfm-2.5-200m-pytorch\")\nmodel     = timesfm.TimesFm(hparams=hparams, checkpoint=ckpt)\n\npoint_fc, quant_fc = model.forecast_with_covariates(\n    inputs=[sales_a, sales_b, sales_c],\n    dynamic_numerical_covariates={\"price\": [price_a, price_b, price_c]},\n    dynamic_categorical_covariates={\"holiday\": [hol_a, hol_b, hol_c]},\n    static_categorical_covariates={\"store_type\": [\"premium\",\"standard\",\"discount\"]},\n    xreg_mode=\"xreg + timesfm\",\n    normalize_xreg_target_per_input=True,\n)\n# point_fc:  (num_series, horizon_len)\n# quant_fc:  (num_series, horizon_len, 10)\n\"\"\")\n\n\ndef explain_xreg_modes() -> None:\n    print(\"\\n\" + \"=\" * 70)\n    print(\"  XREG MODES\")\n    print(\"=\" * 70)\n    print(\"\"\"\n\"xreg + timesfm\" (DEFAULT)\n  1. TimesFM makes baseline forecast\n  2. Fit regression on residuals (actual - baseline) ~ covariates\n  3. Final = TimesFM baseline + XReg adjustment\n  Best when: covariates explain residual variation (e.g. promotions)\n\n\"timesfm + xreg\"\n  1. Fit regression: target ~ covariates\n  2. TimesFM forecasts the residuals\n  3. Final = XReg prediction + TimesFM residual forecast\n  Best when: covariates explain the main signal (e.g. temperature)\n\"\"\")\n\n\ndef main() -> None:\n    print(\"=\" * 70)\n    print(\"  TIMESFM COVARIATES (XREG) EXAMPLE\")\n    print(\"=\" * 70)\n\n    print(\"\\n Generating synthetic retail sales data...\")\n    data = generate_sales_data()\n\n    print(f\"   Stores:         {list(data['stores'].keys())}\")\n    print(f\"   Context length: {CONTEXT_LEN} weeks\")\n    print(f\"   Horizon length: {HORIZON_LEN} weeks\")\n    print(f\"   Covariates:     {list(data['covariates'].keys())}\")\n\n    demonstrate_api()\n    explain_xreg_modes()\n\n    print(\"\\n Creating 2x2 visualization (shared x-axis)...\")\n    create_visualization(data)\n\n    print(\"\\n Saving output data...\")\n    OUTPUT_DIR.mkdir(exist_ok=True)\n\n    records = []\n    for store_id, store_data in data[\"stores\"].items():\n        for i in range(TOTAL_LEN):\n            records.append(\n                {\n                    \"store_id\": store_id,\n                    \"week\": i,\n                    \"split\": \"context\" if i < CONTEXT_LEN else \"horizon\",\n                    \"sales\": round(float(store_data[\"sales\"][i]), 2),\n                    \"base_sales\": round(\n                        float(data[\"components\"][store_id][\"base\"][i]), 2\n                    ),\n                    \"price\": round(float(data[\"covariates\"][\"price\"][store_id][i]), 4),\n                    \"price_effect\": round(\n                        float(data[\"components\"][store_id][\"price_effect\"][i]), 2\n                    ),\n                    \"promotion\": int(data[\"covariates\"][\"promotion\"][store_id][i]),\n                    \"holiday\": int(data[\"covariates\"][\"holiday\"][store_id][i]),\n                    \"day_of_week\": int(data[\"covariates\"][\"day_of_week\"][store_id][i]),\n                    \"store_type\": data[\"covariates\"][\"store_type\"][store_id],\n                    \"region\": data[\"covariates\"][\"region\"][store_id],\n                }\n            )\n\n    df = pd.DataFrame(records)\n    csv_path = OUTPUT_DIR / \"sales_with_covariates.csv\"\n    df.to_csv(csv_path, index=False)\n    print(f\"   Saved: {csv_path}  ({len(df)} rows x {len(df.columns)} cols)\")\n\n    metadata = {\n        \"description\": \"Synthetic retail sales data with covariates for TimesFM XReg demo\",\n        \"note_on_real_data\": (\n            \"For real datasets (e.g., Kaggle Rossmann Store Sales), download to \"\n            \"tempfile.mkdtemp() -- do NOT commit to this repo.\"\n        ),\n        \"stores\": {\n            sid: {\n                **sdata[\"config\"],\n                \"mean_sales_context\": round(\n                    float(sdata[\"sales\"][:CONTEXT_LEN].mean()), 1\n                ),\n            }\n            for sid, sdata in data[\"stores\"].items()\n        },\n        \"dimensions\": {\n            \"context_length\": CONTEXT_LEN,\n            \"horizon_length\": HORIZON_LEN,\n            \"total_length\": TOTAL_LEN,\n            \"num_stores\": N_STORES,\n            \"csv_rows\": len(df),\n        },\n        \"covariates\": {\n            \"dynamic_numerical\": [\"price\"],\n            \"dynamic_categorical\": [\"promotion\", \"holiday\", \"day_of_week\"],\n            \"static_categorical\": [\"store_type\", \"region\"],\n        },\n        \"effect_magnitudes\": {\n            \"holiday\": \"+200 units per holiday week\",\n            \"promotion\": \"+150 units per promotion week\",\n            \"price\": \"-20 units per $1 above base price\",\n        },\n        \"xreg_modes\": {\n            \"xreg + timesfm\": \"Regression on TimesFM residuals (default)\",\n            \"timesfm + xreg\": \"TimesFM on regression residuals\",\n        },\n        \"bug_fixes_history\": [\n            \"v1: Variable-shadowing -- all stores had identical covariates\",\n            \"v2: Fixed shadowing; CONTEXT_LEN 48->24\",\n            \"v3: Added component decomposition (base, price/promo/holiday effects); 2x2 sharex viz\",\n        ],\n    }\n\n    meta_path = OUTPUT_DIR / \"covariates_metadata.json\"\n    with open(meta_path, \"w\") as f:\n        json.dump(metadata, f, indent=2)\n    print(f\"   Saved: {meta_path}\")\n\n    print(\"\\n\" + \"=\" * 70)\n    print(\"  COVARIATES EXAMPLE COMPLETE\")\n    print(\"=\" * 70)\n    print(\"\"\"\nKey points:\n  1. Requires timesfm[xreg] + TimesFM 2.5+ for actual inference\n  2. Dynamic covariates need values for BOTH context AND horizon (future must be known!)\n  3. Static covariates: one value per series (store_type, region)\n  4. All 4 visualization panels share the same week x-axis (0-35)\n  5. Effect decomposition shows holidays/promotions dominate over price variation\n\nOutput files:\n  output/covariates_data.png         -- 2x2 visualization with conclusions\n  output/sales_with_covariates.csv   -- 108-row compact dataset\n  output/covariates_metadata.json    -- metadata + effect magnitudes\n\"\"\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "timesfm-forecasting/examples/covariates-forecasting/output/covariates_metadata.json",
    "content": "{\n  \"description\": \"Synthetic retail sales data with covariates for TimesFM XReg demo\",\n  \"note_on_real_data\": \"For real datasets (e.g., Kaggle Rossmann Store Sales), download to tempfile.mkdtemp() -- do NOT commit to this repo.\",\n  \"stores\": {\n    \"store_A\": {\n      \"type\": \"premium\",\n      \"region\": \"urban\",\n      \"base_sales\": 1000,\n      \"mean_sales_context\": 1148.7\n    },\n    \"store_B\": {\n      \"type\": \"standard\",\n      \"region\": \"suburban\",\n      \"base_sales\": 750,\n      \"mean_sales_context\": 907.0\n    },\n    \"store_C\": {\n      \"type\": \"discount\",\n      \"region\": \"rural\",\n      \"base_sales\": 500,\n      \"mean_sales_context\": 645.3\n    }\n  },\n  \"dimensions\": {\n    \"context_length\": 24,\n    \"horizon_length\": 12,\n    \"total_length\": 36,\n    \"num_stores\": 3,\n    \"csv_rows\": 108\n  },\n  \"covariates\": {\n    \"dynamic_numerical\": [\n      \"price\"\n    ],\n    \"dynamic_categorical\": [\n      \"promotion\",\n      \"holiday\",\n      \"day_of_week\"\n    ],\n    \"static_categorical\": [\n      \"store_type\",\n      \"region\"\n    ]\n  },\n  \"effect_magnitudes\": {\n    \"holiday\": \"+200 units per holiday week\",\n    \"promotion\": \"+150 units per promotion week\",\n    \"price\": \"-20 units per $1 above base price\"\n  },\n  \"xreg_modes\": {\n    \"xreg + timesfm\": \"Regression on TimesFM residuals (default)\",\n    \"timesfm + xreg\": \"TimesFM on regression residuals\"\n  },\n  \"bug_fixes_history\": [\n    \"v1: Variable-shadowing -- all stores had identical covariates\",\n    \"v2: Fixed shadowing; CONTEXT_LEN 48->24\",\n    \"v3: Added component decomposition (base, price/promo/holiday effects); 2x2 sharex viz\"\n  ]\n}"
  },
  {
    "path": "timesfm-forecasting/examples/covariates-forecasting/output/sales_with_covariates.csv",
    "content": "store_id,week,split,sales,base_sales,price,price_effect,promotion,holiday,day_of_week,store_type,region\nstore_A,0,context,1369.59,1012.19,11.6299,7.4,1,1,0,premium,urban\nstore_A,1,context,973.53,973.04,11.9757,0.49,0,0,1,premium,urban\nstore_A,2,context,1064.63,1059.16,11.7269,5.46,0,0,2,premium,urban\nstore_A,3,context,1077.59,1080.99,12.1698,-3.4,0,0,3,premium,urban\nstore_A,4,context,980.39,979.14,11.9372,1.26,0,0,4,premium,urban\nstore_A,5,context,1011.7,1018.36,12.3327,-6.65,0,0,5,premium,urban\nstore_A,6,context,1084.16,1088.16,12.2003,-4.01,0,0,6,premium,urban\nstore_A,7,context,1085.98,1082.23,11.8124,3.75,0,0,0,premium,urban\nstore_A,8,context,1098.52,1105.17,12.3323,-6.65,0,0,1,premium,urban\nstore_A,9,context,1075.62,1081.71,12.3048,-6.1,0,0,2,premium,urban\nstore_A,10,context,1312.23,1159.98,11.8875,2.25,1,0,3,premium,urban\nstore_A,11,context,1368.02,1163.79,11.7883,4.23,0,1,4,premium,urban\nstore_A,12,context,1138.41,1142.06,12.1825,-3.65,0,0,5,premium,urban\nstore_A,13,context,1197.29,1190.09,11.6398,7.2,0,0,6,premium,urban\nstore_A,14,context,1174.12,1168.12,11.6999,6.0,0,0,0,premium,urban\nstore_A,15,context,1128.16,1118.3,11.5074,9.85,0,0,1,premium,urban\nstore_A,16,context,1163.81,1169.55,12.2869,-5.74,0,0,2,premium,urban\nstore_A,17,context,1114.18,1117.48,12.1649,-3.3,0,0,3,premium,urban\nstore_A,18,context,1186.87,1190.98,12.2052,-4.1,0,0,4,premium,urban\nstore_A,19,context,1147.27,1152.88,12.2807,-5.61,0,0,5,premium,urban\nstore_A,20,context,1146.48,1145.66,11.9589,0.82,0,0,6,premium,urban\nstore_A,21,context,1121.83,1123.21,12.0687,-1.37,0,0,0,premium,urban\nstore_A,22,context,1203.28,1196.08,11.6398,7.2,0,0,1,premium,urban\nstore_A,23,context,1344.9,1137.19,11.6145,7.71,0,1,2,premium,urban\nstore_A,24,horizon,1118.64,1122.01,12.1684,-3.37,0,0,3,premium,urban\nstore_A,25,horizon,1121.14,1120.56,11.9711,0.58,0,0,4,premium,urban\nstore_A,26,horizon,1149.99,1151.29,12.0652,-1.3,0,0,5,premium,urban\nstore_A,27,horizon,1284.67,1139.97,12.265,-5.3,1,0,6,premium,urban\nstore_A,28,horizon,1284.67,1137.36,12.1347,-2.69,1,0,0,premium,urban\nstore_A,29,horizon,1132.79,1133.86,12.0536,-1.07,0,0,1,premium,urban\nstore_A,30,horizon,1197.3,1198.49,12.0592,-1.18,0,0,2,premium,urban\nstore_A,31,horizon,1247.22,1093.3,11.804,3.92,1,0,3,premium,urban\nstore_A,32,horizon,1095.84,1086.46,11.5308,9.38,0,0,4,premium,urban\nstore_A,33,horizon,1073.83,1072.57,11.9367,1.27,0,0,5,premium,urban\nstore_A,34,horizon,1134.51,1128.8,11.7146,5.71,0,0,6,premium,urban\nstore_A,35,horizon,1351.15,1149.32,11.9085,1.83,0,1,0,premium,urban\nstore_B,0,context,1062.53,712.0,9.9735,0.53,1,1,0,standard,suburban\nstore_B,1,context,904.49,749.83,9.767,4.66,1,0,1,standard,suburban\nstore_B,2,context,813.63,810.26,9.8316,3.37,0,0,2,standard,suburban\nstore_B,3,context,720.11,720.53,10.0207,-0.41,0,0,3,standard,suburban\nstore_B,4,context,820.78,819.55,9.9389,1.22,0,0,4,standard,suburban\nstore_B,5,context,833.27,823.7,9.5216,9.57,0,0,5,standard,suburban\nstore_B,6,context,795.26,801.78,10.3263,-6.53,0,0,6,standard,suburban\nstore_B,7,context,770.37,778.29,10.3962,-7.92,0,0,0,standard,suburban\nstore_B,8,context,855.92,848.72,9.6402,7.2,0,0,1,standard,suburban\nstore_B,9,context,832.33,833.41,10.054,-1.08,0,0,2,standard,suburban\nstore_B,10,context,1029.44,871.61,9.6086,7.83,1,0,3,standard,suburban\nstore_B,11,context,1066.35,869.8,10.1722,-3.44,0,1,4,standard,suburban\nstore_B,12,context,942.86,938.49,9.7812,4.38,0,0,5,standard,suburban\nstore_B,13,context,1015.99,869.18,10.1594,-3.19,1,0,6,standard,suburban\nstore_B,14,context,836.44,840.98,10.227,-4.54,0,0,0,standard,suburban\nstore_B,15,context,885.72,891.1,10.2686,-5.37,0,0,1,standard,suburban\nstore_B,16,context,901.45,893.6,9.6077,7.85,0,0,2,standard,suburban\nstore_B,17,context,1080.63,938.95,10.416,-8.32,1,0,3,standard,suburban\nstore_B,18,context,922.14,916.74,9.7302,5.4,0,0,4,standard,suburban\nstore_B,19,context,904.66,895.41,9.5374,9.25,0,0,5,standard,suburban\nstore_B,20,context,935.48,936.58,10.0549,-1.1,0,0,6,standard,suburban\nstore_B,21,context,979.23,826.64,9.8709,2.58,1,0,0,standard,suburban\nstore_B,22,context,837.49,844.09,10.3298,-6.6,0,0,1,standard,suburban\nstore_B,23,context,1021.39,827.56,10.3083,-6.17,0,1,2,standard,suburban\nstore_B,24,horizon,847.21,843.55,9.8171,3.66,0,0,3,standard,suburban\nstore_B,25,horizon,789.27,798.33,10.4529,-9.06,0,0,4,standard,suburban\nstore_B,26,horizon,877.09,872.91,9.7909,4.18,0,0,5,standard,suburban\nstore_B,27,horizon,832.42,832.72,10.0151,-0.3,0,0,6,standard,suburban\nstore_B,28,horizon,781.9,777.02,9.756,4.88,0,0,0,standard,suburban\nstore_B,29,horizon,781.04,789.76,10.436,-8.72,0,0,1,standard,suburban\nstore_B,30,horizon,844.57,837.86,9.6646,6.71,0,0,2,standard,suburban\nstore_B,31,horizon,863.43,854.33,9.5449,9.1,0,0,3,standard,suburban\nstore_B,32,horizon,898.12,896.82,9.9351,1.3,0,0,4,standard,suburban\nstore_B,33,horizon,1070.58,930.42,10.4924,-9.85,1,0,5,standard,suburban\nstore_B,34,horizon,820.4,828.24,10.3917,-7.83,0,0,6,standard,suburban\nstore_B,35,horizon,965.86,770.83,10.2486,-4.97,0,1,0,standard,suburban\nstore_C,0,context,709.12,501.23,7.1053,7.89,0,1,0,discount,rural\nstore_C,1,context,651.44,492.78,7.0666,8.67,1,0,1,discount,rural\nstore_C,2,context,659.15,511.04,7.5944,-1.89,1,0,2,discount,rural\nstore_C,3,context,733.06,575.98,7.1462,7.08,1,0,3,discount,rural\nstore_C,4,context,712.21,568.7,7.8247,-6.49,1,0,4,discount,rural\nstore_C,5,context,615.23,611.44,7.3103,3.79,0,0,5,discount,rural\nstore_C,6,context,568.99,561.87,7.1439,7.12,0,0,6,discount,rural\nstore_C,7,context,541.12,549.54,7.921,-8.42,0,0,0,discount,rural\nstore_C,8,context,583.57,576.88,7.1655,6.69,0,0,1,discount,rural\nstore_C,9,context,607.34,603.04,7.2847,4.31,0,0,2,discount,rural\nstore_C,10,context,613.79,606.86,7.1536,6.93,0,0,3,discount,rural\nstore_C,11,context,919.49,561.8,7.1155,7.69,1,1,4,discount,rural\nstore_C,12,context,622.61,613.04,7.0211,9.58,0,0,5,discount,rural\nstore_C,13,context,630.52,621.63,7.0554,8.89,0,0,6,discount,rural\nstore_C,14,context,721.62,715.12,7.1746,6.51,0,0,0,discount,rural\nstore_C,15,context,699.18,690.25,7.0534,8.93,0,0,1,discount,rural\nstore_C,16,context,578.85,580.67,7.5911,-1.82,0,0,2,discount,rural\nstore_C,17,context,598.23,601.84,7.6807,-3.61,0,0,3,discount,rural\nstore_C,18,context,554.43,552.3,7.3936,2.13,0,0,4,discount,rural\nstore_C,19,context,587.39,583.75,7.318,3.64,0,0,5,discount,rural\nstore_C,20,context,615.58,615.67,7.5045,-0.09,0,0,6,discount,rural\nstore_C,21,context,638.68,646.18,7.875,-7.5,0,0,0,discount,rural\nstore_C,22,context,555.99,563.01,7.8511,-7.02,0,0,1,discount,rural\nstore_C,23,context,768.83,559.7,7.0435,9.13,0,1,2,discount,rural\nstore_C,24,horizon,499.62,493.25,7.1815,6.37,0,0,3,discount,rural\nstore_C,25,horizon,570.9,565.64,7.2367,5.27,0,0,4,discount,rural\nstore_C,26,horizon,677.52,522.5,7.2494,5.01,1,0,5,discount,rural\nstore_C,27,horizon,685.25,536.68,7.5712,-1.42,1,0,6,discount,rural\nstore_C,28,horizon,517.46,515.78,7.4163,1.67,0,0,0,discount,rural\nstore_C,29,horizon,549.38,540.36,7.0493,9.01,0,0,1,discount,rural\nstore_C,30,horizon,470.04,467.51,7.3736,2.53,0,0,2,discount,rural\nstore_C,31,horizon,622.9,473.37,7.5238,-0.48,1,0,3,discount,rural\nstore_C,32,horizon,620.09,612.12,7.1017,7.97,0,0,4,discount,rural\nstore_C,33,horizon,614.45,471.12,7.8335,-6.67,1,0,5,discount,rural\nstore_C,34,horizon,484.25,475.29,7.052,8.96,0,0,6,discount,rural\nstore_C,35,horizon,781.64,590.14,7.9248,-8.5,0,1,0,discount,rural\n"
  },
  {
    "path": "timesfm-forecasting/examples/global-temperature/README.md",
    "content": "# TimesFM Forecast Report: Global Temperature Anomaly (2025)\n\n**Model:** TimesFM 1.0 (200M) PyTorch  \n**Generated:** 2026-02-21  \n**Source:** NOAA GISTEMP Global Land-Ocean Temperature Index\n\n---\n\n## Executive Summary\n\nTimesFM forecasts a mean temperature anomaly of **1.19°C** for 2025, slightly below the 2024 average of 1.25°C. The model predicts continued elevated temperatures with a peak of 1.30°C in March 2025 and a minimum of 1.06°C in December 2025.\n\n---\n\n## Input Data\n\n### Historical Temperature Anomalies (2022-2024)\n\n| Date | Anomaly (°C) | Date | Anomaly (°C) | Date | Anomaly (°C) |\n|------|-------------|------|-------------|------|-------------|\n| 2022-01 | 0.89 | 2023-01 | 0.87 | 2024-01 | 1.22 |\n| 2022-02 | 0.89 | 2023-02 | 0.98 | 2024-02 | 1.35 |\n| 2022-03 | 1.02 | 2023-03 | 1.21 | 2024-03 | 1.34 |\n| 2022-04 | 0.88 | 2023-04 | 1.00 | 2024-04 | 1.26 |\n| 2022-05 | 0.85 | 2023-05 | 0.94 | 2024-05 | 1.15 |\n| 2022-06 | 0.88 | 2023-06 | 1.08 | 2024-06 | 1.20 |\n| 2022-07 | 0.88 | 2023-07 | 1.18 | 2024-07 | 1.24 |\n| 2022-08 | 0.90 | 2023-08 | 1.24 | 2024-08 | 1.30 |\n| 2022-09 | 0.88 | 2023-09 | 1.47 | 2024-09 | 1.28 |\n| 2022-10 | 0.95 | 2023-10 | 1.32 | 2024-10 | 1.27 |\n| 2022-11 | 0.77 | 2023-11 | 1.18 | 2024-11 | 1.22 |\n| 2022-12 | 0.78 | 2023-12 | 1.16 | 2024-12 | 1.20 |\n\n**Statistics:**\n- Total observations: 36 months\n- Mean anomaly: 1.09°C\n- Trend (2022→2024): +0.37°C\n\n---\n\n## Raw Forecast Output\n\n### Point Forecast and Confidence Intervals\n\n| Month | Point | 80% CI | 90% CI |\n|-------|-------|--------|--------|\n| 2025-01 | 1.259 | [1.141, 1.297] | [1.248, 1.324] |\n| 2025-02 | 1.286 | [1.141, 1.340] | [1.277, 1.375] |\n| 2025-03 | 1.295 | [1.127, 1.355] | [1.287, 1.404] |\n| 2025-04 | 1.221 | [1.035, 1.290] | [1.208, 1.331] |\n| 2025-05 | 1.170 | [0.969, 1.239] | [1.153, 1.289] |\n| 2025-06 | 1.146 | [0.942, 1.218] | [1.128, 1.270] |\n| 2025-07 | 1.170 | [0.950, 1.248] | [1.151, 1.300] |\n| 2025-08 | 1.203 | [0.971, 1.284] | [1.186, 1.341] |\n| 2025-09 | 1.191 | [0.959, 1.283] | [1.178, 1.335] |\n| 2025-10 | 1.149 | [0.908, 1.240] | [1.126, 1.287] |\n| 2025-11 | 1.080 | [0.836, 1.176] | [1.062, 1.228] |\n| 2025-12 | 1.061 | [0.802, 1.153] | [1.037, 1.217] |\n\n### JSON Output\n\n```json\n{\n  \"model\": \"TimesFM 1.0 (200M) PyTorch\",\n  \"input\": {\n    \"source\": \"NOAA GISTEMP Global Temperature Anomaly\",\n    \"n_observations\": 36,\n    \"date_range\": \"2022-01 to 2024-12\",\n    \"mean_anomaly_c\": 1.089\n  },\n  \"forecast\": {\n    \"horizon\": 12,\n    \"dates\": [\"2025-01\", \"2025-02\", \"2025-03\", \"2025-04\", \"2025-05\", \"2025-06\",\n              \"2025-07\", \"2025-08\", \"2025-09\", \"2025-10\", \"2025-11\", \"2025-12\"],\n    \"point\": [1.259, 1.286, 1.295, 1.221, 1.170, 1.146, 1.170, 1.203, 1.191, 1.149, 1.080, 1.061]\n  },\n  \"summary\": {\n    \"forecast_mean_c\": 1.186,\n    \"forecast_max_c\": 1.295,\n    \"forecast_min_c\": 1.061,\n    \"vs_last_year_mean\": -0.067\n  }\n}\n```\n\n---\n\n## Visualization\n\n![Temperature Anomaly Forecast](forecast_visualization.png)\n\n---\n\n## Findings\n\n### Key Observations\n\n1. **Slight cooling trend expected**: The model forecasts a mean anomaly 0.07°C below 2024 levels, suggesting a potential stabilization after the record-breaking temperatures of 2023-2024.\n\n2. **Seasonal pattern preserved**: The forecast shows the expected seasonal variation with higher anomalies in late winter (Feb-Mar) and lower in late fall (Nov-Dec).\n\n3. **Widening uncertainty**: The 90% CI expands from ±0.04°C in January to ±0.08°C in December, reflecting typical forecast uncertainty growth over time.\n\n4. **Peak temperature**: March 2025 is predicted to have the highest anomaly at 1.30°C, potentially approaching the September 2023 record of 1.47°C.\n\n### Limitations\n\n- TimesFM is a zero-shot forecaster without physical climate model constraints\n- The 36-month training window may not capture multi-decadal climate trends\n- El Niño/La Niña cycles are not explicitly modeled\n\n### Recommendations\n\n- Use this forecast as a baseline comparison for physics-based climate models\n- Update forecast quarterly as new observations become available\n- Consider ensemble approaches combining TimesFM with other methods\n\n---\n\n## Reproducibility\n\n### Files\n\n| File | Description |\n|------|-------------|\n| `temperature_anomaly.csv` | Input data (36 months) |\n| `forecast_output.csv` | Point forecast with quantiles |\n| `forecast_output.json` | Machine-readable forecast |\n| `forecast_visualization.png` | Fan chart visualization |\n| `run_forecast.py` | Forecasting script |\n| `visualize_forecast.py` | Visualization script |\n| `run_example.sh` | One-click runner |\n\n### How to Reproduce\n\n```bash\n# Install dependencies\nuv pip install \"timesfm[torch]\" matplotlib pandas numpy\n\n# Run the complete example\ncd scientific-skills/timesfm-forecasting/examples/global-temperature\n./run_example.sh\n```\n\n---\n\n## Technical Notes\n\n### API Discovery\n\nThe TimesFM PyTorch API differs from the GitHub README documentation:\n\n**Documented (GitHub README):**\n```python\nmodel = timesfm.TimesFm(\n    context_len=512,\n    horizon_len=128,\n    backend=\"gpu\",\n)\nmodel.load_from_google_repo(\"google/timesfm-2.5-200m-pytorch\")\n```\n\n**Actual Working API:**\n```python\nhparams = timesfm.TimesFmHparams(horizon_len=12)\ncheckpoint = timesfm.TimesFmCheckpoint(\n    huggingface_repo_id=\"google/timesfm-1.0-200m-pytorch\"\n)\nmodel = timesfm.TimesFm(hparams=hparams, checkpoint=checkpoint)\n```\n\n### TimesFM 2.5 PyTorch Issue\n\nThe `google/timesfm-2.5-200m-pytorch` checkpoint downloads as `model.safetensors`, but the TimesFM loader expects `torch_model.ckpt`. This causes a `FileNotFoundError` at model load time. Using TimesFM 1.0 PyTorch resolves this issue.\n\n---\n\n*Report generated by TimesFM Forecasting Skill (claude-scientific-skills)*\n"
  },
  {
    "path": "timesfm-forecasting/examples/global-temperature/generate_animation_data.py",
    "content": "#!/usr/bin/env python3\n\"\"\"\nGenerate animation data for interactive forecast visualization.\n\nThis script runs TimesFM forecasts incrementally, starting with minimal data\nand adding one point at a time. Each forecast extends to the final date (2025-12).\n\nOutput: animation_data.json with all forecast steps\n\"\"\"\n\nfrom __future__ import annotations\n\nimport json\nfrom pathlib import Path\n\nimport numpy as np\nimport pandas as pd\nimport timesfm\n\n# Configuration\nMIN_CONTEXT = 12  # Minimum points to start forecasting\nMAX_HORIZON = (\n    36  # Max forecast length (when we have 12 points, forecast 36 months to 2025-12)\n)\nTOTAL_MONTHS = 48  # Total months from 2022-01 to 2025-12 (graph extent)\nINPUT_FILE = Path(__file__).parent / \"temperature_anomaly.csv\"\nOUTPUT_FILE = Path(__file__).parent / \"output\" / \"animation_data.json\"\n\n\ndef main() -> None:\n    print(\"=\" * 60)\n    print(\"  TIMESFM ANIMATION DATA GENERATOR\")\n    print(\"  Dynamic horizon - forecasts always reach 2025-12\")\n    print(\"=\" * 60)\n\n    # Load data\n    df = pd.read_csv(INPUT_FILE, parse_dates=[\"date\"])\n    df = df.sort_values(\"date\").reset_index(drop=True)\n\n    all_dates = df[\"date\"].tolist()\n    all_values = df[\"anomaly_c\"].values.astype(np.float32)\n\n    print(f\"\\n📊 Total data: {len(all_values)} months\")\n    print(\n        f\"   Date range: {all_dates[0].strftime('%Y-%m')} to {all_dates[-1].strftime('%Y-%m')}\"\n    )\n    print(f\"   Animation steps: {len(all_values) - MIN_CONTEXT + 1}\")\n\n    # Load TimesFM with max horizon (will truncate output for shorter forecasts)\n    print(f\"\\n🤖 Loading TimesFM 1.0 (200M) PyTorch (horizon={MAX_HORIZON})...\")\n    hparams = timesfm.TimesFmHparams(horizon_len=MAX_HORIZON)\n    checkpoint = timesfm.TimesFmCheckpoint(\n        huggingface_repo_id=\"google/timesfm-1.0-200m-pytorch\"\n    )\n    model = timesfm.TimesFm(hparams=hparams, checkpoint=checkpoint)\n\n    # Generate forecasts for each step\n    animation_steps = []\n\n    for n_points in range(MIN_CONTEXT, len(all_values) + 1):\n        step_num = n_points - MIN_CONTEXT + 1\n        total_steps = len(all_values) - MIN_CONTEXT + 1\n\n        # Calculate dynamic horizon: forecast enough to reach 2025-12\n        horizon = TOTAL_MONTHS - n_points\n\n        print(\n            f\"\\n📈 Step {step_num}/{total_steps}: Using {n_points} points, forecasting {horizon} months...\"\n        )\n\n        # Get historical data up to this point\n        historical_values = all_values[:n_points]\n        historical_dates = all_dates[:n_points]\n\n        # Run forecast (model outputs MAX_HORIZON, we truncate to actual horizon)\n        point, quantiles = model.forecast(\n            [historical_values],\n            freq=[0],\n        )\n\n        # Truncate to actual horizon\n        point = point[0][:horizon]\n        quantiles = quantiles[0, :horizon, :]\n\n        # Determine forecast dates\n        last_date = historical_dates[-1]\n        forecast_dates = pd.date_range(\n            start=last_date + pd.DateOffset(months=1),\n            periods=horizon,\n            freq=\"MS\",\n        )\n\n        # Store step data\n        step_data = {\n            \"step\": step_num,\n            \"n_points\": n_points,\n            \"horizon\": horizon,\n            \"last_historical_date\": historical_dates[-1].strftime(\"%Y-%m\"),\n            \"historical_dates\": [d.strftime(\"%Y-%m\") for d in historical_dates],\n            \"historical_values\": historical_values.tolist(),\n            \"forecast_dates\": [d.strftime(\"%Y-%m\") for d in forecast_dates],\n            \"point_forecast\": point.tolist(),\n            \"q10\": quantiles[:, 0].tolist(),\n            \"q20\": quantiles[:, 1].tolist(),\n            \"q80\": quantiles[:, 7].tolist(),\n            \"q90\": quantiles[:, 8].tolist(),\n        }\n\n        animation_steps.append(step_data)\n\n        # Show summary\n        print(f\"   Last date: {historical_dates[-1].strftime('%Y-%m')}\")\n        print(f\"   Forecast to: {forecast_dates[-1].strftime('%Y-%m')}\")\n        print(f\"   Forecast mean: {point.mean():.3f}°C\")\n\n    # Create output\n    output = {\n        \"metadata\": {\n            \"model\": \"TimesFM 1.0 (200M) PyTorch\",\n            \"total_steps\": len(animation_steps),\n            \"min_context\": MIN_CONTEXT,\n            \"max_horizon\": MAX_HORIZON,\n            \"total_months\": TOTAL_MONTHS,\n            \"data_source\": \"NOAA GISTEMP Global Temperature Anomaly\",\n            \"full_date_range\": f\"{all_dates[0].strftime('%Y-%m')} to {all_dates[-1].strftime('%Y-%m')}\",\n        },\n        \"actual_data\": {\n            \"dates\": [d.strftime(\"%Y-%m\") for d in all_dates],\n            \"values\": all_values.tolist(),\n        },\n        \"animation_steps\": animation_steps,\n    }\n\n    # Save\n    with open(OUTPUT_FILE, \"w\") as f:\n        json.dump(output, f, indent=2)\n\n    print(f\"\\n\" + \"=\" * 60)\n    print(\"  ✅ ANIMATION DATA COMPLETE\")\n    print(\"=\" * 60)\n    print(f\"\\n📁 Output: {OUTPUT_FILE}\")\n    print(f\"   Total steps: {len(animation_steps)}\")\n    print(f\"   Each forecast extends to 2025-12\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "timesfm-forecasting/examples/global-temperature/generate_gif.py",
    "content": "#!/usr/bin/env python3\n\"\"\"\nGenerate animated GIF showing forecast evolution.\n\nCreates a GIF animation showing how the TimesFM forecast changes\nas more historical data points are added. Shows the full actual data as a background layer.\n\"\"\"\nfrom __future__ import annotations\n\nimport json\nfrom pathlib import Path\n\nimport matplotlib.pyplot as plt\nimport matplotlib.dates as mdates\nimport numpy as np\nimport pandas as pd\nfrom PIL import Image\n\n# Configuration\nEXAMPLE_DIR = Path(__file__).parent\nDATA_FILE = EXAMPLE_DIR / \"output\" / \"animation_data.json\"\nOUTPUT_FILE = EXAMPLE_DIR / \"output\" / \"forecast_animation.gif\"\nDURATION_MS = 500  # Time per frame in milliseconds\n\n\ndef create_frame(\n    ax,\n    step_data: dict,\n    actual_data: dict,\n    final_forecast: dict,\n    total_steps: int,\n    x_min,\n    x_max,\n    y_min,\n    y_max,\n) -> None:\n    \"\"\"Create a single frame of the animation with fixed axes.\"\"\"\n    ax.clear()\n\n    # Parse dates\n    historical_dates = pd.to_datetime(step_data[\"historical_dates\"])\n    forecast_dates = pd.to_datetime(step_data[\"forecast_dates\"])\n    \n    # Get final forecast dates for full extent\n    final_forecast_dates = pd.to_datetime(final_forecast[\"forecast_dates\"])\n    \n    # All actual dates for full background\n    all_actual_dates = pd.to_datetime(actual_data[\"dates\"])\n    all_actual_values = np.array(actual_data[\"values\"])\n\n    # ========== BACKGROUND LAYER: Full actual data (faded) ==========\n    ax.plot(\n        all_actual_dates,\n        all_actual_values,\n        color=\"#9ca3af\",\n        linewidth=1,\n        marker=\"o\",\n        markersize=2,\n        alpha=0.3,\n        label=\"All observed data\",\n        zorder=1,\n    )\n    \n    # ========== BACKGROUND LAYER: Final forecast (faded) ==========\n    ax.plot(\n        final_forecast_dates,\n        final_forecast[\"point_forecast\"],\n        color=\"#fca5a5\",\n        linewidth=1,\n        linestyle=\"--\",\n        marker=\"s\",\n        markersize=2,\n        alpha=0.3,\n        label=\"Final forecast\",\n        zorder=2,\n    )\n\n    # ========== FOREGROUND LAYER: Historical data used (bright) ==========\n    ax.plot(\n        historical_dates,\n        step_data[\"historical_values\"],\n        color=\"#3b82f6\",\n        linewidth=2.5,\n        marker=\"o\",\n        markersize=5,\n        label=\"Data used\",\n        zorder=10,\n    )\n\n    # ========== FOREGROUND LAYER: Current forecast (bright) ==========\n    # 90% CI (outer)\n    ax.fill_between(\n        forecast_dates,\n        step_data[\"q10\"],\n        step_data[\"q90\"],\n        alpha=0.15,\n        color=\"#ef4444\",\n        zorder=5,\n    )\n    \n    # 80% CI (inner)\n    ax.fill_between(\n        forecast_dates,\n        step_data[\"q20\"],\n        step_data[\"q80\"],\n        alpha=0.25,\n        color=\"#ef4444\",\n        zorder=6,\n    )\n    \n    # Forecast line\n    ax.plot(\n        forecast_dates,\n        step_data[\"point_forecast\"],\n        color=\"#ef4444\",\n        linewidth=2.5,\n        marker=\"s\",\n        markersize=5,\n        label=\"Forecast\",\n        zorder=7,\n    )\n\n    # ========== Vertical line at forecast boundary ==========\n    ax.axvline(\n        x=historical_dates[-1],\n        color=\"#6b7280\",\n        linestyle=\"--\",\n        linewidth=1.5,\n        alpha=0.7,\n        zorder=8,\n    )\n\n    # ========== Formatting ==========\n    ax.set_xlabel(\"Date\", fontsize=11)\n    ax.set_ylabel(\"Temperature Anomaly (°C)\", fontsize=11)\n    ax.set_title(\n        f\"TimesFM Forecast Evolution\\n\"\n        f\"Step {step_data['step']}/{total_steps}: {step_data['n_points']} points → \"\n        f\"forecast from {step_data['last_historical_date']}\",\n        fontsize=13,\n        fontweight=\"bold\",\n    )\n    \n    ax.grid(True, alpha=0.3, zorder=0)\n    ax.legend(loc=\"upper left\", fontsize=8)\n    \n    # FIXED AXES - same for all frames\n    ax.set_xlim(x_min, x_max)\n    ax.set_ylim(y_min, y_max)\n    \n    # Format x-axis\n    ax.xaxis.set_major_formatter(mdates.DateFormatter(\"%Y-%m\"))\n    ax.xaxis.set_major_locator(mdates.MonthLocator(interval=4))\n    plt.setp(ax.xaxis.get_majorticklabels(), rotation=45, ha=\"right\")\n\n\ndef main() -> None:\n    print(\"=\" * 60)\n    print(\"  GENERATING ANIMATED GIF\")\n    print(\"=\" * 60)\n    \n    # Load data\n    with open(DATA_FILE) as f:\n        data = json.load(f)\n    \n    total_steps = len(data[\"animation_steps\"])\n    print(f\"\\n📊 Total frames: {total_steps}\")\n    \n    # Get the final forecast step for reference\n    final_forecast = data[\"animation_steps\"][-1]\n    \n    # Calculate fixed axis extents from ALL data\n    all_actual_dates = pd.to_datetime(data[\"actual_data\"][\"dates\"])\n    all_actual_values = np.array(data[\"actual_data\"][\"values\"])\n    \n    final_forecast_dates = pd.to_datetime(final_forecast[\"forecast_dates\"])\n    final_forecast_values = np.array(final_forecast[\"point_forecast\"])\n    \n    # X-axis: from first actual date to last forecast date\n    x_min = all_actual_dates[0]\n    x_max = final_forecast_dates[-1]\n    \n    # Y-axis: min/max across all actual + all forecasts with CIs\n    all_forecast_q10 = np.array(final_forecast[\"q10\"])\n    all_forecast_q90 = np.array(final_forecast[\"q90\"])\n    \n    all_values = np.concatenate([\n        all_actual_values,\n        final_forecast_values,\n        all_forecast_q10,\n        all_forecast_q90,\n    ])\n    y_min = all_values.min() - 0.05\n    y_max = all_values.max() + 0.05\n    \n    print(f\"   X-axis: {x_min.strftime('%Y-%m')} to {x_max.strftime('%Y-%m')}\")\n    print(f\"   Y-axis: {y_min:.2f}°C to {y_max:.2f}°C\")\n    \n    # Create figure\n    fig, ax = plt.subplots(figsize=(12, 6))\n    \n    # Generate frames\n    frames = []\n    \n    for i, step in enumerate(data[\"animation_steps\"]):\n        print(f\"   Frame {i + 1}/{total_steps}...\")\n        \n        create_frame(\n            ax,\n            step,\n            data[\"actual_data\"],\n            final_forecast,\n            total_steps,\n            x_min,\n            x_max,\n            y_min,\n            y_max,\n        )\n        \n        # Save frame to buffer\n        fig.canvas.draw()\n        \n        # Convert to PIL Image\n        buf = fig.canvas.buffer_rgba()\n        width, height = fig.canvas.get_width_height()\n        img = Image.frombytes(\"RGBA\", (width, height), buf)\n        frames.append(img.convert(\"RGB\"))\n    \n    plt.close()\n    \n    # Save as GIF\n    print(f\"\\n💾 Saving GIF: {OUTPUT_FILE}\")\n    frames[0].save(\n        OUTPUT_FILE,\n        save_all=True,\n        append_images=frames[1:],\n        duration=DURATION_MS,\n        loop=0,  # Loop forever\n    )\n    \n    # Get file size\n    size_kb = OUTPUT_FILE.stat().st_size / 1024\n    print(f\"   File size: {size_kb:.1f} KB\")\n    print(f\"\\n✅ Done!\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "timesfm-forecasting/examples/global-temperature/generate_html.py",
    "content": "#!/usr/bin/env python3\n\"\"\"\nGenerate a self-contained HTML file with embedded animation data.\n\nThis creates a single HTML file that can be opened directly in any browser\nwithout needing a server or external JSON file (CORS-safe).\n\"\"\"\n\nfrom __future__ import annotations\n\nimport json\nfrom pathlib import Path\n\nEXAMPLE_DIR = Path(__file__).parent\nDATA_FILE = EXAMPLE_DIR / \"output\" / \"animation_data.json\"\nOUTPUT_FILE = EXAMPLE_DIR / \"output\" / \"interactive_forecast.html\"\n\n\nHTML_TEMPLATE = \"\"\"<!DOCTYPE html>\n<html lang=\"en\">\n<head>\n    <meta charset=\"UTF-8\">\n    <meta name=\"viewport\" content=\"width=device-width, initial-scale=1.0\">\n    <title>TimesFM Interactive Forecast Animation</title>\n    <script src=\"https://cdn.jsdelivr.net/npm/chart.js\"></script>\n    <style>\n        * {{ margin: 0; padding: 0; box-sizing: border-box; }}\n        \n        body {{\n            font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;\n            background: linear-gradient(135deg, #1a1a2e 0%, #16213e 100%);\n            min-height: 100vh;\n            color: #e0e0e0;\n            padding: 20px;\n        }}\n        \n        .container {{ max-width: 1200px; margin: 0 auto; }}\n        \n        header {{ text-align: center; margin-bottom: 30px; }}\n        \n        h1 {{\n            font-size: 2rem;\n            margin-bottom: 10px;\n            background: linear-gradient(90deg, #60a5fa, #a78bfa);\n            -webkit-background-clip: text;\n            -webkit-text-fill-color: transparent;\n        }}\n        \n        .subtitle {{ color: #9ca3af; font-size: 1.1rem; }}\n        \n        .chart-container {{\n            background: rgba(255, 255, 255, 0.05);\n            border-radius: 16px;\n            padding: 20px;\n            margin-bottom: 20px;\n            box-shadow: 0 4px 20px rgba(0, 0, 0, 0.3);\n        }}\n        \n        #chart {{ width: 100% !important; height: 450px !important; }}\n        \n        .controls {{\n            display: flex;\n            flex-direction: column;\n            gap: 20px;\n            background: rgba(255, 255, 255, 0.05);\n            border-radius: 16px;\n            padding: 20px;\n        }}\n        \n        .slider-container {{ display: flex; flex-direction: column; gap: 10px; }}\n        \n        .slider-label {{ display: flex; justify-content: space-between; align-items: center; }}\n        .slider-label span {{ font-size: 0.9rem; color: #9ca3af; }}\n        .slider-label .value {{ font-weight: 600; color: #60a5fa; font-size: 1.1rem; }}\n        \n        input[type=\"range\"] {{\n            width: 100%; height: 8px; border-radius: 4px;\n            background: #374151; outline: none; -webkit-appearance: none;\n        }}\n        \n        input[type=\"range\"]::-webkit-slider-thumb {{\n            -webkit-appearance: none;\n            width: 24px; height: 24px; border-radius: 50%;\n            background: linear-gradient(135deg, #60a5fa, #a78bfa);\n            cursor: pointer;\n            box-shadow: 0 2px 10px rgba(96, 165, 250, 0.5);\n        }}\n        \n        .buttons {{ display: flex; gap: 10px; flex-wrap: wrap; }}\n        \n        button {{\n            flex: 1; min-width: 100px;\n            padding: 12px 20px;\n            border: none; border-radius: 8px;\n            font-size: 1rem; font-weight: 600;\n            cursor: pointer; transition: all 0.2s ease;\n        }}\n        \n        .btn-primary {{\n            background: linear-gradient(135deg, #60a5fa, #a78bfa);\n            color: white;\n        }}\n        .btn-primary:hover {{ transform: translateY(-2px); box-shadow: 0 4px 15px rgba(96, 165, 250, 0.4); }}\n        \n        .btn-secondary {{ background: #374151; color: #e0e0e0; }}\n        .btn-secondary:hover {{ background: #4b5563; }}\n        \n        .stats {{\n            display: grid;\n            grid-template-columns: repeat(auto-fit, minmax(150px, 1fr));\n            gap: 15px;\n            margin-top: 20px;\n        }}\n        \n        .stat-card {{\n            background: rgba(255, 255, 255, 0.05);\n            border-radius: 12px;\n            padding: 15px;\n            text-align: center;\n        }}\n        .stat-card .label {{ font-size: 0.8rem; color: #9ca3af; margin-bottom: 5px; }}\n        .stat-card .value {{ font-size: 1.3rem; font-weight: 600; color: #60a5fa; }}\n        \n        .legend {{\n            display: flex;\n            justify-content: center;\n            gap: 20px;\n            flex-wrap: wrap;\n            margin-top: 15px;\n            padding-top: 15px;\n            border-top: 1px solid rgba(255, 255, 255, 0.1);\n        }}\n        \n        .legend-item {{ display: flex; align-items: center; gap: 8px; font-size: 0.85rem; }}\n        .legend-color {{ width: 16px; height: 16px; border-radius: 4px; }}\n        \n        footer {{\n            text-align: center;\n            margin-top: 30px;\n            color: #6b7280;\n            font-size: 0.9rem;\n        }}\n        footer a {{ color: #60a5fa; text-decoration: none; }}\n    </style>\n</head>\n<body>\n    <div class=\"container\">\n        <header>\n            <h1>TimesFM Forecast Evolution</h1>\n            <p class=\"subtitle\">Watch the forecast evolve as more data is added — forecasts extend to 2025-12</p>\n        </header>\n        \n        <div class=\"chart-container\">\n            <canvas id=\"chart\"></canvas>\n        </div>\n        \n        <div class=\"controls\">\n            <div class=\"slider-container\">\n                <div class=\"slider-label\">\n                    <span>Data Points Used</span>\n                    <span class=\"value\" id=\"points-value\">12 / 36</span>\n                </div>\n                <input type=\"range\" id=\"slider\" min=\"0\" max=\"24\" value=\"0\" step=\"1\">\n                <div class=\"slider-label\">\n                    <span>2022-01</span>\n                    <span id=\"date-end\">Using data through 2022-12</span>\n                </div>\n            </div>\n            \n            <div class=\"buttons\">\n                <button class=\"btn-primary\" id=\"play-btn\">▶ Play</button>\n                <button class=\"btn-secondary\" id=\"reset-btn\">↺ Reset</button>\n            </div>\n            \n            <div class=\"stats\">\n                <div class=\"stat-card\">\n                    <div class=\"label\">Forecast Mean</div>\n                    <div class=\"value\" id=\"stat-mean\">0.86°C</div>\n                </div>\n                <div class=\"stat-card\">\n                    <div class=\"label\">Forecast Horizon</div>\n                    <div class=\"value\" id=\"stat-horizon\">36 months</div>\n                </div>\n                <div class=\"stat-card\">\n                    <div class=\"label\">Forecast Max</div>\n                    <div class=\"value\" id=\"stat-max\">--</div>\n                </div>\n                <div class=\"stat-card\">\n                    <div class=\"label\">Forecast Min</div>\n                    <div class=\"value\" id=\"stat-min\">--</div>\n                </div>\n            </div>\n            \n            <div class=\"legend\">\n                <div class=\"legend-item\">\n                    <div class=\"legend-color\" style=\"background: #9ca3af;\"></div>\n                    <span>All Observed Data</span>\n                </div>\n                <div class=\"legend-item\">\n                    <div class=\"legend-color\" style=\"background: #fca5a5;\"></div>\n                    <span>Final Forecast (reference)</span>\n                </div>\n                <div class=\"legend-item\">\n                    <div class=\"legend-color\" style=\"background: #3b82f6;\"></div>\n                    <span>Data Used</span>\n                </div>\n                <div class=\"legend-item\">\n                    <div class=\"legend-color\" style=\"background: #ef4444;\"></div>\n                    <span>Current Forecast</span>\n                </div>\n                <div class=\"legend-item\">\n                    <div class=\"legend-color\" style=\"background: rgba(239, 68, 68, 0.25);\"></div>\n                    <span>80% CI</span>\n                </div>\n            </div>\n        </div>\n        \n        <footer>\n            <p>TimesFM 1.0 (200M) PyTorch • <a href=\"https://github.com/google-research/timesfm\">Google Research</a></p>\n        </footer>\n    </div>\n\n    <script>\n        // Embedded animation data (no external fetch needed)\n        const animationData = {data_json};\n        \n        let chart = null;\n        let isPlaying = false;\n        let playInterval = null;\n        let currentStep = 0;\n\n        // Fixed axis extents\n        let allDates = [];\n        let yMin = 0.7;\n        let yMax = 1.55;\n\n        function initChart() {{\n            const ctx = document.getElementById('chart').getContext('2d');\n            \n            // Calculate fixed extents\n            const finalStep = animationData.animation_steps[animationData.animation_steps.length - 1];\n            allDates = [\n                ...animationData.actual_data.dates,\n                ...finalStep.forecast_dates\n            ];\n            \n            // Y extent from all values\n            const allValues = [\n                ...animationData.actual_data.values,\n                ...finalStep.point_forecast,\n                ...finalStep.q10,\n                ...finalStep.q90\n            ];\n            yMin = Math.min(...allValues) - 0.05;\n            yMax = Math.max(...allValues) + 0.05;\n            \n            chart = new Chart(ctx, {{\n                type: 'line',\n                data: {{\n                    labels: allDates,\n                    datasets: [\n                        {{\n                            label: 'All Observed',\n                            data: animationData.actual_data.values.map((v, i) => ({{x: animationData.actual_data.dates[i], y: v}})),\n                            borderColor: '#9ca3af',\n                            borderWidth: 1,\n                            pointRadius: 2,\n                            pointBackgroundColor: '#9ca3af',\n                            fill: false,\n                            tension: 0.1,\n                            order: 1,\n                        }},\n                        {{\n                            label: 'Final Forecast',\n                            data: [...Array(animationData.actual_data.dates.length).fill(null), ...finalStep.point_forecast],\n                            borderColor: '#fca5a5',\n                            borderWidth: 1,\n                            borderDash: [4, 4],\n                            pointRadius: 2,\n                            pointBackgroundColor: '#fca5a5',\n                            fill: false,\n                            tension: 0.1,\n                            order: 2,\n                        }},\n                        {{\n                            label: 'Data Used',\n                            data: [],\n                            borderColor: '#3b82f6',\n                            backgroundColor: 'rgba(59, 130, 246, 0.1)',\n                            borderWidth: 2.5,\n                            pointRadius: 4,\n                            pointBackgroundColor: '#3b82f6',\n                            fill: false,\n                            tension: 0.1,\n                            order: 10,\n                        }},\n                        {{\n                            label: '90% CI Lower',\n                            data: [],\n                            borderColor: 'transparent',\n                            backgroundColor: 'rgba(239, 68, 68, 0.08)',\n                            fill: '+1',\n                            pointRadius: 0,\n                            tension: 0.1,\n                            order: 5,\n                        }},\n                        {{\n                            label: '90% CI Upper',\n                            data: [],\n                            borderColor: 'transparent',\n                            backgroundColor: 'rgba(239, 68, 68, 0.08)',\n                            fill: false,\n                            pointRadius: 0,\n                            tension: 0.1,\n                            order: 5,\n                        }},\n                        {{\n                            label: '80% CI Lower',\n                            data: [],\n                            borderColor: 'transparent',\n                            backgroundColor: 'rgba(239, 68, 68, 0.2)',\n                            fill: '+1',\n                            pointRadius: 0,\n                            tension: 0.1,\n                            order: 6,\n                        }},\n                        {{\n                            label: '80% CI Upper',\n                            data: [],\n                            borderColor: 'transparent',\n                            backgroundColor: 'rgba(239, 68, 68, 0.2)',\n                            fill: false,\n                            pointRadius: 0,\n                            tension: 0.1,\n                            order: 6,\n                        }},\n                        {{\n                            label: 'Forecast',\n                            data: [],\n                            borderColor: '#ef4444',\n                            backgroundColor: 'rgba(239, 68, 68, 0.1)',\n                            borderWidth: 2.5,\n                            pointRadius: 4,\n                            pointBackgroundColor: '#ef4444',\n                            fill: false,\n                            tension: 0.1,\n                            order: 7,\n                        }},\n                    ]\n                }},\n                options: {{\n                    responsive: true,\n                    maintainAspectRatio: false,\n                    interaction: {{ intersect: false, mode: 'index' }},\n                    plugins: {{\n                        legend: {{ display: false }},\n                        tooltip: {{\n                            backgroundColor: 'rgba(0, 0, 0, 0.8)',\n                            titleColor: '#fff',\n                            bodyColor: '#fff',\n                            padding: 12,\n                        }},\n                    }},\n                    scales: {{\n                        x: {{\n                            grid: {{ color: 'rgba(255, 255, 255, 0.05)' }},\n                            ticks: {{ color: '#9ca3af', maxRotation: 45, minRotation: 45 }},\n                        }},\n                        y: {{\n                            grid: {{ color: 'rgba(255, 255, 255, 0.05)' }},\n                            ticks: {{\n                                color: '#9ca3af',\n                                callback: v => v.toFixed(2) + '°C'\n                            }},\n                            min: yMin,\n                            max: yMax,\n                        }},\n                    }},\n                    animation: {{ duration: 150 }},\n                }},\n            }});\n        }}\n\n        function updateChart(stepIndex) {{\n            if (!animationData || !chart) return;\n            \n            const step = animationData.animation_steps[stepIndex];\n            const finalStep = animationData.animation_steps[animationData.animation_steps.length - 1];\n            const actual = animationData.actual_data;\n            \n            // Build data arrays for each dataset\n            const nHist = step.historical_dates.length;\n            const nForecast = step.forecast_dates.length;\n            const nActual = actual.dates.length;\n            const nFinalForecast = finalStep.forecast_dates.length;\n            const totalPoints = nActual + nFinalForecast;\n            \n            // Dataset 0: All observed (always full)\n            chart.data.datasets[0].data = actual.values.map((v, i) => ({{x: actual.dates[i], y: v}}));\n            \n            // Dataset 1: Final forecast reference (always full)\n            chart.data.datasets[1].data = [\n                ...Array(nActual).fill(null),\n                ...finalStep.point_forecast\n            ];\n            \n            // Dataset 2: Data used (historical only)\n            const dataUsed = [];\n            for (let i = 0; i < totalPoints; i++) {{\n                if (i < nHist) {{\n                    dataUsed.push(step.historical_values[i]);\n                }} else {{\n                    dataUsed.push(null);\n                }}\n            }}\n            chart.data.datasets[2].data = dataUsed;\n            \n            // Datasets 3-6: CIs (forecast only)\n            const forecastOffset = nActual;\n            const q90Lower = [];\n            const q90Upper = [];\n            const q80Lower = [];\n            const q80Upper = [];\n            \n            for (let i = 0; i < totalPoints; i++) {{\n                const forecastIdx = i - forecastOffset;\n                if (forecastIdx >= 0 && forecastIdx < nForecast) {{\n                    q90Lower.push(step.q10[forecastIdx]);\n                    q90Upper.push(step.q90[forecastIdx]);\n                    q80Lower.push(step.q20[forecastIdx]);\n                    q80Upper.push(step.q80[forecastIdx]);\n                }} else {{\n                    q90Lower.push(null);\n                    q90Upper.push(null);\n                    q80Lower.push(null);\n                    q80Upper.push(null);\n                }}\n            }}\n            chart.data.datasets[3].data = q90Lower;\n            chart.data.datasets[4].data = q90Upper;\n            chart.data.datasets[5].data = q80Lower;\n            chart.data.datasets[6].data = q80Upper;\n            \n            // Dataset 7: Forecast line\n            const forecastData = [];\n            for (let i = 0; i < totalPoints; i++) {{\n                const forecastIdx = i - forecastOffset;\n                if (forecastIdx >= 0 && forecastIdx < nForecast) {{\n                    forecastData.push(step.point_forecast[forecastIdx]);\n                }} else {{\n                    forecastData.push(null);\n                }}\n            }}\n            chart.data.datasets[7].data = forecastData;\n            \n            chart.update('none');\n            \n            // Update UI\n            document.getElementById('slider').value = stepIndex;\n            document.getElementById('points-value').textContent = `${{step.n_points}} / 36`;\n            document.getElementById('date-end').textContent = `Using data through ${{step.last_historical_date}}`;\n            \n            // Stats\n            const mean = (step.point_forecast.reduce((a, b) => a + b, 0) / step.point_forecast.length).toFixed(3);\n            const max = Math.max(...step.point_forecast).toFixed(3);\n            const min = Math.min(...step.point_forecast).toFixed(3);\n            \n            document.getElementById('stat-mean').textContent = mean + '°C';\n            document.getElementById('stat-horizon').textContent = step.horizon + ' months';\n            document.getElementById('stat-max').textContent = max + '°C';\n            document.getElementById('stat-min').textContent = min + '°C';\n            \n            currentStep = stepIndex;\n        }}\n\n        document.getElementById('slider').addEventListener('input', e => {{\n            updateChart(parseInt(e.target.value));\n        }});\n\n        document.getElementById('play-btn').addEventListener('click', () => {{\n            const btn = document.getElementById('play-btn');\n            if (isPlaying) {{\n                clearInterval(playInterval);\n                btn.textContent = '▶ Play';\n                isPlaying = false;\n            }} else {{\n                btn.textContent = '⏸ Pause';\n                isPlaying = true;\n                if (currentStep >= animationData.animation_steps.length - 1) currentStep = 0;\n                playInterval = setInterval(() => {{\n                    if (currentStep >= animationData.animation_steps.length - 1) {{\n                        clearInterval(playInterval);\n                        document.getElementById('play-btn').textContent = '▶ Play';\n                        isPlaying = false;\n                    }} else {{\n                        currentStep++;\n                        updateChart(currentStep);\n                    }}\n                }}, 400);\n            }}\n        }});\n\n        document.getElementById('reset-btn').addEventListener('click', () => {{\n            if (isPlaying) {{\n                clearInterval(playInterval);\n                document.getElementById('play-btn').textContent = '▶ Play';\n                isPlaying = false;\n            }}\n            updateChart(0);\n        }});\n\n        // Initialize on load\n        initChart();\n        updateChart(0);\n    </script>\n</body>\n</html>\n\"\"\"\n\n\ndef main() -> None:\n    print(\"=\" * 60)\n    print(\"  GENERATING SELF-CONTAINED HTML\")\n    print(\"=\" * 60)\n\n    # Load animation data\n    with open(DATA_FILE) as f:\n        data = json.load(f)\n\n    # Generate HTML with embedded data\n    html_content = HTML_TEMPLATE.format(data_json=json.dumps(data, indent=2))\n\n    # Write output\n    with open(OUTPUT_FILE, \"w\") as f:\n        f.write(html_content)\n\n    size_kb = OUTPUT_FILE.stat().st_size / 1024\n    print(f\"\\n✅ Generated: {OUTPUT_FILE}\")\n    print(f\"   File size: {size_kb:.1f} KB\")\n    print(f\"   Fully self-contained — no external dependencies\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "timesfm-forecasting/examples/global-temperature/output/animation_data.json",
    "content": "{\n  \"metadata\": {\n    \"model\": \"TimesFM 1.0 (200M) PyTorch\",\n    \"total_steps\": 25,\n    \"min_context\": 12,\n    \"max_horizon\": 36,\n    \"total_months\": 48,\n    \"data_source\": \"NOAA GISTEMP Global Temperature Anomaly\",\n    \"full_date_range\": \"2022-01 to 2024-12\"\n  },\n  \"actual_data\": {\n    \"dates\": [\n      \"2022-01\",\n      \"2022-02\",\n      \"2022-03\",\n      \"2022-04\",\n      \"2022-05\",\n      \"2022-06\",\n      \"2022-07\",\n      \"2022-08\",\n      \"2022-09\",\n      \"2022-10\",\n      \"2022-11\",\n      \"2022-12\",\n      \"2023-01\",\n      \"2023-02\",\n      \"2023-03\",\n      \"2023-04\",\n      \"2023-05\",\n      \"2023-06\",\n      \"2023-07\",\n      \"2023-08\",\n      \"2023-09\",\n      \"2023-10\",\n      \"2023-11\",\n      \"2023-12\",\n      \"2024-01\",\n      \"2024-02\",\n      \"2024-03\",\n      \"2024-04\",\n      \"2024-05\",\n      \"2024-06\",\n      \"2024-07\",\n      \"2024-08\",\n      \"2024-09\",\n      \"2024-10\",\n      \"2024-11\",\n      \"2024-12\"\n    ],\n    \"values\": [\n      0.8899999856948853,\n      0.8899999856948853,\n      1.0199999809265137,\n      0.8799999952316284,\n      0.8500000238418579,\n      0.8799999952316284,\n      0.8799999952316284,\n      0.8999999761581421,\n      0.8799999952316284,\n      0.949999988079071,\n      0.7699999809265137,\n      0.7799999713897705,\n      0.8700000047683716,\n      0.9800000190734863,\n      1.2100000381469727,\n      1.0,\n      0.9399999976158142,\n      1.0800000429153442,\n      1.1799999475479126,\n      1.2400000095367432,\n      1.4700000286102295,\n      1.3200000524520874,\n      1.1799999475479126,\n      1.159999966621399,\n      1.2200000286102295,\n      1.350000023841858,\n      1.340000033378601,\n      1.2599999904632568,\n      1.149999976158142,\n      1.2000000476837158,\n      1.2400000095367432,\n      1.2999999523162842,\n      1.2799999713897705,\n      1.2699999809265137,\n      1.2200000286102295,\n      1.2000000476837158\n    ]\n  },\n  \"animation_steps\": [\n    {\n      \"step\": 1,\n      \"n_points\": 12,\n      \"horizon\": 36,\n      \"last_historical_date\": \"2022-12\",\n      \"historical_dates\": [\n        \"2022-01\",\n        \"2022-02\",\n        \"2022-03\",\n        \"2022-04\",\n        \"2022-05\",\n        \"2022-06\",\n        \"2022-07\",\n        \"2022-08\",\n        \"2022-09\",\n        \"2022-10\",\n        \"2022-11\",\n        \"2022-12\"\n      ],\n      \"historical_values\": [\n        0.8899999856948853,\n        0.8899999856948853,\n        1.0199999809265137,\n        0.8799999952316284,\n        0.8500000238418579,\n        0.8799999952316284,\n        0.8799999952316284,\n        0.8999999761581421,\n        0.8799999952316284,\n        0.949999988079071,\n        0.7699999809265137,\n        0.7799999713897705\n      ],\n      \"forecast_dates\": [\n        \"2023-01\",\n        \"2023-02\",\n        \"2023-03\",\n        \"2023-04\",\n        \"2023-05\",\n        \"2023-06\",\n        \"2023-07\",\n        \"2023-08\",\n        \"2023-09\",\n        \"2023-10\",\n        \"2023-11\",\n        \"2023-12\",\n        \"2024-01\",\n        \"2024-02\",\n        \"2024-03\",\n        \"2024-04\",\n        \"2024-05\",\n        \"2024-06\",\n        \"2024-07\",\n        \"2024-08\",\n        \"2024-09\",\n        \"2024-10\",\n        \"2024-11\",\n        \"2024-12\",\n        \"2025-01\",\n        \"2025-02\",\n        \"2025-03\",\n        \"2025-04\",\n        \"2025-05\",\n        \"2025-06\",\n        \"2025-07\",\n        \"2025-08\",\n        \"2025-09\",\n        \"2025-10\",\n        \"2025-11\",\n        \"2025-12\"\n      ],\n      \"point_forecast\": [\n        0.825579047203064,\n        0.8330779075622559,\n        0.8368334174156189,\n        0.8413563370704651,\n        0.8546873331069946,\n        0.8463932275772095,\n        0.852830708026886,\n        0.8635484576225281,\n        0.873649001121521,\n        0.8784391283988953,\n        0.8793435096740723,\n        0.886539101600647,\n        0.876642107963562,\n        0.8771936297416687,\n        0.8794507384300232,\n        0.8818798065185547,\n        0.8801761269569397,\n        0.878594696521759,\n        0.8841555714607239,\n        0.8686957955360413,\n        0.8627567887306213,\n        0.8599377870559692,\n        0.8534176349639893,\n        0.8439264297485352,\n        0.8403507471084595,\n        0.84540855884552,\n        0.8334686756134033,\n        0.8366615176200867,\n        0.8480817079544067,\n        0.8587210178375244,\n        0.865203857421875,\n        0.8715710043907166,\n        0.883372962474823,\n        0.8742744326591492,\n        0.8734725117683411,\n        0.8783032894134521\n      ],\n      \"q10\": [\n        0.8354606032371521,\n        0.8444467782974243,\n        0.8485234975814819,\n        0.8526979088783264,\n        0.8648908138275146,\n        0.8568621277809143,\n        0.863645076751709,\n        0.872414231300354,\n        0.8817781209945679,\n        0.8863298892974854,\n        0.8866963982582092,\n        0.8946276903152466,\n        0.8833872675895691,\n        0.8827563524246216,\n        0.8864266872406006,\n        0.887717604637146,\n        0.8854249715805054,\n        0.8838265538215637,\n        0.890777051448822,\n        0.8747947812080383,\n        0.8702181577682495,\n        0.8688124418258667,\n        0.8621772527694702,\n        0.8549044728279114,\n        0.8520718812942505,\n        0.8580353856086731,\n        0.8461477756500244,\n        0.8497025966644287,\n        0.8604429364204407,\n        0.8707754015922546,\n        0.8765125870704651,\n        0.8818733096122742,\n        0.893653154373169,\n        0.8849858045578003,\n        0.8816121220588684,\n        0.8867135643959045\n      ],\n      \"q20\": [\n        0.7518579959869385,\n        0.752423882484436,\n        0.7527720928192139,\n        0.7547875642776489,\n        0.7639567852020264,\n        0.7600989937782288,\n        0.7671870589256287,\n        0.7746827006340027,\n        0.783061146736145,\n        0.7859532237052917,\n        0.7876774072647095,\n        0.7946517467498779,\n        0.7890393137931824,\n        0.7905672192573547,\n        0.7923871874809265,\n        0.7943510413169861,\n        0.7928767204284668,\n        0.7914355993270874,\n        0.7945701479911804,\n        0.784331738948822,\n        0.7799307107925415,\n        0.7775163650512695,\n        0.772225022315979,\n        0.7648971676826477,\n        0.7586244940757751,\n        0.7592141032218933,\n        0.7497149705886841,\n        0.7515254020690918,\n        0.76014643907547,\n        0.7683113813400269,\n        0.7757765054702759,\n        0.7805572748184204,\n        0.790294349193573,\n        0.7851614952087402,\n        0.7844950556755066,\n        0.7886985540390015\n      ],\n      \"q80\": [\n        0.8621454238891602,\n        0.8726990222930908,\n        0.8780758380889893,\n        0.8830247521400452,\n        0.895999014377594,\n        0.8877173066139221,\n        0.8932443261146545,\n        0.9029491543769836,\n        0.9142329096794128,\n        0.918304979801178,\n        0.9192531704902649,\n        0.9270545244216919,\n        0.9149025082588196,\n        0.9147888422012329,\n        0.91729736328125,\n        0.9190108776092529,\n        0.9174938201904297,\n        0.916400671005249,\n        0.9234370589256287,\n        0.9071342349052429,\n        0.9007507562637329,\n        0.8995751142501831,\n        0.8921940326690674,\n        0.8833961486816406,\n        0.8816472291946411,\n        0.8888989686965942,\n        0.8762903809547424,\n        0.8794605731964111,\n        0.891765832901001,\n        0.9021292328834534,\n        0.9087244868278503,\n        0.9149095416069031,\n        0.9275970458984375,\n        0.9168868660926819,\n        0.9142359495162964,\n        0.9194778800010681\n      ],\n      \"q90\": [\n        0.8872727155685425,\n        0.8990722298622131,\n        0.9044539928436279,\n        0.9107659459114075,\n        0.9254093170166016,\n        0.9146999716758728,\n        0.9196149706840515,\n        0.9299551844596863,\n        0.941527783870697,\n        0.9455176591873169,\n        0.9463357925415039,\n        0.9539710283279419,\n        0.9405434727668762,\n        0.9397023320198059,\n        0.9439040422439575,\n        0.9448938369750977,\n        0.9431376457214355,\n        0.9417189359664917,\n        0.9492916464805603,\n        0.9315186738967896,\n        0.9267769455909729,\n        0.925445020198822,\n        0.9191145300865173,\n        0.910182535648346,\n        0.9100216031074524,\n        0.9180203676223755,\n        0.9048261046409607,\n        0.9081428050994873,\n        0.9206303954124451,\n        0.9308969974517822,\n        0.9380975961685181,\n        0.9430014491081238,\n        0.9572127461433411,\n        0.9447380304336548,\n        0.9412767291069031,\n        0.9464495778083801\n      ]\n    },\n    {\n      \"step\": 2,\n      \"n_points\": 13,\n      \"horizon\": 35,\n      \"last_historical_date\": \"2023-01\",\n      \"historical_dates\": [\n        \"2022-01\",\n        \"2022-02\",\n        \"2022-03\",\n        \"2022-04\",\n        \"2022-05\",\n        \"2022-06\",\n        \"2022-07\",\n        \"2022-08\",\n        \"2022-09\",\n        \"2022-10\",\n        \"2022-11\",\n        \"2022-12\",\n        \"2023-01\"\n      ],\n      \"historical_values\": [\n        0.8899999856948853,\n        0.8899999856948853,\n        1.0199999809265137,\n        0.8799999952316284,\n        0.8500000238418579,\n        0.8799999952316284,\n        0.8799999952316284,\n        0.8999999761581421,\n        0.8799999952316284,\n        0.949999988079071,\n        0.7699999809265137,\n        0.7799999713897705,\n        0.8700000047683716\n      ],\n      \"forecast_dates\": [\n        \"2023-02\",\n        \"2023-03\",\n        \"2023-04\",\n        \"2023-05\",\n        \"2023-06\",\n        \"2023-07\",\n        \"2023-08\",\n        \"2023-09\",\n        \"2023-10\",\n        \"2023-11\",\n        \"2023-12\",\n        \"2024-01\",\n        \"2024-02\",\n        \"2024-03\",\n        \"2024-04\",\n        \"2024-05\",\n        \"2024-06\",\n        \"2024-07\",\n        \"2024-08\",\n        \"2024-09\",\n        \"2024-10\",\n        \"2024-11\",\n        \"2024-12\",\n        \"2025-01\",\n        \"2025-02\",\n        \"2025-03\",\n        \"2025-04\",\n        \"2025-05\",\n        \"2025-06\",\n        \"2025-07\",\n        \"2025-08\",\n        \"2025-09\",\n        \"2025-10\",\n        \"2025-11\",\n        \"2025-12\"\n      ],\n      \"point_forecast\": [\n        0.8590402007102966,\n        0.8596092462539673,\n        0.864223062992096,\n        0.8694167733192444,\n        0.8599939346313477,\n        0.8577529191970825,\n        0.8670657873153687,\n        0.8746083378791809,\n        0.8758000731468201,\n        0.8808236718177795,\n        0.8853851556777954,\n        0.8753982186317444,\n        0.8732624053955078,\n        0.8803924322128296,\n        0.8831377029418945,\n        0.8812252879142761,\n        0.8837805986404419,\n        0.8842109441757202,\n        0.8692948818206787,\n        0.8612740635871887,\n        0.8624085783958435,\n        0.8617072105407715,\n        0.8601858615875244,\n        0.8625096082687378,\n        0.8663285374641418,\n        0.8544762134552002,\n        0.8533855080604553,\n        0.862159013748169,\n        0.8707855343818665,\n        0.872623860836029,\n        0.878368079662323,\n        0.8822183012962341,\n        0.8722400665283203,\n        0.8674668669700623,\n        0.8758878111839294\n      ],\n      \"q10\": [\n        0.8657022714614868,\n        0.867158055305481,\n        0.8720226287841797,\n        0.8764638900756836,\n        0.8662244081497192,\n        0.8640622496604919,\n        0.873618483543396,\n        0.8803330063819885,\n        0.8822183609008789,\n        0.8867899775505066,\n        0.8920900821685791,\n        0.8817423582077026,\n        0.8790065050125122,\n        0.8854852914810181,\n        0.8888370394706726,\n        0.8871243596076965,\n        0.8896916508674622,\n        0.8902166485786438,\n        0.8758934736251831,\n        0.8675172924995422,\n        0.8692970871925354,\n        0.8685914874076843,\n        0.8668439388275146,\n        0.8710702061653137,\n        0.8750268220901489,\n        0.8633314967155457,\n        0.8620151281356812,\n        0.8703252077102661,\n        0.8786934614181519,\n        0.8804004192352295,\n        0.8853165507316589,\n        0.889494776725769,\n        0.8794597387313843,\n        0.8745465278625488,\n        0.8814859390258789\n      ],\n      \"q20\": [\n        0.779899537563324,\n        0.7763701677322388,\n        0.7775852680206299,\n        0.7800794839859009,\n        0.7750610113143921,\n        0.7753159403800964,\n        0.7829091548919678,\n        0.7884992957115173,\n        0.7900261878967285,\n        0.7911601066589355,\n        0.7951517105102539,\n        0.7891175746917725,\n        0.7887728810310364,\n        0.7934086918830872,\n        0.7968956232070923,\n        0.7951973080635071,\n        0.796229898929596,\n        0.7950001358985901,\n        0.7845399379730225,\n        0.7791075110435486,\n        0.7789998650550842,\n        0.7794902324676514,\n        0.7773360013961792,\n        0.7764586806297302,\n        0.7767698168754578,\n        0.7689880132675171,\n        0.7689797282218933,\n        0.7759402394294739,\n        0.7828512787818909,\n        0.7850325107574463,\n        0.7882039546966553,\n        0.7904639840126038,\n        0.7844158411026001,\n        0.7818136215209961,\n        0.7875857353210449\n      ],\n      \"q80\": [\n        0.8950973153114319,\n        0.8978567719459534,\n        0.9036805033683777,\n        0.9098731875419617,\n        0.8973860144615173,\n        0.8958126306533813,\n        0.9049636125564575,\n        0.9123932123184204,\n        0.9138861298561096,\n        0.9191209077835083,\n        0.9256614446640015,\n        0.9137347936630249,\n        0.9109636545181274,\n        0.9174929857254028,\n        0.9215986728668213,\n        0.9189587831497192,\n        0.9224711060523987,\n        0.9235640168190002,\n        0.9081242084503174,\n        0.8990890979766846,\n        0.900691568851471,\n        0.9007959961891174,\n        0.8983866572380066,\n        0.9030368328094482,\n        0.9082856178283691,\n        0.8958720564842224,\n        0.8932167291641235,\n        0.9023438692092896,\n        0.9115447998046875,\n        0.9133612513542175,\n        0.9190444350242615,\n        0.9236005544662476,\n        0.9117952585220337,\n        0.906220018863678,\n        0.914079487323761\n      ],\n      \"q90\": [\n        0.9195939302444458,\n        0.9236188530921936,\n        0.9301517605781555,\n        0.9359439611434937,\n        0.9242846369743347,\n        0.9196143746376038,\n        0.9301571846008301,\n        0.9382931590080261,\n        0.9394593238830566,\n        0.9451783895492554,\n        0.9518223404884338,\n        0.9389423131942749,\n        0.9352357387542725,\n        0.9424091577529907,\n        0.947126030921936,\n        0.9439764618873596,\n        0.9481194019317627,\n        0.9504281878471375,\n        0.9335556030273438,\n        0.9240644574165344,\n        0.9264681935310364,\n        0.9259119629859924,\n        0.9245560765266418,\n        0.9293811321258545,\n        0.9364281296730042,\n        0.9225189685821533,\n        0.9183617234230042,\n        0.9289659261703491,\n        0.937990665435791,\n        0.9396582245826721,\n        0.9460575580596924,\n        0.9509962797164917,\n        0.9378201961517334,\n        0.9311509132385254,\n        0.9398520588874817\n      ]\n    },\n    {\n      \"step\": 3,\n      \"n_points\": 14,\n      \"horizon\": 34,\n      \"last_historical_date\": \"2023-02\",\n      \"historical_dates\": [\n        \"2022-01\",\n        \"2022-02\",\n        \"2022-03\",\n        \"2022-04\",\n        \"2022-05\",\n        \"2022-06\",\n        \"2022-07\",\n        \"2022-08\",\n        \"2022-09\",\n        \"2022-10\",\n        \"2022-11\",\n        \"2022-12\",\n        \"2023-01\",\n        \"2023-02\"\n      ],\n      \"historical_values\": [\n        0.8899999856948853,\n        0.8899999856948853,\n        1.0199999809265137,\n        0.8799999952316284,\n        0.8500000238418579,\n        0.8799999952316284,\n        0.8799999952316284,\n        0.8999999761581421,\n        0.8799999952316284,\n        0.949999988079071,\n        0.7699999809265137,\n        0.7799999713897705,\n        0.8700000047683716,\n        0.9800000190734863\n      ],\n      \"forecast_dates\": [\n        \"2023-03\",\n        \"2023-04\",\n        \"2023-05\",\n        \"2023-06\",\n        \"2023-07\",\n        \"2023-08\",\n        \"2023-09\",\n        \"2023-10\",\n        \"2023-11\",\n        \"2023-12\",\n        \"2024-01\",\n        \"2024-02\",\n        \"2024-03\",\n        \"2024-04\",\n        \"2024-05\",\n        \"2024-06\",\n        \"2024-07\",\n        \"2024-08\",\n        \"2024-09\",\n        \"2024-10\",\n        \"2024-11\",\n        \"2024-12\",\n        \"2025-01\",\n        \"2025-02\",\n        \"2025-03\",\n        \"2025-04\",\n        \"2025-05\",\n        \"2025-06\",\n        \"2025-07\",\n        \"2025-08\",\n        \"2025-09\",\n        \"2025-10\",\n        \"2025-11\",\n        \"2025-12\"\n      ],\n      \"point_forecast\": [\n        0.8962793350219727,\n        0.8913998007774353,\n        0.8914807438850403,\n        0.871181845664978,\n        0.8662641644477844,\n        0.8797636032104492,\n        0.8862841129302979,\n        0.884779691696167,\n        0.8836072087287903,\n        0.8898857235908508,\n        0.8741991519927979,\n        0.8697925806045532,\n        0.8814526796340942,\n        0.8840450048446655,\n        0.8814879655838013,\n        0.8813571333885193,\n        0.8835927248001099,\n        0.8649601936340332,\n        0.8594167828559875,\n        0.8685873746871948,\n        0.872805118560791,\n        0.8739079236984253,\n        0.8808366060256958,\n        0.8895877003669739,\n        0.8769407868385315,\n        0.8714866638183594,\n        0.8808306455612183,\n        0.888067364692688,\n        0.8873578906059265,\n        0.8892648816108704,\n        0.8923593759536743,\n        0.8761922717094421,\n        0.8705070614814758,\n        0.8820964694023132\n      ],\n      \"q10\": [\n        0.9006780982017517,\n        0.8960930705070496,\n        0.8975709676742554,\n        0.8764383792877197,\n        0.8719356060028076,\n        0.8863880038261414,\n        0.8936481475830078,\n        0.891782283782959,\n        0.8906540274620056,\n        0.8970102667808533,\n        0.8820476531982422,\n        0.8772810101509094,\n        0.889976978302002,\n        0.8918938636779785,\n        0.8886879086494446,\n        0.8894075751304626,\n        0.8912825584411621,\n        0.8730634450912476,\n        0.8673158288002014,\n        0.8772640824317932,\n        0.8791468739509583,\n        0.8799763321876526,\n        0.8868378400802612,\n        0.8973256349563599,\n        0.883881151676178,\n        0.879287600517273,\n        0.8892991542816162,\n        0.8954638242721558,\n        0.8954599499702454,\n        0.8977177739143372,\n        0.9008411765098572,\n        0.8844205737113953,\n        0.8789454102516174,\n        0.8901882767677307\n      ],\n      \"q20\": [\n        0.8080285787582397,\n        0.8004014492034912,\n        0.7992052435874939,\n        0.7845293879508972,\n        0.7833878993988037,\n        0.7934101819992065,\n        0.798040509223938,\n        0.7972208261489868,\n        0.7961648106575012,\n        0.7998728156089783,\n        0.789516031742096,\n        0.785558819770813,\n        0.794472336769104,\n        0.7951850295066833,\n        0.7945684194564819,\n        0.794198215007782,\n        0.7945625185966492,\n        0.7808390855789185,\n        0.7763155698776245,\n        0.7829429507255554,\n        0.7852435111999512,\n        0.7865880727767944,\n        0.7909019589424133,\n        0.7960636615753174,\n        0.7863008379936218,\n        0.7832475304603577,\n        0.7900716066360474,\n        0.7962746620178223,\n        0.7965481281280518,\n        0.7976964116096497,\n        0.7985848188400269,\n        0.7879433631896973,\n        0.7850476503372192,\n        0.7922680377960205\n      ],\n      \"q80\": [\n        0.9340344071388245,\n        0.9310296177864075,\n        0.931887149810791,\n        0.9107009768486023,\n        0.9042311310768127,\n        0.9196222424507141,\n        0.9265503287315369,\n        0.9255625605583191,\n        0.9238306283950806,\n        0.9304555058479309,\n        0.913487434387207,\n        0.9083813428878784,\n        0.9220874309539795,\n        0.9244784116744995,\n        0.9214062094688416,\n        0.9219330549240112,\n        0.9250167608261108,\n        0.9045271873474121,\n        0.8984488248825073,\n        0.9084285497665405,\n        0.9120396375656128,\n        0.9134330153465271,\n        0.920710563659668,\n        0.9313111305236816,\n        0.9171351194381714,\n        0.9125726222991943,\n        0.922325611114502,\n        0.9292736649513245,\n        0.9300060272216797,\n        0.932316243648529,\n        0.9348157644271851,\n        0.9165349006652832,\n        0.9105325937271118,\n        0.9230691194534302\n      ],\n      \"q90\": [\n        0.9600221514701843,\n        0.9573583006858826,\n        0.9588406682014465,\n        0.9357264041900635,\n        0.9300737380981445,\n        0.9452965259552002,\n        0.953380823135376,\n        0.9521129727363586,\n        0.9504246711730957,\n        0.9578516483306885,\n        0.9395800828933716,\n        0.9347273707389832,\n        0.9480591416358948,\n        0.950930118560791,\n        0.948790431022644,\n        0.94916832447052,\n        0.9522303342819214,\n        0.9315612316131592,\n        0.9246772527694702,\n        0.9351183772087097,\n        0.9386969208717346,\n        0.9390504956245422,\n        0.9479607939720154,\n        0.9585453867912292,\n        0.9437541961669922,\n        0.9387108683586121,\n        0.9494839906692505,\n        0.9573196172714233,\n        0.9568711519241333,\n        0.9595789909362793,\n        0.9637172222137451,\n        0.9441839456558228,\n        0.936747670173645,\n        0.9499791264533997\n      ]\n    },\n    {\n      \"step\": 4,\n      \"n_points\": 15,\n      \"horizon\": 33,\n      \"last_historical_date\": \"2023-03\",\n      \"historical_dates\": [\n        \"2022-01\",\n        \"2022-02\",\n        \"2022-03\",\n        \"2022-04\",\n        \"2022-05\",\n        \"2022-06\",\n        \"2022-07\",\n        \"2022-08\",\n        \"2022-09\",\n        \"2022-10\",\n        \"2022-11\",\n        \"2022-12\",\n        \"2023-01\",\n        \"2023-02\",\n        \"2023-03\"\n      ],\n      \"historical_values\": [\n        0.8899999856948853,\n        0.8899999856948853,\n        1.0199999809265137,\n        0.8799999952316284,\n        0.8500000238418579,\n        0.8799999952316284,\n        0.8799999952316284,\n        0.8999999761581421,\n        0.8799999952316284,\n        0.949999988079071,\n        0.7699999809265137,\n        0.7799999713897705,\n        0.8700000047683716,\n        0.9800000190734863,\n        1.2100000381469727\n      ],\n      \"forecast_dates\": [\n        \"2023-04\",\n        \"2023-05\",\n        \"2023-06\",\n        \"2023-07\",\n        \"2023-08\",\n        \"2023-09\",\n        \"2023-10\",\n        \"2023-11\",\n        \"2023-12\",\n        \"2024-01\",\n        \"2024-02\",\n        \"2024-03\",\n        \"2024-04\",\n        \"2024-05\",\n        \"2024-06\",\n        \"2024-07\",\n        \"2024-08\",\n        \"2024-09\",\n        \"2024-10\",\n        \"2024-11\",\n        \"2024-12\",\n        \"2025-01\",\n        \"2025-02\",\n        \"2025-03\",\n        \"2025-04\",\n        \"2025-05\",\n        \"2025-06\",\n        \"2025-07\",\n        \"2025-08\",\n        \"2025-09\",\n        \"2025-10\",\n        \"2025-11\",\n        \"2025-12\"\n      ],\n      \"point_forecast\": [\n        1.011451005935669,\n        0.9553948640823364,\n        0.9197208285331726,\n        0.9124891757965088,\n        0.9261340498924255,\n        0.9234520792961121,\n        0.9108935594558716,\n        0.8969470858573914,\n        0.8980726599693298,\n        0.8982804417610168,\n        0.8991943001747131,\n        0.9119693636894226,\n        0.9100792407989502,\n        0.9019815921783447,\n        0.8973109126091003,\n        0.8946781158447266,\n        0.8884148001670837,\n        0.8810747861862183,\n        0.8763440251350403,\n        0.8705035448074341,\n        0.8778358101844788,\n        0.8958552479743958,\n        0.9278874397277832,\n        0.9475082755088806,\n        0.9399139285087585,\n        0.9295593500137329,\n        0.9194858074188232,\n        0.916989803314209,\n        0.9152628779411316,\n        0.9101430773735046,\n        0.8927386999130249,\n        0.8823466897010803,\n        0.8857365250587463\n      ],\n      \"q10\": [\n        1.028891921043396,\n        0.9745897650718689,\n        0.9376441240310669,\n        0.9297030568122864,\n        0.9439254403114319,\n        0.943497896194458,\n        0.9286640286445618,\n        0.9142505526542664,\n        0.9157885313034058,\n        0.9157061576843262,\n        0.9165257215499878,\n        0.929168164730072,\n        0.9264547228813171,\n        0.9190627932548523,\n        0.9123958945274353,\n        0.9115281105041504,\n        0.9037967324256897,\n        0.8992751836776733,\n        0.8952363133430481,\n        0.8902027010917664,\n        0.8936614990234375,\n        0.910301148891449,\n        0.9421884417533875,\n        0.9664905667304993,\n        0.957619309425354,\n        0.9471821784973145,\n        0.9369155168533325,\n        0.9328755736351013,\n        0.9314517974853516,\n        0.9264087677001953,\n        0.9108965992927551,\n        0.9000225067138672,\n        0.9029441475868225\n      ],\n      \"q20\": [\n        0.8432373404502869,\n        0.8032699823379517,\n        0.7799109220504761,\n        0.7799201011657715,\n        0.7939504981040955,\n        0.7942459583282471,\n        0.7866204380989075,\n        0.7787443399429321,\n        0.7860440611839294,\n        0.7884118556976318,\n        0.7909562587738037,\n        0.7990366220474243,\n        0.7990424633026123,\n        0.7951732277870178,\n        0.7943146228790283,\n        0.7914892435073853,\n        0.786389946937561,\n        0.7805740237236023,\n        0.7728126049041748,\n        0.7663388848304749,\n        0.767531156539917,\n        0.7775982618331909,\n        0.7965872287750244,\n        0.8098679184913635,\n        0.8040605187416077,\n        0.7990914583206177,\n        0.7943341135978699,\n        0.795067548751831,\n        0.7930296659469604,\n        0.7909825444221497,\n        0.7814936637878418,\n        0.7742173671722412,\n        0.7788263559341431\n      ],\n      \"q80\": [\n        1.0893518924713135,\n        1.031952142715454,\n        0.9909453392028809,\n        0.9802313446998596,\n        0.9924889802932739,\n        0.9901573657989502,\n        0.973213791847229,\n        0.9567193984985352,\n        0.9561106562614441,\n        0.9526670575141907,\n        0.9554384350776672,\n        0.966469407081604,\n        0.9650457501411438,\n        0.9547586441040039,\n        0.9497334957122803,\n        0.9472479820251465,\n        0.9417811632156372,\n        0.9347074627876282,\n        0.9311444163322449,\n        0.925645649433136,\n        0.9340237975120544,\n        0.9546427726745605,\n        0.9898675680160522,\n        1.0140517950057983,\n        1.006885290145874,\n        0.9937493205070496,\n        0.9815763235092163,\n        0.9766898155212402,\n        0.9745802879333496,\n        0.9689580202102661,\n        0.9494245052337646,\n        0.9369281530380249,\n        0.940288782119751\n      ],\n      \"q90\": [\n        1.143047571182251,\n        1.0867642164230347,\n        1.0392613410949707,\n        1.0258489847183228,\n        1.0397703647613525,\n        1.035668134689331,\n        1.0181812047958374,\n        0.9991654753684998,\n        0.9964229464530945,\n        0.9952237606048584,\n        0.994753360748291,\n        1.0074013471603394,\n        1.0027097463607788,\n        0.9933873414993286,\n        0.9889267086982727,\n        0.9854975342750549,\n        0.9785516262054443,\n        0.9728615880012512,\n        0.9702323079109192,\n        0.9645059108734131,\n        0.9732341766357422,\n        0.9938783049583435,\n        1.0329622030258179,\n        1.060141921043396,\n        1.0525397062301636,\n        1.0378689765930176,\n        1.0230897665023804,\n        1.018609642982483,\n        1.0162283182144165,\n        1.0081523656845093,\n        0.9886332750320435,\n        0.9734073877334595,\n        0.9774399399757385\n      ]\n    },\n    {\n      \"step\": 5,\n      \"n_points\": 16,\n      \"horizon\": 32,\n      \"last_historical_date\": \"2023-04\",\n      \"historical_dates\": [\n        \"2022-01\",\n        \"2022-02\",\n        \"2022-03\",\n        \"2022-04\",\n        \"2022-05\",\n        \"2022-06\",\n        \"2022-07\",\n        \"2022-08\",\n        \"2022-09\",\n        \"2022-10\",\n        \"2022-11\",\n        \"2022-12\",\n        \"2023-01\",\n        \"2023-02\",\n        \"2023-03\",\n        \"2023-04\"\n      ],\n      \"historical_values\": [\n        0.8899999856948853,\n        0.8899999856948853,\n        1.0199999809265137,\n        0.8799999952316284,\n        0.8500000238418579,\n        0.8799999952316284,\n        0.8799999952316284,\n        0.8999999761581421,\n        0.8799999952316284,\n        0.949999988079071,\n        0.7699999809265137,\n        0.7799999713897705,\n        0.8700000047683716,\n        0.9800000190734863,\n        1.2100000381469727,\n        1.0\n      ],\n      \"forecast_dates\": [\n        \"2023-05\",\n        \"2023-06\",\n        \"2023-07\",\n        \"2023-08\",\n        \"2023-09\",\n        \"2023-10\",\n        \"2023-11\",\n        \"2023-12\",\n        \"2024-01\",\n        \"2024-02\",\n        \"2024-03\",\n        \"2024-04\",\n        \"2024-05\",\n        \"2024-06\",\n        \"2024-07\",\n        \"2024-08\",\n        \"2024-09\",\n        \"2024-10\",\n        \"2024-11\",\n        \"2024-12\",\n        \"2025-01\",\n        \"2025-02\",\n        \"2025-03\",\n        \"2025-04\",\n        \"2025-05\",\n        \"2025-06\",\n        \"2025-07\",\n        \"2025-08\",\n        \"2025-09\",\n        \"2025-10\",\n        \"2025-11\",\n        \"2025-12\"\n      ],\n      \"point_forecast\": [\n        0.9379441142082214,\n        0.9161815047264099,\n        0.9183650612831116,\n        0.9345710277557373,\n        0.9429481625556946,\n        0.9236418008804321,\n        0.9020940065383911,\n        0.8962475657463074,\n        0.8969618678092957,\n        0.9029411673545837,\n        0.9058347344398499,\n        0.9071778059005737,\n        0.9064934849739075,\n        0.9002208113670349,\n        0.8948965668678284,\n        0.8888558745384216,\n        0.885951042175293,\n        0.8833035230636597,\n        0.8850363492965698,\n        0.8896763324737549,\n        0.9047040939331055,\n        0.9251466989517212,\n        0.9383421540260315,\n        0.9336385726928711,\n        0.9287689328193665,\n        0.9275407791137695,\n        0.9268409609794617,\n        0.924099326133728,\n        0.9169213771820068,\n        0.9030519127845764,\n        0.8919728398323059,\n        0.8939611315727234\n      ],\n      \"q10\": [\n        0.9455586075782776,\n        0.9275433421134949,\n        0.9313569068908691,\n        0.9499651789665222,\n        0.957696259021759,\n        0.9388371706008911,\n        0.9148422479629517,\n        0.9104428887367249,\n        0.9122737646102905,\n        0.9160297513008118,\n        0.9193358421325684,\n        0.9216225147247314,\n        0.9201593399047852,\n        0.9155508875846863,\n        0.9093347191810608,\n        0.9044749736785889,\n        0.8999581336975098,\n        0.8994951248168945,\n        0.9004791378974915,\n        0.9077976942062378,\n        0.9192850589752197,\n        0.9383060336112976,\n        0.9530308842658997,\n        0.9488463401794434,\n        0.9426198601722717,\n        0.9435754418373108,\n        0.9431970119476318,\n        0.9382244944572449,\n        0.9305117726325989,\n        0.9167183041572571,\n        0.9076744914054871,\n        0.9097439646720886\n      ],\n      \"q20\": [\n        0.8105636239051819,\n        0.7875122427940369,\n        0.787703812122345,\n        0.8008798360824585,\n        0.8086710572242737,\n        0.7946160435676575,\n        0.7819311022758484,\n        0.7810927629470825,\n        0.7885390520095825,\n        0.7923018336296082,\n        0.7944296002388,\n        0.793520987033844,\n        0.7936148643493652,\n        0.7905219793319702,\n        0.7880567312240601,\n        0.7844575643539429,\n        0.7792351245880127,\n        0.7751155495643616,\n        0.7713013887405396,\n        0.7743531465530396,\n        0.7803812026977539,\n        0.7938993573188782,\n        0.8021929860115051,\n        0.7987417578697205,\n        0.794520914554596,\n        0.7944797277450562,\n        0.7938265800476074,\n        0.7947475910186768,\n        0.7923287153244019,\n        0.785821259021759,\n        0.7809209823608398,\n        0.7844333648681641\n      ],\n      \"q80\": [\n        0.9937812685966492,\n        0.9760434627532959,\n        0.9809014797210693,\n        0.9971702098846436,\n        1.0051108598709106,\n        0.985238790512085,\n        0.9596951007843018,\n        0.9502063989639282,\n        0.9515751004219055,\n        0.9542210102081299,\n        0.9595392346382141,\n        0.9599698185920715,\n        0.9596587419509888,\n        0.9517510533332825,\n        0.9467341303825378,\n        0.9418620467185974,\n        0.9391661882400513,\n        0.9384753108024597,\n        0.940481960773468,\n        0.9475308656692505,\n        0.963818371295929,\n        0.9858653545379639,\n        1.0016189813613892,\n        0.9964566826820374,\n        0.9913219213485718,\n        0.9908701181411743,\n        0.9896549582481384,\n        0.9836863279342651,\n        0.9743705987930298,\n        0.9582211375236511,\n        0.9449355006217957,\n        0.94720059633255\n      ],\n      \"q90\": [\n        1.0336796045303345,\n        1.0175514221191406,\n        1.021440029144287,\n        1.0401356220245361,\n        1.0489550828933716,\n        1.0270309448242188,\n        0.9989587068557739,\n        0.9885305166244507,\n        0.9877901077270508,\n        0.9937816262245178,\n        0.996868908405304,\n        0.9987958073616028,\n        0.9956378936767578,\n        0.9891375303268433,\n        0.9845867156982422,\n        0.979006290435791,\n        0.9757927656173706,\n        0.9753840565681458,\n        0.9795432090759277,\n        0.9870526194572449,\n        1.0044395923614502,\n        1.0267916917800903,\n        1.0432230234146118,\n        1.0385234355926514,\n        1.0341284275054932,\n        1.0333774089813232,\n        1.0310395956039429,\n        1.025346040725708,\n        1.014280080795288,\n        0.9950195550918579,\n        0.9828959703445435,\n        0.9817364811897278\n      ]\n    },\n    {\n      \"step\": 6,\n      \"n_points\": 17,\n      \"horizon\": 31,\n      \"last_historical_date\": \"2023-05\",\n      \"historical_dates\": [\n        \"2022-01\",\n        \"2022-02\",\n        \"2022-03\",\n        \"2022-04\",\n        \"2022-05\",\n        \"2022-06\",\n        \"2022-07\",\n        \"2022-08\",\n        \"2022-09\",\n        \"2022-10\",\n        \"2022-11\",\n        \"2022-12\",\n        \"2023-01\",\n        \"2023-02\",\n        \"2023-03\",\n        \"2023-04\",\n        \"2023-05\"\n      ],\n      \"historical_values\": [\n        0.8899999856948853,\n        0.8899999856948853,\n        1.0199999809265137,\n        0.8799999952316284,\n        0.8500000238418579,\n        0.8799999952316284,\n        0.8799999952316284,\n        0.8999999761581421,\n        0.8799999952316284,\n        0.949999988079071,\n        0.7699999809265137,\n        0.7799999713897705,\n        0.8700000047683716,\n        0.9800000190734863,\n        1.2100000381469727,\n        1.0,\n        0.9399999976158142\n      ],\n      \"forecast_dates\": [\n        \"2023-06\",\n        \"2023-07\",\n        \"2023-08\",\n        \"2023-09\",\n        \"2023-10\",\n        \"2023-11\",\n        \"2023-12\",\n        \"2024-01\",\n        \"2024-02\",\n        \"2024-03\",\n        \"2024-04\",\n        \"2024-05\",\n        \"2024-06\",\n        \"2024-07\",\n        \"2024-08\",\n        \"2024-09\",\n        \"2024-10\",\n        \"2024-11\",\n        \"2024-12\",\n        \"2025-01\",\n        \"2025-02\",\n        \"2025-03\",\n        \"2025-04\",\n        \"2025-05\",\n        \"2025-06\",\n        \"2025-07\",\n        \"2025-08\",\n        \"2025-09\",\n        \"2025-10\",\n        \"2025-11\",\n        \"2025-12\"\n      ],\n      \"point_forecast\": [\n        0.9097275137901306,\n        0.9010418057441711,\n        0.9079869985580444,\n        0.9222638010978699,\n        0.932843029499054,\n        0.9133341312408447,\n        0.8972155451774597,\n        0.8887625336647034,\n        0.8941851854324341,\n        0.9068790674209595,\n        0.9091910123825073,\n        0.9068935513496399,\n        0.8990182876586914,\n        0.8986428380012512,\n        0.8881825804710388,\n        0.8843041658401489,\n        0.888336718082428,\n        0.8892695307731628,\n        0.8974661231040955,\n        0.9044860601425171,\n        0.9227194786071777,\n        0.9294296503067017,\n        0.9252649545669556,\n        0.9205634593963623,\n        0.9196065664291382,\n        0.9199687242507935,\n        0.9132981300354004,\n        0.9133179187774658,\n        0.9007443785667419,\n        0.8912027478218079,\n        0.8934641480445862\n      ],\n      \"q10\": [\n        0.9192558526992798,\n        0.9128602147102356,\n        0.9227687120437622,\n        0.9362373352050781,\n        0.9478849172592163,\n        0.9271639585494995,\n        0.910339891910553,\n        0.9013872146606445,\n        0.908535897731781,\n        0.9196968078613281,\n        0.9216489791870117,\n        0.9205824136734009,\n        0.9120896458625793,\n        0.9124637842178345,\n        0.9021389484405518,\n        0.8997719883918762,\n        0.9026364684104919,\n        0.9033412933349609,\n        0.9109377264976501,\n        0.9189012050628662,\n        0.9366557598114014,\n        0.9421946406364441,\n        0.937626302242279,\n        0.9345484972000122,\n        0.9316884875297546,\n        0.9340106844902039,\n        0.9270667433738708,\n        0.9266247749328613,\n        0.9148653745651245,\n        0.9044336676597595,\n        0.9073527455329895\n      ],\n      \"q20\": [\n        0.7991487383842468,\n        0.7880749702453613,\n        0.7902460098266602,\n        0.8014485239982605,\n        0.8115598559379578,\n        0.7963781952857971,\n        0.7883695960044861,\n        0.7836517691612244,\n        0.7910313606262207,\n        0.799010694026947,\n        0.8031657934188843,\n        0.8004167675971985,\n        0.7960184216499329,\n        0.7969078421592712,\n        0.7900155782699585,\n        0.7853973507881165,\n        0.7849644422531128,\n        0.7844982743263245,\n        0.7866605520248413,\n        0.7920172810554504,\n        0.8011935353279114,\n        0.8064550161361694,\n        0.8041524887084961,\n        0.8006000518798828,\n        0.7974086403846741,\n        0.7984392046928406,\n        0.7938262224197388,\n        0.7966775298118591,\n        0.7895344495773315,\n        0.7830621004104614,\n        0.7873432636260986\n      ],\n      \"q80\": [\n        0.9585660099983215,\n        0.9542173743247986,\n        0.9642703533172607,\n        0.9804073572158813,\n        0.9885033965110779,\n        0.9688029289245605,\n        0.949183464050293,\n        0.9374165534973145,\n        0.9444000124931335,\n        0.9574207663536072,\n        0.9588959217071533,\n        0.9561213254928589,\n        0.9485365748405457,\n        0.9463241100311279,\n        0.9353682994842529,\n        0.934599757194519,\n        0.9394335746765137,\n        0.9425153136253357,\n        0.9504368901252747,\n        0.9591487050056458,\n        0.9809996485710144,\n        0.986733615398407,\n        0.982063353061676,\n        0.9771464467048645,\n        0.9761553406715393,\n        0.977692723274231,\n        0.9702091813087463,\n        0.9681852459907532,\n        0.9539398550987244,\n        0.942665696144104,\n        0.9438384771347046\n      ],\n      \"q90\": [\n        0.994154691696167,\n        0.9911658763885498,\n        1.0009171962738037,\n        1.0182007551193237,\n        1.0296927690505981,\n        1.0062158107757568,\n        0.985028862953186,\n        0.9721169471740723,\n        0.9787886142730713,\n        0.9931607246398926,\n        0.9947684407234192,\n        0.9917771220207214,\n        0.9817482233047485,\n        0.9805346727371216,\n        0.9713162779808044,\n        0.9691506624221802,\n        0.9753089547157288,\n        0.9789929986000061,\n        0.988203227519989,\n        0.9974985122680664,\n        1.0200386047363281,\n        1.024385929107666,\n        1.0200226306915283,\n        1.0142742395401,\n        1.0153833627700806,\n        1.0168485641479492,\n        1.0072355270385742,\n        1.0065840482711792,\n        0.9912008047103882,\n        0.9780105948448181,\n        0.9798558950424194\n      ]\n    },\n    {\n      \"step\": 7,\n      \"n_points\": 18,\n      \"horizon\": 30,\n      \"last_historical_date\": \"2023-06\",\n      \"historical_dates\": [\n        \"2022-01\",\n        \"2022-02\",\n        \"2022-03\",\n        \"2022-04\",\n        \"2022-05\",\n        \"2022-06\",\n        \"2022-07\",\n        \"2022-08\",\n        \"2022-09\",\n        \"2022-10\",\n        \"2022-11\",\n        \"2022-12\",\n        \"2023-01\",\n        \"2023-02\",\n        \"2023-03\",\n        \"2023-04\",\n        \"2023-05\",\n        \"2023-06\"\n      ],\n      \"historical_values\": [\n        0.8899999856948853,\n        0.8899999856948853,\n        1.0199999809265137,\n        0.8799999952316284,\n        0.8500000238418579,\n        0.8799999952316284,\n        0.8799999952316284,\n        0.8999999761581421,\n        0.8799999952316284,\n        0.949999988079071,\n        0.7699999809265137,\n        0.7799999713897705,\n        0.8700000047683716,\n        0.9800000190734863,\n        1.2100000381469727,\n        1.0,\n        0.9399999976158142,\n        1.0800000429153442\n      ],\n      \"forecast_dates\": [\n        \"2023-07\",\n        \"2023-08\",\n        \"2023-09\",\n        \"2023-10\",\n        \"2023-11\",\n        \"2023-12\",\n        \"2024-01\",\n        \"2024-02\",\n        \"2024-03\",\n        \"2024-04\",\n        \"2024-05\",\n        \"2024-06\",\n        \"2024-07\",\n        \"2024-08\",\n        \"2024-09\",\n        \"2024-10\",\n        \"2024-11\",\n        \"2024-12\",\n        \"2025-01\",\n        \"2025-02\",\n        \"2025-03\",\n        \"2025-04\",\n        \"2025-05\",\n        \"2025-06\",\n        \"2025-07\",\n        \"2025-08\",\n        \"2025-09\",\n        \"2025-10\",\n        \"2025-11\",\n        \"2025-12\"\n      ],\n      \"point_forecast\": [\n        0.9665141701698303,\n        0.9519135355949402,\n        0.9444465637207031,\n        0.9402952790260315,\n        0.9306893348693848,\n        0.9244646430015564,\n        0.9174035787582397,\n        0.9139379858970642,\n        0.9132129549980164,\n        0.9145187735557556,\n        0.911784291267395,\n        0.9093538522720337,\n        0.9040751457214355,\n        0.9021264314651489,\n        0.8961065411567688,\n        0.8968585133552551,\n        0.9025744795799255,\n        0.9108133316040039,\n        0.9250923991203308,\n        0.9451119899749756,\n        0.9571705460548401,\n        0.9546100497245789,\n        0.9493789076805115,\n        0.9495347738265991,\n        0.9465805292129517,\n        0.942088782787323,\n        0.934301495552063,\n        0.927003026008606,\n        0.9134135842323303,\n        0.9131123423576355\n      ],\n      \"q10\": [\n        0.9755732417106628,\n        0.9652556777000427,\n        0.9605708122253418,\n        0.9540410041809082,\n        0.944946825504303,\n        0.9393219351768494,\n        0.9324542880058289,\n        0.9295912981033325,\n        0.9304096698760986,\n        0.9316055178642273,\n        0.9279895424842834,\n        0.9257113337516785,\n        0.9213154315948486,\n        0.9203523397445679,\n        0.9135439991950989,\n        0.9169613718986511,\n        0.9193251729011536,\n        0.9290840029716492,\n        0.9407450556755066,\n        0.9611459970474243,\n        0.9715418815612793,\n        0.966630756855011,\n        0.9606484770774841,\n        0.9624485373497009,\n        0.9596085548400879,\n        0.9563205242156982,\n        0.9496365189552307,\n        0.9395637512207031,\n        0.9281183481216431,\n        0.9275621175765991\n      ],\n      \"q20\": [\n        0.833349347114563,\n        0.8175394535064697,\n        0.8078386783599854,\n        0.8068903088569641,\n        0.8031129837036133,\n        0.801506757736206,\n        0.7994549870491028,\n        0.7967816591262817,\n        0.7986584305763245,\n        0.7988185882568359,\n        0.799284040927887,\n        0.7968909740447998,\n        0.7936790585517883,\n        0.792199432849884,\n        0.7875745892524719,\n        0.7865579128265381,\n        0.7882473468780518,\n        0.7924611568450928,\n        0.7977651357650757,\n        0.8117226362228394,\n        0.8149524331092834,\n        0.8140331506729126,\n        0.8101717233657837,\n        0.8099949359893799,\n        0.8057650923728943,\n        0.8038991093635559,\n        0.7993261814117432,\n        0.798288106918335,\n        0.7926219701766968,\n        0.7953957319259644\n      ],\n      \"q80\": [\n        1.0251524448394775,\n        1.015281319618225,\n        1.0085906982421875,\n        1.0044453144073486,\n        0.9904035329818726,\n        0.9857988953590393,\n        0.977156400680542,\n        0.9709676504135132,\n        0.9726237654685974,\n        0.9721717238426208,\n        0.9683824181556702,\n        0.9648834466934204,\n        0.9616217613220215,\n        0.9584988355636597,\n        0.9530823230743408,\n        0.9561627507209778,\n        0.9611006379127502,\n        0.9723068475723267,\n        0.9880313873291016,\n        1.0103445053100586,\n        1.02413809299469,\n        1.0192902088165283,\n        1.0122601985931396,\n        1.0145885944366455,\n        1.012281060218811,\n        1.0074970722198486,\n        0.9987425804138184,\n        0.987089216709137,\n        0.9722681045532227,\n        0.9707110524177551\n      ],\n      \"q90\": [\n        1.0656019449234009,\n        1.059928059577942,\n        1.0517113208770752,\n        1.0461057424545288,\n        1.035980224609375,\n        1.0275849103927612,\n        1.0181881189346313,\n        1.0124856233596802,\n        1.0126112699508667,\n        1.0153447389602661,\n        1.0106351375579834,\n        1.0058791637420654,\n        1.0014264583587646,\n        0.999718964099884,\n        0.9958565831184387,\n        0.9977275133132935,\n        1.0037381649017334,\n        1.0153366327285767,\n        1.031912088394165,\n        1.055626630783081,\n        1.0701265335083008,\n        1.0629067420959473,\n        1.0560659170150757,\n        1.0568609237670898,\n        1.0577772855758667,\n        1.0517592430114746,\n        1.0405441522598267,\n        1.030192494392395,\n        1.013637900352478,\n        1.0091335773468018\n      ]\n    },\n    {\n      \"step\": 8,\n      \"n_points\": 19,\n      \"horizon\": 29,\n      \"last_historical_date\": \"2023-07\",\n      \"historical_dates\": [\n        \"2022-01\",\n        \"2022-02\",\n        \"2022-03\",\n        \"2022-04\",\n        \"2022-05\",\n        \"2022-06\",\n        \"2022-07\",\n        \"2022-08\",\n        \"2022-09\",\n        \"2022-10\",\n        \"2022-11\",\n        \"2022-12\",\n        \"2023-01\",\n        \"2023-02\",\n        \"2023-03\",\n        \"2023-04\",\n        \"2023-05\",\n        \"2023-06\",\n        \"2023-07\"\n      ],\n      \"historical_values\": [\n        0.8899999856948853,\n        0.8899999856948853,\n        1.0199999809265137,\n        0.8799999952316284,\n        0.8500000238418579,\n        0.8799999952316284,\n        0.8799999952316284,\n        0.8999999761581421,\n        0.8799999952316284,\n        0.949999988079071,\n        0.7699999809265137,\n        0.7799999713897705,\n        0.8700000047683716,\n        0.9800000190734863,\n        1.2100000381469727,\n        1.0,\n        0.9399999976158142,\n        1.0800000429153442,\n        1.1799999475479126\n      ],\n      \"forecast_dates\": [\n        \"2023-08\",\n        \"2023-09\",\n        \"2023-10\",\n        \"2023-11\",\n        \"2023-12\",\n        \"2024-01\",\n        \"2024-02\",\n        \"2024-03\",\n        \"2024-04\",\n        \"2024-05\",\n        \"2024-06\",\n        \"2024-07\",\n        \"2024-08\",\n        \"2024-09\",\n        \"2024-10\",\n        \"2024-11\",\n        \"2024-12\",\n        \"2025-01\",\n        \"2025-02\",\n        \"2025-03\",\n        \"2025-04\",\n        \"2025-05\",\n        \"2025-06\",\n        \"2025-07\",\n        \"2025-08\",\n        \"2025-09\",\n        \"2025-10\",\n        \"2025-11\",\n        \"2025-12\"\n      ],\n      \"point_forecast\": [\n        1.0381698608398438,\n        1.012021780014038,\n        0.99420565366745,\n        0.9754087924957275,\n        0.9563038349151611,\n        0.9495773315429688,\n        0.9422544240951538,\n        0.9361824989318848,\n        0.9247673749923706,\n        0.9178153276443481,\n        0.9097317457199097,\n        0.901350200176239,\n        0.8968333601951599,\n        0.8947892189025879,\n        0.8923584818840027,\n        0.8944633603096008,\n        0.9065102338790894,\n        0.9204601049423218,\n        0.951920211315155,\n        0.9842206239700317,\n        0.99086993932724,\n        0.9848544597625732,\n        0.9833636283874512,\n        0.9852919578552246,\n        0.9797993302345276,\n        0.9684444069862366,\n        0.9575868844985962,\n        0.9473453760147095,\n        0.9351227283477783\n      ],\n      \"q10\": [\n        1.0491734743118286,\n        1.028739333152771,\n        1.0114028453826904,\n        0.9906209111213684,\n        0.971588134765625,\n        0.9669111371040344,\n        0.9621954560279846,\n        0.9568055868148804,\n        0.9453385472297668,\n        0.9398422241210938,\n        0.9300127029418945,\n        0.922597348690033,\n        0.9215761423110962,\n        0.9172200560569763,\n        0.9145788550376892,\n        0.9178516864776611,\n        0.9267954230308533,\n        0.9420651793479919,\n        0.9693762063980103,\n        1.003636121749878,\n        1.005869746208191,\n        0.9975773096084595,\n        0.9942836165428162,\n        0.9985279440879822,\n        0.9944182634353638,\n        0.985649824142456,\n        0.9736542105674744,\n        0.9612159729003906,\n        0.9520760774612427\n      ],\n      \"q20\": [\n        0.8832447528839111,\n        0.8571564555168152,\n        0.840262234210968,\n        0.8279801607131958,\n        0.8175891637802124,\n        0.8145928382873535,\n        0.8104804754257202,\n        0.8050722479820251,\n        0.8001488447189331,\n        0.7951650619506836,\n        0.7925589084625244,\n        0.78853440284729,\n        0.785635232925415,\n        0.7818436622619629,\n        0.7790342569351196,\n        0.779435932636261,\n        0.7866798639297485,\n        0.7947074174880981,\n        0.8116522431373596,\n        0.834707498550415,\n        0.8330732583999634,\n        0.8280425667762756,\n        0.8265914916992188,\n        0.8280237317085266,\n        0.823756992816925,\n        0.820884108543396,\n        0.8138716816902161,\n        0.8067872524261475,\n        0.8027349710464478\n      ],\n      \"q80\": [\n        1.10765540599823,\n        1.0850690603256226,\n        1.0677224397659302,\n        1.0468156337738037,\n        1.0239413976669312,\n        1.018355131149292,\n        1.0108981132507324,\n        1.0029836893081665,\n        0.9916971325874329,\n        0.9822992086410522,\n        0.9713731408119202,\n        0.9630072712898254,\n        0.9601694941520691,\n        0.9586890339851379,\n        0.955090343952179,\n        0.9576360583305359,\n        0.9701409339904785,\n        0.9886602759361267,\n        1.02058744430542,\n        1.0570831298828125,\n        1.0654001235961914,\n        1.0563757419586182,\n        1.0534954071044922,\n        1.0564368963241577,\n        1.051694393157959,\n        1.0388209819793701,\n        1.025420904159546,\n        1.0107486248016357,\n        0.9982277750968933\n      ],\n      \"q90\": [\n        1.1553966999053955,\n        1.137328863143921,\n        1.1165260076522827,\n        1.0933233499526978,\n        1.072894811630249,\n        1.065496563911438,\n        1.0601707696914673,\n        1.0506465435028076,\n        1.038832187652588,\n        1.0302690267562866,\n        1.018511414527893,\n        1.0077110528945923,\n        1.0042316913604736,\n        1.0026092529296875,\n        1.0030121803283691,\n        1.0043935775756836,\n        1.018110990524292,\n        1.0365487337112427,\n        1.0698375701904297,\n        1.1068248748779297,\n        1.114990472793579,\n        1.105769395828247,\n        1.1021937131881714,\n        1.1038919687271118,\n        1.1002414226531982,\n        1.0864661931991577,\n        1.0711843967437744,\n        1.0577744245529175,\n        1.044431209564209\n      ]\n    },\n    {\n      \"step\": 9,\n      \"n_points\": 20,\n      \"horizon\": 28,\n      \"last_historical_date\": \"2023-08\",\n      \"historical_dates\": [\n        \"2022-01\",\n        \"2022-02\",\n        \"2022-03\",\n        \"2022-04\",\n        \"2022-05\",\n        \"2022-06\",\n        \"2022-07\",\n        \"2022-08\",\n        \"2022-09\",\n        \"2022-10\",\n        \"2022-11\",\n        \"2022-12\",\n        \"2023-01\",\n        \"2023-02\",\n        \"2023-03\",\n        \"2023-04\",\n        \"2023-05\",\n        \"2023-06\",\n        \"2023-07\",\n        \"2023-08\"\n      ],\n      \"historical_values\": [\n        0.8899999856948853,\n        0.8899999856948853,\n        1.0199999809265137,\n        0.8799999952316284,\n        0.8500000238418579,\n        0.8799999952316284,\n        0.8799999952316284,\n        0.8999999761581421,\n        0.8799999952316284,\n        0.949999988079071,\n        0.7699999809265137,\n        0.7799999713897705,\n        0.8700000047683716,\n        0.9800000190734863,\n        1.2100000381469727,\n        1.0,\n        0.9399999976158142,\n        1.0800000429153442,\n        1.1799999475479126,\n        1.2400000095367432\n      ],\n      \"forecast_dates\": [\n        \"2023-09\",\n        \"2023-10\",\n        \"2023-11\",\n        \"2023-12\",\n        \"2024-01\",\n        \"2024-02\",\n        \"2024-03\",\n        \"2024-04\",\n        \"2024-05\",\n        \"2024-06\",\n        \"2024-07\",\n        \"2024-08\",\n        \"2024-09\",\n        \"2024-10\",\n        \"2024-11\",\n        \"2024-12\",\n        \"2025-01\",\n        \"2025-02\",\n        \"2025-03\",\n        \"2025-04\",\n        \"2025-05\",\n        \"2025-06\",\n        \"2025-07\",\n        \"2025-08\",\n        \"2025-09\",\n        \"2025-10\",\n        \"2025-11\",\n        \"2025-12\"\n      ],\n      \"point_forecast\": [\n        1.1063826084136963,\n        1.0667672157287598,\n        1.0312474966049194,\n        1.0092777013778687,\n        0.9886403679847717,\n        0.9805473685264587,\n        0.96883624792099,\n        0.9500434994697571,\n        0.9289879202842712,\n        0.9156991839408875,\n        0.9083491563796997,\n        0.9020676016807556,\n        0.9000667333602905,\n        0.8952069878578186,\n        0.8887008428573608,\n        0.8977259993553162,\n        0.9318806529045105,\n        0.9759154915809631,\n        1.0011931657791138,\n        1.0136791467666626,\n        1.0154764652252197,\n        1.0213247537612915,\n        1.0302479267120361,\n        1.032987117767334,\n        1.0179458856582642,\n        0.9947344660758972,\n        0.9729111194610596,\n        0.9626883268356323\n      ],\n      \"q10\": [\n        1.114622950553894,\n        1.083889365196228,\n        1.0484296083450317,\n        1.0276585817337036,\n        1.008374571800232,\n        0.999535322189331,\n        0.9902844429016113,\n        0.9757266640663147,\n        0.9533360600471497,\n        0.9409008026123047,\n        0.9341027736663818,\n        0.9281788468360901,\n        0.9299426674842834,\n        0.921561062335968,\n        0.9143303632736206,\n        0.9240468144416809,\n        0.9563655853271484,\n        1.0021518468856812,\n        1.0241011381149292,\n        1.0326213836669922,\n        1.0297893285751343,\n        1.0334995985031128,\n        1.0426249504089355,\n        1.047775149345398,\n        1.031937837600708,\n        1.0122848749160767,\n        0.9894399642944336,\n        0.978018045425415\n      ],\n      \"q20\": [\n        0.928669810295105,\n        0.8862699866294861,\n        0.8555266261100769,\n        0.8365516662597656,\n        0.8246086835861206,\n        0.8187647461891174,\n        0.8126576542854309,\n        0.8008460402488708,\n        0.7927306890487671,\n        0.7833954095840454,\n        0.7795919179916382,\n        0.7797963619232178,\n        0.7819650173187256,\n        0.7769280672073364,\n        0.7692436575889587,\n        0.7726868391036987,\n        0.7912442684173584,\n        0.8222379088401794,\n        0.8362159132957458,\n        0.8447703719139099,\n        0.8396773934364319,\n        0.8379412293434143,\n        0.8396240472793579,\n        0.8429920077323914,\n        0.833158016204834,\n        0.823620080947876,\n        0.8104652762413025,\n        0.8035314083099365\n      ],\n      \"q80\": [\n        1.1856414079666138,\n        1.1520715951919556,\n        1.117408037185669,\n        1.0936567783355713,\n        1.0721673965454102,\n        1.0631694793701172,\n        1.048310399055481,\n        1.0276391506195068,\n        1.0055267810821533,\n        0.9882948994636536,\n        0.9792788624763489,\n        0.9736778736114502,\n        0.9714402556419373,\n        0.9655618071556091,\n        0.9581301808357239,\n        0.9696058034896851,\n        1.0068414211273193,\n        1.0576438903808594,\n        1.0841014385223389,\n        1.0951288938522339,\n        1.1002628803253174,\n        1.1048551797866821,\n        1.1152007579803467,\n        1.1188753843307495,\n        1.1012613773345947,\n        1.0757598876953125,\n        1.0499663352966309,\n        1.0353318452835083\n      ],\n      \"q90\": [\n        1.23917818069458,\n        1.2113547325134277,\n        1.1742331981658936,\n        1.151162028312683,\n        1.1314780712127686,\n        1.1195954084396362,\n        1.10871160030365,\n        1.0842714309692383,\n        1.0615670680999756,\n        1.0447986125946045,\n        1.0315890312194824,\n        1.024493932723999,\n        1.0225589275360107,\n        1.0159486532211304,\n        1.0109714269638062,\n        1.0237358808517456,\n        1.0626462697982788,\n        1.1151678562164307,\n        1.1429989337921143,\n        1.151975393295288,\n        1.1542848348617554,\n        1.1620826721191406,\n        1.1735471487045288,\n        1.1768637895584106,\n        1.1591532230377197,\n        1.1298507452011108,\n        1.1008673906326294,\n        1.0874378681182861\n      ]\n    },\n    {\n      \"step\": 10,\n      \"n_points\": 21,\n      \"horizon\": 27,\n      \"last_historical_date\": \"2023-09\",\n      \"historical_dates\": [\n        \"2022-01\",\n        \"2022-02\",\n        \"2022-03\",\n        \"2022-04\",\n        \"2022-05\",\n        \"2022-06\",\n        \"2022-07\",\n        \"2022-08\",\n        \"2022-09\",\n        \"2022-10\",\n        \"2022-11\",\n        \"2022-12\",\n        \"2023-01\",\n        \"2023-02\",\n        \"2023-03\",\n        \"2023-04\",\n        \"2023-05\",\n        \"2023-06\",\n        \"2023-07\",\n        \"2023-08\",\n        \"2023-09\"\n      ],\n      \"historical_values\": [\n        0.8899999856948853,\n        0.8899999856948853,\n        1.0199999809265137,\n        0.8799999952316284,\n        0.8500000238418579,\n        0.8799999952316284,\n        0.8799999952316284,\n        0.8999999761581421,\n        0.8799999952316284,\n        0.949999988079071,\n        0.7699999809265137,\n        0.7799999713897705,\n        0.8700000047683716,\n        0.9800000190734863,\n        1.2100000381469727,\n        1.0,\n        0.9399999976158142,\n        1.0800000429153442,\n        1.1799999475479126,\n        1.2400000095367432,\n        1.4700000286102295\n      ],\n      \"forecast_dates\": [\n        \"2023-10\",\n        \"2023-11\",\n        \"2023-12\",\n        \"2024-01\",\n        \"2024-02\",\n        \"2024-03\",\n        \"2024-04\",\n        \"2024-05\",\n        \"2024-06\",\n        \"2024-07\",\n        \"2024-08\",\n        \"2024-09\",\n        \"2024-10\",\n        \"2024-11\",\n        \"2024-12\",\n        \"2025-01\",\n        \"2025-02\",\n        \"2025-03\",\n        \"2025-04\",\n        \"2025-05\",\n        \"2025-06\",\n        \"2025-07\",\n        \"2025-08\",\n        \"2025-09\",\n        \"2025-10\",\n        \"2025-11\",\n        \"2025-12\"\n      ],\n      \"point_forecast\": [\n        1.2447655200958252,\n        1.1675736904144287,\n        1.1279113292694092,\n        1.1182188987731934,\n        1.1093124151229858,\n        1.082032322883606,\n        1.038187861442566,\n        0.9993720650672913,\n        0.9796157479286194,\n        0.9642789959907532,\n        0.9476039409637451,\n        0.9355512857437134,\n        0.9285284876823425,\n        0.9198805689811707,\n        0.9171224236488342,\n        0.9307593703269958,\n        0.968371570110321,\n        0.9984195232391357,\n        0.9925985932350159,\n        0.9877574443817139,\n        0.9934283494949341,\n        1.0018013715744019,\n        1.015841007232666,\n        1.0086023807525635,\n        0.981035590171814,\n        0.9596301913261414,\n        0.950019896030426\n      ],\n      \"q10\": [\n        1.2715141773223877,\n        1.2083916664123535,\n        1.1731905937194824,\n        1.165351152420044,\n        1.162253975868225,\n        1.13302481174469,\n        1.085452914237976,\n        1.051274299621582,\n        1.0313252210617065,\n        1.0168172121047974,\n        0.9987383484840393,\n        0.9869235754013062,\n        0.9812518358230591,\n        0.9713582396507263,\n        0.9646536111831665,\n        0.9781244397163391,\n        1.0141757726669312,\n        1.050538420677185,\n        1.0407419204711914,\n        1.032418966293335,\n        1.0342923402786255,\n        1.0425127744674683,\n        1.05617356300354,\n        1.0557340383529663,\n        1.0226852893829346,\n        1.0031310319900513,\n        0.9946122169494629\n      ],\n      \"q20\": [\n        0.9692280888557434,\n        0.9033447504043579,\n        0.8709640502929688,\n        0.8632612824440002,\n        0.8616656064987183,\n        0.8437307476997375,\n        0.8145183324813843,\n        0.7942112684249878,\n        0.7919824123382568,\n        0.7849438190460205,\n        0.7758752703666687,\n        0.7725547552108765,\n        0.7724835276603699,\n        0.7696750164031982,\n        0.7656691074371338,\n        0.7687865495681763,\n        0.7848570346832275,\n        0.8048490285873413,\n        0.7928374409675598,\n        0.7848871946334839,\n        0.7746942043304443,\n        0.7734623551368713,\n        0.7735666036605835,\n        0.765663743019104,\n        0.7521377205848694,\n        0.7475736737251282,\n        0.7519190907478333\n      ],\n      \"q80\": [\n        1.3772318363189697,\n        1.3073946237564087,\n        1.267617106437683,\n        1.2576971054077148,\n        1.2495336532592773,\n        1.2185810804367065,\n        1.1627202033996582,\n        1.1192079782485962,\n        1.093948483467102,\n        1.0731803178787231,\n        1.0513980388641357,\n        1.0379669666290283,\n        1.0290329456329346,\n        1.0203547477722168,\n        1.0156269073486328,\n        1.0321729183197021,\n        1.0734044313430786,\n        1.1123948097229004,\n        1.1079280376434326,\n        1.1026053428649902,\n        1.1133449077606201,\n        1.1250957250595093,\n        1.1411525011062622,\n        1.1397948265075684,\n        1.104438066482544,\n        1.076056957244873,\n        1.0614937543869019\n      ],\n      \"q90\": [\n        1.4695751667022705,\n        1.4090934991836548,\n        1.3679797649383545,\n        1.3577240705490112,\n        1.3525687456130981,\n        1.315553903579712,\n        1.2607886791229248,\n        1.2103060483932495,\n        1.1827821731567383,\n        1.1617928743362427,\n        1.1323959827423096,\n        1.1176999807357788,\n        1.1078500747680664,\n        1.094464659690857,\n        1.0922305583953857,\n        1.1100425720214844,\n        1.1573286056518555,\n        1.1960670948028564,\n        1.1912381649017334,\n        1.1854121685028076,\n        1.1960220336914062,\n        1.2121495008468628,\n        1.231044888496399,\n        1.2304543256759644,\n        1.1941158771514893,\n        1.1591618061065674,\n        1.1404690742492676\n      ]\n    },\n    {\n      \"step\": 11,\n      \"n_points\": 22,\n      \"horizon\": 26,\n      \"last_historical_date\": \"2023-10\",\n      \"historical_dates\": [\n        \"2022-01\",\n        \"2022-02\",\n        \"2022-03\",\n        \"2022-04\",\n        \"2022-05\",\n        \"2022-06\",\n        \"2022-07\",\n        \"2022-08\",\n        \"2022-09\",\n        \"2022-10\",\n        \"2022-11\",\n        \"2022-12\",\n        \"2023-01\",\n        \"2023-02\",\n        \"2023-03\",\n        \"2023-04\",\n        \"2023-05\",\n        \"2023-06\",\n        \"2023-07\",\n        \"2023-08\",\n        \"2023-09\",\n        \"2023-10\"\n      ],\n      \"historical_values\": [\n        0.8899999856948853,\n        0.8899999856948853,\n        1.0199999809265137,\n        0.8799999952316284,\n        0.8500000238418579,\n        0.8799999952316284,\n        0.8799999952316284,\n        0.8999999761581421,\n        0.8799999952316284,\n        0.949999988079071,\n        0.7699999809265137,\n        0.7799999713897705,\n        0.8700000047683716,\n        0.9800000190734863,\n        1.2100000381469727,\n        1.0,\n        0.9399999976158142,\n        1.0800000429153442,\n        1.1799999475479126,\n        1.2400000095367432,\n        1.4700000286102295,\n        1.3200000524520874\n      ],\n      \"forecast_dates\": [\n        \"2023-11\",\n        \"2023-12\",\n        \"2024-01\",\n        \"2024-02\",\n        \"2024-03\",\n        \"2024-04\",\n        \"2024-05\",\n        \"2024-06\",\n        \"2024-07\",\n        \"2024-08\",\n        \"2024-09\",\n        \"2024-10\",\n        \"2024-11\",\n        \"2024-12\",\n        \"2025-01\",\n        \"2025-02\",\n        \"2025-03\",\n        \"2025-04\",\n        \"2025-05\",\n        \"2025-06\",\n        \"2025-07\",\n        \"2025-08\",\n        \"2025-09\",\n        \"2025-10\",\n        \"2025-11\",\n        \"2025-12\"\n      ],\n      \"point_forecast\": [\n        1.1978572607040405,\n        1.1348843574523926,\n        1.107893705368042,\n        1.0890357494354248,\n        1.075318455696106,\n        1.0392742156982422,\n        1.0066328048706055,\n        0.9802011847496033,\n        0.968873143196106,\n        0.9584881663322449,\n        0.9440371990203857,\n        0.929160475730896,\n        0.9231215715408325,\n        0.9263395667076111,\n        0.9453635811805725,\n        0.9834461212158203,\n        1.0165542364120483,\n        1.0119515657424927,\n        1.0067156553268433,\n        1.0273469686508179,\n        1.068889856338501,\n        1.1046394109725952,\n        1.120269536972046,\n        1.084791898727417,\n        1.0440341234207153,\n        1.0170215368270874\n      ],\n      \"q10\": [\n        1.2035126686096191,\n        1.153576135635376,\n        1.1352055072784424,\n        1.1203036308288574,\n        1.1123145818710327,\n        1.0742825269699097,\n        1.0389323234558105,\n        1.017652988433838,\n        1.0134992599487305,\n        1.0038114786148071,\n        0.9876317381858826,\n        0.972976565361023,\n        0.9668206572532654,\n        0.9690794348716736,\n        0.9879439473152161,\n        1.0214078426361084,\n        1.0546575784683228,\n        1.0502262115478516,\n        1.0401197671890259,\n        1.0604331493377686,\n        1.0953052043914795,\n        1.1325199604034424,\n        1.144276738166809,\n        1.1159130334854126,\n        1.0714142322540283,\n        1.0489661693572998\n      ],\n      \"q20\": [\n        0.9713577032089233,\n        0.9063910245895386,\n        0.8755015134811401,\n        0.8545557260513306,\n        0.8455488681793213,\n        0.8177679777145386,\n        0.799569845199585,\n        0.7851544618606567,\n        0.7884225249290466,\n        0.7802386283874512,\n        0.7720929980278015,\n        0.7622212171554565,\n        0.7612568736076355,\n        0.7628719806671143,\n        0.7799019813537598,\n        0.7968021035194397,\n        0.8116334676742554,\n        0.8049068450927734,\n        0.7901184558868408,\n        0.7965429425239563,\n        0.8083176612854004,\n        0.8315435647964478,\n        0.8326961994171143,\n        0.8086848258972168,\n        0.7895619869232178,\n        0.7825078368186951\n      ],\n      \"q80\": [\n        1.2972460985183716,\n        1.245476484298706,\n        1.2229666709899902,\n        1.210435152053833,\n        1.1973446607589722,\n        1.157381296157837,\n        1.1181674003601074,\n        1.0869324207305908,\n        1.075097680091858,\n        1.0632023811340332,\n        1.0455275774002075,\n        1.0302590131759644,\n        1.0215204954147339,\n        1.025394320487976,\n        1.045914649963379,\n        1.0890913009643555,\n        1.1246864795684814,\n        1.125206470489502,\n        1.1208384037017822,\n        1.145365834236145,\n        1.1913384199142456,\n        1.2334762811660767,\n        1.2504417896270752,\n        1.2180585861206055,\n        1.1676602363586426,\n        1.1349562406539917\n      ],\n      \"q90\": [\n        1.3638895750045776,\n        1.323225975036621,\n        1.304998755455017,\n        1.2944636344909668,\n        1.2835395336151123,\n        1.2412294149398804,\n        1.1998721361160278,\n        1.1685125827789307,\n        1.1557502746582031,\n        1.1425185203552246,\n        1.1200439929962158,\n        1.1038810014724731,\n        1.0953530073165894,\n        1.0953185558319092,\n        1.1211014986038208,\n        1.165018916130066,\n        1.2059204578399658,\n        1.2043343782424927,\n        1.1997365951538086,\n        1.224650502204895,\n        1.2735167741775513,\n        1.3202080726623535,\n        1.3405519723892212,\n        1.304829716682434,\n        1.2509324550628662,\n        1.213225245475769\n      ]\n    },\n    {\n      \"step\": 12,\n      \"n_points\": 23,\n      \"horizon\": 25,\n      \"last_historical_date\": \"2023-11\",\n      \"historical_dates\": [\n        \"2022-01\",\n        \"2022-02\",\n        \"2022-03\",\n        \"2022-04\",\n        \"2022-05\",\n        \"2022-06\",\n        \"2022-07\",\n        \"2022-08\",\n        \"2022-09\",\n        \"2022-10\",\n        \"2022-11\",\n        \"2022-12\",\n        \"2023-01\",\n        \"2023-02\",\n        \"2023-03\",\n        \"2023-04\",\n        \"2023-05\",\n        \"2023-06\",\n        \"2023-07\",\n        \"2023-08\",\n        \"2023-09\",\n        \"2023-10\",\n        \"2023-11\"\n      ],\n      \"historical_values\": [\n        0.8899999856948853,\n        0.8899999856948853,\n        1.0199999809265137,\n        0.8799999952316284,\n        0.8500000238418579,\n        0.8799999952316284,\n        0.8799999952316284,\n        0.8999999761581421,\n        0.8799999952316284,\n        0.949999988079071,\n        0.7699999809265137,\n        0.7799999713897705,\n        0.8700000047683716,\n        0.9800000190734863,\n        1.2100000381469727,\n        1.0,\n        0.9399999976158142,\n        1.0800000429153442,\n        1.1799999475479126,\n        1.2400000095367432,\n        1.4700000286102295,\n        1.3200000524520874,\n        1.1799999475479126\n      ],\n      \"forecast_dates\": [\n        \"2023-12\",\n        \"2024-01\",\n        \"2024-02\",\n        \"2024-03\",\n        \"2024-04\",\n        \"2024-05\",\n        \"2024-06\",\n        \"2024-07\",\n        \"2024-08\",\n        \"2024-09\",\n        \"2024-10\",\n        \"2024-11\",\n        \"2024-12\",\n        \"2025-01\",\n        \"2025-02\",\n        \"2025-03\",\n        \"2025-04\",\n        \"2025-05\",\n        \"2025-06\",\n        \"2025-07\",\n        \"2025-08\",\n        \"2025-09\",\n        \"2025-10\",\n        \"2025-11\",\n        \"2025-12\"\n      ],\n      \"point_forecast\": [\n        1.1388345956802368,\n        1.1001060009002686,\n        1.0774588584899902,\n        1.0615431070327759,\n        1.0359764099121094,\n        1.0100469589233398,\n        0.9895788431167603,\n        0.9849971532821655,\n        0.9746628999710083,\n        0.9684356451034546,\n        0.9609130024909973,\n        0.947131872177124,\n        0.9499838352203369,\n        0.9744009375572205,\n        1.0130703449249268,\n        1.0441275835037231,\n        1.048531413078308,\n        1.0471770763397217,\n        1.0618387460708618,\n        1.1032129526138306,\n        1.1474189758300781,\n        1.1728394031524658,\n        1.1430011987686157,\n        1.0839053392410278,\n        1.0471035242080688\n      ],\n      \"q10\": [\n        1.143956184387207,\n        1.1164032220840454,\n        1.0988131761550903,\n        1.0883313417434692,\n        1.0633952617645264,\n        1.0377331972122192,\n        1.0185223817825317,\n        1.0154881477355957,\n        1.0130091905593872,\n        1.006235957145691,\n        0.9972001314163208,\n        0.984115719795227,\n        0.9868376851081848,\n        1.0110416412353516,\n        1.0470901727676392,\n        1.078067660331726,\n        1.0788366794586182,\n        1.0745474100112915,\n        1.0864962339401245,\n        1.1283372640609741,\n        1.1684935092926025,\n        1.194905400276184,\n        1.1594902276992798,\n        1.106303095817566,\n        1.0674790143966675\n      ],\n      \"q20\": [\n        0.9558293223381042,\n        0.9077008962631226,\n        0.875536322593689,\n        0.8599477410316467,\n        0.8395929932594299,\n        0.820803165435791,\n        0.8097033500671387,\n        0.8071569800376892,\n        0.8063573837280273,\n        0.7997854351997375,\n        0.7947160601615906,\n        0.7840617895126343,\n        0.7878046035766602,\n        0.8045357465744019,\n        0.8319349884986877,\n        0.8483662605285645,\n        0.8439525961875916,\n        0.8370295166969299,\n        0.8409282565116882,\n        0.8701899647712708,\n        0.8887082934379578,\n        0.9067206382751465,\n        0.8854538798332214,\n        0.8463788628578186,\n        0.8287973999977112\n      ],\n      \"q80\": [\n        1.2187292575836182,\n        1.1895191669464111,\n        1.1730304956436157,\n        1.1645177602767944,\n        1.1339150667190552,\n        1.1082265377044678,\n        1.0852689743041992,\n        1.0772539377212524,\n        1.0709658861160278,\n        1.0674384832382202,\n        1.0557781457901,\n        1.0452414751052856,\n        1.0445914268493652,\n        1.07282292842865,\n        1.1126301288604736,\n        1.1508023738861084,\n        1.1525412797927856,\n        1.1523170471191406,\n        1.1721690893173218,\n        1.2185754776000977,\n        1.2663267850875854,\n        1.2923482656478882,\n        1.2582346200942993,\n        1.1959534883499146,\n        1.1527845859527588\n      ],\n      \"q90\": [\n        1.2729495763778687,\n        1.2533750534057617,\n        1.2407320737838745,\n        1.2354146242141724,\n        1.2064470052719116,\n        1.1776363849639893,\n        1.1529877185821533,\n        1.1496665477752686,\n        1.1451096534729004,\n        1.137753963470459,\n        1.1235407590866089,\n        1.1123000383377075,\n        1.1126760244369507,\n        1.1393102407455444,\n        1.185707449913025,\n        1.2234959602355957,\n        1.2277790307998657,\n        1.22639799118042,\n        1.2456328868865967,\n        1.2939422130584717,\n        1.345726490020752,\n        1.3719840049743652,\n        1.338273048400879,\n        1.2677257061004639,\n        1.2217291593551636\n      ]\n    },\n    {\n      \"step\": 13,\n      \"n_points\": 24,\n      \"horizon\": 24,\n      \"last_historical_date\": \"2023-12\",\n      \"historical_dates\": [\n        \"2022-01\",\n        \"2022-02\",\n        \"2022-03\",\n        \"2022-04\",\n        \"2022-05\",\n        \"2022-06\",\n        \"2022-07\",\n        \"2022-08\",\n        \"2022-09\",\n        \"2022-10\",\n        \"2022-11\",\n        \"2022-12\",\n        \"2023-01\",\n        \"2023-02\",\n        \"2023-03\",\n        \"2023-04\",\n        \"2023-05\",\n        \"2023-06\",\n        \"2023-07\",\n        \"2023-08\",\n        \"2023-09\",\n        \"2023-10\",\n        \"2023-11\",\n        \"2023-12\"\n      ],\n      \"historical_values\": [\n        0.8899999856948853,\n        0.8899999856948853,\n        1.0199999809265137,\n        0.8799999952316284,\n        0.8500000238418579,\n        0.8799999952316284,\n        0.8799999952316284,\n        0.8999999761581421,\n        0.8799999952316284,\n        0.949999988079071,\n        0.7699999809265137,\n        0.7799999713897705,\n        0.8700000047683716,\n        0.9800000190734863,\n        1.2100000381469727,\n        1.0,\n        0.9399999976158142,\n        1.0800000429153442,\n        1.1799999475479126,\n        1.2400000095367432,\n        1.4700000286102295,\n        1.3200000524520874,\n        1.1799999475479126,\n        1.159999966621399\n      ],\n      \"forecast_dates\": [\n        \"2024-01\",\n        \"2024-02\",\n        \"2024-03\",\n        \"2024-04\",\n        \"2024-05\",\n        \"2024-06\",\n        \"2024-07\",\n        \"2024-08\",\n        \"2024-09\",\n        \"2024-10\",\n        \"2024-11\",\n        \"2024-12\",\n        \"2025-01\",\n        \"2025-02\",\n        \"2025-03\",\n        \"2025-04\",\n        \"2025-05\",\n        \"2025-06\",\n        \"2025-07\",\n        \"2025-08\",\n        \"2025-09\",\n        \"2025-10\",\n        \"2025-11\",\n        \"2025-12\"\n      ],\n      \"point_forecast\": [\n        1.1204800605773926,\n        1.0831129550933838,\n        1.0525826215744019,\n        1.0186809301376343,\n        0.996323823928833,\n        0.9761021733283997,\n        0.966797411441803,\n        0.9621630311012268,\n        0.950423002243042,\n        0.9326475262641907,\n        0.9303779602050781,\n        0.9362010955810547,\n        0.9639466404914856,\n        1.0171366930007935,\n        1.0539826154708862,\n        1.0581066608428955,\n        1.05403470993042,\n        1.07761549949646,\n        1.122676134109497,\n        1.180346965789795,\n        1.1975631713867188,\n        1.1708546876907349,\n        1.117448329925537,\n        1.0691102743148804\n      ],\n      \"q10\": [\n        1.1319338083267212,\n        1.1058242321014404,\n        1.0804548263549805,\n        1.0469233989715576,\n        1.0246795415878296,\n        1.0055618286132812,\n        0.999349057674408,\n        0.9949856996536255,\n        0.9896860718727112,\n        0.9742559194564819,\n        0.9675081968307495,\n        0.9734180569648743,\n        1.0023202896118164,\n        1.053297996520996,\n        1.090195894241333,\n        1.088844656944275,\n        1.082571029663086,\n        1.104530930519104,\n        1.1468923091888428,\n        1.2043083906173706,\n        1.2187085151672363,\n        1.19277822971344,\n        1.1290017366409302,\n        1.0879333019256592\n      ],\n      \"q20\": [\n        0.9561834335327148,\n        0.9061079621315002,\n        0.8687788844108582,\n        0.8394415378570557,\n        0.8218992948532104,\n        0.8107370138168335,\n        0.8105956315994263,\n        0.8031740784645081,\n        0.8004634380340576,\n        0.7854968309402466,\n        0.7851479053497314,\n        0.7882705330848694,\n        0.8095588684082031,\n        0.8434075117111206,\n        0.8662194013595581,\n        0.8621299862861633,\n        0.8524537682533264,\n        0.8656907677650452,\n        0.896289587020874,\n        0.9350174069404602,\n        0.940517783164978,\n        0.9245748519897461,\n        0.8929179906845093,\n        0.8636151552200317\n      ],\n      \"q80\": [\n        1.19773530960083,\n        1.1693586111068726,\n        1.14640212059021,\n        1.11386239528656,\n        1.082446813583374,\n        1.0650819540023804,\n        1.05680513381958,\n        1.0481219291687012,\n        1.0429224967956543,\n        1.024938702583313,\n        1.0191327333450317,\n        1.028489589691162,\n        1.057991862297058,\n        1.1157665252685547,\n        1.1569236516952515,\n        1.1618187427520752,\n        1.157217025756836,\n        1.1827739477157593,\n        1.2360106706619263,\n        1.2970430850982666,\n        1.3167476654052734,\n        1.2833902835845947,\n        1.2190351486206055,\n        1.1678544282913208\n      ],\n      \"q90\": [\n        1.2482070922851562,\n        1.229236364364624,\n        1.210077166557312,\n        1.18027925491333,\n        1.1515717506408691,\n        1.1297614574432373,\n        1.1205626726150513,\n        1.1177691221237183,\n        1.112573504447937,\n        1.0930581092834473,\n        1.084266185760498,\n        1.0912758111953735,\n        1.1246064901351929,\n        1.182848334312439,\n        1.2307857275009155,\n        1.2338712215423584,\n        1.2311983108520508,\n        1.2551823854446411,\n        1.3106720447540283,\n        1.3747836351394653,\n        1.3966447114944458,\n        1.3567662239074707,\n        1.2892448902130127,\n        1.23186457157135\n      ]\n    },\n    {\n      \"step\": 14,\n      \"n_points\": 25,\n      \"horizon\": 23,\n      \"last_historical_date\": \"2024-01\",\n      \"historical_dates\": [\n        \"2022-01\",\n        \"2022-02\",\n        \"2022-03\",\n        \"2022-04\",\n        \"2022-05\",\n        \"2022-06\",\n        \"2022-07\",\n        \"2022-08\",\n        \"2022-09\",\n        \"2022-10\",\n        \"2022-11\",\n        \"2022-12\",\n        \"2023-01\",\n        \"2023-02\",\n        \"2023-03\",\n        \"2023-04\",\n        \"2023-05\",\n        \"2023-06\",\n        \"2023-07\",\n        \"2023-08\",\n        \"2023-09\",\n        \"2023-10\",\n        \"2023-11\",\n        \"2023-12\",\n        \"2024-01\"\n      ],\n      \"historical_values\": [\n        0.8899999856948853,\n        0.8899999856948853,\n        1.0199999809265137,\n        0.8799999952316284,\n        0.8500000238418579,\n        0.8799999952316284,\n        0.8799999952316284,\n        0.8999999761581421,\n        0.8799999952316284,\n        0.949999988079071,\n        0.7699999809265137,\n        0.7799999713897705,\n        0.8700000047683716,\n        0.9800000190734863,\n        1.2100000381469727,\n        1.0,\n        0.9399999976158142,\n        1.0800000429153442,\n        1.1799999475479126,\n        1.2400000095367432,\n        1.4700000286102295,\n        1.3200000524520874,\n        1.1799999475479126,\n        1.159999966621399,\n        1.2200000286102295\n      ],\n      \"forecast_dates\": [\n        \"2024-02\",\n        \"2024-03\",\n        \"2024-04\",\n        \"2024-05\",\n        \"2024-06\",\n        \"2024-07\",\n        \"2024-08\",\n        \"2024-09\",\n        \"2024-10\",\n        \"2024-11\",\n        \"2024-12\",\n        \"2025-01\",\n        \"2025-02\",\n        \"2025-03\",\n        \"2025-04\",\n        \"2025-05\",\n        \"2025-06\",\n        \"2025-07\",\n        \"2025-08\",\n        \"2025-09\",\n        \"2025-10\",\n        \"2025-11\",\n        \"2025-12\"\n      ],\n      \"point_forecast\": [\n        1.1701693534851074,\n        1.1349387168884277,\n        1.0960110425949097,\n        1.0637617111206055,\n        1.0402418375015259,\n        1.0265028476715088,\n        1.0204156637191772,\n        1.0119004249572754,\n        0.9878545999526978,\n        0.9743345379829407,\n        0.9826735258102417,\n        0.9942994117736816,\n        1.0274856090545654,\n        1.0622732639312744,\n        1.0651201009750366,\n        1.073114037513733,\n        1.0891278982162476,\n        1.126175880432129,\n        1.157884955406189,\n        1.1790106296539307,\n        1.1665725708007812,\n        1.1304852962493896,\n        1.1106657981872559\n      ],\n      \"q10\": [\n        1.1749104261398315,\n        1.147524118423462,\n        1.1174193620681763,\n        1.086887001991272,\n        1.0630450248718262,\n        1.0531063079833984,\n        1.0497565269470215,\n        1.042683482170105,\n        1.0233265161514282,\n        1.0111165046691895,\n        1.014377236366272,\n        1.0274351835250854,\n        1.062585711479187,\n        1.0902963876724243,\n        1.0922062397003174,\n        1.0912160873413086,\n        1.11197829246521,\n        1.1466877460479736,\n        1.1778086423873901,\n        1.199917197227478,\n        1.1789664030075073,\n        1.1495457887649536,\n        1.123175859451294\n      ],\n      \"q20\": [\n        0.9954406023025513,\n        0.9378616213798523,\n        0.893646240234375,\n        0.8610368967056274,\n        0.8414109945297241,\n        0.8318982124328613,\n        0.829987645149231,\n        0.8171640634536743,\n        0.8035246729850769,\n        0.7929065227508545,\n        0.8037456274032593,\n        0.8133399486541748,\n        0.8395006656646729,\n        0.8581016659736633,\n        0.8608949184417725,\n        0.8612385392189026,\n        0.8694777488708496,\n        0.896060585975647,\n        0.9189809560775757,\n        0.9354234337806702,\n        0.918700635433197,\n        0.8955419063568115,\n        0.8815087676048279\n      ],\n      \"q80\": [\n        1.2475481033325195,\n        1.2218120098114014,\n        1.1920394897460938,\n        1.1621203422546387,\n        1.1338578462600708,\n        1.1270941495895386,\n        1.1244370937347412,\n        1.11036217212677,\n        1.0929012298583984,\n        1.0770790576934814,\n        1.0825059413909912,\n        1.0962635278701782,\n        1.133682131767273,\n        1.166754126548767,\n        1.1711883544921875,\n        1.1767802238464355,\n        1.1959017515182495,\n        1.2360646724700928,\n        1.272753357887268,\n        1.293941855430603,\n        1.2775542736053467,\n        1.2417978048324585,\n        1.215286374092102\n      ],\n      \"q90\": [\n        1.2978864908218384,\n        1.2807369232177734,\n        1.2577829360961914,\n        1.2319256067276,\n        1.2072914838790894,\n        1.1944835186004639,\n        1.1949646472930908,\n        1.1887325048446655,\n        1.1706409454345703,\n        1.1535823345184326,\n        1.1557773351669312,\n        1.165435791015625,\n        1.2051732540130615,\n        1.2392327785491943,\n        1.2467840909957886,\n        1.2500178813934326,\n        1.2747631072998047,\n        1.3121440410614014,\n        1.3449633121490479,\n        1.3688087463378906,\n        1.353420376777649,\n        1.3119401931762695,\n        1.2864404916763306\n      ]\n    },\n    {\n      \"step\": 15,\n      \"n_points\": 26,\n      \"horizon\": 22,\n      \"last_historical_date\": \"2024-02\",\n      \"historical_dates\": [\n        \"2022-01\",\n        \"2022-02\",\n        \"2022-03\",\n        \"2022-04\",\n        \"2022-05\",\n        \"2022-06\",\n        \"2022-07\",\n        \"2022-08\",\n        \"2022-09\",\n        \"2022-10\",\n        \"2022-11\",\n        \"2022-12\",\n        \"2023-01\",\n        \"2023-02\",\n        \"2023-03\",\n        \"2023-04\",\n        \"2023-05\",\n        \"2023-06\",\n        \"2023-07\",\n        \"2023-08\",\n        \"2023-09\",\n        \"2023-10\",\n        \"2023-11\",\n        \"2023-12\",\n        \"2024-01\",\n        \"2024-02\"\n      ],\n      \"historical_values\": [\n        0.8899999856948853,\n        0.8899999856948853,\n        1.0199999809265137,\n        0.8799999952316284,\n        0.8500000238418579,\n        0.8799999952316284,\n        0.8799999952316284,\n        0.8999999761581421,\n        0.8799999952316284,\n        0.949999988079071,\n        0.7699999809265137,\n        0.7799999713897705,\n        0.8700000047683716,\n        0.9800000190734863,\n        1.2100000381469727,\n        1.0,\n        0.9399999976158142,\n        1.0800000429153442,\n        1.1799999475479126,\n        1.2400000095367432,\n        1.4700000286102295,\n        1.3200000524520874,\n        1.1799999475479126,\n        1.159999966621399,\n        1.2200000286102295,\n        1.350000023841858\n      ],\n      \"forecast_dates\": [\n        \"2024-03\",\n        \"2024-04\",\n        \"2024-05\",\n        \"2024-06\",\n        \"2024-07\",\n        \"2024-08\",\n        \"2024-09\",\n        \"2024-10\",\n        \"2024-11\",\n        \"2024-12\",\n        \"2025-01\",\n        \"2025-02\",\n        \"2025-03\",\n        \"2025-04\",\n        \"2025-05\",\n        \"2025-06\",\n        \"2025-07\",\n        \"2025-08\",\n        \"2025-09\",\n        \"2025-10\",\n        \"2025-11\",\n        \"2025-12\"\n      ],\n      \"point_forecast\": [\n        1.2504206895828247,\n        1.2035315036773682,\n        1.1605435609817505,\n        1.1372957229614258,\n        1.1169792413711548,\n        1.1097407341003418,\n        1.0960330963134766,\n        1.0716885328292847,\n        1.0411385297775269,\n        1.0377408266067505,\n        1.06381356716156,\n        1.09853994846344,\n        1.1106432676315308,\n        1.1038997173309326,\n        1.0912792682647705,\n        1.10673189163208,\n        1.128816843032837,\n        1.1672472953796387,\n        1.169884204864502,\n        1.1492640972137451,\n        1.1251699924468994,\n        1.1080049276351929\n      ],\n      \"q10\": [\n        1.253143310546875,\n        1.2137634754180908,\n        1.175628900527954,\n        1.158146858215332,\n        1.1375560760498047,\n        1.1330972909927368,\n        1.1224530935287476,\n        1.0991952419281006,\n        1.0732285976409912,\n        1.069901704788208,\n        1.0908238887786865,\n        1.1302318572998047,\n        1.1447051763534546,\n        1.1265060901641846,\n        1.1150192022323608,\n        1.1237907409667969,\n        1.1495832204818726,\n        1.187064528465271,\n        1.191187858581543,\n        1.1717422008514404,\n        1.1371166706085205,\n        1.1280303001403809\n      ],\n      \"q20\": [\n        1.0437579154968262,\n        0.9754042625427246,\n        0.9281424283981323,\n        0.8999512791633606,\n        0.8835805058479309,\n        0.8786535263061523,\n        0.868209958076477,\n        0.8477093577384949,\n        0.8295252919197083,\n        0.8285472989082336,\n        0.8487096428871155,\n        0.8732921481132507,\n        0.8824164271354675,\n        0.8700266480445862,\n        0.8598465323448181,\n        0.8674743175506592,\n        0.8803960084915161,\n        0.9123423099517822,\n        0.9124201536178589,\n        0.8980945348739624,\n        0.8717573881149292,\n        0.8591221570968628\n      ],\n      \"q80\": [\n        1.3365967273712158,\n        1.29902184009552,\n        1.2669174671173096,\n        1.2462443113327026,\n        1.2251611948013306,\n        1.224426031112671,\n        1.2126585245132446,\n        1.1816699504852295,\n        1.1577259302139282,\n        1.1497776508331299,\n        1.1759350299835205,\n        1.2160439491271973,\n        1.2304400205612183,\n        1.2202222347259521,\n        1.2069144248962402,\n        1.2211333513259888,\n        1.2466362714767456,\n        1.2859277725219727,\n        1.2911059856414795,\n        1.2705645561218262,\n        1.2402691841125488,\n        1.225570797920227\n      ],\n      \"q90\": [\n        1.394209623336792,\n        1.3661998510360718,\n        1.3383913040161133,\n        1.3226557970046997,\n        1.3062965869903564,\n        1.3001211881637573,\n        1.2918630838394165,\n        1.267607569694519,\n        1.2428820133209229,\n        1.2324764728546143,\n        1.254516839981079,\n        1.291373372077942,\n        1.3111931085586548,\n        1.2970809936523438,\n        1.2853456735610962,\n        1.3002972602844238,\n        1.330471396446228,\n        1.3700449466705322,\n        1.3697110414505005,\n        1.346665382385254,\n        1.31707763671875,\n        1.3017767667770386\n      ]\n    },\n    {\n      \"step\": 16,\n      \"n_points\": 27,\n      \"horizon\": 21,\n      \"last_historical_date\": \"2024-03\",\n      \"historical_dates\": [\n        \"2022-01\",\n        \"2022-02\",\n        \"2022-03\",\n        \"2022-04\",\n        \"2022-05\",\n        \"2022-06\",\n        \"2022-07\",\n        \"2022-08\",\n        \"2022-09\",\n        \"2022-10\",\n        \"2022-11\",\n        \"2022-12\",\n        \"2023-01\",\n        \"2023-02\",\n        \"2023-03\",\n        \"2023-04\",\n        \"2023-05\",\n        \"2023-06\",\n        \"2023-07\",\n        \"2023-08\",\n        \"2023-09\",\n        \"2023-10\",\n        \"2023-11\",\n        \"2023-12\",\n        \"2024-01\",\n        \"2024-02\",\n        \"2024-03\"\n      ],\n      \"historical_values\": [\n        0.8899999856948853,\n        0.8899999856948853,\n        1.0199999809265137,\n        0.8799999952316284,\n        0.8500000238418579,\n        0.8799999952316284,\n        0.8799999952316284,\n        0.8999999761581421,\n        0.8799999952316284,\n        0.949999988079071,\n        0.7699999809265137,\n        0.7799999713897705,\n        0.8700000047683716,\n        0.9800000190734863,\n        1.2100000381469727,\n        1.0,\n        0.9399999976158142,\n        1.0800000429153442,\n        1.1799999475479126,\n        1.2400000095367432,\n        1.4700000286102295,\n        1.3200000524520874,\n        1.1799999475479126,\n        1.159999966621399,\n        1.2200000286102295,\n        1.350000023841858,\n        1.340000033378601\n      ],\n      \"forecast_dates\": [\n        \"2024-04\",\n        \"2024-05\",\n        \"2024-06\",\n        \"2024-07\",\n        \"2024-08\",\n        \"2024-09\",\n        \"2024-10\",\n        \"2024-11\",\n        \"2024-12\",\n        \"2025-01\",\n        \"2025-02\",\n        \"2025-03\",\n        \"2025-04\",\n        \"2025-05\",\n        \"2025-06\",\n        \"2025-07\",\n        \"2025-08\",\n        \"2025-09\",\n        \"2025-10\",\n        \"2025-11\",\n        \"2025-12\"\n      ],\n      \"point_forecast\": [\n        1.2523874044418335,\n        1.2066220045089722,\n        1.1746571063995361,\n        1.1765081882476807,\n        1.1709487438201904,\n        1.169347882270813,\n        1.1399660110473633,\n        1.1141448020935059,\n        1.094247817993164,\n        1.0913820266723633,\n        1.1216974258422852,\n        1.1433929204940796,\n        1.1276500225067139,\n        1.1138465404510498,\n        1.1109668016433716,\n        1.1382179260253906,\n        1.145559310913086,\n        1.165015697479248,\n        1.1428844928741455,\n        1.1122182607650757,\n        1.1095082759857178\n      ],\n      \"q10\": [\n        1.2494522333145142,\n        1.2100024223327637,\n        1.1815905570983887,\n        1.184570550918579,\n        1.181471824645996,\n        1.1847987174987793,\n        1.1554681062698364,\n        1.1273032426834106,\n        1.1124141216278076,\n        1.1068137884140015,\n        1.1349601745605469,\n        1.160623550415039,\n        1.1481659412384033,\n        1.1232229471206665,\n        1.1228114366531372,\n        1.1419509649276733,\n        1.1522048711776733,\n        1.1742281913757324,\n        1.1551659107208252,\n        1.1268976926803589,\n        1.112238883972168\n      ],\n      \"q20\": [\n        1.0595918893814087,\n        0.9882703423500061,\n        0.9449520111083984,\n        0.9323371648788452,\n        0.921808123588562,\n        0.9140236973762512,\n        0.8879625797271729,\n        0.8599287271499634,\n        0.84772127866745,\n        0.8464851975440979,\n        0.8668861389160156,\n        0.8764016032218933,\n        0.862370491027832,\n        0.8420681953430176,\n        0.8450419306755066,\n        0.8666462898254395,\n        0.8749760985374451,\n        0.8925336003303528,\n        0.8715018033981323,\n        0.8530272841453552,\n        0.8424127697944641\n      ],\n      \"q80\": [\n        1.3265814781188965,\n        1.2932963371276855,\n        1.2723067998886108,\n        1.276952862739563,\n        1.2762058973312378,\n        1.284961462020874,\n        1.2592799663543701,\n        1.2249560356140137,\n        1.213465929031372,\n        1.2041243314743042,\n        1.2399941682815552,\n        1.2660539150238037,\n        1.2495875358581543,\n        1.2333945035934448,\n        1.2315037250518799,\n        1.2564735412597656,\n        1.264156699180603,\n        1.2841299772262573,\n        1.2626703977584839,\n        1.2329728603363037,\n        1.2221158742904663\n      ],\n      \"q90\": [\n        1.3771872520446777,\n        1.3524072170257568,\n        1.3376163244247437,\n        1.347804307937622,\n        1.3534436225891113,\n        1.3581876754760742,\n        1.3364894390106201,\n        1.3079556226730347,\n        1.295357346534729,\n        1.2872941493988037,\n        1.3177791833877563,\n        1.340587854385376,\n        1.3293620347976685,\n        1.3092248439788818,\n        1.309072494506836,\n        1.3312009572982788,\n        1.3418974876403809,\n        1.3621940612792969,\n        1.3367794752120972,\n        1.3070871829986572,\n        1.296994686126709\n      ]\n    },\n    {\n      \"step\": 17,\n      \"n_points\": 28,\n      \"horizon\": 20,\n      \"last_historical_date\": \"2024-04\",\n      \"historical_dates\": [\n        \"2022-01\",\n        \"2022-02\",\n        \"2022-03\",\n        \"2022-04\",\n        \"2022-05\",\n        \"2022-06\",\n        \"2022-07\",\n        \"2022-08\",\n        \"2022-09\",\n        \"2022-10\",\n        \"2022-11\",\n        \"2022-12\",\n        \"2023-01\",\n        \"2023-02\",\n        \"2023-03\",\n        \"2023-04\",\n        \"2023-05\",\n        \"2023-06\",\n        \"2023-07\",\n        \"2023-08\",\n        \"2023-09\",\n        \"2023-10\",\n        \"2023-11\",\n        \"2023-12\",\n        \"2024-01\",\n        \"2024-02\",\n        \"2024-03\",\n        \"2024-04\"\n      ],\n      \"historical_values\": [\n        0.8899999856948853,\n        0.8899999856948853,\n        1.0199999809265137,\n        0.8799999952316284,\n        0.8500000238418579,\n        0.8799999952316284,\n        0.8799999952316284,\n        0.8999999761581421,\n        0.8799999952316284,\n        0.949999988079071,\n        0.7699999809265137,\n        0.7799999713897705,\n        0.8700000047683716,\n        0.9800000190734863,\n        1.2100000381469727,\n        1.0,\n        0.9399999976158142,\n        1.0800000429153442,\n        1.1799999475479126,\n        1.2400000095367432,\n        1.4700000286102295,\n        1.3200000524520874,\n        1.1799999475479126,\n        1.159999966621399,\n        1.2200000286102295,\n        1.350000023841858,\n        1.340000033378601,\n        1.2599999904632568\n      ],\n      \"forecast_dates\": [\n        \"2024-05\",\n        \"2024-06\",\n        \"2024-07\",\n        \"2024-08\",\n        \"2024-09\",\n        \"2024-10\",\n        \"2024-11\",\n        \"2024-12\",\n        \"2025-01\",\n        \"2025-02\",\n        \"2025-03\",\n        \"2025-04\",\n        \"2025-05\",\n        \"2025-06\",\n        \"2025-07\",\n        \"2025-08\",\n        \"2025-09\",\n        \"2025-10\",\n        \"2025-11\",\n        \"2025-12\"\n      ],\n      \"point_forecast\": [\n        1.2068676948547363,\n        1.1843211650848389,\n        1.1752288341522217,\n        1.17955482006073,\n        1.1717453002929688,\n        1.1482445001602173,\n        1.1248430013656616,\n        1.1241732835769653,\n        1.1235134601593018,\n        1.1300708055496216,\n        1.1367747783660889,\n        1.1233289241790771,\n        1.1131789684295654,\n        1.1212987899780273,\n        1.1275365352630615,\n        1.1452269554138184,\n        1.1476627588272095,\n        1.1389117240905762,\n        1.1231611967086792,\n        1.1179301738739014\n      ],\n      \"q10\": [\n        1.202960729598999,\n        1.1801354885101318,\n        1.1744948625564575,\n        1.178760290145874,\n        1.1708077192306519,\n        1.152012586593628,\n        1.1264581680297852,\n        1.1220771074295044,\n        1.12774658203125,\n        1.1319509744644165,\n        1.1353538036346436,\n        1.1257888078689575,\n        1.1163818836212158,\n        1.1152591705322266,\n        1.1232290267944336,\n        1.1383938789367676,\n        1.1435673236846924,\n        1.131921648979187,\n        1.1226390600204468,\n        1.115145206451416\n      ],\n      \"q20\": [\n        1.0335861444473267,\n        0.9781290292739868,\n        0.948025643825531,\n        0.937298595905304,\n        0.9195546507835388,\n        0.8911022543907166,\n        0.8684503436088562,\n        0.8581703901290894,\n        0.8552865386009216,\n        0.8566405177116394,\n        0.8587369918823242,\n        0.8421598076820374,\n        0.8355081081390381,\n        0.835259735584259,\n        0.8424496650695801,\n        0.8557251691818237,\n        0.8595790863037109,\n        0.8550817966461182,\n        0.8462545871734619,\n        0.8529651761054993\n      ],\n      \"q80\": [\n        1.2702223062515259,\n        1.2614306211471558,\n        1.2629116773605347,\n        1.27401602268219,\n        1.2682753801345825,\n        1.253630518913269,\n        1.23259437084198,\n        1.2252973318099976,\n        1.2373583316802979,\n        1.2451832294464111,\n        1.2524268627166748,\n        1.2415071725845337,\n        1.2297941446304321,\n        1.2318909168243408,\n        1.242499828338623,\n        1.2596397399902344,\n        1.26153564453125,\n        1.2511622905731201,\n        1.2375322580337524,\n        1.2279977798461914\n      ],\n      \"q90\": [\n        1.3145431280136108,\n        1.313429594039917,\n        1.3208061456680298,\n        1.3359402418136597,\n        1.3367944955825806,\n        1.3163673877716064,\n        1.2994139194488525,\n        1.2974282503128052,\n        1.3131386041641235,\n        1.3206769227981567,\n        1.325730800628662,\n        1.3097118139266968,\n        1.2984554767608643,\n        1.2993582487106323,\n        1.3110051155090332,\n        1.328444242477417,\n        1.3302743434906006,\n        1.3201899528503418,\n        1.3005670309066772,\n        1.295330286026001\n      ]\n    },\n    {\n      \"step\": 18,\n      \"n_points\": 29,\n      \"horizon\": 19,\n      \"last_historical_date\": \"2024-05\",\n      \"historical_dates\": [\n        \"2022-01\",\n        \"2022-02\",\n        \"2022-03\",\n        \"2022-04\",\n        \"2022-05\",\n        \"2022-06\",\n        \"2022-07\",\n        \"2022-08\",\n        \"2022-09\",\n        \"2022-10\",\n        \"2022-11\",\n        \"2022-12\",\n        \"2023-01\",\n        \"2023-02\",\n        \"2023-03\",\n        \"2023-04\",\n        \"2023-05\",\n        \"2023-06\",\n        \"2023-07\",\n        \"2023-08\",\n        \"2023-09\",\n        \"2023-10\",\n        \"2023-11\",\n        \"2023-12\",\n        \"2024-01\",\n        \"2024-02\",\n        \"2024-03\",\n        \"2024-04\",\n        \"2024-05\"\n      ],\n      \"historical_values\": [\n        0.8899999856948853,\n        0.8899999856948853,\n        1.0199999809265137,\n        0.8799999952316284,\n        0.8500000238418579,\n        0.8799999952316284,\n        0.8799999952316284,\n        0.8999999761581421,\n        0.8799999952316284,\n        0.949999988079071,\n        0.7699999809265137,\n        0.7799999713897705,\n        0.8700000047683716,\n        0.9800000190734863,\n        1.2100000381469727,\n        1.0,\n        0.9399999976158142,\n        1.0800000429153442,\n        1.1799999475479126,\n        1.2400000095367432,\n        1.4700000286102295,\n        1.3200000524520874,\n        1.1799999475479126,\n        1.159999966621399,\n        1.2200000286102295,\n        1.350000023841858,\n        1.340000033378601,\n        1.2599999904632568,\n        1.149999976158142\n      ],\n      \"forecast_dates\": [\n        \"2024-06\",\n        \"2024-07\",\n        \"2024-08\",\n        \"2024-09\",\n        \"2024-10\",\n        \"2024-11\",\n        \"2024-12\",\n        \"2025-01\",\n        \"2025-02\",\n        \"2025-03\",\n        \"2025-04\",\n        \"2025-05\",\n        \"2025-06\",\n        \"2025-07\",\n        \"2025-08\",\n        \"2025-09\",\n        \"2025-10\",\n        \"2025-11\",\n        \"2025-12\"\n      ],\n      \"point_forecast\": [\n        1.1386852264404297,\n        1.1227259635925293,\n        1.1132360696792603,\n        1.103696346282959,\n        1.0890148878097534,\n        1.0628618001937866,\n        1.0592650175094604,\n        1.0809025764465332,\n        1.1213948726654053,\n        1.1205977201461792,\n        1.10319983959198,\n        1.0873777866363525,\n        1.0977184772491455,\n        1.1334233283996582,\n        1.1537142992019653,\n        1.15865159034729,\n        1.1413378715515137,\n        1.1311604976654053,\n        1.1258361339569092\n      ],\n      \"q10\": [\n        1.1357723474502563,\n        1.1218345165252686,\n        1.1151096820831299,\n        1.1036633253097534,\n        1.088782787322998,\n        1.0708427429199219,\n        1.0614827871322632,\n        1.0803805589675903,\n        1.1256681680679321,\n        1.124110460281372,\n        1.1017175912857056,\n        1.0866585969924927,\n        1.0974124670028687,\n        1.1265218257904053,\n        1.1448237895965576,\n        1.150303602218628,\n        1.131263256072998,\n        1.1206773519515991,\n        1.1218606233596802\n      ],\n      \"q20\": [\n        0.9705875515937805,\n        0.9261521100997925,\n        0.9002217650413513,\n        0.8800909519195557,\n        0.8597927689552307,\n        0.837051272392273,\n        0.8270405530929565,\n        0.8327914476394653,\n        0.8583639860153198,\n        0.8556785583496094,\n        0.8432221412658691,\n        0.8295676708221436,\n        0.8404796719551086,\n        0.8643808364868164,\n        0.8823158740997314,\n        0.8855088949203491,\n        0.8733288049697876,\n        0.8654991388320923,\n        0.8692165017127991\n      ],\n      \"q80\": [\n        1.2012592554092407,\n        1.2004612684249878,\n        1.1944599151611328,\n        1.1941598653793335,\n        1.178646206855774,\n        1.1608107089996338,\n        1.156977653503418,\n        1.1782780885696411,\n        1.2296812534332275,\n        1.235266089439392,\n        1.2120579481124878,\n        1.1956090927124023,\n        1.2027981281280518,\n        1.2328962087631226,\n        1.2583279609680176,\n        1.2622652053833008,\n        1.2420697212219238,\n        1.2296068668365479,\n        1.2310649156570435\n      ],\n      \"q90\": [\n        1.2429797649383545,\n        1.248335599899292,\n        1.2536859512329102,\n        1.251412272453308,\n        1.241403341293335,\n        1.21868097782135,\n        1.2173688411712646,\n        1.244056224822998,\n        1.3022620677947998,\n        1.3048560619354248,\n        1.2794227600097656,\n        1.25494384765625,\n        1.2627326250076294,\n        1.292664885520935,\n        1.3210376501083374,\n        1.3248177766799927,\n        1.303199291229248,\n        1.290137529373169,\n        1.2883186340332031\n      ]\n    },\n    {\n      \"step\": 19,\n      \"n_points\": 30,\n      \"horizon\": 18,\n      \"last_historical_date\": \"2024-06\",\n      \"historical_dates\": [\n        \"2022-01\",\n        \"2022-02\",\n        \"2022-03\",\n        \"2022-04\",\n        \"2022-05\",\n        \"2022-06\",\n        \"2022-07\",\n        \"2022-08\",\n        \"2022-09\",\n        \"2022-10\",\n        \"2022-11\",\n        \"2022-12\",\n        \"2023-01\",\n        \"2023-02\",\n        \"2023-03\",\n        \"2023-04\",\n        \"2023-05\",\n        \"2023-06\",\n        \"2023-07\",\n        \"2023-08\",\n        \"2023-09\",\n        \"2023-10\",\n        \"2023-11\",\n        \"2023-12\",\n        \"2024-01\",\n        \"2024-02\",\n        \"2024-03\",\n        \"2024-04\",\n        \"2024-05\",\n        \"2024-06\"\n      ],\n      \"historical_values\": [\n        0.8899999856948853,\n        0.8899999856948853,\n        1.0199999809265137,\n        0.8799999952316284,\n        0.8500000238418579,\n        0.8799999952316284,\n        0.8799999952316284,\n        0.8999999761581421,\n        0.8799999952316284,\n        0.949999988079071,\n        0.7699999809265137,\n        0.7799999713897705,\n        0.8700000047683716,\n        0.9800000190734863,\n        1.2100000381469727,\n        1.0,\n        0.9399999976158142,\n        1.0800000429153442,\n        1.1799999475479126,\n        1.2400000095367432,\n        1.4700000286102295,\n        1.3200000524520874,\n        1.1799999475479126,\n        1.159999966621399,\n        1.2200000286102295,\n        1.350000023841858,\n        1.340000033378601,\n        1.2599999904632568,\n        1.149999976158142,\n        1.2000000476837158\n      ],\n      \"forecast_dates\": [\n        \"2024-07\",\n        \"2024-08\",\n        \"2024-09\",\n        \"2024-10\",\n        \"2024-11\",\n        \"2024-12\",\n        \"2025-01\",\n        \"2025-02\",\n        \"2025-03\",\n        \"2025-04\",\n        \"2025-05\",\n        \"2025-06\",\n        \"2025-07\",\n        \"2025-08\",\n        \"2025-09\",\n        \"2025-10\",\n        \"2025-11\",\n        \"2025-12\"\n      ],\n      \"point_forecast\": [\n        1.1765440702438354,\n        1.1661514043807983,\n        1.1520631313323975,\n        1.1195285320281982,\n        1.0856300592422485,\n        1.0768202543258667,\n        1.0964417457580566,\n        1.1255871057510376,\n        1.155031442642212,\n        1.1183977127075195,\n        1.1013360023498535,\n        1.1082254648208618,\n        1.1356239318847656,\n        1.1829569339752197,\n        1.1888995170593262,\n        1.159764051437378,\n        1.126434564590454,\n        1.1302133798599243\n      ],\n      \"q10\": [\n        1.1751192808151245,\n        1.1651133298873901,\n        1.1592530012130737,\n        1.1195036172866821,\n        1.084028959274292,\n        1.0865756273269653,\n        1.099607229232788,\n        1.1274793148040771,\n        1.160447597503662,\n        1.1203389167785645,\n        1.0989832878112793,\n        1.1072871685028076,\n        1.1345447301864624,\n        1.1779069900512695,\n        1.1820926666259766,\n        1.1511759757995605,\n        1.1156119108200073,\n        1.121741771697998\n      ],\n      \"q20\": [\n        1.0206873416900635,\n        0.9838167428970337,\n        0.9575520157814026,\n        0.9151738882064819,\n        0.8827507495880127,\n        0.876349151134491,\n        0.8842628002166748,\n        0.8949983716011047,\n        0.9151624441146851,\n        0.883825421333313,\n        0.877031147480011,\n        0.8801717162132263,\n        0.9021454453468323,\n        0.9322755336761475,\n        0.9382153153419495,\n        0.9139386415481567,\n        0.8896767497062683,\n        0.8937186598777771\n      ],\n      \"q80\": [\n        1.2346465587615967,\n        1.238021969795227,\n        1.2284244298934937,\n        1.199608564376831,\n        1.1668167114257812,\n        1.165637731552124,\n        1.1883985996246338,\n        1.2180571556091309,\n        1.25492525100708,\n        1.219463586807251,\n        1.1989303827285767,\n        1.2049015760421753,\n        1.2325347661972046,\n        1.276305079460144,\n        1.2895640134811401,\n        1.2548282146453857,\n        1.217138648033142,\n        1.2198824882507324\n      ],\n      \"q90\": [\n        1.2738851308822632,\n        1.2814069986343384,\n        1.2860920429229736,\n        1.251664638519287,\n        1.2245914936065674,\n        1.2196787595748901,\n        1.2461426258087158,\n        1.2824065685272217,\n        1.3231412172317505,\n        1.2859265804290771,\n        1.2610337734222412,\n        1.2612855434417725,\n        1.2891101837158203,\n        1.3355929851531982,\n        1.3490444421768188,\n        1.3118960857391357,\n        1.2749104499816895,\n        1.2770724296569824\n      ]\n    },\n    {\n      \"step\": 20,\n      \"n_points\": 31,\n      \"horizon\": 17,\n      \"last_historical_date\": \"2024-07\",\n      \"historical_dates\": [\n        \"2022-01\",\n        \"2022-02\",\n        \"2022-03\",\n        \"2022-04\",\n        \"2022-05\",\n        \"2022-06\",\n        \"2022-07\",\n        \"2022-08\",\n        \"2022-09\",\n        \"2022-10\",\n        \"2022-11\",\n        \"2022-12\",\n        \"2023-01\",\n        \"2023-02\",\n        \"2023-03\",\n        \"2023-04\",\n        \"2023-05\",\n        \"2023-06\",\n        \"2023-07\",\n        \"2023-08\",\n        \"2023-09\",\n        \"2023-10\",\n        \"2023-11\",\n        \"2023-12\",\n        \"2024-01\",\n        \"2024-02\",\n        \"2024-03\",\n        \"2024-04\",\n        \"2024-05\",\n        \"2024-06\",\n        \"2024-07\"\n      ],\n      \"historical_values\": [\n        0.8899999856948853,\n        0.8899999856948853,\n        1.0199999809265137,\n        0.8799999952316284,\n        0.8500000238418579,\n        0.8799999952316284,\n        0.8799999952316284,\n        0.8999999761581421,\n        0.8799999952316284,\n        0.949999988079071,\n        0.7699999809265137,\n        0.7799999713897705,\n        0.8700000047683716,\n        0.9800000190734863,\n        1.2100000381469727,\n        1.0,\n        0.9399999976158142,\n        1.0800000429153442,\n        1.1799999475479126,\n        1.2400000095367432,\n        1.4700000286102295,\n        1.3200000524520874,\n        1.1799999475479126,\n        1.159999966621399,\n        1.2200000286102295,\n        1.350000023841858,\n        1.340000033378601,\n        1.2599999904632568,\n        1.149999976158142,\n        1.2000000476837158,\n        1.2400000095367432\n      ],\n      \"forecast_dates\": [\n        \"2024-08\",\n        \"2024-09\",\n        \"2024-10\",\n        \"2024-11\",\n        \"2024-12\",\n        \"2025-01\",\n        \"2025-02\",\n        \"2025-03\",\n        \"2025-04\",\n        \"2025-05\",\n        \"2025-06\",\n        \"2025-07\",\n        \"2025-08\",\n        \"2025-09\",\n        \"2025-10\",\n        \"2025-11\",\n        \"2025-12\"\n      ],\n      \"point_forecast\": [\n        1.2069008350372314,\n        1.193657636642456,\n        1.1575161218643188,\n        1.133849859237671,\n        1.1235467195510864,\n        1.1252387762069702,\n        1.1443586349487305,\n        1.1506012678146362,\n        1.134141206741333,\n        1.1200145483016968,\n        1.133240818977356,\n        1.1518402099609375,\n        1.1871200799942017,\n        1.2083441019058228,\n        1.1728931665420532,\n        1.1432249546051025,\n        1.133898377418518\n      ],\n      \"q10\": [\n        1.2029995918273926,\n        1.1932077407836914,\n        1.1641241312026978,\n        1.1338424682617188,\n        1.12429940700531,\n        1.1312663555145264,\n        1.1450644731521606,\n        1.1525075435638428,\n        1.1395219564437866,\n        1.121511697769165,\n        1.132306456565857,\n        1.1525789499282837,\n        1.1869525909423828,\n        1.2014143466949463,\n        1.168949007987976,\n        1.132044792175293,\n        1.1256910562515259\n      ],\n      \"q20\": [\n        1.0395362377166748,\n        0.9963122606277466,\n        0.951080322265625,\n        0.9185925126075745,\n        0.9062104821205139,\n        0.9065833687782288,\n        0.9181973934173584,\n        0.9136454463005066,\n        0.9018174409866333,\n        0.8859837055206299,\n        0.8985838890075684,\n        0.9077322483062744,\n        0.9346957206726074,\n        0.9418572187423706,\n        0.9179990291595459,\n        0.8948585987091064,\n        0.88960200548172\n      ],\n      \"q80\": [\n        1.2682971954345703,\n        1.2702425718307495,\n        1.239664077758789,\n        1.2174897193908691,\n        1.2065781354904175,\n        1.2155629396438599,\n        1.2398309707641602,\n        1.240811824798584,\n        1.2331410646438599,\n        1.2164467573165894,\n        1.2326842546463013,\n        1.251672387123108,\n        1.2885876893997192,\n        1.3034658432006836,\n        1.2739078998565674,\n        1.2376470565795898,\n        1.2269697189331055\n      ],\n      \"q90\": [\n        1.3094301223754883,\n        1.3151092529296875,\n        1.2952126264572144,\n        1.273212194442749,\n        1.2679336071014404,\n        1.2713508605957031,\n        1.2985038757324219,\n        1.305957555770874,\n        1.2964022159576416,\n        1.2809231281280518,\n        1.2935620546340942,\n        1.3095386028289795,\n        1.3458640575408936,\n        1.366443157196045,\n        1.3342082500457764,\n        1.2930103540420532,\n        1.2838038206100464\n      ]\n    },\n    {\n      \"step\": 21,\n      \"n_points\": 32,\n      \"horizon\": 16,\n      \"last_historical_date\": \"2024-08\",\n      \"historical_dates\": [\n        \"2022-01\",\n        \"2022-02\",\n        \"2022-03\",\n        \"2022-04\",\n        \"2022-05\",\n        \"2022-06\",\n        \"2022-07\",\n        \"2022-08\",\n        \"2022-09\",\n        \"2022-10\",\n        \"2022-11\",\n        \"2022-12\",\n        \"2023-01\",\n        \"2023-02\",\n        \"2023-03\",\n        \"2023-04\",\n        \"2023-05\",\n        \"2023-06\",\n        \"2023-07\",\n        \"2023-08\",\n        \"2023-09\",\n        \"2023-10\",\n        \"2023-11\",\n        \"2023-12\",\n        \"2024-01\",\n        \"2024-02\",\n        \"2024-03\",\n        \"2024-04\",\n        \"2024-05\",\n        \"2024-06\",\n        \"2024-07\",\n        \"2024-08\"\n      ],\n      \"historical_values\": [\n        0.8899999856948853,\n        0.8899999856948853,\n        1.0199999809265137,\n        0.8799999952316284,\n        0.8500000238418579,\n        0.8799999952316284,\n        0.8799999952316284,\n        0.8999999761581421,\n        0.8799999952316284,\n        0.949999988079071,\n        0.7699999809265137,\n        0.7799999713897705,\n        0.8700000047683716,\n        0.9800000190734863,\n        1.2100000381469727,\n        1.0,\n        0.9399999976158142,\n        1.0800000429153442,\n        1.1799999475479126,\n        1.2400000095367432,\n        1.4700000286102295,\n        1.3200000524520874,\n        1.1799999475479126,\n        1.159999966621399,\n        1.2200000286102295,\n        1.350000023841858,\n        1.340000033378601,\n        1.2599999904632568,\n        1.149999976158142,\n        1.2000000476837158,\n        1.2400000095367432,\n        1.2999999523162842\n      ],\n      \"forecast_dates\": [\n        \"2024-09\",\n        \"2024-10\",\n        \"2024-11\",\n        \"2024-12\",\n        \"2025-01\",\n        \"2025-02\",\n        \"2025-03\",\n        \"2025-04\",\n        \"2025-05\",\n        \"2025-06\",\n        \"2025-07\",\n        \"2025-08\",\n        \"2025-09\",\n        \"2025-10\",\n        \"2025-11\",\n        \"2025-12\"\n      ],\n      \"point_forecast\": [\n        1.2892454862594604,\n        1.2497223615646362,\n        1.2063699960708618,\n        1.2123697996139526,\n        1.2295829057693481,\n        1.2457282543182373,\n        1.2520256042480469,\n        1.1976659297943115,\n        1.1560035943984985,\n        1.15586519241333,\n        1.168123483657837,\n        1.188661813735962,\n        1.1947652101516724,\n        1.173640251159668,\n        1.128365397453308,\n        1.128602385520935\n      ],\n      \"q10\": [\n        1.2727627754211426,\n        1.2367907762527466,\n        1.1920455694198608,\n        1.1937742233276367,\n        1.2203925848007202,\n        1.2314530611038208,\n        1.2363964319229126,\n        1.1829954385757446,\n        1.1487408876419067,\n        1.1405112743377686,\n        1.1547985076904297,\n        1.1740177869796753,\n        1.1805450916290283,\n        1.1459304094314575,\n        1.1116427183151245,\n        1.0966339111328125\n      ],\n      \"q20\": [\n        1.11649489402771,\n        1.0445278882980347,\n        0.9846185445785522,\n        0.9668428897857666,\n        0.9715695977210999,\n        0.9662386178970337,\n        0.9553800821304321,\n        0.9113569855690002,\n        0.8853881359100342,\n        0.8746424913406372,\n        0.875267505645752,\n        0.8781014680862427,\n        0.8732690215110779,\n        0.8478219509124756,\n        0.8163697719573975,\n        0.815811276435852\n      ],\n      \"q80\": [\n        1.3429784774780273,\n        1.3280879259109497,\n        1.292254090309143,\n        1.3056862354278564,\n        1.3293191194534302,\n        1.352075219154358,\n        1.3609846830368042,\n        1.3075883388519287,\n        1.279836893081665,\n        1.272203803062439,\n        1.2965717315673828,\n        1.3177393674850464,\n        1.3210997581481934,\n        1.295129418373108,\n        1.2528834342956543,\n        1.246609091758728\n      ],\n      \"q90\": [\n        1.3865525722503662,\n        1.3712806701660156,\n        1.3499008417129517,\n        1.3717585802078247,\n        1.4015172719955444,\n        1.4236888885498047,\n        1.4422738552093506,\n        1.3891522884368896,\n        1.3545751571655273,\n        1.349416732788086,\n        1.363886833190918,\n        1.3921372890472412,\n        1.3967747688293457,\n        1.3780581951141357,\n        1.331864356994629,\n        1.3187098503112793\n      ]\n    },\n    {\n      \"step\": 22,\n      \"n_points\": 33,\n      \"horizon\": 15,\n      \"last_historical_date\": \"2024-09\",\n      \"historical_dates\": [\n        \"2022-01\",\n        \"2022-02\",\n        \"2022-03\",\n        \"2022-04\",\n        \"2022-05\",\n        \"2022-06\",\n        \"2022-07\",\n        \"2022-08\",\n        \"2022-09\",\n        \"2022-10\",\n        \"2022-11\",\n        \"2022-12\",\n        \"2023-01\",\n        \"2023-02\",\n        \"2023-03\",\n        \"2023-04\",\n        \"2023-05\",\n        \"2023-06\",\n        \"2023-07\",\n        \"2023-08\",\n        \"2023-09\",\n        \"2023-10\",\n        \"2023-11\",\n        \"2023-12\",\n        \"2024-01\",\n        \"2024-02\",\n        \"2024-03\",\n        \"2024-04\",\n        \"2024-05\",\n        \"2024-06\",\n        \"2024-07\",\n        \"2024-08\",\n        \"2024-09\"\n      ],\n      \"historical_values\": [\n        0.8899999856948853,\n        0.8899999856948853,\n        1.0199999809265137,\n        0.8799999952316284,\n        0.8500000238418579,\n        0.8799999952316284,\n        0.8799999952316284,\n        0.8999999761581421,\n        0.8799999952316284,\n        0.949999988079071,\n        0.7699999809265137,\n        0.7799999713897705,\n        0.8700000047683716,\n        0.9800000190734863,\n        1.2100000381469727,\n        1.0,\n        0.9399999976158142,\n        1.0800000429153442,\n        1.1799999475479126,\n        1.2400000095367432,\n        1.4700000286102295,\n        1.3200000524520874,\n        1.1799999475479126,\n        1.159999966621399,\n        1.2200000286102295,\n        1.350000023841858,\n        1.340000033378601,\n        1.2599999904632568,\n        1.149999976158142,\n        1.2000000476837158,\n        1.2400000095367432,\n        1.2999999523162842,\n        1.2799999713897705\n      ],\n      \"forecast_dates\": [\n        \"2024-10\",\n        \"2024-11\",\n        \"2024-12\",\n        \"2025-01\",\n        \"2025-02\",\n        \"2025-03\",\n        \"2025-04\",\n        \"2025-05\",\n        \"2025-06\",\n        \"2025-07\",\n        \"2025-08\",\n        \"2025-09\",\n        \"2025-10\",\n        \"2025-11\",\n        \"2025-12\"\n      ],\n      \"point_forecast\": [\n        1.2395873069763184,\n        1.192650318145752,\n        1.1737117767333984,\n        1.1951370239257812,\n        1.232491135597229,\n        1.265418291091919,\n        1.2109034061431885,\n        1.1846691370010376,\n        1.1904014348983765,\n        1.2089793682098389,\n        1.2557576894760132,\n        1.2761039733886719,\n        1.2492849826812744,\n        1.2014641761779785,\n        1.1954424381256104\n      ],\n      \"q10\": [\n        1.2416894435882568,\n        1.1871181726455688,\n        1.1744379997253418,\n        1.19320547580719,\n        1.2350860834121704,\n        1.2670172452926636,\n        1.211256980895996,\n        1.1898648738861084,\n        1.1905932426452637,\n        1.1989935636520386,\n        1.247326135635376,\n        1.268507480621338,\n        1.2414063215255737,\n        1.1882392168045044,\n        1.184570550918579\n      ],\n      \"q20\": [\n        1.097076654434204,\n        1.0414971113204956,\n        1.0175477266311646,\n        1.0278714895248413,\n        1.0624254941940308,\n        1.0802021026611328,\n        1.0272504091262817,\n        1.0036317110061646,\n        1.0009558200836182,\n        1.001404047012329,\n        1.0334482192993164,\n        1.042593240737915,\n        1.0162984132766724,\n        0.9763948321342468,\n        0.9707307815551758\n      ],\n      \"q80\": [\n        1.2870674133300781,\n        1.2494632005691528,\n        1.2323118448257446,\n        1.2594434022903442,\n        1.3010603189468384,\n        1.3373479843139648,\n        1.2841951847076416,\n        1.2637286186218262,\n        1.2685482501983643,\n        1.2876002788543701,\n        1.339444637298584,\n        1.3590757846832275,\n        1.3355648517608643,\n        1.2837905883789062,\n        1.2771517038345337\n      ],\n      \"q90\": [\n        1.3212705850601196,\n        1.2820069789886475,\n        1.2749484777450562,\n        1.2991927862167358,\n        1.3489611148834229,\n        1.384088397026062,\n        1.3305764198303223,\n        1.3098028898239136,\n        1.3174644708633423,\n        1.334850788116455,\n        1.387671709060669,\n        1.4108545780181885,\n        1.3836441040039062,\n        1.3309946060180664,\n        1.3257174491882324\n      ]\n    },\n    {\n      \"step\": 23,\n      \"n_points\": 34,\n      \"horizon\": 14,\n      \"last_historical_date\": \"2024-10\",\n      \"historical_dates\": [\n        \"2022-01\",\n        \"2022-02\",\n        \"2022-03\",\n        \"2022-04\",\n        \"2022-05\",\n        \"2022-06\",\n        \"2022-07\",\n        \"2022-08\",\n        \"2022-09\",\n        \"2022-10\",\n        \"2022-11\",\n        \"2022-12\",\n        \"2023-01\",\n        \"2023-02\",\n        \"2023-03\",\n        \"2023-04\",\n        \"2023-05\",\n        \"2023-06\",\n        \"2023-07\",\n        \"2023-08\",\n        \"2023-09\",\n        \"2023-10\",\n        \"2023-11\",\n        \"2023-12\",\n        \"2024-01\",\n        \"2024-02\",\n        \"2024-03\",\n        \"2024-04\",\n        \"2024-05\",\n        \"2024-06\",\n        \"2024-07\",\n        \"2024-08\",\n        \"2024-09\",\n        \"2024-10\"\n      ],\n      \"historical_values\": [\n        0.8899999856948853,\n        0.8899999856948853,\n        1.0199999809265137,\n        0.8799999952316284,\n        0.8500000238418579,\n        0.8799999952316284,\n        0.8799999952316284,\n        0.8999999761581421,\n        0.8799999952316284,\n        0.949999988079071,\n        0.7699999809265137,\n        0.7799999713897705,\n        0.8700000047683716,\n        0.9800000190734863,\n        1.2100000381469727,\n        1.0,\n        0.9399999976158142,\n        1.0800000429153442,\n        1.1799999475479126,\n        1.2400000095367432,\n        1.4700000286102295,\n        1.3200000524520874,\n        1.1799999475479126,\n        1.159999966621399,\n        1.2200000286102295,\n        1.350000023841858,\n        1.340000033378601,\n        1.2599999904632568,\n        1.149999976158142,\n        1.2000000476837158,\n        1.2400000095367432,\n        1.2999999523162842,\n        1.2799999713897705,\n        1.2699999809265137\n      ],\n      \"forecast_dates\": [\n        \"2024-11\",\n        \"2024-12\",\n        \"2025-01\",\n        \"2025-02\",\n        \"2025-03\",\n        \"2025-04\",\n        \"2025-05\",\n        \"2025-06\",\n        \"2025-07\",\n        \"2025-08\",\n        \"2025-09\",\n        \"2025-10\",\n        \"2025-11\",\n        \"2025-12\"\n      ],\n      \"point_forecast\": [\n        1.200866460800171,\n        1.1866711378097534,\n        1.2232941389083862,\n        1.2719991207122803,\n        1.2799842357635498,\n        1.2515898942947388,\n        1.1958189010620117,\n        1.19310462474823,\n        1.2179431915283203,\n        1.2518219947814941,\n        1.2716079950332642,\n        1.2360819578170776,\n        1.1987874507904053,\n        1.1850693225860596\n      ],\n      \"q10\": [\n        1.2021855115890503,\n        1.1821584701538086,\n        1.2226784229278564,\n        1.273689866065979,\n        1.2845158576965332,\n        1.2485958337783813,\n        1.1959373950958252,\n        1.1964659690856934,\n        1.2180784940719604,\n        1.2440263032913208,\n        1.2621558904647827,\n        1.2280503511428833,\n        1.1858408451080322,\n        1.1696057319641113\n      ],\n      \"q20\": [\n        1.0769736766815186,\n        1.0466127395629883,\n        1.0687201023101807,\n        1.1035237312316895,\n        1.1067966222763062,\n        1.0670413970947266,\n        1.0116249322891235,\n        1.003699779510498,\n        1.0221866369247437,\n        1.0382513999938965,\n        1.0417994260787964,\n        1.0053966045379639,\n        0.9645071029663086,\n        0.9537580609321594\n      ],\n      \"q80\": [\n        1.2458512783050537,\n        1.2381772994995117,\n        1.2802457809448242,\n        1.3395813703536987,\n        1.3537287712097168,\n        1.3230884075164795,\n        1.2715508937835693,\n        1.2736643552780151,\n        1.3004214763641357,\n        1.338258147239685,\n        1.3596911430358887,\n        1.3208271265029907,\n        1.2824501991271973,\n        1.2699368000030518\n      ],\n      \"q90\": [\n        1.2776029109954834,\n        1.2695484161376953,\n        1.3248724937438965,\n        1.3829126358032227,\n        1.4010111093521118,\n        1.3700647354125977,\n        1.3196228742599487,\n        1.3224942684173584,\n        1.3526369333267212,\n        1.3852152824401855,\n        1.4081038236618042,\n        1.3699979782104492,\n        1.3278801441192627,\n        1.3165735006332397\n      ]\n    },\n    {\n      \"step\": 24,\n      \"n_points\": 35,\n      \"horizon\": 13,\n      \"last_historical_date\": \"2024-11\",\n      \"historical_dates\": [\n        \"2022-01\",\n        \"2022-02\",\n        \"2022-03\",\n        \"2022-04\",\n        \"2022-05\",\n        \"2022-06\",\n        \"2022-07\",\n        \"2022-08\",\n        \"2022-09\",\n        \"2022-10\",\n        \"2022-11\",\n        \"2022-12\",\n        \"2023-01\",\n        \"2023-02\",\n        \"2023-03\",\n        \"2023-04\",\n        \"2023-05\",\n        \"2023-06\",\n        \"2023-07\",\n        \"2023-08\",\n        \"2023-09\",\n        \"2023-10\",\n        \"2023-11\",\n        \"2023-12\",\n        \"2024-01\",\n        \"2024-02\",\n        \"2024-03\",\n        \"2024-04\",\n        \"2024-05\",\n        \"2024-06\",\n        \"2024-07\",\n        \"2024-08\",\n        \"2024-09\",\n        \"2024-10\",\n        \"2024-11\"\n      ],\n      \"historical_values\": [\n        0.8899999856948853,\n        0.8899999856948853,\n        1.0199999809265137,\n        0.8799999952316284,\n        0.8500000238418579,\n        0.8799999952316284,\n        0.8799999952316284,\n        0.8999999761581421,\n        0.8799999952316284,\n        0.949999988079071,\n        0.7699999809265137,\n        0.7799999713897705,\n        0.8700000047683716,\n        0.9800000190734863,\n        1.2100000381469727,\n        1.0,\n        0.9399999976158142,\n        1.0800000429153442,\n        1.1799999475479126,\n        1.2400000095367432,\n        1.4700000286102295,\n        1.3200000524520874,\n        1.1799999475479126,\n        1.159999966621399,\n        1.2200000286102295,\n        1.350000023841858,\n        1.340000033378601,\n        1.2599999904632568,\n        1.149999976158142,\n        1.2000000476837158,\n        1.2400000095367432,\n        1.2999999523162842,\n        1.2799999713897705,\n        1.2699999809265137,\n        1.2200000286102295\n      ],\n      \"forecast_dates\": [\n        \"2024-12\",\n        \"2025-01\",\n        \"2025-02\",\n        \"2025-03\",\n        \"2025-04\",\n        \"2025-05\",\n        \"2025-06\",\n        \"2025-07\",\n        \"2025-08\",\n        \"2025-09\",\n        \"2025-10\",\n        \"2025-11\",\n        \"2025-12\"\n      ],\n      \"point_forecast\": [\n        1.2384696006774902,\n        1.2530195713043213,\n        1.3186349868774414,\n        1.3470391035079956,\n        1.2608959674835205,\n        1.1712164878845215,\n        1.1867536306381226,\n        1.2420611381530762,\n        1.26655912399292,\n        1.2961373329162598,\n        1.2294163703918457,\n        1.166834831237793,\n        1.1554596424102783\n      ],\n      \"q10\": [\n        1.2286468744277954,\n        1.2455438375473022,\n        1.3089576959609985,\n        1.3339853286743164,\n        1.2469478845596313,\n        1.149349570274353,\n        1.1650605201721191,\n        1.2206904888153076,\n        1.2502191066741943,\n        1.267012357711792,\n        1.2066657543182373,\n        1.1346192359924316,\n        1.115806221961975\n      ],\n      \"q20\": [\n        1.1213350296020508,\n        1.1162383556365967,\n        1.1598260402679443,\n        1.164056420326233,\n        1.0658612251281738,\n        0.9682412147521973,\n        0.9661321043968201,\n        1.0035676956176758,\n        1.0229461193084717,\n        1.0328454971313477,\n        0.9562720656394958,\n        0.8820796608924866,\n        0.8598078489303589\n      ],\n      \"q80\": [\n        1.2744736671447754,\n        1.3042716979980469,\n        1.3755841255187988,\n        1.4136451482772827,\n        1.3280566930770874,\n        1.2414281368255615,\n        1.265558123588562,\n        1.3200562000274658,\n        1.3582563400268555,\n        1.391558051109314,\n        1.325316071510315,\n        1.2584038972854614,\n        1.2471749782562256\n      ],\n      \"q90\": [\n        1.3033584356307983,\n        1.337569236755371,\n        1.418811321258545,\n        1.4565141201019287,\n        1.3784044981002808,\n        1.289095401763916,\n        1.3142638206481934,\n        1.378018856048584,\n        1.411327838897705,\n        1.4371509552001953,\n        1.3814256191253662,\n        1.3204567432403564,\n        1.3057005405426025\n      ]\n    },\n    {\n      \"step\": 25,\n      \"n_points\": 36,\n      \"horizon\": 12,\n      \"last_historical_date\": \"2024-12\",\n      \"historical_dates\": [\n        \"2022-01\",\n        \"2022-02\",\n        \"2022-03\",\n        \"2022-04\",\n        \"2022-05\",\n        \"2022-06\",\n        \"2022-07\",\n        \"2022-08\",\n        \"2022-09\",\n        \"2022-10\",\n        \"2022-11\",\n        \"2022-12\",\n        \"2023-01\",\n        \"2023-02\",\n        \"2023-03\",\n        \"2023-04\",\n        \"2023-05\",\n        \"2023-06\",\n        \"2023-07\",\n        \"2023-08\",\n        \"2023-09\",\n        \"2023-10\",\n        \"2023-11\",\n        \"2023-12\",\n        \"2024-01\",\n        \"2024-02\",\n        \"2024-03\",\n        \"2024-04\",\n        \"2024-05\",\n        \"2024-06\",\n        \"2024-07\",\n        \"2024-08\",\n        \"2024-09\",\n        \"2024-10\",\n        \"2024-11\",\n        \"2024-12\"\n      ],\n      \"historical_values\": [\n        0.8899999856948853,\n        0.8899999856948853,\n        1.0199999809265137,\n        0.8799999952316284,\n        0.8500000238418579,\n        0.8799999952316284,\n        0.8799999952316284,\n        0.8999999761581421,\n        0.8799999952316284,\n        0.949999988079071,\n        0.7699999809265137,\n        0.7799999713897705,\n        0.8700000047683716,\n        0.9800000190734863,\n        1.2100000381469727,\n        1.0,\n        0.9399999976158142,\n        1.0800000429153442,\n        1.1799999475479126,\n        1.2400000095367432,\n        1.4700000286102295,\n        1.3200000524520874,\n        1.1799999475479126,\n        1.159999966621399,\n        1.2200000286102295,\n        1.350000023841858,\n        1.340000033378601,\n        1.2599999904632568,\n        1.149999976158142,\n        1.2000000476837158,\n        1.2400000095367432,\n        1.2999999523162842,\n        1.2799999713897705,\n        1.2699999809265137,\n        1.2200000286102295,\n        1.2000000476837158\n      ],\n      \"forecast_dates\": [\n        \"2025-01\",\n        \"2025-02\",\n        \"2025-03\",\n        \"2025-04\",\n        \"2025-05\",\n        \"2025-06\",\n        \"2025-07\",\n        \"2025-08\",\n        \"2025-09\",\n        \"2025-10\",\n        \"2025-11\",\n        \"2025-12\"\n      ],\n      \"point_forecast\": [\n        1.25933837890625,\n        1.285666823387146,\n        1.2950127124786377,\n        1.2207623720169067,\n        1.170255422592163,\n        1.1455552577972412,\n        1.1702347993850708,\n        1.2026824951171875,\n        1.1909748315811157,\n        1.1490840911865234,\n        1.080478549003601,\n        1.0613453388214111\n      ],\n      \"q10\": [\n        1.2481880187988281,\n        1.2773758172988892,\n        1.286991834640503,\n        1.2084007263183594,\n        1.1533130407333374,\n        1.1275498867034912,\n        1.1510555744171143,\n        1.1859495639801025,\n        1.1784849166870117,\n        1.1264795064926147,\n        1.0624356269836426,\n        1.036609172821045\n      ],\n      \"q20\": [\n        1.1407020092010498,\n        1.1406043767929077,\n        1.126852035522461,\n        1.0352504253387451,\n        0.9691494703292847,\n        0.9420379400253296,\n        0.9503718018531799,\n        0.970925509929657,\n        0.9594371318817139,\n        0.9079477190971375,\n        0.8361266255378723,\n        0.8022069334983826\n      ],\n      \"q80\": [\n        1.2971320152282715,\n        1.3400218486785889,\n        1.3547290563583374,\n        1.2898554801940918,\n        1.2390310764312744,\n        1.2180578708648682,\n        1.248227596282959,\n        1.2842004299163818,\n        1.2832940816879272,\n        1.240414023399353,\n        1.175971508026123,\n        1.153149962425232\n      ],\n      \"q90\": [\n        1.3239599466323853,\n        1.3751201629638672,\n        1.403548240661621,\n        1.3310348987579346,\n        1.2891905307769775,\n        1.2702757120132446,\n        1.2997852563858032,\n        1.3408125638961792,\n        1.3354730606079102,\n        1.286876916885376,\n        1.2283769845962524,\n        1.2169079780578613\n      ]\n    }\n  ]\n}"
  },
  {
    "path": "timesfm-forecasting/examples/global-temperature/output/forecast_output.csv",
    "content": "date,point_forecast,q10,q20,q30,q40,q50,q60,q70,q80,q90,q99\n2025-01-01,1.2593384,1.248188,1.140702,1.1880752,1.2137158,1.2394564,1.2593384,1.2767732,1.297132,1.32396,1.367888\n2025-02-01,1.2856668,1.2773758,1.1406044,1.1960833,1.2322671,1.2593892,1.2856668,1.3110137,1.3400218,1.3751202,1.4253658\n2025-03-01,1.2950127,1.2869918,1.126852,1.1876173,1.234988,1.2675052,1.2950127,1.328448,1.354729,1.4035482,1.4642649\n2025-04-01,1.2207624,1.2084007,1.0352504,1.1041918,1.151865,1.1853008,1.2207624,1.256663,1.2898555,1.3310349,1.4016538\n2025-05-01,1.1702554,1.153313,0.9691495,1.0431063,1.0932612,1.1276176,1.1702554,1.201966,1.2390311,1.2891905,1.3632389\n2025-06-01,1.1455553,1.1275499,0.94203794,1.0110554,1.0658777,1.1061188,1.1455553,1.1806211,1.2180579,1.2702757,1.345366\n2025-07-01,1.1702348,1.1510556,0.9503718,1.0347577,1.0847733,1.1287677,1.1702348,1.2114835,1.2482276,1.2997853,1.3807325\n2025-08-01,1.2026825,1.1859496,0.9709255,1.0594383,1.1106675,1.1579902,1.2026825,1.2399211,1.2842004,1.3408126,1.419526\n2025-09-01,1.1909748,1.1784849,0.95943713,1.0403702,1.103606,1.1511956,1.1909748,1.2390201,1.2832941,1.3354731,1.416972\n2025-10-01,1.1490841,1.1264795,0.9079477,0.99529266,1.0548235,1.1052223,1.1490841,1.1897774,1.240414,1.2868769,1.3775467\n2025-11-01,1.0804785,1.0624356,0.8361266,0.9259792,0.9882403,1.0386353,1.0804785,1.1281581,1.1759715,1.228377,1.3122478\n2025-12-01,1.0613453,1.0366092,0.80220693,0.89521873,0.9593707,1.0152239,1.0613453,1.1032857,1.15315,1.216908,1.2959521\n"
  },
  {
    "path": "timesfm-forecasting/examples/global-temperature/output/forecast_output.json",
    "content": "{\n  \"model\": \"TimesFM 1.0 (200M) PyTorch\",\n  \"input\": {\n    \"source\": \"NOAA GISTEMP Global Temperature Anomaly\",\n    \"n_observations\": 36,\n    \"date_range\": \"2022-01 to 2024-12\",\n    \"mean_anomaly_c\": 1.09\n  },\n  \"forecast\": {\n    \"horizon\": 12,\n    \"dates\": [\n      \"2025-01\",\n      \"2025-02\",\n      \"2025-03\",\n      \"2025-04\",\n      \"2025-05\",\n      \"2025-06\",\n      \"2025-07\",\n      \"2025-08\",\n      \"2025-09\",\n      \"2025-10\",\n      \"2025-11\",\n      \"2025-12\"\n    ],\n    \"point\": [\n      1.25933837890625,\n      1.285666823387146,\n      1.2950127124786377,\n      1.2207623720169067,\n      1.170255422592163,\n      1.1455552577972412,\n      1.1702347993850708,\n      1.2026824951171875,\n      1.1909748315811157,\n      1.1490840911865234,\n      1.080478549003601,\n      1.0613453388214111\n    ],\n    \"quantiles\": {\n      \"10%\": [\n        1.2481880187988281,\n        1.2773758172988892,\n        1.286991834640503,\n        1.2084007263183594,\n        1.1533130407333374,\n        1.1275498867034912,\n        1.1510555744171143,\n        1.1859495639801025,\n        1.1784849166870117,\n        1.1264795064926147,\n        1.0624356269836426,\n        1.036609172821045\n      ],\n      \"20%\": [\n        1.1407020092010498,\n        1.1406043767929077,\n        1.126852035522461,\n        1.0352504253387451,\n        0.9691494703292847,\n        0.9420379400253296,\n        0.9503718018531799,\n        0.970925509929657,\n        0.9594371318817139,\n        0.9079477190971375,\n        0.8361266255378723,\n        0.8022069334983826\n      ],\n      \"30%\": [\n        1.1880751848220825,\n        1.1960833072662354,\n        1.187617301940918,\n        1.104191780090332,\n        1.0431063175201416,\n        1.01105535030365,\n        1.0347577333450317,\n        1.0594383478164673,\n        1.040370225906372,\n        0.9952926635742188,\n        0.9259791970252991,\n        0.8952187299728394\n      ],\n      \"40%\": [\n        1.2137157917022705,\n        1.232267141342163,\n        1.2349879741668701,\n        1.151865005493164,\n        1.0932612419128418,\n        1.0658776760101318,\n        1.084773302078247,\n        1.1106674671173096,\n        1.1036059856414795,\n        1.0548235177993774,\n        0.9882403016090393,\n        0.9593706727027893\n      ],\n      \"50%\": [\n        1.2394564151763916,\n        1.2593891620635986,\n        1.267505168914795,\n        1.1853008270263672,\n        1.127617597579956,\n        1.1061187982559204,\n        1.128767728805542,\n        1.1579902172088623,\n        1.1511956453323364,\n        1.1052223443984985,\n        1.03863525390625,\n        1.0152238607406616\n      ],\n      \"60%\": [\n        1.25933837890625,\n        1.285666823387146,\n        1.2950127124786377,\n        1.2207623720169067,\n        1.170255422592163,\n        1.1455552577972412,\n        1.1702347993850708,\n        1.2026824951171875,\n        1.1909748315811157,\n        1.1490840911865234,\n        1.080478549003601,\n        1.0613453388214111\n      ],\n      \"70%\": [\n        1.27677321434021,\n        1.3110136985778809,\n        1.3284480571746826,\n        1.2566629648208618,\n        1.2019660472869873,\n        1.1806211471557617,\n        1.2114834785461426,\n        1.2399210929870605,\n        1.2390201091766357,\n        1.1897773742675781,\n        1.1281580924987793,\n        1.1032856702804565\n      ],\n      \"80%\": [\n        1.2971320152282715,\n        1.3400218486785889,\n        1.3547290563583374,\n        1.2898554801940918,\n        1.2390310764312744,\n        1.2180578708648682,\n        1.248227596282959,\n        1.2842004299163818,\n        1.2832940816879272,\n        1.240414023399353,\n        1.175971508026123,\n        1.153149962425232\n      ],\n      \"90%\": [\n        1.3239599466323853,\n        1.3751201629638672,\n        1.403548240661621,\n        1.3310348987579346,\n        1.2891905307769775,\n        1.2702757120132446,\n        1.2997852563858032,\n        1.3408125638961792,\n        1.3354730606079102,\n        1.286876916885376,\n        1.2283769845962524,\n        1.2169079780578613\n      ],\n      \"99%\": [\n        1.3678879737854004,\n        1.4253658056259155,\n        1.4642648696899414,\n        1.40165376663208,\n        1.3632389307022095,\n        1.3453660011291504,\n        1.380732536315918,\n        1.4195259809494019,\n        1.416972041130066,\n        1.3775466680526733,\n        1.3122477531433105,\n        1.2959520816802979\n      ]\n    }\n  },\n  \"summary\": {\n    \"forecast_mean_c\": 1.186,\n    \"forecast_max_c\": 1.295,\n    \"forecast_min_c\": 1.061,\n    \"vs_last_year_mean\": -0.067\n  }\n}"
  },
  {
    "path": "timesfm-forecasting/examples/global-temperature/output/interactive_forecast.html",
    "content": "<!DOCTYPE html>\n<html lang=\"en\">\n<head>\n    <meta charset=\"UTF-8\">\n    <meta name=\"viewport\" content=\"width=device-width, initial-scale=1.0\">\n    <title>TimesFM Interactive Forecast Animation</title>\n    <script src=\"https://cdn.jsdelivr.net/npm/chart.js\"></script>\n    <style>\n        * { margin: 0; padding: 0; box-sizing: border-box; }\n        \n        body {\n            font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;\n            background: linear-gradient(135deg, #1a1a2e 0%, #16213e 100%);\n            min-height: 100vh;\n            color: #e0e0e0;\n            padding: 20px;\n        }\n        \n        .container { max-width: 1200px; margin: 0 auto; }\n        \n        header { text-align: center; margin-bottom: 30px; }\n        \n        h1 {\n            font-size: 2rem;\n            margin-bottom: 10px;\n            background: linear-gradient(90deg, #60a5fa, #a78bfa);\n            -webkit-background-clip: text;\n            -webkit-text-fill-color: transparent;\n        }\n        \n        .subtitle { color: #9ca3af; font-size: 1.1rem; }\n        \n        .chart-container {\n            background: rgba(255, 255, 255, 0.05);\n            border-radius: 16px;\n            padding: 20px;\n            margin-bottom: 20px;\n            box-shadow: 0 4px 20px rgba(0, 0, 0, 0.3);\n        }\n        \n        #chart { width: 100% !important; height: 450px !important; }\n        \n        .controls {\n            display: flex;\n            flex-direction: column;\n            gap: 20px;\n            background: rgba(255, 255, 255, 0.05);\n            border-radius: 16px;\n            padding: 20px;\n        }\n        \n        .slider-container { display: flex; flex-direction: column; gap: 10px; }\n        \n        .slider-label { display: flex; justify-content: space-between; align-items: center; }\n        .slider-label span { font-size: 0.9rem; color: #9ca3af; }\n        .slider-label .value { font-weight: 600; color: #60a5fa; font-size: 1.1rem; }\n        \n        input[type=\"range\"] {\n            width: 100%; height: 8px; border-radius: 4px;\n            background: #374151; outline: none; -webkit-appearance: none;\n        }\n        \n        input[type=\"range\"]::-webkit-slider-thumb {\n            -webkit-appearance: none;\n            width: 24px; height: 24px; border-radius: 50%;\n            background: linear-gradient(135deg, #60a5fa, #a78bfa);\n            cursor: pointer;\n            box-shadow: 0 2px 10px rgba(96, 165, 250, 0.5);\n        }\n        \n        .buttons { display: flex; gap: 10px; flex-wrap: wrap; }\n        \n        button {\n            flex: 1; min-width: 100px;\n            padding: 12px 20px;\n            border: none; border-radius: 8px;\n            font-size: 1rem; font-weight: 600;\n            cursor: pointer; transition: all 0.2s ease;\n        }\n        \n        .btn-primary {\n            background: linear-gradient(135deg, #60a5fa, #a78bfa);\n            color: white;\n        }\n        .btn-primary:hover { transform: translateY(-2px); box-shadow: 0 4px 15px rgba(96, 165, 250, 0.4); }\n        \n        .btn-secondary { background: #374151; color: #e0e0e0; }\n        .btn-secondary:hover { background: #4b5563; }\n        \n        .stats {\n            display: grid;\n            grid-template-columns: repeat(auto-fit, minmax(150px, 1fr));\n            gap: 15px;\n            margin-top: 20px;\n        }\n        \n        .stat-card {\n            background: rgba(255, 255, 255, 0.05);\n            border-radius: 12px;\n            padding: 15px;\n            text-align: center;\n        }\n        .stat-card .label { font-size: 0.8rem; color: #9ca3af; margin-bottom: 5px; }\n        .stat-card .value { font-size: 1.3rem; font-weight: 600; color: #60a5fa; }\n        \n        .legend {\n            display: flex;\n            justify-content: center;\n            gap: 20px;\n            flex-wrap: wrap;\n            margin-top: 15px;\n            padding-top: 15px;\n            border-top: 1px solid rgba(255, 255, 255, 0.1);\n        }\n        \n        .legend-item { display: flex; align-items: center; gap: 8px; font-size: 0.85rem; }\n        .legend-color { width: 16px; height: 16px; border-radius: 4px; }\n        \n        footer {\n            text-align: center;\n            margin-top: 30px;\n            color: #6b7280;\n            font-size: 0.9rem;\n        }\n        footer a { color: #60a5fa; text-decoration: none; }\n    </style>\n</head>\n<body>\n    <div class=\"container\">\n        <header>\n            <h1>TimesFM Forecast Evolution</h1>\n            <p class=\"subtitle\">Watch the forecast evolve as more data is added — forecasts extend to 2025-12</p>\n        </header>\n        \n        <div class=\"chart-container\">\n            <canvas id=\"chart\"></canvas>\n        </div>\n        \n        <div class=\"controls\">\n            <div class=\"slider-container\">\n                <div class=\"slider-label\">\n                    <span>Data Points Used</span>\n                    <span class=\"value\" id=\"points-value\">12 / 36</span>\n                </div>\n                <input type=\"range\" id=\"slider\" min=\"0\" max=\"24\" value=\"0\" step=\"1\">\n                <div class=\"slider-label\">\n                    <span>2022-01</span>\n                    <span id=\"date-end\">Using data through 2022-12</span>\n                </div>\n            </div>\n            \n            <div class=\"buttons\">\n                <button class=\"btn-primary\" id=\"play-btn\">▶ Play</button>\n                <button class=\"btn-secondary\" id=\"reset-btn\">↺ Reset</button>\n            </div>\n            \n            <div class=\"stats\">\n                <div class=\"stat-card\">\n                    <div class=\"label\">Forecast Mean</div>\n                    <div class=\"value\" id=\"stat-mean\">0.86°C</div>\n                </div>\n                <div class=\"stat-card\">\n                    <div class=\"label\">Forecast Horizon</div>\n                    <div class=\"value\" id=\"stat-horizon\">36 months</div>\n                </div>\n                <div class=\"stat-card\">\n                    <div class=\"label\">Forecast Max</div>\n                    <div class=\"value\" id=\"stat-max\">--</div>\n                </div>\n                <div class=\"stat-card\">\n                    <div class=\"label\">Forecast Min</div>\n                    <div class=\"value\" id=\"stat-min\">--</div>\n                </div>\n            </div>\n            \n            <div class=\"legend\">\n                <div class=\"legend-item\">\n                    <div class=\"legend-color\" style=\"background: #9ca3af;\"></div>\n                    <span>All Observed Data</span>\n                </div>\n                <div class=\"legend-item\">\n                    <div class=\"legend-color\" style=\"background: #fca5a5;\"></div>\n                    <span>Final Forecast (reference)</span>\n                </div>\n                <div class=\"legend-item\">\n                    <div class=\"legend-color\" style=\"background: #3b82f6;\"></div>\n                    <span>Data Used</span>\n                </div>\n                <div class=\"legend-item\">\n                    <div class=\"legend-color\" style=\"background: #ef4444;\"></div>\n                    <span>Current Forecast</span>\n                </div>\n                <div class=\"legend-item\">\n                    <div class=\"legend-color\" style=\"background: rgba(239, 68, 68, 0.25);\"></div>\n                    <span>80% CI</span>\n                </div>\n            </div>\n        </div>\n        \n        <footer>\n            <p>TimesFM 1.0 (200M) PyTorch • <a href=\"https://github.com/google-research/timesfm\">Google Research</a></p>\n        </footer>\n    </div>\n\n    <script>\n        // Embedded animation data (no external fetch needed)\n        const animationData = {\n  \"metadata\": {\n    \"model\": \"TimesFM 1.0 (200M) PyTorch\",\n    \"total_steps\": 25,\n    \"min_context\": 12,\n    \"max_horizon\": 36,\n    \"total_months\": 48,\n    \"data_source\": \"NOAA GISTEMP Global Temperature Anomaly\",\n    \"full_date_range\": \"2022-01 to 2024-12\"\n  },\n  \"actual_data\": {\n    \"dates\": [\n      \"2022-01\",\n      \"2022-02\",\n      \"2022-03\",\n      \"2022-04\",\n      \"2022-05\",\n      \"2022-06\",\n      \"2022-07\",\n      \"2022-08\",\n      \"2022-09\",\n      \"2022-10\",\n      \"2022-11\",\n      \"2022-12\",\n      \"2023-01\",\n      \"2023-02\",\n      \"2023-03\",\n      \"2023-04\",\n      \"2023-05\",\n      \"2023-06\",\n      \"2023-07\",\n      \"2023-08\",\n      \"2023-09\",\n      \"2023-10\",\n      \"2023-11\",\n      \"2023-12\",\n      \"2024-01\",\n      \"2024-02\",\n      \"2024-03\",\n      \"2024-04\",\n      \"2024-05\",\n      \"2024-06\",\n      \"2024-07\",\n      \"2024-08\",\n      \"2024-09\",\n      \"2024-10\",\n      \"2024-11\",\n      \"2024-12\"\n    ],\n    \"values\": [\n      0.8899999856948853,\n      0.8899999856948853,\n      1.0199999809265137,\n      0.8799999952316284,\n      0.8500000238418579,\n      0.8799999952316284,\n      0.8799999952316284,\n      0.8999999761581421,\n      0.8799999952316284,\n      0.949999988079071,\n      0.7699999809265137,\n      0.7799999713897705,\n      0.8700000047683716,\n      0.9800000190734863,\n      1.2100000381469727,\n      1.0,\n      0.9399999976158142,\n      1.0800000429153442,\n      1.1799999475479126,\n      1.2400000095367432,\n      1.4700000286102295,\n      1.3200000524520874,\n      1.1799999475479126,\n      1.159999966621399,\n      1.2200000286102295,\n      1.350000023841858,\n      1.340000033378601,\n      1.2599999904632568,\n      1.149999976158142,\n      1.2000000476837158,\n      1.2400000095367432,\n      1.2999999523162842,\n      1.2799999713897705,\n      1.2699999809265137,\n      1.2200000286102295,\n      1.2000000476837158\n    ]\n  },\n  \"animation_steps\": [\n    {\n      \"step\": 1,\n      \"n_points\": 12,\n      \"horizon\": 36,\n      \"last_historical_date\": \"2022-12\",\n      \"historical_dates\": [\n        \"2022-01\",\n        \"2022-02\",\n        \"2022-03\",\n        \"2022-04\",\n        \"2022-05\",\n        \"2022-06\",\n        \"2022-07\",\n        \"2022-08\",\n        \"2022-09\",\n        \"2022-10\",\n        \"2022-11\",\n        \"2022-12\"\n      ],\n      \"historical_values\": [\n        0.8899999856948853,\n        0.8899999856948853,\n        1.0199999809265137,\n        0.8799999952316284,\n        0.8500000238418579,\n        0.8799999952316284,\n        0.8799999952316284,\n        0.8999999761581421,\n        0.8799999952316284,\n        0.949999988079071,\n        0.7699999809265137,\n        0.7799999713897705\n      ],\n      \"forecast_dates\": [\n        \"2023-01\",\n        \"2023-02\",\n        \"2023-03\",\n        \"2023-04\",\n        \"2023-05\",\n        \"2023-06\",\n        \"2023-07\",\n        \"2023-08\",\n        \"2023-09\",\n        \"2023-10\",\n        \"2023-11\",\n        \"2023-12\",\n        \"2024-01\",\n        \"2024-02\",\n        \"2024-03\",\n        \"2024-04\",\n        \"2024-05\",\n        \"2024-06\",\n        \"2024-07\",\n        \"2024-08\",\n        \"2024-09\",\n        \"2024-10\",\n        \"2024-11\",\n        \"2024-12\",\n        \"2025-01\",\n        \"2025-02\",\n        \"2025-03\",\n        \"2025-04\",\n        \"2025-05\",\n        \"2025-06\",\n        \"2025-07\",\n        \"2025-08\",\n        \"2025-09\",\n        \"2025-10\",\n        \"2025-11\",\n        \"2025-12\"\n      ],\n      \"point_forecast\": [\n        0.825579047203064,\n        0.8330779075622559,\n        0.8368334174156189,\n        0.8413563370704651,\n        0.8546873331069946,\n        0.8463932275772095,\n        0.852830708026886,\n        0.8635484576225281,\n        0.873649001121521,\n        0.8784391283988953,\n        0.8793435096740723,\n        0.886539101600647,\n        0.876642107963562,\n        0.8771936297416687,\n        0.8794507384300232,\n        0.8818798065185547,\n        0.8801761269569397,\n        0.878594696521759,\n        0.8841555714607239,\n        0.8686957955360413,\n        0.8627567887306213,\n        0.8599377870559692,\n        0.8534176349639893,\n        0.8439264297485352,\n        0.8403507471084595,\n        0.84540855884552,\n        0.8334686756134033,\n        0.8366615176200867,\n        0.8480817079544067,\n        0.8587210178375244,\n        0.865203857421875,\n        0.8715710043907166,\n        0.883372962474823,\n        0.8742744326591492,\n        0.8734725117683411,\n        0.8783032894134521\n      ],\n      \"q10\": [\n        0.8354606032371521,\n        0.8444467782974243,\n        0.8485234975814819,\n        0.8526979088783264,\n        0.8648908138275146,\n        0.8568621277809143,\n        0.863645076751709,\n        0.872414231300354,\n        0.8817781209945679,\n        0.8863298892974854,\n        0.8866963982582092,\n        0.8946276903152466,\n        0.8833872675895691,\n        0.8827563524246216,\n        0.8864266872406006,\n        0.887717604637146,\n        0.8854249715805054,\n        0.8838265538215637,\n        0.890777051448822,\n        0.8747947812080383,\n        0.8702181577682495,\n        0.8688124418258667,\n        0.8621772527694702,\n        0.8549044728279114,\n        0.8520718812942505,\n        0.8580353856086731,\n        0.8461477756500244,\n        0.8497025966644287,\n        0.8604429364204407,\n        0.8707754015922546,\n        0.8765125870704651,\n        0.8818733096122742,\n        0.893653154373169,\n        0.8849858045578003,\n        0.8816121220588684,\n        0.8867135643959045\n      ],\n      \"q20\": [\n        0.7518579959869385,\n        0.752423882484436,\n        0.7527720928192139,\n        0.7547875642776489,\n        0.7639567852020264,\n        0.7600989937782288,\n        0.7671870589256287,\n        0.7746827006340027,\n        0.783061146736145,\n        0.7859532237052917,\n        0.7876774072647095,\n        0.7946517467498779,\n        0.7890393137931824,\n        0.7905672192573547,\n        0.7923871874809265,\n        0.7943510413169861,\n        0.7928767204284668,\n        0.7914355993270874,\n        0.7945701479911804,\n        0.784331738948822,\n        0.7799307107925415,\n        0.7775163650512695,\n        0.772225022315979,\n        0.7648971676826477,\n        0.7586244940757751,\n        0.7592141032218933,\n        0.7497149705886841,\n        0.7515254020690918,\n        0.76014643907547,\n        0.7683113813400269,\n        0.7757765054702759,\n        0.7805572748184204,\n        0.790294349193573,\n        0.7851614952087402,\n        0.7844950556755066,\n        0.7886985540390015\n      ],\n      \"q80\": [\n        0.8621454238891602,\n        0.8726990222930908,\n        0.8780758380889893,\n        0.8830247521400452,\n        0.895999014377594,\n        0.8877173066139221,\n        0.8932443261146545,\n        0.9029491543769836,\n        0.9142329096794128,\n        0.918304979801178,\n        0.9192531704902649,\n        0.9270545244216919,\n        0.9149025082588196,\n        0.9147888422012329,\n        0.91729736328125,\n        0.9190108776092529,\n        0.9174938201904297,\n        0.916400671005249,\n        0.9234370589256287,\n        0.9071342349052429,\n        0.9007507562637329,\n        0.8995751142501831,\n        0.8921940326690674,\n        0.8833961486816406,\n        0.8816472291946411,\n        0.8888989686965942,\n        0.8762903809547424,\n        0.8794605731964111,\n        0.891765832901001,\n        0.9021292328834534,\n        0.9087244868278503,\n        0.9149095416069031,\n        0.9275970458984375,\n        0.9168868660926819,\n        0.9142359495162964,\n        0.9194778800010681\n      ],\n      \"q90\": [\n        0.8872727155685425,\n        0.8990722298622131,\n        0.9044539928436279,\n        0.9107659459114075,\n        0.9254093170166016,\n        0.9146999716758728,\n        0.9196149706840515,\n        0.9299551844596863,\n        0.941527783870697,\n        0.9455176591873169,\n        0.9463357925415039,\n        0.9539710283279419,\n        0.9405434727668762,\n        0.9397023320198059,\n        0.9439040422439575,\n        0.9448938369750977,\n        0.9431376457214355,\n        0.9417189359664917,\n        0.9492916464805603,\n        0.9315186738967896,\n        0.9267769455909729,\n        0.925445020198822,\n        0.9191145300865173,\n        0.910182535648346,\n        0.9100216031074524,\n        0.9180203676223755,\n        0.9048261046409607,\n        0.9081428050994873,\n        0.9206303954124451,\n        0.9308969974517822,\n        0.9380975961685181,\n        0.9430014491081238,\n        0.9572127461433411,\n        0.9447380304336548,\n        0.9412767291069031,\n        0.9464495778083801\n      ]\n    },\n    {\n      \"step\": 2,\n      \"n_points\": 13,\n      \"horizon\": 35,\n      \"last_historical_date\": \"2023-01\",\n      \"historical_dates\": [\n        \"2022-01\",\n        \"2022-02\",\n        \"2022-03\",\n        \"2022-04\",\n        \"2022-05\",\n        \"2022-06\",\n        \"2022-07\",\n        \"2022-08\",\n        \"2022-09\",\n        \"2022-10\",\n        \"2022-11\",\n        \"2022-12\",\n        \"2023-01\"\n      ],\n      \"historical_values\": [\n        0.8899999856948853,\n        0.8899999856948853,\n        1.0199999809265137,\n        0.8799999952316284,\n        0.8500000238418579,\n        0.8799999952316284,\n        0.8799999952316284,\n        0.8999999761581421,\n        0.8799999952316284,\n        0.949999988079071,\n        0.7699999809265137,\n        0.7799999713897705,\n        0.8700000047683716\n      ],\n      \"forecast_dates\": [\n        \"2023-02\",\n        \"2023-03\",\n        \"2023-04\",\n        \"2023-05\",\n        \"2023-06\",\n        \"2023-07\",\n        \"2023-08\",\n        \"2023-09\",\n        \"2023-10\",\n        \"2023-11\",\n        \"2023-12\",\n        \"2024-01\",\n        \"2024-02\",\n        \"2024-03\",\n        \"2024-04\",\n        \"2024-05\",\n        \"2024-06\",\n        \"2024-07\",\n        \"2024-08\",\n        \"2024-09\",\n        \"2024-10\",\n        \"2024-11\",\n        \"2024-12\",\n        \"2025-01\",\n        \"2025-02\",\n        \"2025-03\",\n        \"2025-04\",\n        \"2025-05\",\n        \"2025-06\",\n        \"2025-07\",\n        \"2025-08\",\n        \"2025-09\",\n        \"2025-10\",\n        \"2025-11\",\n        \"2025-12\"\n      ],\n      \"point_forecast\": [\n        0.8590402007102966,\n        0.8596092462539673,\n        0.864223062992096,\n        0.8694167733192444,\n        0.8599939346313477,\n        0.8577529191970825,\n        0.8670657873153687,\n        0.8746083378791809,\n        0.8758000731468201,\n        0.8808236718177795,\n        0.8853851556777954,\n        0.8753982186317444,\n        0.8732624053955078,\n        0.8803924322128296,\n        0.8831377029418945,\n        0.8812252879142761,\n        0.8837805986404419,\n        0.8842109441757202,\n        0.8692948818206787,\n        0.8612740635871887,\n        0.8624085783958435,\n        0.8617072105407715,\n        0.8601858615875244,\n        0.8625096082687378,\n        0.8663285374641418,\n        0.8544762134552002,\n        0.8533855080604553,\n        0.862159013748169,\n        0.8707855343818665,\n        0.872623860836029,\n        0.878368079662323,\n        0.8822183012962341,\n        0.8722400665283203,\n        0.8674668669700623,\n        0.8758878111839294\n      ],\n      \"q10\": [\n        0.8657022714614868,\n        0.867158055305481,\n        0.8720226287841797,\n        0.8764638900756836,\n        0.8662244081497192,\n        0.8640622496604919,\n        0.873618483543396,\n        0.8803330063819885,\n        0.8822183609008789,\n        0.8867899775505066,\n        0.8920900821685791,\n        0.8817423582077026,\n        0.8790065050125122,\n        0.8854852914810181,\n        0.8888370394706726,\n        0.8871243596076965,\n        0.8896916508674622,\n        0.8902166485786438,\n        0.8758934736251831,\n        0.8675172924995422,\n        0.8692970871925354,\n        0.8685914874076843,\n        0.8668439388275146,\n        0.8710702061653137,\n        0.8750268220901489,\n        0.8633314967155457,\n        0.8620151281356812,\n        0.8703252077102661,\n        0.8786934614181519,\n        0.8804004192352295,\n        0.8853165507316589,\n        0.889494776725769,\n        0.8794597387313843,\n        0.8745465278625488,\n        0.8814859390258789\n      ],\n      \"q20\": [\n        0.779899537563324,\n        0.7763701677322388,\n        0.7775852680206299,\n        0.7800794839859009,\n        0.7750610113143921,\n        0.7753159403800964,\n        0.7829091548919678,\n        0.7884992957115173,\n        0.7900261878967285,\n        0.7911601066589355,\n        0.7951517105102539,\n        0.7891175746917725,\n        0.7887728810310364,\n        0.7934086918830872,\n        0.7968956232070923,\n        0.7951973080635071,\n        0.796229898929596,\n        0.7950001358985901,\n        0.7845399379730225,\n        0.7791075110435486,\n        0.7789998650550842,\n        0.7794902324676514,\n        0.7773360013961792,\n        0.7764586806297302,\n        0.7767698168754578,\n        0.7689880132675171,\n        0.7689797282218933,\n        0.7759402394294739,\n        0.7828512787818909,\n        0.7850325107574463,\n        0.7882039546966553,\n        0.7904639840126038,\n        0.7844158411026001,\n        0.7818136215209961,\n        0.7875857353210449\n      ],\n      \"q80\": [\n        0.8950973153114319,\n        0.8978567719459534,\n        0.9036805033683777,\n        0.9098731875419617,\n        0.8973860144615173,\n        0.8958126306533813,\n        0.9049636125564575,\n        0.9123932123184204,\n        0.9138861298561096,\n        0.9191209077835083,\n        0.9256614446640015,\n        0.9137347936630249,\n        0.9109636545181274,\n        0.9174929857254028,\n        0.9215986728668213,\n        0.9189587831497192,\n        0.9224711060523987,\n        0.9235640168190002,\n        0.9081242084503174,\n        0.8990890979766846,\n        0.900691568851471,\n        0.9007959961891174,\n        0.8983866572380066,\n        0.9030368328094482,\n        0.9082856178283691,\n        0.8958720564842224,\n        0.8932167291641235,\n        0.9023438692092896,\n        0.9115447998046875,\n        0.9133612513542175,\n        0.9190444350242615,\n        0.9236005544662476,\n        0.9117952585220337,\n        0.906220018863678,\n        0.914079487323761\n      ],\n      \"q90\": [\n        0.9195939302444458,\n        0.9236188530921936,\n        0.9301517605781555,\n        0.9359439611434937,\n        0.9242846369743347,\n        0.9196143746376038,\n        0.9301571846008301,\n        0.9382931590080261,\n        0.9394593238830566,\n        0.9451783895492554,\n        0.9518223404884338,\n        0.9389423131942749,\n        0.9352357387542725,\n        0.9424091577529907,\n        0.947126030921936,\n        0.9439764618873596,\n        0.9481194019317627,\n        0.9504281878471375,\n        0.9335556030273438,\n        0.9240644574165344,\n        0.9264681935310364,\n        0.9259119629859924,\n        0.9245560765266418,\n        0.9293811321258545,\n        0.9364281296730042,\n        0.9225189685821533,\n        0.9183617234230042,\n        0.9289659261703491,\n        0.937990665435791,\n        0.9396582245826721,\n        0.9460575580596924,\n        0.9509962797164917,\n        0.9378201961517334,\n        0.9311509132385254,\n        0.9398520588874817\n      ]\n    },\n    {\n      \"step\": 3,\n      \"n_points\": 14,\n      \"horizon\": 34,\n      \"last_historical_date\": \"2023-02\",\n      \"historical_dates\": [\n        \"2022-01\",\n        \"2022-02\",\n        \"2022-03\",\n        \"2022-04\",\n        \"2022-05\",\n        \"2022-06\",\n        \"2022-07\",\n        \"2022-08\",\n        \"2022-09\",\n        \"2022-10\",\n        \"2022-11\",\n        \"2022-12\",\n        \"2023-01\",\n        \"2023-02\"\n      ],\n      \"historical_values\": [\n        0.8899999856948853,\n        0.8899999856948853,\n        1.0199999809265137,\n        0.8799999952316284,\n        0.8500000238418579,\n        0.8799999952316284,\n        0.8799999952316284,\n        0.8999999761581421,\n        0.8799999952316284,\n        0.949999988079071,\n        0.7699999809265137,\n        0.7799999713897705,\n        0.8700000047683716,\n        0.9800000190734863\n      ],\n      \"forecast_dates\": [\n        \"2023-03\",\n        \"2023-04\",\n        \"2023-05\",\n        \"2023-06\",\n        \"2023-07\",\n        \"2023-08\",\n        \"2023-09\",\n        \"2023-10\",\n        \"2023-11\",\n        \"2023-12\",\n        \"2024-01\",\n        \"2024-02\",\n        \"2024-03\",\n        \"2024-04\",\n        \"2024-05\",\n        \"2024-06\",\n        \"2024-07\",\n        \"2024-08\",\n        \"2024-09\",\n        \"2024-10\",\n        \"2024-11\",\n        \"2024-12\",\n        \"2025-01\",\n        \"2025-02\",\n        \"2025-03\",\n        \"2025-04\",\n        \"2025-05\",\n        \"2025-06\",\n        \"2025-07\",\n        \"2025-08\",\n        \"2025-09\",\n        \"2025-10\",\n        \"2025-11\",\n        \"2025-12\"\n      ],\n      \"point_forecast\": [\n        0.8962793350219727,\n        0.8913998007774353,\n        0.8914807438850403,\n        0.871181845664978,\n        0.8662641644477844,\n        0.8797636032104492,\n        0.8862841129302979,\n        0.884779691696167,\n        0.8836072087287903,\n        0.8898857235908508,\n        0.8741991519927979,\n        0.8697925806045532,\n        0.8814526796340942,\n        0.8840450048446655,\n        0.8814879655838013,\n        0.8813571333885193,\n        0.8835927248001099,\n        0.8649601936340332,\n        0.8594167828559875,\n        0.8685873746871948,\n        0.872805118560791,\n        0.8739079236984253,\n        0.8808366060256958,\n        0.8895877003669739,\n        0.8769407868385315,\n        0.8714866638183594,\n        0.8808306455612183,\n        0.888067364692688,\n        0.8873578906059265,\n        0.8892648816108704,\n        0.8923593759536743,\n        0.8761922717094421,\n        0.8705070614814758,\n        0.8820964694023132\n      ],\n      \"q10\": [\n        0.9006780982017517,\n        0.8960930705070496,\n        0.8975709676742554,\n        0.8764383792877197,\n        0.8719356060028076,\n        0.8863880038261414,\n        0.8936481475830078,\n        0.891782283782959,\n        0.8906540274620056,\n        0.8970102667808533,\n        0.8820476531982422,\n        0.8772810101509094,\n        0.889976978302002,\n        0.8918938636779785,\n        0.8886879086494446,\n        0.8894075751304626,\n        0.8912825584411621,\n        0.8730634450912476,\n        0.8673158288002014,\n        0.8772640824317932,\n        0.8791468739509583,\n        0.8799763321876526,\n        0.8868378400802612,\n        0.8973256349563599,\n        0.883881151676178,\n        0.879287600517273,\n        0.8892991542816162,\n        0.8954638242721558,\n        0.8954599499702454,\n        0.8977177739143372,\n        0.9008411765098572,\n        0.8844205737113953,\n        0.8789454102516174,\n        0.8901882767677307\n      ],\n      \"q20\": [\n        0.8080285787582397,\n        0.8004014492034912,\n        0.7992052435874939,\n        0.7845293879508972,\n        0.7833878993988037,\n        0.7934101819992065,\n        0.798040509223938,\n        0.7972208261489868,\n        0.7961648106575012,\n        0.7998728156089783,\n        0.789516031742096,\n        0.785558819770813,\n        0.794472336769104,\n        0.7951850295066833,\n        0.7945684194564819,\n        0.794198215007782,\n        0.7945625185966492,\n        0.7808390855789185,\n        0.7763155698776245,\n        0.7829429507255554,\n        0.7852435111999512,\n        0.7865880727767944,\n        0.7909019589424133,\n        0.7960636615753174,\n        0.7863008379936218,\n        0.7832475304603577,\n        0.7900716066360474,\n        0.7962746620178223,\n        0.7965481281280518,\n        0.7976964116096497,\n        0.7985848188400269,\n        0.7879433631896973,\n        0.7850476503372192,\n        0.7922680377960205\n      ],\n      \"q80\": [\n        0.9340344071388245,\n        0.9310296177864075,\n        0.931887149810791,\n        0.9107009768486023,\n        0.9042311310768127,\n        0.9196222424507141,\n        0.9265503287315369,\n        0.9255625605583191,\n        0.9238306283950806,\n        0.9304555058479309,\n        0.913487434387207,\n        0.9083813428878784,\n        0.9220874309539795,\n        0.9244784116744995,\n        0.9214062094688416,\n        0.9219330549240112,\n        0.9250167608261108,\n        0.9045271873474121,\n        0.8984488248825073,\n        0.9084285497665405,\n        0.9120396375656128,\n        0.9134330153465271,\n        0.920710563659668,\n        0.9313111305236816,\n        0.9171351194381714,\n        0.9125726222991943,\n        0.922325611114502,\n        0.9292736649513245,\n        0.9300060272216797,\n        0.932316243648529,\n        0.9348157644271851,\n        0.9165349006652832,\n        0.9105325937271118,\n        0.9230691194534302\n      ],\n      \"q90\": [\n        0.9600221514701843,\n        0.9573583006858826,\n        0.9588406682014465,\n        0.9357264041900635,\n        0.9300737380981445,\n        0.9452965259552002,\n        0.953380823135376,\n        0.9521129727363586,\n        0.9504246711730957,\n        0.9578516483306885,\n        0.9395800828933716,\n        0.9347273707389832,\n        0.9480591416358948,\n        0.950930118560791,\n        0.948790431022644,\n        0.94916832447052,\n        0.9522303342819214,\n        0.9315612316131592,\n        0.9246772527694702,\n        0.9351183772087097,\n        0.9386969208717346,\n        0.9390504956245422,\n        0.9479607939720154,\n        0.9585453867912292,\n        0.9437541961669922,\n        0.9387108683586121,\n        0.9494839906692505,\n        0.9573196172714233,\n        0.9568711519241333,\n        0.9595789909362793,\n        0.9637172222137451,\n        0.9441839456558228,\n        0.936747670173645,\n        0.9499791264533997\n      ]\n    },\n    {\n      \"step\": 4,\n      \"n_points\": 15,\n      \"horizon\": 33,\n      \"last_historical_date\": \"2023-03\",\n      \"historical_dates\": [\n        \"2022-01\",\n        \"2022-02\",\n        \"2022-03\",\n        \"2022-04\",\n        \"2022-05\",\n        \"2022-06\",\n        \"2022-07\",\n        \"2022-08\",\n        \"2022-09\",\n        \"2022-10\",\n        \"2022-11\",\n        \"2022-12\",\n        \"2023-01\",\n        \"2023-02\",\n        \"2023-03\"\n      ],\n      \"historical_values\": [\n        0.8899999856948853,\n        0.8899999856948853,\n        1.0199999809265137,\n        0.8799999952316284,\n        0.8500000238418579,\n        0.8799999952316284,\n        0.8799999952316284,\n        0.8999999761581421,\n        0.8799999952316284,\n        0.949999988079071,\n        0.7699999809265137,\n        0.7799999713897705,\n        0.8700000047683716,\n        0.9800000190734863,\n        1.2100000381469727\n      ],\n      \"forecast_dates\": [\n        \"2023-04\",\n        \"2023-05\",\n        \"2023-06\",\n        \"2023-07\",\n        \"2023-08\",\n        \"2023-09\",\n        \"2023-10\",\n        \"2023-11\",\n        \"2023-12\",\n        \"2024-01\",\n        \"2024-02\",\n        \"2024-03\",\n        \"2024-04\",\n        \"2024-05\",\n        \"2024-06\",\n        \"2024-07\",\n        \"2024-08\",\n        \"2024-09\",\n        \"2024-10\",\n        \"2024-11\",\n        \"2024-12\",\n        \"2025-01\",\n        \"2025-02\",\n        \"2025-03\",\n        \"2025-04\",\n        \"2025-05\",\n        \"2025-06\",\n        \"2025-07\",\n        \"2025-08\",\n        \"2025-09\",\n        \"2025-10\",\n        \"2025-11\",\n        \"2025-12\"\n      ],\n      \"point_forecast\": [\n        1.011451005935669,\n        0.9553948640823364,\n        0.9197208285331726,\n        0.9124891757965088,\n        0.9261340498924255,\n        0.9234520792961121,\n        0.9108935594558716,\n        0.8969470858573914,\n        0.8980726599693298,\n        0.8982804417610168,\n        0.8991943001747131,\n        0.9119693636894226,\n        0.9100792407989502,\n        0.9019815921783447,\n        0.8973109126091003,\n        0.8946781158447266,\n        0.8884148001670837,\n        0.8810747861862183,\n        0.8763440251350403,\n        0.8705035448074341,\n        0.8778358101844788,\n        0.8958552479743958,\n        0.9278874397277832,\n        0.9475082755088806,\n        0.9399139285087585,\n        0.9295593500137329,\n        0.9194858074188232,\n        0.916989803314209,\n        0.9152628779411316,\n        0.9101430773735046,\n        0.8927386999130249,\n        0.8823466897010803,\n        0.8857365250587463\n      ],\n      \"q10\": [\n        1.028891921043396,\n        0.9745897650718689,\n        0.9376441240310669,\n        0.9297030568122864,\n        0.9439254403114319,\n        0.943497896194458,\n        0.9286640286445618,\n        0.9142505526542664,\n        0.9157885313034058,\n        0.9157061576843262,\n        0.9165257215499878,\n        0.929168164730072,\n        0.9264547228813171,\n        0.9190627932548523,\n        0.9123958945274353,\n        0.9115281105041504,\n        0.9037967324256897,\n        0.8992751836776733,\n        0.8952363133430481,\n        0.8902027010917664,\n        0.8936614990234375,\n        0.910301148891449,\n        0.9421884417533875,\n        0.9664905667304993,\n        0.957619309425354,\n        0.9471821784973145,\n        0.9369155168533325,\n        0.9328755736351013,\n        0.9314517974853516,\n        0.9264087677001953,\n        0.9108965992927551,\n        0.9000225067138672,\n        0.9029441475868225\n      ],\n      \"q20\": [\n        0.8432373404502869,\n        0.8032699823379517,\n        0.7799109220504761,\n        0.7799201011657715,\n        0.7939504981040955,\n        0.7942459583282471,\n        0.7866204380989075,\n        0.7787443399429321,\n        0.7860440611839294,\n        0.7884118556976318,\n        0.7909562587738037,\n        0.7990366220474243,\n        0.7990424633026123,\n        0.7951732277870178,\n        0.7943146228790283,\n        0.7914892435073853,\n        0.786389946937561,\n        0.7805740237236023,\n        0.7728126049041748,\n        0.7663388848304749,\n        0.767531156539917,\n        0.7775982618331909,\n        0.7965872287750244,\n        0.8098679184913635,\n        0.8040605187416077,\n        0.7990914583206177,\n        0.7943341135978699,\n        0.795067548751831,\n        0.7930296659469604,\n        0.7909825444221497,\n        0.7814936637878418,\n        0.7742173671722412,\n        0.7788263559341431\n      ],\n      \"q80\": [\n        1.0893518924713135,\n        1.031952142715454,\n        0.9909453392028809,\n        0.9802313446998596,\n        0.9924889802932739,\n        0.9901573657989502,\n        0.973213791847229,\n        0.9567193984985352,\n        0.9561106562614441,\n        0.9526670575141907,\n        0.9554384350776672,\n        0.966469407081604,\n        0.9650457501411438,\n        0.9547586441040039,\n        0.9497334957122803,\n        0.9472479820251465,\n        0.9417811632156372,\n        0.9347074627876282,\n        0.9311444163322449,\n        0.925645649433136,\n        0.9340237975120544,\n        0.9546427726745605,\n        0.9898675680160522,\n        1.0140517950057983,\n        1.006885290145874,\n        0.9937493205070496,\n        0.9815763235092163,\n        0.9766898155212402,\n        0.9745802879333496,\n        0.9689580202102661,\n        0.9494245052337646,\n        0.9369281530380249,\n        0.940288782119751\n      ],\n      \"q90\": [\n        1.143047571182251,\n        1.0867642164230347,\n        1.0392613410949707,\n        1.0258489847183228,\n        1.0397703647613525,\n        1.035668134689331,\n        1.0181812047958374,\n        0.9991654753684998,\n        0.9964229464530945,\n        0.9952237606048584,\n        0.994753360748291,\n        1.0074013471603394,\n        1.0027097463607788,\n        0.9933873414993286,\n        0.9889267086982727,\n        0.9854975342750549,\n        0.9785516262054443,\n        0.9728615880012512,\n        0.9702323079109192,\n        0.9645059108734131,\n        0.9732341766357422,\n        0.9938783049583435,\n        1.0329622030258179,\n        1.060141921043396,\n        1.0525397062301636,\n        1.0378689765930176,\n        1.0230897665023804,\n        1.018609642982483,\n        1.0162283182144165,\n        1.0081523656845093,\n        0.9886332750320435,\n        0.9734073877334595,\n        0.9774399399757385\n      ]\n    },\n    {\n      \"step\": 5,\n      \"n_points\": 16,\n      \"horizon\": 32,\n      \"last_historical_date\": \"2023-04\",\n      \"historical_dates\": [\n        \"2022-01\",\n        \"2022-02\",\n        \"2022-03\",\n        \"2022-04\",\n        \"2022-05\",\n        \"2022-06\",\n        \"2022-07\",\n        \"2022-08\",\n        \"2022-09\",\n        \"2022-10\",\n        \"2022-11\",\n        \"2022-12\",\n        \"2023-01\",\n        \"2023-02\",\n        \"2023-03\",\n        \"2023-04\"\n      ],\n      \"historical_values\": [\n        0.8899999856948853,\n        0.8899999856948853,\n        1.0199999809265137,\n        0.8799999952316284,\n        0.8500000238418579,\n        0.8799999952316284,\n        0.8799999952316284,\n        0.8999999761581421,\n        0.8799999952316284,\n        0.949999988079071,\n        0.7699999809265137,\n        0.7799999713897705,\n        0.8700000047683716,\n        0.9800000190734863,\n        1.2100000381469727,\n        1.0\n      ],\n      \"forecast_dates\": [\n        \"2023-05\",\n        \"2023-06\",\n        \"2023-07\",\n        \"2023-08\",\n        \"2023-09\",\n        \"2023-10\",\n        \"2023-11\",\n        \"2023-12\",\n        \"2024-01\",\n        \"2024-02\",\n        \"2024-03\",\n        \"2024-04\",\n        \"2024-05\",\n        \"2024-06\",\n        \"2024-07\",\n        \"2024-08\",\n        \"2024-09\",\n        \"2024-10\",\n        \"2024-11\",\n        \"2024-12\",\n        \"2025-01\",\n        \"2025-02\",\n        \"2025-03\",\n        \"2025-04\",\n        \"2025-05\",\n        \"2025-06\",\n        \"2025-07\",\n        \"2025-08\",\n        \"2025-09\",\n        \"2025-10\",\n        \"2025-11\",\n        \"2025-12\"\n      ],\n      \"point_forecast\": [\n        0.9379441142082214,\n        0.9161815047264099,\n        0.9183650612831116,\n        0.9345710277557373,\n        0.9429481625556946,\n        0.9236418008804321,\n        0.9020940065383911,\n        0.8962475657463074,\n        0.8969618678092957,\n        0.9029411673545837,\n        0.9058347344398499,\n        0.9071778059005737,\n        0.9064934849739075,\n        0.9002208113670349,\n        0.8948965668678284,\n        0.8888558745384216,\n        0.885951042175293,\n        0.8833035230636597,\n        0.8850363492965698,\n        0.8896763324737549,\n        0.9047040939331055,\n        0.9251466989517212,\n        0.9383421540260315,\n        0.9336385726928711,\n        0.9287689328193665,\n        0.9275407791137695,\n        0.9268409609794617,\n        0.924099326133728,\n        0.9169213771820068,\n        0.9030519127845764,\n        0.8919728398323059,\n        0.8939611315727234\n      ],\n      \"q10\": [\n        0.9455586075782776,\n        0.9275433421134949,\n        0.9313569068908691,\n        0.9499651789665222,\n        0.957696259021759,\n        0.9388371706008911,\n        0.9148422479629517,\n        0.9104428887367249,\n        0.9122737646102905,\n        0.9160297513008118,\n        0.9193358421325684,\n        0.9216225147247314,\n        0.9201593399047852,\n        0.9155508875846863,\n        0.9093347191810608,\n        0.9044749736785889,\n        0.8999581336975098,\n        0.8994951248168945,\n        0.9004791378974915,\n        0.9077976942062378,\n        0.9192850589752197,\n        0.9383060336112976,\n        0.9530308842658997,\n        0.9488463401794434,\n        0.9426198601722717,\n        0.9435754418373108,\n        0.9431970119476318,\n        0.9382244944572449,\n        0.9305117726325989,\n        0.9167183041572571,\n        0.9076744914054871,\n        0.9097439646720886\n      ],\n      \"q20\": [\n        0.8105636239051819,\n        0.7875122427940369,\n        0.787703812122345,\n        0.8008798360824585,\n        0.8086710572242737,\n        0.7946160435676575,\n        0.7819311022758484,\n        0.7810927629470825,\n        0.7885390520095825,\n        0.7923018336296082,\n        0.7944296002388,\n        0.793520987033844,\n        0.7936148643493652,\n        0.7905219793319702,\n        0.7880567312240601,\n        0.7844575643539429,\n        0.7792351245880127,\n        0.7751155495643616,\n        0.7713013887405396,\n        0.7743531465530396,\n        0.7803812026977539,\n        0.7938993573188782,\n        0.8021929860115051,\n        0.7987417578697205,\n        0.794520914554596,\n        0.7944797277450562,\n        0.7938265800476074,\n        0.7947475910186768,\n        0.7923287153244019,\n        0.785821259021759,\n        0.7809209823608398,\n        0.7844333648681641\n      ],\n      \"q80\": [\n        0.9937812685966492,\n        0.9760434627532959,\n        0.9809014797210693,\n        0.9971702098846436,\n        1.0051108598709106,\n        0.985238790512085,\n        0.9596951007843018,\n        0.9502063989639282,\n        0.9515751004219055,\n        0.9542210102081299,\n        0.9595392346382141,\n        0.9599698185920715,\n        0.9596587419509888,\n        0.9517510533332825,\n        0.9467341303825378,\n        0.9418620467185974,\n        0.9391661882400513,\n        0.9384753108024597,\n        0.940481960773468,\n        0.9475308656692505,\n        0.963818371295929,\n        0.9858653545379639,\n        1.0016189813613892,\n        0.9964566826820374,\n        0.9913219213485718,\n        0.9908701181411743,\n        0.9896549582481384,\n        0.9836863279342651,\n        0.9743705987930298,\n        0.9582211375236511,\n        0.9449355006217957,\n        0.94720059633255\n      ],\n      \"q90\": [\n        1.0336796045303345,\n        1.0175514221191406,\n        1.021440029144287,\n        1.0401356220245361,\n        1.0489550828933716,\n        1.0270309448242188,\n        0.9989587068557739,\n        0.9885305166244507,\n        0.9877901077270508,\n        0.9937816262245178,\n        0.996868908405304,\n        0.9987958073616028,\n        0.9956378936767578,\n        0.9891375303268433,\n        0.9845867156982422,\n        0.979006290435791,\n        0.9757927656173706,\n        0.9753840565681458,\n        0.9795432090759277,\n        0.9870526194572449,\n        1.0044395923614502,\n        1.0267916917800903,\n        1.0432230234146118,\n        1.0385234355926514,\n        1.0341284275054932,\n        1.0333774089813232,\n        1.0310395956039429,\n        1.025346040725708,\n        1.014280080795288,\n        0.9950195550918579,\n        0.9828959703445435,\n        0.9817364811897278\n      ]\n    },\n    {\n      \"step\": 6,\n      \"n_points\": 17,\n      \"horizon\": 31,\n      \"last_historical_date\": \"2023-05\",\n      \"historical_dates\": [\n        \"2022-01\",\n        \"2022-02\",\n        \"2022-03\",\n        \"2022-04\",\n        \"2022-05\",\n        \"2022-06\",\n        \"2022-07\",\n        \"2022-08\",\n        \"2022-09\",\n        \"2022-10\",\n        \"2022-11\",\n        \"2022-12\",\n        \"2023-01\",\n        \"2023-02\",\n        \"2023-03\",\n        \"2023-04\",\n        \"2023-05\"\n      ],\n      \"historical_values\": [\n        0.8899999856948853,\n        0.8899999856948853,\n        1.0199999809265137,\n        0.8799999952316284,\n        0.8500000238418579,\n        0.8799999952316284,\n        0.8799999952316284,\n        0.8999999761581421,\n        0.8799999952316284,\n        0.949999988079071,\n        0.7699999809265137,\n        0.7799999713897705,\n        0.8700000047683716,\n        0.9800000190734863,\n        1.2100000381469727,\n        1.0,\n        0.9399999976158142\n      ],\n      \"forecast_dates\": [\n        \"2023-06\",\n        \"2023-07\",\n        \"2023-08\",\n        \"2023-09\",\n        \"2023-10\",\n        \"2023-11\",\n        \"2023-12\",\n        \"2024-01\",\n        \"2024-02\",\n        \"2024-03\",\n        \"2024-04\",\n        \"2024-05\",\n        \"2024-06\",\n        \"2024-07\",\n        \"2024-08\",\n        \"2024-09\",\n        \"2024-10\",\n        \"2024-11\",\n        \"2024-12\",\n        \"2025-01\",\n        \"2025-02\",\n        \"2025-03\",\n        \"2025-04\",\n        \"2025-05\",\n        \"2025-06\",\n        \"2025-07\",\n        \"2025-08\",\n        \"2025-09\",\n        \"2025-10\",\n        \"2025-11\",\n        \"2025-12\"\n      ],\n      \"point_forecast\": [\n        0.9097275137901306,\n        0.9010418057441711,\n        0.9079869985580444,\n        0.9222638010978699,\n        0.932843029499054,\n        0.9133341312408447,\n        0.8972155451774597,\n        0.8887625336647034,\n        0.8941851854324341,\n        0.9068790674209595,\n        0.9091910123825073,\n        0.9068935513496399,\n        0.8990182876586914,\n        0.8986428380012512,\n        0.8881825804710388,\n        0.8843041658401489,\n        0.888336718082428,\n        0.8892695307731628,\n        0.8974661231040955,\n        0.9044860601425171,\n        0.9227194786071777,\n        0.9294296503067017,\n        0.9252649545669556,\n        0.9205634593963623,\n        0.9196065664291382,\n        0.9199687242507935,\n        0.9132981300354004,\n        0.9133179187774658,\n        0.9007443785667419,\n        0.8912027478218079,\n        0.8934641480445862\n      ],\n      \"q10\": [\n        0.9192558526992798,\n        0.9128602147102356,\n        0.9227687120437622,\n        0.9362373352050781,\n        0.9478849172592163,\n        0.9271639585494995,\n        0.910339891910553,\n        0.9013872146606445,\n        0.908535897731781,\n        0.9196968078613281,\n        0.9216489791870117,\n        0.9205824136734009,\n        0.9120896458625793,\n        0.9124637842178345,\n        0.9021389484405518,\n        0.8997719883918762,\n        0.9026364684104919,\n        0.9033412933349609,\n        0.9109377264976501,\n        0.9189012050628662,\n        0.9366557598114014,\n        0.9421946406364441,\n        0.937626302242279,\n        0.9345484972000122,\n        0.9316884875297546,\n        0.9340106844902039,\n        0.9270667433738708,\n        0.9266247749328613,\n        0.9148653745651245,\n        0.9044336676597595,\n        0.9073527455329895\n      ],\n      \"q20\": [\n        0.7991487383842468,\n        0.7880749702453613,\n        0.7902460098266602,\n        0.8014485239982605,\n        0.8115598559379578,\n        0.7963781952857971,\n        0.7883695960044861,\n        0.7836517691612244,\n        0.7910313606262207,\n        0.799010694026947,\n        0.8031657934188843,\n        0.8004167675971985,\n        0.7960184216499329,\n        0.7969078421592712,\n        0.7900155782699585,\n        0.7853973507881165,\n        0.7849644422531128,\n        0.7844982743263245,\n        0.7866605520248413,\n        0.7920172810554504,\n        0.8011935353279114,\n        0.8064550161361694,\n        0.8041524887084961,\n        0.8006000518798828,\n        0.7974086403846741,\n        0.7984392046928406,\n        0.7938262224197388,\n        0.7966775298118591,\n        0.7895344495773315,\n        0.7830621004104614,\n        0.7873432636260986\n      ],\n      \"q80\": [\n        0.9585660099983215,\n        0.9542173743247986,\n        0.9642703533172607,\n        0.9804073572158813,\n        0.9885033965110779,\n        0.9688029289245605,\n        0.949183464050293,\n        0.9374165534973145,\n        0.9444000124931335,\n        0.9574207663536072,\n        0.9588959217071533,\n        0.9561213254928589,\n        0.9485365748405457,\n        0.9463241100311279,\n        0.9353682994842529,\n        0.934599757194519,\n        0.9394335746765137,\n        0.9425153136253357,\n        0.9504368901252747,\n        0.9591487050056458,\n        0.9809996485710144,\n        0.986733615398407,\n        0.982063353061676,\n        0.9771464467048645,\n        0.9761553406715393,\n        0.977692723274231,\n        0.9702091813087463,\n        0.9681852459907532,\n        0.9539398550987244,\n        0.942665696144104,\n        0.9438384771347046\n      ],\n      \"q90\": [\n        0.994154691696167,\n        0.9911658763885498,\n        1.0009171962738037,\n        1.0182007551193237,\n        1.0296927690505981,\n        1.0062158107757568,\n        0.985028862953186,\n        0.9721169471740723,\n        0.9787886142730713,\n        0.9931607246398926,\n        0.9947684407234192,\n        0.9917771220207214,\n        0.9817482233047485,\n        0.9805346727371216,\n        0.9713162779808044,\n        0.9691506624221802,\n        0.9753089547157288,\n        0.9789929986000061,\n        0.988203227519989,\n        0.9974985122680664,\n        1.0200386047363281,\n        1.024385929107666,\n        1.0200226306915283,\n        1.0142742395401,\n        1.0153833627700806,\n        1.0168485641479492,\n        1.0072355270385742,\n        1.0065840482711792,\n        0.9912008047103882,\n        0.9780105948448181,\n        0.9798558950424194\n      ]\n    },\n    {\n      \"step\": 7,\n      \"n_points\": 18,\n      \"horizon\": 30,\n      \"last_historical_date\": \"2023-06\",\n      \"historical_dates\": [\n        \"2022-01\",\n        \"2022-02\",\n        \"2022-03\",\n        \"2022-04\",\n        \"2022-05\",\n        \"2022-06\",\n        \"2022-07\",\n        \"2022-08\",\n        \"2022-09\",\n        \"2022-10\",\n        \"2022-11\",\n        \"2022-12\",\n        \"2023-01\",\n        \"2023-02\",\n        \"2023-03\",\n        \"2023-04\",\n        \"2023-05\",\n        \"2023-06\"\n      ],\n      \"historical_values\": [\n        0.8899999856948853,\n        0.8899999856948853,\n        1.0199999809265137,\n        0.8799999952316284,\n        0.8500000238418579,\n        0.8799999952316284,\n        0.8799999952316284,\n        0.8999999761581421,\n        0.8799999952316284,\n        0.949999988079071,\n        0.7699999809265137,\n        0.7799999713897705,\n        0.8700000047683716,\n        0.9800000190734863,\n        1.2100000381469727,\n        1.0,\n        0.9399999976158142,\n        1.0800000429153442\n      ],\n      \"forecast_dates\": [\n        \"2023-07\",\n        \"2023-08\",\n        \"2023-09\",\n        \"2023-10\",\n        \"2023-11\",\n        \"2023-12\",\n        \"2024-01\",\n        \"2024-02\",\n        \"2024-03\",\n        \"2024-04\",\n        \"2024-05\",\n        \"2024-06\",\n        \"2024-07\",\n        \"2024-08\",\n        \"2024-09\",\n        \"2024-10\",\n        \"2024-11\",\n        \"2024-12\",\n        \"2025-01\",\n        \"2025-02\",\n        \"2025-03\",\n        \"2025-04\",\n        \"2025-05\",\n        \"2025-06\",\n        \"2025-07\",\n        \"2025-08\",\n        \"2025-09\",\n        \"2025-10\",\n        \"2025-11\",\n        \"2025-12\"\n      ],\n      \"point_forecast\": [\n        0.9665141701698303,\n        0.9519135355949402,\n        0.9444465637207031,\n        0.9402952790260315,\n        0.9306893348693848,\n        0.9244646430015564,\n        0.9174035787582397,\n        0.9139379858970642,\n        0.9132129549980164,\n        0.9145187735557556,\n        0.911784291267395,\n        0.9093538522720337,\n        0.9040751457214355,\n        0.9021264314651489,\n        0.8961065411567688,\n        0.8968585133552551,\n        0.9025744795799255,\n        0.9108133316040039,\n        0.9250923991203308,\n        0.9451119899749756,\n        0.9571705460548401,\n        0.9546100497245789,\n        0.9493789076805115,\n        0.9495347738265991,\n        0.9465805292129517,\n        0.942088782787323,\n        0.934301495552063,\n        0.927003026008606,\n        0.9134135842323303,\n        0.9131123423576355\n      ],\n      \"q10\": [\n        0.9755732417106628,\n        0.9652556777000427,\n        0.9605708122253418,\n        0.9540410041809082,\n        0.944946825504303,\n        0.9393219351768494,\n        0.9324542880058289,\n        0.9295912981033325,\n        0.9304096698760986,\n        0.9316055178642273,\n        0.9279895424842834,\n        0.9257113337516785,\n        0.9213154315948486,\n        0.9203523397445679,\n        0.9135439991950989,\n        0.9169613718986511,\n        0.9193251729011536,\n        0.9290840029716492,\n        0.9407450556755066,\n        0.9611459970474243,\n        0.9715418815612793,\n        0.966630756855011,\n        0.9606484770774841,\n        0.9624485373497009,\n        0.9596085548400879,\n        0.9563205242156982,\n        0.9496365189552307,\n        0.9395637512207031,\n        0.9281183481216431,\n        0.9275621175765991\n      ],\n      \"q20\": [\n        0.833349347114563,\n        0.8175394535064697,\n        0.8078386783599854,\n        0.8068903088569641,\n        0.8031129837036133,\n        0.801506757736206,\n        0.7994549870491028,\n        0.7967816591262817,\n        0.7986584305763245,\n        0.7988185882568359,\n        0.799284040927887,\n        0.7968909740447998,\n        0.7936790585517883,\n        0.792199432849884,\n        0.7875745892524719,\n        0.7865579128265381,\n        0.7882473468780518,\n        0.7924611568450928,\n        0.7977651357650757,\n        0.8117226362228394,\n        0.8149524331092834,\n        0.8140331506729126,\n        0.8101717233657837,\n        0.8099949359893799,\n        0.8057650923728943,\n        0.8038991093635559,\n        0.7993261814117432,\n        0.798288106918335,\n        0.7926219701766968,\n        0.7953957319259644\n      ],\n      \"q80\": [\n        1.0251524448394775,\n        1.015281319618225,\n        1.0085906982421875,\n        1.0044453144073486,\n        0.9904035329818726,\n        0.9857988953590393,\n        0.977156400680542,\n        0.9709676504135132,\n        0.9726237654685974,\n        0.9721717238426208,\n        0.9683824181556702,\n        0.9648834466934204,\n        0.9616217613220215,\n        0.9584988355636597,\n        0.9530823230743408,\n        0.9561627507209778,\n        0.9611006379127502,\n        0.9723068475723267,\n        0.9880313873291016,\n        1.0103445053100586,\n        1.02413809299469,\n        1.0192902088165283,\n        1.0122601985931396,\n        1.0145885944366455,\n        1.012281060218811,\n        1.0074970722198486,\n        0.9987425804138184,\n        0.987089216709137,\n        0.9722681045532227,\n        0.9707110524177551\n      ],\n      \"q90\": [\n        1.0656019449234009,\n        1.059928059577942,\n        1.0517113208770752,\n        1.0461057424545288,\n        1.035980224609375,\n        1.0275849103927612,\n        1.0181881189346313,\n        1.0124856233596802,\n        1.0126112699508667,\n        1.0153447389602661,\n        1.0106351375579834,\n        1.0058791637420654,\n        1.0014264583587646,\n        0.999718964099884,\n        0.9958565831184387,\n        0.9977275133132935,\n        1.0037381649017334,\n        1.0153366327285767,\n        1.031912088394165,\n        1.055626630783081,\n        1.0701265335083008,\n        1.0629067420959473,\n        1.0560659170150757,\n        1.0568609237670898,\n        1.0577772855758667,\n        1.0517592430114746,\n        1.0405441522598267,\n        1.030192494392395,\n        1.013637900352478,\n        1.0091335773468018\n      ]\n    },\n    {\n      \"step\": 8,\n      \"n_points\": 19,\n      \"horizon\": 29,\n      \"last_historical_date\": \"2023-07\",\n      \"historical_dates\": [\n        \"2022-01\",\n        \"2022-02\",\n        \"2022-03\",\n        \"2022-04\",\n        \"2022-05\",\n        \"2022-06\",\n        \"2022-07\",\n        \"2022-08\",\n        \"2022-09\",\n        \"2022-10\",\n        \"2022-11\",\n        \"2022-12\",\n        \"2023-01\",\n        \"2023-02\",\n        \"2023-03\",\n        \"2023-04\",\n        \"2023-05\",\n        \"2023-06\",\n        \"2023-07\"\n      ],\n      \"historical_values\": [\n        0.8899999856948853,\n        0.8899999856948853,\n        1.0199999809265137,\n        0.8799999952316284,\n        0.8500000238418579,\n        0.8799999952316284,\n        0.8799999952316284,\n        0.8999999761581421,\n        0.8799999952316284,\n        0.949999988079071,\n        0.7699999809265137,\n        0.7799999713897705,\n        0.8700000047683716,\n        0.9800000190734863,\n        1.2100000381469727,\n        1.0,\n        0.9399999976158142,\n        1.0800000429153442,\n        1.1799999475479126\n      ],\n      \"forecast_dates\": [\n        \"2023-08\",\n        \"2023-09\",\n        \"2023-10\",\n        \"2023-11\",\n        \"2023-12\",\n        \"2024-01\",\n        \"2024-02\",\n        \"2024-03\",\n        \"2024-04\",\n        \"2024-05\",\n        \"2024-06\",\n        \"2024-07\",\n        \"2024-08\",\n        \"2024-09\",\n        \"2024-10\",\n        \"2024-11\",\n        \"2024-12\",\n        \"2025-01\",\n        \"2025-02\",\n        \"2025-03\",\n        \"2025-04\",\n        \"2025-05\",\n        \"2025-06\",\n        \"2025-07\",\n        \"2025-08\",\n        \"2025-09\",\n        \"2025-10\",\n        \"2025-11\",\n        \"2025-12\"\n      ],\n      \"point_forecast\": [\n        1.0381698608398438,\n        1.012021780014038,\n        0.99420565366745,\n        0.9754087924957275,\n        0.9563038349151611,\n        0.9495773315429688,\n        0.9422544240951538,\n        0.9361824989318848,\n        0.9247673749923706,\n        0.9178153276443481,\n        0.9097317457199097,\n        0.901350200176239,\n        0.8968333601951599,\n        0.8947892189025879,\n        0.8923584818840027,\n        0.8944633603096008,\n        0.9065102338790894,\n        0.9204601049423218,\n        0.951920211315155,\n        0.9842206239700317,\n        0.99086993932724,\n        0.9848544597625732,\n        0.9833636283874512,\n        0.9852919578552246,\n        0.9797993302345276,\n        0.9684444069862366,\n        0.9575868844985962,\n        0.9473453760147095,\n        0.9351227283477783\n      ],\n      \"q10\": [\n        1.0491734743118286,\n        1.028739333152771,\n        1.0114028453826904,\n        0.9906209111213684,\n        0.971588134765625,\n        0.9669111371040344,\n        0.9621954560279846,\n        0.9568055868148804,\n        0.9453385472297668,\n        0.9398422241210938,\n        0.9300127029418945,\n        0.922597348690033,\n        0.9215761423110962,\n        0.9172200560569763,\n        0.9145788550376892,\n        0.9178516864776611,\n        0.9267954230308533,\n        0.9420651793479919,\n        0.9693762063980103,\n        1.003636121749878,\n        1.005869746208191,\n        0.9975773096084595,\n        0.9942836165428162,\n        0.9985279440879822,\n        0.9944182634353638,\n        0.985649824142456,\n        0.9736542105674744,\n        0.9612159729003906,\n        0.9520760774612427\n      ],\n      \"q20\": [\n        0.8832447528839111,\n        0.8571564555168152,\n        0.840262234210968,\n        0.8279801607131958,\n        0.8175891637802124,\n        0.8145928382873535,\n        0.8104804754257202,\n        0.8050722479820251,\n        0.8001488447189331,\n        0.7951650619506836,\n        0.7925589084625244,\n        0.78853440284729,\n        0.785635232925415,\n        0.7818436622619629,\n        0.7790342569351196,\n        0.779435932636261,\n        0.7866798639297485,\n        0.7947074174880981,\n        0.8116522431373596,\n        0.834707498550415,\n        0.8330732583999634,\n        0.8280425667762756,\n        0.8265914916992188,\n        0.8280237317085266,\n        0.823756992816925,\n        0.820884108543396,\n        0.8138716816902161,\n        0.8067872524261475,\n        0.8027349710464478\n      ],\n      \"q80\": [\n        1.10765540599823,\n        1.0850690603256226,\n        1.0677224397659302,\n        1.0468156337738037,\n        1.0239413976669312,\n        1.018355131149292,\n        1.0108981132507324,\n        1.0029836893081665,\n        0.9916971325874329,\n        0.9822992086410522,\n        0.9713731408119202,\n        0.9630072712898254,\n        0.9601694941520691,\n        0.9586890339851379,\n        0.955090343952179,\n        0.9576360583305359,\n        0.9701409339904785,\n        0.9886602759361267,\n        1.02058744430542,\n        1.0570831298828125,\n        1.0654001235961914,\n        1.0563757419586182,\n        1.0534954071044922,\n        1.0564368963241577,\n        1.051694393157959,\n        1.0388209819793701,\n        1.025420904159546,\n        1.0107486248016357,\n        0.9982277750968933\n      ],\n      \"q90\": [\n        1.1553966999053955,\n        1.137328863143921,\n        1.1165260076522827,\n        1.0933233499526978,\n        1.072894811630249,\n        1.065496563911438,\n        1.0601707696914673,\n        1.0506465435028076,\n        1.038832187652588,\n        1.0302690267562866,\n        1.018511414527893,\n        1.0077110528945923,\n        1.0042316913604736,\n        1.0026092529296875,\n        1.0030121803283691,\n        1.0043935775756836,\n        1.018110990524292,\n        1.0365487337112427,\n        1.0698375701904297,\n        1.1068248748779297,\n        1.114990472793579,\n        1.105769395828247,\n        1.1021937131881714,\n        1.1038919687271118,\n        1.1002414226531982,\n        1.0864661931991577,\n        1.0711843967437744,\n        1.0577744245529175,\n        1.044431209564209\n      ]\n    },\n    {\n      \"step\": 9,\n      \"n_points\": 20,\n      \"horizon\": 28,\n      \"last_historical_date\": \"2023-08\",\n      \"historical_dates\": [\n        \"2022-01\",\n        \"2022-02\",\n        \"2022-03\",\n        \"2022-04\",\n        \"2022-05\",\n        \"2022-06\",\n        \"2022-07\",\n        \"2022-08\",\n        \"2022-09\",\n        \"2022-10\",\n        \"2022-11\",\n        \"2022-12\",\n        \"2023-01\",\n        \"2023-02\",\n        \"2023-03\",\n        \"2023-04\",\n        \"2023-05\",\n        \"2023-06\",\n        \"2023-07\",\n        \"2023-08\"\n      ],\n      \"historical_values\": [\n        0.8899999856948853,\n        0.8899999856948853,\n        1.0199999809265137,\n        0.8799999952316284,\n        0.8500000238418579,\n        0.8799999952316284,\n        0.8799999952316284,\n        0.8999999761581421,\n        0.8799999952316284,\n        0.949999988079071,\n        0.7699999809265137,\n        0.7799999713897705,\n        0.8700000047683716,\n        0.9800000190734863,\n        1.2100000381469727,\n        1.0,\n        0.9399999976158142,\n        1.0800000429153442,\n        1.1799999475479126,\n        1.2400000095367432\n      ],\n      \"forecast_dates\": [\n        \"2023-09\",\n        \"2023-10\",\n        \"2023-11\",\n        \"2023-12\",\n        \"2024-01\",\n        \"2024-02\",\n        \"2024-03\",\n        \"2024-04\",\n        \"2024-05\",\n        \"2024-06\",\n        \"2024-07\",\n        \"2024-08\",\n        \"2024-09\",\n        \"2024-10\",\n        \"2024-11\",\n        \"2024-12\",\n        \"2025-01\",\n        \"2025-02\",\n        \"2025-03\",\n        \"2025-04\",\n        \"2025-05\",\n        \"2025-06\",\n        \"2025-07\",\n        \"2025-08\",\n        \"2025-09\",\n        \"2025-10\",\n        \"2025-11\",\n        \"2025-12\"\n      ],\n      \"point_forecast\": [\n        1.1063826084136963,\n        1.0667672157287598,\n        1.0312474966049194,\n        1.0092777013778687,\n        0.9886403679847717,\n        0.9805473685264587,\n        0.96883624792099,\n        0.9500434994697571,\n        0.9289879202842712,\n        0.9156991839408875,\n        0.9083491563796997,\n        0.9020676016807556,\n        0.9000667333602905,\n        0.8952069878578186,\n        0.8887008428573608,\n        0.8977259993553162,\n        0.9318806529045105,\n        0.9759154915809631,\n        1.0011931657791138,\n        1.0136791467666626,\n        1.0154764652252197,\n        1.0213247537612915,\n        1.0302479267120361,\n        1.032987117767334,\n        1.0179458856582642,\n        0.9947344660758972,\n        0.9729111194610596,\n        0.9626883268356323\n      ],\n      \"q10\": [\n        1.114622950553894,\n        1.083889365196228,\n        1.0484296083450317,\n        1.0276585817337036,\n        1.008374571800232,\n        0.999535322189331,\n        0.9902844429016113,\n        0.9757266640663147,\n        0.9533360600471497,\n        0.9409008026123047,\n        0.9341027736663818,\n        0.9281788468360901,\n        0.9299426674842834,\n        0.921561062335968,\n        0.9143303632736206,\n        0.9240468144416809,\n        0.9563655853271484,\n        1.0021518468856812,\n        1.0241011381149292,\n        1.0326213836669922,\n        1.0297893285751343,\n        1.0334995985031128,\n        1.0426249504089355,\n        1.047775149345398,\n        1.031937837600708,\n        1.0122848749160767,\n        0.9894399642944336,\n        0.978018045425415\n      ],\n      \"q20\": [\n        0.928669810295105,\n        0.8862699866294861,\n        0.8555266261100769,\n        0.8365516662597656,\n        0.8246086835861206,\n        0.8187647461891174,\n        0.8126576542854309,\n        0.8008460402488708,\n        0.7927306890487671,\n        0.7833954095840454,\n        0.7795919179916382,\n        0.7797963619232178,\n        0.7819650173187256,\n        0.7769280672073364,\n        0.7692436575889587,\n        0.7726868391036987,\n        0.7912442684173584,\n        0.8222379088401794,\n        0.8362159132957458,\n        0.8447703719139099,\n        0.8396773934364319,\n        0.8379412293434143,\n        0.8396240472793579,\n        0.8429920077323914,\n        0.833158016204834,\n        0.823620080947876,\n        0.8104652762413025,\n        0.8035314083099365\n      ],\n      \"q80\": [\n        1.1856414079666138,\n        1.1520715951919556,\n        1.117408037185669,\n        1.0936567783355713,\n        1.0721673965454102,\n        1.0631694793701172,\n        1.048310399055481,\n        1.0276391506195068,\n        1.0055267810821533,\n        0.9882948994636536,\n        0.9792788624763489,\n        0.9736778736114502,\n        0.9714402556419373,\n        0.9655618071556091,\n        0.9581301808357239,\n        0.9696058034896851,\n        1.0068414211273193,\n        1.0576438903808594,\n        1.0841014385223389,\n        1.0951288938522339,\n        1.1002628803253174,\n        1.1048551797866821,\n        1.1152007579803467,\n        1.1188753843307495,\n        1.1012613773345947,\n        1.0757598876953125,\n        1.0499663352966309,\n        1.0353318452835083\n      ],\n      \"q90\": [\n        1.23917818069458,\n        1.2113547325134277,\n        1.1742331981658936,\n        1.151162028312683,\n        1.1314780712127686,\n        1.1195954084396362,\n        1.10871160030365,\n        1.0842714309692383,\n        1.0615670680999756,\n        1.0447986125946045,\n        1.0315890312194824,\n        1.024493932723999,\n        1.0225589275360107,\n        1.0159486532211304,\n        1.0109714269638062,\n        1.0237358808517456,\n        1.0626462697982788,\n        1.1151678562164307,\n        1.1429989337921143,\n        1.151975393295288,\n        1.1542848348617554,\n        1.1620826721191406,\n        1.1735471487045288,\n        1.1768637895584106,\n        1.1591532230377197,\n        1.1298507452011108,\n        1.1008673906326294,\n        1.0874378681182861\n      ]\n    },\n    {\n      \"step\": 10,\n      \"n_points\": 21,\n      \"horizon\": 27,\n      \"last_historical_date\": \"2023-09\",\n      \"historical_dates\": [\n        \"2022-01\",\n        \"2022-02\",\n        \"2022-03\",\n        \"2022-04\",\n        \"2022-05\",\n        \"2022-06\",\n        \"2022-07\",\n        \"2022-08\",\n        \"2022-09\",\n        \"2022-10\",\n        \"2022-11\",\n        \"2022-12\",\n        \"2023-01\",\n        \"2023-02\",\n        \"2023-03\",\n        \"2023-04\",\n        \"2023-05\",\n        \"2023-06\",\n        \"2023-07\",\n        \"2023-08\",\n        \"2023-09\"\n      ],\n      \"historical_values\": [\n        0.8899999856948853,\n        0.8899999856948853,\n        1.0199999809265137,\n        0.8799999952316284,\n        0.8500000238418579,\n        0.8799999952316284,\n        0.8799999952316284,\n        0.8999999761581421,\n        0.8799999952316284,\n        0.949999988079071,\n        0.7699999809265137,\n        0.7799999713897705,\n        0.8700000047683716,\n        0.9800000190734863,\n        1.2100000381469727,\n        1.0,\n        0.9399999976158142,\n        1.0800000429153442,\n        1.1799999475479126,\n        1.2400000095367432,\n        1.4700000286102295\n      ],\n      \"forecast_dates\": [\n        \"2023-10\",\n        \"2023-11\",\n        \"2023-12\",\n        \"2024-01\",\n        \"2024-02\",\n        \"2024-03\",\n        \"2024-04\",\n        \"2024-05\",\n        \"2024-06\",\n        \"2024-07\",\n        \"2024-08\",\n        \"2024-09\",\n        \"2024-10\",\n        \"2024-11\",\n        \"2024-12\",\n        \"2025-01\",\n        \"2025-02\",\n        \"2025-03\",\n        \"2025-04\",\n        \"2025-05\",\n        \"2025-06\",\n        \"2025-07\",\n        \"2025-08\",\n        \"2025-09\",\n        \"2025-10\",\n        \"2025-11\",\n        \"2025-12\"\n      ],\n      \"point_forecast\": [\n        1.2447655200958252,\n        1.1675736904144287,\n        1.1279113292694092,\n        1.1182188987731934,\n        1.1093124151229858,\n        1.082032322883606,\n        1.038187861442566,\n        0.9993720650672913,\n        0.9796157479286194,\n        0.9642789959907532,\n        0.9476039409637451,\n        0.9355512857437134,\n        0.9285284876823425,\n        0.9198805689811707,\n        0.9171224236488342,\n        0.9307593703269958,\n        0.968371570110321,\n        0.9984195232391357,\n        0.9925985932350159,\n        0.9877574443817139,\n        0.9934283494949341,\n        1.0018013715744019,\n        1.015841007232666,\n        1.0086023807525635,\n        0.981035590171814,\n        0.9596301913261414,\n        0.950019896030426\n      ],\n      \"q10\": [\n        1.2715141773223877,\n        1.2083916664123535,\n        1.1731905937194824,\n        1.165351152420044,\n        1.162253975868225,\n        1.13302481174469,\n        1.085452914237976,\n        1.051274299621582,\n        1.0313252210617065,\n        1.0168172121047974,\n        0.9987383484840393,\n        0.9869235754013062,\n        0.9812518358230591,\n        0.9713582396507263,\n        0.9646536111831665,\n        0.9781244397163391,\n        1.0141757726669312,\n        1.050538420677185,\n        1.0407419204711914,\n        1.032418966293335,\n        1.0342923402786255,\n        1.0425127744674683,\n        1.05617356300354,\n        1.0557340383529663,\n        1.0226852893829346,\n        1.0031310319900513,\n        0.9946122169494629\n      ],\n      \"q20\": [\n        0.9692280888557434,\n        0.9033447504043579,\n        0.8709640502929688,\n        0.8632612824440002,\n        0.8616656064987183,\n        0.8437307476997375,\n        0.8145183324813843,\n        0.7942112684249878,\n        0.7919824123382568,\n        0.7849438190460205,\n        0.7758752703666687,\n        0.7725547552108765,\n        0.7724835276603699,\n        0.7696750164031982,\n        0.7656691074371338,\n        0.7687865495681763,\n        0.7848570346832275,\n        0.8048490285873413,\n        0.7928374409675598,\n        0.7848871946334839,\n        0.7746942043304443,\n        0.7734623551368713,\n        0.7735666036605835,\n        0.765663743019104,\n        0.7521377205848694,\n        0.7475736737251282,\n        0.7519190907478333\n      ],\n      \"q80\": [\n        1.3772318363189697,\n        1.3073946237564087,\n        1.267617106437683,\n        1.2576971054077148,\n        1.2495336532592773,\n        1.2185810804367065,\n        1.1627202033996582,\n        1.1192079782485962,\n        1.093948483467102,\n        1.0731803178787231,\n        1.0513980388641357,\n        1.0379669666290283,\n        1.0290329456329346,\n        1.0203547477722168,\n        1.0156269073486328,\n        1.0321729183197021,\n        1.0734044313430786,\n        1.1123948097229004,\n        1.1079280376434326,\n        1.1026053428649902,\n        1.1133449077606201,\n        1.1250957250595093,\n        1.1411525011062622,\n        1.1397948265075684,\n        1.104438066482544,\n        1.076056957244873,\n        1.0614937543869019\n      ],\n      \"q90\": [\n        1.4695751667022705,\n        1.4090934991836548,\n        1.3679797649383545,\n        1.3577240705490112,\n        1.3525687456130981,\n        1.315553903579712,\n        1.2607886791229248,\n        1.2103060483932495,\n        1.1827821731567383,\n        1.1617928743362427,\n        1.1323959827423096,\n        1.1176999807357788,\n        1.1078500747680664,\n        1.094464659690857,\n        1.0922305583953857,\n        1.1100425720214844,\n        1.1573286056518555,\n        1.1960670948028564,\n        1.1912381649017334,\n        1.1854121685028076,\n        1.1960220336914062,\n        1.2121495008468628,\n        1.231044888496399,\n        1.2304543256759644,\n        1.1941158771514893,\n        1.1591618061065674,\n        1.1404690742492676\n      ]\n    },\n    {\n      \"step\": 11,\n      \"n_points\": 22,\n      \"horizon\": 26,\n      \"last_historical_date\": \"2023-10\",\n      \"historical_dates\": [\n        \"2022-01\",\n        \"2022-02\",\n        \"2022-03\",\n        \"2022-04\",\n        \"2022-05\",\n        \"2022-06\",\n        \"2022-07\",\n        \"2022-08\",\n        \"2022-09\",\n        \"2022-10\",\n        \"2022-11\",\n        \"2022-12\",\n        \"2023-01\",\n        \"2023-02\",\n        \"2023-03\",\n        \"2023-04\",\n        \"2023-05\",\n        \"2023-06\",\n        \"2023-07\",\n        \"2023-08\",\n        \"2023-09\",\n        \"2023-10\"\n      ],\n      \"historical_values\": [\n        0.8899999856948853,\n        0.8899999856948853,\n        1.0199999809265137,\n        0.8799999952316284,\n        0.8500000238418579,\n        0.8799999952316284,\n        0.8799999952316284,\n        0.8999999761581421,\n        0.8799999952316284,\n        0.949999988079071,\n        0.7699999809265137,\n        0.7799999713897705,\n        0.8700000047683716,\n        0.9800000190734863,\n        1.2100000381469727,\n        1.0,\n        0.9399999976158142,\n        1.0800000429153442,\n        1.1799999475479126,\n        1.2400000095367432,\n        1.4700000286102295,\n        1.3200000524520874\n      ],\n      \"forecast_dates\": [\n        \"2023-11\",\n        \"2023-12\",\n        \"2024-01\",\n        \"2024-02\",\n        \"2024-03\",\n        \"2024-04\",\n        \"2024-05\",\n        \"2024-06\",\n        \"2024-07\",\n        \"2024-08\",\n        \"2024-09\",\n        \"2024-10\",\n        \"2024-11\",\n        \"2024-12\",\n        \"2025-01\",\n        \"2025-02\",\n        \"2025-03\",\n        \"2025-04\",\n        \"2025-05\",\n        \"2025-06\",\n        \"2025-07\",\n        \"2025-08\",\n        \"2025-09\",\n        \"2025-10\",\n        \"2025-11\",\n        \"2025-12\"\n      ],\n      \"point_forecast\": [\n        1.1978572607040405,\n        1.1348843574523926,\n        1.107893705368042,\n        1.0890357494354248,\n        1.075318455696106,\n        1.0392742156982422,\n        1.0066328048706055,\n        0.9802011847496033,\n        0.968873143196106,\n        0.9584881663322449,\n        0.9440371990203857,\n        0.929160475730896,\n        0.9231215715408325,\n        0.9263395667076111,\n        0.9453635811805725,\n        0.9834461212158203,\n        1.0165542364120483,\n        1.0119515657424927,\n        1.0067156553268433,\n        1.0273469686508179,\n        1.068889856338501,\n        1.1046394109725952,\n        1.120269536972046,\n        1.084791898727417,\n        1.0440341234207153,\n        1.0170215368270874\n      ],\n      \"q10\": [\n        1.2035126686096191,\n        1.153576135635376,\n        1.1352055072784424,\n        1.1203036308288574,\n        1.1123145818710327,\n        1.0742825269699097,\n        1.0389323234558105,\n        1.017652988433838,\n        1.0134992599487305,\n        1.0038114786148071,\n        0.9876317381858826,\n        0.972976565361023,\n        0.9668206572532654,\n        0.9690794348716736,\n        0.9879439473152161,\n        1.0214078426361084,\n        1.0546575784683228,\n        1.0502262115478516,\n        1.0401197671890259,\n        1.0604331493377686,\n        1.0953052043914795,\n        1.1325199604034424,\n        1.144276738166809,\n        1.1159130334854126,\n        1.0714142322540283,\n        1.0489661693572998\n      ],\n      \"q20\": [\n        0.9713577032089233,\n        0.9063910245895386,\n        0.8755015134811401,\n        0.8545557260513306,\n        0.8455488681793213,\n        0.8177679777145386,\n        0.799569845199585,\n        0.7851544618606567,\n        0.7884225249290466,\n        0.7802386283874512,\n        0.7720929980278015,\n        0.7622212171554565,\n        0.7612568736076355,\n        0.7628719806671143,\n        0.7799019813537598,\n        0.7968021035194397,\n        0.8116334676742554,\n        0.8049068450927734,\n        0.7901184558868408,\n        0.7965429425239563,\n        0.8083176612854004,\n        0.8315435647964478,\n        0.8326961994171143,\n        0.8086848258972168,\n        0.7895619869232178,\n        0.7825078368186951\n      ],\n      \"q80\": [\n        1.2972460985183716,\n        1.245476484298706,\n        1.2229666709899902,\n        1.210435152053833,\n        1.1973446607589722,\n        1.157381296157837,\n        1.1181674003601074,\n        1.0869324207305908,\n        1.075097680091858,\n        1.0632023811340332,\n        1.0455275774002075,\n        1.0302590131759644,\n        1.0215204954147339,\n        1.025394320487976,\n        1.045914649963379,\n        1.0890913009643555,\n        1.1246864795684814,\n        1.125206470489502,\n        1.1208384037017822,\n        1.145365834236145,\n        1.1913384199142456,\n        1.2334762811660767,\n        1.2504417896270752,\n        1.2180585861206055,\n        1.1676602363586426,\n        1.1349562406539917\n      ],\n      \"q90\": [\n        1.3638895750045776,\n        1.323225975036621,\n        1.304998755455017,\n        1.2944636344909668,\n        1.2835395336151123,\n        1.2412294149398804,\n        1.1998721361160278,\n        1.1685125827789307,\n        1.1557502746582031,\n        1.1425185203552246,\n        1.1200439929962158,\n        1.1038810014724731,\n        1.0953530073165894,\n        1.0953185558319092,\n        1.1211014986038208,\n        1.165018916130066,\n        1.2059204578399658,\n        1.2043343782424927,\n        1.1997365951538086,\n        1.224650502204895,\n        1.2735167741775513,\n        1.3202080726623535,\n        1.3405519723892212,\n        1.304829716682434,\n        1.2509324550628662,\n        1.213225245475769\n      ]\n    },\n    {\n      \"step\": 12,\n      \"n_points\": 23,\n      \"horizon\": 25,\n      \"last_historical_date\": \"2023-11\",\n      \"historical_dates\": [\n        \"2022-01\",\n        \"2022-02\",\n        \"2022-03\",\n        \"2022-04\",\n        \"2022-05\",\n        \"2022-06\",\n        \"2022-07\",\n        \"2022-08\",\n        \"2022-09\",\n        \"2022-10\",\n        \"2022-11\",\n        \"2022-12\",\n        \"2023-01\",\n        \"2023-02\",\n        \"2023-03\",\n        \"2023-04\",\n        \"2023-05\",\n        \"2023-06\",\n        \"2023-07\",\n        \"2023-08\",\n        \"2023-09\",\n        \"2023-10\",\n        \"2023-11\"\n      ],\n      \"historical_values\": [\n        0.8899999856948853,\n        0.8899999856948853,\n        1.0199999809265137,\n        0.8799999952316284,\n        0.8500000238418579,\n        0.8799999952316284,\n        0.8799999952316284,\n        0.8999999761581421,\n        0.8799999952316284,\n        0.949999988079071,\n        0.7699999809265137,\n        0.7799999713897705,\n        0.8700000047683716,\n        0.9800000190734863,\n        1.2100000381469727,\n        1.0,\n        0.9399999976158142,\n        1.0800000429153442,\n        1.1799999475479126,\n        1.2400000095367432,\n        1.4700000286102295,\n        1.3200000524520874,\n        1.1799999475479126\n      ],\n      \"forecast_dates\": [\n        \"2023-12\",\n        \"2024-01\",\n        \"2024-02\",\n        \"2024-03\",\n        \"2024-04\",\n        \"2024-05\",\n        \"2024-06\",\n        \"2024-07\",\n        \"2024-08\",\n        \"2024-09\",\n        \"2024-10\",\n        \"2024-11\",\n        \"2024-12\",\n        \"2025-01\",\n        \"2025-02\",\n        \"2025-03\",\n        \"2025-04\",\n        \"2025-05\",\n        \"2025-06\",\n        \"2025-07\",\n        \"2025-08\",\n        \"2025-09\",\n        \"2025-10\",\n        \"2025-11\",\n        \"2025-12\"\n      ],\n      \"point_forecast\": [\n        1.1388345956802368,\n        1.1001060009002686,\n        1.0774588584899902,\n        1.0615431070327759,\n        1.0359764099121094,\n        1.0100469589233398,\n        0.9895788431167603,\n        0.9849971532821655,\n        0.9746628999710083,\n        0.9684356451034546,\n        0.9609130024909973,\n        0.947131872177124,\n        0.9499838352203369,\n        0.9744009375572205,\n        1.0130703449249268,\n        1.0441275835037231,\n        1.048531413078308,\n        1.0471770763397217,\n        1.0618387460708618,\n        1.1032129526138306,\n        1.1474189758300781,\n        1.1728394031524658,\n        1.1430011987686157,\n        1.0839053392410278,\n        1.0471035242080688\n      ],\n      \"q10\": [\n        1.143956184387207,\n        1.1164032220840454,\n        1.0988131761550903,\n        1.0883313417434692,\n        1.0633952617645264,\n        1.0377331972122192,\n        1.0185223817825317,\n        1.0154881477355957,\n        1.0130091905593872,\n        1.006235957145691,\n        0.9972001314163208,\n        0.984115719795227,\n        0.9868376851081848,\n        1.0110416412353516,\n        1.0470901727676392,\n        1.078067660331726,\n        1.0788366794586182,\n        1.0745474100112915,\n        1.0864962339401245,\n        1.1283372640609741,\n        1.1684935092926025,\n        1.194905400276184,\n        1.1594902276992798,\n        1.106303095817566,\n        1.0674790143966675\n      ],\n      \"q20\": [\n        0.9558293223381042,\n        0.9077008962631226,\n        0.875536322593689,\n        0.8599477410316467,\n        0.8395929932594299,\n        0.820803165435791,\n        0.8097033500671387,\n        0.8071569800376892,\n        0.8063573837280273,\n        0.7997854351997375,\n        0.7947160601615906,\n        0.7840617895126343,\n        0.7878046035766602,\n        0.8045357465744019,\n        0.8319349884986877,\n        0.8483662605285645,\n        0.8439525961875916,\n        0.8370295166969299,\n        0.8409282565116882,\n        0.8701899647712708,\n        0.8887082934379578,\n        0.9067206382751465,\n        0.8854538798332214,\n        0.8463788628578186,\n        0.8287973999977112\n      ],\n      \"q80\": [\n        1.2187292575836182,\n        1.1895191669464111,\n        1.1730304956436157,\n        1.1645177602767944,\n        1.1339150667190552,\n        1.1082265377044678,\n        1.0852689743041992,\n        1.0772539377212524,\n        1.0709658861160278,\n        1.0674384832382202,\n        1.0557781457901,\n        1.0452414751052856,\n        1.0445914268493652,\n        1.07282292842865,\n        1.1126301288604736,\n        1.1508023738861084,\n        1.1525412797927856,\n        1.1523170471191406,\n        1.1721690893173218,\n        1.2185754776000977,\n        1.2663267850875854,\n        1.2923482656478882,\n        1.2582346200942993,\n        1.1959534883499146,\n        1.1527845859527588\n      ],\n      \"q90\": [\n        1.2729495763778687,\n        1.2533750534057617,\n        1.2407320737838745,\n        1.2354146242141724,\n        1.2064470052719116,\n        1.1776363849639893,\n        1.1529877185821533,\n        1.1496665477752686,\n        1.1451096534729004,\n        1.137753963470459,\n        1.1235407590866089,\n        1.1123000383377075,\n        1.1126760244369507,\n        1.1393102407455444,\n        1.185707449913025,\n        1.2234959602355957,\n        1.2277790307998657,\n        1.22639799118042,\n        1.2456328868865967,\n        1.2939422130584717,\n        1.345726490020752,\n        1.3719840049743652,\n        1.338273048400879,\n        1.2677257061004639,\n        1.2217291593551636\n      ]\n    },\n    {\n      \"step\": 13,\n      \"n_points\": 24,\n      \"horizon\": 24,\n      \"last_historical_date\": \"2023-12\",\n      \"historical_dates\": [\n        \"2022-01\",\n        \"2022-02\",\n        \"2022-03\",\n        \"2022-04\",\n        \"2022-05\",\n        \"2022-06\",\n        \"2022-07\",\n        \"2022-08\",\n        \"2022-09\",\n        \"2022-10\",\n        \"2022-11\",\n        \"2022-12\",\n        \"2023-01\",\n        \"2023-02\",\n        \"2023-03\",\n        \"2023-04\",\n        \"2023-05\",\n        \"2023-06\",\n        \"2023-07\",\n        \"2023-08\",\n        \"2023-09\",\n        \"2023-10\",\n        \"2023-11\",\n        \"2023-12\"\n      ],\n      \"historical_values\": [\n        0.8899999856948853,\n        0.8899999856948853,\n        1.0199999809265137,\n        0.8799999952316284,\n        0.8500000238418579,\n        0.8799999952316284,\n        0.8799999952316284,\n        0.8999999761581421,\n        0.8799999952316284,\n        0.949999988079071,\n        0.7699999809265137,\n        0.7799999713897705,\n        0.8700000047683716,\n        0.9800000190734863,\n        1.2100000381469727,\n        1.0,\n        0.9399999976158142,\n        1.0800000429153442,\n        1.1799999475479126,\n        1.2400000095367432,\n        1.4700000286102295,\n        1.3200000524520874,\n        1.1799999475479126,\n        1.159999966621399\n      ],\n      \"forecast_dates\": [\n        \"2024-01\",\n        \"2024-02\",\n        \"2024-03\",\n        \"2024-04\",\n        \"2024-05\",\n        \"2024-06\",\n        \"2024-07\",\n        \"2024-08\",\n        \"2024-09\",\n        \"2024-10\",\n        \"2024-11\",\n        \"2024-12\",\n        \"2025-01\",\n        \"2025-02\",\n        \"2025-03\",\n        \"2025-04\",\n        \"2025-05\",\n        \"2025-06\",\n        \"2025-07\",\n        \"2025-08\",\n        \"2025-09\",\n        \"2025-10\",\n        \"2025-11\",\n        \"2025-12\"\n      ],\n      \"point_forecast\": [\n        1.1204800605773926,\n        1.0831129550933838,\n        1.0525826215744019,\n        1.0186809301376343,\n        0.996323823928833,\n        0.9761021733283997,\n        0.966797411441803,\n        0.9621630311012268,\n        0.950423002243042,\n        0.9326475262641907,\n        0.9303779602050781,\n        0.9362010955810547,\n        0.9639466404914856,\n        1.0171366930007935,\n        1.0539826154708862,\n        1.0581066608428955,\n        1.05403470993042,\n        1.07761549949646,\n        1.122676134109497,\n        1.180346965789795,\n        1.1975631713867188,\n        1.1708546876907349,\n        1.117448329925537,\n        1.0691102743148804\n      ],\n      \"q10\": [\n        1.1319338083267212,\n        1.1058242321014404,\n        1.0804548263549805,\n        1.0469233989715576,\n        1.0246795415878296,\n        1.0055618286132812,\n        0.999349057674408,\n        0.9949856996536255,\n        0.9896860718727112,\n        0.9742559194564819,\n        0.9675081968307495,\n        0.9734180569648743,\n        1.0023202896118164,\n        1.053297996520996,\n        1.090195894241333,\n        1.088844656944275,\n        1.082571029663086,\n        1.104530930519104,\n        1.1468923091888428,\n        1.2043083906173706,\n        1.2187085151672363,\n        1.19277822971344,\n        1.1290017366409302,\n        1.0879333019256592\n      ],\n      \"q20\": [\n        0.9561834335327148,\n        0.9061079621315002,\n        0.8687788844108582,\n        0.8394415378570557,\n        0.8218992948532104,\n        0.8107370138168335,\n        0.8105956315994263,\n        0.8031740784645081,\n        0.8004634380340576,\n        0.7854968309402466,\n        0.7851479053497314,\n        0.7882705330848694,\n        0.8095588684082031,\n        0.8434075117111206,\n        0.8662194013595581,\n        0.8621299862861633,\n        0.8524537682533264,\n        0.8656907677650452,\n        0.896289587020874,\n        0.9350174069404602,\n        0.940517783164978,\n        0.9245748519897461,\n        0.8929179906845093,\n        0.8636151552200317\n      ],\n      \"q80\": [\n        1.19773530960083,\n        1.1693586111068726,\n        1.14640212059021,\n        1.11386239528656,\n        1.082446813583374,\n        1.0650819540023804,\n        1.05680513381958,\n        1.0481219291687012,\n        1.0429224967956543,\n        1.024938702583313,\n        1.0191327333450317,\n        1.028489589691162,\n        1.057991862297058,\n        1.1157665252685547,\n        1.1569236516952515,\n        1.1618187427520752,\n        1.157217025756836,\n        1.1827739477157593,\n        1.2360106706619263,\n        1.2970430850982666,\n        1.3167476654052734,\n        1.2833902835845947,\n        1.2190351486206055,\n        1.1678544282913208\n      ],\n      \"q90\": [\n        1.2482070922851562,\n        1.229236364364624,\n        1.210077166557312,\n        1.18027925491333,\n        1.1515717506408691,\n        1.1297614574432373,\n        1.1205626726150513,\n        1.1177691221237183,\n        1.112573504447937,\n        1.0930581092834473,\n        1.084266185760498,\n        1.0912758111953735,\n        1.1246064901351929,\n        1.182848334312439,\n        1.2307857275009155,\n        1.2338712215423584,\n        1.2311983108520508,\n        1.2551823854446411,\n        1.3106720447540283,\n        1.3747836351394653,\n        1.3966447114944458,\n        1.3567662239074707,\n        1.2892448902130127,\n        1.23186457157135\n      ]\n    },\n    {\n      \"step\": 14,\n      \"n_points\": 25,\n      \"horizon\": 23,\n      \"last_historical_date\": \"2024-01\",\n      \"historical_dates\": [\n        \"2022-01\",\n        \"2022-02\",\n        \"2022-03\",\n        \"2022-04\",\n        \"2022-05\",\n        \"2022-06\",\n        \"2022-07\",\n        \"2022-08\",\n        \"2022-09\",\n        \"2022-10\",\n        \"2022-11\",\n        \"2022-12\",\n        \"2023-01\",\n        \"2023-02\",\n        \"2023-03\",\n        \"2023-04\",\n        \"2023-05\",\n        \"2023-06\",\n        \"2023-07\",\n        \"2023-08\",\n        \"2023-09\",\n        \"2023-10\",\n        \"2023-11\",\n        \"2023-12\",\n        \"2024-01\"\n      ],\n      \"historical_values\": [\n        0.8899999856948853,\n        0.8899999856948853,\n        1.0199999809265137,\n        0.8799999952316284,\n        0.8500000238418579,\n        0.8799999952316284,\n        0.8799999952316284,\n        0.8999999761581421,\n        0.8799999952316284,\n        0.949999988079071,\n        0.7699999809265137,\n        0.7799999713897705,\n        0.8700000047683716,\n        0.9800000190734863,\n        1.2100000381469727,\n        1.0,\n        0.9399999976158142,\n        1.0800000429153442,\n        1.1799999475479126,\n        1.2400000095367432,\n        1.4700000286102295,\n        1.3200000524520874,\n        1.1799999475479126,\n        1.159999966621399,\n        1.2200000286102295\n      ],\n      \"forecast_dates\": [\n        \"2024-02\",\n        \"2024-03\",\n        \"2024-04\",\n        \"2024-05\",\n        \"2024-06\",\n        \"2024-07\",\n        \"2024-08\",\n        \"2024-09\",\n        \"2024-10\",\n        \"2024-11\",\n        \"2024-12\",\n        \"2025-01\",\n        \"2025-02\",\n        \"2025-03\",\n        \"2025-04\",\n        \"2025-05\",\n        \"2025-06\",\n        \"2025-07\",\n        \"2025-08\",\n        \"2025-09\",\n        \"2025-10\",\n        \"2025-11\",\n        \"2025-12\"\n      ],\n      \"point_forecast\": [\n        1.1701693534851074,\n        1.1349387168884277,\n        1.0960110425949097,\n        1.0637617111206055,\n        1.0402418375015259,\n        1.0265028476715088,\n        1.0204156637191772,\n        1.0119004249572754,\n        0.9878545999526978,\n        0.9743345379829407,\n        0.9826735258102417,\n        0.9942994117736816,\n        1.0274856090545654,\n        1.0622732639312744,\n        1.0651201009750366,\n        1.073114037513733,\n        1.0891278982162476,\n        1.126175880432129,\n        1.157884955406189,\n        1.1790106296539307,\n        1.1665725708007812,\n        1.1304852962493896,\n        1.1106657981872559\n      ],\n      \"q10\": [\n        1.1749104261398315,\n        1.147524118423462,\n        1.1174193620681763,\n        1.086887001991272,\n        1.0630450248718262,\n        1.0531063079833984,\n        1.0497565269470215,\n        1.042683482170105,\n        1.0233265161514282,\n        1.0111165046691895,\n        1.014377236366272,\n        1.0274351835250854,\n        1.062585711479187,\n        1.0902963876724243,\n        1.0922062397003174,\n        1.0912160873413086,\n        1.11197829246521,\n        1.1466877460479736,\n        1.1778086423873901,\n        1.199917197227478,\n        1.1789664030075073,\n        1.1495457887649536,\n        1.123175859451294\n      ],\n      \"q20\": [\n        0.9954406023025513,\n        0.9378616213798523,\n        0.893646240234375,\n        0.8610368967056274,\n        0.8414109945297241,\n        0.8318982124328613,\n        0.829987645149231,\n        0.8171640634536743,\n        0.8035246729850769,\n        0.7929065227508545,\n        0.8037456274032593,\n        0.8133399486541748,\n        0.8395006656646729,\n        0.8581016659736633,\n        0.8608949184417725,\n        0.8612385392189026,\n        0.8694777488708496,\n        0.896060585975647,\n        0.9189809560775757,\n        0.9354234337806702,\n        0.918700635433197,\n        0.8955419063568115,\n        0.8815087676048279\n      ],\n      \"q80\": [\n        1.2475481033325195,\n        1.2218120098114014,\n        1.1920394897460938,\n        1.1621203422546387,\n        1.1338578462600708,\n        1.1270941495895386,\n        1.1244370937347412,\n        1.11036217212677,\n        1.0929012298583984,\n        1.0770790576934814,\n        1.0825059413909912,\n        1.0962635278701782,\n        1.133682131767273,\n        1.166754126548767,\n        1.1711883544921875,\n        1.1767802238464355,\n        1.1959017515182495,\n        1.2360646724700928,\n        1.272753357887268,\n        1.293941855430603,\n        1.2775542736053467,\n        1.2417978048324585,\n        1.215286374092102\n      ],\n      \"q90\": [\n        1.2978864908218384,\n        1.2807369232177734,\n        1.2577829360961914,\n        1.2319256067276,\n        1.2072914838790894,\n        1.1944835186004639,\n        1.1949646472930908,\n        1.1887325048446655,\n        1.1706409454345703,\n        1.1535823345184326,\n        1.1557773351669312,\n        1.165435791015625,\n        1.2051732540130615,\n        1.2392327785491943,\n        1.2467840909957886,\n        1.2500178813934326,\n        1.2747631072998047,\n        1.3121440410614014,\n        1.3449633121490479,\n        1.3688087463378906,\n        1.353420376777649,\n        1.3119401931762695,\n        1.2864404916763306\n      ]\n    },\n    {\n      \"step\": 15,\n      \"n_points\": 26,\n      \"horizon\": 22,\n      \"last_historical_date\": \"2024-02\",\n      \"historical_dates\": [\n        \"2022-01\",\n        \"2022-02\",\n        \"2022-03\",\n        \"2022-04\",\n        \"2022-05\",\n        \"2022-06\",\n        \"2022-07\",\n        \"2022-08\",\n        \"2022-09\",\n        \"2022-10\",\n        \"2022-11\",\n        \"2022-12\",\n        \"2023-01\",\n        \"2023-02\",\n        \"2023-03\",\n        \"2023-04\",\n        \"2023-05\",\n        \"2023-06\",\n        \"2023-07\",\n        \"2023-08\",\n        \"2023-09\",\n        \"2023-10\",\n        \"2023-11\",\n        \"2023-12\",\n        \"2024-01\",\n        \"2024-02\"\n      ],\n      \"historical_values\": [\n        0.8899999856948853,\n        0.8899999856948853,\n        1.0199999809265137,\n        0.8799999952316284,\n        0.8500000238418579,\n        0.8799999952316284,\n        0.8799999952316284,\n        0.8999999761581421,\n        0.8799999952316284,\n        0.949999988079071,\n        0.7699999809265137,\n        0.7799999713897705,\n        0.8700000047683716,\n        0.9800000190734863,\n        1.2100000381469727,\n        1.0,\n        0.9399999976158142,\n        1.0800000429153442,\n        1.1799999475479126,\n        1.2400000095367432,\n        1.4700000286102295,\n        1.3200000524520874,\n        1.1799999475479126,\n        1.159999966621399,\n        1.2200000286102295,\n        1.350000023841858\n      ],\n      \"forecast_dates\": [\n        \"2024-03\",\n        \"2024-04\",\n        \"2024-05\",\n        \"2024-06\",\n        \"2024-07\",\n        \"2024-08\",\n        \"2024-09\",\n        \"2024-10\",\n        \"2024-11\",\n        \"2024-12\",\n        \"2025-01\",\n        \"2025-02\",\n        \"2025-03\",\n        \"2025-04\",\n        \"2025-05\",\n        \"2025-06\",\n        \"2025-07\",\n        \"2025-08\",\n        \"2025-09\",\n        \"2025-10\",\n        \"2025-11\",\n        \"2025-12\"\n      ],\n      \"point_forecast\": [\n        1.2504206895828247,\n        1.2035315036773682,\n        1.1605435609817505,\n        1.1372957229614258,\n        1.1169792413711548,\n        1.1097407341003418,\n        1.0960330963134766,\n        1.0716885328292847,\n        1.0411385297775269,\n        1.0377408266067505,\n        1.06381356716156,\n        1.09853994846344,\n        1.1106432676315308,\n        1.1038997173309326,\n        1.0912792682647705,\n        1.10673189163208,\n        1.128816843032837,\n        1.1672472953796387,\n        1.169884204864502,\n        1.1492640972137451,\n        1.1251699924468994,\n        1.1080049276351929\n      ],\n      \"q10\": [\n        1.253143310546875,\n        1.2137634754180908,\n        1.175628900527954,\n        1.158146858215332,\n        1.1375560760498047,\n        1.1330972909927368,\n        1.1224530935287476,\n        1.0991952419281006,\n        1.0732285976409912,\n        1.069901704788208,\n        1.0908238887786865,\n        1.1302318572998047,\n        1.1447051763534546,\n        1.1265060901641846,\n        1.1150192022323608,\n        1.1237907409667969,\n        1.1495832204818726,\n        1.187064528465271,\n        1.191187858581543,\n        1.1717422008514404,\n        1.1371166706085205,\n        1.1280303001403809\n      ],\n      \"q20\": [\n        1.0437579154968262,\n        0.9754042625427246,\n        0.9281424283981323,\n        0.8999512791633606,\n        0.8835805058479309,\n        0.8786535263061523,\n        0.868209958076477,\n        0.8477093577384949,\n        0.8295252919197083,\n        0.8285472989082336,\n        0.8487096428871155,\n        0.8732921481132507,\n        0.8824164271354675,\n        0.8700266480445862,\n        0.8598465323448181,\n        0.8674743175506592,\n        0.8803960084915161,\n        0.9123423099517822,\n        0.9124201536178589,\n        0.8980945348739624,\n        0.8717573881149292,\n        0.8591221570968628\n      ],\n      \"q80\": [\n        1.3365967273712158,\n        1.29902184009552,\n        1.2669174671173096,\n        1.2462443113327026,\n        1.2251611948013306,\n        1.224426031112671,\n        1.2126585245132446,\n        1.1816699504852295,\n        1.1577259302139282,\n        1.1497776508331299,\n        1.1759350299835205,\n        1.2160439491271973,\n        1.2304400205612183,\n        1.2202222347259521,\n        1.2069144248962402,\n        1.2211333513259888,\n        1.2466362714767456,\n        1.2859277725219727,\n        1.2911059856414795,\n        1.2705645561218262,\n        1.2402691841125488,\n        1.225570797920227\n      ],\n      \"q90\": [\n        1.394209623336792,\n        1.3661998510360718,\n        1.3383913040161133,\n        1.3226557970046997,\n        1.3062965869903564,\n        1.3001211881637573,\n        1.2918630838394165,\n        1.267607569694519,\n        1.2428820133209229,\n        1.2324764728546143,\n        1.254516839981079,\n        1.291373372077942,\n        1.3111931085586548,\n        1.2970809936523438,\n        1.2853456735610962,\n        1.3002972602844238,\n        1.330471396446228,\n        1.3700449466705322,\n        1.3697110414505005,\n        1.346665382385254,\n        1.31707763671875,\n        1.3017767667770386\n      ]\n    },\n    {\n      \"step\": 16,\n      \"n_points\": 27,\n      \"horizon\": 21,\n      \"last_historical_date\": \"2024-03\",\n      \"historical_dates\": [\n        \"2022-01\",\n        \"2022-02\",\n        \"2022-03\",\n        \"2022-04\",\n        \"2022-05\",\n        \"2022-06\",\n        \"2022-07\",\n        \"2022-08\",\n        \"2022-09\",\n        \"2022-10\",\n        \"2022-11\",\n        \"2022-12\",\n        \"2023-01\",\n        \"2023-02\",\n        \"2023-03\",\n        \"2023-04\",\n        \"2023-05\",\n        \"2023-06\",\n        \"2023-07\",\n        \"2023-08\",\n        \"2023-09\",\n        \"2023-10\",\n        \"2023-11\",\n        \"2023-12\",\n        \"2024-01\",\n        \"2024-02\",\n        \"2024-03\"\n      ],\n      \"historical_values\": [\n        0.8899999856948853,\n        0.8899999856948853,\n        1.0199999809265137,\n        0.8799999952316284,\n        0.8500000238418579,\n        0.8799999952316284,\n        0.8799999952316284,\n        0.8999999761581421,\n        0.8799999952316284,\n        0.949999988079071,\n        0.7699999809265137,\n        0.7799999713897705,\n        0.8700000047683716,\n        0.9800000190734863,\n        1.2100000381469727,\n        1.0,\n        0.9399999976158142,\n        1.0800000429153442,\n        1.1799999475479126,\n        1.2400000095367432,\n        1.4700000286102295,\n        1.3200000524520874,\n        1.1799999475479126,\n        1.159999966621399,\n        1.2200000286102295,\n        1.350000023841858,\n        1.340000033378601\n      ],\n      \"forecast_dates\": [\n        \"2024-04\",\n        \"2024-05\",\n        \"2024-06\",\n        \"2024-07\",\n        \"2024-08\",\n        \"2024-09\",\n        \"2024-10\",\n        \"2024-11\",\n        \"2024-12\",\n        \"2025-01\",\n        \"2025-02\",\n        \"2025-03\",\n        \"2025-04\",\n        \"2025-05\",\n        \"2025-06\",\n        \"2025-07\",\n        \"2025-08\",\n        \"2025-09\",\n        \"2025-10\",\n        \"2025-11\",\n        \"2025-12\"\n      ],\n      \"point_forecast\": [\n        1.2523874044418335,\n        1.2066220045089722,\n        1.1746571063995361,\n        1.1765081882476807,\n        1.1709487438201904,\n        1.169347882270813,\n        1.1399660110473633,\n        1.1141448020935059,\n        1.094247817993164,\n        1.0913820266723633,\n        1.1216974258422852,\n        1.1433929204940796,\n        1.1276500225067139,\n        1.1138465404510498,\n        1.1109668016433716,\n        1.1382179260253906,\n        1.145559310913086,\n        1.165015697479248,\n        1.1428844928741455,\n        1.1122182607650757,\n        1.1095082759857178\n      ],\n      \"q10\": [\n        1.2494522333145142,\n        1.2100024223327637,\n        1.1815905570983887,\n        1.184570550918579,\n        1.181471824645996,\n        1.1847987174987793,\n        1.1554681062698364,\n        1.1273032426834106,\n        1.1124141216278076,\n        1.1068137884140015,\n        1.1349601745605469,\n        1.160623550415039,\n        1.1481659412384033,\n        1.1232229471206665,\n        1.1228114366531372,\n        1.1419509649276733,\n        1.1522048711776733,\n        1.1742281913757324,\n        1.1551659107208252,\n        1.1268976926803589,\n        1.112238883972168\n      ],\n      \"q20\": [\n        1.0595918893814087,\n        0.9882703423500061,\n        0.9449520111083984,\n        0.9323371648788452,\n        0.921808123588562,\n        0.9140236973762512,\n        0.8879625797271729,\n        0.8599287271499634,\n        0.84772127866745,\n        0.8464851975440979,\n        0.8668861389160156,\n        0.8764016032218933,\n        0.862370491027832,\n        0.8420681953430176,\n        0.8450419306755066,\n        0.8666462898254395,\n        0.8749760985374451,\n        0.8925336003303528,\n        0.8715018033981323,\n        0.8530272841453552,\n        0.8424127697944641\n      ],\n      \"q80\": [\n        1.3265814781188965,\n        1.2932963371276855,\n        1.2723067998886108,\n        1.276952862739563,\n        1.2762058973312378,\n        1.284961462020874,\n        1.2592799663543701,\n        1.2249560356140137,\n        1.213465929031372,\n        1.2041243314743042,\n        1.2399941682815552,\n        1.2660539150238037,\n        1.2495875358581543,\n        1.2333945035934448,\n        1.2315037250518799,\n        1.2564735412597656,\n        1.264156699180603,\n        1.2841299772262573,\n        1.2626703977584839,\n        1.2329728603363037,\n        1.2221158742904663\n      ],\n      \"q90\": [\n        1.3771872520446777,\n        1.3524072170257568,\n        1.3376163244247437,\n        1.347804307937622,\n        1.3534436225891113,\n        1.3581876754760742,\n        1.3364894390106201,\n        1.3079556226730347,\n        1.295357346534729,\n        1.2872941493988037,\n        1.3177791833877563,\n        1.340587854385376,\n        1.3293620347976685,\n        1.3092248439788818,\n        1.309072494506836,\n        1.3312009572982788,\n        1.3418974876403809,\n        1.3621940612792969,\n        1.3367794752120972,\n        1.3070871829986572,\n        1.296994686126709\n      ]\n    },\n    {\n      \"step\": 17,\n      \"n_points\": 28,\n      \"horizon\": 20,\n      \"last_historical_date\": \"2024-04\",\n      \"historical_dates\": [\n        \"2022-01\",\n        \"2022-02\",\n        \"2022-03\",\n        \"2022-04\",\n        \"2022-05\",\n        \"2022-06\",\n        \"2022-07\",\n        \"2022-08\",\n        \"2022-09\",\n        \"2022-10\",\n        \"2022-11\",\n        \"2022-12\",\n        \"2023-01\",\n        \"2023-02\",\n        \"2023-03\",\n        \"2023-04\",\n        \"2023-05\",\n        \"2023-06\",\n        \"2023-07\",\n        \"2023-08\",\n        \"2023-09\",\n        \"2023-10\",\n        \"2023-11\",\n        \"2023-12\",\n        \"2024-01\",\n        \"2024-02\",\n        \"2024-03\",\n        \"2024-04\"\n      ],\n      \"historical_values\": [\n        0.8899999856948853,\n        0.8899999856948853,\n        1.0199999809265137,\n        0.8799999952316284,\n        0.8500000238418579,\n        0.8799999952316284,\n        0.8799999952316284,\n        0.8999999761581421,\n        0.8799999952316284,\n        0.949999988079071,\n        0.7699999809265137,\n        0.7799999713897705,\n        0.8700000047683716,\n        0.9800000190734863,\n        1.2100000381469727,\n        1.0,\n        0.9399999976158142,\n        1.0800000429153442,\n        1.1799999475479126,\n        1.2400000095367432,\n        1.4700000286102295,\n        1.3200000524520874,\n        1.1799999475479126,\n        1.159999966621399,\n        1.2200000286102295,\n        1.350000023841858,\n        1.340000033378601,\n        1.2599999904632568\n      ],\n      \"forecast_dates\": [\n        \"2024-05\",\n        \"2024-06\",\n        \"2024-07\",\n        \"2024-08\",\n        \"2024-09\",\n        \"2024-10\",\n        \"2024-11\",\n        \"2024-12\",\n        \"2025-01\",\n        \"2025-02\",\n        \"2025-03\",\n        \"2025-04\",\n        \"2025-05\",\n        \"2025-06\",\n        \"2025-07\",\n        \"2025-08\",\n        \"2025-09\",\n        \"2025-10\",\n        \"2025-11\",\n        \"2025-12\"\n      ],\n      \"point_forecast\": [\n        1.2068676948547363,\n        1.1843211650848389,\n        1.1752288341522217,\n        1.17955482006073,\n        1.1717453002929688,\n        1.1482445001602173,\n        1.1248430013656616,\n        1.1241732835769653,\n        1.1235134601593018,\n        1.1300708055496216,\n        1.1367747783660889,\n        1.1233289241790771,\n        1.1131789684295654,\n        1.1212987899780273,\n        1.1275365352630615,\n        1.1452269554138184,\n        1.1476627588272095,\n        1.1389117240905762,\n        1.1231611967086792,\n        1.1179301738739014\n      ],\n      \"q10\": [\n        1.202960729598999,\n        1.1801354885101318,\n        1.1744948625564575,\n        1.178760290145874,\n        1.1708077192306519,\n        1.152012586593628,\n        1.1264581680297852,\n        1.1220771074295044,\n        1.12774658203125,\n        1.1319509744644165,\n        1.1353538036346436,\n        1.1257888078689575,\n        1.1163818836212158,\n        1.1152591705322266,\n        1.1232290267944336,\n        1.1383938789367676,\n        1.1435673236846924,\n        1.131921648979187,\n        1.1226390600204468,\n        1.115145206451416\n      ],\n      \"q20\": [\n        1.0335861444473267,\n        0.9781290292739868,\n        0.948025643825531,\n        0.937298595905304,\n        0.9195546507835388,\n        0.8911022543907166,\n        0.8684503436088562,\n        0.8581703901290894,\n        0.8552865386009216,\n        0.8566405177116394,\n        0.8587369918823242,\n        0.8421598076820374,\n        0.8355081081390381,\n        0.835259735584259,\n        0.8424496650695801,\n        0.8557251691818237,\n        0.8595790863037109,\n        0.8550817966461182,\n        0.8462545871734619,\n        0.8529651761054993\n      ],\n      \"q80\": [\n        1.2702223062515259,\n        1.2614306211471558,\n        1.2629116773605347,\n        1.27401602268219,\n        1.2682753801345825,\n        1.253630518913269,\n        1.23259437084198,\n        1.2252973318099976,\n        1.2373583316802979,\n        1.2451832294464111,\n        1.2524268627166748,\n        1.2415071725845337,\n        1.2297941446304321,\n        1.2318909168243408,\n        1.242499828338623,\n        1.2596397399902344,\n        1.26153564453125,\n        1.2511622905731201,\n        1.2375322580337524,\n        1.2279977798461914\n      ],\n      \"q90\": [\n        1.3145431280136108,\n        1.313429594039917,\n        1.3208061456680298,\n        1.3359402418136597,\n        1.3367944955825806,\n        1.3163673877716064,\n        1.2994139194488525,\n        1.2974282503128052,\n        1.3131386041641235,\n        1.3206769227981567,\n        1.325730800628662,\n        1.3097118139266968,\n        1.2984554767608643,\n        1.2993582487106323,\n        1.3110051155090332,\n        1.328444242477417,\n        1.3302743434906006,\n        1.3201899528503418,\n        1.3005670309066772,\n        1.295330286026001\n      ]\n    },\n    {\n      \"step\": 18,\n      \"n_points\": 29,\n      \"horizon\": 19,\n      \"last_historical_date\": \"2024-05\",\n      \"historical_dates\": [\n        \"2022-01\",\n        \"2022-02\",\n        \"2022-03\",\n        \"2022-04\",\n        \"2022-05\",\n        \"2022-06\",\n        \"2022-07\",\n        \"2022-08\",\n        \"2022-09\",\n        \"2022-10\",\n        \"2022-11\",\n        \"2022-12\",\n        \"2023-01\",\n        \"2023-02\",\n        \"2023-03\",\n        \"2023-04\",\n        \"2023-05\",\n        \"2023-06\",\n        \"2023-07\",\n        \"2023-08\",\n        \"2023-09\",\n        \"2023-10\",\n        \"2023-11\",\n        \"2023-12\",\n        \"2024-01\",\n        \"2024-02\",\n        \"2024-03\",\n        \"2024-04\",\n        \"2024-05\"\n      ],\n      \"historical_values\": [\n        0.8899999856948853,\n        0.8899999856948853,\n        1.0199999809265137,\n        0.8799999952316284,\n        0.8500000238418579,\n        0.8799999952316284,\n        0.8799999952316284,\n        0.8999999761581421,\n        0.8799999952316284,\n        0.949999988079071,\n        0.7699999809265137,\n        0.7799999713897705,\n        0.8700000047683716,\n        0.9800000190734863,\n        1.2100000381469727,\n        1.0,\n        0.9399999976158142,\n        1.0800000429153442,\n        1.1799999475479126,\n        1.2400000095367432,\n        1.4700000286102295,\n        1.3200000524520874,\n        1.1799999475479126,\n        1.159999966621399,\n        1.2200000286102295,\n        1.350000023841858,\n        1.340000033378601,\n        1.2599999904632568,\n        1.149999976158142\n      ],\n      \"forecast_dates\": [\n        \"2024-06\",\n        \"2024-07\",\n        \"2024-08\",\n        \"2024-09\",\n        \"2024-10\",\n        \"2024-11\",\n        \"2024-12\",\n        \"2025-01\",\n        \"2025-02\",\n        \"2025-03\",\n        \"2025-04\",\n        \"2025-05\",\n        \"2025-06\",\n        \"2025-07\",\n        \"2025-08\",\n        \"2025-09\",\n        \"2025-10\",\n        \"2025-11\",\n        \"2025-12\"\n      ],\n      \"point_forecast\": [\n        1.1386852264404297,\n        1.1227259635925293,\n        1.1132360696792603,\n        1.103696346282959,\n        1.0890148878097534,\n        1.0628618001937866,\n        1.0592650175094604,\n        1.0809025764465332,\n        1.1213948726654053,\n        1.1205977201461792,\n        1.10319983959198,\n        1.0873777866363525,\n        1.0977184772491455,\n        1.1334233283996582,\n        1.1537142992019653,\n        1.15865159034729,\n        1.1413378715515137,\n        1.1311604976654053,\n        1.1258361339569092\n      ],\n      \"q10\": [\n        1.1357723474502563,\n        1.1218345165252686,\n        1.1151096820831299,\n        1.1036633253097534,\n        1.088782787322998,\n        1.0708427429199219,\n        1.0614827871322632,\n        1.0803805589675903,\n        1.1256681680679321,\n        1.124110460281372,\n        1.1017175912857056,\n        1.0866585969924927,\n        1.0974124670028687,\n        1.1265218257904053,\n        1.1448237895965576,\n        1.150303602218628,\n        1.131263256072998,\n        1.1206773519515991,\n        1.1218606233596802\n      ],\n      \"q20\": [\n        0.9705875515937805,\n        0.9261521100997925,\n        0.9002217650413513,\n        0.8800909519195557,\n        0.8597927689552307,\n        0.837051272392273,\n        0.8270405530929565,\n        0.8327914476394653,\n        0.8583639860153198,\n        0.8556785583496094,\n        0.8432221412658691,\n        0.8295676708221436,\n        0.8404796719551086,\n        0.8643808364868164,\n        0.8823158740997314,\n        0.8855088949203491,\n        0.8733288049697876,\n        0.8654991388320923,\n        0.8692165017127991\n      ],\n      \"q80\": [\n        1.2012592554092407,\n        1.2004612684249878,\n        1.1944599151611328,\n        1.1941598653793335,\n        1.178646206855774,\n        1.1608107089996338,\n        1.156977653503418,\n        1.1782780885696411,\n        1.2296812534332275,\n        1.235266089439392,\n        1.2120579481124878,\n        1.1956090927124023,\n        1.2027981281280518,\n        1.2328962087631226,\n        1.2583279609680176,\n        1.2622652053833008,\n        1.2420697212219238,\n        1.2296068668365479,\n        1.2310649156570435\n      ],\n      \"q90\": [\n        1.2429797649383545,\n        1.248335599899292,\n        1.2536859512329102,\n        1.251412272453308,\n        1.241403341293335,\n        1.21868097782135,\n        1.2173688411712646,\n        1.244056224822998,\n        1.3022620677947998,\n        1.3048560619354248,\n        1.2794227600097656,\n        1.25494384765625,\n        1.2627326250076294,\n        1.292664885520935,\n        1.3210376501083374,\n        1.3248177766799927,\n        1.303199291229248,\n        1.290137529373169,\n        1.2883186340332031\n      ]\n    },\n    {\n      \"step\": 19,\n      \"n_points\": 30,\n      \"horizon\": 18,\n      \"last_historical_date\": \"2024-06\",\n      \"historical_dates\": [\n        \"2022-01\",\n        \"2022-02\",\n        \"2022-03\",\n        \"2022-04\",\n        \"2022-05\",\n        \"2022-06\",\n        \"2022-07\",\n        \"2022-08\",\n        \"2022-09\",\n        \"2022-10\",\n        \"2022-11\",\n        \"2022-12\",\n        \"2023-01\",\n        \"2023-02\",\n        \"2023-03\",\n        \"2023-04\",\n        \"2023-05\",\n        \"2023-06\",\n        \"2023-07\",\n        \"2023-08\",\n        \"2023-09\",\n        \"2023-10\",\n        \"2023-11\",\n        \"2023-12\",\n        \"2024-01\",\n        \"2024-02\",\n        \"2024-03\",\n        \"2024-04\",\n        \"2024-05\",\n        \"2024-06\"\n      ],\n      \"historical_values\": [\n        0.8899999856948853,\n        0.8899999856948853,\n        1.0199999809265137,\n        0.8799999952316284,\n        0.8500000238418579,\n        0.8799999952316284,\n        0.8799999952316284,\n        0.8999999761581421,\n        0.8799999952316284,\n        0.949999988079071,\n        0.7699999809265137,\n        0.7799999713897705,\n        0.8700000047683716,\n        0.9800000190734863,\n        1.2100000381469727,\n        1.0,\n        0.9399999976158142,\n        1.0800000429153442,\n        1.1799999475479126,\n        1.2400000095367432,\n        1.4700000286102295,\n        1.3200000524520874,\n        1.1799999475479126,\n        1.159999966621399,\n        1.2200000286102295,\n        1.350000023841858,\n        1.340000033378601,\n        1.2599999904632568,\n        1.149999976158142,\n        1.2000000476837158\n      ],\n      \"forecast_dates\": [\n        \"2024-07\",\n        \"2024-08\",\n        \"2024-09\",\n        \"2024-10\",\n        \"2024-11\",\n        \"2024-12\",\n        \"2025-01\",\n        \"2025-02\",\n        \"2025-03\",\n        \"2025-04\",\n        \"2025-05\",\n        \"2025-06\",\n        \"2025-07\",\n        \"2025-08\",\n        \"2025-09\",\n        \"2025-10\",\n        \"2025-11\",\n        \"2025-12\"\n      ],\n      \"point_forecast\": [\n        1.1765440702438354,\n        1.1661514043807983,\n        1.1520631313323975,\n        1.1195285320281982,\n        1.0856300592422485,\n        1.0768202543258667,\n        1.0964417457580566,\n        1.1255871057510376,\n        1.155031442642212,\n        1.1183977127075195,\n        1.1013360023498535,\n        1.1082254648208618,\n        1.1356239318847656,\n        1.1829569339752197,\n        1.1888995170593262,\n        1.159764051437378,\n        1.126434564590454,\n        1.1302133798599243\n      ],\n      \"q10\": [\n        1.1751192808151245,\n        1.1651133298873901,\n        1.1592530012130737,\n        1.1195036172866821,\n        1.084028959274292,\n        1.0865756273269653,\n        1.099607229232788,\n        1.1274793148040771,\n        1.160447597503662,\n        1.1203389167785645,\n        1.0989832878112793,\n        1.1072871685028076,\n        1.1345447301864624,\n        1.1779069900512695,\n        1.1820926666259766,\n        1.1511759757995605,\n        1.1156119108200073,\n        1.121741771697998\n      ],\n      \"q20\": [\n        1.0206873416900635,\n        0.9838167428970337,\n        0.9575520157814026,\n        0.9151738882064819,\n        0.8827507495880127,\n        0.876349151134491,\n        0.8842628002166748,\n        0.8949983716011047,\n        0.9151624441146851,\n        0.883825421333313,\n        0.877031147480011,\n        0.8801717162132263,\n        0.9021454453468323,\n        0.9322755336761475,\n        0.9382153153419495,\n        0.9139386415481567,\n        0.8896767497062683,\n        0.8937186598777771\n      ],\n      \"q80\": [\n        1.2346465587615967,\n        1.238021969795227,\n        1.2284244298934937,\n        1.199608564376831,\n        1.1668167114257812,\n        1.165637731552124,\n        1.1883985996246338,\n        1.2180571556091309,\n        1.25492525100708,\n        1.219463586807251,\n        1.1989303827285767,\n        1.2049015760421753,\n        1.2325347661972046,\n        1.276305079460144,\n        1.2895640134811401,\n        1.2548282146453857,\n        1.217138648033142,\n        1.2198824882507324\n      ],\n      \"q90\": [\n        1.2738851308822632,\n        1.2814069986343384,\n        1.2860920429229736,\n        1.251664638519287,\n        1.2245914936065674,\n        1.2196787595748901,\n        1.2461426258087158,\n        1.2824065685272217,\n        1.3231412172317505,\n        1.2859265804290771,\n        1.2610337734222412,\n        1.2612855434417725,\n        1.2891101837158203,\n        1.3355929851531982,\n        1.3490444421768188,\n        1.3118960857391357,\n        1.2749104499816895,\n        1.2770724296569824\n      ]\n    },\n    {\n      \"step\": 20,\n      \"n_points\": 31,\n      \"horizon\": 17,\n      \"last_historical_date\": \"2024-07\",\n      \"historical_dates\": [\n        \"2022-01\",\n        \"2022-02\",\n        \"2022-03\",\n        \"2022-04\",\n        \"2022-05\",\n        \"2022-06\",\n        \"2022-07\",\n        \"2022-08\",\n        \"2022-09\",\n        \"2022-10\",\n        \"2022-11\",\n        \"2022-12\",\n        \"2023-01\",\n        \"2023-02\",\n        \"2023-03\",\n        \"2023-04\",\n        \"2023-05\",\n        \"2023-06\",\n        \"2023-07\",\n        \"2023-08\",\n        \"2023-09\",\n        \"2023-10\",\n        \"2023-11\",\n        \"2023-12\",\n        \"2024-01\",\n        \"2024-02\",\n        \"2024-03\",\n        \"2024-04\",\n        \"2024-05\",\n        \"2024-06\",\n        \"2024-07\"\n      ],\n      \"historical_values\": [\n        0.8899999856948853,\n        0.8899999856948853,\n        1.0199999809265137,\n        0.8799999952316284,\n        0.8500000238418579,\n        0.8799999952316284,\n        0.8799999952316284,\n        0.8999999761581421,\n        0.8799999952316284,\n        0.949999988079071,\n        0.7699999809265137,\n        0.7799999713897705,\n        0.8700000047683716,\n        0.9800000190734863,\n        1.2100000381469727,\n        1.0,\n        0.9399999976158142,\n        1.0800000429153442,\n        1.1799999475479126,\n        1.2400000095367432,\n        1.4700000286102295,\n        1.3200000524520874,\n        1.1799999475479126,\n        1.159999966621399,\n        1.2200000286102295,\n        1.350000023841858,\n        1.340000033378601,\n        1.2599999904632568,\n        1.149999976158142,\n        1.2000000476837158,\n        1.2400000095367432\n      ],\n      \"forecast_dates\": [\n        \"2024-08\",\n        \"2024-09\",\n        \"2024-10\",\n        \"2024-11\",\n        \"2024-12\",\n        \"2025-01\",\n        \"2025-02\",\n        \"2025-03\",\n        \"2025-04\",\n        \"2025-05\",\n        \"2025-06\",\n        \"2025-07\",\n        \"2025-08\",\n        \"2025-09\",\n        \"2025-10\",\n        \"2025-11\",\n        \"2025-12\"\n      ],\n      \"point_forecast\": [\n        1.2069008350372314,\n        1.193657636642456,\n        1.1575161218643188,\n        1.133849859237671,\n        1.1235467195510864,\n        1.1252387762069702,\n        1.1443586349487305,\n        1.1506012678146362,\n        1.134141206741333,\n        1.1200145483016968,\n        1.133240818977356,\n        1.1518402099609375,\n        1.1871200799942017,\n        1.2083441019058228,\n        1.1728931665420532,\n        1.1432249546051025,\n        1.133898377418518\n      ],\n      \"q10\": [\n        1.2029995918273926,\n        1.1932077407836914,\n        1.1641241312026978,\n        1.1338424682617188,\n        1.12429940700531,\n        1.1312663555145264,\n        1.1450644731521606,\n        1.1525075435638428,\n        1.1395219564437866,\n        1.121511697769165,\n        1.132306456565857,\n        1.1525789499282837,\n        1.1869525909423828,\n        1.2014143466949463,\n        1.168949007987976,\n        1.132044792175293,\n        1.1256910562515259\n      ],\n      \"q20\": [\n        1.0395362377166748,\n        0.9963122606277466,\n        0.951080322265625,\n        0.9185925126075745,\n        0.9062104821205139,\n        0.9065833687782288,\n        0.9181973934173584,\n        0.9136454463005066,\n        0.9018174409866333,\n        0.8859837055206299,\n        0.8985838890075684,\n        0.9077322483062744,\n        0.9346957206726074,\n        0.9418572187423706,\n        0.9179990291595459,\n        0.8948585987091064,\n        0.88960200548172\n      ],\n      \"q80\": [\n        1.2682971954345703,\n        1.2702425718307495,\n        1.239664077758789,\n        1.2174897193908691,\n        1.2065781354904175,\n        1.2155629396438599,\n        1.2398309707641602,\n        1.240811824798584,\n        1.2331410646438599,\n        1.2164467573165894,\n        1.2326842546463013,\n        1.251672387123108,\n        1.2885876893997192,\n        1.3034658432006836,\n        1.2739078998565674,\n        1.2376470565795898,\n        1.2269697189331055\n      ],\n      \"q90\": [\n        1.3094301223754883,\n        1.3151092529296875,\n        1.2952126264572144,\n        1.273212194442749,\n        1.2679336071014404,\n        1.2713508605957031,\n        1.2985038757324219,\n        1.305957555770874,\n        1.2964022159576416,\n        1.2809231281280518,\n        1.2935620546340942,\n        1.3095386028289795,\n        1.3458640575408936,\n        1.366443157196045,\n        1.3342082500457764,\n        1.2930103540420532,\n        1.2838038206100464\n      ]\n    },\n    {\n      \"step\": 21,\n      \"n_points\": 32,\n      \"horizon\": 16,\n      \"last_historical_date\": \"2024-08\",\n      \"historical_dates\": [\n        \"2022-01\",\n        \"2022-02\",\n        \"2022-03\",\n        \"2022-04\",\n        \"2022-05\",\n        \"2022-06\",\n        \"2022-07\",\n        \"2022-08\",\n        \"2022-09\",\n        \"2022-10\",\n        \"2022-11\",\n        \"2022-12\",\n        \"2023-01\",\n        \"2023-02\",\n        \"2023-03\",\n        \"2023-04\",\n        \"2023-05\",\n        \"2023-06\",\n        \"2023-07\",\n        \"2023-08\",\n        \"2023-09\",\n        \"2023-10\",\n        \"2023-11\",\n        \"2023-12\",\n        \"2024-01\",\n        \"2024-02\",\n        \"2024-03\",\n        \"2024-04\",\n        \"2024-05\",\n        \"2024-06\",\n        \"2024-07\",\n        \"2024-08\"\n      ],\n      \"historical_values\": [\n        0.8899999856948853,\n        0.8899999856948853,\n        1.0199999809265137,\n        0.8799999952316284,\n        0.8500000238418579,\n        0.8799999952316284,\n        0.8799999952316284,\n        0.8999999761581421,\n        0.8799999952316284,\n        0.949999988079071,\n        0.7699999809265137,\n        0.7799999713897705,\n        0.8700000047683716,\n        0.9800000190734863,\n        1.2100000381469727,\n        1.0,\n        0.9399999976158142,\n        1.0800000429153442,\n        1.1799999475479126,\n        1.2400000095367432,\n        1.4700000286102295,\n        1.3200000524520874,\n        1.1799999475479126,\n        1.159999966621399,\n        1.2200000286102295,\n        1.350000023841858,\n        1.340000033378601,\n        1.2599999904632568,\n        1.149999976158142,\n        1.2000000476837158,\n        1.2400000095367432,\n        1.2999999523162842\n      ],\n      \"forecast_dates\": [\n        \"2024-09\",\n        \"2024-10\",\n        \"2024-11\",\n        \"2024-12\",\n        \"2025-01\",\n        \"2025-02\",\n        \"2025-03\",\n        \"2025-04\",\n        \"2025-05\",\n        \"2025-06\",\n        \"2025-07\",\n        \"2025-08\",\n        \"2025-09\",\n        \"2025-10\",\n        \"2025-11\",\n        \"2025-12\"\n      ],\n      \"point_forecast\": [\n        1.2892454862594604,\n        1.2497223615646362,\n        1.2063699960708618,\n        1.2123697996139526,\n        1.2295829057693481,\n        1.2457282543182373,\n        1.2520256042480469,\n        1.1976659297943115,\n        1.1560035943984985,\n        1.15586519241333,\n        1.168123483657837,\n        1.188661813735962,\n        1.1947652101516724,\n        1.173640251159668,\n        1.128365397453308,\n        1.128602385520935\n      ],\n      \"q10\": [\n        1.2727627754211426,\n        1.2367907762527466,\n        1.1920455694198608,\n        1.1937742233276367,\n        1.2203925848007202,\n        1.2314530611038208,\n        1.2363964319229126,\n        1.1829954385757446,\n        1.1487408876419067,\n        1.1405112743377686,\n        1.1547985076904297,\n        1.1740177869796753,\n        1.1805450916290283,\n        1.1459304094314575,\n        1.1116427183151245,\n        1.0966339111328125\n      ],\n      \"q20\": [\n        1.11649489402771,\n        1.0445278882980347,\n        0.9846185445785522,\n        0.9668428897857666,\n        0.9715695977210999,\n        0.9662386178970337,\n        0.9553800821304321,\n        0.9113569855690002,\n        0.8853881359100342,\n        0.8746424913406372,\n        0.875267505645752,\n        0.8781014680862427,\n        0.8732690215110779,\n        0.8478219509124756,\n        0.8163697719573975,\n        0.815811276435852\n      ],\n      \"q80\": [\n        1.3429784774780273,\n        1.3280879259109497,\n        1.292254090309143,\n        1.3056862354278564,\n        1.3293191194534302,\n        1.352075219154358,\n        1.3609846830368042,\n        1.3075883388519287,\n        1.279836893081665,\n        1.272203803062439,\n        1.2965717315673828,\n        1.3177393674850464,\n        1.3210997581481934,\n        1.295129418373108,\n        1.2528834342956543,\n        1.246609091758728\n      ],\n      \"q90\": [\n        1.3865525722503662,\n        1.3712806701660156,\n        1.3499008417129517,\n        1.3717585802078247,\n        1.4015172719955444,\n        1.4236888885498047,\n        1.4422738552093506,\n        1.3891522884368896,\n        1.3545751571655273,\n        1.349416732788086,\n        1.363886833190918,\n        1.3921372890472412,\n        1.3967747688293457,\n        1.3780581951141357,\n        1.331864356994629,\n        1.3187098503112793\n      ]\n    },\n    {\n      \"step\": 22,\n      \"n_points\": 33,\n      \"horizon\": 15,\n      \"last_historical_date\": \"2024-09\",\n      \"historical_dates\": [\n        \"2022-01\",\n        \"2022-02\",\n        \"2022-03\",\n        \"2022-04\",\n        \"2022-05\",\n        \"2022-06\",\n        \"2022-07\",\n        \"2022-08\",\n        \"2022-09\",\n        \"2022-10\",\n        \"2022-11\",\n        \"2022-12\",\n        \"2023-01\",\n        \"2023-02\",\n        \"2023-03\",\n        \"2023-04\",\n        \"2023-05\",\n        \"2023-06\",\n        \"2023-07\",\n        \"2023-08\",\n        \"2023-09\",\n        \"2023-10\",\n        \"2023-11\",\n        \"2023-12\",\n        \"2024-01\",\n        \"2024-02\",\n        \"2024-03\",\n        \"2024-04\",\n        \"2024-05\",\n        \"2024-06\",\n        \"2024-07\",\n        \"2024-08\",\n        \"2024-09\"\n      ],\n      \"historical_values\": [\n        0.8899999856948853,\n        0.8899999856948853,\n        1.0199999809265137,\n        0.8799999952316284,\n        0.8500000238418579,\n        0.8799999952316284,\n        0.8799999952316284,\n        0.8999999761581421,\n        0.8799999952316284,\n        0.949999988079071,\n        0.7699999809265137,\n        0.7799999713897705,\n        0.8700000047683716,\n        0.9800000190734863,\n        1.2100000381469727,\n        1.0,\n        0.9399999976158142,\n        1.0800000429153442,\n        1.1799999475479126,\n        1.2400000095367432,\n        1.4700000286102295,\n        1.3200000524520874,\n        1.1799999475479126,\n        1.159999966621399,\n        1.2200000286102295,\n        1.350000023841858,\n        1.340000033378601,\n        1.2599999904632568,\n        1.149999976158142,\n        1.2000000476837158,\n        1.2400000095367432,\n        1.2999999523162842,\n        1.2799999713897705\n      ],\n      \"forecast_dates\": [\n        \"2024-10\",\n        \"2024-11\",\n        \"2024-12\",\n        \"2025-01\",\n        \"2025-02\",\n        \"2025-03\",\n        \"2025-04\",\n        \"2025-05\",\n        \"2025-06\",\n        \"2025-07\",\n        \"2025-08\",\n        \"2025-09\",\n        \"2025-10\",\n        \"2025-11\",\n        \"2025-12\"\n      ],\n      \"point_forecast\": [\n        1.2395873069763184,\n        1.192650318145752,\n        1.1737117767333984,\n        1.1951370239257812,\n        1.232491135597229,\n        1.265418291091919,\n        1.2109034061431885,\n        1.1846691370010376,\n        1.1904014348983765,\n        1.2089793682098389,\n        1.2557576894760132,\n        1.2761039733886719,\n        1.2492849826812744,\n        1.2014641761779785,\n        1.1954424381256104\n      ],\n      \"q10\": [\n        1.2416894435882568,\n        1.1871181726455688,\n        1.1744379997253418,\n        1.19320547580719,\n        1.2350860834121704,\n        1.2670172452926636,\n        1.211256980895996,\n        1.1898648738861084,\n        1.1905932426452637,\n        1.1989935636520386,\n        1.247326135635376,\n        1.268507480621338,\n        1.2414063215255737,\n        1.1882392168045044,\n        1.184570550918579\n      ],\n      \"q20\": [\n        1.097076654434204,\n        1.0414971113204956,\n        1.0175477266311646,\n        1.0278714895248413,\n        1.0624254941940308,\n        1.0802021026611328,\n        1.0272504091262817,\n        1.0036317110061646,\n        1.0009558200836182,\n        1.001404047012329,\n        1.0334482192993164,\n        1.042593240737915,\n        1.0162984132766724,\n        0.9763948321342468,\n        0.9707307815551758\n      ],\n      \"q80\": [\n        1.2870674133300781,\n        1.2494632005691528,\n        1.2323118448257446,\n        1.2594434022903442,\n        1.3010603189468384,\n        1.3373479843139648,\n        1.2841951847076416,\n        1.2637286186218262,\n        1.2685482501983643,\n        1.2876002788543701,\n        1.339444637298584,\n        1.3590757846832275,\n        1.3355648517608643,\n        1.2837905883789062,\n        1.2771517038345337\n      ],\n      \"q90\": [\n        1.3212705850601196,\n        1.2820069789886475,\n        1.2749484777450562,\n        1.2991927862167358,\n        1.3489611148834229,\n        1.384088397026062,\n        1.3305764198303223,\n        1.3098028898239136,\n        1.3174644708633423,\n        1.334850788116455,\n        1.387671709060669,\n        1.4108545780181885,\n        1.3836441040039062,\n        1.3309946060180664,\n        1.3257174491882324\n      ]\n    },\n    {\n      \"step\": 23,\n      \"n_points\": 34,\n      \"horizon\": 14,\n      \"last_historical_date\": \"2024-10\",\n      \"historical_dates\": [\n        \"2022-01\",\n        \"2022-02\",\n        \"2022-03\",\n        \"2022-04\",\n        \"2022-05\",\n        \"2022-06\",\n        \"2022-07\",\n        \"2022-08\",\n        \"2022-09\",\n        \"2022-10\",\n        \"2022-11\",\n        \"2022-12\",\n        \"2023-01\",\n        \"2023-02\",\n        \"2023-03\",\n        \"2023-04\",\n        \"2023-05\",\n        \"2023-06\",\n        \"2023-07\",\n        \"2023-08\",\n        \"2023-09\",\n        \"2023-10\",\n        \"2023-11\",\n        \"2023-12\",\n        \"2024-01\",\n        \"2024-02\",\n        \"2024-03\",\n        \"2024-04\",\n        \"2024-05\",\n        \"2024-06\",\n        \"2024-07\",\n        \"2024-08\",\n        \"2024-09\",\n        \"2024-10\"\n      ],\n      \"historical_values\": [\n        0.8899999856948853,\n        0.8899999856948853,\n        1.0199999809265137,\n        0.8799999952316284,\n        0.8500000238418579,\n        0.8799999952316284,\n        0.8799999952316284,\n        0.8999999761581421,\n        0.8799999952316284,\n        0.949999988079071,\n        0.7699999809265137,\n        0.7799999713897705,\n        0.8700000047683716,\n        0.9800000190734863,\n        1.2100000381469727,\n        1.0,\n        0.9399999976158142,\n        1.0800000429153442,\n        1.1799999475479126,\n        1.2400000095367432,\n        1.4700000286102295,\n        1.3200000524520874,\n        1.1799999475479126,\n        1.159999966621399,\n        1.2200000286102295,\n        1.350000023841858,\n        1.340000033378601,\n        1.2599999904632568,\n        1.149999976158142,\n        1.2000000476837158,\n        1.2400000095367432,\n        1.2999999523162842,\n        1.2799999713897705,\n        1.2699999809265137\n      ],\n      \"forecast_dates\": [\n        \"2024-11\",\n        \"2024-12\",\n        \"2025-01\",\n        \"2025-02\",\n        \"2025-03\",\n        \"2025-04\",\n        \"2025-05\",\n        \"2025-06\",\n        \"2025-07\",\n        \"2025-08\",\n        \"2025-09\",\n        \"2025-10\",\n        \"2025-11\",\n        \"2025-12\"\n      ],\n      \"point_forecast\": [\n        1.200866460800171,\n        1.1866711378097534,\n        1.2232941389083862,\n        1.2719991207122803,\n        1.2799842357635498,\n        1.2515898942947388,\n        1.1958189010620117,\n        1.19310462474823,\n        1.2179431915283203,\n        1.2518219947814941,\n        1.2716079950332642,\n        1.2360819578170776,\n        1.1987874507904053,\n        1.1850693225860596\n      ],\n      \"q10\": [\n        1.2021855115890503,\n        1.1821584701538086,\n        1.2226784229278564,\n        1.273689866065979,\n        1.2845158576965332,\n        1.2485958337783813,\n        1.1959373950958252,\n        1.1964659690856934,\n        1.2180784940719604,\n        1.2440263032913208,\n        1.2621558904647827,\n        1.2280503511428833,\n        1.1858408451080322,\n        1.1696057319641113\n      ],\n      \"q20\": [\n        1.0769736766815186,\n        1.0466127395629883,\n        1.0687201023101807,\n        1.1035237312316895,\n        1.1067966222763062,\n        1.0670413970947266,\n        1.0116249322891235,\n        1.003699779510498,\n        1.0221866369247437,\n        1.0382513999938965,\n        1.0417994260787964,\n        1.0053966045379639,\n        0.9645071029663086,\n        0.9537580609321594\n      ],\n      \"q80\": [\n        1.2458512783050537,\n        1.2381772994995117,\n        1.2802457809448242,\n        1.3395813703536987,\n        1.3537287712097168,\n        1.3230884075164795,\n        1.2715508937835693,\n        1.2736643552780151,\n        1.3004214763641357,\n        1.338258147239685,\n        1.3596911430358887,\n        1.3208271265029907,\n        1.2824501991271973,\n        1.2699368000030518\n      ],\n      \"q90\": [\n        1.2776029109954834,\n        1.2695484161376953,\n        1.3248724937438965,\n        1.3829126358032227,\n        1.4010111093521118,\n        1.3700647354125977,\n        1.3196228742599487,\n        1.3224942684173584,\n        1.3526369333267212,\n        1.3852152824401855,\n        1.4081038236618042,\n        1.3699979782104492,\n        1.3278801441192627,\n        1.3165735006332397\n      ]\n    },\n    {\n      \"step\": 24,\n      \"n_points\": 35,\n      \"horizon\": 13,\n      \"last_historical_date\": \"2024-11\",\n      \"historical_dates\": [\n        \"2022-01\",\n        \"2022-02\",\n        \"2022-03\",\n        \"2022-04\",\n        \"2022-05\",\n        \"2022-06\",\n        \"2022-07\",\n        \"2022-08\",\n        \"2022-09\",\n        \"2022-10\",\n        \"2022-11\",\n        \"2022-12\",\n        \"2023-01\",\n        \"2023-02\",\n        \"2023-03\",\n        \"2023-04\",\n        \"2023-05\",\n        \"2023-06\",\n        \"2023-07\",\n        \"2023-08\",\n        \"2023-09\",\n        \"2023-10\",\n        \"2023-11\",\n        \"2023-12\",\n        \"2024-01\",\n        \"2024-02\",\n        \"2024-03\",\n        \"2024-04\",\n        \"2024-05\",\n        \"2024-06\",\n        \"2024-07\",\n        \"2024-08\",\n        \"2024-09\",\n        \"2024-10\",\n        \"2024-11\"\n      ],\n      \"historical_values\": [\n        0.8899999856948853,\n        0.8899999856948853,\n        1.0199999809265137,\n        0.8799999952316284,\n        0.8500000238418579,\n        0.8799999952316284,\n        0.8799999952316284,\n        0.8999999761581421,\n        0.8799999952316284,\n        0.949999988079071,\n        0.7699999809265137,\n        0.7799999713897705,\n        0.8700000047683716,\n        0.9800000190734863,\n        1.2100000381469727,\n        1.0,\n        0.9399999976158142,\n        1.0800000429153442,\n        1.1799999475479126,\n        1.2400000095367432,\n        1.4700000286102295,\n        1.3200000524520874,\n        1.1799999475479126,\n        1.159999966621399,\n        1.2200000286102295,\n        1.350000023841858,\n        1.340000033378601,\n        1.2599999904632568,\n        1.149999976158142,\n        1.2000000476837158,\n        1.2400000095367432,\n        1.2999999523162842,\n        1.2799999713897705,\n        1.2699999809265137,\n        1.2200000286102295\n      ],\n      \"forecast_dates\": [\n        \"2024-12\",\n        \"2025-01\",\n        \"2025-02\",\n        \"2025-03\",\n        \"2025-04\",\n        \"2025-05\",\n        \"2025-06\",\n        \"2025-07\",\n        \"2025-08\",\n        \"2025-09\",\n        \"2025-10\",\n        \"2025-11\",\n        \"2025-12\"\n      ],\n      \"point_forecast\": [\n        1.2384696006774902,\n        1.2530195713043213,\n        1.3186349868774414,\n        1.3470391035079956,\n        1.2608959674835205,\n        1.1712164878845215,\n        1.1867536306381226,\n        1.2420611381530762,\n        1.26655912399292,\n        1.2961373329162598,\n        1.2294163703918457,\n        1.166834831237793,\n        1.1554596424102783\n      ],\n      \"q10\": [\n        1.2286468744277954,\n        1.2455438375473022,\n        1.3089576959609985,\n        1.3339853286743164,\n        1.2469478845596313,\n        1.149349570274353,\n        1.1650605201721191,\n        1.2206904888153076,\n        1.2502191066741943,\n        1.267012357711792,\n        1.2066657543182373,\n        1.1346192359924316,\n        1.115806221961975\n      ],\n      \"q20\": [\n        1.1213350296020508,\n        1.1162383556365967,\n        1.1598260402679443,\n        1.164056420326233,\n        1.0658612251281738,\n        0.9682412147521973,\n        0.9661321043968201,\n        1.0035676956176758,\n        1.0229461193084717,\n        1.0328454971313477,\n        0.9562720656394958,\n        0.8820796608924866,\n        0.8598078489303589\n      ],\n      \"q80\": [\n        1.2744736671447754,\n        1.3042716979980469,\n        1.3755841255187988,\n        1.4136451482772827,\n        1.3280566930770874,\n        1.2414281368255615,\n        1.265558123588562,\n        1.3200562000274658,\n        1.3582563400268555,\n        1.391558051109314,\n        1.325316071510315,\n        1.2584038972854614,\n        1.2471749782562256\n      ],\n      \"q90\": [\n        1.3033584356307983,\n        1.337569236755371,\n        1.418811321258545,\n        1.4565141201019287,\n        1.3784044981002808,\n        1.289095401763916,\n        1.3142638206481934,\n        1.378018856048584,\n        1.411327838897705,\n        1.4371509552001953,\n        1.3814256191253662,\n        1.3204567432403564,\n        1.3057005405426025\n      ]\n    },\n    {\n      \"step\": 25,\n      \"n_points\": 36,\n      \"horizon\": 12,\n      \"last_historical_date\": \"2024-12\",\n      \"historical_dates\": [\n        \"2022-01\",\n        \"2022-02\",\n        \"2022-03\",\n        \"2022-04\",\n        \"2022-05\",\n        \"2022-06\",\n        \"2022-07\",\n        \"2022-08\",\n        \"2022-09\",\n        \"2022-10\",\n        \"2022-11\",\n        \"2022-12\",\n        \"2023-01\",\n        \"2023-02\",\n        \"2023-03\",\n        \"2023-04\",\n        \"2023-05\",\n        \"2023-06\",\n        \"2023-07\",\n        \"2023-08\",\n        \"2023-09\",\n        \"2023-10\",\n        \"2023-11\",\n        \"2023-12\",\n        \"2024-01\",\n        \"2024-02\",\n        \"2024-03\",\n        \"2024-04\",\n        \"2024-05\",\n        \"2024-06\",\n        \"2024-07\",\n        \"2024-08\",\n        \"2024-09\",\n        \"2024-10\",\n        \"2024-11\",\n        \"2024-12\"\n      ],\n      \"historical_values\": [\n        0.8899999856948853,\n        0.8899999856948853,\n        1.0199999809265137,\n        0.8799999952316284,\n        0.8500000238418579,\n        0.8799999952316284,\n        0.8799999952316284,\n        0.8999999761581421,\n        0.8799999952316284,\n        0.949999988079071,\n        0.7699999809265137,\n        0.7799999713897705,\n        0.8700000047683716,\n        0.9800000190734863,\n        1.2100000381469727,\n        1.0,\n        0.9399999976158142,\n        1.0800000429153442,\n        1.1799999475479126,\n        1.2400000095367432,\n        1.4700000286102295,\n        1.3200000524520874,\n        1.1799999475479126,\n        1.159999966621399,\n        1.2200000286102295,\n        1.350000023841858,\n        1.340000033378601,\n        1.2599999904632568,\n        1.149999976158142,\n        1.2000000476837158,\n        1.2400000095367432,\n        1.2999999523162842,\n        1.2799999713897705,\n        1.2699999809265137,\n        1.2200000286102295,\n        1.2000000476837158\n      ],\n      \"forecast_dates\": [\n        \"2025-01\",\n        \"2025-02\",\n        \"2025-03\",\n        \"2025-04\",\n        \"2025-05\",\n        \"2025-06\",\n        \"2025-07\",\n        \"2025-08\",\n        \"2025-09\",\n        \"2025-10\",\n        \"2025-11\",\n        \"2025-12\"\n      ],\n      \"point_forecast\": [\n        1.25933837890625,\n        1.285666823387146,\n        1.2950127124786377,\n        1.2207623720169067,\n        1.170255422592163,\n        1.1455552577972412,\n        1.1702347993850708,\n        1.2026824951171875,\n        1.1909748315811157,\n        1.1490840911865234,\n        1.080478549003601,\n        1.0613453388214111\n      ],\n      \"q10\": [\n        1.2481880187988281,\n        1.2773758172988892,\n        1.286991834640503,\n        1.2084007263183594,\n        1.1533130407333374,\n        1.1275498867034912,\n        1.1510555744171143,\n        1.1859495639801025,\n        1.1784849166870117,\n        1.1264795064926147,\n        1.0624356269836426,\n        1.036609172821045\n      ],\n      \"q20\": [\n        1.1407020092010498,\n        1.1406043767929077,\n        1.126852035522461,\n        1.0352504253387451,\n        0.9691494703292847,\n        0.9420379400253296,\n        0.9503718018531799,\n        0.970925509929657,\n        0.9594371318817139,\n        0.9079477190971375,\n        0.8361266255378723,\n        0.8022069334983826\n      ],\n      \"q80\": [\n        1.2971320152282715,\n        1.3400218486785889,\n        1.3547290563583374,\n        1.2898554801940918,\n        1.2390310764312744,\n        1.2180578708648682,\n        1.248227596282959,\n        1.2842004299163818,\n        1.2832940816879272,\n        1.240414023399353,\n        1.175971508026123,\n        1.153149962425232\n      ],\n      \"q90\": [\n        1.3239599466323853,\n        1.3751201629638672,\n        1.403548240661621,\n        1.3310348987579346,\n        1.2891905307769775,\n        1.2702757120132446,\n        1.2997852563858032,\n        1.3408125638961792,\n        1.3354730606079102,\n        1.286876916885376,\n        1.2283769845962524,\n        1.2169079780578613\n      ]\n    }\n  ]\n};\n        \n        let chart = null;\n        let isPlaying = false;\n        let playInterval = null;\n        let currentStep = 0;\n\n        // Fixed axis extents\n        let allDates = [];\n        let yMin = 0.7;\n        let yMax = 1.55;\n\n        function initChart() {\n            const ctx = document.getElementById('chart').getContext('2d');\n            \n            // Calculate fixed extents\n            const finalStep = animationData.animation_steps[animationData.animation_steps.length - 1];\n            allDates = [\n                ...animationData.actual_data.dates,\n                ...finalStep.forecast_dates\n            ];\n            \n            // Y extent from all values\n            const allValues = [\n                ...animationData.actual_data.values,\n                ...finalStep.point_forecast,\n                ...finalStep.q10,\n                ...finalStep.q90\n            ];\n            yMin = Math.min(...allValues) - 0.05;\n            yMax = Math.max(...allValues) + 0.05;\n            \n            chart = new Chart(ctx, {\n                type: 'line',\n                data: {\n                    labels: allDates,\n                    datasets: [\n                        {\n                            label: 'All Observed',\n                            data: animationData.actual_data.values.map((v, i) => ({x: animationData.actual_data.dates[i], y: v})),\n                            borderColor: '#9ca3af',\n                            borderWidth: 1,\n                            pointRadius: 2,\n                            pointBackgroundColor: '#9ca3af',\n                            fill: false,\n                            tension: 0.1,\n                            order: 1,\n                        },\n                        {\n                            label: 'Final Forecast',\n                            data: [...Array(animationData.actual_data.dates.length).fill(null), ...finalStep.point_forecast],\n                            borderColor: '#fca5a5',\n                            borderWidth: 1,\n                            borderDash: [4, 4],\n                            pointRadius: 2,\n                            pointBackgroundColor: '#fca5a5',\n                            fill: false,\n                            tension: 0.1,\n                            order: 2,\n                        },\n                        {\n                            label: 'Data Used',\n                            data: [],\n                            borderColor: '#3b82f6',\n                            backgroundColor: 'rgba(59, 130, 246, 0.1)',\n                            borderWidth: 2.5,\n                            pointRadius: 4,\n                            pointBackgroundColor: '#3b82f6',\n                            fill: false,\n                            tension: 0.1,\n                            order: 10,\n                        },\n                        {\n                            label: '90% CI Lower',\n                            data: [],\n                            borderColor: 'transparent',\n                            backgroundColor: 'rgba(239, 68, 68, 0.08)',\n                            fill: '+1',\n                            pointRadius: 0,\n                            tension: 0.1,\n                            order: 5,\n                        },\n                        {\n                            label: '90% CI Upper',\n                            data: [],\n                            borderColor: 'transparent',\n                            backgroundColor: 'rgba(239, 68, 68, 0.08)',\n                            fill: false,\n                            pointRadius: 0,\n                            tension: 0.1,\n                            order: 5,\n                        },\n                        {\n                            label: '80% CI Lower',\n                            data: [],\n                            borderColor: 'transparent',\n                            backgroundColor: 'rgba(239, 68, 68, 0.2)',\n                            fill: '+1',\n                            pointRadius: 0,\n                            tension: 0.1,\n                            order: 6,\n                        },\n                        {\n                            label: '80% CI Upper',\n                            data: [],\n                            borderColor: 'transparent',\n                            backgroundColor: 'rgba(239, 68, 68, 0.2)',\n                            fill: false,\n                            pointRadius: 0,\n                            tension: 0.1,\n                            order: 6,\n                        },\n                        {\n                            label: 'Forecast',\n                            data: [],\n                            borderColor: '#ef4444',\n                            backgroundColor: 'rgba(239, 68, 68, 0.1)',\n                            borderWidth: 2.5,\n                            pointRadius: 4,\n                            pointBackgroundColor: '#ef4444',\n                            fill: false,\n                            tension: 0.1,\n                            order: 7,\n                        },\n                    ]\n                },\n                options: {\n                    responsive: true,\n                    maintainAspectRatio: false,\n                    interaction: { intersect: false, mode: 'index' },\n                    plugins: {\n                        legend: { display: false },\n                        tooltip: {\n                            backgroundColor: 'rgba(0, 0, 0, 0.8)',\n                            titleColor: '#fff',\n                            bodyColor: '#fff',\n                            padding: 12,\n                        },\n                    },\n                    scales: {\n                        x: {\n                            grid: { color: 'rgba(255, 255, 255, 0.05)' },\n                            ticks: { color: '#9ca3af', maxRotation: 45, minRotation: 45 },\n                        },\n                        y: {\n                            grid: { color: 'rgba(255, 255, 255, 0.05)' },\n                            ticks: {\n                                color: '#9ca3af',\n                                callback: v => v.toFixed(2) + '°C'\n                            },\n                            min: yMin,\n                            max: yMax,\n                        },\n                    },\n                    animation: { duration: 150 },\n                },\n            });\n        }\n\n        function updateChart(stepIndex) {\n            if (!animationData || !chart) return;\n            \n            const step = animationData.animation_steps[stepIndex];\n            const finalStep = animationData.animation_steps[animationData.animation_steps.length - 1];\n            const actual = animationData.actual_data;\n            \n            // Build data arrays for each dataset\n            const nHist = step.historical_dates.length;\n            const nForecast = step.forecast_dates.length;\n            const nActual = actual.dates.length;\n            const nFinalForecast = finalStep.forecast_dates.length;\n            const totalPoints = nActual + nFinalForecast;\n            \n            // Dataset 0: All observed (always full)\n            chart.data.datasets[0].data = actual.values.map((v, i) => ({x: actual.dates[i], y: v}));\n            \n            // Dataset 1: Final forecast reference (always full)\n            chart.data.datasets[1].data = [\n                ...Array(nActual).fill(null),\n                ...finalStep.point_forecast\n            ];\n            \n            // Dataset 2: Data used (historical only)\n            const dataUsed = [];\n            for (let i = 0; i < totalPoints; i++) {\n                if (i < nHist) {\n                    dataUsed.push(step.historical_values[i]);\n                } else {\n                    dataUsed.push(null);\n                }\n            }\n            chart.data.datasets[2].data = dataUsed;\n            \n            // Datasets 3-6: CIs (forecast only)\n            const forecastOffset = nActual;\n            const q90Lower = [];\n            const q90Upper = [];\n            const q80Lower = [];\n            const q80Upper = [];\n            \n            for (let i = 0; i < totalPoints; i++) {\n                const forecastIdx = i - forecastOffset;\n                if (forecastIdx >= 0 && forecastIdx < nForecast) {\n                    q90Lower.push(step.q10[forecastIdx]);\n                    q90Upper.push(step.q90[forecastIdx]);\n                    q80Lower.push(step.q20[forecastIdx]);\n                    q80Upper.push(step.q80[forecastIdx]);\n                } else {\n                    q90Lower.push(null);\n                    q90Upper.push(null);\n                    q80Lower.push(null);\n                    q80Upper.push(null);\n                }\n            }\n            chart.data.datasets[3].data = q90Lower;\n            chart.data.datasets[4].data = q90Upper;\n            chart.data.datasets[5].data = q80Lower;\n            chart.data.datasets[6].data = q80Upper;\n            \n            // Dataset 7: Forecast line\n            const forecastData = [];\n            for (let i = 0; i < totalPoints; i++) {\n                const forecastIdx = i - forecastOffset;\n                if (forecastIdx >= 0 && forecastIdx < nForecast) {\n                    forecastData.push(step.point_forecast[forecastIdx]);\n                } else {\n                    forecastData.push(null);\n                }\n            }\n            chart.data.datasets[7].data = forecastData;\n            \n            chart.update('none');\n            \n            // Update UI\n            document.getElementById('slider').value = stepIndex;\n            document.getElementById('points-value').textContent = `${step.n_points} / 36`;\n            document.getElementById('date-end').textContent = `Using data through ${step.last_historical_date}`;\n            \n            // Stats\n            const mean = (step.point_forecast.reduce((a, b) => a + b, 0) / step.point_forecast.length).toFixed(3);\n            const max = Math.max(...step.point_forecast).toFixed(3);\n            const min = Math.min(...step.point_forecast).toFixed(3);\n            \n            document.getElementById('stat-mean').textContent = mean + '°C';\n            document.getElementById('stat-horizon').textContent = step.horizon + ' months';\n            document.getElementById('stat-max').textContent = max + '°C';\n            document.getElementById('stat-min').textContent = min + '°C';\n            \n            currentStep = stepIndex;\n        }\n\n        document.getElementById('slider').addEventListener('input', e => {\n            updateChart(parseInt(e.target.value));\n        });\n\n        document.getElementById('play-btn').addEventListener('click', () => {\n            const btn = document.getElementById('play-btn');\n            if (isPlaying) {\n                clearInterval(playInterval);\n                btn.textContent = '▶ Play';\n                isPlaying = false;\n            } else {\n                btn.textContent = '⏸ Pause';\n                isPlaying = true;\n                if (currentStep >= animationData.animation_steps.length - 1) currentStep = 0;\n                playInterval = setInterval(() => {\n                    if (currentStep >= animationData.animation_steps.length - 1) {\n                        clearInterval(playInterval);\n                        document.getElementById('play-btn').textContent = '▶ Play';\n                        isPlaying = false;\n                    } else {\n                        currentStep++;\n                        updateChart(currentStep);\n                    }\n                }, 400);\n            }\n        });\n\n        document.getElementById('reset-btn').addEventListener('click', () => {\n            if (isPlaying) {\n                clearInterval(playInterval);\n                document.getElementById('play-btn').textContent = '▶ Play';\n                isPlaying = false;\n            }\n            updateChart(0);\n        });\n\n        // Initialize on load\n        initChart();\n        updateChart(0);\n    </script>\n</body>\n</html>\n"
  },
  {
    "path": "timesfm-forecasting/examples/global-temperature/run_example.sh",
    "content": "#!/bin/bash\n# run_example.sh - Run the TimesFM temperature anomaly forecasting example\n#\n# This script:\n# 1. Runs the preflight system check\n# 2. Runs the TimesFM forecast\n# 3. Generates the visualization\n#\n# Usage:\n#   ./run_example.sh\n#\n# Prerequisites:\n#   - Python 3.10+\n#   - timesfm[torch] installed: uv pip install \"timesfm[torch]\"\n#   - matplotlib, pandas, numpy\n\nset -e\n\nSCRIPT_DIR=\"$(cd \"$(dirname \"${BASH_SOURCE[0]}\")\" && pwd)\"\nSKILL_ROOT=\"$(dirname \"$(dirname \"$SCRIPT_DIR\")\")\"\n\necho \"============================================================\"\necho \"  TimesFM Example: Global Temperature Anomaly Forecast\"\necho \"============================================================\"\n\n# Step 1: Preflight check\necho \"\"\necho \"🔍 Step 1: Running preflight system check...\"\npython3 \"$SKILL_ROOT/scripts/check_system.py\" || {\n    echo \"❌ Preflight check failed. Please fix the issues above before continuing.\"\n    exit 1\n}\n\n# Step 2: Run forecast\necho \"\"\necho \"📊 Step 2: Running TimesFM forecast...\"\ncd \"$SCRIPT_DIR\"\npython3 run_forecast.py\n\n# Step 3: Generate visualization\necho \"\"\necho \"📈 Step 3: Generating visualization...\"\npython3 visualize_forecast.py\n\necho \"\"\necho \"============================================================\"\necho \"  ✅ Example complete!\"\necho \"============================================================\"\necho \"\"\necho \"Output files:\"\necho \"  - $SCRIPT_DIR/output/forecast_output.csv\"\necho \"  - $SCRIPT_DIR/output/forecast_output.json\"\necho \"  - $SCRIPT_DIR/output/forecast_visualization.png\"\n"
  },
  {
    "path": "timesfm-forecasting/examples/global-temperature/run_forecast.py",
    "content": "#!/usr/bin/env python3\n\"\"\"\nRun TimesFM forecast on global temperature anomaly data.\nGenerates forecast output CSV and JSON for the example.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport json\nfrom pathlib import Path\n\nimport numpy as np\nimport pandas as pd\n\n# Preflight check\nprint(\"=\" * 60)\nprint(\"  TIMeSFM FORECAST - Global Temperature Anomaly Example\")\nprint(\"=\" * 60)\n\n# Load data\ndata_path = Path(__file__).parent / \"temperature_anomaly.csv\"\ndf = pd.read_csv(data_path, parse_dates=[\"date\"])\ndf = df.sort_values(\"date\").reset_index(drop=True)\n\nprint(f\"\\n📊 Input Data: {len(df)} months of temperature anomalies\")\nprint(\n    f\"   Date range: {df['date'].min().strftime('%Y-%m')} to {df['date'].max().strftime('%Y-%m')}\"\n)\nprint(f\"   Mean anomaly: {df['anomaly_c'].mean():.2f}°C\")\nprint(\n    f\"   Trend: {df['anomaly_c'].iloc[-12:].mean() - df['anomaly_c'].iloc[:12].mean():.2f}°C change (first to last year)\"\n)\n\n# Prepare input for TimesFM\n# TimesFM expects a list of 1D numpy arrays\ninput_series = df[\"anomaly_c\"].values.astype(np.float32)\n\n# Load TimesFM 1.0 (PyTorch)\n# NOTE: TimesFM 2.5 PyTorch checkpoint has a file format issue at time of writing.\n# The model.safetensors file is not loadable via torch.load().\n# Using TimesFM 1.0 PyTorch which works correctly.\nprint(\"\\n🤖 Loading TimesFM 1.0 (200M) PyTorch...\")\nimport timesfm\n\nhparams = timesfm.TimesFmHparams(horizon_len=12)\ncheckpoint = timesfm.TimesFmCheckpoint(\n    huggingface_repo_id=\"google/timesfm-1.0-200m-pytorch\"\n)\nmodel = timesfm.TimesFm(hparams=hparams, checkpoint=checkpoint)\n\n# Forecast\nprint(\"\\n📈 Running forecast (12 months ahead)...\")\nforecast_input = [input_series]\nfrequency_input = [0]  # Monthly data\n\npoint_forecast, experimental_quantile_forecast = model.forecast(\n    forecast_input,\n    freq=frequency_input,\n)\n\nprint(f\"   Point forecast shape: {point_forecast.shape}\")\nprint(f\"   Quantile forecast shape: {experimental_quantile_forecast.shape}\")\n\n# Extract results\npoint = point_forecast[0]  # Shape: (horizon,)\nquantiles = experimental_quantile_forecast[0]  # Shape: (horizon, num_quantiles)\n\n# TimesFM quantiles: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.99]\n# Index mapping: 0=10%, 1=20%, ..., 4=50% (median), ..., 9=99%\nquantile_labels = [\"10%\", \"20%\", \"30%\", \"40%\", \"50%\", \"60%\", \"70%\", \"80%\", \"90%\", \"99%\"]\n\n# Create forecast dates (2025 monthly)\nlast_date = df[\"date\"].max()\nforecast_dates = pd.date_range(\n    start=last_date + pd.DateOffset(months=1), periods=12, freq=\"MS\"\n)\n\n# Build output DataFrame\noutput_df = pd.DataFrame(\n    {\n        \"date\": forecast_dates.strftime(\"%Y-%m-%d\"),\n        \"point_forecast\": point,\n        \"q10\": quantiles[:, 0],\n        \"q20\": quantiles[:, 1],\n        \"q30\": quantiles[:, 2],\n        \"q40\": quantiles[:, 3],\n        \"q50\": quantiles[:, 4],  # Median\n        \"q60\": quantiles[:, 5],\n        \"q70\": quantiles[:, 6],\n        \"q80\": quantiles[:, 7],\n        \"q90\": quantiles[:, 8],\n        \"q99\": quantiles[:, 9],\n    }\n)\n\n# Save outputs\noutput_dir = Path(__file__).parent / \"output\"\noutput_dir.mkdir(exist_ok=True)\noutput_df.to_csv(output_dir / \"forecast_output.csv\", index=False)\n\n# JSON output for the report\noutput_json = {\n    \"model\": \"TimesFM 1.0 (200M) PyTorch\",\n    \"input\": {\n        \"source\": \"NOAA GISTEMP Global Temperature Anomaly\",\n        \"n_observations\": len(df),\n        \"date_range\": f\"{df['date'].min().strftime('%Y-%m')} to {df['date'].max().strftime('%Y-%m')}\",\n        \"mean_anomaly_c\": round(df[\"anomaly_c\"].mean(), 3),\n    },\n    \"forecast\": {\n        \"horizon\": 12,\n        \"dates\": forecast_dates.strftime(\"%Y-%m\").tolist(),\n        \"point\": point.tolist(),\n        \"quantiles\": {\n            label: quantiles[:, i].tolist() for i, label in enumerate(quantile_labels)\n        },\n    },\n    \"summary\": {\n        \"forecast_mean_c\": round(float(point.mean()), 3),\n        \"forecast_max_c\": round(float(point.max()), 3),\n        \"forecast_min_c\": round(float(point.min()), 3),\n        \"vs_last_year_mean\": round(\n            float(point.mean() - df[\"anomaly_c\"].iloc[-12:].mean()), 3\n        ),\n    },\n}\n\nwith open(output_dir / \"forecast_output.json\", \"w\") as f:\n    json.dump(output_json, f, indent=2)\n\n# Print summary\nprint(\"\\n\" + \"=\" * 60)\nprint(\"  FORECAST RESULTS\")\nprint(\"=\" * 60)\nprint(\n    f\"\\n📅 Forecast period: {forecast_dates[0].strftime('%Y-%m')} to {forecast_dates[-1].strftime('%Y-%m')}\"\n)\nprint(f\"\\n🌡️  Temperature Anomaly Forecast (°C above 1951-1980 baseline):\")\nprint(f\"\\n   {'Month':<10} {'Point':>8} {'80% CI':>15} {'90% CI':>15}\")\nprint(f\"   {'-' * 10} {'-' * 8} {'-' * 15} {'-' * 15}\")\nfor i, (date, pt, q10, q90, q05, q95) in enumerate(\n    zip(\n        forecast_dates.strftime(\"%Y-%m\"),\n        point,\n        quantiles[:, 1],  # 20%\n        quantiles[:, 7],  # 80%\n        quantiles[:, 0],  # 10%\n        quantiles[:, 8],  # 90%\n    )\n):\n    print(\n        f\"   {date:<10} {pt:>8.3f} [{q10:>6.3f}, {q90:>6.3f}] [{q05:>6.3f}, {q95:>6.3f}]\"\n    )\n\nprint(f\"\\n📊 Summary Statistics:\")\nprint(f\"   Mean forecast:  {point.mean():.3f}°C\")\nprint(\n    f\"   Max forecast:   {point.max():.3f}°C (Month: {forecast_dates[point.argmax()].strftime('%Y-%m')})\"\n)\nprint(\n    f\"   Min forecast:   {point.min():.3f}°C (Month: {forecast_dates[point.argmin()].strftime('%Y-%m')})\"\n)\nprint(f\"   vs 2024 mean:   {point.mean() - df['anomaly_c'].iloc[-12:].mean():+.3f}°C\")\n\nprint(f\"\\n✅ Output saved to:\")\nprint(f\"   {output_dir / 'forecast_output.csv'}\")\nprint(f\"   {output_dir / 'forecast_output.json'}\")\n"
  },
  {
    "path": "timesfm-forecasting/examples/global-temperature/temperature_anomaly.csv",
    "content": "date,anomaly_c\n2022-01-01,0.89\n2022-02-01,0.89\n2022-03-01,1.02\n2022-04-01,0.88\n2022-05-01,0.85\n2022-06-01,0.88\n2022-07-01,0.88\n2022-08-01,0.90\n2022-09-01,0.88\n2022-10-01,0.95\n2022-11-01,0.77\n2022-12-01,0.78\n2023-01-01,0.87\n2023-02-01,0.98\n2023-03-01,1.21\n2023-04-01,1.00\n2023-05-01,0.94\n2023-06-01,1.08\n2023-07-01,1.18\n2023-08-01,1.24\n2023-09-01,1.47\n2023-10-01,1.32\n2023-11-01,1.18\n2023-12-01,1.16\n2024-01-01,1.22\n2024-02-01,1.35\n2024-03-01,1.34\n2024-04-01,1.26\n2024-05-01,1.15\n2024-06-01,1.20\n2024-07-01,1.24\n2024-08-01,1.30\n2024-09-01,1.28\n2024-10-01,1.27\n2024-11-01,1.22\n2024-12-01,1.20\n"
  },
  {
    "path": "timesfm-forecasting/examples/global-temperature/visualize_forecast.py",
    "content": "#!/usr/bin/env python3\n\"\"\"\nVisualize TimesFM forecast results for global temperature anomaly.\n\nGenerates a publication-quality figure showing:\n- Historical data (2022-2024)\n- Point forecast (2025)\n- 80% and 90% confidence intervals (fan chart)\n\nUsage:\n    python visualize_forecast.py\n\"\"\"\n\nfrom __future__ import annotations\n\nimport json\nfrom pathlib import Path\n\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport pandas as pd\n\n# Configuration\nEXAMPLE_DIR = Path(__file__).parent\nINPUT_FILE = EXAMPLE_DIR / \"temperature_anomaly.csv\"\nFORECAST_FILE = EXAMPLE_DIR / \"output\" / \"forecast_output.json\"\nOUTPUT_FILE = EXAMPLE_DIR / \"output\" / \"forecast_visualization.png\"\n\n\ndef main() -> None:\n    # Load historical data\n    df = pd.read_csv(INPUT_FILE, parse_dates=[\"date\"])\n\n    # Load forecast results\n    with open(FORECAST_FILE) as f:\n        forecast = json.load(f)\n\n    # Extract forecast data\n    dates = pd.to_datetime(forecast[\"forecast\"][\"dates\"])\n    point = np.array(forecast[\"forecast\"][\"point\"])\n    q10 = np.array(forecast[\"forecast\"][\"quantiles\"][\"10%\"])\n    q20 = np.array(forecast[\"forecast\"][\"quantiles\"][\"20%\"])\n    q80 = np.array(forecast[\"forecast\"][\"quantiles\"][\"80%\"])\n    q90 = np.array(forecast[\"forecast\"][\"quantiles\"][\"90%\"])\n\n    # Create figure\n    fig, ax = plt.subplots(figsize=(12, 6))\n\n    # Plot historical data\n    ax.plot(\n        df[\"date\"],\n        df[\"anomaly_c\"],\n        color=\"#2563eb\",\n        linewidth=1.5,\n        marker=\"o\",\n        markersize=3,\n        label=\"Historical (NOAA GISTEMP)\",\n    )\n\n    # Plot 90% CI (outer band)\n    ax.fill_between(dates, q10, q90, alpha=0.2, color=\"#dc2626\", label=\"90% CI\")\n\n    # Plot 80% CI (inner band)\n    ax.fill_between(dates, q20, q80, alpha=0.3, color=\"#dc2626\", label=\"80% CI\")\n\n    # Plot point forecast\n    ax.plot(\n        dates,\n        point,\n        color=\"#dc2626\",\n        linewidth=2,\n        marker=\"s\",\n        markersize=4,\n        label=\"TimesFM Forecast\",\n    )\n\n    # Add vertical line at forecast boundary\n    ax.axvline(\n        x=df[\"date\"].max(), color=\"#6b7280\", linestyle=\"--\", linewidth=1, alpha=0.7\n    )\n\n    # Formatting\n    ax.set_xlabel(\"Date\", fontsize=12)\n    ax.set_ylabel(\"Temperature Anomaly (°C)\", fontsize=12)\n    ax.set_title(\n        \"TimesFM Zero-Shot Forecast Example\\n36-month Temperature Anomaly → 12-month Forecast\",\n        fontsize=14,\n        fontweight=\"bold\",\n    )\n\n    # Add annotations\n    ax.annotate(\n        f\"Mean forecast: {forecast['summary']['forecast_mean_c']:.2f}°C\\n\"\n        f\"vs 2024: {forecast['summary']['vs_last_year_mean']:+.2f}°C\",\n        xy=(dates[6], point[6]),\n        xytext=(dates[6], point[6] + 0.15),\n        fontsize=10,\n        arrowprops=dict(arrowstyle=\"->\", color=\"#6b7280\", lw=1),\n        bbox=dict(boxstyle=\"round,pad=0.3\", facecolor=\"white\", edgecolor=\"#6b7280\"),\n    )\n\n    # Grid and legend\n    ax.grid(True, alpha=0.3)\n    ax.legend(loc=\"upper left\", fontsize=10)\n\n    # Set y-axis limits\n    ax.set_ylim(0.7, 1.5)\n\n    # Rotate x-axis labels\n    plt.xticks(rotation=45, ha=\"right\")\n\n    # Tight layout\n    plt.tight_layout()\n\n    # Save\n    fig.savefig(OUTPUT_FILE, dpi=150, bbox_inches=\"tight\")\n    print(f\"✅ Saved visualization to: {OUTPUT_FILE}\")\n\n    plt.close()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "timesfm-forecasting/references/api_reference.md",
    "content": "# TimesFM API Reference\n\n## Model Classes\n\n### `timesfm.TimesFM_2p5_200M_torch`\n\nThe primary model class for TimesFM 2.5 (200M parameters, PyTorch backend).\n\n#### `from_pretrained()`\n\n```python\nmodel = timesfm.TimesFM_2p5_200M_torch.from_pretrained(\n    \"google/timesfm-2.5-200m-pytorch\",\n    cache_dir=None,         # Optional: custom cache directory\n    force_download=True,    # Re-download even if cached\n)\n```\n\n| Parameter | Type | Default | Description |\n| --------- | ---- | ------- | ----------- |\n| `model_id` | str | `\"google/timesfm-2.5-200m-pytorch\"` | Hugging Face model ID |\n| `revision` | str \\| None | None | Specific model revision |\n| `cache_dir` | str \\| Path \\| None | None | Custom cache directory |\n| `force_download` | bool | True | Force re-download of weights |\n\n**Returns**: Initialized `TimesFM_2p5_200M_torch` instance (not yet compiled).\n\n#### `compile()`\n\nCompiles the model with the given forecast configuration. **Must be called before `forecast()`.**\n\n```python\nmodel.compile(\n    timesfm.ForecastConfig(\n        max_context=1024,\n        max_horizon=256,\n        normalize_inputs=True,\n        per_core_batch_size=32,\n        use_continuous_quantile_head=True,\n        force_flip_invariance=True,\n        infer_is_positive=True,\n        fix_quantile_crossing=True,\n    )\n)\n```\n\n**Raises**: Nothing (but `forecast()` will raise `RuntimeError` if not compiled).\n\n#### `forecast()`\n\nRun inference on one or more time series.\n\n```python\npoint_forecast, quantile_forecast = model.forecast(\n    horizon=24,\n    inputs=[array1, array2, ...],\n)\n```\n\n| Parameter | Type | Description |\n| --------- | ---- | ----------- |\n| `horizon` | int | Number of future steps to forecast |\n| `inputs` | list[np.ndarray] | List of 1-D numpy arrays (each is a time series) |\n\n**Returns**: `tuple[np.ndarray, np.ndarray]`\n\n- `point_forecast`: shape `(batch_size, horizon)` — median (0.5 quantile)\n- `quantile_forecast`: shape `(batch_size, horizon, 10)` — [mean, q10, q20, ..., q90]\n\n**Raises**: `RuntimeError` if model is not compiled.\n\n**Key behaviors**:\n\n- Leading NaN values are stripped automatically\n- Internal NaN values are linearly interpolated\n- Series longer than `max_context` are truncated (last `max_context` points used)\n- Series shorter than `max_context` are padded\n\n#### `forecast_with_covariates()`\n\nRun inference with exogenous variables (requires `timesfm[xreg]`).\n\n```python\npoint, quantiles = model.forecast_with_covariates(\n    inputs=inputs,\n    dynamic_numerical_covariates={\"temp\": [temp_array1, temp_array2]},\n    dynamic_categorical_covariates={\"dow\": [dow_array1, dow_array2]},\n    static_categorical_covariates={\"region\": [\"east\", \"west\"]},\n    xreg_mode=\"xreg + timesfm\",\n)\n```\n\n| Parameter | Type | Description |\n| --------- | ---- | ----------- |\n| `inputs` | list[np.ndarray] | Target time series |\n| `dynamic_numerical_covariates` | dict[str, list[np.ndarray]] | Time-varying numeric features |\n| `dynamic_categorical_covariates` | dict[str, list[np.ndarray]] | Time-varying categorical features |\n| `static_categorical_covariates` | dict[str, list[str]] | Fixed categorical features per series |\n| `xreg_mode` | str | `\"xreg + timesfm\"` or `\"timesfm + xreg\"` |\n\n**Note**: Dynamic covariates must have length `context + horizon` for each series.\n\n---\n\n## `timesfm.ForecastConfig`\n\nImmutable dataclass controlling all forecast behavior.\n\n```python\n@dataclasses.dataclass(frozen=True)\nclass ForecastConfig:\n    max_context: int = 0\n    max_horizon: int = 0\n    normalize_inputs: bool = False\n    per_core_batch_size: int = 1\n    use_continuous_quantile_head: bool = False\n    force_flip_invariance: bool = True\n    infer_is_positive: bool = True\n    fix_quantile_crossing: bool = False\n    return_backcast: bool = False\n    quantiles: list[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]\n    decode_index: int = 5\n```\n\n### Parameter Details\n\n#### `max_context` (int, default=0)\n\nMaximum number of historical time points to use as context.\n\n- **0**: Use the model's maximum supported context (16,384 for v2.5)\n- **N**: Truncate series to last N points\n- **Best practice**: Set to the length of your longest series, or 512–2048 for speed\n\n#### `max_horizon` (int, default=0)\n\nMaximum forecast horizon.\n\n- **0**: Use the model's maximum\n- **N**: Forecasts up to N steps (can still call `forecast(horizon=M)` where M ≤ N)\n- **Best practice**: Set to your expected maximum forecast length\n\n#### `normalize_inputs` (bool, default=False)\n\nWhether to z-normalize each series before feeding to the model.\n\n- **True** (RECOMMENDED): Normalizes each series to zero mean, unit variance\n- **False**: Raw values are passed directly\n- **When False is OK**: Only if your series are already normalized or very close to scale 1.0\n\n#### `per_core_batch_size` (int, default=1)\n\nNumber of series processed per device in each batch.\n\n- Increase for throughput, decrease if OOM\n- See `references/system_requirements.md` for recommended values by hardware\n\n#### `use_continuous_quantile_head` (bool, default=False)\n\nUse the 30M-parameter continuous quantile head for better interval calibration.\n\n- **True** (RECOMMENDED): More accurate prediction intervals, especially for longer horizons\n- **False**: Uses fixed quantile buckets (faster but less accurate intervals)\n\n#### `force_flip_invariance` (bool, default=True)\n\nEnsures the model satisfies `f(-x) = -f(x)`.\n\n- **True** (RECOMMENDED): Mathematical consistency — forecasts are invariant to sign flip\n- **False**: Slightly faster but may produce asymmetric forecasts\n\n#### `infer_is_positive` (bool, default=True)\n\nAutomatically detect if all input values are positive and clamp forecasts ≥ 0.\n\n- **True**: Safe for sales, demand, counts, prices, volumes\n- **False**: Required for temperature, returns, PnL, any series that can be negative\n\n#### `fix_quantile_crossing` (bool, default=False)\n\nPost-process quantiles to ensure monotonicity (q10 ≤ q20 ≤ ... ≤ q90).\n\n- **True** (RECOMMENDED): Guarantees well-ordered quantiles\n- **False**: Slightly faster but quantiles may occasionally cross\n\n#### `return_backcast` (bool, default=False)\n\nReturn the model's reconstruction of the input (backcast) in addition to forecast.\n\n- **True**: Used for covariate workflows and diagnostics\n- **False**: Only return forecast\n\n---\n\n## Available Model Checkpoints\n\n| Model ID | Version | Params | Backend | Context |\n| -------- | ------- | ------ | ------- | ------- |\n| `google/timesfm-2.5-200m-pytorch` | 2.5 | 200M | PyTorch | 16,384 |\n| `google/timesfm-2.5-200m-flax` | 2.5 | 200M | JAX/Flax | 16,384 |\n| `google/timesfm-2.5-200m-transformers` | 2.5 | 200M | Transformers | 16,384 |\n| `google/timesfm-2.0-500m-pytorch` | 2.0 | 500M | PyTorch | 2,048 |\n| `google/timesfm-2.0-500m-jax` | 2.0 | 500M | JAX | 2,048 |\n| `google/timesfm-1.0-200m-pytorch` | 1.0 | 200M | PyTorch | 2,048 |\n| `google/timesfm-1.0-200m` | 1.0 | 200M | JAX | 2,048 |\n\n---\n\n## Output Shape Reference\n\n| Output | Shape | Description |\n| ------ | ----- | ----------- |\n| `point_forecast` | `(B, H)` | Median forecast for B series, H steps |\n| `quantile_forecast` | `(B, H, 10)` | Full quantile distribution |\n| `quantile_forecast[:,:,0]` | `(B, H)` | Mean |\n| `quantile_forecast[:,:,1]` | `(B, H)` | 10th percentile |\n| `quantile_forecast[:,:,5]` | `(B, H)` | 50th percentile (= point_forecast) |\n| `quantile_forecast[:,:,9]` | `(B, H)` | 90th percentile |\n\nWhere `B` = batch size (number of input series), `H` = forecast horizon.\n\n---\n\n---\n\n## Memory Estimation\n\nBefore running forecasts on large datasets, estimate memory requirements:\n\n### Formula\n\n```mermaid\nblock-beta\n    columns 3\n    ram[\"Total RAM Required\"] model[\"Model Weights<br/>~0.8 GB\"] overhead[\"Runtime Overhead<br/>~0.5 GB\"] buffers[\"I/O Buffers<br/>~0.2 MB per 1000 series<br/>per 1000 context\"]\n    \n    ram --> model\n    ram --> overhead\n    ram --> buffers\n```\n\n**Formula**:  \n`RAM (GB) ≈ 0.8 + 0.5 + (0.0002 × num_series × context_length)`\n\n**Variables**:\n- `num_series`: Number of time series in your batch\n- `context_length`: Your `max_context` value (or max series length)\n- `batch_size`: Your `per_core_batch_size` (affects parallel processing overhead)\n\n### Quick Reference\n\n| Dataset Size | Context=512 | Context=1024 | Context=2048 |\n|--------------|-------------|--------------|--------------|\n| 100 series | ~1.4 GB | ~1.5 GB | ~1.7 GB |\n| 1,000 series | ~1.9 GB | ~2.3 GB | ~3.1 GB |\n| 10,000 series| ~9.0 GB | ~17.0 GB | ~33.0 GB |\n\n### Using the Preflight Checker\n\n```bash\npython scripts/check_system.py \\\n  --num-series 1000 \\\n  --context-length 1024 \\\n  --batch-size 32\n```\n\nThis validates both system requirements AND dataset fit before loading the model.\n\n### Reducing Memory Usage\n\nIf your dataset is too large:\n\n1. **Reduce context length**: Use `max_context=512` instead of 1024+ (50% reduction)\n2. **Process in chunks**: Split large batches into smaller groups:\n\n```python\nCHUNK_SIZE = 100\nfor i in range(0, len(inputs), CHUNK_SIZE):\n    chunk = inputs[i:i+CHUNK_SIZE]\n    point, quantiles = model.forecast(horizon=H, inputs=chunk)\n    # Save chunk results\n```\n\n3. **Reduce batch size**: Lower `per_core_batch_size` (slower but less memory)\n4. **Use CPU**: If GPU OOM, the model will automatically fall back to CPU\n\n\n## Error Handling\n\n| Error | Cause | Fix |\n| ----- | ----- | --- |\n| `RuntimeError: Model is not compiled` | Called `forecast()` before `compile()` | Call `model.compile(ForecastConfig(...))` first |\n| `torch.cuda.OutOfMemoryError` | Batch too large for GPU | Reduce `per_core_batch_size` |\n| `ValueError: inputs must be list` | Passed array instead of list | Wrap in list: `[array]` |\n| `HfHubHTTPError` | Download failed | Check internet, set `HF_HOME` to writable dir |\n"
  },
  {
    "path": "timesfm-forecasting/references/data_preparation.md",
    "content": "# Data Preparation for TimesFM\n\n## Input Format\n\nTimesFM accepts a **list of 1-D numpy arrays**. Each array represents one\nunivariate time series.\n\n```python\ninputs = [\n    np.array([1.0, 2.0, 3.0, 4.0, 5.0]),       # Series 1\n    np.array([10.0, 20.0, 15.0, 25.0]),          # Series 2 (different length)\n    np.array([100.0, 110.0, 105.0, 115.0, 120.0, 130.0]),  # Series 3\n]\n```\n\n### Key Properties\n\n- **Variable lengths**: Series in the same batch can have different lengths\n- **Float values**: Use `np.float32` or `np.float64`\n- **1-D only**: Each array must be 1-dimensional (not 2-D matrix rows)\n- **NaN handling**: Leading NaNs are stripped; internal NaNs are linearly interpolated\n\n## Loading from Common Formats\n\n### CSV — Single Series (Long Format)\n\n```python\nimport pandas as pd\nimport numpy as np\n\ndf = pd.read_csv(\"data.csv\", parse_dates=[\"date\"])\nvalues = df[\"value\"].values.astype(np.float32)\ninputs = [values]\n```\n\n### CSV — Multiple Series (Wide Format)\n\n```python\ndf = pd.read_csv(\"data.csv\", parse_dates=[\"date\"], index_col=\"date\")\ninputs = [df[col].dropna().values.astype(np.float32) for col in df.columns]\n```\n\n### CSV — Long Format with ID Column\n\n```python\ndf = pd.read_csv(\"data.csv\", parse_dates=[\"date\"])\ninputs = []\nfor series_id, group in df.groupby(\"series_id\"):\n    values = group.sort_values(\"date\")[\"value\"].values.astype(np.float32)\n    inputs.append(values)\n```\n\n### Pandas DataFrame\n\n```python\n# Single column\ninputs = [df[\"temperature\"].values.astype(np.float32)]\n\n# Multiple columns\ninputs = [df[col].dropna().values.astype(np.float32) for col in numeric_cols]\n```\n\n### Numpy Arrays\n\n```python\n# 2-D array (rows = series, cols = time steps)\ndata = np.load(\"timeseries.npy\")  # shape (N, T)\ninputs = [data[i] for i in range(data.shape[0])]\n\n# Or from 1-D\ninputs = [np.sin(np.linspace(0, 10, 200))]\n```\n\n### Excel\n\n```python\ndf = pd.read_excel(\"data.xlsx\", sheet_name=\"Sheet1\")\ninputs = [df[col].dropna().values.astype(np.float32) for col in df.select_dtypes(include=[np.number]).columns]\n```\n\n### Parquet\n\n```python\ndf = pd.read_parquet(\"data.parquet\")\ninputs = [df[col].dropna().values.astype(np.float32) for col in df.select_dtypes(include=[np.number]).columns]\n```\n\n### JSON\n\n```python\nimport json\n\nwith open(\"data.json\") as f:\n    data = json.load(f)\n\n# Assumes {\"series_name\": [values...], ...}\ninputs = [np.array(values, dtype=np.float32) for values in data.values()]\n```\n\n## NaN Handling\n\nTimesFM handles NaN values automatically:\n\n### Leading NaNs\n\nStripped before feeding to the model:\n\n```python\n# Input:  [NaN, NaN, 1.0, 2.0, 3.0]\n# Actual: [1.0, 2.0, 3.0]\n```\n\n### Internal NaNs\n\nLinearly interpolated:\n\n```python\n# Input:  [1.0, NaN, 3.0, NaN, NaN, 6.0]\n# Actual: [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]\n```\n\n### Trailing NaNs\n\n**Not handled** — drop them before passing to the model:\n\n```python\nvalues = df[\"value\"].values.astype(np.float32)\n# Remove trailing NaNs\nwhile len(values) > 0 and np.isnan(values[-1]):\n    values = values[:-1]\ninputs = [values]\n```\n\n### Best Practice\n\n```python\ndef clean_series(arr: np.ndarray) -> np.ndarray:\n    \"\"\"Clean a time series for TimesFM input.\"\"\"\n    arr = np.asarray(arr, dtype=np.float32)\n    # Remove trailing NaNs\n    while len(arr) > 0 and np.isnan(arr[-1]):\n        arr = arr[:-1]\n    # Replace inf with NaN (will be interpolated)\n    arr[np.isinf(arr)] = np.nan\n    return arr\n\ninputs = [clean_series(df[col].values) for col in cols]\n```\n\n## Context Length Considerations\n\n| Context Length | Use Case | Notes |\n| -------------- | -------- | ----- |\n| 64–256 | Quick prototyping | Minimal context, fast |\n| 256–512 | Daily data, ~1 year | Good balance |\n| 512–1024 | Daily data, ~2-3 years | Standard production |\n| 1024–4096 | Hourly data, weekly patterns | More context = better |\n| 4096–16384 | High-frequency, long patterns | TimesFM 2.5 maximum |\n\n**Rule of thumb**: Provide at least 3–5 full cycles of the dominant pattern\n(e.g., for weekly seasonality with daily data, provide at least 21–35 days).\n\n## Covariates (XReg)\n\nTimesFM 2.5 supports exogenous variables through the `forecast_with_covariates()` API.\n\n### Types of Covariates\n\n| Type | Description | Example |\n| ---- | ----------- | ------- |\n| **Dynamic numerical** | Time-varying numeric features | Temperature, price, promotion spend |\n| **Dynamic categorical** | Time-varying categorical features | Day of week, holiday flag |\n| **Static categorical** | Fixed per-series features | Store ID, region, product category |\n\n### Preparing Covariates\n\nEach covariate must have length `context + horizon` for each series:\n\n```python\nimport numpy as np\n\ncontext_len = 100   # length of historical data\nhorizon = 24        # forecast horizon\ntotal_len = context_len + horizon\n\n# Dynamic numerical: temperature forecast for each series\ntemp = [\n    np.random.randn(total_len).astype(np.float32),  # Series 1\n    np.random.randn(total_len).astype(np.float32),  # Series 2\n]\n\n# Dynamic categorical: day of week (0-6) for each series\ndow = [\n    np.tile(np.arange(7), total_len // 7 + 1)[:total_len],  # Series 1\n    np.tile(np.arange(7), total_len // 7 + 1)[:total_len],  # Series 2\n]\n\n# Static categorical: one label per series\nregions = [\"east\", \"west\"]\n\n# Forecast with covariates\npoint, quantiles = model.forecast_with_covariates(\n    inputs=[values1, values2],\n    dynamic_numerical_covariates={\"temperature\": temp},\n    dynamic_categorical_covariates={\"day_of_week\": dow},\n    static_categorical_covariates={\"region\": regions},\n    xreg_mode=\"xreg + timesfm\",\n)\n```\n\n### XReg Modes\n\n| Mode | Description |\n| ---- | ----------- |\n| `\"xreg + timesfm\"` | Covariates processed first, then combined with TimesFM forecast |\n| `\"timesfm + xreg\"` | TimesFM forecast first, then adjusted by covariates |\n\n## Common Data Issues\n\n### Issue: Series too short\n\nTimesFM needs at least 1 data point, but more context = better forecasts.\n\n```python\nMIN_LENGTH = 32  # Practical minimum for meaningful forecasts\n\ninputs = [\n    arr for arr in raw_inputs\n    if len(arr[~np.isnan(arr)]) >= MIN_LENGTH\n]\n```\n\n### Issue: Series with constant values\n\nConstant series may produce NaN or zero-width prediction intervals:\n\n```python\nfor i, arr in enumerate(inputs):\n    if np.std(arr[~np.isnan(arr)]) < 1e-10:\n        print(f\"⚠️ Series {i} is constant — forecast will be flat\")\n```\n\n### Issue: Extreme outliers\n\nLarge outliers can destabilize forecasts even with normalization:\n\n```python\ndef clip_outliers(arr: np.ndarray, n_sigma: float = 5.0) -> np.ndarray:\n    \"\"\"Clip values beyond n_sigma standard deviations.\"\"\"\n    mu = np.nanmean(arr)\n    sigma = np.nanstd(arr)\n    if sigma > 0:\n        arr = np.clip(arr, mu - n_sigma * sigma, mu + n_sigma * sigma)\n    return arr\n```\n\n### Issue: Mixed frequencies in batch\n\nTimesFM handles each series independently, so you can mix frequencies:\n\n```python\ninputs = [\n    daily_sales,      # 365 points\n    weekly_revenue,   # 52 points\n    monthly_users,    # 24 points\n]\n# All forecasted in one batch — TimesFM handles different lengths\npoint, q = model.forecast(horizon=12, inputs=inputs)\n```\n\nHowever, the `horizon` is shared. If you need different horizons per series,\nforecast in separate calls.\n"
  },
  {
    "path": "timesfm-forecasting/references/system_requirements.md",
    "content": "# System Requirements for TimesFM\n\n## Hardware Tiers\n\nTimesFM can run on a variety of hardware configurations. This guide helps you\nchoose the right setup and tune performance for your machine.\n\n### How Context Limits Are Determined\n\nThe `max_context` values in each tier are **conservative recommendations** based on memory-performance tradeoffs, not hard limits. TimesFM 2.5 supports up to 16,384 context points, but smaller values are recommended for most use cases.\n\n**Why 512 and 1024?**\n\n| Factor | 512 Context | 1024 Context |\n|--------|-------------|--------------|\n| **Memory per 1000 series** | ~100 MB | ~200 MB |\n| **Typical Use Case** | Daily data, ~1-2 years | Daily data, ~2-3 years |\n| **Inference Speed** | Faster | Moderate |\n| **Hardware** | 4-8 GB RAM | 16 GB RAM or GPU |\n\n**Memory Formula**: `RAM ≈ model_weights + 0.5 GB + (0.2 MB × num_series × context_length / 1000)`\n\nWhere:\n- `model_weights` = ~800 MB (TimesFM 2.5)\n- `context_length` = your `max_context` value\n- `num_series` = number of time series in your batch\n\n**You can use larger contexts** if your hardware supports it:\n- **Up to 2048**: Requires ~16 GB RAM for moderate batch sizes\n- **Up to 4096**: Requires GPU or 32+ GB RAM\n- **Up to 16384**: Maximum supported, requires significant memory\n\nSee [Data Preparation Guide](data_preparation.md) for context length recommendations by data frequency.\n\n### Tier 1: Minimal (CPU-Only, 4–8 GB RAM)\n\n- **Use case**: Light exploration, single-series forecasting, prototyping\n- **Model**: TimesFM 2.5 (200M) only\n- **Batch size**: `per_core_batch_size=4`\n- **Context**: Limit `max_context=512`\n- **Expected speed**: ~2–5 seconds per 100-point series\n\n```python\nmodel.compile(timesfm.ForecastConfig(\n    max_context=512,\n    max_horizon=128,\n    per_core_batch_size=4,\n    normalize_inputs=True,\n    use_continuous_quantile_head=True,\n    fix_quantile_crossing=True,\n))\n```\n\n### Tier 2: Standard (CPU 16 GB or GPU 4–8 GB VRAM)\n\n- **Use case**: Batch forecasting (dozens of series), evaluation, production prototypes\n- **Model**: TimesFM 2.5 (200M)\n- **Batch size**: `per_core_batch_size=32` (CPU) or `64` (GPU)\n- **Context**: `max_context=1024`\n- **Expected speed**: ~0.5–1 second per 100-point series (GPU)\n\n```python\nmodel.compile(timesfm.ForecastConfig(\n    max_context=1024,\n    max_horizon=256,\n    per_core_batch_size=64,\n    normalize_inputs=True,\n    use_continuous_quantile_head=True,\n    fix_quantile_crossing=True,\n))\n```\n\n### Tier 3: Production (GPU 16+ GB VRAM or Apple Silicon 32+ GB)\n\n- **Use case**: Large-scale batch forecasting (thousands of series), long context\n- **Model**: TimesFM 2.5 (200M)\n- **Batch size**: `per_core_batch_size=128–256`\n- **Context**: `max_context=4096` or higher\n- **Expected speed**: ~0.1–0.3 seconds per 100-point series\n\n```python\nmodel.compile(timesfm.ForecastConfig(\n    max_context=4096,\n    max_horizon=256,\n    per_core_batch_size=128,\n    normalize_inputs=True,\n    use_continuous_quantile_head=True,\n    fix_quantile_crossing=True,\n))\n```\n\n### Tier 4: Legacy Models (v1.0/v2.0 — 500M parameters)\n\n- **⚠️ WARNING**: TimesFM v2.0 (500M) requires **≥ 16 GB RAM** (CPU) or **≥ 8 GB VRAM** (GPU)\n- **⚠️ WARNING**: TimesFM v1.0 legacy JAX version may require **≥ 32 GB RAM**\n- **Recommendation**: Unless you specifically need a legacy checkpoint, use TimesFM 2.5\n\n## Memory Estimation\n\n### CPU Memory (RAM)\n\nApproximate RAM usage during inference:\n\n| Component | TimesFM 2.5 (200M) | TimesFM 2.0 (500M) |\n| --------- | ------------------- | ------------------- |\n| Model weights | ~800 MB | ~2 GB |\n| Runtime overhead | ~500 MB | ~1 GB |\n| Input/output buffers | ~200 MB per 1000 series | ~500 MB per 1000 series |\n| **Total (small batch)** | **~1.5 GB** | **~3.5 GB** |\n| **Total (large batch)** | **~3 GB** | **~6 GB** |\n\n**Formula**: `RAM ≈ model_weights + 0.5 GB + (0.2 MB × num_series × context_length / 1000)`\n\n### GPU Memory (VRAM)\n\n| Component | TimesFM 2.5 (200M) |\n| --------- | ------------------- |\n| Model weights | ~800 MB |\n| KV cache + activations | ~200–500 MB (scales with context) |\n| Batch buffers | ~100 MB per 100 series at context=1024 |\n| **Total (batch=32)** | **~1.2 GB** |\n| **Total (batch=128)** | **~1.8 GB** |\n| **Total (batch=256)** | **~2.5 GB** |\n\n### Disk Space\n\n| Item | Size |\n| ---- | ---- |\n| TimesFM 2.5 safetensors | ~800 MB |\n| Hugging Face cache overhead | ~200 MB |\n| **Total download** | **~1 GB** |\n\nModel weights are downloaded once from Hugging Face Hub and cached in\n`~/.cache/huggingface/` (or `$HF_HOME`).\n\n## GPU Selection Guide\n\n### NVIDIA GPUs (CUDA)\n\n| GPU | VRAM | Recommended batch | Notes |\n| --- | ---- | ----------------- | ----- |\n| RTX 3060 | 12 GB | 64 | Good entry-level |\n| RTX 3090 / 4090 | 24 GB | 256 | Excellent for production |\n| A100 (40 GB) | 40 GB | 512 | Cloud/HPC |\n| A100 (80 GB) | 80 GB | 1024 | Cloud/HPC |\n| T4 | 16 GB | 128 | Cloud (Colab, AWS) |\n| V100 | 16–32 GB | 128–256 | Cloud |\n\n### Apple Silicon (MPS)\n\n| Chip | Unified Memory | Recommended batch | Notes |\n| ---- | -------------- | ----------------- | ----- |\n| M1 | 8–16 GB | 16–32 | Works, slower than CUDA |\n| M1 Pro/Max | 16–64 GB | 32–128 | Good performance |\n| M2/M3/M4 Pro/Max | 18–128 GB | 64–256 | Excellent |\n\n### CPU Only\n\nWorks on any CPU with sufficient RAM. Expect 5–20× slower than GPU.\n\n## Python and Package Requirements\n\n| Requirement | Minimum | Recommended |\n| ----------- | ------- | ----------- |\n| Python | 3.10 | 3.12+ |\n| numpy | 1.26.4 | latest |\n| torch | 2.0.0 | latest |\n| huggingface_hub | 0.23.0 | latest |\n| safetensors | 0.5.3 | latest |\n\n### Optional Dependencies\n\n| Package | Purpose | Install |\n| ------- | ------- | ------- |\n| jax | Flax backend | `pip install jax[cuda]` |\n| flax | Flax backend | `pip install flax` |\n| scikit-learn | XReg covariates | `pip install scikit-learn` |\n\n## Operating System Compatibility\n\n| OS | Status | Notes |\n| -- | ------ | ----- |\n| Linux (Ubuntu 20.04+) | ✅ Fully supported | Best performance with CUDA |\n| macOS 13+ (Ventura) | ✅ Fully supported | MPS acceleration on Apple Silicon |\n| Windows 11 + WSL2 | ✅ Supported | Use WSL2 for best experience |\n| Windows (native) | ⚠️ Partial | PyTorch works, some edge cases |\n\n## Troubleshooting\n\n### Out of Memory (OOM)\n\n```python\n# Reduce batch size\nmodel.compile(timesfm.ForecastConfig(\n    per_core_batch_size=4,  # Start very small\n    max_context=512,        # Reduce context\n    ...\n))\n\n# Process in chunks\nfor i in range(0, len(inputs), 50):\n    chunk = inputs[i:i+50]\n    p, q = model.forecast(horizon=H, inputs=chunk)\n```\n\n### Slow Inference on CPU\n\n```python\n# Ensure matmul precision is set\nimport torch\ntorch.set_float32_matmul_precision(\"high\")\n\n# Use smaller context\nmodel.compile(timesfm.ForecastConfig(\n    max_context=256,  # Shorter context = faster\n    ...\n))\n```\n\n### Model Download Fails\n\n```bash\n# Set a different cache directory\nexport HF_HOME=/path/with/more/space\n\n# Or download manually\nhuggingface-cli download google/timesfm-2.5-200m-pytorch\n```\n"
  },
  {
    "path": "timesfm-forecasting/scripts/check_system.py",
    "content": "#!/usr/bin/env python3\n\"\"\"TimesFM System Requirements Preflight Checker.\n\nMANDATORY: Run this script before loading TimesFM for the first time.\nIt checks RAM, GPU/VRAM, disk space, Python version, and package\ninstallation so the agent never crashes a user's machine.\n\nUsage:\n    python check_system.py\n    python check_system.py --model v2.5   # default\n    python check_system.py --model v2.0   # archived 500M model\n    python check_system.py --model v1.0   # archived 200M model\n    python check_system.py --json         # machine-readable output\n\"\"\"\n\nfrom __future__ import annotations\n\nimport argparse\nimport json\nimport os\nimport platform\nimport shutil\nimport struct\nimport sys\nfrom dataclasses import dataclass, field\nfrom pathlib import Path\nfrom typing import Any\nimport math\n\n\n# ---------------------------------------------------------------------------\n# Model requirement profiles\n# ---------------------------------------------------------------------------\n\nMODEL_PROFILES: dict[str, dict[str, Any]] = {\n    \"v2.5\": {\n        \"name\": \"TimesFM 2.5 (200M)\",\n        \"params\": \"200M\",\n        \"min_ram_gb\": 2.0,\n        \"recommended_ram_gb\": 4.0,\n        \"min_vram_gb\": 2.0,\n        \"recommended_vram_gb\": 4.0,\n        \"disk_gb\": 2.0,  # model weights + overhead\n        \"hf_repo\": \"google/timesfm-2.5-200m-pytorch\",\n    },\n    \"v2.0\": {\n        \"name\": \"TimesFM 2.0 (500M)\",\n        \"params\": \"500M\",\n        \"min_ram_gb\": 8.0,\n        \"recommended_ram_gb\": 16.0,\n        \"min_vram_gb\": 4.0,\n        \"recommended_vram_gb\": 8.0,\n        \"disk_gb\": 4.0,\n        \"hf_repo\": \"google/timesfm-2.0-500m-pytorch\",\n    },\n    \"v1.0\": {\n        \"name\": \"TimesFM 1.0 (200M)\",\n        \"params\": \"200M\",\n        \"min_ram_gb\": 4.0,\n        \"recommended_ram_gb\": 8.0,\n        \"min_vram_gb\": 2.0,\n        \"recommended_vram_gb\": 4.0,\n        \"disk_gb\": 2.0,\n        \"hf_repo\": \"google/timesfm-1.0-200m-pytorch\",\n    },\n}\n\n\n# ---------------------------------------------------------------------------\n# Result dataclass\n# ---------------------------------------------------------------------------\n\n\n@dataclass\nclass CheckResult:\n    name: str\n    status: str  # \"pass\", \"warn\", \"fail\"\n    detail: str\n    value: str = \"\"\n\n    @property\n    def icon(self) -> str:\n        return {\"pass\": \"✅\", \"warn\": \"⚠️\", \"fail\": \"🛑\"}.get(self.status, \"❓\")\n\n    def __str__(self) -> str:\n        return f\"[{self.name:<10}] {self.value:<40} {self.icon} {self.status.upper()}\"\n\n\n@dataclass\nclass SystemReport:\n    model: str\n    checks: list[CheckResult] = field(default_factory=list)\n    verdict: str = \"\"\n    verdict_detail: str = \"\"\n    recommended_batch_size: int = 1\n    mode: str = \"cpu\"  # \"cpu\", \"gpu\", \"mps\"\n\n    @property\n    def passed(self) -> bool:\n        return all(c.status != \"fail\" for c in self.checks)\n\n    def to_dict(self) -> dict[str, Any]:\n        return {\n            \"model\": self.model,\n            \"passed\": self.passed,\n            \"mode\": self.mode,\n            \"recommended_batch_size\": self.recommended_batch_size,\n            \"verdict\": self.verdict,\n            \"verdict_detail\": self.verdict_detail,\n            \"checks\": [\n                {\n                    \"name\": c.name,\n                    \"status\": c.status,\n                    \"detail\": c.detail,\n                    \"value\": c.value,\n                }\n                for c in self.checks\n            ],\n        }\n\n\n# ---------------------------------------------------------------------------\n# Individual checks\n# ---------------------------------------------------------------------------\n\n\ndef _get_total_ram_gb() -> float:\n    \"\"\"Return total physical RAM in GB, cross-platform.\"\"\"\n    try:\n        if sys.platform == \"linux\":\n            with open(\"/proc/meminfo\") as f:\n                for line in f:\n                    if line.startswith(\"MemTotal\"):\n                        return int(line.split()[1]) / (1024 * 1024)\n        elif sys.platform == \"darwin\":\n            import subprocess\n\n            result = subprocess.run(\n                [\"sysctl\", \"-n\", \"hw.memsize\"],\n                capture_output=True,\n                text=True,\n                check=True,\n            )\n            return int(result.stdout.strip()) / (1024**3)\n        elif sys.platform == \"win32\":\n            import ctypes\n\n            kernel32 = ctypes.windll.kernel32  # type: ignore[attr-defined]\n\n            class MEMORYSTATUSEX(ctypes.Structure):\n                _fields_ = [\n                    (\"dwLength\", ctypes.c_ulong),\n                    (\"dwMemoryLoad\", ctypes.c_ulong),\n                    (\"ullTotalPhys\", ctypes.c_ulonglong),\n                    (\"ullAvailPhys\", ctypes.c_ulonglong),\n                    (\"ullTotalPageFile\", ctypes.c_ulonglong),\n                    (\"ullAvailPageFile\", ctypes.c_ulonglong),\n                    (\"ullTotalVirtual\", ctypes.c_ulonglong),\n                    (\"ullAvailVirtual\", ctypes.c_ulonglong),\n                    (\"sullAvailExtendedVirtual\", ctypes.c_ulonglong),\n                ]\n\n            stat = MEMORYSTATUSEX()\n            stat.dwLength = ctypes.sizeof(stat)\n            kernel32.GlobalMemoryStatusEx(ctypes.byref(stat))\n            return stat.ullTotalPhys / (1024**3)\n    except Exception:\n        pass\n\n    # Fallback: use struct to estimate (unreliable)\n    return struct.calcsize(\"P\") * 8 / 8  # placeholder\n\n\ndef _get_available_ram_gb() -> float:\n    \"\"\"Return available RAM in GB.\"\"\"\n    try:\n        if sys.platform == \"linux\":\n            with open(\"/proc/meminfo\") as f:\n                for line in f:\n                    if line.startswith(\"MemAvailable\"):\n                        return int(line.split()[1]) / (1024 * 1024)\n        elif sys.platform == \"darwin\":\n            import subprocess\n\n            # Use vm_stat for available memory on macOS\n            result = subprocess.run(\n                [\"vm_stat\"], capture_output=True, text=True, check=True\n            )\n            free = 0\n            page_size = 4096\n            for line in result.stdout.split(\"\\n\"):\n                if \"Pages free\" in line or \"Pages inactive\" in line:\n                    val = line.split(\":\")[1].strip().rstrip(\".\")\n                    free += int(val) * page_size\n            return free / (1024**3)\n        elif sys.platform == \"win32\":\n            import ctypes\n\n            kernel32 = ctypes.windll.kernel32  # type: ignore[attr-defined]\n\n            class MEMORYSTATUSEX(ctypes.Structure):\n                _fields_ = [\n                    (\"dwLength\", ctypes.c_ulong),\n                    (\"dwMemoryLoad\", ctypes.c_ulong),\n                    (\"ullTotalPhys\", ctypes.c_ulonglong),\n                    (\"ullAvailPhys\", ctypes.c_ulonglong),\n                    (\"ullTotalPageFile\", ctypes.c_ulonglong),\n                    (\"ullAvailPageFile\", ctypes.c_ulonglong),\n                    (\"ullTotalVirtual\", ctypes.c_ulonglong),\n                    (\"ullAvailVirtual\", ctypes.c_ulonglong),\n                    (\"sullAvailExtendedVirtual\", ctypes.c_ulonglong),\n                ]\n\n            stat = MEMORYSTATUSEX()\n            stat.dwLength = ctypes.sizeof(stat)\n            kernel32.GlobalMemoryStatusEx(ctypes.byref(stat))\n            return stat.ullAvailPhys / (1024**3)\n    except Exception:\n        pass\n    return 0.0\n\n\ndef check_ram(profile: dict[str, Any]) -> CheckResult:\n    \"\"\"Check if system has enough RAM.\"\"\"\n    total = _get_total_ram_gb()\n    available = _get_available_ram_gb()\n    min_ram = profile[\"min_ram_gb\"]\n    rec_ram = profile[\"recommended_ram_gb\"]\n\n    value = f\"Total: {total:.1f} GB | Available: {available:.1f} GB\"\n\n    if total < min_ram:\n        return CheckResult(\n            name=\"RAM\",\n            status=\"fail\",\n            detail=(\n                f\"System has {total:.1f} GB RAM but {profile['name']} requires \"\n                f\"at least {min_ram:.0f} GB. The model will likely fail to load \"\n                f\"or cause the system to swap heavily and become unresponsive.\"\n            ),\n            value=value,\n        )\n    elif total < rec_ram:\n        return CheckResult(\n            name=\"RAM\",\n            status=\"warn\",\n            detail=(\n                f\"System has {total:.1f} GB RAM. {profile['name']} recommends \"\n                f\"{rec_ram:.0f} GB. It may work with small batch sizes but could \"\n                f\"be tight. Use per_core_batch_size=4 or lower.\"\n            ),\n            value=value,\n        )\n    else:\n        return CheckResult(\n            name=\"RAM\",\n            status=\"pass\",\n            detail=f\"System has {total:.1f} GB RAM, meets {rec_ram:.0f} GB recommendation.\",\n            value=value,\n        )\n\n\ndef check_gpu() -> CheckResult:\n    \"\"\"Check GPU availability and VRAM.\"\"\"\n    # Try CUDA first\n    try:\n        import torch\n\n        if torch.cuda.is_available():\n            name = torch.cuda.get_device_name(0)\n            vram = torch.cuda.get_device_properties(0).total_memory / (1024**3)\n            return CheckResult(\n                name=\"GPU\",\n                status=\"pass\",\n                detail=f\"{name} with {vram:.1f} GB VRAM detected.\",\n                value=f\"{name} | VRAM: {vram:.1f} GB\",\n            )\n        elif hasattr(torch.backends, \"mps\") and torch.backends.mps.is_available():\n            return CheckResult(\n                name=\"GPU\",\n                status=\"pass\",\n                detail=\"Apple Silicon MPS backend available. Uses unified memory.\",\n                value=\"Apple Silicon MPS\",\n            )\n        else:\n            return CheckResult(\n                name=\"GPU\",\n                status=\"warn\",\n                detail=(\n                    \"No GPU detected. TimesFM will run on CPU (slower but functional). \"\n                    \"Install CUDA-enabled PyTorch for GPU acceleration.\"\n                ),\n                value=\"None (CPU only)\",\n            )\n    except ImportError:\n        return CheckResult(\n            name=\"GPU\",\n            status=\"warn\",\n            detail=\"PyTorch not installed — cannot check GPU. Install torch first.\",\n            value=\"Unknown (torch not installed)\",\n        )\n\n\ndef check_disk(profile: dict[str, Any]) -> CheckResult:\n    \"\"\"Check available disk space for model download.\"\"\"\n    # Check HuggingFace cache dir or home dir\n    hf_cache = os.environ.get(\"HF_HOME\", os.path.expanduser(\"~/.cache/huggingface\"))\n    cache_dir = Path(hf_cache)\n    check_dir = cache_dir if cache_dir.exists() else Path.home()\n\n    usage = shutil.disk_usage(str(check_dir))\n    free_gb = usage.free / (1024**3)\n    required = profile[\"disk_gb\"]\n\n    value = f\"Free: {free_gb:.1f} GB (in {check_dir})\"\n\n    if free_gb < required:\n        return CheckResult(\n            name=\"Disk\",\n            status=\"fail\",\n            detail=(\n                f\"Only {free_gb:.1f} GB free in {check_dir}. \"\n                f\"Need at least {required:.0f} GB for model weights. \"\n                f\"Free up space or set HF_HOME to a larger volume.\"\n            ),\n            value=value,\n        )\n    else:\n        return CheckResult(\n            name=\"Disk\",\n            status=\"pass\",\n            detail=f\"{free_gb:.1f} GB available, exceeds {required:.0f} GB requirement.\",\n            value=value,\n        )\n\n\ndef check_python() -> CheckResult:\n    \"\"\"Check Python version >= 3.10.\"\"\"\n    version = sys.version.split()[0]\n    major, minor = sys.version_info[:2]\n\n    if (major, minor) < (3, 10):\n        return CheckResult(\n            name=\"Python\",\n            status=\"fail\",\n            detail=f\"Python {version} detected. TimesFM requires Python >= 3.10.\",\n            value=version,\n        )\n    else:\n        return CheckResult(\n            name=\"Python\",\n            status=\"pass\",\n            detail=f\"Python {version} meets >= 3.10 requirement.\",\n            value=version,\n        )\n\n\ndef check_package(pkg_name: str, import_name: str | None = None) -> CheckResult:\n    \"\"\"Check if a Python package is installed.\"\"\"\n    import_name = import_name or pkg_name\n    try:\n        mod = __import__(import_name)\n        version = getattr(mod, \"__version__\", \"unknown\")\n        return CheckResult(\n            name=pkg_name,\n            status=\"pass\",\n            detail=f\"{pkg_name} {version} is installed.\",\n            value=f\"Installed ({version})\",\n        )\n    except ImportError:\n        return CheckResult(\n            name=pkg_name,\n            status=\"warn\",\n            detail=f\"{pkg_name} is not installed. Run: uv pip install {pkg_name}\",\n            value=\"Not installed\",\n        )\n\n\n# ---------------------------------------------------------------------------\n# Batch size recommendation\n# ---------------------------------------------------------------------------\n\n\ndef recommend_batch_size(report: SystemReport) -> int:\n    \"\"\"Recommend per_core_batch_size based on available resources.\"\"\"\n    total_ram = _get_total_ram_gb()\n\n    # Check if GPU is available\n    gpu_check = next((c for c in report.checks if c.name == \"GPU\"), None)\n\n    if gpu_check and gpu_check.status == \"pass\" and \"VRAM\" in gpu_check.value:\n        # Extract VRAM\n        try:\n            vram_str = gpu_check.value.split(\"VRAM:\")[1].strip().split()[0]\n            vram = float(vram_str)\n            if vram >= 24:\n                return 256\n            elif vram >= 16:\n                return 128\n            elif vram >= 8:\n                return 64\n            elif vram >= 4:\n                return 32\n            else:\n                return 16\n        except (ValueError, IndexError):\n            return 32\n    elif gpu_check and \"MPS\" in gpu_check.value:\n        # Apple Silicon — use unified memory heuristic\n        if total_ram >= 32:\n            return 64\n        elif total_ram >= 16:\n            return 32\n        else:\n            return 16\n    else:\n        # CPU only\n        if total_ram >= 32:\n            return 64\n        elif total_ram >= 16:\n            return 32\n        elif total_ram >= 8:\n            return 8\n        else:\n            return 4\n\n\ndef estimate_memory_gb(\n    num_series: int,\n    context_length: int,\n    horizon: int = 0,\n    batch_size: int = 32,\n    model_version: str = \"v2.5\",\n) -> dict[str, float]:\n    \"\"\"Estimate memory requirements for a dataset.\n\n    Args:\n        num_series: Number of time series in the dataset\n        context_length: Length of each time series context window\n        horizon: Forecast horizon (optional, for output storage)\n        batch_size: Batch size for inference\n        model_version: Model version being used\n\n    Returns:\n        Dictionary with memory estimates in GB for different components\n    \"\"\"\n    # Base model memory (weights + overhead)\n    model_memory_gb = 0.8  # ~800MB for model weights\n    overhead_gb = 0.5  # Python overhead, libraries, etc.\n\n    # Input data memory: each value is float32 (4 bytes)\n    # Formula: num_series * context_length * 4 bytes / (1024^3)\n    input_gb = (num_series * context_length * 4) / (1024**3)\n\n    # Batch processing memory (peak during inference)\n    # Each batch needs: batch_size * context_length * 4 bytes\n    batch_input_gb = (batch_size * context_length * 4) / (1024**3)\n\n    # Output memory: horizon * num_series * quantiles * 4 bytes\n    # Default is 10 quantiles (mean + 9 quantiles)\n    num_quantiles = 10\n    output_gb = (num_series * horizon * num_quantiles * 4) / (1024**3) if horizon > 0 else 0\n\n    # Total memory with some headroom for intermediate computations\n    total_gb = model_memory_gb + overhead_gb + input_gb + batch_input_gb + output_gb\n\n    # Add 20% buffer for intermediate tensors and OS overhead\n    total_with_buffer = total_gb * 1.2\n\n    return {\n        \"model_weights\": model_memory_gb,\n        \"overhead\": overhead_gb,\n        \"input_data\": input_gb,\n        \"batch_processing\": batch_input_gb,\n        \"output_data\": output_gb,\n        \"total\": total_gb,\n        \"total_with_buffer\": total_with_buffer,\n    }\n\n\ndef check_dataset_fit(\n    num_series: int,\n    context_length: int,\n    horizon: int = 0,\n    batch_size: int = 32,\n    model_version: str = \"v2.5\",\n) -> tuple[bool, str, dict[str, float]]:\n    \"\"\"Check if a dataset will fit in available memory.\n\n    Args:\n        num_series: Number of time series in the dataset\n        context_length: Length of each time series context window\n        horizon: Forecast horizon (optional)\n        batch_size: Batch size for inference\n        model_version: Model version being used\n\n    Returns:\n        Tuple of (fits: bool, message: str, memory_details: dict)\n    \"\"\"\n    memory = estimate_memory_gb(num_series, context_length, horizon, batch_size, model_version)\n    total_ram = _get_total_ram_gb()\n    available_ram = _get_available_ram_gb()\n\n    required = memory[\"total_with_buffer\"]\n\n    # Leave 10% headroom for OS and other processes\n    usable_ram = total_ram * 0.9\n    usable_available = available_ram * 0.9 if available_ram > 0 else usable_ram\n\n    if required > total_ram:\n        return (\n            False,\n            f\"Dataset requires {required:.1f} GB but system only has {total_ram:.1f} GB RAM. \"\n            f\"Consider processing in chunks or using a machine with more RAM.\",\n            memory,\n        )\n    elif required > usable_available:\n        return (\n            False,\n            f\"Dataset requires {required:.1f} GB but only {available_ram:.1f} GB is available. \"\n            f\"Close other applications or restart to free memory.\",\n            memory,\n        )\n    elif required > usable_ram * 0.8:\n        return (\n            True,\n            f\"Dataset will fit ({required:.1f} GB needed, {total_ram:.1f} GB total) \"\n            f\"but memory usage will be high. Consider reducing batch_size.\",\n            memory,\n        )\n    else:\n        return (\n            True,\n            f\"Dataset fits comfortably: {required:.1f} GB needed, {total_ram:.1f} GB available.\",\n            memory,\n        )\n\n\ndef print_memory_estimate(\n    num_series: int,\n    context_length: int,\n    horizon: int = 0,\n    batch_size: int = 32,\n    model_version: str = \"v2.5\",\n) -> None:\n    \"\"\"Print a detailed memory estimate for a dataset.\n\n    Args:\n        num_series: Number of time series in the dataset\n        context_length: Length of each time series context window\n        horizon: Forecast horizon (optional)\n        batch_size: Batch size for inference\n        model_version: Model version being used\n    \"\"\"\n    memory = estimate_memory_gb(num_series, context_length, horizon, batch_size, model_version)\n    total_ram = _get_total_ram_gb()\n    available_ram = _get_available_ram_gb()\n\n    print(f\"\\n{'=' * 50}\")\n    print(f\" Memory Estimate for Dataset\")\n    print(f\"{'=' * 50}\")\n    print(f\"  Dataset: {num_series:,} series × {context_length} context length\")\n    if horizon > 0:\n        print(f\"  Horizon: {horizon} steps\")\n    print(f\"  Batch size: {batch_size}\")\n    print(f\"  Model: {model_version}\")\n    print(f\"{'-' * 50}\")\n    print(f\"  Model weights:     {memory['model_weights']:.2f} GB\")\n    print(f\"  Overhead:          {memory['overhead']:.2f} GB\")\n    print(f\"  Input data:        {memory['input_data']:.2f} GB\")\n    print(f\"  Batch processing:  {memory['batch_processing']:.2f} GB\")\n    if horizon > 0:\n        print(f\"  Output data:       {memory['output_data']:.2f} GB\")\n    print(f\"{'-' * 50}\")\n    print(f\"  Total (raw):       {memory['total']:.2f} GB\")\n    print(f\"  Total (+20% buf):  {memory['total_with_buffer']:.2f} GB\")\n    print(f\"{'-' * 50}\")\n    print(f\"  System RAM:        {total_ram:.1f} GB\")\n    print(f\"  Available RAM:     {available_ram:.1f} GB\")\n    print(f\"{'=' * 50}\")\n\n    fits, message, _ = check_dataset_fit(\n        num_series, context_length, horizon, batch_size, model_version\n    )\n    status_icon = \"✅\" if fits else \"🛑\"\n    print(f\"  {status_icon} {message}\")\n    print(f\"{'=' * 50}\\n\")\n\n\n# ---------------------------------------------------------------------------\n# Main\n# ---------------------------------------------------------------------------\n\n\ndef run_checks(model_version: str = \"v2.5\") -> SystemReport:\n    \"\"\"Run all system checks and return a report.\"\"\"\n    profile = MODEL_PROFILES[model_version]\n    report = SystemReport(model=profile[\"name\"])\n\n    # Run checks\n    report.checks.append(check_ram(profile))\n    report.checks.append(check_gpu())\n    report.checks.append(check_disk(profile))\n    report.checks.append(check_python())\n    report.checks.append(check_package(\"timesfm\"))\n    report.checks.append(check_package(\"torch\"))\n\n    # Determine mode\n    gpu_check = next((c for c in report.checks if c.name == \"GPU\"), None)\n    if gpu_check and gpu_check.status == \"pass\":\n        if \"MPS\" in gpu_check.value:\n            report.mode = \"mps\"\n        else:\n            report.mode = \"gpu\"\n    else:\n        report.mode = \"cpu\"\n\n    # Batch size\n    report.recommended_batch_size = recommend_batch_size(report)\n\n    # Verdict\n    if report.passed:\n        report.verdict = (\n            f\"✅ System is ready for {profile['name']} ({report.mode.upper()} mode)\"\n        )\n        report.verdict_detail = (\n            f\"Recommended: per_core_batch_size={report.recommended_batch_size}\"\n        )\n    else:\n        failed = [c for c in report.checks if c.status == \"fail\"]\n        report.verdict = f\"🛑 System does NOT meet requirements for {profile['name']}\"\n        report.verdict_detail = \"; \".join(c.detail for c in failed)\n\n    return report\n\n\ndef print_report(report: SystemReport) -> None:\n    \"\"\"Print a human-readable report to stdout.\"\"\"\n    print(f\"\\n{'=' * 50}\")\n    print(f\"  TimesFM System Requirements Check\")\n    print(f\"  Model: {report.model}\")\n    print(f\"{'=' * 50}\\n\")\n\n    for check in report.checks:\n        print(f\"  {check}\")\n    print()\n\n    print(f\"  VERDICT: {report.verdict}\")\n    if report.verdict_detail:\n        print(f\"  {report.verdict_detail}\")\n    print()\n\n\ndef main() -> None:\n    parser = argparse.ArgumentParser(\n        description=\"Check system requirements for TimesFM.\",\n    )\n    parser.add_argument(\n        \"--model\",\n        choices=list(MODEL_PROFILES.keys()),\n        default=\"v2.5\",\n        help=\"Model version to check requirements for (default: v2.5)\",\n    )\n    parser.add_argument(\n        \"--json\",\n        action=\"store_true\",\n        help=\"Output results as JSON (machine-readable)\",\n    )\n    # Dataset preflight options (NEW)\n    dataset_group = parser.add_argument_group(\"dataset preflight (optional)\")\n    dataset_group.add_argument(\n        \"--num-series\",\n        type=int,\n        metavar=\"N\",\n        help=\"Number of time series in your dataset (for memory estimation)\",\n    )\n    dataset_group.add_argument(\n        \"--context-length\",\n        type=int,\n        metavar=\"LEN\",\n        help=\"Length of each input time series (max_context value)\",\n    )\n    dataset_group.add_argument(\n        \"--horizon\",\n        type=int,\n        metavar=\"H\",\n        default=24,\n        help=\"Forecast horizon length (default: 24)\",\n    )\n    dataset_group.add_argument(\n        \"--batch-size\",\n        type=int,\n        metavar=\"SIZE\",\n        default=32,\n        help=\"per_core_batch_size from ForecastConfig (default: 32)\",\n    )\n    dataset_group.add_argument(\n        \"--estimate-only\",\n        action=\"store_true\",\n        help=\"Only show memory estimate, skip system checks\",\n    )\n    args = parser.parse_args()\n\n    # Handle dataset estimation only mode\n    if args.estimate_only and args.num_series and args.context_length:\n        print_memory_estimate(\n            args.num_series,\n            args.context_length,\n            args.horizon,\n            args.batch_size,\n            args.model,\n        )\n        sys.exit(0)\n\n    # Run system checks\n    report = run_checks(args.model)\n\n    # Add dataset check if parameters provided\n    if args.num_series and args.context_length:\n        print_memory_estimate(\n            args.num_series,\n            args.context_length,\n            args.horizon,\n            args.batch_size,\n            args.model,\n        )\n\n    if args.json:\n        print(json.dumps(report.to_dict(), indent=2))\n    else:\n        print_report(report)\n\n    # Exit with non-zero if any check failed\n    sys.exit(0 if report.passed else 1)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "timesfm-forecasting/scripts/forecast_csv.py",
    "content": "#!/usr/bin/env python3\n\"\"\"End-to-end CSV forecasting with TimesFM.\n\nLoads a CSV, runs the system preflight check, loads TimesFM, forecasts\nthe requested columns, and writes results to a new CSV or JSON.\n\nUsage:\n    python forecast_csv.py input.csv --horizon 24\n    python forecast_csv.py input.csv --horizon 12 --date-col date --value-cols sales,revenue\n    python forecast_csv.py input.csv --horizon 52 --output forecasts.csv\n    python forecast_csv.py input.csv --horizon 30 --output forecasts.json --format json\n\nThe script automatically:\n  1. Runs the system preflight check (exits if it fails).\n  2. Loads TimesFM 2.5 from Hugging Face.\n  3. Reads the CSV and identifies time series columns.\n  4. Forecasts each series with prediction intervals.\n  5. Writes results to the specified output file.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport argparse\nimport json\nimport sys\nfrom pathlib import Path\n\nimport numpy as np\nimport pandas as pd\n\n\ndef run_preflight() -> dict:\n    \"\"\"Run the system preflight check and return the report.\"\"\"\n    # Import the check_system module from the same directory\n    script_dir = Path(__file__).parent\n    sys.path.insert(0, str(script_dir))\n    from check_system import run_checks\n\n    report = run_checks(\"v2.5\")\n    if not report.passed:\n        print(\"\\n🛑 System check FAILED. Cannot proceed with forecasting.\")\n        print(f\"   {report.verdict_detail}\")\n        print(\"\\nRun 'python scripts/check_system.py' for details.\")\n        sys.exit(1)\n\n    return report.to_dict()\n\n\ndef load_model(batch_size: int = 32):\n    \"\"\"Load and compile the TimesFM model.\"\"\"\n    import torch\n    import timesfm\n\n    torch.set_float32_matmul_precision(\"high\")\n\n    print(\"Loading TimesFM 2.5 from Hugging Face...\")\n    model = timesfm.TimesFM_2p5_200M_torch.from_pretrained(\n        \"google/timesfm-2.5-200m-pytorch\"\n    )\n\n    print(f\"Compiling with per_core_batch_size={batch_size}...\")\n    model.compile(\n        timesfm.ForecastConfig(\n            max_context=1024,\n            max_horizon=256,\n            normalize_inputs=True,\n            use_continuous_quantile_head=True,\n            force_flip_invariance=True,\n            infer_is_positive=True,\n            fix_quantile_crossing=True,\n            per_core_batch_size=batch_size,\n        )\n    )\n\n    return model\n\n\ndef load_csv(\n    path: str,\n    date_col: str | None = None,\n    value_cols: list[str] | None = None,\n) -> tuple[pd.DataFrame, list[str], str | None]:\n    \"\"\"Load CSV and identify time series columns.\n\n    Returns:\n        (dataframe, value_column_names, date_column_name_or_none)\n    \"\"\"\n    df = pd.read_csv(path)\n\n    # Identify date column\n    if date_col and date_col in df.columns:\n        df[date_col] = pd.to_datetime(df[date_col])\n    elif date_col:\n        print(f\"⚠️ Date column '{date_col}' not found. Available: {list(df.columns)}\")\n        date_col = None\n\n    # Identify value columns\n    if value_cols:\n        missing = [c for c in value_cols if c not in df.columns]\n        if missing:\n            print(f\"⚠️ Columns not found: {missing}. Available: {list(df.columns)}\")\n            value_cols = [c for c in value_cols if c in df.columns]\n    else:\n        # Auto-detect numeric columns (exclude date)\n        numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()\n        if date_col and date_col in numeric_cols:\n            numeric_cols.remove(date_col)\n        value_cols = numeric_cols\n\n    if not value_cols:\n        print(\"🛑 No numeric columns found to forecast.\")\n        sys.exit(1)\n\n    print(f\"Found {len(value_cols)} series to forecast: {value_cols}\")\n    return df, value_cols, date_col\n\n\ndef forecast_series(\n    model, df: pd.DataFrame, value_cols: list[str], horizon: int\n) -> dict[str, dict]:\n    \"\"\"Forecast all series and return results dict.\"\"\"\n    inputs = []\n    for col in value_cols:\n        values = df[col].dropna().values.astype(np.float32)\n        inputs.append(values)\n\n    print(f\"Forecasting {len(inputs)} series with horizon={horizon}...\")\n    point, quantiles = model.forecast(horizon=horizon, inputs=inputs)\n\n    results = {}\n    for i, col in enumerate(value_cols):\n        results[col] = {\n            \"forecast\": point[i].tolist(),\n            \"lower_90\": quantiles[i, :, 1].tolist(),  # 10th percentile\n            \"lower_80\": quantiles[i, :, 2].tolist(),  # 20th percentile\n            \"median\": quantiles[i, :, 5].tolist(),  # 50th percentile\n            \"upper_80\": quantiles[i, :, 8].tolist(),  # 80th percentile\n            \"upper_90\": quantiles[i, :, 9].tolist(),  # 90th percentile\n        }\n\n    return results\n\n\ndef write_csv_output(\n    results: dict[str, dict],\n    output_path: str,\n    df: pd.DataFrame,\n    date_col: str | None,\n    horizon: int,\n) -> None:\n    \"\"\"Write forecast results to CSV.\"\"\"\n    rows = []\n    for col, data in results.items():\n        # Try to generate future dates\n        future_dates = list(range(1, horizon + 1))\n        if date_col and date_col in df.columns:\n            try:\n                last_date = df[date_col].dropna().iloc[-1]\n                freq = pd.infer_freq(df[date_col].dropna())\n                if freq:\n                    future_dates = pd.date_range(\n                        last_date, periods=horizon + 1, freq=freq\n                    )[1:].tolist()\n            except Exception:\n                pass\n\n        for h in range(horizon):\n            row = {\n                \"series\": col,\n                \"step\": h + 1,\n                \"forecast\": data[\"forecast\"][h],\n                \"lower_90\": data[\"lower_90\"][h],\n                \"lower_80\": data[\"lower_80\"][h],\n                \"median\": data[\"median\"][h],\n                \"upper_80\": data[\"upper_80\"][h],\n                \"upper_90\": data[\"upper_90\"][h],\n            }\n            if isinstance(future_dates[0], (pd.Timestamp,)):\n                row[\"date\"] = future_dates[h]\n            rows.append(row)\n\n    out_df = pd.DataFrame(rows)\n    out_df.to_csv(output_path, index=False)\n    print(f\"✅ Wrote {len(rows)} forecast rows to {output_path}\")\n\n\ndef write_json_output(results: dict[str, dict], output_path: str) -> None:\n    \"\"\"Write forecast results to JSON.\"\"\"\n    with open(output_path, \"w\") as f:\n        json.dump(results, f, indent=2)\n    print(f\"✅ Wrote forecasts for {len(results)} series to {output_path}\")\n\n\ndef main() -> None:\n    parser = argparse.ArgumentParser(\n        description=\"Forecast time series from CSV using TimesFM.\"\n    )\n    parser.add_argument(\"input\", help=\"Path to input CSV file\")\n    parser.add_argument(\n        \"--horizon\", type=int, required=True, help=\"Number of steps to forecast\"\n    )\n    parser.add_argument(\"--date-col\", help=\"Name of the date/time column\")\n    parser.add_argument(\n        \"--value-cols\",\n        help=\"Comma-separated list of value columns to forecast (default: all numeric)\",\n    )\n    parser.add_argument(\n        \"--output\",\n        default=\"forecasts.csv\",\n        help=\"Output file path (default: forecasts.csv)\",\n    )\n    parser.add_argument(\n        \"--format\",\n        choices=[\"csv\", \"json\"],\n        default=None,\n        help=\"Output format (inferred from --output extension if not set)\",\n    )\n    parser.add_argument(\n        \"--batch-size\",\n        type=int,\n        default=None,\n        help=\"Override per_core_batch_size (auto-detected from system check if omitted)\",\n    )\n    parser.add_argument(\n        \"--skip-check\",\n        action=\"store_true\",\n        help=\"Skip system preflight check (not recommended)\",\n    )\n    args = parser.parse_args()\n\n    # Parse value columns\n    value_cols = None\n    if args.value_cols:\n        value_cols = [c.strip() for c in args.value_cols.split(\",\")]\n\n    # Determine output format\n    out_format = args.format\n    if not out_format:\n        out_format = \"json\" if args.output.endswith(\".json\") else \"csv\"\n\n    # 1. Preflight check\n    if not args.skip_check:\n        print(\"Running system preflight check...\")\n        report = run_preflight()\n        batch_size = args.batch_size or report.get(\"recommended_batch_size\", 32)\n    else:\n        print(\"⚠️ Skipping system check (--skip-check). Proceed with caution.\")\n        batch_size = args.batch_size or 32\n\n    # 2. Load model\n    model = load_model(batch_size=batch_size)\n\n    # 3. Load CSV\n    df, cols, date_col = load_csv(args.input, args.date_col, value_cols)\n\n    # 4. Forecast\n    results = forecast_series(model, df, cols, args.horizon)\n\n    # 5. Write output\n    if out_format == \"json\":\n        write_json_output(results, args.output)\n    else:\n        write_csv_output(results, args.output, df, date_col, args.horizon)\n\n    print(\"\\nDone! 🎉\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "v1/LICENSE",
    "content": "\n                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "v1/README.md",
    "content": "# TimesFM\n\nTimesFM  (Time Series Foundation Model) is a pretrained time-series foundation model developed by Google\nResearch for time-series forecasting.\n\n* Paper: [A decoder-only foundation model for time-series forecasting](https://arxiv.org/abs/2310.10688), to appear in ICML 2024.\n* [Google Research blog](https://research.google/blog/a-decoder-only-foundation-model-for-time-series-forecasting/)\n* [Hugging Face release](https://huggingface.co/collections/google/timesfm-release-66e4be5fdb56e960c1e482a6)\n\nThis repo contains the code to load public TimesFM checkpoints and run model\ninference. Please visit our \n[Hugging Face release](https://huggingface.co/collections/google/timesfm-release-66e4be5fdb56e960c1e482a6)\nto download model checkpoints.\n\nThis is not an officially supported Google product.\n\nWe recommend at least 32GB RAM to load TimesFM dependencies.\n\n**Need help?** See [TROUBLESHOOTING.md](TROUBLESHOOTING.md) for common installation and usage issues.\n\n## Update - Dec. 30, 2024\n- We are launching a 500m checkpoint as a part of TimesFM-2.0 release. This new checkpoint can be upto 25% better than v1.0 on leading benchmarks and also has a 4 times longer max. context length.\n- Launched [finetuning support](https://github.com/google-research/timesfm/blob/master/notebooks/finetuning.ipynb) that lets you finetune the weights of the pretrained TimesFM model on your own data.\n- Launched [~zero-shot covariate support](https://github.com/google-research/timesfm/blob/master/notebooks/covariates.ipynb) with external regressors. More details [here](https://github.com/google-research/timesfm?tab=readme-ov-file#covariates-support).\n\n## Update - Feb. 17, 2024\n- We are providing the option for [finetuning using Pytorch](https://github.com/google-research/timesfm/blob/master/notebooks/finetuning_torch.ipynb), which mimics the previously added functionality from [finetuning support](https://github.com/google-research/timesfm/blob/master/notebooks/finetuning.ipynb).\n- We are also providing the Multi-GPU finetuining with Pytorch. We currently support DDP multi-gpu finetuning, other variants of multi-gpu training (pipeline parallelism/model parallelism) might be added later. In order to use it, follow the steps in [finetuning example](https://github.com/google-research/timesfm/blob/master/finetuning/finetuning_example.py) .\n\n## Checkpoint timesfm-1.0-200m (-pytorch)\n\ntimesfm-1.0-200m is our first open model checkpoint:\n\n- It performs univariate time series forecasting for context lengths up to 512 timepoints and any horizon lengths, with an optional frequency indicator.\n- It focuses on point forecasts, and does not support probabilistic forecasts. We experimentally offer quantile heads but they have not been calibrated after pretraining.\n\n## Checkpoint timesfm-2.0-500m (-jax/-pytorch)\n\ntimesfm-2.0-500m is our second open model checkpoint:\n\n- It performs univariate time series forecasting for context lengths up to 2048 timepoints and any horizon lengths, with an optional frequency indicator.\n- It focuses on point forecasts. We experimentally offer 10 quantile heads but they have not been calibrated after pretraining.\n- This new checkpoint can be upto 25% better than v1.0 on leading benchmarks and also has a 4 times longer max. context length.\n\n## Benchmarking\n\nTimesFM 2.0 has been added to [GIFT-Eval](https://huggingface.co/spaces/Salesforce/GIFT-Eval) which is one of the most comprehensive time-series bechmarks available. It takes the top spot in terms of aggregated MASE and CRPS, where it is 6\\% better than the next best model in terms of aggregated MASE.\n\n## Installation\n\n### Local installation using poetry\n\nWe will be using `pyenv` and `poetry`. In order to set these things up please follow the instructions [here](https://substack.com/home/post/p-148747960?r=28a5lx&utm_campaign=post&utm_medium=web). Note that the PAX (or JAX) version needs to run on python 3.10.x and the PyTorch version can run on >=3.11.x. Therefore make sure you have two versions of python installed:\n\n```\npyenv install 3.10\npyenv install 3.11\npyenv versions # to list the versions available (lets assume the versions are 3.10.15 and 3.11.10)\n```\n\n### For PAX version installation do the following.\n\n```\npyenv local 3.10.15\npoetry env use 3.10.15\npoetry lock\npoetry install -E  pax\n```\n\nAfter than you can run the timesfm under `poetry shell` or do `poetry run python3 ...`.\n\n### For PyTorch version installation do the following.\n\n```\npyenv local 3.11.10\npoetry env use 3.11.10\npoetry lock\npoetry install -E  torch\n```\n\nAfter than you can run the timesfm under `poetry shell` or do `poetry run python3 ...`.\n\n**Additional Note**: \n\nIf you plan to use the **`forecast_with_covariates`** function (which requires external regressors), \nyou need to install **JAX** and **jaxlib**. If you installed the base version of TimesFM (`torch`), you must manually install the dependencies for **`forecast_with_covariates`** support:\n```\npip install jax jaxlib\n```\n\n**Why is this needed?**  \nThe `forecast_with_covariates` method relies on the `xreg_lib` module, which depends on JAX and jaxlib. If these packages are not installed, \ncalling `forecast_with_covariates` will raise an error. However, due to a lazy import mechanism, `xreg_lib` (and hence JAX/jaxlib) is not needed for standard `forecast` calls.\n\n### Notes\n\n1. Running the provided benchmarks would require additional dependencies. Please see the `experiments` folder.\n\n2. The dependency `lingvo` does not support ARM architectures, and the code is not working for machines with Apple silicon. We are aware of this issue and are working on a solution. Stay tuned.\n\n### Install from PyPI (and publish)\n\nOn python 3.11 you can install the torch version using:\n\n```pip install timesfm[torch]```\n\nOn python 3.10 you can install the pax version using:\n\n```pip install timesfm[pax]```\n\n\n## Usage \n\n### Initialize the model and load a checkpoint.\nThen the base class can be loaded as,\n\n```python\nimport timesfm\n\n# Loading the timesfm-2.0 checkpoint:\n# For PAX\ntfm = timesfm.TimesFm(\n      hparams=timesfm.TimesFmHparams(\n          backend=\"gpu\",\n          per_core_batch_size=32,\n          horizon_len=128,\n          num_layers=50,\n          context_len=2048,\n\n          use_positional_embedding=False,\n      ),\n      checkpoint=timesfm.TimesFmCheckpoint(\n          huggingface_repo_id=\"google/timesfm-2.0-500m-jax\"),\n  )\n\n# For Torch\ntfm = timesfm.TimesFm(\n      hparams=timesfm.TimesFmHparams(\n          backend=\"gpu\",\n          per_core_batch_size=32,\n          horizon_len=128,\n          num_layers=50,\n          use_positional_embedding=False,\n          context_len=2048,\n      ),\n      checkpoint=timesfm.TimesFmCheckpoint(\n          huggingface_repo_id=\"google/timesfm-2.0-500m-pytorch\"),\n  )\n\n# Loading the timesfm-1.0 checkpoint:\n# For PAX\ntfm = timesfm.TimesFm(\n      hparams=timesfm.TimesFmHparams(\n          backend=\"gpu\",\n          per_core_batch_size=32,\n          horizon_len=128,\n      ),\n      checkpoint=timesfm.TimesFmCheckpoint(\n          huggingface_repo_id=\"google/timesfm-1.0-200m\"),\n  )\n\n# For Torch\ntfm = timesfm.TimesFm(\n      hparams=timesfm.TimesFmHparams(\n          backend=\"gpu\",\n          per_core_batch_size=32,\n          horizon_len=128,\n      ),\n      checkpoint=timesfm.TimesFmCheckpoint(\n          huggingface_repo_id=\"google/timesfm-1.0-200m-pytorch\"),\n  )\n```\n\nNote some of the parameters are fixed to load the 200m and 500m models\n\n1. The `context_len` in `hparams` here can be set as the max context length **of the model** (a maximum of 2048 for 2.0 models and 512 for 1.0 models). **It needs to be a multiplier of `input_patch_len`, i.e. a multiplier of 32.** You can provide a shorter series to the `tfm.forecast()` function and the model will handle it. The input time series can have **any context length**. Padding / truncation will be handled by the inference code if needed.\n\n2. The horizon length can be set to anything. We recommend setting it to the largest horizon length you would need in the forecasting tasks for your application. We generally recommend horizon length <= context length but it is not a requirement in the function call.\n\n3. `backend` is one of \"cpu\", \"gpu\", case sensitive.\n\n### Perform inference\n\nWe provide APIs to forecast from either array inputs or `pandas` dataframe. Both forecast methods expect (1) the input time series contexts, (2) along with their frequencies. Please look at the documentation of the functions `tfm.forecast()` and `tfm.forecast_on_df()` for detailed instructions.\n\nIn particular regarding the frequency, TimesFM expects a categorical indicator valued in {0, 1, 2}:\n\n- **0** (default): high frequency, long horizon time series. We recommend using this for time series up to daily granularity.\n- **1**: medium frequency time series. We recommend using this for weekly and monthly data.\n- **2**: low frequency, short horizon time series. We recommend using this for anything beyond monthly, e.g. quarterly or yearly.\n\nThis categorical value should be directly provided with the array inputs. For dataframe inputs, we convert the conventional letter coding of frequencies to our expected categories, that\n\n- **0**: T, MIN, H, D, B, U\n- **1**: W, M\n- **2**: Q, Y\n\nNotice you do **NOT** have to strictly follow our recommendation here. Although this is our setup during model training and we expect it to offer the best forecast result, you can also view the frequency input as a free parameter and modify it per your specific use case.\n\n\nExamples:\n\nArray inputs, with the frequencies set to low, medium and high respectively.\n\n```python\nimport numpy as np\nforecast_input = [\n    np.sin(np.linspace(0, 20, 100)),\n    np.sin(np.linspace(0, 20, 200)),\n    np.sin(np.linspace(0, 20, 400)),\n]\nfrequency_input = [0, 1, 2]\n\npoint_forecast, experimental_quantile_forecast = tfm.forecast(\n    forecast_input,\n    freq=frequency_input,\n)\n```\n\n`pandas` dataframe, with the frequency set to \"M\" monthly.\n\n```python\nimport pandas as pd\n\n# e.g. input_df is\n#       unique_id  ds          y\n# 0     T1         1975-12-31  697458.0\n# 1     T1         1976-01-31  1187650.0\n# 2     T1         1976-02-29  1069690.0\n# 3     T1         1976-03-31  1078430.0\n# 4     T1         1976-04-30  1059910.0\n# ...   ...        ...         ...\n# 8175  T99        1986-01-31  602.0\n# 8176  T99        1986-02-28  684.0\n# 8177  T99        1986-03-31  818.0\n# 8178  T99        1986-04-30  836.0\n# 8179  T99        1986-05-31  878.0\n\nforecast_df = tfm.forecast_on_df(\n    inputs=input_df,\n    freq=\"M\",  # monthly\n    value_name=\"y\",\n    num_jobs=-1,\n)\n```\n\n## Covariates Support\n\nWe now have an external regressors library on top of TimesFM that can support static covariates as well as dynamic covariates available in the future. We have an usage example in [notebooks/covariates.ipynb](https://github.com/google-research/timesfm/blob/master/notebooks/covariates.ipynb).\n\nIf you plan to use the **`forecast_with_covariates`** on timesfm `torch` version, you need to install **JAX** and **jaxlib**. \nYou must manually install the dependencies for **`forecast_with_covariates`** support:\n```\npip install jax jaxlib\n```\n\nLet's take a toy example of forecasting sales for a grocery store: \n\n**Task:** Given the observed the daily sales of this week (7 days), forecast the daily sales of next week (7 days).\n\n```\nProduct: ice cream\nDaily_sales: [30, 30, 4, 5, 7, 8, 10]\nCategory: food\nBase_price: 1.99\nWeekday: [0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6]\nHas_promotion: [Yes, Yes, No, No, No, Yes, Yes, No, No, No, No, No, No, No]\nDaily_temperature: [31.0, 24.3, 19.4, 26.2, 24.6, 30.0, 31.1, 32.4, 30.9, 26.0, 25.0, 27.8, 29.5, 31.2]\n```\n\n```\nProduct: sunscreen\nDaily_sales: [5, 7, 12, 13, 5, 6, 10]\nCategory: skin product\nBase_price: 29.99\nWeekday: [0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6]\nHas_promotion: [No, No, Yes, Yes, No, No, No, Yes, Yes, Yes, Yes, Yes, Yes, Yes]\nDaily_temperature: [31.0, 24.3, 19.4, 26.2, 24.6, 30.0, 31.1, 32.4, 30.9, 26.0, 25.0, 27.8, 29.5, 31.2]\n```\n\nIn this example, besides the `Daily_sales`, we also have covariates `Category`, `Base_price`, `Weekday`, `Has_promotion`, `Daily_temperature`. Let's introduce some concepts:\n\n**Static covariates** are covariates for each time series. \n- In our example, `Category` is a **static categorical covariate**, \n- `Base_price` is a **static numerical covariates**.\n\n**Dynamic covariates** are covaraites for each time stamps.\n- Date / time related features can be usually treated as dynamic covariates.\n- In our example, `Weekday` and `Has_promotion` are **dynamic categorical covariates**.\n- `Daily_temperate` is a **dynamic numerical covariate**.\n\n**Notice:** Here we make it mandatory that the dynamic covariates need to cover both the forecasting context and horizon. For example, all dynamic covariates in the example have 14 values: the first 7 correspond to the observed 7 days, and the last 7 correspond to the next 7 days.\n\nWe can now provide the past data of the two products along with static and dynamic covariates as a batch input to TimesFM and produce forecasts that take into the account the covariates. To learn more, check out the example in [notebooks/covariates.ipynb](https://github.com/google-research/timesfm/blob/master/notebooks/covariates.ipynb).\n\n## Finetuning\n\nWe have provided an example of finetuning the model on a new dataset in [notebooks/finetuning.ipynb](https://github.com/google-research/timesfm/blob/master/notebooks/finetuning.ipynb).\n\n## Contribution Style guide\n\nIf you would like to submit a PR please make sure that you use our formatting style. We use [yapf](https://github.com/google/yapf) for formatting with the following options,\n\n```\n[style]\nbased_on_style = google\n# Add your custom style rules here\nindent_width = 2\nspaces_before_comment = 2\n\n```\n\nPlease run `yapf --in-place --recursive <filename>` on all affected files.\n"
  },
  {
    "path": "v1/TROUBLESHOOTING.md",
    "content": "# Troubleshooting\n\nThis document provides solutions to common issues encountered when using TimesFM.\n\n## Installation Issues\n\n### ARM/Apple Silicon Compatibility\n**Problem:** `lingvo` dependency fails on Apple Silicon (M1/M2/M3) machines.\n```\nERROR: Could not build wheels for lingvo\n```\n**Solution:** This is a known issue. The `lingvo` dependency doesn't support ARM architectures. We recommend:\n- Use x86_64 emulation via Rosetta 2: `arch -x86_64 pip install timesfm[pax]`\n- Use the PyTorch version instead, which has better ARM support: `pip install timesfm[torch]`\n- Use Docker with x86_64 emulation for consistent environments\n\n### Memory Issues During Installation\n**Problem:** Installation fails with memory errors.\n```\nKilled (signal 9)\n```\n**Solution:** \n- Ensure at least 32GB RAM is available\n- Close other applications during installation\n- Use `pip install --no-cache-dir timesfm[torch]` to reduce memory usage\n- Install in a clean virtual environment\n\n### JAX/PyTorch Version Conflicts\n**Problem:** Conflicting JAX and PyTorch installations.\n```\nImportError: cannot import name 'jax' from 'jax'\n```\n**Solution:**\n- For PyTorch-only usage: `pip install timesfm[torch]`\n- For covariates with PyTorch: `pip install timesfm[torch] && pip install jax jaxlib`\n- For PAX version: `pip install timesfm[pax]`\n\n## Runtime Errors\n\n### Model Loading Issues\n**Problem:** Checkpoint download fails or is corrupted.\n```\nHfFileNotFoundError: 404 Client Error\n```\n**Solution:**\n- Check internet connectivity\n- Verify Hugging Face Hub access: `huggingface-cli login`\n- Clear cache: `rm -rf ~/.cache/huggingface/`\n- Use explicit checkpoint paths if needed\n\n### CUDA/GPU Issues\n**Problem:** GPU not detected or CUDA errors.\n```\nRuntimeError: CUDA out of memory\n```\n**Solutions:**\n- Reduce `per_core_batch_size` (try 16, 8, or 4)\n- Reduce `context_len` to minimum needed\n- Use `backend=\"cpu\"` for testing\n- Check GPU memory: `nvidia-smi`\n\n### Context Length Errors\n**Problem:** Input series longer than model capacity.\n```\nValueError: context_len must be <= 512 for v1.0 models\n```\n**Solutions:**\n- Use TimesFM-2.0 for longer contexts (up to 2048)\n- Ensure `context_len` is multiple of 32\n- Truncate input series if necessary\n- Set appropriate `context_len` in model initialization\n\n## Data Issues\n\n### Frequency Mapping Problems\n**Problem:** Unexpected forecasting results with wrong frequency.\n```\nWarning: Frequency 'D' mapped to category 0\n```\n**Solutions:**\n- Verify frequency mapping: D→0 (high), W/M→1 (medium), Q/Y→2 (low)\n- Override automatic mapping by specifying frequency manually\n- Check data granularity matches chosen frequency category\n\n### Missing Values in Time Series\n**Problem:** NaN or missing values in input data.\n```\nValueError: Input contains NaN values\n```\n**Solutions:**\n- Pre-process data to handle missing values (forward fill, interpolation)\n- Ensure continuous time series without gaps\n- Remove or impute missing values before forecasting\n\n### Covariate Dimension Mismatches\n**Problem:** Covariate lengths don't match forecast horizon.\n```\nValueError: Dynamic covariates must cover context + horizon\n```\n**Solutions:**\n- Ensure dynamic covariates have length = context + horizon\n- Check static vs dynamic covariate classification\n- Verify covariate data alignment with time series\n\n## Performance Issues\n\n### Slow Inference\n**Problem:** Forecasting takes unexpectedly long.\n**Solutions:**\n- Use GPU backend: `backend=\"gpu\"`\n- Optimize batch size: increase `per_core_batch_size`\n- Use appropriate model size for your use case\n- Profile with smaller data first\n\n### Memory Usage\n**Problem:** High memory consumption during inference.\n**Solutions:**\n- Reduce batch size: `per_core_batch_size=1`\n- Process data in chunks\n- Use smaller context length when possible\n- Monitor memory with `htop` or `nvidia-smi`\n\n## Common Error Messages\n\n### `ModuleNotFoundError: No module named 'xreg_lib'`\n**Cause:** Missing JAX dependencies for covariates functionality.\n**Solution:** `pip install jax jaxlib`\n\n### `ValueError: horizon_len must be positive`\n**Cause:** Invalid horizon length specified.\n**Solution:** Set `horizon_len > 0` in model initialization.\n\n### `RuntimeError: Expected input batch_size (X) to be divisible by batch_size (Y)`\n**Cause:** Batch size mismatch.\n**Solution:** Adjust `per_core_batch_size` or input data batching.\n\n## Getting Help\n\nIf you encounter issues not covered here:\n1. Check the [GitHub Issues](https://github.com/google-research/timesfm/issues)\n2. Review the [notebooks/](notebooks/) for working examples\n3. Verify your installation follows the exact steps in the Installation section\n4. Test with the provided example data before using your own datasets"
  },
  {
    "path": "v1/docs/contributing.md",
    "content": "# How to Contribute\n\nWe would love to accept your patches and contributions to this project.\n\n## Before you begin\n\n### Sign our Contributor License Agreement\n\nContributions to this project must be accompanied by a\n[Contributor License Agreement](https://cla.developers.google.com/about) (CLA).\nYou (or your employer) retain the copyright to your contribution; this simply\ngives us permission to use and redistribute your contributions as part of the\nproject.\n\nIf you or your current employer have already signed the Google CLA (even if it\nwas for a different project), you probably don't need to do it again.\n\nVisit <https://cla.developers.google.com/> to see your current agreements or to\nsign a new one.\n\n### Review our Community Guidelines\n\nThis project follows [Google's Open Source Community\nGuidelines](https://opensource.google/conduct/).\n\n## Contribution process\n\n### Code Reviews\n\nAll submissions, including submissions by project members, require review. We\nuse [GitHub pull requests](https://docs.github.com/articles/about-pull-requests)\nfor this purpose.\n"
  },
  {
    "path": "v1/experiments/baselines/__init__.py",
    "content": "# Copyright 2024 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License."
  },
  {
    "path": "v1/experiments/baselines/timegpt_pipeline.py",
    "content": "# Copyright 2024 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\nfrom time import time\nfrom typing import List, Optional, Tuple\nfrom dotenv import load_dotenv\nfrom gluonts.time_feature.seasonality import get_seasonality as _get_seasonality\nfrom nixtla import NixtlaClient\nimport pandas as pd\nfrom tqdm import tqdm\nfrom utilsforecast.processing import (\n    backtest_splits,\n    drop_index_if_pandas,\n    join,\n    maybe_compute_sort_indices,\n    take_rows,\n    vertical_concat,\n)\n\n\ndef get_seasonality(freq: str) -> int:\n  return _get_seasonality(freq, seasonalities={\"D\": 7})\n\n\ndef maybe_convert_col_to_datetime(\n    df: pd.DataFrame, col_name: str\n) -> pd.DataFrame:\n  if not pd.api.types.is_datetime64_any_dtype(df[col_name]):\n    df = df.copy()\n    df[col_name] = pd.to_datetime(df[col_name])\n  return df\n\n\ndef zero_pad_time_series(df, freq, min_length=36):\n  \"\"\"If time_series length is less than min_length, front pad it with zeros.\"\"\"\n  # 1. Calculate required padding for each unique_id\n  value_counts = df[\"unique_id\"].value_counts()\n  to_pad = value_counts[value_counts < min_length].index\n\n  # 2. Create a new DataFrame to hold padded data\n  padded_data = []\n\n  for unique_id in to_pad:\n    # 2a. Filter data for the specific unique_id\n    subset = df[df[\"unique_id\"] == unique_id]\n    if len(subset) > min_length:\n      padded_data.append(subset)\n    else:\n      # 2b. Determine earliest date and calculate padding dates\n      start_date = subset[\"ds\"].min()\n      padding_dates = pd.date_range(\n          end=start_date,\n          periods=min_length - len(subset) + 1,\n          freq=freq,  # 'MS' for month start\n      )[\n          :-1\n      ]  # Exclude the start_date itself\n\n      # 2c. Create padding data\n      padding_df = pd.DataFrame(\n          {\"ds\": padding_dates, \"unique_id\": unique_id, \"y\": 0}  # Zero padding\n      )\n\n      # 2d. Combine original and padding data, and append to the list\n      padded_data.append(pd.concat([padding_df, subset]).sort_values(\"ds\"))\n\n  # 3. Combine all padded data and original data (unchanged)\n  result_df = pd.concat(padded_data + [df[~df[\"unique_id\"].isin(to_pad)]])\n  return result_df\n\n\nclass Forecaster:\n  \"\"\"Borrowed from\n\n  https://github.com/Nixtla/nixtla/tree/main/experiments/foundation-time-series-arena/xiuhmolpilli/models.\n  \"\"\"\n\n  def forecast(\n      self,\n      df: pd.DataFrame,\n      h: int,\n      freq: str,\n  ) -> pd.DataFrame:\n    raise NotImplementedError\n\n  def cross_validation(\n      self,\n      df: pd.DataFrame,\n      h: int,\n      freq: str,\n      n_windows: int = 1,\n      step_size: int | None = None,\n  ) -> pd.DataFrame:\n    df = maybe_convert_col_to_datetime(df, \"ds\")\n    # mlforecast cv code\n    results = []\n    sort_idxs = maybe_compute_sort_indices(df, \"unique_id\", \"ds\")\n    if sort_idxs is not None:\n      df = take_rows(df, sort_idxs)\n    splits = backtest_splits(\n        df,\n        n_windows=n_windows,\n        h=h,\n        id_col=\"unique_id\",\n        time_col=\"ds\",\n        freq=pd.tseries.frequencies.to_offset(freq),\n        step_size=h if step_size is None else step_size,\n    )\n    for _, (cutoffs, train, valid) in tqdm(enumerate(splits)):\n      if len(valid.columns) > 3:\n        raise NotImplementedError(\n            \"Cross validation with exogenous variables is not yet supported.\"\n        )\n      y_pred = self.forecast(\n          df=train,\n          h=h,\n          freq=freq,\n      )\n      y_pred = join(y_pred, cutoffs, on=\"unique_id\", how=\"left\")\n      result = join(\n          valid[[\"unique_id\", \"ds\", \"y\"]],\n          y_pred,\n          on=[\"unique_id\", \"ds\"],\n      )\n      if result.shape[0] < valid.shape[0]:\n        raise ValueError(\n            \"Cross validation result produced less results than expected.\"\n            \" Please verify that the frequency parameter (freq) matches your\"\n            \" series' and that there aren't any missing periods.\"\n        )\n      results.append(result)\n    out = vertical_concat(results)\n    out = drop_index_if_pandas(out)\n    first_out_cols = [\"unique_id\", \"ds\", \"cutoff\", \"y\"]\n    remaining_cols = [c for c in out.columns if c not in first_out_cols]\n    fcst_cv_df = out[first_out_cols + remaining_cols]\n    return fcst_cv_df\n\n\nclass TimeGPT(Forecaster):\n  \"\"\"Borrowed from\n\n  https://github.com/Nixtla/nixtla/tree/main/experiments/foundation-time-series-arena/xiuhmolpilli/models.\n  We modify the class to take care of edge cases.\n  \"\"\"\n\n  def __init__(\n      self,\n      api_key: str | None = None,\n      base_url: Optional[str] = None,\n      max_retries: int = 1,\n      model: str = \"timegpt-1\",\n      alias: str = \"TimeGPT\",\n  ):\n    self.api_key = api_key\n    self.base_url = base_url\n    self.max_retries = max_retries\n    self.model = model\n    self.alias = alias\n\n  def _get_client(self) -> NixtlaClient:\n    if self.api_key is None:\n      api_key = os.environ[\"NIXTLA_API_KEY\"]\n    else:\n      api_key = self.api_key\n    return NixtlaClient(\n        api_key=api_key,\n        base_url=self.base_url,\n        max_retries=self.max_retries,\n    )\n\n  def forecast(\n      self,\n      df: pd.DataFrame,\n      h: int,\n      freq: str,\n      level: List = [90.0],\n      chunk_size: Optional[int] = None,\n  ) -> pd.DataFrame:\n    client = self._get_client()\n    fcst_df = None\n    if chunk_size is None:\n      fcst_df = client.forecast(\n          df=df,\n          h=h,\n          freq=freq,\n          level=level,\n          model=self.model,\n      )\n    else:\n      all_unique_ids = df[\"unique_id\"].unique()\n      all_fcst_df = []\n      for i in range(0, len(all_unique_ids), chunk_size):\n        chunk_ids = all_unique_ids[i : i + chunk_size]\n        chunk_df = df[df[\"unique_id\"].isin(chunk_ids)]\n        fct_chunk_df = client.forecast(\n            df=chunk_df,\n            h=h,\n            freq=freq,\n            level=level,\n        )\n        all_fcst_df.append(fct_chunk_df)\n      fcst_df = pd.concat(all_fcst_df)\n    fcst_df[\"ds\"] = pd.to_datetime(fcst_df[\"ds\"])\n    replace_dict = {}\n    for col in fcst_df.columns:\n      if col.startswith(\"TimeGPT\"):\n        replace_dict[col] = col.replace(\"TimeGPT\", self.alias)\n    fcst_df = fcst_df.rename(columns=replace_dict)\n    return fcst_df\n\n\ndef run_timegpt(\n    train_df: pd.DataFrame,\n    horizon: int,\n    freq: str,\n    seasonality: int,\n    level: List[int],\n    dataset: str,\n    model: str = \"timegpt-1\",\n) -> Tuple[pd.DataFrame, float, str]:\n  os.environ[\"NIXTLA_ID_AS_COL\"] = \"true\"\n  model = TimeGPT(model=\"timegpt-1\", alias=model)\n  padded_train_df = zero_pad_time_series(train_df, freq)\n  init_time = time()\n  # For these datasets the API fails if we do not chunk.\n  if dataset in [\"m5\", \"m4_quarterly\"]:\n    chunk_size = 5000\n  else:\n    chunk_size = None\n  fcsts_df = model.forecast(\n      df=padded_train_df,\n      h=horizon,\n      level=level,\n      freq=freq,\n      chunk_size=chunk_size,\n  )\n  total_time = time() - init_time\n  # In case levels are not returned we replace the levels with the mean predictions.\n  # Note that this does not affect the results table as we only compare on point\n  # forecastign metrics.\n  for lvl in level:\n    if f\"{model.alias}-lo-{lvl}\" not in fcsts_df.columns:\n      fcsts_df[f\"{model.alias}-lo-{lvl}\"] = fcsts_df[model.alias]\n    if f\"{model.alias}-hi-{lvl}\" not in fcsts_df.columns:\n      fcsts_df[f\"{model.alias}-hi-{lvl}\"] = fcsts_df[model.alias]\n  return fcsts_df, total_time, model.alias\n"
  },
  {
    "path": "v1/experiments/extended_benchmarks/README.md",
    "content": "# Extended Benchmarks\n\nThe benchmark setting has been borrowed from Nixtla's original [benchmarking](https://github.com/AzulGarza/nixtla/tree/main/experiments/amazon-chronos) of time-series foundation models against a strong statistical ensemble. Later more datasets were added by the Chronos team in this [pull request](https://github.com/shchur/nixtla/tree/chronos-full-eval/experiments/amazon-chronos). We compare on all the datasets in this extended benchmarks.\n\n\n## Running TimesFM on the benchmark\n\nWe need to add the following packages for running these benchmarks. Follow the installation instructions till before `poetry lock`. Then,\n\n```\npoetry add git+https://github.com/awslabs/gluon-ts.git\npoetry lock\npoetry install --only <pax or pytorch>\n```\n\nTo run the timesfm on the benchmark do:\n\n```\npoetry run python3 -m experiments.extended_benchmarks.run_timesfm --model_path=google/timesfm-1.0-200m(-pytorch) --backend=\"gpu\"\n```\n\n\nNote: In the current version of TimesFM we focus on point forecasts and therefore the mase, smape have been calculated using the quantile head corresponding to the median i.e 0.5 quantile. We do offer 10 quantile heads but they have not been calibrated after pretraining. We recommend using them with caution or calibrate/conformalize them on a hold out for your applications. More to follow on later versions.\n\n## Benchmark Results for TimesFM-1.0\n\n![Benchmark Results Table](./tfm_extended_new.png)\n\n__Update:__ We have added TimeGPT-1 to the benchmark results. We had to remove the Dominick dataset as we were not able to run TimeGPT-1 on this benchmark. Note that the previous results including Dominick remain available at `./tfm_results.png`. In order to reproduce the results for TimeGPT-1, please run `run_timegpt.py`.\n\n_Remark:_ All baselines except the ones involving TimeGPT were run performed on a [g2-standard-32](https://cloud.google.com/compute/docs/gpus). Since TimeGPT-1 can only be accessed by an API, the time column might not reflect the true speed of the model as it also includes the communication cost. Moreover, we are not sure about the exact backend hardware for TimeGPT. The TimesFM latency numbers are from the PAX version.\n\nWe can see that TimesFM performs the best in terms of both mase and smape. More importantly it is much faster than the other methods, in particular it is more than 600x faster than StatisticalEnsemble and 80x faster than Chronos (Large).\n\nNote: This benchmark only compares on `one` small horizon window for long horizon datasets like ETT hourly and 15 minutes. More in depth comparison on longer horizon rolling validation tasks are presented in our long horizon benchmarks."
  },
  {
    "path": "v1/experiments/extended_benchmarks/run_timegpt.py",
    "content": "# Copyright 2024 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Evaluation script for timegpt.\"\"\"\n\nimport os\nimport sys\nimport time\n\nfrom absl import flags\nimport numpy as np\nimport pandas as pd\n\nfrom ..baselines.timegpt_pipeline import run_timegpt\nfrom .utils import ExperimentHandler\n\n\ndataset_names = [\n    \"m1_monthly\",\n    \"m1_quarterly\",\n    \"m1_yearly\",\n    \"m3_monthly\",\n    \"m3_other\",\n    \"m3_quarterly\",\n    \"m3_yearly\",\n    \"m4_quarterly\",\n    \"m4_yearly\",\n    \"tourism_monthly\",\n    \"tourism_quarterly\",\n    \"tourism_yearly\",\n    \"nn5_daily_without_missing\",\n    \"m5\",\n    \"nn5_weekly\",\n    \"traffic\",\n    \"weather\",\n    \"australian_electricity_demand\",\n    \"car_parts_without_missing\",\n    \"cif_2016\",\n    \"covid_deaths\",\n    \"ercot\",\n    \"ett_small_15min\",\n    \"ett_small_1h\",\n    \"exchange_rate\",\n    \"fred_md\",\n    \"hospital\",\n]\n\n_MODEL_NAME = flags.DEFINE_string(\n    \"model_name\",\n    \"timegpt-1-long-horizon\",\n    \"Path to model, can also be set to timegpt-1\",\n)\n_SAVE_DIR = flags.DEFINE_string(\"save_dir\", \"./results\", \"Save directory\")\n\n\nQUANTILES = list(np.arange(1, 10) / 10.0)\n\n\ndef main():\n  results_list = []\n  run_id = np.random.randint(100000)\n  model_name = _MODEL_NAME.value\n  for dataset in dataset_names:\n    print(f\"Evaluating model {model_name} on dataset {dataset}\", flush=True)\n    exp = ExperimentHandler(dataset, quantiles=QUANTILES)\n    train_df = exp.train_df\n    horizon = exp.horizon\n    seasonality = exp.seasonality\n    freq = exp.freq\n    level = exp.level\n    fcsts_df, total_time, model_name = run_timegpt(\n        train_df=train_df,\n        horizon=exp.horizon,\n        model=model_name,\n        seasonality=seasonality,\n        freq=freq,\n        dataset=dataset,\n        level=level,\n    )\n    time_df = pd.DataFrame({\"time\": [total_time], \"model\": model_name})\n    fcsts_df = exp.fcst_from_level_to_quantiles(fcsts_df, model_name)\n    results = exp.evaluate_from_predictions(\n        models=[model_name], fcsts_df=fcsts_df, times_df=time_df\n    )\n    print(results, flush=True)\n    results_list.append(results)\n    results_full = pd.concat(results_list)\n    save_path = os.path.join(_SAVE_DIR.value, str(run_id))\n    print(f\"Saving results to {save_path}\", flush=True)\n    os.makedirs(save_path, exist_ok=True)\n    results_full.to_csv(f\"{save_path}/results.csv\")\n\n\nif __name__ == \"__main__\":\n  FLAGS = flags.FLAGS\n  FLAGS(sys.argv)\n  main()\n"
  },
  {
    "path": "v1/experiments/extended_benchmarks/run_timesfm.py",
    "content": "# Copyright 2024 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Evaluation script for timesfm.\"\"\"\n\nimport os\nimport sys\nimport time\n\nfrom absl import flags\nimport numpy as np\nimport pandas as pd\nimport timesfm\n\nfrom .utils import ExperimentHandler\n\ndataset_names = [\n    \"m1_monthly\",\n    \"m1_quarterly\",\n    \"m1_yearly\",\n    \"m3_monthly\",\n    \"m3_other\",\n    \"m3_quarterly\",\n    \"m3_yearly\",\n    \"m4_quarterly\",\n    \"m4_yearly\",\n    \"tourism_monthly\",\n    \"tourism_quarterly\",\n    \"tourism_yearly\",\n    \"nn5_daily_without_missing\",\n    \"m5\",\n    \"nn5_weekly\",\n    \"traffic\",\n    \"weather\",\n    \"australian_electricity_demand\",\n    \"car_parts_without_missing\",\n    \"cif_2016\",\n    \"covid_deaths\",\n    \"ercot\",\n    \"ett_small_15min\",\n    \"ett_small_1h\",\n    \"exchange_rate\",\n    \"fred_md\",\n    \"hospital\",\n]\n\n\ncontext_dict_v2 = {}\n\ncontext_dict_v1 = {\n    \"cif_2016\": 32,\n    \"tourism_yearly\": 64,\n    \"covid_deaths\": 64,\n    \"tourism_quarterly\": 64,\n    \"tourism_monthly\": 64,\n    \"m1_monthly\": 64,\n    \"m1_quarterly\": 64,\n    \"m1_yearly\": 64,\n    \"m3_monthly\": 64,\n    \"m3_other\": 64,\n    \"m3_quarterly\": 64,\n    \"m3_yearly\": 64,\n    \"m4_quarterly\": 64,\n    \"m4_yearly\": 64,\n}\n\n_MODEL_PATH = flags.DEFINE_string(\"model_path\", \"google/timesfm-2.0-500m-jax\",\n                                  \"Path to model\")\n_BATCH_SIZE = flags.DEFINE_integer(\"batch_size\", 64, \"Batch size\")\n_HORIZON = flags.DEFINE_integer(\"horizon\", 128, \"Horizon\")\n_BACKEND = flags.DEFINE_string(\"backend\", \"gpu\", \"Backend\")\n_NUM_JOBS = flags.DEFINE_integer(\"num_jobs\", 1, \"Number of jobs\")\n_SAVE_DIR = flags.DEFINE_string(\"save_dir\", \"./results\", \"Save directory\")\n\nQUANTILES = list(np.arange(1, 10) / 10.0)\n\n\ndef main():\n  results_list = []\n  model_path = _MODEL_PATH.value\n  num_layers = 20\n  max_context_len = 512\n  use_positional_embedding = True\n  context_dict = context_dict_v1\n  if \"2.0\" in model_path:\n    num_layers = 50\n    use_positional_embedding = False\n    max_context_len = 2048\n    context_dict = context_dict_v2\n\n  tfm = timesfm.TimesFm(\n      hparams=timesfm.TimesFmHparams(\n          backend=\"gpu\",\n          per_core_batch_size=32,\n          horizon_len=128,\n          num_layers=num_layers,\n          context_len=max_context_len,\n          use_positional_embedding=use_positional_embedding,\n      ),\n      checkpoint=timesfm.TimesFmCheckpoint(huggingface_repo_id=model_path),\n  )\n  run_id = np.random.randint(100000)\n  model_name = \"timesfm\"\n  for dataset in dataset_names:\n    print(f\"Evaluating model {model_name} on dataset {dataset}\", flush=True)\n    exp = ExperimentHandler(dataset, quantiles=QUANTILES)\n\n    if dataset in context_dict:\n      context_len = context_dict[dataset]\n    else:\n      context_len = max_context_len\n\n    train_df = exp.train_df\n    freq = exp.freq\n    init_time = time.time()\n    fcsts_df = tfm.forecast_on_df(\n        inputs=train_df,\n        freq=freq,\n        value_name=\"y\",\n        model_name=model_name,\n        forecast_context_len=context_len,\n        num_jobs=_NUM_JOBS.value,\n        normalize=True,\n    )\n    total_time = time.time() - init_time\n    time_df = pd.DataFrame({\"time\": [total_time], \"model\": model_name})\n    results = exp.evaluate_from_predictions(models=[model_name],\n                                            fcsts_df=fcsts_df,\n                                            times_df=time_df)\n    print(results, flush=True)\n    results_list.append(results)\n    results_full = pd.concat(results_list)\n    save_path = os.path.join(_SAVE_DIR.value, str(run_id))\n    print(f\"Saving results to {save_path}\", flush=True)\n    os.makedirs(save_path, exist_ok=True)\n    results_full.to_csv(f\"{save_path}/results.csv\")\n\n\nif __name__ == \"__main__\":\n  FLAGS = flags.FLAGS\n  FLAGS(sys.argv)\n  main()\n"
  },
  {
    "path": "v1/experiments/extended_benchmarks/utils.py",
    "content": "# Copyright 2024 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Forked from https://github.com/Nixtla/nixtla/blob/main/experiments/amazon-chronos/src/utils.py.\"\"\"\n\nfrom functools import partial\nfrom itertools import repeat\nimport multiprocessing\nimport os\nfrom pathlib import Path\nfrom typing import List\n\nfrom gluonts.dataset import Dataset\nfrom gluonts.dataset.repository.datasets import (\n    dataset_names as gluonts_datasets,\n    get_dataset,\n)\nfrom gluonts.time_feature.seasonality import get_seasonality\nimport numpy as np\nimport pandas as pd\nfrom utilsforecast.evaluation import evaluate\nfrom utilsforecast.losses import mae, mase, smape\n\n\ndef parallel_transform(inp):\n  ts, last_n = inp[0], inp[1]\n  return ExperimentHandler._transform_gluonts_instance_to_df(ts, last_n=last_n)\n\n\ndef quantile_loss(\n    df: pd.DataFrame,\n    models: list,\n    q: float = 0.5,\n    id_col: str = \"unique_id\",\n    target_col: str = \"y\",\n) -> pd.DataFrame:\n  delta_y = df[models].sub(df[target_col], axis=0)\n  res = (\n      np.maximum(q * delta_y, (q - 1) * delta_y)\n      .groupby(df[id_col], observed=True)\n      .mean()\n  )\n  res.index.name = id_col\n  res = res.reset_index()\n  return res\n\n\nclass ExperimentHandler:\n\n  def __init__(\n      self,\n      dataset: str,\n      quantiles: List[float] = list(np.arange(1, 10) / 10.0),\n      results_dir: str = \"./results\",\n      models_dir: str = \"./models\",\n  ):\n    if dataset not in gluonts_datasets:\n      raise Exception(\n          f\"dataset {dataset} not found in gluonts \"\n          f\"available datasets: {', '.join(gluonts_datasets)}\"\n      )\n    self.dataset = dataset\n    self.quantiles = quantiles\n    self.level = self._transform_quantiles_to_levels(quantiles)\n    self.results_dir = results_dir\n    self.models_dir = models_dir\n    # defining datasets\n    self._maybe_download_m3_or_m5_file(self.dataset)\n    gluonts_dataset = get_dataset(self.dataset)\n    self.horizon = gluonts_dataset.metadata.prediction_length\n    if self.horizon is None:\n      raise Exception(\n          f\"horizon not found for dataset {self.dataset} \"\n          \"experiment cannot be run\"\n      )\n    self.freq = gluonts_dataset.metadata.freq\n    # get_seasonality() returns 1 for freq='D', override this to 7. This significantly improves the accuracy of\n    # statistical models on datasets like m5/nn5_daily. The models like AutoARIMA/AutoETS can still set\n    # seasonality=1 internally on datasets like weather by choosing non-seasonal models during model selection.\n    if self.freq == \"D\":\n      self.seasonality = 7\n    else:\n      self.seasonality = get_seasonality(self.freq)\n    self.gluonts_train_dataset = gluonts_dataset.train\n    self.gluonts_test_dataset = gluonts_dataset.test\n    self._create_dir_if_not_exists(self.results_dir)\n    try:\n      multiprocessing.set_start_method(\"spawn\")\n    except RuntimeError:\n      print(\"Multiprocessing context has already been set.\")\n\n  @staticmethod\n  def _maybe_download_m3_or_m5_file(dataset: str):\n    if dataset[:2] == \"m3\":\n      m3_file = Path.home() / \".gluonts\" / \"datasets\" / \"M3C.xls\"\n      if not m3_file.exists():\n        from datasetsforecast.m3 import M3\n        from datasetsforecast.utils import download_file\n\n        download_file(m3_file.parent, M3.source_url)\n    elif dataset == \"m5\":\n      m5_raw_dir = Path.home() / \".gluonts\" / \"m5\"\n      if not m5_raw_dir.exists():\n        import zipfile\n        from datasetsforecast.m5 import M5\n        from datasetsforecast.utils import download_file\n\n        download_file(m5_raw_dir, M5.source_url)\n        with zipfile.ZipFile(m5_raw_dir / \"m5.zip\", \"r\") as zip_ref:\n          zip_ref.extractall(m5_raw_dir)\n\n  @staticmethod\n  def _transform_quantiles_to_levels(quantiles: List[float]) -> List[int]:\n    level = [\n        int(100 - 200 * q) for q in quantiles if q < 0.5\n    ]  # in this case mean=mediain\n    level = sorted(list(set(level)))\n    return level\n\n  @staticmethod\n  def _create_dir_if_not_exists(directory: str):\n    Path(directory).mkdir(parents=True, exist_ok=True)\n\n  @staticmethod\n  def _transform_gluonts_instance_to_df(\n      ts: dict,\n      last_n: int | None = None,\n  ) -> pd.DataFrame:\n    start_period = ts[\"start\"]\n    start_ds, freq = start_period.to_timestamp(), start_period.freq\n    target = ts[\"target\"]\n    ds = pd.date_range(start=start_ds, freq=freq, periods=len(target))\n    if last_n is not None:\n      target = target[-last_n:]\n      ds = ds[-last_n:]\n    ts_df = pd.DataFrame({\"unique_id\": ts[\"item_id\"], \"ds\": ds, \"y\": target})\n    return ts_df\n\n  @staticmethod\n  def _transform_gluonts_dataset_to_df(\n      gluonts_dataset: Dataset,\n      last_n: int | None = None,\n  ) -> pd.DataFrame:\n    with multiprocessing.Pool(os.cpu_count()) as pool:  # Create a process pool\n      results = pool.map(\n          parallel_transform, zip(gluonts_dataset, repeat(last_n))\n      )\n    df = pd.concat(results)\n    df = df.reset_index(drop=True)\n    return df\n\n  @property\n  def train_df(self) -> pd.DataFrame:\n    train_df = self._transform_gluonts_dataset_to_df(self.gluonts_train_dataset)\n    return train_df\n\n  @property\n  def test_df(self) -> pd.DataFrame:\n    test_df = self._transform_gluonts_dataset_to_df(\n        self.gluonts_test_dataset,\n        last_n=self.horizon,\n    )\n    # Make sure that only the first backtest window is used for evaluation on `traffic` / `exchange_rate` datasets\n    return test_df.groupby(\"unique_id\", sort=False).head(self.horizon)\n\n  def save_dataframe(self, df: pd.DataFrame, file_name: str):\n    df.to_csv(f\"{self.results_dir}/{file_name}\", index=False)\n\n  def save_results(\n      self, fcst_df: pd.DataFrame, total_time: float, model_name: str\n  ):\n    self.save_dataframe(\n        fcst_df,\n        f\"{model_name}-{self.dataset}-fcst.csv\",\n    )\n    time_df = pd.DataFrame({\"time\": [total_time], \"model\": model_name})\n    self.save_dataframe(\n        time_df,\n        f\"{model_name}-{self.dataset}-time.csv\",\n    )\n\n  def fcst_from_level_to_quantiles(\n      self,\n      fcst_df: pd.DataFrame,\n      model_name: str,\n  ) -> pd.DataFrame:\n    fcst_df = fcst_df.copy()\n    cols = [\"unique_id\", \"ds\", model_name]\n    for q in self.quantiles:\n      if q == 0.5:\n        col = f\"{model_name}\"\n      else:\n        lv = int(100 - 200 * q)\n        hi_or_lo = \"lo\" if lv > 0 else \"hi\"\n        lv = abs(lv)\n        col = f\"{model_name}-{hi_or_lo}-{lv}\"\n      q_col = f\"{model_name}-q-{q}\"\n      fcst_df[q_col] = fcst_df[col].values\n      cols.append(q_col)\n    return fcst_df[cols]\n\n  def evaluate_models(self, models: List[str]) -> pd.DataFrame:\n    fcsts_df = []\n    times_df = []\n    for model in models:\n      fcst_method_df = pd.read_csv(\n          f\"{self.results_dir}/{model}-{self.dataset}-fcst.csv\"\n      ).set_index([\"unique_id\", \"ds\"])\n      fcsts_df.append(fcst_method_df)\n      time_method_df = pd.read_csv(\n          f\"{self.results_dir}/{model}-{self.dataset}-time.csv\"\n      )\n      times_df.append(time_method_df)\n    fcsts_df = pd.concat(fcsts_df, axis=1).reset_index()\n    fcsts_df[\"ds\"] = pd.to_datetime(fcsts_df[\"ds\"])\n    times_df = pd.concat(times_df)\n    return self.evaluate_from_predictions(\n        models=models, fcsts_df=fcsts_df, times_df=times_df\n    )\n\n  def evaluate_from_predictions(\n      self, models: List[str], fcsts_df: pd.DataFrame, times_df: pd.DataFrame\n  ) -> pd.DataFrame:\n    test_df = self.test_df\n    train_df = self.train_df\n    test_df = test_df.merge(fcsts_df, how=\"left\")\n    assert test_df.isna().sum().sum() == 0, \"merge contains nas\"\n    # point evaluation\n    point_fcsts_cols = [\"unique_id\", \"ds\", \"y\"] + models\n    test_df[\"unique_id\"] = test_df[\"unique_id\"].astype(str)\n    train_df[\"unique_id\"] = train_df[\"unique_id\"].astype(str)\n    mase_seas = partial(mase, seasonality=self.seasonality)\n    eval_df = evaluate(\n        test_df[point_fcsts_cols],\n        train_df=train_df,\n        metrics=[smape, mase_seas, mae],\n    )\n    # probabilistic evaluation\n    eval_prob_df = []\n    for q in self.quantiles:\n      prob_cols = [f\"{model}-q-{q}\" for model in models]\n      eval_q_df = quantile_loss(test_df, models=prob_cols, q=q)\n      eval_q_df[prob_cols] = eval_q_df[prob_cols] * self.horizon\n      eval_q_df = eval_q_df.rename(columns=dict(zip(prob_cols, models)))\n      eval_q_df[\"metric\"] = f\"quantile-loss-{q}\"\n      eval_prob_df.append(eval_q_df)\n    eval_prob_df = pd.concat(eval_prob_df)\n    eval_prob_df = eval_prob_df.groupby(\"metric\").sum().reset_index()\n    total_y = test_df[\"y\"].sum()\n    eval_prob_df[models] = eval_prob_df[models] / total_y\n    eval_prob_df[\"metric\"] = \"scaled_crps\"\n    eval_df = pd.concat([eval_df, eval_prob_df]).reset_index(drop=True)\n    eval_df = eval_df.groupby(\"metric\").mean(numeric_only=True).reset_index()\n    eval_df = eval_df.melt(\n        id_vars=\"metric\", value_name=\"value\", var_name=\"model\"\n    )\n    times_df.insert(0, \"metric\", \"time\")\n    times_df = times_df.rename(columns={\"time\": \"value\"})\n    eval_df = pd.concat([eval_df, times_df])\n    eval_df.insert(0, \"dataset\", self.dataset)\n    eval_df = eval_df.sort_values([\"dataset\", \"metric\", \"model\"])\n    eval_df = eval_df.reset_index(drop=True)\n    return eval_df\n\n\nif __name__ == \"__main__\":\n  multiprocessing.set_start_method(\"spawn\")\n"
  },
  {
    "path": "v1/experiments/long_horizon_benchmarks/README.md",
    "content": "# Extended Benchmarks\n\nWe benchmark on the original test set for ETT datasets as per long horizon benchmark papers (see [here](https://openreview.net/forum?id=pCbC3aQB5W) for example.) In the original benchmark, rolling validation task on all test windows (with a stride of 1) is considered. While we can easily run our method on this task, the baselines can take a very long time to run. Therefore we present results on a modified task with stride between windows set to Horizon length i.e all disjoint horizons in the test period is considered.\n\nAll experiments were performed on a [g2-standard-32](https://cloud.google.com/compute/docs/gpus). We compare TimesFM with [Amazon-Chronos](https://github.com/amazon-science/chronos-forecasting).\n\n## Running TimesFM on the benchmark\n\nWe need to add the following packages for running these benchmarks. Follow the installation instructions till before `poetry lock`. Then,\n\n```\npoetry add git+https://github.com/awslabs/gluon-ts.git\npoetry add git+https://github.com/amazon-science/chronos-forecasting.git\npoetry lock\npoetry install --only pax\n```\nNote that for now only the pax version runs on this benchmark, because we had to remove the old tf dependency from the pytorch version. We will fix this issue soon.\n\nTo run the timesfm on the benchmark do:\n\n```\npoetry run python3 -m experiments.long_horizon_benchmarks.run_eval \\\n--model_path=google/timesfm-1.0-200m --backend=\"gpu\" \\\n--pred_len=96 --context_len=512 --dataset=etth1\n```\n\nIn the above, `<model_path>` should point to the checkpoint directory that can be downloaded from HuggingFace. \n\nFor running chronos on the same benchmark you can run the command,\n\n```\npoetry run python3 -m experiments.long_horizon_benchmarks.run_eval \\\n--model_path=amazon/chronos-t5-mini --backend=\"gpu\" \\\n--pred_len=96 --context_len=512 --dataset=etth1\n```\n\nYou can change the model size from \"mini\" to \"large\" as required. The datasets we benchmark on are etth1, etth2, ettm1 and ettm2.\n\n## Benchmark Results for TimesFM-1.0\n\n![Benchmark Results Table](./tfm_long_horizon.png)\n\nWe compare the performance on horizon lengths of 96, 192 and 336, while context length is held fixed at 512.\n\nWe can see that TimesFM performs the best in terms of both wape and smape. More importantly it is much faster than the other methods, in particular it is more than 1000x faster than Chronos (Large)."
  },
  {
    "path": "v1/experiments/long_horizon_benchmarks/run_eval.py",
    "content": "# Copyright 2024 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Eval pipeline.\"\"\"\n\nimport json\nimport os\nimport sys\nimport time\nfrom absl import flags\nimport chronos\nimport numpy as np\nimport pandas as pd\nimport timesfm\nfrom timesfm import data_loader\nimport torch\nimport tqdm\n\nFLAGS = flags.FLAGS\n\n_BATCH_SIZE = flags.DEFINE_integer(\"batch_size\", 64,\n                                   \"Batch size for the randomly sampled batch\")\n_DATASET = flags.DEFINE_string(\"dataset\", \"etth1\", \"The name of the dataset.\")\n\n_MODEL_PATH = flags.DEFINE_string(\"model_path\", \"google/timesfm-2.0-500m-jax\",\n                                  \"The name of the model.\")\n_DATETIME_COL = flags.DEFINE_string(\"datetime_col\", \"date\",\n                                    \"Column having datetime.\")\n_NUM_COV_COLS = flags.DEFINE_list(\"num_cov_cols\", None,\n                                  \"Column having numerical features.\")\n_CAT_COV_COLS = flags.DEFINE_list(\"cat_cov_cols\", None,\n                                  \"Column having categorical features.\")\n_TS_COLS = flags.DEFINE_list(\"ts_cols\", None, \"Columns of time-series features\")\n_NORMALIZE = flags.DEFINE_bool(\"normalize\", True,\n                               \"normalize data for eval or not\")\n_CONTEXT_LEN = flags.DEFINE_integer(\"context_len\", 2048,\n                                    \"Length of the context window\")\n_PRED_LEN = flags.DEFINE_integer(\"pred_len\", 96, \"prediction length.\")\n_BACKEND = flags.DEFINE_string(\"backend\", \"gpu\", \"backend to use\")\n_RESULTS_DIR = flags.DEFINE_string(\"results_dir\", \"./results/long_horizon\",\n                                   \"results directory\")\n\nDATA_DICT = {\n    \"ettm2\": {\n        \"boundaries\": [34560, 46080, 57600],\n        \"data_path\": \"./datasets/ETT-small/ETTm2.csv\",\n        \"freq\": \"15min\",\n    },\n    \"ettm1\": {\n        \"boundaries\": [34560, 46080, 57600],\n        \"data_path\": \"./datasets/ETT-small/ETTm1.csv\",\n        \"freq\": \"15min\",\n    },\n    \"etth2\": {\n        \"boundaries\": [8640, 11520, 14400],\n        \"data_path\": \"./datasets/ETT-small/ETTh2.csv\",\n        \"freq\": \"H\",\n    },\n    \"etth1\": {\n        \"boundaries\": [8640, 11520, 14400],\n        \"data_path\": \"./datasets/ETT-small/ETTh1.csv\",\n        \"freq\": \"H\",\n    },\n    \"elec\": {\n        \"boundaries\": [18413, 21044, 26304],\n        \"data_path\": \"./datasets/electricity/electricity.csv\",\n        \"freq\": \"H\",\n    },\n    \"traffic\": {\n        \"boundaries\": [12280, 14036, 17544],\n        \"data_path\": \"./datasets/traffic/traffic.csv\",\n        \"freq\": \"H\",\n    },\n    \"weather\": {\n        \"boundaries\": [36887, 42157, 52696],\n        \"data_path\": \"./datasets/weather/weather.csv\",\n        \"freq\": \"10min\",\n    },\n}\n\nQUANTILES = list(np.arange(1, 10) / 10.0)\nEPS = 1e-7\n\n\ndef get_forecasts(model_path, model, past, freq, pred_len):\n  \"\"\"Get forecasts.\"\"\"\n  if model_path.startswith(\"amazon\"):\n    out = model.predict(\n        torch.tensor(past),\n        prediction_length=pred_len,\n        limit_prediction_length=False,\n    )\n    out = out.numpy()\n    out = np.median(out, axis=1)\n  else:\n    lfreq = [freq] * past.shape[0]\n    _, out = model.forecast(list(past), lfreq)\n    out = out[:, :, 5]\n  return out\n\n\ndef _mse(y_pred, y_true):\n  \"\"\"mse loss.\"\"\"\n  return np.square(y_pred - y_true)\n\n\ndef _mae(y_pred, y_true):\n  \"\"\"mae loss.\"\"\"\n  return np.abs(y_pred - y_true)\n\n\ndef _smape(y_pred, y_true):\n  \"\"\"_smape loss.\"\"\"\n  abs_diff = np.abs(y_pred - y_true)\n  abs_val = (np.abs(y_true) + np.abs(y_pred)) / 2\n  abs_val = np.where(abs_val > EPS, abs_val, 1.0)\n  abs_diff = np.where(abs_val > EPS, abs_diff, 0.0)\n  return abs_diff / abs_val\n\n\ndef eval():\n  \"\"\"Eval pipeline.\"\"\"\n  dataset = _DATASET.value\n  data_path = DATA_DICT[dataset][\"data_path\"]\n  freq = DATA_DICT[dataset][\"freq\"]\n  int_freq = timesfm.freq_map(freq)\n  boundaries = DATA_DICT[dataset][\"boundaries\"]\n\n  data_df = pd.read_csv(open(data_path, \"r\"))\n\n  if _TS_COLS.value is not None:\n    ts_cols = DATA_DICT[dataset][\"ts_cols\"]\n    num_cov_cols = DATA_DICT[dataset][\"num_cov_cols\"]\n    cat_cov_cols = DATA_DICT[dataset][\"cat_cov_cols\"]\n  else:\n    ts_cols = [col for col in data_df.columns if col != _DATETIME_COL.value]\n    num_cov_cols = None\n    cat_cov_cols = None\n  batch_size = min(_BATCH_SIZE.value, len(ts_cols))\n  dtl = data_loader.TimeSeriesdata(\n      data_path=data_path,\n      datetime_col=_DATETIME_COL.value,\n      num_cov_cols=num_cov_cols,\n      cat_cov_cols=cat_cov_cols,\n      ts_cols=np.array(ts_cols),\n      train_range=[0, boundaries[0]],\n      val_range=[boundaries[0], boundaries[1]],\n      test_range=[boundaries[1], boundaries[2]],\n      hist_len=_CONTEXT_LEN.value,\n      pred_len=_PRED_LEN.value,\n      batch_size=batch_size,\n      freq=freq,\n      normalize=_NORMALIZE.value,\n      epoch_len=None,\n      holiday=False,\n      permute=False,\n  )\n  eval_itr = dtl.tf_dataset(mode=\"test\",\n                            shift=_PRED_LEN.value).as_numpy_iterator()\n  model_path = _MODEL_PATH.value\n  if model_path.startswith(\"amazon\"):\n    model = chronos.ChronosPipeline.from_pretrained(\n        model_path,\n        device_map=\"auto\",\n        torch_dtype=torch.bfloat16,\n    )\n  else:\n    model = timesfm.TimesFm(\n        hparams=timesfm.TimesFmHparams(\n            backend=\"gpu\",\n            per_core_batch_size=32,\n            horizon_len=128,\n            num_layers=50,\n            context_len=_CONTEXT_LEN.value,\n            use_positional_embedding=False,\n        ),\n        checkpoint=timesfm.TimesFmCheckpoint(huggingface_repo_id=model_path),\n    )\n  smape_run_losses = []\n  mse_run_losses = []\n  mae_run_losses = []\n\n  num_elements = 0\n  abs_sum = 0\n  start_time = time.time()\n\n  for batch in tqdm.tqdm(eval_itr):\n    past = batch[0]\n    actuals = batch[3]\n    forecasts = get_forecasts(model_path, model, past, int_freq,\n                              _PRED_LEN.value)\n    forecasts = forecasts[:, 0:actuals.shape[1]]\n    mae_run_losses.append(_mae(forecasts, actuals).sum())\n    mse_run_losses.append(_mse(forecasts, actuals).sum())\n    smape_run_losses.append(_smape(forecasts, actuals).sum())\n    num_elements += actuals.shape[0] * actuals.shape[1]\n    abs_sum += np.abs(actuals).sum()\n\n  mse_val = np.sum(mse_run_losses) / num_elements\n\n  result_dict = {\n      \"mse\": mse_val,\n      \"smape\": np.sum(smape_run_losses) / num_elements,\n      \"mae\": np.sum(mae_run_losses) / num_elements,\n      \"wape\": np.sum(mae_run_losses) / abs_sum,\n      \"nrmse\": np.sqrt(mse_val) / (abs_sum / num_elements),\n      \"num_elements\": num_elements,\n      \"abs_sum\": abs_sum,\n      \"total_time\": time.time() - start_time,\n      \"model_path\": model_path,\n      \"dataset\": dataset,\n      \"freq\": freq,\n      \"pred_len\": _PRED_LEN.value,\n      \"context_len\": _CONTEXT_LEN.value,\n  }\n  run_id = np.random.randint(10000)\n  save_path = os.path.join(_RESULTS_DIR.value, str(run_id))\n  print(f\"Saving results to {save_path}\", flush=True)\n  os.makedirs(save_path, exist_ok=True)\n  with open(os.path.join(save_path, \"results.json\"), \"w\") as f:\n    json.dump(result_dict, f)\n  print(result_dict, flush=True)\n\n\nif __name__ == \"__main__\":\n  FLAGS = flags.FLAGS\n  FLAGS(sys.argv)\n  eval()\n"
  },
  {
    "path": "v1/notebooks/covariates.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# TimesFM with Covariates\\n\",\n    \"\\n\",\n    \"This toturial notebook demonstrates how to utilize exogenous covariates with TimesFM when making forecasts. Before running this notebook, make sure:\\n\",\n    \"\\n\",\n    \"- You've read through the README of TimesFM.\\n\",\n    \"- A local kernel with Python 3.10 is up and running, for the jax version.\\n\",\n    \"- Install the JAX version following the installation instructions.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Setup the environment and install TimesFM.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Load the checkpoint\\n\",\n    \"\\n\",\n    \"**Notice:** Please set up the backend as per your machine (\\\"cpu\\\", \\\"gpu\\\" or \\\"tpu\\\"). This notebook will run by default on GPU.\\n\",\n    \"\\n\",\n    \"We load the 2.0-500m model checkpoint from HuggingFace.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import timesfm\\n\",\n    \"timesfm_backend = \\\"gpu\\\"  # @param\\n\",\n    \"\\n\",\n    \"model = timesfm.TimesFm(\\n\",\n    \"      hparams=timesfm.TimesFmHparams(\\n\",\n    \"          backend=timesfm_backend,\\n\",\n    \"          per_core_batch_size=32,\\n\",\n    \"          horizon_len=128,\\n\",\n    \"          num_layers=50,\\n\",\n    \"          use_positional_embedding=False,\\n\",\n    \"          context_len=2048,\\n\",\n    \"      ),\\n\",\n    \"      checkpoint=timesfm.TimesFmCheckpoint(\\n\",\n    \"          huggingface_repo_id=\\\"google/timesfm-2.0-500m-jax\\\"),\\n\",\n    \"  )\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Covariates\\n\",\n    \"\\n\",\n    \"Let's take a toy example of forecasting sales for a grocery store: \\n\",\n    \"\\n\",\n    \"**Task:** Given the observed the daily sales of this week (7 days), forecast the daily sales of next week (7 days).\\n\",\n    \"\\n\",\n    \"```\\n\",\n    \"Product: ice cream\\n\",\n    \"Daily_sales: [30, 30, 4, 5, 7, 8, 10]\\n\",\n    \"Category: food\\n\",\n    \"Base_price: 1.99\\n\",\n    \"Weekday: [0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6]\\n\",\n    \"Has_promotion: [Yes, Yes, No, No, No, Yes, Yes, No, No, No, No, No, No, No]\\n\",\n    \"Daily_temperature: [31.0, 24.3, 19.4, 26.2, 24.6, 30.0, 31.1, 32.4, 30.9, 26.0, 25.0, 27.8, 29.5, 31.2]\\n\",\n    \"```\\n\",\n    \"\\n\",\n    \"```\\n\",\n    \"Product: sunscreen\\n\",\n    \"Daily_sales: [5, 7, 12, 13, 5, 6, 10]\\n\",\n    \"Category: skin product\\n\",\n    \"Base_price: 29.99\\n\",\n    \"Weekday: [0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6]\\n\",\n    \"Has_promotion: [No, No, Yes, Yes, No, No, No, Yes, Yes, Yes, Yes, Yes, Yes, Yes]\\n\",\n    \"Daily_temperature: [31.0, 24.3, 19.4, 26.2, 24.6, 30.0, 31.1, 32.4, 30.9, 26.0, 25.0, 27.8, 29.5, 31.2]\\n\",\n    \"```\\n\",\n    \"\\n\",\n    \"In this example, besides the `Daily_sales`, we also have covariates `Category`, `Base_price`, `Weekday`, `Has_promotion`, `Daily_temperature`. Let's introduce some concepts:\\n\",\n    \"\\n\",\n    \"**Static covariates** are covariates for each time series. \\n\",\n    \"- In our example, `Category` is a **static categorical covariate**, \\n\",\n    \"- `Base_price` is a **static numerical covariates**.\\n\",\n    \"\\n\",\n    \"**Dynamic covariates** are covaraites for each time stamps.\\n\",\n    \"- Date / time related features can be usually treated as dynamic covariates.\\n\",\n    \"- In our example, `Weekday` and `Has_promotion` are **dynamic categorical covariates**.\\n\",\n    \"- `Daily_temperate` is a **dynamic numerical covariate**.\\n\",\n    \"\\n\",\n    \"**Notice:** Here we make it mandatory that the dynamic covariates need to cover both the forecasting context and horizon. For example, all dynamic covariates in the example have 14 values: the first 7 correspond to the observed 7 days, and the last 7 correspond to the next 7 days.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# TimesFM with Covariates\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"The strategy we take here is to treat covariates as batched in-context exogenous regressors (XReg) and fit linear models on them outside of TimesFM. The final forecast will be the sum of the TimesFM forecast and the linear model forecast.\\n\",\n    \"\\n\",\n    \" In simple words, we consider these two options.\\n\",\n    \"\\n\",\n    \"**Option 1:** Get the TimesFM forecast, and fit the linear model regressing the residuals on the covariates (\\\"timesfm + xreg\\\").\\n\",\n    \"\\n\",\n    \"**Option 2:** Fit the linear model of the time series itself on the covariates, then forecast the residuals using TimesFM  (\\\"xreg + timesfm\\\").\\n\",\n    \"\\n\",\n    \"Let's take a code at the example of Electricity Price Forecasting (EPF). \\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 2,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import pandas as pd\\n\",\n    \"import numpy as np\\n\",\n    \"from collections import defaultdict\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"df = pd.read_csv('https://datasets-nixtla.s3.amazonaws.com/EPF_FR_BE.csv')\\n\",\n    \"df['ds'] = pd.to_datetime(df['ds'])\\n\",\n    \"df\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"This dataset has a few covariates beside the hourly target `y`:\\n\",\n    \"\\n\",\n    \"- `unique_id`: a static categorical covariate indicating the country.\\n\",\n    \"- `gen_forecast`: a dynamic numerical covariate indicating the estimated electricity to be generated.\\n\",\n    \"- `system_load`: the observed system load. Notice that this **CANNOT** be considered as a dynamic numerical covariate because we cannot know its values over the forecasting horizon in advance.\\n\",\n    \"- `weekday`: a dynamic categorical covariate.\\\\\\n\",\n    \"\\n\",\n    \"Let's now make some forecasting tasks for TimesFM based on this dataset. For simplicity we create forecast contexts of 120 time points (hours) and forecast horizons of 24 time points.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 4,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Data pipelining\\n\",\n    \"def get_batched_data_fn(\\n\",\n    \"    batch_size: int = 128, \\n\",\n    \"    context_len: int = 120, \\n\",\n    \"    horizon_len: int = 24,\\n\",\n    \"):\\n\",\n    \"  examples = defaultdict(list)\\n\",\n    \"\\n\",\n    \"  num_examples = 0\\n\",\n    \"  for country in (\\\"FR\\\", \\\"BE\\\"):\\n\",\n    \"    sub_df = df[df[\\\"unique_id\\\"] == country]\\n\",\n    \"    for start in range(0, len(sub_df) - (context_len + horizon_len), horizon_len):\\n\",\n    \"      num_examples += 1\\n\",\n    \"      examples[\\\"country\\\"].append(country)\\n\",\n    \"      examples[\\\"inputs\\\"].append(sub_df[\\\"y\\\"][start:(context_end := start + context_len)].tolist())\\n\",\n    \"      examples[\\\"gen_forecast\\\"].append(sub_df[\\\"gen_forecast\\\"][start:context_end + horizon_len].tolist())\\n\",\n    \"      examples[\\\"week_day\\\"].append(sub_df[\\\"week_day\\\"][start:context_end + horizon_len].tolist())\\n\",\n    \"      examples[\\\"outputs\\\"].append(sub_df[\\\"y\\\"][context_end:(context_end + horizon_len)].tolist())\\n\",\n    \"  \\n\",\n    \"  def data_fn():\\n\",\n    \"    for i in range(1 + (num_examples - 1) // batch_size):\\n\",\n    \"      yield {k: v[(i * batch_size) : ((i + 1) * batch_size)] for k, v in examples.items()}\\n\",\n    \"  \\n\",\n    \"  return data_fn\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 5,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Define metrics\\n\",\n    \"def mse(y_pred, y_true):\\n\",\n    \"  y_pred = np.array(y_pred)\\n\",\n    \"  y_true = np.array(y_true)\\n\",\n    \"  return np.mean(np.square(y_pred - y_true), axis=1, keepdims=True)\\n\",\n    \"\\n\",\n    \"def mae(y_pred, y_true):\\n\",\n    \"  y_pred = np.array(y_pred)\\n\",\n    \"  y_true = np.array(y_true)\\n\",\n    \"  return np.mean(np.abs(y_pred - y_true), axis=1, keepdims=True)\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Now let's try `model.forecast_with_covariates`. \\n\",\n    \"\\n\",\n    \"In particular, the output is a tuple whose first element is the new forecast.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import time\\n\",\n    \"\\n\",\n    \"# Benchmark\\n\",\n    \"batch_size = 128\\n\",\n    \"context_len = 120\\n\",\n    \"horizon_len = 24\\n\",\n    \"input_data = get_batched_data_fn(batch_size = 128)\\n\",\n    \"metrics = defaultdict(list)\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"for i, example in enumerate(input_data()):\\n\",\n    \"  raw_forecast, _ = model.forecast(\\n\",\n    \"      inputs=example[\\\"inputs\\\"], freq=[0] * len(example[\\\"inputs\\\"])\\n\",\n    \"  )\\n\",\n    \"  start_time = time.time()\\n\",\n    \"  # Forecast with covariates\\n\",\n    \"  # Output: new forecast, forecast by the xreg\\n\",\n    \"  cov_forecast, ols_forecast = model.forecast_with_covariates(  \\n\",\n    \"      inputs=example[\\\"inputs\\\"],\\n\",\n    \"      dynamic_numerical_covariates={\\n\",\n    \"          \\\"gen_forecast\\\": example[\\\"gen_forecast\\\"],\\n\",\n    \"      },\\n\",\n    \"      dynamic_categorical_covariates={\\n\",\n    \"          \\\"week_day\\\": example[\\\"week_day\\\"],\\n\",\n    \"      },\\n\",\n    \"      static_numerical_covariates={},\\n\",\n    \"      static_categorical_covariates={\\n\",\n    \"          \\\"country\\\": example[\\\"country\\\"]\\n\",\n    \"      },\\n\",\n    \"      freq=[0] * len(example[\\\"inputs\\\"]),\\n\",\n    \"      xreg_mode=\\\"xreg + timesfm\\\",              # default\\n\",\n    \"      ridge=0.0,\\n\",\n    \"      force_on_cpu=False,\\n\",\n    \"      normalize_xreg_target_per_input=True,    # default\\n\",\n    \"  )\\n\",\n    \"  print(\\n\",\n    \"      f\\\"\\\\rFinished batch {i} linear in {time.time() - start_time} seconds\\\",\\n\",\n    \"      end=\\\"\\\",\\n\",\n    \"  )\\n\",\n    \"  metrics[\\\"eval_mae_timesfm\\\"].extend(\\n\",\n    \"      mae(raw_forecast[:, :horizon_len], example[\\\"outputs\\\"])\\n\",\n    \"  )\\n\",\n    \"  metrics[\\\"eval_mae_xreg_timesfm\\\"].extend(mae(cov_forecast, example[\\\"outputs\\\"]))\\n\",\n    \"  metrics[\\\"eval_mae_xreg\\\"].extend(mae(ols_forecast, example[\\\"outputs\\\"]))\\n\",\n    \"  metrics[\\\"eval_mse_timesfm\\\"].extend(\\n\",\n    \"      mse(raw_forecast[:, :horizon_len], example[\\\"outputs\\\"])\\n\",\n    \"  )\\n\",\n    \"  metrics[\\\"eval_mse_xreg_timesfm\\\"].extend(mse(cov_forecast, example[\\\"outputs\\\"]))\\n\",\n    \"  metrics[\\\"eval_mse_xreg\\\"].extend(mse(ols_forecast, example[\\\"outputs\\\"]))\\n\",\n    \"\\n\",\n    \"print()\\n\",\n    \"\\n\",\n    \"for k, v in metrics.items():\\n\",\n    \"  print(f\\\"{k}: {np.mean(v)}\\\")\\n\",\n    \"\\n\",\n    \"# My output:\\n\",\n    \"# eval_mae_timesfm: 6.762283045916956\\n\",\n    \"# eval_mae_xreg_timesfm: 5.39219617611074\\n\",\n    \"# eval_mae_xreg: 37.15275842572484\\n\",\n    \"# eval_mse_timesfm: 166.7771466306823\\n\",\n    \"# eval_mse_xreg_timesfm: 120.64757721021306\\n\",\n    \"# eval_mse_xreg: 1672.2116821201796\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"You should see results close to \\n\",\n    \"```\\n\",\n    \"eval_mae_timesfm: 6.729583250571446\\n\",\n    \"eval_mae_xreg_timesfm: 5.3375301110158\\n\",\n    \"eval_mae_xreg: 37.152760709266\\n\",\n    \"eval_mse_timesfm: 162.3132151851567\\n\",\n    \"eval_mse_xreg_timesfm: 120.9900627409689\\n\",\n    \"eval_mse_xreg: 1672.208769045399\\n\",\n    \"```\\n\",\n    \"\\n\",\n    \"With the covariates, the TimesFM forecast Mean Absolute Error improves from 6.73 to 5.34, and Mean Squred Error from 162.31 to 120.99. The results of purely fitting the linear model are also provided for reference.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Formatting Your Request\\n\",\n    \"\\n\",\n    \"It is quite crucial to get the covariates properly formatted so that we can call this `model.forecast_with_covariates`. Please see its docstring for details. Here let's also grab a batch from a toy data input pipeline for quick explanations.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"toy_input_pipeline = get_batched_data_fn(batch_size=2, context_len=5, horizon_len=2)\\n\",\n    \"print(next(toy_input_pipeline()))\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"You should see something similar to this\\n\",\n    \"```\\n\",\n    \"{\\n\",\n    \"    'country': ['FR', 'FR'], \\n\",\n    \"    'inputs': [[53.48, 51.93, 48.76, 42.27, 38.41], [48.76, 42.27, 38.41, 35.72, 32.66]], \\n\",\n    \"    'gen_forecast': [[76905.0, 75492.0, 74394.0, 72639.0, 69347.0, 67960.0, 67564.0], [74394.0, 72639.0, 69347.0, 67960.0, 67564.0, 67277.0, 67019.0]], \\n\",\n    \"    'week_day': [[3, 3, 3, 3, 3, 3, 3], [3, 3, 3, 3, 3, 3, 3]], \\n\",\n    \"    'outputs': [[35.72, 32.66], [32.83, 30.06]],\\n\",\n    \"}\\n\",\n    \"```\\n\",\n    \"\\n\",\n    \"Notice:\\n\",\n    \"- We have two examples in this batch.\\n\",\n    \"- For each example we support different context lengths and horizon lengths just as `model.forecast`. Although it is not demonstrated in this dataset.\\n\",\n    \"- If dynamic covariates are present, the horizon lengths will be inferred from them, e.g. how many values are provided in additional to the ones corresponding to the inputs. Make sure all your dynamic covariates have the same length per example.\\n\",\n    \"- The static covariates are one per example.\\n\",\n    \"\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## More Applications\\n\",\n    \"\\n\",\n    \"### Past Dynamic Covariates\\n\",\n    \"\\n\",\n    \"Past dynamic covariates are covariates that are only available for the context. For instance in our example `system_load` is a past dynamic covariate. Time series models generally can handle this, however it is something the batched in context regression cannot address, because these regressors are not available in the future. If you do have those covariates and consider them very meaningful, there are two hacky options to try immediately:\\n\",\n    \"\\n\",\n    \"1. Shift and repeat these past dynamic covariates to use their delayed version. For example, if you think the `system_load` for this week is meaningful for forecasting next week, you can create a `delay_7_system_load` by shifting 7 timestamps and use this as one dynamic numerical covariate for TimesFM.\\n\",\n    \"2. Bootstrap, that is to run TimesFM once to forecast these past dynamic covariates into the horizon, then call TimesFM again using these forecasts as the future part for these dynamic covariates.\\n\",\n    \"\\n\",\n    \"### Multivariate Time Series\\n\",\n    \"\\n\",\n    \"For multivariate time series, if we need univariate forecast, we can try treating the main time series as the target and use the rest as the dynamic covariates.\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"chronos-v2\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.10.15\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}\n"
  },
  {
    "path": "v1/notebooks/finetuning.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Importing relevant packages for finetuning\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import os\\n\",\n    \"os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'\\n\",\n    \"os.environ['JAX_PMAP_USE_TENSORSTORE'] = 'false'\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import timesfm\\n\",\n    \"import gc\\n\",\n    \"import numpy as np\\n\",\n    \"import pandas as pd\\n\",\n    \"from timesfm import patched_decoder\\n\",\n    \"from timesfm import data_loader\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 3,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from tqdm import tqdm\\n\",\n    \"import dataclasses\\n\",\n    \"import IPython\\n\",\n    \"import IPython.display\\n\",\n    \"import matplotlib as mpl\\n\",\n    \"import matplotlib.pyplot as plt\\n\",\n    \"mpl.rcParams['figure.figsize'] = (8, 6)\\n\",\n    \"mpl.rcParams['axes.grid'] = False\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Loading TimesFM pretrained checkpoint\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"timesfm_backend = \\\"gpu\\\"  # @param\\n\",\n    \"\\n\",\n    \"tfm = timesfm.TimesFm(\\n\",\n    \"      hparams=timesfm.TimesFmHparams(\\n\",\n    \"          backend=timesfm_backend,\\n\",\n    \"          per_core_batch_size=32,\\n\",\n    \"          horizon_len=128,\\n\",\n    \"          num_layers=50,\\n\",\n    \"          # Se this to True for v1.0 checkpoints\\n\",\n    \"          use_positional_embedding=False,\\n\",\n    \"          # Note that we could set this to as high as 2048 but keeping it 512 here so that\\n\",\n    \"          # both v1.0 and 2.0 checkpoints work\\n\",\n    \"          context_len=512,\\n\",\n    \"      ),\\n\",\n    \"      checkpoint=timesfm.TimesFmCheckpoint(\\n\",\n    \"          huggingface_repo_id=\\\"google/timesfm-2.0-500m-jax\\\"),\\n\",\n    \"  )\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Evaluating pretrained checkpoint on ETT datasets\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 5,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"DATA_DICT = {\\n\",\n    \"    \\\"ettm2\\\": {\\n\",\n    \"        \\\"boundaries\\\": [34560, 46080, 57600],\\n\",\n    \"        \\\"data_path\\\": \\\"../datasets/ETT-small/ETTm2.csv\\\",\\n\",\n    \"        \\\"freq\\\": \\\"15min\\\",\\n\",\n    \"    },\\n\",\n    \"    \\\"ettm1\\\": {\\n\",\n    \"        \\\"boundaries\\\": [34560, 46080, 57600],\\n\",\n    \"        \\\"data_path\\\": \\\"../datasets/ETT-small/ETTm1.csv\\\",\\n\",\n    \"        \\\"freq\\\": \\\"15min\\\",\\n\",\n    \"    },\\n\",\n    \"    \\\"etth2\\\": {\\n\",\n    \"        \\\"boundaries\\\": [8640, 11520, 14400],\\n\",\n    \"        \\\"data_path\\\": \\\"../datasets/ETT-small/ETTh2.csv\\\",\\n\",\n    \"        \\\"freq\\\": \\\"H\\\",\\n\",\n    \"    },\\n\",\n    \"    \\\"etth1\\\": {\\n\",\n    \"        \\\"boundaries\\\": [8640, 11520, 14400],\\n\",\n    \"        \\\"data_path\\\": \\\"../datasets/ETT-small/ETTh1.csv\\\",\\n\",\n    \"        \\\"freq\\\": \\\"H\\\",\\n\",\n    \"    },\\n\",\n    \"    \\\"elec\\\": {\\n\",\n    \"        \\\"boundaries\\\": [18413, 21044, 26304],\\n\",\n    \"        \\\"data_path\\\": \\\"../datasets/electricity/electricity.csv\\\",\\n\",\n    \"        \\\"freq\\\": \\\"H\\\",\\n\",\n    \"    },\\n\",\n    \"    \\\"traffic\\\": {\\n\",\n    \"        \\\"boundaries\\\": [12280, 14036, 17544],\\n\",\n    \"        \\\"data_path\\\": \\\"../datasets/traffic/traffic.csv\\\",\\n\",\n    \"        \\\"freq\\\": \\\"H\\\",\\n\",\n    \"    },\\n\",\n    \"    \\\"weather\\\": {\\n\",\n    \"        \\\"boundaries\\\": [36887, 42157, 52696],\\n\",\n    \"        \\\"data_path\\\": \\\"../datasets/weather/weather.csv\\\",\\n\",\n    \"        \\\"freq\\\": \\\"10min\\\",\\n\",\n    \"    },\\n\",\n    \"}\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 6,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"dataset = \\\"ettm1\\\"\\n\",\n    \"data_path = DATA_DICT[dataset][\\\"data_path\\\"]\\n\",\n    \"freq = DATA_DICT[dataset][\\\"freq\\\"]\\n\",\n    \"int_freq = timesfm.freq_map(freq)\\n\",\n    \"boundaries = DATA_DICT[dataset][\\\"boundaries\\\"]\\n\",\n    \"\\n\",\n    \"data_df = pd.read_csv(open(data_path, \\\"r\\\"))\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"ts_cols = [col for col in data_df.columns if col != \\\"date\\\"]\\n\",\n    \"num_cov_cols = None\\n\",\n    \"cat_cov_cols = None\\n\",\n    \"\\n\",\n    \"context_len = 512\\n\",\n    \"pred_len = 96\\n\",\n    \"\\n\",\n    \"num_ts = len(ts_cols)\\n\",\n    \"batch_size = 8\\n\",\n    \"\\n\",\n    \"dtl = data_loader.TimeSeriesdata(\\n\",\n    \"      data_path=data_path,\\n\",\n    \"      datetime_col=\\\"date\\\",\\n\",\n    \"      num_cov_cols=num_cov_cols,\\n\",\n    \"      cat_cov_cols=cat_cov_cols,\\n\",\n    \"      ts_cols=np.array(ts_cols),\\n\",\n    \"      train_range=[0, boundaries[0]],\\n\",\n    \"      val_range=[boundaries[0], boundaries[1]],\\n\",\n    \"      test_range=[boundaries[1], boundaries[2]],\\n\",\n    \"      hist_len=context_len,\\n\",\n    \"      pred_len=pred_len,\\n\",\n    \"      batch_size=num_ts,\\n\",\n    \"      freq=freq,\\n\",\n    \"      normalize=True,\\n\",\n    \"      epoch_len=None,\\n\",\n    \"      holiday=False,\\n\",\n    \"      permute=True,\\n\",\n    \"  )\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"train_batches = dtl.tf_dataset(mode=\\\"train\\\", shift=1).batch(batch_size)\\n\",\n    \"val_batches = dtl.tf_dataset(mode=\\\"val\\\", shift=pred_len)\\n\",\n    \"test_batches = dtl.tf_dataset(mode=\\\"test\\\", shift=pred_len)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"for tbatch in tqdm(train_batches.as_numpy_iterator()):\\n\",\n    \"    break\\n\",\n    \"print(tbatch[0].shape)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### MAE on the test split for the pretrained TimesFM model\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"mae_losses = []\\n\",\n    \"for batch in tqdm(test_batches.as_numpy_iterator()):\\n\",\n    \"    past = batch[0]\\n\",\n    \"    actuals = batch[3]\\n\",\n    \"    forecasts, _ = tfm.forecast(list(past), [0] * past.shape[0], normalize=True)\\n\",\n    \"    forecasts = forecasts[:, 0 : actuals.shape[1]]\\n\",\n    \"    mae_losses.append(np.abs(forecasts - actuals).mean())\\n\",\n    \"\\n\",\n    \"print(f\\\"MAE: {np.mean(mae_losses)}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Finetuning the model on the ETT dataset\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 9,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import jax\\n\",\n    \"from jax import numpy as jnp\\n\",\n    \"from praxis import pax_fiddle\\n\",\n    \"from praxis import py_utils\\n\",\n    \"from praxis import pytypes\\n\",\n    \"from praxis import base_model\\n\",\n    \"from praxis import optimizers\\n\",\n    \"from praxis import schedules\\n\",\n    \"from praxis import base_hyperparams\\n\",\n    \"from praxis import base_layer\\n\",\n    \"from paxml import tasks_lib\\n\",\n    \"from paxml import trainer_lib\\n\",\n    \"from paxml import checkpoints\\n\",\n    \"from paxml import learners\\n\",\n    \"from paxml import partitioning\\n\",\n    \"from paxml import checkpoint_types\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 10,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# PAX shortcuts\\n\",\n    \"NestedMap = py_utils.NestedMap\\n\",\n    \"WeightInit = base_layer.WeightInit\\n\",\n    \"WeightHParams = base_layer.WeightHParams\\n\",\n    \"InstantiableParams = py_utils.InstantiableParams\\n\",\n    \"JTensor = pytypes.JTensor\\n\",\n    \"NpTensor = pytypes.NpTensor\\n\",\n    \"WeightedScalars = pytypes.WeightedScalars\\n\",\n    \"instantiate = base_hyperparams.instantiate\\n\",\n    \"LayerTpl = pax_fiddle.Config[base_layer.BaseLayer]\\n\",\n    \"AuxLossStruct = base_layer.AuxLossStruct\\n\",\n    \"\\n\",\n    \"AUX_LOSS = base_layer.AUX_LOSS\\n\",\n    \"template_field = base_layer.template_field\\n\",\n    \"\\n\",\n    \"# Standard prng key names\\n\",\n    \"PARAMS = base_layer.PARAMS\\n\",\n    \"RANDOM = base_layer.RANDOM\\n\",\n    \"\\n\",\n    \"key = jax.random.PRNGKey(seed=1234)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 11,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"model = pax_fiddle.Config(\\n\",\n    \"    patched_decoder.PatchedDecoderFinetuneModel,\\n\",\n    \"    name='patched_decoder_finetune',\\n\",\n    \"    core_layer_tpl=tfm.model_p,\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### We will hold the transformer layers fixed while finetuning, while training all other components.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 12,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"@pax_fiddle.auto_config\\n\",\n    \"def build_learner() -> learners.Learner:\\n\",\n    \"  return pax_fiddle.Config(\\n\",\n    \"      learners.Learner,\\n\",\n    \"      name='learner',\\n\",\n    \"      loss_name='avg_qloss',\\n\",\n    \"      optimizer=optimizers.Adam(\\n\",\n    \"          epsilon=1e-7,\\n\",\n    \"          clip_threshold=1e2,\\n\",\n    \"          learning_rate=1e-2,\\n\",\n    \"          lr_schedule=pax_fiddle.Config(\\n\",\n    \"              schedules.Cosine,\\n\",\n    \"              initial_value=1e-3,\\n\",\n    \"              final_value=1e-4,\\n\",\n    \"              total_steps=40000,\\n\",\n    \"          ),\\n\",\n    \"          ema_decay=0.9999,\\n\",\n    \"      ),\\n\",\n    \"      # Linear probing i.e we hold the transformer layers fixed.\\n\",\n    \"      bprop_variable_exclusion=['.*/stacked_transformer_layer/.*'],\\n\",\n    \"  )\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 13,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"task_p = tasks_lib.SingleTask(\\n\",\n    \"    name='ts-learn',\\n\",\n    \"    model=model,\\n\",\n    \"    train=tasks_lib.SingleTask.Train(\\n\",\n    \"        learner=build_learner(),\\n\",\n    \"    ),\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"task_p.model.ici_mesh_shape = [1, 1, 1]\\n\",\n    \"task_p.model.mesh_axis_names = ['replica', 'data', 'mdl']\\n\",\n    \"\\n\",\n    \"DEVICES = np.array(jax.devices()).reshape([1, 1, 1])\\n\",\n    \"MESH = jax.sharding.Mesh(DEVICES, ['replica', 'data', 'mdl'])\\n\",\n    \"\\n\",\n    \"num_devices = jax.local_device_count()\\n\",\n    \"print(f'num_devices: {num_devices}')\\n\",\n    \"print(f'device kind: {jax.local_devices()[0].device_kind}')\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"jax_task = task_p\\n\",\n    \"key, init_key = jax.random.split(key)\\n\",\n    \"\\n\",\n    \"# To correctly prepare a batch of data for model initialization (now that shape\\n\",\n    \"# inference is merged), we take one devices*batch_size tensor tuple of data,\\n\",\n    \"# slice out just one batch, then run the prepare_input_batch function over it.\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def process_train_batch(batch):\\n\",\n    \"    past_ts = batch[0].reshape(batch_size * num_ts, -1)\\n\",\n    \"    actual_ts = batch[3].reshape(batch_size * num_ts, -1)\\n\",\n    \"    return NestedMap(input_ts=past_ts, actual_ts=actual_ts)\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def process_eval_batch(batch):\\n\",\n    \"    past_ts = batch[0]\\n\",\n    \"    actual_ts = batch[3]\\n\",\n    \"    return NestedMap(input_ts=past_ts, actual_ts=actual_ts)\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"jax_model_states, _ = trainer_lib.initialize_model_state(\\n\",\n    \"    jax_task,\\n\",\n    \"    init_key,\\n\",\n    \"    process_train_batch(tbatch),\\n\",\n    \"    checkpoint_type=checkpoint_types.CheckpointType.GDA,\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Setting the initial model weights to the pretrained TimesFM parameters.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"jax_model_states.mdl_vars['params']['core_layer'] = tfm._train_state.mdl_vars['params']\\n\",\n    \"jax_vars = jax_model_states.mdl_vars\\n\",\n    \"gc.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Training loop\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 19,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"jax_task = task_p\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def train_step(states, prng_key, inputs):\\n\",\n    \"  return trainer_lib.train_step_single_learner(\\n\",\n    \"      jax_task, states, prng_key, inputs\\n\",\n    \"  )\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def eval_step(states, prng_key, inputs):\\n\",\n    \"  states = states.to_eval_state()\\n\",\n    \"  return trainer_lib.eval_step_single_learner(\\n\",\n    \"      jax_task, states, prng_key, inputs\\n\",\n    \"  )\\n\",\n    \"\\n\",\n    \"key, train_key, eval_key = jax.random.split(key, 3)\\n\",\n    \"train_prng_seed = jax.random.split(train_key, num=jax.local_device_count())\\n\",\n    \"eval_prng_seed = jax.random.split(eval_key, num=jax.local_device_count())\\n\",\n    \"\\n\",\n    \"p_train_step = jax.pmap(train_step, axis_name='batch')\\n\",\n    \"p_eval_step = jax.pmap(eval_step, axis_name='batch')\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 20,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"replicated_jax_states = trainer_lib.replicate_model_state(jax_model_states)\\n\",\n    \"replicated_jax_vars = replicated_jax_states.mdl_vars\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 21,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"best_eval_loss = 1e7\\n\",\n    \"step_count = 0\\n\",\n    \"patience = 0\\n\",\n    \"NUM_EPOCHS = 100\\n\",\n    \"PATIENCE = 5\\n\",\n    \"TRAIN_STEPS_PER_EVAL = 1000\\n\",\n    \"CHECKPOINT_DIR='/home/senrajat_google_com/ettm1_finetune'\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 22,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def reshape_batch_for_pmap(batch, num_devices):\\n\",\n    \"  def _reshape(input_tensor):\\n\",\n    \"    bsize = input_tensor.shape[0]\\n\",\n    \"    residual_shape = list(input_tensor.shape[1:])\\n\",\n    \"    nbsize = bsize // num_devices\\n\",\n    \"    return jnp.reshape(input_tensor, [num_devices, nbsize] + residual_shape)\\n\",\n    \"\\n\",\n    \"  return jax.tree.map(_reshape, batch)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"for epoch in range(NUM_EPOCHS):\\n\",\n    \"    print(f\\\"__________________Epoch: {epoch}__________________\\\", flush=True)\\n\",\n    \"    train_its = train_batches.as_numpy_iterator()\\n\",\n    \"    if patience >= PATIENCE:\\n\",\n    \"        print(\\\"Early stopping.\\\", flush=True)\\n\",\n    \"        break\\n\",\n    \"    for batch in tqdm(train_its):\\n\",\n    \"        train_losses = []\\n\",\n    \"        if patience >= PATIENCE:\\n\",\n    \"            print(\\\"Early stopping.\\\", flush=True)\\n\",\n    \"            break\\n\",\n    \"        tbatch = process_train_batch(batch)\\n\",\n    \"        tbatch = reshape_batch_for_pmap(tbatch, num_devices)\\n\",\n    \"        replicated_jax_states, step_fun_out = p_train_step(\\n\",\n    \"            replicated_jax_states, train_prng_seed, tbatch\\n\",\n    \"        )\\n\",\n    \"        train_losses.append(step_fun_out.loss[0])\\n\",\n    \"        if step_count % TRAIN_STEPS_PER_EVAL == 0:\\n\",\n    \"            print(\\n\",\n    \"                f\\\"Train loss at step {step_count}: {np.mean(train_losses)}\\\",\\n\",\n    \"                flush=True,\\n\",\n    \"            )\\n\",\n    \"            train_losses = []\\n\",\n    \"            print(\\\"Starting eval.\\\", flush=True)\\n\",\n    \"            val_its = val_batches.as_numpy_iterator()\\n\",\n    \"            eval_losses = []\\n\",\n    \"            for ev_batch in tqdm(val_its):\\n\",\n    \"                ebatch = process_eval_batch(ev_batch)\\n\",\n    \"                ebatch = reshape_batch_for_pmap(ebatch, num_devices)\\n\",\n    \"                _, step_fun_out = p_eval_step(\\n\",\n    \"                    replicated_jax_states, eval_prng_seed, ebatch\\n\",\n    \"                )\\n\",\n    \"                eval_losses.append(step_fun_out.loss[0])\\n\",\n    \"            mean_loss = np.mean(eval_losses)\\n\",\n    \"            print(f\\\"Eval loss at step {step_count}: {mean_loss}\\\", flush=True)\\n\",\n    \"            if mean_loss < best_eval_loss or np.isnan(mean_loss):\\n\",\n    \"                best_eval_loss = mean_loss\\n\",\n    \"                print(\\\"Saving checkpoint.\\\")\\n\",\n    \"                jax_state_for_saving = py_utils.maybe_unreplicate_for_fully_replicated(\\n\",\n    \"                    replicated_jax_states\\n\",\n    \"                )\\n\",\n    \"                checkpoints.save_checkpoint(\\n\",\n    \"                    jax_state_for_saving, CHECKPOINT_DIR, overwrite=True\\n\",\n    \"                )\\n\",\n    \"                patience = 0\\n\",\n    \"                del jax_state_for_saving\\n\",\n    \"                gc.collect()\\n\",\n    \"            else:\\n\",\n    \"                patience += 1\\n\",\n    \"                print(f\\\"patience: {patience}\\\")\\n\",\n    \"        step_count += 1\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Loading and evaluating the best (according to validation loss) finetuned checkpoint\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"train_state = checkpoints.restore_checkpoint(jax_model_states, CHECKPOINT_DIR)\\n\",\n    \"print(train_state.step)\\n\",\n    \"tfm._train_state.mdl_vars['params'] = train_state.mdl_vars['params']['core_layer']\\n\",\n    \"tfm.jit_decode()\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"mae_losses = []\\n\",\n    \"for batch in tqdm(test_batches.as_numpy_iterator()):\\n\",\n    \"    past = batch[0]\\n\",\n    \"    actuals = batch[3]\\n\",\n    \"    _, forecasts = tfm.forecast(list(past), [0] * past.shape[0])\\n\",\n    \"    forecasts = forecasts[:, 0 : actuals.shape[1], 5]\\n\",\n    \"    mae_losses.append(np.abs(forecasts - actuals).mean())\\n\",\n    \"\\n\",\n    \"print(f\\\"MAE: {np.mean(mae_losses)}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## There is around a __7%__ reduction in MAE from finetuning.\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"chronos-v2\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.10.15\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}\n"
  },
  {
    "path": "v1/notebooks/finetuning_torch.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Introduction\\n\",\n    \"This notebook shows how to use TimesFM with finetuning. \\n\",\n    \"\\n\",\n    \"In order to perform finetuning, you need to create the Pytorch Dataset in a proper format. The example of the Dataset is provided below.\\n\",\n    \"The finetuning code can be found in timesfm.finetuning_torch.py. This notebook just imports the methods from finetuning\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Dataset Creation\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"TimesFM v1.2.0. See https://github.com/google-research/timesfm/blob/master/README.md for updated APIs.\\n\",\n      \"Loaded Jax TimesFM.\\n\",\n      \"Loaded PyTorch TimesFM.\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"from os import path\\n\",\n    \"from typing import Optional, Tuple\\n\",\n    \"\\n\",\n    \"import numpy as np\\n\",\n    \"import pandas as pd\\n\",\n    \"import torch\\n\",\n    \"import torch.multiprocessing as mp\\n\",\n    \"import yfinance as yf\\n\",\n    \"from finetuning.finetuning_torch import FinetuningConfig, TimesFMFinetuner\\n\",\n    \"from huggingface_hub import snapshot_download\\n\",\n    \"from torch.utils.data import Dataset\\n\",\n    \"\\n\",\n    \"from timesfm import TimesFm, TimesFmCheckpoint, TimesFmHparams\\n\",\n    \"from timesfm.pytorch_patched_decoder import PatchedTimeSeriesDecoder\\n\",\n    \"import os\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"class TimeSeriesDataset(Dataset):\\n\",\n    \"  \\\"\\\"\\\"Dataset for time series data compatible with TimesFM.\\\"\\\"\\\"\\n\",\n    \"\\n\",\n    \"  def __init__(self,\\n\",\n    \"               series: np.ndarray,\\n\",\n    \"               context_length: int,\\n\",\n    \"               horizon_length: int,\\n\",\n    \"               freq_type: int = 0):\\n\",\n    \"    \\\"\\\"\\\"\\n\",\n    \"        Initialize dataset.\\n\",\n    \"\\n\",\n    \"        Args:\\n\",\n    \"            series: Time series data\\n\",\n    \"            context_length: Number of past timesteps to use as input\\n\",\n    \"            horizon_length: Number of future timesteps to predict\\n\",\n    \"            freq_type: Frequency type (0, 1, or 2)\\n\",\n    \"        \\\"\\\"\\\"\\n\",\n    \"    if freq_type not in [0, 1, 2]:\\n\",\n    \"      raise ValueError(\\\"freq_type must be 0, 1, or 2\\\")\\n\",\n    \"\\n\",\n    \"    self.series = series\\n\",\n    \"    self.context_length = context_length\\n\",\n    \"    self.horizon_length = horizon_length\\n\",\n    \"    self.freq_type = freq_type\\n\",\n    \"    self._prepare_samples()\\n\",\n    \"\\n\",\n    \"  def _prepare_samples(self) -> None:\\n\",\n    \"    \\\"\\\"\\\"Prepare sliding window samples from the time series.\\\"\\\"\\\"\\n\",\n    \"    self.samples = []\\n\",\n    \"    total_length = self.context_length + self.horizon_length\\n\",\n    \"\\n\",\n    \"    for start_idx in range(0, len(self.series) - total_length + 1):\\n\",\n    \"      end_idx = start_idx + self.context_length\\n\",\n    \"      x_context = self.series[start_idx:end_idx]\\n\",\n    \"      x_future = self.series[end_idx:end_idx + self.horizon_length]\\n\",\n    \"      self.samples.append((x_context, x_future))\\n\",\n    \"\\n\",\n    \"  def __len__(self) -> int:\\n\",\n    \"    return len(self.samples)\\n\",\n    \"\\n\",\n    \"  def __getitem__(\\n\",\n    \"      self, index: int\\n\",\n    \"  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\\n\",\n    \"    x_context, x_future = self.samples[index]\\n\",\n    \"\\n\",\n    \"    x_context = torch.tensor(x_context, dtype=torch.float32)\\n\",\n    \"    x_future = torch.tensor(x_future, dtype=torch.float32)\\n\",\n    \"\\n\",\n    \"    input_padding = torch.zeros_like(x_context)\\n\",\n    \"    freq = torch.tensor([self.freq_type], dtype=torch.long)\\n\",\n    \"\\n\",\n    \"    return x_context, input_padding, freq, x_future\\n\",\n    \"\\n\",\n    \"def prepare_datasets(series: np.ndarray,\\n\",\n    \"                     context_length: int,\\n\",\n    \"                     horizon_length: int,\\n\",\n    \"                     freq_type: int = 0,\\n\",\n    \"                     train_split: float = 0.8) -> Tuple[Dataset, Dataset]:\\n\",\n    \"  \\\"\\\"\\\"\\n\",\n    \"    Prepare training and validation datasets from time series data.\\n\",\n    \"\\n\",\n    \"    Args:\\n\",\n    \"        series: Input time series data\\n\",\n    \"        context_length: Number of past timesteps to use\\n\",\n    \"        horizon_length: Number of future timesteps to predict\\n\",\n    \"        freq_type: Frequency type (0, 1, or 2)\\n\",\n    \"        train_split: Fraction of data to use for training\\n\",\n    \"\\n\",\n    \"    Returns:\\n\",\n    \"        Tuple of (train_dataset, val_dataset)\\n\",\n    \"    \\\"\\\"\\\"\\n\",\n    \"  train_size = int(len(series) * train_split)\\n\",\n    \"  train_data = series[:train_size]\\n\",\n    \"  val_data = series[train_size:]\\n\",\n    \"\\n\",\n    \"  # Create datasets with specified frequency type\\n\",\n    \"  train_dataset = TimeSeriesDataset(train_data,\\n\",\n    \"                                    context_length=context_length,\\n\",\n    \"                                    horizon_length=horizon_length,\\n\",\n    \"                                    freq_type=freq_type)\\n\",\n    \"\\n\",\n    \"  val_dataset = TimeSeriesDataset(val_data,\\n\",\n    \"                                  context_length=context_length,\\n\",\n    \"                                  horizon_length=horizon_length,\\n\",\n    \"                                  freq_type=freq_type)\\n\",\n    \"\\n\",\n    \"  return train_dataset, val_dataset\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Model Creation\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 2,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def get_model(load_weights: bool = False):\\n\",\n    \"  device = \\\"cuda\\\" if torch.cuda.is_available() else \\\"cpu\\\"\\n\",\n    \"  repo_id = \\\"google/timesfm-2.0-500m-pytorch\\\"\\n\",\n    \"  hparams = TimesFmHparams(\\n\",\n    \"      backend=device,\\n\",\n    \"      per_core_batch_size=32,\\n\",\n    \"      horizon_len=128,\\n\",\n    \"      num_layers=50,\\n\",\n    \"      use_positional_embedding=False,\\n\",\n    \"      context_len=\\n\",\n    \"      192,  # Context length can be anything up to 2048 in multiples of 32\\n\",\n    \"  )\\n\",\n    \"  tfm = TimesFm(hparams=hparams,\\n\",\n    \"                checkpoint=TimesFmCheckpoint(huggingface_repo_id=repo_id))\\n\",\n    \"\\n\",\n    \"  model = PatchedTimeSeriesDecoder(tfm._model_config)\\n\",\n    \"  if load_weights:\\n\",\n    \"    checkpoint_path = path.join(snapshot_download(repo_id), \\\"torch_model.ckpt\\\")\\n\",\n    \"    loaded_checkpoint = torch.load(checkpoint_path, weights_only=True)\\n\",\n    \"    model.load_state_dict(loaded_checkpoint)\\n\",\n    \"  return model, hparams, tfm._model_config\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 3,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def plot_predictions(\\n\",\n    \"    model: TimesFm,\\n\",\n    \"    val_dataset: Dataset,\\n\",\n    \"    save_path: Optional[str] = \\\"predictions.png\\\",\\n\",\n    \") -> None:\\n\",\n    \"  \\\"\\\"\\\"\\n\",\n    \"    Plot model predictions against ground truth for a batch of validation data.\\n\",\n    \"\\n\",\n    \"    Args:\\n\",\n    \"      model: Trained TimesFM model\\n\",\n    \"      val_dataset: Validation dataset\\n\",\n    \"      save_path: Path to save the plot\\n\",\n    \"    \\\"\\\"\\\"\\n\",\n    \"  import matplotlib.pyplot as plt\\n\",\n    \"\\n\",\n    \"  model.eval()\\n\",\n    \"\\n\",\n    \"  x_context, x_padding, freq, x_future = val_dataset[0]\\n\",\n    \"  x_context = x_context.unsqueeze(0)  # Add batch dimension\\n\",\n    \"  x_padding = x_padding.unsqueeze(0)\\n\",\n    \"  freq = freq.unsqueeze(0)\\n\",\n    \"  x_future = x_future.unsqueeze(0)\\n\",\n    \"\\n\",\n    \"  device = next(model.parameters()).device\\n\",\n    \"  x_context = x_context.to(device)\\n\",\n    \"  x_padding = x_padding.to(device)\\n\",\n    \"  freq = freq.to(device)\\n\",\n    \"  x_future = x_future.to(device)\\n\",\n    \"\\n\",\n    \"  with torch.no_grad():\\n\",\n    \"    predictions = model(x_context, x_padding.float(), freq)\\n\",\n    \"    predictions_mean = predictions[..., 0]  # [B, N, horizon_len]\\n\",\n    \"    last_patch_pred = predictions_mean[:, -1, :]  # [B, horizon_len]\\n\",\n    \"\\n\",\n    \"  context_vals = x_context[0].cpu().numpy()\\n\",\n    \"  future_vals = x_future[0].cpu().numpy()\\n\",\n    \"  pred_vals = last_patch_pred[0].cpu().numpy()\\n\",\n    \"\\n\",\n    \"  context_len = len(context_vals)\\n\",\n    \"  horizon_len = len(future_vals)\\n\",\n    \"\\n\",\n    \"  plt.figure(figsize=(12, 6))\\n\",\n    \"\\n\",\n    \"  plt.plot(range(context_len),\\n\",\n    \"           context_vals,\\n\",\n    \"           label=\\\"Historical Data\\\",\\n\",\n    \"           color=\\\"blue\\\",\\n\",\n    \"           linewidth=2)\\n\",\n    \"\\n\",\n    \"  plt.plot(\\n\",\n    \"      range(context_len, context_len + horizon_len),\\n\",\n    \"      future_vals,\\n\",\n    \"      label=\\\"Ground Truth\\\",\\n\",\n    \"      color=\\\"green\\\",\\n\",\n    \"      linestyle=\\\"--\\\",\\n\",\n    \"      linewidth=2,\\n\",\n    \"  )\\n\",\n    \"\\n\",\n    \"  plt.plot(range(context_len, context_len + horizon_len),\\n\",\n    \"           pred_vals,\\n\",\n    \"           label=\\\"Prediction\\\",\\n\",\n    \"           color=\\\"red\\\",\\n\",\n    \"           linewidth=2)\\n\",\n    \"\\n\",\n    \"  plt.xlabel(\\\"Time Step\\\")\\n\",\n    \"  plt.ylabel(\\\"Value\\\")\\n\",\n    \"  plt.title(\\\"TimesFM Predictions vs Ground Truth\\\")\\n\",\n    \"  plt.legend()\\n\",\n    \"  plt.grid(True)\\n\",\n    \"\\n\",\n    \"  if save_path:\\n\",\n    \"    plt.savefig(save_path)\\n\",\n    \"    print(f\\\"Plot saved to {save_path}\\\")\\n\",\n    \"\\n\",\n    \"  plt.close()\\n\",\n    \"\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 6,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def get_data(context_len: int,\\n\",\n    \"             horizon_len: int,\\n\",\n    \"             freq_type: int = 0) -> Tuple[Dataset, Dataset]:\\n\",\n    \"  df = yf.download(\\\"AAPL\\\", start=\\\"2010-01-01\\\", end=\\\"2019-01-01\\\")\\n\",\n    \"  time_series = df[\\\"Close\\\"].values\\n\",\n    \"\\n\",\n    \"  train_dataset, val_dataset = prepare_datasets(\\n\",\n    \"      series=time_series,\\n\",\n    \"      context_length=context_len,\\n\",\n    \"      horizon_length=horizon_len,\\n\",\n    \"      freq_type=freq_type,\\n\",\n    \"      train_split=0.8,\\n\",\n    \"  )\\n\",\n    \"\\n\",\n    \"  print(f\\\"Created datasets:\\\")\\n\",\n    \"  print(f\\\"- Training samples: {len(train_dataset)}\\\")\\n\",\n    \"  print(f\\\"- Validation samples: {len(val_dataset)}\\\")\\n\",\n    \"  print(f\\\"- Using frequency type: {freq_type}\\\")\\n\",\n    \"  return train_dataset, val_dataset\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def single_gpu_example():\\n\",\n    \"  \\\"\\\"\\\"Basic example of finetuning TimesFM on stock data.\\\"\\\"\\\"\\n\",\n    \"  model, hparams, tfm_config = get_model(load_weights=True)\\n\",\n    \"  config = FinetuningConfig(batch_size=256,\\n\",\n    \"                            num_epochs=5,\\n\",\n    \"                            learning_rate=1e-4,\\n\",\n    \"                            use_wandb=True,\\n\",\n    \"                            freq_type=1,\\n\",\n    \"                            log_every_n_steps=10,\\n\",\n    \"                            val_check_interval=0.5,\\n\",\n    \"                            use_quantile_loss=True)\\n\",\n    \"\\n\",\n    \"  train_dataset, val_dataset = get_data(128,\\n\",\n    \"                                        tfm_config.horizon_len,\\n\",\n    \"                                        freq_type=config.freq_type)\\n\",\n    \"  finetuner = TimesFMFinetuner(model, config)\\n\",\n    \"\\n\",\n    \"  print(\\\"\\\\nStarting finetuning...\\\")\\n\",\n    \"  results = finetuner.finetune(train_dataset=train_dataset,\\n\",\n    \"                               val_dataset=val_dataset)\\n\",\n    \"\\n\",\n    \"  print(\\\"\\\\nFinetuning completed!\\\")\\n\",\n    \"  print(f\\\"Training history: {len(results['history']['train_loss'])} epochs\\\")\\n\",\n    \"\\n\",\n    \"  plot_predictions(\\n\",\n    \"      model=model,\\n\",\n    \"      val_dataset=val_dataset,\\n\",\n    \"      save_path=\\\"timesfm_predictions.png\\\",\\n\",\n    \"  )\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 7,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"model_id\": \"ac84aeda3a1749ae8f30b06859067bb1\",\n       \"version_major\": 2,\n       \"version_minor\": 0\n      },\n      \"text/plain\": [\n       \"Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"model_id\": \"6d9d8081fc514c6d8601a2e0e63954a2\",\n       \"version_major\": 2,\n       \"version_minor\": 0\n      },\n      \"text/plain\": [\n       \"Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[*********************100%***********************]  1 of 1 completed\\n\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Created datasets:\\n\",\n      \"- Training samples: 1556\\n\",\n      \"- Validation samples: 198\\n\",\n      \"- Using frequency type: 1\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\u001b[34m\\u001b[1mwandb\\u001b[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.\\n\",\n      \"\\u001b[34m\\u001b[1mwandb\\u001b[0m: Currently logged in as: \\u001b[33mmishacamry\\u001b[0m. Use \\u001b[1m`wandb login --relogin`\\u001b[0m to force relogin\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/html\": [\n       \"Tracking run with wandb version 0.19.1\"\n      ],\n      \"text/plain\": [\n       \"<IPython.core.display.HTML object>\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/html\": [\n       \"Run data is saved locally in <code>/home/chertushkin/forks/timesfm/notebooks/wandb/run-20250217_114343-tjs63ml2</code>\"\n      ],\n      \"text/plain\": [\n       \"<IPython.core.display.HTML object>\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/html\": [\n       \"Syncing run <strong><a href='https://wandb.ai/mishacamry/timesfm-finetuning/runs/tjs63ml2' target=\\\"_blank\\\">chocolate-eon-50</a></strong> to <a href='https://wandb.ai/mishacamry/timesfm-finetuning' target=\\\"_blank\\\">Weights & Biases</a> (<a href='https://wandb.me/developer-guide' target=\\\"_blank\\\">docs</a>)<br>\"\n      ],\n      \"text/plain\": [\n       \"<IPython.core.display.HTML object>\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/html\": [\n       \" View project at <a href='https://wandb.ai/mishacamry/timesfm-finetuning' target=\\\"_blank\\\">https://wandb.ai/mishacamry/timesfm-finetuning</a>\"\n      ],\n      \"text/plain\": [\n       \"<IPython.core.display.HTML object>\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/html\": [\n       \" View run at <a href='https://wandb.ai/mishacamry/timesfm-finetuning/runs/tjs63ml2' target=\\\"_blank\\\">https://wandb.ai/mishacamry/timesfm-finetuning/runs/tjs63ml2</a>\"\n      ],\n      \"text/plain\": [\n       \"<IPython.core.display.HTML object>\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\n\",\n      \"Starting finetuning...\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/html\": [],\n      \"text/plain\": [\n       \"<IPython.core.display.HTML object>\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/html\": [\n       \"<br>    <style><br>        .wandb-row {<br>            display: flex;<br>            flex-direction: row;<br>            flex-wrap: wrap;<br>            justify-content: flex-start;<br>            width: 100%;<br>        }<br>        .wandb-col {<br>            display: flex;<br>            flex-direction: column;<br>            flex-basis: 100%;<br>            flex: 1;<br>            padding: 10px;<br>        }<br>    </style><br><div class=\\\"wandb-row\\\"><div class=\\\"wandb-col\\\"><h3>Run history:</h3><br/><table class=\\\"wandb\\\"><tr><td>epoch</td><td>▁▃▅▆█</td></tr><tr><td>learning_rate</td><td>▁▁▁▁▁</td></tr><tr><td>train_loss</td><td>█▃▂▁▁</td></tr><tr><td>val_loss</td><td>█▁▄▁▂</td></tr></table><br/></div><div class=\\\"wandb-col\\\"><h3>Run summary:</h3><br/><table class=\\\"wandb\\\"><tr><td>epoch</td><td>5</td></tr><tr><td>learning_rate</td><td>0.0001</td></tr><tr><td>train_loss</td><td>2.85423</td></tr><tr><td>val_loss</td><td>26.7628</td></tr></table><br/></div></div>\"\n      ],\n      \"text/plain\": [\n       \"<IPython.core.display.HTML object>\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/html\": [\n       \" View run <strong style=\\\"color:#cdcd00\\\">chocolate-eon-50</strong> at: <a href='https://wandb.ai/mishacamry/timesfm-finetuning/runs/tjs63ml2' target=\\\"_blank\\\">https://wandb.ai/mishacamry/timesfm-finetuning/runs/tjs63ml2</a><br> View project at: <a href='https://wandb.ai/mishacamry/timesfm-finetuning' target=\\\"_blank\\\">https://wandb.ai/mishacamry/timesfm-finetuning</a><br>Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)\"\n      ],\n      \"text/plain\": [\n       \"<IPython.core.display.HTML object>\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/html\": [\n       \"Find logs at: <code>./wandb/run-20250217_114343-tjs63ml2/logs</code>\"\n      ],\n      \"text/plain\": [\n       \"<IPython.core.display.HTML object>\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\n\",\n      \"Finetuning completed!\\n\",\n      \"Training history: 5 epochs\\n\",\n      \"Plot saved to timesfm_predictions.png\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"single_gpu_example()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"timesfm-DnAbSweh-py3.11\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.11.10\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}\n"
  },
  {
    "path": "v1/peft/README.md",
    "content": "# Fine-Tuning Pipeline\n\nThis folder contains a generic fine-tuning pipeline designed to support multiple PEFT fine-tuning strategies.\n\n## Features\n\n- **Supported Fine-Tuning Strategies**:\n  - **Full Fine-Tuning**: Adjusts all parameters of the model during training.\n  - **[Linear Probing](https://arxiv.org/abs/2302.11939)**: Fine-tunes only the residual blocks and the embedding layer, leaving other parameters unchanged.\n  - **[LoRA (Low-Rank Adaptation)](https://arxiv.org/abs/2106.09685)**: A memory-efficient method that fine-tunes a small number of parameters by decomposing the weight matrices into low-rank matrices.\n  - **[DoRA (Directional LoRA)](https://arxiv.org/abs/2402.09353v4)**: An extension of LoRA that decomposes pre-trained weights into magnitude and direction components. It uses LoRA for directional adaptation, enhancing learning capacity and stability without additional inference overhead.\n\n## Usage\n### Fine-Tuning Script\nThe provided finetune.py script allows you to fine-tune a model with specific configurations. You can customize various parameters to suit your dataset and desired fine-tuning strategy.\n\nExample Usage:\n\n```zsh\nsource finetune.sh\n```\nThis script runs the finetune.py file with a predefined set of hyperparameters for the model. You can adjust the parameters in the script as needed.\n\n### Available Options\nRun the script with the --help flag to see a full list of available options and their descriptions:\n```zsh\npython3 finetune.py --help\n```\nScript Configuration\nYou can modify the following key parameters directly in the finetune.sh script:\nFine-Tuning Strategy: Toggle between full fine-tuning, LoRA \\[`--use-lora`\\], DoRA [\\[`--use-dora`\\]], or Linear Probing \\[`--use-linear-probing`\\].\n\n### Performance Comparison\nThe figure below compares the performance of LoRA/DoRA against Linear Probing under the following conditions:\n\n<img width=\"528\" alt=\"image\" src=\"https://github.com/user-attachments/assets/6c9f820b-5865-4821-8014-c346b9d632a5\">\n\n- Training data split: 60% train, 20% validation, 20% test.\n- Benchmark: context_len=128, horizon_len=96\n- Fine-tuning: context_len=128, horizon_len=128\n- Black: Best result.\n- Blue: Second best result.\n"
  },
  {
    "path": "v1/peft/finetune.py",
    "content": "# Copyright 2024 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nFinetune pipeline.\n\"\"\"\nimport gc\nimport logging\nimport warnings\nfrom datetime import datetime\nfrom typing import Tuple\n\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\nimport pandas as pd\nimport typer\nimport wandb\nfrom jax import numpy as jnp\nfrom paxml import checkpoint_types, checkpoints, learners, tasks_lib, trainer_lib\nfrom praxis import optimizers, pax_fiddle, py_utils, schedules\nfrom rich import print\nfrom tqdm import tqdm\nfrom typing_extensions import Annotated\n\nfrom adapter.utils import get_adapter_params, load_adapter_layer\nfrom timesfm import TimesFm, data_loader, patched_decoder\n\nNestedMap = py_utils.NestedMap\n\n\nwarnings.filterwarnings(\"ignore\")\ncmdstanpy_logger = logging.getLogger(\"cmdstanpy\")\nabsl_logger = logging.getLogger(\"absl\")\ncmdstanpy_logger.disabled = True\nabsl_logger.disabled = True\n\n\"\"\"\nTimesFM model config. These are fixed since pre-training was done \nwith this configuration.\n\"\"\"\nINPUT_PATCH_LEN = 32\nOUTPUT_PATCH_LEN = 128\nNUM_LAYERS = 20\nMODEL_DIMS = 1280\n\nQUANTILES = list(np.arange(1, 10) / 10.0)\nEPS = 1e-7\nRANDOM_SEED = 1234\n\n\ndef finetune(\n    *,\n    model_name: Annotated[\n        str, typer.Option(help=\"Specify the name of the huggingface model.\")\n    ] = \"google/timesfm-1.0-200m\",\n    checkpoint_path: Annotated[\n        str, typer.Option(help=\"The path to the local model checkpoint.\")\n    ] = None,\n    datetime_col: Annotated[str, typer.Option(help=\"Column having datetime.\")] = \"ds\",\n    ts_cols: Annotated[\n        list[str], typer.Option(help=\"Columns of time-series features.\")\n    ] = [],\n    normalize: Annotated[\n        bool, typer.Option(help=\"Normalize data for eval or not\")\n    ] = True,\n    context_len: Annotated[int, typer.Option(help=\"Length of the context window\")],\n    horizon_len: Annotated[int, typer.Option(help=\"Prediction length.\")],\n    freq: Annotated[\n        str,\n        typer.Option(\n            ...,\n            help=\"Frequency Map Str\",\n        ),\n    ],\n    data_path: Annotated[str, typer.Option(help=\"Path to dataset csv\")],\n    boundaries: Annotated[\n        Tuple[int, int, int],\n        typer.Option(\n            help=\"boundaries of dataset to train, val, test\",\n        ),\n    ] = (0, 0, 0),\n    backend: Annotated[str, typer.Option(help=\"Backend device: cpu, gpu, tpu\")],\n    batch_size: Annotated[\n        int, typer.Option(help=\"Batch size for the randomly sampled batch\")\n    ] = 16,\n    num_epochs: Annotated[int, typer.Option(help=\"Number of epochs\")],\n    learning_rate: Annotated[float, typer.Option(help=\"adam optimizer learning rate\")],\n    adam_epsilon: Annotated[float, typer.Option(help=\"adam optimizer epsilon\")],\n    adam_clip_threshold: Annotated[\n        float, typer.Option(help=\"adam optimizer clip threshold\")\n    ],\n    cos_initial_decay_value: Annotated[\n        float, typer.Option(help=\"cosine initial decay value\")\n    ],\n    cos_final_decay_value: Annotated[\n        float, typer.Option(help=\"cosine final decay value\")\n    ],\n    cos_decay_steps: Annotated[int, typer.Option(help=\"Number of cosine decay steps\")],\n    ema_decay: Annotated[float, typer.Option(help=\"Exponential moving average decay\")],\n    early_stop_patience: Annotated[\n        int, typer.Option(..., help=\"Early stopping patience\")\n    ] = 5,\n    use_lora: Annotated[\n        bool,\n        typer.Option(\n            help=\"Train low rank adapters for stacked transformer block\",\n        ),\n    ] = False,\n    lora_rank: Annotated[\n        int,\n        typer.Option(\n            help=\"LoRA Rank\",\n        ),\n    ] = 8,\n    lora_target_modules: Annotated[\n        str,\n        typer.Option(\n            help=\"LoRA target modules of the transformer block. Allowed values: [all, attention, mlp]\"\n        ),\n    ] = \"all\",\n    use_dora: Annotated[\n        bool,\n        typer.Option(\n            help=\"Apply DoRA strategy along with LoRA.\",\n        ),\n    ] = False,\n    use_linear_probing: Annotated[\n        bool,\n        typer.Option(\n            help=\"Linear Probing. Train only input/output and embedding params. Freeze params in stack transformer block.\",\n        ),\n    ] = False,\n    checkpoint_dir: Annotated[\n        str, typer.Option(help=\"Checkpoint directory\")\n    ] = \"./checkpoints\",\n    wandb_project: Annotated[\n        str, typer.Option(help=\"Weights & Biases project name\")\n    ] = \"google_timesfm_finetune\",\n) -> None:\n    key = jax.random.PRNGKey(seed=RANDOM_SEED)\n    wandb.init(project=wandb_project, config=locals())\n\n    data_df = pd.read_csv(open(data_path, \"r\"))\n\n    if boundaries == (0, 0, 0):\n        # Default boundaries: train 60%, val 20%, test 20%\n        boundaries = [\n            int(len(data_df) * 0.6),\n            int(len(data_df) * 0.8),\n            len(data_df) - 1,\n        ]\n\n    ts_cols = [col for col in data_df.columns if col != datetime_col]\n\n    dtl = data_loader.TimeSeriesdata(\n        data_path=data_path,\n        datetime_col=datetime_col,\n        num_cov_cols=None,\n        cat_cov_cols=None,\n        ts_cols=np.array(ts_cols),\n        train_range=[0, boundaries[0]],\n        val_range=[boundaries[0], boundaries[1]],\n        test_range=[boundaries[1], boundaries[2]],\n        hist_len=context_len,\n        pred_len=horizon_len,\n        batch_size=batch_size,\n        freq=freq,\n        normalize=normalize,\n        epoch_len=None,\n        holiday=False,\n        permute=False,\n    )\n\n    train_batches = dtl.tf_dataset(mode=\"train\", shift=1).batch(batch_size)\n    val_batches = dtl.tf_dataset(mode=\"val\", shift=horizon_len)\n\n    for tbatch in tqdm(train_batches.as_numpy_iterator()):\n        pass\n\n    tfm = TimesFm(\n        context_len=context_len,\n        horizon_len=horizon_len,\n        input_patch_len=INPUT_PATCH_LEN,\n        output_patch_len=OUTPUT_PATCH_LEN,\n        num_layers=NUM_LAYERS,\n        model_dims=MODEL_DIMS,\n        backend=backend,\n        per_core_batch_size=batch_size,\n        quantiles=QUANTILES,\n    )\n\n    if checkpoint_path:\n        tfm.load_from_checkpoint(\n            checkpoint_path=checkpoint_path,\n            checkpoint_type=checkpoints.CheckpointType.FLAX,\n        )\n    else:\n        tfm.load_from_checkpoint(\n            repo_id=model_name,\n            checkpoint_type=checkpoints.CheckpointType.FLAX,\n        )\n\n    model = pax_fiddle.Config(\n        patched_decoder.PatchedDecoderFinetuneModel,\n        name=\"patched_decoder_finetune\",\n        core_layer_tpl=tfm.model_p,\n    )\n\n    if use_lora:\n        load_adapter_layer(\n            mdl_vars=tfm._train_state.mdl_vars,\n            model=model.core_layer_tpl,\n            lora_rank=lora_rank,\n            lora_target_modules=lora_target_modules,\n            use_dora=use_dora,\n        )\n\n    @pax_fiddle.auto_config\n    def build_learner() -> learners.Learner:\n        bprop_variable_inclusion = []\n        bprop_variable_exclusion = []\n        if use_lora:\n            bprop_variable_inclusion.append(r\"^.*lora.*$\")\n            if use_dora:\n                bprop_variable_inclusion.append(r\"^.*dora.*$\")\n        elif use_linear_probing:\n            bprop_variable_exclusion = [\".*/stacked_transformer_layer/.*\"]\n\n        return pax_fiddle.Config(\n            learners.Learner,\n            name=\"learner\",\n            loss_name=\"avg_qloss\",\n            optimizer=optimizers.Adam(\n                epsilon=adam_epsilon,\n                clip_threshold=adam_clip_threshold,\n                learning_rate=learning_rate,\n                lr_schedule=pax_fiddle.Config(\n                    schedules.Cosine,\n                    initial_value=cos_initial_decay_value,\n                    final_value=cos_final_decay_value,\n                    total_steps=cos_decay_steps,\n                ),\n                ema_decay=ema_decay,\n            ),\n            bprop_variable_exclusion=bprop_variable_exclusion,\n            bprop_variable_inclusion=bprop_variable_inclusion,\n        )\n\n    task_p = tasks_lib.SingleTask(\n        name=\"ts-learn\",\n        model=model,\n        train=tasks_lib.SingleTask.Train(\n            learner=build_learner(),\n        ),\n    )\n\n    task_p.model.ici_mesh_shape = [1, 1, 1]\n    task_p.model.mesh_axis_names = [\"replica\", \"data\", \"mdl\"]\n\n    DEVICES = np.array(jax.devices()).reshape([1, 1, 1])\n    jax.sharding.Mesh(DEVICES, [\"replica\", \"data\", \"mdl\"])\n\n    num_devices = jax.local_device_count()\n    print(f\"num_devices: {num_devices}\")\n    print(f\"device kind: {jax.local_devices()[0].device_kind}\")\n\n    jax_task = task_p\n    key, init_key = jax.random.split(key)\n\n    def process_train_batch(batch):\n        past_ts = batch[0].reshape(batch_size * len(ts_cols), -1)\n        actual_ts = batch[3].reshape(batch_size * len(ts_cols), -1)\n        return NestedMap(input_ts=past_ts, actual_ts=actual_ts)\n\n    def process_eval_batch(batch):\n        past_ts = batch[0]\n        actual_ts = batch[3]\n        return NestedMap(input_ts=past_ts, actual_ts=actual_ts)\n\n    jax_model_states, _ = trainer_lib.initialize_model_state(\n        jax_task,\n        init_key,\n        process_train_batch(tbatch),\n        checkpoint_type=checkpoint_types.CheckpointType.GDA,\n    )\n    jax_model_states.mdl_vars[\"params\"][\"core_layer\"] = tfm._train_state.mdl_vars[\n        \"params\"\n    ]\n    gc.collect()\n\n    jax_task = task_p\n\n    def train_step(states, prng_key, inputs):\n        return trainer_lib.train_step_single_learner(jax_task, states, prng_key, inputs)\n\n    def eval_step(states, prng_key, inputs):\n        states = states.to_eval_state()\n        return trainer_lib.eval_step_single_learner(jax_task, states, prng_key, inputs)\n\n    key, train_key, eval_key = jax.random.split(key, 3)\n    train_prng_seed = jax.random.split(train_key, num=jax.local_device_count())\n    eval_prng_seed = jax.random.split(eval_key, num=jax.local_device_count())\n\n    p_train_step = jax.pmap(train_step, axis_name=\"batch\")\n    p_eval_step = jax.pmap(eval_step, axis_name=\"batch\")\n\n    replicated_jax_states = trainer_lib.replicate_model_state(jax_model_states)\n\n    def reshape_batch_for_pmap(batch, num_devices):\n        def _reshape(input_tensor):\n            bsize = input_tensor.shape[0]\n            residual_shape = list(input_tensor.shape[1:])\n            nbsize = bsize // num_devices\n            return jnp.reshape(input_tensor, [num_devices, nbsize] + residual_shape)\n\n        return jax.tree.map(_reshape, batch)\n\n    patience = 0\n    best_eval_loss = 1e7\n    checkpoint_dir = f\"{checkpoint_dir}/run_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{wandb.run.id}\"\n    for epoch in range(num_epochs):\n        if patience >= early_stop_patience:\n            print(\"Early stopping.\")\n            break\n        print(f\"Epoch: {epoch + 1}\")\n        train_its = train_batches.as_numpy_iterator()\n        train_losses = []\n        for batch in tqdm(train_its):\n            tbatch = process_train_batch(batch)\n            tbatch = reshape_batch_for_pmap(tbatch, num_devices)\n            replicated_jax_states, step_fun_out = p_train_step(\n                replicated_jax_states, train_prng_seed, tbatch\n            )\n            train_losses.append(step_fun_out.loss[0])\n            wandb.log({\"train_step_loss\": step_fun_out.loss[0]})\n\n        avg_train_loss = np.mean(train_losses)\n\n        print(\"Starting eval.\")\n        val_its = val_batches.as_numpy_iterator()\n        eval_losses = []\n        for ev_batch in tqdm(val_its):\n            ebatch = process_eval_batch(ev_batch)\n            ebatch = reshape_batch_for_pmap(ebatch, num_devices)\n            _, step_fun_out = p_eval_step(replicated_jax_states, eval_prng_seed, ebatch)\n            eval_losses.append(step_fun_out.loss[0])\n            wandb.log({\"eval_step_loss\": step_fun_out.loss[0]})\n\n        avg_eval_loss = np.mean(eval_losses)\n\n        print(f\"Train Loss: {avg_train_loss}, Val Loss: {avg_eval_loss}\")\n\n        wandb.log(\n            {\n                \"epoch\": epoch + 1,\n                \"avg_train_loss\": avg_train_loss,\n                \"avg_val_loss\": avg_eval_loss,\n            }\n        )\n\n        if avg_eval_loss < best_eval_loss or np.isnan(avg_eval_loss):\n            best_eval_loss = avg_eval_loss\n            print(\"Saving checkpoint.\")\n            jax_state_for_saving = py_utils.maybe_unreplicate_for_fully_replicated(\n                replicated_jax_states\n            )\n            if use_lora:\n                adapter_params = get_adapter_params(\n                    params=jax_state_for_saving.mdl_vars,\n                    lora_target_modules=lora_target_modules,\n                    num_layers=NUM_LAYERS,\n                    use_dora=use_dora,\n                )\n                jax_state_for_saving.mdl_vars[\"params\"] = adapter_params\n\n            checkpoints.save_checkpoint(\n                jax_state_for_saving, checkpoint_dir, overwrite=True\n            )\n\n            patience = 0\n            del jax_state_for_saving\n            gc.collect()\n        else:\n            patience += 1\n            print(f\"patience: {patience}\")\n    print(\"Fine-tuning completed.\")\n\n\nif __name__ == \"__main__\":\n    typer.run(finetune)\n"
  },
  {
    "path": "v1/peft/finetune.sh",
    "content": "#!/bin/bash\n\n# Script to finetune a model with specific configurations\n# Adjust the parameters below as needed. For a full list of options and descriptions, run the script with the --help flag.\n\nexport TF_CPP_MIN_LOG_LEVEL=2 XLA_PYTHON_CLIENT_PREALLOCATE=false\n\npython3 finetune.py \\\n    --model-name=\"google/timesfm-1.0-200m\" \\\n    --backend=\"cpu\" \\\n    --horizon-len=128 \\\n    --context-len=512 \\\n    --freq=\"15min\" \\\n    --data-path=\"../datasets/ETT-small/ETTm1.csv\" \\\n    --num-epochs=100 \\\n    --learning-rate=1e-3 \\\n    --adam-epsilon=1e-7 \\\n    --adam-clip-threshold=1e2 \\\n    --early-stop-patience=10 \\\n    --datetime-col=\"date\" \\\n    --use-lora \\\n    --lora-rank=1 \\\n    --lora-target-modules=\"all\" \\\n    --use-dora \\\n    --cos-initial-decay-value=1e-4 \\\n    --cos-decay-steps=40000 \\\n    --cos-final-decay-value=1e-5 \\\n    --ema-decay=0.9999\n\n# To see all available options and their descriptions, use the --help flag\n# python3 finetune.py --help\n"
  },
  {
    "path": "v1/peft/usage.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Load Base Model\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from timesfm import TimesFm, freq_map, data_loader\\n\",\n    \"from adapter.utils import load_adapter_checkpoint\\n\",\n    \"from tqdm import tqdm\\n\",\n    \"import numpy as np\\n\",\n    \"import pandas as pd\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"tfm = TimesFm(\\n\",\n    \"    context_len=512,\\n\",\n    \"    horizon_len=128,\\n\",\n    \"    input_patch_len=32,\\n\",\n    \"    output_patch_len=128,\\n\",\n    \"    num_layers=20,\\n\",\n    \"    model_dims=1280,\\n\",\n    \"    backend=\\\"cpu\\\",\\n\",\n    \")\\n\",\n    \"tfm.load_from_checkpoint(repo_id=\\\"google/timesfm-1.0-200m\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"DATA_DICT = {\\n\",\n    \"    \\\"ettm2\\\": {\\n\",\n    \"        \\\"boundaries\\\": [34560, 46080, 57600],\\n\",\n    \"        \\\"data_path\\\": \\\"../datasets/ETT-small/ETTm2.csv\\\",\\n\",\n    \"        \\\"freq\\\": \\\"15min\\\",\\n\",\n    \"    },\\n\",\n    \"    \\\"ettm1\\\": {\\n\",\n    \"        \\\"boundaries\\\": [34560, 46080, 57600],\\n\",\n    \"        \\\"data_path\\\": \\\"../datasets/ETT-small/ETTm1.csv\\\",\\n\",\n    \"        \\\"freq\\\": \\\"15min\\\",\\n\",\n    \"    },\\n\",\n    \"    \\\"etth2\\\": {\\n\",\n    \"        \\\"boundaries\\\": [8640, 11520, 14400],\\n\",\n    \"        \\\"data_path\\\": \\\"../datasets/ETT-small/ETTh2.csv\\\",\\n\",\n    \"        \\\"freq\\\": \\\"H\\\",\\n\",\n    \"    },\\n\",\n    \"    \\\"etth1\\\": {\\n\",\n    \"        \\\"boundaries\\\": [8640, 11520, 14400],\\n\",\n    \"        \\\"data_path\\\": \\\"../datasets/ETT-small/ETTh1.csv\\\",\\n\",\n    \"        \\\"freq\\\": \\\"H\\\",\\n\",\n    \"    },\\n\",\n    \"    \\\"elec\\\": {\\n\",\n    \"        \\\"boundaries\\\": [18413, 21044, 26304],\\n\",\n    \"        \\\"data_path\\\": \\\"../datasets/electricity/electricity.csv\\\",\\n\",\n    \"        \\\"freq\\\": \\\"H\\\",\\n\",\n    \"    },\\n\",\n    \"    \\\"traffic\\\": {\\n\",\n    \"        \\\"boundaries\\\": [12280, 14036, 17544],\\n\",\n    \"        \\\"data_path\\\": \\\"../datasets/traffic/traffic.csv\\\",\\n\",\n    \"        \\\"freq\\\": \\\"H\\\",\\n\",\n    \"    },\\n\",\n    \"    \\\"weather\\\": {\\n\",\n    \"        \\\"boundaries\\\": [36887, 42157, 52696],\\n\",\n    \"        \\\"data_path\\\": \\\"../datasets/weather/weather.csv\\\",\\n\",\n    \"        \\\"freq\\\": \\\"10min\\\",\\n\",\n    \"    },\\n\",\n    \"}\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Load Adapter Checkpoint\\n\",\n    \"\\n\",\n    \"Specify the adapter checkpoint path, rank and the modules used to train the adapters and whether dora was employed or not.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"load_adapter_checkpoint(\\n\",\n    \"    model=tfm,\\n\",\n    \"    adapter_checkpoint_path=\\\"./checkpoints/run_20240716_163900_lyo4psz3\\\",\\n\",\n    \"    lora_rank=1,\\n\",\n    \"    lora_target_modules=\\\"all\\\",\\n\",\n    \"    use_dora=True,\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Test Performance\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"dataset = \\\"ettm1\\\"\\n\",\n    \"data_path = DATA_DICT[dataset][\\\"data_path\\\"]\\n\",\n    \"freq = DATA_DICT[dataset][\\\"freq\\\"]\\n\",\n    \"int_freq = freq_map(freq)\\n\",\n    \"boundaries = DATA_DICT[dataset][\\\"boundaries\\\"]\\n\",\n    \"\\n\",\n    \"data_df = pd.read_csv(open(data_path, \\\"r\\\"))\\n\",\n    \"\\n\",\n    \"ts_cols = [col for col in data_df.columns if col != \\\"date\\\"]\\n\",\n    \"num_cov_cols = None\\n\",\n    \"cat_cov_cols = None\\n\",\n    \"\\n\",\n    \"context_len = 512\\n\",\n    \"pred_len = 96\\n\",\n    \"\\n\",\n    \"num_ts = len(ts_cols)\\n\",\n    \"batch_size = 16\\n\",\n    \"\\n\",\n    \"dtl = data_loader.TimeSeriesdata(\\n\",\n    \"    data_path=data_path,\\n\",\n    \"    datetime_col=\\\"date\\\",\\n\",\n    \"    num_cov_cols=num_cov_cols,\\n\",\n    \"    cat_cov_cols=cat_cov_cols,\\n\",\n    \"    ts_cols=np.array(ts_cols),\\n\",\n    \"    train_range=[0, boundaries[0]],\\n\",\n    \"    val_range=[boundaries[0], boundaries[1]],\\n\",\n    \"    test_range=[boundaries[1], boundaries[2]],\\n\",\n    \"    hist_len=context_len,\\n\",\n    \"    pred_len=pred_len,\\n\",\n    \"    batch_size=num_ts,\\n\",\n    \"    freq=\\\"15min\\\",\\n\",\n    \"    normalize=True,\\n\",\n    \"    epoch_len=None,\\n\",\n    \"    holiday=False,\\n\",\n    \"    permute=True,\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"test_batches = dtl.tf_dataset(mode=\\\"test\\\", shift=pred_len)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"mae_losses = []\\n\",\n    \"for batch in tqdm(test_batches.as_numpy_iterator()):\\n\",\n    \"    past = batch[0]\\n\",\n    \"    actuals = batch[3]\\n\",\n    \"    _, forecasts = tfm.forecast(list(past), [0] * past.shape[0])\\n\",\n    \"    forecasts = forecasts[:, 0 : actuals.shape[1], 5]\\n\",\n    \"    mae_losses.append(np.abs(forecasts - actuals).mean())\\n\",\n    \"\\n\",\n    \"print(f\\\"MAE: {np.mean(mae_losses)}\\\")\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"tanmay_tfm_env\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.10.14\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}\n"
  },
  {
    "path": "v1/pyproject.toml",
    "content": "[tool.poetry]\nname = \"timesfm\"\npackages = [\n    { include = \"timesfm\", from = \"src\" },\n    { include = \"finetuning\", from = \"src\" },\n]\ndescription = \"Open weights time-series foundation model from Google Research.\"\nversion = \"1.3.0\"\nauthors = [\n    \"Rajat Sen <senrajat@google.com>\",\n    \"Yichen Zhou <yichenzhou@google.com>\",\n    \"Abhimanyu Das <abhidas@google.com>\",\n    \"Petros Mol <pmol@google.com>\",\n    \"Justin Güse <guese.justin@gmail.com>\",\n    \"Michael Chertushkin <chertushkinmichael@gmail.com>\"\n]\nreadme = \"README.md\"\nkeywords = [\"time series\", \"timesfm\", \"forecast\", \"time series model\"]\nhomepage = \"https://github.com/google-research/timesfm\"\nrepository = \"https://github.com/google-research/timesfm\"\nclassifiers = [\n    \"Environment :: Console\",\n    \"Framework :: Flake8\",\n    \"Operating System :: OS Independent\",\n    \"Topic :: Software Development :: Documentation\",\n    \"Topic :: Software Development :: Libraries :: Python Modules\",\n    \"Topic :: Software Development :: Quality Assurance\",\n]\ninclude = [\"LICENSE\"]\n\n[tool.poetry.dependencies]\npython = \">=3.10,<3.12\"\neinshape = \">=1.0.0\"\nnumpy = \">=1.26.4\"\npandas = \">=2.0.0\"\nutilsforecast = \">=0.1.10\"\nhuggingface_hub = { version = \">=0.23.0\", extras = [\"cli\"] }\nscikit-learn = \">=1.2.2\"\ntyper = \">=0.12.3\"\nwandb = \">=0.17.5\"\nabsl-py = \">=1.4.0\"\nsafetensors = \"^0.5.3\"\n\n[tool.poetry.extras]\n# Note: `lingvo` is an optional Google-internal dependency with strict Python\n# version and packaging constraints that cause install failures on some\n# environments (Colab etc.). We omit it from the pax extra here so users can\n# opt-in explicitly if they need it and have a compatible environment.\npax = [\"paxml\", \"jax\", \"jaxlib\"]\ntorch = [\"torch\"] \n\n[tool.poetry.dependencies.paxml]\nversion = \">=1.4.0\"\npython = \">=3.10,<3.11\"\n\n\n\n[tool.poetry.dependencies.jax]\nversion = \">=0.4.26\"\nextras = [\"cuda12\"]\npython = \">=3.10,<3.12\"  # Support both python versions\n\n[tool.poetry.dependencies.jaxlib]\nversion = \">=0.4.26\"\npython = \">=3.10,<3.12\"  # Support both python versions\n\n[tool.poetry.dependencies.torch]\nversion = \">=2.0.0\"\nextras = [\"cuda\"]\npython = \">=3.11,<3.12\"\n\n[tool.poetry.group.dev.dependencies]\npytest = \">=8.3.2\"\n\n[build-system]\nrequires = [\"poetry-core\"]\nbuild-backend = \"poetry.core.masonry.api\"\n"
  },
  {
    "path": "v1/src/adapter/__init__.py",
    "content": "# Copyright 2024 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"adapter init file.\"\"\"\n\nfrom .dora_layers import DoraAttentionProjection, DoraCombinedQKVProjection, DoraLinear\nfrom .lora_layers import LoraAttentionProjection, LoraCombinedQKVProjection, LoraLinear\n"
  },
  {
    "path": "v1/src/adapter/dora_layers.py",
    "content": "# Copyright 2024 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom jax import numpy as jnp\nfrom praxis import base_layer\nfrom praxis.layers import attentions, linears\n\nWeightInit = base_layer.WeightInit\nWeightHParams = base_layer.WeightHParams\n\n\nclass DoraTheta(base_layer.Theta):\n    def __init__(self, module):\n        self.module = module\n\n    def _dora_initialized(self):\n        if (\n            self.module.has_variable(\"params\", \"lora_a\")\n            and self.module.has_variable(\"params\", \"lora_b\")\n            and self.module.has_variable(\"params\", \"dora_m\")\n            and \"lora_a\" in self.module._weight_hparams\n            and \"lora_b\" in self.module._weight_hparams\n            and \"dora_m\" in self.module._weight_hparams\n        ):\n            return True\n        else:\n            return False\n\n    def _dorafy_var(self, w):\n        lora_a = super().__getattr__(\"lora_a\")\n        lora_b = super().__getattr__(\"lora_b\")\n        dora_m = super().__getattr__(\"dora_m\")\n\n        lora_delta = self.module.einsum(\"...dr,...nr->...dn\", lora_a, lora_b)\n        lora_delta = jnp.reshape(lora_delta, w.shape)\n\n        w_prime = w + lora_delta\n\n        column_norm = jnp.linalg.norm(w_prime, ord=2, axis=0, keepdims=True)\n        norm_adapted = w_prime / column_norm\n        w_prime = dora_m * norm_adapted\n        return w_prime\n\n    def __getattr__(self, k):\n        var = super().__getattr__(k)\n        if not self._dora_initialized():\n            return var\n\n        if k == \"w\":\n            return self._dorafy_var(var)\n\n        return var\n\n    def __getitem__(self, k):\n        var = super().__getattr__(k)\n        if not self._dora_initialized():\n            return var\n\n        if k == \"w\":\n            return self._dorafy_var(var)\n\n        return var\n\n\nclass DoraThetaDescriptor:\n    \"\"\"Dot syntax accession descriptor.\"\"\"\n\n    def __get__(self, obj, objtype=None):\n        return DoraTheta(obj)\n\n\nclass DoraLinear(linears.Linear):\n    rank: int = 0\n    lora_init: WeightInit | None = None\n    theta = DoraThetaDescriptor()\n\n    def setup(self) -> None:\n        lora_init = self.lora_init if self.lora_init else self.weight_init\n\n        super().setup()\n        self.create_variable(\n            \"lora_a\",\n            WeightHParams(\n                shape=[self.input_dims, self.rank],\n                init=lora_init,\n                mesh_shape=self.mesh_shape,\n                tensor_split_dims_mapping=[None, None],\n            ),\n        )\n        self.create_variable(\n            \"lora_b\",\n            WeightHParams(\n                shape=[self.output_dims, self.rank],\n                init=WeightInit.Constant(scale=0.0),\n                mesh_shape=self.mesh_shape,\n                tensor_split_dims_mapping=[None, None],\n            ),\n        )\n        self.create_variable(\n            \"dora_m\",\n            WeightHParams(\n                shape=[1, self.output_dims],\n                init=lora_init,\n                mesh_shape=self.mesh_shape,\n                tensor_split_dims_mapping=[None, None],\n            ),\n        )\n\n\nclass DoraAttentionProjection(attentions.AttentionProjection):\n    rank: int = 0\n    lora_init: WeightInit | None = None\n    theta = DoraThetaDescriptor()\n\n    def setup(self) -> None:\n        super().setup()\n        w_weight_params = self._weight_hparams[\"w\"]\n        lora_init = self.lora_init if self.lora_init else w_weight_params.init\n\n        self.create_variable(\n            \"lora_a\",\n            WeightHParams(\n                shape=[self.input_dim, self.rank],\n                init=lora_init,\n                mesh_shape=self.mesh_shape,\n                tensor_split_dims_mapping=[\n                    None,\n                    None,\n                ],\n            ),\n        )\n        self.create_variable(\n            \"lora_b\",\n            WeightHParams(\n                shape=[self.dim_per_head * self.num_heads, self.rank],\n                init=WeightInit.Constant(scale=0.0),\n                mesh_shape=self.mesh_shape,\n                tensor_split_dims_mapping=[\n                    None,\n                    None,\n                ],\n            ),\n        )\n        self.create_variable(\n            \"dora_m\",\n            WeightHParams(\n                shape=[1, self.num_heads, self.dim_per_head],\n                init=lora_init,\n                mesh_shape=self.mesh_shape,\n                tensor_split_dims_mapping=[None, None, None],\n            ),\n        )\n\n\nclass DoraCombinedQKVProjection(attentions.CombinedQKVProjectionLayer):\n    rank: int = 0\n    lora_init: WeightInit | None = None\n    theta = DoraThetaDescriptor()\n\n    def setup(self) -> None:\n        super().setup()\n        w_weight_params = self._weight_hparams[\"w\"]\n        lora_init = self.lora_init if self.lora_init else w_weight_params.init\n\n        self.create_variable(\n            \"lora_a\",\n            WeightHParams(\n                shape=[3, self.input_dim, self.rank],\n                init=lora_init,\n                mesh_shape=self.mesh_shape,\n                tensor_split_dims_mapping=[None, None, None],\n            ),\n        )\n        self.create_variable(\n            \"lora_b\",\n            WeightHParams(\n                shape=[3, self.dim_per_head * self.num_heads, self.rank],\n                init=WeightInit.Constant(scale=0.0),\n                mesh_shape=self.mesh_shape,\n                tensor_split_dims_mapping=[None, None, None],\n            ),\n        )\n        self.create_variable(\n            \"dora_m\",\n            WeightHParams(\n                shape=[3, 1, self.num_heads, self.dim_per_head],\n                init=lora_init,\n                mesh_shape=self.mesh_shape,\n                tensor_split_dims_mapping=[None, None, None, None],\n            ),\n        )\n"
  },
  {
    "path": "v1/src/adapter/lora_layers.py",
    "content": "# Copyright 2024 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom jax import numpy as jnp\nfrom praxis import base_layer\nfrom praxis.layers import attentions, linears\n\nWeightInit = base_layer.WeightInit\nWeightHParams = base_layer.WeightHParams\n\n\nclass LoraTheta(base_layer.Theta):\n    def __init__(self, module):\n        self.module = module\n\n    def _lora_initialized(self):\n        if (\n            self.module.has_variable(\"params\", \"lora_a\")\n            and self.module.has_variable(\"params\", \"lora_b\")\n            and \"lora_a\" in self.module._weight_hparams\n            and \"lora_b\" in self.module._weight_hparams\n        ):\n            return True\n        else:\n            return False\n\n    def _lorafy_var(self, w):\n        lora_a = super().__getattr__(\"lora_a\")\n        lora_b = super().__getattr__(\"lora_b\")\n        lora_delta = self.module.einsum(\"...dr,...nr->...dn\", lora_a, lora_b)\n        lora_delta = jnp.reshape(lora_delta, w.shape)\n        w_prime = w + lora_delta\n        return w_prime\n\n    def __getattr__(self, k):\n        var = super().__getattr__(k)\n        if not self._lora_initialized():\n            return var\n\n        if k == \"w\":\n            return self._lorafy_var(var)\n\n        return var\n\n    def __getitem__(self, k):\n        var = super().__getattr__(k)\n        if not self._lora_initialized():\n            return var\n\n        if k == \"w\":\n            return self._lorafy_var(var)\n\n        return var\n\n\nclass LoraThetaDescriptor:\n    \"\"\"Dot syntax accession descriptor.\"\"\"\n\n    def __get__(self, obj, objtype=None):\n        return LoraTheta(obj)\n\n\nclass LoraLinear(linears.Linear):\n    rank: int = 0\n    lora_init: WeightInit | None = None\n    theta = LoraThetaDescriptor()\n\n    def setup(self) -> None:\n        lora_init = self.lora_init if self.lora_init else self.weight_init\n\n        super().setup()\n        self.create_variable(\n            \"lora_a\",\n            WeightHParams(\n                shape=[self.input_dims, self.rank],\n                init=lora_init,\n                mesh_shape=self.mesh_shape,\n                tensor_split_dims_mapping=[None, None],\n            ),\n        )\n        self.create_variable(\n            \"lora_b\",\n            WeightHParams(\n                shape=[self.output_dims, self.rank],\n                init=WeightInit.Constant(scale=0.0),\n                mesh_shape=self.mesh_shape,\n                tensor_split_dims_mapping=[None, None],\n            ),\n        )\n\n\nclass LoraAttentionProjection(attentions.AttentionProjection):\n    rank: int = 0\n    lora_init: WeightInit | None = None\n    theta = LoraThetaDescriptor()\n\n    def setup(self) -> None:\n        super().setup()\n        w_weight_params = self._weight_hparams[\"w\"]\n        lora_init = self.lora_init if self.lora_init else w_weight_params.init\n\n        self.create_variable(\n            \"lora_a\",\n            WeightHParams(\n                shape=[self.input_dim, self.rank],\n                init=lora_init,\n                mesh_shape=self.mesh_shape,\n                tensor_split_dims_mapping=[\n                    None,\n                    None,\n                ],\n            ),\n        )\n        self.create_variable(\n            \"lora_b\",\n            WeightHParams(\n                shape=[self.dim_per_head * self.num_heads, self.rank],\n                init=WeightInit.Constant(scale=0.0),\n                mesh_shape=self.mesh_shape,\n                tensor_split_dims_mapping=[\n                    None,\n                    None,\n                ],\n            ),\n        )\n\n\nclass LoraCombinedQKVProjection(attentions.CombinedQKVProjectionLayer):\n    rank: int = 0\n    lora_init: WeightInit | None = None\n    theta = LoraThetaDescriptor()\n\n    def setup(self) -> None:\n        super().setup()\n        w_weight_params = self._weight_hparams[\"w\"]\n        lora_init = self.lora_init if self.lora_init else w_weight_params.init\n\n        self.create_variable(\n            \"lora_a\",\n            WeightHParams(\n                shape=[3, self.input_dim, self.rank],\n                init=lora_init,\n                mesh_shape=self.mesh_shape,\n                tensor_split_dims_mapping=[None, None, None],\n            ),\n        )\n        self.create_variable(\n            \"lora_b\",\n            WeightHParams(\n                shape=[3, self.dim_per_head * self.num_heads, self.rank],\n                init=WeightInit.Constant(scale=0.0),\n                mesh_shape=self.mesh_shape,\n                tensor_split_dims_mapping=[None, None, None],\n            ),\n        )\n"
  },
  {
    "path": "v1/src/adapter/utils.py",
    "content": "# Copyright 2024 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nThis file provides functionality for loading and merging adapter weights\nin timesfm model, specifically for LoRA and DoRA.\nLoRA: https://arxiv.org/abs/2106.09685\nDoRA: https://arxiv.org/abs/2402.09353v4 \n\"\"\"\n\nimport time\n\nimport jax\nimport jax.numpy as jnp\nfrom paxml import checkpoints, tasks_lib\nfrom paxml.train_states import TrainState\nfrom praxis import pax_fiddle\n\nfrom adapter.dora_layers import (\n    DoraAttentionProjection,\n    DoraCombinedQKVProjection,\n    DoraLinear,\n)\nfrom adapter.lora_layers import (\n    LoraAttentionProjection,\n    LoraCombinedQKVProjection,\n    LoraLinear,\n)\nfrom timesfm import TimesFm\n\n\ndef get_adapter_params(\n    params: dict, lora_target_modules: str, num_layers: int, use_dora: bool = False\n) -> dict:\n    \"\"\"\n    Extracts adapter parameters from the given model parameters for saving the checkpoint.\n\n    Args:\n        params (dict): The full model parameters.\n        lora_target_modules (str): Target modules for LoRA/DoRA adaptation.\n        num_layers (int): Number of transformer layers.\n        use_dora (bool, optional): Whether DoRA was used or not. Defaults to False.\n\n    Returns:\n        dict: A dictionary containing the extracted adapter parameters.\n    \"\"\"\n    adapter_params = {}\n    for i in range(num_layers):\n        layer_key = f\"x_layers_{i}\"\n        adapter_params[layer_key] = {}\n\n        if lora_target_modules in [\"all\", \"mlp\"]:\n            for ff_layer_key in [\"ffn_layer1\", \"ffn_layer2\"]:\n                linear = params[\"params\"][\"core_layer\"][\"stacked_transformer_layer\"][\n                    layer_key\n                ][\"ff_layer\"][ff_layer_key][\"linear\"]\n\n                lora_a = linear[\"lora_a\"]\n                lora_b = linear[\"lora_b\"]\n\n                adapter_params[layer_key][ff_layer_key] = {\n                    \"lora_a\": lora_a,\n                    \"lora_b\": lora_b,\n                }\n\n                if use_dora:\n                    adapter_params[layer_key][ff_layer_key][\"dora_m\"] = linear[\"dora_m\"]\n\n        if lora_target_modules in [\"all\", \"attention\"]:\n            attention = params[\"params\"][\"core_layer\"][\"stacked_transformer_layer\"][\n                layer_key\n            ][\"self_attention\"]\n\n            for component in [\"key\", \"query\", \"value\", \"post\"]:\n                lora_a = attention[component][\"lora_a\"]\n                lora_b = attention[component][\"lora_b\"]\n\n                adapter_params[layer_key][component] = {\n                    \"lora_a\": lora_a,\n                    \"lora_b\": lora_b,\n                }\n\n                if use_dora:\n                    adapter_params[layer_key][component][\"dora_m\"] = attention[\n                        component\n                    ][\"dora_m\"]\n    return adapter_params\n\n\ndef load_adapter_checkpoint(\n    model: TimesFm,\n    adapter_checkpoint_path: str,\n    lora_rank: int,\n    lora_target_modules: str,\n    use_dora: bool,\n) -> None:\n    \"\"\"\n    Loads an adapter checkpoint and merges it with the original model weights.\n\n    Args:\n        model (TimesFm): The model to update.\n        adapter_checkpoint_path (str): Path to the adapter checkpoint.\n        lora_rank (int): Rank of the LoRA adaptation.\n        lora_target_modules (str): Target modules for adaptation.\n        use_dora (bool): Whether DoRA was used or not.\n\n    Returns:\n        None\n    \"\"\"\n\n    \"\"\"\n    currently loading and initializing the model with adapter layers first and then merging the\n    adapter weights to original weights and replacing the adapter layers back to original layer.\n    # NOTE: refactor this. there should be a better way to load the LoRA checkpoint.\n    \"\"\"\n    model._logging(f\"Restoring adapter checkpoint from {adapter_checkpoint_path}.\")\n    start_time = time.time()\n    original_linear_tpl, original_attn_tpl, original_combined_qkv_tpl = (\n        load_adapter_layer(\n            mdl_vars=model._train_state.mdl_vars,\n            model=model._model,\n            lora_rank=lora_rank,\n            lora_target_modules=lora_target_modules,\n            use_dora=use_dora,\n        )\n    )\n\n    var_weight_hparams = model._model.abstract_init_with_metadata(\n        model._get_sample_inputs(), do_eval=True\n    )\n\n    adapter_weight_hparams = _get_adapter_weight_params(\n        var_weight_hparams=var_weight_hparams,\n        lora_target_modules=lora_target_modules,\n        num_layers=model._model.stacked_transformer_params_tpl.num_layers,\n        use_dora=use_dora,\n    )\n\n    adapter_state_partition_specs = tasks_lib.create_state_partition_specs(\n        adapter_weight_hparams,\n        mesh_shape=model.mesh_shape,\n        mesh_axis_names=model.mesh_name,\n        discard_opt_states=True,\n        learners=None,\n    )\n    adapter_state_local_shapes = tasks_lib.create_state_unpadded_shapes(\n        adapter_weight_hparams,\n        discard_opt_states=True,\n        learners=None,\n    )\n    adapter_train_state = checkpoints.restore_checkpoint(\n        state_global_shapes=adapter_state_local_shapes,\n        checkpoint_dir=adapter_checkpoint_path,\n        checkpoint_type=checkpoints.CheckpointType.FLAX,\n        state_specs=adapter_state_partition_specs,\n        step=None,\n    )\n\n    # add adapter weights to the original weights\n    _merge_adapter_weights(\n        model=model,\n        adapter_train_state=adapter_train_state,\n        lora_target_modules=lora_target_modules,\n        num_layers=model._model.stacked_transformer_params_tpl.num_layers,\n        use_dora=use_dora,\n    )\n\n    # replace back with the original model layer\n    if lora_target_modules in [\"all\", \"mlp\"]:\n        model._model.stacked_transformer_params_tpl.transformer_layer_params_tpl.tr_fflayer_tpl.fflayer_tpl.linear_tpl = (\n            original_linear_tpl\n        )\n\n    if lora_target_modules in [\"all\", \"attention\"]:\n        model._model.stacked_transformer_params_tpl.transformer_layer_params_tpl.tr_atten_tpl.proj_tpl = (\n            original_attn_tpl\n        )\n        model._model.stacked_transformer_params_tpl.transformer_layer_params_tpl.tr_atten_tpl.combined_qkv_proj_tpl = (\n            original_combined_qkv_tpl\n        )\n    model._logging(\n        f\"Restored adapter checkpoint in {time.time() - start_time:.2f} seconds.\"\n    )\n\n    # jit compile the model\n    model.jit_decode()\n\n\ndef _merge_adapter_weights(\n    model: TimesFm,\n    adapter_train_state: TrainState,\n    lora_target_modules: str,\n    num_layers: int,\n    use_dora: bool,\n) -> None:\n    \"\"\"\n    Merges adapter weights with the original model weights.\n\n    Args:\n        model (TimesFm): The model to update.\n        adapter_train_state (TrainState): The adapter's train state.\n        lora_target_modules (str): Target modules for adaptation.\n        num_layers (int): Number of transformer layers.\n        use_dora (bool): Whether DoRA was used or not.\n    \"\"\"\n    for i in range(num_layers):\n        layer_key = f\"x_layers_{i}\"\n\n        if lora_target_modules in [\"all\", \"mlp\"]:\n            for ff_layer_key in [\"ffn_layer1\", \"ffn_layer2\"]:\n                linear = model._train_state.mdl_vars[\"params\"][\n                    \"stacked_transformer_layer\"\n                ][layer_key][\"ff_layer\"][ff_layer_key][\"linear\"]\n\n                params = adapter_train_state.mdl_vars[layer_key][ff_layer_key]\n                lora_a = params[\"lora_a\"]\n                lora_b = params[\"lora_b\"]\n\n                w = linear[\"w\"]\n\n                lora_delta = jnp.einsum(\"...dr,...nr->...dn\", lora_a, lora_b)\n                lora_delta = jnp.reshape(lora_delta, w.shape)\n                w_prime = w + lora_delta\n\n                if use_dora:\n                    dora_m = params[\"dora_m\"]\n                    column_norm = jnp.linalg.norm(w_prime, ord=2, axis=0, keepdims=True)\n                    norm_adapted = w_prime / column_norm\n                    w_prime = dora_m * norm_adapted\n                    linear[\"w\"] = w_prime\n                    del linear[\"dora_m\"]\n\n                else:\n                    linear[\"w\"] = w_prime\n\n                del linear[\"lora_a\"]\n                del linear[\"lora_b\"]\n\n        if lora_target_modules in [\"all\", \"attention\"]:\n            attention = model._train_state.mdl_vars[\"params\"][\n                \"stacked_transformer_layer\"\n            ][layer_key][\"self_attention\"]\n\n            for component in [\"key\", \"query\", \"value\", \"post\"]:\n                params = adapter_train_state.mdl_vars[layer_key][component]\n                lora_a = params[\"lora_a\"]\n                lora_b = params[\"lora_b\"]\n\n                w = attention[component][\"w\"]\n\n                lora_delta = jnp.einsum(\"...dr,...nr->...dn\", lora_a, lora_b)\n                lora_delta = jnp.reshape(lora_delta, w.shape)\n                w_prime = w + lora_delta\n\n                if use_dora:\n                    dora_m = params[\"dora_m\"]\n                    column_norm = jnp.linalg.norm(w_prime, ord=2, axis=0, keepdims=True)\n                    norm_adapted = w_prime / column_norm\n                    w_prime = dora_m * norm_adapted\n                    attention[component][\"w\"] = w_prime\n                    del attention[component][\"dora_m\"]\n\n                else:\n                    attention[component][\"w\"] = w_prime\n\n                del attention[component][\"lora_a\"]\n                del attention[component][\"lora_b\"]\n\n\ndef _get_adapter_weight_params(\n    var_weight_hparams: dict, lora_target_modules: str, num_layers: int, use_dora: bool\n) -> dict:\n    \"\"\"\n    Extracts adapter weight parameters from the given variable weight hyperparameters.\n\n    Args:\n        var_weight_hparams (dict): Variable weight hyperparameters.\n        lora_target_modules (str): Target modules for adaptation.\n        num_layers (int): Number of transformer layers.\n        use_dora (bool): Whether DoRA was used or not.\n\n    Returns:\n        dict: A dictionary containing the extracted adapter weight parameters.\n    \"\"\"\n    adapter_params = {}\n    for i in range(num_layers):\n        layer = f\"x_layers_{i}\"\n        adapter_params[layer] = {}\n\n        if lora_target_modules in [\"all\", \"mlp\"]:\n            for ff_layer_key in [\"ffn_layer1\", \"ffn_layer2\"]:\n                adapter_weight_params = var_weight_hparams[\"params\"][\n                    \"stacked_transformer_layer\"\n                ][layer][\"ff_layer\"][ff_layer_key][\"linear\"]\n                adapter_params[layer][ff_layer_key] = {\n                    \"lora_a\": adapter_weight_params[\"lora_a\"],\n                    \"lora_b\": adapter_weight_params[\"lora_b\"],\n                }\n\n                if use_dora:\n                    adapter_params[layer][ff_layer_key][\"dora_m\"] = (\n                        adapter_weight_params[\"dora_m\"]\n                    )\n\n        if lora_target_modules in [\"all\", \"attention\"]:\n            for component in [\"key\", \"value\", \"query\", \"post\"]:\n                adapter_weight_params = var_weight_hparams[\"params\"][\n                    \"stacked_transformer_layer\"\n                ][layer][\"self_attention\"][component]\n                adapter_params[layer][component] = {\n                    \"lora_a\": adapter_weight_params[\"lora_a\"],\n                    \"lora_b\": adapter_weight_params[\"lora_b\"],\n                }\n\n                if use_dora:\n                    adapter_params[layer][component][\"dora_m\"] = adapter_weight_params[\n                        \"dora_m\"\n                    ]\n\n    return adapter_params\n\n\ndef load_adapter_layer(\n    mdl_vars: dict,\n    model: pax_fiddle.Config,\n    lora_rank: int,\n    lora_target_modules: str,\n    use_dora: bool = False,\n) -> tuple[pax_fiddle.Config, pax_fiddle.Config]:\n    \"\"\"\n    Updates target modules with adapter layers.\n\n    Args:\n        mdl_vars (dict): Model variables.\n        model (pax_fiddle.Config): Model configuration.\n        lora_rank (int): Rank of the LoRA adaptation.\n        lora_target_modules (str): Target modules for adaptation.\n        use_dora (bool, optional): Whether DoRA was used or not.\n\n    Returns:\n        tuple[pax_fiddle.Config, pax_fiddle.Config]: Updated model configurations.\n    \"\"\"\n    original_linear_tpl = original_attn_tpl = original_combined_qkv_tpl = None\n    if lora_target_modules in [\"all\", \"mlp\"]:\n        original_linear_tpl = (\n            model.stacked_transformer_params_tpl.transformer_layer_params_tpl.tr_fflayer_tpl.fflayer_tpl.linear_tpl\n        )\n        adapter_linear_tpl = (\n            pax_fiddle.Config(\n                DoraLinear,\n                rank=lora_rank,\n            )\n            if use_dora\n            else pax_fiddle.Config(\n                LoraLinear,\n                rank=lora_rank,\n            )\n        )\n        adapter_linear_tpl.copy_fields_from(original_linear_tpl)\n        model.stacked_transformer_params_tpl.transformer_layer_params_tpl.tr_fflayer_tpl.fflayer_tpl.linear_tpl = (\n            adapter_linear_tpl\n        )\n\n    if lora_target_modules in [\"all\", \"attention\"]:\n        original_attn_tpl = (\n            model.stacked_transformer_params_tpl.transformer_layer_params_tpl.tr_atten_tpl.proj_tpl\n        )\n\n        adapter_attn_tpl = (\n            pax_fiddle.Config(DoraAttentionProjection, rank=lora_rank)\n            if use_dora\n            else pax_fiddle.Config(LoraAttentionProjection, rank=lora_rank)\n        )\n        adapter_attn_tpl.copy_fields_from(original_attn_tpl)\n\n        original_combined_qkv_tpl = (\n            model.stacked_transformer_params_tpl.transformer_layer_params_tpl.tr_atten_tpl.combined_qkv_proj_tpl\n        )\n\n        adapter_combined_qkv_tpl = (\n            pax_fiddle.Config(DoraCombinedQKVProjection, rank=lora_rank)\n            if use_dora\n            else pax_fiddle.Config(LoraCombinedQKVProjection, rank=lora_rank)\n        )\n        adapter_combined_qkv_tpl.copy_fields_from(original_combined_qkv_tpl)\n\n        model.stacked_transformer_params_tpl.transformer_layer_params_tpl.tr_atten_tpl.proj_tpl = (\n            adapter_attn_tpl\n        )\n        model.stacked_transformer_params_tpl.transformer_layer_params_tpl.tr_atten_tpl.combined_qkv_proj_tpl = (\n            adapter_combined_qkv_tpl\n        )\n\n    # initialize and add adapter weights\n    _initialize_adapter_params(\n        mdl_vars=mdl_vars,\n        num_layers=model.stacked_transformer_params_tpl.num_layers,\n        lora_rank=lora_rank,\n        lora_target_modules=lora_target_modules,\n        use_dora=use_dora,\n    )\n\n    return original_linear_tpl, original_attn_tpl, original_combined_qkv_tpl\n\n\ndef _initialize_adapter_params(\n    mdl_vars: dict,\n    num_layers,\n    lora_rank: int,\n    lora_target_modules: str,\n    use_dora: bool = False,\n    seed: int = 1234,\n) -> dict:\n    \"\"\"\n    Initializes and adds adapter parameters to target modules.\n\n    Args:\n        mdl_vars (dict): Model variables.\n        num_layers (int): Number of transformer layers.\n        lora_rank (int): Rank of the LoRA adaptation.\n        lora_target_modules (str): Target modules for adaptation.\n        use_dora (bool, optional): Whether DoRA was used or not.\n        seed (int, optional): Random seed for initialization. Defaults to 1234.\n\n    Returns:\n        dict: Updated model variables with initialized adapter parameters.\n    \"\"\"\n    for i in range(num_layers):\n        layer_key = f\"x_layers_{i}\"\n        if lora_target_modules in [\"all\", \"mlp\"]:\n            for ff_layer_key in [\"ffn_layer1\", \"ffn_layer2\"]:\n                linear = mdl_vars[\"params\"][\"stacked_transformer_layer\"][layer_key][\n                    \"ff_layer\"\n                ][ff_layer_key][\"linear\"]\n                original_w = linear[\"w\"]\n                input_dim, output_dim = original_w.shape\n                std_dev = 1 / jnp.sqrt(lora_rank)\n\n                normal_initializer = jax.nn.initializers.normal(std_dev)\n                lora_a = normal_initializer(\n                    jax.random.key(seed), (input_dim, lora_rank), jnp.float32\n                )\n                lora_b = jnp.zeros((output_dim, lora_rank))\n\n                linear[\"lora_a\"] = lora_a\n                linear[\"lora_b\"] = lora_b\n\n                if use_dora:\n                    norm = jnp.linalg.norm(original_w, ord=2, axis=0, keepdims=True)\n                    linear[\"dora_m\"] = norm\n\n        if lora_target_modules in [\"all\", \"attention\"]:\n            attention = mdl_vars[\"params\"][\"stacked_transformer_layer\"][layer_key][\n                \"self_attention\"\n            ]\n\n            for component in [\"key\", \"query\", \"value\", \"post\"]:\n                original_w = attention[component][\"w\"]\n                w_dim = original_w.shape[0]\n                std_dev = 1 / jnp.sqrt(lora_rank)\n\n                normal_initializer = jax.nn.initializers.normal(std_dev)\n                lora_a = normal_initializer(\n                    jax.random.key(seed), (w_dim, lora_rank), jnp.float32\n                )\n                lora_b = jnp.zeros((w_dim, lora_rank))\n\n                attention[component][\"lora_a\"] = lora_a\n                attention[component][\"lora_b\"] = lora_b\n\n                if use_dora:\n                    norm = jnp.linalg.norm(\n                        original_w, ord=2, axis=0, keepdims=True\n                    ).astype(jnp.float32)\n                    attention[component][\"dora_m\"] = norm\n    return mdl_vars\n"
  },
  {
    "path": "v1/src/finetuning/__init__.py",
    "content": ""
  },
  {
    "path": "v1/src/finetuning/finetuning_example.py",
    "content": "\"\"\"\nExample usage of the TimesFM Finetuning Framework.\n\nFor single GPU:\npython script.py --training_mode=single\n\nFor multiple GPUs:\npython script.py --training_mode=multi --gpu_ids=0,1,2\n\"\"\"\n\nimport os\nfrom dataclasses import asdict\nfrom os import path\nfrom typing import Optional, Tuple\n\nimport numpy as np\nimport pandas as pd\nimport torch\nimport torch.multiprocessing as mp\nimport yfinance as yf\nfrom absl import app, flags\nfrom huggingface_hub import snapshot_download\nfrom safetensors.torch import load_file\nfrom torch.utils.data import Dataset\n\nfrom finetuning.finetuning_torch import FinetuningConfig, TimesFMFinetuner\nfrom timesfm import TimesFm, TimesFmCheckpoint, TimesFmHparams\nfrom timesfm.pytorch_patched_decoder import (PatchedTimeSeriesDecoder,\n                                             TimesFMConfig)\n\nFLAGS = flags.FLAGS\n\nflags.DEFINE_enum(\n    \"training_mode\",\n    \"single\",\n    [\"single\", \"multi\"],\n    'Training mode: \"single\" for single-GPU or \"multi\" for multi-GPU training.',\n)\n\nflags.DEFINE_list(\n    \"gpu_ids\", [\"0\"],\n    \"Comma-separated list of GPU IDs to use for multi-GPU training. Example: 0,1,2\"\n)\n\nflags.DEFINE_string(\n    \"local_model_path\",\n    None,\n    \"Path to a local .safetensors model file. If provided, overrides Hugging Face download.\"\n)\n\nclass TimeSeriesDataset(Dataset):\n  \"\"\"Dataset for time series data compatible with TimesFM.\"\"\"\n\n  def __init__(self,\n               series: np.ndarray,\n               context_length: int,\n               horizon_length: int,\n               freq_type: int = 0):\n    \"\"\"\n        Initialize dataset.\n\n        Args:\n            series: Time series data\n            context_length: Number of past timesteps to use as input\n            horizon_length: Number of future timesteps to predict\n            freq_type: Frequency type (0, 1, or 2)\n        \"\"\"\n    if freq_type not in [0, 1, 2]:\n      raise ValueError(\"freq_type must be 0, 1, or 2\")\n\n    self.series = series\n    self.context_length = context_length\n    self.horizon_length = horizon_length\n    self.freq_type = freq_type\n    self._prepare_samples()\n\n  def _prepare_samples(self) -> None:\n    \"\"\"Prepare sliding window samples from the time series.\"\"\"\n    self.samples = []\n    total_length = self.context_length + self.horizon_length\n\n    for start_idx in range(0, len(self.series) - total_length + 1):\n      end_idx = start_idx + self.context_length\n      x_context = self.series[start_idx:end_idx]\n      x_future = self.series[end_idx:end_idx + self.horizon_length]\n      self.samples.append((x_context, x_future))\n\n  def __len__(self) -> int:\n    return len(self.samples)\n\n  def __getitem__(\n      self, index: int\n  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n    x_context, x_future = self.samples[index]\n\n    x_context = torch.tensor(x_context, dtype=torch.float32)\n    x_future = torch.tensor(x_future, dtype=torch.float32)\n\n    input_padding = torch.zeros_like(x_context)\n    freq = torch.tensor([self.freq_type], dtype=torch.long)\n\n    return x_context, input_padding, freq, x_future\n\n\ndef prepare_datasets(series: np.ndarray,\n                     context_length: int,\n                     horizon_length: int,\n                     freq_type: int = 0,\n                     train_split: float = 0.8) -> Tuple[Dataset, Dataset]:\n  \"\"\"\n    Prepare training and validation datasets from time series data.\n\n    Args:\n        series: Input time series data\n        context_length: Number of past timesteps to use\n        horizon_length: Number of future timesteps to predict\n        freq_type: Frequency type (0, 1, or 2)\n        train_split: Fraction of data to use for training\n\n    Returns:\n        Tuple of (train_dataset, val_dataset)\n    \"\"\"\n  train_size = int(len(series) * train_split)\n  train_data = series[:train_size]\n  val_data = series[train_size:]\n\n  # Create datasets with specified frequency type\n  train_dataset = TimeSeriesDataset(train_data,\n                                    context_length=context_length,\n                                    horizon_length=horizon_length,\n                                    freq_type=freq_type)\n\n  val_dataset = TimeSeriesDataset(val_data,\n                                  context_length=context_length,\n                                  horizon_length=horizon_length,\n                                  freq_type=freq_type)\n\n  return train_dataset, val_dataset\n\n\ndef get_model(load_weights: bool = False):\n  device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n  hparams = TimesFmHparams(\n      backend=device,\n      per_core_batch_size=32,\n      horizon_len=128,\n      num_layers=50,\n      use_positional_embedding=False,\n      context_len=192,\n  )\n  \n  if load_weights:\n    if FLAGS.local_model_path:\n      tfm_config = TimesFMConfig()\n      model = PatchedTimeSeriesDecoder(tfm_config)\n      loaded_checkpoint = load_file(FLAGS.local_model_path)\n    else:\n      repo_id = \"google/timesfm-2.0-500m-pytorch\"\n      tfm = TimesFm(hparams=hparams,\n              checkpoint=TimesFmCheckpoint(huggingface_repo_id=repo_id))\n\n      tfm_config = tfm._model_config\n      model = PatchedTimeSeriesDecoder(tfm_config)\n      checkpoint_path = path.join(snapshot_download(repo_id), \"torch_model.ckpt\")\n      loaded_checkpoint = torch.load(checkpoint_path, weights_only=True)\n\n    model.load_state_dict(loaded_checkpoint)\n  return model, hparams, tfm_config\n\n\ndef plot_predictions(\n    model: TimesFm,\n    val_dataset: Dataset,\n    save_path: Optional[str] = \"predictions.png\",\n) -> None:\n  \"\"\"\n    Plot model predictions against ground truth for a batch of validation data.\n\n    Args:\n      model: Trained TimesFM model\n      val_dataset: Validation dataset\n      save_path: Path to save the plot\n    \"\"\"\n  import matplotlib.pyplot as plt\n\n  model.eval()\n\n  x_context, x_padding, freq, x_future = val_dataset[0]\n  x_context = x_context.unsqueeze(0)  # Add batch dimension\n  x_padding = x_padding.unsqueeze(0)\n  freq = freq.unsqueeze(0)\n  x_future = x_future.unsqueeze(0)\n\n  device = next(model.parameters()).device\n  x_context = x_context.to(device)\n  x_padding = x_padding.to(device)\n  freq = freq.to(device)\n  x_future = x_future.to(device)\n\n  with torch.no_grad():\n    predictions = model(x_context, x_padding.float(), freq)\n    predictions_mean = predictions[..., 0]  # [B, N, horizon_len]\n    last_patch_pred = predictions_mean[:, -1, :]  # [B, horizon_len]\n\n  context_vals = x_context[0].cpu().numpy()\n  future_vals = x_future[0].cpu().numpy()\n  pred_vals = last_patch_pred[0].cpu().numpy()\n\n  context_len = len(context_vals)\n  horizon_len = len(future_vals)\n\n  plt.figure(figsize=(12, 6))\n\n  plt.plot(range(context_len),\n           context_vals,\n           label=\"Historical Data\",\n           color=\"blue\",\n           linewidth=2)\n\n  plt.plot(\n      range(context_len, context_len + horizon_len),\n      future_vals,\n      label=\"Ground Truth\",\n      color=\"green\",\n      linestyle=\"--\",\n      linewidth=2,\n  )\n\n  plt.plot(range(context_len, context_len + horizon_len),\n           pred_vals,\n           label=\"Prediction\",\n           color=\"red\",\n           linewidth=2)\n\n  plt.xlabel(\"Time Step\")\n  plt.ylabel(\"Value\")\n  plt.title(\"TimesFM Predictions vs Ground Truth\")\n  plt.legend()\n  plt.grid(True)\n\n  if save_path:\n    plt.savefig(save_path)\n    print(f\"Plot saved to {save_path}\")\n\n  plt.close()\n\n\ndef get_data(context_len: int,\n             horizon_len: int,\n             freq_type: int = 0) -> Tuple[Dataset, Dataset]:\n  df = yf.download(\"AAPL\", start=\"2010-01-01\", end=\"2019-01-01\")\n  time_series = df[\"Close\"].values\n\n  train_dataset, val_dataset = prepare_datasets(\n      series=time_series,\n      context_length=context_len,\n      horizon_length=horizon_len,\n      freq_type=freq_type,\n      train_split=0.8,\n  )\n\n  print(f\"Created datasets:\")\n  print(f\"- Training samples: {len(train_dataset)}\")\n  print(f\"- Validation samples: {len(val_dataset)}\")\n  print(f\"- Using frequency type: {freq_type}\")\n  return train_dataset, val_dataset\n\n\ndef single_gpu_example():\n  \"\"\"Basic example of finetuning TimesFM on stock data.\"\"\"\n  model, hparams, tfm_config = get_model(load_weights=True)\n  config = FinetuningConfig(batch_size=256,\n                            num_epochs=5,\n                            learning_rate=1e-4,\n                            use_wandb=True,\n                            freq_type=1,\n                            log_every_n_steps=10,\n                            val_check_interval=0.5,\n                            use_quantile_loss=True)\n\n  train_dataset, val_dataset = get_data(128,\n                                        tfm_config.horizon_len,\n                                        freq_type=config.freq_type)\n  finetuner = TimesFMFinetuner(model, config)\n\n  print(\"\\nStarting finetuning...\")\n  results = finetuner.finetune(train_dataset=train_dataset,\n                               val_dataset=val_dataset)\n\n  print(\"\\nFinetuning completed!\")\n  print(f\"Training history: {len(results['history']['train_loss'])} epochs\")\n\n  plot_predictions(\n      model=model,\n      val_dataset=val_dataset,\n      save_path=\"timesfm_predictions.png\",\n  )\n\n\ndef setup_process(rank, world_size, model, config, train_dataset, val_dataset,\n                  return_dict):\n  \"\"\"Setup process function with optimized CUDA handling.\"\"\"\n  try:\n    if torch.cuda.is_available():\n      torch.cuda.set_device(rank)\n\n    os.environ[\"MASTER_ADDR\"] = config.master_addr\n    os.environ[\"MASTER_PORT\"] = config.master_port\n    if not torch.distributed.is_initialized():\n      torch.distributed.init_process_group(backend=\"nccl\",\n                                           world_size=world_size,\n                                           rank=rank)\n\n    finetuner = TimesFMFinetuner(model, config, rank=rank)\n\n    results = finetuner.finetune(train_dataset=train_dataset,\n                                 val_dataset=val_dataset)\n\n    if rank == 0:\n      return_dict[\"results\"] = results\n      plot_predictions(\n          model=model,\n          val_dataset=val_dataset,\n          save_path=\"timesfm_predictions.png\",\n      )\n\n  except Exception as e:\n    print(f\"Error in process {rank}: {str(e)}\")\n    raise e\n  finally:\n    if torch.distributed.is_initialized():\n      torch.distributed.destroy_process_group()\n\n\ndef multi_gpu_example():\n  \"\"\"Example of finetuning TimesFM using multiple GPUs with optimized spawn.\"\"\"\n  mp.set_start_method(\"spawn\", force=True)\n\n  gpu_ids = [0, 1]\n  world_size = len(gpu_ids)\n\n  model, hparams, tfm_config = get_model(load_weights=True)\n\n  # Create config\n  config = FinetuningConfig(\n      batch_size=256,\n      num_epochs=5,\n      learning_rate=3e-5,\n      use_wandb=True,\n      distributed=True,\n      gpu_ids=gpu_ids,\n      log_every_n_steps=50,\n      val_check_interval=0.5,\n  )\n  train_dataset, val_dataset = get_data(128, tfm_config.horizon_len)\n  manager = mp.Manager()\n  return_dict = manager.dict()\n\n  # Launch processes\n  mp.spawn(\n      setup_process,\n      args=(world_size, model, config, train_dataset, val_dataset, return_dict),\n      nprocs=world_size,\n      join=True,\n  )\n\n  results = return_dict.get(\"results\", None)\n  print(\"\\nFinetuning completed!\")\n  return results\n\n\ndef main(argv):\n  \"\"\"Main function that selects and runs the appropriate training mode.\"\"\"\n\n  try:\n    if FLAGS.training_mode == \"single\":\n      print(\"\\nStarting single-GPU training...\")\n      single_gpu_example()\n    else:\n      gpu_ids = [int(id) for id in FLAGS.gpu_ids]\n      print(f\"\\nStarting multi-GPU training using GPUs: {gpu_ids}...\")\n\n      config = FinetuningConfig(\n          batch_size=256,\n          num_epochs=5,\n          learning_rate=3e-5,\n          use_wandb=True,\n          distributed=True,\n          gpu_ids=gpu_ids,\n      )\n\n      results = multi_gpu_example(config)\n      print(\"\\nMulti-GPU training completed!\")\n\n  except Exception as e:\n    print(f\"Training failed: {str(e)}\")\n  finally:\n    if torch.distributed.is_initialized():\n      torch.distributed.destroy_process_group()\n\n\nif __name__ == \"__main__\":\n  app.run(main)\n"
  },
  {
    "path": "v1/src/finetuning/finetuning_torch.py",
    "content": "\"\"\"\nTimesFM Finetuner: A flexible framework for finetuning TimesFM models on custom datasets.\n\"\"\"\n\nimport logging\nimport os\nfrom abc import ABC, abstractmethod\nfrom dataclasses import dataclass, field\nfrom typing import Any, Callable, Dict, List, Optional\n\nimport torch\nimport torch.distributed as dist\nimport torch.nn as nn\nfrom torch.nn.parallel import DistributedDataParallel as DDP\nfrom torch.utils.data import DataLoader, Dataset\nfrom timesfm.pytorch_patched_decoder import create_quantiles\n\nimport wandb\n\n\nclass MetricsLogger(ABC):\n  \"\"\"Abstract base class for logging metrics during training.\n\n    This class defines the interface for logging metrics during model training.\n    Concrete implementations can log to different backends (e.g., WandB, TensorBoard).\n    \"\"\"\n\n  @abstractmethod\n  def log_metrics(self,\n                  metrics: Dict[str, Any],\n                  step: Optional[int] = None) -> None:\n    \"\"\"Log metrics to the specified backend.\n\n        Args:\n          metrics: Dictionary containing metric names and values.\n          step: Optional step number or epoch for the metrics.\n        \"\"\"\n    pass\n\n  @abstractmethod\n  def close(self) -> None:\n    \"\"\"Clean up any resources used by the logger.\"\"\"\n    pass\n\n\nclass WandBLogger(MetricsLogger):\n  \"\"\"Weights & Biases implementation of metrics logging.\n\n    Args:\n      project: Name of the W&B project.\n      config: Configuration dictionary to log.\n      rank: Process rank in distributed training.\n    \"\"\"\n\n  def __init__(self, project: str, config: Dict[str, Any], rank: int = 0):\n    self.rank = rank\n    if rank == 0:\n      wandb.init(project=project, config=config)\n\n  def log_metrics(self,\n                  metrics: Dict[str, Any],\n                  step: Optional[int] = None) -> None:\n    \"\"\"Log metrics to W&B if on the main process.\n\n        Args:\n          metrics: Dictionary of metrics to log.\n          step: Current training step or epoch.\n        \"\"\"\n    if self.rank == 0:\n      wandb.log(metrics, step=step)\n\n  def close(self) -> None:\n    \"\"\"Finish the W&B run if on the main process.\"\"\"\n    if self.rank == 0:\n      wandb.finish()\n\n\nclass DistributedManager:\n  \"\"\"Manages distributed training setup and cleanup.\n\n    Args:\n      world_size: Total number of processes.\n      rank: Process rank.\n      master_addr: Address of the master process.\n      master_port: Port for distributed communication.\n      backend: PyTorch distributed backend to use.\n    \"\"\"\n\n  def __init__(\n      self,\n      world_size: int,\n      rank: int,\n      master_addr: str = \"localhost\",\n      master_port: str = \"12358\",\n      backend: str = \"nccl\",\n  ):\n    self.world_size = world_size\n    self.rank = rank\n    self.master_addr = master_addr\n    self.master_port = master_port\n    self.backend = backend\n\n  def setup(self) -> None:\n    \"\"\"Initialize the distributed environment.\"\"\"\n    os.environ[\"MASTER_ADDR\"] = self.master_addr\n    os.environ[\"MASTER_PORT\"] = self.master_port\n\n    if not dist.is_initialized():\n      dist.init_process_group(backend=self.backend,\n                              world_size=self.world_size,\n                              rank=self.rank)\n\n  def cleanup(self) -> None:\n    \"\"\"Clean up the distributed environment.\"\"\"\n    if dist.is_initialized():\n      dist.destroy_process_group()\n\n\n@dataclass\nclass FinetuningConfig:\n  \"\"\"Configuration for model training.\n\n    Args:\n      batch_size: Number of samples per batch.\n      num_epochs: Number of training epochs.\n      learning_rate: Initial learning rate.\n      weight_decay: L2 regularization factor.\n      freq_type: Frequency, can be [0, 1, 2].\n      use_quantile_loss: bool = False  # Flag to enable/disable quantile loss\n      quantiles: Optional[List[float]] = None\n      device: Device to train on ('cuda' or 'cpu').\n      distributed: Whether to use distributed training.\n      gpu_ids: List of GPU IDs to use.\n      master_port: Port for distributed training.\n      master_addr: Address for distributed training.\n      use_wandb: Whether to use Weights & Biases logging.\n      wandb_project: W&B project name.\n      log_every_n_steps: Log metrics every N steps (batches), this is inspired from Pytorch Lightning\n      val_check_interval: How often within one training epoch to check val metrics. (also from Pytorch Lightning)\n        Can be: float (0.0-1.0): fraction of epoch (e.g., 0.5 = validate twice per epoch)\n                int: validate every N batches\n    \"\"\"\n\n  batch_size: int = 32\n  num_epochs: int = 20\n  learning_rate: float = 1e-4\n  weight_decay: float = 0.01\n  freq_type: int = 0\n  use_quantile_loss: bool = False\n  quantiles: Optional[List[float]] = None\n  device: str = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n  distributed: bool = False\n  gpu_ids: List[int] = field(default_factory=lambda: [0])\n  master_port: str = \"12358\"\n  master_addr: str = \"localhost\"\n  use_wandb: bool = False\n  wandb_project: str = \"timesfm-finetuning\"\n  log_every_n_steps: int = 50\n  val_check_interval: float = 0.5\n\n\nclass TimesFMFinetuner:\n  \"\"\"Handles model training and validation.\n\n    Args:\n      model: PyTorch model to train.\n      config: Training configuration.\n      rank: Process rank for distributed training.\n      loss_fn: Loss function (defaults to MSE).\n      logger: Optional logging.Logger instance.\n    \"\"\"\n\n  def __init__(\n      self,\n      model: nn.Module,\n      config: FinetuningConfig,\n      rank: int = 0,\n      loss_fn: Optional[Callable] = None,\n      logger: Optional[logging.Logger] = None,\n  ):\n    self.model = model\n    self.config = config\n    self.rank = rank\n    self.logger = logger or logging.getLogger(__name__)\n    self.device = torch.device(\n        f\"cuda:{rank}\" if torch.cuda.is_available() else \"cpu\")\n    self.loss_fn = loss_fn or (lambda x, y: torch.mean((x - y.squeeze(-1))**2))\n\n    if config.use_wandb:\n      self.metrics_logger = WandBLogger(config.wandb_project, config.__dict__,\n                                        rank)\n\n    if config.distributed:\n      self.dist_manager = DistributedManager(\n          world_size=len(config.gpu_ids),\n          rank=rank,\n          master_addr=config.master_addr,\n          master_port=config.master_port,\n      )\n      self.dist_manager.setup()\n      self.model = self._setup_distributed_model()\n\n  def _setup_distributed_model(self) -> nn.Module:\n    \"\"\"Configure model for distributed training.\"\"\"\n    self.model = self.model.to(self.device)\n    return DDP(self.model,\n               device_ids=[self.config.gpu_ids[self.rank]],\n               output_device=self.config.gpu_ids[self.rank])\n\n  def _create_dataloader(self, dataset: Dataset, is_train: bool) -> DataLoader:\n    \"\"\"Create appropriate DataLoader based on training configuration.\n\n        Args:\n          dataset: Dataset to create loader for.\n          is_train: Whether this is for training (affects shuffling).\n\n        Returns:\n          DataLoader instance.\n        \"\"\"\n    if self.config.distributed:\n      sampler = torch.utils.data.distributed.DistributedSampler(\n          dataset,\n          num_replicas=len(self.config.gpu_ids),\n          rank=dist.get_rank(),\n          shuffle=is_train)\n    else:\n      sampler = None\n\n    return DataLoader(\n        dataset,\n        batch_size=self.config.batch_size,\n        shuffle=(is_train and not self.config.distributed),\n        sampler=sampler,\n    )\n\n  def _quantile_loss(self, pred: torch.Tensor, actual: torch.Tensor,\n                     quantile: float) -> torch.Tensor:\n    \"\"\"Calculates quantile loss.\n        Args:\n            pred: Predicted values\n            actual: Actual values\n            quantile: Quantile at which loss is computed\n        Returns:\n            Quantile loss\n        \"\"\"\n    dev = actual - pred\n    loss_first = dev * quantile\n    loss_second = -dev * (1.0 - quantile)\n    return 2 * torch.where(loss_first >= 0, loss_first, loss_second)\n\n  def _process_batch(self, batch: List[torch.Tensor]) -> tuple:\n    \"\"\"Process a single batch of data.\n\n        Args:\n          batch: List of input tensors.\n\n        Returns:\n          Tuple of (loss, predictions).\n        \"\"\"\n    x_context, x_padding, freq, x_future = [\n        t.to(self.device, non_blocking=True) for t in batch\n    ]\n\n    predictions = self.model(x_context, x_padding.float(), freq)\n    predictions_mean = predictions[..., 0]\n    last_patch_pred = predictions_mean[:, -1, :]\n\n    loss = self.loss_fn(last_patch_pred, x_future.squeeze(-1))\n    if self.config.use_quantile_loss:\n      quantiles = self.config.quantiles or create_quantiles()\n      for i, quantile in enumerate(quantiles):\n        last_patch_quantile = predictions[:, -1, :, i + 1]\n        loss += torch.mean(\n            self._quantile_loss(last_patch_quantile, x_future.squeeze(-1),\n                                quantile))\n\n    return loss, predictions\n\n  def _train_epoch(self, train_loader: DataLoader,\n                   optimizer: torch.optim.Optimizer) -> float:\n    \"\"\"Train for one epoch in a distributed setting.\n\n        Args:\n            train_loader: DataLoader for training data.\n            optimizer: Optimizer instance.\n\n        Returns:\n            Average training loss for the epoch.\n        \"\"\"\n    self.model.train()\n    total_loss = 0.0\n    num_batches = len(train_loader)\n\n    for batch in train_loader:\n      loss, _ = self._process_batch(batch)\n\n      optimizer.zero_grad()\n      loss.backward()\n      optimizer.step()\n\n      total_loss += loss.item()\n\n    avg_loss = total_loss / num_batches\n\n    if self.config.distributed:\n      avg_loss_tensor = torch.tensor(avg_loss, device=self.device)\n      dist.all_reduce(avg_loss_tensor, op=dist.ReduceOp.SUM)\n      avg_loss = (avg_loss_tensor / dist.get_world_size()).item()\n\n    return avg_loss\n\n  def _validate(self, val_loader: DataLoader) -> float:\n    \"\"\"Perform validation.\n\n        Args:\n            val_loader: DataLoader for validation data.\n\n        Returns:\n            Average validation loss.\n        \"\"\"\n    self.model.eval()\n    total_loss = 0.0\n    num_batches = len(val_loader)\n\n    with torch.no_grad():\n      for batch in val_loader:\n        loss, _ = self._process_batch(batch)\n        total_loss += loss.item()\n\n    avg_loss = total_loss / num_batches\n\n    if self.config.distributed:\n      avg_loss_tensor = torch.tensor(avg_loss, device=self.device)\n      dist.all_reduce(avg_loss_tensor, op=dist.ReduceOp.SUM)\n      avg_loss = (avg_loss_tensor / dist.get_world_size()).item()\n\n    return avg_loss\n\n  def finetune(self, train_dataset: Dataset,\n               val_dataset: Dataset) -> Dict[str, Any]:\n    \"\"\"Train the model.\n\n        Args:\n          train_dataset: Training dataset.\n          val_dataset: Validation dataset.\n\n        Returns:\n          Dictionary containing training history.\n        \"\"\"\n    self.model = self.model.to(self.device)\n    train_loader = self._create_dataloader(train_dataset, is_train=True)\n    val_loader = self._create_dataloader(val_dataset, is_train=False)\n\n    optimizer = torch.optim.Adam(self.model.parameters(),\n                                 lr=self.config.learning_rate,\n                                 weight_decay=self.config.weight_decay)\n\n    history = {\"train_loss\": [], \"val_loss\": [], \"learning_rate\": []}\n\n    self.logger.info(\n        f\"Starting training for {self.config.num_epochs} epochs...\")\n    self.logger.info(f\"Training samples: {len(train_dataset)}\")\n    self.logger.info(f\"Validation samples: {len(val_dataset)}\")\n\n    try:\n      for epoch in range(self.config.num_epochs):\n        train_loss = self._train_epoch(train_loader, optimizer)\n        val_loss = self._validate(val_loader)\n        current_lr = optimizer.param_groups[0][\"lr\"]\n\n        metrics = {\n            \"train_loss\": train_loss,\n            \"val_loss\": val_loss,\n            \"learning_rate\": current_lr,\n            \"epoch\": epoch + 1,\n        }\n\n        if self.config.use_wandb:\n          self.metrics_logger.log_metrics(metrics)\n\n        history[\"train_loss\"].append(train_loss)\n        history[\"val_loss\"].append(val_loss)\n        history[\"learning_rate\"].append(current_lr)\n\n        if self.rank == 0:\n          self.logger.info(\n              f\"[Epoch {epoch+1}] Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}\"\n          )\n\n    except KeyboardInterrupt:\n      self.logger.info(\"Training interrupted by user\")\n\n    if self.config.distributed:\n      self.dist_manager.cleanup()\n\n    if self.config.use_wandb:\n      self.metrics_logger.close()\n\n    return {\"history\": history}\n"
  },
  {
    "path": "v1/src/timesfm/__init__.py",
    "content": "# Copyright 2024 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"TimesFM init file.\"\"\"\n\nprint(\n    \" See https://github.com/google-research/timesfm/blob/master/README.md for updated APIs.\"\n)\nfrom timesfm.timesfm_base import (\n    freq_map,\n    TimesFmCheckpoint,\n    TimesFmHparams,\n    TimesFmBase,\n)\nimport sys\n\ntry:\n    from timesfm.timesfm_jax import TimesFmJax as TimesFm\n    from timesfm import data_loader\n\n    print(f\"Loaded Jax TimesFM, likely because python version is {sys.version}.\")\nexcept Exception as _:\n    from timesfm.timesfm_torch import TimesFmTorch as TimesFm\n\n    print(f\"Loaded PyTorch TimesFM, likely because python version is {sys.version}.\")\n"
  },
  {
    "path": "v1/src/timesfm/data_loader.py",
    "content": "# Copyright 2024 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"TF dataloaders for general timeseries datasets.\n\nThe expected input format is csv file with a datetime index.\n\"\"\"\n\nfrom absl import logging\nimport numpy as np\nimport pandas as pd\nfrom sklearn.preprocessing import StandardScaler\nimport tensorflow as tf\nfrom . import time_features\n\n\nclass TimeSeriesdata(object):\n  \"\"\"Data loader class.\"\"\"\n\n  def __init__(\n      self,\n      data_path,\n      datetime_col,\n      num_cov_cols,\n      cat_cov_cols,\n      ts_cols,\n      train_range,\n      val_range,\n      test_range,\n      hist_len,\n      pred_len,\n      batch_size,\n      freq='H',\n      normalize=True,\n      epoch_len=None,\n      holiday=False,\n      permute=True,\n  ):\n    \"\"\"Initialize objects.\n\n    Args:\n      data_path: path to csv file\n      datetime_col: column name for datetime col\n      num_cov_cols: list of numerical global covariates\n      cat_cov_cols: list of categorical global covariates\n      ts_cols: columns corresponding to ts\n      train_range: tuple of train ranges\n      val_range: tuple of validation ranges\n      test_range: tuple of test ranges\n      hist_len: historical context\n      pred_len: prediction length\n      batch_size: batch size (number of ts in a batch)\n      freq: freq of original data\n      normalize: std. normalize data or not\n      epoch_len: num iters in an epoch\n      holiday: use holiday features or not\n      permute: permute ts in train batches or not\n\n    Returns:\n      None\n    \"\"\"\n    self.data_df = pd.read_csv(open(data_path, 'r'))\n    if not num_cov_cols:\n      self.data_df['ncol'] = np.zeros(self.data_df.shape[0])\n      num_cov_cols = ['ncol']\n    if not cat_cov_cols:\n      self.data_df['ccol'] = np.zeros(self.data_df.shape[0])\n      cat_cov_cols = ['ccol']\n    self.data_df.fillna(0, inplace=True)\n    self.data_df.set_index(pd.DatetimeIndex(self.data_df[datetime_col]),\n                           inplace=True)\n    self.num_cov_cols = num_cov_cols\n    self.cat_cov_cols = cat_cov_cols\n    self.ts_cols = ts_cols\n    self.train_range = train_range\n    self.val_range = val_range\n    self.test_range = test_range\n    data_df_idx = self.data_df.index\n    date_index = data_df_idx.union(\n        pd.date_range(\n            data_df_idx[-1] + pd.Timedelta(1, freq=freq),\n            periods=pred_len + 1,\n            freq=freq,\n        ))\n    self.time_df = time_features.TimeCovariates(\n        date_index, holiday=holiday).get_covariates()\n    self.hist_len = hist_len\n    self.pred_len = pred_len\n    self.batch_size = batch_size\n    self.freq = freq\n    self.normalize = normalize\n    self.data_mat = self.data_df[self.ts_cols].to_numpy().transpose()\n    self.data_mat = self.data_mat[:, 0:self.test_range[1]]\n    self.time_mat = self.time_df.to_numpy().transpose()\n    self.num_feat_mat = self.data_df[num_cov_cols].to_numpy().transpose()\n    self.cat_feat_mat, self.cat_sizes = self._get_cat_cols(cat_cov_cols)\n    self.normalize = normalize\n    if normalize:\n      self._normalize_data()\n    logging.info(\n        'Data Shapes: %s, %s, %s, %s',\n        self.data_mat.shape,\n        self.time_mat.shape,\n        self.num_feat_mat.shape,\n        self.cat_feat_mat.shape,\n    )\n    self.epoch_len = epoch_len\n    self.permute = permute\n\n  def _get_cat_cols(self, cat_cov_cols):\n    \"\"\"Get categorical columns.\"\"\"\n    cat_vars = []\n    cat_sizes = []\n    for col in cat_cov_cols:\n      dct = {x: i for i, x in enumerate(self.data_df[col].unique())}\n      cat_sizes.append(len(dct))\n      mapped = self.data_df[col].map(lambda x: dct[x]).to_numpy().transpose()  # pylint: disable=cell-var-from-loop\n      cat_vars.append(mapped)\n    return np.vstack(cat_vars), cat_sizes\n\n  def _normalize_data(self):\n    self.scaler = StandardScaler()\n    train_mat = self.data_mat[:, 0:self.train_range[1]]\n    self.scaler = self.scaler.fit(train_mat.transpose())\n    self.data_mat = self.scaler.transform(self.data_mat.transpose()).transpose()\n\n  def train_gen(self):\n    \"\"\"Generator for training data.\"\"\"\n    num_ts = len(self.ts_cols)\n    perm = np.arange(\n        self.train_range[0] + self.hist_len,\n        self.train_range[1] - self.pred_len,\n    )\n    perm = np.random.permutation(perm)\n    hist_len = self.hist_len\n    logging.info('Hist len: %s', hist_len)\n    if not self.epoch_len:\n      epoch_len = len(perm)\n    else:\n      epoch_len = self.epoch_len\n    for idx in perm[0:epoch_len]:\n      for _ in range(num_ts // self.batch_size + 1):\n        if self.permute:\n          tsidx = np.random.choice(num_ts, size=self.batch_size, replace=False)\n        else:\n          tsidx = np.arange(num_ts)\n        dtimes = np.arange(idx - hist_len, idx + self.pred_len)\n        (\n            bts_train,\n            bts_pred,\n            bfeats_train,\n            bfeats_pred,\n            bcf_train,\n            bcf_pred,\n        ) = self._get_features_and_ts(dtimes, tsidx, hist_len)\n\n        all_data = [\n            bts_train,\n            bfeats_train,\n            bcf_train,\n            bts_pred,\n            bfeats_pred,\n            bcf_pred,\n            tsidx,\n        ]\n        yield tuple(all_data)\n\n  def test_val_gen(self, mode='val', shift=1):\n    \"\"\"Generator for validation/test data.\"\"\"\n    if mode == 'val':\n      start = self.val_range[0]\n      end = self.val_range[1] - self.pred_len + 1\n    elif mode == 'test':\n      start = self.test_range[0]\n      end = self.test_range[1] - self.pred_len + 1\n    else:\n      raise NotImplementedError('Eval mode not implemented')\n    num_ts = len(self.ts_cols)\n    hist_len = self.hist_len\n    logging.info('Hist len: %s', hist_len)\n    perm = np.arange(start, end)\n    if self.epoch_len:\n      epoch_len = self.epoch_len\n    else:\n      epoch_len = len(perm)\n    for i in range(0, epoch_len, shift):\n      idx = perm[i]\n      for batch_idx in range(0, num_ts, self.batch_size):\n        tsidx = np.arange(batch_idx, min(batch_idx + self.batch_size, num_ts))\n        dtimes = np.arange(idx - hist_len, idx + self.pred_len)\n        (\n            bts_train,\n            bts_pred,\n            bfeats_train,\n            bfeats_pred,\n            bcf_train,\n            bcf_pred,\n        ) = self._get_features_and_ts(dtimes, tsidx, hist_len)\n        all_data = [\n            bts_train,\n            bfeats_train,\n            bcf_train,\n            bts_pred,\n            bfeats_pred,\n            bcf_pred,\n            tsidx,\n        ]\n        yield tuple(all_data)\n\n  def _get_features_and_ts(self, dtimes, tsidx, hist_len=None):\n    \"\"\"Get features and ts in specified windows.\"\"\"\n    if hist_len is None:\n      hist_len = self.hist_len\n    data_times = dtimes[dtimes < self.data_mat.shape[1]]\n    bdata = self.data_mat[:, data_times]\n    bts = bdata[tsidx, :]\n    bnf = self.num_feat_mat[:, data_times]\n    bcf = self.cat_feat_mat[:, data_times]\n    btf = self.time_mat[:, dtimes]\n    if bnf.shape[1] < btf.shape[1]:\n      rem_len = btf.shape[1] - bnf.shape[1]\n      rem_rep = np.repeat(bnf[:, [-1]], repeats=rem_len)\n      rem_rep_cat = np.repeat(bcf[:, [-1]], repeats=rem_len)\n      bnf = np.hstack([bnf, rem_rep.reshape(bnf.shape[0], -1)])\n      bcf = np.hstack([bcf, rem_rep_cat.reshape(bcf.shape[0], -1)])\n    bfeats = np.vstack([btf, bnf])\n    bts_train = bts[:, 0:hist_len]\n    bts_pred = bts[:, hist_len:]\n    bfeats_train = bfeats[:, 0:hist_len]\n    bfeats_pred = bfeats[:, hist_len:]\n    bcf_train = bcf[:, 0:hist_len]\n    bcf_pred = bcf[:, hist_len:]\n    return bts_train, bts_pred, bfeats_train, bfeats_pred, bcf_train, bcf_pred\n\n  def tf_dataset(self, mode='train', shift=1):\n    \"\"\"Tensorflow Dataset.\"\"\"\n    if mode == 'train':\n      gen_fn = self.train_gen\n    else:\n      gen_fn = lambda: self.test_val_gen(mode, shift)\n    output_types = tuple([tf.float32] * 2 + [tf.int32] + [tf.float32] * 2 +\n                         [tf.int32] * 2)\n    dataset = tf.data.Dataset.from_generator(gen_fn, output_types)\n    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)\n    return dataset\n"
  },
  {
    "path": "v1/src/timesfm/patched_decoder.py",
    "content": "# Copyright 2024 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Pax ML model for patched time-series decoder.\n\nThe file implements Residual MLPs, Patched Decoder layers and PAX ML models.\n\"\"\"\n\nimport dataclasses\nfrom typing import Optional, Tuple\n\nimport einshape as es\nfrom jax import lax\nimport jax.numpy as jnp\nfrom praxis import base_layer\nfrom praxis import base_model\nfrom praxis import layers\nfrom praxis import pax_fiddle\nfrom praxis import py_utils\nfrom praxis import pytypes\nfrom praxis.layers import activations\nfrom praxis.layers import embedding_softmax\nfrom praxis.layers import linears\nfrom praxis.layers import normalizations\nfrom praxis.layers import stochastics\nfrom praxis.layers import transformers\n\n# PAX shortcuts\nNestedMap = py_utils.NestedMap\nJTensor = pytypes.JTensor\n\nLayerTpl = pax_fiddle.Config[base_layer.BaseLayer]\ntemplate_field = base_layer.template_field\n\nPAD_VAL = 1123581321.0\nDEFAULT_QUANTILES = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]\n\n# NestedMap keys\n_INPUT_TS = \"input_ts\"\n_TARGET_FUTURE = \"actual_ts\"\n_INPUT_PADDING = \"input_padding\"\n_OUTPUT_TS = \"output_ts\"\n_FREQ = \"freq\"\n_OUTPUT_TOKENS = \"output_tokens\"\n_STATS = \"stats\"\n\n# Small numerical value.\n_TOLERANCE = 1e-7\n\n\ndef _shift_padded_seq(mask: JTensor, seq: JTensor) -> JTensor:\n  \"\"\"Shifts rows of seq based on the first 0 in each row of the mask.\"\"\"\n  num = seq.shape[1]\n\n  # Find the index of the first 0 in each row of the mask\n  first_zero_idx = jnp.argmin(mask, axis=1)\n\n  # Create a range array for indexing\n  idx_range = jnp.arange(num)\n\n  def shift_row(carry, x):\n    seq_row, shift = x\n    shifted_idx = (idx_range - shift) % num\n    shifted_row = seq_row[shifted_idx]\n    return carry, shifted_row\n\n  # Use lax.scan to shift each row of seq based on the corresponding\n  # first_zero_idx.\n  _, shifted_seq = lax.scan(shift_row, None, (seq, first_zero_idx))\n\n  return shifted_seq\n\n\nclass ResidualBlock(base_layer.BaseLayer):\n  \"\"\"Simple feedforward block with residual connection.\n\n  Attributes:\n    input_dims: input dimension.\n    hidden_dims: hidden dimension.\n    output_dims: output dimension.\n    dropout_prob: dropout probability.\n    layer_norm: whether to use layer norm or not.\n    dropout_tpl: config for dropout.\n    ln_tpl: config for layer norm.\n    act_tpl: config for activation in hidden layer.\n  \"\"\"\n\n  input_dims: int = 0\n  hidden_dims: int = 0\n  output_dims: int = 0\n  dropout_prob: float = 0.0\n  layer_norm: bool = False\n  dropout_tpl: LayerTpl = template_field(stochastics.Dropout)\n  ln_tpl: LayerTpl = template_field(normalizations.LayerNorm)\n  act_tpl: LayerTpl = template_field(activations.Swish)\n\n  def setup(self):\n    lnorm_tpl = self.ln_tpl.clone()\n    lnorm_tpl.dim = self.output_dims\n    self.create_child(\"ln_layer\", lnorm_tpl)\n\n    dropout_tpl = self.dropout_tpl.clone()\n    dropout_tpl.keep_prob = 1.0 - self.dropout_prob\n    self.create_child(\"dropout\", dropout_tpl)\n\n    self.create_child(\n        \"hidden_layer\",\n        pax_fiddle.Config(\n            linears.FeedForward,\n            input_dims=self.input_dims,\n            output_dims=self.hidden_dims,\n            activation_tpl=self.act_tpl.clone(),\n        ),\n    )\n\n    self.create_child(\n        \"output_layer\",\n        pax_fiddle.Config(\n            linears.FeedForward,\n            input_dims=self.hidden_dims,\n            output_dims=self.output_dims,\n            activation_tpl=pax_fiddle.Config(activations.Identity),\n        ),\n    )\n\n    self.create_child(\n        \"residual_layer\",\n        pax_fiddle.Config(\n            linears.FeedForward,\n            input_dims=self.input_dims,\n            output_dims=self.output_dims,\n            activation_tpl=pax_fiddle.Config(activations.Identity),\n        ),\n    )\n\n  def __call__(self, inputs: JTensor) -> JTensor:\n    hidden = self.hidden_layer(inputs)\n    output = self.output_layer(hidden)\n    output = self.dropout(output)\n    residual = self.residual_layer(inputs)\n    if self.layer_norm:\n      return self.ln_layer(output + residual)\n    else:\n      return output + residual\n\n\ndef _masked_mean_std(inputs: JTensor,\n                     padding: JTensor) -> Tuple[JTensor, JTensor]:\n  \"\"\"Calculates mean and standard deviation of arr across axis 1.\n\n  It should exclude values where pad is 1.\n\n  Args:\n    inputs: A JAX array of shape [b, n, p].\n    padding: A JAX array of shape [b, n, p] with values 0 or 1.\n\n  Returns:\n    A tuple containing the mean and standard deviation of arr. We return the\n    statistics of the first patch with more than three non-padded values.\n  \"\"\"\n  # Selecting the first pad with more than 3 unpadded values.\n  pad_sum = jnp.sum(1 - padding, axis=2)\n\n  def _get_patch_index(arr: JTensor):\n    indices = jnp.argmax(arr >= 3, axis=1)\n    row_sum = (arr >= 3).sum(axis=1)\n    return jnp.where(row_sum == 0, arr.shape[1] - 1, indices)\n\n  patch_indices = _get_patch_index(pad_sum)\n  bidxs = jnp.arange(inputs.shape[0])\n\n  arr = inputs[bidxs, patch_indices, :]\n  pad = padding[bidxs, patch_indices, :]\n\n  # Create a mask where P is 0\n  mask = 1 - pad\n\n  # Calculate the number of valid elements\n  num_valid_elements = jnp.sum(mask, axis=1)\n\n  num_valid_elements = jnp.where(num_valid_elements == 0, 1, num_valid_elements)\n\n  # Calculate the masked sum for mean and centered squared sum for variance.\n  masked_sum = jnp.sum(arr * mask, axis=1)\n\n  # Calculate the masked mean and standard deviation\n  masked_mean = masked_sum / num_valid_elements\n  centered = (arr - masked_mean[:, None]) * mask\n  masked_var = jnp.sum(centered**2, axis=1) / num_valid_elements\n  masked_var = jnp.where(masked_var < 0.0, 0.0, masked_var)\n  masked_std = jnp.sqrt(masked_var)\n\n  return masked_mean, masked_std\n\n\ndef _create_quantiles() -> list[float]:\n  \"\"\"Returns the quantiles for forecasting.\"\"\"\n  return DEFAULT_QUANTILES\n\n\nclass PatchedTimeSeriesDecoder(base_layer.BaseLayer):\n  \"\"\"Patch decoder layer for time-series foundation model.\n\n  Attributes:\n    patch_len: length of input patches.\n    horizon_len: length of output patches. Referred to as `output_patch_len`\n      during inference.\n    model_dims: model dimension of stacked transformer layer.\n    hidden_dims: hidden dimensions in fully connected layers.\n    quantiles: list of quantiles for non prob model.\n    residual_block_tpl: config for residual block.\n    stacked_transformer_params_tpl: config for stacked transformer.\n    use_freq: whether to use frequency encoding.\n\n  In all of what followed, except specified otherwise, B is batch size, T is\n  sequence length of time-series. N is the number of input patches that can be\n  obtained from T. P is the input patch length and H is the horizon length. Q is\n  number of output logits. D is model dimension.\n  \"\"\"\n\n  patch_len: int = 0\n  horizon_len: int = 0\n  model_dims: int = 0\n  hidden_dims: int = 0\n  quantiles: list[float] = dataclasses.field(default_factory=_create_quantiles)\n  residual_block_tpl: LayerTpl = template_field(ResidualBlock)\n  stacked_transformer_params_tpl: LayerTpl = template_field(\n      transformers.StackedTransformer)\n  use_freq: bool = True\n  use_pos_emb: bool = True\n\n  def setup(self) -> None:\n    \"\"\"Construct the model.\"\"\"\n    num_outputs = len(self.quantiles) + 1\n\n    stl = self.stacked_transformer_params_tpl.clone()\n    stl.model_dims = self.model_dims\n    stl.hidden_dims = self.hidden_dims\n    stl.mask_self_attention = True\n\n    self.create_child(\"stacked_transformer_layer\", stl)\n\n    input_resl = self.residual_block_tpl.clone()\n    ff_in_dims = 2 * self.patch_len\n    input_resl.input_dims = ff_in_dims\n    input_resl.hidden_dims = self.hidden_dims\n    input_resl.output_dims = self.model_dims\n    self.create_child(\n        \"input_ff_layer\",\n        input_resl,\n    )\n\n    horizon_resl = self.residual_block_tpl.clone()\n    horizon_resl.input_dims = self.model_dims\n    horizon_resl.hidden_dims = self.hidden_dims\n    horizon_resl.output_dims = self.horizon_len * num_outputs\n    self.create_child(\n        \"horizon_ff_layer\",\n        horizon_resl,\n    )\n\n    self.create_child(\n        \"position_emb\",\n        pax_fiddle.Config(layers.PositionalEmbedding,\n                          embedding_dims=self.model_dims),\n    )\n\n    if self.use_freq:\n      self.create_child(\n          \"freq_emb\",\n          pax_fiddle.Config(\n              embedding_softmax.Embedding,\n              num_classes=3,\n              input_dims=self.model_dims,\n          ),\n      )\n\n  def transform_decode_state(\n      self, transform_fn: base_layer.DecodeStateTransformFn) -> None:\n    \"\"\"Transforms all decode state variables based on transform_fn.\"\"\"\n    self.stacked_transformer_layer.transform_decode_state(transform_fn)\n\n  def _forward_transform(\n      self, inputs: JTensor,\n      patched_pads: JTensor) -> Tuple[JTensor, Tuple[JTensor, JTensor]]:\n    \"\"\"Input is of shape [B, N, P].\"\"\"\n    mu, sigma = _masked_mean_std(inputs, patched_pads)\n    sigma = jnp.maximum(sigma, _TOLERANCE)\n    # Normalize each patch.\n    outputs = (inputs - mu[:, None, None]) / sigma[:, None, None]\n    outputs = jnp.where(\n        jnp.abs(inputs - PAD_VAL) < _TOLERANCE, PAD_VAL, outputs)\n    return outputs, (mu, sigma)\n\n  def _reverse_transform(self, outputs: JTensor,\n                         stats: Tuple[JTensor, JTensor]) -> JTensor:\n    \"\"\"Output is of shape [B, N, P, Q].\"\"\"\n    mu, sigma = stats\n    return outputs * sigma[:, None, None, None] + mu[:, None, None, None]\n\n  def _preprocess_input(\n      self,\n      input_ts: JTensor,\n      input_padding: JTensor,\n      pos_emb: Optional[JTensor] = None,\n  ) -> Tuple[JTensor, JTensor, Optional[Tuple[JTensor, JTensor]], JTensor]:\n    \"\"\"Preprocess input for stacked transformer.\"\"\"\n    # Reshape into patches.\n    patched_inputs = es.jax_einshape(\"b(np)->bnp\", input_ts, p=self.patch_len)\n    patched_pads = es.jax_einshape(\"b(np)->bnp\",\n                                   input_padding,\n                                   p=self.patch_len)\n    patched_inputs = jnp.where(\n        jnp.abs(patched_pads - 1.0) < _TOLERANCE, 0.0, patched_inputs)\n    patched_pads = jnp.where(\n        jnp.abs(patched_inputs - PAD_VAL) < _TOLERANCE, 1, patched_pads)\n    patched_inputs, stats = self._forward_transform(patched_inputs,\n                                                    patched_pads)\n\n    # B x N x D\n    patched_inputs = patched_inputs * (1.0 - patched_pads)\n    concat_inputs = jnp.concatenate([patched_inputs, patched_pads], axis=-1)\n    model_input = self.input_ff_layer(concat_inputs)\n    # A patch should not be padded even if there is at least one zero.\n    patched_padding = jnp.min(patched_pads, axis=-1)\n\n    if self.use_pos_emb:\n      if pos_emb is None:\n        position_emb = self.position_emb(seq_length=model_input.shape[1])\n      else:\n        position_emb = pos_emb\n      if self.do_eval:\n        if position_emb.shape[0] != model_input.shape[0]:\n          position_emb = jnp.repeat(position_emb, model_input.shape[0], axis=0)\n        position_emb = _shift_padded_seq(patched_padding, position_emb)\n      model_input += position_emb\n\n    return model_input, patched_padding, stats, patched_inputs\n\n  def _postprocess_output(\n      self,\n      model_output: JTensor,\n      num_outputs: int,\n      stats: Tuple[JTensor, JTensor],\n  ) -> JTensor:\n    \"\"\"Postprocess output of stacked transformer.\"\"\"\n    # B x N x (H.Q)\n    output_ts = self.horizon_ff_layer(model_output)\n    output_ts = es.jax_einshape(\"bn(hq)->bnhq\",\n                                output_ts,\n                                q=num_outputs,\n                                h=self.horizon_len)\n    return self._reverse_transform(output_ts, stats)\n\n  def __call__(self, inputs: NestedMap) -> NestedMap:\n    \"\"\"PatchTST call.\n\n    Args:\n      inputs: A NestedMap containing (1) input_ts: input sequence of shape [B,\n        T] where T must be multiple of patch_length; (2) input_padding: that\n        contains padding map.\n\n    Returns:\n      A nested map with two keys:\n      (1) 'output_tokens' of shape [B, N, D].\n      (2) 'output_ts' of shape [B, N, H, Q]\n      (3) 'stats' a Tuple of statistics for renormalization.\n    \"\"\"\n    input_ts, input_padding = inputs[_INPUT_TS], inputs[_INPUT_PADDING]\n    num_outputs = len(self.quantiles) + 1\n    model_input, patched_padding, stats, _ = self._preprocess_input(\n        input_ts=input_ts,\n        input_padding=input_padding,\n    )\n    if self.use_freq:\n      freq = inputs[_FREQ].astype(jnp.int32)\n      f_emb = self.freq_emb(freq)  # B x 1 x D\n      f_emb = jnp.repeat(f_emb, model_input.shape[1], axis=1)\n      model_input += f_emb\n    model_output = self.stacked_transformer_layer(model_input, patched_padding)\n\n    output_ts = self._postprocess_output(model_output, num_outputs, stats)\n    return NestedMap({\n        _OUTPUT_TOKENS: model_output,\n        _OUTPUT_TS: output_ts,\n        _STATS: stats\n    })\n\n  def decode(\n      self,\n      inputs: NestedMap,\n      horizon_len: int,\n      output_patch_len: Optional[int] = None,\n      max_len: int | None = None,\n      return_forecast_on_context: bool = False,\n  ) -> tuple[JTensor, JTensor]:\n    \"\"\"Auto-regressive decoding without caching.\n\n    Args:\n      inputs: input time-series and paddings. Time-series shape B x C, padding\n        shape shape B x (C + H) where H is the prediction length.\n      horizon_len: prediction length.\n      output_patch_len: output length to be fetched from one step of\n        auto-regressive decoding.\n      max_len: maximum training context length.\n      return_forecast_on_context: whether to return the model forecast on the\n        context except the first input patch.\n\n    Returns:\n      Tuple of two forecasting results:\n      - Point (mean) output predictions as a tensor with shape B x H'.\n      - Full predictions (mean and quantiles) as a tensor with shape\n        B x H' x (1 + # quantiles).\n      In particular, if return_forecast_on_context is True, H' is H plus\n      the forecastable context length, i.e. context_len - (first) patch_len.\n    \"\"\"\n    final_out = inputs[_INPUT_TS]\n    context_len = final_out.shape[1]\n    paddings = inputs[_INPUT_PADDING]\n    if max_len is None:\n      max_len = context_len\n    if self.use_freq:\n      freq = inputs[_FREQ].astype(jnp.int32)\n    else:\n      freq = jnp.zeros([final_out.shape[0], 1], dtype=jnp.int32)\n    full_outputs = []\n    if paddings.shape[1] != final_out.shape[1] + horizon_len:\n      raise ValueError(\n          \"Length of paddings must match length of input + horizon_len:\"\n          f\" {paddings.shape[1]} != {final_out.shape[1]} + {horizon_len}\")\n    if output_patch_len is None:\n      output_patch_len = self.horizon_len\n    num_decode_patches = (horizon_len + output_patch_len -\n                          1) // output_patch_len\n    for step_index in range(num_decode_patches):\n      current_padding = paddings[:, 0:final_out.shape[1]]\n      input_ts = final_out[:, -max_len:]\n      input_padding = current_padding[:, -max_len:]\n      model_input = NestedMap(\n          input_ts=input_ts,\n          input_padding=input_padding,\n          freq=freq,\n      )\n      fprop_outputs = self(model_input)[_OUTPUT_TS]\n      if return_forecast_on_context and step_index == 0:\n        # For the first decodings step, collect the model forecast on the\n        # context except the unavailable first input batch forecast.\n        new_full_ts = fprop_outputs[:, :-1, :self.patch_len, :]\n        new_full_ts = es.jax_einshape(\"bnph->b(np)h\", new_full_ts)\n\n        full_outputs.append(new_full_ts)\n\n      # (full batch, last patch, output_patch_len, index of mean forecast = 0)\n      new_ts = fprop_outputs[:, -1, :output_patch_len, 0]\n      new_full_ts = fprop_outputs[:, -1, :output_patch_len, :]\n      # (full batch, last patch, output_patch_len, all output indices)\n      full_outputs.append(new_full_ts)\n      final_out = jnp.concatenate([final_out, new_ts], axis=-1)\n\n    if return_forecast_on_context:\n      # `full_outputs` indexing starts at after the first input patch.\n      full_outputs = jnp.concatenate(full_outputs,\n                                     axis=1)[:, :(context_len - self.patch_len +\n                                                  horizon_len), :]\n    else:\n      # `full_outputs` indexing starts at the forecast horizon.\n      full_outputs = jnp.concatenate(full_outputs, axis=1)[:, 0:horizon_len, :]\n\n    return (full_outputs[:, :, 0], full_outputs)\n\n\nclass PatchedDecoderFinetuneModel(base_model.BaseModel):\n  \"\"\"Model class for finetuning patched time-series decoder.\n\n  Attributes:\n    core_layer_tpl: config for core layer.\n    freq: freq to finetune on.\n  \"\"\"\n\n  core_layer_tpl: LayerTpl = template_field(PatchedTimeSeriesDecoder)\n  freq: int = 0\n\n  def setup(self) -> None:\n    self.create_child(\"core_layer\", self.core_layer_tpl)\n\n  def compute_predictions(self, input_batch: NestedMap) -> NestedMap:\n    input_ts = input_batch[_INPUT_TS]\n    input_padding = jnp.zeros_like(input_ts)\n    context_len = input_ts.shape[1]\n    input_patch_len = self.core_layer_tpl.patch_len\n    context_pad = ((context_len + input_patch_len - 1) //\n                   input_patch_len) * input_patch_len - context_len\n\n    input_ts = jnp.pad(input_ts, [(0, 0), (context_pad, 0)])\n    input_padding = jnp.pad(input_padding, [(0, 0), (context_pad, 0)],\n                            constant_values=1)\n    freq = jnp.ones([input_ts.shape[0], 1], dtype=jnp.int32) * self.freq\n    new_input_batch = NestedMap(\n        input_ts=input_ts,\n        input_padding=input_padding,\n        freq=freq,\n    )\n    return self.core_layer(new_input_batch)\n\n  def _quantile_loss(self, pred: JTensor, actual: JTensor,\n                     quantile: float) -> JTensor:\n    \"\"\"Calculates quantile loss.\n\n    Args:\n      pred: B x T\n      actual: B x T\n      quantile: quantile at which loss is computed.\n\n    Returns:\n      per coordinate loss.\n    \"\"\"\n    dev = actual - pred\n    loss_first = dev * quantile\n    loss_second = -dev * (1.0 - quantile)\n    return 2 * jnp.where(loss_first >= 0, loss_first, loss_second)\n\n  def compute_loss(self, prediction_output: NestedMap,\n                   input_batch: NestedMap) -> Tuple[NestedMap, NestedMap]:\n    output_ts = prediction_output[_OUTPUT_TS]\n    actual_ts = input_batch[_TARGET_FUTURE]\n    pred_ts = output_ts[:, -1, 0:actual_ts.shape[1], :]\n    loss = jnp.square(pred_ts[:, :, 0] - actual_ts)\n    for i, quantile in enumerate(self.core_layer.quantiles):\n      loss += self._quantile_loss(pred_ts[:, :, i + 1], actual_ts, quantile)\n    loss = loss.mean()\n    loss_weight = jnp.array(1.0, dtype=jnp.float32)\n    per_example_out = NestedMap()\n    return {\"avg_qloss\": (loss, loss_weight)}, per_example_out\n"
  },
  {
    "path": "v1/src/timesfm/pytorch_patched_decoder.py",
    "content": "# Copyright 2024 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Pytorch version of patched decoder.\"\"\"\n\nimport dataclasses\nimport math\nfrom typing import List, Tuple\nimport torch\nfrom torch import nn\nimport torch.nn.functional as F\n\n\ndef create_quantiles() -> list[float]:\n  return [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]\n\n\n@dataclasses.dataclass\nclass TimesFMConfig:\n  \"\"\"Config for initializing timesfm patched_decoder class.\"\"\"\n\n  # The number of blocks in the model.\n  num_layers: int = 20\n  # The number of attention heads used in the attention layers of the model.\n  num_heads: int = 16\n  # The number of key-value heads for implementing attention.\n  num_kv_heads: int = 16\n  # The hidden size of the model.\n  hidden_size: int = 1280\n  # The dimension of the MLP representations.\n  intermediate_size: int = 1280\n  # The number of head dimensions.\n  head_dim: int = 80\n  # The epsilon used by the rms normalization layers.\n  rms_norm_eps: float = 1e-6\n  # Patch length\n  patch_len: int = 32\n  # Horizon length\n  horizon_len: int = 128\n  # quantiles\n  quantiles: List[float] = dataclasses.field(default_factory=create_quantiles)\n  # Padding value\n  pad_val: float = 1123581321.0\n  # Tolerance\n  tolerance: float = 1e-6\n  # The dtype of the weights.\n  dtype: str = \"bfloat32\"\n  # use positional embedding\n  use_positional_embedding: bool = True\n\n\ndef _masked_mean_std(\n    inputs: torch.Tensor,\n    padding: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:\n  \"\"\"Calculates mean and standard deviation of `inputs` across axis 1.\n\n  It excludes values where `padding` is 1.\n\n  Args:\n    inputs: A PyTorch tensor of shape [b, n, p].\n    padding: A PyTorch tensor of shape [b, n, p] with values 0 or 1.\n\n  Returns:\n    A tuple containing the mean and standard deviation.\n    We return the statistics of the first patch with more than three non-padded\n    values.\n  \"\"\"\n  # Selecting the first patch with more than 3 unpadded values.\n  pad_sum = torch.sum(1 - padding, dim=2)\n\n  def _get_patch_index(arr: torch.Tensor):\n    indices = torch.argmax((arr >= 3).to(torch.int32), dim=1)\n    row_sum = (arr >= 3).to(torch.int32).sum(dim=1)\n    return torch.where(row_sum == 0, arr.shape[1] - 1, indices)\n\n  patch_indices = _get_patch_index(pad_sum)\n  bidxs = torch.arange(inputs.shape[0])\n\n  arr = inputs[bidxs, patch_indices, :]\n  pad = padding[bidxs, patch_indices, :]\n\n  # Create a mask where padding is 0\n  mask = 1 - pad\n\n  # Calculate the number of valid elements\n  num_valid_elements = torch.sum(mask, dim=1)\n  num_valid_elements = torch.clamp(num_valid_elements, min=1.0)\n\n  # Calculate the masked sum and mean\n  masked_sum = torch.sum(arr * mask, dim=1)\n  masked_mean = masked_sum / num_valid_elements\n\n  # Calculate the masked variance using centered values (numerically stable)\n  masked_centered_arr = (arr - masked_mean.unsqueeze(-1)) * mask\n  masked_var = torch.sum(masked_centered_arr**2, dim=1) / num_valid_elements\n  masked_var = torch.clamp(masked_var, min=0.0)\n  masked_std = torch.sqrt(masked_var)\n\n  return masked_mean, masked_std\n\n\ndef _shift_padded_seq(mask: torch.Tensor, seq: torch.Tensor) -> torch.Tensor:\n  \"\"\"Shifts rows of seq based on the first 0 in each row of the mask.\n\n  Args:\n    mask: mask tensor of shape [B, N]\n    seq: seq tensor of shape [B, N, P]\n\n  Returns:\n    Returns the shifted sequence.\n  \"\"\"\n  batch_size, num_seq, feature_dim = seq.shape\n\n  new_mask: torch.BoolTensor = mask == 0\n\n  # Use argmax to find the first True value in each row\n  indices = new_mask.to(torch.int32).argmax(dim=1)\n\n  # Handle rows with all zeros\n  indices[~new_mask.any(dim=1)] = -1\n\n  # Create index ranges for each sequence in the batch\n  idx_range = (torch.arange(num_seq).to(\n      seq.device).unsqueeze(0).unsqueeze(-1).expand(batch_size, -1,\n                                                    feature_dim))\n\n  # Calculate shifted indices for each element in each sequence\n  shifted_idx = (idx_range - indices[:, None, None]) % num_seq\n\n  # Gather values from seq using shifted indices\n  shifted_seq = seq.gather(1, shifted_idx)\n\n  return shifted_seq\n\n\ndef get_large_negative_number(dtype: torch.dtype) -> torch.Tensor:\n  \"\"\"Returns a large negative value for the given dtype.\"\"\"\n  if dtype.is_floating_point:\n    dtype_max = torch.finfo(dtype).max\n  else:\n    dtype_max = torch.iinfo(dtype).max\n  return torch.tensor(-0.7 * dtype_max, dtype=dtype)\n\n\ndef apply_mask_to_logits(logits: torch.Tensor,\n                         mask: torch.Tensor) -> torch.Tensor:\n  \"\"\"Applies a floating-point mask to a set of logits.\n\n  Args:\n      logits: A torch.Tensor of logit values.\n      mask: A torch.Tensor (float32) of mask values with the encoding described\n        in the function documentation.\n\n  Returns:\n      Masked logits.\n  \"\"\"\n\n  min_value = get_large_negative_number(logits.dtype)\n\n  return torch.where((mask >= min_value * 0.5), logits, min_value)\n\n\ndef convert_paddings_to_mask(\n    paddings: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor:\n  \"\"\"Converts binary paddings to a logit mask ready to add to attention matrix.\n\n  Args:\n      paddings: binary torch.Tensor of shape [B, T], with 1 denoting padding\n        token.\n      dtype: data type of the input.\n\n  Returns:\n      A torch.Tensor of shape [B, 1, 1, T] ready to add to attention logits.\n  \"\"\"\n  attention_mask = paddings.detach().clone()\n  attention_mask = attention_mask[:, None, None, :]  # Equivalent to jnp.newaxis\n  attention_mask *= get_large_negative_number(dtype)\n  return attention_mask\n\n\ndef causal_mask(input_t: torch.Tensor) -> torch.Tensor:\n  \"\"\"Computes and returns causal mask.\n\n  Args:\n      input_t: A torch.Tensor of shape [B, T, D].\n\n  Returns:\n      An attention_mask torch.Tensor of shape [1, 1, T, T]. Attention mask has\n      already been converted to large negative values.\n  \"\"\"\n  assert input_t.dtype.is_floating_point, input_t.dtype\n  large_negative_number = get_large_negative_number(input_t.dtype)\n  t = input_t.shape[1]\n  col_idx = torch.arange(t).unsqueeze(0).repeat(t, 1)\n  row_idx = torch.arange(t).unsqueeze(1).repeat(1, t)\n  mask = (row_idx < col_idx).to(input_t.dtype) * large_negative_number\n  return (mask.unsqueeze(0).unsqueeze(0).to(input_t.device)\n         )  # Equivalent to jnp.newaxis\n\n\ndef merge_masks(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:\n  \"\"\"Merges 2 masks.\n\n  logscale mask is expected but 0/1 mask is also fine.\n\n  Args:\n      a: torch.Tensor of shape [1|B, 1, 1|T, S].\n      b: torch.Tensor of shape [1|B, 1, 1|T, S].\n\n  Returns:\n      torch.Tensor of shape [1|B, 1, 1|T, S].\n  \"\"\"\n\n  def expand_t(key_mask):\n    query_mask = key_mask.transpose(-1, -2)  # Equivalent of jnp.transpose\n    return torch.minimum(query_mask, key_mask)\n\n  if a.shape[2] != b.shape[2]:\n    if a.shape[2] == 1:\n      a = expand_t(a)\n    else:\n      assert b.shape[2] == 1\n      b = expand_t(b)\n\n  assert a.shape[1:] == b.shape[1:], f\"a.shape={a.shape}, b.shape={b.shape}.\"\n  return torch.minimum(a, b)  # Element-wise minimum, similar to jnp.minimum\n\n\nclass ResidualBlock(nn.Module):\n  \"\"\"TimesFM residual block.\"\"\"\n\n  def __init__(\n      self,\n      input_dims,\n      hidden_dims,\n      output_dims,\n  ):\n    super(ResidualBlock, self).__init__()\n    self.input_dims = input_dims\n    self.hidden_dims = hidden_dims\n    self.output_dims = output_dims\n\n    # Hidden Layer\n    self.hidden_layer = nn.Sequential(\n        nn.Linear(input_dims, hidden_dims),\n        nn.SiLU(),\n    )\n\n    # Output Layer\n    self.output_layer = nn.Linear(hidden_dims, output_dims)\n    # Residual Layer\n    self.residual_layer = nn.Linear(input_dims, output_dims)\n\n  def forward(self, x):\n    hidden = self.hidden_layer(x)\n    output = self.output_layer(hidden)\n    residual = self.residual_layer(x)\n    return output + residual\n\n\nclass RMSNorm(torch.nn.Module):\n  \"\"\"Pax rms norm in pytorch.\"\"\"\n\n  def __init__(\n      self,\n      dim: int,\n      eps: float = 1e-6,\n      add_unit_offset: bool = False,\n  ):\n    super().__init__()\n    self.eps = eps\n    self.add_unit_offset = add_unit_offset\n    self.weight = nn.Parameter(torch.zeros(dim))\n\n  def _norm(self, x):\n    return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)\n\n  def forward(self, x):\n    output = self._norm(x.float())\n    if self.add_unit_offset:\n      output = output * (1 + self.weight.float())\n    else:\n      output = output * self.weight.float()\n    return output.type_as(x)\n\n\nclass TransformerMLP(nn.Module):\n  \"\"\"Pax transformer MLP in pytorch.\"\"\"\n\n  def __init__(\n      self,\n      hidden_size: int,\n      intermediate_size: int,\n  ):\n    super().__init__()\n    self.gate_proj = nn.Linear(hidden_size, intermediate_size)\n    self.down_proj = nn.Linear(intermediate_size, hidden_size)\n    self.layer_norm = nn.LayerNorm(normalized_shape=hidden_size, eps=1e-6)\n\n  def forward(self, x, paddings=None):\n    gate_inp = self.layer_norm(x)\n    gate = self.gate_proj(gate_inp)\n    gate = F.relu(gate)\n    outputs = self.down_proj(gate)\n    if paddings is not None:\n      outputs = outputs * (1.0 - paddings[:, :, None])\n    return outputs + x\n\n\nclass TimesFMAttention(nn.Module):\n  \"\"\"Implements the attention used in TimesFM.\"\"\"\n\n  def __init__(\n      self,\n      hidden_size: int,\n      num_heads: int,\n      num_kv_heads: int,\n      head_dim: int,\n  ):\n    super().__init__()\n\n    self.num_heads = num_heads\n    self.num_kv_heads = num_kv_heads\n\n    assert self.num_heads % self.num_kv_heads == 0\n    self.num_queries_per_kv = self.num_heads // self.num_kv_heads\n\n    self.hidden_size = hidden_size\n    self.head_dim = head_dim\n\n    self.q_size = self.num_heads * self.head_dim\n    self.kv_size = self.num_kv_heads * self.head_dim\n    self.scaling = nn.Parameter(\n        torch.empty((self.head_dim,), dtype=torch.float32),)\n\n    self.qkv_proj = nn.Linear(\n        self.hidden_size,\n        (self.num_heads + 2 * self.num_kv_heads) * self.head_dim,\n    )\n    self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size)\n\n  def _per_dim_scaling(self, query: torch.Tensor) -> torch.Tensor:\n    # [batch_size, n_local_heads, input_len, head_dim]\n    r_softplus_0 = 1.442695041\n    softplus_func = torch.nn.Softplus()\n    scale = r_softplus_0 / math.sqrt(self.head_dim)\n    scale = scale * softplus_func(self.scaling)\n    return query * scale[None, None, None, :]\n\n  def forward(\n      self,\n      hidden_states: torch.Tensor,\n      mask: torch.Tensor,\n      kv_write_indices: torch.Tensor | None = None,\n      kv_cache: Tuple[torch.Tensor, torch.Tensor] | None = None,\n  ) -> torch.Tensor:\n    hidden_states_shape = hidden_states.shape\n    assert len(hidden_states_shape) == 3\n\n    batch_size, input_len, _ = hidden_states_shape\n\n    qkv = self.qkv_proj(hidden_states)\n    xq, xk, xv = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)\n\n    xq = xq.view(batch_size, -1, self.num_heads, self.head_dim)\n    xk = xk.view(batch_size, -1, self.num_kv_heads, self.head_dim)\n    xv = xv.view(batch_size, -1, self.num_kv_heads, self.head_dim)\n    xq = self._per_dim_scaling(xq)\n\n    # Write new kv cache.\n    # [batch_size, input_len, n_local_kv_heads, head_dim]\n    if kv_cache is not None and kv_write_indices is not None:\n      k_cache, v_cache = kv_cache\n      k_cache.index_copy_(1, kv_write_indices, xk)\n      v_cache.index_copy_(1, kv_write_indices, xv)\n\n      key = k_cache\n      value = v_cache\n    else:\n      key = xk\n      value = xv\n    if self.num_kv_heads != self.num_heads:\n      # [batch_size, max_seq_len, n_local_heads, head_dim]\n      key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=2)\n      value = torch.repeat_interleave(value, self.num_queries_per_kv, dim=2)\n\n    # [batch_size, n_local_heads, input_len, head_dim]\n    q = xq.transpose(1, 2)\n    # [batch_size, n_local_heads, max_seq_len, head_dim]\n    k = key.transpose(1, 2)\n    v = value.transpose(1, 2)\n\n    # [batch_size, n_local_heads, input_len, max_seq_len]\n    scores = torch.matmul(q, k.transpose(2, 3))\n    scores = scores + mask\n    scores = F.softmax(scores.float(), dim=-1).type_as(q)\n\n    # [batch_size, n_local_heads, input_len, head_dim]\n    output = torch.matmul(scores, v)\n    # return scores, output.transpose(1, 2).contiguous()\n\n    # [batch_size, input_len, hidden_dim]\n    output = output.transpose(1, 2).contiguous().view(batch_size, input_len, -1)\n    output = self.o_proj(output)\n    return scores, output\n\n\nclass TimesFMDecoderLayer(nn.Module):\n  \"\"\"Transformer layer.\"\"\"\n\n  def __init__(\n      self,\n      hidden_size: int,\n      intermediate_size: int,\n      num_heads: int,\n      num_kv_heads: int,\n      head_dim: int,\n      rms_norm_eps: float = 1e-6,\n  ):\n    super().__init__()\n    self.self_attn = TimesFMAttention(\n        hidden_size=hidden_size,\n        num_heads=num_heads,\n        num_kv_heads=num_kv_heads,\n        head_dim=head_dim,\n    )\n    self.mlp = TransformerMLP(\n        hidden_size=hidden_size,\n        intermediate_size=intermediate_size,\n    )\n    self.input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps)\n\n  def forward(\n      self,\n      hidden_states: torch.Tensor,\n      mask: torch.Tensor,\n      paddings: torch.Tensor,\n      kv_write_indices: torch.Tensor | None = None,\n      kv_cache: Tuple[torch.Tensor, torch.Tensor] | None = None,\n  ) -> torch.Tensor:\n    # Self Attention\n    residual = hidden_states\n    hidden_states = self.input_layernorm(hidden_states)\n    scores, hidden_states = self.self_attn(\n        hidden_states=hidden_states,\n        mask=mask,\n        kv_write_indices=kv_write_indices,\n        kv_cache=kv_cache,\n    )\n    hidden_states = residual + hidden_states\n\n    # MLP\n    hidden_states = self.mlp(hidden_states, paddings=paddings)\n\n    return scores, hidden_states\n\n\nclass StackedDecoder(nn.Module):\n  \"\"\"Stacked transformer layer.\"\"\"\n\n  def __init__(\n      self,\n      hidden_size: int,\n      intermediate_size: int,\n      num_heads: int,\n      num_kv_heads: int,\n      head_dim: int,\n      num_layers: int,\n      rms_norm_eps: float = 1e-6,\n  ):\n    super().__init__()\n\n    self.layers = nn.ModuleList()\n    for _ in range(num_layers):\n      self.layers.append(\n          TimesFMDecoderLayer(\n              hidden_size=hidden_size,\n              intermediate_size=intermediate_size,\n              num_heads=num_heads,\n              num_kv_heads=num_kv_heads,\n              head_dim=head_dim,\n              rms_norm_eps=rms_norm_eps,\n          ))\n\n  def forward(\n      self,\n      hidden_states: torch.Tensor,\n      paddings: torch.Tensor,\n      kv_write_indices: torch.Tensor | None = None,\n      kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] | None = None,\n  ) -> torch.Tensor:\n    padding_mask = convert_paddings_to_mask(paddings, hidden_states.dtype)\n    atten_mask = causal_mask(hidden_states)\n    mask = merge_masks(padding_mask, atten_mask)\n    for i in range(len(self.layers)):\n      layer = self.layers[i]\n      kv_cache = kv_caches[i] if kv_caches is not None else None\n      _, hidden_states = layer(\n          hidden_states=hidden_states,\n          mask=mask,\n          paddings=paddings,\n          kv_write_indices=kv_write_indices,\n          kv_cache=kv_cache,\n      )\n    return hidden_states\n\n\nclass PositionalEmbedding(torch.nn.Module):\n  \"\"\"Generates position embedding for a given 1-d sequence.\n\n  Attributes:\n      min_timescale: Start of the geometric index. Determines the periodicity of\n        the added signal.\n      max_timescale: End of the geometric index. Determines the frequency of the\n        added signal.\n      embedding_dims: Dimension of the embedding to be generated.\n  \"\"\"\n\n  def __init__(\n      self,\n      embedding_dims: int,\n      min_timescale: int = 1,\n      max_timescale: int = 10_000,\n  ) -> None:\n    super().__init__()\n    self.min_timescale = min_timescale\n    self.max_timescale = max_timescale\n    self.embedding_dims = embedding_dims\n\n  def forward(self, seq_length=None, position=None):\n    \"\"\"Generates a Tensor of sinusoids with different frequencies.\n\n    Args:\n        seq_length: an optional Python int defining the output sequence length.\n          if the `position` argument is specified.\n        position:   [B, seq_length], optional position for each token in the\n          sequence, only required when the sequence is packed.\n\n    Returns:\n        [B, seqlen, D] if `position` is specified, else [1, seqlen, D]\n    \"\"\"\n    if position is None:\n      assert seq_length is not None\n      # [1, seqlen]\n      position = torch.arange(seq_length, dtype=torch.float32).unsqueeze(0)\n    else:\n      assert position.ndim == 2, position.shape\n\n    num_timescales = self.embedding_dims // 2\n    log_timescale_increment = math.log(\n        float(self.max_timescale) / float(self.min_timescale)) / max(\n            num_timescales - 1, 1)\n    inv_timescales = self.min_timescale * torch.exp(\n        torch.arange(num_timescales, dtype=torch.float32) *\n        -log_timescale_increment)\n    scaled_time = position.unsqueeze(2) * inv_timescales.unsqueeze(0).unsqueeze(\n        0)\n    signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2)\n    # Padding to ensure correct embedding dimension\n    signal = F.pad(signal, (0, 0, 0, self.embedding_dims % 2))\n    return signal\n\n\nclass PatchedTimeSeriesDecoder(nn.Module):\n  \"\"\"Patched time-series decoder.\"\"\"\n\n  def __init__(self, config: TimesFMConfig):\n    super().__init__()\n    self.config = config\n    self.input_ff_layer = ResidualBlock(\n        input_dims=2 * config.patch_len,\n        output_dims=config.hidden_size,\n        hidden_dims=config.intermediate_size,\n    )\n    self.freq_emb = nn.Embedding(num_embeddings=3,\n                                 embedding_dim=config.hidden_size)\n    self.horizon_ff_layer = ResidualBlock(\n        input_dims=config.hidden_size,\n        output_dims=config.horizon_len * (1 + len(config.quantiles)),\n        hidden_dims=config.intermediate_size,\n    )\n    self.stacked_transformer = StackedDecoder(\n        hidden_size=self.config.hidden_size,\n        intermediate_size=self.config.intermediate_size,\n        num_heads=self.config.num_heads,\n        num_kv_heads=self.config.num_kv_heads,\n        head_dim=self.config.head_dim,\n        num_layers=self.config.num_layers,\n        rms_norm_eps=self.config.rms_norm_eps,\n    )\n    if self.config.use_positional_embedding:\n      self.position_emb = PositionalEmbedding(self.config.hidden_size)\n\n  def _forward_transform(\n      self, inputs: torch.Tensor, patched_pads: torch.Tensor\n  ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:\n    \"\"\"Input is of shape [B, N, P].\"\"\"\n    mu, sigma = _masked_mean_std(inputs, patched_pads)\n    sigma = torch.clamp(sigma, min=self.config.tolerance)\n\n    # Normalize each patch\n    outputs = (inputs - mu[:, None, None]) / sigma[:, None, None]\n    outputs = torch.where(\n        torch.abs(inputs - self.config.pad_val) < self.config.tolerance,\n        torch.tensor(self.config.pad_val,\n                     dtype=outputs.dtype,\n                     device=outputs.device),\n        outputs,\n    )\n    return outputs, (mu, sigma)\n\n  def _reverse_transform(\n      self, outputs: torch.Tensor, stats: tuple[torch.Tensor,\n                                                torch.Tensor]) -> torch.Tensor:\n    \"\"\"Output is of shape [B, N, P, Q].\"\"\"\n    mu, sigma = stats\n    return outputs * sigma[:, None, None, None] + mu[:, None, None, None]\n\n  def _preprocess_input(\n      self,\n      input_ts: torch.Tensor,\n      input_padding: torch.Tensor,\n  ) -> tuple[\n      torch.Tensor,\n      torch.Tensor,\n      tuple[torch.Tensor, torch.Tensor] | None,\n      torch.Tensor,\n  ]:\n    \"\"\"Preprocess input for stacked transformer.\"\"\"\n\n    # Reshape into patches (using view for efficiency)\n    bsize = input_ts.shape[0]\n    patched_inputs = input_ts.view(bsize, -1, self.config.patch_len)\n    patched_pads = input_padding.view(bsize, -1, self.config.patch_len)\n\n    patched_inputs = torch.where(\n        torch.abs(patched_pads - 1.0) < self.config.tolerance,\n        torch.tensor(0.0,\n                     dtype=patched_inputs.dtype,\n                     device=patched_inputs.device),\n        patched_inputs,\n    )\n    patched_pads = torch.where(\n        torch.abs(patched_inputs - self.config.pad_val) < self.config.tolerance,\n        torch.tensor(1.0, dtype=patched_pads.dtype, device=patched_pads.device),\n        patched_pads,\n    )\n    patched_inputs, stats = self._forward_transform(patched_inputs,\n                                                    patched_pads)\n\n    # B x N x D\n    patched_inputs = patched_inputs * (1.0 - patched_pads)\n    concat_inputs = torch.cat([patched_inputs, patched_pads], dim=-1)\n    model_input = self.input_ff_layer(concat_inputs)\n\n    # A patch should not be padded even if there is at least one zero.\n    patched_padding = torch.min(patched_pads,\n                                dim=-1)[0]  # Get the values from the min result\n    if self.config.use_positional_embedding:\n      pos_emb = self.position_emb(model_input.shape[1]).to(model_input.device)\n      pos_emb = torch.concat([pos_emb] * model_input.shape[0], dim=0)\n      pos_emb = _shift_padded_seq(patched_padding, pos_emb)\n      model_input += pos_emb\n\n    return model_input, patched_padding, stats, patched_inputs\n\n  def _postprocess_output(\n      self,\n      model_output: torch.Tensor,\n      num_outputs: int,\n      stats: tuple[torch.Tensor, torch.Tensor],\n  ) -> torch.Tensor:\n    \"\"\"Postprocess output of stacked transformer.\"\"\"\n\n    # B x N x (H.Q)\n    output_ts = self.horizon_ff_layer(model_output)\n\n    # Reshape using view\n    b, n, _ = output_ts.shape\n    output_ts = output_ts.view(b, n, self.config.horizon_len, num_outputs)\n\n    return self._reverse_transform(output_ts, stats)\n\n  def forward(\n      self,\n      input_ts: torch.Tensor,\n      input_padding: torch.LongTensor,\n      freq: torch.Tensor,\n  ) -> torch.Tensor:\n    num_outputs = len(self.config.quantiles) + 1\n    model_input, patched_padding, stats, _ = self._preprocess_input(\n        input_ts=input_ts,\n        input_padding=input_padding,\n    )\n    f_emb = self.freq_emb(freq)  # B x 1 x D\n    model_input += f_emb\n    model_output = self.stacked_transformer(model_input, patched_padding)\n\n    output_ts = self._postprocess_output(model_output, num_outputs, stats)\n    return output_ts\n\n  def decode(\n      self,\n      input_ts: torch.Tensor,\n      paddings: torch.Tensor,\n      freq: torch.LongTensor,\n      horizon_len: int,\n      output_patch_len: int | None = None,\n      max_len: int | None = None,\n      return_forecast_on_context: bool = False,\n  ) -> tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"Auto-regressive decoding without caching.\n\n    Args:\n      input_ts: input time-series and paddings. Time-series shape B x C.\n      paddings: padding shape B x (C + H) where H is the prediction length.\n      freq: frequency shape B x 1\n      horizon_len: prediction length.\n      output_patch_len: output length to be fetched from one step of\n        auto-regressive decoding.\n      max_len: maximum training context length.\n      return_forecast_on_context: whether to return the model forecast on the\n        context except the first input patch.\n\n    Returns:\n      Tuple of two forecasting results:\n      - Point (mean) output predictions as a tensor with shape B x H'.\n      - Full predictions (mean and quantiles) as a tensor with shape\n        B x H' x (1 + # quantiles).\n      In particular, if return_forecast_on_context is True, H' is H plus\n      the forecastable context length, i.e. context_len - (first) patch_len.\n    \"\"\"\n    final_out = input_ts\n    context_len = final_out.shape[1]\n    full_outputs = []\n    if max_len is None:\n      max_len = context_len\n    if paddings.shape[1] != final_out.shape[1] + horizon_len:\n      raise ValueError(\n          \"Length of paddings must match length of input + horizon_len:\"\n          f\" {paddings.shape[1]} != {final_out.shape[1]} + {horizon_len}\")\n    if output_patch_len is None:\n      output_patch_len = self.config.horizon_len\n    num_decode_patches = (horizon_len + output_patch_len -\n                          1) // output_patch_len\n    for step_index in range(num_decode_patches):\n      current_padding = paddings[:, 0:final_out.shape[1]]\n      input_ts = final_out[:, -max_len:]\n      input_padding = current_padding[:, -max_len:]\n      fprop_outputs = self(input_ts, input_padding, freq)\n      if return_forecast_on_context and step_index == 0:\n        # For the first decodings step, collect the model forecast on the\n        # context except the unavailable first input batch forecast.\n        new_full_ts = fprop_outputs[:, 0:-1, 0:self.config.patch_len, :]\n        new_full_ts = new_full_ts.reshape(new_full_ts.size(0), -1,\n                                          new_full_ts.size(3))\n\n        full_outputs.append(new_full_ts)\n\n      # (full batch, last patch, output_patch_len, index of mean forecast = 0)\n      new_ts = fprop_outputs[:, -1, :output_patch_len, 0]\n      new_full_ts = fprop_outputs[:, -1, :output_patch_len, :]\n      # (full batch, last patch, output_patch_len, all output indices)\n      full_outputs.append(new_full_ts)\n      final_out = torch.concatenate([final_out, new_ts], axis=-1)\n\n    if return_forecast_on_context:\n      # `full_outputs` indexing starts at after the first input patch.\n      full_outputs = torch.concatenate(\n          full_outputs,\n          axis=1)[:, :(context_len - self.config.patch_len + horizon_len), :]\n    else:\n      # `full_outputs` indexing starts at the forecast horizon.\n      full_outputs = torch.concatenate(full_outputs, axis=1)[:,\n                                                             0:horizon_len, :]\n\n    return (full_outputs[:, :, 0], full_outputs)\n"
  },
  {
    "path": "v1/src/timesfm/time_features.py",
    "content": "# Copyright 2024 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Directory to extract time covariates.\n\nExtract time covariates from datetime.\n\"\"\"\n\nimport numpy as np\nimport pandas as pd\nfrom pandas.tseries.holiday import EasterMonday\nfrom pandas.tseries.holiday import GoodFriday\nfrom pandas.tseries.holiday import Holiday\nfrom pandas.tseries.holiday import SU\nfrom pandas.tseries.holiday import TH\nfrom pandas.tseries.holiday import USColumbusDay\nfrom pandas.tseries.holiday import USLaborDay\nfrom pandas.tseries.holiday import USMartinLutherKingJr\nfrom pandas.tseries.holiday import USMemorialDay\nfrom pandas.tseries.holiday import USPresidentsDay\nfrom pandas.tseries.holiday import USThanksgivingDay\nfrom pandas.tseries.offsets import DateOffset\nfrom pandas.tseries.offsets import Day\nfrom pandas.tseries.offsets import Easter\nfrom sklearn.preprocessing import StandardScaler\nfrom tqdm import tqdm\n\n\n# This is 183 to cover half a year (in both directions), also for leap years\n# + 17 as Eastern can be between March, 22 - April, 25\nMAX_WINDOW = 183 + 17\n\n\ndef _distance_to_holiday(holiday):\n  \"\"\"Return distance to given holiday.\"\"\"\n\n  def _distance_to_day(index):\n    holiday_date = holiday.dates(\n        index - pd.Timedelta(days=MAX_WINDOW),\n        index + pd.Timedelta(days=MAX_WINDOW),\n    )\n    assert (\n        len(holiday_date) != 0  # pylint: disable=g-explicit-length-test\n    ), f\"No closest holiday for the date index {index} found.\"\n    # It sometimes returns two dates if it is exactly half a year after the\n    # holiday. In this case, the smaller distance (182 days) is returned.\n    return (index - holiday_date[0]).days\n\n  return _distance_to_day\n\n\nEasterSunday = Holiday(\n    \"Easter Sunday\", month=1, day=1, offset=[Easter(), Day(0)]\n)\nNewYearsDay = Holiday(\"New Years Day\", month=1, day=1)\nSuperBowl = Holiday(\n    \"Superbowl\", month=2, day=1, offset=DateOffset(weekday=SU(1))\n)\nMothersDay = Holiday(\n    \"Mothers Day\", month=5, day=1, offset=DateOffset(weekday=SU(2))\n)\nIndependenceDay = Holiday(\"Independence Day\", month=7, day=4)\nChristmasEve = Holiday(\"Christmas\", month=12, day=24)\nChristmasDay = Holiday(\"Christmas\", month=12, day=25)\nNewYearsEve = Holiday(\"New Years Eve\", month=12, day=31)\nBlackFriday = Holiday(\n    \"Black Friday\",\n    month=11,\n    day=1,\n    offset=[pd.DateOffset(weekday=TH(4)), Day(1)],\n)\nCyberMonday = Holiday(\n    \"Cyber Monday\",\n    month=11,\n    day=1,\n    offset=[pd.DateOffset(weekday=TH(4)), Day(4)],\n)\n\nHOLIDAYS = [\n    EasterMonday,\n    GoodFriday,\n    USColumbusDay,\n    USLaborDay,\n    USMartinLutherKingJr,\n    USMemorialDay,\n    USPresidentsDay,\n    USThanksgivingDay,\n    EasterSunday,\n    NewYearsDay,\n    SuperBowl,\n    MothersDay,\n    IndependenceDay,\n    ChristmasEve,\n    ChristmasDay,\n    NewYearsEve,\n    BlackFriday,\n    CyberMonday,\n]\n\n\nclass TimeCovariates(object):\n  \"\"\"Extract all time covariates except for holidays.\"\"\"\n\n  def __init__(\n      self,\n      datetimes,\n      normalized=True,\n      holiday=False,\n  ):\n    \"\"\"Init function.\n\n    Args:\n      datetimes: pandas DatetimeIndex (lowest granularity supported is min)\n      normalized: whether to normalize features or not\n      holiday: fetch holiday features or not\n\n    Returns:\n      None\n    \"\"\"\n    self.normalized = normalized\n    self.dti = datetimes\n    self.holiday = holiday\n\n  def _minute_of_hour(self):\n    minutes = np.array(self.dti.minute, dtype=np.float32)\n    if self.normalized:\n      minutes = minutes / 59.0 - 0.5\n    return minutes\n\n  def _hour_of_day(self):\n    hours = np.array(self.dti.hour, dtype=np.float32)\n    if self.normalized:\n      hours = hours / 23.0 - 0.5\n    return hours\n\n  def _day_of_week(self):\n    day_week = np.array(self.dti.dayofweek, dtype=np.float32)\n    if self.normalized:\n      day_week = day_week / 6.0 - 0.5\n    return day_week\n\n  def _day_of_month(self):\n    day_month = np.array(self.dti.day, dtype=np.float32)\n    if self.normalized:\n      day_month = day_month / 30.0 - 0.5\n    return day_month\n\n  def _day_of_year(self):\n    day_year = np.array(self.dti.dayofyear, dtype=np.float32)\n    if self.normalized:\n      day_year = day_year / 364.0 - 0.5\n    return day_year\n\n  def _month_of_year(self):\n    month_year = np.array(self.dti.month, dtype=np.float32)\n    if self.normalized:\n      month_year = month_year / 11.0 - 0.5\n    return month_year\n\n  def _week_of_year(self):\n    week_year = np.array(self.dti.strftime(\"%U\").astype(int), dtype=np.float32)\n    if self.normalized:\n      week_year = week_year / 51.0 - 0.5\n    return week_year\n\n  def _get_holidays(self):\n    dti_series = self.dti.to_series()\n    hol_variates = np.vstack([\n        dti_series.apply(_distance_to_holiday(h)).values for h in tqdm(HOLIDAYS)\n    ])\n    # hol_variates is (num_holiday, num_time_steps), the normalization should be\n    # performed in the num_time_steps dimension.\n    return StandardScaler().fit_transform(hol_variates.T).T\n\n  def get_covariates(self):\n    \"\"\"Get all time covariates.\"\"\"\n    moh = self._minute_of_hour().reshape(1, -1)\n    hod = self._hour_of_day().reshape(1, -1)\n    dom = self._day_of_month().reshape(1, -1)\n    dow = self._day_of_week().reshape(1, -1)\n    doy = self._day_of_year().reshape(1, -1)\n    moy = self._month_of_year().reshape(1, -1)\n    woy = self._week_of_year().reshape(1, -1)\n\n    all_covs = [\n        moh,\n        hod,\n        dom,\n        dow,\n        doy,\n        moy,\n        woy,\n    ]\n    columns = [\"moh\", \"hod\", \"dom\", \"dow\", \"doy\", \"moy\", \"woy\"]\n    if self.holiday:\n      hol_covs = self._get_holidays()\n      all_covs.append(hol_covs)\n      columns += [f\"hol_{i}\" for i in range(len(HOLIDAYS))]\n\n    return pd.DataFrame(\n        data=np.vstack(all_covs).transpose(),\n        columns=columns,\n        index=self.dti,\n    )\n"
  },
  {
    "path": "v1/src/timesfm/timesfm_base.py",
    "content": "# Copyright 2024 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Base class for TimesFM inference. This will be common to PAX and Pytorch.\"\"\"\n\nimport collections\nimport dataclasses\nimport logging\nimport multiprocessing\nfrom typing import Any, Literal, Sequence, TYPE_CHECKING\n\nimport numpy as np\nimport pandas as pd\n\nfrom utilsforecast.processing import make_future_dataframe\n\nif TYPE_CHECKING:\n    from . import xreg_lib\n    Category = xreg_lib.Category\n    XRegMode = xreg_lib.XRegMode\nelse:\n    Category = int | str\n    XRegMode = str\n\n_TOL = 1e-6\nDEFAULT_QUANTILES = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9)\n\n\ndef process_group(key, group, value_name, forecast_context_len):\n  group = group.tail(forecast_context_len)\n  return np.array(group[value_name], dtype=np.float32), key\n\n\ndef moving_average(arr, window_size):\n  \"\"\"Calculates the moving average using NumPy's convolution function.\"\"\"\n  # Pad with zeros to handle initial window positions\n  arr_padded = np.pad(arr, (window_size - 1, 0), \"constant\")\n  smoothed_arr = (np.convolve(arr_padded, np.ones(window_size), \"valid\") / \n                  window_size)\n  return [smoothed_arr, arr - smoothed_arr]\n\n\ndef freq_map(freq: str):\n  \"\"\"Returns the frequency map for the given frequency string.\"\"\"\n  freq = str.upper(freq)\n  if freq.endswith(\"MS\"):\n    return 1\n  elif freq.endswith((\"H\", \"T\", \"MIN\", \"D\", \"B\", \"U\", \"S\")):\n    return 0\n  elif (\n    freq.endswith((\"W\", \"M\"))\n    or freq.startswith(\"W-\")\n    or (freq.startswith(\"M\") and len(freq) == 2)\n  ):\n    return 1\n  elif (\n    freq.endswith((\"Y\", \"Q\", \"A\"))\n    or freq.startswith(\"Y-\")\n    or freq.startswith(\"Q-\")\n    or freq.startswith(\"A-\")\n  ):\n    return 2\n  else:\n    raise ValueError(f\"Invalid frequency: {freq}\")\n\n\ndef strip_leading_nans(arr):\n  \"\"\"\n  Removes contiguous NaN values from the beginning of a NumPy array.\n\n  Args:\n    arr: The input NumPy array.\n\n  Returns:\n    A new NumPy array with leading NaN values removed.\n    If the array is all NaNs or empty, returns an empty array.\n  \"\"\"\n\n  isnan = np.isnan(arr)\n  first_valid_index = np.argmax(~isnan)\n  return arr[first_valid_index:]\n\n\ndef linear_interpolation(arr):\n  \"\"\"\n    Performs linear interpolation to fill NaN values in a 1D numpy array.\n\n    Args:\n        arr: The 1D numpy array containing NaN values.\n\n    Returns:\n        A new numpy array with NaN values filled using linear interpolation, \n        or the original array if no NaNs are present. \n        Returns None if the input is not a 1D array.\n        Returns the original array if there are no NaN values.\n    \"\"\"\n\n  nans = np.isnan(arr)\n  if not np.any(nans):  # Check if there are any NaNs\n    return arr\n\n  def x(z):\n    return z.nonzero()[0]\n\n  nans_indices = x(nans)\n  non_nans_indices = x(~nans)\n  non_nans_values = arr[~nans]\n\n  try:\n    arr[nans] = np.interp(nans_indices, non_nans_indices, non_nans_values)\n  except ValueError:\n    if len(non_nans_values) > 0:\n      mu = np.nanmean(arr)\n    else:\n      mu = 0.0\n    arr = np.where(np.isfinite(arr), arr, mu)\n  return arr\n\n\n# Per time series normalization: forward.\ndef _normalize(batch):\n  stats = [\n      (np.mean(x), np.where((w := np.std(x)) > _TOL, w, 1.0)) for x in batch\n  ]\n  new_batch = [(x - stat[0]) / stat[1] for x, stat in zip(batch, stats)]\n  return new_batch, stats\n\n\n# Per time series normalization: inverse.\ndef _renormalize(batch, stats):\n  return [x * stat[1] + stat[0] for x, stat in zip(batch, stats)]\n\n\n@dataclasses.dataclass(kw_only=True)\nclass TimesFmHparams:\n  \"\"\"Hparams used to initialize a TimesFM model for inference.\n\n  These are the sufficient subset of hparams to configure TimesFM inference\n  agnostic to the checkpoint version, and are not necessarily the same as the\n  hparams used to train the checkpoint.\n\n  Attributes:\n    context_len: Largest context length the model allows for each decode call.\n      This technically can be any large, but practically should set to the\n      context length the checkpoint was trained with.\n    horizon_len: Forecast horizon.\n    input_patch_len: Input patch len.\n    output_patch_len: Output patch len. How many timepoints is taken from a\n      single step of autoregressive decoding. Can be set as the training horizon\n      of the checkpoint.\n    num_layers: Number of transformer layers in the model.\n    model_dims: Model dimension.\n    per_core_batch_size: Batch size on each core for data parallelism.\n    backend: One of \"cpu\", \"gpu\" or \"tpu\".\n    quantiles: Which quantiles are output by the model.\n  \"\"\"\n\n  context_len: int = 512\n  horizon_len: int = 128\n  input_patch_len: int = 32\n  output_patch_len: int = 128\n  num_layers: int = 20\n  num_heads: int = 16\n  model_dims: int = 1280\n  per_core_batch_size: int = 32\n  backend: Literal[\"cpu\", \"gpu\", \"tpu\"] = \"cpu\"\n  quantiles: Sequence[float] | None = DEFAULT_QUANTILES\n  use_positional_embedding: bool = True\n  # Hparams beyond the model.\n  point_forecast_mode: Literal[\"mean\", \"median\"] = \"median\"\n\n\n@dataclasses.dataclass(kw_only=True)\nclass TimesFmCheckpoint:\n  \"\"\"Checkpoint used to initialize a TimesFM model for inference.\n\n  Attributes:\n    version: Version of the checkpoint, e.g. \"jax\", \"torch\", \"tensorflow\", etc.\n      The factory will create the corresponding TimesFm inference class based on\n      this version.\n    path: Path to the checkpoint.\n    type: If provided, type of the checkpoint used by the specific checkpoint\n      loader per version.\n    step: If provided, step of the checkpoint.\n  \"\"\"\n\n  version: str = \"jax\"\n  path: str | None = None\n  huggingface_repo_id: str | None = None\n  type: Any = None\n  step: int | None = None\n  local_dir: str | None = None\n\n\nclass TimesFmBase:\n  \"\"\"Base TimesFM forecast API for inference.\n\n  This class is the scaffolding for calling TimesFM forecast. To properly use:\n    1. Create an instance with the correct hyperparameters of a TimesFM model.\n    2. Call `load_from_checkpoint` to load a compatible checkpoint.\n    3. Call `forecast` for inference.\n  \"\"\"\n\n  def _logging(self, s):\n    print(s)\n\n  def __post_init__(self) -> None:\n    \"\"\"Additional initialization for subclasses before checkpoint loading.\"\"\"\n    pass\n\n  def __init__(self, hparams: TimesFmHparams,\n               checkpoint: TimesFmCheckpoint) -> None:\n    \"\"\"Initializes the TimesFM forecast API.\n\n    Args:\n      hparams: Hyperparameters of the model.\n      checkpoint: Checkpoint to load. Notice `checkpoint.version` will decide\n        which TimesFM version to use.\n    \"\"\"\n    self.hparams = hparams\n\n    # Expand hparams for conciseness within the model code.\n    self.context_len = hparams.context_len\n    self.horizon_len = hparams.horizon_len\n    self.input_patch_len = hparams.input_patch_len\n    self.output_patch_len = hparams.output_patch_len\n    self.num_layers = hparams.num_layers\n    self.model_dims = hparams.model_dims\n    self.backend = hparams.backend\n    self.quantiles = hparams.quantiles\n    self.num_heads = hparams.num_heads\n    self.use_pos_emb = hparams.use_positional_embedding\n\n    # Rewrite these values in __post_init__ for SPMD.\n    self.num_cores = 1\n    self.per_core_batch_size = hparams.per_core_batch_size\n    self.global_batch_size = hparams.per_core_batch_size\n\n    self._horizon_start = self.context_len - self.input_patch_len\n    self.__post_init__()\n    self.load_from_checkpoint(checkpoint)\n\n  def load_from_checkpoint(self, checkpoint: TimesFmCheckpoint) -> None:\n    \"\"\"Loads a checkpoint and compiles the decoder.\"\"\"\n    raise NotImplementedError(\"`load_from_checkpoint` is not implemented.\")\n\n  def _preprocess(\n      self, inputs: Sequence[np.ndarray],\n      freq: Sequence[int]) -> tuple[np.ndarray, np.ndarray, np.ndarray, int]:\n    \"\"\"Formats and pads raw inputs to feed into the model.\n\n    This function both pads each time series to match the context length, and\n    pads the inputs to meet the SPMD shape requirement.\n\n    Args:\n      inputs: A list of 1d JTensors. Each JTensor is the context time series of\n        a single forecast task.\n      freq: list of frequencies\n\n    Returns:\n    A tuple of:\n    - the padded input time series to meet the model required context.\n    - the padding indicator.\n    - the frequency of each input time series.\n    - the number of padded examples for SPMD so that each core has the same\n        number (a multiple of `batch_size`) of examples.\n    \"\"\"\n\n    input_ts, input_padding, inp_freq = [], [], []\n\n    pmap_pad = ((len(inputs) - 1) // self.global_batch_size +\n                1) * self.global_batch_size - len(inputs)\n\n    for i, ts in enumerate(inputs):\n      input_len = ts.shape[0]\n      padding = np.zeros(shape=(input_len + self.horizon_len,), dtype=float)\n      if input_len < self.context_len:\n        num_front_pad = self.context_len - input_len\n        ts = np.concatenate([np.zeros(shape=(num_front_pad,), dtype=float), ts],\n                            axis=0)\n        padding = np.concatenate(\n            [np.ones(shape=(num_front_pad,), dtype=float), padding], axis=0)\n      elif input_len > self.context_len:\n        ts = ts[-self.context_len:]\n        padding = padding[-(self.context_len + self.horizon_len):]\n\n      input_ts.append(ts)\n      input_padding.append(padding)\n      inp_freq.append(freq[i])\n\n    # Padding the remainder batch.\n    for _ in range(pmap_pad):\n      input_ts.append(input_ts[-1])\n      input_padding.append(input_padding[-1])\n      inp_freq.append(inp_freq[-1])\n\n    return (\n        np.stack(input_ts, axis=0),\n        np.stack(input_padding, axis=0),\n        np.array(inp_freq).astype(np.int32).reshape(-1, 1),\n        pmap_pad,\n    )\n\n  def _forecast(\n      self,\n      inputs: Sequence[Any],\n      freq: Sequence[int] | None = None,\n      window_size: int | None = None,\n      forecast_context_len: int | None = None,\n      return_forecast_on_context: bool = False,\n  ) -> tuple[np.ndarray, np.ndarray]:\n    \"\"\"Forecasts on a list of time series.\n\n    Args:\n      inputs: list of time series forecast contexts. Each context time series\n        should be in a format convertible to JTensor by `jnp.array`.\n      freq: frequency of each context time series. 0 for high frequency\n        (default), 1 for medium, and 2 for low. Notice this is different from\n        the `freq` required by `forecast_on_df`.\n      window_size: window size of trend + residual decomposition. If None then\n        we do not do decomposition.\n      forecast_context_len: optional max context length.\n      return_forecast_on_context: True to return the forecast on the context\n        when available, i.e. after the first input patch.\n\n    Returns:\n    A tuple for np.array:\n    - the mean forecast of size (# inputs, # forecast horizon),\n    - the full forecast (mean + quantiles) of size\n        (# inputs,  # forecast horizon, 1 + # quantiles).\n\n    Raises:\n    ValueError: If the checkpoint is not properly loaded.\n    \"\"\"\n    raise NotImplementedError(\"`_forecast` is not implemented.\")\n\n  def forecast(\n      self,\n      inputs: Sequence[Any],\n      freq: Sequence[int] | None = None,\n      window_size: int | None = None,\n      forecast_context_len: int | None = None,\n      return_forecast_on_context: bool = False,\n      normalize: bool = False,\n  ) -> tuple[np.ndarray, np.ndarray]:\n    \"\"\"Forecasts on a list of time series.\n\n    Args:\n      inputs: list of time series forecast contexts. Each context time series\n        should be in a format convertible to JTensor by `jnp.array`.\n      freq: frequency of each context time series. 0 for high frequency\n        (default), 1 for medium, and 2 for low. Notice this is different from\n        the `freq` required by `forecast_on_df`.\n      window_size: window size of trend + residual decomposition. If None then\n        we do not do decomposition.\n      forecast_context_len: optional max context length.\n      return_forecast_on_context: True to return the forecast on the context\n        when available, i.e. after the first input patch.\n      normalize: If True, then we normalize the inputs before forecasting and\n        the outputs are then renormalized to the original scale.\n\n    Returns:\n    A tuple for np.array:\n    - the mean forecast of size (# inputs, # forecast horizon),\n    - the full forecast (mean + quantiles) of size\n        (# inputs,  # forecast horizon, 1 + # quantiles).\n\n    Raises:\n    ValueError: If the checkpoint is not properly loaded.\n    \"\"\"\n    stats = None\n\n    tmp_inputs = []\n    for each_input in inputs:\n      arr = np.array(each_input)\n      if not np.isfinite(arr).all():\n        arr = np.where(np.isfinite(arr), arr, np.nan)\n        arr = strip_leading_nans(arr)\n        arr = linear_interpolation(arr)\n      tmp_inputs.append(arr)\n\n    inputs = tmp_inputs\n    if normalize:\n      inputs, stats = _normalize(inputs)\n    mean_forecast, quantile_forecast = self._forecast(\n        inputs,\n        freq,\n        window_size,\n        forecast_context_len,\n        return_forecast_on_context,\n    )\n    if stats is not None:\n      stats = np.array(stats)\n      mu = stats[:, 0]\n      sigma = stats[:, 1]\n      mean_forecast = mean_forecast * sigma[:, None] + mu[:, None]\n      quantile_forecast = (quantile_forecast * sigma[:, None, None] +\n                           mu[:, None, None])\n    if self.hparams.point_forecast_mode == \"mean\":\n      return mean_forecast, quantile_forecast\n    elif self.hparams.point_forecast_mode == \"median\":\n      if self._median_index == -1:\n        for i, quantile in enumerate(self.quantiles):\n          if quantile == 0.5:\n            self._median_index = i\n            break\n        if self._median_index == -1:\n          raise ValueError(\"Median (0.5) is not found in the model quantiles:\"\n                           f\" {self.quantiles}. Please check the hparams.\")\n      return (\n          quantile_forecast[:, :, 1 + self._median_index],\n          quantile_forecast,\n      )\n    else:\n      raise ValueError(\n          \"Unsupported point forecast mode:\"\n          f\" {self.hparams.point_forecast_mode}. Use 'mean' or 'median'.\")\n\n  def forecast_with_covariates(\n      self,\n      inputs: list[Sequence[float]],\n      dynamic_numerical_covariates: (dict[str, Sequence[Sequence[float]]] |\n                                     None) = None,\n      dynamic_categorical_covariates: (dict[str, Sequence[Sequence[Category]]] |\n                                       None) = None,\n      static_numerical_covariates: dict[str, Sequence[float]] | None = None,\n      static_categorical_covariates: (dict[str, Sequence[Category]] |\n                                      None) = None,\n      freq: Sequence[int] | None = None,\n      window_size: int | None = None,\n      forecast_context_len: int | None = None,\n      xreg_mode: XRegMode = \"xreg + timesfm\",\n      normalize_xreg_target_per_input: bool = True,\n      ridge: float = 0.0,\n      max_rows_per_col: int = 0,\n      force_on_cpu: bool = False,\n  ):\n    \"\"\"Forecasts on a list of time series with covariates.\n\n    To optimize inference speed, avoid string valued categorical covariates.\n\n    Args:\n      inputs: A list of time series forecast contexts. Each context time series\n        should be in a format convertible to JTensor by `jnp.array`.\n      dynamic_numerical_covariates: A dict of dynamic numerical covariates.\n      dynamic_categorical_covariates: A dict of dynamic categorical covariates.\n      static_numerical_covariates: A dict of static numerical covariates.\n      static_categorical_covariates: A dict of static categorical covariates.\n      freq: frequency of each context time series. 0 for high frequency\n        (default), 1 for medium, and 2 for low. Notice this is different from\n        the `freq` required by `forecast_on_df`.\n      window_size: window size of trend + residual decomposition. If None then\n        we do not do decomposition.\n      forecast_context_len: optional max context length.\n      xreg_mode: one of \"xreg + timesfm\" or \"timesfm + xreg\". \"xreg + timesfm\"\n        fits a model on the residuals of the TimesFM forecast. \"timesfm + xreg\"\n        fits a model on the targets then forecasts on the residuals via TimesFM.\n      normalize_xreg_target_per_input: whether to normalize the xreg target per\n        input in the given batch.\n      ridge: ridge penalty for the linear model.\n      max_rows_per_col: max number of rows per column for the linear model.\n      force_on_cpu: whether to force running on cpu for the linear model.\n\n    Returns:\n      A tuple of two lists. The first is the outputs of the model. The second is\n      the outputs of the xreg.\n    \"\"\"\n\n    from . import xreg_lib\n\n    # Verify and bookkeep covariates.\n    if not (dynamic_numerical_covariates or dynamic_categorical_covariates or\n            static_numerical_covariates or static_categorical_covariates):\n      raise ValueError(\n          \"At least one of dynamic_numerical_covariates,\"\n          \" dynamic_categorical_covariates, static_numerical_covariates,\"\n          \" static_categorical_covariates must be set.\")\n\n    # Track the lengths of (1) each input, (2) the part that can be used in the\n    # linear model, and (3) the horizon.\n    input_lens, train_lens, test_lens = [], [], []\n\n    for i, input_ts in enumerate(inputs):\n      input_len = len(input_ts)\n      input_lens.append(input_len)\n\n      if xreg_mode == \"timesfm + xreg\":\n        # For fitting residuals, no TimesFM forecast on the first patch.\n        train_lens.append(max(0, input_len - self.input_patch_len))\n      elif xreg_mode == \"xreg + timesfm\":\n        train_lens.append(input_len)\n      else:\n        raise ValueError(f\"Unsupported mode: {xreg_mode}\")\n\n      if dynamic_numerical_covariates:\n        test_lens.append(\n            len(list(dynamic_numerical_covariates.values())[0][i]) - input_len)\n      elif dynamic_categorical_covariates:\n        test_lens.append(\n            len(list(dynamic_categorical_covariates.values())[0][i]) -\n            input_len)\n      else:\n        test_lens.append(self.horizon_len)\n\n      if test_lens[-1] > self.horizon_len:\n        raise ValueError(\n            \"Forecast requested longer horizon than the model definition \"\n            f\"supports: {test_lens[-1]} vs {self.horizon_len}.\")\n\n    # Prepare the covariates into train and test.\n    train_dynamic_numerical_covariates = collections.defaultdict(list)\n    test_dynamic_numerical_covariates = collections.defaultdict(list)\n    train_dynamic_categorical_covariates = collections.defaultdict(list)\n    test_dynamic_categorical_covariates = collections.defaultdict(list)\n    for covariates, train_covariates, test_covariates in (\n        (\n            dynamic_numerical_covariates,\n            train_dynamic_numerical_covariates,\n            test_dynamic_numerical_covariates,\n        ),\n        (\n            dynamic_categorical_covariates,\n            train_dynamic_categorical_covariates,\n            test_dynamic_categorical_covariates,\n        ),\n    ):\n      if not covariates:\n        continue\n      for covariate_name, covariate_values in covariates.items():\n        for input_len, train_len, covariate_value in zip(\n            input_lens, train_lens, covariate_values):\n          train_covariates[covariate_name].append(\n              covariate_value[(input_len - train_len):input_len])\n          test_covariates[covariate_name].append(covariate_value[input_len:])\n\n    # Fit models.\n    if xreg_mode == \"timesfm + xreg\":\n      # Forecast via TimesFM then fit a model on the residuals.\n      mean_outputs, _ = self.forecast(\n          inputs,\n          freq,\n          window_size,\n          forecast_context_len,\n          return_forecast_on_context=True,\n      )\n      targets = [\n          (np.array(input_ts)[-train_len:] -\n           mean_output[(self._horizon_start - train_len):self._horizon_start])\n          for input_ts, mean_output, train_len in zip(inputs, mean_outputs,\n                                                      train_lens)\n      ]\n      per_instance_stats = None\n      if normalize_xreg_target_per_input:\n        targets, per_instance_stats = _normalize(targets)\n      xregs = xreg_lib.BatchedInContextXRegLinear(\n          targets=targets,\n          train_lens=train_lens,\n          test_lens=test_lens,\n          train_dynamic_numerical_covariates=train_dynamic_numerical_covariates,\n          test_dynamic_numerical_covariates=test_dynamic_numerical_covariates,\n          train_dynamic_categorical_covariates=\n          train_dynamic_categorical_covariates,\n          test_dynamic_categorical_covariates=\n          test_dynamic_categorical_covariates,\n          static_numerical_covariates=static_numerical_covariates,\n          static_categorical_covariates=static_categorical_covariates,\n      ).fit(\n          ridge=ridge,\n          one_hot_encoder_drop=None if ridge > 0 else \"first\",\n          max_rows_per_col=max_rows_per_col,\n          force_on_cpu=force_on_cpu,\n          debug_info=False,\n          assert_covariates=True,\n          assert_covariate_shapes=True,\n      )\n      if normalize_xreg_target_per_input:\n        xregs = _renormalize(xregs, per_instance_stats)\n      outputs = [\n          (mean_output[self._horizon_start:(self._horizon_start + test_len)] +\n           xreg)\n          for mean_output, test_len, xreg in zip(mean_outputs, test_lens, xregs)\n      ]\n\n    else:\n      # Fit a model on the targets then forecast on the residuals via TimesFM.\n      targets = [\n          np.array(input_ts)[-train_len:]\n          for input_ts, train_len in zip(inputs, train_lens)\n      ]\n      per_instance_stats = None\n      if normalize_xreg_target_per_input:\n        targets, per_instance_stats = _normalize(targets)\n      xregs, xregs_on_context, _, _, _ = xreg_lib.BatchedInContextXRegLinear(\n          targets=targets,\n          train_lens=train_lens,\n          test_lens=test_lens,\n          train_dynamic_numerical_covariates=train_dynamic_numerical_covariates,\n          test_dynamic_numerical_covariates=test_dynamic_numerical_covariates,\n          train_dynamic_categorical_covariates=\n          train_dynamic_categorical_covariates,\n          test_dynamic_categorical_covariates=\n          test_dynamic_categorical_covariates,\n          static_numerical_covariates=static_numerical_covariates,\n          static_categorical_covariates=static_categorical_covariates,\n      ).fit(\n          ridge=ridge,\n          one_hot_encoder_drop=None if ridge > 0 else \"first\",\n          max_rows_per_col=max_rows_per_col,\n          force_on_cpu=force_on_cpu,\n          debug_info=True,\n          assert_covariates=True,\n          assert_covariate_shapes=True,\n      )\n      mean_outputs, _ = self.forecast(\n          [\n              target - xreg_on_context\n              for target, xreg_on_context in zip(targets, xregs_on_context)\n          ],\n          freq,\n          window_size,\n          forecast_context_len,\n          return_forecast_on_context=True,\n      )\n      outputs = [\n          (mean_output[self._horizon_start:(self._horizon_start + test_len)] +\n           xreg)\n          for mean_output, test_len, xreg in zip(mean_outputs, test_lens, xregs)\n      ]\n      if normalize_xreg_target_per_input:\n        outputs = _renormalize(outputs, per_instance_stats)\n\n    return outputs, xregs\n\n  def forecast_on_df(\n      self,\n      inputs: pd.DataFrame,\n      freq: str,\n      forecast_context_len: int = 0,\n      value_name: str = \"values\",\n      model_name: str = \"timesfm\",\n      window_size: int | None = None,\n      num_jobs: int = 1,\n      normalize: bool = False,\n      verbose: bool = True,\n  ) -> pd.DataFrame:\n    \"\"\"Forecasts on a list of time series.\n\n    Args:\n      inputs: A pd.DataFrame of all time series. The dataframe should have a\n        `unique_id` column for identifying the time series, a `ds` column for\n        timestamps and a value column for the time series values.\n      freq: string valued `freq` of data. Notice this is different from the\n        `freq` required by `forecast`. See `freq_map` for allowed values.\n      forecast_context_len: If provided none zero, we take the last\n        `forecast_context_len` time-points from each series as the forecast\n        context instead of the `context_len` set by the model.\n      value_name: The name of the value column.\n      model_name: name of the model to be written into future df.\n      window_size: window size of trend + residual decomposition. If None then\n        we do not do decomposition.\n      num_jobs: number of parallel processes to use for dataframe processing.\n      normalize: normalize context before forecasting or not.\n      verbose: output model states in terminal.\n\n    Returns:\n      Future forecasts dataframe.\n    \"\"\"\n    if not (\"unique_id\" in inputs.columns and \"ds\" in inputs.columns and\n            value_name in inputs.columns):\n      raise ValueError(\n          f\"DataFrame must have unique_id, ds and {value_name} columns.\")\n    if not forecast_context_len:\n      forecast_context_len = self.context_len\n    logging.info(\"Preprocessing dataframe.\")\n    df_sorted = inputs.sort_values(by=[\"unique_id\", \"ds\"])\n    new_inputs = []\n    uids = []\n    if num_jobs == 1:\n      if verbose:\n        print(\"Processing dataframe with single process.\")\n      for key, group in df_sorted.groupby(\"unique_id\"):\n        inp, uid = process_group(\n            key,\n            group,\n            value_name,\n            forecast_context_len,\n        )\n        new_inputs.append(inp)\n        uids.append(uid)\n    else:\n      if num_jobs == -1:\n        num_jobs = multiprocessing.cpu_count()\n      if verbose:\n        print(\"Processing dataframe with multiple processes.\")\n      with multiprocessing.Pool(processes=num_jobs) as pool:\n        results = pool.starmap(\n            process_group,\n            [(key, group, value_name, forecast_context_len)\n             for key, group in df_sorted.groupby(\"unique_id\")],\n        )\n      new_inputs, uids = zip(*results)\n    if verbose:\n      print(\"Finished preprocessing dataframe.\")\n    freq_inps = [freq_map(freq)] * len(new_inputs)\n    _, full_forecast = self.forecast(new_inputs,\n                                     freq=freq_inps,\n                                     normalize=normalize,\n                                     window_size=window_size)\n    if verbose:\n      print(\"Finished forecasting.\")\n    fcst_df = make_future_dataframe(\n        uids=uids,\n        last_times=df_sorted.groupby(\"unique_id\")[\"ds\"].tail(1),\n        h=self.horizon_len,\n        freq=freq,\n    )\n    fcst_df[model_name] = full_forecast[:, 0:self.horizon_len, 0].reshape(-1, 1)\n\n    for i, q in enumerate(self.quantiles):\n      q_col = f\"{model_name}-q-{q}\"\n      fcst_df[q_col] = full_forecast[:, 0:self.horizon_len,\n                                     1 + i].reshape(-1, 1)\n      if q == 0.5:\n        fcst_df[model_name] = fcst_df[q_col]\n    logging.info(\"Finished creating output dataframe.\")\n    return fcst_df\n"
  },
  {
    "path": "v1/src/timesfm/timesfm_jax.py",
    "content": "# Copyright 2024 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"TimesFM JAX forecast API for inference.\"\"\"\n\nimport logging\nimport multiprocessing\nimport time\nfrom os import path\nfrom typing import Any, Sequence\n\nimport einshape as es\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\nfrom huggingface_hub import snapshot_download\n\nfrom paxml import checkpoints, tasks_lib\nfrom praxis import base_hyperparams, base_layer, pax_fiddle, py_utils, pytypes\nfrom praxis.layers import normalizations, transformers\nfrom timesfm import timesfm_base\nfrom timesfm import patched_decoder\n\ninstantiate = base_hyperparams.instantiate\nNestedMap = py_utils.NestedMap\nJTensor = pytypes.JTensor\n\n_TOL = 1e-6\n\n\nclass TimesFmJax(timesfm_base.TimesFmBase):\n  \"\"\"TimesFM forecast API for inference.\n\n  This class is the scaffolding for calling TimesFM forecast. To properly use:\n    1. Create an instance with the correct hyperparameters of a TimesFM model.\n    2. Call `load_from_checkpoint` to load a compatible checkpoint.\n    3. Call `forecast` for inference.\n\n  Given the model size, this API does not shard the model weights for SPMD. All\n  parallelism happens on the data dimension.\n\n  Compilation happens during the first time `forecast` is called and uses the\n  `per_core_batch_size` to set and freeze the input signature. Subsequent calls\n  to `forecast` reflect the actual inference latency.\n  \"\"\"\n\n  def _get_sample_inputs(self):\n    return {\n        \"input_ts\":\n            jnp.zeros(\n                (\n                    self.per_core_batch_size,\n                    self.context_len + self.output_patch_len,\n                ),\n                dtype=jnp.float32,\n            ),\n        \"input_padding\":\n            jnp.zeros(\n                (\n                    self.per_core_batch_size,\n                    self.context_len + self.output_patch_len,\n                ),\n                dtype=jnp.float32,\n            ),\n        \"freq\":\n            jnp.zeros(\n                (\n                    self.per_core_batch_size,\n                    1,\n                ),\n                dtype=jnp.int32,\n            ),\n    }\n\n  def __post_init__(self):\n    self.num_cores = jax.local_device_count(self.backend)\n    self.global_batch_size = self.per_core_batch_size * self.num_cores\n    self._eval_context = base_layer.JaxContext.HParams(do_eval=True)\n    self._pmapped_decode = None\n    self._model = None\n    self._train_state = None\n    self._median_index = -1\n\n  def load_from_checkpoint(\n      self,\n      checkpoint: timesfm_base.TimesFmCheckpoint,\n  ) -> None:\n    \"\"\"Loads a checkpoint and compiles the decoder.\"\"\"\n    checkpoint_type = (checkpoints.CheckpointType.FLAX\n                       if checkpoint.type is None else checkpoint.type)\n    checkpoint_path = checkpoint.path\n    step = checkpoint.step\n    repo_id = checkpoint.huggingface_repo_id\n    if checkpoint_path is None:\n      checkpoint_path = path.join(snapshot_download(repo_id), \"checkpoints\")\n    # Rewrite the devices for Jax.\n    self.mesh_shape = [1, self.num_cores, 1]\n    self.mesh_name = [\"replica\", \"data\", \"mdl\"]\n\n    self.model_p = pax_fiddle.Config(\n        patched_decoder.PatchedTimeSeriesDecoder,\n        name=\"patched_decoder\",\n        horizon_len=self.output_patch_len,\n        patch_len=self.input_patch_len,\n        model_dims=self.model_dims,\n        hidden_dims=self.model_dims,\n        residual_block_tpl=pax_fiddle.Config(patched_decoder.ResidualBlock),\n        quantiles=self.quantiles,\n        use_freq=True,\n        use_pos_emb=self.use_pos_emb,\n        stacked_transformer_params_tpl=pax_fiddle.Config(\n            transformers.StackedTransformer,\n            num_heads=self.num_heads,\n            num_layers=self.num_layers,\n            transformer_layer_params_tpl=pax_fiddle.Config(\n                transformers.Transformer,\n                ln_tpl=pax_fiddle.Config(normalizations.RmsNorm,),\n            ),\n        ),\n    )\n\n    self._key1, self._key2 = jax.random.split(jax.random.PRNGKey(42))\n    self._model = None\n    self._train_state = None\n    self._pmapped_decode = None\n    self._eval_context = base_layer.JaxContext.HParams(do_eval=True)\n    try:\n      multiprocessing.set_start_method(\"spawn\")\n    except RuntimeError:\n      print(\"Multiprocessing context has already been set.\")\n    # Download the checkpoint from Hugging Face Hub if not given\n\n    #  Initialize the model weights.\n    self._logging(\"Constructing model weights.\")\n    start_time = time.time()\n    self._model = instantiate(self.model_p)\n    var_weight_hparams = self._model.abstract_init_with_metadata(\n        self._get_sample_inputs(), do_eval=True)\n    train_state_partition_specs = tasks_lib.create_state_partition_specs(\n        var_weight_hparams,\n        mesh_shape=self.mesh_shape,\n        mesh_axis_names=self.mesh_name,\n        discard_opt_states=True,\n        learners=None,\n    )\n    train_state_local_shapes = tasks_lib.create_state_unpadded_shapes(\n        var_weight_hparams,\n        discard_opt_states=True,\n        learners=None,\n    )\n    self._logging(\n        f\"Constructed model weights in {time.time() - start_time:.2f} seconds.\")\n\n    # Load the model weights.\n    self._logging(f\"Restoring checkpoint from {checkpoint_path}.\")\n    start_time = time.time()\n    self._train_state = checkpoints.restore_checkpoint(\n        train_state_local_shapes,\n        checkpoint_dir=checkpoint_path,\n        checkpoint_type=checkpoint_type,\n        state_specs=train_state_partition_specs,\n        step=step,\n    )\n    self._logging(\n        f\"Restored checkpoint in {time.time() - start_time:.2f} seconds.\")\n    self.jit_decode()\n\n  def jit_decode(self):\n    \"\"\"Jitting decoding function.\"\"\"\n\n    # Initialize and jit the decode fn.\n    def _decode(inputs):\n      assert self._model is not None\n      assert self._train_state is not None\n      return self._model.apply(\n          self._train_state.mdl_vars,\n          inputs,\n          horizon_len=self.horizon_len,\n          output_patch_len=self.output_patch_len,\n          max_len=self.context_len,\n          return_forecast_on_context=True,\n          rngs={\n              base_layer.PARAMS: self._key1,\n              base_layer.RANDOM: self._key2,\n          },\n          method=self._model.decode,\n      )\n\n    self._logging(\"Jitting decoding.\")\n    start_time = time.time()\n    self._pmapped_decode = jax.pmap(\n        _decode,\n        axis_name=\"batch\",\n        devices=jax.devices(self.backend),\n        backend=self.backend,\n        axis_size=self.num_cores,\n    )\n    with base_layer.JaxContext.new_context(hparams=self._eval_context):\n      _ = self._pmapped_decode(\n          NestedMap({\n              \"input_ts\":\n                  jnp.zeros(\n                      (\n                          self.num_cores,\n                          self.per_core_batch_size,\n                          self.context_len,\n                      ),\n                      dtype=jnp.float32,\n                  ),\n              \"input_padding\":\n                  jnp.zeros(\n                      (\n                          self.num_cores,\n                          self.per_core_batch_size,\n                          self.context_len + self.horizon_len,\n                      ),\n                      dtype=jnp.float32,\n                  ),\n              \"date_features\":\n                  None,\n              \"freq\":\n                  jnp.zeros(\n                      (self.num_cores, self.per_core_batch_size, 1),\n                      dtype=jnp.int32,\n                  ),\n          }))\n    self._logging(f\"Jitted decoding in {time.time() - start_time:.2f} seconds.\")\n\n  def _forecast(\n      self,\n      inputs: Sequence[Any],\n      freq: Sequence[int] | None = None,\n      window_size: int | None = None,\n      forecast_context_len: int | None = None,\n      return_forecast_on_context: bool = False,\n  ) -> tuple[np.ndarray, np.ndarray]:\n    \"\"\"Forecasts on a list of time series.\n\n    Args:\n      inputs: list of time series forecast contexts. Each context time series\n        should be in a format convertible to JTensor by `jnp.array`.\n      freq: frequency of each context time series. 0 for high frequency\n        (default), 1 for medium, and 2 for low. Notice this is different from\n        the `freq` required by `forecast_on_df`.\n      window_size: window size of trend + residual decomposition. If None then\n        we do not do decomposition.\n      forecast_context_len: optional max context length.\n      return_forecast_on_context: True to return the forecast on the context\n        when available, i.e. after the first input patch.\n\n    Returns:\n    A tuple for JTensors:\n    - the mean forecast of size (# inputs, # forecast horizon),\n    - the full forecast (mean + quantiles) of size\n        (# inputs,  # forecast horizon, 1 + # quantiles).\n\n    Raises:\n    ValueError: If the checkpoint is not properly loaded.\n    \"\"\"\n    if not self._train_state or not self._model:\n      raise ValueError(\n          \"Checkpoint not loaded. Call `load_from_checkpoint` before\"\n          \" `forecast`.\")\n    if forecast_context_len is None:\n      fcontext_len = self.context_len\n    else:\n      fcontext_len = forecast_context_len\n    inputs = [np.array(ts)[-fcontext_len:] for ts in inputs]\n\n    if window_size is not None:\n      new_inputs = []\n      for ts in inputs:\n        new_inputs.extend(timesfm_base.moving_average(ts, window_size))\n      inputs = new_inputs\n\n    if freq is None:\n      logging.info(\"No frequency provided via `freq`. Default to high (0).\")\n      freq = [0] * len(inputs)\n\n    input_ts, input_padding, inp_freq, pmap_pad = self._preprocess(inputs, freq)\n    with base_layer.JaxContext.new_context(hparams=self._eval_context):\n      mean_outputs = []\n      full_outputs = []\n      assert input_ts.shape[0] % self.global_batch_size == 0\n      for i in range(input_ts.shape[0] // self.global_batch_size):\n        input_ts_in = jnp.array(input_ts[i * self.global_batch_size:(i + 1) *\n                                         self.global_batch_size])\n        input_padding_in = jnp.array(\n            input_padding[i * self.global_batch_size:(i + 1) *\n                          self.global_batch_size],)\n        inp_freq_in = jnp.array(\n            inp_freq[i * self.global_batch_size:(i + 1) *\n                     self.global_batch_size, :],\n            dtype=jnp.int32,\n        )\n        pmapped_inputs = NestedMap({\n            \"input_ts\":\n                es.jax_einshape(\n                    \"(db)...->db...\",\n                    input_ts_in,\n                    d=self.num_cores,\n                ),\n            \"input_padding\":\n                es.jax_einshape(\n                    \"(db)...->db...\",\n                    input_padding_in,\n                    d=self.num_cores,\n                ),\n            \"date_features\":\n                None,\n            \"freq\":\n                es.jax_einshape(\n                    \"(db)...->db...\",\n                    inp_freq_in,\n                    d=self.num_cores,\n                ),\n        })\n        mean_output, full_output = self._pmapped_decode(pmapped_inputs)\n        if not return_forecast_on_context:\n          mean_output = mean_output[:, :, self._horizon_start:, ...]\n          full_output = full_output[:, :, self._horizon_start:, ...]\n        mean_output = es.jax_einshape(\"db...->(db)...\",\n                                      mean_output,\n                                      d=self.num_cores)\n        full_output = es.jax_einshape(\"db...->(db)...\",\n                                      full_output,\n                                      d=self.num_cores)\n        mean_output = np.array(mean_output)\n        full_output = np.array(full_output)\n        mean_outputs.append(mean_output)\n        full_outputs.append(full_output)\n\n    mean_outputs = np.concatenate(mean_outputs, axis=0)\n    full_outputs = np.concatenate(full_outputs, axis=0)\n\n    if pmap_pad > 0:\n      mean_outputs = mean_outputs[:-pmap_pad, ...]\n      full_outputs = full_outputs[:-pmap_pad, ...]\n\n    if window_size is not None:\n      mean_outputs = mean_outputs[0::2, ...] + mean_outputs[1::2, ...]\n      full_outputs = full_outputs[0::2, ...] + full_outputs[1::2, ...]\n    return mean_outputs, full_outputs\n"
  },
  {
    "path": "v1/src/timesfm/timesfm_torch.py",
    "content": "# Copyright 2024 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"TimesFM pytorch forecast API for inference.\"\"\"\n\nimport logging\nfrom os import path\nfrom typing import Any, Sequence\n\nimport numpy as np\nimport torch\nfrom huggingface_hub import snapshot_download\nfrom timesfm import timesfm_base\n\nfrom . import pytorch_patched_decoder as ppd\n\n_TOL = 1e-6\n\n\nclass TimesFmTorch(timesfm_base.TimesFmBase):\n  \"\"\"TimesFM forecast API for inference.\"\"\"\n\n  def __post_init__(self):\n    self._model_config = ppd.TimesFMConfig(\n        num_layers=self.num_layers,\n        num_heads=self.num_heads,\n        hidden_size=self.model_dims,\n        intermediate_size=self.model_dims,\n        patch_len=self.input_patch_len,\n        horizon_len=self.output_patch_len,\n        head_dim=self.model_dims // self.num_heads,\n        quantiles=self.quantiles,\n        use_positional_embedding=self.use_pos_emb,\n    )\n    self._model = None\n    self.num_cores = 1\n    self.global_batch_size = self.per_core_batch_size\n    self._device = torch.device(\"cuda:0\" if (\n        torch.cuda.is_available() and self.backend == \"gpu\") else \"cpu\")\n    self._median_index = -1\n\n  def load_from_checkpoint(\n      self,\n      checkpoint: timesfm_base.TimesFmCheckpoint,\n  ) -> None:\n    \"\"\"Loads a checkpoint and compiles the decoder.\"\"\"\n    checkpoint_path = checkpoint.path\n    repo_id = checkpoint.huggingface_repo_id\n    if checkpoint_path is None:\n      checkpoint_path = path.join(\n          snapshot_download(repo_id, local_dir=checkpoint.local_dir),\n          \"torch_model.ckpt\")\n    self._model = ppd.PatchedTimeSeriesDecoder(self._model_config)\n    loaded_checkpoint = torch.load(checkpoint_path, weights_only=True)\n    logging.info(\"Loading checkpoint from %s\", checkpoint_path)\n    self._model.load_state_dict(loaded_checkpoint)\n    logging.info(\"Sending checkpoint to device %s\", f\"{self._device}\")\n    self._model.to(self._device)\n    self._model.eval()\n    # TODO: add compilation.\n\n  def _forecast(\n      self,\n      inputs: Sequence[Any],\n      freq: Sequence[int] | None = None,\n      window_size: int | None = None,\n      forecast_context_len: int | None = None,\n      return_forecast_on_context: bool = False,\n  ) -> tuple[np.ndarray, np.ndarray]:\n    \"\"\"Forecasts on a list of time series.\n\n    Args:\n      inputs: list of time series forecast contexts. Each context time series\n        should be in a format convertible to JTensor by `jnp.array`.\n      freq: frequency of each context time series. 0 for high frequency\n        (default), 1 for medium, and 2 for low. Notice this is different from\n        the `freq` required by `forecast_on_df`.\n      window_size: window size of trend + residual decomposition. If None then\n        we do not do decomposition.\n      forecast_context_len: optional max context length.\n      return_forecast_on_context: True to return the forecast on the context\n        when available, i.e. after the first input patch.\n\n    Returns:\n    A tuple for JTensors:\n    - the mean forecast of size (# inputs, # forecast horizon),\n    - the full forecast (mean + quantiles) of size\n        (# inputs,  # forecast horizon, 1 + # quantiles).\n\n    Raises:\n    ValueError: If the checkpoint is not properly loaded.\n    \"\"\"\n    if self._model is None:\n      raise ValueError(\"Checkpoint is not properly loaded.\")\n\n    if forecast_context_len is None:\n      forecast_context_len = self.context_len\n    inputs = [np.array(ts)[-forecast_context_len:] for ts in inputs]\n\n    if window_size is not None:\n      new_inputs = []\n      for ts in inputs:\n        new_inputs.extend(timesfm_base.moving_average(ts, window_size))\n      inputs = new_inputs\n\n    if freq is None:\n      logging.info(\"No frequency provided via `freq`. Default to high (0).\")\n      freq = [0] * len(inputs)\n\n    input_ts, input_padding, inp_freq, pmap_pad = self._preprocess(inputs, freq)\n\n    with torch.no_grad():\n      mean_outputs = []\n      full_outputs = []\n      for i in range(input_ts.shape[0] // self.global_batch_size):\n        t_input_ts = torch.Tensor(input_ts[i * self.global_batch_size:(i + 1) *\n                                           self.global_batch_size]).to(\n                                               self._device)\n        t_input_padding = torch.Tensor(\n            input_padding[i * self.global_batch_size:(i + 1) *\n                          self.global_batch_size]).to(self._device)\n        t_inp_freq = torch.LongTensor(\n            inp_freq[i * self.global_batch_size:(i + 1) *\n                     self.global_batch_size, :]).to(self._device)\n\n        mean_output, full_output = self._model.decode(\n            input_ts=t_input_ts,\n            paddings=t_input_padding,\n            freq=t_inp_freq,\n            horizon_len=self.horizon_len,\n            output_patch_len=self.output_patch_len,\n            # Returns forecasts on context for parity with the Jax version.\n            return_forecast_on_context=True,\n        )\n        if not return_forecast_on_context:\n          mean_output = mean_output[:, self._horizon_start:, ...]\n          full_output = full_output[:, self._horizon_start:, ...]\n\n        if self.backend == \"gpu\":\n          mean_output = mean_output.cpu()\n          full_output = full_output.cpu()\n        mean_output = mean_output.detach().numpy()\n        full_output = full_output.detach().numpy()\n        mean_outputs.append(mean_output)\n        full_outputs.append(full_output)\n\n    mean_outputs = np.concatenate(mean_outputs, axis=0)\n    full_outputs = np.concatenate(full_outputs, axis=0)\n\n    if pmap_pad > 0:\n      mean_outputs = mean_outputs[:-pmap_pad, ...]\n      full_outputs = full_outputs[:-pmap_pad, ...]\n\n    if window_size is not None:\n      mean_outputs = mean_outputs[0::2, ...] + mean_outputs[1::2, ...]\n      full_outputs = full_outputs[0::2, ...] + full_outputs[1::2, ...]\n\n    return mean_outputs, full_outputs\n"
  },
  {
    "path": "v1/src/timesfm/xreg_lib.py",
    "content": "# Copyright 2024 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Helper functions for in-context covariates and regression.\"\"\"\n\nimport itertools\nimport math\nfrom typing import Any, Iterable, Literal, Mapping, Sequence\n\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\nfrom sklearn import preprocessing\n\nCategory = int | str\n\n_TOL = 1e-6\nXRegMode = Literal[\"timesfm + xreg\", \"xreg + timesfm\"]\n\n\ndef _unnest(nested: Sequence[Sequence[Any]]) -> np.ndarray:\n  return np.array(list(itertools.chain.from_iterable(nested)))\n\n\ndef _repeat(elements: Iterable[Any], counts: Iterable[int]) -> np.ndarray:\n  return np.array(\n      list(\n          itertools.chain.from_iterable(map(itertools.repeat, elements,\n                                            counts))))\n\n\ndef _to_padded_jax_array(x: np.ndarray) -> jax.Array:\n  if x.ndim == 1:\n    (i,) = x.shape\n    di = 2**math.ceil(math.log2(i)) - i\n    return jnp.pad(x, ((0, di),), mode=\"constant\", constant_values=0.0)\n  elif x.ndim == 2:\n    i, j = x.shape\n    di = 2**math.ceil(math.log2(i)) - i\n    dj = 2**math.ceil(math.log2(j)) - j\n    return jnp.pad(x, ((0, di), (0, dj)), mode=\"constant\", constant_values=0.0)\n  else:\n    raise ValueError(f\"Unsupported array shape: {x.shape}\")\n\n\nclass BatchedInContextXRegBase:\n  \"\"\"Helper class for in-context regression covariate formatting.\n\n  Attributes:\n    targets: List of targets (responses) of the in-context regression.\n    train_lens: List of lengths of each target vector from the context.\n    test_lens: List of lengths of each forecast horizon.\n    train_dynamic_numerical_covariates: Dict of covariate names mapping to the\n      dynamic numerical covariates of each forecast task on the context. Their\n      lengths should match the corresponding lengths in `train_lens`.\n    train_dynamic_categorical_covariates: Dict of covariate names mapping to the\n      dynamic categorical covariates of each forecast task on the context. Their\n      lengths should match the corresponding lengths in `train_lens`.\n    test_dynamic_numerical_covariates: Dict of covariate names mapping to the\n      dynamic numerical covariates of each forecast task on the horizon. Their\n      lengths should match the corresponding lengths in `test_lens`.\n    test_dynamic_categorical_covariates: Dict of covariate names mapping to the\n      dynamic categorical covariates of each forecast task on the horizon. Their\n      lengths should match the corresponding lengths in `test_lens`.\n    static_numerical_covariates: Dict of covariate names mapping to the static\n      numerical covariates of each forecast task.\n    static_categorical_covariates: Dict of covariate names mapping to the static\n      categorical covariates of each forecast task.\n  \"\"\"\n\n  def __init__(\n      self,\n      targets: Sequence[Sequence[float]],\n      train_lens: Sequence[int],\n      test_lens: Sequence[int],\n      train_dynamic_numerical_covariates: (\n          Mapping[str, Sequence[Sequence[float]]] | None) = None,\n      train_dynamic_categorical_covariates: (\n          Mapping[str, Sequence[Sequence[Category]]] | None) = None,\n      test_dynamic_numerical_covariates: (\n          Mapping[str, Sequence[Sequence[float]]] | None) = None,\n      test_dynamic_categorical_covariates: (\n          Mapping[str, Sequence[Sequence[Category]]] | None) = None,\n      static_numerical_covariates: Mapping[str, Sequence[float]] | None = None,\n      static_categorical_covariates: (Mapping[str, Sequence[Category]] |\n                                      None) = None,\n  ) -> None:\n    \"\"\"Initializes with the exogenous covariate inputs.\n\n    Here we use model fitting language to refer to the context as 'train' and\n    the horizon as 'test'. We assume batched inputs. To properly format the\n    request:\n\n     - `train_lens` represents the contexts in the batch. Targets and all train\n     dynamic covariates should have the same lengths as the corresponding\n     elements\n     in `train_lens`. Notice each `train_len` can be different from the exact\n     length of the corresponding context depending on how much of the context is\n     used for fitting the in-context model.\n     - `test_lens` represents the horizon lengths in the batch. All tesdt\n     dynamic\n     covariates should have the same lengths as the corresponding elements in\n     `test_lens`.\n     - Static covariates should be one for each input.\n     - For train and test dynamic covariates, they should have the same\n     covariate\n     names.\n\n     Pass an empty dict {} for a covariate type if it is not present.\n\n     Example:\n       Here is a set of valid inputs whose schema can be used for reference.\n       ```\n       targets = [\n           [0.0, 0.1, 0.2],\n           [0.0, 0.1, 0.2, 0.3],\n       ]  # Two inputs in this batch.\n       train_lens = [3, 4]\n       test_lens = [2, 5]  # Forecast horizons 2 and 5 respectively.\n       train_dynamic_numerical_covariates = {\n           \"cov_1_dn\": [[0.0, 0.5, 1.0], [0.0, 0.5, 1.0, 1.5]],\n           \"cov_2_dn\": [[0.0, 1.5, 1.0], [0.0, 1.5, 1.0, 2.5]],\n       }  # Each train dynamic covariate has 3 and 4 elements respectively.\n       test_dynamic_numerical_covariates = {\n           \"cov_1_dn\": [[0.1, 0.6], [0.1, 0.6, 1.1, 1.6, 2.4]],\n           \"cov_2_dn\": [[0.1, 1.1], [0.1, 1.6, 1.1, 2.6, 10.0]],\n       }  # Each test dynamic covariate has 2 and 5 elements respectively.\n       train_dynamic_categorical_covariates = {\n           \"cov_1_dc\": [[0, 1, 0], [0, 1, 2, 3]],\n           \"cov_2_dc\": [[\"good\", \"bad\", \"good\"], [\"good\", \"good\", \"bad\",\n           \"bad\"]],\n       }\n       test_dynamic_categorical_covariates = {\n           \"cov_1_dc\": [[1, 0], [1, 0, 2, 3, 1]],\n           \"cov_2_dc\": [[\"bad\", \"good\"], [\"bad\", \"bad\", \"bad\", \"bad\", \"bad\"]],\n       }\n       static_numerical_covariates = {\n           \"cov_1_sn\": [0.0, 3.0],\n           \"cov_2_sn\": [2.0, 1.0],\n           \"cov_3_sn\": [1.0, 2.0],\n       }  # Each static covariate has 1 element for each input.\n       static_categorical_covariates = {\n           \"cov_1_sc\": [\"apple\", \"orange\"],\n           \"cov_2_sc\": [2, 3],\n       }\n       ```\n\n    Args:\n      targets: List of targets (responses) of the in-context regression.\n      train_lens: List of lengths of each target vector from the context.\n      test_lens: List of lengths of each forecast horizon.\n      train_dynamic_numerical_covariates: Dict of covariate names mapping to the\n        dynamic numerical covariates of each forecast task on the context. Their\n        lengths should match the corresponding lengths in `train_lens`.\n      train_dynamic_categorical_covariates: Dict of covariate names mapping to\n        the dynamic categorical covariates of each forecast task on the context.\n        Their lengths should match the corresponding lengths in `train_lens`.\n      test_dynamic_numerical_covariates: Dict of covariate names mapping to the\n        dynamic numerical covariates of each forecast task on the horizon. Their\n        lengths should match the corresponding lengths in `test_lens`.\n      test_dynamic_categorical_covariates: Dict of covariate names mapping to\n        the dynamic categorical covariates of each forecast task on the horizon.\n        Their lengths should match the corresponding lengths in `test_lens`.\n      static_numerical_covariates: Dict of covariate names mapping to the static\n        numerical covariates of each forecast task.\n      static_categorical_covariates: Dict of covariate names mapping to the\n        static categorical covariates of each forecast task.\n    \"\"\"\n    self.targets = targets\n    self.train_lens = train_lens\n    self.test_lens = test_lens\n    self.train_dynamic_numerical_covariates = (\n        train_dynamic_numerical_covariates or {})\n    self.train_dynamic_categorical_covariates = (\n        train_dynamic_categorical_covariates or {})\n    self.test_dynamic_numerical_covariates = (test_dynamic_numerical_covariates\n                                              or {})\n    self.test_dynamic_categorical_covariates = (\n        test_dynamic_categorical_covariates or {})\n    self.static_numerical_covariates = static_numerical_covariates or {}\n    self.static_categorical_covariates = static_categorical_covariates or {}\n\n  def _assert_covariates(self, assert_covariate_shapes: bool = False) -> None:\n    \"\"\"Verifies the validity of the covariate inputs.\"\"\"\n\n    # Check presence.\n    if (self.train_dynamic_numerical_covariates and\n        not self.test_dynamic_numerical_covariates) or (\n            not self.train_dynamic_numerical_covariates and\n            self.test_dynamic_numerical_covariates):\n      raise ValueError(\n          \"train_dynamic_numerical_covariates and\"\n          \" test_dynamic_numerical_covariates must be both present or both\"\n          \" absent.\")\n\n    if (self.train_dynamic_categorical_covariates and\n        not self.test_dynamic_categorical_covariates) or (\n            not self.train_dynamic_categorical_covariates and\n            self.test_dynamic_categorical_covariates):\n      raise ValueError(\n          \"train_dynamic_categorical_covariates and\"\n          \" test_dynamic_categorical_covariates must be both present or both\"\n          \" absent.\")\n\n    # Check keys.\n    for dict_a, dict_b, dict_a_name, dict_b_name in (\n        (\n            self.train_dynamic_numerical_covariates,\n            self.test_dynamic_numerical_covariates,\n            \"train_dynamic_numerical_covariates\",\n            \"test_dynamic_numerical_covariates\",\n        ),\n        (\n            self.train_dynamic_categorical_covariates,\n            self.test_dynamic_categorical_covariates,\n            \"train_dynamic_categorical_covariates\",\n            \"test_dynamic_categorical_covariates\",\n        ),\n    ):\n      if w := set(dict_a.keys()) - set(dict_b.keys()):\n        raise ValueError(\n            f\"{dict_a_name} has keys not present in {dict_b_name}: {w}\")\n      if w := set(dict_b.keys()) - set(dict_a.keys()):\n        raise ValueError(\n            f\"{dict_b_name} has keys not present in {dict_a_name}: {w}\")\n\n    # Check shapes.\n    if assert_covariate_shapes:\n      if len(self.targets) != len(self.train_lens):\n        raise ValueError(\n            \"targets and train_lens must have the same number of elements.\")\n\n      if len(self.train_lens) != len(self.test_lens):\n        raise ValueError(\n            \"train_lens and test_lens must have the same number of elements.\")\n\n      for i, (target, train_len) in enumerate(zip(self.targets,\n                                                  self.train_lens)):\n        if len(target) != train_len:\n          raise ValueError(\n              f\"targets[{i}] has length {len(target)} != expected {train_len}.\")\n\n      for key, values in self.static_numerical_covariates.items():\n        if len(values) != len(self.train_lens):\n          raise ValueError(\n              f\"static_numerical_covariates has key {key} with number of\"\n              f\" examples {len(values)} != expected {len(self.train_lens)}.\")\n\n      for key, values in self.static_categorical_covariates.items():\n        if len(values) != len(self.train_lens):\n          raise ValueError(\n              f\"static_categorical_covariates has key {key} with number of\"\n              f\" examples {len(values)} != expected {len(self.train_lens)}.\")\n\n      for lens, dict_cov, dict_cov_name in (\n          (\n              self.train_lens,\n              self.train_dynamic_numerical_covariates,\n              \"train_dynamic_numerical_covariates\",\n          ),\n          (\n              self.train_lens,\n              self.train_dynamic_categorical_covariates,\n              \"train_dynamic_categorical_covariates\",\n          ),\n          (\n              self.test_lens,\n              self.test_dynamic_numerical_covariates,\n              \"test_dynamic_numerical_covariates\",\n          ),\n          (\n              self.test_lens,\n              self.test_dynamic_categorical_covariates,\n              \"test_dynamic_categorical_covariates\",\n          ),\n      ):\n        for key, cov_values in dict_cov.items():\n          if len(cov_values) != len(lens):\n            raise ValueError(\n                f\"{dict_cov_name} has key {key} with number of examples\"\n                f\" {len(cov_values)} != expected {len(lens)}.\")\n          for i, cov_value in enumerate(cov_values):\n            if len(cov_value) != lens[i]:\n              raise ValueError(\n                  f\"{dict_cov_name} has key {key} with its {i}-th example\"\n                  f\" length {len(cov_value)} != expected {lens[i]}.\")\n\n  def create_covariate_matrix(\n      self,\n      one_hot_encoder_drop: str | None = \"first\",\n      use_intercept: bool = True,\n      assert_covariates: bool = False,\n      assert_covariate_shapes: bool = False,\n  ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:\n    \"\"\"Creates target vector and covariate matrices for in context regression.\n\n    Here we use model fitting language to refer to the context as 'train' and\n    the horizon as 'test'.\n\n    Args:\n      one_hot_encoder_drop: Which drop strategy to use for the one hot encoder.\n      use_intercept: Whether to prepare an intercept (all 1) column in the\n        matrices.\n      assert_covariates: Whether to assert the validity of the covariate inputs.\n      assert_covariate_shapes: Whether to assert the shapes of the covariate\n        inputs when `assert_covariates` is True.\n\n    Returns:\n      A tuple of the target vector, the covariate matrix for the context, and\n      the covariate matrix for the horizon.\n    \"\"\"\n    if assert_covariates:\n      self._assert_covariates(assert_covariate_shapes)\n\n    x_train, x_test = [], []\n\n    # Numerical features.\n    for name in sorted(self.train_dynamic_numerical_covariates):\n      x_train.append(\n          _unnest(self.train_dynamic_numerical_covariates[name])[:, np.newaxis])\n      x_test.append(\n          _unnest(self.test_dynamic_numerical_covariates[name])[:, np.newaxis])\n\n    for covs in self.static_numerical_covariates.values():\n      x_train.append(_repeat(covs, self.train_lens)[:, np.newaxis])\n      x_test.append(_repeat(covs, self.test_lens)[:, np.newaxis])\n\n    if x_train:\n      x_train = np.concatenate(x_train, axis=1)\n      x_test = np.concatenate(x_test, axis=1)\n\n      # Normalize for robustness.\n      x_mean = np.mean(x_train, axis=0, keepdims=True)\n      x_std = np.where((w := np.std(x_train, axis=0, keepdims=True)) > _TOL, w,\n                       1.0)\n      x_train = [(x_train - x_mean) / x_std]\n      x_test = [(x_test - x_mean) / x_std]\n\n    # Categorical features. Encode one by one.\n    one_hot_encoder = preprocessing.OneHotEncoder(\n        drop=one_hot_encoder_drop,\n        sparse_output=False,\n        handle_unknown=\"ignore\",\n    )\n    for name in sorted(self.train_dynamic_categorical_covariates.keys()):\n      ohe_train = _unnest(\n          self.train_dynamic_categorical_covariates[name])[:, np.newaxis]\n      ohe_test = _unnest(\n          self.test_dynamic_categorical_covariates[name])[:, np.newaxis]\n      x_train.append(np.array(one_hot_encoder.fit_transform(ohe_train)))\n      x_test.append(np.array(one_hot_encoder.transform(ohe_test)))\n\n    for covs in self.static_categorical_covariates.values():\n      ohe = one_hot_encoder.fit_transform(np.array(covs)[:, np.newaxis])\n      x_train.append(_repeat(ohe, self.train_lens))\n      x_test.append(_repeat(ohe, self.test_lens))\n\n    x_train = np.concatenate(x_train, axis=1)\n    x_test = np.concatenate(x_test, axis=1)\n\n    if use_intercept:\n      x_train = np.pad(x_train, ((0, 0), (1, 0)), constant_values=1.0)\n      x_test = np.pad(x_test, ((0, 0), (1, 0)), constant_values=1.0)\n\n    return _unnest(self.targets), x_train, x_test\n\n  def fit(self) -> Any:\n    raise NotImplementedError(\"Fit is not implemented.\")\n\n\nclass BatchedInContextXRegLinear(BatchedInContextXRegBase):\n  \"\"\"Linear in-context regression model.\"\"\"\n\n  def fit(\n      self,\n      ridge: float = 0.0,\n      one_hot_encoder_drop: str | None = \"first\",\n      use_intercept: bool = True,\n      force_on_cpu: bool = False,\n      max_rows_per_col: int = 0,\n      max_rows_per_col_sample_seed: int = 42,\n      debug_info: bool = False,\n      assert_covariates: bool = False,\n      assert_covariate_shapes: bool = False,\n  ) -> (list[np.ndarray] | tuple[list[np.ndarray], list[np.ndarray], jax.Array,\n                                 jax.Array, jax.Array]):\n    \"\"\"Fits a linear model for in-context regression.\n\n    Args:\n      ridge: A non-negative value for specifying the ridge regression penalty.\n        If 0 is provided, fallback to ordinary least squares. Note this penalty\n        is added to the normalized covariate matrix.\n      one_hot_encoder_drop: Which drop strategy to use for the one hot encoder.\n      use_intercept: Whether to prepare an intercept (all 1) column in the\n        matrices.\n      force_on_cpu: Whether to force execution on cpu for accelerator machines.\n      max_rows_per_col: How many rows to subsample per column. 0 for no\n        subsampling. This is for speeding up model fitting.\n      max_rows_per_col_sample_seed: The seed for the subsampling if needed by\n        `max_rows_per_col`.\n      debug_info: Whether to return debug info.\n      assert_covariates: Whether to assert the validity of the covariate inputs.\n      assert_covariate_shapes: Whether to assert the shapes of the covariate\n        inputs when `assert_covariates` is True.\n\n    Returns:\n      If `debug_info` is False:\n        The linear fits on the horizon.\n      If `debug_info` is True:\n        A tuple of:\n        - the linear fits on the horizon,\n        - the linear fits on the context,\n        - the flattened target vector,\n        - the covariate matrix for the context, and\n        - the covariate matrix for the horizon.\n    \"\"\"\n    flat_targets, x_train_raw, x_test = self.create_covariate_matrix(\n        one_hot_encoder_drop=one_hot_encoder_drop,\n        use_intercept=use_intercept,\n        assert_covariates=assert_covariates,\n        assert_covariate_shapes=assert_covariate_shapes,\n    )\n\n    x_train = x_train_raw.copy()\n    if max_rows_per_col:\n      nrows, ncols = x_train.shape\n      if nrows > (w := ncols * max_rows_per_col):\n        subsample = jax.random.choice(\n            jax.random.PRNGKey(max_rows_per_col_sample_seed),\n            nrows,\n            (w,),\n            replace=False,\n        )\n        x_train = x_train[subsample]\n        flat_targets = flat_targets[subsample]\n\n    device = jax.devices(\"cpu\")[0] if force_on_cpu else None\n    # Runs jitted version of the solvers which are quicker at the cost of\n    # running jitting during the first time calling. Re-jitting happens whenever\n    # new (padded) shapes are encountered.\n    # Ocassionally it helps with the speed and the accuracy if we force single\n    # thread execution on cpu for accelerator machines:\n    # 1. Avoid moving data to accelarator memory.\n    # 2. Avoid precision loss if any.\n    with jax.default_device(device):\n      x_train_raw = _to_padded_jax_array(x_train_raw)\n      x_train = _to_padded_jax_array(x_train)\n      flat_targets = _to_padded_jax_array(flat_targets)\n      x_test = _to_padded_jax_array(x_test)\n      beta_hat = (jnp.linalg.pinv(\n          x_train.T @ x_train + ridge * jnp.eye(x_train.shape[1]),\n          hermitian=True,\n      ) @ x_train.T @ flat_targets)\n      y_hat = x_test @ beta_hat\n      y_hat_context = x_train_raw @ beta_hat if debug_info else None\n\n    outputs = []\n    outputs_context = []\n\n    # Reconstruct the ragged 2-dim batched forecasts from flattened linear fits.\n    train_index, test_index = 0, 0\n    for train_index_delta, test_index_delta in zip(self.train_lens,\n                                                   self.test_lens):\n      outputs.append(np.array(y_hat[test_index:(test_index +\n                                                test_index_delta)]))\n      if debug_info:\n        outputs_context.append(\n            np.array(y_hat_context[train_index:(train_index +\n                                                train_index_delta)]))\n      train_index += train_index_delta\n      test_index += test_index_delta\n\n    if debug_info:\n      return outputs, outputs_context, flat_targets, x_train, x_test\n    else:\n      return outputs\n"
  },
  {
    "path": "v1/tests/test_timesfm.py",
    "content": "# Copyright 2024 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nfrom datetime import datetime, timedelta\n\nimport numpy as np\nimport pandas as pd\nimport pytest\n\nimport timesfm\n\n\ndef create_sample_dataframe(\n    start_date: datetime, end_date: datetime, freq: str = \"D\"\n) -> pd.DataFrame:\n    \"\"\"\n    Create a sample DataFrame with time series data.\n\n    Args:\n        start_date (datetime): Start date of the time series.\n        end_date (datetime): End date of the time series.\n        freq (str): Frequency of the time series (default: \"D\" for daily).\n\n    Returns:\n        pd.DataFrame: DataFrame with columns 'unique_id', 'ds', and 'ts'.\n    \"\"\"\n    date_range = pd.date_range(start=start_date, end=end_date, freq=freq)\n    ts_data = np.random.randn(len(date_range))\n    df = pd.DataFrame({\"unique_id\": \"ts-1\", \"ds\": date_range, \"ts\": ts_data})\n    return df\n\n\n@pytest.mark.parametrize(\"context_length\", [128, 256, 512])\n@pytest.mark.parametrize(\"prediction_length\", [96, 128, 256])\n@pytest.mark.parametrize(\"freq\", [\"D\", \"H\", \"W\"])\ndef test_timesfm_forecast_on_df(\n    context_length: int,\n    prediction_length: int,\n    freq: str,\n) -> None:\n    model = timesfm.TimesFm(\n        context_len=context_length,\n        horizon_len=prediction_length,\n        input_patch_len=32,\n        output_patch_len=128,\n        num_layers=20,\n        model_dims=1280,\n        backend=\"cpu\",\n    )\n    model.load_from_checkpoint(repo_id=\"google/timesfm-1.0-200m\")\n\n    end_date = datetime.now()\n    start_date = end_date - timedelta(days=context_length)\n    input_df = create_sample_dataframe(start_date, end_date, freq)\n\n    forecast_df = model.forecast_on_df(\n        inputs=input_df,\n        freq=freq,\n        value_name=\"ts\",\n        num_jobs=-1,\n    )\n\n    assert (\n        len(forecast_df) == prediction_length\n    ), f\"Expected forecast length of {prediction_length}, but got {len(forecast_df)}\"\n    assert (\n        \"timesfm\" in forecast_df.columns\n    ), \"Forecast DataFrame should contain 'timesfm' column\"\n\n    last_input_date = input_df[\"ds\"].max()\n    first_forecast_date = forecast_df[\"ds\"].min()\n    expected_first_forecast_date = last_input_date + pd.Timedelta(1, unit=freq)\n    assert (\n        first_forecast_date == expected_first_forecast_date\n    ), f\"Forecast should start from {expected_first_forecast_date}, but starts from {first_forecast_date}\"\n\n    print(\n        f\"Successful forecast with context_length={context_length}, prediction_length={prediction_length}, freq={freq}\"\n    )\n"
  }
]