Full Code of google-research/timesfm for AI

master ed29ec5ef7e2 cached
85 files
917.3 KB
280.2k tokens
376 symbols
1 requests
Download .txt
Showing preview only (955K chars total). Download the full file or copy to clipboard to get everything.
Repository: google-research/timesfm
Branch: master
Commit: ed29ec5ef7e2
Files: 85
Total size: 917.3 KB

Directory structure:
gitextract_oox45t84/

├── .gitattributes
├── .github/
│   └── workflows/
│       ├── main.yml
│       └── manual_publish.yml
├── .gitignore
├── AGENTS.md
├── LICENSE
├── README.md
├── pyproject.toml
├── requirements.txt
├── src/
│   └── timesfm/
│       ├── __init__.py
│       ├── configs.py
│       ├── flax/
│       │   ├── __init__.py
│       │   ├── dense.py
│       │   ├── normalization.py
│       │   ├── transformer.py
│       │   └── util.py
│       ├── timesfm_2p5/
│       │   ├── timesfm_2p5_base.py
│       │   ├── timesfm_2p5_flax.py
│       │   └── timesfm_2p5_torch.py
│       ├── torch/
│       │   ├── __init__.py
│       │   ├── dense.py
│       │   ├── normalization.py
│       │   ├── transformer.py
│       │   └── util.py
│       └── utils/
│           └── xreg_lib.py
├── timesfm-forecasting/
│   ├── SKILL.md
│   ├── examples/
│   │   ├── anomaly-detection/
│   │   │   ├── detect_anomalies.py
│   │   │   └── output/
│   │   │       └── anomaly_detection.json
│   │   ├── covariates-forecasting/
│   │   │   ├── demo_covariates.py
│   │   │   └── output/
│   │   │       ├── covariates_metadata.json
│   │   │       └── sales_with_covariates.csv
│   │   └── global-temperature/
│   │       ├── README.md
│   │       ├── generate_animation_data.py
│   │       ├── generate_gif.py
│   │       ├── generate_html.py
│   │       ├── output/
│   │       │   ├── animation_data.json
│   │       │   ├── forecast_output.csv
│   │       │   ├── forecast_output.json
│   │       │   └── interactive_forecast.html
│   │       ├── run_example.sh
│   │       ├── run_forecast.py
│   │       ├── temperature_anomaly.csv
│   │       └── visualize_forecast.py
│   ├── references/
│   │   ├── api_reference.md
│   │   ├── data_preparation.md
│   │   └── system_requirements.md
│   └── scripts/
│       ├── check_system.py
│       └── forecast_csv.py
└── v1/
    ├── LICENSE
    ├── README.md
    ├── TROUBLESHOOTING.md
    ├── docs/
    │   └── contributing.md
    ├── experiments/
    │   ├── baselines/
    │   │   ├── __init__.py
    │   │   └── timegpt_pipeline.py
    │   ├── extended_benchmarks/
    │   │   ├── README.md
    │   │   ├── run_timegpt.py
    │   │   ├── run_timesfm.py
    │   │   └── utils.py
    │   └── long_horizon_benchmarks/
    │       ├── README.md
    │       └── run_eval.py
    ├── notebooks/
    │   ├── covariates.ipynb
    │   ├── finetuning.ipynb
    │   └── finetuning_torch.ipynb
    ├── peft/
    │   ├── README.md
    │   ├── finetune.py
    │   ├── finetune.sh
    │   └── usage.ipynb
    ├── pyproject.toml
    ├── src/
    │   ├── adapter/
    │   │   ├── __init__.py
    │   │   ├── dora_layers.py
    │   │   ├── lora_layers.py
    │   │   └── utils.py
    │   ├── finetuning/
    │   │   ├── __init__.py
    │   │   ├── finetuning_example.py
    │   │   └── finetuning_torch.py
    │   └── timesfm/
    │       ├── __init__.py
    │       ├── data_loader.py
    │       ├── patched_decoder.py
    │       ├── pytorch_patched_decoder.py
    │       ├── time_features.py
    │       ├── timesfm_base.py
    │       ├── timesfm_jax.py
    │       ├── timesfm_torch.py
    │       └── xreg_lib.py
    └── tests/
        └── test_timesfm.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitattributes
================================================
# Git LFS tracking for binary outputs in timesfm-forecasting skill
timesfm-forecasting/**/*.png filter=lfs diff=lfs merge=lfs -text
timesfm-forecasting/**/*.gif filter=lfs diff=lfs merge=lfs -text


================================================
FILE: .github/workflows/main.yml
================================================
name: Python package build

on:
  push:
    branches: [ "master" ]
  pull_request:
    branches: [ "master" ]

jobs:
  build:
    runs-on: ubuntu-latest
    steps:
      - uses: actions/checkout@v2
      - name: Set up Python
        uses: actions/setup-python@v2
        with:
          python-version: '3.11'
      - name: Install uv
        run: |
          curl -LsSf https://astral.sh/uv/install.sh | sh
          echo "$HOME/.cargo/bin" >> $GITHUB_PATH
      - name: Create virtual environment
        run: uv venv
      - name: Install build dependencies
        run: |
          uv pip install build ".[torch,flax]"
      - name: Build package
        run: uv run python -m build

================================================
FILE: .github/workflows/manual_publish.yml
================================================
name: Manual PyPI Publish

on:
  workflow_dispatch:

