diff --git a/examples/fixtures.go b/examples/fixtures.go index 7d0402d..79c23d3 100644 --- a/examples/fixtures.go +++ b/examples/fixtures.go @@ -3,10 +3,11 @@ package main import "time" func fixtureBlogCreate(i int) *Blog { + ts := time.Now() return &Blog{ ID: 1 * i, Title: "Title 1", - CreatedAt: time.Now(), + CreatedAt: &UnsetableTime{&ts}, Posts: []*Post{ { ID: 1 * i, diff --git a/examples/models.go b/examples/models.go index 4842361..234c2a0 100644 --- a/examples/models.go +++ b/examples/models.go @@ -1,21 +1,38 @@ package main import ( + "encoding/json" "fmt" "time" "github.com/hashicorp/jsonapi" ) +type UnsetableTime struct { + Value *time.Time +} + +func (t *UnsetableTime) MarshalAttribute() (interface{}, error) { + if t == nil { + return nil, nil + } + + if t.Value == nil { + return json.RawMessage(nil), nil + } else { + return t.Value, nil + } +} + // Blog is a model representing a blog site type Blog struct { - ID int `jsonapi:"primary,blogs"` - Title string `jsonapi:"attr,title"` - Posts []*Post `jsonapi:"relation,posts"` - CurrentPost *Post `jsonapi:"relation,current_post"` - CurrentPostID int `jsonapi:"attr,current_post_id"` - CreatedAt time.Time `jsonapi:"attr,created_at"` - ViewCount int `jsonapi:"attr,view_count"` + ID int `jsonapi:"primary,blogs"` + Title string `jsonapi:"attr,title"` + Posts []*Post `jsonapi:"relation,posts"` + CurrentPost *Post `jsonapi:"relation,current_post"` + CurrentPostID int `jsonapi:"attr,current_post_id"` + CreatedAt *UnsetableTime `jsonapi:"attr,created_at,omitempty,iso8601"` + ViewCount int `jsonapi:"attr,view_count"` } // Post is a model representing a post on a blog diff --git a/models_test.go b/models_test.go index 889142a..a3bcfbf 100644 --- a/models_test.go +++ b/models_test.go @@ -1,10 +1,14 @@ package jsonapi import ( + "encoding/json" + "errors" "fmt" "time" ) +var now = time.Now() + type BadModel struct { ID int `jsonapi:"primary"` } @@ -80,15 +84,56 @@ type GenericInterface struct { Data interface{} `jsonapi:"attr,interface"` } +type UnsetableTime struct { + Value *time.Time +} + +func (t *UnsetableTime) MarshalAttribute() (interface{}, error) { + if t == nil { + return nil, nil + } + + if t.Value == nil { + return json.RawMessage(nil), nil + } else { + return t.Value, nil + } +} + +func (t *UnsetableTime) UnmarshalAttribute(obj interface{}) error { + var ts time.Time + var err error + + if obj == nil { + t.Value = nil + return nil + } + + if tsStr, ok := obj.(string); ok { + ts, err = time.Parse(tsStr, time.RFC3339) + if err == nil { + t.Value = &ts + return nil + } + } else if tsFloat, ok := obj.(float64); ok { + ts = time.Unix(int64(tsFloat), 0) + + t.Value = &ts + return nil + } + + return errors.New("couldn't parse time") +} + type Blog struct { - ID int `jsonapi:"primary,blogs"` - ClientID string `jsonapi:"client-id"` - Title string `jsonapi:"attr,title"` - Posts []*Post `jsonapi:"relation,posts"` - CurrentPost *Post `jsonapi:"relation,current_post"` - CurrentPostID int `jsonapi:"attr,current_post_id"` - CreatedAt time.Time `jsonapi:"attr,created_at"` - ViewCount int `jsonapi:"attr,view_count"` + ID int `jsonapi:"primary,blogs"` + ClientID string `jsonapi:"client-id"` + Title string `jsonapi:"attr,title"` + Posts []*Post `jsonapi:"relation,posts"` + CurrentPost *Post `jsonapi:"relation,current_post"` + CurrentPostID int `jsonapi:"attr,current_post_id"` + CreatedAt *UnsetableTime `jsonapi:"attr,created_at,omitempty"` + ViewCount int `jsonapi:"attr,view_count"` Links Links `jsonapi:"links,omitempty"` } diff --git a/request.go b/request.go index a12c2b1..c9a52dc 100644 --- a/request.go +++ b/request.go @@ -36,6 +36,10 @@ var ( ErrTypeNotFound = errors.New("no primary type annotation found on model") ) +type AttributeUnmarshaler interface { + UnmarshalAttribute(interface{}) error +} + // ErrUnsupportedPtrType is returned when the Struct field was a pointer but // the JSON value was of a different type type ErrUnsupportedPtrType struct { @@ -589,6 +593,15 @@ func unmarshalAttribute( value = reflect.ValueOf(attribute) fieldType := structField.Type + i := reflect.TypeOf((*AttributeUnmarshaler)(nil)).Elem() + if fieldType.Implements(i) { + x := reflect.New(fieldType.Elem()) + y := (x.Interface()).(AttributeUnmarshaler) + err = y.UnmarshalAttribute(attribute) + value = reflect.ValueOf(y) + return + } + // Handle field of type []string if fieldValue.Type() == reflect.TypeOf([]string{}) { value, err = handleStringSlice(attribute) diff --git a/request_test.go b/request_test.go index 7eb9bda..b610d75 100644 --- a/request_test.go +++ b/request_test.go @@ -387,7 +387,7 @@ func TestUnmarshalSetsAttrs(t *testing.T) { t.Fatal(err) } - if out.CreatedAt.IsZero() { + if out.CreatedAt.Value.IsZero() { t.Fatalf("Did not parse time") } @@ -1431,7 +1431,7 @@ func testModel() *Blog { ID: 5, ClientID: "1", Title: "Title 1", - CreatedAt: time.Now(), + CreatedAt: &UnsetableTime{&now}, Posts: []*Post{ { ID: 1, diff --git a/response.go b/response.go index 602b16b..e105623 100644 --- a/response.go +++ b/response.go @@ -30,6 +30,15 @@ var ( ErrUnexpectedNil = errors.New("slice of struct pointers cannot contain nil") ) +// AttributeUnmarshaler can be implemented if custom marshaling is desired. +// This interface behaves differently than json.Marshaler in that it returns +// an interface rather than a byte array. The value returned can be a different +// type than the method reciever, and will be substituted for the original value +// as the jsonapi marshaling proceeds. +type AttributeMarshaler interface { + MarshalAttribute() (interface{}, error) +} + // MarshalPayload writes a jsonapi response for one or many records. The // related records are sideloaded into the "included" array. If this method is // given a struct pointer as an argument it will serialize in the form @@ -331,12 +340,29 @@ func visitModelNode(model interface{}, included *map[string]*Node, node.Attributes = make(map[string]interface{}) } - if fieldValue.Type() == reflect.TypeOf(time.Time{}) { - t := fieldValue.Interface().(time.Time) + // See if we need to omit this field + if omitEmpty { + if fieldValue.Interface() == nil { + continue + } - if t.IsZero() { + emptyValue := reflect.Zero(fieldValue.Type()) + if reflect.DeepEqual(fieldValue.Interface(), emptyValue.Interface()) { continue } + } + + if m, ok := fieldValue.Interface().(AttributeMarshaler); ok { + a, err := m.MarshalAttribute() + if err != nil { + return nil, err + } + + fieldValue = reflect.ValueOf(a) + } + + if fieldValue.Type() == reflect.TypeOf(time.Time{}) { + t := fieldValue.Interface().(time.Time) if iso8601 { node.Attributes[args[1]] = t.UTC().Format(iso8601TimeFormat) @@ -348,10 +374,6 @@ func visitModelNode(model interface{}, included *map[string]*Node, } else if fieldValue.Type() == reflect.TypeOf(new(time.Time)) { // A time pointer may be nil if fieldValue.IsNil() { - if omitEmpty { - continue - } - node.Attributes[args[1]] = nil } else { tm := fieldValue.Interface().(*time.Time) @@ -369,14 +391,6 @@ func visitModelNode(model interface{}, included *map[string]*Node, } } } else { - // Dealing with a fieldValue that is not a time - emptyValue := reflect.Zero(fieldValue.Type()) - - // See if we need to omit this field - if omitEmpty && reflect.DeepEqual(fieldValue.Interface(), emptyValue.Interface()) { - continue - } - strAttr, ok := fieldValue.Interface().(string) if ok { node.Attributes[args[1]] = strAttr diff --git a/response_test.go b/response_test.go index 599d999..78f320a 100644 --- a/response_test.go +++ b/response_test.go @@ -629,7 +629,7 @@ func TestHasPrimaryAnnotation(t *testing.T) { testModel := &Blog{ ID: 5, Title: "Title 1", - CreatedAt: time.Now(), + CreatedAt: &UnsetableTime{&now}, } out := bytes.NewBuffer(nil) @@ -658,7 +658,7 @@ func TestSupportsAttributes(t *testing.T) { testModel := &Blog{ ID: 5, Title: "Title 1", - CreatedAt: time.Now(), + CreatedAt: &UnsetableTime{&now}, } out := bytes.NewBuffer(nil) @@ -683,10 +683,10 @@ func TestSupportsAttributes(t *testing.T) { } func TestOmitsZeroTimes(t *testing.T) { - testModel := &Blog{ - ID: 5, - Title: "Title 1", - CreatedAt: time.Time{}, + testModel := &Company{ + ID: "id", + Name: "Company", + FoundedAt: time.Time{}, } out := bytes.NewBuffer(nil) @@ -705,8 +705,8 @@ func TestOmitsZeroTimes(t *testing.T) { t.Fatalf("Expected attributes") } - if data.Attributes["created_at"] != nil { - t.Fatalf("Created at was serialized even though it was a zero Time") + if data.Attributes["founded_at"] != nil { + t.Fatalf("Founded at was serialized even though it was a zero Time") } } @@ -824,7 +824,7 @@ func TestSupportsLinkable(t *testing.T) { testModel := &Blog{ ID: 5, Title: "Title 1", - CreatedAt: time.Now(), + CreatedAt: &UnsetableTime{&now}, } out := bytes.NewBuffer(nil) @@ -906,7 +906,7 @@ func TestSupportsMetable(t *testing.T) { testModel := &Blog{ ID: 5, Title: "Title 1", - CreatedAt: time.Now(), + CreatedAt: &UnsetableTime{&now}, } out := bytes.NewBuffer(nil) @@ -977,7 +977,7 @@ func TestRelations(t *testing.T) { } func TestNoRelations(t *testing.T) { - testModel := &Blog{ID: 1, Title: "Title 1", CreatedAt: time.Now()} + testModel := &Blog{ID: 1, Title: "Title 1", CreatedAt: &UnsetableTime{&now}} out := bytes.NewBuffer(nil) if err := MarshalPayload(out, testModel); err != nil { @@ -1037,7 +1037,7 @@ func TestMarshalPayload_many(t *testing.T) { &Blog{ ID: 5, Title: "Title 1", - CreatedAt: time.Now(), + CreatedAt: &UnsetableTime{&now}, Posts: []*Post{ { ID: 1, @@ -1059,7 +1059,7 @@ func TestMarshalPayload_many(t *testing.T) { &Blog{ ID: 6, Title: "Title 2", - CreatedAt: time.Now(), + CreatedAt: &UnsetableTime{&now}, Posts: []*Post{ { ID: 3, @@ -1200,7 +1200,7 @@ func testBlog() *Blog { return &Blog{ ID: 5, Title: "Title 1", - CreatedAt: time.Now(), + CreatedAt: &UnsetableTime{&now}, Posts: []*Post{ { ID: 1, @@ -1262,3 +1262,27 @@ func testBlog() *Blog { }, } } + +func TestCustomAttributeMarshaling(t *testing.T) { + blog := &Blog{ID: 1, Title: "Title 1", CreatedAt: nil} + + bytes := bytes.NewBuffer(nil) + MarshalPayload(bytes, blog) + + var jsonData map[string]interface{} + if err := json.Unmarshal(bytes.Bytes(), &jsonData); err != nil { + t.Fatal(err) + } + + if data, ok := jsonData["data"].(map[string]interface{}); ok { + if attrs, ok := data["attributes"].(map[string]interface{}); ok { + if _, ok := attrs["created_at"]; ok { + t.Fatalf("attributes should not contain `created_at`") + } + } else { + t.Fatalf("attributes key did not contain a Hash/Dict/Map") + } + } else { + t.Fatalf("data key did not contain a Hash/Dict/Map") + } +}