Skip to content

Commit

Permalink
Merge pull request #249 from ppedrot/even-leaner-symbol-table
Browse files Browse the repository at this point in the history
More compact representation of compiled programs
  • Loading branch information
gares authored Jul 30, 2024
2 parents 965f44f + f6e85a7 commit d7e778b
Showing 1 changed file with 41 additions and 40 deletions.
81 changes: 41 additions & 40 deletions src/compiler.ml
Original file line number Diff line number Diff line change
Expand Up @@ -105,27 +105,37 @@ type table = {
uuid : Util.UUID.t;
} [@@deriving show]

type pruned_table = {
c2s0 : string D.Constants.Map.t;
c2t0 : D.term D.Constants.Map.t;
} [@@deriving show]
type entry =
| GlobalSymbol of D.constant * string
| BoundVariable of D.constant * D.term
[@@deriving show]

type pruned_table = entry array [@@deriving show]

let locked { locked } = locked
let lock t = { t with locked = true }
let uuid { uuid } = uuid
let equal t1 t2 =
locked t1 && locked t2 && uuid t1 = uuid t2

let size t = D.Constants.Map.cardinal t.c2t0
let size t = Array.length t

let symbols { c2s0 } =
List.map (fun (c,s) -> s ^ ":" ^ string_of_int c) (D.Constants.Map.bindings c2s0)
let symbols table =
let map = function
| GlobalSymbol (c, s) -> Some (s ^ ":" ^ string_of_int c)
| BoundVariable _ -> None
in
List.rev @@ List.filter_map map @@ Array.to_list table

let prune t ~alive =
{
c2s0 = D.Constants.Map.filter (fun k _ -> D.Constants.Set.mem k alive) t.c2s;
c2t0 = D.Constants.Map.filter (fun k _ -> D.Constants.Set.mem k alive) t.c2t;
}
let c2s = t.c2s in
let c2t0 = D.Constants.Map.filter (fun k _ -> D.Constants.Set.mem k alive) t.c2t in
let map k t =
if k < 0 then GlobalSymbol (k, D.Constants.Map.find k c2s)
else BoundVariable (k, t)
in
let c2t0 = D.Constants.Map.mapi map c2t0 in
Array.of_list @@ List.rev_map snd @@ D.Constants.Map.bindings c2t0

let table = D.State.declare
~descriptor:D.elpi_state_descriptor
Expand Down Expand Up @@ -247,10 +257,10 @@ let build_shift ?(lock_base=false) ~flags:{ print_units } ~base symbols =
(* We try hard to respect the same order if possible, since some tests
(grundlagen) depend on this order (for performance, the constant-timestamp
heuristic in unfolding) *)
List.fold_left (fun (base,shift as acc) (v, t) ->
if v < 0 then
let name = Map.find v symbols.c2s0 in
try
Array.fold_left (fun (base,shift as acc) e ->
match e with
| GlobalSymbol (v, name) ->
begin try
let c, _ = F.Map.find (F.from_string name) base.ast2ct in
if c == v then acc
else begin
Expand All @@ -262,13 +272,14 @@ let build_shift ?(lock_base=false) ~flags:{ print_units } ~base symbols =
| Not_found ->
let base, (c,_) = allocate_global_symbol_aux (Ast.Func.from_string name) base in
base, Map.add v c shift
else
end
| BoundVariable (v, t) ->
if Map.mem v base.c2t then acc
else
let base = { base with c2t = Map.add v t base.c2t } in
base, shift
)
(base,Map.empty) (List.rev (Map.bindings symbols.c2t0)))
(base, Map.empty) symbols)

let build_shift ?lock_base ~flags ~base symbols =
try Stdlib.Result.Ok (build_shift ?lock_base ~flags ~base symbols)
Expand Down Expand Up @@ -533,9 +544,6 @@ type program = {
clauses : (preterm,Ast.Structured.attribute) Ast.Clause.t list;
chr : (constant list * prechr_rule list) list;
local_names : int;
symbols : C.Set.t;

toplevel_macros : macro_declaration;
}
[@@deriving show]

