Skip to content

Speedup DimShuffle and Reshape in C and Numba backends #1226

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 25, 2025

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Feb 20, 2025

TLDR: Make views cheap again!

All backends

Remove the unused inplace option, less branching is always helpful

Python backend

Implementation based on as_strided that mimicks the old (and now current again) C-backend approach. This is debatable since the Python implementation has historically been there more for readability than performance. However it's useful for fast compiling graphs not to suck?
as_strided is actually insanely slow in python, so I'm back to transpose+reshape. The removal of some checks and unnecessary asarray calls provide a small speedup, although this is obviously not critical.

Numba backend

Use as_strided which simplifies the implementation by a ton with two benefits:

  1. Much faster compilation / caching. The new benchmark runs on 1.6s (includes many runs) after caching, vs 4.6s. When checking a simple function directly I saw wall times going down from 1.2 vs 600ms (first time after starting the interpreter) and 300ms vs 100ms in subsequent runs. This seems to scale with the number of DimShuffles in the graph?
  2. Slightly faster for contiguous arrays: the new benchmark roughly 1.1-1.3x faster
import pytensor
import pytensor.tensor as pt
import numpy as np

x = pt.tensor("x", shape=(1, 2, 3))
y = x.dimshuffle(2, "x", 1, "x")

out = pytensor.function([pytensor.In(x, borrow=True)], pytensor.Out(y, borrow=True), mode="NUMBA")
out.trust_input = True

x_test = np.zeros((1, 2, 3))
out(x_test)
%timeit out(x_test)

The example with non-contiguous arrays in #1111 has no penalty now, which is a 10,000x speedup (example is exaggerated with a very large array)

C backend

This was the original goal of this PR

There were two regressions caused by e593b0a and 223ee15.

The DimShuffle changes introduce a PyArray_IntpConverter to compute a new shape and used two calls to numpy C-functions PyArray_Transpose and PyArray_NewShape. I suspect the slowdown comes mostly from the introduction of PyArray_IntPConverter but I couldn't find anything wrong with the simpler logic from Theano times, other than...

The bug that motivated the changes had to do with a useless second pass on the dimensions calculations:

    /* set the strides of the broadcasted dimensions.
     * This algorithm is from numpy: PyArray_Newshape() in
     * cvs/numpy/numpy/core/src/multiarraymodule.c */
    if (nd_out > 0) {
        if (strides[nd_out - 1] == 0)
            strides[nd_out - 1] = PyArray_DESCR(basename)->elsize;
        for (npy_intp i = nd_out - 2; i > -1; --i) {
            if (strides[i] == 0)
                strides[i] = strides[i + 1] * dimensions[i + 1];
        }
    }

Which is baffling in that DimShuffle is not doing any broadcasting behavior, othen than expand_dims, which the first pass already handles correctly. Removing this loop fixes the original bug.

The Reshape changes were simpler, they introduced a generic PyArray_IntpConverter to convert the shape numpy vector into a simple C-array for the Reshape operation.

The concern about strides in the commit message doesn't hold, because they were being used directly. I added a direct tests for odd strides just in case.

In the long term Reshape should take the separate shape entries as scalars, instead of a single vector. The user never defines a graph with a vector anyway, but with isolated entries, so the whole packing and unpacking is just overhead. This is tracked #881.

Profiling a simple function suggests the previous changes caused a 8.6-11x slowdown per Op call for the DimShuffle operation and a 5.4x slowdown for the Reshape operation. The new benchmarks detects only a 3.6x and 2.3x slowdown, respectively, due to the PyTensor call overhead.

## Op
### DimShuffle: (8.6x slowdown)
Before: 1.74e-06s per call (class)
After: 2.02e-07s

### Reshape: (5.4x slowdown)
Before: 1.50e-06s per call (class)
After: 2.78e-07s

Copy link

codecov bot commented Feb 20, 2025

Codecov Report

Attention: Patch coverage is 96.55172% with 1 line in your changes missing coverage. Please review.

Project coverage is 81.99%. Comparing base (53763f5) to head (50801a3).
Report is 4 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/link/pytorch/dispatch/elemwise.py 0.00% 1 Missing ⚠️

