Skip to content

Commit bbb5636

Browse files
committed
Use string for SNI instead of byte[]
1 parent 2013a71 commit bbb5636

19 files changed

+202
-92
lines changed

src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj

+3
Original file line numberDiff line numberDiff line change
@@ -515,6 +515,9 @@
515515
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\SqlObjectPool.cs">
516516
<Link>Microsoft\Data\SqlClient\SqlObjectPool.cs</Link>
517517
</Compile>
518+
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\SqlObjectPools.cs">
519+
<Link>Microsoft\Data\SqlClient\SqlObjectPools.cs</Link>
520+
</Compile>
518521
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\SqlParameter.cs">
519522
<Link>Microsoft\Data\SqlClient\SqlParameter.cs</Link>
520523
</Compile>

src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs

+7-6
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
using System;
66
using System.Buffers;
7+
using System.Collections.Generic;
78
using System.Diagnostics;
89
using System.IO;
910
using System.Net;
@@ -50,7 +51,7 @@ internal static SNIHandle CreateConnectionHandle(
5051
string fullServerName,
5152
TimeoutTimer timeout,
5253
out byte[] instanceName,
53-
ref byte[][] spnBuffer,
54+
ref string[] spnBuffer,
5455
string serverSPN,
5556
bool flushCache,
5657
bool async,
@@ -114,12 +115,12 @@ internal static SNIHandle CreateConnectionHandle(
114115
return sniHandle;
115116
}
116117

117-
private static byte[][] GetSqlServerSPNs(DataSource dataSource, string serverSPN)
118+
private static string[] GetSqlServerSPNs(DataSource dataSource, string serverSPN)
118119
{
119120
Debug.Assert(!string.IsNullOrWhiteSpace(dataSource.ServerName));
120121
if (!string.IsNullOrWhiteSpace(serverSPN))
121122
{
122-
return new byte[1][] { Encoding.Unicode.GetBytes(serverSPN) };
123+
return new[] { serverSPN };
123124
}
124125

125126
string hostName = dataSource.ServerName;
@@ -137,7 +138,7 @@ private static byte[][] GetSqlServerSPNs(DataSource dataSource, string serverSPN
137138
return GetSqlServerSPNs(hostName, postfix, dataSource.ResolvedProtocol);
138139
}
139140

140-
private static byte[][] GetSqlServerSPNs(string hostNameOrAddress, string portOrInstanceName, DataSource.Protocol protocol)
141+
private static string[] GetSqlServerSPNs(string hostNameOrAddress, string portOrInstanceName, DataSource.Protocol protocol)
141142
{
142143
Debug.Assert(!string.IsNullOrWhiteSpace(hostNameOrAddress));
143144
IPHostEntry hostEntry = null;
@@ -168,12 +169,12 @@ private static byte[][] GetSqlServerSPNs(string hostNameOrAddress, string portOr
168169
string serverSpnWithDefaultPort = serverSpn + $":{DefaultSqlServerPort}";
169170
// Set both SPNs with and without Port as Port is optional for default instance
170171
SqlClientEventSource.Log.TryAdvancedTraceEvent("SNIProxy.GetSqlServerSPN | Info | ServerSPNs {0} and {1}", serverSpn, serverSpnWithDefaultPort);
171-
return new byte[][] { Encoding.Unicode.GetBytes(serverSpn), Encoding.Unicode.GetBytes(serverSpnWithDefaultPort) };
172+
return new[] { serverSpn, serverSpnWithDefaultPort };
172173
}
173174
// else Named Pipes do not need to valid port
174175

175176
SqlClientEventSource.Log.TryAdvancedTraceEvent("SNIProxy.GetSqlServerSPN | Info | ServerSPN {0}", serverSpn);
176-
return new byte[][] { Encoding.Unicode.GetBytes(serverSpn) };
177+
return new[] { serverSpn };
177178
}
178179

179180
/// <summary>

src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs

+10-11
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,7 @@ internal sealed partial class TdsParser
112112

113113
private bool _is2022 = false;
114114

115-
private byte[][] _sniSpnBuffer = null;
116-
// UNDONE - need to have some for both instances - both command and default???
115+
private string[] _sniSpn = null;
117116

118117
// SqlStatistics
119118
private SqlStatistics _statistics = null;
@@ -390,7 +389,7 @@ internal void Connect(
390389
}
391390
else
392391
{
393-
_sniSpnBuffer = null;
392+
_sniSpn = null;
394393
SqlClientEventSource.Log.TryTraceEvent("TdsParser.Connect | SEC | Connection Object Id {0}, Authentication Mode: {1}", _connHandler.ObjectID,
395394
authType == SqlAuthenticationMethod.NotSpecified ? SqlAuthenticationMethod.SqlPassword.ToString() : authType.ToString());
396395
}
@@ -402,7 +401,7 @@ internal void Connect(
402401
SqlClientEventSource.Log.TryTraceEvent("<sc.TdsParser.Connect|SEC> Encryption will be disabled as target server is a SQL Local DB instance.");
403402
}
404403

405-
_sniSpnBuffer = null;
404+
_sniSpn = null;
406405
_authenticationProvider = null;
407406

408407
// AD Integrated behaves like Windows integrated when connecting to a non-fedAuth server
@@ -441,7 +440,7 @@ internal void Connect(
441440
serverInfo.ExtendedServerName,
442441
timeout,
443442
out instanceName,
444-
ref _sniSpnBuffer,
443+
ref _sniSpn,
445444
false,
446445
true,
447446
fParallel,
@@ -454,8 +453,6 @@ internal void Connect(
454453
hostNameInCertificate,
455454
serverCertificateFilename);
456455

457-
_authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this);
458-
459456
if (TdsEnums.SNI_SUCCESS != _physicalStateObj.Status)
460457
{
461458
_physicalStateObj.AddError(ProcessSNIError(_physicalStateObj));
@@ -470,6 +467,8 @@ internal void Connect(
470467
Debug.Fail("SNI returned status != success, but no error thrown?");
471468
}
472469

470+
_authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this);
471+
473472
_server = serverInfo.ResolvedServerName;
474473

475474
if (connHandler.PoolGroupProviderInfo != null)
@@ -540,7 +539,7 @@ internal void Connect(
540539
_physicalStateObj.CreatePhysicalSNIHandle(
541540
serverInfo.ExtendedServerName,
542541
timeout, out instanceName,
543-
ref _sniSpnBuffer,
542+
ref _sniSpn,
544543
true,
545544
true,
546545
fParallel,
@@ -553,15 +552,15 @@ internal void Connect(
553552
hostNameInCertificate,
554553
serverCertificateFilename);
555554

556-
_authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this);
557-
558555
if (TdsEnums.SNI_SUCCESS != _physicalStateObj.Status)
559556
{
560557
_physicalStateObj.AddError(ProcessSNIError(_physicalStateObj));
561558
SqlClientEventSource.Log.TryTraceEvent("<sc.TdsParser.Connect|ERR|SEC> Login failure");
562559
ThrowExceptionAndWarning(_physicalStateObj);
563560
}
564561

562+
_authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this);
563+
565564
uint retCode = _physicalStateObj.SniGetConnectionId(ref _connHandler._clientConnectionId);
566565

567566
Debug.Assert(retCode == TdsEnums.SNI_SUCCESS, "Unexpected failure state upon calling SniGetConnectionId");
@@ -13317,7 +13316,7 @@ internal string TraceString()
1331713316
_fMARS ? bool.TrueString : bool.FalseString,
1331813317
_sessionPool == null ? "(null)" : _sessionPool.TraceString(),
1331913318
_is2005 ? bool.TrueString : bool.FalseString,
13320-
_sniSpnBuffer == null ? "(null)" : _sniSpnBuffer.Length.ToString((IFormatProvider)null),
13319+
_sniSpn == null ? "(null)" : _sniSpn.Length.ToString((IFormatProvider)null),
1332113320
_physicalStateObj != null ? "(null)" : _physicalStateObj.ErrorCount.ToString((IFormatProvider)null),
1332213321
_physicalStateObj != null ? "(null)" : _physicalStateObj.WarningCount.ToString((IFormatProvider)null),
1332313322
_physicalStateObj != null ? "(null)" : _physicalStateObj.PreAttentionErrorCount.ToString((IFormatProvider)null),

src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.netcore.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ internal abstract void CreatePhysicalSNIHandle(
186186
string serverName,
187187
TimeoutTimer timeout,
188188
out byte[] instanceName,
189-
ref byte[][] spnBuffer,
189+
ref string[] spn,
190190
bool flushCache,
191191
bool async,
192192
bool fParallel,

src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ internal override void CreatePhysicalSNIHandle(
8181
string serverName,
8282
TimeoutTimer timeout,
8383
out byte[] instanceName,
84-
ref byte[][] spnBuffer,
84+
ref string[] spn,
8585
bool flushCache,
8686
bool async,
8787
bool parallel,
@@ -94,7 +94,7 @@ internal override void CreatePhysicalSNIHandle(
9494
string hostNameInCertificate,
9595
string serverCertificateFilename)
9696
{
97-
SNIHandle? sessionHandle = SNIProxy.CreateConnectionHandle(serverName, timeout, out instanceName, ref spnBuffer, serverSPN,
97+
SNIHandle? sessionHandle = SNIProxy.CreateConnectionHandle(serverName, timeout, out instanceName, ref spn, serverSPN,
9898
flushCache, async, parallel, isIntegratedSecurity, iPAddressPreference, cachedFQDN, ref pendingDNSInfo, tlsFirst,
9999
hostNameInCertificate, serverCertificateFilename);
100100

src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs

+5-8
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ internal override void CreatePhysicalSNIHandle(
144144
string serverName,
145145
TimeoutTimer timeout,
146146
out byte[] instanceName,
147-
ref byte[][] spnBuffer,
147+
ref string[] spn,
148148
bool flushCache,
149149
bool async,
150150
bool fParallel,
@@ -157,31 +157,28 @@ internal override void CreatePhysicalSNIHandle(
157157
string hostNameInCertificate,
158158
string serverCertificateFilename)
159159
{
160-
// We assume that the loadSSPILibrary has been called already. now allocate proper length of buffer
161-
spnBuffer = new byte[1][];
162160
if (isIntegratedSecurity)
163161
{
164162
// now allocate proper length of buffer
165163
if (!string.IsNullOrEmpty(serverSPN))
166164
{
167165
// Native SNI requires the Unicode encoding and any other encoding like UTF8 breaks the code.
168-
byte[] srvSPN = Encoding.Unicode.GetBytes(serverSPN);
169-
Trace.Assert(srvSPN.Length <= SniNativeWrapper.SniMaxComposedSpnLength, "Length of the provided SPN exceeded the buffer size.");
170-
spnBuffer[0] = srvSPN;
171166
SqlClientEventSource.Log.TryTraceEvent("<{0}.{1}|SEC> Server SPN `{2}` from the connection string is used.", nameof(TdsParserStateObjectNative), nameof(CreatePhysicalSNIHandle), serverSPN);
172167
}
173168
else
174169
{
175-
spnBuffer[0] = new byte[SniNativeWrapper.SniMaxComposedSpnLength];
170+
// This will signal to the interop layer that we need to retrieve the SPN
171+
serverSPN = string.Empty;
176172
}
177173
}
178174

179175
ConsumerInfo myInfo = CreateConsumerInfo(async);
180176
SQLDNSInfo cachedDNSInfo;
181177
bool ret = SQLFallbackDNSCache.Instance.GetDNSInfo(cachedFQDN, out cachedDNSInfo);
182178

183-
_sessionHandle = new SNIHandle(myInfo, serverName, spnBuffer[0], timeout.MillisecondsRemainingInt, out instanceName,
179+
_sessionHandle = new SNIHandle(myInfo, serverName, ref serverSPN, timeout.MillisecondsRemainingInt, out instanceName,
184180
flushCache, !async, fParallel, ipPreference, cachedDNSInfo, hostNameInCertificate);
181+
spn = new[] { serverSPN.TrimEnd() };
185182
}
186183

187184
protected override uint SNIPacketGetData(PacketHandle packet, byte[] _inBuff, ref uint dataSize)

src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj

+8-4
Original file line numberDiff line numberDiff line change
@@ -337,10 +337,13 @@
337337
<Compile Include="$(CommonSourceRoot)Microsoft\Data\Sql\SqlDataSourceEnumerator.cs">
338338
<Link>Microsoft\Data\Sql\SqlDataSourceEnumerator.cs</Link>
339339
</Compile>
340-
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\SqlObjectPool.cs">
341-
<Link>Microsoft\Data\SqlClient\SqlObjectPool.cs</Link>
342-
</Compile>
343-
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\AAsyncCallContext.cs">
340+
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\SqlObjectPool.cs">
341+
<Link>Microsoft\Data\SqlClient\SqlObjectPool.cs</Link>
342+
</Compile>
343+
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\SqlObjectPools.cs">
344+
<Link>Microsoft\Data\SqlClient\SqlObjectPools.cs</Link>
345+
</Compile>
346+
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\AAsyncCallContext.cs">
344347
<Link>Microsoft\Data\SqlClient\AAsyncCallContext.cs</Link>
345348
</Compile>
346349
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\ActiveDirectoryAuthenticationTimeoutRetryHelper.cs">
@@ -854,6 +857,7 @@
854857
<Compile Include="Microsoft\Data\Common\DbConnectionString.cs" />
855858
<Compile Include="Microsoft\Data\Common\GreenMethods.cs" />
856859
<Compile Include="Microsoft\Data\SqlClient\assemblycache.cs" />
860+
<Compile Include="Microsoft\Data\SqlClient\BufferWriterExtensions.cs" />
857861
<Compile Include="Microsoft\Data\SqlClient\Reliability\SqlConfigurableRetryLogicManager.LoadType.cs" />
858862
<Compile Include="Microsoft\Data\SqlClient\Server\SmiConnection.cs" />
859863
<Compile Include="Microsoft\Data\SqlClient\Server\SmiContext.cs" />
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
using System.Buffers;
2+
using System.Text;
3+
4+
namespace Microsoft.Data.SqlClient
5+
{
6+
internal static class BufferWriterExtensions
7+
{
8+
internal static long GetBytes(this Encoding encoding, string str, IBufferWriter<byte> bufferWriter)
9+
{
10+
var count = encoding.GetByteCount(str);
11+
var array = ArrayPool<byte>.Shared.Rent(count);
12+
13+
try
14+
{
15+
encoding.GetBytes(str, 0, str.Length, array, 0);
16+
bufferWriter.Write(array);
17+
return count;
18+
}
19+
finally
20+
{
21+
ArrayPool<byte>.Shared.Return(array);
22+
}
23+
}
24+
}
25+
}

src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs

+14-17
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ internal int ObjectID
128128

129129
private bool _is2022 = false;
130130

131-
private byte[] _sniSpnBuffer = null;
131+
private string _sniSpn = null;
132132

133133
// UNDONE - need to have some for both instances - both command and default???
134134

@@ -430,27 +430,24 @@ internal void Connect(ServerInfo serverInfo,
430430
// AD Integrated behaves like Windows integrated when connecting to a non-fedAuth server
431431
if (integratedSecurity || authType == SqlAuthenticationMethod.ActiveDirectoryIntegrated)
432432
{
433-
_authenticationProvider = _physicalStateObj.CreateSSPIContextProvider();
434-
435433
if (!string.IsNullOrEmpty(serverInfo.ServerSPN))
436434
{
437-
// Native SNI requires the Unicode encoding and any other encoding like UTF8 breaks the code.
438-
byte[] srvSPN = Encoding.Unicode.GetBytes(serverInfo.ServerSPN);
439-
Trace.Assert(srvSPN.Length <= SniNativeWrapper.SniMaxComposedSpnLength, "The provided SPN length exceeded the buffer size.");
440-
_sniSpnBuffer = srvSPN;
435+
_sniSpn = serverInfo.ServerSPN;
441436
SqlClientEventSource.Log.TryTraceEvent("<sc.TdsParser.Connect|SEC> Server SPN `{0}` from the connection string is used.", serverInfo.ServerSPN);
442437
}
443438
else
444439
{
445-
// now allocate proper length of buffer
446-
_sniSpnBuffer = new byte[SniNativeWrapper.SniMaxComposedSpnLength];
440+
// Empty signifies to interop layer that SNI needs to be generated
441+
_sniSpn = string.Empty;
447442
}
443+
444+
_authenticationProvider = _physicalStateObj.CreateSSPIContextProvider();
448445
SqlClientEventSource.Log.TryTraceEvent("<sc.TdsParser.Connect|SEC> SSPI or Active Directory Authentication Library for SQL Server based integrated authentication");
449446
}
450447
else
451448
{
452449
_authenticationProvider = null;
453-
_sniSpnBuffer = null;
450+
_sniSpn = null;
454451

455452
switch (authType)
456453
{
@@ -529,7 +526,7 @@ internal void Connect(ServerInfo serverInfo,
529526
serverInfo.ExtendedServerName,
530527
timeout,
531528
out instanceName,
532-
_sniSpnBuffer,
529+
ref _sniSpn,
533530
false,
534531
true,
535532
fParallel,
@@ -539,8 +536,6 @@ internal void Connect(ServerInfo serverInfo,
539536
FQDNforDNSCache,
540537
hostNameInCertificate);
541538

542-
_authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this);
543-
544539
if (TdsEnums.SNI_SUCCESS != _physicalStateObj.Status)
545540
{
546541
_physicalStateObj.AddError(ProcessSNIError(_physicalStateObj));
@@ -555,6 +550,8 @@ internal void Connect(ServerInfo serverInfo,
555550
Debug.Fail("SNI returned status != success, but no error thrown?");
556551
}
557552

553+
_authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this);
554+
558555
_server = serverInfo.ResolvedServerName;
559556

560557
if (connHandler.PoolGroupProviderInfo != null)
@@ -629,7 +626,7 @@ internal void Connect(ServerInfo serverInfo,
629626
serverInfo.ExtendedServerName,
630627
timeout,
631628
out instanceName,
632-
_sniSpnBuffer,
629+
ref _sniSpn,
633630
true,
634631
true,
635632
fParallel,
@@ -639,15 +636,15 @@ internal void Connect(ServerInfo serverInfo,
639636
serverInfo.ResolvedServerName,
640637
hostNameInCertificate);
641638

642-
_authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this);
643-
644639
if (TdsEnums.SNI_SUCCESS != _physicalStateObj.Status)
645640
{
646641
_physicalStateObj.AddError(ProcessSNIError(_physicalStateObj));
647642
SqlClientEventSource.Log.TryTraceEvent("<sc.TdsParser.Connect|ERR|SEC> Login failure");
648643
ThrowExceptionAndWarning(_physicalStateObj);
649644
}
650645

646+
_authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this);
647+
651648
uint retCode = SniNativeWrapper.SniGetConnectionId(_physicalStateObj.Handle, ref _connHandler._clientConnectionId);
652649
Debug.Assert(retCode == TdsEnums.SNI_SUCCESS, "Unexpected failure state upon calling SniGetConnectionId");
653650
SqlClientEventSource.Log.TryTraceEvent("<sc.TdsParser.Connect|SEC> Sending prelogin handshake");
@@ -13785,7 +13782,7 @@ internal string TraceString()
1378513782
_is2000 ? bool.TrueString : bool.FalseString,
1378613783
_is2000SP1 ? bool.TrueString : bool.FalseString,
1378713784
_is2005 ? bool.TrueString : bool.FalseString,
13788-
_sniSpnBuffer == null ? "(null)" : _sniSpnBuffer.Length.ToString((IFormatProvider)null),
13785+
_sniSpn == null ? "(null)" : _sniSpn.Length.ToString((IFormatProvider)null),
1378913786
_physicalStateObj != null ? "(null)" : _physicalStateObj.ErrorCount.ToString((IFormatProvider)null),
1379013787
_physicalStateObj != null ? "(null)" : _physicalStateObj.WarningCount.ToString((IFormatProvider)null),
1379113788
_physicalStateObj != null ? "(null)" : _physicalStateObj.PreAttentionErrorCount.ToString((IFormatProvider)null),

0 commit comments

Comments
 (0)