Skip to content

Commit e62cc95

Browse files
committed
add mean reduction
1 parent 93794ec commit e62cc95

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

tensorflow_addons/optimizers/gradient_accumulator.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def __init__(
2626
self,
2727
inner_optimizer: types.Optimizer,
2828
accum_steps: types.TensorLike = 4,
29+
reduction: str = "SUM",
2930
name: str = "GradientAccumulator",
3031
**kwargs,
3132
):
@@ -35,6 +36,7 @@ def __init__(
3536
inner_optimizer: str or `tf.keras.optimizers.Optimizer` that will be
3637
used to compute and apply gradients.
3738
accum_steps: int > 0. Update gradient in every accumulation steps.
39+
reduction: str, Reduction method ['SUM', 'MEAN']
3840
name: Optional name for the operations created when applying
3941
gradients. Defaults to "GradientAccumulator".
4042
**kwargs: keyword arguments. Allowed to be {`clipnorm`,
@@ -49,6 +51,7 @@ def __init__(
4951
self._step = None
5052
self._gradients = {}
5153
self._accum_steps = accum_steps
54+
self._reduction = reduction
5255

5356
def _accum_grad(grads_and_vars):
5457
with tf.init_scope():
@@ -78,6 +81,8 @@ def _accum_grad(grads_and_vars):
7881

7982
def _get_grad():
8083
new_grad = handle.read_value()
84+
if self._reduction == "MEAN":
85+
new_grad /= tf.cast(self._accum_steps, new_grad.dtype)
8186
indices = tf.squeeze(
8287
tf.where(
8388
tf.reduce_sum(
@@ -108,10 +113,11 @@ def _get_grad():
108113
new_grads_and_vars.append((new_grad, var))
109114
else:
110115
handle.assign_add(grad)
111-
fake_grad = tf.zeros_like(var)
112116

113117
def _get_grad():
114118
new_grad = handle.read_value()
119+
if self._reduction == "MEAN":
120+
new_grad /= tf.cast(self._accum_steps, new_grad.dtype)
115121
handle.assign(
116122
tf.zeros_like(handle), use_locking=self._use_locking
117123
)
@@ -120,7 +126,7 @@ def _get_grad():
120126
new_grad = tf.cond(
121127
self.step % self._accum_steps == 0,
122128
_get_grad,
123-
lambda: fake_grad,
129+
lambda: tf.zeros_like(grad),
124130
)
125131
new_grads_and_vars.append((new_grad, var))
126132
return new_grads_and_vars

0 commit comments

Comments
 (0)