Skip to content

Commit df9dcf6

Browse files
committed
Simplify boolean operations with any and all
1 parent 5e74536 commit df9dcf6

File tree

2 files changed

+8
-10
lines changed

2 files changed

+8
-10
lines changed

pytensor/gradient.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1041,13 +1041,12 @@ def access_term_cache(node):
10411041
# list of bools indicating if each input is connected to the cost
10421042
inputs_connected = [
10431043
(
1044-
True
1045-
in [
1044+
any(
10461045
input_to_output and output_to_cost
10471046
for input_to_output, output_to_cost in zip(
10481047
input_to_outputs, outputs_connected
10491048
)
1050-
]
1049+
)
10511050
)
10521051
for input_to_outputs in connection_pattern
10531052
]
@@ -1067,25 +1066,24 @@ def access_term_cache(node):
10671066
# List of bools indicating if each input only has NullType outputs
10681067
only_connected_to_nan = [
10691068
(
1070-
True
1071-
not in [
1069+
not any(
10721070
in_to_out and out_to_cost and not out_nan
10731071
for in_to_out, out_to_cost, out_nan in zip(
10741072
in_to_outs, outputs_connected, ograd_is_nan
10751073
)
1076-
]
1074+
)
10771075
)
10781076
for in_to_outs in connection_pattern
10791077
]
10801078

1081-
if True not in inputs_connected:
1079+
if not any(inputs_connected):
10821080
# All outputs of this op are disconnected so we can skip
10831081
# Calling the op's grad method and report that the inputs
10841082
# are disconnected
10851083
# (The op's grad method could do this too, but this saves the
10861084
# implementer the trouble of worrying about this case)
10871085
input_grads = [disconnected_type() for ipt in inputs]
1088-
elif False not in only_connected_to_nan:
1086+
elif all(only_connected_to_nan):
10891087
# All inputs are only connected to nan gradients, so we don't
10901088
# need to bother calling the grad method. We know the gradient
10911089
# with respect to all connected inputs is nan.

pytensor/tensor/elemwise.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,12 +201,12 @@ def make_node(self, _input):
201201
f"input is incorrect for this op. Expected {self.input_broadcastable}, got {ib}."
202202
)
203203
for expected, b in zip(self.input_broadcastable, ib):
204-
if expected is True and b is False:
204+
if expected and not b:
205205
raise TypeError(
206206
"The broadcastable pattern of the "
207207
f"input is incorrect for this op. Expected {self.input_broadcastable}, got {ib}."
208208
)
209-
# else, expected == b or expected is False and b is True
209+
# else, expected == b or not expected and b
210210
# Both case are good.
211211

212212
out_static_shape = []

0 commit comments

Comments
 (0)