Skip to content

Commit d630ee5

Browse files
committed
Fix the pinv function, which was implicitly using __array__
1 parent bb28167 commit d630ee5

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

array_api_strict/_linalg.py

+2
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,8 @@ def pinv(x: Array, /, *, rtol: Optional[Union[float, Array]] = None) -> Array:
267267
# default tolerance by max(M, N).
268268
if rtol is None:
269269
rtol = max(x.shape[-2:]) * finfo(x.dtype).eps
270+
if isinstance(rtol, Array):
271+
rtol = rtol._array
270272
return Array._new(np.linalg.pinv(x._array, rcond=rtol), device=x.device)
271273

272274
@requires_extension('linalg')

0 commit comments

Comments
 (0)