Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix FormSum weights #335

Merged
merged 6 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions ufl/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,21 @@ def __new__(cls, *args, **kw):
if isinstance(right, (Coargument, Argument)):
return left

if isinstance(left, (FormSum, Sum)):
# Action distributes over sums on the LHS
return FormSum(*[(Action(component, right), 1) for component in left.ufl_operands])
if isinstance(right, (FormSum, Sum)):
# Action also distributes over sums on the RHS
return FormSum(*[(Action(left, component), 1) for component in right.ufl_operands])
# Action distributes over sums on the LHS
if isinstance(left, Sum):
return FormSum(*((Action(component, right), 1) for component in left.ufl_operands))
elif isinstance(left, FormSum):
return FormSum(
*((Action(c, right), w) for c, w in zip(left.components(), left.weights()))
)

# Action also distributes over sums on the RHS
if isinstance(right, Sum):
return FormSum(*((Action(left, component), 1) for component in right.ufl_operands))
elif isinstance(right, FormSum):
return FormSum(
*((Action(left, c), w) for c, w in zip(right.components(), right.weights()))
)

return super(Action, cls).__new__(cls)

Expand Down
2 changes: 1 addition & 1 deletion ufl/adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __new__(cls, *args, **kw):
return form._form
elif isinstance(form, FormSum):
# Adjoint distributes over sums
return FormSum(*[(Adjoint(component), 1) for component in form.components()])
return FormSum(*((Adjoint(c), w) for c, w in zip(form.components(), form.weights())))
elif isinstance(form, Coargument):
# The adjoint of a coargument `c: V* -> V*` is the identity
# matrix mapping from V to V (i.e. V x V* -> R).
Expand Down
2 changes: 1 addition & 1 deletion ufl/algorithms/map_integrands.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def map_integrands(function, form, only_integral_type=None):
# Simplification of `BaseForm` objects may turn `FormSum` into a sum of `Expr` objects
# that are not `BaseForm`, i.e. into a `Sum` object.
# Example: `Action(Adjoint(c*), u)` with `c*` a `Coargument` and u a `Coefficient`.
return sum([component for component, _ in nonzero_components])
return sum(component * w for component, w in nonzero_components)
return FormSum(*nonzero_components)
elif isinstance(form, Adjoint):
# Zeros are caught inside `Adjoint.__new__`
Expand Down
16 changes: 10 additions & 6 deletions ufl/form.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,7 +787,7 @@ def _analyze_domains(self):

# Collect unique domains
self._domains = sort_domains(
join_domains(chain.from_iterable(e.ufl_domains() for e in self.ufl_operands))
join_domains(chain.from_iterable(c.ufl_domains() for c in self.components()))
)

def ufl_domains(self):
Expand All @@ -799,7 +799,9 @@ def ufl_domains(self):
def __hash__(self):
"""Hash."""
if self._hash is None:
self._hash = hash(tuple(hash(component) for component in self.components()))
self._hash = hash(
tuple((hash(c), hash(w)) for c, w in zip(self.components(), self.weights()))
)
return self._hash

def equals(self, other):
Expand All @@ -808,8 +810,10 @@ def equals(self, other):
return False
if self is other:
return True
return len(self.components()) == len(other.components()) and all(
a == b for a, b in zip(self.components(), other.components())
return (
len(self.components()) == len(other.components())
and all(a == b for a, b in zip(self.components(), other.components()))
and all(a == b for a, b in zip(self.weights(), other.weights()))
)

def __str__(self):
Expand All @@ -818,7 +822,7 @@ def __str__(self):
# warning("Calling str on form is potentially expensive and
# should be avoided except during debugging.")
# Not caching this because it can be huge
s = "\n + ".join(str(component) for component in self.components())
s = "\n + ".join(f"{w}*{c}" for c, w in zip(self.components(), self.weights()))
return s or "<empty FormSum>"

def __repr__(self):
Expand All @@ -827,7 +831,7 @@ def __repr__(self):
# warning("Calling repr on form is potentially expensive and
# should be avoided except during debugging.")
# Not caching this because it can be huge
itgs = ", ".join(repr(component) for component in self.components())
itgs = ", ".join(f"{w!r}*{c!r}" for c, w in zip(self.components(), self.weights()))
r = "FormSum([" + itgs + "])"
return r

Expand Down
4 changes: 2 additions & 2 deletions ufl/formoperators.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,8 +319,8 @@ def derivative(form, coefficient, argument=None, coefficient_derivatives=None):
# Distribute derivative over FormSum components
return FormSum(
*[
(derivative(component, coefficient, argument, coefficient_derivatives), 1)
for component in form.components()
(derivative(component, coefficient, argument, coefficient_derivatives), w)
for component, w in zip(form.components(), form.weights())
]
)
elif isinstance(form, Adjoint):
Expand Down
Loading