Skip to content

Commit 5b1b2d6

Browse files
Handle corner case when steps = 0
1 parent 3cc5766 commit 5b1b2d6

File tree

2 files changed

+36
-8
lines changed

2 files changed

+36
-8
lines changed

pytensor/tensor/extra_ops.py

+35-8
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,19 @@
2727
from pytensor.tensor.exceptions import NotScalarConstantError
2828
from pytensor.tensor.math import abs as pt_abs
2929
from pytensor.tensor.math import all as pt_all
30+
from pytensor.tensor.math import (
31+
bitwise_and,
32+
ge,
33+
gt,
34+
log,
35+
lt,
36+
maximum,
37+
minimum,
38+
prod,
39+
sign,
40+
switch,
41+
)
3042
from pytensor.tensor.math import eq as pt_eq
31-
from pytensor.tensor.math import ge, log, lt, maximum, minimum, prod, sign, switch
3243
from pytensor.tensor.math import max as pt_max
3344
from pytensor.tensor.math import sum as pt_sum
3445
from pytensor.tensor.shape import specify_broadcastable
@@ -1602,17 +1613,25 @@ def _check_deprecated_inputs(stop, end, num, steps):
16021613

16031614

16041615
def _linspace_core(
1605-
start, stop, num, endpoint=True, retstep=False, dtype=None, axis=0
1616+
start: TensorVariable,
1617+
stop: TensorVariable,
1618+
num: int,
1619+
dtype: str,
1620+
endpoint=True,
1621+
retstep=False,
1622+
axis=0,
16061623
) -> TensorVariable | tuple[TensorVariable, TensorVariable]:
16071624
div = (num - 1) if endpoint else num
16081625
delta = (stop - start).astype(dtype)
1609-
step = delta / div
1610-
16111626
samples = ptb.arange(0, num, dtype=dtype).reshape((-1,) + (1,) * delta.ndim)
1612-
samples = samples * step + start
16131627

1614-
if endpoint and num > 1:
1615-
set_subtensor(samples[-1, ...], stop, inplace=True)
1628+
step = switch(gt(div, 0), delta / div, np.nan)
1629+
samples = switch(gt(div, 0), samples * delta / div + start, samples * delta + start)
1630+
samples = switch(
1631+
bitwise_and(gt(num, 1), np.asarray(endpoint)),
1632+
set_subtensor(samples[-1, ...], stop),
1633+
samples,
1634+
)
16161635

16171636
if axis != 0:
16181637
samples = ptb.moveaxis(samples, 0, axis)
@@ -1729,7 +1748,15 @@ def linspace(
17291748
end, num = _check_deprecated_inputs(stop, end, num, steps)
17301749
(start, stop), type = _broadcast_inputs_and_dtypes(start, stop, dtype=dtype)
17311750

1732-
return _linspace_core(start, stop, num, endpoint, retstep, dtype, axis)
1751+
return _linspace_core(
1752+
start=start,
1753+
stop=stop,
1754+
num=num,
1755+
dtype=dtype,
1756+
endpoint=endpoint,
1757+
retstep=retstep,
1758+
axis=axis,
1759+
)
17331760

17341761

17351762
def geomspace(

tests/tensor/test_extra_ops.py

+1
Original file line numberDiff line numberDiff line change
@@ -1286,6 +1286,7 @@ def test_broadcast_arrays():
12861286
"start, stop, num_samples, endpoint, axis",
12871287
[
12881288
(1, 10, 50, True, 0),
1289+
(1, 10, 1, True, 0),
12891290
(np.array([5, 6]), np.array([[10, 10], [10, 10]]), 25, True, 0),
12901291
(np.array([5, 6]), np.array([[10, 10], [10, 10]]), 25, True, 1),
12911292
(np.array([5, 6]), np.array([[10, 10], [10, 10]]), 25, False, -1),

0 commit comments

Comments
 (0)