Skip to content

Commit 5068e8b

Browse files
committed
change model loading
1 parent 395abc4 commit 5068e8b

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

orion/primitives/timesfm.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,18 @@ def __init__(self,
6666
self.batch_size = batch_size
6767
self.target = target
6868

69-
self.model = tf.TimesFm(hparams=tf.TimesFmHparams(context_len=window_size,
70-
per_core_batch_size=batch_size,
71-
horizon_len=pred_len),
72-
checkpoint=tf.TimesFmCheckpoint(huggingface_repo_id=repo_id))
69+
self.model = tf.TimesFm(
70+
hparams=tf.TimesFmHparams(
71+
backend="gpu",
72+
per_core_batch_size=batch_size,
73+
horizon_len=pred_len,
74+
num_layers=50,
75+
use_positional_embedding=False,
76+
context_len=window_size,
77+
),
78+
checkpoint=tf.TimesFmCheckpoint(
79+
huggingface_repo_id=repo_id)
80+
)
7381

7482
def predict(self, X, force=False):
7583
"""Forecasting timeseries

0 commit comments

Comments
 (0)