Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable callback to negotiate NTLM or Kerberos #2248

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Directory.Build.props
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
<NuGetRoot Condition="'$(NuGetRoot)' == ''">$(RepoRoot).nuget\</NuGetRoot>
<NuGetCmd>$(NuGetRoot)nuget.exe</NuGetCmd>
<!-- Respect environment variable for the .NET install directory if set; otherwise, use the current default location -->
<TreatWarningsAsErrors>true</TreatWarningsAsErrors>
<!-- <TreatWarningsAsErrors Condition="!$(Configuration.Contains('Debug'))">true</TreatWarningsAsErrors>-->
<BuildSimulator Condition="'$(BuildSimulator)' != 'true'">false</BuildSimulator>
</PropertyGroup>
<PropertyGroup Condition="'$(BuildSimulator)' == 'true'">
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,7 @@
<Compile Include="Microsoft\Data\SqlClient\SqlConnectionFactory.AssemblyLoadContext.cs" />
</ItemGroup>
<ItemGroup>
<Compile Include="..\..\src\Microsoft\Data\SqlClient\TdsParser.SSPI.cs" Link="Microsoft\Data\SqlClient\TdsParser.SSPI.cs" />
<Compile Include="..\..\src\Resources\StringsHelper.cs">
<Link>Resources\StringsHelper.cs</Link>
</Compile>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ private static readonly Dictionary<string, SqlColumnEncryptionKeyStoreProvider>
/// Global custom provider list can only supplied once per application.
/// </summary>
private static IReadOnlyDictionary<string, SqlColumnEncryptionKeyStoreProvider> s_globalCustomColumnEncryptionKeyStoreProviders;
private Func<NegotiateCallbackContext, CancellationToken, Task<ReadOnlyMemory<byte>>> _negotiateCallback;

/// <summary>
/// Dictionary object holding trusted key paths for various SQL Servers.
Expand Down Expand Up @@ -721,6 +722,17 @@ public Func<SqlAuthenticationParameters, CancellationToken, Task<SqlAuthenticati
}
}

/// <include file='../../../../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml' path='docs/members[@name="SqlConnection"]/NegotiateCallback/*' />
public Func<NegotiateCallbackContext, CancellationToken, Task<ReadOnlyMemory<byte>>> NegotiateCallback
{
get => _negotiateCallback;
set
{
_negotiateCallback = value;
ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, credential: _credential, accessToken: null, negotiateCallback: value));
}
}

/// <include file='../../../../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml' path='docs/members[@name="SqlConnection"]/Database/*' />
[ResDescription(StringsHelper.ResourceNames.SqlConnection_Database)]
[ResCategory(StringsHelper.ResourceNames.SqlConnection_DataSource)]
Expand Down Expand Up @@ -1059,7 +1071,7 @@ private void CheckAndThrowOnInvalidCombinationOfConnectionOptionAndAccessToken(S
throw ADP.InvalidMixedUsageOfCredentialAndAccessToken();
}

if(_accessTokenCallback != null)
if (_accessTokenCallback != null)
{
throw ADP.InvalidMixedUsageOfAccessTokenAndTokenCallback();
}
Expand All @@ -1081,7 +1093,7 @@ private void CheckAndThrowOnInvalidCombinationOfConnectionOptionAndAccessTokenCa
throw ADP.InvalidMixedUsageOfAccessTokenCallbackAndAuthentication();
}

if(_accessToken != null)
if (_accessToken != null)
{
throw ADP.InvalidMixedUsageOfAccessTokenAndTokenCallback();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ override protected DbConnectionInternal CreateConnection(DbConnectionOptions opt
opt = new SqlConnectionString(opt, instanceName, userInstance: false, setEnlistValue: null);
poolGroupProviderInfo = null; // null so we do not pass to constructor below...
}
return new SqlInternalConnectionTds(identity, opt, key.Credential, poolGroupProviderInfo, "", null, redirectedUserInstance, userOpt, recoverySessionData, applyTransientFaultHandling: applyTransientFaultHandling, key.AccessToken, pool, key.AccessTokenCallback);
return new SqlInternalConnectionTds(identity, opt, key.Credential, poolGroupProviderInfo, "", null, redirectedUserInstance, userOpt, recoverySessionData, applyTransientFaultHandling: applyTransientFaultHandling, key.AccessToken, pool, key.AccessTokenCallback, key.NegotiateCallback);
}

