@@ -327,3 +327,66 @@ func TestQueryWithCustomRule(t *testing.T) {
327
327
}
328
328
}
329
329
}
330
+
331
+ // TestCustomRuleWithArgs tests graphql.GetArgumentValues() be able to access
332
+ // field's argument values from custom validation rule.
333
+ func TestCustomRuleWithArgs (t * testing.T ) {
334
+ fieldDef , ok := testutil .StarWarsSchema .QueryType ().Fields ()["human" ]
335
+ if ! ok {
336
+ t .Fatal ("can't retrieve \" human\" field definition" )
337
+ }
338
+
339
+ // a custom validation rule to extract argument values of "human" field.
340
+ var actual map [string ]interface {}
341
+ enter := func (p visitor.VisitFuncParams ) (string , interface {}) {
342
+ // only interested in "human" field.
343
+ fieldNode , ok := p .Node .(* ast.Field )
344
+ if ! ok || fieldNode .Name == nil || fieldNode .Name .Value != "human" {
345
+ return visitor .ActionNoChange , nil
346
+ }
347
+ // extract argument values by graphql.GetArgumentValues().
348
+ actual = graphql .GetArgumentValues (fieldDef .Args , fieldNode .Arguments , nil )
349
+ return visitor .ActionNoChange , nil
350
+ }
351
+ checkHumanArgs := func (context * graphql.ValidationContext ) * graphql.ValidationRuleInstance {
352
+ return & graphql.ValidationRuleInstance {
353
+ VisitorOpts : & visitor.VisitorOptions {
354
+ KindFuncMap : map [string ]visitor.NamedVisitFuncs {
355
+ kinds .Field : {Enter : enter },
356
+ },
357
+ },
358
+ }
359
+ }
360
+
361
+ for _ , tc := range []struct {
362
+ query string
363
+ expected map [string ]interface {}
364
+ }{
365
+ {
366
+ `query { human(id: "1000") { name } }` ,
367
+ map [string ]interface {}{"id" : "1000" },
368
+ },
369
+ {
370
+ `query { human(id: "1002") { name } }` ,
371
+ map [string ]interface {}{"id" : "1002" },
372
+ },
373
+ {
374
+ `query { human(id: "9999") { name } }` ,
375
+ map [string ]interface {}{"id" : "9999" },
376
+ },
377
+ } {
378
+ actual = nil
379
+ params := graphql.Params {
380
+ Schema : testutil .StarWarsSchema ,
381
+ RequestString : tc .query ,
382
+ ValidationRules : append (graphql .SpecifiedRules , checkHumanArgs ),
383
+ }
384
+ result := graphql .Do (params )
385
+ if len (result .Errors ) > 0 {
386
+ t .Fatalf ("wrong result, unexpected errors: %v" , result .Errors )
387
+ }
388
+ if ! reflect .DeepEqual (actual , tc .expected ) {
389
+ t .Fatalf ("unexpected result: want=%+v got=%+v" , tc .expected , actual )
390
+ }
391
+ }
392
+ }
0 commit comments