Skip to content

Commit f0307cd

Browse files
committed
fix(cs): Enum marshaling
1 parent e87f65c commit f0307cd

7 files changed

Lines changed: 191 additions & 16 deletions

File tree

hosts/dotnet/Hako.SourceGenerator.Tests/JSBindingGeneratorTests.cs

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5229,7 +5229,7 @@ public void SetMode(FileMode mode) { }
52295229
var generatedCode = result.GeneratedTrees.First(t => t.FilePath.Contains("FileHandler")).GetText().ToString();
52305230

52315231
// Should marshal enum as string
5232-
Assert.Contains(".ToString())", generatedCode);
5232+
Assert.Contains(".ToStringFast())", generatedCode);
52335233
Assert.Contains("ctx.NewString", generatedCode);
52345234

52355235
// Should unmarshal enum from string
@@ -5332,7 +5332,7 @@ public partial class User
53325332

53335333
// Should handle nullable enum marshaling
53345334
Assert.Contains("ctx.Null()", generatedCode);
5335-
Assert.Contains(".ToString())", generatedCode);
5335+
Assert.Contains(".ToStringFast())", generatedCode);
53365336

53375337
// Should handle nullable enum unmarshaling
53385338
Assert.Contains("IsNullOrUndefined()", generatedCode);
@@ -5375,8 +5375,7 @@ public void SetPriorities(Priority[] priorities) { }
53755375
var generatedCode = result.GeneratedTrees.First(t => t.FilePath.Contains("TaskManager")).GetText().ToString();
53765376

53775377
// Should handle enum arrays
5378-
Assert.Contains("ToJSArrayOf", generatedCode);
5379-
Assert.Contains("ToArrayOf<global::TestNamespace.Priority>", generatedCode);
5378+
Assert.Contains(" var Priorities = args[0].ToArray<string>().Select(x => global::System.Enum.Parse<global::TestNamespace.Priority>(x, ignoreCase: true)).ToArray();", generatedCode);
53805379
}
53815380

