Skip to content

Commit c4fb0cf

Browse files
committed
Raise explicitly on Python methods that are incompatible with lazy variables
Notably changes the behavior of `__bool__` to always raise. Before there was a hack based on whether a variable had been compared to something before.
1 parent e00abf3 commit c4fb0cf

File tree

15 files changed

+107
-67
lines changed

15 files changed

+107
-67
lines changed

pytensor/compile/function/pfunc.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -569,7 +569,7 @@ def construct_pfunc_ins_and_outs(
569569
if not fgraph:
570570
# Extend the outputs with the updates on input variables so they are
571571
# also cloned
572-
additional_outputs = [i.update for i in inputs if i.update]
572+
additional_outputs = [i.update for i in inputs if i.update is not None]
573573
if outputs is None:
574574
out_list = []
575575
else:
@@ -608,7 +608,7 @@ def construct_pfunc_ins_and_outs(
608608
new_i.variable = iv
609609

610610
# If needed, replace the input's update by its cloned equivalent
611-
if i.update:
611+
if i.update is not None:
612612
new_i.update = clone_d[i.update]
613613

614614
new_inputs.append(new_i)

pytensor/compile/function/types.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def std_fgraph(
198198
update_mapping = {}
199199
out_idx = len(output_specs)
200200
for idx, input_spec in enumerate(input_specs):
201-
if input_spec.update:
201+
if input_spec.update is not None:
202202
updates.append(input_spec.update)
203203
update_mapping[out_idx] = idx
204204
out_idx += 1
@@ -1195,7 +1195,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
11951195
updated_fgraph_inputs = {
11961196
fgraph_i
11971197
for i, fgraph_i in zip(wrapped_inputs, fgraph.inputs, strict=True)
1198-
if getattr(i, "update", False)
1198+
if getattr(i, "update", None) is not None
11991199
}
12001200

12011201
# We can't use fgraph.inputs as this don't include Constant Value.
@@ -1351,7 +1351,11 @@ def check_unused_inputs(inputs, outputs, on_unused_input):
13511351
ancestors(
13521352
(
13531353
[o.variable for o in outputs]
1354-
+ [i.update for i in inputs if getattr(i, "update", False)]
1354+
+ [
1355+
i.update
1356+
for i in inputs
1357+
if getattr(i, "update", None) is not None
1358+
]
13551359
),
13561360
blockers=[i.variable for i in inputs],
13571361
)

pytensor/compile/nanguardmode.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def _is_numeric_value(arr, var):
3636
return False
3737
elif isinstance(arr, np.random.mtrand.RandomState | np.random.Generator):
3838
return False
39-
elif var and isinstance(var.type, RandomType):
39+
elif var is not None and isinstance(var.type, RandomType):
4040
return False
4141
elif isinstance(arr, slice):
4242
return False

pytensor/scalar/basic.py

+31-5
Original file line numberDiff line numberDiff line change
@@ -823,6 +823,37 @@ def get_scalar_type(dtype, cache: dict[str, ScalarType] = {}) -> ScalarType:
823823

824824

825825
class _scalar_py_operators:
826+
# These can't work because Python requires native output types
827+
def __bool__(self):
828+
raise TypeError(
829+
"ScalarVariable cannot be converted to Python boolean. "
830+
"Call `.astype(bool)` for the symbolic equivalent."
831+
)
832+
833+
def __index__(self):
834+
raise TypeError(
835+
"ScalarVariable cannot be converted to Python integer. "
836+
"Call `.astype(int)` for the symbolic equivalent."
837+
)
838+
839+
def __int__(self):
840+
raise TypeError(
841+
"ScalarVariable cannot be converted to Python integer. "
842+
"Call `.astype(int)` for the symbolic equivalent."
843+
)
844+
845+
def __float__(self):
846+
raise TypeError(
847+
"ScalarVariable cannot be converted to Python float. "
848+
"Call `.astype(float)` for the symbolic equivalent."
849+
)
850+
851+
def __complex__(self):
852+
raise TypeError(
853+
"ScalarVariable cannot be converted to Python complex number. "
854+
"Call `.astype(complex)` for the symbolic equivalent."
855+
)
856+
826857
# So that we can simplify checking code when we have a mixture of ScalarType
827858
# variables and Tensor variables
828859
ndim = 0
@@ -843,11 +874,6 @@ def __abs__(self):
843874
def __neg__(self):
844875
return neg(self)
845876

846-
# CASTS
847-
# def __int__(self): return AsInt(self).out
848-
# def __float__(self): return AsDouble(self).out
849-
# def __complex__(self): return AsComplex(self).out
850-
851877
# BITWISE
852878
def __invert__(self):
853879
return invert(self)

pytensor/scalar/loop.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,12 @@ def __init__(
6060
constant = []
6161
if not len(init) == len(update):
6262
raise ValueError("An update must be given for each init variable")
63-
if until:
63+
if until is not None:
6464
inputs, outputs = clone([*init, *constant], [*update, until])
6565
else:
6666
inputs, outputs = clone([*init, *constant], update)
6767

68-
self.is_while = bool(until)
68+
self.is_while = until is not None
6969
self.inputs, self.outputs = self._cleanup_graph(inputs, outputs)
7070
self._validate_updates(self.inputs, self.outputs)
7171

pytensor/scalar/math.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -856,7 +856,7 @@ def inner_loop_a(sum_a, delta, xpow, k_minus_one_minus_n, fac, dfac, x):
856856
dfac = k_minus_one_minus_n * dfac + fac
857857
fac *= k_minus_one_minus_n
858858
delta = dfac / xpow
859-
return (sum_a, delta, xpow, k_minus_one_minus_n, fac, dfac), ()
859+
return (sum_a, delta, xpow, k_minus_one_minus_n, fac, dfac), None
860860

861861
init = [sum_a0, delta, xpow, k_minus_one_minus_n, fac, dfac]
862862
constant = [x]

pytensor/scan/basic.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -979,7 +979,7 @@ def wrap_into_list(x):
979979
# user-specified within the inner-function (e.g. by returning an update
980980
# `dict`) or the `SharedVariable.default_update`s of a shared variable
981981
# created in the inner-function.
982-
if input.update and (is_local or input.variable in updates):
982+
if input.update is not None and (is_local or input.variable in updates):
983983
# We need to remove the `default_update`s on the shared
984984
# variables created within the context of the loop function
985985
# (e.g. via use of `RandomStream`); otherwise, they'll get

pytensor/tensor/basic.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -3430,7 +3430,14 @@ def __getitem__(self, *args):
34303430
raise NotImplementedError(
34313431
"Not implemented for slices whose step is complex"
34323432
)
3433-
ranges = [arange(sl.start or 0, sl.stop, sl.step or 1) for sl in args[0]]
3433+
ranges = [
3434+
arange(
3435+
sl.start if sl.start is not None else 0,
3436+
sl.stop,
3437+
sl.step if sl.step is not None else 1,
3438+
)
3439+
for sl in args[0]
3440+
]
34343441
shapes = [
34353442
tuple([1] * j + [r.shape[0]] + [1] * (ndim - 1 - j))
34363443
for j, r in enumerate(ranges)

pytensor/tensor/conv/abstract_conv.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -2199,7 +2199,7 @@ def __init__(
21992199
):
22002200
border_mode = "valid"
22012201

2202-
self.imshp = tuple(imshp) if imshp else (None,) * (2 + convdim)
2202+
self.imshp = tuple(imshp) if imshp is not None else (None,) * (2 + convdim)
22032203
for imshp_i in self.imshp:
22042204
if imshp_i is not None:
22052205
# Components of imshp should be constant or ints
@@ -2209,7 +2209,7 @@ def __init__(
22092209
raise ValueError(
22102210
"imshp should be None or a tuple of constant int values"
22112211
).with_traceback(sys.exc_info()[2])
2212-
if kshp:
2212+
if kshp is not None:
22132213
self.kshp = tuple(kshp)
22142214
else:
22152215
self.kshp = (None,) * ((2 + 2 * convdim) if unshared else (2 + convdim))

pytensor/tensor/math.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -1811,14 +1811,14 @@ def R_op(self, inputs, eval_points):
18111811
if eval_points[0] is None and eval_points[1] is None:
18121812
return [None]
18131813

1814-
if eval_points[0]:
1814+
if eval_points[0] is not None:
18151815
t1 = self(eval_points[0], inputs[1])
1816-
if eval_points[1]:
1816+
if eval_points[1] is not None:
18171817
t2 = self(inputs[0], eval_points[1])
18181818

1819-
if eval_points[0] and eval_points[1]:
1819+
if eval_points[0] is not None and eval_points[1] is not None:
18201820
return [t1 + t2]
1821-
elif eval_points[0]:
1821+
elif eval_points[0] is not None:
18221822
return [t1]
18231823
else:
18241824
return [t2]

pytensor/tensor/rewriting/blas.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -803,7 +803,7 @@ def local_dot22_to_dot22scalar(fgraph, node):
803803
"""
804804
if node.op != mul:
805805
return False
806-
i_dot22 = [x.owner and x.owner.op == _dot22 for x in node.inputs]
806+
i_dot22 = [x.owner is not None and x.owner.op == _dot22 for x in node.inputs]
807807
if not any(i_dot22):
808808
return False # no dot22
809809
if i_dot22.count(True) > 1:
@@ -813,14 +813,16 @@ def local_dot22_to_dot22scalar(fgraph, node):
813813
dot22_idx = i_dot22.index(True)
814814
d = node.inputs[dot22_idx]
815815
i_scalar = [_as_scalar(x, dtype=d.dtype) for x in node.inputs]
816-
if not any(i_scalar):
816+
if all(i is None for i in i_scalar):
817817
# Check if we can reorder the graph as this mul have a mul in inputs.
818818
# We support only 1 additional level of mul.
819819
# The canonizer should have merged those mul together.
820820
i_mul = [
821821
x.owner
822822
and x.owner.op == mul
823-
and any(_as_scalar(x_i, dtype=d.dtype) for x_i in x.owner.inputs)
823+
and any(
824+
_as_scalar(x_i, dtype=d.dtype) is not None for x_i in x.owner.inputs
825+
)
824826
for x in node.inputs
825827
]
826828
if not any(i_mul):
@@ -834,7 +836,7 @@ def local_dot22_to_dot22scalar(fgraph, node):
834836

835837
scalar_idx = -1
836838
for i, x in enumerate(m.owner.inputs):
837-
if _as_scalar(x, dtype=d.dtype) and (
839+
if _as_scalar(x, dtype=d.dtype) is not None and (
838840
pytensor.scalar.upcast(x.type.dtype, d.type.dtype) == d.type.dtype
839841
):
840842
scalar_idx = i

pytensor/tensor/rewriting/math.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -1331,14 +1331,14 @@ def local_sum_prod_of_mul_or_div(fgraph, node):
13311331

13321332
# If we have a `Prod`, then the outside terms need to be raised to the power of the number of elements
13331333
# that were contracted in the input
1334-
if isinstance(node.op, Prod) and inner_term:
1334+
if isinstance(node.op, Prod) and inner_term is not None:
13351335
dtype = inner_term.dtype
13361336
n_reduced_elements = prod(
13371337
[inner_term.shape[i].astype(dtype) for i in reduced_axes]
13381338
)
13391339
outer_term = outer_term**n_reduced_elements
13401340

1341-
if not inner_term:
1341+
if inner_term is None:
13421342
# Sum/Prod is useless, just return the outer_term
13431343
# (This can only happen for mul, not division)
13441344
new_out = outer_term
@@ -1992,7 +1992,7 @@ def local_pow_canonicalize(fgraph, node):
19921992
# x ** 1 = x
19931993
new_out = broadcast_arrays(*node.inputs)[0]
19941994

1995-
if not new_out:
1995+
if new_out is None:
19961996
return
19971997

19981998
if new_out.dtype != node.out.dtype:
@@ -2119,7 +2119,7 @@ def local_pow_to_nested_squaring(fgraph, node):
21192119
rval1_scal = None
21202120
while y_to_do > 0:
21212121
log_to_do = int(np.log2(y_to_do))
2122-
if rval1:
2122+
if rval1 is not None:
21232123
rval1 *= pow2[log_to_do]
21242124
rval1_scal *= pow2_scal[log_to_do]
21252125
else:
@@ -2137,7 +2137,7 @@ def local_pow_to_nested_squaring(fgraph, node):
21372137
rval = [reciprocal(rval1)]
21382138
else:
21392139
rval = [rval1]
2140-
if rval:
2140+
if rval is not None:
21412141
rval[0] = cast(rval[0], odtype)
21422142
return rval
21432143

pytensor/tensor/rewriting/special.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def softmax_simplifier(numerators, denominators):
162162
matching_denom = denominator
163163
break
164164

165-
if matching_denom:
165+
if matching_denom is not None:
166166
softmax = Softmax(axis=sum_axis)(numerator.owner.inputs[0])
167167
copy_stack_trace(numerator, softmax)
168168
numerators.remove(numerator)

pytensor/tensor/variable.py

+35-34
Original file line numberDiff line numberDiff line change
@@ -26,53 +26,54 @@
2626

2727

2828
class _tensor_py_operators:
29+
# These can't work because Python requires native output types
30+
def __bool__(self):
31+
raise TypeError(
32+
"TensorVariable cannot be converted to Python boolean. "
33+
"Call `.astype(bool)` for the symbolic equivalent."
34+
)
35+
36+
def __index__(self):
37+
raise TypeError(
38+
"TensorVariable cannot be converted to Python integer. "
39+
"Call `.astype(int)` for the symbolic equivalent."
40+
)
41+
42+
def __int__(self):
43+
raise TypeError(
44+
"TensorVariable cannot be converted to Python integer. "
45+
"Call `.astype(int)` for the symbolic equivalent."
46+
)
47+
48+
def __float__(self):
49+
raise TypeError(
50+
"TensorVariables cannot be converted to Python float. "
51+
"Call `.astype(float)` for the symbolic equivalent."
52+
)
53+
54+
def __complex__(self):
55+
raise TypeError(
56+
"TensorVariables cannot be converted to Python complex number. "
57+
"Call `.astype(complex)` for the symbolic equivalent."
58+
)
59+
2960
def __abs__(self):
3061
return pt.math.abs(self)
3162

3263
def __neg__(self):
3364
return pt.math.neg(self)
3465

35-
# These won't work because Python requires an int return value
36-
# def __int__(self): return convert_to_int32(self)
37-
# def __float__(self): return convert_to_float64(self)
38-
# def __complex__(self): return convert_to_complex128(self)
39-
40-
_is_nonzero = True
41-
4266
def __lt__(self, other):
43-
rval = pt.math.lt(self, other)
44-
rval._is_nonzero = False
45-
return rval
67+
return pt.math.lt(self, other)
4668

4769
def __le__(self, other):
48-
rval = pt.math.le(self, other)
49-
rval._is_nonzero = False
50-
return rval
70+
return pt.math.le(self, other)
5171

5272
def __gt__(self, other):
53-
rval = pt.math.gt(self, other)
54-
rval._is_nonzero = False
55-
return rval
73+
return pt.math.gt(self, other)
5674

5775
def __ge__(self, other):
58-
rval = pt.math.ge(self, other)
59-
rval._is_nonzero = False
60-
return rval
61-
62-
def __bool__(self):
63-
# This is meant to prohibit stuff like a < b < c, which is internally
64-
# implemented as (a < b) and (b < c). The trouble with this is the
65-
# side-effect that checking for a non-NULL a by typing "if a: ..."
66-
# uses the same __nonzero__ method. We want these both to work, but
67-
# it seems impossible. Currently, all vars evaluate to nonzero except
68-
# the return values of comparison operators, which raise this
69-
# exception. If you can think of a better solution, go for it!
70-
#
71-
# __bool__ is Python 3.x data model. __nonzero__ is Python 2.x.
72-
if self._is_nonzero:
73-
return True
74-
else:
75-
raise TypeError("Variables do not support boolean operations.")
76+
return pt.math.ge(self, other)
7677

7778
def __invert__(self):
7879
return pt.math.invert(self)

tests/tensor/test_type.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,7 @@ def test_tensor_creator_dtype_catch(dtype):
399399
tensor(dtype, shape=(None,))
400400

401401
# This should work
402-
assert tensor(dtype=dtype, shape=(None,))
402+
assert tensor(dtype=dtype, shape=(None,)) is not None
403403

404404

405405
def test_tensor_creator_ignores_rare_dtype_name():

0 commit comments

Comments
 (0)