|
| 1 | +import warnings |
1 | 2 | from collections.abc import Collection, Iterable
|
2 | 3 |
|
3 | 4 | import numpy as np
|
|
20 | 21 | from pytensor.raise_op import Assert
|
21 | 22 | from pytensor.scalar import int32 as int_t
|
22 | 23 | from pytensor.scalar import upcast
|
23 |
| -from pytensor.tensor import as_tensor_variable |
| 24 | +from pytensor.tensor import TensorLike, as_tensor_variable |
24 | 25 | from pytensor.tensor import basic as ptb
|
25 | 26 | from pytensor.tensor.basic import alloc, second
|
26 | 27 | from pytensor.tensor.exceptions import NotScalarConstantError
|
27 | 28 | from pytensor.tensor.math import abs as pt_abs
|
28 | 29 | from pytensor.tensor.math import all as pt_all
|
29 | 30 | from pytensor.tensor.math import eq as pt_eq
|
30 |
| -from pytensor.tensor.math import ge, lt, maximum, minimum, prod, switch |
| 31 | +from pytensor.tensor.math import ge, log, lt, maximum, minimum, prod, sign, switch |
31 | 32 | from pytensor.tensor.math import max as pt_max
|
32 | 33 | from pytensor.tensor.math import sum as pt_sum
|
33 | 34 | from pytensor.tensor.shape import specify_broadcastable
|
@@ -1583,27 +1584,278 @@ def broadcast_shape_iter(
|
1583 | 1584 | return tuple(result_dims)
|
1584 | 1585 |
|
1585 | 1586 |
|
1586 |
| -def geomspace(start, end, steps, base=10.0): |
1587 |
| - from pytensor.tensor.math import log |
| 1587 | +def _check_deprecated_inputs(stop, end, num, steps): |
| 1588 | + if end is not None: |
| 1589 | + warnings.warn( |
| 1590 | + "The 'end' parameter is deprecated and will be removed in a future version. Use 'stop' instead.", |
| 1591 | + DeprecationWarning, |
| 1592 | + ) |
| 1593 | + stop = end |
| 1594 | + if steps is not None: |
| 1595 | + warnings.warn( |
| 1596 | + "The 'steps' parameter is deprecated and will be removed in a future version. Use 'num' instead.", |
| 1597 | + DeprecationWarning, |
| 1598 | + ) |
| 1599 | + num = steps |
| 1600 | + |
| 1601 | + return stop, num |
| 1602 | + |
| 1603 | + |
| 1604 | +def _linspace_core( |
| 1605 | + start, stop, num, endpoint=True, retstep=False, dtype=None, axis=0 |
| 1606 | +) -> TensorVariable | tuple[TensorVariable, TensorVariable]: |
| 1607 | + div = (num - 1) if endpoint else num |
| 1608 | + delta = (stop - start).astype(dtype) |
| 1609 | + step = delta / div |
| 1610 | + |
| 1611 | + samples = ptb.arange(0, num, dtype=dtype).reshape((-1,) + (1,) * delta.ndim) |
| 1612 | + samples = samples * step + start |
| 1613 | + |
| 1614 | + if endpoint and num > 1: |
| 1615 | + set_subtensor(samples[-1, ...], stop, inplace=True) |
| 1616 | + |
| 1617 | + if axis != 0: |
| 1618 | + samples = ptb.moveaxis(samples, 0, axis) |
| 1619 | + |
| 1620 | + if retstep: |
| 1621 | + return samples, step |
| 1622 | + |
| 1623 | + return samples |
| 1624 | + |
| 1625 | + |
| 1626 | +def _broadcast_inputs_and_dtypes(*args, dtype=None): |
| 1627 | + args = map(ptb.as_tensor_variable, args) |
| 1628 | + args = broadcast_arrays(*args) |
| 1629 | + |
| 1630 | + if dtype is None: |
| 1631 | + dtype = pytensor.config.floatX |
| 1632 | + |
| 1633 | + return args, dtype |
| 1634 | + |
| 1635 | + |
| 1636 | +def _broadcast_base_with_inputs(start, stop, base, dtype, axis): |
| 1637 | + """ |
| 1638 | + Broadcast the base tensor with the start and stop tensors if base is not a scalar. This is important because it |
| 1639 | + may change how the axis argument is interpreted in the final output. |
| 1640 | +
|
| 1641 | + Parameters |
| 1642 | + ---------- |
| 1643 | + start |
| 1644 | + stop |
| 1645 | + base |
| 1646 | + dtype |
| 1647 | + axis |
| 1648 | +
|
| 1649 | + Returns |
| 1650 | + ------- |
| 1651 | +
|
| 1652 | + """ |
| 1653 | + base = ptb.as_tensor_variable(base, dtype=dtype) |
| 1654 | + if base.ndim > 0: |
| 1655 | + ndmax = len(broadcast_shape(start, stop, base)) |
| 1656 | + start, stop, base = ( |
| 1657 | + ptb.shape_padleft(a, ndmax - a.ndim) for a in (start, stop, base) |
| 1658 | + ) |
| 1659 | + base = ptb.expand_dims(base, axis=(axis,)) |
| 1660 | + |
| 1661 | + return start, stop, base |
| 1662 | + |
| 1663 | + |
| 1664 | +def linspace( |
| 1665 | + start: TensorLike, |
| 1666 | + stop: TensorLike, |
| 1667 | + num: TensorLike = 50, |
| 1668 | + endpoint: bool = True, |
| 1669 | + retstep: bool = False, |
| 1670 | + dtype: str | None = None, |
| 1671 | + axis: int = 0, |
| 1672 | + end: TensorLike | None = None, |
| 1673 | + steps: TensorLike | None = None, |
| 1674 | +) -> TensorVariable | tuple[TensorVariable, TensorVariable]: |
| 1675 | + """ |
| 1676 | + Return evenly spaced numbers over a specified interval. |
| 1677 | +
|
| 1678 | + Returns `num` evenly spaced samples, calculated over the interval [`start`, `stop`]. |
| 1679 | +
|
| 1680 | + The endpoint of the interval can optionally be excluded. |
| 1681 | +
|
| 1682 | + Parameters |
| 1683 | + ---------- |
| 1684 | + start: int, float, or TensorVariable |
| 1685 | + The starting value of the sequence. |
| 1686 | +
|
| 1687 | + stop: int, float or TensorVariable |
| 1688 | + The end value of the sequence, unless `endpoint` is set to False. |
| 1689 | + In that case, the sequence consists of all but the last of `num + 1` evenly spaced samples, such that `stop` is excluded. |
| 1690 | +
|
| 1691 | + num: int |
| 1692 | + Number of samples to generate. Must be non-negative. |
| 1693 | +
|
| 1694 | + endpoint: bool |
| 1695 | + Whether to include the endpoint in the range. |
| 1696 | +
|
| 1697 | + retstep: bool |
| 1698 | + If true, returns both the samples and an array of steps between samples. |
| 1699 | +
|
| 1700 | + dtype: str, optional |
| 1701 | + dtype of the output tensor(s). If None, the dtype is inferred from that of the values provided to the `start` |
| 1702 | + and `end` arguments. |
1588 | 1703 |
|
1589 |
| - start = ptb.as_tensor_variable(start) |
1590 |
| - end = ptb.as_tensor_variable(end) |
1591 |
| - return base ** linspace(log(start) / log(base), log(end) / log(base), steps) |
| 1704 | + axis: int |
| 1705 | + Axis along which to generate samples. Ignored if both `start` and `end` have dimension 0. By default, axis=0 |
| 1706 | + will insert the samples on a new left-most dimension. To insert samples on a right-most dimension, use axis=-1. |
1592 | 1707 |
|
| 1708 | + end: int, float or TensorVariable |
| 1709 | + .. warning:: |
| 1710 | + The "end" parameter is deprecated and will be removed in a future version. Use "stop" instead. |
| 1711 | + The end value of the sequence, unless `endpoint` is set to False. |
| 1712 | + In that case, the sequence consists of all but the last of `num + 1` evenly spaced samples, such that `end` is |
| 1713 | + excluded. |
1593 | 1714 |
|
1594 |
| -def logspace(start, end, steps, base=10.0): |
1595 |
| - start = ptb.as_tensor_variable(start) |
1596 |
| - end = ptb.as_tensor_variable(end) |
1597 |
| - return base ** linspace(start, end, steps) |
| 1715 | + steps: float, int, or TensorVariable |
| 1716 | + .. warning:: |
| 1717 | + The "steps" parameter is deprecated and will be removed in a future version. Use "num" instead. |
1598 | 1718 |
|
| 1719 | + Number of samples to generate. Must be non-negative |
| 1720 | +
|
| 1721 | + Returns |
| 1722 | + ------- |
| 1723 | + samples: TensorVariable |
| 1724 | + Tensor containing `num` evenly-spaced values between [start, stop]. The range is inclusive if `endpoint` is True. |
| 1725 | +
|
| 1726 | + step: TensorVariable |
| 1727 | + Tensor containing the spacing between samples. Only returned if `retstep` is True. |
| 1728 | + """ |
| 1729 | + end, num = _check_deprecated_inputs(stop, end, num, steps) |
| 1730 | + (start, stop), type = _broadcast_inputs_and_dtypes(start, stop, dtype=dtype) |
| 1731 | + |
| 1732 | + return _linspace_core(start, stop, num, endpoint, retstep, dtype, axis) |
| 1733 | + |
| 1734 | + |
| 1735 | +def geomspace( |
| 1736 | + start: TensorLike, |
| 1737 | + stop: TensorLike, |
| 1738 | + num: int = 50, |
| 1739 | + base: float = 10.0, |
| 1740 | + endpoint: bool = True, |
| 1741 | + dtype: str | None = None, |
| 1742 | + axis: int = 0, |
| 1743 | + end: TensorLike | None = None, |
| 1744 | + steps: TensorLike | None = None, |
| 1745 | +) -> TensorVariable: |
| 1746 | + """ |
| 1747 | + Return numbers spaced evenly on a log scale (a geometric progression). |
| 1748 | +
|
| 1749 | + Parameters |
| 1750 | + ---------- |
| 1751 | + Returns `num` evenly spaced samples, calculated over the interval [`start`, `stop`]. |
| 1752 | +
|
| 1753 | + The endpoint of the interval can optionally be excluded. |
| 1754 | +
|
| 1755 | + Parameters |
| 1756 | + ---------- |
| 1757 | + start: int, float, or TensorVariable |
| 1758 | + The starting value of the sequence. |
| 1759 | +
|
| 1760 | + stop: int, float or TensorVariable |
| 1761 | + The end value of the sequence, unless `endpoint` is set to False. |
| 1762 | + In that case, the sequence consists of all but the last of `num + 1` evenly spaced samples, such that `stop` is excluded. |
| 1763 | +
|
| 1764 | + num: int |
| 1765 | + Number of samples to generate. Must be non-negative. |
| 1766 | +
|
| 1767 | + base: float |
| 1768 | + The base of the log space. The step size between the elements in ln(samples) / ln(base) |
| 1769 | + (or log_base(samples)) is uniform. |
| 1770 | +
|
| 1771 | + endpoint: bool |
| 1772 | + Whether to include the endpoint in the range. |
| 1773 | +
|
| 1774 | + dtype: str, optional |
| 1775 | + dtype of the output tensor(s). If None, the dtype is inferred from that of the values provided to the `start` |
| 1776 | + and `end` arguments. |
| 1777 | +
|
| 1778 | + axis: int |
| 1779 | + Axis along which to generate samples. Ignored if both `start` and `end` have dimension 0. By default, axis=0 |
| 1780 | + will insert the samples on a new left-most dimension. To insert samples on a right-most dimension, use axis=-1. |
| 1781 | +
|
| 1782 | + end: int, float or TensorVariable |
| 1783 | + .. warning:: |
| 1784 | + The "end" parameter is deprecated and will be removed in a future version. Use "stop" instead. |
| 1785 | + The end value of the sequence, unless `endpoint` is set to False. |
| 1786 | + In that case, the sequence consists of all but the last of `num + 1` evenly spaced samples, such that `end` is |
| 1787 | + excluded. |
| 1788 | +
|
| 1789 | + steps: float, int, or TensorVariable |
| 1790 | + .. warning:: |
| 1791 | + The "steps" parameter is deprecated and will be removed in a future version. Use "num" instead. |
| 1792 | +
|
| 1793 | + Number of samples to generate. Must be non-negative |
| 1794 | +
|
| 1795 | + Returns |
| 1796 | + ------- |
| 1797 | + samples: TensorVariable |
| 1798 | + Tensor containing `num` evenly-spaced values between [start, stop]. The range is inclusive if `endpoint` is True. |
| 1799 | + """ |
| 1800 | + stop, num = _check_deprecated_inputs(stop, end, num, steps) |
| 1801 | + (start, stop), dtype = _broadcast_inputs_and_dtypes(start, stop, dtype=dtype) |
| 1802 | + start, stop, base = _broadcast_base_with_inputs(start, stop, base, dtype, axis) |
| 1803 | + |
| 1804 | + out_sign = sign(start) |
| 1805 | + log_start, log_stop = ( |
| 1806 | + log(start * out_sign) / log(base), |
| 1807 | + log(stop * out_sign) / log(base), |
| 1808 | + ) |
| 1809 | + result = _linspace_core( |
| 1810 | + start=log_start, |
| 1811 | + stop=log_stop, |
| 1812 | + num=num, |
| 1813 | + endpoint=endpoint, |
| 1814 | + dtype=dtype, |
| 1815 | + axis=0, |
| 1816 | + retstep=False, |
| 1817 | + ) |
| 1818 | + result = base**result |
| 1819 | + |
| 1820 | + if num > 0: |
| 1821 | + set_subtensor(result[0, ...], start, inplace=True) |
| 1822 | + if num > 1 and endpoint: |
| 1823 | + set_subtensor(result[-1, ...], stop, inplace=True) |
| 1824 | + |
| 1825 | + result = result * out_sign |
| 1826 | + |
| 1827 | + if axis != 0: |
| 1828 | + result = ptb.moveaxis(result, 0, axis) |
| 1829 | + |
| 1830 | + return result |
| 1831 | + |
| 1832 | + |
| 1833 | +def logspace( |
| 1834 | + start: TensorLike, |
| 1835 | + stop: TensorLike, |
| 1836 | + num: int = 50, |
| 1837 | + base: float = 10.0, |
| 1838 | + endpoint: bool = True, |
| 1839 | + dtype: str | None = None, |
| 1840 | + axis: int = 0, |
| 1841 | + end: TensorLike | None = None, |
| 1842 | + steps: TensorLike | None = None, |
| 1843 | +) -> TensorVariable: |
| 1844 | + stop, num = _check_deprecated_inputs(stop, end, num, steps) |
| 1845 | + (start, stop), type = _broadcast_inputs_and_dtypes(start, stop, dtype=dtype) |
| 1846 | + start, stop, base = _broadcast_base_with_inputs(start, stop, base, dtype, axis) |
| 1847 | + |
| 1848 | + ls = _linspace_core( |
| 1849 | + start=start, |
| 1850 | + stop=stop, |
| 1851 | + num=num, |
| 1852 | + endpoint=endpoint, |
| 1853 | + dtype=dtype, |
| 1854 | + axis=axis, |
| 1855 | + retstep=False, |
| 1856 | + ) |
1599 | 1857 |
|
1600 |
| -def linspace(start, end, steps): |
1601 |
| - start = ptb.as_tensor_variable(start) |
1602 |
| - end = ptb.as_tensor_variable(end) |
1603 |
| - arr = ptb.arange(steps) |
1604 |
| - arr = ptb.shape_padright(arr, max(start.ndim, end.ndim)) |
1605 |
| - multiplier = (end - start) / (steps - 1) |
1606 |
| - return start + arr * multiplier |
| 1858 | + return base**ls |
1607 | 1859 |
|
1608 | 1860 |
|
1609 | 1861 | def broadcast_to(
|
|
0 commit comments