@@ -26,6 +26,7 @@ def __init__(
2626 self ,
2727 inner_optimizer : types .Optimizer ,
2828 accum_steps : types .TensorLike = 4 ,
29+ reduction : str = "SUM" ,
2930 name : str = "GradientAccumulator" ,
3031 ** kwargs ,
3132 ):
@@ -35,6 +36,7 @@ def __init__(
3536 inner_optimizer: str or `tf.keras.optimizers.Optimizer` that will be
3637 used to compute and apply gradients.
3738 accum_steps: int > 0. Update gradient in every accumulation steps.
39+ reduction: str, Reduction method ['SUM', 'MEAN']
3840 name: Optional name for the operations created when applying
3941 gradients. Defaults to "GradientAccumulator".
4042 **kwargs: keyword arguments. Allowed to be {`clipnorm`,
@@ -49,6 +51,7 @@ def __init__(
4951 self ._step = None
5052 self ._gradients = {}
5153 self ._accum_steps = accum_steps
54+ self ._reduction = reduction
5255
5356 def _accum_grad (grads_and_vars ):
5457 with tf .init_scope ():
@@ -78,6 +81,8 @@ def _accum_grad(grads_and_vars):
7881
7982 def _get_grad ():
8083 new_grad = handle .read_value ()
84+ if self ._reduction == "MEAN" :
85+ new_grad /= tf .cast (self ._accum_steps , new_grad .dtype )
8186 indices = tf .squeeze (
8287 tf .where (
8388 tf .reduce_sum (
@@ -108,10 +113,11 @@ def _get_grad():
108113 new_grads_and_vars .append ((new_grad , var ))
109114 else :
110115 handle .assign_add (grad )
111- fake_grad = tf .zeros_like (var )
112116
113117 def _get_grad ():
114118 new_grad = handle .read_value ()
119+ if self ._reduction == "MEAN" :
120+ new_grad /= tf .cast (self ._accum_steps , new_grad .dtype )
115121 handle .assign (
116122 tf .zeros_like (handle ), use_locking = self ._use_locking
117123 )
@@ -120,7 +126,7 @@ def _get_grad():
120126 new_grad = tf .cond (
121127 self .step % self ._accum_steps == 0 ,
122128 _get_grad ,
123- lambda : fake_grad ,
129+ lambda : tf . zeros_like ( grad ) ,
124130 )
125131 new_grads_and_vars .append ((new_grad , var ))
126132 return new_grads_and_vars
0 commit comments