Skip to content

Commit e25e8a2

Browse files
committed
Fix unauthorized inplace update of vector B in numba solve_triangular
1 parent cc8c499 commit e25e8a2

File tree

2 files changed

+11
-7
lines changed

2 files changed

+11
-7
lines changed

pytensor/link/numba/dispatch/slinalg.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -126,13 +126,17 @@ def impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b):
126126

127127
B_is_1d = B.ndim == 1
128128

129-
if not overwrite_b:
130-
B_copy = _copy_to_fortran_order(B)
131-
else:
129+
if overwrite_b:
132130
B_copy = B
131+
else:
132+
if B_is_1d:
133+
# _copy_to_fortran_order does nothing with vectors
134+
B_copy = np.copy(B)
135+
else:
136+
B_copy = _copy_to_fortran_order(B)
133137

134138
if B_is_1d:
135-
B_copy = np.expand_dims(B, -1)
139+
B_copy = np.expand_dims(B_copy, -1)
136140

137141
NRHS = 1 if B_is_1d else int(B_copy.shape[-1])
138142

tests/link/numba/test_slinalg.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,9 @@ def A_func(x):
7979
A_val = A_val + np.random.normal(size=(5, 5)) * 1j
8080
b_val = b_val + np.random.normal(size=b_shape) * 1j
8181

82-
X_np = f(A_func(A_val.copy()), b_val.copy())
82+
X_np = f(A_func(A_val), b_val)
8383

84-
test_input = transpose_func(A_func(A_val.copy()), trans)
84+
test_input = transpose_func(A_func(A_val), trans)
8585

8686
ATOL = 1e-8 if floatX.endswith("64") else 1e-4
8787
RTOL = 1e-8 if floatX.endswith("64") else 1e-4
@@ -92,7 +92,7 @@ def A_func(x):
9292
compare_numba_and_py(
9393
compiled_fgraph.inputs,
9494
compiled_fgraph.outputs,
95-
[A_func(A_val.copy()), b_val.copy()],
95+
[A_func(A_val), b_val],
9696
)
9797

9898

0 commit comments

Comments
 (0)