diff --git a/models_test.go b/models_test.go index df3b43b..41f34d6 100644 --- a/models_test.go +++ b/models_test.go @@ -89,15 +89,36 @@ type GenericInterface struct { Data interface{} `jsonapi:"attr,interface"` } +type Organization struct { + ID int `jsonapi:"primary,organizations"` + ClientID string `jsonapi:"client-id"` + Name string `jsonapi:"attr,title"` + DefaultProject *Project `jsonapi:"relation,default_project"` + CreatedAt time.Time `jsonapi:"attr,created_at"` + + Links Links `jsonapi:"links,omitempty"` +} + +type Project struct { + ID int `jsonapi:"primary,projects"` + ClientID string `jsonapi:"client-id"` + Name string `jsonapi:"attr,name"` + Organization *Organization `jsonapi:"relation,organization"` + + Links Links `jsonapi:"links,omitempty"` +} + type Blog struct { - ID int `jsonapi:"primary,blogs"` - ClientID string `jsonapi:"client-id"` - Title string `jsonapi:"attr,title"` - Posts []*Post `jsonapi:"relation,posts"` - CurrentPost *Post `jsonapi:"relation,current_post"` - CurrentPostID int `jsonapi:"attr,current_post_id"` - CreatedAt time.Time `jsonapi:"attr,created_at"` - ViewCount int `jsonapi:"attr,view_count"` + ID int `jsonapi:"primary,blogs"` + ClientID string `jsonapi:"client-id"` + Title string `jsonapi:"attr,title"` + CurrentPostID int `jsonapi:"attr,current_post_id"` + CreatedAt time.Time `jsonapi:"attr,created_at"` + ViewCount int `jsonapi:"attr,view_count"` + Posts []*Post `jsonapi:"relation,posts"` + CurrentPost *Post `jsonapi:"relation,current_post"` + Organization *Organization `jsonapi:"relation,organization"` + Project *Project `jsonapi:"relation,project"` Links Links `jsonapi:"links,omitempty"` } diff --git a/request.go b/request.go index 27f628e..a21bc2b 100644 --- a/request.go +++ b/request.go @@ -60,6 +60,11 @@ func newErrUnsupportedPtrType(rf reflect.Value, t reflect.Type, structField refl return ErrUnsupportedPtrType{rf, t, structField} } +type includedNode struct { + node *Node + model *reflect.Value +} + // UnmarshalPayload converts an io into a struct instance using jsonapi tags on // struct fields. This method supports single request payloads only, at the // moment. Bulk creates and updates are not supported yet. @@ -94,19 +99,19 @@ func newErrUnsupportedPtrType(rf reflect.Value, t reflect.Type, structField refl // model interface{} should be a pointer to a struct. func UnmarshalPayload(in io.Reader, model interface{}) error { payload := new(OnePayload) + included := make(map[string]*includedNode) if err := json.NewDecoder(in).Decode(payload); err != nil { return err } if payload.Included != nil { - includedMap := make(map[string]*Node) - for _, included := range payload.Included { - key := fmt.Sprintf("%s,%s", included.Type, included.ID) - includedMap[key] = included + for _, include := range payload.Included { + key := fmt.Sprintf("%s,%s", include.Type, include.ID) + included[key] = &includedNode{include, nil} } - return unmarshalNode(payload.Data, reflect.ValueOf(model), &includedMap) + return unmarshalNode(payload.Data, reflect.ValueOf(model), &included) } return unmarshalNode(payload.Data, reflect.ValueOf(model), nil) } @@ -120,19 +125,19 @@ func UnmarshalManyPayload(in io.Reader, t reflect.Type) ([]interface{}, error) { return nil, err } - models := []interface{}{} // will be populated from the "data" - includedMap := map[string]*Node{} // will be populate from the "included" + models := []interface{}{} // will be populated from the "data" + included := map[string]*includedNode{} // will be populate from the "included" if payload.Included != nil { - for _, included := range payload.Included { - key := fmt.Sprintf("%s,%s", included.Type, included.ID) - includedMap[key] = included + for _, include := range payload.Included { + key := fmt.Sprintf("%s,%s", include.Type, include.ID) + included[key] = &includedNode{include, nil} } } for _, data := range payload.Data { model := reflect.New(t.Elem()) - err := unmarshalNode(data, model, &includedMap) + err := unmarshalNode(data, model, &included) if err != nil { return nil, err } @@ -263,7 +268,7 @@ func getStructTags(field reflect.StructField) ([]string, error) { // unmarshalNodeMaybeChoice populates a model that may or may not be // a choice type struct that corresponds to a polyrelation or relation -func unmarshalNodeMaybeChoice(m *reflect.Value, data *Node, annotation string, choiceTypeMapping map[string]structFieldIndex, included *map[string]*Node) error { +func unmarshalNodeMaybeChoice(m *reflect.Value, data *Node, annotation string, choiceTypeMapping map[string]structFieldIndex, included *map[string]*includedNode) error { // This will hold either the value of the choice type model or the actual // model, depending on annotation var actualModel = *m @@ -300,7 +305,7 @@ func unmarshalNodeMaybeChoice(m *reflect.Value, data *Node, annotation string, c return nil } -func unmarshalNode(data *Node, model reflect.Value, included *map[string]*Node) (err error) { +func unmarshalNode(data *Node, model reflect.Value, included *map[string]*includedNode) (err error) { defer func() { if r := recover(); r != nil { err = fmt.Errorf("data is not a jsonapi representation of '%v'", model.Type()) @@ -509,6 +514,23 @@ func unmarshalNode(data *Node, model reflect.Value, included *map[string]*Node) // model, depending on annotation m := reflect.New(fieldValue.Type().Elem()) + // Check if the item in the relationship was already processed elsewhere. Avoids potential infinite recursive loops + // caused by circular references between included relationships (two included items include one another) + includedKey := fmt.Sprintf("%s,%s", relationship.Data.Type, relationship.Data.ID) + if included != nil && (*included)[includedKey] != nil { + if (*included)[includedKey].model != nil { + fieldValue.Set(*(*included)[includedKey].model) + } else { + (*included)[includedKey].model = &m + err := unmarshalNodeMaybeChoice(&m, (*included)[includedKey].node, annotation, choiceMapping, included) + if err != nil { + er = err + break + } + fieldValue.Set(m) + } + continue + } err = unmarshalNodeMaybeChoice(&m, relationship.Data, annotation, choiceMapping, included) if err != nil { er = err @@ -565,11 +587,11 @@ func unmarshalNode(data *Node, model reflect.Value, included *map[string]*Node) return er } -func fullNode(n *Node, included *map[string]*Node) *Node { +func fullNode(n *Node, included *map[string]*includedNode) *Node { includedKey := fmt.Sprintf("%s,%s", n.Type, n.ID) if included != nil && (*included)[includedKey] != nil { - return (*included)[includedKey] + return (*included)[includedKey].node } return n diff --git a/request_test.go b/request_test.go index 1408ad9..4c11a19 100644 --- a/request_test.go +++ b/request_test.go @@ -689,6 +689,41 @@ func TestUnmarshalRelationships(t *testing.T) { } } +func TestUnmarshalMany_relationships_with_circular_inclusion(t *testing.T) { + data := samplePayloadWithCircularInclusion() + payload, err := json.Marshal(data) + if err != nil { + t.Fatal(err) + } + in := bytes.NewReader(payload) + model := reflect.TypeOf(new(Blog)) + + out, err := UnmarshalManyPayload(in, model) + if err != nil { + t.Fatal(err) + } + + result_1 := out[0].(*Blog) + + if result_1.Project != result_1.Organization.DefaultProject { + t.Errorf("expected blog.project (%p) to hold the same pointer as blog.organization.default-project (%p) ", result_1.Project, result_1.Organization.DefaultProject) + } + + if result_1.Organization != result_1.Project.Organization { + t.Errorf("expected blog.organization (%p) to hold the same pointer as blog.project.organization (%p)", result_1.Organization, result_1.Project.Organization) + } + + result_2 := out[1].(*Blog) + + if result_2.Project != result_2.Organization.DefaultProject { + t.Errorf("expected blog.project (%p) to hold the same pointer as blog.organization.default-project (%p) ", result_2.Project, result_2.Organization.DefaultProject) + } + + if result_2.Organization != result_2.Project.Organization { + t.Errorf("expected blog.organization (%p) to hold the same pointer as blog.project.organization (%p)", result_2.Organization, result_2.Project.Organization) + } +} + func Test_UnmarshalPayload_polymorphicRelations(t *testing.T) { in := bytes.NewReader([]byte(`{ "data": { @@ -1378,6 +1413,105 @@ func TestUnmarshalCustomTypeAttributes_ErrInvalidType(t *testing.T) { } } +func samplePayloadWithCircularInclusion() *ManyPayload { + payload := &ManyPayload{ + Data: []*Node{ + { + Type: "blogs", + ClientID: "1", + ID: "1", + Attributes: map[string]interface{}{ + "title": "Foo", + "current_post_id": 1, + "created_at": 1436216820, + "view_count": 1000, + }, + Relationships: map[string]interface{}{ + "project": &RelationshipOneNode{ + Data: &Node{ + Type: "projects", + ClientID: "1", + ID: "1", + }, + }, + "organization": &RelationshipOneNode{ + Data: &Node{ + Type: "organizations", + ClientID: "1", + ID: "1", + }, + }, + }, + }, + { + Type: "blogs", + ClientID: "2", + ID: "2", + Attributes: map[string]interface{}{ + "title": "Foo2", + "current_post_id": 1, + "created_at": 1436216820, + "view_count": 1000, + }, + Relationships: map[string]interface{}{ + "project": &RelationshipOneNode{ + Data: &Node{ + Type: "projects", + ClientID: "1", + ID: "1", + }, + }, + "organization": &RelationshipOneNode{ + Data: &Node{ + Type: "organizations", + ClientID: "1", + ID: "1", + }, + }, + }, + }, + }, + Included: []*Node{ + { + Type: "projects", + ClientID: "1", + ID: "1", + Attributes: map[string]interface{}{ + "name": "Bar", + }, + Relationships: map[string]interface{}{ + "organization": &RelationshipOneNode{ + Data: &Node{ + Type: "organizations", + ClientID: "1", + ID: "1", + }, + }, + }, + }, + { + Type: "organizations", + ClientID: "1", + ID: "1", + Attributes: map[string]interface{}{ + "name": "Baz", + }, + Relationships: map[string]interface{}{ + "default_project": &RelationshipOneNode{ + Data: &Node{ + Type: "projects", + ClientID: "1", + ID: "1", + }, + }, + }, + }, + }, + } + + return payload +} + func samplePayloadWithoutIncluded() map[string]interface{} { return map[string]interface{}{ "data": map[string]interface{}{