Skip to content

Commit ee12cd1

Browse files
authored
BUG: fix sinc for torch (#56)
1 parent a89ffdd commit ee12cd1

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

src/array_api_extra/_funcs.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,8 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
543543
raise ValueError(err_msg)
544544
# no scalars in `where` - array-api#807
545545
y = xp.pi * xp.where(
546-
x, x, xp.asarray(xp.finfo(x.dtype).eps, dtype=x.dtype, device=x.device)
546+
xp.astype(x, xp.bool),
547+
x,
548+
xp.asarray(xp.finfo(x.dtype).eps, dtype=x.dtype, device=x.device),
547549
)
548550
return xp.sin(y) / y

0 commit comments

Comments
 (0)