Skip to content

Commit

Permalink
C backend: fix handling of struct inits that may be const or not
Browse files Browse the repository at this point in the history
This happens when there are two branches that are bound to an existing struct,
and to a literal one that needs a malloc+non-const;

the fix refines the handling of `no_struct_literals` in scalc, to ensure that
`SLocalInit` is used on a local temporary variable for the constructs needing
malloc.
  • Loading branch information
AltGr committed Sep 16, 2024
1 parent ee56965 commit 2e68915
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 77 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ reset-tests: .FORCE $(CLERK_BIN)
# $(CLERK_TEST) test $@

%.c.exe: %.catala_en .FORCE
$(CATALA_BIN) c $<
$(CATALA_BIN) c $(CATALAOPTS) $<
cc --std=c89 -Wall -pedantic $*.c -lcatala_runtime -lgmp -Wno-unused-but-set-variable -Wno-unused-variable -I $$(ocamlfind query dates_calc)/c -I_build/install/default/lib/catala/runtime_c -L_build/install/default/lib/catala/runtime_c -o $*.c.exe
$@
.FORCE:
Expand Down
59 changes: 37 additions & 22 deletions compiler/scalc/from_lcalc.ml
Original file line number Diff line number Diff line change
Expand Up @@ -319,26 +319,34 @@ and translate_expr (ctxt : 'm ctxt) (expr : 'm L.expr) :
block of statements, and an expression containing the variable holding the
result *)
and spill_expr ctxt expr =
let tmp_var, ctxt = fresh_var ctxt ctxt.context_name ~pos:(Expr.pos expr) in
let pos = Expr.pos expr in
let typ = Expr.maybe_ty (Mark.get expr) in
let tmp_var, ctxt = fresh_var ctxt ctxt.context_name ~pos in
let ctxt =
{ ctxt with context_name = Mark.remove (A.VarName.get_info tmp_var) }
in
let tmp_stmts, ren_ctx =
translate_assignment ctxt (Some (tmp_var, Expr.pos expr)) expr
in
let stmts =
RevBlock.make
[
( A.SLocalDecl
{
name = tmp_var, Expr.pos expr;
typ = Expr.maybe_ty (Mark.get expr);
},
Expr.pos expr );
]
++ tmp_stmts
in
stmts, (A.EVar tmp_var, Expr.pos expr), ren_ctx
match Mark.remove expr with
| (EArray _ | EStruct _ | EInj _ | ETuple _)
when ctxt.config.no_struct_literals ->
(* We want [SLocalInit] for these constructs requiring malloc *)
let stmts, expr, ren_ctx = translate_struct_literal ctxt expr in
( stmts +> (A.SLocalInit { name = tmp_var, pos; expr; typ }, pos),
(A.EVar tmp_var, pos),
ren_ctx )
| _ ->
let tmp_stmts, ren_ctx =
translate_assignment ctxt (Some (tmp_var, Expr.pos expr)) expr
in
let stmts =
RevBlock.make
[
( A.SLocalDecl
{ name = tmp_var, pos; typ = Expr.maybe_ty (Mark.get expr) },
pos );
]
++ tmp_stmts
in
stmts, (A.EVar tmp_var, pos), ren_ctx

