@@ -2343,12 +2343,14 @@ def local_log_sum_exp(fgraph, node):
2343
2343
else :
2344
2344
dimshuffle_op = None
2345
2345
2346
- if not sum_node or not isinstance (sum_node .op , Sum ):
2346
+ if not ( sum_node and isinstance (sum_node .op , Sum ) ):
2347
2347
return
2348
2348
2349
2349
exp_node , axis = sum_node .inputs [0 ].owner , sum_node .op .axis
2350
- if not exp_node or not (
2351
- isinstance (exp_node .op , Elemwise ) and isinstance (exp_node .op .scalar_op , ps .Exp )
2350
+ if not (
2351
+ exp_node
2352
+ and isinstance (exp_node .op , Elemwise )
2353
+ and isinstance (exp_node .op .scalar_op , ps .Exp )
2352
2354
):
2353
2355
return
2354
2356
@@ -2660,7 +2662,7 @@ def local_log_erfc(fgraph, node):
2660
2662
10.0541948,10.0541951,.0000001)]
2661
2663
"""
2662
2664
2663
- if not node .inputs [0 ].owner or node .inputs [0 ].owner .op != erfc :
2665
+ if not ( node .inputs [0 ].owner and node .inputs [0 ].owner .op == erfc ) :
2664
2666
return False
2665
2667
2666
2668
if hasattr (node .tag , "local_log_erfc_applied" ):
@@ -2725,7 +2727,7 @@ def local_grad_log_erfc_neg(fgraph, node):
2725
2727
if node .inputs [0 ].owner .op != mul :
2726
2728
mul_in = None
2727
2729
y = []
2728
- if not node .inputs [0 ].owner or node .inputs [0 ].owner .op != exp :
2730
+ if not ( node .inputs [0 ].owner and node .inputs [0 ].owner .op == exp ) :
2729
2731
return False
2730
2732
exp_in = node .inputs [0 ]
2731
2733
else :
@@ -2749,7 +2751,9 @@ def local_grad_log_erfc_neg(fgraph, node):
2749
2751
2750
2752
if exp_in .owner .inputs [0 ].owner .op == neg :
2751
2753
neg_in = exp_in .owner .inputs [0 ]
2752
- if not neg_in .owner .inputs [0 ].owner or neg_in .owner .inputs [0 ].owner .op != sqr :
2754
+ if not (
2755
+ neg_in .owner .inputs [0 ].owner and neg_in .owner .inputs [0 ].owner .op == sqr
2756
+ ):
2753
2757
return False
2754
2758
sqr_in = neg_in .owner .inputs [0 ]
2755
2759
x = sqr_in .owner .inputs [0 ]
@@ -2794,9 +2798,9 @@ def check_input(inputs):
2794
2798
return False
2795
2799
2796
2800
if len (mul_neg .owner .inputs ) == 2 :
2797
- if (
2798
- not mul_neg .owner .inputs [1 ].owner
2799
- or mul_neg .owner .inputs [1 ].owner .op ! = sqr
2801
+ if not (
2802
+ mul_neg .owner .inputs [1 ].owner
2803
+ and mul_neg .owner .inputs [1 ].owner .op = = sqr
2800
2804
):
2801
2805
return False
2802
2806
sqr_in = mul_neg .owner .inputs [1 ]
@@ -2809,10 +2813,10 @@ def check_input(inputs):
2809
2813
return False
2810
2814
2811
2815
if cst2 != - 1 :
2812
- if (
2813
- not erfc_x .owner
2814
- or erfc_x .owner .op ! = mul
2815
- or len (erfc_x .owner .inputs ) ! = 2
2816
+ if not (
2817
+ erfc_x .owner
2818
+ and erfc_x .owner .op = = mul
2819
+ and len (erfc_x .owner .inputs ) = = 2
2816
2820
):
2817
2821
# todo implement that case
2818
2822
return False
0 commit comments