Skip to content

Commit 66f7666

Browse files
Internal change
PiperOrigin-RevId: 485676095
1 parent 00351f9 commit 66f7666

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

official/projects/edgetpu/vision/modeling/common_modules.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,14 @@ def _cross_replica_average(self, t: tf.Tensor, num_shards_per_group: int):
5151
return tf1.tpu.cross_replica_sum(t, group_assignment) / tf.cast(
5252
num_shards_per_group, t.dtype)
5353

54-
def _moments(self, inputs: tf.Tensor, reduction_axes: int, keep_dims: int):
54+
def _moments(self,
55+
inputs: tf.Tensor,
56+
reduction_axes: int,
57+
keep_dims: int,
58+
mask: Optional[tf.Tensor] = None):
5559
"""Compute the mean and variance: it overrides the original _moments."""
5660
shard_mean, shard_variance = super(TpuBatchNormalization, self)._moments(
57-
inputs, reduction_axes, keep_dims=keep_dims)
61+
inputs, reduction_axes, keep_dims=keep_dims, mask=mask)
5862

5963
num_shards = tpu_function.get_tpu_context().number_of_shards or 1
6064
if num_shards <= 8: # Skip cross_replica for 2x2 or smaller slices.

0 commit comments

Comments
 (0)