Skip to content

Commit 3bb378e

Browse files
David-EngelDavoudEshtehari
authored andcommitted
[Fix] Hang on infinite timeout and managed SNI (dotnet#1742)
* Fix for issue dotnet#1733 - hang on infinite timeout Also fix SSRP when DataSource is an IP address
1 parent a06894f commit 3bb378e

File tree

5 files changed

+53
-41
lines changed

5 files changed

+53
-41
lines changed

src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNICommon.cs

+10
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
// See the LICENSE file in the project root for more information.
44

55
using System;
6+
using System.Net;
67
using System.Net.Security;
78
using System.Security.Cryptography.X509Certificates;
89

@@ -194,6 +195,15 @@ internal static bool ValidateSslServerCertificate(string targetServerName, X509C
194195
}
195196
}
196197

198+
internal static IPAddress[] GetDnsIpAddresses(string serverName)
199+
{
200+
using (TrySNIEventScope.Create(nameof(GetDnsIpAddresses)))
201+
{
202+
SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNICommon), EventType.INFO, "Getting DNS host entries for serverName {0}.", args0: serverName);
203+
return Dns.GetHostAddresses(serverName);
204+
}
205+
}
206+
197207
/// <summary>
198208
/// Sets last error encountered for SNI
199209
/// </summary>

src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs

