Skip to content

Commit 99b6d54

Browse files
committed
updates
1 parent 7d74b14 commit 99b6d54

File tree

3 files changed

+217
-139
lines changed

3 files changed

+217
-139
lines changed

mongo/integration/cmd_monitoring_helpers_test.go

Lines changed: 182 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -270,21 +270,24 @@ func checkExpectations(mt *mtest.T, expectations *[]*expectation, id0, id1 bson.
270270
return
271271
}
272272

273+
startedEvents := make([]*cmdStartedEvt, len(*expectations))
274+
succeededEvents := make([]*cmdSucceededEvt, len(*expectations))
275+
failedEvents := make([]*cmdFailedEvt, len(*expectations))
276+
273277
for idx, expectation := range *expectations {
274-
var err error
278+
startedEvents[idx] = expectation.CommandStartedEvent
279+
succeededEvents[idx] = expectation.CommandSucceededEvent
280+
failedEvents[idx] = expectation.CommandFailedEvent
281+
}
275282

276-
if expectation.CommandStartedEvent != nil {
277-
err = compareStartedEvent(mt, expectation, id0, id1)
278-
}
279-
if expectation.CommandSucceededEvent != nil {
280-
err = compareSucceededEvent(mt, expectation)
281-
}
282-
if expectation.CommandFailedEvent != nil {
283-
err = compareFailedEvent(mt, expectation)
284-
}
283+
var err error
284+
err = compareStartedEvents(mt, startedEvents, id0, id1)
285+
assert.Nil(mt, err, "expectation comparison %s", err)
286+
err = compareSucceededEvents(mt, succeededEvents)
287+
assert.Nil(mt, err, "expectation comparison %s", err)
288+
err = compareFailedEvents(mt, failedEvents)
289+
assert.Nil(mt, err, "expectation comparison %s", err)
285290

286-
assert.Nil(mt, err, "expectation comparison error at index %v: %s", idx, err)
287-
}
288291
}
289292

290293
// newMatchError appends `expected` and `actual` BSON data to an error.
@@ -298,83 +301,105 @@ func newMatchError(mt *mtest.T, expected bson.Raw, actual bson.Raw, format strin
298301
return fmt.Errorf("%s\nExpected %s\nGot: %s", msg, string(expectedJSON), string(actualJSON))
299302
}
300303

