Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable callback to negotiate NTLM or Kerberos #2248

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Directory.Build.props
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
<NuGetRoot Condition="'$(NuGetRoot)' == ''">$(RepoRoot).nuget\</NuGetRoot>
<NuGetCmd>$(NuGetRoot)nuget.exe</NuGetCmd>
<!-- Respect environment variable for the .NET install directory if set; otherwise, use the current default location -->
<TreatWarningsAsErrors>true</TreatWarningsAsErrors>
<!-- <TreatWarningsAsErrors Condition="!$(Configuration.Contains('Debug'))">true</TreatWarningsAsErrors>-->
<BuildSimulator Condition="'$(BuildSimulator)' != 'true'">false</BuildSimulator>
</PropertyGroup>
<PropertyGroup Condition="'$(BuildSimulator)' == 'true'">
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,7 @@
<Compile Include="Microsoft\Data\SqlClient\SqlConnectionFactory.AssemblyLoadContext.cs" />
</ItemGroup>
<ItemGroup>
<Compile Include="..\..\netfx\src\Microsoft\Data\SqlClient\SSPIContextManager.cs" Link="Microsoft\Data\SqlClient\SSPIContextManager.cs" />
<Compile Include="..\..\src\Resources\StringsHelper.cs">
<Link>Resources\StringsHelper.cs</Link>
</Compile>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -721,6 +721,9 @@ public Func<SqlAuthenticationParameters, CancellationToken, Task<SqlAuthenticati
}
}

/// <include file='../../../../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml' path='docs/members[@name="SqlConnection"]/NegotiateCallback/*' />
public Func<SqlAuthenticationParameters, CancellationToken, ReadOnlyMemory<byte>> NegotiateCallback { get; set; }

/// <include file='../../../../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml' path='docs/members[@name="SqlConnection"]/Database/*' />
[ResDescription(StringsHelper.ResourceNames.SqlConnection_Database)]
[ResCategory(StringsHelper.ResourceNames.SqlConnection_DataSource)]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,19 @@ internal SqlConnectionTimeoutErrorInternal TimeoutErrorInternal
private string _routingDestination = null;
private readonly TimeoutTimer _timeout;

public CancellationTokenSource CreateCancellationTokenSource()
{
CancellationTokenSource cts = new();

// Use Connection timeout value to cancel token acquire request after certain period of time.(int)
if (_timeout.MillisecondsRemaining < Int32.MaxValue)
{
cts.CancelAfter((int)_timeout.MillisecondsRemaining);
}

return cts;
}

// although the new password is generally not used it must be passed to the ctor
// the new Login7 packet will always write out the new password (or a length of zero and no bytes if not present)
//
Expand Down Expand Up @@ -2463,12 +2476,7 @@ internal SqlFedAuthToken GetFedAuthToken(SqlFedAuthInfo fedAuthInfo)
authParamsBuilder.WithPassword(ConnectionOptions.Password);
}
SqlAuthenticationParameters parameters = authParamsBuilder;
CancellationTokenSource cts = new();
// Use Connection timeout value to cancel token acquire request after certain period of time.(int)
if (_timeout.MillisecondsRemaining < Int32.MaxValue)
{
cts.CancelAfter((int)_timeout.MillisecondsRemaining);
}
using CancellationTokenSource cts = CreateCancellationTokenSource();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use curly braces for clarity rather than implicily scoped using.

_fedAuthToken = Task.Run(async () => await _accessTokenCallback(parameters, cts.Token)).GetAwaiter().GetResult().ToSqlFedAuthToken();
_activeDirectoryAuthTimeoutRetryHelper.CachedToken = _fedAuthToken;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1269,7 +1269,7 @@ private void FireInfoMessageEvent(SqlConnection connection, SqlCommand command,

sqlErs.Add(error);

SqlException exc = SqlException.CreateException(sqlErs, serverVersion, _connHandler, innerException:null, batchCommand: command?.GetCurrentBatchCommand());
SqlException exc = SqlException.CreateException(sqlErs, serverVersion, _connHandler, innerException: null, batchCommand: command?.GetCurrentBatchCommand());

bool notified;
connection.OnInfoMessage(new SqlInfoMessageEventArgs(exc), out notified);
Expand Down Expand Up @@ -3980,7 +3980,7 @@ internal bool TryProcessError(byte token, TdsParserStateObject stateObj, SqlComm
{
batchIndex = command.GetCurrentBatchIndex();
}
error = new SqlError(number, state, errorClass, _server, message, procedure, line,exception: null, batchIndex: batchIndex);
error = new SqlError(number, state, errorClass, _server, message, procedure, line, exception: null, batchIndex: batchIndex);
return true;
}

Expand Down Expand Up @@ -5723,7 +5723,7 @@ private bool TryReadSqlStringValue(SqlBuffer value, byte type, int length, Encod
s = "";
}
}

