Skip to content

Commit 6c93750

Browse files
committed
Simplify numba DimShuffle implementation
1 parent c83e04f commit 6c93750

File tree

1 file changed

+15
-82
lines changed

1 file changed

+15
-82
lines changed

pytensor/link/numba/dispatch/elemwise.py

+15-82
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import numba
55
import numpy as np
66
from numba.core.extending import overload
7+
from numpy.lib.stride_tricks import as_strided
78

89
from pytensor.graph.op import Op
910
from pytensor.link.numba.dispatch import basic as numba_basic
@@ -411,91 +412,23 @@ def numba_funcify_CAReduce(op, node, **kwargs):
411412

412413
@numba_funcify.register(DimShuffle)
413414
def numba_funcify_DimShuffle(op, node, **kwargs):
414-
shuffle = tuple(op.shuffle)
415-
transposition = tuple(op.transposition)
416-
augment = tuple(op.augment)
415+
new_order = tuple(op._new_order)
416+
shape_template = (1,) * node.outputs[0].ndim
417+
strides_template = (0,) * node.outputs[0].ndim
417418

418-
ndim_new_shape = len(shuffle) + len(augment)
419-
420-
no_transpose = all(i == j for i, j in enumerate(transposition))
421-
if no_transpose:
422-
423-
@numba_basic.numba_njit
424-
def transpose(x):
425-
return x
426-
427-
else:
428-
429-
@numba_basic.numba_njit
430-
def transpose(x):
431-
return np.transpose(x, transposition)
432-
433-
shape_template = (1,) * ndim_new_shape
434-
435-
# When `len(shuffle) == 0`, the `shuffle_shape[j]` expression below
436-
# is typed as `getitem(Tuple(), int)`, which has no implementation
437-
# (since getting an item from an empty sequence doesn't make sense).
438-
# To avoid this compile-time error, we omit the expression altogether.
439-
if len(shuffle) > 0:
440-
# Use the statically known shape if available
441-
if all(length is not None for length in node.outputs[0].type.shape):
442-
shape = node.outputs[0].type.shape
443-
444-
@numba_basic.numba_njit
445-
def find_shape(array_shape):
446-
return shape
447-
448-
else:
449-
450-
@numba_basic.numba_njit
451-
def find_shape(array_shape):
452-
shape = shape_template
453-
j = 0
454-
for i in range(ndim_new_shape):
455-
if i not in augment:
456-
length = array_shape[j]
457-
shape = numba_basic.tuple_setitem(shape, i, length)
458-
j = j + 1
459-
return shape
460-
461-
else:
462-
463-
@numba_basic.numba_njit
464-
def find_shape(array_shape):
465-
return shape_template
466-
467-
if ndim_new_shape > 0:
468-
469-
@numba_basic.numba_njit
470-
def dimshuffle_inner(x, shuffle):
471-
x = transpose(x)
472-
shuffle_shape = x.shape[: len(shuffle)]
473-
new_shape = find_shape(shuffle_shape)
474-
475-
# FIXME: Numba's `array.reshape` only accepts C arrays.
476-
return np.reshape(np.ascontiguousarray(x), new_shape)
419+
@numba_basic.numba_njit
420+
def dimshuffle(x):
421+
old_shape = x.shape
422+
old_strides = x.strides
477423

478-
else:
424+
new_shape = shape_template
425+
new_strides = strides_template
426+
for i, o in enumerate(new_order):
427+
if o != -1:
428+
new_shape = numba_basic.tuple_setitem(new_shape, i, old_shape[o])
429+
new_strides = numba_basic.tuple_setitem(new_strides, i, old_strides[o])
479430

480-
@numba_basic.numba_njit
481-
def dimshuffle_inner(x, shuffle):
482-
return np.reshape(np.ascontiguousarray(x), ())
483-
484-
# Without the following wrapper function we would see this error:
485-
# E No implementation of function Function(<built-in function getitem>) found for signature:
486-
# E
487-
# E >>> getitem(UniTuple(int64 x 2), slice<a:b>)
488-
# E
489-
# E There are 22 candidate implementations:
490-
# E - Of which 22 did not match due to:
491-
# E Overload of function 'getitem': File: <numerous>: Line N/A.
492-
# E With argument(s): '(UniTuple(int64 x 2), slice<a:b>)':
493-
# E No match.
494-
# ...(on this line)...
495-
# E shuffle_shape = res.shape[: len(shuffle)]
496-
@numba_basic.numba_njit(inline="always")
497-
def dimshuffle(x):
498-
return dimshuffle_inner(np.asarray(x), shuffle)
431+
return as_strided(x, shape=new_shape, strides=new_strides)
499432

500433
return dimshuffle
501434

0 commit comments

Comments
 (0)