|
4 | 4 | import numba
|
5 | 5 | import numpy as np
|
6 | 6 | from numba.core.extending import overload
|
| 7 | +from numpy.lib.stride_tricks import as_strided |
7 | 8 |
|
8 | 9 | from pytensor.graph.op import Op
|
9 | 10 | from pytensor.link.numba.dispatch import basic as numba_basic
|
@@ -411,91 +412,23 @@ def numba_funcify_CAReduce(op, node, **kwargs):
|
411 | 412 |
|
412 | 413 | @numba_funcify.register(DimShuffle)
|
413 | 414 | def numba_funcify_DimShuffle(op, node, **kwargs):
|
414 |
| - shuffle = tuple(op.shuffle) |
415 |
| - transposition = tuple(op.transposition) |
416 |
| - augment = tuple(op.augment) |
| 415 | + new_order = tuple(op._new_order) |
| 416 | + shape_template = (1,) * node.outputs[0].ndim |
| 417 | + strides_template = (0,) * node.outputs[0].ndim |
417 | 418 |
|
418 |
| - ndim_new_shape = len(shuffle) + len(augment) |
419 |
| - |
420 |
| - no_transpose = all(i == j for i, j in enumerate(transposition)) |
421 |
| - if no_transpose: |
422 |
| - |
423 |
| - @numba_basic.numba_njit |
424 |
| - def transpose(x): |
425 |
| - return x |
426 |
| - |
427 |
| - else: |
428 |
| - |
429 |
| - @numba_basic.numba_njit |
430 |
| - def transpose(x): |
431 |
| - return np.transpose(x, transposition) |
432 |
| - |
433 |
| - shape_template = (1,) * ndim_new_shape |
434 |
| - |
435 |
| - # When `len(shuffle) == 0`, the `shuffle_shape[j]` expression below |
436 |
| - # is typed as `getitem(Tuple(), int)`, which has no implementation |
437 |
| - # (since getting an item from an empty sequence doesn't make sense). |
438 |
| - # To avoid this compile-time error, we omit the expression altogether. |
439 |
| - if len(shuffle) > 0: |
440 |
| - # Use the statically known shape if available |
441 |
| - if all(length is not None for length in node.outputs[0].type.shape): |
442 |
| - shape = node.outputs[0].type.shape |
443 |
| - |
444 |
| - @numba_basic.numba_njit |
445 |
| - def find_shape(array_shape): |
446 |
| - return shape |
447 |
| - |
448 |
| - else: |
449 |
| - |
450 |
| - @numba_basic.numba_njit |
451 |
| - def find_shape(array_shape): |
452 |
| - shape = shape_template |
453 |
| - j = 0 |
454 |
| - for i in range(ndim_new_shape): |
455 |
| - if i not in augment: |
456 |
| - length = array_shape[j] |
457 |
| - shape = numba_basic.tuple_setitem(shape, i, length) |
458 |
| - j = j + 1 |
459 |
| - return shape |
460 |
| - |
461 |
| - else: |
462 |
| - |
463 |
| - @numba_basic.numba_njit |
464 |
| - def find_shape(array_shape): |
465 |
| - return shape_template |
466 |
| - |
467 |
| - if ndim_new_shape > 0: |
468 |
| - |
469 |
| - @numba_basic.numba_njit |
470 |
| - def dimshuffle_inner(x, shuffle): |
471 |
| - x = transpose(x) |
472 |
| - shuffle_shape = x.shape[: len(shuffle)] |
473 |
| - new_shape = find_shape(shuffle_shape) |
474 |
| - |
475 |
| - # FIXME: Numba's `array.reshape` only accepts C arrays. |
476 |
| - return np.reshape(np.ascontiguousarray(x), new_shape) |
| 419 | + @numba_basic.numba_njit |
| 420 | + def dimshuffle(x): |
| 421 | + old_shape = x.shape |
| 422 | + old_strides = x.strides |
477 | 423 |
|
478 |
| - else: |
| 424 | + new_shape = shape_template |
| 425 | + new_strides = strides_template |
| 426 | + for i, o in enumerate(new_order): |
| 427 | + if o != -1: |
| 428 | + new_shape = numba_basic.tuple_setitem(new_shape, i, old_shape[o]) |
| 429 | + new_strides = numba_basic.tuple_setitem(new_strides, i, old_strides[o]) |
479 | 430 |
|
480 |
| - @numba_basic.numba_njit |
481 |
| - def dimshuffle_inner(x, shuffle): |
482 |
| - return np.reshape(np.ascontiguousarray(x), ()) |
483 |
| - |
484 |
| - # Without the following wrapper function we would see this error: |
485 |
| - # E No implementation of function Function(<built-in function getitem>) found for signature: |
486 |
| - # E |
487 |
| - # E >>> getitem(UniTuple(int64 x 2), slice<a:b>) |
488 |
| - # E |
489 |
| - # E There are 22 candidate implementations: |
490 |
| - # E - Of which 22 did not match due to: |
491 |
| - # E Overload of function 'getitem': File: <numerous>: Line N/A. |
492 |
| - # E With argument(s): '(UniTuple(int64 x 2), slice<a:b>)': |
493 |
| - # E No match. |
494 |
| - # ...(on this line)... |
495 |
| - # E shuffle_shape = res.shape[: len(shuffle)] |
496 |
| - @numba_basic.numba_njit(inline="always") |
497 |
| - def dimshuffle(x): |
498 |
| - return dimshuffle_inner(np.asarray(x), shuffle) |
| 431 | + return as_strided(x, shape=new_shape, strides=new_strides) |
499 | 432 |
|
500 | 433 | return dimshuffle
|
501 | 434 |
|
|
0 commit comments