Skip to content

Commit d83702d

Browse files
committed
WIP
1 parent 768f403 commit d83702d

File tree

1 file changed

+31
-12
lines changed

1 file changed

+31
-12
lines changed

ufl/algorithms/remove_component_tensors.py

+31-12
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
# SPDX-License-Identifier: LGPL-3.0-or-later
1010

1111
from ufl.algorithms.estimate_degrees import SumDegreeEstimator
12+
from ufl.checks import is_cellwise_constant
1213
from ufl.classes import (
1314
ComponentTensor,
1415
Form,
@@ -42,13 +43,10 @@ def zero(self, o):
4243
free_indices = []
4344
index_dimensions = []
4445
for i, d in zip(o.ufl_free_indices, o.ufl_index_dimensions):
45-
if Index(i) in self.fimap:
46-
ind_j = self.fimap[Index(i)]
47-
if isinstance(ind_j, Index):
48-
free_indices.append(ind_j.count())
49-
index_dimensions.append(d)
50-
else:
51-
free_indices.append(i)
46+
k = Index(i)
47+
j = self.fimap.get(k, k)
48+
if isinstance(j, Index):
49+
free_indices.append(j.count())
5250
index_dimensions.append(d)
5351
return Zero(
5452
shape=o.ufl_shape,
@@ -71,24 +69,45 @@ def __init__(self):
7169
self._object_cache = {}
7270
self.degree_estimator = SumDegreeEstimator(1, {})
7371

72+
def is_cellwise_constant(self, o):
73+
"""More precise checks for cellwise constants."""
74+
if is_cellwise_constant(o):
75+
return True
76+
degree = map_expr_dag(self.degree_estimator, o)
77+
return degree == 0
78+
7479
expr = MultiFunction.reuse_if_untouched
7580

81+
@memoized_handler
82+
def sum(self, o):
83+
"""Simplify Sum, allow adding Zeros with different free indices."""
84+
oa, ob = o.ufl_operands
85+
a = map_expr_dag(self, oa)
86+
b = map_expr_dag(self, ob)
87+
if isinstance(b, Zero):
88+
return a
89+
if isinstance(a, Zero):
90+
return b
91+
if a is oa and b is ob:
92+
# Reuse if untouched
93+
return o
94+
return o._ufl_expr_reconstruct_(a, b)
95+
7696
@memoized_handler
7797
def reference_grad(self, o):
7898
"""Simplify ReferenceGrad(Constant)."""
7999
(operand,) = o.ufl_operands
80-
operand = map_expr_dag(self, operand)
81-
degree = map_expr_dag(self.degree_estimator, operand)
82-
if degree == 0:
100+
f = map_expr_dag(self, operand)
101+
if self.is_cellwise_constant(f):
83102
return Zero(
84103
shape=o.ufl_shape,
85104
free_indices=o.ufl_free_indices,
86105
index_dimensions=o.ufl_index_dimensions,
87106
)
88-
if operand is o.ufl_operands[0]:
107+
if f is operand:
89108
# Reuse if untouched
90109
return o
91-
return o._ufl_expr_reconstruct_(operand)
110+
return o._ufl_expr_reconstruct_(f)
92111

93112
@memoized_handler
94113
def indexed(self, o):

0 commit comments

Comments
 (0)