Skip to content

Commit 953d848

Browse files
authored
Failing to restore AdaBelief optimizer from checkpoint (#2705)
* Update adabelief.py * addressed pr comments
1 parent e5bfef1 commit 953d848

File tree

2 files changed

+25
-2
lines changed

2 files changed

+25
-2
lines changed

tensorflow_addons/optimizers/adabelief.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def __init__(
140140
self._set_hyper("decay", self._initial_decay)
141141
self._set_hyper("weight_decay", weight_decay)
142142
self._set_hyper("sma_threshold", sma_threshold)
143-
self._set_hyper("total_steps", int(total_steps))
143+
self._set_hyper("total_steps", float(total_steps))
144144
self._set_hyper("warmup_proportion", warmup_proportion)
145145
self._set_hyper("min_lr", min_lr)
146146
self.epsilon = epsilon or tf.keras.backend.epsilon()
@@ -325,7 +325,7 @@ def get_config(self):
325325
"epsilon": self.epsilon,
326326
"amsgrad": self.amsgrad,
327327
"rectify": self.rectify,
328-
"total_steps": self._serialize_hyperparameter("total_steps"),
328+
"total_steps": int(self._serialize_hyperparameter("total_steps")),
329329
"warmup_proportion": self._serialize_hyperparameter(
330330
"warmup_proportion"
331331
),

tensorflow_addons/optimizers/tests/adabelief_test.py

+23
Original file line numberDiff line numberDiff line change
@@ -236,3 +236,26 @@ def test_scheduler_serialization():
236236
"class_name": "InverseTimeDecay",
237237
"config": wd_scheduler.get_config(),
238238
}
239+
240+
241+
def test_checkpoint_serialization(tmpdir):
242+
optimizer = AdaBelief()
243+
optimizer2 = AdaBelief()
244+
245+
var_0 = tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)
246+
var_1 = tf.Variable([3.0, 4.0], dtype=tf.dtypes.float32)
247+
248+
grad_0 = tf.constant([0.1, 0.2], dtype=tf.dtypes.float32)
249+
grad_1 = tf.constant([0.03, 0.04], dtype=tf.dtypes.float32)
250+
251+
grads_and_vars = list(zip([grad_0, grad_1], [var_0, var_1]))
252+
253+
optimizer.apply_gradients(grads_and_vars)
254+
255+
checkpoint = tf.train.Checkpoint(optimizer=optimizer)
256+
checkpoint2 = tf.train.Checkpoint(optimizer=optimizer2)
257+
model_path = str(tmpdir / "adabelief_chkpt")
258+
checkpoint.write(model_path)
259+
checkpoint2.read(model_path)
260+
261+
optimizer2.apply_gradients(grads_and_vars)

0 commit comments

Comments
 (0)