diff --git a/ufl/algorithms/remove_component_tensors.py b/ufl/algorithms/remove_component_tensors.py index af0a22b56..1b7b99d15 100644 --- a/ufl/algorithms/remove_component_tensors.py +++ b/ufl/algorithms/remove_component_tensors.py @@ -2,13 +2,14 @@ This module contains classes and functions to remove component tensors. """ -# Copyright (C) 2008-2016 Martin Sandve Alnæs +# Copyright (C) 2025 Pablo Brubeck # # This file is part of UFL (https://www.fenicsproject.org) # # SPDX-License-Identifier: LGPL-3.0-or-later -from ufl.classes import ComponentTensor, Form, Index, MultiIndex, Zero +from ufl.algorithms.map_integrands import map_integrand_dags +from ufl.classes import ComponentTensor, Index, MultiIndex, Zero from ufl.corealg.map_dag import map_expr_dag from ufl.corealg.multifunction import MultiFunction, memoized_handler @@ -107,13 +108,5 @@ def indexed(self, o): def remove_component_tensors(o): """Remove component tensors.""" - if isinstance(o, Form): - integrals = [] - for integral in o.integrals(): - integrand = remove_component_tensors(integral.integrand()) - if not isinstance(integrand, Zero): - integrals.append(integral.reconstruct(integrand=integrand)) - return o._ufl_expr_reconstruct_(integrals) - else: - rule = IndexRemover() - return map_expr_dag(rule, o) + rule = IndexRemover() + return map_integrand_dags(rule, o)