diff --git a/models_test.go b/models_test.go index 2d4aae4..a7a97fc 100644 --- a/models_test.go +++ b/models_test.go @@ -159,14 +159,14 @@ func (bc *BadComment) JSONAPILinks() *Links { type Company struct { ID string `jsonapi:"primary,companies"` Name string `jsonapi:"attr,name"` - Boss Employee `jsonapi:"attr,boss"` + Boss *Employee `jsonapi:"attr,boss"` Teams []Team `jsonapi:"attr,teams"` FoundedAt time.Time `jsonapi:"attr,founded-at,iso8601"` } type Team struct { Name string `jsonapi:"attr,name"` - Leader *Employee `jsonapi:"attr,leader"` + Leader Employee `jsonapi:"attr,leader"` Members []Employee `jsonapi:"attr,members"` } diff --git a/request.go b/request.go index a7bb0b1..f4bdefb 100644 --- a/request.go +++ b/request.go @@ -247,6 +247,7 @@ func unmarshalNode(data *Node, model reflect.Value, included *map[string]*Node) } structField := fieldType + value, err := unmarshalAttribute(attribute, args, structField, fieldValue) if err != nil { er = err @@ -332,6 +333,25 @@ func unmarshalNode(data *Node, model reflect.Value, included *map[string]*Node) return er } +func unmarshalFromAttribute(attribute interface{}, fieldValue reflect.Value) (reflect.Value, error) { + structData, err := json.Marshal(attribute) + if err != nil { + return reflect.Value{}, err + } + + structNode := new(Node) + if err := json.Unmarshal(structData, &structNode.Attributes); err != nil { + return reflect.Value{}, err + } + + structModel := reflect.New(fieldValue.Type()) + if err := unmarshalNode(structNode, structModel, nil); err != nil { + return reflect.Value{}, err + } + + return structModel, nil +} + func fullNode(n *Node, included *map[string]*Node) *Node { includedKey := fmt.Sprintf("%s,%s", n.Type, n.ID) @@ -397,12 +417,13 @@ func unmarshalAttribute( if fieldValue.Type() == reflect.TypeOf(time.Time{}) || fieldValue.Type() == reflect.TypeOf(new(time.Time)) { value, err = handleTime(attribute, args, fieldValue) + return } // Handle field of type struct - if fieldValue.Type().Kind() == reflect.Struct { - value, err = handleStruct(attribute, fieldValue) + if fieldValue.Kind() == reflect.Struct { + value, err = unmarshalFromAttribute(attribute, fieldValue) return } @@ -421,7 +442,7 @@ func unmarshalAttribute( // Field was a Pointer type if fieldValue.Kind() == reflect.Ptr { - value, err = handlePointer(attribute, args, fieldType, fieldValue, structField) + value, err = handlePointer(attribute, fieldType, fieldValue, structField) return } @@ -477,7 +498,6 @@ func handleTime(attribute interface{}, args []string, fieldValue reflect.Value) } var at int64 - if v.Kind() == reflect.Float64 { at = int64(v.Interface().(float64)) } else if v.Kind() == reflect.Int { @@ -487,7 +507,6 @@ func handleTime(attribute interface{}, args []string, fieldValue reflect.Value) } t := time.Unix(at, 0) - return reflect.ValueOf(t), nil } @@ -553,7 +572,6 @@ func handleNumeric( func handlePointer( attribute interface{}, - args []string, fieldType reflect.Type, fieldValue reflect.Value, structField reflect.StructField) (reflect.Value, error) { @@ -569,11 +587,13 @@ func handlePointer( concreteVal = reflect.ValueOf(&cVal) case map[string]interface{}: var err error - concreteVal, err = handleStruct(attribute, fieldValue) + fieldValueType := reflect.New(fieldValue.Type().Elem()).Elem() + concreteVal, err = unmarshalFromAttribute(attribute, fieldValueType) if err != nil { return reflect.Value{}, newErrUnsupportedPtrType( reflect.ValueOf(attribute), fieldType, structField) } + return concreteVal, err default: return reflect.Value{}, newErrUnsupportedPtrType( @@ -591,7 +611,6 @@ func handlePointer( func handleStruct( attribute interface{}, fieldValue reflect.Value) (reflect.Value, error) { - data, err := json.Marshal(attribute) if err != nil { return reflect.Value{}, err @@ -619,13 +638,13 @@ func handleStruct( func handleStructSlice( attribute interface{}, fieldValue reflect.Value) (reflect.Value, error) { + models := reflect.New(fieldValue.Type()).Elem() dataMap := reflect.ValueOf(attribute).Interface().([]interface{}) for _, data := range dataMap { model := reflect.New(fieldValue.Type().Elem()).Elem() - value, err := handleStruct(data, model) - + value, err := unmarshalFromAttribute(data, model) if err != nil { continue } diff --git a/response.go b/response.go index e8e85fa..b4d2ea8 100644 --- a/response.go +++ b/response.go @@ -207,7 +207,13 @@ func visitModelNode(model interface{}, included *map[string]*Node, node := new(Node) var er error - value := reflect.ValueOf(model) + var value reflect.Value + if modelValue, ok := model.(reflect.Value); ok { + value = modelValue.Addr() + } else { + value = reflect.ValueOf(model) + } + if value.IsNil() { return nil, nil } @@ -340,6 +346,45 @@ func visitModelNode(model interface{}, included *map[string]*Node, node.Attributes[args[1]] = tm.Unix() } } + } else if fieldValue.Kind() == reflect.Slice && fieldValue.Type().Elem().Kind() == reflect.Struct { + if omitEmpty && fieldValue.Len() == 0 { + continue + } + + newSlice := make([]map[string]interface{}, fieldValue.Len()) + for i:=0; i < fieldValue.Len(); i++ { + included := make(map[string]*Node) + nested, err := visitModelNode(fieldValue.Index(i), &included, true) + if err != nil { + er = err + break + } + + newSlice[i] = nested.Attributes + } + node.Attributes[args[1]] = newSlice + } else if fieldValue.Kind() == reflect.Struct || + (fieldValue.Kind() == reflect.Ptr && fieldValue.Elem().Kind() == reflect.Struct) { + + // Dealing with a fieldValue that is not a time + emptyValue := reflect.Zero(fieldValue.Type()) + + // See if we need to omit this field + if omitEmpty && reflect.DeepEqual(fieldValue.Interface(), emptyValue.Interface()) { + continue + } + + included := make(map[string]*Node) + if fieldValue.Kind() == reflect.Ptr { + fieldValue = fieldValue.Elem() + } + nested, err := visitModelNode(fieldValue, &included, true) + if err != nil { + er = err + break + } + + node.Attributes[args[1]] = nested.Attributes } else { // Dealing with a fieldValue that is not a time emptyValue := reflect.Zero(fieldValue.Type()) diff --git a/response_test.go b/response_test.go index 5b42595..2475624 100644 --- a/response_test.go +++ b/response_test.go @@ -901,6 +901,57 @@ func TestMarshal_InvalidIntefaceArgument(t *testing.T) { } } +func TestMarshalNestedStruct(t *testing.T) { + team := Team{ + Name: "Awesome team", + Leader: Employee{ + Firstname: "John", + Surname: "Mota", + Age: 35, + }, + Members: []Employee{ + { + Firstname: "Henrique", + Surname: "Doe", + }, + }, + } + + now := time.Now() + company := Company { + ID: "an_id", + Name: "Awesome Company", + Boss: &Employee{ + Firstname: "Company", + Surname: "boss", + Age: 60, + }, + Teams: []Team { + team, + }, + FoundedAt: now, + } + + buffer := bytes.NewBuffer(nil) + MarshalOnePayloadEmbedded(buffer, &company) + reader := bytes.NewReader(buffer.Bytes()) + var finalCompany Company + UnmarshalPayload(reader, &finalCompany) + + diff := company.FoundedAt.Sub(finalCompany.FoundedAt) + + if diff.Seconds() > 1 { + t.Error("final unmarshal payload founded at must be approximately equal to the original.") + } + + company.FoundedAt = time.Time{} + finalCompany.FoundedAt = time.Time{} + + if !reflect.DeepEqual(company, finalCompany) { + t.Error("final unmarshal payload should be equal to the original one.") + } +} + func testBlog() *Blog { return &Blog{ ID: 5,