Skip to content

Commit aedb074

Browse files
committed
Update test syntax
1 parent 4b82c7f commit aedb074

File tree

1 file changed

+2
-21
lines changed

1 file changed

+2
-21
lines changed

tensorflow_addons/layers/tests/embedding_bag_test.py

Lines changed: 2 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -100,16 +100,7 @@ def test_backward(input_shape, input_dim, dtype, indices_dtype, combiner, graph_
100100
expected_grads = tape.gradient(expected, [params, weights])
101101
# Gather returns sparse IndexedSlices so we have to sum them together.
102102
test_utils.assert_allclose_according_to_type(
103-
tf.math.unsorted_segment_sum(
104-
expected_grads[0].values,
105-
expected_grads[0].indices,
106-
expected_grads[0].dense_shape[0],
107-
),
108-
tf.math.unsorted_segment_sum(
109-
grads[0].values,
110-
grads[0].indices,
111-
grads[0].dense_shape[0],
112-
),
103+
tf.convert_to_tensor(expected_grads[0]), tf.convert_to_tensor(grads[0]),
113104
)
114105
test_utils.assert_allclose_according_to_type(
115106
expected_grads[1],
@@ -125,14 +116,4 @@ def test_backward(input_shape, input_dim, dtype, indices_dtype, combiner, graph_
125116
expected_grads = tape.gradient(expected, [params])
126117
# Gather returns sparse IndexedSlices so we have to sum them together.
127118
test_utils.assert_allclose_according_to_type(
128-
tf.math.unsorted_segment_sum(
129-
expected_grads[0].values,
130-
expected_grads[0].indices,
131-
expected_grads[0].dense_shape[0],
132-
),
133-
tf.math.unsorted_segment_sum(
134-
grads[0].values,
135-
grads[0].indices,
136-
grads[0].dense_shape[0],
137-
),
138-
)
119+
tf.convert_to_tensor(expected_grads[0]), tf.convert_to_tensor(grads[0]))

0 commit comments

Comments
 (0)