+2-16
Original file line numberDiff line numberDiff line change
@@ -279,14 +279,7 @@ private Socket TryConnectParallel(string hostName, int port, TimeSpan ts, bool i
279279
Socket availableSocket = null;
280280
Task<Socket> connectTask;
281281

282-
Task<IPAddress[]> serverAddrTask = Dns.GetHostAddressesAsync(hostName);
283-
bool complete = serverAddrTask.Wait(ts);
284-
285-
// DNS timed out - don't block
286-
if (!complete)
287-
return null;
288-
289-
IPAddress[] serverAddresses = serverAddrTask.Result;
282+
IPAddress[] serverAddresses = SNICommon.GetDnsIpAddresses(hostName);
290283

291284
if (serverAddresses.Length > MaxParallelIpAddresses)
292285
{
@@ -338,14 +331,7 @@ private static Socket Connect(string serverName, int port, TimeSpan timeout, boo
338331
{
339332
SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.INFO, "IP preference : {0}", Enum.GetName(typeof(SqlConnectionIPAddressPreference), ipPreference));
340333

341-
Task<IPAddress[]> serverAddrTask = Dns.GetHostAddressesAsync(serverName);
342-
bool complete = serverAddrTask.Wait(timeout);
343-
344-
// DNS timed out - don't block
345-
if (!complete)
346-
return null;
347-
348-
IPAddress[] ipAddresses = serverAddrTask.Result;
334+
IPAddress[] ipAddresses = SNICommon.GetDnsIpAddresses(serverName);
349335

350336
string IPv4String = null;
351337
string IPv6String = null;

src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SSRP.cs

+12-23
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,16 @@ private static byte[] SendUDPRequest(string browserHostname, int port, byte[] re
164164
Debug.Assert(port >= 0 && port <= 65535, "Invalid port");
165165
Debug.Assert(requestPacket != null && requestPacket.Length > 0, "requestPacket should not be null or 0-length array");
166166

167-
bool isIpAddress = IPAddress.TryParse(browserHostname, out IPAddress address);
167+
if (IPAddress.TryParse(browserHostname, out IPAddress address))
168+
{
169+
SsrpResult response = SendUDPRequest(new IPAddress[] { address }, port, requestPacket, allIPsInParallel);
170+
if (response != null && response.ResponsePacket != null)
171+
return response.ResponsePacket;
172+
else if (response != null && response.Error != null)
173+
throw response.Error;
174+
else
175+
return null;
176+
}
168177

169178
TimeSpan ts = default;
170179
// In case the Timeout is Infinite, we will receive the max value of Int64 as the tick count
@@ -175,27 +184,7 @@ private static byte[] SendUDPRequest(string browserHostname, int port, byte[] re
175184
ts = ts.Ticks < 0 ? TimeSpan.FromTicks(0) : ts;
176185
}
177186

178-
IPAddress[] ipAddresses = null;
179-
if (!isIpAddress)
180-
{
181-
Task<IPAddress[]> serverAddrTask = Dns.GetHostAddressesAsync(browserHostname);
182-
bool taskComplete;
183-
try
184-
{
185-
taskComplete = serverAddrTask.Wait(ts);
186-
}
187-
catch (AggregateException ae)
188-
{
189-
throw ae.InnerException;
190-
}
191-
192-
// If DNS took too long, need to return instead of blocking
193-
if (!taskComplete)
194-
return null;
195-
196-
ipAddresses = serverAddrTask.Result;
197-
}
198-
187+
IPAddress[] ipAddresses = SNICommon.GetDnsIpAddresses(browserHostname);
199188
Debug.Assert(ipAddresses.Length > 0, "DNS should throw if zero addresses resolve");
200189

201190
switch (ipPreference)
@@ -272,7 +261,7 @@ private static SsrpResult SendUDPRequest(IPAddress[] ipAddresses, int port, byte
272261
for (int i = 0; i < ipAddresses.Length; i++)
273262
{
274263
IPEndPoint endPoint = new IPEndPoint(ipAddresses[i], port);
275-
tasks.Add(Task.Factory.StartNew<SsrpResult>(() => SendUDPRequest(endPoint, requestPacket)));
264+
tasks.Add(Task.Factory.StartNew<SsrpResult>(() => SendUDPRequest(endPoint, requestPacket), cts.Token));
276265
}
277266

278267
List<Task<SsrpResult>> completedTasks = new();

src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ConnectivityTests/ConnectivityTest.cs

+15
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,21 @@ public static void EnvironmentHostNameSPIDTest()
8585
Assert.True(false, "No non-empty hostname found for the application");
8686
}
8787

88+
[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup))]
89+
public static async void ConnectionTimeoutInfiniteTest()
90+
{
91+
// Exercise the special-case infinite connect timeout code path
92+
SqlConnectionStringBuilder builder = new(DataTestUtility.TCPConnectionString)
93+
{
94+
ConnectTimeout = 0 // Infinite
95+
};
96+
97+
using SqlConnection conn = new(builder.ConnectionString);
98+
CancellationTokenSource cts = new(30000);
99+
// Will throw TaskCanceledException and fail the test in the event of a hang
100+
await conn.OpenAsync(cts.Token);
101+
}
102+
88103
[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup))]
89104
public static void ConnectionTimeoutTestWithThread()
90105
{

src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/InstanceNameTest/InstanceNameTest.cs

+14-2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
// See the LICENSE file in the project root for more information.
44

55
using System;
6+
using System.Net;
67
using System.Net.Sockets;
78
using System.Text;
89
using System.Threading.Tasks;
@@ -12,17 +13,17 @@ namespace Microsoft.Data.SqlClient.ManualTesting.Tests
1213
{
1314
public static class InstanceNameTest
1415
{
15-
[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup))]
16+
[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.IsNotAzureServer), nameof(DataTestUtility.IsNotAzureSynapse), nameof(DataTestUtility.AreConnStringsSetup))]
1617
public static void ConnectToSQLWithInstanceNameTest()
1718
{
1819
SqlConnectionStringBuilder builder = new(DataTestUtility.TCPConnectionString);
1920

2021
bool proceed = true;
2122
string dataSourceStr = builder.DataSource.Replace("tcp:", "");
2223
string[] serverNamePartsByBackSlash = dataSourceStr.Split('\\');
24+
string hostname = serverNamePartsByBackSlash[0];
2325
if (!dataSourceStr.Contains(",") && serverNamePartsByBackSlash.Length == 2)
2426
{
25-
string hostname = serverNamePartsByBackSlash[0];
2627
proceed = !string.IsNullOrWhiteSpace(hostname) && IsBrowserAlive(hostname);
2728
}
2829

@@ -31,6 +32,17 @@ public static void ConnectToSQLWithInstanceNameTest()
3132
using SqlConnection connection = new(builder.ConnectionString);
3233
connection.Open();
3334
connection.Close();
35+
36+
if (builder.Encrypt != SqlConnectionEncryptOption.Strict)
37+
{
38+
// Exercise the IP address-specific code in SSRP
39+
IPAddress[] addresses = Dns.GetHostAddresses(hostname);
40+
builder.DataSource = builder.DataSource.Replace(hostname, addresses[0].ToString());
41+
builder.TrustServerCertificate = true;
42+
using SqlConnection connection2 = new(builder.ConnectionString);
43+
connection2.Open();
44+
connection2.Close();
45+
}
3446
}
3547
}
3648

0 commit comments

Comments
 (0)