Skip to content

Commit

Permalink
Python backend: fix handling of renamed modules
Browse files Browse the repository at this point in the history
  • Loading branch information
AltGr committed Sep 24, 2024
1 parent a20bdc8 commit 6238619
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 19 deletions.
25 changes: 17 additions & 8 deletions compiler/scalc/to_python.ml
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,17 @@ let renaming =
~reserved:python_keywords
(* TODO: add catala runtime built-ins as reserved as well ? *)
~skip_constant_binders:false ~constant_binder_name:None
~namespaced_fields_constrs:true ~prefix_module:false ~f_var:String.to_ascii
~f_struct:String.to_camel_case ~f_enum:String.to_camel_case
~namespaced_fields_constrs:true ~prefix_module:false
~f_var:String.to_snake_case ~f_struct:String.to_camel_case
~f_enum:String.to_camel_case

let format_qualified (type id) (module Id: Uid.Qualified with type t = id) ctx ppf (s: id) =
match List.rev (Id.path s) with
| [] -> Format.pp_print_string ppf (Id.base s)
| m :: _ -> Format.fprintf ppf "%a.%s" VarName.format (ModuleName.Map.find m ctx.modules) (Id.base s)

let format_struct = format_qualified (module StructName)
let format_enum = format_qualified (module EnumName)

let typ_needs_parens (e : typ) : bool =
match Mark.remove e with TArrow _ | TArray _ -> true | _ -> false
Expand All @@ -199,12 +208,12 @@ let rec format_typ ctx (fmt : Format.formatter) (typ : typ) : unit =
~pp_sep:(fun fmt () -> Format.fprintf fmt ", ")
(fun fmt t -> Format.fprintf fmt "%a" format_typ_with_parens t))
ts
| TStruct s -> StructName.format fmt s
| TStruct s -> format_struct ctx fmt s
| TOption some_typ ->
(* We translate the option type with an overloading by Python's [None] *)
Format.fprintf fmt "Optional[%a]" format_typ some_typ
| TDefault t -> format_typ fmt t
| TEnum e -> EnumName.format fmt e
| TEnum e -> format_enum ctx fmt e
| TArrow (t1, t2) ->
Format.fprintf fmt "Callable[[%a], %a]"
(Format.pp_print_list
Expand All @@ -223,7 +232,7 @@ let rec format_expression ctx (fmt : Format.formatter) (e : expr) : unit =
| EVar v -> VarName.format fmt v
| EFunc f -> FuncName.format fmt f
| EStruct { fields = es; name = s } ->
Format.fprintf fmt "%a(%a)" StructName.format s
Format.fprintf fmt "%a(%a)" (format_struct ctx) s
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
(fun fmt (struct_field, e) ->
Expand All @@ -244,8 +253,8 @@ let rec format_expression ctx (fmt : Format.formatter) (e : expr) : unit =
(* We translate the option type with an overloading by Python's [None] *)
format_expression ctx fmt e
| EInj { e1 = e; cons; name = enum_name; _ } ->
Format.fprintf fmt "%a(%a_Code.%a,@ %a)" EnumName.format enum_name
EnumName.format enum_name EnumConstructor.format cons
Format.fprintf fmt "%a(%a_Code.%a,@ %a)" (format_enum ctx) enum_name
(format_enum ctx) enum_name EnumConstructor.format cons
(format_expression ctx) e
| EArray es ->
Format.fprintf fmt "[%a]"
Expand Down Expand Up @@ -406,7 +415,7 @@ let rec format_statement ctx (fmt : Format.formatter) (s : stmt Mark.pos) : unit
~pp_sep:(fun fmt () -> Format.fprintf fmt "@]@\n@[<hov 4>elif ")
(fun fmt (case, cons_name) ->
Format.fprintf fmt "%a.code == %a_Code.%a:@," VarName.format
switch_var EnumName.format e_name EnumConstructor.format cons_name;
switch_var (format_enum ctx) e_name EnumConstructor.format cons_name;
format_block ctx fmt
(Utils.subst_block case.payload_var_name
(* Not a real catala struct, but will print as <var>.value *)
Expand Down
14 changes: 7 additions & 7 deletions tests/backends/python_name_clash.catala_en
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ class BIn:
return "BIn()".format()


def SomeName(SomeName_in:SomeNameIn):
i = (SomeName_in.i_in)
def some_name(some_name_in:SomeNameIn):
i = (some_name_in.i_in)
o__1 = ((i + integer_of_string("1")))
if o__1 is None:
pos = (SourcePosition(
Expand All @@ -102,7 +102,7 @@ def SomeName(SomeName_in:SomeNameIn):
o = (o__1)
return SomeName(o = o)

def B(B_in:BIn):
def b(b_in:BIn):
result__2 = (integer_of_string("1"))
if result__2 is None:
pos = (SourcePosition(
Expand All @@ -112,12 +112,12 @@ def B(B_in:BIn):
raise NoValue(pos)
else:
result__1 = (result__2)
result = (SomeName(SomeNameIn(i_in = result__1)))
result = (some_name(SomeNameIn(i_in = result__1)))
result__3 = (SomeName(o = result.o))
if True:
some_name = (result__3)
some_name__1 = (result__3)
else:
some_name = (result__3)
return B(some_name = some_name)
some_name__1 = (result__3)
return B(some_name = some_name__1)
```
The above should *not* show `some_name = temp_some_name`, but instead `some_name_1 = ...`
8 changes: 4 additions & 4 deletions tests/name_resolution/good/toplevel_defs.catala_en
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ glob2 = (
decimal_of_string("17.")))
)

def S2(S2_in:S2In):
def s2(s2_in:S2In):
a__1 = ((glob3(money_of_cents_string("4400")) +
decimal_of_string("100.")))
if a__1 is None:
Expand All @@ -369,7 +369,7 @@ def S2(S2_in:S2In):
a = (a__1)
return S2(a = a)

def S3(S3_in:S3In):
def s3(s3_in:S3In):
a__1 = ((decimal_of_string("50.") +
glob4(money_of_cents_string("4400"), decimal_of_string("55."))))
if a__1 is None:
Expand All @@ -383,7 +383,7 @@ def S3(S3_in:S3In):
a = (a__1)
return S3(a = a)

def S4(S4_in:S4In):
def s4(s4_in:S4In):
a__1 = ((glob5 + decimal_of_string("1.")))
if a__1 is None:
pos = (SourcePosition(
Expand All @@ -396,7 +396,7 @@ def S4(S4_in:S4In):
a = (a__1)
return S4(a = a)

def S(S_in:SIn):
def s(s_in:SIn):
a__1 = ((glob1 * glob1))
if a__1 is None:
pos = (SourcePosition(
Expand Down

0 comments on commit 6238619

Please sign in to comment.