protected override DbConnectionOptions CreateConnectionOptions(string connectionString, DbConnectionOptions previous)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,9 @@ internal sealed class SqlInternalConnectionTds : SqlInternalConnection, IDisposa
// The Federated Authentication returned by TryGetFedAuthTokenLocked or GetFedAuthToken.
SqlFedAuthToken _fedAuthToken = null;
internal byte[] _accessTokenInBytes;
internal readonly Func<SqlAuthenticationParameters, CancellationToken,Task<SqlAuthenticationToken>> _accessTokenCallback;
internal readonly Func<SqlAuthenticationParameters, CancellationToken, Task<SqlAuthenticationToken>> _accessTokenCallback;

internal readonly Func<NegotiateCallbackContext, CancellationToken, Task<ReadOnlyMemory<byte>>> _negotiateCallback;

private readonly ActiveDirectoryAuthenticationTimeoutRetryHelper _activeDirectoryAuthTimeoutRetryHelper;
private readonly SqlAuthenticationProviderManager _sqlAuthenticationProviderManager;
Expand Down Expand Up @@ -431,6 +433,19 @@ internal SqlConnectionTimeoutErrorInternal TimeoutErrorInternal
private string _routingDestination = null;
private readonly TimeoutTimer _timeout;

public CancellationTokenSource CreateCancellationTokenSource()
{
CancellationTokenSource cts = new();

// Use Connection timeout value to cancel token acquire request after certain period of time.(int)
if (_timeout.MillisecondsRemaining < Int32.MaxValue)
{
cts.CancelAfter((int)_timeout.MillisecondsRemaining);
}

return cts;
}

// although the new password is generally not used it must be passed to the ctor
// the new Login7 packet will always write out the new password (or a length of zero and no bytes if not present)
//
Expand All @@ -447,8 +462,9 @@ internal SqlInternalConnectionTds(
bool applyTransientFaultHandling = false,
string accessToken = null,
DbConnectionPool pool = null,
Func<SqlAuthenticationParameters, CancellationToken,
Task<SqlAuthenticationToken>> accessTokenCallback = null) : base(connectionOptions)
Func<SqlAuthenticationParameters, CancellationToken, Task<SqlAuthenticationToken>> accessTokenCallback = null,
Func<NegotiateCallbackContext, CancellationToken, Task<ReadOnlyMemory<byte>>> negotiateCallback = null)
: base(connectionOptions)

{
#if DEBUG
Expand Down Expand Up @@ -482,6 +498,7 @@ internal SqlInternalConnectionTds(
}

_accessTokenCallback = accessTokenCallback;
_negotiateCallback = negotiateCallback;

_activeDirectoryAuthTimeoutRetryHelper = new ActiveDirectoryAuthenticationTimeoutRetryHelper();
_sqlAuthenticationProviderManager = SqlAuthenticationProviderManager.Instance;
Expand Down Expand Up @@ -2463,12 +2480,7 @@ internal SqlFedAuthToken GetFedAuthToken(SqlFedAuthInfo fedAuthInfo)
authParamsBuilder.WithPassword(ConnectionOptions.Password);
}
SqlAuthenticationParameters parameters = authParamsBuilder;
CancellationTokenSource cts = new();
// Use Connection timeout value to cancel token acquire request after certain period of time.(int)
if (_timeout.MillisecondsRemaining < Int32.MaxValue)
{
cts.CancelAfter((int)_timeout.MillisecondsRemaining);
}
using CancellationTokenSource cts = CreateCancellationTokenSource();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use curly braces for clarity rather than implicily scoped using.

_fedAuthToken = Task.Run(async () => await _accessTokenCallback(parameters, cts.Token)).GetAwaiter().GetResult().ToSqlFedAuthToken();
_activeDirectoryAuthTimeoutRetryHelper.CachedToken = _fedAuthToken;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1269,7 +1269,7 @@ private void FireInfoMessageEvent(SqlConnection connection, SqlCommand command,

sqlErs.Add(error);

SqlException exc = SqlException.CreateException(sqlErs, serverVersion, _connHandler, innerException:null, batchCommand: command?.GetCurrentBatchCommand());
SqlException exc = SqlException.CreateException(sqlErs, serverVersion, _connHandler, innerException: null, batchCommand: command?.GetCurrentBatchCommand());

bool notified;
connection.OnInfoMessage(new SqlInfoMessageEventArgs(exc), out notified);
Expand Down Expand Up @@ -3980,7 +3980,7 @@ internal bool TryProcessError(byte token, TdsParserStateObject stateObj, SqlComm
{
batchIndex = command.GetCurrentBatchIndex();
}
error = new SqlError(number, state, errorClass, _server, message, procedure, line,exception: null, batchIndex: batchIndex);
error = new SqlError(number, state, errorClass, _server, message, procedure, line, exception: null, batchIndex: batchIndex);
return true;
}

Expand Down Expand Up @@ -5723,7 +5723,7 @@ private bool TryReadSqlStringValue(SqlBuffer value, byte type, int length, Encod
s = "";
}
}

