Showing preview only (955K chars total). Download the full file or copy to clipboard to get everything.
Repository: google-research/timesfm
Branch: master
Commit: ed29ec5ef7e2
Files: 85
Total size: 917.3 KB
Directory structure:
gitextract_oox45t84/
├── .gitattributes
├── .github/
│ └── workflows/
│ ├── main.yml
│ └── manual_publish.yml
├── .gitignore
├── AGENTS.md
├── LICENSE
├── README.md
├── pyproject.toml
├── requirements.txt
├── src/
│ └── timesfm/
│ ├── __init__.py
│ ├── configs.py
│ ├── flax/
│ │ ├── __init__.py
│ │ ├── dense.py
│ │ ├── normalization.py
│ │ ├── transformer.py
│ │ └── util.py
│ ├── timesfm_2p5/
│ │ ├── timesfm_2p5_base.py
│ │ ├── timesfm_2p5_flax.py
│ │ └── timesfm_2p5_torch.py
│ ├── torch/
│ │ ├── __init__.py
│ │ ├── dense.py
│ │ ├── normalization.py
│ │ ├── transformer.py
│ │ └── util.py
│ └── utils/
│ └── xreg_lib.py
├── timesfm-forecasting/
│ ├── SKILL.md
│ ├── examples/
│ │ ├── anomaly-detection/
│ │ │ ├── detect_anomalies.py
│ │ │ └── output/
│ │ │ └── anomaly_detection.json
│ │ ├── covariates-forecasting/
│ │ │ ├── demo_covariates.py
│ │ │ └── output/
│ │ │ ├── covariates_metadata.json
│ │ │ └── sales_with_covariates.csv
│ │ └── global-temperature/
│ │ ├── README.md
│ │ ├── generate_animation_data.py
│ │ ├── generate_gif.py
│ │ ├── generate_html.py
│ │ ├── output/
│ │ │ ├── animation_data.json
│ │ │ ├── forecast_output.csv
│ │ │ ├── forecast_output.json
│ │ │ └── interactive_forecast.html
│ │ ├── run_example.sh
│ │ ├── run_forecast.py
│ │ ├── temperature_anomaly.csv
│ │ └── visualize_forecast.py
│ ├── references/
│ │ ├── api_reference.md
│ │ ├── data_preparation.md
│ │ └── system_requirements.md
│ └── scripts/
│ ├── check_system.py
│ └── forecast_csv.py
└── v1/
├── LICENSE
├── README.md
├── TROUBLESHOOTING.md
├── docs/
│ └── contributing.md
├── experiments/
│ ├── baselines/
│ │ ├── __init__.py
│ │ └── timegpt_pipeline.py
│ ├── extended_benchmarks/
│ │ ├── README.md
│ │ ├── run_timegpt.py
│ │ ├── run_timesfm.py
│ │ └── utils.py
│ └── long_horizon_benchmarks/
│ ├── README.md
│ └── run_eval.py
├── notebooks/
│ ├── covariates.ipynb
│ ├── finetuning.ipynb
│ └── finetuning_torch.ipynb
├── peft/
│ ├── README.md
│ ├── finetune.py
│ ├── finetune.sh
│ └── usage.ipynb
├── pyproject.toml
├── src/
│ ├── adapter/
│ │ ├── __init__.py
│ │ ├── dora_layers.py
│ │ ├── lora_layers.py
│ │ └── utils.py
│ ├── finetuning/
│ │ ├── __init__.py
│ │ ├── finetuning_example.py
│ │ └── finetuning_torch.py
│ └── timesfm/
│ ├── __init__.py
│ ├── data_loader.py
│ ├── patched_decoder.py
│ ├── pytorch_patched_decoder.py
│ ├── time_features.py
│ ├── timesfm_base.py
│ ├── timesfm_jax.py
│ ├── timesfm_torch.py
│ └── xreg_lib.py
└── tests/
└── test_timesfm.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitattributes
================================================
# Git LFS tracking for binary outputs in timesfm-forecasting skill
timesfm-forecasting/**/*.png filter=lfs diff=lfs merge=lfs -text
timesfm-forecasting/**/*.gif filter=lfs diff=lfs merge=lfs -text
================================================
FILE: .github/workflows/main.yml
================================================
name: Python package build
on:
push:
branches: [ "master" ]
pull_request:
branches: [ "master" ]
jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.11'
- name: Install uv
run: |
curl -LsSf https://astral.sh/uv/install.sh | sh
echo "$HOME/.cargo/bin" >> $GITHUB_PATH
- name: Create virtual environment
run: uv venv
- name: Install build dependencies
run: |
uv pip install build ".[torch,flax]"
- name: Build package
run: uv run python -m build
================================================
FILE: .github/workflows/manual_publish.yml
================================================
name: Manual PyPI Publish
on:
workflow_dispatch:
jobs:
build-and-publish:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.11'
- name: Install uv
run: |
curl -LsSf https://astral.sh/uv/install.sh | sh
echo "$HOME/.cargo/bin" >> $GITHUB_PATH
- name: Create virtual environment
run: uv venv
- name: Install build dependencies
run: uv pip install build twine
- name: Build package
run: uv run python -m build
- name: Publish to PyPI
env:
TWINE_USERNAME: __token__
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
run: uv run twine upload dist/*
================================================
FILE: .gitignore
================================================
.venv/
dist/
__pycache__/
checkpoints/
wandb/
datasets/
results/
timesfm_jax.egg-info/
development_setup.md
================================================
FILE: AGENTS.md
================================================
# TimesFM — Agent Entry Point
This repository ships a first-party **Agent Skill** for TimesFM at:
```
timesfm-forecasting/
└── SKILL.md ← read this for the full skill
```
## Install the skill
Copy the skill directory into your agent's skills folder:
```bash
# Cursor / Claude Code / OpenCode / Codex (global install)
cp -r timesfm-forecasting/ ~/.cursor/skills/
cp -r timesfm-forecasting/ ~/.claude/skills/
# Or project-level
cp -r timesfm-forecasting/ .cursor/skills/
```
Any agent that supports the open [Agent Skills standard](https://agentskills.io) will discover it automatically.
## Working in this repo
If you are developing TimesFM itself (not using it), the source lives in `src/timesfm/`.
Archived v1/v2 code and notebooks are in `v1/`.
Run tests:
```bash
pytest v1/tests/
```
See `README.md` for full developer setup.
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
# TimesFM
TimesFM (Time Series Foundation Model) is a pretrained time-series foundation
model developed by Google Research for time-series forecasting.
* Paper:
[A decoder-only foundation model for time-series forecasting](https://arxiv.org/abs/2310.10688),
ICML 2024.
* All checkpoints:
[TimesFM Hugging Face Collection](https://huggingface.co/collections/google/timesfm-release-66e4be5fdb56e960c1e482a6).
* [Google Research blog](https://research.google/blog/a-decoder-only-foundation-model-for-time-series-forecasting/).
* [TimesFM in BigQuery](https://cloud.google.com/bigquery/docs/timesfm-model):
an official Google product.
This open version is not an officially supported Google product.
**Latest Model Version:** TimesFM 2.5
**Archived Model Versions:**
- 1.0 and 2.0: relevant code archived in the sub directory `v1`. You can `pip
install timesfm==1.3.0` to install an older version of this package to load
them.
## Update - Oct. 29, 2025
Added back the covariate support through XReg for TimesFM 2.5.
## Update - Sept. 15, 2025
TimesFM 2.5 is out!
Comparing to TimesFM 2.0, this new 2.5 model:
- uses 200M parameters, down from 500M.
- supports up to 16k context length, up from 2048.
- supports continuous quantile forecast up to 1k horizon via an optional 30M
quantile head.
- gets rid of the `frequency` indicator.
- has a couple of new forecasting flags.
Along with the model upgrade we have also upgraded the inference API. This repo
will be under construction over the next few weeks to
1. add support for an upcoming Flax version of the model (faster inference).
2. add back covariate support.
3. populate more docstrings, docs and notebook.
### Install
1. Clone the repository:
```shell
git clone https://github.com/google-research/timesfm.git
cd timesfm
```
2. Create a virtual environment and install dependencies using `uv`:
```shell
# Create a virtual environment
uv venv
# Activate the environment
source .venv/bin/activate
# Install the package in editable mode with torch
uv pip install -e .[torch]
# Or with flax
uv pip install -e .[flax]
# Or XReg is needed
uv pip install -e .[xreg]
```
3. [Optional] Install your preferred `torch` / `jax` backend based on your OS and accelerators
(CPU, GPU, TPU or Apple Silicon).:
- [Install PyTorch](https://pytorch.org/get-started/locally/).
- [Install Jax](https://docs.jax.dev/en/latest/installation.html#installation)
for Flax.
### Code Example
```python
import torch
import numpy as np
import timesfm
torch.set_float32_matmul_precision("high")
model = timesfm.TimesFM_2p5_200M_torch.from_pretrained("google/timesfm-2.5-200m-pytorch")
model.compile(
timesfm.ForecastConfig(
max_context=1024,
max_horizon=256,
normalize_inputs=True,
use_continuous_quantile_head=True,
force_flip_invariance=True,
infer_is_positive=True,
fix_quantile_crossing=True,
)
)
point_forecast, quantile_forecast = model.forecast(
horizon=12,
inputs=[
np.linspace(0, 1, 100),
np.sin(np.linspace(0, 20, 67)),
], # Two dummy inputs
)
point_forecast.shape # (2, 12)
quantile_forecast.shape # (2, 12, 10): mean, then 10th to 90th quantiles.
```
================================================
FILE: pyproject.toml
================================================
[project]
name = "timesfm"
version = "2.0.0"
description = "A time series foundation model."
authors = [
{name = "Rajat Sen", email = "senrajat@google.com"},
{name = "Yichen Zhou", email = "yichenzhou@google.com"},
{name = "Abhimanyu Das", email = "abhidas@google.com"},
{name = "Petros Mol", email = "pmol@google.com"},
{name = "Michael Chertushkin", email = "chertushkinmichael@gmail.com"},
]
license = {text = "Apache-2.0"}
readme = "README.md"
requires-python = ">=3.10"
dependencies = [
"numpy>=1.26.4",
"huggingface_hub[cli]>=0.23.0",
"safetensors>=0.5.3",
]
[project.optional-dependencies]
torch = [
"torch>=2.0.0",
]
flax = [
"flax",
"optax",
"einshape",
"orbax-checkpoint",
"jaxtyping",
"jax[cuda]"
]
xreg = [
"jax[cuda]",
"scikit-learn",
]
[tool.ruff]
line-length = 88
indent-width = 2
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
================================================
FILE: requirements.txt
================================================
# This file was autogenerated by uv via the following command:
# uv pip compile pyproject.toml -o requirements.txt
anyio==4.11.0
# via httpx
certifi==2025.10.5
# via
# httpcore
# httpx
click==8.3.0
# via typer-slim
filelock==3.19.1
# via huggingface-hub
fsspec==2025.9.0
# via huggingface-hub
h11==0.16.0
# via httpcore
hf-xet==1.2.0
# via huggingface-hub
httpcore==1.0.9
# via httpx
httpx==0.28.1
# via huggingface-hub
huggingface-hub==1.0.1
# via timesfm (pyproject.toml)
idna==3.10
# via
# anyio
# httpx
numpy==2.2.6
# via timesfm (pyproject.toml)
packaging==25.0
# via huggingface-hub
pyyaml==6.0.3
# via huggingface-hub
safetensors==0.6.2
# via timesfm (pyproject.toml)
shellingham==1.5.4
# via huggingface-hub
sniffio==1.3.1
# via anyio
tqdm==4.67.1
# via huggingface-hub
typer-slim==0.20.0
# via huggingface-hub
typing-extensions==4.15.0
# via
# anyio
# huggingface-hub
# typer-slim
================================================
FILE: src/timesfm/__init__.py
================================================
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TimesFM API."""
from .configs import ForecastConfig
try:
from .timesfm_2p5 import timesfm_2p5_torch
TimesFM_2p5_200M_torch = timesfm_2p5_torch.TimesFM_2p5_200M_torch
except ImportError:
pass
try:
from .timesfm_2p5 import timesfm_2p5_flax
TimesFM_2p5_200M_flax = timesfm_2p5_flax.TimesFM_2p5_200M_flax
except ImportError:
pass
================================================
FILE: src/timesfm/configs.py
================================================
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Abstract configs for TimesFM layers."""
import dataclasses
from typing import Literal
@dataclasses.dataclass(frozen=True)
class ForecastConfig:
"""Options for forecasting.
Attributes:
max_context: The maximum context length. This is used by the complied decode
function at inference time during batched inference. Any input time series
with length less than max_context will be padded with zeros, and with
length greater than max_context will be truncated.
max_horizon: The maximum horizon length. This is used by the complied decode
function at inference time during batched inference. The compiled cached
decoding function will by default forecast till max_horizon.
normalize_inputs: Whether to normalize the inputs. This is useful when the
raw inputs are of extremely large or small magnitudes which may result in
numerical issues.
window_size: The window size for decomposed forecasting.
TODO(siriuz42):implement it.
per_core_batch_size: The batch size per core. Used at inference time during
batched inference when multiple GPU / TPU devices are used.
use_continuous_quantile_head: Whether to use a separate continuous quantile
head to avoid quantile collapsing.
force_flip_invariance: Whether to force flip invariance. TimesFM guarantees
that TimesFM(aX + b) = a * TimesFM(x) + b for a >= 0 by default. This flag
extends it to a < 0 as well.
infer_is_positive: Whether to guarantee nonnegativity of the output if the
input is nonnegative.
fix_quantile_crossing: Whether to fix quantile crossing.
return_backcast: Whether to return backcast.
"""
max_context: int = 0
max_horizon: int = 0
normalize_inputs: bool = False
window_size: int = 0
per_core_batch_size: int = 1
use_continuous_quantile_head: bool = False
force_flip_invariance: bool = True
infer_is_positive: bool = True
fix_quantile_crossing: bool = False
return_backcast: bool = False
@dataclasses.dataclass(frozen=True)
class ResidualBlockConfig:
"""Framework-agnostic config for a residual block."""
input_dims: int
hidden_dims: int
output_dims: int
use_bias: bool
activation: Literal["relu", "swish", "none"]
@dataclasses.dataclass(frozen=True)
class RandomFourierFeaturesConfig:
"""Framework-agnostic config for random fourier features."""
input_dims: int
output_dims: int
projection_stddev: float
use_bias: bool
@dataclasses.dataclass(frozen=True)
class TransformerConfig:
"""Framework-agnostic config for a transformer."""
model_dims: int
hidden_dims: int
num_heads: int
attention_norm: Literal["rms"]
feedforward_norm: Literal["rms"]
qk_norm: Literal["rms", "none"]
use_bias: bool
use_rotary_position_embeddings: bool
ff_activation: Literal["relu", "swish", "none"]
fuse_qkv: bool
@dataclasses.dataclass(frozen=True)
class StackedTransformersConfig:
"""Framework-agnostic config for a stacked transformers."""
num_layers: int
transformer: TransformerConfig
================================================
FILE: src/timesfm/flax/__init__.py
================================================
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
================================================
FILE: src/timesfm/flax/dense.py
================================================
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Dense layers for TimesFM."""
from flax import nnx
import jax
import jax.numpy as jnp
import jaxtyping
from .. import configs
Array = jaxtyping.Array
Bool = jaxtyping.Bool
Float = jaxtyping.Float
Integer = jaxtyping.Integer
Num = jaxtyping.Num
ResidualBlockConfig = configs.ResidualBlockConfig
RandomFourierFeaturesConfig = configs.RandomFourierFeaturesConfig
class ResidualBlock(nnx.Module):
"""Residual block with two linear layers and a linear residual connection."""
def __init__(self, config: ResidualBlockConfig, *, rngs=nnx.Rngs(42)):
self.config = config
self.hidden_layer = nnx.Linear(
in_features=config.input_dims,
out_features=config.hidden_dims,
use_bias=config.use_bias,
rngs=rngs,
)
self.output_layer = nnx.Linear(
in_features=config.hidden_dims,
out_features=config.output_dims,
use_bias=config.use_bias,
rngs=rngs,
)
self.residual_layer = nnx.Linear(
in_features=config.input_dims,
out_features=config.output_dims,
use_bias=config.use_bias,
rngs=rngs,
)
if config.activation == "relu":
self.activation = jax.nn.relu
elif config.activation == "swish":
self.activation = jax.nn.swish
elif config.activation == "none":
self.activation = lambda x: x
else:
raise ValueError(f"Activation: {config.activation} not supported.")
def __call__(self, x: Float[Array, "b ... i"]) -> Float[Array, "b ... o"]:
return self.output_layer(
self.activation(self.hidden_layer(x))
) + self.residual_layer(x)
class RandomFourierFeatures(nnx.Module):
"""Random Fourier features layer."""
__data__ = ("phrase_shifts",)
def __init__(self, config: RandomFourierFeaturesConfig, *, rngs=nnx.Rngs(42)):
self.config = config
if config.output_dims % 4 != 0:
raise ValueError(
f"Output dims must be a multiple of 4: {config.output_dims} % 4 != 0."
)
num_projected_features = config.output_dims // 4
self.phase_shifts = nnx.Param(jnp.zeros(shape=(2, num_projected_features)))
self.projection_layer = nnx.Linear(
in_features=config.input_dims,
out_features=num_projected_features,
use_bias=config.use_bias,
rngs=rngs,
)
self.residual_layer = nnx.Linear(
in_features=config.input_dims,
out_features=config.output_dims,
use_bias=config.use_bias,
rngs=rngs,
)
def __call__(self, x: Float[Array, "b ... i"]) -> Float[Array, "b ... o"]:
projected = self.projection_layer(x)
cos_features = jnp.cos(projected)
sin_features = jnp.sin(projected)
sq_wave_1 = jnp.sign(jnp.sin(projected + self.phase_shifts[0, :]))
sq_wave_2 = jnp.sign(jnp.sin(projected + self.phase_shifts[1, :]))
fourier_features = jnp.concatenate(
[cos_features, sin_features, sq_wave_1, sq_wave_2], axis=-1
)
residual = self.residual_layer(x)
return fourier_features + residual
================================================
FILE: src/timesfm/flax/normalization.py
================================================
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Normalization layers for TimesFM."""
from flax import nnx
import jax
import jax.numpy as jnp
import jaxtyping
Array = jaxtyping.Array
Bool = jaxtyping.Bool
Float = jaxtyping.Float
Integer = jaxtyping.Integer
Num = jaxtyping.Num
class RMSNorm(nnx.Module):
"""RMS normalization."""
__data__ = ("scale",)
def __init__(
self,
num_features: int,
*,
epsilon: float = 1e-6,
rngs=nnx.Rngs(42),
):
del rngs
self.scale = nnx.Param(jnp.zeros(shape=(num_features,)))
self.num_features = num_features
self.epsilon = epsilon
def __call__(self, inputs: Float[Array, "b ... d"]) -> Float[Array, "b ... d"]:
var = jnp.mean(jnp.square(inputs), axis=-1, keepdims=True)
normed_inputs = inputs * jax.lax.rsqrt(var + self.epsilon)
normed_inputs *= self.scale
return normed_inputs
class LayerNorm(nnx.Module):
"""Layer normalization replica of LayerNorm."""
__data__ = ("scale", "bias")
def __init__(self, num_features: int, *, epsilon: float = 1e-6, rngs=nnx.Rngs(42)):
del rngs
self.scale = nnx.Param(jnp.ones(shape=(num_features,)))
self.bias = nnx.Param(jnp.zeros(shape=(num_features,)))
self.num_features = num_features
self.epsilon = epsilon
def __call__(self, inputs: Float[Array, "b ... d"]) -> Float[Array, "b ... d"]:
mean = jnp.mean(inputs, axis=-1, keepdims=True)
var = jnp.mean(jnp.square(inputs - mean), axis=-1, keepdims=True)
normed_inputs = (inputs - mean) * jax.lax.rsqrt(var + self.epsilon)
normed_inputs *= self.scale
normed_inputs += self.bias
return normed_inputs
================================================
FILE: src/timesfm/flax/transformer.py
================================================
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformer layers for TimesFM."""
import functools
from typing import Callable
from flax import nnx
from flax.nnx.nn import linear
import jax
from jax import lax
import jax.numpy as jnp
import jaxtyping
from .. import configs
from . import normalization, util
Array = jaxtyping.Array
Bool = jaxtyping.Bool
Float = jaxtyping.Float
Integer = jaxtyping.Integer
Num = jaxtyping.Num
LayerNorm = normalization.LayerNorm
RMSNorm = normalization.RMSNorm
LinearGeneral = linear.LinearGeneral
TransformerConfig = configs.TransformerConfig
DecodeCache = util.DecodeCache
@functools.partial(
jax.jit,
static_argnames=("query_length", "kv_length"),
)
def make_attn_mask(
query_length: int,
num_all_masked_kv: Integer[Array, "b"],
query_index_offset: Integer[Array, "b"] | None = None,
kv_length: int = 0,
) -> Bool[Array, "b 1 q n"]:
"""Makes attention mask."""
if kv_length == 0:
kv_length = query_length
q_index = jnp.arange(query_length)[None, None, :, None]
if query_index_offset is not None:
q_index += query_index_offset[:, None, None, None]
kv_index = jnp.arange(kv_length)[None, None, None, :]
return jnp.logical_and(
q_index >= kv_index,
kv_index >= num_all_masked_kv[:, None, None, None],
)
class RotaryPositionalEmbedding(nnx.Module):
"""Rotary positional embedding."""
def __init__(
self,
embedding_dims: int,
min_timescale: int = 1,
max_timescale: int = 10000,
):
self.embedding_dims = embedding_dims
self.min_timescale = min_timescale
self.max_timescale = max_timescale
def __call__(
self,
inputs: Float[Array, "b ... d"],
position: Array | None = None,
):
"""Generates a JTensor of sinusoids with different frequencies."""
if self.embedding_dims != inputs.shape[-1]:
raise ValueError(
"The embedding dims of the rotary position embedding"
"must match the hidden dimension of the inputs."
)
half_embedding_dim = self.embedding_dims // 2
fraction = 2 * jnp.arange(0, half_embedding_dim) / self.embedding_dims
timescale = (
self.min_timescale * (self.max_timescale / self.min_timescale) ** fraction
)
if position is None:
seq_length = inputs.shape[1]
position = jnp.arange(seq_length, dtype=jnp.float32)[None, :]
if len(inputs.shape) == 4:
position = position[..., None, None]
timescale = timescale[None, None, None, :]
elif len(inputs.shape) == 3:
position = position[..., None]
timescale = timescale[None, None, :]
else:
raise ValueError("Inputs must be of rank 3 or 4.")
sinusoid_inp = position / timescale
sin = jnp.sin(sinusoid_inp)
cos = jnp.cos(sinusoid_inp)
first_half, second_half = jnp.split(inputs, 2, axis=-1)
first_part = first_half * cos - second_half * sin
second_part = second_half * cos + first_half * sin
first_part = first_part.astype(None)
second_part = second_part.astype(None)
return jnp.concatenate([first_part, second_part], axis=-1)
class PerDimScale(nnx.Module):
"""Per-dimension scaling."""
__data__ = ("per_dim_scale",)
def __init__(self, num_dims: int, *, rngs=nnx.Rngs(42)):
del rngs
self.num_dims = num_dims
self.per_dim_scale = nnx.Param(jnp.zeros(shape=(num_dims,)))
def __call__(self, x: Float[Array, "b ... d"]) -> Float[Array, "b ... d"]:
return x * (
1.442695041 / jnp.sqrt(self.num_dims) * jax.nn.softplus(self.per_dim_scale)
)
class MultiHeadAttention(nnx.Module):
"""Multi-head attention."""
def __init__(
self,
num_heads: int,
in_features: int,
*,
use_per_dim_scale: bool = True,
use_rotary_position_embeddings: bool = True,
use_bias: bool = False,
deterministic: bool | None = None,
attention_fn: Callable[..., Array] = nnx.dot_product_attention,
qk_norm: str = "rms",
rngs=nnx.Rngs(42),
):
self.num_heads = num_heads
self.in_features = in_features
self.qkv_features = in_features
self.out_features = in_features
self.in_kv_features = in_features
self.deterministic = deterministic
self.use_bias = use_bias
self.attention_fn = attention_fn
self.qk_norm = qk_norm
if self.qkv_features % self.num_heads != 0:
raise ValueError(
f"Memory dimension ({self.qkv_features}) must be divisible by "
f"'num_heads' heads ({self.num_heads})."
)
self.head_dim = self.qkv_features // self.num_heads
linear_general = functools.partial(
LinearGeneral,
out_features=(self.num_heads, self.head_dim),
use_bias=self.use_bias,
)
# project inputs_q to multi-headed q/k/v
# dimensions are then [batch..., length, n_heads, n_features_per_head]
self.query = linear_general(self.in_features, rngs=rngs)
self.key = linear_general(self.in_kv_features, rngs=rngs)
self.value = linear_general(self.in_kv_features, rngs=rngs)
if self.qk_norm == "rms":
self.query_ln = RMSNorm(self.head_dim)
self.key_ln = RMSNorm(self.head_dim)
else:
self.query_ln = None
self.key_ln = None
self.out = LinearGeneral(
in_features=(self.num_heads, self.head_dim),
out_features=self.out_features,
axis=(-2, -1),
use_bias=self.use_bias,
rngs=rngs,
)
self.use_per_dim_scale = use_per_dim_scale
self.use_rotary_position_embeddings = use_rotary_position_embeddings
if self.use_rotary_position_embeddings:
self.rotary_position_embedding = RotaryPositionalEmbedding(
embedding_dims=self.head_dim,
)
else:
self.rotary_position_embedding = None
if use_per_dim_scale:
self.per_dim_scale = PerDimScale(num_dims=self.head_dim, rngs=rngs)
else:
self.per_dim_scale = None
def __call__(
self,
inputs_q: Array,
*,
decode_cache: DecodeCache | None = None,
patch_mask: Array | None = None,
deterministic: bool | None = None,
sow_weights: bool = False,
) -> tuple[Float[Array, "b ... o"], DecodeCache | None]:
"""Applies multi-head dot product attention on the input data."""
_, n_patches, input_in_features = inputs_q.shape
if input_in_features != self.in_features:
raise ValueError(
f"Incompatible input dimension, got {input_in_features} "
f"but module expects {self.in_features}."
)
if patch_mask is None:
patch_mask = jnp.zeros_like(inputs_q.shape[:-1], dtype=jnp.bool)
# For query: rope -> ln -> per_dim_scale
query = self.query(inputs_q)
key = self.key(inputs_q)
value = self.value(inputs_q)
if decode_cache is None:
num_masked = jnp.sum(patch_mask.astype(jnp.int32), axis=-1, keepdims=False)
next_index = jnp.zeros_like(num_masked, dtype=jnp.int32)
else:
num_masked = (
jnp.sum(patch_mask.astype(jnp.int32), axis=-1, keepdims=False)
+ decode_cache.num_masked
)
next_index = decode_cache.next_index
if self.use_rotary_position_embeddings:
position = (
jnp.arange(n_patches, dtype=jnp.int32)[None, :]
+ next_index[:, None]
- num_masked[:, None]
)
query = self.rotary_position_embedding(query, position)
key = self.rotary_position_embedding(key, position)
if self.query_ln is not None:
query = self.query_ln(query)
if self.key_ln is not None:
key = self.key_ln(key)
if self.use_per_dim_scale:
query = self.per_dim_scale(query)
if decode_cache is not None:
# Cached decoding.
_, decode_cache_size, _, _ = decode_cache.value.shape
zero = jnp.array(0, dtype=lax.dtype(next_index.dtype))
start_indices = (zero, next_index[0], zero, zero)
key = lax.dynamic_update_slice(decode_cache.key, key, start_indices)
value = lax.dynamic_update_slice(decode_cache.value, value, start_indices)
decode_cache.key = key
decode_cache.value = value
decode_cache.next_index = next_index + n_patches
decode_cache.num_masked = num_masked
attn_mask = make_attn_mask(
query_length=n_patches,
num_all_masked_kv=num_masked,
query_index_offset=next_index,
kv_length=decode_cache_size,
)
else:
# Training
attn_mask = make_attn_mask(query_length=n_patches, num_all_masked_kv=num_masked)
# apply attention
x = self.attention_fn(
query * jnp.sqrt(self.head_dim),
key,
value,
mask=attn_mask,
deterministic=deterministic,
module=self if sow_weights else None,
)
# back to the original inputs dimensions
out = self.out(x)
return out, decode_cache
class Transformer(nnx.Module):
"""Classic Transformer used in TimesFM."""
def __init__(self, config: TransformerConfig, *, rngs=nnx.Rngs(42)):
self.config = config
if config.attention_norm == "rms":
self.pre_attn_ln = RMSNorm(num_features=config.model_dims, rngs=rngs)
self.post_attn_ln = RMSNorm(num_features=config.model_dims, rngs=rngs)
else:
raise ValueError(f"Layer norm: {config.attention_norm} not supported.")
self.attn = MultiHeadAttention(
num_heads=config.num_heads,
in_features=config.model_dims,
use_per_dim_scale=True,
use_rotary_position_embeddings=config.use_rotary_position_embeddings,
qk_norm=config.qk_norm,
rngs=rngs,
)
if config.feedforward_norm == "rms":
self.pre_ff_ln = RMSNorm(num_features=config.model_dims, rngs=rngs)
self.post_ff_ln = RMSNorm(num_features=config.model_dims, rngs=rngs)
else:
raise ValueError(f"Layer norm: {config.feedforward_norm} not supported.")
self.ff0 = nnx.Linear(
in_features=config.model_dims,
out_features=config.hidden_dims,
use_bias=config.use_bias,
rngs=rngs,
)
self.ff1 = nnx.Linear(
in_features=config.hidden_dims,
out_features=config.model_dims,
use_bias=config.use_bias,
rngs=rngs,
)
if config.ff_activation == "relu":
self.activation = jax.nn.relu
elif config.ff_activation == "swish":
self.activation = jax.nn.swish
elif config.ff_activation == "none":
self.activation = lambda x: x
else:
raise ValueError(f"Activation: {config.ff_activation} not supported.")
def __call__(
self,
input_embeddings: Float[Array, "b n d"],
patch_mask: Bool[Array, "b n"],
decode_cache: DecodeCache | None = None,
) -> tuple[Float[Array, "b n d"], DecodeCache | None]:
attn_output, decode_cache = self.attn(
inputs_q=self.pre_attn_ln(input_embeddings),
decode_cache=decode_cache,
patch_mask=patch_mask,
sow_weights=False,
deterministic=True,
)
attn_output = self.post_attn_ln(attn_output) + input_embeddings
output_embeddings = (
self.post_ff_ln(self.ff1(self.activation(self.ff0(self.pre_ff_ln(attn_output)))))
+ attn_output
)
return output_embeddings, decode_cache
================================================
FILE: src/timesfm/flax/util.py
================================================
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Flax utility functions for TimesFM layers."""
import dataclasses
import functools
import jax
import jax.numpy as jnp
import jaxtyping
Float = jaxtyping.Float
Array = jaxtyping.Array
Bool = jaxtyping.Bool
Integer = jaxtyping.Integer
_TOLERANCE = 1e-6
@jax.tree_util.register_dataclass
@dataclasses.dataclass(frozen=False)
class DecodeCache:
"""Cache for decoding."""
next_index: Integer[Array, "b"]
num_masked: Integer[Array, "b"]
key: Float[Array, "b n h d"]
value: Float[Array, "b n h d"]
@jax.jit
def update_running_stats(
n: Float[Array, "b"],
mu: Float[Array, "b"],
sigma: Float[Array, "b"],
x: Float[Array, "b p"],
mask: Bool[Array, "b p"],
) -> tuple[
tuple[Float[Array, "b"], Float[Array, "b"], Float[Array, "b"]],
tuple[Float[Array, "b"], Float[Array, "b"], Float[Array, "b"]],
]:
"""Updates the running stats."""
is_legit = jnp.logical_not(mask)
inc_n = jnp.sum(is_legit.astype(jnp.float32), axis=-1, keepdims=False)
inc_mu = jnp.where(
inc_n == 0, 0.0, jnp.mean(x, axis=-1, keepdims=False, where=is_legit)
)
inc_sigma = jnp.where(
inc_n == 0, 0.0, jnp.std(x, axis=-1, keepdims=False, where=is_legit)
)
new_n = n + inc_n
new_mu = jnp.where(new_n == 0, 0.0, (n * mu + inc_mu * inc_n) / new_n)
new_sigma = jnp.sqrt(
jnp.where(
new_n == 0,
0.0,
(
n * sigma * sigma
+ inc_n * inc_sigma * inc_sigma
+ n * (mu - new_mu) * (mu - new_mu)
+ inc_n * (inc_mu - new_mu) * (inc_mu - new_mu)
)
/ new_n,
)
)
return (w := (new_n, new_mu, new_sigma), w)
def scan_along_axis(f, init, xs, axis: int, **kwargs):
"""Scans along an axis."""
moved_xs = jax.tree_util.tree_map(lambda x: jnp.moveaxis(x, axis, 0), xs)
carry, moved_ys = jax.lax.scan(f, init, moved_xs, **kwargs)
return (
carry,
jax.tree_util.tree_map(lambda x: jnp.moveaxis(x, 0, axis), moved_ys),
)
@functools.partial(jax.jit, static_argnames=("reverse",))
def revin(
x: Float[Array, "b ..."],
mu: Float[Array, "b ..."],
sigma: Float[Array, "b ..."],
reverse: bool = False,
):
"""Reversible per-instance normalization."""
if len(mu.shape) == len(x.shape) - 1:
mu = mu[..., None]
sigma = sigma[..., None]
elif len(mu.shape) == len(x.shape) - 2:
mu = mu[..., None, None]
sigma = sigma[..., None, None]
if reverse:
return x * sigma + mu
else:
return (x - mu) / jnp.where(sigma < _TOLERANCE, 1.0, sigma)
================================================
FILE: src/timesfm/timesfm_2p5/timesfm_2p5_base.py
================================================
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TimesFM 2p5 base implementation."""
import dataclasses
from typing import Any, Callable, Sequence
import collections
import numpy as np
from .. import configs
ResidualBlockConfig = configs.ResidualBlockConfig
StackedTransformersConfig = configs.StackedTransformersConfig
TransformerConfig = configs.TransformerConfig
ForecastConfig = configs.ForecastConfig
Category = int | str
XRegMode = str
def strip_leading_nans(arr):
"""Removes contiguous NaN values from the beginning of a NumPy array.
Args:
arr: The input NumPy array.
Returns:
A new NumPy array with leading NaN values removed.
If the array is all NaNs or empty, returns an empty array.
"""
isnan = np.isnan(arr)
first_valid_index = np.argmax(~isnan)
return arr[first_valid_index:]
def linear_interpolation(arr):
"""Performs linear interpolation to fill NaN values in a 1D numpy array.
Args:
arr: The 1D numpy array containing NaN values.
Returns:
A new numpy array with NaN values filled using linear interpolation,
or the original array if no NaNs are present.
Returns None if the input is not a 1D array.
Returns the original array if there are no NaN values.
"""
nans = np.isnan(arr)
if not np.any(nans): # Check if there are any NaNs
return arr
def x(z):
return z.nonzero()[0]
nans_indices = x(nans)
non_nans_indices = x(~nans)
non_nans_values = arr[~nans]
try:
arr[nans] = np.interp(nans_indices, non_nans_indices, non_nans_values)
except ValueError:
if non_nans_values:
mu = np.nanmean(arr)
else:
mu = 0.0
arr = np.where(np.isfinite(arr), arr, mu)
return arr
@dataclasses.dataclass(frozen=True)
class TimesFM_2p5_200M_Definition:
"""Framework-agnostic config of TimesFM 2.5."""
context_limit = 16384
input_patch_len: int = 32
output_patch_len: int = 128
output_quantile_len: int = 1024
quantiles: list[float] = dataclasses.field(
default_factory=lambda: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
)
decode_index: int = 5
tokenizer: ResidualBlockConfig = ResidualBlockConfig(
input_dims=64,
hidden_dims=1280,
output_dims=1280,
use_bias=True,
activation="swish",
)
stacked_transformers: StackedTransformersConfig = StackedTransformersConfig(
num_layers=20,
transformer=TransformerConfig(
model_dims=1280,
hidden_dims=1280,
num_heads=16,
attention_norm="rms",
feedforward_norm="rms",
qk_norm="rms",
use_bias=False,
use_rotary_position_embeddings=True,
ff_activation="swish",
fuse_qkv=True,
),
)
output_projection_point: ResidualBlockConfig = ResidualBlockConfig(
input_dims=1280,
hidden_dims=1280,
output_dims=1280,
use_bias=False,
activation="swish",
)
output_projection_quantiles: ResidualBlockConfig = ResidualBlockConfig(
input_dims=1280,
hidden_dims=1280,
output_dims=10240,
use_bias=False,
activation="swish",
)
class TimesFM_2p5:
"""Abstract base class for TimesFM models.
Attributes:
forecast_config: Configuration for forecasting flags.
compiled_decode: Compiled decode function.
global_batch_size: Global batch size.
"""
forecast_config: ForecastConfig | None = None
compiled_decode: Callable[..., Any] | None = None
global_batch_size: int = 0
def load_checkpoint(self, path: str):
"""Loads a TimesFM model from a checkpoint."""
raise NotImplementedError()
def compile(self, forecast_config: ForecastConfig | None = None):
"""Compiles the TimesFM model for fast decoding."""
raise NotImplementedError()
def forecast(
self, horizon: int, inputs: list[np.ndarray]
) -> tuple[np.ndarray, np.ndarray]:
"""Forecasts the time series."""
if self.compiled_decode is None:
raise RuntimeError("Model is not compiled. Please call compile() first.")
assert self.global_batch_size > 0
assert self.forecast_config is not None
context = self.forecast_config.max_context
num_inputs = len(inputs)
if (w := num_inputs % self.global_batch_size) != 0:
inputs += [np.array([0.0] * 3)] * (self.global_batch_size - w)
output_points = []
output_quantiles = []
values = []
masks = []
idx = 0
for each_input in inputs:
value = linear_interpolation(strip_leading_nans(np.array(each_input)))
if (w := len(value)) >= context:
value = value[-context:]
mask = np.zeros_like(value, dtype=bool)
else:
mask = np.array([True] * (context - w) + [False] * w)
value = np.pad(value, (context - w, 0), "constant", constant_values=0.0)
values.append(value)
masks.append(mask)
idx += 1
if idx == self.global_batch_size:
idx = 0
point_forecast, quantile_forecast = self.compiled_decode(horizon, values, masks)
output_points.append(point_forecast)
output_quantiles.append(quantile_forecast)
values = []
masks = []
output_points = np.concatenate(output_points, axis=0)
output_quantiles = np.concatenate(output_quantiles, axis=0)
return output_points[:num_inputs], output_quantiles[:num_inputs]
def forecast_with_covariates(
self,
inputs: list[Sequence[float]],
dynamic_numerical_covariates: dict[str, Sequence[Sequence[float]]] | None = None,
dynamic_categorical_covariates: (
dict[str, Sequence[Sequence[Category]]] | None
) = None,
static_numerical_covariates: dict[str, Sequence[float]] | None = None,
static_categorical_covariates: dict[str, Sequence[Category]] | None = None,
xreg_mode: XRegMode = "xreg + timesfm",
normalize_xreg_target_per_input: bool = True,
ridge: float = 0.0,
max_rows_per_col: int = 0,
force_on_cpu: bool = False,
):
"""Forecasts on a list of time series with covariates.
To optimize inference speed, avoid string valued categorical covariates.
Args:
inputs: A list of time series forecast contexts. Each context time series
should be in a format convertible to JTensor by `jnp.array`.
dynamic_numerical_covariates: A dict of dynamic numerical covariates.
dynamic_categorical_covariates: A dict of dynamic categorical covariates.
static_numerical_covariates: A dict of static numerical covariates.
static_categorical_covariates: A dict of static categorical covariates.
xreg_mode: one of "xreg + timesfm" or "timesfm + xreg". "timesfm + xreg"
fits a model on the residuals of the TimesFM forecast. "xreg + timesfm"
fits a model on the targets then forecasts on the residuals via TimesFM.
normalize_xreg_target_per_input: whether to normalize the xreg target per
input in the given batch.
ridge: ridge penalty for the linear model.
max_rows_per_col: max number of rows per column for the linear model.
force_on_cpu: whether to force running on cpu for the linear model.
Returns:
A tuple of two lists. The first is the outputs of the model. The second is
the outputs of the xreg.
"""
if self.forecast_config is None:
raise ValueError("Model is not compiled. Please call compile() first.")
elif not self.forecast_config.return_backcast:
raise ValueError(
"For XReg, `return_backcast` must be set to True in the forecast config. Please recompile the model."
)
from ..utils import xreg_lib
# Verify and bookkeep covariates.
if not (
dynamic_numerical_covariates
or dynamic_categorical_covariates
or static_numerical_covariates
or static_categorical_covariates
):
raise ValueError(
"At least one of dynamic_numerical_covariates,"
" dynamic_categorical_covariates, static_numerical_covariates,"
" static_categorical_covariates must be set."
)
# Track the lengths of (1) each input, (2) the part that can be used in the
# linear model, and (3) the horizon.
input_lens, train_lens, test_lens = [], [], []
for i, input_ts in enumerate(inputs):
input_len = len(input_ts)
input_lens.append(input_len)
if xreg_mode == "timesfm + xreg":
# For fitting residuals, no TimesFM forecast on the first patch.
train_lens.append(max(0, input_len - self.model.p))
elif xreg_mode == "xreg + timesfm":
train_lens.append(input_len)
else:
raise ValueError(f"Unsupported mode: {xreg_mode}")
if dynamic_numerical_covariates:
test_lens.append(
len(list(dynamic_numerical_covariates.values())[0][i]) - input_len
)
elif dynamic_categorical_covariates:
test_lens.append(
len(list(dynamic_categorical_covariates.values())[0][i]) - input_len
)
else:
test_lens.append(self.forecast_config.max_horizon)
if test_lens[-1] > self.forecast_config.max_horizon:
raise ValueError(
"Forecast horizon length inferred from the dynamic covaraites is longer than the"
f"max_horizon defined in the forecast config: {test_lens[-1]} > {self.forecast_config.max_horizon=}."
)
# Prepare the covariates into train and test.
train_dynamic_numerical_covariates = collections.defaultdict(list)
test_dynamic_numerical_covariates = collections.defaultdict(list)
train_dynamic_categorical_covariates = collections.defaultdict(list)
test_dynamic_categorical_covariates = collections.defaultdict(list)
for covariates, train_covariates, test_covariates in (
(
dynamic_numerical_covariates,
train_dynamic_numerical_covariates,
test_dynamic_numerical_covariates,
),
(
dynamic_categorical_covariates,
train_dynamic_categorical_covariates,
test_dynamic_categorical_covariates,
),
):
if not covariates:
continue
for covariate_name, covariate_values in covariates.items():
for input_len, train_len, covariate_value in zip(
input_lens, train_lens, covariate_values
):
train_covariates[covariate_name].append(
covariate_value[(input_len - train_len) : input_len]
)
test_covariates[covariate_name].append(covariate_value[input_len:])
# Fit models.
if xreg_mode == "timesfm + xreg":
# Forecast via TimesFM then fit a model on the residuals.
point_outputs, quantile_outputs = self.forecast(
horizon=self.forecast_config.max_horizon, inputs=inputs
)
targets = [
(
np.array(input_ts)[-train_len:]
- point_output[: -self.forecast_config.max_horizon][-train_len:]
)
for input_ts, point_output, train_len in zip(inputs, point_outputs, train_lens)
]
per_instance_stats = None
if normalize_xreg_target_per_input:
targets, per_instance_stats = xreg_lib.normalize(targets)
xregs = xreg_lib.BatchedInContextXRegLinear(
targets=targets,
train_lens=train_lens,
test_lens=test_lens,
train_dynamic_numerical_covariates=train_dynamic_numerical_covariates,
test_dynamic_numerical_covariates=test_dynamic_numerical_covariates,
train_dynamic_categorical_covariates=train_dynamic_categorical_covariates,
test_dynamic_categorical_covariates=test_dynamic_categorical_covariates,
static_numerical_covariates=static_numerical_covariates,
static_categorical_covariates=static_categorical_covariates,
).fit(
ridge=ridge,
one_hot_encoder_drop=None if ridge > 0 else "first",
max_rows_per_col=max_rows_per_col,
force_on_cpu=force_on_cpu,
debug_info=False,
assert_covariates=True,
assert_covariate_shapes=True,
)
if normalize_xreg_target_per_input:
xregs = xreg_lib.renormalize(xregs, per_instance_stats)
xregs = np.array(xregs)
new_point_outputs = [
(point_output[-self.forecast_config.max_horizon :][:test_len] + xreg)
for point_output, test_len, xreg in zip(point_outputs, test_lens, xregs)
]
new_quantile_outputs = [
(
quantile_output[-self.forecast_config.max_horizon :][:test_len]
+ xreg[..., None]
)
for quantile_output, test_len, xreg in zip(quantile_outputs, test_lens, xregs)
]
else:
# Fit a model on the targets then forecast on the residuals via TimesFM.
targets = [
np.array(input_ts)[-train_len:]
for input_ts, train_len in zip(inputs, train_lens)
]
per_instance_stats = None
if normalize_xreg_target_per_input:
targets, per_instance_stats = xreg_lib.normalize(targets)
xregs, xregs_on_context, _, _, _ = xreg_lib.BatchedInContextXRegLinear(
targets=targets,
train_lens=train_lens,
test_lens=test_lens,
train_dynamic_numerical_covariates=train_dynamic_numerical_covariates,
test_dynamic_numerical_covariates=test_dynamic_numerical_covariates,
train_dynamic_categorical_covariates=train_dynamic_categorical_covariates,
test_dynamic_categorical_covariates=test_dynamic_categorical_covariates,
static_numerical_covariates=static_numerical_covariates,
static_categorical_covariates=static_categorical_covariates,
).fit(
ridge=ridge,
one_hot_encoder_drop=None if ridge > 0 else "first",
max_rows_per_col=max_rows_per_col,
force_on_cpu=force_on_cpu,
debug_info=True,
assert_covariates=True,
assert_covariate_shapes=True,
)
point_outputs, quantile_outputs = self.forecast(
horizon=self.forecast_config.max_horizon,
inputs=[
target - xreg_on_context
for target, xreg_on_context in zip(targets, xregs_on_context)
],
)
new_point_outputs = [
(point_output[-self.forecast_config.max_horizon :][:test_len] + xreg)
for point_output, test_len, xreg in zip(point_outputs, test_lens, xregs)
]
new_quantile_outputs = [
(
quantile_output[-self.forecast_config.max_horizon :][:test_len]
+ xreg[..., None]
)
for quantile_output, test_len, xreg in zip(quantile_outputs, test_lens, xregs)
]
if normalize_xreg_target_per_input:
new_point_outputs = xreg_lib.renormalize(new_point_outputs, per_instance_stats)
new_quantile_outputs = xreg_lib.renormalize(
new_quantile_outputs, per_instance_stats
)
return new_point_outputs, new_quantile_outputs
================================================
FILE: src/timesfm/timesfm_2p5/timesfm_2p5_flax.py
================================================
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TimesFM models in Flax."""
import dataclasses
import functools
import gc
import logging
import math
import os
from pathlib import Path
from typing import Any, Callable, Dict
import einshape
from flax import nnx
import huggingface_hub
import jax
import jax.numpy as jnp
import jaxtyping
import numpy as np
import orbax.checkpoint as ocp
from .. import configs
from ..flax import dense, transformer, util
from . import timesfm_2p5_base
jax_einshape = einshape.jax_einshape
scan = util.scan_along_axis
revin = util.revin
Float = jaxtyping.Float
Bool = jaxtyping.Bool
Array = jaxtyping.Array
def try_gc():
for d in jax.local_devices():
stats = d.memory_stats()
if stats is None:
return
if stats["bytes_in_use"] / stats["bytes_limit"] > 0.75:
gc.collect()
break
@nnx.vmap(in_axes=(None, 0), out_axes=0)
def _create_stacked_transformers(
config: configs.StackedTransformersConfig, key: jax.Array
):
return transformer.Transformer(config.transformer, rngs=nnx.Rngs(key))
def _scan_along_axis(f, init, xs, axis: int, **kwargs):
"""Scans along an axis."""
moved_xs = jax.tree_util.tree_map(lambda x: jnp.moveaxis(x, axis, 0), xs)
carry, moved_ys = jax.lax.scan(f, init, moved_xs, **kwargs)
return (
carry,
jax.tree_util.tree_map(lambda x: jnp.moveaxis(x, 0, axis), moved_ys),
)
@nnx.scan(in_axes=(0, nnx.Carry, None, 0), out_axes=(nnx.Carry, 0))
def _apply_stacked_transformers(
model: transformer.Transformer,
x: Float[Array, "b n d"],
m: Float[Array, "b n"],
decode_cache: util.DecodeCache | None = None,
) -> Float[Array, "b n d"]:
return model(x, m, decode_cache=decode_cache)
class TimesFM_2p5_200M_flax_module(nnx.Module): # pylint: disable=invalid-name
"""TimesFM 2.5 with 200M parameters."""
config = timesfm_2p5_base.TimesFM_2p5_200M_Definition()
decode_index: int = 5
compiled_decode: Callable[..., Any] | None = None
backend: str = ""
context: int = 0
horizon: int = 0
per_core_batch_size: int = 0
def __init__(self):
super().__init__()
self.backend = jax.devices()[0].platform
self.num_devices = len(jax.devices(self.backend))
# Names constants.
self.p = self.config.input_patch_len # 32
self.o = self.config.output_patch_len # 128
self.os = self.config.output_quantile_len # 1024
self.m = self.o // self.p # 4
self.x = self.config.stacked_transformers.num_layers # 20
self.h = self.config.stacked_transformers.transformer.num_heads # 16
self.md = self.config.stacked_transformers.transformer.model_dims # 1280
self.hd = self.md // self.h # 80
self.q = len(self.config.quantiles) + 1 # 10
self.aridx = self.config.decode_index # 5
# Layers.
self.tokenizer = dense.ResidualBlock(self.config.tokenizer)
self.stacked_xf = _create_stacked_transformers(
self.config.stacked_transformers,
jax.random.split(jax.random.key(42), self.x),
)
self.output_projection_point = dense.ResidualBlock(
self.config.output_projection_point
)
self.output_projection_quantiles = dense.ResidualBlock(
self.config.output_projection_quantiles
)
def __call__(
self,
inputs: Float[Array, "b n p"],
masks: Bool[Array, "b n p"],
decode_cache: util.DecodeCache | None = None,
):
tokenizer_inputs = jnp.concatenate([inputs, masks.astype(inputs.dtype)], axis=-1)
input_embeddings = self.tokenizer(tokenizer_inputs)
if decode_cache is None:
decode_cache = [None] * self.x
output_embeddings, decode_cache = _apply_stacked_transformers(
self.stacked_xf, input_embeddings, masks[..., -1], decode_cache
)
output_ts = self.output_projection_point(output_embeddings)
output_quantile_spread = self.output_projection_quantiles(output_embeddings)
return (
input_embeddings,
output_embeddings,
output_ts,
output_quantile_spread,
), decode_cache
@nnx.jit(static_argnames=("horizon",))
def decode(self, horizon: int, inputs, masks):
batch_size, context = inputs.shape[0], inputs.shape[1]
num_decode_steps = (horizon - 1) // self.o
num_input_patches = context // self.p
decode_cache_size = num_input_patches + num_decode_steps * self.m
# Prefill
patched_inputs = jax_einshape("b(np)->bnp", inputs, b=batch_size, p=self.p)
patched_masks = jax_einshape("b(np)->bnp", masks, b=batch_size, p=self.p)
(last_n, last_mu, last_sigma), (_, context_mu, context_sigma) = scan(
lambda carry, xs: util.update_running_stats(*carry, *xs),
init=(zero := jnp.zeros(shape=(batch_size)), zero, zero),
xs=(patched_inputs, patched_masks),
axis=1,
)
decode_cache = util.DecodeCache(
next_index=jnp.zeros(shape=(self.x, batch_size), dtype=jnp.int32),
num_masked=jnp.zeros(shape=(self.x, batch_size), dtype=jnp.int32),
key=jnp.zeros(shape=(self.x, batch_size, decode_cache_size, self.h, self.hd)),
value=jnp.zeros(shape=(self.x, batch_size, decode_cache_size, self.h, self.hd)),
)
normed_inputs = revin(patched_inputs, context_mu, context_sigma, reverse=False)
normed_inputs = jnp.where(patched_masks, 0.0, normed_inputs)
(_, _, normed_outputs, normed_quantile_spread), decode_cache = self(
normed_inputs, patched_masks, decode_cache
)
renormed_outputs = jax_einshape(
"bn(oq)->bnoq",
revin(normed_outputs, context_mu, context_sigma, reverse=True),
o=self.o,
q=self.q,
)
renormed_quantile_spread = jax_einshape(
"bn(oq)->bnoq",
revin(normed_quantile_spread, context_mu, context_sigma, reverse=True),
o=self.os,
q=self.q,
)[:, -1, ...]
# Autogressive decode
@nnx.scan(in_axes=(None, nnx.Carry, 0), out_axes=(nnx.Carry, 1))
def _ar_decode(module, carry, unused_iter):
last_renormed_output, (last_n, last_mu, last_sigma), decode_cache = carry
new_patched_input = jax_einshape(
"b(mp)->bmp", last_renormed_output, m=module.m, p=module.p
)
new_mask = jnp.zeros_like(new_patched_input, dtype=jnp.bool)
carry_stats, (_, new_mu, new_sigma) = scan(
lambda carry, xs: util.update_running_stats(*carry, *xs),
init=(last_n, last_mu, last_sigma),
xs=(new_patched_input, new_mask),
axis=1,
)
new_normed_input = revin(new_patched_input, new_mu, new_sigma, reverse=False)
(_, _, new_normed_output, _), decode_cache = module(
new_normed_input, new_mask, decode_cache
)
new_renormed_output = jax_einshape(
"bm(oq)->bmoq",
revin(new_normed_output, new_mu, new_sigma, reverse=True),
o=module.o,
q=module.q,
)[..., -1, :, :]
return (
(
new_renormed_output[..., module.decode_index],
carry_stats,
decode_cache,
),
new_renormed_output,
)
if num_decode_steps > 0:
_, ar_renormed_outputs = _ar_decode(
self,
(
renormed_outputs[..., -1, :, self.decode_index],
(last_n, last_mu, last_sigma),
decode_cache,
),
jnp.arange(num_decode_steps),
)
else:
ar_renormed_outputs = None
return renormed_outputs, renormed_quantile_spread, ar_renormed_outputs
def compile(
self,
context: int,
horizon: int,
per_core_batch_size: int = 1,
):
if context % self.p != 0:
logging.info(
"When compiling, context needs to be multiple of the patch size %d."
" Modifying context to %d.",
self.p,
context := math.ceil(context / self.p) * self.p,
)
if horizon % self.o != 0:
logging.info(
"When compiling, horizon needs to be multiple of the output patch"
" size %d. Modifying horizon to %d.",
self.o,
horizon := math.ceil(horizon / self.o) * self.o,
)
self.context = context
self.horizon = horizon
self.per_core_batch_size = per_core_batch_size
@nnx.pmap(
in_axes=(None, None, 0, 0),
out_axes=(0, 0, 0),
devices=jax.devices(self.backend),
axis_size=self.num_devices,
static_broadcasted_argnums=(1,),
axis_name="global_batch",
)
def compiled_decode_kernel(model, horizon, inputs, masks):
return model.decode(horizon, inputs, masks)
self.compiled_decode = functools.partial(compiled_decode_kernel, self)
def _flip_quantile_fn(x):
return jnp.concatenate([x[..., :1], jnp.flip(x[..., 1:], axis=-1)], axis=-1)
@functools.partial(
jax.jit,
donate_argnums=(0, 1, 2),
)
def _force_flip_invariance_fn(
flipped_pf_outputs,
flipped_quantile_spreads,
flipped_ar_outputs,
):
"""Forces flip invariance."""
flipped_pf_outputs = _flip_quantile_fn(flipped_pf_outputs)
flipped_pf_outputs = jax_einshape("tb...->(tb)...", flipped_pf_outputs)
flipped_quantile_spreads = _flip_quantile_fn(flipped_quantile_spreads)
flipped_quantile_spreads = jax_einshape("tb...->(tb)...", flipped_quantile_spreads)
to_concat = [flipped_pf_outputs[:, -1, ...]]
if flipped_ar_outputs is not None:
flipped_ar_outputs = _flip_quantile_fn(flipped_ar_outputs)
flipped_ar_outputs = jax_einshape("tbno...->(tb)(no)...", flipped_ar_outputs)
to_concat.append(flipped_ar_outputs)
flipped_full_forecast = jnp.concatenate(to_concat, axis=1)
return flipped_quantile_spreads, flipped_pf_outputs, flipped_full_forecast
@functools.partial(
jax.jit,
static_argnames=("max_horizon",),
donate_argnums=(0,),
)
def _use_continuous_quantile_head_fn(full_forecast, quantile_spreads, max_horizon):
"""Uses continuous quantile head."""
to_stack = [full_forecast[..., :max_horizon, 0]]
for quantile_index in [1, 2, 3, 4]:
to_stack.append(
quantile_spreads[:, :max_horizon, quantile_index]
- quantile_spreads[:, :max_horizon, 5]
+ full_forecast[:, :max_horizon, 5]
)
to_stack.append(full_forecast[..., :max_horizon, 5])
for quantile_index in [6, 7, 8, 9]:
to_stack.append(
quantile_spreads[:, :max_horizon, quantile_index]
- quantile_spreads[:, :max_horizon, 5]
+ full_forecast[:, :max_horizon, 5]
)
return jnp.stack(to_stack, axis=-1)
@functools.partial(jax.jit, donate_argnums=(0,))
def _fix_quantile_crossing_fn(full_forecast):
"""Fixes quantile crossing."""
lower_quantiles = _scan_along_axis(
lambda carry, x: (w := jnp.minimum(carry, x), w),
init=full_forecast[..., 5],
xs=full_forecast[..., 1:5],
axis=-1,
reverse=True,
)[1]
upper_quantiles = _scan_along_axis(
lambda carry, x: (w := jnp.maximum(carry, x), w),
init=full_forecast[..., 5],
xs=full_forecast[..., 6:10],
axis=-1,
reverse=False,
)[1]
return jnp.concatenate(
[
full_forecast[..., :1],
lower_quantiles,
full_forecast[..., 5:6],
upper_quantiles,
],
axis=-1,
)
@functools.partial(jax.jit, static_argnames=("fc",), donate_argnums=(1, 2))
def _before_model_decode(fc, inputs, masks):
"""All Jax steps before model decode call."""
if fc.infer_is_positive:
is_positive = jnp.all(inputs >= 0, axis=-1, keepdims=True)
else:
is_positive = None
if fc.normalize_inputs:
mu = jnp.mean(inputs, axis=-1, keepdims=True)
sigma = jnp.std(inputs, axis=-1, keepdims=True)
inputs = revin(inputs, mu, sigma, reverse=False)
else:
mu, sigma = None, None
inputs = jax_einshape("(tb)...->tb...", inputs, b=fc.per_core_batch_size)
masks = jax_einshape("(tb)...->tb...", masks, b=fc.per_core_batch_size)
return inputs, masks, is_positive, mu, sigma
@functools.partial(
jax.jit,
static_argnames=(
"fc",
"p",
),
donate_argnums=(1, 2, 3, 4, 5, 6, 7, 8, 9),
)
def _after_model_decode(
fc,
pf_outputs,
quantile_spreads,
ar_outputs,
flipped_pf_outputs,
flipped_quantile_spreads,
flipped_ar_outputs,
is_positive,
mu,
sigma,
p,
):
"""All Jax steps after model decode call."""
# t: num_devices, b: per_core_batch_size
pf_outputs = jax_einshape("tb...->(tb)...", pf_outputs)
quantile_spreads = jax_einshape("tb...->(tb)...", quantile_spreads)
to_concat = [pf_outputs[:, -1, ...]]
if ar_outputs is not None:
ar_outputs = jax_einshape("tbno...->(tb)(no)...", ar_outputs)
to_concat.append(ar_outputs)
full_forecast = jnp.concatenate(to_concat, axis=1)
if fc.force_flip_invariance:
(
flipped_quantile_spreads,
flipped_pf_outputs,
flipped_full_forecast,
) = _force_flip_invariance_fn(
flipped_pf_outputs, flipped_quantile_spreads, flipped_ar_outputs
)
quantile_spreads = (quantile_spreads - flipped_quantile_spreads) / 2
pf_outputs = (pf_outputs - flipped_pf_outputs) / 2
full_forecast = (full_forecast - flipped_full_forecast) / 2
if fc.use_continuous_quantile_head:
full_forecast = _use_continuous_quantile_head_fn(
full_forecast, quantile_spreads, fc.max_horizon
)
if fc.return_backcast:
full_backcast = jax_einshape("...npq->...(np)q", pf_outputs[:, :-1, :p, :])
full_forecast = jnp.concatenate([full_backcast, full_forecast], axis=1)
if fc.fix_quantile_crossing:
full_forecast = _fix_quantile_crossing_fn(full_forecast)
if fc.normalize_inputs:
full_forecast = revin(full_forecast, mu, sigma, reverse=True)
if is_positive is not None:
full_forecast = jnp.where(
is_positive[..., None],
jnp.maximum(full_forecast, jnp.zeros_like(full_forecast)),
full_forecast,
)
return full_forecast
class TimesFM_2p5_200M_flax(timesfm_2p5_base.TimesFM_2p5):
"""Flax implementation of TimesFM 2.5 with 200M parameters."""
model: nnx.Module = TimesFM_2p5_200M_flax_module()
@classmethod
def from_pretrained(
cls,
model_id: str = "google/timesfm-2.5-200m-flax",
*,
revision: str | None = None,
cache_dir: str | Path | None = None,
force_download: bool = False,
proxies: Dict | None = None,
resume_download: bool | None = None,
local_files_only: bool | None = None,
token: str | None = None,
**model_kwargs,
):
"""Loads a Flax TimesFM model."""
# Create an instance of the model wrapper class.
instance = cls(**model_kwargs)
# Determine the path to the model weights.
model_file_path = ""
if os.path.isdir(model_id):
logging.info("Loading checkpoint from local directory: %s", model_id)
model_file_path = model_id
else:
logging.info("Downloading checkpoint from Hugging Face repo %s", model_id)
model_file_path = huggingface_hub.snapshot_download(
repo_id=model_id,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
logging.info("Loading checkpoint from: %s", model_file_path)
checkpointer = ocp.StandardCheckpointer()
graph, state = nnx.split(instance.model)
state = checkpointer.restore(model_file_path, state)
instance.model = nnx.merge(graph, state)
return instance
def compile(
self,
forecast_config: configs.ForecastConfig,
dryrun: bool = True,
**kwargs
):
# Acrobym used during validation.
print("Compiling model...")
fc = forecast_config
if fc.max_context % self.model.p != 0:
logging.info(
"When compiling, max context needs to be multiple of the patch size"
" %d. Using max context = %d instead.",
self.model.p,
new_context := math.ceil(fc.max_context / self.model.p) * self.model.p,
)
fc = dataclasses.replace(fc, max_context=new_context)
if fc.max_horizon % self.model.o != 0:
logging.info(
"When compiling, max horizon needs to be multiple of the output patch"
" size %d. Using max horizon = %d instead.",
self.model.o,
new_horizon := math.ceil(fc.max_horizon / self.model.o) * self.model.o,
)
fc = dataclasses.replace(fc, max_horizon=new_horizon)
if fc.max_context + fc.max_horizon > self.model.config.context_limit:
raise ValueError(
"Context + horizon must be less than the context limit."
f" {fc.max_context} + {fc.max_horizon} >"
f" {self.model.config.context_limit}."
)
if fc.use_continuous_quantile_head and (fc.max_horizon > self.model.os):
raise ValueError(
f"Continuous quantile head is not supported for horizons > {self.model.os}."
)
self.forecast_config = fc
self.model.compile(
context=self.forecast_config.max_context,
horizon=self.forecast_config.max_horizon,
per_core_batch_size=fc.per_core_batch_size,
)
self.per_core_batch_size = self.forecast_config.per_core_batch_size
self.num_devices = self.model.num_devices
self.global_batch_size = (
self.forecast_config.per_core_batch_size * self.model.num_devices
)
def compiled_decode_kernel(fc, horizon, inputs, masks):
inputs = jnp.array(inputs, dtype=jnp.float32)
masks = jnp.array(masks, dtype=jnp.bool)
if horizon > fc.max_horizon:
raise ValueError(
f"Horizon must be less than the max horizon. {horizon} > {fc.max_horizon}."
)
to_trim = fc.max_horizon - horizon
inputs, masks, is_positive, mu, sigma = _before_model_decode(fc, inputs, masks)
pf_outputs, quantile_spreads, ar_outputs = self.model.compiled_decode(
fc.max_horizon, inputs, masks
)
if fc.force_flip_invariance:
flipped_pf_outputs, flipped_quantile_spreads, flipped_ar_outputs = (
self.model.compiled_decode(fc.max_horizon, -inputs, masks)
)
else:
flipped_pf_outputs, flipped_quantile_spreads, flipped_ar_outputs = (
None,
None,
None,
)
full_forecast = _after_model_decode(
fc,
pf_outputs,
quantile_spreads,
ar_outputs,
flipped_pf_outputs,
flipped_quantile_spreads,
flipped_ar_outputs,
is_positive,
mu,
sigma,
self.model.p,
)
full_forecast_np = np.array(full_forecast)
del full_forecast
try_gc()
if to_trim > 0:
full_forecast_np = full_forecast_np[..., :-to_trim, :]
return full_forecast_np[..., 5], full_forecast_np
self.compiled_decode = functools.partial(
compiled_decode_kernel, self.forecast_config
)
if dryrun:
_ = self.compiled_decode(
self.forecast_config.max_horizon,
jnp.zeros(
(self.global_batch_size, self.forecast_config.max_context), dtype=jnp.float32
),
jnp.zeros(
(self.global_batch_size, self.forecast_config.max_context), dtype=jnp.bool
),
)
print("Compiling done.")
================================================
FILE: src/timesfm/timesfm_2p5/timesfm_2p5_torch.py
================================================
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TimesFM models."""
import dataclasses
import logging
import math
import os
from pathlib import Path
from typing import Optional, Sequence, Union
import numpy as np
import torch
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
from safetensors.torch import load_file, save_file
from torch import nn
from .. import configs
from ..torch import dense, transformer, util
from . import timesfm_2p5_base
revin = util.revin
class TimesFM_2p5_200M_torch_module(nn.Module):
"""TimesFM 2.5 with 200M parameters."""
config = timesfm_2p5_base.TimesFM_2p5_200M_Definition()
def __init__(self):
super().__init__()
# Names constants.
self.p = self.config.input_patch_len # 32
self.o = self.config.output_patch_len # 128
self.os = self.config.output_quantile_len # 1024
self.m = self.o // self.p # 4
self.x = self.config.stacked_transformers.num_layers # 20
self.h = self.config.stacked_transformers.transformer.num_heads # 16
self.md = self.config.stacked_transformers.transformer.model_dims # 1280
self.hd = self.md // self.h # 80
self.q = len(self.config.quantiles) + 1 # 10
self.aridx = self.config.decode_index # 5
# Layers.
self.tokenizer = dense.ResidualBlock(self.config.tokenizer)
self.stacked_xf = nn.ModuleList(
[
transformer.Transformer(self.config.stacked_transformers.transformer)
for _ in range(self.x)
]
)
self.output_projection_point = dense.ResidualBlock(
self.config.output_projection_point
)
self.output_projection_quantiles = dense.ResidualBlock(
self.config.output_projection_quantiles
)
# Device.
if torch.cuda.is_available():
self.device = torch.device("cuda:0")
self.device_count = torch.cuda.device_count()
else:
self.device = torch.device("cpu")
self.device_count = 1
def load_checkpoint(self, path: str, **kwargs):
"""Loads a PyTorch TimesFM model from a checkpoint."""
tensors = load_file(path)
self.load_state_dict(tensors, strict=True)
self.to(self.device)
torch_compile = True
if "torch_compile" in kwargs:
torch_compile = kwargs["torch_compile"]
if torch_compile:
print("Compiling model...")
self = torch.compile(self)
self.eval()
def forward(
self,
inputs: torch.Tensor,
masks: torch.Tensor,
decode_caches: list[util.DecodeCache] | None = None,
):
tokenizer_inputs = torch.cat([inputs, masks.to(inputs.dtype)], dim=-1)
input_embeddings = self.tokenizer(tokenizer_inputs)
if decode_caches is None:
decode_caches = [None] * self.x
output_embeddings = input_embeddings
new_decode_caches = []
for i, layer in enumerate(self.stacked_xf):
output_embeddings, new_cache = layer(
output_embeddings, masks[..., -1], decode_caches[i]
)
new_decode_caches.append(new_cache)
output_ts = self.output_projection_point(output_embeddings)
output_quantile_spread = self.output_projection_quantiles(output_embeddings)
return (
input_embeddings,
output_embeddings,
output_ts,
output_quantile_spread,
), new_decode_caches
def decode(self, horizon: int, inputs, masks):
"""Decodes the time series."""
with torch.no_grad():
batch_size, context = inputs.shape[0], inputs.shape[1]
num_decode_steps = (horizon - 1) // self.o
num_input_patches = context // self.p
decode_cache_size = num_input_patches + num_decode_steps * self.m
# Prefill
patched_inputs = torch.reshape(inputs, (batch_size, -1, self.p))
patched_masks = torch.reshape(masks, (batch_size, -1, self.p))
# running stats
n = torch.zeros(batch_size, device=inputs.device)
mu = torch.zeros(batch_size, device=inputs.device)
sigma = torch.zeros(batch_size, device=inputs.device)
patch_mu = []
patch_sigma = []
for i in range(num_input_patches):
(n, mu, sigma), _ = util.update_running_stats(
n, mu, sigma, patched_inputs[:, i], patched_masks[:, i]
)
patch_mu.append(mu)
patch_sigma.append(sigma)
last_n, last_mu, last_sigma = n, mu, sigma
context_mu = torch.stack(patch_mu, dim=1)
context_sigma = torch.stack(patch_sigma, dim=1)
decode_caches = [
util.DecodeCache(
next_index=torch.zeros(batch_size, dtype=torch.int32, device=inputs.device),
num_masked=torch.zeros(batch_size, dtype=torch.int32, device=inputs.device),
key=torch.zeros(
batch_size,
decode_cache_size,
self.h,
self.hd,
device=inputs.device,
),
value=torch.zeros(
batch_size,
decode_cache_size,
self.h,
self.hd,
device=inputs.device,
),
)
for _ in range(self.x)
]
normed_inputs = revin(patched_inputs, context_mu, context_sigma, reverse=False)
normed_inputs = torch.where(patched_masks, 0.0, normed_inputs)
(_, _, normed_outputs, normed_quantile_spread), decode_caches = self(
normed_inputs, patched_masks, decode_caches
)
renormed_outputs = torch.reshape(
revin(normed_outputs, context_mu, context_sigma, reverse=True),
(batch_size, -1, self.o, self.q),
)
renormed_quantile_spread = torch.reshape(
revin(normed_quantile_spread, context_mu, context_sigma, reverse=True),
(batch_size, -1, self.os, self.q),
)[:, -1, ...]
# Autogressive decode
ar_outputs = []
last_renormed_output = renormed_outputs[:, -1, :, self.aridx]
for _ in range(num_decode_steps):
new_patched_input = torch.reshape(
last_renormed_output, (batch_size, self.m, self.p)
)
new_mask = torch.zeros_like(new_patched_input, dtype=torch.bool)
n, mu, sigma = last_n, last_mu, last_sigma
new_mus, new_sigmas = [], []
for i in range(self.m):
(n, mu, sigma), _ = util.update_running_stats(
n, mu, sigma, new_patched_input[:, i], new_mask[:, i]
)
new_mus.append(mu)
new_sigmas.append(sigma)
last_n, last_mu, last_sigma = n, mu, sigma
new_mu = torch.stack(new_mus, dim=1)
new_sigma = torch.stack(new_sigmas, dim=1)
new_normed_input = revin(new_patched_input, new_mu, new_sigma, reverse=False)
(_, _, new_normed_output, _), decode_caches = self(
new_normed_input, new_mask, decode_caches
)
new_renormed_output = torch.reshape(
revin(new_normed_output, new_mu, new_sigma, reverse=True),
(batch_size, self.m, self.o, self.q),
)
ar_outputs.append(new_renormed_output[:, -1, ...])
last_renormed_output = new_renormed_output[:, -1, :, self.aridx]
if num_decode_steps > 0:
ar_renormed_outputs = torch.stack(ar_outputs, dim=1)
else:
ar_renormed_outputs = None
return renormed_outputs, renormed_quantile_spread, ar_renormed_outputs
def forecast_naive(
self, horizon: int, inputs: Sequence[np.ndarray]
) -> list[np.ndarray]:
"""Forecasts the time series.
This is a naive implementation for debugging purposes. No forecasting
flags are used here. Forecasting quality can be subpar.
Args:
horizon: The number of time points to forecast.
inputs: A sequence of numpy arrays, each representing a time series to
query forecast for.
Returns:
A list of numpy arrays of forecasts.
"""
outputs = []
for each_input in inputs:
input_t = torch.tensor(each_input, dtype=torch.float32)
mask = torch.zeros_like(input_t, dtype=torch.bool)
len_front_mask = self.p - (len(each_input) % self.p)
if len_front_mask < self.p:
input_t = torch.cat(
[torch.zeros(len_front_mask, dtype=torch.float32), input_t], dim=0
)
mask = torch.cat([torch.ones(len_front_mask, dtype=torch.bool), mask], dim=0)
input_t = input_t[None, ...]
mask = mask[None, ...]
t_pf, _, t_ar = self.decode(horizon, input_t, mask)
to_concat = [t_pf[:, -1, ...]]
if t_ar is not None:
to_concat.append(t_ar.reshape(1, -1, self.q))
torch_forecast = torch.cat(to_concat, dim=1)[..., :horizon]
torch_forecast = torch_forecast.squeeze(0)
outputs.append(torch_forecast.detach().cpu().numpy())
return outputs
class TimesFM_2p5_200M_torch(
timesfm_2p5_base.TimesFM_2p5,
PyTorchModelHubMixin,
library_name="timesfm",
repo_url="https://github.com/google-research/timesfm",
paper_url="https://arxiv.org/abs/2310.10688",
docs_url="https://github.com/google-research/timesfm",
license="apache-2.0",
pipeline_tag="time-series-forecasting",
tags=["pytorch", "timeseries", "forecasting", "timesfm-2.5"],
):
"""PyTorch implementation of TimesFM 2.5 with 200M parameters."""
DEFAULT_REPO_ID = "google/timesfm-2.5-200m-pytorch"
WEIGHTS_FILENAME = "model.safetensors"
def __init__(
self,
torch_compile: bool = True,
config: Optional[dict] = None,
):
self.model = TimesFM_2p5_200M_torch_module()
self.torch_compile = torch_compile
if config is not None:
self._hub_mixin_config = config
@classmethod
def _from_pretrained(
cls,
*,
model_id: str = DEFAULT_REPO_ID,
revision: Optional[str],
cache_dir: Optional[Union[str, Path]],
force_download: bool = False,
local_files_only: bool,
token: Optional[Union[str, bool]],
config: Optional[dict] = None,
**model_kwargs,
):
"""
Loads a PyTorch safetensors TimesFM model from a local path or the Hugging
Face Hub. This method is the backend for the `from_pretrained` class
method provided by `PyTorchModelHubMixin`.
"""
# Determine the path to the model weights.
model_file_path = ""
if os.path.isdir(model_id):
logging.info("Loading checkpoint from local directory: %s", model_id)
model_file_path = os.path.join(model_id, cls.WEIGHTS_FILENAME)
if not os.path.exists(model_file_path):
raise FileNotFoundError(
f"{cls.WEIGHTS_FILENAME} not found in directory {model_id}"
)
else:
logging.info("Downloading checkpoint from Hugging Face repo %s", model_id)
model_file_path = hf_hub_download(
repo_id=model_id,
filename=cls.WEIGHTS_FILENAME,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
token=token,
local_files_only=local_files_only,
)
# Create an instance of the model wrapper class.
instance = cls(config=config, **model_kwargs)
logging.info("Loading checkpoint from: %s", model_file_path)
# Load the weights into the model.
instance.model.load_checkpoint(
model_file_path, torch_compile=instance.torch_compile
)
return instance
def _save_pretrained(self, save_directory: Union[str, Path]):
"""
Saves the model's state dictionary to a safetensors file. This method
is called by the `save_pretrained` method from `PyTorchModelHubMixin`.
"""
if not os.path.exists(save_directory):
os.makedirs(save_directory)
weights_path = os.path.join(save_directory, self.WEIGHTS_FILENAME)
save_file(self.model.state_dict(), weights_path)
def compile(self, forecast_config: configs.ForecastConfig, **kwargs) -> None:
"""Attempts to compile the model for fast decoding.
See configs.ForecastConfig for more details on the supported flags.
Args:
forecast_config: Configuration for forecasting flags.
**kwargs: Additional keyword arguments to pass to model.compile().
"""
self.global_batch_size = (
forecast_config.per_core_batch_size * self.model.device_count
)
# Shortcut.
fc = forecast_config
if fc.max_context % self.model.p != 0:
logging.info(
"When compiling, max context needs to be multiple of the patch size"
" %d. Using max context = %d instead.",
self.model.p,
new_context := math.ceil(fc.max_context / self.model.p) * self.model.p,
)
fc = dataclasses.replace(fc, max_context=new_context)
if fc.max_horizon % self.model.o != 0:
logging.info(
"When compiling, max horizon needs to be multiple of the output patch"
" size %d. Using max horizon = %d instead.",
self.model.o,
new_horizon := math.ceil(fc.max_horizon / self.model.o) * self.model.o,
)
fc = dataclasses.replace(fc, max_horizon=new_horizon)
if fc.max_context + fc.max_horizon > self.model.config.context_limit:
raise ValueError(
"Context + horizon must be less than the context limit."
f" {fc.max_context} + {fc.max_horizon} >"
f" {self.model.config.context_limit}."
)
if fc.use_continuous_quantile_head and (fc.max_horizon > self.model.os):
raise ValueError(
f"Continuous quantile head is not supported for horizons > {self.model.os}."
)
self.forecast_config = fc
def _compiled_decode(horizon, inputs, masks):
if horizon > fc.max_horizon:
raise ValueError(
f"Horizon must be less than the max horizon. {horizon} > {fc.max_horizon}."
)
inputs = (
torch.from_numpy(np.array(inputs)).to(self.model.device).to(torch.float32)
)
masks = torch.from_numpy(np.array(masks)).to(self.model.device).to(torch.bool)
batch_size = inputs.shape[0]
if fc.infer_is_positive:
is_positive = torch.all(inputs >= 0, dim=-1, keepdim=True)
else:
is_positive = None
if fc.normalize_inputs:
mu = torch.mean(inputs, dim=-1, keepdim=True)
sigma = torch.std(inputs, dim=-1, keepdim=True)
inputs = revin(inputs, mu, sigma, reverse=False)
else:
mu, sigma = None, None
pf_outputs, quantile_spreads, ar_outputs = self.model.decode(
forecast_config.max_horizon, inputs, masks
)
to_cat = [pf_outputs[:, -1, ...]]
if ar_outputs is not None:
to_cat.append(ar_outputs.reshape(batch_size, -1, self.model.q))
full_forecast = torch.cat(to_cat, dim=1)
def flip_quantile_fn(x):
return torch.cat([x[..., :1], torch.flip(x[..., 1:], dims=(-1,))], dim=-1)
if fc.force_flip_invariance:
flipped_pf_outputs, flipped_quantile_spreads, flipped_ar_outputs = (
self.model.decode(forecast_config.max_horizon, -inputs, masks)
)
flipped_quantile_spreads = flip_quantile_fn(flipped_quantile_spreads)
flipped_pf_outputs = flip_quantile_fn(flipped_pf_outputs)
to_cat = [flipped_pf_outputs[:, -1, ...]]
if flipped_ar_outputs is not None:
to_cat.append(flipped_ar_outputs.reshape(batch_size, -1, self.model.q))
flipped_full_forecast = torch.cat(to_cat, dim=1)
quantile_spreads = (quantile_spreads - flipped_quantile_spreads) / 2
pf_outputs = (pf_outputs - flipped_pf_outputs) / 2
full_forecast = (full_forecast - flipped_full_forecast) / 2
if fc.use_continuous_quantile_head:
for quantile_index in [1, 2, 3, 4, 6, 7, 8, 9]:
full_forecast[:, :, quantile_index] = (
quantile_spreads[:, : fc.max_horizon, quantile_index]
- quantile_spreads[:, : fc.max_horizon, 5]
+ full_forecast[:, : fc.max_horizon, 5]
)
full_forecast = full_forecast[:, :horizon, :]
if fc.return_backcast:
full_backcast = pf_outputs[:, :-1, : self.model.p, :].reshape(
batch_size, -1, self.model.q
)
full_forecast = torch.cat([full_backcast, full_forecast], dim=1)
if fc.fix_quantile_crossing:
for i in [4, 3, 2, 1]:
full_forecast[:, :, i] = torch.where(
full_forecast[:, :, i] < full_forecast[:, :, i + 1],
full_forecast[:, :, i],
full_forecast[:, :, i + 1],
)
for i in [6, 7, 8, 9]:
full_forecast[:, :, i] = torch.where(
full_forecast[:, :, i] > full_forecast[:, :, i - 1],
full_forecast[:, :, i],
full_forecast[:, :, i - 1],
)
if fc.normalize_inputs:
full_forecast = revin(full_forecast, mu, sigma, reverse=True)
if is_positive is not None:
full_forecast = torch.where(
is_positive[..., None],
torch.maximum(full_forecast, torch.zeros_like(full_forecast)),
full_forecast,
)
full_forecast = full_forecast.detach().cpu().numpy()
return full_forecast[..., 5], full_forecast
self.compiled_decode = _compiled_decode
================================================
FILE: src/timesfm/torch/__init__.py
================================================
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
================================================
FILE: src/timesfm/torch/dense.py
================================================
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Dense layers for TimesFM."""
import torch
from torch import nn
from .. import configs
class ResidualBlock(nn.Module):
"""Residual block with two linear layers and a linear residual connection."""
def __init__(self, config: configs.ResidualBlockConfig):
super().__init__()
self.config = config
self.hidden_layer = nn.Linear(
in_features=config.input_dims,
out_features=config.hidden_dims,
bias=config.use_bias,
)
self.output_layer = nn.Linear(
in_features=config.hidden_dims,
out_features=config.output_dims,
bias=config.use_bias,
)
self.residual_layer = nn.Linear(
in_features=config.input_dims,
out_features=config.output_dims,
bias=config.use_bias,
)
if config.activation == "relu":
self.activation = nn.ReLU()
elif config.activation == "swish":
self.activation = nn.SiLU()
elif config.activation == "none":
self.activation = nn.Identity()
else:
raise ValueError(f"Activation: {config.activation} not supported.")
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.output_layer(
self.activation(self.hidden_layer(x))
) + self.residual_layer(x)
class RandomFourierFeatures(nn.Module):
"""Random Fourier features layer."""
def __init__(self, config: configs.RandomFourierFeaturesConfig):
super().__init__()
self.config = config
if config.output_dims % 4 != 0:
raise ValueError(
f"Output dims must be a multiple of 4: {config.output_dims} % 4 != 0."
)
num_projected_features = config.output_dims // 4
self.phase_shifts = nn.Parameter(torch.zeros(2, num_projected_features))
self.projection_layer = nn.Linear(
in_features=config.input_dims,
out_features=num_projected_features,
bias=config.use_bias,
)
self.residual_layer = nn.Linear(
in_features=config.input_dims,
out_features=config.output_dims,
bias=config.use_bias,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
projected = self.projection_layer(x)
cos_features = torch.cos(projected)
sin_features = torch.sin(projected)
sq_wave_1 = torch.sign(torch.sin(projected + self.phase_shifts[0, :]))
sq_wave_2 = torch.sign(torch.sin(projected + self.phase_shifts[1, :]))
fourier_features = torch.cat(
[cos_features, sin_features, sq_wave_1, sq_wave_2], dim=-1
)
residual = self.residual_layer(x)
return fourier_features + residual
================================================
FILE: src/timesfm/torch/normalization.py
================================================
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Normalization layers for TimesFM."""
import torch
from torch import nn
class RMSNorm(nn.Module):
"""RMS normalization."""
def __init__(
self,
num_features: int,
*,
epsilon: float = 1e-6,
):
super().__init__()
self.scale = nn.Parameter(torch.zeros(num_features))
self.num_features = num_features
self.epsilon = epsilon
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
var = torch.mean(torch.square(inputs), dim=-1, keepdim=True)
normed_inputs = inputs * torch.rsqrt(var + self.epsilon)
normed_inputs = normed_inputs * self.scale
return normed_inputs
================================================
FILE: src/timesfm/torch/transformer.py
================================================
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformer layers for TimesFM."""
import math
from typing import Callable
import torch
import torch.nn.functional as F
from torch import nn
from .. import configs
from . import normalization, util
LayerNorm = nn.LayerNorm
RMSNorm = normalization.RMSNorm
DecodeCache = util.DecodeCache
def make_attn_mask(
query_length: int,
num_all_masked_kv: torch.Tensor,
query_index_offset: torch.Tensor | None = None,
kv_length: int = 0,
) -> torch.Tensor:
"""Makes attention mask."""
if kv_length == 0:
kv_length = query_length
q_index = torch.arange(query_length, device=num_all_masked_kv.device)[
None, None, :, None
]
if query_index_offset is not None:
q_index = q_index + query_index_offset[:, None, None, None]
kv_index = torch.arange(kv_length, device=num_all_masked_kv.device)[
None, None, None, :
]
return torch.logical_and(
q_index >= kv_index,
kv_index >= num_all_masked_kv[:, None, None, None],
)
class RotaryPositionalEmbedding(nn.Module):
"""Rotary positional embedding."""
def __init__(
self,
embedding_dims: int,
min_timescale: float = 1.0,
max_timescale: float = 10000.0,
):
super().__init__()
self.embedding_dims = embedding_dims
self.min_timescale = min_timescale
self.max_timescale = max_timescale
def forward(
self,
inputs: torch.Tensor,
position: torch.Tensor | None = None,
):
"""Generates a JTensor of sinusoids with different frequencies."""
if self.embedding_dims != inputs.shape[-1]:
raise ValueError(
"The embedding dims of the rotary position embedding"
"must match the hidden dimension of the inputs."
)
half_embedding_dim = self.embedding_dims // 2
fraction = (
2
* torch.arange(0, half_embedding_dim, device=inputs.device)
/ self.embedding_dims
)
timescale = (
self.min_timescale * (self.max_timescale / self.min_timescale) ** fraction
).to(inputs.device)
if position is None:
seq_length = inputs.shape[1]
position = torch.arange(seq_length, dtype=torch.float32, device=inputs.device)[
None, :
]
if len(inputs.shape) == 4:
position = position[..., None, None]
timescale = timescale[None, None, None, :]
elif len(inputs.shape) == 3:
position = position[..., None]
timescale = timescale[None, None, :]
else:
raise ValueError("Inputs must be of rank 3 or 4.")
sinusoid_inp = position / timescale
sin = torch.sin(sinusoid_inp)
cos = torch.cos(sinusoid_inp)
first_half, second_half = torch.chunk(inputs, 2, dim=-1)
first_part = first_half * cos - second_half * sin
second_part = second_half * cos + first_half * sin
return torch.cat([first_part, second_part], dim=-1)
def _dot_product_attention(
query,
key,
value,
mask=None,
):
"""Computes dot-product attention given query, key, and value."""
attn_weights = torch.einsum("...qhd,...khd->...hqk", query, key)
if mask is not None:
attn_weights = torch.where(
mask, attn_weights, -torch.finfo(attn_weights.dtype).max / 2
)
attn_weights = F.softmax(attn_weights, dim=-1)
return torch.einsum("...hqk,...khd->...qhd", attn_weights, value)
def _torch_dot_product_attention(query, key, value, mask=None):
"""
Performs the exact same (unscaled) attention as the above function,
but using the fast and fused F.scaled_dot_product_attention kernel.
"""
# 1. Permute inputs from (B, L, H, D) to the expected (B, H, L, D)
query = query.permute(0, 2, 1, 3)
key = key.permute(0, 2, 1, 3)
value = value.permute(0, 2, 1, 3)
# 2. Call the fused attention kernel
# - Pass the mask to `attn_mask`.
# - Set `scale=1.0` to disable the default 1/sqrt(d_k) scaling.
output = F.scaled_dot_product_attention(query, key, value, attn_mask=mask, scale=1.0)
# 3. Permute the output back to the original (B, L, H, D) layout
output = output.permute(0, 2, 1, 3)
return output
class PerDimScale(nn.Module):
"""Per-dimension scaling."""
def __init__(self, num_dims: int):
super().__init__()
self.num_dims = num_dims
self.per_dim_scale = nn.Parameter(torch.zeros(num_dims))
def forward(self, x: torch.Tensor) -> torch.Tensor:
scale_factor = (
1.442695041 / math.sqrt(self.num_dims) * F.softplus(self.per_dim_scale)
)
return x * scale_factor
class MultiHeadAttention(nn.Module):
"""Multi-head attention."""
def __init__(
self,
num_heads: int,
in_features: int,
*,
use_per_dim_scale: bool = True,
use_rotary_position_embeddings: bool = True,
use_bias: bool = False,
attention_fn: Callable[..., torch.Tensor] = _torch_dot_product_attention,
qk_norm: str = "rms",
fuse_qkv: bool = False,
):
super().__init__()
self.num_heads = num_heads
self.in_features = in_features
self.head_dim = in_features // num_heads
self.use_bias = use_bias
self.attention_fn = attention_fn
self.qk_norm = qk_norm
self.fuse_qkv = fuse_qkv
if self.in_features % self.num_heads != 0:
raise ValueError(
f"Memory dimension ({self.in_features}) must be divisible by "
f"'num_heads' heads ({self.num_heads})."
)
if self.fuse_qkv:
self.qkv_proj = nn.Linear(self.in_features, 3 * self.in_features, bias=use_bias)
else:
self.query = nn.Linear(self.in_features, self.in_features, bias=use_bias)
self.key = nn.Linear(self.in_features, self.in_features, bias=use_bias)
self.value = nn.Linear(self.in_features, self.in_features, bias=use_bias)
self.out = nn.Linear(self.in_features, self.in_features, bias=use_bias)
if self.qk_norm == "rms":
self.query_ln = RMSNorm(self.head_dim)
self.key_ln = RMSNorm(self.head_dim)
else:
self.query_ln = nn.Identity()
self.key_ln = nn.Identity()
self.use_rotary_position_embeddings = use_rotary_position_embeddings
if self.use_rotary_position_embeddings:
self.rotary_position_embedding = RotaryPositionalEmbedding(
embedding_dims=self.head_dim,
)
self.use_per_dim_scale = use_per_dim_scale
if use_per_dim_scale:
self.per_dim_scale = PerDimScale(num_dims=self.head_dim)
def forward(
self,
inputs_q: torch.Tensor,
*,
decode_cache: DecodeCache | None = None,
patch_mask: torch.Tensor | None = None,
) -> tuple[torch.Tensor, DecodeCache | None]:
b, n_patches, _ = inputs_q.shape
if patch_mask is None:
patch_mask = torch.zeros(b, n_patches, dtype=torch.bool, device=inputs_q.device)
if self.fuse_qkv:
qkv = self.qkv_proj(inputs_q)
query, key, value = torch.chunk(qkv, 3, dim=-1)
query = query.view(b, n_patches, self.num_heads, self.head_dim)
key = key.view(b, n_patches, self.num_heads, self.head_dim)
value = value.view(b, n_patches, self.num_heads, self.head_dim)
else:
query = self.query(inputs_q).view(b, n_patches, self.num_heads, self.head_dim)
key = self.key(inputs_q).view(b, n_patches, self.num_heads, self.head_dim)
value = self.value(inputs_q).view(b, n_patches, self.num_heads, self.head_dim)
if decode_cache is None:
num_masked = torch.sum(patch_mask.to(torch.int32), dim=-1)
next_index = torch.zeros_like(num_masked, dtype=torch.int32)
else:
num_masked = (
torch.sum(patch_mask.to(torch.int32), dim=-1) + decode_cache.num_masked
)
next_index = decode_cache.next_index.clone()
if self.use_rotary_position_embeddings:
position = (
torch.arange(n_patches, device=inputs_q.device)[None, :]
+ next_index[:, None]
- num_masked[:, None]
)
query = self.rotary_position_embedding(query, position)
key = self.rotary_position_embedding(key, position)
query = self.query_ln(query)
key = self.key_ln(key)
if self.use_per_dim_scale:
query = self.per_dim_scale(query)
if decode_cache is not None:
_, decode_cache_size, _, _ = decode_cache.value.shape
start = decode_cache.next_index[0]
end = start + n_patches
# Perform a single, vectorized slice assignment for the entire batch.
# This is vastly more efficient than a Python for-loop.
decode_cache.key[:, start:end] = key
decode_cache.value[:, start:end] = value
key = decode_cache.key
value = decode_cache.value
decode_cache.next_index += n_patches
decode_cache.num_masked = num_masked
attn_mask = make_attn_mask(
query_length=n_patches,
num_all_masked_kv=num_masked,
query_index_offset=next_index,
kv_length=decode_cache_size,
)
else:
attn_mask = make_attn_mask(query_length=n_patches, num_all_masked_kv=num_masked)
x = self.attention_fn(
query,
key,
value,
mask=attn_mask,
)
x = x.reshape(b, n_patches, self.in_features)
out = self.out(x)
return out, decode_cache
class Transformer(nn.Module):
"""Classic Transformer used in TimesFM."""
def __init__(self, config: configs.TransformerConfig):
super().__init__()
self.config = config
if config.attention_norm == "rms":
self.pre_attn_ln = RMSNorm(num_features=config.model_dims)
self.post_attn_ln = RMSNorm(num_features=config.model_dims)
else:
raise ValueError(f"Layer norm: {config.attention_norm} not supported.")
self.attn = MultiHeadAttention(
num_heads=config.num_heads,
in_features=config.model_dims,
use_per_dim_scale=True,
use_rotary_position_embeddings=config.use_rotary_position_embeddings,
qk_norm=config.qk_norm,
fuse_qkv=config.fuse_qkv,
)
if config.feedforward_norm == "rms":
self.pre_ff_ln = RMSNorm(num_features=config.model_dims)
self.post_ff_ln = RMSNorm(num_features=config.model_dims)
else:
raise ValueError(f"Layer norm: {config.feedforward_norm} not supported.")
self.ff0 = nn.Linear(
in_features=config.model_dims,
out_features=config.hidden_dims,
bias=config.use_bias,
)
self.ff1 = nn.Linear(
in_features=config.hidden_dims,
out_features=config.model_dims,
bias=config.use_bias,
)
if config.ff_activation == "relu":
self.activation = nn.ReLU()
elif config.ff_activation == "swish":
self.activation = nn.SiLU()
elif config.ff_activation == "none":
self.activation = nn.Identity()
else:
raise ValueError(f"Activation: {config.ff_activation} not supported.")
def forward(
self,
input_embeddings: torch.Tensor,
patch_mask: torch.Tensor,
decode_cache: DecodeCache | None = None,
) -> tuple[torch.Tensor, DecodeCache | None]:
attn_output, decode_cache = self.attn(
inputs_q=self.pre_attn_ln(input_embeddings),
decode_cache=decode_cache,
patch_mask=patch_mask,
)
attn_output = self.post_attn_ln(attn_output) + input_embeddings
output_embeddings = (
self.post_ff_ln(self.ff1(self.activation(self.ff0(self.pre_ff_ln(attn_output)))))
+ attn_output
)
return output_embeddings, decode_cache
================================================
FILE: src/timesfm/torch/util.py
================================================
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch utility functions for TimesFM layers."""
import dataclasses
import torch
_TOLERANCE = 1e-6
@dataclasses.dataclass(frozen=False)
class DecodeCache:
"""Cache for decoding."""
next_index: torch.Tensor
num_masked: torch.Tensor
key: torch.Tensor
value: torch.Tensor
def update_running_stats(
n: torch.Tensor,
mu: torch.Tensor,
sigma: torch.Tensor,
x: torch.Tensor,
mask: torch.Tensor,
) -> tuple[
tuple[torch.Tensor, torch.Tensor, torch.Tensor],
tuple[torch.Tensor, torch.Tensor, torch.Tensor],
]:
"""Updates the running stats."""
is_legit = torch.logical_not(mask)
inc_n = torch.sum(is_legit.to(x.dtype), dim=-1)
inc_mu_numerator = torch.sum(x * is_legit, dim=-1)
inc_n_safe = torch.where(inc_n == 0, 1.0, inc_n)
inc_mu = inc_mu_numerator / inc_n_safe
inc_mu = torch.where(inc_n == 0, 0.0, inc_mu)
inc_var_numerator = torch.sum(
((x - inc_mu.unsqueeze(-1)) ** 2) * is_legit, dim=-1
)
inc_var = inc_var_numerator / inc_n_safe
inc_var = torch.where(inc_n == 0, 0.0, inc_var)
inc_sigma = torch.sqrt(inc_var)
new_n = n + inc_n
new_n_safe = torch.where(new_n == 0, 1.0, new_n)
new_mu = (n * mu + inc_mu * inc_n) / new_n_safe
new_mu = torch.where(new_n == 0, 0.0, new_mu)
term1 = n * sigma.pow(2)
term2 = inc_n * inc_sigma.pow(2)
term3 = n * (mu - new_mu).pow(2)
term4 = inc_n * (inc_mu - new_mu).pow(2)
new_var = (term1 + term2 + term3 + term4) / new_n_safe
new_var = torch.where(new_n == 0, 0.0, new_var)
new_sigma = torch.sqrt(torch.clamp(new_var, min=0.0))
return (w := (new_n, new_mu, new_sigma), w)
def revin(
x: torch.Tensor,
mu: torch.Tensor,
sigma: torch.Tensor,
reverse: bool = False,
):
"""Reversible instance normalization."""
if len(mu.shape) == len(x.shape) - 1:
mu = mu[..., None]
sigma = sigma[..., None]
elif len(mu.shape) == len(x.shape) - 2:
mu = mu[..., None, None]
sigma = sigma[..., None, None]
if reverse:
return x * sigma + mu
else:
return (x - mu) / torch.where(sigma < _TOLERANCE, 1.0, sigma)
================================================
FILE: src/timesfm/utils/xreg_lib.py
================================================
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helper functions for in-context covariates and regression."""
import itertools
import math
from typing import Any, Iterable, Literal, Mapping, Sequence
try:
import jax
import jax.numpy as jnp
import numpy as np
from sklearn import preprocessing
except ImportError:
raise ImportError(
"Failed to load the XReg module. Did you forget to install `timesfm[xreg]`?"
)
Category = int | str
_TOL = 1e-6
XRegMode = Literal["timesfm + xreg", "xreg + timesfm"]
def _unnest(nested: Sequence[Sequence[Any]]) -> np.ndarray:
return np.array(list(itertools.chain.from_iterable(nested)))
def _repeat(elements: Iterable[Any], counts: Iterable[int]) -> np.ndarray:
return np.array(
list(itertools.chain.from_iterable(map(itertools.repeat, elements, counts)))
)
def _to_padded_jax_array(x: np.ndarray) -> jax.Array:
if x.ndim == 1:
(i,) = x.shape
di = 2 ** math.ceil(math.log2(i)) - i
return jnp.pad(x, ((0, di),), mode="constant", constant_values=0.0)
elif x.ndim == 2:
i, j = x.shape
di = 2 ** math.ceil(math.log2(i)) - i
dj = 2 ** math.ceil(math.log2(j)) - j
return jnp.pad(x, ((0, di), (0, dj)), mode="constant", constant_values=0.0)
else:
raise ValueError(f"Unsupported array shape: {x.shape}")
# Per time series normalization: forward.
def normalize(batch):
stats = [(np.mean(x), np.where((w := np.std(x)) > _TOL, w, 1.0)) for x in batch]
new_batch = [(x - stat[0]) / stat[1] for x, stat in zip(batch, stats)]
return new_batch, stats
# Per time series normalization: inverse.
def renormalize(batch, stats):
return [x * stat[1] + stat[0] for x, stat in zip(batch, stats)]
class BatchedInContextXRegBase:
"""Helper class for in-context regression covariate formatting.
Attributes:
targets: List of targets (responses) of the in-context regression.
train_lens: List of lengths of each target vector from the context.
test_lens: List of lengths of each forecast horizon.
train_dynamic_numerical_covariates: Dict of covariate names mapping to the
dynamic numerical covariates of each forecast task on the context. Their
lengths should match the corresponding lengths in `train_lens`.
train_dynamic_categorical_covariates: Dict of covariate names mapping to the
dynamic categorical covariates of each forecast task on the context. Their
lengths should match the corresponding lengths in `train_lens`.
test_dynamic_numerical_covariates: Dict of covariate names mapping to the
dynamic numerical covariates of each forecast task on the horizon. Their
lengths should match the corresponding lengths in `test_lens`.
test_dynamic_categorical_covariates: Dict of covariate names mapping to the
dynamic categorical covariates of each forecast task on the horizon. Their
lengths should match the corresponding lengths in `test_lens`.
static_numerical_covariates: Dict of covariate names mapping to the static
numerical covariates of each forecast task.
static_categorical_covariates: Dict of covariate names mapping to the static
categorical covariates of each forecast task.
"""
def __init__(
self,
targets: Sequence[Sequence[float]],
train_lens: Sequence[int],
test_lens: Sequence[int],
train_dynamic_numerical_covariates: (
Mapping[str, Sequence[Sequence[float]]] | None
) = None,
train_dynamic_categorical_covariates: (
Mapping[str, Sequence[Sequence[Category]]] | None
) = None,
test_dynamic_numerical_covariates: (
Mapping[str, Sequence[Sequence[float]]] | None
) = None,
test_dynamic_categorical_covariates: (
Mapping[str, Sequence[Sequence[Category]]] | None
) = None,
static_numerical_covariates: Mapping[str, Sequence[float]] | None = None,
static_categorical_covariates: (Mapping[str, Sequence[Category]] | None) = None,
) -> None:
"""Initializes with the exogenous covariate inputs.
Here we use model fitting language to refer to the context as 'train' and
the horizon as 'test'. We assume batched inputs. To properly format the
request:
- `train_lens` represents the contexts in the batch. Targets and all train
dynamic covariates should have the same lengths as the corresponding
elements
in `train_lens`. Notice each `train_len` can be different from the exact
length of the corresponding context depending on how much of the context is
used for fitting the in-context model.
- `test_lens` represents the horizon lengths in the batch. All tesdt
dynamic
covariates should have the same lengths as the corresponding elements in
`test_lens`.
- Static covariates should be one for each input.
- For train and test dynamic covariates, they should have the same
covariate
names.
Pass an empty dict {} for a covariate type if it is not present.
Example:
Here is a set of valid inputs whose schema can be used for reference.
```
targets = [
[0.0, 0.1, 0.2],
[0.0, 0.1, 0.2, 0.3],
] # Two inputs in this batch.
train_lens = [3, 4]
test_lens = [2, 5] # Forecast horizons 2 and 5 respectively.
train_dynamic_numerical_covariates = {
"cov_1_dn": [[0.0, 0.5, 1.0], [0.0, 0.5, 1.0, 1.5]],
"cov_2_dn": [[0.0, 1.5, 1.0], [0.0, 1.5, 1.0, 2.5]],
} # Each train dynamic covariate has 3 and 4 elements respectively.
test_dynamic_numerical_covariates = {
"cov_1_dn": [[0.1, 0.6], [0.1, 0.6, 1.1, 1.6, 2.4]],
"cov_2_dn": [[0.1, 1.1], [0.1, 1.6, 1.1, 2.6, 10.0]],
} # Each test dynamic covariate has 2 and 5 elements respectively.
train_dynamic_categorical_covariates = {
"cov_1_dc": [[0, 1, 0], [0, 1, 2, 3]],
"cov_2_dc": [["good", "bad", "good"], ["good", "good", "bad",
"bad"]],
}
test_dynamic_categorical_covariates = {
"cov_1_dc": [[1, 0], [1, 0, 2, 3, 1]],
"cov_2_dc": [["bad", "good"], ["bad", "bad", "bad", "bad", "bad"]],
}
static_numerical_covariates = {
"cov_1_sn": [0.0, 3.0],
"cov_2_sn": [2.0, 1.0],
"cov_3_sn": [1.0, 2.0],
} # Each static covariate has 1 element for each input.
static_categorical_covariates = {
"cov_1_sc": ["apple", "orange"],
"cov_2_sc": [2, 3],
}
```
Args:
targets: List of targets (responses) of the in-context regression.
train_lens: List of lengths of each target vector from the context.
test_lens: List of lengths of each forecast horizon.
train_dynamic_numerical_covariates: Dict of covariate names mapping to the
dynamic numerical covariates of each forecast task on the context. Their
lengths should match the corresponding lengths in `train_lens`.
train_dynamic_categorical_covariates: Dict of covariate names mapping to
the dynamic categorical covariates of each forecast task on the context.
Their lengths should match the corresponding lengths in `train_lens`.
test_dynamic_numerical_covariates: Dict of covariate names mapping to the
dynamic numerical covariates of each forecast task on the horizon. Their
lengths should match the corresponding lengths in `test_lens`.
test_dynamic_categorical_covariates: Dict of covariate names mapping to
the dynamic categorical covariates of each forecast task on the horizon.
Their lengths should match the corresponding lengths in `test_lens`.
static_numerical_covariates: Dict of covariate names mapping to the static
numerical covariates of each forecast task.
static_categorical_covariates: Dict of covariate names mapping to the
static categorical covariates of each forecast task.
"""
self.targets = targets
self.train_lens = train_lens
self.test_lens = test_lens
self.train_dynamic_numerical_covariates = train_dynamic_numerical_covariates or {}
self.train_dynamic_categorical_covariates = (
train_dynamic_categorical_covariates or {}
)
self.test_dynamic_numerical_covariates = test_dynamic_numerical_covariates or {}
self.test_dynamic_categorical_covariates = test_dynamic_categorical_covariates or {}
self.static_numerical_covariates = static_numerical_covariates or {}
self.static_categorical_covariates = static_categorical_covariates or {}
def _assert_covariates(self, assert_covariate_shapes: bool = False) -> None:
"""Verifies the validity of the covariate inputs."""
# Check presence.
if (
self.train_dynamic_numerical_covariates
and not self.test_dynamic_numerical_covariates
) or (
not self.train_dynamic_numerical_covariates
and self.test_dynamic_numerical_covariates
):
raise ValueError(
"train_dynamic_numerical_covariates and"
" test_dynamic_numerical_covariates must be both present or both"
" absent."
)
if (
self.train_dynamic_categorical_covariates
and not self.test_dynamic_categorical_covariates
) or (
not self.train_dynamic_categorical_covariates
and self.test_dynamic_categorical_covariates
):
raise ValueError(
"train_dynamic_categorical_covariates and"
" test_dynamic_categorical_covariates must be both present or both"
" absent."
)
# Check keys.
for dict_a, dict_b, dict_a_name, dict_b_name in (
(
self.train_dynamic_numerical_covariates,
self.test_dynamic_numerical_covariates,
"train_dynamic_numerical_covariates",
"test_dynamic_numerical_covariates",
),
(
self.train_dynamic_categorical_covariates,
self.test_dynamic_categorical_covariates,
"train_dynamic_categorical_covariates",
"test_dynamic_categorical_covariates",
),
):
if w := set(dict_a.keys()) - set(dict_b.keys()):
raise ValueError(f"{dict_a_name} has keys not present in {dict_b_name}: {w}")
if w := set(dict_b.keys()) - set(dict_a.keys()):
raise ValueError(f"{dict_b_name} has keys not present in {dict_a_name}: {w}")
# Check shapes.
if assert_covariate_shapes:
if len(self.targets) != len(self.train_lens):
raise ValueError(
"targets and train_lens must have the same number of elements."
)
if len(self.train_lens) != len(self.test_lens):
raise ValueError(
"train_lens and test_lens must have the same number of elements."
)
for i, (target, train_len) in enumerate(zip(self.targets, self.train_lens)):
if len(target) != train_len:
raise ValueError(
f"targets[{i}] has length {len(target)} != expected {train_len}."
)
for key, values in self.static_numerical_covariates.items():
if len(values) != len(self.train_lens):
raise ValueError(
f"static_numerical_covariates has key {key} with number of"
f" examples {len(values)} != expected {len(self.train_lens)}."
)
for key, values in self.static_categorical_covariates.items():
if len(values) != len(self.train_lens):
raise ValueError(
f"static_categorical_covariates has key {key} with number of"
f" examples {len(values)} != expected {len(self.train_lens)}."
)
for lens, dict_cov, dict_cov_name in (
(
self.train_lens,
self.train_dynamic_numerical_covariates,
"train_dynamic_numerical_covariates",
),
(
self.train_lens,
self.train_dynamic_categorical_covariates,
"train_dynamic_categorical_covariates",
),
(
self.test_lens,
self.test_dynamic_numerical_covariates,
"test_dynamic_numerical_covariates",
),
(
self.test_lens,
self.test_dynamic_categorical_covariates,
"test_dynamic_categorical_covariates",
),
):
for key, cov_values in dict_cov.items():
if len(cov_values) != len(lens):
raise ValueError(
f"{dict_cov_name} has key {key} with number of examples"
f" {len(cov_values)} != expected {len(lens)}."
)
for i, cov_value in enumerate(cov_values):
if len(cov_value) != lens[i]:
raise ValueError(
f"{dict_cov_name} has key {key} with its {i}-th example"
f" length {len(cov_value)} != expected {lens[i]}."
)
def create_covariate_matrix(
self,
one_hot_encoder_drop: str | None = "first",
use_intercept: bool = True,
assert_covariates: bool = False,
assert_covariate_shapes: bool = False,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Creates target vector and covariate matrices for in context regression.
Here we use model fitting language to refer to the context as 'train' and
the horizon as 'test'.
Args:
one_hot_encoder_drop: Which drop strategy to use for the one hot encoder.
use_intercept: Whether to prepare an intercept (all 1) column in the
matrices.
assert_covariates: Whether to assert the validity of the covariate inputs.
assert_covariate_shapes: Whether to assert the shapes of the covariate
inputs when `assert_covariates` is True.
Returns:
A tuple of the target vector, the covariate matrix for the context, and
the covariate matrix for the horizon.
"""
if assert_covariates:
self._assert_covariates(assert_covariate_shapes)
x_train, x_test = [], []
# Numerical features.
for name in sorted(self.train_dynamic_numerical_covariates):
x_train.append(
_unnest(self.train_dynamic_numerical_covariates[name])[:, np.newaxis]
)
x_test.append(
_unnest(self.test_dynamic_numerical_covariates[name])[:, np.newaxis]
)
for covs in self.static_numerical_covariates.values():
x_train.append(_repeat(covs, self.train_lens)[:, np.newaxis])
x_test.append(_repeat(covs, self.test_lens)[:, np.newaxis])
if x_train:
x_train = np.concatenate(x_train, axis=1)
x_test = np.concatenate(x_test, axis=1)
# Normalize for robustness.
x_mean = np.mean(x_train, axis=0, keepdims=True)
x_std = np.where((w := np.std(x_train, axis=0, keepdims=True)) > _TOL, w, 1.0)
x_train = [(x_train - x_mean) / x_std]
x_test = [(x_test - x_mean) / x_std]
# Categorical features. Encode one by one.
one_hot_encoder = preprocessing.OneHotEncoder(
drop=one_hot_encoder_drop,
sparse_output=False,
handle_unknown="ignore",
)
for name in sorted(self.train_dynamic_categorical_covariates.keys()):
ohe_train = _unnest(self.train_dynamic_categorical_covariates[name])[
:, np.newaxis
]
ohe_test = _unnest(self.test_dynamic_categorical_covariates[name])[:, np.newaxis]
x_train.append(np.array(one_hot_encoder.fit_transform(ohe_train)))
x_test.append(np.array(one_hot_encoder.transform(ohe_test)))
for covs in self.static_categorical_covariates.values():
ohe = one_hot_encoder.fit_transform(np.array(covs)[:, np.newaxis])
x_train.append(_repeat(ohe, self.train_lens))
x_test.append(_repeat(ohe, self.test_lens))
x_train = np.concatenate(x_train, axis=1)
x_test = np.concatenate(x_test, axis=1)
if use_intercept:
x_train = np.pad(x_train, ((0, 0), (1, 0)), constant_values=1.0)
x_test = np.pad(x_test, ((0, 0), (1, 0)), constant_values=1.0)
return _unnest(self.targets), x_train, x_test
def fit(self) -> Any:
raise NotImplementedError("Fit is not implemented.")
class BatchedInContextXRegLinear(BatchedInContextXRegBase):
"""Linear in-context regression model."""
def fit(
self,
ridge: float = 0.0,
one_hot_encoder_drop: str | None = "first",
use_intercept: bool = True,
force_on_cpu: bool = False,
max_rows_per_col: int = 0,
max_rows_per_col_sample_seed: int = 42,
debug_info: bool = False,
assert_covariates: bool = False,
assert_covariate_shapes: bool = False,
) -> (
list[np.ndarray]
| tuple[list[np.ndarray], list[np.ndarray], jax.Array, jax.Array, jax.Array]
):
"""Fits a linear model for in-context regression.
Args:
ridge: A non-negative value for specifying the ridge regression penalty.
If 0 is provided, fallback to ordinary least squares. Note this penalty
is added to the normalized covariate matrix.
one_hot_encoder_drop: Which drop strategy to use for the one hot encoder.
use_intercept: Whether to prepare an intercept (all 1) column in the
matrices.
force_on_cpu: Whether to force execution on cpu for accelerator machines.
max_rows_per_col: How many rows to subsample per column. 0 for no
subsampling. This is for speeding up model fitting.
max_rows_per_col_sample_seed: The seed for the subsampling if needed by
`max_rows_per_col`.
debug_info: Whether to return debug info.
assert_covariates: Whether to assert the validity of the covariate inputs.
assert_covariate_shapes: Whether to assert the shapes of the covariate
inputs when `assert_covariates` is True.
Returns:
If `debug_info` is False:
The linear fits on the horizon.
If `debug_info` is True:
A tuple of:
- the linear fits on the horizon,
- the linear fits on the context,
- the flattened target vector,
- the covariate matrix for the context, and
- the covariate matrix for the horizon.
"""
flat_targets, x_train_raw, x_test = self.create_covariate_matrix(
one_hot_encoder_drop=one_hot_encoder_drop,
use_intercept=use_intercept,
assert_covariates=assert_covariates,
assert_covariate_shapes=assert_covariate_shapes,
)
x_train = x_train_raw.copy()
if max_rows_per_col:
nrows, ncols = x_train.shape
if nrows > (w := ncols * max_rows_per_col):
subsample = jax.random.choice(
jax.random.PRNGKey(max_rows_per_col_sample_seed),
nrows,
(w,),
replace=False,
)
x_train = x_train[subsample]
flat_targets = flat_targets[subsample]
device = jax.devices("cpu")[0] if force_on_cpu else None
# Runs jitted version of the solvers which are quicker at the cost of
# running jitting during the first time calling. Re-jitting happens whenever
# new (padded) shapes are encountered.
# Ocassionally it helps with the speed and the accuracy if we force single
# thread execution on cpu for accelerator machines:
# 1. Avoid moving data to accelarator memory.
# 2. Avoid precision loss if any.
with jax.default_device(device):
x_train_raw = _to_padded_jax_array(x_train_raw)
x_train = _to_padded_jax_array(x_train)
flat_targets = _to_padded_jax_array(flat_targets)
x_test = _to_padded_jax_array(x_test)
beta_hat = (
jnp.linalg.pinv(
x_train.T @ x_train + ridge * jnp.eye(x_train.shape[1]),
hermitian=True,
)
@ x_train.T
@ flat_targets
)
y_hat = x_test @ beta_hat
y_hat_context = x_train_raw @ beta_hat if debug_info else None
outputs = []
outputs_context = []
# Reconstruct the ragged 2-dim batched forecasts from flattened linear fits.
train_index, test_index = 0, 0
for train_index_delta, test_index_delta in zip(self.train_lens, self.test_lens):
outputs.append(np.array(y_hat[test_index : (test_index + test_index_delta)]))
if debug_info:
outputs_context.append(
np.array(y_hat_context[train_index : (train_index + train_index_delta)])
)
train_index += train_index_delta
test_index += test_index_delta
if debug_info:
return outputs, outputs_context, flat_targets, x_train, x_test
else:
return outputs
================================================
FILE: timesfm-forecasting/SKILL.md
================================================
---
name: timesfm-forecasting
description: >
Zero-shot time series forecasting with Google's TimesFM foundation model. Use this
skill when forecasting ANY univariate time series — sales, sensor readings, stock prices,
energy demand, patient vitals, weather, or scientific measurements — without training a
custom model. Supports both basic forecasting and advanced covariate forecasting (XReg)
with dynamic and static exogenous variables. Automatically checks system RAM/GPU before
loading the model, validates dataset fit before processing, supports CSV/DataFrame/array
inputs, and returns point forecasts with calibrated prediction intervals. Includes a
preflight system checker script that MUST be run before first use to verify the machine
can load the model and handle your specific dataset.
license: Apache-2.0
metadata:
author: Clayton Young (@borealBytes)
version: "1.0.0"
---
# TimesFM Forecasting
## Overview
TimesFM (Time Series Foundation Model) is a pretrained decoder-only foundation model
developed by Google Research for time-series forecasting. It works **zero-shot** — feed it
any univariate time series and it returns point forecasts with calibrated quantile
prediction intervals, no training required.
This skill includes a **mandatory preflight system checker** that verifies RAM, GPU memory,
and disk space before the model is ever loaded so the agent never crashes the user's machine.
> **Key numbers**: TimesFM 2.5 uses 200M parameters (~800 MB on disk, ~1.5 GB in RAM on
> CPU, ~1 GB VRAM on GPU). The archived v1/v2 500M-parameter model needs ~32 GB RAM.
> Always run the system checker first.
## When to Use This Skill
Use this skill when:
- Forecasting **any univariate time series** (sales, demand, sensor, vitals, price, weather)
- You need **zero-shot forecasting** without training a custom model
- You want **probabilistic forecasts** with calibrated prediction intervals (quantiles)
- You have time series of **any length** (the model handles 1–16,384 context points)
- You need to **batch-forecast** hundreds or thousands of series efficiently
- You want a **foundation model** approach instead of hand-tuning ARIMA/ETS parameters
- You need **covariate forecasting** with exogenous variables (price, promotions, holidays, day-of-week effects) → use `forecast_with_covariates()` (TimesFM 2.5 + `pip install timesfm[xreg]`)
Do **not** use this skill when:
- You need classical statistical models with coefficient interpretation → use `statsmodels`
- You need time series classification or clustering → use `aeon`
- You need multivariate vector autoregression or Granger causality → use `statsmodels`
- Your data is tabular (not temporal) → use `scikit-learn`
- You cannot install optional dependencies → XReg requires scikit-learn and JAX
> **Note on Anomaly Detection**: TimesFM does not have built-in anomaly detection, but you
> can use the **quantile forecasts as prediction intervals** — values outside the 90% CI
> (q10–q90) are statistically unusual. See `examples/anomaly-detection/` for a full example.
## ⚠️ Mandatory Preflight: System Requirements Check
**CRITICAL — ALWAYS run the system checker before loading the model for the first time.**
```bash
python scripts/check_system.py
```
This script checks:
1. **Available RAM** — warns if below 4 GB, blocks if below 2 GB
2. **GPU availability** — detects CUDA/MPS devices and VRAM
3. **Disk space** — verifies room for the ~800 MB model download
4. **Python version** — requires 3.10+
5. **Existing installation** — checks if `timesfm` and `torch` are installed
> **Note:** Model weights are **NOT stored in this repository**. TimesFM weights (~800 MB)
> download on-demand from HuggingFace on first use and cache in `~/.cache/huggingface/`.
```mermaid
flowchart TD
start["🚀 Run check_system.py"] --> ram{"RAM ≥ 4 GB?"}
ram -->|"Yes"| gpu{"GPU available?"}
ram -->|"No (2-4 GB)"| warn_ram["⚠️ Warning: tight RAM<br/>CPU-only, small batches"]
ram -->|"No (< 2 GB)"| block["🛑 BLOCKED<br/>Insufficient memory"]
warn_ram --> disk
gpu -->|"CUDA / MPS"| vram{"VRAM ≥ 2 GB?"}
gpu -->|"CPU only"| cpu_ok["✅ CPU mode<br/>Slower but works"]
vram -->|"Yes"| gpu_ok["✅ GPU mode<br/>Fast inference"]
vram -->|"No"| cpu_ok
gpu_ok --> disk{"Disk ≥ 2 GB free?"}
cpu_ok --> disk
disk -->|"Yes"| ready["✅ READY<br/>Safe to load model"]
disk -->|"No"| block_disk["🛑 BLOCKED<br/>Need space for weights"]
```
### Dataset Preflight (NEW)
Before loading your actual data, verify it will fit in memory:
```bash
# Quick estimate for your dataset
python scripts/check_system.py \
--num-series 1000 \
--context-length 1024 \
--horizon 24 \
--batch-size 32 \
--estimate-only
```
This will show you the estimated memory requirements and warn if your dataset is too large.
**Memory Estimation Formula**:
`RAM ≈ 0.8 GB (model) + 0.5 GB (overhead) + (0.2 MB × num_series × context_length / 1000)`
**Example Outputs**:
✅ **Dataset Fits**:
```
Total CPU memory: 2.34 GB
Total GPU memory: 2.15 GB
```
⚠️ **Dataset Too Large**:
```
Dataset requires ~12.5 GB RAM but system has 8.0 GB.
Try: context_length=512 or process in chunks of 50 series.
```
### Hardware Requirements by Model Version
| Model | Parameters | RAM (CPU) | VRAM (GPU) | Disk | Context |
| ----- | ---------- | --------- | ---------- | ---- | ------- |
| **TimesFM 2.5** (recommended) | 200M | ≥ 4 GB | ≥ 2 GB | ~800 MB | up to 16,384 |
| TimesFM 2.0 (archived) | 500M | ≥ 16 GB | ≥ 8 GB | ~2 GB | up to 2,048 |
| TimesFM 1.0 (archived) | 200M | ≥ 8 GB | ≥ 4 GB | ~800 MB | up to 2,048 |
> **Recommendation**: Always use TimesFM 2.5 unless you have a specific reason to use an
> older checkpoint. It is smaller, faster, and supports 8× longer context.
## 🔧 Installation
### Step 1: Verify System (always first)
```bash
python scripts/check_system.py
```
### Step 2: Install TimesFM
```bash
# Using uv (fast)
uv pip install timesfm[torch]
# Or using pip
pip install timesfm[torch]
# For JAX/Flax backend (faster on TPU/GPU)
uv pip install timesfm[flax]
```
### Step 3: Install PyTorch for Your Hardware
```bash
# CUDA 12.1 (NVIDIA GPU)
pip install torch>=2.0.0 --index-url https://download.pytorch.org/whl/cu121
# CPU only
pip install torch>=2.0.0 --index-url https://download.pytorch.org/whl/cpu
# Apple Silicon (MPS)
pip install torch>=2.0.0 # MPS support is built-in
```
## 🎯 Quick Start
### Minimal Example
```python
import torch, numpy as np, timesfm
torch.set_float32_matmul_precision("high")
model = timesfm.TimesFM_2p5_200M_torch.from_pretrained(
"google/timesfm-2.5-200m-pytorch"
)
model.compile(timesfm.ForecastConfig(
max_context=1024, max_horizon=256, normalize_inputs=True,
use_continuous_quantile_head=True, force_flip_invariance=True,
infer_is_positive=True, fix_quantile_crossing=True,
))
point, quantiles = model.forecast(horizon=24, inputs=[
np.sin(np.linspace(0, 20, 200)), # any 1-D array
])
# point.shape == (1, 24) — median forecast
# quantiles.shape == (1, 24, 10) — 10th–90th percentile bands
```
### Forecast with Covariates (XReg)
TimesFM 2.5+ supports exogenous variables through `forecast_with_covariates()`.
Requires `pip install timesfm[xreg]`.
```python
point, quantiles = model.forecast_with_covariates(
inputs=inputs,
dynamic_numerical_covariates={"price": price_arrays},
dynamic_categorical_covariates={"holiday": holiday_arrays},
static_categorical_covariates={"region": region_labels},
xreg_mode="xreg + timesfm", # or "timesfm + xreg"
)
```
### Anomaly Detection (via Quantile Intervals)
```python
point, q = model.forecast(horizon=H, inputs=[values])
lower_90 = q[0, :, 1] # 10th percentile
upper_90 = q[0, :, 9] # 90th percentile
actual = test_values
anomalies = (actual < lower_90) | (actual > upper_90)
```
| Severity | Condition | Interpretation |
| -------- | --------- | -------------- |
| **Normal** | Inside 80% CI | Expected behavior |
| **Warning** | Outside 80% CI | Unusual but possible |
| **Critical** | Outside 90% CI | Statistically rare (< 10% probability) |
> See `examples/anomaly-detection/` for a complete worked example with visualization.
## 📊 Understanding the Output
TimesFM returns `(point_forecast, quantile_forecast)`:
- **`point_forecast`**: shape `(batch, horizon)` — the median (0.5 quantile)
- **`quantile_forecast`**: shape `(batch, horizon, 10)` — ten quantile slices:
| Index | Quantile | Use |
| ----- | -------- | --- |
| 0 | Mean | Average prediction |
| 1 | 0.1 | Lower bound of 80% PI |
| 2 | 0.2 | Lower bound of 60% PI |
| **5** | **0.5** | **Median (= `point_forecast`)** |
| 8 | 0.8 | Upper bound of 60% PI |
| 9 | 0.9 | Upper bound of 80% PI |
```python
point, q = model.forecast(horizon=H, inputs=data)
lower_80 = q[:, :, 1] # 10th percentile
upper_80 = q[:, :, 9] # 90th percentile
median = q[:, :, 5]
```
## 🔧 ForecastConfig Reference
All forecasting behavior is controlled by `timesfm.ForecastConfig`:
```python
timesfm.ForecastConfig(
max_context=1024, # Max context window
max_horizon=256, # Max forecast horizon
normalize_inputs=True, # RECOMMENDED — prevents scale instability
per_core_batch_size=32, # Tune for memory
use_continuous_quantile_head=True, # Better quantile accuracy for long horizons
force_flip_invariance=True, # Ensures f(-x) = -f(x)
infer_is_positive=True, # Clamp forecasts ≥ 0 when all inputs > 0
fix_quantile_crossing=True, # Ensure q10 ≤ q20 ≤ ... ≤ q90
return_backcast=False, # Return backcast (for covariate workflows)
)
```
| Parameter | Default | When to Change |
| --------- | ------- | -------------- |
| `max_context` | 0 | Set to match your longest historical window |
| `normalize_inputs` | False | **Always set True** |
| `use_continuous_quantile_head` | False | **Set True** for calibrated PIs |
| `infer_is_positive` | True | Set False for series that can be negative |
| `fix_quantile_crossing` | False | **Set True** for monotonic quantiles |
See `references/api_reference.md` for the complete parameter reference.
## 📋 Common Workflows
### Single Series Forecast
```python
import torch, numpy as np, pandas as pd, timesfm, matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
torch.set_float32_matmul_precision("high")
model = timesfm.TimesFM_2p5_200M_torch.from_pretrained(
"google/timesfm-2.5-200m-pytorch"
)
model.compile(timesfm.ForecastConfig(
max_context=512, max_horizon=52, normalize_inputs=True,
use_continuous_quantile_head=True, fix_quantile_crossing=True,
))
df = pd.read_csv("weekly_demand.csv", parse_dates=["week"])
values = df["demand"].values.astype(np.float32)
point, quantiles = model.forecast(horizon=52, inputs=[values])
fig, ax = plt.subplots(figsize=(12, 5))
ax.plot(values[-104:], label="Historical")
x_fc = range(len(values[-104:]), len(values[-104:]) + 52)
ax.plot(x_fc, point[0], label="Forecast", color="tab:orange")
ax.fill_between(x_fc, quantiles[0, :, 1], quantiles[0, :, 9],
alpha=0.2, color="tab:orange", label="80% PI")
ax.legend(); ax.set_title("52-Week Demand Forecast")
plt.tight_layout(); plt.savefig("forecast.png", dpi=150)
```
### Batch Forecasting (Many Series)
```python
df = pd.read_csv("all_stores.csv", parse_dates=["date"], index_col="date")
inputs = [df[col].dropna().values.astype(np.float32) for col in df.columns]
point, quantiles = model.forecast(horizon=30, inputs=inputs)
import json
results = {col: {"forecast": point[i].tolist(),
"lower_80": quantiles[i, :, 1].tolist(),
"upper_80": quantiles[i, :, 9].tolist()}
for i, col in enumerate(df.columns)}
with open("batch_forecasts.json", "w") as f:
json.dump(results, f, indent=2)
```
### Evaluate Forecast Accuracy
```python
H = 24
train, actual = values[:-H], values[-H:]
point, quantiles = model.forecast(horizon=H, inputs=[train])
pred = point[0]
mae = np.mean(np.abs(actual - pred))
rmse = np.sqrt(np.mean((actual - pred) ** 2))
mape = np.mean(np.abs((actual - pred) / actual)) * 100
coverage = np.mean((actual >= quantiles[0, :, 1]) & (actual <= quantiles[0, :, 9])) * 100
print(f"MAE: {mae:.2f} | RMSE: {rmse:.2f} | MAPE: {mape:.1f}% | 80% PI Coverage: {coverage:.1f}%")
```
## ⚙️ Performance Tuning
```python
# Always set on Ampere+ GPUs (A100, RTX 3090+)
torch.set_float32_matmul_precision("high")
# Batch size guidelines:
# GPU 8 GB VRAM: per_core_batch_size=64
# GPU 16 GB VRAM: per_core_batch_size=128
# CPU 8 GB RAM: per_core_batch_size=8
# CPU 16 GB RAM: per_core_batch_size=32
# Memory-constrained: process in chunks
CHUNK = 50
results = []
for i in range(0, len(inputs), CHUNK):
p, q = model.forecast(horizon=H, inputs=inputs[i:i+CHUNK])
results.append((p, q))
```
## 📚 Available Scripts
### `scripts/check_system.py`
Mandatory preflight checker — run before first model load.
Now includes **dataset-aware memory estimation** to prevent OOM errors before loading your data.
```bash
# Basic system check
python scripts/check_system.py
# Check if your specific dataset will fit
python scripts/check_system.py \
--num-series 1000 \
--context-length 1024 \
--horizon 24 \
--batch-size 32
# Quick memory estimate without system checks
python scripts/check_system.py \
--num-series 5000 \
--context-length 2048 \
--estimate-only
```
**What it checks**:
1. **Available RAM** — warns if below 4 GB, blocks if below 2 GB
2. **GPU availability** — detects CUDA/MPS devices and VRAM
3. **Disk space** — verifies room for the ~800 MB model download
4. **Python version** — requires 3.10+
5. **Existing installation** — checks if `timesfm` and `torch` are installed
6. **Dataset fit** (NEW) — estimates memory for your specific dataset and warns if it won't fit
### `scripts/forecast_csv.py`
End-to-end CSV forecasting CLI.
```bash
python scripts/forecast_csv.py input.csv \
--horizon 24 \
--date-col date \
--value-cols sales,revenue \
--output forecasts.csv
```
## 📖 Reference Documentation
| File | Contents |
| ---- | -------- |
| `references/system_requirements.md` | Hardware tiers, GPU/CPU selection, memory estimation |
| `references/api_reference.md` | Full `ForecastConfig` docs, output shapes, model options |
| `references/data_preparation.md` | Input formats, NaN handling, CSV loading, covariate setup |
## 🧪 Examples
| Example | Directory | What It Demonstrates |
| ------- | --------- | -------------------- |
| **Global Temperature Forecast** | `examples/global-temperature/` | Basic `model.forecast()`, CSV → PNG → GIF pipeline |
| **Anomaly Detection** | `examples/anomaly-detection/` | Two-phase detrend + Z-score + quantile PI, 2-panel viz |
| **Covariates (XReg)** | `examples/covariates-forecasting/` | `forecast_with_covariates()`, 2×2 shared-axis viz |
```bash
# Run all three examples:
cd examples/global-temperature && python run_forecast.py && python visualize_forecast.py
cd examples/anomaly-detection && python detect_anomalies.py
cd examples/covariates-forecasting && python demo_covariates.py
```
### Expected Outputs
| Example | Key output files | Acceptance criteria |
| ------- | ---------------- | ------------------- |
| global-temperature | `output/forecast_output.json`, `output/forecast_visualization.png` | `point_forecast` has 12 values; PNG shows context + forecast + PI bands |
| anomaly-detection | `output/anomaly_detection.json`, `output/anomaly_detection.png` | Sep 2023 flagged CRITICAL (z ≥ 3.0) |
| covariates-forecasting | `output/sales_with_covariates.csv`, `output/covariates_data.png` | 108 rows (3 stores × 36 weeks); distinct price arrays per store |
## Model Versions
| Version | Params | Context | Status | HuggingFace checkpoint |
| ------- | ------ | ------- | ------ | ---------------------- |
| **2.5** | 200M | 16,384 | **Latest** | `google/timesfm-2.5-200m-pytorch` |
| 2.0 | 500M | 2,048 | Archived | `google/timesfm-2.0-500m-pytorch` |
| 1.0 | 200M | 2,048 | Archived | `google/timesfm-1.0-200m-pytorch` |
- TimesFM 1.0/2.0: must pass `freq=[0]` for monthly data
- TimesFM 2.5: no frequency flag — it was removed
## Resources
- **Paper**: [A Decoder-Only Foundation Model for Time-Series Forecasting](https://arxiv.org/abs/2310.10688) (ICML 2024)
- **HuggingFace**: https://huggingface.co/collections/google/timesfm-release-66e4be5fdb56e960c1e482a6
- **Google Blog**: https://research.google/blog/a-decoder-only-foundation-model-for-time-series-forecasting/
- **BigQuery Integration**: https://cloud.google.com/bigquery/docs/timesfm-model
## Quality Checklist
Run after every TimesFM task before declaring success:
- [ ] **Output shape** — `point_fc` is `(n_series, horizon)`, `quant_fc` is `(n_series, horizon, 10)`
- [ ] **Quantile indices** — index 0 = mean, 1 = q10 ... 9 = q90. NOT 0 = q0.
- [ ] **Frequency flag** — TimesFM 1.0/2.0: pass `freq=[0]` for monthly. TimesFM 2.5: omit.
- [ ] **Series length** — context must be ≥ 32 data points.
- [ ] **No NaN** — `np.isnan(point_fc).any()` must be False.
- [ ] **Axes** — multiple panels sharing data must use `sharex=True`.
- [ ] **`matplotlib.use('Agg')`** — before any pyplot import when running headless.
- [ ] **`infer_is_positive`** — set False for temperature, financial returns, negatives.
## Common Mistakes
1. **Quantile index off-by-one** — `quant_fc[..., 0]` is the **mean**, not q0. q10 = index 1, q90 = index 9. Define: `IDX_Q10, IDX_Q90 = 1, 9`.
2. **Variable shadowing in covariate loops** — don't use the outer loop variable as a comprehension variable when building per-series covariate dicts.
3. **Wrong CSV column name** — global-temperature CSV uses `anomaly_c`, not `anomaly`. Print `df.columns` first.
4. **TimesFM 2.5 required for `forecast_with_covariates()`** — TimesFM 1.0 does NOT have this method.
5. **Future covariates must span the full horizon** — dynamic covariates need values for BOTH context AND forecast windows.
6. **Context anomaly detection uses residuals** — detrend first, then Z-score. Raw Z-scores mislead on trending data.
## Validation & Verification
```bash
# Anomaly detection regression:
python -c "
import json
d = json.load(open('examples/anomaly-detection/output/anomaly_detection.json'))
assert d['context_summary']['critical'] >= 1, 'Sep 2023 must be CRITICAL'
print('Anomaly detection: PASS')"
# Covariates regression:
python -c "
import pandas as pd
df = pd.read_csv('examples/covariates-forecasting/output/sales_with_covariates.csv')
assert len(df) == 108, f'Expected 108 rows, got {len(df)}'
print('Covariates: PASS')"
```
================================================
FILE: timesfm-forecasting/examples/anomaly-detection/detect_anomalies.py
================================================
#!/usr/bin/env python3
"""
TimesFM Anomaly Detection Example — Two-Phase Method
Phase 1 (context): Linear detrend + Z-score on 36 months of real NOAA
temperature anomaly data (2022-01 through 2024-12).
Sep 2023 (1.47 C) is a known critical outlier.
Phase 2 (forecast): TimesFM quantile prediction intervals on a 12-month
synthetic future with 3 injected anomalies.
Outputs:
output/anomaly_detection.png -- 2-panel visualization
output/anomaly_detection.json -- structured detection records
"""
from __future__ import annotations
import json
from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
HORIZON = 12
DATA_FILE = (
Path(__file__).parent.parent / "global-temperature" / "temperature_anomaly.csv"
)
OUTPUT_DIR = Path(__file__).parent / "output"
CRITICAL_Z = 3.0
WARNING_Z = 2.0
# quant_fc index mapping: 0=mean, 1=q10, 2=q20, ..., 9=q90
IDX_Q10, IDX_Q20, IDX_Q80, IDX_Q90 = 1, 2, 8, 9
CLR = {"CRITICAL": "#e02020", "WARNING": "#f08030", "NORMAL": "#4a90d9"}
# ---------------------------------------------------------------------------
# Phase 1: context anomaly detection
# ---------------------------------------------------------------------------
def detect_context_anomalies(
values: np.ndarray,
dates: list,
) -> tuple[list[dict], np.ndarray, np.ndarray, float]:
"""Linear detrend + Z-score anomaly detection on context period.
Returns
-------
records : list of dicts, one per month
trend_line : fitted linear trend values (same length as values)
residuals : actual - trend_line
res_std : std of residuals (used as sigma for threshold bands)
"""
n = len(values)
idx = np.arange(n, dtype=float)
coeffs = np.polyfit(idx, values, 1)
trend_line = np.polyval(coeffs, idx)
residuals = values - trend_line
res_std = residuals.std()
records = []
for i, (d, v, r) in enumerate(zip(dates, values, residuals)):
z = r / res_std if res_std > 0 else 0.0
if abs(z) >= CRITICAL_Z:
severity = "CRITICAL"
elif abs(z) >= WARNING_Z:
severity = "WARNING"
else:
severity = "NORMAL"
records.append(
{
"date": str(d)[:7],
"value": round(float(v), 4),
"trend": round(float(trend_line[i]), 4),
"residual": round(float(r), 4),
"z_score": round(float(z), 3),
"severity": severity,
}
)
return records, trend_line, residuals, res_std
# ---------------------------------------------------------------------------
# Phase 2: synthetic future + forecast anomaly detection
# ---------------------------------------------------------------------------
def build_synthetic_future(
context: np.ndarray,
n: int,
seed: int = 42,
) -> tuple[np.ndarray, list[int]]:
"""Build a plausible future with 3 injected anomalies.
Injected months: 3, 8, 11 (0-indexed within the 12-month horizon).
Returns (future_values, injected_indices).
"""
rng = np.random.default_rng(seed)
trend = np.linspace(context[-6:].mean(), context[-6:].mean() + 0.05, n)
noise = rng.normal(0, 0.1, n)
future = trend + noise
injected = [3, 8, 11]
future[3] += 0.7 # CRITICAL spike
future[8] -= 0.65 # CRITICAL dip
future[11] += 0.45 # WARNING spike
return future.astype(np.float32), injected
def detect_forecast_anomalies(
future_values: np.ndarray,
point: np.ndarray,
quant_fc: np.ndarray,
future_dates: list,
injected_at: list[int],
) -> list[dict]:
"""Classify each forecast month by which PI band it falls outside.
CRITICAL = outside 80% PI (q10-q90)
WARNING = outside 60% PI (q20-q80) but inside 80% PI
NORMAL = inside 60% PI
"""
q10 = quant_fc[IDX_Q10]
q20 = quant_fc[IDX_Q20]
q80 = quant_fc[IDX_Q80]
q90 = quant_fc[IDX_Q90]
records = []
for i, (d, fv, pt) in enumerate(zip(future_dates, future_values, point)):
outside_80 = fv < q10[i] or fv > q90[i]
outside_60 = fv < q20[i] or fv > q80[i]
if outside_80:
severity = "CRITICAL"
elif outside_60:
severity = "WARNING"
else:
severity = "NORMAL"
records.append(
{
"date": str(d)[:7],
"actual": round(float(fv), 4),
"forecast": round(float(pt), 4),
"q10": round(float(q10[i]), 4),
"q20": round(float(q20[i]), 4),
"q80": round(float(q80[i]), 4),
"q90": round(float(q90[i]), 4),
"severity": severity,
"was_injected": i in injected_at,
}
)
return records
# ---------------------------------------------------------------------------
# Visualization
# ---------------------------------------------------------------------------
def plot_results(
context_dates: list,
context_values: np.ndarray,
ctx_records: list[dict],
trend_line: np.ndarray,
residuals: np.ndarray,
res_std: float,
future_dates: list,
future_values: np.ndarray,
point_fc: np.ndarray,
quant_fc: np.ndarray,
fc_records: list[dict],
) -> None:
OUTPUT_DIR.mkdir(exist_ok=True)
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 10), gridspec_kw={"hspace": 0.42})
fig.suptitle(
"TimesFM Anomaly Detection — Two-Phase Method", fontsize=14, fontweight="bold"
)
# -----------------------------------------------------------------------
# Panel 1 — full timeline
# -----------------------------------------------------------------------
ctx_x = [pd.Timestamp(d) for d in context_dates]
fut_x = [pd.Timestamp(d) for d in future_dates]
divider = ctx_x[-1]
# context: blue line + trend + 2sigma band
ax1.plot(
ctx_x,
context_values,
color=CLR["NORMAL"],
lw=2,
marker="o",
ms=4,
label="Observed (context)",
)
ax1.plot(ctx_x, trend_line, color="#aaaaaa", lw=1.5, ls="--", label="Linear trend")
ax1.fill_between(
ctx_x,
trend_line - 2 * res_std,
trend_line + 2 * res_std,
alpha=0.15,
color=CLR["NORMAL"],
label="+/-2sigma band",
)
# context anomaly markers
seen_ctx: set[str] = set()
for rec in ctx_records:
if rec["severity"] == "NORMAL":
continue
d = pd.Timestamp(rec["date"])
v = rec["value"]
sev = rec["severity"]
lbl = f"Context {sev}" if sev not in seen_ctx else None
seen_ctx.add(sev)
ax1.scatter(d, v, marker="D", s=90, color=CLR[sev], zorder=6, label=lbl)
ax1.annotate(
f"z={rec['z_score']:+.1f}",
(d, v),
textcoords="offset points",
xytext=(0, 9),
fontsize=7.5,
ha="center",
color=CLR[sev],
)
# forecast section
q10 = quant_fc[IDX_Q10]
q20 = quant_fc[IDX_Q20]
q80 = quant_fc[IDX_Q80]
q90 = quant_fc[IDX_Q90]
ax1.plot(fut_x, future_values, "k--", lw=1.5, label="Synthetic future (truth)")
ax1.plot(
fut_x,
point_fc,
color=CLR["CRITICAL"],
lw=2,
marker="s",
ms=4,
label="TimesFM point forecast",
)
ax1.fill_between(fut_x, q10, q90, alpha=0.15, color=CLR["CRITICAL"], label="80% PI")
ax1.fill_between(fut_x, q20, q80, alpha=0.25, color=CLR["CRITICAL"], label="60% PI")
seen_fc: set[str] = set()
for i, rec in enumerate(fc_records):
if rec["severity"] == "NORMAL":
continue
d = pd.Timestamp(rec["date"])
v = rec["actual"]
sev = rec["severity"]
mk = "X" if sev == "CRITICAL" else "^"
lbl = f"Forecast {sev}" if sev not in seen_fc else None
seen_fc.add(sev)
ax1.scatter(d, v, marker=mk, s=100, color=CLR[sev], zorder=6, label=lbl)
ax1.axvline(divider, color="#555555", lw=1.5, ls=":")
ax1.text(
divider,
ax1.get_ylim()[1] if ax1.get_ylim()[1] != 0 else 1.5,
" <- Context | Forecast ->",
fontsize=8.5,
color="#555555",
style="italic",
va="top",
)
ax1.annotate(
"Context: D = Z-score anomaly | Forecast: X = CRITICAL, ^ = WARNING",
xy=(0.01, 0.04),
xycoords="axes fraction",
fontsize=8,
bbox=dict(boxstyle="round", fc="white", ec="#cccccc", alpha=0.9),
)
ax1.set_ylabel("Temperature Anomaly (C)", fontsize=10)
ax1.legend(ncol=2, fontsize=7.5, loc="upper left")
ax1.grid(True, alpha=0.22)
# -----------------------------------------------------------------------
# Panel 2 — deviation bars across all 48 months
# -----------------------------------------------------------------------
all_labels: list[str] = []
bar_colors: list[str] = []
bar_heights: list[float] = []
for rec in ctx_records:
all_labels.append(rec["date"])
bar_heights.append(rec["residual"])
bar_colors.append(CLR[rec["severity"]])
fc_deviations: list[float] = []
for rec in fc_records:
all_labels.append(rec["date"])
dev = rec["actual"] - rec["forecast"]
fc_deviations.append(dev)
bar_heights.append(dev)
bar_colors.append(CLR[rec["severity"]])
xs = np.arange(len(all_labels))
ax2.bar(xs[:36], bar_heights[:36], color=bar_colors[:36], alpha=0.8)
ax2.bar(xs[36:], bar_heights[36:], color=bar_colors[36:], alpha=0.8)
# threshold lines for context section only
ax2.hlines(
[2 * res_std, -2 * res_std], -0.5, 35.5, colors=CLR["NORMAL"], lw=1.2, ls="--"
)
ax2.hlines(
[3 * res_std, -3 * res_std], -0.5, 35.5, colors=CLR["NORMAL"], lw=1.0, ls=":"
)
# PI bands for forecast section
fc_xs = xs[36:]
ax2.fill_between(
fc_xs,
q10 - point_fc,
q90 - point_fc,
alpha=0.12,
color=CLR["CRITICAL"],
step="mid",
)
ax2.fill_between(
fc_xs,
q20 - point_fc,
q80 - point_fc,
alpha=0.20,
color=CLR["CRITICAL"],
step="mid",
)
ax2.axvline(35.5, color="#555555", lw=1.5, ls="--")
ax2.axhline(0, color="black", lw=0.8, alpha=0.6)
ax2.text(
10,
ax2.get_ylim()[0] * 0.85 if ax2.get_ylim()[0] < 0 else -0.05,
"<- Context: delta from linear trend",
fontsize=8,
style="italic",
color="#555555",
ha="center",
)
ax2.text(
41,
ax2.get_ylim()[0] * 0.85 if ax2.get_ylim()[0] < 0 else -0.05,
"Forecast: delta from TimesFM ->",
fontsize=8,
style="italic",
color="#555555",
ha="center",
)
tick_every = 3
ax2.set_xticks(xs[::tick_every])
ax2.set_xticklabels(all_labels[::tick_every], rotation=45, ha="right", fontsize=7)
ax2.set_ylabel("Delta from expected (C)", fontsize=10)
ax2.grid(True, alpha=0.22, axis="y")
legend_patches = [
mpatches.Patch(color=CLR["CRITICAL"], label="CRITICAL"),
mpatches.Patch(color=CLR["WARNING"], label="WARNING"),
mpatches.Patch(color=CLR["NORMAL"], label="Normal"),
]
ax2.legend(handles=legend_patches, fontsize=8, loc="upper right")
output_path = OUTPUT_DIR / "anomaly_detection.png"
plt.savefig(output_path, dpi=150, bbox_inches="tight")
plt.close()
print(f"\n Saved: {output_path}")
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main() -> None:
print("=" * 68)
print(" TIMESFM ANOMALY DETECTION — TWO-PHASE METHOD")
print("=" * 68)
# --- Load context data ---------------------------------------------------
df = pd.read_csv(DATA_FILE)
df["date"] = pd.to_datetime(df["date"])
df = df.sort_values("date").reset_index(drop=True)
context_values = df["anomaly_c"].values.astype(np.float32)
context_dates = [pd.Timestamp(d) for d in df["date"].tolist()]
start_str = context_dates[0].strftime('%Y-%m') if not pd.isnull(context_dates[0]) else '?'
end_str = context_dates[-1].strftime('%Y-%m') if not pd.isnull(context_dates[-1]) else '?'
print(f"\n Context: {len(context_values)} months ({start_str} - {end_str})")
# --- Phase 1: context anomaly detection ----------------------------------
ctx_records, trend_line, residuals, res_std = detect_context_anomalies(
context_values, context_dates
)
ctx_critical = [r for r in ctx_records if r["severity"] == "CRITICAL"]
ctx_warning = [r for r in ctx_records if r["severity"] == "WARNING"]
print(f"\n [Phase 1] Context anomalies (Z-score, sigma={res_std:.3f} C):")
print(f" CRITICAL (|Z|>={CRITICAL_Z}): {len(ctx_critical)}")
for r in ctx_critical:
print(f" {r['date']} {r['value']:+.3f} C z={r['z_score']:+.2f}")
print(f" WARNING (|Z|>={WARNING_Z}): {len(ctx_warning)}")
for r in ctx_warning:
print(f" {r['date']} {r['value']:+.3f} C z={r['z_score']:+.2f}")
# --- Load TimesFM --------------------------------------------------------
print("\n Loading TimesFM 1.0 ...")
import timesfm
hparams = timesfm.TimesFmHparams(horizon_len=HORIZON)
checkpoint = timesfm.TimesFmCheckpoint(
huggingface_repo_id="google/timesfm-1.0-200m-pytorch"
)
model = timesfm.TimesFm(hparams=hparams, checkpoint=checkpoint)
point_out, quant_out = model.forecast([context_values], freq=[0])
point_fc = point_out[0] # shape (HORIZON,)
quant_fc = quant_out[0].T # shape (10, HORIZON)
# --- Build synthetic future + Phase 2 detection --------------------------
future_values, injected = build_synthetic_future(context_values, HORIZON)
last_date = context_dates[-1]
future_dates = [last_date + pd.DateOffset(months=i + 1) for i in range(HORIZON)]
fc_records = detect_forecast_anomalies(
future_values, point_fc, quant_fc, future_dates, injected
)
fc_critical = [r for r in fc_records if r["severity"] == "CRITICAL"]
fc_warning = [r for r in fc_records if r["severity"] == "WARNING"]
print(f"\n [Phase 2] Forecast anomalies (quantile PI, horizon={HORIZON} months):")
print(f" CRITICAL (outside 80% PI): {len(fc_critical)}")
for r in fc_critical:
print(
f" {r['date']} actual={r['actual']:+.3f} "
f"fc={r['forecast']:+.3f} injected={r['was_injected']}"
)
print(f" WARNING (outside 60% PI): {len(fc_warning)}")
for r in fc_warning:
print(
f" {r['date']} actual={r['actual']:+.3f} "
f"fc={r['forecast']:+.3f} injected={r['was_injected']}"
)
# --- Plot ----------------------------------------------------------------
print("\n Generating 2-panel visualization...")
plot_results(
context_dates,
context_values,
ctx_records,
trend_line,
residuals,
res_std,
future_dates,
future_values,
point_fc,
quant_fc,
fc_records,
)
# --- Save JSON -----------------------------------------------------------
OUTPUT_DIR.mkdir(exist_ok=True)
out = {
"method": "two_phase",
"context_method": "linear_detrend_zscore",
"forecast_method": "quantile_prediction_intervals",
"thresholds": {
"critical_z": CRITICAL_Z,
"warning_z": WARNING_Z,
"pi_critical_pct": 80,
"pi_warning_pct": 60,
},
"context_summary": {
"total": len(ctx_records),
"critical": len(ctx_critical),
"warning": len(ctx_warning),
"normal": len([r for r in ctx_records if r["severity"] == "NORMAL"]),
"res_std": round(float(res_std), 5),
},
"forecast_summary": {
"total": len(fc_records),
"critical": len(fc_critical),
"warning": len(fc_warning),
"normal": len([r for r in fc_records if r["severity"] == "NORMAL"]),
},
"context_detections": ctx_records,
"forecast_detections": fc_records,
}
json_path = OUTPUT_DIR / "anomaly_detection.json"
with open(json_path, "w") as f:
json.dump(out, f, indent=2)
print(f" Saved: {json_path}")
print("\n" + "=" * 68)
print(" SUMMARY")
print("=" * 68)
print(
f" Context ({len(ctx_records)} months): "
f"{len(ctx_critical)} CRITICAL, {len(ctx_warning)} WARNING"
)
print(
f" Forecast ({len(fc_records)} months): "
f"{len(fc_critical)} CRITICAL, {len(fc_warning)} WARNING"
)
print("=" * 68)
if __name__ == "__main__":
main()
================================================
FILE: timesfm-forecasting/examples/anomaly-detection/output/anomaly_detection.json
================================================
{
"method": "two_phase",
"context_method": "linear_detrend_zscore",
"forecast_method": "quantile_prediction_intervals",
"thresholds": {
"critical_z": 3.0,
"warning_z": 2.0,
"pi_critical_pct": 80,
"pi_warning_pct": 60
},
"context_summary": {
"total": 36,
"critical": 1,
"warning": 0,
"normal": 35,
"res_std": 0.11362
},
"forecast_summary": {
"total": 12,
"critical": 4,
"warning": 1,
"normal": 7
},
"context_detections": [
{
"date": "2022-01",
"value": 0.89,
"trend": 0.837,
"residual": 0.053,
"z_score": 0.467,
"severity": "NORMAL"
},
{
"date": "2022-02",
"value": 0.89,
"trend": 0.8514,
"residual": 0.0386,
"z_score": 0.34,
"severity": "NORMAL"
},
{
"date": "2022-03",
"value": 1.02,
"trend": 0.8658,
"residual": 0.1542,
"z_score": 1.357,
"severity": "NORMAL"
},
{
"date": "2022-04",
"value": 0.88,
"trend": 0.8803,
"residual": -0.0003,
"z_score": -0.002,
"severity": "NORMAL"
},
{
"date": "2022-05",
"value": 0.85,
"trend": 0.8947,
"residual": -0.0447,
"z_score": -0.394,
"severity": "NORMAL"
},
{
"date": "2022-06",
"value": 0.88,
"trend": 0.9092,
"residual": -0.0292,
"z_score": -0.257,
"severity": "NORMAL"
},
{
"date": "2022-07",
"value": 0.88,
"trend": 0.9236,
"residual": -0.0436,
"z_score": -0.384,
"severity": "NORMAL"
},
{
"date": "2022-08",
"value": 0.9,
"trend": 0.9381,
"residual": -0.0381,
"z_score": -0.335,
"severity": "NORMAL"
},
{
"date": "2022-09",
"value": 0.88,
"trend": 0.9525,
"residual": -0.0725,
"z_score": -0.638,
"severity": "NORMAL"
},
{
"date": "2022-10",
"value": 0.95,
"trend": 0.9669,
"residual": -0.0169,
"z_score": -0.149,
"severity": "NORMAL"
},
{
"date": "2022-11",
"value": 0.77,
"trend": 0.9814,
"residual": -0.2114,
"z_score": -1.86,
"severity": "NORMAL"
},
{
"date": "2022-12",
"value": 0.78,
"trend": 0.9958,
"residual": -0.2158,
"z_score": -1.9,
"severity": "NORMAL"
},
{
"date": "2023-01",
"value": 0.87,
"trend": 1.0103,
"residual": -0.1403,
"z_score": -1.235,
"severity": "NORMAL"
},
{
"date": "2023-02",
"value": 0.98,
"trend": 1.0247,
"residual": -0.0447,
"z_score": -0.394,
"severity": "NORMAL"
},
{
"date": "2023-03",
"value": 1.21,
"trend": 1.0392,
"residual": 0.1708,
"z_score": 1.503,
"severity": "NORMAL"
},
{
"date": "2023-04",
"value": 1.0,
"trend": 1.0536,
"residual": -0.0536,
"z_score": -0.472,
"severity": "NORMAL"
},
{
"date": "2023-05",
"value": 0.94,
"trend": 1.0681,
"residual": -0.1281,
"z_score": -1.127,
"severity": "NORMAL"
},
{
"date": "2023-06",
"value": 1.08,
"trend": 1.0825,
"residual": -0.0025,
"z_score": -0.022,
"severity": "NORMAL"
},
{
"date": "2023-07",
"value": 1.18,
"trend": 1.0969,
"residual": 0.0831,
"z_score": 0.731,
"severity": "NORMAL"
},
{
"date": "2023-08",
"value": 1.24,
"trend": 1.1114,
"residual": 0.1286,
"z_score": 1.132,
"severity": "NORMAL"
},
{
"date": "2023-09",
"value": 1.47,
"trend": 1.1258,
"residual": 0.3442,
"z_score": 3.029,
"severity": "CRITICAL"
},
{
"date": "2023-10",
"value": 1.32,
"trend": 1.1403,
"residual": 0.1797,
"z_score": 1.582,
"severity": "NORMAL"
},
{
"date": "2023-11",
"value": 1.18,
"trend": 1.1547,
"residual": 0.0253,
"z_score": 0.222,
"severity": "NORMAL"
},
{
"date": "2023-12",
"value": 1.16,
"trend": 1.1692,
"residual": -0.0092,
"z_score": -0.081,
"severity": "NORMAL"
},
{
"date": "2024-01",
"value": 1.22,
"trend": 1.1836,
"residual": 0.0364,
"z_score": 0.32,
"severity": "NORMAL"
},
{
"date": "2024-02",
"value": 1.35,
"trend": 1.1981,
"residual": 0.1519,
"z_score": 1.337,
"severity": "NORMAL"
},
{
"date": "2024-03",
"value": 1.34,
"trend": 1.2125,
"residual": 0.1275,
"z_score": 1.122,
"severity": "NORMAL"
},
{
"date": "2024-04",
"value": 1.26,
"trend": 1.2269,
"residual": 0.0331,
"z_score": 0.291,
"severity": "NORMAL"
},
{
"date": "2024-05",
"value": 1.15,
"trend": 1.2414,
"residual": -0.0914,
"z_score": -0.804,
"severity": "NORMAL"
},
{
"date": "2024-06",
"value": 1.2,
"trend": 1.2558,
"residual": -0.0558,
"z_score": -0.491,
"severity": "NORMAL"
},
{
"date": "2024-07",
"value": 1.24,
"trend": 1.2703,
"residual": -0.0303,
"z_score": -0.266,
"severity": "NORMAL"
},
{
"date": "2024-08",
"value": 1.3,
"trend": 1.2847,
"residual": 0.0153,
"z_score": 0.135,
"severity": "NORMAL"
},
{
"date": "2024-09",
"value": 1.28,
"trend": 1.2992,
"residual": -0.0192,
"z_score": -0.169,
"severity": "NORMAL"
},
{
"date": "2024-10",
"value": 1.27,
"trend": 1.3136,
"residual": -0.0436,
"z_score": -0.384,
"severity": "NORMAL"
},
{
"date": "2024-11",
"value": 1.22,
"trend": 1.328,
"residual": -0.108,
"z_score": -0.951,
"severity": "NORMAL"
},
{
"date": "2024-12",
"value": 1.2,
"trend": 1.3425,
"residual": -0.1425,
"z_score": -1.254,
"severity": "NORMAL"
}
],
"forecast_detections": [
{
"date": "2025-01",
"actual": 1.2821,
"forecast": 1.2593,
"q10": 1.1407,
"q20": 1.1881,
"q80": 1.324,
"q90": 1.3679,
"severity": "NORMAL",
"was_injected": false
},
{
"date": "2025-02",
"actual": 1.1522,
"forecast": 1.2857,
"q10": 1.1406,
"q20": 1.1961,
"q80": 1.3751,
"q90": 1.4254,
"severity": "WARNING",
"was_injected": false
},
{
"date": "2025-03",
"actual": 1.3358,
"forecast": 1.295,
"q10": 1.1269,
"q20": 1.1876,
"q80": 1.4035,
"q90": 1.4643,
"severity": "NORMAL",
"was_injected": false
},
{
"date": "2025-04",
"actual": 2.0594,
"forecast": 1.2208,
"q10": 1.0353,
"q20": 1.1042,
"q80": 1.331,
"q90": 1.4017,
"severity": "CRITICAL",
"was_injected": true
},
{
"date": "2025-05",
"actual": 1.0747,
"forecast": 1.1703,
"q10": 0.9691,
"q20": 1.0431,
"q80": 1.2892,
"q90": 1.3632,
"severity": "NORMAL",
"was_injected": false
},
{
"date": "2025-06",
"actual": 1.1442,
"forecast": 1.1456,
"q10": 0.942,
"q20": 1.0111,
"q80": 1.2703,
"q90": 1.3454,
"severity": "NORMAL",
"was_injected": false
},
{
"date": "2025-07",
"actual": 1.2917,
"forecast": 1.1702,
"q10": 0.9504,
"q20": 1.0348,
"q80": 1.2998,
"q90": 1.3807,
"severity": "NORMAL",
"was_injected": false
},
{
"date": "2025-08",
"actual": 1.2519,
"forecast": 1.2027,
"q10": 0.9709,
"q20": 1.0594,
"q80": 1.3408,
"q90": 1.4195,
"severity": "NORMAL",
"was_injected": false
},
{
"date": "2025-09",
"actual": 0.6364,
"forecast": 1.191,
"q10": 0.9594,
"q20": 1.0404,
"q80": 1.3355,
"q90": 1.417,
"severity": "CRITICAL",
"was_injected": true
},
{
"date": "2025-10",
"actual": 1.2073,
"forecast": 1.1491,
"q10": 0.9079,
"q20": 0.9953,
"q80": 1.2869,
"q90": 1.3775,
"severity": "NORMAL",
"was_injected": false
},
{
"date": "2025-11",
"actual": 1.3851,
"forecast": 1.0805,
"q10": 0.8361,
"q20": 0.926,
"q80": 1.2284,
"q90": 1.3122,
"severity": "CRITICAL",
"was_injected": false
},
{
"date": "2025-12",
"actual": 1.8294,
"forecast": 1.0613,
"q10": 0.8022,
"q20": 0.8952,
"q80": 1.2169,
"q90": 1.296,
"severity": "CRITICAL",
"was_injected": true
}
]
}
================================================
FILE: timesfm-forecasting/examples/covariates-forecasting/demo_covariates.py
================================================
#!/usr/bin/env python3
"""
TimesFM Covariates (XReg) Example
Demonstrates the TimesFM covariate API using synthetic retail sales data.
TimesFM 1.0 does NOT support forecast_with_covariates(); that requires
TimesFM 2.5 + `pip install timesfm[xreg]`.
This script:
1. Generates synthetic 3-store weekly retail data (24-week context, 12-week horizon)
2. Produces a 2x2 visualization showing WHAT each covariate contributes
and WHY knowing them improves forecasts -- all panels share the same
week x-axis (0 = first context week, 35 = last horizon week)
3. Exports a compact CSV (108 rows) and metadata JSON
NOTE ON REAL DATA:
If you want to use a real retail dataset (e.g., Kaggle Rossmann Store Sales),
download it to a TEMP location -- do NOT commit large CSVs to this repo.
import tempfile, urllib.request
tmp = tempfile.mkdtemp(prefix="timesfm_retail_")
# urllib.request.urlretrieve("https://...store_sales.csv", f"{tmp}/store_sales.csv")
# df = pd.read_csv(f"{tmp}/store_sales.csv")
This skills directory intentionally keeps only tiny reference datasets.
"""
from __future__ import annotations
import json
from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
EXAMPLE_DIR = Path(__file__).parent
OUTPUT_DIR = EXAMPLE_DIR / "output"
N_STORES = 3
CONTEXT_LEN = 24
HORIZON_LEN = 12
TOTAL_LEN = CONTEXT_LEN + HORIZON_LEN # 36
def generate_sales_data() -> dict:
"""Generate synthetic retail sales data with covariate components stored separately.
Returns a dict with:
stores: {store_id: {sales, config}}
covariates: {price, promotion, holiday, day_of_week, store_type, region}
components: {store_id: {base, price_effect, promo_effect, holiday_effect}}
Components let us show 'what would sales look like without covariates?' --
the gap between 'base' and 'sales' IS the covariate signal.
BUG FIX v3: Previous versions had variable-shadowing where inner dict
comprehension `{store_id: ... for store_id in stores}` overwrote the outer
loop variable causing all stores to get identical covariate arrays.
Fixed by accumulating per-store arrays separately before building covariate dict.
"""
rng = np.random.default_rng(42)
stores = {
"store_A": {"type": "premium", "region": "urban", "base_sales": 1000},
"store_B": {"type": "standard", "region": "suburban", "base_sales": 750},
"store_C": {"type": "discount", "region": "rural", "base_sales": 500},
}
base_prices = {"store_A": 12.0, "store_B": 10.0, "store_C": 7.5}
data: dict = {"stores": {}, "covariates": {}, "components": {}}
prices_by_store: dict[str, np.ndarray] = {}
promos_by_store: dict[str, np.ndarray] = {}
holidays_by_store: dict[str, np.ndarray] = {}
dow_by_store: dict[str, np.ndarray] = {}
for store_id, config in stores.items():
bp = base_prices[store_id]
weeks = np.arange(TOTAL_LEN)
trend = config["base_sales"] * (1 + 0.005 * weeks)
seasonality = 80 * np.sin(2 * np.pi * weeks / 52)
noise = rng.normal(0, 40, TOTAL_LEN)
base = (trend + seasonality + noise).astype(np.float32)
price = (bp + rng.uniform(-0.5, 0.5, TOTAL_LEN)).astype(np.float32)
price_effect = (-20 * (price - bp)).astype(np.float32)
holidays = np.zeros(TOTAL_LEN, dtype=np.float32)
for hw in [0, 11, 23, 35]:
if hw < TOTAL_LEN:
holidays[hw] = 1.0
holiday_effect = (200 * holidays).astype(np.float32)
promotion = rng.choice([0.0, 1.0], TOTAL_LEN, p=[0.8, 0.2]).astype(np.float32)
promo_effect = (150 * promotion).astype(np.float32)
day_of_week = np.tile(np.arange(7), TOTAL_LEN // 7 + 1)[:TOTAL_LEN].astype(
np.int32
)
sales = np.maximum(base + price_effect + holiday_effect + promo_effect, 50.0)
data["stores"][store_id] = {"sales": sales, "config": config}
data["components"][store_id] = {
"base": base,
"price_effect": price_effect,
"promo_effect": promo_effect,
"holiday_effect": holiday_effect,
}
prices_by_store[store_id] = price
promos_by_store[store_id] = promotion
holidays_by_store[store_id] = holidays
dow_by_store[store_id] = day_of_week
data["covariates"] = {
"price": prices_by_store,
"promotion": promos_by_store,
"holiday": holidays_by_store,
"day_of_week": dow_by_store,
"store_type": {sid: stores[sid]["type"] for sid in stores},
"region": {sid: stores[sid]["region"] for sid in stores},
}
return data
def create_visualization(data: dict) -> None:
"""
2x2 figure -- ALL panels share x-axis = weeks 0-35.
(0,0) Sales by store -- context solid, horizon dashed
(0,1) Store A: actual vs baseline (no covariates), with event overlays showing uplift
(1,0) Price covariate for all stores -- full 36 weeks including horizon
(1,1) Covariate effect decomposition for Store A (stacked fill_between)
Each panel has a conclusion annotation box explaining what the data shows.
"""
OUTPUT_DIR.mkdir(exist_ok=True)
store_colors = {"store_A": "#1a56db", "store_B": "#057a55", "store_C": "#c03221"}
weeks = np.arange(TOTAL_LEN)
fig, axes = plt.subplots(
2,
2,
figsize=(16, 11),
sharex=True,
gridspec_kw={"hspace": 0.42, "wspace": 0.32},
)
fig.suptitle(
"TimesFM Covariates (XReg) -- Retail Sales with Exogenous Variables\n"
"Shared x-axis: Week 0-23 = context (observed) | Week 24-35 = forecast horizon",
fontsize=13,
fontweight="bold",
y=1.01,
)
def add_divider(ax, label_top=True):
ax.axvline(CONTEXT_LEN - 0.5, color="#9ca3af", lw=1.3, ls="--", alpha=0.8)
ax.axvspan(
CONTEXT_LEN - 0.5, TOTAL_LEN - 0.5, alpha=0.06, color="grey", zorder=0
)
if label_top:
ax.text(
CONTEXT_LEN + 0.3,
1.01,
"<- horizon ->",
transform=ax.get_xaxis_transform(),
fontsize=7.5,
color="#6b7280",
style="italic",
)
# -- (0,0): Sales by Store ---------------------------------------------------
ax = axes[0, 0]
base_price_labels = {"store_A": "$12", "store_B": "$10", "store_C": "$7.50"}
for sid, store_data in data["stores"].items():
sales = store_data["sales"]
c = store_colors[sid]
lbl = f"{sid} ({store_data['config']['type']}, {base_price_labels[sid]} base)"
ax.plot(
weeks[:CONTEXT_LEN],
sales[:CONTEXT_LEN],
color=c,
lw=2,
marker="o",
ms=3,
label=lbl,
)
ax.plot(
weeks[CONTEXT_LEN:],
sales[CONTEXT_LEN:],
color=c,
lw=1.5,
ls="--",
marker="o",
ms=3,
alpha=0.6,
)
add_divider(ax)
ax.set_ylabel("Weekly Sales (units)", fontsize=10)
ax.set_title("Sales by Store", fontsize=11, fontweight="bold")
ax.legend(fontsize=7.5, loc="upper left")
ax.grid(True, alpha=0.22)
ratio = (
data["stores"]["store_A"]["sales"][:CONTEXT_LEN].mean()
/ data["stores"]["store_C"]["sales"][:CONTEXT_LEN].mean()
)
ax.annotate(
f"Store A earns {ratio:.1f}x Store C\n(premium vs discount pricing)\n"
f"-> store_type is a useful static covariate",
xy=(0.97, 0.05),
xycoords="axes fraction",
ha="right",
fontsize=8,
bbox=dict(boxstyle="round", fc="#fffbe6", ec="#d4a017", alpha=0.95),
)
# -- (0,1): Store A actual vs baseline ---------------------------------------
ax = axes[0, 1]
comp_A = data["components"]["store_A"]
sales_A = data["stores"]["store_A"]["sales"]
base_A = comp_A["base"]
promo_A = data["covariates"]["promotion"]["store_A"]
holiday_A = data["covariates"]["holiday"]["store_A"]
ax.plot(
weeks[:CONTEXT_LEN],
base_A[:CONTEXT_LEN],
color="#9ca3af",
lw=1.8,
ls="--",
label="Baseline (no covariates)",
)
ax.fill_between(
weeks[:CONTEXT_LEN],
base_A[:CONTEXT_LEN],
sales_A[:CONTEXT_LEN],
where=(sales_A[:CONTEXT_LEN] > base_A[:CONTEXT_LEN]),
alpha=0.35,
color="#22c55e",
label="Covariate uplift",
)
ax.fill_between(
weeks[:CONTEXT_LEN],
sales_A[:CONTEXT_LEN],
base_A[:CONTEXT_LEN],
where=(sales_A[:CONTEXT_LEN] < base_A[:CONTEXT_LEN]),
alpha=0.30,
color="#ef4444",
label="Price suppression",
)
ax.plot(
weeks[:CONTEXT_LEN],
sales_A[:CONTEXT_LEN],
color=store_colors["store_A"],
lw=2,
label="Actual sales (Store A)",
)
for w in range(CONTEXT_LEN):
if holiday_A[w] > 0:
ax.axvspan(w - 0.45, w + 0.45, alpha=0.22, color="darkorange", zorder=0)
promo_weeks = [w for w in range(CONTEXT_LEN) if promo_A[w] > 0]
if promo_weeks:
ax.scatter(
promo_weeks,
sales_A[promo_weeks],
marker="^",
color="#16a34a",
s=70,
zorder=6,
label="Promotion week",
)
add_divider(ax)
ax.set_ylabel("Weekly Sales (units)", fontsize=10)
ax.set_title(
"Store A -- Actual vs Baseline (No Covariates)", fontsize=11, fontweight="bold"
)
ax.legend(fontsize=7.5, loc="upper left", ncol=2)
ax.grid(True, alpha=0.22)
hm = holiday_A[:CONTEXT_LEN] > 0
pm = promo_A[:CONTEXT_LEN] > 0
h_lift = (
(sales_A[:CONTEXT_LEN][hm] - base_A[:CONTEXT_LEN][hm]).mean() if hm.any() else 0
)
p_lift = (
(sales_A[:CONTEXT_LEN][pm] - base_A[:CONTEXT_LEN][pm]).mean() if pm.any() else 0
)
ax.annotate(
f"Holiday weeks: +{h_lift:.0f} units avg\n"
f"Promotion weeks: +{p_lift:.0f} units avg\n"
f"Future event schedules must be known for XReg",
xy=(0.97, 0.05),
xycoords="axes fraction",
ha="right",
fontsize=8,
bbox=dict(boxstyle="round", fc="#fffbe6", ec="#d4a017", alpha=0.95),
)
# -- (1,0): Price covariate -- full 36 weeks ---------------------------------
ax = axes[1, 0]
for sid in data["stores"]:
ax.plot(
weeks,
data["covariates"]["price"][sid],
color=store_colors[sid],
lw=2,
label=sid,
alpha=0.85,
)
add_divider(ax, label_top=False)
ax.set_xlabel("Week", fontsize=10)
ax.set_ylabel("Price ($)", fontsize=10)
ax.set_title(
"Price Covariate -- Context + Forecast Horizon", fontsize=11, fontweight="bold"
)
ax.legend(fontsize=8, loc="upper right")
ax.grid(True, alpha=0.22)
ax.annotate(
"Prices are planned -- known for forecast horizon\n"
"Price elasticity: -$1 increase -> -20 units sold\n"
"Store A ($12) consistently more expensive than C ($7.50)",
xy=(0.97, 0.05),
xycoords="axes fraction",
ha="right",
fontsize=8,
bbox=dict(boxstyle="round", fc="#fffbe6", ec="#d4a017", alpha=0.95),
)
# -- (1,1): Covariate effect decomposition -----------------------------------
ax = axes[1, 1]
pe = comp_A["price_effect"]
pre = comp_A["promo_effect"]
he = comp_A["holiday_effect"]
ax.fill_between(
weeks,
0,
pe,
alpha=0.65,
color="steelblue",
step="mid",
label=f"Price effect (max +/-{np.abs(pe).max():.0f} units)",
)
ax.fill_between(
weeks,
pe,
pe + pre,
alpha=0.70,
color="#22c55e",
step="mid",
label="Promotion effect (+150 units)",
)
ax.fill_between(
weeks,
pe + pre,
pe + pre + he,
alpha=0.70,
color="darkorange",
step="mid",
label="Holiday effect (+200 units)",
)
total = pe + pre + he
ax.plot(weeks, total, "k-", lw=1.5, alpha=0.75, label="Total covariate effect")
ax.axhline(0, color="black", lw=0.9, alpha=0.6)
add_divider(ax, label_top=False)
ax.set_xlabel("Week", fontsize=10)
ax.set_ylabel("Effect on sales (units)", fontsize=10)
ax.set_title(
"Store A -- Covariate Effect Decomposition", fontsize=11, fontweight="bold"
)
ax.legend(fontsize=7.5, loc="upper right")
ax.grid(True, alpha=0.22, axis="y")
ax.annotate(
f"Holidays (+200) and promotions (+150) dominate\n"
f"Price effect (+/-{np.abs(pe).max():.0f} units) is minor by comparison\n"
f"-> Time-varying covariates explain most sales spikes",
xy=(0.97, 0.55),
xycoords="axes fraction",
ha="right",
fontsize=8,
bbox=dict(boxstyle="round", fc="#fffbe6", ec="#d4a017", alpha=0.95),
)
tick_pos = list(range(0, TOTAL_LEN, 4))
for row in [0, 1]:
for col in [0, 1]:
axes[row, col].set_xticks(tick_pos)
plt.tight_layout()
output_path = OUTPUT_DIR / "covariates_data.png"
plt.savefig(output_path, dpi=150, bbox_inches="tight")
plt.close()
print(f"\n Saved visualization: {output_path}")
def demonstrate_api() -> None:
print("\n" + "=" * 70)
print(" TIMESFM COVARIATES API (TimesFM 2.5)")
print("=" * 70)
print("""
# Installation
pip install timesfm[xreg]
import timesfm
hparams = timesfm.TimesFmHparams(backend="cpu", per_core_batch_size=32, horizon_len=12)
ckpt = timesfm.TimesFmCheckpoint(huggingface_repo_id="google/timesfm-2.5-200m-pytorch")
model = timesfm.TimesFm(hparams=hparams, checkpoint=ckpt)
point_fc, quant_fc = model.forecast_with_covariates(
inputs=[sales_a, sales_b, sales_c],
dynamic_numerical_covariates={"price": [price_a, price_b, price_c]},
dynamic_categorical_covariates={"holiday": [hol_a, hol_b, hol_c]},
static_categorical_covariates={"store_type": ["premium","standard","discount"]},
xreg_mode="xreg + timesfm",
normalize_xreg_target_per_input=True,
)
# point_fc: (num_series, horizon_len)
# quant_fc: (num_series, horizon_len, 10)
""")
def explain_xreg_modes() -> None:
print("\n" + "=" * 70)
print(" XREG MODES")
print("=" * 70)
print("""
"xreg + timesfm" (DEFAULT)
1. TimesFM makes baseline forecast
2. Fit regression on residuals (actual - baseline) ~ covariates
3. Final = TimesFM baseline + XReg adjustment
Best when: covariates explain residual variat
gitextract_oox45t84/
├── .gitattributes
├── .github/
│ └── workflows/
│ ├── main.yml
│ └── manual_publish.yml
├── .gitignore
├── AGENTS.md
├── LICENSE
├── README.md
├── pyproject.toml
├── requirements.txt
├── src/
│ └── timesfm/
│ ├── __init__.py
│ ├── configs.py
│ ├── flax/
│ │ ├── __init__.py
│ │ ├── dense.py
│ │ ├── normalization.py
│ │ ├── transformer.py
│ │ └── util.py
│ ├── timesfm_2p5/
│ │ ├── timesfm_2p5_base.py
│ │ ├── timesfm_2p5_flax.py
│ │ └── timesfm_2p5_torch.py
│ ├── torch/
│ │ ├── __init__.py
│ │ ├── dense.py
│ │ ├── normalization.py
│ │ ├── transformer.py
│ │ └── util.py
│ └── utils/
│ └── xreg_lib.py
├── timesfm-forecasting/
│ ├── SKILL.md
│ ├── examples/
│ │ ├── anomaly-detection/
│ │ │ ├── detect_anomalies.py
│ │ │ └── output/
│ │ │ └── anomaly_detection.json
│ │ ├── covariates-forecasting/
│ │ │ ├── demo_covariates.py
│ │ │ └── output/
│ │ │ ├── covariates_metadata.json
│ │ │ └── sales_with_covariates.csv
│ │ └── global-temperature/
│ │ ├── README.md
│ │ ├── generate_animation_data.py
│ │ ├── generate_gif.py
│ │ ├── generate_html.py
│ │ ├── output/
│ │ │ ├── animation_data.json
│ │ │ ├── forecast_output.csv
│ │ │ ├── forecast_output.json
│ │ │ └── interactive_forecast.html
│ │ ├── run_example.sh
│ │ ├── run_forecast.py
│ │ ├── temperature_anomaly.csv
│ │ └── visualize_forecast.py
│ ├── references/
│ │ ├── api_reference.md
│ │ ├── data_preparation.md
│ │ └── system_requirements.md
│ └── scripts/
│ ├── check_system.py
│ └── forecast_csv.py
└── v1/
├── LICENSE
├── README.md
├── TROUBLESHOOTING.md
├── docs/
│ └── contributing.md
├── experiments/
│ ├── baselines/
│ │ ├── __init__.py
│ │ └── timegpt_pipeline.py
│ ├── extended_benchmarks/
│ │ ├── README.md
│ │ ├── run_timegpt.py
│ │ ├── run_timesfm.py
│ │ └── utils.py
│ └── long_horizon_benchmarks/
│ ├── README.md
│ └── run_eval.py
├── notebooks/
│ ├── covariates.ipynb
│ ├── finetuning.ipynb
│ └── finetuning_torch.ipynb
├── peft/
│ ├── README.md
│ ├── finetune.py
│ ├── finetune.sh
│ └── usage.ipynb
├── pyproject.toml
├── src/
│ ├── adapter/
│ │ ├── __init__.py
│ │ ├── dora_layers.py
│ │ ├── lora_layers.py
│ │ └── utils.py
│ ├── finetuning/
│ │ ├── __init__.py
│ │ ├── finetuning_example.py
│ │ └── finetuning_torch.py
│ └── timesfm/
│ ├── __init__.py
│ ├── data_loader.py
│ ├── patched_decoder.py
│ ├── pytorch_patched_decoder.py
│ ├── time_features.py
│ ├── timesfm_base.py
│ ├── timesfm_jax.py
│ ├── timesfm_torch.py
│ └── xreg_lib.py
└── tests/
└── test_timesfm.py
SYMBOL INDEX (376 symbols across 41 files)
FILE: src/timesfm/configs.py
class ForecastConfig (line 22) | class ForecastConfig:
class ResidualBlockConfig (line 64) | class ResidualBlockConfig:
class RandomFourierFeaturesConfig (line 75) | class RandomFourierFeaturesConfig:
class TransformerConfig (line 85) | class TransformerConfig:
class StackedTransformersConfig (line 101) | class StackedTransformersConfig:
FILE: src/timesfm/flax/dense.py
class ResidualBlock (line 34) | class ResidualBlock(nnx.Module):
method __init__ (line 37) | def __init__(self, config: ResidualBlockConfig, *, rngs=nnx.Rngs(42)):
method __call__ (line 66) | def __call__(self, x: Float[Array, "b ... i"]) -> Float[Array, "b ... ...
class RandomFourierFeatures (line 72) | class RandomFourierFeatures(nnx.Module):
method __init__ (line 77) | def __init__(self, config: RandomFourierFeaturesConfig, *, rngs=nnx.Rn...
method __call__ (line 100) | def __call__(self, x: Float[Array, "b ... i"]) -> Float[Array, "b ... ...
FILE: src/timesfm/flax/normalization.py
class RMSNorm (line 29) | class RMSNorm(nnx.Module):
method __init__ (line 34) | def __init__(
method __call__ (line 46) | def __call__(self, inputs: Float[Array, "b ... d"]) -> Float[Array, "b...
class LayerNorm (line 53) | class LayerNorm(nnx.Module):
method __init__ (line 58) | def __init__(self, num_features: int, *, epsilon: float = 1e-6, rngs=n...
method __call__ (line 65) | def __call__(self, inputs: Float[Array, "b ... d"]) -> Float[Array, "b...
FILE: src/timesfm/flax/transformer.py
function make_attn_mask (line 46) | def make_attn_mask(
class RotaryPositionalEmbedding (line 67) | class RotaryPositionalEmbedding(nnx.Module):
method __init__ (line 70) | def __init__(
method __call__ (line 80) | def __call__(
class PerDimScale (line 118) | class PerDimScale(nnx.Module):
method __init__ (line 123) | def __init__(self, num_dims: int, *, rngs=nnx.Rngs(42)):
method __call__ (line 128) | def __call__(self, x: Float[Array, "b ... d"]) -> Float[Array, "b ... ...
class MultiHeadAttention (line 134) | class MultiHeadAttention(nnx.Module):
method __init__ (line 137) | def __init__(
method __call__ (line 207) | def __call__(
class Transformer (line 291) | class Transformer(nnx.Module):
method __init__ (line 294) | def __init__(self, config: TransformerConfig, *, rngs=nnx.Rngs(42)):
method __call__ (line 338) | def __call__(
FILE: src/timesfm/flax/util.py
class DecodeCache (line 33) | class DecodeCache:
function update_running_stats (line 43) | def update_running_stats(
function scan_along_axis (line 80) | def scan_along_axis(f, init, xs, axis: int, **kwargs):
function revin (line 91) | def revin(
FILE: src/timesfm/timesfm_2p5/timesfm_2p5_base.py
function strip_leading_nans (line 33) | def strip_leading_nans(arr):
function linear_interpolation (line 49) | def linear_interpolation(arr):
class TimesFM_2p5_200M_Definition (line 85) | class TimesFM_2p5_200M_Definition:
class TimesFM_2p5 (line 134) | class TimesFM_2p5:
method load_checkpoint (line 147) | def load_checkpoint(self, path: str):
method compile (line 151) | def compile(self, forecast_config: ForecastConfig | None = None):
method forecast (line 155) | def forecast(
method forecast_with_covariates (line 198) | def forecast_with_covariates(
FILE: src/timesfm/timesfm_2p5/timesfm_2p5_flax.py
function try_gc (line 48) | def try_gc():
function _create_stacked_transformers (line 59) | def _create_stacked_transformers(
function _scan_along_axis (line 65) | def _scan_along_axis(f, init, xs, axis: int, **kwargs):
function _apply_stacked_transformers (line 76) | def _apply_stacked_transformers(
class TimesFM_2p5_200M_flax_module (line 85) | class TimesFM_2p5_200M_flax_module(nnx.Module): # pylint: disable=inval...
method __init__ (line 96) | def __init__(self):
method __call__ (line 126) | def __call__(
method decode (line 149) | def decode(self, horizon: int, inputs, masks):
method compile (line 237) | def compile(
function _flip_quantile_fn (line 276) | def _flip_quantile_fn(x):
function _force_flip_invariance_fn (line 284) | def _force_flip_invariance_fn(
function _use_continuous_quantile_head_fn (line 309) | def _use_continuous_quantile_head_fn(full_forecast, quantile_spreads, ma...
function _fix_quantile_crossing_fn (line 329) | def _fix_quantile_crossing_fn(full_forecast):
function _before_model_decode (line 357) | def _before_model_decode(fc, inputs, masks):
function _after_model_decode (line 385) | def _after_model_decode(
class TimesFM_2p5_200M_flax (line 445) | class TimesFM_2p5_200M_flax(timesfm_2p5_base.TimesFM_2p5):
method from_pretrained (line 451) | def from_pretrained(
method compile (line 494) | def compile(
FILE: src/timesfm/timesfm_2p5/timesfm_2p5_torch.py
class TimesFM_2p5_200M_torch_module (line 36) | class TimesFM_2p5_200M_torch_module(nn.Module):
method __init__ (line 41) | def __init__(self):
method load_checkpoint (line 79) | def load_checkpoint(self, path: str, **kwargs):
method forward (line 93) | def forward(
method decode (line 122) | def decode(self, horizon: int, inputs, masks):
method forecast_naive (line 228) | def forecast_naive(
class TimesFM_2p5_200M_torch (line 266) | class TimesFM_2p5_200M_torch(
method __init__ (line 282) | def __init__(
method _from_pretrained (line 293) | def _from_pretrained(
method _save_pretrained (line 341) | def _save_pretrained(self, save_directory: Union[str, Path]):
method compile (line 352) | def compile(self, forecast_config: configs.ForecastConfig, **kwargs) -...
FILE: src/timesfm/torch/dense.py
class ResidualBlock (line 23) | class ResidualBlock(nn.Module):
method __init__ (line 26) | def __init__(self, config: configs.ResidualBlockConfig):
method forward (line 53) | def forward(self, x: torch.Tensor) -> torch.Tensor:
class RandomFourierFeatures (line 59) | class RandomFourierFeatures(nn.Module):
method __init__ (line 62) | def __init__(self, config: configs.RandomFourierFeaturesConfig):
method forward (line 84) | def forward(self, x: torch.Tensor) -> torch.Tensor:
FILE: src/timesfm/torch/normalization.py
class RMSNorm (line 21) | class RMSNorm(nn.Module):
method __init__ (line 24) | def __init__(
method forward (line 35) | def forward(self, inputs: torch.Tensor) -> torch.Tensor:
FILE: src/timesfm/torch/transformer.py
function make_attn_mask (line 32) | def make_attn_mask(
class RotaryPositionalEmbedding (line 56) | class RotaryPositionalEmbedding(nn.Module):
method __init__ (line 59) | def __init__(
method forward (line 70) | def forward(
function _dot_product_attention (line 114) | def _dot_product_attention(
function _torch_dot_product_attention (line 132) | def _torch_dot_product_attention(query, key, value, mask=None):
class PerDimScale (line 154) | class PerDimScale(nn.Module):
method __init__ (line 157) | def __init__(self, num_dims: int):
method forward (line 162) | def forward(self, x: torch.Tensor) -> torch.Tensor:
class MultiHeadAttention (line 169) | class MultiHeadAttention(nn.Module):
method __init__ (line 172) | def __init__(
method forward (line 224) | def forward(
class Transformer (line 307) | class Transformer(nn.Module):
method __init__ (line 310) | def __init__(self, config: configs.TransformerConfig):
method forward (line 354) | def forward(
FILE: src/timesfm/torch/util.py
class DecodeCache (line 24) | class DecodeCache:
function update_running_stats (line 33) | def update_running_stats(
function revin (line 77) | def revin(
FILE: src/timesfm/utils/xreg_lib.py
function _unnest (line 36) | def _unnest(nested: Sequence[Sequence[Any]]) -> np.ndarray:
function _repeat (line 40) | def _repeat(elements: Iterable[Any], counts: Iterable[int]) -> np.ndarray:
function _to_padded_jax_array (line 46) | def _to_padded_jax_array(x: np.ndarray) -> jax.Array:
function normalize (line 61) | def normalize(batch):
function renormalize (line 68) | def renormalize(batch, stats):
class BatchedInContextXRegBase (line 72) | class BatchedInContextXRegBase:
method __init__ (line 97) | def __init__(
method _assert_covariates (line 210) | def _assert_covariates(self, assert_covariate_shapes: bool = False) ->...
method create_covariate_matrix (line 327) | def create_covariate_matrix(
method fit (line 407) | def fit(self) -> Any:
class BatchedInContextXRegLinear (line 411) | class BatchedInContextXRegLinear(BatchedInContextXRegBase):
method fit (line 414) | def fit(
FILE: timesfm-forecasting/examples/anomaly-detection/detect_anomalies.py
function detect_context_anomalies (line 50) | def detect_context_anomalies(
function build_synthetic_future (line 98) | def build_synthetic_future(
function detect_forecast_anomalies (line 121) | def detect_forecast_anomalies(
function plot_results (line 172) | def plot_results(
function main (line 391) | def main() -> None:
FILE: timesfm-forecasting/examples/covariates-forecasting/demo_covariates.py
function generate_sales_data (line 49) | def generate_sales_data() -> dict:
function create_visualization (line 132) | def create_visualization(data: dict) -> None:
function demonstrate_api (line 405) | def demonstrate_api() -> None:
function explain_xreg_modes (line 431) | def explain_xreg_modes() -> None:
function main (line 450) | def main() -> None:
FILE: timesfm-forecasting/examples/global-temperature/generate_animation_data.py
function main (line 30) | def main() -> None:
FILE: timesfm-forecasting/examples/global-temperature/generate_gif.py
function create_frame (line 26) | def create_frame(
function main (line 157) | def main() -> None:
FILE: timesfm-forecasting/examples/global-temperature/generate_html.py
function main (line 521) | def main() -> None:
FILE: timesfm-forecasting/examples/global-temperature/visualize_forecast.py
function main (line 30) | def main() -> None:
FILE: timesfm-forecasting/scripts/check_system.py
class CheckResult (line 75) | class CheckResult:
method icon (line 82) | def icon(self) -> str:
method __str__ (line 85) | def __str__(self) -> str:
class SystemReport (line 90) | class SystemReport:
method passed (line 99) | def passed(self) -> bool:
method to_dict (line 102) | def to_dict(self) -> dict[str, Any]:
function _get_total_ram_gb (line 127) | def _get_total_ram_gb() -> float:
function _get_available_ram_gb (line 174) | def _get_available_ram_gb() -> float:
function check_ram (line 223) | def check_ram(profile: dict[str, Any]) -> CheckResult:
function check_gpu (line 263) | def check_gpu() -> CheckResult:
function check_disk (line 304) | def check_disk(profile: dict[str, Any]) -> CheckResult:
function check_python (line 337) | def check_python() -> CheckResult:
function check_package (line 358) | def check_package(pkg_name: str, import_name: str | None = None) -> Chec...
function recommend_batch_size (line 384) | def recommend_batch_size(report: SystemReport) -> int:
function estimate_memory_gb (line 428) | def estimate_memory_gb(
function check_dataset_fit (line 481) | def check_dataset_fit(
function print_memory_estimate (line 539) | def print_memory_estimate(
function run_checks (line 595) | def run_checks(model_version: str = "v2.5") -> SystemReport:
function print_report (line 637) | def print_report(report: SystemReport) -> None:
function main (line 654) | def main() -> None:
FILE: timesfm-forecasting/scripts/forecast_csv.py
function run_preflight (line 32) | def run_preflight() -> dict:
function load_model (line 49) | def load_model(batch_size: int = 32):
function load_csv (line 78) | def load_csv(
function forecast_series (line 118) | def forecast_series(
function write_csv_output (line 144) | def write_csv_output(
function write_json_output (line 187) | def write_json_output(results: dict[str, dict], output_path: str) -> None:
function main (line 194) | def main() -> None:
FILE: v1/experiments/baselines/timegpt_pipeline.py
function get_seasonality (line 33) | def get_seasonality(freq: str) -> int:
function maybe_convert_col_to_datetime (line 37) | def maybe_convert_col_to_datetime(
function zero_pad_time_series (line 46) | def zero_pad_time_series(df, freq, min_length=36):
class Forecaster (line 84) | class Forecaster:
method forecast (line 90) | def forecast(
method cross_validation (line 98) | def cross_validation(
class TimeGPT (line 152) | class TimeGPT(Forecaster):
method __init__ (line 159) | def __init__(
method _get_client (line 173) | def _get_client(self) -> NixtlaClient:
method forecast (line 184) | def forecast(
function run_timegpt (line 225) | def run_timegpt(
FILE: v1/experiments/extended_benchmarks/run_timegpt.py
function main (line 70) | def main():
FILE: v1/experiments/extended_benchmarks/run_timesfm.py
function main (line 88) | def main():
FILE: v1/experiments/extended_benchmarks/utils.py
function parallel_transform (line 36) | def parallel_transform(inp):
function quantile_loss (line 41) | def quantile_loss(
class ExperimentHandler (line 59) | class ExperimentHandler:
method __init__ (line 61) | def __init__(
method _maybe_download_m3_or_m5_file (line 104) | def _maybe_download_m3_or_m5_file(dataset: str):
method _transform_quantiles_to_levels (line 124) | def _transform_quantiles_to_levels(quantiles: List[float]) -> List[int]:
method _create_dir_if_not_exists (line 132) | def _create_dir_if_not_exists(directory: str):
method _transform_gluonts_instance_to_df (line 136) | def _transform_gluonts_instance_to_df(
method _transform_gluonts_dataset_to_df (line 151) | def _transform_gluonts_dataset_to_df(
method train_df (line 164) | def train_df(self) -> pd.DataFrame:
method test_df (line 169) | def test_df(self) -> pd.DataFrame:
method save_dataframe (line 177) | def save_dataframe(self, df: pd.DataFrame, file_name: str):
method save_results (line 180) | def save_results(
method fcst_from_level_to_quantiles (line 193) | def fcst_from_level_to_quantiles(
method evaluate_models (line 213) | def evaluate_models(self, models: List[str]) -> pd.DataFrame:
method evaluate_from_predictions (line 232) | def evaluate_from_predictions(
FILE: v1/experiments/long_horizon_benchmarks/run_eval.py
function get_forecasts (line 95) | def get_forecasts(model_path, model, past, freq, pred_len):
function _mse (line 112) | def _mse(y_pred, y_true):
function _mae (line 117) | def _mae(y_pred, y_true):
function _smape (line 122) | def _smape(y_pred, y_true):
function eval (line 131) | def eval():
FILE: v1/peft/finetune.py
function finetune (line 63) | def finetune(
FILE: v1/src/adapter/dora_layers.py
class DoraTheta (line 23) | class DoraTheta(base_layer.Theta):
method __init__ (line 24) | def __init__(self, module):
method _dora_initialized (line 27) | def _dora_initialized(self):
method _dorafy_var (line 40) | def _dorafy_var(self, w):
method __getattr__ (line 55) | def __getattr__(self, k):
method __getitem__ (line 65) | def __getitem__(self, k):
class DoraThetaDescriptor (line 76) | class DoraThetaDescriptor:
method __get__ (line 79) | def __get__(self, obj, objtype=None):
class DoraLinear (line 83) | class DoraLinear(linears.Linear):
method setup (line 88) | def setup(self) -> None:
class DoraAttentionProjection (line 121) | class DoraAttentionProjection(attentions.AttentionProjection):
method setup (line 126) | def setup(self) -> None:
class DoraCombinedQKVProjection (line 166) | class DoraCombinedQKVProjection(attentions.CombinedQKVProjectionLayer):
method setup (line 171) | def setup(self) -> None:
FILE: v1/src/adapter/lora_layers.py
class LoraTheta (line 23) | class LoraTheta(base_layer.Theta):
method __init__ (line 24) | def __init__(self, module):
method _lora_initialized (line 27) | def _lora_initialized(self):
method _lorafy_var (line 38) | def _lorafy_var(self, w):
method __getattr__ (line 46) | def __getattr__(self, k):
method __getitem__ (line 56) | def __getitem__(self, k):
class LoraThetaDescriptor (line 67) | class LoraThetaDescriptor:
method __get__ (line 70) | def __get__(self, obj, objtype=None):
class LoraLinear (line 74) | class LoraLinear(linears.Linear):
method setup (line 79) | def setup(self) -> None:
class LoraAttentionProjection (line 103) | class LoraAttentionProjection(attentions.AttentionProjection):
method setup (line 108) | def setup(self) -> None:
class LoraCombinedQKVProjection (line 139) | class LoraCombinedQKVProjection(attentions.CombinedQKVProjectionLayer):
method setup (line 144) | def setup(self) -> None:
FILE: v1/src/adapter/utils.py
function get_adapter_params (line 43) | def get_adapter_params(
function load_adapter_checkpoint (line 101) | def load_adapter_checkpoint(
function _merge_adapter_weights (line 200) | def _merge_adapter_weights(
function _get_adapter_weight_params (line 281) | def _get_adapter_weight_params(
function load_adapter_layer (line 334) | def load_adapter_layer(
function _initialize_adapter_params (line 417) | def _initialize_adapter_params(
FILE: v1/src/finetuning/finetuning_example.py
class TimeSeriesDataset (line 51) | class TimeSeriesDataset(Dataset):
method __init__ (line 54) | def __init__(self,
method _prepare_samples (line 77) | def _prepare_samples(self) -> None:
method __len__ (line 88) | def __len__(self) -> int:
method __getitem__ (line 91) | def __getitem__(
function prepare_datasets (line 105) | def prepare_datasets(series: np.ndarray,
function get_model (line 141) | def get_model(load_weights: bool = False):
function plot_predictions (line 171) | def plot_predictions(
function get_data (line 248) | def get_data(context_len: int,
function single_gpu_example (line 269) | def single_gpu_example():
function setup_process (line 300) | def setup_process(rank, world_size, model, config, train_dataset, val_da...
function multi_gpu_example (line 335) | def multi_gpu_example():
function main (line 372) | def main(argv):
FILE: v1/src/finetuning/finetuning_torch.py
class MetricsLogger (line 21) | class MetricsLogger(ABC):
method log_metrics (line 29) | def log_metrics(self,
method close (line 41) | def close(self) -> None:
class WandBLogger (line 46) | class WandBLogger(MetricsLogger):
method __init__ (line 55) | def __init__(self, project: str, config: Dict[str, Any], rank: int = 0):
method log_metrics (line 60) | def log_metrics(self,
method close (line 72) | def close(self) -> None:
class DistributedManager (line 78) | class DistributedManager:
method __init__ (line 89) | def __init__(
method setup (line 103) | def setup(self) -> None:
method cleanup (line 113) | def cleanup(self) -> None:
class FinetuningConfig (line 120) | class FinetuningConfig:
class TimesFMFinetuner (line 162) | class TimesFMFinetuner:
method __init__ (line 173) | def __init__(
method _setup_distributed_model (line 203) | def _setup_distributed_model(self) -> nn.Module:
method _create_dataloader (line 210) | def _create_dataloader(self, dataset: Dataset, is_train: bool) -> Data...
method _quantile_loss (line 236) | def _quantile_loss(self, pred: torch.Tensor, actual: torch.Tensor,
method _process_batch (line 251) | def _process_batch(self, batch: List[torch.Tensor]) -> tuple:
method _train_epoch (line 279) | def _train_epoch(self, train_loader: DataLoader,
method _validate (line 312) | def _validate(self, val_loader: DataLoader) -> float:
method finetune (line 339) | def finetune(self, train_dataset: Dataset,
FILE: v1/src/timesfm/data_loader.py
class TimeSeriesdata (line 27) | class TimeSeriesdata(object):
method __init__ (line 30) | def __init__(
method _get_cat_cols (line 120) | def _get_cat_cols(self, cat_cov_cols):
method _normalize_data (line 131) | def _normalize_data(self):
method train_gen (line 137) | def train_gen(self):
method test_val_gen (line 178) | def test_val_gen(self, mode='val', shift=1):
method _get_features_and_ts (line 220) | def _get_features_and_ts(self, dtimes, tsidx, hist_len=None):
method tf_dataset (line 245) | def tf_dataset(self, mode='train', shift=1):
FILE: v1/src/timesfm/patched_decoder.py
function _shift_padded_seq (line 61) | def _shift_padded_seq(mask: JTensor, seq: JTensor) -> JTensor:
class ResidualBlock (line 84) | class ResidualBlock(base_layer.BaseLayer):
method setup (line 107) | def setup(self):
method __call__ (line 146) | def __call__(self, inputs: JTensor) -> JTensor:
function _masked_mean_std (line 157) | def _masked_mean_std(inputs: JTensor,
function _create_quantiles (line 206) | def _create_quantiles() -> list[float]:
class PatchedTimeSeriesDecoder (line 211) | class PatchedTimeSeriesDecoder(base_layer.BaseLayer):
method setup (line 242) | def setup(self) -> None:
method transform_decode_state (line 288) | def transform_decode_state(
method _forward_transform (line 293) | def _forward_transform(
method _reverse_transform (line 305) | def _reverse_transform(self, outputs: JTensor,
method _preprocess_input (line 311) | def _preprocess_input(
method _postprocess_output (line 350) | def _postprocess_output(
method __call__ (line 365) | def __call__(self, inputs: NestedMap) -> NestedMap:
method decode (line 399) | def decode(
class PatchedDecoderFinetuneModel (line 482) | class PatchedDecoderFinetuneModel(base_model.BaseModel):
method setup (line 493) | def setup(self) -> None:
method compute_predictions (line 496) | def compute_predictions(self, input_batch: NestedMap) -> NestedMap:
method _quantile_loss (line 515) | def _quantile_loss(self, pred: JTensor, actual: JTensor,
method compute_loss (line 532) | def compute_loss(self, prediction_output: NestedMap,
FILE: v1/src/timesfm/pytorch_patched_decoder.py
function create_quantiles (line 24) | def create_quantiles() -> list[float]:
class TimesFMConfig (line 29) | class TimesFMConfig:
function _masked_mean_std (line 62) | def _masked_mean_std(
function _shift_padded_seq (line 112) | def _shift_padded_seq(mask: torch.Tensor, seq: torch.Tensor) -> torch.Te...
function get_large_negative_number (line 146) | def get_large_negative_number(dtype: torch.dtype) -> torch.Tensor:
function apply_mask_to_logits (line 155) | def apply_mask_to_logits(logits: torch.Tensor,
function convert_paddings_to_mask (line 173) | def convert_paddings_to_mask(
function causal_mask (line 191) | def causal_mask(input_t: torch.Tensor) -> torch.Tensor:
function merge_masks (line 211) | def merge_masks(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
class ResidualBlock (line 239) | class ResidualBlock(nn.Module):
method __init__ (line 242) | def __init__(
method forward (line 264) | def forward(self, x):
class RMSNorm (line 271) | class RMSNorm(torch.nn.Module):
method __init__ (line 274) | def __init__(
method _norm (line 285) | def _norm(self, x):
method forward (line 288) | def forward(self, x):
class TransformerMLP (line 297) | class TransformerMLP(nn.Module):
method __init__ (line 300) | def __init__(
method forward (line 310) | def forward(self, x, paddings=None):
class TimesFMAttention (line 320) | class TimesFMAttention(nn.Module):
method __init__ (line 323) | def __init__(
method _per_dim_scaling (line 352) | def _per_dim_scaling(self, query: torch.Tensor) -> torch.Tensor:
method forward (line 360) | def forward(
class TimesFMDecoderLayer (line 418) | class TimesFMDecoderLayer(nn.Module):
method __init__ (line 421) | def __init__(
method forward (line 443) | def forward(
class StackedDecoder (line 468) | class StackedDecoder(nn.Module):
method __init__ (line 471) | def __init__(
method forward (line 495) | def forward(
class PositionalEmbedding (line 518) | class PositionalEmbedding(torch.nn.Module):
method __init__ (line 529) | def __init__(
method forward (line 540) | def forward(self, seq_length=None, position=None):
class PatchedTimeSeriesDecoder (line 574) | class PatchedTimeSeriesDecoder(nn.Module):
method __init__ (line 577) | def __init__(self, config: TimesFMConfig):
method _forward_transform (line 604) | def _forward_transform(
method _reverse_transform (line 622) | def _reverse_transform(
method _preprocess_input (line 629) | def _preprocess_input(
method _postprocess_output (line 677) | def _postprocess_output(
method forward (line 694) | def forward(
method decode (line 712) | def decode(
FILE: v1/src/timesfm/time_features.py
function _distance_to_holiday (line 45) | def _distance_to_holiday(holiday):
class TimeCovariates (line 112) | class TimeCovariates(object):
method __init__ (line 115) | def __init__(
method _minute_of_hour (line 135) | def _minute_of_hour(self):
method _hour_of_day (line 141) | def _hour_of_day(self):
method _day_of_week (line 147) | def _day_of_week(self):
method _day_of_month (line 153) | def _day_of_month(self):
method _day_of_year (line 159) | def _day_of_year(self):
method _month_of_year (line 165) | def _month_of_year(self):
method _week_of_year (line 171) | def _week_of_year(self):
method _get_holidays (line 177) | def _get_holidays(self):
method get_covariates (line 186) | def get_covariates(self):
FILE: v1/src/timesfm/timesfm_base.py
function process_group (line 39) | def process_group(key, group, value_name, forecast_context_len):
function moving_average (line 44) | def moving_average(arr, window_size):
function freq_map (line 53) | def freq_map(freq: str):
function strip_leading_nans (line 77) | def strip_leading_nans(arr):
function linear_interpolation (line 94) | def linear_interpolation(arr):
function _normalize (line 131) | def _normalize(batch):
function _renormalize (line 140) | def _renormalize(batch, stats):
class TimesFmHparams (line 145) | class TimesFmHparams:
class TimesFmCheckpoint (line 184) | class TimesFmCheckpoint:
class TimesFmBase (line 205) | class TimesFmBase:
method _logging (line 214) | def _logging(self, s):
method __post_init__ (line 217) | def __post_init__(self) -> None:
method __init__ (line 221) | def __init__(self, hparams: TimesFmHparams,
method load_from_checkpoint (line 253) | def load_from_checkpoint(self, checkpoint: TimesFmCheckpoint) -> None:
method _preprocess (line 257) | def _preprocess(
method _forecast (line 314) | def _forecast(
method forecast (line 347) | def forecast(
method forecast_with_covariates (line 429) | def forecast_with_covariates(
method forecast_on_df (line 644) | def forecast_on_df(
FILE: v1/src/timesfm/timesfm_jax.py
class TimesFmJax (line 41) | class TimesFmJax(timesfm_base.TimesFmBase):
method _get_sample_inputs (line 57) | def _get_sample_inputs(self):
method __post_init__ (line 85) | def __post_init__(self):
method load_from_checkpoint (line 94) | def load_from_checkpoint(
method jit_decode (line 178) | def jit_decode(self):
method _forecast (line 239) | def _forecast(
FILE: v1/src/timesfm/timesfm_torch.py
class TimesFmTorch (line 30) | class TimesFmTorch(timesfm_base.TimesFmBase):
method __post_init__ (line 33) | def __post_init__(self):
method load_from_checkpoint (line 52) | def load_from_checkpoint(
method _forecast (line 72) | def _forecast(
FILE: v1/src/timesfm/xreg_lib.py
function _unnest (line 31) | def _unnest(nested: Sequence[Sequence[Any]]) -> np.ndarray:
function _repeat (line 35) | def _repeat(elements: Iterable[Any], counts: Iterable[int]) -> np.ndarray:
function _to_padded_jax_array (line 42) | def _to_padded_jax_array(x: np.ndarray) -> jax.Array:
class BatchedInContextXRegBase (line 56) | class BatchedInContextXRegBase:
method __init__ (line 81) | def __init__(
method _assert_covariates (line 193) | def _assert_covariates(self, assert_covariate_shapes: bool = False) ->...
method create_covariate_matrix (line 298) | def create_covariate_matrix(
method fit (line 377) | def fit(self) -> Any:
class BatchedInContextXRegLinear (line 381) | class BatchedInContextXRegLinear(BatchedInContextXRegBase):
method fit (line 384) | def fit(
FILE: v1/tests/test_timesfm.py
function create_sample_dataframe (line 25) | def create_sample_dataframe(
function test_timesfm_forecast_on_df (line 48) | def test_timesfm_forecast_on_df(
Condensed preview — 85 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (995K chars).
[
{
"path": ".gitattributes",
"chars": 197,
"preview": "# Git LFS tracking for binary outputs in timesfm-forecasting skill\ntimesfm-forecasting/**/*.png filter=lfs diff=lfs merg"
},
{
"path": ".github/workflows/main.yml",
"chars": 687,
"preview": "name: Python package build\n\non:\n push:\n branches: [ \"master\" ]\n pull_request:\n branches: [ \"master\" ]\n\njobs:\n b"
},
{
"path": ".github/workflows/manual_publish.yml",
"chars": 794,
"preview": "name: Manual PyPI Publish\n\non:\n workflow_dispatch:\n\njobs:\n build-and-publish:\n runs-on: ubuntu-latest\n steps:\n "
},
{
"path": ".gitignore",
"chars": 108,
"preview": ".venv/\ndist/\n__pycache__/\ncheckpoints/\nwandb/\ndatasets/\nresults/\ntimesfm_jax.egg-info/\ndevelopment_setup.md\n"
},
{
"path": "AGENTS.md",
"chars": 844,
"preview": "# TimesFM — Agent Entry Point\n\nThis repository ships a first-party **Agent Skill** for TimesFM at:\n\n```\ntimesfm-forecast"
},
{
"path": "LICENSE",
"chars": 11358,
"preview": "\n Apache License\n Version 2.0, January 2004\n "
},
{
"path": "README.md",
"chars": 3338,
"preview": "# TimesFM\n\nTimesFM (Time Series Foundation Model) is a pretrained time-series foundation\nmodel developed by Google Resea"
},
{
"path": "pyproject.toml",
"chars": 955,
"preview": "[project]\nname = \"timesfm\"\nversion = \"2.0.0\"\ndescription = \"A time series foundation model.\"\nauthors = [\n {name = \"Ra"
},
{
"path": "requirements.txt",
"chars": 1015,
"preview": "# This file was autogenerated by uv via the following command:\n# uv pip compile pyproject.toml -o requirements.txt\nan"
},
{
"path": "src/timesfm/__init__.py",
"chars": 919,
"preview": "# Copyright 2025 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "src/timesfm/configs.py",
"chars": 3617,
"preview": "# Copyright 2025 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "src/timesfm/flax/__init__.py",
"chars": 574,
"preview": "# Copyright 2025 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "src/timesfm/flax/dense.py",
"chars": 3522,
"preview": "# Copyright 2025 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "src/timesfm/flax/normalization.py",
"chars": 2168,
"preview": "# Copyright 2025 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "src/timesfm/flax/transformer.py",
"chars": 11506,
"preview": "# Copyright 2025 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "src/timesfm/flax/util.py",
"chars": 3027,
"preview": "# Copyright 2025 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "src/timesfm/timesfm_2p5/timesfm_2p5_base.py",
"chars": 15174,
"preview": "# Copyright 2025 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "src/timesfm/timesfm_2p5/timesfm_2p5_flax.py",
"chars": 19375,
"preview": "# Copyright 2025 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "src/timesfm/timesfm_2p5/timesfm_2p5_torch.py",
"chars": 17324,
"preview": "# Copyright 2025 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "src/timesfm/torch/__init__.py",
"chars": 574,
"preview": "# Copyright 2025 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "src/timesfm/torch/dense.py",
"chars": 3109,
"preview": "# Copyright 2025 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "src/timesfm/torch/normalization.py",
"chars": 1204,
"preview": "# Copyright 2025 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "src/timesfm/torch/transformer.py",
"chars": 11720,
"preview": "# Copyright 2025 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "src/timesfm/torch/util.py",
"chars": 2654,
"preview": "# Copyright 2025 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "src/timesfm/utils/xreg_lib.py",
"chars": 20721,
"preview": "# Copyright 2025 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "timesfm-forecasting/SKILL.md",
"chars": 18707,
"preview": "---\nname: timesfm-forecasting\ndescription: >\n Zero-shot time series forecasting with Google's TimesFM foundation model."
},
{
"path": "timesfm-forecasting/examples/anomaly-detection/detect_anomalies.py",
"chars": 17022,
"preview": "#!/usr/bin/env python3\n\"\"\"\nTimesFM Anomaly Detection Example — Two-Phase Method\n\nPhase 1 (context): Linear detrend + Z-s"
},
{
"path": "timesfm-forecasting/examples/anomaly-detection/output/anomaly_detection.json",
"chars": 9019,
"preview": "{\n \"method\": \"two_phase\",\n \"context_method\": \"linear_detrend_zscore\",\n \"forecast_method\": \"quantile_prediction_interv"
},
{
"path": "timesfm-forecasting/examples/covariates-forecasting/demo_covariates.py",
"chars": 19694,
"preview": "#!/usr/bin/env python3\n\"\"\"\nTimesFM Covariates (XReg) Example\n\nDemonstrates the TimesFM covariate API using synthetic ret"
},
{
"path": "timesfm-forecasting/examples/covariates-forecasting/output/covariates_metadata.json",
"chars": 1586,
"preview": "{\n \"description\": \"Synthetic retail sales data with covariates for TimesFM XReg demo\",\n \"note_on_real_data\": \"For real"
},
{
"path": "timesfm-forecasting/examples/covariates-forecasting/output/sales_with_covariates.csv",
"chars": 7399,
"preview": "store_id,week,split,sales,base_sales,price,price_effect,promotion,holiday,day_of_week,store_type,region\nstore_A,0,contex"
},
{
"path": "timesfm-forecasting/examples/global-temperature/README.md",
"chars": 5708,
"preview": "# TimesFM Forecast Report: Global Temperature Anomaly (2025)\n\n**Model:** TimesFM 1.0 (200M) PyTorch \n**Generated:** 202"
},
{
"path": "timesfm-forecasting/examples/global-temperature/generate_animation_data.py",
"chars": 4993,
"preview": "#!/usr/bin/env python3\n\"\"\"\nGenerate animation data for interactive forecast visualization.\n\nThis script runs TimesFM for"
},
{
"path": "timesfm-forecasting/examples/global-temperature/generate_gif.py",
"chars": 6646,
"preview": "#!/usr/bin/env python3\n\"\"\"\nGenerate animated GIF showing forecast evolution.\n\nCreates a GIF animation showing how the Ti"
},
{
"path": "timesfm-forecasting/examples/global-temperature/generate_html.py",
"chars": 21136,
"preview": "#!/usr/bin/env python3\n\"\"\"\nGenerate a self-contained HTML file with embedded animation data.\n\nThis creates a single HTML"
},
{
"path": "timesfm-forecasting/examples/global-temperature/output/animation_data.json",
"chars": 133208,
"preview": "{\n \"metadata\": {\n \"model\": \"TimesFM 1.0 (200M) PyTorch\",\n \"total_steps\": 25,\n \"min_context\": 12,\n \"max_hori"
},
{
"path": "timesfm-forecasting/examples/global-temperature/output/forecast_output.csv",
"chars": 1494,
"preview": "date,point_forecast,q10,q20,q30,q40,q50,q60,q70,q80,q90,q99\n2025-01-01,1.2593384,1.248188,1.140702,1.1880752,1.2137158,1"
},
{
"path": "timesfm-forecasting/examples/global-temperature/output/forecast_output.json",
"chars": 4524,
"preview": "{\n \"model\": \"TimesFM 1.0 (200M) PyTorch\",\n \"input\": {\n \"source\": \"NOAA GISTEMP Global Temperature Anomaly\",\n \"n_"
},
{
"path": "timesfm-forecasting/examples/global-temperature/output/interactive_forecast.html",
"chars": 153042,
"preview": "<!DOCTYPE html>\n<html lang=\"en\">\n<head>\n <meta charset=\"UTF-8\">\n <meta name=\"viewport\" content=\"width=device-width"
},
{
"path": "timesfm-forecasting/examples/global-temperature/run_example.sh",
"chars": 1519,
"preview": "#!/bin/bash\n# run_example.sh - Run the TimesFM temperature anomaly forecasting example\n#\n# This script:\n# 1. Runs the pr"
},
{
"path": "timesfm-forecasting/examples/global-temperature/run_forecast.py",
"chars": 5470,
"preview": "#!/usr/bin/env python3\n\"\"\"\nRun TimesFM forecast on global temperature anomaly data.\nGenerates forecast output CSV and JS"
},
{
"path": "timesfm-forecasting/examples/global-temperature/temperature_anomaly.csv",
"chars": 591,
"preview": "date,anomaly_c\n2022-01-01,0.89\n2022-02-01,0.89\n2022-03-01,1.02\n2022-04-01,0.88\n2022-05-01,0.85\n2022-06-01,0.88\n2022-07-0"
},
{
"path": "timesfm-forecasting/examples/global-temperature/visualize_forecast.py",
"chars": 3287,
"preview": "#!/usr/bin/env python3\n\"\"\"\nVisualize TimesFM forecast results for global temperature anomaly.\n\nGenerates a publication-q"
},
{
"path": "timesfm-forecasting/references/api_reference.md",
"chars": 9725,
"preview": "# TimesFM API Reference\n\n## Model Classes\n\n### `timesfm.TimesFM_2p5_200M_torch`\n\nThe primary model class for TimesFM 2.5"
},
{
"path": "timesfm-forecasting/references/data_preparation.md",
"chars": 7175,
"preview": "# Data Preparation for TimesFM\n\n## Input Format\n\nTimesFM accepts a **list of 1-D numpy arrays**. Each array represents o"
},
{
"path": "timesfm-forecasting/references/system_requirements.md",
"chars": 7004,
"preview": "# System Requirements for TimesFM\n\n## Hardware Tiers\n\nTimesFM can run on a variety of hardware configurations. This guid"
},
{
"path": "timesfm-forecasting/scripts/check_system.py",
"chars": 24136,
"preview": "#!/usr/bin/env python3\n\"\"\"TimesFM System Requirements Preflight Checker.\n\nMANDATORY: Run this script before loading Time"
},
{
"path": "timesfm-forecasting/scripts/forecast_csv.py",
"chars": 8664,
"preview": "#!/usr/bin/env python3\n\"\"\"End-to-end CSV forecasting with TimesFM.\n\nLoads a CSV, runs the system preflight check, loads "
},
{
"path": "v1/LICENSE",
"chars": 11358,
"preview": "\n Apache License\n Version 2.0, January 2004\n "
},
{
"path": "v1/README.md",
"chars": 13885,
"preview": "# TimesFM\n\nTimesFM (Time Series Foundation Model) is a pretrained time-series foundation model developed by Google\nRese"
},
{
"path": "v1/TROUBLESHOOTING.md",
"chars": 4703,
"preview": "# Troubleshooting\n\nThis document provides solutions to common issues encountered when using TimesFM.\n\n## Installation Is"
},
{
"path": "v1/docs/contributing.md",
"chars": 1067,
"preview": "# How to Contribute\n\nWe would love to accept your patches and contributions to this project.\n\n## Before you begin\n\n### S"
},
{
"path": "v1/experiments/baselines/__init__.py",
"chars": 573,
"preview": "# Copyright 2024 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "v1/experiments/baselines/timegpt_pipeline.py",
"chars": 7880,
"preview": "# Copyright 2024 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "v1/experiments/extended_benchmarks/README.md",
"chars": 2667,
"preview": "# Extended Benchmarks\n\nThe benchmark setting has been borrowed from Nixtla's original [benchmarking](https://github.com/"
},
{
"path": "v1/experiments/extended_benchmarks/run_timegpt.py",
"chars": 2907,
"preview": "# Copyright 2024 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "v1/experiments/extended_benchmarks/run_timesfm.py",
"chars": 4249,
"preview": "# Copyright 2024 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "v1/experiments/extended_benchmarks/utils.py",
"chars": 9710,
"preview": "# Copyright 2024 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "v1/experiments/long_horizon_benchmarks/README.md",
"chars": 2372,
"preview": "# Extended Benchmarks\n\nWe benchmark on the original test set for ETT datasets as per long horizon benchmark papers (see "
},
{
"path": "v1/experiments/long_horizon_benchmarks/run_eval.py",
"chars": 7626,
"preview": "# Copyright 2024 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "v1/notebooks/covariates.ipynb",
"chars": 14988,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"metadata\": {},\n \"source\": [\n \"# TimesFM with Covariates\\n\",\n \""
},
{
"path": "v1/notebooks/finetuning.ipynb",
"chars": 18560,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"metadata\": {},\n \"source\": [\n \"## Importing relevant packages for "
},
{
"path": "v1/notebooks/finetuning_torch.ipynb",
"chars": 18550,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"metadata\": {},\n \"source\": [\n \"# Introduction\\n\",\n \"This notebo"
},
{
"path": "v1/peft/README.md",
"chars": 2202,
"preview": "# Fine-Tuning Pipeline\n\nThis folder contains a generic fine-tuning pipeline designed to support multiple PEFT fine-tunin"
},
{
"path": "v1/peft/finetune.py",
"chars": 13525,
"preview": "# Copyright 2024 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "v1/peft/finetune.sh",
"chars": 957,
"preview": "#!/bin/bash\n\n# Script to finetune a model with specific configurations\n# Adjust the parameters below as needed. For a fu"
},
{
"path": "v1/peft/usage.ipynb",
"chars": 5536,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"metadata\": {},\n \"source\": [\n \"## Load Base Model\"\n ]\n },\n {\n "
},
{
"path": "v1/pyproject.toml",
"chars": 2269,
"preview": "[tool.poetry]\nname = \"timesfm\"\npackages = [\n { include = \"timesfm\", from = \"src\" },\n { include = \"finetuning\", fro"
},
{
"path": "v1/src/adapter/__init__.py",
"chars": 777,
"preview": "# Copyright 2024 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "v1/src/adapter/dora_layers.py",
"chars": 6256,
"preview": "# Copyright 2024 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "v1/src/adapter/lora_layers.py",
"chars": 5029,
"preview": "# Copyright 2024 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "v1/src/adapter/utils.py",
"chars": 17697,
"preview": "# Copyright 2024 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "v1/src/finetuning/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "v1/src/finetuning/finetuning_example.py",
"chars": 12001,
"preview": "\"\"\"\nExample usage of the TimesFM Finetuning Framework.\n\nFor single GPU:\npython script.py --training_mode=single\n\nFor mul"
},
{
"path": "v1/src/finetuning/finetuning_torch.py",
"chars": 12456,
"preview": "\"\"\"\nTimesFM Finetuner: A flexible framework for finetuning TimesFM models on custom datasets.\n\"\"\"\n\nimport logging\nimport"
},
{
"path": "v1/src/timesfm/__init__.py",
"chars": 1180,
"preview": "# Copyright 2024 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "v1/src/timesfm/data_loader.py",
"chars": 8496,
"preview": "# Copyright 2024 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "v1/src/timesfm/patched_decoder.py",
"chars": 19100,
"preview": "# Copyright 2024 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "v1/src/timesfm/pytorch_patched_decoder.py",
"chars": 26137,
"preview": "# Copyright 2024 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "v1/src/timesfm/time_features.py",
"chars": 6285,
"preview": "# Copyright 2024 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "v1/src/timesfm/timesfm_base.py",
"chars": 27305,
"preview": "# Copyright 2024 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "v1/src/timesfm/timesfm_jax.py",
"chars": 12707,
"preview": "# Copyright 2024 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "v1/src/timesfm/timesfm_torch.py",
"chars": 6369,
"preview": "# Copyright 2024 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "v1/src/timesfm/xreg_lib.py",
"chars": 20671,
"preview": "# Copyright 2024 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "v1/tests/test_timesfm.py",
"chars": 3030,
"preview": "# Copyright 2024 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
}
]
About this extraction
This page contains the full source code of the google-research/timesfm GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 85 files (917.3 KB), approximately 280.2k tokens, and a symbol index with 376 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.