diff --git a/models_test.go b/models_test.go index 21a3a35..17efc60 100644 --- a/models_test.go +++ b/models_test.go @@ -159,7 +159,7 @@ func (bc *BadComment) JSONAPILinks() *Links { type Company struct { ID string `jsonapi:"primary,companies"` Name string `jsonapi:"attr,name"` - Boss Employee `jsonapi:"attr,boss"` + Boss *Employee `jsonapi:"attr,boss"` Teams []Team `jsonapi:"attr,teams"` FoundedAt time.Time `jsonapi:"attr,founded-at,iso8601"` } diff --git a/request.go b/request.go index f752a58..4dbc9fe 100644 --- a/request.go +++ b/request.go @@ -248,26 +248,14 @@ func unmarshalNode(data *Node, model reflect.Value, included *map[string]*Node) structField := fieldType - if structField.Type.Kind() != reflect.Struct || - fieldValue.Type() == reflect.TypeOf(new(time.Time)) || - fieldValue.Type() == reflect.TypeOf(time.Time{}) { - value, err := unmarshalAttribute(attribute, args, structField, fieldValue) - if err != nil { - er = err - break - } - assign(fieldValue, value) - continue - - } else { - structModel, err := unmarshalFromAttribute(attribute, fieldValue) - if err != nil { - er = err - break - } - fieldValue.Set((*structModel).Elem()) - continue + value, err := unmarshalAttribute(attribute, args, structField, fieldValue) + if err != nil { + er = err + break } + + assign(fieldValue, value) + continue } else if annotation == annotationRelation { isSlice := fieldValue.Type().Kind() == reflect.Slice @@ -346,21 +334,22 @@ func unmarshalNode(data *Node, model reflect.Value, included *map[string]*Node) return er } -func unmarshalFromAttribute(attribute interface{}, fieldValue reflect.Value) (*reflect.Value, error) { +func unmarshalFromAttribute(attribute interface{}, fieldValue reflect.Value) (reflect.Value, error) { structData, err := json.Marshal(attribute) if err != nil { - return nil, err + return reflect.Value{}, err } structNode := new(Node) if err := json.Unmarshal(structData, &structNode.Attributes); err != nil { - return nil, err + return reflect.Value{}, err } + structModel := reflect.New(fieldValue.Type()) if err := unmarshalNode(structNode, structModel, nil); err != nil { - return nil, err + return reflect.Value{}, err } - return &structModel, nil + return structModel, nil } func fullNode(n *Node, included *map[string]*Node) *Node { @@ -376,7 +365,7 @@ func fullNode(n *Node, included *map[string]*Node) *Node { // assign will take the value specified and assign it to the field; if // field is expecting a ptr assign will assign a ptr. func assign(field, value reflect.Value) { - if field.Kind() == reflect.Ptr || field.Kind() == reflect.Struct{ + if field.Kind() == reflect.Ptr { field.Set(value) } else { field.Set(reflect.Indirect(value)) @@ -402,12 +391,13 @@ func unmarshalAttribute( if fieldValue.Type() == reflect.TypeOf(time.Time{}) || fieldValue.Type() == reflect.TypeOf(new(time.Time)) { value, err = handleTime(attribute, args, fieldValue) + return } // Handle field of type struct - if fieldValue.Type().Kind() == reflect.Struct { - value, err = handleStruct(attribute, fieldValue) + if fieldValue.Kind() == reflect.Struct { + value, err = unmarshalFromAttribute(attribute, fieldValue) return } @@ -426,7 +416,7 @@ func unmarshalAttribute( // Field was a Pointer type if fieldValue.Kind() == reflect.Ptr { - value, err = handlePointer(attribute, args, fieldType, fieldValue, structField) + value, err = handlePointer(attribute, fieldType, fieldValue, structField) return } @@ -482,7 +472,6 @@ func handleTime(attribute interface{}, args []string, fieldValue reflect.Value) } var at int64 - if v.Kind() == reflect.Float64 { at = int64(v.Interface().(float64)) } else if v.Kind() == reflect.Int { @@ -492,7 +481,6 @@ func handleTime(attribute interface{}, args []string, fieldValue reflect.Value) } t := time.Unix(at, 0) - return reflect.ValueOf(t), nil } @@ -558,7 +546,6 @@ func handleNumeric( func handlePointer( attribute interface{}, - args []string, fieldType reflect.Type, fieldValue reflect.Value, structField reflect.StructField) (reflect.Value, error) { @@ -574,12 +561,15 @@ func handlePointer( concreteVal = reflect.ValueOf(&cVal) case map[string]interface{}: var err error - concreteVal, err = handleStruct(attribute, fieldValue) + fieldValueType := reflect.New(fieldValue.Type().Elem()).Elem() + concreteVal, err = unmarshalFromAttribute(attribute, fieldValueType) if err != nil { return reflect.Value{}, newErrUnsupportedPtrType( reflect.ValueOf(attribute), fieldType, structField) } + return concreteVal, err + default: return reflect.Value{}, newErrUnsupportedPtrType( reflect.ValueOf(attribute), fieldType, structField) @@ -624,13 +614,13 @@ func handleStruct( func handleStructSlice( attribute interface{}, fieldValue reflect.Value) (reflect.Value, error) { + models := reflect.New(fieldValue.Type()).Elem() dataMap := reflect.ValueOf(attribute).Interface().([]interface{}) for _, data := range dataMap { model := reflect.New(fieldValue.Type().Elem()).Elem() - value, err := handleStruct(data, model) - + value, err := unmarshalFromAttribute(data, model) if err != nil { continue } diff --git a/response.go b/response.go index 0c0f368..28daf87 100644 --- a/response.go +++ b/response.go @@ -359,8 +359,12 @@ func visitModelNode(model interface{}, included *map[string]*Node, newSlice[i] = nested.Attributes } node.Attributes[args[1]] = newSlice - } else if fieldValue.Kind() == reflect.Struct { + } else if fieldValue.Kind() == reflect.Struct || + (fieldValue.Kind() == reflect.Ptr && fieldValue.Elem().Kind() == reflect.Struct) { included := make(map[string]*Node) + if fieldValue.Kind() == reflect.Ptr { + fieldValue = fieldValue.Elem() + } nested, err := visitModelNode(fieldValue, &included, true) if err != nil { er = err diff --git a/response_test.go b/response_test.go index 7783b6f..bf5b7e1 100644 --- a/response_test.go +++ b/response_test.go @@ -918,13 +918,37 @@ func TestMarshalNestedStruct(t *testing.T) { }, } + now := time.Now() + company := Company { + ID: "an_id", + Name: "Awesome Company", + Boss: &Employee{ + Firstname: "Company", + Surname: "boss", + Age: 60, + }, + Teams: []Team { + team, + }, + FoundedAt: now, + } + buffer := bytes.NewBuffer(nil) - MarshalOnePayloadEmbedded(buffer, &team) + MarshalOnePayloadEmbedded(buffer, &company) reader := bytes.NewReader(buffer.Bytes()) - var finalTeam Team - UnmarshalPayload(reader, &finalTeam) + var finalCompany Company + UnmarshalPayload(reader, &finalCompany) + + diff := company.FoundedAt.Sub(finalCompany.FoundedAt) + + if diff.Seconds() > 1 { + t.Error("final unmarshal payload founded at must be approximately equal to the original.") + } + + company.FoundedAt = time.Time{} + finalCompany.FoundedAt = time.Time{} - if !reflect.DeepEqual(team, finalTeam) { + if !reflect.DeepEqual(company, finalCompany) { t.Error("final unmarshal payload should be equal to the original one.") } }