Skip to content

Commit 4b1400d

Browse files
ZixuanJiangGoogle-ML-Automation
authored andcommitted
#jax Optimize jax.numpy.take_along_axis along the dimension satisfies
* the dimension is not the one along which to take values * the dimension size of input tensor is 1 * the dimension size of the indices is not 1 Previously, we create constant zero as the dummy indices, which is redundant. We can squeeze the input tensor and generate the `stablehlo.gather` directly. In the following example, ``` h = jtu.rand_default(self.rng())((2, 1, 5, 7, 13), np.float32) g = jtu.rand_int(self.rng(), 0, 7)((2, 3, 5, 11, 1), np.uint8) q0 = jnp.take_along_axis(h, g, axis=-2) ``` It lowers to the following module before this change, ``` module @jit_foo attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<2x1x5x7x13xf32> loc("x"), %arg1: tensor<2x3x5x11x1xui8> loc("y")) -> (tensor<2x3x5x11x13xf32> {jax.result_info = ""}) { %0 = call @take_along_axis(%arg0, %arg1) : (tensor<2x1x5x7x13xf32>, tensor<2x3x5x11x1xui8>) -> tensor<2x3x5x11x13xf32> loc(#loc32) return %0 : tensor<2x3x5x11x13xf32> loc(#loc) } loc(#loc) func.func private @take_along_axis(%arg0: tensor<2x1x5x7x13xf32> loc("jit(foo)/jit(main)/pjit"(#loc31)), %arg1: tensor<2x3x5x11x1xui8> loc("jit(foo)/jit(main)/pjit"(#loc31))) -> tensor<2x3x5x11x13xf32> { %0 = stablehlo.convert %arg1 : (tensor<2x3x5x11x1xui8>) -> tensor<2x3x5x11x1xi32> loc(#loc33) %c = stablehlo.constant dense<0> : tensor<i32> loc(#loc32) %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<2x3x5x11x1xi32> loc(#loc34) %2 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<2x3x5x11x1xi32> loc(#loc35) %3 = stablehlo.compare LT, %0, %2, SIGNED : (tensor<2x3x5x11x1xi32>, tensor<2x3x5x11x1xi32>) -> tensor<2x3x5x11x1xi1> loc(#loc35) %c_0 = stablehlo.constant dense<7> : tensor<i32> loc(#loc32) %4 = stablehlo.broadcast_in_dim %c_0, dims = [] : (tensor<i32>) -> tensor<2x3x5x11x1xi32> loc(#loc36) %5 = stablehlo.add %0, %4 : tensor<2x3x5x11x1xi32> loc(#loc36) %6 = stablehlo.select %3, %5, %0 : tensor<2x3x5x11x1xi1>, tensor<2x3x5x11x1xi32> loc(#loc37) %7 = stablehlo.concatenate %1, %6, dim = 4 : (tensor<2x3x5x11x1xi32>, tensor<2x3x5x11x1xi32>) -> tensor<2x3x5x11x2xi32> loc(#loc38) %c_1 = stablehlo.constant dense<[0, 6]> : tensor<2xi64> loc(#loc39) %8 = stablehlo.convert %7 : (tensor<2x3x5x11x2xi32>) -> tensor<2x3x5x11x2xi64> loc(#loc33) %c_2 = stablehlo.constant dense<0> : tensor<i64> loc(#loc39) %9 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor<i64>) -> tensor<2x3x5x11x2xi64> loc(#loc40) %10 = stablehlo.compare GE, %8, %9, SIGNED : (tensor<2x3x5x11x2xi64>, tensor<2x3x5x11x2xi64>) -> tensor<2x3x5x11x2xi1> loc(#loc40) %11 = stablehlo.broadcast_in_dim %c_1, dims = [4] : (tensor<2xi64>) -> tensor<1x1x1x1x2xi64> loc(#loc34) %12 = stablehlo.broadcast_in_dim %11, dims = [0, 1, 2, 3, 4] : (tensor<1x1x1x1x2xi64>) -> tensor<2x3x5x11x2xi64> loc(#loc41) %13 = stablehlo.compare LE, %8, %12, SIGNED : (tensor<2x3x5x11x2xi64>, tensor<2x3x5x11x2xi64>) -> tensor<2x3x5x11x2xi1> loc(#loc41) %14 = stablehlo.and %10, %13 : tensor<2x3x5x11x2xi1> loc(#loc42) %c_3 = stablehlo.constant dense<true> : tensor<i1> loc(#loc43) %15 = stablehlo.reduce(%14 init: %c_3) applies stablehlo.and across dimensions = [4] : (tensor<2x3x5x11x2xi1>, tensor<i1>) -> tensor<2x3x5x11xi1> loc(#loc43) %16 = "stablehlo.gather"(%arg0, %8) <{dimension_numbers = #stablehlo.gather<offset_dims = [4], collapsed_slice_dims = [1, 3], operand_batching_dims = [0, 2], start_indices_batching_dims = [0, 2], start_index_map = [1, 3], index_vector_dim = 4>, indices_are_sorted = false, slice_sizes = array<i64: 1, 1, 1, 1, 13>}> : (tensor<2x1x5x7x13xf32>, tensor<2x3x5x11x2xi64>) -> tensor<2x3x5x11x13xf32> loc(#loc39) %17 = stablehlo.broadcast_in_dim %15, dims = [0, 1, 2, 3] : (tensor<2x3x5x11xi1>) -> tensor<2x3x5x11x13xi1> loc(#loc34) %cst = stablehlo.constant dense<0x7FC00000> : tensor<f32> loc(#loc39) %18 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<2x3x5x11x13xf32> loc(#loc34) %19 = stablehlo.select %17, %16, %18 : tensor<2x3x5x11x13xi1>, tensor<2x3x5x11x13xf32> loc(#loc37) return %19 : tensor<2x3x5x11x13xf32> loc(#loc32) } } ``` With this change, we have ``` module @jit_foo attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<2x1x5x7x13xf32> loc("x"), %arg1: tensor<2x3x5x11x1xui8> loc("y")) -> (tensor<2x3x5x11x13xf32> {jax.result_info = ""}) { %0 = call @take_along_axis(%arg0, %arg1) : (tensor<2x1x5x7x13xf32>, tensor<2x3x5x11x1xui8>) -> tensor<2x3x5x11x13xf32> loc(#loc32) return %0 : tensor<2x3x5x11x13xf32> loc(#loc) } loc(#loc) func.func private @take_along_axis(%arg0: tensor<2x1x5x7x13xf32> loc("jit(foo)/jit(main)/pjit"(#loc31)), %arg1: tensor<2x3x5x11x1xui8> loc("jit(foo)/jit(main)/pjit"(#loc31))) -> tensor<2x3x5x11x13xf32> { %0 = stablehlo.convert %arg1 : (tensor<2x3x5x11x1xui8>) -> tensor<2x3x5x11x1xi32> loc(#loc33) %c = stablehlo.constant dense<0> : tensor<i32> loc(#loc32) %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<2x3x5x11x1xi32> loc(#loc34) %2 = stablehlo.compare LT, %0, %1, SIGNED : (tensor<2x3x5x11x1xi32>, tensor<2x3x5x11x1xi32>) -> tensor<2x3x5x11x1xi1> loc(#loc34) %c_0 = stablehlo.constant dense<7> : tensor<i32> loc(#loc32) %3 = stablehlo.broadcast_in_dim %c_0, dims = [] : (tensor<i32>) -> tensor<2x3x5x11x1xi32> loc(#loc35) %4 = stablehlo.add %0, %3 : tensor<2x3x5x11x1xi32> loc(#loc35) %5 = stablehlo.select %2, %4, %0 : tensor<2x3x5x11x1xi1>, tensor<2x3x5x11x1xi32> loc(#loc36) %6 = stablehlo.reshape %arg0 : (tensor<2x1x5x7x13xf32>) -> tensor<2x5x7x13xf32> loc(#loc37) %c_1 = stablehlo.constant dense<6> : tensor<1xi64> loc(#loc38) %7 = stablehlo.convert %5 : (tensor<2x3x5x11x1xi32>) -> tensor<2x3x5x11x1xi64> loc(#loc33) %c_2 = stablehlo.constant dense<0> : tensor<i64> loc(#loc38) %8 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor<i64>) -> tensor<2x3x5x11x1xi64> loc(#loc39) %9 = stablehlo.compare GE, %7, %8, SIGNED : (tensor<2x3x5x11x1xi64>, tensor<2x3x5x11x1xi64>) -> tensor<2x3x5x11x1xi1> loc(#loc39) %10 = stablehlo.broadcast_in_dim %c_1, dims = [4] : (tensor<1xi64>) -> tensor<1x1x1x1x1xi64> loc(#loc40) %11 = stablehlo.broadcast_in_dim %10, dims = [0, 1, 2, 3, 4] : (tensor<1x1x1x1x1xi64>) -> tensor<2x3x5x11x1xi64> loc(#loc41) %12 = stablehlo.compare LE, %7, %11, SIGNED : (tensor<2x3x5x11x1xi64>, tensor<2x3x5x11x1xi64>) -> tensor<2x3x5x11x1xi1> loc(#loc41) %13 = stablehlo.and %9, %12 : tensor<2x3x5x11x1xi1> loc(#loc42) %c_3 = stablehlo.constant dense<true> : tensor<i1> loc(#loc43) %14 = stablehlo.reduce(%13 init: %c_3) applies stablehlo.and across dimensions = [4] : (tensor<2x3x5x11x1xi1>, tensor<i1>) -> tensor<2x3x5x11xi1> loc(#loc43) %15 = "stablehlo.gather"(%6, %7) <{dimension_numbers = #stablehlo.gather<offset_dims = [4], collapsed_slice_dims = [2], operand_batching_dims = [0, 1], start_indices_batching_dims = [0, 2], start_index_map = [2], index_vector_dim = 4>, indices_are_sorted = false, slice_sizes = array<i64: 1, 1, 1, 13>}> : (tensor<2x5x7x13xf32>, tensor<2x3x5x11x1xi64>) -> tensor<2x3x5x11x13xf32> loc(#loc38) %16 = stablehlo.broadcast_in_dim %14, dims = [0, 1, 2, 3] : (tensor<2x3x5x11xi1>) -> tensor<2x3x5x11x13xi1> loc(#loc40) %cst = stablehlo.constant dense<0x7FC00000> : tensor<f32> loc(#loc38) %17 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<2x3x5x11x13xf32> loc(#loc40) %18 = stablehlo.select %16, %15, %17 : tensor<2x3x5x11x13xi1>, tensor<2x3x5x11x13xf32> loc(#loc36) return %18 : tensor<2x3x5x11x13xf32> loc(#loc32) } } ``` PiperOrigin-RevId: 725506779
1 parent 1e447c8 commit 4b1400d

File tree

2 files changed

+24
-10
lines changed

2 files changed

+24
-10
lines changed

Diff for: jax/_src/numpy/lax_numpy.py

+17-10
Original file line numberDiff line numberDiff line change
@@ -10835,29 +10835,33 @@ def replace(tup, val):
1083510835
collapsed_slice_dims = []
1083610836
operand_batching_dims = []
1083710837
start_indices_batching_dims = []
10838+
10839+
# We will squeeze the array. i is the index of the unsqueezed shape, while
10840+
# new_i is the index of the squeezed shape. j is the index of the gather
10841+
# indices.
10842+
dims_to_squeeze = []
10843+
new_i = 0
1083810844
j = 0
1083910845
for i in range(rank):
1084010846
if i == axis_int:
1084110847
if mode != 'promise_in_bounds':
1084210848
indices = _normalize_index(indices, axis_size)
1084310849
gather_indices.append(lax.reshape(indices, gather_index_shape))
1084410850
slice_sizes.append(1)
10845-
start_index_map.append(i)
10846-
collapsed_slice_dims.append(i)
10851+
start_index_map.append(new_i)
10852+
collapsed_slice_dims.append(new_i)
10853+
new_i += 1
1084710854
j += 1
1084810855
elif core.definitely_equal(idx_shape[i], 1):
1084910856
# If idx_shape[i] == 1, we can just take the entirety of the arr's axis
1085010857
# and avoid forming an iota index.
1085110858
offset_dims.append(i)
1085210859
slice_sizes.append(arr_shape[i])
10860+
new_i += 1
1085310861
elif core.definitely_equal(arr_shape[i], 1):
10854-
# If the array dimension is 1 but the index dimension is not, we
10855-
# broadcast the array dimension to the index dimension by repeatedly
10856-
# gathering the first element.
10857-
gather_indices.append(zeros(gather_index_shape, dtype=index_dtype))
10858-
slice_sizes.append(1)
10859-
start_index_map.append(i)
10860-
collapsed_slice_dims.append(i)
10862+
# If the array dimension is 1 but the index dimension is not, we will
10863+
# squeeze this dimension.
10864+
dims_to_squeeze.append(i)
1086110865
j += 1
1086210866
else:
1086310867
# Otherwise, idx_shape[i] == arr_shape[i]. Mark the dimensions in both
@@ -10866,10 +10870,13 @@ def replace(tup, val):
1086610870
slice_sizes.append(0)
1086710871
else:
1086810872
slice_sizes.append(1)
10869-
operand_batching_dims.append(i)
10873+
operand_batching_dims.append(new_i)
1087010874
start_indices_batching_dims.append(j)
10875+
new_i += 1
1087110876
j += 1
1087210877

