Merge pull request #247 from ppedrot/saner-merge-types
Use sets rather than lists in compiler types
gares authored Jul 25, 2024
2 parents 89aa9bd + 272b3e5 commit b0b0d6c
Showing 5 changed files with 122 additions and 66 deletions.
6 changes: 6 additions & 0 deletions
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# Unreleased

- Compiler:
- Improve performance of separate compilation

# v1.19.4 (July 2024)

Requires Menhir 20211230 and OCaml 4.08 or above.
148 changes: 99 additions & 49 deletions src/
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ type argmap = {
n2t : (D.term * D.Constants.t) StrMap.t;
n2i : int StrMap.t;
[@@ deriving show]
[@@ deriving show, ord]

let empty_amap = {
nargs = 0;
Expand Down Expand Up @@ -405,29 +405,29 @@ type preterm = {
loc : Loc.t;
spilling : bool;
[@@ deriving show]
[@@ deriving show, ord]

type type_declaration = {
tname : D.constant;
ttype : preterm;
tloc : Loc.t;
[@@ deriving show]
[@@ deriving show, ord]

type type_abbrev_declaration = {
taname : D.constant;
tavalue : preterm;
taparams : int;
taloc : Loc.t;
[@@ deriving show]
[@@ deriving show, ord]

type presequent = {
peigen : D.term;
pcontext : D.term;
pconclusion : D.term;
[@@ deriving show]
[@@ deriving show, ord]
type prechr_rule = {
pto_match : presequent list;
pto_remove : presequent list;
Expand All @@ -438,7 +438,7 @@ type prechr_rule = {
pifexpr : string option;
pcloc : Loc.t;
[@@ deriving show]
[@@ deriving show, ord]

Intermediate program representation
Expand All @@ -447,6 +447,64 @@ type prechr_rule = {
open Data
module C = Constants

module Types = struct

type typ = {
tindex : Ast.Structured.tattribute;
decl : type_declaration
[@@deriving show, ord]

module Set = Util.Set.Make(struct
type t = typ
let compare = compare_typ
let show = show_typ
let pp = pp_typ

type types = {
set : Set.t;
lst : typ list;
def : typ;
} [@@deriving show, ord]

let make t = { set = Set.singleton t; lst = [t]; def = t }

let merge t1 t2 =
let l2 = List.filter (fun t -> not @@ Set.mem t t1.set) t2.lst in
match l2 with
| [] -> t1
| _ :: _ ->
set = Set.union t1.set t2.set;
lst = t1.lst @ l2;
def = t2.def;

let smart_map (f : typ -> typ) (t : types) : types =
let fold t accu =
let t' = f t in
if t' == t then accu
else Set.add t' (Set.remove t accu)
let set' = Set.fold fold t.set t.set in
let lst' = smart_map f t.lst in
let def' = f t.def in
if set' == t.set && lst' == t.lst && def' == t.def then t
else { set = set'; lst = lst'; def = def' }

let append x t = {
set = Set.add x t.set;
lst = x :: t.lst;
def = t.def;

let fold f accu t = List.fold_left f accu t.lst
let iter f t = List.iter f t.lst
let for_all f t = List.for_all f t.lst


module Structured = struct

type program = {
Expand All @@ -455,7 +513,7 @@ type program = {
toplevel_macros : macro_declaration;
and pbody = {
types : typ list C.Map.t;
types : Types.types C.Map.t;
type_abbrevs : type_abbrev_declaration C.Map.t;
modes : (mode * Loc.t) C.Map.t;
body : block list;
Expand All @@ -467,18 +525,14 @@ and block =
| Namespace of string * pbody
| Shorten of C.t Ast.Structured.shorthand list * pbody
| Constraints of constant list * prechr_rule list * pbody
and typ = {
tindex : Ast.Structured.tattribute;
decl : type_declaration
[@@deriving show]
[@@deriving show, ord]


module Flat = struct

type program = {
types : Structured.typ list C.Map.t;
types : Types.types C.Map.t;
type_abbrevs : type_abbrev_declaration C.Map.t;
modes : (mode * Loc.t) C.Map.t;
clauses : (preterm,Ast.Structured.attribute) Ast.Clause.t list;
Expand All @@ -495,7 +549,7 @@ end
module Assembled = struct

type program = {
types : Structured.typ list C.Map.t;
types : Types.types C.Map.t;
type_abbrevs : type_abbrev_declaration C.Map.t;
modes : (mode * Loc.t) C.Map.t;
clauses_rev : (preterm,attribute) Ast.Clause.t list;
Expand Down Expand Up @@ -538,7 +592,7 @@ module WithMain = struct

(* The entire program + query, but still in "printable" format *)
type 'a query = {
types : Structured.typ list C.Map.t;
types : Types.types C.Map.t;
type_abbrevs : type_abbrev_declaration C.Map.t;
modes : mode C.Map.t;
clauses_rev : (preterm,Assembled.attribute) Ast.Clause.t list;
Expand Down Expand Up @@ -859,10 +913,10 @@ module ToDBL : sig
(* Exported since also used to flatten (here we "flatten" locals) *)
val prefix_const : State.t -> string list -> C.t -> State.t * C.t
val merge_modes : State.t -> (mode * Loc.t) Map.t -> (mode * Loc.t) Map.t -> (mode * Loc.t) Map.t
val merge_types :
Structured.typ list C.Map.t ->
Structured.typ list C.Map.t ->
Structured.typ list C.Map.t
val merge_types : State.t ->
Types.types C.Map.t ->
Types.types C.Map.t ->
Types.types C.Map.t
val merge_type_abbrevs : State.t ->
type_abbrev_declaration C.Map.t ->
type_abbrev_declaration C.Map.t ->
Expand Down Expand Up @@ -1234,7 +1288,7 @@ let query_preterm_of_ast ~depth macros state (loc, t) =
let state, ttype =
preterms_of_ast ~on_type:true loc ~depth:lcs F.Map.empty state (fun ~depth:_ state x -> state, [loc,x]) ty in
let ttype = assert(List.length ttype = 1); List.hd ttype in
state, { Structured.tindex = attributes; decl = { tname; ttype; tloc = loc } }
state, { Types.tindex = attributes; decl = { tname; ttype; tloc = loc } }

let funct_of_ast state c =
Expand All @@ -1258,21 +1312,16 @@ let query_preterm_of_ast ~depth macros state (loc, t) =
state, C.Map.add mname (args,loc) modes

let merge_modes state m1 m2 =
if C.Map.is_empty m1 then m2 else
C.Map.fold (fun k v m ->
check_duplicate_mode state k v m;
C.Map.add k v m)
m2 m1

let merge_types t1 t2 =
C.Map.merge (fun _ l1 l2 ->
match l1, l2 with
| None, None -> None
| Some _ as l, None -> l
| None, (Some _ as l) -> l
| Some l1, Some l2 ->
Some (l1 @ (List.filter (fun x -> not @@ List.mem x l1) l2))) t1 t2
let merge_types _s t1 t2 =
C.Map.union (fun _ l1 l2 -> Some (Types.merge l1 l2)) t1 t2

let merge_type_abbrevs s m1 m2 =
if C.Map.is_empty m2 then m1 else
C.Map.fold (fun _ v m -> add_to_index_type_abbrev s m v) m1 m2

let rec toplevel_clausify loc ~depth state t =
Expand Down Expand Up @@ -1351,9 +1400,9 @@ let query_preterm_of_ast ~depth macros state (loc, t) =
let map_append k v m =
let l = C.Map.find k m in
C.Map.add k (v::l) m
C.Map.add k (Types.append v l) m
with Not_found ->
C.Map.add k [v] m
C.Map.add k (Types.make v) m

let run (state : State.t) ~toplevel_macros p =
(* FIXME: otypes omodes - NO, rewrite spilling on data.term *)
Expand All @@ -1365,7 +1414,7 @@ let query_preterm_of_ast ~depth macros state (loc, t) =
let type_abbrevs = List.fold_left (add_to_index_type_abbrev state) C.Map.empty type_abbrevs in
let state, types =
map_acc (compile_type lcs) state types in
let types = List.fold_left (fun m t -> map_append t.Structured.decl.tname t m) C.Map.empty types in
let types = List.fold_left (fun m t -> map_append t.Types.decl.tname t m) C.Map.empty types in
let state, modes = List.fold_left compile_mode (state,C.Map.empty) modes in
let defs_m = defs_of_modes modes in
let defs_t = defs_of_types types in
Expand All @@ -1391,7 +1440,7 @@ let query_preterm_of_ast ~depth macros state (loc, t) =
compile_program macros lcs state p in
let defs = C.Set.union defs symbols in
let modes = merge_modes state modes mp in
let types = merge_types types tp in
let types = merge_types state types tp in
let type_abbrevs = merge_type_abbrevs state type_abbrevs ta in
let state = set_varmap state orig_varmap in
let lcs, state, types, type_abbrevs, modes, defs, compiled_rest =
Expand Down Expand Up @@ -1514,12 +1563,12 @@ let subst_amap state f { nargs; c2i; i2n; n2t; n2i } =
t,c) n2t in
{ nargs; c2i; i2n; n2t; n2i }

let smart_map_type state f ({ Structured.tindex; decl = { tname; ttype; tloc }} as tdecl) =
let smart_map_type state f ({ Types.tindex; decl = { tname; ttype; tloc }} as tdecl) =
let tname1 = f tname in
let ttype1 = smart_map_term ~on_type:true state f ttype.term in
let tamap1 =subst_amap state f ttype.amap in
if tname1 == tname && ttype1 == ttype.term && ttype.amap = tamap1 then tdecl
else { Structured.tindex; decl = { tname = tname1; tloc; ttype = { term = ttype1; amap = tamap1; loc = ttype.loc; spilling = ttype.spilling } } }
else { Types.tindex; decl = { tname = tname1; tloc; ttype = { term = ttype1; amap = tamap1; loc = ttype.loc; spilling = ttype.spilling } } }

let map_sequent state f { peigen; pcontext; pconclusion } =
Expand Down Expand Up @@ -1577,7 +1626,7 @@ let subst_amap state f { nargs; c2i; i2n; n2t; n2i } =

let apply_subst_types ?live_symbols st s tm =
let ksub = apply_subst_constant ?live_symbols s in
C.Map.fold (fun k tl m -> C.Map.add (ksub k) (smart_map (smart_map_type st ksub) tl) m) tm C.Map.empty
C.Map.fold (fun k tl m -> C.Map.add (ksub k) (Types.smart_map (smart_map_type st ksub) tl) m) tm C.Map.empty

let apply_subst_type_abbrevs ?live_symbols st s = tabbrevs_map st (apply_subst_constant ?live_symbols s)

Expand Down Expand Up @@ -1611,15 +1660,15 @@ let subst_amap state f { nargs; c2i; i2n; n2t; n2i } =
| [] -> types, type_abbrevs, modes, clauses, chr
| Shorten(shorthands, { types = t; type_abbrevs = ta; modes = m; body; symbols = s }) :: rest ->
let insubst = push_subst_shorthands shorthands s subst in
let types = ToDBL.merge_types (apply_subst_types ~live_symbols state insubst t) types in
let types = ToDBL.merge_types state (apply_subst_types ~live_symbols state insubst t) types in
let type_abbrevs = ToDBL.merge_type_abbrevs state (apply_subst_type_abbrevs ~live_symbols state insubst ta) type_abbrevs in
let modes = ToDBL.merge_modes state (apply_subst_modes ~live_symbols insubst m) modes in
let types, type_abbrevs, modes, clauses, chr =
compile_body live_symbols state lcs types type_abbrevs modes clauses chr insubst body in
compile_body live_symbols state lcs types type_abbrevs modes clauses chr subst rest
| Namespace (extra, { types = t; type_abbrevs = ta; modes = m; body; symbols = s }) :: rest ->
let state, insubst = push_subst state extra s subst in
let types = ToDBL.merge_types (apply_subst_types ~live_symbols state insubst t) types in
let types = ToDBL.merge_types state (apply_subst_types ~live_symbols state insubst t) types in
let type_abbrevs = ToDBL.merge_type_abbrevs state (apply_subst_type_abbrevs ~live_symbols state insubst ta) type_abbrevs in
let modes = ToDBL.merge_modes state (apply_subst_modes ~live_symbols insubst m) modes in
let types, type_abbrevs, modes, clauses, chr =
Expand All @@ -1630,7 +1679,7 @@ let subst_amap state f { nargs; c2i; i2n; n2t; n2i } =
let clauses = clauses @ cl in
compile_body live_symbols state lcs types type_abbrevs modes clauses chr subst rest
| Constraints (clique, rules, { types = t; type_abbrevs = ta; modes = m; body }) :: rest ->
let types = ToDBL.merge_types (apply_subst_types ~live_symbols state subst t) types in
let types = ToDBL.merge_types state (apply_subst_types ~live_symbols state subst t) types in
let type_abbrevs = ToDBL.merge_type_abbrevs state (apply_subst_type_abbrevs ~live_symbols state subst ta) type_abbrevs in
let modes = ToDBL.merge_modes state (apply_subst_modes ~live_symbols subst m) modes in
let chr = apply_subst_chr ~live_symbols state subst (clique,rules) :: chr in
Expand Down Expand Up @@ -1697,16 +1746,16 @@ module Spill : sig

val spill_clause :
State.t -> types:Structured.typ list C.Map.t -> modes:(constant -> mode) ->
State.t -> types:Types.types C.Map.t -> modes:(constant -> mode) ->
(preterm, 'a) Ast.Clause.t -> (preterm, 'a) Ast.Clause.t

val spill_chr :
State.t -> types:Structured.typ list C.Map.t -> modes:(constant -> mode) ->
State.t -> types:Types.types C.Map.t -> modes:(constant -> mode) ->
(constant list * prechr_rule list) -> (constant list * prechr_rule list)

(* Exported to compile the query *)
val spill_preterm :
State.t -> Structured.typ list C.Map.t -> (C.t -> mode) -> preterm -> preterm
State.t -> Types.types C.Map.t -> (C.t -> mode) -> preterm -> preterm

end = struct (* {{{ *)

Expand All @@ -1722,7 +1771,7 @@ end = struct (* {{{ *)

let type_of_const types c =
let { Structured.decl = { ttype } } = List.hd @@ List.rev @@ C.Map.find c types in
let { Types.decl = { ttype } } = (C.Map.find c types).Types.def in
read_ty ttype.term
Not_found -> `Unknown
Expand Down Expand Up @@ -2037,7 +2086,7 @@ let assemble flags state code (ul : compilation_unit list) =
state, code in
let modes = ToDBL.merge_modes state m1 m2 in
let type_abbrevs = ToDBL.merge_type_abbrevs state ta1 ta2 in
let types = ToDBL.merge_types t1 t2 in
let types = ToDBL.merge_types state t1 t2 in
let cl2 = filter_if flags clause_name cl2 in
let cl2 = (Spill.spill_clause state ~types ~modes:(fun c -> fst @@ C.Map.find c modes)) cl2 in
let c2 = (Spill.spill_chr state ~types ~modes:(fun c -> fst @@ C.Map.find c modes)) c2 in
Expand Down Expand Up @@ -2206,19 +2255,19 @@ let is_builtin state tname =
let check_all_builtin_are_typed state types =
Constants.Set.iter (fun c ->
if not (match C.Map.find c types with
| l -> l |> List.for_all (fun { Structured.tindex;_} -> tindex = Ast.Structured.External)
| l -> l |> Types.for_all (fun { Types.tindex;_} -> tindex = Ast.Structured.External)
| exception Not_found -> false) then
error ("Built-in without external type declaration: " ^ state c))
(Builtins.all state);
C.Map.iter (fun tname tl -> tl |> List.iter (fun { Structured.tindex; decl = { tname; tloc }} ->
C.Map.iter (fun tname tl -> tl |> Types.iter (fun { Types.tindex; decl = { tname; tloc }} ->
if tindex = Ast.Structured.External && not (is_builtin state tname) then
error ~loc:tloc ("external type declaration without Built-in: " ^ state tname)))

let check_no_regular_types_for_builtins state types =
C.Map.iter (fun tname l -> l |> List.iter (fun {Structured.tindex; decl = { tloc } } ->
C.Map.iter (fun tname l -> l |> Types.iter (fun {Types.tindex; decl = { tloc } } ->
if tindex <> Ast.Structured.External && is_builtin state tname then
anomaly ~loc:tloc ("type declaration for Built-in " ^ state tname ^ " must be flagged as external");
Expand Down Expand Up @@ -2467,7 +2516,7 @@ let run
with Not_found ->
C.Map.add name (mode,index) map in
let map = C.Map.fold (fun tname l acc -> l |> List.fold_left (fun acc { Structured.tindex } -> add_indexing_for tname (Some tindex) acc) acc) types C.Map.empty in
let map = C.Map.fold (fun tname l acc -> Types.fold (fun acc { Types.tindex } -> add_indexing_for tname (Some tindex) acc) acc l) types C.Map.empty in
let map = C.Map.fold (fun k _ m -> add_indexing_for k None m) modes map in
map in
let state, clauses_rev =
Expand Down Expand Up @@ -2715,8 +2764,9 @@ let static_check ~exec ~checker:(state,program)
let time = `Compiletime in
let state, p,q = quote_syntax time state q in
let state, tlist = C.Map.fold (fun tname l (state,tl) ->
let l = l.Types.lst in
let state, l =
List.rev l |> map_acc (fun state { Structured.decl = { ttype } } ->
List.rev l |> map_acc (fun state { Types.decl = { ttype } } ->
let state, c = mkQCon time ~compiler_state state ~on_type:false tname in
let ttypet = unfold_type_abbrevs ~compiler_state initial_depth type_abbrevs ttype in
let state, ttypet = quote_preterm time ~compiler_state state ~on_type:true ttypet in
Expand Down

