Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use string for SNI instead of byte[] #2790

Merged
merged 12 commits into from
Mar 3, 2025
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ internal class SNIProxy
/// <param name="fullServerName">Full server name from connection string</param>
/// <param name="timeout">Timer expiration</param>
/// <param name="instanceName">Instance name</param>
/// <param name="spnBuffer">SPN</param>
/// <param name="spns">SPNs</param>
/// <param name="serverSPN">pre-defined SPN</param>
/// <param name="flushCache">Flush packet cache</param>
/// <param name="async">Asynchronous connection</param>
Expand All @@ -51,7 +51,7 @@ internal static SNIHandle CreateConnectionHandle(
string fullServerName,
TimeoutTimer timeout,
out byte[] instanceName,
ref string[] spnBuffer,
ref string[] spns,
string serverSPN,
bool flushCache,
bool async,
Expand Down Expand Up @@ -103,7 +103,7 @@ internal static SNIHandle CreateConnectionHandle(
{
try
{
spnBuffer = GetSqlServerSPNs(details, serverSPN);
spns = GetSqlServerSPNs(details, serverSPN);
}
catch (Exception e)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ internal sealed partial class TdsParser

private bool _is2022 = false;

private string[] _sniSpn = null;
private string[] _serverSpn = null;

// SqlStatistics
private SqlStatistics _statistics = null;
Expand Down Expand Up @@ -389,7 +389,7 @@ internal void Connect(
}
else
{
_sniSpn = null;
_serverSpn = null;
SqlClientEventSource.Log.TryTraceEvent("TdsParser.Connect | SEC | Connection Object Id {0}, Authentication Mode: {1}", _connHandler.ObjectID,
authType == SqlAuthenticationMethod.NotSpecified ? SqlAuthenticationMethod.SqlPassword.ToString() : authType.ToString());
}
Expand All @@ -401,7 +401,7 @@ internal void Connect(
SqlClientEventSource.Log.TryTraceEvent("<sc.TdsParser.Connect|SEC> Encryption will be disabled as target server is a SQL Local DB instance.");
}

_sniSpn = null;
_serverSpn = null;
_authenticationProvider = null;

// AD Integrated behaves like Windows integrated when connecting to a non-fedAuth server
Expand Down Expand Up @@ -440,7 +440,7 @@ internal void Connect(
serverInfo.ExtendedServerName,
timeout,
out instanceName,
ref _sniSpn,
ref _serverSpn,
false,
true,
fParallel,
Expand All @@ -453,6 +453,8 @@ internal void Connect(
hostNameInCertificate,
serverCertificateFilename);

_authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this);

if (TdsEnums.SNI_SUCCESS != _physicalStateObj.Status)
{
_physicalStateObj.AddError(ProcessSNIError(_physicalStateObj));
Expand All @@ -467,8 +469,6 @@ internal void Connect(
Debug.Fail("SNI returned status != success, but no error thrown?");
}

_authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this);

_server = serverInfo.ResolvedServerName;

if (connHandler.PoolGroupProviderInfo != null)
Expand Down Expand Up @@ -539,7 +539,7 @@ internal void Connect(
_physicalStateObj.CreatePhysicalSNIHandle(
serverInfo.ExtendedServerName,
timeout, out instanceName,
ref _sniSpn,
ref _serverSpn,
true,
true,
fParallel,
Expand All @@ -552,15 +552,15 @@ internal void Connect(
hostNameInCertificate,
serverCertificateFilename);

_authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this);

if (TdsEnums.SNI_SUCCESS != _physicalStateObj.Status)
{
_physicalStateObj.AddError(ProcessSNIError(_physicalStateObj));
SqlClientEventSource.Log.TryTraceEvent("<sc.TdsParser.Connect|ERR|SEC> Login failure");
ThrowExceptionAndWarning(_physicalStateObj);
}

_authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this);

uint retCode = _physicalStateObj.SniGetConnectionId(ref _connHandler._clientConnectionId);

Debug.Assert(retCode == TdsEnums.SNI_SUCCESS, "Unexpected failure state upon calling SniGetConnectionId");
Expand Down Expand Up @@ -13316,7 +13316,7 @@ internal string TraceString()
_fMARS ? bool.TrueString : bool.FalseString,
_sessionPool == null ? "(null)" : _sessionPool.TraceString(),
_is2005 ? bool.TrueString : bool.FalseString,
_sniSpn == null ? "(null)" : _sniSpn.Length.ToString((IFormatProvider)null),
_serverSpn == null ? "(null)" : _serverSpn.Length.ToString((IFormatProvider)null),
_physicalStateObj != null ? "(null)" : _physicalStateObj.ErrorCount.ToString((IFormatProvider)null),
_physicalStateObj != null ? "(null)" : _physicalStateObj.WarningCount.ToString((IFormatProvider)null),
_physicalStateObj != null ? "(null)" : _physicalStateObj.PreAttentionErrorCount.ToString((IFormatProvider)null),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ internal int ObjectID

private bool _is2022 = false;

private string _sniSpn = null;
private string _serverSpn = null;

// UNDONE - need to have some for both instances - both command and default???

Expand Down Expand Up @@ -163,7 +163,7 @@ internal int ObjectID
// now data length is 1 byte
// First bit is 1 indicating client support failover partner with readonly intent
private static readonly byte[] s_FeatureExtDataAzureSQLSupportFeatureRequest = { 0x01 };

// NOTE: You must take the internal connection's _parserLock before modifying this
internal bool _asyncWrite = false;

Expand Down Expand Up @@ -430,24 +430,25 @@ internal void Connect(ServerInfo serverInfo,
// AD Integrated behaves like Windows integrated when connecting to a non-fedAuth server
if (integratedSecurity || authType == SqlAuthenticationMethod.ActiveDirectoryIntegrated)
{
_authenticationProvider = _physicalStateObj.CreateSSPIContextProvider();

if (!string.IsNullOrEmpty(serverInfo.ServerSPN))
{
_sniSpn = serverInfo.ServerSPN;
_serverSpn = serverInfo.ServerSPN;
SqlClientEventSource.Log.TryTraceEvent("<sc.TdsParser.Connect|SEC> Server SPN `{0}` from the connection string is used.", serverInfo.ServerSPN);
}
else
{
// Empty signifies to interop layer that SNI needs to be generated
_sniSpn = string.Empty;
// Empty signifies to interop layer that SPN needs to be generated
_serverSpn = string.Empty;
}

_authenticationProvider = _physicalStateObj.CreateSSPIContextProvider();
SqlClientEventSource.Log.TryTraceEvent("<sc.TdsParser.Connect|SEC> SSPI or Active Directory Authentication Library for SQL Server based integrated authentication");
}
else
{
_authenticationProvider = null;
_sniSpn = null;
_serverSpn = null;

switch (authType)
{
Expand Down Expand Up @@ -526,7 +527,7 @@ internal void Connect(ServerInfo serverInfo,
serverInfo.ExtendedServerName,
timeout,
out instanceName,
ref _sniSpn,
ref _serverSpn,
false,
true,
fParallel,
Expand All @@ -536,6 +537,8 @@ internal void Connect(ServerInfo serverInfo,
FQDNforDNSCache,
hostNameInCertificate);

_authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this);

if (TdsEnums.SNI_SUCCESS != _physicalStateObj.Status)
{
_physicalStateObj.AddError(ProcessSNIError(_physicalStateObj));
Expand All @@ -550,8 +553,6 @@ internal void Connect(ServerInfo serverInfo,
Debug.Fail("SNI returned status != success, but no error thrown?");
}

_authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this);

_server = serverInfo.ResolvedServerName;

if (connHandler.PoolGroupProviderInfo != null)
Expand Down Expand Up @@ -626,7 +627,7 @@ internal void Connect(ServerInfo serverInfo,
serverInfo.ExtendedServerName,
timeout,
out instanceName,
ref _sniSpn,
ref _serverSpn,
true,
true,
fParallel,
Expand All @@ -636,15 +637,15 @@ internal void Connect(ServerInfo serverInfo,
serverInfo.ResolvedServerName,
hostNameInCertificate);

_authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this);

if (TdsEnums.SNI_SUCCESS != _physicalStateObj.Status)
{
_physicalStateObj.AddError(ProcessSNIError(_physicalStateObj));
SqlClientEventSource.Log.TryTraceEvent("<sc.TdsParser.Connect|ERR|SEC> Login failure");
ThrowExceptionAndWarning(_physicalStateObj);
}

_authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this);

uint retCode = SniNativeWrapper.SniGetConnectionId(_physicalStateObj.Handle, ref _connHandler._clientConnectionId);
Debug.Assert(retCode == TdsEnums.SNI_SUCCESS, "Unexpected failure state upon calling SniGetConnectionId");
SqlClientEventSource.Log.TryTraceEvent("<sc.TdsParser.Connect|SEC> Sending prelogin handshake");
Expand Down Expand Up @@ -3390,7 +3391,7 @@ private TdsOperationStatus TryProcessDone(SqlCommand cmd, SqlDataReader reader,
Debug.Assert(!((sqlTransaction != null && _distributedTransaction != null) ||
(_userStartedLocalTransaction != null && _distributedTransaction != null))
, "ProcessDone - have both distributed and local transactions not null!");
*/
*/
// WebData 112722

stateObj.DecrementOpenResultCount();
Expand Down Expand Up @@ -3877,8 +3878,8 @@ private TdsOperationStatus TryProcessSessionState(TdsParserStateObject stateObj,
if (!recoverable)
{
checked
{
sdata._unrecoverableStatesCount++;
{
sdata._unrecoverableStatesCount++;
}
}
}
Expand All @@ -3899,8 +3900,8 @@ private TdsOperationStatus TryProcessSessionState(TdsParserStateObject stateObj,
else
{
checked
{
sdata._unrecoverableStatesCount++;
{
sdata._unrecoverableStatesCount++;
}
}
sv._recoverable = recoverable;
Expand Down Expand Up @@ -3979,29 +3980,29 @@ private TdsOperationStatus TryProcessLoginAck(TdsParserStateObject stateObj, out
{
case TdsEnums.SQL2005_MAJOR << 24 | TdsEnums.SQL2005_RTM_MINOR: // 2005
if (increment != TdsEnums.SQL2005_INCREMENT)
{
throw SQL.InvalidTDSVersion();
{
throw SQL.InvalidTDSVersion();
}
_is2005 = true;
break;
case TdsEnums.SQL2008_MAJOR << 24 | TdsEnums.SQL2008_MINOR:
if (increment != TdsEnums.SQL2008_INCREMENT)
{
throw SQL.InvalidTDSVersion();
{
throw SQL.InvalidTDSVersion();
}
_is2008 = true;
break;
case TdsEnums.SQL2012_MAJOR << 24 | TdsEnums.SQL2012_MINOR:
if (increment != TdsEnums.SQL2012_INCREMENT)
{
throw SQL.InvalidTDSVersion();
{
throw SQL.InvalidTDSVersion();
}
_is2012 = true;
break;
case TdsEnums.TDS8_MAJOR << 24 | TdsEnums.TDS8_MINOR:
if (increment != TdsEnums.TDS8_INCREMENT)
{
throw SQL.InvalidTDSVersion();
{
throw SQL.InvalidTDSVersion();
}
_is2022 = true;
break;
Expand Down Expand Up @@ -5934,7 +5935,7 @@ private TdsOperationStatus TryProcessColInfo(_SqlMetaDataSet columns, SqlDataRea
for (int i = 0; i < columns.Length; i++)
{
_SqlMetaData col = columns[i];

result = stateObj.TryReadByte(out _);
if (result != TdsOperationStatus.Done)
{
Expand Down Expand Up @@ -6471,7 +6472,7 @@ private TdsOperationStatus TryReadSqlStringValue(SqlBuffer value, byte type, int
char[] cc = null;
bool buffIsRented = false;
result = TryReadPlpUnicodeChars(ref cc, 0, length >> 1, stateObj, out length, supportRentedBuff: true, rentedBuff: ref buffIsRented);

if (result == TdsOperationStatus.Done)
{
if (length > 0)
Expand Down Expand Up @@ -11370,7 +11371,7 @@ internal object EncryptColumnValue(object value, SqlMetaDataPriv metadata, strin
actualLengthInBytes = (isSqlType) ? ((SqlBinary)value).Length : ((byte[])value).Length;
if (metadata.baseTI.length > 0 &&
actualLengthInBytes > metadata.baseTI.length)
{
{
// see comments above
actualLengthInBytes = metadata.baseTI.length;
}
Expand Down Expand Up @@ -12278,7 +12279,7 @@ public override Task WriteAsync(byte[] buffer, int offset, int count, Cancellati
_parser.WriteInt(count, _stateObj); // write length of chunk
task = _stateObj.WriteByteArray(buffer, count, offset, canAccumulate: false);
}

return task ?? Task.CompletedTask;
}
catch (System.OutOfMemoryException)
Expand Down Expand Up @@ -12511,7 +12512,7 @@ private async Task WriteTextFeed(TextDataFeed feed, Encoding encoding, bool need
char[] inBuff = ArrayPool<char>.Shared.Rent(constTextBufferSize);

encoding = encoding ?? new UnicodeEncoding(false, false);

using (ConstrainedTextWriter writer = new ConstrainedTextWriter(new StreamWriter(new TdsOutputStream(this, stateObj, null), encoding), size))
{
if (needBom)
Expand Down Expand Up @@ -13429,7 +13430,7 @@ internal TdsOperationStatus TryReadPlpUnicodeChars(
int charsRead = 0;
int charsLeft = 0;
char[] newbuf;

if (stateObj._longlen == 0)
{
Debug.Assert(stateObj._longlenleft == 0);
Expand Down Expand Up @@ -13546,7 +13547,7 @@ internal TdsOperationStatus TryReadPlpUnicodeChars(
totalCharsRead++;
}
if (stateObj._longlenleft == 0)
{
{
// Read the next chunk or cleanup state if hit the end
result = stateObj.TryReadPlpLength(false, out _);
if (result != TdsOperationStatus.Done)
Expand Down Expand Up @@ -13782,7 +13783,7 @@ internal string TraceString()
_is2000 ? bool.TrueString : bool.FalseString,
_is2000SP1 ? bool.TrueString : bool.FalseString,
_is2005 ? bool.TrueString : bool.FalseString,
_sniSpn == null ? "(null)" : _sniSpn.Length.ToString((IFormatProvider)null),
_serverSpn == null ? "(null)" : _serverSpn.Length.ToString((IFormatProvider)null),
_physicalStateObj != null ? "(null)" : _physicalStateObj.ErrorCount.ToString((IFormatProvider)null),
_physicalStateObj != null ? "(null)" : _physicalStateObj.WarningCount.ToString((IFormatProvider)null),
_physicalStateObj != null ? "(null)" : _physicalStateObj.PreAttentionErrorCount.ToString((IFormatProvider)null),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ private void LoadSSPILibrary()
}
}

protected override void GenerateSspiClientContext(ReadOnlySpan<byte> incomingBlob, IBufferWriter<byte> outgoingBlobWriter, ReadOnlySpan<string> _sniSpnBuffer)
protected override void GenerateSspiClientContext(ReadOnlySpan<byte> incomingBlob, IBufferWriter<byte> outgoingBlobWriter, ReadOnlySpan<string> serverSpn)
{
#if NETFRAMEWORK
SNIHandle handle = _physicalStateObj.Handle;
Expand All @@ -62,7 +62,7 @@ protected override void GenerateSspiClientContext(ReadOnlySpan<byte> incomingBlo
var sendLength = s_maxSSPILength;
var outBuff = outgoingBlobWriter.GetSpan((int)sendLength);

if (0 != SniNativeWrapper.SNISecGenClientContext(handle, incomingBlob, outBuff, ref sendLength, _sniSpnBuffer[0]))
if (0 != SniNativeWrapper.SNISecGenClientContext(handle, incomingBlob, outBuff, ref sendLength, serverSpn[0]))
{
throw new InvalidOperationException(SQLMessage.SSPIGenerateError());
}
Expand Down
Loading
Loading