diff --git a/bson/default_value_encoders.go b/bson/default_value_encoders.go index bd5a20f2f9..83c409375d 100644 --- a/bson/default_value_encoders.go +++ b/bson/default_value_encoders.go @@ -9,6 +9,7 @@ package bson import ( "encoding/json" "errors" + "fmt" "math" "net/url" "reflect" @@ -165,6 +166,7 @@ func decimal128EncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) e if !val.IsValid() || val.Type() != tDecimal { return ValueEncoderError{Name: "Decimal128EncodeValue", Types: []reflect.Type{tDecimal}, Received: val} } + fmt.Println(val.Interface().(Decimal128)) return vw.WriteDecimal128(val.Interface().(Decimal128)) } diff --git a/event/monitoring.go b/event/monitoring.go index 2ca98969d7..6f6db625ac 100644 --- a/event/monitoring.go +++ b/event/monitoring.go @@ -75,17 +75,20 @@ const ( // strings for pool command monitoring types const ( - ConnectionPoolCreated = "ConnectionPoolCreated" - ConnectionPoolReady = "ConnectionPoolReady" - ConnectionPoolCleared = "ConnectionPoolCleared" - ConnectionPoolClosed = "ConnectionPoolClosed" - ConnectionCreated = "ConnectionCreated" - ConnectionReady = "ConnectionReady" - ConnectionClosed = "ConnectionClosed" - ConnectionCheckOutStarted = "ConnectionCheckOutStarted" - ConnectionCheckOutFailed = "ConnectionCheckOutFailed" - ConnectionCheckedOut = "ConnectionCheckedOut" - ConnectionCheckedIn = "ConnectionCheckedIn" + ConnectionPoolCreated = "ConnectionPoolCreated" + ConnectionPoolReady = "ConnectionPoolReady" + ConnectionPoolCleared = "ConnectionPoolCleared" + ConnectionPoolClosed = "ConnectionPoolClosed" + ConnectionCreated = "ConnectionCreated" + ConnectionReady = "ConnectionReady" + ConnectionClosed = "ConnectionClosed" + ConnectionCheckOutStarted = "ConnectionCheckOutStarted" + ConnectionCheckOutFailed = "ConnectionCheckOutFailed" + ConnectionCheckedOut = "ConnectionCheckedOut" + ConnectionCheckedIn = "ConnectionCheckedIn" + ConnectionPendingResponseStarted = "ConnectionPendingResponseStarted" + ConnectionPendingResponseSucceeded = "ConnectionPendingResponseSucceeded" + ConnectionPendingResponseFailed = "ConnectionPendingResponseFailed" ) // MonitorPoolOptions contains pool options as formatted in pool events @@ -105,9 +108,11 @@ type PoolEvent struct { Reason string `json:"reason"` // ServiceID is only set if the Type is PoolCleared and the server is deployed behind a load balancer. This field // can be used to distinguish between individual servers in a load balanced deployment. - ServiceID *bson.ObjectID `json:"serviceId"` - Interruption bool `json:"interruptInUseConnections"` - Error error `json:"error"` + ServiceID *bson.ObjectID `json:"serviceId"` + Interruption bool `json:"interruptInUseConnections"` + Error error `json:"error"` + RequestID int32 `json:"requestId"` + RemainingTime time.Duration `json:"remainingTime"` } // PoolMonitor is a function that allows the user to gain access to events occurring in the pool diff --git a/internal/driverutil/context.go b/internal/driverutil/context.go new file mode 100644 index 0000000000..5b8cd54b3a --- /dev/null +++ b/internal/driverutil/context.go @@ -0,0 +1,49 @@ +// Copyright (C) MongoDB, Inc. 2025-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package driverutil + +import "context" + +// ContextKey is a custom type used for the keys in context values to avoid +// collisions. +type ContextKey string + +const ( + // ContextKeyHasMaxTimeMS represents a boolean value that indicates if + // maxTimeMS will be set on the wire message for an operation. + ContextKeyHasMaxTimeMS ContextKey = "hasMaxTimeMS" + + // ContextKeyRequestID is the requestID for a given operation. This is used to + // propagate the requestID for a pending read during connection check out. + ContextKeyRequestID ContextKey = "requestID" +) + +// WithValueHasMaxTimeMS returns a copy of the parent context with an added +// value indicating whether an operation will append maxTimeMS to the wire +// message. +func WithValueHasMaxTimeMS(parentCtx context.Context, val bool) context.Context { + return context.WithValue(parentCtx, ContextKeyHasMaxTimeMS, val) +} + +// WithRequestID returns a copy of the parent context with an added request ID +// value. +func WithRequestID(parentCtx context.Context, requestID int32) context.Context { + return context.WithValue(parentCtx, ContextKeyRequestID, requestID) +} + +// HasMaxTimeMS checks if the context is for an operation that will append +// maxTimeMS to the wire message. +func HasMaxTimeMS(ctx context.Context) bool { + return ctx.Value(ContextKeyHasMaxTimeMS) != nil +} + +// GetRequestID retrieves the request ID from the context if it exists. +func GetRequestID(ctx context.Context) (int32, bool) { + val, ok := ctx.Value(ContextKeyRequestID).(int32) + + return val, ok +} diff --git a/internal/integration/client_test.go b/internal/integration/client_test.go index 0478967a52..c500cd2217 100644 --- a/internal/integration/client_test.go +++ b/internal/integration/client_test.go @@ -13,6 +13,7 @@ import ( "os" "reflect" "strings" + "sync" "testing" "time" @@ -675,9 +676,9 @@ func TestClient(t *testing.T) { }, } + _, err := mt.Coll.InsertOne(context.Background(), bson.D{}) for _, tc := range testCases { mt.Run(tc.desc, func(mt *mtest.T) { - _, err := mt.Coll.InsertOne(context.Background(), bson.D{}) require.NoError(mt, err) mt.SetFailPoint(failpoint.FailPoint{ @@ -692,30 +693,47 @@ func TestClient(t *testing.T) { mt.ClearEvents() + wg := sync.WaitGroup{} + wg.Add(50) + for i := 0; i < 50; i++ { - // Run 50 operations, each with a timeout of 50ms. Expect + // Run 50 concurrent operations, each with a timeout of 50ms. Expect // them to all return a timeout error because the failpoint - // blocks find operations for 500ms. Run 50 to increase the + // blocks find operations for 50ms. Run 50 to increase the // probability that an operation will time out in a way that // can cause a retry. - ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) - err = tc.operation(ctx, mt.Coll) - cancel() - assert.ErrorIs(mt, err, context.DeadlineExceeded) - assert.True(mt, mongo.IsTimeout(err), "expected mongo.IsTimeout(err) to be true") - - // Assert that each operation reported exactly one command - // started events, which means the operation did not retry - // after the context timeout. - evts := mt.GetAllStartedEvents() - require.Len(mt, - mt.GetAllStartedEvents(), - 1, - "expected exactly 1 command started event per operation, but got %d after %d iterations", - len(evts), - i) - mt.ClearEvents() + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 150*time.Millisecond) + err := tc.operation(ctx, mt.Coll) + cancel() + assert.ErrorIs(mt, err, context.DeadlineExceeded) + assert.True(mt, mongo.IsTimeout(err), "expected mongo.IsTimeout(err) to be true") + + wg.Done() + }() } + + wg.Wait() + + // Since an operation requires checking out a connection and because we + // attempt a pending read for socket timeouts and since the test forces + // 50 concurrent socket timeouts, then it's possible that an + // operation checks out a connection that has a pending read. In this + // case the operation will time out when checking out a connection, and + // a started event will not be propagated. So instead of + // checking that we got exactly 50 started events, we should instead + // ensure that the number of started events is equal to the number of + // unique connections used to process the operations. + pendingReadConns := mt.NumberConnectionsPendingReadStarted() + evts := mt.GetAllStartedEvents() + + require.Equal(mt, + len(evts)+pendingReadConns, + 50, + "expected exactly 1 command started event per operation (50), but got %d", + len(evts)+pendingReadConns) + mt.ClearEvents() + mt.ClearFailPoints() }) } }) diff --git a/internal/integration/csot_prose_test.go b/internal/integration/csot_prose_test.go index ce7219b042..f2854559ce 100644 --- a/internal/integration/csot_prose_test.go +++ b/internal/integration/csot_prose_test.go @@ -176,6 +176,7 @@ func TestCSOTProse(t *testing.T) { time.Millisecond, "expected ping to fail within 150ms") }) + }) mt.RunOpts("11. multi-batch bulkWrites", mtest.NewOptions().MinServerVersion("8.0"). diff --git a/internal/integration/csot_test.go b/internal/integration/csot_test.go index 6808efb2a4..877f5e0341 100644 --- a/internal/integration/csot_test.go +++ b/internal/integration/csot_test.go @@ -38,12 +38,13 @@ func TestCSOT_maxTimeMS(t *testing.T) { mt := mtest.New(t, mtest.NewOptions().CreateClient(false)) testCases := []struct { - desc string - commandName string - setup func(coll *mongo.Collection) error - operation func(ctx context.Context, coll *mongo.Collection) error - sendsMaxTimeMS bool - topologies []mtest.TopologyKind + desc string + commandName string + setup func(coll *mongo.Collection) error + operation func(ctx context.Context, coll *mongo.Collection) error + sendsMaxTimeMS bool + topologies []mtest.TopologyKind + preventsConnClosureWithTimeoutMS bool }{ { desc: "FindOne", @@ -55,7 +56,8 @@ func TestCSOT_maxTimeMS(t *testing.T) { operation: func(ctx context.Context, coll *mongo.Collection) error { return coll.FindOne(ctx, bson.D{}).Err() }, - sendsMaxTimeMS: true, + sendsMaxTimeMS: true, + preventsConnClosureWithTimeoutMS: true, }, { desc: "Find", @@ -68,7 +70,8 @@ func TestCSOT_maxTimeMS(t *testing.T) { _, err := coll.Find(ctx, bson.D{}) return err }, - sendsMaxTimeMS: false, + sendsMaxTimeMS: false, + preventsConnClosureWithTimeoutMS: false, }, { desc: "FindOneAndDelete", @@ -80,7 +83,8 @@ func TestCSOT_maxTimeMS(t *testing.T) { operation: func(ctx context.Context, coll *mongo.Collection) error { return coll.FindOneAndDelete(ctx, bson.D{}).Err() }, - sendsMaxTimeMS: true, + sendsMaxTimeMS: true, + preventsConnClosureWithTimeoutMS: true, }, { desc: "FindOneAndUpdate", @@ -92,7 +96,8 @@ func TestCSOT_maxTimeMS(t *testing.T) { operation: func(ctx context.Context, coll *mongo.Collection) error { return coll.FindOneAndUpdate(ctx, bson.D{}, bson.M{"$set": bson.M{"key": "value"}}).Err() }, - sendsMaxTimeMS: true, + sendsMaxTimeMS: true, + preventsConnClosureWithTimeoutMS: true, }, { desc: "FindOneAndReplace", @@ -104,7 +109,8 @@ func TestCSOT_maxTimeMS(t *testing.T) { operation: func(ctx context.Context, coll *mongo.Collection) error { return coll.FindOneAndReplace(ctx, bson.D{}, bson.D{}).Err() }, - sendsMaxTimeMS: true, + sendsMaxTimeMS: true, + preventsConnClosureWithTimeoutMS: true, }, { desc: "InsertOne", @@ -113,7 +119,8 @@ func TestCSOT_maxTimeMS(t *testing.T) { _, err := coll.InsertOne(ctx, bson.D{}) return err }, - sendsMaxTimeMS: true, + sendsMaxTimeMS: true, + preventsConnClosureWithTimeoutMS: true, }, { desc: "InsertMany", @@ -122,7 +129,8 @@ func TestCSOT_maxTimeMS(t *testing.T) { _, err := coll.InsertMany(ctx, []interface{}{bson.D{}}) return err }, - sendsMaxTimeMS: true, + sendsMaxTimeMS: true, + preventsConnClosureWithTimeoutMS: true, }, { desc: "UpdateOne", @@ -131,7 +139,8 @@ func TestCSOT_maxTimeMS(t *testing.T) { _, err := coll.UpdateOne(ctx, bson.D{}, bson.M{"$set": bson.M{"key": "value"}}) return err }, - sendsMaxTimeMS: true, + sendsMaxTimeMS: true, + preventsConnClosureWithTimeoutMS: true, }, { desc: "UpdateMany", @@ -140,7 +149,8 @@ func TestCSOT_maxTimeMS(t *testing.T) { _, err := coll.UpdateMany(ctx, bson.D{}, bson.M{"$set": bson.M{"key": "value"}}) return err }, - sendsMaxTimeMS: true, + sendsMaxTimeMS: true, + preventsConnClosureWithTimeoutMS: true, }, { desc: "ReplaceOne", @@ -149,7 +159,8 @@ func TestCSOT_maxTimeMS(t *testing.T) { _, err := coll.ReplaceOne(ctx, bson.D{}, bson.D{}) return err }, - sendsMaxTimeMS: true, + sendsMaxTimeMS: true, + preventsConnClosureWithTimeoutMS: true, }, { desc: "DeleteOne", @@ -158,7 +169,8 @@ func TestCSOT_maxTimeMS(t *testing.T) { _, err := coll.DeleteOne(ctx, bson.D{}) return err }, - sendsMaxTimeMS: true, + sendsMaxTimeMS: true, + preventsConnClosureWithTimeoutMS: true, }, { desc: "DeleteMany", @@ -168,6 +180,8 @@ func TestCSOT_maxTimeMS(t *testing.T) { return err }, sendsMaxTimeMS: true, + + preventsConnClosureWithTimeoutMS: true, }, { desc: "Distinct", @@ -175,7 +189,8 @@ func TestCSOT_maxTimeMS(t *testing.T) { operation: func(ctx context.Context, coll *mongo.Collection) error { return coll.Distinct(ctx, "name", bson.D{}).Err() }, - sendsMaxTimeMS: true, + sendsMaxTimeMS: true, + preventsConnClosureWithTimeoutMS: true, }, { desc: "Aggregate", @@ -184,7 +199,8 @@ func TestCSOT_maxTimeMS(t *testing.T) { _, err := coll.Aggregate(ctx, mongo.Pipeline{}) return err }, - sendsMaxTimeMS: false, + sendsMaxTimeMS: false, + preventsConnClosureWithTimeoutMS: false, }, { desc: "Watch", @@ -196,7 +212,8 @@ func TestCSOT_maxTimeMS(t *testing.T) { } return err }, - sendsMaxTimeMS: true, + sendsMaxTimeMS: true, + preventsConnClosureWithTimeoutMS: false, // Change Streams aren't supported on standalone topologies. topologies: []mtest.TopologyKind{ mtest.ReplicaSet, @@ -218,7 +235,8 @@ func TestCSOT_maxTimeMS(t *testing.T) { var res []bson.D return cursor.All(ctx, &res) }, - sendsMaxTimeMS: false, + sendsMaxTimeMS: false, + preventsConnClosureWithTimeoutMS: false, }, } @@ -348,56 +366,57 @@ func TestCSOT_maxTimeMS(t *testing.T) { assertMaxTimeMSNotSet(mt, evt.Command) } }) + if tc.preventsConnClosureWithTimeoutMS { + opts := mtest.NewOptions(). + // Blocking failpoints don't work on pre-4.2 and sharded + // clusters. + Topologies(mtest.Single, mtest.ReplicaSet). + MinServerVersion("4.2") + mt.RunOpts("prevents connection closure", opts, func(mt *mtest.T) { + if tc.setup != nil { + err := tc.setup(mt.Coll) + require.NoError(mt, err) + } - opts := mtest.NewOptions(). - // Blocking failpoints don't work on pre-4.2 and sharded - // clusters. - Topologies(mtest.Single, mtest.ReplicaSet). - MinServerVersion("4.2") - mt.RunOpts("prevents connection closure", opts, func(mt *mtest.T) { - if tc.setup != nil { - err := tc.setup(mt.Coll) - require.NoError(mt, err) - } - - mt.SetFailPoint(failpoint.FailPoint{ - ConfigureFailPoint: "failCommand", - Mode: failpoint.ModeAlwaysOn, - Data: failpoint.Data{ - FailCommands: []string{tc.commandName}, - BlockConnection: true, - // Note that some operations (currently Find and - // Aggregate) do not send maxTimeMS by default, meaning - // that the server will only respond after BlockTimeMS - // is elapsed. If the amount of time that the driver - // waits for responses after a timeout is significantly - // lower than BlockTimeMS, this test will start failing - // for those operations. - BlockTimeMS: 500, - }, - }) - - tpm := eventtest.NewTestPoolMonitor() - mt.ResetClient(options.Client(). - SetPoolMonitor(tpm.PoolMonitor)) - - // Run 5 operations that time out, then assert that no - // connections were closed. - for i := 0; i < 5; i++ { - ctx, cancel := context.WithTimeout(context.Background(), 15*time.Millisecond) - err := tc.operation(ctx, mt.Coll) - cancel() - - if !mongo.IsTimeout(err) { - t.Logf("Operation %d returned a non-timeout error: %v", i, err) + mt.SetFailPoint(failpoint.FailPoint{ + ConfigureFailPoint: "failCommand", + Mode: failpoint.ModeAlwaysOn, + Data: failpoint.Data{ + FailCommands: []string{tc.commandName}, + BlockConnection: true, + // Note that some operations (currently Find and + // Aggregate) do not send maxTimeMS by default, meaning + // that the server will only respond after BlockTimeMS + // is elapsed. If the amount of time that the driver + // waits for responses after a timeout is significantly + // lower than BlockTimeMS, this test will start failing + // for those operations. + BlockTimeMS: 500, + }, + }) + + tpm := eventtest.NewTestPoolMonitor() + mt.ResetClient(options.Client(). + SetPoolMonitor(tpm.PoolMonitor)) + + // Run 5 operations that time out, then assert that no + // connections were closed. + for i := 0; i < 5; i++ { + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Millisecond) + err := tc.operation(ctx, mt.Coll) + cancel() + + if !mongo.IsTimeout(err) { + t.Logf("Operation %d returned a non-timeout error: %v", i, err) + } } - } - closedEvents := tpm.Events(func(pe *event.PoolEvent) bool { - return pe.Type == event.ConnectionClosed + closedEvents := tpm.Events(func(pe *event.PoolEvent) bool { + return pe.Type == event.ConnectionClosed + }) + assert.Len(mt, closedEvents, 0, "expected no connection closed event") }) - assert.Len(mt, closedEvents, 0, "expected no connection closed event") - }) + } }) } diff --git a/internal/integration/mtest/mongotest.go b/internal/integration/mtest/mongotest.go index 3967bf7f82..1430ab6e0c 100644 --- a/internal/integration/mtest/mongotest.go +++ b/internal/integration/mtest/mongotest.go @@ -55,7 +55,10 @@ type T struct { // It must be accessed using the atomic package and should be at the beginning of the struct. // - atomic bug: https://pkg.go.dev/sync/atomic#pkg-note-BUG // - suggested layout: https://go101.org/article/memory-layout.html - connsCheckedOut int64 + connsCheckedOut int64 + connPendingReadStarted int64 + connPendingReadSucceeded int64 + connPendingReadFailed int64 *testing.T @@ -348,6 +351,20 @@ func (t *T) NumberConnectionsCheckedOut() int { return int(atomic.LoadInt64(&t.connsCheckedOut)) } +// NumberConnectionsPendingReadStarted returns the number of connections that have +// started a pending read. +func (t *T) NumberConnectionsPendingReadStarted() int { + return int(atomic.LoadInt64(&t.connPendingReadStarted)) +} + +func (t *T) NumberConnectionsPendingReadSucceeded() int { + return int(atomic.LoadInt64(&t.connPendingReadSucceeded)) +} + +func (t *T) NumberConnectionsPendingReadFailed() int { + return int(atomic.LoadInt64(&t.connPendingReadFailed)) +} + // ClearEvents clears the existing command monitoring events. func (t *T) ClearEvents() { t.started = t.started[:0] @@ -547,6 +564,11 @@ func (t *T) TrackFailPoint(fpName string) { // ClearFailPoints disables all previously set failpoints for this test. func (t *T) ClearFailPoints() { + // Run some arbitrary command to ensure that any connection that would + // otherwise blocking during a pending read is closed. This could happen if + // the mode times > 1 and the blocking time is > default pending read timeout. + _ = t.Client.Ping(context.Background(), nil) + db := t.Client.Database("admin") for _, fp := range t.failPointNames { cmd := failpoint.FailPoint{ @@ -640,6 +662,12 @@ func (t *T) createTestClient() { atomic.AddInt64(&t.connsCheckedOut, 1) case event.ConnectionCheckedIn: atomic.AddInt64(&t.connsCheckedOut, -1) + case event.ConnectionPendingResponseStarted: + atomic.AddInt64(&t.connPendingReadStarted, 1) + case event.ConnectionPendingResponseSucceeded: + atomic.AddInt64(&t.connPendingReadSucceeded, 1) + case event.ConnectionCheckOutFailed: + atomic.AddInt64(&t.connPendingReadFailed, 1) } }, }) diff --git a/internal/integration/unified/event.go b/internal/integration/unified/event.go index abbec74439..9ee8fe7404 100644 --- a/internal/integration/unified/event.go +++ b/internal/integration/unified/event.go @@ -16,27 +16,30 @@ import ( type monitoringEventType string const ( - commandStartedEvent monitoringEventType = "CommandStartedEvent" - commandSucceededEvent monitoringEventType = "CommandSucceededEvent" - commandFailedEvent monitoringEventType = "CommandFailedEvent" - poolCreatedEvent monitoringEventType = "PoolCreatedEvent" - poolReadyEvent monitoringEventType = "PoolReadyEvent" - poolClearedEvent monitoringEventType = "PoolClearedEvent" - poolClosedEvent monitoringEventType = "PoolClosedEvent" - connectionCreatedEvent monitoringEventType = "ConnectionCreatedEvent" - connectionReadyEvent monitoringEventType = "ConnectionReadyEvent" - connectionClosedEvent monitoringEventType = "ConnectionClosedEvent" - connectionCheckOutStartedEvent monitoringEventType = "ConnectionCheckOutStartedEvent" - connectionCheckOutFailedEvent monitoringEventType = "ConnectionCheckOutFailedEvent" - connectionCheckedOutEvent monitoringEventType = "ConnectionCheckedOutEvent" - connectionCheckedInEvent monitoringEventType = "ConnectionCheckedInEvent" - serverDescriptionChangedEvent monitoringEventType = "ServerDescriptionChangedEvent" - serverHeartbeatFailedEvent monitoringEventType = "ServerHeartbeatFailedEvent" - serverHeartbeatStartedEvent monitoringEventType = "ServerHeartbeatStartedEvent" - serverHeartbeatSucceededEvent monitoringEventType = "ServerHeartbeatSucceededEvent" - topologyDescriptionChangedEvent monitoringEventType = "TopologyDescriptionChangedEvent" - topologyOpeningEvent monitoringEventType = "TopologyOpeningEvent" - topologyClosedEvent monitoringEventType = "TopologyClosedEvent" + commandStartedEvent monitoringEventType = "CommandStartedEvent" + commandSucceededEvent monitoringEventType = "CommandSucceededEvent" + commandFailedEvent monitoringEventType = "CommandFailedEvent" + poolCreatedEvent monitoringEventType = "PoolCreatedEvent" + poolReadyEvent monitoringEventType = "PoolReadyEvent" + poolClearedEvent monitoringEventType = "PoolClearedEvent" + poolClosedEvent monitoringEventType = "PoolClosedEvent" + connectionCreatedEvent monitoringEventType = "ConnectionCreatedEvent" + connectionReadyEvent monitoringEventType = "ConnectionReadyEvent" + connectionClosedEvent monitoringEventType = "ConnectionClosedEvent" + connectionCheckOutStartedEvent monitoringEventType = "ConnectionCheckOutStartedEvent" + connectionCheckOutFailedEvent monitoringEventType = "ConnectionCheckOutFailedEvent" + connectionCheckedOutEvent monitoringEventType = "ConnectionCheckedOutEvent" + connectionCheckedInEvent monitoringEventType = "ConnectionCheckedInEvent" + connectionPendingResponseStarted monitoringEventType = "ConnectionPendingResponseStarted" + connectionPendingResponseSucceeded monitoringEventType = "ConnectionPendingResponseSucceeded" + connectionPendingResponseFailed monitoringEventType = "ConnectionPendingResponseFailed" + serverDescriptionChangedEvent monitoringEventType = "ServerDescriptionChangedEvent" + serverHeartbeatFailedEvent monitoringEventType = "ServerHeartbeatFailedEvent" + serverHeartbeatStartedEvent monitoringEventType = "ServerHeartbeatStartedEvent" + serverHeartbeatSucceededEvent monitoringEventType = "ServerHeartbeatSucceededEvent" + topologyDescriptionChangedEvent monitoringEventType = "TopologyDescriptionChangedEvent" + topologyOpeningEvent monitoringEventType = "TopologyOpeningEvent" + topologyClosedEvent monitoringEventType = "TopologyClosedEvent" ) func monitoringEventTypeFromString(eventStr string) (monitoringEventType, bool) { @@ -69,6 +72,12 @@ func monitoringEventTypeFromString(eventStr string) (monitoringEventType, bool) return connectionCheckedOutEvent, true case "connectioncheckedinevent": return connectionCheckedInEvent, true + case "connectionpendingresponsestarted": + return connectionPendingResponseStarted, true + case "connectionpendingresponsesucceeded": + return connectionPendingResponseSucceeded, true + case "connectionpendingresponsefailed": + return connectionPendingResponseFailed, true case "serverdescriptionchangedevent": return serverDescriptionChangedEvent, true case "serverheartbeatfailedevent": @@ -112,6 +121,12 @@ func monitoringEventTypeFromPoolEvent(evt *event.PoolEvent) monitoringEventType return connectionCheckedOutEvent case event.ConnectionCheckedIn: return connectionCheckedInEvent + case event.ConnectionPendingResponseStarted: + return connectionPendingResponseStarted + case event.ConnectionPendingResponseSucceeded: + return connectionPendingResponseSucceeded + case event.ConnectionPendingResponseFailed: + return connectionPendingResponseFailed default: return "" } diff --git a/internal/integration/unified/event_verification.go b/internal/integration/unified/event_verification.go index 56c53f8adb..eb3e8b49be 100644 --- a/internal/integration/unified/event_verification.go +++ b/internal/integration/unified/event_verification.go @@ -56,7 +56,10 @@ type cmapEvent struct { Reason *string `bson:"reason"` } `bson:"connectionCheckOutFailedEvent"` - ConnectionCheckedInEvent *struct{} `bson:"connectionCheckedInEvent"` + ConnectionCheckedInEvent *struct{} `bson:"connectionCheckedInEvent"` + ConnectionPendingResponseStarted *struct{} `bson:"connectionPendingResponseStarted"` + ConnectionPendingResponseSucceeded *struct{} `bson:"connectionPendingResponseSucceeded"` + ConnectionPendingResponseFailed *struct{} `bson:"connectionPendingResponseFailed"` PoolClearedEvent *struct { HasServiceID *bool `bson:"hasServiceId"` @@ -359,6 +362,18 @@ func verifyCMAPEvents(client *clientEntity, expectedEvents *expectedEvents) erro if _, pooled, err = getNextPoolEvent(pooled, event.ConnectionCheckedIn); err != nil { return newEventVerificationError(idx, client, "failed to get next pool event: %v", err.Error()) } + case evt.ConnectionPendingResponseStarted != nil: + if _, pooled, err = getNextPoolEvent(pooled, event.ConnectionPendingResponseStarted); err != nil { + return newEventVerificationError(idx, client, "failed to get next pool event: %v", err.Error()) + } + case evt.ConnectionPendingResponseSucceeded != nil: + if _, pooled, err = getNextPoolEvent(pooled, event.ConnectionPendingResponseSucceeded); err != nil { + return newEventVerificationError(idx, client, "failed to get next pool event: %v", err.Error()) + } + case evt.ConnectionPendingResponseFailed != nil: + if _, pooled, err = getNextPoolEvent(pooled, event.ConnectionPendingResponseFailed); err != nil { + return newEventVerificationError(idx, client, "failed to get next pool event: %v", err.Error()) + } case evt.PoolClearedEvent != nil: var actual *event.PoolEvent if actual, pooled, err = getNextPoolEvent(pooled, event.ConnectionPoolCleared); err != nil { diff --git a/internal/logger/component.go b/internal/logger/component.go index a601707cbf..85acf05142 100644 --- a/internal/logger/component.go +++ b/internal/logger/component.go @@ -28,6 +28,9 @@ const ( ConnectionCheckoutFailed = "Connection checkout failed" ConnectionCheckedOut = "Connection checked out" ConnectionCheckedIn = "Connection checked in" + ConnectionPendingReadStarted = "Pending response started" + ConnectionPendingReadSucceeded = "Pending response succeeded" + ConnectionPendingReadFailed = "Pending response failed" ServerSelectionFailed = "Server selection failed" ServerSelectionStarted = "Server selection started" ServerSelectionSucceeded = "Server selection succeeded" diff --git a/testdata/specifications b/testdata/specifications index 43d2c7bacd..6118debee4 160000 --- a/testdata/specifications +++ b/testdata/specifications @@ -1 +1 @@ -Subproject commit 43d2c7bacd62249de8d2173bf8ee39e6fd7a686e +Subproject commit 6118debee41cfd1bca197b315bd1f10ad95f66ae diff --git a/x/mongo/driver/operation.go b/x/mongo/driver/operation.go index 2597a5de66..cc9f631a99 100644 --- a/x/mongo/driver/operation.go +++ b/x/mongo/driver/operation.go @@ -785,6 +785,14 @@ func (op Operation) Execute(ctx context.Context) error { if moreToCome { roundTrip = op.moreToComeRoundTrip } + + // Set context values to handle a pending read in case of a socket + // timeout. + if maxTimeMS != 0 { + ctx = driverutil.WithValueHasMaxTimeMS(ctx, true) + ctx = driverutil.WithRequestID(ctx, startedInfo.requestID) + } + res, err = roundTrip(ctx, conn, *wm) if ep, ok := srvr.(ErrorProcessor); ok { diff --git a/x/mongo/driver/topology/cmap_prose_test.go b/x/mongo/driver/topology/cmap_prose_test.go index 0524b99e9c..1b33e263ec 100644 --- a/x/mongo/driver/topology/cmap_prose_test.go +++ b/x/mongo/driver/topology/cmap_prose_test.go @@ -9,13 +9,19 @@ package topology import ( "context" "errors" + "io" "net" + "regexp" + "sync" "testing" "time" "go.mongodb.org/mongo-driver/v2/event" "go.mongodb.org/mongo-driver/v2/internal/assert" + "go.mongodb.org/mongo-driver/v2/internal/csot" + "go.mongodb.org/mongo-driver/v2/internal/driverutil" "go.mongodb.org/mongo-driver/v2/internal/require" + "go.mongodb.org/mongo-driver/v2/mongo/address" "go.mongodb.org/mongo-driver/v2/x/mongo/driver/operation" ) @@ -263,6 +269,202 @@ func TestCMAPProse(t *testing.T) { }) }) }) + + // Need to test the case where we attempt a non-blocking read to determine if + // we should refresh the remaining time. In the case of the Go Driver, we do + // this by attempt to "pee" at 1 byte with a deadline of 1ns. + t.Run("connection attempts peek but fails", func(t *testing.T) { + const requestID = int32(-1) + timeout := 10 * time.Millisecond + + // Mock a TCP listener that will write a byte sequence > 5 (to avoid errors + // due to size) to the TCP socket. Have the listener sleep for 2x the + // timeout provided to the connection AFTER writing the byte sequence. This + // wiill cause the connection to timeout while reading from the socket. + addr := bootstrapConnections(t, 1, func(nc net.Conn) { + defer func() { + _ = nc.Close() + }() + + _, err := nc.Write([]byte{12, 0, 0, 0, 0, 0, 0, 0, 1}) + require.NoError(t, err) + time.Sleep(timeout * 2) + + // Write nothing so that the 1 millisecond "non-blocking" peek fails. + }) + + poolEventsByType := make(map[string][]event.PoolEvent) + poolEventsByTypeMu := &sync.Mutex{} + + monitor := &event.PoolMonitor{ + Event: func(pe *event.PoolEvent) { + poolEventsByTypeMu.Lock() + poolEventsByType[pe.Type] = append(poolEventsByType[pe.Type], *pe) + poolEventsByTypeMu.Unlock() + }, + } + + p := newPool( + poolConfig{ + Address: address.Address(addr.String()), + PoolMonitor: monitor, + }, + ) + defer p.close(context.Background()) + err := p.ready() + require.NoError(t, err) + + // Check out a connection and read from the socket, causing a timeout and + // pinning the connection to a pending read state. + conn, err := p.checkOut(context.Background()) + require.NoError(t, err) + + ctx, cancel := csot.WithTimeout(context.Background(), &timeout) + defer cancel() + + ctx = driverutil.WithValueHasMaxTimeMS(ctx, true) + ctx = driverutil.WithRequestID(ctx, requestID) + + _, err = conn.readWireMessage(ctx) + regex := regexp.MustCompile( + `^connection\(.*\[-\d+\]\) incomplete read of full message: context deadline exceeded: read tcp 127.0.0.1:.*->127.0.0.1:.*: i\/o timeout$`, + ) + assert.True(t, regex.MatchString(err.Error()), "error %q does not match pattern %q", err, regex) + + // Check in the connection with a pending read state. The next time this + // connection is checked out, it should attempt to read the pending + // response. + err = p.checkIn(conn) + require.NoError(t, err) + + // Wait 3s to make sure there is no remaining time on the pending read + // state. + time.Sleep(3 * time.Second) + + // Check out the connection again. The remaining time should be exhausted + // requiring us to "peek" at the connection to determine if we should + _, err = p.checkOut(context.Background()) + assert.ErrorIs(t, err, io.EOF) + + // There should be 1 ConnectionPendingResponseStarted event. + started := poolEventsByType[event.ConnectionPendingResponseStarted] + require.Len(t, started, 1) + + assert.Equal(t, addr.String(), started[0].Address) + assert.Equal(t, conn.driverConnectionID, started[0].ConnectionID) + assert.Equal(t, requestID, started[0].RequestID) + + // There should be 1 ConnectionPendingResponseFailed event. + failed := poolEventsByType[event.ConnectionPendingResponseFailed] + require.Len(t, failed, 1) + + assert.Equal(t, addr.String(), failed[0].Address) + assert.Equal(t, conn.driverConnectionID, failed[0].ConnectionID) + assert.Equal(t, requestID, failed[0].RequestID) + assert.Equal(t, "error", failed[0].Reason) + assert.ErrorIs(t, failed[0].Error, io.EOF) + assert.Equal(t, time.Duration(0), failed[0].RemainingTime) + + // There should be 0 ConnectionPendingResponseSucceeded event. + require.Len(t, poolEventsByType[event.ConnectionPendingResponseSucceeded], 0) + }) + + t.Run("connection attempts peek and succeeds", func(t *testing.T) { + const requestID = int32(-1) + timeout := 10 * time.Millisecond + + // Mock a TCP listener that will write a byte sequence > 5 (to avoid errors + // due to size) to the TCP socket. Have the listener sleep for 2x the + // timeout provided to the connection AFTER writing the byte sequence. This + // wiill cause the connection to timeout while reading from the socket. + addr := bootstrapConnections(t, 1, func(nc net.Conn) { + defer func() { + _ = nc.Close() + }() + + _, err := nc.Write([]byte{12, 0, 0, 0, 0, 0, 0, 0, 1}) + require.NoError(t, err) + time.Sleep(timeout * 2) + + // Write data that can be peeked at. + _, err = nc.Write([]byte{12, 0, 0, 0, 0, 0, 0, 0, 1}) + require.NoError(t, err) + + }) + + poolEventsByType := make(map[string][]event.PoolEvent) + poolEventsByTypeMu := &sync.Mutex{} + + monitor := &event.PoolMonitor{ + Event: func(pe *event.PoolEvent) { + poolEventsByTypeMu.Lock() + poolEventsByType[pe.Type] = append(poolEventsByType[pe.Type], *pe) + poolEventsByTypeMu.Unlock() + }, + } + + p := newPool( + poolConfig{ + Address: address.Address(addr.String()), + PoolMonitor: monitor, + }, + ) + defer p.close(context.Background()) + err := p.ready() + require.NoError(t, err) + + // Check out a connection and read from the socket, causing a timeout and + // pinning the connection to a pending read state. + conn, err := p.checkOut(context.Background()) + require.NoError(t, err) + + ctx, cancel := csot.WithTimeout(context.Background(), &timeout) + defer cancel() + + ctx = driverutil.WithValueHasMaxTimeMS(ctx, true) + ctx = driverutil.WithRequestID(ctx, requestID) + + _, err = conn.readWireMessage(ctx) + regex := regexp.MustCompile( + `^connection\(.*\[-\d+\]\) incomplete read of full message: context deadline exceeded: read tcp 127.0.0.1:.*->127.0.0.1:.*: i\/o timeout$`, + ) + assert.True(t, regex.MatchString(err.Error()), "error %q does not match pattern %q", err, regex) + + // Check in the connection with a pending read state. The next time this + // connection is checked out, it should attempt to read the pending + // response. + err = p.checkIn(conn) + require.NoError(t, err) + + // Wait 3s to make sure there is no remaining time on the pending read + // state. + time.Sleep(3 * time.Second) + + // Check out the connection again. The remaining time should be exhausted + // requiring us to "peek" at the connection to determine if we should + _, err = p.checkOut(context.Background()) + require.NoError(t, err) + + // There should be 1 ConnectionPendingResponseStarted event. + started := poolEventsByType[event.ConnectionPendingResponseStarted] + require.Len(t, started, 1) + + assert.Equal(t, addr.String(), started[0].Address) + assert.Equal(t, conn.driverConnectionID, started[0].ConnectionID) + assert.Equal(t, requestID, started[0].RequestID) + + // There should be 0 ConnectionPendingResponseFailed event. + require.Len(t, poolEventsByType[event.ConnectionPendingResponseFailed], 0) + + // There should be 1 ConnectionPendingResponseSucceeded event. + succeeded := poolEventsByType[event.ConnectionPendingResponseSucceeded] + require.Len(t, succeeded, 1) + + assert.Equal(t, addr.String(), succeeded[0].Address) + assert.Equal(t, conn.driverConnectionID, succeeded[0].ConnectionID) + assert.Equal(t, requestID, succeeded[0].RequestID) + assert.Greater(t, int(succeeded[0].Duration), 0) + }) } func createTestPool(t *testing.T, cfg poolConfig, opts ...ConnectionOption) *pool { diff --git a/x/mongo/driver/topology/connection.go b/x/mongo/driver/topology/connection.go index 24ad6a3a51..7edbd15ceb 100644 --- a/x/mongo/driver/topology/connection.go +++ b/x/mongo/driver/topology/connection.go @@ -47,6 +47,12 @@ var ( func nextConnectionID() uint64 { return atomic.AddUint64(&globalConnectionID, 1) } +type pendingReadState struct { + remainingBytes int32 + requestID int32 + start time.Time +} + type connection struct { // state must be accessed using the atomic package and should be at the beginning of the struct. // - atomic bug: https://pkg.go.dev/sync/atomic#pkg-note-BUG @@ -82,9 +88,11 @@ type connection struct { // accessTokens in the OIDC authenticator cache. oidcTokenGenID uint64 - // awaitRemainingBytes indicates the size of server response that was not completely - // read before returning the connection to the pool. - awaitRemainingBytes *int32 + // pendingReadState contains information required to attempt a pending read + // in the event of a socket timeout for an operation that has appended + // maxTimeMS to the wire message. + pendingReadState *pendingReadState + pendingReadMu sync.Mutex } // newConnection handles the creation of a connection. It does not connect the connection. @@ -407,11 +415,14 @@ func (c *connection) readWireMessage(ctx context.Context) ([]byte, error) { dst, errMsg, err := c.read(ctx) if err != nil { - if c.awaitRemainingBytes == nil { - // If the connection was not marked as awaiting response, close the - // connection because we don't know what the connection state is. + c.pendingReadMu.Lock() + if c.pendingReadState == nil { + // If there is no pending read on the connection, use the pre-CSOT + // behavior and close the connection because we don't know if there are + // other bytes left to read. c.close() } + c.pendingReadMu.Unlock() message := errMsg if errors.Is(err, io.EOF) { message = "socket was unexpectedly closed" @@ -476,8 +487,14 @@ func (c *connection) read(ctx context.Context) (bytesRead []byte, errMsg string, // reading messages from an exhaust cursor. n, err := io.ReadFull(c.nc, sizeBuf[:]) if err != nil { - if l := int32(n); l == 0 && isCSOTTimeout(err) { - c.awaitRemainingBytes = &l + if l := int32(n); l == 0 && isCSOTTimeout(err) && driverutil.HasMaxTimeMS(ctx) { + requestID, _ := driverutil.GetRequestID(ctx) + + c.pendingReadState = &pendingReadState{ + remainingBytes: l, + requestID: requestID, + start: time.Now(), + } } return nil, "incomplete read of message header", err } @@ -492,8 +509,14 @@ func (c *connection) read(ctx context.Context) (bytesRead []byte, errMsg string, n, err = io.ReadFull(c.nc, dst[4:]) if err != nil { remainingBytes := size - 4 - int32(n) - if remainingBytes > 0 && isCSOTTimeout(err) { - c.awaitRemainingBytes = &remainingBytes + if remainingBytes > 0 && isCSOTTimeout(err) && driverutil.HasMaxTimeMS(ctx) { + requestID, _ := driverutil.GetRequestID(ctx) + + c.pendingReadState = &pendingReadState{ + remainingBytes: remainingBytes, + requestID: requestID, + start: time.Now(), + } } return dst, "incomplete read of full message", err } diff --git a/x/mongo/driver/topology/pool.go b/x/mongo/driver/topology/pool.go index d6568e844f..7bfd368818 100644 --- a/x/mongo/driver/topology/pool.go +++ b/x/mongo/driver/topology/pool.go @@ -7,7 +7,9 @@ package topology import ( + "bufio" "context" + "errors" "fmt" "io" "net" @@ -576,6 +578,10 @@ func (p *pool) checkOut(ctx context.Context) (conn *connection, err error) { return nil, w.err } + if err := awaitPendingRead(ctx, p, w.conn); err != nil { + return nil, err + } + duration = time.Since(start) if mustLogPoolMessage(p) { keysAndValues := logger.KeyValues{ @@ -632,6 +638,10 @@ func (p *pool) checkOut(ctx context.Context) (conn *connection, err error) { return nil, w.err } + if err := awaitPendingRead(ctx, p, w.conn); err != nil { + return nil, err + } + duration := time.Since(start) if mustLogPoolMessage(p) { keysAndValues := logger.KeyValues{ @@ -771,82 +781,289 @@ func (p *pool) removeConnection(conn *connection, reason reason, err error) erro return nil } -var ( - // BGReadTimeout is the maximum amount of the to wait when trying to read - // the server reply on a connection after an operation timed out. The - // default is 400ms. - // - // Deprecated: BGReadTimeout is intended for internal use only and may be - // removed or modified at any time. - BGReadTimeout = 400 * time.Millisecond +// PendingReadTimeout is the maximum amount of the to wait when trying to read +// the server reply on a connection after an operation timed out. The +// default is 400 milliseconds. This value is refreshed for every 4KB read from +// the TCP stream. +// +// Deprecated: PendingReadTimeout is intended for internal use only and may be +// removed or modified at any time. +var PendingReadTimeout = 3000 * time.Millisecond + +// publishPendingReadStarted will log a message to the pool logger and +// publish an event to the pool monitor if they are set. +func publishPendingReadStarted(pool *pool, conn *connection) { + prs := conn.pendingReadState + if prs == nil { + return + } - // BGReadCallback is a callback for monitoring the behavior of the - // background-read-on-timeout connection preserving mechanism. - // - // Deprecated: BGReadCallback is intended for internal use only and may be - // removed or modified at any time. - BGReadCallback func(addr string, start, read time.Time, errs []error, connClosed bool) -) + // log a message to the pool logger if it is set. + if mustLogPoolMessage(pool) { + keysAndValues := logger.KeyValues{ + logger.KeyDriverConnectionID, conn.driverConnectionID, + logger.KeyRequestID, prs.requestID, + } -// bgRead sets a new read deadline on the provided connection and tries to read -// any bytes returned by the server. If successful, it checks the connection -// into the provided pool. If there are any errors, it closes the connection. -// -// It calls the package-global BGReadCallback function, if set, with the -// address, timings, and any errors that occurred. -func bgRead(pool *pool, conn *connection, size int32) { - var err error - start := time.Now() + logPoolMessage(pool, logger.ConnectionPendingReadStarted, keysAndValues...) + } - defer func() { - read := time.Now() - errs := make([]error, 0) - connClosed := false - if err != nil { - errs = append(errs, err) - connClosed = true - err = conn.close() - if err != nil { - errs = append(errs, fmt.Errorf("error closing conn after reading: %w", err)) - } + // publish an event to the pool monitor if it is set. + if pool.monitor != nil { + event := &event.PoolEvent{ + Type: event.ConnectionPendingResponseStarted, + Address: pool.address.String(), + ConnectionID: conn.driverConnectionID, + RequestID: prs.requestID, } - // No matter what happens, always check the connection back into the - // pool, which will either make it available for other operations or - // remove it from the pool if it was closed. - err = pool.checkInNoEvent(conn) - if err != nil { - errs = append(errs, fmt.Errorf("error checking in: %w", err)) + pool.monitor.Event(event) + } +} + +func publishPendingReadFailed(pool *pool, conn *connection, err error) { + prs := conn.pendingReadState + if prs == nil { + return + } + + reason := event.ReasonError + if errors.Is(err, context.DeadlineExceeded) { + reason = event.ReasonTimedOut + } + + if mustLogPoolMessage(pool) { + keysAndValues := logger.KeyValues{ + logger.KeyDriverConnectionID, conn.driverConnectionID, + logger.KeyRequestID, prs.requestID, + logger.KeyReason, reason, + logger.KeyError, err.Error(), } - if BGReadCallback != nil { - BGReadCallback(conn.addr.String(), start, read, errs, connClosed) + logPoolMessage(pool, logger.ConnectionPendingReadFailed, keysAndValues...) + } + + if pool.monitor != nil { + e := &event.PoolEvent{ + Type: event.ConnectionPendingResponseFailed, + Address: pool.address.String(), + ConnectionID: conn.driverConnectionID, + RequestID: prs.requestID, + Reason: reason, + Error: err, } - }() + pool.monitor.Event(e) + } +} - err = conn.nc.SetReadDeadline(time.Now().Add(BGReadTimeout)) - if err != nil { - err = fmt.Errorf("error setting a read deadline: %w", err) +func publishPendingReadSucceeded(pool *pool, conn *connection, dur time.Duration) { + prs := conn.pendingReadState + if prs == nil { return } - if size == 0 { + if mustLogPoolMessage(pool) { + keysAndValues := logger.KeyValues{ + logger.KeyDriverConnectionID, conn.driverConnectionID, + logger.KeyRequestID, prs.requestID, + logger.KeyDurationMS, dur.Milliseconds(), + } + + logPoolMessage(pool, logger.ConnectionPendingReadSucceeded, keysAndValues...) + } + + if pool.monitor != nil { + event := &event.PoolEvent{ + Type: event.ConnectionPendingResponseSucceeded, + Address: pool.address.String(), + ConnectionID: conn.driverConnectionID, + RequestID: prs.requestID, + Duration: dur, + } + + pool.monitor.Event(event) + } +} + +// peekConnectionAlive checks if the connection is alive by peeking at the +// buffered reader. If the connection is closed, it will return false. +func peekConnectionAlive(conn *connection) (int, error) { + // Set a very short deadline to avoid blocking. + if err := conn.nc.SetReadDeadline(time.Now().Add(1 * time.Millisecond)); err != nil { + return 0, err + } + + // Wrap the connection in a buffered reader to use peek. + reader := bufio.NewReader(conn.nc) + + // Try to peek at one byte. + bytes, err := reader.Peek(1) + return len(bytes), err +} + +func attemptPendingRead(ctx context.Context, conn *connection, remainingTime time.Duration) (int, error) { + pendingreadState := conn.pendingReadState + if pendingreadState == nil { + return 0, fmt.Errorf("no pending read state") + } + + dl, contextDeadlineUsed := ctx.Deadline() + calculatedDeadline := time.Now().Add(remainingTime) + + if contextDeadlineUsed { + // Use the minimum of the user-provided deadline and the calculated + // deadline. + if calculatedDeadline.Before(dl) { + dl = calculatedDeadline + } + } else { + dl = calculatedDeadline + } + + err := conn.nc.SetReadDeadline(dl) + if err != nil { + return 0, fmt.Errorf("error setting a read deadline: %w", err) + } + + size := pendingreadState.remainingBytes + + if size == 0 { // Question: Would this alawys equal to zero? var sizeBuf [4]byte - _, err = io.ReadFull(conn.nc, sizeBuf[:]) - if err != nil { - err = fmt.Errorf("error reading the message size: %w", err) - return + if bytesRead, err := io.ReadFull(conn.nc, sizeBuf[:]); err != nil { + err = transformNetworkError(ctx, err, contextDeadlineUsed) + + return bytesRead, fmt.Errorf("error reading the message size: %w", err) } + size, err = conn.parseWmSizeBytes(sizeBuf) if err != nil { - return + return int(size), transformNetworkError(ctx, err, contextDeadlineUsed) } size -= 4 } - _, err = io.CopyN(io.Discard, conn.nc, int64(size)) + + const bufSize = 4096 + buf := make([]byte, bufSize) + + var totalRead int64 + + // Iterate every 4KB of the TCP stream, refreshing the remainingTimeout for + // each successful read to avoid closing while streaming large (upto 16MiB) + // response messages. + for totalRead < int64(size) { + newDeadline := time.Now().Add(time.Until(dl)) + if err := conn.nc.SetReadDeadline(newDeadline); err != nil { + return int(totalRead), fmt.Errorf("error renewing read deadline: %w", err) + } + + remaining := int64(size) - totalRead + + readSize := bufSize + if int64(readSize) > remaining { + readSize = int(remaining) + } + + n, err := conn.nc.Read(buf[:readSize]) + if n > 0 { + totalRead += int64(n) + } + + if err != nil { + // If the read times out, record the bytes left to read before exiting. + // Reduce the remainingTime. + nerr := net.Error(nil) + if l := int32(n); l == 0 && errors.As(err, &nerr) && nerr.Timeout() { + pendingreadState.remainingBytes = l + pendingreadState.remainingBytes + } + + err = transformNetworkError(ctx, err, contextDeadlineUsed) + return n, fmt.Errorf("error discarding %d byte message: %w", size, err) + } + + pendingreadState.start = time.Now() + } + + return int(totalRead), nil +} + +// awaitPendingRead sets a new read deadline on the provided connection and +// tries to read any bytes returned by the server. If there are any errors, the +// connection will be checked back into the pool to be retried. +func awaitPendingRead(ctx context.Context, pool *pool, conn *connection) error { + conn.pendingReadMu.Lock() + defer conn.pendingReadMu.Unlock() + + // If there are no bytes pending read, do nothing. + if conn.pendingReadState == nil { + return nil + } + + publishPendingReadStarted(pool, conn) + + var ( + pendingReadState = conn.pendingReadState + remainingTime = pendingReadState.start.Add(PendingReadTimeout).Sub(time.Now()) + err error + bytesRead int + ) + + st := time.Now() + if remainingTime <= 0 { + // If there is no remaining time, we can just peek at the connection to check + // aliveness. In such cases, we don't want to close the connection. + bytesRead, err = peekConnectionAlive(conn) + } else { + bytesRead, err = attemptPendingRead(ctx, conn, remainingTime) + } + + endTime := time.Now() + endDuration := time.Since(st) + if err != nil { - err = fmt.Errorf("error discarding %d byte message: %w", size, err) + // No matter what happens, always check the connection back into the + // pool, which will either make it available for other operations or + // remove it from the pool if it was closed. + // + // TODO(GODRIVER-3385): Figure out how to handle this error. It's possible + // that a single connection can be checked out to handle multiple concurrent + // operations. This is likely a bug in the Go Driver. So it's possible that + // the connection is idle at the point of check-in. + defer func() { + publishPendingReadFailed(pool, conn, err) + + _ = pool.checkInNoEvent(conn) + }() + + if netErr, ok := err.(net.Error); ok && !netErr.Timeout() { + fmt.Println(1) + if err := conn.close(); err != nil { + return err + } + return err + } } + + // If the read was successful, then we should refresh the remaining timeout. + if bytesRead > 0 { + pendingReadState.start = endTime + } + + // If the remaining time has been exceeded, then close the connection. + if endTime.Sub(pendingReadState.start) > PendingReadTimeout { + if err := conn.close(); err != nil { + return err + } + } + + if err != nil { + return err + } + + publishPendingReadSucceeded(pool, conn, endDuration) + + conn.pendingReadState = nil + + return nil } // checkIn returns an idle connection to the pool. If the connection is perished or the pool is @@ -888,21 +1105,6 @@ func (p *pool) checkInNoEvent(conn *connection) error { return ErrWrongPool } - // If the connection has an awaiting server response, try to read the - // response in another goroutine before checking it back into the pool. - // - // Do this here because we want to publish checkIn events when the operation - // is done with the connection, not when it's ready to be used again. That - // means that connections in "awaiting response" state are checked in but - // not usable, which is not covered by the current pool events. We may need - // to add pool event information in the future to communicate that. - if conn.awaitRemainingBytes != nil { - size := *conn.awaitRemainingBytes - conn.awaitRemainingBytes = nil - go bgRead(p, conn, size) - return nil - } - // Bump the connection idle start time here because we're about to make the // connection "available". The idle start time is used to determine how long // a connection has been idle and when it has reached its max idle time and diff --git a/x/mongo/driver/topology/pool_test.go b/x/mongo/driver/topology/pool_test.go index 3d270de2e0..c05e8133f0 100644 --- a/x/mongo/driver/topology/pool_test.go +++ b/x/mongo/driver/topology/pool_test.go @@ -18,6 +18,7 @@ import ( "go.mongodb.org/mongo-driver/v2/event" "go.mongodb.org/mongo-driver/v2/internal/assert" "go.mongodb.org/mongo-driver/v2/internal/csot" + "go.mongodb.org/mongo-driver/v2/internal/driverutil" "go.mongodb.org/mongo-driver/v2/internal/eventtest" "go.mongodb.org/mongo-driver/v2/internal/require" "go.mongodb.org/mongo-driver/v2/mongo/address" @@ -1233,24 +1234,10 @@ func TestPool_maintain(t *testing.T) { }) } -func TestBackgroundRead(t *testing.T) { +func TestAwaitPendingRead(t *testing.T) { t.Parallel() - newBGReadCallback := func(errsCh chan []error) func(string, time.Time, time.Time, []error, bool) { - return func(_ string, _, _ time.Time, errs []error, _ bool) { - errsCh <- errs - close(errsCh) - } - } - t.Run("incomplete read of message header", func(t *testing.T) { - errsCh := make(chan []error) - var originalCallback func(string, time.Time, time.Time, []error, bool) - originalCallback, BGReadCallback = BGReadCallback, newBGReadCallback(errsCh) - t.Cleanup(func() { - BGReadCallback = originalCallback - }) - timeout := 10 * time.Millisecond cleanup := make(chan struct{}) @@ -1274,24 +1261,21 @@ func TestBackgroundRead(t *testing.T) { conn, err := p.checkOut(context.Background()) require.NoError(t, err) + ctx, cancel := csot.WithTimeout(context.Background(), &timeout) defer cancel() + + ctx = driverutil.WithValueHasMaxTimeMS(ctx, true) + ctx = driverutil.WithRequestID(ctx, -1) + _, err = conn.readWireMessage(ctx) regex := regexp.MustCompile( `^connection\(.*\[-\d+\]\) incomplete read of message header: context deadline exceeded: read tcp 127.0.0.1:.*->127.0.0.1:.*: i\/o timeout$`, ) assert.True(t, regex.MatchString(err.Error()), "error %q does not match pattern %q", err, regex) - assert.Nil(t, conn.awaitRemainingBytes, "conn.awaitRemainingBytes should be nil") - close(errsCh) // this line causes a double close if BGReadCallback is ever called. + assert.Nil(t, conn.pendingReadState, "conn.awaitRemainingBytes should be nil") }) t.Run("timeout reading message header, successful background read", func(t *testing.T) { - errsCh := make(chan []error) - var originalCallback func(string, time.Time, time.Time, []error, bool) - originalCallback, BGReadCallback = BGReadCallback, newBGReadCallback(errsCh) - t.Cleanup(func() { - BGReadCallback = originalCallback - }) - timeout := 10 * time.Millisecond addr := bootstrapConnections(t, 1, func(nc net.Conn) { @@ -1305,8 +1289,20 @@ func TestBackgroundRead(t *testing.T) { require.NoError(t, err) }) + var pendingReadError error + monitor := &event.PoolMonitor{ + Event: func(pe *event.PoolEvent) { + if pe.Type == event.ConnectionPendingResponseFailed { + pendingReadError = pe.Error + } + }, + } + p := newPool( - poolConfig{Address: address.Address(addr.String())}, + poolConfig{ + Address: address.Address(addr.String()), + PoolMonitor: monitor, + }, ) defer p.close(context.Background()) err := p.ready() @@ -1314,8 +1310,13 @@ func TestBackgroundRead(t *testing.T) { conn, err := p.checkOut(context.Background()) require.NoError(t, err) + ctx, cancel := csot.WithTimeout(context.Background(), &timeout) defer cancel() + + ctx = driverutil.WithValueHasMaxTimeMS(ctx, true) + ctx = driverutil.WithRequestID(ctx, -1) + _, err = conn.readWireMessage(ctx) regex := regexp.MustCompile( `^connection\(.*\[-\d+\]\) incomplete read of message header: context deadline exceeded: read tcp 127.0.0.1:.*->127.0.0.1:.*: i\/o timeout$`, @@ -1323,22 +1324,13 @@ func TestBackgroundRead(t *testing.T) { assert.True(t, regex.MatchString(err.Error()), "error %q does not match pattern %q", err, regex) err = p.checkIn(conn) require.NoError(t, err) - var bgErrs []error - select { - case bgErrs = <-errsCh: - case <-time.After(3 * time.Second): - assert.Fail(t, "did not receive expected error after waiting for 3 seconds") - } - require.Len(t, bgErrs, 0, "expected no error from bgRead()") + + _, err = p.checkOut(context.Background()) + require.NoError(t, err) + + require.NoError(t, pendingReadError) }) t.Run("timeout reading message header, incomplete head during background read", func(t *testing.T) { - errsCh := make(chan []error) - var originalCallback func(string, time.Time, time.Time, []error, bool) - originalCallback, BGReadCallback = BGReadCallback, newBGReadCallback(errsCh) - t.Cleanup(func() { - BGReadCallback = originalCallback - }) - timeout := 10 * time.Millisecond addr := bootstrapConnections(t, 1, func(nc net.Conn) { @@ -1352,8 +1344,20 @@ func TestBackgroundRead(t *testing.T) { require.NoError(t, err) }) + var pendingReadError error + monitor := &event.PoolMonitor{ + Event: func(pe *event.PoolEvent) { + if pe.Type == event.ConnectionPendingResponseFailed { + pendingReadError = pe.Error + } + }, + } + p := newPool( - poolConfig{Address: address.Address(addr.String())}, + poolConfig{ + Address: address.Address(addr.String()), + PoolMonitor: monitor, + }, ) defer p.close(context.Background()) err := p.ready() @@ -1361,8 +1365,13 @@ func TestBackgroundRead(t *testing.T) { conn, err := p.checkOut(context.Background()) require.NoError(t, err) + ctx, cancel := csot.WithTimeout(context.Background(), &timeout) defer cancel() + + ctx = driverutil.WithValueHasMaxTimeMS(ctx, true) + ctx = driverutil.WithRequestID(ctx, -1) + _, err = conn.readWireMessage(ctx) regex := regexp.MustCompile( `^connection\(.*\[-\d+\]\) incomplete read of message header: context deadline exceeded: read tcp 127.0.0.1:.*->127.0.0.1:.*: i\/o timeout$`, @@ -1370,23 +1379,13 @@ func TestBackgroundRead(t *testing.T) { assert.True(t, regex.MatchString(err.Error()), "error %q does not match pattern %q", err, regex) err = p.checkIn(conn) require.NoError(t, err) - var bgErrs []error - select { - case bgErrs = <-errsCh: - case <-time.After(3 * time.Second): - assert.Fail(t, "did not receive expected error after waiting for 3 seconds") - } - require.Len(t, bgErrs, 1, "expected 1 error from bgRead()") - assert.EqualError(t, bgErrs[0], "error reading the message size: unexpected EOF") + + _, err = p.checkOut(context.Background()) + require.Error(t, err) + + assert.EqualError(t, pendingReadError, "error reading the message size: unexpected EOF") }) t.Run("timeout reading message header, background read timeout", func(t *testing.T) { - errsCh := make(chan []error) - var originalCallback func(string, time.Time, time.Time, []error, bool) - originalCallback, BGReadCallback = BGReadCallback, newBGReadCallback(errsCh) - t.Cleanup(func() { - BGReadCallback = originalCallback - }) - timeout := 10 * time.Millisecond cleanup := make(chan struct{}) @@ -1404,17 +1403,35 @@ func TestBackgroundRead(t *testing.T) { require.NoError(t, err) }) + var pendingReadError error + monitor := &event.PoolMonitor{ + Event: func(pe *event.PoolEvent) { + if pe.Type == event.ConnectionPendingResponseFailed { + pendingReadError = pe.Error + } + }, + } + p := newPool( - poolConfig{Address: address.Address(addr.String())}, + poolConfig{ + Address: address.Address(addr.String()), + PoolMonitor: monitor, + }, ) + defer p.close(context.Background()) err := p.ready() require.NoError(t, err) conn, err := p.checkOut(context.Background()) require.NoError(t, err) + ctx, cancel := csot.WithTimeout(context.Background(), &timeout) defer cancel() + + ctx = driverutil.WithValueHasMaxTimeMS(ctx, true) + ctx = driverutil.WithRequestID(ctx, -1) + _, err = conn.readWireMessage(ctx) regex := regexp.MustCompile( `^connection\(.*\[-\d+\]\) incomplete read of message header: context deadline exceeded: read tcp 127.0.0.1:.*->127.0.0.1:.*: i\/o timeout$`, @@ -1422,26 +1439,16 @@ func TestBackgroundRead(t *testing.T) { assert.True(t, regex.MatchString(err.Error()), "error %q does not match pattern %q", err, regex) err = p.checkIn(conn) require.NoError(t, err) - var bgErrs []error - select { - case bgErrs = <-errsCh: - case <-time.After(3 * time.Second): - assert.Fail(t, "did not receive expected error after waiting for 3 seconds") - } - require.Len(t, bgErrs, 1, "expected 1 error from bgRead()") + + _, err = p.checkOut(context.Background()) + require.Error(t, err) + wantErr := regexp.MustCompile( `^error discarding 6 byte message: read tcp 127.0.0.1:.*->127.0.0.1:.*: i\/o timeout$`, ) - assert.True(t, wantErr.MatchString(bgErrs[0].Error()), "error %q does not match pattern %q", bgErrs[0], wantErr) + assert.True(t, wantErr.MatchString(pendingReadError.Error()), "error %q does not match pattern %q", pendingReadError, wantErr) }) t.Run("timeout reading full message, successful background read", func(t *testing.T) { - errsCh := make(chan []error) - var originalCallback func(string, time.Time, time.Time, []error, bool) - originalCallback, BGReadCallback = BGReadCallback, newBGReadCallback(errsCh) - t.Cleanup(func() { - BGReadCallback = originalCallback - }) - timeout := 10 * time.Millisecond addr := bootstrapConnections(t, 1, func(nc net.Conn) { @@ -1458,17 +1465,35 @@ func TestBackgroundRead(t *testing.T) { require.NoError(t, err) }) + var pendingReadError error + monitor := &event.PoolMonitor{ + Event: func(pe *event.PoolEvent) { + if pe.Type == event.ConnectionPendingResponseFailed { + pendingReadError = pe.Error + } + }, + } + p := newPool( - poolConfig{Address: address.Address(addr.String())}, + poolConfig{ + Address: address.Address(addr.String()), + PoolMonitor: monitor, + }, ) + defer p.close(context.Background()) err := p.ready() require.NoError(t, err) conn, err := p.checkOut(context.Background()) require.NoError(t, err) + ctx, cancel := csot.WithTimeout(context.Background(), &timeout) defer cancel() + + ctx = driverutil.WithValueHasMaxTimeMS(ctx, true) + ctx = driverutil.WithRequestID(ctx, -1) + _, err = conn.readWireMessage(ctx) regex := regexp.MustCompile( `^connection\(.*\[-\d+\]\) incomplete read of full message: context deadline exceeded: read tcp 127.0.0.1:.*->127.0.0.1:.*: i\/o timeout$`, @@ -1476,22 +1501,13 @@ func TestBackgroundRead(t *testing.T) { assert.True(t, regex.MatchString(err.Error()), "error %q does not match pattern %q", err, regex) err = p.checkIn(conn) require.NoError(t, err) - var bgErrs []error - select { - case bgErrs = <-errsCh: - case <-time.After(3 * time.Second): - assert.Fail(t, "did not receive expected error after waiting for 3 seconds") - } - require.Len(t, bgErrs, 0, "expected no error from bgRead()") + + _, err = p.checkOut(context.Background()) + require.NoError(t, err) + + require.NoError(t, pendingReadError) }) t.Run("timeout reading full message, background read EOF", func(t *testing.T) { - errsCh := make(chan []error) - var originalCallback func(string, time.Time, time.Time, []error, bool) - originalCallback, BGReadCallback = BGReadCallback, newBGReadCallback(errsCh) - t.Cleanup(func() { - BGReadCallback = originalCallback - }) - timeout := 10 * time.Millisecond addr := bootstrapConnections(t, 1, func(nc net.Conn) { @@ -1508,17 +1524,35 @@ func TestBackgroundRead(t *testing.T) { require.NoError(t, err) }) + var pendingReadError error + monitor := &event.PoolMonitor{ + Event: func(pe *event.PoolEvent) { + if pe.Type == event.ConnectionPendingResponseFailed { + pendingReadError = pe.Error + } + }, + } + p := newPool( - poolConfig{Address: address.Address(addr.String())}, + poolConfig{ + Address: address.Address(addr.String()), + PoolMonitor: monitor, + }, ) + defer p.close(context.Background()) err := p.ready() require.NoError(t, err) conn, err := p.checkOut(context.Background()) require.NoError(t, err) + ctx, cancel := csot.WithTimeout(context.Background(), &timeout) defer cancel() + + ctx = driverutil.WithValueHasMaxTimeMS(ctx, true) + ctx = driverutil.WithRequestID(ctx, -1) + _, err = conn.readWireMessage(ctx) regex := regexp.MustCompile( `^connection\(.*\[-\d+\]\) incomplete read of full message: context deadline exceeded: read tcp 127.0.0.1:.*->127.0.0.1:.*: i\/o timeout$`, @@ -1526,14 +1560,11 @@ func TestBackgroundRead(t *testing.T) { assert.True(t, regex.MatchString(err.Error()), "error %q does not match pattern %q", err, regex) err = p.checkIn(conn) require.NoError(t, err) - var bgErrs []error - select { - case bgErrs = <-errsCh: - case <-time.After(3 * time.Second): - assert.Fail(t, "did not receive expected error after waiting for 3 seconds") - } - require.Len(t, bgErrs, 1, "expected 1 error from bgRead()") - assert.EqualError(t, bgErrs[0], "error discarding 3 byte message: EOF") + + _, err = p.checkOut(context.Background()) + require.Error(t, err) + + assert.EqualError(t, pendingReadError, "error discarding 3 byte message: EOF") }) }