@@ -176,10 +176,12 @@ func New(wrapped *testing.T, opts ...*Options) *T {
176
176
177
177
t := newT (wrapped , opts ... )
178
178
179
+ t .fpClient = t .createTestClient ()
180
+
179
181
// only create a client if it needs to be shared in sub-tests
180
182
// otherwise, a new client will be created for each subtest
181
183
if t .shareClient != nil && * t .shareClient {
182
- t .createTestClient ()
184
+ t .Client = t . createTestClient ()
183
185
}
184
186
185
187
wrapped .Cleanup (t .cleanup )
@@ -229,10 +231,13 @@ func (t *T) RunOpts(name string, opts *Options, callback func(mt *T)) {
229
231
if sub .shareClient != nil && * sub .shareClient && sub .clientType == t .clientType {
230
232
sub .Client = t .Client
231
233
}
234
+ if sub .fpClient == nil {
235
+ sub .fpClient = sub .createTestClient ()
236
+ }
232
237
// only create a client if not already set
233
238
if sub .Client == nil {
234
239
if sub .createClient == nil || * sub .createClient {
235
- sub .createTestClient ()
240
+ sub .Client = sub . createTestClient ()
236
241
}
237
242
}
238
243
// create a collection for this test
@@ -406,10 +411,8 @@ func (t *T) ResetClient(opts *options.ClientOptions) {
406
411
t .clientOpts = opts
407
412
}
408
413
409
- if len (t .failPointNames ) == 0 {
410
- _ = t .Client .Disconnect (context .Background ())
411
- }
412
- t .createTestClient ()
414
+ _ = t .Client .Disconnect (context .Background ())
415
+ t .Client = t .createTestClient ()
413
416
t .DB = t .Client .Database (t .dbName )
414
417
t .Coll = t .DB .Collection (t .collName , t .collOpts )
415
418
@@ -562,9 +565,6 @@ func (t *T) SetFailPoint(fp FailPoint) {
562
565
}
563
566
}
564
567
565
- if t .fpClient == nil {
566
- t .fpClient = t .Client
567
- }
568
568
if err := SetFailPoint (fp , t .fpClient ); err != nil {
569
569
t .Fatal (err )
570
570
}
@@ -576,9 +576,6 @@ func (t *T) SetFailPoint(fp FailPoint) {
576
576
// the failpoint will appear in command monitoring channels. The fail point will be automatically disabled after this
577
577
// test has run.
578
578
func (t * T ) SetFailPointFromDocument (fp bson.Raw ) {
579
- if t .fpClient == nil {
580
- t .fpClient = t .Client
581
- }
582
579
if err := SetRawFailPoint (fp , t .fpClient ); err != nil {
583
580
t .Fatal (err )
584
581
}
@@ -595,16 +592,7 @@ func (t *T) TrackFailPoint(fpName string) {
595
592
596
593
// ClearFailPoints disables all previously set failpoints for this test.
597
594
func (t * T ) ClearFailPoints () {
598
- client := t .fpClient
599
- if client == nil {
600
- client = t .Client
601
- } else {
602
- defer func () {
603
- // _ = t.fpClient.Disconnect(context.Background())
604
- t .fpClient = nil
605
- }()
606
- }
607
- db := client .Database ("admin" )
595
+ db := t .fpClient .Database ("admin" )
608
596
for _ , fp := range t .failPointNames {
609
597
cmd := bson.D {
610
598
{"configureFailPoint" , fp },
@@ -643,7 +631,7 @@ func sanitizeCollectionName(db string, coll string) string {
643
631
return coll
644
632
}
645
633
646
- func (t * T ) createTestClient () {
634
+ func (t * T ) createTestClient () * mongo. Client {
647
635
clientOpts := t .clientOpts
648
636
if clientOpts == nil {
649
637
// default opts
@@ -702,19 +690,20 @@ func (t *T) createTestClient() {
702
690
})
703
691
}
704
692
693
+ var client * mongo.Client
705
694
var err error
706
695
switch t .clientType {
707
696
case Pinned :
708
697
// pin to first mongos
709
698
pinnedHostList := []string {testContext .connString .Hosts [0 ]}
710
699
uriOpts := options .Client ().ApplyURI (testContext .connString .Original ).SetHosts (pinnedHostList )
711
- t . Client , err = mongo .NewClient (uriOpts , clientOpts )
700
+ client , err = mongo .NewClient (uriOpts , clientOpts )
712
701
case Mock :
713
702
// clear pool monitor to avoid configuration error
714
703
clientOpts .PoolMonitor = nil
715
704
t .mockDeployment = newMockDeployment ()
716
705
clientOpts .Deployment = t .mockDeployment
717
- t . Client , err = mongo .NewClient (clientOpts )
706
+ client , err = mongo .NewClient (clientOpts )
718
707
case Proxy :
719
708
t .proxyDialer = newProxyDialer ()
720
709
clientOpts .SetDialer (t .proxyDialer )
@@ -732,14 +721,15 @@ func (t *T) createTestClient() {
732
721
}
733
722
734
723
// Pass in uriOpts first so clientOpts wins if there are any conflicting settings.
735
- t . Client , err = mongo .NewClient (uriOpts , clientOpts )
724
+ client , err = mongo .NewClient (uriOpts , clientOpts )
736
725
}
737
726
if err != nil {
738
727
t .Fatalf ("error creating client: %v" , err )
739
728
}
740
- if err := t . Client .Connect (context .Background ()); err != nil {
729
+ if err := client .Connect (context .Background ()); err != nil {
741
730
t .Fatalf ("error connecting client: %v" , err )
742
731
}
732
+ return client
743
733
}
744
734
745
735
func (t * T ) createTestCollection () {
0 commit comments