|
13 | 13 | # limitations under the License.
|
14 | 14 | # ==============================================================================
|
15 | 15 | """Rectified Adam (RAdam) optimizer."""
|
| 16 | +import warnings |
16 | 17 |
|
17 | 18 | import tensorflow as tf
|
18 | 19 | from tensorflow_addons.utils.types import FloatTensorLike
|
@@ -79,7 +80,10 @@ def __init__(
|
79 | 80 | weight_decay: FloatTensorLike = 0.0,
|
80 | 81 | amsgrad: bool = False,
|
81 | 82 | sma_threshold: FloatTensorLike = 5.0,
|
82 |
| - total_steps: int = 0, |
| 83 | + # float for total_steps is here to be able to load models created before |
| 84 | + # https://github.com/tensorflow/addons/pull/1375 was merged. It should be |
| 85 | + # removed for Addons 0.11. |
| 86 | + total_steps: Union[int, float] = 0, |
83 | 87 | warmup_proportion: FloatTensorLike = 0.1,
|
84 | 88 | min_lr: FloatTensorLike = 0.0,
|
85 | 89 | name: str = "RectifiedAdam",
|
@@ -123,7 +127,16 @@ def __init__(
|
123 | 127 | self._set_hyper("decay", self._initial_decay)
|
124 | 128 | self._set_hyper("weight_decay", weight_decay)
|
125 | 129 | self._set_hyper("sma_threshold", sma_threshold)
|
126 |
| - self._set_hyper("total_steps", float(total_steps)) |
| 130 | + if isinstance(total_steps, float): |
| 131 | + warnings.warn( |
| 132 | + "The parameter `total_steps` passed to the __init__ of RectifiedAdam " |
| 133 | + "is a float. This behavior is deprecated and in Addons 0.11, this " |
| 134 | + "will raise an error. Use an int instead. If you get this message " |
| 135 | + "when loading a model, save it again and the `total_steps` parameter " |
| 136 | + "will automatically be converted to a int.", |
| 137 | + DeprecationWarning, |
| 138 | + ) |
| 139 | + self._set_hyper("total_steps", int(total_steps)) |
127 | 140 | self._set_hyper("warmup_proportion", warmup_proportion)
|
128 | 141 | self._set_hyper("min_lr", min_lr)
|
129 | 142 | self.epsilon = epsilon or tf.keras.backend.epsilon()
|
|
0 commit comments