Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use string for SNI instead of byte[] #2790

Merged
merged 12 commits into from
Mar 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,9 @@
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\SqlObjectPool.cs">
<Link>Microsoft\Data\SqlClient\SqlObjectPool.cs</Link>
</Compile>
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\SqlObjectPools.cs">
<Link>Microsoft\Data\SqlClient\SqlObjectPools.cs</Link>
</Compile>
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\SqlParameter.cs">
<Link>Microsoft\Data\SqlClient\SqlParameter.cs</Link>
</Compile>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

using System;
using System.Buffers;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Net;
Expand Down Expand Up @@ -33,7 +34,7 @@ internal class SNIProxy
/// <param name="fullServerName">Full server name from connection string</param>
/// <param name="timeout">Timer expiration</param>
/// <param name="instanceName">Instance name</param>
/// <param name="spnBuffer">SPN</param>
/// <param name="spns">SPNs</param>
/// <param name="serverSPN">pre-defined SPN</param>
/// <param name="flushCache">Flush packet cache</param>
/// <param name="async">Asynchronous connection</param>
Expand All @@ -50,7 +51,7 @@ internal static SNIHandle CreateConnectionHandle(
string fullServerName,
TimeoutTimer timeout,
out byte[] instanceName,
ref byte[][] spnBuffer,
ref string[] spns,
string serverSPN,
bool flushCache,
bool async,
Expand Down Expand Up @@ -102,7 +103,7 @@ internal static SNIHandle CreateConnectionHandle(
{
try
{
spnBuffer = GetSqlServerSPNs(details, serverSPN);
spns = GetSqlServerSPNs(details, serverSPN);
}
catch (Exception e)
{
Expand All @@ -114,12 +115,12 @@ internal static SNIHandle CreateConnectionHandle(
return sniHandle;
}

private static byte[][] GetSqlServerSPNs(DataSource dataSource, string serverSPN)
private static string[] GetSqlServerSPNs(DataSource dataSource, string serverSPN)
{
Debug.Assert(!string.IsNullOrWhiteSpace(dataSource.ServerName));
if (!string.IsNullOrWhiteSpace(serverSPN))
{
return new byte[1][] { Encoding.Unicode.GetBytes(serverSPN) };
return new[] { serverSPN };
}

string hostName = dataSource.ServerName;
Expand All @@ -137,7 +138,7 @@ private static byte[][] GetSqlServerSPNs(DataSource dataSource, string serverSPN
return GetSqlServerSPNs(hostName, postfix, dataSource.ResolvedProtocol);
}

private static byte[][] GetSqlServerSPNs(string hostNameOrAddress, string portOrInstanceName, DataSource.Protocol protocol)
private static string[] GetSqlServerSPNs(string hostNameOrAddress, string portOrInstanceName, DataSource.Protocol protocol)
{
Debug.Assert(!string.IsNullOrWhiteSpace(hostNameOrAddress));
IPHostEntry hostEntry = null;
Expand Down Expand Up @@ -168,12 +169,12 @@ private static byte[][] GetSqlServerSPNs(string hostNameOrAddress, string portOr
string serverSpnWithDefaultPort = serverSpn + $":{DefaultSqlServerPort}";
// Set both SPNs with and without Port as Port is optional for default instance
SqlClientEventSource.Log.TryAdvancedTraceEvent("SNIProxy.GetSqlServerSPN | Info | ServerSPNs {0} and {1}", serverSpn, serverSpnWithDefaultPort);
return new byte[][] { Encoding.Unicode.GetBytes(serverSpn), Encoding.Unicode.GetBytes(serverSpnWithDefaultPort) };
return new[] { serverSpn, serverSpnWithDefaultPort };
}
// else Named Pipes do not need to valid port

SqlClientEventSource.Log.TryAdvancedTraceEvent("SNIProxy.GetSqlServerSPN | Info | ServerSPN {0}", serverSpn);
return new byte[][] { Encoding.Unicode.GetBytes(serverSpn) };
return new[] { serverSpn };
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,7 @@ internal sealed partial class TdsParser

private bool _is2022 = false;

private byte[][] _sniSpnBuffer = null;
// UNDONE - need to have some for both instances - both command and default???
private string[] _serverSpn = null;

// SqlStatistics
private SqlStatistics _statistics = null;
Expand Down Expand Up @@ -390,7 +389,7 @@ internal void Connect(
}
else
{
_sniSpnBuffer = null;
_serverSpn = null;
SqlClientEventSource.Log.TryTraceEvent("TdsParser.Connect | SEC | Connection Object Id {0}, Authentication Mode: {1}", _connHandler.ObjectID,
authType == SqlAuthenticationMethod.NotSpecified ? SqlAuthenticationMethod.SqlPassword.ToString() : authType.ToString());
}
Expand All @@ -402,7 +401,7 @@ internal void Connect(
SqlClientEventSource.Log.TryTraceEvent("<sc.TdsParser.Connect|SEC> Encryption will be disabled as target server is a SQL Local DB instance.");
}

_sniSpnBuffer = null;
_serverSpn = null;
_authenticationProvider = null;

// AD Integrated behaves like Windows integrated when connecting to a non-fedAuth server
Expand Down Expand Up @@ -441,7 +440,7 @@ internal void Connect(
serverInfo.ExtendedServerName,
timeout,
out instanceName,
ref _sniSpnBuffer,
ref _serverSpn,
false,
true,
fParallel,
Expand Down Expand Up @@ -540,7 +539,7 @@ internal void Connect(
_physicalStateObj.CreatePhysicalSNIHandle(
serverInfo.ExtendedServerName,
timeout, out instanceName,
ref _sniSpnBuffer,
ref _serverSpn,
true,
true,
fParallel,
Expand Down Expand Up @@ -13317,7 +13316,7 @@ internal string TraceString()
_fMARS ? bool.TrueString : bool.FalseString,
_sessionPool == null ? "(null)" : _sessionPool.TraceString(),
_is2005 ? bool.TrueString : bool.FalseString,
_sniSpnBuffer == null ? "(null)" : _sniSpnBuffer.Length.ToString((IFormatProvider)null),
_serverSpn == null ? "(null)" : _serverSpn.Length.ToString((IFormatProvider)null),
_physicalStateObj != null ? "(null)" : _physicalStateObj.ErrorCount.ToString((IFormatProvider)null),
_physicalStateObj != null ? "(null)" : _physicalStateObj.WarningCount.ToString((IFormatProvider)null),
_physicalStateObj != null ? "(null)" : _physicalStateObj.PreAttentionErrorCount.ToString((IFormatProvider)null),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ internal abstract void CreatePhysicalSNIHandle(
string serverName,
TimeoutTimer timeout,
out byte[] instanceName,
ref byte[][] spnBuffer,
ref string[] spns,
bool flushCache,
bool async,
bool fParallel,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ internal override void CreatePhysicalSNIHandle(
string serverName,
TimeoutTimer timeout,
out byte[] instanceName,
ref byte[][] spnBuffer,
ref string[] spns,
bool flushCache,
bool async,
bool parallel,
Expand All @@ -94,7 +94,7 @@ internal override void CreatePhysicalSNIHandle(
string hostNameInCertificate,
string serverCertificateFilename)
{
SNIHandle? sessionHandle = SNIProxy.CreateConnectionHandle(serverName, timeout, out instanceName, ref spnBuffer, serverSPN,
SNIHandle? sessionHandle = SNIProxy.CreateConnectionHandle(serverName, timeout, out instanceName, ref spns, serverSPN,
flushCache, async, parallel, isIntegratedSecurity, iPAddressPreference, cachedFQDN, ref pendingDNSInfo, tlsFirst,
hostNameInCertificate, serverCertificateFilename);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ internal override void CreatePhysicalSNIHandle(
string serverName,
TimeoutTimer timeout,
out byte[] instanceName,
ref byte[][] spnBuffer,
ref string[] spns,
bool flushCache,
bool async,
bool fParallel,
Expand All @@ -157,31 +157,28 @@ internal override void CreatePhysicalSNIHandle(
string hostNameInCertificate,
string serverCertificateFilename)
{
// We assume that the loadSSPILibrary has been called already. now allocate proper length of buffer
spnBuffer = new byte[1][];
if (isIntegratedSecurity)
{
// now allocate proper length of buffer
if (!string.IsNullOrEmpty(serverSPN))
{
// Native SNI requires the Unicode encoding and any other encoding like UTF8 breaks the code.
byte[] srvSPN = Encoding.Unicode.GetBytes(serverSPN);
Trace.Assert(srvSPN.Length <= SniNativeWrapper.SniMaxComposedSpnLength, "Length of the provided SPN exceeded the buffer size.");
spnBuffer[0] = srvSPN;
SqlClientEventSource.Log.TryTraceEvent("<{0}.{1}|SEC> Server SPN `{2}` from the connection string is used.", nameof(TdsParserStateObjectNative), nameof(CreatePhysicalSNIHandle), serverSPN);
}
else
{
spnBuffer[0] = new byte[SniNativeWrapper.SniMaxComposedSpnLength];
// This will signal to the interop layer that we need to retrieve the SPN
serverSPN = string.Empty;
}
}

ConsumerInfo myInfo = CreateConsumerInfo(async);
SQLDNSInfo cachedDNSInfo;
bool ret = SQLFallbackDNSCache.Instance.GetDNSInfo(cachedFQDN, out cachedDNSInfo);

_sessionHandle = new SNIHandle(myInfo, serverName, spnBuffer[0], timeout.MillisecondsRemainingInt, out instanceName,
_sessionHandle = new SNIHandle(myInfo, serverName, ref serverSPN, timeout.MillisecondsRemainingInt, out instanceName,
flushCache, !async, fParallel, ipPreference, cachedDNSInfo, hostNameInCertificate);
spns = new[] { serverSPN.TrimEnd() };
}

protected override uint SNIPacketGetData(PacketHandle packet, byte[] _inBuff, ref uint dataSize)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -337,10 +337,13 @@
<Compile Include="$(CommonSourceRoot)Microsoft\Data\Sql\SqlDataSourceEnumerator.cs">
<Link>Microsoft\Data\Sql\SqlDataSourceEnumerator.cs</Link>
</Compile>
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\SqlObjectPool.cs">
<Link>Microsoft\Data\SqlClient\SqlObjectPool.cs</Link>
</Compile>
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\AAsyncCallContext.cs">
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\SqlObjectPool.cs">
<Link>Microsoft\Data\SqlClient\SqlObjectPool.cs</Link>
</Compile>
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\SqlObjectPools.cs">
<Link>Microsoft\Data\SqlClient\SqlObjectPools.cs</Link>
</Compile>
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\AAsyncCallContext.cs">
<Link>Microsoft\Data\SqlClient\AAsyncCallContext.cs</Link>
</Compile>
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\ActiveDirectoryAuthenticationTimeoutRetryHelper.cs">
Expand Down Expand Up @@ -854,6 +857,7 @@
<Compile Include="Microsoft\Data\Common\DbConnectionString.cs" />
<Compile Include="Microsoft\Data\Common\GreenMethods.cs" />
<Compile Include="Microsoft\Data\SqlClient\assemblycache.cs" />
<Compile Include="Microsoft\Data\SqlClient\BufferWriterExtensions.cs" />
<Compile Include="Microsoft\Data\SqlClient\Reliability\SqlConfigurableRetryLogicManager.LoadType.cs" />
<Compile Include="Microsoft\Data\SqlClient\Server\SmiConnection.cs" />
<Compile Include="Microsoft\Data\SqlClient\Server\SmiContext.cs" />
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
using System;
using System.Buffers;
using System.Text;

namespace Microsoft.Data.SqlClient
{
internal static class BufferWriterExtensions
{
internal static long GetBytes(this Encoding encoding, string str, IBufferWriter<byte> bufferWriter)
{
var count = encoding.GetByteCount(str);
var array = ArrayPool<byte>.Shared.Rent(count);

try
{
var length = encoding.GetBytes(str, 0, str.Length, array, 0);
bufferWriter.Write(array.AsSpan(0, length));
return length;
}
finally
{
ArrayPool<byte>.Shared.Return(array);
}
}
}
}
Loading
Loading