Skip to content

Commit 5714253

Browse files
committed
Simplify and speedup numba DimShuffle implementation
1 parent 26f852a commit 5714253

File tree

3 files changed

+58
-95
lines changed

3 files changed

+58
-95
lines changed

pytensor/link/numba/dispatch/elemwise.py

+25-77
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import numba
55
import numpy as np
66
from numba.core.extending import overload
7+
from numpy.lib.stride_tricks import as_strided
78

89
from pytensor.graph.op import Op
910
from pytensor.link.numba.dispatch import basic as numba_basic
@@ -411,91 +412,38 @@ def numba_funcify_CAReduce(op, node, **kwargs):
411412

412413
@numba_funcify.register(DimShuffle)
413414
def numba_funcify_DimShuffle(op, node, **kwargs):
414-
shuffle = tuple(op.shuffle)
415-
transposition = tuple(op.transposition)
416-
augment = tuple(op.augment)
415+
# We use `as_strided` to achieve the DimShuffle behavior of transposing and expanding/squezing dimensions in one call
416+
# Numba doesn't currently support multiple expand/squeeze, and reshape is limited to contiguous arrays.
417+
new_order = tuple(op._new_order)
418+
shape_template = (1,) * node.outputs[0].ndim
419+
strides_template = (0,) * node.outputs[0].ndim
417420

418-
ndim_new_shape = len(shuffle) + len(augment)
419-
420-
no_transpose = all(i == j for i, j in enumerate(transposition))
421-
if no_transpose:
422-
423-
@numba_basic.numba_njit
424-
def transpose(x):
425-
return x
426-
427-
else:
428-
429-
@numba_basic.numba_njit
430-
def transpose(x):
431-
return np.transpose(x, transposition)
432-
433-
shape_template = (1,) * ndim_new_shape
434-
435-
# When `len(shuffle) == 0`, the `shuffle_shape[j]` expression below
436-
# is typed as `getitem(Tuple(), int)`, which has no implementation
437-
# (since getting an item from an empty sequence doesn't make sense).
438-
# To avoid this compile-time error, we omit the expression altogether.
439-
if len(shuffle) > 0:
440-
# Use the statically known shape if available
441-
if all(length is not None for length in node.outputs[0].type.shape):
442-
shape = node.outputs[0].type.shape
443-
444-
@numba_basic.numba_njit
445-
def find_shape(array_shape):
446-
return shape
447-
448-
else:
449-
450-
@numba_basic.numba_njit
451-
def find_shape(array_shape):
452-
shape = shape_template
453-
j = 0
454-
for i in range(ndim_new_shape):
455-
if i not in augment:
456-
length = array_shape[j]
457-
shape = numba_basic.tuple_setitem(shape, i, length)
458-
j = j + 1
459-
return shape
460-
461-
else:
462-
463-
@numba_basic.numba_njit
464-
def find_shape(array_shape):
465-
return shape_template
466-
467-
if ndim_new_shape > 0:
421+
if new_order == ():
422+
# Special case needed because of https://github.com/numba/numba/issues/9933
468423

469424
@numba_basic.numba_njit
470-
def dimshuffle_inner(x, shuffle):
471-
x = transpose(x)
472-
shuffle_shape = x.shape[: len(shuffle)]
473-
new_shape = find_shape(shuffle_shape)
425+
def squeeze_to_0d(x):
426+
return as_strided(x, shape=(), strides=())
474427

475-
# FIXME: Numba's `array.reshape` only accepts C arrays.
476-
return np.reshape(np.ascontiguousarray(x), new_shape)
428+
return squeeze_to_0d
477429

478430
else:
479431

480432
@numba_basic.numba_njit
481-
def dimshuffle_inner(x, shuffle):
482-
return np.reshape(np.ascontiguousarray(x), ())
483-
484-
# Without the following wrapper function we would see this error:
485-
# E No implementation of function Function(<built-in function getitem>) found for signature:
486-
# E
487-
# E >>> getitem(UniTuple(int64 x 2), slice<a:b>)
488-
# E
489-
# E There are 22 candidate implementations:
490-
# E - Of which 22 did not match due to:
491-
# E Overload of function 'getitem': File: <numerous>: Line N/A.
492-
# E With argument(s): '(UniTuple(int64 x 2), slice<a:b>)':
493-
# E No match.
494-
# ...(on this line)...
495-
# E shuffle_shape = res.shape[: len(shuffle)]
496-
@numba_basic.numba_njit(inline="always")
497-
def dimshuffle(x):
498-
return dimshuffle_inner(np.asarray(x), shuffle)
433+
def dimshuffle(x):
434+
old_shape = x.shape
435+
old_strides = x.strides
436+
437+
new_shape = shape_template
438+
new_strides = strides_template
439+
for i, o in enumerate(new_order):
440+
if o != -1:
441+
new_shape = numba_basic.tuple_setitem(new_shape, i, old_shape[o])
442+
new_strides = numba_basic.tuple_setitem(
443+
new_strides, i, old_strides[o]
444+
)
445+
446+
return as_strided(x, shape=new_shape, strides=new_strides)
499447

