Skip to content

Commit 9263a71

Browse files
committed
updates
1 parent a1b8b46 commit 9263a71

File tree

1 file changed

+17
-27
lines changed

1 file changed

+17
-27
lines changed

mongo/integration/mtest/mongotest.go

+17-27
Original file line numberDiff line numberDiff line change
@@ -176,10 +176,12 @@ func New(wrapped *testing.T, opts ...*Options) *T {
176176

177177
t := newT(wrapped, opts...)
178178

179+
t.fpClient = t.createTestClient()
180+
179181
// only create a client if it needs to be shared in sub-tests
180182
// otherwise, a new client will be created for each subtest
181183
if t.shareClient != nil && *t.shareClient {
182-
t.createTestClient()
184+
t.Client = t.createTestClient()
183185
}
184186

185187
wrapped.Cleanup(t.cleanup)
@@ -229,10 +231,13 @@ func (t *T) RunOpts(name string, opts *Options, callback func(mt *T)) {
229231
if sub.shareClient != nil && *sub.shareClient && sub.clientType == t.clientType {
230232
sub.Client = t.Client
231233
}
234+
if sub.fpClient == nil {
235+
sub.fpClient = sub.createTestClient()
236+
}
232237
// only create a client if not already set
233238
if sub.Client == nil {
234239
if sub.createClient == nil || *sub.createClient {
235-
sub.createTestClient()
240+
sub.Client = sub.createTestClient()
236241
}
237242
}
238243
// create a collection for this test
@@ -406,10 +411,8 @@ func (t *T) ResetClient(opts *options.ClientOptions) {
406411
t.clientOpts = opts
407412
}
408413

409-
if len(t.failPointNames) == 0 {
410-
_ = t.Client.Disconnect(context.Background())
411-
}
412-
t.createTestClient()
414+
_ = t.Client.Disconnect(context.Background())
415+
t.Client = t.createTestClient()
413416
t.DB = t.Client.Database(t.dbName)
414417
t.Coll = t.DB.Collection(t.collName, t.collOpts)
415418

@@ -562,9 +565,6 @@ func (t *T) SetFailPoint(fp FailPoint) {
562565
}
563566
}
564567

565-
if t.fpClient == nil {
566-
t.fpClient = t.Client
567-
}
568568
if err := SetFailPoint(fp, t.fpClient); err != nil {
569569
t.Fatal(err)
570570
}
@@ -576,9 +576,6 @@ func (t *T) SetFailPoint(fp FailPoint) {
576576
// the failpoint will appear in command monitoring channels. The fail point will be automatically disabled after this
577577
// test has run.
578578
func (t *T) SetFailPointFromDocument(fp bson.Raw) {
579-
if t.fpClient == nil {
580-
t.fpClient = t.Client
581-
}
582579
if err := SetRawFailPoint(fp, t.fpClient); err != nil {
583580
t.Fatal(err)
584581
}
@@ -595,16 +592,7 @@ func (t *T) TrackFailPoint(fpName string) {
595592

596593
// ClearFailPoints disables all previously set failpoints for this test.
597594
func (t *T) ClearFailPoints() {
598-
client := t.fpClient
599-
if client == nil {
600-
client = t.Client
601-
} else {
602-
defer func() {
603-
// _ = t.fpClient.Disconnect(context.Background())
604-
t.fpClient = nil
605-
}()
606-
}
607-
db := client.Database("admin")
595+
db := t.fpClient.Database("admin")
608596
for _, fp := range t.failPointNames {
609597
cmd := bson.D{
610598
{"configureFailPoint", fp},
@@ -643,7 +631,7 @@ func sanitizeCollectionName(db string, coll string) string {
643631
return coll
644632
}
645633

646-
func (t *T) createTestClient() {
634+
func (t *T) createTestClient() *mongo.Client {
647635
clientOpts := t.clientOpts
648636
if clientOpts == nil {
649637
// default opts
@@ -702,19 +690,20 @@ func (t *T) createTestClient() {
702690
})
703691
}
704692

693+
var client *mongo.Client
705694
var err error
706695
switch t.clientType {
707696
case Pinned:
708697
// pin to first mongos
709698
pinnedHostList := []string{testContext.connString.Hosts[0]}
710699
uriOpts := options.Client().ApplyURI(testContext.connString.Original).SetHosts(pinnedHostList)
711-
t.Client, err = mongo.NewClient(uriOpts, clientOpts)
700+
client, err = mongo.NewClient(uriOpts, clientOpts)
712701
case Mock:
713702
// clear pool monitor to avoid configuration error
714703
clientOpts.PoolMonitor = nil
715704
t.mockDeployment = newMockDeployment()
716705
clientOpts.Deployment = t.mockDeployment
717-
t.Client, err = mongo.NewClient(clientOpts)
706+
client, err = mongo.NewClient(clientOpts)
718707
case Proxy:
719708
t.proxyDialer = newProxyDialer()
720709
clientOpts.SetDialer(t.proxyDialer)
@@ -732,14 +721,15 @@ func (t *T) createTestClient() {
732721
}
733722

734723
// Pass in uriOpts first so clientOpts wins if there are any conflicting settings.
735-
t.Client, err = mongo.NewClient(uriOpts, clientOpts)
724+
client, err = mongo.NewClient(uriOpts, clientOpts)
736725
}
737726
if err != nil {
738727
t.Fatalf("error creating client: %v", err)
739728
}
740-
if err := t.Client.Connect(context.Background()); err != nil {
729+
if err := client.Connect(context.Background()); err != nil {
741730
t.Fatalf("error connecting client: %v", err)
742731
}
732+
return client
743733
}
744734

745735
func (t *T) createTestCollection() {

0 commit comments

Comments
 (0)