@@ -1640,6 +1640,8 @@ def _linspace_core(
1640
1640
1641
1641
1642
1642
def _broadcast_inputs (* args ):
1643
+ """Helper function to preprocess inputs to *space Ops"""
1644
+
1643
1645
args = map (ptb .as_tensor_variable , args )
1644
1646
args = broadcast_arrays (* args )
1645
1647
@@ -1653,14 +1655,23 @@ def _broadcast_base_with_inputs(start, stop, base, axis):
1653
1655
1654
1656
Parameters
1655
1657
----------
1656
- start
1657
- stop
1658
- base
1659
- axis
1658
+ start: TensorVariable
1659
+ The start value(s) of the sequence(s).
1660
+ stop: TensorVariable
1661
+ The end value(s) of the sequence(s)
1662
+ base: TensorVariable
1663
+ The log base value(s) of the sequence(s)
1664
+ axis: int
1665
+ The axis along which to generate samples.
1660
1666
1661
1667
Returns
1662
1668
-------
1663
-
1669
+ start: TensorVariable
1670
+ The start value(s) of the sequence(s), broadcast with the base tensor if necessary.
1671
+ stop: TensorVariable
1672
+ The end value(s) of the sequence(s), broadcast with the base tensor if necessary.
1673
+ base: TensorVariable
1674
+ The log base value(s) of the sequence(s), broadcast with the start and stop tensors if necessary.
1664
1675
"""
1665
1676
base = ptb .as_tensor_variable (base )
1666
1677
if base .ndim > 0 :
@@ -1841,10 +1852,9 @@ def geomspace(
1841
1852
)
1842
1853
result = base ** result
1843
1854
1844
- if num > 0 :
1845
- result = set_subtensor (result [0 , ...], start )
1846
- if num > 1 and endpoint :
1847
- result = set_subtensor (result [- 1 , ...], stop )
1855
+ result = switch (gt (num , 0 ), set_subtensor (result [0 , ...], start ), result )
1856
+ if endpoint :
1857
+ result = switch (gt (num , 1 ), set_subtensor (result [- 1 , ...], stop ), result )
1848
1858
1849
1859
result = result * out_sign
1850
1860
0 commit comments