❌ Your patch status has failed because the patch coverage (96.55%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files

Impacted file tree graph

@@           Coverage Diff           @@
##             main    #1226   +/-   ##
=======================================
  Coverage   81.98%   81.99%           
=======================================
  Files         188      188           
  Lines       48568    48523   -45     
  Branches     8677     8668    -9     
=======================================
- Hits        39819    39785   -34     
+ Misses       6584     6579    -5     
+ Partials     2165     2159    -6     
Files with missing lines Coverage Δ
pytensor/link/jax/dispatch/elemwise.py 83.82% <100.00%> (+2.13%) ⬆️
pytensor/link/numba/dispatch/elemwise.py 94.82% <100.00%> (+0.14%) ⬆️
pytensor/tensor/elemwise.py 89.59% <100.00%> (+0.59%) ⬆️
pytensor/tensor/shape.py 90.10% <100.00%> (-0.05%) ⬇️
pytensor/link/pytorch/dispatch/elemwise.py 66.16% <0.00%> (+1.45%) ⬆️

... and 1 file with indirect coverage changes

@ricardoV94 ricardoV94 force-pushed the reshape_regression branch 2 times, most recently from ebc4263 to b24ee8b Compare February 20, 2025 13:01
@ricardoV94 ricardoV94 changed the title Revert regression in Reshape c-impl speed Revert regression in DimShuffle and Reshape C-impl speed Feb 20, 2025
@ricardoV94 ricardoV94 marked this pull request as ready for review February 20, 2025 14:23
@ricardoV94 ricardoV94 changed the title Revert regression in DimShuffle and Reshape C-impl speed Speedup DimShuffle and Reshape in C and Numba backends Feb 21, 2025
@ricardoV94 ricardoV94 marked this pull request as draft February 21, 2025 12:43
@numba_basic.numba_njit(inline="always")
def dimshuffle(x):
return dimshuffle_inner(np.asarray(x), shuffle)
return as_strided(x, shape=new_shape, strides=new_strides)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it correct that since we only allow changing the order of the axis, and don't introduce new axes or change the total number of arrays this never produces arrays where the same location in memory is used multiple times?
I think it would be a bad idea to introduce cases where that could happen. The as_strided docs also contains a warning note about this: https://numpy.org/devdocs/reference/generated/numpy.lib.stride_tricks.as_strided.html#numpy.lib.stride_tricks.as_strided

If that is in fact not an issue here, I think we should add a small note in the code here pointing out why that's not a problem?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does produce multiple arrays that use the same memory and introduces/ removes axis. That's why the Op is marked as a view op to signal it can't be mutated if the original input is needed elsewhere or protected (ie provided by the user)

Copy link
Member Author

@ricardoV94 ricardoV94 Feb 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't create self overlapping memory and use the original strides which is what the docs recommend. Now the user could provide such an input but that would break any inplace rewrite regardless of this Op

@ricardoV94
Copy link
Member Author

ricardoV94 commented Feb 22, 2025

@jessegrabowski @aseyboldt the new numba DimShuffle is revealing some inplace bug in the failing linalg grad tests (at least test_numba_Cholesky_grad, and test_solve_triangular_grad with "defaults").

Before the DimShuffle in the grad graphs were probably forcing a copy with the as_contiguous call inside reshape.

It can be avoided by doing a .copy() in the output of DimShuffle (which is not a fix!) or disabling inplace.

There are 3 options for the source of the bug:

  1. an Op is destroying the input when it shouldn't similar to e25e8a2
  2. An op is returning a view of the input when it shouldn't (it didn't advertise it as a view_map)
  3. Some bug in the inplacing rewrites that is introducing a rewrite that is technically not valid.

Unfortunately we don't have something like the DebugMode for Numba that would identify which Op is the culprit? Can you think of a way to debug this?

Here is the compiled grad graph of the cholesky test (I removed Elemwise inplace optimization so only SolveTriangular is supposedly allowed to do inplacing):


