From 8b4a188ea55e2c48dacf791e91d23ee15ed646c3 Mon Sep 17 00:00:00 2001 From: Josh Lane Date: Wed, 26 Nov 2025 05:38:25 -0800 Subject: [PATCH] feat(gen): add array element type discrimination for oneOf/anyOf Enable automatic discrimination between oneOf variants that have array fields with different element types (e.g., string[] vs integer[] vs boolean[]). This extends the type-based discrimination added in PR #1584 to support cases where variants share the same field name with array types but differ in their element types. Implementation: - Add ArrayElementType and ArrayElementTypeID fields to UniqueFieldVariant - Add getArrayElementTypeInfo() to extract element type from array type IDs - Update validation to allow discrimination when array element types differ - Generate decoder code that peeks into arrays using d.Capture() and d.ArrIter() to check first element type without consuming Supported cases: - Basic primitives: string[] vs integer[] vs boolean[] - Object vs primitive: object[] vs string[] - Mixed: array type combined with unique field discrimination Limitations (future work): - Nested arrays (array[array[string]] vs array[array[integer]]) - Complex object arrays (User[] vs Product[] with same object type) --- .../array_element_discrimination.json | 264 ++++++++++++++++++ gen/_template/json/encoders_sum.tmpl | 40 +++ gen/ir/type.go | 9 + gen/schema_gen.go | 10 +- gen/schema_gen_sum.go | 99 ++++++- gen/templates.go | 58 +++- 6 files changed, 468 insertions(+), 12 deletions(-) create mode 100644 _testdata/positive/array_element_discrimination.json diff --git a/_testdata/positive/array_element_discrimination.json b/_testdata/positive/array_element_discrimination.json new file mode 100644 index 000000000..a84ce535a --- /dev/null +++ b/_testdata/positive/array_element_discrimination.json @@ -0,0 +1,264 @@ +{ + "openapi": "3.0.3", + "info": { + "title": "Array Element Type Discrimination Test", + "version": "1.0.0", + "description": "Tests discrimination between oneOf variants based on array element types (currently unsupported but should work)" + }, + "paths": { + "/basic-arrays": { + "get": { + "operationId": "getBasicArrays", + "description": "Test basic array element discrimination: string[] vs integer[]", + "responses": { + "200": { + "description": "OK", + "content": { + "application/json": { + "schema": { + "type": "array", + "items": { + "$ref": "#/components/schemas/BasicArrayResource" + } + } + } + } + } + } + } + }, + "/object-vs-primitive-arrays": { + "get": { + "operationId": "getObjectVsPrimitiveArrays", + "description": "Test array of objects vs array of primitives", + "responses": { + "200": { + "description": "OK", + "content": { + "application/json": { + "schema": { + "type": "array", + "items": { + "$ref": "#/components/schemas/ObjectVsPrimitiveResource" + } + } + } + } + } + } + } + }, + "/mixed-discrimination": { + "post": { + "operationId": "postMixedDiscrimination", + "description": "Test mixed discrimination with array field + other fields", + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/MixedDiscriminationResource" + } + } + } + }, + "responses": { + "200": { + "description": "OK" + } + } + } + } + }, + "components": { + "schemas": { + "BasicArrayResource": { + "oneOf": [ + { + "$ref": "#/components/schemas/StringArrayVariant" + }, + { + "$ref": "#/components/schemas/IntegerArrayVariant" + }, + { + "$ref": "#/components/schemas/BooleanArrayVariant" + } + ] + }, + "StringArrayVariant": { + "type": "object", + "required": ["name", "items"], + "properties": { + "name": { + "type": "string" + }, + "items": { + "type": "array", + "description": "Array of string items", + "items": { + "type": "string" + } + }, + "stringInfo": { + "type": "string", + "description": "Unique to string variant for fallback discrimination" + } + } + }, + "IntegerArrayVariant": { + "type": "object", + "required": ["name", "items"], + "properties": { + "name": { + "type": "string" + }, + "items": { + "type": "array", + "description": "Array of integer items", + "items": { + "type": "integer" + } + }, + "intInfo": { + "type": "integer", + "description": "Unique to integer variant for fallback discrimination" + } + } + }, + "BooleanArrayVariant": { + "type": "object", + "required": ["name", "items"], + "properties": { + "name": { + "type": "string" + }, + "items": { + "type": "array", + "description": "Array of boolean items", + "items": { + "type": "boolean" + } + }, + "boolInfo": { + "type": "boolean", + "description": "Unique to boolean variant for fallback discrimination" + } + } + }, + "ObjectVsPrimitiveResource": { + "oneOf": [ + { + "$ref": "#/components/schemas/ObjectArrayVariant" + }, + { + "$ref": "#/components/schemas/PrimitiveArrayVariant" + } + ] + }, + "ObjectArrayVariant": { + "type": "object", + "required": ["id", "items"], + "properties": { + "id": { + "type": "string" + }, + "items": { + "type": "array", + "description": "Array of object items", + "items": { + "type": "object", + "required": ["key", "value"], + "properties": { + "key": { + "type": "string" + }, + "value": { + "type": "integer" + } + } + } + }, + "objectMeta": { + "type": "object", + "description": "Unique to object variant for fallback discrimination", + "properties": { + "version": { + "type": "integer" + } + } + } + } + }, + "PrimitiveArrayVariant": { + "type": "object", + "required": ["id", "items"], + "properties": { + "id": { + "type": "string" + }, + "items": { + "type": "array", + "description": "Array of primitive string items", + "items": { + "type": "string" + } + }, + "primitiveMeta": { + "type": "string", + "description": "Unique to primitive variant for fallback discrimination" + } + } + }, + "MixedDiscriminationResource": { + "oneOf": [ + { + "$ref": "#/components/schemas/MixedStringArrayVariant" + }, + { + "$ref": "#/components/schemas/MixedNumberArrayVariant" + } + ] + }, + "MixedStringArrayVariant": { + "type": "object", + "required": ["id", "items", "metadata"], + "properties": { + "id": { + "type": "string" + }, + "items": { + "type": "array", + "description": "Array of strings", + "items": { + "type": "string" + } + }, + "metadata": { + "type": "string", + "description": "Unique to string variant" + } + } + }, + "MixedNumberArrayVariant": { + "type": "object", + "required": ["id", "items", "count"], + "properties": { + "id": { + "type": "string" + }, + "items": { + "type": "array", + "description": "Array of numbers", + "items": { + "type": "number" + } + }, + "count": { + "type": "integer", + "description": "Unique to number variant" + } + } + } + } + } +} diff --git a/gen/_template/json/encoders_sum.tmpl b/gen/_template/json/encoders_sum.tmpl index b8be27e2a..b338bd5c1 100644 --- a/gen/_template/json/encoders_sum.tmpl +++ b/gen/_template/json/encoders_sum.tmpl @@ -153,6 +153,46 @@ func (s *{{ $.Name }}) Decode(d *jx.Decoder) error { } found = true s.Type = match + {{- else if needsArrayElementDiscrimination $variants }} + // Array element discrimination: peek into array to check first element type + if typ := d.Next(); typ != jx.Array { + return d.Skip() + } + // Capture array to peek at first element without consuming + if err := d.Capture(func(d *jx.Decoder) error { + // Check if array is empty + iter, err := d.ArrIter() + if err != nil { + return err + } + if !iter.Next() { + // Empty array - use first variant as default + {{- $firstVariant := index (dedupeVariantsByArrayElementType $variants) 0 }} + if !found { + found = true + s.Type = {{ $firstVariant.VariantType }} + } + return nil + } + elemType := d.Next() + switch elemType { + {{- range $v := dedupeVariantsByArrayElementType $variants }} + {{- if $v.ArrayElementType }} + case {{ $v.ArrayElementType }}: + match := {{ $v.VariantType }} + if found && s.Type != match { + s.Type = "" + return errors.Errorf("multiple oneOf matches: (%v, %v)", s.Type, match) + } + found = true + s.Type = match + {{- end }} + {{- end }} + } + return nil + }); err != nil { + return err + } {{- else }} // Multiple variants have this field - use type checking to discriminate typ := d.Next() diff --git a/gen/ir/type.go b/gen/ir/type.go index 89f4cf800..4e474db8d 100644 --- a/gen/ir/type.go +++ b/gen/ir/type.go @@ -38,6 +38,15 @@ type UniqueFieldVariant struct { VariantType string // e.g., "SystemEventEvent" FieldType string // jx.Type constant, e.g., "jx.String" Nullable bool // true if field is nullable (accepts both base type and jx.Null) + + // ArrayElementType is the jx.Type of array elements for array element discrimination. + // Only set when FieldType is "jx.Array" and element type can distinguish variants. + // e.g., "jx.String" for array[string], "jx.Number" for array[integer], "jx.Object" for array[object] + ArrayElementType string + + // ArrayElementTypeID is the full type ID for array elements (e.g., "string", "integer", "object"). + // Used for more detailed discrimination like distinguishing integer vs number. + ArrayElementTypeID string } // SumSpec for KindSum. diff --git a/gen/schema_gen.go b/gen/schema_gen.go index e216e1e29..67748ec64 100644 --- a/gen/schema_gen.go +++ b/gen/schema_gen.go @@ -4,6 +4,7 @@ import ( "cmp" "fmt" "path" + "sort" "strings" "github.com/go-faster/errors" @@ -865,7 +866,14 @@ func inferSchemaFromObject(obj map[string]any) *jsonschema.Schema { schema := &jsonschema.Schema{ Type: jsonschema.Object, } - for fieldName, fieldValue := range obj { + // Sort keys for deterministic output + keys := make([]string, 0, len(obj)) + for k := range obj { + keys = append(keys, k) + } + sort.Strings(keys) + for _, fieldName := range keys { + fieldValue := obj[fieldName] prop := jsonschema.Property{ Name: fieldName, Schema: inferSchemaFromValue(fieldValue), diff --git a/gen/schema_gen_sum.go b/gen/schema_gen_sum.go index 3a10f3e3e..c4b61662e 100644 --- a/gen/schema_gen_sum.go +++ b/gen/schema_gen_sum.go @@ -39,6 +39,9 @@ const ( typeIDSum = "sum" typeIDAlias = "alias" typeIDPointer = "pointer" + + // jxTypeArray is the string representation of jx.Array for template generation. + jxTypeArray = "jx.Array" ) // jxTypeForFieldType returns the jx.Type constant name for runtime type checking. @@ -56,7 +59,7 @@ func jxTypeForFieldType(typeID string) string { case typeID == typeIDObject: return "jx.Object" case strings.HasPrefix(typeID, "array["): - return "jx.Array" + return jxTypeArray case strings.HasPrefix(typeID, "map["): return "jx.Object" case strings.HasPrefix(typeID, "enum_"): @@ -67,6 +70,24 @@ func jxTypeForFieldType(typeID string) string { } } +// getArrayElementTypeInfo extracts element type information from an array type ID. +// Returns the element type ID and its corresponding jx.Type. +// For non-array types, returns empty strings. +func getArrayElementTypeInfo(typeID string) (elementTypeID, elementJxType string) { + if !strings.HasPrefix(typeID, "array[") { + return "", "" + } + + // Extract element type: "array[string]" -> "string" + elementTypeID = strings.TrimPrefix(typeID, "array[") + elementTypeID = strings.TrimSuffix(elementTypeID, "]") + + // Get the jx.Type for the element + elementJxType = jxTypeForFieldType(elementTypeID) + + return elementTypeID, elementJxType +} + // getFieldTypeID returns a type identifier for discrimination purposes. // Fields with the same name but different typeIDs can discriminate variants. func getFieldTypeID(t *ir.Type) string { @@ -849,13 +870,18 @@ func (g *schemaGen) oneOf(name string, schema *jsonschema.Schema, side bool) (*i isNullable := (f.Type.IsGeneric() && f.Type.GenericVariant.Nullable) || (f.Type.IsPointer() && f.Type.NilSemantic.Null()) + // Get array element type info for array element discrimination + elemTypeID, elemJxType := getArrayElementTypeInfo(sig.typeID) + // Add to UniqueFields map for template iteration // Include entries even when jxType is empty (simple field-name discrimination) sum.SumSpec.UniqueFields[f.Tag.JSON] = append(sum.SumSpec.UniqueFields[f.Tag.JSON], ir.UniqueFieldVariant{ - VariantName: s.Name, - VariantType: s.Name + sum.Name, - FieldType: jxType, // Empty string means no runtime type check needed - Nullable: isNullable, // true if field accepts null values + VariantName: s.Name, + VariantType: s.Name + sum.Name, + FieldType: jxType, // Empty string means no runtime type check needed + Nullable: isNullable, // true if field accepts null values + ArrayElementType: elemJxType, + ArrayElementTypeID: elemTypeID, }) } } @@ -1022,7 +1048,7 @@ func (g *schemaGen) oneOf(name string, schema *jsonschema.Schema, side bool) (*i } } - // If all variants have the same jxType (or empty), try value-based discrimination + // If all variants have the same jxType (or empty), try value-based or array element discrimination if len(uniqueJxTypes) <= 1 { // Try value-based discrimination (enum values) canUse, discriminator, err := canUseValueDiscrimination(fieldName, fieldVariants) @@ -1063,7 +1089,66 @@ func (g *schemaGen) oneOf(name string, schema *jsonschema.Schema, side bool) (*i continue // This field can discriminate, move to next field } - // Can't use value discrimination, record for potential error + // Value discrimination didn't work, check if array element discrimination is possible + allArrays := true + uniqueArrayElemTypes := make(map[string]bool) + for _, fv := range fieldVariants { + if fv.FieldType != jxTypeArray { + allArrays = false + break + } + if fv.ArrayElementType != "" { + uniqueArrayElemTypes[fv.ArrayElementType] = true + } + } + + // If all variants are arrays with different element types, check if we have other discriminating fields + if allArrays && len(uniqueArrayElemTypes) > 1 { + // Array element discrimination works, but we need to check if there are other + // unique fields to discriminate in case this array field is missing or empty. + // If this array field is the ONLY way to discriminate, we should reject it + // because the field might be optional and missing from the JSON. + + // Check if any variant has a unique field that exists ONLY in that variant (by name) + hasOtherUniqueFields := false + variantFieldNames := make(map[string]map[string]struct{}) // variant -> set of field names + for _, variant := range sum.SumOf { + variantFieldNames[variant.Name] = make(map[string]struct{}) + for _, f := range variant.JSON().Fields() { + variantFieldNames[variant.Name][f.Tag.JSON] = struct{}{} + } + } + + // For each variant, check if it has any field name unique to it + for _, variant := range sum.SumOf { + for fieldName := range variantFieldNames[variant.Name] { + isUniqueToVariant := true + for _, otherVariant := range sum.SumOf { + if otherVariant.Name == variant.Name { + continue + } + if _, hasField := variantFieldNames[otherVariant.Name][fieldName]; hasField { + isUniqueToVariant = false + break + } + } + if isUniqueToVariant { + hasOtherUniqueFields = true + break + } + } + if hasOtherUniqueFields { + break + } + } + + if hasOtherUniqueFields { + continue // Can discriminate by array element type (with fallback to other fields) + } + // Fall through to the error below - array element discrimination alone is not sufficient + } + + // Can't use value discrimination or array element discrimination, record for potential error var typeIDs []string for _, v := range sortedVariants { for _, s := range sum.SumOf { diff --git a/gen/templates.go b/gen/templates.go index 1e1edeeb1..f666b39fc 100644 --- a/gen/templates.go +++ b/gen/templates.go @@ -206,10 +206,12 @@ func templateFunctions() template.FuncMap { "mod": func(a, b int) int { return a % b }, - "isObjectParam": isObjectParam, - "paramObjectFields": paramObjectFields, - "uniqueResponseTypes": uniqueResponseTypes, - "dedupeVariantsByType": dedupeVariantsByType, + "isObjectParam": isObjectParam, + "paramObjectFields": paramObjectFields, + "uniqueResponseTypes": uniqueResponseTypes, + "dedupeVariantsByType": dedupeVariantsByType, + "needsArrayElementDiscrimination": needsArrayElementDiscrimination, + "dedupeVariantsByArrayElementType": dedupeVariantsByArrayElementType, } } @@ -317,3 +319,51 @@ func dedupeVariantsByType(variants []ir.UniqueFieldVariant) []ir.UniqueFieldVari return result } + +// needsArrayElementDiscrimination checks if all variants have the same jx.Array FieldType +// but different ArrayElementTypes, requiring element-level discrimination. +func needsArrayElementDiscrimination(variants []ir.UniqueFieldVariant) bool { + if len(variants) < 2 { + return false + } + + // All variants must be arrays + for _, v := range variants { + if v.FieldType != jxTypeArray { + return false + } + } + + // Count unique element types + uniqueElemTypes := make(map[string]bool) + for _, v := range variants { + if v.ArrayElementType != "" { + uniqueElemTypes[v.ArrayElementType] = true + } + } + + return len(uniqueElemTypes) > 1 +} + +// dedupeVariantsByArrayElementType deduplicates array variants by their ArrayElementType. +// Used when all variants are arrays that need element-level discrimination. +func dedupeVariantsByArrayElementType(variants []ir.UniqueFieldVariant) []ir.UniqueFieldVariant { + if len(variants) == 0 { + return variants + } + + seen := make(map[string]bool) + result := make([]ir.UniqueFieldVariant, 0, len(variants)) + + for _, v := range variants { + // If ArrayElementType is empty, include the variant + if v.ArrayElementType == "" || !seen[v.ArrayElementType] { + if v.ArrayElementType != "" { + seen[v.ArrayElementType] = true + } + result = append(result, v) + } + } + + return result +}