Skip to content

Commit d62f4b1

Browse files
ArmavicaricardoV94
authored andcommitted
Replace more "if not a or not b" with "if not (a and b)"
1 parent f0e9354 commit d62f4b1

File tree

6 files changed

+39
-34
lines changed

6 files changed

+39
-34
lines changed

pytensor/sparse/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ def sparse_grad(var):
1515
"""
1616
from pytensor.tensor.subtensor import AdvancedSubtensor, AdvancedSubtensor1
1717

18-
if var.owner is None or not isinstance(
19-
var.owner.op, AdvancedSubtensor | AdvancedSubtensor1
18+
if not (
19+
var.owner and isinstance(var.owner.op, AdvancedSubtensor | AdvancedSubtensor1)
2020
):
2121
raise TypeError(
2222
"Sparse gradient is only implemented for AdvancedSubtensor and AdvancedSubtensor1"

pytensor/tensor/math.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -2134,8 +2134,8 @@ def dense_dot(a, b):
21342134
"""
21352135
a, b = as_tensor_variable(a), as_tensor_variable(b)
21362136

2137-
if not isinstance(a.type, DenseTensorType) or not isinstance(
2138-
b.type, DenseTensorType
2137+
if not (
2138+
isinstance(a.type, DenseTensorType) and isinstance(b.type, DenseTensorType)
21392139
):
21402140
raise TypeError("The dense dot product is only supported for dense types")
21412141

pytensor/tensor/rewriting/basic.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -658,13 +658,13 @@ def local_cast_cast(fgraph, node):
658658
and the first cast cause an upcast.
659659
660660
"""
661-
if not isinstance(node.op, Elemwise) or not isinstance(node.op.scalar_op, ps.Cast):
661+
if not (isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, ps.Cast)):
662662
return
663663
x = node.inputs[0]
664-
if (
665-
x.owner is None
666-
or not isinstance(x.owner.op, Elemwise)
667-
or not isinstance(x.owner.op.scalar_op, ps.Cast)
664+
if not (
665+
x.owner
666+
and isinstance(x.owner.op, Elemwise)
667+
and isinstance(x.owner.op.scalar_op, ps.Cast)
668668
):
669669
return
670670

@@ -1053,8 +1053,9 @@ def local_merge_switch_same_cond(fgraph, node):
10531053
Example: switch(c, a, b) + switch(c, x, y) -> switch(c, a+x, b+y)
10541054
"""
10551055
# node must be binary elemwise or add or mul
1056-
if not isinstance(node.op, Elemwise) or not isinstance(
1057-
node.op.scalar_op, ps.BinaryScalarOp | ps.Add | ps.Mul
1056+
if not (
1057+
isinstance(node.op, Elemwise)
1058+
and isinstance(node.op.scalar_op, ps.BinaryScalarOp | ps.Add | ps.Mul)
10581059
):
10591060
return
10601061
# all inputs must be switch

pytensor/tensor/rewriting/elemwise.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -473,10 +473,10 @@ def local_useless_dimshuffle_makevector(fgraph, node):
473473

474474
makevector_out = node.inputs[0]
475475

476-
if (
477-
not makevector_out.owner
478-
or not isinstance(makevector_out.owner.op, MakeVector)
479-
or not makevector_out.broadcastable == (True,)
476+
if not (
477+
makevector_out.owner
478+
and isinstance(makevector_out.owner.op, MakeVector)
479+
and makevector_out.broadcastable == (True,)
480480
):
481481
return
482482

@@ -570,8 +570,8 @@ def local_add_mul_fusion(fgraph, node):
570570
This rewrite is almost useless after the AlgebraicCanonizer is used,
571571
but it catches a few edge cases that are not canonicalized by it
572572
"""
573-
if not isinstance(node.op, Elemwise) or not isinstance(
574-
node.op.scalar_op, ps.Add | ps.Mul
573+
if not (
574+
isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, ps.Add | ps.Mul)
575575
):
576576
return False
577577

@@ -1094,8 +1094,8 @@ def print_profile(stream, prof, level=0):
10941094
@node_rewriter([Elemwise])
10951095
def local_useless_composite_outputs(fgraph, node):
10961096
"""Remove inputs and outputs of Composite Ops that are not used anywhere."""
1097-
if not isinstance(node.op, Elemwise) or not isinstance(
1098-
node.op.scalar_op, ps.Composite
1097+
if not (
1098+
isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, ps.Composite)
10991099
):
11001100
return
11011101
comp = node.op.scalar_op
@@ -1135,7 +1135,7 @@ def local_careduce_fusion(fgraph, node):
11351135

11361136
elm_node = car_input.owner
11371137

1138-
if elm_node is None or not isinstance(elm_node.op, Elemwise):
1138+
if not (elm_node and isinstance(elm_node.op, Elemwise)):
11391139
return False
11401140

11411141
elm_scalar_op = elm_node.op.scalar_op

pytensor/tensor/rewriting/math.py

+17-13
Original file line numberDiff line numberDiff line change
@@ -2343,12 +2343,14 @@ def local_log_sum_exp(fgraph, node):
23432343
else:
23442344
dimshuffle_op = None
23452345

2346-
if not sum_node or not isinstance(sum_node.op, Sum):
2346+
if not (sum_node and isinstance(sum_node.op, Sum)):
23472347
return
23482348

23492349
exp_node, axis = sum_node.inputs[0].owner, sum_node.op.axis
2350-
if not exp_node or not (
2351-
isinstance(exp_node.op, Elemwise) and isinstance(exp_node.op.scalar_op, ps.Exp)
2350+
if not (
2351+
exp_node
2352+
and isinstance(exp_node.op, Elemwise)
2353+
and isinstance(exp_node.op.scalar_op, ps.Exp)
23522354
):
23532355
return
23542356

@@ -2660,7 +2662,7 @@ def local_log_erfc(fgraph, node):
26602662
10.0541948,10.0541951,.0000001)]
26612663
"""
26622664

2663-
if not node.inputs[0].owner or node.inputs[0].owner.op != erfc:
2665+
if not (node.inputs[0].owner and node.inputs[0].owner.op == erfc):
26642666
return False
26652667

26662668
if hasattr(node.tag, "local_log_erfc_applied"):
@@ -2725,7 +2727,7 @@ def local_grad_log_erfc_neg(fgraph, node):
27252727
if node.inputs[0].owner.op != mul:
27262728
mul_in = None
27272729
y = []
2728-
if not node.inputs[0].owner or node.inputs[0].owner.op != exp:
2730+
if not (node.inputs[0].owner and node.inputs[0].owner.op == exp):
27292731
return False
27302732
exp_in = node.inputs[0]
27312733
else:
@@ -2749,7 +2751,9 @@ def local_grad_log_erfc_neg(fgraph, node):
27492751

27502752
if exp_in.owner.inputs[0].owner.op == neg:
27512753
neg_in = exp_in.owner.inputs[0]
2752-
if not neg_in.owner.inputs[0].owner or neg_in.owner.inputs[0].owner.op != sqr:
2754+
if not (
2755+
neg_in.owner.inputs[0].owner and neg_in.owner.inputs[0].owner.op == sqr
2756+
):
27532757
return False
27542758
sqr_in = neg_in.owner.inputs[0]
27552759
x = sqr_in.owner.inputs[0]
@@ -2794,9 +2798,9 @@ def check_input(inputs):
27942798
return False
27952799

27962800
if len(mul_neg.owner.inputs) == 2:
2797-
if (
2798-
not mul_neg.owner.inputs[1].owner
2799-
or mul_neg.owner.inputs[1].owner.op != sqr
2801+
if not (
2802+
mul_neg.owner.inputs[1].owner
2803+
and mul_neg.owner.inputs[1].owner.op == sqr
28002804
):
28012805
return False
28022806
sqr_in = mul_neg.owner.inputs[1]
@@ -2809,10 +2813,10 @@ def check_input(inputs):
28092813
return False
28102814

28112815
if cst2 != -1:
2812-
if (
2813-
not erfc_x.owner
2814-
or erfc_x.owner.op != mul
2815-
or len(erfc_x.owner.inputs) != 2
2816+
if not (
2817+
erfc_x.owner
2818+
and erfc_x.owner.op == mul
2819+
and len(erfc_x.owner.inputs) == 2
28162820
):
28172821
# todo implement that case
28182822
return False

pytensor/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ def __new__(cls):
360360
# don't want that, so we check the class. When we add one, we
361361
# add one only to the current class, so all is working
362362
# correctly.
363-
if cls.__instance is None or not isinstance(cls.__instance, cls):
363+
if not (cls.__instance and isinstance(cls.__instance, cls)):
364364
cls.__instance = super().__new__(cls)
365365
return cls.__instance
366366

0 commit comments

Comments
 (0)