Skip to content

Commit 5308ddd

Browse files
committed
Remove patch on Numba impl of Split
1 parent 4b1761b commit 5308ddd

File tree

1 file changed

+1
-4
lines changed

1 file changed

+1
-4
lines changed

pytensor/link/numba/dispatch/tensor_basic.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -136,10 +136,7 @@ def join(axis, *tensors):
136136
def numba_funcify_Split(op, **kwargs):
137137
@numba_basic.numba_njit
138138
def split(tensor, axis, indices):
139-
# Work around for https://github.com/numba/numba/issues/8257
140-
axis = axis % tensor.ndim
141-
axis = numba_basic.to_scalar(axis)
142-
return np.split(tensor, np.cumsum(indices)[:-1], axis=axis)
139+
return np.split(tensor, np.cumsum(indices)[:-1], axis=axis.item())
143140

144141
return split
145142

0 commit comments

Comments
 (0)