Skip to content

Commit 83ac04f

Browse files
committed
BUG: fix dtype of include_initial in cumulative_sum
In `concat([zeros(...), x])` zeros must have the same dtype as `x`.
1 parent 3ff4ca6 commit 83ac04f

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

array_api_strict/_statistical_functions.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ def cumulative_sum(
3131
) -> Array:
3232
if x.dtype not in _numeric_dtypes:
3333
raise TypeError("Only numeric dtypes are allowed in cumulative_sum")
34-
dt = x.dtype if dtype is None else dtype
3534
if dtype is not None:
3635
dtype = dtype._np_dtype
3736

@@ -44,7 +43,7 @@ def cumulative_sum(
4443
if include_initial:
4544
if axis < 0:
4645
axis += x.ndim
47-
x = concat([zeros(x.shape[:axis] + (1,) + x.shape[axis + 1:], dtype=dt), x], axis=axis)
46+
x = concat([zeros(x.shape[:axis] + (1,) + x.shape[axis + 1:], dtype=x.dtype), x], axis=axis)
4847
return Array._new(np.cumsum(x._array, axis=axis, dtype=dtype), device=x.device)
4948

5049

0 commit comments

Comments
 (0)