301-
func compareStartedEvent(mt *mtest.T, expectation *expectation, id0, id1 bson.Raw) error {
304+
func compareStartedEvents(mt *mtest.T, expectations []*cmdStartedEvt, id0, id1 bson.Raw) error {
302305
mt.Helper()
303306

304-
expected := expectation.CommandStartedEvent
305-
306-
if len(expected.Extra) > 0 {
307-
return fmt.Errorf("unrecognized fields for CommandStartedEvent: %v", expected.Extra)
308-
}
309-
310-
evt := mt.GetStartedEvent()
311-
if evt == nil {
312-
return errors.New("expected CommandStartedEvent, got nil")
313-
}
314-
315-
if expected.CommandName != "" && expected.CommandName != evt.CommandName {
316-
return fmt.Errorf("command name mismatch for started event; expected %s, got %s", expected.CommandName, evt.CommandName)
317-
}
318-
if expected.DatabaseName != "" && expected.DatabaseName != evt.DatabaseName {
319-
return fmt.Errorf("database name mismatch; expected %s, got %s", expected.DatabaseName, evt.DatabaseName)
307+
expectedCmds := make(map[string]bool)
308+
for _, expected := range expectations {
309+
expectedCmds[expected.CommandName] = true
320310
}
321311

322-
eElems, err := expected.Command.Elements()
323-
if err != nil {
324-
return fmt.Errorf("error getting expected command elements: %s", err)
325-
}
326-
327-
for _, elem := range eElems {
328-
key := elem.Key()
329-
val := elem.Value()
330-
331-
actualVal, err := evt.Command.LookupErr(key)
312+
compare := func(expected *cmdStartedEvt) error {
313+
if expected == nil {
314+
return nil
315+
}
316+
if len(expected.Extra) > 0 {
317+
return fmt.Errorf("unrecognized fields for CommandStartedEvent: %v", expected.Extra)
318+
}
332319

333-
// Keys that may be nil
334-
if val.Type == bson.TypeNull {
335-
// Expected value is BSON null. Expect the actual field to be omitted.
336-
if errors.Is(err, bsoncore.ErrElementNotFound) {
337-
continue
320+
var evt *event.CommandStartedEvent
321+
// skip events not in expectations
322+
for {
323+
evt = mt.GetStartedEvent()
324+
if evt == nil {
325+
return errors.New("expected CommandStartedEvent, got nil")
338326
}
339-
if err != nil {
340-
return newMatchError(mt, expected.Command, evt.Command, "expected key %q to be omitted but got error: %v", key, err)
327+
if v, ok := expectedCmds[expected.CommandName]; ok && v {
328+
break
341329
}
342-
return newMatchError(mt, expected.Command, evt.Command, "expected key %q to be omitted but got %q", key, actualVal)
343330
}
344-
assert.Nil(mt, err, "expected command to contain key %q", key)
345331

346-
if key == "batchSize" {
347-
// Some command monitoring tests expect that the driver will send a lower batch size if the required batch
348-
// size is lower than the operation limit. We only do this for legacy servers <= 3.0 because those server
349-
// versions do not support the limit option, but not for 3.2+. We've already validated that the command
350-
// contains a batchSize field above and we can skip the actual value comparison below.
351-
continue
332+
if expected.CommandName != "" && expected.CommandName != evt.CommandName {
333+
return fmt.Errorf("command name mismatch for started event; expected %s, got %s", expected.CommandName, evt.CommandName)
334+
}
335+
if expected.DatabaseName != "" && expected.DatabaseName != evt.DatabaseName {
336+
return fmt.Errorf("database name mismatch; expected %s, got %s", expected.DatabaseName, evt.DatabaseName)
337+
}
338+
339+
eElems, err := expected.Command.Elements()
340+
if err != nil {
341+
return fmt.Errorf("error getting expected command elements: %s", err)
352342
}
353343

354-
switch key {
355-
case "lsid":
356-
sessName := val.StringValue()
357-
var expectedID bson.Raw
358-
actualID := actualVal.Document()
344+
for _, elem := range eElems {
345+
key := elem.Key()
346+
val := elem.Value()
359347

360-
switch sessName {
361-
case "session0":
362-
expectedID = id0
363-
case "session1":
364-
expectedID = id1
365-
default:
366-
return newMatchError(mt, expected.Command, evt.Command, "unrecognized session identifier in command document: %s", sessName)
348+
actualVal, err := evt.Command.LookupErr(key)
349+
350+
// Keys that may be nil
351+
if val.Type == bson.TypeNull {
352+
// Expected value is BSON null. Expect the actual field to be omitted.
353+
if errors.Is(err, bsoncore.ErrElementNotFound) {
354+
continue
355+
}
356+
if err != nil {
357+
return newMatchError(mt, expected.Command, evt.Command, "expected key %q to be omitted but got error: %v", key, err)
358+
}
359+
return newMatchError(mt, expected.Command, evt.Command, "expected key %q to be omitted but got %q", key, actualVal)
367360
}
361+
assert.Nil(mt, err, "expected command to contain key %q", key)
368362

369-
if !bytes.Equal(expectedID, actualID) {
370-
return newMatchError(mt, expected.Command, evt.Command, "session ID mismatch for session %s; expected %s, got %s", sessName, expectedID,
371-
actualID)
363+
if key == "batchSize" {
364+
// Some command monitoring tests expect that the driver will send a lower batch size if the required batch
365+
// size is lower than the operation limit. We only do this for legacy servers <= 3.0 because those server
366+
// versions do not support the limit option, but not for 3.2+. We've already validated that the command
367+
// contains a batchSize field above and we can skip the actual value comparison below.
368+
continue
372369
}
373-
default:
374-
if err := compareValues(mt, key, val, actualVal); err != nil {
375-
return newMatchError(mt, expected.Command, evt.Command, "%s", err)
370+
371+
switch key {
372+
case "lsid":
373+
sessName := val.StringValue()
374+
var expectedID bson.Raw
375+
actualID := actualVal.Document()
376+
377+
switch sessName {
378+
case "session0":
379+
expectedID = id0
380+
case "session1":
381+
expectedID = id1
382+
default:
383+
return newMatchError(mt, expected.Command, evt.Command, "unrecognized session identifier in command document: %s", sessName)
384+
}
385+
386+
if !bytes.Equal(expectedID, actualID) {
387+
return newMatchError(mt, expected.Command, evt.Command, "session ID mismatch for session %s; expected %s, got %s", sessName, expectedID,
388+
actualID)
389+
}
390+
default:
391+
if err := compareValues(mt, key, val, actualVal); err != nil {
392+
return newMatchError(mt, expected.Command, evt.Command, "%s", err)
393+
}
376394
}
377395
}
396+
return nil
397+
}
398+
for idx, expected := range expectations {
399+
err := compare(expected)
400+
if err != nil {
401+
return fmt.Errorf("error at index %d: %s", idx, err)
402+
}
378403
}
379404
return nil
380405
}
@@ -416,60 +441,108 @@ func compareWriteErrors(mt *mtest.T, expected, actual bson.Raw) error {
416441
return nil
417442
}
418443

419-
func compareSucceededEvent(mt *mtest.T, expectation *expectation) error {
444+
func compareSucceededEvents(mt *mtest.T, expectations []*cmdSucceededEvt) error {
420445
mt.Helper()
421446

422-
expected := expectation.CommandSucceededEvent
423-
if len(expected.Extra) > 0 {
424-
return fmt.Errorf("unrecognized fields for CommandSucceededEvent: %v", expected.Extra)
425-
}
426-
evt := mt.GetSucceededEvent()
427-
if evt == nil {
428-
return errors.New("expected CommandSucceededEvent, got nil")
447+
expectedCmds := make(map[string]bool)
448+
for _, expected := range expectations {
449+
expectedCmds[expected.CommandName] = true
429450
}
430451

431-
if expected.CommandName != "" && expected.CommandName != evt.CommandName {
432-
return fmt.Errorf("command name mismatch for succeeded event; expected %s, got %s", expected.CommandName, evt.CommandName)
433-
}
452+
compare := func(expected *cmdSucceededEvt) error {
453+
if expected == nil {
454+
return nil
455+
}
456+
if len(expected.Extra) > 0 {
457+
return fmt.Errorf("unrecognized fields for CommandSucceededEvent: %v", expected.Extra)
458+
}
434459

435-
eElems, err := expected.Reply.Elements()
436-
if err != nil {
437-
return fmt.Errorf("error getting expected reply elements: %s", err)
438-
}
460+
var evt *event.CommandSucceededEvent
461+
// skip events not in expectations
462+
for {
463+
evt = mt.GetSucceededEvent()
464+
if evt == nil {
465+
return errors.New("expected CommandSucceededEvent, got nil")
466+
}
467+
if v, ok := expectedCmds[expected.CommandName]; ok && v {
468+
break
469+
}
470+
}
439471

440-
for _, elem := range eElems {
441-
key := elem.Key()
442-
val := elem.Value()
443-
actualVal := evt.Reply.Lookup(key)
472+
if expected.CommandName != "" && expected.CommandName != evt.CommandName {
473+
return fmt.Errorf("command name mismatch for succeeded event; expected %s, got %s", expected.CommandName, evt.CommandName)
474+
}
444475

445-
switch key {
446-
case "writeErrors":
447-
if err = compareWriteErrors(mt, val.Array(), actualVal.Array()); err != nil {
448-
return newMatchError(mt, expected.Reply, evt.Reply, "%s", err)
449-
}
450-
default:
451-
if err := compareValues(mt, key, val, actualVal); err != nil {
452-
return newMatchError(mt, expected.Reply, evt.Reply, "%s", err)
476+
eElems, err := expected.Reply.Elements()
477+
if err != nil {
478+
return fmt.Errorf("error getting expected reply elements: %s", err)
479+
}
480+
481+
for _, elem := range eElems {
482+
key := elem.Key()
483+
val := elem.Value()
484+
actualVal := evt.Reply.Lookup(key)
485+
486+
switch key {
487+
case "writeErrors":
488+
if err = compareWriteErrors(mt, val.Array(), actualVal.Array()); err != nil {
489+
return newMatchError(mt, expected.Reply, evt.Reply, "%s", err)
490+
}
491+
default:
492+
if err := compareValues(mt, key, val, actualVal); err != nil {
493+
return newMatchError(mt, expected.Reply, evt.Reply, "%s", err)
494+
}
453495
}
454496
}
497+
return nil
498+
}
499+
for idx, expected := range expectations {
500+
err := compare(expected)
501+
if err != nil {
502+
return fmt.Errorf("error at index %d: %s", idx, err)
503+
}
455504
}
456505
return nil
457506
}
458507

459-
func compareFailedEvent(mt *mtest.T, expectation *expectation) error {
508+
func compareFailedEvents(mt *mtest.T, expectations []*cmdFailedEvt) error {
460509
mt.Helper()
461510

462-
expected := expectation.CommandFailedEvent
463-
if len(expected.Extra) > 0 {
464-
return fmt.Errorf("unrecognized fields for CommandFailedEvent: %v", expected.Extra)
465-
}
466-
evt := mt.GetFailedEvent()
467-
if evt == nil {
468-
return errors.New("expected CommandFailedEvent, got nil")
511+
expectedCmds := make(map[string]bool)
512+
for _, expected := range expectations {
513+
expectedCmds[expected.CommandName] = true
469514
}
470515

471-
if expected.CommandName != "" && expected.CommandName != evt.CommandName {
472-
return fmt.Errorf("command name mismatch for failed event; expected %s, got %s", expected.CommandName, evt.CommandName)
516+
compare := func(expected *cmdFailedEvt) error {
517+
if expected == nil {
518+
return nil
519+
}
520+
if len(expected.Extra) > 0 {
521+
return fmt.Errorf("unrecognized fields for CommandFailedEvent: %v", expected.Extra)
522+
}
523+
524+
var evt *event.CommandFailedEvent
525+
// skip events not in expectations
526+
for {
527+
evt = mt.GetFailedEvent()
528+
if evt == nil {
529+
return errors.New("expected CommandFailedEvent, got nil")
530+
}
531+
if v, ok := expectedCmds[expected.CommandName]; ok && v {
532+
break
533+
}
534+
}
535+
536+
if expected.CommandName != "" && expected.CommandName != evt.CommandName {
537+
return fmt.Errorf("command name mismatch for failed event; expected %s, got %s", expected.CommandName, evt.CommandName)
538+
}
539+
return nil
540+
}
541+
for idx, expected := range expectations {
542+
err := compare(expected)
543+
if err != nil {
544+
return fmt.Errorf("error at index %d: %s", idx, err)
545+
}
473546
}
474547
return nil
475548
}

0 commit comments

Comments
 (0)