Skip to content

Commit 99bc353

Browse files
Parallelize SSRP requests when MSF is specified (#1578) (#1708)
1 parent 5cca4a7 commit 99bc353

File tree

5 files changed

+319
-35
lines changed

5 files changed

+319
-35
lines changed

src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNICommon.cs

+1
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ internal class SNICommon
108108
internal const int ConnTimeoutError = 11;
109109
internal const int ConnNotUsableError = 19;
110110
internal const int InvalidConnStringError = 25;
111+
internal const int ErrorLocatingServerInstance = 26;
111112
internal const int HandshakeFailureError = 31;
112113
internal const int InternalExceptionError = 35;
113114
internal const int ConnOpenFailedError = 40;

src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs

+5-5
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ private static bool IsErrorStatus(SecurityStatusPalErrorCode errorCode)
141141
/// <param name="isIntegratedSecurity"></param>
142142
/// <param name="ipPreference">IP address preference</param>
143143
/// <param name="cachedFQDN">Used for DNS Cache</param>
144-
/// <param name="pendingDNSInfo">Used for DNS Cache</param>
144+
/// <param name="pendingDNSInfo">Used for DNS Cache</param>
145145
/// <returns>SNI handle</returns>
146146
internal static SNIHandle CreateConnectionHandle(string fullServerName, bool ignoreSniOpenTimeout, long timerExpire, out byte[] instanceName, ref byte[][] spnBuffer,
147147
bool flushCache, bool async, bool parallel, bool isIntegratedSecurity, SqlConnectionIPAddressPreference ipPreference, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo)
@@ -263,7 +263,7 @@ private static byte[][] GetSqlServerSPNs(string hostNameOrAddress, string portOr
263263
/// <param name="parallel">Should MultiSubnetFailover be used</param>
264264
/// <param name="ipPreference">IP address preference</param>
265265
/// <param name="cachedFQDN">Key for DNS Cache</param>
266-
/// <param name="pendingDNSInfo">Used for DNS Cache</param>
266+
/// <param name="pendingDNSInfo">Used for DNS Cache</param>
267267
/// <returns>SNITCPHandle</returns>
268268
private static SNITCPHandle CreateTcpHandle(DataSource details, long timerExpire, bool parallel, SqlConnectionIPAddressPreference ipPreference, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo)
269269
{
@@ -285,12 +285,12 @@ private static SNITCPHandle CreateTcpHandle(DataSource details, long timerExpire
285285
try
286286
{
287287
port = isAdminConnection ?
288-
SSRP.GetDacPortByInstanceName(hostName, details.InstanceName) :
289-
SSRP.GetPortByInstanceName(hostName, details.InstanceName);
288+
SSRP.GetDacPortByInstanceName(hostName, details.InstanceName, timerExpire, parallel, ipPreference) :
289+
SSRP.GetPortByInstanceName(hostName, details.InstanceName, timerExpire, parallel, ipPreference);
290290
}
291291
catch (SocketException se)
292292
{
293-
SNILoadHandle.SingletonInstance.LastError = new SNIError(SNIProviders.TCP_PROV, SNICommon.InvalidConnStringError, se);
293+
SNILoadHandle.SingletonInstance.LastError = new SNIError(SNIProviders.TCP_PROV, SNICommon.ErrorLocatingServerInstance, se);
294294
return null;
295295
}
296296
}

src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs

+17-6
Original file line numberDiff line numberDiff line change
@@ -146,9 +146,9 @@ public SNITCPHandle(string serverName, int port, long timerExpire, bool parallel
146146
bool reportError = true;
147147

148148
SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.INFO, "Connection Id {0}, Connecting to serverName {1} and port {2}", args0: _connectionId, args1: serverName, args2: port);
149-
// We will always first try to connect with serverName as before and let the DNS server to resolve the serverName.
150-
// If the DSN resolution fails, we will try with IPs in the DNS cache if existed. We try with cached IPs based on IPAddressPreference.
151-
// The exceptions will be throw to upper level and be handled as before.
149+
// We will always first try to connect with serverName as before and let DNS resolve the serverName.
150+
// If DNS resolution fails, we will try with IPs in the DNS cache if they exist. We try with cached IPs based on IPAddressPreference.
151+
// Exceptions will be thrown to the caller and be handled as before.
152152
try
153153
{
154154
if (parallel)
@@ -280,7 +280,12 @@ private Socket TryConnectParallel(string hostName, int port, TimeSpan ts, bool i
280280
Task<Socket> connectTask;
281281

282282
Task<IPAddress[]> serverAddrTask = Dns.GetHostAddressesAsync(hostName);
283-
serverAddrTask.Wait(ts);
283+
bool complete = serverAddrTask.Wait(ts);
284+
285+
// DNS timed out - don't block
286+
if (!complete)
287+
return null;
288+
284289
IPAddress[] serverAddresses = serverAddrTask.Result;
285290

286291
if (serverAddresses.Length > MaxParallelIpAddresses)
@@ -324,7 +329,6 @@ private Socket TryConnectParallel(string hostName, int port, TimeSpan ts, bool i
324329

325330
availableSocket = connectTask.Result;
326331
return availableSocket;
327-
328332
}
329333

330334
// Connect to server with hostName and port.
@@ -334,7 +338,14 @@ private static Socket Connect(string serverName, int port, TimeSpan timeout, boo
334338
{
335339
SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.INFO, "IP preference : {0}", Enum.GetName(typeof(SqlConnectionIPAddressPreference), ipPreference));
336340

337-
IPAddress[] ipAddresses = Dns.GetHostAddresses(serverName);
341+
Task<IPAddress[]> serverAddrTask = Dns.GetHostAddressesAsync(serverName);
342+
bool complete = serverAddrTask.Wait(timeout);
343+
344+
// DNS timed out - don't block
345+
if (!complete)
346+
return null;
347+
348+
IPAddress[] ipAddresses = serverAddrTask.Result;
338349

339350
string IPv4String = null;
340351
string IPv6String = null;

src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SSRP.cs

+180-16
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,13 @@
33
// See the LICENSE file in the project root for more information.
44

55
using System;
6+
using System.Collections.Generic;
67
using System.Diagnostics;
8+
using System.Linq;
79
using System.Net;
810
using System.Net.Sockets;
911
using System.Text;
12+
using System.Threading;
1013
using System.Threading.Tasks;
1114

1215
namespace Microsoft.Data.SqlClient.SNI
@@ -21,8 +24,11 @@ internal class SSRP
2124
/// </summary>
2225
/// <param name="browserHostName">SQL Sever Browser hostname</param>
2326
/// <param name="instanceName">instance name to find port number</param>
27+
/// <param name="timerExpire">Connection timer expiration</param>
28+
/// <param name="allIPsInParallel">query all resolved IP addresses in parallel</param>
29+
/// <param name="ipPreference">IP address preference</param>
2430
/// <returns>port number for given instance name</returns>
25-
internal static int GetPortByInstanceName(string browserHostName, string instanceName)
31+
internal static int GetPortByInstanceName(string browserHostName, string instanceName, long timerExpire, bool allIPsInParallel, SqlConnectionIPAddressPreference ipPreference)
2632
{
2733
Debug.Assert(!string.IsNullOrWhiteSpace(browserHostName), "browserHostName should not be null, empty, or whitespace");
2834
Debug.Assert(!string.IsNullOrWhiteSpace(instanceName), "instanceName should not be null, empty, or whitespace");
@@ -32,7 +38,7 @@ internal static int GetPortByInstanceName(string browserHostName, string instanc
3238
byte[] responsePacket = null;
3339
try
3440
{
35-
responsePacket = SendUDPRequest(browserHostName, SqlServerBrowserPort, instanceInfoRequest);
41+
responsePacket = SendUDPRequest(browserHostName, SqlServerBrowserPort, instanceInfoRequest, timerExpire, allIPsInParallel, ipPreference);
3642
}
3743
catch (SocketException se)
3844
{
@@ -87,14 +93,17 @@ private static byte[] CreateInstanceInfoRequest(string instanceName)
8793
/// </summary>
8894
/// <param name="browserHostName">SQL Sever Browser hostname</param>
8995
/// <param name="instanceName">instance name to lookup DAC port</param>
96+
/// <param name="timerExpire">Connection timer expiration</param>
97+
/// <param name="allIPsInParallel">query all resolved IP addresses in parallel</param>
98+
/// <param name="ipPreference">IP address preference</param>
9099
/// <returns>DAC port for given instance name</returns>
91-
internal static int GetDacPortByInstanceName(string browserHostName, string instanceName)
100+
internal static int GetDacPortByInstanceName(string browserHostName, string instanceName, long timerExpire, bool allIPsInParallel, SqlConnectionIPAddressPreference ipPreference)
92101
{
93102
Debug.Assert(!string.IsNullOrWhiteSpace(browserHostName), "browserHostName should not be null, empty, or whitespace");
94103
Debug.Assert(!string.IsNullOrWhiteSpace(instanceName), "instanceName should not be null, empty, or whitespace");
95104

96105
byte[] dacPortInfoRequest = CreateDacPortInfoRequest(instanceName);
97-
byte[] responsePacket = SendUDPRequest(browserHostName, SqlServerBrowserPort, dacPortInfoRequest);
106+
byte[] responsePacket = SendUDPRequest(browserHostName, SqlServerBrowserPort, dacPortInfoRequest, timerExpire, allIPsInParallel, ipPreference);
98107

99108
const byte SvrResp = 0x05;
100109
const byte ProtocolVersion = 0x01;
@@ -131,43 +140,198 @@ private static byte[] CreateDacPortInfoRequest(string instanceName)
131140
return requestPacket;
132141
}
133142

143+
private class SsrpResult
144+
{
145+
public byte[] ResponsePacket;
146+
public Exception Error;
147+
}
148+
134149
/// <summary>
135150
/// Sends request to server, and receives response from server by UDP.
136151
/// </summary>
137152
/// <param name="browserHostname">UDP server hostname</param>
138153
/// <param name="port">UDP server port</param>
139154
/// <param name="requestPacket">request packet</param>
155+
/// <param name="timerExpire">Connection timer expiration</param>
156+
/// <param name="allIPsInParallel">query all resolved IP addresses in parallel</param>
157+
/// <param name="ipPreference">IP address preference</param>
140158
/// <returns>response packet from UDP server</returns>
141-
private static byte[] SendUDPRequest(string browserHostname, int port, byte[] requestPacket)
159+
private static byte[] SendUDPRequest(string browserHostname, int port, byte[] requestPacket, long timerExpire, bool allIPsInParallel, SqlConnectionIPAddressPreference ipPreference)
142160
{
143161
using (TrySNIEventScope.Create(nameof(SSRP)))
144162
{
145163
Debug.Assert(!string.IsNullOrWhiteSpace(browserHostname), "browserhostname should not be null, empty, or whitespace");
146164
Debug.Assert(port >= 0 && port <= 65535, "Invalid port");
147165
Debug.Assert(requestPacket != null && requestPacket.Length > 0, "requestPacket should not be null or 0-length array");
148166

149-
const int sendTimeOutMs = 1000;
150-
const int receiveTimeOutMs = 1000;
167+
bool isIpAddress = IPAddress.TryParse(browserHostname, out IPAddress address);
151168

152-
IPAddress address = null;
153-
bool isIpAddress = IPAddress.TryParse(browserHostname, out address);
169+
TimeSpan ts = default;
170+
// In case the Timeout is Infinite, we will receive the max value of Int64 as the tick count
171+
// The infinite Timeout is a function of ConnectionString Timeout=0
172+
if (long.MaxValue != timerExpire)
173+
{
174+
ts = DateTime.FromFileTime(timerExpire) - DateTime.Now;
175+
ts = ts.Ticks < 0 ? TimeSpan.FromTicks(0) : ts;
176+
}
154177

155-
byte[] responsePacket = null;
156-
using (UdpClient client = new UdpClient(!isIpAddress ? AddressFamily.InterNetwork : address.AddressFamily))
178+
IPAddress[] ipAddresses = null;
179+
if (!isIpAddress)
180+
{
181+
Task<IPAddress[]> serverAddrTask = Dns.GetHostAddressesAsync(browserHostname);
182+
bool taskComplete;
183+
try
184+
{
185+
taskComplete = serverAddrTask.Wait(ts);
186+
}
187+
catch (AggregateException ae)
188+
{
189+
throw ae.InnerException;
190+
}
191+
192+
// If DNS took too long, need to return instead of blocking
193+
if (!taskComplete)
194+
return null;
195+
196+
ipAddresses = serverAddrTask.Result;
197+
}
198+
199+
Debug.Assert(ipAddresses.Length > 0, "DNS should throw if zero addresses resolve");
200+
201+
switch (ipPreference)
157202
{
158-
Task<int> sendTask = client.SendAsync(requestPacket, requestPacket.Length, browserHostname, port);
203+
case SqlConnectionIPAddressPreference.IPv4First:
204+
{
205+
SsrpResult response4 = SendUDPRequest(ipAddresses.Where(i => i.AddressFamily == AddressFamily.InterNetwork).ToArray(), port, requestPacket, allIPsInParallel);
206+
if (response4 != null && response4.ResponsePacket != null)
207+
return response4.ResponsePacket;
208+
209+
SsrpResult response6 = SendUDPRequest(ipAddresses.Where(i => i.AddressFamily == AddressFamily.InterNetworkV6).ToArray(), port, requestPacket, allIPsInParallel);
210+
if (response6 != null && response6.ResponsePacket != null)
211+
return response6.ResponsePacket;
212+
213+
// No responses so throw first error
214+
if (response4 != null && response4.Error != null)
215+
throw response4.Error;
216+
else if (response6 != null && response6.Error != null)
217+
throw response6.Error;
218+
219+
break;
220+
}
221+
case SqlConnectionIPAddressPreference.IPv6First:
222+
{
223+
SsrpResult response6 = SendUDPRequest(ipAddresses.Where(i => i.AddressFamily == AddressFamily.InterNetworkV6).ToArray(), port, requestPacket, allIPsInParallel);
224+
if (response6 != null && response6.ResponsePacket != null)
225+
return response6.ResponsePacket;
226+
227+
SsrpResult response4 = SendUDPRequest(ipAddresses.Where(i => i.AddressFamily == AddressFamily.InterNetwork).ToArray(), port, requestPacket, allIPsInParallel);
228+
if (response4 != null && response4.ResponsePacket != null)
229+
return response4.ResponsePacket;
230+
231+
// No responses so throw first error
232+
if (response6 != null && response6.Error != null)
233+
throw response6.Error;
234+
else if (response4 != null && response4.Error != null)
235+
throw response4.Error;
236+
237+
break;
238+
}
239+
default:
240+
{
241+
SsrpResult response = SendUDPRequest(ipAddresses, port, requestPacket, true); // allIPsInParallel);
242+
if (response != null && response.ResponsePacket != null)
243+
return response.ResponsePacket;
244+
else if (response != null && response.Error != null)
245+
throw response.Error;
246+
247+
break;
248+
}
249+
}
250+
251+
return null;
252+
}
253+
}
254+
255+
/// <summary>
256+
/// Sends request to server, and receives response from server by UDP.
257+
/// </summary>
258+
/// <param name="ipAddresses">IP Addresses</param>
259+
/// <param name="port">UDP server port</param>
260+
/// <param name="requestPacket">request packet</param>
261+
/// <param name="allIPsInParallel">query all resolved IP addresses in parallel</param>
262+
/// <returns>response packet from UDP server</returns>
263+
private static SsrpResult SendUDPRequest(IPAddress[] ipAddresses, int port, byte[] requestPacket, bool allIPsInParallel)
264+
{
265+
if (ipAddresses.Length == 0)
266+
return null;
267+
268+
if (allIPsInParallel) // Used for MultiSubnetFailover
269+
{
270+
List<Task<SsrpResult>> tasks = new(ipAddresses.Length);
271+
CancellationTokenSource cts = new CancellationTokenSource();
272+
for (int i = 0; i < ipAddresses.Length; i++)
273+
{
274+
IPEndPoint endPoint = new IPEndPoint(ipAddresses[i], port);
275+
tasks.Add(Task.Factory.StartNew<SsrpResult>(() => SendUDPRequest(endPoint, requestPacket)));
276+
}
277+
278+
List<Task<SsrpResult>> completedTasks = new();
279+
while (tasks.Count > 0)
280+
{
281+
int first = Task.WaitAny(tasks.ToArray());
282+
if (tasks[first].Result.ResponsePacket != null)
283+
{
284+
cts.Cancel();
285+
return tasks[first].Result;
286+
}
287+
else
288+
{
289+
completedTasks.Add(tasks[first]);
290+
tasks.Remove(tasks[first]);
291+
}
292+
}
293+
294+
Debug.Assert(completedTasks.Count > 0, "completedTasks should never be 0");
295+
296+
// All tasks failed. Return the error from the first failure.
297+
return completedTasks[0].Result;
298+
}
299+
else
300+
{
301+
// If not parallel, use the first IP address provided
302+
IPEndPoint endPoint = new IPEndPoint(ipAddresses[0], port);
303+
return SendUDPRequest(endPoint, requestPacket);
304+
}
305+
}
306+
307+
private static SsrpResult SendUDPRequest(IPEndPoint endPoint, byte[] requestPacket)
308+
{
309+
const int sendTimeOutMs = 1000;
310+
const int receiveTimeOutMs = 1000;
311+
312+
SsrpResult result = new();
313+
314+
try
315+
{
316+
using (UdpClient client = new UdpClient(endPoint.AddressFamily))
317+
{
318+
Task<int> sendTask = client.SendAsync(requestPacket, requestPacket.Length, endPoint);
159319
Task<UdpReceiveResult> receiveTask = null;
160-
320+
161321
SqlClientEventSource.Log.TrySNITraceEvent(nameof(SSRP), EventType.INFO, "Waiting for UDP Client to fetch Port info.");
162322
if (sendTask.Wait(sendTimeOutMs) && (receiveTask = client.ReceiveAsync()).Wait(receiveTimeOutMs))
163323
{
164324
SqlClientEventSource.Log.TrySNITraceEvent(nameof(SSRP), EventType.INFO, "Received Port info from UDP Client.");
165-
responsePacket = receiveTask.Result.Buffer;
325+
result.ResponsePacket = receiveTask.Result.Buffer;
166326
}
167327
}
168-
169-
return responsePacket;
170328
}
329+
catch (Exception e)
330+
{
331+
result.Error = e;
332+
}
333+
334+
return result;
171335
}
172336
}
173337
}

0 commit comments

Comments
 (0)