(** This translates an expression [block_expr] to a series of statements that
compute its value, and either assign to the given variable, or return it. *)
Expand Down Expand Up @@ -493,13 +501,20 @@ and translate_assignment
},
Expr.pos block_expr ),
ren_ctx )
| ELit _ | EAppOp _ | EArray _ | EVar _ | EStruct _ | EInj _ | ETuple _
| EArray _ | EStruct _ | EInj _ | ETuple _ | ELit _ | EAppOp _ | EVar _
| ETupleAccess _ | EStructAccess _ | EExternal _ | EApp _ ->
let stmts, expr, ren_ctx =
match block_expr with
| ((EStruct _ | EInj _ | ETuple _ | EArray _), _) as e ->
translate_struct_literal ctxt e
| e -> translate_expr ctxt e
match Mark.remove block_expr with
| (EArray _ | EStruct _ | EInj _ | ETuple _) as e ->
let is_option =
match e with
| EInj { name; _ } -> EnumName.equal name Expr.option_enum
| _ -> false
in
if ctxt.config.no_struct_literals && not is_option then
spill_expr ctxt block_expr
else translate_struct_literal ctxt block_expr
| _ -> translate_expr ctxt block_expr
in
( (stmts
+>
Expand Down
49 changes: 21 additions & 28 deletions compiler/scalc/to_c.ml
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,13 @@ let rec format_expression
else VarName.format fmt v
| EFunc f -> FuncName.format fmt f
| EStructFieldAccess { e1; field; _ } ->
Format.fprintf fmt "(%a)->%a" format_expression e1 StructField.format field
let lpar, rpar =
match e1 with
| EVar _, _ | EStructFieldAccess _, _ -> "", ""
| _ -> "(", ")"
in
Format.fprintf fmt "%s%a%s->%a" lpar format_expression e1 rpar
StructField.format field
| EInj { e1; cons; name = enum_name; _ }
when EnumName.equal enum_name Expr.option_enum ->
if EnumConstructor.equal cons Expr.none_constr then
Expand All @@ -269,11 +275,9 @@ let rec format_expression
| ELit l -> Format.fprintf fmt "%a" format_lit (Mark.copy e l)
| EPosLit -> assert false (* Handled only as toplevel definitions *)
| EAppOp { op = ToClosureEnv, _; args = [arg]; _ } ->
Format.fprintf fmt "((catala_closure *)%a)"
format_expression arg
Format.fprintf fmt "((catala_closure *)%a)" format_expression arg
| EAppOp { op = FromClosureEnv, _; args = [arg]; _ } ->
Format.fprintf fmt "((CATALA_TUPLE)%a)"
format_expression arg
Format.fprintf fmt "((CATALA_TUPLE)%a)" format_expression arg
| EAppOp { op = ((Map | Filter), _) as op; args = [arg1; arg2]; _ } ->
Format.fprintf fmt "%a(%a,@ %a)" format_op op format_expression arg1
format_expression arg2
Expand Down Expand Up @@ -603,26 +607,8 @@ and format_block (ctx : ctx) (env : env) (fmt : Format.formatter) (b : block) :
when Mark.equal VarName.equal name n1 ->
format_decls defined_vars remaining
((SLocalInit { name; typ; expr }, m) :: r)
| ((SLocalDecl { name; typ }, _) as decl) :: r ->
let () =
match typ with
| (TArray _ | TStruct _ | TEnum _ | TTuple _), _ ->
let defs =
Utils.filter_map_block
(function
| SLocalDef { name = n1; expr; _ }, _
when Mark.equal VarName.equal name n1 -> Some expr
| _ -> None)
r
in
let malloc, no_malloc = List.partition requires_malloc defs in
(* NOTE: if there are branches that need a malloc and others not, we choose to do the malloc anyway, but without marking the pointer as const. It could be better to delay the malloc to just before the definitions that will need it. *)
if malloc <> [] then
print_init_malloc fmt (no_malloc = []) (Mark.remove name) typ
else
format_statement ctx env fmt decl
| _ -> format_statement ctx env fmt decl
in
| ((SLocalDecl _, _) as decl) :: r ->
format_statement ctx env fmt decl;
format_decls defined_vars remaining r
| ((SLocalInit { name; typ; expr }, m) as init) :: r ->
if requires_malloc expr then (
Expand All @@ -646,9 +632,16 @@ and format_block (ctx : ctx) (env : env) (fmt : Format.formatter) (b : block) :
| [] -> List.rev remaining
in
match List.find_opt (function SFatalError _, _ -> true | _ -> false) b with
| Some (SFatalError { pos_expr = EVar vpos, _; _ }, _ as fatal) ->
(* avoid printing dead code: only print the fatal error (this also avoids warnings about unused or undefined variables) *)
let pos_def = List.find_opt (function SLocalInit {name = v, _; _}, _ -> VarName.equal v vpos | _ -> false) b in
| Some ((SFatalError { pos_expr = EVar vpos, _; _ }, _) as fatal) ->
(* avoid printing dead code: only print the fatal error (this also avoids
warnings about unused or undefined variables) *)
let pos_def =
List.find_opt
(function
| SLocalInit { name = v, _; _ }, _ -> VarName.equal v vpos
| _ -> false)
b
in
Option.iter (format_statement ctx env fmt) pos_def;
format_statement ctx env fmt fatal;
Format.fprintf fmt "@,return NULL;" (* unreachable, but avoids a warning *)
Expand Down
24 changes: 12 additions & 12 deletions compiler/scalc/utils.ml
Original file line number Diff line number Diff line change
Expand Up @@ -126,16 +126,16 @@ let rec find_block pred = function

let rec filter_map_block pred = function
| [] -> []
| (SIfThenElse { then_block; else_block; _ }, _) as stmt :: r ->
Option.to_list (pred stmt) @
filter_map_block pred then_block @
filter_map_block pred else_block @
filter_map_block pred r
| (SSwitch { switch_cases; _ }, _) as stmt :: r ->
Option.to_list (pred stmt) @
List.flatten
(List.map (fun case -> filter_map_block pred case.case_block) switch_cases)
| ((SIfThenElse { then_block; else_block; _ }, _) as stmt) :: r ->
Option.to_list (pred stmt)
@ filter_map_block pred then_block
@ filter_map_block pred else_block
@ filter_map_block pred r
| stmt :: r ->
Option.to_list (pred stmt) @
filter_map_block pred r
| ((SSwitch { switch_cases; _ }, _) as stmt) :: r ->
Option.to_list (pred stmt)
@ List.flatten
(List.map
(fun case -> filter_map_block pred case.case_block)
switch_cases)
@ filter_map_block pred r
| stmt :: r -> Option.to_list (pred stmt) @ filter_map_block pred r
30 changes: 16 additions & 14 deletions tests/backends/output/simple.c
Original file line number Diff line number Diff line change
Expand Up @@ -34,30 +34,27 @@ typedef struct Baz_in {

const Baz* baz (const Baz_in* baz_in)
{
const catala_closure* a;
const catala_closure* a = baz_in->a_in;
const Bar* a__1;
const CATALA_OPTION(Bar*) a__2;
const catala_closure* code_and_env;
const CATALA_OPTION(Bar*) a__3;
const catala_closure* code_and_env = a;
const CATALA_OPTION(Bar*) a__3 =
((const CATALA_OPTION(Bar*)(*)(const CLOSURE_ENV, CATALA_UNIT))
code_and_env->funcp)(code_and_env->env, CATALA_UNITVAL);
CATALA_DEC b;
const CATALA_OPTION(CATALA_DEC) b__1;
const CATALA_OPTION(CATALA_DEC) b__2;
CATALA_BOOL b__3;
const CATALA_ARRAY(CATALA_DEC) c;
const CATALA_OPTION(CATALA_ARRAY(CATALA_DEC)) c__1 =
catala_malloc(sizeof(catala_option));
CATALA_ARRAY(CATALA_DEC) c__2 = catala_malloc(sizeof(catala_array));
Baz* baz__1 = catala_malloc(sizeof(Baz));
a = baz_in->a_in;
code_and_env = a;
a__3 = ((const CATALA_OPTION(Bar*)(*)(const CLOSURE_ENV, CATALA_UNIT))
code_and_env->funcp)(code_and_env->env, CATALA_UNITVAL);
CATALA_ARRAY(CATALA_DEC) const c__2 = catala_malloc(sizeof(catala_array));
const CATALA_OPTION(CATALA_ARRAY(CATALA_DEC)) c__1;
Baz* const baz__1 = catala_malloc(sizeof(Baz));
if (a__3->code == catala_option_some) {
a__2 = catala_some(a__3->payload);
} else {
const Bar* a__4;
const CATALA_OPTION(Bar*) a__5 = catala_malloc(sizeof(catala_option));
Bar* a__6 = catala_malloc(sizeof(Bar));
Bar* const a__6 = catala_malloc(sizeof(Bar));
const CATALA_OPTION(Bar*) a__5;
a__6->code = Bar_No;
a__6->payload.No = CATALA_UNITVAL;
a__5 = catala_some(a__6);
Expand All @@ -67,6 +64,7 @@ const Baz* baz (const Baz_in* baz_in)
static const catala_code_position pos[1] =
{{"tests/backends/simple.catala_en", 11, 11, 11, 12}};
catala_error(catala_no_value, pos);
return NULL;
}
a__2 = catala_some(a__4);
}
Expand All @@ -76,6 +74,7 @@ const Baz* baz (const Baz_in* baz_in)
static const catala_code_position pos[1] =
{{"tests/backends/simple.catala_en", 11, 11, 11, 12}};
catala_error(catala_no_value, pos);
return NULL;
}
switch (a__1->code) {
case Bar_No: {
Expand All @@ -102,7 +101,7 @@ const Baz* baz (const Baz_in* baz_in)
break;
}
case Bar_Yes: {
Foo* foo = a__1->payload.Yes;
const Foo* foo = a__1->payload.Yes;
CATALA_DEC b__5;
if (foo->x == CATALA_TRUE) {
b__5 = catala_new_dec_str("1");
Expand All @@ -121,6 +120,7 @@ const Baz* baz (const Baz_in* baz_in)
static const catala_code_position pos[1] =
{{"tests/backends/simple.catala_en", 12, 10, 12, 11}};
catala_error(catala_no_value, pos);
return NULL;
}
c__2->size = 2;
c__2->elements = catala_malloc(2 * sizeof(void*));
Expand All @@ -133,12 +133,14 @@ const Baz* baz (const Baz_in* baz_in)
static const catala_code_position pos[1] =
{{"tests/backends/simple.catala_en", 13, 10, 13, 11}};
catala_error(catala_no_value, pos);
return NULL;
}
baz__1->b = b;
baz__1->c = c;
return baz__1;
}


int main (int argc, char** argv)
{
catala_init();
Expand Down

0 comments on commit 2e68915

Please sign in to comment.