Skip to content

Commit

Permalink
IndexSum: Factor out common factors if possible
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrubeck committed Jan 26, 2025
1 parent 36e0e5d commit 438c594
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 6 deletions.
20 changes: 17 additions & 3 deletions ufl/indexsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#
# SPDX-License-Identifier: LGPL-3.0-or-later

from ufl.algebra import Product
from ufl.constantvalue import Zero
from ufl.core.expr import Expr, ufl_err_str
from ufl.core.multiindex import MultiIndex
Expand All @@ -21,7 +22,7 @@
class IndexSum(Operator):
"""Index sum."""

__slots__ = ("_dimension", "ufl_free_indices", "ufl_index_dimensions")
__slots__ = ("_dimension", "_initialised", "ufl_free_indices", "ufl_index_dimensions")

def __new__(cls, summand, index):
"""Create a new IndexSum."""
Expand All @@ -33,21 +34,33 @@ def __new__(cls, summand, index):
if len(index) != 1:
raise ValueError(f"Expecting a single Index but got {len(index)}.")

(j,) = index
# Simplification to zero
if isinstance(summand, Zero):
sh = summand.ufl_shape
(j,) = index
fi = summand.ufl_free_indices
fid = summand.ufl_index_dimensions
pos = fi.index(j.count())
fi = fi[:pos] + fi[pos + 1 :]
fid = fid[:pos] + fid[pos + 1 :]
return Zero(sh, fi, fid)

return Operator.__new__(cls)
# Factor out common factors
if isinstance(summand, Product):
a, b = summand.ufl_operands
if j.count() not in a.ufl_free_indices:
return Product(a, IndexSum(b, index))
elif j.count() not in b.ufl_free_indices:
return Product(b, IndexSum(a, index))

self = Operator.__new__(cls)
self._initialised = False
return self

def __init__(self, summand, index):
"""Initialise."""
if self._initialised:
return
(j,) = index
fi = summand.ufl_free_indices
fid = summand.ufl_index_dimensions
Expand All @@ -56,6 +69,7 @@ def __init__(self, summand, index):
self.ufl_free_indices = fi[:pos] + fi[pos + 1 :]
self.ufl_index_dimensions = fid[:pos] + fid[pos + 1 :]
Operator.__init__(self, (summand, index))
self._initialised = True

def index(self):
"""Get index."""
Expand Down
6 changes: 3 additions & 3 deletions ufl/pullback.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def apply(self, expr):
# Apply transform "row-wise" to TensorElement(PiolaMapped, ...)
*k, i, j, m, n = indices(len(expr.ufl_shape) + 2)
kmn = (*k, m, n)
return as_tensor((1.0 / detJ) ** 2 * J[i, m] * expr[kmn] * J[j, n], (*k, i, j))
return (1.0 / detJ) ** 2 * as_tensor(J[i, m] * (expr[kmn] * J[j, n]), (*k, i, j))

def physical_value_shape(self, element, domain) -> typing.Tuple[int, ...]:
"""Get the physical value shape when this pull back is applied to an element on a domain.
Expand Down Expand Up @@ -313,7 +313,7 @@ def apply(self, expr):
# Apply transform "row-wise" to TensorElement(PiolaMapped, ...)
*k, i, j, m, n = indices(len(expr.ufl_shape) + 2)
kmn = (*k, m, n)
return as_tensor(K[m, i] * expr[kmn] * K[n, j], (*k, i, j))
return as_tensor(K[m, i] * (expr[kmn] * K[n, j]), (*k, i, j))

def physical_value_shape(self, element, domain) -> typing.Tuple[int, ...]:
"""Get the physical value shape when this pull back is applied to an element on a domain.
Expand Down Expand Up @@ -358,7 +358,7 @@ def apply(self, expr):
# Apply transform "row-wise" to TensorElement(PiolaMapped, ...)
*k, i, j, m, n = indices(len(expr.ufl_shape) + 2)
kmn = (*k, m, n)
return as_tensor((1.0 / detJ) * K[m, i] * expr[kmn] * J[j, n], (*k, i, j))
return (1.0 / detJ) * as_tensor(K[m, i] * (expr[kmn] * J[j, n]), (*k, i, j))

def physical_value_shape(self, element, domain) -> typing.Tuple[int, ...]:
"""Get the physical value shape when this pull back is applied to an element.
Expand Down

0 comments on commit 438c594

Please sign in to comment.