Repository: google-research/timesfm
Branch: master
Commit: ed29ec5ef7e2
Files: 85
Total size: 917.3 KB
Directory structure:
gitextract_oox45t84/
├── .gitattributes
├── .github/
│ └── workflows/
│ ├── main.yml
│ └── manual_publish.yml
├── .gitignore
├── AGENTS.md
├── LICENSE
├── README.md
├── pyproject.toml
├── requirements.txt
├── src/
│ └── timesfm/
│ ├── __init__.py
│ ├── configs.py
│ ├── flax/
│ │ ├── __init__.py
│ │ ├── dense.py
│ │ ├── normalization.py
│ │ ├── transformer.py
│ │ └── util.py
│ ├── timesfm_2p5/
│ │ ├── timesfm_2p5_base.py
│ │ ├── timesfm_2p5_flax.py
│ │ └── timesfm_2p5_torch.py
│ ├── torch/
│ │ ├── __init__.py
│ │ ├── dense.py
│ │ ├── normalization.py
│ │ ├── transformer.py
│ │ └── util.py
│ └── utils/
│ └── xreg_lib.py
├── timesfm-forecasting/
│ ├── SKILL.md
│ ├── examples/
│ │ ├── anomaly-detection/
│ │ │ ├── detect_anomalies.py
│ │ │ └── output/
│ │ │ └── anomaly_detection.json
│ │ ├── covariates-forecasting/
│ │ │ ├── demo_covariates.py
│ │ │ └── output/
│ │ │ ├── covariates_metadata.json
│ │ │ └── sales_with_covariates.csv
│ │ └── global-temperature/
│ │ ├── README.md
│ │ ├── generate_animation_data.py
│ │ ├── generate_gif.py
│ │ ├── generate_html.py
│ │ ├── output/
│ │ │ ├── animation_data.json
│ │ │ ├── forecast_output.csv
│ │ │ ├── forecast_output.json
│ │ │ └── interactive_forecast.html
│ │ ├── run_example.sh
│ │ ├── run_forecast.py
│ │ ├── temperature_anomaly.csv
│ │ └── visualize_forecast.py
│ ├── references/
│ │ ├── api_reference.md
│ │ ├── data_preparation.md
│ │ └── system_requirements.md
│ └── scripts/
│ ├── check_system.py
│ └── forecast_csv.py
└── v1/
├── LICENSE
├── README.md
├── TROUBLESHOOTING.md
├── docs/
│ └── contributing.md
├── experiments/
│ ├── baselines/
│ │ ├── __init__.py
│ │ └── timegpt_pipeline.py
│ ├── extended_benchmarks/
│ │ ├── README.md
│ │ ├── run_timegpt.py
│ │ ├── run_timesfm.py
│ │ └── utils.py
│ └── long_horizon_benchmarks/
│ ├── README.md
│ └── run_eval.py
├── notebooks/
│ ├── covariates.ipynb
│ ├── finetuning.ipynb
│ └── finetuning_torch.ipynb
├── peft/
│ ├── README.md
│ ├── finetune.py
│ ├── finetune.sh
│ └── usage.ipynb
├── pyproject.toml
├── src/
│ ├── adapter/
│ │ ├── __init__.py
│ │ ├── dora_layers.py
│ │ ├── lora_layers.py
│ │ └── utils.py
│ ├── finetuning/
│ │ ├── __init__.py
│ │ ├── finetuning_example.py
│ │ └── finetuning_torch.py
│ └── timesfm/
│ ├── __init__.py
│ ├── data_loader.py
│ ├── patched_decoder.py
│ ├── pytorch_patched_decoder.py
│ ├── time_features.py
│ ├── timesfm_base.py
│ ├── timesfm_jax.py
│ ├── timesfm_torch.py
│ └── xreg_lib.py
└── tests/
└── test_timesfm.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitattributes
================================================
# Git LFS tracking for binary outputs in timesfm-forecasting skill
timesfm-forecasting/**/*.png filter=lfs diff=lfs merge=lfs -text
timesfm-forecasting/**/*.gif filter=lfs diff=lfs merge=lfs -text
================================================
FILE: .github/workflows/main.yml
================================================
name: Python package build
on:
push:
branches: [ "master" ]
pull_request:
branches: [ "master" ]
jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.11'
- name: Install uv
run: |
curl -LsSf https://astral.sh/uv/install.sh | sh
echo "$HOME/.cargo/bin" >> $GITHUB_PATH
- name: Create virtual environment
run: uv venv
- name: Install build dependencies
run: |
uv pip install build ".[torch,flax]"
- name: Build package
run: uv run python -m build
================================================
FILE: .github/workflows/manual_publish.yml
================================================
name: Manual PyPI Publish
on:
workflow_dispatch:
jobs:
build-and-publish:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.11'
- name: Install uv
run: |
curl -LsSf https://astral.sh/uv/install.sh | sh
echo "$HOME/.cargo/bin" >> $GITHUB_PATH
- name: Create virtual environment
run: uv venv
- name: Install build dependencies
run: uv pip install build twine
- name: Build package
run: uv run python -m build
- name: Publish to PyPI
env:
TWINE_USERNAME: __token__
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
run: uv run twine upload dist/*
================================================
FILE: .gitignore
================================================
.venv/
dist/
__pycache__/
checkpoints/
wandb/
datasets/
results/
timesfm_jax.egg-info/
development_setup.md
================================================
FILE: AGENTS.md
================================================
# TimesFM — Agent Entry Point
This repository ships a first-party **Agent Skill** for TimesFM at:
```
timesfm-forecasting/
└── SKILL.md ← read this for the full skill
```
## Install the skill
Copy the skill directory into your agent's skills folder:
```bash
# Cursor / Claude Code / OpenCode / Codex (global install)
cp -r timesfm-forecasting/ ~/.cursor/skills/
cp -r timesfm-forecasting/ ~/.claude/skills/
# Or project-level
cp -r timesfm-forecasting/ .cursor/skills/
```
Any agent that supports the open [Agent Skills standard](https://agentskills.io) will discover it automatically.
## Working in this repo
If you are developing TimesFM itself (not using it), the source lives in `src/timesfm/`.
Archived v1/v2 code and notebooks are in `v1/`.
Run tests:
```bash
pytest v1/tests/
```
See `README.md` for full developer setup.
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
# TimesFM
TimesFM (Time Series Foundation Model) is a pretrained time-series foundation
model developed by Google Research for time-series forecasting.
* Paper:
[A decoder-only foundation model for time-series forecasting](https://arxiv.org/abs/2310.10688),
ICML 2024.
* All checkpoints:
[TimesFM Hugging Face Collection](https://huggingface.co/collections/google/timesfm-release-66e4be5fdb56e960c1e482a6).
* [Google Research blog](https://research.google/blog/a-decoder-only-foundation-model-for-time-series-forecasting/).
* [TimesFM in BigQuery](https://cloud.google.com/bigquery/docs/timesfm-model):
an official Google product.
This open version is not an officially supported Google product.
**Latest Model Version:** TimesFM 2.5
**Archived Model Versions:**
- 1.0 and 2.0: relevant code archived in the sub directory `v1`. You can `pip
install timesfm==1.3.0` to install an older version of this package to load
them.
## Update - Oct. 29, 2025
Added back the covariate support through XReg for TimesFM 2.5.
## Update - Sept. 15, 2025
TimesFM 2.5 is out!
Comparing to TimesFM 2.0, this new 2.5 model:
- uses 200M parameters, down from 500M.
- supports up to 16k context length, up from 2048.
- supports continuous quantile forecast up to 1k horizon via an optional 30M
quantile head.
- gets rid of the `frequency` indicator.
- has a couple of new forecasting flags.
Along with the model upgrade we have also upgraded the inference API. This repo
will be under construction over the next few weeks to
1. add support for an upcoming Flax version of the model (faster inference).
2. add back covariate support.
3. populate more docstrings, docs and notebook.
### Install
1. Clone the repository:
```shell
git clone https://github.com/google-research/timesfm.git
cd timesfm
```
2. Create a virtual environment and install dependencies using `uv`:
```shell
# Create a virtual environment
uv venv
# Activate the environment
source .venv/bin/activate
# Install the package in editable mode with torch
uv pip install -e .[torch]
# Or with flax
uv pip install -e .[flax]
# Or XReg is needed
uv pip install -e .[xreg]
```
3. [Optional] Install your preferred `torch` / `jax` backend based on your OS and accelerators
(CPU, GPU, TPU or Apple Silicon).:
- [Install PyTorch](https://pytorch.org/get-started/locally/).
- [Install Jax](https://docs.jax.dev/en/latest/installation.html#installation)
for Flax.
### Code Example
```python
import torch
import numpy as np
import timesfm
torch.set_float32_matmul_precision("high")
model = timesfm.TimesFM_2p5_200M_torch.from_pretrained("google/timesfm-2.5-200m-pytorch")
model.compile(
timesfm.ForecastConfig(
max_context=1024,
max_horizon=256,
normalize_inputs=True,
use_continuous_quantile_head=True,
force_flip_invariance=True,
infer_is_positive=True,
fix_quantile_crossing=True,
)
)
point_forecast, quantile_forecast = model.forecast(
horizon=12,
inputs=[
np.linspace(0, 1, 100),
np.sin(np.linspace(0, 20, 67)),
], # Two dummy inputs
)
point_forecast.shape # (2, 12)
quantile_forecast.shape # (2, 12, 10): mean, then 10th to 90th quantiles.
```
================================================
FILE: pyproject.toml
================================================
[project]
name = "timesfm"
version = "2.0.0"
description = "A time series foundation model."
authors = [
{name = "Rajat Sen", email = "senrajat@google.com"},
{name = "Yichen Zhou", email = "yichenzhou@google.com"},
{name = "Abhimanyu Das", email = "abhidas@google.com"},
{name = "Petros Mol", email = "pmol@google.com"},
{name = "Michael Chertushkin", email = "chertushkinmichael@gmail.com"},
]
license = {text = "Apache-2.0"}
readme = "README.md"
requires-python = ">=3.10"
dependencies = [
"numpy>=1.26.4",
"huggingface_hub[cli]>=0.23.0",
"safetensors>=0.5.3",
]
[project.optional-dependencies]
torch = [
"torch>=2.0.0",
]
flax = [
"flax",
"optax",
"einshape",
"orbax-checkpoint",
"jaxtyping",
"jax[cuda]"
]
xreg = [
"jax[cuda]",
"scikit-learn",
]
[tool.ruff]
line-length = 88
indent-width = 2
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
================================================
FILE: requirements.txt
================================================
# This file was autogenerated by uv via the following command:
# uv pip compile pyproject.toml -o requirements.txt
anyio==4.11.0
# via httpx
certifi==2025.10.5
# via
# httpcore
# httpx
click==8.3.0
# via typer-slim
filelock==3.19.1
# via huggingface-hub
fsspec==2025.9.0
# via huggingface-hub
h11==0.16.0
# via httpcore
hf-xet==1.2.0
# via huggingface-hub
httpcore==1.0.9
# via httpx
httpx==0.28.1
# via huggingface-hub
huggingface-hub==1.0.1
# via timesfm (pyproject.toml)
idna==3.10
# via
# anyio
# httpx
numpy==2.2.6
# via timesfm (pyproject.toml)
packaging==25.0
# via huggingface-hub
pyyaml==6.0.3
# via huggingface-hub
safetensors==0.6.2
# via timesfm (pyproject.toml)
shellingham==1.5.4
# via huggingface-hub
sniffio==1.3.1
# via anyio
tqdm==4.67.1
# via huggingface-hub
typer-slim==0.20.0
# via huggingface-hub
typing-extensions==4.15.0
# via
# anyio
# huggingface-hub
# typer-slim
================================================
FILE: src/timesfm/__init__.py
================================================
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TimesFM API."""
from .configs import ForecastConfig
try:
from .timesfm_2p5 import timesfm_2p5_torch
TimesFM_2p5_200M_torch = timesfm_2p5_torch.TimesFM_2p5_200M_torch
except ImportError:
pass
try:
from .timesfm_2p5 import timesfm_2p5_flax
TimesFM_2p5_200M_flax = timesfm_2p5_flax.TimesFM_2p5_200M_flax
except ImportError:
pass
================================================
FILE: src/timesfm/configs.py
================================================
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Abstract configs for TimesFM layers."""
import dataclasses
from typing import Literal
@dataclasses.dataclass(frozen=True)
class ForecastConfig:
"""Options for forecasting.
Attributes:
max_context: The maximum context length. This is used by the complied decode
function at inference time during batched inference. Any input time series
with length less than max_context will be padded with zeros, and with
length greater than max_context will be truncated.
max_horizon: The maximum horizon length. This is used by the complied decode
function at inference time during batched inference. The compiled cached
decoding function will by default forecast till max_horizon.
normalize_inputs: Whether to normalize the inputs. This is useful when the
raw inputs are of extremely large or small magnitudes which may result in
numerical issues.
window_size: The window size for decomposed forecasting.
TODO(siriuz42):implement it.
per_core_batch_size: The batch size per core. Used at inference time during
batched inference when multiple GPU / TPU devices are used.
use_continuous_quantile_head: Whether to use a separate continuous quantile
head to avoid quantile collapsing.
force_flip_invariance: Whether to force flip invariance. TimesFM guarantees
that TimesFM(aX + b) = a * TimesFM(x) + b for a >= 0 by default. This flag
extends it to a < 0 as well.
infer_is_positive: Whether to guarantee nonnegativity of the output if the
input is nonnegative.
fix_quantile_crossing: Whether to fix quantile crossing.
return_backcast: Whether to return backcast.
"""
max_context: int = 0
max_horizon: int = 0
normalize_inputs: bool = False
window_size: int = 0
per_core_batch_size: int = 1
use_continuous_quantile_head: bool = False
force_flip_invariance: bool = True
infer_is_positive: bool = True
fix_quantile_crossing: bool = False
return_backcast: bool = False
@dataclasses.dataclass(frozen=True)
class ResidualBlockConfig:
"""Framework-agnostic config for a residual block."""
input_dims: int
hidden_dims: int
output_dims: int
use_bias: bool
activation: Literal["relu", "swish", "none"]
@dataclasses.dataclass(frozen=True)
class RandomFourierFeaturesConfig:
"""Framework-agnostic config for random fourier features."""
input_dims: int
output_dims: int
projection_stddev: float
use_bias: bool
@dataclasses.dataclass(frozen=True)
class TransformerConfig:
"""Framework-agnostic config for a transformer."""
model_dims: int
hidden_dims: int
num_heads: int
attention_norm: Literal["rms"]
feedforward_norm: Literal["rms"]
qk_norm: Literal["rms", "none"]
use_bias: bool
use_rotary_position_embeddings: bool
ff_activation: Literal["relu", "swish", "none"]
fuse_qkv: bool
@dataclasses.dataclass(frozen=True)
class StackedTransformersConfig:
"""Framework-agnostic config for a stacked transformers."""
num_layers: int
transformer: TransformerConfig
================================================
FILE: src/timesfm/flax/__init__.py
================================================
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
================================================
FILE: src/timesfm/flax/dense.py
================================================
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Dense layers for TimesFM."""
from flax import nnx
import jax
import jax.numpy as jnp
import jaxtyping
from .. import configs
Array = jaxtyping.Array
Bool = jaxtyping.Bool
Float = jaxtyping.Float
Integer = jaxtyping.Integer
Num = jaxtyping.Num
ResidualBlockConfig = configs.ResidualBlockConfig
RandomFourierFeaturesConfig = configs.RandomFourierFeaturesConfig
class ResidualBlock(nnx.Module):
"""Residual block with two linear layers and a linear residual connection."""
def __init__(self, config: ResidualBlockConfig, *, rngs=nnx.Rngs(42)):
self.config = config
self.hidden_layer = nnx.Linear(
in_features=config.input_dims,
out_features=config.hidden_dims,
use_bias=config.use_bias,
rngs=rngs,
)
self.output_layer = nnx.Linear(
in_features=config.hidden_dims,
out_features=config.output_dims,
use_bias=config.use_bias,
rngs=rngs,
)
self.residual_layer = nnx.Linear(
in_features=config.input_dims,
out_features=config.output_dims,
use_bias=config.use_bias,
rngs=rngs,
)
if config.activation == "relu":
self.activation = jax.nn.relu
elif config.activation == "swish":
self.activation = jax.nn.swish
elif config.activation == "none":
self.activation = lambda x: x
else:
raise ValueError(f"Activation: {config.activation} not supported.")
def __call__(self, x: Float[Array, "b ... i"]) -> Float[Array, "b ... o"]:
return self.output_layer(
self.activation(self.hidden_layer(x))
) + self.residual_layer(x)
class RandomFourierFeatures(nnx.Module):
"""Random Fourier features layer."""
__data__ = ("phrase_shifts",)
def __init__(self, config: RandomFourierFeaturesConfig, *, rngs=nnx.Rngs(42)):
self.config = config
if config.output_dims % 4 != 0:
raise ValueError(
f"Output dims must be a multiple of 4: {config.output_dims} % 4 != 0."
)
num_projected_features = config.output_dims // 4
self.phase_shifts = nnx.Param(jnp.zeros(shape=(2, num_projected_features)))
self.projection_layer = nnx.Linear(
in_features=config.input_dims,
out_features=num_projected_features,
use_bias=config.use_bias,
rngs=rngs,
)
self.residual_layer = nnx.Linear(
in_features=config.input_dims,
out_features=config.output_dims,
use_bias=config.use_bias,
rngs=rngs,
)
def __call__(self, x: Float[Array, "b ... i"]) -> Float[Array, "b ... o"]:
projected = self.projection_layer(x)
cos_features = jnp.cos(projected)
sin_features = jnp.sin(projected)
sq_wave_1 = jnp.sign(jnp.sin(projected + self.phase_shifts[0, :]))
sq_wave_2 = jnp.sign(jnp.sin(projected + self.phase_shifts[1, :]))
fourier_features = jnp.concatenate(
[cos_features, sin_features, sq_wave_1, sq_wave_2], axis=-1
)
residual = self.residual_layer(x)
return fourier_features + residual
================================================
FILE: src/timesfm/flax/normalization.py
================================================
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Normalization layers for TimesFM."""
from flax import nnx
import jax
import jax.numpy as jnp
import jaxtyping
Array = jaxtyping.Array
Bool = jaxtyping.Bool
Float = jaxtyping.Float
Integer = jaxtyping.Integer
Num = jaxtyping.Num
class RMSNorm(nnx.Module):
"""RMS normalization."""
__data__ = ("scale",)
def __init__(
self,
num_features: int,
*,
epsilon: float = 1e-6,
rngs=nnx.Rngs(42),
):
del rngs
self.scale = nnx.Param(jnp.zeros(shape=(num_features,)))
self.num_features = num_features
self.epsilon = epsilon
def __call__(self, inputs: Float[Array, "b ... d"]) -> Float[Array, "b ... d"]:
var = jnp.mean(jnp.square(inputs), axis=-1, keepdims=True)
normed_inputs = inputs * jax.lax.rsqrt(var + self.epsilon)
normed_inputs *= self.scale
return normed_inputs
class LayerNorm(nnx.Module):
"""Layer normalization replica of LayerNorm."""
__data__ = ("scale", "bias")
def __init__(self, num_features: int, *, epsilon: float = 1e-6, rngs=nnx.Rngs(42)):
del rngs
self.scale = nnx.Param(jnp.ones(shape=(num_features,)))
self.bias = nnx.Param(jnp.zeros(shape=(num_features,)))
self.num_features = num_features
self.epsilon = epsilon
def __call__(self, inputs: Float[Array, "b ... d"]) -> Float[Array, "b ... d"]:
mean = jnp.mean(inputs, axis=-1, keepdims=True)
var = jnp.mean(jnp.square(inputs - mean), axis=-1, keepdims=True)
normed_inputs = (inputs - mean) * jax.lax.rsqrt(var + self.epsilon)
normed_inputs *= self.scale
normed_inputs += self.bias
return normed_inputs
================================================
FILE: src/timesfm/flax/transformer.py
================================================
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformer layers for TimesFM."""
import functools
from typing import Callable
from flax import nnx
from flax.nnx.nn import linear
import jax
from jax import lax
import jax.numpy as jnp
import jaxtyping
from .. import configs
from . import normalization, util
Array = jaxtyping.Array
Bool = jaxtyping.Bool
Float = jaxtyping.Float
Integer = jaxtyping.Integer
Num = jaxtyping.Num
LayerNorm = normalization.LayerNorm
RMSNorm = normalization.RMSNorm
LinearGeneral = linear.LinearGeneral
TransformerConfig = configs.TransformerConfig
DecodeCache = util.DecodeCache
@functools.partial(
jax.jit,
static_argnames=("query_length", "kv_length"),
)
def make_attn_mask(
query_length: int,
num_all_masked_kv: Integer[Array, "b"],
query_index_offset: Integer[Array, "b"] | None = None,
kv_length: int = 0,
) -> Bool[Array, "b 1 q n"]:
"""Makes attention mask."""
if kv_length == 0:
kv_length = query_length
q_index = jnp.arange(query_length)[None, None, :, None]
if query_index_offset is not None:
q_index += query_index_offset[:, None, None, None]
kv_index = jnp.arange(kv_length)[None, None, None, :]
return jnp.logical_and(
q_index >= kv_index,
kv_index >= num_all_masked_kv[:, None, None, None],
)
class RotaryPositionalEmbedding(nnx.Module):
"""Rotary positional embedding."""
def __init__(
self,
embedding_dims: int,
min_timescale: int = 1,
max_timescale: int = 10000,
):
self.embedding_dims = embedding_dims
self.min_timescale = min_timescale
self.max_timescale = max_timescale
def __call__(
self,
inputs: Float[Array, "b ... d"],
position: Array | None = None,
):
"""Generates a JTensor of sinusoids with different frequencies."""
if self.embedding_dims != inputs.shape[-1]:
raise ValueError(
"The embedding dims of the rotary position embedding"
"must match the hidden dimension of the inputs."
)
half_embedding_dim = self.embedding_dims // 2
fraction = 2 * jnp.arange(0, half_embedding_dim) / self.embedding_dims
timescale = (
self.min_timescale * (self.max_timescale / self.min_timescale) ** fraction
)
if position is None:
seq_length = inputs.shape[1]
position = jnp.arange(seq_length, dtype=jnp.float32)[None, :]
if len(inputs.shape) == 4:
position = position[..., None, None]
timescale = timescale[None, None, None, :]
elif len(inputs.shape) == 3:
position = position[..., None]
timescale = timescale[None, None, :]
else:
raise ValueError("Inputs must be of rank 3 or 4.")
sinusoid_inp = position / timescale
sin = jnp.sin(sinusoid_inp)
cos = jnp.cos(sinusoid_inp)
first_half, second_half = jnp.split(inputs, 2, axis=-1)
first_part = first_half * cos - second_half * sin
second_part = second_half * cos + first_half * sin
first_part = first_part.astype(None)
second_part = second_part.astype(None)
return jnp.concatenate([first_part, second_part], axis=-1)
class PerDimScale(nnx.Module):
"""Per-dimension scaling."""
__data__ = ("per_dim_scale",)
def __init__(self, num_dims: int, *, rngs=nnx.Rngs(42)):
del rngs
self.num_dims = num_dims
self.per_dim_scale = nnx.Param(jnp.zeros(shape=(num_dims,)))
def __call__(self, x: Float[Array, "b ... d"]) -> Float[Array, "b ... d"]:
return x * (
1.442695041 / jnp.sqrt(self.num_dims) * jax.nn.softplus(self.per_dim_scale)
)
class MultiHeadAttention(nnx.Module):
"""Multi-head attention."""
def __init__(
self,
num_heads: int,
in_features: int,
*,
use_per_dim_scale: bool = True,
use_rotary_position_embeddings: bool = True,
use_bias: bool = False,
deterministic: bool | None = None,
attention_fn: Callable[..., Array] = nnx.dot_product_attention,
qk_norm: str = "rms",
rngs=nnx.Rngs(42),
):
self.num_heads = num_heads
self.in_features = in_features
self.qkv_features = in_features
self.out_features = in_features
self.in_kv_features = in_features
self.deterministic = deterministic
self.use_bias = use_bias
self.attention_fn = attention_fn
self.qk_norm = qk_norm
if self.qkv_features % self.num_heads != 0:
raise ValueError(
f"Memory dimension ({self.qkv_features}) must be divisible by "
f"'num_heads' heads ({self.num_heads})."
)
self.head_dim = self.qkv_features // self.num_heads
linear_general = functools.partial(
LinearGeneral,
out_features=(self.num_heads, self.head_dim),
use_bias=self.use_bias,
)
# project inputs_q to multi-headed q/k/v
# dimensions are then [batch..., length, n_heads, n_features_per_head]
self.query = linear_general(self.in_features, rngs=rngs)
self.key = linear_general(self.in_kv_features, rngs=rngs)
self.value = linear_general(self.in_kv_features, rngs=rngs)
if self.qk_norm == "rms":
self.query_ln = RMSNorm(self.head_dim)
self.key_ln = RMSNorm(self.head_dim)
else:
self.query_ln = None
self.key_ln = None
self.out = LinearGeneral(
in_features=(self.num_heads, self.head_dim),
out_features=self.out_features,
axis=(-2, -1),
use_bias=self.use_bias,
rngs=rngs,
)
self.use_per_dim_scale = use_per_dim_scale
self.use_rotary_position_embeddings = use_rotary_position_embeddings
if self.use_rotary_position_embeddings:
self.rotary_position_embedding = RotaryPositionalEmbedding(
embedding_dims=self.head_dim,
)
else:
self.rotary_position_embedding = None
if use_per_dim_scale:
self.per_dim_scale = PerDimScale(num_dims=self.head_dim, rngs=rngs)
else:
self.per_dim_scale = None
def __call__(
self,
inputs_q: Array,
*,
decode_cache: DecodeCache | None = None,
patch_mask: Array | None = None,
deterministic: bool | None = None,
sow_weights: bool = False,
) -> tuple[Float[Array, "b ... o"], DecodeCache | None]:
"""Applies multi-head dot product attention on the input data."""
_, n_patches, input_in_features = inputs_q.shape
if input_in_features != self.in_features:
raise ValueError(
f"Incompatible input dimension, got {input_in_features} "
f"but module expects {self.in_features}."
)
if patch_mask is None:
patch_mask = jnp.zeros_like(inputs_q.shape[:-1], dtype=jnp.bool)
# For query: rope -> ln -> per_dim_scale
query = self.query(inputs_q)
key = self.key(inputs_q)
value = self.value(inputs_q)
if decode_cache is None:
num_masked = jnp.sum(patch_mask.astype(jnp.int32), axis=-1, keepdims=False)
next_index = jnp.zeros_like(num_masked, dtype=jnp.int32)
else:
num_masked = (
jnp.sum(patch_mask.astype(jnp.int32), axis=-1, keepdims=False)
+ decode_cache.num_masked
)
next_index = decode_cache.next_index
if self.use_rotary_position_embeddings:
position = (
jnp.arange(n_patches, dtype=jnp.int32)[None, :]
+ next_index[:, None]
- num_masked[:, None]
)
query = self.rotary_position_embedding(query, position)
key = self.rotary_position_embedding(key, position)
if self.query_ln is not None:
query = self.query_ln(query)
if self.key_ln is not None:
key = self.key_ln(key)
if self.use_per_dim_scale:
query = self.per_dim_scale(query)
if decode_cache is not None:
# Cached decoding.
_, decode_cache_size, _, _ = decode_cache.value.shape
zero = jnp.array(0, dtype=lax.dtype(next_index.dtype))
start_indices = (zero, next_index[0], zero, zero)
key = lax.dynamic_update_slice(decode_cache.key, key, start_indices)
value = lax.dynamic_update_slice(decode_cache.value, value, start_indices)
decode_cache.key = key
decode_cache.value = value
decode_cache.next_index = next_index + n_patches
decode_cache.num_masked = num_masked
attn_mask = make_attn_mask(
query_length=n_patches,
num_all_masked_kv=num_masked,
query_index_offset=next_index,
kv_length=decode_cache_size,
)
else:
# Training
attn_mask = make_attn_mask(query_length=n_patches, num_all_masked_kv=num_masked)
# apply attention
x = self.attention_fn(
query * jnp.sqrt(self.head_dim),
key,
value,
mask=attn_mask,
deterministic=deterministic,
module=self if sow_weights else None,
)
# back to the original inputs dimensions
out = self.out(x)
return out, decode_cache
class Transformer(nnx.Module):
"""Classic Transformer used in TimesFM."""
def __init__(self, config: TransformerConfig, *, rngs=nnx.Rngs(42)):
self.config = config
if config.attention_norm == "rms":
self.pre_attn_ln = RMSNorm(num_features=config.model_dims, rngs=rngs)
self.post_attn_ln = RMSNorm(num_features=config.model_dims, rngs=rngs)
else:
raise ValueError(f"Layer norm: {config.attention_norm} not supported.")
self.attn = MultiHeadAttention(
num_heads=config.num_heads,
in_features=config.model_dims,
use_per_dim_scale=True,
use_rotary_position_embeddings=config.use_rotary_position_embeddings,
qk_norm=config.qk_norm,
rngs=rngs,
)
if config.feedforward_norm == "rms":
self.pre_ff_ln = RMSNorm(num_features=config.model_dims, rngs=rngs)
self.post_ff_ln = RMSNorm(num_features=config.model_dims, rngs=rngs)
else:
raise ValueError(f"Layer norm: {config.feedforward_norm} not supported.")
self.ff0 = nnx.Linear(
in_features=config.model_dims,
out_features=config.hidden_dims,
use_bias=config.use_bias,
rngs=rngs,
)
self.ff1 = nnx.Linear(
in_features=config.hidden_dims,
out_features=config.model_dims,
use_bias=config.use_bias,
rngs=rngs,
)
if config.ff_activation == "relu":
self.activation = jax.nn.relu
elif config.ff_activation == "swish":
self.activation = jax.nn.swish
elif config.ff_activation == "none":
self.activation = lambda x: x
else:
raise ValueError(f"Activation: {config.ff_activation} not supported.")
def __call__(
self,
input_embeddings: Float[Array, "b n d"],
patch_mask: Bool[Array, "b n"],
decode_cache: DecodeCache | None = None,
) -> tuple[Float[Array, "b n d"], DecodeCache | None]:
attn_output, decode_cache = self.attn(
inputs_q=self.pre_attn_ln(input_embeddings),
decode_cache=decode_cache,
patch_mask=patch_mask,
sow_weights=False,
deterministic=True,
)
attn_output = self.post_attn_ln(attn_output) + input_embeddings
output_embeddings = (
self.post_ff_ln(self.ff1(self.activation(self.ff0(self.pre_ff_ln(attn_output)))))
+ attn_output
)
return output_embeddings, decode_cache
================================================
FILE: src/timesfm/flax/util.py
================================================
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Flax utility functions for TimesFM layers."""
import dataclasses
import functools
import jax
import jax.numpy as jnp
import jaxtyping
Float = jaxtyping.Float
Array = jaxtyping.Array
Bool = jaxtyping.Bool
Integer = jaxtyping.Integer
_TOLERANCE = 1e-6
@jax.tree_util.register_dataclass
@dataclasses.dataclass(frozen=False)
class DecodeCache:
"""Cache for decoding."""
next_index: Integer[Array, "b"]
num_masked: Integer[Array, "b"]
key: Float[Array, "b n h d"]
value: Float[Array, "b n h d"]
@jax.jit
def update_running_stats(
n: Float[Array, "b"],
mu: Float[Array, "b"],
sigma: Float[Array, "b"],
x: Float[Array, "b p"],
mask: Bool[Array, "b p"],
) -> tuple[
tuple[Float[Array, "b"], Float[Array, "b"], Float[Array, "b"]],
tuple[Float[Array, "b"], Float[Array, "b"], Float[Array, "b"]],
]:
"""Updates the running stats."""
is_legit = jnp.logical_not(mask)
inc_n = jnp.sum(is_legit.astype(jnp.float32), axis=-1, keepdims=False)
inc_mu = jnp.where(
inc_n == 0, 0.0, jnp.mean(x, axis=-1, keepdims=False, where=is_legit)
)
inc_sigma = jnp.where(
inc_n == 0, 0.0, jnp.std(x, axis=-1, keepdims=False, where=is_legit)
)
new_n = n + inc_n
new_mu = jnp.where(new_n == 0, 0.0, (n * mu + inc_mu * inc_n) / new_n)
new_sigma = jnp.sqrt(
jnp.where(
new_n == 0,
0.0,
(
n * sigma * sigma
+ inc_n * inc_sigma * inc_sigma
+ n * (mu - new_mu) * (mu - new_mu)
+ inc_n * (inc_mu - new_mu) * (inc_mu - new_mu)
)
/ new_n,
)
)
return (w := (new_n, new_mu, new_sigma), w)
def scan_along_axis(f, init, xs, axis: int, **kwargs):
"""Scans along an axis."""
moved_xs = jax.tree_util.tree_map(lambda x: jnp.moveaxis(x, axis, 0), xs)
carry, moved_ys = jax.lax.scan(f, init, moved_xs, **kwargs)
return (
carry,
jax.tree_util.tree_map(lambda x: jnp.moveaxis(x, 0, axis), moved_ys),
)
@functools.partial(jax.jit, static_argnames=("reverse",))
def revin(
x: Float[Array, "b ..."],
mu: Float[Array, "b ..."],
sigma: Float[Array, "b ..."],
reverse: bool = False,
):
"""Reversible per-instance normalization."""
if len(mu.shape) == len(x.shape) - 1:
mu = mu[..., None]
sigma = sigma[..., None]
elif len(mu.shape) == len(x.shape) - 2:
mu = mu[..., None, None]
sigma = sigma[..., None, None]
if reverse:
return x * sigma + mu
else:
return (x - mu) / jnp.where(sigma < _TOLERANCE, 1.0, sigma)
================================================
FILE: src/timesfm/timesfm_2p5/timesfm_2p5_base.py
================================================
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TimesFM 2p5 base implementation."""
import dataclasses
from typing import Any, Callable, Sequence
import collections
import numpy as np
from .. import configs
ResidualBlockConfig = configs.ResidualBlockConfig
StackedTransformersConfig = configs.StackedTransformersConfig
TransformerConfig = configs.TransformerConfig
ForecastConfig = configs.ForecastConfig
Category = int | str
XRegMode = str
def strip_leading_nans(arr):
"""Removes contiguous NaN values from the beginning of a NumPy array.
Args:
arr: The input NumPy array.
Returns:
A new NumPy array with leading NaN values removed.
If the array is all NaNs or empty, returns an empty array.
"""
isnan = np.isnan(arr)
first_valid_index = np.argmax(~isnan)
return arr[first_valid_index:]
def linear_interpolation(arr):
"""Performs linear interpolation to fill NaN values in a 1D numpy array.
Args:
arr: The 1D numpy array containing NaN values.
Returns:
A new numpy array with NaN values filled using linear interpolation,
or the original array if no NaNs are present.
Returns None if the input is not a 1D array.
Returns the original array if there are no NaN values.
"""
nans = np.isnan(arr)
if not np.any(nans): # Check if there are any NaNs
return arr
def x(z):
return z.nonzero()[0]
nans_indices = x(nans)
non_nans_indices = x(~nans)
non_nans_values = arr[~nans]
try:
arr[nans] = np.interp(nans_indices, non_nans_indices, non_nans_values)
except ValueError:
if non_nans_values:
mu = np.nanmean(arr)
else:
mu = 0.0
arr = np.where(np.isfinite(arr), arr, mu)
return arr
@dataclasses.dataclass(frozen=True)
class TimesFM_2p5_200M_Definition:
"""Framework-agnostic config of TimesFM 2.5."""
context_limit = 16384
input_patch_len: int = 32
output_patch_len: int = 128
output_quantile_len: int = 1024
quantiles: list[float] = dataclasses.field(
default_factory=lambda: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
)
decode_index: int = 5
tokenizer: ResidualBlockConfig = ResidualBlockConfig(
input_dims=64,
hidden_dims=1280,
output_dims=1280,
use_bias=True,
activation="swish",
)
stacked_transformers: StackedTransformersConfig = StackedTransformersConfig(
num_layers=20,
transformer=TransformerConfig(
model_dims=1280,
hidden_dims=1280,
num_heads=16,
attention_norm="rms",
feedforward_norm="rms",
qk_norm="rms",
use_bias=False,
use_rotary_position_embeddings=True,
ff_activation="swish",
fuse_qkv=True,
),
)
output_projection_point: ResidualBlockConfig = ResidualBlockConfig(
input_dims=1280,
hidden_dims=1280,
output_dims=1280,
use_bias=False,
activation="swish",
)
output_projection_quantiles: ResidualBlockConfig = ResidualBlockConfig(
input_dims=1280,
hidden_dims=1280,
output_dims=10240,
use_bias=False,
activation="swish",
)
class TimesFM_2p5:
"""Abstract base class for TimesFM models.
Attributes:
forecast_config: Configuration for forecasting flags.
compiled_decode: Compiled decode function.
global_batch_size: Global batch size.
"""
forecast_config: ForecastConfig | None = None
compiled_decode: Callable[..., Any] | None = None
global_batch_size: int = 0
def load_checkpoint(self, path: str):
"""Loads a TimesFM model from a checkpoint."""
raise NotImplementedError()
def compile(self, forecast_config: ForecastConfig | None = None):
"""Compiles the TimesFM model for fast decoding."""
raise NotImplementedError()
def forecast(
self, horizon: int, inputs: list[np.ndarray]
) -> tuple[np.ndarray, np.ndarray]:
"""Forecasts the time series."""
if self.compiled_decode is None:
raise RuntimeError("Model is not compiled. Please call compile() first.")
assert self.global_batch_size > 0
assert self.forecast_config is not None
context = self.forecast_config.max_context
num_inputs = len(inputs)
if (w := num_inputs % self.global_batch_size) != 0:
inputs += [np.array([0.0] * 3)] * (self.global_batch_size - w)
output_points = []
output_quantiles = []
values = []
masks = []
idx = 0
for each_input in inputs:
value = linear_interpolation(strip_leading_nans(np.array(each_input)))
if (w := len(value)) >= context:
value = value[-context:]
mask = np.zeros_like(value, dtype=bool)
else:
mask = np.array([True] * (context - w) + [False] * w)
value = np.pad(value, (context - w, 0), "constant", constant_values=0.0)
values.append(value)
masks.append(mask)
idx += 1
if idx == self.global_batch_size:
idx = 0
point_forecast, quantile_forecast = self.compiled_decode(horizon, values, masks)
output_points.append(point_forecast)
output_quantiles.append(quantile_forecast)
values = []
masks = []
output_points = np.concatenate(output_points, axis=0)
output_quantiles = np.concatenate(output_quantiles, axis=0)
return output_points[:num_inputs], output_quantiles[:num_inputs]
def forecast_with_covariates(
self,
inputs: list[Sequence[float]],
dynamic_numerical_covariates: dict[str, Sequence[Sequence[float]]] | None = None,
dynamic_categorical_covariates: (
dict[str, Sequence[Sequence[Category]]] | None
) = None,
static_numerical_covariates: dict[str, Sequence[float]] | None = None,
static_categorical_covariates: dict[str, Sequence[Category]] | None = None,
xreg_mode: XRegMode = "xreg + timesfm",
normalize_xreg_target_per_input: bool = True,
ridge: float = 0.0,
max_rows_per_col: int = 0,
force_on_cpu: bool = False,
):
"""Forecasts on a list of time series with covariates.
To optimize inference speed, avoid string valued categorical covariates.
Args:
inputs: A list of time series forecast contexts. Each context time series
should be in a format convertible to JTensor by `jnp.array`.
dynamic_numerical_covariates: A dict of dynamic numerical covariates.
dynamic_categorical_covariates: A dict of dynamic categorical covariates.
static_numerical_covariates: A dict of static numerical covariates.
static_categorical_covariates: A dict of static categorical covariates.
xreg_mode: one of "xreg + timesfm" or "timesfm + xreg". "timesfm + xreg"
fits a model on the residuals of the TimesFM forecast. "xreg + timesfm"
fits a model on the targets then forecasts on the residuals via TimesFM.
normalize_xreg_target_per_input: whether to normalize the xreg target per
input in the given batch.
ridge: ridge penalty for the linear model.
max_rows_per_col: max number of rows per column for the linear model.
force_on_cpu: whether to force running on cpu for the linear model.
Returns:
A tuple of two lists. The first is the outputs of the model. The second is
the outputs of the xreg.
"""
if self.forecast_config is None:
raise ValueError("Model is not compiled. Please call compile() first.")
elif not self.forecast_config.return_backcast:
raise ValueError(
"For XReg, `return_backcast` must be set to True in the forecast config. Please recompile the model."
)
from ..utils import xreg_lib
# Verify and bookkeep covariates.
if not (
dynamic_numerical_covariates
or dynamic_categorical_covariates
or static_numerical_covariates
or static_categorical_covariates
):
raise ValueError(
"At least one of dynamic_numerical_covariates,"
" dynamic_categorical_covariates, static_numerical_covariates,"
" static_categorical_covariates must be set."
)
# Track the lengths of (1) each input, (2) the part that can be used in the
# linear model, and (3) the horizon.
input_lens, train_lens, test_lens = [], [], []
for i, input_ts in enumerate(inputs):
input_len = len(input_ts)
input_lens.append(input_len)
if xreg_mode == "timesfm + xreg":
# For fitting residuals, no TimesFM forecast on the first patch.
train_lens.append(max(0, input_len - self.model.p))
elif xreg_mode == "xreg + timesfm":
train_lens.append(input_len)
else:
raise ValueError(f"Unsupported mode: {xreg_mode}")
if dynamic_numerical_covariates:
test_lens.append(
len(list(dynamic_numerical_covariates.values())[0][i]) - input_len
)
elif dynamic_categorical_covariates:
test_lens.append(
len(list(dynamic_categorical_covariates.values())[0][i]) - input_len
)
else:
test_lens.append(self.forecast_config.max_horizon)
if test_lens[-1] > self.forecast_config.max_horizon:
raise ValueError(
"Forecast horizon length inferred from the dynamic covaraites is longer than the"
f"max_horizon defined in the forecast config: {test_lens[-1]} > {self.forecast_config.max_horizon=}."
)
# Prepare the covariates into train and test.
train_dynamic_numerical_covariates = collections.defaultdict(list)
test_dynamic_numerical_covariates = collections.defaultdict(list)
train_dynamic_categorical_covariates = collections.defaultdict(list)
test_dynamic_categorical_covariates = collections.defaultdict(list)
for covariates, train_covariates, test_covariates in (
(
dynamic_numerical_covariates,
train_dynamic_numerical_covariates,
test_dynamic_numerical_covariates,
),
(
dynamic_categorical_covariates,
train_dynamic_categorical_covariates,
test_dynamic_categorical_covariates,
),
):
if not covariates:
continue
for covariate_name, covariate_values in covariates.items():
for input_len, train_len, covariate_value in zip(
input_lens, train_lens, covariate_values
):
train_covariates[covariate_name].append(
covariate_value[(input_len - train_len) : input_len]
)
test_covariates[covariate_name].append(covariate_value[input_len:])
# Fit models.
if xreg_mode == "timesfm + xreg":
# Forecast via TimesFM then fit a model on the residuals.
point_outputs, quantile_outputs = self.forecast(
horizon=self.forecast_config.max_horizon, inputs=inputs
)
targets = [
(
np.array(input_ts)[-train_len:]
- point_output[: -self.forecast_config.max_horizon][-train_len:]
)
for input_ts, point_output, train_len in zip(inputs, point_outputs, train_lens)
]
per_instance_stats = None
if normalize_xreg_target_per_input:
targets, per_instance_stats = xreg_lib.normalize(targets)
xregs = xreg_lib.BatchedInContextXRegLinear(
targets=targets,
train_lens=train_lens,
test_lens=test_lens,
train_dynamic_numerical_covariates=train_dynamic_numerical_covariates,
test_dynamic_numerical_covariates=test_dynamic_numerical_covariates,
train_dynamic_categorical_covariates=train_dynamic_categorical_covariates,
test_dynamic_categorical_covariates=test_dynamic_categorical_covariates,
static_numerical_covariates=static_numerical_covariates,
static_categorical_covariates=static_categorical_covariates,
).fit(
ridge=ridge,
one_hot_encoder_drop=None if ridge > 0 else "first",
max_rows_per_col=max_rows_per_col,
force_on_cpu=force_on_cpu,
debug_info=False,
assert_covariates=True,
assert_covariate_shapes=True,
)
if normalize_xreg_target_per_input:
xregs = xreg_lib.renormalize(xregs, per_instance_stats)
xregs = np.array(xregs)
new_point_outputs = [
(point_output[-self.forecast_config.max_horizon :][:test_len] + xreg)
for point_output, test_len, xreg in zip(point_outputs, test_lens, xregs)
]
new_quantile_outputs = [
(
quantile_output[-self.forecast_config.max_horizon :][:test_len]
+ xreg[..., None]
)
for quantile_output, test_len, xreg in zip(quantile_outputs, test_lens, xregs)
]
else:
# Fit a model on the targets then forecast on the residuals via TimesFM.
targets = [
np.array(input_ts)[-train_len:]
for input_ts, train_len in zip(inputs, train_lens)
]
per_instance_stats = None
if normalize_xreg_target_per_input:
targets, per_instance_stats = xreg_lib.normalize(targets)
xregs, xregs_on_context, _, _, _ = xreg_lib.BatchedInContextXRegLinear(
targets=targets,
train_lens=train_lens,
test_lens=test_lens,
train_dynamic_numerical_covariates=train_dynamic_numerical_covariates,
test_dynamic_numerical_covariates=test_dynamic_numerical_covariates,
train_dynamic_categorical_covariates=train_dynamic_categorical_covariates,
test_dynamic_categorical_covariates=test_dynamic_categorical_covariates,
static_numerical_covariates=static_numerical_covariates,
static_categorical_covariates=static_categorical_covariates,
).fit(
ridge=ridge,
one_hot_encoder_drop=None if ridge > 0 else "first",
max_rows_per_col=max_rows_per_col,
force_on_cpu=force_on_cpu,
debug_info=True,
assert_covariates=True,
assert_covariate_shapes=True,
)
point_outputs, quantile_outputs = self.forecast(
horizon=self.forecast_config.max_horizon,
inputs=[
target - xreg_on_context
for target, xreg_on_context in zip(targets, xregs_on_context)
],
)
new_point_outputs = [
(point_output[-self.forecast_config.max_horizon :][:test_len] + xreg)
for point_output, test_len, xreg in zip(point_outputs, test_lens, xregs)
]
new_quantile_outputs = [
(
quantile_output[-self.forecast_config.max_horizon :][:test_len]
+ xreg[..., None]
)
for quantile_output, test_len, xreg in zip(quantile_outputs, test_lens, xregs)
]
if normalize_xreg_target_per_input:
new_point_outputs = xreg_lib.renormalize(new_point_outputs, per_instance_stats)
new_quantile_outputs = xreg_lib.renormalize(
new_quantile_outputs, per_instance_stats
)
return new_point_outputs, new_quantile_outputs
================================================
FILE: src/timesfm/timesfm_2p5/timesfm_2p5_flax.py
================================================
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TimesFM models in Flax."""
import dataclasses
import functools
import gc
import logging
import math
import os
from pathlib import Path
from typing import Any, Callable, Dict
import einshape
from flax import nnx
import huggingface_hub
import jax
import jax.numpy as jnp
import jaxtyping
import numpy as np
import orbax.checkpoint as ocp
from .. import configs
from ..flax import dense, transformer, util
from . import timesfm_2p5_base
jax_einshape = einshape.jax_einshape
scan = util.scan_along_axis
revin = util.revin
Float = jaxtyping.Float
Bool = jaxtyping.Bool
Array = jaxtyping.Array
def try_gc():
for d in jax.local_devices():
stats = d.memory_stats()
if stats is None:
return
if stats["bytes_in_use"] / stats["bytes_limit"] > 0.75:
gc.collect()
break
@nnx.vmap(in_axes=(None, 0), out_axes=0)
def _create_stacked_transformers(
config: configs.StackedTransformersConfig, key: jax.Array
):
return transformer.Transformer(config.transformer, rngs=nnx.Rngs(key))
def _scan_along_axis(f, init, xs, axis: int, **kwargs):
"""Scans along an axis."""
moved_xs = jax.tree_util.tree_map(lambda x: jnp.moveaxis(x, axis, 0), xs)
carry, moved_ys = jax.lax.scan(f, init, moved_xs, **kwargs)
return (
carry,
jax.tree_util.tree_map(lambda x: jnp.moveaxis(x, 0, axis), moved_ys),
)
@nnx.scan(in_axes=(0, nnx.Carry, None, 0), out_axes=(nnx.Carry, 0))
def _apply_stacked_transformers(
model: transformer.Transformer,
x: Float[Array, "b n d"],
m: Float[Array, "b n"],
decode_cache: util.DecodeCache | None = None,
) -> Float[Array, "b n d"]:
return model(x, m, decode_cache=decode_cache)
class TimesFM_2p5_200M_flax_module(nnx.Module): # pylint: disable=invalid-name
"""TimesFM 2.5 with 200M parameters."""
config = timesfm_2p5_base.TimesFM_2p5_200M_Definition()
decode_index: int = 5
compiled_decode: Callable[..., Any] | None = None
backend: str = ""
context: int = 0
horizon: int = 0
per_core_batch_size: int = 0
def __init__(self):
super().__init__()
self.backend = jax.devices()[0].platform
self.num_devices = len(jax.devices(self.backend))
# Names constants.
self.p = self.config.input_patch_len # 32
self.o = self.config.output_patch_len # 128
self.os = self.config.output_quantile_len # 1024
self.m = self.o // self.p # 4
self.x = self.config.stacked_transformers.num_layers # 20
self.h = self.config.stacked_transformers.transformer.num_heads # 16
self.md = self.config.stacked_transformers.transformer.model_dims # 1280
self.hd = self.md // self.h # 80
self.q = len(self.config.quantiles) + 1 # 10
self.aridx = self.config.decode_index # 5
# Layers.
self.tokenizer = dense.ResidualBlock(self.config.tokenizer)
self.stacked_xf = _create_stacked_transformers(
self.config.stacked_transformers,
jax.random.split(jax.random.key(42), self.x),
)
self.output_projection_point = dense.ResidualBlock(
self.config.output_projection_point
)
self.output_projection_quantiles = dense.ResidualBlock(
self.config.output_projection_quantiles
)
def __call__(
self,
inputs: Float[Array, "b n p"],
masks: Bool[Array, "b n p"],
decode_cache: util.DecodeCache | None = None,
):
tokenizer_inputs = jnp.concatenate([inputs, masks.astype(inputs.dtype)], axis=-1)
input_embeddings = self.tokenizer(tokenizer_inputs)
if decode_cache is None:
decode_cache = [None] * self.x
output_embeddings, decode_cache = _apply_stacked_transformers(
self.stacked_xf, input_embeddings, masks[..., -1], decode_cache
)
output_ts = self.output_projection_point(output_embeddings)
output_quantile_spread = self.output_projection_quantiles(output_embeddings)
return (
input_embeddings,
output_embeddings,
output_ts,
output_quantile_spread,
), decode_cache
@nnx.jit(static_argnames=("horizon",))
def decode(self, horizon: int, inputs, masks):
batch_size, context = inputs.shape[0], inputs.shape[1]
num_decode_steps = (horizon - 1) // self.o
num_input_patches = context // self.p
decode_cache_size = num_input_patches + num_decode_steps * self.m
# Prefill
patched_inputs = jax_einshape("b(np)->bnp", inputs, b=batch_size, p=self.p)
patched_masks = jax_einshape("b(np)->bnp", masks, b=batch_size, p=self.p)
(last_n, last_mu, last_sigma), (_, context_mu, context_sigma) = scan(
lambda carry, xs: util.update_running_stats(*carry, *xs),
init=(zero := jnp.zeros(shape=(batch_size)), zero, zero),
xs=(patched_inputs, patched_masks),
axis=1,
)
decode_cache = util.DecodeCache(
next_index=jnp.zeros(shape=(self.x, batch_size), dtype=jnp.int32),
num_masked=jnp.zeros(shape=(self.x, batch_size), dtype=jnp.int32),
key=jnp.zeros(shape=(self.x, batch_size, decode_cache_size, self.h, self.hd)),
value=jnp.zeros(shape=(self.x, batch_size, decode_cache_size, self.h, self.hd)),
)
normed_inputs = revin(patched_inputs, context_mu, context_sigma, reverse=False)
normed_inputs = jnp.where(patched_masks, 0.0, normed_inputs)
(_, _, normed_outputs, normed_quantile_spread), decode_cache = self(
normed_inputs, patched_masks, decode_cache
)
renormed_outputs = jax_einshape(
"bn(oq)->bnoq",
revin(normed_outputs, context_mu, context_sigma, reverse=True),
o=self.o,
q=self.q,
)
renormed_quantile_spread = jax_einshape(
"bn(oq)->bnoq",
revin(normed_quantile_spread, context_mu, context_sigma, reverse=True),
o=self.os,
q=self.q,
)[:, -1, ...]
# Autogressive decode
@nnx.scan(in_axes=(None, nnx.Carry, 0), out_axes=(nnx.Carry, 1))
def _ar_decode(module, carry, unused_iter):
last_renormed_output, (last_n, last_mu, last_sigma), decode_cache = carry
new_patched_input = jax_einshape(
"b(mp)->bmp", last_renormed_output, m=module.m, p=module.p
)
new_mask = jnp.zeros_like(new_patched_input, dtype=jnp.bool)
carry_stats, (_, new_mu, new_sigma) = scan(
lambda carry, xs: util.update_running_stats(*carry, *xs),
init=(last_n, last_mu, last_sigma),
xs=(new_patched_input, new_mask),
axis=1,
)
new_normed_input = revin(new_patched_input, new_mu, new_sigma, reverse=False)
(_, _, new_normed_output, _), decode_cache = module(
new_normed_input, new_mask, decode_cache
)
new_renormed_output = jax_einshape(
"bm(oq)->bmoq",
revin(new_normed_output, new_mu, new_sigma, reverse=True),
o=module.o,
q=module.q,
)[..., -1, :, :]
return (
(
new_renormed_output[..., module.decode_index],
carry_stats,
decode_cache,
),
new_renormed_output,
)
if num_decode_steps > 0:
_, ar_renormed_outputs = _ar_decode(
self,
(
renormed_outputs[..., -1, :, self.decode_index],
(last_n, last_mu, last_sigma),
decode_cache,
),
jnp.arange(num_decode_steps),
)
else:
ar_renormed_outputs = None
return renormed_outputs, renormed_quantile_spread, ar_renormed_outputs
def compile(
self,
context: int,
horizon: int,
per_core_batch_size: int = 1,
):
if context % self.p != 0:
logging.info(
"When compiling, context needs to be multiple of the patch size %d."
" Modifying context to %d.",
self.p,
context := math.ceil(context / self.p) * self.p,
)
if horizon % self.o != 0:
logging.info(
"When compiling, horizon needs to be multiple of the output patch"
" size %d. Modifying horizon to %d.",
self.o,
horizon := math.ceil(horizon / self.o) * self.o,
)
self.context = context
self.horizon = horizon
self.per_core_batch_size = per_core_batch_size
@nnx.pmap(
in_axes=(None, None, 0, 0),
out_axes=(0, 0, 0),
devices=jax.devices(self.backend),
axis_size=self.num_devices,
static_broadcasted_argnums=(1,),
axis_name="global_batch",
)
def compiled_decode_kernel(model, horizon, inputs, masks):
return model.decode(horizon, inputs, masks)
self.compiled_decode = functools.partial(compiled_decode_kernel, self)
def _flip_quantile_fn(x):
return jnp.concatenate([x[..., :1], jnp.flip(x[..., 1:], axis=-1)], axis=-1)
@functools.partial(
jax.jit,
donate_argnums=(0, 1, 2),
)
def _force_flip_invariance_fn(
flipped_pf_outputs,
flipped_quantile_spreads,
flipped_ar_outputs,
):
"""Forces flip invariance."""
flipped_pf_outputs = _flip_quantile_fn(flipped_pf_outputs)
flipped_pf_outputs = jax_einshape("tb...->(tb)...", flipped_pf_outputs)
flipped_quantile_spreads = _flip_quantile_fn(flipped_quantile_spreads)
flipped_quantile_spreads = jax_einshape("tb...->(tb)...", flipped_quantile_spreads)
to_concat = [flipped_pf_outputs[:, -1, ...]]
if flipped_ar_outputs is not None:
flipped_ar_outputs = _flip_quantile_fn(flipped_ar_outputs)
flipped_ar_outputs = jax_einshape("tbno...->(tb)(no)...", flipped_ar_outputs)
to_concat.append(flipped_ar_outputs)
flipped_full_forecast = jnp.concatenate(to_concat, axis=1)
return flipped_quantile_spreads, flipped_pf_outputs, flipped_full_forecast
@functools.partial(
jax.jit,
static_argnames=("max_horizon",),
donate_argnums=(0,),
)
def _use_continuous_quantile_head_fn(full_forecast, quantile_spreads, max_horizon):
"""Uses continuous quantile head."""
to_stack = [full_forecast[..., :max_horizon, 0]]
for quantile_index in [1, 2, 3, 4]:
to_stack.append(
quantile_spreads[:, :max_horizon, quantile_index]
- quantile_spreads[:, :max_horizon, 5]
+ full_forecast[:, :max_horizon, 5]
)
to_stack.append(full_forecast[..., :max_horizon, 5])
for quantile_index in [6, 7, 8, 9]:
to_stack.append(
quantile_spreads[:, :max_horizon, quantile_index]
- quantile_spreads[:, :max_horizon, 5]
+ full_forecast[:, :max_horizon, 5]
)
return jnp.stack(to_stack, axis=-1)
@functools.partial(jax.jit, donate_argnums=(0,))
def _fix_quantile_crossing_fn(full_forecast):
"""Fixes quantile crossing."""
lower_quantiles = _scan_along_axis(
lambda carry, x: (w := jnp.minimum(carry, x), w),
init=full_forecast[..., 5],
xs=full_forecast[..., 1:5],
axis=-1,
reverse=True,
)[1]
upper_quantiles = _scan_along_axis(
lambda carry, x: (w := jnp.maximum(carry, x), w),
init=full_forecast[..., 5],
xs=full_forecast[..., 6:10],
axis=-1,
reverse=False,
)[1]
return jnp.concatenate(
[
full_forecast[..., :1],
lower_quantiles,
full_forecast[..., 5:6],
upper_quantiles,
],
axis=-1,
)
@functools.partial(jax.jit, static_argnames=("fc",), donate_argnums=(1, 2))
def _before_model_decode(fc, inputs, masks):
"""All Jax steps before model decode call."""
if fc.infer_is_positive:
is_positive = jnp.all(inputs >= 0, axis=-1, keepdims=True)
else:
is_positive = None
if fc.normalize_inputs:
mu = jnp.mean(inputs, axis=-1, keepdims=True)
sigma = jnp.std(inputs, axis=-1, keepdims=True)
inputs = revin(inputs, mu, sigma, reverse=False)
else:
mu, sigma = None, None
inputs = jax_einshape("(tb)...->tb...", inputs, b=fc.per_core_batch_size)
masks = jax_einshape("(tb)...->tb...", masks, b=fc.per_core_batch_size)
return inputs, masks, is_positive, mu, sigma
@functools.partial(
jax.jit,
static_argnames=(
"fc",
"p",
),
donate_argnums=(1, 2, 3, 4, 5, 6, 7, 8, 9),
)
def _after_model_decode(
fc,
pf_outputs,
quantile_spreads,
ar_outputs,
flipped_pf_outputs,
flipped_quantile_spreads,
flipped_ar_outputs,
is_positive,
mu,
sigma,
p,
):
"""All Jax steps after model decode call."""
# t: num_devices, b: per_core_batch_size
pf_outputs = jax_einshape("tb...->(tb)...", pf_outputs)
quantile_spreads = jax_einshape("tb...->(tb)...", quantile_spreads)
to_concat = [pf_outputs[:, -1, ...]]
if ar_outputs is not None:
ar_outputs = jax_einshape("tbno...->(tb)(no)...", ar_outputs)
to_concat.append(ar_outputs)
full_forecast = jnp.concatenate(to_concat, axis=1)
if fc.force_flip_invariance:
(
flipped_quantile_spreads,
flipped_pf_outputs,
flipped_full_forecast,
) = _force_flip_invariance_fn(
flipped_pf_outputs, flipped_quantile_spreads, flipped_ar_outputs
)
quantile_spreads = (quantile_spreads - flipped_quantile_spreads) / 2
pf_outputs = (pf_outputs - flipped_pf_outputs) / 2
full_forecast = (full_forecast - flipped_full_forecast) / 2
if fc.use_continuous_quantile_head:
full_forecast = _use_continuous_quantile_head_fn(
full_forecast, quantile_spreads, fc.max_horizon
)
if fc.return_backcast:
full_backcast = jax_einshape("...npq->...(np)q", pf_outputs[:, :-1, :p, :])
full_forecast = jnp.concatenate([full_backcast, full_forecast], axis=1)
if fc.fix_quantile_crossing:
full_forecast = _fix_quantile_crossing_fn(full_forecast)
if fc.normalize_inputs:
full_forecast = revin(full_forecast, mu, sigma, reverse=True)
if is_positive is not None:
full_forecast = jnp.where(
is_positive[..., None],
jnp.maximum(full_forecast, jnp.zeros_like(full_forecast)),
full_forecast,
)
return full_forecast
class TimesFM_2p5_200M_flax(timesfm_2p5_base.TimesFM_2p5):
"""Flax implementation of TimesFM 2.5 with 200M parameters."""
model: nnx.Module = TimesFM_2p5_200M_flax_module()
@classmethod
def from_pretrained(
cls,
model_id: str = "google/timesfm-2.5-200m-flax",
*,
revision: str | None = None,
cache_dir: str | Path | None = None,
force_download: bool = False,
proxies: Dict | None = None,
resume_download: bool | None = None,
local_files_only: bool | None = None,
token: str | None = None,
**model_kwargs,
):
"""Loads a Flax TimesFM model."""
# Create an instance of the model wrapper class.
instance = cls(**model_kwargs)
# Determine the path to the model weights.
model_file_path = ""
if os.path.isdir(model_id):
logging.info("Loading checkpoint from local directory: %s", model_id)
model_file_path = model_id
else:
logging.info("Downloading checkpoint from Hugging Face repo %s", model_id)
model_file_path = huggingface_hub.snapshot_download(
repo_id=model_id,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
logging.info("Loading checkpoint from: %s", model_file_path)
checkpointer = ocp.StandardCheckpointer()
graph, state = nnx.split(instance.model)
state = checkpointer.restore(model_file_path, state)
instance.model = nnx.merge(graph, state)
return instance
def compile(
self,
forecast_config: configs.ForecastConfig,
dryrun: bool = True,
**kwargs
):
# Acrobym used during validation.
print("Compiling model...")
fc = forecast_config
if fc.max_context % self.model.p != 0:
logging.info(
"When compiling, max context needs to be multiple of the patch size"
" %d. Using max context = %d instead.",
self.model.p,
new_context := math.ceil(fc.max_context / self.model.p) * self.model.p,
)
fc = dataclasses.replace(fc, max_context=new_context)
if fc.max_horizon % self.model.o != 0:
logging.info(
"When compiling, max horizon needs to be multiple of the output patch"
" size %d. Using max horizon = %d instead.",
self.model.o,
new_horizon := math.ceil(fc.max_horizon / self.model.o) * self.model.o,
)
fc = dataclasses.replace(fc, max_horizon=new_horizon)
if fc.max_context + fc.max_horizon > self.model.config.context_limit:
raise ValueError(
"Context + horizon must be less than the context limit."
f" {fc.max_context} + {fc.max_horizon} >"
f" {self.model.config.context_limit}."
)
if fc.use_continuous_quantile_head and (fc.max_horizon > self.model.os):
raise ValueError(
f"Continuous quantile head is not supported for horizons > {self.model.os}."
)
self.forecast_config = fc
self.model.compile(
context=self.forecast_config.max_context,
horizon=self.forecast_config.max_horizon,
per_core_batch_size=fc.per_core_batch_size,
)
self.per_core_batch_size = self.forecast_config.per_core_batch_size
self.num_devices = self.model.num_devices
self.global_batch_size = (
self.forecast_config.per_core_batch_size * self.model.num_devices
)
def compiled_decode_kernel(fc, horizon, inputs, masks):
inputs = jnp.array(inputs, dtype=jnp.float32)
masks = jnp.array(masks, dtype=jnp.bool)
if horizon > fc.max_horizon:
raise ValueError(
f"Horizon must be less than the max horizon. {horizon} > {fc.max_horizon}."
)
to_trim = fc.max_horizon - horizon
inputs, masks, is_positive, mu, sigma = _before_model_decode(fc, inputs, masks)
pf_outputs, quantile_spreads, ar_outputs = self.model.compiled_decode(
fc.max_horizon, inputs, masks
)
if fc.force_flip_invariance:
flipped_pf_outputs, flipped_quantile_spreads, flipped_ar_outputs = (
self.model.compiled_decode(fc.max_horizon, -inputs, masks)
)
else:
flipped_pf_outputs, flipped_quantile_spreads, flipped_ar_outputs = (
None,
None,
None,
)
full_forecast = _after_model_decode(
fc,
pf_outputs,
quantile_spreads,
ar_outputs,
flipped_pf_outputs,
flipped_quantile_spreads,
flipped_ar_outputs,
is_positive,
mu,
sigma,
self.model.p,
)
full_forecast_np = np.array(full_forecast)
del full_forecast
try_gc()
if to_trim > 0:
full_forecast_np = full_forecast_np[..., :-to_trim, :]
return full_forecast_np[..., 5], full_forecast_np
self.compiled_decode = functools.partial(
compiled_decode_kernel, self.forecast_config
)
if dryrun:
_ = self.compiled_decode(
self.forecast_config.max_horizon,
jnp.zeros(
(self.global_batch_size, self.forecast_config.max_context), dtype=jnp.float32
),
jnp.zeros(
(self.global_batch_size, self.forecast_config.max_context), dtype=jnp.bool
),
)
print("Compiling done.")
================================================
FILE: src/timesfm/timesfm_2p5/timesfm_2p5_torch.py
================================================
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TimesFM models."""
import dataclasses
import logging
import math
import os
from pathlib import Path
from typing import Optional, Sequence, Union
import numpy as np
import torch
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
from safetensors.torch import load_file, save_file
from torch import nn
from .. import configs
from ..torch import dense, transformer, util
from . import timesfm_2p5_base
revin = util.revin
class TimesFM_2p5_200M_torch_module(nn.Module):
"""TimesFM 2.5 with 200M parameters."""
config = timesfm_2p5_base.TimesFM_2p5_200M_Definition()
def __init__(self):
super().__init__()
# Names constants.
self.p = self.config.input_patch_len # 32
self.o = self.config.output_patch_len # 128
self.os = self.config.output_quantile_len # 1024
self.m = self.o // self.p # 4
self.x = self.config.stacked_transformers.num_layers # 20
self.h = self.config.stacked_transformers.transformer.num_heads # 16
self.md = self.config.stacked_transformers.transformer.model_dims # 1280
self.hd = self.md // self.h # 80
self.q = len(self.config.quantiles) + 1 # 10
self.aridx = self.config.decode_index # 5
# Layers.
self.tokenizer = dense.ResidualBlock(self.config.tokenizer)
self.stacked_xf = nn.ModuleList(
[
transformer.Transformer(self.config.stacked_transformers.transformer)
for _ in range(self.x)
]
)
self.output_projection_point = dense.ResidualBlock(
self.config.output_projection_point
)
self.output_projection_quantiles = dense.ResidualBlock(
self.config.output_projection_quantiles
)
# Device.
if torch.cuda.is_available():
self.device = torch.device("cuda:0")
self.device_count = torch.cuda.device_count()
else:
self.device = torch.device("cpu")
self.device_count = 1
def load_checkpoint(self, path: str, **kwargs):
"""Loads a PyTorch TimesFM model from a checkpoint."""
tensors = load_file(path)
self.load_state_dict(tensors, strict=True)
self.to(self.device)
torch_compile = True
if "torch_compile" in kwargs:
torch_compile = kwargs["torch_compile"]
if torch_compile:
print("Compiling model...")
self = torch.compile(self)
self.eval()
def forward(
self,
inputs: torch.Tensor,
masks: torch.Tensor,
decode_caches: list[util.DecodeCache] | None = None,
):
tokenizer_inputs = torch.cat([inputs, masks.to(inputs.dtype)], dim=-1)
input_embeddings = self.tokenizer(tokenizer_inputs)
if decode_caches is None:
decode_caches = [None] * self.x
output_embeddings = input_embeddings
new_decode_caches = []
for i, layer in enumerate(self.stacked_xf):
output_embeddings, new_cache = layer(
output_embeddings, masks[..., -1], decode_caches[i]
)
new_decode_caches.append(new_cache)
output_ts = self.output_projection_point(output_embeddings)
output_quantile_spread = self.output_projection_quantiles(output_embeddings)
return (
input_embeddings,
output_embeddings,
output_ts,
output_quantile_spread,
), new_decode_caches
def decode(self, horizon: int, inputs, masks):
"""Decodes the time series."""
with torch.no_grad():
batch_size, context = inputs.shape[0], inputs.shape[1]
num_decode_steps = (horizon - 1) // self.o
num_input_patches = context // self.p
decode_cache_size = num_input_patches + num_decode_steps * self.m
# Prefill
patched_inputs = torch.reshape(inputs, (batch_size, -1, self.p))
patched_masks = torch.reshape(masks, (batch_size, -1, self.p))
# running stats
n = torch.zeros(batch_size, device=inputs.device)
mu = torch.zeros(batch_size, device=inputs.device)
sigma = torch.zeros(batch_size, device=inputs.device)
patch_mu = []
patch_sigma = []
for i in range(num_input_patches):
(n, mu, sigma), _ = util.update_running_stats(
n, mu, sigma, patched_inputs[:, i], patched_masks[:, i]
)
patch_mu.append(mu)
patch_sigma.append(sigma)
last_n, last_mu, last_sigma = n, mu, sigma
context_mu = torch.stack(patch_mu, dim=1)
context_sigma = torch.stack(patch_sigma, dim=1)
decode_caches = [
util.DecodeCache(
next_index=torch.zeros(batch_size, dtype=torch.int32, device=inputs.device),
num_masked=torch.zeros(batch_size, dtype=torch.int32, device=inputs.device),
key=torch.zeros(
batch_size,
decode_cache_size,
self.h,
self.hd,
device=inputs.device,
),
value=torch.zeros(
batch_size,
decode_cache_size,
self.h,
self.hd,
device=inputs.device,
),
)
for _ in range(self.x)
]
normed_inputs = revin(patched_inputs, context_mu, context_sigma, reverse=False)
normed_inputs = torch.where(patched_masks, 0.0, normed_inputs)
(_, _, normed_outputs, normed_quantile_spread), decode_caches = self(
normed_inputs, patched_masks, decode_caches
)
renormed_outputs = torch.reshape(
revin(normed_outputs, context_mu, context_sigma, reverse=True),
(batch_size, -1, self.o, self.q),
)
renormed_quantile_spread = torch.reshape(
revin(normed_quantile_spread, context_mu, context_sigma, reverse=True),
(batch_size, -1, self.os, self.q),
)[:, -1, ...]
# Autogressive decode
ar_outputs = []
last_renormed_output = renormed_outputs[:, -1, :, self.aridx]
for _ in range(num_decode_steps):
new_patched_input = torch.reshape(
last_renormed_output, (batch_size, self.m, self.p)
)
new_mask = torch.zeros_like(new_patched_input, dtype=torch.bool)
n, mu, sigma = last_n, last_mu, last_sigma
new_mus, new_sigmas = [], []
for i in range(self.m):
(n, mu, sigma), _ = util.update_running_stats(
n, mu, sigma, new_patched_input[:, i], new_mask[:, i]
)
new_mus.append(mu)
new_sigmas.append(sigma)
last_n, last_mu, last_sigma = n, mu, sigma
new_mu = torch.stack(new_mus, dim=1)
new_sigma = torch.stack(new_sigmas, dim=1)
new_normed_input = revin(new_patched_input, new_mu, new_sigma, reverse=False)
(_, _, new_normed_output, _), decode_caches = self(
new_normed_input, new_mask, decode_caches
)
new_renormed_output = torch.reshape(
revin(new_normed_output, new_mu, new_sigma, reverse=True),
(batch_size, self.m, self.o, self.q),
)
ar_outputs.append(new_renormed_output[:, -1, ...])
last_renormed_output = new_renormed_output[:, -1, :, self.aridx]
if num_decode_steps > 0:
ar_renormed_outputs = torch.stack(ar_outputs, dim=1)
else:
ar_renormed_outputs = None
return renormed_outputs, renormed_quantile_spread, ar_renormed_outputs
def forecast_naive(
self, horizon: int, inputs: Sequence[np.ndarray]
) -> list[np.ndarray]:
"""Forecasts the time series.
This is a naive implementation for debugging purposes. No forecasting
flags are used here. Forecasting quality can be subpar.
Args:
horizon: The number of time points to forecast.
inputs: A sequence of numpy arrays, each representing a time series to
query forecast for.
Returns:
A list of numpy arrays of forecasts.
"""
outputs = []
for each_input in inputs:
input_t = torch.tensor(each_input, dtype=torch.float32)
mask = torch.zeros_like(input_t, dtype=torch.bool)
len_front_mask = self.p - (len(each_input) % self.p)
if len_front_mask < self.p:
input_t = torch.cat(
[torch.zeros(len_front_mask, dtype=torch.float32), input_t], dim=0
)
mask = torch.cat([torch.ones(len_front_mask, dtype=torch.bool), mask], dim=0)
input_t = input_t[None, ...]
mask = mask[None, ...]
t_pf, _, t_ar = self.decode(horizon, input_t, mask)
to_concat = [t_pf[:, -1, ...]]
if t_ar is not None:
to_concat.append(t_ar.reshape(1, -1, self.q))
torch_forecast = torch.cat(to_concat, dim=1)[..., :horizon]
torch_forecast = torch_forecast.squeeze(0)
outputs.append(torch_forecast.detach().cpu().numpy())
return outputs
class TimesFM_2p5_200M_torch(
timesfm_2p5_base.TimesFM_2p5,
PyTorchModelHubMixin,
library_name="timesfm",
repo_url="https://github.com/google-research/timesfm",
paper_url="https://arxiv.org/abs/2310.10688",
docs_url="https://github.com/google-research/timesfm",
license="apache-2.0",
pipeline_tag="time-series-forecasting",
tags=["pytorch", "timeseries", "forecasting", "timesfm-2.5"],
):
"""PyTorch implementation of TimesFM 2.5 with 200M parameters."""
DEFAULT_REPO_ID = "google/timesfm-2.5-200m-pytorch"
WEIGHTS_FILENAME = "model.safetensors"
def __init__(
self,
torch_compile: bool = True,
config: Optional[dict] = None,
):
self.model = TimesFM_2p5_200M_torch_module()
self.torch_compile = torch_compile
if config is not None:
self._hub_mixin_config = config
@classmethod
def _from_pretrained(
cls,
*,
model_id: str = DEFAULT_REPO_ID,
revision: Optional[str],
cache_dir: Optional[Union[str, Path]],
force_download: bool = False,
local_files_only: bool,
token: Optional[Union[str, bool]],
config: Optional[dict] = None,
**model_kwargs,
):
"""
Loads a PyTorch safetensors TimesFM model from a local path or the Hugging
Face Hub. This method is the backend for the `from_pretrained` class
method provided by `PyTorchModelHubMixin`.
"""
# Determine the path to the model weights.
model_file_path = ""
if os.path.isdir(model_id):
logging.info("Loading checkpoint from local directory: %s", model_id)
model_file_path = os.path.join(model_id, cls.WEIGHTS_FILENAME)
if not os.path.exists(model_file_path):
raise FileNotFoundError(
f"{cls.WEIGHTS_FILENAME} not found in directory {model_id}"
)
else:
logging.info("Downloading checkpoint from Hugging Face repo %s", model_id)
model_file_path = hf_hub_download(
repo_id=model_id,
filename=cls.WEIGHTS_FILENAME,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
token=token,
local_files_only=local_files_only,
)
# Create an instance of the model wrapper class.
instance = cls(config=config, **model_kwargs)
logging.info("Loading checkpoint from: %s", model_file_path)
# Load the weights into the model.
instance.model.load_checkpoint(
model_file_path, torch_compile=instance.torch_compile
)
return instance
def _save_pretrained(self, save_directory: Union[str, Path]):
"""
Saves the model's state dictionary to a safetensors file. This method
is called by the `save_pretrained` method from `PyTorchModelHubMixin`.
"""
if not os.path.exists(save_directory):
os.makedirs(save_directory)
weights_path = os.path.join(save_directory, self.WEIGHTS_FILENAME)
save_file(self.model.state_dict(), weights_path)
def compile(self, forecast_config: configs.ForecastConfig, **kwargs) -> None:
"""Attempts to compile the model for fast decoding.
See configs.ForecastConfig for more details on the supported flags.
Args:
forecast_config: Configuration for forecasting flags.
**kwargs: Additional keyword arguments to pass to model.compile().
"""
self.global_batch_size = (
forecast_config.per_core_batch_size * self.model.device_count
)
# Shortcut.
fc = forecast_config
if fc.max_context % self.model.p != 0:
logging.info(
"When compiling, max context needs to be multiple of the patch size"
" %d. Using max context = %d instead.",
self.model.p,
new_context := math.ceil(fc.max_context / self.model.p) * self.model.p,
)
fc = dataclasses.replace(fc, max_context=new_context)
if fc.max_horizon % self.model.o != 0:
logging.info(
"When compiling, max horizon needs to be multiple of the output patch"
" size %d. Using max horizon = %d instead.",
self.model.o,
new_horizon := math.ceil(fc.max_horizon / self.model.o) * self.model.o,
)
fc = dataclasses.replace(fc, max_horizon=new_horizon)
if fc.max_context + fc.max_horizon > self.model.config.context_limit:
raise ValueError(
"Context + horizon must be less than the context limit."
f" {fc.max_context} + {fc.max_horizon} >"
f" {self.model.config.context_limit}."
)
if fc.use_continuous_quantile_head and (fc.max_horizon > self.model.os):
raise ValueError(
f"Continuous quantile head is not supported for horizons > {self.model.os}."
)
self.forecast_config = fc
def _compiled_decode(horizon, inputs, masks):
if horizon > fc.max_horizon:
raise ValueError(
f"Horizon must be less than the max horizon. {horizon} > {fc.max_horizon}."
)
inputs = (
torch.from_numpy(np.array(inputs)).to(self.model.device).to(torch.float32)
)
masks = torch.from_numpy(np.array(masks)).to(self.model.device).to(torch.bool)
batch_size = inputs.shape[0]
if fc.infer_is_positive:
is_positive = torch.all(inputs >= 0, dim=-1, keepdim=True)
else:
is_positive = None
if fc.normalize_inputs:
mu = torch.mean(inputs, dim=-1, keepdim=True)
sigma = torch.std(inputs, dim=-1, keepdim=True)
inputs = revin(inputs, mu, sigma, reverse=False)
else:
mu, sigma = None, None
pf_outputs, quantile_spreads, ar_outputs = self.model.decode(
forecast_config.max_horizon, inputs, masks
)
to_cat = [pf_outputs[:, -1, ...]]
if ar_outputs is not None:
to_cat.append(ar_outputs.reshape(batch_size, -1, self.model.q))
full_forecast = torch.cat(to_cat, dim=1)
def flip_quantile_fn(x):
return torch.cat([x[..., :1], torch.flip(x[..., 1:], dims=(-1,))], dim=-1)
if fc.force_flip_invariance:
flipped_pf_outputs, flipped_quantile_spreads, flipped_ar_outputs = (
self.model.decode(forecast_config.max_horizon, -inputs, masks)
)
flipped_quantile_spreads = flip_quantile_fn(flipped_quantile_spreads)
flipped_pf_outputs = flip_quantile_fn(flipped_pf_outputs)
to_cat = [flipped_pf_outputs[:, -1, ...]]
if flipped_ar_outputs is not None:
to_cat.append(flipped_ar_outputs.reshape(batch_size, -1, self.model.q))
flipped_full_forecast = torch.cat(to_cat, dim=1)
quantile_spreads = (quantile_spreads - flipped_quantile_spreads) / 2
pf_outputs = (pf_outputs - flipped_pf_outputs) / 2
full_forecast = (full_forecast - flipped_full_forecast) / 2
if fc.use_continuous_quantile_head:
for quantile_index in [1, 2, 3, 4, 6, 7, 8, 9]:
full_forecast[:, :, quantile_index] = (
quantile_spreads[:, : fc.max_horizon, quantile_index]
- quantile_spreads[:, : fc.max_horizon, 5]
+ full_forecast[:, : fc.max_horizon, 5]
)
full_forecast = full_forecast[:, :horizon, :]
if fc.return_backcast:
full_backcast = pf_outputs[:, :-1, : self.model.p, :].reshape(
batch_size, -1, self.model.q
)
full_forecast = torch.cat([full_backcast, full_forecast], dim=1)
if fc.fix_quantile_crossing:
for i in [4, 3, 2, 1]:
full_forecast[:, :, i] = torch.where(
full_forecast[:, :, i] < full_forecast[:, :, i + 1],
full_forecast[:, :, i],
full_forecast[:, :, i + 1],
)
for i in [6, 7, 8, 9]:
full_forecast[:, :, i] = torch.where(
full_forecast[:, :, i] > full_forecast[:, :, i - 1],
full_forecast[:, :, i],
full_forecast[:, :, i - 1],
)
if fc.normalize_inputs:
full_forecast = revin(full_forecast, mu, sigma, reverse=True)
if is_positive is not None:
full_forecast = torch.where(
is_positive[..., None],
torch.maximum(full_forecast, torch.zeros_like(full_forecast)),
full_forecast,
)
full_forecast = full_forecast.detach().cpu().numpy()
return full_forecast[..., 5], full_forecast
self.compiled_decode = _compiled_decode
================================================
FILE: src/timesfm/torch/__init__.py
================================================
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
================================================
FILE: src/timesfm/torch/dense.py
================================================
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Dense layers for TimesFM."""
import torch
from torch import nn
from .. import configs
class ResidualBlock(nn.Module):
"""Residual block with two linear layers and a linear residual connection."""
def __init__(self, config: configs.ResidualBlockConfig):
super().__init__()
self.config = config
self.hidden_layer = nn.Linear(
in_features=config.input_dims,
out_features=config.hidden_dims,
bias=config.use_bias,
)
self.output_layer = nn.Linear(
in_features=config.hidden_dims,
out_features=config.output_dims,
bias=config.use_bias,
)
self.residual_layer = nn.Linear(
in_features=config.input_dims,
out_features=config.output_dims,
bias=config.use_bias,
)
if config.activation == "relu":
self.activation = nn.ReLU()
elif config.activation == "swish":
self.activation = nn.SiLU()
elif config.activation == "none":
self.activation = nn.Identity()
else:
raise ValueError(f"Activation: {config.activation} not supported.")
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.output_layer(
self.activation(self.hidden_layer(x))
) + self.residual_layer(x)
class RandomFourierFeatures(nn.Module):
"""Random Fourier features layer."""
def __init__(self, config: configs.RandomFourierFeaturesConfig):
super().__init__()
self.config = config
if config.output_dims % 4 != 0:
raise ValueError(
f"Output dims must be a multiple of 4: {config.output_dims} % 4 != 0."
)
num_projected_features = config.output_dims // 4
self.phase_shifts = nn.Parameter(torch.zeros(2, num_projected_features))
self.projection_layer = nn.Linear(
in_features=config.input_dims,
out_features=num_projected_features,
bias=config.use_bias,
)
self.residual_layer = nn.Linear(
in_features=config.input_dims,
out_features=config.output_dims,
bias=config.use_bias,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
projected = self.projection_layer(x)
cos_features = torch.cos(projected)
sin_features = torch.sin(projected)
sq_wave_1 = torch.sign(torch.sin(projected + self.phase_shifts[0, :]))
sq_wave_2 = torch.sign(torch.sin(projected + self.phase_shifts[1, :]))
fourier_features = torch.cat(
[cos_features, sin_features, sq_wave_1, sq_wave_2], dim=-1
)
residual = self.residual_layer(x)
return fourier_features + residual
================================================
FILE: src/timesfm/torch/normalization.py
================================================
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Normalization layers for TimesFM."""
import torch
from torch import nn
class RMSNorm(nn.Module):
"""RMS normalization."""
def __init__(
self,
num_features: int,
*,
epsilon: float = 1e-6,
):
super().__init__()
self.scale = nn.Parameter(torch.zeros(num_features))
self.num_features = num_features
self.epsilon = epsilon
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
var = torch.mean(torch.square(inputs), dim=-1, keepdim=True)
normed_inputs = inputs * torch.rsqrt(var + self.epsilon)
normed_inputs = normed_inputs * self.scale
return normed_inputs
================================================
FILE: src/timesfm/torch/transformer.py
================================================
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformer layers for TimesFM."""
import math
from typing import Callable
import torch
import torch.nn.functional as F
from torch import nn
from .. import configs
from . import normalization, util
LayerNorm = nn.LayerNorm
RMSNorm = normalization.RMSNorm
DecodeCache = util.DecodeCache
def make_attn_mask(
query_length: int,
num_all_masked_kv: torch.Tensor,
query_index_offset: torch.Tensor | None = None,
kv_length: int = 0,
) -> torch.Tensor:
"""Makes attention mask."""
if kv_length == 0:
kv_length = query_length
q_index = torch.arange(query_length, device=num_all_masked_kv.device)[
None, None, :, None
]
if query_index_offset is not None:
q_index = q_index + query_index_offset[:, None, None, None]
kv_index = torch.arange(kv_length, device=num_all_masked_kv.device)[
None, None, None, :
]
return torch.logical_and(
q_index >= kv_index,
kv_index >= num_all_masked_kv[:, None, None, None],
)
class RotaryPositionalEmbedding(nn.Module):
"""Rotary positional embedding."""
def __init__(
self,
embedding_dims: int,
min_timescale: float = 1.0,
max_timescale: float = 10000.0,
):
super().__init__()
self.embedding_dims = embedding_dims
self.min_timescale = min_timescale
self.max_timescale = max_timescale
def forward(
self,
inputs: torch.Tensor,
position: torch.Tensor | None = None,
):
"""Generates a JTensor of sinusoids with different frequencies."""
if self.embedding_dims != inputs.shape[-1]:
raise ValueError(
"The embedding dims of the rotary position embedding"
"must match the hidden dimension of the inputs."
)
half_embedding_dim = self.embedding_dims // 2
fraction = (
2
* torch.arange(0, half_embedding_dim, device=inputs.device)
/ self.embedding_dims
)
timescale = (
self.min_timescale * (self.max_timescale / self.min_timescale) ** fraction
).to(inputs.device)
if position is None:
seq_length = inputs.shape[1]
position = torch.arange(seq_length, dtype=torch.float32, device=inputs.device)[
None, :
]
if len(inputs.shape) == 4:
position = position[..., None, None]
timescale = timescale[None, None, None, :]
elif len(inputs.shape) == 3:
position = position[..., None]
timescale = timescale[None, None, :]
else:
raise ValueError("Inputs must be of rank 3 or 4.")
sinusoid_inp = position / timescale
sin = torch.sin(sinusoid_inp)
cos = torch.cos(sinusoid_inp)
first_half, second_half = torch.chunk(inputs, 2, dim=-1)
first_part = first_half * cos - second_half * sin
second_part = second_half * cos + first_half * sin
return torch.cat([first_part, second_part], dim=-1)
def _dot_product_attention(
query,
key,
value,
mask=None,
):
"""Computes dot-product attention given query, key, and value."""
attn_weights = torch.einsum("...qhd,...khd->...hqk", query, key)
if mask is not None:
attn_weights = torch.where(
mask, attn_weights, -torch.finfo(attn_weights.dtype).max / 2
)
attn_weights = F.softmax(attn_weights, dim=-1)
return torch.einsum("...hqk,...khd->...qhd", attn_weights, value)
def _torch_dot_product_attention(query, key, value, mask=None):
"""
Performs the exact same (unscaled) attention as the above function,
but using the fast and fused F.scaled_dot_product_attention kernel.
"""
# 1. Permute inputs from (B, L, H, D) to the expected (B, H, L, D)
query = query.permute(0, 2, 1, 3)
key = key.permute(0, 2, 1, 3)
value = value.permute(0, 2, 1, 3)
# 2. Call the fused attention kernel
# - Pass the mask to `attn_mask`.
# - Set `scale=1.0` to disable the default 1/sqrt(d_k) scaling.
output = F.scaled_dot_product_attention(query, key, value, attn_mask=mask, scale=1.0)
# 3. Permute the output back to the original (B, L, H, D) layout
output = output.permute(0, 2, 1, 3)
return output
class PerDimScale(nn.Module):
"""Per-dimension scaling."""
def __init__(self, num_dims: int):
super().__init__()
self.num_dims = num_dims
self.per_dim_scale = nn.Parameter(torch.zeros(num_dims))
def forward(self, x: torch.Tensor) -> torch.Tensor:
scale_factor = (
1.442695041 / math.sqrt(self.num_dims) * F.softplus(self.per_dim_scale)
)
return x * scale_factor
class MultiHeadAttention(nn.Module):
"""Multi-head attention."""
def __init__(
self,
num_heads: int,
in_features: int,
*,
use_per_dim_scale: bool = True,
use_rotary_position_embeddings: bool = True,
use_bias: bool = False,
attention_fn: Callable[..., torch.Tensor] = _torch_dot_product_attention,
qk_norm: str = "rms",
fuse_qkv: bool = False,
):
super().__init__()
self.num_heads = num_heads
self.in_features = in_features
self.head_dim = in_features // num_heads
self.use_bias = use_bias
self.attention_fn = attention_fn
self.qk_norm = qk_norm
self.fuse_qkv = fuse_qkv
if self.in_features % self.num_heads != 0:
raise ValueError(
f"Memory dimension ({self.in_features}) must be divisible by "
f"'num_heads' heads ({self.num_heads})."
)
if self.fuse_qkv:
self.qkv_proj = nn.Linear(self.in_features, 3 * self.in_features, bias=use_bias)
else:
self.query = nn.Linear(self.in_features, self.in_features, bias=use_bias)
self.key = nn.Linear(self.in_features, self.in_features, bias=use_bias)
self.value = nn.Linear(self.in_features, self.in_features, bias=use_bias)
self.out = nn.Linear(self.in_features, self.in_features, bias=use_bias)
if self.qk_norm == "rms":
self.query_ln = RMSNorm(self.head_dim)
self.key_ln = RMSNorm(self.head_dim)
else:
self.query_ln = nn.Identity()
self.key_ln = nn.Identity()
self.use_rotary_position_embeddings = use_rotary_position_embeddings
if self.use_rotary_position_embeddings:
self.rotary_position_embedding = RotaryPositionalEmbedding(
embedding_dims=self.head_dim,
)
self.use_per_dim_scale = use_per_dim_scale
if use_per_dim_scale:
self.per_dim_scale = PerDimScale(num_dims=self.head_dim)
def forward(
self,
inputs_q: torch.Tensor,
*,
decode_cache: DecodeCache | None = None,
patch_mask: torch.Tensor | None = None,
) -> tuple[torch.Tensor, DecodeCache | None]:
b, n_patches, _ = inputs_q.shape
if patch_mask is None:
patch_mask = torch.zeros(b, n_patches, dtype=torch.bool, device=inputs_q.device)
if self.fuse_qkv:
qkv = self.qkv_proj(inputs_q)
query, key, value = torch.chunk(qkv, 3, dim=-1)
query = query.view(b, n_patches, self.num_heads, self.head_dim)
key = key.view(b, n_patches, self.num_heads, self.head_dim)
value = value.view(b, n_patches, self.num_heads, self.head_dim)
else:
query = self.query(inputs_q).view(b, n_patches, self.num_heads, self.head_dim)
key = self.key(inputs_q).view(b, n_patches, self.num_heads, self.head_dim)
value = self.value(inputs_q).view(b, n_patches, self.num_heads, self.head_dim)
if decode_cache is None:
num_masked = torch.sum(patch_mask.to(torch.int32), dim=-1)
next_index = torch.zeros_like(num_masked, dtype=torch.int32)
else:
num_masked = (
torch.sum(patch_mask.to(torch.int32), dim=-1) + decode_cache.num_masked
)
next_index = decode_cache.next_index.clone()
if self.use_rotary_position_embeddings:
position = (
torch.arange(n_patches, device=inputs_q.device)[None, :]
+ next_index[:, None]
- num_masked[:, None]
)
query = self.rotary_position_embedding(query, position)
key = self.rotary_position_embedding(key, position)
query = self.query_ln(query)
key = self.key_ln(key)
if self.use_per_dim_scale:
query = self.per_dim_scale(query)
if decode_cache is not None:
_, decode_cache_size, _, _ = decode_cache.value.shape
start = decode_cache.next_index[0]
end = start + n_patches
# Perform a single, vectorized slice assignment for the entire batch.
# This is vastly more efficient than a Python for-loop.
decode_cache.key[:, start:end] = key
decode_cache.value[:, start:end] = value
key = decode_cache.key
value = decode_cache.value
decode_cache.next_index += n_patches
decode_cache.num_masked = num_masked
attn_mask = make_attn_mask(
query_length=n_patches,
num_all_masked_kv=num_masked,
query_index_offset=next_index,
kv_length=decode_cache_size,
)
else:
attn_mask = make_attn_mask(query_length=n_patches, num_all_masked_kv=num_masked)
x = self.attention_fn(
query,
key,
value,
mask=attn_mask,
)
x = x.reshape(b, n_patches, self.in_features)
out = self.out(x)
return out, decode_cache
class Transformer(nn.Module):
"""Classic Transformer used in TimesFM."""
def __init__(self, config: configs.TransformerConfig):
super().__init__()
self.config = config
if config.attention_norm == "rms":
self.pre_attn_ln = RMSNorm(num_features=config.model_dims)
self.post_attn_ln = RMSNorm(num_features=config.model_dims)
else:
raise ValueError(f"Layer norm: {config.attention_norm} not supported.")
self.attn = MultiHeadAttention(
num_heads=config.num_heads,
in_features=config.model_dims,
use_per_dim_scale=True,
use_rotary_position_embeddings=config.use_rotary_position_embeddings,
qk_norm=config.qk_norm,
fuse_qkv=config.fuse_qkv,
)
if config.feedforward_norm == "rms":
self.pre_ff_ln = RMSNorm(num_features=config.model_dims)
self.post_ff_ln = RMSNorm(num_features=config.model_dims)
else:
raise ValueError(f"Layer norm: {config.feedforward_norm} not supported.")
self.ff0 = nn.Linear(
in_features=config.model_dims,
out_features=config.hidden_dims,
bias=config.use_bias,
)
self.ff1 = nn.Linear(
in_features=config.hidden_dims,
out_features=config.model_dims,
bias=config.use_bias,
)
if config.ff_activation == "relu":
self.activation = nn.ReLU()
elif config.ff_activation == "swish":
self.activation = nn.SiLU()
elif config.ff_activation == "none":
self.activation = nn.Identity()
else:
raise ValueError(f"Activation: {config.ff_activation} not supported.")
def forward(
self,
input_embeddings: torch.Tensor,
patch_mask: torch.Tensor,
decode_cache: DecodeCache | None = None,
) -> tuple[torch.Tensor, DecodeCache | None]:
attn_output, decode_cache = self.attn(
inputs_q=self.pre_attn_ln(input_embeddings),
decode_cache=decode_cache,
patch_mask=patch_mask,
)
attn_output = self.post_attn_ln(attn_output) + input_embeddings
output_embeddings = (
self.post_ff_ln(self.ff1(self.activation(self.ff0(self.pre_ff_ln(attn_output)))))
+ attn_output
)
return output_embeddings, decode_cache
================================================
FILE: src/timesfm/torch/util.py
================================================
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch utility functions for TimesFM layers."""
import dataclasses
import torch
_TOLERANCE = 1e-6
@dataclasses.dataclass(frozen=False)
class DecodeCache:
"""Cache for decoding."""
next_index: torch.Tensor
num_masked: torch.Tensor
key: torch.Tensor
value: torch.Tensor
def update_running_stats(
n: torch.Tensor,
mu: torch.Tensor,
sigma: torch.Tensor,
x: torch.Tensor,
mask: torch.Tensor,
) -> tuple[
tuple[torch.Tensor, torch.Tensor, torch.Tensor],
tuple[torch.Tensor, torch.Tensor, torch.Tensor],
]:
"""Updates the running stats."""
is_legit = torch.logical_not(mask)
inc_n = torch.sum(is_legit.to(x.dtype), dim=-1)
inc_mu_numerator = torch.sum(x * is_legit, dim=-1)
inc_n_safe = torch.where(inc_n == 0, 1.0, inc_n)
inc_mu = inc_mu_numerator / inc_n_safe
inc_mu = torch.where(inc_n == 0, 0.0, inc_mu)
inc_var_numerator = torch.sum(
((x - inc_mu.unsqueeze(-1)) ** 2) * is_legit, dim=-1
)
inc_var = inc_var_numerator / inc_n_safe
inc_var = torch.where(inc_n == 0, 0.0, inc_var)
inc_sigma = torch.sqrt(inc_var)
new_n = n + inc_n
new_n_safe = torch.where(new_n == 0, 1.0, new_n)
new_mu = (n * mu + inc_mu * inc_n) / new_n_safe
new_mu = torch.where(new_n == 0, 0.0, new_mu)
term1 = n * sigma.pow(2)
term2 = inc_n * inc_sigma.pow(2)
term3 = n * (mu - new_mu).pow(2)
term4 = inc_n * (inc_mu - new_mu).pow(2)
new_var = (term1 + term2 + term3 + term4) / new_n_safe
new_var = torch.where(new_n == 0, 0.0, new_var)
new_sigma = torch.sqrt(torch.clamp(new_var, min=0.0))
return (w := (new_n, new_mu, new_sigma), w)
def revin(
x: torch.Tensor,
mu: torch.Tensor,
sigma: torch.Tensor,
reverse: bool = False,
):
"""Reversible instance normalization."""
if len(mu.shape) == len(x.shape) - 1:
mu = mu[..., None]
sigma = sigma[..., None]
elif len(mu.shape) == len(x.shape) - 2:
mu = mu[..., None, None]
sigma = sigma[..., None, None]
if reverse:
return x * sigma + mu
else:
return (x - mu) / torch.where(sigma < _TOLERANCE, 1.0, sigma)
================================================
FILE: src/timesfm/utils/xreg_lib.py
================================================
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helper functions for in-context covariates and regression."""
import itertools
import math
from typing import Any, Iterable, Literal, Mapping, Sequence
try:
import jax
import jax.numpy as jnp
import numpy as np
from sklearn import preprocessing
except ImportError:
raise ImportError(
"Failed to load the XReg module. Did you forget to install `timesfm[xreg]`?"
)
Category = int | str
_TOL = 1e-6
XRegMode = Literal["timesfm + xreg", "xreg + timesfm"]
def _unnest(nested: Sequence[Sequence[Any]]) -> np.ndarray:
return np.array(list(itertools.chain.from_iterable(nested)))
def _repeat(elements: Iterable[Any], counts: Iterable[int]) -> np.ndarray:
return np.array(
list(itertools.chain.from_iterable(map(itertools.repeat, elements, counts)))
)
def _to_padded_jax_array(x: np.ndarray) -> jax.Array:
if x.ndim == 1:
(i,) = x.shape
di = 2 ** math.ceil(math.log2(i)) - i
return jnp.pad(x, ((0, di),), mode="constant", constant_values=0.0)
elif x.ndim == 2:
i, j = x.shape
di = 2 ** math.ceil(math.log2(i)) - i
dj = 2 ** math.ceil(math.log2(j)) - j
return jnp.pad(x, ((0, di), (0, dj)), mode="constant", constant_values=0.0)
else:
raise ValueError(f"Unsupported array shape: {x.shape}")
# Per time series normalization: forward.
def normalize(batch):
stats = [(np.mean(x), np.where((w := np.std(x)) > _TOL, w, 1.0)) for x in batch]
new_batch = [(x - stat[0]) / stat[1] for x, stat in zip(batch, stats)]
return new_batch, stats
# Per time series normalization: inverse.
def renormalize(batch, stats):
return [x * stat[1] + stat[0] for x, stat in zip(batch, stats)]
class BatchedInContextXRegBase:
"""Helper class for in-context regression covariate formatting.
Attributes:
targets: List of targets (responses) of the in-context regression.
train_lens: List of lengths of each target vector from the context.
test_lens: List of lengths of each forecast horizon.
train_dynamic_numerical_covariates: Dict of covariate names mapping to the
dynamic numerical covariates of each forecast task on the context. Their
lengths should match the corresponding lengths in `train_lens`.
train_dynamic_categorical_covariates: Dict of covariate names mapping to the
dynamic categorical covariates of each forecast task on the context. Their
lengths should match the corresponding lengths in `train_lens`.
test_dynamic_numerical_covariates: Dict of covariate names mapping to the
dynamic numerical covariates of each forecast task on the horizon. Their
lengths should match the corresponding lengths in `test_lens`.
test_dynamic_categorical_covariates: Dict of covariate names mapping to the
dynamic categorical covariates of each forecast task on the horizon. Their
lengths should match the corresponding lengths in `test_lens`.
static_numerical_covariates: Dict of covariate names mapping to the static
numerical covariates of each forecast task.
static_categorical_covariates: Dict of covariate names mapping to the static
categorical covariates of each forecast task.
"""
def __init__(
self,
targets: Sequence[Sequence[float]],
train_lens: Sequence[int],
test_lens: Sequence[int],
train_dynamic_numerical_covariates: (
Mapping[str, Sequence[Sequence[float]]] | None
) = None,
train_dynamic_categorical_covariates: (
Mapping[str, Sequence[Sequence[Category]]] | None
) = None,
test_dynamic_numerical_covariates: (
Mapping[str, Sequence[Sequence[float]]] | None
) = None,
test_dynamic_categorical_covariates: (
Mapping[str, Sequence[Sequence[Category]]] | None
) = None,
static_numerical_covariates: Mapping[str, Sequence[float]] | None = None,
static_categorical_covariates: (Mapping[str, Sequence[Category]] | None) = None,
) -> None:
"""Initializes with the exogenous covariate inputs.
Here we use model fitting language to refer to the context as 'train' and
the horizon as 'test'. We assume batched inputs. To properly format the
request:
- `train_lens` represents the contexts in the batch. Targets and all train
dynamic covariates should have the same lengths as the corresponding
elements
in `train_lens`. Notice each `train_len` can be different from the exact
length of the corresponding context depending on how much of the context is
used for fitting the in-context model.
- `test_lens` represents the horizon lengths in the batch. All tesdt
dynamic
covariates should have the same lengths as the corresponding elements in
`test_lens`.
- Static covariates should be one for each input.
- For train and test dynamic covariates, they should have the same
covariate
names.
Pass an empty dict {} for a covariate type if it is not present.
Example:
Here is a set of valid inputs whose schema can be used for reference.
```
targets = [
[0.0, 0.1, 0.2],
[0.0, 0.1, 0.2, 0.3],
] # Two inputs in this batch.
train_lens = [3, 4]
test_lens = [2, 5] # Forecast horizons 2 and 5 respectively.
train_dynamic_numerical_covariates = {
"cov_1_dn": [[0.0, 0.5, 1.0], [0.0, 0.5, 1.0, 1.5]],
"cov_2_dn": [[0.0, 1.5, 1.0], [0.0, 1.5, 1.0, 2.5]],
} # Each train dynamic covariate has 3 and 4 elements respectively.
test_dynamic_numerical_covariates = {
"cov_1_dn": [[0.1, 0.6], [0.1, 0.6, 1.1, 1.6, 2.4]],
"cov_2_dn": [[0.1, 1.1], [0.1, 1.6, 1.1, 2.6, 10.0]],
} # Each test dynamic covariate has 2 and 5 elements respectively.
train_dynamic_categorical_covariates = {
"cov_1_dc": [[0, 1, 0], [0, 1, 2, 3]],
"cov_2_dc": [["good", "bad", "good"], ["good", "good", "bad",
"bad"]],
}
test_dynamic_categorical_covariates = {
"cov_1_dc": [[1, 0], [1, 0, 2, 3, 1]],
"cov_2_dc": [["bad", "good"], ["bad", "bad", "bad", "bad", "bad"]],
}
static_numerical_covariates = {
"cov_1_sn": [0.0, 3.0],
"cov_2_sn": [2.0, 1.0],
"cov_3_sn": [1.0, 2.0],
} # Each static covariate has 1 element for each input.
static_categorical_covariates = {
"cov_1_sc": ["apple", "orange"],
"cov_2_sc": [2, 3],
}
```
Args:
targets: List of targets (responses) of the in-context regression.
train_lens: List of lengths of each target vector from the context.
test_lens: List of lengths of each forecast horizon.
train_dynamic_numerical_covariates: Dict of covariate names mapping to the
dynamic numerical covariates of each forecast task on the context. Their
lengths should match the corresponding lengths in `train_lens`.
train_dynamic_categorical_covariates: Dict of covariate names mapping to
the dynamic categorical covariates of each forecast task on the context.
Their lengths should match the corresponding lengths in `train_lens`.
test_dynamic_numerical_covariates: Dict of covariate names mapping to the
dynamic numerical covariates of each forecast task on the horizon. Their
lengths should match the corresponding lengths in `test_lens`.
test_dynamic_categorical_covariates: Dict of covariate names mapping to
the dynamic categorical covariates of each forecast task on the horizon.
Their lengths should match the corresponding lengths in `test_lens`.
static_numerical_covariates: Dict of covariate names mapping to the static
numerical covariates of each forecast task.
static_categorical_covariates: Dict of covariate names mapping to the
static categorical covariates of each forecast task.
"""
self.targets = targets
self.train_lens = train_lens
self.test_lens = test_lens
self.train_dynamic_numerical_covariates = train_dynamic_numerical_covariates or {}
self.train_dynamic_categorical_covariates = (
train_dynamic_categorical_covariates or {}
)
self.test_dynamic_numerical_covariates = test_dynamic_numerical_covariates or {}
self.test_dynamic_categorical_covariates = test_dynamic_categorical_covariates or {}
self.static_numerical_covariates = static_numerical_covariates or {}
self.static_categorical_covariates = static_categorical_covariates or {}
def _assert_covariates(self, assert_covariate_shapes: bool = False) -> None:
"""Verifies the validity of the covariate inputs."""
# Check presence.
if (
self.train_dynamic_numerical_covariates
and not self.test_dynamic_numerical_covariates
) or (
not self.train_dynamic_numerical_covariates
and self.test_dynamic_numerical_covariates
):
raise ValueError(
"train_dynamic_numerical_covariates and"
" test_dynamic_numerical_covariates must be both present or both"
" absent."
)
if (
self.train_dynamic_categorical_covariates
and not self.test_dynamic_categorical_covariates
) or (
not self.train_dynamic_categorical_covariates
and self.test_dynamic_categorical_covariates
):
raise ValueError(
"train_dynamic_categorical_covariates and"
" test_dynamic_categorical_covariates must be both present or both"
" absent."
)
# Check keys.
for dict_a, dict_b, dict_a_name, dict_b_name in (
(
self.train_dynamic_numerical_covariates,
self.test_dynamic_numerical_covariates,
"train_dynamic_numerical_covariates",
"test_dynamic_numerical_covariates",
),
(
self.train_dynamic_categorical_covariates,
self.test_dynamic_categorical_covariates,
"train_dynamic_categorical_covariates",
"test_dynamic_categorical_covariates",
),
):
if w := set(dict_a.keys()) - set(dict_b.keys()):
raise ValueError(f"{dict_a_name} has keys not present in {dict_b_name}: {w}")
if w := set(dict_b.keys()) - set(dict_a.keys()):
raise ValueError(f"{dict_b_name} has keys not present in {dict_a_name}: {w}")
# Check shapes.
if assert_covariate_shapes:
if len(self.targets) != len(self.train_lens):
raise ValueError(
"targets and train_lens must have the same number of elements."
)
if len(self.train_lens) != len(self.test_lens):
raise ValueError(
"train_lens and test_lens must have the same number of elements."
)
for i, (target, train_len) in enumerate(zip(self.targets, self.train_lens)):
if len(target) != train_len:
raise ValueError(
f"targets[{i}] has length {len(target)} != expected {train_len}."
)
for key, values in self.static_numerical_covariates.items():
if len(values) != len(self.train_lens):
raise ValueError(
f"static_numerical_covariates has key {key} with number of"
f" examples {len(values)} != expected {len(self.train_lens)}."
)
for key, values in self.static_categorical_covariates.items():
if len(values) != len(self.train_lens):
raise ValueError(
f"static_categorical_covariates has key {key} with number of"
f" examples {len(values)} != expected {len(self.train_lens)}."
)
for lens, dict_cov, dict_cov_name in (
(
self.train_lens,
self.train_dynamic_numerical_covariates,
"train_dynamic_numerical_covariates",
),
(
self.train_lens,
self.train_dynamic_categorical_covariates,
"train_dynamic_categorical_covariates",
),
(
self.test_lens,
self.test_dynamic_numerical_covariates,
"test_dynamic_numerical_covariates",
),
(
self.test_lens,
self.test_dynamic_categorical_covariates,
"test_dynamic_categorical_covariates",
),
):
for key, cov_values in dict_cov.items():
if len(cov_values) != len(lens):
raise ValueError(
f"{dict_cov_name} has key {key} with number of examples"
f" {len(cov_values)} != expected {len(lens)}."
)
for i, cov_value in enumerate(cov_values):
if len(cov_value) != lens[i]:
raise ValueError(
f"{dict_cov_name} has key {key} with its {i}-th example"
f" length {len(cov_value)} != expected {lens[i]}."
)
def create_covariate_matrix(
self,
one_hot_encoder_drop: str | None = "first",
use_intercept: bool = True,
assert_covariates: bool = False,
assert_covariate_shapes: bool = False,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Creates target vector and covariate matrices for in context regression.
Here we use model fitting language to refer to the context as 'train' and
the horizon as 'test'.
Args:
one_hot_encoder_drop: Which drop strategy to use for the one hot encoder.
use_intercept: Whether to prepare an intercept (all 1) column in the
matrices.
assert_covariates: Whether to assert the validity of the covariate inputs.
assert_covariate_shapes: Whether to assert the shapes of the covariate
inputs when `assert_covariates` is True.
Returns:
A tuple of the target vector, the covariate matrix for the context, and
the covariate matrix for the horizon.
"""
if assert_covariates:
self._assert_covariates(assert_covariate_shapes)
x_train, x_test = [], []
# Numerical features.
for name in sorted(self.train_dynamic_numerical_covariates):
x_train.append(
_unnest(self.train_dynamic_numerical_covariates[name])[:, np.newaxis]
)
x_test.append(
_unnest(self.test_dynamic_numerical_covariates[name])[:, np.newaxis]
)
for covs in self.static_numerical_covariates.values():
x_train.append(_repeat(covs, self.train_lens)[:, np.newaxis])
x_test.append(_repeat(covs, self.test_lens)[:, np.newaxis])
if x_train:
x_train = np.concatenate(x_train, axis=1)
x_test = np.concatenate(x_test, axis=1)
# Normalize for robustness.
x_mean = np.mean(x_train, axis=0, keepdims=True)
x_std = np.where((w := np.std(x_train, axis=0, keepdims=True)) > _TOL, w, 1.0)
x_train = [(x_train - x_mean) / x_std]
x_test = [(x_test - x_mean) / x_std]
# Categorical features. Encode one by one.
one_hot_encoder = preprocessing.OneHotEncoder(
drop=one_hot_encoder_drop,
sparse_output=False,
handle_unknown="ignore",
)
for name in sorted(self.train_dynamic_categorical_covariates.keys()):
ohe_train = _unnest(self.train_dynamic_categorical_covariates[name])[
:, np.newaxis
]
ohe_test = _unnest(self.test_dynamic_categorical_covariates[name])[:, np.newaxis]
x_train.append(np.array(one_hot_encoder.fit_transform(ohe_train)))
x_test.append(np.array(one_hot_encoder.transform(ohe_test)))
for covs in self.static_categorical_covariates.values():
ohe = one_hot_encoder.fit_transform(np.array(covs)[:, np.newaxis])
x_train.append(_repeat(ohe, self.train_lens))
x_test.append(_repeat(ohe, self.test_lens))
x_train = np.concatenate(x_train, axis=1)
x_test = np.concatenate(x_test, axis=1)
if use_intercept:
x_train = np.pad(x_train, ((0, 0), (1, 0)), constant_values=1.0)
x_test = np.pad(x_test, ((0, 0), (1, 0)), constant_values=1.0)
return _unnest(self.targets), x_train, x_test
def fit(self) -> Any:
raise NotImplementedError("Fit is not implemented.")
class BatchedInContextXRegLinear(BatchedInContextXRegBase):
"""Linear in-context regression model."""
def fit(
self,
ridge: float = 0.0,
one_hot_encoder_drop: str | None = "first",
use_intercept: bool = True,
force_on_cpu: bool = False,
max_rows_per_col: int = 0,
max_rows_per_col_sample_seed: int = 42,
debug_info: bool = False,
assert_covariates: bool = False,
assert_covariate_shapes: bool = False,
) -> (
list[np.ndarray]
| tuple[list[np.ndarray], list[np.ndarray], jax.Array, jax.Array, jax.Array]
):
"""Fits a linear model for in-context regression.
Args:
ridge: A non-negative value for specifying the ridge regression penalty.
If 0 is provided, fallback to ordinary least squares. Note this penalty
is added to the normalized covariate matrix.
one_hot_encoder_drop: Which drop strategy to use for the one hot encoder.
use_intercept: Whether to prepare an intercept (all 1) column in the
matrices.
force_on_cpu: Whether to force execution on cpu for accelerator machines.
max_rows_per_col: How many rows to subsample per column. 0 for no
subsampling. This is for speeding up model fitting.
max_rows_per_col_sample_seed: The seed for the subsampling if needed by
`max_rows_per_col`.
debug_info: Whether to return debug info.
assert_covariates: Whether to assert the validity of the covariate inputs.
assert_covariate_shapes: Whether to assert the shapes of the covariate
inputs when `assert_covariates` is True.
Returns:
If `debug_info` is False:
The linear fits on the horizon.
If `debug_info` is True:
A tuple of:
- the linear fits on the horizon,
- the linear fits on the context,
- the flattened target vector,
- the covariate matrix for the context, and
- the covariate matrix for the horizon.
"""
flat_targets, x_train_raw, x_test = self.create_covariate_matrix(
one_hot_encoder_drop=one_hot_encoder_drop,
use_intercept=use_intercept,
assert_covariates=assert_covariates,
assert_covariate_shapes=assert_covariate_shapes,
)
x_train = x_train_raw.copy()
if max_rows_per_col:
nrows, ncols = x_train.shape
if nrows > (w := ncols * max_rows_per_col):
subsample = jax.random.choice(
jax.random.PRNGKey(max_rows_per_col_sample_seed),
nrows,
(w,),
replace=False,
)
x_train = x_train[subsample]
flat_targets = flat_targets[subsample]
device = jax.devices("cpu")[0] if force_on_cpu else None
# Runs jitted version of the solvers which are quicker at the cost of
# running jitting during the first time calling. Re-jitting happens whenever
# new (padded) shapes are encountered.
# Ocassionally it helps with the speed and the accuracy if we force single
# thread execution on cpu for accelerator machines:
# 1. Avoid moving data to accelarator memory.
# 2. Avoid precision loss if any.
with jax.default_device(device):
x_train_raw = _to_padded_jax_array(x_train_raw)
x_train = _to_padded_jax_array(x_train)
flat_targets = _to_padded_jax_array(flat_targets)
x_test = _to_padded_jax_array(x_test)
beta_hat = (
jnp.linalg.pinv(
x_train.T @ x_train + ridge * jnp.eye(x_train.shape[1]),
hermitian=True,
)
@ x_train.T
@ flat_targets
)
y_hat = x_test @ beta_hat
y_hat_context = x_train_raw @ beta_hat if debug_info else None
outputs = []
outputs_context = []
# Reconstruct the ragged 2-dim batched forecasts from flattened linear fits.
train_index, test_index = 0, 0
for train_index_delta, test_index_delta in zip(self.train_lens, self.test_lens):
outputs.append(np.array(y_hat[test_index : (test_index + test_index_delta)]))
if debug_info:
outputs_context.append(
np.array(y_hat_context[train_index : (train_index + train_index_delta)])
)
train_index += train_index_delta
test_index += test_index_delta
if debug_info:
return outputs, outputs_context, flat_targets, x_train, x_test
else:
return outputs
================================================
FILE: timesfm-forecasting/SKILL.md
================================================
---
name: timesfm-forecasting
description: >
Zero-shot time series forecasting with Google's TimesFM foundation model. Use this
skill when forecasting ANY univariate time series — sales, sensor readings, stock prices,
energy demand, patient vitals, weather, or scientific measurements — without training a
custom model. Supports both basic forecasting and advanced covariate forecasting (XReg)
with dynamic and static exogenous variables. Automatically checks system RAM/GPU before
loading the model, validates dataset fit before processing, supports CSV/DataFrame/array
inputs, and returns point forecasts with calibrated prediction intervals. Includes a
preflight system checker script that MUST be run before first use to verify the machine
can load the model and handle your specific dataset.
license: Apache-2.0
metadata:
author: Clayton Young (@borealBytes)
version: "1.0.0"
---
# TimesFM Forecasting
## Overview
TimesFM (Time Series Foundation Model) is a pretrained decoder-only foundation model
developed by Google Research for time-series forecasting. It works **zero-shot** — feed it
any univariate time series and it returns point forecasts with calibrated quantile
prediction intervals, no training required.
This skill includes a **mandatory preflight system checker** that verifies RAM, GPU memory,
and disk space before the model is ever loaded so the agent never crashes the user's machine.
> **Key numbers**: TimesFM 2.5 uses 200M parameters (~800 MB on disk, ~1.5 GB in RAM on
> CPU, ~1 GB VRAM on GPU). The archived v1/v2 500M-parameter model needs ~32 GB RAM.
> Always run the system checker first.
## When to Use This Skill
Use this skill when:
- Forecasting **any univariate time series** (sales, demand, sensor, vitals, price, weather)
- You need **zero-shot forecasting** without training a custom model
- You want **probabilistic forecasts** with calibrated prediction intervals (quantiles)
- You have time series of **any length** (the model handles 1–16,384 context points)
- You need to **batch-forecast** hundreds or thousands of series efficiently
- You want a **foundation model** approach instead of hand-tuning ARIMA/ETS parameters
- You need **covariate forecasting** with exogenous variables (price, promotions, holidays, day-of-week effects) → use `forecast_with_covariates()` (TimesFM 2.5 + `pip install timesfm[xreg]`)
Do **not** use this skill when:
- You need classical statistical models with coefficient interpretation → use `statsmodels`
- You need time series classification or clustering → use `aeon`
- You need multivariate vector autoregression or Granger causality → use `statsmodels`
- Your data is tabular (not temporal) → use `scikit-learn`
- You cannot install optional dependencies → XReg requires scikit-learn and JAX
> **Note on Anomaly Detection**: TimesFM does not have built-in anomaly detection, but you
> can use the **quantile forecasts as prediction intervals** — values outside the 90% CI
> (q10–q90) are statistically unusual. See `examples/anomaly-detection/` for a full example.
## ⚠️ Mandatory Preflight: System Requirements Check
**CRITICAL — ALWAYS run the system checker before loading the model for the first time.**
```bash
python scripts/check_system.py
```
This script checks:
1. **Available RAM** — warns if below 4 GB, blocks if below 2 GB
2. **GPU availability** — detects CUDA/MPS devices and VRAM
3. **Disk space** — verifies room for the ~800 MB model download
4. **Python version** — requires 3.10+
5. **Existing installation** — checks if `timesfm` and `torch` are installed
> **Note:** Model weights are **NOT stored in this repository**. TimesFM weights (~800 MB)
> download on-demand from HuggingFace on first use and cache in `~/.cache/huggingface/`.
```mermaid
flowchart TD
start["🚀 Run check_system.py"] --> ram{"RAM ≥ 4 GB?"}
ram -->|"Yes"| gpu{"GPU available?"}
ram -->|"No (2-4 GB)"| warn_ram["⚠️ Warning: tight RAM CPU-only, small batches"]
ram -->|"No (< 2 GB)"| block["🛑 BLOCKED Insufficient memory"]
warn_ram --> disk
gpu -->|"CUDA / MPS"| vram{"VRAM ≥ 2 GB?"}
gpu -->|"CPU only"| cpu_ok["✅ CPU mode Slower but works"]
vram -->|"Yes"| gpu_ok["✅ GPU mode Fast inference"]
vram -->|"No"| cpu_ok
gpu_ok --> disk{"Disk ≥ 2 GB free?"}
cpu_ok --> disk
disk -->|"Yes"| ready["✅ READY Safe to load model"]
disk -->|"No"| block_disk["🛑 BLOCKED Need space for weights"]
```
### Dataset Preflight (NEW)
Before loading your actual data, verify it will fit in memory:
```bash
# Quick estimate for your dataset
python scripts/check_system.py \
--num-series 1000 \
--context-length 1024 \
--horizon 24 \
--batch-size 32 \
--estimate-only
```
This will show you the estimated memory requirements and warn if your dataset is too large.
**Memory Estimation Formula**:
`RAM ≈ 0.8 GB (model) + 0.5 GB (overhead) + (0.2 MB × num_series × context_length / 1000)`
**Example Outputs**:
✅ **Dataset Fits**:
```
Total CPU memory: 2.34 GB
Total GPU memory: 2.15 GB
```
⚠️ **Dataset Too Large**:
```
Dataset requires ~12.5 GB RAM but system has 8.0 GB.
Try: context_length=512 or process in chunks of 50 series.
```
### Hardware Requirements by Model Version
| Model | Parameters | RAM (CPU) | VRAM (GPU) | Disk | Context |
| ----- | ---------- | --------- | ---------- | ---- | ------- |
| **TimesFM 2.5** (recommended) | 200M | ≥ 4 GB | ≥ 2 GB | ~800 MB | up to 16,384 |
| TimesFM 2.0 (archived) | 500M | ≥ 16 GB | ≥ 8 GB | ~2 GB | up to 2,048 |
| TimesFM 1.0 (archived) | 200M | ≥ 8 GB | ≥ 4 GB | ~800 MB | up to 2,048 |
> **Recommendation**: Always use TimesFM 2.5 unless you have a specific reason to use an
> older checkpoint. It is smaller, faster, and supports 8× longer context.
## 🔧 Installation
### Step 1: Verify System (always first)
```bash
python scripts/check_system.py
```
### Step 2: Install TimesFM
```bash
# Using uv (fast)
uv pip install timesfm[torch]
# Or using pip
pip install timesfm[torch]
# For JAX/Flax backend (faster on TPU/GPU)
uv pip install timesfm[flax]
```
### Step 3: Install PyTorch for Your Hardware
```bash
# CUDA 12.1 (NVIDIA GPU)
pip install torch>=2.0.0 --index-url https://download.pytorch.org/whl/cu121
# CPU only
pip install torch>=2.0.0 --index-url https://download.pytorch.org/whl/cpu
# Apple Silicon (MPS)
pip install torch>=2.0.0 # MPS support is built-in
```
## 🎯 Quick Start
### Minimal Example
```python
import torch, numpy as np, timesfm
torch.set_float32_matmul_precision("high")
model = timesfm.TimesFM_2p5_200M_torch.from_pretrained(
"google/timesfm-2.5-200m-pytorch"
)
model.compile(timesfm.ForecastConfig(
max_context=1024, max_horizon=256, normalize_inputs=True,
use_continuous_quantile_head=True, force_flip_invariance=True,
infer_is_positive=True, fix_quantile_crossing=True,
))
point, quantiles = model.forecast(horizon=24, inputs=[
np.sin(np.linspace(0, 20, 200)), # any 1-D array
])
# point.shape == (1, 24) — median forecast
# quantiles.shape == (1, 24, 10) — 10th–90th percentile bands
```
### Forecast with Covariates (XReg)
TimesFM 2.5+ supports exogenous variables through `forecast_with_covariates()`.
Requires `pip install timesfm[xreg]`.
```python
point, quantiles = model.forecast_with_covariates(
inputs=inputs,
dynamic_numerical_covariates={"price": price_arrays},
dynamic_categorical_covariates={"holiday": holiday_arrays},
static_categorical_covariates={"region": region_labels},
xreg_mode="xreg + timesfm", # or "timesfm + xreg"
)
```
### Anomaly Detection (via Quantile Intervals)
```python
point, q = model.forecast(horizon=H, inputs=[values])
lower_90 = q[0, :, 1] # 10th percentile
upper_90 = q[0, :, 9] # 90th percentile
actual = test_values
anomalies = (actual < lower_90) | (actual > upper_90)
```
| Severity | Condition | Interpretation |
| -------- | --------- | -------------- |
| **Normal** | Inside 80% CI | Expected behavior |
| **Warning** | Outside 80% CI | Unusual but possible |
| **Critical** | Outside 90% CI | Statistically rare (< 10% probability) |
> See `examples/anomaly-detection/` for a complete worked example with visualization.
## 📊 Understanding the Output
TimesFM returns `(point_forecast, quantile_forecast)`:
- **`point_forecast`**: shape `(batch, horizon)` — the median (0.5 quantile)
- **`quantile_forecast`**: shape `(batch, horizon, 10)` — ten quantile slices:
| Index | Quantile | Use |
| ----- | -------- | --- |
| 0 | Mean | Average prediction |
| 1 | 0.1 | Lower bound of 80% PI |
| 2 | 0.2 | Lower bound of 60% PI |
| **5** | **0.5** | **Median (= `point_forecast`)** |
| 8 | 0.8 | Upper bound of 60% PI |
| 9 | 0.9 | Upper bound of 80% PI |
```python
point, q = model.forecast(horizon=H, inputs=data)
lower_80 = q[:, :, 1] # 10th percentile
upper_80 = q[:, :, 9] # 90th percentile
median = q[:, :, 5]
```
## 🔧 ForecastConfig Reference
All forecasting behavior is controlled by `timesfm.ForecastConfig`:
```python
timesfm.ForecastConfig(
max_context=1024, # Max context window
max_horizon=256, # Max forecast horizon
normalize_inputs=True, # RECOMMENDED — prevents scale instability
per_core_batch_size=32, # Tune for memory
use_continuous_quantile_head=True, # Better quantile accuracy for long horizons
force_flip_invariance=True, # Ensures f(-x) = -f(x)
infer_is_positive=True, # Clamp forecasts ≥ 0 when all inputs > 0
fix_quantile_crossing=True, # Ensure q10 ≤ q20 ≤ ... ≤ q90
return_backcast=False, # Return backcast (for covariate workflows)
)
```
| Parameter | Default | When to Change |
| --------- | ------- | -------------- |
| `max_context` | 0 | Set to match your longest historical window |
| `normalize_inputs` | False | **Always set True** |
| `use_continuous_quantile_head` | False | **Set True** for calibrated PIs |
| `infer_is_positive` | True | Set False for series that can be negative |
| `fix_quantile_crossing` | False | **Set True** for monotonic quantiles |
See `references/api_reference.md` for the complete parameter reference.
## 📋 Common Workflows
### Single Series Forecast
```python
import torch, numpy as np, pandas as pd, timesfm, matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
torch.set_float32_matmul_precision("high")
model = timesfm.TimesFM_2p5_200M_torch.from_pretrained(
"google/timesfm-2.5-200m-pytorch"
)
model.compile(timesfm.ForecastConfig(
max_context=512, max_horizon=52, normalize_inputs=True,
use_continuous_quantile_head=True, fix_quantile_crossing=True,
))
df = pd.read_csv("weekly_demand.csv", parse_dates=["week"])
values = df["demand"].values.astype(np.float32)
point, quantiles = model.forecast(horizon=52, inputs=[values])
fig, ax = plt.subplots(figsize=(12, 5))
ax.plot(values[-104:], label="Historical")
x_fc = range(len(values[-104:]), len(values[-104:]) + 52)
ax.plot(x_fc, point[0], label="Forecast", color="tab:orange")
ax.fill_between(x_fc, quantiles[0, :, 1], quantiles[0, :, 9],
alpha=0.2, color="tab:orange", label="80% PI")
ax.legend(); ax.set_title("52-Week Demand Forecast")
plt.tight_layout(); plt.savefig("forecast.png", dpi=150)
```
### Batch Forecasting (Many Series)
```python
df = pd.read_csv("all_stores.csv", parse_dates=["date"], index_col="date")
inputs = [df[col].dropna().values.astype(np.float32) for col in df.columns]
point, quantiles = model.forecast(horizon=30, inputs=inputs)
import json
results = {col: {"forecast": point[i].tolist(),
"lower_80": quantiles[i, :, 1].tolist(),
"upper_80": quantiles[i, :, 9].tolist()}
for i, col in enumerate(df.columns)}
with open("batch_forecasts.json", "w") as f:
json.dump(results, f, indent=2)
```
### Evaluate Forecast Accuracy
```python
H = 24
train, actual = values[:-H], values[-H:]
point, quantiles = model.forecast(horizon=H, inputs=[train])
pred = point[0]
mae = np.mean(np.abs(actual - pred))
rmse = np.sqrt(np.mean((actual - pred) ** 2))
mape = np.mean(np.abs((actual - pred) / actual)) * 100
coverage = np.mean((actual >= quantiles[0, :, 1]) & (actual <= quantiles[0, :, 9])) * 100
print(f"MAE: {mae:.2f} | RMSE: {rmse:.2f} | MAPE: {mape:.1f}% | 80% PI Coverage: {coverage:.1f}%")
```
## ⚙️ Performance Tuning
```python
# Always set on Ampere+ GPUs (A100, RTX 3090+)
torch.set_float32_matmul_precision("high")
# Batch size guidelines:
# GPU 8 GB VRAM: per_core_batch_size=64
# GPU 16 GB VRAM: per_core_batch_size=128
# CPU 8 GB RAM: per_core_batch_size=8
# CPU 16 GB RAM: per_core_batch_size=32
# Memory-constrained: process in chunks
CHUNK = 50
results = []
for i in range(0, len(inputs), CHUNK):
p, q = model.forecast(horizon=H, inputs=inputs[i:i+CHUNK])
results.append((p, q))
```
## 📚 Available Scripts
### `scripts/check_system.py`
Mandatory preflight checker — run before first model load.
Now includes **dataset-aware memory estimation** to prevent OOM errors before loading your data.
```bash
# Basic system check
python scripts/check_system.py
# Check if your specific dataset will fit
python scripts/check_system.py \
--num-series 1000 \
--context-length 1024 \
--horizon 24 \
--batch-size 32
# Quick memory estimate without system checks
python scripts/check_system.py \
--num-series 5000 \
--context-length 2048 \
--estimate-only
```
**What it checks**:
1. **Available RAM** — warns if below 4 GB, blocks if below 2 GB
2. **GPU availability** — detects CUDA/MPS devices and VRAM
3. **Disk space** — verifies room for the ~800 MB model download
4. **Python version** — requires 3.10+
5. **Existing installation** — checks if `timesfm` and `torch` are installed
6. **Dataset fit** (NEW) — estimates memory for your specific dataset and warns if it won't fit
### `scripts/forecast_csv.py`
End-to-end CSV forecasting CLI.
```bash
python scripts/forecast_csv.py input.csv \
--horizon 24 \
--date-col date \
--value-cols sales,revenue \
--output forecasts.csv
```
## 📖 Reference Documentation
| File | Contents |
| ---- | -------- |
| `references/system_requirements.md` | Hardware tiers, GPU/CPU selection, memory estimation |
| `references/api_reference.md` | Full `ForecastConfig` docs, output shapes, model options |
| `references/data_preparation.md` | Input formats, NaN handling, CSV loading, covariate setup |
## 🧪 Examples
| Example | Directory | What It Demonstrates |
| ------- | --------- | -------------------- |
| **Global Temperature Forecast** | `examples/global-temperature/` | Basic `model.forecast()`, CSV → PNG → GIF pipeline |
| **Anomaly Detection** | `examples/anomaly-detection/` | Two-phase detrend + Z-score + quantile PI, 2-panel viz |
| **Covariates (XReg)** | `examples/covariates-forecasting/` | `forecast_with_covariates()`, 2×2 shared-axis viz |
```bash
# Run all three examples:
cd examples/global-temperature && python run_forecast.py && python visualize_forecast.py
cd examples/anomaly-detection && python detect_anomalies.py
cd examples/covariates-forecasting && python demo_covariates.py
```
### Expected Outputs
| Example | Key output files | Acceptance criteria |
| ------- | ---------------- | ------------------- |
| global-temperature | `output/forecast_output.json`, `output/forecast_visualization.png` | `point_forecast` has 12 values; PNG shows context + forecast + PI bands |
| anomaly-detection | `output/anomaly_detection.json`, `output/anomaly_detection.png` | Sep 2023 flagged CRITICAL (z ≥ 3.0) |
| covariates-forecasting | `output/sales_with_covariates.csv`, `output/covariates_data.png` | 108 rows (3 stores × 36 weeks); distinct price arrays per store |
## Model Versions
| Version | Params | Context | Status | HuggingFace checkpoint |
| ------- | ------ | ------- | ------ | ---------------------- |
| **2.5** | 200M | 16,384 | **Latest** | `google/timesfm-2.5-200m-pytorch` |
| 2.0 | 500M | 2,048 | Archived | `google/timesfm-2.0-500m-pytorch` |
| 1.0 | 200M | 2,048 | Archived | `google/timesfm-1.0-200m-pytorch` |
- TimesFM 1.0/2.0: must pass `freq=[0]` for monthly data
- TimesFM 2.5: no frequency flag — it was removed
## Resources
- **Paper**: [A Decoder-Only Foundation Model for Time-Series Forecasting](https://arxiv.org/abs/2310.10688) (ICML 2024)
- **HuggingFace**: https://huggingface.co/collections/google/timesfm-release-66e4be5fdb56e960c1e482a6
- **Google Blog**: https://research.google/blog/a-decoder-only-foundation-model-for-time-series-forecasting/
- **BigQuery Integration**: https://cloud.google.com/bigquery/docs/timesfm-model
## Quality Checklist
Run after every TimesFM task before declaring success:
- [ ] **Output shape** — `point_fc` is `(n_series, horizon)`, `quant_fc` is `(n_series, horizon, 10)`
- [ ] **Quantile indices** — index 0 = mean, 1 = q10 ... 9 = q90. NOT 0 = q0.
- [ ] **Frequency flag** — TimesFM 1.0/2.0: pass `freq=[0]` for monthly. TimesFM 2.5: omit.
- [ ] **Series length** — context must be ≥ 32 data points.
- [ ] **No NaN** — `np.isnan(point_fc).any()` must be False.
- [ ] **Axes** — multiple panels sharing data must use `sharex=True`.
- [ ] **`matplotlib.use('Agg')`** — before any pyplot import when running headless.
- [ ] **`infer_is_positive`** — set False for temperature, financial returns, negatives.
## Common Mistakes
1. **Quantile index off-by-one** — `quant_fc[..., 0]` is the **mean**, not q0. q10 = index 1, q90 = index 9. Define: `IDX_Q10, IDX_Q90 = 1, 9`.
2. **Variable shadowing in covariate loops** — don't use the outer loop variable as a comprehension variable when building per-series covariate dicts.
3. **Wrong CSV column name** — global-temperature CSV uses `anomaly_c`, not `anomaly`. Print `df.columns` first.
4. **TimesFM 2.5 required for `forecast_with_covariates()`** — TimesFM 1.0 does NOT have this method.
5. **Future covariates must span the full horizon** — dynamic covariates need values for BOTH context AND forecast windows.
6. **Context anomaly detection uses residuals** — detrend first, then Z-score. Raw Z-scores mislead on trending data.
## Validation & Verification
```bash
# Anomaly detection regression:
python -c "
import json
d = json.load(open('examples/anomaly-detection/output/anomaly_detection.json'))
assert d['context_summary']['critical'] >= 1, 'Sep 2023 must be CRITICAL'
print('Anomaly detection: PASS')"
# Covariates regression:
python -c "
import pandas as pd
df = pd.read_csv('examples/covariates-forecasting/output/sales_with_covariates.csv')
assert len(df) == 108, f'Expected 108 rows, got {len(df)}'
print('Covariates: PASS')"
```
================================================
FILE: timesfm-forecasting/examples/anomaly-detection/detect_anomalies.py
================================================
#!/usr/bin/env python3
"""
TimesFM Anomaly Detection Example — Two-Phase Method
Phase 1 (context): Linear detrend + Z-score on 36 months of real NOAA
temperature anomaly data (2022-01 through 2024-12).
Sep 2023 (1.47 C) is a known critical outlier.
Phase 2 (forecast): TimesFM quantile prediction intervals on a 12-month
synthetic future with 3 injected anomalies.
Outputs:
output/anomaly_detection.png -- 2-panel visualization
output/anomaly_detection.json -- structured detection records
"""
from __future__ import annotations
import json
from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
HORIZON = 12
DATA_FILE = (
Path(__file__).parent.parent / "global-temperature" / "temperature_anomaly.csv"
)
OUTPUT_DIR = Path(__file__).parent / "output"
CRITICAL_Z = 3.0
WARNING_Z = 2.0
# quant_fc index mapping: 0=mean, 1=q10, 2=q20, ..., 9=q90
IDX_Q10, IDX_Q20, IDX_Q80, IDX_Q90 = 1, 2, 8, 9
CLR = {"CRITICAL": "#e02020", "WARNING": "#f08030", "NORMAL": "#4a90d9"}
# ---------------------------------------------------------------------------
# Phase 1: context anomaly detection
# ---------------------------------------------------------------------------
def detect_context_anomalies(
values: np.ndarray,
dates: list,
) -> tuple[list[dict], np.ndarray, np.ndarray, float]:
"""Linear detrend + Z-score anomaly detection on context period.
Returns
-------
records : list of dicts, one per month
trend_line : fitted linear trend values (same length as values)
residuals : actual - trend_line
res_std : std of residuals (used as sigma for threshold bands)
"""
n = len(values)
idx = np.arange(n, dtype=float)
coeffs = np.polyfit(idx, values, 1)
trend_line = np.polyval(coeffs, idx)
residuals = values - trend_line
res_std = residuals.std()
records = []
for i, (d, v, r) in enumerate(zip(dates, values, residuals)):
z = r / res_std if res_std > 0 else 0.0
if abs(z) >= CRITICAL_Z:
severity = "CRITICAL"
elif abs(z) >= WARNING_Z:
severity = "WARNING"
else:
severity = "NORMAL"
records.append(
{
"date": str(d)[:7],
"value": round(float(v), 4),
"trend": round(float(trend_line[i]), 4),
"residual": round(float(r), 4),
"z_score": round(float(z), 3),
"severity": severity,
}
)
return records, trend_line, residuals, res_std
# ---------------------------------------------------------------------------
# Phase 2: synthetic future + forecast anomaly detection
# ---------------------------------------------------------------------------
def build_synthetic_future(
context: np.ndarray,
n: int,
seed: int = 42,
) -> tuple[np.ndarray, list[int]]:
"""Build a plausible future with 3 injected anomalies.
Injected months: 3, 8, 11 (0-indexed within the 12-month horizon).
Returns (future_values, injected_indices).
"""
rng = np.random.default_rng(seed)
trend = np.linspace(context[-6:].mean(), context[-6:].mean() + 0.05, n)
noise = rng.normal(0, 0.1, n)
future = trend + noise
injected = [3, 8, 11]
future[3] += 0.7 # CRITICAL spike
future[8] -= 0.65 # CRITICAL dip
future[11] += 0.45 # WARNING spike
return future.astype(np.float32), injected
def detect_forecast_anomalies(
future_values: np.ndarray,
point: np.ndarray,
quant_fc: np.ndarray,
future_dates: list,
injected_at: list[int],
) -> list[dict]:
"""Classify each forecast month by which PI band it falls outside.
CRITICAL = outside 80% PI (q10-q90)
WARNING = outside 60% PI (q20-q80) but inside 80% PI
NORMAL = inside 60% PI
"""
q10 = quant_fc[IDX_Q10]
q20 = quant_fc[IDX_Q20]
q80 = quant_fc[IDX_Q80]
q90 = quant_fc[IDX_Q90]
records = []
for i, (d, fv, pt) in enumerate(zip(future_dates, future_values, point)):
outside_80 = fv < q10[i] or fv > q90[i]
outside_60 = fv < q20[i] or fv > q80[i]
if outside_80:
severity = "CRITICAL"
elif outside_60:
severity = "WARNING"
else:
severity = "NORMAL"
records.append(
{
"date": str(d)[:7],
"actual": round(float(fv), 4),
"forecast": round(float(pt), 4),
"q10": round(float(q10[i]), 4),
"q20": round(float(q20[i]), 4),
"q80": round(float(q80[i]), 4),
"q90": round(float(q90[i]), 4),
"severity": severity,
"was_injected": i in injected_at,
}
)
return records
# ---------------------------------------------------------------------------
# Visualization
# ---------------------------------------------------------------------------
def plot_results(
context_dates: list,
context_values: np.ndarray,
ctx_records: list[dict],
trend_line: np.ndarray,
residuals: np.ndarray,
res_std: float,
future_dates: list,
future_values: np.ndarray,
point_fc: np.ndarray,
quant_fc: np.ndarray,
fc_records: list[dict],
) -> None:
OUTPUT_DIR.mkdir(exist_ok=True)
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 10), gridspec_kw={"hspace": 0.42})
fig.suptitle(
"TimesFM Anomaly Detection — Two-Phase Method", fontsize=14, fontweight="bold"
)
# -----------------------------------------------------------------------
# Panel 1 — full timeline
# -----------------------------------------------------------------------
ctx_x = [pd.Timestamp(d) for d in context_dates]
fut_x = [pd.Timestamp(d) for d in future_dates]
divider = ctx_x[-1]
# context: blue line + trend + 2sigma band
ax1.plot(
ctx_x,
context_values,
color=CLR["NORMAL"],
lw=2,
marker="o",
ms=4,
label="Observed (context)",
)
ax1.plot(ctx_x, trend_line, color="#aaaaaa", lw=1.5, ls="--", label="Linear trend")
ax1.fill_between(
ctx_x,
trend_line - 2 * res_std,
trend_line + 2 * res_std,
alpha=0.15,
color=CLR["NORMAL"],
label="+/-2sigma band",
)
# context anomaly markers
seen_ctx: set[str] = set()
for rec in ctx_records:
if rec["severity"] == "NORMAL":
continue
d = pd.Timestamp(rec["date"])
v = rec["value"]
sev = rec["severity"]
lbl = f"Context {sev}" if sev not in seen_ctx else None
seen_ctx.add(sev)
ax1.scatter(d, v, marker="D", s=90, color=CLR[sev], zorder=6, label=lbl)
ax1.annotate(
f"z={rec['z_score']:+.1f}",
(d, v),
textcoords="offset points",
xytext=(0, 9),
fontsize=7.5,
ha="center",
color=CLR[sev],
)
# forecast section
q10 = quant_fc[IDX_Q10]
q20 = quant_fc[IDX_Q20]
q80 = quant_fc[IDX_Q80]
q90 = quant_fc[IDX_Q90]
ax1.plot(fut_x, future_values, "k--", lw=1.5, label="Synthetic future (truth)")
ax1.plot(
fut_x,
point_fc,
color=CLR["CRITICAL"],
lw=2,
marker="s",
ms=4,
label="TimesFM point forecast",
)
ax1.fill_between(fut_x, q10, q90, alpha=0.15, color=CLR["CRITICAL"], label="80% PI")
ax1.fill_between(fut_x, q20, q80, alpha=0.25, color=CLR["CRITICAL"], label="60% PI")
seen_fc: set[str] = set()
for i, rec in enumerate(fc_records):
if rec["severity"] == "NORMAL":
continue
d = pd.Timestamp(rec["date"])
v = rec["actual"]
sev = rec["severity"]
mk = "X" if sev == "CRITICAL" else "^"
lbl = f"Forecast {sev}" if sev not in seen_fc else None
seen_fc.add(sev)
ax1.scatter(d, v, marker=mk, s=100, color=CLR[sev], zorder=6, label=lbl)
ax1.axvline(divider, color="#555555", lw=1.5, ls=":")
ax1.text(
divider,
ax1.get_ylim()[1] if ax1.get_ylim()[1] != 0 else 1.5,
" <- Context | Forecast ->",
fontsize=8.5,
color="#555555",
style="italic",
va="top",
)
ax1.annotate(
"Context: D = Z-score anomaly | Forecast: X = CRITICAL, ^ = WARNING",
xy=(0.01, 0.04),
xycoords="axes fraction",
fontsize=8,
bbox=dict(boxstyle="round", fc="white", ec="#cccccc", alpha=0.9),
)
ax1.set_ylabel("Temperature Anomaly (C)", fontsize=10)
ax1.legend(ncol=2, fontsize=7.5, loc="upper left")
ax1.grid(True, alpha=0.22)
# -----------------------------------------------------------------------
# Panel 2 — deviation bars across all 48 months
# -----------------------------------------------------------------------
all_labels: list[str] = []
bar_colors: list[str] = []
bar_heights: list[float] = []
for rec in ctx_records:
all_labels.append(rec["date"])
bar_heights.append(rec["residual"])
bar_colors.append(CLR[rec["severity"]])
fc_deviations: list[float] = []
for rec in fc_records:
all_labels.append(rec["date"])
dev = rec["actual"] - rec["forecast"]
fc_deviations.append(dev)
bar_heights.append(dev)
bar_colors.append(CLR[rec["severity"]])
xs = np.arange(len(all_labels))
ax2.bar(xs[:36], bar_heights[:36], color=bar_colors[:36], alpha=0.8)
ax2.bar(xs[36:], bar_heights[36:], color=bar_colors[36:], alpha=0.8)
# threshold lines for context section only
ax2.hlines(
[2 * res_std, -2 * res_std], -0.5, 35.5, colors=CLR["NORMAL"], lw=1.2, ls="--"
)
ax2.hlines(
[3 * res_std, -3 * res_std], -0.5, 35.5, colors=CLR["NORMAL"], lw=1.0, ls=":"
)
# PI bands for forecast section
fc_xs = xs[36:]
ax2.fill_between(
fc_xs,
q10 - point_fc,
q90 - point_fc,
alpha=0.12,
color=CLR["CRITICAL"],
step="mid",
)
ax2.fill_between(
fc_xs,
q20 - point_fc,
q80 - point_fc,
alpha=0.20,
color=CLR["CRITICAL"],
step="mid",
)
ax2.axvline(35.5, color="#555555", lw=1.5, ls="--")
ax2.axhline(0, color="black", lw=0.8, alpha=0.6)
ax2.text(
10,
ax2.get_ylim()[0] * 0.85 if ax2.get_ylim()[0] < 0 else -0.05,
"<- Context: delta from linear trend",
fontsize=8,
style="italic",
color="#555555",
ha="center",
)
ax2.text(
41,
ax2.get_ylim()[0] * 0.85 if ax2.get_ylim()[0] < 0 else -0.05,
"Forecast: delta from TimesFM ->",
fontsize=8,
style="italic",
color="#555555",
ha="center",
)
tick_every = 3
ax2.set_xticks(xs[::tick_every])
ax2.set_xticklabels(all_labels[::tick_every], rotation=45, ha="right", fontsize=7)
ax2.set_ylabel("Delta from expected (C)", fontsize=10)
ax2.grid(True, alpha=0.22, axis="y")
legend_patches = [
mpatches.Patch(color=CLR["CRITICAL"], label="CRITICAL"),
mpatches.Patch(color=CLR["WARNING"], label="WARNING"),
mpatches.Patch(color=CLR["NORMAL"], label="Normal"),
]
ax2.legend(handles=legend_patches, fontsize=8, loc="upper right")
output_path = OUTPUT_DIR / "anomaly_detection.png"
plt.savefig(output_path, dpi=150, bbox_inches="tight")
plt.close()
print(f"\n Saved: {output_path}")
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main() -> None:
print("=" * 68)
print(" TIMESFM ANOMALY DETECTION — TWO-PHASE METHOD")
print("=" * 68)
# --- Load context data ---------------------------------------------------
df = pd.read_csv(DATA_FILE)
df["date"] = pd.to_datetime(df["date"])
df = df.sort_values("date").reset_index(drop=True)
context_values = df["anomaly_c"].values.astype(np.float32)
context_dates = [pd.Timestamp(d) for d in df["date"].tolist()]
start_str = context_dates[0].strftime('%Y-%m') if not pd.isnull(context_dates[0]) else '?'
end_str = context_dates[-1].strftime('%Y-%m') if not pd.isnull(context_dates[-1]) else '?'
print(f"\n Context: {len(context_values)} months ({start_str} - {end_str})")
# --- Phase 1: context anomaly detection ----------------------------------
ctx_records, trend_line, residuals, res_std = detect_context_anomalies(
context_values, context_dates
)
ctx_critical = [r for r in ctx_records if r["severity"] == "CRITICAL"]
ctx_warning = [r for r in ctx_records if r["severity"] == "WARNING"]
print(f"\n [Phase 1] Context anomalies (Z-score, sigma={res_std:.3f} C):")
print(f" CRITICAL (|Z|>={CRITICAL_Z}): {len(ctx_critical)}")
for r in ctx_critical:
print(f" {r['date']} {r['value']:+.3f} C z={r['z_score']:+.2f}")
print(f" WARNING (|Z|>={WARNING_Z}): {len(ctx_warning)}")
for r in ctx_warning:
print(f" {r['date']} {r['value']:+.3f} C z={r['z_score']:+.2f}")
# --- Load TimesFM --------------------------------------------------------
print("\n Loading TimesFM 1.0 ...")
import timesfm
hparams = timesfm.TimesFmHparams(horizon_len=HORIZON)
checkpoint = timesfm.TimesFmCheckpoint(
huggingface_repo_id="google/timesfm-1.0-200m-pytorch"
)
model = timesfm.TimesFm(hparams=hparams, checkpoint=checkpoint)
point_out, quant_out = model.forecast([context_values], freq=[0])
point_fc = point_out[0] # shape (HORIZON,)
quant_fc = quant_out[0].T # shape (10, HORIZON)
# --- Build synthetic future + Phase 2 detection --------------------------
future_values, injected = build_synthetic_future(context_values, HORIZON)
last_date = context_dates[-1]
future_dates = [last_date + pd.DateOffset(months=i + 1) for i in range(HORIZON)]
fc_records = detect_forecast_anomalies(
future_values, point_fc, quant_fc, future_dates, injected
)
fc_critical = [r for r in fc_records if r["severity"] == "CRITICAL"]
fc_warning = [r for r in fc_records if r["severity"] == "WARNING"]
print(f"\n [Phase 2] Forecast anomalies (quantile PI, horizon={HORIZON} months):")
print(f" CRITICAL (outside 80% PI): {len(fc_critical)}")
for r in fc_critical:
print(
f" {r['date']} actual={r['actual']:+.3f} "
f"fc={r['forecast']:+.3f} injected={r['was_injected']}"
)
print(f" WARNING (outside 60% PI): {len(fc_warning)}")
for r in fc_warning:
print(
f" {r['date']} actual={r['actual']:+.3f} "
f"fc={r['forecast']:+.3f} injected={r['was_injected']}"
)
# --- Plot ----------------------------------------------------------------
print("\n Generating 2-panel visualization...")
plot_results(
context_dates,
context_values,
ctx_records,
trend_line,
residuals,
res_std,
future_dates,
future_values,
point_fc,
quant_fc,
fc_records,
)
# --- Save JSON -----------------------------------------------------------
OUTPUT_DIR.mkdir(exist_ok=True)
out = {
"method": "two_phase",
"context_method": "linear_detrend_zscore",
"forecast_method": "quantile_prediction_intervals",
"thresholds": {
"critical_z": CRITICAL_Z,
"warning_z": WARNING_Z,
"pi_critical_pct": 80,
"pi_warning_pct": 60,
},
"context_summary": {
"total": len(ctx_records),
"critical": len(ctx_critical),
"warning": len(ctx_warning),
"normal": len([r for r in ctx_records if r["severity"] == "NORMAL"]),
"res_std": round(float(res_std), 5),
},
"forecast_summary": {
"total": len(fc_records),
"critical": len(fc_critical),
"warning": len(fc_warning),
"normal": len([r for r in fc_records if r["severity"] == "NORMAL"]),
},
"context_detections": ctx_records,
"forecast_detections": fc_records,
}
json_path = OUTPUT_DIR / "anomaly_detection.json"
with open(json_path, "w") as f:
json.dump(out, f, indent=2)
print(f" Saved: {json_path}")
print("\n" + "=" * 68)
print(" SUMMARY")
print("=" * 68)
print(
f" Context ({len(ctx_records)} months): "
f"{len(ctx_critical)} CRITICAL, {len(ctx_warning)} WARNING"
)
print(
f" Forecast ({len(fc_records)} months): "
f"{len(fc_critical)} CRITICAL, {len(fc_warning)} WARNING"
)
print("=" * 68)
if __name__ == "__main__":
main()
================================================
FILE: timesfm-forecasting/examples/anomaly-detection/output/anomaly_detection.json
================================================
{
"method": "two_phase",
"context_method": "linear_detrend_zscore",
"forecast_method": "quantile_prediction_intervals",
"thresholds": {
"critical_z": 3.0,
"warning_z": 2.0,
"pi_critical_pct": 80,
"pi_warning_pct": 60
},
"context_summary": {
"total": 36,
"critical": 1,
"warning": 0,
"normal": 35,
"res_std": 0.11362
},
"forecast_summary": {
"total": 12,
"critical": 4,
"warning": 1,
"normal": 7
},
"context_detections": [
{
"date": "2022-01",
"value": 0.89,
"trend": 0.837,
"residual": 0.053,
"z_score": 0.467,
"severity": "NORMAL"
},
{
"date": "2022-02",
"value": 0.89,
"trend": 0.8514,
"residual": 0.0386,
"z_score": 0.34,
"severity": "NORMAL"
},
{
"date": "2022-03",
"value": 1.02,
"trend": 0.8658,
"residual": 0.1542,
"z_score": 1.357,
"severity": "NORMAL"
},
{
"date": "2022-04",
"value": 0.88,
"trend": 0.8803,
"residual": -0.0003,
"z_score": -0.002,
"severity": "NORMAL"
},
{
"date": "2022-05",
"value": 0.85,
"trend": 0.8947,
"residual": -0.0447,
"z_score": -0.394,
"severity": "NORMAL"
},
{
"date": "2022-06",
"value": 0.88,
"trend": 0.9092,
"residual": -0.0292,
"z_score": -0.257,
"severity": "NORMAL"
},
{
"date": "2022-07",
"value": 0.88,
"trend": 0.9236,
"residual": -0.0436,
"z_score": -0.384,
"severity": "NORMAL"
},
{
"date": "2022-08",
"value": 0.9,
"trend": 0.9381,
"residual": -0.0381,
"z_score": -0.335,
"severity": "NORMAL"
},
{
"date": "2022-09",
"value": 0.88,
"trend": 0.9525,
"residual": -0.0725,
"z_score": -0.638,
"severity": "NORMAL"
},
{
"date": "2022-10",
"value": 0.95,
"trend": 0.9669,
"residual": -0.0169,
"z_score": -0.149,
"severity": "NORMAL"
},
{
"date": "2022-11",
"value": 0.77,
"trend": 0.9814,
"residual": -0.2114,
"z_score": -1.86,
"severity": "NORMAL"
},
{
"date": "2022-12",
"value": 0.78,
"trend": 0.9958,
"residual": -0.2158,
"z_score": -1.9,
"severity": "NORMAL"
},
{
"date": "2023-01",
"value": 0.87,
"trend": 1.0103,
"residual": -0.1403,
"z_score": -1.235,
"severity": "NORMAL"
},
{
"date": "2023-02",
"value": 0.98,
"trend": 1.0247,
"residual": -0.0447,
"z_score": -0.394,
"severity": "NORMAL"
},
{
"date": "2023-03",
"value": 1.21,
"trend": 1.0392,
"residual": 0.1708,
"z_score": 1.503,
"severity": "NORMAL"
},
{
"date": "2023-04",
"value": 1.0,
"trend": 1.0536,
"residual": -0.0536,
"z_score": -0.472,
"severity": "NORMAL"
},
{
"date": "2023-05",
"value": 0.94,
"trend": 1.0681,
"residual": -0.1281,
"z_score": -1.127,
"severity": "NORMAL"
},
{
"date": "2023-06",
"value": 1.08,
"trend": 1.0825,
"residual": -0.0025,
"z_score": -0.022,
"severity": "NORMAL"
},
{
"date": "2023-07",
"value": 1.18,
"trend": 1.0969,
"residual": 0.0831,
"z_score": 0.731,
"severity": "NORMAL"
},
{
"date": "2023-08",
"value": 1.24,
"trend": 1.1114,
"residual": 0.1286,
"z_score": 1.132,
"severity": "NORMAL"
},
{
"date": "2023-09",
"value": 1.47,
"trend": 1.1258,
"residual": 0.3442,
"z_score": 3.029,
"severity": "CRITICAL"
},
{
"date": "2023-10",
"value": 1.32,
"trend": 1.1403,
"residual": 0.1797,
"z_score": 1.582,
"severity": "NORMAL"
},
{
"date": "2023-11",
"value": 1.18,
"trend": 1.1547,
"residual": 0.0253,
"z_score": 0.222,
"severity": "NORMAL"
},
{
"date": "2023-12",
"value": 1.16,
"trend": 1.1692,
"residual": -0.0092,
"z_score": -0.081,
"severity": "NORMAL"
},
{
"date": "2024-01",
"value": 1.22,
"trend": 1.1836,
"residual": 0.0364,
"z_score": 0.32,
"severity": "NORMAL"
},
{
"date": "2024-02",
"value": 1.35,
"trend": 1.1981,
"residual": 0.1519,
"z_score": 1.337,
"severity": "NORMAL"
},
{
"date": "2024-03",
"value": 1.34,
"trend": 1.2125,
"residual": 0.1275,
"z_score": 1.122,
"severity": "NORMAL"
},
{
"date": "2024-04",
"value": 1.26,
"trend": 1.2269,
"residual": 0.0331,
"z_score": 0.291,
"severity": "NORMAL"
},
{
"date": "2024-05",
"value": 1.15,
"trend": 1.2414,
"residual": -0.0914,
"z_score": -0.804,
"severity": "NORMAL"
},
{
"date": "2024-06",
"value": 1.2,
"trend": 1.2558,
"residual": -0.0558,
"z_score": -0.491,
"severity": "NORMAL"
},
{
"date": "2024-07",
"value": 1.24,
"trend": 1.2703,
"residual": -0.0303,
"z_score": -0.266,
"severity": "NORMAL"
},
{
"date": "2024-08",
"value": 1.3,
"trend": 1.2847,
"residual": 0.0153,
"z_score": 0.135,
"severity": "NORMAL"
},
{
"date": "2024-09",
"value": 1.28,
"trend": 1.2992,
"residual": -0.0192,
"z_score": -0.169,
"severity": "NORMAL"
},
{
"date": "2024-10",
"value": 1.27,
"trend": 1.3136,
"residual": -0.0436,
"z_score": -0.384,
"severity": "NORMAL"
},
{
"date": "2024-11",
"value": 1.22,
"trend": 1.328,
"residual": -0.108,
"z_score": -0.951,
"severity": "NORMAL"
},
{
"date": "2024-12",
"value": 1.2,
"trend": 1.3425,
"residual": -0.1425,
"z_score": -1.254,
"severity": "NORMAL"
}
],
"forecast_detections": [
{
"date": "2025-01",
"actual": 1.2821,
"forecast": 1.2593,
"q10": 1.1407,
"q20": 1.1881,
"q80": 1.324,
"q90": 1.3679,
"severity": "NORMAL",
"was_injected": false
},
{
"date": "2025-02",
"actual": 1.1522,
"forecast": 1.2857,
"q10": 1.1406,
"q20": 1.1961,
"q80": 1.3751,
"q90": 1.4254,
"severity": "WARNING",
"was_injected": false
},
{
"date": "2025-03",
"actual": 1.3358,
"forecast": 1.295,
"q10": 1.1269,
"q20": 1.1876,
"q80": 1.4035,
"q90": 1.4643,
"severity": "NORMAL",
"was_injected": false
},
{
"date": "2025-04",
"actual": 2.0594,
"forecast": 1.2208,
"q10": 1.0353,
"q20": 1.1042,
"q80": 1.331,
"q90": 1.4017,
"severity": "CRITICAL",
"was_injected": true
},
{
"date": "2025-05",
"actual": 1.0747,
"forecast": 1.1703,
"q10": 0.9691,
"q20": 1.0431,
"q80": 1.2892,
"q90": 1.3632,
"severity": "NORMAL",
"was_injected": false
},
{
"date": "2025-06",
"actual": 1.1442,
"forecast": 1.1456,
"q10": 0.942,
"q20": 1.0111,
"q80": 1.2703,
"q90": 1.3454,
"severity": "NORMAL",
"was_injected": false
},
{
"date": "2025-07",
"actual": 1.2917,
"forecast": 1.1702,
"q10": 0.9504,
"q20": 1.0348,
"q80": 1.2998,
"q90": 1.3807,
"severity": "NORMAL",
"was_injected": false
},
{
"date": "2025-08",
"actual": 1.2519,
"forecast": 1.2027,
"q10": 0.9709,
"q20": 1.0594,
"q80": 1.3408,
"q90": 1.4195,
"severity": "NORMAL",
"was_injected": false
},
{
"date": "2025-09",
"actual": 0.6364,
"forecast": 1.191,
"q10": 0.9594,
"q20": 1.0404,
"q80": 1.3355,
"q90": 1.417,
"severity": "CRITICAL",
"was_injected": true
},
{
"date": "2025-10",
"actual": 1.2073,
"forecast": 1.1491,
"q10": 0.9079,
"q20": 0.9953,
"q80": 1.2869,
"q90": 1.3775,
"severity": "NORMAL",
"was_injected": false
},
{
"date": "2025-11",
"actual": 1.3851,
"forecast": 1.0805,
"q10": 0.8361,
"q20": 0.926,
"q80": 1.2284,
"q90": 1.3122,
"severity": "CRITICAL",
"was_injected": false
},
{
"date": "2025-12",
"actual": 1.8294,
"forecast": 1.0613,
"q10": 0.8022,
"q20": 0.8952,
"q80": 1.2169,
"q90": 1.296,
"severity": "CRITICAL",
"was_injected": true
}
]
}
================================================
FILE: timesfm-forecasting/examples/covariates-forecasting/demo_covariates.py
================================================
#!/usr/bin/env python3
"""
TimesFM Covariates (XReg) Example
Demonstrates the TimesFM covariate API using synthetic retail sales data.
TimesFM 1.0 does NOT support forecast_with_covariates(); that requires
TimesFM 2.5 + `pip install timesfm[xreg]`.
This script:
1. Generates synthetic 3-store weekly retail data (24-week context, 12-week horizon)
2. Produces a 2x2 visualization showing WHAT each covariate contributes
and WHY knowing them improves forecasts -- all panels share the same
week x-axis (0 = first context week, 35 = last horizon week)
3. Exports a compact CSV (108 rows) and metadata JSON
NOTE ON REAL DATA:
If you want to use a real retail dataset (e.g., Kaggle Rossmann Store Sales),
download it to a TEMP location -- do NOT commit large CSVs to this repo.
import tempfile, urllib.request
tmp = tempfile.mkdtemp(prefix="timesfm_retail_")
# urllib.request.urlretrieve("https://...store_sales.csv", f"{tmp}/store_sales.csv")
# df = pd.read_csv(f"{tmp}/store_sales.csv")
This skills directory intentionally keeps only tiny reference datasets.
"""
from __future__ import annotations
import json
from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
EXAMPLE_DIR = Path(__file__).parent
OUTPUT_DIR = EXAMPLE_DIR / "output"
N_STORES = 3
CONTEXT_LEN = 24
HORIZON_LEN = 12
TOTAL_LEN = CONTEXT_LEN + HORIZON_LEN # 36
def generate_sales_data() -> dict:
"""Generate synthetic retail sales data with covariate components stored separately.
Returns a dict with:
stores: {store_id: {sales, config}}
covariates: {price, promotion, holiday, day_of_week, store_type, region}
components: {store_id: {base, price_effect, promo_effect, holiday_effect}}
Components let us show 'what would sales look like without covariates?' --
the gap between 'base' and 'sales' IS the covariate signal.
BUG FIX v3: Previous versions had variable-shadowing where inner dict
comprehension `{store_id: ... for store_id in stores}` overwrote the outer
loop variable causing all stores to get identical covariate arrays.
Fixed by accumulating per-store arrays separately before building covariate dict.
"""
rng = np.random.default_rng(42)
stores = {
"store_A": {"type": "premium", "region": "urban", "base_sales": 1000},
"store_B": {"type": "standard", "region": "suburban", "base_sales": 750},
"store_C": {"type": "discount", "region": "rural", "base_sales": 500},
}
base_prices = {"store_A": 12.0, "store_B": 10.0, "store_C": 7.5}
data: dict = {"stores": {}, "covariates": {}, "components": {}}
prices_by_store: dict[str, np.ndarray] = {}
promos_by_store: dict[str, np.ndarray] = {}
holidays_by_store: dict[str, np.ndarray] = {}
dow_by_store: dict[str, np.ndarray] = {}
for store_id, config in stores.items():
bp = base_prices[store_id]
weeks = np.arange(TOTAL_LEN)
trend = config["base_sales"] * (1 + 0.005 * weeks)
seasonality = 80 * np.sin(2 * np.pi * weeks / 52)
noise = rng.normal(0, 40, TOTAL_LEN)
base = (trend + seasonality + noise).astype(np.float32)
price = (bp + rng.uniform(-0.5, 0.5, TOTAL_LEN)).astype(np.float32)
price_effect = (-20 * (price - bp)).astype(np.float32)
holidays = np.zeros(TOTAL_LEN, dtype=np.float32)
for hw in [0, 11, 23, 35]:
if hw < TOTAL_LEN:
holidays[hw] = 1.0
holiday_effect = (200 * holidays).astype(np.float32)
promotion = rng.choice([0.0, 1.0], TOTAL_LEN, p=[0.8, 0.2]).astype(np.float32)
promo_effect = (150 * promotion).astype(np.float32)
day_of_week = np.tile(np.arange(7), TOTAL_LEN // 7 + 1)[:TOTAL_LEN].astype(
np.int32
)
sales = np.maximum(base + price_effect + holiday_effect + promo_effect, 50.0)
data["stores"][store_id] = {"sales": sales, "config": config}
data["components"][store_id] = {
"base": base,
"price_effect": price_effect,
"promo_effect": promo_effect,
"holiday_effect": holiday_effect,
}
prices_by_store[store_id] = price
promos_by_store[store_id] = promotion
holidays_by_store[store_id] = holidays
dow_by_store[store_id] = day_of_week
data["covariates"] = {
"price": prices_by_store,
"promotion": promos_by_store,
"holiday": holidays_by_store,
"day_of_week": dow_by_store,
"store_type": {sid: stores[sid]["type"] for sid in stores},
"region": {sid: stores[sid]["region"] for sid in stores},
}
return data
def create_visualization(data: dict) -> None:
"""
2x2 figure -- ALL panels share x-axis = weeks 0-35.
(0,0) Sales by store -- context solid, horizon dashed
(0,1) Store A: actual vs baseline (no covariates), with event overlays showing uplift
(1,0) Price covariate for all stores -- full 36 weeks including horizon
(1,1) Covariate effect decomposition for Store A (stacked fill_between)
Each panel has a conclusion annotation box explaining what the data shows.
"""
OUTPUT_DIR.mkdir(exist_ok=True)
store_colors = {"store_A": "#1a56db", "store_B": "#057a55", "store_C": "#c03221"}
weeks = np.arange(TOTAL_LEN)
fig, axes = plt.subplots(
2,
2,
figsize=(16, 11),
sharex=True,
gridspec_kw={"hspace": 0.42, "wspace": 0.32},
)
fig.suptitle(
"TimesFM Covariates (XReg) -- Retail Sales with Exogenous Variables\n"
"Shared x-axis: Week 0-23 = context (observed) | Week 24-35 = forecast horizon",
fontsize=13,
fontweight="bold",
y=1.01,
)
def add_divider(ax, label_top=True):
ax.axvline(CONTEXT_LEN - 0.5, color="#9ca3af", lw=1.3, ls="--", alpha=0.8)
ax.axvspan(
CONTEXT_LEN - 0.5, TOTAL_LEN - 0.5, alpha=0.06, color="grey", zorder=0
)
if label_top:
ax.text(
CONTEXT_LEN + 0.3,
1.01,
"<- horizon ->",
transform=ax.get_xaxis_transform(),
fontsize=7.5,
color="#6b7280",
style="italic",
)
# -- (0,0): Sales by Store ---------------------------------------------------
ax = axes[0, 0]
base_price_labels = {"store_A": "$12", "store_B": "$10", "store_C": "$7.50"}
for sid, store_data in data["stores"].items():
sales = store_data["sales"]
c = store_colors[sid]
lbl = f"{sid} ({store_data['config']['type']}, {base_price_labels[sid]} base)"
ax.plot(
weeks[:CONTEXT_LEN],
sales[:CONTEXT_LEN],
color=c,
lw=2,
marker="o",
ms=3,
label=lbl,
)
ax.plot(
weeks[CONTEXT_LEN:],
sales[CONTEXT_LEN:],
color=c,
lw=1.5,
ls="--",
marker="o",
ms=3,
alpha=0.6,
)
add_divider(ax)
ax.set_ylabel("Weekly Sales (units)", fontsize=10)
ax.set_title("Sales by Store", fontsize=11, fontweight="bold")
ax.legend(fontsize=7.5, loc="upper left")
ax.grid(True, alpha=0.22)
ratio = (
data["stores"]["store_A"]["sales"][:CONTEXT_LEN].mean()
/ data["stores"]["store_C"]["sales"][:CONTEXT_LEN].mean()
)
ax.annotate(
f"Store A earns {ratio:.1f}x Store C\n(premium vs discount pricing)\n"
f"-> store_type is a useful static covariate",
xy=(0.97, 0.05),
xycoords="axes fraction",
ha="right",
fontsize=8,
bbox=dict(boxstyle="round", fc="#fffbe6", ec="#d4a017", alpha=0.95),
)
# -- (0,1): Store A actual vs baseline ---------------------------------------
ax = axes[0, 1]
comp_A = data["components"]["store_A"]
sales_A = data["stores"]["store_A"]["sales"]
base_A = comp_A["base"]
promo_A = data["covariates"]["promotion"]["store_A"]
holiday_A = data["covariates"]["holiday"]["store_A"]
ax.plot(
weeks[:CONTEXT_LEN],
base_A[:CONTEXT_LEN],
color="#9ca3af",
lw=1.8,
ls="--",
label="Baseline (no covariates)",
)
ax.fill_between(
weeks[:CONTEXT_LEN],
base_A[:CONTEXT_LEN],
sales_A[:CONTEXT_LEN],
where=(sales_A[:CONTEXT_LEN] > base_A[:CONTEXT_LEN]),
alpha=0.35,
color="#22c55e",
label="Covariate uplift",
)
ax.fill_between(
weeks[:CONTEXT_LEN],
sales_A[:CONTEXT_LEN],
base_A[:CONTEXT_LEN],
where=(sales_A[:CONTEXT_LEN] < base_A[:CONTEXT_LEN]),
alpha=0.30,
color="#ef4444",
label="Price suppression",
)
ax.plot(
weeks[:CONTEXT_LEN],
sales_A[:CONTEXT_LEN],
color=store_colors["store_A"],
lw=2,
label="Actual sales (Store A)",
)
for w in range(CONTEXT_LEN):
if holiday_A[w] > 0:
ax.axvspan(w - 0.45, w + 0.45, alpha=0.22, color="darkorange", zorder=0)
promo_weeks = [w for w in range(CONTEXT_LEN) if promo_A[w] > 0]
if promo_weeks:
ax.scatter(
promo_weeks,
sales_A[promo_weeks],
marker="^",
color="#16a34a",
s=70,
zorder=6,
label="Promotion week",
)
add_divider(ax)
ax.set_ylabel("Weekly Sales (units)", fontsize=10)
ax.set_title(
"Store A -- Actual vs Baseline (No Covariates)", fontsize=11, fontweight="bold"
)
ax.legend(fontsize=7.5, loc="upper left", ncol=2)
ax.grid(True, alpha=0.22)
hm = holiday_A[:CONTEXT_LEN] > 0
pm = promo_A[:CONTEXT_LEN] > 0
h_lift = (
(sales_A[:CONTEXT_LEN][hm] - base_A[:CONTEXT_LEN][hm]).mean() if hm.any() else 0
)
p_lift = (
(sales_A[:CONTEXT_LEN][pm] - base_A[:CONTEXT_LEN][pm]).mean() if pm.any() else 0
)
ax.annotate(
f"Holiday weeks: +{h_lift:.0f} units avg\n"
f"Promotion weeks: +{p_lift:.0f} units avg\n"
f"Future event schedules must be known for XReg",
xy=(0.97, 0.05),
xycoords="axes fraction",
ha="right",
fontsize=8,
bbox=dict(boxstyle="round", fc="#fffbe6", ec="#d4a017", alpha=0.95),
)
# -- (1,0): Price covariate -- full 36 weeks ---------------------------------
ax = axes[1, 0]
for sid in data["stores"]:
ax.plot(
weeks,
data["covariates"]["price"][sid],
color=store_colors[sid],
lw=2,
label=sid,
alpha=0.85,
)
add_divider(ax, label_top=False)
ax.set_xlabel("Week", fontsize=10)
ax.set_ylabel("Price ($)", fontsize=10)
ax.set_title(
"Price Covariate -- Context + Forecast Horizon", fontsize=11, fontweight="bold"
)
ax.legend(fontsize=8, loc="upper right")
ax.grid(True, alpha=0.22)
ax.annotate(
"Prices are planned -- known for forecast horizon\n"
"Price elasticity: -$1 increase -> -20 units sold\n"
"Store A ($12) consistently more expensive than C ($7.50)",
xy=(0.97, 0.05),
xycoords="axes fraction",
ha="right",
fontsize=8,
bbox=dict(boxstyle="round", fc="#fffbe6", ec="#d4a017", alpha=0.95),
)
# -- (1,1): Covariate effect decomposition -----------------------------------
ax = axes[1, 1]
pe = comp_A["price_effect"]
pre = comp_A["promo_effect"]
he = comp_A["holiday_effect"]
ax.fill_between(
weeks,
0,
pe,
alpha=0.65,
color="steelblue",
step="mid",
label=f"Price effect (max +/-{np.abs(pe).max():.0f} units)",
)
ax.fill_between(
weeks,
pe,
pe + pre,
alpha=0.70,
color="#22c55e",
step="mid",
label="Promotion effect (+150 units)",
)
ax.fill_between(
weeks,
pe + pre,
pe + pre + he,
alpha=0.70,
color="darkorange",
step="mid",
label="Holiday effect (+200 units)",
)
total = pe + pre + he
ax.plot(weeks, total, "k-", lw=1.5, alpha=0.75, label="Total covariate effect")
ax.axhline(0, color="black", lw=0.9, alpha=0.6)
add_divider(ax, label_top=False)
ax.set_xlabel("Week", fontsize=10)
ax.set_ylabel("Effect on sales (units)", fontsize=10)
ax.set_title(
"Store A -- Covariate Effect Decomposition", fontsize=11, fontweight="bold"
)
ax.legend(fontsize=7.5, loc="upper right")
ax.grid(True, alpha=0.22, axis="y")
ax.annotate(
f"Holidays (+200) and promotions (+150) dominate\n"
f"Price effect (+/-{np.abs(pe).max():.0f} units) is minor by comparison\n"
f"-> Time-varying covariates explain most sales spikes",
xy=(0.97, 0.55),
xycoords="axes fraction",
ha="right",
fontsize=8,
bbox=dict(boxstyle="round", fc="#fffbe6", ec="#d4a017", alpha=0.95),
)
tick_pos = list(range(0, TOTAL_LEN, 4))
for row in [0, 1]:
for col in [0, 1]:
axes[row, col].set_xticks(tick_pos)
plt.tight_layout()
output_path = OUTPUT_DIR / "covariates_data.png"
plt.savefig(output_path, dpi=150, bbox_inches="tight")
plt.close()
print(f"\n Saved visualization: {output_path}")
def demonstrate_api() -> None:
print("\n" + "=" * 70)
print(" TIMESFM COVARIATES API (TimesFM 2.5)")
print("=" * 70)
print("""
# Installation
pip install timesfm[xreg]
import timesfm
hparams = timesfm.TimesFmHparams(backend="cpu", per_core_batch_size=32, horizon_len=12)
ckpt = timesfm.TimesFmCheckpoint(huggingface_repo_id="google/timesfm-2.5-200m-pytorch")
model = timesfm.TimesFm(hparams=hparams, checkpoint=ckpt)
point_fc, quant_fc = model.forecast_with_covariates(
inputs=[sales_a, sales_b, sales_c],
dynamic_numerical_covariates={"price": [price_a, price_b, price_c]},
dynamic_categorical_covariates={"holiday": [hol_a, hol_b, hol_c]},
static_categorical_covariates={"store_type": ["premium","standard","discount"]},
xreg_mode="xreg + timesfm",
normalize_xreg_target_per_input=True,
)
# point_fc: (num_series, horizon_len)
# quant_fc: (num_series, horizon_len, 10)
""")
def explain_xreg_modes() -> None:
print("\n" + "=" * 70)
print(" XREG MODES")
print("=" * 70)
print("""
"xreg + timesfm" (DEFAULT)
1. TimesFM makes baseline forecast
2. Fit regression on residuals (actual - baseline) ~ covariates
3. Final = TimesFM baseline + XReg adjustment
Best when: covariates explain residual variation (e.g. promotions)
"timesfm + xreg"
1. Fit regression: target ~ covariates
2. TimesFM forecasts the residuals
3. Final = XReg prediction + TimesFM residual forecast
Best when: covariates explain the main signal (e.g. temperature)
""")
def main() -> None:
print("=" * 70)
print(" TIMESFM COVARIATES (XREG) EXAMPLE")
print("=" * 70)
print("\n Generating synthetic retail sales data...")
data = generate_sales_data()
print(f" Stores: {list(data['stores'].keys())}")
print(f" Context length: {CONTEXT_LEN} weeks")
print(f" Horizon length: {HORIZON_LEN} weeks")
print(f" Covariates: {list(data['covariates'].keys())}")
demonstrate_api()
explain_xreg_modes()
print("\n Creating 2x2 visualization (shared x-axis)...")
create_visualization(data)
print("\n Saving output data...")
OUTPUT_DIR.mkdir(exist_ok=True)
records = []
for store_id, store_data in data["stores"].items():
for i in range(TOTAL_LEN):
records.append(
{
"store_id": store_id,
"week": i,
"split": "context" if i < CONTEXT_LEN else "horizon",
"sales": round(float(store_data["sales"][i]), 2),
"base_sales": round(
float(data["components"][store_id]["base"][i]), 2
),
"price": round(float(data["covariates"]["price"][store_id][i]), 4),
"price_effect": round(
float(data["components"][store_id]["price_effect"][i]), 2
),
"promotion": int(data["covariates"]["promotion"][store_id][i]),
"holiday": int(data["covariates"]["holiday"][store_id][i]),
"day_of_week": int(data["covariates"]["day_of_week"][store_id][i]),
"store_type": data["covariates"]["store_type"][store_id],
"region": data["covariates"]["region"][store_id],
}
)
df = pd.DataFrame(records)
csv_path = OUTPUT_DIR / "sales_with_covariates.csv"
df.to_csv(csv_path, index=False)
print(f" Saved: {csv_path} ({len(df)} rows x {len(df.columns)} cols)")
metadata = {
"description": "Synthetic retail sales data with covariates for TimesFM XReg demo",
"note_on_real_data": (
"For real datasets (e.g., Kaggle Rossmann Store Sales), download to "
"tempfile.mkdtemp() -- do NOT commit to this repo."
),
"stores": {
sid: {
**sdata["config"],
"mean_sales_context": round(
float(sdata["sales"][:CONTEXT_LEN].mean()), 1
),
}
for sid, sdata in data["stores"].items()
},
"dimensions": {
"context_length": CONTEXT_LEN,
"horizon_length": HORIZON_LEN,
"total_length": TOTAL_LEN,
"num_stores": N_STORES,
"csv_rows": len(df),
},
"covariates": {
"dynamic_numerical": ["price"],
"dynamic_categorical": ["promotion", "holiday", "day_of_week"],
"static_categorical": ["store_type", "region"],
},
"effect_magnitudes": {
"holiday": "+200 units per holiday week",
"promotion": "+150 units per promotion week",
"price": "-20 units per $1 above base price",
},
"xreg_modes": {
"xreg + timesfm": "Regression on TimesFM residuals (default)",
"timesfm + xreg": "TimesFM on regression residuals",
},
"bug_fixes_history": [
"v1: Variable-shadowing -- all stores had identical covariates",
"v2: Fixed shadowing; CONTEXT_LEN 48->24",
"v3: Added component decomposition (base, price/promo/holiday effects); 2x2 sharex viz",
],
}
meta_path = OUTPUT_DIR / "covariates_metadata.json"
with open(meta_path, "w") as f:
json.dump(metadata, f, indent=2)
print(f" Saved: {meta_path}")
print("\n" + "=" * 70)
print(" COVARIATES EXAMPLE COMPLETE")
print("=" * 70)
print("""
Key points:
1. Requires timesfm[xreg] + TimesFM 2.5+ for actual inference
2. Dynamic covariates need values for BOTH context AND horizon (future must be known!)
3. Static covariates: one value per series (store_type, region)
4. All 4 visualization panels share the same week x-axis (0-35)
5. Effect decomposition shows holidays/promotions dominate over price variation
Output files:
output/covariates_data.png -- 2x2 visualization with conclusions
output/sales_with_covariates.csv -- 108-row compact dataset
output/covariates_metadata.json -- metadata + effect magnitudes
""")
if __name__ == "__main__":
main()
================================================
FILE: timesfm-forecasting/examples/covariates-forecasting/output/covariates_metadata.json
================================================
{
"description": "Synthetic retail sales data with covariates for TimesFM XReg demo",
"note_on_real_data": "For real datasets (e.g., Kaggle Rossmann Store Sales), download to tempfile.mkdtemp() -- do NOT commit to this repo.",
"stores": {
"store_A": {
"type": "premium",
"region": "urban",
"base_sales": 1000,
"mean_sales_context": 1148.7
},
"store_B": {
"type": "standard",
"region": "suburban",
"base_sales": 750,
"mean_sales_context": 907.0
},
"store_C": {
"type": "discount",
"region": "rural",
"base_sales": 500,
"mean_sales_context": 645.3
}
},
"dimensions": {
"context_length": 24,
"horizon_length": 12,
"total_length": 36,
"num_stores": 3,
"csv_rows": 108
},
"covariates": {
"dynamic_numerical": [
"price"
],
"dynamic_categorical": [
"promotion",
"holiday",
"day_of_week"
],
"static_categorical": [
"store_type",
"region"
]
},
"effect_magnitudes": {
"holiday": "+200 units per holiday week",
"promotion": "+150 units per promotion week",
"price": "-20 units per $1 above base price"
},
"xreg_modes": {
"xreg + timesfm": "Regression on TimesFM residuals (default)",
"timesfm + xreg": "TimesFM on regression residuals"
},
"bug_fixes_history": [
"v1: Variable-shadowing -- all stores had identical covariates",
"v2: Fixed shadowing; CONTEXT_LEN 48->24",
"v3: Added component decomposition (base, price/promo/holiday effects); 2x2 sharex viz"
]
}
================================================
FILE: timesfm-forecasting/examples/covariates-forecasting/output/sales_with_covariates.csv
================================================
store_id,week,split,sales,base_sales,price,price_effect,promotion,holiday,day_of_week,store_type,region
store_A,0,context,1369.59,1012.19,11.6299,7.4,1,1,0,premium,urban
store_A,1,context,973.53,973.04,11.9757,0.49,0,0,1,premium,urban
store_A,2,context,1064.63,1059.16,11.7269,5.46,0,0,2,premium,urban
store_A,3,context,1077.59,1080.99,12.1698,-3.4,0,0,3,premium,urban
store_A,4,context,980.39,979.14,11.9372,1.26,0,0,4,premium,urban
store_A,5,context,1011.7,1018.36,12.3327,-6.65,0,0,5,premium,urban
store_A,6,context,1084.16,1088.16,12.2003,-4.01,0,0,6,premium,urban
store_A,7,context,1085.98,1082.23,11.8124,3.75,0,0,0,premium,urban
store_A,8,context,1098.52,1105.17,12.3323,-6.65,0,0,1,premium,urban
store_A,9,context,1075.62,1081.71,12.3048,-6.1,0,0,2,premium,urban
store_A,10,context,1312.23,1159.98,11.8875,2.25,1,0,3,premium,urban
store_A,11,context,1368.02,1163.79,11.7883,4.23,0,1,4,premium,urban
store_A,12,context,1138.41,1142.06,12.1825,-3.65,0,0,5,premium,urban
store_A,13,context,1197.29,1190.09,11.6398,7.2,0,0,6,premium,urban
store_A,14,context,1174.12,1168.12,11.6999,6.0,0,0,0,premium,urban
store_A,15,context,1128.16,1118.3,11.5074,9.85,0,0,1,premium,urban
store_A,16,context,1163.81,1169.55,12.2869,-5.74,0,0,2,premium,urban
store_A,17,context,1114.18,1117.48,12.1649,-3.3,0,0,3,premium,urban
store_A,18,context,1186.87,1190.98,12.2052,-4.1,0,0,4,premium,urban
store_A,19,context,1147.27,1152.88,12.2807,-5.61,0,0,5,premium,urban
store_A,20,context,1146.48,1145.66,11.9589,0.82,0,0,6,premium,urban
store_A,21,context,1121.83,1123.21,12.0687,-1.37,0,0,0,premium,urban
store_A,22,context,1203.28,1196.08,11.6398,7.2,0,0,1,premium,urban
store_A,23,context,1344.9,1137.19,11.6145,7.71,0,1,2,premium,urban
store_A,24,horizon,1118.64,1122.01,12.1684,-3.37,0,0,3,premium,urban
store_A,25,horizon,1121.14,1120.56,11.9711,0.58,0,0,4,premium,urban
store_A,26,horizon,1149.99,1151.29,12.0652,-1.3,0,0,5,premium,urban
store_A,27,horizon,1284.67,1139.97,12.265,-5.3,1,0,6,premium,urban
store_A,28,horizon,1284.67,1137.36,12.1347,-2.69,1,0,0,premium,urban
store_A,29,horizon,1132.79,1133.86,12.0536,-1.07,0,0,1,premium,urban
store_A,30,horizon,1197.3,1198.49,12.0592,-1.18,0,0,2,premium,urban
store_A,31,horizon,1247.22,1093.3,11.804,3.92,1,0,3,premium,urban
store_A,32,horizon,1095.84,1086.46,11.5308,9.38,0,0,4,premium,urban
store_A,33,horizon,1073.83,1072.57,11.9367,1.27,0,0,5,premium,urban
store_A,34,horizon,1134.51,1128.8,11.7146,5.71,0,0,6,premium,urban
store_A,35,horizon,1351.15,1149.32,11.9085,1.83,0,1,0,premium,urban
store_B,0,context,1062.53,712.0,9.9735,0.53,1,1,0,standard,suburban
store_B,1,context,904.49,749.83,9.767,4.66,1,0,1,standard,suburban
store_B,2,context,813.63,810.26,9.8316,3.37,0,0,2,standard,suburban
store_B,3,context,720.11,720.53,10.0207,-0.41,0,0,3,standard,suburban
store_B,4,context,820.78,819.55,9.9389,1.22,0,0,4,standard,suburban
store_B,5,context,833.27,823.7,9.5216,9.57,0,0,5,standard,suburban
store_B,6,context,795.26,801.78,10.3263,-6.53,0,0,6,standard,suburban
store_B,7,context,770.37,778.29,10.3962,-7.92,0,0,0,standard,suburban
store_B,8,context,855.92,848.72,9.6402,7.2,0,0,1,standard,suburban
store_B,9,context,832.33,833.41,10.054,-1.08,0,0,2,standard,suburban
store_B,10,context,1029.44,871.61,9.6086,7.83,1,0,3,standard,suburban
store_B,11,context,1066.35,869.8,10.1722,-3.44,0,1,4,standard,suburban
store_B,12,context,942.86,938.49,9.7812,4.38,0,0,5,standard,suburban
store_B,13,context,1015.99,869.18,10.1594,-3.19,1,0,6,standard,suburban
store_B,14,context,836.44,840.98,10.227,-4.54,0,0,0,standard,suburban
store_B,15,context,885.72,891.1,10.2686,-5.37,0,0,1,standard,suburban
store_B,16,context,901.45,893.6,9.6077,7.85,0,0,2,standard,suburban
store_B,17,context,1080.63,938.95,10.416,-8.32,1,0,3,standard,suburban
store_B,18,context,922.14,916.74,9.7302,5.4,0,0,4,standard,suburban
store_B,19,context,904.66,895.41,9.5374,9.25,0,0,5,standard,suburban
store_B,20,context,935.48,936.58,10.0549,-1.1,0,0,6,standard,suburban
store_B,21,context,979.23,826.64,9.8709,2.58,1,0,0,standard,suburban
store_B,22,context,837.49,844.09,10.3298,-6.6,0,0,1,standard,suburban
store_B,23,context,1021.39,827.56,10.3083,-6.17,0,1,2,standard,suburban
store_B,24,horizon,847.21,843.55,9.8171,3.66,0,0,3,standard,suburban
store_B,25,horizon,789.27,798.33,10.4529,-9.06,0,0,4,standard,suburban
store_B,26,horizon,877.09,872.91,9.7909,4.18,0,0,5,standard,suburban
store_B,27,horizon,832.42,832.72,10.0151,-0.3,0,0,6,standard,suburban
store_B,28,horizon,781.9,777.02,9.756,4.88,0,0,0,standard,suburban
store_B,29,horizon,781.04,789.76,10.436,-8.72,0,0,1,standard,suburban
store_B,30,horizon,844.57,837.86,9.6646,6.71,0,0,2,standard,suburban
store_B,31,horizon,863.43,854.33,9.5449,9.1,0,0,3,standard,suburban
store_B,32,horizon,898.12,896.82,9.9351,1.3,0,0,4,standard,suburban
store_B,33,horizon,1070.58,930.42,10.4924,-9.85,1,0,5,standard,suburban
store_B,34,horizon,820.4,828.24,10.3917,-7.83,0,0,6,standard,suburban
store_B,35,horizon,965.86,770.83,10.2486,-4.97,0,1,0,standard,suburban
store_C,0,context,709.12,501.23,7.1053,7.89,0,1,0,discount,rural
store_C,1,context,651.44,492.78,7.0666,8.67,1,0,1,discount,rural
store_C,2,context,659.15,511.04,7.5944,-1.89,1,0,2,discount,rural
store_C,3,context,733.06,575.98,7.1462,7.08,1,0,3,discount,rural
store_C,4,context,712.21,568.7,7.8247,-6.49,1,0,4,discount,rural
store_C,5,context,615.23,611.44,7.3103,3.79,0,0,5,discount,rural
store_C,6,context,568.99,561.87,7.1439,7.12,0,0,6,discount,rural
store_C,7,context,541.12,549.54,7.921,-8.42,0,0,0,discount,rural
store_C,8,context,583.57,576.88,7.1655,6.69,0,0,1,discount,rural
store_C,9,context,607.34,603.04,7.2847,4.31,0,0,2,discount,rural
store_C,10,context,613.79,606.86,7.1536,6.93,0,0,3,discount,rural
store_C,11,context,919.49,561.8,7.1155,7.69,1,1,4,discount,rural
store_C,12,context,622.61,613.04,7.0211,9.58,0,0,5,discount,rural
store_C,13,context,630.52,621.63,7.0554,8.89,0,0,6,discount,rural
store_C,14,context,721.62,715.12,7.1746,6.51,0,0,0,discount,rural
store_C,15,context,699.18,690.25,7.0534,8.93,0,0,1,discount,rural
store_C,16,context,578.85,580.67,7.5911,-1.82,0,0,2,discount,rural
store_C,17,context,598.23,601.84,7.6807,-3.61,0,0,3,discount,rural
store_C,18,context,554.43,552.3,7.3936,2.13,0,0,4,discount,rural
store_C,19,context,587.39,583.75,7.318,3.64,0,0,5,discount,rural
store_C,20,context,615.58,615.67,7.5045,-0.09,0,0,6,discount,rural
store_C,21,context,638.68,646.18,7.875,-7.5,0,0,0,discount,rural
store_C,22,context,555.99,563.01,7.8511,-7.02,0,0,1,discount,rural
store_C,23,context,768.83,559.7,7.0435,9.13,0,1,2,discount,rural
store_C,24,horizon,499.62,493.25,7.1815,6.37,0,0,3,discount,rural
store_C,25,horizon,570.9,565.64,7.2367,5.27,0,0,4,discount,rural
store_C,26,horizon,677.52,522.5,7.2494,5.01,1,0,5,discount,rural
store_C,27,horizon,685.25,536.68,7.5712,-1.42,1,0,6,discount,rural
store_C,28,horizon,517.46,515.78,7.4163,1.67,0,0,0,discount,rural
store_C,29,horizon,549.38,540.36,7.0493,9.01,0,0,1,discount,rural
store_C,30,horizon,470.04,467.51,7.3736,2.53,0,0,2,discount,rural
store_C,31,horizon,622.9,473.37,7.5238,-0.48,1,0,3,discount,rural
store_C,32,horizon,620.09,612.12,7.1017,7.97,0,0,4,discount,rural
store_C,33,horizon,614.45,471.12,7.8335,-6.67,1,0,5,discount,rural
store_C,34,horizon,484.25,475.29,7.052,8.96,0,0,6,discount,rural
store_C,35,horizon,781.64,590.14,7.9248,-8.5,0,1,0,discount,rural
================================================
FILE: timesfm-forecasting/examples/global-temperature/README.md
================================================
# TimesFM Forecast Report: Global Temperature Anomaly (2025)
**Model:** TimesFM 1.0 (200M) PyTorch
**Generated:** 2026-02-21
**Source:** NOAA GISTEMP Global Land-Ocean Temperature Index
---
## Executive Summary
TimesFM forecasts a mean temperature anomaly of **1.19°C** for 2025, slightly below the 2024 average of 1.25°C. The model predicts continued elevated temperatures with a peak of 1.30°C in March 2025 and a minimum of 1.06°C in December 2025.
---
## Input Data
### Historical Temperature Anomalies (2022-2024)
| Date | Anomaly (°C) | Date | Anomaly (°C) | Date | Anomaly (°C) |
|------|-------------|------|-------------|------|-------------|
| 2022-01 | 0.89 | 2023-01 | 0.87 | 2024-01 | 1.22 |
| 2022-02 | 0.89 | 2023-02 | 0.98 | 2024-02 | 1.35 |
| 2022-03 | 1.02 | 2023-03 | 1.21 | 2024-03 | 1.34 |
| 2022-04 | 0.88 | 2023-04 | 1.00 | 2024-04 | 1.26 |
| 2022-05 | 0.85 | 2023-05 | 0.94 | 2024-05 | 1.15 |
| 2022-06 | 0.88 | 2023-06 | 1.08 | 2024-06 | 1.20 |
| 2022-07 | 0.88 | 2023-07 | 1.18 | 2024-07 | 1.24 |
| 2022-08 | 0.90 | 2023-08 | 1.24 | 2024-08 | 1.30 |
| 2022-09 | 0.88 | 2023-09 | 1.47 | 2024-09 | 1.28 |
| 2022-10 | 0.95 | 2023-10 | 1.32 | 2024-10 | 1.27 |
| 2022-11 | 0.77 | 2023-11 | 1.18 | 2024-11 | 1.22 |
| 2022-12 | 0.78 | 2023-12 | 1.16 | 2024-12 | 1.20 |
**Statistics:**
- Total observations: 36 months
- Mean anomaly: 1.09°C
- Trend (2022→2024): +0.37°C
---
## Raw Forecast Output
### Point Forecast and Confidence Intervals
| Month | Point | 80% CI | 90% CI |
|-------|-------|--------|--------|
| 2025-01 | 1.259 | [1.141, 1.297] | [1.248, 1.324] |
| 2025-02 | 1.286 | [1.141, 1.340] | [1.277, 1.375] |
| 2025-03 | 1.295 | [1.127, 1.355] | [1.287, 1.404] |
| 2025-04 | 1.221 | [1.035, 1.290] | [1.208, 1.331] |
| 2025-05 | 1.170 | [0.969, 1.239] | [1.153, 1.289] |
| 2025-06 | 1.146 | [0.942, 1.218] | [1.128, 1.270] |
| 2025-07 | 1.170 | [0.950, 1.248] | [1.151, 1.300] |
| 2025-08 | 1.203 | [0.971, 1.284] | [1.186, 1.341] |
| 2025-09 | 1.191 | [0.959, 1.283] | [1.178, 1.335] |
| 2025-10 | 1.149 | [0.908, 1.240] | [1.126, 1.287] |
| 2025-11 | 1.080 | [0.836, 1.176] | [1.062, 1.228] |
| 2025-12 | 1.061 | [0.802, 1.153] | [1.037, 1.217] |
### JSON Output
```json
{
"model": "TimesFM 1.0 (200M) PyTorch",
"input": {
"source": "NOAA GISTEMP Global Temperature Anomaly",
"n_observations": 36,
"date_range": "2022-01 to 2024-12",
"mean_anomaly_c": 1.089
},
"forecast": {
"horizon": 12,
"dates": ["2025-01", "2025-02", "2025-03", "2025-04", "2025-05", "2025-06",
"2025-07", "2025-08", "2025-09", "2025-10", "2025-11", "2025-12"],
"point": [1.259, 1.286, 1.295, 1.221, 1.170, 1.146, 1.170, 1.203, 1.191, 1.149, 1.080, 1.061]
},
"summary": {
"forecast_mean_c": 1.186,
"forecast_max_c": 1.295,
"forecast_min_c": 1.061,
"vs_last_year_mean": -0.067
}
}
```
---
## Visualization

---
## Findings
### Key Observations
1. **Slight cooling trend expected**: The model forecasts a mean anomaly 0.07°C below 2024 levels, suggesting a potential stabilization after the record-breaking temperatures of 2023-2024.
2. **Seasonal pattern preserved**: The forecast shows the expected seasonal variation with higher anomalies in late winter (Feb-Mar) and lower in late fall (Nov-Dec).
3. **Widening uncertainty**: The 90% CI expands from ±0.04°C in January to ±0.08°C in December, reflecting typical forecast uncertainty growth over time.
4. **Peak temperature**: March 2025 is predicted to have the highest anomaly at 1.30°C, potentially approaching the September 2023 record of 1.47°C.
### Limitations
- TimesFM is a zero-shot forecaster without physical climate model constraints
- The 36-month training window may not capture multi-decadal climate trends
- El Niño/La Niña cycles are not explicitly modeled
### Recommendations
- Use this forecast as a baseline comparison for physics-based climate models
- Update forecast quarterly as new observations become available
- Consider ensemble approaches combining TimesFM with other methods
---
## Reproducibility
### Files
| File | Description |
|------|-------------|
| `temperature_anomaly.csv` | Input data (36 months) |
| `forecast_output.csv` | Point forecast with quantiles |
| `forecast_output.json` | Machine-readable forecast |
| `forecast_visualization.png` | Fan chart visualization |
| `run_forecast.py` | Forecasting script |
| `visualize_forecast.py` | Visualization script |
| `run_example.sh` | One-click runner |
### How to Reproduce
```bash
# Install dependencies
uv pip install "timesfm[torch]" matplotlib pandas numpy
# Run the complete example
cd scientific-skills/timesfm-forecasting/examples/global-temperature
./run_example.sh
```
---
## Technical Notes
### API Discovery
The TimesFM PyTorch API differs from the GitHub README documentation:
**Documented (GitHub README):**
```python
model = timesfm.TimesFm(
context_len=512,
horizon_len=128,
backend="gpu",
)
model.load_from_google_repo("google/timesfm-2.5-200m-pytorch")
```
**Actual Working API:**
```python
hparams = timesfm.TimesFmHparams(horizon_len=12)
checkpoint = timesfm.TimesFmCheckpoint(
huggingface_repo_id="google/timesfm-1.0-200m-pytorch"
)
model = timesfm.TimesFm(hparams=hparams, checkpoint=checkpoint)
```
### TimesFM 2.5 PyTorch Issue
The `google/timesfm-2.5-200m-pytorch` checkpoint downloads as `model.safetensors`, but the TimesFM loader expects `torch_model.ckpt`. This causes a `FileNotFoundError` at model load time. Using TimesFM 1.0 PyTorch resolves this issue.
---
*Report generated by TimesFM Forecasting Skill (claude-scientific-skills)*
================================================
FILE: timesfm-forecasting/examples/global-temperature/generate_animation_data.py
================================================
#!/usr/bin/env python3
"""
Generate animation data for interactive forecast visualization.
This script runs TimesFM forecasts incrementally, starting with minimal data
and adding one point at a time. Each forecast extends to the final date (2025-12).
Output: animation_data.json with all forecast steps
"""
from __future__ import annotations
import json
from pathlib import Path
import numpy as np
import pandas as pd
import timesfm
# Configuration
MIN_CONTEXT = 12 # Minimum points to start forecasting
MAX_HORIZON = (
36 # Max forecast length (when we have 12 points, forecast 36 months to 2025-12)
)
TOTAL_MONTHS = 48 # Total months from 2022-01 to 2025-12 (graph extent)
INPUT_FILE = Path(__file__).parent / "temperature_anomaly.csv"
OUTPUT_FILE = Path(__file__).parent / "output" / "animation_data.json"
def main() -> None:
print("=" * 60)
print(" TIMESFM ANIMATION DATA GENERATOR")
print(" Dynamic horizon - forecasts always reach 2025-12")
print("=" * 60)
# Load data
df = pd.read_csv(INPUT_FILE, parse_dates=["date"])
df = df.sort_values("date").reset_index(drop=True)
all_dates = df["date"].tolist()
all_values = df["anomaly_c"].values.astype(np.float32)
print(f"\n📊 Total data: {len(all_values)} months")
print(
f" Date range: {all_dates[0].strftime('%Y-%m')} to {all_dates[-1].strftime('%Y-%m')}"
)
print(f" Animation steps: {len(all_values) - MIN_CONTEXT + 1}")
# Load TimesFM with max horizon (will truncate output for shorter forecasts)
print(f"\n🤖 Loading TimesFM 1.0 (200M) PyTorch (horizon={MAX_HORIZON})...")
hparams = timesfm.TimesFmHparams(horizon_len=MAX_HORIZON)
checkpoint = timesfm.TimesFmCheckpoint(
huggingface_repo_id="google/timesfm-1.0-200m-pytorch"
)
model = timesfm.TimesFm(hparams=hparams, checkpoint=checkpoint)
# Generate forecasts for each step
animation_steps = []
for n_points in range(MIN_CONTEXT, len(all_values) + 1):
step_num = n_points - MIN_CONTEXT + 1
total_steps = len(all_values) - MIN_CONTEXT + 1
# Calculate dynamic horizon: forecast enough to reach 2025-12
horizon = TOTAL_MONTHS - n_points
print(
f"\n📈 Step {step_num}/{total_steps}: Using {n_points} points, forecasting {horizon} months..."
)
# Get historical data up to this point
historical_values = all_values[:n_points]
historical_dates = all_dates[:n_points]
# Run forecast (model outputs MAX_HORIZON, we truncate to actual horizon)
point, quantiles = model.forecast(
[historical_values],
freq=[0],
)
# Truncate to actual horizon
point = point[0][:horizon]
quantiles = quantiles[0, :horizon, :]
# Determine forecast dates
last_date = historical_dates[-1]
forecast_dates = pd.date_range(
start=last_date + pd.DateOffset(months=1),
periods=horizon,
freq="MS",
)
# Store step data
step_data = {
"step": step_num,
"n_points": n_points,
"horizon": horizon,
"last_historical_date": historical_dates[-1].strftime("%Y-%m"),
"historical_dates": [d.strftime("%Y-%m") for d in historical_dates],
"historical_values": historical_values.tolist(),
"forecast_dates": [d.strftime("%Y-%m") for d in forecast_dates],
"point_forecast": point.tolist(),
"q10": quantiles[:, 0].tolist(),
"q20": quantiles[:, 1].tolist(),
"q80": quantiles[:, 7].tolist(),
"q90": quantiles[:, 8].tolist(),
}
animation_steps.append(step_data)
# Show summary
print(f" Last date: {historical_dates[-1].strftime('%Y-%m')}")
print(f" Forecast to: {forecast_dates[-1].strftime('%Y-%m')}")
print(f" Forecast mean: {point.mean():.3f}°C")
# Create output
output = {
"metadata": {
"model": "TimesFM 1.0 (200M) PyTorch",
"total_steps": len(animation_steps),
"min_context": MIN_CONTEXT,
"max_horizon": MAX_HORIZON,
"total_months": TOTAL_MONTHS,
"data_source": "NOAA GISTEMP Global Temperature Anomaly",
"full_date_range": f"{all_dates[0].strftime('%Y-%m')} to {all_dates[-1].strftime('%Y-%m')}",
},
"actual_data": {
"dates": [d.strftime("%Y-%m") for d in all_dates],
"values": all_values.tolist(),
},
"animation_steps": animation_steps,
}
# Save
with open(OUTPUT_FILE, "w") as f:
json.dump(output, f, indent=2)
print(f"\n" + "=" * 60)
print(" ✅ ANIMATION DATA COMPLETE")
print("=" * 60)
print(f"\n📁 Output: {OUTPUT_FILE}")
print(f" Total steps: {len(animation_steps)}")
print(f" Each forecast extends to 2025-12")
if __name__ == "__main__":
main()
================================================
FILE: timesfm-forecasting/examples/global-temperature/generate_gif.py
================================================
#!/usr/bin/env python3
"""
Generate animated GIF showing forecast evolution.
Creates a GIF animation showing how the TimesFM forecast changes
as more historical data points are added. Shows the full actual data as a background layer.
"""
from __future__ import annotations
import json
from pathlib import Path
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import numpy as np
import pandas as pd
from PIL import Image
# Configuration
EXAMPLE_DIR = Path(__file__).parent
DATA_FILE = EXAMPLE_DIR / "output" / "animation_data.json"
OUTPUT_FILE = EXAMPLE_DIR / "output" / "forecast_animation.gif"
DURATION_MS = 500 # Time per frame in milliseconds
def create_frame(
ax,
step_data: dict,
actual_data: dict,
final_forecast: dict,
total_steps: int,
x_min,
x_max,
y_min,
y_max,
) -> None:
"""Create a single frame of the animation with fixed axes."""
ax.clear()
# Parse dates
historical_dates = pd.to_datetime(step_data["historical_dates"])
forecast_dates = pd.to_datetime(step_data["forecast_dates"])
# Get final forecast dates for full extent
final_forecast_dates = pd.to_datetime(final_forecast["forecast_dates"])
# All actual dates for full background
all_actual_dates = pd.to_datetime(actual_data["dates"])
all_actual_values = np.array(actual_data["values"])
# ========== BACKGROUND LAYER: Full actual data (faded) ==========
ax.plot(
all_actual_dates,
all_actual_values,
color="#9ca3af",
linewidth=1,
marker="o",
markersize=2,
alpha=0.3,
label="All observed data",
zorder=1,
)
# ========== BACKGROUND LAYER: Final forecast (faded) ==========
ax.plot(
final_forecast_dates,
final_forecast["point_forecast"],
color="#fca5a5",
linewidth=1,
linestyle="--",
marker="s",
markersize=2,
alpha=0.3,
label="Final forecast",
zorder=2,
)
# ========== FOREGROUND LAYER: Historical data used (bright) ==========
ax.plot(
historical_dates,
step_data["historical_values"],
color="#3b82f6",
linewidth=2.5,
marker="o",
markersize=5,
label="Data used",
zorder=10,
)
# ========== FOREGROUND LAYER: Current forecast (bright) ==========
# 90% CI (outer)
ax.fill_between(
forecast_dates,
step_data["q10"],
step_data["q90"],
alpha=0.15,
color="#ef4444",
zorder=5,
)
# 80% CI (inner)
ax.fill_between(
forecast_dates,
step_data["q20"],
step_data["q80"],
alpha=0.25,
color="#ef4444",
zorder=6,
)
# Forecast line
ax.plot(
forecast_dates,
step_data["point_forecast"],
color="#ef4444",
linewidth=2.5,
marker="s",
markersize=5,
label="Forecast",
zorder=7,
)
# ========== Vertical line at forecast boundary ==========
ax.axvline(
x=historical_dates[-1],
color="#6b7280",
linestyle="--",
linewidth=1.5,
alpha=0.7,
zorder=8,
)
# ========== Formatting ==========
ax.set_xlabel("Date", fontsize=11)
ax.set_ylabel("Temperature Anomaly (°C)", fontsize=11)
ax.set_title(
f"TimesFM Forecast Evolution\n"
f"Step {step_data['step']}/{total_steps}: {step_data['n_points']} points → "
f"forecast from {step_data['last_historical_date']}",
fontsize=13,
fontweight="bold",
)
ax.grid(True, alpha=0.3, zorder=0)
ax.legend(loc="upper left", fontsize=8)
# FIXED AXES - same for all frames
ax.set_xlim(x_min, x_max)
ax.set_ylim(y_min, y_max)
# Format x-axis
ax.xaxis.set_major_formatter(mdates.DateFormatter("%Y-%m"))
ax.xaxis.set_major_locator(mdates.MonthLocator(interval=4))
plt.setp(ax.xaxis.get_majorticklabels(), rotation=45, ha="right")
def main() -> None:
print("=" * 60)
print(" GENERATING ANIMATED GIF")
print("=" * 60)
# Load data
with open(DATA_FILE) as f:
data = json.load(f)
total_steps = len(data["animation_steps"])
print(f"\n📊 Total frames: {total_steps}")
# Get the final forecast step for reference
final_forecast = data["animation_steps"][-1]
# Calculate fixed axis extents from ALL data
all_actual_dates = pd.to_datetime(data["actual_data"]["dates"])
all_actual_values = np.array(data["actual_data"]["values"])
final_forecast_dates = pd.to_datetime(final_forecast["forecast_dates"])
final_forecast_values = np.array(final_forecast["point_forecast"])
# X-axis: from first actual date to last forecast date
x_min = all_actual_dates[0]
x_max = final_forecast_dates[-1]
# Y-axis: min/max across all actual + all forecasts with CIs
all_forecast_q10 = np.array(final_forecast["q10"])
all_forecast_q90 = np.array(final_forecast["q90"])
all_values = np.concatenate([
all_actual_values,
final_forecast_values,
all_forecast_q10,
all_forecast_q90,
])
y_min = all_values.min() - 0.05
y_max = all_values.max() + 0.05
print(f" X-axis: {x_min.strftime('%Y-%m')} to {x_max.strftime('%Y-%m')}")
print(f" Y-axis: {y_min:.2f}°C to {y_max:.2f}°C")
# Create figure
fig, ax = plt.subplots(figsize=(12, 6))
# Generate frames
frames = []
for i, step in enumerate(data["animation_steps"]):
print(f" Frame {i + 1}/{total_steps}...")
create_frame(
ax,
step,
data["actual_data"],
final_forecast,
total_steps,
x_min,
x_max,
y_min,
y_max,
)
# Save frame to buffer
fig.canvas.draw()
# Convert to PIL Image
buf = fig.canvas.buffer_rgba()
width, height = fig.canvas.get_width_height()
img = Image.frombytes("RGBA", (width, height), buf)
frames.append(img.convert("RGB"))
plt.close()
# Save as GIF
print(f"\n💾 Saving GIF: {OUTPUT_FILE}")
frames[0].save(
OUTPUT_FILE,
save_all=True,
append_images=frames[1:],
duration=DURATION_MS,
loop=0, # Loop forever
)
# Get file size
size_kb = OUTPUT_FILE.stat().st_size / 1024
print(f" File size: {size_kb:.1f} KB")
print(f"\n✅ Done!")
if __name__ == "__main__":
main()
================================================
FILE: timesfm-forecasting/examples/global-temperature/generate_html.py
================================================
#!/usr/bin/env python3
"""
Generate a self-contained HTML file with embedded animation data.
This creates a single HTML file that can be opened directly in any browser
without needing a server or external JSON file (CORS-safe).
"""
from __future__ import annotations
import json
from pathlib import Path
EXAMPLE_DIR = Path(__file__).parent
DATA_FILE = EXAMPLE_DIR / "output" / "animation_data.json"
OUTPUT_FILE = EXAMPLE_DIR / "output" / "interactive_forecast.html"
HTML_TEMPLATE = """
TimesFM Interactive Forecast Animation
TimesFM Forecast Evolution
Watch the forecast evolve as more data is added — forecasts extend to 2025-12
Watch the forecast evolve as more data is added — forecasts extend to 2025-12
Data Points Used12 / 36
2022-01Using data through 2022-12
Forecast Mean
0.86°C
Forecast Horizon
36 months
Forecast Max
--
Forecast Min
--
All Observed Data
Final Forecast (reference)
Data Used
Current Forecast
80% CI
================================================
FILE: timesfm-forecasting/examples/global-temperature/run_example.sh
================================================
#!/bin/bash
# run_example.sh - Run the TimesFM temperature anomaly forecasting example
#
# This script:
# 1. Runs the preflight system check
# 2. Runs the TimesFM forecast
# 3. Generates the visualization
#
# Usage:
# ./run_example.sh
#
# Prerequisites:
# - Python 3.10+
# - timesfm[torch] installed: uv pip install "timesfm[torch]"
# - matplotlib, pandas, numpy
set -e
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
SKILL_ROOT="$(dirname "$(dirname "$SCRIPT_DIR")")"
echo "============================================================"
echo " TimesFM Example: Global Temperature Anomaly Forecast"
echo "============================================================"
# Step 1: Preflight check
echo ""
echo "🔍 Step 1: Running preflight system check..."
python3 "$SKILL_ROOT/scripts/check_system.py" || {
echo "❌ Preflight check failed. Please fix the issues above before continuing."
exit 1
}
# Step 2: Run forecast
echo ""
echo "📊 Step 2: Running TimesFM forecast..."
cd "$SCRIPT_DIR"
python3 run_forecast.py
# Step 3: Generate visualization
echo ""
echo "📈 Step 3: Generating visualization..."
python3 visualize_forecast.py
echo ""
echo "============================================================"
echo " ✅ Example complete!"
echo "============================================================"
echo ""
echo "Output files:"
echo " - $SCRIPT_DIR/output/forecast_output.csv"
echo " - $SCRIPT_DIR/output/forecast_output.json"
echo " - $SCRIPT_DIR/output/forecast_visualization.png"
================================================
FILE: timesfm-forecasting/examples/global-temperature/run_forecast.py
================================================
#!/usr/bin/env python3
"""
Run TimesFM forecast on global temperature anomaly data.
Generates forecast output CSV and JSON for the example.
"""
from __future__ import annotations
import json
from pathlib import Path
import numpy as np
import pandas as pd
# Preflight check
print("=" * 60)
print(" TIMeSFM FORECAST - Global Temperature Anomaly Example")
print("=" * 60)
# Load data
data_path = Path(__file__).parent / "temperature_anomaly.csv"
df = pd.read_csv(data_path, parse_dates=["date"])
df = df.sort_values("date").reset_index(drop=True)
print(f"\n📊 Input Data: {len(df)} months of temperature anomalies")
print(
f" Date range: {df['date'].min().strftime('%Y-%m')} to {df['date'].max().strftime('%Y-%m')}"
)
print(f" Mean anomaly: {df['anomaly_c'].mean():.2f}°C")
print(
f" Trend: {df['anomaly_c'].iloc[-12:].mean() - df['anomaly_c'].iloc[:12].mean():.2f}°C change (first to last year)"
)
# Prepare input for TimesFM
# TimesFM expects a list of 1D numpy arrays
input_series = df["anomaly_c"].values.astype(np.float32)
# Load TimesFM 1.0 (PyTorch)
# NOTE: TimesFM 2.5 PyTorch checkpoint has a file format issue at time of writing.
# The model.safetensors file is not loadable via torch.load().
# Using TimesFM 1.0 PyTorch which works correctly.
print("\n🤖 Loading TimesFM 1.0 (200M) PyTorch...")
import timesfm
hparams = timesfm.TimesFmHparams(horizon_len=12)
checkpoint = timesfm.TimesFmCheckpoint(
huggingface_repo_id="google/timesfm-1.0-200m-pytorch"
)
model = timesfm.TimesFm(hparams=hparams, checkpoint=checkpoint)
# Forecast
print("\n📈 Running forecast (12 months ahead)...")
forecast_input = [input_series]
frequency_input = [0] # Monthly data
point_forecast, experimental_quantile_forecast = model.forecast(
forecast_input,
freq=frequency_input,
)
print(f" Point forecast shape: {point_forecast.shape}")
print(f" Quantile forecast shape: {experimental_quantile_forecast.shape}")
# Extract results
point = point_forecast[0] # Shape: (horizon,)
quantiles = experimental_quantile_forecast[0] # Shape: (horizon, num_quantiles)
# TimesFM quantiles: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.99]
# Index mapping: 0=10%, 1=20%, ..., 4=50% (median), ..., 9=99%
quantile_labels = ["10%", "20%", "30%", "40%", "50%", "60%", "70%", "80%", "90%", "99%"]
# Create forecast dates (2025 monthly)
last_date = df["date"].max()
forecast_dates = pd.date_range(
start=last_date + pd.DateOffset(months=1), periods=12, freq="MS"
)
# Build output DataFrame
output_df = pd.DataFrame(
{
"date": forecast_dates.strftime("%Y-%m-%d"),
"point_forecast": point,
"q10": quantiles[:, 0],
"q20": quantiles[:, 1],
"q30": quantiles[:, 2],
"q40": quantiles[:, 3],
"q50": quantiles[:, 4], # Median
"q60": quantiles[:, 5],
"q70": quantiles[:, 6],
"q80": quantiles[:, 7],
"q90": quantiles[:, 8],
"q99": quantiles[:, 9],
}
)
# Save outputs
output_dir = Path(__file__).parent / "output"
output_dir.mkdir(exist_ok=True)
output_df.to_csv(output_dir / "forecast_output.csv", index=False)
# JSON output for the report
output_json = {
"model": "TimesFM 1.0 (200M) PyTorch",
"input": {
"source": "NOAA GISTEMP Global Temperature Anomaly",
"n_observations": len(df),
"date_range": f"{df['date'].min().strftime('%Y-%m')} to {df['date'].max().strftime('%Y-%m')}",
"mean_anomaly_c": round(df["anomaly_c"].mean(), 3),
},
"forecast": {
"horizon": 12,
"dates": forecast_dates.strftime("%Y-%m").tolist(),
"point": point.tolist(),
"quantiles": {
label: quantiles[:, i].tolist() for i, label in enumerate(quantile_labels)
},
},
"summary": {
"forecast_mean_c": round(float(point.mean()), 3),
"forecast_max_c": round(float(point.max()), 3),
"forecast_min_c": round(float(point.min()), 3),
"vs_last_year_mean": round(
float(point.mean() - df["anomaly_c"].iloc[-12:].mean()), 3
),
},
}
with open(output_dir / "forecast_output.json", "w") as f:
json.dump(output_json, f, indent=2)
# Print summary
print("\n" + "=" * 60)
print(" FORECAST RESULTS")
print("=" * 60)
print(
f"\n📅 Forecast period: {forecast_dates[0].strftime('%Y-%m')} to {forecast_dates[-1].strftime('%Y-%m')}"
)
print(f"\n🌡️ Temperature Anomaly Forecast (°C above 1951-1980 baseline):")
print(f"\n {'Month':<10} {'Point':>8} {'80% CI':>15} {'90% CI':>15}")
print(f" {'-' * 10} {'-' * 8} {'-' * 15} {'-' * 15}")
for i, (date, pt, q10, q90, q05, q95) in enumerate(
zip(
forecast_dates.strftime("%Y-%m"),
point,
quantiles[:, 1], # 20%
quantiles[:, 7], # 80%
quantiles[:, 0], # 10%
quantiles[:, 8], # 90%
)
):
print(
f" {date:<10} {pt:>8.3f} [{q10:>6.3f}, {q90:>6.3f}] [{q05:>6.3f}, {q95:>6.3f}]"
)
print(f"\n📊 Summary Statistics:")
print(f" Mean forecast: {point.mean():.3f}°C")
print(
f" Max forecast: {point.max():.3f}°C (Month: {forecast_dates[point.argmax()].strftime('%Y-%m')})"
)
print(
f" Min forecast: {point.min():.3f}°C (Month: {forecast_dates[point.argmin()].strftime('%Y-%m')})"
)
print(f" vs 2024 mean: {point.mean() - df['anomaly_c'].iloc[-12:].mean():+.3f}°C")
print(f"\n✅ Output saved to:")
print(f" {output_dir / 'forecast_output.csv'}")
print(f" {output_dir / 'forecast_output.json'}")
================================================
FILE: timesfm-forecasting/examples/global-temperature/temperature_anomaly.csv
================================================
date,anomaly_c
2022-01-01,0.89
2022-02-01,0.89
2022-03-01,1.02
2022-04-01,0.88
2022-05-01,0.85
2022-06-01,0.88
2022-07-01,0.88
2022-08-01,0.90
2022-09-01,0.88
2022-10-01,0.95
2022-11-01,0.77
2022-12-01,0.78
2023-01-01,0.87
2023-02-01,0.98
2023-03-01,1.21
2023-04-01,1.00
2023-05-01,0.94
2023-06-01,1.08
2023-07-01,1.18
2023-08-01,1.24
2023-09-01,1.47
2023-10-01,1.32
2023-11-01,1.18
2023-12-01,1.16
2024-01-01,1.22
2024-02-01,1.35
2024-03-01,1.34
2024-04-01,1.26
2024-05-01,1.15
2024-06-01,1.20
2024-07-01,1.24
2024-08-01,1.30
2024-09-01,1.28
2024-10-01,1.27
2024-11-01,1.22
2024-12-01,1.20
================================================
FILE: timesfm-forecasting/examples/global-temperature/visualize_forecast.py
================================================
#!/usr/bin/env python3
"""
Visualize TimesFM forecast results for global temperature anomaly.
Generates a publication-quality figure showing:
- Historical data (2022-2024)
- Point forecast (2025)
- 80% and 90% confidence intervals (fan chart)
Usage:
python visualize_forecast.py
"""
from __future__ import annotations
import json
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
# Configuration
EXAMPLE_DIR = Path(__file__).parent
INPUT_FILE = EXAMPLE_DIR / "temperature_anomaly.csv"
FORECAST_FILE = EXAMPLE_DIR / "output" / "forecast_output.json"
OUTPUT_FILE = EXAMPLE_DIR / "output" / "forecast_visualization.png"
def main() -> None:
# Load historical data
df = pd.read_csv(INPUT_FILE, parse_dates=["date"])
# Load forecast results
with open(FORECAST_FILE) as f:
forecast = json.load(f)
# Extract forecast data
dates = pd.to_datetime(forecast["forecast"]["dates"])
point = np.array(forecast["forecast"]["point"])
q10 = np.array(forecast["forecast"]["quantiles"]["10%"])
q20 = np.array(forecast["forecast"]["quantiles"]["20%"])
q80 = np.array(forecast["forecast"]["quantiles"]["80%"])
q90 = np.array(forecast["forecast"]["quantiles"]["90%"])
# Create figure
fig, ax = plt.subplots(figsize=(12, 6))
# Plot historical data
ax.plot(
df["date"],
df["anomaly_c"],
color="#2563eb",
linewidth=1.5,
marker="o",
markersize=3,
label="Historical (NOAA GISTEMP)",
)
# Plot 90% CI (outer band)
ax.fill_between(dates, q10, q90, alpha=0.2, color="#dc2626", label="90% CI")
# Plot 80% CI (inner band)
ax.fill_between(dates, q20, q80, alpha=0.3, color="#dc2626", label="80% CI")
# Plot point forecast
ax.plot(
dates,
point,
color="#dc2626",
linewidth=2,
marker="s",
markersize=4,
label="TimesFM Forecast",
)
# Add vertical line at forecast boundary
ax.axvline(
x=df["date"].max(), color="#6b7280", linestyle="--", linewidth=1, alpha=0.7
)
# Formatting
ax.set_xlabel("Date", fontsize=12)
ax.set_ylabel("Temperature Anomaly (°C)", fontsize=12)
ax.set_title(
"TimesFM Zero-Shot Forecast Example\n36-month Temperature Anomaly → 12-month Forecast",
fontsize=14,
fontweight="bold",
)
# Add annotations
ax.annotate(
f"Mean forecast: {forecast['summary']['forecast_mean_c']:.2f}°C\n"
f"vs 2024: {forecast['summary']['vs_last_year_mean']:+.2f}°C",
xy=(dates[6], point[6]),
xytext=(dates[6], point[6] + 0.15),
fontsize=10,
arrowprops=dict(arrowstyle="->", color="#6b7280", lw=1),
bbox=dict(boxstyle="round,pad=0.3", facecolor="white", edgecolor="#6b7280"),
)
# Grid and legend
ax.grid(True, alpha=0.3)
ax.legend(loc="upper left", fontsize=10)
# Set y-axis limits
ax.set_ylim(0.7, 1.5)
# Rotate x-axis labels
plt.xticks(rotation=45, ha="right")
# Tight layout
plt.tight_layout()
# Save
fig.savefig(OUTPUT_FILE, dpi=150, bbox_inches="tight")
print(f"✅ Saved visualization to: {OUTPUT_FILE}")
plt.close()
if __name__ == "__main__":
main()
================================================
FILE: timesfm-forecasting/references/api_reference.md
================================================
# TimesFM API Reference
## Model Classes
### `timesfm.TimesFM_2p5_200M_torch`
The primary model class for TimesFM 2.5 (200M parameters, PyTorch backend).
#### `from_pretrained()`
```python
model = timesfm.TimesFM_2p5_200M_torch.from_pretrained(
"google/timesfm-2.5-200m-pytorch",
cache_dir=None, # Optional: custom cache directory
force_download=True, # Re-download even if cached
)
```
| Parameter | Type | Default | Description |
| --------- | ---- | ------- | ----------- |
| `model_id` | str | `"google/timesfm-2.5-200m-pytorch"` | Hugging Face model ID |
| `revision` | str \| None | None | Specific model revision |
| `cache_dir` | str \| Path \| None | None | Custom cache directory |
| `force_download` | bool | True | Force re-download of weights |
**Returns**: Initialized `TimesFM_2p5_200M_torch` instance (not yet compiled).
#### `compile()`
Compiles the model with the given forecast configuration. **Must be called before `forecast()`.**
```python
model.compile(
timesfm.ForecastConfig(
max_context=1024,
max_horizon=256,
normalize_inputs=True,
per_core_batch_size=32,
use_continuous_quantile_head=True,
force_flip_invariance=True,
infer_is_positive=True,
fix_quantile_crossing=True,
)
)
```
**Raises**: Nothing (but `forecast()` will raise `RuntimeError` if not compiled).
#### `forecast()`
Run inference on one or more time series.
```python
point_forecast, quantile_forecast = model.forecast(
horizon=24,
inputs=[array1, array2, ...],
)
```
| Parameter | Type | Description |
| --------- | ---- | ----------- |
| `horizon` | int | Number of future steps to forecast |
| `inputs` | list[np.ndarray] | List of 1-D numpy arrays (each is a time series) |
**Returns**: `tuple[np.ndarray, np.ndarray]`
- `point_forecast`: shape `(batch_size, horizon)` — median (0.5 quantile)
- `quantile_forecast`: shape `(batch_size, horizon, 10)` — [mean, q10, q20, ..., q90]
**Raises**: `RuntimeError` if model is not compiled.
**Key behaviors**:
- Leading NaN values are stripped automatically
- Internal NaN values are linearly interpolated
- Series longer than `max_context` are truncated (last `max_context` points used)
- Series shorter than `max_context` are padded
#### `forecast_with_covariates()`
Run inference with exogenous variables (requires `timesfm[xreg]`).
```python
point, quantiles = model.forecast_with_covariates(
inputs=inputs,
dynamic_numerical_covariates={"temp": [temp_array1, temp_array2]},
dynamic_categorical_covariates={"dow": [dow_array1, dow_array2]},
static_categorical_covariates={"region": ["east", "west"]},
xreg_mode="xreg + timesfm",
)
```
| Parameter | Type | Description |
| --------- | ---- | ----------- |
| `inputs` | list[np.ndarray] | Target time series |
| `dynamic_numerical_covariates` | dict[str, list[np.ndarray]] | Time-varying numeric features |
| `dynamic_categorical_covariates` | dict[str, list[np.ndarray]] | Time-varying categorical features |
| `static_categorical_covariates` | dict[str, list[str]] | Fixed categorical features per series |
| `xreg_mode` | str | `"xreg + timesfm"` or `"timesfm + xreg"` |
**Note**: Dynamic covariates must have length `context + horizon` for each series.
---
## `timesfm.ForecastConfig`
Immutable dataclass controlling all forecast behavior.
```python
@dataclasses.dataclass(frozen=True)
class ForecastConfig:
max_context: int = 0
max_horizon: int = 0
normalize_inputs: bool = False
per_core_batch_size: int = 1
use_continuous_quantile_head: bool = False
force_flip_invariance: bool = True
infer_is_positive: bool = True
fix_quantile_crossing: bool = False
return_backcast: bool = False
quantiles: list[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
decode_index: int = 5
```
### Parameter Details
#### `max_context` (int, default=0)
Maximum number of historical time points to use as context.
- **0**: Use the model's maximum supported context (16,384 for v2.5)
- **N**: Truncate series to last N points
- **Best practice**: Set to the length of your longest series, or 512–2048 for speed
#### `max_horizon` (int, default=0)
Maximum forecast horizon.
- **0**: Use the model's maximum
- **N**: Forecasts up to N steps (can still call `forecast(horizon=M)` where M ≤ N)
- **Best practice**: Set to your expected maximum forecast length
#### `normalize_inputs` (bool, default=False)
Whether to z-normalize each series before feeding to the model.
- **True** (RECOMMENDED): Normalizes each series to zero mean, unit variance
- **False**: Raw values are passed directly
- **When False is OK**: Only if your series are already normalized or very close to scale 1.0
#### `per_core_batch_size` (int, default=1)
Number of series processed per device in each batch.
- Increase for throughput, decrease if OOM
- See `references/system_requirements.md` for recommended values by hardware
#### `use_continuous_quantile_head` (bool, default=False)
Use the 30M-parameter continuous quantile head for better interval calibration.
- **True** (RECOMMENDED): More accurate prediction intervals, especially for longer horizons
- **False**: Uses fixed quantile buckets (faster but less accurate intervals)
#### `force_flip_invariance` (bool, default=True)
Ensures the model satisfies `f(-x) = -f(x)`.
- **True** (RECOMMENDED): Mathematical consistency — forecasts are invariant to sign flip
- **False**: Slightly faster but may produce asymmetric forecasts
#### `infer_is_positive` (bool, default=True)
Automatically detect if all input values are positive and clamp forecasts ≥ 0.
- **True**: Safe for sales, demand, counts, prices, volumes
- **False**: Required for temperature, returns, PnL, any series that can be negative
#### `fix_quantile_crossing` (bool, default=False)
Post-process quantiles to ensure monotonicity (q10 ≤ q20 ≤ ... ≤ q90).
- **True** (RECOMMENDED): Guarantees well-ordered quantiles
- **False**: Slightly faster but quantiles may occasionally cross
#### `return_backcast` (bool, default=False)
Return the model's reconstruction of the input (backcast) in addition to forecast.
- **True**: Used for covariate workflows and diagnostics
- **False**: Only return forecast
---
## Available Model Checkpoints
| Model ID | Version | Params | Backend | Context |
| -------- | ------- | ------ | ------- | ------- |
| `google/timesfm-2.5-200m-pytorch` | 2.5 | 200M | PyTorch | 16,384 |
| `google/timesfm-2.5-200m-flax` | 2.5 | 200M | JAX/Flax | 16,384 |
| `google/timesfm-2.5-200m-transformers` | 2.5 | 200M | Transformers | 16,384 |
| `google/timesfm-2.0-500m-pytorch` | 2.0 | 500M | PyTorch | 2,048 |
| `google/timesfm-2.0-500m-jax` | 2.0 | 500M | JAX | 2,048 |
| `google/timesfm-1.0-200m-pytorch` | 1.0 | 200M | PyTorch | 2,048 |
| `google/timesfm-1.0-200m` | 1.0 | 200M | JAX | 2,048 |
---
## Output Shape Reference
| Output | Shape | Description |
| ------ | ----- | ----------- |
| `point_forecast` | `(B, H)` | Median forecast for B series, H steps |
| `quantile_forecast` | `(B, H, 10)` | Full quantile distribution |
| `quantile_forecast[:,:,0]` | `(B, H)` | Mean |
| `quantile_forecast[:,:,1]` | `(B, H)` | 10th percentile |
| `quantile_forecast[:,:,5]` | `(B, H)` | 50th percentile (= point_forecast) |
| `quantile_forecast[:,:,9]` | `(B, H)` | 90th percentile |
Where `B` = batch size (number of input series), `H` = forecast horizon.
---
---
## Memory Estimation
Before running forecasts on large datasets, estimate memory requirements:
### Formula
```mermaid
block-beta
columns 3
ram["Total RAM Required"] model["Model Weights ~0.8 GB"] overhead["Runtime Overhead ~0.5 GB"] buffers["I/O Buffers ~0.2 MB per 1000 series per 1000 context"]
ram --> model
ram --> overhead
ram --> buffers
```
**Formula**:
`RAM (GB) ≈ 0.8 + 0.5 + (0.0002 × num_series × context_length)`
**Variables**:
- `num_series`: Number of time series in your batch
- `context_length`: Your `max_context` value (or max series length)
- `batch_size`: Your `per_core_batch_size` (affects parallel processing overhead)
### Quick Reference
| Dataset Size | Context=512 | Context=1024 | Context=2048 |
|--------------|-------------|--------------|--------------|
| 100 series | ~1.4 GB | ~1.5 GB | ~1.7 GB |
| 1,000 series | ~1.9 GB | ~2.3 GB | ~3.1 GB |
| 10,000 series| ~9.0 GB | ~17.0 GB | ~33.0 GB |
### Using the Preflight Checker
```bash
python scripts/check_system.py \
--num-series 1000 \
--context-length 1024 \
--batch-size 32
```
This validates both system requirements AND dataset fit before loading the model.
### Reducing Memory Usage
If your dataset is too large:
1. **Reduce context length**: Use `max_context=512` instead of 1024+ (50% reduction)
2. **Process in chunks**: Split large batches into smaller groups:
```python
CHUNK_SIZE = 100
for i in range(0, len(inputs), CHUNK_SIZE):
chunk = inputs[i:i+CHUNK_SIZE]
point, quantiles = model.forecast(horizon=H, inputs=chunk)
# Save chunk results
```
3. **Reduce batch size**: Lower `per_core_batch_size` (slower but less memory)
4. **Use CPU**: If GPU OOM, the model will automatically fall back to CPU
## Error Handling
| Error | Cause | Fix |
| ----- | ----- | --- |
| `RuntimeError: Model is not compiled` | Called `forecast()` before `compile()` | Call `model.compile(ForecastConfig(...))` first |
| `torch.cuda.OutOfMemoryError` | Batch too large for GPU | Reduce `per_core_batch_size` |
| `ValueError: inputs must be list` | Passed array instead of list | Wrap in list: `[array]` |
| `HfHubHTTPError` | Download failed | Check internet, set `HF_HOME` to writable dir |
================================================
FILE: timesfm-forecasting/references/data_preparation.md
================================================
# Data Preparation for TimesFM
## Input Format
TimesFM accepts a **list of 1-D numpy arrays**. Each array represents one
univariate time series.
```python
inputs = [
np.array([1.0, 2.0, 3.0, 4.0, 5.0]), # Series 1
np.array([10.0, 20.0, 15.0, 25.0]), # Series 2 (different length)
np.array([100.0, 110.0, 105.0, 115.0, 120.0, 130.0]), # Series 3
]
```
### Key Properties
- **Variable lengths**: Series in the same batch can have different lengths
- **Float values**: Use `np.float32` or `np.float64`
- **1-D only**: Each array must be 1-dimensional (not 2-D matrix rows)
- **NaN handling**: Leading NaNs are stripped; internal NaNs are linearly interpolated
## Loading from Common Formats
### CSV — Single Series (Long Format)
```python
import pandas as pd
import numpy as np
df = pd.read_csv("data.csv", parse_dates=["date"])
values = df["value"].values.astype(np.float32)
inputs = [values]
```
### CSV — Multiple Series (Wide Format)
```python
df = pd.read_csv("data.csv", parse_dates=["date"], index_col="date")
inputs = [df[col].dropna().values.astype(np.float32) for col in df.columns]
```
### CSV — Long Format with ID Column
```python
df = pd.read_csv("data.csv", parse_dates=["date"])
inputs = []
for series_id, group in df.groupby("series_id"):
values = group.sort_values("date")["value"].values.astype(np.float32)
inputs.append(values)
```
### Pandas DataFrame
```python
# Single column
inputs = [df["temperature"].values.astype(np.float32)]
# Multiple columns
inputs = [df[col].dropna().values.astype(np.float32) for col in numeric_cols]
```
### Numpy Arrays
```python
# 2-D array (rows = series, cols = time steps)
data = np.load("timeseries.npy") # shape (N, T)
inputs = [data[i] for i in range(data.shape[0])]
# Or from 1-D
inputs = [np.sin(np.linspace(0, 10, 200))]
```
### Excel
```python
df = pd.read_excel("data.xlsx", sheet_name="Sheet1")
inputs = [df[col].dropna().values.astype(np.float32) for col in df.select_dtypes(include=[np.number]).columns]
```
### Parquet
```python
df = pd.read_parquet("data.parquet")
inputs = [df[col].dropna().values.astype(np.float32) for col in df.select_dtypes(include=[np.number]).columns]
```
### JSON
```python
import json
with open("data.json") as f:
data = json.load(f)
# Assumes {"series_name": [values...], ...}
inputs = [np.array(values, dtype=np.float32) for values in data.values()]
```
## NaN Handling
TimesFM handles NaN values automatically:
### Leading NaNs
Stripped before feeding to the model:
```python
# Input: [NaN, NaN, 1.0, 2.0, 3.0]
# Actual: [1.0, 2.0, 3.0]
```
### Internal NaNs
Linearly interpolated:
```python
# Input: [1.0, NaN, 3.0, NaN, NaN, 6.0]
# Actual: [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
```
### Trailing NaNs
**Not handled** — drop them before passing to the model:
```python
values = df["value"].values.astype(np.float32)
# Remove trailing NaNs
while len(values) > 0 and np.isnan(values[-1]):
values = values[:-1]
inputs = [values]
```
### Best Practice
```python
def clean_series(arr: np.ndarray) -> np.ndarray:
"""Clean a time series for TimesFM input."""
arr = np.asarray(arr, dtype=np.float32)
# Remove trailing NaNs
while len(arr) > 0 and np.isnan(arr[-1]):
arr = arr[:-1]
# Replace inf with NaN (will be interpolated)
arr[np.isinf(arr)] = np.nan
return arr
inputs = [clean_series(df[col].values) for col in cols]
```
## Context Length Considerations
| Context Length | Use Case | Notes |
| -------------- | -------- | ----- |
| 64–256 | Quick prototyping | Minimal context, fast |
| 256–512 | Daily data, ~1 year | Good balance |
| 512–1024 | Daily data, ~2-3 years | Standard production |
| 1024–4096 | Hourly data, weekly patterns | More context = better |
| 4096–16384 | High-frequency, long patterns | TimesFM 2.5 maximum |
**Rule of thumb**: Provide at least 3–5 full cycles of the dominant pattern
(e.g., for weekly seasonality with daily data, provide at least 21–35 days).
## Covariates (XReg)
TimesFM 2.5 supports exogenous variables through the `forecast_with_covariates()` API.
### Types of Covariates
| Type | Description | Example |
| ---- | ----------- | ------- |
| **Dynamic numerical** | Time-varying numeric features | Temperature, price, promotion spend |
| **Dynamic categorical** | Time-varying categorical features | Day of week, holiday flag |
| **Static categorical** | Fixed per-series features | Store ID, region, product category |
### Preparing Covariates
Each covariate must have length `context + horizon` for each series:
```python
import numpy as np
context_len = 100 # length of historical data
horizon = 24 # forecast horizon
total_len = context_len + horizon
# Dynamic numerical: temperature forecast for each series
temp = [
np.random.randn(total_len).astype(np.float32), # Series 1
np.random.randn(total_len).astype(np.float32), # Series 2
]
# Dynamic categorical: day of week (0-6) for each series
dow = [
np.tile(np.arange(7), total_len // 7 + 1)[:total_len], # Series 1
np.tile(np.arange(7), total_len // 7 + 1)[:total_len], # Series 2
]
# Static categorical: one label per series
regions = ["east", "west"]
# Forecast with covariates
point, quantiles = model.forecast_with_covariates(
inputs=[values1, values2],
dynamic_numerical_covariates={"temperature": temp},
dynamic_categorical_covariates={"day_of_week": dow},
static_categorical_covariates={"region": regions},
xreg_mode="xreg + timesfm",
)
```
### XReg Modes
| Mode | Description |
| ---- | ----------- |
| `"xreg + timesfm"` | Covariates processed first, then combined with TimesFM forecast |
| `"timesfm + xreg"` | TimesFM forecast first, then adjusted by covariates |
## Common Data Issues
### Issue: Series too short
TimesFM needs at least 1 data point, but more context = better forecasts.
```python
MIN_LENGTH = 32 # Practical minimum for meaningful forecasts
inputs = [
arr for arr in raw_inputs
if len(arr[~np.isnan(arr)]) >= MIN_LENGTH
]
```
### Issue: Series with constant values
Constant series may produce NaN or zero-width prediction intervals:
```python
for i, arr in enumerate(inputs):
if np.std(arr[~np.isnan(arr)]) < 1e-10:
print(f"⚠️ Series {i} is constant — forecast will be flat")
```
### Issue: Extreme outliers
Large outliers can destabilize forecasts even with normalization:
```python
def clip_outliers(arr: np.ndarray, n_sigma: float = 5.0) -> np.ndarray:
"""Clip values beyond n_sigma standard deviations."""
mu = np.nanmean(arr)
sigma = np.nanstd(arr)
if sigma > 0:
arr = np.clip(arr, mu - n_sigma * sigma, mu + n_sigma * sigma)
return arr
```
### Issue: Mixed frequencies in batch
TimesFM handles each series independently, so you can mix frequencies:
```python
inputs = [
daily_sales, # 365 points
weekly_revenue, # 52 points
monthly_users, # 24 points
]
# All forecasted in one batch — TimesFM handles different lengths
point, q = model.forecast(horizon=12, inputs=inputs)
```
However, the `horizon` is shared. If you need different horizons per series,
forecast in separate calls.
================================================
FILE: timesfm-forecasting/references/system_requirements.md
================================================
# System Requirements for TimesFM
## Hardware Tiers
TimesFM can run on a variety of hardware configurations. This guide helps you
choose the right setup and tune performance for your machine.
### How Context Limits Are Determined
The `max_context` values in each tier are **conservative recommendations** based on memory-performance tradeoffs, not hard limits. TimesFM 2.5 supports up to 16,384 context points, but smaller values are recommended for most use cases.
**Why 512 and 1024?**
| Factor | 512 Context | 1024 Context |
|--------|-------------|--------------|
| **Memory per 1000 series** | ~100 MB | ~200 MB |
| **Typical Use Case** | Daily data, ~1-2 years | Daily data, ~2-3 years |
| **Inference Speed** | Faster | Moderate |
| **Hardware** | 4-8 GB RAM | 16 GB RAM or GPU |
**Memory Formula**: `RAM ≈ model_weights + 0.5 GB + (0.2 MB × num_series × context_length / 1000)`
Where:
- `model_weights` = ~800 MB (TimesFM 2.5)
- `context_length` = your `max_context` value
- `num_series` = number of time series in your batch
**You can use larger contexts** if your hardware supports it:
- **Up to 2048**: Requires ~16 GB RAM for moderate batch sizes
- **Up to 4096**: Requires GPU or 32+ GB RAM
- **Up to 16384**: Maximum supported, requires significant memory
See [Data Preparation Guide](data_preparation.md) for context length recommendations by data frequency.
### Tier 1: Minimal (CPU-Only, 4–8 GB RAM)
- **Use case**: Light exploration, single-series forecasting, prototyping
- **Model**: TimesFM 2.5 (200M) only
- **Batch size**: `per_core_batch_size=4`
- **Context**: Limit `max_context=512`
- **Expected speed**: ~2–5 seconds per 100-point series
```python
model.compile(timesfm.ForecastConfig(
max_context=512,
max_horizon=128,
per_core_batch_size=4,
normalize_inputs=True,
use_continuous_quantile_head=True,
fix_quantile_crossing=True,
))
```
### Tier 2: Standard (CPU 16 GB or GPU 4–8 GB VRAM)
- **Use case**: Batch forecasting (dozens of series), evaluation, production prototypes
- **Model**: TimesFM 2.5 (200M)
- **Batch size**: `per_core_batch_size=32` (CPU) or `64` (GPU)
- **Context**: `max_context=1024`
- **Expected speed**: ~0.5–1 second per 100-point series (GPU)
```python
model.compile(timesfm.ForecastConfig(
max_context=1024,
max_horizon=256,
per_core_batch_size=64,
normalize_inputs=True,
use_continuous_quantile_head=True,
fix_quantile_crossing=True,
))
```
### Tier 3: Production (GPU 16+ GB VRAM or Apple Silicon 32+ GB)
- **Use case**: Large-scale batch forecasting (thousands of series), long context
- **Model**: TimesFM 2.5 (200M)
- **Batch size**: `per_core_batch_size=128–256`
- **Context**: `max_context=4096` or higher
- **Expected speed**: ~0.1–0.3 seconds per 100-point series
```python
model.compile(timesfm.ForecastConfig(
max_context=4096,
max_horizon=256,
per_core_batch_size=128,
normalize_inputs=True,
use_continuous_quantile_head=True,
fix_quantile_crossing=True,
))
```
### Tier 4: Legacy Models (v1.0/v2.0 — 500M parameters)
- **⚠️ WARNING**: TimesFM v2.0 (500M) requires **≥ 16 GB RAM** (CPU) or **≥ 8 GB VRAM** (GPU)
- **⚠️ WARNING**: TimesFM v1.0 legacy JAX version may require **≥ 32 GB RAM**
- **Recommendation**: Unless you specifically need a legacy checkpoint, use TimesFM 2.5
## Memory Estimation
### CPU Memory (RAM)
Approximate RAM usage during inference:
| Component | TimesFM 2.5 (200M) | TimesFM 2.0 (500M) |
| --------- | ------------------- | ------------------- |
| Model weights | ~800 MB | ~2 GB |
| Runtime overhead | ~500 MB | ~1 GB |
| Input/output buffers | ~200 MB per 1000 series | ~500 MB per 1000 series |
| **Total (small batch)** | **~1.5 GB** | **~3.5 GB** |
| **Total (large batch)** | **~3 GB** | **~6 GB** |
**Formula**: `RAM ≈ model_weights + 0.5 GB + (0.2 MB × num_series × context_length / 1000)`
### GPU Memory (VRAM)
| Component | TimesFM 2.5 (200M) |
| --------- | ------------------- |
| Model weights | ~800 MB |
| KV cache + activations | ~200–500 MB (scales with context) |
| Batch buffers | ~100 MB per 100 series at context=1024 |
| **Total (batch=32)** | **~1.2 GB** |
| **Total (batch=128)** | **~1.8 GB** |
| **Total (batch=256)** | **~2.5 GB** |
### Disk Space
| Item | Size |
| ---- | ---- |
| TimesFM 2.5 safetensors | ~800 MB |
| Hugging Face cache overhead | ~200 MB |
| **Total download** | **~1 GB** |
Model weights are downloaded once from Hugging Face Hub and cached in
`~/.cache/huggingface/` (or `$HF_HOME`).
## GPU Selection Guide
### NVIDIA GPUs (CUDA)
| GPU | VRAM | Recommended batch | Notes |
| --- | ---- | ----------------- | ----- |
| RTX 3060 | 12 GB | 64 | Good entry-level |
| RTX 3090 / 4090 | 24 GB | 256 | Excellent for production |
| A100 (40 GB) | 40 GB | 512 | Cloud/HPC |
| A100 (80 GB) | 80 GB | 1024 | Cloud/HPC |
| T4 | 16 GB | 128 | Cloud (Colab, AWS) |
| V100 | 16–32 GB | 128–256 | Cloud |
### Apple Silicon (MPS)
| Chip | Unified Memory | Recommended batch | Notes |
| ---- | -------------- | ----------------- | ----- |
| M1 | 8–16 GB | 16–32 | Works, slower than CUDA |
| M1 Pro/Max | 16–64 GB | 32–128 | Good performance |
| M2/M3/M4 Pro/Max | 18–128 GB | 64–256 | Excellent |
### CPU Only
Works on any CPU with sufficient RAM. Expect 5–20× slower than GPU.
## Python and Package Requirements
| Requirement | Minimum | Recommended |
| ----------- | ------- | ----------- |
| Python | 3.10 | 3.12+ |
| numpy | 1.26.4 | latest |
| torch | 2.0.0 | latest |
| huggingface_hub | 0.23.0 | latest |
| safetensors | 0.5.3 | latest |
### Optional Dependencies
| Package | Purpose | Install |
| ------- | ------- | ------- |
| jax | Flax backend | `pip install jax[cuda]` |
| flax | Flax backend | `pip install flax` |
| scikit-learn | XReg covariates | `pip install scikit-learn` |
## Operating System Compatibility
| OS | Status | Notes |
| -- | ------ | ----- |
| Linux (Ubuntu 20.04+) | ✅ Fully supported | Best performance with CUDA |
| macOS 13+ (Ventura) | ✅ Fully supported | MPS acceleration on Apple Silicon |
| Windows 11 + WSL2 | ✅ Supported | Use WSL2 for best experience |
| Windows (native) | ⚠️ Partial | PyTorch works, some edge cases |
## Troubleshooting
### Out of Memory (OOM)
```python
# Reduce batch size
model.compile(timesfm.ForecastConfig(
per_core_batch_size=4, # Start very small
max_context=512, # Reduce context
...
))
# Process in chunks
for i in range(0, len(inputs), 50):
chunk = inputs[i:i+50]
p, q = model.forecast(horizon=H, inputs=chunk)
```
### Slow Inference on CPU
```python
# Ensure matmul precision is set
import torch
torch.set_float32_matmul_precision("high")
# Use smaller context
model.compile(timesfm.ForecastConfig(
max_context=256, # Shorter context = faster
...
))
```
### Model Download Fails
```bash
# Set a different cache directory
export HF_HOME=/path/with/more/space
# Or download manually
huggingface-cli download google/timesfm-2.5-200m-pytorch
```
================================================
FILE: timesfm-forecasting/scripts/check_system.py
================================================
#!/usr/bin/env python3
"""TimesFM System Requirements Preflight Checker.
MANDATORY: Run this script before loading TimesFM for the first time.
It checks RAM, GPU/VRAM, disk space, Python version, and package
installation so the agent never crashes a user's machine.
Usage:
python check_system.py
python check_system.py --model v2.5 # default
python check_system.py --model v2.0 # archived 500M model
python check_system.py --model v1.0 # archived 200M model
python check_system.py --json # machine-readable output
"""
from __future__ import annotations
import argparse
import json
import os
import platform
import shutil
import struct
import sys
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
import math
# ---------------------------------------------------------------------------
# Model requirement profiles
# ---------------------------------------------------------------------------
MODEL_PROFILES: dict[str, dict[str, Any]] = {
"v2.5": {
"name": "TimesFM 2.5 (200M)",
"params": "200M",
"min_ram_gb": 2.0,
"recommended_ram_gb": 4.0,
"min_vram_gb": 2.0,
"recommended_vram_gb": 4.0,
"disk_gb": 2.0, # model weights + overhead
"hf_repo": "google/timesfm-2.5-200m-pytorch",
},
"v2.0": {
"name": "TimesFM 2.0 (500M)",
"params": "500M",
"min_ram_gb": 8.0,
"recommended_ram_gb": 16.0,
"min_vram_gb": 4.0,
"recommended_vram_gb": 8.0,
"disk_gb": 4.0,
"hf_repo": "google/timesfm-2.0-500m-pytorch",
},
"v1.0": {
"name": "TimesFM 1.0 (200M)",
"params": "200M",
"min_ram_gb": 4.0,
"recommended_ram_gb": 8.0,
"min_vram_gb": 2.0,
"recommended_vram_gb": 4.0,
"disk_gb": 2.0,
"hf_repo": "google/timesfm-1.0-200m-pytorch",
},
}
# ---------------------------------------------------------------------------
# Result dataclass
# ---------------------------------------------------------------------------
@dataclass
class CheckResult:
name: str
status: str # "pass", "warn", "fail"
detail: str
value: str = ""
@property
def icon(self) -> str:
return {"pass": "✅", "warn": "⚠️", "fail": "🛑"}.get(self.status, "❓")
def __str__(self) -> str:
return f"[{self.name:<10}] {self.value:<40} {self.icon} {self.status.upper()}"
@dataclass
class SystemReport:
model: str
checks: list[CheckResult] = field(default_factory=list)
verdict: str = ""
verdict_detail: str = ""
recommended_batch_size: int = 1
mode: str = "cpu" # "cpu", "gpu", "mps"
@property
def passed(self) -> bool:
return all(c.status != "fail" for c in self.checks)
def to_dict(self) -> dict[str, Any]:
return {
"model": self.model,
"passed": self.passed,
"mode": self.mode,
"recommended_batch_size": self.recommended_batch_size,
"verdict": self.verdict,
"verdict_detail": self.verdict_detail,
"checks": [
{
"name": c.name,
"status": c.status,
"detail": c.detail,
"value": c.value,
}
for c in self.checks
],
}
# ---------------------------------------------------------------------------
# Individual checks
# ---------------------------------------------------------------------------
def _get_total_ram_gb() -> float:
"""Return total physical RAM in GB, cross-platform."""
try:
if sys.platform == "linux":
with open("/proc/meminfo") as f:
for line in f:
if line.startswith("MemTotal"):
return int(line.split()[1]) / (1024 * 1024)
elif sys.platform == "darwin":
import subprocess
result = subprocess.run(
["sysctl", "-n", "hw.memsize"],
capture_output=True,
text=True,
check=True,
)
return int(result.stdout.strip()) / (1024**3)
elif sys.platform == "win32":
import ctypes
kernel32 = ctypes.windll.kernel32 # type: ignore[attr-defined]
class MEMORYSTATUSEX(ctypes.Structure):
_fields_ = [
("dwLength", ctypes.c_ulong),
("dwMemoryLoad", ctypes.c_ulong),
("ullTotalPhys", ctypes.c_ulonglong),
("ullAvailPhys", ctypes.c_ulonglong),
("ullTotalPageFile", ctypes.c_ulonglong),
("ullAvailPageFile", ctypes.c_ulonglong),
("ullTotalVirtual", ctypes.c_ulonglong),
("ullAvailVirtual", ctypes.c_ulonglong),
("sullAvailExtendedVirtual", ctypes.c_ulonglong),
]
stat = MEMORYSTATUSEX()
stat.dwLength = ctypes.sizeof(stat)
kernel32.GlobalMemoryStatusEx(ctypes.byref(stat))
return stat.ullTotalPhys / (1024**3)
except Exception:
pass
# Fallback: use struct to estimate (unreliable)
return struct.calcsize("P") * 8 / 8 # placeholder
def _get_available_ram_gb() -> float:
"""Return available RAM in GB."""
try:
if sys.platform == "linux":
with open("/proc/meminfo") as f:
for line in f:
if line.startswith("MemAvailable"):
return int(line.split()[1]) / (1024 * 1024)
elif sys.platform == "darwin":
import subprocess
# Use vm_stat for available memory on macOS
result = subprocess.run(
["vm_stat"], capture_output=True, text=True, check=True
)
free = 0
page_size = 4096
for line in result.stdout.split("\n"):
if "Pages free" in line or "Pages inactive" in line:
val = line.split(":")[1].strip().rstrip(".")
free += int(val) * page_size
return free / (1024**3)
elif sys.platform == "win32":
import ctypes
kernel32 = ctypes.windll.kernel32 # type: ignore[attr-defined]
class MEMORYSTATUSEX(ctypes.Structure):
_fields_ = [
("dwLength", ctypes.c_ulong),
("dwMemoryLoad", ctypes.c_ulong),
("ullTotalPhys", ctypes.c_ulonglong),
("ullAvailPhys", ctypes.c_ulonglong),
("ullTotalPageFile", ctypes.c_ulonglong),
("ullAvailPageFile", ctypes.c_ulonglong),
("ullTotalVirtual", ctypes.c_ulonglong),
("ullAvailVirtual", ctypes.c_ulonglong),
("sullAvailExtendedVirtual", ctypes.c_ulonglong),
]
stat = MEMORYSTATUSEX()
stat.dwLength = ctypes.sizeof(stat)
kernel32.GlobalMemoryStatusEx(ctypes.byref(stat))
return stat.ullAvailPhys / (1024**3)
except Exception:
pass
return 0.0
def check_ram(profile: dict[str, Any]) -> CheckResult:
"""Check if system has enough RAM."""
total = _get_total_ram_gb()
available = _get_available_ram_gb()
min_ram = profile["min_ram_gb"]
rec_ram = profile["recommended_ram_gb"]
value = f"Total: {total:.1f} GB | Available: {available:.1f} GB"
if total < min_ram:
return CheckResult(
name="RAM",
status="fail",
detail=(
f"System has {total:.1f} GB RAM but {profile['name']} requires "
f"at least {min_ram:.0f} GB. The model will likely fail to load "
f"or cause the system to swap heavily and become unresponsive."
),
value=value,
)
elif total < rec_ram:
return CheckResult(
name="RAM",
status="warn",
detail=(
f"System has {total:.1f} GB RAM. {profile['name']} recommends "
f"{rec_ram:.0f} GB. It may work with small batch sizes but could "
f"be tight. Use per_core_batch_size=4 or lower."
),
value=value,
)
else:
return CheckResult(
name="RAM",
status="pass",
detail=f"System has {total:.1f} GB RAM, meets {rec_ram:.0f} GB recommendation.",
value=value,
)
def check_gpu() -> CheckResult:
"""Check GPU availability and VRAM."""
# Try CUDA first
try:
import torch
if torch.cuda.is_available():
name = torch.cuda.get_device_name(0)
vram = torch.cuda.get_device_properties(0).total_memory / (1024**3)
return CheckResult(
name="GPU",
status="pass",
detail=f"{name} with {vram:.1f} GB VRAM detected.",
value=f"{name} | VRAM: {vram:.1f} GB",
)
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
return CheckResult(
name="GPU",
status="pass",
detail="Apple Silicon MPS backend available. Uses unified memory.",
value="Apple Silicon MPS",
)
else:
return CheckResult(
name="GPU",
status="warn",
detail=(
"No GPU detected. TimesFM will run on CPU (slower but functional). "
"Install CUDA-enabled PyTorch for GPU acceleration."
),
value="None (CPU only)",
)
except ImportError:
return CheckResult(
name="GPU",
status="warn",
detail="PyTorch not installed — cannot check GPU. Install torch first.",
value="Unknown (torch not installed)",
)
def check_disk(profile: dict[str, Any]) -> CheckResult:
"""Check available disk space for model download."""
# Check HuggingFace cache dir or home dir
hf_cache = os.environ.get("HF_HOME", os.path.expanduser("~/.cache/huggingface"))
cache_dir = Path(hf_cache)
check_dir = cache_dir if cache_dir.exists() else Path.home()
usage = shutil.disk_usage(str(check_dir))
free_gb = usage.free / (1024**3)
required = profile["disk_gb"]
value = f"Free: {free_gb:.1f} GB (in {check_dir})"
if free_gb < required:
return CheckResult(
name="Disk",
status="fail",
detail=(
f"Only {free_gb:.1f} GB free in {check_dir}. "
f"Need at least {required:.0f} GB for model weights. "
f"Free up space or set HF_HOME to a larger volume."
),
value=value,
)
else:
return CheckResult(
name="Disk",
status="pass",
detail=f"{free_gb:.1f} GB available, exceeds {required:.0f} GB requirement.",
value=value,
)
def check_python() -> CheckResult:
"""Check Python version >= 3.10."""
version = sys.version.split()[0]
major, minor = sys.version_info[:2]
if (major, minor) < (3, 10):
return CheckResult(
name="Python",
status="fail",
detail=f"Python {version} detected. TimesFM requires Python >= 3.10.",
value=version,
)
else:
return CheckResult(
name="Python",
status="pass",
detail=f"Python {version} meets >= 3.10 requirement.",
value=version,
)
def check_package(pkg_name: str, import_name: str | None = None) -> CheckResult:
"""Check if a Python package is installed."""
import_name = import_name or pkg_name
try:
mod = __import__(import_name)
version = getattr(mod, "__version__", "unknown")
return CheckResult(
name=pkg_name,
status="pass",
detail=f"{pkg_name} {version} is installed.",
value=f"Installed ({version})",
)
except ImportError:
return CheckResult(
name=pkg_name,
status="warn",
detail=f"{pkg_name} is not installed. Run: uv pip install {pkg_name}",
value="Not installed",
)
# ---------------------------------------------------------------------------
# Batch size recommendation
# ---------------------------------------------------------------------------
def recommend_batch_size(report: SystemReport) -> int:
"""Recommend per_core_batch_size based on available resources."""
total_ram = _get_total_ram_gb()
# Check if GPU is available
gpu_check = next((c for c in report.checks if c.name == "GPU"), None)
if gpu_check and gpu_check.status == "pass" and "VRAM" in gpu_check.value:
# Extract VRAM
try:
vram_str = gpu_check.value.split("VRAM:")[1].strip().split()[0]
vram = float(vram_str)
if vram >= 24:
return 256
elif vram >= 16:
return 128
elif vram >= 8:
return 64
elif vram >= 4:
return 32
else:
return 16
except (ValueError, IndexError):
return 32
elif gpu_check and "MPS" in gpu_check.value:
# Apple Silicon — use unified memory heuristic
if total_ram >= 32:
return 64
elif total_ram >= 16:
return 32
else:
return 16
else:
# CPU only
if total_ram >= 32:
return 64
elif total_ram >= 16:
return 32
elif total_ram >= 8:
return 8
else:
return 4
def estimate_memory_gb(
num_series: int,
context_length: int,
horizon: int = 0,
batch_size: int = 32,
model_version: str = "v2.5",
) -> dict[str, float]:
"""Estimate memory requirements for a dataset.
Args:
num_series: Number of time series in the dataset
context_length: Length of each time series context window
horizon: Forecast horizon (optional, for output storage)
batch_size: Batch size for inference
model_version: Model version being used
Returns:
Dictionary with memory estimates in GB for different components
"""
# Base model memory (weights + overhead)
model_memory_gb = 0.8 # ~800MB for model weights
overhead_gb = 0.5 # Python overhead, libraries, etc.
# Input data memory: each value is float32 (4 bytes)
# Formula: num_series * context_length * 4 bytes / (1024^3)
input_gb = (num_series * context_length * 4) / (1024**3)
# Batch processing memory (peak during inference)
# Each batch needs: batch_size * context_length * 4 bytes
batch_input_gb = (batch_size * context_length * 4) / (1024**3)
# Output memory: horizon * num_series * quantiles * 4 bytes
# Default is 10 quantiles (mean + 9 quantiles)
num_quantiles = 10
output_gb = (num_series * horizon * num_quantiles * 4) / (1024**3) if horizon > 0 else 0
# Total memory with some headroom for intermediate computations
total_gb = model_memory_gb + overhead_gb + input_gb + batch_input_gb + output_gb
# Add 20% buffer for intermediate tensors and OS overhead
total_with_buffer = total_gb * 1.2
return {
"model_weights": model_memory_gb,
"overhead": overhead_gb,
"input_data": input_gb,
"batch_processing": batch_input_gb,
"output_data": output_gb,
"total": total_gb,
"total_with_buffer": total_with_buffer,
}
def check_dataset_fit(
num_series: int,
context_length: int,
horizon: int = 0,
batch_size: int = 32,
model_version: str = "v2.5",
) -> tuple[bool, str, dict[str, float]]:
"""Check if a dataset will fit in available memory.
Args:
num_series: Number of time series in the dataset
context_length: Length of each time series context window
horizon: Forecast horizon (optional)
batch_size: Batch size for inference
model_version: Model version being used
Returns:
Tuple of (fits: bool, message: str, memory_details: dict)
"""
memory = estimate_memory_gb(num_series, context_length, horizon, batch_size, model_version)
total_ram = _get_total_ram_gb()
available_ram = _get_available_ram_gb()
required = memory["total_with_buffer"]
# Leave 10% headroom for OS and other processes
usable_ram = total_ram * 0.9
usable_available = available_ram * 0.9 if available_ram > 0 else usable_ram
if required > total_ram:
return (
False,
f"Dataset requires {required:.1f} GB but system only has {total_ram:.1f} GB RAM. "
f"Consider processing in chunks or using a machine with more RAM.",
memory,
)
elif required > usable_available:
return (
False,
f"Dataset requires {required:.1f} GB but only {available_ram:.1f} GB is available. "
f"Close other applications or restart to free memory.",
memory,
)
elif required > usable_ram * 0.8:
return (
True,
f"Dataset will fit ({required:.1f} GB needed, {total_ram:.1f} GB total) "
f"but memory usage will be high. Consider reducing batch_size.",
memory,
)
else:
return (
True,
f"Dataset fits comfortably: {required:.1f} GB needed, {total_ram:.1f} GB available.",
memory,
)
def print_memory_estimate(
num_series: int,
context_length: int,
horizon: int = 0,
batch_size: int = 32,
model_version: str = "v2.5",
) -> None:
"""Print a detailed memory estimate for a dataset.
Args:
num_series: Number of time series in the dataset
context_length: Length of each time series context window
horizon: Forecast horizon (optional)
batch_size: Batch size for inference
model_version: Model version being used
"""
memory = estimate_memory_gb(num_series, context_length, horizon, batch_size, model_version)
total_ram = _get_total_ram_gb()
available_ram = _get_available_ram_gb()
print(f"\n{'=' * 50}")
print(f" Memory Estimate for Dataset")
print(f"{'=' * 50}")
print(f" Dataset: {num_series:,} series × {context_length} context length")
if horizon > 0:
print(f" Horizon: {horizon} steps")
print(f" Batch size: {batch_size}")
print(f" Model: {model_version}")
print(f"{'-' * 50}")
print(f" Model weights: {memory['model_weights']:.2f} GB")
print(f" Overhead: {memory['overhead']:.2f} GB")
print(f" Input data: {memory['input_data']:.2f} GB")
print(f" Batch processing: {memory['batch_processing']:.2f} GB")
if horizon > 0:
print(f" Output data: {memory['output_data']:.2f} GB")
print(f"{'-' * 50}")
print(f" Total (raw): {memory['total']:.2f} GB")
print(f" Total (+20% buf): {memory['total_with_buffer']:.2f} GB")
print(f"{'-' * 50}")
print(f" System RAM: {total_ram:.1f} GB")
print(f" Available RAM: {available_ram:.1f} GB")
print(f"{'=' * 50}")
fits, message, _ = check_dataset_fit(
num_series, context_length, horizon, batch_size, model_version
)
status_icon = "✅" if fits else "🛑"
print(f" {status_icon} {message}")
print(f"{'=' * 50}\n")
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def run_checks(model_version: str = "v2.5") -> SystemReport:
"""Run all system checks and return a report."""
profile = MODEL_PROFILES[model_version]
report = SystemReport(model=profile["name"])
# Run checks
report.checks.append(check_ram(profile))
report.checks.append(check_gpu())
report.checks.append(check_disk(profile))
report.checks.append(check_python())
report.checks.append(check_package("timesfm"))
report.checks.append(check_package("torch"))
# Determine mode
gpu_check = next((c for c in report.checks if c.name == "GPU"), None)
if gpu_check and gpu_check.status == "pass":
if "MPS" in gpu_check.value:
report.mode = "mps"
else:
report.mode = "gpu"
else:
report.mode = "cpu"
# Batch size
report.recommended_batch_size = recommend_batch_size(report)
# Verdict
if report.passed:
report.verdict = (
f"✅ System is ready for {profile['name']} ({report.mode.upper()} mode)"
)
report.verdict_detail = (
f"Recommended: per_core_batch_size={report.recommended_batch_size}"
)
else:
failed = [c for c in report.checks if c.status == "fail"]
report.verdict = f"🛑 System does NOT meet requirements for {profile['name']}"
report.verdict_detail = "; ".join(c.detail for c in failed)
return report
def print_report(report: SystemReport) -> None:
"""Print a human-readable report to stdout."""
print(f"\n{'=' * 50}")
print(f" TimesFM System Requirements Check")
print(f" Model: {report.model}")
print(f"{'=' * 50}\n")
for check in report.checks:
print(f" {check}")
print()
print(f" VERDICT: {report.verdict}")
if report.verdict_detail:
print(f" {report.verdict_detail}")
print()
def main() -> None:
parser = argparse.ArgumentParser(
description="Check system requirements for TimesFM.",
)
parser.add_argument(
"--model",
choices=list(MODEL_PROFILES.keys()),
default="v2.5",
help="Model version to check requirements for (default: v2.5)",
)
parser.add_argument(
"--json",
action="store_true",
help="Output results as JSON (machine-readable)",
)
# Dataset preflight options (NEW)
dataset_group = parser.add_argument_group("dataset preflight (optional)")
dataset_group.add_argument(
"--num-series",
type=int,
metavar="N",
help="Number of time series in your dataset (for memory estimation)",
)
dataset_group.add_argument(
"--context-length",
type=int,
metavar="LEN",
help="Length of each input time series (max_context value)",
)
dataset_group.add_argument(
"--horizon",
type=int,
metavar="H",
default=24,
help="Forecast horizon length (default: 24)",
)
dataset_group.add_argument(
"--batch-size",
type=int,
metavar="SIZE",
default=32,
help="per_core_batch_size from ForecastConfig (default: 32)",
)
dataset_group.add_argument(
"--estimate-only",
action="store_true",
help="Only show memory estimate, skip system checks",
)
args = parser.parse_args()
# Handle dataset estimation only mode
if args.estimate_only and args.num_series and args.context_length:
print_memory_estimate(
args.num_series,
args.context_length,
args.horizon,
args.batch_size,
args.model,
)
sys.exit(0)
# Run system checks
report = run_checks(args.model)
# Add dataset check if parameters provided
if args.num_series and args.context_length:
print_memory_estimate(
args.num_series,
args.context_length,
args.horizon,
args.batch_size,
args.model,
)
if args.json:
print(json.dumps(report.to_dict(), indent=2))
else:
print_report(report)
# Exit with non-zero if any check failed
sys.exit(0 if report.passed else 1)
if __name__ == "__main__":
main()
================================================
FILE: timesfm-forecasting/scripts/forecast_csv.py
================================================
#!/usr/bin/env python3
"""End-to-end CSV forecasting with TimesFM.
Loads a CSV, runs the system preflight check, loads TimesFM, forecasts
the requested columns, and writes results to a new CSV or JSON.
Usage:
python forecast_csv.py input.csv --horizon 24
python forecast_csv.py input.csv --horizon 12 --date-col date --value-cols sales,revenue
python forecast_csv.py input.csv --horizon 52 --output forecasts.csv
python forecast_csv.py input.csv --horizon 30 --output forecasts.json --format json
The script automatically:
1. Runs the system preflight check (exits if it fails).
2. Loads TimesFM 2.5 from Hugging Face.
3. Reads the CSV and identifies time series columns.
4. Forecasts each series with prediction intervals.
5. Writes results to the specified output file.
"""
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
import numpy as np
import pandas as pd
def run_preflight() -> dict:
"""Run the system preflight check and return the report."""
# Import the check_system module from the same directory
script_dir = Path(__file__).parent
sys.path.insert(0, str(script_dir))
from check_system import run_checks
report = run_checks("v2.5")
if not report.passed:
print("\n🛑 System check FAILED. Cannot proceed with forecasting.")
print(f" {report.verdict_detail}")
print("\nRun 'python scripts/check_system.py' for details.")
sys.exit(1)
return report.to_dict()
def load_model(batch_size: int = 32):
"""Load and compile the TimesFM model."""
import torch
import timesfm
torch.set_float32_matmul_precision("high")
print("Loading TimesFM 2.5 from Hugging Face...")
model = timesfm.TimesFM_2p5_200M_torch.from_pretrained(
"google/timesfm-2.5-200m-pytorch"
)
print(f"Compiling with per_core_batch_size={batch_size}...")
model.compile(
timesfm.ForecastConfig(
max_context=1024,
max_horizon=256,
normalize_inputs=True,
use_continuous_quantile_head=True,
force_flip_invariance=True,
infer_is_positive=True,
fix_quantile_crossing=True,
per_core_batch_size=batch_size,
)
)
return model
def load_csv(
path: str,
date_col: str | None = None,
value_cols: list[str] | None = None,
) -> tuple[pd.DataFrame, list[str], str | None]:
"""Load CSV and identify time series columns.
Returns:
(dataframe, value_column_names, date_column_name_or_none)
"""
df = pd.read_csv(path)
# Identify date column
if date_col and date_col in df.columns:
df[date_col] = pd.to_datetime(df[date_col])
elif date_col:
print(f"⚠️ Date column '{date_col}' not found. Available: {list(df.columns)}")
date_col = None
# Identify value columns
if value_cols:
missing = [c for c in value_cols if c not in df.columns]
if missing:
print(f"⚠️ Columns not found: {missing}. Available: {list(df.columns)}")
value_cols = [c for c in value_cols if c in df.columns]
else:
# Auto-detect numeric columns (exclude date)
numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
if date_col and date_col in numeric_cols:
numeric_cols.remove(date_col)
value_cols = numeric_cols
if not value_cols:
print("🛑 No numeric columns found to forecast.")
sys.exit(1)
print(f"Found {len(value_cols)} series to forecast: {value_cols}")
return df, value_cols, date_col
def forecast_series(
model, df: pd.DataFrame, value_cols: list[str], horizon: int
) -> dict[str, dict]:
"""Forecast all series and return results dict."""
inputs = []
for col in value_cols:
values = df[col].dropna().values.astype(np.float32)
inputs.append(values)
print(f"Forecasting {len(inputs)} series with horizon={horizon}...")
point, quantiles = model.forecast(horizon=horizon, inputs=inputs)
results = {}
for i, col in enumerate(value_cols):
results[col] = {
"forecast": point[i].tolist(),
"lower_90": quantiles[i, :, 1].tolist(), # 10th percentile
"lower_80": quantiles[i, :, 2].tolist(), # 20th percentile
"median": quantiles[i, :, 5].tolist(), # 50th percentile
"upper_80": quantiles[i, :, 8].tolist(), # 80th percentile
"upper_90": quantiles[i, :, 9].tolist(), # 90th percentile
}
return results
def write_csv_output(
results: dict[str, dict],
output_path: str,
df: pd.DataFrame,
date_col: str | None,
horizon: int,
) -> None:
"""Write forecast results to CSV."""
rows = []
for col, data in results.items():
# Try to generate future dates
future_dates = list(range(1, horizon + 1))
if date_col and date_col in df.columns:
try:
last_date = df[date_col].dropna().iloc[-1]
freq = pd.infer_freq(df[date_col].dropna())
if freq:
future_dates = pd.date_range(
last_date, periods=horizon + 1, freq=freq
)[1:].tolist()
except Exception:
pass
for h in range(horizon):
row = {
"series": col,
"step": h + 1,
"forecast": data["forecast"][h],
"lower_90": data["lower_90"][h],
"lower_80": data["lower_80"][h],
"median": data["median"][h],
"upper_80": data["upper_80"][h],
"upper_90": data["upper_90"][h],
}
if isinstance(future_dates[0], (pd.Timestamp,)):
row["date"] = future_dates[h]
rows.append(row)
out_df = pd.DataFrame(rows)
out_df.to_csv(output_path, index=False)
print(f"✅ Wrote {len(rows)} forecast rows to {output_path}")
def write_json_output(results: dict[str, dict], output_path: str) -> None:
"""Write forecast results to JSON."""
with open(output_path, "w") as f:
json.dump(results, f, indent=2)
print(f"✅ Wrote forecasts for {len(results)} series to {output_path}")
def main() -> None:
parser = argparse.ArgumentParser(
description="Forecast time series from CSV using TimesFM."
)
parser.add_argument("input", help="Path to input CSV file")
parser.add_argument(
"--horizon", type=int, required=True, help="Number of steps to forecast"
)
parser.add_argument("--date-col", help="Name of the date/time column")
parser.add_argument(
"--value-cols",
help="Comma-separated list of value columns to forecast (default: all numeric)",
)
parser.add_argument(
"--output",
default="forecasts.csv",
help="Output file path (default: forecasts.csv)",
)
parser.add_argument(
"--format",
choices=["csv", "json"],
default=None,
help="Output format (inferred from --output extension if not set)",
)
parser.add_argument(
"--batch-size",
type=int,
default=None,
help="Override per_core_batch_size (auto-detected from system check if omitted)",
)
parser.add_argument(
"--skip-check",
action="store_true",
help="Skip system preflight check (not recommended)",
)
args = parser.parse_args()
# Parse value columns
value_cols = None
if args.value_cols:
value_cols = [c.strip() for c in args.value_cols.split(",")]
# Determine output format
out_format = args.format
if not out_format:
out_format = "json" if args.output.endswith(".json") else "csv"
# 1. Preflight check
if not args.skip_check:
print("Running system preflight check...")
report = run_preflight()
batch_size = args.batch_size or report.get("recommended_batch_size", 32)
else:
print("⚠️ Skipping system check (--skip-check). Proceed with caution.")
batch_size = args.batch_size or 32
# 2. Load model
model = load_model(batch_size=batch_size)
# 3. Load CSV
df, cols, date_col = load_csv(args.input, args.date_col, value_cols)
# 4. Forecast
results = forecast_series(model, df, cols, args.horizon)
# 5. Write output
if out_format == "json":
write_json_output(results, args.output)
else:
write_csv_output(results, args.output, df, date_col, args.horizon)
print("\nDone! 🎉")
if __name__ == "__main__":
main()
================================================
FILE: v1/LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: v1/README.md
================================================
# TimesFM
TimesFM (Time Series Foundation Model) is a pretrained time-series foundation model developed by Google
Research for time-series forecasting.
* Paper: [A decoder-only foundation model for time-series forecasting](https://arxiv.org/abs/2310.10688), to appear in ICML 2024.
* [Google Research blog](https://research.google/blog/a-decoder-only-foundation-model-for-time-series-forecasting/)
* [Hugging Face release](https://huggingface.co/collections/google/timesfm-release-66e4be5fdb56e960c1e482a6)
This repo contains the code to load public TimesFM checkpoints and run model
inference. Please visit our
[Hugging Face release](https://huggingface.co/collections/google/timesfm-release-66e4be5fdb56e960c1e482a6)
to download model checkpoints.
This is not an officially supported Google product.
We recommend at least 32GB RAM to load TimesFM dependencies.
**Need help?** See [TROUBLESHOOTING.md](TROUBLESHOOTING.md) for common installation and usage issues.
## Update - Dec. 30, 2024
- We are launching a 500m checkpoint as a part of TimesFM-2.0 release. This new checkpoint can be upto 25% better than v1.0 on leading benchmarks and also has a 4 times longer max. context length.
- Launched [finetuning support](https://github.com/google-research/timesfm/blob/master/notebooks/finetuning.ipynb) that lets you finetune the weights of the pretrained TimesFM model on your own data.
- Launched [~zero-shot covariate support](https://github.com/google-research/timesfm/blob/master/notebooks/covariates.ipynb) with external regressors. More details [here](https://github.com/google-research/timesfm?tab=readme-ov-file#covariates-support).
## Update - Feb. 17, 2024
- We are providing the option for [finetuning using Pytorch](https://github.com/google-research/timesfm/blob/master/notebooks/finetuning_torch.ipynb), which mimics the previously added functionality from [finetuning support](https://github.com/google-research/timesfm/blob/master/notebooks/finetuning.ipynb).
- We are also providing the Multi-GPU finetuining with Pytorch. We currently support DDP multi-gpu finetuning, other variants of multi-gpu training (pipeline parallelism/model parallelism) might be added later. In order to use it, follow the steps in [finetuning example](https://github.com/google-research/timesfm/blob/master/finetuning/finetuning_example.py) .
## Checkpoint timesfm-1.0-200m (-pytorch)
timesfm-1.0-200m is our first open model checkpoint:
- It performs univariate time series forecasting for context lengths up to 512 timepoints and any horizon lengths, with an optional frequency indicator.
- It focuses on point forecasts, and does not support probabilistic forecasts. We experimentally offer quantile heads but they have not been calibrated after pretraining.
## Checkpoint timesfm-2.0-500m (-jax/-pytorch)
timesfm-2.0-500m is our second open model checkpoint:
- It performs univariate time series forecasting for context lengths up to 2048 timepoints and any horizon lengths, with an optional frequency indicator.
- It focuses on point forecasts. We experimentally offer 10 quantile heads but they have not been calibrated after pretraining.
- This new checkpoint can be upto 25% better than v1.0 on leading benchmarks and also has a 4 times longer max. context length.
## Benchmarking
TimesFM 2.0 has been added to [GIFT-Eval](https://huggingface.co/spaces/Salesforce/GIFT-Eval) which is one of the most comprehensive time-series bechmarks available. It takes the top spot in terms of aggregated MASE and CRPS, where it is 6\% better than the next best model in terms of aggregated MASE.
## Installation
### Local installation using poetry
We will be using `pyenv` and `poetry`. In order to set these things up please follow the instructions [here](https://substack.com/home/post/p-148747960?r=28a5lx&utm_campaign=post&utm_medium=web). Note that the PAX (or JAX) version needs to run on python 3.10.x and the PyTorch version can run on >=3.11.x. Therefore make sure you have two versions of python installed:
```
pyenv install 3.10
pyenv install 3.11
pyenv versions # to list the versions available (lets assume the versions are 3.10.15 and 3.11.10)
```
### For PAX version installation do the following.
```
pyenv local 3.10.15
poetry env use 3.10.15
poetry lock
poetry install -E pax
```
After than you can run the timesfm under `poetry shell` or do `poetry run python3 ...`.
### For PyTorch version installation do the following.
```
pyenv local 3.11.10
poetry env use 3.11.10
poetry lock
poetry install -E torch
```
After than you can run the timesfm under `poetry shell` or do `poetry run python3 ...`.
**Additional Note**:
If you plan to use the **`forecast_with_covariates`** function (which requires external regressors),
you need to install **JAX** and **jaxlib**. If you installed the base version of TimesFM (`torch`), you must manually install the dependencies for **`forecast_with_covariates`** support:
```
pip install jax jaxlib
```
**Why is this needed?**
The `forecast_with_covariates` method relies on the `xreg_lib` module, which depends on JAX and jaxlib. If these packages are not installed,
calling `forecast_with_covariates` will raise an error. However, due to a lazy import mechanism, `xreg_lib` (and hence JAX/jaxlib) is not needed for standard `forecast` calls.
### Notes
1. Running the provided benchmarks would require additional dependencies. Please see the `experiments` folder.
2. The dependency `lingvo` does not support ARM architectures, and the code is not working for machines with Apple silicon. We are aware of this issue and are working on a solution. Stay tuned.
### Install from PyPI (and publish)
On python 3.11 you can install the torch version using:
```pip install timesfm[torch]```
On python 3.10 you can install the pax version using:
```pip install timesfm[pax]```
## Usage
### Initialize the model and load a checkpoint.
Then the base class can be loaded as,
```python
import timesfm
# Loading the timesfm-2.0 checkpoint:
# For PAX
tfm = timesfm.TimesFm(
hparams=timesfm.TimesFmHparams(
backend="gpu",
per_core_batch_size=32,
horizon_len=128,
num_layers=50,
context_len=2048,
use_positional_embedding=False,
),
checkpoint=timesfm.TimesFmCheckpoint(
huggingface_repo_id="google/timesfm-2.0-500m-jax"),
)
# For Torch
tfm = timesfm.TimesFm(
hparams=timesfm.TimesFmHparams(
backend="gpu",
per_core_batch_size=32,
horizon_len=128,
num_layers=50,
use_positional_embedding=False,
context_len=2048,
),
checkpoint=timesfm.TimesFmCheckpoint(
huggingface_repo_id="google/timesfm-2.0-500m-pytorch"),
)
# Loading the timesfm-1.0 checkpoint:
# For PAX
tfm = timesfm.TimesFm(
hparams=timesfm.TimesFmHparams(
backend="gpu",
per_core_batch_size=32,
horizon_len=128,
),
checkpoint=timesfm.TimesFmCheckpoint(
huggingface_repo_id="google/timesfm-1.0-200m"),
)
# For Torch
tfm = timesfm.TimesFm(
hparams=timesfm.TimesFmHparams(
backend="gpu",
per_core_batch_size=32,
horizon_len=128,
),
checkpoint=timesfm.TimesFmCheckpoint(
huggingface_repo_id="google/timesfm-1.0-200m-pytorch"),
)
```
Note some of the parameters are fixed to load the 200m and 500m models
1. The `context_len` in `hparams` here can be set as the max context length **of the model** (a maximum of 2048 for 2.0 models and 512 for 1.0 models). **It needs to be a multiplier of `input_patch_len`, i.e. a multiplier of 32.** You can provide a shorter series to the `tfm.forecast()` function and the model will handle it. The input time series can have **any context length**. Padding / truncation will be handled by the inference code if needed.
2. The horizon length can be set to anything. We recommend setting it to the largest horizon length you would need in the forecasting tasks for your application. We generally recommend horizon length <= context length but it is not a requirement in the function call.
3. `backend` is one of "cpu", "gpu", case sensitive.
### Perform inference
We provide APIs to forecast from either array inputs or `pandas` dataframe. Both forecast methods expect (1) the input time series contexts, (2) along with their frequencies. Please look at the documentation of the functions `tfm.forecast()` and `tfm.forecast_on_df()` for detailed instructions.
In particular regarding the frequency, TimesFM expects a categorical indicator valued in {0, 1, 2}:
- **0** (default): high frequency, long horizon time series. We recommend using this for time series up to daily granularity.
- **1**: medium frequency time series. We recommend using this for weekly and monthly data.
- **2**: low frequency, short horizon time series. We recommend using this for anything beyond monthly, e.g. quarterly or yearly.
This categorical value should be directly provided with the array inputs. For dataframe inputs, we convert the conventional letter coding of frequencies to our expected categories, that
- **0**: T, MIN, H, D, B, U
- **1**: W, M
- **2**: Q, Y
Notice you do **NOT** have to strictly follow our recommendation here. Although this is our setup during model training and we expect it to offer the best forecast result, you can also view the frequency input as a free parameter and modify it per your specific use case.
Examples:
Array inputs, with the frequencies set to low, medium and high respectively.
```python
import numpy as np
forecast_input = [
np.sin(np.linspace(0, 20, 100)),
np.sin(np.linspace(0, 20, 200)),
np.sin(np.linspace(0, 20, 400)),
]
frequency_input = [0, 1, 2]
point_forecast, experimental_quantile_forecast = tfm.forecast(
forecast_input,
freq=frequency_input,
)
```
`pandas` dataframe, with the frequency set to "M" monthly.
```python
import pandas as pd
# e.g. input_df is
# unique_id ds y
# 0 T1 1975-12-31 697458.0
# 1 T1 1976-01-31 1187650.0
# 2 T1 1976-02-29 1069690.0
# 3 T1 1976-03-31 1078430.0
# 4 T1 1976-04-30 1059910.0
# ... ... ... ...
# 8175 T99 1986-01-31 602.0
# 8176 T99 1986-02-28 684.0
# 8177 T99 1986-03-31 818.0
# 8178 T99 1986-04-30 836.0
# 8179 T99 1986-05-31 878.0
forecast_df = tfm.forecast_on_df(
inputs=input_df,
freq="M", # monthly
value_name="y",
num_jobs=-1,
)
```
## Covariates Support
We now have an external regressors library on top of TimesFM that can support static covariates as well as dynamic covariates available in the future. We have an usage example in [notebooks/covariates.ipynb](https://github.com/google-research/timesfm/blob/master/notebooks/covariates.ipynb).
If you plan to use the **`forecast_with_covariates`** on timesfm `torch` version, you need to install **JAX** and **jaxlib**.
You must manually install the dependencies for **`forecast_with_covariates`** support:
```
pip install jax jaxlib
```
Let's take a toy example of forecasting sales for a grocery store:
**Task:** Given the observed the daily sales of this week (7 days), forecast the daily sales of next week (7 days).
```
Product: ice cream
Daily_sales: [30, 30, 4, 5, 7, 8, 10]
Category: food
Base_price: 1.99
Weekday: [0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6]
Has_promotion: [Yes, Yes, No, No, No, Yes, Yes, No, No, No, No, No, No, No]
Daily_temperature: [31.0, 24.3, 19.4, 26.2, 24.6, 30.0, 31.1, 32.4, 30.9, 26.0, 25.0, 27.8, 29.5, 31.2]
```
```
Product: sunscreen
Daily_sales: [5, 7, 12, 13, 5, 6, 10]
Category: skin product
Base_price: 29.99
Weekday: [0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6]
Has_promotion: [No, No, Yes, Yes, No, No, No, Yes, Yes, Yes, Yes, Yes, Yes, Yes]
Daily_temperature: [31.0, 24.3, 19.4, 26.2, 24.6, 30.0, 31.1, 32.4, 30.9, 26.0, 25.0, 27.8, 29.5, 31.2]
```
In this example, besides the `Daily_sales`, we also have covariates `Category`, `Base_price`, `Weekday`, `Has_promotion`, `Daily_temperature`. Let's introduce some concepts:
**Static covariates** are covariates for each time series.
- In our example, `Category` is a **static categorical covariate**,
- `Base_price` is a **static numerical covariates**.
**Dynamic covariates** are covaraites for each time stamps.
- Date / time related features can be usually treated as dynamic covariates.
- In our example, `Weekday` and `Has_promotion` are **dynamic categorical covariates**.
- `Daily_temperate` is a **dynamic numerical covariate**.
**Notice:** Here we make it mandatory that the dynamic covariates need to cover both the forecasting context and horizon. For example, all dynamic covariates in the example have 14 values: the first 7 correspond to the observed 7 days, and the last 7 correspond to the next 7 days.
We can now provide the past data of the two products along with static and dynamic covariates as a batch input to TimesFM and produce forecasts that take into the account the covariates. To learn more, check out the example in [notebooks/covariates.ipynb](https://github.com/google-research/timesfm/blob/master/notebooks/covariates.ipynb).
## Finetuning
We have provided an example of finetuning the model on a new dataset in [notebooks/finetuning.ipynb](https://github.com/google-research/timesfm/blob/master/notebooks/finetuning.ipynb).
## Contribution Style guide
If you would like to submit a PR please make sure that you use our formatting style. We use [yapf](https://github.com/google/yapf) for formatting with the following options,
```
[style]
based_on_style = google
# Add your custom style rules here
indent_width = 2
spaces_before_comment = 2
```
Please run `yapf --in-place --recursive ` on all affected files.
================================================
FILE: v1/TROUBLESHOOTING.md
================================================
# Troubleshooting
This document provides solutions to common issues encountered when using TimesFM.
## Installation Issues
### ARM/Apple Silicon Compatibility
**Problem:** `lingvo` dependency fails on Apple Silicon (M1/M2/M3) machines.
```
ERROR: Could not build wheels for lingvo
```
**Solution:** This is a known issue. The `lingvo` dependency doesn't support ARM architectures. We recommend:
- Use x86_64 emulation via Rosetta 2: `arch -x86_64 pip install timesfm[pax]`
- Use the PyTorch version instead, which has better ARM support: `pip install timesfm[torch]`
- Use Docker with x86_64 emulation for consistent environments
### Memory Issues During Installation
**Problem:** Installation fails with memory errors.
```
Killed (signal 9)
```
**Solution:**
- Ensure at least 32GB RAM is available
- Close other applications during installation
- Use `pip install --no-cache-dir timesfm[torch]` to reduce memory usage
- Install in a clean virtual environment
### JAX/PyTorch Version Conflicts
**Problem:** Conflicting JAX and PyTorch installations.
```
ImportError: cannot import name 'jax' from 'jax'
```
**Solution:**
- For PyTorch-only usage: `pip install timesfm[torch]`
- For covariates with PyTorch: `pip install timesfm[torch] && pip install jax jaxlib`
- For PAX version: `pip install timesfm[pax]`
## Runtime Errors
### Model Loading Issues
**Problem:** Checkpoint download fails or is corrupted.
```
HfFileNotFoundError: 404 Client Error
```
**Solution:**
- Check internet connectivity
- Verify Hugging Face Hub access: `huggingface-cli login`
- Clear cache: `rm -rf ~/.cache/huggingface/`
- Use explicit checkpoint paths if needed
### CUDA/GPU Issues
**Problem:** GPU not detected or CUDA errors.
```
RuntimeError: CUDA out of memory
```
**Solutions:**
- Reduce `per_core_batch_size` (try 16, 8, or 4)
- Reduce `context_len` to minimum needed
- Use `backend="cpu"` for testing
- Check GPU memory: `nvidia-smi`
### Context Length Errors
**Problem:** Input series longer than model capacity.
```
ValueError: context_len must be <= 512 for v1.0 models
```
**Solutions:**
- Use TimesFM-2.0 for longer contexts (up to 2048)
- Ensure `context_len` is multiple of 32
- Truncate input series if necessary
- Set appropriate `context_len` in model initialization
## Data Issues
### Frequency Mapping Problems
**Problem:** Unexpected forecasting results with wrong frequency.
```
Warning: Frequency 'D' mapped to category 0
```
**Solutions:**
- Verify frequency mapping: D→0 (high), W/M→1 (medium), Q/Y→2 (low)
- Override automatic mapping by specifying frequency manually
- Check data granularity matches chosen frequency category
### Missing Values in Time Series
**Problem:** NaN or missing values in input data.
```
ValueError: Input contains NaN values
```
**Solutions:**
- Pre-process data to handle missing values (forward fill, interpolation)
- Ensure continuous time series without gaps
- Remove or impute missing values before forecasting
### Covariate Dimension Mismatches
**Problem:** Covariate lengths don't match forecast horizon.
```
ValueError: Dynamic covariates must cover context + horizon
```
**Solutions:**
- Ensure dynamic covariates have length = context + horizon
- Check static vs dynamic covariate classification
- Verify covariate data alignment with time series
## Performance Issues
### Slow Inference
**Problem:** Forecasting takes unexpectedly long.
**Solutions:**
- Use GPU backend: `backend="gpu"`
- Optimize batch size: increase `per_core_batch_size`
- Use appropriate model size for your use case
- Profile with smaller data first
### Memory Usage
**Problem:** High memory consumption during inference.
**Solutions:**
- Reduce batch size: `per_core_batch_size=1`
- Process data in chunks
- Use smaller context length when possible
- Monitor memory with `htop` or `nvidia-smi`
## Common Error Messages
### `ModuleNotFoundError: No module named 'xreg_lib'`
**Cause:** Missing JAX dependencies for covariates functionality.
**Solution:** `pip install jax jaxlib`
### `ValueError: horizon_len must be positive`
**Cause:** Invalid horizon length specified.
**Solution:** Set `horizon_len > 0` in model initialization.
### `RuntimeError: Expected input batch_size (X) to be divisible by batch_size (Y)`
**Cause:** Batch size mismatch.
**Solution:** Adjust `per_core_batch_size` or input data batching.
## Getting Help
If you encounter issues not covered here:
1. Check the [GitHub Issues](https://github.com/google-research/timesfm/issues)
2. Review the [notebooks/](notebooks/) for working examples
3. Verify your installation follows the exact steps in the Installation section
4. Test with the provided example data before using your own datasets
================================================
FILE: v1/docs/contributing.md
================================================
# How to Contribute
We would love to accept your patches and contributions to this project.
## Before you begin
### Sign our Contributor License Agreement
Contributions to this project must be accompanied by a
[Contributor License Agreement](https://cla.developers.google.com/about) (CLA).
You (or your employer) retain the copyright to your contribution; this simply
gives us permission to use and redistribute your contributions as part of the
project.
If you or your current employer have already signed the Google CLA (even if it
was for a different project), you probably don't need to do it again.
Visit to see your current agreements or to
sign a new one.
### Review our Community Guidelines
This project follows [Google's Open Source Community
Guidelines](https://opensource.google/conduct/).
## Contribution process
### Code Reviews
All submissions, including submissions by project members, require review. We
use [GitHub pull requests](https://docs.github.com/articles/about-pull-requests)
for this purpose.
================================================
FILE: v1/experiments/baselines/__init__.py
================================================
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
================================================
FILE: v1/experiments/baselines/timegpt_pipeline.py
================================================
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from time import time
from typing import List, Optional, Tuple
from dotenv import load_dotenv
from gluonts.time_feature.seasonality import get_seasonality as _get_seasonality
from nixtla import NixtlaClient
import pandas as pd
from tqdm import tqdm
from utilsforecast.processing import (
backtest_splits,
drop_index_if_pandas,
join,
maybe_compute_sort_indices,
take_rows,
vertical_concat,
)
def get_seasonality(freq: str) -> int:
return _get_seasonality(freq, seasonalities={"D": 7})
def maybe_convert_col_to_datetime(
df: pd.DataFrame, col_name: str
) -> pd.DataFrame:
if not pd.api.types.is_datetime64_any_dtype(df[col_name]):
df = df.copy()
df[col_name] = pd.to_datetime(df[col_name])
return df
def zero_pad_time_series(df, freq, min_length=36):
"""If time_series length is less than min_length, front pad it with zeros."""
# 1. Calculate required padding for each unique_id
value_counts = df["unique_id"].value_counts()
to_pad = value_counts[value_counts < min_length].index
# 2. Create a new DataFrame to hold padded data
padded_data = []
for unique_id in to_pad:
# 2a. Filter data for the specific unique_id
subset = df[df["unique_id"] == unique_id]
if len(subset) > min_length:
padded_data.append(subset)
else:
# 2b. Determine earliest date and calculate padding dates
start_date = subset["ds"].min()
padding_dates = pd.date_range(
end=start_date,
periods=min_length - len(subset) + 1,
freq=freq, # 'MS' for month start
)[
:-1
] # Exclude the start_date itself
# 2c. Create padding data
padding_df = pd.DataFrame(
{"ds": padding_dates, "unique_id": unique_id, "y": 0} # Zero padding
)
# 2d. Combine original and padding data, and append to the list
padded_data.append(pd.concat([padding_df, subset]).sort_values("ds"))
# 3. Combine all padded data and original data (unchanged)
result_df = pd.concat(padded_data + [df[~df["unique_id"].isin(to_pad)]])
return result_df
class Forecaster:
"""Borrowed from
https://github.com/Nixtla/nixtla/tree/main/experiments/foundation-time-series-arena/xiuhmolpilli/models.
"""
def forecast(
self,
df: pd.DataFrame,
h: int,
freq: str,
) -> pd.DataFrame:
raise NotImplementedError
def cross_validation(
self,
df: pd.DataFrame,
h: int,
freq: str,
n_windows: int = 1,
step_size: int | None = None,
) -> pd.DataFrame:
df = maybe_convert_col_to_datetime(df, "ds")
# mlforecast cv code
results = []
sort_idxs = maybe_compute_sort_indices(df, "unique_id", "ds")
if sort_idxs is not None:
df = take_rows(df, sort_idxs)
splits = backtest_splits(
df,
n_windows=n_windows,
h=h,
id_col="unique_id",
time_col="ds",
freq=pd.tseries.frequencies.to_offset(freq),
step_size=h if step_size is None else step_size,
)
for _, (cutoffs, train, valid) in tqdm(enumerate(splits)):
if len(valid.columns) > 3:
raise NotImplementedError(
"Cross validation with exogenous variables is not yet supported."
)
y_pred = self.forecast(
df=train,
h=h,
freq=freq,
)
y_pred = join(y_pred, cutoffs, on="unique_id", how="left")
result = join(
valid[["unique_id", "ds", "y"]],
y_pred,
on=["unique_id", "ds"],
)
if result.shape[0] < valid.shape[0]:
raise ValueError(
"Cross validation result produced less results than expected."
" Please verify that the frequency parameter (freq) matches your"
" series' and that there aren't any missing periods."
)
results.append(result)
out = vertical_concat(results)
out = drop_index_if_pandas(out)
first_out_cols = ["unique_id", "ds", "cutoff", "y"]
remaining_cols = [c for c in out.columns if c not in first_out_cols]
fcst_cv_df = out[first_out_cols + remaining_cols]
return fcst_cv_df
class TimeGPT(Forecaster):
"""Borrowed from
https://github.com/Nixtla/nixtla/tree/main/experiments/foundation-time-series-arena/xiuhmolpilli/models.
We modify the class to take care of edge cases.
"""
def __init__(
self,
api_key: str | None = None,
base_url: Optional[str] = None,
max_retries: int = 1,
model: str = "timegpt-1",
alias: str = "TimeGPT",
):
self.api_key = api_key
self.base_url = base_url
self.max_retries = max_retries
self.model = model
self.alias = alias
def _get_client(self) -> NixtlaClient:
if self.api_key is None:
api_key = os.environ["NIXTLA_API_KEY"]
else:
api_key = self.api_key
return NixtlaClient(
api_key=api_key,
base_url=self.base_url,
max_retries=self.max_retries,
)
def forecast(
self,
df: pd.DataFrame,
h: int,
freq: str,
level: List = [90.0],
chunk_size: Optional[int] = None,
) -> pd.DataFrame:
client = self._get_client()
fcst_df = None
if chunk_size is None:
fcst_df = client.forecast(
df=df,
h=h,
freq=freq,
level=level,
model=self.model,
)
else:
all_unique_ids = df["unique_id"].unique()
all_fcst_df = []
for i in range(0, len(all_unique_ids), chunk_size):
chunk_ids = all_unique_ids[i : i + chunk_size]
chunk_df = df[df["unique_id"].isin(chunk_ids)]
fct_chunk_df = client.forecast(
df=chunk_df,
h=h,
freq=freq,
level=level,
)
all_fcst_df.append(fct_chunk_df)
fcst_df = pd.concat(all_fcst_df)
fcst_df["ds"] = pd.to_datetime(fcst_df["ds"])
replace_dict = {}
for col in fcst_df.columns:
if col.startswith("TimeGPT"):
replace_dict[col] = col.replace("TimeGPT", self.alias)
fcst_df = fcst_df.rename(columns=replace_dict)
return fcst_df
def run_timegpt(
train_df: pd.DataFrame,
horizon: int,
freq: str,
seasonality: int,
level: List[int],
dataset: str,
model: str = "timegpt-1",
) -> Tuple[pd.DataFrame, float, str]:
os.environ["NIXTLA_ID_AS_COL"] = "true"
model = TimeGPT(model="timegpt-1", alias=model)
padded_train_df = zero_pad_time_series(train_df, freq)
init_time = time()
# For these datasets the API fails if we do not chunk.
if dataset in ["m5", "m4_quarterly"]:
chunk_size = 5000
else:
chunk_size = None
fcsts_df = model.forecast(
df=padded_train_df,
h=horizon,
level=level,
freq=freq,
chunk_size=chunk_size,
)
total_time = time() - init_time
# In case levels are not returned we replace the levels with the mean predictions.
# Note that this does not affect the results table as we only compare on point
# forecastign metrics.
for lvl in level:
if f"{model.alias}-lo-{lvl}" not in fcsts_df.columns:
fcsts_df[f"{model.alias}-lo-{lvl}"] = fcsts_df[model.alias]
if f"{model.alias}-hi-{lvl}" not in fcsts_df.columns:
fcsts_df[f"{model.alias}-hi-{lvl}"] = fcsts_df[model.alias]
return fcsts_df, total_time, model.alias
================================================
FILE: v1/experiments/extended_benchmarks/README.md
================================================
# Extended Benchmarks
The benchmark setting has been borrowed from Nixtla's original [benchmarking](https://github.com/AzulGarza/nixtla/tree/main/experiments/amazon-chronos) of time-series foundation models against a strong statistical ensemble. Later more datasets were added by the Chronos team in this [pull request](https://github.com/shchur/nixtla/tree/chronos-full-eval/experiments/amazon-chronos). We compare on all the datasets in this extended benchmarks.
## Running TimesFM on the benchmark
We need to add the following packages for running these benchmarks. Follow the installation instructions till before `poetry lock`. Then,
```
poetry add git+https://github.com/awslabs/gluon-ts.git
poetry lock
poetry install --only
```
To run the timesfm on the benchmark do:
```
poetry run python3 -m experiments.extended_benchmarks.run_timesfm --model_path=google/timesfm-1.0-200m(-pytorch) --backend="gpu"
```
Note: In the current version of TimesFM we focus on point forecasts and therefore the mase, smape have been calculated using the quantile head corresponding to the median i.e 0.5 quantile. We do offer 10 quantile heads but they have not been calibrated after pretraining. We recommend using them with caution or calibrate/conformalize them on a hold out for your applications. More to follow on later versions.
## Benchmark Results for TimesFM-1.0

__Update:__ We have added TimeGPT-1 to the benchmark results. We had to remove the Dominick dataset as we were not able to run TimeGPT-1 on this benchmark. Note that the previous results including Dominick remain available at `./tfm_results.png`. In order to reproduce the results for TimeGPT-1, please run `run_timegpt.py`.
_Remark:_ All baselines except the ones involving TimeGPT were run performed on a [g2-standard-32](https://cloud.google.com/compute/docs/gpus). Since TimeGPT-1 can only be accessed by an API, the time column might not reflect the true speed of the model as it also includes the communication cost. Moreover, we are not sure about the exact backend hardware for TimeGPT. The TimesFM latency numbers are from the PAX version.
We can see that TimesFM performs the best in terms of both mase and smape. More importantly it is much faster than the other methods, in particular it is more than 600x faster than StatisticalEnsemble and 80x faster than Chronos (Large).
Note: This benchmark only compares on `one` small horizon window for long horizon datasets like ETT hourly and 15 minutes. More in depth comparison on longer horizon rolling validation tasks are presented in our long horizon benchmarks.
================================================
FILE: v1/experiments/extended_benchmarks/run_timegpt.py
================================================
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluation script for timegpt."""
import os
import sys
import time
from absl import flags
import numpy as np
import pandas as pd
from ..baselines.timegpt_pipeline import run_timegpt
from .utils import ExperimentHandler
dataset_names = [
"m1_monthly",
"m1_quarterly",
"m1_yearly",
"m3_monthly",
"m3_other",
"m3_quarterly",
"m3_yearly",
"m4_quarterly",
"m4_yearly",
"tourism_monthly",
"tourism_quarterly",
"tourism_yearly",
"nn5_daily_without_missing",
"m5",
"nn5_weekly",
"traffic",
"weather",
"australian_electricity_demand",
"car_parts_without_missing",
"cif_2016",
"covid_deaths",
"ercot",
"ett_small_15min",
"ett_small_1h",
"exchange_rate",
"fred_md",
"hospital",
]
_MODEL_NAME = flags.DEFINE_string(
"model_name",
"timegpt-1-long-horizon",
"Path to model, can also be set to timegpt-1",
)
_SAVE_DIR = flags.DEFINE_string("save_dir", "./results", "Save directory")
QUANTILES = list(np.arange(1, 10) / 10.0)
def main():
results_list = []
run_id = np.random.randint(100000)
model_name = _MODEL_NAME.value
for dataset in dataset_names:
print(f"Evaluating model {model_name} on dataset {dataset}", flush=True)
exp = ExperimentHandler(dataset, quantiles=QUANTILES)
train_df = exp.train_df
horizon = exp.horizon
seasonality = exp.seasonality
freq = exp.freq
level = exp.level
fcsts_df, total_time, model_name = run_timegpt(
train_df=train_df,
horizon=exp.horizon,
model=model_name,
seasonality=seasonality,
freq=freq,
dataset=dataset,
level=level,
)
time_df = pd.DataFrame({"time": [total_time], "model": model_name})
fcsts_df = exp.fcst_from_level_to_quantiles(fcsts_df, model_name)
results = exp.evaluate_from_predictions(
models=[model_name], fcsts_df=fcsts_df, times_df=time_df
)
print(results, flush=True)
results_list.append(results)
results_full = pd.concat(results_list)
save_path = os.path.join(_SAVE_DIR.value, str(run_id))
print(f"Saving results to {save_path}", flush=True)
os.makedirs(save_path, exist_ok=True)
results_full.to_csv(f"{save_path}/results.csv")
if __name__ == "__main__":
FLAGS = flags.FLAGS
FLAGS(sys.argv)
main()
================================================
FILE: v1/experiments/extended_benchmarks/run_timesfm.py
================================================
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluation script for timesfm."""
import os
import sys
import time
from absl import flags
import numpy as np
import pandas as pd
import timesfm
from .utils import ExperimentHandler
dataset_names = [
"m1_monthly",
"m1_quarterly",
"m1_yearly",
"m3_monthly",
"m3_other",
"m3_quarterly",
"m3_yearly",
"m4_quarterly",
"m4_yearly",
"tourism_monthly",
"tourism_quarterly",
"tourism_yearly",
"nn5_daily_without_missing",
"m5",
"nn5_weekly",
"traffic",
"weather",
"australian_electricity_demand",
"car_parts_without_missing",
"cif_2016",
"covid_deaths",
"ercot",
"ett_small_15min",
"ett_small_1h",
"exchange_rate",
"fred_md",
"hospital",
]
context_dict_v2 = {}
context_dict_v1 = {
"cif_2016": 32,
"tourism_yearly": 64,
"covid_deaths": 64,
"tourism_quarterly": 64,
"tourism_monthly": 64,
"m1_monthly": 64,
"m1_quarterly": 64,
"m1_yearly": 64,
"m3_monthly": 64,
"m3_other": 64,
"m3_quarterly": 64,
"m3_yearly": 64,
"m4_quarterly": 64,
"m4_yearly": 64,
}
_MODEL_PATH = flags.DEFINE_string("model_path", "google/timesfm-2.0-500m-jax",
"Path to model")
_BATCH_SIZE = flags.DEFINE_integer("batch_size", 64, "Batch size")
_HORIZON = flags.DEFINE_integer("horizon", 128, "Horizon")
_BACKEND = flags.DEFINE_string("backend", "gpu", "Backend")
_NUM_JOBS = flags.DEFINE_integer("num_jobs", 1, "Number of jobs")
_SAVE_DIR = flags.DEFINE_string("save_dir", "./results", "Save directory")
QUANTILES = list(np.arange(1, 10) / 10.0)
def main():
results_list = []
model_path = _MODEL_PATH.value
num_layers = 20
max_context_len = 512
use_positional_embedding = True
context_dict = context_dict_v1
if "2.0" in model_path:
num_layers = 50
use_positional_embedding = False
max_context_len = 2048
context_dict = context_dict_v2
tfm = timesfm.TimesFm(
hparams=timesfm.TimesFmHparams(
backend="gpu",
per_core_batch_size=32,
horizon_len=128,
num_layers=num_layers,
context_len=max_context_len,
use_positional_embedding=use_positional_embedding,
),
checkpoint=timesfm.TimesFmCheckpoint(huggingface_repo_id=model_path),
)
run_id = np.random.randint(100000)
model_name = "timesfm"
for dataset in dataset_names:
print(f"Evaluating model {model_name} on dataset {dataset}", flush=True)
exp = ExperimentHandler(dataset, quantiles=QUANTILES)
if dataset in context_dict:
context_len = context_dict[dataset]
else:
context_len = max_context_len
train_df = exp.train_df
freq = exp.freq
init_time = time.time()
fcsts_df = tfm.forecast_on_df(
inputs=train_df,
freq=freq,
value_name="y",
model_name=model_name,
forecast_context_len=context_len,
num_jobs=_NUM_JOBS.value,
normalize=True,
)
total_time = time.time() - init_time
time_df = pd.DataFrame({"time": [total_time], "model": model_name})
results = exp.evaluate_from_predictions(models=[model_name],
fcsts_df=fcsts_df,
times_df=time_df)
print(results, flush=True)
results_list.append(results)
results_full = pd.concat(results_list)
save_path = os.path.join(_SAVE_DIR.value, str(run_id))
print(f"Saving results to {save_path}", flush=True)
os.makedirs(save_path, exist_ok=True)
results_full.to_csv(f"{save_path}/results.csv")
if __name__ == "__main__":
FLAGS = flags.FLAGS
FLAGS(sys.argv)
main()
================================================
FILE: v1/experiments/extended_benchmarks/utils.py
================================================
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Forked from https://github.com/Nixtla/nixtla/blob/main/experiments/amazon-chronos/src/utils.py."""
from functools import partial
from itertools import repeat
import multiprocessing
import os
from pathlib import Path
from typing import List
from gluonts.dataset import Dataset
from gluonts.dataset.repository.datasets import (
dataset_names as gluonts_datasets,
get_dataset,
)
from gluonts.time_feature.seasonality import get_seasonality
import numpy as np
import pandas as pd
from utilsforecast.evaluation import evaluate
from utilsforecast.losses import mae, mase, smape
def parallel_transform(inp):
ts, last_n = inp[0], inp[1]
return ExperimentHandler._transform_gluonts_instance_to_df(ts, last_n=last_n)
def quantile_loss(
df: pd.DataFrame,
models: list,
q: float = 0.5,
id_col: str = "unique_id",
target_col: str = "y",
) -> pd.DataFrame:
delta_y = df[models].sub(df[target_col], axis=0)
res = (
np.maximum(q * delta_y, (q - 1) * delta_y)
.groupby(df[id_col], observed=True)
.mean()
)
res.index.name = id_col
res = res.reset_index()
return res
class ExperimentHandler:
def __init__(
self,
dataset: str,
quantiles: List[float] = list(np.arange(1, 10) / 10.0),
results_dir: str = "./results",
models_dir: str = "./models",
):
if dataset not in gluonts_datasets:
raise Exception(
f"dataset {dataset} not found in gluonts "
f"available datasets: {', '.join(gluonts_datasets)}"
)
self.dataset = dataset
self.quantiles = quantiles
self.level = self._transform_quantiles_to_levels(quantiles)
self.results_dir = results_dir
self.models_dir = models_dir
# defining datasets
self._maybe_download_m3_or_m5_file(self.dataset)
gluonts_dataset = get_dataset(self.dataset)
self.horizon = gluonts_dataset.metadata.prediction_length
if self.horizon is None:
raise Exception(
f"horizon not found for dataset {self.dataset} "
"experiment cannot be run"
)
self.freq = gluonts_dataset.metadata.freq
# get_seasonality() returns 1 for freq='D', override this to 7. This significantly improves the accuracy of
# statistical models on datasets like m5/nn5_daily. The models like AutoARIMA/AutoETS can still set
# seasonality=1 internally on datasets like weather by choosing non-seasonal models during model selection.
if self.freq == "D":
self.seasonality = 7
else:
self.seasonality = get_seasonality(self.freq)
self.gluonts_train_dataset = gluonts_dataset.train
self.gluonts_test_dataset = gluonts_dataset.test
self._create_dir_if_not_exists(self.results_dir)
try:
multiprocessing.set_start_method("spawn")
except RuntimeError:
print("Multiprocessing context has already been set.")
@staticmethod
def _maybe_download_m3_or_m5_file(dataset: str):
if dataset[:2] == "m3":
m3_file = Path.home() / ".gluonts" / "datasets" / "M3C.xls"
if not m3_file.exists():
from datasetsforecast.m3 import M3
from datasetsforecast.utils import download_file
download_file(m3_file.parent, M3.source_url)
elif dataset == "m5":
m5_raw_dir = Path.home() / ".gluonts" / "m5"
if not m5_raw_dir.exists():
import zipfile
from datasetsforecast.m5 import M5
from datasetsforecast.utils import download_file
download_file(m5_raw_dir, M5.source_url)
with zipfile.ZipFile(m5_raw_dir / "m5.zip", "r") as zip_ref:
zip_ref.extractall(m5_raw_dir)
@staticmethod
def _transform_quantiles_to_levels(quantiles: List[float]) -> List[int]:
level = [
int(100 - 200 * q) for q in quantiles if q < 0.5
] # in this case mean=mediain
level = sorted(list(set(level)))
return level
@staticmethod
def _create_dir_if_not_exists(directory: str):
Path(directory).mkdir(parents=True, exist_ok=True)
@staticmethod
def _transform_gluonts_instance_to_df(
ts: dict,
last_n: int | None = None,
) -> pd.DataFrame:
start_period = ts["start"]
start_ds, freq = start_period.to_timestamp(), start_period.freq
target = ts["target"]
ds = pd.date_range(start=start_ds, freq=freq, periods=len(target))
if last_n is not None:
target = target[-last_n:]
ds = ds[-last_n:]
ts_df = pd.DataFrame({"unique_id": ts["item_id"], "ds": ds, "y": target})
return ts_df
@staticmethod
def _transform_gluonts_dataset_to_df(
gluonts_dataset: Dataset,
last_n: int | None = None,
) -> pd.DataFrame:
with multiprocessing.Pool(os.cpu_count()) as pool: # Create a process pool
results = pool.map(
parallel_transform, zip(gluonts_dataset, repeat(last_n))
)
df = pd.concat(results)
df = df.reset_index(drop=True)
return df
@property
def train_df(self) -> pd.DataFrame:
train_df = self._transform_gluonts_dataset_to_df(self.gluonts_train_dataset)
return train_df
@property
def test_df(self) -> pd.DataFrame:
test_df = self._transform_gluonts_dataset_to_df(
self.gluonts_test_dataset,
last_n=self.horizon,
)
# Make sure that only the first backtest window is used for evaluation on `traffic` / `exchange_rate` datasets
return test_df.groupby("unique_id", sort=False).head(self.horizon)
def save_dataframe(self, df: pd.DataFrame, file_name: str):
df.to_csv(f"{self.results_dir}/{file_name}", index=False)
def save_results(
self, fcst_df: pd.DataFrame, total_time: float, model_name: str
):
self.save_dataframe(
fcst_df,
f"{model_name}-{self.dataset}-fcst.csv",
)
time_df = pd.DataFrame({"time": [total_time], "model": model_name})
self.save_dataframe(
time_df,
f"{model_name}-{self.dataset}-time.csv",
)
def fcst_from_level_to_quantiles(
self,
fcst_df: pd.DataFrame,
model_name: str,
) -> pd.DataFrame:
fcst_df = fcst_df.copy()
cols = ["unique_id", "ds", model_name]
for q in self.quantiles:
if q == 0.5:
col = f"{model_name}"
else:
lv = int(100 - 200 * q)
hi_or_lo = "lo" if lv > 0 else "hi"
lv = abs(lv)
col = f"{model_name}-{hi_or_lo}-{lv}"
q_col = f"{model_name}-q-{q}"
fcst_df[q_col] = fcst_df[col].values
cols.append(q_col)
return fcst_df[cols]
def evaluate_models(self, models: List[str]) -> pd.DataFrame:
fcsts_df = []
times_df = []
for model in models:
fcst_method_df = pd.read_csv(
f"{self.results_dir}/{model}-{self.dataset}-fcst.csv"
).set_index(["unique_id", "ds"])
fcsts_df.append(fcst_method_df)
time_method_df = pd.read_csv(
f"{self.results_dir}/{model}-{self.dataset}-time.csv"
)
times_df.append(time_method_df)
fcsts_df = pd.concat(fcsts_df, axis=1).reset_index()
fcsts_df["ds"] = pd.to_datetime(fcsts_df["ds"])
times_df = pd.concat(times_df)
return self.evaluate_from_predictions(
models=models, fcsts_df=fcsts_df, times_df=times_df
)
def evaluate_from_predictions(
self, models: List[str], fcsts_df: pd.DataFrame, times_df: pd.DataFrame
) -> pd.DataFrame:
test_df = self.test_df
train_df = self.train_df
test_df = test_df.merge(fcsts_df, how="left")
assert test_df.isna().sum().sum() == 0, "merge contains nas"
# point evaluation
point_fcsts_cols = ["unique_id", "ds", "y"] + models
test_df["unique_id"] = test_df["unique_id"].astype(str)
train_df["unique_id"] = train_df["unique_id"].astype(str)
mase_seas = partial(mase, seasonality=self.seasonality)
eval_df = evaluate(
test_df[point_fcsts_cols],
train_df=train_df,
metrics=[smape, mase_seas, mae],
)
# probabilistic evaluation
eval_prob_df = []
for q in self.quantiles:
prob_cols = [f"{model}-q-{q}" for model in models]
eval_q_df = quantile_loss(test_df, models=prob_cols, q=q)
eval_q_df[prob_cols] = eval_q_df[prob_cols] * self.horizon
eval_q_df = eval_q_df.rename(columns=dict(zip(prob_cols, models)))
eval_q_df["metric"] = f"quantile-loss-{q}"
eval_prob_df.append(eval_q_df)
eval_prob_df = pd.concat(eval_prob_df)
eval_prob_df = eval_prob_df.groupby("metric").sum().reset_index()
total_y = test_df["y"].sum()
eval_prob_df[models] = eval_prob_df[models] / total_y
eval_prob_df["metric"] = "scaled_crps"
eval_df = pd.concat([eval_df, eval_prob_df]).reset_index(drop=True)
eval_df = eval_df.groupby("metric").mean(numeric_only=True).reset_index()
eval_df = eval_df.melt(
id_vars="metric", value_name="value", var_name="model"
)
times_df.insert(0, "metric", "time")
times_df = times_df.rename(columns={"time": "value"})
eval_df = pd.concat([eval_df, times_df])
eval_df.insert(0, "dataset", self.dataset)
eval_df = eval_df.sort_values(["dataset", "metric", "model"])
eval_df = eval_df.reset_index(drop=True)
return eval_df
if __name__ == "__main__":
multiprocessing.set_start_method("spawn")
================================================
FILE: v1/experiments/long_horizon_benchmarks/README.md
================================================
# Extended Benchmarks
We benchmark on the original test set for ETT datasets as per long horizon benchmark papers (see [here](https://openreview.net/forum?id=pCbC3aQB5W) for example.) In the original benchmark, rolling validation task on all test windows (with a stride of 1) is considered. While we can easily run our method on this task, the baselines can take a very long time to run. Therefore we present results on a modified task with stride between windows set to Horizon length i.e all disjoint horizons in the test period is considered.
All experiments were performed on a [g2-standard-32](https://cloud.google.com/compute/docs/gpus). We compare TimesFM with [Amazon-Chronos](https://github.com/amazon-science/chronos-forecasting).
## Running TimesFM on the benchmark
We need to add the following packages for running these benchmarks. Follow the installation instructions till before `poetry lock`. Then,
```
poetry add git+https://github.com/awslabs/gluon-ts.git
poetry add git+https://github.com/amazon-science/chronos-forecasting.git
poetry lock
poetry install --only pax
```
Note that for now only the pax version runs on this benchmark, because we had to remove the old tf dependency from the pytorch version. We will fix this issue soon.
To run the timesfm on the benchmark do:
```
poetry run python3 -m experiments.long_horizon_benchmarks.run_eval \
--model_path=google/timesfm-1.0-200m --backend="gpu" \
--pred_len=96 --context_len=512 --dataset=etth1
```
In the above, `` should point to the checkpoint directory that can be downloaded from HuggingFace.
For running chronos on the same benchmark you can run the command,
```
poetry run python3 -m experiments.long_horizon_benchmarks.run_eval \
--model_path=amazon/chronos-t5-mini --backend="gpu" \
--pred_len=96 --context_len=512 --dataset=etth1
```
You can change the model size from "mini" to "large" as required. The datasets we benchmark on are etth1, etth2, ettm1 and ettm2.
## Benchmark Results for TimesFM-1.0

We compare the performance on horizon lengths of 96, 192 and 336, while context length is held fixed at 512.
We can see that TimesFM performs the best in terms of both wape and smape. More importantly it is much faster than the other methods, in particular it is more than 1000x faster than Chronos (Large).
================================================
FILE: v1/experiments/long_horizon_benchmarks/run_eval.py
================================================
# Copyright 2024 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Eval pipeline."""
import json
import os
import sys
import time
from absl import flags
import chronos
import numpy as np
import pandas as pd
import timesfm
from timesfm import data_loader
import torch
import tqdm
FLAGS = flags.FLAGS
_BATCH_SIZE = flags.DEFINE_integer("batch_size", 64,
"Batch size for the randomly sampled batch")
_DATASET = flags.DEFINE_string("dataset", "etth1", "The name of the dataset.")
_MODEL_PATH = flags.DEFINE_string("model_path", "google/timesfm-2.0-500m-jax",
"The name of the model.")
_DATETIME_COL = flags.DEFINE_string("datetime_col", "date",
"Column having datetime.")
_NUM_COV_COLS = flags.DEFINE_list("num_cov_cols", None,
"Column having numerical features.")
_CAT_COV_COLS = flags.DEFINE_list("cat_cov_cols", None,
"Column having categorical features.")
_TS_COLS = flags.DEFINE_list("ts_cols", None, "Columns of time-series features")
_NORMALIZE = flags.DEFINE_bool("normalize", True,
"normalize data for eval or not")
_CONTEXT_LEN = flags.DEFINE_integer("context_len", 2048,
"Length of the context window")
_PRED_LEN = flags.DEFINE_integer("pred_len", 96, "prediction length.")
_BACKEND = flags.DEFINE_string("backend", "gpu", "backend to use")
_RESULTS_DIR = flags.DEFINE_string("results_dir", "./results/long_horizon",
"results directory")
DATA_DICT = {
"ettm2": {
"boundaries": [34560, 46080, 57600],
"data_path": "./datasets/ETT-small/ETTm2.csv",
"freq": "15min",
},
"ettm1": {
"boundaries": [34560, 46080, 57600],
"data_path": "./datasets/ETT-small/ETTm1.csv",
"freq": "15min",
},
"etth2": {
"boundaries": [8640, 11520, 14400],
"data_path": "./datasets/ETT-small/ETTh2.csv",
"freq": "H",
},
"etth1": {
"boundaries": [8640, 11520, 14400],
"data_path": "./datasets/ETT-small/ETTh1.csv",
"freq": "H",
},
"elec": {
"boundaries": [18413, 21044, 26304],
"data_path": "./datasets/electricity/electricity.csv",
"freq": "H",
},
"traffic": {
"boundaries": [12280, 14036, 17544],
"data_path": "./datasets/traffic/traffic.csv",
"freq": "H",
},
"weather": {
"boundaries": [36887, 42157, 52696],
"data_path": "./datasets/weather/weather.csv",
"freq": "10min",
},
}
QUANTILES = list(np.arange(1, 10) / 10.0)
EPS = 1e-7
def get_forecasts(model_path, model, past, freq, pred_len):
"""Get forecasts."""
if model_path.startswith("amazon"):
out = model.predict(
torch.tensor(past),
prediction_length=pred_len,
limit_prediction_length=False,
)
out = out.numpy()
out = np.median(out, axis=1)
else:
lfreq = [freq] * past.shape[0]
_, out = model.forecast(list(past), lfreq)
out = out[:, :, 5]
return out
def _mse(y_pred, y_true):
"""mse loss."""
return np.square(y_pred - y_true)
def _mae(y_pred, y_true):
"""mae loss."""
return np.abs(y_pred - y_true)
def _smape(y_pred, y_true):
"""_smape loss."""
abs_diff = np.abs(y_pred - y_true)
abs_val = (np.abs(y_true) + np.abs(y_pred)) / 2
abs_val = np.where(abs_val > EPS, abs_val, 1.0)
abs_diff = np.where(abs_val > EPS, abs_diff, 0.0)
return abs_diff / abs_val
def eval():
"""Eval pipeline."""
dataset = _DATASET.value
data_path = DATA_DICT[dataset]["data_path"]
freq = DATA_DICT[dataset]["freq"]
int_freq = timesfm.freq_map(freq)
boundaries = DATA_DICT[dataset]["boundaries"]
data_df = pd.read_csv(open(data_path, "r"))
if _TS_COLS.value is not None:
ts_cols = DATA_DICT[dataset]["ts_cols"]
num_cov_cols = DATA_DICT[dataset]["num_cov_cols"]
cat_cov_cols = DATA_DICT[dataset]["cat_cov_cols"]
else:
ts_cols = [col for col in data_df.columns if col != _DATETIME_COL.value]
num_cov_cols = None
cat_cov_cols = None
batch_size = min(_BATCH_SIZE.value, len(ts_cols))
dtl = data_loader.TimeSeriesdata(
data_path=data_path,
datetime_col=_DATETIME_COL.value,
num_cov_cols=num_cov_cols,
cat_cov_cols=cat_cov_cols,
ts_cols=np.array(ts_cols),
train_range=[0, boundaries[0]],
val_range=[boundaries[0], boundaries[1]],
test_range=[boundaries[1], boundaries[2]],
hist_len=_CONTEXT_LEN.value,
pred_len=_PRED_LEN.value,
batch_size=batch_size,
freq=freq,
normalize=_NORMALIZE.value,
epoch_len=None,
holiday=False,
permute=False,
)
eval_itr = dtl.tf_dataset(mode="test",
shift=_PRED_LEN.value).as_numpy_iterator()
model_path = _MODEL_PATH.value
if model_path.startswith("amazon"):
model = chronos.ChronosPipeline.from_pretrained(
model_path,
device_map="auto",
torch_dtype=torch.bfloat16,
)
else:
model = timesfm.TimesFm(
hparams=timesfm.TimesFmHparams(
backend="gpu",
per_core_batch_size=32,
horizon_len=128,
num_layers=50,
context_len=_CONTEXT_LEN.value,
use_positional_embedding=False,
),
checkpoint=timesfm.TimesFmCheckpoint(huggingface_repo_id=model_path),
)
smape_run_losses = []
mse_run_losses = []
mae_run_losses = []
num_elements = 0
abs_sum = 0
start_time = time.time()
for batch in tqdm.tqdm(eval_itr):
past = batch[0]
actuals = batch[3]
forecasts = get_forecasts(model_path, model, past, int_freq,
_PRED_LEN.value)
forecasts = forecasts[:, 0:actuals.shape[1]]
mae_run_losses.append(_mae(forecasts, actuals).sum())
mse_run_losses.append(_mse(forecasts, actuals).sum())
smape_run_losses.append(_smape(forecasts, actuals).sum())
num_elements += actuals.shape[0] * actuals.shape[1]
abs_sum += np.abs(actuals).sum()
mse_val = np.sum(mse_run_losses) / num_elements
result_dict = {
"mse": mse_val,
"smape": np.sum(smape_run_losses) / num_elements,
"mae": np.sum(mae_run_losses) / num_elements,
"wape": np.sum(mae_run_losses) / abs_sum,
"nrmse": np.sqrt(mse_val) / (abs_sum / num_elements),
"num_elements": num_elements,
"abs_sum": abs_sum,
"total_time": time.time() - start_time,
"model_path": model_path,
"dataset": dataset,
"freq": freq,
"pred_len": _PRED_LEN.value,
"context_len": _CONTEXT_LEN.value,
}
run_id = np.random.randint(10000)
save_path = os.path.join(_RESULTS_DIR.value, str(run_id))
print(f"Saving results to {save_path}", flush=True)
os.makedirs(save_path, exist_ok=True)
with open(os.path.join(save_path, "results.json"), "w") as f:
json.dump(result_dict, f)
print(result_dict, flush=True)
if __name__ == "__main__":
FLAGS = flags.FLAGS
FLAGS(sys.argv)
eval()
================================================
FILE: v1/notebooks/covariates.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# TimesFM with Covariates\n",
"\n",
"This toturial notebook demonstrates how to utilize exogenous covariates with TimesFM when making forecasts. Before running this notebook, make sure:\n",
"\n",
"- You've read through the README of TimesFM.\n",
"- A local kernel with Python 3.10 is up and running, for the jax version.\n",
"- Install the JAX version following the installation instructions."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setup the environment and install TimesFM."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load the checkpoint\n",
"\n",
"**Notice:** Please set up the backend as per your machine (\"cpu\", \"gpu\" or \"tpu\"). This notebook will run by default on GPU.\n",
"\n",
"We load the 2.0-500m model checkpoint from HuggingFace."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import timesfm\n",
"timesfm_backend = \"gpu\" # @param\n",
"\n",
"model = timesfm.TimesFm(\n",
" hparams=timesfm.TimesFmHparams(\n",
" backend=timesfm_backend,\n",
" per_core_batch_size=32,\n",
" horizon_len=128,\n",
" num_layers=50,\n",
" use_positional_embedding=False,\n",
" context_len=2048,\n",
" ),\n",
" checkpoint=timesfm.TimesFmCheckpoint(\n",
" huggingface_repo_id=\"google/timesfm-2.0-500m-jax\"),\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Covariates\n",
"\n",
"Let's take a toy example of forecasting sales for a grocery store: \n",
"\n",
"**Task:** Given the observed the daily sales of this week (7 days), forecast the daily sales of next week (7 days).\n",
"\n",
"```\n",
"Product: ice cream\n",
"Daily_sales: [30, 30, 4, 5, 7, 8, 10]\n",
"Category: food\n",
"Base_price: 1.99\n",
"Weekday: [0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6]\n",
"Has_promotion: [Yes, Yes, No, No, No, Yes, Yes, No, No, No, No, No, No, No]\n",
"Daily_temperature: [31.0, 24.3, 19.4, 26.2, 24.6, 30.0, 31.1, 32.4, 30.9, 26.0, 25.0, 27.8, 29.5, 31.2]\n",
"```\n",
"\n",
"```\n",
"Product: sunscreen\n",
"Daily_sales: [5, 7, 12, 13, 5, 6, 10]\n",
"Category: skin product\n",
"Base_price: 29.99\n",
"Weekday: [0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6]\n",
"Has_promotion: [No, No, Yes, Yes, No, No, No, Yes, Yes, Yes, Yes, Yes, Yes, Yes]\n",
"Daily_temperature: [31.0, 24.3, 19.4, 26.2, 24.6, 30.0, 31.1, 32.4, 30.9, 26.0, 25.0, 27.8, 29.5, 31.2]\n",
"```\n",
"\n",
"In this example, besides the `Daily_sales`, we also have covariates `Category`, `Base_price`, `Weekday`, `Has_promotion`, `Daily_temperature`. Let's introduce some concepts:\n",
"\n",
"**Static covariates** are covariates for each time series. \n",
"- In our example, `Category` is a **static categorical covariate**, \n",
"- `Base_price` is a **static numerical covariates**.\n",
"\n",
"**Dynamic covariates** are covaraites for each time stamps.\n",
"- Date / time related features can be usually treated as dynamic covariates.\n",
"- In our example, `Weekday` and `Has_promotion` are **dynamic categorical covariates**.\n",
"- `Daily_temperate` is a **dynamic numerical covariate**.\n",
"\n",
"**Notice:** Here we make it mandatory that the dynamic covariates need to cover both the forecasting context and horizon. For example, all dynamic covariates in the example have 14 values: the first 7 correspond to the observed 7 days, and the last 7 correspond to the next 7 days."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# TimesFM with Covariates\n",
"\n",
"\n",
"The strategy we take here is to treat covariates as batched in-context exogenous regressors (XReg) and fit linear models on them outside of TimesFM. The final forecast will be the sum of the TimesFM forecast and the linear model forecast.\n",
"\n",
" In simple words, we consider these two options.\n",
"\n",
"**Option 1:** Get the TimesFM forecast, and fit the linear model regressing the residuals on the covariates (\"timesfm + xreg\").\n",
"\n",
"**Option 2:** Fit the linear model of the time series itself on the covariates, then forecast the residuals using TimesFM (\"xreg + timesfm\").\n",
"\n",
"Let's take a code at the example of Electricity Price Forecasting (EPF). \n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"from collections import defaultdict"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"df = pd.read_csv('https://datasets-nixtla.s3.amazonaws.com/EPF_FR_BE.csv')\n",
"df['ds'] = pd.to_datetime(df['ds'])\n",
"df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This dataset has a few covariates beside the hourly target `y`:\n",
"\n",
"- `unique_id`: a static categorical covariate indicating the country.\n",
"- `gen_forecast`: a dynamic numerical covariate indicating the estimated electricity to be generated.\n",
"- `system_load`: the observed system load. Notice that this **CANNOT** be considered as a dynamic numerical covariate because we cannot know its values over the forecasting horizon in advance.\n",
"- `weekday`: a dynamic categorical covariate.\\\n",
"\n",
"Let's now make some forecasting tasks for TimesFM based on this dataset. For simplicity we create forecast contexts of 120 time points (hours) and forecast horizons of 24 time points."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# Data pipelining\n",
"def get_batched_data_fn(\n",
" batch_size: int = 128, \n",
" context_len: int = 120, \n",
" horizon_len: int = 24,\n",
"):\n",
" examples = defaultdict(list)\n",
"\n",
" num_examples = 0\n",
" for country in (\"FR\", \"BE\"):\n",
" sub_df = df[df[\"unique_id\"] == country]\n",
" for start in range(0, len(sub_df) - (context_len + horizon_len), horizon_len):\n",
" num_examples += 1\n",
" examples[\"country\"].append(country)\n",
" examples[\"inputs\"].append(sub_df[\"y\"][start:(context_end := start + context_len)].tolist())\n",
" examples[\"gen_forecast\"].append(sub_df[\"gen_forecast\"][start:context_end + horizon_len].tolist())\n",
" examples[\"week_day\"].append(sub_df[\"week_day\"][start:context_end + horizon_len].tolist())\n",
" examples[\"outputs\"].append(sub_df[\"y\"][context_end:(context_end + horizon_len)].tolist())\n",
" \n",
" def data_fn():\n",
" for i in range(1 + (num_examples - 1) // batch_size):\n",
" yield {k: v[(i * batch_size) : ((i + 1) * batch_size)] for k, v in examples.items()}\n",
" \n",
" return data_fn\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"# Define metrics\n",
"def mse(y_pred, y_true):\n",
" y_pred = np.array(y_pred)\n",
" y_true = np.array(y_true)\n",
" return np.mean(np.square(y_pred - y_true), axis=1, keepdims=True)\n",
"\n",
"def mae(y_pred, y_true):\n",
" y_pred = np.array(y_pred)\n",
" y_true = np.array(y_true)\n",
" return np.mean(np.abs(y_pred - y_true), axis=1, keepdims=True)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now let's try `model.forecast_with_covariates`. \n",
"\n",
"In particular, the output is a tuple whose first element is the new forecast."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import time\n",
"\n",
"# Benchmark\n",
"batch_size = 128\n",
"context_len = 120\n",
"horizon_len = 24\n",
"input_data = get_batched_data_fn(batch_size = 128)\n",
"metrics = defaultdict(list)\n",
"\n",
"\n",
"for i, example in enumerate(input_data()):\n",
" raw_forecast, _ = model.forecast(\n",
" inputs=example[\"inputs\"], freq=[0] * len(example[\"inputs\"])\n",
" )\n",
" start_time = time.time()\n",
" # Forecast with covariates\n",
" # Output: new forecast, forecast by the xreg\n",
" cov_forecast, ols_forecast = model.forecast_with_covariates( \n",
" inputs=example[\"inputs\"],\n",
" dynamic_numerical_covariates={\n",
" \"gen_forecast\": example[\"gen_forecast\"],\n",
" },\n",
" dynamic_categorical_covariates={\n",
" \"week_day\": example[\"week_day\"],\n",
" },\n",
" static_numerical_covariates={},\n",
" static_categorical_covariates={\n",
" \"country\": example[\"country\"]\n",
" },\n",
" freq=[0] * len(example[\"inputs\"]),\n",
" xreg_mode=\"xreg + timesfm\", # default\n",
" ridge=0.0,\n",
" force_on_cpu=False,\n",
" normalize_xreg_target_per_input=True, # default\n",
" )\n",
" print(\n",
" f\"\\rFinished batch {i} linear in {time.time() - start_time} seconds\",\n",
" end=\"\",\n",
" )\n",
" metrics[\"eval_mae_timesfm\"].extend(\n",
" mae(raw_forecast[:, :horizon_len], example[\"outputs\"])\n",
" )\n",
" metrics[\"eval_mae_xreg_timesfm\"].extend(mae(cov_forecast, example[\"outputs\"]))\n",
" metrics[\"eval_mae_xreg\"].extend(mae(ols_forecast, example[\"outputs\"]))\n",
" metrics[\"eval_mse_timesfm\"].extend(\n",
" mse(raw_forecast[:, :horizon_len], example[\"outputs\"])\n",
" )\n",
" metrics[\"eval_mse_xreg_timesfm\"].extend(mse(cov_forecast, example[\"outputs\"]))\n",
" metrics[\"eval_mse_xreg\"].extend(mse(ols_forecast, example[\"outputs\"]))\n",
"\n",
"print()\n",
"\n",
"for k, v in metrics.items():\n",
" print(f\"{k}: {np.mean(v)}\")\n",
"\n",
"# My output:\n",
"# eval_mae_timesfm: 6.762283045916956\n",
"# eval_mae_xreg_timesfm: 5.39219617611074\n",
"# eval_mae_xreg: 37.15275842572484\n",
"# eval_mse_timesfm: 166.7771466306823\n",
"# eval_mse_xreg_timesfm: 120.64757721021306\n",
"# eval_mse_xreg: 1672.2116821201796"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You should see results close to \n",
"```\n",
"eval_mae_timesfm: 6.729583250571446\n",
"eval_mae_xreg_timesfm: 5.3375301110158\n",
"eval_mae_xreg: 37.152760709266\n",
"eval_mse_timesfm: 162.3132151851567\n",
"eval_mse_xreg_timesfm: 120.9900627409689\n",
"eval_mse_xreg: 1672.208769045399\n",
"```\n",
"\n",
"With the covariates, the TimesFM forecast Mean Absolute Error improves from 6.73 to 5.34, and Mean Squred Error from 162.31 to 120.99. The results of purely fitting the linear model are also provided for reference."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Formatting Your Request\n",
"\n",
"It is quite crucial to get the covariates properly formatted so that we can call this `model.forecast_with_covariates`. Please see its docstring for details. Here let's also grab a batch from a toy data input pipeline for quick explanations."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"toy_input_pipeline = get_batched_data_fn(batch_size=2, context_len=5, horizon_len=2)\n",
"print(next(toy_input_pipeline()))\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You should see something similar to this\n",
"```\n",
"{\n",
" 'country': ['FR', 'FR'], \n",
" 'inputs': [[53.48, 51.93, 48.76, 42.27, 38.41], [48.76, 42.27, 38.41, 35.72, 32.66]], \n",
" 'gen_forecast': [[76905.0, 75492.0, 74394.0, 72639.0, 69347.0, 67960.0, 67564.0], [74394.0, 72639.0, 69347.0, 67960.0, 67564.0, 67277.0, 67019.0]], \n",
" 'week_day': [[3, 3, 3, 3, 3, 3, 3], [3, 3, 3, 3, 3, 3, 3]], \n",
" 'outputs': [[35.72, 32.66], [32.83, 30.06]],\n",
"}\n",
"```\n",
"\n",
"Notice:\n",
"- We have two examples in this batch.\n",
"- For each example we support different context lengths and horizon lengths just as `model.forecast`. Although it is not demonstrated in this dataset.\n",
"- If dynamic covariates are present, the horizon lengths will be inferred from them, e.g. how many values are provided in additional to the ones corresponding to the inputs. Make sure all your dynamic covariates have the same length per example.\n",
"- The static covariates are one per example.\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## More Applications\n",
"\n",
"### Past Dynamic Covariates\n",
"\n",
"Past dynamic covariates are covariates that are only available for the context. For instance in our example `system_load` is a past dynamic covariate. Time series models generally can handle this, however it is something the batched in context regression cannot address, because these regressors are not available in the future. If you do have those covariates and consider them very meaningful, there are two hacky options to try immediately:\n",
"\n",
"1. Shift and repeat these past dynamic covariates to use their delayed version. For example, if you think the `system_load` for this week is meaningful for forecasting next week, you can create a `delay_7_system_load` by shifting 7 timestamps and use this as one dynamic numerical covariate for TimesFM.\n",
"2. Bootstrap, that is to run TimesFM once to forecast these past dynamic covariates into the horizon, then call TimesFM again using these forecasts as the future part for these dynamic covariates.\n",
"\n",
"### Multivariate Time Series\n",
"\n",
"For multivariate time series, if we need univariate forecast, we can try treating the main time series as the target and use the rest as the dynamic covariates."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "chronos-v2",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.15"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
================================================
FILE: v1/notebooks/finetuning.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Importing relevant packages for finetuning"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'\n",
"os.environ['JAX_PMAP_USE_TENSORSTORE'] = 'false'"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import timesfm\n",
"import gc\n",
"import numpy as np\n",
"import pandas as pd\n",
"from timesfm import patched_decoder\n",
"from timesfm import data_loader"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"from tqdm import tqdm\n",
"import dataclasses\n",
"import IPython\n",
"import IPython.display\n",
"import matplotlib as mpl\n",
"import matplotlib.pyplot as plt\n",
"mpl.rcParams['figure.figsize'] = (8, 6)\n",
"mpl.rcParams['axes.grid'] = False"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Loading TimesFM pretrained checkpoint"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"timesfm_backend = \"gpu\" # @param\n",
"\n",
"tfm = timesfm.TimesFm(\n",
" hparams=timesfm.TimesFmHparams(\n",
" backend=timesfm_backend,\n",
" per_core_batch_size=32,\n",
" horizon_len=128,\n",
" num_layers=50,\n",
" # Se this to True for v1.0 checkpoints\n",
" use_positional_embedding=False,\n",
" # Note that we could set this to as high as 2048 but keeping it 512 here so that\n",
" # both v1.0 and 2.0 checkpoints work\n",
" context_len=512,\n",
" ),\n",
" checkpoint=timesfm.TimesFmCheckpoint(\n",
" huggingface_repo_id=\"google/timesfm-2.0-500m-jax\"),\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Evaluating pretrained checkpoint on ETT datasets"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"DATA_DICT = {\n",
" \"ettm2\": {\n",
" \"boundaries\": [34560, 46080, 57600],\n",
" \"data_path\": \"../datasets/ETT-small/ETTm2.csv\",\n",
" \"freq\": \"15min\",\n",
" },\n",
" \"ettm1\": {\n",
" \"boundaries\": [34560, 46080, 57600],\n",
" \"data_path\": \"../datasets/ETT-small/ETTm1.csv\",\n",
" \"freq\": \"15min\",\n",
" },\n",
" \"etth2\": {\n",
" \"boundaries\": [8640, 11520, 14400],\n",
" \"data_path\": \"../datasets/ETT-small/ETTh2.csv\",\n",
" \"freq\": \"H\",\n",
" },\n",
" \"etth1\": {\n",
" \"boundaries\": [8640, 11520, 14400],\n",
" \"data_path\": \"../datasets/ETT-small/ETTh1.csv\",\n",
" \"freq\": \"H\",\n",
" },\n",
" \"elec\": {\n",
" \"boundaries\": [18413, 21044, 26304],\n",
" \"data_path\": \"../datasets/electricity/electricity.csv\",\n",
" \"freq\": \"H\",\n",
" },\n",
" \"traffic\": {\n",
" \"boundaries\": [12280, 14036, 17544],\n",
" \"data_path\": \"../datasets/traffic/traffic.csv\",\n",
" \"freq\": \"H\",\n",
" },\n",
" \"weather\": {\n",
" \"boundaries\": [36887, 42157, 52696],\n",
" \"data_path\": \"../datasets/weather/weather.csv\",\n",
" \"freq\": \"10min\",\n",
" },\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"dataset = \"ettm1\"\n",
"data_path = DATA_DICT[dataset][\"data_path\"]\n",
"freq = DATA_DICT[dataset][\"freq\"]\n",
"int_freq = timesfm.freq_map(freq)\n",
"boundaries = DATA_DICT[dataset][\"boundaries\"]\n",
"\n",
"data_df = pd.read_csv(open(data_path, \"r\"))\n",
"\n",
"\n",
"ts_cols = [col for col in data_df.columns if col != \"date\"]\n",
"num_cov_cols = None\n",
"cat_cov_cols = None\n",
"\n",
"context_len = 512\n",
"pred_len = 96\n",
"\n",
"num_ts = len(ts_cols)\n",
"batch_size = 8\n",
"\n",
"dtl = data_loader.TimeSeriesdata(\n",
" data_path=data_path,\n",
" datetime_col=\"date\",\n",
" num_cov_cols=num_cov_cols,\n",
" cat_cov_cols=cat_cov_cols,\n",
" ts_cols=np.array(ts_cols),\n",
" train_range=[0, boundaries[0]],\n",
" val_range=[boundaries[0], boundaries[1]],\n",
" test_range=[boundaries[1], boundaries[2]],\n",
" hist_len=context_len,\n",
" pred_len=pred_len,\n",
" batch_size=num_ts,\n",
" freq=freq,\n",
" normalize=True,\n",
" epoch_len=None,\n",
" holiday=False,\n",
" permute=True,\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train_batches = dtl.tf_dataset(mode=\"train\", shift=1).batch(batch_size)\n",
"val_batches = dtl.tf_dataset(mode=\"val\", shift=pred_len)\n",
"test_batches = dtl.tf_dataset(mode=\"test\", shift=pred_len)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for tbatch in tqdm(train_batches.as_numpy_iterator()):\n",
" break\n",
"print(tbatch[0].shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### MAE on the test split for the pretrained TimesFM model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"mae_losses = []\n",
"for batch in tqdm(test_batches.as_numpy_iterator()):\n",
" past = batch[0]\n",
" actuals = batch[3]\n",
" forecasts, _ = tfm.forecast(list(past), [0] * past.shape[0], normalize=True)\n",
" forecasts = forecasts[:, 0 : actuals.shape[1]]\n",
" mae_losses.append(np.abs(forecasts - actuals).mean())\n",
"\n",
"print(f\"MAE: {np.mean(mae_losses)}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Finetuning the model on the ETT dataset"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"import jax\n",
"from jax import numpy as jnp\n",
"from praxis import pax_fiddle\n",
"from praxis import py_utils\n",
"from praxis import pytypes\n",
"from praxis import base_model\n",
"from praxis import optimizers\n",
"from praxis import schedules\n",
"from praxis import base_hyperparams\n",
"from praxis import base_layer\n",
"from paxml import tasks_lib\n",
"from paxml import trainer_lib\n",
"from paxml import checkpoints\n",
"from paxml import learners\n",
"from paxml import partitioning\n",
"from paxml import checkpoint_types"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"# PAX shortcuts\n",
"NestedMap = py_utils.NestedMap\n",
"WeightInit = base_layer.WeightInit\n",
"WeightHParams = base_layer.WeightHParams\n",
"InstantiableParams = py_utils.InstantiableParams\n",
"JTensor = pytypes.JTensor\n",
"NpTensor = pytypes.NpTensor\n",
"WeightedScalars = pytypes.WeightedScalars\n",
"instantiate = base_hyperparams.instantiate\n",
"LayerTpl = pax_fiddle.Config[base_layer.BaseLayer]\n",
"AuxLossStruct = base_layer.AuxLossStruct\n",
"\n",
"AUX_LOSS = base_layer.AUX_LOSS\n",
"template_field = base_layer.template_field\n",
"\n",
"# Standard prng key names\n",
"PARAMS = base_layer.PARAMS\n",
"RANDOM = base_layer.RANDOM\n",
"\n",
"key = jax.random.PRNGKey(seed=1234)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"model = pax_fiddle.Config(\n",
" patched_decoder.PatchedDecoderFinetuneModel,\n",
" name='patched_decoder_finetune',\n",
" core_layer_tpl=tfm.model_p,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### We will hold the transformer layers fixed while finetuning, while training all other components."
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"@pax_fiddle.auto_config\n",
"def build_learner() -> learners.Learner:\n",
" return pax_fiddle.Config(\n",
" learners.Learner,\n",
" name='learner',\n",
" loss_name='avg_qloss',\n",
" optimizer=optimizers.Adam(\n",
" epsilon=1e-7,\n",
" clip_threshold=1e2,\n",
" learning_rate=1e-2,\n",
" lr_schedule=pax_fiddle.Config(\n",
" schedules.Cosine,\n",
" initial_value=1e-3,\n",
" final_value=1e-4,\n",
" total_steps=40000,\n",
" ),\n",
" ema_decay=0.9999,\n",
" ),\n",
" # Linear probing i.e we hold the transformer layers fixed.\n",
" bprop_variable_exclusion=['.*/stacked_transformer_layer/.*'],\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"task_p = tasks_lib.SingleTask(\n",
" name='ts-learn',\n",
" model=model,\n",
" train=tasks_lib.SingleTask.Train(\n",
" learner=build_learner(),\n",
" ),\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"task_p.model.ici_mesh_shape = [1, 1, 1]\n",
"task_p.model.mesh_axis_names = ['replica', 'data', 'mdl']\n",
"\n",
"DEVICES = np.array(jax.devices()).reshape([1, 1, 1])\n",
"MESH = jax.sharding.Mesh(DEVICES, ['replica', 'data', 'mdl'])\n",
"\n",
"num_devices = jax.local_device_count()\n",
"print(f'num_devices: {num_devices}')\n",
"print(f'device kind: {jax.local_devices()[0].device_kind}')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"jax_task = task_p\n",
"key, init_key = jax.random.split(key)\n",
"\n",
"# To correctly prepare a batch of data for model initialization (now that shape\n",
"# inference is merged), we take one devices*batch_size tensor tuple of data,\n",
"# slice out just one batch, then run the prepare_input_batch function over it.\n",
"\n",
"\n",
"def process_train_batch(batch):\n",
" past_ts = batch[0].reshape(batch_size * num_ts, -1)\n",
" actual_ts = batch[3].reshape(batch_size * num_ts, -1)\n",
" return NestedMap(input_ts=past_ts, actual_ts=actual_ts)\n",
"\n",
"\n",
"def process_eval_batch(batch):\n",
" past_ts = batch[0]\n",
" actual_ts = batch[3]\n",
" return NestedMap(input_ts=past_ts, actual_ts=actual_ts)\n",
"\n",
"\n",
"jax_model_states, _ = trainer_lib.initialize_model_state(\n",
" jax_task,\n",
" init_key,\n",
" process_train_batch(tbatch),\n",
" checkpoint_type=checkpoint_types.CheckpointType.GDA,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Setting the initial model weights to the pretrained TimesFM parameters."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"jax_model_states.mdl_vars['params']['core_layer'] = tfm._train_state.mdl_vars['params']\n",
"jax_vars = jax_model_states.mdl_vars\n",
"gc.collect()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Training loop"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"jax_task = task_p\n",
"\n",
"\n",
"def train_step(states, prng_key, inputs):\n",
" return trainer_lib.train_step_single_learner(\n",
" jax_task, states, prng_key, inputs\n",
" )\n",
"\n",
"\n",
"def eval_step(states, prng_key, inputs):\n",
" states = states.to_eval_state()\n",
" return trainer_lib.eval_step_single_learner(\n",
" jax_task, states, prng_key, inputs\n",
" )\n",
"\n",
"key, train_key, eval_key = jax.random.split(key, 3)\n",
"train_prng_seed = jax.random.split(train_key, num=jax.local_device_count())\n",
"eval_prng_seed = jax.random.split(eval_key, num=jax.local_device_count())\n",
"\n",
"p_train_step = jax.pmap(train_step, axis_name='batch')\n",
"p_eval_step = jax.pmap(eval_step, axis_name='batch')"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"replicated_jax_states = trainer_lib.replicate_model_state(jax_model_states)\n",
"replicated_jax_vars = replicated_jax_states.mdl_vars"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"best_eval_loss = 1e7\n",
"step_count = 0\n",
"patience = 0\n",
"NUM_EPOCHS = 100\n",
"PATIENCE = 5\n",
"TRAIN_STEPS_PER_EVAL = 1000\n",
"CHECKPOINT_DIR='/home/senrajat_google_com/ettm1_finetune'"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"def reshape_batch_for_pmap(batch, num_devices):\n",
" def _reshape(input_tensor):\n",
" bsize = input_tensor.shape[0]\n",
" residual_shape = list(input_tensor.shape[1:])\n",
" nbsize = bsize // num_devices\n",
" return jnp.reshape(input_tensor, [num_devices, nbsize] + residual_shape)\n",
"\n",
" return jax.tree.map(_reshape, batch)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for epoch in range(NUM_EPOCHS):\n",
" print(f\"__________________Epoch: {epoch}__________________\", flush=True)\n",
" train_its = train_batches.as_numpy_iterator()\n",
" if patience >= PATIENCE:\n",
" print(\"Early stopping.\", flush=True)\n",
" break\n",
" for batch in tqdm(train_its):\n",
" train_losses = []\n",
" if patience >= PATIENCE:\n",
" print(\"Early stopping.\", flush=True)\n",
" break\n",
" tbatch = process_train_batch(batch)\n",
" tbatch = reshape_batch_for_pmap(tbatch, num_devices)\n",
" replicated_jax_states, step_fun_out = p_train_step(\n",
" replicated_jax_states, train_prng_seed, tbatch\n",
" )\n",
" train_losses.append(step_fun_out.loss[0])\n",
" if step_count % TRAIN_STEPS_PER_EVAL == 0:\n",
" print(\n",
" f\"Train loss at step {step_count}: {np.mean(train_losses)}\",\n",
" flush=True,\n",
" )\n",
" train_losses = []\n",
" print(\"Starting eval.\", flush=True)\n",
" val_its = val_batches.as_numpy_iterator()\n",
" eval_losses = []\n",
" for ev_batch in tqdm(val_its):\n",
" ebatch = process_eval_batch(ev_batch)\n",
" ebatch = reshape_batch_for_pmap(ebatch, num_devices)\n",
" _, step_fun_out = p_eval_step(\n",
" replicated_jax_states, eval_prng_seed, ebatch\n",
" )\n",
" eval_losses.append(step_fun_out.loss[0])\n",
" mean_loss = np.mean(eval_losses)\n",
" print(f\"Eval loss at step {step_count}: {mean_loss}\", flush=True)\n",
" if mean_loss < best_eval_loss or np.isnan(mean_loss):\n",
" best_eval_loss = mean_loss\n",
" print(\"Saving checkpoint.\")\n",
" jax_state_for_saving = py_utils.maybe_unreplicate_for_fully_replicated(\n",
" replicated_jax_states\n",
" )\n",
" checkpoints.save_checkpoint(\n",
" jax_state_for_saving, CHECKPOINT_DIR, overwrite=True\n",
" )\n",
" patience = 0\n",
" del jax_state_for_saving\n",
" gc.collect()\n",
" else:\n",
" patience += 1\n",
" print(f\"patience: {patience}\")\n",
" step_count += 1"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Loading and evaluating the best (according to validation loss) finetuned checkpoint"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train_state = checkpoints.restore_checkpoint(jax_model_states, CHECKPOINT_DIR)\n",
"print(train_state.step)\n",
"tfm._train_state.mdl_vars['params'] = train_state.mdl_vars['params']['core_layer']\n",
"tfm.jit_decode()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"mae_losses = []\n",
"for batch in tqdm(test_batches.as_numpy_iterator()):\n",
" past = batch[0]\n",
" actuals = batch[3]\n",
" _, forecasts = tfm.forecast(list(past), [0] * past.shape[0])\n",
" forecasts = forecasts[:, 0 : actuals.shape[1], 5]\n",
" mae_losses.append(np.abs(forecasts - actuals).mean())\n",
"\n",
"print(f\"MAE: {np.mean(mae_losses)}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## There is around a __7%__ reduction in MAE from finetuning."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "chronos-v2",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.15"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
================================================
FILE: v1/notebooks/finetuning_torch.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Introduction\n",
"This notebook shows how to use TimesFM with finetuning. \n",
"\n",
"In order to perform finetuning, you need to create the Pytorch Dataset in a proper format. The example of the Dataset is provided below.\n",
"The finetuning code can be found in timesfm.finetuning_torch.py. This notebook just imports the methods from finetuning"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Dataset Creation"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"TimesFM v1.2.0. See https://github.com/google-research/timesfm/blob/master/README.md for updated APIs.\n",
"Loaded Jax TimesFM.\n",
"Loaded PyTorch TimesFM.\n"
]
}
],
"source": [
"from os import path\n",
"from typing import Optional, Tuple\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"import torch\n",
"import torch.multiprocessing as mp\n",
"import yfinance as yf\n",
"from finetuning.finetuning_torch import FinetuningConfig, TimesFMFinetuner\n",
"from huggingface_hub import snapshot_download\n",
"from torch.utils.data import Dataset\n",
"\n",
"from timesfm import TimesFm, TimesFmCheckpoint, TimesFmHparams\n",
"from timesfm.pytorch_patched_decoder import PatchedTimeSeriesDecoder\n",
"import os\n",
"\n",
"\n",
"class TimeSeriesDataset(Dataset):\n",
" \"\"\"Dataset for time series data compatible with TimesFM.\"\"\"\n",
"\n",
" def __init__(self,\n",
" series: np.ndarray,\n",
" context_length: int,\n",
" horizon_length: int,\n",
" freq_type: int = 0):\n",
" \"\"\"\n",
" Initialize dataset.\n",
"\n",
" Args:\n",
" series: Time series data\n",
" context_length: Number of past timesteps to use as input\n",
" horizon_length: Number of future timesteps to predict\n",
" freq_type: Frequency type (0, 1, or 2)\n",
" \"\"\"\n",
" if freq_type not in [0, 1, 2]:\n",
" raise ValueError(\"freq_type must be 0, 1, or 2\")\n",
"\n",
" self.series = series\n",
" self.context_length = context_length\n",
" self.horizon_length = horizon_length\n",
" self.freq_type = freq_type\n",
" self._prepare_samples()\n",
"\n",
" def _prepare_samples(self) -> None:\n",
" \"\"\"Prepare sliding window samples from the time series.\"\"\"\n",
" self.samples = []\n",
" total_length = self.context_length + self.horizon_length\n",
"\n",
" for start_idx in range(0, len(self.series) - total_length + 1):\n",
" end_idx = start_idx + self.context_length\n",
" x_context = self.series[start_idx:end_idx]\n",
" x_future = self.series[end_idx:end_idx + self.horizon_length]\n",
" self.samples.append((x_context, x_future))\n",
"\n",
" def __len__(self) -> int:\n",
" return len(self.samples)\n",
"\n",
" def __getitem__(\n",
" self, index: int\n",
" ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n",
" x_context, x_future = self.samples[index]\n",
"\n",
" x_context = torch.tensor(x_context, dtype=torch.float32)\n",
" x_future = torch.tensor(x_future, dtype=torch.float32)\n",
"\n",
" input_padding = torch.zeros_like(x_context)\n",
" freq = torch.tensor([self.freq_type], dtype=torch.long)\n",
"\n",
" return x_context, input_padding, freq, x_future\n",
"\n",
"def prepare_datasets(series: np.ndarray,\n",
" context_length: int,\n",
" horizon_length: int,\n",
" freq_type: int = 0,\n",
" train_split: float = 0.8) -> Tuple[Dataset, Dataset]:\n",
" \"\"\"\n",
" Prepare training and validation datasets from time series data.\n",
"\n",
" Args:\n",
" series: Input time series data\n",
" context_length: Number of past timesteps to use\n",
" horizon_length: Number of future timesteps to predict\n",
" freq_type: Frequency type (0, 1, or 2)\n",
" train_split: Fraction of data to use for training\n",
"\n",
" Returns:\n",
" Tuple of (train_dataset, val_dataset)\n",
" \"\"\"\n",
" train_size = int(len(series) * train_split)\n",
" train_data = series[:train_size]\n",
" val_data = series[train_size:]\n",
"\n",
" # Create datasets with specified frequency type\n",
" train_dataset = TimeSeriesDataset(train_data,\n",
" context_length=context_length,\n",
" horizon_length=horizon_length,\n",
" freq_type=freq_type)\n",
"\n",
" val_dataset = TimeSeriesDataset(val_data,\n",
" context_length=context_length,\n",
" horizon_length=horizon_length,\n",
" freq_type=freq_type)\n",
"\n",
" return train_dataset, val_dataset\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Model Creation"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def get_model(load_weights: bool = False):\n",
" device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
" repo_id = \"google/timesfm-2.0-500m-pytorch\"\n",
" hparams = TimesFmHparams(\n",
" backend=device,\n",
" per_core_batch_size=32,\n",
" horizon_len=128,\n",
" num_layers=50,\n",
" use_positional_embedding=False,\n",
" context_len=\n",
" 192, # Context length can be anything up to 2048 in multiples of 32\n",
" )\n",
" tfm = TimesFm(hparams=hparams,\n",
" checkpoint=TimesFmCheckpoint(huggingface_repo_id=repo_id))\n",
"\n",
" model = PatchedTimeSeriesDecoder(tfm._model_config)\n",
" if load_weights:\n",
" checkpoint_path = path.join(snapshot_download(repo_id), \"torch_model.ckpt\")\n",
" loaded_checkpoint = torch.load(checkpoint_path, weights_only=True)\n",
" model.load_state_dict(loaded_checkpoint)\n",
" return model, hparams, tfm._model_config\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def plot_predictions(\n",
" model: TimesFm,\n",
" val_dataset: Dataset,\n",
" save_path: Optional[str] = \"predictions.png\",\n",
") -> None:\n",
" \"\"\"\n",
" Plot model predictions against ground truth for a batch of validation data.\n",
"\n",
" Args:\n",
" model: Trained TimesFM model\n",
" val_dataset: Validation dataset\n",
" save_path: Path to save the plot\n",
" \"\"\"\n",
" import matplotlib.pyplot as plt\n",
"\n",
" model.eval()\n",
"\n",
" x_context, x_padding, freq, x_future = val_dataset[0]\n",
" x_context = x_context.unsqueeze(0) # Add batch dimension\n",
" x_padding = x_padding.unsqueeze(0)\n",
" freq = freq.unsqueeze(0)\n",
" x_future = x_future.unsqueeze(0)\n",
"\n",
" device = next(model.parameters()).device\n",
" x_context = x_context.to(device)\n",
" x_padding = x_padding.to(device)\n",
" freq = freq.to(device)\n",
" x_future = x_future.to(device)\n",
"\n",
" with torch.no_grad():\n",
" predictions = model(x_context, x_padding.float(), freq)\n",
" predictions_mean = predictions[..., 0] # [B, N, horizon_len]\n",
" last_patch_pred = predictions_mean[:, -1, :] # [B, horizon_len]\n",
"\n",
" context_vals = x_context[0].cpu().numpy()\n",
" future_vals = x_future[0].cpu().numpy()\n",
" pred_vals = last_patch_pred[0].cpu().numpy()\n",
"\n",
" context_len = len(context_vals)\n",
" horizon_len = len(future_vals)\n",
"\n",
" plt.figure(figsize=(12, 6))\n",
"\n",
" plt.plot(range(context_len),\n",
" context_vals,\n",
" label=\"Historical Data\",\n",
" color=\"blue\",\n",
" linewidth=2)\n",
"\n",
" plt.plot(\n",
" range(context_len, context_len + horizon_len),\n",
" future_vals,\n",
" label=\"Ground Truth\",\n",
" color=\"green\",\n",
" linestyle=\"--\",\n",
" linewidth=2,\n",
" )\n",
"\n",
" plt.plot(range(context_len, context_len + horizon_len),\n",
" pred_vals,\n",
" label=\"Prediction\",\n",
" color=\"red\",\n",
" linewidth=2)\n",
"\n",
" plt.xlabel(\"Time Step\")\n",
" plt.ylabel(\"Value\")\n",
" plt.title(\"TimesFM Predictions vs Ground Truth\")\n",
" plt.legend()\n",
" plt.grid(True)\n",
"\n",
" if save_path:\n",
" plt.savefig(save_path)\n",
" print(f\"Plot saved to {save_path}\")\n",
"\n",
" plt.close()\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"def get_data(context_len: int,\n",
" horizon_len: int,\n",
" freq_type: int = 0) -> Tuple[Dataset, Dataset]:\n",
" df = yf.download(\"AAPL\", start=\"2010-01-01\", end=\"2019-01-01\")\n",
" time_series = df[\"Close\"].values\n",
"\n",
" train_dataset, val_dataset = prepare_datasets(\n",
" series=time_series,\n",
" context_length=context_len,\n",
" horizon_length=horizon_len,\n",
" freq_type=freq_type,\n",
" train_split=0.8,\n",
" )\n",
"\n",
" print(f\"Created datasets:\")\n",
" print(f\"- Training samples: {len(train_dataset)}\")\n",
" print(f\"- Validation samples: {len(val_dataset)}\")\n",
" print(f\"- Using frequency type: {freq_type}\")\n",
" return train_dataset, val_dataset\n",
"\n",
"\n",
"\n",
"def single_gpu_example():\n",
" \"\"\"Basic example of finetuning TimesFM on stock data.\"\"\"\n",
" model, hparams, tfm_config = get_model(load_weights=True)\n",
" config = FinetuningConfig(batch_size=256,\n",
" num_epochs=5,\n",
" learning_rate=1e-4,\n",
" use_wandb=True,\n",
" freq_type=1,\n",
" log_every_n_steps=10,\n",
" val_check_interval=0.5,\n",
" use_quantile_loss=True)\n",
"\n",
" train_dataset, val_dataset = get_data(128,\n",
" tfm_config.horizon_len,\n",
" freq_type=config.freq_type)\n",
" finetuner = TimesFMFinetuner(model, config)\n",
"\n",
" print(\"\\nStarting finetuning...\")\n",
" results = finetuner.finetune(train_dataset=train_dataset,\n",
" val_dataset=val_dataset)\n",
"\n",
" print(\"\\nFinetuning completed!\")\n",
" print(f\"Training history: {len(results['history']['train_loss'])} epochs\")\n",
"\n",
" plot_predictions(\n",
" model=model,\n",
" val_dataset=val_dataset,\n",
" save_path=\"timesfm_predictions.png\",\n",
" )\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ac84aeda3a1749ae8f30b06859067bb1",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Fetching 3 files: 0%| | 0/3 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "6d9d8081fc514c6d8601a2e0e63954a2",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Fetching 3 files: 0%| | 0/3 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"[*********************100%***********************] 1 of 1 completed\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Created datasets:\n",
"- Training samples: 1556\n",
"- Validation samples: 198\n",
"- Using frequency type: 1\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[34m\u001b[1mwandb\u001b[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mmishacamry\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
]
},
{
"data": {
"text/html": [
"Tracking run with wandb version 0.19.1"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"Run data is saved locally in /home/chertushkin/forks/timesfm/notebooks/wandb/run-20250217_114343-tjs63ml2"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"Syncing run chocolate-eon-50 to Weights & Biases (docs) "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
" View project at https://wandb.ai/mishacamry/timesfm-finetuning"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
" View run at https://wandb.ai/mishacamry/timesfm-finetuning/runs/tjs63ml2"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Starting finetuning...\n"
]
},
{
"data": {
"text/html": [],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"
Run history:
epoch
▁▃▅▆█
learning_rate
▁▁▁▁▁
train_loss
█▃▂▁▁
val_loss
█▁▄▁▂
Run summary:
epoch
5
learning_rate
0.0001
train_loss
2.85423
val_loss
26.7628
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
" View run chocolate-eon-50 at: https://wandb.ai/mishacamry/timesfm-finetuning/runs/tjs63ml2 View project at: https://wandb.ai/mishacamry/timesfm-finetuning Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"Find logs at: ./wandb/run-20250217_114343-tjs63ml2/logs"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Finetuning completed!\n",
"Training history: 5 epochs\n",
"Plot saved to timesfm_predictions.png\n"
]
}
],
"source": [
"single_gpu_example()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "timesfm-DnAbSweh-py3.11",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.10"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
================================================
FILE: v1/peft/README.md
================================================
# Fine-Tuning Pipeline
This folder contains a generic fine-tuning pipeline designed to support multiple PEFT fine-tuning strategies.
## Features
- **Supported Fine-Tuning Strategies**:
- **Full Fine-Tuning**: Adjusts all parameters of the model during training.
- **[Linear Probing](https://arxiv.org/abs/2302.11939)**: Fine-tunes only the residual blocks and the embedding layer, leaving other parameters unchanged.
- **[LoRA (Low-Rank Adaptation)](https://arxiv.org/abs/2106.09685)**: A memory-efficient method that fine-tunes a small number of parameters by decomposing the weight matrices into low-rank matrices.
- **[DoRA (Directional LoRA)](https://arxiv.org/abs/2402.09353v4)**: An extension of LoRA that decomposes pre-trained weights into magnitude and direction components. It uses LoRA for directional adaptation, enhancing learning capacity and stability without additional inference overhead.
## Usage
### Fine-Tuning Script
The provided finetune.py script allows you to fine-tune a model with specific configurations. You can customize various parameters to suit your dataset and desired fine-tuning strategy.
Example Usage:
```zsh
source finetune.sh
```
This script runs the finetune.py file with a predefined set of hyperparameters for the model. You can adjust the parameters in the script as needed.
### Available Options
Run the script with the --help flag to see a full list of available options and their descriptions:
```zsh
python3 finetune.py --help
```
Script Configuration
You can modify the following key parameters directly in the finetune.sh script:
Fine-Tuning Strategy: Toggle between full fine-tuning, LoRA \[`--use-lora`\], DoRA [\[`--use-dora`\]], or Linear Probing \[`--use-linear-probing`\].
### Performance Comparison
The figure below compares the performance of LoRA/DoRA against Linear Probing under the following conditions:
- Training data split: 60% train, 20% validation, 20% test.
- Benchmark: context_len=128, horizon_len=96
- Fine-tuning: context_len=128, horizon_len=128
- Black: Best result.
- Blue: Second best result.
================================================
FILE: v1/peft/finetune.py
================================================
# Copyright 2024 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Finetune pipeline.
"""
import gc
import logging
import warnings
from datetime import datetime
from typing import Tuple
import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd
import typer
import wandb
from jax import numpy as jnp
from paxml import checkpoint_types, checkpoints, learners, tasks_lib, trainer_lib
from praxis import optimizers, pax_fiddle, py_utils, schedules
from rich import print
from tqdm import tqdm
from typing_extensions import Annotated
from adapter.utils import get_adapter_params, load_adapter_layer
from timesfm import TimesFm, data_loader, patched_decoder
NestedMap = py_utils.NestedMap
warnings.filterwarnings("ignore")
cmdstanpy_logger = logging.getLogger("cmdstanpy")
absl_logger = logging.getLogger("absl")
cmdstanpy_logger.disabled = True
absl_logger.disabled = True
"""
TimesFM model config. These are fixed since pre-training was done
with this configuration.
"""
INPUT_PATCH_LEN = 32
OUTPUT_PATCH_LEN = 128
NUM_LAYERS = 20
MODEL_DIMS = 1280
QUANTILES = list(np.arange(1, 10) / 10.0)
EPS = 1e-7
RANDOM_SEED = 1234
def finetune(
*,
model_name: Annotated[
str, typer.Option(help="Specify the name of the huggingface model.")
] = "google/timesfm-1.0-200m",
checkpoint_path: Annotated[
str, typer.Option(help="The path to the local model checkpoint.")
] = None,
datetime_col: Annotated[str, typer.Option(help="Column having datetime.")] = "ds",
ts_cols: Annotated[
list[str], typer.Option(help="Columns of time-series features.")
] = [],
normalize: Annotated[
bool, typer.Option(help="Normalize data for eval or not")
] = True,
context_len: Annotated[int, typer.Option(help="Length of the context window")],
horizon_len: Annotated[int, typer.Option(help="Prediction length.")],
freq: Annotated[
str,
typer.Option(
...,
help="Frequency Map Str",
),
],
data_path: Annotated[str, typer.Option(help="Path to dataset csv")],
boundaries: Annotated[
Tuple[int, int, int],
typer.Option(
help="boundaries of dataset to train, val, test",
),
] = (0, 0, 0),
backend: Annotated[str, typer.Option(help="Backend device: cpu, gpu, tpu")],
batch_size: Annotated[
int, typer.Option(help="Batch size for the randomly sampled batch")
] = 16,
num_epochs: Annotated[int, typer.Option(help="Number of epochs")],
learning_rate: Annotated[float, typer.Option(help="adam optimizer learning rate")],
adam_epsilon: Annotated[float, typer.Option(help="adam optimizer epsilon")],
adam_clip_threshold: Annotated[
float, typer.Option(help="adam optimizer clip threshold")
],
cos_initial_decay_value: Annotated[
float, typer.Option(help="cosine initial decay value")
],
cos_final_decay_value: Annotated[
float, typer.Option(help="cosine final decay value")
],
cos_decay_steps: Annotated[int, typer.Option(help="Number of cosine decay steps")],
ema_decay: Annotated[float, typer.Option(help="Exponential moving average decay")],
early_stop_patience: Annotated[
int, typer.Option(..., help="Early stopping patience")
] = 5,
use_lora: Annotated[
bool,
typer.Option(
help="Train low rank adapters for stacked transformer block",
),
] = False,
lora_rank: Annotated[
int,
typer.Option(
help="LoRA Rank",
),
] = 8,
lora_target_modules: Annotated[
str,
typer.Option(
help="LoRA target modules of the transformer block. Allowed values: [all, attention, mlp]"
),
] = "all",
use_dora: Annotated[
bool,
typer.Option(
help="Apply DoRA strategy along with LoRA.",
),
] = False,
use_linear_probing: Annotated[
bool,
typer.Option(
help="Linear Probing. Train only input/output and embedding params. Freeze params in stack transformer block.",
),
] = False,
checkpoint_dir: Annotated[
str, typer.Option(help="Checkpoint directory")
] = "./checkpoints",
wandb_project: Annotated[
str, typer.Option(help="Weights & Biases project name")
] = "google_timesfm_finetune",
) -> None:
key = jax.random.PRNGKey(seed=RANDOM_SEED)
wandb.init(project=wandb_project, config=locals())
data_df = pd.read_csv(open(data_path, "r"))
if boundaries == (0, 0, 0):
# Default boundaries: train 60%, val 20%, test 20%
boundaries = [
int(len(data_df) * 0.6),
int(len(data_df) * 0.8),
len(data_df) - 1,
]
ts_cols = [col for col in data_df.columns if col != datetime_col]
dtl = data_loader.TimeSeriesdata(
data_path=data_path,
datetime_col=datetime_col,
num_cov_cols=None,
cat_cov_cols=None,
ts_cols=np.array(ts_cols),
train_range=[0, boundaries[0]],
val_range=[boundaries[0], boundaries[1]],
test_range=[boundaries[1], boundaries[2]],
hist_len=context_len,
pred_len=horizon_len,
batch_size=batch_size,
freq=freq,
normalize=normalize,
epoch_len=None,
holiday=False,
permute=False,
)
train_batches = dtl.tf_dataset(mode="train", shift=1).batch(batch_size)
val_batches = dtl.tf_dataset(mode="val", shift=horizon_len)
for tbatch in tqdm(train_batches.as_numpy_iterator()):
pass
tfm = TimesFm(
context_len=context_len,
horizon_len=horizon_len,
input_patch_len=INPUT_PATCH_LEN,
output_patch_len=OUTPUT_PATCH_LEN,
num_layers=NUM_LAYERS,
model_dims=MODEL_DIMS,
backend=backend,
per_core_batch_size=batch_size,
quantiles=QUANTILES,
)
if checkpoint_path:
tfm.load_from_checkpoint(
checkpoint_path=checkpoint_path,
checkpoint_type=checkpoints.CheckpointType.FLAX,
)
else:
tfm.load_from_checkpoint(
repo_id=model_name,
checkpoint_type=checkpoints.CheckpointType.FLAX,
)
model = pax_fiddle.Config(
patched_decoder.PatchedDecoderFinetuneModel,
name="patched_decoder_finetune",
core_layer_tpl=tfm.model_p,
)
if use_lora:
load_adapter_layer(
mdl_vars=tfm._train_state.mdl_vars,
model=model.core_layer_tpl,
lora_rank=lora_rank,
lora_target_modules=lora_target_modules,
use_dora=use_dora,
)
@pax_fiddle.auto_config
def build_learner() -> learners.Learner:
bprop_variable_inclusion = []
bprop_variable_exclusion = []
if use_lora:
bprop_variable_inclusion.append(r"^.*lora.*$")
if use_dora:
bprop_variable_inclusion.append(r"^.*dora.*$")
elif use_linear_probing:
bprop_variable_exclusion = [".*/stacked_transformer_layer/.*"]
return pax_fiddle.Config(
learners.Learner,
name="learner",
loss_name="avg_qloss",
optimizer=optimizers.Adam(
epsilon=adam_epsilon,
clip_threshold=adam_clip_threshold,
learning_rate=learning_rate,
lr_schedule=pax_fiddle.Config(
schedules.Cosine,
initial_value=cos_initial_decay_value,
final_value=cos_final_decay_value,
total_steps=cos_decay_steps,
),
ema_decay=ema_decay,
),
bprop_variable_exclusion=bprop_variable_exclusion,
bprop_variable_inclusion=bprop_variable_inclusion,
)
task_p = tasks_lib.SingleTask(
name="ts-learn",
model=model,
train=tasks_lib.SingleTask.Train(
learner=build_learner(),
),
)
task_p.model.ici_mesh_shape = [1, 1, 1]
task_p.model.mesh_axis_names = ["replica", "data", "mdl"]
DEVICES = np.array(jax.devices()).reshape([1, 1, 1])
jax.sharding.Mesh(DEVICES, ["replica", "data", "mdl"])
num_devices = jax.local_device_count()
print(f"num_devices: {num_devices}")
print(f"device kind: {jax.local_devices()[0].device_kind}")
jax_task = task_p
key, init_key = jax.random.split(key)
def process_train_batch(batch):
past_ts = batch[0].reshape(batch_size * len(ts_cols), -1)
actual_ts = batch[3].reshape(batch_size * len(ts_cols), -1)
return NestedMap(input_ts=past_ts, actual_ts=actual_ts)
def process_eval_batch(batch):
past_ts = batch[0]
actual_ts = batch[3]
return NestedMap(input_ts=past_ts, actual_ts=actual_ts)
jax_model_states, _ = trainer_lib.initialize_model_state(
jax_task,
init_key,
process_train_batch(tbatch),
checkpoint_type=checkpoint_types.CheckpointType.GDA,
)
jax_model_states.mdl_vars["params"]["core_layer"] = tfm._train_state.mdl_vars[
"params"
]
gc.collect()
jax_task = task_p
def train_step(states, prng_key, inputs):
return trainer_lib.train_step_single_learner(jax_task, states, prng_key, inputs)
def eval_step(states, prng_key, inputs):
states = states.to_eval_state()
return trainer_lib.eval_step_single_learner(jax_task, states, prng_key, inputs)
key, train_key, eval_key = jax.random.split(key, 3)
train_prng_seed = jax.random.split(train_key, num=jax.local_device_count())
eval_prng_seed = jax.random.split(eval_key, num=jax.local_device_count())
p_train_step = jax.pmap(train_step, axis_name="batch")
p_eval_step = jax.pmap(eval_step, axis_name="batch")
replicated_jax_states = trainer_lib.replicate_model_state(jax_model_states)
def reshape_batch_for_pmap(batch, num_devices):
def _reshape(input_tensor):
bsize = input_tensor.shape[0]
residual_shape = list(input_tensor.shape[1:])
nbsize = bsize // num_devices
return jnp.reshape(input_tensor, [num_devices, nbsize] + residual_shape)
return jax.tree.map(_reshape, batch)
patience = 0
best_eval_loss = 1e7
checkpoint_dir = f"{checkpoint_dir}/run_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{wandb.run.id}"
for epoch in range(num_epochs):
if patience >= early_stop_patience:
print("Early stopping.")
break
print(f"Epoch: {epoch + 1}")
train_its = train_batches.as_numpy_iterator()
train_losses = []
for batch in tqdm(train_its):
tbatch = process_train_batch(batch)
tbatch = reshape_batch_for_pmap(tbatch, num_devices)
replicated_jax_states, step_fun_out = p_train_step(
replicated_jax_states, train_prng_seed, tbatch
)
train_losses.append(step_fun_out.loss[0])
wandb.log({"train_step_loss": step_fun_out.loss[0]})
avg_train_loss = np.mean(train_losses)
print("Starting eval.")
val_its = val_batches.as_numpy_iterator()
eval_losses = []
for ev_batch in tqdm(val_its):
ebatch = process_eval_batch(ev_batch)
ebatch = reshape_batch_for_pmap(ebatch, num_devices)
_, step_fun_out = p_eval_step(replicated_jax_states, eval_prng_seed, ebatch)
eval_losses.append(step_fun_out.loss[0])
wandb.log({"eval_step_loss": step_fun_out.loss[0]})
avg_eval_loss = np.mean(eval_losses)
print(f"Train Loss: {avg_train_loss}, Val Loss: {avg_eval_loss}")
wandb.log(
{
"epoch": epoch + 1,
"avg_train_loss": avg_train_loss,
"avg_val_loss": avg_eval_loss,
}
)
if avg_eval_loss < best_eval_loss or np.isnan(avg_eval_loss):
best_eval_loss = avg_eval_loss
print("Saving checkpoint.")
jax_state_for_saving = py_utils.maybe_unreplicate_for_fully_replicated(
replicated_jax_states
)
if use_lora:
adapter_params = get_adapter_params(
params=jax_state_for_saving.mdl_vars,
lora_target_modules=lora_target_modules,
num_layers=NUM_LAYERS,
use_dora=use_dora,
)
jax_state_for_saving.mdl_vars["params"] = adapter_params
checkpoints.save_checkpoint(
jax_state_for_saving, checkpoint_dir, overwrite=True
)
patience = 0
del jax_state_for_saving
gc.collect()
else:
patience += 1
print(f"patience: {patience}")
print("Fine-tuning completed.")
if __name__ == "__main__":
typer.run(finetune)
================================================
FILE: v1/peft/finetune.sh
================================================
#!/bin/bash
# Script to finetune a model with specific configurations
# Adjust the parameters below as needed. For a full list of options and descriptions, run the script with the --help flag.
export TF_CPP_MIN_LOG_LEVEL=2 XLA_PYTHON_CLIENT_PREALLOCATE=false
python3 finetune.py \
--model-name="google/timesfm-1.0-200m" \
--backend="cpu" \
--horizon-len=128 \
--context-len=512 \
--freq="15min" \
--data-path="../datasets/ETT-small/ETTm1.csv" \
--num-epochs=100 \
--learning-rate=1e-3 \
--adam-epsilon=1e-7 \
--adam-clip-threshold=1e2 \
--early-stop-patience=10 \
--datetime-col="date" \
--use-lora \
--lora-rank=1 \
--lora-target-modules="all" \
--use-dora \
--cos-initial-decay-value=1e-4 \
--cos-decay-steps=40000 \
--cos-final-decay-value=1e-5 \
--ema-decay=0.9999
# To see all available options and their descriptions, use the --help flag
# python3 finetune.py --help
================================================
FILE: v1/peft/usage.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load Base Model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from timesfm import TimesFm, freq_map, data_loader\n",
"from adapter.utils import load_adapter_checkpoint\n",
"from tqdm import tqdm\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"\n",
"tfm = TimesFm(\n",
" context_len=512,\n",
" horizon_len=128,\n",
" input_patch_len=32,\n",
" output_patch_len=128,\n",
" num_layers=20,\n",
" model_dims=1280,\n",
" backend=\"cpu\",\n",
")\n",
"tfm.load_from_checkpoint(repo_id=\"google/timesfm-1.0-200m\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"DATA_DICT = {\n",
" \"ettm2\": {\n",
" \"boundaries\": [34560, 46080, 57600],\n",
" \"data_path\": \"../datasets/ETT-small/ETTm2.csv\",\n",
" \"freq\": \"15min\",\n",
" },\n",
" \"ettm1\": {\n",
" \"boundaries\": [34560, 46080, 57600],\n",
" \"data_path\": \"../datasets/ETT-small/ETTm1.csv\",\n",
" \"freq\": \"15min\",\n",
" },\n",
" \"etth2\": {\n",
" \"boundaries\": [8640, 11520, 14400],\n",
" \"data_path\": \"../datasets/ETT-small/ETTh2.csv\",\n",
" \"freq\": \"H\",\n",
" },\n",
" \"etth1\": {\n",
" \"boundaries\": [8640, 11520, 14400],\n",
" \"data_path\": \"../datasets/ETT-small/ETTh1.csv\",\n",
" \"freq\": \"H\",\n",
" },\n",
" \"elec\": {\n",
" \"boundaries\": [18413, 21044, 26304],\n",
" \"data_path\": \"../datasets/electricity/electricity.csv\",\n",
" \"freq\": \"H\",\n",
" },\n",
" \"traffic\": {\n",
" \"boundaries\": [12280, 14036, 17544],\n",
" \"data_path\": \"../datasets/traffic/traffic.csv\",\n",
" \"freq\": \"H\",\n",
" },\n",
" \"weather\": {\n",
" \"boundaries\": [36887, 42157, 52696],\n",
" \"data_path\": \"../datasets/weather/weather.csv\",\n",
" \"freq\": \"10min\",\n",
" },\n",
"}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load Adapter Checkpoint\n",
"\n",
"Specify the adapter checkpoint path, rank and the modules used to train the adapters and whether dora was employed or not."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"load_adapter_checkpoint(\n",
" model=tfm,\n",
" adapter_checkpoint_path=\"./checkpoints/run_20240716_163900_lyo4psz3\",\n",
" lora_rank=1,\n",
" lora_target_modules=\"all\",\n",
" use_dora=True,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Test Performance"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dataset = \"ettm1\"\n",
"data_path = DATA_DICT[dataset][\"data_path\"]\n",
"freq = DATA_DICT[dataset][\"freq\"]\n",
"int_freq = freq_map(freq)\n",
"boundaries = DATA_DICT[dataset][\"boundaries\"]\n",
"\n",
"data_df = pd.read_csv(open(data_path, \"r\"))\n",
"\n",
"ts_cols = [col for col in data_df.columns if col != \"date\"]\n",
"num_cov_cols = None\n",
"cat_cov_cols = None\n",
"\n",
"context_len = 512\n",
"pred_len = 96\n",
"\n",
"num_ts = len(ts_cols)\n",
"batch_size = 16\n",
"\n",
"dtl = data_loader.TimeSeriesdata(\n",
" data_path=data_path,\n",
" datetime_col=\"date\",\n",
" num_cov_cols=num_cov_cols,\n",
" cat_cov_cols=cat_cov_cols,\n",
" ts_cols=np.array(ts_cols),\n",
" train_range=[0, boundaries[0]],\n",
" val_range=[boundaries[0], boundaries[1]],\n",
" test_range=[boundaries[1], boundaries[2]],\n",
" hist_len=context_len,\n",
" pred_len=pred_len,\n",
" batch_size=num_ts,\n",
" freq=\"15min\",\n",
" normalize=True,\n",
" epoch_len=None,\n",
" holiday=False,\n",
" permute=True,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"test_batches = dtl.tf_dataset(mode=\"test\", shift=pred_len)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"mae_losses = []\n",
"for batch in tqdm(test_batches.as_numpy_iterator()):\n",
" past = batch[0]\n",
" actuals = batch[3]\n",
" _, forecasts = tfm.forecast(list(past), [0] * past.shape[0])\n",
" forecasts = forecasts[:, 0 : actuals.shape[1], 5]\n",
" mae_losses.append(np.abs(forecasts - actuals).mean())\n",
"\n",
"print(f\"MAE: {np.mean(mae_losses)}\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "tanmay_tfm_env",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
================================================
FILE: v1/pyproject.toml
================================================
[tool.poetry]
name = "timesfm"
packages = [
{ include = "timesfm", from = "src" },
{ include = "finetuning", from = "src" },
]
description = "Open weights time-series foundation model from Google Research."
version = "1.3.0"
authors = [
"Rajat Sen ",
"Yichen Zhou ",
"Abhimanyu Das ",
"Petros Mol ",
"Justin Güse ",
"Michael Chertushkin "
]
readme = "README.md"
keywords = ["time series", "timesfm", "forecast", "time series model"]
homepage = "https://github.com/google-research/timesfm"
repository = "https://github.com/google-research/timesfm"
classifiers = [
"Environment :: Console",
"Framework :: Flake8",
"Operating System :: OS Independent",
"Topic :: Software Development :: Documentation",
"Topic :: Software Development :: Libraries :: Python Modules",
"Topic :: Software Development :: Quality Assurance",
]
include = ["LICENSE"]
[tool.poetry.dependencies]
python = ">=3.10,<3.12"
einshape = ">=1.0.0"
numpy = ">=1.26.4"
pandas = ">=2.0.0"
utilsforecast = ">=0.1.10"
huggingface_hub = { version = ">=0.23.0", extras = ["cli"] }
scikit-learn = ">=1.2.2"
typer = ">=0.12.3"
wandb = ">=0.17.5"
absl-py = ">=1.4.0"
safetensors = "^0.5.3"
[tool.poetry.extras]
# Note: `lingvo` is an optional Google-internal dependency with strict Python
# version and packaging constraints that cause install failures on some
# environments (Colab etc.). We omit it from the pax extra here so users can
# opt-in explicitly if they need it and have a compatible environment.
pax = ["paxml", "jax", "jaxlib"]
torch = ["torch"]
[tool.poetry.dependencies.paxml]
version = ">=1.4.0"
python = ">=3.10,<3.11"
[tool.poetry.dependencies.jax]
version = ">=0.4.26"
extras = ["cuda12"]
python = ">=3.10,<3.12" # Support both python versions
[tool.poetry.dependencies.jaxlib]
version = ">=0.4.26"
python = ">=3.10,<3.12" # Support both python versions
[tool.poetry.dependencies.torch]
version = ">=2.0.0"
extras = ["cuda"]
python = ">=3.11,<3.12"
[tool.poetry.group.dev.dependencies]
pytest = ">=8.3.2"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
================================================
FILE: v1/src/adapter/__init__.py
================================================
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""adapter init file."""
from .dora_layers import DoraAttentionProjection, DoraCombinedQKVProjection, DoraLinear
from .lora_layers import LoraAttentionProjection, LoraCombinedQKVProjection, LoraLinear
================================================
FILE: v1/src/adapter/dora_layers.py
================================================
# Copyright 2024 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from jax import numpy as jnp
from praxis import base_layer
from praxis.layers import attentions, linears
WeightInit = base_layer.WeightInit
WeightHParams = base_layer.WeightHParams
class DoraTheta(base_layer.Theta):
def __init__(self, module):
self.module = module
def _dora_initialized(self):
if (
self.module.has_variable("params", "lora_a")
and self.module.has_variable("params", "lora_b")
and self.module.has_variable("params", "dora_m")
and "lora_a" in self.module._weight_hparams
and "lora_b" in self.module._weight_hparams
and "dora_m" in self.module._weight_hparams
):
return True
else:
return False
def _dorafy_var(self, w):
lora_a = super().__getattr__("lora_a")
lora_b = super().__getattr__("lora_b")
dora_m = super().__getattr__("dora_m")
lora_delta = self.module.einsum("...dr,...nr->...dn", lora_a, lora_b)
lora_delta = jnp.reshape(lora_delta, w.shape)
w_prime = w + lora_delta
column_norm = jnp.linalg.norm(w_prime, ord=2, axis=0, keepdims=True)
norm_adapted = w_prime / column_norm
w_prime = dora_m * norm_adapted
return w_prime
def __getattr__(self, k):
var = super().__getattr__(k)
if not self._dora_initialized():
return var
if k == "w":
return self._dorafy_var(var)
return var
def __getitem__(self, k):
var = super().__getattr__(k)
if not self._dora_initialized():
return var
if k == "w":
return self._dorafy_var(var)
return var
class DoraThetaDescriptor:
"""Dot syntax accession descriptor."""
def __get__(self, obj, objtype=None):
return DoraTheta(obj)
class DoraLinear(linears.Linear):
rank: int = 0
lora_init: WeightInit | None = None
theta = DoraThetaDescriptor()
def setup(self) -> None:
lora_init = self.lora_init if self.lora_init else self.weight_init
super().setup()
self.create_variable(
"lora_a",
WeightHParams(
shape=[self.input_dims, self.rank],
init=lora_init,
mesh_shape=self.mesh_shape,
tensor_split_dims_mapping=[None, None],
),
)
self.create_variable(
"lora_b",
WeightHParams(
shape=[self.output_dims, self.rank],
init=WeightInit.Constant(scale=0.0),
mesh_shape=self.mesh_shape,
tensor_split_dims_mapping=[None, None],
),
)
self.create_variable(
"dora_m",
WeightHParams(
shape=[1, self.output_dims],
init=lora_init,
mesh_shape=self.mesh_shape,
tensor_split_dims_mapping=[None, None],
),
)
class DoraAttentionProjection(attentions.AttentionProjection):
rank: int = 0
lora_init: WeightInit | None = None
theta = DoraThetaDescriptor()
def setup(self) -> None:
super().setup()
w_weight_params = self._weight_hparams["w"]
lora_init = self.lora_init if self.lora_init else w_weight_params.init
self.create_variable(
"lora_a",
WeightHParams(
shape=[self.input_dim, self.rank],
init=lora_init,
mesh_shape=self.mesh_shape,
tensor_split_dims_mapping=[
None,
None,
],
),
)
self.create_variable(
"lora_b",
WeightHParams(
shape=[self.dim_per_head * self.num_heads, self.rank],
init=WeightInit.Constant(scale=0.0),
mesh_shape=self.mesh_shape,
tensor_split_dims_mapping=[
None,
None,
],
),
)
self.create_variable(
"dora_m",
WeightHParams(
shape=[1, self.num_heads, self.dim_per_head],
init=lora_init,
mesh_shape=self.mesh_shape,
tensor_split_dims_mapping=[None, None, None],
),
)
class DoraCombinedQKVProjection(attentions.CombinedQKVProjectionLayer):
rank: int = 0
lora_init: WeightInit | None = None
theta = DoraThetaDescriptor()
def setup(self) -> None:
super().setup()
w_weight_params = self._weight_hparams["w"]
lora_init = self.lora_init if self.lora_init else w_weight_params.init
self.create_variable(
"lora_a",
WeightHParams(
shape=[3, self.input_dim, self.rank],
init=lora_init,
mesh_shape=self.mesh_shape,
tensor_split_dims_mapping=[None, None, None],
),
)
self.create_variable(
"lora_b",
WeightHParams(
shape=[3, self.dim_per_head * self.num_heads, self.rank],
init=WeightInit.Constant(scale=0.0),
mesh_shape=self.mesh_shape,
tensor_split_dims_mapping=[None, None, None],
),
)
self.create_variable(
"dora_m",
WeightHParams(
shape=[3, 1, self.num_heads, self.dim_per_head],
init=lora_init,
mesh_shape=self.mesh_shape,
tensor_split_dims_mapping=[None, None, None, None],
),
)
================================================
FILE: v1/src/adapter/lora_layers.py
================================================
# Copyright 2024 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from jax import numpy as jnp
from praxis import base_layer
from praxis.layers import attentions, linears
WeightInit = base_layer.WeightInit
WeightHParams = base_layer.WeightHParams
class LoraTheta(base_layer.Theta):
def __init__(self, module):
self.module = module
def _lora_initialized(self):
if (
self.module.has_variable("params", "lora_a")
and self.module.has_variable("params", "lora_b")
and "lora_a" in self.module._weight_hparams
and "lora_b" in self.module._weight_hparams
):
return True
else:
return False
def _lorafy_var(self, w):
lora_a = super().__getattr__("lora_a")
lora_b = super().__getattr__("lora_b")
lora_delta = self.module.einsum("...dr,...nr->...dn", lora_a, lora_b)
lora_delta = jnp.reshape(lora_delta, w.shape)
w_prime = w + lora_delta
return w_prime
def __getattr__(self, k):
var = super().__getattr__(k)
if not self._lora_initialized():
return var
if k == "w":
return self._lorafy_var(var)
return var
def __getitem__(self, k):
var = super().__getattr__(k)
if not self._lora_initialized():
return var
if k == "w":
return self._lorafy_var(var)
return var
class LoraThetaDescriptor:
"""Dot syntax accession descriptor."""
def __get__(self, obj, objtype=None):
return LoraTheta(obj)
class LoraLinear(linears.Linear):
rank: int = 0
lora_init: WeightInit | None = None
theta = LoraThetaDescriptor()
def setup(self) -> None:
lora_init = self.lora_init if self.lora_init else self.weight_init
super().setup()
self.create_variable(
"lora_a",
WeightHParams(
shape=[self.input_dims, self.rank],
init=lora_init,
mesh_shape=self.mesh_shape,
tensor_split_dims_mapping=[None, None],
),
)
self.create_variable(
"lora_b",
WeightHParams(
shape=[self.output_dims, self.rank],
init=WeightInit.Constant(scale=0.0),
mesh_shape=self.mesh_shape,
tensor_split_dims_mapping=[None, None],
),
)
class LoraAttentionProjection(attentions.AttentionProjection):
rank: int = 0
lora_init: WeightInit | None = None
theta = LoraThetaDescriptor()
def setup(self) -> None:
super().setup()
w_weight_params = self._weight_hparams["w"]
lora_init = self.lora_init if self.lora_init else w_weight_params.init
self.create_variable(
"lora_a",
WeightHParams(
shape=[self.input_dim, self.rank],
init=lora_init,
mesh_shape=self.mesh_shape,
tensor_split_dims_mapping=[
None,
None,
],
),
)
self.create_variable(
"lora_b",
WeightHParams(
shape=[self.dim_per_head * self.num_heads, self.rank],
init=WeightInit.Constant(scale=0.0),
mesh_shape=self.mesh_shape,
tensor_split_dims_mapping=[
None,
None,
],
),
)
class LoraCombinedQKVProjection(attentions.CombinedQKVProjectionLayer):
rank: int = 0
lora_init: WeightInit | None = None
theta = LoraThetaDescriptor()
def setup(self) -> None:
super().setup()
w_weight_params = self._weight_hparams["w"]
lora_init = self.lora_init if self.lora_init else w_weight_params.init
self.create_variable(
"lora_a",
WeightHParams(
shape=[3, self.input_dim, self.rank],
init=lora_init,
mesh_shape=self.mesh_shape,
tensor_split_dims_mapping=[None, None, None],
),
)
self.create_variable(
"lora_b",
WeightHParams(
shape=[3, self.dim_per_head * self.num_heads, self.rank],
init=WeightInit.Constant(scale=0.0),
mesh_shape=self.mesh_shape,
tensor_split_dims_mapping=[None, None, None],
),
)
================================================
FILE: v1/src/adapter/utils.py
================================================
# Copyright 2024 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file provides functionality for loading and merging adapter weights
in timesfm model, specifically for LoRA and DoRA.
LoRA: https://arxiv.org/abs/2106.09685
DoRA: https://arxiv.org/abs/2402.09353v4
"""
import time
import jax
import jax.numpy as jnp
from paxml import checkpoints, tasks_lib
from paxml.train_states import TrainState
from praxis import pax_fiddle
from adapter.dora_layers import (
DoraAttentionProjection,
DoraCombinedQKVProjection,
DoraLinear,
)
from adapter.lora_layers import (
LoraAttentionProjection,
LoraCombinedQKVProjection,
LoraLinear,
)
from timesfm import TimesFm
def get_adapter_params(
params: dict, lora_target_modules: str, num_layers: int, use_dora: bool = False
) -> dict:
"""
Extracts adapter parameters from the given model parameters for saving the checkpoint.
Args:
params (dict): The full model parameters.
lora_target_modules (str): Target modules for LoRA/DoRA adaptation.
num_layers (int): Number of transformer layers.
use_dora (bool, optional): Whether DoRA was used or not. Defaults to False.
Returns:
dict: A dictionary containing the extracted adapter parameters.
"""
adapter_params = {}
for i in range(num_layers):
layer_key = f"x_layers_{i}"
adapter_params[layer_key] = {}
if lora_target_modules in ["all", "mlp"]:
for ff_layer_key in ["ffn_layer1", "ffn_layer2"]:
linear = params["params"]["core_layer"]["stacked_transformer_layer"][
layer_key
]["ff_layer"][ff_layer_key]["linear"]
lora_a = linear["lora_a"]
lora_b = linear["lora_b"]
adapter_params[layer_key][ff_layer_key] = {
"lora_a": lora_a,
"lora_b": lora_b,
}
if use_dora:
adapter_params[layer_key][ff_layer_key]["dora_m"] = linear["dora_m"]
if lora_target_modules in ["all", "attention"]:
attention = params["params"]["core_layer"]["stacked_transformer_layer"][
layer_key
]["self_attention"]
for component in ["key", "query", "value", "post"]:
lora_a = attention[component]["lora_a"]
lora_b = attention[component]["lora_b"]
adapter_params[layer_key][component] = {
"lora_a": lora_a,
"lora_b": lora_b,
}
if use_dora:
adapter_params[layer_key][component]["dora_m"] = attention[
component
]["dora_m"]
return adapter_params
def load_adapter_checkpoint(
model: TimesFm,
adapter_checkpoint_path: str,
lora_rank: int,
lora_target_modules: str,
use_dora: bool,
) -> None:
"""
Loads an adapter checkpoint and merges it with the original model weights.
Args:
model (TimesFm): The model to update.
adapter_checkpoint_path (str): Path to the adapter checkpoint.
lora_rank (int): Rank of the LoRA adaptation.
lora_target_modules (str): Target modules for adaptation.
use_dora (bool): Whether DoRA was used or not.
Returns:
None
"""
"""
currently loading and initializing the model with adapter layers first and then merging the
adapter weights to original weights and replacing the adapter layers back to original layer.
# NOTE: refactor this. there should be a better way to load the LoRA checkpoint.
"""
model._logging(f"Restoring adapter checkpoint from {adapter_checkpoint_path}.")
start_time = time.time()
original_linear_tpl, original_attn_tpl, original_combined_qkv_tpl = (
load_adapter_layer(
mdl_vars=model._train_state.mdl_vars,
model=model._model,
lora_rank=lora_rank,
lora_target_modules=lora_target_modules,
use_dora=use_dora,
)
)
var_weight_hparams = model._model.abstract_init_with_metadata(
model._get_sample_inputs(), do_eval=True
)
adapter_weight_hparams = _get_adapter_weight_params(
var_weight_hparams=var_weight_hparams,
lora_target_modules=lora_target_modules,
num_layers=model._model.stacked_transformer_params_tpl.num_layers,
use_dora=use_dora,
)
adapter_state_partition_specs = tasks_lib.create_state_partition_specs(
adapter_weight_hparams,
mesh_shape=model.mesh_shape,
mesh_axis_names=model.mesh_name,
discard_opt_states=True,
learners=None,
)
adapter_state_local_shapes = tasks_lib.create_state_unpadded_shapes(
adapter_weight_hparams,
discard_opt_states=True,
learners=None,
)
adapter_train_state = checkpoints.restore_checkpoint(
state_global_shapes=adapter_state_local_shapes,
checkpoint_dir=adapter_checkpoint_path,
checkpoint_type=checkpoints.CheckpointType.FLAX,
state_specs=adapter_state_partition_specs,
step=None,
)
# add adapter weights to the original weights
_merge_adapter_weights(
model=model,
adapter_train_state=adapter_train_state,
lora_target_modules=lora_target_modules,
num_layers=model._model.stacked_transformer_params_tpl.num_layers,
use_dora=use_dora,
)
# replace back with the original model layer
if lora_target_modules in ["all", "mlp"]:
model._model.stacked_transformer_params_tpl.transformer_layer_params_tpl.tr_fflayer_tpl.fflayer_tpl.linear_tpl = (
original_linear_tpl
)
if lora_target_modules in ["all", "attention"]:
model._model.stacked_transformer_params_tpl.transformer_layer_params_tpl.tr_atten_tpl.proj_tpl = (
original_attn_tpl
)
model._model.stacked_transformer_params_tpl.transformer_layer_params_tpl.tr_atten_tpl.combined_qkv_proj_tpl = (
original_combined_qkv_tpl
)
model._logging(
f"Restored adapter checkpoint in {time.time() - start_time:.2f} seconds."
)
# jit compile the model
model.jit_decode()
def _merge_adapter_weights(
model: TimesFm,
adapter_train_state: TrainState,
lora_target_modules: str,
num_layers: int,
use_dora: bool,
) -> None:
"""
Merges adapter weights with the original model weights.
Args:
model (TimesFm): The model to update.
adapter_train_state (TrainState): The adapter's train state.
lora_target_modules (str): Target modules for adaptation.
num_layers (int): Number of transformer layers.
use_dora (bool): Whether DoRA was used or not.
"""
for i in range(num_layers):
layer_key = f"x_layers_{i}"
if lora_target_modules in ["all", "mlp"]:
for ff_layer_key in ["ffn_layer1", "ffn_layer2"]:
linear = model._train_state.mdl_vars["params"][
"stacked_transformer_layer"
][layer_key]["ff_layer"][ff_layer_key]["linear"]
params = adapter_train_state.mdl_vars[layer_key][ff_layer_key]
lora_a = params["lora_a"]
lora_b = params["lora_b"]
w = linear["w"]
lora_delta = jnp.einsum("...dr,...nr->...dn", lora_a, lora_b)
lora_delta = jnp.reshape(lora_delta, w.shape)
w_prime = w + lora_delta
if use_dora:
dora_m = params["dora_m"]
column_norm = jnp.linalg.norm(w_prime, ord=2, axis=0, keepdims=True)
norm_adapted = w_prime / column_norm
w_prime = dora_m * norm_adapted
linear["w"] = w_prime
del linear["dora_m"]
else:
linear["w"] = w_prime
del linear["lora_a"]
del linear["lora_b"]
if lora_target_modules in ["all", "attention"]:
attention = model._train_state.mdl_vars["params"][
"stacked_transformer_layer"
][layer_key]["self_attention"]
for component in ["key", "query", "value", "post"]:
params = adapter_train_state.mdl_vars[layer_key][component]
lora_a = params["lora_a"]
lora_b = params["lora_b"]
w = attention[component]["w"]
lora_delta = jnp.einsum("...dr,...nr->...dn", lora_a, lora_b)
lora_delta = jnp.reshape(lora_delta, w.shape)
w_prime = w + lora_delta
if use_dora:
dora_m = params["dora_m"]
column_norm = jnp.linalg.norm(w_prime, ord=2, axis=0, keepdims=True)
norm_adapted = w_prime / column_norm
w_prime = dora_m * norm_adapted
attention[component]["w"] = w_prime
del attention[component]["dora_m"]
else:
attention[component]["w"] = w_prime
del attention[component]["lora_a"]
del attention[component]["lora_b"]
def _get_adapter_weight_params(
var_weight_hparams: dict, lora_target_modules: str, num_layers: int, use_dora: bool
) -> dict:
"""
Extracts adapter weight parameters from the given variable weight hyperparameters.
Args:
var_weight_hparams (dict): Variable weight hyperparameters.
lora_target_modules (str): Target modules for adaptation.
num_layers (int): Number of transformer layers.
use_dora (bool): Whether DoRA was used or not.
Returns:
dict: A dictionary containing the extracted adapter weight parameters.
"""
adapter_params = {}
for i in range(num_layers):
layer = f"x_layers_{i}"
adapter_params[layer] = {}
if lora_target_modules in ["all", "mlp"]:
for ff_layer_key in ["ffn_layer1", "ffn_layer2"]:
adapter_weight_params = var_weight_hparams["params"][
"stacked_transformer_layer"
][layer]["ff_layer"][ff_layer_key]["linear"]
adapter_params[layer][ff_layer_key] = {
"lora_a": adapter_weight_params["lora_a"],
"lora_b": adapter_weight_params["lora_b"],
}
if use_dora:
adapter_params[layer][ff_layer_key]["dora_m"] = (
adapter_weight_params["dora_m"]
)
if lora_target_modules in ["all", "attention"]:
for component in ["key", "value", "query", "post"]:
adapter_weight_params = var_weight_hparams["params"][
"stacked_transformer_layer"
][layer]["self_attention"][component]
adapter_params[layer][component] = {
"lora_a": adapter_weight_params["lora_a"],
"lora_b": adapter_weight_params["lora_b"],
}
if use_dora:
adapter_params[layer][component]["dora_m"] = adapter_weight_params[
"dora_m"
]
return adapter_params
def load_adapter_layer(
mdl_vars: dict,
model: pax_fiddle.Config,
lora_rank: int,
lora_target_modules: str,
use_dora: bool = False,
) -> tuple[pax_fiddle.Config, pax_fiddle.Config]:
"""
Updates target modules with adapter layers.
Args:
mdl_vars (dict): Model variables.
model (pax_fiddle.Config): Model configuration.
lora_rank (int): Rank of the LoRA adaptation.
lora_target_modules (str): Target modules for adaptation.
use_dora (bool, optional): Whether DoRA was used or not.
Returns:
tuple[pax_fiddle.Config, pax_fiddle.Config]: Updated model configurations.
"""
original_linear_tpl = original_attn_tpl = original_combined_qkv_tpl = None
if lora_target_modules in ["all", "mlp"]:
original_linear_tpl = (
model.stacked_transformer_params_tpl.transformer_layer_params_tpl.tr_fflayer_tpl.fflayer_tpl.linear_tpl
)
adapter_linear_tpl = (
pax_fiddle.Config(
DoraLinear,
rank=lora_rank,
)
if use_dora
else pax_fiddle.Config(
LoraLinear,
rank=lora_rank,
)
)
adapter_linear_tpl.copy_fields_from(original_linear_tpl)
model.stacked_transformer_params_tpl.transformer_layer_params_tpl.tr_fflayer_tpl.fflayer_tpl.linear_tpl = (
adapter_linear_tpl
)
if lora_target_modules in ["all", "attention"]:
original_attn_tpl = (
model.stacked_transformer_params_tpl.transformer_layer_params_tpl.tr_atten_tpl.proj_tpl
)
adapter_attn_tpl = (
pax_fiddle.Config(DoraAttentionProjection, rank=lora_rank)
if use_dora
else pax_fiddle.Config(LoraAttentionProjection, rank=lora_rank)
)
adapter_attn_tpl.copy_fields_from(original_attn_tpl)
original_combined_qkv_tpl = (
model.stacked_transformer_params_tpl.transformer_layer_params_tpl.tr_atten_tpl.combined_qkv_proj_tpl
)
adapter_combined_qkv_tpl = (
pax_fiddle.Config(DoraCombinedQKVProjection, rank=lora_rank)
if use_dora
else pax_fiddle.Config(LoraCombinedQKVProjection, rank=lora_rank)
)
adapter_combined_qkv_tpl.copy_fields_from(original_combined_qkv_tpl)
model.stacked_transformer_params_tpl.transformer_layer_params_tpl.tr_atten_tpl.proj_tpl = (
adapter_attn_tpl
)
model.stacked_transformer_params_tpl.transformer_layer_params_tpl.tr_atten_tpl.combined_qkv_proj_tpl = (
adapter_combined_qkv_tpl
)
# initialize and add adapter weights
_initialize_adapter_params(
mdl_vars=mdl_vars,
num_layers=model.stacked_transformer_params_tpl.num_layers,
lora_rank=lora_rank,
lora_target_modules=lora_target_modules,
use_dora=use_dora,
)
return original_linear_tpl, original_attn_tpl, original_combined_qkv_tpl
def _initialize_adapter_params(
mdl_vars: dict,
num_layers,
lora_rank: int,
lora_target_modules: str,
use_dora: bool = False,
seed: int = 1234,
) -> dict:
"""
Initializes and adds adapter parameters to target modules.
Args:
mdl_vars (dict): Model variables.
num_layers (int): Number of transformer layers.
lora_rank (int): Rank of the LoRA adaptation.
lora_target_modules (str): Target modules for adaptation.
use_dora (bool, optional): Whether DoRA was used or not.
seed (int, optional): Random seed for initialization. Defaults to 1234.
Returns:
dict: Updated model variables with initialized adapter parameters.
"""
for i in range(num_layers):
layer_key = f"x_layers_{i}"
if lora_target_modules in ["all", "mlp"]:
for ff_layer_key in ["ffn_layer1", "ffn_layer2"]:
linear = mdl_vars["params"]["stacked_transformer_layer"][layer_key][
"ff_layer"
][ff_layer_key]["linear"]
original_w = linear["w"]
input_dim, output_dim = original_w.shape
std_dev = 1 / jnp.sqrt(lora_rank)
normal_initializer = jax.nn.initializers.normal(std_dev)
lora_a = normal_initializer(
jax.random.key(seed), (input_dim, lora_rank), jnp.float32
)
lora_b = jnp.zeros((output_dim, lora_rank))
linear["lora_a"] = lora_a
linear["lora_b"] = lora_b
if use_dora:
norm = jnp.linalg.norm(original_w, ord=2, axis=0, keepdims=True)
linear["dora_m"] = norm
if lora_target_modules in ["all", "attention"]:
attention = mdl_vars["params"]["stacked_transformer_layer"][layer_key][
"self_attention"
]
for component in ["key", "query", "value", "post"]:
original_w = attention[component]["w"]
w_dim = original_w.shape[0]
std_dev = 1 / jnp.sqrt(lora_rank)
normal_initializer = jax.nn.initializers.normal(std_dev)
lora_a = normal_initializer(
jax.random.key(seed), (w_dim, lora_rank), jnp.float32
)
lora_b = jnp.zeros((w_dim, lora_rank))
attention[component]["lora_a"] = lora_a
attention[component]["lora_b"] = lora_b
if use_dora:
norm = jnp.linalg.norm(
original_w, ord=2, axis=0, keepdims=True
).astype(jnp.float32)
attention[component]["dora_m"] = norm
return mdl_vars
================================================
FILE: v1/src/finetuning/__init__.py
================================================
================================================
FILE: v1/src/finetuning/finetuning_example.py
================================================
"""
Example usage of the TimesFM Finetuning Framework.
For single GPU:
python script.py --training_mode=single
For multiple GPUs:
python script.py --training_mode=multi --gpu_ids=0,1,2
"""
import os
from dataclasses import asdict
from os import path
from typing import Optional, Tuple
import numpy as np
import pandas as pd
import torch
import torch.multiprocessing as mp
import yfinance as yf
from absl import app, flags
from huggingface_hub import snapshot_download
from safetensors.torch import load_file
from torch.utils.data import Dataset
from finetuning.finetuning_torch import FinetuningConfig, TimesFMFinetuner
from timesfm import TimesFm, TimesFmCheckpoint, TimesFmHparams
from timesfm.pytorch_patched_decoder import (PatchedTimeSeriesDecoder,
TimesFMConfig)
FLAGS = flags.FLAGS
flags.DEFINE_enum(
"training_mode",
"single",
["single", "multi"],
'Training mode: "single" for single-GPU or "multi" for multi-GPU training.',
)
flags.DEFINE_list(
"gpu_ids", ["0"],
"Comma-separated list of GPU IDs to use for multi-GPU training. Example: 0,1,2"
)
flags.DEFINE_string(
"local_model_path",
None,
"Path to a local .safetensors model file. If provided, overrides Hugging Face download."
)
class TimeSeriesDataset(Dataset):
"""Dataset for time series data compatible with TimesFM."""
def __init__(self,
series: np.ndarray,
context_length: int,
horizon_length: int,
freq_type: int = 0):
"""
Initialize dataset.
Args:
series: Time series data
context_length: Number of past timesteps to use as input
horizon_length: Number of future timesteps to predict
freq_type: Frequency type (0, 1, or 2)
"""
if freq_type not in [0, 1, 2]:
raise ValueError("freq_type must be 0, 1, or 2")
self.series = series
self.context_length = context_length
self.horizon_length = horizon_length
self.freq_type = freq_type
self._prepare_samples()
def _prepare_samples(self) -> None:
"""Prepare sliding window samples from the time series."""
self.samples = []
total_length = self.context_length + self.horizon_length
for start_idx in range(0, len(self.series) - total_length + 1):
end_idx = start_idx + self.context_length
x_context = self.series[start_idx:end_idx]
x_future = self.series[end_idx:end_idx + self.horizon_length]
self.samples.append((x_context, x_future))
def __len__(self) -> int:
return len(self.samples)
def __getitem__(
self, index: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
x_context, x_future = self.samples[index]
x_context = torch.tensor(x_context, dtype=torch.float32)
x_future = torch.tensor(x_future, dtype=torch.float32)
input_padding = torch.zeros_like(x_context)
freq = torch.tensor([self.freq_type], dtype=torch.long)
return x_context, input_padding, freq, x_future
def prepare_datasets(series: np.ndarray,
context_length: int,
horizon_length: int,
freq_type: int = 0,
train_split: float = 0.8) -> Tuple[Dataset, Dataset]:
"""
Prepare training and validation datasets from time series data.
Args:
series: Input time series data
context_length: Number of past timesteps to use
horizon_length: Number of future timesteps to predict
freq_type: Frequency type (0, 1, or 2)
train_split: Fraction of data to use for training
Returns:
Tuple of (train_dataset, val_dataset)
"""
train_size = int(len(series) * train_split)
train_data = series[:train_size]
val_data = series[train_size:]
# Create datasets with specified frequency type
train_dataset = TimeSeriesDataset(train_data,
context_length=context_length,
horizon_length=horizon_length,
freq_type=freq_type)
val_dataset = TimeSeriesDataset(val_data,
context_length=context_length,
horizon_length=horizon_length,
freq_type=freq_type)
return train_dataset, val_dataset
def get_model(load_weights: bool = False):
device = "cuda" if torch.cuda.is_available() else "cpu"
hparams = TimesFmHparams(
backend=device,
per_core_batch_size=32,
horizon_len=128,
num_layers=50,
use_positional_embedding=False,
context_len=192,
)
if load_weights:
if FLAGS.local_model_path:
tfm_config = TimesFMConfig()
model = PatchedTimeSeriesDecoder(tfm_config)
loaded_checkpoint = load_file(FLAGS.local_model_path)
else:
repo_id = "google/timesfm-2.0-500m-pytorch"
tfm = TimesFm(hparams=hparams,
checkpoint=TimesFmCheckpoint(huggingface_repo_id=repo_id))
tfm_config = tfm._model_config
model = PatchedTimeSeriesDecoder(tfm_config)
checkpoint_path = path.join(snapshot_download(repo_id), "torch_model.ckpt")
loaded_checkpoint = torch.load(checkpoint_path, weights_only=True)
model.load_state_dict(loaded_checkpoint)
return model, hparams, tfm_config
def plot_predictions(
model: TimesFm,
val_dataset: Dataset,
save_path: Optional[str] = "predictions.png",
) -> None:
"""
Plot model predictions against ground truth for a batch of validation data.
Args:
model: Trained TimesFM model
val_dataset: Validation dataset
save_path: Path to save the plot
"""
import matplotlib.pyplot as plt
model.eval()
x_context, x_padding, freq, x_future = val_dataset[0]
x_context = x_context.unsqueeze(0) # Add batch dimension
x_padding = x_padding.unsqueeze(0)
freq = freq.unsqueeze(0)
x_future = x_future.unsqueeze(0)
device = next(model.parameters()).device
x_context = x_context.to(device)
x_padding = x_padding.to(device)
freq = freq.to(device)
x_future = x_future.to(device)
with torch.no_grad():
predictions = model(x_context, x_padding.float(), freq)
predictions_mean = predictions[..., 0] # [B, N, horizon_len]
last_patch_pred = predictions_mean[:, -1, :] # [B, horizon_len]
context_vals = x_context[0].cpu().numpy()
future_vals = x_future[0].cpu().numpy()
pred_vals = last_patch_pred[0].cpu().numpy()
context_len = len(context_vals)
horizon_len = len(future_vals)
plt.figure(figsize=(12, 6))
plt.plot(range(context_len),
context_vals,
label="Historical Data",
color="blue",
linewidth=2)
plt.plot(
range(context_len, context_len + horizon_len),
future_vals,
label="Ground Truth",
color="green",
linestyle="--",
linewidth=2,
)
plt.plot(range(context_len, context_len + horizon_len),
pred_vals,
label="Prediction",
color="red",
linewidth=2)
plt.xlabel("Time Step")
plt.ylabel("Value")
plt.title("TimesFM Predictions vs Ground Truth")
plt.legend()
plt.grid(True)
if save_path:
plt.savefig(save_path)
print(f"Plot saved to {save_path}")
plt.close()
def get_data(context_len: int,
horizon_len: int,
freq_type: int = 0) -> Tuple[Dataset, Dataset]:
df = yf.download("AAPL", start="2010-01-01", end="2019-01-01")
time_series = df["Close"].values
train_dataset, val_dataset = prepare_datasets(
series=time_series,
context_length=context_len,
horizon_length=horizon_len,
freq_type=freq_type,
train_split=0.8,
)
print(f"Created datasets:")
print(f"- Training samples: {len(train_dataset)}")
print(f"- Validation samples: {len(val_dataset)}")
print(f"- Using frequency type: {freq_type}")
return train_dataset, val_dataset
def single_gpu_example():
"""Basic example of finetuning TimesFM on stock data."""
model, hparams, tfm_config = get_model(load_weights=True)
config = FinetuningConfig(batch_size=256,
num_epochs=5,
learning_rate=1e-4,
use_wandb=True,
freq_type=1,
log_every_n_steps=10,
val_check_interval=0.5,
use_quantile_loss=True)
train_dataset, val_dataset = get_data(128,
tfm_config.horizon_len,
freq_type=config.freq_type)
finetuner = TimesFMFinetuner(model, config)
print("\nStarting finetuning...")
results = finetuner.finetune(train_dataset=train_dataset,
val_dataset=val_dataset)
print("\nFinetuning completed!")
print(f"Training history: {len(results['history']['train_loss'])} epochs")
plot_predictions(
model=model,
val_dataset=val_dataset,
save_path="timesfm_predictions.png",
)
def setup_process(rank, world_size, model, config, train_dataset, val_dataset,
return_dict):
"""Setup process function with optimized CUDA handling."""
try:
if torch.cuda.is_available():
torch.cuda.set_device(rank)
os.environ["MASTER_ADDR"] = config.master_addr
os.environ["MASTER_PORT"] = config.master_port
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend="nccl",
world_size=world_size,
rank=rank)
finetuner = TimesFMFinetuner(model, config, rank=rank)
results = finetuner.finetune(train_dataset=train_dataset,
val_dataset=val_dataset)
if rank == 0:
return_dict["results"] = results
plot_predictions(
model=model,
val_dataset=val_dataset,
save_path="timesfm_predictions.png",
)
except Exception as e:
print(f"Error in process {rank}: {str(e)}")
raise e
finally:
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()
def multi_gpu_example():
"""Example of finetuning TimesFM using multiple GPUs with optimized spawn."""
mp.set_start_method("spawn", force=True)
gpu_ids = [0, 1]
world_size = len(gpu_ids)
model, hparams, tfm_config = get_model(load_weights=True)
# Create config
config = FinetuningConfig(
batch_size=256,
num_epochs=5,
learning_rate=3e-5,
use_wandb=True,
distributed=True,
gpu_ids=gpu_ids,
log_every_n_steps=50,
val_check_interval=0.5,
)
train_dataset, val_dataset = get_data(128, tfm_config.horizon_len)
manager = mp.Manager()
return_dict = manager.dict()
# Launch processes
mp.spawn(
setup_process,
args=(world_size, model, config, train_dataset, val_dataset, return_dict),
nprocs=world_size,
join=True,
)
results = return_dict.get("results", None)
print("\nFinetuning completed!")
return results
def main(argv):
"""Main function that selects and runs the appropriate training mode."""
try:
if FLAGS.training_mode == "single":
print("\nStarting single-GPU training...")
single_gpu_example()
else:
gpu_ids = [int(id) for id in FLAGS.gpu_ids]
print(f"\nStarting multi-GPU training using GPUs: {gpu_ids}...")
config = FinetuningConfig(
batch_size=256,
num_epochs=5,
learning_rate=3e-5,
use_wandb=True,
distributed=True,
gpu_ids=gpu_ids,
)
results = multi_gpu_example(config)
print("\nMulti-GPU training completed!")
except Exception as e:
print(f"Training failed: {str(e)}")
finally:
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()
if __name__ == "__main__":
app.run(main)
================================================
FILE: v1/src/finetuning/finetuning_torch.py
================================================
"""
TimesFM Finetuner: A flexible framework for finetuning TimesFM models on custom datasets.
"""
import logging
import os
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, Dataset
from timesfm.pytorch_patched_decoder import create_quantiles
import wandb
class MetricsLogger(ABC):
"""Abstract base class for logging metrics during training.
This class defines the interface for logging metrics during model training.
Concrete implementations can log to different backends (e.g., WandB, TensorBoard).
"""
@abstractmethod
def log_metrics(self,
metrics: Dict[str, Any],
step: Optional[int] = None) -> None:
"""Log metrics to the specified backend.
Args:
metrics: Dictionary containing metric names and values.
step: Optional step number or epoch for the metrics.
"""
pass
@abstractmethod
def close(self) -> None:
"""Clean up any resources used by the logger."""
pass
class WandBLogger(MetricsLogger):
"""Weights & Biases implementation of metrics logging.
Args:
project: Name of the W&B project.
config: Configuration dictionary to log.
rank: Process rank in distributed training.
"""
def __init__(self, project: str, config: Dict[str, Any], rank: int = 0):
self.rank = rank
if rank == 0:
wandb.init(project=project, config=config)
def log_metrics(self,
metrics: Dict[str, Any],
step: Optional[int] = None) -> None:
"""Log metrics to W&B if on the main process.
Args:
metrics: Dictionary of metrics to log.
step: Current training step or epoch.
"""
if self.rank == 0:
wandb.log(metrics, step=step)
def close(self) -> None:
"""Finish the W&B run if on the main process."""
if self.rank == 0:
wandb.finish()
class DistributedManager:
"""Manages distributed training setup and cleanup.
Args:
world_size: Total number of processes.
rank: Process rank.
master_addr: Address of the master process.
master_port: Port for distributed communication.
backend: PyTorch distributed backend to use.
"""
def __init__(
self,
world_size: int,
rank: int,
master_addr: str = "localhost",
master_port: str = "12358",
backend: str = "nccl",
):
self.world_size = world_size
self.rank = rank
self.master_addr = master_addr
self.master_port = master_port
self.backend = backend
def setup(self) -> None:
"""Initialize the distributed environment."""
os.environ["MASTER_ADDR"] = self.master_addr
os.environ["MASTER_PORT"] = self.master_port
if not dist.is_initialized():
dist.init_process_group(backend=self.backend,
world_size=self.world_size,
rank=self.rank)
def cleanup(self) -> None:
"""Clean up the distributed environment."""
if dist.is_initialized():
dist.destroy_process_group()
@dataclass
class FinetuningConfig:
"""Configuration for model training.
Args:
batch_size: Number of samples per batch.
num_epochs: Number of training epochs.
learning_rate: Initial learning rate.
weight_decay: L2 regularization factor.
freq_type: Frequency, can be [0, 1, 2].
use_quantile_loss: bool = False # Flag to enable/disable quantile loss
quantiles: Optional[List[float]] = None
device: Device to train on ('cuda' or 'cpu').
distributed: Whether to use distributed training.
gpu_ids: List of GPU IDs to use.
master_port: Port for distributed training.
master_addr: Address for distributed training.
use_wandb: Whether to use Weights & Biases logging.
wandb_project: W&B project name.
log_every_n_steps: Log metrics every N steps (batches), this is inspired from Pytorch Lightning
val_check_interval: How often within one training epoch to check val metrics. (also from Pytorch Lightning)
Can be: float (0.0-1.0): fraction of epoch (e.g., 0.5 = validate twice per epoch)
int: validate every N batches
"""
batch_size: int = 32
num_epochs: int = 20
learning_rate: float = 1e-4
weight_decay: float = 0.01
freq_type: int = 0
use_quantile_loss: bool = False
quantiles: Optional[List[float]] = None
device: str = "cuda" if torch.cuda.is_available() else "cpu"
distributed: bool = False
gpu_ids: List[int] = field(default_factory=lambda: [0])
master_port: str = "12358"
master_addr: str = "localhost"
use_wandb: bool = False
wandb_project: str = "timesfm-finetuning"
log_every_n_steps: int = 50
val_check_interval: float = 0.5
class TimesFMFinetuner:
"""Handles model training and validation.
Args:
model: PyTorch model to train.
config: Training configuration.
rank: Process rank for distributed training.
loss_fn: Loss function (defaults to MSE).
logger: Optional logging.Logger instance.
"""
def __init__(
self,
model: nn.Module,
config: FinetuningConfig,
rank: int = 0,
loss_fn: Optional[Callable] = None,
logger: Optional[logging.Logger] = None,
):
self.model = model
self.config = config
self.rank = rank
self.logger = logger or logging.getLogger(__name__)
self.device = torch.device(
f"cuda:{rank}" if torch.cuda.is_available() else "cpu")
self.loss_fn = loss_fn or (lambda x, y: torch.mean((x - y.squeeze(-1))**2))
if config.use_wandb:
self.metrics_logger = WandBLogger(config.wandb_project, config.__dict__,
rank)
if config.distributed:
self.dist_manager = DistributedManager(
world_size=len(config.gpu_ids),
rank=rank,
master_addr=config.master_addr,
master_port=config.master_port,
)
self.dist_manager.setup()
self.model = self._setup_distributed_model()
def _setup_distributed_model(self) -> nn.Module:
"""Configure model for distributed training."""
self.model = self.model.to(self.device)
return DDP(self.model,
device_ids=[self.config.gpu_ids[self.rank]],
output_device=self.config.gpu_ids[self.rank])
def _create_dataloader(self, dataset: Dataset, is_train: bool) -> DataLoader:
"""Create appropriate DataLoader based on training configuration.
Args:
dataset: Dataset to create loader for.
is_train: Whether this is for training (affects shuffling).
Returns:
DataLoader instance.
"""
if self.config.distributed:
sampler = torch.utils.data.distributed.DistributedSampler(
dataset,
num_replicas=len(self.config.gpu_ids),
rank=dist.get_rank(),
shuffle=is_train)
else:
sampler = None
return DataLoader(
dataset,
batch_size=self.config.batch_size,
shuffle=(is_train and not self.config.distributed),
sampler=sampler,
)
def _quantile_loss(self, pred: torch.Tensor, actual: torch.Tensor,
quantile: float) -> torch.Tensor:
"""Calculates quantile loss.
Args:
pred: Predicted values
actual: Actual values
quantile: Quantile at which loss is computed
Returns:
Quantile loss
"""
dev = actual - pred
loss_first = dev * quantile
loss_second = -dev * (1.0 - quantile)
return 2 * torch.where(loss_first >= 0, loss_first, loss_second)
def _process_batch(self, batch: List[torch.Tensor]) -> tuple:
"""Process a single batch of data.
Args:
batch: List of input tensors.
Returns:
Tuple of (loss, predictions).
"""
x_context, x_padding, freq, x_future = [
t.to(self.device, non_blocking=True) for t in batch
]
predictions = self.model(x_context, x_padding.float(), freq)
predictions_mean = predictions[..., 0]
last_patch_pred = predictions_mean[:, -1, :]
loss = self.loss_fn(last_patch_pred, x_future.squeeze(-1))
if self.config.use_quantile_loss:
quantiles = self.config.quantiles or create_quantiles()
for i, quantile in enumerate(quantiles):
last_patch_quantile = predictions[:, -1, :, i + 1]
loss += torch.mean(
self._quantile_loss(last_patch_quantile, x_future.squeeze(-1),
quantile))
return loss, predictions
def _train_epoch(self, train_loader: DataLoader,
optimizer: torch.optim.Optimizer) -> float:
"""Train for one epoch in a distributed setting.
Args:
train_loader: DataLoader for training data.
optimizer: Optimizer instance.
Returns:
Average training loss for the epoch.
"""
self.model.train()
total_loss = 0.0
num_batches = len(train_loader)
for batch in train_loader:
loss, _ = self._process_batch(batch)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / num_batches
if self.config.distributed:
avg_loss_tensor = torch.tensor(avg_loss, device=self.device)
dist.all_reduce(avg_loss_tensor, op=dist.ReduceOp.SUM)
avg_loss = (avg_loss_tensor / dist.get_world_size()).item()
return avg_loss
def _validate(self, val_loader: DataLoader) -> float:
"""Perform validation.
Args:
val_loader: DataLoader for validation data.
Returns:
Average validation loss.
"""
self.model.eval()
total_loss = 0.0
num_batches = len(val_loader)
with torch.no_grad():
for batch in val_loader:
loss, _ = self._process_batch(batch)
total_loss += loss.item()
avg_loss = total_loss / num_batches
if self.config.distributed:
avg_loss_tensor = torch.tensor(avg_loss, device=self.device)
dist.all_reduce(avg_loss_tensor, op=dist.ReduceOp.SUM)
avg_loss = (avg_loss_tensor / dist.get_world_size()).item()
return avg_loss
def finetune(self, train_dataset: Dataset,
val_dataset: Dataset) -> Dict[str, Any]:
"""Train the model.
Args:
train_dataset: Training dataset.
val_dataset: Validation dataset.
Returns:
Dictionary containing training history.
"""
self.model = self.model.to(self.device)
train_loader = self._create_dataloader(train_dataset, is_train=True)
val_loader = self._create_dataloader(val_dataset, is_train=False)
optimizer = torch.optim.Adam(self.model.parameters(),
lr=self.config.learning_rate,
weight_decay=self.config.weight_decay)
history = {"train_loss": [], "val_loss": [], "learning_rate": []}
self.logger.info(
f"Starting training for {self.config.num_epochs} epochs...")
self.logger.info(f"Training samples: {len(train_dataset)}")
self.logger.info(f"Validation samples: {len(val_dataset)}")
try:
for epoch in range(self.config.num_epochs):
train_loss = self._train_epoch(train_loader, optimizer)
val_loss = self._validate(val_loader)
current_lr = optimizer.param_groups[0]["lr"]
metrics = {
"train_loss": train_loss,
"val_loss": val_loss,
"learning_rate": current_lr,
"epoch": epoch + 1,
}
if self.config.use_wandb:
self.metrics_logger.log_metrics(metrics)
history["train_loss"].append(train_loss)
history["val_loss"].append(val_loss)
history["learning_rate"].append(current_lr)
if self.rank == 0:
self.logger.info(
f"[Epoch {epoch+1}] Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}"
)
except KeyboardInterrupt:
self.logger.info("Training interrupted by user")
if self.config.distributed:
self.dist_manager.cleanup()
if self.config.use_wandb:
self.metrics_logger.close()
return {"history": history}
================================================
FILE: v1/src/timesfm/__init__.py
================================================
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TimesFM init file."""
print(
" See https://github.com/google-research/timesfm/blob/master/README.md for updated APIs."
)
from timesfm.timesfm_base import (
freq_map,
TimesFmCheckpoint,
TimesFmHparams,
TimesFmBase,
)
import sys
try:
from timesfm.timesfm_jax import TimesFmJax as TimesFm
from timesfm import data_loader
print(f"Loaded Jax TimesFM, likely because python version is {sys.version}.")
except Exception as _:
from timesfm.timesfm_torch import TimesFmTorch as TimesFm
print(f"Loaded PyTorch TimesFM, likely because python version is {sys.version}.")
================================================
FILE: v1/src/timesfm/data_loader.py
================================================
# Copyright 2024 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TF dataloaders for general timeseries datasets.
The expected input format is csv file with a datetime index.
"""
from absl import logging
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
import tensorflow as tf
from . import time_features
class TimeSeriesdata(object):
"""Data loader class."""
def __init__(
self,
data_path,
datetime_col,
num_cov_cols,
cat_cov_cols,
ts_cols,
train_range,
val_range,
test_range,
hist_len,
pred_len,
batch_size,
freq='H',
normalize=True,
epoch_len=None,
holiday=False,
permute=True,
):
"""Initialize objects.
Args:
data_path: path to csv file
datetime_col: column name for datetime col
num_cov_cols: list of numerical global covariates
cat_cov_cols: list of categorical global covariates
ts_cols: columns corresponding to ts
train_range: tuple of train ranges
val_range: tuple of validation ranges
test_range: tuple of test ranges
hist_len: historical context
pred_len: prediction length
batch_size: batch size (number of ts in a batch)
freq: freq of original data
normalize: std. normalize data or not
epoch_len: num iters in an epoch
holiday: use holiday features or not
permute: permute ts in train batches or not
Returns:
None
"""
self.data_df = pd.read_csv(open(data_path, 'r'))
if not num_cov_cols:
self.data_df['ncol'] = np.zeros(self.data_df.shape[0])
num_cov_cols = ['ncol']
if not cat_cov_cols:
self.data_df['ccol'] = np.zeros(self.data_df.shape[0])
cat_cov_cols = ['ccol']
self.data_df.fillna(0, inplace=True)
self.data_df.set_index(pd.DatetimeIndex(self.data_df[datetime_col]),
inplace=True)
self.num_cov_cols = num_cov_cols
self.cat_cov_cols = cat_cov_cols
self.ts_cols = ts_cols
self.train_range = train_range
self.val_range = val_range
self.test_range = test_range
data_df_idx = self.data_df.index
date_index = data_df_idx.union(
pd.date_range(
data_df_idx[-1] + pd.Timedelta(1, freq=freq),
periods=pred_len + 1,
freq=freq,
))
self.time_df = time_features.TimeCovariates(
date_index, holiday=holiday).get_covariates()
self.hist_len = hist_len
self.pred_len = pred_len
self.batch_size = batch_size
self.freq = freq
self.normalize = normalize
self.data_mat = self.data_df[self.ts_cols].to_numpy().transpose()
self.data_mat = self.data_mat[:, 0:self.test_range[1]]
self.time_mat = self.time_df.to_numpy().transpose()
self.num_feat_mat = self.data_df[num_cov_cols].to_numpy().transpose()
self.cat_feat_mat, self.cat_sizes = self._get_cat_cols(cat_cov_cols)
self.normalize = normalize
if normalize:
self._normalize_data()
logging.info(
'Data Shapes: %s, %s, %s, %s',
self.data_mat.shape,
self.time_mat.shape,
self.num_feat_mat.shape,
self.cat_feat_mat.shape,
)
self.epoch_len = epoch_len
self.permute = permute
def _get_cat_cols(self, cat_cov_cols):
"""Get categorical columns."""
cat_vars = []
cat_sizes = []
for col in cat_cov_cols:
dct = {x: i for i, x in enumerate(self.data_df[col].unique())}
cat_sizes.append(len(dct))
mapped = self.data_df[col].map(lambda x: dct[x]).to_numpy().transpose() # pylint: disable=cell-var-from-loop
cat_vars.append(mapped)
return np.vstack(cat_vars), cat_sizes
def _normalize_data(self):
self.scaler = StandardScaler()
train_mat = self.data_mat[:, 0:self.train_range[1]]
self.scaler = self.scaler.fit(train_mat.transpose())
self.data_mat = self.scaler.transform(self.data_mat.transpose()).transpose()
def train_gen(self):
"""Generator for training data."""
num_ts = len(self.ts_cols)
perm = np.arange(
self.train_range[0] + self.hist_len,
self.train_range[1] - self.pred_len,
)
perm = np.random.permutation(perm)
hist_len = self.hist_len
logging.info('Hist len: %s', hist_len)
if not self.epoch_len:
epoch_len = len(perm)
else:
epoch_len = self.epoch_len
for idx in perm[0:epoch_len]:
for _ in range(num_ts // self.batch_size + 1):
if self.permute:
tsidx = np.random.choice(num_ts, size=self.batch_size, replace=False)
else:
tsidx = np.arange(num_ts)
dtimes = np.arange(idx - hist_len, idx + self.pred_len)
(
bts_train,
bts_pred,
bfeats_train,
bfeats_pred,
bcf_train,
bcf_pred,
) = self._get_features_and_ts(dtimes, tsidx, hist_len)
all_data = [
bts_train,
bfeats_train,
bcf_train,
bts_pred,
bfeats_pred,
bcf_pred,
tsidx,
]
yield tuple(all_data)
def test_val_gen(self, mode='val', shift=1):
"""Generator for validation/test data."""
if mode == 'val':
start = self.val_range[0]
end = self.val_range[1] - self.pred_len + 1
elif mode == 'test':
start = self.test_range[0]
end = self.test_range[1] - self.pred_len + 1
else:
raise NotImplementedError('Eval mode not implemented')
num_ts = len(self.ts_cols)
hist_len = self.hist_len
logging.info('Hist len: %s', hist_len)
perm = np.arange(start, end)
if self.epoch_len:
epoch_len = self.epoch_len
else:
epoch_len = len(perm)
for i in range(0, epoch_len, shift):
idx = perm[i]
for batch_idx in range(0, num_ts, self.batch_size):
tsidx = np.arange(batch_idx, min(batch_idx + self.batch_size, num_ts))
dtimes = np.arange(idx - hist_len, idx + self.pred_len)
(
bts_train,
bts_pred,
bfeats_train,
bfeats_pred,
bcf_train,
bcf_pred,
) = self._get_features_and_ts(dtimes, tsidx, hist_len)
all_data = [
bts_train,
bfeats_train,
bcf_train,
bts_pred,
bfeats_pred,
bcf_pred,
tsidx,
]
yield tuple(all_data)
def _get_features_and_ts(self, dtimes, tsidx, hist_len=None):
"""Get features and ts in specified windows."""
if hist_len is None:
hist_len = self.hist_len
data_times = dtimes[dtimes < self.data_mat.shape[1]]
bdata = self.data_mat[:, data_times]
bts = bdata[tsidx, :]
bnf = self.num_feat_mat[:, data_times]
bcf = self.cat_feat_mat[:, data_times]
btf = self.time_mat[:, dtimes]
if bnf.shape[1] < btf.shape[1]:
rem_len = btf.shape[1] - bnf.shape[1]
rem_rep = np.repeat(bnf[:, [-1]], repeats=rem_len)
rem_rep_cat = np.repeat(bcf[:, [-1]], repeats=rem_len)
bnf = np.hstack([bnf, rem_rep.reshape(bnf.shape[0], -1)])
bcf = np.hstack([bcf, rem_rep_cat.reshape(bcf.shape[0], -1)])
bfeats = np.vstack([btf, bnf])
bts_train = bts[:, 0:hist_len]
bts_pred = bts[:, hist_len:]
bfeats_train = bfeats[:, 0:hist_len]
bfeats_pred = bfeats[:, hist_len:]
bcf_train = bcf[:, 0:hist_len]
bcf_pred = bcf[:, hist_len:]
return bts_train, bts_pred, bfeats_train, bfeats_pred, bcf_train, bcf_pred
def tf_dataset(self, mode='train', shift=1):
"""Tensorflow Dataset."""
if mode == 'train':
gen_fn = self.train_gen
else:
gen_fn = lambda: self.test_val_gen(mode, shift)
output_types = tuple([tf.float32] * 2 + [tf.int32] + [tf.float32] * 2 +
[tf.int32] * 2)
dataset = tf.data.Dataset.from_generator(gen_fn, output_types)
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
return dataset
================================================
FILE: v1/src/timesfm/patched_decoder.py
================================================
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pax ML model for patched time-series decoder.
The file implements Residual MLPs, Patched Decoder layers and PAX ML models.
"""
import dataclasses
from typing import Optional, Tuple
import einshape as es
from jax import lax
import jax.numpy as jnp
from praxis import base_layer
from praxis import base_model
from praxis import layers
from praxis import pax_fiddle
from praxis import py_utils
from praxis import pytypes
from praxis.layers import activations
from praxis.layers import embedding_softmax
from praxis.layers import linears
from praxis.layers import normalizations
from praxis.layers import stochastics
from praxis.layers import transformers
# PAX shortcuts
NestedMap = py_utils.NestedMap
JTensor = pytypes.JTensor
LayerTpl = pax_fiddle.Config[base_layer.BaseLayer]
template_field = base_layer.template_field
PAD_VAL = 1123581321.0
DEFAULT_QUANTILES = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
# NestedMap keys
_INPUT_TS = "input_ts"
_TARGET_FUTURE = "actual_ts"
_INPUT_PADDING = "input_padding"
_OUTPUT_TS = "output_ts"
_FREQ = "freq"
_OUTPUT_TOKENS = "output_tokens"
_STATS = "stats"
# Small numerical value.
_TOLERANCE = 1e-7
def _shift_padded_seq(mask: JTensor, seq: JTensor) -> JTensor:
"""Shifts rows of seq based on the first 0 in each row of the mask."""
num = seq.shape[1]
# Find the index of the first 0 in each row of the mask
first_zero_idx = jnp.argmin(mask, axis=1)
# Create a range array for indexing
idx_range = jnp.arange(num)
def shift_row(carry, x):
seq_row, shift = x
shifted_idx = (idx_range - shift) % num
shifted_row = seq_row[shifted_idx]
return carry, shifted_row
# Use lax.scan to shift each row of seq based on the corresponding
# first_zero_idx.
_, shifted_seq = lax.scan(shift_row, None, (seq, first_zero_idx))
return shifted_seq
class ResidualBlock(base_layer.BaseLayer):
"""Simple feedforward block with residual connection.
Attributes:
input_dims: input dimension.
hidden_dims: hidden dimension.
output_dims: output dimension.
dropout_prob: dropout probability.
layer_norm: whether to use layer norm or not.
dropout_tpl: config for dropout.
ln_tpl: config for layer norm.
act_tpl: config for activation in hidden layer.
"""
input_dims: int = 0
hidden_dims: int = 0
output_dims: int = 0
dropout_prob: float = 0.0
layer_norm: bool = False
dropout_tpl: LayerTpl = template_field(stochastics.Dropout)
ln_tpl: LayerTpl = template_field(normalizations.LayerNorm)
act_tpl: LayerTpl = template_field(activations.Swish)
def setup(self):
lnorm_tpl = self.ln_tpl.clone()
lnorm_tpl.dim = self.output_dims
self.create_child("ln_layer", lnorm_tpl)
dropout_tpl = self.dropout_tpl.clone()
dropout_tpl.keep_prob = 1.0 - self.dropout_prob
self.create_child("dropout", dropout_tpl)
self.create_child(
"hidden_layer",
pax_fiddle.Config(
linears.FeedForward,
input_dims=self.input_dims,
output_dims=self.hidden_dims,
activation_tpl=self.act_tpl.clone(),
),
)
self.create_child(
"output_layer",
pax_fiddle.Config(
linears.FeedForward,
input_dims=self.hidden_dims,
output_dims=self.output_dims,
activation_tpl=pax_fiddle.Config(activations.Identity),
),
)
self.create_child(
"residual_layer",
pax_fiddle.Config(
linears.FeedForward,
input_dims=self.input_dims,
output_dims=self.output_dims,
activation_tpl=pax_fiddle.Config(activations.Identity),
),
)
def __call__(self, inputs: JTensor) -> JTensor:
hidden = self.hidden_layer(inputs)
output = self.output_layer(hidden)
output = self.dropout(output)
residual = self.residual_layer(inputs)
if self.layer_norm:
return self.ln_layer(output + residual)
else:
return output + residual
def _masked_mean_std(inputs: JTensor,
padding: JTensor) -> Tuple[JTensor, JTensor]:
"""Calculates mean and standard deviation of arr across axis 1.
It should exclude values where pad is 1.
Args:
inputs: A JAX array of shape [b, n, p].
padding: A JAX array of shape [b, n, p] with values 0 or 1.
Returns:
A tuple containing the mean and standard deviation of arr. We return the
statistics of the first patch with more than three non-padded values.
"""
# Selecting the first pad with more than 3 unpadded values.
pad_sum = jnp.sum(1 - padding, axis=2)
def _get_patch_index(arr: JTensor):
indices = jnp.argmax(arr >= 3, axis=1)
row_sum = (arr >= 3).sum(axis=1)
return jnp.where(row_sum == 0, arr.shape[1] - 1, indices)
patch_indices = _get_patch_index(pad_sum)
bidxs = jnp.arange(inputs.shape[0])
arr = inputs[bidxs, patch_indices, :]
pad = padding[bidxs, patch_indices, :]
# Create a mask where P is 0
mask = 1 - pad
# Calculate the number of valid elements
num_valid_elements = jnp.sum(mask, axis=1)
num_valid_elements = jnp.where(num_valid_elements == 0, 1, num_valid_elements)
# Calculate the masked sum for mean and centered squared sum for variance.
masked_sum = jnp.sum(arr * mask, axis=1)
# Calculate the masked mean and standard deviation
masked_mean = masked_sum / num_valid_elements
centered = (arr - masked_mean[:, None]) * mask
masked_var = jnp.sum(centered**2, axis=1) / num_valid_elements
masked_var = jnp.where(masked_var < 0.0, 0.0, masked_var)
masked_std = jnp.sqrt(masked_var)
return masked_mean, masked_std
def _create_quantiles() -> list[float]:
"""Returns the quantiles for forecasting."""
return DEFAULT_QUANTILES
class PatchedTimeSeriesDecoder(base_layer.BaseLayer):
"""Patch decoder layer for time-series foundation model.
Attributes:
patch_len: length of input patches.
horizon_len: length of output patches. Referred to as `output_patch_len`
during inference.
model_dims: model dimension of stacked transformer layer.
hidden_dims: hidden dimensions in fully connected layers.
quantiles: list of quantiles for non prob model.
residual_block_tpl: config for residual block.
stacked_transformer_params_tpl: config for stacked transformer.
use_freq: whether to use frequency encoding.
In all of what followed, except specified otherwise, B is batch size, T is
sequence length of time-series. N is the number of input patches that can be
obtained from T. P is the input patch length and H is the horizon length. Q is
number of output logits. D is model dimension.
"""
patch_len: int = 0
horizon_len: int = 0
model_dims: int = 0
hidden_dims: int = 0
quantiles: list[float] = dataclasses.field(default_factory=_create_quantiles)
residual_block_tpl: LayerTpl = template_field(ResidualBlock)
stacked_transformer_params_tpl: LayerTpl = template_field(
transformers.StackedTransformer)
use_freq: bool = True
use_pos_emb: bool = True
def setup(self) -> None:
"""Construct the model."""
num_outputs = len(self.quantiles) + 1
stl = self.stacked_transformer_params_tpl.clone()
stl.model_dims = self.model_dims
stl.hidden_dims = self.hidden_dims
stl.mask_self_attention = True
self.create_child("stacked_transformer_layer", stl)
input_resl = self.residual_block_tpl.clone()
ff_in_dims = 2 * self.patch_len
input_resl.input_dims = ff_in_dims
input_resl.hidden_dims = self.hidden_dims
input_resl.output_dims = self.model_dims
self.create_child(
"input_ff_layer",
input_resl,
)
horizon_resl = self.residual_block_tpl.clone()
horizon_resl.input_dims = self.model_dims
horizon_resl.hidden_dims = self.hidden_dims
horizon_resl.output_dims = self.horizon_len * num_outputs
self.create_child(
"horizon_ff_layer",
horizon_resl,
)
self.create_child(
"position_emb",
pax_fiddle.Config(layers.PositionalEmbedding,
embedding_dims=self.model_dims),
)
if self.use_freq:
self.create_child(
"freq_emb",
pax_fiddle.Config(
embedding_softmax.Embedding,
num_classes=3,
input_dims=self.model_dims,
),
)
def transform_decode_state(
self, transform_fn: base_layer.DecodeStateTransformFn) -> None:
"""Transforms all decode state variables based on transform_fn."""
self.stacked_transformer_layer.transform_decode_state(transform_fn)
def _forward_transform(
self, inputs: JTensor,
patched_pads: JTensor) -> Tuple[JTensor, Tuple[JTensor, JTensor]]:
"""Input is of shape [B, N, P]."""
mu, sigma = _masked_mean_std(inputs, patched_pads)
sigma = jnp.maximum(sigma, _TOLERANCE)
# Normalize each patch.
outputs = (inputs - mu[:, None, None]) / sigma[:, None, None]
outputs = jnp.where(
jnp.abs(inputs - PAD_VAL) < _TOLERANCE, PAD_VAL, outputs)
return outputs, (mu, sigma)
def _reverse_transform(self, outputs: JTensor,
stats: Tuple[JTensor, JTensor]) -> JTensor:
"""Output is of shape [B, N, P, Q]."""
mu, sigma = stats
return outputs * sigma[:, None, None, None] + mu[:, None, None, None]
def _preprocess_input(
self,
input_ts: JTensor,
input_padding: JTensor,
pos_emb: Optional[JTensor] = None,
) -> Tuple[JTensor, JTensor, Optional[Tuple[JTensor, JTensor]], JTensor]:
"""Preprocess input for stacked transformer."""
# Reshape into patches.
patched_inputs = es.jax_einshape("b(np)->bnp", input_ts, p=self.patch_len)
patched_pads = es.jax_einshape("b(np)->bnp",
input_padding,
p=self.patch_len)
patched_inputs = jnp.where(
jnp.abs(patched_pads - 1.0) < _TOLERANCE, 0.0, patched_inputs)
patched_pads = jnp.where(
jnp.abs(patched_inputs - PAD_VAL) < _TOLERANCE, 1, patched_pads)
patched_inputs, stats = self._forward_transform(patched_inputs,
patched_pads)
# B x N x D
patched_inputs = patched_inputs * (1.0 - patched_pads)
concat_inputs = jnp.concatenate([patched_inputs, patched_pads], axis=-1)
model_input = self.input_ff_layer(concat_inputs)
# A patch should not be padded even if there is at least one zero.
patched_padding = jnp.min(patched_pads, axis=-1)
if self.use_pos_emb:
if pos_emb is None:
position_emb = self.position_emb(seq_length=model_input.shape[1])
else:
position_emb = pos_emb
if self.do_eval:
if position_emb.shape[0] != model_input.shape[0]:
position_emb = jnp.repeat(position_emb, model_input.shape[0], axis=0)
position_emb = _shift_padded_seq(patched_padding, position_emb)
model_input += position_emb
return model_input, patched_padding, stats, patched_inputs
def _postprocess_output(
self,
model_output: JTensor,
num_outputs: int,
stats: Tuple[JTensor, JTensor],
) -> JTensor:
"""Postprocess output of stacked transformer."""
# B x N x (H.Q)
output_ts = self.horizon_ff_layer(model_output)
output_ts = es.jax_einshape("bn(hq)->bnhq",
output_ts,
q=num_outputs,
h=self.horizon_len)
return self._reverse_transform(output_ts, stats)
def __call__(self, inputs: NestedMap) -> NestedMap:
"""PatchTST call.
Args:
inputs: A NestedMap containing (1) input_ts: input sequence of shape [B,
T] where T must be multiple of patch_length; (2) input_padding: that
contains padding map.
Returns:
A nested map with two keys:
(1) 'output_tokens' of shape [B, N, D].
(2) 'output_ts' of shape [B, N, H, Q]
(3) 'stats' a Tuple of statistics for renormalization.
"""
input_ts, input_padding = inputs[_INPUT_TS], inputs[_INPUT_PADDING]
num_outputs = len(self.quantiles) + 1
model_input, patched_padding, stats, _ = self._preprocess_input(
input_ts=input_ts,
input_padding=input_padding,
)
if self.use_freq:
freq = inputs[_FREQ].astype(jnp.int32)
f_emb = self.freq_emb(freq) # B x 1 x D
f_emb = jnp.repeat(f_emb, model_input.shape[1], axis=1)
model_input += f_emb
model_output = self.stacked_transformer_layer(model_input, patched_padding)
output_ts = self._postprocess_output(model_output, num_outputs, stats)
return NestedMap({
_OUTPUT_TOKENS: model_output,
_OUTPUT_TS: output_ts,
_STATS: stats
})
def decode(
self,
inputs: NestedMap,
horizon_len: int,
output_patch_len: Optional[int] = None,
max_len: int | None = None,
return_forecast_on_context: bool = False,
) -> tuple[JTensor, JTensor]:
"""Auto-regressive decoding without caching.
Args:
inputs: input time-series and paddings. Time-series shape B x C, padding
shape shape B x (C + H) where H is the prediction length.
horizon_len: prediction length.
output_patch_len: output length to be fetched from one step of
auto-regressive decoding.
max_len: maximum training context length.
return_forecast_on_context: whether to return the model forecast on the
context except the first input patch.
Returns:
Tuple of two forecasting results:
- Point (mean) output predictions as a tensor with shape B x H'.
- Full predictions (mean and quantiles) as a tensor with shape
B x H' x (1 + # quantiles).
In particular, if return_forecast_on_context is True, H' is H plus
the forecastable context length, i.e. context_len - (first) patch_len.
"""
final_out = inputs[_INPUT_TS]
context_len = final_out.shape[1]
paddings = inputs[_INPUT_PADDING]
if max_len is None:
max_len = context_len
if self.use_freq:
freq = inputs[_FREQ].astype(jnp.int32)
else:
freq = jnp.zeros([final_out.shape[0], 1], dtype=jnp.int32)
full_outputs = []
if paddings.shape[1] != final_out.shape[1] + horizon_len:
raise ValueError(
"Length of paddings must match length of input + horizon_len:"
f" {paddings.shape[1]} != {final_out.shape[1]} + {horizon_len}")
if output_patch_len is None:
output_patch_len = self.horizon_len
num_decode_patches = (horizon_len + output_patch_len -
1) // output_patch_len
for step_index in range(num_decode_patches):
current_padding = paddings[:, 0:final_out.shape[1]]
input_ts = final_out[:, -max_len:]
input_padding = current_padding[:, -max_len:]
model_input = NestedMap(
input_ts=input_ts,
input_padding=input_padding,
freq=freq,
)
fprop_outputs = self(model_input)[_OUTPUT_TS]
if return_forecast_on_context and step_index == 0:
# For the first decodings step, collect the model forecast on the
# context except the unavailable first input batch forecast.
new_full_ts = fprop_outputs[:, :-1, :self.patch_len, :]
new_full_ts = es.jax_einshape("bnph->b(np)h", new_full_ts)
full_outputs.append(new_full_ts)
# (full batch, last patch, output_patch_len, index of mean forecast = 0)
new_ts = fprop_outputs[:, -1, :output_patch_len, 0]
new_full_ts = fprop_outputs[:, -1, :output_patch_len, :]
# (full batch, last patch, output_patch_len, all output indices)
full_outputs.append(new_full_ts)
final_out = jnp.concatenate([final_out, new_ts], axis=-1)
if return_forecast_on_context:
# `full_outputs` indexing starts at after the first input patch.
full_outputs = jnp.concatenate(full_outputs,
axis=1)[:, :(context_len - self.patch_len +
horizon_len), :]
else:
# `full_outputs` indexing starts at the forecast horizon.
full_outputs = jnp.concatenate(full_outputs, axis=1)[:, 0:horizon_len, :]
return (full_outputs[:, :, 0], full_outputs)
class PatchedDecoderFinetuneModel(base_model.BaseModel):
"""Model class for finetuning patched time-series decoder.
Attributes:
core_layer_tpl: config for core layer.
freq: freq to finetune on.
"""
core_layer_tpl: LayerTpl = template_field(PatchedTimeSeriesDecoder)
freq: int = 0
def setup(self) -> None:
self.create_child("core_layer", self.core_layer_tpl)
def compute_predictions(self, input_batch: NestedMap) -> NestedMap:
input_ts = input_batch[_INPUT_TS]
input_padding = jnp.zeros_like(input_ts)
context_len = input_ts.shape[1]
input_patch_len = self.core_layer_tpl.patch_len
context_pad = ((context_len + input_patch_len - 1) //
input_patch_len) * input_patch_len - context_len
input_ts = jnp.pad(input_ts, [(0, 0), (context_pad, 0)])
input_padding = jnp.pad(input_padding, [(0, 0), (context_pad, 0)],
constant_values=1)
freq = jnp.ones([input_ts.shape[0], 1], dtype=jnp.int32) * self.freq
new_input_batch = NestedMap(
input_ts=input_ts,
input_padding=input_padding,
freq=freq,
)
return self.core_layer(new_input_batch)
def _quantile_loss(self, pred: JTensor, actual: JTensor,
quantile: float) -> JTensor:
"""Calculates quantile loss.
Args:
pred: B x T
actual: B x T
quantile: quantile at which loss is computed.
Returns:
per coordinate loss.
"""
dev = actual - pred
loss_first = dev * quantile
loss_second = -dev * (1.0 - quantile)
return 2 * jnp.where(loss_first >= 0, loss_first, loss_second)
def compute_loss(self, prediction_output: NestedMap,
input_batch: NestedMap) -> Tuple[NestedMap, NestedMap]:
output_ts = prediction_output[_OUTPUT_TS]
actual_ts = input_batch[_TARGET_FUTURE]
pred_ts = output_ts[:, -1, 0:actual_ts.shape[1], :]
loss = jnp.square(pred_ts[:, :, 0] - actual_ts)
for i, quantile in enumerate(self.core_layer.quantiles):
loss += self._quantile_loss(pred_ts[:, :, i + 1], actual_ts, quantile)
loss = loss.mean()
loss_weight = jnp.array(1.0, dtype=jnp.float32)
per_example_out = NestedMap()
return {"avg_qloss": (loss, loss_weight)}, per_example_out
================================================
FILE: v1/src/timesfm/pytorch_patched_decoder.py
================================================
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pytorch version of patched decoder."""
import dataclasses
import math
from typing import List, Tuple
import torch
from torch import nn
import torch.nn.functional as F
def create_quantiles() -> list[float]:
return [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
@dataclasses.dataclass
class TimesFMConfig:
"""Config for initializing timesfm patched_decoder class."""
# The number of blocks in the model.
num_layers: int = 20
# The number of attention heads used in the attention layers of the model.
num_heads: int = 16
# The number of key-value heads for implementing attention.
num_kv_heads: int = 16
# The hidden size of the model.
hidden_size: int = 1280
# The dimension of the MLP representations.
intermediate_size: int = 1280
# The number of head dimensions.
head_dim: int = 80
# The epsilon used by the rms normalization layers.
rms_norm_eps: float = 1e-6
# Patch length
patch_len: int = 32
# Horizon length
horizon_len: int = 128
# quantiles
quantiles: List[float] = dataclasses.field(default_factory=create_quantiles)
# Padding value
pad_val: float = 1123581321.0
# Tolerance
tolerance: float = 1e-6
# The dtype of the weights.
dtype: str = "bfloat32"
# use positional embedding
use_positional_embedding: bool = True
def _masked_mean_std(
inputs: torch.Tensor,
padding: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Calculates mean and standard deviation of `inputs` across axis 1.
It excludes values where `padding` is 1.
Args:
inputs: A PyTorch tensor of shape [b, n, p].
padding: A PyTorch tensor of shape [b, n, p] with values 0 or 1.
Returns:
A tuple containing the mean and standard deviation.
We return the statistics of the first patch with more than three non-padded
values.
"""
# Selecting the first patch with more than 3 unpadded values.
pad_sum = torch.sum(1 - padding, dim=2)
def _get_patch_index(arr: torch.Tensor):
indices = torch.argmax((arr >= 3).to(torch.int32), dim=1)
row_sum = (arr >= 3).to(torch.int32).sum(dim=1)
return torch.where(row_sum == 0, arr.shape[1] - 1, indices)
patch_indices = _get_patch_index(pad_sum)
bidxs = torch.arange(inputs.shape[0])
arr = inputs[bidxs, patch_indices, :]
pad = padding[bidxs, patch_indices, :]
# Create a mask where padding is 0
mask = 1 - pad
# Calculate the number of valid elements
num_valid_elements = torch.sum(mask, dim=1)
num_valid_elements = torch.clamp(num_valid_elements, min=1.0)
# Calculate the masked sum and mean
masked_sum = torch.sum(arr * mask, dim=1)
masked_mean = masked_sum / num_valid_elements
# Calculate the masked variance using centered values (numerically stable)
masked_centered_arr = (arr - masked_mean.unsqueeze(-1)) * mask
masked_var = torch.sum(masked_centered_arr**2, dim=1) / num_valid_elements
masked_var = torch.clamp(masked_var, min=0.0)
masked_std = torch.sqrt(masked_var)
return masked_mean, masked_std
def _shift_padded_seq(mask: torch.Tensor, seq: torch.Tensor) -> torch.Tensor:
"""Shifts rows of seq based on the first 0 in each row of the mask.
Args:
mask: mask tensor of shape [B, N]
seq: seq tensor of shape [B, N, P]
Returns:
Returns the shifted sequence.
"""
batch_size, num_seq, feature_dim = seq.shape
new_mask: torch.BoolTensor = mask == 0
# Use argmax to find the first True value in each row
indices = new_mask.to(torch.int32).argmax(dim=1)
# Handle rows with all zeros
indices[~new_mask.any(dim=1)] = -1
# Create index ranges for each sequence in the batch
idx_range = (torch.arange(num_seq).to(
seq.device).unsqueeze(0).unsqueeze(-1).expand(batch_size, -1,
feature_dim))
# Calculate shifted indices for each element in each sequence
shifted_idx = (idx_range - indices[:, None, None]) % num_seq
# Gather values from seq using shifted indices
shifted_seq = seq.gather(1, shifted_idx)
return shifted_seq
def get_large_negative_number(dtype: torch.dtype) -> torch.Tensor:
"""Returns a large negative value for the given dtype."""
if dtype.is_floating_point:
dtype_max = torch.finfo(dtype).max
else:
dtype_max = torch.iinfo(dtype).max
return torch.tensor(-0.7 * dtype_max, dtype=dtype)
def apply_mask_to_logits(logits: torch.Tensor,
mask: torch.Tensor) -> torch.Tensor:
"""Applies a floating-point mask to a set of logits.
Args:
logits: A torch.Tensor of logit values.
mask: A torch.Tensor (float32) of mask values with the encoding described
in the function documentation.
Returns:
Masked logits.
"""
min_value = get_large_negative_number(logits.dtype)
return torch.where((mask >= min_value * 0.5), logits, min_value)
def convert_paddings_to_mask(
paddings: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor:
"""Converts binary paddings to a logit mask ready to add to attention matrix.
Args:
paddings: binary torch.Tensor of shape [B, T], with 1 denoting padding
token.
dtype: data type of the input.
Returns:
A torch.Tensor of shape [B, 1, 1, T] ready to add to attention logits.
"""
attention_mask = paddings.detach().clone()
attention_mask = attention_mask[:, None, None, :] # Equivalent to jnp.newaxis
attention_mask *= get_large_negative_number(dtype)
return attention_mask
def causal_mask(input_t: torch.Tensor) -> torch.Tensor:
"""Computes and returns causal mask.
Args:
input_t: A torch.Tensor of shape [B, T, D].
Returns:
An attention_mask torch.Tensor of shape [1, 1, T, T]. Attention mask has
already been converted to large negative values.
"""
assert input_t.dtype.is_floating_point, input_t.dtype
large_negative_number = get_large_negative_number(input_t.dtype)
t = input_t.shape[1]
col_idx = torch.arange(t).unsqueeze(0).repeat(t, 1)
row_idx = torch.arange(t).unsqueeze(1).repeat(1, t)
mask = (row_idx < col_idx).to(input_t.dtype) * large_negative_number
return (mask.unsqueeze(0).unsqueeze(0).to(input_t.device)
) # Equivalent to jnp.newaxis
def merge_masks(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
"""Merges 2 masks.
logscale mask is expected but 0/1 mask is also fine.
Args:
a: torch.Tensor of shape [1|B, 1, 1|T, S].
b: torch.Tensor of shape [1|B, 1, 1|T, S].
Returns:
torch.Tensor of shape [1|B, 1, 1|T, S].
"""
def expand_t(key_mask):
query_mask = key_mask.transpose(-1, -2) # Equivalent of jnp.transpose
return torch.minimum(query_mask, key_mask)
if a.shape[2] != b.shape[2]:
if a.shape[2] == 1:
a = expand_t(a)
else:
assert b.shape[2] == 1
b = expand_t(b)
assert a.shape[1:] == b.shape[1:], f"a.shape={a.shape}, b.shape={b.shape}."
return torch.minimum(a, b) # Element-wise minimum, similar to jnp.minimum
class ResidualBlock(nn.Module):
"""TimesFM residual block."""
def __init__(
self,
input_dims,
hidden_dims,
output_dims,
):
super(ResidualBlock, self).__init__()
self.input_dims = input_dims
self.hidden_dims = hidden_dims
self.output_dims = output_dims
# Hidden Layer
self.hidden_layer = nn.Sequential(
nn.Linear(input_dims, hidden_dims),
nn.SiLU(),
)
# Output Layer
self.output_layer = nn.Linear(hidden_dims, output_dims)
# Residual Layer
self.residual_layer = nn.Linear(input_dims, output_dims)
def forward(self, x):
hidden = self.hidden_layer(x)
output = self.output_layer(hidden)
residual = self.residual_layer(x)
return output + residual
class RMSNorm(torch.nn.Module):
"""Pax rms norm in pytorch."""
def __init__(
self,
dim: int,
eps: float = 1e-6,
add_unit_offset: bool = False,
):
super().__init__()
self.eps = eps
self.add_unit_offset = add_unit_offset
self.weight = nn.Parameter(torch.zeros(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float())
if self.add_unit_offset:
output = output * (1 + self.weight.float())
else:
output = output * self.weight.float()
return output.type_as(x)
class TransformerMLP(nn.Module):
"""Pax transformer MLP in pytorch."""
def __init__(
self,
hidden_size: int,
intermediate_size: int,
):
super().__init__()
self.gate_proj = nn.Linear(hidden_size, intermediate_size)
self.down_proj = nn.Linear(intermediate_size, hidden_size)
self.layer_norm = nn.LayerNorm(normalized_shape=hidden_size, eps=1e-6)
def forward(self, x, paddings=None):
gate_inp = self.layer_norm(x)
gate = self.gate_proj(gate_inp)
gate = F.relu(gate)
outputs = self.down_proj(gate)
if paddings is not None:
outputs = outputs * (1.0 - paddings[:, :, None])
return outputs + x
class TimesFMAttention(nn.Module):
"""Implements the attention used in TimesFM."""
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
head_dim: int,
):
super().__init__()
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.hidden_size = hidden_size
self.head_dim = head_dim
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = nn.Parameter(
torch.empty((self.head_dim,), dtype=torch.float32),)
self.qkv_proj = nn.Linear(
self.hidden_size,
(self.num_heads + 2 * self.num_kv_heads) * self.head_dim,
)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size)
def _per_dim_scaling(self, query: torch.Tensor) -> torch.Tensor:
# [batch_size, n_local_heads, input_len, head_dim]
r_softplus_0 = 1.442695041
softplus_func = torch.nn.Softplus()
scale = r_softplus_0 / math.sqrt(self.head_dim)
scale = scale * softplus_func(self.scaling)
return query * scale[None, None, None, :]
def forward(
self,
hidden_states: torch.Tensor,
mask: torch.Tensor,
kv_write_indices: torch.Tensor | None = None,
kv_cache: Tuple[torch.Tensor, torch.Tensor] | None = None,
) -> torch.Tensor:
hidden_states_shape = hidden_states.shape
assert len(hidden_states_shape) == 3
batch_size, input_len, _ = hidden_states_shape
qkv = self.qkv_proj(hidden_states)
xq, xk, xv = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
xq = xq.view(batch_size, -1, self.num_heads, self.head_dim)
xk = xk.view(batch_size, -1, self.num_kv_heads, self.head_dim)
xv = xv.view(batch_size, -1, self.num_kv_heads, self.head_dim)
xq = self._per_dim_scaling(xq)
# Write new kv cache.
# [batch_size, input_len, n_local_kv_heads, head_dim]
if kv_cache is not None and kv_write_indices is not None:
k_cache, v_cache = kv_cache
k_cache.index_copy_(1, kv_write_indices, xk)
v_cache.index_copy_(1, kv_write_indices, xv)
key = k_cache
value = v_cache
else:
key = xk
value = xv
if self.num_kv_heads != self.num_heads:
# [batch_size, max_seq_len, n_local_heads, head_dim]
key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=2)
value = torch.repeat_interleave(value, self.num_queries_per_kv, dim=2)
# [batch_size, n_local_heads, input_len, head_dim]
q = xq.transpose(1, 2)
# [batch_size, n_local_heads, max_seq_len, head_dim]
k = key.transpose(1, 2)
v = value.transpose(1, 2)
# [batch_size, n_local_heads, input_len, max_seq_len]
scores = torch.matmul(q, k.transpose(2, 3))
scores = scores + mask
scores = F.softmax(scores.float(), dim=-1).type_as(q)
# [batch_size, n_local_heads, input_len, head_dim]
output = torch.matmul(scores, v)
# return scores, output.transpose(1, 2).contiguous()
# [batch_size, input_len, hidden_dim]
output = output.transpose(1, 2).contiguous().view(batch_size, input_len, -1)
output = self.o_proj(output)
return scores, output
class TimesFMDecoderLayer(nn.Module):
"""Transformer layer."""
def __init__(
self,
hidden_size: int,
intermediate_size: int,
num_heads: int,
num_kv_heads: int,
head_dim: int,
rms_norm_eps: float = 1e-6,
):
super().__init__()
self.self_attn = TimesFMAttention(
hidden_size=hidden_size,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
)
self.mlp = TransformerMLP(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
)
self.input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
mask: torch.Tensor,
paddings: torch.Tensor,
kv_write_indices: torch.Tensor | None = None,
kv_cache: Tuple[torch.Tensor, torch.Tensor] | None = None,
) -> torch.Tensor:
# Self Attention
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
scores, hidden_states = self.self_attn(
hidden_states=hidden_states,
mask=mask,
kv_write_indices=kv_write_indices,
kv_cache=kv_cache,
)
hidden_states = residual + hidden_states
# MLP
hidden_states = self.mlp(hidden_states, paddings=paddings)
return scores, hidden_states
class StackedDecoder(nn.Module):
"""Stacked transformer layer."""
def __init__(
self,
hidden_size: int,
intermediate_size: int,
num_heads: int,
num_kv_heads: int,
head_dim: int,
num_layers: int,
rms_norm_eps: float = 1e-6,
):
super().__init__()
self.layers = nn.ModuleList()
for _ in range(num_layers):
self.layers.append(
TimesFMDecoderLayer(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
rms_norm_eps=rms_norm_eps,
))
def forward(
self,
hidden_states: torch.Tensor,
paddings: torch.Tensor,
kv_write_indices: torch.Tensor | None = None,
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] | None = None,
) -> torch.Tensor:
padding_mask = convert_paddings_to_mask(paddings, hidden_states.dtype)
atten_mask = causal_mask(hidden_states)
mask = merge_masks(padding_mask, atten_mask)
for i in range(len(self.layers)):
layer = self.layers[i]
kv_cache = kv_caches[i] if kv_caches is not None else None
_, hidden_states = layer(
hidden_states=hidden_states,
mask=mask,
paddings=paddings,
kv_write_indices=kv_write_indices,
kv_cache=kv_cache,
)
return hidden_states
class PositionalEmbedding(torch.nn.Module):
"""Generates position embedding for a given 1-d sequence.
Attributes:
min_timescale: Start of the geometric index. Determines the periodicity of
the added signal.
max_timescale: End of the geometric index. Determines the frequency of the
added signal.
embedding_dims: Dimension of the embedding to be generated.
"""
def __init__(
self,
embedding_dims: int,
min_timescale: int = 1,
max_timescale: int = 10_000,
) -> None:
super().__init__()
self.min_timescale = min_timescale
self.max_timescale = max_timescale
self.embedding_dims = embedding_dims
def forward(self, seq_length=None, position=None):
"""Generates a Tensor of sinusoids with different frequencies.
Args:
seq_length: an optional Python int defining the output sequence length.
if the `position` argument is specified.
position: [B, seq_length], optional position for each token in the
sequence, only required when the sequence is packed.
Returns:
[B, seqlen, D] if `position` is specified, else [1, seqlen, D]
"""
if position is None:
assert seq_length is not None
# [1, seqlen]
position = torch.arange(seq_length, dtype=torch.float32).unsqueeze(0)
else:
assert position.ndim == 2, position.shape
num_timescales = self.embedding_dims // 2
log_timescale_increment = math.log(
float(self.max_timescale) / float(self.min_timescale)) / max(
num_timescales - 1, 1)
inv_timescales = self.min_timescale * torch.exp(
torch.arange(num_timescales, dtype=torch.float32) *
-log_timescale_increment)
scaled_time = position.unsqueeze(2) * inv_timescales.unsqueeze(0).unsqueeze(
0)
signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2)
# Padding to ensure correct embedding dimension
signal = F.pad(signal, (0, 0, 0, self.embedding_dims % 2))
return signal
class PatchedTimeSeriesDecoder(nn.Module):
"""Patched time-series decoder."""
def __init__(self, config: TimesFMConfig):
super().__init__()
self.config = config
self.input_ff_layer = ResidualBlock(
input_dims=2 * config.patch_len,
output_dims=config.hidden_size,
hidden_dims=config.intermediate_size,
)
self.freq_emb = nn.Embedding(num_embeddings=3,
embedding_dim=config.hidden_size)
self.horizon_ff_layer = ResidualBlock(
input_dims=config.hidden_size,
output_dims=config.horizon_len * (1 + len(config.quantiles)),
hidden_dims=config.intermediate_size,
)
self.stacked_transformer = StackedDecoder(
hidden_size=self.config.hidden_size,
intermediate_size=self.config.intermediate_size,
num_heads=self.config.num_heads,
num_kv_heads=self.config.num_kv_heads,
head_dim=self.config.head_dim,
num_layers=self.config.num_layers,
rms_norm_eps=self.config.rms_norm_eps,
)
if self.config.use_positional_embedding:
self.position_emb = PositionalEmbedding(self.config.hidden_size)
def _forward_transform(
self, inputs: torch.Tensor, patched_pads: torch.Tensor
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
"""Input is of shape [B, N, P]."""
mu, sigma = _masked_mean_std(inputs, patched_pads)
sigma = torch.clamp(sigma, min=self.config.tolerance)
# Normalize each patch
outputs = (inputs - mu[:, None, None]) / sigma[:, None, None]
outputs = torch.where(
torch.abs(inputs - self.config.pad_val) < self.config.tolerance,
torch.tensor(self.config.pad_val,
dtype=outputs.dtype,
device=outputs.device),
outputs,
)
return outputs, (mu, sigma)
def _reverse_transform(
self, outputs: torch.Tensor, stats: tuple[torch.Tensor,
torch.Tensor]) -> torch.Tensor:
"""Output is of shape [B, N, P, Q]."""
mu, sigma = stats
return outputs * sigma[:, None, None, None] + mu[:, None, None, None]
def _preprocess_input(
self,
input_ts: torch.Tensor,
input_padding: torch.Tensor,
) -> tuple[
torch.Tensor,
torch.Tensor,
tuple[torch.Tensor, torch.Tensor] | None,
torch.Tensor,
]:
"""Preprocess input for stacked transformer."""
# Reshape into patches (using view for efficiency)
bsize = input_ts.shape[0]
patched_inputs = input_ts.view(bsize, -1, self.config.patch_len)
patched_pads = input_padding.view(bsize, -1, self.config.patch_len)
patched_inputs = torch.where(
torch.abs(patched_pads - 1.0) < self.config.tolerance,
torch.tensor(0.0,
dtype=patched_inputs.dtype,
device=patched_inputs.device),
patched_inputs,
)
patched_pads = torch.where(
torch.abs(patched_inputs - self.config.pad_val) < self.config.tolerance,
torch.tensor(1.0, dtype=patched_pads.dtype, device=patched_pads.device),
patched_pads,
)
patched_inputs, stats = self._forward_transform(patched_inputs,
patched_pads)
# B x N x D
patched_inputs = patched_inputs * (1.0 - patched_pads)
concat_inputs = torch.cat([patched_inputs, patched_pads], dim=-1)
model_input = self.input_ff_layer(concat_inputs)
# A patch should not be padded even if there is at least one zero.
patched_padding = torch.min(patched_pads,
dim=-1)[0] # Get the values from the min result
if self.config.use_positional_embedding:
pos_emb = self.position_emb(model_input.shape[1]).to(model_input.device)
pos_emb = torch.concat([pos_emb] * model_input.shape[0], dim=0)
pos_emb = _shift_padded_seq(patched_padding, pos_emb)
model_input += pos_emb
return model_input, patched_padding, stats, patched_inputs
def _postprocess_output(
self,
model_output: torch.Tensor,
num_outputs: int,
stats: tuple[torch.Tensor, torch.Tensor],
) -> torch.Tensor:
"""Postprocess output of stacked transformer."""
# B x N x (H.Q)
output_ts = self.horizon_ff_layer(model_output)
# Reshape using view
b, n, _ = output_ts.shape
output_ts = output_ts.view(b, n, self.config.horizon_len, num_outputs)
return self._reverse_transform(output_ts, stats)
def forward(
self,
input_ts: torch.Tensor,
input_padding: torch.LongTensor,
freq: torch.Tensor,
) -> torch.Tensor:
num_outputs = len(self.config.quantiles) + 1
model_input, patched_padding, stats, _ = self._preprocess_input(
input_ts=input_ts,
input_padding=input_padding,
)
f_emb = self.freq_emb(freq) # B x 1 x D
model_input += f_emb
model_output = self.stacked_transformer(model_input, patched_padding)
output_ts = self._postprocess_output(model_output, num_outputs, stats)
return output_ts
def decode(
self,
input_ts: torch.Tensor,
paddings: torch.Tensor,
freq: torch.LongTensor,
horizon_len: int,
output_patch_len: int | None = None,
max_len: int | None = None,
return_forecast_on_context: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Auto-regressive decoding without caching.
Args:
input_ts: input time-series and paddings. Time-series shape B x C.
paddings: padding shape B x (C + H) where H is the prediction length.
freq: frequency shape B x 1
horizon_len: prediction length.
output_patch_len: output length to be fetched from one step of
auto-regressive decoding.
max_len: maximum training context length.
return_forecast_on_context: whether to return the model forecast on the
context except the first input patch.
Returns:
Tuple of two forecasting results:
- Point (mean) output predictions as a tensor with shape B x H'.
- Full predictions (mean and quantiles) as a tensor with shape
B x H' x (1 + # quantiles).
In particular, if return_forecast_on_context is True, H' is H plus
the forecastable context length, i.e. context_len - (first) patch_len.
"""
final_out = input_ts
context_len = final_out.shape[1]
full_outputs = []
if max_len is None:
max_len = context_len
if paddings.shape[1] != final_out.shape[1] + horizon_len:
raise ValueError(
"Length of paddings must match length of input + horizon_len:"
f" {paddings.shape[1]} != {final_out.shape[1]} + {horizon_len}")
if output_patch_len is None:
output_patch_len = self.config.horizon_len
num_decode_patches = (horizon_len + output_patch_len -
1) // output_patch_len
for step_index in range(num_decode_patches):
current_padding = paddings[:, 0:final_out.shape[1]]
input_ts = final_out[:, -max_len:]
input_padding = current_padding[:, -max_len:]
fprop_outputs = self(input_ts, input_padding, freq)
if return_forecast_on_context and step_index == 0:
# For the first decodings step, collect the model forecast on the
# context except the unavailable first input batch forecast.
new_full_ts = fprop_outputs[:, 0:-1, 0:self.config.patch_len, :]
new_full_ts = new_full_ts.reshape(new_full_ts.size(0), -1,
new_full_ts.size(3))
full_outputs.append(new_full_ts)
# (full batch, last patch, output_patch_len, index of mean forecast = 0)
new_ts = fprop_outputs[:, -1, :output_patch_len, 0]
new_full_ts = fprop_outputs[:, -1, :output_patch_len, :]
# (full batch, last patch, output_patch_len, all output indices)
full_outputs.append(new_full_ts)
final_out = torch.concatenate([final_out, new_ts], axis=-1)
if return_forecast_on_context:
# `full_outputs` indexing starts at after the first input patch.
full_outputs = torch.concatenate(
full_outputs,
axis=1)[:, :(context_len - self.config.patch_len + horizon_len), :]
else:
# `full_outputs` indexing starts at the forecast horizon.
full_outputs = torch.concatenate(full_outputs, axis=1)[:,
0:horizon_len, :]
return (full_outputs[:, :, 0], full_outputs)
================================================
FILE: v1/src/timesfm/time_features.py
================================================
# Copyright 2024 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Directory to extract time covariates.
Extract time covariates from datetime.
"""
import numpy as np
import pandas as pd
from pandas.tseries.holiday import EasterMonday
from pandas.tseries.holiday import GoodFriday
from pandas.tseries.holiday import Holiday
from pandas.tseries.holiday import SU
from pandas.tseries.holiday import TH
from pandas.tseries.holiday import USColumbusDay
from pandas.tseries.holiday import USLaborDay
from pandas.tseries.holiday import USMartinLutherKingJr
from pandas.tseries.holiday import USMemorialDay
from pandas.tseries.holiday import USPresidentsDay
from pandas.tseries.holiday import USThanksgivingDay
from pandas.tseries.offsets import DateOffset
from pandas.tseries.offsets import Day
from pandas.tseries.offsets import Easter
from sklearn.preprocessing import StandardScaler
from tqdm import tqdm
# This is 183 to cover half a year (in both directions), also for leap years
# + 17 as Eastern can be between March, 22 - April, 25
MAX_WINDOW = 183 + 17
def _distance_to_holiday(holiday):
"""Return distance to given holiday."""
def _distance_to_day(index):
holiday_date = holiday.dates(
index - pd.Timedelta(days=MAX_WINDOW),
index + pd.Timedelta(days=MAX_WINDOW),
)
assert (
len(holiday_date) != 0 # pylint: disable=g-explicit-length-test
), f"No closest holiday for the date index {index} found."
# It sometimes returns two dates if it is exactly half a year after the
# holiday. In this case, the smaller distance (182 days) is returned.
return (index - holiday_date[0]).days
return _distance_to_day
EasterSunday = Holiday(
"Easter Sunday", month=1, day=1, offset=[Easter(), Day(0)]
)
NewYearsDay = Holiday("New Years Day", month=1, day=1)
SuperBowl = Holiday(
"Superbowl", month=2, day=1, offset=DateOffset(weekday=SU(1))
)
MothersDay = Holiday(
"Mothers Day", month=5, day=1, offset=DateOffset(weekday=SU(2))
)
IndependenceDay = Holiday("Independence Day", month=7, day=4)
ChristmasEve = Holiday("Christmas", month=12, day=24)
ChristmasDay = Holiday("Christmas", month=12, day=25)
NewYearsEve = Holiday("New Years Eve", month=12, day=31)
BlackFriday = Holiday(
"Black Friday",
month=11,
day=1,
offset=[pd.DateOffset(weekday=TH(4)), Day(1)],
)
CyberMonday = Holiday(
"Cyber Monday",
month=11,
day=1,
offset=[pd.DateOffset(weekday=TH(4)), Day(4)],
)
HOLIDAYS = [
EasterMonday,
GoodFriday,
USColumbusDay,
USLaborDay,
USMartinLutherKingJr,
USMemorialDay,
USPresidentsDay,
USThanksgivingDay,
EasterSunday,
NewYearsDay,
SuperBowl,
MothersDay,
IndependenceDay,
ChristmasEve,
ChristmasDay,
NewYearsEve,
BlackFriday,
CyberMonday,
]
class TimeCovariates(object):
"""Extract all time covariates except for holidays."""
def __init__(
self,
datetimes,
normalized=True,
holiday=False,
):
"""Init function.
Args:
datetimes: pandas DatetimeIndex (lowest granularity supported is min)
normalized: whether to normalize features or not
holiday: fetch holiday features or not
Returns:
None
"""
self.normalized = normalized
self.dti = datetimes
self.holiday = holiday
def _minute_of_hour(self):
minutes = np.array(self.dti.minute, dtype=np.float32)
if self.normalized:
minutes = minutes / 59.0 - 0.5
return minutes
def _hour_of_day(self):
hours = np.array(self.dti.hour, dtype=np.float32)
if self.normalized:
hours = hours / 23.0 - 0.5
return hours
def _day_of_week(self):
day_week = np.array(self.dti.dayofweek, dtype=np.float32)
if self.normalized:
day_week = day_week / 6.0 - 0.5
return day_week
def _day_of_month(self):
day_month = np.array(self.dti.day, dtype=np.float32)
if self.normalized:
day_month = day_month / 30.0 - 0.5
return day_month
def _day_of_year(self):
day_year = np.array(self.dti.dayofyear, dtype=np.float32)
if self.normalized:
day_year = day_year / 364.0 - 0.5
return day_year
def _month_of_year(self):
month_year = np.array(self.dti.month, dtype=np.float32)
if self.normalized:
month_year = month_year / 11.0 - 0.5
return month_year
def _week_of_year(self):
week_year = np.array(self.dti.strftime("%U").astype(int), dtype=np.float32)
if self.normalized:
week_year = week_year / 51.0 - 0.5
return week_year
def _get_holidays(self):
dti_series = self.dti.to_series()
hol_variates = np.vstack([
dti_series.apply(_distance_to_holiday(h)).values for h in tqdm(HOLIDAYS)
])
# hol_variates is (num_holiday, num_time_steps), the normalization should be
# performed in the num_time_steps dimension.
return StandardScaler().fit_transform(hol_variates.T).T
def get_covariates(self):
"""Get all time covariates."""
moh = self._minute_of_hour().reshape(1, -1)
hod = self._hour_of_day().reshape(1, -1)
dom = self._day_of_month().reshape(1, -1)
dow = self._day_of_week().reshape(1, -1)
doy = self._day_of_year().reshape(1, -1)
moy = self._month_of_year().reshape(1, -1)
woy = self._week_of_year().reshape(1, -1)
all_covs = [
moh,
hod,
dom,
dow,
doy,
moy,
woy,
]
columns = ["moh", "hod", "dom", "dow", "doy", "moy", "woy"]
if self.holiday:
hol_covs = self._get_holidays()
all_covs.append(hol_covs)
columns += [f"hol_{i}" for i in range(len(HOLIDAYS))]
return pd.DataFrame(
data=np.vstack(all_covs).transpose(),
columns=columns,
index=self.dti,
)
================================================
FILE: v1/src/timesfm/timesfm_base.py
================================================
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base class for TimesFM inference. This will be common to PAX and Pytorch."""
import collections
import dataclasses
import logging
import multiprocessing
from typing import Any, Literal, Sequence, TYPE_CHECKING
import numpy as np
import pandas as pd
from utilsforecast.processing import make_future_dataframe
if TYPE_CHECKING:
from . import xreg_lib
Category = xreg_lib.Category
XRegMode = xreg_lib.XRegMode
else:
Category = int | str
XRegMode = str
_TOL = 1e-6
DEFAULT_QUANTILES = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9)
def process_group(key, group, value_name, forecast_context_len):
group = group.tail(forecast_context_len)
return np.array(group[value_name], dtype=np.float32), key
def moving_average(arr, window_size):
"""Calculates the moving average using NumPy's convolution function."""
# Pad with zeros to handle initial window positions
arr_padded = np.pad(arr, (window_size - 1, 0), "constant")
smoothed_arr = (np.convolve(arr_padded, np.ones(window_size), "valid") /
window_size)
return [smoothed_arr, arr - smoothed_arr]
def freq_map(freq: str):
"""Returns the frequency map for the given frequency string."""
freq = str.upper(freq)
if freq.endswith("MS"):
return 1
elif freq.endswith(("H", "T", "MIN", "D", "B", "U", "S")):
return 0
elif (
freq.endswith(("W", "M"))
or freq.startswith("W-")
or (freq.startswith("M") and len(freq) == 2)
):
return 1
elif (
freq.endswith(("Y", "Q", "A"))
or freq.startswith("Y-")
or freq.startswith("Q-")
or freq.startswith("A-")
):
return 2
else:
raise ValueError(f"Invalid frequency: {freq}")
def strip_leading_nans(arr):
"""
Removes contiguous NaN values from the beginning of a NumPy array.
Args:
arr: The input NumPy array.
Returns:
A new NumPy array with leading NaN values removed.
If the array is all NaNs or empty, returns an empty array.
"""
isnan = np.isnan(arr)
first_valid_index = np.argmax(~isnan)
return arr[first_valid_index:]
def linear_interpolation(arr):
"""
Performs linear interpolation to fill NaN values in a 1D numpy array.
Args:
arr: The 1D numpy array containing NaN values.
Returns:
A new numpy array with NaN values filled using linear interpolation,
or the original array if no NaNs are present.
Returns None if the input is not a 1D array.
Returns the original array if there are no NaN values.
"""
nans = np.isnan(arr)
if not np.any(nans): # Check if there are any NaNs
return arr
def x(z):
return z.nonzero()[0]
nans_indices = x(nans)
non_nans_indices = x(~nans)
non_nans_values = arr[~nans]
try:
arr[nans] = np.interp(nans_indices, non_nans_indices, non_nans_values)
except ValueError:
if len(non_nans_values) > 0:
mu = np.nanmean(arr)
else:
mu = 0.0
arr = np.where(np.isfinite(arr), arr, mu)
return arr
# Per time series normalization: forward.
def _normalize(batch):
stats = [
(np.mean(x), np.where((w := np.std(x)) > _TOL, w, 1.0)) for x in batch
]
new_batch = [(x - stat[0]) / stat[1] for x, stat in zip(batch, stats)]
return new_batch, stats
# Per time series normalization: inverse.
def _renormalize(batch, stats):
return [x * stat[1] + stat[0] for x, stat in zip(batch, stats)]
@dataclasses.dataclass(kw_only=True)
class TimesFmHparams:
"""Hparams used to initialize a TimesFM model for inference.
These are the sufficient subset of hparams to configure TimesFM inference
agnostic to the checkpoint version, and are not necessarily the same as the
hparams used to train the checkpoint.
Attributes:
context_len: Largest context length the model allows for each decode call.
This technically can be any large, but practically should set to the
context length the checkpoint was trained with.
horizon_len: Forecast horizon.
input_patch_len: Input patch len.
output_patch_len: Output patch len. How many timepoints is taken from a
single step of autoregressive decoding. Can be set as the training horizon
of the checkpoint.
num_layers: Number of transformer layers in the model.
model_dims: Model dimension.
per_core_batch_size: Batch size on each core for data parallelism.
backend: One of "cpu", "gpu" or "tpu".
quantiles: Which quantiles are output by the model.
"""
context_len: int = 512
horizon_len: int = 128
input_patch_len: int = 32
output_patch_len: int = 128
num_layers: int = 20
num_heads: int = 16
model_dims: int = 1280
per_core_batch_size: int = 32
backend: Literal["cpu", "gpu", "tpu"] = "cpu"
quantiles: Sequence[float] | None = DEFAULT_QUANTILES
use_positional_embedding: bool = True
# Hparams beyond the model.
point_forecast_mode: Literal["mean", "median"] = "median"
@dataclasses.dataclass(kw_only=True)
class TimesFmCheckpoint:
"""Checkpoint used to initialize a TimesFM model for inference.
Attributes:
version: Version of the checkpoint, e.g. "jax", "torch", "tensorflow", etc.
The factory will create the corresponding TimesFm inference class based on
this version.
path: Path to the checkpoint.
type: If provided, type of the checkpoint used by the specific checkpoint
loader per version.
step: If provided, step of the checkpoint.
"""
version: str = "jax"
path: str | None = None
huggingface_repo_id: str | None = None
type: Any = None
step: int | None = None
local_dir: str | None = None
class TimesFmBase:
"""Base TimesFM forecast API for inference.
This class is the scaffolding for calling TimesFM forecast. To properly use:
1. Create an instance with the correct hyperparameters of a TimesFM model.
2. Call `load_from_checkpoint` to load a compatible checkpoint.
3. Call `forecast` for inference.
"""
def _logging(self, s):
print(s)
def __post_init__(self) -> None:
"""Additional initialization for subclasses before checkpoint loading."""
pass
def __init__(self, hparams: TimesFmHparams,
checkpoint: TimesFmCheckpoint) -> None:
"""Initializes the TimesFM forecast API.
Args:
hparams: Hyperparameters of the model.
checkpoint: Checkpoint to load. Notice `checkpoint.version` will decide
which TimesFM version to use.
"""
self.hparams = hparams
# Expand hparams for conciseness within the model code.
self.context_len = hparams.context_len
self.horizon_len = hparams.horizon_len
self.input_patch_len = hparams.input_patch_len
self.output_patch_len = hparams.output_patch_len
self.num_layers = hparams.num_layers
self.model_dims = hparams.model_dims
self.backend = hparams.backend
self.quantiles = hparams.quantiles
self.num_heads = hparams.num_heads
self.use_pos_emb = hparams.use_positional_embedding
# Rewrite these values in __post_init__ for SPMD.
self.num_cores = 1
self.per_core_batch_size = hparams.per_core_batch_size
self.global_batch_size = hparams.per_core_batch_size
self._horizon_start = self.context_len - self.input_patch_len
self.__post_init__()
self.load_from_checkpoint(checkpoint)
def load_from_checkpoint(self, checkpoint: TimesFmCheckpoint) -> None:
"""Loads a checkpoint and compiles the decoder."""
raise NotImplementedError("`load_from_checkpoint` is not implemented.")
def _preprocess(
self, inputs: Sequence[np.ndarray],
freq: Sequence[int]) -> tuple[np.ndarray, np.ndarray, np.ndarray, int]:
"""Formats and pads raw inputs to feed into the model.
This function both pads each time series to match the context length, and
pads the inputs to meet the SPMD shape requirement.
Args:
inputs: A list of 1d JTensors. Each JTensor is the context time series of
a single forecast task.
freq: list of frequencies
Returns:
A tuple of:
- the padded input time series to meet the model required context.
- the padding indicator.
- the frequency of each input time series.
- the number of padded examples for SPMD so that each core has the same
number (a multiple of `batch_size`) of examples.
"""
input_ts, input_padding, inp_freq = [], [], []
pmap_pad = ((len(inputs) - 1) // self.global_batch_size +
1) * self.global_batch_size - len(inputs)
for i, ts in enumerate(inputs):
input_len = ts.shape[0]
padding = np.zeros(shape=(input_len + self.horizon_len,), dtype=float)
if input_len < self.context_len:
num_front_pad = self.context_len - input_len
ts = np.concatenate([np.zeros(shape=(num_front_pad,), dtype=float), ts],
axis=0)
padding = np.concatenate(
[np.ones(shape=(num_front_pad,), dtype=float), padding], axis=0)
elif input_len > self.context_len:
ts = ts[-self.context_len:]
padding = padding[-(self.context_len + self.horizon_len):]
input_ts.append(ts)
input_padding.append(padding)
inp_freq.append(freq[i])
# Padding the remainder batch.
for _ in range(pmap_pad):
input_ts.append(input_ts[-1])
input_padding.append(input_padding[-1])
inp_freq.append(inp_freq[-1])
return (
np.stack(input_ts, axis=0),
np.stack(input_padding, axis=0),
np.array(inp_freq).astype(np.int32).reshape(-1, 1),
pmap_pad,
)
def _forecast(
self,
inputs: Sequence[Any],
freq: Sequence[int] | None = None,
window_size: int | None = None,
forecast_context_len: int | None = None,
return_forecast_on_context: bool = False,
) -> tuple[np.ndarray, np.ndarray]:
"""Forecasts on a list of time series.
Args:
inputs: list of time series forecast contexts. Each context time series
should be in a format convertible to JTensor by `jnp.array`.
freq: frequency of each context time series. 0 for high frequency
(default), 1 for medium, and 2 for low. Notice this is different from
the `freq` required by `forecast_on_df`.
window_size: window size of trend + residual decomposition. If None then
we do not do decomposition.
forecast_context_len: optional max context length.
return_forecast_on_context: True to return the forecast on the context
when available, i.e. after the first input patch.
Returns:
A tuple for np.array:
- the mean forecast of size (# inputs, # forecast horizon),
- the full forecast (mean + quantiles) of size
(# inputs, # forecast horizon, 1 + # quantiles).
Raises:
ValueError: If the checkpoint is not properly loaded.
"""
raise NotImplementedError("`_forecast` is not implemented.")
def forecast(
self,
inputs: Sequence[Any],
freq: Sequence[int] | None = None,
window_size: int | None = None,
forecast_context_len: int | None = None,
return_forecast_on_context: bool = False,
normalize: bool = False,
) -> tuple[np.ndarray, np.ndarray]:
"""Forecasts on a list of time series.
Args:
inputs: list of time series forecast contexts. Each context time series
should be in a format convertible to JTensor by `jnp.array`.
freq: frequency of each context time series. 0 for high frequency
(default), 1 for medium, and 2 for low. Notice this is different from
the `freq` required by `forecast_on_df`.
window_size: window size of trend + residual decomposition. If None then
we do not do decomposition.
forecast_context_len: optional max context length.
return_forecast_on_context: True to return the forecast on the context
when available, i.e. after the first input patch.
normalize: If True, then we normalize the inputs before forecasting and
the outputs are then renormalized to the original scale.
Returns:
A tuple for np.array:
- the mean forecast of size (# inputs, # forecast horizon),
- the full forecast (mean + quantiles) of size
(# inputs, # forecast horizon, 1 + # quantiles).
Raises:
ValueError: If the checkpoint is not properly loaded.
"""
stats = None
tmp_inputs = []
for each_input in inputs:
arr = np.array(each_input)
if not np.isfinite(arr).all():
arr = np.where(np.isfinite(arr), arr, np.nan)
arr = strip_leading_nans(arr)
arr = linear_interpolation(arr)
tmp_inputs.append(arr)
inputs = tmp_inputs
if normalize:
inputs, stats = _normalize(inputs)
mean_forecast, quantile_forecast = self._forecast(
inputs,
freq,
window_size,
forecast_context_len,
return_forecast_on_context,
)
if stats is not None:
stats = np.array(stats)
mu = stats[:, 0]
sigma = stats[:, 1]
mean_forecast = mean_forecast * sigma[:, None] + mu[:, None]
quantile_forecast = (quantile_forecast * sigma[:, None, None] +
mu[:, None, None])
if self.hparams.point_forecast_mode == "mean":
return mean_forecast, quantile_forecast
elif self.hparams.point_forecast_mode == "median":
if self._median_index == -1:
for i, quantile in enumerate(self.quantiles):
if quantile == 0.5:
self._median_index = i
break
if self._median_index == -1:
raise ValueError("Median (0.5) is not found in the model quantiles:"
f" {self.quantiles}. Please check the hparams.")
return (
quantile_forecast[:, :, 1 + self._median_index],
quantile_forecast,
)
else:
raise ValueError(
"Unsupported point forecast mode:"
f" {self.hparams.point_forecast_mode}. Use 'mean' or 'median'.")
def forecast_with_covariates(
self,
inputs: list[Sequence[float]],
dynamic_numerical_covariates: (dict[str, Sequence[Sequence[float]]] |
None) = None,
dynamic_categorical_covariates: (dict[str, Sequence[Sequence[Category]]] |
None) = None,
static_numerical_covariates: dict[str, Sequence[float]] | None = None,
static_categorical_covariates: (dict[str, Sequence[Category]] |
None) = None,
freq: Sequence[int] | None = None,
window_size: int | None = None,
forecast_context_len: int | None = None,
xreg_mode: XRegMode = "xreg + timesfm",
normalize_xreg_target_per_input: bool = True,
ridge: float = 0.0,
max_rows_per_col: int = 0,
force_on_cpu: bool = False,
):
"""Forecasts on a list of time series with covariates.
To optimize inference speed, avoid string valued categorical covariates.
Args:
inputs: A list of time series forecast contexts. Each context time series
should be in a format convertible to JTensor by `jnp.array`.
dynamic_numerical_covariates: A dict of dynamic numerical covariates.
dynamic_categorical_covariates: A dict of dynamic categorical covariates.
static_numerical_covariates: A dict of static numerical covariates.
static_categorical_covariates: A dict of static categorical covariates.
freq: frequency of each context time series. 0 for high frequency
(default), 1 for medium, and 2 for low. Notice this is different from
the `freq` required by `forecast_on_df`.
window_size: window size of trend + residual decomposition. If None then
we do not do decomposition.
forecast_context_len: optional max context length.
xreg_mode: one of "xreg + timesfm" or "timesfm + xreg". "xreg + timesfm"
fits a model on the residuals of the TimesFM forecast. "timesfm + xreg"
fits a model on the targets then forecasts on the residuals via TimesFM.
normalize_xreg_target_per_input: whether to normalize the xreg target per
input in the given batch.
ridge: ridge penalty for the linear model.
max_rows_per_col: max number of rows per column for the linear model.
force_on_cpu: whether to force running on cpu for the linear model.
Returns:
A tuple of two lists. The first is the outputs of the model. The second is
the outputs of the xreg.
"""
from . import xreg_lib
# Verify and bookkeep covariates.
if not (dynamic_numerical_covariates or dynamic_categorical_covariates or
static_numerical_covariates or static_categorical_covariates):
raise ValueError(
"At least one of dynamic_numerical_covariates,"
" dynamic_categorical_covariates, static_numerical_covariates,"
" static_categorical_covariates must be set.")
# Track the lengths of (1) each input, (2) the part that can be used in the
# linear model, and (3) the horizon.
input_lens, train_lens, test_lens = [], [], []
for i, input_ts in enumerate(inputs):
input_len = len(input_ts)
input_lens.append(input_len)
if xreg_mode == "timesfm + xreg":
# For fitting residuals, no TimesFM forecast on the first patch.
train_lens.append(max(0, input_len - self.input_patch_len))
elif xreg_mode == "xreg + timesfm":
train_lens.append(input_len)
else:
raise ValueError(f"Unsupported mode: {xreg_mode}")
if dynamic_numerical_covariates:
test_lens.append(
len(list(dynamic_numerical_covariates.values())[0][i]) - input_len)
elif dynamic_categorical_covariates:
test_lens.append(
len(list(dynamic_categorical_covariates.values())[0][i]) -
input_len)
else:
test_lens.append(self.horizon_len)
if test_lens[-1] > self.horizon_len:
raise ValueError(
"Forecast requested longer horizon than the model definition "
f"supports: {test_lens[-1]} vs {self.horizon_len}.")
# Prepare the covariates into train and test.
train_dynamic_numerical_covariates = collections.defaultdict(list)
test_dynamic_numerical_covariates = collections.defaultdict(list)
train_dynamic_categorical_covariates = collections.defaultdict(list)
test_dynamic_categorical_covariates = collections.defaultdict(list)
for covariates, train_covariates, test_covariates in (
(
dynamic_numerical_covariates,
train_dynamic_numerical_covariates,
test_dynamic_numerical_covariates,
),
(
dynamic_categorical_covariates,
train_dynamic_categorical_covariates,
test_dynamic_categorical_covariates,
),
):
if not covariates:
continue
for covariate_name, covariate_values in covariates.items():
for input_len, train_len, covariate_value in zip(
input_lens, train_lens, covariate_values):
train_covariates[covariate_name].append(
covariate_value[(input_len - train_len):input_len])
test_covariates[covariate_name].append(covariate_value[input_len:])
# Fit models.
if xreg_mode == "timesfm + xreg":
# Forecast via TimesFM then fit a model on the residuals.
mean_outputs, _ = self.forecast(
inputs,
freq,
window_size,
forecast_context_len,
return_forecast_on_context=True,
)
targets = [
(np.array(input_ts)[-train_len:] -
mean_output[(self._horizon_start - train_len):self._horizon_start])
for input_ts, mean_output, train_len in zip(inputs, mean_outputs,
train_lens)
]
per_instance_stats = None
if normalize_xreg_target_per_input:
targets, per_instance_stats = _normalize(targets)
xregs = xreg_lib.BatchedInContextXRegLinear(
targets=targets,
train_lens=train_lens,
test_lens=test_lens,
train_dynamic_numerical_covariates=train_dynamic_numerical_covariates,
test_dynamic_numerical_covariates=test_dynamic_numerical_covariates,
train_dynamic_categorical_covariates=
train_dynamic_categorical_covariates,
test_dynamic_categorical_covariates=
test_dynamic_categorical_covariates,
static_numerical_covariates=static_numerical_covariates,
static_categorical_covariates=static_categorical_covariates,
).fit(
ridge=ridge,
one_hot_encoder_drop=None if ridge > 0 else "first",
max_rows_per_col=max_rows_per_col,
force_on_cpu=force_on_cpu,
debug_info=False,
assert_covariates=True,
assert_covariate_shapes=True,
)
if normalize_xreg_target_per_input:
xregs = _renormalize(xregs, per_instance_stats)
outputs = [
(mean_output[self._horizon_start:(self._horizon_start + test_len)] +
xreg)
for mean_output, test_len, xreg in zip(mean_outputs, test_lens, xregs)
]
else:
# Fit a model on the targets then forecast on the residuals via TimesFM.
targets = [
np.array(input_ts)[-train_len:]
for input_ts, train_len in zip(inputs, train_lens)
]
per_instance_stats = None
if normalize_xreg_target_per_input:
targets, per_instance_stats = _normalize(targets)
xregs, xregs_on_context, _, _, _ = xreg_lib.BatchedInContextXRegLinear(
targets=targets,
train_lens=train_lens,
test_lens=test_lens,
train_dynamic_numerical_covariates=train_dynamic_numerical_covariates,
test_dynamic_numerical_covariates=test_dynamic_numerical_covariates,
train_dynamic_categorical_covariates=
train_dynamic_categorical_covariates,
test_dynamic_categorical_covariates=
test_dynamic_categorical_covariates,
static_numerical_covariates=static_numerical_covariates,
static_categorical_covariates=static_categorical_covariates,
).fit(
ridge=ridge,
one_hot_encoder_drop=None if ridge > 0 else "first",
max_rows_per_col=max_rows_per_col,
force_on_cpu=force_on_cpu,
debug_info=True,
assert_covariates=True,
assert_covariate_shapes=True,
)
mean_outputs, _ = self.forecast(
[
target - xreg_on_context
for target, xreg_on_context in zip(targets, xregs_on_context)
],
freq,
window_size,
forecast_context_len,
return_forecast_on_context=True,
)
outputs = [
(mean_output[self._horizon_start:(self._horizon_start + test_len)] +
xreg)
for mean_output, test_len, xreg in zip(mean_outputs, test_lens, xregs)
]
if normalize_xreg_target_per_input:
outputs = _renormalize(outputs, per_instance_stats)
return outputs, xregs
def forecast_on_df(
self,
inputs: pd.DataFrame,
freq: str,
forecast_context_len: int = 0,
value_name: str = "values",
model_name: str = "timesfm",
window_size: int | None = None,
num_jobs: int = 1,
normalize: bool = False,
verbose: bool = True,
) -> pd.DataFrame:
"""Forecasts on a list of time series.
Args:
inputs: A pd.DataFrame of all time series. The dataframe should have a
`unique_id` column for identifying the time series, a `ds` column for
timestamps and a value column for the time series values.
freq: string valued `freq` of data. Notice this is different from the
`freq` required by `forecast`. See `freq_map` for allowed values.
forecast_context_len: If provided none zero, we take the last
`forecast_context_len` time-points from each series as the forecast
context instead of the `context_len` set by the model.
value_name: The name of the value column.
model_name: name of the model to be written into future df.
window_size: window size of trend + residual decomposition. If None then
we do not do decomposition.
num_jobs: number of parallel processes to use for dataframe processing.
normalize: normalize context before forecasting or not.
verbose: output model states in terminal.
Returns:
Future forecasts dataframe.
"""
if not ("unique_id" in inputs.columns and "ds" in inputs.columns and
value_name in inputs.columns):
raise ValueError(
f"DataFrame must have unique_id, ds and {value_name} columns.")
if not forecast_context_len:
forecast_context_len = self.context_len
logging.info("Preprocessing dataframe.")
df_sorted = inputs.sort_values(by=["unique_id", "ds"])
new_inputs = []
uids = []
if num_jobs == 1:
if verbose:
print("Processing dataframe with single process.")
for key, group in df_sorted.groupby("unique_id"):
inp, uid = process_group(
key,
group,
value_name,
forecast_context_len,
)
new_inputs.append(inp)
uids.append(uid)
else:
if num_jobs == -1:
num_jobs = multiprocessing.cpu_count()
if verbose:
print("Processing dataframe with multiple processes.")
with multiprocessing.Pool(processes=num_jobs) as pool:
results = pool.starmap(
process_group,
[(key, group, value_name, forecast_context_len)
for key, group in df_sorted.groupby("unique_id")],
)
new_inputs, uids = zip(*results)
if verbose:
print("Finished preprocessing dataframe.")
freq_inps = [freq_map(freq)] * len(new_inputs)
_, full_forecast = self.forecast(new_inputs,
freq=freq_inps,
normalize=normalize,
window_size=window_size)
if verbose:
print("Finished forecasting.")
fcst_df = make_future_dataframe(
uids=uids,
last_times=df_sorted.groupby("unique_id")["ds"].tail(1),
h=self.horizon_len,
freq=freq,
)
fcst_df[model_name] = full_forecast[:, 0:self.horizon_len, 0].reshape(-1, 1)
for i, q in enumerate(self.quantiles):
q_col = f"{model_name}-q-{q}"
fcst_df[q_col] = full_forecast[:, 0:self.horizon_len,
1 + i].reshape(-1, 1)
if q == 0.5:
fcst_df[model_name] = fcst_df[q_col]
logging.info("Finished creating output dataframe.")
return fcst_df
================================================
FILE: v1/src/timesfm/timesfm_jax.py
================================================
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TimesFM JAX forecast API for inference."""
import logging
import multiprocessing
import time
from os import path
from typing import Any, Sequence
import einshape as es
import jax
import jax.numpy as jnp
import numpy as np
from huggingface_hub import snapshot_download
from paxml import checkpoints, tasks_lib
from praxis import base_hyperparams, base_layer, pax_fiddle, py_utils, pytypes
from praxis.layers import normalizations, transformers
from timesfm import timesfm_base
from timesfm import patched_decoder
instantiate = base_hyperparams.instantiate
NestedMap = py_utils.NestedMap
JTensor = pytypes.JTensor
_TOL = 1e-6
class TimesFmJax(timesfm_base.TimesFmBase):
"""TimesFM forecast API for inference.
This class is the scaffolding for calling TimesFM forecast. To properly use:
1. Create an instance with the correct hyperparameters of a TimesFM model.
2. Call `load_from_checkpoint` to load a compatible checkpoint.
3. Call `forecast` for inference.
Given the model size, this API does not shard the model weights for SPMD. All
parallelism happens on the data dimension.
Compilation happens during the first time `forecast` is called and uses the
`per_core_batch_size` to set and freeze the input signature. Subsequent calls
to `forecast` reflect the actual inference latency.
"""
def _get_sample_inputs(self):
return {
"input_ts":
jnp.zeros(
(
self.per_core_batch_size,
self.context_len + self.output_patch_len,
),
dtype=jnp.float32,
),
"input_padding":
jnp.zeros(
(
self.per_core_batch_size,
self.context_len + self.output_patch_len,
),
dtype=jnp.float32,
),
"freq":
jnp.zeros(
(
self.per_core_batch_size,
1,
),
dtype=jnp.int32,
),
}
def __post_init__(self):
self.num_cores = jax.local_device_count(self.backend)
self.global_batch_size = self.per_core_batch_size * self.num_cores
self._eval_context = base_layer.JaxContext.HParams(do_eval=True)
self._pmapped_decode = None
self._model = None
self._train_state = None
self._median_index = -1
def load_from_checkpoint(
self,
checkpoint: timesfm_base.TimesFmCheckpoint,
) -> None:
"""Loads a checkpoint and compiles the decoder."""
checkpoint_type = (checkpoints.CheckpointType.FLAX
if checkpoint.type is None else checkpoint.type)
checkpoint_path = checkpoint.path
step = checkpoint.step
repo_id = checkpoint.huggingface_repo_id
if checkpoint_path is None:
checkpoint_path = path.join(snapshot_download(repo_id), "checkpoints")
# Rewrite the devices for Jax.
self.mesh_shape = [1, self.num_cores, 1]
self.mesh_name = ["replica", "data", "mdl"]
self.model_p = pax_fiddle.Config(
patched_decoder.PatchedTimeSeriesDecoder,
name="patched_decoder",
horizon_len=self.output_patch_len,
patch_len=self.input_patch_len,
model_dims=self.model_dims,
hidden_dims=self.model_dims,
residual_block_tpl=pax_fiddle.Config(patched_decoder.ResidualBlock),
quantiles=self.quantiles,
use_freq=True,
use_pos_emb=self.use_pos_emb,
stacked_transformer_params_tpl=pax_fiddle.Config(
transformers.StackedTransformer,
num_heads=self.num_heads,
num_layers=self.num_layers,
transformer_layer_params_tpl=pax_fiddle.Config(
transformers.Transformer,
ln_tpl=pax_fiddle.Config(normalizations.RmsNorm,),
),
),
)
self._key1, self._key2 = jax.random.split(jax.random.PRNGKey(42))
self._model = None
self._train_state = None
self._pmapped_decode = None
self._eval_context = base_layer.JaxContext.HParams(do_eval=True)
try:
multiprocessing.set_start_method("spawn")
except RuntimeError:
print("Multiprocessing context has already been set.")
# Download the checkpoint from Hugging Face Hub if not given
# Initialize the model weights.
self._logging("Constructing model weights.")
start_time = time.time()
self._model = instantiate(self.model_p)
var_weight_hparams = self._model.abstract_init_with_metadata(
self._get_sample_inputs(), do_eval=True)
train_state_partition_specs = tasks_lib.create_state_partition_specs(
var_weight_hparams,
mesh_shape=self.mesh_shape,
mesh_axis_names=self.mesh_name,
discard_opt_states=True,
learners=None,
)
train_state_local_shapes = tasks_lib.create_state_unpadded_shapes(
var_weight_hparams,
discard_opt_states=True,
learners=None,
)
self._logging(
f"Constructed model weights in {time.time() - start_time:.2f} seconds.")
# Load the model weights.
self._logging(f"Restoring checkpoint from {checkpoint_path}.")
start_time = time.time()
self._train_state = checkpoints.restore_checkpoint(
train_state_local_shapes,
checkpoint_dir=checkpoint_path,
checkpoint_type=checkpoint_type,
state_specs=train_state_partition_specs,
step=step,
)
self._logging(
f"Restored checkpoint in {time.time() - start_time:.2f} seconds.")
self.jit_decode()
def jit_decode(self):
"""Jitting decoding function."""
# Initialize and jit the decode fn.
def _decode(inputs):
assert self._model is not None
assert self._train_state is not None
return self._model.apply(
self._train_state.mdl_vars,
inputs,
horizon_len=self.horizon_len,
output_patch_len=self.output_patch_len,
max_len=self.context_len,
return_forecast_on_context=True,
rngs={
base_layer.PARAMS: self._key1,
base_layer.RANDOM: self._key2,
},
method=self._model.decode,
)
self._logging("Jitting decoding.")
start_time = time.time()
self._pmapped_decode = jax.pmap(
_decode,
axis_name="batch",
devices=jax.devices(self.backend),
backend=self.backend,
axis_size=self.num_cores,
)
with base_layer.JaxContext.new_context(hparams=self._eval_context):
_ = self._pmapped_decode(
NestedMap({
"input_ts":
jnp.zeros(
(
self.num_cores,
self.per_core_batch_size,
self.context_len,
),
dtype=jnp.float32,
),
"input_padding":
jnp.zeros(
(
self.num_cores,
self.per_core_batch_size,
self.context_len + self.horizon_len,
),
dtype=jnp.float32,
),
"date_features":
None,
"freq":
jnp.zeros(
(self.num_cores, self.per_core_batch_size, 1),
dtype=jnp.int32,
),
}))
self._logging(f"Jitted decoding in {time.time() - start_time:.2f} seconds.")
def _forecast(
self,
inputs: Sequence[Any],
freq: Sequence[int] | None = None,
window_size: int | None = None,
forecast_context_len: int | None = None,
return_forecast_on_context: bool = False,
) -> tuple[np.ndarray, np.ndarray]:
"""Forecasts on a list of time series.
Args:
inputs: list of time series forecast contexts. Each context time series
should be in a format convertible to JTensor by `jnp.array`.
freq: frequency of each context time series. 0 for high frequency
(default), 1 for medium, and 2 for low. Notice this is different from
the `freq` required by `forecast_on_df`.
window_size: window size of trend + residual decomposition. If None then
we do not do decomposition.
forecast_context_len: optional max context length.
return_forecast_on_context: True to return the forecast on the context
when available, i.e. after the first input patch.
Returns:
A tuple for JTensors:
- the mean forecast of size (# inputs, # forecast horizon),
- the full forecast (mean + quantiles) of size
(# inputs, # forecast horizon, 1 + # quantiles).
Raises:
ValueError: If the checkpoint is not properly loaded.
"""
if not self._train_state or not self._model:
raise ValueError(
"Checkpoint not loaded. Call `load_from_checkpoint` before"
" `forecast`.")
if forecast_context_len is None:
fcontext_len = self.context_len
else:
fcontext_len = forecast_context_len
inputs = [np.array(ts)[-fcontext_len:] for ts in inputs]
if window_size is not None:
new_inputs = []
for ts in inputs:
new_inputs.extend(timesfm_base.moving_average(ts, window_size))
inputs = new_inputs
if freq is None:
logging.info("No frequency provided via `freq`. Default to high (0).")
freq = [0] * len(inputs)
input_ts, input_padding, inp_freq, pmap_pad = self._preprocess(inputs, freq)
with base_layer.JaxContext.new_context(hparams=self._eval_context):
mean_outputs = []
full_outputs = []
assert input_ts.shape[0] % self.global_batch_size == 0
for i in range(input_ts.shape[0] // self.global_batch_size):
input_ts_in = jnp.array(input_ts[i * self.global_batch_size:(i + 1) *
self.global_batch_size])
input_padding_in = jnp.array(
input_padding[i * self.global_batch_size:(i + 1) *
self.global_batch_size],)
inp_freq_in = jnp.array(
inp_freq[i * self.global_batch_size:(i + 1) *
self.global_batch_size, :],
dtype=jnp.int32,
)
pmapped_inputs = NestedMap({
"input_ts":
es.jax_einshape(
"(db)...->db...",
input_ts_in,
d=self.num_cores,
),
"input_padding":
es.jax_einshape(
"(db)...->db...",
input_padding_in,
d=self.num_cores,
),
"date_features":
None,
"freq":
es.jax_einshape(
"(db)...->db...",
inp_freq_in,
d=self.num_cores,
),
})
mean_output, full_output = self._pmapped_decode(pmapped_inputs)
if not return_forecast_on_context:
mean_output = mean_output[:, :, self._horizon_start:, ...]
full_output = full_output[:, :, self._horizon_start:, ...]
mean_output = es.jax_einshape("db...->(db)...",
mean_output,
d=self.num_cores)
full_output = es.jax_einshape("db...->(db)...",
full_output,
d=self.num_cores)
mean_output = np.array(mean_output)
full_output = np.array(full_output)
mean_outputs.append(mean_output)
full_outputs.append(full_output)
mean_outputs = np.concatenate(mean_outputs, axis=0)
full_outputs = np.concatenate(full_outputs, axis=0)
if pmap_pad > 0:
mean_outputs = mean_outputs[:-pmap_pad, ...]
full_outputs = full_outputs[:-pmap_pad, ...]
if window_size is not None:
mean_outputs = mean_outputs[0::2, ...] + mean_outputs[1::2, ...]
full_outputs = full_outputs[0::2, ...] + full_outputs[1::2, ...]
return mean_outputs, full_outputs
================================================
FILE: v1/src/timesfm/timesfm_torch.py
================================================
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TimesFM pytorch forecast API for inference."""
import logging
from os import path
from typing import Any, Sequence
import numpy as np
import torch
from huggingface_hub import snapshot_download
from timesfm import timesfm_base
from . import pytorch_patched_decoder as ppd
_TOL = 1e-6
class TimesFmTorch(timesfm_base.TimesFmBase):
"""TimesFM forecast API for inference."""
def __post_init__(self):
self._model_config = ppd.TimesFMConfig(
num_layers=self.num_layers,
num_heads=self.num_heads,
hidden_size=self.model_dims,
intermediate_size=self.model_dims,
patch_len=self.input_patch_len,
horizon_len=self.output_patch_len,
head_dim=self.model_dims // self.num_heads,
quantiles=self.quantiles,
use_positional_embedding=self.use_pos_emb,
)
self._model = None
self.num_cores = 1
self.global_batch_size = self.per_core_batch_size
self._device = torch.device("cuda:0" if (
torch.cuda.is_available() and self.backend == "gpu") else "cpu")
self._median_index = -1
def load_from_checkpoint(
self,
checkpoint: timesfm_base.TimesFmCheckpoint,
) -> None:
"""Loads a checkpoint and compiles the decoder."""
checkpoint_path = checkpoint.path
repo_id = checkpoint.huggingface_repo_id
if checkpoint_path is None:
checkpoint_path = path.join(
snapshot_download(repo_id, local_dir=checkpoint.local_dir),
"torch_model.ckpt")
self._model = ppd.PatchedTimeSeriesDecoder(self._model_config)
loaded_checkpoint = torch.load(checkpoint_path, weights_only=True)
logging.info("Loading checkpoint from %s", checkpoint_path)
self._model.load_state_dict(loaded_checkpoint)
logging.info("Sending checkpoint to device %s", f"{self._device}")
self._model.to(self._device)
self._model.eval()
# TODO: add compilation.
def _forecast(
self,
inputs: Sequence[Any],
freq: Sequence[int] | None = None,
window_size: int | None = None,
forecast_context_len: int | None = None,
return_forecast_on_context: bool = False,
) -> tuple[np.ndarray, np.ndarray]:
"""Forecasts on a list of time series.
Args:
inputs: list of time series forecast contexts. Each context time series
should be in a format convertible to JTensor by `jnp.array`.
freq: frequency of each context time series. 0 for high frequency
(default), 1 for medium, and 2 for low. Notice this is different from
the `freq` required by `forecast_on_df`.
window_size: window size of trend + residual decomposition. If None then
we do not do decomposition.
forecast_context_len: optional max context length.
return_forecast_on_context: True to return the forecast on the context
when available, i.e. after the first input patch.
Returns:
A tuple for JTensors:
- the mean forecast of size (# inputs, # forecast horizon),
- the full forecast (mean + quantiles) of size
(# inputs, # forecast horizon, 1 + # quantiles).
Raises:
ValueError: If the checkpoint is not properly loaded.
"""
if self._model is None:
raise ValueError("Checkpoint is not properly loaded.")
if forecast_context_len is None:
forecast_context_len = self.context_len
inputs = [np.array(ts)[-forecast_context_len:] for ts in inputs]
if window_size is not None:
new_inputs = []
for ts in inputs:
new_inputs.extend(timesfm_base.moving_average(ts, window_size))
inputs = new_inputs
if freq is None:
logging.info("No frequency provided via `freq`. Default to high (0).")
freq = [0] * len(inputs)
input_ts, input_padding, inp_freq, pmap_pad = self._preprocess(inputs, freq)
with torch.no_grad():
mean_outputs = []
full_outputs = []
for i in range(input_ts.shape[0] // self.global_batch_size):
t_input_ts = torch.Tensor(input_ts[i * self.global_batch_size:(i + 1) *
self.global_batch_size]).to(
self._device)
t_input_padding = torch.Tensor(
input_padding[i * self.global_batch_size:(i + 1) *
self.global_batch_size]).to(self._device)
t_inp_freq = torch.LongTensor(
inp_freq[i * self.global_batch_size:(i + 1) *
self.global_batch_size, :]).to(self._device)
mean_output, full_output = self._model.decode(
input_ts=t_input_ts,
paddings=t_input_padding,
freq=t_inp_freq,
horizon_len=self.horizon_len,
output_patch_len=self.output_patch_len,
# Returns forecasts on context for parity with the Jax version.
return_forecast_on_context=True,
)
if not return_forecast_on_context:
mean_output = mean_output[:, self._horizon_start:, ...]
full_output = full_output[:, self._horizon_start:, ...]
if self.backend == "gpu":
mean_output = mean_output.cpu()
full_output = full_output.cpu()
mean_output = mean_output.detach().numpy()
full_output = full_output.detach().numpy()
mean_outputs.append(mean_output)
full_outputs.append(full_output)
mean_outputs = np.concatenate(mean_outputs, axis=0)
full_outputs = np.concatenate(full_outputs, axis=0)
if pmap_pad > 0:
mean_outputs = mean_outputs[:-pmap_pad, ...]
full_outputs = full_outputs[:-pmap_pad, ...]
if window_size is not None:
mean_outputs = mean_outputs[0::2, ...] + mean_outputs[1::2, ...]
full_outputs = full_outputs[0::2, ...] + full_outputs[1::2, ...]
return mean_outputs, full_outputs
================================================
FILE: v1/src/timesfm/xreg_lib.py
================================================
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helper functions for in-context covariates and regression."""
import itertools
import math
from typing import Any, Iterable, Literal, Mapping, Sequence
import jax
import jax.numpy as jnp
import numpy as np
from sklearn import preprocessing
Category = int | str
_TOL = 1e-6
XRegMode = Literal["timesfm + xreg", "xreg + timesfm"]
def _unnest(nested: Sequence[Sequence[Any]]) -> np.ndarray:
return np.array(list(itertools.chain.from_iterable(nested)))
def _repeat(elements: Iterable[Any], counts: Iterable[int]) -> np.ndarray:
return np.array(
list(
itertools.chain.from_iterable(map(itertools.repeat, elements,
counts))))
def _to_padded_jax_array(x: np.ndarray) -> jax.Array:
if x.ndim == 1:
(i,) = x.shape
di = 2**math.ceil(math.log2(i)) - i
return jnp.pad(x, ((0, di),), mode="constant", constant_values=0.0)
elif x.ndim == 2:
i, j = x.shape
di = 2**math.ceil(math.log2(i)) - i
dj = 2**math.ceil(math.log2(j)) - j
return jnp.pad(x, ((0, di), (0, dj)), mode="constant", constant_values=0.0)
else:
raise ValueError(f"Unsupported array shape: {x.shape}")
class BatchedInContextXRegBase:
"""Helper class for in-context regression covariate formatting.
Attributes:
targets: List of targets (responses) of the in-context regression.
train_lens: List of lengths of each target vector from the context.
test_lens: List of lengths of each forecast horizon.
train_dynamic_numerical_covariates: Dict of covariate names mapping to the
dynamic numerical covariates of each forecast task on the context. Their
lengths should match the corresponding lengths in `train_lens`.
train_dynamic_categorical_covariates: Dict of covariate names mapping to the
dynamic categorical covariates of each forecast task on the context. Their
lengths should match the corresponding lengths in `train_lens`.
test_dynamic_numerical_covariates: Dict of covariate names mapping to the
dynamic numerical covariates of each forecast task on the horizon. Their
lengths should match the corresponding lengths in `test_lens`.
test_dynamic_categorical_covariates: Dict of covariate names mapping to the
dynamic categorical covariates of each forecast task on the horizon. Their
lengths should match the corresponding lengths in `test_lens`.
static_numerical_covariates: Dict of covariate names mapping to the static
numerical covariates of each forecast task.
static_categorical_covariates: Dict of covariate names mapping to the static
categorical covariates of each forecast task.
"""
def __init__(
self,
targets: Sequence[Sequence[float]],
train_lens: Sequence[int],
test_lens: Sequence[int],
train_dynamic_numerical_covariates: (
Mapping[str, Sequence[Sequence[float]]] | None) = None,
train_dynamic_categorical_covariates: (
Mapping[str, Sequence[Sequence[Category]]] | None) = None,
test_dynamic_numerical_covariates: (
Mapping[str, Sequence[Sequence[float]]] | None) = None,
test_dynamic_categorical_covariates: (
Mapping[str, Sequence[Sequence[Category]]] | None) = None,
static_numerical_covariates: Mapping[str, Sequence[float]] | None = None,
static_categorical_covariates: (Mapping[str, Sequence[Category]] |
None) = None,
) -> None:
"""Initializes with the exogenous covariate inputs.
Here we use model fitting language to refer to the context as 'train' and
the horizon as 'test'. We assume batched inputs. To properly format the
request:
- `train_lens` represents the contexts in the batch. Targets and all train
dynamic covariates should have the same lengths as the corresponding
elements
in `train_lens`. Notice each `train_len` can be different from the exact
length of the corresponding context depending on how much of the context is
used for fitting the in-context model.
- `test_lens` represents the horizon lengths in the batch. All tesdt
dynamic
covariates should have the same lengths as the corresponding elements in
`test_lens`.
- Static covariates should be one for each input.
- For train and test dynamic covariates, they should have the same
covariate
names.
Pass an empty dict {} for a covariate type if it is not present.
Example:
Here is a set of valid inputs whose schema can be used for reference.
```
targets = [
[0.0, 0.1, 0.2],
[0.0, 0.1, 0.2, 0.3],
] # Two inputs in this batch.
train_lens = [3, 4]
test_lens = [2, 5] # Forecast horizons 2 and 5 respectively.
train_dynamic_numerical_covariates = {
"cov_1_dn": [[0.0, 0.5, 1.0], [0.0, 0.5, 1.0, 1.5]],
"cov_2_dn": [[0.0, 1.5, 1.0], [0.0, 1.5, 1.0, 2.5]],
} # Each train dynamic covariate has 3 and 4 elements respectively.
test_dynamic_numerical_covariates = {
"cov_1_dn": [[0.1, 0.6], [0.1, 0.6, 1.1, 1.6, 2.4]],
"cov_2_dn": [[0.1, 1.1], [0.1, 1.6, 1.1, 2.6, 10.0]],
} # Each test dynamic covariate has 2 and 5 elements respectively.
train_dynamic_categorical_covariates = {
"cov_1_dc": [[0, 1, 0], [0, 1, 2, 3]],
"cov_2_dc": [["good", "bad", "good"], ["good", "good", "bad",
"bad"]],
}
test_dynamic_categorical_covariates = {
"cov_1_dc": [[1, 0], [1, 0, 2, 3, 1]],
"cov_2_dc": [["bad", "good"], ["bad", "bad", "bad", "bad", "bad"]],
}
static_numerical_covariates = {
"cov_1_sn": [0.0, 3.0],
"cov_2_sn": [2.0, 1.0],
"cov_3_sn": [1.0, 2.0],
} # Each static covariate has 1 element for each input.
static_categorical_covariates = {
"cov_1_sc": ["apple", "orange"],
"cov_2_sc": [2, 3],
}
```
Args:
targets: List of targets (responses) of the in-context regression.
train_lens: List of lengths of each target vector from the context.
test_lens: List of lengths of each forecast horizon.
train_dynamic_numerical_covariates: Dict of covariate names mapping to the
dynamic numerical covariates of each forecast task on the context. Their
lengths should match the corresponding lengths in `train_lens`.
train_dynamic_categorical_covariates: Dict of covariate names mapping to
the dynamic categorical covariates of each forecast task on the context.
Their lengths should match the corresponding lengths in `train_lens`.
test_dynamic_numerical_covariates: Dict of covariate names mapping to the
dynamic numerical covariates of each forecast task on the horizon. Their
lengths should match the corresponding lengths in `test_lens`.
test_dynamic_categorical_covariates: Dict of covariate names mapping to
the dynamic categorical covariates of each forecast task on the horizon.
Their lengths should match the corresponding lengths in `test_lens`.
static_numerical_covariates: Dict of covariate names mapping to the static
numerical covariates of each forecast task.
static_categorical_covariates: Dict of covariate names mapping to the
static categorical covariates of each forecast task.
"""
self.targets = targets
self.train_lens = train_lens
self.test_lens = test_lens
self.train_dynamic_numerical_covariates = (
train_dynamic_numerical_covariates or {})
self.train_dynamic_categorical_covariates = (
train_dynamic_categorical_covariates or {})
self.test_dynamic_numerical_covariates = (test_dynamic_numerical_covariates
or {})
self.test_dynamic_categorical_covariates = (
test_dynamic_categorical_covariates or {})
self.static_numerical_covariates = static_numerical_covariates or {}
self.static_categorical_covariates = static_categorical_covariates or {}
def _assert_covariates(self, assert_covariate_shapes: bool = False) -> None:
"""Verifies the validity of the covariate inputs."""
# Check presence.
if (self.train_dynamic_numerical_covariates and
not self.test_dynamic_numerical_covariates) or (
not self.train_dynamic_numerical_covariates and
self.test_dynamic_numerical_covariates):
raise ValueError(
"train_dynamic_numerical_covariates and"
" test_dynamic_numerical_covariates must be both present or both"
" absent.")
if (self.train_dynamic_categorical_covariates and
not self.test_dynamic_categorical_covariates) or (
not self.train_dynamic_categorical_covariates and
self.test_dynamic_categorical_covariates):
raise ValueError(
"train_dynamic_categorical_covariates and"
" test_dynamic_categorical_covariates must be both present or both"
" absent.")
# Check keys.
for dict_a, dict_b, dict_a_name, dict_b_name in (
(
self.train_dynamic_numerical_covariates,
self.test_dynamic_numerical_covariates,
"train_dynamic_numerical_covariates",
"test_dynamic_numerical_covariates",
),
(
self.train_dynamic_categorical_covariates,
self.test_dynamic_categorical_covariates,
"train_dynamic_categorical_covariates",
"test_dynamic_categorical_covariates",
),
):
if w := set(dict_a.keys()) - set(dict_b.keys()):
raise ValueError(
f"{dict_a_name} has keys not present in {dict_b_name}: {w}")
if w := set(dict_b.keys()) - set(dict_a.keys()):
raise ValueError(
f"{dict_b_name} has keys not present in {dict_a_name}: {w}")
# Check shapes.
if assert_covariate_shapes:
if len(self.targets) != len(self.train_lens):
raise ValueError(
"targets and train_lens must have the same number of elements.")
if len(self.train_lens) != len(self.test_lens):
raise ValueError(
"train_lens and test_lens must have the same number of elements.")
for i, (target, train_len) in enumerate(zip(self.targets,
self.train_lens)):
if len(target) != train_len:
raise ValueError(
f"targets[{i}] has length {len(target)} != expected {train_len}.")
for key, values in self.static_numerical_covariates.items():
if len(values) != len(self.train_lens):
raise ValueError(
f"static_numerical_covariates has key {key} with number of"
f" examples {len(values)} != expected {len(self.train_lens)}.")
for key, values in self.static_categorical_covariates.items():
if len(values) != len(self.train_lens):
raise ValueError(
f"static_categorical_covariates has key {key} with number of"
f" examples {len(values)} != expected {len(self.train_lens)}.")
for lens, dict_cov, dict_cov_name in (
(
self.train_lens,
self.train_dynamic_numerical_covariates,
"train_dynamic_numerical_covariates",
),
(
self.train_lens,
self.train_dynamic_categorical_covariates,
"train_dynamic_categorical_covariates",
),
(
self.test_lens,
self.test_dynamic_numerical_covariates,
"test_dynamic_numerical_covariates",
),
(
self.test_lens,
self.test_dynamic_categorical_covariates,
"test_dynamic_categorical_covariates",
),
):
for key, cov_values in dict_cov.items():
if len(cov_values) != len(lens):
raise ValueError(
f"{dict_cov_name} has key {key} with number of examples"
f" {len(cov_values)} != expected {len(lens)}.")
for i, cov_value in enumerate(cov_values):
if len(cov_value) != lens[i]:
raise ValueError(
f"{dict_cov_name} has key {key} with its {i}-th example"
f" length {len(cov_value)} != expected {lens[i]}.")
def create_covariate_matrix(
self,
one_hot_encoder_drop: str | None = "first",
use_intercept: bool = True,
assert_covariates: bool = False,
assert_covariate_shapes: bool = False,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Creates target vector and covariate matrices for in context regression.
Here we use model fitting language to refer to the context as 'train' and
the horizon as 'test'.
Args:
one_hot_encoder_drop: Which drop strategy to use for the one hot encoder.
use_intercept: Whether to prepare an intercept (all 1) column in the
matrices.
assert_covariates: Whether to assert the validity of the covariate inputs.
assert_covariate_shapes: Whether to assert the shapes of the covariate
inputs when `assert_covariates` is True.
Returns:
A tuple of the target vector, the covariate matrix for the context, and
the covariate matrix for the horizon.
"""
if assert_covariates:
self._assert_covariates(assert_covariate_shapes)
x_train, x_test = [], []
# Numerical features.
for name in sorted(self.train_dynamic_numerical_covariates):
x_train.append(
_unnest(self.train_dynamic_numerical_covariates[name])[:, np.newaxis])
x_test.append(
_unnest(self.test_dynamic_numerical_covariates[name])[:, np.newaxis])
for covs in self.static_numerical_covariates.values():
x_train.append(_repeat(covs, self.train_lens)[:, np.newaxis])
x_test.append(_repeat(covs, self.test_lens)[:, np.newaxis])
if x_train:
x_train = np.concatenate(x_train, axis=1)
x_test = np.concatenate(x_test, axis=1)
# Normalize for robustness.
x_mean = np.mean(x_train, axis=0, keepdims=True)
x_std = np.where((w := np.std(x_train, axis=0, keepdims=True)) > _TOL, w,
1.0)
x_train = [(x_train - x_mean) / x_std]
x_test = [(x_test - x_mean) / x_std]
# Categorical features. Encode one by one.
one_hot_encoder = preprocessing.OneHotEncoder(
drop=one_hot_encoder_drop,
sparse_output=False,
handle_unknown="ignore",
)
for name in sorted(self.train_dynamic_categorical_covariates.keys()):
ohe_train = _unnest(
self.train_dynamic_categorical_covariates[name])[:, np.newaxis]
ohe_test = _unnest(
self.test_dynamic_categorical_covariates[name])[:, np.newaxis]
x_train.append(np.array(one_hot_encoder.fit_transform(ohe_train)))
x_test.append(np.array(one_hot_encoder.transform(ohe_test)))
for covs in self.static_categorical_covariates.values():
ohe = one_hot_encoder.fit_transform(np.array(covs)[:, np.newaxis])
x_train.append(_repeat(ohe, self.train_lens))
x_test.append(_repeat(ohe, self.test_lens))
x_train = np.concatenate(x_train, axis=1)
x_test = np.concatenate(x_test, axis=1)
if use_intercept:
x_train = np.pad(x_train, ((0, 0), (1, 0)), constant_values=1.0)
x_test = np.pad(x_test, ((0, 0), (1, 0)), constant_values=1.0)
return _unnest(self.targets), x_train, x_test
def fit(self) -> Any:
raise NotImplementedError("Fit is not implemented.")
class BatchedInContextXRegLinear(BatchedInContextXRegBase):
"""Linear in-context regression model."""
def fit(
self,
ridge: float = 0.0,
one_hot_encoder_drop: str | None = "first",
use_intercept: bool = True,
force_on_cpu: bool = False,
max_rows_per_col: int = 0,
max_rows_per_col_sample_seed: int = 42,
debug_info: bool = False,
assert_covariates: bool = False,
assert_covariate_shapes: bool = False,
) -> (list[np.ndarray] | tuple[list[np.ndarray], list[np.ndarray], jax.Array,
jax.Array, jax.Array]):
"""Fits a linear model for in-context regression.
Args:
ridge: A non-negative value for specifying the ridge regression penalty.
If 0 is provided, fallback to ordinary least squares. Note this penalty
is added to the normalized covariate matrix.
one_hot_encoder_drop: Which drop strategy to use for the one hot encoder.
use_intercept: Whether to prepare an intercept (all 1) column in the
matrices.
force_on_cpu: Whether to force execution on cpu for accelerator machines.
max_rows_per_col: How many rows to subsample per column. 0 for no
subsampling. This is for speeding up model fitting.
max_rows_per_col_sample_seed: The seed for the subsampling if needed by
`max_rows_per_col`.
debug_info: Whether to return debug info.
assert_covariates: Whether to assert the validity of the covariate inputs.
assert_covariate_shapes: Whether to assert the shapes of the covariate
inputs when `assert_covariates` is True.
Returns:
If `debug_info` is False:
The linear fits on the horizon.
If `debug_info` is True:
A tuple of:
- the linear fits on the horizon,
- the linear fits on the context,
- the flattened target vector,
- the covariate matrix for the context, and
- the covariate matrix for the horizon.
"""
flat_targets, x_train_raw, x_test = self.create_covariate_matrix(
one_hot_encoder_drop=one_hot_encoder_drop,
use_intercept=use_intercept,
assert_covariates=assert_covariates,
assert_covariate_shapes=assert_covariate_shapes,
)
x_train = x_train_raw.copy()
if max_rows_per_col:
nrows, ncols = x_train.shape
if nrows > (w := ncols * max_rows_per_col):
subsample = jax.random.choice(
jax.random.PRNGKey(max_rows_per_col_sample_seed),
nrows,
(w,),
replace=False,
)
x_train = x_train[subsample]
flat_targets = flat_targets[subsample]
device = jax.devices("cpu")[0] if force_on_cpu else None
# Runs jitted version of the solvers which are quicker at the cost of
# running jitting during the first time calling. Re-jitting happens whenever
# new (padded) shapes are encountered.
# Ocassionally it helps with the speed and the accuracy if we force single
# thread execution on cpu for accelerator machines:
# 1. Avoid moving data to accelarator memory.
# 2. Avoid precision loss if any.
with jax.default_device(device):
x_train_raw = _to_padded_jax_array(x_train_raw)
x_train = _to_padded_jax_array(x_train)
flat_targets = _to_padded_jax_array(flat_targets)
x_test = _to_padded_jax_array(x_test)
beta_hat = (jnp.linalg.pinv(
x_train.T @ x_train + ridge * jnp.eye(x_train.shape[1]),
hermitian=True,
) @ x_train.T @ flat_targets)
y_hat = x_test @ beta_hat
y_hat_context = x_train_raw @ beta_hat if debug_info else None
outputs = []
outputs_context = []
# Reconstruct the ragged 2-dim batched forecasts from flattened linear fits.
train_index, test_index = 0, 0
for train_index_delta, test_index_delta in zip(self.train_lens,
self.test_lens):
outputs.append(np.array(y_hat[test_index:(test_index +
test_index_delta)]))
if debug_info:
outputs_context.append(
np.array(y_hat_context[train_index:(train_index +
train_index_delta)]))
train_index += train_index_delta
test_index += test_index_delta
if debug_info:
return outputs, outputs_context, flat_targets, x_train, x_test
else:
return outputs
================================================
FILE: v1/tests/test_timesfm.py
================================================
# Copyright 2024 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from datetime import datetime, timedelta
import numpy as np
import pandas as pd
import pytest
import timesfm
def create_sample_dataframe(
start_date: datetime, end_date: datetime, freq: str = "D"
) -> pd.DataFrame:
"""
Create a sample DataFrame with time series data.
Args:
start_date (datetime): Start date of the time series.
end_date (datetime): End date of the time series.
freq (str): Frequency of the time series (default: "D" for daily).
Returns:
pd.DataFrame: DataFrame with columns 'unique_id', 'ds', and 'ts'.
"""
date_range = pd.date_range(start=start_date, end=end_date, freq=freq)
ts_data = np.random.randn(len(date_range))
df = pd.DataFrame({"unique_id": "ts-1", "ds": date_range, "ts": ts_data})
return df
@pytest.mark.parametrize("context_length", [128, 256, 512])
@pytest.mark.parametrize("prediction_length", [96, 128, 256])
@pytest.mark.parametrize("freq", ["D", "H", "W"])
def test_timesfm_forecast_on_df(
context_length: int,
prediction_length: int,
freq: str,
) -> None:
model = timesfm.TimesFm(
context_len=context_length,
horizon_len=prediction_length,
input_patch_len=32,
output_patch_len=128,
num_layers=20,
model_dims=1280,
backend="cpu",
)
model.load_from_checkpoint(repo_id="google/timesfm-1.0-200m")
end_date = datetime.now()
start_date = end_date - timedelta(days=context_length)
input_df = create_sample_dataframe(start_date, end_date, freq)
forecast_df = model.forecast_on_df(
inputs=input_df,
freq=freq,
value_name="ts",
num_jobs=-1,
)
assert (
len(forecast_df) == prediction_length
), f"Expected forecast length of {prediction_length}, but got {len(forecast_df)}"
assert (
"timesfm" in forecast_df.columns
), "Forecast DataFrame should contain 'timesfm' column"
last_input_date = input_df["ds"].max()
first_forecast_date = forecast_df["ds"].min()
expected_first_forecast_date = last_input_date + pd.Timedelta(1, unit=freq)
assert (
first_forecast_date == expected_first_forecast_date
), f"Forecast should start from {expected_first_forecast_date}, but starts from {first_forecast_date}"
print(
f"Successful forecast with context_length={context_length}, prediction_length={prediction_length}, freq={freq}"
)