Skip to content

Commit 89d85fb

Browse files
committed
Use string instead of byte[] for SNI
Majority of use cases end up decoding before doing anything with it, so it makes sense to just store it as a string rather than byte[]
1 parent 7e4d15f commit 89d85fb

16 files changed

+264
-117
lines changed

src/Microsoft.Data.SqlClient/netcore/src/Interop/SNINativeMethodWrapper.Windows.cs

+80-21
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
// See the LICENSE file in the project root for more information.
44

55
using System;
6+
using System.Buffers;
7+
using System.Diagnostics;
68
using System.Runtime.InteropServices;
79
using System.Text;
810
using Microsoft.Data.Common;
@@ -398,7 +400,7 @@ internal static unsafe uint SNIOpenSyncEx(
398400
ConsumerInfo consumerInfo,
399401
string constring,
400402
ref IntPtr pConn,
401-
byte[] spnBuffer,
403+
ref string spn,
402404
byte[] instanceName,
403405
bool fOverrideCache,
404406
bool fSync,
@@ -436,13 +438,59 @@ internal static unsafe uint SNIOpenSyncEx(
436438
clientConsumerInfo.DNSCacheInfo.wszCachedTcpIPv6 = cachedDNSInfo?.AddrIPv6;
437439
clientConsumerInfo.DNSCacheInfo.wszCachedTcpPort = cachedDNSInfo?.Port;
438440

439-
if (spnBuffer != null)
441+
if (spn != null)
440442
{
441-
fixed (byte* pin_spnBuffer = &spnBuffer[0])
443+
// An empty string implies we need to find the SPN so we supply a buffer for the max size
444+
if (spn.Length == 0)
442445
{
443-
clientConsumerInfo.szSPN = pin_spnBuffer;
444-
clientConsumerInfo.cchSPN = (uint)spnBuffer.Length;
445-
return SNIOpenSyncExWrapper(ref clientConsumerInfo, out pConn);
446+
var array = ArrayPool<byte>.Shared.Rent(SniMaxComposedSpnLength);
447+
array.AsSpan().Clear();
448+
449+
try
450+
{
451+
fixed (byte* pin_spnBuffer = array)
452+
{
453+
clientConsumerInfo.szSPN = pin_spnBuffer;
454+
clientConsumerInfo.cchSPN = (uint)SniMaxComposedSpnLength;
455+
456+
var result = SNIOpenSyncExWrapper(ref clientConsumerInfo, out pConn);
457+
458+
if (result == 0)
459+
{
460+
spn = Encoding.Unicode.GetString(array).TrimEnd('\0');
461+
}
462+
463+
return result;
464+
}
465+
}
466+
finally
467+
{
468+
ArrayPool<byte>.Shared.Return(array);
469+
}
470+
}
471+
472+
// We have a value of the SPN, so we marshal that and send it to the native layer
473+
else
474+
{
475+
var writer = SqlObjectPools.BufferWriter.Rent();
476+
477+
// Native SNI requires the Unicode encoding and any other encoding like UTF8 breaks the code.
478+
Encoding.Unicode.GetBytes(spn, writer);
479+
Trace.Assert(writer.WrittenCount <= SniMaxComposedSpnLength, "Length of the provided SPN exceeded the buffer size.");
480+
481+
try
482+
{
483+
fixed (byte* pin_spnBuffer = writer.WrittenSpan)
484+
{
485+
clientConsumerInfo.szSPN = pin_spnBuffer;
486+
clientConsumerInfo.cchSPN = (uint)writer.WrittenCount;
487+
return SNIOpenSyncExWrapper(ref clientConsumerInfo, out pConn);
488+
}
489+
}
490+
finally
491+
{
492+
SqlObjectPools.BufferWriter.Return(writer);
493+
}
446494
}
447495
}
448496
else
@@ -471,25 +519,36 @@ internal static unsafe void SNIPacketSetData(SNIPacket packet, byte[] data, int
471519
}
472520
}
473521

474-
internal static unsafe uint SNISecGenClientContext(SNIHandle pConnectionObject, ReadOnlySpan<byte> inBuff, Span<byte> outBuff, out uint sendLength, byte[] serverUserName)
522+
internal static unsafe uint SNISecGenClientContext(SNIHandle pConnectionObject, ReadOnlySpan<byte> inBuff, Span<byte> outBuff, out uint sendLength, string serverUserName)
475523
{
476524
sendLength = (uint)outBuff.Length;
477525

478-
fixed (byte* pin_serverUserName = &serverUserName[0])
479-
fixed (byte* pInBuff = inBuff)
480-
fixed (byte* pOutBuff = outBuff)
526+
var serverWriter = SqlObjectPools.BufferWriter.Rent();
527+
528+
try
529+
{
530+
Encoding.Unicode.GetBytes(serverUserName, serverWriter);
531+
532+
fixed (byte* pin_serverUserName = serverWriter.WrittenSpan)
533+
fixed (byte* pInBuff = inBuff)
534+
fixed (byte* pOutBuff = outBuff)
535+
{
536+
return SNISecGenClientContextWrapper(
537+
pConnectionObject,
538+
pInBuff,
539+
(uint)inBuff.Length,
540+
pOutBuff,
541+
ref sendLength,
542+
out _,
543+
pin_serverUserName,
544+
(uint)serverWriter.WrittenCount,
545+
null,
546+
null);
547+
}
548+
}
549+
finally
481550
{
482-
return SNISecGenClientContextWrapper(
483-
pConnectionObject,
484-
pInBuff,
485-
(uint)inBuff.Length,
486-
pOutBuff,
487-
ref sendLength,
488-
out _,
489-
pin_serverUserName,
490-
(uint)serverUserName.Length,
491-
null,
492-
null);
551+
SqlObjectPools.BufferWriter.Return(serverWriter);
493552
}
494553
}
495554

src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs

+13-17
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
using System;
66
using System.Buffers;
7+
using System.Collections.Generic;
78
using System.Diagnostics;
89
using System.IO;
910
using System.Net;
@@ -33,11 +34,11 @@ internal class SNIProxy
3334
/// <param name="sspiClientContextStatus">SSPI client context status</param>
3435
/// <param name="receivedBuff">Receive buffer</param>
3536
/// <param name="sendWriter">Writer for send buffer</param>
36-
/// <param name="serverName">Service Principal Name buffer</param>
37+
/// <param name="serverNames">Service Principal Name</param>
3738
/// <returns>SNI error code</returns>
38-
internal static void GenSspiClientContext(SspiClientContextStatus sspiClientContextStatus, ReadOnlySpan<byte> receivedBuff, IBufferWriter<byte> sendWriter, byte[][] serverName)
39+
internal static void GenSspiClientContext(SspiClientContextStatus sspiClientContextStatus, ReadOnlySpan<byte> receivedBuff, IBufferWriter<byte> sendWriter, string[] serverNames)
3940
{
40-
// TODO: this should use ReadOnlyMemory all the way through
41+
// TODO: this should use ReadOnlySpan all the way through
4142
byte[] array = null;
4243

4344
if (!receivedBuff.IsEmpty)
@@ -46,10 +47,10 @@ internal static void GenSspiClientContext(SspiClientContextStatus sspiClientCont
4647
receivedBuff.CopyTo(array);
4748
}
4849

49-
GenSspiClientContext(sspiClientContextStatus, array, sendWriter, serverName);
50+
GenSspiClientContext(sspiClientContextStatus, array, sendWriter, serverNames);
5051
}
5152

52-
private static void GenSspiClientContext(SspiClientContextStatus sspiClientContextStatus, byte[] receivedBuff, IBufferWriter<byte> sendWriter, byte[][] serverName)
53+
private static void GenSspiClientContext(SspiClientContextStatus sspiClientContextStatus, byte[] receivedBuff, IBufferWriter<byte> sendWriter, string[] serverNames)
5354
{
5455
SafeDeleteContext securityContext = sspiClientContextStatus.SecurityContext;
5556
ContextFlagsPal contextFlags = sspiClientContextStatus.ContextFlags;
@@ -81,15 +82,10 @@ private static void GenSspiClientContext(SspiClientContextStatus sspiClientConte
8182
| ContextFlagsPal.Delegate
8283
| ContextFlagsPal.MutualAuth;
8384

84-
string[] serverSPNs = new string[serverName.Length];
85-
for (int i = 0; i < serverName.Length; i++)
86-
{
87-
serverSPNs[i] = Encoding.Unicode.GetString(serverName[i]);
88-
}
8985
SecurityStatusPal statusCode = NegotiateStreamPal.InitializeSecurityContext(
9086
credentialsHandle,
9187
ref securityContext,
92-
serverSPNs,
88+
serverNames,
9389
requestedContextFlags,
9490
inSecurityBufferArray,
9591
outSecurityBuffer,
@@ -164,7 +160,7 @@ internal static SNIHandle CreateConnectionHandle(
164160
string fullServerName,
165161
TimeoutTimer timeout,
166162
out byte[] instanceName,
167-
ref byte[][] spnBuffer,
163+
ref string[] spnBuffer,
168164
string serverSPN,
169165
bool flushCache,
170166
bool async,
@@ -228,12 +224,12 @@ internal static SNIHandle CreateConnectionHandle(
228224
return sniHandle;
229225
}
230226

231-
private static byte[][] GetSqlServerSPNs(DataSource dataSource, string serverSPN)
227+
private static string[] GetSqlServerSPNs(DataSource dataSource, string serverSPN)
232228
{
233229
Debug.Assert(!string.IsNullOrWhiteSpace(dataSource.ServerName));
234230
if (!string.IsNullOrWhiteSpace(serverSPN))
235231
{
236-
return new byte[1][] { Encoding.Unicode.GetBytes(serverSPN) };
232+
return new[] { serverSPN };
237233
}
238234

239235
string hostName = dataSource.ServerName;
@@ -251,7 +247,7 @@ private static byte[][] GetSqlServerSPNs(DataSource dataSource, string serverSPN
251247
return GetSqlServerSPNs(hostName, postfix, dataSource.ResolvedProtocol);
252248
}
253249

254-
private static byte[][] GetSqlServerSPNs(string hostNameOrAddress, string portOrInstanceName, DataSource.Protocol protocol)
250+
private static string[] GetSqlServerSPNs(string hostNameOrAddress, string portOrInstanceName, DataSource.Protocol protocol)
255251
{
256252
Debug.Assert(!string.IsNullOrWhiteSpace(hostNameOrAddress));
257253
IPHostEntry hostEntry = null;
@@ -282,12 +278,12 @@ private static byte[][] GetSqlServerSPNs(string hostNameOrAddress, string portOr
282278
string serverSpnWithDefaultPort = serverSpn + $":{DefaultSqlServerPort}";
283279
// Set both SPNs with and without Port as Port is optional for default instance
284280
SqlClientEventSource.Log.TryAdvancedTraceEvent("SNIProxy.GetSqlServerSPN | Info | ServerSPNs {0} and {1}", serverSpn, serverSpnWithDefaultPort);
285-
return new byte[][] { Encoding.Unicode.GetBytes(serverSpn), Encoding.Unicode.GetBytes(serverSpnWithDefaultPort) };
281+
return new[] { serverSpn, serverSpnWithDefaultPort };
286282
}
287283
// else Named Pipes do not need to valid port
288284

289285
SqlClientEventSource.Log.TryAdvancedTraceEvent("SNIProxy.GetSqlServerSPN | Info | ServerSPN {0}", serverSpn);
290-
return new byte[][] { Encoding.Unicode.GetBytes(serverSpn) };
286+
return new[] { serverSpn };
291287
}
292288

293289
/// <summary>

src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs

+10-10
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ internal static void Assert(string message)
134134

135135
private bool _is2022 = false;
136136

137-
private byte[][] _sniSpnBuffer = null;
137+
private string[] _sniSpn = null;
138138

139139
// SqlStatistics
140140
private SqlStatistics _statistics = null;
@@ -404,7 +404,7 @@ internal void Connect(
404404
}
405405
else
406406
{
407-
_sniSpnBuffer = null;
407+
_sniSpn = null;
408408
SqlClientEventSource.Log.TryTraceEvent("TdsParser.Connect | SEC | Connection Object Id {0}, Authentication Mode: {1}", _connHandler._objectID,
409409
authType == SqlAuthenticationMethod.NotSpecified ? SqlAuthenticationMethod.SqlPassword.ToString() : authType.ToString());
410410
}
@@ -416,7 +416,7 @@ internal void Connect(
416416
SqlClientEventSource.Log.TryTraceEvent("<sc.TdsParser.Connect|SEC> Encryption will be disabled as target server is a SQL Local DB instance.");
417417
}
418418

419-
_sniSpnBuffer = null;
419+
_sniSpn = null;
420420
_authenticationProvider = null;
421421

422422
// AD Integrated behaves like Windows integrated when connecting to a non-fedAuth server
@@ -455,7 +455,7 @@ internal void Connect(
455455
serverInfo.ExtendedServerName,
456456
timeout,
457457
out instanceName,
458-
ref _sniSpnBuffer,
458+
ref _sniSpn,
459459
false,
460460
true,
461461
fParallel,
@@ -468,8 +468,6 @@ internal void Connect(
468468
hostNameInCertificate,
469469
serverCertificateFilename);
470470

471-
_authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this);
472-
473471
if (TdsEnums.SNI_SUCCESS != _physicalStateObj.Status)
474472
{
475473
_physicalStateObj.AddError(ProcessSNIError(_physicalStateObj));
@@ -484,6 +482,8 @@ internal void Connect(
484482
Debug.Fail("SNI returned status != success, but no error thrown?");
485483
}
486484

485+
_authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this);
486+
487487
_server = serverInfo.ResolvedServerName;
488488

489489
if (connHandler.PoolGroupProviderInfo != null)
@@ -554,7 +554,7 @@ internal void Connect(
554554
_physicalStateObj.CreatePhysicalSNIHandle(
555555
serverInfo.ExtendedServerName,
556556
timeout, out instanceName,
557-
ref _sniSpnBuffer,
557+
ref _sniSpn,
558558
true,
559559
true,
560560
fParallel,
@@ -567,15 +567,15 @@ internal void Connect(
567567
hostNameInCertificate,
568568
serverCertificateFilename);
569569

570-
_authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this);
571-
572570
if (TdsEnums.SNI_SUCCESS != _physicalStateObj.Status)
573571
{
574572
_physicalStateObj.AddError(ProcessSNIError(_physicalStateObj));
575573
SqlClientEventSource.Log.TryTraceEvent("<sc.TdsParser.Connect|ERR|SEC> Login failure");
576574
ThrowExceptionAndWarning(_physicalStateObj);
577575
}
578576

577+
_authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this);
578+
579579
uint retCode = _physicalStateObj.SniGetConnectionId(ref _connHandler._clientConnectionId);
580580

581581
Debug.Assert(retCode == TdsEnums.SNI_SUCCESS, "Unexpected failure state upon calling SniGetConnectionId");
@@ -13167,7 +13167,7 @@ internal string TraceString()
1316713167
_fMARS ? bool.TrueString : bool.FalseString,
1316813168
_sessionPool == null ? "(null)" : _sessionPool.TraceString(),
1316913169
_is2005 ? bool.TrueString : bool.FalseString,
13170-
_sniSpnBuffer == null ? "(null)" : _sniSpnBuffer.Length.ToString((IFormatProvider)null),
13170+
_sniSpn == null ? "(null)" : _sniSpn.Length.ToString((IFormatProvider)null),
1317113171
_physicalStateObj != null ? "(null)" : _physicalStateObj.ErrorCount.ToString((IFormatProvider)null),
1317213172
_physicalStateObj != null ? "(null)" : _physicalStateObj.WarningCount.ToString((IFormatProvider)null),
1317313173
_physicalStateObj != null ? "(null)" : _physicalStateObj.PreAttentionErrorCount.ToString((IFormatProvider)null),

src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.netcore.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ internal abstract void CreatePhysicalSNIHandle(
186186
string serverName,
187187
TimeoutTimer timeout,
188188
out byte[] instanceName,
189-
ref byte[][] spnBuffer,
189+
ref string[] spn,
190190
bool flushCache,
191191
bool async,
192192
bool fParallel,

src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ internal override void CreatePhysicalSNIHandle(
8181
string serverName,
8282
TimeoutTimer timeout,
8383
out byte[] instanceName,
84-
ref byte[][] spnBuffer,
84+
ref string[] spn,
8585
bool flushCache,
8686
bool async,
8787
bool parallel,
@@ -94,7 +94,7 @@ internal override void CreatePhysicalSNIHandle(
9494
string hostNameInCertificate,
9595
string serverCertificateFilename)
9696
{
97-
SNIHandle? sessionHandle = SNIProxy.CreateConnectionHandle(serverName, timeout, out instanceName, ref spnBuffer, serverSPN,
97+
SNIHandle? sessionHandle = SNIProxy.CreateConnectionHandle(serverName, timeout, out instanceName, ref spn, serverSPN,
9898
flushCache, async, parallel, isIntegratedSecurity, iPAddressPreference, cachedFQDN, ref pendingDNSInfo, tlsFirst,
9999
hostNameInCertificate, serverCertificateFilename);
100100

src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs

+5-8
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ internal override void CreatePhysicalSNIHandle(
143143
string serverName,
144144
TimeoutTimer timeout,
145145
out byte[] instanceName,
146-
ref byte[][] spnBuffer,
146+
ref string[] spn,
147147
bool flushCache,
148148
bool async,
149149
bool fParallel,
@@ -156,31 +156,28 @@ internal override void CreatePhysicalSNIHandle(
156156
string hostNameInCertificate,
157157
string serverCertificateFilename)
158158
{
159-
// We assume that the loadSSPILibrary has been called already. now allocate proper length of buffer
160-
spnBuffer = new byte[1][];
161159
if (isIntegratedSecurity)
162160
{
163161
// now allocate proper length of buffer
164162
if (!string.IsNullOrEmpty(serverSPN))
165163
{
166164
// Native SNI requires the Unicode encoding and any other encoding like UTF8 breaks the code.
167-
byte[] srvSPN = Encoding.Unicode.GetBytes(serverSPN);
168-
Trace.Assert(srvSPN.Length <= SNINativeMethodWrapper.SniMaxComposedSpnLength, "Length of the provided SPN exceeded the buffer size.");
169-
spnBuffer[0] = srvSPN;
170165
SqlClientEventSource.Log.TryTraceEvent("<{0}.{1}|SEC> Server SPN `{2}` from the connection string is used.", nameof(TdsParserStateObjectNative), nameof(CreatePhysicalSNIHandle), serverSPN);
171166
}
172167
else
173168
{
174-
spnBuffer[0] = new byte[SNINativeMethodWrapper.SniMaxComposedSpnLength];
169+
// This will signal to the interop layer that we need to retrieve the SPN
170+
serverSPN = string.Empty;
175171
}
176172
}
177173

178174
SNINativeMethodWrapper.ConsumerInfo myInfo = CreateConsumerInfo(async);
179175
SQLDNSInfo cachedDNSInfo;
180176
bool ret = SQLFallbackDNSCache.Instance.GetDNSInfo(cachedFQDN, out cachedDNSInfo);
181177

182-
_sessionHandle = new SNIHandle(myInfo, serverName, spnBuffer[0], timeout.MillisecondsRemainingInt, out instanceName,
178+
_sessionHandle = new SNIHandle(myInfo, serverName, ref serverSPN, timeout.MillisecondsRemainingInt, out instanceName,
183179
flushCache, !async, fParallel, ipPreference, cachedDNSInfo, hostNameInCertificate);
180+
spn = new[] { serverSPN.TrimEnd() };
184181
}
185182

186183
protected override uint SNIPacketGetData(PacketHandle packet, byte[] _inBuff, ref uint dataSize)

0 commit comments

Comments
 (0)