Skip to content

Commit 184f624

Browse files
committed
Stop returning IndexedSlices gradients
1 parent aedb074 commit 184f624

File tree

1 file changed

+1
-12
lines changed

1 file changed

+1
-12
lines changed

tensorflow_addons/layers/embedding_bag.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -68,18 +68,7 @@ def _embedding_bag_grad(op, grads):
6868
value_grads, weight_grads = _embedding_bag_so.ops.addons_embedding_bag_grad(
6969
indices, params, weights, grads, combiner=combiner
7070
)
71-
# Because value grads are sparse, returning IndexedSlices can be faster for optimizer.
72-
unique_indices = tf.unique(tf.reshape(indices, (-1,)))[0]
73-
sorted_unique_indices = tf.sort(unique_indices)
74-
return [
75-
None,
76-
tf.IndexedSlices(
77-
indices=sorted_unique_indices,
78-
values=tf.gather(value_grads, sorted_unique_indices),
79-
dense_shape=tf.shape(params),
80-
),
81-
weight_grads,
82-
]
71+
return [None, value_grads, weight_grads]
8372

8473

8574
@tf.keras.utils.register_keras_serializable(package="Addons")

0 commit comments

Comments
 (0)