Skip to content

Commit d6edf27

Browse files
committed
Add parameter concrete to _transform_pattern
1 parent d4abb24 commit d6edf27

File tree

1 file changed

+35
-24
lines changed

1 file changed

+35
-24
lines changed

pyk/src/pyk/klean/k2lean4.py

Lines changed: 35 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -357,17 +357,28 @@ def _def_binders(self, defs: Mapping[str, Pattern]) -> list[Binder]:
357357
for ident, pattern in defs.items()
358358
]
359359

360-
def _transform_pattern(self, pattern: Pattern) -> Term:
360+
def _transform_pattern(self, pattern: Pattern, *, concrete: bool = False) -> Term:
361361
match pattern:
362362
case EVar(name):
363363
return self._transform_evar(name)
364364
case DV(SortApp(sort), String(value)):
365365
return self._transform_dv(sort, value)
366366
case App(symbol, sorts, args):
367-
return self._transform_app(symbol, sorts, args)
367+
return self._transform_app(symbol, sorts, args, concrete=concrete)
368368
case _:
369369
raise ValueError(f'Unsupported pattern: {pattern.text}')
370370

371+
def _transform_arg(self, pattern: Pattern, *, concrete: bool = False) -> Term:
372+
term = self._transform_pattern(pattern, concrete=concrete)
373+
374+
if not isinstance(pattern, App):
375+
return term
376+
377+
if pattern.symbol in self.structure_symbols:
378+
return term
379+
380+
return Term(f'({term})')
381+
371382
def _transform_evar(self, name: str) -> Term:
372383
return Term(_var_ident(name))
373384

@@ -418,47 +429,47 @@ def encode(c: str) -> str:
418429
encoded = ''.join(encode(c) for c in value)
419430
return Term(f'"{encoded}"')
420431

421-
def _transform_app(self, symbol: str, sorts: tuple[Sort, ...], args: tuple[Pattern, ...]) -> Term:
432+
def _transform_app(
433+
self,
434+
symbol: str,
435+
sorts: tuple[Sort, ...],
436+
args: tuple[Pattern, ...],
437+
*,
438+
concrete: bool,
439+
) -> Term:
422440
if symbol == 'inj':
423-
return self._transform_inj_app(sorts, args)
441+
return self._transform_inj_app(sorts, args, concrete=concrete)
424442

425443
if symbol in self.structure_symbols:
426444
fields = self.structures[self.structure_symbols[symbol]]
427-
return self._transform_structure_app(fields, args)
445+
return self._transform_structure_app(fields, args, concrete=concrete)
428446

429447
decl = self.defn.symbols[symbol]
430448
sort = decl.sort.name if isinstance(decl.sort, SortApp) else None
431-
return self._transform_basic_app(sort, symbol, args)
432-
433-
def _transform_arg(self, pattern: Pattern) -> Term:
434-
term = self._transform_pattern(pattern)
449+
return self._transform_basic_app(sort, symbol, args, concrete=concrete)
435450

436-
if not isinstance(pattern, App):
437-
return term
438-
439-
if pattern.symbol in self.structure_symbols:
440-
return term
441-
442-
return Term(f'({term})')
443-
444-
def _transform_inj_app(self, sorts: tuple[Sort, ...], args: tuple[Pattern, ...]) -> Term:
451+
def _transform_inj_app(self, sorts: tuple[Sort, ...], args: tuple[Pattern, ...], *, concrete: bool) -> Term:
445452
_from_sort, _to_sort = sorts
446453
assert isinstance(_from_sort, SortApp)
447454
assert isinstance(_to_sort, SortApp)
448455
from_str = _from_sort.name
449456
to_str = _to_sort.name
450457
(arg,) = args
451-
term = self._transform_arg(arg)
452-
return Term(f'(@inj {from_str} {to_str}) {term}')
458+
term = self._transform_arg(arg, concrete=concrete)
459+
if concrete:
460+
return Term(f'{to_str}.inj_{from_str} {term}')
461+
else:
462+
return Term(f'(@inj {from_str} {to_str}) {term}')
453463

454-
def _transform_structure_app(self, fields: Iterable[Field], args: Iterable[Pattern]) -> Term:
464+
def _transform_structure_app(self, fields: Iterable[Field], args: Iterable[Pattern], *, concrete: bool) -> Term:
455465
fields_str = ', '.join(
456-
f'{field.name} := {self._transform_pattern(arg)}' for field, arg in zip(fields, args, strict=True)
466+
f'{field.name} := {self._transform_pattern(arg, concrete=concrete)}'
467+
for field, arg in zip(fields, args, strict=True)
457468
)
458469
lbrace, rbrace = ['{', '}']
459470
return Term(f'{lbrace} {fields_str} {rbrace}')
460471

461-
def _transform_basic_app(self, sort: str | None, symbol: str, args: Iterable[Pattern]) -> Term:
472+
def _transform_basic_app(self, sort: str | None, symbol: str, args: Iterable[Pattern], *, concrete: bool) -> Term:
462473
chunks = []
463474

464475
ident: str
@@ -469,7 +480,7 @@ def _transform_basic_app(self, sort: str | None, symbol: str, args: Iterable[Pat
469480
ident = _symbol_ident(symbol)
470481

471482
chunks.append(ident)
472-
chunks.extend(str(self._transform_arg(arg)) for arg in args)
483+
chunks.extend(str(self._transform_arg(arg, concrete=concrete)) for arg in args)
473484
return Term(' '.join(chunks))
474485

475486

0 commit comments

Comments
 (0)