@@ -270,21 +270,24 @@ func checkExpectations(mt *mtest.T, expectations *[]*expectation, id0, id1 bson.
270
270
return
271
271
}
272
272
273
+ startedEvents := make ([]* cmdStartedEvt , len (* expectations ))
274
+ succeededEvents := make ([]* cmdSucceededEvt , len (* expectations ))
275
+ failedEvents := make ([]* cmdFailedEvt , len (* expectations ))
276
+
273
277
for idx , expectation := range * expectations {
274
- var err error
278
+ startedEvents [idx ] = expectation .CommandStartedEvent
279
+ succeededEvents [idx ] = expectation .CommandSucceededEvent
280
+ failedEvents [idx ] = expectation .CommandFailedEvent
281
+ }
275
282
276
- if expectation .CommandStartedEvent != nil {
277
- err = compareStartedEvent (mt , expectation , id0 , id1 )
278
- }
279
- if expectation .CommandSucceededEvent != nil {
280
- err = compareSucceededEvent (mt , expectation )
281
- }
282
- if expectation .CommandFailedEvent != nil {
283
- err = compareFailedEvent (mt , expectation )
284
- }
283
+ var err error
284
+ err = compareStartedEvents (mt , startedEvents , id0 , id1 )
285
+ assert .Nil (mt , err , "expectation comparison %s" , err )
286
+ err = compareSucceededEvents (mt , succeededEvents )
287
+ assert .Nil (mt , err , "expectation comparison %s" , err )
288
+ err = compareFailedEvents (mt , failedEvents )
289
+ assert .Nil (mt , err , "expectation comparison %s" , err )
285
290
286
- assert .Nil (mt , err , "expectation comparison error at index %v: %s" , idx , err )
287
- }
288
291
}
289
292
290
293
// newMatchError appends `expected` and `actual` BSON data to an error.
@@ -298,83 +301,105 @@ func newMatchError(mt *mtest.T, expected bson.Raw, actual bson.Raw, format strin
298
301
return fmt .Errorf ("%s\n Expected %s\n Got: %s" , msg , string (expectedJSON ), string (actualJSON ))
299
302
}
300
303
301
- func compareStartedEvent (mt * mtest.T , expectation * expectation , id0 , id1 bson.Raw ) error {
304
+ func compareStartedEvents (mt * mtest.T , expectations [] * cmdStartedEvt , id0 , id1 bson.Raw ) error {
302
305
mt .Helper ()
303
306
304
- expected := expectation .CommandStartedEvent
305
-
306
- if len (expected .Extra ) > 0 {
307
- return fmt .Errorf ("unrecognized fields for CommandStartedEvent: %v" , expected .Extra )
308
- }
309
-
310
- evt := mt .GetStartedEvent ()
311
- if evt == nil {
312
- return errors .New ("expected CommandStartedEvent, got nil" )
313
- }
314
-
315
- if expected .CommandName != "" && expected .CommandName != evt .CommandName {
316
- return fmt .Errorf ("command name mismatch for started event; expected %s, got %s" , expected .CommandName , evt .CommandName )
317
- }
318
- if expected .DatabaseName != "" && expected .DatabaseName != evt .DatabaseName {
319
- return fmt .Errorf ("database name mismatch; expected %s, got %s" , expected .DatabaseName , evt .DatabaseName )
307
+ expectedCmds := make (map [string ]bool )
308
+ for _ , expected := range expectations {
309
+ expectedCmds [expected .CommandName ] = true
320
310
}
321
311
322
- eElems , err := expected .Command .Elements ()
323
- if err != nil {
324
- return fmt .Errorf ("error getting expected command elements: %s" , err )
325
- }
326
-
327
- for _ , elem := range eElems {
328
- key := elem .Key ()
329
- val := elem .Value ()
330
-
331
- actualVal , err := evt .Command .LookupErr (key )
312
+ compare := func (expected * cmdStartedEvt ) error {
313
+ if expected == nil {
314
+ return nil
315
+ }
316
+ if len (expected .Extra ) > 0 {
317
+ return fmt .Errorf ("unrecognized fields for CommandStartedEvent: %v" , expected .Extra )
318
+ }
332
319
333
- // Keys that may be nil
334
- if val .Type == bson .TypeNull {
335
- // Expected value is BSON null. Expect the actual field to be omitted.
336
- if errors .Is (err , bsoncore .ErrElementNotFound ) {
337
- continue
320
+ var evt * event.CommandStartedEvent
321
+ // skip events not in expectations
322
+ for {
323
+ evt = mt .GetStartedEvent ()
324
+ if evt == nil {
325
+ return errors .New ("expected CommandStartedEvent, got nil" )
338
326
}
339
- if err != nil {
340
- return newMatchError ( mt , expected . Command , evt . Command , "expected key %q to be omitted but got error: %v" , key , err )
327
+ if v , ok := expectedCmds [ expected . CommandName ]; ok && v {
328
+ break
341
329
}
342
- return newMatchError (mt , expected .Command , evt .Command , "expected key %q to be omitted but got %q" , key , actualVal )
343
330
}
344
- assert .Nil (mt , err , "expected command to contain key %q" , key )
345
331
346
- if key == "batchSize" {
347
- // Some command monitoring tests expect that the driver will send a lower batch size if the required batch
348
- // size is lower than the operation limit. We only do this for legacy servers <= 3.0 because those server
349
- // versions do not support the limit option, but not for 3.2+. We've already validated that the command
350
- // contains a batchSize field above and we can skip the actual value comparison below.
351
- continue
332
+ if expected .CommandName != "" && expected .CommandName != evt .CommandName {
333
+ return fmt .Errorf ("command name mismatch for started event; expected %s, got %s" , expected .CommandName , evt .CommandName )
334
+ }
335
+ if expected .DatabaseName != "" && expected .DatabaseName != evt .DatabaseName {
336
+ return fmt .Errorf ("database name mismatch; expected %s, got %s" , expected .DatabaseName , evt .DatabaseName )
337
+ }
338
+
339
+ eElems , err := expected .Command .Elements ()
340
+ if err != nil {
341
+ return fmt .Errorf ("error getting expected command elements: %s" , err )
352
342
}
353
343
354
- switch key {
355
- case "lsid" :
356
- sessName := val .StringValue ()
357
- var expectedID bson.Raw
358
- actualID := actualVal .Document ()
344
+ for _ , elem := range eElems {
345
+ key := elem .Key ()
346
+ val := elem .Value ()
359
347
360
- switch sessName {
361
- case "session0" :
362
- expectedID = id0
363
- case "session1" :
364
- expectedID = id1
365
- default :
366
- return newMatchError (mt , expected .Command , evt .Command , "unrecognized session identifier in command document: %s" , sessName )
348
+ actualVal , err := evt .Command .LookupErr (key )
349
+
350
+ // Keys that may be nil
351
+ if val .Type == bson .TypeNull {
352
+ // Expected value is BSON null. Expect the actual field to be omitted.
353
+ if errors .Is (err , bsoncore .ErrElementNotFound ) {
354
+ continue
355
+ }
356
+ if err != nil {
357
+ return newMatchError (mt , expected .Command , evt .Command , "expected key %q to be omitted but got error: %v" , key , err )
358
+ }
359
+ return newMatchError (mt , expected .Command , evt .Command , "expected key %q to be omitted but got %q" , key , actualVal )
367
360
}
361
+ assert .Nil (mt , err , "expected command to contain key %q" , key )
368
362
369
- if ! bytes .Equal (expectedID , actualID ) {
370
- return newMatchError (mt , expected .Command , evt .Command , "session ID mismatch for session %s; expected %s, got %s" , sessName , expectedID ,
371
- actualID )
363
+ if key == "batchSize" {
364
+ // Some command monitoring tests expect that the driver will send a lower batch size if the required batch
365
+ // size is lower than the operation limit. We only do this for legacy servers <= 3.0 because those server
366
+ // versions do not support the limit option, but not for 3.2+. We've already validated that the command
367
+ // contains a batchSize field above and we can skip the actual value comparison below.
368
+ continue
372
369
}
373
- default :
374
- if err := compareValues (mt , key , val , actualVal ); err != nil {
375
- return newMatchError (mt , expected .Command , evt .Command , "%s" , err )
370
+
371
+ switch key {
372
+ case "lsid" :
373
+ sessName := val .StringValue ()
374
+ var expectedID bson.Raw
375
+ actualID := actualVal .Document ()
376
+
377
+ switch sessName {
378
+ case "session0" :
379
+ expectedID = id0
380
+ case "session1" :
381
+ expectedID = id1
382
+ default :
383
+ return newMatchError (mt , expected .Command , evt .Command , "unrecognized session identifier in command document: %s" , sessName )
384
+ }
385
+
386
+ if ! bytes .Equal (expectedID , actualID ) {
387
+ return newMatchError (mt , expected .Command , evt .Command , "session ID mismatch for session %s; expected %s, got %s" , sessName , expectedID ,
388
+ actualID )
389
+ }
390
+ default :
391
+ if err := compareValues (mt , key , val , actualVal ); err != nil {
392
+ return newMatchError (mt , expected .Command , evt .Command , "%s" , err )
393
+ }
376
394
}
377
395
}
396
+ return nil
397
+ }
398
+ for idx , expected := range expectations {
399
+ err := compare (expected )
400
+ if err != nil {
401
+ return fmt .Errorf ("error at index %d: %s" , idx , err )
402
+ }
378
403
}
379
404
return nil
380
405
}
@@ -416,60 +441,108 @@ func compareWriteErrors(mt *mtest.T, expected, actual bson.Raw) error {
416
441
return nil
417
442
}
418
443
419
- func compareSucceededEvent (mt * mtest.T , expectation * expectation ) error {
444
+ func compareSucceededEvents (mt * mtest.T , expectations [] * cmdSucceededEvt ) error {
420
445
mt .Helper ()
421
446
422
- expected := expectation .CommandSucceededEvent
423
- if len (expected .Extra ) > 0 {
424
- return fmt .Errorf ("unrecognized fields for CommandSucceededEvent: %v" , expected .Extra )
425
- }
426
- evt := mt .GetSucceededEvent ()
427
- if evt == nil {
428
- return errors .New ("expected CommandSucceededEvent, got nil" )
447
+ expectedCmds := make (map [string ]bool )
448
+ for _ , expected := range expectations {
449
+ expectedCmds [expected .CommandName ] = true
429
450
}
430
451
431
- if expected .CommandName != "" && expected .CommandName != evt .CommandName {
432
- return fmt .Errorf ("command name mismatch for succeeded event; expected %s, got %s" , expected .CommandName , evt .CommandName )
433
- }
452
+ compare := func (expected * cmdSucceededEvt ) error {
453
+ if expected == nil {
454
+ return nil
455
+ }
456
+ if len (expected .Extra ) > 0 {
457
+ return fmt .Errorf ("unrecognized fields for CommandSucceededEvent: %v" , expected .Extra )
458
+ }
434
459
435
- eElems , err := expected .Reply .Elements ()
436
- if err != nil {
437
- return fmt .Errorf ("error getting expected reply elements: %s" , err )
438
- }
460
+ var evt * event.CommandSucceededEvent
461
+ // skip events not in expectations
462
+ for {
463
+ evt = mt .GetSucceededEvent ()
464
+ if evt == nil {
465
+ return errors .New ("expected CommandSucceededEvent, got nil" )
466
+ }
467
+ if v , ok := expectedCmds [expected .CommandName ]; ok && v {
468
+ break
469
+ }
470
+ }
439
471
440
- for _ , elem := range eElems {
441
- key := elem .Key ()
442
- val := elem .Value ()
443
- actualVal := evt .Reply .Lookup (key )
472
+ if expected .CommandName != "" && expected .CommandName != evt .CommandName {
473
+ return fmt .Errorf ("command name mismatch for succeeded event; expected %s, got %s" , expected .CommandName , evt .CommandName )
474
+ }
444
475
445
- switch key {
446
- case "writeErrors" :
447
- if err = compareWriteErrors (mt , val .Array (), actualVal .Array ()); err != nil {
448
- return newMatchError (mt , expected .Reply , evt .Reply , "%s" , err )
449
- }
450
- default :
451
- if err := compareValues (mt , key , val , actualVal ); err != nil {
452
- return newMatchError (mt , expected .Reply , evt .Reply , "%s" , err )
476
+ eElems , err := expected .Reply .Elements ()
477
+ if err != nil {
478
+ return fmt .Errorf ("error getting expected reply elements: %s" , err )
479
+ }
480
+
481
+ for _ , elem := range eElems {
482
+ key := elem .Key ()
483
+ val := elem .Value ()
484
+ actualVal := evt .Reply .Lookup (key )
485
+
486
+ switch key {
487
+ case "writeErrors" :
488
+ if err = compareWriteErrors (mt , val .Array (), actualVal .Array ()); err != nil {
489
+ return newMatchError (mt , expected .Reply , evt .Reply , "%s" , err )
490
+ }
491
+ default :
492
+ if err := compareValues (mt , key , val , actualVal ); err != nil {
493
+ return newMatchError (mt , expected .Reply , evt .Reply , "%s" , err )
494
+ }
453
495
}
454
496
}
497
+ return nil
498
+ }
499
+ for idx , expected := range expectations {
500
+ err := compare (expected )
501
+ if err != nil {
502
+ return fmt .Errorf ("error at index %d: %s" , idx , err )
503
+ }
455
504
}
456
505
return nil
457
506
}
458
507
459
- func compareFailedEvent (mt * mtest.T , expectation * expectation ) error {
508
+ func compareFailedEvents (mt * mtest.T , expectations [] * cmdFailedEvt ) error {
460
509
mt .Helper ()
461
510
462
- expected := expectation .CommandFailedEvent
463
- if len (expected .Extra ) > 0 {
464
- return fmt .Errorf ("unrecognized fields for CommandFailedEvent: %v" , expected .Extra )
465
- }
466
- evt := mt .GetFailedEvent ()
467
- if evt == nil {
468
- return errors .New ("expected CommandFailedEvent, got nil" )
511
+ expectedCmds := make (map [string ]bool )
512
+ for _ , expected := range expectations {
513
+ expectedCmds [expected .CommandName ] = true
469
514
}
470
515
471
- if expected .CommandName != "" && expected .CommandName != evt .CommandName {
472
- return fmt .Errorf ("command name mismatch for failed event; expected %s, got %s" , expected .CommandName , evt .CommandName )
516
+ compare := func (expected * cmdFailedEvt ) error {
517
+ if expected == nil {
518
+ return nil
519
+ }
520
+ if len (expected .Extra ) > 0 {
521
+ return fmt .Errorf ("unrecognized fields for CommandFailedEvent: %v" , expected .Extra )
522
+ }
523
+
524
+ var evt * event.CommandFailedEvent
525
+ // skip events not in expectations
526
+ for {
527
+ evt = mt .GetFailedEvent ()
528
+ if evt == nil {
529
+ return errors .New ("expected CommandFailedEvent, got nil" )
530
+ }
531
+ if v , ok := expectedCmds [expected .CommandName ]; ok && v {
532
+ break
533
+ }
534
+ }
535
+
536
+ if expected .CommandName != "" && expected .CommandName != evt .CommandName {
537
+ return fmt .Errorf ("command name mismatch for failed event; expected %s, got %s" , expected .CommandName , evt .CommandName )
538
+ }
539
+ return nil
540
+ }
541
+ for idx , expected := range expectations {
542
+ err := compare (expected )
543
+ if err != nil {
544
+ return fmt .Errorf ("error at index %d: %s" , idx , err )
545
+ }
473
546
}
474
547
return nil
475
548
}
0 commit comments