|
27 | 27 | from pytensor.tensor.exceptions import NotScalarConstantError
|
28 | 28 | from pytensor.tensor.math import abs as pt_abs
|
29 | 29 | from pytensor.tensor.math import all as pt_all
|
| 30 | +from pytensor.tensor.math import ( |
| 31 | + bitwise_and, |
| 32 | + ge, |
| 33 | + gt, |
| 34 | + log, |
| 35 | + lt, |
| 36 | + maximum, |
| 37 | + minimum, |
| 38 | + prod, |
| 39 | + sign, |
| 40 | + switch, |
| 41 | +) |
30 | 42 | from pytensor.tensor.math import eq as pt_eq
|
31 |
| -from pytensor.tensor.math import ge, log, lt, maximum, minimum, prod, sign, switch |
32 | 43 | from pytensor.tensor.math import max as pt_max
|
33 | 44 | from pytensor.tensor.math import sum as pt_sum
|
34 | 45 | from pytensor.tensor.shape import specify_broadcastable
|
@@ -1602,17 +1613,25 @@ def _check_deprecated_inputs(stop, end, num, steps):
|
1602 | 1613 |
|
1603 | 1614 |
|
1604 | 1615 | def _linspace_core(
|
1605 |
| - start, stop, num, endpoint=True, retstep=False, dtype=None, axis=0 |
| 1616 | + start: TensorVariable, |
| 1617 | + stop: TensorVariable, |
| 1618 | + num: int, |
| 1619 | + dtype: str, |
| 1620 | + endpoint=True, |
| 1621 | + retstep=False, |
| 1622 | + axis=0, |
1606 | 1623 | ) -> TensorVariable | tuple[TensorVariable, TensorVariable]:
|
1607 | 1624 | div = (num - 1) if endpoint else num
|
1608 | 1625 | delta = (stop - start).astype(dtype)
|
1609 |
| - step = delta / div |
1610 |
| - |
1611 | 1626 | samples = ptb.arange(0, num, dtype=dtype).reshape((-1,) + (1,) * delta.ndim)
|
1612 |
| - samples = samples * step + start |
1613 | 1627 |
|
1614 |
| - if endpoint and num > 1: |
1615 |
| - set_subtensor(samples[-1, ...], stop, inplace=True) |
| 1628 | + step = switch(gt(div, 0), delta / div, np.nan) |
| 1629 | + samples = switch(gt(div, 0), samples * delta / div + start, samples * delta + start) |
| 1630 | + samples = switch( |
| 1631 | + bitwise_and(gt(num, 1), np.asarray(endpoint)), |
| 1632 | + set_subtensor(samples[-1, ...], stop), |
| 1633 | + samples, |
| 1634 | + ) |
1616 | 1635 |
|
1617 | 1636 | if axis != 0:
|
1618 | 1637 | samples = ptb.moveaxis(samples, 0, axis)
|
@@ -1729,7 +1748,15 @@ def linspace(
|
1729 | 1748 | end, num = _check_deprecated_inputs(stop, end, num, steps)
|
1730 | 1749 | (start, stop), type = _broadcast_inputs_and_dtypes(start, stop, dtype=dtype)
|
1731 | 1750 |
|
1732 |
| - return _linspace_core(start, stop, num, endpoint, retstep, dtype, axis) |
| 1751 | + return _linspace_core( |
| 1752 | + start=start, |
| 1753 | + stop=stop, |
| 1754 | + num=num, |
| 1755 | + dtype=dtype, |
| 1756 | + endpoint=endpoint, |
| 1757 | + retstep=retstep, |
| 1758 | + axis=axis, |
| 1759 | + ) |
1733 | 1760 |
|
1734 | 1761 |
|
1735 | 1762 | def geomspace(
|
|
0 commit comments