Composite{(((i0 + i1) * i2) - i3)} [id A] <Matrix(float64, shape=(5, 5))> 17
 ├─ SolveTriangular{trans=0, unit_diagonal=False, lower=False, check_finite=True, b_ndim=2, overwrite_b=True} [id B] <Matrix(float64, shape=(5, 5))> d={0: [1]} 13
 │  ├─ Transpose{axes=[1, 0]} [id C] <Matrix(float64, shape=(5, 5))> v={0: [0]} 3
 │  │  └─ Cholesky{lower=True, check_finite=True, on_error='raise', overwrite_a=False} [id D] <Matrix(float64, shape=(5, 5))> 1
 │  │     └─ input 0 [id E] <Matrix(float64, shape=(5, 5))>
 │  └─ Transpose{axes=[1, 0]} [id F] <Matrix(float64, shape=(5, 5))> v={0: [0]} 12
 │     └─ SolveTriangular{trans=0, unit_diagonal=False, lower=False, check_finite=True, b_ndim=2, overwrite_b=True} [id G] <Matrix(float64, shape=(5, 5))> d={0: [1]} 11
 │        ├─ Transpose{axes=[1, 0]} [id C] <Matrix(float64, shape=(5, 5))> v={0: [0]} 3
 │        │  └─ ···
 │        └─ Composite{((i0 * i1) - i2)} [id H] <Matrix(float64, shape=(5, 5))> 10
 │           ├─ dot [id I] <Matrix(float64, shape=(5, 5))> 4
 │           │  ├─ Transpose{axes=[1, 0]} [id J] <Matrix(float64, shape=(5, 5))> v={0: [0]} 2
 │           │  │  └─ SpecifyShape [id K] <Matrix(float64, shape=(5, 5))> v={0: [0]} 0
 │           │  │     ├─ random_projection [id L] <Matrix(float64, shape=(?, ?))>
 │           │  │     ├─ 5 [id M] <Scalar(int8, shape=())>
 │           │  │     └─ 5 [id M] <Scalar(int8, shape=())>
 │           │  └─ Cholesky{lower=True, check_finite=True, on_error='raise', overwrite_a=False} [id D] <Matrix(float64, shape=(5, 5))> 1
 │           │     └─ ···
 │           ├─ [[1. 1. 1. ... 0. 0. 1.]] [id N] <Matrix(float64, shape=(5, 5))>
 │           └─ Transpose{axes=[1, 0]} [id O] <Matrix(float64, shape=(?, ?))> v={0: [0]} 9
 │              └─ AllocDiag{self.axis1=0, self.axis2=1, self.offset=0} [id P] <Matrix(float64, shape=(?, ?))> 8
 │                 └─ Mul [id Q] <Vector(float64, shape=(?,))> 7
 │                    ├─ [0.5] [id R] <Vector(float64, shape=(1,))>
 │                    └─ ExtractDiag{offset=0, axis1=0, axis2=1, view=False} [id S] <Vector(float64, shape=(?,))> 6
 │                       └─ dot [id T] <Matrix(float64, shape=(5, 5))> 5
 │                          ├─ Transpose{axes=[1, 0]} [id C] <Matrix(float64, shape=(5, 5))> v={0: [0]} 3
 │                          │  └─ ···
 │                          └─ SpecifyShape [id K] <Matrix(float64, shape=(5, 5))> v={0: [0]} 0
 │                             └─ ···
 ├─ Transpose{axes=[1, 0]} [id U] <Matrix(float64, shape=(5, 5))> v={0: [0]} 15
 │  └─ SolveTriangular{trans=0, unit_diagonal=False, lower=False, check_finite=True, b_ndim=2, overwrite_b=True} [id B] <Matrix(float64, shape=(5, 5))> d={0: [1]} 13
 │     └─ ···
 ├─ [[1. 0. 0. ... 1. 1. 1.]] [id V] <Matrix(float64, shape=(5, 5))>
 └─ AllocDiag{self.axis1=0, self.axis2=1, self.offset=0} [id W] <Matrix(float64, shape=(?, ?))> 16
    └─ ExtractDiag{offset=0, axis1=0, axis2=1, view=False} [id X] <Vector(float64, shape=(?,))> 14
       └─ SolveTriangular{trans=0, unit_diagonal=False, lower=False, check_finite=True, b_ndim=2, overwrite_b=True} [id B] <Matrix(float64, shape=(5, 5))> d={0: [1]} 13
          └─ ···

@jessegrabowski
Copy link
Member

Running the test suit locally, I see that the test_solve_triangular_grad test is also failing, but only when lower=False, which is the mode that is used in the cholesky gradient, so the bug might be there?

@ricardoV94 ricardoV94 force-pushed the reshape_regression branch 3 times, most recently from 0693144 to e96b6a9 Compare February 23, 2025 07:26
dimensions = (npy_intp*) malloc(nd_out * sizeof(npy_intp));
strides = (npy_intp*) malloc(nd_out * sizeof(npy_intp));
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I don't know why we're using the externalCOp approach for this. It used to be a normal COp before, and then you didn't need this sort of alloc/params stuff...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ricardoV94
Copy link
Member Author

Found the bug: #1233

@ricardoV94 ricardoV94 force-pushed the reshape_regression branch 3 times, most recently from 5714253 to 315a937 Compare February 25, 2025 10:16
This was caused by 223ee15, which used the generic `PyArray_IntpConverter` to convert the shape numpy vector into a simple C-array for the Reshape operation. There seems to be no need for this change as the strides were correctly used

Profiling suggests the previous changes caused a 7.5x slowdown. The benchmark detects only a 2.3x slowdown due to the PyTensor call overhead.
Introduced in e593b0a due to a bug when inputs had zero-strides.

The bug can be fixed just by removing a block that assumed some `full`/`broadcasting` behavior by the operation, but this is not happening with DimShuffle.
@ricardoV94 ricardoV94 marked this pull request as ready for review February 25, 2025 10:22
@ricardoV94 ricardoV94 requested a review from aseyboldt February 25, 2025 10:23
Copy link
Member

@aseyboldt aseyboldt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good :-)

@ricardoV94 ricardoV94 merged commit 5008fab into pymc-devs:main Feb 25, 2025
72 of 73 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants