Skip to content

Commit 7aac34c

Browse files
Fix the serialization bug of rectified adam. (#1375)
* Fix the serialization bug of rectified adam. * Better error message. * Update tensorflow_addons/optimizers/rectified_adam.py
1 parent 711e725 commit 7aac34c

File tree

2 files changed

+24
-2
lines changed

2 files changed

+24
-2
lines changed

tensorflow_addons/optimizers/rectified_adam.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515
"""Rectified Adam (RAdam) optimizer."""
16+
import warnings
1617

1718
import tensorflow as tf
1819
from tensorflow_addons.utils.types import FloatTensorLike
@@ -79,7 +80,10 @@ def __init__(
7980
weight_decay: FloatTensorLike = 0.0,
8081
amsgrad: bool = False,
8182
sma_threshold: FloatTensorLike = 5.0,
82-
total_steps: int = 0,
83+
# float for total_steps is here to be able to load models created before
84+
# https://github.com/tensorflow/addons/pull/1375 was merged. It should be
85+
# removed for Addons 0.11.
86+
total_steps: Union[int, float] = 0,
8387
warmup_proportion: FloatTensorLike = 0.1,
8488
min_lr: FloatTensorLike = 0.0,
8589
name: str = "RectifiedAdam",
@@ -123,7 +127,16 @@ def __init__(
123127
self._set_hyper("decay", self._initial_decay)
124128
self._set_hyper("weight_decay", weight_decay)
125129
self._set_hyper("sma_threshold", sma_threshold)
126-
self._set_hyper("total_steps", float(total_steps))
130+
if isinstance(total_steps, float):
131+
warnings.warn(
132+
"The parameter `total_steps` passed to the __init__ of RectifiedAdam "
133+
"is a float. This behavior is deprecated and in Addons 0.11, this "
134+
"will raise an error. Use an int instead. If you get this message "
135+
"when loading a model, save it again and the `total_steps` parameter "
136+
"will automatically be converted to a int.",
137+
DeprecationWarning,
138+
)
139+
self._set_hyper("total_steps", int(total_steps))
127140
self._set_hyper("warmup_proportion", warmup_proportion)
128141
self._set_hyper("min_lr", min_lr)
129142
self.epsilon = epsilon or tf.keras.backend.epsilon()

tensorflow_addons/optimizers/rectified_adam_test.py

+9
Original file line numberDiff line numberDiff line change
@@ -172,5 +172,14 @@ def test_get_config(self):
172172
self.assertEqual(config["total_steps"], 0)
173173

174174

175+
def test_serialization():
176+
optimizer = RectifiedAdam(
177+
lr=1e-3, total_steps=10000, warmup_proportion=0.1, min_lr=1e-5,
178+
)
179+
config = tf.keras.optimizers.serialize(optimizer)
180+
new_optimizer = tf.keras.optimizers.deserialize(config)
181+
assert new_optimizer.get_config() == optimizer.get_config()
182+
183+
175184
if __name__ == "__main__":
176185
sys.exit(pytest.main([__file__]))

0 commit comments

Comments
 (0)