9
9
# SPDX-License-Identifier: LGPL-3.0-or-later
10
10
11
11
from ufl .algorithms .estimate_degrees import SumDegreeEstimator
12
+ from ufl .checks import is_cellwise_constant
12
13
from ufl .classes import (
13
14
ComponentTensor ,
14
15
Form ,
@@ -42,13 +43,10 @@ def zero(self, o):
42
43
free_indices = []
43
44
index_dimensions = []
44
45
for i , d in zip (o .ufl_free_indices , o .ufl_index_dimensions ):
45
- if Index (i ) in self .fimap :
46
- ind_j = self .fimap [Index (i )]
47
- if isinstance (ind_j , Index ):
48
- free_indices .append (ind_j .count ())
49
- index_dimensions .append (d )
50
- else :
51
- free_indices .append (i )
46
+ k = Index (i )
47
+ j = self .fimap .get (k , k )
48
+ if isinstance (j , Index ):
49
+ free_indices .append (j .count ())
52
50
index_dimensions .append (d )
53
51
return Zero (
54
52
shape = o .ufl_shape ,
@@ -71,24 +69,45 @@ def __init__(self):
71
69
self ._object_cache = {}
72
70
self .degree_estimator = SumDegreeEstimator (1 , {})
73
71
72
+ def is_cellwise_constant (self , o ):
73
+ """More precise checks for cellwise constants."""
74
+ if is_cellwise_constant (o ):
75
+ return True
76
+ degree = map_expr_dag (self .degree_estimator , o )
77
+ return degree == 0
78
+
74
79
expr = MultiFunction .reuse_if_untouched
75
80
81
+ @memoized_handler
82
+ def sum (self , o ):
83
+ """Simplify Sum, allow adding Zeros with different free indices."""
84
+ oa , ob = o .ufl_operands
85
+ a = map_expr_dag (self , oa )
86
+ b = map_expr_dag (self , ob )
87
+ if isinstance (b , Zero ):
88
+ return a
89
+ if isinstance (a , Zero ):
90
+ return b
91
+ if a is oa and b is ob :
92
+ # Reuse if untouched
93
+ return o
94
+ return o ._ufl_expr_reconstruct_ (a , b )
95
+
76
96
@memoized_handler
77
97
def reference_grad (self , o ):
78
98
"""Simplify ReferenceGrad(Constant)."""
79
99
(operand ,) = o .ufl_operands
80
- operand = map_expr_dag (self , operand )
81
- degree = map_expr_dag (self .degree_estimator , operand )
82
- if degree == 0 :
100
+ f = map_expr_dag (self , operand )
101
+ if self .is_cellwise_constant (f ):
83
102
return Zero (
84
103
shape = o .ufl_shape ,
85
104
free_indices = o .ufl_free_indices ,
86
105
index_dimensions = o .ufl_index_dimensions ,
87
106
)
88
- if operand is o . ufl_operands [ 0 ] :
107
+ if f is operand :
89
108
# Reuse if untouched
90
109
return o
91
- return o ._ufl_expr_reconstruct_ (operand )
110
+ return o ._ufl_expr_reconstruct_ (f )
92
111
93
112
@memoized_handler
94
113
def indexed (self , o ):
0 commit comments