11
11
from ..konvert import unmunge
12
12
from ..kore .internal import CollectionKind
13
13
from ..kore .kompiled import KoreSymbolTable
14
- from ..kore .manip import elim_aliases , free_occs
14
+ from ..kore .manip import collect_symbols , elim_aliases , free_occs
15
15
from ..kore .syntax import DV , And , App , EVar , SortApp , String , Top
16
16
from ..utils import FrozenDict , POSet
17
17
from .model import (
18
18
Alt ,
19
19
AltsFieldVal ,
20
+ AltsVal ,
20
21
Axiom ,
21
22
Ctor ,
23
+ Definition ,
22
24
ExplBinder ,
23
25
ImplBinder ,
24
26
Inductive ,
25
27
Instance ,
26
28
InstField ,
29
+ Modifiers ,
27
30
Module ,
28
31
Mutual ,
29
32
Signature ,
30
33
SimpleFieldVal ,
34
+ SimpleVal ,
31
35
StructCtor ,
32
36
Structure ,
33
37
StructVal ,
34
38
Term ,
35
39
)
36
40
37
41
if TYPE_CHECKING :
38
- from collections .abc import Iterable , Iterator , Mapping
42
+ from collections .abc import Collection , Iterable , Iterator , Mapping
39
43
from typing import Final
40
44
41
45
from ..kore .internal import KoreDefn
42
- from ..kore .rule import RewriteRule , Rule
46
+ from ..kore .rule import FunctionRule , RewriteRule , Rule
43
47
from ..kore .syntax import Pattern , Sort , SymbolDecl
44
- from .model import Binder , Command , Declaration , FieldVal
48
+ from .model import Binder , Command , Declaration , DeclVal , FieldVal
45
49
46
50
47
51
_PRELUDE_SORTS : Final = {'SortBool' , 'SortBytes' , 'SortId' , 'SortInt' , 'SortString' , 'SortStringBuffer' }
52
+ _PRELUDE_FUNCS : Final = {
53
+ "Lbl'UndsPlus'Int'Unds'" ,
54
+ "Lbl'Unds'-Int'Unds'" ,
55
+ "Lbl'UndsStar'Int'Unds'" ,
56
+ "Lbl'Unds-LT-Eqls'Int'Unds'" ,
57
+ }
48
58
49
59
50
60
class Field (NamedTuple ):
@@ -96,6 +106,51 @@ def fields_of(sort: str) -> tuple[Field, ...] | None:
96
106
97
107
return FrozenDict ((sort , fields ) for sort in self .defn .sorts if (fields := fields_of (sort )) is not None )
98
108
109
+ @cached_property
110
+ def _func_rules_by_uid (self ) -> FrozenDict [str , FunctionRule ]:
111
+ return FrozenDict ((rule .uid , rule ) for rules in self .defn .functions .values () for rule in rules )
112
+
113
+ @cached_property
114
+ def _func_deps (self ) -> FrozenDict [str , frozenset [str ]]:
115
+ deps : list [tuple [str , str ]] = []
116
+ elems : set [str ] = set ()
117
+ for func , rules in self .defn .functions .items ():
118
+ elems .add (func )
119
+ for rule in rules :
120
+ elems .add (rule .uid )
121
+ # A function depends on the rules that define it
122
+ deps .append ((func , rule .uid ))
123
+ # A rule depends on all the functions that it applies
124
+ symbols = collect_symbols (
125
+ And (SortApp ('SortFoo' ), (rule .req or Top (SortApp ('SortFoo' )), * rule .lhs .args , rule .rhs ))
126
+ ) # Collection functions like `_List_` can occur on the LHS
127
+ deps .extend ((rule .uid , symbol ) for symbol in symbols if symbol in self .defn .functions )
128
+
129
+ closed = POSet (deps ).image
130
+ return FrozenDict ((elem , frozenset (closed .get (elem , []))) for elem in elems )
131
+
132
+ @cached_property
133
+ def _func_sccs (self ) -> tuple [tuple [str , ...], ...]:
134
+ sccs = _ordered_sccs (self ._func_deps )
135
+ return tuple (tuple (scc ) for scc in sccs )
136
+
137
+ @cached_property
138
+ def noncomputable (self ) -> frozenset [str ]:
139
+ res : set [str ] = set ()
140
+
141
+ for scc in self ._func_sccs :
142
+ assert scc
143
+ elem = scc [0 ]
144
+
145
+ if elem in self .defn .functions and elem not in _PRELUDE_FUNCS and not self .defn .functions [elem ]:
146
+ assert len (scc ) == 1
147
+ res .add (elem )
148
+
149
+ if any (dep in res for dep in self ._func_deps [elem ]):
150
+ res .update (scc )
151
+
152
+ return frozenset (res )
153
+
99
154
@staticmethod
100
155
def _is_cell (sort : str ) -> bool :
101
156
return sort .endswith ('Cell' )
@@ -247,12 +302,154 @@ def inj(subsort: str, supersort: str, x: str) -> Term:
247
302
return res
248
303
249
304
def func_module (self ) -> Module :
250
- commands = [self ._func_axiom (func ) for func in self .defn .functions ]
251
- return Module (commands = commands )
305
+ sccs = self ._func_sccs
306
+ return Module (commands = tuple (command for elems in sccs if (command := self ._func_block (elems ))))
307
+
308
+ def _func_block (self , elems : Iterable [str ]) -> Command | None :
309
+ assert elems
310
+ elems = [elem for elem in elems if elem not in _PRELUDE_FUNCS ]
311
+
312
+ if not elems :
313
+ return None
252
314
253
- def _func_axiom (self , func : str ) -> Axiom :
315
+ if len (elems ) == 1 :
316
+ (elem ,) = elems
317
+ return self ._func_command (elem )
318
+
319
+ return Mutual (commands = tuple (self ._func_command (elem ) for elem in elems ))
320
+
321
+ def _func_command (self , elem : str ) -> Command :
322
+ if elem in self .defn .functions :
323
+ decl = self .defn .symbols [elem ]
324
+ rules = self .defn .functions [elem ]
325
+ if rules :
326
+ return self ._func_def (decl , rules )
327
+ return self ._func_axiom (decl )
328
+ rule = self ._func_rules_by_uid [elem ]
329
+ return self ._func_rule_def (rule )
330
+
331
+ def _func_def (self , decl : SymbolDecl , rules : tuple [FunctionRule , ...]) -> Definition :
332
+ def sort_rules_by_priority (rules : tuple [FunctionRule , ...]) -> list [str ]:
333
+ grouped : dict [int , list [str ]] = {}
334
+ for rule in rules :
335
+ grouped .setdefault (rule .priority , []).append (_rule_name (rule ))
336
+ groups = [sorted (grouped [priority ]) for priority in sorted (grouped )]
337
+ return [rule for group in groups for rule in group ]
338
+
339
+ assert rules
340
+
341
+ sorted_rules = sort_rules_by_priority (rules )
342
+ params = [f'x{ i } ' for i in range (len (decl .param_sorts ))]
343
+ arg_str = ' ' + ' ' .join (params ) if params else ''
344
+ term : Term
345
+ if len (sorted_rules ) == 1 :
346
+ rule_str = sorted_rules [0 ]
347
+ term = Term (f'{ rule_str } { arg_str } ' )
348
+ else :
349
+ rules_str = f'[{ ", " .join (sorted_rules )} ]'
350
+ term = Term (f'{ rules_str } .findSome? (·{ arg_str } )' )
351
+
352
+ val = SimpleVal (term )
353
+ func = decl .symbol .name
254
354
ident = _symbol_ident (func )
255
- decl = self .defn .symbols [func ]
355
+ signature = self ._func_signature (decl )
356
+ modifiers = Modifiers (noncomputable = True ) if func in self .noncomputable else None
357
+
358
+ return Definition (ident , val , signature , modifiers = modifiers )
359
+
360
+ def _func_rule_def (self , rule : FunctionRule ) -> Definition :
361
+ decl = self .defn .symbols [rule .lhs .symbol ]
362
+ sort_params = [var .name for var in decl .symbol .vars ]
363
+ param_sorts = [sort .name for sort in decl .param_sorts ]
364
+ sort = decl .sort .name
365
+
366
+ ident = _rule_name (rule )
367
+ binders = (ImplBinder (sort_params , Term ('Type' )),) if sort_params else ()
368
+ ty = Term (' → ' .join (param_sorts + [f'Option { sort } ' ]))
369
+ modifiers = Modifiers (noncomputable = True ) if rule .uid in self .noncomputable else None
370
+ signature = Signature (binders , ty )
371
+
372
+ req , lhs , rhs , defs = self ._extract_func_rule (rule )
373
+ val = self ._func_rule_val (lhs .args , req , rhs , defs )
374
+
375
+ return Definition (ident , val , signature , modifiers = modifiers )
376
+
377
+ def _extract_func_rule (self , rule : FunctionRule ) -> tuple [Pattern , App , Pattern , dict [str , Pattern ]]:
378
+ req = rule .req if rule .req else Top (SortApp ('Foo' ))
379
+
380
+ pattern = elim_aliases (And (SortApp ('Foo' ), (req , rule .lhs , rule .rhs )))
381
+ assert isinstance (pattern , And )
382
+ req , lhs , rhs = pattern .ops
383
+ assert isinstance (lhs , App )
384
+
385
+ free = (f"Var'Unds'Val{ i } " for i in count ())
386
+ pattern , defs = self ._elim_fun_apps (And (SortApp ('Foo' ), (req , rhs )), free )
387
+ assert isinstance (pattern , And )
388
+ req , rhs = pattern .ops
389
+
390
+ return req , lhs , rhs , defs
391
+
392
+ def _func_rule_val (
393
+ self ,
394
+ args : tuple [Pattern , ...],
395
+ req : Pattern ,
396
+ rhs : Pattern ,
397
+ defs : dict [str , Pattern ],
398
+ ) -> DeclVal :
399
+ term = self ._func_rule_term (req , rhs , defs )
400
+
401
+ if not args :
402
+ return SimpleVal (term )
403
+
404
+ alts : list [Alt ] = []
405
+
406
+ match_alt = Alt (tuple (self ._transform_pattern (arg , concrete = True ) for arg in args ), term )
407
+ alts .append (match_alt )
408
+
409
+ if not all (self ._is_exhaustive (arg ) for arg in args ):
410
+ nomatch_alt = Alt ((Term ('_' ),) * len (args ), Term ('none' ))
411
+ alts .append (nomatch_alt )
412
+
413
+ return AltsVal (alts )
414
+
415
+ def _func_rule_term (self , req : Pattern , rhs : Pattern , defs : dict [str , Pattern ]) -> Term :
416
+ if not defs and isinstance (req , Top ):
417
+ return Term (f'some { self ._transform_arg (rhs )} ' )
418
+
419
+ seq_strs : list [str ] = []
420
+ seq_strs .extend (f'let { var } <- { self ._transform_pattern (pattern )} ' for var , pattern in defs .items ())
421
+ if not isinstance (req , Top ):
422
+ seq_strs .append (f'guard { self ._transform_arg (req )} ' )
423
+ seq_strs .append (f'return { self ._transform_arg (rhs )} ' )
424
+ do_str = '\n ' .join (' ' + seq_str for seq_str in seq_strs )
425
+ return Term (f'do\n { do_str } ' )
426
+
427
+ def _is_exhaustive (self , pattern : Pattern ) -> bool :
428
+ match pattern :
429
+ case DV ():
430
+ return False
431
+ case EVar ():
432
+ return True
433
+ case App (symbol , _, args ) as app :
434
+ if symbol in self .defn .functions :
435
+ # Collection function
436
+ return False
437
+
438
+ _sort = self .symbol_table .infer_sort (app )
439
+ assert isinstance (_sort , SortApp )
440
+ sort = _sort .name
441
+ n_ctors = len (self .defn .constructors .get (sort , ())) + len (self .defn .subsorts .get (sort , ()))
442
+ assert n_ctors
443
+ return n_ctors == 1 and all (self ._is_exhaustive (arg ) for arg in args )
444
+ case _:
445
+ raise AssertionError ()
446
+
447
+ def _func_axiom (self , decl : SymbolDecl ) -> Axiom :
448
+ ident = _symbol_ident (decl .symbol .name )
449
+ signature = self ._func_signature (decl )
450
+ return Axiom (ident , signature )
451
+
452
+ def _func_signature (self , decl : SymbolDecl ) -> Signature :
256
453
sort_params = [var .name for var in decl .symbol .vars ]
257
454
param_sorts = [sort .name for sort in decl .param_sorts ]
258
455
sort = decl .sort .name
@@ -261,7 +458,8 @@ def _func_axiom(self, func: str) -> Axiom:
261
458
if sort_params :
262
459
binders .append (ImplBinder (sort_params , Term ('Type' )))
263
460
binders .extend (ExplBinder ((f'x{ i } ' ,), Term (sort )) for i , sort in enumerate (param_sorts ))
264
- return Axiom (ident , Signature (binders , Term (f'Option { sort } ' )))
461
+
462
+ return Signature (binders , Term (f'Option { sort } ' ))
265
463
266
464
def rewrite_module (self ) -> Module :
267
465
commands = (self ._rewrite_inductive (),)
@@ -553,7 +751,7 @@ def _sort_dependencies(defn: KoreDefn) -> dict[str, set[str]]:
553
751
} # Ensure that sorts without dependencies are also represented
554
752
555
753
556
- def _ordered_sccs (deps : dict [str , set [str ]]) -> list [list [str ]]:
754
+ def _ordered_sccs (deps : Mapping [str , Collection [str ]]) -> list [list [str ]]:
557
755
sccs = _sccs (deps )
558
756
559
757
elems_by_scc : dict [int , set [str ]] = {}
@@ -576,7 +774,7 @@ def _ordered_sccs(deps: dict[str, set[str]]) -> list[list[str]]:
576
774
577
775
578
776
# TODO Implement a more efficient algorithm, e.g. Tarjan's algorithm
579
- def _sccs (deps : dict [str , set [str ]]) -> dict [str , int ]:
777
+ def _sccs (deps : Mapping [str , Iterable [str ]]) -> dict [str , int ]:
580
778
res : dict [str , int ] = {}
581
779
582
780
scc = count ()
0 commit comments