@@ -1638,6 +1638,8 @@ def _linspace_core(
1638
1638
1639
1639
1640
1640
def _broadcast_inputs (* args ):
1641
+ """Helper function to preprocess inputs to *space Ops"""
1642
+
1641
1643
args = map (ptb .as_tensor_variable , args )
1642
1644
args = broadcast_arrays (* args )
1643
1645
@@ -1651,14 +1653,23 @@ def _broadcast_base_with_inputs(start, stop, base, axis):
1651
1653
1652
1654
Parameters
1653
1655
----------
1654
- start
1655
- stop
1656
- base
1657
- axis
1656
+ start: TensorVariable
1657
+ The start value(s) of the sequence(s).
1658
+ stop: TensorVariable
1659
+ The end value(s) of the sequence(s)
1660
+ base: TensorVariable
1661
+ The log base value(s) of the sequence(s)
1662
+ axis: int
1663
+ The axis along which to generate samples.
1658
1664
1659
1665
Returns
1660
1666
-------
1661
-
1667
+ start: TensorVariable
1668
+ The start value(s) of the sequence(s), broadcast with the base tensor if necessary.
1669
+ stop: TensorVariable
1670
+ The end value(s) of the sequence(s), broadcast with the base tensor if necessary.
1671
+ base: TensorVariable
1672
+ The log base value(s) of the sequence(s), broadcast with the start and stop tensors if necessary.
1662
1673
"""
1663
1674
base = ptb .as_tensor_variable (base )
1664
1675
if base .ndim > 0 :
@@ -1839,10 +1850,9 @@ def geomspace(
1839
1850
)
1840
1851
result = base ** result
1841
1852
1842
- if num > 0 :
1843
- result = set_subtensor (result [0 , ...], start )
1844
- if num > 1 and endpoint :
1845
- result = set_subtensor (result [- 1 , ...], stop )
1853
+ result = switch (gt (num , 0 ), set_subtensor (result [0 , ...], start ), result )
1854
+ if endpoint :
1855
+ result = switch (gt (num , 1 ), set_subtensor (result [- 1 , ...], stop ), result )
1846
1856
1847
1857
result = result * out_sign
1848
1858
0 commit comments