Skip to content

Commit

Permalink
simplify definitions
Browse files Browse the repository at this point in the history
  • Loading branch information
IgorBaratta committed Jan 3, 2024
1 parent 3e5442f commit aa636db
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 21 deletions.
24 changes: 14 additions & 10 deletions ffcx/codegeneration/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,17 +74,21 @@ def __init__(self, ir, symbols, options):
ufl.geometry.CellOrientation,
ufl.geometry.FacetOrientation}

def get(self, t, mt, tabledata, quadrature_rule, access):
def get(self, mt, tabledata, quadrature_rule, access):
# Call appropriate handler, depending on the type of t
ttype = type(t)

terminal = mt.terminal
ttype = type(terminal)

if ttype in self.do_nothing_set:
return [], []
else:
elif ttype in self.call_lookup:
handler = self.call_lookup[ttype]
return handler(t, mt, tabledata, quadrature_rule, access)
return handler(mt, tabledata, quadrature_rule, access)
else:
raise NotImplementedError(f"No handler for terminal type: {ttype}")

def coefficient(self, t, mt, tabledata, quadrature_rule, access):
def coefficient(self, mt, tabledata, quadrature_rule, access):
"""Return definition code for coefficients."""
# For applying tensor product to coefficients, we need to know if the coefficient
# has a tensor factorisation and if the quadrature rule has a tensor factorisation.
Expand Down Expand Up @@ -142,7 +146,7 @@ def coefficient(self, t, mt, tabledata, quadrature_rule, access):

return pre_code, code

def _define_coordinate_dofs_lincomb(self, e, mt, tabledata, quadrature_rule, access):
def _define_coordinate_dofs_lincomb(self, mt, tabledata, quadrature_rule, access):
"""Define x or J as a linear combination of coordinate dofs with given table data."""
# Get properties of domain
domain = ufl.domain.extract_unique_domain(mt.terminal)
Expand Down Expand Up @@ -182,7 +186,7 @@ def _define_coordinate_dofs_lincomb(self, e, mt, tabledata, quadrature_rule, acc

return [], code

def spatial_coordinate(self, e, mt, tabledata, quadrature_rule, access):
def spatial_coordinate(self, mt, tabledata, quadrature_rule, access):
"""Return definition code for the physical spatial coordinates.
If physical coordinates are given:
Expand All @@ -200,8 +204,8 @@ def spatial_coordinate(self, e, mt, tabledata, quadrature_rule, access):
logging.exception("FIXME: Jacobian in custom integrals is not implemented.")
return []
else:
return self._define_coordinate_dofs_lincomb(e, mt, tabledata, quadrature_rule, access)
return self._define_coordinate_dofs_lincomb(mt, tabledata, quadrature_rule, access)

def jacobian(self, e, mt, tabledata, quadrature_rule, access):
def jacobian(self, mt, tabledata, quadrature_rule, access):
"""Return definition code for the Jacobian of x(X)."""
return self._define_coordinate_dofs_lincomb(e, mt, tabledata, quadrature_rule, access)
return self._define_coordinate_dofs_lincomb(mt, tabledata, quadrature_rule, access)
4 changes: 2 additions & 2 deletions ffcx/codegeneration/expression_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,9 +323,9 @@ def generate_partition(self, symbol, F, mode):
tabledata = attr.get('tr')

# Backend specific modified terminal translation
vaccess = self.backend.access.get(mt.terminal, mt, tabledata, 0)
vaccess = self.backend.access.get(mt, tabledata, 0)

predef, vdef = self.backend.definitions.get(mt.terminal, mt, tabledata, 0, vaccess)
predef, vdef = self.backend.definitions.get(mt, tabledata, 0, vaccess)
if predef:
pre_definitions[str(predef[0].symbol.name)] = predef

Expand Down
15 changes: 6 additions & 9 deletions ffcx/codegeneration/integral_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,6 @@ def generate_varying_partition(self, quadrature_rule):

# Get annotated graph of factorisation
F = self.ir.integrand[quadrature_rule]["factorization"]

arraysymbol = L.Symbol(f"sv_{quadrature_rule.id()}", dtype=L.DataType.SCALAR)
pre_definitions, parts = self.generate_partition(arraysymbol, F, "varying", quadrature_rule)
parts = L.commented_code_list(parts, f"Varying computations for quadrature rule {quadrature_rule.id()}")
Expand All @@ -291,27 +290,25 @@ def generate_partition(self, symbol, F, mode, quadrature_rule):
v = attr['expression']
mt = attr.get('mt')

# Generate code only if the expression is not already in
# cache
# Generate code only if the expression is not already in cache
if not self.get_var(quadrature_rule, v):
if v._ufl_is_literal_:
vaccess = L.ufl_to_lnodes(v)
elif mt is not None:
# All finite element based terminals have table
# data, as well as some, but not all, of the
# symbolic geometric terminals
elif mt:
assert mt is not None
tabledata = attr.get('tr')

# Backend specific modified terminal translation
vaccess = self.backend.access.get(mt.terminal, mt, tabledata, quadrature_rule)
predef, vdef = self.backend.definitions.get(mt.terminal, mt, tabledata, quadrature_rule, vaccess)
vaccess = self.backend.access.get(mt, tabledata, quadrature_rule)
predef, vdef = self.backend.definitions.get(mt, tabledata, quadrature_rule, vaccess)
if predef:
access = predef[0].symbol.name
pre_definitions[str(access)] = predef

# Store definitions of terminals in list
assert isinstance(vdef, list)
definitions[str(vaccess)] = vdef

else:
# Get previously visited operands
vops = [self.get_var(quadrature_rule, op) for op in v.ufl_operands]
Expand Down

0 comments on commit aa636db

Please sign in to comment.