|
16 | 16 |
|
17 | 17 | import numpy as np
|
18 | 18 | import pytest
|
| 19 | +from packaging.version import Version |
19 | 20 |
|
20 | 21 | import tensorflow.compat.v2 as tf
|
21 | 22 | from tensorflow_addons.optimizers import AdaBelief, Lookahead
|
@@ -227,15 +228,31 @@ def test_scheduler_serialization():
|
227 | 228 | new_optimizer = tf.keras.optimizers.deserialize(config)
|
228 | 229 | assert new_optimizer.get_config() == optimizer.get_config()
|
229 | 230 |
|
230 |
| - assert new_optimizer.get_config()["learning_rate"] == { |
231 |
| - "class_name": "ExponentialDecay", |
232 |
| - "config": lr_scheduler.get_config(), |
233 |
| - } |
234 |
| - |
235 |
| - assert new_optimizer.get_config()["weight_decay"] == { |
236 |
| - "class_name": "InverseTimeDecay", |
237 |
| - "config": wd_scheduler.get_config(), |
238 |
| - } |
| 231 | + # TODO: Remove after 2.13 is oldest version supported due to new serialization |
| 232 | + if Version(tf.__version__) >= Version("2.13"): |
| 233 | + assert new_optimizer.get_config()["learning_rate"] == { |
| 234 | + "class_name": "ExponentialDecay", |
| 235 | + "config": lr_scheduler.get_config(), |
| 236 | + "module": "keras.optimizers.schedules", |
| 237 | + "registered_name": None, |
| 238 | + } |
| 239 | + assert new_optimizer.get_config()["weight_decay"] == { |
| 240 | + "class_name": "InverseTimeDecay", |
| 241 | + "config": wd_scheduler.get_config(), |
| 242 | + "module": "keras.optimizers.schedules", |
| 243 | + "registered_name": None, |
| 244 | + } |
| 245 | + |
| 246 | + else: |
| 247 | + assert new_optimizer.get_config()["learning_rate"] == { |
| 248 | + "class_name": "ExponentialDecay", |
| 249 | + "config": lr_scheduler.get_config(), |
| 250 | + } |
| 251 | + |
| 252 | + assert new_optimizer.get_config()["weight_decay"] == { |
| 253 | + "class_name": "InverseTimeDecay", |
| 254 | + "config": wd_scheduler.get_config(), |
| 255 | + } |
239 | 256 |
|
240 | 257 |
|
241 | 258 | def test_checkpoint_serialization(tmpdir):
|
|
0 commit comments