@@ -228,6 +228,38 @@ def test_dynamic_decay():
228
228
np .testing .assert_allclose (ema_var0 .read_value (), [0.64 , 1.64 ])
229
229
230
230
231
+ @pytest .mark .usefixtures ("maybe_run_functions_eagerly" )
232
+ @pytest .mark .with_device ([tf .distribute .MirroredStrategy ])
233
+ def test_swap_weight_no_shadow_copy (device ):
234
+ with device .scope ():
235
+ var = tf .Variable ([1.0 , 2.0 ])
236
+ grads = tf .constant ([0.1 , 0.1 ])
237
+
238
+ opt = MovingAverage (tf .keras .optimizers .SGD (lr = 2.0 ), average_decay = 0.5 )
239
+
240
+ @tf .function
241
+ def apply_gradients ():
242
+ opt .apply_gradients ([(grads , var )])
243
+
244
+ device .run (apply_gradients )
245
+
246
+ np .testing .assert_allclose (var .read_value (), [0.8 , 1.8 ])
247
+ ema_var = opt .get_slot (var , "average" )
248
+ np .testing .assert_allclose (ema_var .read_value (), [0.9 , 1.9 ])
249
+
250
+ with device .scope ():
251
+ opt .swap_weights ()
252
+
253
+ np .testing .assert_allclose (ema_var .read_value (), [0.8 , 1.8 ])
254
+ np .testing .assert_allclose (var .read_value (), [0.9 , 1.9 ])
255
+
256
+ with device .scope ():
257
+ opt .swap_weights ()
258
+
259
+ np .testing .assert_allclose (var .read_value (), [0.8 , 1.8 ])
260
+ np .testing .assert_allclose (ema_var .read_value (), [0.9 , 1.9 ])
261
+
262
+
231
263
@pytest .mark .usefixtures ("maybe_run_functions_eagerly" )
232
264
@pytest .mark .with_device ([tf .distribute .MirroredStrategy ])
233
265
def test_swap_weights (device ):
0 commit comments