53825381
[Fact]
@@ -5412,7 +5411,7 @@ public partial record LogEntry(
54125411

54135412
// ToJSValue should marshal enum
54145413
Assert.Contains("realm.NewString", generatedCode);
5415-
Assert.Contains(".ToString())", generatedCode);
5414+
Assert.Contains(".ToStringFast())", generatedCode);
54165415

54175416
// FromJSValue should unmarshal enum
54185417
Assert.Contains("global::System.Enum.Parse<", generatedCode);
@@ -5459,7 +5458,7 @@ public static void Request(string url, HttpMethod method) { }
54595458

54605459
// Should marshal enum in module
54615460
Assert.Contains("ctx.NewString", generatedCode);
5462-
Assert.Contains(".ToString())", generatedCode);
5461+
Assert.Contains(".ToStringFast())", generatedCode);
54635462

54645463
// Should unmarshal enum from module method parameter
54655464
Assert.Contains("global::System.Enum.Parse<", generatedCode);

hosts/dotnet/Hako.SourceGenerator/JSBindingGenerator.Enums.cs

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
using System;
12
using System.Collections.Generic;
23
using System.Collections.Immutable;
34
using System.Linq;
@@ -104,6 +105,81 @@ private static string GenerateEnumBinding(
104105
sb.AppendLine();
105106

106107
var accessibility = GetAccessibilityModifier(model.DeclaredAccessibility);
108+
109+
// Generate the extension class for ToStringFast
110+
sb.AppendLine($"{accessibility} static class {model.EnumName}Extensions");
111+
sb.AppendLine("{");
112+
113+
// Generate ToStringFast method
114+
sb.AppendLine($" {accessibility} static string ToStringFast(this {model.EnumName} value)");
115+
116+
if (model.IsFlags)
117+
{
118+
// For flags enums, handle combined values
119+
sb.AppendLine(" {");
120+
sb.AppendLine(" return value switch");
121+
sb.AppendLine(" {");
122+
123+
// Add case for zero
124+
var zeroValue = model.Values.FirstOrDefault(v => Convert.ToInt64(v.Value) == 0);
125+
if (zeroValue != null)
126+
{
127+
sb.AppendLine($" 0 => nameof({model.EnumName}.{zeroValue.Name}),");
128+
}
129+
else
130+
{
131+
sb.AppendLine(" 0 => \"0\",");
132+
}
133+
134+
// Add cases for individual known values
135+
foreach (var enumValue in model.Values.Where(v => Convert.ToInt64(v.Value) != 0))
136+
{
137+
sb.AppendLine(
138+
$" {model.EnumName}.{enumValue.Name} => nameof({model.EnumName}.{enumValue.Name}),");
139+
}
140+
141+
sb.AppendLine(" _ => FormatFlags(value)");
142+
sb.AppendLine(" };");
143+
sb.AppendLine();
144+
145+
// Generate the helper method for building the flags string
146+
sb.AppendLine($" static string FormatFlags({model.EnumName} value)");
147+
sb.AppendLine(" {");
148+
sb.AppendLine(" var flags = new System.Collections.Generic.List<string>();");
149+
sb.AppendLine();
150+
151+
foreach (var enumValue in model.Values.Where(v => Convert.ToInt64(v.Value) != 0))
152+
{
153+
var enumValueLong = Convert.ToInt64(enumValue.Value);
154+
sb.AppendLine(
155+
$" if ((value & {model.EnumName}.{enumValue.Name}) == {model.EnumName}.{enumValue.Name})");
156+
sb.AppendLine($" flags.Add(nameof({model.EnumName}.{enumValue.Name}));");
157+
}
158+
159+
sb.AppendLine();
160+
sb.AppendLine(" return flags.Count > 0 ? string.Join(\", \", flags) : value.ToString();");
161+
sb.AppendLine(" }");
162+
sb.AppendLine(" }");
163+
}
164+
else
165+
{
166+
// For normal enums, use a switch expression
167+
sb.AppendLine(" => value switch");
168+
sb.AppendLine(" {");
169+
170+
foreach (var enumValue in model.Values)
171+
{
172+
sb.AppendLine(
173+
$" {model.EnumName}.{enumValue.Name} => nameof({model.EnumName}.{enumValue.Name}),");
174+
}
175+
176+
sb.AppendLine(" _ => value.ToString(),");
177+
sb.AppendLine(" };");
178+
}
179+
180+
sb.AppendLine("}");
181+
sb.AppendLine();
182+
107183
var useCSharp14Extensions = compilationSettings.LanguageVersion is >= LanguageVersion.CSharp14;
108184

109185
if (useCSharp14Extensions)

hosts/dotnet/Hako.SourceGenerator/JSBindingGenerator.Marshaling.cs

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1150,6 +1150,22 @@ private static string GetStrictUnmarshalCode(TypeInfo type, string jsValueName,
11501150
{
11511151
var elementTypeName = type.ElementType.Replace("global::", "");
11521152

1153+
// Check if element is a [JSEnum]
1154+
if (type.ItemTypeSymbol != null && type.ItemTypeSymbol.IsJSEnum())
1155+
{
1156+
var fullEnumType = type.ItemTypeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
1157+
1158+
if (type.ItemTypeSymbol.IsJSEnumFlags())
1159+
{
1160+
return $"{jsValueName}.ToArray<int>().Select(x => ({fullEnumType})x).ToArray()";
1161+
}
1162+
else
1163+
{
1164+
return
1165+
$"{jsValueName}.ToArray<string>().Select(x => global::System.Enum.Parse<{fullEnumType}>(x, ignoreCase: true)).ToArray()";
1166+
}
1167+
}
1168+
11531169
if (IsPrimitiveTypeName(type.ElementType) ||
11541170
elementTypeName is "System.Object" or "object")
11551171
return $"{jsValueName}.ToArray<{type.ElementType}>()";
@@ -1199,6 +1215,25 @@ private static string GetUnmarshalCode(TypeInfo type, string jsValueName, string
11991215
return
12001216
$"({jsValueName}.IsArrayBuffer() ? {jsValueName}.CopyArrayBuffer() : {jsValueName}.CopyTypedArray())";
12011217

1218+
if (type is { IsArray: true, ElementType: not null })
1219+
{
1220+
// Check if element is a [JSEnum]
1221+
if (type.ItemTypeSymbol != null && type.ItemTypeSymbol.IsJSEnum())
1222+
{
1223+
var fullEnumType = type.ItemTypeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
1224+
1225+
if (type.ItemTypeSymbol.IsJSEnumFlags())
1226+
{
1227+
return $"{jsValueName}.ToArray<int>().Select(x => ({fullEnumType})x).ToArray()";
1228+
}
1229+
else
1230+
{
1231+
return
1232+
$"{jsValueName}.ToArray<string>().Select(x => global::System.Enum.Parse<{fullEnumType}>(x, ignoreCase: true)).ToArray()";
1233+
}
1234+
}
1235+
}
1236+
12021237
return $"{type.FullName}.FromJSValue({contextVarName}, {jsValueName})";
12031238
}
12041239

@@ -1368,8 +1403,8 @@ private static string GetMarshalCodeForPrimitive(TypeInfo type, string valueName
13681403
: $"{ctxName}.NewNumber((int){valueName})";
13691404

13701405
return type.IsNullable
1371-
? $"({valueName} == null ? {ctxName}.Null() : {ctxName}.NewString({valueName}.ToString()))"
1372-
: $"{ctxName}.NewString({valueName}.ToString())";
1406+
? $"({valueName} == null ? {ctxName}.Null() : {ctxName}.NewString({valueName}.ToStringFast()))"
1407+
: $"{ctxName}.NewString({valueName}.ToStringFast())";
13731408
}
13741409

13751410
if (type.FullName == "global::System.Byte[]")
@@ -1381,6 +1416,23 @@ private static string GetMarshalCodeForPrimitive(TypeInfo type, string valueName
13811416
{
13821417
var elementTypeName = type.ElementType.Replace("global::", "");
13831418

1419+
// Check if element is a [JSEnum]
1420+
if (type.ItemTypeSymbol != null && type.ItemTypeSymbol.IsJSEnum())
1421+
{
1422+
if (type.ItemTypeSymbol.IsJSEnumFlags())
1423+
{
1424+
return type.IsNullable
1425+
? $"({valueName} == null ? {ctxName}.Null() : {ctxName}.NewArray({valueName}.Select(x => (int)x)))"
1426+
: $"{ctxName}.NewArray({valueName}.Select(x => (int)x))";
1427+
}
1428+
else
1429+
{
1430+
return type.IsNullable
1431+
? $"({valueName} == null ? {ctxName}.Null() : {ctxName}.NewArray({valueName}.Select(x => x.ToStringFast())))"
1432+
: $"{ctxName}.NewArray({valueName}.Select(x => x.ToStringFast()))";
1433+
}
1434+
}
1435+
13841436
if (IsPrimitiveTypeName(type.ElementType) ||
13851437
elementTypeName is "System.Object" or "object")
13861438
return type.IsNullable

hosts/dotnet/Hako.SourceGenerator/JSBindingGenerator.cs

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1766,16 +1766,20 @@ private static TypeInfo CreateTypeInfo(ITypeSymbol type)
17661766
var isArray = type.TypeKind == TypeKind.Array;
17671767

17681768
string? elementType = null;
1769+
ITypeSymbol? itemTypeSymbol = null;
1770+
17691771
if (isArray && type is IArrayTypeSymbol arrayType)
1772+
{
1773+
itemTypeSymbol = arrayType.ElementType;
17701774
elementType = arrayType.ElementType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
1775+
}
17711776

17721777
var specialType = type.SpecialType;
17731778

17741779
ITypeSymbol? underlyingType = null;
17751780
if (type.IsNullableValueType() && type is INamedTypeSymbol { TypeArguments.Length: > 0 } namedType)
17761781
underlyingType = namedType.TypeArguments[0];
1777-
1778-
// Check if it's a [JSEnum]
1782+
17791783
var isEnum = false;
17801784
var isFlags = false;
17811785

@@ -1788,8 +1792,7 @@ private static TypeInfo CreateTypeInfo(ITypeSymbol type)
17881792
isFlags = enumSymbol.GetAttributes()
17891793
.Any(a => a.AttributeClass?.ToDisplayString() == "System.FlagsAttribute");
17901794
}
1791-
1792-
// Check if it's a generic dictionary
1795+
17931796
var isGenericDictionary = false;
17941797
string? keyType = null;
17951798
string? valueType = null;
@@ -1811,13 +1814,11 @@ private static TypeInfo CreateTypeInfo(ITypeSymbol type)
18111814
valueType = valueTypeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
18121815
}
18131816
}
1814-
1815-
// Check if it's a generic collection
1817+
18161818
var isGenericCollection = false;
18171819
string? itemType = null;
1818-
ITypeSymbol? itemTypeSymbol = null;
18191820

1820-
if (type is INamedTypeSymbol { IsGenericType: true } collectionType && !isGenericDictionary)
1821+
if (type is INamedTypeSymbol { IsGenericType: true } collectionType && !isGenericDictionary && !isArray)
18211822
{
18221823
var typeDefinition = collectionType.ConstructedFrom.ToDisplayString();
18231824
if (typeDefinition is "System.Collections.Generic.List<T>" or

hosts/dotnet/Hako.SourceGenerator/TypeSymbolExtensions.cs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,4 +189,17 @@ public static IEnumerable<IPropertySymbol> GetPropertiesInHierarchy(this INamedT
189189
break;
190190
}
191191
}
192+
193+
public static bool IsJSEnum(this ITypeSymbol typeSymbol)
194+
{
195+
return typeSymbol.TypeKind == TypeKind.Enum &&
196+
HasAttribute(typeSymbol, "HakoJS.SourceGeneration.JSEnumAttribute");
197+
}
198+
199+
public static bool IsJSEnumFlags(this ITypeSymbol typeSymbol)
200+
{
201+
return typeSymbol.IsJSEnum() &&
202+
typeSymbol.GetAttributes()
203+
.Any(a => a.AttributeClass?.ToDisplayString() == "System.FlagsAttribute");
204+
}
192205
}

hosts/dotnet/Hako/Extensions/JSValueExtensions.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -625,6 +625,7 @@ public static T[] ToArray<T>(this JSValue jsValue)
625625

626626
return array;
627627
}
628+
628629

629630
/// <summary>
630631
/// Converts a JavaScript array to a .NET array of types implementing <see cref="IJSMarshalable{T}" />.

hosts/dotnet/Hako/VM/Realm.cs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -830,6 +830,39 @@ public JSValue NewArray()
830830
{
831831
return _valueFactory.FromNativeValue(Array.Empty<object>());
832832
}
833+
834+
/// <summary>
835+
/// Creates a new JavaScript array from a variable number of JSValue objects.
836+
/// </summary>
837+
/// <param name="items">The JSValue objects to populate the array with.</param>
838+
/// <returns>A <see cref="JSValue"/> representing the new array.</returns>
839+
public JSValue NewArray(params object[] items)
840+
{
841+
var array = NewArray();
842+
try
843+
{
844+
for (int i = 0; i < items.Length; i++)
845+
{
846+
array.SetProperty(i, items[i]);
847+
}
848+
return array;
849+
}
850+
catch
851+
{
852+
array.Dispose();
853+
throw;
854+
}
855+
}
856+
857+
/// <summary>
858+
/// Creates a new JavaScript array from an enumerable collection of JSValue objects.
859+
/// </summary>
860+
/// <param name="items">The JSValue objects to populate the array with.</param>
861+
/// <returns>A <see cref="JSValue"/> representing the new array.</returns>
862+
public JSValue NewArray(IEnumerable<object> items)
863+
{
864+
return NewArray(items.ToArray());
865+
}
833866

834867
/// <summary>
835868
/// Creates a new ArrayBuffer from byte data.

0 commit comments

Comments
 (0)