Skip to content

Commit 4713c06

Browse files
committed
IndexSum: Factor out common factors if possible
1 parent 36e0e5d commit 4713c06

File tree

3 files changed

+22
-8
lines changed

3 files changed

+22
-8
lines changed

test/test_apply_function_pullbacks.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,8 @@ def test_apply_single_function_pullbacks_triangle3d():
151151
vc: as_vector(Jinv[j, i] * rvc[j], i),
152152
t: rt,
153153
s: as_tensor([[rs[0], rs[1], rs[2]], [rs[1], rs[3], rs[4]], [rs[2], rs[4], rs[5]]]),
154-
cov2t: as_tensor(Jinv[k, i] * rcov2t[k, l] * Jinv[l, j], (i, j)),
155-
contra2t: as_tensor((1.0 / detJ) ** 2 * J[i, k] * rcontra2t[k, l] * J[j, l], (i, j)),
154+
cov2t: as_tensor(Jinv[k, i] * (rcov2t[k, l] * Jinv[l, j]), (i, j)),
155+
contra2t: (1.0 / detJ) ** 2 * as_tensor(J[i, k] * (rcontra2t[k, l] * J[j, l]), (i, j)),
156156
# Mixed elements become a bit more complicated
157157
uml2: as_vector([ruml2[0] / detJ, ruml2[1] / detJ]),
158158
um: rum,

ufl/indexsum.py

+17-3
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#
77
# SPDX-License-Identifier: LGPL-3.0-or-later
88

9+
from ufl.algebra import Product
910
from ufl.constantvalue import Zero
1011
from ufl.core.expr import Expr, ufl_err_str
1112
from ufl.core.multiindex import MultiIndex
@@ -21,7 +22,7 @@
2122
class IndexSum(Operator):
2223
"""Index sum."""
2324

24-
__slots__ = ("_dimension", "ufl_free_indices", "ufl_index_dimensions")
25+
__slots__ = ("_dimension", "_initialised", "ufl_free_indices", "ufl_index_dimensions")
2526

2627
def __new__(cls, summand, index):
2728
"""Create a new IndexSum."""
@@ -33,21 +34,33 @@ def __new__(cls, summand, index):
3334
if len(index) != 1:
3435
raise ValueError(f"Expecting a single Index but got {len(index)}.")
3536

37+
(j,) = index
3638
# Simplification to zero
3739
if isinstance(summand, Zero):
3840
sh = summand.ufl_shape
39-
(j,) = index
4041
fi = summand.ufl_free_indices
4142
fid = summand.ufl_index_dimensions
4243
pos = fi.index(j.count())
4344
fi = fi[:pos] + fi[pos + 1 :]
4445
fid = fid[:pos] + fid[pos + 1 :]
4546
return Zero(sh, fi, fid)
4647

47-
return Operator.__new__(cls)
48+
# Factor out common factors
49+
if isinstance(summand, Product):
50+
a, b = summand.ufl_operands
51+
if j.count() not in a.ufl_free_indices:
52+
return Product(a, IndexSum(b, index))
53+
elif j.count() not in b.ufl_free_indices:
54+
return Product(b, IndexSum(a, index))
55+
56+
self = Operator.__new__(cls)
57+
self._initialised = False
58+
return self
4859

4960
def __init__(self, summand, index):
5061
"""Initialise."""
62+
if self._initialised:
63+
return
5164
(j,) = index
5265
fi = summand.ufl_free_indices
5366
fid = summand.ufl_index_dimensions
@@ -56,6 +69,7 @@ def __init__(self, summand, index):
5669
self.ufl_free_indices = fi[:pos] + fi[pos + 1 :]
5770
self.ufl_index_dimensions = fid[:pos] + fid[pos + 1 :]
5871
Operator.__init__(self, (summand, index))
72+
self._initialised = True
5973

6074
def index(self):
6175
"""Get index."""

ufl/pullback.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ def apply(self, expr):
270270
# Apply transform "row-wise" to TensorElement(PiolaMapped, ...)
271271
*k, i, j, m, n = indices(len(expr.ufl_shape) + 2)
272272
kmn = (*k, m, n)
273-
return as_tensor((1.0 / detJ) ** 2 * J[i, m] * expr[kmn] * J[j, n], (*k, i, j))
273+
return (1.0 / detJ) ** 2 * as_tensor(J[i, m] * (expr[kmn] * J[j, n]), (*k, i, j))
274274

275275
def physical_value_shape(self, element, domain) -> typing.Tuple[int, ...]:
276276
"""Get the physical value shape when this pull back is applied to an element on a domain.
@@ -313,7 +313,7 @@ def apply(self, expr):
313313
# Apply transform "row-wise" to TensorElement(PiolaMapped, ...)
314314
*k, i, j, m, n = indices(len(expr.ufl_shape) + 2)
315315
kmn = (*k, m, n)
316-
return as_tensor(K[m, i] * expr[kmn] * K[n, j], (*k, i, j))
316+
return as_tensor(K[m, i] * (expr[kmn] * K[n, j]), (*k, i, j))
317317

318318
def physical_value_shape(self, element, domain) -> typing.Tuple[int, ...]:
319319
"""Get the physical value shape when this pull back is applied to an element on a domain.
@@ -358,7 +358,7 @@ def apply(self, expr):
358358
# Apply transform "row-wise" to TensorElement(PiolaMapped, ...)
359359
*k, i, j, m, n = indices(len(expr.ufl_shape) + 2)
360360
kmn = (*k, m, n)
361-
return as_tensor((1.0 / detJ) * K[m, i] * expr[kmn] * J[j, n], (*k, i, j))
361+
return (1.0 / detJ) * as_tensor(K[m, i] * (expr[kmn] * J[j, n]), (*k, i, j))
362362

363363
def physical_value_shape(self, element, domain) -> typing.Tuple[int, ...]:
364364
"""Get the physical value shape when this pull back is applied to an element.

0 commit comments

Comments
 (0)