diff --git a/graphql_test.go b/graphql_test.go index a6a318d9..081a708c 100644 --- a/graphql_test.go +++ b/graphql_test.go @@ -1,6 +1,7 @@ package graphql_test import ( + "bytes" "context" "errors" "fmt" @@ -2926,7 +2927,7 @@ func TestInput(t *testing.T) { }) } -type inputArgumentsHello struct {} +type inputArgumentsHello struct{} type inputArgumentsScalarMismatch1 struct{} @@ -2946,7 +2947,7 @@ type helloInputMismatch struct { World string } -func (r *inputArgumentsHello) Hello(args struct { Input *helloInput }) string { +func (r *inputArgumentsHello) Hello(args struct{ Input *helloInput }) string { return "Hello " + args.Input.Name + "!" } @@ -2954,7 +2955,7 @@ func (r *inputArgumentsScalarMismatch1) Hello(name string) string { return "Hello " + name + "!" } -func (r *inputArgumentsScalarMismatch2) Hello(args struct { World string }) string { +func (r *inputArgumentsScalarMismatch2) Hello(args struct{ World string }) string { return "Hello " + args.World + "!" } @@ -2962,11 +2963,11 @@ func (r *inputArgumentsObjectMismatch1) Hello(in helloInput) string { return "Hello " + in.Name + "!" } -func (r *inputArgumentsObjectMismatch2) Hello(args struct { Input *helloInputMismatch }) string { +func (r *inputArgumentsObjectMismatch2) Hello(args struct{ Input *helloInputMismatch }) string { return "Hello " + args.Input.World + "!" } -func (r *inputArgumentsObjectMismatch3) Hello(args struct { Input *struct { Thing string } }) string { +func (r *inputArgumentsObjectMismatch3) Hello(args struct{ Input *struct{ Thing string } }) string { return "Hello " + args.Input.Thing + "!" } @@ -3635,3 +3636,84 @@ func TestSubscriptions_In_Exec(t *testing.T) { }, }) } + +func TestOverlappingAlias(t *testing.T) { + query := ` + { + hero(episode: EMPIRE) { + a: name + a: id + } + } + ` + result := starwarsSchema.Exec(context.Background(), query, "", nil) + if len(result.Errors) == 0 { + t.Fatal("Expected error from overlapping alias") + } +} + +// go test -bench=FragmentQueries -benchmem +func BenchmarkFragmentQueries(b *testing.B) { + singleQuery := ` + composed_%d: hero(episode: EMPIRE) { + name + ...friendsNames + ...friendsIds + } + ` + + queryTemplate := ` + { + %s + } + + fragment friendsNames on Character { + friends { + name + } + } + + fragment friendsIds on Character { + friends { + id + } + } + ` + + testCases := []int{ + 1, + 10, + 100, + 1000, + 10000, + } + + for _, c := range testCases { + // for each count, add a case for overlapping aliases vs non-overlapping aliases + for _, o := range []bool{true} { + + var buffer bytes.Buffer + for i := 0; i < c; i++ { + idx := 0 + if o { + idx = i + } + buffer.WriteString(fmt.Sprintf(singleQuery, idx)) + } + + query := fmt.Sprintf(queryTemplate, buffer.String()) + a := "overlapping" + if o { + a = "non-overlapping" + } + b.Run(fmt.Sprintf("%d queries %s aliases", c, a), func(b *testing.B) { + for n := 0; n < b.N; n++ { + result := starwarsSchema.Exec(context.Background(), query, "", nil) + if len(result.Errors) != 0 { + b.Fatal(result.Errors[0]) + } + } + }) + } + } +} diff --git a/internal/common/lexer.go b/internal/common/lexer.go index 5807f7ae..c3b1fdf2 100644 --- a/internal/common/lexer.go +++ b/internal/common/lexer.go @@ -30,6 +30,10 @@ func NewLexer(s string, useStringDescriptions bool) *Lexer { } sc.Init(strings.NewReader(s)) + sc.Error = func(s *scanner.Scanner, msg string) { + // do nothing, as we get a large volume of bad requests and we dont want to log these + } + return &Lexer{sc: sc, useStringDescriptions: useStringDescriptions} } diff --git a/internal/validation/validation.go b/internal/validation/validation.go index 94a9faf8..d919f0c9 100644 --- a/internal/validation/validation.go +++ b/internal/validation/validation.go @@ -275,9 +275,10 @@ func validateSelectionSet(c *opContext, sels []query.Selection, t schema.NamedTy validateSelection(c, sel, t) } + useCache := len(sels) <= 100 for i, a := range sels { for _, b := range sels[i+1:] { - c.validateOverlap(a, b, nil, nil) + c.validateOverlap(a, b, nil, nil, useCache) } } } @@ -485,16 +486,21 @@ func detectFragmentCycleSel(c *context, sel query.Selection, fragVisited map[*qu } } -func (c *context) validateOverlap(a, b query.Selection, reasons *[]string, locs *[]errors.Location) { +func (c *context) validateOverlap(a, b query.Selection, reasons *[]string, locs *[]errors.Location, useCache bool) { if a == b { return } - if _, ok := c.overlapValidated[selectionPair{a, b}]; ok { - return + if useCache { + if _, ok := c.overlapValidated[selectionPair{a, b}]; ok { + return + } + key := selectionPair{b, a} + if _, ok := c.overlapValidated[key]; ok { + return + } + c.overlapValidated[key] = struct{}{} } - c.overlapValidated[selectionPair{a, b}] = struct{}{} - c.overlapValidated[selectionPair{b, a}] = struct{}{} switch a := a.(type) { case *query.Field: @@ -503,7 +509,7 @@ func (c *context) validateOverlap(a, b query.Selection, reasons *[]string, locs if b.Alias.Loc.Before(a.Alias.Loc) { a, b = b, a } - if reasons2, locs2 := c.validateFieldOverlap(a, b); len(reasons2) != 0 { + if reasons2, locs2 := c.validateFieldOverlap(a, b, useCache); len(reasons2) != 0 { locs2 = append(locs2, a.Alias.Loc, b.Alias.Loc) if reasons == nil { c.addErrMultiLoc(locs2, "OverlappingFieldsCanBeMerged", "Fields %q conflict because %s. Use different aliases on the fields to fetch both if this was intentional.", a.Alias.Name, strings.Join(reasons2, " and ")) @@ -517,13 +523,13 @@ func (c *context) validateOverlap(a, b query.Selection, reasons *[]string, locs case *query.InlineFragment: for _, sel := range b.Selections { - c.validateOverlap(a, sel, reasons, locs) + c.validateOverlap(a, sel, reasons, locs, useCache) } case *query.FragmentSpread: if frag := c.doc.Fragments.Get(b.Name.Name); frag != nil { for _, sel := range frag.Selections { - c.validateOverlap(a, sel, reasons, locs) + c.validateOverlap(a, sel, reasons, locs, useCache) } } @@ -533,13 +539,13 @@ func (c *context) validateOverlap(a, b query.Selection, reasons *[]string, locs case *query.InlineFragment: for _, sel := range a.Selections { - c.validateOverlap(sel, b, reasons, locs) + c.validateOverlap(sel, b, reasons, locs, useCache) } case *query.FragmentSpread: if frag := c.doc.Fragments.Get(a.Name.Name); frag != nil { for _, sel := range frag.Selections { - c.validateOverlap(sel, b, reasons, locs) + c.validateOverlap(sel, b, reasons, locs, useCache) } } @@ -548,21 +554,24 @@ func (c *context) validateOverlap(a, b query.Selection, reasons *[]string, locs } } -func (c *context) validateFieldOverlap(a, b *query.Field) ([]string, []errors.Location) { +func (c *context) validateFieldOverlap(a, b *query.Field, useCache bool) ([]string, []errors.Location) { if a.Alias.Name != b.Alias.Name { return nil, nil } - if asf := c.fieldMap[a].sf; asf != nil { - if bsf := c.fieldMap[b].sf; bsf != nil { + afm := c.fieldMap[a] + bfm := c.fieldMap[b] + + if asf := afm.sf; asf != nil { + if bsf := bfm.sf; bsf != nil { if !typesCompatible(asf.Type, bsf.Type) { return []string{fmt.Sprintf("they return conflicting types %s and %s", asf.Type, bsf.Type)}, nil } } } - at := c.fieldMap[a].parent - bt := c.fieldMap[b].parent + at := afm.parent + bt := bfm.parent if at == nil || bt == nil || at == bt { if a.Name.Name != b.Name.Name { return []string{fmt.Sprintf("%s and %s are different fields", a.Name.Name, b.Name.Name)}, nil @@ -577,7 +586,7 @@ func (c *context) validateFieldOverlap(a, b *query.Field) ([]string, []errors.Lo var locs []errors.Location for _, a2 := range a.Selections { for _, b2 := range b.Selections { - c.validateOverlap(a2, b2, &reasons, &locs) + c.validateOverlap(a2, b2, &reasons, &locs, useCache) } } return reasons, locs