diff --git a/src/compiler.ml b/src/compiler.ml index 0eed3ede3..9434198cb 100644 --- a/src/compiler.ml +++ b/src/compiler.ml @@ -105,10 +105,12 @@ type table = { uuid : Util.UUID.t; } [@@deriving show] -type pruned_table = { - c2s0 : string D.Constants.Map.t; - c2t0 : D.term D.Constants.Map.t; -} [@@deriving show] +type entry = +| GlobalSymbol of D.constant * string +| BoundVariable of D.constant * D.term +[@@deriving show] + +type pruned_table = entry array [@@deriving show] let locked { locked } = locked let lock t = { t with locked = true } @@ -116,16 +118,24 @@ let uuid { uuid } = uuid let equal t1 t2 = locked t1 && locked t2 && uuid t1 = uuid t2 -let size t = D.Constants.Map.cardinal t.c2t0 +let size t = Array.length t -let symbols { c2s0 } = - List.map (fun (c,s) -> s ^ ":" ^ string_of_int c) (D.Constants.Map.bindings c2s0) +let symbols table = + let map = function + | GlobalSymbol (c, s) -> Some (s ^ ":" ^ string_of_int c) + | BoundVariable _ -> None + in + List.rev @@ List.filter_map map @@ Array.to_list table let prune t ~alive = - { - c2s0 = D.Constants.Map.filter (fun k _ -> D.Constants.Set.mem k alive) t.c2s; - c2t0 = D.Constants.Map.filter (fun k _ -> D.Constants.Set.mem k alive) t.c2t; - } + let c2s = t.c2s in + let c2t0 = D.Constants.Map.filter (fun k _ -> D.Constants.Set.mem k alive) t.c2t in + let map k t = + if k < 0 then GlobalSymbol (k, D.Constants.Map.find k c2s) + else BoundVariable (k, t) + in + let c2t0 = D.Constants.Map.mapi map c2t0 in + Array.of_list @@ List.rev_map snd @@ D.Constants.Map.bindings c2t0 let table = D.State.declare ~descriptor:D.elpi_state_descriptor @@ -247,10 +257,10 @@ let build_shift ?(lock_base=false) ~flags:{ print_units } ~base symbols = (* We try hard to respect the same order if possible, since some tests (grundlagen) depend on this order (for performance, the constant-timestamp heuristic in unfolding) *) - List.fold_left (fun (base,shift as acc) (v, t) -> - if v < 0 then - let name = Map.find v symbols.c2s0 in - try + Array.fold_left (fun (base,shift as acc) e -> + match e with + | GlobalSymbol (v, name) -> + begin try let c, _ = F.Map.find (F.from_string name) base.ast2ct in if c == v then acc else begin @@ -262,13 +272,14 @@ let build_shift ?(lock_base=false) ~flags:{ print_units } ~base symbols = | Not_found -> let base, (c,_) = allocate_global_symbol_aux (Ast.Func.from_string name) base in base, Map.add v c shift - else + end + | BoundVariable (v, t) -> if Map.mem v base.c2t then acc else let base = { base with c2t = Map.add v t base.c2t } in base, shift ) - (base,Map.empty) (List.rev (Map.bindings symbols.c2t0))) + (base, Map.empty) symbols) let build_shift ?lock_base ~flags ~base symbols = try Stdlib.Result.Ok (build_shift ?lock_base ~flags ~base symbols) @@ -533,9 +544,6 @@ type program = { clauses : (preterm,Ast.Structured.attribute) Ast.Clause.t list; chr : (constant list * prechr_rule list) list; local_names : int; - symbols : C.Set.t; - - toplevel_macros : macro_declaration; } [@@deriving show] @@ -579,7 +587,7 @@ type compilation_unit = { type builtins = string * Data.BuiltInPredicate.declaration list -type header = State.t * compilation_unit +type header = State.t * compilation_unit * macro_declaration type program = State.t * Assembled.program @@ -1497,7 +1505,7 @@ module Flatten : sig (* Eliminating the structure (name spaces) *) - val run : State.t -> Structured.program -> Flat.program + val run : State.t -> Structured.program -> C.Set.t * macro_declaration * Flat.program val relocate : State.t -> D.constant D.Constants.Map.t -> Flat.program -> Flat.program val relocate_term : State.t -> D.constant D.Constants.Map.t -> term -> term @@ -1696,14 +1704,12 @@ let subst_amap state f { nargs; c2i; i2n; n2t; n2i } = let modes = apply_subst_modes ~live_symbols empty_subst modes in let types, type_abbrevs, modes, clauses, chr = compile_body live_symbols state local_names types type_abbrevs modes [] [] empty_subst body in - { Flat.types; + !live_symbols, toplevel_macros, { Flat.types; type_abbrevs; modes; clauses; chr = List.rev chr; local_names; - toplevel_macros; - symbols = !live_symbols } let relocate_term state s t = let ksub = apply_subst_constant ([],s) in @@ -1716,8 +1722,6 @@ let subst_amap state f { nargs; c2i; i2n; n2t; n2i } = clauses; chr; local_names; - toplevel_macros; - symbols; } = let f = [], f in { @@ -1727,8 +1731,6 @@ let subst_amap state f { nargs; c2i; i2n; n2t; n2i } = clauses = apply_subst_clauses state f clauses; chr = smart_map (apply_subst_chr state f) chr; local_names; - toplevel_macros; - symbols; } @@ -2073,7 +2075,7 @@ let assemble flags state code (ul : compilation_unit list) = let state, clauses_rev, types, type_abbrevs, modes, chr_rev = List.fold_left (fun (state, cl1, t1, ta1, m1, c1) ({ symbol_table; code } as _u) -> - let state, { Flat.clauses = cl2; types = t2; type_abbrevs = ta2; modes = m2; chr = c2; toplevel_macros = _ } = + let state, { Flat.clauses = cl2; types = t2; type_abbrevs = ta2; modes = m2; chr = c2; } = let state, shift = Stdlib.Result.get_ok @@ Symbols.build_shift ~flags ~base:state symbol_table in let code = if C.Map.is_empty shift then code @@ -2146,7 +2148,7 @@ let unit_or_header_of_ast { print_passes } s ?(toplevel_macros=F.Map.empty) p = Format.eprintf "== Structured ================@\n@[%a@]@\n" (w_symbol_table s Structured.pp_program) p; - let p = Flatten.run s p in + let alive, toplevel_macros, p = Flatten.run s p in if print_passes then Format.eprintf "== Flat ================@\n@[%a@]@\n" @@ -2155,8 +2157,8 @@ let unit_or_header_of_ast { print_passes } s ?(toplevel_macros=F.Map.empty) p = s, { version = "%%VERSION_NUM%%"; code = p; - symbol_table = Symbols.prune (State.get Symbols.table s) ~alive:p.Flat.symbols - } + symbol_table = Symbols.prune (State.get Symbols.table s) ~alive + }, toplevel_macros ;; let print_unit { print_units } x = @@ -2199,17 +2201,16 @@ let header_of_ast ~flags ~parser:p state_descriptor quotation_descriptor hoas_de | Data.BuiltInPredicate.MLDataC _ -> state | Data.BuiltInPredicate.LPCode _ -> state | Data.BuiltInPredicate.LPDoc _ -> state) state decls) state builtins in - let state, u = unit_or_header_of_ast flags state ast in + let state, u, toplevel_macros = unit_or_header_of_ast flags state ast in print_unit flags u; - state, u + state, u, toplevel_macros -let unit_of_ast ~flags ~header:(s, (header : compilation_unit)) p : compilation_unit = - let toplevel_macros = header.code.Flat.toplevel_macros in - let _, u = unit_or_header_of_ast flags s ~toplevel_macros p in +let unit_of_ast ~flags ~header:(s, (header : compilation_unit), toplevel_macros) p : compilation_unit = + let _, u, _ = unit_or_header_of_ast flags s ~toplevel_macros p in print_unit flags u; u -let assemble_units ~flags ~header:(s,h) units : program = +let assemble_units ~flags ~header:(s,h,toplevel_macros) units : program = let nunits_with_locals = (h :: units) |> List.filter (fun {code = { Flat.local_names = x }} -> x > 0) |> List.length in @@ -2217,7 +2218,7 @@ let assemble_units ~flags ~header:(s,h) units : program = if nunits_with_locals > 0 then error "Only 1 compilation unit is supported when local directives are used"; - let init = { Assembled.empty with toplevel_macros = h.code.toplevel_macros; local_names = h.code.local_names } in + let init = { Assembled.empty with toplevel_macros; local_names = h.code.local_names } in let s, p = Assemble.assemble flags s init (h :: units) in