|
6 | 6 | "testing"
|
7 | 7 |
|
8 | 8 | "github.com/graphql-go/graphql"
|
| 9 | + "github.com/graphql-go/graphql/language/ast" |
| 10 | + "github.com/graphql-go/graphql/language/kinds" |
| 11 | + "github.com/graphql-go/graphql/language/visitor" |
9 | 12 | "github.com/graphql-go/graphql/testutil"
|
10 | 13 | )
|
11 | 14 |
|
@@ -268,3 +271,59 @@ func TestEmptyStringIsNotNull(t *testing.T) {
|
268 | 271 | t.Errorf("wrong result, query: %v, graphql result diff: %v", query, testutil.Diff(expected, result))
|
269 | 272 | }
|
270 | 273 | }
|
| 274 | + |
| 275 | +func TestQueryWithCustomRule(t *testing.T) { |
| 276 | + // Test graphql.Do() with custom rule, it extracts query name from each |
| 277 | + // Tests. |
| 278 | + ruleN := len(graphql.SpecifiedRules) |
| 279 | + rules := make([]graphql.ValidationRuleFn, ruleN+1) |
| 280 | + copy(rules[:ruleN], graphql.SpecifiedRules) |
| 281 | + |
| 282 | + var ( |
| 283 | + queryFound bool |
| 284 | + queryName string |
| 285 | + ) |
| 286 | + rules[ruleN] = func(context *graphql.ValidationContext) *graphql.ValidationRuleInstance { |
| 287 | + return &graphql.ValidationRuleInstance{ |
| 288 | + VisitorOpts: &visitor.VisitorOptions{ |
| 289 | + KindFuncMap: map[string]visitor.NamedVisitFuncs{ |
| 290 | + kinds.OperationDefinition: { |
| 291 | + Enter: func(p visitor.VisitFuncParams) (string, interface{}) { |
| 292 | + od, ok := p.Node.(*ast.OperationDefinition) |
| 293 | + if ok && od.Operation == "query" { |
| 294 | + queryFound = true |
| 295 | + if od.Name != nil { |
| 296 | + queryName = od.Name.Value |
| 297 | + } |
| 298 | + } |
| 299 | + return visitor.ActionNoChange, nil |
| 300 | + }, |
| 301 | + }, |
| 302 | + }, |
| 303 | + }, |
| 304 | + } |
| 305 | + } |
| 306 | + |
| 307 | + expectedNames := []string{ |
| 308 | + "HeroNameQuery", |
| 309 | + "HeroNameAndFriendsQuery", |
| 310 | + "HumanByIdQuery", |
| 311 | + } |
| 312 | + |
| 313 | + for i, test := range Tests { |
| 314 | + queryFound, queryName = false, "" |
| 315 | + params := graphql.Params{ |
| 316 | + Schema: test.Schema, |
| 317 | + RequestString: test.Query, |
| 318 | + VariableValues: test.Variables, |
| 319 | + ValidationRules: rules, |
| 320 | + } |
| 321 | + testGraphql(test, params, t) |
| 322 | + if !queryFound { |
| 323 | + t.Fatal("can't detect \"query\" operation by validation rule") |
| 324 | + } |
| 325 | + if queryName != expectedNames[i] { |
| 326 | + t.Fatalf("unexpected query name: want=%s got=%s", queryName, expectedNames) |
| 327 | + } |
| 328 | + } |
| 329 | +} |
0 commit comments