if (buffIsRented)
{
// do not use clearArray:true on the rented array because it can be massively larger
Expand Down Expand Up @@ -8581,7 +8581,18 @@ internal void SendFedAuthToken(SqlFedAuthToken fedAuthToken)

private void SSPIData(byte[] receivedBuff, uint receivedLength, ref byte[] sendBuff, ref uint sendLength)
{
if (TdsParserStateObjectFactory.UseManagedSNI)
if (Connection._negotiateCallback is { })
{
try
{
HandleNegotiateCallback(Connection, receivedBuff.AsMemory(0, (int)receivedLength), ref sendBuff, ref sendLength);
}
catch (Exception e)
{
SSPIError(e.Message + Environment.NewLine + e.StackTrace, TdsEnums.GEN_CLIENT_CONTEXT);
}
}
else if (TdsParserStateObjectFactory.UseManagedSNI)
{
try
{
Expand Down Expand Up @@ -9202,7 +9213,7 @@ internal Task TdsExecuteRPC(SqlCommand cmd, IList<_SqlRPC> rpcArray, int timeout
{
// Throw an exception if ForceColumnEncryption is set on a parameter and the ColumnEncryption is not enabled on SqlConnection or SqlCommand
if (
!(cmd.ColumnEncryptionSetting == SqlCommandColumnEncryptionSetting.Enabled
!(cmd.ColumnEncryptionSetting == SqlCommandColumnEncryptionSetting.Enabled
||
(cmd.ColumnEncryptionSetting == SqlCommandColumnEncryptionSetting.UseConnectionSetting && cmd.Connection.IsColumnEncryptionSettingEnabled)))
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,9 @@
<Compile Include="..\..\src\Microsoft\Data\ProviderBase\TimeoutTimer.cs">
<Link>Microsoft\Data\ProviderBase\TimeoutTimer.cs</Link>
</Compile>
<Compile Include="..\..\src\Microsoft\Data\SqlClient\TdsParser.SSPI.cs">
<Link>Microsoft\Data\SqlClient\TdsParser.SSPI.cs</Link>
</Compile>
<Compile Include="..\..\src\Microsoft\Data\Sql\SqlDataSourceEnumerator.cs">
<Link>Microsoft\Data\Sql\SqlDataSourceEnumerator.cs</Link>
</Compile>
Expand Down Expand Up @@ -741,4 +744,4 @@
<Import Project="$(NetFxSource)tools\targets\GenerateThisAssemblyCs.targets" />
<Import Project="$(NetFxSource)tools\targets\GenerateAssemblyRef.targets" />
<Import Project="$(NetFxSource)tools\targets\GenerateAssemblyInfo.targets" />
</Project>
</Project>
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,7 @@ public SqlRetryLogicBaseProvider RetryLogicProvider
// The downstream handling of Connection open is the same for idle connection resiliency. Currently we want to apply transient fault handling only to the connections opened
// using SqlConnection.Open() method.
internal bool _applyTransientFaultHandling = false;
private Func<NegotiateCallbackContext, CancellationToken, Task<ReadOnlyMemory<byte>>> _negotiateCallback;

/// <include file='../../../../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml' path='docs/members[@name="SqlConnection"]/ctorConnectionString/*' />
public SqlConnection(string connectionString) : this(connectionString, null)
Expand Down Expand Up @@ -768,6 +769,17 @@ public Func<SqlAuthenticationParameters, CancellationToken, Task<SqlAuthenticati
}
}

/// <include file='../../../../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml' path='docs/members[@name="SqlConnection"]/NegotiateCallback/*' />
public Func<NegotiateCallbackContext, CancellationToken, Task<ReadOnlyMemory<byte>>> NegotiateCallback
{
get => _negotiateCallback;
set
{
_negotiateCallback = value;
ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, credential: _credential, accessToken: null, _serverCertificateValidationCallback, _clientCertificateRetrievalCallback, _originalNetworkAddressInfo, negotiateCallback: value));
}
}

/// <include file='../../../../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml' path='docs/members[@name="SqlConnection"]/CommandTimeout/*' />
[
DesignerSerializationVisibility(DesignerSerializationVisibility.Hidden),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ override protected DbConnectionInternal CreateConnection(DbConnectionOptions opt
opt = new SqlConnectionString(opt, instanceName, false /* user instance=false */, null /* do not modify the Enlist value */);
poolGroupProviderInfo = null; // null so we do not pass to constructor below...
}
result = new SqlInternalConnectionTds(identity, opt, key.Credential, poolGroupProviderInfo, "", null, redirectedUserInstance, userOpt, recoverySessionData, key.ServerCertificateValidationCallback, key.ClientCertificateRetrievalCallback, pool, key.AccessToken, key.OriginalNetworkAddressInfo, applyTransientFaultHandling: applyTransientFaultHandling, key.AccessTokenCallback);
result = new SqlInternalConnectionTds(identity, opt, key.Credential, poolGroupProviderInfo, "", null, redirectedUserInstance, userOpt, recoverySessionData, key.ServerCertificateValidationCallback, key.ClientCertificateRetrievalCallback, pool, key.AccessToken, key.OriginalNetworkAddressInfo, applyTransientFaultHandling: applyTransientFaultHandling, key.AccessTokenCallback, key.NegotiateCallback);
}
return result;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,11 @@ sealed internal class SqlInternalConnectionTds : SqlInternalConnection, IDisposa

// The Federated Authentication returned by TryGetFedAuthTokenLocked or GetFedAuthToken.
SqlFedAuthToken _fedAuthToken = null;

internal byte[] _accessTokenInBytes;
internal readonly Func<SqlAuthenticationParameters, CancellationToken,Task<SqlAuthenticationToken>> _accessTokenCallback;
internal readonly Func<SqlAuthenticationParameters, CancellationToken, Task<SqlAuthenticationToken>> _accessTokenCallback;

internal readonly Func<NegotiateCallbackContext, CancellationToken, Task<ReadOnlyMemory<byte>>> _negotiateCallback;

private readonly ActiveDirectoryAuthenticationTimeoutRetryHelper _activeDirectoryAuthTimeoutRetryHelper;
private readonly SqlAuthenticationProviderManager _sqlAuthenticationProviderManager;
Expand Down Expand Up @@ -259,6 +262,19 @@ internal bool IsDNSCachingBeforeRedirectSupported

private static HashSet<int> transientErrors = new HashSet<int>();

public CancellationTokenSource CreateCancellationTokenSource()
{
CancellationTokenSource cts = new();

// Use Connection timeout value to cancel token acquire request after certain period of time.(int)
if (_timeout.MillisecondsRemaining < Int32.MaxValue)
{
cts.CancelAfter((int)_timeout.MillisecondsRemaining);
}

return cts;
}

internal SessionData CurrentSessionData
{
get
Expand Down Expand Up @@ -434,8 +450,9 @@ internal SqlInternalConnectionTds(
string accessToken = null,
SqlClientOriginalNetworkAddressInfo originalNetworkAddressInfo = null,
bool applyTransientFaultHandling = false,
Func<SqlAuthenticationParameters, CancellationToken,
Task<SqlAuthenticationToken>> accessTokenCallback = null) : base(connectionOptions)
Func<SqlAuthenticationParameters, CancellationToken, Task<SqlAuthenticationToken>> accessTokenCallback = null,
Func<NegotiateCallbackContext, CancellationToken, Task<ReadOnlyMemory<byte>>> negotiateCallback = null)
: base(connectionOptions)
{

#if DEBUG
Expand Down Expand Up @@ -491,6 +508,7 @@ internal SqlInternalConnectionTds(
}

_accessTokenCallback = accessTokenCallback;
_negotiateCallback = negotiateCallback;

_activeDirectoryAuthTimeoutRetryHelper = new ActiveDirectoryAuthenticationTimeoutRetryHelper();
_sqlAuthenticationProviderManager = SqlAuthenticationProviderManager.Instance;
Expand Down Expand Up @@ -2877,12 +2895,7 @@ internal SqlFedAuthToken GetFedAuthToken(SqlFedAuthInfo fedAuthInfo)
authParamsBuilder.WithPassword(ConnectionOptions.Password);
}
SqlAuthenticationParameters parameters = authParamsBuilder;
CancellationTokenSource cts = new();
// Use Connection timeout value to cancel token acquire request after certain period of time.(int)
if (_timeout.MillisecondsRemaining < Int32.MaxValue)
{
cts.CancelAfter((int)_timeout.MillisecondsRemaining);
}
using CancellationTokenSource cts = CreateCancellationTokenSource();
_fedAuthToken = Task.Run(async () => await _accessTokenCallback(parameters, cts.Token)).GetAwaiter().GetResult().ToSqlFedAuthToken();
_activeDirectoryAuthTimeoutRetryHelper.CachedToken = _fedAuthToken;
}
Expand Down
Loading