|
| 1 | +from collections import defaultdict |
| 2 | +from typing import Union |
| 3 | + |
| 4 | +from pysmt.fnode import FNode |
| 5 | +from pysmt.shortcuts import FALSE, And, Symbol, Solver, simplify |
| 6 | +from pysmt.walkers import IdentityDagWalker |
| 7 | + |
| 8 | +from lib_pea.location import Location |
| 9 | +from lib_pea.transition import Transition |
| 10 | + |
| 11 | +SOLVER_NAME = "z3" |
| 12 | +LOGIC = "UFLIRA" |
| 13 | + |
| 14 | + |
| 15 | +class PeaOperationsMixin: |
| 16 | + |
| 17 | + OP_TOKEN = 1 |
| 18 | + |
| 19 | + def intersect(self: Union["Pea", "PeaOperationsMixin"], other: "Pea") -> "Pea": |
| 20 | + """Naiive implementation of PEA intersection for building small examples""" |
| 21 | + from lib_pea.pea import Pea |
| 22 | + |
| 23 | + PeaOperationsMixin.OP_TOKEN += 2 # Some way to add unuiqe stuff to each pea |
| 24 | + # TODO Clock substitutions |
| 25 | + self_clocks = {c: f"{c}.{PeaOperationsMixin.OP_TOKEN-1}" for c in self.clocks} |
| 26 | + other_clocks = {c: f"{c}.{PeaOperationsMixin.OP_TOKEN}" for c in other.clocks} |
| 27 | + result = Pea() |
| 28 | + locations = PeaOperationsMixin.__union_locations(self.locations(), other.locations(), self_clocks, other_clocks) |
| 29 | + result.transitions = PeaOperationsMixin.__union_transitions( |
| 30 | + self.transitions, other.transitions, locations, self_clocks, other_clocks |
| 31 | + ) |
| 32 | + result.clocks = set(self_clocks.values()) | set(other_clocks.values()) |
| 33 | + # TODO: Minimize away all false edges and locations |
| 34 | + return result |
| 35 | + |
| 36 | + @staticmethod |
| 37 | + def __union_locations( |
| 38 | + self_locs: set[Location], |
| 39 | + other_locs: set[Location], |
| 40 | + self_clocks: dict[str, str], |
| 41 | + other_clocks: dict[str, str], |
| 42 | + ) -> dict[(Location, Location), Location]: |
| 43 | + result = dict() |
| 44 | + for sl in self_locs: |
| 45 | + for ol in other_locs: |
| 46 | + if not sl or not ol: |
| 47 | + continue # Cant combine an initial edge with a non-initnial edge |
| 48 | + ul = Location( |
| 49 | + state_invariant=PeaOperationsMixin.__conjunct_builder( |
| 50 | + sl.state_invariant, ol.state_invariant, self_clocks, other_clocks |
| 51 | + ), |
| 52 | + clock_invariant=PeaOperationsMixin.__conjunct_builder( |
| 53 | + sl.clock_invariant, ol.clock_invariant, self_clocks, other_clocks |
| 54 | + ), |
| 55 | + label=f"{sl.label}+{ol.label}", |
| 56 | + ) |
| 57 | + if ul.state_invariant is FALSE() or ul.clock_invariant is FALSE(): |
| 58 | + continue |
| 59 | + result[(sl, ol)] = ul |
| 60 | + return result |
| 61 | + |
| 62 | + @staticmethod |
| 63 | + def __union_transitions( |
| 64 | + self_transitions: defaultdict[Location, set[Transition]], |
| 65 | + other_transitions: defaultdict[Location, set[Transition]], |
| 66 | + locations: dict[(Location, Location), Location], |
| 67 | + self_clocks: dict[str, str], |
| 68 | + other_clocks: dict[str, str], |
| 69 | + ) -> defaultdict[Location, set[Transition]]: |
| 70 | + result = defaultdict(set) |
| 71 | + for (self_loc, other_loc), union_loc in locations.items(): |
| 72 | + for st in self_transitions[self_loc]: |
| 73 | + for ot in other_transitions[other_loc]: |
| 74 | + if (st.dst, ot.dst) not in locations: |
| 75 | + continue |
| 76 | + ut = Transition( |
| 77 | + src=union_loc, |
| 78 | + dst=locations[(st.dst, ot.dst)], |
| 79 | + guard=PeaOperationsMixin.__conjunct_builder(st.guard, ot.guard, self_clocks, other_clocks), |
| 80 | + resets=frozenset({self_clocks[c] for c in st.resets} | {other_clocks[c] for c in ot.resets}), |
| 81 | + ) |
| 82 | + if ut.guard is FALSE(): |
| 83 | + continue |
| 84 | + result[union_loc].add(ut) |
| 85 | + # Build initial trainsitions |
| 86 | + for si in self_transitions[None]: |
| 87 | + for oi in self_transitions[None]: |
| 88 | + if (si.dst, oi.dst) not in locations: |
| 89 | + continue |
| 90 | + ut = Transition( |
| 91 | + src=None, |
| 92 | + dst=locations[(si.dst, oi.dst)], |
| 93 | + guard=PeaOperationsMixin.__conjunct_builder(si.guard, oi.guard, self_clocks, other_clocks), |
| 94 | + resets=frozenset(), |
| 95 | + ) |
| 96 | + if ut.guard is FALSE(): |
| 97 | + continue |
| 98 | + result[None].add(ut) |
| 99 | + return result |
| 100 | + |
| 101 | + @staticmethod |
| 102 | + def __conjunct_builder( |
| 103 | + self_junct: FNode, other_junct: FNode, self_clocks: dict[str, str], other_clocks: dict[str, str] |
| 104 | + ): |
| 105 | + """Just conjuct the two Fnodes, but make clocks unique before""" |
| 106 | + self_junct = Renamer(self_clocks).walk(self_junct) |
| 107 | + other_junct = Renamer(other_clocks).walk(other_junct) |
| 108 | + g = And(self_junct, other_junct) |
| 109 | + with Solver(name=SOLVER_NAME, logic=LOGIC) as solver: |
| 110 | + if solver.is_unsat(g): |
| 111 | + return FALSE() |
| 112 | + g = simplify(g) |
| 113 | + return g |
| 114 | + |
| 115 | + |
| 116 | +class Renamer(IdentityDagWalker): |
| 117 | + def __init__(self, renaming_dict: dict): |
| 118 | + IdentityDagWalker.__init__(self) |
| 119 | + self.renaming_dict = renaming_dict |
| 120 | + |
| 121 | + def walk_symbol(self, formula, args, **kwargs): |
| 122 | + # lambda s: Symbol("renamed_" + s.symbol_name(), s.symbol_type()) |
| 123 | + if name := formula.symbol_name in self.renaming_dict: |
| 124 | + return Symbol(self.renaming_dict[name], formula.symbol_type()) |
| 125 | + return formula |
0 commit comments