diff --git a/compiler/scalc/to_python.ml b/compiler/scalc/to_python.ml index 1e39d3e4d..d3b0e6dff 100644 --- a/compiler/scalc/to_python.ml +++ b/compiler/scalc/to_python.ml @@ -173,8 +173,17 @@ let renaming = ~reserved:python_keywords (* TODO: add catala runtime built-ins as reserved as well ? *) ~skip_constant_binders:false ~constant_binder_name:None - ~namespaced_fields_constrs:true ~prefix_module:false ~f_var:String.to_ascii - ~f_struct:String.to_camel_case ~f_enum:String.to_camel_case + ~namespaced_fields_constrs:true ~prefix_module:false + ~f_var:String.to_snake_case ~f_struct:String.to_camel_case + ~f_enum:String.to_camel_case + +let format_qualified (type id) (module Id: Uid.Qualified with type t = id) ctx ppf (s: id) = + match List.rev (Id.path s) with + | [] -> Format.pp_print_string ppf (Id.base s) + | m :: _ -> Format.fprintf ppf "%a.%s" VarName.format (ModuleName.Map.find m ctx.modules) (Id.base s) + +let format_struct = format_qualified (module StructName) +let format_enum = format_qualified (module EnumName) let typ_needs_parens (e : typ) : bool = match Mark.remove e with TArrow _ | TArray _ -> true | _ -> false @@ -199,12 +208,12 @@ let rec format_typ ctx (fmt : Format.formatter) (typ : typ) : unit = ~pp_sep:(fun fmt () -> Format.fprintf fmt ", ") (fun fmt t -> Format.fprintf fmt "%a" format_typ_with_parens t)) ts - | TStruct s -> StructName.format fmt s + | TStruct s -> format_struct ctx fmt s | TOption some_typ -> (* We translate the option type with an overloading by Python's [None] *) Format.fprintf fmt "Optional[%a]" format_typ some_typ | TDefault t -> format_typ fmt t - | TEnum e -> EnumName.format fmt e + | TEnum e -> format_enum ctx fmt e | TArrow (t1, t2) -> Format.fprintf fmt "Callable[[%a], %a]" (Format.pp_print_list @@ -223,7 +232,7 @@ let rec format_expression ctx (fmt : Format.formatter) (e : expr) : unit = | EVar v -> VarName.format fmt v | EFunc f -> FuncName.format fmt f | EStruct { fields = es; name = s } -> - Format.fprintf fmt "%a(%a)" StructName.format s + Format.fprintf fmt "%a(%a)" (format_struct ctx) s (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") (fun fmt (struct_field, e) -> @@ -244,8 +253,8 @@ let rec format_expression ctx (fmt : Format.formatter) (e : expr) : unit = (* We translate the option type with an overloading by Python's [None] *) format_expression ctx fmt e | EInj { e1 = e; cons; name = enum_name; _ } -> - Format.fprintf fmt "%a(%a_Code.%a,@ %a)" EnumName.format enum_name - EnumName.format enum_name EnumConstructor.format cons + Format.fprintf fmt "%a(%a_Code.%a,@ %a)" (format_enum ctx) enum_name + (format_enum ctx) enum_name EnumConstructor.format cons (format_expression ctx) e | EArray es -> Format.fprintf fmt "[%a]" @@ -406,7 +415,7 @@ let rec format_statement ctx (fmt : Format.formatter) (s : stmt Mark.pos) : unit ~pp_sep:(fun fmt () -> Format.fprintf fmt "@]@\n@[elif ") (fun fmt (case, cons_name) -> Format.fprintf fmt "%a.code == %a_Code.%a:@," VarName.format - switch_var EnumName.format e_name EnumConstructor.format cons_name; + switch_var (format_enum ctx) e_name EnumConstructor.format cons_name; format_block ctx fmt (Utils.subst_block case.payload_var_name (* Not a real catala struct, but will print as .value *) diff --git a/tests/backends/python_name_clash.catala_en b/tests/backends/python_name_clash.catala_en index 33788aa74..60525217f 100644 --- a/tests/backends/python_name_clash.catala_en +++ b/tests/backends/python_name_clash.catala_en @@ -89,8 +89,8 @@ class BIn: return "BIn()".format() -def SomeName(SomeName_in:SomeNameIn): - i = (SomeName_in.i_in) +def some_name(some_name_in:SomeNameIn): + i = (some_name_in.i_in) o__1 = ((i + integer_of_string("1"))) if o__1 is None: pos = (SourcePosition( @@ -102,7 +102,7 @@ def SomeName(SomeName_in:SomeNameIn): o = (o__1) return SomeName(o = o) -def B(B_in:BIn): +def b(b_in:BIn): result__2 = (integer_of_string("1")) if result__2 is None: pos = (SourcePosition( @@ -112,12 +112,12 @@ def B(B_in:BIn): raise NoValue(pos) else: result__1 = (result__2) - result = (SomeName(SomeNameIn(i_in = result__1))) + result = (some_name(SomeNameIn(i_in = result__1))) result__3 = (SomeName(o = result.o)) if True: - some_name = (result__3) + some_name__1 = (result__3) else: - some_name = (result__3) - return B(some_name = some_name) + some_name__1 = (result__3) + return B(some_name = some_name__1) ``` The above should *not* show `some_name = temp_some_name`, but instead `some_name_1 = ...` diff --git a/tests/name_resolution/good/toplevel_defs.catala_en b/tests/name_resolution/good/toplevel_defs.catala_en index d5b5a1970..fdd6c7234 100644 --- a/tests/name_resolution/good/toplevel_defs.catala_en +++ b/tests/name_resolution/good/toplevel_defs.catala_en @@ -355,7 +355,7 @@ glob2 = ( decimal_of_string("17."))) ) -def S2(S2_in:S2In): +def s2(s2_in:S2In): a__1 = ((glob3(money_of_cents_string("4400")) + decimal_of_string("100."))) if a__1 is None: @@ -369,7 +369,7 @@ def S2(S2_in:S2In): a = (a__1) return S2(a = a) -def S3(S3_in:S3In): +def s3(s3_in:S3In): a__1 = ((decimal_of_string("50.") + glob4(money_of_cents_string("4400"), decimal_of_string("55.")))) if a__1 is None: @@ -383,7 +383,7 @@ def S3(S3_in:S3In): a = (a__1) return S3(a = a) -def S4(S4_in:S4In): +def s4(s4_in:S4In): a__1 = ((glob5 + decimal_of_string("1."))) if a__1 is None: pos = (SourcePosition( @@ -396,7 +396,7 @@ def S4(S4_in:S4In): a = (a__1) return S4(a = a) -def S(S_in:SIn): +def s(s_in:SIn): a__1 = ((glob1 * glob1)) if a__1 is None: pos = (SourcePosition(