if (buffIsRented)
{
// do not use clearArray:true on the rented array because it can be massively larger
Expand Down Expand Up @@ -8581,7 +8581,18 @@ internal void SendFedAuthToken(SqlFedAuthToken fedAuthToken)

private void SSPIData(byte[] receivedBuff, uint receivedLength, ref byte[] sendBuff, ref uint sendLength)
{
if (TdsParserStateObjectFactory.UseManagedSNI)
if (Connection.Connection.NegotiateCallback is { })
{
try
{
SSPIContextManager.Invoke(Connection, ref sendBuff, ref sendLength);
}
catch (Exception e)
{
SSPIError(e.Message + Environment.NewLine + e.StackTrace, TdsEnums.GEN_CLIENT_CONTEXT);
}
}
else if (TdsParserStateObjectFactory.UseManagedSNI)
{
try
{
Expand Down Expand Up @@ -9202,7 +9213,7 @@ internal Task TdsExecuteRPC(SqlCommand cmd, IList<_SqlRPC> rpcArray, int timeout
{
// Throw an exception if ForceColumnEncryption is set on a parameter and the ColumnEncryption is not enabled on SqlConnection or SqlCommand
if (
!(cmd.ColumnEncryptionSetting == SqlCommandColumnEncryptionSetting.Enabled
!(cmd.ColumnEncryptionSetting == SqlCommandColumnEncryptionSetting.Enabled
||
(cmd.ColumnEncryptionSetting == SqlCommandColumnEncryptionSetting.UseConnectionSetting && cmd.Connection.IsColumnEncryptionSettingEnabled)))
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,7 @@
<Compile Include="Microsoft\Data\SqlClient\SqlSequentialTextReaderSmi.cs" />
<Compile Include="Microsoft\Data\SqlClient\SqlTransaction.cs" />
<Compile Include="Microsoft\Data\SqlClient\SqlUtil.cs" />
<Compile Include="Microsoft\Data\SqlClient\SSPIContextManager.cs" />
<Compile Include="Microsoft\Data\SqlClient\TdsParser.cs" />
<Compile Include="Microsoft\Data\SqlClient\TdsParserHelperClasses.cs" />
<Compile Include="Microsoft\Data\SqlClient\TdsParserStateObject.netfx.cs" />
Expand Down Expand Up @@ -741,4 +742,4 @@
<Import Project="$(NetFxSource)tools\targets\GenerateThisAssemblyCs.targets" />
<Import Project="$(NetFxSource)tools\targets\GenerateAssemblyRef.targets" />
<Import Project="$(NetFxSource)tools\targets\GenerateAssemblyInfo.targets" />
</Project>
</Project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
using System;
using System.Diagnostics;

namespace Microsoft.Data.SqlClient
{
internal static class SSPIContextManager
{
#if NETFRAMEWORK
public static void Invoke(SqlInternalConnectionTds Connection, byte[] output, ref uint outputLength)
#else
public static void Invoke(SqlInternalConnectionTds Connection, ref byte[] output, ref uint outputLength)
#endif
{
Debug.Assert(Connection.Connection.NegotiateCallback is not null);

var result = Invoke(Connection);

#if !NETFRAMEWORK
output = new byte[result.Length];
#endif
result.CopyTo(output);
outputLength = (uint)result.Length;
}

private static ReadOnlyMemory<byte> Invoke(SqlInternalConnectionTds Connection)
{
var auth = new SqlAuthenticationParameters.Builder(Connection.ConnectionOptions.Authentication, Connection.ConnectionOptions.ObtainWorkstationId(), "auth", Connection.ConnectionOptions.DataSource, Connection.ConnectionOptions.InitialCatalog);

if (Connection.ConnectionOptions.UserID is { } userId)
{
auth.WithUserId(userId);
}

if (Connection.ConnectionOptions.Password is { } password)
{
auth.WithPassword(password);
}

using var cts = Connection.CreateCancellationTokenSource();

return Connection.Connection.NegotiateCallback(auth, cts.Token);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -768,6 +768,9 @@ public Func<SqlAuthenticationParameters, CancellationToken, Task<SqlAuthenticati
}
}

/// <include file='../../../../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml' path='docs/members[@name="SqlConnection"]/NegotiateCallback/*' />
public Func<SqlAuthenticationParameters, CancellationToken, ReadOnlyMemory<byte>> NegotiateCallback { get; set; }

/// <include file='../../../../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml' path='docs/members[@name="SqlConnection"]/CommandTimeout/*' />
[
DesignerSerializationVisibility(DesignerSerializationVisibility.Hidden),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,19 @@ internal bool IsDNSCachingBeforeRedirectSupported

private static HashSet<int> transientErrors = new HashSet<int>();

public CancellationTokenSource CreateCancellationTokenSource()
{
CancellationTokenSource cts = new();

// Use Connection timeout value to cancel token acquire request after certain period of time.(int)
if (_timeout.MillisecondsRemaining < Int32.MaxValue)
{
cts.CancelAfter((int)_timeout.MillisecondsRemaining);
}

return cts;
}

internal SessionData CurrentSessionData
{
get
Expand Down Expand Up @@ -2877,12 +2890,7 @@ internal SqlFedAuthToken GetFedAuthToken(SqlFedAuthInfo fedAuthInfo)
authParamsBuilder.WithPassword(ConnectionOptions.Password);
}
SqlAuthenticationParameters parameters = authParamsBuilder;
CancellationTokenSource cts = new();
// Use Connection timeout value to cancel token acquire request after certain period of time.(int)
if (_timeout.MillisecondsRemaining < Int32.MaxValue)
{
cts.CancelAfter((int)_timeout.MillisecondsRemaining);
}
using CancellationTokenSource cts = CreateCancellationTokenSource();
_fedAuthToken = Task.Run(async () => await _accessTokenCallback(parameters, cts.Token)).GetAwaiter().GetResult().ToSqlFedAuthToken();
_activeDirectoryAuthTimeoutRetryHelper.CachedToken = _fedAuthToken;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1485,10 +1485,10 @@ private PreLoginHandshakeStatus ConsumePreLoginHandshake(
}

// Validate Certificate if Trust Server Certificate=false and Encryption forced (EncryptionOptions.ON) from Server.
bool shouldValidateServerCert = (_encryptionOption == EncryptionOptions.ON && !trustServerCert) ||
((authType != SqlAuthenticationMethod.NotSpecified || (_connHandler._accessTokenInBytes != null ||
_connHandler._accessTokenCallback != null))
&& !trustServerCert);
bool shouldValidateServerCert = (_encryptionOption == EncryptionOptions.ON && !trustServerCert) ||
((authType != SqlAuthenticationMethod.NotSpecified || (_connHandler._accessTokenInBytes != null ||
_connHandler._accessTokenCallback != null))
&& !trustServerCert);

UInt32 info = (shouldValidateServerCert ? TdsEnums.SNI_SSL_VALIDATE_CERTIFICATE : 0)
| (is2005OrLater && (_encryptionOption & EncryptionOptions.CLIENT_CERT) == 0 ? TdsEnums.SNI_SSL_USE_SCHANNEL_CACHE : 0);
Expand Down Expand Up @@ -2528,7 +2528,7 @@ internal bool TryRun(RunBehavior runBehavior, SqlCommand cmdHandler, SqlDataRead
(error.Class <= TdsEnums.MAX_USER_CORRECTABLE_ERROR_CLASS))
{
// Fire SqlInfoMessage here
FireInfoMessageEvent(connection,cmdHandler, stateObj, error);
FireInfoMessageEvent(connection, cmdHandler, stateObj, error);
}
else
{
Expand Down Expand Up @@ -8886,7 +8886,7 @@ internal void TdsLogin(SqlLogin rec,
TdsEnums.FeatureExtension requestedFeatures,
SessionData recoverySessionData,
FederatedAuthenticationFeatureExtensionData fedAuthFeatureExtensionData,
SqlClientOriginalNetworkAddressInfo originalNetworkAddressInfo,
SqlClientOriginalNetworkAddressInfo originalNetworkAddressInfo,
SqlConnectionEncryptOption encrypt)
{
_physicalStateObj.SetTimeoutSeconds(rec.timeout);
Expand Down Expand Up @@ -9462,10 +9462,25 @@ private void SNISSPIData(byte[] receivedBuff, UInt32 receivedLength, byte[] send
// we do not have SSPI data coming from server, so send over 0's for pointer and length
receivedLength = 0;
}
// we need to respond to the server's message with SSPI data
if (0 != SNINativeMethodWrapper.SNISecGenClientContext(_physicalStateObj.Handle, receivedBuff, receivedLength, sendBuff, ref sendLength, _sniSpnBuffer))

if (Connection.Connection.NegotiateCallback is { })
{
SSPIError(SQLMessage.SSPIGenerateError(), TdsEnums.GEN_CLIENT_CONTEXT);
try
{
SSPIContextManager.Invoke(Connection, sendBuff, ref sendLength);
}
catch (Exception e)
{
SSPIError(e.Message + Environment.NewLine + e.StackTrace, TdsEnums.GEN_CLIENT_CONTEXT);
}
}
else
{
// we need to respond to the server's message with SSPI data
if (0 != SNINativeMethodWrapper.SNISecGenClientContext(_physicalStateObj.Handle, receivedBuff, receivedLength, sendBuff, ref sendLength, _sniSpnBuffer))
{
SSPIError(SQLMessage.SSPIGenerateError(), TdsEnums.GEN_CLIENT_CONTEXT);
}
}
}

Expand Down Expand Up @@ -10101,7 +10116,7 @@ internal Task TdsExecuteRPC(SqlCommand cmd, IList<_SqlRPC> rpcArray, int timeout
{
// Throw an exception if ForceColumnEncryption is set on a parameter and the ColumnEncryption is not enabled on SqlConnection or SqlCommand
if (
!(cmd.ColumnEncryptionSetting == SqlCommandColumnEncryptionSetting.Enabled
!(cmd.ColumnEncryptionSetting == SqlCommandColumnEncryptionSetting.Enabled
||
(cmd.ColumnEncryptionSetting == SqlCommandColumnEncryptionSetting.UseConnectionSetting && cmd.Connection.IsColumnEncryptionSettingEnabled))
)
Expand Down