We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 4b1761b commit 5308dddCopy full SHA for 5308ddd
pytensor/link/numba/dispatch/tensor_basic.py
@@ -136,10 +136,7 @@ def join(axis, *tensors):
136
def numba_funcify_Split(op, **kwargs):
137
@numba_basic.numba_njit
138
def split(tensor, axis, indices):
139
- # Work around for https://github.com/numba/numba/issues/8257
140
- axis = axis % tensor.ndim
141
- axis = numba_basic.to_scalar(axis)
142
- return np.split(tensor, np.cumsum(indices)[:-1], axis=axis)
+ return np.split(tensor, np.cumsum(indices)[:-1], axis=axis.item())
143
144
return split
145
0 commit comments