10878+
# Squeeze a to remove singleton dimensions.
10879+
a = lax.squeeze(a, dims_to_squeeze)
1087310880
gather_indices_arr = lax.concatenate(gather_indices, dimension=j)
1087410881
dnums = lax.GatherDimensionNumbers(
1087510882
offset_dims=tuple(offset_dims),

Diff for: tests/lax_numpy_test.py

+7
Original file line numberDiff line numberDiff line change
@@ -4721,6 +4721,13 @@ def testTakeAlongAxisWithUint8IndicesDoesNotOverflow(self):
47214721
q1 = np.take_along_axis( h, g, axis=-1)
47224722
np.testing.assert_equal(q0, q1)
47234723

4724+
def testTakeAlongAxisInputTensorHasSingletonDimension(self):
4725+
h = jtu.rand_default(self.rng())((2, 1, 5, 7, 13), np.float32)
4726+
g = jtu.rand_int(self.rng(), 0, 7)((2, 3, 5, 11, 1), np.uint8)
4727+
q0 = jnp.take_along_axis(h, g, axis=-2)
4728+
q1 = np.take_along_axis( h, g, axis=-2)
4729+
np.testing.assert_equal(q0, q1)
4730+
47244731
def testTakeAlongAxisOutOfBounds(self):
47254732
x = jnp.arange(10, dtype=jnp.float32)
47264733
idx = jnp.array([-11, -10, -9, -5, -1, 0, 1, 5, 9, 10, 11])

0 commit comments

Comments
 (0)