Skip to content

Commit 9450fde

Browse files
authored
Use string for SNI instead of byte[] (#2790)
1 parent b875e1d commit 9450fde

18 files changed

+285
-158
lines changed

src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj

+3
Original file line numberDiff line numberDiff line change
@@ -515,6 +515,9 @@
515515
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\SqlObjectPool.cs">
516516
<Link>Microsoft\Data\SqlClient\SqlObjectPool.cs</Link>
517517
</Compile>
518+
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\SqlObjectPools.cs">
519+
<Link>Microsoft\Data\SqlClient\SqlObjectPools.cs</Link>
520+
</Compile>
518521
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\SqlParameter.cs">
519522
<Link>Microsoft\Data\SqlClient\SqlParameter.cs</Link>
520523
</Compile>

src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs

+9-8
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
using System;
66
using System.Buffers;
7+
using System.Collections.Generic;
78
using System.Diagnostics;
89
using System.IO;
910
using System.Net;
@@ -33,7 +34,7 @@ internal class SNIProxy
3334
/// <param name="fullServerName">Full server name from connection string</param>
3435
/// <param name="timeout">Timer expiration</param>
3536
/// <param name="instanceName">Instance name</param>
36-
/// <param name="spnBuffer">SPN</param>
37+
/// <param name="spns">SPNs</param>
3738
/// <param name="serverSPN">pre-defined SPN</param>
3839
/// <param name="flushCache">Flush packet cache</param>
3940
/// <param name="async">Asynchronous connection</param>
@@ -50,7 +51,7 @@ internal static SNIHandle CreateConnectionHandle(
5051
string fullServerName,
5152
TimeoutTimer timeout,
5253
out byte[] instanceName,
53-
ref byte[][] spnBuffer,
54+
ref string[] spns,
5455
string serverSPN,
5556
bool flushCache,
5657
bool async,
@@ -102,7 +103,7 @@ internal static SNIHandle CreateConnectionHandle(
102103
{
103104
try
104105
{
105-
spnBuffer = GetSqlServerSPNs(details, serverSPN);
106+
spns = GetSqlServerSPNs(details, serverSPN);
106107
}
107108
catch (Exception e)
108109
{
@@ -114,12 +115,12 @@ internal static SNIHandle CreateConnectionHandle(
114115
return sniHandle;
115116
}
116117

117-
private static byte[][] GetSqlServerSPNs(DataSource dataSource, string serverSPN)
118+
private static string[] GetSqlServerSPNs(DataSource dataSource, string serverSPN)
118119
{
119120
Debug.Assert(!string.IsNullOrWhiteSpace(dataSource.ServerName));
120121
if (!string.IsNullOrWhiteSpace(serverSPN))
121122
{
122-
return new byte[1][] { Encoding.Unicode.GetBytes(serverSPN) };
123+
return new[] { serverSPN };
123124
}
124125

125126
string hostName = dataSource.ServerName;
@@ -137,7 +138,7 @@ private static byte[][] GetSqlServerSPNs(DataSource dataSource, string serverSPN
137138
return GetSqlServerSPNs(hostName, postfix, dataSource.ResolvedProtocol);
138139
}
139140

140-
private static byte[][] GetSqlServerSPNs(string hostNameOrAddress, string portOrInstanceName, DataSource.Protocol protocol)
141+
private static string[] GetSqlServerSPNs(string hostNameOrAddress, string portOrInstanceName, DataSource.Protocol protocol)
141142
{
142143
Debug.Assert(!string.IsNullOrWhiteSpace(hostNameOrAddress));
143144
IPHostEntry hostEntry = null;
@@ -168,12 +169,12 @@ private static byte[][] GetSqlServerSPNs(string hostNameOrAddress, string portOr
168169
string serverSpnWithDefaultPort = serverSpn + $":{DefaultSqlServerPort}";
169170
// Set both SPNs with and without Port as Port is optional for default instance
170171
SqlClientEventSource.Log.TryAdvancedTraceEvent("SNIProxy.GetSqlServerSPN | Info | ServerSPNs {0} and {1}", serverSpn, serverSpnWithDefaultPort);
171-
return new byte[][] { Encoding.Unicode.GetBytes(serverSpn), Encoding.Unicode.GetBytes(serverSpnWithDefaultPort) };
172+
return new[] { serverSpn, serverSpnWithDefaultPort };
172173
}
173174
// else Named Pipes do not need to valid port
174175

175176
SqlClientEventSource.Log.TryAdvancedTraceEvent("SNIProxy.GetSqlServerSPN | Info | ServerSPN {0}", serverSpn);
176-
return new byte[][] { Encoding.Unicode.GetBytes(serverSpn) };
177+
return new[] { serverSpn };
177178
}
178179

179180
/// <summary>

src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs

+6-7
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,7 @@ internal sealed partial class TdsParser
113113

114114
private bool _is2022 = false;
115115

116-
private byte[][] _sniSpnBuffer = null;
117-
// UNDONE - need to have some for both instances - both command and default???
116+
private string[] _serverSpn = null;
118117

119118
// SqlStatistics
120119
private SqlStatistics _statistics = null;
@@ -391,7 +390,7 @@ internal void Connect(
391390
}
392391
else
393392
{
394-
_sniSpnBuffer = null;
393+
_serverSpn = null;
395394
SqlClientEventSource.Log.TryTraceEvent("TdsParser.Connect | SEC | Connection Object Id {0}, Authentication Mode: {1}", _connHandler.ObjectID,
396395
authType == SqlAuthenticationMethod.NotSpecified ? SqlAuthenticationMethod.SqlPassword.ToString() : authType.ToString());
397396
}
@@ -403,7 +402,7 @@ internal void Connect(
403402
SqlClientEventSource.Log.TryTraceEvent("<sc.TdsParser.Connect|SEC> Encryption will be disabled as target server is a SQL Local DB instance.");
404403
}
405404

406-
_sniSpnBuffer = null;
405+
_serverSpn = null;
407406
_authenticationProvider = null;
408407

409408
// AD Integrated behaves like Windows integrated when connecting to a non-fedAuth server
@@ -442,7 +441,7 @@ internal void Connect(
442441
serverInfo.ExtendedServerName,
443442
timeout,
444443
out instanceName,
445-
ref _sniSpnBuffer,
444+
ref _serverSpn,
446445
false,
447446
true,
448447
fParallel,
@@ -541,7 +540,7 @@ internal void Connect(
541540
_physicalStateObj.CreatePhysicalSNIHandle(
542541
serverInfo.ExtendedServerName,
543542
timeout, out instanceName,
544-
ref _sniSpnBuffer,
543+
ref _serverSpn,
545544
true,
546545
true,
547546
fParallel,
@@ -13318,7 +13317,7 @@ internal string TraceString()
1331813317
_fMARS ? bool.TrueString : bool.FalseString,
1331913318
_sessionPool == null ? "(null)" : _sessionPool.TraceString(),
1332013319
_is2005 ? bool.TrueString : bool.FalseString,
13321-
_sniSpnBuffer == null ? "(null)" : _sniSpnBuffer.Length.ToString((IFormatProvider)null),
13320+
_serverSpn == null ? "(null)" : _serverSpn.Length.ToString((IFormatProvider)null),
1332213321
_physicalStateObj != null ? "(null)" : _physicalStateObj.ErrorCount.ToString((IFormatProvider)null),
1332313322
_physicalStateObj != null ? "(null)" : _physicalStateObj.WarningCount.ToString((IFormatProvider)null),
1332413323
_physicalStateObj != null ? "(null)" : _physicalStateObj.PreAttentionErrorCount.ToString((IFormatProvider)null),

src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.netcore.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ internal abstract void CreatePhysicalSNIHandle(
186186
string serverName,
187187
TimeoutTimer timeout,
188188
out byte[] instanceName,
189-
ref byte[][] spnBuffer,
189+
ref string[] spns,
190190
bool flushCache,
191191
bool async,
192192
bool fParallel,

src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ internal override void CreatePhysicalSNIHandle(
8181
string serverName,
8282
TimeoutTimer timeout,
8383
out byte[] instanceName,
84-
ref byte[][] spnBuffer,
84+
ref string[] spns,
8585
bool flushCache,
8686
bool async,
8787
bool parallel,
@@ -94,7 +94,7 @@ internal override void CreatePhysicalSNIHandle(
9494
string hostNameInCertificate,
9595
string serverCertificateFilename)
9696
{
97-
SNIHandle? sessionHandle = SNIProxy.CreateConnectionHandle(serverName, timeout, out instanceName, ref spnBuffer, serverSPN,
97+
SNIHandle? sessionHandle = SNIProxy.CreateConnectionHandle(serverName, timeout, out instanceName, ref spns, serverSPN,
9898
flushCache, async, parallel, isIntegratedSecurity, iPAddressPreference, cachedFQDN, ref pendingDNSInfo, tlsFirst,
9999
hostNameInCertificate, serverCertificateFilename);
100100

src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs

+5-8
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ internal override void CreatePhysicalSNIHandle(
144144
string serverName,
145145
TimeoutTimer timeout,
146146
out byte[] instanceName,
147-
ref byte[][] spnBuffer,
147+
ref string[] spns,
148148
bool flushCache,
149149
bool async,
150150
bool fParallel,
@@ -157,31 +157,28 @@ internal override void CreatePhysicalSNIHandle(
157157
string hostNameInCertificate,
158158
string serverCertificateFilename)
159159
{
160-
// We assume that the loadSSPILibrary has been called already. now allocate proper length of buffer
161-
spnBuffer = new byte[1][];
162160
if (isIntegratedSecurity)
163161
{
164162
// now allocate proper length of buffer
165163
if (!string.IsNullOrEmpty(serverSPN))
166164
{
167165
// Native SNI requires the Unicode encoding and any other encoding like UTF8 breaks the code.
168-
byte[] srvSPN = Encoding.Unicode.GetBytes(serverSPN);
169-
Trace.Assert(srvSPN.Length <= SniNativeWrapper.SniMaxComposedSpnLength, "Length of the provided SPN exceeded the buffer size.");
170-
spnBuffer[0] = srvSPN;
171166
SqlClientEventSource.Log.TryTraceEvent("<{0}.{1}|SEC> Server SPN `{2}` from the connection string is used.", nameof(TdsParserStateObjectNative), nameof(CreatePhysicalSNIHandle), serverSPN);
172167
}
173168
else
174169
{
175-
spnBuffer[0] = new byte[SniNativeWrapper.SniMaxComposedSpnLength];
170+
// This will signal to the interop layer that we need to retrieve the SPN
171+
serverSPN = string.Empty;
176172
}
177173
}
178174

179175
ConsumerInfo myInfo = CreateConsumerInfo(async);
180176
SQLDNSInfo cachedDNSInfo;
181177
bool ret = SQLFallbackDNSCache.Instance.GetDNSInfo(cachedFQDN, out cachedDNSInfo);
182178

183-
_sessionHandle = new SNIHandle(myInfo, serverName, spnBuffer[0], timeout.MillisecondsRemainingInt, out instanceName,
179+
_sessionHandle = new SNIHandle(myInfo, serverName, ref serverSPN, timeout.MillisecondsRemainingInt, out instanceName,
184180
flushCache, !async, fParallel, ipPreference, cachedDNSInfo, hostNameInCertificate);
181+
spns = new[] { serverSPN.TrimEnd() };
185182
}
186183

187184
protected override uint SNIPacketGetData(PacketHandle packet, byte[] _inBuff, ref uint dataSize)

src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj

+8-4
Original file line numberDiff line numberDiff line change
@@ -337,10 +337,13 @@
337337
<Compile Include="$(CommonSourceRoot)Microsoft\Data\Sql\SqlDataSourceEnumerator.cs">
338338
<Link>Microsoft\Data\Sql\SqlDataSourceEnumerator.cs</Link>
339339
</Compile>
340-
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\SqlObjectPool.cs">
341-
<Link>Microsoft\Data\SqlClient\SqlObjectPool.cs</Link>
342-
</Compile>
343-
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\AAsyncCallContext.cs">
340+
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\SqlObjectPool.cs">
341+
<Link>Microsoft\Data\SqlClient\SqlObjectPool.cs</Link>
342+
</Compile>
343+
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\SqlObjectPools.cs">
344+
<Link>Microsoft\Data\SqlClient\SqlObjectPools.cs</Link>
345+
</Compile>
346+
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\AAsyncCallContext.cs">
344347
<Link>Microsoft\Data\SqlClient\AAsyncCallContext.cs</Link>
345348
</Compile>
346349
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\ActiveDirectoryAuthenticationTimeoutRetryHelper.cs">
@@ -854,6 +857,7 @@
854857
<Compile Include="Microsoft\Data\Common\DbConnectionString.cs" />
855858
<Compile Include="Microsoft\Data\Common\GreenMethods.cs" />
856859
<Compile Include="Microsoft\Data\SqlClient\assemblycache.cs" />
860+
<Compile Include="Microsoft\Data\SqlClient\BufferWriterExtensions.cs" />
857861
<Compile Include="Microsoft\Data\SqlClient\Reliability\SqlConfigurableRetryLogicManager.LoadType.cs" />
858862
<Compile Include="Microsoft\Data\SqlClient\Server\SmiConnection.cs" />
859863
<Compile Include="Microsoft\Data\SqlClient\Server\SmiContext.cs" />
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
using System;
2+
using System.Buffers;
3+
using System.Text;
4+
5+
namespace Microsoft.Data.SqlClient
6+
{
7+
internal static class BufferWriterExtensions
8+
{
9+
internal static long GetBytes(this Encoding encoding, string str, IBufferWriter<byte> bufferWriter)
10+
{
11+
var count = encoding.GetByteCount(str);
12+
var array = ArrayPool<byte>.Shared.Rent(count);
13+
14+
try
15+
{
16+
var length = encoding.GetBytes(str, 0, str.Length, array, 0);
17+
bufferWriter.Write(array.AsSpan(0, length));
18+
return length;
19+
}
20+
finally
21+
{
22+
ArrayPool<byte>.Shared.Return(array);
23+
}
24+
}
25+
}
26+
}

0 commit comments

Comments
 (0)