File tree Expand file tree Collapse file tree 2 files changed +25
-2
lines changed
tensorflow_addons/optimizers Expand file tree Collapse file tree 2 files changed +25
-2
lines changed Original file line number Diff line number Diff line change @@ -140,7 +140,7 @@ def __init__(
140140 self ._set_hyper ("decay" , self ._initial_decay )
141141 self ._set_hyper ("weight_decay" , weight_decay )
142142 self ._set_hyper ("sma_threshold" , sma_threshold )
143- self ._set_hyper ("total_steps" , int (total_steps ))
143+ self ._set_hyper ("total_steps" , float (total_steps ))
144144 self ._set_hyper ("warmup_proportion" , warmup_proportion )
145145 self ._set_hyper ("min_lr" , min_lr )
146146 self .epsilon = epsilon or tf .keras .backend .epsilon ()
@@ -325,7 +325,7 @@ def get_config(self):
325325 "epsilon" : self .epsilon ,
326326 "amsgrad" : self .amsgrad ,
327327 "rectify" : self .rectify ,
328- "total_steps" : self ._serialize_hyperparameter ("total_steps" ),
328+ "total_steps" : int ( self ._serialize_hyperparameter ("total_steps" ) ),
329329 "warmup_proportion" : self ._serialize_hyperparameter (
330330 "warmup_proportion"
331331 ),
Original file line number Diff line number Diff line change @@ -236,3 +236,26 @@ def test_scheduler_serialization():
236236 "class_name" : "InverseTimeDecay" ,
237237 "config" : wd_scheduler .get_config (),
238238 }
239+
240+
241+ def test_checkpoint_serialization (tmpdir ):
242+ optimizer = AdaBelief ()
243+ optimizer2 = AdaBelief ()
244+
245+ var_0 = tf .Variable ([1.0 , 2.0 ], dtype = tf .dtypes .float32 )
246+ var_1 = tf .Variable ([3.0 , 4.0 ], dtype = tf .dtypes .float32 )
247+
248+ grad_0 = tf .constant ([0.1 , 0.2 ], dtype = tf .dtypes .float32 )
249+ grad_1 = tf .constant ([0.03 , 0.04 ], dtype = tf .dtypes .float32 )
250+
251+ grads_and_vars = list (zip ([grad_0 , grad_1 ], [var_0 , var_1 ]))
252+
253+ optimizer .apply_gradients (grads_and_vars )
254+
255+ checkpoint = tf .train .Checkpoint (optimizer = optimizer )
256+ checkpoint2 = tf .train .Checkpoint (optimizer = optimizer2 )
257+ model_path = str (tmpdir / "adabelief_chkpt" )
258+ checkpoint .write (model_path )
259+ checkpoint2 .read (model_path )
260+
261+ optimizer2 .apply_gradients (grads_and_vars )
You can’t perform that action at this time.
0 commit comments