jobs:
  build-and-publish:
    runs-on: ubuntu-latest
    steps:
      - uses: actions/checkout@v2
      - name: Set up Python
        uses: actions/setup-python@v2
        with:
          python-version: '3.11'
      - name: Install uv
        run: |
          curl -LsSf https://astral.sh/uv/install.sh | sh
          echo "$HOME/.cargo/bin" >> $GITHUB_PATH
      - name: Create virtual environment
        run: uv venv
      - name: Install build dependencies
        run: uv pip install build twine
      - name: Build package
        run: uv run python -m build
      - name: Publish to PyPI
        env:
          TWINE_USERNAME: __token__
          TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
        run: uv run twine upload dist/*

================================================
FILE: .gitignore
================================================
.venv/
dist/
__pycache__/
checkpoints/
wandb/
datasets/
results/
timesfm_jax.egg-info/
development_setup.md


================================================
FILE: AGENTS.md
================================================
# TimesFM — Agent Entry Point

This repository ships a first-party **Agent Skill** for TimesFM at:

```
timesfm-forecasting/
└── SKILL.md    ← read this for the full skill
```

## Install the skill

Copy the skill directory into your agent's skills folder:

```bash
# Cursor / Claude Code / OpenCode / Codex (global install)
cp -r timesfm-forecasting/ ~/.cursor/skills/
cp -r timesfm-forecasting/ ~/.claude/skills/

# Or project-level
cp -r timesfm-forecasting/ .cursor/skills/
```

Any agent that supports the open [Agent Skills standard](https://agentskills.io) will discover it automatically.

## Working in this repo

If you are developing TimesFM itself (not using it), the source lives in `src/timesfm/`.
Archived v1/v2 code and notebooks are in `v1/`.

Run tests:

```bash
pytest v1/tests/
```

See `README.md` for full developer setup.


================================================
FILE: LICENSE
================================================

                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright [yyyy] [name of copyright owner]

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.


================================================
FILE: README.md
================================================
# TimesFM

TimesFM (Time Series Foundation Model) is a pretrained time-series foundation
model developed by Google Research for time-series forecasting.

*   Paper:
    [A decoder-only foundation model for time-series forecasting](https://arxiv.org/abs/2310.10688),
    ICML 2024.
*   All checkpoints:
    [TimesFM Hugging Face Collection](https://huggingface.co/collections/google/timesfm-release-66e4be5fdb56e960c1e482a6).
*   [Google Research blog](https://research.google/blog/a-decoder-only-foundation-model-for-time-series-forecasting/).
*   [TimesFM in BigQuery](https://cloud.google.com/bigquery/docs/timesfm-model):
    an official Google product.

This open version is not an officially supported Google product.

**Latest Model Version:** TimesFM 2.5

**Archived Model Versions:**

-   1.0 and 2.0: relevant code archived in the sub directory `v1`. You can `pip
    install timesfm==1.3.0` to install an older version of this package to load
    them.

## Update - Oct. 29, 2025

Added back the covariate support through XReg for TimesFM 2.5.


## Update - Sept. 15, 2025

TimesFM 2.5 is out!

Comparing to TimesFM 2.0, this new 2.5 model:

-   uses 200M parameters, down from 500M.
-   supports up to 16k context length, up from 2048.
-   supports continuous quantile forecast up to 1k horizon via an optional 30M
    quantile head.
-   gets rid of the `frequency` indicator.
-   has a couple of new forecasting flags.

Along with the model upgrade we have also upgraded the inference API. This repo
will be under construction over the next few weeks to

1.  add support for an upcoming Flax version of the model (faster inference).
2.  add back covariate support.
3.  populate more docstrings, docs and notebook.

### Install

1.  Clone the repository:
    ```shell
    git clone https://github.com/google-research/timesfm.git
    cd timesfm
    ```

2.  Create a virtual environment and install dependencies using `uv`:
    ```shell
    # Create a virtual environment
    uv venv
    
    # Activate the environment
    source .venv/bin/activate
    
    # Install the package in editable mode with torch
    uv pip install -e .[torch]
    # Or with flax
    uv pip install -e .[flax]
    # Or XReg is needed
    uv pip install -e .[xreg]
    ```

3. [Optional] Install your preferred `torch` / `jax` backend based on your OS and accelerators
(CPU, GPU, TPU or Apple Silicon).:

-   [Install PyTorch](https://pytorch.org/get-started/locally/).
-   [Install Jax](https://docs.jax.dev/en/latest/installation.html#installation)
    for Flax.

### Code Example

```python
import torch
import numpy as np
import timesfm

torch.set_float32_matmul_precision("high")

model = timesfm.TimesFM_2p5_200M_torch.from_pretrained("google/timesfm-2.5-200m-pytorch")

model.compile(
    timesfm.ForecastConfig(
        max_context=1024,
        max_horizon=256,
        normalize_inputs=True,
        use_continuous_quantile_head=True,
        force_flip_invariance=True,
        infer_is_positive=True,
        fix_quantile_crossing=True,
    )
)
point_forecast, quantile_forecast = model.forecast(
    horizon=12,
    inputs=[
        np.linspace(0, 1, 100),
        np.sin(np.linspace(0, 20, 67)),
    ],  # Two dummy inputs
)
point_forecast.shape  # (2, 12)
quantile_forecast.shape  # (2, 12, 10): mean, then 10th to 90th quantiles.
```


================================================
FILE: pyproject.toml
================================================
[project]
name = "timesfm"
version = "2.0.0"
description = "A time series foundation model."
authors = [
    {name = "Rajat Sen", email = "senrajat@google.com"},
    {name = "Yichen Zhou", email = "yichenzhou@google.com"},
    {name = "Abhimanyu Das", email = "abhidas@google.com"},
    {name = "Petros Mol", email = "pmol@google.com"},
    {name = "Michael Chertushkin", email = "chertushkinmichael@gmail.com"},
]
license = {text = "Apache-2.0"}
readme = "README.md"
requires-python = ">=3.10"
dependencies = [
    "numpy>=1.26.4",
    "huggingface_hub[cli]>=0.23.0",
    "safetensors>=0.5.3",
]

[project.optional-dependencies]
torch = [
    "torch>=2.0.0",
]
flax = [
    "flax",
    "optax",
    "einshape",
    "orbax-checkpoint",
    "jaxtyping",
    "jax[cuda]"
]
xreg = [
    "jax[cuda]",
    "scikit-learn",
]

[tool.ruff]
line-length = 88
indent-width = 2

[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"



================================================
FILE: requirements.txt
================================================
# This file was autogenerated by uv via the following command:
#    uv pip compile pyproject.toml -o requirements.txt
anyio==4.11.0
    # via httpx
certifi==2025.10.5
    # via
    #   httpcore
    #   httpx
click==8.3.0
    # via typer-slim
filelock==3.19.1
    # via huggingface-hub
fsspec==2025.9.0
    # via huggingface-hub
h11==0.16.0
    # via httpcore
hf-xet==1.2.0
    # via huggingface-hub
httpcore==1.0.9
    # via httpx
httpx==0.28.1
    # via huggingface-hub
huggingface-hub==1.0.1
    # via timesfm (pyproject.toml)
idna==3.10
    # via
    #   anyio
    #   httpx
numpy==2.2.6
    # via timesfm (pyproject.toml)
packaging==25.0
    # via huggingface-hub
pyyaml==6.0.3
    # via huggingface-hub
safetensors==0.6.2
    # via timesfm (pyproject.toml)
shellingham==1.5.4
    # via huggingface-hub
sniffio==1.3.1
    # via anyio
tqdm==4.67.1
    # via huggingface-hub
typer-slim==0.20.0
    # via huggingface-hub
typing-extensions==4.15.0
    # via
    #   anyio
    #   huggingface-hub
    #   typer-slim


================================================
FILE: src/timesfm/__init__.py
================================================
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""TimesFM API."""

from .configs import ForecastConfig

try:
  from .timesfm_2p5 import timesfm_2p5_torch
  TimesFM_2p5_200M_torch = timesfm_2p5_torch.TimesFM_2p5_200M_torch
except ImportError:
  pass

try:
  from .timesfm_2p5 import timesfm_2p5_flax
  TimesFM_2p5_200M_flax = timesfm_2p5_flax.TimesFM_2p5_200M_flax
except ImportError:
  pass


================================================
FILE: src/timesfm/configs.py
================================================
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Abstract configs for TimesFM layers."""

import dataclasses
from typing import Literal


@dataclasses.dataclass(frozen=True)
class ForecastConfig:
  """Options for forecasting.

  Attributes:
    max_context: The maximum context length. This is used by the complied decode
      function at inference time during batched inference. Any input time series
      with length less than max_context will be padded with zeros, and with
      length greater than max_context will be truncated.
    max_horizon: The maximum horizon length. This is used by the complied decode
      function at inference time during batched inference. The compiled cached
      decoding function will by default forecast till max_horizon.
    normalize_inputs: Whether to normalize the inputs. This is useful when the
      raw inputs are of extremely large or small magnitudes which may result in
      numerical issues.
    window_size: The window size for decomposed forecasting.
      TODO(siriuz42):implement it.
    per_core_batch_size: The batch size per core. Used at inference time during
      batched inference when multiple GPU / TPU devices are used.
    use_continuous_quantile_head: Whether to use a separate continuous quantile
      head to avoid quantile collapsing.
    force_flip_invariance: Whether to force flip invariance. TimesFM guarantees
      that TimesFM(aX + b) = a * TimesFM(x) + b for a >= 0 by default. This flag
      extends it to a < 0 as well.
    infer_is_positive: Whether to guarantee nonnegativity of the output if the
      input is nonnegative.
    fix_quantile_crossing: Whether to fix quantile crossing.
    return_backcast: Whether to return backcast.
  """

  max_context: int = 0
  max_horizon: int = 0
  normalize_inputs: bool = False
  window_size: int = 0
  per_core_batch_size: int = 1
  use_continuous_quantile_head: bool = False
  force_flip_invariance: bool = True
  infer_is_positive: bool = True
  fix_quantile_crossing: bool = False
  return_backcast: bool = False


@dataclasses.dataclass(frozen=True)
class ResidualBlockConfig:
  """Framework-agnostic config for a residual block."""

  input_dims: int
  hidden_dims: int
  output_dims: int
  use_bias: bool
  activation: Literal["relu", "swish", "none"]


@dataclasses.dataclass(frozen=True)
class RandomFourierFeaturesConfig:
  """Framework-agnostic config for random fourier features."""

  input_dims: int
  output_dims: int
  projection_stddev: float
  use_bias: bool


@dataclasses.dataclass(frozen=True)
class TransformerConfig:
  """Framework-agnostic config for a transformer."""

  model_dims: int
  hidden_dims: int
  num_heads: int
  attention_norm: Literal["rms"]
  feedforward_norm: Literal["rms"]
  qk_norm: Literal["rms", "none"]
  use_bias: bool
  use_rotary_position_embeddings: bool
  ff_activation: Literal["relu", "swish", "none"]
  fuse_qkv: bool


@dataclasses.dataclass(frozen=True)
class StackedTransformersConfig:
  """Framework-agnostic config for a stacked transformers."""

  num_layers: int
  transformer: TransformerConfig


================================================
FILE: src/timesfm/flax/__init__.py
================================================
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


================================================
FILE: src/timesfm/flax/dense.py
================================================
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Dense layers for TimesFM."""

from flax import nnx
import jax
import jax.numpy as jnp
import jaxtyping

from .. import configs

Array = jaxtyping.Array
Bool = jaxtyping.Bool
Float = jaxtyping.Float
Integer = jaxtyping.Integer
Num = jaxtyping.Num

ResidualBlockConfig = configs.ResidualBlockConfig
RandomFourierFeaturesConfig = configs.RandomFourierFeaturesConfig


class ResidualBlock(nnx.Module):
  """Residual block with two linear layers and a linear residual connection."""

  def __init__(self, config: ResidualBlockConfig, *, rngs=nnx.Rngs(42)):
    self.config = config
    self.hidden_layer = nnx.Linear(
      in_features=config.input_dims,
      out_features=config.hidden_dims,
      use_bias=config.use_bias,
      rngs=rngs,
    )
    self.output_layer = nnx.Linear(
      in_features=config.hidden_dims,
      out_features=config.output_dims,
      use_bias=config.use_bias,
      rngs=rngs,
    )
    self.residual_layer = nnx.Linear(
      in_features=config.input_dims,
      out_features=config.output_dims,
      use_bias=config.use_bias,
      rngs=rngs,
    )
    if config.activation == "relu":
      self.activation = jax.nn.relu
    elif config.activation == "swish":
      self.activation = jax.nn.swish
    elif config.activation == "none":
      self.activation = lambda x: x
    else:
      raise ValueError(f"Activation: {config.activation} not supported.")

  def __call__(self, x: Float[Array, "b ... i"]) -> Float[Array, "b ... o"]:
    return self.output_layer(
      self.activation(self.hidden_layer(x))
    ) + self.residual_layer(x)


class RandomFourierFeatures(nnx.Module):
  """Random Fourier features layer."""

  __data__ = ("phrase_shifts",)

  def __init__(self, config: RandomFourierFeaturesConfig, *, rngs=nnx.Rngs(42)):
    self.config = config

    if config.output_dims % 4 != 0:
      raise ValueError(
        f"Output dims must be a multiple of 4: {config.output_dims} % 4 != 0."
      )
    num_projected_features = config.output_dims // 4

    self.phase_shifts = nnx.Param(jnp.zeros(shape=(2, num_projected_features)))
    self.projection_layer = nnx.Linear(
      in_features=config.input_dims,
      out_features=num_projected_features,
      use_bias=config.use_bias,
      rngs=rngs,
    )
    self.residual_layer = nnx.Linear(
      in_features=config.input_dims,
      out_features=config.output_dims,
      use_bias=config.use_bias,
      rngs=rngs,
    )

  def __call__(self, x: Float[Array, "b ... i"]) -> Float[Array, "b ... o"]:
    projected = self.projection_layer(x)
    cos_features = jnp.cos(projected)
    sin_features = jnp.sin(projected)
    sq_wave_1 = jnp.sign(jnp.sin(projected + self.phase_shifts[0, :]))
    sq_wave_2 = jnp.sign(jnp.sin(projected + self.phase_shifts[1, :]))
    fourier_features = jnp.concatenate(
      [cos_features, sin_features, sq_wave_1, sq_wave_2], axis=-1
    )
    residual = self.residual_layer(x)
    return fourier_features + residual


================================================
FILE: src/timesfm/flax/normalization.py
================================================
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Normalization layers for TimesFM."""

from flax import nnx
import jax
import jax.numpy as jnp
import jaxtyping

Array = jaxtyping.Array
Bool = jaxtyping.Bool
Float = jaxtyping.Float
Integer = jaxtyping.Integer
Num = jaxtyping.Num


class RMSNorm(nnx.Module):
  """RMS normalization."""

  __data__ = ("scale",)

  def __init__(
    self,
    num_features: int,
    *,
    epsilon: float = 1e-6,
    rngs=nnx.Rngs(42),
  ):
    del rngs
    self.scale = nnx.Param(jnp.zeros(shape=(num_features,)))
    self.num_features = num_features
    self.epsilon = epsilon

  def __call__(self, inputs: Float[Array, "b ... d"]) -> Float[Array, "b ... d"]:
    var = jnp.mean(jnp.square(inputs), axis=-1, keepdims=True)
    normed_inputs = inputs * jax.lax.rsqrt(var + self.epsilon)
    normed_inputs *= self.scale
    return normed_inputs


class LayerNorm(nnx.Module):
  """Layer normalization replica of  LayerNorm."""

  __data__ = ("scale", "bias")

  def __init__(self, num_features: int, *, epsilon: float = 1e-6, rngs=nnx.Rngs(42)):
    del rngs
    self.scale = nnx.Param(jnp.ones(shape=(num_features,)))
    self.bias = nnx.Param(jnp.zeros(shape=(num_features,)))
    self.num_features = num_features
    self.epsilon = epsilon

  def __call__(self, inputs: Float[Array, "b ... d"]) -> Float[Array, "b ... d"]:
    mean = jnp.mean(inputs, axis=-1, keepdims=True)
    var = jnp.mean(jnp.square(inputs - mean), axis=-1, keepdims=True)
    normed_inputs = (inputs - mean) * jax.lax.rsqrt(var + self.epsilon)
    normed_inputs *= self.scale
    normed_inputs += self.bias
    return normed_inputs


================================================
FILE: src/timesfm/flax/transformer.py
================================================
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Transformer layers for TimesFM."""

import functools
from typing import Callable

from flax import nnx
from flax.nnx.nn import linear
import jax
from jax import lax
import jax.numpy as jnp
import jaxtyping

from .. import configs
from . import normalization, util

Array = jaxtyping.Array
Bool = jaxtyping.Bool
Float = jaxtyping.Float
Integer = jaxtyping.Integer
Num = jaxtyping.Num
LayerNorm = normalization.LayerNorm
RMSNorm = normalization.RMSNorm
LinearGeneral = linear.LinearGeneral
TransformerConfig = configs.TransformerConfig
DecodeCache = util.DecodeCache


@functools.partial(
  jax.jit,
  static_argnames=("query_length", "kv_length"),
)
def make_attn_mask(
  query_length: int,
  num_all_masked_kv: Integer[Array, "b"],
  query_index_offset: Integer[Array, "b"] | None = None,
  kv_length: int = 0,
) -> Bool[Array, "b 1 q n"]:
  """Makes attention mask."""

  if kv_length == 0:
    kv_length = query_length

  q_index = jnp.arange(query_length)[None, None, :, None]
  if query_index_offset is not None:
    q_index += query_index_offset[:, None, None, None]
  kv_index = jnp.arange(kv_length)[None, None, None, :]
  return jnp.logical_and(
    q_index >= kv_index,
    kv_index >= num_all_masked_kv[:, None, None, None],
  )


class RotaryPositionalEmbedding(nnx.Module):
  """Rotary positional embedding."""

  def __init__(
    self,
    embedding_dims: int,
    min_timescale: int = 1,
    max_timescale: int = 10000,
  ):
    self.embedding_dims = embedding_dims
    self.min_timescale = min_timescale
    self.max_timescale = max_timescale

  def __call__(
    self,
    inputs: Float[Array, "b ... d"],
    position: Array | None = None,
  ):
    """Generates a JTensor of sinusoids with different frequencies."""
    if self.embedding_dims != inputs.shape[-1]:
      raise ValueError(
        "The embedding dims of the rotary position embedding"
        "must match the hidden dimension of the inputs."
      )
    half_embedding_dim = self.embedding_dims // 2
    fraction = 2 * jnp.arange(0, half_embedding_dim) / self.embedding_dims
    timescale = (
      self.min_timescale * (self.max_timescale / self.min_timescale) ** fraction
    )
    if position is None:
      seq_length = inputs.shape[1]
      position = jnp.arange(seq_length, dtype=jnp.float32)[None, :]
    if len(inputs.shape) == 4:
      position = position[..., None, None]
      timescale = timescale[None, None, None, :]
    elif len(inputs.shape) == 3:
      position = position[..., None]
      timescale = timescale[None, None, :]
    else:
      raise ValueError("Inputs must be of rank 3 or 4.")
    sinusoid_inp = position / timescale
    sin = jnp.sin(sinusoid_inp)
    cos = jnp.cos(sinusoid_inp)
    first_half, second_half = jnp.split(inputs, 2, axis=-1)
    first_part = first_half * cos - second_half * sin
    second_part = second_half * cos + first_half * sin
    first_part = first_part.astype(None)
    second_part = second_part.astype(None)
    return jnp.concatenate([first_part, second_part], axis=-1)


class PerDimScale(nnx.Module):
  """Per-dimension scaling."""

  __data__ = ("per_dim_scale",)

  def __init__(self, num_dims: int, *, rngs=nnx.Rngs(42)):
    del rngs
    self.num_dims = num_dims
    self.per_dim_scale = nnx.Param(jnp.zeros(shape=(num_dims,)))

  def __call__(self, x: Float[Array, "b ... d"]) -> Float[Array, "b ... d"]:
    return x * (
      1.442695041 / jnp.sqrt(self.num_dims) * jax.nn.softplus(self.per_dim_scale)
    )


class MultiHeadAttention(nnx.Module):
  """Multi-head attention."""

  def __init__(
    self,
    num_heads: int,
    in_features: int,
    *,
    use_per_dim_scale: bool = True,
    use_rotary_position_embeddings: bool = True,
    use_bias: bool = False,
    deterministic: bool | None = None,
    attention_fn: Callable[..., Array] = nnx.dot_product_attention,
    qk_norm: str = "rms",
    rngs=nnx.Rngs(42),
  ):
    self.num_heads = num_heads
    self.in_features = in_features
    self.qkv_features = in_features
    self.out_features = in_features
    self.in_kv_features = in_features
    self.deterministic = deterministic
    self.use_bias = use_bias
    self.attention_fn = attention_fn
    self.qk_norm = qk_norm

    if self.qkv_features % self.num_heads != 0:
      raise ValueError(
        f"Memory dimension ({self.qkv_features}) must be divisible by "
        f"'num_heads' heads ({self.num_heads})."
      )
    self.head_dim = self.qkv_features // self.num_heads

    linear_general = functools.partial(
      LinearGeneral,
      out_features=(self.num_heads, self.head_dim),
      use_bias=self.use_bias,
    )
    # project inputs_q to multi-headed q/k/v
    # dimensions are then [batch..., length, n_heads, n_features_per_head]
    self.query = linear_general(self.in_features, rngs=rngs)
    self.key = linear_general(self.in_kv_features, rngs=rngs)
    self.value = linear_general(self.in_kv_features, rngs=rngs)

    if self.qk_norm == "rms":
      self.query_ln = RMSNorm(self.head_dim)
      self.key_ln = RMSNorm(self.head_dim)
    else:
      self.query_ln = None
      self.key_ln = None

    self.out = LinearGeneral(
      in_features=(self.num_heads, self.head_dim),
      out_features=self.out_features,
      axis=(-2, -1),
      use_bias=self.use_bias,
      rngs=rngs,
    )

    self.use_per_dim_scale = use_per_dim_scale
    self.use_rotary_position_embeddings = use_rotary_position_embeddings
    if self.use_rotary_position_embeddings:
      self.rotary_position_embedding = RotaryPositionalEmbedding(
        embedding_dims=self.head_dim,
      )
    else:
      self.rotary_position_embedding = None

    if use_per_dim_scale:
      self.per_dim_scale = PerDimScale(num_dims=self.head_dim, rngs=rngs)
    else:
      self.per_dim_scale = None

  def __call__(
    self,
    inputs_q: Array,
    *,
    decode_cache: DecodeCache | None = None,
    patch_mask: Array | None = None,
    deterministic: bool | None = None,
    sow_weights: bool = False,
  ) -> tuple[Float[Array, "b ... o"], DecodeCache | None]:
    """Applies multi-head dot product attention on the input data."""
    _, n_patches, input_in_features = inputs_q.shape
    if input_in_features != self.in_features:
      raise ValueError(
        f"Incompatible input dimension, got {input_in_features} "
        f"but module expects {self.in_features}."
      )
    if patch_mask is None:
      patch_mask = jnp.zeros_like(inputs_q.shape[:-1], dtype=jnp.bool)

    # For query: rope -> ln -> per_dim_scale
    query = self.query(inputs_q)
    key = self.key(inputs_q)
    value = self.value(inputs_q)

    if decode_cache is None:
      num_masked = jnp.sum(patch_mask.astype(jnp.int32), axis=-1, keepdims=False)
      next_index = jnp.zeros_like(num_masked, dtype=jnp.int32)
    else:
      num_masked = (
        jnp.sum(patch_mask.astype(jnp.int32), axis=-1, keepdims=False)
        + decode_cache.num_masked
      )
      next_index = decode_cache.next_index

    if self.use_rotary_position_embeddings:
      position = (
        jnp.arange(n_patches, dtype=jnp.int32)[None, :]
        + next_index[:, None]
        - num_masked[:, None]
      )
      query = self.rotary_position_embedding(query, position)
      key = self.rotary_position_embedding(key, position)
    if self.query_ln is not None:
      query = self.query_ln(query)
    if self.key_ln is not None:
      key = self.key_ln(key)
    if self.use_per_dim_scale:
      query = self.per_dim_scale(query)

    if decode_cache is not None:
      # Cached decoding.
      _, decode_cache_size, _, _ = decode_cache.value.shape
      zero = jnp.array(0, dtype=lax.dtype(next_index.dtype))
      start_indices = (zero, next_index[0], zero, zero)
      key = lax.dynamic_update_slice(decode_cache.key, key, start_indices)
      value = lax.dynamic_update_slice(decode_cache.value, value, start_indices)
      decode_cache.key = key
      decode_cache.value = value
      decode_cache.next_index = next_index + n_patches
      decode_cache.num_masked = num_masked
      attn_mask = make_attn_mask(
        query_length=n_patches,
        num_all_masked_kv=num_masked,
        query_index_offset=next_index,
        kv_length=decode_cache_size,
      )
    else:
      # Training
      attn_mask = make_attn_mask(query_length=n_patches, num_all_masked_kv=num_masked)

    # apply attention
    x = self.attention_fn(
      query * jnp.sqrt(self.head_dim),
      key,
      value,
      mask=attn_mask,
      deterministic=deterministic,
      module=self if sow_weights else None,
    )
    # back to the original inputs dimensions
    out = self.out(x)
    return out, decode_cache


class Transformer(nnx.Module):
  """Classic Transformer used in TimesFM."""

  def __init__(self, config: TransformerConfig, *, rngs=nnx.Rngs(42)):
    self.config = config

    if config.attention_norm == "rms":
      self.pre_attn_ln = RMSNorm(num_features=config.model_dims, rngs=rngs)
      self.post_attn_ln = RMSNorm(num_features=config.model_dims, rngs=rngs)
    else:
      raise ValueError(f"Layer norm: {config.attention_norm} not supported.")

    self.attn = MultiHeadAttention(
      num_heads=config.num_heads,
      in_features=config.model_dims,
      use_per_dim_scale=True,
      use_rotary_position_embeddings=config.use_rotary_position_embeddings,
      qk_norm=config.qk_norm,
      rngs=rngs,
    )

    if config.feedforward_norm == "rms":
      self.pre_ff_ln = RMSNorm(num_features=config.model_dims, rngs=rngs)
      self.post_ff_ln = RMSNorm(num_features=config.model_dims, rngs=rngs)
    else:
      raise ValueError(f"Layer norm: {config.feedforward_norm} not supported.")
    self.ff0 = nnx.Linear(
      in_features=config.model_dims,
      out_features=config.hidden_dims,
      use_bias=config.use_bias,
      rngs=rngs,
    )
    self.ff1 = nnx.Linear(
      in_features=config.hidden_dims,
      out_features=config.model_dims,
      use_bias=config.use_bias,
      rngs=rngs,
    )
    if config.ff_activation == "relu":
      self.activation = jax.nn.relu
    elif config.ff_activation == "swish":
      self.activation = jax.nn.swish
    elif config.ff_activation == "none":
      self.activation = lambda x: x
    else:
      raise ValueError(f"Activation: {config.ff_activation} not supported.")

  def __call__(
    self,
    input_embeddings: Float[Array, "b n d"],
    patch_mask: Bool[Array, "b n"],
    decode_cache: DecodeCache | None = None,
  ) -> tuple[Float[Array, "b n d"], DecodeCache | None]:
    attn_output, decode_cache = self.attn(
      inputs_q=self.pre_attn_ln(input_embeddings),
      decode_cache=decode_cache,
      patch_mask=patch_mask,
      sow_weights=False,
      deterministic=True,
    )
    attn_output = self.post_attn_ln(attn_output) + input_embeddings
    output_embeddings = (
      self.post_ff_ln(self.ff1(self.activation(self.ff0(self.pre_ff_ln(attn_output)))))
      + attn_output
    )
    return output_embeddings, decode_cache


================================================
FILE: src/timesfm/flax/util.py
================================================
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Flax utility functions for TimesFM layers."""

import dataclasses
import functools
import jax
import jax.numpy as jnp
import jaxtyping

Float = jaxtyping.Float
Array = jaxtyping.Array
Bool = jaxtyping.Bool
Integer = jaxtyping.Integer

_TOLERANCE = 1e-6


@jax.tree_util.register_dataclass
@dataclasses.dataclass(frozen=False)
class DecodeCache:
  """Cache for decoding."""

  next_index: Integer[Array, "b"]
  num_masked: Integer[Array, "b"]
  key: Float[Array, "b n h d"]
  value: Float[Array, "b n h d"]


@jax.jit
def update_running_stats(
  n: Float[Array, "b"],
  mu: Float[Array, "b"],
  sigma: Float[Array, "b"],
  x: Float[Array, "b p"],
  mask: Bool[Array, "b p"],
) -> tuple[
  tuple[Float[Array, "b"], Float[Array, "b"], Float[Array, "b"]],
  tuple[Float[Array, "b"], Float[Array, "b"], Float[Array, "b"]],
]:
  """Updates the running stats."""
  is_legit = jnp.logical_not(mask)
  inc_n = jnp.sum(is_legit.astype(jnp.float32), axis=-1, keepdims=False)
  inc_mu = jnp.where(
    inc_n == 0, 0.0, jnp.mean(x, axis=-1, keepdims=False, where=is_legit)
  )
  inc_sigma = jnp.where(
    inc_n == 0, 0.0, jnp.std(x, axis=-1, keepdims=False, where=is_legit)
  )
  new_n = n + inc_n
  new_mu = jnp.where(new_n == 0, 0.0, (n * mu + inc_mu * inc_n) / new_n)
  new_sigma = jnp.sqrt(
    jnp.where(
      new_n == 0,
      0.0,
      (
        n * sigma * sigma
        + inc_n * inc_sigma * inc_sigma
        + n * (mu - new_mu) * (mu - new_mu)
        + inc_n * (inc_mu - new_mu) * (inc_mu - new_mu)
      )
      / new_n,
    )
  )
  return (w := (new_n, new_mu, new_sigma), w)


def scan_along_axis(f, init, xs, axis: int, **kwargs):
  """Scans along an axis."""
  moved_xs = jax.tree_util.tree_map(lambda x: jnp.moveaxis(x, axis, 0), xs)
  carry, moved_ys = jax.lax.scan(f, init, moved_xs, **kwargs)
  return (
    carry,
    jax.tree_util.tree_map(lambda x: jnp.moveaxis(x, 0, axis), moved_ys),
  )


@functools.partial(jax.jit, static_argnames=("reverse",))
def revin(
  x: Float[Array, "b ..."],
  mu: Float[Array, "b ..."],
  sigma: Float[Array, "b ..."],
  reverse: bool = False,
):
  """Reversible per-instance normalization."""
  if len(mu.shape) == len(x.shape) - 1:
    mu = mu[..., None]
    sigma = sigma[..., None]
  elif len(mu.shape) == len(x.shape) - 2:
    mu = mu[..., None, None]
    sigma = sigma[..., None, None]
  if reverse:
    return x * sigma + mu
  else:
    return (x - mu) / jnp.where(sigma < _TOLERANCE, 1.0, sigma)


================================================
FILE: src/timesfm/timesfm_2p5/timesfm_2p5_base.py
================================================
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""TimesFM 2p5 base implementation."""

import dataclasses
from typing import Any, Callable, Sequence

import collections
import numpy as np

from .. import configs

ResidualBlockConfig = configs.ResidualBlockConfig
StackedTransformersConfig = configs.StackedTransformersConfig
TransformerConfig = configs.TransformerConfig
ForecastConfig = configs.ForecastConfig
Category = int | str
XRegMode = str


def strip_leading_nans(arr):
  """Removes contiguous NaN values from the beginning of a NumPy array.

  Args:
    arr: The input NumPy array.

  Returns:
    A new NumPy array with leading NaN values removed.
    If the array is all NaNs or empty, returns an empty array.
  """

  isnan = np.isnan(arr)
  first_valid_index = np.argmax(~isnan)
  return arr[first_valid_index:]


def linear_interpolation(arr):
  """Performs linear interpolation to fill NaN values in a 1D numpy array.

  Args:
      arr: The 1D numpy array containing NaN values.

  Returns:
      A new numpy array with NaN values filled using linear interpolation,
      or the original array if no NaNs are present.
      Returns None if the input is not a 1D array.
      Returns the original array if there are no NaN values.
  """

  nans = np.isnan(arr)
  if not np.any(nans):  # Check if there are any NaNs
    return arr

  def x(z):
    return z.nonzero()[0]

  nans_indices = x(nans)
  non_nans_indices = x(~nans)
  non_nans_values = arr[~nans]

  try:
    arr[nans] = np.interp(nans_indices, non_nans_indices, non_nans_values)
  except ValueError:
    if non_nans_values:
      mu = np.nanmean(arr)
    else:
      mu = 0.0
    arr = np.where(np.isfinite(arr), arr, mu)
  return arr


@dataclasses.dataclass(frozen=True)
class TimesFM_2p5_200M_Definition:
  """Framework-agnostic config of TimesFM 2.5."""

  context_limit = 16384
  input_patch_len: int = 32
  output_patch_len: int = 128
  output_quantile_len: int = 1024
  quantiles: list[float] = dataclasses.field(
    default_factory=lambda: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
  )
  decode_index: int = 5
  tokenizer: ResidualBlockConfig = ResidualBlockConfig(
    input_dims=64,
    hidden_dims=1280,
    output_dims=1280,
    use_bias=True,
    activation="swish",
  )
  stacked_transformers: StackedTransformersConfig = StackedTransformersConfig(
    num_layers=20,
    transformer=TransformerConfig(
      model_dims=1280,
      hidden_dims=1280,
      num_heads=16,
      attention_norm="rms",
      feedforward_norm="rms",
      qk_norm="rms",
      use_bias=False,
      use_rotary_position_embeddings=True,
      ff_activation="swish",
      fuse_qkv=True,
    ),
  )
  output_projection_point: ResidualBlockConfig = ResidualBlockConfig(
    input_dims=1280,
    hidden_dims=1280,
    output_dims=1280,
    use_bias=False,
    activation="swish",
  )
  output_projection_quantiles: ResidualBlockConfig = ResidualBlockConfig(
    input_dims=1280,
    hidden_dims=1280,
    output_dims=10240,
    use_bias=False,
    activation="swish",
  )


class TimesFM_2p5:
  """Abstract base class for TimesFM models.

  Attributes:
    forecast_config: Configuration for forecasting flags.
    compiled_decode: Compiled decode function.
    global_batch_size: Global batch size.
  """

  forecast_config: ForecastConfig | None = None
  compiled_decode: Callable[..., Any] | None = None
  global_batch_size: int = 0

  def load_checkpoint(self, path: str):
    """Loads a TimesFM model from a checkpoint."""
    raise NotImplementedError()

  def compile(self, forecast_config: ForecastConfig | None = None):
    """Compiles the TimesFM model for fast decoding."""
    raise NotImplementedError()

  def forecast(
    self, horizon: int, inputs: list[np.ndarray]
  ) -> tuple[np.ndarray, np.ndarray]:
    """Forecasts the time series."""
    if self.compiled_decode is None:
      raise RuntimeError("Model is not compiled. Please call compile() first.")

    assert self.global_batch_size > 0
    assert self.forecast_config is not None

    context = self.forecast_config.max_context
    num_inputs = len(inputs)
    if (w := num_inputs % self.global_batch_size) != 0:
      inputs += [np.array([0.0] * 3)] * (self.global_batch_size - w)

    output_points = []
    output_quantiles = []
    values = []
    masks = []
    idx = 0
    for each_input in inputs:
      value = linear_interpolation(strip_leading_nans(np.array(each_input)))
      if (w := len(value)) >= context:
        value = value[-context:]
        mask = np.zeros_like(value, dtype=bool)
      else:
        mask = np.array([True] * (context - w) + [False] * w)
        value = np.pad(value, (context - w, 0), "constant", constant_values=0.0)
      values.append(value)
      masks.append(mask)
      idx += 1
      if idx == self.global_batch_size:
        idx = 0
        point_forecast, quantile_forecast = self.compiled_decode(horizon, values, masks)
        output_points.append(point_forecast)
        output_quantiles.append(quantile_forecast)
        values = []
        masks = []

    output_points = np.concatenate(output_points, axis=0)
    output_quantiles = np.concatenate(output_quantiles, axis=0)
    return output_points[:num_inputs], output_quantiles[:num_inputs]

  def forecast_with_covariates(
    self,
    inputs: list[Sequence[float]],
    dynamic_numerical_covariates: dict[str, Sequence[Sequence[float]]] | None = None,
    dynamic_categorical_covariates: (
      dict[str, Sequence[Sequence[Category]]] | None
    ) = None,
    static_numerical_covariates: dict[str, Sequence[float]] | None = None,
    static_categorical_covariates: dict[str, Sequence[Category]] | None = None,
    xreg_mode: XRegMode = "xreg + timesfm",
    normalize_xreg_target_per_input: bool = True,
    ridge: float = 0.0,
    max_rows_per_col: int = 0,
    force_on_cpu: bool = False,
  ):
    """Forecasts on a list of time series with covariates.

    To optimize inference speed, avoid string valued categorical covariates.

    Args:
      inputs: A list of time series forecast contexts. Each context time series
        should be in a format convertible to JTensor by `jnp.array`.
      dynamic_numerical_covariates: A dict of dynamic numerical covariates.
      dynamic_categorical_covariates: A dict of dynamic categorical covariates.
      static_numerical_covariates: A dict of static numerical covariates.
      static_categorical_covariates: A dict of static categorical covariates.
      xreg_mode: one of "xreg + timesfm" or "timesfm + xreg". "timesfm + xreg"
        fits a model on the residuals of the TimesFM forecast. "xreg + timesfm"
        fits a model on the targets then forecasts on the residuals via TimesFM.
      normalize_xreg_target_per_input: whether to normalize the xreg target per
        input in the given batch.
      ridge: ridge penalty for the linear model.
      max_rows_per_col: max number of rows per column for the linear model.
      force_on_cpu: whether to force running on cpu for the linear model.

    Returns:
      A tuple of two lists. The first is the outputs of the model. The second is
      the outputs of the xreg.
    """
    if self.forecast_config is None:
      raise ValueError("Model is not compiled. Please call compile() first.")
    elif not self.forecast_config.return_backcast:
      raise ValueError(
        "For XReg, `return_backcast` must be set to True in the forecast config. Please recompile the model."
      )

    from ..utils import xreg_lib

    # Verify and bookkeep covariates.
    if not (
      dynamic_numerical_covariates
      or dynamic_categorical_covariates
      or static_numerical_covariates
      or static_categorical_covariates
    ):
      raise ValueError(
        "At least one of dynamic_numerical_covariates,"
        " dynamic_categorical_covariates, static_numerical_covariates,"
        " static_categorical_covariates must be set."
      )

    # Track the lengths of (1) each input, (2) the part that can be used in the
    # linear model, and (3) the horizon.
    input_lens, train_lens, test_lens = [], [], []

    for i, input_ts in enumerate(inputs):
      input_len = len(input_ts)
      input_lens.append(input_len)

      if xreg_mode == "timesfm + xreg":
        # For fitting residuals, no TimesFM forecast on the first patch.
        train_lens.append(max(0, input_len - self.model.p))
      elif xreg_mode == "xreg + timesfm":
        train_lens.append(input_len)
      else:
        raise ValueError(f"Unsupported mode: {xreg_mode}")

      if dynamic_numerical_covariates:
        test_lens.append(
          len(list(dynamic_numerical_covariates.values())[0][i]) - input_len
        )
      elif dynamic_categorical_covariates:
        test_lens.append(
          len(list(dynamic_categorical_covariates.values())[0][i]) - input_len
        )
      else:
        test_lens.append(self.forecast_config.max_horizon)

      if test_lens[-1] > self.forecast_config.max_horizon:
        raise ValueError(
          "Forecast horizon length inferred from the dynamic covaraites is longer than the"
          f"max_horizon defined in the forecast config: {test_lens[-1]} > {self.forecast_config.max_horizon=}."
        )

    # Prepare the covariates into train and test.
    train_dynamic_numerical_covariates = collections.defaultdict(list)
    test_dynamic_numerical_covariates = collections.defaultdict(list)
    train_dynamic_categorical_covariates = collections.defaultdict(list)
    test_dynamic_categorical_covariates = collections.defaultdict(list)
    for covariates, train_covariates, test_covariates in (
      (
        dynamic_numerical_covariates,
        train_dynamic_numerical_covariates,
        test_dynamic_numerical_covariates,
      ),
      (
        dynamic_categorical_covariates,
        train_dynamic_categorical_covariates,
        test_dynamic_categorical_covariates,
      ),
    ):
      if not covariates:
        continue
      for covariate_name, covariate_values in covariates.items():
        for input_len, train_len, covariate_value in zip(
          input_lens, train_lens, covariate_values
        ):
          train_covariates[covariate_name].append(
            covariate_value[(input_len - train_len) : input_len]
          )
          test_covariates[covariate_name].append(covariate_value[input_len:])

    # Fit models.
    if xreg_mode == "timesfm + xreg":
      # Forecast via TimesFM then fit a model on the residuals.
      point_outputs, quantile_outputs = self.forecast(
        horizon=self.forecast_config.max_horizon, inputs=inputs
      )
      targets = [
        (
          np.array(input_ts)[-train_len:]
          - point_output[: -self.forecast_config.max_horizon][-train_len:]
        )
        for input_ts, point_output, train_len in zip(inputs, point_outputs, train_lens)
      ]
      per_instance_stats = None
      if normalize_xreg_target_per_input:
        targets, per_instance_stats = xreg_lib.normalize(targets)
      xregs = xreg_lib.BatchedInContextXRegLinear(
        targets=targets,
        train_lens=train_lens,
        test_lens=test_lens,
        train_dynamic_numerical_covariates=train_dynamic_numerical_covariates,
        test_dynamic_numerical_covariates=test_dynamic_numerical_covariates,
        train_dynamic_categorical_covariates=train_dynamic_categorical_covariates,
        test_dynamic_categorical_covariates=test_dynamic_categorical_covariates,
        static_numerical_covariates=static_numerical_covariates,
        static_categorical_covariates=static_categorical_covariates,
      ).fit(
        ridge=ridge,
        one_hot_encoder_drop=None if ridge > 0 else "first",
        max_rows_per_col=max_rows_per_col,
        force_on_cpu=force_on_cpu,
        debug_info=False,
        assert_covariates=True,
        assert_covariate_shapes=True,
      )
      if normalize_xreg_target_per_input:
        xregs = xreg_lib.renormalize(xregs, per_instance_stats)
      xregs = np.array(xregs)
      new_point_outputs = [
        (point_output[-self.forecast_config.max_horizon :][:test_len] + xreg)
        for point_output, test_len, xreg in zip(point_outputs, test_lens, xregs)
      ]
      new_quantile_outputs = [
        (
          quantile_output[-self.forecast_config.max_horizon :][:test_len]
          + xreg[..., None]
        )
        for quantile_output, test_len, xreg in zip(quantile_outputs, test_lens, xregs)
      ]

    else:
      # Fit a model on the targets then forecast on the residuals via TimesFM.
      targets = [
        np.array(input_ts)[-train_len:]
        for input_ts, train_len in zip(inputs, train_lens)
      ]
      per_instance_stats = None
      if normalize_xreg_target_per_input:
        targets, per_instance_stats = xreg_lib.normalize(targets)
      xregs, xregs_on_context, _, _, _ = xreg_lib.BatchedInContextXRegLinear(
        targets=targets,
        train_lens=train_lens,
        test_lens=test_lens,
        train_dynamic_numerical_covariates=train_dynamic_numerical_covariates,
        test_dynamic_numerical_covariates=test_dynamic_numerical_covariates,
        train_dynamic_categorical_covariates=train_dynamic_categorical_covariates,
        test_dynamic_categorical_covariates=test_dynamic_categorical_covariates,
        static_numerical_covariates=static_numerical_covariates,
        static_categorical_covariates=static_categorical_covariates,
      ).fit(
        ridge=ridge,
        one_hot_encoder_drop=None if ridge > 0 else "first",
        max_rows_per_col=max_rows_per_col,
        force_on_cpu=force_on_cpu,
        debug_info=True,
        assert_covariates=True,
        assert_covariate_shapes=True,
      )
      point_outputs, quantile_outputs = self.forecast(
        horizon=self.forecast_config.max_horizon,
        inputs=[
          target - xreg_on_context
          for target, xreg_on_context in zip(targets, xregs_on_context)
        ],
      )
      new_point_outputs = [
        (point_output[-self.forecast_config.max_horizon :][:test_len] + xreg)
        for point_output, test_len, xreg in zip(point_outputs, test_lens, xregs)
      ]
      new_quantile_outputs = [
        (
          quantile_output[-self.forecast_config.max_horizon :][:test_len]
          + xreg[..., None]
        )
        for quantile_output, test_len, xreg in zip(quantile_outputs, test_lens, xregs)
      ]
      if normalize_xreg_target_per_input:
        new_point_outputs = xreg_lib.renormalize(new_point_outputs, per_instance_stats)
        new_quantile_outputs = xreg_lib.renormalize(
          new_quantile_outputs, per_instance_stats
        )

    return new_point_outputs, new_quantile_outputs


================================================
FILE: src/timesfm/timesfm_2p5/timesfm_2p5_flax.py
================================================
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""TimesFM models in Flax."""

import dataclasses
import functools
import gc
import logging
import math
import os
from pathlib import Path
from typing import Any, Callable, Dict

import einshape
from flax import nnx
import huggingface_hub
import jax
import jax.numpy as jnp
import jaxtyping
import numpy as np
import orbax.checkpoint as ocp

from .. import configs
from ..flax import dense, transformer, util
from . import timesfm_2p5_base

jax_einshape = einshape.jax_einshape
scan = util.scan_along_axis
revin = util.revin

Float = jaxtyping.Float
Bool = jaxtyping.Bool
Array = jaxtyping.Array


def try_gc():
  for d in jax.local_devices():
    stats = d.memory_stats()
    if stats is None:
      return
    if stats["bytes_in_use"] / stats["bytes_limit"] > 0.75:
      gc.collect()
      break


@nnx.vmap(in_axes=(None, 0), out_axes=0)
def _create_stacked_transformers(
  config: configs.StackedTransformersConfig, key: jax.Array
):
  return transformer.Transformer(config.transformer, rngs=nnx.Rngs(key))


def _scan_along_axis(f, init, xs, axis: int, **kwargs):
  """Scans along an axis."""
  moved_xs = jax.tree_util.tree_map(lambda x: jnp.moveaxis(x, axis, 0), xs)
  carry, moved_ys = jax.lax.scan(f, init, moved_xs, **kwargs)
  return (
    carry,
    jax.tree_util.tree_map(lambda x: jnp.moveaxis(x, 0, axis), moved_ys),
  )


@nnx.scan(in_axes=(0, nnx.Carry, None, 0), out_axes=(nnx.Carry, 0))
def _apply_stacked_transformers(
  model: transformer.Transformer,
  x: Float[Array, "b n d"],
  m: Float[Array, "b n"],
  decode_cache: util.DecodeCache | None = None,
) -> Float[Array, "b n d"]:
  return model(x, m, decode_cache=decode_cache)


class TimesFM_2p5_200M_flax_module(nnx.Module):  # pylint: disable=invalid-name
  """TimesFM 2.5 with 200M parameters."""

  config = timesfm_2p5_base.TimesFM_2p5_200M_Definition()
  decode_index: int = 5
  compiled_decode: Callable[..., Any] | None = None
  backend: str = ""
  context: int = 0
  horizon: int = 0
  per_core_batch_size: int = 0

  def __init__(self):
    super().__init__()
    self.backend = jax.devices()[0].platform
    self.num_devices = len(jax.devices(self.backend))

    # Names constants.
    self.p = self.config.input_patch_len  # 32
    self.o = self.config.output_patch_len  # 128
    self.os = self.config.output_quantile_len  # 1024
    self.m = self.o // self.p  # 4
    self.x = self.config.stacked_transformers.num_layers  # 20
    self.h = self.config.stacked_transformers.transformer.num_heads  # 16
    self.md = self.config.stacked_transformers.transformer.model_dims  # 1280
    self.hd = self.md // self.h  # 80
    self.q = len(self.config.quantiles) + 1  # 10
    self.aridx = self.config.decode_index  # 5

    # Layers.
    self.tokenizer = dense.ResidualBlock(self.config.tokenizer)
    self.stacked_xf = _create_stacked_transformers(
      self.config.stacked_transformers,
      jax.random.split(jax.random.key(42), self.x),
    )
    self.output_projection_point = dense.ResidualBlock(
      self.config.output_projection_point
    )
    self.output_projection_quantiles = dense.ResidualBlock(
      self.config.output_projection_quantiles
    )

  def __call__(
    self,
    inputs: Float[Array, "b n p"],
    masks: Bool[Array, "b n p"],
    decode_cache: util.DecodeCache | None = None,
  ):
    tokenizer_inputs = jnp.concatenate([inputs, masks.astype(inputs.dtype)], axis=-1)
    input_embeddings = self.tokenizer(tokenizer_inputs)
    if decode_cache is None:
      decode_cache = [None] * self.x
    output_embeddings, decode_cache = _apply_stacked_transformers(
      self.stacked_xf, input_embeddings, masks[..., -1], decode_cache
    )
    output_ts = self.output_projection_point(output_embeddings)
    output_quantile_spread = self.output_projection_quantiles(output_embeddings)
    return (
      input_embeddings,
      output_embeddings,
      output_ts,
      output_quantile_spread,
    ), decode_cache

  @nnx.jit(static_argnames=("horizon",))
  def decode(self, horizon: int, inputs, masks):
    batch_size, context = inputs.shape[0], inputs.shape[1]
    num_decode_steps = (horizon - 1) // self.o
    num_input_patches = context // self.p
    decode_cache_size = num_input_patches + num_decode_steps * self.m

    # Prefill
    patched_inputs = jax_einshape("b(np)->bnp", inputs, b=batch_size, p=self.p)
    patched_masks = jax_einshape("b(np)->bnp", masks, b=batch_size, p=self.p)
    (last_n, last_mu, last_sigma), (_, context_mu, context_sigma) = scan(
      lambda carry, xs: util.update_running_stats(*carry, *xs),
      init=(zero := jnp.zeros(shape=(batch_size)), zero, zero),
      xs=(patched_inputs, patched_masks),
      axis=1,
    )
    decode_cache = util.DecodeCache(
      next_index=jnp.zeros(shape=(self.x, batch_size), dtype=jnp.int32),
      num_masked=jnp.zeros(shape=(self.x, batch_size), dtype=jnp.int32),
      key=jnp.zeros(shape=(self.x, batch_size, decode_cache_size, self.h, self.hd)),
      value=jnp.zeros(shape=(self.x, batch_size, decode_cache_size, self.h, self.hd)),
    )
    normed_inputs = revin(patched_inputs, context_mu, context_sigma, reverse=False)
    normed_inputs = jnp.where(patched_masks, 0.0, normed_inputs)
    (_, _, normed_outputs, normed_quantile_spread), decode_cache = self(
      normed_inputs, patched_masks, decode_cache
    )
    renormed_outputs = jax_einshape(
      "bn(oq)->bnoq",
      revin(normed_outputs, context_mu, context_sigma, reverse=True),
      o=self.o,
      q=self.q,
    )
    renormed_quantile_spread = jax_einshape(
      "bn(oq)->bnoq",
      revin(normed_quantile_spread, context_mu, context_sigma, reverse=True),
      o=self.os,
      q=self.q,
    )[:, -1, ...]

    # Autogressive decode
    @nnx.scan(in_axes=(None, nnx.Carry, 0), out_axes=(nnx.Carry, 1))
    def _ar_decode(module, carry, unused_iter):
      last_renormed_output, (last_n, last_mu, last_sigma), decode_cache = carry
      new_patched_input = jax_einshape(
        "b(mp)->bmp", last_renormed_output, m=module.m, p=module.p
      )
      new_mask = jnp.zeros_like(new_patched_input, dtype=jnp.bool)
      carry_stats, (_, new_mu, new_sigma) = scan(
        lambda carry, xs: util.update_running_stats(*carry, *xs),
        init=(last_n, last_mu, last_sigma),
        xs=(new_patched_input, new_mask),
        axis=1,
      )
      new_normed_input = revin(new_patched_input, new_mu, new_sigma, reverse=False)
      (_, _, new_normed_output, _), decode_cache = module(
        new_normed_input, new_mask, decode_cache
      )
      new_renormed_output = jax_einshape(
        "bm(oq)->bmoq",
        revin(new_normed_output, new_mu, new_sigma, reverse=True),
        o=module.o,
        q=module.q,
      )[..., -1, :, :]

      return (
        (
          new_renormed_output[..., module.decode_index],
          carry_stats,
          decode_cache,
        ),
        new_renormed_output,
      )

    if num_decode_steps > 0:
      _, ar_renormed_outputs = _ar_decode(
        self,
        (
          renormed_outputs[..., -1, :, self.decode_index],
          (last_n, last_mu, last_sigma),
          decode_cache,
        ),
        jnp.arange(num_decode_steps),
      )
    else:
      ar_renormed_outputs = None

    return renormed_outputs, renormed_quantile_spread, ar_renormed_outputs

  def compile(
    self,
    context: int,
    horizon: int,
    per_core_batch_size: int = 1,
  ):
    if context % self.p != 0:
      logging.info(
        "When compiling, context needs to be multiple of the patch size %d."
        " Modifying context to %d.",
        self.p,
        context := math.ceil(context / self.p) * self.p,
      )
    if horizon % self.o != 0:
      logging.info(
        "When compiling, horizon needs to be multiple of the output patch"
        " size %d. Modifying horizon to %d.",
        self.o,
        horizon := math.ceil(horizon / self.o) * self.o,
      )

    self.context = context
    self.horizon = horizon
    self.per_core_batch_size = per_core_batch_size

    @nnx.pmap(
      in_axes=(None, None, 0, 0),
      out_axes=(0, 0, 0),
      devices=jax.devices(self.backend),
      axis_size=self.num_devices,
      static_broadcasted_argnums=(1,),
      axis_name="global_batch",
    )
    def compiled_decode_kernel(model, horizon, inputs, masks):
      return model.decode(horizon, inputs, masks)

    self.compiled_decode = functools.partial(compiled_decode_kernel, self)


def _flip_quantile_fn(x):
  return jnp.concatenate([x[..., :1], jnp.flip(x[..., 1:], axis=-1)], axis=-1)


@functools.partial(
  jax.jit,
  donate_argnums=(0, 1, 2),
)
def _force_flip_invariance_fn(
  flipped_pf_outputs,
  flipped_quantile_spreads,
  flipped_ar_outputs,
):
  """Forces flip invariance."""
  flipped_pf_outputs = _flip_quantile_fn(flipped_pf_outputs)
  flipped_pf_outputs = jax_einshape("tb...->(tb)...", flipped_pf_outputs)
  flipped_quantile_spreads = _flip_quantile_fn(flipped_quantile_spreads)
  flipped_quantile_spreads = jax_einshape("tb...->(tb)...", flipped_quantile_spreads)
  to_concat = [flipped_pf_outputs[:, -1, ...]]
  if flipped_ar_outputs is not None:
    flipped_ar_outputs = _flip_quantile_fn(flipped_ar_outputs)
    flipped_ar_outputs = jax_einshape("tbno...->(tb)(no)...", flipped_ar_outputs)
    to_concat.append(flipped_ar_outputs)
  flipped_full_forecast = jnp.concatenate(to_concat, axis=1)

  return flipped_quantile_spreads, flipped_pf_outputs, flipped_full_forecast


@functools.partial(
  jax.jit,
  static_argnames=("max_horizon",),
  donate_argnums=(0,),
)
def _use_continuous_quantile_head_fn(full_forecast, quantile_spreads, max_horizon):
  """Uses continuous quantile head."""
  to_stack = [full_forecast[..., :max_horizon, 0]]
  for quantile_index in [1, 2, 3, 4]:
    to_stack.append(
      quantile_spreads[:, :max_horizon, quantile_index]
      - quantile_spreads[:, :max_horizon, 5]
      + full_forecast[:, :max_horizon, 5]
    )
  to_stack.append(full_forecast[..., :max_horizon, 5])
  for quantile_index in [6, 7, 8, 9]:
    to_stack.append(
      quantile_spreads[:, :max_horizon, quantile_index]
      - quantile_spreads[:, :max_horizon, 5]
      + full_forecast[:, :max_horizon, 5]
    )
  return jnp.stack(to_stack, axis=-1)


@functools.partial(jax.jit, donate_argnums=(0,))
def _fix_quantile_crossing_fn(full_forecast):
  """Fixes quantile crossing."""
  lower_quantiles = _scan_along_axis(
    lambda carry, x: (w := jnp.minimum(carry, x), w),
    init=full_forecast[..., 5],
    xs=full_forecast[..., 1:5],
    axis=-1,
    reverse=True,
  )[1]
  upper_quantiles = _scan_along_axis(
    lambda carry, x: (w := jnp.maximum(carry, x), w),
    init=full_forecast[..., 5],
    xs=full_forecast[..., 6:10],
    axis=-1,
    reverse=False,
  )[1]
  return jnp.concatenate(
    [
      full_forecast[..., :1],
      lower_quantiles,
      full_forecast[..., 5:6],
      upper_quantiles,
    ],
    axis=-1,
  )


@functools.partial(jax.jit, static_argnames=("fc",), donate_argnums=(1, 2))
def _before_model_decode(fc, inputs, masks):
  """All Jax steps before model decode call."""
  if fc.infer_is_positive:
    is_positive = jnp.all(inputs >= 0, axis=-1, keepdims=True)
  else:
    is_positive = None

  if fc.normalize_inputs:
    mu = jnp.mean(inputs, axis=-1, keepdims=True)
    sigma = jnp.std(inputs, axis=-1, keepdims=True)
    inputs = revin(inputs, mu, sigma, reverse=False)
  else:
    mu, sigma = None, None

  inputs = jax_einshape("(tb)...->tb...", inputs, b=fc.per_core_batch_size)
  masks = jax_einshape("(tb)...->tb...", masks, b=fc.per_core_batch_size)

  return inputs, masks, is_positive, mu, sigma


@functools.partial(
  jax.jit,
  static_argnames=(
    "fc",
    "p",
  ),
  donate_argnums=(1, 2, 3, 4, 5, 6, 7, 8, 9),
)
def _after_model_decode(
  fc,
  pf_outputs,
  quantile_spreads,
  ar_outputs,
  flipped_pf_outputs,
  flipped_quantile_spreads,
  flipped_ar_outputs,
  is_positive,
  mu,
  sigma,
  p,
):
  """All Jax steps after model decode call."""
  # t: num_devices, b: per_core_batch_size
  pf_outputs = jax_einshape("tb...->(tb)...", pf_outputs)
  quantile_spreads = jax_einshape("tb...->(tb)...", quantile_spreads)
  to_concat = [pf_outputs[:, -1, ...]]
  if ar_outputs is not None:
    ar_outputs = jax_einshape("tbno...->(tb)(no)...", ar_outputs)
    to_concat.append(ar_outputs)
  full_forecast = jnp.concatenate(to_concat, axis=1)

  if fc.force_flip_invariance:
    (
      flipped_quantile_spreads,
      flipped_pf_outputs,
      flipped_full_forecast,
    ) = _force_flip_invariance_fn(
      flipped_pf_outputs, flipped_quantile_spreads, flipped_ar_outputs
    )
    quantile_spreads = (quantile_spreads - flipped_quantile_spreads) / 2
    pf_outputs = (pf_outputs - flipped_pf_outputs) / 2
    full_forecast = (full_forecast - flipped_full_forecast) / 2

  if fc.use_continuous_quantile_head:
    full_forecast = _use_continuous_quantile_head_fn(
      full_forecast, quantile_spreads, fc.max_horizon
    )

  if fc.return_backcast:
    full_backcast = jax_einshape("...npq->...(np)q", pf_outputs[:, :-1, :p, :])
    full_forecast = jnp.concatenate([full_backcast, full_forecast], axis=1)

  if fc.fix_quantile_crossing:
    full_forecast = _fix_quantile_crossing_fn(full_forecast)

  if fc.normalize_inputs:
    full_forecast = revin(full_forecast, mu, sigma, reverse=True)

  if is_positive is not None:
    full_forecast = jnp.where(
      is_positive[..., None],
      jnp.maximum(full_forecast, jnp.zeros_like(full_forecast)),
      full_forecast,
    )

  return full_forecast


class TimesFM_2p5_200M_flax(timesfm_2p5_base.TimesFM_2p5):
  """Flax implementation of TimesFM 2.5 with 200M parameters."""

  model: nnx.Module = TimesFM_2p5_200M_flax_module()

  @classmethod
  def from_pretrained(
    cls,
    model_id: str = "google/timesfm-2.5-200m-flax",
    *,
    revision: str | None = None,
    cache_dir: str | Path | None = None,
    force_download: bool = False,
    proxies: Dict | None = None,
    resume_download: bool | None = None,
    local_files_only: bool | None = None,
    token: str | None = None,
    **model_kwargs,
  ):
    """Loads a Flax TimesFM model."""

    # Create an instance of the model wrapper class.
    instance = cls(**model_kwargs)

    # Determine the path to the model weights.
    model_file_path = ""
    if os.path.isdir(model_id):
      logging.info("Loading checkpoint from local directory: %s", model_id)
      model_file_path = model_id
    else:
      logging.info("Downloading checkpoint from Hugging Face repo %s", model_id)
      model_file_path = huggingface_hub.snapshot_download(
        repo_id=model_id,
        revision=revision,
        cache_dir=cache_dir,
        force_download=force_download,
        proxies=proxies,
        resume_download=resume_download,
        token=token,
        local_files_only=local_files_only,
      )
      logging.info("Loading checkpoint from: %s", model_file_path)

    checkpointer = ocp.StandardCheckpointer()
    graph, state = nnx.split(instance.model)
    state = checkpointer.restore(model_file_path, state)
    instance.model = nnx.merge(graph, state)
    return instance

  def compile(
    self,
    forecast_config: configs.ForecastConfig,
    dryrun: bool = True,
    **kwargs
  ):
    # Acrobym used during validation.
    print("Compiling model...")

    fc = forecast_config
    if fc.max_context % self.model.p != 0:
      logging.info(
        "When compiling, max context needs to be multiple of the patch size"
        " %d. Using max context = %d instead.",
        self.model.p,
        new_context := math.ceil(fc.max_context / self.model.p) * self.model.p,
      )
      fc = dataclasses.replace(fc, max_context=new_context)
    if fc.max_horizon % self.model.o != 0:
      logging.info(
        "When compiling, max horizon needs to be multiple of the output patch"
        " size %d. Using max horizon = %d instead.",
        self.model.o,
        new_horizon := math.ceil(fc.max_horizon / self.model.o) * self.model.o,
      )
      fc = dataclasses.replace(fc, max_horizon=new_horizon)
    if fc.max_context + fc.max_horizon > self.model.config.context_limit:
      raise ValueError(
        "Context + horizon must be less than the context limit."
        f" {fc.max_context} + {fc.max_horizon} >"
        f" {self.model.config.context_limit}."
      )
    if fc.use_continuous_quantile_head and (fc.max_horizon > self.model.os):
      raise ValueError(
        f"Continuous quantile head is not supported for horizons > {self.model.os}."
      )

    self.forecast_config = fc
    self.model.compile(
      context=self.forecast_config.max_context,
      horizon=self.forecast_config.max_horizon,
      per_core_batch_size=fc.per_core_batch_size,
    )
    self.per_core_batch_size = self.forecast_config.per_core_batch_size
    self.num_devices = self.model.num_devices
    self.global_batch_size = (
      self.forecast_config.per_core_batch_size * self.model.num_devices
    )

    def compiled_decode_kernel(fc, horizon, inputs, masks):
      inputs = jnp.array(inputs, dtype=jnp.float32)
      masks = jnp.array(masks, dtype=jnp.bool)
      if horizon > fc.max_horizon:
        raise ValueError(
          f"Horizon must be less than the max horizon. {horizon} > {fc.max_horizon}."
        )
      to_trim = fc.max_horizon - horizon

      inputs, masks, is_positive, mu, sigma = _before_model_decode(fc, inputs, masks)

      pf_outputs, quantile_spreads, ar_outputs = self.model.compiled_decode(
        fc.max_horizon, inputs, masks
      )
      if fc.force_flip_invariance:
        flipped_pf_outputs, flipped_quantile_spreads, flipped_ar_outputs = (
          self.model.compiled_decode(fc.max_horizon, -inputs, masks)
        )
      else:
        flipped_pf_outputs, flipped_quantile_spreads, flipped_ar_outputs = (
          None,
          None,
          None,
        )

      full_forecast = _after_model_decode(
        fc,
        pf_outputs,
        quantile_spreads,
        ar_outputs,
        flipped_pf_outputs,
        flipped_quantile_spreads,
        flipped_ar_outputs,
        is_positive,
        mu,
        sigma,
        self.model.p,
      )
      full_forecast_np = np.array(full_forecast)
      del full_forecast
      try_gc()
      if to_trim > 0:
        full_forecast_np = full_forecast_np[..., :-to_trim, :]
      return full_forecast_np[..., 5], full_forecast_np

    self.compiled_decode = functools.partial(
      compiled_decode_kernel, self.forecast_config
    )

    if dryrun:
      _ = self.compiled_decode(
        self.forecast_config.max_horizon,
        jnp.zeros(
          (self.global_batch_size, self.forecast_config.max_context), dtype=jnp.float32
        ),
        jnp.zeros(
          (self.global_batch_size, self.forecast_config.max_context), dtype=jnp.bool
        ),
      )
    print("Compiling done.")


================================================
FILE: src/timesfm/timesfm_2p5/timesfm_2p5_torch.py
================================================
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TimesFM models."""

import dataclasses
import logging
import math
import os
from pathlib import Path
from typing import Optional, Sequence, Union

import numpy as np
import torch
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
from safetensors.torch import load_file, save_file
from torch import nn

from .. import configs
from ..torch import dense, transformer, util
from . import timesfm_2p5_base

revin = util.revin


class TimesFM_2p5_200M_torch_module(nn.Module):
  """TimesFM 2.5 with 200M parameters."""

  config = timesfm_2p5_base.TimesFM_2p5_200M_Definition()

  def __init__(self):
    super().__init__()

    # Names constants.
    self.p = self.config.input_patch_len  # 32
    self.o = self.config.output_patch_len  # 128
    self.os = self.config.output_quantile_len  # 1024
    self.m = self.o // self.p  # 4
    self.x = self.config.stacked_transformers.num_layers  # 20
    self.h = self.config.stacked_transformers.transformer.num_heads  # 16
    self.md = self.config.stacked_transformers.transformer.model_dims  # 1280
    self.hd = self.md // self.h  # 80
    self.q = len(self.config.quantiles) + 1  # 10
    self.aridx = self.config.decode_index  # 5

    # Layers.
    self.tokenizer = dense.ResidualBlock(self.config.tokenizer)
    self.stacked_xf = nn.ModuleList(
      [
        transformer.Transformer(self.config.stacked_transformers.transformer)
        for _ in range(self.x)
      ]
    )
    self.output_projection_point = dense.ResidualBlock(
      self.config.output_projection_point
    )
    self.output_projection_quantiles = dense.ResidualBlock(
      self.config.output_projection_quantiles
    )

    # Device.
    if torch.cuda.is_available():
      self.device = torch.device("cuda:0")
      self.device_count = torch.cuda.device_count()
    else:
      self.device = torch.device("cpu")
      self.device_count = 1

  def load_checkpoint(self, path: str, **kwargs):
    """Loads a PyTorch TimesFM model from a checkpoint."""
    tensors = load_file(path)
    self.load_state_dict(tensors, strict=True)
    self.to(self.device)
    torch_compile = True
    if "torch_compile" in kwargs:
      torch_compile = kwargs["torch_compile"]
    if torch_compile:
      print("Compiling model...")
      self = torch.compile(self)

    self.eval()

  def forward(
    self,
    inputs: torch.Tensor,
    masks: torch.Tensor,
    decode_caches: list[util.DecodeCache] | None = None,
  ):
    tokenizer_inputs = torch.cat([inputs, masks.to(inputs.dtype)], dim=-1)
    input_embeddings = self.tokenizer(tokenizer_inputs)

    if decode_caches is None:
      decode_caches = [None] * self.x

    output_embeddings = input_embeddings
    new_decode_caches = []
    for i, layer in enumerate(self.stacked_xf):
      output_embeddings, new_cache = layer(
        output_embeddings, masks[..., -1], decode_caches[i]
      )
      new_decode_caches.append(new_cache)
    output_ts = self.output_projection_point(output_embeddings)
    output_quantile_spread = self.output_projection_quantiles(output_embeddings)

    return (
      input_embeddings,
      output_embeddings,
      output_ts,
      output_quantile_spread,
    ), new_decode_caches

  def decode(self, horizon: int, inputs, masks):
    """Decodes the time series."""

    with torch.no_grad():
      batch_size, context = inputs.shape[0], inputs.shape[1]
      num_decode_steps = (horizon - 1) // self.o
      num_input_patches = context // self.p
      decode_cache_size = num_input_patches + num_decode_steps * self.m

      # Prefill
      patched_inputs = torch.reshape(inputs, (batch_size, -1, self.p))
      patched_masks = torch.reshape(masks, (batch_size, -1, self.p))

      # running stats
      n = torch.zeros(batch_size, device=inputs.device)
      mu = torch.zeros(batch_size, device=inputs.device)
      sigma = torch.zeros(batch_size, device=inputs.device)
      patch_mu = []
      patch_sigma = []
      for i in range(num_input_patches):
        (n, mu, sigma), _ = util.update_running_stats(
          n, mu, sigma, patched_inputs[:, i], patched_masks[:, i]
        )
        patch_mu.append(mu)
        patch_sigma.append(sigma)
      last_n, last_mu, last_sigma = n, mu, sigma
      context_mu = torch.stack(patch_mu, dim=1)
      context_sigma = torch.stack(patch_sigma, dim=1)

      decode_caches = [
        util.DecodeCache(
          next_index=torch.zeros(batch_size, dtype=torch.int32, device=inputs.device),
          num_masked=torch.zeros(batch_size, dtype=torch.int32, device=inputs.device),
          key=torch.zeros(
            batch_size,
            decode_cache_size,
            self.h,
            self.hd,
            device=inputs.device,
          ),
          value=torch.zeros(
            batch_size,
            decode_cache_size,
            self.h,
            self.hd,
            device=inputs.device,
          ),
        )
        for _ in range(self.x)
      ]

      normed_inputs = revin(patched_inputs, context_mu, context_sigma, reverse=False)
      normed_inputs = torch.where(patched_masks, 0.0, normed_inputs)
      (_, _, normed_outputs, normed_quantile_spread), decode_caches = self(
        normed_inputs, patched_masks, decode_caches
      )
      renormed_outputs = torch.reshape(
        revin(normed_outputs, context_mu, context_sigma, reverse=True),
        (batch_size, -1, self.o, self.q),
      )
      renormed_quantile_spread = torch.reshape(
        revin(normed_quantile_spread, context_mu, context_sigma, reverse=True),
        (batch_size, -1, self.os, self.q),
      )[:, -1, ...]

      # Autogressive decode
      ar_outputs = []
      last_renormed_output = renormed_outputs[:, -1, :, self.aridx]

      for _ in range(num_decode_steps):
        new_patched_input = torch.reshape(
          last_renormed_output, (batch_size, self.m, self.p)
        )
        new_mask = torch.zeros_like(new_patched_input, dtype=torch.bool)

        n, mu, sigma = last_n, last_mu, last_sigma
        new_mus, new_sigmas = [], []
        for i in range(self.m):
          (n, mu, sigma), _ = util.update_running_stats(
            n, mu, sigma, new_patched_input[:, i], new_mask[:, i]
          )
          new_mus.append(mu)
          new_sigmas.append(sigma)
        last_n, last_mu, last_sigma = n, mu, sigma
        new_mu = torch.stack(new_mus, dim=1)
        new_sigma = torch.stack(new_sigmas, dim=1)

        new_normed_input = revin(new_patched_input, new_mu, new_sigma, reverse=False)
        (_, _, new_normed_output, _), decode_caches = self(
          new_normed_input, new_mask, decode_caches
        )

        new_renormed_output = torch.reshape(
          revin(new_normed_output, new_mu, new_sigma, reverse=True),
          (batch_size, self.m, self.o, self.q),
        )
        ar_outputs.append(new_renormed_output[:, -1, ...])
        last_renormed_output = new_renormed_output[:, -1, :, self.aridx]

      if num_decode_steps > 0:
        ar_renormed_outputs = torch.stack(ar_outputs, dim=1)
      else:
        ar_renormed_outputs = None

    return renormed_outputs, renormed_quantile_spread, ar_renormed_outputs

  def forecast_naive(
    self, horizon: int, inputs: Sequence[np.ndarray]
  ) -> list[np.ndarray]:
    """Forecasts the time series.

    This is a naive implementation for debugging purposes. No forecasting
    flags are used here. Forecasting quality can be subpar.

    Args:
      horizon: The number of time points to forecast.
      inputs: A sequence of numpy arrays, each representing a time series to
        query forecast for.

    Returns:
      A list of numpy arrays of forecasts.
    """
    outputs = []
    for each_input in inputs:
      input_t = torch.tensor(each_input, dtype=torch.float32)
      mask = torch.zeros_like(input_t, dtype=torch.bool)
      len_front_mask = self.p - (len(each_input) % self.p)
      if len_front_mask < self.p:
        input_t = torch.cat(
          [torch.zeros(len_front_mask, dtype=torch.float32), input_t], dim=0
        )
        mask = torch.cat([torch.ones(len_front_mask, dtype=torch.bool), mask], dim=0)
      input_t = input_t[None, ...]
      mask = mask[None, ...]
      t_pf, _, t_ar = self.decode(horizon, input_t, mask)
      to_concat = [t_pf[:, -1, ...]]
      if t_ar is not None:
        to_concat.append(t_ar.reshape(1, -1, self.q))
      torch_forecast = torch.cat(to_concat, dim=1)[..., :horizon]
      torch_forecast = torch_forecast.squeeze(0)
      outputs.append(torch_forecast.detach().cpu().numpy())
    return outputs


class TimesFM_2p5_200M_torch(
  timesfm_2p5_base.TimesFM_2p5,
  PyTorchModelHubMixin,
  library_name="timesfm",
  repo_url="https://github.com/google-research/timesfm",
  paper_url="https://arxiv.org/abs/2310.10688",
  docs_url="https://github.com/google-research/timesfm",
  license="apache-2.0",
  pipeline_tag="time-series-forecasting",
  tags=["pytorch", "timeseries", "forecasting", "timesfm-2.5"],
):
  """PyTorch implementation of TimesFM 2.5 with 200M parameters."""

  DEFAULT_REPO_ID = "google/timesfm-2.5-200m-pytorch"
  WEIGHTS_FILENAME = "model.safetensors"

  def __init__(
    self,
    torch_compile: bool = True,
    config: Optional[dict] = None,
  ):
    self.model = TimesFM_2p5_200M_torch_module()
    self.torch_compile = torch_compile
    if config is not None:
      self._hub_mixin_config = config

  @classmethod
  def _from_pretrained(
    cls,
    *,
    model_id: str = DEFAULT_REPO_ID,
    revision: Optional[str],
    cache_dir: Optional[Union[str, Path]],
    force_download: bool = False,
    local_files_only: bool,
    token: Optional[Union[str, bool]],
    config: Optional[dict] = None,
    **model_kwargs,
  ):
    """
    Loads a PyTorch safetensors TimesFM model from a local path or the Hugging
    Face Hub. This method is the backend for the `from_pretrained` class
    method provided by `PyTorchModelHubMixin`.
    """
    # Determine the path to the model weights.
    model_file_path = ""
    if os.path.isdir(model_id):
      logging.info("Loading checkpoint from local directory: %s", model_id)
      model_file_path = os.path.join(model_id, cls.WEIGHTS_FILENAME)
      if not os.path.exists(model_file_path):
        raise FileNotFoundError(
          f"{cls.WEIGHTS_FILENAME} not found in directory {model_id}"
        )
    else:
      logging.info("Downloading checkpoint from Hugging Face repo %s", model_id)
      model_file_path = hf_hub_download(
        repo_id=model_id,
        filename=cls.WEIGHTS_FILENAME,
        revision=revision,
        cache_dir=cache_dir,
        force_download=force_download,
        token=token,
        local_files_only=local_files_only,
      )

    # Create an instance of the model wrapper class.
    instance = cls(config=config, **model_kwargs)

    logging.info("Loading checkpoint from: %s", model_file_path)
    # Load the weights into the model.
    instance.model.load_checkpoint(
      model_file_path, torch_compile=instance.torch_compile
    )
    return instance

  def _save_pretrained(self, save_directory: Union[str, Path]):
    """
    Saves the model's state dictionary to a safetensors file. This method
    is called by the `save_pretrained` method from `PyTorchModelHubMixin`.
    """
    if not os.path.exists(save_directory):
      os.makedirs(save_directory)

    weights_path = os.path.join(save_directory, self.WEIGHTS_FILENAME)
    save_file(self.model.state_dict(), weights_path)

  def compile(self, forecast_config: configs.ForecastConfig, **kwargs) -> None:
    """Attempts to compile the model for fast decoding.

    See configs.ForecastConfig for more details on the supported flags.

    Args:
      forecast_config: Configuration for forecasting flags.
      **kwargs: Additional keyword arguments to pass to model.compile().
    """
    self.global_batch_size = (
      forecast_config.per_core_batch_size * self.model.device_count
    )

    # Shortcut.
    fc = forecast_config

    if fc.max_context % self.model.p != 0:
      logging.info(
        "When compiling, max context needs to be multiple of the patch size"
        " %d. Using max context = %d instead.",
        self.model.p,
        new_context := math.ceil(fc.max_context / self.model.p) * self.model.p,
      )
      fc = dataclasses.replace(fc, max_context=new_context)
    if fc.max_horizon % self.model.o != 0:
      logging.info(
        "When compiling, max horizon needs to be multiple of the output patch"
        " size %d. Using max horizon = %d instead.",
        self.model.o,
        new_horizon := math.ceil(fc.max_horizon / self.model.o) * self.model.o,
      )
      fc = dataclasses.replace(fc, max_horizon=new_horizon)
    if fc.max_context + fc.max_horizon > self.model.config.context_limit:
      raise ValueError(
        "Context + horizon must be less than the context limit."
        f" {fc.max_context} + {fc.max_horizon} >"
        f" {self.model.config.context_limit}."
      )
    if fc.use_continuous_quantile_head and (fc.max_horizon > self.model.os):
      raise ValueError(
        f"Continuous quantile head is not supported for horizons > {self.model.os}."
      )
    self.forecast_config = fc

    def _compiled_decode(horizon, inputs, masks):
      if horizon > fc.max_horizon:
        raise ValueError(
          f"Horizon must be less than the max horizon. {horizon} > {fc.max_horizon}."
        )

      inputs = (
        torch.from_numpy(np.array(inputs)).to(self.model.device).to(torch.float32)
      )
      masks = torch.from_numpy(np.array(masks)).to(self.model.device).to(torch.bool)
      batch_size = inputs.shape[0]

      if fc.infer_is_positive:
        is_positive = torch.all(inputs >= 0, dim=-1, keepdim=True)
      else:
        is_positive = None

      if fc.normalize_inputs:
        mu = torch.mean(inputs, dim=-1, keepdim=True)
        sigma = torch.std(inputs, dim=-1, keepdim=True)
        inputs = revin(inputs, mu, sigma, reverse=False)
      else:
        mu, sigma = None, None

      pf_outputs, quantile_spreads, ar_outputs = self.model.decode(
        forecast_config.max_horizon, inputs, masks
      )
      to_cat = [pf_outputs[:, -1, ...]]
      if ar_outputs is not None:
        to_cat.append(ar_outputs.reshape(batch_size, -1, self.model.q))
      full_forecast = torch.cat(to_cat, dim=1)

      def flip_quantile_fn(x):
        return torch.cat([x[..., :1], torch.flip(x[..., 1:], dims=(-1,))], dim=-1)

      if fc.force_flip_invariance:
        flipped_pf_outputs, flipped_quantile_spreads, flipped_ar_outputs = (
          self.model.decode(forecast_config.max_horizon, -inputs, masks)
        )
        flipped_quantile_spreads = flip_quantile_fn(flipped_quantile_spreads)
        flipped_pf_outputs = flip_quantile_fn(flipped_pf_outputs)
        to_cat = [flipped_pf_outputs[:, -1, ...]]
        if flipped_ar_outputs is not None:
          to_cat.append(flipped_ar_outputs.reshape(batch_size, -1, self.model.q))
        flipped_full_forecast = torch.cat(to_cat, dim=1)
        quantile_spreads = (quantile_spreads - flipped_quantile_spreads) / 2
        pf_outputs = (pf_outputs - flipped_pf_outputs) / 2
        full_forecast = (full_forecast - flipped_full_forecast) / 2

      if fc.use_continuous_quantile_head:
        for quantile_index in [1, 2, 3, 4, 6, 7, 8, 9]:
          full_forecast[:, :, quantile_index] = (
            quantile_spreads[:, : fc.max_horizon, quantile_index]
            - quantile_spreads[:, : fc.max_horizon, 5]
            + full_forecast[:, : fc.max_horizon, 5]
          )
      full_forecast = full_forecast[:, :horizon, :]

      if fc.return_backcast:
        full_backcast = pf_outputs[:, :-1, : self.model.p, :].reshape(
          batch_size, -1, self.model.q
        )
        full_forecast = torch.cat([full_backcast, full_forecast], dim=1)

      if fc.fix_quantile_crossing:
        for i in [4, 3, 2, 1]:
          full_forecast[:, :, i] = torch.where(
            full_forecast[:, :, i] < full_forecast[:, :, i + 1],
            full_forecast[:, :, i],
            full_forecast[:, :, i + 1],
          )
        for i in [6, 7, 8, 9]:
          full_forecast[:, :, i] = torch.where(
            full_forecast[:, :, i] > full_forecast[:, :, i - 1],
            full_forecast[:, :, i],
            full_forecast[:, :, i - 1],
          )

      if fc.normalize_inputs:
        full_forecast = revin(full_forecast, mu, sigma, reverse=True)

      if is_positive is not None:
        full_forecast = torch.where(
          is_positive[..., None],
          torch.maximum(full_forecast, torch.zeros_like(full_forecast)),
          full_forecast,
        )

      full_forecast = full_forecast.detach().cpu().numpy()
      return full_forecast[..., 5], full_forecast

    self.compiled_decode = _compiled_decode


================================================
FILE: src/timesfm/torch/__init__.py
================================================
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


================================================
FILE: src/timesfm/torch/dense.py
================================================
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Dense layers for TimesFM."""

import torch
from torch import nn

from .. import configs


class ResidualBlock(nn.Module):
  """Residual block with two linear layers and a linear residual connection."""

  def __init__(self, config: configs.ResidualBlockConfig):
    super().__init__()
    self.config = config
    self.hidden_layer = nn.Linear(
        in_features=config.input_dims,
        out_features=config.hidden_dims,
        bias=config.use_bias,
    )
    self.output_layer = nn.Linear(
        in_features=config.hidden_dims,
        out_features=config.output_dims,
        bias=config.use_bias,
    )
    self.residual_layer = nn.Linear(
        in_features=config.input_dims,
        out_features=config.output_dims,
        bias=config.use_bias,
    )
    if config.activation == "relu":
      self.activation = nn.ReLU()
    elif config.activation == "swish":
      self.activation = nn.SiLU()
    elif config.activation == "none":
      self.activation = nn.Identity()
    else:
      raise ValueError(f"Activation: {config.activation} not supported.")

  def forward(self, x: torch.Tensor) -> torch.Tensor:
    return self.output_layer(
        self.activation(self.hidden_layer(x))
    ) + self.residual_layer(x)


class RandomFourierFeatures(nn.Module):
  """Random Fourier features layer."""

  def __init__(self, config: configs.RandomFourierFeaturesConfig):
    super().__init__()
    self.config = config

    if config.output_dims % 4 != 0:
      raise ValueError(
          f"Output dims must be a multiple of 4: {config.output_dims} % 4 != 0."
      )
    num_projected_features = config.output_dims // 4

    self.phase_shifts = nn.Parameter(torch.zeros(2, num_projected_features))
    self.projection_layer = nn.Linear(
        in_features=config.input_dims,
        out_features=num_projected_features,
        bias=config.use_bias,
    )
    self.residual_layer = nn.Linear(
        in_features=config.input_dims,
        out_features=config.output_dims,
        bias=config.use_bias,
    )

  def forward(self, x: torch.Tensor) -> torch.Tensor:
    projected = self.projection_layer(x)
    cos_features = torch.cos(projected)
    sin_features = torch.sin(projected)
    sq_wave_1 = torch.sign(torch.sin(projected + self.phase_shifts[0, :]))
    sq_wave_2 = torch.sign(torch.sin(projected + self.phase_shifts[1, :]))
    fourier_features = torch.cat(
        [cos_features, sin_features, sq_wave_1, sq_wave_2], dim=-1
    )
    residual = self.residual_layer(x)
    return fourier_features + residual


================================================
FILE: src/timesfm/torch/normalization.py
================================================
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Normalization layers for TimesFM."""

import torch
from torch import nn


class RMSNorm(nn.Module):
  """RMS normalization."""

  def __init__(
      self,
      num_features: int,
      *,
      epsilon: float = 1e-6,
  ):
    super().__init__()
    self.scale = nn.Parameter(torch.zeros(num_features))
    self.num_features = num_features
    self.epsilon = epsilon

  def forward(self, inputs: torch.Tensor) -> torch.Tensor:
    var = torch.mean(torch.square(inputs), dim=-1, keepdim=True)
    normed_inputs = inputs * torch.rsqrt(var + self.epsilon)
    normed_inputs = normed_inputs * self.scale
    return normed_inputs


================================================
FILE: src/timesfm/torch/transformer.py
================================================
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Transformer layers for TimesFM."""

import math
from typing import Callable

import torch
import torch.nn.functional as F
from torch import nn

from .. import configs
from . import normalization, util

LayerNorm = nn.LayerNorm
RMSNorm = normalization.RMSNorm
DecodeCache = util.DecodeCache


def make_attn_mask(
  query_length: int,
  num_all_masked_kv: torch.Tensor,
  query_index_offset: torch.Tensor | None = None,
  kv_length: int = 0,
) -> torch.Tensor:
  """Makes attention mask."""
  if kv_length == 0:
    kv_length = query_length

  q_index = torch.arange(query_length, device=num_all_masked_kv.device)[
    None, None, :, None
  ]
  if query_index_offset is not None:
    q_index = q_index + query_index_offset[:, None, None, None]
  kv_index = torch.arange(kv_length, device=num_all_masked_kv.device)[
    None, None, None, :
  ]
  return torch.logical_and(
    q_index >= kv_index,
    kv_index >= num_all_masked_kv[:, None, None, None],
  )


class RotaryPositionalEmbedding(nn.Module):
  """Rotary positional embedding."""

  def __init__(
    self,
    embedding_dims: int,
    min_timescale: float = 1.0,
    max_timescale: float = 10000.0,
  ):
    super().__init__()
    self.embedding_dims = embedding_dims
    self.min_timescale = min_timescale
    self.max_timescale = max_timescale

  def forward(
    self,
    inputs: torch.Tensor,
    position: torch.Tensor | None = None,
  ):
    """Generates a JTensor of sinusoids with different frequencies."""
    if self.embedding_dims != inputs.shape[-1]:
      raise ValueError(
        "The embedding dims of the rotary position embedding"
        "must match the hidden dimension of the inputs."
      )
    half_embedding_dim = self.embedding_dims // 2
    fraction = (
      2
      * torch.arange(0, half_embedding_dim, device=inputs.device)
      / self.embedding_dims
    )
    timescale = (
      self.min_timescale * (self.max_timescale / self.min_timescale) ** fraction
    ).to(inputs.device)
    if position is None:
      seq_length = inputs.shape[1]
      position = torch.arange(seq_length, dtype=torch.float32, device=inputs.device)[
        None, :
      ]

    if len(inputs.shape) == 4:
      position = position[..., None, None]
      timescale = timescale[None, None, None, :]
    elif len(inputs.shape) == 3:
      position = position[..., None]
      timescale = timescale[None, None, :]
    else:
      raise ValueError("Inputs must be of rank 3 or 4.")

    sinusoid_inp = position / timescale
    sin = torch.sin(sinusoid_inp)
    cos = torch.cos(sinusoid_inp)
    first_half, second_half = torch.chunk(inputs, 2, dim=-1)
    first_part = first_half * cos - second_half * sin
    second_part = second_half * cos + first_half * sin
    return torch.cat([first_part, second_part], dim=-1)


def _dot_product_attention(
  query,
  key,
  value,
  mask=None,
):
  """Computes dot-product attention given query, key, and value."""
  attn_weights = torch.einsum("...qhd,...khd->...hqk", query, key)
  if mask is not None:
    attn_weights = torch.where(
      mask, attn_weights, -torch.finfo(attn_weights.dtype).max / 2
    )

  attn_weights = F.softmax(attn_weights, dim=-1)

  return torch.einsum("...hqk,...khd->...qhd", attn_weights, value)


def _torch_dot_product_attention(query, key, value, mask=None):
  """
  Performs the exact same (unscaled) attention as the above function,
  but using the fast and fused F.scaled_dot_product_attention kernel.
  """

  # 1. Permute inputs from (B, L, H, D) to the expected (B, H, L, D)
  query = query.permute(0, 2, 1, 3)
  key = key.permute(0, 2, 1, 3)
  value = value.permute(0, 2, 1, 3)

  # 2. Call the fused attention kernel
  #    - Pass the mask to `attn_mask`.
  #    - Set `scale=1.0` to disable the default 1/sqrt(d_k) scaling.
  output = F.scaled_dot_product_attention(query, key, value, attn_mask=mask, scale=1.0)

  # 3. Permute the output back to the original (B, L, H, D) layout
  output = output.permute(0, 2, 1, 3)

  return output


class PerDimScale(nn.Module):
  """Per-dimension scaling."""

  def __init__(self, num_dims: int):
    super().__init__()
    self.num_dims = num_dims
    self.per_dim_scale = nn.Parameter(torch.zeros(num_dims))

  def forward(self, x: torch.Tensor) -> torch.Tensor:
    scale_factor = (
      1.442695041 / math.sqrt(self.num_dims) * F.softplus(self.per_dim_scale)
    )
    return x * scale_factor


class MultiHeadAttention(nn.Module):
  """Multi-head attention."""

  def __init__(
    self,
    num_heads: int,
    in_features: int,
    *,
    use_per_dim_scale: bool = True,
    use_rotary_position_embeddings: bool = True,
    use_bias: bool = False,
    attention_fn: Callable[..., torch.Tensor] = _torch_dot_product_attention,
    qk_norm: str = "rms",
    fuse_qkv: bool = False,
  ):
    super().__init__()
    self.num_heads = num_heads
    self.in_features = in_features
    self.head_dim = in_features // num_heads
    self.use_bias = use_bias
    self.attention_fn = attention_fn
    self.qk_norm = qk_norm
    self.fuse_qkv = fuse_qkv

    if self.in_features % self.num_heads != 0:
      raise ValueError(
        f"Memory dimension ({self.in_features}) must be divisible by "
        f"'num_heads' heads ({self.num_heads})."
      )

    if self.fuse_qkv:
      self.qkv_proj = nn.Linear(self.in_features, 3 * self.in_features, bias=use_bias)
    else:
      self.query = nn.Linear(self.in_features, self.in_features, bias=use_bias)
      self.key = nn.Linear(self.in_features, self.in_features, bias=use_bias)
      self.value = nn.Linear(self.in_features, self.in_features, bias=use_bias)
    self.out = nn.Linear(self.in_features, self.in_features, bias=use_bias)

    if self.qk_norm == "rms":
      self.query_ln = RMSNorm(self.head_dim)
      self.key_ln = RMSNorm(self.head_dim)
    else:
      self.query_ln = nn.Identity()
      self.key_ln = nn.Identity()

    self.use_rotary_position_embeddings = use_rotary_position_embeddings
    if self.use_rotary_position_embeddings:
      self.rotary_position_embedding = RotaryPositionalEmbedding(
        embedding_dims=self.head_dim,
      )

    self.use_per_dim_scale = use_per_dim_scale
    if use_per_dim_scale:
      self.per_dim_scale = PerDimScale(num_dims=self.head_dim)

  def forward(
    self,
    inputs_q: torch.Tensor,
    *,
    decode_cache: DecodeCache | None = None,
    patch_mask: torch.Tensor | None = None,
  ) -> tuple[torch.Tensor, DecodeCache | None]:
    b, n_patches, _ = inputs_q.shape
    if patch_mask is None:
      patch_mask = torch.zeros(b, n_patches, dtype=torch.bool, device=inputs_q.device)

    if self.fuse_qkv:
      qkv = self.qkv_proj(inputs_q)
      query, key, value = torch.chunk(qkv, 3, dim=-1)
      query = query.view(b, n_patches, self.num_heads, self.head_dim)
      key = key.view(b, n_patches, self.num_heads, self.head_dim)
      value = value.view(b, n_patches, self.num_heads, self.head_dim)
    else:
      query = self.query(inputs_q).view(b, n_patches, self.num_heads, self.head_dim)
      key = self.key(inputs_q).view(b, n_patches, self.num_heads, self.head_dim)
      value = self.value(inputs_q).view(b, n_patches, self.num_heads, self.head_dim)

    if decode_cache is None:
      num_masked = torch.sum(patch_mask.to(torch.int32), dim=-1)
      next_index = torch.zeros_like(num_masked, dtype=torch.int32)
    else:
      num_masked = (
        torch.sum(patch_mask.to(torch.int32), dim=-1) + decode_cache.num_masked
      )
      next_index = decode_cache.next_index.clone()

    if self.use_rotary_position_embeddings:
      position = (
        torch.arange(n_patches, device=inputs_q.device)[None, :]
        + next_index[:, None]
        - num_masked[:, None]
      )
      query = self.rotary_position_embedding(query, position)
      key = self.rotary_position_embedding(key, position)

    query = self.query_ln(query)
    key = self.key_ln(key)

    if self.use_per_dim_scale:
      query = self.per_dim_scale(query)

    if decode_cache is not None:
      _, decode_cache_size, _, _ = decode_cache.value.shape

      start = decode_cache.next_index[0]
      end = start + n_patches

      # Perform a single, vectorized slice assignment for the entire batch.
      # This is vastly more efficient than a Python for-loop.

      decode_cache.key[:, start:end] = key
      decode_cache.value[:, start:end] = value

      key = decode_cache.key
      value = decode_cache.value
      decode_cache.next_index += n_patches
      decode_cache.num_masked = num_masked
      attn_mask = make_attn_mask(
        query_length=n_patches,
        num_all_masked_kv=num_masked,
        query_index_offset=next_index,
        kv_length=decode_cache_size,
      )
    else:
      attn_mask = make_attn_mask(query_length=n_patches, num_all_masked_kv=num_masked)

    x = self.attention_fn(
      query,
      key,
      value,
      mask=attn_mask,
    )

    x = x.reshape(b, n_patches, self.in_features)
    out = self.out(x)
    return out, decode_cache


class Transformer(nn.Module):
  """Classic Transformer used in TimesFM."""

  def __init__(self, config: configs.TransformerConfig):
    super().__init__()
    self.config = config

    if config.attention_norm == "rms":
      self.pre_attn_ln = RMSNorm(num_features=config.model_dims)
      self.post_attn_ln = RMSNorm(num_features=config.model_dims)
    else:
      raise ValueError(f"Layer norm: {config.attention_norm} not supported.")

    self.attn = MultiHeadAttention(
      num_heads=config.num_heads,
      in_features=config.model_dims,
      use_per_dim_scale=True,
      use_rotary_position_embeddings=config.use_rotary_position_embeddings,
      qk_norm=config.qk_norm,
      fuse_qkv=config.fuse_qkv,
    )

    if config.feedforward_norm == "rms":
      self.pre_ff_ln = RMSNorm(num_features=config.model_dims)
      self.post_ff_ln = RMSNorm(num_features=config.model_dims)
    else:
      raise ValueError(f"Layer norm: {config.feedforward_norm} not supported.")

    self.ff0 = nn.Linear(
      in_features=config.model_dims,
      out_features=config.hidden_dims,
      bias=config.use_bias,
    )
    self.ff1 = nn.Linear(
      in_features=config.hidden_dims,
      out_features=config.model_dims,
      bias=config.use_bias,
    )
    if config.ff_activation == "relu":
      self.activation = nn.ReLU()
    elif config.ff_activation == "swish":
      self.activation = nn.SiLU()
    elif config.ff_activation == "none":
      self.activation = nn.Identity()
    else:
      raise ValueError(f"Activation: {config.ff_activation} not supported.")

  def forward(
    self,
    input_embeddings: torch.Tensor,
    patch_mask: torch.Tensor,
    decode_cache: DecodeCache | None = None,
  ) -> tuple[torch.Tensor, DecodeCache | None]:
    attn_output, decode_cache = self.attn(
      inputs_q=self.pre_attn_ln(input_embeddings),
      decode_cache=decode_cache,
      patch_mask=patch_mask,
    )
    attn_output = self.post_attn_ln(attn_output) + input_embeddings
    output_embeddings = (
      self.post_ff_ln(self.ff1(self.activation(self.ff0(self.pre_ff_ln(attn_output)))))
      + attn_output
    )
    return output_embeddings, decode_cache


================================================
FILE: src/timesfm/torch/util.py
================================================
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""PyTorch utility functions for TimesFM layers."""

import dataclasses
import torch

_TOLERANCE = 1e-6


@dataclasses.dataclass(frozen=False)
class DecodeCache:
  """Cache for decoding."""

  next_index: torch.Tensor
  num_masked: torch.Tensor
  key: torch.Tensor
  value: torch.Tensor


def update_running_stats(
    n: torch.Tensor,
    mu: torch.Tensor,
    sigma: torch.Tensor,
    x: torch.Tensor,
    mask: torch.Tensor,
) -> tuple[
    tuple[torch.Tensor, torch.Tensor, torch.Tensor],
    tuple[torch.Tensor, torch.Tensor, torch.Tensor],
]:
  """Updates the running stats."""
  is_legit = torch.logical_not(mask)
  inc_n = torch.sum(is_legit.to(x.dtype), dim=-1)

  inc_mu_numerator = torch.sum(x * is_legit, dim=-1)
  inc_n_safe = torch.where(inc_n == 0, 1.0, inc_n)
  inc_mu = inc_mu_numerator / inc_n_safe
  inc_mu = torch.where(inc_n == 0, 0.0, inc_mu)

  inc_var_numerator = torch.sum(
      ((x - inc_mu.unsqueeze(-1)) ** 2) * is_legit, dim=-1
  )
  inc_var = inc_var_numerator / inc_n_safe
  inc_var = torch.where(inc_n == 0, 0.0, inc_var)
  inc_sigma = torch.sqrt(inc_var)

  new_n = n + inc_n
  new_n_safe = torch.where(new_n == 0, 1.0, new_n)

  new_mu = (n * mu + inc_mu * inc_n) / new_n_safe
  new_mu = torch.where(new_n == 0, 0.0, new_mu)

  term1 = n * sigma.pow(2)
  term2 = inc_n * inc_sigma.pow(2)
  term3 = n * (mu - new_mu).pow(2)
  term4 = inc_n * (inc_mu - new_mu).pow(2)

  new_var = (term1 + term2 + term3 + term4) / new_n_safe
  new_var = torch.where(new_n == 0, 0.0, new_var)
  new_sigma = torch.sqrt(torch.clamp(new_var, min=0.0))

  return (w := (new_n, new_mu, new_sigma), w)


def revin(
    x: torch.Tensor,
    mu: torch.Tensor,
    sigma: torch.Tensor,
    reverse: bool = False,
):
  """Reversible instance normalization."""
  if len(mu.shape) == len(x.shape) - 1:
    mu = mu[..., None]
    sigma = sigma[..., None]
  elif len(mu.shape) == len(x.shape) - 2:
    mu = mu[..., None, None]
    sigma = sigma[..., None, None]

  if reverse:
    return x * sigma + mu
  else:
    return (x - mu) / torch.where(sigma < _TOLERANCE, 1.0, sigma)


================================================
FILE: src/timesfm/utils/xreg_lib.py
================================================
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helper functions for in-context covariates and regression."""

import itertools
import math
from typing import Any, Iterable, Literal, Mapping, Sequence

try:
  import jax
  import jax.numpy as jnp
  import numpy as np
  from sklearn import preprocessing
except ImportError:
  raise ImportError(
    "Failed to load the XReg module. Did you forget to install `timesfm[xreg]`?"
  )

Category = int | str

_TOL = 1e-6
XRegMode = Literal["timesfm + xreg", "xreg + timesfm"]


def _unnest(nested: Sequence[Sequence[Any]]) -> np.ndarray:
  return np.array(list(itertools.chain.from_iterable(nested)))


def _repeat(elements: Iterable[Any], counts: Iterable[int]) -> np.ndarray:
  return np.array(
    list(itertools.chain.from_iterable(map(itertools.repeat, elements, counts)))
  )


def _to_padded_jax_array(x: np.ndarray) -> jax.Array:
  if x.ndim == 1:
    (i,) = x.shape
    di = 2 ** math.ceil(math.log2(i)) - i
    return jnp.pad(x, ((0, di),), mode="constant", constant_values=0.0)
  elif x.ndim == 2:
    i, j = x.shape
    di = 2 ** math.ceil(math.log2(i)) - i
    dj = 2 ** math.ceil(math.log2(j)) - j
    return jnp.pad(x, ((0, di), (0, dj)), mode="constant", constant_values=0.0)
  else:
    raise ValueError(f"Unsupported array shape: {x.shape}")


# Per time series normalization: forward.
def normalize(batch):
  stats = [(np.mean(x), np.where((w := np.std(x)) > _TOL, w, 1.0)) for x in batch]
  new_batch = [(x - stat[0]) / stat[1] for x, stat in zip(batch, stats)]
  return new_batch, stats


# Per time series normalization: inverse.
def renormalize(batch, stats):
  return [x * stat[1] + stat[0] for x, stat in zip(batch, stats)]


class BatchedInContextXRegBase:
  """Helper class for in-context regression covariate formatting.

  Attributes:
    targets: List of targets (responses) of the in-context regression.
    train_lens: List of lengths of each target vector from the context.
    test_lens: List of lengths of each forecast horizon.
    train_dynamic_numerical_covariates: Dict of covariate names mapping to the
      dynamic numerical covariates of each forecast task on the context. Their
      lengths should match the corresponding lengths in `train_lens`.
    train_dynamic_categorical_covariates: Dict of covariate names mapping to the
      dynamic categorical covariates of each forecast task on the context. Their
      lengths should match the corresponding lengths in `train_lens`.
    test_dynamic_numerical_covariates: Dict of covariate names mapping to the
      dynamic numerical covariates of each forecast task on the horizon. Their
      lengths should match the corresponding lengths in `test_lens`.
    test_dynamic_categorical_covariates: Dict of covariate names mapping to the
      dynamic categorical covariates of each forecast task on the horizon. Their
      lengths should match the corresponding lengths in `test_lens`.
    static_numerical_covariates: Dict of covariate names mapping to the static
      numerical covariates of each forecast task.
    static_categorical_covariates: Dict of covariate names mapping to the static
      categorical covariates of each forecast task.
  """

  def __init__(
    self,
    targets: Sequence[Sequence[float]],
    train_lens: Sequence[int],
    test_lens: Sequence[int],
    train_dynamic_numerical_covariates: (
      Mapping[str, Sequence[Sequence[float]]] | None
    ) = None,
    train_dynamic_categorical_covariates: (
      Mapping[str, Sequence[Sequence[Category]]] | None
    ) = None,
    test_dynamic_numerical_covariates: (
      Mapping[str, Sequence[Sequence[float]]] | None
    ) = None,
    test_dynamic_categorical_covariates: (
      Mapping[str, Sequence[Sequence[Category]]] | None
    ) = None,
    static_numerical_covariates: Mapping[str, Sequence[float]] | None = None,
    static_categorical_covariates: (Mapping[str, Sequence[Category]] | None) = None,
  ) -> None:
    """Initializes with the exogenous covariate inputs.

    Here we use model fitting language to refer to the context as 'train' and
    the horizon as 'test'. We assume batched inputs. To properly format the
    request:

     - `train_lens` represents the contexts in the batch. Targets and all train
     dynamic covariates should have the same lengths as the corresponding
     elements
     in `train_lens`. Notice each `train_len` can be different from the exact
     length of the corresponding context depending on how much of the context is
     used for fitting the in-context model.
     - `test_lens` represents the horizon lengths in the batch. All tesdt
     dynamic
     covariates should have the same lengths as the corresponding elements in
     `test_lens`.
     - Static covariates should be one for each input.
     - For train and test dynamic covariates, they should have the same
     covariate
     names.

     Pass an empty dict {} for a covariate type if it is not present.

     Example:
       Here is a set of valid inputs whose schema can be used for reference.
       ```
       targets = [
           [0.0, 0.1, 0.2],
           [0.0, 0.1, 0.2, 0.3],
       ]  # Two inputs in this batch.
       train_lens = [3, 4]
       test_lens = [2, 5]  # Forecast horizons 2 and 5 respectively.
       train_dynamic_numerical_covariates = {
           "cov_1_dn": [[0.0, 0.5, 1.0], [0.0, 0.5, 1.0, 1.5]],
           "cov_2_dn": [[0.0, 1.5, 1.0], [0.0, 1.5, 1.0, 2.5]],
       }  # Each train dynamic covariate has 3 and 4 elements respectively.
       test_dynamic_numerical_covariates = {
           "cov_1_dn": [[0.1, 0.6], [0.1, 0.6, 1.1, 1.6, 2.4]],
           "cov_2_dn": [[0.1, 1.1], [0.1, 1.6, 1.1, 2.6, 10.0]],
       }  # Each test dynamic covariate has 2 and 5 elements respectively.
       train_dynamic_categorical_covariates = {
           "cov_1_dc": [[0, 1, 0], [0, 1, 2, 3]],
           "cov_2_dc": [["good", "bad", "good"], ["good", "good", "bad",
           "bad"]],
       }
       test_dynamic_categorical_covariates = {
           "cov_1_dc": [[1, 0], [1, 0, 2, 3, 1]],
           "cov_2_dc": [["bad", "good"], ["bad", "bad", "bad", "bad", "bad"]],
       }
       static_numerical_covariates = {
           "cov_1_sn": [0.0, 3.0],
           "cov_2_sn": [2.0, 1.0],
           "cov_3_sn": [1.0, 2.0],
       }  # Each static covariate has 1 element for each input.
       static_categorical_covariates = {
           "cov_1_sc": ["apple", "orange"],
           "cov_2_sc": [2, 3],
       }
       ```

    Args:
      targets: List of targets (responses) of the in-context regression.
      train_lens: List of lengths of each target vector from the context.
      test_lens: List of lengths of each forecast horizon.
      train_dynamic_numerical_covariates: Dict of covariate names mapping to the
        dynamic numerical covariates of each forecast task on the context. Their
        lengths should match the corresponding lengths in `train_lens`.
      train_dynamic_categorical_covariates: Dict of covariate names mapping to
        the dynamic categorical covariates of each forecast task on the context.
        Their lengths should match the corresponding lengths in `train_lens`.
      test_dynamic_numerical_covariates: Dict of covariate names mapping to the
        dynamic numerical covariates of each forecast task on the horizon. Their
        lengths should match the corresponding lengths in `test_lens`.
      test_dynamic_categorical_covariates: Dict of covariate names mapping to
        the dynamic categorical covariates of each forecast task on the horizon.
        Their lengths should match the corresponding lengths in `test_lens`.
      static_numerical_covariates: Dict of covariate names mapping to the static
        numerical covariates of each forecast task.
      static_categorical_covariates: Dict of covariate names mapping to the
        static categorical covariates of each forecast task.
    """
    self.targets = targets
    self.train_lens = train_lens
    self.test_lens = test_lens
    self.train_dynamic_numerical_covariates = train_dynamic_numerical_covariates or {}
    self.train_dynamic_categorical_covariates = (
      train_dynamic_categorical_covariates or {}
    )
    self.test_dynamic_numerical_covariates = test_dynamic_numerical_covariates or {}
    self.test_dynamic_categorical_covariates = test_dynamic_categorical_covariates or {}
    self.static_numerical_covariates = static_numerical_covariates or {}
    self.static_categorical_covariates = static_categorical_covariates or {}

  def _assert_covariates(self, assert_covariate_shapes: bool = False) -> None:
    """Verifies the validity of the covariate inputs."""

    # Check presence.
    if (
      self.train_dynamic_numerical_covariates
      and not self.test_dynamic_numerical_covariates
    ) or (
      not self.train_dynamic_numerical_covariates
      and self.test_dynamic_numerical_covariates
    ):
      raise ValueError(
        "train_dynamic_numerical_covariates and"
        " test_dynamic_numerical_covariates must be both present or both"
        " absent."
      )

    if (
      self.train_dynamic_categorical_covariates
      and not self.test_dynamic_categorical_covariates
    ) or (
      not self.train_dynamic_categorical_covariates
      and self.test_dynamic_categorical_covariates
    ):
      raise ValueError(
        "train_dynamic_categorical_covariates and"
        " test_dynamic_categorical_covariates must be both present or both"
        " absent."
      )

    # Check keys.
    for dict_a, dict_b, dict_a_name, dict_b_name in (
      (
        self.train_dynamic_numerical_covariates,
        self.test_dynamic_numerical_covariates,
        "train_dynamic_numerical_covariates",
        "test_dynamic_numerical_covariates",
      ),
      (
        self.train_dynamic_categorical_covariates,
        self.test_dynamic_categorical_covariates,
        "train_dynamic_categorical_covariates",
        "test_dynamic_categorical_covariates",
      ),
    ):
      if w := set(dict_a.keys()) - set(dict_b.keys()):
        raise ValueError(f"{dict_a_name} has keys not present in {dict_b_name}: {w}")
      if w := set(dict_b.keys()) - set(dict_a.keys()):
        raise ValueError(f"{dict_b_name} has keys not present in {dict_a_name}: {w}")

    # Check shapes.
    if assert_covariate_shapes:
      if len(self.targets) != len(self.train_lens):
        raise ValueError(
          "targets and train_lens must have the same number of elements."
        )

      if len(self.train_lens) != len(self.test_lens):
        raise ValueError(
          "train_lens and test_lens must have the same number of elements."
        )

      for i, (target, train_len) in enumerate(zip(self.targets, self.train_lens)):
        if len(target) != train_len:
          raise ValueError(
            f"targets[{i}] has length {len(target)} != expected {train_len}."
          )

      for key, values in self.static_numerical_covariates.items():
        if len(values) != len(self.train_lens):
          raise ValueError(
            f"static_numerical_covariates has key {key} with number of"
            f" examples {len(values)} != expected {len(self.train_lens)}."
          )

      for key, values in self.static_categorical_covariates.items():
        if len(values) != len(self.train_lens):
          raise ValueError(
            f"static_categorical_covariates has key {key} with number of"
            f" examples {len(values)} != expected {len(self.train_lens)}."
          )

      for lens, dict_cov, dict_cov_name in (
        (
          self.train_lens,
          self.train_dynamic_numerical_covariates,
          "train_dynamic_numerical_covariates",
        ),
        (
          self.train_lens,
          self.train_dynamic_categorical_covariates,
          "train_dynamic_categorical_covariates",
        ),
        (
          self.test_lens,
          self.test_dynamic_numerical_covariates,
          "test_dynamic_numerical_covariates",
        ),
        (
          self.test_lens,
          self.test_dynamic_categorical_covariates,
          "test_dynamic_categorical_covariates",
        ),
      ):
        for key, cov_values in dict_cov.items():
          if len(cov_values) != len(lens):
            raise ValueError(
              f"{dict_cov_name} has key {key} with number of examples"
              f" {len(cov_values)} != expected {len(lens)}."
            )
          for i, cov_value in enumerate(cov_values):
            if len(cov_value) != lens[i]:
              raise ValueError(
                f"{dict_cov_name} has key {key} with its {i}-th example"
                f" length {len(cov_value)} != expected {lens[i]}."
              )

  def create_covariate_matrix(
    self,
    one_hot_encoder_drop: str | None = "first",
    use_intercept: bool = True,
    assert_covariates: bool = False,
    assert_covariate_shapes: bool = False,
  ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    """Creates target vector and covariate matrices for in context regression.

    Here we use model fitting language to refer to the context as 'train' and
    the horizon as 'test'.

    Args:
      one_hot_encoder_drop: Which drop strategy to use for the one hot encoder.
      use_intercept: Whether to prepare an intercept (all 1) column in the
        matrices.
      assert_covariates: Whether to assert the validity of the covariate inputs.
      assert_covariate_shapes: Whether to assert the shapes of the covariate
        inputs when `assert_covariates` is True.

    Returns:
      A tuple of the target vector, the covariate matrix for the context, and
      the covariate matrix for the horizon.
    """
    if assert_covariates:
      self._assert_covariates(assert_covariate_shapes)

    x_train, x_test = [], []

    # Numerical features.
    for name in sorted(self.train_dynamic_numerical_covariates):
      x_train.append(
        _unnest(self.train_dynamic_numerical_covariates[name])[:, np.newaxis]
      )
      x_test.append(
        _unnest(self.test_dynamic_numerical_covariates[name])[:, np.newaxis]
      )

    for covs in self.static_numerical_covariates.values():
      x_train.append(_repeat(covs, self.train_lens)[:, np.newaxis])
      x_test.append(_repeat(covs, self.test_lens)[:, np.newaxis])

    if x_train:
      x_train = np.concatenate(x_train, axis=1)
      x_test = np.concatenate(x_test, axis=1)

      # Normalize for robustness.
      x_mean = np.mean(x_train, axis=0, keepdims=True)
      x_std = np.where((w := np.std(x_train, axis=0, keepdims=True)) > _TOL, w, 1.0)
      x_train = [(x_train - x_mean) / x_std]
      x_test = [(x_test - x_mean) / x_std]

    # Categorical features. Encode one by one.
    one_hot_encoder = preprocessing.OneHotEncoder(
      drop=one_hot_encoder_drop,
      sparse_output=False,
      handle_unknown="ignore",
    )
    for name in sorted(self.train_dynamic_categorical_covariates.keys()):
      ohe_train = _unnest(self.train_dynamic_categorical_covariates[name])[
        :, np.newaxis
      ]
      ohe_test = _unnest(self.test_dynamic_categorical_covariates[name])[:, np.newaxis]
      x_train.append(np.array(one_hot_encoder.fit_transform(ohe_train)))
      x_test.append(np.array(one_hot_encoder.transform(ohe_test)))

    for covs in self.static_categorical_covariates.values():
      ohe = one_hot_encoder.fit_transform(np.array(covs)[:, np.newaxis])
      x_train.append(_repeat(ohe, self.train_lens))
      x_test.append(_repeat(ohe, self.test_lens))

    x_train = np.concatenate(x_train, axis=1)
    x_test = np.concatenate(x_test, axis=1)

    if use_intercept:
      x_train = np.pad(x_train, ((0, 0), (1, 0)), constant_values=1.0)
      x_test = np.pad(x_test, ((0, 0), (1, 0)), constant_values=1.0)

    return _unnest(self.targets), x_train, x_test

  def fit(self) -> Any:
    raise NotImplementedError("Fit is not implemented.")


class BatchedInContextXRegLinear(BatchedInContextXRegBase):
  """Linear in-context regression model."""

  def fit(
    self,
    ridge: float = 0.0,
    one_hot_encoder_drop: str | None = "first",
    use_intercept: bool = True,
    force_on_cpu: bool = False,
    max_rows_per_col: int = 0,
    max_rows_per_col_sample_seed: int = 42,
    debug_info: bool = False,
    assert_covariates: bool = False,
    assert_covariate_shapes: bool = False,
  ) -> (
    list[np.ndarray]
    | tuple[list[np.ndarray], list[np.ndarray], jax.Array, jax.Array, jax.Array]
  ):
    """Fits a linear model for in-context regression.

    Args:
      ridge: A non-negative value for specifying the ridge regression penalty.
        If 0 is provided, fallback to ordinary least squares. Note this penalty
        is added to the normalized covariate matrix.
      one_hot_encoder_drop: Which drop strategy to use for the one hot encoder.
      use_intercept: Whether to prepare an intercept (all 1) column in the
        matrices.
      force_on_cpu: Whether to force execution on cpu for accelerator machines.
      max_rows_per_col: How many rows to subsample per column. 0 for no
        subsampling. This is for speeding up model fitting.
      max_rows_per_col_sample_seed: The seed for the subsampling if needed by
        `max_rows_per_col`.
      debug_info: Whether to return debug info.
      assert_covariates: Whether to assert the validity of the covariate inputs.
      assert_covariate_shapes: Whether to assert the shapes of the covariate
        inputs when `assert_covariates` is True.

    Returns:
      If `debug_info` is False:
        The linear fits on the horizon.
      If `debug_info` is True:
        A tuple of:
        - the linear fits on the horizon,
        - the linear fits on the context,
        - the flattened target vector,
        - the covariate matrix for the context, and
        - the covariate matrix for the horizon.
    """
    flat_targets, x_train_raw, x_test = self.create_covariate_matrix(
      one_hot_encoder_drop=one_hot_encoder_drop,
      use_intercept=use_intercept,
      assert_covariates=assert_covariates,
      assert_covariate_shapes=assert_covariate_shapes,
    )

    x_train = x_train_raw.copy()
    if max_rows_per_col:
      nrows, ncols = x_train.shape
      if nrows > (w := ncols * max_rows_per_col):
        subsample = jax.random.choice(
          jax.random.PRNGKey(max_rows_per_col_sample_seed),
          nrows,
          (w,),
          replace=False,
        )
        x_train = x_train[subsample]
        flat_targets = flat_targets[subsample]

    device = jax.devices("cpu")[0] if force_on_cpu else None
    # Runs jitted version of the solvers which are quicker at the cost of
    # running jitting during the first time calling. Re-jitting happens whenever
    # new (padded) shapes are encountered.
    # Ocassionally it helps with the speed and the accuracy if we force single
    # thread execution on cpu for accelerator machines:
    # 1. Avoid moving data to accelarator memory.
    # 2. Avoid precision loss if any.
    with jax.default_device(device):
      x_train_raw = _to_padded_jax_array(x_train_raw)
      x_train = _to_padded_jax_array(x_train)
      flat_targets = _to_padded_jax_array(flat_targets)
      x_test = _to_padded_jax_array(x_test)
      beta_hat = (
        jnp.linalg.pinv(
          x_train.T @ x_train + ridge * jnp.eye(x_train.shape[1]),
          hermitian=True,
        )
        @ x_train.T
        @ flat_targets
      )
      y_hat = x_test @ beta_hat
      y_hat_context = x_train_raw @ beta_hat if debug_info else None

    outputs = []
    outputs_context = []

    # Reconstruct the ragged 2-dim batched forecasts from flattened linear fits.
    train_index, test_index = 0, 0
    for train_index_delta, test_index_delta in zip(self.train_lens, self.test_lens):
      outputs.append(np.array(y_hat[test_index : (test_index + test_index_delta)]))
      if debug_info:
        outputs_context.append(
          np.array(y_hat_context[train_index : (train_index + train_index_delta)])
        )
      train_index += train_index_delta
      test_index += test_index_delta

    if debug_info:
      return outputs, outputs_context, flat_targets, x_train, x_test
    else:
      return outputs


================================================
FILE: timesfm-forecasting/SKILL.md
================================================
---
name: timesfm-forecasting
description: >
  Zero-shot time series forecasting with Google's TimesFM foundation model. Use this
  skill when forecasting ANY univariate time series — sales, sensor readings, stock prices,
  energy demand, patient vitals, weather, or scientific measurements — without training a
  custom model. Supports both basic forecasting and advanced covariate forecasting (XReg)
  with dynamic and static exogenous variables. Automatically checks system RAM/GPU before
  loading the model, validates dataset fit before processing, supports CSV/DataFrame/array
  inputs, and returns point forecasts with calibrated prediction intervals. Includes a
  preflight system checker script that MUST be run before first use to verify the machine
  can load the model and handle your specific dataset.
license: Apache-2.0
metadata:
  author: Clayton Young (@borealBytes)
  version: "1.0.0"
---

# TimesFM Forecasting

## Overview

TimesFM (Time Series Foundation Model) is a pretrained decoder-only foundation model
developed by Google Research for time-series forecasting. It works **zero-shot** — feed it
any univariate time series and it returns point forecasts with calibrated quantile
prediction intervals, no training required.

This skill includes a **mandatory preflight system checker** that verifies RAM, GPU memory,
and disk space before the model is ever loaded so the agent never crashes the user's machine.

> **Key numbers**: TimesFM 2.5 uses 200M parameters (~800 MB on disk, ~1.5 GB in RAM on
> CPU, ~1 GB VRAM on GPU). The archived v1/v2 500M-parameter model needs ~32 GB RAM.
> Always run the system checker first.

## When to Use This Skill

Use this skill when:

- Forecasting **any univariate time series** (sales, demand, sensor, vitals, price, weather)
- You need **zero-shot forecasting** without training a custom model
- You want **probabilistic forecasts** with calibrated prediction intervals (quantiles)
- You have time series of **any length** (the model handles 1–16,384 context points)
- You need to **batch-forecast** hundreds or thousands of series efficiently
- You want a **foundation model** approach instead of hand-tuning ARIMA/ETS parameters
- You need **covariate forecasting** with exogenous variables (price, promotions, holidays, day-of-week effects) → use `forecast_with_covariates()` (TimesFM 2.5 + `pip install timesfm[xreg]`)


Do **not** use this skill when:

- You need classical statistical models with coefficient interpretation → use `statsmodels`
- You need time series classification or clustering → use `aeon`
- You need multivariate vector autoregression or Granger causality → use `statsmodels`
- Your data is tabular (not temporal) → use `scikit-learn`
- You cannot install optional dependencies → XReg requires scikit-learn and JAX


> **Note on Anomaly Detection**: TimesFM does not have built-in anomaly detection, but you
> can use the **quantile forecasts as prediction intervals** — values outside the 90% CI
> (q10–q90) are statistically unusual. See `examples/anomaly-detection/` for a full example.

## ⚠️ Mandatory Preflight: System Requirements Check

**CRITICAL — ALWAYS run the system checker before loading the model for the first time.**

```bash
python scripts/check_system.py
```

This script checks:

1. **Available RAM** — warns if below 4 GB, blocks if below 2 GB
2. **GPU availability** — detects CUDA/MPS devices and VRAM
3. **Disk space** — verifies room for the ~800 MB model download
4. **Python version** — requires 3.10+
5. **Existing installation** — checks if `timesfm` and `torch` are installed

> **Note:** Model weights are **NOT stored in this repository**. TimesFM weights (~800 MB)
> download on-demand from HuggingFace on first use and cache in `~/.cache/huggingface/`.

```mermaid
flowchart TD
    start["🚀 Run check_system.py"] --> ram{"RAM ≥ 4 GB?"}
    ram -->|"Yes"| gpu{"GPU available?"}
    ram -->|"No (2-4 GB)"| warn_ram["⚠️ Warning: tight RAM<br/>CPU-only, small batches"]
    ram -->|"No (< 2 GB)"| block["🛑 BLOCKED<br/>Insufficient memory"]
    warn_ram --> disk
    gpu -->|"CUDA / MPS"| vram{"VRAM ≥ 2 GB?"}
    gpu -->|"CPU only"| cpu_ok["✅ CPU mode<br/>Slower but works"]
    vram -->|"Yes"| gpu_ok["✅ GPU mode<br/>Fast inference"]
    vram -->|"No"| cpu_ok
    gpu_ok --> disk{"Disk ≥ 2 GB free?"}
    cpu_ok --> disk
    disk -->|"Yes"| ready["✅ READY<br/>Safe to load model"]
    disk -->|"No"| block_disk["🛑 BLOCKED<br/>Need space for weights"]
```

### Dataset Preflight (NEW)

Before loading your actual data, verify it will fit in memory:

```bash
# Quick estimate for your dataset
python scripts/check_system.py \
  --num-series 1000 \
  --context-length 1024 \
  --horizon 24 \
  --batch-size 32 \
  --estimate-only
```

This will show you the estimated memory requirements and warn if your dataset is too large.

**Memory Estimation Formula**:
`RAM ≈ 0.8 GB (model) + 0.5 GB (overhead) + (0.2 MB × num_series × context_length / 1000)`

**Example Outputs**:

✅ **Dataset Fits**:
```
Total CPU memory: 2.34 GB
Total GPU memory: 2.15 GB
```

⚠️ **Dataset Too Large**:
```
Dataset requires ~12.5 GB RAM but system has 8.0 GB.
Try: context_length=512 or process in chunks of 50 series.
```

### Hardware Requirements by Model Version

| Model | Parameters | RAM (CPU) | VRAM (GPU) | Disk | Context |
| ----- | ---------- | --------- | ---------- | ---- | ------- |
| **TimesFM 2.5** (recommended) | 200M | ≥ 4 GB | ≥ 2 GB | ~800 MB | up to 16,384 |
| TimesFM 2.0 (archived) | 500M | ≥ 16 GB | ≥ 8 GB | ~2 GB | up to 2,048 |
| TimesFM 1.0 (archived) | 200M | ≥ 8 GB | ≥ 4 GB | ~800 MB | up to 2,048 |

> **Recommendation**: Always use TimesFM 2.5 unless you have a specific reason to use an
> older checkpoint. It is smaller, faster, and supports 8× longer context.

## 🔧 Installation

### Step 1: Verify System (always first)

```bash
python scripts/check_system.py
```

### Step 2: Install TimesFM

```bash
# Using uv (fast)
uv pip install timesfm[torch]

# Or using pip
pip install timesfm[torch]

# For JAX/Flax backend (faster on TPU/GPU)
uv pip install timesfm[flax]
```

### Step 3: Install PyTorch for Your Hardware

```bash
# CUDA 12.1 (NVIDIA GPU)
pip install torch>=2.0.0 --index-url https://download.pytorch.org/whl/cu121

# CPU only
pip install torch>=2.0.0 --index-url https://download.pytorch.org/whl/cpu

# Apple Silicon (MPS)
pip install torch>=2.0.0  # MPS support is built-in
```

## 🎯 Quick Start

### Minimal Example

```python
import torch, numpy as np, timesfm

torch.set_float32_matmul_precision("high")

model = timesfm.TimesFM_2p5_200M_torch.from_pretrained(
    "google/timesfm-2.5-200m-pytorch"
)
model.compile(timesfm.ForecastConfig(
    max_context=1024, max_horizon=256, normalize_inputs=True,
    use_continuous_quantile_head=True, force_flip_invariance=True,
    infer_is_positive=True, fix_quantile_crossing=True,
))

point, quantiles = model.forecast(horizon=24, inputs=[
    np.sin(np.linspace(0, 20, 200)),  # any 1-D array
])
# point.shape == (1, 24)         — median forecast
# quantiles.shape == (1, 24, 10) — 10th–90th percentile bands
```

### Forecast with Covariates (XReg)

TimesFM 2.5+ supports exogenous variables through `forecast_with_covariates()`.
Requires `pip install timesfm[xreg]`.

```python
point, quantiles = model.forecast_with_covariates(
    inputs=inputs,
    dynamic_numerical_covariates={"price": price_arrays},
    dynamic_categorical_covariates={"holiday": holiday_arrays},
    static_categorical_covariates={"region": region_labels},
    xreg_mode="xreg + timesfm",  # or "timesfm + xreg"
)
```

### Anomaly Detection (via Quantile Intervals)

```python
point, q = model.forecast(horizon=H, inputs=[values])

lower_90 = q[0, :, 1]  # 10th percentile
upper_90 = q[0, :, 9]  # 90th percentile

actual = test_values
anomalies = (actual < lower_90) | (actual > upper_90)
```

| Severity | Condition | Interpretation |
| -------- | --------- | -------------- |
| **Normal** | Inside 80% CI | Expected behavior |
| **Warning** | Outside 80% CI | Unusual but possible |
| **Critical** | Outside 90% CI | Statistically rare (< 10% probability) |

> See `examples/anomaly-detection/` for a complete worked example with visualization.

## 📊 Understanding the Output

TimesFM returns `(point_forecast, quantile_forecast)`:

- **`point_forecast`**: shape `(batch, horizon)` — the median (0.5 quantile)
- **`quantile_forecast`**: shape `(batch, horizon, 10)` — ten quantile slices:

| Index | Quantile | Use |
| ----- | -------- | --- |
| 0 | Mean | Average prediction |
| 1 | 0.1 | Lower bound of 80% PI |
| 2 | 0.2 | Lower bound of 60% PI |
| **5** | **0.5** | **Median (= `point_forecast`)** |
| 8 | 0.8 | Upper bound of 60% PI |
| 9 | 0.9 | Upper bound of 80% PI |

```python
point, q = model.forecast(horizon=H, inputs=data)

lower_80 = q[:, :, 1]  # 10th percentile
upper_80 = q[:, :, 9]  # 90th percentile
median   = q[:, :, 5]
```

## 🔧 ForecastConfig Reference

All forecasting behavior is controlled by `timesfm.ForecastConfig`:

```python
timesfm.ForecastConfig(
    max_context=1024,                    # Max context window
    max_horizon=256,                     # Max forecast horizon
    normalize_inputs=True,               # RECOMMENDED — prevents scale instability
    per_core_batch_size=32,              # Tune for memory
    use_continuous_quantile_head=True,   # Better quantile accuracy for long horizons
    force_flip_invariance=True,          # Ensures f(-x) = -f(x)
    infer_is_positive=True,              # Clamp forecasts ≥ 0 when all inputs > 0
    fix_quantile_crossing=True,          # Ensure q10 ≤ q20 ≤ ... ≤ q90
    return_backcast=False,               # Return backcast (for covariate workflows)
)
```

| Parameter | Default | When to Change |
| --------- | ------- | -------------- |
| `max_context` | 0 | Set to match your longest historical window |
| `normalize_inputs` | False | **Always set True** |
| `use_continuous_quantile_head` | False | **Set True** for calibrated PIs |
| `infer_is_positive` | True | Set False for series that can be negative |
| `fix_quantile_crossing` | False | **Set True** for monotonic quantiles |

See `references/api_reference.md` for the complete parameter reference.

## 📋 Common Workflows

### Single Series Forecast

```python
import torch, numpy as np, pandas as pd, timesfm, matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

torch.set_float32_matmul_precision("high")
model = timesfm.TimesFM_2p5_200M_torch.from_pretrained(
    "google/timesfm-2.5-200m-pytorch"
)
model.compile(timesfm.ForecastConfig(
    max_context=512, max_horizon=52, normalize_inputs=True,
    use_continuous_quantile_head=True, fix_quantile_crossing=True,
))

df = pd.read_csv("weekly_demand.csv", parse_dates=["week"])
values = df["demand"].values.astype(np.float32)

point, quantiles = model.forecast(horizon=52, inputs=[values])

fig, ax = plt.subplots(figsize=(12, 5))
ax.plot(values[-104:], label="Historical")
x_fc = range(len(values[-104:]), len(values[-104:]) + 52)
ax.plot(x_fc, point[0], label="Forecast", color="tab:orange")
ax.fill_between(x_fc, quantiles[0, :, 1], quantiles[0, :, 9],
                alpha=0.2, color="tab:orange", label="80% PI")
ax.legend(); ax.set_title("52-Week Demand Forecast")
plt.tight_layout(); plt.savefig("forecast.png", dpi=150)
```

### Batch Forecasting (Many Series)

```python
df = pd.read_csv("all_stores.csv", parse_dates=["date"], index_col="date")
inputs = [df[col].dropna().values.astype(np.float32) for col in df.columns]

point, quantiles = model.forecast(horizon=30, inputs=inputs)

import json
results = {col: {"forecast": point[i].tolist(),
                 "lower_80": quantiles[i, :, 1].tolist(),
                 "upper_80": quantiles[i, :, 9].tolist()}
           for i, col in enumerate(df.columns)}
with open("batch_forecasts.json", "w") as f:
    json.dump(results, f, indent=2)
```

### Evaluate Forecast Accuracy

```python
H = 24
train, actual = values[:-H], values[-H:]
point, quantiles = model.forecast(horizon=H, inputs=[train])
pred = point[0]

mae  = np.mean(np.abs(actual - pred))
rmse = np.sqrt(np.mean((actual - pred) ** 2))
mape = np.mean(np.abs((actual - pred) / actual)) * 100
coverage = np.mean((actual >= quantiles[0, :, 1]) & (actual <= quantiles[0, :, 9])) * 100

print(f"MAE: {mae:.2f} | RMSE: {rmse:.2f} | MAPE: {mape:.1f}% | 80% PI Coverage: {coverage:.1f}%")
```

## ⚙️ Performance Tuning

```python
# Always set on Ampere+ GPUs (A100, RTX 3090+)
torch.set_float32_matmul_precision("high")

# Batch size guidelines:
# GPU 8 GB VRAM:  per_core_batch_size=64
# GPU 16 GB VRAM: per_core_batch_size=128
# CPU 8 GB RAM:   per_core_batch_size=8
# CPU 16 GB RAM:  per_core_batch_size=32

# Memory-constrained: process in chunks
CHUNK = 50
results = []
for i in range(0, len(inputs), CHUNK):
    p, q = model.forecast(horizon=H, inputs=inputs[i:i+CHUNK])
    results.append((p, q))
```

## 📚 Available Scripts

### `scripts/check_system.py`

Mandatory preflight checker — run before first model load.
Now includes **dataset-aware memory estimation** to prevent OOM errors before loading your data.

```bash
# Basic system check
python scripts/check_system.py

# Check if your specific dataset will fit
python scripts/check_system.py \
  --num-series 1000 \
  --context-length 1024 \
  --horizon 24 \
  --batch-size 32

# Quick memory estimate without system checks
python scripts/check_system.py \
  --num-series 5000 \
  --context-length 2048 \
  --estimate-only
```

**What it checks**:

1. **Available RAM** — warns if below 4 GB, blocks if below 2 GB
2. **GPU availability** — detects CUDA/MPS devices and VRAM
3. **Disk space** — verifies room for the ~800 MB model download
4. **Python version** — requires 3.10+
5. **Existing installation** — checks if `timesfm` and `torch` are installed
6. **Dataset fit** (NEW) — estimates memory for your specific dataset and warns if it won't fit

### `scripts/forecast_csv.py`

End-to-end CSV forecasting CLI.

```bash
python scripts/forecast_csv.py input.csv \
    --horizon 24 \
    --date-col date \
    --value-cols sales,revenue \
    --output forecasts.csv
```

## 📖 Reference Documentation

| File | Contents |
| ---- | -------- |
| `references/system_requirements.md` | Hardware tiers, GPU/CPU selection, memory estimation |
| `references/api_reference.md` | Full `ForecastConfig` docs, output shapes, model options |
| `references/data_preparation.md` | Input formats, NaN handling, CSV loading, covariate setup |

## 🧪 Examples

| Example | Directory | What It Demonstrates |
| ------- | --------- | -------------------- |
| **Global Temperature Forecast** | `examples/global-temperature/` | Basic `model.forecast()`, CSV → PNG → GIF pipeline |
| **Anomaly Detection** | `examples/anomaly-detection/` | Two-phase detrend + Z-score + quantile PI, 2-panel viz |
| **Covariates (XReg)** | `examples/covariates-forecasting/` | `forecast_with_covariates()`, 2×2 shared-axis viz |

```bash
# Run all three examples:
cd examples/global-temperature && python run_forecast.py && python visualize_forecast.py
cd examples/anomaly-detection  && python detect_anomalies.py
cd examples/covariates-forecasting && python demo_covariates.py
```

### Expected Outputs

| Example | Key output files | Acceptance criteria |
| ------- | ---------------- | ------------------- |
| global-temperature | `output/forecast_output.json`, `output/forecast_visualization.png` | `point_forecast` has 12 values; PNG shows context + forecast + PI bands |
| anomaly-detection | `output/anomaly_detection.json`, `output/anomaly_detection.png` | Sep 2023 flagged CRITICAL (z ≥ 3.0) |
| covariates-forecasting | `output/sales_with_covariates.csv`, `output/covariates_data.png` | 108 rows (3 stores × 36 weeks); distinct price arrays per store |

## Model Versions

| Version | Params | Context | Status | HuggingFace checkpoint |
| ------- | ------ | ------- | ------ | ---------------------- |
| **2.5** | 200M | 16,384 | **Latest** | `google/timesfm-2.5-200m-pytorch` |
| 2.0 | 500M | 2,048 | Archived | `google/timesfm-2.0-500m-pytorch` |
| 1.0 | 200M | 2,048 | Archived | `google/timesfm-1.0-200m-pytorch` |

- TimesFM 1.0/2.0: must pass `freq=[0]` for monthly data
- TimesFM 2.5: no frequency flag — it was removed

## Resources

- **Paper**: [A Decoder-Only Foundation Model for Time-Series Forecasting](https://arxiv.org/abs/2310.10688) (ICML 2024)
- **HuggingFace**: https://huggingface.co/collections/google/timesfm-release-66e4be5fdb56e960c1e482a6
- **Google Blog**: https://research.google/blog/a-decoder-only-foundation-model-for-time-series-forecasting/
- **BigQuery Integration**: https://cloud.google.com/bigquery/docs/timesfm-model

## Quality Checklist

Run after every TimesFM task before declaring success:

- [ ] **Output shape** — `point_fc` is `(n_series, horizon)`, `quant_fc` is `(n_series, horizon, 10)`
- [ ] **Quantile indices** — index 0 = mean, 1 = q10 ... 9 = q90. NOT 0 = q0.
- [ ] **Frequency flag** — TimesFM 1.0/2.0: pass `freq=[0]` for monthly. TimesFM 2.5: omit.
- [ ] **Series length** — context must be ≥ 32 data points.
- [ ] **No NaN** — `np.isnan(point_fc).any()` must be False.
- [ ] **Axes** — multiple panels sharing data must use `sharex=True`.
- [ ] **`matplotlib.use('Agg')`** — before any pyplot import when running headless.
- [ ] **`infer_is_positive`** — set False for temperature, financial returns, negatives.

## Common Mistakes

1. **Quantile index off-by-one** — `quant_fc[..., 0]` is the **mean**, not q0. q10 = index 1, q90 = index 9. Define: `IDX_Q10, IDX_Q90 = 1, 9`.

2. **Variable shadowing in covariate loops** — don't use the outer loop variable as a comprehension variable when building per-series covariate dicts.

3. **Wrong CSV column name** — global-temperature CSV uses `anomaly_c`, not `anomaly`. Print `df.columns` first.

4. **TimesFM 2.5 required for `forecast_with_covariates()`** — TimesFM 1.0 does NOT have this method.

5. **Future covariates must span the full horizon** — dynamic covariates need values for BOTH context AND forecast windows.

6. **Context anomaly detection uses residuals** — detrend first, then Z-score. Raw Z-scores mislead on trending data.

## Validation & Verification

```bash
# Anomaly detection regression:
python -c "
import json
d = json.load(open('examples/anomaly-detection/output/anomaly_detection.json'))
assert d['context_summary']['critical'] >= 1, 'Sep 2023 must be CRITICAL'
print('Anomaly detection: PASS')"

# Covariates regression:
python -c "
import pandas as pd
df = pd.read_csv('examples/covariates-forecasting/output/sales_with_covariates.csv')
assert len(df) == 108, f'Expected 108 rows, got {len(df)}'
print('Covariates: PASS')"
```


================================================
FILE: timesfm-forecasting/examples/anomaly-detection/detect_anomalies.py
================================================
#!/usr/bin/env python3
"""
TimesFM Anomaly Detection Example — Two-Phase Method

Phase 1 (context): Linear detrend + Z-score on 36 months of real NOAA
  temperature anomaly data (2022-01 through 2024-12).
  Sep 2023 (1.47 C) is a known critical outlier.

Phase 2 (forecast): TimesFM quantile prediction intervals on a 12-month
  synthetic future with 3 injected anomalies.

Outputs:
  output/anomaly_detection.png  -- 2-panel visualization
  output/anomaly_detection.json -- structured detection records
"""

from __future__ import annotations

import json
from pathlib import Path

import matplotlib

matplotlib.use("Agg")
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

HORIZON = 12
DATA_FILE = (
    Path(__file__).parent.parent / "global-temperature" / "temperature_anomaly.csv"
)
OUTPUT_DIR = Path(__file__).parent / "output"

CRITICAL_Z = 3.0
WARNING_Z = 2.0

# quant_fc index mapping: 0=mean, 1=q10, 2=q20, ..., 9=q90
IDX_Q10, IDX_Q20, IDX_Q80, IDX_Q90 = 1, 2, 8, 9

CLR = {"CRITICAL": "#e02020", "WARNING": "#f08030", "NORMAL": "#4a90d9"}


# ---------------------------------------------------------------------------
# Phase 1: context anomaly detection
# ---------------------------------------------------------------------------


def detect_context_anomalies(
    values: np.ndarray,
    dates: list,
) -> tuple[list[dict], np.ndarray, np.ndarray, float]:
    """Linear detrend + Z-score anomaly detection on context period.

    Returns
    -------
    records    : list of dicts, one per month
    trend_line : fitted linear trend values (same length as values)
    residuals  : actual - trend_line
    res_std    : std of residuals (used as sigma for threshold bands)
    """
    n = len(values)
    idx = np.arange(n, dtype=float)

    coeffs = np.polyfit(idx, values, 1)
    trend_line = np.polyval(coeffs, idx)
    residuals = values - trend_line
    res_std = residuals.std()

    records = []
    for i, (d, v, r) in enumerate(zip(dates, values, residuals)):
        z = r / res_std if res_std > 0 else 0.0
        if abs(z) >= CRITICAL_Z:
            severity = "CRITICAL"
        elif abs(z) >= WARNING_Z:
            severity = "WARNING"
        else:
            severity = "NORMAL"
        records.append(
            {
                "date": str(d)[:7],
                "value": round(float(v), 4),
                "trend": round(float(trend_line[i]), 4),
                "residual": round(float(r), 4),
                "z_score": round(float(z), 3),
                "severity": severity,
            }
        )
    return records, trend_line, residuals, res_std


# ---------------------------------------------------------------------------
# Phase 2: synthetic future + forecast anomaly detection
# ---------------------------------------------------------------------------


def build_synthetic_future(
    context: np.ndarray,
    n: int,
    seed: int = 42,
) -> tuple[np.ndarray, list[int]]:
    """Build a plausible future with 3 injected anomalies.

    Injected months: 3, 8, 11 (0-indexed within the 12-month horizon).
    Returns (future_values, injected_indices).
    """
    rng = np.random.default_rng(seed)
    trend = np.linspace(context[-6:].mean(), context[-6:].mean() + 0.05, n)
    noise = rng.normal(0, 0.1, n)
    future = trend + noise

    injected = [3, 8, 11]
    future[3] += 0.7  # CRITICAL spike
    future[8] -= 0.65  # CRITICAL dip
    future[11] += 0.45  # WARNING spike

    return future.astype(np.float32), injected


def detect_forecast_anomalies(
    future_values: np.ndarray,
    point: np.ndarray,
    quant_fc: np.ndarray,
    future_dates: list,
    injected_at: list[int],
) -> list[dict]:
    """Classify each forecast month by which PI band it falls outside.

    CRITICAL = outside 80% PI (q10-q90)
    WARNING  = outside 60% PI (q20-q80) but inside 80% PI
    NORMAL   = inside 60% PI
    """
    q10 = quant_fc[IDX_Q10]
    q20 = quant_fc[IDX_Q20]
    q80 = quant_fc[IDX_Q80]
    q90 = quant_fc[IDX_Q90]

    records = []
    for i, (d, fv, pt) in enumerate(zip(future_dates, future_values, point)):
        outside_80 = fv < q10[i] or fv > q90[i]
        outside_60 = fv < q20[i] or fv > q80[i]

        if outside_80:
            severity = "CRITICAL"
        elif outside_60:
            severity = "WARNING"
        else:
            severity = "NORMAL"

        records.append(
            {
                "date": str(d)[:7],
                "actual": round(float(fv), 4),
                "forecast": round(float(pt), 4),
                "q10": round(float(q10[i]), 4),
                "q20": round(float(q20[i]), 4),
                "q80": round(float(q80[i]), 4),
                "q90": round(float(q90[i]), 4),
                "severity": severity,
                "was_injected": i in injected_at,
            }
        )
    return records


# ---------------------------------------------------------------------------
# Visualization
# ---------------------------------------------------------------------------


def plot_results(
    context_dates: list,
    context_values: np.ndarray,
    ctx_records: list[dict],
    trend_line: np.ndarray,
    residuals: np.ndarray,
    res_std: float,
    future_dates: list,
    future_values: np.ndarray,
    point_fc: np.ndarray,
    quant_fc: np.ndarray,
    fc_records: list[dict],
) -> None:
    OUTPUT_DIR.mkdir(exist_ok=True)

    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 10), gridspec_kw={"hspace": 0.42})
    fig.suptitle(
        "TimesFM Anomaly Detection — Two-Phase Method", fontsize=14, fontweight="bold"
    )

    # -----------------------------------------------------------------------
    # Panel 1 — full timeline
    # -----------------------------------------------------------------------
    ctx_x = [pd.Timestamp(d) for d in context_dates]
    fut_x = [pd.Timestamp(d) for d in future_dates]
    divider = ctx_x[-1]

    # context: blue line + trend + 2sigma band
    ax1.plot(
        ctx_x,
        context_values,
        color=CLR["NORMAL"],
        lw=2,
        marker="o",
        ms=4,
        label="Observed (context)",
    )
    ax1.plot(ctx_x, trend_line, color="#aaaaaa", lw=1.5, ls="--", label="Linear trend")
    ax1.fill_between(
        ctx_x,
        trend_line - 2 * res_std,
        trend_line + 2 * res_std,
        alpha=0.15,
        color=CLR["NORMAL"],
        label="+/-2sigma band",
    )

    # context anomaly markers
    seen_ctx: set[str] = set()
    for rec in ctx_records:
        if rec["severity"] == "NORMAL":
            continue
        d = pd.Timestamp(rec["date"])
        v = rec["value"]
        sev = rec["severity"]
        lbl = f"Context {sev}" if sev not in seen_ctx else None
        seen_ctx.add(sev)
        ax1.scatter(d, v, marker="D", s=90, color=CLR[sev], zorder=6, label=lbl)
        ax1.annotate(
            f"z={rec['z_score']:+.1f}",
            (d, v),
            textcoords="offset points",
            xytext=(0, 9),
            fontsize=7.5,
            ha="center",
            color=CLR[sev],
        )

    # forecast section
    q10 = quant_fc[IDX_Q10]
    q20 = quant_fc[IDX_Q20]
    q80 = quant_fc[IDX_Q80]
    q90 = quant_fc[IDX_Q90]

    ax1.plot(fut_x, future_values, "k--", lw=1.5, label="Synthetic future (truth)")
    ax1.plot(
        fut_x,
        point_fc,
        color=CLR["CRITICAL"],
        lw=2,
        marker="s",
        ms=4,
        label="TimesFM point forecast",
    )
    ax1.fill_between(fut_x, q10, q90, alpha=0.15, color=CLR["CRITICAL"], label="80% PI")
    ax1.fill_between(fut_x, q20, q80, alpha=0.25, color=CLR["CRITICAL"], label="60% PI")

    seen_fc: set[str] = set()
    for i, rec in enumerate(fc_records):
        if rec["severity"] == "NORMAL":
            continue
        d = pd.Timestamp(rec["date"])
        v = rec["actual"]
        sev = rec["severity"]
        mk = "X" if sev == "CRITICAL" else "^"
        lbl = f"Forecast {sev}" if sev not in seen_fc else None
        seen_fc.add(sev)
        ax1.scatter(d, v, marker=mk, s=100, color=CLR[sev], zorder=6, label=lbl)

    ax1.axvline(divider, color="#555555", lw=1.5, ls=":")
    ax1.text(
        divider,
        ax1.get_ylim()[1] if ax1.get_ylim()[1] != 0 else 1.5,
        "  <- Context | Forecast ->",
        fontsize=8.5,
        color="#555555",
        style="italic",
        va="top",
    )

    ax1.annotate(
        "Context: D = Z-score anomaly | Forecast: X = CRITICAL, ^ = WARNING",
        xy=(0.01, 0.04),
        xycoords="axes fraction",
        fontsize=8,
        bbox=dict(boxstyle="round", fc="white", ec="#cccccc", alpha=0.9),
    )

    ax1.set_ylabel("Temperature Anomaly (C)", fontsize=10)
    ax1.legend(ncol=2, fontsize=7.5, loc="upper left")
    ax1.grid(True, alpha=0.22)

    # -----------------------------------------------------------------------
    # Panel 2 — deviation bars across all 48 months
    # -----------------------------------------------------------------------
    all_labels: list[str] = []
    bar_colors: list[str] = []
    bar_heights: list[float] = []

    for rec in ctx_records:
        all_labels.append(rec["date"])
        bar_heights.append(rec["residual"])
        bar_colors.append(CLR[rec["severity"]])

    fc_deviations: list[float] = []
    for rec in fc_records:
        all_labels.append(rec["date"])
        dev = rec["actual"] - rec["forecast"]
        fc_deviations.append(dev)
        bar_heights.append(dev)
        bar_colors.append(CLR[rec["severity"]])

    xs = np.arange(len(all_labels))
    ax2.bar(xs[:36], bar_heights[:36], color=bar_colors[:36], alpha=0.8)
    ax2.bar(xs[36:], bar_heights[36:], color=bar_colors[36:], alpha=0.8)

    # threshold lines for context section only
    ax2.hlines(
        [2 * res_std, -2 * res_std], -0.5, 35.5, colors=CLR["NORMAL"], lw=1.2, ls="--"
    )
    ax2.hlines(
        [3 * res_std, -3 * res_std], -0.5, 35.5, colors=CLR["NORMAL"], lw=1.0, ls=":"
    )

    # PI bands for forecast section
    fc_xs = xs[36:]
    ax2.fill_between(
        fc_xs,
        q10 - point_fc,
        q90 - point_fc,
        alpha=0.12,
        color=CLR["CRITICAL"],
        step="mid",
    )
    ax2.fill_between(
        fc_xs,
        q20 - point_fc,
        q80 - point_fc,
        alpha=0.20,
        color=CLR["CRITICAL"],
        step="mid",
    )

    ax2.axvline(35.5, color="#555555", lw=1.5, ls="--")
    ax2.axhline(0, color="black", lw=0.8, alpha=0.6)

    ax2.text(
        10,
        ax2.get_ylim()[0] * 0.85 if ax2.get_ylim()[0] < 0 else -0.05,
        "<- Context: delta from linear trend",
        fontsize=8,
        style="italic",
        color="#555555",
        ha="center",
    )
    ax2.text(
        41,
        ax2.get_ylim()[0] * 0.85 if ax2.get_ylim()[0] < 0 else -0.05,
        "Forecast: delta from TimesFM ->",
        fontsize=8,
        style="italic",
        color="#555555",
        ha="center",
    )

    tick_every = 3
    ax2.set_xticks(xs[::tick_every])
    ax2.set_xticklabels(all_labels[::tick_every], rotation=45, ha="right", fontsize=7)
    ax2.set_ylabel("Delta from expected (C)", fontsize=10)
    ax2.grid(True, alpha=0.22, axis="y")

    legend_patches = [
        mpatches.Patch(color=CLR["CRITICAL"], label="CRITICAL"),
        mpatches.Patch(color=CLR["WARNING"], label="WARNING"),
        mpatches.Patch(color=CLR["NORMAL"], label="Normal"),
    ]
    ax2.legend(handles=legend_patches, fontsize=8, loc="upper right")

    output_path = OUTPUT_DIR / "anomaly_detection.png"
    plt.savefig(output_path, dpi=150, bbox_inches="tight")
    plt.close()
    print(f"\n  Saved: {output_path}")


# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------


def main() -> None:
    print("=" * 68)
    print("  TIMESFM ANOMALY DETECTION — TWO-PHASE METHOD")
    print("=" * 68)

    # --- Load context data ---------------------------------------------------
    df = pd.read_csv(DATA_FILE)
    df["date"] = pd.to_datetime(df["date"])
    df = df.sort_values("date").reset_index(drop=True)

    context_values = df["anomaly_c"].values.astype(np.float32)
    context_dates = [pd.Timestamp(d) for d in df["date"].tolist()]
    start_str = context_dates[0].strftime('%Y-%m') if not pd.isnull(context_dates[0]) else '?'
    end_str   = context_dates[-1].strftime('%Y-%m') if not pd.isnull(context_dates[-1]) else '?'
    print(f"\n  Context: {len(context_values)} months  ({start_str} - {end_str})")

    # --- Phase 1: context anomaly detection ----------------------------------
    ctx_records, trend_line, residuals, res_std = detect_context_anomalies(
        context_values, context_dates
    )
    ctx_critical = [r for r in ctx_records if r["severity"] == "CRITICAL"]
    ctx_warning = [r for r in ctx_records if r["severity"] == "WARNING"]
    print(f"\n  [Phase 1] Context anomalies (Z-score, sigma={res_std:.3f} C):")
    print(f"    CRITICAL (|Z|>={CRITICAL_Z}): {len(ctx_critical)}")
    for r in ctx_critical:
        print(f"      {r['date']}  {r['value']:+.3f} C  z={r['z_score']:+.2f}")
    print(f"    WARNING  (|Z|>={WARNING_Z}): {len(ctx_warning)}")
    for r in ctx_warning:
        print(f"      {r['date']}  {r['value']:+.3f} C  z={r['z_score']:+.2f}")

    # --- Load TimesFM --------------------------------------------------------
    print("\n  Loading TimesFM 1.0 ...")
    import timesfm

    hparams = timesfm.TimesFmHparams(horizon_len=HORIZON)
    checkpoint = timesfm.TimesFmCheckpoint(
        huggingface_repo_id="google/timesfm-1.0-200m-pytorch"
    )
    model = timesfm.TimesFm(hparams=hparams, checkpoint=checkpoint)

    point_out, quant_out = model.forecast([context_values], freq=[0])
    point_fc = point_out[0]  # shape (HORIZON,)
    quant_fc = quant_out[0].T  # shape (10, HORIZON)

    # --- Build synthetic future + Phase 2 detection --------------------------
    future_values, injected = build_synthetic_future(context_values, HORIZON)
    last_date = context_dates[-1]
    future_dates = [last_date + pd.DateOffset(months=i + 1) for i in range(HORIZON)]

    fc_records = detect_forecast_anomalies(
        future_values, point_fc, quant_fc, future_dates, injected
    )
    fc_critical = [r for r in fc_records if r["severity"] == "CRITICAL"]
    fc_warning = [r for r in fc_records if r["severity"] == "WARNING"]

    print(f"\n  [Phase 2] Forecast anomalies (quantile PI, horizon={HORIZON} months):")
    print(f"    CRITICAL (outside 80% PI): {len(fc_critical)}")
    for r in fc_critical:
        print(
            f"      {r['date']}  actual={r['actual']:+.3f}  "
            f"fc={r['forecast']:+.3f}  injected={r['was_injected']}"
        )
    print(f"    WARNING  (outside 60% PI): {len(fc_warning)}")
    for r in fc_warning:
        print(
            f"      {r['date']}  actual={r['actual']:+.3f}  "
            f"fc={r['forecast']:+.3f}  injected={r['was_injected']}"
        )

    # --- Plot ----------------------------------------------------------------
    print("\n  Generating 2-panel visualization...")
    plot_results(
        context_dates,
        context_values,
        ctx_records,
        trend_line,
        residuals,
        res_std,
        future_dates,
        future_values,
        point_fc,
        quant_fc,
        fc_records,
    )

    # --- Save JSON -----------------------------------------------------------
    OUTPUT_DIR.mkdir(exist_ok=True)
    out = {
        "method": "two_phase",
        "context_method": "linear_detrend_zscore",
        "forecast_method": "quantile_prediction_intervals",
        "thresholds": {
            "critical_z": CRITICAL_Z,
            "warning_z": WARNING_Z,
            "pi_critical_pct": 80,
            "pi_warning_pct": 60,
        },
        "context_summary": {
            "total": len(ctx_records),
            "critical": len(ctx_critical),
            "warning": len(ctx_warning),
            "normal": len([r for r in ctx_records if r["severity"] == "NORMAL"]),
            "res_std": round(float(res_std), 5),
        },
        "forecast_summary": {
            "total": len(fc_records),
            "critical": len(fc_critical),
            "warning": len(fc_warning),
            "normal": len([r for r in fc_records if r["severity"] == "NORMAL"]),
        },
        "context_detections": ctx_records,
        "forecast_detections": fc_records,
    }
    json_path = OUTPUT_DIR / "anomaly_detection.json"
    with open(json_path, "w") as f:
        json.dump(out, f, indent=2)
    print(f"  Saved: {json_path}")

    print("\n" + "=" * 68)
    print("  SUMMARY")
    print("=" * 68)
    print(
        f"  Context  ({len(ctx_records)} months): "
        f"{len(ctx_critical)} CRITICAL, {len(ctx_warning)} WARNING"
    )
    print(
        f"  Forecast ({len(fc_records)} months): "
        f"{len(fc_critical)} CRITICAL, {len(fc_warning)} WARNING"
    )
    print("=" * 68)


if __name__ == "__main__":
    main()


================================================
FILE: timesfm-forecasting/examples/anomaly-detection/output/anomaly_detection.json
================================================
{
  "method": "two_phase",
  "context_method": "linear_detrend_zscore",
  "forecast_method": "quantile_prediction_intervals",
  "thresholds": {
    "critical_z": 3.0,
    "warning_z": 2.0,
    "pi_critical_pct": 80,
    "pi_warning_pct": 60
  },
  "context_summary": {
    "total": 36,
    "critical": 1,
    "warning": 0,
    "normal": 35,
    "res_std": 0.11362
  },
  "forecast_summary": {
    "total": 12,
    "critical": 4,
    "warning": 1,
    "normal": 7
  },
  "context_detections": [
    {
      "date": "2022-01",
      "value": 0.89,
      "trend": 0.837,
      "residual": 0.053,
      "z_score": 0.467,
      "severity": "NORMAL"
    },
    {
      "date": "2022-02",
      "value": 0.89,
      "trend": 0.8514,
      "residual": 0.0386,
      "z_score": 0.34,
      "severity": "NORMAL"
    },
    {
      "date": "2022-03",
      "value": 1.02,
      "trend": 0.8658,
      "residual": 0.1542,
      "z_score": 1.357,
      "severity": "NORMAL"
    },
    {
      "date": "2022-04",
      "value": 0.88,
      "trend": 0.8803,
      "residual": -0.0003,
      "z_score": -0.002,
      "severity": "NORMAL"
    },
    {
      "date": "2022-05",
      "value": 0.85,
      "trend": 0.8947,
      "residual": -0.0447,
      "z_score": -0.394,
      "severity": "NORMAL"
    },
    {
      "date": "2022-06",
      "value": 0.88,
      "trend": 0.9092,
      "residual": -0.0292,
      "z_score": -0.257,
      "severity": "NORMAL"
    },
    {
      "date": "2022-07",
      "value": 0.88,
      "trend": 0.9236,
      "residual": -0.0436,
      "z_score": -0.384,
      "severity": "NORMAL"
    },
    {
      "date": "2022-08",
      "value": 0.9,
      "trend": 0.9381,
      "residual": -0.0381,
      "z_score": -0.335,
      "severity": "NORMAL"
    },
    {
      "date": "2022-09",
      "value": 0.88,
      "trend": 0.9525,
      "residual": -0.0725,
      "z_score": -0.638,
      "severity": "NORMAL"
    },
    {
      "date": "2022-10",
      "value": 0.95,
      "trend": 0.9669,
      "residual": -0.0169,
      "z_score": -0.149,
      "severity": "NORMAL"
    },
    {
      "date": "2022-11",
      "value": 0.77,
      "trend": 0.9814,
      "residual": -0.2114,
      "z_score": -1.86,
      "severity": "NORMAL"
    },
    {
      "date": "2022-12",
      "value": 0.78,
      "trend": 0.9958,
      "residual": -0.2158,
      "z_score": -1.9,
      "severity": "NORMAL"
    },
    {
      "date": "2023-01",
      "value": 0.87,
      "trend": 1.0103,
      "residual": -0.1403,
      "z_score": -1.235,
      "severity": "NORMAL"
    },
    {
      "date": "2023-02",
      "value": 0.98,
      "trend": 1.0247,
      "residual": -0.0447,
      "z_score": -0.394,
      "severity": "NORMAL"
    },
    {
      "date": "2023-03",
      "value": 1.21,
      "trend": 1.0392,
      "residual": 0.1708,
      "z_score": 1.503,
      "severity": "NORMAL"
    },
    {
      "date": "2023-04",
      "value": 1.0,
      "trend": 1.0536,
      "residual": -0.0536,
      "z_score": -0.472,
      "severity": "NORMAL"
    },
    {
      "date": "2023-05",
      "value": 0.94,
      "trend": 1.0681,
      "residual": -0.1281,
      "z_score": -1.127,
      "severity": "NORMAL"
    },
    {
      "date": "2023-06",
      "value": 1.08,
      "trend": 1.0825,
      "residual": -0.0025,
      "z_score": -0.022,
      "severity": "NORMAL"
    },
    {
      "date": "2023-07",
      "value": 1.18,
      "trend": 1.0969,
      "residual": 0.0831,
      "z_score": 0.731,
      "severity": "NORMAL"
    },
    {
      "date": "2023-08",
      "value": 1.24,
      "trend": 1.1114,
      "residual": 0.1286,
      "z_score": 1.132,
      "severity": "NORMAL"
    },
    {
      "date": "2023-09",
      "value": 1.47,
      "trend": 1.1258,
      "residual": 0.3442,
      "z_score": 3.029,
      "severity": "CRITICAL"
    },
    {
      "date": "2023-10",
      "value": 1.32,
      "trend": 1.1403,
      "residual": 0.1797,
      "z_score": 1.582,
      "severity": "NORMAL"
    },
    {
      "date": "2023-11",
      "value": 1.18,
      "trend": 1.1547,
      "residual": 0.0253,
      "z_score": 0.222,
      "severity": "NORMAL"
    },
    {
      "date": "2023-12",
      "value": 1.16,
      "trend": 1.1692,
      "residual": -0.0092,
      "z_score": -0.081,
      "severity": "NORMAL"
    },
    {
      "date": "2024-01",
      "value": 1.22,
      "trend": 1.1836,
      "residual": 0.0364,
      "z_score": 0.32,
      "severity": "NORMAL"
    },
    {
      "date": "2024-02",
      "value": 1.35,
      "trend": 1.1981,
      "residual": 0.1519,
      "z_score": 1.337,
      "severity": "NORMAL"
    },
    {
      "date": "2024-03",
      "value": 1.34,
      "trend": 1.2125,
      "residual": 0.1275,
      "z_score": 1.122,
      "severity": "NORMAL"
    },
    {
      "date": "2024-04",
      "value": 1.26,
      "trend": 1.2269,
      "residual": 0.0331,
      "z_score": 0.291,
      "severity": "NORMAL"
    },
    {
      "date": "2024-05",
      "value": 1.15,
      "trend": 1.2414,
      "residual": -0.0914,
      "z_score": -0.804,
      "severity": "NORMAL"
    },
    {
      "date": "2024-06",
      "value": 1.2,
      "trend": 1.2558,
      "residual": -0.0558,
      "z_score": -0.491,
      "severity": "NORMAL"
    },
    {
      "date": "2024-07",
      "value": 1.24,
      "trend": 1.2703,
      "residual": -0.0303,
      "z_score": -0.266,
      "severity": "NORMAL"
    },
    {
      "date": "2024-08",
      "value": 1.3,
      "trend": 1.2847,
      "residual": 0.0153,
      "z_score": 0.135,
      "severity": "NORMAL"
    },
    {
      "date": "2024-09",
      "value": 1.28,
      "trend": 1.2992,
      "residual": -0.0192,
      "z_score": -0.169,
      "severity": "NORMAL"
    },
    {
      "date": "2024-10",
      "value": 1.27,
      "trend": 1.3136,
      "residual": -0.0436,
      "z_score": -0.384,
      "severity": "NORMAL"
    },
    {
      "date": "2024-11",
      "value": 1.22,
      "trend": 1.328,
      "residual": -0.108,
      "z_score": -0.951,
      "severity": "NORMAL"
    },
    {
      "date": "2024-12",
      "value": 1.2,
      "trend": 1.3425,
      "residual": -0.1425,
      "z_score": -1.254,
      "severity": "NORMAL"
    }
  ],
  "forecast_detections": [
    {
      "date": "2025-01",
      "actual": 1.2821,
      "forecast": 1.2593,
      "q10": 1.1407,
      "q20": 1.1881,
      "q80": 1.324,
      "q90": 1.3679,
      "severity": "NORMAL",
      "was_injected": false
    },
    {
      "date": "2025-02",
      "actual": 1.1522,
      "forecast": 1.2857,
      "q10": 1.1406,
      "q20": 1.1961,
      "q80": 1.3751,
      "q90": 1.4254,
      "severity": "WARNING",
      "was_injected": false
    },
    {
      "date": "2025-03",
      "actual": 1.3358,
      "forecast": 1.295,
      "q10": 1.1269,
      "q20": 1.1876,
      "q80": 1.4035,
      "q90": 1.4643,
      "severity": "NORMAL",
      "was_injected": false
    },
    {
      "date": "2025-04",
      "actual": 2.0594,
      "forecast": 1.2208,
      "q10": 1.0353,
      "q20": 1.1042,
      "q80": 1.331,
      "q90": 1.4017,
      "severity": "CRITICAL",
      "was_injected": true
    },
    {
      "date": "2025-05",
      "actual": 1.0747,
      "forecast": 1.1703,
      "q10": 0.9691,
      "q20": 1.0431,
      "q80": 1.2892,
      "q90": 1.3632,
      "severity": "NORMAL",
      "was_injected": false
    },
    {
      "date": "2025-06",
      "actual": 1.1442,
      "forecast": 1.1456,
      "q10": 0.942,
      "q20": 1.0111,
      "q80": 1.2703,
      "q90": 1.3454,
      "severity": "NORMAL",
      "was_injected": false
    },
    {
      "date": "2025-07",
      "actual": 1.2917,
      "forecast": 1.1702,
      "q10": 0.9504,
      "q20": 1.0348,
      "q80": 1.2998,
      "q90": 1.3807,
      "severity": "NORMAL",
      "was_injected": false
    },
    {
      "date": "2025-08",
      "actual": 1.2519,
      "forecast": 1.2027,
      "q10": 0.9709,
      "q20": 1.0594,
      "q80": 1.3408,
      "q90": 1.4195,
      "severity": "NORMAL",
      "was_injected": false
    },
    {
      "date": "2025-09",
      "actual": 0.6364,
      "forecast": 1.191,
      "q10": 0.9594,
      "q20": 1.0404,
      "q80": 1.3355,
      "q90": 1.417,
      "severity": "CRITICAL",
      "was_injected": true
    },
    {
      "date": "2025-10",
      "actual": 1.2073,
      "forecast": 1.1491,
      "q10": 0.9079,
      "q20": 0.9953,
      "q80": 1.2869,
      "q90": 1.3775,
      "severity": "NORMAL",
      "was_injected": false
    },
    {
      "date": "2025-11",
      "actual": 1.3851,
      "forecast": 1.0805,
      "q10": 0.8361,
      "q20": 0.926,
      "q80": 1.2284,
      "q90": 1.3122,
      "severity": "CRITICAL",
      "was_injected": false
    },
    {
      "date": "2025-12",
      "actual": 1.8294,
      "forecast": 1.0613,
      "q10": 0.8022,
      "q20": 0.8952,
      "q80": 1.2169,
      "q90": 1.296,
      "severity": "CRITICAL",
      "was_injected": true
    }
  ]
}

================================================
FILE: timesfm-forecasting/examples/covariates-forecasting/demo_covariates.py
================================================
#!/usr/bin/env python3
"""
TimesFM Covariates (XReg) Example

Demonstrates the TimesFM covariate API using synthetic retail sales data.
TimesFM 1.0 does NOT support forecast_with_covariates(); that requires
TimesFM 2.5 + `pip install timesfm[xreg]`.

This script:
  1. Generates synthetic 3-store weekly retail data (24-week context, 12-week horizon)
  2. Produces a 2x2 visualization showing WHAT each covariate contributes
     and WHY knowing them improves forecasts -- all panels share the same
     week x-axis (0 = first context week, 35 = last horizon week)
  3. Exports a compact CSV (108 rows) and metadata JSON

NOTE ON REAL DATA:
  If you want to use a real retail dataset (e.g., Kaggle Rossmann Store Sales),
  download it to a TEMP location -- do NOT commit large CSVs to this repo.

      import tempfile, urllib.request
      tmp = tempfile.mkdtemp(prefix="timesfm_retail_")
      # urllib.request.urlretrieve("https://...store_sales.csv", f"{tmp}/store_sales.csv")
      # df = pd.read_csv(f"{tmp}/store_sales.csv")

  This skills directory intentionally keeps only tiny reference datasets.
"""

from __future__ import annotations

import json
from pathlib import Path

import matplotlib

matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

EXAMPLE_DIR = Path(__file__).parent
OUTPUT_DIR = EXAMPLE_DIR / "output"

N_STORES = 3
CONTEXT_LEN = 24
HORIZON_LEN = 12
TOTAL_LEN = CONTEXT_LEN + HORIZON_LEN  # 36


def generate_sales_data() -> dict:
    """Generate synthetic retail sales data with covariate components stored separately.

    Returns a dict with:
      stores:     {store_id: {sales, config}}
      covariates: {price, promotion, holiday, day_of_week, store_type, region}
      components: {store_id: {base, price_effect, promo_effect, holiday_effect}}

    Components let us show 'what would sales look like without covariates?' --
    the gap between 'base' and 'sales' IS the covariate signal.

    BUG FIX v3: Previous versions had variable-shadowing where inner dict
    comprehension `{store_id: ... for store_id in stores}` overwrote the outer
    loop variable causing all stores to get identical covariate arrays.
    Fixed by accumulating per-store arrays separately before building covariate dict.
    """
    rng = np.random.default_rng(42)

    stores = {
        "store_A": {"type": "premium", "region": "urban", "base_sales": 1000},
        "store_B": {"type": "standard", "region": "suburban", "base_sales": 750},
        "store_C": {"type": "discount", "region": "rural", "base_sales": 500},
    }
    base_prices = {"store_A": 12.0, "store_B": 10.0, "store_C": 7.5}

    data: dict = {"stores": {}, "covariates": {}, "components": {}}

    prices_by_store: dict[str, np.ndarray] = {}
    promos_by_store: dict[str, np.ndarray] = {}
    holidays_by_store: dict[str, np.ndarray] = {}
    dow_by_store: dict[str, np.ndarray] = {}

    for store_id, config in stores.items():
        bp = base_prices[store_id]
        weeks = np.arange(TOTAL_LEN)

        trend = config["base_sales"] * (1 + 0.005 * weeks)
        seasonality = 80 * np.sin(2 * np.pi * weeks / 52)
        noise = rng.normal(0, 40, TOTAL_LEN)
        base = (trend + seasonality + noise).astype(np.float32)

        price = (bp + rng.uniform(-0.5, 0.5, TOTAL_LEN)).astype(np.float32)
        price_effect = (-20 * (price - bp)).astype(np.float32)

        holidays = np.zeros(TOTAL_LEN, dtype=np.float32)
        for hw in [0, 11, 23, 35]:
            if hw < TOTAL_LEN:
                holidays[hw] = 1.0
        holiday_effect = (200 * holidays).astype(np.float32)

        promotion = rng.choice([0.0, 1.0], TOTAL_LEN, p=[0.8, 0.2]).astype(np.float32)
        promo_effect = (150 * promotion).astype(np.float32)

        day_of_week = np.tile(np.arange(7), TOTAL_LEN // 7 + 1)[:TOTAL_LEN].astype(
            np.int32
        )

        sales = np.maximum(base + price_effect + holiday_effect + promo_effect, 50.0)

        data["stores"][store_id] = {"sales": sales, "config": config}
        data["components"][store_id] = {
            "base": base,
            "price_effect": price_effect,
            "promo_effect": promo_effect,
            "holiday_effect": holiday_effect,
        }

        prices_by_store[store_id] = price
        promos_by_store[store_id] = promotion
        holidays_by_store[store_id] = holidays
        dow_by_store[store_id] = day_of_week

    data["covariates"] = {
        "price": prices_by_store,
        "promotion": promos_by_store,
        "holiday": holidays_by_store,
        "day_of_week": dow_by_store,
        "store_type": {sid: stores[sid]["type"] for sid in stores},
        "region": {sid: stores[sid]["region"] for sid in stores},
    }
    return data


def create_visualization(data: dict) -> None:
    """
    2x2 figure -- ALL panels share x-axis = weeks 0-35.

    (0,0) Sales by store -- context solid, horizon dashed
    (0,1) Store A: actual vs baseline (no covariates), with event overlays showing uplift
    (1,0) Price covariate for all stores -- full 36 weeks including horizon
    (1,1) Covariate effect decomposition for Store A (stacked fill_between)

    Each panel has a conclusion annotation box explaining what the data shows.
    """
    OUTPUT_DIR.mkdir(exist_ok=True)

    store_colors = {"store_A": "#1a56db", "store_B": "#057a55", "store_C": "#c03221"}
    weeks = np.arange(TOTAL_LEN)

    fig, axes = plt.subplots(
        2,
        2,
        figsize=(16, 11),
        sharex=True,
        gridspec_kw={"hspace": 0.42, "wspace": 0.32},
    )
    fig.suptitle(
        "TimesFM Covariates (XReg) -- Retail Sales with Exogenous Variables\n"
        "Shared x-axis: Week 0-23 = context (observed) | Week 24-35 = forecast horizon",
        fontsize=13,
        fontweight="bold",
        y=1.01,
    )

    def add_divider(ax, label_top=True):
        ax.axvline(CONTEXT_LEN - 0.5, color="#9ca3af", lw=1.3, ls="--", alpha=0.8)
        ax.axvspan(
            CONTEXT_LEN - 0.5, TOTAL_LEN - 0.5, alpha=0.06, color="grey", zorder=0
        )
        if label_top:
            ax.text(
                CONTEXT_LEN + 0.3,
                1.01,
                "<- horizon ->",
                transform=ax.get_xaxis_transform(),
                fontsize=7.5,
                color="#6b7280",
                style="italic",
            )

    # -- (0,0): Sales by Store ---------------------------------------------------
    ax = axes[0, 0]
    base_price_labels = {"store_A": "$12", "store_B": "$10", "store_C": "$7.50"}
    for sid, store_data in data["stores"].items():
        sales = store_data["sales"]
        c = store_colors[sid]
        lbl = f"{sid} ({store_data['config']['type']}, {base_price_labels[sid]} base)"
        ax.plot(
            weeks[:CONTEXT_LEN],
            sales[:CONTEXT_LEN],
            color=c,
            lw=2,
            marker="o",
            ms=3,
            label=lbl,
        )
        ax.plot(
            weeks[CONTEXT_LEN:],
            sales[CONTEXT_LEN:],
            color=c,
            lw=1.5,
            ls="--",
            marker="o",
            ms=3,
            alpha=0.6,
        )
    add_divider(ax)
    ax.set_ylabel("Weekly Sales (units)", fontsize=10)
    ax.set_title("Sales by Store", fontsize=11, fontweight="bold")
    ax.legend(fontsize=7.5, loc="upper left")
    ax.grid(True, alpha=0.22)
    ratio = (
        data["stores"]["store_A"]["sales"][:CONTEXT_LEN].mean()
        / data["stores"]["store_C"]["sales"][:CONTEXT_LEN].mean()
    )
    ax.annotate(
        f"Store A earns {ratio:.1f}x Store C\n(premium vs discount pricing)\n"
        f"-> store_type is a useful static covariate",
        xy=(0.97, 0.05),
        xycoords="axes fraction",
        ha="right",
        fontsize=8,
        bbox=dict(boxstyle="round", fc="#fffbe6", ec="#d4a017", alpha=0.95),
    )

    # -- (0,1): Store A actual vs baseline ---------------------------------------
    ax = axes[0, 1]
    comp_A = data["components"]["store_A"]
    sales_A = data["stores"]["store_A"]["sales"]
    base_A = comp_A["base"]
    promo_A = data["covariates"]["promotion"]["store_A"]
    holiday_A = data["covariates"]["holiday"]["store_A"]

    ax.plot(
        weeks[:CONTEXT_LEN],
        base_A[:CONTEXT_LEN],
        color="#9ca3af",
        lw=1.8,
        ls="--",
        label="Baseline (no covariates)",
    )
    ax.fill_between(
        weeks[:CONTEXT_LEN],
        base_A[:CONTEXT_LEN],
        sales_A[:CONTEXT_LEN],
        where=(sales_A[:CONTEXT_LEN] > base_A[:CONTEXT_LEN]),
        alpha=0.35,
        color="#22c55e",
        label="Covariate uplift",
    )
    ax.fill_between(
        weeks[:CONTEXT_LEN],
        sales_A[:CONTEXT_LEN],
        base_A[:CONTEXT_LEN],
        where=(sales_A[:CONTEXT_LEN] < base_A[:CONTEXT_LEN]),
        alpha=0.30,
        color="#ef4444",
        label="Price suppression",
    )
    ax.plot(
        weeks[:CONTEXT_LEN],
        sales_A[:CONTEXT_LEN],
        color=store_colors["store_A"],
        lw=2,
        label="Actual sales (Store A)",
    )

    for w in range(CONTEXT_LEN):
        if holiday_A[w] > 0:
            ax.axvspan(w - 0.45, w + 0.45, alpha=0.22, color="darkorange", zorder=0)
    promo_weeks = [w for w in range(CONTEXT_LEN) if promo_A[w] > 0]
    if promo_weeks:
        ax.scatter(
            promo_weeks,
            sales_A[promo_weeks],
            marker="^",
            color="#16a34a",
            s=70,
            zorder=6,
            label="Promotion week",
        )

    add_divider(ax)
    ax.set_ylabel("Weekly Sales (units)", fontsize=10)
    ax.set_title(
        "Store A -- Actual vs Baseline (No Covariates)", fontsize=11, fontweight="bold"
    )
    ax.legend(fontsize=7.5, loc="upper left", ncol=2)
    ax.grid(True, alpha=0.22)

    hm = holiday_A[:CONTEXT_LEN] > 0
    pm = promo_A[:CONTEXT_LEN] > 0
    h_lift = (
        (sales_A[:CONTEXT_LEN][hm] - base_A[:CONTEXT_LEN][hm]).mean() if hm.any() else 0
    )
    p_lift = (
        (sales_A[:CONTEXT_LEN][pm] - base_A[:CONTEXT_LEN][pm]).mean() if pm.any() else 0
    )
    ax.annotate(
        f"Holiday weeks: +{h_lift:.0f} units avg\n"
        f"Promotion weeks: +{p_lift:.0f} units avg\n"
        f"Future event schedules must be known for XReg",
        xy=(0.97, 0.05),
        xycoords="axes fraction",
        ha="right",
        fontsize=8,
        bbox=dict(boxstyle="round", fc="#fffbe6", ec="#d4a017", alpha=0.95),
    )

    # -- (1,0): Price covariate -- full 36 weeks ---------------------------------
    ax = axes[1, 0]
    for sid in data["stores"]:
        ax.plot(
            weeks,
            data["covariates"]["price"][sid],
            color=store_colors[sid],
            lw=2,
            label=sid,
            alpha=0.85,
        )
    add_divider(ax, label_top=False)
    ax.set_xlabel("Week", fontsize=10)
    ax.set_ylabel("Price ($)", fontsize=10)
    ax.set_title(
        "Price Covariate -- Context + Forecast Horizon", fontsize=11, fontweight="bold"
    )
    ax.legend(fontsize=8, loc="upper right")
    ax.grid(True, alpha=0.22)
    ax.annotate(
        "Prices are planned -- known for forecast horizon\n"
        "Price elasticity: -$1 increase -> -20 units sold\n"
        "Store A ($12) consistently more expensive than C ($7.50)",
        xy=(0.97, 0.05),
        xycoords="axes fraction",
        ha="right",
        fontsize=8,
        bbox=dict(boxstyle="round", fc="#fffbe6", ec="#d4a017", alpha=0.95),
    )

    # -- (1,1): Covariate effect decomposition -----------------------------------
    ax = axes[1, 1]
    pe = comp_A["price_effect"]
    pre = comp_A["promo_effect"]
    he = comp_A["holiday_effect"]

    ax.fill_between(
        weeks,
        0,
        pe,
        alpha=0.65,
        color="steelblue",
        step="mid",
        label=f"Price effect (max +/-{np.abs(pe).max():.0f} units)",
    )
    ax.fill_between(
        weeks,
        pe,
        pe + pre,
        alpha=0.70,
        color="#22c55e",
        step="mid",
        label="Promotion effect (+150 units)",
    )
    ax.fill_between(
        weeks,
        pe + pre,
        pe + pre + he,
        alpha=0.70,
        color="darkorange",
        step="mid",
        label="Holiday effect (+200 units)",
    )
    total = pe + pre + he
    ax.plot(weeks, total, "k-", lw=1.5, alpha=0.75, label="Total covariate effect")
    ax.axhline(0, color="black", lw=0.9, alpha=0.6)
    add_divider(ax, label_top=False)
    ax.set_xlabel("Week", fontsize=10)
    ax.set_ylabel("Effect on sales (units)", fontsize=10)
    ax.set_title(
        "Store A -- Covariate Effect Decomposition", fontsize=11, fontweight="bold"
    )
    ax.legend(fontsize=7.5, loc="upper right")
    ax.grid(True, alpha=0.22, axis="y")
    ax.annotate(
        f"Holidays (+200) and promotions (+150) dominate\n"
        f"Price effect (+/-{np.abs(pe).max():.0f} units) is minor by comparison\n"
        f"-> Time-varying covariates explain most sales spikes",
        xy=(0.97, 0.55),
        xycoords="axes fraction",
        ha="right",
        fontsize=8,
        bbox=dict(boxstyle="round", fc="#fffbe6", ec="#d4a017", alpha=0.95),
    )

    tick_pos = list(range(0, TOTAL_LEN, 4))
    for row in [0, 1]:
        for col in [0, 1]:
            axes[row, col].set_xticks(tick_pos)

    plt.tight_layout()
    output_path = OUTPUT_DIR / "covariates_data.png"
    plt.savefig(output_path, dpi=150, bbox_inches="tight")
    plt.close()
    print(f"\n Saved visualization: {output_path}")


def demonstrate_api() -> None:
    print("\n" + "=" * 70)
    print("  TIMESFM COVARIATES API (TimesFM 2.5)")
    print("=" * 70)
    print("""
# Installation
pip install timesfm[xreg]

import timesfm
hparams   = timesfm.TimesFmHparams(backend="cpu", per_core_batch_size=32, horizon_len=12)
ckpt      = timesfm.TimesFmCheckpoint(huggingface_repo_id="google/timesfm-2.5-200m-pytorch")
model     = timesfm.TimesFm(hparams=hparams, checkpoint=ckpt)

point_fc, quant_fc = model.forecast_with_covariates(
    inputs=[sales_a, sales_b, sales_c],
    dynamic_numerical_covariates={"price": [price_a, price_b, price_c]},
    dynamic_categorical_covariates={"holiday": [hol_a, hol_b, hol_c]},
    static_categorical_covariates={"store_type": ["premium","standard","discount"]},
    xreg_mode="xreg + timesfm",
    normalize_xreg_target_per_input=True,
)
# point_fc:  (num_series, horizon_len)
# quant_fc:  (num_series, horizon_len, 10)
""")


def explain_xreg_modes() -> None:
    print("\n" + "=" * 70)
    print("  XREG MODES")
    print("=" * 70)
    print("""
"xreg + timesfm" (DEFAULT)
  1. TimesFM makes baseline forecast
  2. Fit regression on residuals (actual - baseline) ~ covariates
  3. Final = TimesFM baseline + XReg adjustment
  Best when: covariates explain residual variat
Download .txt
gitextract_oox45t84/

├── .gitattributes
├── .github/
│   └── workflows/
│       ├── main.yml
│       └── manual_publish.yml
├── .gitignore
├── AGENTS.md
├── LICENSE
├── README.md
├── pyproject.toml
├── requirements.txt
├── src/
│   └── timesfm/
│       ├── __init__.py
│       ├── configs.py
│       ├── flax/
│       │   ├── __init__.py
│       │   ├── dense.py
│       │   ├── normalization.py
│       │   ├── transformer.py
│       │   └── util.py
│       ├── timesfm_2p5/
│       │   ├── timesfm_2p5_base.py
│       │   ├── timesfm_2p5_flax.py
│       │   └── timesfm_2p5_torch.py
│       ├── torch/
│       │   ├── __init__.py
│       │   ├── dense.py
│       │   ├── normalization.py
│       │   ├── transformer.py
│       │   └── util.py
│       └── utils/
│           └── xreg_lib.py
├── timesfm-forecasting/
│   ├── SKILL.md
│   ├── examples/
│   │   ├── anomaly-detection/
│   │   │   ├── detect_anomalies.py
│   │   │   └── output/
│   │   │       └── anomaly_detection.json
│   │   ├── covariates-forecasting/
│   │   │   ├── demo_covariates.py
│   │   │   └── output/
│   │   │       ├── covariates_metadata.json
│   │   │       └── sales_with_covariates.csv
│   │   └── global-temperature/
│   │       ├── README.md
│   │       ├── generate_animation_data.py
│   │       ├── generate_gif.py
│   │       ├── generate_html.py
│   │       ├── output/
│   │       │   ├── animation_data.json
│   │       │   ├── forecast_output.csv
│   │       │   ├── forecast_output.json
│   │       │   └── interactive_forecast.html
│   │       ├── run_example.sh
│   │       ├── run_forecast.py
│   │       ├── temperature_anomaly.csv
│   │       └── visualize_forecast.py
│   ├── references/
│   │   ├── api_reference.md
│   │   ├── data_preparation.md
│   │   └── system_requirements.md
│   └── scripts/
│       ├── check_system.py
│       └── forecast_csv.py
└── v1/
    ├── LICENSE
    ├── README.md
    ├── TROUBLESHOOTING.md
    ├── docs/
    │   └── contributing.md
    ├── experiments/
    │   ├── baselines/
    │   │   ├── __init__.py
    │   │   └── timegpt_pipeline.py
    │   ├── extended_benchmarks/
    │   │   ├── README.md
    │   │   ├── run_timegpt.py
    │   │   ├── run_timesfm.py
    │   │   └── utils.py
    │   └── long_horizon_benchmarks/
    │       ├── README.md
    │       └── run_eval.py
    ├── notebooks/
    │   ├── covariates.ipynb
    │   ├── finetuning.ipynb
    │   └── finetuning_torch.ipynb
    ├── peft/
    │   ├── README.md
    │   ├── finetune.py
    │   ├── finetune.sh
    │   └── usage.ipynb
    ├── pyproject.toml
    ├── src/
    │   ├── adapter/
    │   │   ├── __init__.py
    │   │   ├── dora_layers.py
    │   │   ├── lora_layers.py
    │   │   └── utils.py
    │   ├── finetuning/
    │   │   ├── __init__.py
    │   │   ├── finetuning_example.py
    │   │   └── finetuning_torch.py
    │   └── timesfm/
    │       ├── __init__.py
    │       ├── data_loader.py
    │       ├── patched_decoder.py
    │       ├── pytorch_patched_decoder.py
    │       ├── time_features.py
    │       ├── timesfm_base.py
    │       ├── timesfm_jax.py
    │       ├── timesfm_torch.py
    │       └── xreg_lib.py
    └── tests/
        └── test_timesfm.py
Download .txt
SYMBOL INDEX (376 symbols across 41 files)

FILE: src/timesfm/configs.py
  class ForecastConfig (line 22) | class ForecastConfig:
  class ResidualBlockConfig (line 64) | class ResidualBlockConfig:
  class RandomFourierFeaturesConfig (line 75) | class RandomFourierFeaturesConfig:
  class TransformerConfig (line 85) | class TransformerConfig:
  class StackedTransformersConfig (line 101) | class StackedTransformersConfig:

FILE: src/timesfm/flax/dense.py
  class ResidualBlock (line 34) | class ResidualBlock(nnx.Module):
    method __init__ (line 37) | def __init__(self, config: ResidualBlockConfig, *, rngs=nnx.Rngs(42)):
    method __call__ (line 66) | def __call__(self, x: Float[Array, "b ... i"]) -> Float[Array, "b ... ...
  class RandomFourierFeatures (line 72) | class RandomFourierFeatures(nnx.Module):
    method __init__ (line 77) | def __init__(self, config: RandomFourierFeaturesConfig, *, rngs=nnx.Rn...
    method __call__ (line 100) | def __call__(self, x: Float[Array, "b ... i"]) -> Float[Array, "b ... ...

FILE: src/timesfm/flax/normalization.py
  class RMSNorm (line 29) | class RMSNorm(nnx.Module):
    method __init__ (line 34) | def __init__(
    method __call__ (line 46) | def __call__(self, inputs: Float[Array, "b ... d"]) -> Float[Array, "b...
  class LayerNorm (line 53) | class LayerNorm(nnx.Module):
    method __init__ (line 58) | def __init__(self, num_features: int, *, epsilon: float = 1e-6, rngs=n...
    method __call__ (line 65) | def __call__(self, inputs: Float[Array, "b ... d"]) -> Float[Array, "b...

FILE: src/timesfm/flax/transformer.py
  function make_attn_mask (line 46) | def make_attn_mask(
  class RotaryPositionalEmbedding (line 67) | class RotaryPositionalEmbedding(nnx.Module):
    method __init__ (line 70) | def __init__(
    method __call__ (line 80) | def __call__(
  class PerDimScale (line 118) | class PerDimScale(nnx.Module):
    method __init__ (line 123) | def __init__(self, num_dims: int, *, rngs=nnx.Rngs(42)):
    method __call__ (line 128) | def __call__(self, x: Float[Array, "b ... d"]) -> Float[Array, "b ... ...
  class MultiHeadAttention (line 134) | class MultiHeadAttention(nnx.Module):
    method __init__ (line 137) | def __init__(
    method __call__ (line 207) | def __call__(
  class Transformer (line 291) | class Transformer(nnx.Module):
    method __init__ (line 294) | def __init__(self, config: TransformerConfig, *, rngs=nnx.Rngs(42)):
    method __call__ (line 338) | def __call__(

FILE: src/timesfm/flax/util.py
  class DecodeCache (line 33) | class DecodeCache:
  function update_running_stats (line 43) | def update_running_stats(
  function scan_along_axis (line 80) | def scan_along_axis(f, init, xs, axis: int, **kwargs):
  function revin (line 91) | def revin(

FILE: src/timesfm/timesfm_2p5/timesfm_2p5_base.py
  function strip_leading_nans (line 33) | def strip_leading_nans(arr):
  function linear_interpolation (line 49) | def linear_interpolation(arr):
  class TimesFM_2p5_200M_Definition (line 85) | class TimesFM_2p5_200M_Definition:
  class TimesFM_2p5 (line 134) | class TimesFM_2p5:
    method load_checkpoint (line 147) | def load_checkpoint(self, path: str):
    method compile (line 151) | def compile(self, forecast_config: ForecastConfig | None = None):
    method forecast (line 155) | def forecast(
    method forecast_with_covariates (line 198) | def forecast_with_covariates(

FILE: src/timesfm/timesfm_2p5/timesfm_2p5_flax.py
  function try_gc (line 48) | def try_gc():
  function _create_stacked_transformers (line 59) | def _create_stacked_transformers(
  function _scan_along_axis (line 65) | def _scan_along_axis(f, init, xs, axis: int, **kwargs):
  function _apply_stacked_transformers (line 76) | def _apply_stacked_transformers(
  class TimesFM_2p5_200M_flax_module (line 85) | class TimesFM_2p5_200M_flax_module(nnx.Module):  # pylint: disable=inval...
    method __init__ (line 96) | def __init__(self):
    method __call__ (line 126) | def __call__(
    method decode (line 149) | def decode(self, horizon: int, inputs, masks):
    method compile (line 237) | def compile(
  function _flip_quantile_fn (line 276) | def _flip_quantile_fn(x):
  function _force_flip_invariance_fn (line 284) | def _force_flip_invariance_fn(
  function _use_continuous_quantile_head_fn (line 309) | def _use_continuous_quantile_head_fn(full_forecast, quantile_spreads, ma...
  function _fix_quantile_crossing_fn (line 329) | def _fix_quantile_crossing_fn(full_forecast):
  function _before_model_decode (line 357) | def _before_model_decode(fc, inputs, masks):
  function _after_model_decode (line 385) | def _after_model_decode(
  class TimesFM_2p5_200M_flax (line 445) | class TimesFM_2p5_200M_flax(timesfm_2p5_base.TimesFM_2p5):
    method from_pretrained (line 451) | def from_pretrained(
    method compile (line 494) | def compile(

FILE: src/timesfm/timesfm_2p5/timesfm_2p5_torch.py
  class TimesFM_2p5_200M_torch_module (line 36) | class TimesFM_2p5_200M_torch_module(nn.Module):
    method __init__ (line 41) | def __init__(self):
    method load_checkpoint (line 79) | def load_checkpoint(self, path: str, **kwargs):
    method forward (line 93) | def forward(
    method decode (line 122) | def decode(self, horizon: int, inputs, masks):
    method forecast_naive (line 228) | def forecast_naive(
  class TimesFM_2p5_200M_torch (line 266) | class TimesFM_2p5_200M_torch(
    method __init__ (line 282) | def __init__(
    method _from_pretrained (line 293) | def _from_pretrained(
    method _save_pretrained (line 341) | def _save_pretrained(self, save_directory: Union[str, Path]):
    method compile (line 352) | def compile(self, forecast_config: configs.ForecastConfig, **kwargs) -...

FILE: src/timesfm/torch/dense.py
  class ResidualBlock (line 23) | class ResidualBlock(nn.Module):
    method __init__ (line 26) | def __init__(self, config: configs.ResidualBlockConfig):
    method forward (line 53) | def forward(self, x: torch.Tensor) -> torch.Tensor:
  class RandomFourierFeatures (line 59) | class RandomFourierFeatures(nn.Module):
    method __init__ (line 62) | def __init__(self, config: configs.RandomFourierFeaturesConfig):
    method forward (line 84) | def forward(self, x: torch.Tensor) -> torch.Tensor:

FILE: src/timesfm/torch/normalization.py
  class RMSNorm (line 21) | class RMSNorm(nn.Module):
    method __init__ (line 24) | def __init__(
    method forward (line 35) | def forward(self, inputs: torch.Tensor) -> torch.Tensor:

FILE: src/timesfm/torch/transformer.py
  function make_attn_mask (line 32) | def make_attn_mask(
  class RotaryPositionalEmbedding (line 56) | class RotaryPositionalEmbedding(nn.Module):
    method __init__ (line 59) | def __init__(
    method forward (line 70) | def forward(
  function _dot_product_attention (line 114) | def _dot_product_attention(
  function _torch_dot_product_attention (line 132) | def _torch_dot_product_attention(query, key, value, mask=None):
  class PerDimScale (line 154) | class PerDimScale(nn.Module):
    method __init__ (line 157) | def __init__(self, num_dims: int):
    method forward (line 162) | def forward(self, x: torch.Tensor) -> torch.Tensor:
  class MultiHeadAttention (line 169) | class MultiHeadAttention(nn.Module):
    method __init__ (line 172) | def __init__(
    method forward (line 224) | def forward(
  class Transformer (line 307) | class Transformer(nn.Module):
    method __init__ (line 310) | def __init__(self, config: configs.TransformerConfig):
    method forward (line 354) | def forward(

FILE: src/timesfm/torch/util.py
  class DecodeCache (line 24) | class DecodeCache:
  function update_running_stats (line 33) | def update_running_stats(
  function revin (line 77) | def revin(

FILE: src/timesfm/utils/xreg_lib.py
  function _unnest (line 36) | def _unnest(nested: Sequence[Sequence[Any]]) -> np.ndarray:
  function _repeat (line 40) | def _repeat(elements: Iterable[Any], counts: Iterable[int]) -> np.ndarray:
  function _to_padded_jax_array (line 46) | def _to_padded_jax_array(x: np.ndarray) -> jax.Array:
  function normalize (line 61) | def normalize(batch):
  function renormalize (line 68) | def renormalize(batch, stats):
  class BatchedInContextXRegBase (line 72) | class BatchedInContextXRegBase:
    method __init__ (line 97) | def __init__(
    method _assert_covariates (line 210) | def _assert_covariates(self, assert_covariate_shapes: bool = False) ->...
    method create_covariate_matrix (line 327) | def create_covariate_matrix(
    method fit (line 407) | def fit(self) -> Any:
  class BatchedInContextXRegLinear (line 411) | class BatchedInContextXRegLinear(BatchedInContextXRegBase):
    method fit (line 414) | def fit(

FILE: timesfm-forecasting/examples/anomaly-detection/detect_anomalies.py
  function detect_context_anomalies (line 50) | def detect_context_anomalies(
  function build_synthetic_future (line 98) | def build_synthetic_future(
  function detect_forecast_anomalies (line 121) | def detect_forecast_anomalies(
  function plot_results (line 172) | def plot_results(
  function main (line 391) | def main() -> None:

FILE: timesfm-forecasting/examples/covariates-forecasting/demo_covariates.py
  function generate_sales_data (line 49) | def generate_sales_data() -> dict:
  function create_visualization (line 132) | def create_visualization(data: dict) -> None:
  function demonstrate_api (line 405) | def demonstrate_api() -> None:
  function explain_xreg_modes (line 431) | def explain_xreg_modes() -> None:
  function main (line 450) | def main() -> None:

FILE: timesfm-forecasting/examples/global-temperature/generate_animation_data.py
  function main (line 30) | def main() -> None:

FILE: timesfm-forecasting/examples/global-temperature/generate_gif.py
  function create_frame (line 26) | def create_frame(
  function main (line 157) | def main() -> None:

FILE: timesfm-forecasting/examples/global-temperature/generate_html.py
  function main (line 521) | def main() -> None:

FILE: timesfm-forecasting/examples/global-temperature/visualize_forecast.py
  function main (line 30) | def main() -> None:

FILE: timesfm-forecasting/scripts/check_system.py
  class CheckResult (line 75) | class CheckResult:
    method icon (line 82) | def icon(self) -> str:
    method __str__ (line 85) | def __str__(self) -> str:
  class SystemReport (line 90) | class SystemReport:
    method passed (line 99) | def passed(self) -> bool:
    method to_dict (line 102) | def to_dict(self) -> dict[str, Any]:
  function _get_total_ram_gb (line 127) | def _get_total_ram_gb() -> float:
  function _get_available_ram_gb (line 174) | def _get_available_ram_gb() -> float:
  function check_ram (line 223) | def check_ram(profile: dict[str, Any]) -> CheckResult:
  function check_gpu (line 263) | def check_gpu() -> CheckResult:
  function check_disk (line 304) | def check_disk(profile: dict[str, Any]) -> CheckResult:
  function check_python (line 337) | def check_python() -> CheckResult:
  function check_package (line 358) | def check_package(pkg_name: str, import_name: str | None = None) -> Chec...
  function recommend_batch_size (line 384) | def recommend_batch_size(report: SystemReport) -> int:
  function estimate_memory_gb (line 428) | def estimate_memory_gb(
  function check_dataset_fit (line 481) | def check_dataset_fit(
  function print_memory_estimate (line 539) | def print_memory_estimate(
  function run_checks (line 595) | def run_checks(model_version: str = "v2.5") -> SystemReport:
  function print_report (line 637) | def print_report(report: SystemReport) -> None:
  function main (line 654) | def main() -> None:

FILE: timesfm-forecasting/scripts/forecast_csv.py
  function run_preflight (line 32) | def run_preflight() -> dict:
  function load_model (line 49) | def load_model(batch_size: int = 32):
  function load_csv (line 78) | def load_csv(
  function forecast_series (line 118) | def forecast_series(
  function write_csv_output (line 144) | def write_csv_output(
  function write_json_output (line 187) | def write_json_output(results: dict[str, dict], output_path: str) -> None:
  function main (line 194) | def main() -> None:

FILE: v1/experiments/baselines/timegpt_pipeline.py
  function get_seasonality (line 33) | def get_seasonality(freq: str) -> int:
  function maybe_convert_col_to_datetime (line 37) | def maybe_convert_col_to_datetime(
  function zero_pad_time_series (line 46) | def zero_pad_time_series(df, freq, min_length=36):
  class Forecaster (line 84) | class Forecaster:
    method forecast (line 90) | def forecast(
    method cross_validation (line 98) | def cross_validation(
  class TimeGPT (line 152) | class TimeGPT(Forecaster):
    method __init__ (line 159) | def __init__(
    method _get_client (line 173) | def _get_client(self) -> NixtlaClient:
    method forecast (line 184) | def forecast(
  function run_timegpt (line 225) | def run_timegpt(

FILE: v1/experiments/extended_benchmarks/run_timegpt.py
  function main (line 70) | def main():

FILE: v1/experiments/extended_benchmarks/run_timesfm.py
  function main (line 88) | def main():

FILE: v1/experiments/extended_benchmarks/utils.py
  function parallel_transform (line 36) | def parallel_transform(inp):
  function quantile_loss (line 41) | def quantile_loss(
  class ExperimentHandler (line 59) | class ExperimentHandler:
    method __init__ (line 61) | def __init__(
    method _maybe_download_m3_or_m5_file (line 104) | def _maybe_download_m3_or_m5_file(dataset: str):
    method _transform_quantiles_to_levels (line 124) | def _transform_quantiles_to_levels(quantiles: List[float]) -> List[int]:
    method _create_dir_if_not_exists (line 132) | def _create_dir_if_not_exists(directory: str):
    method _transform_gluonts_instance_to_df (line 136) | def _transform_gluonts_instance_to_df(
    method _transform_gluonts_dataset_to_df (line 151) | def _transform_gluonts_dataset_to_df(
    method train_df (line 164) | def train_df(self) -> pd.DataFrame:
    method test_df (line 169) | def test_df(self) -> pd.DataFrame:
    method save_dataframe (line 177) | def save_dataframe(self, df: pd.DataFrame, file_name: str):
    method save_results (line 180) | def save_results(
    method fcst_from_level_to_quantiles (line 193) | def fcst_from_level_to_quantiles(
    method evaluate_models (line 213) | def evaluate_models(self, models: List[str]) -> pd.DataFrame:
    method evaluate_from_predictions (line 232) | def evaluate_from_predictions(

FILE: v1/experiments/long_horizon_benchmarks/run_eval.py
  function get_forecasts (line 95) | def get_forecasts(model_path, model, past, freq, pred_len):
  function _mse (line 112) | def _mse(y_pred, y_true):
  function _mae (line 117) | def _mae(y_pred, y_true):
  function _smape (line 122) | def _smape(y_pred, y_true):
  function eval (line 131) | def eval():

FILE: v1/peft/finetune.py
  function finetune (line 63) | def finetune(

FILE: v1/src/adapter/dora_layers.py
  class DoraTheta (line 23) | class DoraTheta(base_layer.Theta):
    method __init__ (line 24) | def __init__(self, module):
    method _dora_initialized (line 27) | def _dora_initialized(self):
    method _dorafy_var (line 40) | def _dorafy_var(self, w):
    method __getattr__ (line 55) | def __getattr__(self, k):
    method __getitem__ (line 65) | def __getitem__(self, k):
  class DoraThetaDescriptor (line 76) | class DoraThetaDescriptor:
    method __get__ (line 79) | def __get__(self, obj, objtype=None):
  class DoraLinear (line 83) | class DoraLinear(linears.Linear):
    method setup (line 88) | def setup(self) -> None:
  class DoraAttentionProjection (line 121) | class DoraAttentionProjection(attentions.AttentionProjection):
    method setup (line 126) | def setup(self) -> None:
  class DoraCombinedQKVProjection (line 166) | class DoraCombinedQKVProjection(attentions.CombinedQKVProjectionLayer):
    method setup (line 171) | def setup(self) -> None:

FILE: v1/src/adapter/lora_layers.py
  class LoraTheta (line 23) | class LoraTheta(base_layer.Theta):
    method __init__ (line 24) | def __init__(self, module):
    method _lora_initialized (line 27) | def _lora_initialized(self):
    method _lorafy_var (line 38) | def _lorafy_var(self, w):
    method __getattr__ (line 46) | def __getattr__(self, k):
    method __getitem__ (line 56) | def __getitem__(self, k):
  class LoraThetaDescriptor (line 67) | class LoraThetaDescriptor:
    method __get__ (line 70) | def __get__(self, obj, objtype=None):
  class LoraLinear (line 74) | class LoraLinear(linears.Linear):
    method setup (line 79) | def setup(self) -> None:
  class LoraAttentionProjection (line 103) | class LoraAttentionProjection(attentions.AttentionProjection):
    method setup (line 108) | def setup(self) -> None:
  class LoraCombinedQKVProjection (line 139) | class LoraCombinedQKVProjection(attentions.CombinedQKVProjectionLayer):
    method setup (line 144) | def setup(self) -> None:

FILE: v1/src/adapter/utils.py
  function get_adapter_params (line 43) | def get_adapter_params(
  function load_adapter_checkpoint (line 101) | def load_adapter_checkpoint(
  function _merge_adapter_weights (line 200) | def _merge_adapter_weights(
  function _get_adapter_weight_params (line 281) | def _get_adapter_weight_params(
  function load_adapter_layer (line 334) | def load_adapter_layer(
  function _initialize_adapter_params (line 417) | def _initialize_adapter_params(

FILE: v1/src/finetuning/finetuning_example.py
  class TimeSeriesDataset (line 51) | class TimeSeriesDataset(Dataset):
    method __init__ (line 54) | def __init__(self,
    method _prepare_samples (line 77) | def _prepare_samples(self) -> None:
    method __len__ (line 88) | def __len__(self) -> int:
    method __getitem__ (line 91) | def __getitem__(
  function prepare_datasets (line 105) | def prepare_datasets(series: np.ndarray,
  function get_model (line 141) | def get_model(load_weights: bool = False):
  function plot_predictions (line 171) | def plot_predictions(
  function get_data (line 248) | def get_data(context_len: int,
  function single_gpu_example (line 269) | def single_gpu_example():
  function setup_process (line 300) | def setup_process(rank, world_size, model, config, train_dataset, val_da...
  function multi_gpu_example (line 335) | def multi_gpu_example():
  function main (line 372) | def main(argv):

FILE: v1/src/finetuning/finetuning_torch.py
  class MetricsLogger (line 21) | class MetricsLogger(ABC):
    method log_metrics (line 29) | def log_metrics(self,
    method close (line 41) | def close(self) -> None:
  class WandBLogger (line 46) | class WandBLogger(MetricsLogger):
    method __init__ (line 55) | def __init__(self, project: str, config: Dict[str, Any], rank: int = 0):
    method log_metrics (line 60) | def log_metrics(self,
    method close (line 72) | def close(self) -> None:
  class DistributedManager (line 78) | class DistributedManager:
    method __init__ (line 89) | def __init__(
    method setup (line 103) | def setup(self) -> None:
    method cleanup (line 113) | def cleanup(self) -> None:
  class FinetuningConfig (line 120) | class FinetuningConfig:
  class TimesFMFinetuner (line 162) | class TimesFMFinetuner:
    method __init__ (line 173) | def __init__(
    method _setup_distributed_model (line 203) | def _setup_distributed_model(self) -> nn.Module:
    method _create_dataloader (line 210) | def _create_dataloader(self, dataset: Dataset, is_train: bool) -> Data...
    method _quantile_loss (line 236) | def _quantile_loss(self, pred: torch.Tensor, actual: torch.Tensor,
    method _process_batch (line 251) | def _process_batch(self, batch: List[torch.Tensor]) -> tuple:
    method _train_epoch (line 279) | def _train_epoch(self, train_loader: DataLoader,
    method _validate (line 312) | def _validate(self, val_loader: DataLoader) -> float:
    method finetune (line 339) | def finetune(self, train_dataset: Dataset,

FILE: v1/src/timesfm/data_loader.py
  class TimeSeriesdata (line 27) | class TimeSeriesdata(object):
    method __init__ (line 30) | def __init__(
    method _get_cat_cols (line 120) | def _get_cat_cols(self, cat_cov_cols):
    method _normalize_data (line 131) | def _normalize_data(self):
    method train_gen (line 137) | def train_gen(self):
    method test_val_gen (line 178) | def test_val_gen(self, mode='val', shift=1):
    method _get_features_and_ts (line 220) | def _get_features_and_ts(self, dtimes, tsidx, hist_len=None):
    method tf_dataset (line 245) | def tf_dataset(self, mode='train', shift=1):

FILE: v1/src/timesfm/patched_decoder.py
  function _shift_padded_seq (line 61) | def _shift_padded_seq(mask: JTensor, seq: JTensor) -> JTensor:
  class ResidualBlock (line 84) | class ResidualBlock(base_layer.BaseLayer):
    method setup (line 107) | def setup(self):
    method __call__ (line 146) | def __call__(self, inputs: JTensor) -> JTensor:
  function _masked_mean_std (line 157) | def _masked_mean_std(inputs: JTensor,
  function _create_quantiles (line 206) | def _create_quantiles() -> list[float]:
  class PatchedTimeSeriesDecoder (line 211) | class PatchedTimeSeriesDecoder(base_layer.BaseLayer):
    method setup (line 242) | def setup(self) -> None:
    method transform_decode_state (line 288) | def transform_decode_state(
    method _forward_transform (line 293) | def _forward_transform(
    method _reverse_transform (line 305) | def _reverse_transform(self, outputs: JTensor,
    method _preprocess_input (line 311) | def _preprocess_input(
    method _postprocess_output (line 350) | def _postprocess_output(
    method __call__ (line 365) | def __call__(self, inputs: NestedMap) -> NestedMap:
    method decode (line 399) | def decode(
  class PatchedDecoderFinetuneModel (line 482) | class PatchedDecoderFinetuneModel(base_model.BaseModel):
    method setup (line 493) | def setup(self) -> None:
    method compute_predictions (line 496) | def compute_predictions(self, input_batch: NestedMap) -> NestedMap:
    method _quantile_loss (line 515) | def _quantile_loss(self, pred: JTensor, actual: JTensor,
    method compute_loss (line 532) | def compute_loss(self, prediction_output: NestedMap,

FILE: v1/src/timesfm/pytorch_patched_decoder.py
  function create_quantiles (line 24) | def create_quantiles() -> list[float]:
  class TimesFMConfig (line 29) | class TimesFMConfig:
  function _masked_mean_std (line 62) | def _masked_mean_std(
  function _shift_padded_seq (line 112) | def _shift_padded_seq(mask: torch.Tensor, seq: torch.Tensor) -> torch.Te...
  function get_large_negative_number (line 146) | def get_large_negative_number(dtype: torch.dtype) -> torch.Tensor:
  function apply_mask_to_logits (line 155) | def apply_mask_to_logits(logits: torch.Tensor,
  function convert_paddings_to_mask (line 173) | def convert_paddings_to_mask(
  function causal_mask (line 191) | def causal_mask(input_t: torch.Tensor) -> torch.Tensor:
  function merge_masks (line 211) | def merge_masks(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
  class ResidualBlock (line 239) | class ResidualBlock(nn.Module):
    method __init__ (line 242) | def __init__(
    method forward (line 264) | def forward(self, x):
  class RMSNorm (line 271) | class RMSNorm(torch.nn.Module):
    method __init__ (line 274) | def __init__(
    method _norm (line 285) | def _norm(self, x):
    method forward (line 288) | def forward(self, x):
  class TransformerMLP (line 297) | class TransformerMLP(nn.Module):
    method __init__ (line 300) | def __init__(
    method forward (line 310) | def forward(self, x, paddings=None):
  class TimesFMAttention (line 320) | class TimesFMAttention(nn.Module):
    method __init__ (line 323) | def __init__(
    method _per_dim_scaling (line 352) | def _per_dim_scaling(self, query: torch.Tensor) -> torch.Tensor:
    method forward (line 360) | def forward(
  class TimesFMDecoderLayer (line 418) | class TimesFMDecoderLayer(nn.Module):
    method __init__ (line 421) | def __init__(
    method forward (line 443) | def forward(
  class StackedDecoder (line 468) | class StackedDecoder(nn.Module):
    method __init__ (line 471) | def __init__(
    method forward (line 495) | def forward(
  class PositionalEmbedding (line 518) | class PositionalEmbedding(torch.nn.Module):
    method __init__ (line 529) | def __init__(
    method forward (line 540) | def forward(self, seq_length=None, position=None):
  class PatchedTimeSeriesDecoder (line 574) | class PatchedTimeSeriesDecoder(nn.Module):
    method __init__ (line 577) | def __init__(self, config: TimesFMConfig):
    method _forward_transform (line 604) | def _forward_transform(
    method _reverse_transform (line 622) | def _reverse_transform(
    method _preprocess_input (line 629) | def _preprocess_input(
    method _postprocess_output (line 677) | def _postprocess_output(
    method forward (line 694) | def forward(
    method decode (line 712) | def decode(

FILE: v1/src/timesfm/time_features.py
  function _distance_to_holiday (line 45) | def _distance_to_holiday(holiday):
  class TimeCovariates (line 112) | class TimeCovariates(object):
    method __init__ (line 115) | def __init__(
    method _minute_of_hour (line 135) | def _minute_of_hour(self):
    method _hour_of_day (line 141) | def _hour_of_day(self):
    method _day_of_week (line 147) | def _day_of_week(self):
    method _day_of_month (line 153) | def _day_of_month(self):
    method _day_of_year (line 159) | def _day_of_year(self):
    method _month_of_year (line 165) | def _month_of_year(self):
    method _week_of_year (line 171) | def _week_of_year(self):
    method _get_holidays (line 177) | def _get_holidays(self):
    method get_covariates (line 186) | def get_covariates(self):

FILE: v1/src/timesfm/timesfm_base.py
  function process_group (line 39) | def process_group(key, group, value_name, forecast_context_len):
  function moving_average (line 44) | def moving_average(arr, window_size):
  function freq_map (line 53) | def freq_map(freq: str):
  function strip_leading_nans (line 77) | def strip_leading_nans(arr):
  function linear_interpolation (line 94) | def linear_interpolation(arr):
  function _normalize (line 131) | def _normalize(batch):
  function _renormalize (line 140) | def _renormalize(batch, stats):
  class TimesFmHparams (line 145) | class TimesFmHparams:
  class TimesFmCheckpoint (line 184) | class TimesFmCheckpoint:
  class TimesFmBase (line 205) | class TimesFmBase:
    method _logging (line 214) | def _logging(self, s):
    method __post_init__ (line 217) | def __post_init__(self) -> None:
    method __init__ (line 221) | def __init__(self, hparams: TimesFmHparams,
    method load_from_checkpoint (line 253) | def load_from_checkpoint(self, checkpoint: TimesFmCheckpoint) -> None:
    method _preprocess (line 257) | def _preprocess(
    method _forecast (line 314) | def _forecast(
    method forecast (line 347) | def forecast(
    method forecast_with_covariates (line 429) | def forecast_with_covariates(
    method forecast_on_df (line 644) | def forecast_on_df(

FILE: v1/src/timesfm/timesfm_jax.py
  class TimesFmJax (line 41) | class TimesFmJax(timesfm_base.TimesFmBase):
    method _get_sample_inputs (line 57) | def _get_sample_inputs(self):
    method __post_init__ (line 85) | def __post_init__(self):
    method load_from_checkpoint (line 94) | def load_from_checkpoint(
    method jit_decode (line 178) | def jit_decode(self):
    method _forecast (line 239) | def _forecast(

FILE: v1/src/timesfm/timesfm_torch.py
  class TimesFmTorch (line 30) | class TimesFmTorch(timesfm_base.TimesFmBase):
    method __post_init__ (line 33) | def __post_init__(self):
    method load_from_checkpoint (line 52) | def load_from_checkpoint(
    method _forecast (line 72) | def _forecast(

FILE: v1/src/timesfm/xreg_lib.py
  function _unnest (line 31) | def _unnest(nested: Sequence[Sequence[Any]]) -> np.ndarray:
  function _repeat (line 35) | def _repeat(elements: Iterable[Any], counts: Iterable[int]) -> np.ndarray:
  function _to_padded_jax_array (line 42) | def _to_padded_jax_array(x: np.ndarray) -> jax.Array:
  class BatchedInContextXRegBase (line 56) | class BatchedInContextXRegBase:
    method __init__ (line 81) | def __init__(
    method _assert_covariates (line 193) | def _assert_covariates(self, assert_covariate_shapes: bool = False) ->...
    method create_covariate_matrix (line 298) | def create_covariate_matrix(
    method fit (line 377) | def fit(self) -> Any:
  class BatchedInContextXRegLinear (line 381) | class BatchedInContextXRegLinear(BatchedInContextXRegBase):
    method fit (line 384) | def fit(

FILE: v1/tests/test_timesfm.py
  function create_sample_dataframe (line 25) | def create_sample_dataframe(
  function test_timesfm_forecast_on_df (line 48) | def test_timesfm_forecast_on_df(
Condensed preview — 85 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (995K chars).
[
  {
    "path": ".gitattributes",
    "chars": 197,
    "preview": "# Git LFS tracking for binary outputs in timesfm-forecasting skill\ntimesfm-forecasting/**/*.png filter=lfs diff=lfs merg"
  },
  {
    "path": ".github/workflows/main.yml",
    "chars": 687,
    "preview": "name: Python package build\n\non:\n  push:\n    branches: [ \"master\" ]\n  pull_request:\n    branches: [ \"master\" ]\n\njobs:\n  b"
  },
  {
    "path": ".github/workflows/manual_publish.yml",
    "chars": 794,
    "preview": "name: Manual PyPI Publish\n\non:\n  workflow_dispatch:\n\njobs:\n  build-and-publish:\n    runs-on: ubuntu-latest\n    steps:\n  "
  },
  {
    "path": ".gitignore",
    "chars": 108,
    "preview": ".venv/\ndist/\n__pycache__/\ncheckpoints/\nwandb/\ndatasets/\nresults/\ntimesfm_jax.egg-info/\ndevelopment_setup.md\n"
  },
  {
    "path": "AGENTS.md",
    "chars": 844,
    "preview": "# TimesFM — Agent Entry Point\n\nThis repository ships a first-party **Agent Skill** for TimesFM at:\n\n```\ntimesfm-forecast"
  },
  {
    "path": "LICENSE",
    "chars": 11358,
    "preview": "\n                                 Apache License\n                           Version 2.0, January 2004\n                  "
  },
  {
    "path": "README.md",
    "chars": 3338,
    "preview": "# TimesFM\n\nTimesFM (Time Series Foundation Model) is a pretrained time-series foundation\nmodel developed by Google Resea"
  },
  {
    "path": "pyproject.toml",
    "chars": 955,
    "preview": "[project]\nname = \"timesfm\"\nversion = \"2.0.0\"\ndescription = \"A time series foundation model.\"\nauthors = [\n    {name = \"Ra"
  },
  {
    "path": "requirements.txt",
    "chars": 1015,
    "preview": "# This file was autogenerated by uv via the following command:\n#    uv pip compile pyproject.toml -o requirements.txt\nan"
  },
  {
    "path": "src/timesfm/__init__.py",
    "chars": 919,
    "preview": "# Copyright 2025 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "src/timesfm/configs.py",
    "chars": 3617,
    "preview": "# Copyright 2025 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "src/timesfm/flax/__init__.py",
    "chars": 574,
    "preview": "# Copyright 2025 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "src/timesfm/flax/dense.py",
    "chars": 3522,
    "preview": "# Copyright 2025 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "src/timesfm/flax/normalization.py",
    "chars": 2168,
    "preview": "# Copyright 2025 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "src/timesfm/flax/transformer.py",
    "chars": 11506,
    "preview": "# Copyright 2025 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "src/timesfm/flax/util.py",
    "chars": 3027,
    "preview": "# Copyright 2025 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "src/timesfm/timesfm_2p5/timesfm_2p5_base.py",
    "chars": 15174,
    "preview": "# Copyright 2025 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "src/timesfm/timesfm_2p5/timesfm_2p5_flax.py",
    "chars": 19375,
    "preview": "# Copyright 2025 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "src/timesfm/timesfm_2p5/timesfm_2p5_torch.py",
    "chars": 17324,
    "preview": "# Copyright 2025 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "src/timesfm/torch/__init__.py",
    "chars": 574,
    "preview": "# Copyright 2025 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "src/timesfm/torch/dense.py",
    "chars": 3109,
    "preview": "# Copyright 2025 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "src/timesfm/torch/normalization.py",
    "chars": 1204,
    "preview": "# Copyright 2025 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "src/timesfm/torch/transformer.py",
    "chars": 11720,
    "preview": "# Copyright 2025 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "src/timesfm/torch/util.py",
    "chars": 2654,
    "preview": "# Copyright 2025 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "src/timesfm/utils/xreg_lib.py",
    "chars": 20721,
    "preview": "# Copyright 2025 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "timesfm-forecasting/SKILL.md",
    "chars": 18707,
    "preview": "---\nname: timesfm-forecasting\ndescription: >\n  Zero-shot time series forecasting with Google's TimesFM foundation model."
  },
  {
    "path": "timesfm-forecasting/examples/anomaly-detection/detect_anomalies.py",
    "chars": 17022,
    "preview": "#!/usr/bin/env python3\n\"\"\"\nTimesFM Anomaly Detection Example — Two-Phase Method\n\nPhase 1 (context): Linear detrend + Z-s"
  },
  {
    "path": "timesfm-forecasting/examples/anomaly-detection/output/anomaly_detection.json",
    "chars": 9019,
    "preview": "{\n  \"method\": \"two_phase\",\n  \"context_method\": \"linear_detrend_zscore\",\n  \"forecast_method\": \"quantile_prediction_interv"
  },
  {
    "path": "timesfm-forecasting/examples/covariates-forecasting/demo_covariates.py",
    "chars": 19694,
    "preview": "#!/usr/bin/env python3\n\"\"\"\nTimesFM Covariates (XReg) Example\n\nDemonstrates the TimesFM covariate API using synthetic ret"
  },
  {
    "path": "timesfm-forecasting/examples/covariates-forecasting/output/covariates_metadata.json",
    "chars": 1586,
    "preview": "{\n  \"description\": \"Synthetic retail sales data with covariates for TimesFM XReg demo\",\n  \"note_on_real_data\": \"For real"
  },
  {
    "path": "timesfm-forecasting/examples/covariates-forecasting/output/sales_with_covariates.csv",
    "chars": 7399,
    "preview": "store_id,week,split,sales,base_sales,price,price_effect,promotion,holiday,day_of_week,store_type,region\nstore_A,0,contex"
  },
  {
    "path": "timesfm-forecasting/examples/global-temperature/README.md",
    "chars": 5708,
    "preview": "# TimesFM Forecast Report: Global Temperature Anomaly (2025)\n\n**Model:** TimesFM 1.0 (200M) PyTorch  \n**Generated:** 202"
  },
  {
    "path": "timesfm-forecasting/examples/global-temperature/generate_animation_data.py",
    "chars": 4993,
    "preview": "#!/usr/bin/env python3\n\"\"\"\nGenerate animation data for interactive forecast visualization.\n\nThis script runs TimesFM for"
  },
  {
    "path": "timesfm-forecasting/examples/global-temperature/generate_gif.py",
    "chars": 6646,
    "preview": "#!/usr/bin/env python3\n\"\"\"\nGenerate animated GIF showing forecast evolution.\n\nCreates a GIF animation showing how the Ti"
  },
  {
    "path": "timesfm-forecasting/examples/global-temperature/generate_html.py",
    "chars": 21136,
    "preview": "#!/usr/bin/env python3\n\"\"\"\nGenerate a self-contained HTML file with embedded animation data.\n\nThis creates a single HTML"
  },
  {
    "path": "timesfm-forecasting/examples/global-temperature/output/animation_data.json",
    "chars": 133208,
    "preview": "{\n  \"metadata\": {\n    \"model\": \"TimesFM 1.0 (200M) PyTorch\",\n    \"total_steps\": 25,\n    \"min_context\": 12,\n    \"max_hori"
  },
  {
    "path": "timesfm-forecasting/examples/global-temperature/output/forecast_output.csv",
    "chars": 1494,
    "preview": "date,point_forecast,q10,q20,q30,q40,q50,q60,q70,q80,q90,q99\n2025-01-01,1.2593384,1.248188,1.140702,1.1880752,1.2137158,1"
  },
  {
    "path": "timesfm-forecasting/examples/global-temperature/output/forecast_output.json",
    "chars": 4524,
    "preview": "{\n  \"model\": \"TimesFM 1.0 (200M) PyTorch\",\n  \"input\": {\n    \"source\": \"NOAA GISTEMP Global Temperature Anomaly\",\n    \"n_"
  },
  {
    "path": "timesfm-forecasting/examples/global-temperature/output/interactive_forecast.html",
    "chars": 153042,
    "preview": "<!DOCTYPE html>\n<html lang=\"en\">\n<head>\n    <meta charset=\"UTF-8\">\n    <meta name=\"viewport\" content=\"width=device-width"
  },
  {
    "path": "timesfm-forecasting/examples/global-temperature/run_example.sh",
    "chars": 1519,
    "preview": "#!/bin/bash\n# run_example.sh - Run the TimesFM temperature anomaly forecasting example\n#\n# This script:\n# 1. Runs the pr"
  },
  {
    "path": "timesfm-forecasting/examples/global-temperature/run_forecast.py",
    "chars": 5470,
    "preview": "#!/usr/bin/env python3\n\"\"\"\nRun TimesFM forecast on global temperature anomaly data.\nGenerates forecast output CSV and JS"
  },
  {
    "path": "timesfm-forecasting/examples/global-temperature/temperature_anomaly.csv",
    "chars": 591,
    "preview": "date,anomaly_c\n2022-01-01,0.89\n2022-02-01,0.89\n2022-03-01,1.02\n2022-04-01,0.88\n2022-05-01,0.85\n2022-06-01,0.88\n2022-07-0"
  },
  {
    "path": "timesfm-forecasting/examples/global-temperature/visualize_forecast.py",
    "chars": 3287,
    "preview": "#!/usr/bin/env python3\n\"\"\"\nVisualize TimesFM forecast results for global temperature anomaly.\n\nGenerates a publication-q"
  },
  {
    "path": "timesfm-forecasting/references/api_reference.md",
    "chars": 9725,
    "preview": "# TimesFM API Reference\n\n## Model Classes\n\n### `timesfm.TimesFM_2p5_200M_torch`\n\nThe primary model class for TimesFM 2.5"
  },
  {
    "path": "timesfm-forecasting/references/data_preparation.md",
    "chars": 7175,
    "preview": "# Data Preparation for TimesFM\n\n## Input Format\n\nTimesFM accepts a **list of 1-D numpy arrays**. Each array represents o"
  },
  {
    "path": "timesfm-forecasting/references/system_requirements.md",
    "chars": 7004,
    "preview": "# System Requirements for TimesFM\n\n## Hardware Tiers\n\nTimesFM can run on a variety of hardware configurations. This guid"
  },
  {
    "path": "timesfm-forecasting/scripts/check_system.py",
    "chars": 24136,
    "preview": "#!/usr/bin/env python3\n\"\"\"TimesFM System Requirements Preflight Checker.\n\nMANDATORY: Run this script before loading Time"
  },
  {
    "path": "timesfm-forecasting/scripts/forecast_csv.py",
    "chars": 8664,
    "preview": "#!/usr/bin/env python3\n\"\"\"End-to-end CSV forecasting with TimesFM.\n\nLoads a CSV, runs the system preflight check, loads "
  },
  {
    "path": "v1/LICENSE",
    "chars": 11358,
    "preview": "\n                                 Apache License\n                           Version 2.0, January 2004\n                  "
  },
  {
    "path": "v1/README.md",
    "chars": 13885,
    "preview": "# TimesFM\n\nTimesFM  (Time Series Foundation Model) is a pretrained time-series foundation model developed by Google\nRese"
  },
  {
    "path": "v1/TROUBLESHOOTING.md",
    "chars": 4703,
    "preview": "# Troubleshooting\n\nThis document provides solutions to common issues encountered when using TimesFM.\n\n## Installation Is"
  },
  {
    "path": "v1/docs/contributing.md",
    "chars": 1067,
    "preview": "# How to Contribute\n\nWe would love to accept your patches and contributions to this project.\n\n## Before you begin\n\n### S"
  },
  {
    "path": "v1/experiments/baselines/__init__.py",
    "chars": 573,
    "preview": "# Copyright 2024 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "v1/experiments/baselines/timegpt_pipeline.py",
    "chars": 7880,
    "preview": "# Copyright 2024 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "v1/experiments/extended_benchmarks/README.md",
    "chars": 2667,
    "preview": "# Extended Benchmarks\n\nThe benchmark setting has been borrowed from Nixtla's original [benchmarking](https://github.com/"
  },
  {
    "path": "v1/experiments/extended_benchmarks/run_timegpt.py",
    "chars": 2907,
    "preview": "# Copyright 2024 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "v1/experiments/extended_benchmarks/run_timesfm.py",
    "chars": 4249,
    "preview": "# Copyright 2024 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "v1/experiments/extended_benchmarks/utils.py",
    "chars": 9710,
    "preview": "# Copyright 2024 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "v1/experiments/long_horizon_benchmarks/README.md",
    "chars": 2372,
    "preview": "# Extended Benchmarks\n\nWe benchmark on the original test set for ETT datasets as per long horizon benchmark papers (see "
  },
  {
    "path": "v1/experiments/long_horizon_benchmarks/run_eval.py",
    "chars": 7626,
    "preview": "# Copyright 2024 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
  },
  {
    "path": "v1/notebooks/covariates.ipynb",
    "chars": 14988,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# TimesFM with Covariates\\n\",\n    \""
  },
  {
    "path": "v1/notebooks/finetuning.ipynb",
    "chars": 18560,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Importing relevant packages for "
  },
  {
    "path": "v1/notebooks/finetuning_torch.ipynb",
    "chars": 18550,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Introduction\\n\",\n    \"This notebo"
  },
  {
    "path": "v1/peft/README.md",
    "chars": 2202,
    "preview": "# Fine-Tuning Pipeline\n\nThis folder contains a generic fine-tuning pipeline designed to support multiple PEFT fine-tunin"
  },
  {
    "path": "v1/peft/finetune.py",
    "chars": 13525,
    "preview": "# Copyright 2024 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
  },
  {
    "path": "v1/peft/finetune.sh",
    "chars": 957,
    "preview": "#!/bin/bash\n\n# Script to finetune a model with specific configurations\n# Adjust the parameters below as needed. For a fu"
  },
  {
    "path": "v1/peft/usage.ipynb",
    "chars": 5536,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Load Base Model\"\n   ]\n  },\n  {\n "
  },
  {
    "path": "v1/pyproject.toml",
    "chars": 2269,
    "preview": "[tool.poetry]\nname = \"timesfm\"\npackages = [\n    { include = \"timesfm\", from = \"src\" },\n    { include = \"finetuning\", fro"
  },
  {
    "path": "v1/src/adapter/__init__.py",
    "chars": 777,
    "preview": "# Copyright 2024 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "v1/src/adapter/dora_layers.py",
    "chars": 6256,
    "preview": "# Copyright 2024 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
  },
  {
    "path": "v1/src/adapter/lora_layers.py",
    "chars": 5029,
    "preview": "# Copyright 2024 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
  },
  {
    "path": "v1/src/adapter/utils.py",
    "chars": 17697,
    "preview": "# Copyright 2024 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
  },
  {
    "path": "v1/src/finetuning/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "v1/src/finetuning/finetuning_example.py",
    "chars": 12001,
    "preview": "\"\"\"\nExample usage of the TimesFM Finetuning Framework.\n\nFor single GPU:\npython script.py --training_mode=single\n\nFor mul"
  },
  {
    "path": "v1/src/finetuning/finetuning_torch.py",
    "chars": 12456,
    "preview": "\"\"\"\nTimesFM Finetuner: A flexible framework for finetuning TimesFM models on custom datasets.\n\"\"\"\n\nimport logging\nimport"
  },
  {
    "path": "v1/src/timesfm/__init__.py",
    "chars": 1180,
    "preview": "# Copyright 2024 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "v1/src/timesfm/data_loader.py",
    "chars": 8496,
    "preview": "# Copyright 2024 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
  },
  {
    "path": "v1/src/timesfm/patched_decoder.py",
    "chars": 19100,
    "preview": "# Copyright 2024 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "v1/src/timesfm/pytorch_patched_decoder.py",
    "chars": 26137,
    "preview": "# Copyright 2024 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "v1/src/timesfm/time_features.py",
    "chars": 6285,
    "preview": "# Copyright 2024 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
  },
  {
    "path": "v1/src/timesfm/timesfm_base.py",
    "chars": 27305,
    "preview": "# Copyright 2024 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "v1/src/timesfm/timesfm_jax.py",
    "chars": 12707,
    "preview": "# Copyright 2024 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "v1/src/timesfm/timesfm_torch.py",
    "chars": 6369,
    "preview": "# Copyright 2024 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "v1/src/timesfm/xreg_lib.py",
    "chars": 20671,
    "preview": "# Copyright 2024 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "v1/tests/test_timesfm.py",
    "chars": 3030,
    "preview": "# Copyright 2024 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
  }
]

About this extraction

This page contains the full source code of the google-research/timesfm GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 85 files (917.3 KB), approximately 280.2k tokens, and a symbol index with 376 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!