Skip to content

Commit aa636db

Browse files
committed
simplify definitions
1 parent 3e5442f commit aa636db

File tree

3 files changed

+22
-21
lines changed

3 files changed

+22
-21
lines changed

ffcx/codegeneration/definitions.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -74,17 +74,21 @@ def __init__(self, ir, symbols, options):
7474
ufl.geometry.CellOrientation,
7575
ufl.geometry.FacetOrientation}
7676

77-
def get(self, t, mt, tabledata, quadrature_rule, access):
77+
def get(self, mt, tabledata, quadrature_rule, access):
7878
# Call appropriate handler, depending on the type of t
79-
ttype = type(t)
79+
80+
terminal = mt.terminal
81+
ttype = type(terminal)
8082

8183
if ttype in self.do_nothing_set:
8284
return [], []
83-
else:
85+
elif ttype in self.call_lookup:
8486
handler = self.call_lookup[ttype]
85-
return handler(t, mt, tabledata, quadrature_rule, access)
87+
return handler(mt, tabledata, quadrature_rule, access)
88+
else:
89+
raise NotImplementedError(f"No handler for terminal type: {ttype}")
8690

87-
def coefficient(self, t, mt, tabledata, quadrature_rule, access):
91+
def coefficient(self, mt, tabledata, quadrature_rule, access):
8892
"""Return definition code for coefficients."""
8993
# For applying tensor product to coefficients, we need to know if the coefficient
9094
# has a tensor factorisation and if the quadrature rule has a tensor factorisation.
@@ -142,7 +146,7 @@ def coefficient(self, t, mt, tabledata, quadrature_rule, access):
142146

143147
return pre_code, code
144148

145-
def _define_coordinate_dofs_lincomb(self, e, mt, tabledata, quadrature_rule, access):
149+
def _define_coordinate_dofs_lincomb(self, mt, tabledata, quadrature_rule, access):
146150
"""Define x or J as a linear combination of coordinate dofs with given table data."""
147151
# Get properties of domain
148152
domain = ufl.domain.extract_unique_domain(mt.terminal)
@@ -182,7 +186,7 @@ def _define_coordinate_dofs_lincomb(self, e, mt, tabledata, quadrature_rule, acc
182186

183187
return [], code
184188

185-
def spatial_coordinate(self, e, mt, tabledata, quadrature_rule, access):
189+
def spatial_coordinate(self, mt, tabledata, quadrature_rule, access):
186190
"""Return definition code for the physical spatial coordinates.
187191
188192
If physical coordinates are given:
@@ -200,8 +204,8 @@ def spatial_coordinate(self, e, mt, tabledata, quadrature_rule, access):
200204
logging.exception("FIXME: Jacobian in custom integrals is not implemented.")
201205
return []
202206
else:
203-
return self._define_coordinate_dofs_lincomb(e, mt, tabledata, quadrature_rule, access)
207+
return self._define_coordinate_dofs_lincomb(mt, tabledata, quadrature_rule, access)
204208

205-
def jacobian(self, e, mt, tabledata, quadrature_rule, access):
209+
def jacobian(self, mt, tabledata, quadrature_rule, access):
206210
"""Return definition code for the Jacobian of x(X)."""
207-
return self._define_coordinate_dofs_lincomb(e, mt, tabledata, quadrature_rule, access)
211+
return self._define_coordinate_dofs_lincomb(mt, tabledata, quadrature_rule, access)

ffcx/codegeneration/expression_generator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -323,9 +323,9 @@ def generate_partition(self, symbol, F, mode):
323323
tabledata = attr.get('tr')
324324

325325
# Backend specific modified terminal translation
326-
vaccess = self.backend.access.get(mt.terminal, mt, tabledata, 0)
326+
vaccess = self.backend.access.get(mt, tabledata, 0)
327327

328-
predef, vdef = self.backend.definitions.get(mt.terminal, mt, tabledata, 0, vaccess)
328+
predef, vdef = self.backend.definitions.get(mt, tabledata, 0, vaccess)
329329
if predef:
330330
pre_definitions[str(predef[0].symbol.name)] = predef
331331

ffcx/codegeneration/integral_generator.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,6 @@ def generate_varying_partition(self, quadrature_rule):
270270

271271
# Get annotated graph of factorisation
272272
F = self.ir.integrand[quadrature_rule]["factorization"]
273-
274273
arraysymbol = L.Symbol(f"sv_{quadrature_rule.id()}", dtype=L.DataType.SCALAR)
275274
pre_definitions, parts = self.generate_partition(arraysymbol, F, "varying", quadrature_rule)
276275
parts = L.commented_code_list(parts, f"Varying computations for quadrature rule {quadrature_rule.id()}")
@@ -291,27 +290,25 @@ def generate_partition(self, symbol, F, mode, quadrature_rule):
291290
v = attr['expression']
292291
mt = attr.get('mt')
293292

294-
# Generate code only if the expression is not already in
295-
# cache
293+
# Generate code only if the expression is not already in cache
296294
if not self.get_var(quadrature_rule, v):
297295
if v._ufl_is_literal_:
298296
vaccess = L.ufl_to_lnodes(v)
299-
elif mt is not None:
300-
# All finite element based terminals have table
301-
# data, as well as some, but not all, of the
302-
# symbolic geometric terminals
297+
elif mt:
298+
assert mt is not None
303299
tabledata = attr.get('tr')
304300

305301
# Backend specific modified terminal translation
306-
vaccess = self.backend.access.get(mt.terminal, mt, tabledata, quadrature_rule)
307-
predef, vdef = self.backend.definitions.get(mt.terminal, mt, tabledata, quadrature_rule, vaccess)
302+
vaccess = self.backend.access.get(mt, tabledata, quadrature_rule)
303+
predef, vdef = self.backend.definitions.get(mt, tabledata, quadrature_rule, vaccess)
308304
if predef:
309305
access = predef[0].symbol.name
310306
pre_definitions[str(access)] = predef
311307

312308
# Store definitions of terminals in list
313309
assert isinstance(vdef, list)
314310
definitions[str(vaccess)] = vdef
311+
315312
else:
316313
# Get previously visited operands
317314
vops = [self.get_var(quadrature_rule, op) for op in v.ufl_operands]

0 commit comments

Comments
 (0)