Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support custom marshaling and unmarshaling for attributes #18

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/fixtures.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@ package main
import "time"

func fixtureBlogCreate(i int) *Blog {
ts := time.Now()
return &Blog{
ID: 1 * i,
Title: "Title 1",
CreatedAt: time.Now(),
CreatedAt: &UnsetableTime{&ts},
Posts: []*Post{
{
ID: 1 * i,
Expand Down
31 changes: 24 additions & 7 deletions examples/models.go
Original file line number Diff line number Diff line change
@@ -1,21 +1,38 @@
package main

import (
"encoding/json"
"fmt"
"time"

"github.com/hashicorp/jsonapi"
)

type UnsetableTime struct {
Value *time.Time
}

func (t *UnsetableTime) MarshalAttribute() (interface{}, error) {
if t == nil {
return nil, nil
}

if t.Value == nil {
return json.RawMessage(nil), nil
} else {
return t.Value, nil
}
}

// Blog is a model representing a blog site
type Blog struct {
ID int `jsonapi:"primary,blogs"`
Title string `jsonapi:"attr,title"`
Posts []*Post `jsonapi:"relation,posts"`
CurrentPost *Post `jsonapi:"relation,current_post"`
CurrentPostID int `jsonapi:"attr,current_post_id"`
CreatedAt time.Time `jsonapi:"attr,created_at"`
ViewCount int `jsonapi:"attr,view_count"`
ID int `jsonapi:"primary,blogs"`
Title string `jsonapi:"attr,title"`
Posts []*Post `jsonapi:"relation,posts"`
CurrentPost *Post `jsonapi:"relation,current_post"`
CurrentPostID int `jsonapi:"attr,current_post_id"`
CreatedAt *UnsetableTime `jsonapi:"attr,created_at,omitempty,iso8601"`
ViewCount int `jsonapi:"attr,view_count"`
}

// Post is a model representing a post on a blog
Expand Down
61 changes: 53 additions & 8 deletions models_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
package jsonapi

import (
"encoding/json"
"errors"
"fmt"
"time"
)

var now = time.Now()

type BadModel struct {
ID int `jsonapi:"primary"`
}
Expand Down Expand Up @@ -80,15 +84,56 @@ type GenericInterface struct {
Data interface{} `jsonapi:"attr,interface"`
}

type UnsetableTime struct {
Value *time.Time
}

func (t *UnsetableTime) MarshalAttribute() (interface{}, error) {
if t == nil {
return nil, nil
}

if t.Value == nil {
return json.RawMessage(nil), nil
} else {
return t.Value, nil
}
}

func (t *UnsetableTime) UnmarshalAttribute(obj interface{}) error {
var ts time.Time
var err error

if obj == nil {
t.Value = nil
return nil
}

if tsStr, ok := obj.(string); ok {
ts, err = time.Parse(tsStr, time.RFC3339)
if err == nil {
t.Value = &ts
return nil
}
} else if tsFloat, ok := obj.(float64); ok {
ts = time.Unix(int64(tsFloat), 0)

t.Value = &ts
return nil
}

return errors.New("couldn't parse time")
}

type Blog struct {
ID int `jsonapi:"primary,blogs"`
ClientID string `jsonapi:"client-id"`
Title string `jsonapi:"attr,title"`
Posts []*Post `jsonapi:"relation,posts"`
CurrentPost *Post `jsonapi:"relation,current_post"`
CurrentPostID int `jsonapi:"attr,current_post_id"`
CreatedAt time.Time `jsonapi:"attr,created_at"`
ViewCount int `jsonapi:"attr,view_count"`
ID int `jsonapi:"primary,blogs"`
ClientID string `jsonapi:"client-id"`
Title string `jsonapi:"attr,title"`
Posts []*Post `jsonapi:"relation,posts"`
CurrentPost *Post `jsonapi:"relation,current_post"`
CurrentPostID int `jsonapi:"attr,current_post_id"`
CreatedAt *UnsetableTime `jsonapi:"attr,created_at,omitempty"`
ViewCount int `jsonapi:"attr,view_count"`

Links Links `jsonapi:"links,omitempty"`
}
Expand Down
13 changes: 13 additions & 0 deletions request.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ var (
ErrTypeNotFound = errors.New("no primary type annotation found on model")
)

type AttributeUnmarshaler interface {
UnmarshalAttribute(interface{}) error
}

// ErrUnsupportedPtrType is returned when the Struct field was a pointer but
// the JSON value was of a different type
type ErrUnsupportedPtrType struct {
Expand Down Expand Up @@ -589,6 +593,15 @@ func unmarshalAttribute(
value = reflect.ValueOf(attribute)
fieldType := structField.Type

i := reflect.TypeOf((*AttributeUnmarshaler)(nil)).Elem()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be a package level var so it doesn't have to be evaluated on every attribute?

if fieldType.Implements(i) {
x := reflect.New(fieldType.Elem())
y := (x.Interface()).(AttributeUnmarshaler)
err = y.UnmarshalAttribute(attribute)
value = reflect.ValueOf(y)
return
}

// Handle field of type []string
if fieldValue.Type() == reflect.TypeOf([]string{}) {
value, err = handleStringSlice(attribute)
Expand Down
4 changes: 2 additions & 2 deletions request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ func TestUnmarshalSetsAttrs(t *testing.T) {
t.Fatal(err)
}

if out.CreatedAt.IsZero() {
if out.CreatedAt.Value.IsZero() {
t.Fatalf("Did not parse time")
}

Expand Down Expand Up @@ -1431,7 +1431,7 @@ func testModel() *Blog {
ID: 5,
ClientID: "1",
Title: "Title 1",
CreatedAt: time.Now(),
CreatedAt: &UnsetableTime{&now},
Posts: []*Post{
{
ID: 1,
Expand Down
44 changes: 29 additions & 15 deletions response.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,15 @@ var (
ErrUnexpectedNil = errors.New("slice of struct pointers cannot contain nil")
)

// AttributeUnmarshaler can be implemented if custom marshaling is desired.
// This interface behaves differently than json.Marshaler in that it returns
// an interface rather than a byte array. The value returned can be a different
// type than the method reciever, and will be substituted for the original value
// as the jsonapi marshaling proceeds.
type AttributeMarshaler interface {
MarshalAttribute() (interface{}, error)
}

// MarshalPayload writes a jsonapi response for one or many records. The
// related records are sideloaded into the "included" array. If this method is
// given a struct pointer as an argument it will serialize in the form
Expand Down Expand Up @@ -331,12 +340,29 @@ func visitModelNode(model interface{}, included *map[string]*Node,
node.Attributes = make(map[string]interface{})
}

if fieldValue.Type() == reflect.TypeOf(time.Time{}) {
t := fieldValue.Interface().(time.Time)
// See if we need to omit this field
if omitEmpty {
if fieldValue.Interface() == nil {
continue
}

if t.IsZero() {
emptyValue := reflect.Zero(fieldValue.Type())
if reflect.DeepEqual(fieldValue.Interface(), emptyValue.Interface()) {
continue
}
}

if m, ok := fieldValue.Interface().(AttributeMarshaler); ok {
a, err := m.MarshalAttribute()
if err != nil {
return nil, err
}

fieldValue = reflect.ValueOf(a)
}

if fieldValue.Type() == reflect.TypeOf(time.Time{}) {
t := fieldValue.Interface().(time.Time)

if iso8601 {
node.Attributes[args[1]] = t.UTC().Format(iso8601TimeFormat)
Expand All @@ -348,10 +374,6 @@ func visitModelNode(model interface{}, included *map[string]*Node,
} else if fieldValue.Type() == reflect.TypeOf(new(time.Time)) {
// A time pointer may be nil
if fieldValue.IsNil() {
if omitEmpty {
continue
}

node.Attributes[args[1]] = nil
} else {
tm := fieldValue.Interface().(*time.Time)
Expand All @@ -369,14 +391,6 @@ func visitModelNode(model interface{}, included *map[string]*Node,
}
}
} else {
// Dealing with a fieldValue that is not a time
emptyValue := reflect.Zero(fieldValue.Type())

// See if we need to omit this field
if omitEmpty && reflect.DeepEqual(fieldValue.Interface(), emptyValue.Interface()) {
continue
}

strAttr, ok := fieldValue.Interface().(string)
if ok {
node.Attributes[args[1]] = strAttr
Expand Down
52 changes: 38 additions & 14 deletions response_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,7 @@ func TestHasPrimaryAnnotation(t *testing.T) {
testModel := &Blog{
ID: 5,
Title: "Title 1",
CreatedAt: time.Now(),
CreatedAt: &UnsetableTime{&now},
}

out := bytes.NewBuffer(nil)
Expand Down Expand Up @@ -658,7 +658,7 @@ func TestSupportsAttributes(t *testing.T) {
testModel := &Blog{
ID: 5,
Title: "Title 1",
CreatedAt: time.Now(),
CreatedAt: &UnsetableTime{&now},
}

out := bytes.NewBuffer(nil)
Expand All @@ -683,10 +683,10 @@ func TestSupportsAttributes(t *testing.T) {
}

func TestOmitsZeroTimes(t *testing.T) {
testModel := &Blog{
ID: 5,
Title: "Title 1",
CreatedAt: time.Time{},
testModel := &Company{
ID: "id",
Name: "Company",
FoundedAt: time.Time{},
}

out := bytes.NewBuffer(nil)
Expand All @@ -705,8 +705,8 @@ func TestOmitsZeroTimes(t *testing.T) {
t.Fatalf("Expected attributes")
}

if data.Attributes["created_at"] != nil {
t.Fatalf("Created at was serialized even though it was a zero Time")
if data.Attributes["founded_at"] != nil {
t.Fatalf("Founded at was serialized even though it was a zero Time")
}
}

Expand Down Expand Up @@ -824,7 +824,7 @@ func TestSupportsLinkable(t *testing.T) {
testModel := &Blog{
ID: 5,
Title: "Title 1",
CreatedAt: time.Now(),
CreatedAt: &UnsetableTime{&now},
}

out := bytes.NewBuffer(nil)
Expand Down Expand Up @@ -906,7 +906,7 @@ func TestSupportsMetable(t *testing.T) {
testModel := &Blog{
ID: 5,
Title: "Title 1",
CreatedAt: time.Now(),
CreatedAt: &UnsetableTime{&now},
}

out := bytes.NewBuffer(nil)
Expand Down Expand Up @@ -977,7 +977,7 @@ func TestRelations(t *testing.T) {
}

func TestNoRelations(t *testing.T) {
testModel := &Blog{ID: 1, Title: "Title 1", CreatedAt: time.Now()}
testModel := &Blog{ID: 1, Title: "Title 1", CreatedAt: &UnsetableTime{&now}}

out := bytes.NewBuffer(nil)
if err := MarshalPayload(out, testModel); err != nil {
Expand Down Expand Up @@ -1037,7 +1037,7 @@ func TestMarshalPayload_many(t *testing.T) {
&Blog{
ID: 5,
Title: "Title 1",
CreatedAt: time.Now(),
CreatedAt: &UnsetableTime{&now},
Posts: []*Post{
{
ID: 1,
Expand All @@ -1059,7 +1059,7 @@ func TestMarshalPayload_many(t *testing.T) {
&Blog{
ID: 6,
Title: "Title 2",
CreatedAt: time.Now(),
CreatedAt: &UnsetableTime{&now},
Posts: []*Post{
{
ID: 3,
Expand Down Expand Up @@ -1200,7 +1200,7 @@ func testBlog() *Blog {
return &Blog{
ID: 5,
Title: "Title 1",
CreatedAt: time.Now(),
CreatedAt: &UnsetableTime{&now},
Posts: []*Post{
{
ID: 1,
Expand Down Expand Up @@ -1262,3 +1262,27 @@ func testBlog() *Blog {
},
}
}

func TestCustomAttributeMarshaling(t *testing.T) {
blog := &Blog{ID: 1, Title: "Title 1", CreatedAt: nil}

bytes := bytes.NewBuffer(nil)
MarshalPayload(bytes, blog)

var jsonData map[string]interface{}
if err := json.Unmarshal(bytes.Bytes(), &jsonData); err != nil {
t.Fatal(err)
}

if data, ok := jsonData["data"].(map[string]interface{}); ok {
if attrs, ok := data["attributes"].(map[string]interface{}); ok {
if _, ok := attrs["created_at"]; ok {
t.Fatalf("attributes should not contain `created_at`")
}
} else {
t.Fatalf("attributes key did not contain a Hash/Dict/Map")
}
} else {
t.Fatalf("data key did not contain a Hash/Dict/Map")
}
}