Skip to content

Commit

Permalink
Update TimesFM to handle exceptions (#611)
Browse files Browse the repository at this point in the history
* raise errors based on different cases

* fix lint tests

* create a pretrained units directory

* fix minimum tests

* fix tests

* remove pretrained-test and install manually
  • Loading branch information
sarahmish authored Jan 22, 2025
1 parent b56dbf9 commit 615cdb6
Show file tree
Hide file tree
Showing 22 changed files with 151 additions and 44 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ jobs:
docs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v1

- name: Python
uses: actions/setup-python@v1
uses: actions/setup-python@v2
with:
python-version: 3.8

Expand Down
29 changes: 24 additions & 5 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,27 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install package and dependencies
run: pip install invoke .[test]
- name: invoke pytest
run: invoke pytest
- name: invoke unit
run: invoke unit


unit-pretrained:
runs-on: ${{ matrix.os }}
strategy:
matrix:
python-version: ['3.11']
os: [ubuntu-latest, macos-latest, windows-latest]
steps:
- uses: actions/checkout@v1
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install package and dependencies
run: |
pip install invoke pytest .[pretrained]
- name: invoke pretrained
run: invoke pretrained


minimum:
Expand Down Expand Up @@ -125,7 +144,7 @@ jobs:
run: invoke tutorials


pretrained:
pretrained-tutorials:
runs-on: ${{ matrix.os }}
strategy:
matrix:
Expand All @@ -141,5 +160,5 @@ jobs:
run: |
pip install "mistune>=2.0.3,<3.1"
pip install invoke jupyter .[pretrained]
- name: invoke pretrained
run: invoke pretrained
- name: invoke pretrained-tutorials
run: invoke pretrained-tutorials
8 changes: 6 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ fix-lint: ## fix lint issues using autoflake, autopep8, and isort

.PHONY: test-unit
test-unit: ## run tests quickly with the default Python
invoke pytest
invoke unit

.PHONY: test-readme
test-readme: ## run the readme snippets
Expand All @@ -135,9 +135,13 @@ test-readme: ## run the readme snippets
test-tutorials: ## run the tutorial notebooks
invoke tutorials

.PHONY: test-pretrained
test-pretrained: ## run the tutorial notebooks
invoke pretrained

.PHONY: test-pretrained-tutorials
test-pretrained-tutorials: ## run the tutorial notebooks
invoke pretrained
invoke pretrained-tutorials

.PHONY: test
test: test-unit test-readme test-tutorials ## test everything that needs test dependencies
Expand Down
5 changes: 5 additions & 0 deletions orion/primitives/jsons/orion.primitives.timesfm.TimesFM.json
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@
{
"name": "X",
"type": "ndarray"
},
{
"name": "force",
"type": "bool",
"default": false
}
],
"output": [
Expand Down
69 changes: 39 additions & 30 deletions orion/primitives/timesfm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,27 @@
https://github.com/google-research/timesfm?tab=readme-ov-file
"""

import numpy as np
import sys

import timesfm as tf
if sys.version_info < (3, 11):
msg = (
'`timesfm` requires Python >= 3.11 and your '
f'python version is {sys.version}.\n'
'Make sure you are using Python 3.11 or later.\n'
)
raise RuntimeError(msg)

try:
import timesfm as tf
except ImportError as ie:
ie.msg += (
'\n\nIt seems like `timesfm` is not installed.\n'
'Please install `timesfm` using:\n'
'\n pip install orion-ml[pretrained]'
)
raise

MAX_LENGTH = 93000


class TimesFM:
Expand Down Expand Up @@ -53,7 +71,7 @@ def __init__(self,
horizon_len=pred_len),
checkpoint=tf.TimesFmCheckpoint(huggingface_repo_id=repo_id))

def predict(self, X):
def predict(self, X, force=False):
"""Forecasting timeseries
Args:
Expand All @@ -63,30 +81,21 @@ def predict(self, X):
ndarray:
forecasted timeseries.
"""
frequency_input = [self.freq]*len(X)
d = X.shape[-1]

# Univariate
if d == 1:
y_hat, _ = self.model.forecast(X[:, :, 0], freq=frequency_input)
return y_hat[:, 0]

# Multivariate
covariates = list(range(d))
covariates = covariates.remove(self.target)
X_cont = X[:, :, self.target]
X_cov = np.delete(X, self.target, axis=2)

# Append covariates with future values
m, n, k = X_cov.shape
X_cov_new = np.zeros((m, n+self.pred_len, k))
X_cov_new[:, :-self.pred_len, :] = X_cov
X_cov_new[:-1, -self.pred_len:, :] = X_cov[1:, :self.pred_len, :]

x_cov = {str(i): X_cov_new[:, :, i] for i in range(k)}
y_hat, _ = self.model.forecast_with_covariates(
inputs=X_cont,
dynamic_numerical_covariates=x_cov,
freq=frequency_input,
)
return np.concatenate(y_hat)
frequency_input = [self.freq] * len(X)
m, n, d = X.shape

# does not support multivariate
if d > 1:
raise ValueError(f'Encountered X with too many channels (channels={d}).')

# does not support long time series
if not force and m > (MAX_LENGTH - self.window_size):
msg = (
f'`X` has {m} samples, which might result in out of memory issues.\n'
'If you are sure you want to proceed, set `force=True`.'
)

raise MemoryError(msg)

y_hat, _ = self.model.forecast(X[:, :, 0], freq=frequency_input)
return y_hat[:, 0]
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,6 @@
'test': tests_require,
'dev': development_requires + tests_require,
'pretrained': pretrained_requires,
'pretrained-dev': pretrained_requires + development_requires + tests_require,
},
include_package_data=True,
install_requires=install_requires,
Expand Down
13 changes: 9 additions & 4 deletions tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,13 @@


@task
def pytest(c):
c.run('python -m pytest --cov=orion')
def unit(c):
c.run('python -m pytest ./tests/unit --cov=orion')


@task
def pretrained(c):
c.run('python -m pytest ./tests/pretrained')


@task
Expand Down Expand Up @@ -70,7 +75,7 @@ def install_minimum(c):
def minimum(c):
install_minimum(c)
c.run('python -m pip check')
c.run('python -m pytest')
c.run('python -m pytest ./tests/unit')


@task
Expand Down Expand Up @@ -107,7 +112,7 @@ def tutorials(c):
), hide='out')

@task
def pretrained(c):
def pretrained_tutorials(c):
pipelines = os.listdir(os.path.join('orion', 'pipelines', 'pretrained'))
for ipynb_file in glob.glob('tutorials/pipelines/*.ipynb'):
for pipeline in pipelines:
Expand Down
66 changes: 66 additions & 0 deletions tests/pretrained/test_timesfm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import importlib
import unittest
from unittest.mock import patch

import numpy as np

import timesfm as tf
from orion.primitives.timesfm import MAX_LENGTH, TimesFM


class TestTimesFMImport(unittest.TestCase):

@patch('sys.version_info', (3, 10))
def test_runtime_error_python_version_less_than_3_11(self):
with self.assertRaises(RuntimeError) as context:
import orion.primitives.timesfm
importlib.reload(orion.primitives.timesfm)

self.assertIn('requires Python >= 3.11', str(context.exception))
self.assertIn('python version is', str(context.exception))

@patch('sys.version_info', (3, 11))
@patch('builtins.__import__', side_effect=ImportError())
def test_import_error_timesfm_not_installed(self, mock_import):
# simulate Python version 3.11 and timesfm not installed
with self.assertRaises(ImportError):
import orion.primitives.timesfm # noqa


class TestTimesFMPredict(unittest.TestCase):

def setUp(self):
self.model = TimesFM()

def test_value_error_multivariate_input(self):
# create a multivariate input with more than one channel
X = np.random.rand(10, 5, 2) # Shape (m, n, d) with d > 1

with self.assertRaises(ValueError) as context:
self.model.predict(X)

self.assertIn('Encountered X with too many channels', str(context.exception))

def test_memory_error_long_time_series(self):
# create a long time series input
m = MAX_LENGTH - self.model.window_size + 1
X = np.random.rand(m, 5, 1) # Shape (m, n, d) with d = 1

with self.assertRaises(MemoryError) as context:
self.model.predict(X)

self.assertIn('might result in out of memory issues', str(context.exception))

@patch.object(tf.TimesFm, 'forecast', return_value=(np.random.rand(10, 1), None))
def test_no_memory_error_with_force(self, mock_forecast):
# create a long time series input
m = MAX_LENGTH - self.model.window_size + 1
X = np.random.rand(m, 5, 1) # Shape (m, n, d) with d = 1

# should not raise MemoryError when force=True
try:
self.model.predict(X, force=True)
except MemoryError:
self.fail("predict() raised MemoryError unexpectedly with force=True")

mock_forecast.assert_called_once()
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.

0 comments on commit 615cdb6

Please sign in to comment.