Skip to content

Commit fdb29a2

Browse files
committed
Add improvements
1 parent 7b0afb9 commit fdb29a2

File tree

4 files changed

+58
-40
lines changed

4 files changed

+58
-40
lines changed

models_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ func (bc *BadComment) JSONAPILinks() *Links {
159159
type Company struct {
160160
ID string `jsonapi:"primary,companies"`
161161
Name string `jsonapi:"attr,name"`
162-
Boss Employee `jsonapi:"attr,boss"`
162+
Boss *Employee `jsonapi:"attr,boss"`
163163
Teams []Team `jsonapi:"attr,teams"`
164164
FoundedAt time.Time `jsonapi:"attr,founded-at,iso8601"`
165165
}

request.go

+24-34
Original file line numberDiff line numberDiff line change
@@ -248,26 +248,14 @@ func unmarshalNode(data *Node, model reflect.Value, included *map[string]*Node)
248248

249249
structField := fieldType
250250

251-
if structField.Type.Kind() != reflect.Struct ||
252-
fieldValue.Type() == reflect.TypeOf(new(time.Time)) ||
253-
fieldValue.Type() == reflect.TypeOf(time.Time{}) {
254-
value, err := unmarshalAttribute(attribute, args, structField, fieldValue)
255-
if err != nil {
256-
er = err
257-
break
258-
}
259-
assign(fieldValue, value)
260-
continue
261-
262-
} else {
263-
structModel, err := unmarshalFromAttribute(attribute, fieldValue)
264-
if err != nil {
265-
er = err
266-
break
267-
}
268-
fieldValue.Set((*structModel).Elem())
269-
continue
251+
value, err := unmarshalAttribute(attribute, args, structField, fieldValue)
252+
if err != nil {
253+
er = err
254+
break
270255
}
256+
257+
assign(fieldValue, value)
258+
continue
271259
} else if annotation == annotationRelation {
272260
isSlice := fieldValue.Type().Kind() == reflect.Slice
273261

@@ -346,21 +334,23 @@ func unmarshalNode(data *Node, model reflect.Value, included *map[string]*Node)
346334
return er
347335
}
348336

349-
func unmarshalFromAttribute(attribute interface{}, fieldValue reflect.Value) (*reflect.Value, error) {
337+
func unmarshalFromAttribute(attribute interface{}, fieldValue reflect.Value) (reflect.Value, error) {
350338
structData, err := json.Marshal(attribute)
351339
if err != nil {
352-
return nil, err
340+
return reflect.Value{}, err
353341
}
342+
354343
structNode := new(Node)
355344
if err := json.Unmarshal(structData, &structNode.Attributes); err != nil {
356-
return nil, err
345+
return reflect.Value{}, err
357346
}
347+
358348
structModel := reflect.New(fieldValue.Type())
359349
if err := unmarshalNode(structNode, structModel, nil); err != nil {
360-
return nil, err
350+
return reflect.Value{}, err
361351
}
362352

363-
return &structModel, nil
353+
return structModel, nil
364354
}
365355

366356
func fullNode(n *Node, included *map[string]*Node) *Node {
@@ -376,7 +366,7 @@ func fullNode(n *Node, included *map[string]*Node) *Node {
376366
// assign will take the value specified and assign it to the field; if
377367
// field is expecting a ptr assign will assign a ptr.
378368
func assign(field, value reflect.Value) {
379-
if field.Kind() == reflect.Ptr || field.Kind() == reflect.Struct{
369+
if field.Kind() == reflect.Ptr {
380370
field.Set(value)
381371
} else {
382372
field.Set(reflect.Indirect(value))
@@ -402,12 +392,13 @@ func unmarshalAttribute(
402392
if fieldValue.Type() == reflect.TypeOf(time.Time{}) ||
403393
fieldValue.Type() == reflect.TypeOf(new(time.Time)) {
404394
value, err = handleTime(attribute, args, fieldValue)
395+
405396
return
406397
}
407398

408399
// Handle field of type struct
409-
if fieldValue.Type().Kind() == reflect.Struct {
410-
value, err = handleStruct(attribute, fieldValue)
400+
if fieldValue.Kind() == reflect.Struct {
401+
value, err = unmarshalFromAttribute(attribute, fieldValue)
411402
return
412403
}
413404

@@ -426,7 +417,7 @@ func unmarshalAttribute(
426417

427418
// Field was a Pointer type
428419
if fieldValue.Kind() == reflect.Ptr {
429-
value, err = handlePointer(attribute, args, fieldType, fieldValue, structField)
420+
value, err = handlePointer(attribute, fieldType, fieldValue, structField)
430421
return
431422
}
432423

@@ -482,7 +473,6 @@ func handleTime(attribute interface{}, args []string, fieldValue reflect.Value)
482473
}
483474

484475
var at int64
485-
486476
if v.Kind() == reflect.Float64 {
487477
at = int64(v.Interface().(float64))
488478
} else if v.Kind() == reflect.Int {
@@ -492,7 +482,6 @@ func handleTime(attribute interface{}, args []string, fieldValue reflect.Value)
492482
}
493483

494484
t := time.Unix(at, 0)
495-
496485
return reflect.ValueOf(t), nil
497486
}
498487

@@ -558,7 +547,6 @@ func handleNumeric(
558547

559548
func handlePointer(
560549
attribute interface{},
561-
args []string,
562550
fieldType reflect.Type,
563551
fieldValue reflect.Value,
564552
structField reflect.StructField) (reflect.Value, error) {
@@ -574,11 +562,13 @@ func handlePointer(
574562
concreteVal = reflect.ValueOf(&cVal)
575563
case map[string]interface{}:
576564
var err error
577-
concreteVal, err = handleStruct(attribute, fieldValue)
565+
fieldValueType := reflect.New(fieldValue.Type().Elem()).Elem()
566+
concreteVal, err = unmarshalFromAttribute(attribute, fieldValueType)
578567
if err != nil {
579568
return reflect.Value{}, newErrUnsupportedPtrType(
580569
reflect.ValueOf(attribute), fieldType, structField)
581570
}
571+
582572
return concreteVal, err
583573
default:
584574
return reflect.Value{}, newErrUnsupportedPtrType(
@@ -624,13 +614,13 @@ func handleStruct(
624614
func handleStructSlice(
625615
attribute interface{},
626616
fieldValue reflect.Value) (reflect.Value, error) {
617+
627618
models := reflect.New(fieldValue.Type()).Elem()
628619
dataMap := reflect.ValueOf(attribute).Interface().([]interface{})
629620
for _, data := range dataMap {
630621
model := reflect.New(fieldValue.Type().Elem()).Elem()
631622

632-
value, err := handleStruct(data, model)
633-
623+
value, err := unmarshalFromAttribute(data, model)
634624
if err != nil {
635625
continue
636626
}

response.go

+5-1
Original file line numberDiff line numberDiff line change
@@ -359,8 +359,12 @@ func visitModelNode(model interface{}, included *map[string]*Node,
359359
newSlice[i] = nested.Attributes
360360
}
361361
node.Attributes[args[1]] = newSlice
362-
} else if fieldValue.Kind() == reflect.Struct {
362+
} else if fieldValue.Kind() == reflect.Struct ||
363+
(fieldValue.Kind() == reflect.Ptr && fieldValue.Elem().Kind() == reflect.Struct) {
363364
included := make(map[string]*Node)
365+
if fieldValue.Kind() == reflect.Ptr {
366+
fieldValue = fieldValue.Elem()
367+
}
364368
nested, err := visitModelNode(fieldValue, &included, true)
365369
if err != nil {
366370
er = err

response_test.go

+28-4
Original file line numberDiff line numberDiff line change
@@ -918,13 +918,37 @@ func TestMarshalNestedStruct(t *testing.T) {
918918
},
919919
}
920920

921+
now := time.Now()
922+
company := Company {
923+
ID: "an_id",
924+
Name: "Awesome Company",
925+
Boss: &Employee{
926+
Firstname: "Company",
927+
Surname: "boss",
928+
Age: 60,
929+
},
930+
Teams: []Team {
931+
team,
932+
},
933+
FoundedAt: now,
934+
}
935+
921936
buffer := bytes.NewBuffer(nil)
922-
MarshalOnePayloadEmbedded(buffer, &team)
937+
MarshalOnePayloadEmbedded(buffer, &company)
923938
reader := bytes.NewReader(buffer.Bytes())
924-
var finalTeam Team
925-
UnmarshalPayload(reader, &finalTeam)
939+
var finalCompany Company
940+
UnmarshalPayload(reader, &finalCompany)
941+
942+
diff := company.FoundedAt.Sub(finalCompany.FoundedAt)
943+
944+
if diff.Seconds() > 1 {
945+
t.Error("final unmarshal payload founded at must be approximately equal to the original.")
946+
}
947+
948+
company.FoundedAt = time.Time{}
949+
finalCompany.FoundedAt = time.Time{}
926950

927-
if !reflect.DeepEqual(team, finalTeam) {
951+
if !reflect.DeepEqual(company, finalCompany) {
928952
t.Error("final unmarshal payload should be equal to the original one.")
929953
}
930954
}

0 commit comments

Comments
 (0)