@@ -105,27 +105,37 @@ type table = {
105
105
uuid : Util.UUID .t ;
106
106
} [@@ deriving show ]
107
107
108
- type pruned_table = {
109
- c2s0 : string D.Constants.Map .t ;
110
- c2t0 : D .term D.Constants.Map .t ;
111
- } [@@ deriving show ]
108
+ type entry =
109
+ | GlobalSymbol of D .constant * string
110
+ | BoundVariable of D .constant * D .term
111
+ [@@ deriving show ]
112
+
113
+ type pruned_table = entry array [@@ deriving show ]
112
114
113
115
let locked { locked } = locked
114
116
let lock t = { t with locked = true }
115
117
let uuid { uuid } = uuid
116
118
let equal t1 t2 =
117
119
locked t1 && locked t2 && uuid t1 = uuid t2
118
120
119
- let size t = D.Constants.Map. cardinal t.c2t0
121
+ let size t = Array. length t
120
122
121
- let symbols { c2s0 } =
122
- List. map (fun (c ,s ) -> s ^ " :" ^ string_of_int c) (D.Constants.Map. bindings c2s0)
123
+ let symbols table =
124
+ let map = function
125
+ | GlobalSymbol (c , s ) -> Some (s ^ " :" ^ string_of_int c)
126
+ | BoundVariable _ -> None
127
+ in
128
+ List. rev @@ List. filter_map map @@ Array. to_list table
123
129
124
130
let prune t ~alive =
125
- {
126
- c2s0 = D.Constants.Map. filter (fun k _ -> D.Constants.Set. mem k alive) t.c2s;
127
- c2t0 = D.Constants.Map. filter (fun k _ -> D.Constants.Set. mem k alive) t.c2t;
128
- }
131
+ let c2s = t.c2s in
132
+ let c2t0 = D.Constants.Map. filter (fun k _ -> D.Constants.Set. mem k alive) t.c2t in
133
+ let map k t =
134
+ if k < 0 then GlobalSymbol (k, D.Constants.Map. find k c2s)
135
+ else BoundVariable (k, t)
136
+ in
137
+ let c2t0 = D.Constants.Map. mapi map c2t0 in
138
+ Array. of_list @@ List. rev_map snd @@ D.Constants.Map. bindings c2t0
129
139
130
140
let table = D.State. declare
131
141
~descriptor: D. elpi_state_descriptor
@@ -247,10 +257,10 @@ let build_shift ?(lock_base=false) ~flags:{ print_units } ~base symbols =
247
257
(* We try hard to respect the same order if possible, since some tests
248
258
(grundlagen) depend on this order (for performance, the constant-timestamp
249
259
heuristic in unfolding) *)
250
- List . fold_left (fun (base ,shift as acc ) ( v , t ) ->
251
- if v < 0 then
252
- let name = Map. find v symbols.c2s0 in
253
- try
260
+ Array . fold_left (fun (base ,shift as acc ) e ->
261
+ match e with
262
+ | GlobalSymbol ( v , name ) ->
263
+ begin try
254
264
let c, _ = F.Map. find (F. from_string name) base.ast2ct in
255
265
if c == v then acc
256
266
else begin
@@ -262,13 +272,14 @@ let build_shift ?(lock_base=false) ~flags:{ print_units } ~base symbols =
262
272
| Not_found ->
263
273
let base, (c,_) = allocate_global_symbol_aux (Ast.Func. from_string name) base in
264
274
base, Map. add v c shift
265
- else
275
+ end
276
+ | BoundVariable (v , t ) ->
266
277
if Map. mem v base.c2t then acc
267
278
else
268
279
let base = { base with c2t = Map. add v t base.c2t } in
269
280
base, shift
270
281
)
271
- (base,Map. empty) ( List. rev ( Map. bindings symbols.c2t0)) )
282
+ (base, Map. empty) symbols)
272
283
273
284
let build_shift ?lock_base ~flags ~base symbols =
274
285
try Stdlib.Result. Ok (build_shift ?lock_base ~flags ~base symbols)
@@ -533,9 +544,6 @@ type program = {
533
544
clauses : (preterm ,Ast.Structured .attribute ) Ast.Clause .t list ;
534
545
chr : (constant list * prechr_rule list ) list ;
535
546
local_names : int ;
536
- symbols : C.Set .t ;
537
-
538
- toplevel_macros : macro_declaration ;
539
547
}
540
548
[@@ deriving show ]
541
549
@@ -579,7 +587,7 @@ type compilation_unit = {
579
587
580
588
type builtins = string * Data.BuiltInPredicate .declaration list
581
589
582
- type header = State .t * compilation_unit
590
+ type header = State .t * compilation_unit * macro_declaration
583
591
type program = State .t * Assembled .program
584
592
585
593
@@ -1497,7 +1505,7 @@ module Flatten : sig
1497
1505
1498
1506
(* Eliminating the structure (name spaces) *)
1499
1507
1500
- val run : State .t -> Structured .program -> Flat .program
1508
+ val run : State .t -> Structured .program -> C.Set .t * macro_declaration * Flat .program
1501
1509
1502
1510
val relocate : State .t -> D .constant D.Constants.Map .t -> Flat .program -> Flat .program
1503
1511
val relocate_term : State .t -> D .constant D.Constants.Map .t -> term -> term
@@ -1696,14 +1704,12 @@ let subst_amap state f { nargs; c2i; i2n; n2t; n2i } =
1696
1704
let modes = apply_subst_modes ~live_symbols empty_subst modes in
1697
1705
let types, type_abbrevs, modes, clauses, chr =
1698
1706
compile_body live_symbols state local_names types type_abbrevs modes [] [] empty_subst body in
1699
- { Flat. types;
1707
+ ! live_symbols, toplevel_macros, { Flat. types;
1700
1708
type_abbrevs;
1701
1709
modes;
1702
1710
clauses;
1703
1711
chr = List. rev chr;
1704
1712
local_names;
1705
- toplevel_macros;
1706
- symbols = ! live_symbols
1707
1713
}
1708
1714
let relocate_term state s t =
1709
1715
let ksub = apply_subst_constant ([] ,s) in
@@ -1716,8 +1722,6 @@ let subst_amap state f { nargs; c2i; i2n; n2t; n2i } =
1716
1722
clauses;
1717
1723
chr;
1718
1724
local_names;
1719
- toplevel_macros;
1720
- symbols;
1721
1725
} =
1722
1726
let f = [] , f in
1723
1727
{
@@ -1727,8 +1731,6 @@ let subst_amap state f { nargs; c2i; i2n; n2t; n2i } =
1727
1731
clauses = apply_subst_clauses state f clauses;
1728
1732
chr = smart_map (apply_subst_chr state f) chr;
1729
1733
local_names;
1730
- toplevel_macros;
1731
- symbols;
1732
1734
}
1733
1735
1734
1736
@@ -2073,7 +2075,7 @@ let assemble flags state code (ul : compilation_unit list) =
2073
2075
2074
2076
let state, clauses_rev, types, type_abbrevs, modes, chr_rev =
2075
2077
List. fold_left (fun (state , cl1 , t1 , ta1 , m1 , c1 ) ({ symbol_table; code } as _u ) ->
2076
- let state, { Flat. clauses = cl2; types = t2; type_abbrevs = ta2; modes = m2; chr = c2; toplevel_macros = _ } =
2078
+ let state, { Flat. clauses = cl2; types = t2; type_abbrevs = ta2; modes = m2; chr = c2; } =
2077
2079
let state, shift = Stdlib.Result. get_ok @@ Symbols. build_shift ~flags ~base: state symbol_table in
2078
2080
let code =
2079
2081
if C.Map. is_empty shift then code
@@ -2146,7 +2148,7 @@ let unit_or_header_of_ast { print_passes } s ?(toplevel_macros=F.Map.empty) p =
2146
2148
Format. eprintf " == Structured ================@\n @[<v 0>%a@]@\n "
2147
2149
(w_symbol_table s Structured. pp_program) p;
2148
2150
2149
- let p = Flatten. run s p in
2151
+ let alive, toplevel_macros, p = Flatten. run s p in
2150
2152
2151
2153
if print_passes then
2152
2154
Format. eprintf " == Flat ================@\n @[<v 0>%a@]@\n "
@@ -2155,8 +2157,8 @@ let unit_or_header_of_ast { print_passes } s ?(toplevel_macros=F.Map.empty) p =
2155
2157
s, {
2156
2158
version = " %%VERSION_NUM%%" ;
2157
2159
code = p;
2158
- symbol_table = Symbols. prune (State. get Symbols. table s) ~alive: p. Flat. symbols
2159
- }
2160
+ symbol_table = Symbols. prune (State. get Symbols. table s) ~alive
2161
+ }, toplevel_macros
2160
2162
;;
2161
2163
2162
2164
let print_unit { print_units } x =
@@ -2199,25 +2201,24 @@ let header_of_ast ~flags ~parser:p state_descriptor quotation_descriptor hoas_de
2199
2201
| Data.BuiltInPredicate. MLDataC _ -> state
2200
2202
| Data.BuiltInPredicate. LPCode _ -> state
2201
2203
| Data.BuiltInPredicate. LPDoc _ -> state) state decls) state builtins in
2202
- let state, u = unit_or_header_of_ast flags state ast in
2204
+ let state, u, toplevel_macros = unit_or_header_of_ast flags state ast in
2203
2205
print_unit flags u;
2204
- state, u
2206
+ state, u, toplevel_macros
2205
2207
2206
- let unit_of_ast ~flags ~header :(s , (header : compilation_unit )) p : compilation_unit =
2207
- let toplevel_macros = header.code.Flat. toplevel_macros in
2208
- let _, u = unit_or_header_of_ast flags s ~toplevel_macros p in
2208
+ let unit_of_ast ~flags ~header :(s , (header : compilation_unit ), toplevel_macros ) p : compilation_unit =
2209
+ let _, u, _ = unit_or_header_of_ast flags s ~toplevel_macros p in
2209
2210
print_unit flags u;
2210
2211
u
2211
2212
2212
- let assemble_units ~flags ~header :(s ,h ) units : program =
2213
+ let assemble_units ~flags ~header :(s ,h , toplevel_macros ) units : program =
2213
2214
2214
2215
let nunits_with_locals =
2215
2216
(h :: units) |> List. filter (fun {code = { Flat. local_names = x } } -> x > 0 ) |> List. length in
2216
2217
2217
2218
if nunits_with_locals > 0 then
2218
2219
error " Only 1 compilation unit is supported when local directives are used" ;
2219
2220
2220
- let init = { Assembled. empty with toplevel_macros = h.code.toplevel_macros ; local_names = h.code.local_names } in
2221
+ let init = { Assembled. empty with toplevel_macros; local_names = h.code.local_names } in
2221
2222
2222
2223
let s, p = Assemble. assemble flags s init (h :: units) in
2223
2224
0 commit comments