Skip to content

Commit 0be9c1d

Browse files
committed
Move SPN to ReadOnlyList
1 parent ff11043 commit 0be9c1d

17 files changed

+159
-113
lines changed

Diff for: src/Microsoft.Data.SqlClient/netcore/src/Common/src/System/Net/Security/NegotiateStreamPal.Unix.cs

+2-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5+
using System.Collections.Generic;
56
using System.ComponentModel;
67
using System.Diagnostics;
78
using Microsoft.Win32.SafeHandles;
@@ -146,7 +147,7 @@ private static SecurityStatusPal EstablishSecurityContext(
146147
internal static SecurityStatusPal InitializeSecurityContext(
147148
SafeFreeCredentials credentialsHandle,
148149
ref SafeDeleteContext securityContext,
149-
string[] spns,
150+
IReadOnlyList<string> spns,
150151
ContextFlagsPal requestedContextFlags,
151152
SecurityBuffer[] inSecurityBufferArray,
152153
SecurityBuffer outSecurityBuffer,

Diff for: src/Microsoft.Data.SqlClient/netcore/src/Common/src/System/Net/Security/NegotiateStreamPal.Windows.cs

+2-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
using System.Globalization;
66
using System.ComponentModel;
77
using Microsoft.Data;
8+
using System.Collections.Generic;
89

910
namespace System.Net.Security
1011
{
@@ -71,7 +72,7 @@ internal static string QueryContextAuthenticationPackage(SafeDeleteContext secur
7172
internal static SecurityStatusPal InitializeSecurityContext(
7273
SafeFreeCredentials credentialsHandle,
7374
ref SafeDeleteContext securityContext,
74-
string[] spn,
75+
IReadOnlyList<string> spn,
7576
ContextFlagsPal requestedContextFlags,
7677
SecurityBuffer[] inSecurityBufferArray,
7778
SecurityBuffer outSecurityBuffer,

Diff for: src/Microsoft.Data.SqlClient/netcore/src/Interop/SNINativeMethodWrapper.Windows.cs

+47-22
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
// See the LICENSE file in the project root for more information.
44

55
using System;
6+
using System.Buffers;
67
using System.Runtime.InteropServices;
78
using System.Text;
89
using Microsoft.Data.Common;
@@ -398,7 +399,7 @@ internal static unsafe uint SNIOpenSyncEx(
398399
ConsumerInfo consumerInfo,
399400
string constring,
400401
ref IntPtr pConn,
401-
byte[] spnBuffer,
402+
ref string spn,
402403
byte[] instanceName,
403404
bool fOverrideCache,
404405
bool fSync,
@@ -436,13 +437,25 @@ internal static unsafe uint SNIOpenSyncEx(
436437
clientConsumerInfo.DNSCacheInfo.wszCachedTcpIPv6 = cachedDNSInfo?.AddrIPv6;
437438
clientConsumerInfo.DNSCacheInfo.wszCachedTcpPort = cachedDNSInfo?.Port;
438439

439-
if (spnBuffer != null)
440+
if (spn != null)
440441
{
441-
fixed (byte* pin_spnBuffer = &spnBuffer[0])
442+
var array = ArrayPool<byte>.Shared.Rent(SniMaxComposedSpnLength);
443+
try
442444
{
443-
clientConsumerInfo.szSPN = pin_spnBuffer;
444-
clientConsumerInfo.cchSPN = (uint)spnBuffer.Length;
445-
return SNIOpenSyncExWrapper(ref clientConsumerInfo, out pConn);
445+
fixed (byte* pin_spnBuffer = &array[0])
446+
{
447+
clientConsumerInfo.szSPN = pin_spnBuffer;
448+
clientConsumerInfo.cchSPN = (uint)SniMaxComposedSpnLength;
449+
var result = SNIOpenSyncExWrapper(ref clientConsumerInfo, out pConn);
450+
451+
spn = Encoding.Unicode.GetString(array, 0, (int)clientConsumerInfo.cchSPN).TrimEnd();
452+
453+
return result;
454+
}
455+
}
456+
finally
457+
{
458+
ArrayPool<byte>.Shared.Return(array);
446459
}
447460
}
448461
else
@@ -471,24 +484,36 @@ internal static unsafe void SNIPacketSetData(SNIPacket packet, byte[] data, int
471484
}
472485
}
473486

474-
internal static unsafe uint SNISecGenClientContext(SNIHandle pConnectionObject, ReadOnlySpan<byte> inBuff, Span<byte> outBuff, ref uint sendLength, byte[] serverUserName)
487+
internal static unsafe uint SNISecGenClientContext(SNIHandle pConnectionObject, ReadOnlySpan<byte> inBuff, Span<byte> outBuff, ref uint sendLength, string serverUserName)
475488
{
476-
fixed (byte* pin_serverUserName = &serverUserName[0])
477-
fixed (byte* pInBuff = inBuff)
478-
fixed (byte* pOutBuff = outBuff)
489+
var serverNameLength = Encoding.Unicode.GetByteCount(serverUserName);
490+
var serverNameArray = ArrayPool<byte>.Shared.Rent(serverNameLength);
491+
492+
try
493+
{
494+
Encoding.Unicode.GetBytes(serverUserName, 0, serverUserName.Length, serverNameArray, 0);
495+
496+
fixed (byte* pin_serverUserName = serverNameArray)
497+
fixed (byte* pInBuff = inBuff)
498+
fixed (byte* pOutBuff = outBuff)
499+
{
500+
bool local_fDone;
501+
return SNISecGenClientContextWrapper(
502+
pConnectionObject,
503+
pInBuff,
504+
(uint)inBuff.Length,
505+
pOutBuff,
506+
ref sendLength,
507+
out local_fDone,
508+
pin_serverUserName,
509+
(uint)serverNameLength,
510+
null,
511+
null);
512+
}
513+
}
514+
finally
479515
{
480-
bool local_fDone;
481-
return SNISecGenClientContextWrapper(
482-
pConnectionObject,
483-
pInBuff,
484-
(uint)inBuff.Length,
485-
pOutBuff,
486-
ref sendLength,
487-
out local_fDone,
488-
pin_serverUserName,
489-
(uint)serverUserName.Length,
490-
null,
491-
null);
516+
ArrayPool<byte>.Shared.Return(serverNameArray);
492517
}
493518
}
494519

Diff for: src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs

+12-16
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
using System;
66
using System.Buffers;
7+
using System.Collections.Generic;
78
using System.Diagnostics;
89
using System.IO;
910
using System.Net;
@@ -32,25 +33,25 @@ internal class SNIProxy
3233
/// </summary>
3334
/// <param name="sspiClientContextStatus">SSPI client context status</param>
3435
/// <param name="receivedBuff">Receive buffer</param>
35-
/// <param name="serverName">Service Principal Name buffer</param>
36+
/// <param name="serverNames">Service Principal Name buffer</param>
3637
/// <returns>Memory for response</returns>
37-
internal static IMemoryOwner<byte> GenSspiClientContext(SspiClientContextStatus sspiClientContextStatus, ReadOnlyMemory<byte> receivedBuff, byte[][] serverName)
38+
internal static IMemoryOwner<byte> GenSspiClientContext(SspiClientContextStatus sspiClientContextStatus, ReadOnlyMemory<byte> receivedBuff, IReadOnlyList<string> serverNames)
3839
{
3940
// TODO: this should use ReadOnlyMemory all the way through
4041
var array = ArrayPool<byte>.Shared.Rent(receivedBuff.Length);
4142

4243
try
4344
{
4445
receivedBuff.CopyTo(array);
45-
return GenSspiClientContext(sspiClientContextStatus, array, receivedBuff.Length, serverName);
46+
return GenSspiClientContext(sspiClientContextStatus, array, receivedBuff.Length, serverNames);
4647
}
4748
finally
4849
{
4950
ArrayPool<byte>.Shared.Return(array);
5051
}
5152
}
5253

53-
private static IMemoryOwner<byte> GenSspiClientContext(SspiClientContextStatus sspiClientContextStatus, byte[] receivedBuff, int receivedBuffLength, byte[][] serverName)
54+
private static IMemoryOwner<byte> GenSspiClientContext(SspiClientContextStatus sspiClientContextStatus, byte[] receivedBuff, int receivedBuffLength, IReadOnlyList<string> serverSPNs)
5455
{
5556
SafeDeleteContext securityContext = sspiClientContextStatus.SecurityContext;
5657
ContextFlagsPal contextFlags = sspiClientContextStatus.ContextFlags;
@@ -66,7 +67,7 @@ private static IMemoryOwner<byte> GenSspiClientContext(SspiClientContextStatus s
6667
SecurityBuffer[] inSecurityBufferArray;
6768
if (receivedBuff != null)
6869
{
69-
inSecurityBufferArray = new SecurityBuffer[] { new SecurityBuffer(receivedBuff, SecurityBufferType.SECBUFFER_TOKEN) };
70+
inSecurityBufferArray = new SecurityBuffer[] { new SecurityBuffer(receivedBuff, 0, receivedBuffLength, SecurityBufferType.SECBUFFER_TOKEN) };
7071
}
7172
else
7273
{
@@ -82,11 +83,6 @@ private static IMemoryOwner<byte> GenSspiClientContext(SspiClientContextStatus s
8283
| ContextFlagsPal.Delegate
8384
| ContextFlagsPal.MutualAuth;
8485

85-
string[] serverSPNs = new string[serverName.Length];
86-
for (int i = 0; i < serverName.Length; i++)
87-
{
88-
serverSPNs[i] = Encoding.Unicode.GetString(serverName[i]);
89-
}
9086
SecurityStatusPal statusCode = NegotiateStreamPal.InitializeSecurityContext(
9187
credentialsHandle,
9288
ref securityContext,
@@ -162,7 +158,7 @@ internal static SNIHandle CreateConnectionHandle(
162158
string fullServerName,
163159
TimeoutTimer timeout,
164160
out byte[] instanceName,
165-
ref byte[][] spnBuffer,
161+
ref string[] spnBuffer,
166162
string serverSPN,
167163
bool flushCache,
168164
bool async,
@@ -226,12 +222,12 @@ internal static SNIHandle CreateConnectionHandle(
226222
return sniHandle;
227223
}
228224

229-
private static byte[][] GetSqlServerSPNs(DataSource dataSource, string serverSPN)
225+
private static string[] GetSqlServerSPNs(DataSource dataSource, string serverSPN)
230226
{
231227
Debug.Assert(!string.IsNullOrWhiteSpace(dataSource.ServerName));
232228
if (!string.IsNullOrWhiteSpace(serverSPN))
233229
{
234-
return new byte[1][] { Encoding.Unicode.GetBytes(serverSPN) };
230+
return new[] { serverSPN };
235231
}
236232

237233
string hostName = dataSource.ServerName;
@@ -249,7 +245,7 @@ private static byte[][] GetSqlServerSPNs(DataSource dataSource, string serverSPN
249245
return GetSqlServerSPNs(hostName, postfix, dataSource._connectionProtocol);
250246
}
251247

252-
private static byte[][] GetSqlServerSPNs(string hostNameOrAddress, string portOrInstanceName, DataSource.Protocol protocol)
248+
private static string[] GetSqlServerSPNs(string hostNameOrAddress, string portOrInstanceName, DataSource.Protocol protocol)
253249
{
254250
Debug.Assert(!string.IsNullOrWhiteSpace(hostNameOrAddress));
255251
IPHostEntry hostEntry = null;
@@ -280,12 +276,12 @@ private static byte[][] GetSqlServerSPNs(string hostNameOrAddress, string portOr
280276
string serverSpnWithDefaultPort = serverSpn + $":{DefaultSqlServerPort}";
281277
// Set both SPNs with and without Port as Port is optional for default instance
282278
SqlClientEventSource.Log.TryAdvancedTraceEvent("SNIProxy.GetSqlServerSPN | Info | ServerSPNs {0} and {1}", serverSpn, serverSpnWithDefaultPort);
283-
return new byte[][] { Encoding.Unicode.GetBytes(serverSpn), Encoding.Unicode.GetBytes(serverSpnWithDefaultPort) };
279+
return new[] { serverSpn, serverSpnWithDefaultPort };
284280
}
285281
// else Named Pipes do not need to valid port
286282

287283
SqlClientEventSource.Log.TryAdvancedTraceEvent("SNIProxy.GetSqlServerSPN | Info | ServerSPN {0}", serverSpn);
288-
return new byte[][] { Encoding.Unicode.GetBytes(serverSpn) };
284+
return new[] { serverSpn };
289285
}
290286

291287
/// <summary>

Diff for: src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs

+9-9
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ internal static void Assert(string message)
134134

135135
private bool _is2022 = false;
136136

137-
private byte[][] _sniSpnBuffer = null;
137+
private string[] _sniSpn = null;
138138

139139
// SqlStatistics
140140
private SqlStatistics _statistics = null;
@@ -404,7 +404,7 @@ internal void Connect(
404404
}
405405
else
406406
{
407-
_sniSpnBuffer = null;
407+
_sniSpn = null;
408408
SqlClientEventSource.Log.TryTraceEvent("TdsParser.Connect | SEC | Connection Object Id {0}, Authentication Mode: {1}", _connHandler._objectID,
409409
authType == SqlAuthenticationMethod.NotSpecified ? SqlAuthenticationMethod.SqlPassword.ToString() : authType.ToString());
410410
}
@@ -416,7 +416,7 @@ internal void Connect(
416416
SqlClientEventSource.Log.TryTraceEvent("<sc.TdsParser.Connect|SEC> Encryption will be disabled as target server is a SQL Local DB instance.");
417417
}
418418

419-
_sniSpnBuffer = null;
419+
_sniSpn = null;
420420
_authenticationProvider = null;
421421

422422
// AD Integrated behaves like Windows integrated when connecting to a non-fedAuth server
@@ -455,7 +455,7 @@ internal void Connect(
455455
serverInfo.ExtendedServerName,
456456
timeout,
457457
out instanceName,
458-
ref _sniSpnBuffer,
458+
ref _sniSpn,
459459
false,
460460
true,
461461
fParallel,
@@ -468,7 +468,7 @@ internal void Connect(
468468
hostNameInCertificate,
469469
serverCertificateFilename);
470470

471-
_authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this);
471+
_authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this, _sniSpn);
472472

473473
if (TdsEnums.SNI_SUCCESS != _physicalStateObj.Status)
474474
{
@@ -554,7 +554,7 @@ internal void Connect(
554554
_physicalStateObj.CreatePhysicalSNIHandle(
555555
serverInfo.ExtendedServerName,
556556
timeout, out instanceName,
557-
ref _sniSpnBuffer,
557+
ref _sniSpn,
558558
true,
559559
true,
560560
fParallel,
@@ -567,15 +567,15 @@ internal void Connect(
567567
hostNameInCertificate,
568568
serverCertificateFilename);
569569

570-
_authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this);
571-
572570
if (TdsEnums.SNI_SUCCESS != _physicalStateObj.Status)
573571
{
574572
_physicalStateObj.AddError(ProcessSNIError(_physicalStateObj));
575573
SqlClientEventSource.Log.TryTraceEvent("<sc.TdsParser.Connect|ERR|SEC> Login failure");
576574
ThrowExceptionAndWarning(_physicalStateObj);
577575
}
578576

577+
_authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this, _sniSpn);
578+
579579
uint retCode = _physicalStateObj.SniGetConnectionId(ref _connHandler._clientConnectionId);
580580

581581
Debug.Assert(retCode == TdsEnums.SNI_SUCCESS, "Unexpected failure state upon calling SniGetConnectionId");
@@ -12850,7 +12850,7 @@ internal string TraceString()
1285012850
_fMARS ? bool.TrueString : bool.FalseString,
1285112851
null == _sessionPool ? "(null)" : _sessionPool.TraceString(),
1285212852
_is2005 ? bool.TrueString : bool.FalseString,
12853-
null == _sniSpnBuffer ? "(null)" : _sniSpnBuffer.Length.ToString((IFormatProvider)null),
12853+
null == _sniSpn ? "(null)" : _sniSpn.Length.ToString((IFormatProvider)null),
1285412854
_physicalStateObj != null ? "(null)" : _physicalStateObj.ErrorCount.ToString((IFormatProvider)null),
1285512855
_physicalStateObj != null ? "(null)" : _physicalStateObj.WarningCount.ToString((IFormatProvider)null),
1285612856
_physicalStateObj != null ? "(null)" : _physicalStateObj.PreAttentionErrorCount.ToString((IFormatProvider)null),

Diff for: src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.netcore.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ internal abstract void CreatePhysicalSNIHandle(
186186
string serverName,
187187
TimeoutTimer timeout,
188188
out byte[] instanceName,
189-
ref byte[][] spnBuffer,
189+
ref string[] spn,
190190
bool flushCache,
191191
bool async,
192192
bool fParallel,

Diff for: src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ internal override void CreatePhysicalSNIHandle(
8181
string serverName,
8282
TimeoutTimer timeout,
8383
out byte[] instanceName,
84-
ref byte[][] spnBuffer,
84+
ref string[] spn,
8585
bool flushCache,
8686
bool async,
8787
bool parallel,
@@ -94,7 +94,7 @@ internal override void CreatePhysicalSNIHandle(
9494
string hostNameInCertificate,
9595
string serverCertificateFilename)
9696
{
97-
SNIHandle? sessionHandle = SNIProxy.CreateConnectionHandle(serverName, timeout, out instanceName, ref spnBuffer, serverSPN,
97+
SNIHandle? sessionHandle = SNIProxy.CreateConnectionHandle(serverName, timeout, out instanceName, ref spn, serverSPN,
9898
flushCache, async, parallel, isIntegratedSecurity, iPAddressPreference, cachedFQDN, ref pendingDNSInfo, tlsFirst,
9999
hostNameInCertificate, serverCertificateFilename);
100100

Diff for: src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs

+4-7
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ internal override void CreatePhysicalSNIHandle(
143143
string serverName,
144144
TimeoutTimer timeout,
145145
out byte[] instanceName,
146-
ref byte[][] spnBuffer,
146+
ref string[] spn,
147147
bool flushCache,
148148
bool async,
149149
bool fParallel,
@@ -157,30 +157,27 @@ internal override void CreatePhysicalSNIHandle(
157157
string serverCertificateFilename)
158158
{
159159
// We assume that the loadSSPILibrary has been called already. now allocate proper length of buffer
160-
spnBuffer = new byte[1][];
161160
if (isIntegratedSecurity)
162161
{
163162
// now allocate proper length of buffer
164163
if (!string.IsNullOrEmpty(serverSPN))
165164
{
166165
// Native SNI requires the Unicode encoding and any other encoding like UTF8 breaks the code.
167-
byte[] srvSPN = Encoding.Unicode.GetBytes(serverSPN);
168-
Trace.Assert(srvSPN.Length <= SNINativeMethodWrapper.SniMaxComposedSpnLength, "Length of the provided SPN exceeded the buffer size.");
169-
spnBuffer[0] = srvSPN;
170166
SqlClientEventSource.Log.TryTraceEvent("<{0}.{1}|SEC> Server SPN `{2}` from the connection string is used.", nameof(TdsParserStateObjectNative), nameof(CreatePhysicalSNIHandle), serverSPN);
171167
}
172168
else
173169
{
174-
spnBuffer[0] = new byte[SNINativeMethodWrapper.SniMaxComposedSpnLength];
170+
serverSPN = string.Empty;
175171
}
176172
}
177173

178174
SNINativeMethodWrapper.ConsumerInfo myInfo = CreateConsumerInfo(async);
179175
SQLDNSInfo cachedDNSInfo;
180176
bool ret = SQLFallbackDNSCache.Instance.GetDNSInfo(cachedFQDN, out cachedDNSInfo);
181177

182-
_sessionHandle = new SNIHandle(myInfo, serverName, spnBuffer[0], timeout.MillisecondsRemainingInt, out instanceName,
178+
_sessionHandle = new SNIHandle(myInfo, serverName, ref serverSPN, timeout.MillisecondsRemainingInt, out instanceName,
183179
flushCache, !async, fParallel, ipPreference, cachedDNSInfo, hostNameInCertificate);
180+
spn = new[] { serverSPN.TrimEnd() };
184181
}
185182

186183
protected override uint SNIPacketGetData(PacketHandle packet, byte[] _inBuff, ref uint dataSize)

0 commit comments

Comments
 (0)