Skip to content

Commit 3cc5766

Browse files
Refactor linspace, logspace, and geomspace to match numpy implementation
1 parent 5d0acd4 commit 3cc5766

File tree

2 files changed

+291
-40
lines changed

2 files changed

+291
-40
lines changed

pytensor/tensor/extra_ops.py

+270-18
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import warnings
12
from collections.abc import Collection, Iterable
23

34
import numpy as np
@@ -20,14 +21,14 @@
2021
from pytensor.raise_op import Assert
2122
from pytensor.scalar import int32 as int_t
2223
from pytensor.scalar import upcast
23-
from pytensor.tensor import as_tensor_variable
24+
from pytensor.tensor import TensorLike, as_tensor_variable
2425
from pytensor.tensor import basic as ptb
2526
from pytensor.tensor.basic import alloc, second
2627
from pytensor.tensor.exceptions import NotScalarConstantError
2728
from pytensor.tensor.math import abs as pt_abs
2829
from pytensor.tensor.math import all as pt_all
2930
from pytensor.tensor.math import eq as pt_eq
30-
from pytensor.tensor.math import ge, lt, maximum, minimum, prod, switch
31+
from pytensor.tensor.math import ge, log, lt, maximum, minimum, prod, sign, switch
3132
from pytensor.tensor.math import max as pt_max
3233
from pytensor.tensor.math import sum as pt_sum
3334
from pytensor.tensor.shape import specify_broadcastable
@@ -1583,27 +1584,278 @@ def broadcast_shape_iter(
15831584
return tuple(result_dims)
15841585

15851586

1586-
def geomspace(start, end, steps, base=10.0):
1587-
from pytensor.tensor.math import log
1587+
def _check_deprecated_inputs(stop, end, num, steps):
1588+
if end is not None:
1589+
warnings.warn(
1590+
"The 'end' parameter is deprecated and will be removed in a future version. Use 'stop' instead.",
1591+
DeprecationWarning,
1592+
)
1593+
stop = end
1594+
if steps is not None:
1595+
warnings.warn(
1596+
"The 'steps' parameter is deprecated and will be removed in a future version. Use 'num' instead.",
1597+
DeprecationWarning,
1598+
)
1599+
num = steps
1600+
1601+
return stop, num
1602+
1603+
1604+
def _linspace_core(
1605+
start, stop, num, endpoint=True, retstep=False, dtype=None, axis=0
1606+
) -> TensorVariable | tuple[TensorVariable, TensorVariable]:
1607+
div = (num - 1) if endpoint else num
1608+
delta = (stop - start).astype(dtype)
1609+
step = delta / div
1610+
1611+
samples = ptb.arange(0, num, dtype=dtype).reshape((-1,) + (1,) * delta.ndim)
1612+
samples = samples * step + start
1613+
1614+
if endpoint and num > 1:
1615+
set_subtensor(samples[-1, ...], stop, inplace=True)
1616+
1617+
if axis != 0:
1618+
samples = ptb.moveaxis(samples, 0, axis)
1619+
1620+
if retstep:
1621+
return samples, step
1622+
1623+
return samples
1624+
1625+
1626+
def _broadcast_inputs_and_dtypes(*args, dtype=None):
1627+
args = map(ptb.as_tensor_variable, args)
1628+
args = broadcast_arrays(*args)
1629+
1630+
if dtype is None:
1631+
dtype = pytensor.config.floatX
1632+
1633+
return args, dtype
1634+
1635+
1636+
def _broadcast_base_with_inputs(start, stop, base, dtype, axis):
1637+
"""
1638+
Broadcast the base tensor with the start and stop tensors if base is not a scalar. This is important because it
1639+
may change how the axis argument is interpreted in the final output.
1640+
1641+
Parameters
1642+
----------
1643+
start
1644+
stop
1645+
base
1646+
dtype
1647+
axis
1648+
1649+
Returns
1650+
-------
1651+
1652+
"""
1653+
base = ptb.as_tensor_variable(base, dtype=dtype)
1654+
if base.ndim > 0:
1655+
ndmax = len(broadcast_shape(start, stop, base))
1656+
start, stop, base = (
1657+
ptb.shape_padleft(a, ndmax - a.ndim) for a in (start, stop, base)
1658+
)
1659+
base = ptb.expand_dims(base, axis=(axis,))
1660+
1661+
return start, stop, base
1662+
1663+
1664+
def linspace(
1665+
start: TensorLike,
1666+
stop: TensorLike,
1667+
num: TensorLike = 50,
1668+
endpoint: bool = True,
1669+
retstep: bool = False,
1670+
dtype: str | None = None,
1671+
axis: int = 0,
1672+
end: TensorLike | None = None,
1673+
steps: TensorLike | None = None,
1674+
) -> TensorVariable | tuple[TensorVariable, TensorVariable]:
1675+
"""
1676+
Return evenly spaced numbers over a specified interval.
1677+
1678+
Returns `num` evenly spaced samples, calculated over the interval [`start`, `stop`].
1679+
1680+
The endpoint of the interval can optionally be excluded.
1681+
1682+
Parameters
1683+
----------
1684+
start: int, float, or TensorVariable
1685+
The starting value of the sequence.
1686+
1687+
stop: int, float or TensorVariable
1688+
The end value of the sequence, unless `endpoint` is set to False.
1689+
In that case, the sequence consists of all but the last of `num + 1` evenly spaced samples, such that `stop` is excluded.
1690+
1691+
num: int
1692+
Number of samples to generate. Must be non-negative.
1693+
1694+
endpoint: bool
1695+
Whether to include the endpoint in the range.
1696+
1697+
retstep: bool
1698+
If true, returns both the samples and an array of steps between samples.
1699+
1700+
dtype: str, optional
1701+
dtype of the output tensor(s). If None, the dtype is inferred from that of the values provided to the `start`
1702+
and `end` arguments.
15881703
1589-
start = ptb.as_tensor_variable(start)
1590-
end = ptb.as_tensor_variable(end)
1591-
return base ** linspace(log(start) / log(base), log(end) / log(base), steps)
1704+
axis: int
1705+
Axis along which to generate samples. Ignored if both `start` and `end` have dimension 0. By default, axis=0
1706+
will insert the samples on a new left-most dimension. To insert samples on a right-most dimension, use axis=-1.
15921707
1708+
end: int, float or TensorVariable
1709+
.. warning::
1710+
The "end" parameter is deprecated and will be removed in a future version. Use "stop" instead.
1711+
The end value of the sequence, unless `endpoint` is set to False.
1712+
In that case, the sequence consists of all but the last of `num + 1` evenly spaced samples, such that `end` is
1713+
excluded.
15931714
1594-
def logspace(start, end, steps, base=10.0):
1595-
start = ptb.as_tensor_variable(start)
1596-
end = ptb.as_tensor_variable(end)
1597-
return base ** linspace(start, end, steps)
1715+
steps: float, int, or TensorVariable
1716+
.. warning::
1717+
The "steps" parameter is deprecated and will be removed in a future version. Use "num" instead.
15981718
1719+
Number of samples to generate. Must be non-negative
1720+
1721+
Returns
1722+
-------
1723+
samples: TensorVariable
1724+
Tensor containing `num` evenly-spaced values between [start, stop]. The range is inclusive if `endpoint` is True.
1725+
1726+
step: TensorVariable
1727+
Tensor containing the spacing between samples. Only returned if `retstep` is True.
1728+
"""
1729+
end, num = _check_deprecated_inputs(stop, end, num, steps)
1730+
(start, stop), type = _broadcast_inputs_and_dtypes(start, stop, dtype=dtype)
1731+
1732+
return _linspace_core(start, stop, num, endpoint, retstep, dtype, axis)
1733+
1734+
1735+
def geomspace(
1736+
start: TensorLike,
1737+
stop: TensorLike,
1738+
num: int = 50,
1739+
base: float = 10.0,
1740+
endpoint: bool = True,
1741+
dtype: str | None = None,
1742+
axis: int = 0,
1743+
end: TensorLike | None = None,
1744+
steps: TensorLike | None = None,
1745+
) -> TensorVariable:
1746+
"""
1747+
Return numbers spaced evenly on a log scale (a geometric progression).
1748+
1749+
Parameters
1750+
----------
1751+
Returns `num` evenly spaced samples, calculated over the interval [`start`, `stop`].
1752+
1753+
The endpoint of the interval can optionally be excluded.
1754+
1755+
Parameters
1756+
----------
1757+
start: int, float, or TensorVariable
1758+
The starting value of the sequence.
1759+
1760+
stop: int, float or TensorVariable
1761+
The end value of the sequence, unless `endpoint` is set to False.
1762+
In that case, the sequence consists of all but the last of `num + 1` evenly spaced samples, such that `stop` is excluded.
1763+
1764+
num: int
1765+
Number of samples to generate. Must be non-negative.
1766+
1767+
base: float
1768+
The base of the log space. The step size between the elements in ln(samples) / ln(base)
1769+
(or log_base(samples)) is uniform.
1770+
1771+
endpoint: bool
1772+
Whether to include the endpoint in the range.
1773+
1774+
dtype: str, optional
1775+
dtype of the output tensor(s). If None, the dtype is inferred from that of the values provided to the `start`
1776+
and `end` arguments.
1777+
1778+
axis: int
1779+
Axis along which to generate samples. Ignored if both `start` and `end` have dimension 0. By default, axis=0
1780+
will insert the samples on a new left-most dimension. To insert samples on a right-most dimension, use axis=-1.
1781+
1782+
end: int, float or TensorVariable
1783+
.. warning::
1784+
The "end" parameter is deprecated and will be removed in a future version. Use "stop" instead.
1785+
The end value of the sequence, unless `endpoint` is set to False.
1786+
In that case, the sequence consists of all but the last of `num + 1` evenly spaced samples, such that `end` is
1787+
excluded.
1788+
1789+
steps: float, int, or TensorVariable
1790+
.. warning::
1791+
The "steps" parameter is deprecated and will be removed in a future version. Use "num" instead.
1792+
1793+
Number of samples to generate. Must be non-negative
1794+
1795+
Returns
1796+
-------
1797+
samples: TensorVariable
1798+
Tensor containing `num` evenly-spaced values between [start, stop]. The range is inclusive if `endpoint` is True.
1799+
"""
1800+
stop, num = _check_deprecated_inputs(stop, end, num, steps)
1801+
(start, stop), dtype = _broadcast_inputs_and_dtypes(start, stop, dtype=dtype)
1802+
start, stop, base = _broadcast_base_with_inputs(start, stop, base, dtype, axis)
1803+
1804+
out_sign = sign(start)
1805+
log_start, log_stop = (
1806+
log(start * out_sign) / log(base),
1807+
log(stop * out_sign) / log(base),
1808+
)
1809+
result = _linspace_core(
1810+
start=log_start,
1811+
stop=log_stop,
1812+
num=num,
1813+
endpoint=endpoint,
1814+
dtype=dtype,
1815+
axis=0,
1816+
retstep=False,
1817+
)
1818+
result = base**result
1819+
1820+
if num > 0:
1821+
set_subtensor(result[0, ...], start, inplace=True)
1822+
if num > 1 and endpoint:
1823+
set_subtensor(result[-1, ...], stop, inplace=True)
1824+
1825+
result = result * out_sign
1826+
1827+
if axis != 0:
1828+
result = ptb.moveaxis(result, 0, axis)
1829+
1830+
return result
1831+
1832+
1833+
def logspace(
1834+
start: TensorLike,
1835+
stop: TensorLike,
1836+
num: int = 50,
1837+
base: float = 10.0,
1838+
endpoint: bool = True,
1839+
dtype: str | None = None,
1840+
axis: int = 0,
1841+
end: TensorLike | None = None,
1842+
steps: TensorLike | None = None,
1843+
) -> TensorVariable:
1844+
stop, num = _check_deprecated_inputs(stop, end, num, steps)
1845+
(start, stop), type = _broadcast_inputs_and_dtypes(start, stop, dtype=dtype)
1846+
start, stop, base = _broadcast_base_with_inputs(start, stop, base, dtype, axis)
1847+
1848+
ls = _linspace_core(
1849+
start=start,
1850+
stop=stop,
1851+
num=num,
1852+
endpoint=endpoint,
1853+
dtype=dtype,
1854+
axis=axis,
1855+
retstep=False,
1856+
)
15991857

1600-
def linspace(start, end, steps):
1601-
start = ptb.as_tensor_variable(start)
1602-
end = ptb.as_tensor_variable(end)
1603-
arr = ptb.arange(steps)
1604-
arr = ptb.shape_padright(arr, max(start.ndim, end.ndim))
1605-
multiplier = (end - start) / (steps - 1)
1606-
return start + arr * multiplier
1858+
return base**ls
16071859

16081860

16091861
def broadcast_to(

tests/tensor/test_extra_ops.py

+21-22
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,6 @@
3535
diff,
3636
fill_diagonal,
3737
fill_diagonal_offset,
38-
geomspace,
39-
linspace,
40-
logspace,
4138
ravel_multi_index,
4239
repeat,
4340
searchsorted,
@@ -1281,25 +1278,27 @@ def test_broadcast_arrays():
12811278

12821279

12831280
@pytest.mark.parametrize(
1284-
"start, stop, num_samples",
1281+
"op",
1282+
["linspace", "logspace", "geomspace"],
1283+
ids=["linspace", "logspace", "geomspace"],
1284+
)
1285+
@pytest.mark.parametrize(
1286+
"start, stop, num_samples, endpoint, axis",
12851287
[
1286-
(1, 10, 50),
1287-
(np.array([5, 6]), np.array([[10, 10], [10, 10]]), 25),
1288-
(1, np.array([5, 6]), 30),
1288+
(1, 10, 50, True, 0),
1289+
(np.array([5, 6]), np.array([[10, 10], [10, 10]]), 25, True, 0),
1290+
(np.array([5, 6]), np.array([[10, 10], [10, 10]]), 25, True, 1),
1291+
(np.array([5, 6]), np.array([[10, 10], [10, 10]]), 25, False, -1),
1292+
(1, np.array([5, 6]), 30, True, 0),
1293+
(1, np.array([5, 6]), 30, False, -1),
12891294
],
12901295
)
1291-
def test_space_ops(start, stop, num_samples):
1292-
z = linspace(start, stop, num_samples)
1293-
pytensor_res = function(inputs=[], outputs=z)()
1294-
numpy_res = np.linspace(start, stop, num=num_samples)
1295-
assert np.allclose(pytensor_res, numpy_res)
1296-
1297-
z = logspace(start, stop, num_samples)
1298-
pytensor_res = function(inputs=[], outputs=z)()
1299-
numpy_res = np.logspace(start, stop, num=num_samples)
1300-
assert np.allclose(pytensor_res, numpy_res)
1301-
1302-
z = geomspace(start, stop, num_samples)
1303-
pytensor_res = function(inputs=[], outputs=z)()
1304-
numpy_res = np.geomspace(start, stop, num=num_samples)
1305-
assert np.allclose(pytensor_res, numpy_res)
1296+
def test_space_ops(op, start, stop, num_samples, endpoint, axis):
1297+
pt_func = getattr(pt, op)
1298+
np_func = getattr(np, op)
1299+
z = pt_func(start, stop, num_samples, endpoint=endpoint, axis=axis)
1300+
1301+
numpy_res = np_func(start, stop, num=num_samples, endpoint=endpoint, axis=axis)
1302+
pytensor_res = function(inputs=[], outputs=z, mode="FAST_COMPILE")()
1303+
1304+
np.testing.assert_allclose(pytensor_res, numpy_res, atol=1e-6, rtol=1e-6)

0 commit comments

Comments
 (0)