diff --git a/graphql_test.go b/graphql_test.go index e09dcc9..c47ff27 100644 --- a/graphql_test.go +++ b/graphql_test.go @@ -2,6 +2,7 @@ package graphql_test import ( "context" + "encoding/json" "io" "io/ioutil" "net/http" @@ -64,6 +65,55 @@ func TestClient_Query_partialDataWithErrorResponse(t *testing.T) { } } +func TestClient_Query_partialDataRawQueryWithErrorResponse(t *testing.T) { + mux := http.NewServeMux() + mux.HandleFunc("/graphql", func(w http.ResponseWriter, req *http.Request) { + w.Header().Set("Content-Type", "application/json") + mustWrite(w, `{ + "data": { + "node1": { "id": "MDEyOklzc3VlQ29tbWVudDE2OTQwNzk0Ng==" }, + "node2": null + }, + "errors": [ + { + "message": "Could not resolve to a node with the global id of 'NotExist'", + "type": "NOT_FOUND", + "path": [ + "node2" + ], + "locations": [ + { + "line": 10, + "column": 4 + } + ] + } + ] + }`) + }) + client := graphql.NewClient("/graphql", &http.Client{Transport: localRoundTripper{handler: mux}}) + + var q struct { + Node1 json.RawMessage `graphql:"node1"` + Node2 *struct { + ID graphql.ID + } `graphql:"node2: node(id: \"NotExist\")"` + } + err := client.Query(context.Background(), &q, nil) + if err == nil { + t.Fatal("got error: nil, want: non-nil\n") + } + if got, want := err.Error(), "Could not resolve to a node with the global id of 'NotExist'"; got != want { + t.Errorf("got error: %v, want: %v\n", got, want) + } + if q.Node1 == nil || string(q.Node1) != `{"id":"MDEyOklzc3VlQ29tbWVudDE2OTQwNzk0Ng=="}` { + t.Errorf("got wrong q.Node1: %v\n", string(q.Node1)) + } + if q.Node2 != nil { + t.Errorf("got non-nil q.Node2: %v, want: nil\n", *q.Node2) + } +} + func TestClient_Query_noDataWithErrorResponse(t *testing.T) { mux := http.NewServeMux() mux.HandleFunc("/graphql", func(w http.ResponseWriter, req *http.Request) { diff --git a/internal/jsonutil/graphql.go b/internal/jsonutil/graphql.go index 15bae24..5df1465 100644 --- a/internal/jsonutil/graphql.go +++ b/internal/jsonutil/graphql.go @@ -42,6 +42,7 @@ func UnmarshalGraphQL(data []byte, v interface{}) error { type decoder struct { tokenizer interface { Token() (json.Token, error) + Decode(v interface{}) error } // Stack of what part of input JSON we're in the middle of - objects, arrays. @@ -68,10 +69,14 @@ func (d *decoder) Decode(v interface{}) error { // decode decodes a single JSON value from d.tokenizer into d.vs. func (d *decoder) decode() error { + rawMessageValue := reflect.ValueOf(json.RawMessage{}) + // The loop invariant is that the top of each d.vs stack // is where we try to unmarshal the next JSON value we see. for len(d.vs) > 0 { + var tok interface{} tok, err := d.tokenizer.Token() + if err == io.EOF { return errors.New("unexpected end of JSON input") } else if err != nil { @@ -87,6 +92,8 @@ func (d *decoder) decode() error { return errors.New("unexpected non-key in JSON input") } someFieldExist := false + // If one field is raw all must be treated as raw + rawMessage := false for i := range d.vs { v := d.vs[i][len(d.vs[i])-1] if v.Kind() == reflect.Ptr { @@ -97,6 +104,10 @@ func (d *decoder) decode() error { f = fieldByGraphQLName(v, key) if f.IsValid() { someFieldExist = true + // Check for special embedded json + if f.Type() == rawMessageValue.Type() { + rawMessage = true + } } } d.vs[i] = append(d.vs[i], f) @@ -105,13 +116,20 @@ func (d *decoder) decode() error { return fmt.Errorf("struct field for %q doesn't exist in any of %v places to unmarshal", key, len(d.vs)) } - // We've just consumed the current token, which was the key. - // Read the next token, which should be the value, and let the rest of code process it. - tok, err = d.tokenizer.Token() - if err == io.EOF { - return errors.New("unexpected end of JSON input") - } else if err != nil { - return err + if rawMessage { + // Read the next complete object from the json stream + var data json.RawMessage + d.tokenizer.Decode(&data) + tok = data + } else { + // We've just consumed the current token, which was the key. + // Read the next token, which should be the value, and let the rest of code process it. + tok, err = d.tokenizer.Token() + if err == io.EOF { + return errors.New("unexpected end of JSON input") + } else if err != nil { + return err + } } // Are we inside an array and seeing next value (rather than end of array)? @@ -136,7 +154,7 @@ func (d *decoder) decode() error { } switch tok := tok.(type) { - case string, json.Number, bool, nil: + case string, json.Number, bool, nil, json.RawMessage: // Value. for i := range d.vs { @@ -302,7 +320,7 @@ func isGraphQLFragment(f reflect.StructField) bool { // unmarshalValue unmarshals JSON value into v. // v must be addressable and not obtained by the use of unexported // struct fields, otherwise unmarshalValue will panic. -func unmarshalValue(value json.Token, v reflect.Value) error { +func unmarshalValue(value interface{}, v reflect.Value) error { b, err := json.Marshal(value) // TODO: Short-circuit (if profiling says it's worth it). if err != nil { return err diff --git a/internal/jsonutil/graphql_test.go b/internal/jsonutil/graphql_test.go index 6329ed8..8a63876 100644 --- a/internal/jsonutil/graphql_test.go +++ b/internal/jsonutil/graphql_test.go @@ -1,6 +1,7 @@ package jsonutil_test import ( + "encoding/json" "reflect" "testing" "time" @@ -80,6 +81,29 @@ func TestUnmarshalGraphQL_jsonTag(t *testing.T) { } } +func TestUnmarshalGraphQL_jsonRawTag(t *testing.T) { + type query struct { + Data json.RawMessage + Another string + } + var got query + err := jsonutil.UnmarshalGraphQL([]byte(`{ + "Data": { "foo":"bar" }, + "Another" : "stuff" + }`), &got) + + if err != nil { + t.Fatal(err) + } + want := query{ + Another: "stuff", + Data: []byte(`{"foo":"bar"}`), + } + if !reflect.DeepEqual(got, want) { + t.Errorf("not equal: %v %v", want, got) + } +} + func TestUnmarshalGraphQL_array(t *testing.T) { type query struct { Foo []graphql.String