Expand Down Expand Up @@ -579,7 +587,7 @@ type compilation_unit = {

type builtins = string * Data.BuiltInPredicate.declaration list

type header = State.t * compilation_unit
type header = State.t * compilation_unit * macro_declaration
type program = State.t * Assembled.program


Expand Down Expand Up @@ -1497,7 +1505,7 @@ module Flatten : sig

(* Eliminating the structure (name spaces) *)

val run : State.t -> Structured.program -> Flat.program
val run : State.t -> Structured.program -> C.Set.t * macro_declaration * Flat.program

val relocate : State.t -> D.constant D.Constants.Map.t -> Flat.program -> Flat.program
val relocate_term : State.t -> D.constant D.Constants.Map.t -> term -> term
Expand Down Expand Up @@ -1696,14 +1704,12 @@ let subst_amap state f { nargs; c2i; i2n; n2t; n2i } =
let modes = apply_subst_modes ~live_symbols empty_subst modes in
let types, type_abbrevs, modes, clauses, chr =
compile_body live_symbols state local_names types type_abbrevs modes [] [] empty_subst body in
{ Flat.types;
!live_symbols, toplevel_macros, { Flat.types;
type_abbrevs;
modes;
clauses;
chr = List.rev chr;
local_names;
toplevel_macros;
symbols = !live_symbols
}
let relocate_term state s t =
let ksub = apply_subst_constant ([],s) in
Expand All @@ -1716,8 +1722,6 @@ let subst_amap state f { nargs; c2i; i2n; n2t; n2i } =
clauses;
chr;
local_names;
toplevel_macros;
symbols;
} =
let f = [], f in
{
Expand All @@ -1727,8 +1731,6 @@ let subst_amap state f { nargs; c2i; i2n; n2t; n2i } =
clauses = apply_subst_clauses state f clauses;
chr = smart_map (apply_subst_chr state f) chr;
local_names;
toplevel_macros;
symbols;
}


Expand Down Expand Up @@ -2073,7 +2075,7 @@ let assemble flags state code (ul : compilation_unit list) =

let state, clauses_rev, types, type_abbrevs, modes, chr_rev =
List.fold_left (fun (state, cl1, t1, ta1, m1, c1) ({ symbol_table; code } as _u) ->
let state, { Flat.clauses = cl2; types = t2; type_abbrevs = ta2; modes = m2; chr = c2; toplevel_macros = _ } =
let state, { Flat.clauses = cl2; types = t2; type_abbrevs = ta2; modes = m2; chr = c2; } =
let state, shift = Stdlib.Result.get_ok @@ Symbols.build_shift ~flags ~base:state symbol_table in
let code =
if C.Map.is_empty shift then code
Expand Down Expand Up @@ -2146,7 +2148,7 @@ let unit_or_header_of_ast { print_passes } s ?(toplevel_macros=F.Map.empty) p =
Format.eprintf "== Structured ================@\n@[<v 0>%a@]@\n"
(w_symbol_table s Structured.pp_program) p;

let p = Flatten.run s p in
let alive, toplevel_macros, p = Flatten.run s p in

if print_passes then
Format.eprintf "== Flat ================@\n@[<v 0>%a@]@\n"
Expand All @@ -2155,8 +2157,8 @@ let unit_or_header_of_ast { print_passes } s ?(toplevel_macros=F.Map.empty) p =
s, {
version = "%%VERSION_NUM%%";
code = p;
symbol_table = Symbols.prune (State.get Symbols.table s) ~alive:p.Flat.symbols
}
symbol_table = Symbols.prune (State.get Symbols.table s) ~alive
}, toplevel_macros
;;

let print_unit { print_units } x =
Expand Down Expand Up @@ -2199,25 +2201,24 @@ let header_of_ast ~flags ~parser:p state_descriptor quotation_descriptor hoas_de
| Data.BuiltInPredicate.MLDataC _ -> state
| Data.BuiltInPredicate.LPCode _ -> state
| Data.BuiltInPredicate.LPDoc _ -> state) state decls) state builtins in
let state, u = unit_or_header_of_ast flags state ast in
let state, u, toplevel_macros = unit_or_header_of_ast flags state ast in
print_unit flags u;
state, u
state, u, toplevel_macros

let unit_of_ast ~flags ~header:(s, (header : compilation_unit)) p : compilation_unit =
let toplevel_macros = header.code.Flat.toplevel_macros in
let _, u = unit_or_header_of_ast flags s ~toplevel_macros p in
let unit_of_ast ~flags ~header:(s, (header : compilation_unit), toplevel_macros) p : compilation_unit =
let _, u, _ = unit_or_header_of_ast flags s ~toplevel_macros p in
print_unit flags u;
u

let assemble_units ~flags ~header:(s,h) units : program =
let assemble_units ~flags ~header:(s,h,toplevel_macros) units : program =

let nunits_with_locals =
(h :: units) |> List.filter (fun {code = { Flat.local_names = x }} -> x > 0) |> List.length in

if nunits_with_locals > 0 then
error "Only 1 compilation unit is supported when local directives are used";

let init = { Assembled.empty with toplevel_macros = h.code.toplevel_macros; local_names = h.code.local_names } in
let init = { Assembled.empty with toplevel_macros; local_names = h.code.local_names } in

let s, p = Assemble.assemble flags s init (h :: units) in

Expand Down

0 comments on commit d7e778b

Please sign in to comment.