500448
return dimshuffle
501449

tests/link/numba/test_elemwise.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from tests.tensor.test_elemwise import (
2424
careduce_benchmark_tester,
2525
check_elemwise_runtime_broadcast,
26+
dimshuffle_benchmark,
2627
)
2728

2829

@@ -201,7 +202,7 @@ def test_Dimshuffle_returns_array():
201202

202203
def test_Dimshuffle_non_contiguous():
203204
"""The numba impl of reshape doesn't work with
204-
non-contiguous arrays, make sure we work around thpt."""
205+
non-contiguous arrays, make sure we work around that."""
205206
x = pt.dvector()
206207
idx = pt.vector(dtype="int64")
207208
op = DimShuffle(input_ndim=1, new_order=[])
@@ -643,3 +644,7 @@ def test_numba_careduce_benchmark(self, axis, c_contiguous, benchmark):
643644
return careduce_benchmark_tester(
644645
axis, c_contiguous, mode="NUMBA", benchmark=benchmark
645646
)
647+
648+
@pytest.mark.parametrize("c_contiguous", (True, False))
649+
def test_dimshuffle(self, c_contiguous, benchmark):
650+
dimshuffle_benchmark("NUMBA", c_contiguous, benchmark)

tests/tensor/test_elemwise.py

+27-17
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,30 @@ def custom_reduce(a):
6666
return np.apply_along_axis(custom_reduce, axis, x)
6767

6868

69+
def dimshuffle_benchmark(mode, c_contiguous, benchmark):
70+
x = tensor3("x")
71+
if c_contiguous:
72+
x_val = np.random.random((2, 3, 4)).astype(config.floatX)
73+
else:
74+
x_val = np.random.random((200, 300, 400)).transpose(1, 2, 0)
75+
ys = [x.transpose(t) for t in itertools.permutations((0, 1, 2))]
76+
ys += [
77+
x[None],
78+
x[:, None],
79+
x[:, :, None],
80+
x[:, :, :, None],
81+
]
82+
# Borrow to avoid deepcopy overhead
83+
fn = pytensor.function(
84+
[In(x, borrow=True)],
85+
[Out(y, borrow=True) for y in ys],
86+
mode=mode,
87+
)
88+
fn.trust_input = True
89+
fn(x_val) # JIT compile for JIT backends
90+
benchmark(fn, x_val)
91+
92+
6993
class TestDimShuffle(unittest_tools.InferShapeTester):
7094
op = DimShuffle
7195
type = TensorType
@@ -218,23 +242,9 @@ def test_valid_input_ndim(self):
218242
with pytest.raises(TypeError, match="input_ndim must be an integer"):
219243
DimShuffle(input_ndim=(True, False), new_order=(1, 0))
220244

221-
def test_benchmark(self, benchmark):
222-
x = tensor3("x")
223-
x_val = np.random.random((2, 3, 4)).astype(config.floatX)
224-
ys = [x.transpose(t) for t in itertools.permutations((0, 1, 2))]
225-
ys += [
226-
x[None],
227-
x[:, None],
228-
x[:, :, None],
229-
x[:, :, :, None],
230-
]
231-
# Borrow to avoid deepcopy overhead
232-
fn = pytensor.function(
233-
[In(x, borrow=True)],
234-
[Out(y, borrow=True) for y in ys],
235-
)
236-
fn.trust_input = True
237-
benchmark(fn, x_val)
245+
@pytest.mark.parametrize("c_contiguous", [True, False])
246+
def test_benchmark(self, c_contiguous, benchmark):
247+
dimshuffle_benchmark("FAST_RUN", c_contiguous, benchmark)
238248

239249

240250
class TestBroadcast:

0 commit comments

Comments
 (0)