Skip to content

Commit 23c375b

Browse files
committed
Generate function definitions
1 parent 0ec659f commit 23c375b

File tree

3 files changed

+216
-13
lines changed

3 files changed

+216
-13
lines changed

pyk/src/pyk/klean/generate.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@ def generate(
2828
k2lean4 = K2Lean4(defn)
2929
genmodel = {
3030
'Sorts': (k2lean4.sort_module, ['Prelude']),
31-
'Func': (k2lean4.func_module, ['Sorts']),
3231
'Inj': (k2lean4.inj_module, ['Sorts']),
33-
'Rewrite': (k2lean4.rewrite_module, ['Func', 'Inj']),
32+
'Func': (k2lean4.func_module, ['Inj']),
33+
'Rewrite': (k2lean4.rewrite_module, ['Func']),
3434
}
3535

3636
modules = _gen_modules(context['library_name'], genmodel)

pyk/src/pyk/klean/k2lean4.py

+209-11
Original file line numberDiff line numberDiff line change
@@ -11,40 +11,50 @@
1111
from ..konvert import unmunge
1212
from ..kore.internal import CollectionKind
1313
from ..kore.kompiled import KoreSymbolTable
14-
from ..kore.manip import elim_aliases, free_occs
14+
from ..kore.manip import collect_symbols, elim_aliases, free_occs
1515
from ..kore.syntax import DV, And, App, EVar, SortApp, String, Top
1616
from ..utils import FrozenDict, POSet
1717
from .model import (
1818
Alt,
1919
AltsFieldVal,
20+
AltsVal,
2021
Axiom,
2122
Ctor,
23+
Definition,
2224
ExplBinder,
2325
ImplBinder,
2426
Inductive,
2527
Instance,
2628
InstField,
29+
Modifiers,
2730
Module,
2831
Mutual,
2932
Signature,
3033
SimpleFieldVal,
34+
SimpleVal,
3135
StructCtor,
3236
Structure,
3337
StructVal,
3438
Term,
3539
)
3640

3741
if TYPE_CHECKING:
38-
from collections.abc import Iterable, Iterator, Mapping
42+
from collections.abc import Collection, Iterable, Iterator, Mapping
3943
from typing import Final
4044

4145
from ..kore.internal import KoreDefn
42-
from ..kore.rule import RewriteRule, Rule
46+
from ..kore.rule import FunctionRule, RewriteRule, Rule
4347
from ..kore.syntax import Pattern, Sort, SymbolDecl
44-
from .model import Binder, Command, Declaration, FieldVal
48+
from .model import Binder, Command, Declaration, DeclVal, FieldVal
4549

4650

4751
_PRELUDE_SORTS: Final = {'SortBool', 'SortBytes', 'SortId', 'SortInt', 'SortString', 'SortStringBuffer'}
52+
_PRELUDE_FUNCS: Final = {
53+
"Lbl'UndsPlus'Int'Unds'",
54+
"Lbl'Unds'-Int'Unds'",
55+
"Lbl'UndsStar'Int'Unds'",
56+
"Lbl'Unds-LT-Eqls'Int'Unds'",
57+
}
4858

4959

5060
class Field(NamedTuple):
@@ -96,6 +106,51 @@ def fields_of(sort: str) -> tuple[Field, ...] | None:
96106

97107
return FrozenDict((sort, fields) for sort in self.defn.sorts if (fields := fields_of(sort)) is not None)
98108

109+
@cached_property
110+
def _func_rules_by_uid(self) -> FrozenDict[str, FunctionRule]:
111+
return FrozenDict((rule.uid, rule) for rules in self.defn.functions.values() for rule in rules)
112+
113+
@cached_property
114+
def _func_deps(self) -> FrozenDict[str, frozenset[str]]:
115+
deps: list[tuple[str, str]] = []
116+
elems: set[str] = set()
117+
for func, rules in self.defn.functions.items():
118+
elems.add(func)
119+
for rule in rules:
120+
elems.add(rule.uid)
121+
# A function depends on the rules that define it
122+
deps.append((func, rule.uid))
123+
# A rule depends on all the functions that it applies
124+
symbols = collect_symbols(
125+
And(SortApp('SortFoo'), (rule.req or Top(SortApp('SortFoo')), *rule.lhs.args, rule.rhs))
126+
) # Collection functions like `_List_` can occur on the LHS
127+
deps.extend((rule.uid, symbol) for symbol in symbols if symbol in self.defn.functions)
128+
129+
closed = POSet(deps).image
130+
return FrozenDict((elem, frozenset(closed.get(elem, []))) for elem in elems)
131+
132+
@cached_property
133+
def _func_sccs(self) -> tuple[tuple[str, ...], ...]:
134+
sccs = _ordered_sccs(self._func_deps)
135+
return tuple(tuple(scc) for scc in sccs)
136+
137+
@cached_property
138+
def noncomputable(self) -> frozenset[str]:
139+
res: set[str] = set()
140+
141+
for scc in self._func_sccs:
142+
assert scc
143+
elem = scc[0]
144+
145+
if elem in self.defn.functions and elem not in _PRELUDE_FUNCS and not self.defn.functions[elem]:
146+
assert len(scc) == 1
147+
res.add(elem)
148+
149+
if any(dep in res for dep in self._func_deps[elem]):
150+
res.update(scc)
151+
152+
return frozenset(res)
153+
99154
@staticmethod
100155
def _is_cell(sort: str) -> bool:
101156
return sort.endswith('Cell')
@@ -247,12 +302,154 @@ def inj(subsort: str, supersort: str, x: str) -> Term:
247302
return res
248303

249304
def func_module(self) -> Module:
250-
commands = [self._func_axiom(func) for func in self.defn.functions]
251-
return Module(commands=commands)
305+
sccs = self._func_sccs
306+
return Module(commands=tuple(command for elems in sccs if (command := self._func_block(elems))))
307+
308+
def _func_block(self, elems: Iterable[str]) -> Command | None:
309+
assert elems
310+
elems = [elem for elem in elems if elem not in _PRELUDE_FUNCS]
311+
312+
if not elems:
313+
return None
252314

253-
def _func_axiom(self, func: str) -> Axiom:
315+
if len(elems) == 1:
316+
(elem,) = elems
317+
return self._func_command(elem)
318+
319+
return Mutual(commands=tuple(self._func_command(elem) for elem in elems))
320+
321+
def _func_command(self, elem: str) -> Command:
322+
if elem in self.defn.functions:
323+
decl = self.defn.symbols[elem]
324+
rules = self.defn.functions[elem]
325+
if rules:
326+
return self._func_def(decl, rules)
327+
return self._func_axiom(decl)
328+
rule = self._func_rules_by_uid[elem]
329+
return self._func_rule_def(rule)
330+
331+
def _func_def(self, decl: SymbolDecl, rules: tuple[FunctionRule, ...]) -> Definition:
332+
def sort_rules_by_priority(rules: tuple[FunctionRule, ...]) -> list[str]:
333+
grouped: dict[int, list[str]] = {}
334+
for rule in rules:
335+
grouped.setdefault(rule.priority, []).append(_rule_name(rule))
336+
groups = [sorted(grouped[priority]) for priority in sorted(grouped)]
337+
return [rule for group in groups for rule in group]
338+
339+
assert rules
340+
341+
sorted_rules = sort_rules_by_priority(rules)
342+
params = [f'x{i}' for i in range(len(decl.param_sorts))]
343+
arg_str = ' ' + ' '.join(params) if params else ''
344+
term: Term
345+
if len(sorted_rules) == 1:
346+
rule_str = sorted_rules[0]
347+
term = Term(f'{rule_str}{arg_str}')
348+
else:
349+
rules_str = f'[{", ".join(sorted_rules)}]'
350+
term = Term(f'{rules_str}.findSome? (·{arg_str})')
351+
352+
val = SimpleVal(term)
353+
func = decl.symbol.name
254354
ident = _symbol_ident(func)
255-
decl = self.defn.symbols[func]
355+
signature = self._func_signature(decl)
356+
modifiers = Modifiers(noncomputable=True) if func in self.noncomputable else None
357+
358+
return Definition(ident, val, signature, modifiers=modifiers)
359+
360+
def _func_rule_def(self, rule: FunctionRule) -> Definition:
361+
decl = self.defn.symbols[rule.lhs.symbol]
362+
sort_params = [var.name for var in decl.symbol.vars]
363+
param_sorts = [sort.name for sort in decl.param_sorts]
364+
sort = decl.sort.name
365+
366+
ident = _rule_name(rule)
367+
binders = (ImplBinder(sort_params, Term('Type')),) if sort_params else ()
368+
ty = Term(' → '.join(param_sorts + [f'Option {sort}']))
369+
modifiers = Modifiers(noncomputable=True) if rule.uid in self.noncomputable else None
370+
signature = Signature(binders, ty)
371+
372+
req, lhs, rhs, defs = self._extract_func_rule(rule)
373+
val = self._func_rule_val(lhs.args, req, rhs, defs)
374+
375+
return Definition(ident, val, signature, modifiers=modifiers)
376+
377+
def _extract_func_rule(self, rule: FunctionRule) -> tuple[Pattern, App, Pattern, dict[str, Pattern]]:
378+
req = rule.req if rule.req else Top(SortApp('Foo'))
379+
380+
pattern = elim_aliases(And(SortApp('Foo'), (req, rule.lhs, rule.rhs)))
381+
assert isinstance(pattern, And)
382+
req, lhs, rhs = pattern.ops
383+
assert isinstance(lhs, App)
384+
385+
free = (f"Var'Unds'Val{i}" for i in count())
386+
pattern, defs = self._elim_fun_apps(And(SortApp('Foo'), (req, rhs)), free)
387+
assert isinstance(pattern, And)
388+
req, rhs = pattern.ops
389+
390+
return req, lhs, rhs, defs
391+
392+
def _func_rule_val(
393+
self,
394+
args: tuple[Pattern, ...],
395+
req: Pattern,
396+
rhs: Pattern,
397+
defs: dict[str, Pattern],
398+
) -> DeclVal:
399+
term = self._func_rule_term(req, rhs, defs)
400+
401+
if not args:
402+
return SimpleVal(term)
403+
404+
alts: list[Alt] = []
405+
406+
match_alt = Alt(tuple(self._transform_pattern(arg, concrete=True) for arg in args), term)
407+
alts.append(match_alt)
408+
409+
if not all(self._is_exhaustive(arg) for arg in args):
410+
nomatch_alt = Alt((Term('_'),) * len(args), Term('none'))
411+
alts.append(nomatch_alt)
412+
413+
return AltsVal(alts)
414+
415+
def _func_rule_term(self, req: Pattern, rhs: Pattern, defs: dict[str, Pattern]) -> Term:
416+
if not defs and isinstance(req, Top):
417+
return Term(f'some {self._transform_arg(rhs)}')
418+
419+
seq_strs: list[str] = []
420+
seq_strs.extend(f'let {var} <- {self._transform_pattern(pattern)}' for var, pattern in defs.items())
421+
if not isinstance(req, Top):
422+
seq_strs.append(f'guard {self._transform_arg(req)}')
423+
seq_strs.append(f'return {self._transform_arg(rhs)}')
424+
do_str = '\n'.join(' ' + seq_str for seq_str in seq_strs)
425+
return Term(f'do\n{do_str}')
426+
427+
def _is_exhaustive(self, pattern: Pattern) -> bool:
428+
match pattern:
429+
case DV():
430+
return False
431+
case EVar():
432+
return True
433+
case App(symbol, _, args) as app:
434+
if symbol in self.defn.functions:
435+
# Collection function
436+
return False
437+
438+
_sort = self.symbol_table.infer_sort(app)
439+
assert isinstance(_sort, SortApp)
440+
sort = _sort.name
441+
n_ctors = len(self.defn.constructors.get(sort, ())) + len(self.defn.subsorts.get(sort, ()))
442+
assert n_ctors
443+
return n_ctors == 1 and all(self._is_exhaustive(arg) for arg in args)
444+
case _:
445+
raise AssertionError()
446+
447+
def _func_axiom(self, decl: SymbolDecl) -> Axiom:
448+
ident = _symbol_ident(decl.symbol.name)
449+
signature = self._func_signature(decl)
450+
return Axiom(ident, signature)
451+
452+
def _func_signature(self, decl: SymbolDecl) -> Signature:
256453
sort_params = [var.name for var in decl.symbol.vars]
257454
param_sorts = [sort.name for sort in decl.param_sorts]
258455
sort = decl.sort.name
@@ -261,7 +458,8 @@ def _func_axiom(self, func: str) -> Axiom:
261458
if sort_params:
262459
binders.append(ImplBinder(sort_params, Term('Type')))
263460
binders.extend(ExplBinder((f'x{i}',), Term(sort)) for i, sort in enumerate(param_sorts))
264-
return Axiom(ident, Signature(binders, Term(f'Option {sort}')))
461+
462+
return Signature(binders, Term(f'Option {sort}'))
265463

266464
def rewrite_module(self) -> Module:
267465
commands = (self._rewrite_inductive(),)
@@ -553,7 +751,7 @@ def _sort_dependencies(defn: KoreDefn) -> dict[str, set[str]]:
553751
} # Ensure that sorts without dependencies are also represented
554752

