Skip to content

Commit 0cb4674

Browse files
authored
Makes create_slots automatically setup weights for swap_weights (#2195)
* Enables moving_average optimizer to allow calling swap_weights without the need to call shadow_copy first. * Update moving_average.py
1 parent 7f7c97d commit 0cb4674

File tree

2 files changed

+36
-3
lines changed

2 files changed

+36
-3
lines changed

tensorflow_addons/optimizers/moving_average.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -127,12 +127,13 @@ def get_config(self):
127127
return {**base_config, **config}
128128

129129
def _create_slots(self, var_list):
130-
self._optimizer._create_slots(
131-
var_list=var_list
132-
) # pylint: disable=protected-access
130+
self._optimizer._create_slots(var_list=var_list)
133131
for var in var_list:
134132
self.add_slot(var, "average", var.read_value())
135133

134+
self._average_weights = [self.get_slot(var, "average") for var in var_list]
135+
self._model_weights = var_list
136+
136137
def shadow_copy(self, model_weights):
137138
"""Creates shadow variables for the given model weights."""
138139
for var in model_weights:

tensorflow_addons/optimizers/tests/moving_average_test.py

+32
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,38 @@ def test_dynamic_decay():
228228
np.testing.assert_allclose(ema_var0.read_value(), [0.64, 1.64])
229229

230230

231+
@pytest.mark.usefixtures("maybe_run_functions_eagerly")
232+
@pytest.mark.with_device([tf.distribute.MirroredStrategy])
233+
def test_swap_weight_no_shadow_copy(device):
234+
with device.scope():
235+
var = tf.Variable([1.0, 2.0])
236+
grads = tf.constant([0.1, 0.1])
237+
238+
opt = MovingAverage(tf.keras.optimizers.SGD(lr=2.0), average_decay=0.5)
239+
240+
@tf.function
241+
def apply_gradients():
242+
opt.apply_gradients([(grads, var)])
243+
244+
device.run(apply_gradients)
245+
246+
np.testing.assert_allclose(var.read_value(), [0.8, 1.8])
247+
ema_var = opt.get_slot(var, "average")
248+
np.testing.assert_allclose(ema_var.read_value(), [0.9, 1.9])
249+
250+
with device.scope():
251+
opt.swap_weights()
252+
253+
np.testing.assert_allclose(ema_var.read_value(), [0.8, 1.8])
254+
np.testing.assert_allclose(var.read_value(), [0.9, 1.9])
255+
256+
with device.scope():
257+
opt.swap_weights()
258+
259+
np.testing.assert_allclose(var.read_value(), [0.8, 1.8])
260+
np.testing.assert_allclose(ema_var.read_value(), [0.9, 1.9])
261+
262+
231263
@pytest.mark.usefixtures("maybe_run_functions_eagerly")
232264
@pytest.mark.with_device([tf.distribute.MirroredStrategy])
233265
def test_swap_weights(device):

0 commit comments

Comments
 (0)