We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 395abc4 commit 5068e8bCopy full SHA for 5068e8b
orion/primitives/timesfm.py
@@ -66,10 +66,18 @@ def __init__(self,
66
self.batch_size = batch_size
67
self.target = target
68
69
- self.model = tf.TimesFm(hparams=tf.TimesFmHparams(context_len=window_size,
70
- per_core_batch_size=batch_size,
71
- horizon_len=pred_len),
72
- checkpoint=tf.TimesFmCheckpoint(huggingface_repo_id=repo_id))
+ self.model = tf.TimesFm(
+ hparams=tf.TimesFmHparams(
+ backend="gpu",
+ per_core_batch_size=batch_size,
73
+ horizon_len=pred_len,
74
+ num_layers=50,
75
+ use_positional_embedding=False,
76
+ context_len=window_size,
77
+ ),
78
+ checkpoint=tf.TimesFmCheckpoint(
79
+ huggingface_repo_id=repo_id)
80
+ )
81
82
def predict(self, X, force=False):
83
"""Forecasting timeseries
0 commit comments