Skip to content

Commit 08741c9

Browse files
authored
Add experimental_aggregate_gradients support (#2137)
* Add experimental_aggregate_gradients support * format code * fix test case
1 parent dc3aa06 commit 08741c9

File tree

2 files changed

+14
-2
lines changed

2 files changed

+14
-2
lines changed

tensorflow_addons/optimizers/average_wrapper.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,9 @@ def _create_hypers(self):
6565
def _prepare(self, var_list):
6666
return self._optimizer._prepare(var_list=var_list)
6767

68-
def apply_gradients(self, grads_and_vars, name=None):
68+
def apply_gradients(self, grads_and_vars, name=None, **kwargs):
6969
self._optimizer._iterations = self.iterations
70-
return super().apply_gradients(grads_and_vars, name)
70+
return super().apply_gradients(grads_and_vars, name, **kwargs)
7171

7272
@abc.abstractmethod
7373
def average_op(self, var, average_var):

tensorflow_addons/optimizers/tests/moving_average_test.py

+12
Original file line numberDiff line numberDiff line change
@@ -259,3 +259,15 @@ def apply_gradients():
259259

260260
np.testing.assert_allclose(var.read_value(), [0.8, 1.8])
261261
np.testing.assert_allclose(ema_var.read_value(), [0.9, 1.9])
262+
263+
264+
@pytest.mark.usefixtures("run_with_mixed_precision_policy")
265+
def test_model_mixed_precision():
266+
tf.keras.mixed_precision.experimental.set_policy("mixed_float16")
267+
x = np.random.standard_normal((10000, 3))
268+
w = np.random.standard_normal((3, 1))
269+
y = np.dot(x, w) + np.random.standard_normal((10000, 1)) * 1e-4
270+
model = tf.keras.Sequential()
271+
model.add(tf.keras.layers.Dense(input_shape=(3,), units=1))
272+
model.compile(MovingAverage("sgd"), loss="mse")
273+
model.fit(x, y, epochs=3)

0 commit comments

Comments
 (0)