3
3
// See the LICENSE file in the project root for more information.
4
4
5
5
using System ;
6
+ using System . Collections . Generic ;
6
7
using System . Diagnostics ;
8
+ using System . Linq ;
7
9
using System . Net ;
8
10
using System . Net . Sockets ;
9
11
using System . Text ;
12
+ using System . Threading ;
10
13
using System . Threading . Tasks ;
11
14
12
15
namespace Microsoft . Data . SqlClient . SNI
@@ -21,8 +24,11 @@ internal class SSRP
21
24
/// </summary>
22
25
/// <param name="browserHostName">SQL Sever Browser hostname</param>
23
26
/// <param name="instanceName">instance name to find port number</param>
27
+ /// <param name="timerExpire">Connection timer expiration</param>
28
+ /// <param name="allIPsInParallel">query all resolved IP addresses in parallel</param>
29
+ /// <param name="ipPreference">IP address preference</param>
24
30
/// <returns>port number for given instance name</returns>
25
- internal static int GetPortByInstanceName ( string browserHostName , string instanceName )
31
+ internal static int GetPortByInstanceName ( string browserHostName , string instanceName , long timerExpire , bool allIPsInParallel , SqlConnectionIPAddressPreference ipPreference )
26
32
{
27
33
Debug . Assert ( ! string . IsNullOrWhiteSpace ( browserHostName ) , "browserHostName should not be null, empty, or whitespace" ) ;
28
34
Debug . Assert ( ! string . IsNullOrWhiteSpace ( instanceName ) , "instanceName should not be null, empty, or whitespace" ) ;
@@ -32,7 +38,7 @@ internal static int GetPortByInstanceName(string browserHostName, string instanc
32
38
byte [ ] responsePacket = null ;
33
39
try
34
40
{
35
- responsePacket = SendUDPRequest ( browserHostName , SqlServerBrowserPort , instanceInfoRequest ) ;
41
+ responsePacket = SendUDPRequest ( browserHostName , SqlServerBrowserPort , instanceInfoRequest , timerExpire , allIPsInParallel , ipPreference ) ;
36
42
}
37
43
catch ( SocketException se )
38
44
{
@@ -87,14 +93,17 @@ private static byte[] CreateInstanceInfoRequest(string instanceName)
87
93
/// </summary>
88
94
/// <param name="browserHostName">SQL Sever Browser hostname</param>
89
95
/// <param name="instanceName">instance name to lookup DAC port</param>
96
+ /// <param name="timerExpire">Connection timer expiration</param>
97
+ /// <param name="allIPsInParallel">query all resolved IP addresses in parallel</param>
98
+ /// <param name="ipPreference">IP address preference</param>
90
99
/// <returns>DAC port for given instance name</returns>
91
- internal static int GetDacPortByInstanceName ( string browserHostName , string instanceName )
100
+ internal static int GetDacPortByInstanceName ( string browserHostName , string instanceName , long timerExpire , bool allIPsInParallel , SqlConnectionIPAddressPreference ipPreference )
92
101
{
93
102
Debug . Assert ( ! string . IsNullOrWhiteSpace ( browserHostName ) , "browserHostName should not be null, empty, or whitespace" ) ;
94
103
Debug . Assert ( ! string . IsNullOrWhiteSpace ( instanceName ) , "instanceName should not be null, empty, or whitespace" ) ;
95
104
96
105
byte [ ] dacPortInfoRequest = CreateDacPortInfoRequest ( instanceName ) ;
97
- byte [ ] responsePacket = SendUDPRequest ( browserHostName , SqlServerBrowserPort , dacPortInfoRequest ) ;
106
+ byte [ ] responsePacket = SendUDPRequest ( browserHostName , SqlServerBrowserPort , dacPortInfoRequest , timerExpire , allIPsInParallel , ipPreference ) ;
98
107
99
108
const byte SvrResp = 0x05 ;
100
109
const byte ProtocolVersion = 0x01 ;
@@ -131,43 +140,198 @@ private static byte[] CreateDacPortInfoRequest(string instanceName)
131
140
return requestPacket ;
132
141
}
133
142
143
+ private class SsrpResult
144
+ {
145
+ public byte [ ] ResponsePacket ;
146
+ public Exception Error ;
147
+ }
148
+
134
149
/// <summary>
135
150
/// Sends request to server, and receives response from server by UDP.
136
151
/// </summary>
137
152
/// <param name="browserHostname">UDP server hostname</param>
138
153
/// <param name="port">UDP server port</param>
139
154
/// <param name="requestPacket">request packet</param>
155
+ /// <param name="timerExpire">Connection timer expiration</param>
156
+ /// <param name="allIPsInParallel">query all resolved IP addresses in parallel</param>
157
+ /// <param name="ipPreference">IP address preference</param>
140
158
/// <returns>response packet from UDP server</returns>
141
- private static byte [ ] SendUDPRequest ( string browserHostname , int port , byte [ ] requestPacket )
159
+ private static byte [ ] SendUDPRequest ( string browserHostname , int port , byte [ ] requestPacket , long timerExpire , bool allIPsInParallel , SqlConnectionIPAddressPreference ipPreference )
142
160
{
143
161
using ( TrySNIEventScope . Create ( nameof ( SSRP ) ) )
144
162
{
145
163
Debug . Assert ( ! string . IsNullOrWhiteSpace ( browserHostname ) , "browserhostname should not be null, empty, or whitespace" ) ;
146
164
Debug . Assert ( port >= 0 && port <= 65535 , "Invalid port" ) ;
147
165
Debug . Assert ( requestPacket != null && requestPacket . Length > 0 , "requestPacket should not be null or 0-length array" ) ;
148
166
149
- const int sendTimeOutMs = 1000 ;
150
- const int receiveTimeOutMs = 1000 ;
167
+ bool isIpAddress = IPAddress . TryParse ( browserHostname , out IPAddress address ) ;
151
168
152
- IPAddress address = null ;
153
- bool isIpAddress = IPAddress . TryParse ( browserHostname , out address ) ;
169
+ TimeSpan ts = default ;
170
+ // In case the Timeout is Infinite, we will receive the max value of Int64 as the tick count
171
+ // The infinite Timeout is a function of ConnectionString Timeout=0
172
+ if ( long . MaxValue != timerExpire )
173
+ {
174
+ ts = DateTime . FromFileTime ( timerExpire ) - DateTime . Now ;
175
+ ts = ts . Ticks < 0 ? TimeSpan . FromTicks ( 0 ) : ts ;
176
+ }
154
177
155
- byte [ ] responsePacket = null ;
156
- using ( UdpClient client = new UdpClient ( ! isIpAddress ? AddressFamily . InterNetwork : address . AddressFamily ) )
178
+ IPAddress [ ] ipAddresses = null ;
179
+ if ( ! isIpAddress )
180
+ {
181
+ Task < IPAddress [ ] > serverAddrTask = Dns . GetHostAddressesAsync ( browserHostname ) ;
182
+ bool taskComplete ;
183
+ try
184
+ {
185
+ taskComplete = serverAddrTask . Wait ( ts ) ;
186
+ }
187
+ catch ( AggregateException ae )
188
+ {
189
+ throw ae . InnerException ;
190
+ }
191
+
192
+ // If DNS took too long, need to return instead of blocking
193
+ if ( ! taskComplete )
194
+ return null ;
195
+
196
+ ipAddresses = serverAddrTask . Result ;
197
+ }
198
+
199
+ Debug . Assert ( ipAddresses . Length > 0 , "DNS should throw if zero addresses resolve" ) ;
200
+
201
+ switch ( ipPreference )
157
202
{
158
- Task < int > sendTask = client . SendAsync ( requestPacket , requestPacket . Length , browserHostname , port ) ;
203
+ case SqlConnectionIPAddressPreference . IPv4First :
204
+ {
205
+ SsrpResult response4 = SendUDPRequest ( ipAddresses . Where ( i => i . AddressFamily == AddressFamily . InterNetwork ) . ToArray ( ) , port , requestPacket , allIPsInParallel ) ;
206
+ if ( response4 != null && response4 . ResponsePacket != null )
207
+ return response4 . ResponsePacket ;
208
+
209
+ SsrpResult response6 = SendUDPRequest ( ipAddresses . Where ( i => i . AddressFamily == AddressFamily . InterNetworkV6 ) . ToArray ( ) , port , requestPacket , allIPsInParallel ) ;
210
+ if ( response6 != null && response6 . ResponsePacket != null )
211
+ return response6 . ResponsePacket ;
212
+
213
+ // No responses so throw first error
214
+ if ( response4 != null && response4 . Error != null )
215
+ throw response4 . Error ;
216
+ else if ( response6 != null && response6 . Error != null )
217
+ throw response6 . Error ;
218
+
219
+ break ;
220
+ }
221
+ case SqlConnectionIPAddressPreference . IPv6First :
222
+ {
223
+ SsrpResult response6 = SendUDPRequest ( ipAddresses . Where ( i => i . AddressFamily == AddressFamily . InterNetworkV6 ) . ToArray ( ) , port , requestPacket , allIPsInParallel ) ;
224
+ if ( response6 != null && response6 . ResponsePacket != null )
225
+ return response6 . ResponsePacket ;
226
+
227
+ SsrpResult response4 = SendUDPRequest ( ipAddresses . Where ( i => i . AddressFamily == AddressFamily . InterNetwork ) . ToArray ( ) , port , requestPacket , allIPsInParallel ) ;
228
+ if ( response4 != null && response4 . ResponsePacket != null )
229
+ return response4 . ResponsePacket ;
230
+
231
+ // No responses so throw first error
232
+ if ( response6 != null && response6 . Error != null )
233
+ throw response6 . Error ;
234
+ else if ( response4 != null && response4 . Error != null )
235
+ throw response4 . Error ;
236
+
237
+ break ;
238
+ }
239
+ default :
240
+ {
241
+ SsrpResult response = SendUDPRequest ( ipAddresses , port , requestPacket , true ) ; // allIPsInParallel);
242
+ if ( response != null && response . ResponsePacket != null )
243
+ return response . ResponsePacket ;
244
+ else if ( response != null && response . Error != null )
245
+ throw response . Error ;
246
+
247
+ break ;
248
+ }
249
+ }
250
+
251
+ return null ;
252
+ }
253
+ }
254
+
255
+ /// <summary>
256
+ /// Sends request to server, and receives response from server by UDP.
257
+ /// </summary>
258
+ /// <param name="ipAddresses">IP Addresses</param>
259
+ /// <param name="port">UDP server port</param>
260
+ /// <param name="requestPacket">request packet</param>
261
+ /// <param name="allIPsInParallel">query all resolved IP addresses in parallel</param>
262
+ /// <returns>response packet from UDP server</returns>
263
+ private static SsrpResult SendUDPRequest ( IPAddress [ ] ipAddresses , int port , byte [ ] requestPacket , bool allIPsInParallel )
264
+ {
265
+ if ( ipAddresses . Length == 0 )
266
+ return null ;
267
+
268
+ if ( allIPsInParallel ) // Used for MultiSubnetFailover
269
+ {
270
+ List < Task < SsrpResult > > tasks = new ( ipAddresses . Length ) ;
271
+ CancellationTokenSource cts = new CancellationTokenSource ( ) ;
272
+ for ( int i = 0 ; i < ipAddresses . Length ; i ++ )
273
+ {
274
+ IPEndPoint endPoint = new IPEndPoint ( ipAddresses [ i ] , port ) ;
275
+ tasks . Add ( Task . Factory . StartNew < SsrpResult > ( ( ) => SendUDPRequest ( endPoint , requestPacket ) ) ) ;
276
+ }
277
+
278
+ List < Task < SsrpResult > > completedTasks = new ( ) ;
279
+ while ( tasks . Count > 0 )
280
+ {
281
+ int first = Task . WaitAny ( tasks . ToArray ( ) ) ;
282
+ if ( tasks [ first ] . Result . ResponsePacket != null )
283
+ {
284
+ cts . Cancel ( ) ;
285
+ return tasks [ first ] . Result ;
286
+ }
287
+ else
288
+ {
289
+ completedTasks . Add ( tasks [ first ] ) ;
290
+ tasks . Remove ( tasks [ first ] ) ;
291
+ }
292
+ }
293
+
294
+ Debug . Assert ( completedTasks . Count > 0 , "completedTasks should never be 0" ) ;
295
+
296
+ // All tasks failed. Return the error from the first failure.
297
+ return completedTasks [ 0 ] . Result ;
298
+ }
299
+ else
300
+ {
301
+ // If not parallel, use the first IP address provided
302
+ IPEndPoint endPoint = new IPEndPoint ( ipAddresses [ 0 ] , port ) ;
303
+ return SendUDPRequest ( endPoint , requestPacket ) ;
304
+ }
305
+ }
306
+
307
+ private static SsrpResult SendUDPRequest ( IPEndPoint endPoint , byte [ ] requestPacket )
308
+ {
309
+ const int sendTimeOutMs = 1000 ;
310
+ const int receiveTimeOutMs = 1000 ;
311
+
312
+ SsrpResult result = new ( ) ;
313
+
314
+ try
315
+ {
316
+ using ( UdpClient client = new UdpClient ( endPoint . AddressFamily ) )
317
+ {
318
+ Task < int > sendTask = client . SendAsync ( requestPacket , requestPacket . Length , endPoint ) ;
159
319
Task < UdpReceiveResult > receiveTask = null ;
160
-
320
+
161
321
SqlClientEventSource . Log . TrySNITraceEvent ( nameof ( SSRP ) , EventType . INFO , "Waiting for UDP Client to fetch Port info." ) ;
162
322
if ( sendTask . Wait ( sendTimeOutMs ) && ( receiveTask = client . ReceiveAsync ( ) ) . Wait ( receiveTimeOutMs ) )
163
323
{
164
324
SqlClientEventSource . Log . TrySNITraceEvent ( nameof ( SSRP ) , EventType . INFO , "Received Port info from UDP Client." ) ;
165
- responsePacket = receiveTask . Result . Buffer ;
325
+ result . ResponsePacket = receiveTask . Result . Buffer ;
166
326
}
167
327
}
168
-
169
- return responsePacket ;
170
328
}
329
+ catch ( Exception e )
330
+ {
331
+ result . Error = e ;
332
+ }
333
+
334
+ return result ;
171
335
}
172
336
}
173
337
}
0 commit comments