555753

556-
def _ordered_sccs(deps: dict[str, set[str]]) -> list[list[str]]:
754+
def _ordered_sccs(deps: Mapping[str, Collection[str]]) -> list[list[str]]:
557755
sccs = _sccs(deps)
558756

559757
elems_by_scc: dict[int, set[str]] = {}
@@ -576,7 +774,7 @@ def _ordered_sccs(deps: dict[str, set[str]]) -> list[list[str]]:
576774

577775

578776
# TODO Implement a more efficient algorithm, e.g. Tarjan's algorithm
579-
def _sccs(deps: dict[str, set[str]]) -> dict[str, int]:
777+
def _sccs(deps: Mapping[str, Iterable[str]]) -> dict[str, int]:
580778
res: dict[str, int] = {}
581779

582780
scc = count()

pyk/src/pyk/klean/template/{{ cookiecutter.package_name }}/{{ cookiecutter.library_name }}/Prelude.lean

+5
Original file line numberDiff line numberDiff line change
@@ -210,3 +210,8 @@ class Inj (From To : Type) : Type where
210210
inj (x : From) : To
211211

212212
def inj {From To : Type} [inst : Inj From To] := inst.inj
213+
214+
def «_+Int_» (x0 : SortInt) (x1 : SortInt) : Option SortInt := some (x0 + x1)
215+
def «_-Int_» (x0 : SortInt) (x1 : SortInt) : Option SortInt := some (x0 - x1)
216+
def «_*Int_» (x0 : SortInt) (x1 : SortInt) : Option SortInt := some (x0 * x1)
217+
def «_<=Int_» (x0 : SortInt) (x1 : SortInt) : Option SortBool := some (x0 <= x1)

0 commit comments

Comments
 (0)